From 92f618a5d2de25b389c31730dec77d445e4a6a68 Mon Sep 17 00:00:00 2001 From: Benedikt Sauer Date: Sun, 31 Aug 2008 20:05:35 +0200 Subject: [PATCH] Implemented locking so the threading version works too. Still todo: - Control-C doesn't work. --- server/blinken_handler.py | 4 +- server/my_threading.py | 4 + server/server.py | 21 +- server/stackless_socket.py | 1110 ++++++++++++++++++++++---------------------- 4 files changed, 574 insertions(+), 565 deletions(-) diff --git a/server/blinken_handler.py b/server/blinken_handler.py index 7d1b307..a313e95 100644 --- a/server/blinken_handler.py +++ b/server/blinken_handler.py @@ -1,3 +1,4 @@ +from __future__ import with_statement from SocketServer import BaseRequestHandler __version__ = "i" @@ -15,7 +16,8 @@ class BlinkenHandler(BaseRequestHandler): data = recv(4096) coords = tuple(int(i) for i in data.split(':')) if len(coords) == 2: - self.server.connections[coords] = self.request + with self.server.lock: + self.server.connections[coords] = self.request send(M_OKAY) else: send(M_NOPE) diff --git a/server/my_threading.py b/server/my_threading.py index dd6180f..5021b6e 100644 --- a/server/my_threading.py +++ b/server/my_threading.py @@ -16,6 +16,9 @@ try: sleepingTasklets = [] + class Lock(object): + __enter__ = lambda *args: None + __exit__ = lambda *args: None def sleep(seconds): channel = stackless.channel() @@ -57,6 +60,7 @@ try: except ImportError: from threading import Thread as PyThread + from threading import Lock from SocketServer import ThreadingMixIn class Thread(PyThread): diff --git a/server/server.py b/server/server.py index 4023b2e..542803c 100644 --- a/server/server.py +++ b/server/server.py @@ -1,6 +1,7 @@ +from __future__ import with_statement from projector import DummyProjector from blinken_handler import BlinkenHandler -from my_threading import Thread, schedule, ThreadingMixIn +from my_threading import Thread, schedule, ThreadingMixIn, Lock from SocketServer import StreamRequestHandler, TCPServer @@ -37,6 +38,7 @@ class BlinkenServer(object): self.set_projector(DummyProjector()) self._blinken = Server(('', self.blinken_port), BlinkenHandler) self._blinken.connections = {} + self._blinken.lock = Lock() self._blinken.close_request = lambda request : None self._control = Server(('', self.control_port), ControlHandler) @@ -56,14 +58,15 @@ class BlinkenServer(object): while self.is_running: for frame in self.current: to_clean = [] - for pixel, socket in self._blinken.connections.iteritems(): - try: - socket.send(("on" if pixel in frame else "off") + "\r\n") - except: - to_clean.append(pixel) - - for i in to_clean: - del self._blinken.connections[i] + with self._blinken.lock: + for pixel, socket in self._blinken.connections.iteritems(): + try: + socket.send(("on" if pixel in frame else "off") + "\r\n") + except: + to_clean.append(pixel) + + for i in to_clean: + del self._blinken.connections[i] if self.next_ready: if self.current.priority < self.next.priority or self.force_next: diff --git a/server/stackless_socket.py b/server/stackless_socket.py index 808d1ac..7724858 100644 --- a/server/stackless_socket.py +++ b/server/stackless_socket.py @@ -1,555 +1,555 @@ -# -# Stackless compatible socket module: -# -# Author: Richard Tew -# -# This code was written to serve as an example of Stackless Python usage. -# Feel free to email me with any questions, comments, or suggestions for -# improvement. -# -# This wraps the asyncore module and the dispatcher class it provides in order -# write a socket module replacement that uses channels to allow calls to it to -# block until a delayed event occurs. -# -# Not all aspects of the socket module are provided by this file. Examples of -# it in use can be seen at the bottom of this file. -# -# NOTE: Versions of the asyncore module from Python 2.4 or later include bug -# fixes and earlier versions will not guarantee correct behaviour. -# Specifically, it monitors for errors on sockets where the version in -# Python 2.3.3 does not. -# - -# Possible improvements: -# - More correct error handling. When there is an error on a socket found by -# poll, there is no idea what it actually is. -# - Launching each bit of incoming data in its own tasklet on the recvChannel -# send is a little over the top. It should be possible to add it to the -# rest of the queued data - -import stackless -import asyncore, weakref -import socket as stdsocket # We need the "socket" name for the function we export. - -# If we are to masquerade as the socket module, we need to provide the constants. -if "__all__" in stdsocket.__dict__: - __all__ = stdsocket.__dict__ - for k, v in stdsocket.__dict__.iteritems(): - if k in __all__: - globals()[k] = v - elif k == "EBADF": - globals()[k] = v -else: - for k, v in stdsocket.__dict__.iteritems(): - if k.upper() == k: - globals()[k] = v - error = stdsocket.error - timeout = stdsocket.timeout - # WARNING: this function blocks and is not thread safe. - # The only solution is to spawn a thread to handle all - # getaddrinfo requests. Implementing a stackless DNS - # lookup service is only second best as getaddrinfo may - # use other methods. - getaddrinfo = stdsocket.getaddrinfo - -# urllib2 apparently uses this directly. We need to cater for that. -_fileobject = stdsocket._fileobject - -# Someone needs to invoke asyncore.poll() regularly to keep the socket -# data moving. The "ManageSockets" function here is a simple example -# of such a function. It is started by StartManager(), which uses the -# global "managerRunning" to ensure that no more than one copy is -# running. -# -# If you think you can do this better, register an alternative to -# StartManager using stacklesssocket_manager(). Your function will be -# called every time a new socket is created; it's your responsibility -# to ensure it doesn't start multiple copies of itself unnecessarily. -# - -managerRunning = False - -def ManageSockets(): - global managerRunning - - while len(asyncore.socket_map): - # Check the sockets for activity. - asyncore.poll(0.05) - # Yield to give other tasklets a chance to be scheduled. - stackless.schedule() - - managerRunning = False - -def StartManager(): - global managerRunning - if not managerRunning: - managerRunning = True - stackless.tasklet(ManageSockets)() - -_manage_sockets_func = StartManager - -def stacklesssocket_manager(mgr): - global _manage_sockets_func - _manage_sockets_func = mgr - -def socket(*args, **kwargs): - import sys - if "socket" in sys.modules and sys.modules["socket"] is not stdsocket: - raise RuntimeError("Use 'stacklesssocket.install' instead of replacing the 'socket' module") - -_realsocket_old = stdsocket._realsocket -_socketobject_old = stdsocket._socketobject - -class _socketobject_new(_socketobject_old): - def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, _sock=None): - # We need to do this here. - if _sock is None: - _sock = _realsocket_old(family, type, proto) - _sock = _fakesocket(_sock) - _manage_sockets_func() - _socketobject_old.__init__(self, family, type, proto, _sock) - if not isinstance(self._sock, _fakesocket): - raise RuntimeError("bad socket") - - def accept(self): - sock, addr = self._sock.accept() - sock = _fakesocket(sock) - sock.wasConnected = True - return _socketobject_new(_sock=sock), addr - - accept.__doc__ = _socketobject_old.accept.__doc__ - - -def check_still_connected(f): - " Decorate socket functions to check they are still connected. " - def new_f(self, *args, **kwds): - if not self.connected: - # The socket was never connected. - if not self.wasConnected: - raise error(10057, "Socket is not connected") - # The socket has been closed already. - raise error(EBADF, 'Bad file descriptor') - return f(self, *args, **kwds) - return new_f - - -def install(): - if stdsocket._realsocket is socket: - raise StandardError("Still installed") - stdsocket._realsocket = socket - stdsocket.socket = stdsocket.SocketType = stdsocket._socketobject = _socketobject_new - -def uninstall(): - stdsocket._realsocket = _realsocket_old - stdsocket.socket = stdsocket.SocketType = stdsocket._socketobject = _socketobject_old - - -class _fakesocket(asyncore.dispatcher): - connectChannel = None - acceptChannel = None - recvChannel = None - wasConnected = False - - def __init__(self, realSocket): - # This is worth doing. I was passing in an invalid socket which - # was an instance of _fakesocket and it was causing tasklet death. - if not isinstance(realSocket, _realsocket_old): - raise StandardError("An invalid socket passed to fakesocket %s" % realSocket.__class__) - - # This will register the real socket in the internal socket map. - asyncore.dispatcher.__init__(self, realSocket) - self.socket = realSocket - - self.recvChannel = stackless.channel() - self.readString = '' - self.readIdx = 0 - - self.sendBuffer = '' - self.sendToBuffers = [] - - def __del__(self): - # There are no more users (sockets or files) of this fake socket, we - # are safe to close it fully. If we don't, asyncore will choke on - # the weakref failures. - self.close() - - # The asyncore version of this function depends on socket being set - # which is not the case when this fake socket has been closed. - def __getattr__(self, attr): - if not hasattr(self, "socket"): - raise AttributeError("socket attribute unset on '"+ attr +"' lookup") - return getattr(self.socket, attr) - - def add_channel(self, map=None): - if map is None: - map = self._map - map[self._fileno] = weakref.proxy(self) - - def writable(self): - if self.socket.type != SOCK_DGRAM and not self.connected: - return True - return len(self.sendBuffer) or len(self.sendToBuffers) - - def accept(self): - if not self.acceptChannel: - self.acceptChannel = stackless.channel() - return self.acceptChannel.receive() - - def connect(self, address): - asyncore.dispatcher.connect(self, address) - # UDP sockets do not connect. - if self.socket.type != SOCK_DGRAM and not self.connected: - if not self.connectChannel: - self.connectChannel = stackless.channel() - # Prefer the sender. Do not block when sending, given that - # there is a tasklet known to be waiting, this will happen. - self.connectChannel.preference = 1 - self.connectChannel.receive() - - @check_still_connected - def send(self, data, flags=0): - self.sendBuffer += data - # stackless.schedule() - return len(data) - - @check_still_connected - def sendall(self, data, flags=0): - # WARNING: this will busy wait until all data is sent - # It should be possible to do away with the busy wait with - # the use of a channel. - self.sendBuffer += data - while self.sendBuffer: - stackless.schedule() - return len(data) - - def sendto(self, sendData, sendArg1=None, sendArg2=None): - # sendto(data, address) - # sendto(data [, flags], address) - if sendArg2 is not None: - flags = sendArg1 - sendAddress = sendArg2 - else: - flags = 0 - sendAddress = sendArg1 - - waitChannel = None - for idx, (data, address, channel, sentBytes) in enumerate(self.sendToBuffers): - if address == sendAddress: - self.sendToBuffers[idx] = (data + sendData, address, channel, sentBytes) - waitChannel = channel - break - if waitChannel is None: - waitChannel = stackless.channel() - self.sendToBuffers.append((sendData, sendAddress, waitChannel, 0)) - return waitChannel.receive() - - # Read at most byteCount bytes. - def recv(self, byteCount, flags=0): - # recv() must not concatenate two or more data fragments sent with - # send() on the remote side. Single fragment sent with single send() - # call should be split into strings of length less than or equal - # to 'byteCount', and returned by one or more recv() calls. - - remainingBytes = self.readIdx != len(self.readString) - # TODO: Verify this connectivity behaviour. - - if not self.connected: - # Sockets which have never been connected do this. - if not self.wasConnected: - raise error(10057, 'Socket is not connected') - - # Sockets which were connected, but no longer are, use - # up the remaining input. Observed this with urllib.urlopen - # where it closes the socket and then allows the caller to - # use a file to access the body of the web page. - elif not remainingBytes: - self.readString = self.recvChannel.receive() - self.readIdx = 0 - remainingBytes = len(self.readString) - - if byteCount == 1 and remainingBytes: - ret = self.readString[self.readIdx] - self.readIdx += 1 - elif self.readIdx == 0 and byteCount >= len(self.readString): - ret = self.readString - self.readString = "" - else: - idx = self.readIdx + byteCount - ret = self.readString[self.readIdx:idx] - self.readString = self.readString[idx:] - self.readIdx = 0 - - # ret will be '' when EOF. - return ret - - def recvfrom(self, byteCount, flags=0): - if self.socket.type == SOCK_STREAM: - return self.recv(byteCount), None - - # recvfrom() must not concatenate two or more packets. - # Each call should return the first 'byteCount' part of the packet. - data, address = self.recvChannel.receive() - return data[:byteCount], address - - def close(self): - asyncore.dispatcher.close(self) - - self.connected = False - self.accepting = False - self.sendBuffer = None # breaks the loop in sendall - - # Clear out all the channels with relevant errors. - while self.acceptChannel and self.acceptChannel.balance < 0: - self.acceptChannel.send_exception(error, 9, 'Bad file descriptor') - while self.connectChannel and self.connectChannel.balance < 0: - self.connectChannel.send_exception(error, 10061, 'Connection refused') - while self.recvChannel and self.recvChannel.balance < 0: - # The closing of a socket is indicted by receiving nothing. The - # exception would have been sent if the server was killed, rather - # than closed down gracefully. - self.recvChannel.send("") - #self.recvChannel.send_exception(error, 10054, 'Connection reset by peer') - - # asyncore doesn't support this. Why not? - def fileno(self): - return self.socket.fileno() - - def handle_accept(self): - if self.acceptChannel and self.acceptChannel.balance < 0: - t = asyncore.dispatcher.accept(self) - if t is None: - return - t[0].setsockopt(SOL_SOCKET, SO_REUSEADDR, 1) - stackless.tasklet(self.acceptChannel.send)(t) - - # Inform the blocked connect call that the connection has been made. - def handle_connect(self): - if self.socket.type != SOCK_DGRAM: - self.wasConnected = True - self.connectChannel.send(None) - - # Asyncore says its done but self.readBuffer may be non-empty - # so can't close yet. Do nothing and let 'recv' trigger the close. - def handle_close(self): - pass - - # Some error, just close the channel and let that raise errors to - # blocked calls. - def handle_expt(self): - self.close() - - def handle_read(self): - try: - if self.socket.type == SOCK_DGRAM: - ret = self.socket.recvfrom(20000) - else: - ret = asyncore.dispatcher.recv(self, 20000) - # Not sure this is correct, but it seems to give the - # right behaviour. Namely removing the socket from - # asyncore. - if not ret: - self.close() - stackless.tasklet(self.recvChannel.send)(ret) - except stdsocket.error, err: - # If there's a read error assume the connection is - # broken and drop any pending output - if self.sendBuffer: - self.sendBuffer = "" - self.recvChannel.send_exception(stdsocket.error, err) - - def handle_write(self): - if len(self.sendBuffer): - sentBytes = asyncore.dispatcher.send(self, self.sendBuffer[:512]) - self.sendBuffer = self.sendBuffer[sentBytes:] - elif len(self.sendToBuffers): - data, address, channel, oldSentBytes = self.sendToBuffers[0] - sentBytes = self.socket.sendto(data, address) - totalSentBytes = oldSentBytes + sentBytes - if len(data) > sentBytes: - self.sendToBuffers[0] = data[sentBytes:], address, channel, totalSentBytes - else: - del self.sendToBuffers[0] - stackless.tasklet(channel.send)(totalSentBytes) - - -if __name__ == '__main__': - import sys - import struct - # Test code goes here. - testAddress = "127.0.0.1", 3000 - info = -12345678 - data = struct.pack("i", info) - dataLength = len(data) - - def TestTCPServer(address): - global info, data, dataLength - - print "server listen socket creation" - listenSocket = stdsocket.socket(AF_INET, SOCK_STREAM) - listenSocket.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1) - listenSocket.bind(address) - listenSocket.listen(5) - - NUM_TESTS = 2 - - i = 1 - while i < NUM_TESTS + 1: - # No need to schedule this tasklet as the accept should yield most - # of the time on the underlying channel. - print "server connection wait", i - currentSocket, clientAddress = listenSocket.accept() - print "server", i, "listen socket", currentSocket.fileno(), "from", clientAddress - - if i == 1: - print "server closing (a)", i, "fd", currentSocket.fileno(), "id", id(currentSocket) - currentSocket.close() - print "server closed (a)", i - elif i == 2: - print "server test", i, "send" - currentSocket.send(data) - print "server test", i, "recv" - if currentSocket.recv(4) != "": - print "server recv(1)", i, "FAIL" - break - # multiple empty recvs are fine - if currentSocket.recv(4) != "": - print "server recv(2)", i, "FAIL" - break - else: - print "server closing (b)", i, "fd", currentSocket.fileno(), "id", id(currentSocket) - currentSocket.close() - - print "server test", i, "OK" - i += 1 - - if i != NUM_TESTS+1: - print "server: FAIL", i - else: - print "server: OK", i - - print "Done server" - - def TestTCPClient(address): - global info, data, dataLength - - # Attempt 1: - clientSocket = stdsocket.socket() - clientSocket.connect(address) - print "client connection (1) fd", clientSocket.fileno(), "id", id(clientSocket._sock), "waiting to recv" - if clientSocket.recv(5) != "": - print "client test", 1, "FAIL" - else: - print "client test", 1, "OK" - - # Attempt 2: - clientSocket = stdsocket.socket() - clientSocket.connect(address) - print "client connection (2) fd", clientSocket.fileno(), "id", id(clientSocket._sock), "waiting to recv" - s = clientSocket.recv(dataLength) - if s == "": - print "client test", 2, "FAIL (disconnect)" - else: - t = struct.unpack("i", s) - if t[0] == info: - print "client test", 2, "OK" - else: - print "client test", 2, "FAIL (wrong data)" - - print "client exit" - - def TestMonkeyPatchUrllib(uri): - # replace the system socket with this module - #oldSocket = sys.modules["socket"] - #sys.modules["socket"] = __import__(__name__) - install() - try: - import urllib # must occur after monkey-patching! - f = urllib.urlopen(uri) - if not isinstance(f.fp._sock, _fakesocket): - raise AssertionError("failed to apply monkeypatch, got %s" % f.fp._sock.__class__) - s = f.read() - if len(s) != 0: - print "Fetched", len(s), "bytes via replaced urllib" - else: - raise AssertionError("no text received?") - finally: - #sys.modules["socket"] = oldSocket - uninstall() - - def TestMonkeyPatchUDP(address): - # replace the system socket with this module - #oldSocket = sys.modules["socket"] - #sys.modules["socket"] = __import__(__name__) - install() - try: - def UDPServer(address): - listenSocket = stdsocket.socket(AF_INET, SOCK_DGRAM) - listenSocket.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1) - listenSocket.bind(address) - - # Apparently each call to recvfrom maps to an incoming - # packet and if we only ask for part of that packet, the - # rest is lost. We really need a proper unittest suite - # which tests this module against the normal socket - # module. - print "waiting to receive" - data, address = listenSocket.recvfrom(256) - print "received", data, len(data) - if len(data) != 256: - raise StandardError("Unexpected UDP packet size") - - def UDPClient(address): - clientSocket = stdsocket.socket(AF_INET, SOCK_DGRAM) - # clientSocket.connect(address) - print "sending 512 byte packet" - sentBytes = clientSocket.sendto("-"+ ("*" * 510) +"-", address) - print "sent 512 byte packet", sentBytes - - stackless.tasklet(UDPServer)(address) - stackless.tasklet(UDPClient)(address) - stackless.run() - finally: - #sys.modules["socket"] = oldSocket - uninstall() - - if len(sys.argv) == 2: - if sys.argv[1] == "client": - print "client started" - TestTCPClient(testAddress) - print "client exited" - elif sys.argv[1] == "slpclient": - print "client started" - stackless.tasklet(TestTCPClient)(testAddress) - stackless.run() - print "client exited" - elif sys.argv[1] == "server": - print "server started" - TestTCPServer(testAddress) - print "server exited" - elif sys.argv[1] == "slpserver": - print "server started" - stackless.tasklet(TestTCPServer)(testAddress) - stackless.run() - print "server exited" - else: - print "Usage:", sys.argv[0], "[client|server|slpclient|slpserver]" - - sys.exit(1) - else: - print "* Running client/server test" - install() - try: - stackless.tasklet(TestTCPServer)(testAddress) - stackless.tasklet(TestTCPClient)(testAddress) - stackless.run() - finally: - uninstall() - - print "* Running urllib test" - stackless.tasklet(TestMonkeyPatchUrllib)("http://python.org/") - stackless.run() - - print "* Running udp test" - TestMonkeyPatchUDP(testAddress) - - print "result: SUCCESS" +# +# Stackless compatible socket module: +# +# Author: Richard Tew +# +# This code was written to serve as an example of Stackless Python usage. +# Feel free to email me with any questions, comments, or suggestions for +# improvement. +# +# This wraps the asyncore module and the dispatcher class it provides in order +# write a socket module replacement that uses channels to allow calls to it to +# block until a delayed event occurs. +# +# Not all aspects of the socket module are provided by this file. Examples of +# it in use can be seen at the bottom of this file. +# +# NOTE: Versions of the asyncore module from Python 2.4 or later include bug +# fixes and earlier versions will not guarantee correct behaviour. +# Specifically, it monitors for errors on sockets where the version in +# Python 2.3.3 does not. +# + +# Possible improvements: +# - More correct error handling. When there is an error on a socket found by +# poll, there is no idea what it actually is. +# - Launching each bit of incoming data in its own tasklet on the recvChannel +# send is a little over the top. It should be possible to add it to the +# rest of the queued data + +import stackless +import asyncore, weakref +import socket as stdsocket # We need the "socket" name for the function we export. + +# If we are to masquerade as the socket module, we need to provide the constants. +if "__all__" in stdsocket.__dict__: + __all__ = stdsocket.__dict__ + for k, v in stdsocket.__dict__.iteritems(): + if k in __all__: + globals()[k] = v + elif k == "EBADF": + globals()[k] = v +else: + for k, v in stdsocket.__dict__.iteritems(): + if k.upper() == k: + globals()[k] = v + error = stdsocket.error + timeout = stdsocket.timeout + # WARNING: this function blocks and is not thread safe. + # The only solution is to spawn a thread to handle all + # getaddrinfo requests. Implementing a stackless DNS + # lookup service is only second best as getaddrinfo may + # use other methods. + getaddrinfo = stdsocket.getaddrinfo + +# urllib2 apparently uses this directly. We need to cater for that. +_fileobject = stdsocket._fileobject + +# Someone needs to invoke asyncore.poll() regularly to keep the socket +# data moving. The "ManageSockets" function here is a simple example +# of such a function. It is started by StartManager(), which uses the +# global "managerRunning" to ensure that no more than one copy is +# running. +# +# If you think you can do this better, register an alternative to +# StartManager using stacklesssocket_manager(). Your function will be +# called every time a new socket is created; it's your responsibility +# to ensure it doesn't start multiple copies of itself unnecessarily. +# + +managerRunning = False + +def ManageSockets(): + global managerRunning + + while len(asyncore.socket_map): + # Check the sockets for activity. + asyncore.poll(0.05) + # Yield to give other tasklets a chance to be scheduled. + stackless.schedule() + + managerRunning = False + +def StartManager(): + global managerRunning + if not managerRunning: + managerRunning = True + stackless.tasklet(ManageSockets)() + +_manage_sockets_func = StartManager + +def stacklesssocket_manager(mgr): + global _manage_sockets_func + _manage_sockets_func = mgr + +def socket(*args, **kwargs): + import sys + if "socket" in sys.modules and sys.modules["socket"] is not stdsocket: + raise RuntimeError("Use 'stacklesssocket.install' instead of replacing the 'socket' module") + +_realsocket_old = stdsocket._realsocket +_socketobject_old = stdsocket._socketobject + +class _socketobject_new(_socketobject_old): + def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, _sock=None): + # We need to do this here. + if _sock is None: + _sock = _realsocket_old(family, type, proto) + _sock = _fakesocket(_sock) + _manage_sockets_func() + _socketobject_old.__init__(self, family, type, proto, _sock) + if not isinstance(self._sock, _fakesocket): + raise RuntimeError("bad socket") + + def accept(self): + sock, addr = self._sock.accept() + sock = _fakesocket(sock) + sock.wasConnected = True + return _socketobject_new(_sock=sock), addr + + accept.__doc__ = _socketobject_old.accept.__doc__ + + +def check_still_connected(f): + " Decorate socket functions to check they are still connected. " + def new_f(self, *args, **kwds): + if not self.connected: + # The socket was never connected. + if not self.wasConnected: + raise error(10057, "Socket is not connected") + # The socket has been closed already. + raise error(EBADF, 'Bad file descriptor') + return f(self, *args, **kwds) + return new_f + + +def install(): + if stdsocket._realsocket is socket: + raise StandardError("Still installed") + stdsocket._realsocket = socket + stdsocket.socket = stdsocket.SocketType = stdsocket._socketobject = _socketobject_new + +def uninstall(): + stdsocket._realsocket = _realsocket_old + stdsocket.socket = stdsocket.SocketType = stdsocket._socketobject = _socketobject_old + + +class _fakesocket(asyncore.dispatcher): + connectChannel = None + acceptChannel = None + recvChannel = None + wasConnected = False + + def __init__(self, realSocket): + # This is worth doing. I was passing in an invalid socket which + # was an instance of _fakesocket and it was causing tasklet death. + if not isinstance(realSocket, _realsocket_old): + raise StandardError("An invalid socket passed to fakesocket %s" % realSocket.__class__) + + # This will register the real socket in the internal socket map. + asyncore.dispatcher.__init__(self, realSocket) + self.socket = realSocket + + self.recvChannel = stackless.channel() + self.readString = '' + self.readIdx = 0 + + self.sendBuffer = '' + self.sendToBuffers = [] + + def __del__(self): + # There are no more users (sockets or files) of this fake socket, we + # are safe to close it fully. If we don't, asyncore will choke on + # the weakref failures. + self.close() + + # The asyncore version of this function depends on socket being set + # which is not the case when this fake socket has been closed. + def __getattr__(self, attr): + if not hasattr(self, "socket"): + raise AttributeError("socket attribute unset on '"+ attr +"' lookup") + return getattr(self.socket, attr) + + def add_channel(self, map=None): + if map is None: + map = self._map + map[self._fileno] = weakref.proxy(self) + + def writable(self): + if self.socket.type != SOCK_DGRAM and not self.connected: + return True + return len(self.sendBuffer) or len(self.sendToBuffers) + + def accept(self): + if not self.acceptChannel: + self.acceptChannel = stackless.channel() + return self.acceptChannel.receive() + + def connect(self, address): + asyncore.dispatcher.connect(self, address) + # UDP sockets do not connect. + if self.socket.type != SOCK_DGRAM and not self.connected: + if not self.connectChannel: + self.connectChannel = stackless.channel() + # Prefer the sender. Do not block when sending, given that + # there is a tasklet known to be waiting, this will happen. + self.connectChannel.preference = 1 + self.connectChannel.receive() + + @check_still_connected + def send(self, data, flags=0): + self.sendBuffer += data + # stackless.schedule() + return len(data) + + @check_still_connected + def sendall(self, data, flags=0): + # WARNING: this will busy wait until all data is sent + # It should be possible to do away with the busy wait with + # the use of a channel. + self.sendBuffer += data + while self.sendBuffer: + stackless.schedule() + return len(data) + + def sendto(self, sendData, sendArg1=None, sendArg2=None): + # sendto(data, address) + # sendto(data [, flags], address) + if sendArg2 is not None: + flags = sendArg1 + sendAddress = sendArg2 + else: + flags = 0 + sendAddress = sendArg1 + + waitChannel = None + for idx, (data, address, channel, sentBytes) in enumerate(self.sendToBuffers): + if address == sendAddress: + self.sendToBuffers[idx] = (data + sendData, address, channel, sentBytes) + waitChannel = channel + break + if waitChannel is None: + waitChannel = stackless.channel() + self.sendToBuffers.append((sendData, sendAddress, waitChannel, 0)) + return waitChannel.receive() + + # Read at most byteCount bytes. + def recv(self, byteCount, flags=0): + # recv() must not concatenate two or more data fragments sent with + # send() on the remote side. Single fragment sent with single send() + # call should be split into strings of length less than or equal + # to 'byteCount', and returned by one or more recv() calls. + + remainingBytes = self.readIdx != len(self.readString) + # TODO: Verify this connectivity behaviour. + + if not self.connected: + # Sockets which have never been connected do this. + if not self.wasConnected: + raise error(10057, 'Socket is not connected') + + # Sockets which were connected, but no longer are, use + # up the remaining input. Observed this with urllib.urlopen + # where it closes the socket and then allows the caller to + # use a file to access the body of the web page. + elif not remainingBytes: + self.readString = self.recvChannel.receive() + self.readIdx = 0 + remainingBytes = len(self.readString) + + if byteCount == 1 and remainingBytes: + ret = self.readString[self.readIdx] + self.readIdx += 1 + elif self.readIdx == 0 and byteCount >= len(self.readString): + ret = self.readString + self.readString = "" + else: + idx = self.readIdx + byteCount + ret = self.readString[self.readIdx:idx] + self.readString = self.readString[idx:] + self.readIdx = 0 + + # ret will be '' when EOF. + return ret + + def recvfrom(self, byteCount, flags=0): + if self.socket.type == SOCK_STREAM: + return self.recv(byteCount), None + + # recvfrom() must not concatenate two or more packets. + # Each call should return the first 'byteCount' part of the packet. + data, address = self.recvChannel.receive() + return data[:byteCount], address + + def close(self): + asyncore.dispatcher.close(self) + + self.connected = False + self.accepting = False + self.sendBuffer = None # breaks the loop in sendall + + # Clear out all the channels with relevant errors. + while self.acceptChannel and self.acceptChannel.balance < 0: + self.acceptChannel.send_exception(error, 9, 'Bad file descriptor') + while self.connectChannel and self.connectChannel.balance < 0: + self.connectChannel.send_exception(error, 10061, 'Connection refused') + while self.recvChannel and self.recvChannel.balance < 0: + # The closing of a socket is indicted by receiving nothing. The + # exception would have been sent if the server was killed, rather + # than closed down gracefully. + self.recvChannel.send("") + #self.recvChannel.send_exception(error, 10054, 'Connection reset by peer') + + # asyncore doesn't support this. Why not? + def fileno(self): + return self.socket.fileno() + + def handle_accept(self): + if self.acceptChannel and self.acceptChannel.balance < 0: + t = asyncore.dispatcher.accept(self) + if t is None: + return + t[0].setsockopt(SOL_SOCKET, SO_REUSEADDR, 1) + stackless.tasklet(self.acceptChannel.send)(t) + + # Inform the blocked connect call that the connection has been made. + def handle_connect(self): + if self.socket.type != SOCK_DGRAM: + self.wasConnected = True + self.connectChannel.send(None) + + # Asyncore says its done but self.readBuffer may be non-empty + # so can't close yet. Do nothing and let 'recv' trigger the close. + def handle_close(self): + pass + + # Some error, just close the channel and let that raise errors to + # blocked calls. + def handle_expt(self): + self.close() + + def handle_read(self): + try: + if self.socket.type == SOCK_DGRAM: + ret = self.socket.recvfrom(20000) + else: + ret = asyncore.dispatcher.recv(self, 20000) + # Not sure this is correct, but it seems to give the + # right behaviour. Namely removing the socket from + # asyncore. + if not ret: + self.close() + stackless.tasklet(self.recvChannel.send)(ret) + except stdsocket.error, err: + # If there's a read error assume the connection is + # broken and drop any pending output + if self.sendBuffer: + self.sendBuffer = "" + self.recvChannel.send_exception(stdsocket.error, err) + + def handle_write(self): + if len(self.sendBuffer): + sentBytes = asyncore.dispatcher.send(self, self.sendBuffer[:512]) + self.sendBuffer = self.sendBuffer[sentBytes:] + elif len(self.sendToBuffers): + data, address, channel, oldSentBytes = self.sendToBuffers[0] + sentBytes = self.socket.sendto(data, address) + totalSentBytes = oldSentBytes + sentBytes + if len(data) > sentBytes: + self.sendToBuffers[0] = data[sentBytes:], address, channel, totalSentBytes + else: + del self.sendToBuffers[0] + stackless.tasklet(channel.send)(totalSentBytes) + + +if __name__ == '__main__': + import sys + import struct + # Test code goes here. + testAddress = "127.0.0.1", 3000 + info = -12345678 + data = struct.pack("i", info) + dataLength = len(data) + + def TestTCPServer(address): + global info, data, dataLength + + print "server listen socket creation" + listenSocket = stdsocket.socket(AF_INET, SOCK_STREAM) + listenSocket.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1) + listenSocket.bind(address) + listenSocket.listen(5) + + NUM_TESTS = 2 + + i = 1 + while i < NUM_TESTS + 1: + # No need to schedule this tasklet as the accept should yield most + # of the time on the underlying channel. + print "server connection wait", i + currentSocket, clientAddress = listenSocket.accept() + print "server", i, "listen socket", currentSocket.fileno(), "from", clientAddress + + if i == 1: + print "server closing (a)", i, "fd", currentSocket.fileno(), "id", id(currentSocket) + currentSocket.close() + print "server closed (a)", i + elif i == 2: + print "server test", i, "send" + currentSocket.send(data) + print "server test", i, "recv" + if currentSocket.recv(4) != "": + print "server recv(1)", i, "FAIL" + break + # multiple empty recvs are fine + if currentSocket.recv(4) != "": + print "server recv(2)", i, "FAIL" + break + else: + print "server closing (b)", i, "fd", currentSocket.fileno(), "id", id(currentSocket) + currentSocket.close() + + print "server test", i, "OK" + i += 1 + + if i != NUM_TESTS+1: + print "server: FAIL", i + else: + print "server: OK", i + + print "Done server" + + def TestTCPClient(address): + global info, data, dataLength + + # Attempt 1: + clientSocket = stdsocket.socket() + clientSocket.connect(address) + print "client connection (1) fd", clientSocket.fileno(), "id", id(clientSocket._sock), "waiting to recv" + if clientSocket.recv(5) != "": + print "client test", 1, "FAIL" + else: + print "client test", 1, "OK" + + # Attempt 2: + clientSocket = stdsocket.socket() + clientSocket.connect(address) + print "client connection (2) fd", clientSocket.fileno(), "id", id(clientSocket._sock), "waiting to recv" + s = clientSocket.recv(dataLength) + if s == "": + print "client test", 2, "FAIL (disconnect)" + else: + t = struct.unpack("i", s) + if t[0] == info: + print "client test", 2, "OK" + else: + print "client test", 2, "FAIL (wrong data)" + + print "client exit" + + def TestMonkeyPatchUrllib(uri): + # replace the system socket with this module + #oldSocket = sys.modules["socket"] + #sys.modules["socket"] = __import__(__name__) + install() + try: + import urllib # must occur after monkey-patching! + f = urllib.urlopen(uri) + if not isinstance(f.fp._sock, _fakesocket): + raise AssertionError("failed to apply monkeypatch, got %s" % f.fp._sock.__class__) + s = f.read() + if len(s) != 0: + print "Fetched", len(s), "bytes via replaced urllib" + else: + raise AssertionError("no text received?") + finally: + #sys.modules["socket"] = oldSocket + uninstall() + + def TestMonkeyPatchUDP(address): + # replace the system socket with this module + #oldSocket = sys.modules["socket"] + #sys.modules["socket"] = __import__(__name__) + install() + try: + def UDPServer(address): + listenSocket = stdsocket.socket(AF_INET, SOCK_DGRAM) + listenSocket.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1) + listenSocket.bind(address) + + # Apparently each call to recvfrom maps to an incoming + # packet and if we only ask for part of that packet, the + # rest is lost. We really need a proper unittest suite + # which tests this module against the normal socket + # module. + print "waiting to receive" + data, address = listenSocket.recvfrom(256) + print "received", data, len(data) + if len(data) != 256: + raise StandardError("Unexpected UDP packet size") + + def UDPClient(address): + clientSocket = stdsocket.socket(AF_INET, SOCK_DGRAM) + # clientSocket.connect(address) + print "sending 512 byte packet" + sentBytes = clientSocket.sendto("-"+ ("*" * 510) +"-", address) + print "sent 512 byte packet", sentBytes + + stackless.tasklet(UDPServer)(address) + stackless.tasklet(UDPClient)(address) + stackless.run() + finally: + #sys.modules["socket"] = oldSocket + uninstall() + + if len(sys.argv) == 2: + if sys.argv[1] == "client": + print "client started" + TestTCPClient(testAddress) + print "client exited" + elif sys.argv[1] == "slpclient": + print "client started" + stackless.tasklet(TestTCPClient)(testAddress) + stackless.run() + print "client exited" + elif sys.argv[1] == "server": + print "server started" + TestTCPServer(testAddress) + print "server exited" + elif sys.argv[1] == "slpserver": + print "server started" + stackless.tasklet(TestTCPServer)(testAddress) + stackless.run() + print "server exited" + else: + print "Usage:", sys.argv[0], "[client|server|slpclient|slpserver]" + + sys.exit(1) + else: + print "* Running client/server test" + install() + try: + stackless.tasklet(TestTCPServer)(testAddress) + stackless.tasklet(TestTCPClient)(testAddress) + stackless.run() + finally: + uninstall() + + print "* Running urllib test" + stackless.tasklet(TestMonkeyPatchUrllib)("http://python.org/") + stackless.run() + + print "* Running udp test" + TestMonkeyPatchUDP(testAddress) + + print "result: SUCCESS" -- 2.11.4.GIT