Bug 1732032 [wpt PR 30868] - fix: canShare() return false if not allowed to use,...
[gecko.git] / storage / mozStorageSQLFunctions.cpp
blobd160ddb18b4bfc3a422e63dde9998ab8ce895d06
1 /* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*-
2 * vim: sw=2 ts=2 et lcs=trail\:.,tab\:>~ :
3 * This Source Code Form is subject to the terms of the Mozilla Public
4 * License, v. 2.0. If a copy of the MPL was not distributed with this
5 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
7 #include "mozilla/ArrayUtils.h"
9 #include "mozStorageSQLFunctions.h"
10 #include "nsTArray.h"
11 #include "nsUnicharUtils.h"
12 #include <algorithm>
14 namespace mozilla {
15 namespace storage {
17 ////////////////////////////////////////////////////////////////////////////////
18 //// Local Helper Functions
20 namespace {
22 /**
23 * Performs the LIKE comparison of a string against a pattern. For more detail
24 * see http://www.sqlite.org/lang_expr.html#like.
26 * @param aPatternItr
27 * An iterator at the start of the pattern to check for.
28 * @param aPatternEnd
29 * An iterator at the end of the pattern to check for.
30 * @param aStringItr
31 * An iterator at the start of the string to check for the pattern.
32 * @param aStringEnd
33 * An iterator at the end of the string to check for the pattern.
34 * @param aEscapeChar
35 * The character to use for escaping symbols in the pattern.
36 * @return 1 if the pattern is found, 0 otherwise.
38 int likeCompare(nsAString::const_iterator aPatternItr,
39 nsAString::const_iterator aPatternEnd,
40 nsAString::const_iterator aStringItr,
41 nsAString::const_iterator aStringEnd, char16_t aEscapeChar) {
42 const char16_t MATCH_ALL('%');
43 const char16_t MATCH_ONE('_');
45 bool lastWasEscape = false;
46 while (aPatternItr != aPatternEnd) {
47 /**
48 * What we do in here is take a look at each character from the input
49 * pattern, and do something with it. There are 4 possibilities:
50 * 1) character is an un-escaped match-all character
51 * 2) character is an un-escaped match-one character
52 * 3) character is an un-escaped escape character
53 * 4) character is not any of the above
55 if (!lastWasEscape && *aPatternItr == MATCH_ALL) {
56 // CASE 1
57 /**
58 * Now we need to skip any MATCH_ALL or MATCH_ONE characters that follow a
59 * MATCH_ALL character. For each MATCH_ONE character, skip one character
60 * in the pattern string.
62 while (*aPatternItr == MATCH_ALL || *aPatternItr == MATCH_ONE) {
63 if (*aPatternItr == MATCH_ONE) {
64 // If we've hit the end of the string we are testing, no match
65 if (aStringItr == aStringEnd) return 0;
66 aStringItr++;
68 aPatternItr++;
71 // If we've hit the end of the pattern string, match
72 if (aPatternItr == aPatternEnd) return 1;
74 while (aStringItr != aStringEnd) {
75 if (likeCompare(aPatternItr, aPatternEnd, aStringItr, aStringEnd,
76 aEscapeChar)) {
77 // we've hit a match, so indicate this
78 return 1;
80 aStringItr++;
83 // No match
84 return 0;
85 } else if (!lastWasEscape && *aPatternItr == MATCH_ONE) {
86 // CASE 2
87 if (aStringItr == aStringEnd) {
88 // If we've hit the end of the string we are testing, no match
89 return 0;
91 aStringItr++;
92 lastWasEscape = false;
93 } else if (!lastWasEscape && *aPatternItr == aEscapeChar) {
94 // CASE 3
95 lastWasEscape = true;
96 } else {
97 // CASE 4
98 if (::ToUpperCase(*aStringItr) != ::ToUpperCase(*aPatternItr)) {
99 // If we've hit a point where the strings don't match, there is no match
100 return 0;
102 aStringItr++;
103 lastWasEscape = false;
106 aPatternItr++;
109 return aStringItr == aStringEnd;
113 * Compute the Levenshtein Edit Distance between two strings.
115 * @param aStringS
116 * a string
117 * @param aStringT
118 * another string
119 * @param _result
120 * an outparam that will receive the edit distance between the arguments
121 * @return a Sqlite result code, e.g. SQLITE_OK, SQLITE_NOMEM, etc.
123 int levenshteinDistance(const nsAString& aStringS, const nsAString& aStringT,
124 int* _result) {
125 // Set the result to a non-sensical value in case we encounter an error.
126 *_result = -1;
128 const uint32_t sLen = aStringS.Length();
129 const uint32_t tLen = aStringT.Length();
131 if (sLen == 0) {
132 *_result = tLen;
133 return SQLITE_OK;
135 if (tLen == 0) {
136 *_result = sLen;
137 return SQLITE_OK;
140 // Notionally, Levenshtein Distance is computed in a matrix. If we
141 // assume s = "span" and t = "spam", the matrix would look like this:
142 // s -->
143 // t s p a n
144 // | 0 1 2 3 4
145 // V s 1 * * * *
146 // p 2 * * * *
147 // a 3 * * * *
148 // m 4 * * * *
150 // Note that the row width is sLen + 1 and the column height is tLen + 1,
151 // where sLen is the length of the string "s" and tLen is the length of "t".
152 // The first row and the first column are initialized as shown, and
153 // the algorithm computes the remaining cells row-by-row, and
154 // left-to-right within each row. The computation only requires that
155 // we be able to see the current row and the previous one.
157 // Allocate memory for two rows.
158 AutoTArray<int, nsAutoString::kStorageSize> row1;
159 AutoTArray<int, nsAutoString::kStorageSize> row2;
161 // Declare the raw pointers that will actually be used to access the memory.
162 int* prevRow = row1.AppendElements(sLen + 1);
163 int* currRow = row2.AppendElements(sLen + 1);
165 // Initialize the first row.
166 for (uint32_t i = 0; i <= sLen; i++) prevRow[i] = i;
168 const char16_t* s = aStringS.BeginReading();
169 const char16_t* t = aStringT.BeginReading();
171 // Compute the empty cells in the "matrix" row-by-row, starting with
172 // the second row.
173 for (uint32_t ti = 1; ti <= tLen; ti++) {
174 // Initialize the first cell in this row.
175 currRow[0] = ti;
177 // Get the character from "t" that corresponds to this row.
178 const char16_t tch = t[ti - 1];
180 // Compute the remaining cells in this row, left-to-right,
181 // starting at the second column (and first character of "s").
182 for (uint32_t si = 1; si <= sLen; si++) {
183 // Get the character from "s" that corresponds to this column,
184 // compare it to the t-character, and compute the "cost".
185 const char16_t sch = s[si - 1];
186 int cost = (sch == tch) ? 0 : 1;
188 // ............ We want to calculate the value of cell "d" from
189 // ...ab....... the previously calculated (or initialized) cells
190 // ...cd....... "a", "b", and "c", where d = min(a', b', c').
191 // ............
192 int aPrime = prevRow[si - 1] + cost;
193 int bPrime = prevRow[si] + 1;
194 int cPrime = currRow[si - 1] + 1;
195 currRow[si] = std::min(aPrime, std::min(bPrime, cPrime));
198 // Advance to the next row. The current row becomes the previous
199 // row and we recycle the old previous row as the new current row.
200 // We don't need to re-initialize the new current row since we will
201 // rewrite all of its cells anyway.
202 int* oldPrevRow = prevRow;
203 prevRow = currRow;
204 currRow = oldPrevRow;
207 // The final result is the value of the last cell in the last row.
208 // Note that that's now in the "previous" row, since we just swapped them.
209 *_result = prevRow[sLen];
210 return SQLITE_OK;
213 // This struct is used only by registerFunctions below, but ISO C++98 forbids
214 // instantiating a template dependent on a locally-defined type. Boo-urns!
215 struct Functions {
216 const char* zName;
217 int nArg;
218 int enc;
219 void* pContext;
220 void (*xFunc)(::sqlite3_context*, int, sqlite3_value**);
223 } // namespace
225 ////////////////////////////////////////////////////////////////////////////////
226 //// Exposed Functions
228 int registerFunctions(sqlite3* aDB) {
229 Functions functions[] = {
230 {"lower", 1, SQLITE_UTF16, 0, caseFunction},
231 {"lower", 1, SQLITE_UTF8, 0, caseFunction},
232 {"upper", 1, SQLITE_UTF16, (void*)1, caseFunction},
233 {"upper", 1, SQLITE_UTF8, (void*)1, caseFunction},
235 {"like", 2, SQLITE_UTF16, 0, likeFunction},
236 {"like", 2, SQLITE_UTF8, 0, likeFunction},
237 {"like", 3, SQLITE_UTF16, 0, likeFunction},
238 {"like", 3, SQLITE_UTF8, 0, likeFunction},
240 {"levenshteinDistance", 2, SQLITE_UTF16, 0, levenshteinDistanceFunction},
241 {"levenshteinDistance", 2, SQLITE_UTF8, 0, levenshteinDistanceFunction},
243 {"utf16Length", 1, SQLITE_UTF16, 0, utf16LengthFunction},
244 {"utf16Length", 1, SQLITE_UTF8, 0, utf16LengthFunction},
247 int rv = SQLITE_OK;
248 for (size_t i = 0; SQLITE_OK == rv && i < ArrayLength(functions); ++i) {
249 struct Functions* p = &functions[i];
250 rv = ::sqlite3_create_function(aDB, p->zName, p->nArg, p->enc, p->pContext,
251 p->xFunc, nullptr, nullptr);
254 return rv;
257 ////////////////////////////////////////////////////////////////////////////////
258 //// SQL Functions
260 void caseFunction(sqlite3_context* aCtx, int aArgc, sqlite3_value** aArgv) {
261 NS_ASSERTION(1 == aArgc, "Invalid number of arguments!");
263 const char16_t* value =
264 static_cast<const char16_t*>(::sqlite3_value_text16(aArgv[0]));
265 nsAutoString data(value,
266 ::sqlite3_value_bytes16(aArgv[0]) / sizeof(char16_t));
267 bool toUpper = ::sqlite3_user_data(aCtx) ? true : false;
269 if (toUpper)
270 ::ToUpperCase(data);
271 else
272 ::ToLowerCase(data);
274 // Set the result.
275 ::sqlite3_result_text16(aCtx, data.get(), data.Length() * sizeof(char16_t),
276 SQLITE_TRANSIENT);
280 * This implements the like() SQL function. This is used by the LIKE operator.
281 * The SQL statement 'A LIKE B' is implemented as 'like(B, A)', and if there is
282 * an escape character, say E, it is implemented as 'like(B, A, E)'.
284 void likeFunction(sqlite3_context* aCtx, int aArgc, sqlite3_value** aArgv) {
285 NS_ASSERTION(2 == aArgc || 3 == aArgc, "Invalid number of arguments!");
287 if (::sqlite3_value_bytes(aArgv[0]) > SQLITE_MAX_LIKE_PATTERN_LENGTH) {
288 ::sqlite3_result_error(aCtx, "LIKE or GLOB pattern too complex",
289 SQLITE_TOOBIG);
290 return;
293 if (!::sqlite3_value_text16(aArgv[0]) || !::sqlite3_value_text16(aArgv[1]))
294 return;
296 const char16_t* a =
297 static_cast<const char16_t*>(::sqlite3_value_text16(aArgv[1]));
298 int aLen = ::sqlite3_value_bytes16(aArgv[1]) / sizeof(char16_t);
299 nsDependentString A(a, aLen);
301 const char16_t* b =
302 static_cast<const char16_t*>(::sqlite3_value_text16(aArgv[0]));
303 int bLen = ::sqlite3_value_bytes16(aArgv[0]) / sizeof(char16_t);
304 nsDependentString B(b, bLen);
305 NS_ASSERTION(!B.IsEmpty(), "LIKE string must not be null!");
307 char16_t E = 0;
308 if (3 == aArgc)
309 E = static_cast<const char16_t*>(::sqlite3_value_text16(aArgv[2]))[0];
311 nsAString::const_iterator itrString, endString;
312 A.BeginReading(itrString);
313 A.EndReading(endString);
314 nsAString::const_iterator itrPattern, endPattern;
315 B.BeginReading(itrPattern);
316 B.EndReading(endPattern);
317 ::sqlite3_result_int(
318 aCtx, likeCompare(itrPattern, endPattern, itrString, endString, E));
321 void levenshteinDistanceFunction(sqlite3_context* aCtx, int aArgc,
322 sqlite3_value** aArgv) {
323 NS_ASSERTION(2 == aArgc, "Invalid number of arguments!");
325 // If either argument is a SQL NULL, then return SQL NULL.
326 if (::sqlite3_value_type(aArgv[0]) == SQLITE_NULL ||
327 ::sqlite3_value_type(aArgv[1]) == SQLITE_NULL) {
328 ::sqlite3_result_null(aCtx);
329 return;
332 const char16_t* a =
333 static_cast<const char16_t*>(::sqlite3_value_text16(aArgv[0]));
334 int aLen = ::sqlite3_value_bytes16(aArgv[0]) / sizeof(char16_t);
336 const char16_t* b =
337 static_cast<const char16_t*>(::sqlite3_value_text16(aArgv[1]));
338 int bLen = ::sqlite3_value_bytes16(aArgv[1]) / sizeof(char16_t);
340 // Compute the Levenshtein Distance, and return the result (or error).
341 int distance = -1;
342 const nsDependentString A(a, aLen);
343 const nsDependentString B(b, bLen);
344 int status = levenshteinDistance(A, B, &distance);
345 if (status == SQLITE_OK) {
346 ::sqlite3_result_int(aCtx, distance);
347 } else if (status == SQLITE_NOMEM) {
348 ::sqlite3_result_error_nomem(aCtx);
349 } else {
350 ::sqlite3_result_error(aCtx, "User function returned error code", -1);
354 void utf16LengthFunction(sqlite3_context* aCtx, int aArgc,
355 sqlite3_value** aArgv) {
356 NS_ASSERTION(1 == aArgc, "Invalid number of arguments!");
358 int len = ::sqlite3_value_bytes16(aArgv[0]) / sizeof(char16_t);
360 // Set the result.
361 ::sqlite3_result_int(aCtx, len);
364 } // namespace storage
365 } // namespace mozilla