2 # Stackless compatible socket module:
4 # Author: Richard Tew <richard.m.tew@gmail.com>
6 # This code was written to serve as an example of Stackless Python usage.
7 # Feel free to email me with any questions, comments, or suggestions for
10 # This wraps the asyncore module and the dispatcher class it provides in order
11 # write a socket module replacement that uses channels to allow calls to it to
12 # block until a delayed event occurs.
14 # Not all aspects of the socket module are provided by this file. Examples of
15 # it in use can be seen at the bottom of this file.
17 # NOTE: Versions of the asyncore module from Python 2.4 or later include bug
18 # fixes and earlier versions will not guarantee correct behaviour.
19 # Specifically, it monitors for errors on sockets where the version in
20 # Python 2.3.3 does not.
23 # Possible improvements:
24 # - More correct error handling. When there is an error on a socket found by
25 # poll, there is no idea what it actually is.
26 # - Launching each bit of incoming data in its own tasklet on the recvChannel
27 # send is a little over the top. It should be possible to add it to the
28 # rest of the queued data
31 import asyncore
, weakref
32 import socket
as stdsocket
# We need the "socket" name for the function we export.
34 # If we are to masquerade as the socket module, we need to provide the constants.
35 if "__all__" in stdsocket
.__dict
__:
36 __all__
= stdsocket
.__dict
__
37 for k
, v
in stdsocket
.__dict
__.iteritems():
43 for k
, v
in stdsocket
.__dict
__.iteritems():
46 error
= stdsocket
.error
47 timeout
= stdsocket
.timeout
48 # WARNING: this function blocks and is not thread safe.
49 # The only solution is to spawn a thread to handle all
50 # getaddrinfo requests. Implementing a stackless DNS
51 # lookup service is only second best as getaddrinfo may
53 getaddrinfo
= stdsocket
.getaddrinfo
55 # urllib2 apparently uses this directly. We need to cater for that.
56 _fileobject
= stdsocket
._fileobject
58 # Someone needs to invoke asyncore.poll() regularly to keep the socket
59 # data moving. The "ManageSockets" function here is a simple example
60 # of such a function. It is started by StartManager(), which uses the
61 # global "managerRunning" to ensure that no more than one copy is
64 # If you think you can do this better, register an alternative to
65 # StartManager using stacklesssocket_manager(). Your function will be
66 # called every time a new socket is created; it's your responsibility
67 # to ensure it doesn't start multiple copies of itself unnecessarily.
70 managerRunning
= False
75 while len(asyncore
.socket_map
):
76 # Check the sockets for activity.
78 # Yield to give other tasklets a chance to be scheduled.
81 managerRunning
= False
85 if not managerRunning
:
87 stackless
.tasklet(ManageSockets
)()
89 _manage_sockets_func
= StartManager
91 def stacklesssocket_manager(mgr
):
92 global _manage_sockets_func
93 _manage_sockets_func
= mgr
95 def socket(*args
, **kwargs
):
97 if "socket" in sys
.modules
and sys
.modules
["socket"] is not stdsocket
:
98 raise RuntimeError("Use 'stacklesssocket.install' instead of replacing the 'socket' module")
100 _realsocket_old
= stdsocket
._realsocket
101 _socketobject_old
= stdsocket
._socketobject
103 class _socketobject_new(_socketobject_old
):
104 def __init__(self
, family
=AF_INET
, type=SOCK_STREAM
, proto
=0, _sock
=None):
105 # We need to do this here.
107 _sock
= _realsocket_old(family
, type, proto
)
108 _sock
= _fakesocket(_sock
)
109 _manage_sockets_func()
110 _socketobject_old
.__init
__(self
, family
, type, proto
, _sock
)
111 if not isinstance(self
._sock
, _fakesocket
):
112 raise RuntimeError("bad socket")
115 sock
, addr
= self
._sock
.accept()
116 sock
= _fakesocket(sock
)
117 sock
.wasConnected
= True
118 return _socketobject_new(_sock
=sock
), addr
120 accept
.__doc
__ = _socketobject_old
.accept
.__doc
__
123 def check_still_connected(f
):
124 " Decorate socket functions to check they are still connected. "
125 def new_f(self
, *args
, **kwds
):
126 if not self
.connected
:
127 # The socket was never connected.
128 if not self
.wasConnected
:
129 raise error(10057, "Socket is not connected")
130 # The socket has been closed already.
131 raise error(EBADF
, 'Bad file descriptor')
132 return f(self
, *args
, **kwds
)
137 if stdsocket
._realsocket
is socket
:
138 raise StandardError("Still installed")
139 stdsocket
._realsocket
= socket
140 stdsocket
.socket
= stdsocket
.SocketType
= stdsocket
._socketobject
= _socketobject_new
143 stdsocket
._realsocket
= _realsocket_old
144 stdsocket
.socket
= stdsocket
.SocketType
= stdsocket
._socketobject
= _socketobject_old
147 class _fakesocket(asyncore
.dispatcher
):
148 connectChannel
= None
153 def __init__(self
, realSocket
):
154 # This is worth doing. I was passing in an invalid socket which
155 # was an instance of _fakesocket and it was causing tasklet death.
156 if not isinstance(realSocket
, _realsocket_old
):
157 raise StandardError("An invalid socket passed to fakesocket %s" % realSocket
.__class
__)
159 # This will register the real socket in the internal socket map.
160 asyncore
.dispatcher
.__init
__(self
, realSocket
)
161 self
.socket
= realSocket
163 self
.recvChannel
= stackless
.channel()
168 self
.sendToBuffers
= []
171 # There are no more users (sockets or files) of this fake socket, we
172 # are safe to close it fully. If we don't, asyncore will choke on
173 # the weakref failures.
176 # The asyncore version of this function depends on socket being set
177 # which is not the case when this fake socket has been closed.
178 def __getattr__(self
, attr
):
179 if not hasattr(self
, "socket"):
180 raise AttributeError("socket attribute unset on '"+ attr
+"' lookup")
181 return getattr(self
.socket
, attr
)
183 def add_channel(self
, map=None):
186 map[self
._fileno
] = weakref
.proxy(self
)
189 if self
.socket
.type != SOCK_DGRAM
and not self
.connected
:
191 return len(self
.sendBuffer
) or len(self
.sendToBuffers
)
194 if not self
.acceptChannel
:
195 self
.acceptChannel
= stackless
.channel()
196 return self
.acceptChannel
.receive()
198 def connect(self
, address
):
199 asyncore
.dispatcher
.connect(self
, address
)
200 # UDP sockets do not connect.
201 if self
.socket
.type != SOCK_DGRAM
and not self
.connected
:
202 if not self
.connectChannel
:
203 self
.connectChannel
= stackless
.channel()
204 # Prefer the sender. Do not block when sending, given that
205 # there is a tasklet known to be waiting, this will happen.
206 self
.connectChannel
.preference
= 1
207 self
.connectChannel
.receive()
209 @check_still_connected
210 def send(self
, data
, flags
=0):
211 self
.sendBuffer
+= data
212 # stackless.schedule()
215 @check_still_connected
216 def sendall(self
, data
, flags
=0):
217 # WARNING: this will busy wait until all data is sent
218 # It should be possible to do away with the busy wait with
219 # the use of a channel.
220 self
.sendBuffer
+= data
221 while self
.sendBuffer
:
225 def sendto(self
, sendData
, sendArg1
=None, sendArg2
=None):
226 # sendto(data, address)
227 # sendto(data [, flags], address)
228 if sendArg2
is not None:
230 sendAddress
= sendArg2
233 sendAddress
= sendArg1
236 for idx
, (data
, address
, channel
, sentBytes
) in enumerate(self
.sendToBuffers
):
237 if address
== sendAddress
:
238 self
.sendToBuffers
[idx
] = (data
+ sendData
, address
, channel
, sentBytes
)
239 waitChannel
= channel
241 if waitChannel
is None:
242 waitChannel
= stackless
.channel()
243 self
.sendToBuffers
.append((sendData
, sendAddress
, waitChannel
, 0))
244 return waitChannel
.receive()
246 # Read at most byteCount bytes.
247 def recv(self
, byteCount
, flags
=0):
248 # recv() must not concatenate two or more data fragments sent with
249 # send() on the remote side. Single fragment sent with single send()
250 # call should be split into strings of length less than or equal
251 # to 'byteCount', and returned by one or more recv() calls.
253 remainingBytes
= self
.readIdx
!= len(self
.readString
)
254 # TODO: Verify this connectivity behaviour.
256 if not self
.connected
:
257 # Sockets which have never been connected do this.
258 if not self
.wasConnected
:
259 raise error(10057, 'Socket is not connected')
261 # Sockets which were connected, but no longer are, use
262 # up the remaining input. Observed this with urllib.urlopen
263 # where it closes the socket and then allows the caller to
264 # use a file to access the body of the web page.
265 elif not remainingBytes
:
266 self
.readString
= self
.recvChannel
.receive()
268 remainingBytes
= len(self
.readString
)
270 if byteCount
== 1 and remainingBytes
:
271 ret
= self
.readString
[self
.readIdx
]
273 elif self
.readIdx
== 0 and byteCount
>= len(self
.readString
):
274 ret
= self
.readString
277 idx
= self
.readIdx
+ byteCount
278 ret
= self
.readString
[self
.readIdx
:idx
]
279 self
.readString
= self
.readString
[idx
:]
282 # ret will be '' when EOF.
285 def recvfrom(self
, byteCount
, flags
=0):
286 if self
.socket
.type == SOCK_STREAM
:
287 return self
.recv(byteCount
), None
289 # recvfrom() must not concatenate two or more packets.
290 # Each call should return the first 'byteCount' part of the packet.
291 data
, address
= self
.recvChannel
.receive()
292 return data
[:byteCount
], address
295 asyncore
.dispatcher
.close(self
)
297 self
.connected
= False
298 self
.accepting
= False
299 self
.sendBuffer
= None # breaks the loop in sendall
301 # Clear out all the channels with relevant errors.
302 while self
.acceptChannel
and self
.acceptChannel
.balance
< 0:
303 self
.acceptChannel
.send_exception(error
, 9, 'Bad file descriptor')
304 while self
.connectChannel
and self
.connectChannel
.balance
< 0:
305 self
.connectChannel
.send_exception(error
, 10061, 'Connection refused')
306 while self
.recvChannel
and self
.recvChannel
.balance
< 0:
307 # The closing of a socket is indicted by receiving nothing. The
308 # exception would have been sent if the server was killed, rather
309 # than closed down gracefully.
310 self
.recvChannel
.send("")
311 #self.recvChannel.send_exception(error, 10054, 'Connection reset by peer')
313 # asyncore doesn't support this. Why not?
315 return self
.socket
.fileno()
317 def handle_accept(self
):
318 if self
.acceptChannel
and self
.acceptChannel
.balance
< 0:
319 t
= asyncore
.dispatcher
.accept(self
)
322 t
[0].setsockopt(SOL_SOCKET
, SO_REUSEADDR
, 1)
323 stackless
.tasklet(self
.acceptChannel
.send
)(t
)
325 # Inform the blocked connect call that the connection has been made.
326 def handle_connect(self
):
327 if self
.socket
.type != SOCK_DGRAM
:
328 self
.wasConnected
= True
329 self
.connectChannel
.send(None)
331 # Asyncore says its done but self.readBuffer may be non-empty
332 # so can't close yet. Do nothing and let 'recv' trigger the close.
333 def handle_close(self
):
336 # Some error, just close the channel and let that raise errors to
338 def handle_expt(self
):
341 def handle_read(self
):
343 if self
.socket
.type == SOCK_DGRAM
:
344 ret
= self
.socket
.recvfrom(20000)
346 ret
= asyncore
.dispatcher
.recv(self
, 20000)
347 # Not sure this is correct, but it seems to give the
348 # right behaviour. Namely removing the socket from
352 stackless
.tasklet(self
.recvChannel
.send
)(ret
)
353 except stdsocket
.error
, err
:
354 # If there's a read error assume the connection is
355 # broken and drop any pending output
358 self
.recvChannel
.send_exception(stdsocket
.error
, err
)
360 def handle_write(self
):
361 if len(self
.sendBuffer
):
362 sentBytes
= asyncore
.dispatcher
.send(self
, self
.sendBuffer
[:512])
363 self
.sendBuffer
= self
.sendBuffer
[sentBytes
:]
364 elif len(self
.sendToBuffers
):
365 data
, address
, channel
, oldSentBytes
= self
.sendToBuffers
[0]
366 sentBytes
= self
.socket
.sendto(data
, address
)
367 totalSentBytes
= oldSentBytes
+ sentBytes
368 if len(data
) > sentBytes
:
369 self
.sendToBuffers
[0] = data
[sentBytes
:], address
, channel
, totalSentBytes
371 del self
.sendToBuffers
[0]
372 stackless
.tasklet(channel
.send
)(totalSentBytes
)
375 if __name__
== '__main__':
378 # Test code goes here.
379 testAddress
= "127.0.0.1", 3000
381 data
= struct
.pack("i", info
)
382 dataLength
= len(data
)
384 def TestTCPServer(address
):
385 global info
, data
, dataLength
387 print "server listen socket creation"
388 listenSocket
= stdsocket
.socket(AF_INET
, SOCK_STREAM
)
389 listenSocket
.setsockopt(SOL_SOCKET
, SO_REUSEADDR
, 1)
390 listenSocket
.bind(address
)
391 listenSocket
.listen(5)
396 while i
< NUM_TESTS
+ 1:
397 # No need to schedule this tasklet as the accept should yield most
398 # of the time on the underlying channel.
399 print "server connection wait", i
400 currentSocket
, clientAddress
= listenSocket
.accept()
401 print "server", i
, "listen socket", currentSocket
.fileno(), "from", clientAddress
404 print "server closing (a)", i
, "fd", currentSocket
.fileno(), "id", id(currentSocket
)
405 currentSocket
.close()
406 print "server closed (a)", i
408 print "server test", i
, "send"
409 currentSocket
.send(data
)
410 print "server test", i
, "recv"
411 if currentSocket
.recv(4) != "":
412 print "server recv(1)", i
, "FAIL"
414 # multiple empty recvs are fine
415 if currentSocket
.recv(4) != "":
416 print "server recv(2)", i
, "FAIL"
419 print "server closing (b)", i
, "fd", currentSocket
.fileno(), "id", id(currentSocket
)
420 currentSocket
.close()
422 print "server test", i
, "OK"
426 print "server: FAIL", i
428 print "server: OK", i
432 def TestTCPClient(address
):
433 global info
, data
, dataLength
436 clientSocket
= stdsocket
.socket()
437 clientSocket
.connect(address
)
438 print "client connection (1) fd", clientSocket
.fileno(), "id", id(clientSocket
._sock
), "waiting to recv"
439 if clientSocket
.recv(5) != "":
440 print "client test", 1, "FAIL"
442 print "client test", 1, "OK"
445 clientSocket
= stdsocket
.socket()
446 clientSocket
.connect(address
)
447 print "client connection (2) fd", clientSocket
.fileno(), "id", id(clientSocket
._sock
), "waiting to recv"
448 s
= clientSocket
.recv(dataLength
)
450 print "client test", 2, "FAIL (disconnect)"
452 t
= struct
.unpack("i", s
)
454 print "client test", 2, "OK"
456 print "client test", 2, "FAIL (wrong data)"
460 def TestMonkeyPatchUrllib(uri
):
461 # replace the system socket with this module
462 #oldSocket = sys.modules["socket"]
463 #sys.modules["socket"] = __import__(__name__)
466 import urllib
# must occur after monkey-patching!
467 f
= urllib
.urlopen(uri
)
468 if not isinstance(f
.fp
._sock
, _fakesocket
):
469 raise AssertionError("failed to apply monkeypatch, got %s" % f
.fp
._sock
.__class
__)
472 print "Fetched", len(s
), "bytes via replaced urllib"
474 raise AssertionError("no text received?")
476 #sys.modules["socket"] = oldSocket
479 def TestMonkeyPatchUDP(address
):
480 # replace the system socket with this module
481 #oldSocket = sys.modules["socket"]
482 #sys.modules["socket"] = __import__(__name__)
485 def UDPServer(address
):
486 listenSocket
= stdsocket
.socket(AF_INET
, SOCK_DGRAM
)
487 listenSocket
.setsockopt(SOL_SOCKET
, SO_REUSEADDR
, 1)
488 listenSocket
.bind(address
)
490 # Apparently each call to recvfrom maps to an incoming
491 # packet and if we only ask for part of that packet, the
492 # rest is lost. We really need a proper unittest suite
493 # which tests this module against the normal socket
495 print "waiting to receive"
496 data
, address
= listenSocket
.recvfrom(256)
497 print "received", data
, len(data
)
499 raise StandardError("Unexpected UDP packet size")
501 def UDPClient(address
):
502 clientSocket
= stdsocket
.socket(AF_INET
, SOCK_DGRAM
)
503 # clientSocket.connect(address)
504 print "sending 512 byte packet"
505 sentBytes
= clientSocket
.sendto("-"+ ("*" * 510) +"-", address
)
506 print "sent 512 byte packet", sentBytes
508 stackless
.tasklet(UDPServer
)(address
)
509 stackless
.tasklet(UDPClient
)(address
)
512 #sys.modules["socket"] = oldSocket
515 if len(sys
.argv
) == 2:
516 if sys
.argv
[1] == "client":
517 print "client started"
518 TestTCPClient(testAddress
)
519 print "client exited"
520 elif sys
.argv
[1] == "slpclient":
521 print "client started"
522 stackless
.tasklet(TestTCPClient
)(testAddress
)
524 print "client exited"
525 elif sys
.argv
[1] == "server":
526 print "server started"
527 TestTCPServer(testAddress
)
528 print "server exited"
529 elif sys
.argv
[1] == "slpserver":
530 print "server started"
531 stackless
.tasklet(TestTCPServer
)(testAddress
)
533 print "server exited"
535 print "Usage:", sys
.argv
[0], "[client|server|slpclient|slpserver]"
539 print "* Running client/server test"
542 stackless
.tasklet(TestTCPServer
)(testAddress
)
543 stackless
.tasklet(TestTCPClient
)(testAddress
)
548 print "* Running urllib test"
549 stackless
.tasklet(TestMonkeyPatchUrllib
)("http://python.org/")
552 print "* Running udp test"
553 TestMonkeyPatchUDP(testAddress
)
555 print "result: SUCCESS"