lsnes rr2-β24
[lsnes.git] / src / library / mathexpr.cpp
blobcfdd927937c74cf557a87675924e02a115170f82
1 #include "mathexpr.hpp"
2 #include "string.hpp"
3 #include <functional>
4 #include <set>
5 #include <map>
7 namespace mathexpr
9 operinfo::operinfo(std::string funcname)
10 : fnname(funcname), is_operator(false), operands(0), precedence(0), rtl(false)
13 operinfo::operinfo(std::string opername, unsigned _operands, int _percedence, bool _rtl)
14 : fnname(opername), is_operator(true), operands(_operands), precedence(_percedence), rtl(_rtl)
17 operinfo::~operinfo()
21 typeinfo::~typeinfo()
25 mathexpr::mathexpr(typeinfo* _type)
26 : type(*_type)
28 owns_operator = false;
29 state = UNDEFINED;
30 _value = NULL;
31 fn = (operinfo*)0xDEADBEEF;
34 mathexpr::mathexpr(typeinfo* _type, GC::pointer<mathexpr> fwd)
35 : type(*_type)
37 owns_operator = false;
38 state = FORWARD;
39 _value = type.allocate();
40 arguments.push_back(&*fwd);
41 fn = NULL;
44 mathexpr::mathexpr(value _val)
45 : type(*_val.type)
47 owns_operator = false;
48 state = FIXED;
49 _value = type.copy_allocate(_val._value);
50 fn = NULL;
53 mathexpr::mathexpr(typeinfo* _type, const std::string& _val, bool string)
54 : type(*_type)
56 owns_operator = false;
57 state = FIXED;
58 _value = type.parse(_val, string);
59 fn = NULL;
62 mathexpr::mathexpr(typeinfo* _type, operinfo* _fn, std::vector<GC::pointer<mathexpr>> _args, bool _owns_operator)
63 : type(*_type), fn(_fn), owns_operator(_owns_operator)
65 try {
66 for(auto& i : _args)
67 arguments.push_back(&*i);
68 _value = type.allocate();
69 state = TO_BE_EVALUATED;
70 } catch(...) {
71 if(owns_operator)
72 delete fn;
73 throw;
77 mathexpr::~mathexpr()
79 if(owns_operator && fn)
80 delete fn;
81 type.deallocate(_value);
84 void mathexpr::reset()
86 if(state == TO_BE_EVALUATED || state == FIXED || state == UNDEFINED || state == FORWARD)
87 return;
88 if(state == FORWARD_EVALD || state == FORWARD_EVALING) {
89 state = FORWARD;
90 return;
92 state = TO_BE_EVALUATED;
93 for(auto i : arguments)
94 i->reset();
97 mathexpr::mathexpr(const mathexpr& m)
98 : state(m.state), type(m.type), fn(m.fn), _error(m._error), arguments(m.arguments)
100 _value = m._value ? type.copy_allocate(m._value) : NULL;
101 if(state == EVALUATING) state = TO_BE_EVALUATED;
104 mathexpr& mathexpr::operator=(const mathexpr& m)
106 if(this == &m)
107 return *this;
108 std::string _xerror = m._error;
109 std::vector<mathexpr*> _arguments = m.arguments;
110 if(m._value) {
111 if(!_value)
112 _value = m.type.copy_allocate(m._value);
113 else
114 m.type.copy(_value, m._value);
115 } else if(_value) {
116 m.type.deallocate(_value);
117 _value = NULL;
118 } else
119 _value = NULL;
120 type = m.type;
121 fn = m.fn;
122 state = m.state;
123 owns_operator = m.owns_operator;
124 m.owns_operator = false;
125 std::swap(arguments, _arguments);
126 std::swap(_error, _xerror);
127 return *this;
130 value mathexpr::evaluate()
132 value ret;
133 ret.type = &type;
134 switch(state) {
135 case TO_BE_EVALUATED:
136 //Need to evaluate.
137 try {
138 for(auto i : arguments) {
139 if(&i->type != &type) {
140 throw error(error::TYPE_MISMATCH,
141 "Types for function mismatch");
144 state = EVALUATING;
145 std::vector<std::function<value()>> promises;
146 for(auto i : arguments) {
147 mathexpr* m = i;
148 promises.push_back([m]() { return m->evaluate(); });
150 value tmp;
151 tmp.type = &type;
152 tmp._value = _value;
153 fn->evaluate(tmp, promises);
154 state = EVALUATED;
155 } catch(error& e) {
156 state = FAILED;
157 errcode = e.get_code();
158 _error = e.what();
159 throw;
160 } catch(std::exception& e) {
161 state = FAILED;
162 errcode = error::UNKNOWN;
163 _error = e.what();
164 throw;
165 } catch(...) {
166 state = FAILED;
167 errcode = error::UNKNOWN;
168 _error = "Unknown error";
169 throw;
171 ret._value = _value;
172 return ret;
173 case EVALUATING:
174 case FORWARD_EVALING:
175 //Circular dependency.
176 mark_error_and_throw(error::CIRCULAR, "Circular dependency");
177 case EVALUATED:
178 case FIXED:
179 case FORWARD_EVALD:
180 ret._value = _value;
181 return ret;
182 case UNDEFINED:
183 throw error(error::UNDEFINED, "Undefined variable");
184 case FAILED:
185 throw error(errcode, _error);
186 case FORWARD:
187 try {
188 state = FORWARD_EVALING;
189 value v = arguments[0]->evaluate();
190 type.copy(_value, v._value);
191 state = FORWARD_EVALD;
192 return v;
193 } catch(...) {
194 state = FORWARD;
195 throw;
198 throw error(error::INTERNAL, "Internal error (shouldn't be here)");
201 void mathexpr::trace()
203 for(auto i : arguments)
204 i->mark();
207 void mathexpr::mark_error_and_throw(error::errorcode _errcode, const std::string& _xerror)
209 if(state == EVALUATING) {
210 state = FAILED;
211 errcode = _errcode;
212 _error = _xerror;
214 if(state == FORWARD_EVALING) {
215 state = FORWARD;
217 throw error(_errcode, _error);
220 namespace
223 X_EXPR -> VALUE
224 X_EXPR -> STRING
225 X_EXPR -> NONARY-OP
226 X_EXPR -> FUNCTION X_ARGS
227 X_EXPR -> UNARY-OP X_EXPR
228 X_EXPR -> X_LAMBDA X_EXPR
229 X_LAMBDA -> X_EXPR BINARY-OP
230 X_ARGS -> OPEN-PAREN CLOSE_PAREN
231 X_ARGS -> OPEN-PAREN X_TAIL
232 X_TAIL -> X_PAIR X_TAIL
233 X_TAIL -> X_EXPR CLOSE-PAREN
234 X_PAIR -> X_EXPR COMMA
237 //SUBEXPRESSION -> VALUE
238 //SUBEXPRESSION -> STRING
239 //SUBEXPRESSION -> FUNCTION OPEN-PAREN CLOSE-PAREN
240 //SUBEXPRESSION -> FUNCTION OPEN-PAREN (SUBEXPRESSION COMMA)* SUBEXPRESSION CLOSE-PAREN
241 //SUBEXPRESSION -> OPEN-PAREN SUBEXPRESSION CLOSE-PAREN
242 //SUBEXPRESSION -> SUBEXPRESSION BINARY-OP SUBEXPRESSION
243 //SUBEXPRESSION -> UNARY-OP SUBEXPRESSION
244 //SUBEXPRESSION -> NONARY-OP
246 bool is_alphanumeric(char ch)
248 if(ch >= '0' && ch <= '9') return true;
249 if(ch >= 'a' && ch <= 'z') return true;
250 if(ch >= 'A' && ch <= 'Z') return true;
251 if(ch == '_') return true;
252 if(ch == '.') return true;
253 return false;
256 enum token_kind
258 TT_OPEN_PAREN,
259 TT_CLOSE_PAREN,
260 TT_COMMA,
261 TT_FUNCTION,
262 TT_OPERATOR,
263 TT_VARIABLE,
264 TT_VALUE,
265 TT_STRING,
268 struct operations_set
270 operations_set(std::set<operinfo*>& ops)
271 : operations(ops)
274 operinfo* find_function(const std::string& name)
276 operinfo* fn = NULL;
277 for(auto j : operations) {
278 if(name == j->fnname && !j->is_operator)
279 fn = j;
281 if(!fn) throw std::runtime_error("No such function '" + name + "'");
282 return fn;
284 operinfo* find_operator(const std::string& name, unsigned arity)
286 for(auto j : operations) {
287 if(name == j->fnname && j->is_operator && j->operands == arity)
288 return j;
290 return NULL;
292 private:
293 std::set<operinfo*>& operations;
296 struct subexpression
298 subexpression(token_kind k) : kind(k) {}
299 subexpression(token_kind k, const std::string& str) : kind(k), string(str) {}
300 token_kind kind;
301 std::string string;
304 size_t find_last_in_sub(std::vector<subexpression>& ss, size_t first)
306 size_t depth;
307 switch(ss[first].kind) {
308 case TT_FUNCTION:
309 if(first + 1 == ss.size() || ss[first + 1].kind != TT_OPEN_PAREN)
310 throw std::runtime_error("Function requires argument list");
311 first++;
312 case TT_OPEN_PAREN:
313 depth = 0;
314 while(first < ss.size()) {
315 if(ss[first].kind == TT_OPEN_PAREN)
316 depth++;
317 if(ss[first].kind == TT_CLOSE_PAREN)
318 if(!--depth) break;
319 first++;
321 if(first == ss.size())
322 throw std::runtime_error("Unmatched '('");
323 return first;
324 case TT_CLOSE_PAREN:
325 throw std::runtime_error("Unmatched ')'");
326 case TT_COMMA:
327 throw std::runtime_error("',' only allowed in function arguments");
328 case TT_VALUE:
329 case TT_STRING:
330 case TT_OPERATOR:
331 case TT_VARIABLE:
332 return first;
334 throw std::runtime_error("Internal error (shouldn't be here)");
337 size_t find_end_of_arg(std::vector<subexpression>& ss, size_t first)
339 size_t depth = 0;
340 while(first < ss.size()) {
341 if(depth == 0 && ss[first].kind == TT_COMMA)
342 return first;
343 if(ss[first].kind == TT_OPEN_PAREN)
344 depth++;
345 if(ss[first].kind == TT_CLOSE_PAREN) {
346 if(depth == 0)
347 return first;
348 depth--;
350 first++;
352 return ss.size();
355 struct expr_or_op
357 expr_or_op(GC::pointer<mathexpr> e) : expr(e), typei(NULL) {}
358 expr_or_op(std::string o) : op(o), typei(NULL) {}
359 GC::pointer<mathexpr> expr;
360 std::string op;
361 operinfo* typei;
364 GC::pointer<mathexpr> parse_rec(typeinfo& _type, std::vector<expr_or_op>& operands,
365 size_t first, size_t last)
367 if(operands.empty())
368 return GC::pointer<mathexpr>(GC::obj_tag(), &_type);
369 if(last - first > 1) {
370 //Find the highest percedence operator.
371 size_t best = last;
372 for(size_t i = first; i < last; i++) {
373 if(operands[i].typei) {
374 if(best == last)
375 best = i;
376 else if(operands[i].typei->precedence < operands[best].typei->precedence) {
377 best = i;
378 } else if(!operands[best].typei->rtl &&
379 operands[i].typei->precedence == operands[best].typei->precedence) {
380 best = i;
384 if(best == last) throw std::runtime_error("Internal error: No operands?");
385 if(operands[best].typei->operands == 1) {
386 //The operator is unary, collect up all following unary operators.
387 size_t j = first;
388 while(operands[j].typei)
389 j++;
390 std::vector<GC::pointer<mathexpr>> args;
391 args.push_back(parse_rec(_type, operands, first + 1, j + 1));
392 return GC::pointer<mathexpr>(GC::obj_tag(), &_type,
393 operands[best].typei, args);
394 } else {
395 //Binary operator.
396 std::vector<GC::pointer<mathexpr>> args;
397 args.push_back(parse_rec(_type, operands, first, best));
398 args.push_back(parse_rec(_type, operands, best + 1, last));
399 return GC::pointer<mathexpr>(GC::obj_tag(), &_type,
400 operands[best].typei, args);
403 return operands[first].expr;
406 GC::pointer<mathexpr> parse_rec(typeinfo& _type, std::vector<subexpression>& ss,
407 std::set<operinfo*>& operations,
408 std::function<GC::pointer<mathexpr>(const std::string&)> vars, size_t first, size_t last)
410 operations_set opset(operations);
411 std::vector<expr_or_op> operands;
412 std::vector<GC::pointer<mathexpr>> args;
413 operinfo* fn;
414 for(size_t i = first; i < last; i++) {
415 size_t l = find_last_in_sub(ss, i);
416 if(l >= last) throw std::runtime_error("Internal error: Improper nesting");
417 switch(ss[i].kind) {
418 case TT_OPEN_PAREN:
419 operands.push_back(parse_rec(_type, ss, operations, vars, i + 1, l));
420 break;
421 case TT_VALUE:
422 operands.push_back(GC::pointer<mathexpr>(GC::obj_tag(), &_type,
423 ss[i].string, false));
424 break;
425 case TT_STRING:
426 operands.push_back(GC::pointer<mathexpr>(GC::obj_tag(), &_type,
427 ss[i].string, true));
428 break;
429 case TT_VARIABLE:
430 //We have to warp this is identify transform to make the evaluation lazy.
431 operands.push_back(GC::pointer<mathexpr>(GC::obj_tag(), &_type,
432 vars(ss[i].string)));
433 break;
434 case TT_FUNCTION:
435 fn = opset.find_function(ss[i].string);
436 i += 2;
437 while(ss[i].kind != TT_CLOSE_PAREN) {
438 size_t k = find_end_of_arg(ss, i);
439 args.push_back(parse_rec(_type, ss, operations, vars, i, k));
440 if(k < ss.size() && ss[k].kind == TT_COMMA)
441 i = k + 1;
442 else
443 i = k;
445 operands.push_back(GC::pointer<mathexpr>(GC::obj_tag(), &_type, fn,
446 args));
447 args.clear();
448 break;
449 case TT_OPERATOR:
450 operands.push_back(ss[i].string);
451 break;
452 case TT_CLOSE_PAREN:
453 case TT_COMMA:
454 ; //Can't appen.
456 i = l;
458 if(operands.empty())
459 throw std::runtime_error("Empty subexpression");
460 //Translate nonary operators to values.
461 for(auto& i : operands) {
462 if(!(bool)i.expr) {
463 auto fn = opset.find_operator(i.op, 0);
464 if(fn)
465 i.expr = GC::pointer<mathexpr>(GC::obj_tag(), &_type, fn,
466 std::vector<GC::pointer<mathexpr>>());
469 //Check that there aren't two consequtive subexpressions and mark operators.
470 bool was_operand = false;
471 for(auto& i : operands) {
472 bool is_operand = (bool)i.expr;
473 if(!is_operand && !was_operand)
474 if(!(i.typei = opset.find_operator(i.op, 1)))
475 throw std::runtime_error("'" + i.op + "' is not an unary operator");
476 if(!is_operand && was_operand)
477 if(!(i.typei = opset.find_operator(i.op, 2)))
478 throw std::runtime_error("'" + i.op + "' is not a binary operator");
479 if(was_operand && is_operand)
480 throw std::runtime_error("Expected operator, got operand");
481 was_operand = is_operand;
483 if(!was_operand)
484 throw std::runtime_error("Expected operand, got end of subexpression");
485 //Okay, now the expression has been reduced into series of operators and subexpressions.
486 //If there are multiple consequtive operators, the first (except as first item) is binary,
487 //and the others are unary.
488 return parse_rec(_type, operands, 0, operands.size());
491 void tokenize(const std::string& expr, std::set<operinfo*>& operations,
492 std::vector<subexpression>& tokenization)
494 for(size_t i = 0; i < expr.length();) {
495 if(expr[i] == '(') {
496 tokenization.push_back(subexpression(TT_OPEN_PAREN));
497 i++;
498 } else if(expr[i] == ')') {
499 tokenization.push_back(subexpression(TT_CLOSE_PAREN));
500 i++;
501 } else if(expr[i] == ',') {
502 tokenization.push_back(subexpression(TT_COMMA));
503 i++;
504 } else if(expr[i] == ' ') {
505 i++;
506 } else if(expr[i] == '$') {
507 //Variable. If the next character is {, parse until }, otherwise parse until
508 //non-alphanum.
509 std::string varname = "";
510 if(i + 1 < expr.length() && expr[i + 1] == '{') {
511 //Terminate by '}'.
512 i++;
513 while(i + 1 < expr.length() && expr[i + 1] != '}')
514 varname += std::string(1, expr[++i]);
515 if(i + 1 >= expr.length() || expr[i + 1] != '}')
516 throw std::runtime_error("${ without matching }");
517 i++;
518 } else {
519 //Terminate by non-alphanum.
520 while(i + 1 < expr.length() && is_alphanumeric(expr[i + 1]))
521 varname += std::string(1, expr[++i]);
523 tokenization.push_back(subexpression(TT_VARIABLE, varname));
524 i++;
525 } else if(expr[i] == '"') {
526 bool escape = false;
527 size_t endpos = i;
528 endpos++;
529 while(endpos < expr.length() && (escape || expr[endpos] != '"')) {
530 if(!escape) {
531 if(expr[endpos] == '\\')
532 escape = true;
533 endpos++;
534 } else {
535 escape = false;
536 endpos++;
539 if(endpos == expr.length())
540 throw std::runtime_error("Unmatched \"");
541 //Copy (i,endpos-1) and descape.
542 std::string tmp;
543 escape = false;
544 for(size_t j = i + 1; j < endpos; j++) {
545 if(!escape) {
546 if(expr[j] != '\\')
547 tmp += std::string(1, expr[j]);
548 else
549 escape = true;
550 } else {
551 tmp += std::string(1, expr[j]);
552 escape = false;
555 tokenization.push_back(subexpression(TT_STRING, tmp));
556 i = endpos + 1;
557 } else {
558 bool found = false;
559 //Function names are only recognized if it begins here and is followed by '('.
560 for(auto j : operations) {
561 if(j->is_operator) continue; //Not a function.
562 if(i + j->fnname.length() + 1 > expr.length()) continue; //Too long.
563 if(expr[i + j->fnname.length()] != '(') continue; //Not followed by '('.
564 for(size_t k = 0; k < j->fnname.length(); k++)
565 if(expr[i + k] != j->fnname[k]) goto nomatch; //No match.
566 tokenization.push_back(subexpression(TT_FUNCTION, j->fnname));
567 i += j->fnname.length();
568 found = true;
569 break;
570 nomatch: ;
572 if(found) continue;
573 //Operators. These use longest match rule.
574 size_t longest_match = 0;
575 std::string op;
576 for(auto j : operations) {
577 if(!j->is_operator) continue; //Not an operator.
578 if(i + j->fnname.length() > expr.length()) continue; //Too long.
579 for(size_t k = 0; k < j->fnname.length(); k++)
580 if(expr[i + k] != j->fnname[k]) goto next; //No match.
581 if(j->fnname.length() <= longest_match) continue; //Not longest.
582 found = true;
583 op = j->fnname;
584 longest_match = op.length();
585 next: ;
587 if(found) {
588 tokenization.push_back(subexpression(TT_OPERATOR, op));
589 i += op.length();
590 continue;
592 //Okay, token until next non-alphanum.
593 std::string tmp;
594 while(i < expr.length() && is_alphanumeric(expr[i]))
595 tmp += std::string(1, expr[i++]);
596 if(tmp.length()) {
597 tokenization.push_back(subexpression(TT_VALUE, tmp));
598 continue;
600 std::string summary;
601 size_t j;
602 size_t utfcount = 0;
603 for(j = i; j < expr.length() && (j < i + 20 || utfcount); j++) {
604 if(utfcount) utfcount--;
605 summary += std::string(1, expr[j]);
606 if((uint8_t)expr[j] >= 0xF0) utfcount = 3;
607 if((uint8_t)expr[j] >= 0xE0) utfcount = 2;
608 if((uint8_t)expr[j] >= 0xC0) utfcount = 1;
609 if((uint8_t)expr[j] < 0x80) utfcount = 0;
611 if(j < expr.length()) summary += "[...]";
612 throw std::runtime_error("Expression parse error, at '" + summary + "'");
618 GC::pointer<mathexpr> mathexpr::parse(typeinfo& _type, const std::string& expr,
619 std::function<GC::pointer<mathexpr>(const std::string&)> vars)
621 if(expr == "")
622 throw std::runtime_error("Empty expression");
623 auto operations = _type.operations();
624 std::vector<subexpression> tokenization;
625 tokenize(expr, operations, tokenization);
626 return parse_rec(_type, tokenization, operations, vars, 0, tokenization.size());