From 0a5b069cddfbfcaf37970f3991b226a6653483fb Mon Sep 17 00:00:00 2001 From: Roy Hyunjin Han Date: Sat, 9 Feb 2013 19:12:21 -0800 Subject: [PATCH] Added test to check that child threads die when parent dies --- CHANGES.rst | 3 + README.rst | 13 +- TODO.rst | 7 +- serve_tests.py | 17 ++- setup.py | 1 - socketIO_client/__init__.py | 263 ++++++++++++++++++++---------------- socketIO_client/tests.py | 26 +++- 7 files changed, 193 insertions(+), 137 deletions(-) mode change 100755 => 100644 serve_tests.py mode change 100755 => 100644 setup.py diff --git a/CHANGES.rst b/CHANGES.rst index 8afe01b..1e5d0f7 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,3 +1,6 @@ +0.4 +--- + 0.3 --- - Added support for secure connections diff --git a/README.rst b/README.rst index 5ed4179..504e0d5 100644 --- a/README.rst +++ b/README.rst @@ -35,10 +35,9 @@ Activate isolated environment. :: Emit. :: from socketIO_client import SocketIO - - socketIO = SocketIO('localhost', 8000) - socketIO.emit('aaa', {'bbb': 'ccc'}) - socketIO.wait(seconds=1) # Exit after one second + with SocketIO('localhost', 8000) as socketIO: + socketIO.emit('aaa') + socketIO.wait(1) # Wait a second Emit with callback. :: @@ -47,9 +46,9 @@ Emit with callback. :: def on_response(*args): print args - socketIO = SocketIO('localhost', 8000) - socketIO.emit('aaa', {'bbb': 'ccc'}, on_response) - socketIO.wait(forCallbacks=True) # Exit after callbacks run + with SocketIO('localhost', 8000) as socketIO: + socketIO.emit('aaa', {'bbb': 'ccc'}, on_response) + socketIO.wait(seconds=1, forCallbacks=True) # Wait for callback Define events. :: diff --git a/TODO.rst b/TODO.rst index bf51bc3..e362fcd 100644 --- a/TODO.rst +++ b/TODO.rst @@ -1,5 +1,8 @@ -Let user define a proxy #5 -Let user emit without arguments #5 ++ Fix unittests ++ Fix exceptions when websocket server disappears + +Fix thread exceptions + Integrate Zac's fork #6 Integrate Sajal's fork #7 Integrate Francis's fork #10 diff --git a/serve_tests.py b/serve_tests.py old mode 100755 new mode 100644 index 47c946a..faea52e --- a/serve_tests.py +++ b/serve_tests.py @@ -1,7 +1,14 @@ 'Launch this server in another terminal window before running tests' -from socketio import socketio_manage -from socketio.namespace import BaseNamespace -from socketio.server import SocketIOServer +import sys +try: + from socketio import socketio_manage + from socketio.namespace import BaseNamespace + from socketio.server import SocketIOServer +except ImportError: + from setuptools.command import easy_install + easy_install.main(['-U', 'gevent-socketio']) + print('\nPlease run the script again to launch the test server.') + sys.exit(1) class Namespace(BaseNamespace): @@ -25,5 +32,7 @@ class Application(object): if __name__ == '__main__': - socketIOServer = SocketIOServer(('0.0.0.0', 8000), Application()) + port = 8000 + print 'Starting server at port %s' % port + socketIOServer = SocketIOServer(('0.0.0.0', port), Application()) socketIOServer.serve_forever() diff --git a/setup.py b/setup.py old mode 100755 new mode 100644 index 0ca72de..0f79923 --- a/setup.py +++ b/setup.py @@ -24,7 +24,6 @@ setup( url='https://github.com/invisibleroads/socketIO-client', install_requires=[ 'anyjson', - 'gevent-socketio', 'websocket-client', ], packages=find_packages(), diff --git a/socketIO_client/__init__.py b/socketIO_client/__init__.py index e32996d..493748a 100644 --- a/socketIO_client/__init__.py +++ b/socketIO_client/__init__.py @@ -1,11 +1,16 @@ -import websocket +import sys +import traceback + +import socket from anyjson import dumps, loads +from functools import partial from threading import Thread, Event from time import sleep from urllib import urlopen +from websocket import WebSocketConnectionClosedException, create_connection -__version__ = '0.3' +__version__ = '0.4' PROTOCOL = 1 # SocketIO protocol version @@ -16,7 +21,7 @@ class BaseNamespace(object): # pragma: no cover def __init__(self, socketIO): self.socketIO = socketIO - def on_connect(self, socketIO): + def on_connect(self): pass def on_disconnect(self): @@ -46,54 +51,69 @@ class BaseNamespace(object): # pragma: no cover class SocketIO(object): - messageID = 0 + _messageID = 0 - def __init__(self, host, port, Namespace=BaseNamespace, secure=False): - self.host = host - self.port = int(port) - self.namespace = Namespace(self) - self.secure = secure - self.__connect() + def __init__(self, host, port, Namespace=BaseNamespace, secure=False, proxies=None): + self._host = host + self._port = int(port) + self._namespace = Namespace(self) + self._secure = secure + self._proxies = proxies + self._connect() - heartbeatInterval = self.heartbeatTimeout - 2 - self.heartbeatThread = RhythmicThread(heartbeatInterval, - self._send_heartbeat) - self.heartbeatThread.start() + heartbeatInterval = self._heartbeatTimeout - 2 + self._heartbeatThread = RhythmicThread(heartbeatInterval, self._send_heartbeat) + self._heartbeatThread.start() - self.channelByName = {} - self.callbackByEvent = {} - self.namespaceThread = ListenerThread(self) - self.namespaceThread.start() + self._channelByName = {} + self._callbackByEvent = {} + self._namespaceThread = ListenerThread(self._recv_packet, self._get_callback) + self._namespaceThread.start() - def __del__(self): # pragma: no cover - self.heartbeatThread.cancel() - self.namespaceThread.cancel() - self.connection.close() + def __enter__(self): + return self - def __connect(self): - baseURL = '%s:%d/socket.io/%s' % (self.host, self.port, PROTOCOL) + def __exit__(self, exc_type, exc_value, traceback): + self.__del__() + + def __del__(self): + self._heartbeatThread.cancel() + self._namespaceThread.cancel() + self._connection.close() + + def _connect(self): + baseURL = '%s:%d/socket.io/%s' % (self._host, self._port, PROTOCOL) try: response = urlopen('%s://%s/' % ( - 'https' if self.secure else 'http', baseURL)) + 'https' if self._secure else 'http', baseURL), + proxies=self._proxies) except IOError: # pragma: no cover raise SocketIOError('Could not start connection') if 200 != response.getcode(): # pragma: no cover raise SocketIOError('Could not establish connection') responseParts = response.readline().split(':') - self.sessionID = responseParts[0] - self.heartbeatTimeout = int(responseParts[1]) - self.connectionTimeout = int(responseParts[2]) - self.supportedTransports = responseParts[3].split(',') - if 'websocket' not in self.supportedTransports: + self._sessionID = responseParts[0] + self._heartbeatTimeout = int(responseParts[1]) + self._connectionTimeout = int(responseParts[2]) + self._supportedTransports = responseParts[3].split(',') + if 'websocket' not in self._supportedTransports: raise SocketIOError('Could not parse handshake') # pragma: no cover socketURL = '%s://%s/websocket/%s' % ( - 'wss' if self.secure else 'ws', baseURL, self.sessionID) - self.connection = websocket.create_connection(socketURL) + 'wss' if self._secure else 'ws', baseURL, self._sessionID) + self._connection = create_connection(socketURL) def _recv_packet(self): code, packetID, channelName, data = -1, None, None, None - packet = self.connection.recv() - packetParts = packet.split(':', 3) + try: + packet = self._connection.recv() + except WebSocketConnectionClosedException: + raise SocketIOConnectionError('Lost connection (Connection closed)') + except socket.timeout: + raise SocketIOConnectionError('Lost connection (Connection timed out)') + try: + packetParts = packet.split(':', 3) + except AttributeError: + raise SocketIOPacketError('Received invalid packet (%s)' % packet) packetCount = len(packetParts) if 4 == packetCount: code, packetID, channelName, data = packetParts @@ -104,34 +124,36 @@ class SocketIO(object): return int(code), packetID, channelName, data def _send_packet(self, code, channelName='', data='', callback=None): - self.connection.send(':'.join([ - str(code), - self.set_callback(callback) if callback else '', - channelName, - data])) + callbackNumber = self._set_callback(callback) if callback else '' + packetParts = [str(code), callbackNumber, channelName, data] + try: + self._connection.send(':'.join(packetParts)) + except socket.error: + raise SocketIOPacketError('Could not send packet') def disconnect(self, channelName=''): self._send_packet(0, channelName) if channelName: - del self.channelByName[channelName] + del self._channelByName[channelName] else: self.__del__() @property def connected(self): - return self.connection.connected + return self._connection.connected def connect(self, channelName, Namespace=BaseNamespace): channel = Channel(self, channelName, Namespace) - self.channelByName[channelName] = channel + self._channelByName[channelName] = channel self._send_packet(1, channelName) return channel def _send_heartbeat(self): try: self._send_packet(2) - except: - self.__del__() + except SocketIOPacketError: + print 'Could not send heartbeat' + pass def message(self, messageData, callback=None, channelName=''): if isinstance(messageData, basestring): @@ -144,40 +166,39 @@ class SocketIO(object): def emit(self, eventName, *eventArguments, **eventKeywords): code = 5 - if callable(eventArguments[-1]): + callback = None + if eventArguments and callable(eventArguments[-1]): callback = eventArguments[-1] eventArguments = eventArguments[:-1] - else: - callback = None channelName = eventKeywords.get('channelName', '') data = dumps(dict(name=eventName, args=eventArguments)) self._send_packet(code, channelName, data, callback) - def get_callback(self, channelName, eventName): + def _get_callback(self, channelName, eventName): 'Get callback associated with channelName and eventName' - socketIO = self.channelByName[channelName] if channelName else self + socketIO = self._channelByName[channelName] if channelName else self try: - return socketIO.callbackByEvent[eventName] + return socketIO._callbackByEvent[eventName] except KeyError: pass - namespace = socketIO.namespace def callback_(*eventArguments): - return namespace.on_(eventName, *eventArguments) - return getattr(namespace, name_callback(eventName), callback_) + return socketIO._namespace.on_(eventName, *eventArguments) + callbackName = 'on_' + eventName.replace(' ', '_') + return getattr(socketIO._namespace, callbackName, callback_) - def set_callback(self, callback): + def _set_callback(self, callback): 'Set callback that will be called after receiving an acknowledgment' - self.messageID += 1 - self.namespaceThread.set_callback(self.messageID, callback) - return '%s+' % self.messageID + self._messageID += 1 + self._namespaceThread.set_callback(self._messageID, callback) + return '%s+' % self._messageID def on(self, eventName, callback): - self.callbackByEvent[eventName] = callback + self._callbackByEvent[eventName] = callback def wait(self, seconds=None, forCallbacks=False): if forCallbacks: - self.namespaceThread.wait_for_callbacks(seconds) + self._namespaceThread.wait_for_callbacks(seconds) elif seconds: sleep(seconds) else: @@ -191,24 +212,22 @@ class SocketIO(object): class Channel(object): def __init__(self, socketIO, channelName, Namespace): - self.socketIO = socketIO - self.channelName = channelName - self.namespace = Namespace(self) - self.callbackByEvent = {} + self._socketIO = socketIO + self._channelName = channelName + self._namespace = Namespace(self) + self._callbackByEvent = {} def disconnect(self): - self.socketIO.disconnect(self.channelName) + self._socketIO.disconnect(self._channelName) def emit(self, eventName, *eventArguments): - self.socketIO.emit(eventName, *eventArguments, - channelName=self.channelName) + self._socketIO.emit(eventName, *eventArguments, channelName=self._channelName) def message(self, messageData, callback=None): - self.socketIO.message(messageData, callback, - channelName=self.channelName) + self._socketIO.message(messageData, callback, channelName=self._channelName) def on(self, eventName, eventCallback): - self.callbackByEvent[eventName] = eventCallback + self._callbackByEvent[eventName] = eventCallback class ListenerThread(Thread): @@ -216,34 +235,43 @@ class ListenerThread(Thread): daemon = True - def __init__(self, socketIO): + def __init__(self, recv_packet, get_callback): super(ListenerThread, self).__init__() - self.socketIO = socketIO self.done = Event() self.waitingForCallbacks = Event() self.callbackByMessageID = {} - self.get_callback = self.socketIO.get_callback + self.recv_packet = recv_packet + self.get_callback = get_callback def run(self): - while not self.done.is_set(): - try: - code, packetID, channelName, data = self.socketIO._recv_packet() - except: - continue - try: - delegate = { - 0: self.on_disconnect, - 1: self.on_connect, - 2: self.on_heartbeat, - 3: self.on_message, - 4: self.on_json, - 5: self.on_event, - 6: self.on_acknowledgment, - 7: self.on_error, - }[code] - except KeyError: - continue - delegate(packetID, channelName, data) + try: + while not self.done.is_set(): + try: + code, packetID, channelName, data = self.recv_packet() + except SocketIOConnectionError, error: + print error + return + except SocketIOPacketError, error: + print error + continue + get_channel_callback = partial(self.get_callback, channelName) + try: + delegate = { + 0: self.on_disconnect, + 1: self.on_connect, + 2: self.on_heartbeat, + 3: self.on_message, + 4: self.on_json, + 5: self.on_event, + 6: self.on_acknowledgment, + 7: self.on_error, + }[code] + except KeyError: + continue + delegate(packetID, get_channel_callback, data) + except: + exc_type, exc_value, exc_traceback = sys.exc_info() + open('tracebacks.log', 'a+t').write('\n'.join(traceback.format_tb(exc_traceback))) def cancel(self): self.done.set() @@ -255,33 +283,28 @@ class ListenerThread(Thread): def set_callback(self, messageID, callback): self.callbackByMessageID[messageID] = callback - def on_disconnect(self, packetID, channelName, data): - callback = self.get_callback(channelName, 'disconnect') - callback() + def on_disconnect(self, packetID, get_channel_callback, data): + get_channel_callback('disconnect')() - def on_connect(self, packetID, channelName, data): - callback = self.get_callback(channelName, 'connect') - callback(self.socketIO) + def on_connect(self, packetID, get_channel_callback, data): + get_channel_callback('connect')() - def on_heartbeat(self, packetID, channelName, data): + def on_heartbeat(self, packetID, get_channel_callback, data): pass - def on_message(self, packetID, channelName, data): - callback = self.get_callback(channelName, 'message') - callback(data) + def on_message(self, packetID, get_channel_callback, data): + get_channel_callback('message')(data) - def on_json(self, packetID, channelName, data): - callback = self.get_callback(channelName, 'message') - callback(loads(data)) + def on_json(self, packetID, get_channel_callback, data): + get_channel_callback('message')(loads(data)) - def on_event(self, packetID, channelName, data): + def on_event(self, packetID, get_channel_callback, data): valueByName = loads(data) eventName = valueByName['name'] eventArguments = valueByName['args'] - callback = self.get_callback(channelName, eventName) - callback(*eventArguments) + get_channel_callback(eventName)(*eventArguments) - def on_acknowledgment(self, packetID, channelName, data): + def on_acknowledgment(self, packetID, get_channel_callback, data): dataParts = data.split('+', 1) messageID = int(dataParts[0]) arguments = loads(dataParts[1]) or [] @@ -296,21 +319,20 @@ class ListenerThread(Thread): if self.waitingForCallbacks.is_set() and not callbackCount: self.cancel() - def on_error(self, packetID, channelName, data): + def on_error(self, packetID, get_channel_callback, data): reason, advice = data.split('+', 1) - callback = self.get_callback(channelName, 'error') - callback(reason, advice) + get_channel_callback('error')(reason, advice) class RhythmicThread(Thread): - 'Execute rhythmicFunction every few seconds' + 'Execute call every few seconds' daemon = True - def __init__(self, intervalInSeconds, rhythmicFunction, *args, **kw): + def __init__(self, intervalInSeconds, call, *args, **kw): super(RhythmicThread, self).__init__() self.intervalInSeconds = intervalInSeconds - self.rhythmicFunction = rhythmicFunction + self.call = call self.args = args self.kw = kw self.done = Event() @@ -318,10 +340,11 @@ class RhythmicThread(Thread): def run(self): try: while not self.done.is_set(): - self.rhythmicFunction(*self.args, **self.kw) + self.call(*self.args, **self.kw) self.done.wait(self.intervalInSeconds) except: - pass + exc_type, exc_value, exc_traceback = sys.exc_info() + open('tracebacks.log', 'a+t').write('\n'.join(traceback.format_tb(exc_traceback))) def cancel(self): self.done.set() @@ -331,5 +354,9 @@ class SocketIOError(Exception): pass -def name_callback(eventName): - return 'on_' + eventName.replace(' ', '_') +class SocketIOConnectionError(SocketIOError): + pass + + +class SocketIOPacketError(SocketIOError): + pass diff --git a/socketIO_client/tests.py b/socketIO_client/tests.py index a85c48e..53a6cf6 100644 --- a/socketIO_client/tests.py +++ b/socketIO_client/tests.py @@ -15,10 +15,16 @@ class TestSocketIO(TestCase): self.assertEqual(socketIO.connected, False) def test_emit(self): + socketIO = SocketIO('localhost', 8000, Namespace) + socketIO.emit('aaa') + sleep(0.5) + self.assertEqual(socketIO._namespace.payload, '') + + def test_emit_with_payload(self): socketIO = SocketIO('localhost', 8000, Namespace) socketIO.emit('aaa', PAYLOAD) sleep(0.5) - self.assertEqual(socketIO.namespace.payload, PAYLOAD) + self.assertEqual(socketIO._namespace.payload, PAYLOAD) def test_emit_with_callback(self): global ON_RESPONSE_CALLED @@ -43,16 +49,26 @@ class TestSocketIO(TestCase): newsSocket = mainSocket.connect('/news', Namespace) newsSocket.emit('aaa', PAYLOAD) sleep(0.5) - self.assertNotEqual(mainSocket.namespace.payload, PAYLOAD) - self.assertNotEqual(chatSocket.namespace.payload, PAYLOAD) - self.assertEqual(newsSocket.namespace.payload, PAYLOAD) + self.assertNotEqual(mainSocket._namespace.payload, PAYLOAD) + self.assertNotEqual(chatSocket._namespace.payload, PAYLOAD) + self.assertEqual(newsSocket._namespace.payload, PAYLOAD) + + def test_delete(self): + socketIO = SocketIO('localhost', 8000) + childThreads = [ + socketIO._heartbeatThread, + socketIO._namespaceThread, + ] + del socketIO + for childThread in childThreads: + self.assertEqual(True, childThread.done.is_set()) class Namespace(BaseNamespace): payload = None - def on_ddd(self, data): + def on_ddd(self, data=''): self.payload = data