some updates
[iv.d.git] / gengdtw.d
blob810ec22f729a88dae2be48125828748a6ca4d5a5
1 /*
2 * Copyright (c) 2009, Thomas Jaeger <ThJaeger@gmail.com>
4 * Permission to use, copy, modify, and/or distribute this software for any
5 * purpose with or without fee is hereby granted, provided that the above
6 * copyright notice and this permission notice appear in all copies.
8 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
9 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
11 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
13 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
14 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16 * adaptation by Ketmar // Invisible Vector <ketmar@ketmar.no-ip.org>
18 // DTW-based gesture recognizer
19 module iv.gengdtw /*is aliced*/;
20 private:
22 import iv.alice;
23 import iv.vfs;
25 //version = gengdtw_debug_mempools;
28 // ////////////////////////////////////////////////////////////////////////// //
29 public alias DTWFloat = float;
32 // ////////////////////////////////////////////////////////////////////////// //
33 enum dtwInfinity = cast(DTWFloat)0.2;
34 enum EPS = cast(DTWFloat)0.000001;
37 // ////////////////////////////////////////////////////////////////////////// //
38 public final class DTWGlyph {
39 public:
40 enum MinMatchScore = cast(DTWFloat)0.85;
42 private:
43 static struct Point {
44 DTWFloat x, y;
45 DTWFloat t = 0.0, dt = 0.0;
46 DTWFloat alpha = 0.0;
47 @property bool valid () const pure nothrow @safe @nogc { pragma(inline, true); import std.math : isNaN; return (!x.isNaN && !y.isNaN); }
50 private:
51 Point[] points;
52 bool mFinished;
53 string mName;
55 private:
56 static void unsafeArrayAppend(T) (ref T[] arr, auto ref T v) {
57 auto optr = arr.ptr;
58 arr ~= v;
59 if (optr !is arr.ptr) {
60 import core.memory : GC;
61 optr = arr.ptr;
62 if (optr !is null && optr is GC.addrOf(optr)) GC.setAttr(optr, GC.BlkAttr.NO_INTERIOR);
66 static T[] unsafeArrayDup(T) (const(T)[] arr) {
67 auto res = arr.dup;
68 if (res.ptr) {
69 import core.memory : GC;
70 if (res.ptr !is null && res.ptr is GC.addrOf(res.ptr)) GC.setAttr(res.ptr, GC.BlkAttr.NO_INTERIOR);
72 return res;
75 public:
76 this () nothrow @safe @nogc {}
77 this (string aname) nothrow @safe @nogc { mName = aname; }
79 @property const pure nothrow @safe @nogc {
80 bool valid () { pragma(inline, true); return (points.length >= 2); }
81 bool finished () { pragma(inline, true); return (points.length >= 2 && mFinished); }
82 string name () const { pragma(inline, true); return mName; }
83 usize length () const pure nothrow @safe @nogc { return points.length; }
84 alias opDollar = length;
85 Point opIndex (usize idx) { pragma(inline, true); return (idx < points.length ? points[idx] : Point.init); }
88 @property void name(T:const(char)[]) (T v) nothrow @safe {
89 static if (is(T == typeof(null))) mName = null;
90 else static if (is(T == string)) mName = v;
91 else { if (mName != v) mName = v.idup; }
94 // you can use this to "reset" points array without actually deleting it
95 // otherwise it is identical to "clear"
96 auto reset () nothrow @trusted {
97 if (points.length) { points.length = 0; points.assumeSafeAppend; }
98 mName = null;
99 mFinished = false;
100 return this;
103 auto clear () nothrow @trusted {
104 delete points;
105 mName = null;
106 mFinished = false;
107 return this;
110 auto clone () const @trusted {
111 auto res = new DTWGlyph(mName);
112 if (points.length > 0) res.points = unsafeArrayDup(points);
113 res.mFinished = mFinished;
114 return res;
117 // return new (cloned) glyph if `this` is not finished
118 DTWGlyph getFinished () const @safe {
119 if (mFinished) {
120 return this.clone;
121 } else {
122 if (points.length < 2) throw new Exception("invalid glyph");
123 return this.clone.finish;
127 auto appendPoint (int x, int y) @trusted {
128 if (mFinished) throw new Exception("can't append points to finished glyph");
129 enum MinPointDistance = 4;
130 if (points.length > 0) {
131 // check distance and don't add points that are too close to each other
132 immutable lx = x-points[$-1].x, ly = y-points[$-1].y;
133 if (lx*lx+ly*ly < MinPointDistance*MinPointDistance) return this;
135 unsafeArrayAppend(points, Point(x, y));
136 mFinished = false;
137 return this;
140 // "finish" (finalize) gesture in-place
141 auto finish () nothrow @trusted @nogc {
142 import std.math : hypot, atan2, PI;
143 if (mFinished || points.length < 2) return this;
144 DTWFloat total = 0.0;
145 DTWFloat minX, minY, maxX, maxY;
146 DTWFloat scaleX, scaleY, scale;
147 points[0].t = 0.0;
148 foreach (immutable idx; 0..points.length-1) {
149 total += hypot(points.ptr[idx+1].x-points.ptr[idx].x, points.ptr[idx+1].y-points.ptr[idx].y);
150 points.ptr[idx+1].t = total;
152 foreach (ref Point v; points[]) v.t /= total;
153 minX = maxX = points.ptr[0].x;
154 minY = maxY = points.ptr[0].y;
155 foreach (immutable idx; 1..points.length) {
156 if (points.ptr[idx].x < minX) minX = points.ptr[idx].x;
157 if (points.ptr[idx].x > maxX) maxX = points.ptr[idx].x;
158 if (points.ptr[idx].y < minY) minY = points.ptr[idx].y;
159 if (points.ptr[idx].y > maxY) maxY = points.ptr[idx].y;
161 scaleX = maxX-minX;
162 scaleY = maxY-minY;
163 scale = (scaleX > scaleY ? scaleX : scaleY);
164 if (scale < 0.001) scale = 1;
165 immutable DTWFloat mx2 = (minX+maxX)/2.0;
166 immutable DTWFloat my2 = (minY+maxY)/2.0;
167 foreach (immutable idx; 0..points.length) {
168 points.ptr[idx].x = (points.ptr[idx].x-mx2)/scale+0.5;
169 points.ptr[idx].y = (points.ptr[idx].y-my2)/scale+0.5;
171 foreach (immutable idx; 1..points.length-1) {
172 points.ptr[idx].dt = points.ptr[idx+1].t-points.ptr[idx].t;
173 points.ptr[idx].alpha = atan2(points.ptr[idx+1].y-points.ptr[idx].y, points.ptr[idx+1].x-points.ptr[idx].x)/PI;
175 mFinished = true;
176 return this;
179 private {
180 // do avoid frequent malloc()/free() calls, we'll use memory pools instead
181 // statics are thread-local, so it is thread-safe, and this function cannot be called recursively
182 // we will leak this memory, but it doesn't matter in practice
183 static void*[3] mempools;
184 static usize[3] mempoolsizes;
186 static T* poolAlloc(T, ubyte poolidx) (usize count) nothrow @trusted @nogc {
187 static assert(poolidx < mempools.length);
188 if (count == 0) count = 1; // just in case
189 usize sz = T.sizeof*count;
190 if (sz > mempoolsizes.ptr[poolidx]) {
191 import core.stdc.stdlib : realloc;
192 import core.exception : onOutOfMemoryError;
193 version(gengdtw_debug_mempools) {{ import core.stdc.stdio; stderr.fprintf("pool #%u: reallocing %u to %u\n", cast(uint)poolidx, cast(uint)mempoolsizes.ptr[poolidx], cast(uint)sz); }}
194 auto data = realloc(mempools.ptr[poolidx], sz);
195 if (data is null) onOutOfMemoryError();
196 mempools.ptr[poolidx] = data;
197 mempoolsizes.ptr[poolidx] = sz;
198 } else if (sz < mempoolsizes.ptr[poolidx]/2) {
199 // shring used memory, 'cause why not?
200 import core.stdc.stdlib : realloc;
201 auto data = realloc(mempools.ptr[poolidx], sz+sz/3);
202 if (data !is null) {
203 version(gengdtw_debug_mempools) {{ import core.stdc.stdio; stderr.fprintf("pool #%u: shrinking %u to %u\n", cast(uint)poolidx, cast(uint)mempoolsizes.ptr[poolidx], cast(uint)(sz+sz/3)); }}
204 mempools.ptr[poolidx] = data;
205 mempoolsizes.ptr[poolidx] = sz;
207 } else {
208 version(gengdtw_debug_mempools) {{ import core.stdc.stdio; stderr.fprintf("pool #%u: reusing %u of %u\n", cast(uint)poolidx, cast(uint)sz, cast(uint)mempoolsizes.ptr[poolidx]); }}
210 T* res = cast(T*)(mempools.ptr[poolidx]);
211 assert(res !is null);
212 return res;
215 static void poolFree(ubyte poolidx) () nothrow @trusted @nogc {
216 // nothing to do here
219 version(gengdtw_debug_mempools) static ~this () {
220 import core.stdc.stdio;
221 foreach (immutable pidx; 0..mempools.length) {
222 stderr.fprintf("max pool #%u size: %u\n", cast(uint)pidx, cast(uint)mempoolsizes[pidx]);
227 /* To compare two gestures, we use dynamic programming to minimize (an approximation)
228 * of the integral over square of the angle difference among (roughly) all
229 * reparametrizations whose slope is always between 1/2 and 2.
230 * Use `isGoodScore()` to check if something was (somewhat) reliably matched.
232 DTWFloat compare (const(DTWGlyph) b) const nothrow @nogc {
233 static struct A2D(T, ubyte poolidx) {
234 static assert(poolidx < DTWGlyph.mempools.length);
235 private:
236 usize dim0, dim1;
237 T* data;
239 public nothrow @trusted @nogc:
240 @disable this ();
241 @disable this (this);
243 this (usize d0, usize d1, T initV=T.init) {
244 if (d0 < 1) d0 = 1;
245 if (d1 < 1) d1 = 1;
246 data = DTWGlyph.poolAlloc!(T, poolidx)(d0*d1);
247 dim0 = d0;
248 dim1 = d1;
249 data[0..d0*d1] = initV;
252 ~this () {
253 DTWGlyph.poolFree!poolidx();
254 // just in case
255 data = null;
256 dim0 = dim1 = 0;
259 T opIndex (in usize i0, in usize i1) const {
260 //pragma(inline, true);
261 if (i0 >= dim0 || i1 >= dim1) assert(0); // the thing that should not be
262 return data[i1*dim0+i0];
265 T opIndexAssign (in T v, in usize i0, in usize i1) {
266 //pragma(inline, true);
267 if (i0 >= dim0 || i1 >= dim1) assert(0); // the thing that should not be
268 return (data[i1*dim0+i0] = v);
272 const(DTWGlyph) a = this;
273 if (!finished || b is null || !b.finished) return dtwInfinity;
274 immutable m = a.points.length-1;
275 immutable n = b.points.length-1;
276 DTWFloat cost = dtwInfinity;
277 auto dist = A2D!(DTWFloat, 0)(m+1, n+1, dtwInfinity);
278 auto prevx = A2D!(usize, 1)(m+1, n+1);
279 auto prevy = A2D!(usize, 2)(m+1, n+1);
280 //foreach (/+auto+/ idx; 0..m+1) foreach (/+auto+/ idx1; 0..n+1) dist[idx, idx1] = dtwInfinity;
281 //dist[m, n] = dtwInfinity;
282 dist[0, 0] = 0.0;
283 foreach (immutable x; 0..m) {
284 foreach (immutable y; 0..n) {
285 if (dist[x, y] >= dtwInfinity) continue;
286 DTWFloat tx = a.points[x].t;
287 DTWFloat ty = b.points[y].t;
288 usize maxX = x, maxY = y;
289 usize k = 0;
290 while (k < 4) {
291 void step (usize x2, usize y2) nothrow @nogc {
292 DTWFloat dtx = a.points[x2].t-tx;
293 DTWFloat dty = b.points[y2].t-ty;
294 if (dtx >= dty*2.2 || dty >= dtx*2.2 || dtx < EPS || dty < EPS) return;
295 ++k;
296 DTWFloat d = 0.0;
297 usize i = x, j = y;
298 DTWFloat nexttx = (a.points[i+1].t-tx)/dtx;
299 DTWFloat nextty = (b.points[j+1].t-ty)/dty;
300 DTWFloat curT = 0.0;
301 for (;;) {
302 immutable DTWFloat ad = sqr(angleDiff(a.points[i].alpha, b.points[j].alpha));
303 DTWFloat nextt = (nexttx < nextty ? nexttx : nextty);
304 bool done = (nextt >= 1.0-EPS);
305 if (done) nextt = 1.0;
306 d += (nextt-curT)*ad;
307 if (done) break;
308 curT = nextt;
309 if (nexttx < nextty) {
310 nexttx = (a.points[++i+1].t-tx)/dtx;
311 } else {
312 nextty = (b.points[++j+1].t-ty)/dty;
315 DTWFloat newDist = dist[x, y]+d*(dtx+dty);
316 if (newDist != newDist) assert(0); /*???*/
317 if (newDist >= dist[x2, y2]) return;
318 prevx[x2, y2] = x;
319 prevy[x2, y2] = y;
320 dist[x2, y2] = newDist;
323 if (a.points[maxX+1].t-tx > b.points[maxY+1].t-ty) {
324 ++maxY;
325 if (maxY == n) { step(m, n); break; }
326 foreach (usize x2; x+1..maxX+1) step(x2, maxY);
327 } else {
328 ++maxX;
329 if (maxX == m) { step(m, n); break; }
330 foreach (usize y2; y+1..maxY+1) step(maxX, y2);
335 return dist[m, n]; // cost
338 DTWFloat score (const(DTWGlyph) pat) const nothrow {
339 if (!finished || pat is null || !pat.finished) return -1.0;
340 DTWFloat cost = pat.compare(this), score;
341 if (cost >= dtwInfinity) return -1.0;
342 score = 1.0-2.5*cost;
343 if (score <= 0.0) return 0.0;
344 return score;
347 bool match (const(DTWGlyph) pat) const nothrow { return (pat !is null ? pat.score(this) >= MinMatchScore : false); }
349 static bool isGoodScore (DTWFloat score) {
350 pragma(inline, true);
351 import std.math : isNaN;
352 return (!score.isNaN ? score >= MinMatchScore : false);
355 private:
356 DTWFloat angle (usize idx) const pure nothrow @safe @nogc { return (idx < points.length ? points[idx].alpha : 0.0); }
358 DTWFloat angleDiff (const(DTWGlyph) b, usize idx0, usize idx1) const pure nothrow @safe @nogc {
359 import std.math : abs;
360 return (b !is null && idx0 < points.length && idx1 < b.points.length ? abs(angleDiff(angle(idx0), b.angle(idx1))) : DTWFloat.nan);
363 static DTWFloat sqr (in DTWFloat x) pure nothrow @safe @nogc { pragma(inline, true); return x*x; }
365 static DTWFloat angleDiff (in DTWFloat alpha, in DTWFloat beta) pure nothrow @safe @nogc {
366 // return 1.0-cos((alpha-beta)*PI);
367 pragma(inline, true);
368 DTWFloat d = alpha-beta;
369 if (d < cast(DTWFloat)-1.0) d += cast(DTWFloat)2.0; else if (d > cast(DTWFloat)1.0) d -= cast(DTWFloat)2.0;
370 return d;
373 public:
374 const(DTWGlyph) findMatch (const(DTWGlyph)[] list, DTWFloat* outscore=null) const {
375 DTWFloat bestScore = cast(DTWFloat)-1.0;
376 DTWGlyph res = null;
377 if (outscore !is null) *outscore = DTWFloat.nan;
378 if (valid) {
379 auto me = getFinished;
380 scope(exit) delete me;
381 foreach (const DTWGlyph gs; list) {
382 if (gs is null || !gs.valid) continue;
383 auto g = gs.getFinished;
384 scope(exit) delete g;
385 DTWFloat score = g.score(me);
386 if (score >= MinMatchScore && score > bestScore) {
387 bestScore = score;
388 res = cast(DTWGlyph)gs; // sorry
392 if (res !is null && outscore !is null) *outscore = bestScore;
393 return res;
398 // ////////////////////////////////////////////////////////////////////////// //
399 public DTWGlyph[] gstLibLoad (VFile fl) {
400 DTWGlyph loadGlyph () {
401 uint ver;
402 auto res = new DTWGlyph();
403 auto cnt = fl.readNum!uint;
404 uint finished = fl.readNum!uint;
405 if (finished&0x80) {
406 // cnt is version; ignore for now, but it is 1
407 if (cnt != 1 && cnt != 2) throw new Exception("invalid glyph version");
408 ver = cnt;
409 finished &= 0x01;
410 cnt = fl.readNum!uint;
411 } else {
412 ver = 0;
414 if (cnt > 0x7fff_ffff) throw new Exception("invalid glyph point count");
415 res.points.length = cnt;
416 res.mFinished = (res.points.length > 1 && finished);
417 foreach (ref pt; res.points) {
418 if (ver == 0) {
419 pt.x = fl.readNum!double;
420 pt.y = fl.readNum!double;
421 pt.t = fl.readNum!double;
422 pt.dt = fl.readNum!double;
423 pt.alpha = fl.readNum!double;
424 } else {
425 // v1 and v2
426 pt.x = fl.readNum!float;
427 pt.y = fl.readNum!float;
428 if (ver == 1 || finished) {
429 pt.t = fl.readNum!float;
430 pt.dt = fl.readNum!float;
431 pt.alpha = fl.readNum!float;
435 return res;
438 DTWGlyph[] res;
439 auto sign = fl.readNum!uint;
440 if (sign != 0x4C53384Bu) throw new Exception("invalid glyph library signature"); // "K8SL"
441 auto ver = fl.readNum!ubyte;
442 if (ver != 0) throw new Exception("invalid glyph library version");
443 auto count = fl.readNum!uint;
444 if (count > 0x7fff_ffff) throw new Exception("too many glyphs in library");
445 while (count-- > 0) {
446 // name
447 auto len = fl.readNum!uint();
448 if (len > 1024) throw new Exception("glyph name too long");
449 string name;
450 if (len > 0) {
451 auto buf = new char[](len);
452 fl.rawReadExact(buf);
453 name = cast(string)buf; // it is safe to cast here
455 // glyph
456 auto g = loadGlyph();
457 g.mName = name;
458 res ~= g;
460 return res;
464 // ////////////////////////////////////////////////////////////////////////// //
465 public void gstLibSave (VFile fl, const(DTWGlyph)[] list) {
466 if (list.length > uint.max/16) throw new Exception("too many glyphs");
467 fl.rawWriteExact("K8SL");
468 fl.writeNum!ubyte(0); // version
469 fl.writeNum!uint(cast(uint)list.length);
470 foreach (const DTWGlyph g; list) {
471 if (g is null || !g.valid) throw new Exception("can't save invalid glyph");
472 // name
473 if (g.name.length > 1024) throw new Exception("glyph name too long");
474 fl.writeNum!uint(cast(uint)g.name.length);
475 fl.rawWriteExact(g.name);
476 // points
477 if (g.points.length > uint.max/64) throw new Exception("too many points in glyph");
478 ubyte ptver;
479 static if (DTWFloat.sizeof == float.sizeof) {
480 // v1 or v2
481 if (!g.finished) {
482 // v2
483 ptver = 2;
484 fl.writeNum!uint(2);
485 fl.writeNum!uint(0x80);
486 } else {
487 // v1
488 ptver = 1;
489 fl.writeNum!uint(1);
490 fl.writeNum!uint(0x81);
492 fl.writeNum!uint(cast(uint)g.points.length);
493 } else {
494 // v0
495 ptver = 0;
496 fl.writeNum!uint(cast(uint)g.points.length);
497 fl.writeNum!uint(g.finished ? 1 : 0);
499 foreach (immutable pt; g.points) {
500 if (ptver == 0) {
501 fl.writeNum!double(pt.x);
502 fl.writeNum!double(pt.y);
503 fl.writeNum!double(pt.t);
504 fl.writeNum!double(pt.dt);
505 fl.writeNum!double(pt.alpha);
506 } else {
507 fl.writeNum!float(pt.x);
508 fl.writeNum!float(pt.y);
509 if (ptver == 1) {
510 fl.writeNum!float(pt.t);
511 fl.writeNum!float(pt.dt);
512 fl.writeNum!float(pt.alpha);