Lua: Don't lua_error() out of context with pending dtors
[lsnes.git] / src / library / mathexpr.cpp
blobe1ebbda78c9ca6ca8c419a16313243b53a585d7e
1 #include "mathexpr.hpp"
2 #include "string.hpp"
3 #include <set>
4 #include <map>
6 namespace mathexpr
8 operinfo::operinfo(std::string funcname)
9 : fnname(funcname), is_operator(false), operands(0), precedence(0), rtl(false)
12 operinfo::operinfo(std::string opername, unsigned _operands, int _percedence, bool _rtl)
13 : fnname(opername), is_operator(true), operands(_operands), precedence(_percedence), rtl(_rtl)
16 operinfo::~operinfo()
20 typeinfo::~typeinfo()
24 mathexpr::mathexpr(typeinfo* _type)
25 : type(*_type)
27 owns_operator = false;
28 state = UNDEFINED;
29 _value = NULL;
30 fn = (operinfo*)0xDEADBEEF;
33 mathexpr::mathexpr(typeinfo* _type, GC::pointer<mathexpr> fwd)
34 : type(*_type)
36 owns_operator = false;
37 state = FORWARD;
38 _value = type.allocate();
39 arguments.push_back(&*fwd);
40 fn = NULL;
43 mathexpr::mathexpr(value _val)
44 : type(*_val.type)
46 owns_operator = false;
47 state = FIXED;
48 _value = type.copy_allocate(_val._value);
49 fn = NULL;
52 mathexpr::mathexpr(typeinfo* _type, const std::string& _val, bool string)
53 : type(*_type)
55 owns_operator = false;
56 state = FIXED;
57 _value = type.parse(_val, string);
58 fn = NULL;
61 mathexpr::mathexpr(typeinfo* _type, operinfo* _fn, std::vector<GC::pointer<mathexpr>> _args, bool _owns_operator)
62 : type(*_type), fn(_fn), owns_operator(_owns_operator)
64 try {
65 for(auto& i : _args)
66 arguments.push_back(&*i);
67 _value = type.allocate();
68 state = TO_BE_EVALUATED;
69 } catch(...) {
70 if(owns_operator)
71 delete fn;
72 throw;
76 mathexpr::~mathexpr()
78 if(owns_operator && fn)
79 delete fn;
80 type.deallocate(_value);
83 void mathexpr::reset()
85 if(state == TO_BE_EVALUATED || state == FIXED || state == UNDEFINED || state == FORWARD)
86 return;
87 if(state == FORWARD_EVALD || state == FORWARD_EVALING) {
88 state = FORWARD;
89 return;
91 state = TO_BE_EVALUATED;
92 for(auto i : arguments)
93 i->reset();
96 mathexpr::mathexpr(const mathexpr& m)
97 : state(m.state), type(m.type), fn(m.fn), _error(m._error), arguments(m.arguments)
99 _value = m._value ? type.copy_allocate(m._value) : NULL;
100 if(state == EVALUATING) state = TO_BE_EVALUATED;
103 mathexpr& mathexpr::operator=(const mathexpr& m)
105 if(this == &m)
106 return *this;
107 std::string _xerror = m._error;
108 std::vector<mathexpr*> _arguments = m.arguments;
109 if(m._value) {
110 if(!_value)
111 _value = m.type.copy_allocate(m._value);
112 else
113 m.type.copy(_value, m._value);
114 } else if(_value) {
115 m.type.deallocate(_value);
116 _value = NULL;
117 } else
118 _value = NULL;
119 type = m.type;
120 fn = m.fn;
121 state = m.state;
122 owns_operator = m.owns_operator;
123 m.owns_operator = false;
124 std::swap(arguments, _arguments);
125 std::swap(_error, _xerror);
126 return *this;
129 value mathexpr::evaluate()
131 value ret;
132 ret.type = &type;
133 switch(state) {
134 case TO_BE_EVALUATED:
135 //Need to evaluate.
136 try {
137 for(auto i : arguments) {
138 if(&i->type != &type) {
139 throw error(error::TYPE_MISMATCH,
140 "Types for function mismatch");
143 state = EVALUATING;
144 std::vector<std::function<value()>> promises;
145 for(auto i : arguments) {
146 mathexpr* m = i;
147 promises.push_back([m]() { return m->evaluate(); });
149 value tmp;
150 tmp.type = &type;
151 tmp._value = _value;
152 fn->evaluate(tmp, promises);
153 state = EVALUATED;
154 } catch(error& e) {
155 state = FAILED;
156 errcode = e.get_code();
157 _error = e.what();
158 throw;
159 } catch(std::exception& e) {
160 state = FAILED;
161 errcode = error::UNKNOWN;
162 _error = e.what();
163 throw;
164 } catch(...) {
165 state = FAILED;
166 errcode = error::UNKNOWN;
167 _error = "Unknown error";
168 throw;
170 ret._value = _value;
171 return ret;
172 case EVALUATING:
173 case FORWARD_EVALING:
174 //Circular dependency.
175 mark_error_and_throw(error::CIRCULAR, "Circular dependency");
176 case EVALUATED:
177 case FIXED:
178 case FORWARD_EVALD:
179 ret._value = _value;
180 return ret;
181 case UNDEFINED:
182 throw error(error::UNDEFINED, "Undefined variable");
183 case FAILED:
184 throw error(errcode, _error);
185 case FORWARD:
186 try {
187 state = FORWARD_EVALING;
188 value v = arguments[0]->evaluate();
189 type.copy(_value, v._value);
190 state = FORWARD_EVALD;
191 return v;
192 } catch(...) {
193 state = FORWARD;
194 throw;
197 throw error(error::INTERNAL, "Internal error (shouldn't be here)");
200 void mathexpr::trace()
202 for(auto i : arguments)
203 i->mark();
206 void mathexpr::mark_error_and_throw(error::errorcode _errcode, const std::string& _xerror)
208 if(state == EVALUATING) {
209 state = FAILED;
210 errcode = _errcode;
211 _error = _xerror;
213 if(state == FORWARD_EVALING) {
214 state = FORWARD;
216 throw error(_errcode, _error);
219 namespace
222 X_EXPR -> VALUE
223 X_EXPR -> STRING
224 X_EXPR -> NONARY-OP
225 X_EXPR -> FUNCTION X_ARGS
226 X_EXPR -> UNARY-OP X_EXPR
227 X_EXPR -> X_LAMBDA X_EXPR
228 X_LAMBDA -> X_EXPR BINARY-OP
229 X_ARGS -> OPEN-PAREN CLOSE_PAREN
230 X_ARGS -> OPEN-PAREN X_TAIL
231 X_TAIL -> X_PAIR X_TAIL
232 X_TAIL -> X_EXPR CLOSE-PAREN
233 X_PAIR -> X_EXPR COMMA
236 //SUBEXPRESSION -> VALUE
237 //SUBEXPRESSION -> STRING
238 //SUBEXPRESSION -> FUNCTION OPEN-PAREN CLOSE-PAREN
239 //SUBEXPRESSION -> FUNCTION OPEN-PAREN (SUBEXPRESSION COMMA)* SUBEXPRESSION CLOSE-PAREN
240 //SUBEXPRESSION -> OPEN-PAREN SUBEXPRESSION CLOSE-PAREN
241 //SUBEXPRESSION -> SUBEXPRESSION BINARY-OP SUBEXPRESSION
242 //SUBEXPRESSION -> UNARY-OP SUBEXPRESSION
243 //SUBEXPRESSION -> NONARY-OP
245 bool is_alphanumeric(char ch)
247 if(ch >= '0' && ch <= '9') return true;
248 if(ch >= 'a' && ch <= 'z') return true;
249 if(ch >= 'A' && ch <= 'Z') return true;
250 if(ch == '_') return true;
251 if(ch == '.') return true;
252 return false;
255 enum token_kind
257 TT_OPEN_PAREN,
258 TT_CLOSE_PAREN,
259 TT_COMMA,
260 TT_FUNCTION,
261 TT_OPERATOR,
262 TT_VARIABLE,
263 TT_VALUE,
264 TT_STRING,
267 struct operations_set
269 operations_set(std::set<operinfo*>& ops)
270 : operations(ops)
273 operinfo* find_function(const std::string& name)
275 operinfo* fn = NULL;
276 for(auto j : operations) {
277 if(name == j->fnname && !j->is_operator)
278 fn = j;
280 if(!fn) throw std::runtime_error("No such function '" + name + "'");
281 return fn;
283 operinfo* find_operator(const std::string& name, unsigned arity)
285 for(auto j : operations) {
286 if(name == j->fnname && j->is_operator && j->operands == arity)
287 return j;
289 return NULL;
291 private:
292 std::set<operinfo*>& operations;
295 struct subexpression
297 subexpression(token_kind k) : kind(k) {}
298 subexpression(token_kind k, const std::string& str) : kind(k), string(str) {}
299 token_kind kind;
300 std::string string;
303 size_t find_last_in_sub(std::vector<subexpression>& ss, size_t first)
305 size_t depth;
306 switch(ss[first].kind) {
307 case TT_FUNCTION:
308 if(first + 1 == ss.size() || ss[first + 1].kind != TT_OPEN_PAREN)
309 throw std::runtime_error("Function requires argument list");
310 first++;
311 case TT_OPEN_PAREN:
312 depth = 0;
313 while(first < ss.size()) {
314 if(ss[first].kind == TT_OPEN_PAREN)
315 depth++;
316 if(ss[first].kind == TT_CLOSE_PAREN)
317 if(!--depth) break;
318 first++;
320 if(first == ss.size())
321 throw std::runtime_error("Unmatched '('");
322 return first;
323 case TT_CLOSE_PAREN:
324 throw std::runtime_error("Unmatched ')'");
325 case TT_COMMA:
326 throw std::runtime_error("',' only allowed in function arguments");
327 case TT_VALUE:
328 case TT_STRING:
329 case TT_OPERATOR:
330 case TT_VARIABLE:
331 return first;
333 throw std::runtime_error("Internal error (shouldn't be here)");
336 size_t find_end_of_arg(std::vector<subexpression>& ss, size_t first)
338 size_t depth = 0;
339 while(first < ss.size()) {
340 if(depth == 0 && ss[first].kind == TT_COMMA)
341 return first;
342 if(ss[first].kind == TT_OPEN_PAREN)
343 depth++;
344 if(ss[first].kind == TT_CLOSE_PAREN) {
345 if(depth == 0)
346 return first;
347 depth--;
349 first++;
351 return ss.size();
354 struct expr_or_op
356 expr_or_op(GC::pointer<mathexpr> e) : expr(e), typei(NULL) {}
357 expr_or_op(std::string o) : op(o), typei(NULL) {}
358 GC::pointer<mathexpr> expr;
359 std::string op;
360 operinfo* typei;
363 GC::pointer<mathexpr> parse_rec(typeinfo& _type, std::vector<expr_or_op>& operands,
364 size_t first, size_t last)
366 if(operands.empty())
367 return GC::pointer<mathexpr>(GC::obj_tag(), &_type);
368 if(last - first > 1) {
369 //Find the highest percedence operator.
370 size_t best = last;
371 for(size_t i = first; i < last; i++) {
372 if(operands[i].typei) {
373 if(best == last)
374 best = i;
375 else if(operands[i].typei->precedence < operands[best].typei->precedence) {
376 best = i;
377 } else if(!operands[best].typei->rtl &&
378 operands[i].typei->precedence == operands[best].typei->precedence) {
379 best = i;
383 if(best == last) throw std::runtime_error("Internal error: No operands?");
384 if(operands[best].typei->operands == 1) {
385 //The operator is unary, collect up all following unary operators.
386 size_t j = first;
387 while(operands[j].typei)
388 j++;
389 std::vector<GC::pointer<mathexpr>> args;
390 args.push_back(parse_rec(_type, operands, first + 1, j + 1));
391 return GC::pointer<mathexpr>(GC::obj_tag(), &_type,
392 operands[best].typei, args);
393 } else {
394 //Binary operator.
395 std::vector<GC::pointer<mathexpr>> args;
396 args.push_back(parse_rec(_type, operands, first, best));
397 args.push_back(parse_rec(_type, operands, best + 1, last));
398 return GC::pointer<mathexpr>(GC::obj_tag(), &_type,
399 operands[best].typei, args);
402 return operands[first].expr;
405 GC::pointer<mathexpr> parse_rec(typeinfo& _type, std::vector<subexpression>& ss,
406 std::set<operinfo*>& operations,
407 std::function<GC::pointer<mathexpr>(const std::string&)> vars, size_t first, size_t last)
409 operations_set opset(operations);
410 std::vector<expr_or_op> operands;
411 std::vector<GC::pointer<mathexpr>> args;
412 operinfo* fn;
413 for(size_t i = first; i < last; i++) {
414 size_t l = find_last_in_sub(ss, i);
415 if(l >= last) throw std::runtime_error("Internal error: Improper nesting");
416 switch(ss[i].kind) {
417 case TT_OPEN_PAREN:
418 operands.push_back(parse_rec(_type, ss, operations, vars, i + 1, l));
419 break;
420 case TT_VALUE:
421 operands.push_back(GC::pointer<mathexpr>(GC::obj_tag(), &_type,
422 ss[i].string, false));
423 break;
424 case TT_STRING:
425 operands.push_back(GC::pointer<mathexpr>(GC::obj_tag(), &_type,
426 ss[i].string, true));
427 break;
428 case TT_VARIABLE:
429 //We have to warp this is identify transform to make the evaluation lazy.
430 operands.push_back(GC::pointer<mathexpr>(GC::obj_tag(), &_type,
431 vars(ss[i].string)));
432 break;
433 case TT_FUNCTION:
434 fn = opset.find_function(ss[i].string);
435 i += 2;
436 while(ss[i].kind != TT_CLOSE_PAREN) {
437 size_t k = find_end_of_arg(ss, i);
438 args.push_back(parse_rec(_type, ss, operations, vars, i, k));
439 if(k < ss.size() && ss[k].kind == TT_COMMA)
440 i = k + 1;
441 else
442 i = k;
444 operands.push_back(GC::pointer<mathexpr>(GC::obj_tag(), &_type, fn,
445 args));
446 args.clear();
447 break;
448 case TT_OPERATOR:
449 operands.push_back(ss[i].string);
450 break;
451 case TT_CLOSE_PAREN:
452 case TT_COMMA:
453 ; //Can't appen.
455 i = l;
457 if(operands.empty())
458 throw std::runtime_error("Empty subexpression");
459 //Translate nonary operators to values.
460 for(auto& i : operands) {
461 if(!(bool)i.expr) {
462 auto fn = opset.find_operator(i.op, 0);
463 if(fn)
464 i.expr = GC::pointer<mathexpr>(GC::obj_tag(), &_type, fn,
465 std::vector<GC::pointer<mathexpr>>());
468 //Check that there aren't two consequtive subexpressions and mark operators.
469 bool was_operand = false;
470 for(auto& i : operands) {
471 bool is_operand = (bool)i.expr;
472 if(!is_operand && !was_operand)
473 if(!(i.typei = opset.find_operator(i.op, 1)))
474 throw std::runtime_error("'" + i.op + "' is not an unary operator");
475 if(!is_operand && was_operand)
476 if(!(i.typei = opset.find_operator(i.op, 2)))
477 throw std::runtime_error("'" + i.op + "' is not a binary operator");
478 if(was_operand && is_operand)
479 throw std::runtime_error("Expected operator, got operand");
480 was_operand = is_operand;
482 if(!was_operand)
483 throw std::runtime_error("Expected operand, got end of subexpression");
484 //Okay, now the expression has been reduced into series of operators and subexpressions.
485 //If there are multiple consequtive operators, the first (except as first item) is binary,
486 //and the others are unary.
487 return parse_rec(_type, operands, 0, operands.size());
490 void tokenize(const std::string& expr, std::set<operinfo*>& operations,
491 std::vector<subexpression>& tokenization)
493 for(size_t i = 0; i < expr.length();) {
494 if(expr[i] == '(') {
495 tokenization.push_back(subexpression(TT_OPEN_PAREN));
496 i++;
497 } else if(expr[i] == ')') {
498 tokenization.push_back(subexpression(TT_CLOSE_PAREN));
499 i++;
500 } else if(expr[i] == ',') {
501 tokenization.push_back(subexpression(TT_COMMA));
502 i++;
503 } else if(expr[i] == ' ') {
504 i++;
505 } else if(expr[i] == '$') {
506 //Variable. If the next character is {, parse until }, otherwise parse until
507 //non-alphanum.
508 std::string varname = "";
509 if(i + 1 < expr.length() && expr[i + 1] == '{') {
510 //Terminate by '}'.
511 i++;
512 while(i + 1 < expr.length() && expr[i + 1] != '}')
513 varname += std::string(1, expr[++i]);
514 if(i + 1 >= expr.length() || expr[i + 1] != '}')
515 throw std::runtime_error("${ without matching }");
516 i++;
517 } else {
518 //Terminate by non-alphanum.
519 while(i + 1 < expr.length() && is_alphanumeric(expr[i + 1]))
520 varname += std::string(1, expr[++i]);
522 tokenization.push_back(subexpression(TT_VARIABLE, varname));
523 i++;
524 } else if(expr[i] == '"') {
525 bool escape = false;
526 size_t endpos = i;
527 endpos++;
528 while(endpos < expr.length() && (escape || expr[endpos] != '"')) {
529 if(!escape) {
530 if(expr[endpos] == '\\')
531 escape = true;
532 endpos++;
533 } else {
534 escape = false;
535 endpos++;
538 if(endpos == expr.length())
539 throw std::runtime_error("Unmatched \"");
540 //Copy (i,endpos-1) and descape.
541 std::string tmp;
542 escape = false;
543 for(size_t j = i + 1; j < endpos; j++) {
544 if(!escape) {
545 if(expr[j] != '\\')
546 tmp += std::string(1, expr[j]);
547 else
548 escape = true;
549 } else {
550 tmp += std::string(1, expr[j]);
551 escape = false;
554 tokenization.push_back(subexpression(TT_STRING, tmp));
555 i = endpos + 1;
556 } else {
557 bool found = false;
558 //Function names are only recognized if it begins here and is followed by '('.
559 for(auto j : operations) {
560 if(j->is_operator) continue; //Not a function.
561 if(i + j->fnname.length() + 1 > expr.length()) continue; //Too long.
562 if(expr[i + j->fnname.length()] != '(') continue; //Not followed by '('.
563 for(size_t k = 0; k < j->fnname.length(); k++)
564 if(expr[i + k] != j->fnname[k]) goto nomatch; //No match.
565 tokenization.push_back(subexpression(TT_FUNCTION, j->fnname));
566 i += j->fnname.length();
567 found = true;
568 break;
569 nomatch: ;
571 if(found) continue;
572 //Operators. These use longest match rule.
573 size_t longest_match = 0;
574 std::string op;
575 for(auto j : operations) {
576 if(!j->is_operator) continue; //Not an operator.
577 if(i + j->fnname.length() > expr.length()) continue; //Too long.
578 for(size_t k = 0; k < j->fnname.length(); k++)
579 if(expr[i + k] != j->fnname[k]) goto next; //No match.
580 if(j->fnname.length() <= longest_match) continue; //Not longest.
581 found = true;
582 op = j->fnname;
583 longest_match = op.length();
584 next: ;
586 if(found) {
587 tokenization.push_back(subexpression(TT_OPERATOR, op));
588 i += op.length();
589 continue;
591 //Okay, token until next non-alphanum.
592 std::string tmp;
593 while(i < expr.length() && is_alphanumeric(expr[i]))
594 tmp += std::string(1, expr[i++]);
595 if(tmp.length()) {
596 tokenization.push_back(subexpression(TT_VALUE, tmp));
597 continue;
599 std::string summary;
600 size_t j;
601 size_t utfcount = 0;
602 for(j = i; j < expr.length() && (j < i + 20 || utfcount); j++) {
603 if(utfcount) utfcount--;
604 summary += std::string(1, expr[j]);
605 if((uint8_t)expr[j] >= 0xF0) utfcount = 3;
606 if((uint8_t)expr[j] >= 0xE0) utfcount = 2;
607 if((uint8_t)expr[j] >= 0xC0) utfcount = 1;
608 if((uint8_t)expr[j] < 0x80) utfcount = 0;
610 if(j < expr.length()) summary += "[...]";
611 throw std::runtime_error("Expression parse error, at '" + summary + "'");
617 GC::pointer<mathexpr> mathexpr::parse(typeinfo& _type, const std::string& expr,
618 std::function<GC::pointer<mathexpr>(const std::string&)> vars)
620 if(expr == "")
621 throw std::runtime_error("Empty expression");
622 auto operations = _type.operations();
623 std::vector<subexpression> tokenization;
624 tokenize(expr, operations, tokenization);
625 return parse_rec(_type, tokenization, operations, vars, 0, tokenization.size());