1 -----------------------------------------------------------------------------
2 -- A hacked dispatcher module
3 -- LuaSocket sample files
5 -----------------------------------------------------------------------------
7 local table = require("table")
8 local socket
= require("socket")
9 local coroutine
= require("coroutine")
12 -- if too much time goes by without any activity in one of our sockets, we
16 -----------------------------------------------------------------------------
17 -- We implement 3 types of dispatchers:
21 -- The user can choose whatever one is needed
22 -----------------------------------------------------------------------------
25 -- default handler is coroutine
26 function newhandler(mode
)
27 mode
= mode
or "coroutine"
28 return handlert
[mode
]()
31 local function seqstart(self
, func
)
35 -- sequential handler simply calls the functions and doesn't wrap I/O
36 function handlert
.sequential()
43 -----------------------------------------------------------------------------
44 -- Mega hack. Don't try to do this at home.
45 -----------------------------------------------------------------------------
46 -- we can't yield across calls to protect, so we rewrite it with coxpcall
47 -- make sure you don't require any module that uses socket.protect before
49 function socket
.protect(f
)
51 local co
= coroutine
.create(f
)
53 local results
= {coroutine
.resume(co
, base
.unpack(arg
))}
54 local status
= table.remove(results
, 1)
56 if base
.type(results
[1]) == 'table' then
57 return nil, results
[1][1]
58 else base
.error(results
[1]) end
60 if coroutine
.status(co
) == "suspended" then
61 arg
= {coroutine
.yield(base
.unpack(results
))}
63 return base
.unpack(results
)
69 -----------------------------------------------------------------------------
70 -- Simple set data structure. O(1) everything.
71 -----------------------------------------------------------------------------
72 local function newset()
75 return base
.setmetatable(set
, {__index
= {
76 insert
= function(set
, value
)
77 if not reverse
[value
] then
78 table.insert(set
, value
)
79 reverse
[value
] = table.getn(set
)
82 remove = function(set
, value
)
83 local index
= reverse
[value
]
86 local top
= table.remove(set
)
96 -----------------------------------------------------------------------------
97 -- socket.tcp() wrapper for the coroutine dispatcher
98 -----------------------------------------------------------------------------
99 local function cowrap(dispatcher
, tcp
, error)
100 if not tcp
then return nil, error end
101 -- put it in non-blocking mode right away
103 -- metatable for wrap produces new methods on demand for those that we
104 -- don't override explicitly.
105 local metat
= { __index
= function(table, key
)
106 table[key
] = function(...)
108 return tcp
[key
](base
.unpack(arg
))
112 -- does our user want to do his own non-blocking I/O?
114 -- create a wrap object that will behave just like a real socket object
116 -- we ignore settimeout to preserve our 0 timeout, but record whether
117 -- the user wants to do his own non-blocking I/O
118 function wrap
:settimeout(value
, mode
)
119 if value
== 0 then zero
= true
120 else zero
= false end
123 -- send in non-blocking mode and yield on timeout
124 function wrap
:send(data
, first
, last
)
125 first
= (first
or 1) - 1
128 -- return control to dispatcher and tell it we want to send
129 -- if upon return the dispatcher tells us we timed out,
130 -- return an error to whoever called us
131 if coroutine
.yield(dispatcher
.sending
, tcp
) == "timeout" then
132 return nil, "timeout"
135 result
, error, first
= tcp
:send(data
, first
+1, last
)
136 -- if we are done, or there was an unexpected error,
137 -- break away from loop
138 if error ~= "timeout" then return result
, error, first
end
141 -- receive in non-blocking mode and yield on timeout
142 -- or simply return partial read, if user requested timeout = 0
143 function wrap
:receive(pattern
, partial
)
144 local error = "timeout"
147 -- return control to dispatcher and tell it we want to receive
148 -- if upon return the dispatcher tells us we timed out,
149 -- return an error to whoever called us
150 if coroutine
.yield(dispatcher
.receiving
, tcp
) == "timeout" then
151 return nil, "timeout"
154 value
, error, partial
= tcp
:receive(pattern
, partial
)
155 -- if we are done, or there was an unexpected error,
156 -- break away from loop. also, if the user requested
157 -- zero timeout, return all we got
158 if (error ~= "timeout") or zero
then
159 return value
, error, partial
163 -- connect in non-blocking mode and yield on timeout
164 function wrap
:connect(host
, port
)
165 local result
, error = tcp
:connect(host
, port
)
166 if error == "timeout" then
167 -- return control to dispatcher. we will be writable when
168 -- connection succeeds.
169 -- if upon return the dispatcher tells us we have a
170 -- timeout, just abort
171 if coroutine
.yield(dispatcher
.sending
, tcp
) == "timeout" then
172 return nil, "timeout"
174 -- when we come back, check if connection was successful
175 result
, error = tcp
:connect(host
, port
)
176 if result
or error == "already connected" then return 1
177 else return nil, "non-blocking connect failed" end
178 else return result
, error end
180 -- accept in non-blocking mode and yield on timeout
181 function wrap
:accept()
183 -- return control to dispatcher. we will be readable when a
184 -- connection arrives.
185 -- if upon return the dispatcher tells us we have a
186 -- timeout, just abort
187 if coroutine
.yield(dispatcher
.receiving
, tcp
) == "timeout" then
188 return nil, "timeout"
190 local client
, error = tcp
:accept()
191 if error ~= "timeout" then
192 return cowrap(dispatcher
, client
, error)
196 -- remove cortn from context
197 function wrap
:close()
198 dispatcher
.stamp
[tcp
] = nil
199 dispatcher
.sending
.set
:remove(tcp
)
200 dispatcher
.sending
.cortn
[tcp
] = nil
201 dispatcher
.receiving
.set
:remove(tcp
)
202 dispatcher
.receiving
.cortn
[tcp
] = nil
205 return base
.setmetatable(wrap
, metat
)
209 -----------------------------------------------------------------------------
210 -- Our coroutine dispatcher
211 -----------------------------------------------------------------------------
212 local cometat
= { __index
= {} }
214 function schedule(cortn
, status
, operation
, tcp
)
216 if cortn
and operation
then
217 operation
.set
:insert(tcp
)
218 operation
.cortn
[tcp
] = cortn
219 operation
.stamp
[tcp
] = socket
.gettime()
221 else base
.error(operation
) end
224 function kick(operation
, tcp
)
225 operation
.cortn
[tcp
] = nil
226 operation
.set
:remove(tcp
)
229 function wakeup(operation
, tcp
)
230 local cortn
= operation
.cortn
[tcp
]
231 -- if cortn is still valid, wake it up
234 return cortn
, coroutine
.resume(cortn
)
235 -- othrewise, just get scheduler not to do anything
241 function abort(operation
, tcp
)
242 local cortn
= operation
.cortn
[tcp
]
245 coroutine
.resume(cortn
, "timeout")
249 -- step through all active cortns
250 function cometat
.__index
:step()
251 -- check which sockets are interesting and act on them
252 local readable
, writable
= socket
.select(self
.receiving
.set
,
254 -- for all readable connections, resume their cortns and reschedule
255 -- when they yield back to us
256 for _
, tcp
in base
.ipairs(readable
) do
257 schedule(wakeup(self
.receiving
, tcp
))
259 -- for all writable connections, do the same
260 for _
, tcp
in base
.ipairs(writable
) do
261 schedule(wakeup(self
.sending
, tcp
))
263 -- politely ask replacement I/O functions in idle cortns to
264 -- return reporting a timeout
265 local now
= socket
.gettime()
266 for tcp
, stamp
in base
.pairs(self
.stamp
) do
267 if tcp
.class
== "tcp{client}" and now
- stamp
> TIMEOUT
then
268 abort(self
.sending
, tcp
)
269 abort(self
.receiving
, tcp
)
274 function cometat
.__index
:start(func
)
275 local cortn
= coroutine
.create(func
)
276 schedule(cortn
, coroutine
.resume(cortn
))
279 function handlert
.coroutine()
296 function dispatcher
.tcp()
297 return cowrap(dispatcher
, socket
.tcp())
299 return base
.setmetatable(dispatcher
, cometat
)