Squash commits
[vis-surround.git] / init.lua
blob29e598edc35ab8dff5cec21028323a6f793d39c3
1 -- SPDX-License-Identifier: GPL-3.0-or-later
2 -- © 2020 Georgi Kirilov
4 require("vis")
5 local vis = vis
7 local progname = ...
9 local M = {
10 prefix = {add = {"ys", "S"}, change = {"cs", "C"}, delete = {"ds", "D"}},
13 local builtin_textobjects = {
14 ["["] = {{ "[" , "]" }, id = 7}, -- +/VIS_TEXTOBJECT_OUTER_SQUARE_BRACKET vis.h
15 ["{"] = {{ "{" , "}" }, id = 9}, -- +/VIS_TEXTOBJECT_OUTER_CURLY_BRACKET vis.h
16 ["<"] = {{ "<" , ">" }, id = 11}, -- +/VIS_TEXTOBJECT_OUTER_ANGLE_BRACKET vis.h
17 ["("] = {{ "(" , ")" }, id = 13}, -- +/VIS_TEXTOBJECT_OUTER_PARENTHESIS vis.h
18 ['"'] = {{ '"' , '"' }, id = 15}, -- +/VIS_TEXTOBJECT_OUTER_QUOTE vis.h
19 ["'"] = {{ "'" , "'" }, id = 17}, -- +/VIS_TEXTOBJECT_OUTER_SINGLE_QUOTE vis.h
20 ["`"] = {{ "`" , "`" }, id = 19}, -- +/VIS_TEXTOBJECT_OUTER_BACKTICK vis.h
21 {{ "" , "" }, id = 28}, -- +/VIS_TEXTOBJECT_INVALID vis.h
24 local aliases = {}
25 for key, data in pairs(builtin_textobjects) do
26 local pair = data[1] aliases[pair[2]] = key ~= pair[2] and data or nil
27 end
28 for alias, data in pairs(aliases) do
29 builtin_textobjects[alias] = data
30 end
31 for alias, key in pairs({
32 B = "{",
33 b = "(",
34 }) do builtin_textobjects[alias] = builtin_textobjects[key] end
36 local function get_pair(key) return builtin_textobjects[key] and builtin_textobjects[key][1] end
38 local function take_param(_, d)
39 if d and type(d[3]) == "table" then
40 if #d[3] == 2 then
41 if table.concat(d[3]):find("\xef\xbf\xbd", 1, true) then
42 local status, out = vis:pipe(nil, nil, "vis-menu" .. (d[4] and " -p '" .. d[4] .. ":'" or ""))
43 if status == 0 then
44 local param = out:sub(1, -2)
45 return {d[3][1]:gsub("\xef\xbf\xbd", param), d[3][2]:gsub("\xef\xbf\xbd", param)}
46 end
47 else
48 return d[3]
49 end
50 end
51 else
52 return d
53 end
54 end
56 local function adjust_spacing(file, range, d)
57 local padding = ""
58 if vis.mode == vis.modes.VISUAL_LINE then
59 padding = d[1] ~= "\n" and "\n" or padding
60 elseif vis.mode ~= vis.modes.VISUAL then
61 local trailing = file:content(range):match("(%s*)$")
62 if #trailing > 0 then
63 range.finish = range.finish - #trailing
64 end
65 end
66 return padding
67 end
69 local function add(file, range, pos)
70 if range.finish <= range.start then return pos end
71 local d = take_param(vis.win, get_pair(M.key[1], pos))
72 if not d then return pos end
73 local padding = adjust_spacing(file, range, d)
74 file:insert(range.finish, d[2] .. padding)
75 file:insert(range.start, d[1] .. padding)
76 return range.start
77 end
79 local function escape(text)
80 return text:gsub("[][^$)(%%.*+?-]", "%%%0")
81 end
83 local function delimiters_in_place(file, range, pos, key, get_padding)
84 local start, slen, finish, flen
85 if vis.mode == vis.modes.VISUAL_LINE then
86 local block = file:content(range)
87 vis.count = nil
88 local d = get_pair(key, range.start + block:find("\n", 1, true))
89 if not (d and d[1] and d[2]) then return end
90 local d1, d2 = escape(d[1]), escape(d[2])
91 local sl = table.pack(block:match("^()[ \t]*()" .. d1 .. "[ \t]-\n()"))
92 if #sl == 0 then
93 sl = table.pack(block:match("()[ \t]*()" .. d1 .. "[ \t]-()\n"))
94 end
95 local el = table.pack(block:match("()\n[ \t]*()" .. d2 .. "()[ \t]*\n$"))
96 if #el == 0 then
97 el = table.pack(block:match("\n[ \t]*()()" .. d2 .. "[ \t]*()[^\n]-\n$"))
98 end
99 if not (#sl > 0 and #el > 0) then return end
100 start = range.start + sl[get_padding and 1 or 2] - 1
101 slen = get_padding and sl[3] - sl[1] or #d[1]
102 finish = range.start + el[get_padding and 1 or 2] - 1
103 flen = get_padding and el[3] - el[1] or #d[2]
104 else
105 local d = get_pair(key, pos)
106 if not (d and d[1] and d[2]) then return end
107 if file:content(range.start, #d[1]):find(d[1], 1, true)
108 and file:content(range.finish - #d[2], #d[2]):find(d[2], 1, true) then
109 start, slen, finish, flen = range.start, #d[1], range.finish - #d[2], #d[2]
112 return start, slen, finish, flen
115 local function change(file, range, pos)
116 if range.finish <= range.start then return pos end
117 local start, slen, finish, flen = delimiters_in_place(file, range, pos, M.key[1])
118 if not start then return pos end
119 local n = take_param(vis.win, get_pair(M.key[2], pos))
120 if not n then return pos end
121 file:delete(finish, flen)
122 file:insert(finish, n[2])
123 file:delete(start, slen)
124 file:insert(start, n[1])
125 if pos < range.start + slen then
126 return (pos < range.start + #n[1] and pos < range.start + slen - 1 or slen == 1) and pos or range.start + #n[1] - 1
127 elseif pos >= range.finish - flen then
128 return (pos < range.finish - flen + #n[2] and pos < range.finish - 1) and pos - slen + #n[1] or range.finish - slen - flen + #n[1] + #n[2] - 1
129 else
130 return pos - slen + #n[1]
134 local function delete(file, range, pos)
135 if range.finish <= range.start then return pos end
136 local start, slen, finish, flen = delimiters_in_place(file, range, pos, M.key[1], true)
137 if not start then return pos end
138 file:delete(finish, flen)
139 file:delete(start, slen)
140 return range.start
143 local function outer(key)
144 return builtin_textobjects[key] and builtin_textobjects[key].id or builtin_textobjects[1].id
147 local function va_call(id, nargs, needs_range)
148 return function(keys)
149 if #keys < nargs then return -1 end
150 if #keys == nargs then
151 M.key = {}
152 for key in keys:gmatch(".") do table.insert(M.key, key) end
153 vis:operator(id)
154 if needs_range then
155 vis:textobject(outer(M.key[1]))
158 return #keys
162 local function h(msg)
163 return string.format("|@%s| %s", progname, msg)
166 local function operator_new(prefix, handler, nargs, help)
167 local id = vis:operator_register(handler)
168 if id < 0 then
169 return false
171 if type(prefix) == "table" then
172 local needs_range = ({[change] = true, [delete] = true})[handler]
173 if prefix[1] then vis:map(vis.modes.NORMAL, prefix[1], va_call(id, nargs, needs_range), h(help)) end
174 if prefix[2] then vis:map(vis.modes.VISUAL, prefix[2], va_call(id, nargs), h(help)) end
176 return id
179 vis.events.subscribe(vis.events.INIT, function()
180 M.operator = {
181 add = operator_new(M.prefix.add, add, 1, "Add delimiters at range boundaries"),
182 change = operator_new(M.prefix.change, change, 2, "Change delimiters at range boundaries"),
183 delete = operator_new(M.prefix.delete, delete, 1, "Delete delimiters at range boundaries"),
185 local vis_pairs = package.loaded["pairs"] or package.loaded["vis-pairs"]
186 if vis_pairs then
187 get_pair = function(key, pos) return vis_pairs.get_pair(key, vis.win, pos) end
188 outer = function(key) vis_pairs.key = key return vis_pairs.textobject.outer end
190 end)
192 return M