yuck
[vis-pairs.git] / init.lua
blob27cf638fbe585e94d6f11001e7a29d7036aa2540
1 -- SPDX-License-Identifier: GPL-3.0-or-later
2 -- © 2020 Georgi Kirilov
4 require("vis")
5 local vis = vis
7 local l = require("lpeg")
9 local progname = ...
11 -- XXX: in Lua 5.2 unpack() was moved into table
12 local unpack = table.unpack or unpack
14 local M
16 local builtin_textobjects = {
17 ["["] = { "[" , "]" },
18 ["{"] = { "{" , "}" },
19 ["<"] = { "<" , ">" },
20 ["("] = { "(" , ")" },
21 ['"'] = { '"' , '"', name = "A quoted string" },
22 ["'"] = { "'" , "'", name = "A single quoted string" },
23 ["`"] = { "`" , "`", name = "A backtick delimited string" },
26 local builtin_motions = {
27 ["["] = { ["("] = builtin_textobjects["("], ["{"] = builtin_textobjects["{"] },
28 ["]"] = { [")"] = builtin_textobjects["("], ["}"] = builtin_textobjects["{"] },
31 local alias = {
32 ["]"] = "[",
33 ["}"] = "{",
34 [">"] = "<",
35 [")"] = "(",
36 B = "{",
37 b = "(",
40 local function get_pair(key, win)
41 return M.map[win.syntax] and M.map[win.syntax][key]
42 or M.map[1] and M.map[1][key]
43 or builtin_textobjects[key]
44 or builtin_textobjects[alias[key]]
45 or not key:match("%w") and {key, key}
46 end
48 local function at_pos(t, pos)
49 if pos.start + 1 >= t[1] and pos.finish < t[#t] then return t end
50 end
52 local function asymmetric(d, escaped, pos)
53 local p
54 local I = l.Cp()
55 local skip = escaped and escaped + l.P(1) or l.P(1)
56 if #d == 1 then
57 p = (d - l.B"\\") * I * ("\\" * l.P(1) + (skip - d))^0 * I * d
58 else
59 p = d * I * (skip - d)^0 * I * d
60 end
61 return l.Ct(I * p * I) * l.Cc(pos) / at_pos
62 end
64 local function symmetric(d1, d2, escaped, pos)
65 local I = l.Cp()
66 local skip = escaped and escaped + l.P(1) or l.P(1)
67 return l.P{l.Ct(I * d1 * I * ((skip - d1 - d2) + l.V(1))^0 * I * d2 * I) * l.Cc(pos) / at_pos}
68 end
70 local function nth_innermost(t, count)
71 local start, finish, c = 0, 0, count
72 if #t == 5 then
73 start, finish, c = nth_innermost(t[3], count)
74 end
75 if c then
76 return {t[1], t[2]}, {t[#t - 1], t[#t]}, c > 1 and c - 1 or nil
77 end
78 return start, finish
79 end
81 local precedence = {
82 [vis.lexers.COMMENT] = {vis.lexers.STRING},
83 [vis.lexers.STRING] = {},
86 local function selection_range(win, pos)
87 for selection in win:selections_iterator() do
88 if selection.pos == pos then
89 return selection.range
90 end
91 end
92 end
94 local prev_match
96 local function any_captures(_, position, t)
97 if type(t) == "table" then
98 return position, t
99 end
100 if t then
101 prev_match = position - t
105 local function not_past(_, position, pos)
106 local newpos = prev_match > position and prev_match or position
107 return newpos <= pos and newpos or false
110 local function match_at(str, pattern, pos)
111 prev_match = 0
112 local I = l.Cp()
113 local p = l.P{l.Cmt(l.Ct(I * (pattern/0) * I) * l.Cc(pos) / at_pos * l.Cc(0), any_captures) + 1 * l.Cmt(l.Cc(pos.start + 1), not_past) * l.V(1)}
114 local t = p:match(str)
115 if t then return t[1] - 1, t[#t] - 1 end
118 --- Returns a unique grammar rule name for the given lexer's rule name.
119 local function rule_id(lexer, name) return lexer._name .. '.' .. name end
121 -- get_rule that doesn't assert on me:
122 local function get_rule(lexer, id)
123 if lexer._lexer then lexer = lexer._lexer end -- proxy; get true parent
124 if id == 'whitespace' then return l.V(rule_id(lexer, id)) end -- special case
125 return (lexer._RULES or lexer._rules)[id]
128 local function escaping_context(lexer, range, data)
129 local p
130 for _, name in ipairs({vis.lexers.COMMENT, vis.lexers.STRING}) do
131 local rule = get_rule(lexer, name)
132 if rule then
133 p = p and p + rule / 0 or rule / 0
136 if not p then return {} end
137 if not range then return {escape = p} end -- means we are retrying with a "fake" pos
138 local e1, e2 = match_at(data, p, range)
139 if not (e1 and e2) then return {escape = p} end
140 p = nil
141 local escaped_range = {e1 + 1, e2}
142 local escaped_data = data:sub(e1 + 1, e2)
143 for _, level in ipairs({vis.lexers.COMMENT, vis.lexers.STRING}) do
144 if l.match(get_rule(lexer, level) / 0 * -1, escaped_data) then
145 for _, name in ipairs(precedence[level]) do
146 local rule = get_rule(lexer, name)
147 if rule then
148 p = p and p + rule / 0 or rule / 0
151 return {escape = p, range = escaped_range}
156 local function get_range(key, win, pos, file_data, count)
157 if not win.syntax then return end
158 local d = get_pair(key, win)
159 if not d then return end
160 local lexer = vis.lexers.load(win.syntax)
161 repeat
162 local sel_range = selection_range(win, pos)
163 local c = escaping_context(lexer, sel_range, file_data)
164 local range = c.range or {1, #file_data}
165 local correction = range[1] - 1
166 pos = pos - correction
167 if sel_range then
168 sel_range.start = sel_range.start - correction
169 sel_range.finish = sel_range.finish - correction
170 else
171 sel_range = {start = pos + 1, finish = pos + 2}
173 local p = d[1] ~= d[2] and symmetric(d[1], d[2], c.escape, sel_range) or asymmetric(d[1], c.escape, sel_range)
174 local can_abut = d[1] == d[2] and #d[1] == 1 and not (builtin_textobjects[key] or M.map[1][key] or M.map[win.syntax] and M.map[win.syntax][key])
175 local skip = c.escape and c.escape + 1 or 1
176 local data = c.range and file_data:sub(unpack(c.range)) or file_data
177 local pattern = l.P{l.Cmt(p * l.Cc(can_abut and 1 or 0), any_captures) + skip * l.Cmt(l.Cc(pos + 1), not_past) * l.V(1)}
178 prev_match = 0
179 local hierarchy = pattern:match(data)
180 if hierarchy then
181 local offsets = {nth_innermost(hierarchy, count or 1)}
182 offsets[3] = nil -- a leftover from calling nth_innermost() with count higher than the hierarchy depth.
183 for _, o in ipairs(offsets) do
184 for i, v in ipairs(o) do
185 o[i] = v - 1 + correction
188 return unpack(offsets)
189 else
190 pos = correction - 1
192 until hierarchy or pos < 0
195 local function keep_last(acc, cur)
196 if #acc == 0 then
197 acc[1] = cur
198 else
199 acc[2] = cur
201 return acc
204 local function barf_linewise(win, content, start, finish)
205 if vis.mode == vis.modes.VISUAL_LINE then
206 local skip
207 if win.syntax then
208 local rules = vis.lexers.load(win.syntax)._RULES
209 for _, name in ipairs({vis.lexers.COMMENT, vis.lexers.STRING}) do
210 if rules[name] then
211 skip = skip and skip + rules[name] / 0 or rules[name] / 0
215 skip = skip and skip + 1 or 1
216 start, finish = unpack(l.match(l.Cf(l.Cc({}) * (l.Cp() * l.P"\n" + skip * l.Cmt(l.Cc(finish), not_past))^0, keep_last), content, start + 1))
218 return start, finish
221 local function get_delimiters(key, win, pos, count)
222 local d = get_pair(key, win)
223 if not d or type(d[1]) == "string" and type(d[2]) == "string" then return d end
224 local content = win.file:content(0, win.file.size)
225 local start, finish = get_range(key, win, pos, content, count or vis.count)
226 if start and finish then
227 return {win.file:content(start[1], start[2] - start[1]), win.file:content(finish[1], finish[2] - finish[1]), d[3], d.prompt}
228 elseif #d > 2 then
229 return {nil, nil, d[3], d.prompt}
233 local function outer(win, pos, content, count)
234 local start, finish = get_range(M.key, win, pos, content, count)
235 if start and finish then return start[1], finish[2] end
238 local function inner(win, pos, content, count)
239 local start, finish = get_range(M.key, win, pos, content, count)
240 if start and finish then return barf_linewise(win, content, start[2], finish[1]) end
243 local function opening(win, pos, content, count)
244 local start, _ = get_range(M.key, win, pos, content, count)
245 if not start then return pos end
246 local exclusive = vis.mode == vis.modes.OPERATOR_PENDING and pos >= start[2] or vis.mode == vis.modes.VISUAL and pos < start[2] - 1
247 return start[2] - 1 + (exclusive and 1 or 0), vis.mode == vis.modes.OPERATOR_PENDING and pos >= start[2]
250 local function closing(win, pos, content, count)
251 local _, finish = get_range(M.key, win, pos, content, count)
252 if not finish then return pos end
253 local exclusive = vis.mode == vis.modes.VISUAL and pos > finish[1]
254 return finish[1] - (exclusive and 1 or 0)
257 local done_once
259 local function bail_early()
260 if vis.count and vis.count > 1 then
261 if done_once then
262 done_once = nil
263 return true
264 else
265 done_once = true
268 return false
271 local function win_map(textobject, prefix, binding, help)
272 return function(win)
273 if not textobject then
274 win:map(vis.modes.NORMAL, prefix, binding, help)
276 win:map(vis.modes.VISUAL, prefix, binding, help)
277 win:map(vis.modes.OPERATOR_PENDING, prefix, binding, help)
281 local function bind_builtin(key, execute, id)
282 return function()
283 M.key = key
284 execute(vis, id)
288 local function prep(func)
289 return function(win, pos)
290 if bail_early() then return pos end
291 local content = win.file:content(0, win.file.size)
292 local start, finish = func(win, pos, content, vis.count)
293 if not vis.count and vis.mode == vis.modes.VISUAL or start and not finish then
294 local old = selection_range(win, pos)
295 local same_or_smaller = finish and start >= old.start and finish <= old.finish
296 local didnt_move = not finish and start == pos
297 if same_or_smaller or didnt_move then
298 start, finish = func(win, pos, content, 2)
301 return start, finish
305 local function h(msg)
306 return string.format("|@%s| %s", progname, msg)
309 local mappings = {}
311 local function new(execute, register, prefix, handler, help)
312 local id = register(vis, prep(handler))
313 if id < 0 then
314 return false
316 if prefix then
317 local binding = function(keys)
318 if #keys < 1 then return -1 end
319 if #keys == 1 then
320 M.key = keys
321 execute(vis, id)
323 return #keys
325 table.insert(mappings, win_map(execute == vis.textobject, prefix, binding, help))
326 local builtin = execute == vis.motion and builtin_motions[prefix] or builtin_textobjects
327 for key, _ in pairs(builtin) do
328 local d = builtin[key]
329 local simple = type(d[1]) == "string" and type(d[2]) == "string" and d[1] .. d[2]
330 local hlp = (execute == vis.motion and help or "") .. (d.name or (simple or "pattern-delimited") .. " block")
331 if execute ~= vis.textobject then
332 vis:map(vis.modes.NORMAL, prefix .. key, bind_builtin(key, execute, id), h(hlp))
334 local variant = prefix == M.prefix.outer and " (outer variant)" or prefix == M.prefix.inner and " (inner variant)" or ""
335 vis:map(vis.modes.VISUAL, prefix .. key, bind_builtin(key, execute, id), h(hlp and hlp .. variant or help))
336 vis:map(vis.modes.OPERATOR_PENDING, prefix .. key, bind_builtin(key, execute, id), h(hlp and hlp .. variant or help))
339 return id
342 vis.events.subscribe(vis.events.WIN_OPEN, function(win)
343 for _, map_keys in ipairs(mappings) do
344 map_keys(win)
346 local function delete_pair(direction, do_delete)
347 return function()
348 local locations = {}
349 for selection in win:selections_iterator() do
350 local pos = selection.pos
351 if pos - direction < 0 then return end
352 local key = win.file:content(pos - direction, 1)
353 local p = M.map[win.syntax] and M.map[win.syntax][key]
354 or M.map[1] and M.map[1][key]
355 or builtin_textobjects[key]
356 or builtin_textobjects[alias[key]]
357 local left, len = pos - direction, #key
358 if p and (key == p[1] or key == p[2]) then
359 M.key = p[1]
360 local start, finish = inner(win, pos, win.file:content(0, win.file.size))
361 if start and start == finish and pos == start then
362 left = start - #p[1]
363 len = #p[1] + #p[2]
366 locations[selection.number] = len - 1
367 if do_delete then
368 win.file:delete(left, len)
369 selection.pos = left
372 return locations
375 M.unpair[win] = delete_pair(1)
376 if M.autopairs and (not vis_parkour or vis_parkour(win)) then
377 win:map(vis.modes.INSERT, "<Backspace>", delete_pair(1, true))
378 win:map(vis.modes.INSERT, "<Delete>", delete_pair(0, true))
380 end)
382 vis.events.subscribe(vis.events.WIN_CLOSE, function(win)
383 M.unpair[win] = nil
384 end)
386 vis.events.subscribe(vis.events.INIT, function()
387 local function cmp(_, _, c1, c2) return c1 == c2 end
388 local function casecmp(_, _, c1, c2) return c1:lower() == c2:lower() end
389 local function end_tag(s1, s2, cmpfunc) return l.Cmt(s1 * l.Cb("t") * l.C((1 - l.P(s2))^1) * s2, cmpfunc) end
390 local tex_environment = {"\\begin{" * l.Cg(l.R("az", "AZ")^1, "t") * "}", end_tag("\\end{", "}", cmp), {"\\begin{\xef\xbf\xbd}", "\\end{\xef\xbf\xbd}"}, prompt = "environment name"}
391 local tag_name = (l.S"_:" + l.R("az", "AZ")) * (l.R("az", "AZ", "09") + l.S"_:.-")^0
392 local noslash = {--[[implicit:]] p=1, dt=1, dd=1, li=1, --[[void:]] area=1, base=1, br=1, col=1, embed=1, hr=1, img=1, input=1, link=1, meta=1, param=1, source=1, track=1, wbr=1}
393 local function is_not(_, _, v) return v ~= 1 end
394 local html_tag = {"<" * l.Cg(l.Cmt(tag_name / string.lower / noslash, is_not), "t") * (1 - l.S"><")^0 * (">" - l.B"/"), end_tag("</", ">", casecmp), {"<\xef\xbf\xbd>", "</\xef\xbf\xbd>"}, prompt = "tag name"}
395 local xml_tag = {"<" * l.Cg(tag_name, "t") * (1 - l.S"><")^0 * (">" - l.B"/"), end_tag("</", ">", cmp), {"<\xef\xbf\xbd>", "</\xef\xbf\xbd>"}, prompt = "tag name", name = "<tag></tag> block"}
396 local function any_pair(set, default) return {l.Cg(l.S(set), "s"), l.Cmt(l.Cb("s") * l.C(1), function(_, _, c1, c2) return builtin_textobjects[c1][2] == c2 end), builtin_textobjects[default]} end
397 local any_bracket = any_pair("({[", "(")
398 local presets = {
399 {t = xml_tag},
400 xml = {t = xml_tag},
401 html = {t = html_tag},
402 markdown = {t = html_tag, ["_"] = {"_", "_"}, ["*"] = {"*", "*"}},
403 asp = {t = html_tag},
404 jsp = {t = html_tag},
405 php = {t = html_tag},
406 rhtml = {t = html_tag},
407 scheme = {b = any_bracket},
408 clojure = {b = any_bracket},
409 fennel = {b = any_bracket},
410 latex = {e = tex_environment},
412 for syntax, bindings in pairs(presets) do
413 if not M.map[syntax] then
414 M.map[syntax] = bindings
415 else
416 for key, pattern in pairs(bindings) do
417 if not M.map[syntax][key] then M.map[syntax][key] = pattern end
421 for key, d in pairs(M.map[1]) do
422 builtin_textobjects[key] = {d[1], d[2], name = d.name}
423 builtin_motions[M.prefix.opening][key] = builtin_textobjects[key]
424 builtin_motions[M.prefix.closing][key] = builtin_textobjects[key]
427 M.motion = {
428 opening = new(vis.motion, vis.motion_register, M.prefix.opening, opening, "Move cursor to the beginning of a "),
429 closing = new(vis.motion, vis.motion_register, M.prefix.closing, closing, "Move cursor to the end of a "),
431 M.textobject = {
432 inner = new(vis.textobject, vis.textobject_register, M.prefix.inner, inner, "Delimited block (inner variant)"),
433 outer = new(vis.textobject, vis.textobject_register, M.prefix.outer, outer, "Delimited block (outer variant)"),
436 if M.autopairs then
437 vis.events.subscribe(vis.events.INPUT, function(key)
438 if vis.mode == vis.modes.REPLACE then return end
439 local win = vis.win
440 if vis_parkour and vis_parkour(win) then return end
441 local p = M.map[win.syntax] and M.map[win.syntax][key]
442 or M.map[1] and M.map[1][key]
443 or builtin_textobjects[key]
444 or builtin_textobjects[alias[key]]
445 if not p then return end
446 if M.no_autopairs[key] and M.no_autopairs[key][win.syntax or ""] then return end
447 for selection in win:selections_iterator() do
448 local pos = selection.pos
449 M.key = key
450 local _, finish = outer(win, pos, win.file:content(0, win.file.size))
451 if key == p[1] and p[1] ~= p[2] or p[1] == p[2] and pos + 1 ~= finish then
452 win.file:insert(pos, p[2])
453 selection.pos = pos
454 elseif key == p[2] and pos + 1 == finish then
455 win.file:delete(pos, #p[2])
456 selection.pos = pos
459 end)
462 end)
464 M = {
465 map = {},
466 get_pair = get_delimiters,
467 get_range_inner = inner,
468 get_range_outer = outer,
469 prefix = {outer = "a", inner = "i", opening = "[", closing = "]"},
470 autopairs = true,
471 no_autopairs = {["'"] = {markdown = true, [""] = true}},
472 unpair = {}
475 vis_pairs = M
477 return M