move aux functions out of the constructor
[lisp-parkour.git] / node.lua
blob235105a8d4679a8afee4ffccdfa9b8438e87c52d
1 require'lpeg'
2 local l = lpeg
3 local P, S, V, Cc, Cmt, Cp, Ct = l.P, l.S, l.V, l.Cc, l.Cmt, l.Cp, l.Ct
5 local hspace = S" \t"
6 local newline = S"\n\r"
7 local function past(_, position, pos) return position <= pos end
9 local M = {}
11 local function startof(node) return node.start end
12 local function finishof(node) return node.finish end
14 function M.new(read)
16 local function at_pos(_, position, start, finish, range)
17 if range.start + 1 >= start and range.start < finish and range.finish < finish then
18 return position, start - 1, finish - 1
19 end
20 end
22 local function gt(pos, val)
23 return val and pos > val
24 end
26 local function le(pos, val)
27 return val and pos <= val
28 end
30 -- binary search a list for the nearest node before pos
31 local function before(t, pos, key, skip)
32 local left, right = 1, #t
33 while left <= right do
34 local m = math.floor(left + (right - left) / 2)
35 local m_val = key(t[m])
36 if gt(pos, m_val) and (m == #t or le(pos, key(t[m + 1]))) then
37 if skip then
38 while (t[m] and skip(t[m])) do
39 m = m - 1
40 end
41 end
42 return t[m], m
43 end
44 if le(pos, m_val) then right = m - 1 else left = m + 1 end
45 end
46 end
48 -- binary search a list for the nearest node after pos
49 local function after(t, pos, key, skip)
50 if t.is_root and (#t == 0 or pos >= t[#t].start) then
51 -- XXX: if we are at the last parsed top-level node, try to access the next one
52 -- to get it parsed as well. Otherwise the search will stop prematurely.
53 local _ = t[#t + 1]
54 end
55 local left, right = 1, #t
56 while left <= right do
57 local m = math.floor(left + (right - left) / 2)
58 local m_val = key(t[m])
59 if le(pos, m_val) and (m == 1 or gt(pos, key(t[m - 1]))) then
60 if skip then
61 while (t[m] and skip(t[m])) do
62 m = m + 1
63 end
64 end
65 return t[m], m
66 end
67 if le(pos, m_val) then right = m - 1 else left = m + 1 end
68 end
69 end
71 -- binary search a list for the node that contains pos
72 local function around(t, range)
73 if not (range.start and range.finish) then return end
74 local left, right = 1, #t
75 while left <= right do
76 local m = math.floor(left + (right - left) / 2)
77 local e = t[m]
78 local nxt = t[m + 1]
79 if e.start and range.start >= e.start and e.finish and range.finish <= e.finish + 1 and
80 (not nxt or nxt.start > range.start) -- if adjoined - asdf|(qwer) - this will prefer (qwer)
81 then return e, m end
82 if e.start and range.start < e.start then right = m - 1 else left = m + 1 end
83 end
84 end
86 local function find_after(t, selection, pred, stop)
87 local _, n = t.after(selection.start, startof)
88 while t[n] do
89 if pred(t[n]) then return t[n], n
90 elseif stop and stop(t[n]) then return end
91 n = n + 1
92 end
93 end
95 local function find_before(t, selection, pred, stop)
96 local _, n = t.before(selection.finish, finishof)
97 while t[n] do
98 if pred(t[n]) then return t[n], n
99 elseif stop and stop(t[n]) then return end
100 n = n - 1
104 local function intersects(t, range)
105 local start = t.start + (t.p and #t.p or 0)
106 return start >= range.start and t.finish >= range.finish and start < range.finish
107 or start < range.start and t.finish < range.finish and range.start <= t.finish
110 local function touches(t, range)
111 return t.start + (t.p and #t.p or 0) >= range.start and t.start <= range.start or t.finish == range.finish - 1
114 local function contains(t, range)
115 if t.is_list then
116 return t.start + (t.p and #t.p or 0) < range.start and t.finish > range.finish - 1
117 else
118 return t.start <= range.start and t.finish >= range.finish - 1
122 local function sexp_around(t, range)
123 if t.is_root or (t.is_list and contains(t, range)) then
124 local child, nth = t.around(range)
125 if child and child.is_list and not touches(child, range) then
126 return sexp_around(child, range)
127 else
128 return child, t, nth
133 -- save the current logical position in terms of sexps, not offsets
134 local function _sexp_path(t, range, indices, nodes)
135 if t.is_root or (t.is_list and contains(t, range)) then
136 local child, nth = t.around(range)
137 table.insert(indices, nth)
138 table.insert(nodes, child)
139 if child and child.is_list and not touches(child, range) then
140 _sexp_path(child, range, indices, nodes)
141 elseif not child then
142 table.insert(indices, false)
143 table.insert(nodes, false)
146 return indices, nodes
149 local function sexp_path(t, range)
150 return _sexp_path(t, range, {}, {})
153 -- find a sexp by following a previously-saved "sexp path"
154 local function goto_path(root, path)
155 local dest = root
156 local parent, nth
157 for n, i in ipairs(path) do
158 dest = dest[i]
159 nth = i
160 parent = not parent and root or parent[path[n - 1]]
162 return dest, parent, nth
165 local function catchup(tree, range)
166 if range.finish and range.finish >= tree.first_invalid then
167 tree.parse_to(range.finish + 1)
171 local function ensure_reparse(func)
172 return function(t)
173 return function(range, ...)
174 -- TODO: get rid of this type check. It is here because .after and .before take pos, not range
175 catchup(t, type(range) == "number" and {finish = range} or range)
176 return func(t, range, ...)
181 local root_methods = {
182 around = around,
183 before = before,
184 after = after,
185 find_after = find_after,
186 find_before = find_before,
187 sexp_at = sexp_around,
188 sexp_path = sexp_path,
191 local function bind(func)
192 return function(t)
193 return function(...)
194 return func(t, ...)
199 local list_methods = {
200 is_list = function(_) return true end,
201 is_empty = function(t) return #t == 0 end,
202 around = bind(around),
203 before = bind(before),
204 after = bind(after),
205 find_after = bind(find_after),
206 find_before = bind(find_before),
209 local atom_methods = {}
211 local quasiatom_methods = {} -- "quasi atom" is a word inside a string or comment
213 local function text(t)
214 local len = t.finish + 1 - t.start
215 return read(t.start, len)
218 atom_methods.text = text
219 list_methods.text = text
220 quasiatom_methods.text = text
222 local function dispatch(vtable)
223 return function(self, key)
224 if vtable[key] then
225 return vtable[key](self)
230 local quasiatom_node = {
231 __index = dispatch(quasiatom_methods)
234 local quasilist_methods = (function(word) -- "quasi list" is a string or comment
235 return {
236 start_before = function(t, pos)
237 pos = pos - (t.start + #t.d)
238 local stops = P{Ct(((1 - word)^0 * Cp() * Cmt(Cc(pos), past) * word)^0)}:match(t.itext)
239 local newpos = stops and stops[#stops]
240 return newpos and (t.start + #t.d + newpos - 1)
241 end,
243 start_after = function(t, pos)
244 local base = t.start + #t.d
245 if pos >= base then
246 local stop = P{(1 - word) * Cp() * word + 1 * V(1)}:match(t.itext, pos - base + 1)
247 return stop and (base + stop - 1)
248 else
249 local stop = P{(1 - word)^1 * Cp() * word}:match(t.itext)
250 return stop and (base + stop - 1) or base
252 end,
254 finish_before = function(t, pos)
255 pos = pos - (t.start + #t.d)
256 local stops = P{Ct(((1 - word)^0 * word * Cp() * Cmt(Cc(pos), past))^0)}:match(t.itext)
257 local newpos = stops and stops[#stops]
258 return newpos and (t.start + #t.d + newpos - 1)
259 end,
261 finish_after = function(t, pos)
262 local base = t.start + #t.d
263 local stop = P{word * Cp() + 1 * V(1)}:match(t.itext, (pos > base and pos - base + 1 or nil))
264 return stop and (base + stop - 1)
265 end,
267 word_at = function(t, range)
268 local r = {}
269 r.start = range.start - (t.start + #t.d)
270 r.finish = range.finish - (t.start + #t.d)
271 local start, finish = P{Cmt(Cp() * word * Cp() * Cc(r), at_pos) +
272 Cmt(1 * Cc(r.finish + 1), past) * V(1)}:match(t.itext)
273 if not start then return end
274 local node = {start = t.start + #t.d + start, finish = t.start + #t.d + finish - 1}
275 return setmetatable(node, quasiatom_node)
276 end,
278 end)(P{"\\" * P(1 - hspace - newline) + P(1 - hspace - newline)^1})
280 local function quasilist_is_empty(t)
281 return not t.finish_after(t.start + (t.p and #t.p or 0) + #t.d)
284 local function quasilist_new(opposite)
285 local methods = {}
286 for k, v in pairs(quasilist_methods) do
287 methods[k] = bind(v)
289 methods.is_empty = quasilist_is_empty
290 methods.text = text
291 function methods.itext(t)
292 local start = t.start + (t.p and #t.p or 0) + #t.d
293 local len = t.finish + 1 - #opposite[t.d] - start
294 return read(start, len) or ""
297 local quasilist_node = {
298 __index = dispatch(methods)
301 return quasilist_node
304 local atom_node = {
305 __index = dispatch(atom_methods)
308 local list_node = {
309 __index = dispatch(list_methods)
312 local function root_rewind(t)
313 return function(index)
314 assert(index <= t.first_invalid, ("%d > %d - You can only rewind backwards!"):format(index, t.first_invalid))
315 t.first_invalid = index
319 local function root_is_parsed(t)
320 return function(index)
321 return index < t.first_invalid
325 local function root_goto_path(t)
326 return function(path)
327 local range = {start = t[path[1]].finish, finish = t[path[1]].finish}
328 catchup(t, range)
329 return goto_path(t, path)
333 local function root_new(opposite)
334 local methods = {}
335 for k, v in pairs(root_methods) do
336 methods[k] = ensure_reparse(v)
339 methods.rewind = root_rewind
340 methods.is_parsed = root_is_parsed
341 methods.goto_path = root_goto_path
343 local function unbalanced_delimiters(t, range)
344 if t.d or t.is_root then
345 for i = 1, #t do
346 if t[i].d and (contains(t[i], range) or intersects(t[i], range)) then
347 unbalanced_delimiters(t[i], range)
350 if not t.is_root then
351 local start = t.start + (t.p and #t.p or 0) + #t.d
352 local finish = t.finish + 1 - #opposite[t.d]
353 if start <= range.start and finish < range.finish then
354 coroutine.yield({start = finish, finish = t.finish + 1, closing = true})
355 elseif start > range.start and t.finish >= range.finish then
356 coroutine.yield({start = t.start, finish = start, opening = true})
362 function methods.unbalanced_delimiters(t)
363 local slice_at = coroutine.wrap(unbalanced_delimiters)
364 return function(range)
365 local skips = {}
366 catchup(t, range)
367 repeat
368 local slice_pos = slice_at(t, range)
369 if slice_pos then table.insert(skips, slice_pos) end
370 until not slice_pos
371 return skips
375 return methods
378 return {
379 atom = atom_node,
380 list = list_node,
381 quasilist = quasilist_new,
382 root = root_new,
383 _before = before,
384 _around = around,
389 return M