Implemented locking so the threading version works too.
[blynken.git] / server / stackless_socket.py
blob772485858819f96179f7d515d0c6bf9f9ade4f33
2 # Stackless compatible socket module:
4 # Author: Richard Tew <richard.m.tew@gmail.com>
6 # This code was written to serve as an example of Stackless Python usage.
7 # Feel free to email me with any questions, comments, or suggestions for
8 # improvement.
10 # This wraps the asyncore module and the dispatcher class it provides in order
11 # write a socket module replacement that uses channels to allow calls to it to
12 # block until a delayed event occurs.
14 # Not all aspects of the socket module are provided by this file. Examples of
15 # it in use can be seen at the bottom of this file.
17 # NOTE: Versions of the asyncore module from Python 2.4 or later include bug
18 # fixes and earlier versions will not guarantee correct behaviour.
19 # Specifically, it monitors for errors on sockets where the version in
20 # Python 2.3.3 does not.
23 # Possible improvements:
24 # - More correct error handling. When there is an error on a socket found by
25 # poll, there is no idea what it actually is.
26 # - Launching each bit of incoming data in its own tasklet on the recvChannel
27 # send is a little over the top. It should be possible to add it to the
28 # rest of the queued data
30 import stackless
31 import asyncore, weakref
32 import socket as stdsocket # We need the "socket" name for the function we export.
34 # If we are to masquerade as the socket module, we need to provide the constants.
35 if "__all__" in stdsocket.__dict__:
36 __all__ = stdsocket.__dict__
37 for k, v in stdsocket.__dict__.iteritems():
38 if k in __all__:
39 globals()[k] = v
40 elif k == "EBADF":
41 globals()[k] = v
42 else:
43 for k, v in stdsocket.__dict__.iteritems():
44 if k.upper() == k:
45 globals()[k] = v
46 error = stdsocket.error
47 timeout = stdsocket.timeout
48 # WARNING: this function blocks and is not thread safe.
49 # The only solution is to spawn a thread to handle all
50 # getaddrinfo requests. Implementing a stackless DNS
51 # lookup service is only second best as getaddrinfo may
52 # use other methods.
53 getaddrinfo = stdsocket.getaddrinfo
55 # urllib2 apparently uses this directly. We need to cater for that.
56 _fileobject = stdsocket._fileobject
58 # Someone needs to invoke asyncore.poll() regularly to keep the socket
59 # data moving. The "ManageSockets" function here is a simple example
60 # of such a function. It is started by StartManager(), which uses the
61 # global "managerRunning" to ensure that no more than one copy is
62 # running.
64 # If you think you can do this better, register an alternative to
65 # StartManager using stacklesssocket_manager(). Your function will be
66 # called every time a new socket is created; it's your responsibility
67 # to ensure it doesn't start multiple copies of itself unnecessarily.
70 managerRunning = False
72 def ManageSockets():
73 global managerRunning
75 while len(asyncore.socket_map):
76 # Check the sockets for activity.
77 asyncore.poll(0.05)
78 # Yield to give other tasklets a chance to be scheduled.
79 stackless.schedule()
81 managerRunning = False
83 def StartManager():
84 global managerRunning
85 if not managerRunning:
86 managerRunning = True
87 stackless.tasklet(ManageSockets)()
89 _manage_sockets_func = StartManager
91 def stacklesssocket_manager(mgr):
92 global _manage_sockets_func
93 _manage_sockets_func = mgr
95 def socket(*args, **kwargs):
96 import sys
97 if "socket" in sys.modules and sys.modules["socket"] is not stdsocket:
98 raise RuntimeError("Use 'stacklesssocket.install' instead of replacing the 'socket' module")
100 _realsocket_old = stdsocket._realsocket
101 _socketobject_old = stdsocket._socketobject
103 class _socketobject_new(_socketobject_old):
104 def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, _sock=None):
105 # We need to do this here.
106 if _sock is None:
107 _sock = _realsocket_old(family, type, proto)
108 _sock = _fakesocket(_sock)
109 _manage_sockets_func()
110 _socketobject_old.__init__(self, family, type, proto, _sock)
111 if not isinstance(self._sock, _fakesocket):
112 raise RuntimeError("bad socket")
114 def accept(self):
115 sock, addr = self._sock.accept()
116 sock = _fakesocket(sock)
117 sock.wasConnected = True
118 return _socketobject_new(_sock=sock), addr
120 accept.__doc__ = _socketobject_old.accept.__doc__
123 def check_still_connected(f):
124 " Decorate socket functions to check they are still connected. "
125 def new_f(self, *args, **kwds):
126 if not self.connected:
127 # The socket was never connected.
128 if not self.wasConnected:
129 raise error(10057, "Socket is not connected")
130 # The socket has been closed already.
131 raise error(EBADF, 'Bad file descriptor')
132 return f(self, *args, **kwds)
133 return new_f
136 def install():
137 if stdsocket._realsocket is socket:
138 raise StandardError("Still installed")
139 stdsocket._realsocket = socket
140 stdsocket.socket = stdsocket.SocketType = stdsocket._socketobject = _socketobject_new
142 def uninstall():
143 stdsocket._realsocket = _realsocket_old
144 stdsocket.socket = stdsocket.SocketType = stdsocket._socketobject = _socketobject_old
147 class _fakesocket(asyncore.dispatcher):
148 connectChannel = None
149 acceptChannel = None
150 recvChannel = None
151 wasConnected = False
153 def __init__(self, realSocket):
154 # This is worth doing. I was passing in an invalid socket which
155 # was an instance of _fakesocket and it was causing tasklet death.
156 if not isinstance(realSocket, _realsocket_old):
157 raise StandardError("An invalid socket passed to fakesocket %s" % realSocket.__class__)
159 # This will register the real socket in the internal socket map.
160 asyncore.dispatcher.__init__(self, realSocket)
161 self.socket = realSocket
163 self.recvChannel = stackless.channel()
164 self.readString = ''
165 self.readIdx = 0
167 self.sendBuffer = ''
168 self.sendToBuffers = []
170 def __del__(self):
171 # There are no more users (sockets or files) of this fake socket, we
172 # are safe to close it fully. If we don't, asyncore will choke on
173 # the weakref failures.
174 self.close()
176 # The asyncore version of this function depends on socket being set
177 # which is not the case when this fake socket has been closed.
178 def __getattr__(self, attr):
179 if not hasattr(self, "socket"):
180 raise AttributeError("socket attribute unset on '"+ attr +"' lookup")
181 return getattr(self.socket, attr)
183 def add_channel(self, map=None):
184 if map is None:
185 map = self._map
186 map[self._fileno] = weakref.proxy(self)
188 def writable(self):
189 if self.socket.type != SOCK_DGRAM and not self.connected:
190 return True
191 return len(self.sendBuffer) or len(self.sendToBuffers)
193 def accept(self):
194 if not self.acceptChannel:
195 self.acceptChannel = stackless.channel()
196 return self.acceptChannel.receive()
198 def connect(self, address):
199 asyncore.dispatcher.connect(self, address)
200 # UDP sockets do not connect.
201 if self.socket.type != SOCK_DGRAM and not self.connected:
202 if not self.connectChannel:
203 self.connectChannel = stackless.channel()
204 # Prefer the sender. Do not block when sending, given that
205 # there is a tasklet known to be waiting, this will happen.
206 self.connectChannel.preference = 1
207 self.connectChannel.receive()
209 @check_still_connected
210 def send(self, data, flags=0):
211 self.sendBuffer += data
212 # stackless.schedule()
213 return len(data)
215 @check_still_connected
216 def sendall(self, data, flags=0):
217 # WARNING: this will busy wait until all data is sent
218 # It should be possible to do away with the busy wait with
219 # the use of a channel.
220 self.sendBuffer += data
221 while self.sendBuffer:
222 stackless.schedule()
223 return len(data)
225 def sendto(self, sendData, sendArg1=None, sendArg2=None):
226 # sendto(data, address)
227 # sendto(data [, flags], address)
228 if sendArg2 is not None:
229 flags = sendArg1
230 sendAddress = sendArg2
231 else:
232 flags = 0
233 sendAddress = sendArg1
235 waitChannel = None
236 for idx, (data, address, channel, sentBytes) in enumerate(self.sendToBuffers):
237 if address == sendAddress:
238 self.sendToBuffers[idx] = (data + sendData, address, channel, sentBytes)
239 waitChannel = channel
240 break
241 if waitChannel is None:
242 waitChannel = stackless.channel()
243 self.sendToBuffers.append((sendData, sendAddress, waitChannel, 0))
244 return waitChannel.receive()
246 # Read at most byteCount bytes.
247 def recv(self, byteCount, flags=0):
248 # recv() must not concatenate two or more data fragments sent with
249 # send() on the remote side. Single fragment sent with single send()
250 # call should be split into strings of length less than or equal
251 # to 'byteCount', and returned by one or more recv() calls.
253 remainingBytes = self.readIdx != len(self.readString)
254 # TODO: Verify this connectivity behaviour.
256 if not self.connected:
257 # Sockets which have never been connected do this.
258 if not self.wasConnected:
259 raise error(10057, 'Socket is not connected')
261 # Sockets which were connected, but no longer are, use
262 # up the remaining input. Observed this with urllib.urlopen
263 # where it closes the socket and then allows the caller to
264 # use a file to access the body of the web page.
265 elif not remainingBytes:
266 self.readString = self.recvChannel.receive()
267 self.readIdx = 0
268 remainingBytes = len(self.readString)
270 if byteCount == 1 and remainingBytes:
271 ret = self.readString[self.readIdx]
272 self.readIdx += 1
273 elif self.readIdx == 0 and byteCount >= len(self.readString):
274 ret = self.readString
275 self.readString = ""
276 else:
277 idx = self.readIdx + byteCount
278 ret = self.readString[self.readIdx:idx]
279 self.readString = self.readString[idx:]
280 self.readIdx = 0
282 # ret will be '' when EOF.
283 return ret
285 def recvfrom(self, byteCount, flags=0):
286 if self.socket.type == SOCK_STREAM:
287 return self.recv(byteCount), None
289 # recvfrom() must not concatenate two or more packets.
290 # Each call should return the first 'byteCount' part of the packet.
291 data, address = self.recvChannel.receive()
292 return data[:byteCount], address
294 def close(self):
295 asyncore.dispatcher.close(self)
297 self.connected = False
298 self.accepting = False
299 self.sendBuffer = None # breaks the loop in sendall
301 # Clear out all the channels with relevant errors.
302 while self.acceptChannel and self.acceptChannel.balance < 0:
303 self.acceptChannel.send_exception(error, 9, 'Bad file descriptor')
304 while self.connectChannel and self.connectChannel.balance < 0:
305 self.connectChannel.send_exception(error, 10061, 'Connection refused')
306 while self.recvChannel and self.recvChannel.balance < 0:
307 # The closing of a socket is indicted by receiving nothing. The
308 # exception would have been sent if the server was killed, rather
309 # than closed down gracefully.
310 self.recvChannel.send("")
311 #self.recvChannel.send_exception(error, 10054, 'Connection reset by peer')
313 # asyncore doesn't support this. Why not?
314 def fileno(self):
315 return self.socket.fileno()
317 def handle_accept(self):
318 if self.acceptChannel and self.acceptChannel.balance < 0:
319 t = asyncore.dispatcher.accept(self)
320 if t is None:
321 return
322 t[0].setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
323 stackless.tasklet(self.acceptChannel.send)(t)
325 # Inform the blocked connect call that the connection has been made.
326 def handle_connect(self):
327 if self.socket.type != SOCK_DGRAM:
328 self.wasConnected = True
329 self.connectChannel.send(None)
331 # Asyncore says its done but self.readBuffer may be non-empty
332 # so can't close yet. Do nothing and let 'recv' trigger the close.
333 def handle_close(self):
334 pass
336 # Some error, just close the channel and let that raise errors to
337 # blocked calls.
338 def handle_expt(self):
339 self.close()
341 def handle_read(self):
342 try:
343 if self.socket.type == SOCK_DGRAM:
344 ret = self.socket.recvfrom(20000)
345 else:
346 ret = asyncore.dispatcher.recv(self, 20000)
347 # Not sure this is correct, but it seems to give the
348 # right behaviour. Namely removing the socket from
349 # asyncore.
350 if not ret:
351 self.close()
352 stackless.tasklet(self.recvChannel.send)(ret)
353 except stdsocket.error, err:
354 # If there's a read error assume the connection is
355 # broken and drop any pending output
356 if self.sendBuffer:
357 self.sendBuffer = ""
358 self.recvChannel.send_exception(stdsocket.error, err)
360 def handle_write(self):
361 if len(self.sendBuffer):
362 sentBytes = asyncore.dispatcher.send(self, self.sendBuffer[:512])
363 self.sendBuffer = self.sendBuffer[sentBytes:]
364 elif len(self.sendToBuffers):
365 data, address, channel, oldSentBytes = self.sendToBuffers[0]
366 sentBytes = self.socket.sendto(data, address)
367 totalSentBytes = oldSentBytes + sentBytes
368 if len(data) > sentBytes:
369 self.sendToBuffers[0] = data[sentBytes:], address, channel, totalSentBytes
370 else:
371 del self.sendToBuffers[0]
372 stackless.tasklet(channel.send)(totalSentBytes)
375 if __name__ == '__main__':
376 import sys
377 import struct
378 # Test code goes here.
379 testAddress = "127.0.0.1", 3000
380 info = -12345678
381 data = struct.pack("i", info)
382 dataLength = len(data)
384 def TestTCPServer(address):
385 global info, data, dataLength
387 print "server listen socket creation"
388 listenSocket = stdsocket.socket(AF_INET, SOCK_STREAM)
389 listenSocket.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
390 listenSocket.bind(address)
391 listenSocket.listen(5)
393 NUM_TESTS = 2
395 i = 1
396 while i < NUM_TESTS + 1:
397 # No need to schedule this tasklet as the accept should yield most
398 # of the time on the underlying channel.
399 print "server connection wait", i
400 currentSocket, clientAddress = listenSocket.accept()
401 print "server", i, "listen socket", currentSocket.fileno(), "from", clientAddress
403 if i == 1:
404 print "server closing (a)", i, "fd", currentSocket.fileno(), "id", id(currentSocket)
405 currentSocket.close()
406 print "server closed (a)", i
407 elif i == 2:
408 print "server test", i, "send"
409 currentSocket.send(data)
410 print "server test", i, "recv"
411 if currentSocket.recv(4) != "":
412 print "server recv(1)", i, "FAIL"
413 break
414 # multiple empty recvs are fine
415 if currentSocket.recv(4) != "":
416 print "server recv(2)", i, "FAIL"
417 break
418 else:
419 print "server closing (b)", i, "fd", currentSocket.fileno(), "id", id(currentSocket)
420 currentSocket.close()
422 print "server test", i, "OK"
423 i += 1
425 if i != NUM_TESTS+1:
426 print "server: FAIL", i
427 else:
428 print "server: OK", i
430 print "Done server"
432 def TestTCPClient(address):
433 global info, data, dataLength
435 # Attempt 1:
436 clientSocket = stdsocket.socket()
437 clientSocket.connect(address)
438 print "client connection (1) fd", clientSocket.fileno(), "id", id(clientSocket._sock), "waiting to recv"
439 if clientSocket.recv(5) != "":
440 print "client test", 1, "FAIL"
441 else:
442 print "client test", 1, "OK"
444 # Attempt 2:
445 clientSocket = stdsocket.socket()
446 clientSocket.connect(address)
447 print "client connection (2) fd", clientSocket.fileno(), "id", id(clientSocket._sock), "waiting to recv"
448 s = clientSocket.recv(dataLength)
449 if s == "":
450 print "client test", 2, "FAIL (disconnect)"
451 else:
452 t = struct.unpack("i", s)
453 if t[0] == info:
454 print "client test", 2, "OK"
455 else:
456 print "client test", 2, "FAIL (wrong data)"
458 print "client exit"
460 def TestMonkeyPatchUrllib(uri):
461 # replace the system socket with this module
462 #oldSocket = sys.modules["socket"]
463 #sys.modules["socket"] = __import__(__name__)
464 install()
465 try:
466 import urllib # must occur after monkey-patching!
467 f = urllib.urlopen(uri)
468 if not isinstance(f.fp._sock, _fakesocket):
469 raise AssertionError("failed to apply monkeypatch, got %s" % f.fp._sock.__class__)
470 s = f.read()
471 if len(s) != 0:
472 print "Fetched", len(s), "bytes via replaced urllib"
473 else:
474 raise AssertionError("no text received?")
475 finally:
476 #sys.modules["socket"] = oldSocket
477 uninstall()
479 def TestMonkeyPatchUDP(address):
480 # replace the system socket with this module
481 #oldSocket = sys.modules["socket"]
482 #sys.modules["socket"] = __import__(__name__)
483 install()
484 try:
485 def UDPServer(address):
486 listenSocket = stdsocket.socket(AF_INET, SOCK_DGRAM)
487 listenSocket.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
488 listenSocket.bind(address)
490 # Apparently each call to recvfrom maps to an incoming
491 # packet and if we only ask for part of that packet, the
492 # rest is lost. We really need a proper unittest suite
493 # which tests this module against the normal socket
494 # module.
495 print "waiting to receive"
496 data, address = listenSocket.recvfrom(256)
497 print "received", data, len(data)
498 if len(data) != 256:
499 raise StandardError("Unexpected UDP packet size")
501 def UDPClient(address):
502 clientSocket = stdsocket.socket(AF_INET, SOCK_DGRAM)
503 # clientSocket.connect(address)
504 print "sending 512 byte packet"
505 sentBytes = clientSocket.sendto("-"+ ("*" * 510) +"-", address)
506 print "sent 512 byte packet", sentBytes
508 stackless.tasklet(UDPServer)(address)
509 stackless.tasklet(UDPClient)(address)
510 stackless.run()
511 finally:
512 #sys.modules["socket"] = oldSocket
513 uninstall()
515 if len(sys.argv) == 2:
516 if sys.argv[1] == "client":
517 print "client started"
518 TestTCPClient(testAddress)
519 print "client exited"
520 elif sys.argv[1] == "slpclient":
521 print "client started"
522 stackless.tasklet(TestTCPClient)(testAddress)
523 stackless.run()
524 print "client exited"
525 elif sys.argv[1] == "server":
526 print "server started"
527 TestTCPServer(testAddress)
528 print "server exited"
529 elif sys.argv[1] == "slpserver":
530 print "server started"
531 stackless.tasklet(TestTCPServer)(testAddress)
532 stackless.run()
533 print "server exited"
534 else:
535 print "Usage:", sys.argv[0], "[client|server|slpclient|slpserver]"
537 sys.exit(1)
538 else:
539 print "* Running client/server test"
540 install()
541 try:
542 stackless.tasklet(TestTCPServer)(testAddress)
543 stackless.tasklet(TestTCPClient)(testAddress)
544 stackless.run()
545 finally:
546 uninstall()
548 print "* Running urllib test"
549 stackless.tasklet(TestMonkeyPatchUrllib)("http://python.org/")
550 stackless.run()
552 print "* Running udp test"
553 TestMonkeyPatchUDP(testAddress)
555 print "result: SUCCESS"