Fix race between killing object and drawing object
[lsnes.git] / include / library / lua-class.hpp
blob5de1e40d29bb6867a946329c06ff1de5c948409b
1 #ifndef _library__lua_class__hpp__included__
2 #define _library__lua_class__hpp__included__
4 #include "lua-base.hpp"
5 #include "lua-pin.hpp"
7 namespace lua
9 class class_base;
10 class parameters;
12 /**
13 * Group of classes.
15 class class_group
17 public:
18 /**
19 * Create a group.
21 class_group();
22 /**
23 * Destroy a group.
25 ~class_group();
26 /**
27 * Add a class to group.
29 void do_register(const std::string& name, class_base& fun);
30 /**
31 * Drop a class from group.
33 void do_unregister(const std::string& name, class_base& fun);
34 /**
35 * Request callbacks on all currently registered functions.
37 void request_callback(std::function<void(std::string, class_base*)> cb);
38 /**
39 * Bind a callback.
41 * Callbacks for all registered functions are immediately called.
43 int add_callback(std::function<void(std::string, class_base*)> cb,
44 std::function<void(class_group*)> dcb);
45 /**
46 * Unbind a calback.
48 void drop_callback(int handle);
49 private:
50 char dummy;
53 struct class_ops
55 bool (*is)(state& _state, int index);
56 const std::string& (*name)();
57 std::string (*print)(state& _state, int index);
60 std::list<class_ops>& userdata_recogn_fns();
61 std::string try_recognize_userdata(state& _state, int index);
62 std::string try_print_userdata(state& _state, int index);
63 std::unordered_map<std::type_index, void*>& class_types();
65 /**
66 * Helper class containing binding data for Lua class call.
68 template<class T> struct class_binding
70 /**
71 * The pointer to call.
73 int (T::*fn)(state& lstate, lua::parameters& P);
74 /**
75 * The state to call it in.
77 state* _state;
78 /**
79 * The name of the method to pass.
81 char fname[];
84 /**
85 * Helper class containing binding data for Lua static class call.
87 struct static_binding
89 /**
90 * The pointer to call.
92 int (*fn)(state& lstate, parameters& P);
93 /**
94 * The state to call it in.
96 state* _state;
97 /**
98 * The name of the method to pass.
100 char fname[];
103 template<class T> class _class;
106 * Function to obtain class object for given Lua class.
108 template<class T> _class<T>& objclass()
110 auto& type = typeid(T);
111 if(!class_types().count(type))
112 throw std::runtime_error("Internal error: Lua class not found!");
113 return *reinterpret_cast<_class<T>*>(class_types()[type]);
117 * A class method.
119 template<class T> struct class_method
122 * Name.
124 const char* name;
126 * Function.
128 int (T::*fn)(state& LS, lua::parameters& P);
132 * A static class method.
134 struct static_method
137 * Name.
139 const char* name;
141 * Function.
143 int (*fn)(state& LS, parameters& P);
147 * Virtual base of Lua classes
149 class class_base
151 public:
153 * Create a new Lua class.
155 * Parameter _group: The group the class will be in.
156 * Parameter _name: The name of the class.
158 class_base(class_group& _group, const std::string& _name);
160 * Dtor.
162 virtual ~class_base() throw();
164 * Lookup by name in given Lua state.
166 * Parameter _L: The Lua state to look in.
167 * Parameter _name: The name of the class.
168 * Returns: The class instance, or NULL if no match.
170 static class_base* lookup(state& L, const std::string& _name);
172 * Push class table to stack.
174 static bool lookup_and_push(state& L, const std::string& _name);
176 * Get set of all classes.
178 static std::set<std::string> all_classes(state& L);
180 * Register in given Lua state.
182 virtual void register_state(state& L) = 0;
184 * Lookup static methods in class.
186 virtual std::list<static_method> static_methods() = 0;
188 * Lookup class methods in class.
190 virtual std::set<std::string> class_methods() = 0;
192 * Get name of class.
194 const std::string& get_name() { return name; }
195 protected:
196 void delayed_register();
197 void register_static(state& L);
198 private:
199 class_group& group;
200 std::string name;
201 bool registered;
204 static const size_t overcommit_std_align = 32;
207 * Align a overcommit pointer.
209 template<typename T, typename U> U* align_overcommit(T* th)
211 size_t ptr = reinterpret_cast<size_t>(th) + sizeof(T);
212 return reinterpret_cast<U*>(ptr + (overcommit_std_align - ptr % overcommit_std_align) % overcommit_std_align);
216 * The type of Lua classes.
218 template<class T> class _class : public class_base
220 template<typename... U> T* _create(state& _state, U... args)
222 size_t overcommit = T::overcommit(args...);
223 void* obj = _state.newuserdata(sizeof(T) + overcommit);
224 load_metatable(_state);
225 _state.setmetatable(-2);
226 T* _obj = reinterpret_cast<T*>(obj);
227 try {
228 new(_obj) T(_state, args...);
229 } catch(...) {
230 //CTOR FAILED. Get rid of the dtor (since it would error) and then dump the object.
231 _state.newtable();
232 _state.setmetatable(-2);
233 _state.pop(1);
234 throw;
236 return _obj;
239 static int class_bind_trampoline(lua_State* LS)
241 try {
242 class_binding<T>* b = (class_binding<T>*)lua_touserdata(LS, lua_upvalueindex(1));
243 state L(*b->_state, LS);
244 T* p = _class<T>::get(L, 1, b->fname);
245 lua::parameters P(L, b->fname);
246 return (p->*(b->fn))(L, P);
247 } catch(std::exception& e) {
248 std::string err = e.what();
249 lua_pushlstring(LS, err.c_str(), err.length());
250 lua_error(LS);
252 return 0; //NOTREACHED
255 T* _get(state& _state, int arg, const std::string& fname, bool optional = false)
257 if(_state.type(arg) == LUA_TNONE || _state.type(arg) == LUA_TNIL) {
258 if(optional)
259 return NULL;
260 else
261 goto badtype;
263 load_metatable(_state);
264 if(!_state.getmetatable(arg))
265 goto badtype;
266 if(!_state.rawequal(-1, -2))
267 goto badtype;
268 _state.pop(2);
269 return reinterpret_cast<T*>(_state.touserdata(arg));
270 badtype:
271 (stringfmt() << "argument #" << arg << " to " << fname << " must be " << name).throwex();
272 return NULL; //Never reached.
275 bool _is(state& _state, int arg)
277 if(_state.type(arg) != LUA_TUSERDATA)
278 return false;
279 load_metatable(_state);
280 if(!_state.getmetatable(arg)) {
281 _state.pop(1);
282 return false;
284 bool ret = _state.rawequal(-1, -2);
285 _state.pop(2);
286 return ret;
289 objpin<T> _pin(state& _state, int arg, const std::string& fname)
291 T* obj = get(_state, arg, fname);
292 _state.pushvalue(arg);
293 objpin<T> t(_state, obj);
294 _state.pop(1);
295 return t;
298 void bind(state& _state, const char* keyname, int (T::*fn)(state& LS, lua::parameters& P))
300 load_metatable(_state);
301 _state.pushstring(keyname);
302 std::string fname = name + std::string("::") + keyname;
303 void* ptr = _state.newuserdata(sizeof(class_binding<T>) + fname.length() + 1);
304 class_binding<T>* bdata = reinterpret_cast<class_binding<T>*>(ptr);
305 bdata->fn = fn;
306 bdata->_state = &_state.get_master();
307 std::copy(fname.begin(), fname.end(), bdata->fname);
308 bdata->fname[fname.length()] = 0;
309 _state.pushcclosure(class_bind_trampoline, 1);
310 _state.rawset(-3);
311 _state.pop(1);
313 protected:
314 void register_state(state& L)
316 static char once_key;
317 register_static(L);
318 if(L.do_once(&once_key))
319 for(auto i : cmethods)
320 bind(L, i.name, i.fn);
322 public:
324 * Create a new Lua class.
326 * Parameter _group: The group the class will be in.
327 * Parameter _name: The name of the class.
328 * Parameter _smethods: Static methods of the class.
329 * Parameter _cmethods: Class methods of the class.
330 * Parameter _print: The print method.
332 _class(class_group& _group, const std::string& _name, std::initializer_list<static_method> _smethods,
333 std::initializer_list<class_method<T>> _cmethods = {}, std::string (T::*_print)() = NULL)
334 : class_base(_group, _name), smethods(_smethods), cmethods(_cmethods)
336 name = _name;
337 class_ops m;
338 printmeth = _print;
339 m.is = _class<T>::is;
340 m.name = _class<T>::get_name;
341 m.print = _class<T>::print;
342 userdata_recogn_fns().push_back(m);
343 auto& type = typeid(T);
344 class_types()[type] = this;
345 delayed_register();
348 * Dtor
350 ~_class() throw()
352 auto& type = typeid(T);
353 class_types().erase(type);
354 auto& fns = userdata_recogn_fns();
355 for(auto i = fns.begin(); i != fns.end(); i++) {
356 if(i->is == _class<T>::is) {
357 fns.erase(i);
358 break;
363 * Create a new instance of object.
365 * Parameter _state: The Lua state to create the object in.
366 * Parameter args: The arguments to pass to class constructor.
368 template<typename... U> static T* create(state& _state, U... args)
370 return objclass<T>()._create(_state, args...);
374 * Get a pointer to the object.
376 * Parameter _state: The Lua state.
377 * Parameter arg: Argument index.
378 * Parameter fname: The name of function for error messages.
379 * Parameter optional: If true and argument is NIL or none, return NULL.
380 * Throws std::runtime_error: Wrong type.
382 static T* get(state& _state, int arg, const std::string& fname, bool optional = false)
383 throw(std::bad_alloc, std::runtime_error)
385 return objclass<T>()._get(_state, arg, fname, optional);
389 * Identify if object is of this type.
391 * Parameter _state: The Lua state.
392 * Parameter arg: Argument index.
393 * Returns: True if object is of specified type, false if not.
395 static bool is(state& _state, int arg) throw()
397 try {
398 return objclass<T>()._is(_state, arg);
399 } catch(...) {
400 return false;
404 * Get name of class.
406 static const std::string& get_name()
408 try {
409 return objclass<T>().name;
410 } catch(...) {
411 static std::string foo = "???";
412 return foo;
416 * Format instance of this class as string.
418 static std::string print(state& _state, int index)
420 T* obj = get(_state, index, "__internal_print");
421 try {
422 auto pmeth = objclass<T>().printmeth;
423 if(pmeth)
424 return (obj->*pmeth)();
425 else
426 return "";
427 } catch(...) {
428 return "";
432 * Get a pin of object against Lua GC.
434 * Parameter _state: The Lua state.
435 * Parameter arg: Argument index.
436 * Parameter fname: Name of function for error message purposes.
437 * Throws std::runtime_error: Wrong type.
439 static objpin<T> pin(state& _state, int arg, const std::string& fname) throw(std::bad_alloc,
440 std::runtime_error)
442 return objclass<T>()._pin(_state, arg, fname);
445 * Lookup static methods.
447 std::list<static_method> static_methods()
449 return smethods;
452 * Lookup class methods.
454 std::set<std::string> class_methods()
456 std::set<std::string> r;
457 for(auto& i : cmethods)
458 r.insert(i.name);
459 return r;
461 private:
462 static int dogc(lua_State* LS)
464 T* obj = reinterpret_cast<T*>(lua_touserdata(LS, 1));
465 obj->~T();
466 return 0;
469 static int newindex(lua_State* LS)
471 lua_pushstring(LS, "Writing metatables of classes is not allowed");
472 lua_error(LS);
473 return 0;
476 static int index(lua_State* LS)
478 lua_getmetatable(LS, 1);
479 lua_pushvalue(LS, 2);
480 lua_rawget(LS, -2);
481 if(lua_type(LS, -1) == LUA_TNIL) {
482 std::string err = std::string("Class '") + lua_tostring(LS, lua_upvalueindex(1)) +
483 "' does not have class method '" + lua_tostring(LS, 2) + "'";
484 lua_pushstring(LS, err.c_str());
485 lua_error(LS);
487 return 1;
490 void load_metatable(state& _state)
492 again:
493 _state.pushlightuserdata(this);
494 _state.rawget(LUA_REGISTRYINDEX);
495 if(_state.type(-1) == LUA_TNIL) {
496 _state.pop(1);
497 _state.pushlightuserdata(this);
498 _state.newtable();
499 _state.pushvalue(-1);
500 _state.setmetatable(-2);
501 _state.pushstring("__gc");
502 _state.pushcfunction(&_class<T>::dogc);
503 _state.rawset(-3);
504 _state.pushstring("__newindex");
505 _state.pushcfunction(&_class<T>::newindex);
506 _state.rawset(-3);
507 _state.pushstring("__index");
508 _state.pushlstring(name);
509 _state.pushcclosure(&_class<T>::index, 1);
510 _state.rawset(-3);
511 _state.rawset(LUA_REGISTRYINDEX);
512 goto again;
515 std::string name;
516 std::list<static_method> smethods;
517 std::list<class_method<T>> cmethods;
518 std::string (T::*printmeth)();
519 _class(const _class<T>&);
520 _class& operator=(const _class<T>&);
524 #endif