Lua: loadfile/dofile: Support basenames
[lsnes.git] / src / lua / loadfile.cpp
blobcbad9c2b57b6f8fe9a7091bea4d018b9a1928f81
1 #include "lua/internal.hpp"
2 #include "library/minmax.hpp"
3 #include "library/zip.hpp"
4 #include "core/memorymanip.hpp"
5 #include <functional>
7 namespace
9 std::string dashstring(char ch, int dashes)
11 if(dashes)
12 return std::string(1, ch) + std::string(dashes, '=') + std::string(1, ch);
13 else
14 return std::string(1, ch) + std::string(1, ch);
17 struct replace
19 replace()
21 upper_buf = NULL;
22 upper_ptr = 0;
23 upper_size = 0;
24 upper_eof = false;
25 matched = 0;
26 copied = 0;
27 source = 0;
28 target = "[[]]";
30 replace(const std::string& _target)
31 : replace()
33 target = _target;
35 std::pair<const char*, size_t> run(std::function<std::pair<const char*, size_t>()> fn);
36 private:
37 char buffer[4096];
38 std::string target;
39 size_t matched;
40 size_t copied;
41 int source;
42 const char* upper_buf;
43 size_t upper_ptr;
44 size_t upper_size;
45 bool upper_eof;
48 std::string pattern = "@@LUA_SCRIPT_FILENAME@@";
50 std::pair<const char*, size_t> replace::run(std::function<std::pair<const char*, size_t>()> fn)
52 size_t emitted = 0;
53 while(emitted < sizeof(buffer)) {
54 while(upper_ptr == upper_size && !upper_eof) {
55 auto g = fn();
56 upper_buf = g.first;
57 upper_size = g.second;
58 upper_ptr = 0;
59 if(!upper_buf && !upper_size)
60 upper_eof = true;
62 if(upper_ptr == upper_size && source == 0) {
63 if(!matched)
64 break;
65 copied = 0;
66 source = 1;
68 switch(source) {
69 case 0: //Upper_buf.
70 if(upper_buf[upper_ptr] == pattern[matched]) {
71 matched++;
72 upper_ptr++;
73 if(matched == pattern.length()) {
74 source = 2;
75 copied = 0;
77 } else if(matched) {
78 //Flush the rest.
79 source = 1;
80 copied = 0;
81 } else {
82 buffer[emitted++] = upper_buf[upper_ptr++];
84 break;
85 case 1: //Source.
86 if(matched == 2 && upper_ptr < upper_size && upper_buf[upper_ptr] == '@') {
87 //This is exceptional, just flush the first '@'.
88 buffer[emitted++] = '@';
89 upper_ptr++;
90 matched = 2;
91 source = 0;
92 } else if(copied == matched) {
93 //End.
94 matched = 0;
95 source = 0;
96 } else {
97 buffer[emitted++] = pattern[copied++];
99 break;
100 case 2: //Target.
101 if(copied == target.size()) {
102 //End
103 matched = 0;
104 source = 0;
105 } else {
106 buffer[emitted++] = target[copied++];
108 break;
111 if(!emitted)
112 return std::make_pair(reinterpret_cast<const char*>(NULL), 0);
113 return std::make_pair(buffer, emitted);
116 struct reader
118 reader(std::istream& _s, const std::string& fn)
119 : s(_s)
121 int dashes = 0;
122 while(true) {
123 std::string tmpl = dashstring(']', dashes);
124 if(fn.find(tmpl) == std::string::npos)
125 break;
127 rpl = replace(dashstring('[', dashes) + fn + dashstring(']', dashes));
129 const char* rfn(lua_State* L, size_t* size);
130 static const char* rfn(lua_State* L, void* data, size_t* size)
132 return reinterpret_cast<reader*>(data)->rfn(L, size);
134 private:
135 std::istream& s;
136 replace rpl;
139 const char* reader::rfn(lua_State* L, size_t* size)
141 auto g = rpl.run([this]() -> std::pair<const char*, size_t> {
142 size_t size;
143 static char buffer[4096];
144 if(!this->s)
145 return std::make_pair(reinterpret_cast<const char*>(NULL), 0);
146 this->s.read(buffer, sizeof(buffer));
147 size = this->s.gcount();
148 if(!size) {
149 return std::make_pair(reinterpret_cast<const char*>(NULL), 0);
151 return std::make_pair(buffer, size);
153 *size = g.second;
154 return g.first;
157 void load_chunk(lua_state& L, const std::string& fname)
159 std::string file2;
160 std::string file1 = L.get_string(1, fname.c_str());
161 if(L.type(2) != LUA_TNIL && L.type(2) != LUA_TNONE)
162 file2 = L.get_string(2, fname.c_str());
163 std::string absfilename = resolve_file_relative(file1, file2);
164 std::istream& file = open_file_relative(file1, file2);
165 std::string chunkname;
166 if(file2 != "")
167 chunkname = file2 + "[" + file1 + "]";
168 else
169 chunkname = file1;
170 reader rc(file, absfilename);
171 int r = lua_load(L.handle(), reader::rfn, &rc, chunkname.c_str()
172 #if LUA_VERSION_NUM == 502
173 , "t"
174 #endif
176 delete &file;
177 if(r == 0) {
178 return;
179 } else if(r == LUA_ERRSYNTAX) {
180 (stringfmt() << "Syntax error: " << L.tostring(-1)).throwex();
181 } else if(r == LUA_ERRMEM) {
182 (stringfmt() << "Out of memory: " << L.tostring(-1)).throwex();
183 } else {
184 (stringfmt() << "Unknown error: " << L.tostring(-1)).throwex();
188 function_ptr_luafun loadfile2(LS, "loadfile2", [](lua_state& L, const std::string& fname)
189 -> int {
190 load_chunk(L, fname);
191 return 1;
194 function_ptr_luafun dofile2(LS, "dofile2", [](lua_state& L, const std::string& fname)
195 -> int {
196 load_chunk(L, fname);
197 int old_sp = lua_gettop(L.handle());
198 lua_call(L.handle(), 0, LUA_MULTRET);
199 int new_sp = lua_gettop(L.handle());
200 return new_sp - (old_sp - 1);