new samples: synthesis of source file from an AST, weaving sources and bits of genera...
[metalua.git] / src / samples / synth.mlua
blob02c75c93809df1722d48780a24a1360f3c1b4cc1
1 require 'strict'
3 -{ extension 'match' }
5 synth = { }
6 synth.__index = synth
8 function synth.new ()
9    local self = {
10       _acc           = { },
11       current_indent = 0,
12       indent_step    = "   "
13    }
14    return setmetatable (self, synth)
15 end
17 function synth:run (ast)
18    --if getmetatable (self) ~= synth and not ast then
19    if not ast then
20       self, ast = synth.new(), self
21    end
22    self._acc = { }
23    self:node (ast)
24    return table.concat (self._acc)
25 end
27 function synth:acc (x)
28    if x then table.insert (self._acc, x) end
29 end
31 function synth:nl ()
32    if self.current_indent == 0 then
33       self:acc "\n"
34    end
35    self:acc ("\n" .. self.indent_step:rep (self.current_indent))
36 end
38 function synth:nlindent ()
39    self.current_indent = self.current_indent + 1
40    self:nl ()
41 end
43 function synth:nldedent ()
44    self.current_indent = self.current_indent - 1
45    self:acc ("\n" .. self.indent_step:rep (self.current_indent))
46 end
48 local keywords = table.transpose {
49    "and",
50    "break",
51    "do",
52    "else",
53    "elseif",
54    "end",
55    "false",
56    "for",
57    "function",
58    "if",
59    "in",
60    "local",
61    "nil",
62    "not",
63    "or",
64    "repeat",
65    "return",
66    "then",
67    "true",
68    "until",
69    "while"
72 local function is_ident (id)
73    return id:strmatch "^[%a_][%w_]*$" and not keywords[id]
74 end
76 local function is_idx_stack (ast)
77    match ast with
78    | `Id{ _ } -> return true
79    | `Index{ left, `String{ _ } } -> return is_idx_stack (left)
80    | _ -> return false
81    end
82 end
84 local op_preprec = {
85    { "or", "and" },
86    { "lt", "le", "eq", "ne" },
87    { "concat" }, 
88    { "add", "sub" },
89    { "mul", "div", "mod" },
90    { "unary", "not", "len" },
91    { "pow" },
92    { "index" } }
94 local op_prec = { }
96 for prec, ops in ipairs (op_preprec) do
97    for op in ivalues (ops) do
98       op_prec[op] = prec
99    end
102 local op_symbol = {
103    add = " + ",
104    sub = " - ",
105    mul = " * ",
106    div = " / ",
107    mod = " % ",
108    pow = " ^ ",
109    concat = " .. ",
110    eq = " == ",
111    ne = " ~= ",
112    lt = " < ",
113    le = " <= ",
114    ["and"] = " and ",
115    ["or"] = " or ",
116    ["not"] = "not ",
117    len = "# "
120 function synth:node (node)
121    assert (self~=synth and self._acc)
122    if not node.tag then
123       self:list (node, self.nl)
124    else
125       local f = synth[node.tag]
126       if type (f) == "function" then
127          f (self, node, unpack (node))
128       elseif type (f) == "string" then
129          self:acc (f)
130       else
131          self:acc " -{ "
132          self:acc (table.tostring (node, "nohash"))
133          self:acc " }"
134       end
135    end
138 function synth:list (list, sep, start)
139    for i = start or 1, # list do
140       self:node (list[i])
141       if list[i + 1] then
142          if not sep then            
143          elseif type (sep) == "function" then sep (self)
144          elseif type (sep) == "string"   then self:acc (sep)
145          else   error "Invalid list separator" end
146       end
147    end
150 function synth:Do (node)
151    self:acc "do"
152    self:nlindent ()
153    self:list (node, self.nl)
154    self:nldedent ()
155    self:acc "end"
158 function synth:Set (node)
159    match node with
160    | `Set{ { `Index{ lhs, `String{ method } } }, 
161            { `Function{ { `Id "self", ... } == params, body } } } 
162       if is_idx_stack (lhs) and is_ident (method) ->
163       -- function foo:bar(...) ... end
164       self:acc      "function "
165       self:node     (lhs)
166       self:acc      ":"
167       self:acc      (method)
168       self:acc      " ("
169       self:list    (params, ", ", 2)
170       self:acc      ")"
171       self:nlindent ()
172       self:list    (body, self.nl)
173       self:nldedent ()
174       self:acc      "end"
176    | `Set{ { lhs }, { `Function{ params, body } } } if is_idx_stack (lhs) ->
177       -- function foo(...) ... end
178       self:acc      "function "
179       self:node    (lhs)
180       self:acc      " ("
181       self:list    (params, ", ")
182       self:acc      ")"
183       self:nlindent ()
184       self:list    (body, self.nl)
185       self:nldedent ()
186       self:acc      "end"
188    | `Set{ { `Id{ lhs1name } == lhs1, ... } == lhs, rhs } if not is_ident (lhs1name) ->
189       -- foo, ... = ... when foo isn't a valid identifier
190       self:acc      "("
191       self:node    (lhs1)
192       self:acc      ")"
193       if lhs[2] then 
194          self:acc   ", "
195          self:list (lhs, ", ", 2)
196       end
197       self:acc      " = "
198       self:list    (rhs, ", ")
200    | `Set{ lhs, rhs } ->
201       -- ... = ...
202       self:list (lhs, ", ")
203       self:acc   " = "
204       self:list (rhs, ", ")
205    end
208 function synth:While (node, cond, body)
209    self:acc      "while "
210    self:node     (cond)
211    self:acc      " do"
212    self:nlindent ()
213    self:list     (body, self.nl)
214    self:nldedent ()
215    self:acc      "end"
218 function synth:Repeat (node, body, cond)
219    self:acc      "repeat"
220    self:nlindent ()
221    self:list     (body, self.nl)
222    self:nldedent ()
223    self:acc      "until "
224    self:node     (cond)
227 function synth:If (node)
228    for i = 1, #node-1, 2 do
229       local cond, body = node[i], node[i+1]
230       self:acc      (i==1 and "if " or "elseif ")
231       self:node     (cond)
232       self:acc      " then"
233       self:nlindent ()
234       self:list     (body, self.nl)
235       self:nldedent ()
236    end
237    if #node%2 == 1 then
238       self:acc      "else"
239       self:nlindent ()
240       self:list     (node[#node], self.nl)
241       self:nldedent ()
242    end
243    self:acc "end"
246 function synth:Fornum (node, var, first, last)
247    local body = node[#node]
248    self:acc      "for "
249    self:node     (var)
250    self:acc      " = "
251    self:node     (first)
252    self:acc      ", "
253    self:node     (last)
254    if #node==5 then
255       self:acc   ", "
256       self:node  (node[4]) -- step increment
257    end
258    self:acc      " do"
259    self:nlindent ()
260    self:list     (body, self.nl)
261    self:nldedent ()
262    self:acc      "end"
265 function synth:Forin (node, vars, generators, body)
266    self:acc      "for "
267    self:list     (vars, ", ")
268    self:acc      " in "
269    self:list     (generators, ", ")
270    self:acc      " do"
271    self:nlindent ()
272    self:list     (body, self.nl)
273    self:nldedent ()
274    self:acc      "end"
277 function synth:Local (node, lhs, rhs)
278    self:acc     "local "
279    self:list    (lhs, ", ")
280    if rhs[1] then
281       self:acc  " = "
282       self:list (rhs, ", ")
283    end
286 function synth:Localrec (node, lhs, rhs)
287    match node with
288    | `Localrec{ { `Id{name} }, { `Function{ params, body } } } if is_ident (name) ->
289       self:acc      "local function "
290       self:acc      (name)
291       self:acc      " ("
292       self:list     (params, ", ")
293       self:acc      ")"
294       self:nlindent ()
295       self:list     (body, self.nl)
296       self:nldedent ()
297       self:acc      "end"
299    | _ -> 
300       self:acc "-{ "
301       self:acc (table.tostring (node, 'nohash', 80))
302       self:acc " }"
303    end
307 function synth:Call (node, f)
308    local parens
309    match node with
310    | `Call{ _, `String{_} }
311    | `Call{ _, `Table{...}} -> parens = false
312    | _ -> parens = true
313    end
314    self:node (f)
315    if parens then
316       self:acc " ("
317    else
318       self:acc " "
319    end
320    self:list (node, ", ", 2)
321    if parens then
322       self:acc ")"
323    end
326 function synth:Invoke (node, f, method)
327    local parens
328    match node with
329    | `Invoke{ _, _, `String{_} }
330    | `Invoke{ _, _, `Table{...}} -> parens = false
331    | _ -> parens = true
332    end
333    self:node   (f)
334    self:acc    ":"
335    self:acc    (method[1])
336    self:acc    (parens and " (" or " ")
337    self:list   (node, ", ", 3)
338    if parens then
339       self:acc ")"
340    end
343 function synth:Return (node)
344    self:acc  "return "
345    self:list (node, ", ")
348 synth.Break = "break"
349 synth.Nil = "nil"
350 synth.False = "false"
351 synth.True = "true"
352 synth.Dots = "..."
354 function synth:Number (node, n)
355    self:acc (tostring (n))
358 function synth:String (node, str)
359    self:acc (string.format ("%q", str):gsub ("\\\n", "\\n"))
362 function synth:Function (node, params, body)
363    self:acc "function "
364    self:acc " ("
365    self:list (params, ", ")
366    self:acc ")"
367    self:nlindent ()
368    self:list (body, self.nl)
369    self:nldedent ()
370    self:acc "end"
373 function synth:Table (node)
374    if not node[1] then self:acc "{ }" else
375       self:acc "{"
376       self:nlindent ()
377       for i, elem in ipairs (node) do
378          match elem with
379          | `Pair{ `String{ key }, value } if is_ident (key) ->
380             self:acc  (key)
381             self:acc  " = "
382             self:node (value)
384          | `Pair{ key, value } ->
385             self:acc  "["
386             self:node (key)
387             self:acc  "] = "
388             self:node (value)
390          | _ -> 
391             self:node (elem)
392          end
393          if node[i + 1] then
394             self:acc ","
395             self:nl  ()
396          end
397       end
398       self:nldedent  ()
399       self:acc       "}"
400    end
403 function synth:Op (node, op, a, b)
404    -- Transform not (a == b) into a ~= b
405    match node with
406    | `Op{ "not", `Op{ "eq", _a, _b } } 
407    | `Op{ "not", `Paren{ `Op{ "eq", _a, _b } } } ->  
408       op, a, b = "ne", _a, _b
409    | _ ->
410    end
412    if b then -- binary
413       local left_paren, right_paren
414       match a with
415       | `Op{ op_a, ...} if op_prec[op] >= op_prec[op_a] -> left_paren = true
416       | _ -> left_paren = false
417       end
419       match b with -- FIXME: might not work with right assoc operators
420       | `Op{ op_b, ...} if op_prec[op] >= op_prec[op_b] -> right_paren = true
421       | _ -> right_paren = false
422       end
424       self:acc   (left_paren and "(")
425       self:node (a)
426       self:acc   (left_paren and ")")
428       self:acc   (op_symbol [op])
430       self:acc   (right_paren and "(")
431       self:node (b)
432       self:acc   (right_paren and ")")
434    else -- unary
435       
436       local paren
437       match a with
438       | `Op{ op_a, ... } if op_prec[op] >= op_prec[op_a] -> paren = true
439       | _ -> paren = false
440       end
441       self:acc   (op_symbol[op])
442       self:acc   (paren and "(")
443       self:node (a)
444       self:acc   (paren and ")")
445    end
448 function synth:Paren (node, content)
449    self:acc  "("
450    self:node (content)
451    self:acc  ")"
454 function synth:Index (node, table, key)
455    local paren_table
456    match table with
457    | `Op{ op, ... } if op_prec[op] < op_prec.index -> paren_table = true
458    | _ -> paren_table = false
459    end
461    self:acc   (paren_table and "(")
462    self:node (table)
463    self:acc   (paren_table and ")")
465    match key with
466    | `String{ field } if is_ident (field) -> 
467       self:acc "."
468       self:acc (field)
469    | _ -> 
470       self:acc   "["
471       self:node (key)
472       self:acc   "]"
473    end
476 function synth:Id (node, name)
477    if is_ident (name) then
478       self:acc (name)
479    else
480       self:acc    "-{`Id "
481       self:String (node, name)
482       self:acc    "}"
483    end 
486 require 'mlc'
487 local filename = (arg[2] or arg[1]) or arg[0]
488 local ast = mlc.luafile_to_ast (filename)
489 local status, src = xpcall (function () print(synth.run(ast)) end,
490                             function (err)
491                                print (err)
492                                print (debug.traceback())
493                             end)
495 --print (synth.run (ast))