Actually call on_reset callback
[lsnes.git] / include / library / lua-class.hpp
blob0e30add05b4e09504a9c96752a4c96fd9b6df70b
1 #ifndef _library__lua_class__hpp__included__
2 #define _library__lua_class__hpp__included__
4 #include <functional>
5 #include "lua-base.hpp"
6 #include "lua-pin.hpp"
8 namespace lua
10 class class_base;
11 class parameters;
13 /**
14 * Group of classes.
16 class class_group
18 public:
19 /**
20 * Create a group.
22 class_group();
23 /**
24 * Destroy a group.
26 ~class_group();
27 /**
28 * Add a class to group.
30 void do_register(const std::string& name, class_base& fun);
31 /**
32 * Drop a class from group.
34 void do_unregister(const std::string& name, class_base& fun);
35 /**
36 * Request callbacks on all currently registered functions.
38 void request_callback(std::function<void(std::string, class_base*)> cb);
39 /**
40 * Bind a callback.
42 * Callbacks for all registered functions are immediately called.
44 int add_callback(std::function<void(std::string, class_base*)> cb,
45 std::function<void(class_group*)> dcb);
46 /**
47 * Unbind a calback.
49 void drop_callback(int handle);
50 private:
51 char dummy;
54 struct class_ops
56 bool (*is)(state& _state, int index);
57 const std::string& (*name)();
58 std::string (*print)(state& _state, int index);
61 std::list<class_ops>& userdata_recogn_fns();
62 std::string try_recognize_userdata(state& _state, int index);
63 std::string try_print_userdata(state& _state, int index);
64 std::unordered_map<std::type_index, void*>& class_types();
66 /**
67 * Helper class containing binding data for Lua class call.
69 template<class T> struct class_binding
71 /**
72 * The pointer to call.
74 int (T::*fn)(state& lstate, lua::parameters& P);
75 /**
76 * The state to call it in.
78 state* _state;
79 /**
80 * The name of the method to pass.
82 char fname[];
85 /**
86 * Helper class containing binding data for Lua static class call.
88 struct static_binding
90 /**
91 * The pointer to call.
93 int (*fn)(state& lstate, parameters& P);
94 /**
95 * The state to call it in.
97 state* _state;
98 /**
99 * The name of the method to pass.
101 char fname[];
104 template<class T> class _class;
107 * Function to obtain class object for given Lua class.
109 template<class T> _class<T>& objclass()
111 auto& type = typeid(T);
112 if(!class_types().count(type))
113 throw std::runtime_error("Internal error: Lua class not found!");
114 return *reinterpret_cast<_class<T>*>(class_types()[type]);
118 * A class method.
120 template<class T> struct class_method
123 * Name.
125 const char* name;
127 * Function.
129 int (T::*fn)(state& LS, lua::parameters& P);
133 * A static class method.
135 struct static_method
138 * Name.
140 const char* name;
142 * Function.
144 int (*fn)(state& LS, parameters& P);
148 * Virtual base of Lua classes
150 class class_base
152 public:
154 * Create a new Lua class.
156 * Parameter _group: The group the class will be in.
157 * Parameter _name: The name of the class.
159 class_base(class_group& _group, const std::string& _name);
161 * Dtor.
163 virtual ~class_base() throw();
165 * Lookup by name in given Lua state.
167 * Parameter _L: The Lua state to look in.
168 * Parameter _name: The name of the class.
169 * Returns: The class instance, or NULL if no match.
171 static class_base* lookup(state& L, const std::string& _name);
173 * Push class table to stack.
175 static bool lookup_and_push(state& L, const std::string& _name);
177 * Get set of all classes.
179 static std::set<std::string> all_classes(state& L);
181 * Register in given Lua state.
183 virtual void register_state(state& L) = 0;
185 * Lookup static methods in class.
187 virtual std::list<static_method> static_methods() = 0;
189 * Lookup class methods in class.
191 virtual std::set<std::string> class_methods() = 0;
193 * Get name of class.
195 const std::string& get_name() { return name; }
196 protected:
197 void delayed_register();
198 void register_static(state& L);
199 private:
200 class_group& group;
201 std::string name;
202 bool registered;
205 static const size_t overcommit_std_align = 32;
208 * Align a overcommit pointer.
210 template<typename T, typename U> U* align_overcommit(T* th)
212 size_t ptr = reinterpret_cast<size_t>(th) + sizeof(T);
213 return reinterpret_cast<U*>(ptr + (overcommit_std_align - ptr % overcommit_std_align) % overcommit_std_align);
217 * The type of Lua classes.
219 template<class T> class _class : public class_base
221 template<typename... U> T* _create(state& _state, U... args)
223 size_t overcommit = T::overcommit(args...);
224 void* obj = NULL;
225 auto st = &_state;
226 _state.run_interruptable([st, overcommit, &obj]() {
227 obj = st->newuserdata(sizeof(T) + overcommit);
228 }, 0, 1);
229 load_metatable(_state);
230 _state.setmetatable(-2);
231 T* _obj = reinterpret_cast<T*>(obj);
232 try {
233 new(_obj) T(_state, args...);
234 } catch(...) {
235 //CTOR FAILED. Get rid of the dtor (since it would error) and then dump the object.
236 _state.newtable();
237 _state.setmetatable(-2);
238 _state.pop(1);
239 throw;
241 return _obj;
244 static int class_bind_trampoline(state& L)
246 class_binding<T>* b = (class_binding<T>*)L.touserdata(L.trampoline_upval(1));
247 T* p = _class<T>::get(L, 1, b->fname);
248 lua::parameters P(L, b->fname);
249 return (p->*(b->fn))(L, P);
252 T* _get(state& _state, int arg, const std::string& fname, bool optional = false)
254 if(_state.type(arg) == LUA_TNONE || _state.type(arg) == LUA_TNIL) {
255 if(optional)
256 return NULL;
257 else
258 goto badtype;
260 load_metatable(_state);
261 if(!_state.getmetatable(arg))
262 goto badtype;
263 if(!_state.rawequal(-1, -2))
264 goto badtype;
265 _state.pop(2);
266 return reinterpret_cast<T*>(_state.touserdata(arg));
267 badtype:
268 (stringfmt() << "argument #" << arg << " to " << fname << " must be " << name).throwex();
269 return NULL; //Never reached.
272 bool _is(state& _state, int arg)
274 if(_state.type(arg) != LUA_TUSERDATA)
275 return false;
276 load_metatable(_state);
277 if(!_state.getmetatable(arg)) {
278 _state.pop(1);
279 return false;
281 bool ret = _state.rawequal(-1, -2);
282 _state.pop(2);
283 return ret;
286 objpin<T> _pin(state& _state, int arg, const std::string& fname)
288 T* obj = get(_state, arg, fname);
289 _state.pushvalue(arg);
290 objpin<T> t(_state, obj);
291 _state.pop(1);
292 return t;
295 void bind(state& _state, const char* keyname, int (T::*fn)(state& LS, lua::parameters& P))
297 load_metatable(_state);
298 _state.pushstring(keyname);
299 std::string fname = name + std::string("::") + keyname;
300 void* ptr = _state.newuserdata(sizeof(class_binding<T>) + fname.length() + 1);
301 class_binding<T>* bdata = reinterpret_cast<class_binding<T>*>(ptr);
302 bdata->fn = fn;
303 bdata->_state = &_state.get_master();
304 std::copy(fname.begin(), fname.end(), bdata->fname);
305 bdata->fname[fname.length()] = 0;
306 _state.push_trampoline(class_bind_trampoline, 1);
307 _state.rawset(-3);
308 _state.pop(1);
310 protected:
311 void register_state(state& L)
313 static char once_key;
314 register_static(L);
315 if(L.do_once(&once_key))
316 for(auto i : cmethods)
317 bind(L, i.name, i.fn);
319 public:
321 * Create a new Lua class.
323 * Parameter _group: The group the class will be in.
324 * Parameter _name: The name of the class.
325 * Parameter _smethods: Static methods of the class.
326 * Parameter _cmethods: Class methods of the class.
327 * Parameter _print: The print method.
329 _class(class_group& _group, const std::string& _name, std::initializer_list<static_method> _smethods,
330 std::initializer_list<class_method<T>> _cmethods = {}, std::string (T::*_print)() = NULL)
331 : class_base(_group, _name), smethods(_smethods), cmethods(_cmethods)
333 name = _name;
334 class_ops m;
335 printmeth = _print;
336 m.is = _class<T>::is;
337 m.name = _class<T>::get_name;
338 m.print = _class<T>::print;
339 userdata_recogn_fns().push_back(m);
340 auto& type = typeid(T);
341 class_types()[type] = this;
342 delayed_register();
345 * Dtor
347 ~_class() throw()
349 auto& type = typeid(T);
350 class_types().erase(type);
351 auto& fns = userdata_recogn_fns();
352 for(auto i = fns.begin(); i != fns.end(); i++) {
353 if(i->is == _class<T>::is) {
354 fns.erase(i);
355 break;
360 * Create a new instance of object.
362 * Parameter _state: The Lua state to create the object in.
363 * Parameter args: The arguments to pass to class constructor.
365 template<typename... U> static T* create(state& _state, U... args)
367 return objclass<T>()._create(_state, args...);
371 * Get a pointer to the object.
373 * Parameter _state: The Lua state.
374 * Parameter arg: Argument index.
375 * Parameter fname: The name of function for error messages.
376 * Parameter optional: If true and argument is NIL or none, return NULL.
377 * Throws std::runtime_error: Wrong type.
379 static T* get(state& _state, int arg, const std::string& fname, bool optional = false)
381 return objclass<T>()._get(_state, arg, fname, optional);
385 * Identify if object is of this type.
387 * Parameter _state: The Lua state.
388 * Parameter arg: Argument index.
389 * Returns: True if object is of specified type, false if not.
391 static bool is(state& _state, int arg) throw()
393 try {
394 return objclass<T>()._is(_state, arg);
395 } catch(...) {
396 return false;
400 * Get name of class.
402 static const std::string& get_name()
404 try {
405 return objclass<T>().name;
406 } catch(...) {
407 static std::string foo = "???";
408 return foo;
412 * Format instance of this class as string.
414 static std::string print(state& _state, int index)
416 T* obj = get(_state, index, "__internal_print");
417 try {
418 auto pmeth = objclass<T>().printmeth;
419 if(pmeth)
420 return (obj->*pmeth)();
421 else
422 return "";
423 } catch(...) {
424 return "";
428 * Get a pin of object against Lua GC.
430 * Parameter _state: The Lua state.
431 * Parameter arg: Argument index.
432 * Parameter fname: Name of function for error message purposes.
433 * Throws std::runtime_error: Wrong type.
435 static objpin<T> pin(state& _state, int arg, const std::string& fname)
437 return objclass<T>()._pin(_state, arg, fname);
440 * Lookup static methods.
442 std::list<static_method> static_methods()
444 return smethods;
447 * Lookup class methods.
449 std::set<std::string> class_methods()
451 std::set<std::string> r;
452 for(auto& i : cmethods)
453 r.insert(i.name);
454 return r;
456 private:
457 static int dogc(state& L)
459 T* obj = reinterpret_cast<T*>(L.touserdata(1));
460 obj->~T();
461 return 0;
464 static int newindex(state& L)
466 throw std::runtime_error("Writing metatables of classes is not allowed");
469 static int index(state& L)
471 L.getmetatable(1);
472 L.pushvalue(2);
473 L.rawget(-2);
474 if(L.type(-1) == LUA_TNIL) {
475 std::string err = std::string("Class '") + L.tostring(L.trampoline_upval(1)) +
476 "' does not have class method '" + L.tostring(2) + "'";
477 throw std::runtime_error(err);
479 return 1;
482 void load_metatable(state& _state)
484 again:
485 _state.pushlightuserdata(this);
486 _state.rawget(LUA_REGISTRYINDEX);
487 if(_state.type(-1) == LUA_TNIL) {
488 _state.pop(1);
489 _state.pushlightuserdata(this);
490 _state.newtable();
491 _state.pushvalue(-1);
492 _state.setmetatable(-2);
493 _state.pushstring("__gc");
494 _state.push_trampoline(&_class<T>::dogc, 0);
495 _state.rawset(-3);
496 _state.pushstring("__newindex");
497 _state.push_trampoline(&_class<T>::newindex, 0);
498 _state.rawset(-3);
499 _state.pushstring("__index");
500 _state.pushlstring(name);
501 _state.push_trampoline(&_class<T>::index, 1);
502 _state.rawset(-3);
503 _state.rawset(LUA_REGISTRYINDEX);
504 goto again;
507 std::string name;
508 std::list<static_method> smethods;
509 std::list<class_method<T>> cmethods;
510 std::string (T::*printmeth)();
511 _class(const _class<T>&);
512 _class& operator=(const _class<T>&);
516 #endif