Completed bootstrap, without separate rings for metalevels; improved clopts, metalua...
[metalua.git] / src / lib / extension-compiler / match.mlua
bloba19752b153ab364f657f2a0c686622084bc5b787
1 ----------------------------------------------------------------------
2 -- Metalua samples:  $Id$
3 --
4 -- Summary: Structural pattern matching for metalua ADT.
5 --
6 ----------------------------------------------------------------------
7 --
8 -- Copyright (c) 2006, Fabien Fleutot <metalua@gmail.com>.
9 --
10 -- This software is released under the MIT Licence, see licence.txt
11 -- for details.
13 --------------------------------------------------------------------------------
15 -- This extension, borrowed from ML dialects, allows in a single operation to
16 -- analyze the structure of nested ADT, and bind local variables to subtrees
17 -- of the analyzed ADT before executing a block of statements chosen depending
18 -- on the tested term's structure.
20 -- The general form of a pattern matching statement is:
22 -- match <tested_term> with
23 -- | <pattern_1_1> | <pattern_1_2> | <pattern_1_3> -> <block_1>
24 -- | <pattern_2> -> <block_2>
25 -- | <pattern_3_1> | <pattern_3_2> if <some_condition> -> <block_3> 
26 -- end
27 -- 
28 -- If one of the patterns <pattern_1_x> accurately describes the
29 -- structure of <tested_term>, then <block_1> is executed (and no
30 -- other block of the match statement is tested). If none of
31 -- <pattern_1_x> patterns mathc <tested_term>, but <pattern_2> does,
32 -- then <block_2> is evaluated before exiting. If no pattern matches,
33 -- the whole <match> statemetn does nothing. If more than one pattern
34 -- matches, the first one wins.
35 -- 
36 -- When an additional condition, introduced by [if], is put after
37 -- the patterns, this condition is evaluated if one of the patterns matches,
38 -- and the case is considered successful only if the condition returns neither
39 -- [nil] nor [false].
41 -- Terminology
42 -- ===========
44 -- The whole compound statement is called a match; Each schema is
45 -- called a pattern; Each sequence (list of patterns, optional guard,
46 -- statements block) is called a case.
48 -- Patterns
49 -- ========
50 -- Patterns can consist of:
52 -- - numbers, booleans, strings: they only match terms equal to them
54 -- - variables: they match everything, and bind it, i.e. the variable
55 --   will be set to the corresponding tested value when the block will
56 --   be executed (if the whole pattern and the guard match). If a
57 --   variable appears more than once in a single pattern, all captured
58 --   values have to be equal, in the sense of the "==" operator.
60 -- - tables: a table matches if all these conditions are met:
61 --   * the tested term is a table;
62 --   * all of the pattern's keys are strings or integer or implicit indexes;
63 --   * all of the pattern's values are valid patterns, except maybe the
64 --     last value with implicit integer key, which can also be [...];
65 --   * every value in the tested term is matched by the corresponding
66 --     sub-pattern;
67 --   * There are as many integer-indexed values in the tested term as in
68 --     the pattern, or there is a [...] at the end of the table pattern.
69 -- 
70 -- Pattern examples
71 -- ================
73 -- Pattern { 1, a } matches term { 1, 2 }, and binds [a] to [2].
74 -- It doesn't match term { 1, 2, 3 } (wrong number of parameters).
76 -- Pattern { 1, a, ... } matches term { 1, 2 } as well as { 1, 2, 3 }
77 -- (the trailing [...] suppresses the same-length condition)
78 -- 
79 -- `Foo{ a, { bar = 2, b } } matches `Foo{ 1, { bar = 2, "THREE" } }, 
80 -- and binds [a] to [1], [b] to ["THREE"] (the syntax sugar for [tag] fields
81 -- is available in patterns as well as in regular terms).
83 -- Implementation hints
84 -- ====================
86 -- Since the control flow quickly becomes hairy, it's implemented with
87 -- gotos and labels. [on_success] holds the label name where the
88 -- control flow must go when the currently parsed pattern
89 -- matches. [on_failure] is the next position to reach if the current
90 -- pattern mismatches: either the next pattern in a multiple-patterns
91 -- case, or the next case when parsing the last pattern of a case, or
92 -- the end of the match code for the last pattern of the last case.
94 -- [case_vars] is the list of variables created for the current
95 -- case. It's kept to generate the local variables declaration.
96 -- [pattern_vars] is also kept, to detect non-linear variables
97 -- (variables which appear more than once in a given pattern, and
98 -- therefore require an == test).
100 --------------------------------------------------------------------------------
102 -- TODO:
104 -- [CHECK WHETHER IT'S STILL TRUE AFTER TESTS INVERSION]
105 -- - Optimize jumps: the bytecode generated often contains several
106 --   [OP_JMP 1] in a row, which is quite silly. That might be due to the
107 --   implementation of [goto], but something can also probably be done
108 --   in pattern matching implementation.
110 ----------------------------------------------------------------------
112 ----------------------------------------------------------------------
113 -- Convert a tested term and a list of (pattern, statement) pairs
114 -- into a pattern-matching AST.
115 ----------------------------------------------------------------------
116 local function match_builder (tested_terms_list, cases)
118    local local_vars = { }
119    local var = |n| `Id{ "$v" .. n }
120    local on_failure -- current target upon pattern mismatch
122    local literal_tags = { String=1, Number=1, Boolean=1 }
124    local current_code -- list where instructions are accumulated
125    local pattern_vars -- list where local vars are accumulated
126    local case_vars    -- list where local vars are accumulated
128    -------------------------------------------------------------------
129    -- Accumulate statements in [current_code]
130    -------------------------------------------------------------------
131    local function acc (x) 
132       --printf ("%s", disp.ast (x))
133       table.insert (current_code, x) end
134    local function acc_test (it) -- the test must fail for match to succeeed.
135       acc +{stat: if -{it} then -{`Goto{ on_failure }} end } end
136    local function acc_assign (lhs, rhs)
137       local_vars[lhs[1]] = true
138       acc (`Let{ {lhs}, {rhs} }) end
140    -------------------------------------------------------------------
141    -- Set of variables bound in the current pattern, to find
142    -- non-linear patterns.
143    -------------------------------------------------------------------
144    local function handle_id (id, val)
145       assert (id.tag=="Id")
146       if id[1] == "_" then 
147          -- "_" is used as a dummy var ==> no assignment, no == checking
148          case_vars["_"] = true
149       elseif pattern_vars[id[1]] then 
150          -- This var is already bound ==> test for equality
151          acc_test +{ -{val} ~= -{id} }
152       else
153          -- Free var ==> bind it, and remember it for latter linearity checking
154          acc_assign (id, val) 
155          pattern_vars[id[1]] = true
156          case_vars[id[1]]    = true
157       end
158    end
160    -------------------------------------------------------------------
161    -- Turn a pattern into a list of tests and assignments stored into
162    -- [current_code]. [n] is the depth of the subpattern into the
163    -- toplevel pattern; [pattern] is the AST of a pattern, or a
164    -- subtree of that pattern when [n>0].
165    -------------------------------------------------------------------
166    local function pattern_builder (n, pattern)
167       local v = var(n)
168       if literal_tags[pattern.tag]  then acc_test +{ -{v} ~= -{pattern} }
169       elseif "Id"    == pattern.tag then handle_id (pattern, v)
170       elseif "Op"    == pattern.tag and "div" == pattern[1] then
171          local n2 = n>0 and n+1 or 1
172          local _, regexp, sub_pattern = unpack(pattern)
173          if sub_pattern.tag=="Id" then sub_pattern = `Table{ sub_pattern } end
174          -- Sanity checks --
175          assert (regexp.tag=="String", 
176                  "Left hand side operand for '/' in a pattern must be "..
177                  "a literal string representing a regular expression")
178          assert (sub_pattern.tag=="Table",
179                  "Right hand side operand for '/' in a pattern must be "..
180                  "an identifier or a list of identifiers")
181          for x in ivalues(sub_pattern) do
182             assert (x.tag=="Id" or x.tag=='Dots',
183                  "Right hand side operand for '/' in a pattern must be "..
184                  "a list of identifiers")
185          end
187          -- Can only match strings
188          acc_test +{ type(-{v}) ~= 'string' }
189          -- put all captures in a list
190          local capt_list  = +{ { string.strmatch(-{v}, -{regexp}) } }
191          -- save them in a var_n for recursive decomposition
192          acc +{stat: local -{var(n2)} = -{capt_list} }
193          -- was capture successful?
194          acc_test +{ not next (-{var(n2)}) }
195          pattern_builder (n2, sub_pattern)
196       elseif "Table" == pattern.tag then
197          local seen_dots, len = false, 0
198          acc_test +{ type( -{v} ) ~= "table" } 
199          for i = 1, #pattern do
200             local key, sub_pattern
201             if pattern[i].tag=="Key" then -- Explicit key
202                key, sub_pattern = unpack (pattern[i])
203                assert (literal_tags[key.tag], "Invalid key")
204             else -- Implicit key
205                len, key, sub_pattern = len+1, `Number{ len+1 }, pattern[i]
206             end
207             assert (not seen_dots, "Wrongly placed `...' ")
208             if sub_pattern.tag == "Id" then 
209                -- Optimization: save a useless [ v(n+1)=v(n).key ]
210                handle_id (sub_pattern, `Index{ v, key })
211                if sub_pattern[1] ~= "_" then 
212                   acc_test +{ -{sub_pattern} == nil } 
213                end
214             elseif sub_pattern.tag == "Dots" then
215                -- Remember to suppress arity checking
216                seen_dots = true
217             else
218                -- Business as usual:
219                local n2 = n>0 and n+1 or 1
220                acc_assign (var(n2), `Index{ v, key })
221                pattern_builder (n2, sub_pattern)
222             end
223          end
224          if not seen_dots then -- Check arity
225             acc_test +{ #-{v} ~= -{`Number{len}} }
226          end
227       else 
228          error ("Invalid pattern: "..table.tostring(pattern, "nohash"))
229       end
230    end
232    local end_of_match = mlp.gensym "_end_of_match"
233    local arity = #tested_terms_list
234    local x = `Local{ { }, { } }
235    for i=1,arity do 
236       x[1][i]=var(-i)
237       x[2][i]= tested_terms_list[i]
238    end
239    local complete_code = `Do{ x }
241    -- Foreach [{patterns, guard, block}]:
242    for i = 1, #cases do
243       local patterns, guard, block = unpack (cases[i])
244    
245       -- Reset accumulators
246       local local_decl_stat = { }
247       current_code = `Do{ `Local { local_decl_stat, { } } } -- reset code accumulator
248       case_vars = { }
249       table.insert (complete_code, current_code)
251       local on_success = mlp.gensym "_on_success" -- 1 success target per case
253       -----------------------------------------------------------
254       -- Foreach [pattern] in [patterns]:
255       -- on failure go to next pattern if any, 
256       -- next case if no more pattern in current case.
257       -- on success (i.e. no [goto on_failure]), go to after last pattern test
258       -- if there is a guard, test it before the block: it's common to all patterns,
259       -----------------------------------------------------------
260       for j = 1, #patterns do
261          if #patterns[j] ~= arity then 
262             error( "Invalid match: pattern has only "..
263                    #patterns[j].." elements, "..
264                    arity.." were expected")
265          end
266          pattern_vars = { }
267          on_failure = mlp.gensym "_on_failure" -- 1 failure target per pattern
268          
269          for k = 1, arity do pattern_builder (-k, patterns[j][k]) end
270          if j<#patterns then 
271             acc (`Goto{on_success}) 
272             acc (`Label{on_failure}) 
273          end
274       end
275       acc (`Label{on_success})
276       if guard then acc_test (`Op{ "not", guard}) end
277       acc (block)
278       acc (`Goto{end_of_match}) 
279       acc (`Label{on_failure})
281       -- fill local variables declaration:
282       local v1 = var(1)[1]
283       for k, _ in pairs(case_vars) do
284          if k[1] ~= v1 then table.insert (local_decl_stat, `Id{k}) end
285       end
287    end
288    acc +{error "mismatch"} -- cause a mismatch error after last case failed
289    table.insert(complete_code, `Label{ end_of_match })
290    return complete_code
293 ----------------------------------------------------------------------
294 -- Sugar: add the syntactic extension that makes pattern matching
295 --        pleasant to read and write.
296 ----------------------------------------------------------------------
298 mlp.lexer:add{ "match", "with", "->" }
299 mlp.block.terminators:add "|"
301 mlp.stat:add{ name = "match statement",
302    "match", mlp.expr_list, "with",
303    gg.optkeyword "|",
304    gg.list{ name = "match cases list",
305       primary     = gg.sequence{ name = "match case",
306          gg.list{ name = "patterns",
307             primary = mlp.expr_list,
308             separators = "|",
309             terminators = { "->", "if" } },
310          gg.onkeyword{ "if", mlp.expr, consume = true },
311          "->",
312          mlp.block },
313       separators  = "|",
314       terminators = "end" },
315    "end",
316    builder = |x| match_builder (x[1], x[3]) }