Blink roll 148954:148987
[chromium-blink-merge.git] / chrome_frame / http_negotiate.cc
blobd77bd923850cacb8fa652cd1b4cb068ebf77749d
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "chrome_frame/http_negotiate.h"
7 #include <atlbase.h>
8 #include <atlcom.h>
9 #include <htiframe.h>
11 #include "base/logging.h"
12 #include "base/memory/scoped_ptr.h"
13 #include "base/string_util.h"
14 #include "base/stringprintf.h"
15 #include "base/utf_string_conversions.h"
16 #include "chrome_frame/bho.h"
17 #include "chrome_frame/exception_barrier.h"
18 #include "chrome_frame/html_utils.h"
19 #include "chrome_frame/urlmon_moniker.h"
20 #include "chrome_frame/urlmon_url_request.h"
21 #include "chrome_frame/utils.h"
22 #include "chrome_frame/vtable_patch_manager.h"
23 #include "net/http/http_response_headers.h"
24 #include "net/http/http_util.h"
26 bool HttpNegotiatePatch::modify_user_agent_ = true;
27 const char kUACompatibleHttpHeader[] = "x-ua-compatible";
28 const char kLowerCaseUserAgent[] = "user-agent";
30 // From the latest urlmon.h. Symbol name prepended with LOCAL_ to
31 // avoid conflict (and therefore build errors) for those building with
32 // a newer Windows SDK.
33 // TODO(robertshield): Remove this once we update our SDK version.
34 const int LOCAL_BINDSTATUS_SERVER_MIMETYPEAVAILABLE = 54;
36 static const int kHttpNegotiateBeginningTransactionIndex = 3;
38 BEGIN_VTABLE_PATCHES(IHttpNegotiate)
39 VTABLE_PATCH_ENTRY(kHttpNegotiateBeginningTransactionIndex,
40 HttpNegotiatePatch::BeginningTransaction)
41 END_VTABLE_PATCHES()
43 namespace {
45 class SimpleBindStatusCallback : public CComObjectRootEx<CComSingleThreadModel>,
46 public IBindStatusCallback {
47 public:
48 BEGIN_COM_MAP(SimpleBindStatusCallback)
49 COM_INTERFACE_ENTRY(IBindStatusCallback)
50 END_COM_MAP()
52 // IBindStatusCallback implementation
53 STDMETHOD(OnStartBinding)(DWORD reserved, IBinding* binding) {
54 return E_NOTIMPL;
57 STDMETHOD(GetPriority)(LONG* priority) {
58 return E_NOTIMPL;
60 STDMETHOD(OnLowResource)(DWORD reserved) {
61 return E_NOTIMPL;
64 STDMETHOD(OnProgress)(ULONG progress, ULONG max_progress,
65 ULONG status_code, LPCWSTR status_text) {
66 return E_NOTIMPL;
68 STDMETHOD(OnStopBinding)(HRESULT result, LPCWSTR error) {
69 return E_NOTIMPL;
72 STDMETHOD(GetBindInfo)(DWORD* bind_flags, BINDINFO* bind_info) {
73 return E_NOTIMPL;
76 STDMETHOD(OnDataAvailable)(DWORD flags, DWORD size, FORMATETC* formatetc,
77 STGMEDIUM* storage) {
78 return E_NOTIMPL;
80 STDMETHOD(OnObjectAvailable)(REFIID iid, IUnknown* object) {
81 return E_NOTIMPL;
85 // Returns the full user agent header from the HTTP header strings passed to
86 // IHttpNegotiate::BeginningTransaction. Looks first in |additional_headers|
87 // and if it can't be found there looks in |headers|.
88 std::string GetUserAgentFromHeaders(LPCWSTR headers,
89 LPCWSTR additional_headers) {
90 using net::HttpUtil;
92 std::string ascii_headers;
93 if (additional_headers) {
94 ascii_headers = WideToASCII(additional_headers);
97 // Extract "User-Agent" from |additional_headers| or |headers|.
98 HttpUtil::HeadersIterator headers_iterator(ascii_headers.begin(),
99 ascii_headers.end(), "\r\n");
100 std::string user_agent_value;
101 if (headers_iterator.AdvanceTo(kLowerCaseUserAgent)) {
102 user_agent_value = headers_iterator.values();
103 } else if (headers != NULL) {
104 // See if there's a user-agent header specified in the original headers.
105 std::string original_headers(WideToASCII(headers));
106 HttpUtil::HeadersIterator original_it(original_headers.begin(),
107 original_headers.end(), "\r\n");
108 if (original_it.AdvanceTo(kLowerCaseUserAgent))
109 user_agent_value = original_it.values();
112 return user_agent_value;
115 // Removes the named header |field| from a set of headers. |field| must be
116 // lower-case.
117 std::string ExcludeFieldFromHeaders(const std::string& old_headers,
118 const char* field) {
119 using net::HttpUtil;
120 std::string new_headers;
121 new_headers.reserve(old_headers.size());
122 HttpUtil::HeadersIterator headers_iterator(old_headers.begin(),
123 old_headers.end(), "\r\n");
124 while (headers_iterator.GetNext()) {
125 if (!LowerCaseEqualsASCII(headers_iterator.name_begin(),
126 headers_iterator.name_end(),
127 field)) {
128 new_headers.append(headers_iterator.name_begin(),
129 headers_iterator.name_end());
130 new_headers += ": ";
131 new_headers.append(headers_iterator.values_begin(),
132 headers_iterator.values_end());
133 new_headers += "\r\n";
137 return new_headers;
140 std::string MutateCFUserAgentString(LPCWSTR headers,
141 LPCWSTR additional_headers,
142 bool add_user_agent) {
143 std::string user_agent_value(GetUserAgentFromHeaders(headers,
144 additional_headers));
146 // Use the default "User-Agent" if none was provided.
147 if (user_agent_value.empty())
148 user_agent_value = http_utils::GetDefaultUserAgent();
150 // Now add chromeframe to it.
151 user_agent_value = add_user_agent ?
152 http_utils::AddChromeFrameToUserAgentValue(user_agent_value) :
153 http_utils::RemoveChromeFrameFromUserAgentValue(user_agent_value);
155 // Build a new set of additional headers, skipping the existing user agent
156 // value if present.
157 return ReplaceOrAddUserAgent(additional_headers, user_agent_value);
160 } // end namespace
163 std::string AppendCFUserAgentString(LPCWSTR headers,
164 LPCWSTR additional_headers) {
165 return MutateCFUserAgentString(headers, additional_headers, true);
169 // Looks for a user agent header found in |headers| or |additional_headers|
170 // then returns |additional_headers| with a modified user agent header that does
171 // not include the chromeframe token.
172 std::string RemoveCFUserAgentString(LPCWSTR headers,
173 LPCWSTR additional_headers) {
174 return MutateCFUserAgentString(headers, additional_headers, false);
178 // Unconditionally adds the specified |user_agent_value| to the given set of
179 // |headers|, removing any that were already there.
180 std::string ReplaceOrAddUserAgent(LPCWSTR headers,
181 const std::string& user_agent_value) {
182 std::string new_headers;
183 if (headers) {
184 std::string ascii_headers(WideToASCII(headers));
185 // Build new headers, skip the existing user agent value from
186 // existing headers.
187 new_headers = ExcludeFieldFromHeaders(ascii_headers, kLowerCaseUserAgent);
189 new_headers += "User-Agent: ";
190 new_headers += user_agent_value;
191 new_headers += "\r\n";
192 return new_headers;
195 HttpNegotiatePatch::HttpNegotiatePatch() {
198 HttpNegotiatePatch::~HttpNegotiatePatch() {
201 // static
202 bool HttpNegotiatePatch::Initialize() {
203 if (IS_PATCHED(IHttpNegotiate)) {
204 DLOG(WARNING) << __FUNCTION__ << " called more than once.";
205 return true;
207 // Use our SimpleBindStatusCallback class as we need a temporary object that
208 // implements IBindStatusCallback.
209 CComObjectStackEx<SimpleBindStatusCallback> request;
210 base::win::ScopedComPtr<IBindCtx> bind_ctx;
211 HRESULT hr = CreateAsyncBindCtx(0, &request, NULL, bind_ctx.Receive());
212 DCHECK(SUCCEEDED(hr)) << "CreateAsyncBindCtx";
213 if (bind_ctx) {
214 base::win::ScopedComPtr<IUnknown> bscb_holder;
215 bind_ctx->GetObjectParam(L"_BSCB_Holder_", bscb_holder.Receive());
216 if (bscb_holder) {
217 hr = PatchHttpNegotiate(bscb_holder);
218 } else {
219 NOTREACHED() << "Failed to get _BSCB_Holder_";
220 hr = E_UNEXPECTED;
222 bind_ctx.Release();
225 return SUCCEEDED(hr);
228 // static
229 void HttpNegotiatePatch::Uninitialize() {
230 vtable_patch::UnpatchInterfaceMethods(IHttpNegotiate_PatchInfo);
233 // static
234 HRESULT HttpNegotiatePatch::PatchHttpNegotiate(IUnknown* to_patch) {
235 DCHECK(to_patch);
236 DCHECK_IS_NOT_PATCHED(IHttpNegotiate);
238 base::win::ScopedComPtr<IHttpNegotiate> http;
239 HRESULT hr = http.QueryFrom(to_patch);
240 if (FAILED(hr)) {
241 hr = DoQueryService(IID_IHttpNegotiate, to_patch, http.Receive());
244 if (http) {
245 hr = vtable_patch::PatchInterfaceMethods(http, IHttpNegotiate_PatchInfo);
246 DLOG_IF(ERROR, FAILED(hr))
247 << base::StringPrintf("HttpNegotiate patch failed 0x%08X", hr);
248 } else {
249 DLOG(WARNING)
250 << base::StringPrintf("IHttpNegotiate not supported 0x%08X", hr);
252 return hr;
255 // static
256 HRESULT HttpNegotiatePatch::BeginningTransaction(
257 IHttpNegotiate_BeginningTransaction_Fn original, IHttpNegotiate* me,
258 LPCWSTR url, LPCWSTR headers, DWORD reserved, LPWSTR* additional_headers) {
259 DVLOG(1) << __FUNCTION__ << " " << url << " headers:\n" << headers;
261 HRESULT hr = original(me, url, headers, reserved, additional_headers);
263 if (FAILED(hr)) {
264 DLOG(WARNING) << __FUNCTION__ << " Delegate returned an error";
265 return hr;
267 if (modify_user_agent_) {
268 std::string updated_headers;
270 if (IsGcfDefaultRenderer() &&
271 RendererTypeForUrl(url) == RENDERER_TYPE_CHROME_DEFAULT_RENDERER) {
272 // Replace the user-agent header with Chrome's.
273 updated_headers = ReplaceOrAddUserAgent(*additional_headers,
274 http_utils::GetChromeUserAgent());
275 } else if (ShouldRemoveUAForUrl(url)) {
276 updated_headers = RemoveCFUserAgentString(headers, *additional_headers);
277 } else {
278 updated_headers = AppendCFUserAgentString(headers, *additional_headers);
281 *additional_headers = reinterpret_cast<wchar_t*>(::CoTaskMemRealloc(
282 *additional_headers,
283 (updated_headers.length() + 1) * sizeof(wchar_t)));
284 lstrcpyW(*additional_headers, ASCIIToWide(updated_headers).c_str());
285 } else {
286 // TODO(erikwright): Remove the user agent if it is present (i.e., because
287 // of PostPlatform setting in the registry).
289 return S_OK;