Update Clear Key to support keyids formatted init_data
[chromium-blink-merge.git] / media / cdm / json_web_key.cc
blob21df837798a38e6416c17dbada85a61ba7268f69
1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "media/cdm/json_web_key.h"
7 #include "base/base64.h"
8 #include "base/json/json_reader.h"
9 #include "base/json/json_string_value_serializer.h"
10 #include "base/json/string_escape.h"
11 #include "base/logging.h"
12 #include "base/memory/scoped_ptr.h"
13 #include "base/strings/string_number_conversions.h"
14 #include "base/strings/string_util.h"
15 #include "base/values.h"
17 namespace media {
19 const char kKeysTag[] = "keys";
20 const char kKeyTypeTag[] = "kty";
21 const char kKeyTypeOct[] = "oct"; // Octet sequence.
22 const char kAlgTag[] = "alg";
23 const char kAlgA128KW[] = "A128KW"; // AES key wrap using a 128-bit key.
24 const char kKeyTag[] = "k";
25 const char kKeyIdTag[] = "kid";
26 const char kKeyIdsTag[] = "kids";
27 const char kBase64Padding = '=';
28 const char kBase64Plus[] = "+";
29 const char kBase64UrlPlusReplacement[] = "-";
30 const char kBase64Slash[] = "/";
31 const char kBase64UrlSlashReplacement[] = "_";
32 const char kBase64UrlInvalid[] = "+/=";
33 const char kTypeTag[] = "type";
34 const char kTemporarySession[] = "temporary";
35 const char kPersistentLicenseSession[] = "persistent-license";
36 const char kPersistentReleaseMessageSession[] = "persistent-release-message";
38 // Encodes |input| into a base64url string without padding.
39 static std::string EncodeBase64Url(const uint8* input, int input_length) {
40 std::string encoded_text;
41 base::Base64Encode(
42 std::string(reinterpret_cast<const char*>(input), input_length),
43 &encoded_text);
45 // Remove any padding characters added by Base64Encode().
46 size_t found = encoded_text.find_last_not_of(kBase64Padding);
47 if (found != std::string::npos)
48 encoded_text.erase(found + 1);
50 // base64url encoding means the characters '-' and '_' must be used
51 // instead of '+' and '/', respectively.
52 base::ReplaceChars(encoded_text, kBase64Plus, kBase64UrlPlusReplacement,
53 &encoded_text);
54 base::ReplaceChars(encoded_text, kBase64Slash, kBase64UrlSlashReplacement,
55 &encoded_text);
57 return encoded_text;
60 // Decodes a base64url string. Returns empty string on error.
61 static std::string DecodeBase64Url(const std::string& encoded_text) {
62 // EME spec doesn't allow '+', '/', or padding characters.
63 if (encoded_text.find_first_of(kBase64UrlInvalid) != std::string::npos) {
64 DVLOG(1) << "Invalid base64url format: " << encoded_text;
65 return std::string();
68 // Since base::Base64Decode() requires padding characters, add them so length
69 // of |encoded_text| is exactly a multiple of 4.
70 size_t num_last_grouping_chars = encoded_text.length() % 4;
71 std::string modified_text = encoded_text;
72 if (num_last_grouping_chars > 0)
73 modified_text.append(4 - num_last_grouping_chars, kBase64Padding);
75 // base64url encoding means the characters '-' and '_' must be used
76 // instead of '+' and '/', respectively, so replace them before calling
77 // base::Base64Decode().
78 base::ReplaceChars(modified_text, kBase64UrlPlusReplacement, kBase64Plus,
79 &modified_text);
80 base::ReplaceChars(modified_text, kBase64UrlSlashReplacement, kBase64Slash,
81 &modified_text);
83 std::string decoded_text;
84 if (!base::Base64Decode(modified_text, &decoded_text)) {
85 DVLOG(1) << "Base64 decoding failed on: " << modified_text;
86 return std::string();
89 return decoded_text;
92 static std::string ShortenTo64Characters(const std::string& input) {
93 // Convert |input| into a string with escaped characters replacing any
94 // non-ASCII characters. Limiting |input| to the first 65 characters so
95 // we don't waste time converting a potentially long string and then
96 // throwing away the excess.
97 std::string escaped_str =
98 base::EscapeBytesAsInvalidJSONString(input.substr(0, 65), false);
99 if (escaped_str.length() <= 64u)
100 return escaped_str;
102 // This may end up truncating an escaped character, but the first part of
103 // the string should provide enough information.
104 return escaped_str.substr(0, 61).append("...");
107 std::string GenerateJWKSet(const uint8* key, int key_length,
108 const uint8* key_id, int key_id_length) {
109 // Both |key| and |key_id| need to be base64 encoded strings in the JWK.
110 std::string key_base64 = EncodeBase64Url(key, key_length);
111 std::string key_id_base64 = EncodeBase64Url(key_id, key_id_length);
113 // Create the JWK, and wrap it into a JWK Set.
114 scoped_ptr<base::DictionaryValue> jwk(new base::DictionaryValue());
115 jwk->SetString(kKeyTypeTag, kKeyTypeOct);
116 jwk->SetString(kAlgTag, kAlgA128KW);
117 jwk->SetString(kKeyTag, key_base64);
118 jwk->SetString(kKeyIdTag, key_id_base64);
119 scoped_ptr<base::ListValue> list(new base::ListValue());
120 list->Append(jwk.release());
121 base::DictionaryValue jwk_set;
122 jwk_set.Set(kKeysTag, list.release());
124 // Finally serialize |jwk_set| into a string and return it.
125 std::string serialized_jwk;
126 JSONStringValueSerializer serializer(&serialized_jwk);
127 serializer.Serialize(jwk_set);
128 return serialized_jwk;
131 // Processes a JSON Web Key to extract the key id and key value. Sets |jwk_key|
132 // to the id/value pair and returns true on success.
133 static bool ConvertJwkToKeyPair(const base::DictionaryValue& jwk,
134 KeyIdAndKeyPair* jwk_key) {
135 std::string type;
136 if (!jwk.GetString(kKeyTypeTag, &type) || type != kKeyTypeOct) {
137 DVLOG(1) << "Missing or invalid '" << kKeyTypeTag << "': " << type;
138 return false;
141 std::string alg;
142 if (!jwk.GetString(kAlgTag, &alg) || alg != kAlgA128KW) {
143 DVLOG(1) << "Missing or invalid '" << kAlgTag << "': " << alg;
144 return false;
147 // Get the key id and actual key parameters.
148 std::string encoded_key_id;
149 std::string encoded_key;
150 if (!jwk.GetString(kKeyIdTag, &encoded_key_id)) {
151 DVLOG(1) << "Missing '" << kKeyIdTag << "' parameter";
152 return false;
154 if (!jwk.GetString(kKeyTag, &encoded_key)) {
155 DVLOG(1) << "Missing '" << kKeyTag << "' parameter";
156 return false;
159 // Key ID and key are base64-encoded strings, so decode them.
160 std::string raw_key_id = DecodeBase64Url(encoded_key_id);
161 if (raw_key_id.empty()) {
162 DVLOG(1) << "Invalid '" << kKeyIdTag << "' value: " << encoded_key_id;
163 return false;
166 std::string raw_key = DecodeBase64Url(encoded_key);
167 if (raw_key.empty()) {
168 DVLOG(1) << "Invalid '" << kKeyTag << "' value: " << encoded_key;
169 return false;
172 // Add the decoded key ID and the decoded key to the list.
173 *jwk_key = std::make_pair(raw_key_id, raw_key);
174 return true;
177 bool ExtractKeysFromJWKSet(const std::string& jwk_set,
178 KeyIdAndKeyPairs* keys,
179 MediaKeys::SessionType* session_type) {
180 if (!base::IsStringASCII(jwk_set)) {
181 DVLOG(1) << "Non ASCII JWK Set: " << jwk_set;
182 return false;
185 scoped_ptr<base::Value> root(base::JSONReader().ReadToValue(jwk_set));
186 if (!root.get() || root->GetType() != base::Value::TYPE_DICTIONARY) {
187 DVLOG(1) << "Not valid JSON: " << jwk_set << ", root: " << root.get();
188 return false;
191 // Locate the set from the dictionary.
192 base::DictionaryValue* dictionary =
193 static_cast<base::DictionaryValue*>(root.get());
194 base::ListValue* list_val = NULL;
195 if (!dictionary->GetList(kKeysTag, &list_val)) {
196 DVLOG(1) << "Missing '" << kKeysTag
197 << "' parameter or not a list in JWK Set";
198 return false;
201 // Create a local list of keys, so that |jwk_keys| only gets updated on
202 // success.
203 KeyIdAndKeyPairs local_keys;
204 for (size_t i = 0; i < list_val->GetSize(); ++i) {
205 base::DictionaryValue* jwk = NULL;
206 if (!list_val->GetDictionary(i, &jwk)) {
207 DVLOG(1) << "Unable to access '" << kKeysTag << "'[" << i
208 << "] in JWK Set";
209 return false;
211 KeyIdAndKeyPair key_pair;
212 if (!ConvertJwkToKeyPair(*jwk, &key_pair)) {
213 DVLOG(1) << "Error from '" << kKeysTag << "'[" << i << "]";
214 return false;
216 local_keys.push_back(key_pair);
219 // Successfully processed all JWKs in the set. Now check if "type" is
220 // specified.
221 base::Value* value = NULL;
222 std::string session_type_id;
223 if (!dictionary->Get(kTypeTag, &value)) {
224 // Not specified, so use the default type.
225 *session_type = MediaKeys::TEMPORARY_SESSION;
226 } else if (!value->GetAsString(&session_type_id)) {
227 DVLOG(1) << "Invalid '" << kTypeTag << "' value";
228 return false;
229 } else if (session_type_id == kTemporarySession) {
230 *session_type = MediaKeys::TEMPORARY_SESSION;
231 } else if (session_type_id == kPersistentLicenseSession) {
232 *session_type = MediaKeys::PERSISTENT_LICENSE_SESSION;
233 } else if (session_type_id == kPersistentReleaseMessageSession) {
234 *session_type = MediaKeys::PERSISTENT_RELEASE_MESSAGE_SESSION;
235 } else {
236 DVLOG(1) << "Invalid '" << kTypeTag << "' value: " << session_type_id;
237 return false;
240 // All done.
241 keys->swap(local_keys);
242 return true;
245 bool ExtractKeyIdsFromKeyIdsInitData(const std::string& input,
246 KeyIdList* key_ids,
247 std::string* error_message) {
248 if (!base::IsStringASCII(input)) {
249 error_message->assign("Non ASCII: ");
250 error_message->append(ShortenTo64Characters(input));
251 return false;
254 scoped_ptr<base::Value> root(base::JSONReader().ReadToValue(input));
255 if (!root.get() || root->GetType() != base::Value::TYPE_DICTIONARY) {
256 error_message->assign("Not valid JSON: ");
257 error_message->append(ShortenTo64Characters(input));
258 return false;
261 // Locate the set from the dictionary.
262 base::DictionaryValue* dictionary =
263 static_cast<base::DictionaryValue*>(root.get());
264 base::ListValue* list_val = NULL;
265 if (!dictionary->GetList(kKeyIdsTag, &list_val)) {
266 error_message->assign("Missing '");
267 error_message->append(kKeyIdsTag);
268 error_message->append("' parameter or not a list");
269 return false;
272 // Create a local list of key ids, so that |key_ids| only gets updated on
273 // success.
274 KeyIdList local_key_ids;
275 for (size_t i = 0; i < list_val->GetSize(); ++i) {
276 std::string encoded_key_id;
277 if (!list_val->GetString(i, &encoded_key_id)) {
278 error_message->assign("'");
279 error_message->append(kKeyIdsTag);
280 error_message->append("'[");
281 error_message->append(base::UintToString(i));
282 error_message->append("] is not string.");
283 return false;
286 // Key ID is a base64-encoded string, so decode it.
287 std::string raw_key_id = DecodeBase64Url(encoded_key_id);
288 if (raw_key_id.empty()) {
289 error_message->assign("'");
290 error_message->append(kKeyIdsTag);
291 error_message->append("'[");
292 error_message->append(base::UintToString(i));
293 error_message->append("] is not valid base64url encoded. Value: ");
294 error_message->append(ShortenTo64Characters(encoded_key_id));
295 return false;
298 // Add the decoded key ID to the list.
299 local_key_ids.push_back(std::vector<uint8>(
300 raw_key_id.data(), raw_key_id.data() + raw_key_id.length()));
303 // All done.
304 key_ids->swap(local_key_ids);
305 error_message->clear();
306 return true;
309 void CreateLicenseRequest(const KeyIdList& key_ids,
310 MediaKeys::SessionType session_type,
311 std::vector<uint8>* license) {
312 // Create the license request.
313 scoped_ptr<base::DictionaryValue> request(new base::DictionaryValue());
314 scoped_ptr<base::ListValue> list(new base::ListValue());
315 for (const auto& key_id : key_ids)
316 list->AppendString(EncodeBase64Url(&key_id[0], key_id.size()));
317 request->Set(kKeyIdsTag, list.release());
319 switch (session_type) {
320 case MediaKeys::TEMPORARY_SESSION:
321 request->SetString(kTypeTag, kTemporarySession);
322 break;
323 case MediaKeys::PERSISTENT_LICENSE_SESSION:
324 request->SetString(kTypeTag, kPersistentLicenseSession);
325 break;
326 case MediaKeys::PERSISTENT_RELEASE_MESSAGE_SESSION:
327 request->SetString(kTypeTag, kPersistentReleaseMessageSession);
328 break;
331 // Serialize the license request as a string.
332 std::string json;
333 JSONStringValueSerializer serializer(&json);
334 serializer.Serialize(*request);
336 // Convert the serialized license request into std::vector and return it.
337 std::vector<uint8> result(json.begin(), json.end());
338 license->swap(result);
341 bool ExtractFirstKeyIdFromLicenseRequest(const std::vector<uint8>& license,
342 std::vector<uint8>* first_key) {
343 const std::string license_as_str(
344 reinterpret_cast<const char*>(!license.empty() ? &license[0] : NULL),
345 license.size());
346 if (!base::IsStringASCII(license_as_str)) {
347 DVLOG(1) << "Non ASCII license: " << license_as_str;
348 return false;
351 scoped_ptr<base::Value> root(base::JSONReader().ReadToValue(license_as_str));
352 if (!root.get() || root->GetType() != base::Value::TYPE_DICTIONARY) {
353 DVLOG(1) << "Not valid JSON: " << license_as_str;
354 return false;
357 // Locate the set from the dictionary.
358 base::DictionaryValue* dictionary =
359 static_cast<base::DictionaryValue*>(root.get());
360 base::ListValue* list_val = NULL;
361 if (!dictionary->GetList(kKeyIdsTag, &list_val)) {
362 DVLOG(1) << "Missing '" << kKeyIdsTag << "' parameter or not a list";
363 return false;
366 // Get the first key.
367 if (list_val->GetSize() < 1) {
368 DVLOG(1) << "Empty '" << kKeyIdsTag << "' list";
369 return false;
372 std::string encoded_key;
373 if (!list_val->GetString(0, &encoded_key)) {
374 DVLOG(1) << "First entry in '" << kKeyIdsTag << "' not a string";
375 return false;
378 std::string decoded_string = DecodeBase64Url(encoded_key);
379 if (decoded_string.empty()) {
380 DVLOG(1) << "Invalid '" << kKeyIdsTag << "' value: " << encoded_key;
381 return false;
384 std::vector<uint8> result(decoded_string.begin(), decoded_string.end());
385 first_key->swap(result);
386 return true;
389 } // namespace media