changed copyright years in source files
[fegdk.git] / core / code / system / f_mathexpression.cpp
blob0d42709c1899ab9ad17f61febc82b594e3c9bde1
1 /*
2 fegdk: FE Game Development Kit
3 Copyright (C) 2001-2008 Alexey "waker" Yakovenko
5 This library is free software; you can redistribute it and/or
6 modify it under the terms of the GNU Library General Public
7 License as published by the Free Software Foundation; either
8 version 2 of the License, or (at your option) any later version.
10 This library is distributed in the hope that it will be useful,
11 but WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
13 Library General Public License for more details.
15 You should have received a copy of the GNU Library General Public
16 License along with this library; if not, write to the Free
17 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
19 Alexey Yakovenko
20 waker@users.sourceforge.net
23 #include "pch.h"
24 #include "f_string.h"
25 #include "f_mathexpression.h"
26 #include "f_error.h"
27 #include "f_engine.h"
28 #include "f_parmmgr.h"
29 #include "f_helpers.h"
30 #include <assert.h>
31 #ifdef _DEBUG
32 #include <stdio.h>
33 #endif
34 #include "f_console.h"
35 #include "cvars.h"
37 namespace fe
40 // TODO: implement table lookups
42 table lookup details:
44 new op:
46 op_table1d dst, src0, src1
47 dst -- destination register
48 src0 -- pointer to a table of floats
49 src1.x -- index into table
50 src1.yzw -- undefined
52 tree semantics:
54 op_table1d
55 / \
56 src0 src1
58 script syntax example:
60 sintable[time*0.1]
63 // TODO: intrinsic sin/cos (purposes???)
64 // TODO: 2d/3d/4d lookups (purposes???)
67 op semantics:
69 add dst, src0, src1
70 sub dst, src0, src1
71 mul dst, src0, src1
72 div dst, src0, src1
73 dot dst, src0, src1
74 mov dst, src
78 bytecode format
79 char nconsts; // number of constants
81 float constants[nconsts]; // array of constants
83 char ncmds; // number of commands
85 struct {
86 ((uchar)operator_t) op;
87 uchar destreg;
88 ushort parm1;
89 ushort parm2;
90 } commands[ncmds]; // array of commands
92 parm format:
93 bits 0..13 - index of parameter (can be constant, register or variable)
94 bit 14..15 - 00 - constant; 01 - register; 10 - variable; 11 - reserved
98 app-defined variables are accessed through parmManager
99 the way it works should be as follows:
101 float *fval = engine->varManager ()->getFloat4ByName ("name");
105 ushort index = engine->varManager ()->getParmIndexByName ("name");
106 float *fval = engine->varManager ()->getFloat4ByIndex (index);
108 parm manager should (?) provide overloads for type-casting. e.g.:
110 float* getFloat4ByIndex (ushort index);
111 float4* getFloatByIndex (ushort index);
112 float4* getFloat3ByIndex (ushort index);
113 ... etc ...
118 registers are shared between programs
119 registers are not guaranteed to keep meaningfull values between programs
122 void mathExpression::getwords (const cStr &in, std::vector< cStr > &expression)
124 expression.resize (0);
126 const char *s = in.c_str ();
127 for (;;)
129 if (*s == 0)
130 break;
131 else if (*s == '*'
132 || *s == '/'
133 || *s == '+'
134 || *s == '-'
135 || *s == '|'
136 || *s == '['
137 || *s == ']'
138 || *s == '('
139 || *s == ')')
141 cStr ss;
142 ss += *s;
143 expression.push_back (ss);
144 s++;
146 else if (*s <= 32)
147 s++;
148 else
150 const char *e = s;
151 while (*e && *e > 32
152 && *e != '*'
153 && *e != '/'
154 && *e != '+'
155 && *e != '-'
156 && *e != '|'
157 && *e != '['
158 && *e != ']'
159 && *e != '('
160 && *e != ')')
162 e++;
164 if (e == s)
165 break;
166 cStr ss;
167 for (; s != e; s++)
168 ss += *s;
170 expression.push_back (ss);
175 bool mathExpression::validate (const std::vector<cStr> &expr)
177 // +001: no imparity of brackets, e.g. "(()"
178 // +002: no double-operator cases, e.g. "a+*b"
179 // +003: no double-value cases, e.g. "a b"
180 // 004: no operator-only branches, e.g. "-"
181 // +005: no empty branches, e.g. "()"
182 // 006: no branches starting with operators other than "+" or "-"
184 int cnt = 0;
185 int bcnt = 0;
186 int scnt = 0;
187 int sbcnt = 0;
188 size_t i;
190 for (i = 0; i < expr.size (); i++)
192 // 001
193 if (expr[i] == "(")
195 cnt++;
196 bcnt++;
198 else if (expr[i] == ")")
200 cnt--;
201 bcnt++;
204 // 002
205 else if (expr[i] == "*"
206 || expr[i] == "/"
207 || expr[i] == "+"
208 || expr[i] == "-"
209 || expr[i] == "|")
211 if (i < expr.size () - 1)
213 if (expr[i+1] == "*"
214 || expr[i+1] == "/"
215 || expr[i+1] == "+"
216 || expr[i+1] == "-"
217 || expr[i+1] == "|")
219 return false;
224 else if (expr[i] == "[")
226 scnt++;
227 sbcnt++;
230 else if (expr[i] == "]")
232 scnt--;
233 sbcnt++;
236 // 003
237 else // normal value
239 if (i < expr.size () - 1)
241 if (expr[i+1] != "*"
242 && expr[i+1] != "/"
243 && expr[i+1] != "+"
244 && expr[i+1] != "-"
245 && expr[i+1] != "|"
246 && expr[i+1] != "["
247 && expr[i+1] != "]"
248 && expr[i+1] != "("
249 && expr[i+1] != ")")
251 return false;
256 if (cnt) // 001
257 return false;
259 if (scnt)
260 return false;
262 // 004
263 if (expr.size () == 1
264 && (expr[0] == "*"
265 || expr[0] == "/"
266 || expr[0] == "+"
267 || expr[0] == "-"
268 || expr[0] == "|"))
270 return false;
273 // 005
274 if (bcnt == (int)expr.size ())
275 return false;
277 if (sbcnt == (int)expr.size ())
278 return false;
280 return true;
283 cStr mathExpression::opstring (int op)
285 cStr s[] = { "", "+", "-", "*", "/" };
286 return s[op];
289 mathExpression::operator_t mathExpression::getoperator (const std::vector<cStr> &expression, size_t curword)
291 cStr s = expression[curword];
292 if (s == "*")
293 return op_mul;
294 if (s == "/")
295 return op_div;
296 if (s == "+")
297 return op_add;
298 if (s == "-")
299 return op_sub;
300 if (s == "|")
301 return op_dot;
302 return op_no;
305 bool mathExpression::iscomplex (const cStr &s)
307 if (s == "(")
308 return true;
309 return false;
312 mathExpression::node* mathExpression::buildtree (std::vector<cStr> &expr)
314 // handle branch
315 while (expr[0] == "(" && expr.back () == ")")
317 // handle patalogical cases like "((3.2)+ (1))"
318 size_t i;
319 int cnt = 1;
320 for (i = 1; i < expr.size (); i++)
322 if (expr[i] == "(")
323 cnt++;
324 else if (expr[i] == ")")
325 cnt--;
326 if (cnt == 0 && i != expr.size () - 1)
327 break;
329 // remove brackets
330 if (i == expr.size ())
332 expr.erase (expr.begin ());
333 expr.pop_back ();
334 continue;
336 break;
339 // handle table lookups
340 // this check ensures that we don't have a case like "a[x] + b[y]"
341 if (expr.size () > 1
342 && getoperator (expr, 0) == op_no
343 && expr[1] == "["
344 && expr.back () == "]")
346 bool lookup = true;
347 for (size_t i = 2; i < expr.size (); i++)
349 if (expr[i] == "]"
350 && i != expr.size () - 1)
352 lookup = false;
353 break;
357 if (lookup)
359 std::vector<cStr> l;
360 std::vector<cStr> r;
361 l.push_back (expr[0]);
362 for (size_t i = 2; i < expr.size () - 1; i++)
364 r.push_back (expr[i]);
366 node *n = new node;
367 n->op = op_lookup;
368 n->left = buildtree (l);
369 n->right = buildtree (r);
370 return n;
374 // handle expressions started w/ "+" or "-" signs
375 if (expr[0] == "+")
376 expr.erase (expr.begin ());
377 else if (expr[0] == "-")
379 expr.erase (expr.begin ());
380 expr[0] = "-" + expr[0];
383 // handle single value
384 if (expr.size () == 1)
386 node *n = new node;
387 n->leaf = true;
388 // n->value = (float)atof (expr[0]);
389 n->strvalue = expr[0];
390 return n;
393 // find lowest-priority op
394 // FIXME: should select op nearest to a center of expression. dunno if it matters..
395 size_t index = (unsigned)-1;
396 operator_t minop = op_max;
397 size_t i;
398 for (i = 0; i < expr.size (); i++)
400 if (expr[i] == "(")
402 // skip to next ")"
403 int cnt = 1;
404 size_t k;
405 for (k = i+1; k < expr.size (); k++)
407 if (expr[k] == ")")
409 cnt--;
410 if (!cnt)
411 break;
413 else if (expr[k] == "(")
414 cnt++;
416 i = k;
417 continue;
420 if (expr[i] == "[")
422 // skip to next "]"
423 int cnt = 1;
424 size_t k;
425 for (k = i+1; k < expr.size (); k++)
427 if (expr[k] == "]")
429 cnt--;
430 if (!cnt)
431 break;
433 else if (expr[k] == "[")
434 cnt++;
436 i = k;
437 continue;
440 operator_t op = getoperator (expr, i);
441 if (op == op_no)
442 continue;
444 if (op < minop)
446 minop = op;
447 index = i;
451 // dbg
452 #ifdef _DEBUG
453 cStr l;
454 cStr r;
455 #endif
457 // split
458 std::vector<cStr> leftexpr, rightexpr;
459 for (i = 0; i < expr.size (); i++)
461 if (i < index)
463 leftexpr.push_back (expr[i]);
464 #ifdef _DEBUG
465 l += expr[i];
466 #endif
468 else if (i > index)
470 rightexpr.push_back (expr[i]);
471 #ifdef _DEBUG
472 r += expr[i];
473 #endif
477 node *n = new node;
478 n->op = minop;
479 if (!leftexpr.empty ())
480 n->left = buildtree (leftexpr);
481 if (!rightexpr.empty ())
482 n->right = buildtree (rightexpr);
484 return n;
487 cStr mathExpression::strop (operator_t op) const
489 cStr s[] = { "", "add", "sub", "mul", "div", "dot" };
490 return s[op];
493 bool mathExpression::isconst (const char *s)
495 if (*s == '-' || *s == '+')
496 s++;
497 while (*s)
499 if (!isdigit (*s) && *s != '.')
500 return false;
501 s++;
504 return true;
507 int mathExpression::getparamidx (const char *n) const
509 int parmIdx = g_engine->getParmMgr ()->getParmIndexByName (n);
510 if (parmIdx >= (1<<13))
511 Con_Printf ("param index is too big: %s(%d)\n", n, parmIdx);
512 if (parmIdx == -1)
514 parmIdx = g_engine->getCVarManager ()->getCVarIdx (n);
515 if (parmIdx == -1)
516 Con_Printf ("cvar not found: %s\n", n);
517 if (parmIdx >= (1<<13))
518 Con_Printf ("param index is too big: %s(%d)\n", n, parmIdx);
519 parmIdx |= (1<<13);
521 return parmIdx | varmask;
524 float* mathExpression::getvarvalueptr (int v) const
526 int i = v &~ parmtypemask;
527 if (i & (1<<13))
528 return &g_engine->getCVarManager ()->cvarForIdx (i &~ (1<<13))->fvalue;
529 return g_engine->getParmMgr ()->getFloat (i);
532 void mathExpression::traverse (node *n, std::vector< mathExpression::command_t > &commands, std::vector< float > &constants, int reg)
534 assert (reg < maxregisters);
535 if (n->root && n->leaf)
537 command_t c;
539 c.opcode = op_mov;
541 c.dst = reg;
543 if (isconst (n->strvalue))
545 c.src0 = (ushort)constants.size ();
546 c.src0 |= constmask;
547 constants.push_back ((float)atof (n->strvalue));
549 else // var
551 c.src0 = getparamidx (n->strvalue);
554 // printf ("mov r%d, %f\n", reg, n->value);
556 commands.push_back (c);
558 return;
560 else
562 if (!n->left->leaf)
563 traverse (n->left, commands, constants, reg);
564 if (!n->right->leaf)
565 traverse (n->right, commands, constants, n->left->leaf ? reg : reg+1);
566 assert (!n->leaf); // do nothing to leaf nodes, they already have correct values
567 // n->strvalue.printf ("r%d", reg);
568 n->reg = reg;
570 command_t c;
571 c.opcode = n->op;
572 c.dst = reg;
574 if (!n->left->leaf) // reg
576 c.src0 = n->left->reg | regmask;
578 else if (isconst (n->left->strvalue))
580 c.src0 = (ushort)constants.size ();
581 c.src0 |= constmask;
582 constants.push_back ((float)atof (n->left->strvalue));
584 else // var
586 c.src0 = getparamidx (n->left->strvalue);
589 if (!n->right->leaf) // reg
591 c.src1 = n->right->reg | regmask;
593 else if (isconst (n->right->strvalue))
595 c.src1 = (ushort)constants.size ();
596 c.src1 |= constmask;
597 constants.push_back ((float)atof (n->right->strvalue));
599 else // var
601 c.src1 = getparamidx (n->right->strvalue);
604 commands.push_back (c);
606 // printf ("%s r%d, %s, %s\n", strop (n->op).c_str (), reg, n->left->strvalue.c_str (), n->right->strvalue.c_str ());
610 #ifdef _DEBUG
611 void mathExpression::printProgram (const std::vector< mathExpression::command_t > &commands, const std::vector< float > &constants) const
613 FE_EPTR_FROM_OWNER;
614 for (size_t i = 0; i < commands.size (); i++)
616 command_t c = commands[i];
618 switch (c.opcode)
620 case op_mov:
621 printf ("mov r%d, ", c.dst);
622 if (constmask == (c.src0 & parmtypemask))
623 printf ("%f\n", constants[c.src0 &~ parmtypemask]);
624 else
625 printf ("%f\n", engine->parmMgr ()->getFloat (c.src0 &~ parmtypemask));
626 break;
627 case op_mul:
628 case op_div:
629 case op_add:
630 case op_sub:
631 case op_dot:
632 printf ("%s r%d, ", strop ((operator_t)c.opcode).c_str (), c.dst);
634 if (constmask == (c.src0 & parmtypemask))
635 printf ("%f, ", constants[c.src0 &~ parmtypemask]);
636 else if (varmask == (c.src0 & parmtypemask))
637 printf ("%f, ", getvarvalue (c.src0));
638 else if (regmask == (c.src0 & parmtypemask))
639 printf ("r%d, ", c.src0 &~ parmtypemask);
641 if (constmask == (c.src1 & parmtypemask))
642 printf ("%f\n", constants[c.src1 &~ parmtypemask]);
643 else if (varmask == (c.src1 & parmtypemask))
644 printf ("%f\n", getvarvalue (c.src1));
645 else if (regmask == (c.src1 & parmtypemask))
646 printf ("r%d\n", c.src1 &~ parmtypemask);
647 break;
651 #endif
653 void mathExpression::deltree (node *n)
655 if (n->left)
656 deltree (n->left);
657 if (n->right)
658 deltree (n->right);
659 delete n;
662 mathExpression::mathExpression (void)
666 mathExpression::mathExpression (const char *expr)
668 mName = expr;
669 std::vector< cStr > expression;
670 getwords (expr, expression);
671 bool valid = validate (expression);
672 if (!valid)
673 sys_error ("error in expression \"%s\".\n", expr);
675 node *n = buildtree (expression);
676 assert (n);
677 if (!n)
678 sys_error ("failed to build binary tree for expression \"%s\".\n", expr);
679 n->root = true;
681 traverse (n, mCommands, mConstants);
682 deltree (n);
684 #ifdef _DEBUG
685 printProgram (mCommands, mConstants);
686 #endif
689 mathExpression::~mathExpression (void)
693 float mathExpression::evaluate (void)
695 float registers[maxregisters];
696 float *p1, *p2;
698 for (size_t i = 0; i < mCommands.size (); i++)
700 command_t c = mCommands[i];
702 switch (c.opcode)
704 case op_mov:
705 if (constmask == (c.src0 & parmtypemask))
706 registers[c.dst] = mConstants[c.src0 &~ parmtypemask];
707 else
708 registers[c.dst] = *g_engine->getParmMgr ()->getFloat (c.src0 &~ parmtypemask);
709 break;
710 default:
711 if (constmask == (c.src0 & parmtypemask))
712 p1 = &mConstants[c.src0 &~ parmtypemask];
713 else if (varmask == (c.src0 & parmtypemask))
714 p1 = getvarvalueptr (c.src0);
715 else if (regmask == (c.src0 & parmtypemask))
716 p1 = &registers[c.src0 &~ parmtypemask];
717 else
718 sys_error ("[mathExpression::evaluate] something wrong with mathexpression: '%s'", mName.c_str ());
720 if (constmask == (c.src1 & parmtypemask))
721 p2 = &mConstants[c.src1 &~ parmtypemask];
722 else if (varmask == (c.src1 & parmtypemask))
723 p2 = getvarvalueptr (c.src1);
724 else if (regmask == (c.src1 & parmtypemask))
725 p2 = &registers[c.src1 &~ parmtypemask];
726 else
727 sys_error ("[mathExpression::evaluate] something wrong with mathexpression: '%s'", mName.c_str ());
729 switch (c.opcode)
731 case op_lookup:
733 int low = (int)floor (*p2);
734 int hi = (int)ceil (*p2);
735 float frac = *p2 - low;
737 int sz = g_engine->getParmMgr ()->getParmSize (c.src0 &~ parmtypemask);
739 low = low % sz;
740 hi = hi % sz;
742 float l = p1[low];
743 float r = p1[hi];
744 registers[c.dst] = l * (1-frac) + r * (frac);
746 break;
747 case op_mul:
748 registers[c.dst] = (*p1) * (*p2);
749 break;
750 case op_div:
751 registers[c.dst] = (*p1) / (*p2);
752 break;
753 case op_add:
754 registers[c.dst] = (*p1) + (*p2);
755 break;
756 case op_sub:
757 registers[c.dst] = (*p1) - (*p2);
758 break;
759 case op_dot:
761 // float d;
762 // d = p1->x * p2->x + p1->y * p2->y + p1->z * p2->z + p1->w * p2->w;
763 // registers[c.dst].x = registers[c.dst].y = registers[c.dst].z = registers[c.dst].w = d;
764 registers[c.dst] = 1;
766 break;
768 break;
771 return registers[0];