From 683533c580f7700867912248f22199732188387e Mon Sep 17 00:00:00 2001 From: "(no author)" <(no author)@41a61cd8-c433-0410-bb1c-e256eeef9e11> Date: Fri, 14 Dec 2007 23:47:01 +0000 Subject: [PATCH] r1407@opsdev009 (orig r74442): dreiss | 2007-12-14 15:46:47 -0800 Thrift: Python support for Unix-domain sockets, and eager timeout setting. Reviewed By: mcslee Test Plan: Ran the test script. Revert Plan: ok Other Notes: Contributed by Ben Maurer. git-svn-id: http://svn.facebook.com/svnroot/thrift/trunk@722 41a61cd8-c433-0410-bb1c-e256eeef9e11 --- lib/py/src/transport/TSocket.py | 32 ++++++++++++++++---- test/py/TestSocket.py | 65 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 6 deletions(-) create mode 100755 test/py/TestSocket.py diff --git a/lib/py/src/transport/TSocket.py b/lib/py/src/transport/TSocket.py index 0b44344..146820d 100644 --- a/lib/py/src/transport/TSocket.py +++ b/lib/py/src/transport/TSocket.py @@ -13,11 +13,21 @@ class TSocket(TTransportBase): """Socket implementation of TTransport base.""" - def __init__(self, host='localhost', port=9090): + def __init__(self, host='localhost', port=9090, unix_socket=None): + """Initialize a TSocket + + @param host(str) The host to connect to. + @param port(int) The (TCP) port to connect to. + @param unix_socket(str) The filename of a unix socket to connect to. + (host and port will be ignored.) + """ + self.host = host self.port = port self.handle = None - + self._unix_socket = unix_socket + self._timeout = None + def setHandle(self, h): self.handle = h @@ -25,16 +35,26 @@ class TSocket(TTransportBase): return self.handle != None def setTimeout(self, ms): - if (self.handle != None): - self.handle.settimeout(ms/1000.00) + if ms is None: + self._timeout = None else: - raise TTransportException(TTransportException.NOT_OPEN, 'No handle yet in TSocket') + self._timeout = ms/1000.0 + + if (self.handle != None): + self.handle.settimeout(self._timeout) + def _resolveAddr(self): + if self._unix_socket is not None: + return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None, self._unix_socket)] + else: + return socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_PASSIVE | socket.AI_ADDRCONFIG) + def open(self): try: - res0 = socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_PASSIVE | socket.AI_ADDRCONFIG) + res0 = self._resolveAddr() for res in res0: self.handle = socket.socket(res[0], res[1]) + self.handle.settimeout(self._timeout) try: self.handle.connect(res[4]) except socket.error, e: diff --git a/test/py/TestSocket.py b/test/py/TestSocket.py new file mode 100755 index 0000000..7cbdf5c --- /dev/null +++ b/test/py/TestSocket.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python + +import sys, glob +sys.path.insert(0, './gen-py') +sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0]) + +from ThriftTest import ThriftTest +from ThriftTest.ttypes import * +from thrift.transport import TTransport +from thrift.transport import TSocket +from thrift.protocol import TBinaryProtocol +import unittest +import time +import socket +import random +from optparse import OptionParser + +class TimeoutTest(unittest.TestCase): + def setUp(self): + for i in xrange(50): + try: + # find a port we can use + self.listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.port = random.randint(10000, 30000) + self.listen_sock.bind(('localhost', self.port)) + self.listen_sock.listen(5) + break + except: + if i == 49: + raise + + def testConnectTimeout(self): + starttime = time.time() + + try: + leaky = [] + for i in xrange(100): + socket = TSocket.TSocket('localhost', self.port) + socket.setTimeout(10) + socket.open() + leaky.append(socket) + except: + assert time.time() - starttime < 5.0 + + def testWriteTimeout(self): + starttime = time.time() + + try: + socket = TSocket.TSocket('localhost', self.port) + socket.setTimeout(10) + socket.open() + lsock = self.listen_sock.accept() + while True: + socket.write("hi" * 100) + + except: + assert time.time() - starttime < 5.0 + +suite = unittest.TestSuite() +loader = unittest.TestLoader() + +suite.addTest(loader.loadTestsFromTestCase(TimeoutTest)) + +testRunner = unittest.TextTestRunner(verbosity=2) +testRunner.run(suite) -- 2.11.4.GIT