From 01bfefdd8ba4bc597b16d6cbf84ea365d2081125 Mon Sep 17 00:00:00 2001 From: Roy Hyunjin Han Date: Wed, 13 Feb 2013 08:27:58 -0800 Subject: [PATCH] Trying to refactor the damn code to remove cyclic references --- README.rst | 27 +-- socketIO_client/__init__.py | 381 +++++++++++++++++++----------------- socketIO_client/tests.py | 33 ++-- 3 files changed, 233 insertions(+), 208 deletions(-) diff --git a/README.rst b/README.rst index 504e0d5..a1984c2 100644 --- a/README.rst +++ b/README.rst @@ -70,7 +70,8 @@ Define events in a namespace. :: def on_ddd(self, *args): self.socketIO.emit('eee', {'fff': 'ggg'}) - socketIO = SocketIO('localhost', 8000, Namespace) + socketIO = SocketIO('localhost', 8000) + socketIO.define(Namespace) socketIO.wait() # Loop until CTRL-C Define standard events. :: @@ -79,7 +80,7 @@ Define standard events. :: class Namespace(BaseNamespace): - def on_connect(self, socketIO): + def on_connect(self): print '[Connected]' def on_disconnect(self): @@ -91,32 +92,36 @@ Define standard events. :: def on_message(self, id, message): print '[Message] %s: %s' % (id, message) - socketIO = SocketIO('localhost', 8000, Namespace) + socketIO = SocketIO('localhost', 8000) + socketIO.define(Namespace) socketIO.wait() # Loop until CTRL-C -Define different behavior for different channels on a single socket. :: +Define different namespaces on a single socket. :: from socketIO_client import SocketIO, BaseNamespace - class MainNamespace(BaseNamespace): + class MainNamespace(Channel): def on_aaa(self, *args): print 'aaa', args - class ChatNamespace(BaseNamespace): + class ChatNamespace(Channel): def on_bbb(self, *args): print 'bbb', args - class NewsNamespace(BaseNamespace): + class NewsNamespace(Channel): def on_ccc(self, *args): print 'ccc', args - mainSocket = SocketIO('localhost', 8000, MainNamespace) - chatSocket = mainSocket.connect('/chat', ChatNamespace) - newsSocket = mainSocket.connect('/news', NewsNamespace) - mainSocket.wait() # Loop until CTRL-C + socketIO = SocketIO('localhost', 8000) + socketIO.define(MainNamespace) + chatSocket = socketIO.define(ChatNamespace, '/chat') + chatSocket.emit('bbb') + newsSocket = socketIO.define(NewsNamespace, '/news') + newsSocket.emit('ccc') + socketIO.wait() # Loop until CTRL-C Open secure websockets (HTTPS / WSS) behind a proxy. :: diff --git a/socketIO_client/__init__.py b/socketIO_client/__init__.py index 9ae06a2..b247fb4 100644 --- a/socketIO_client/__init__.py +++ b/socketIO_client/__init__.py @@ -1,7 +1,5 @@ import socket -import weakref from anyjson import dumps, loads -from functools import partial from threading import Event, Thread from time import sleep from urllib import urlopen @@ -13,8 +11,8 @@ PROTOCOL = 1 # SocketIO protocol version class BaseNamespace(object): # pragma: no cover - def __init__(self, socketIO): - self.socketIO = socketIO + def __init__(self, _socketIO): + self._socketIO = _socketIO def on_connect(self): pass @@ -44,40 +42,21 @@ class BaseNamespace(object): # pragma: no cover print '[Reconnect]', args -class Channel(object): - - def __init__(self, socketIO, channelName, Namespace): - self._socketIO = weakref.proxy(socketIO) - self._channelName = channelName - self._namespace = Namespace(self) - self._callbackByEvent = {} - - def disconnect(self): - self._socketIO.disconnect(self._channelName) - - def emit(self, eventName, *eventArguments): - self._socketIO.emit(eventName, *eventArguments, channelName=self._channelName) - - def message(self, messageData, callback=None): - self._socketIO.message(messageData, callback, channelName=self._channelName) - - def on(self, eventName, eventCallback): - self._callbackByEvent[eventName] = eventCallback - - class SocketIO(object): - def __init__(self, host, port, Namespace=BaseNamespace, secure=False, proxies=None): + def __init__(self, host, port, secure=False, proxies=None): self._socketIO = _SocketIO(host, port, secure, proxies) + self._channelByPath = {} - self._heartbeatThread = _RhythmicThread( + self._rhythmicThread = _RhythmicThread( self._socketIO.heartbeatTimeout, self._socketIO.send_heartbeat) - self._heartbeatThread.start() + self._rhythmicThread.start() - self._namespace = Namespace(self._socketIO) - self._namespaceThread = _ListenerThread(self._socketIO) - self._namespaceThread.start() + self._listenerThread = _ListenerThread( + self._socketIO, + self._channelByPath) + self._listenerThread.start() def __enter__(self): return self @@ -88,20 +67,30 @@ class SocketIO(object): def __del__(self): self.disconnect() - def disconnect(self, channelName=''): - self._send_packet(0, channelName) - if channelName: - del self.channelByName[channelName] - else: - self._heartbeatThread.cancel() - self._namespaceThread.cancel() + @property + def connected(self): + return self._socketIO.connection.connected - def connect(self, channelName, Namespace=BaseNamespace): - channel = Channel(self, channelName, Namespace) - self.channelByName[channelName] = channel - self.send_packet(1, channelName) + def disconnect(self, channelPath=''): + self._socketIO.disconnect(channelPath) + if channelPath: + del self._channelByPath[channelPath] + else: + self._rhythmicThread.cancel() + self._listenerThread.cancel() + + def define(self, Namespace, channelPath=''): + self._socketIO.connect(channelPath) + channel = Channel(self._socketIO, Namespace, channelPath) + self._channelByPath[channelPath] = channel return channel + def get_namespace(self, channelPath=''): + return self._channelByPath[channelPath].get_namespace() + + def on(self, eventName, eventCallback): + self._callbackByEvent[eventName] = callback + def message(self, messageData, callback=None, channelName=''): if isinstance(messageData, basestring): code = 3 @@ -111,7 +100,8 @@ class SocketIO(object): data = dumps(messageData) self._send_packet(code, channelName, data, callback) - def emit(self, eventName, *eventArguments, **eventKeywords): + def emit(self, eventName, *eventArguments): + self._socketIO.emit(eventName, *eventArguments) code = 5 callback = None if eventArguments and callable(eventArguments[-1]): @@ -121,12 +111,9 @@ class SocketIO(object): data = dumps(dict(name=eventName, args=eventArguments)) self._send_packet(code, channelName, data, callback) - def on(self, eventName, callback): - self._callbackByEvent[eventName] = callback - def wait(self, seconds=None, forCallbacks=False): if forCallbacks: - self._namespaceThread.wait_for_callbacks(seconds) + self.__listenerThread.wait_for_callbacks(seconds) elif seconds: sleep(seconds) else: @@ -137,115 +124,6 @@ class SocketIO(object): pass -class _SocketIO(object): - 'Low-level interface to remove cyclic references in child threads' - - messageID = 0 - - def __init__(self, host, port, secure=False, proxies=None): - self.connect(host, port, secure, proxies) - self.callbackByMessageID = {} - self.callbackByEvent = {} - self.channelByName = {} - - def __del__(self): - self.connection.close() - - def connect(self, host, port, secure, proxies): - baseURL = '%s:%d/socket.io/%s' % (host, port, PROTOCOL) - targetScheme = 'https' if secure else 'http' - targetURL = '%s://%s/' % (targetScheme, baseURL) - try: - response = urlopen(targetURL, proxies=proxies) - except IOError: # pragma: no cover - raise SocketIOError('Could not start connection') - if 200 != response.getcode(): # pragma: no cover - raise SocketIOError('Could not establish connection') - responseParts = response.readline().split(':') - sessionID = responseParts[0] - heartbeatTimeout = int(responseParts[1]) - # connectionTimeout = int(responseParts[2]) - supportedTransports = responseParts[3].split(',') - if 'websocket' not in supportedTransports: - raise SocketIOError('Could not parse handshake') # pragma: no cover - socketScheme = 'wss' if secure else 'ws' - socketURL = '%s://%s/websocket/%s' % (socketScheme, baseURL, sessionID) - self.connection = create_connection(socketURL) - self.heartbeatInterval = heartbeatTimeout - 2 - - def recv_packet(self): - code, packetID, channelName, data = -1, None, None, None - try: - packet = self.connection.recv() - except WebSocketConnectionClosedException: - raise SocketIOConnectionError('Lost connection (Connection closed)') - except socket.timeout: - raise SocketIOConnectionError('Lost connection (Connection timed out)') - try: - packetParts = packet.split(':', 3) - except AttributeError: - raise SocketIOPacketError('Received invalid packet (%s)' % packet) - packetCount = len(packetParts) - if 4 == packetCount: - code, packetID, channelName, data = packetParts - elif 3 == packetCount: - code, packetID, channelName = packetParts - elif 1 == packetCount: # pragma: no cover - code = packetParts[0] - return int(code), packetID, channelName, data - - def send_packet(self, code, channelName='', data='', callback=None): - callbackNumber = self.set_messageID_callback(callback) if callback else '' - packetParts = [str(code), callbackNumber, channelName, data] - try: - self.connection.send(':'.join(packetParts)) - except socket.error: - raise SocketIOPacketError('Could not send packet') - - def set_messageID_callback(self, callback): - 'Set callback that will be called after receiving an acknowledgment' - self.messageID += 1 - self.callbackByMessageID[self.messageID] = callback - return '%s+' % self.messageID - - def get_messageID_callback(self, messageID): - 'Get callback associated with messageID' - try: - callback = self.callbackByMessageID[messageID] - del self.callbackByMessageID[messageID] - return callback - except KeyError: - return - - @property - def has_messageID_callback(self): - return True if self.callbackByMessageID else False - - def get_event_callback(self, channelName, eventName): - 'Get callback associated with channelName and eventName' - _socketIO = self.channelByName[channelName] if channelName else self - try: - return _socketIO.callbackByEvent[eventName] - except KeyError: - pass - - def callback_(*eventArguments): - return _socketIO.namespace.on_(eventName, *eventArguments) - callbackName = 'on_' + eventName.replace(' ', '_') - return getattr(_socketIO.namespace, callbackName, callback_) - - @property - def connected(self): - return self.connection.connected - - def send_heartbeat(self): - try: - self.send_packet(2) - except SocketIOPacketError: - print 'Could not send heartbeat' - pass - - class _RhythmicThread(Thread): 'Execute call every few seconds' @@ -273,23 +151,32 @@ class _ListenerThread(Thread): daemon = True - def __init__(self, _socketIO, get_event_callback, ): + def __init__(self, _socketIO, _channelByPath): super(_ListenerThread, self).__init__() self._socketIO = _socketIO + self._channelByPath = _channelByPath self.done = Event() - self.waitingForCallbacks = Event() + self.waiting = Event() + + def cancel(self): + self.done.set() + + def wait_for_callbacks(self, seconds): + self.waiting.set() + # Block callingThread until listenerThread terminates + self.join(seconds) def run(self): while not self.done.is_set(): try: - code, packetID, channelName, data = self._socketIO.recv_packet() + code, packetID, channelPath, data = self._socketIO.recv_packet() except SocketIOConnectionError, error: print error return except SocketIOPacketError, error: print error continue - get_channel_callback = partial(self._socketIO.get_event_callback, channelName) + channel = self._channelByPath[channelPath] try: delegate = { 0: self.on_disconnect, @@ -303,50 +190,174 @@ class _ListenerThread(Thread): }[code] except KeyError: continue - delegate(packetID, get_channel_callback, data) + delegate(packetID, channel._get_eventCallback, data) - def cancel(self): - self.done.set() + def on_disconnect(self, packetID, get_eventCallback, data): + get_eventCallback('disconnect')() - def wait_for_callbacks(self, seconds): - self.waitingForCallbacks.set() - self.join(seconds) + def on_connect(self, packetID, get_eventCallback, data): + get_eventCallback('connect')() - def on_disconnect(self, packetID, get_channel_callback, data): - get_channel_callback('disconnect')() + def on_heartbeat(self, packetID, get_eventCallback, data): + get_eventCallback('heartbeat')() - def on_connect(self, packetID, get_channel_callback, data): - get_channel_callback('connect')() + def on_message(self, packetID, get_eventCallback, data): + get_eventCallback('message')(data) - def on_heartbeat(self, packetID, get_channel_callback, data): - pass + def on_json(self, packetID, get_eventCallback, data): + get_eventCallback('message')(loads(data)) - def on_message(self, packetID, get_channel_callback, data): - get_channel_callback('message')(data) - - def on_json(self, packetID, get_channel_callback, data): - get_channel_callback('message')(loads(data)) - - def on_event(self, packetID, get_channel_callback, data): + def on_event(self, packetID, get_eventCallback, data): valueByName = loads(data) eventName = valueByName['name'] eventArguments = valueByName['args'] - get_channel_callback(eventName)(*eventArguments) + get_eventCallback(eventName)(*eventArguments) - def on_acknowledgment(self, packetID, get_channel_callback, data): + def on_acknowledgment(self, packetID, get_eventCallback, data): dataParts = data.split('+', 1) messageID = int(dataParts[0]) arguments = loads(dataParts[1]) or [] - callback = self._socketIO.get_messageID_callback(messageID) + callback = self._socketIO.get_messageCallback(messageID) if not callback: return callback(*arguments) - if self.waitingForCallbacks.is_set() and not self._socketIO.has_messageID_callback: + if self.waiting.is_set() and not self._socketIO.has_messageCallback: self.cancel() - def on_error(self, packetID, get_channel_callback, data): + def on_error(self, packetID, get_eventCallback, data): reason, advice = data.split('+', 1) - get_channel_callback('error')(reason, advice) + get_eventCallback('error')(reason, advice) + + +class _SocketIO(object): + 'Low-level interface to remove cyclic references in child threads' + + messageID = 0 + self.callbackByMessageID = {} + self.callbackByEvent = {} + + + def __init__(self, host, port, secure, proxies): + baseURL = '%s:%d/socket.io/%s' % (host, port, PROTOCOL) + targetScheme = 'https' if secure else 'http' + targetURL = '%s://%s/' % (targetScheme, baseURL) + try: + response = urlopen(targetURL, proxies=proxies) + except IOError: # pragma: no cover + raise SocketIOError('Could not start connection') + if 200 != response.getcode(): # pragma: no cover + raise SocketIOError('Could not establish connection') + responseParts = response.readline().split(':') + sessionID = responseParts[0] + heartbeatTimeout = int(responseParts[1]) + # connectionTimeout = int(responseParts[2]) + supportedTransports = responseParts[3].split(',') + if 'websocket' not in supportedTransports: + raise SocketIOError('Could not parse handshake') # pragma: no cover + socketScheme = 'wss' if secure else 'ws' + socketURL = '%s://%s/websocket/%s' % (socketScheme, baseURL, sessionID) + self.connection = create_connection(socketURL) + self.heartbeatInterval = heartbeatTimeout - 2 + + def __del__(self): + self.connection.close() + + def get_channel(self, channelPath): + + def get_channel(self, channelName): + return self.channelByName[channelName] + pass + + def recv_packet(self): + code, packetID, channelName, data = -1, None, None, None + try: + packet = self.connection.recv() + except WebSocketConnectionClosedException: + raise SocketIOConnectionError('Lost connection (Connection closed)') + except socket.timeout: + raise SocketIOConnectionError('Lost connection (Connection timed out)') + try: + packetParts = packet.split(':', 3) + except AttributeError: + raise SocketIOPacketError('Received invalid packet (%s)' % packet) + packetCount = len(packetParts) + if 4 == packetCount: + code, packetID, channelName, data = packetParts + elif 3 == packetCount: + code, packetID, channelName = packetParts + elif 1 == packetCount: # pragma: no cover + code = packetParts[0] + return int(code), packetID, channelName, data + + def send_packet(self, code, channelName='', data='', callback=None): + callbackNumber = self.set_messageCallback(callback) if callback else '' + packetParts = [str(code), callbackNumber, channelName, data] + try: + self.connection.send(':'.join(packetParts)) + except socket.error: + raise SocketIOPacketError('Could not send packet') + + def disconnect(self, channelPath): + self.send_packet(0, channelPath) + + def connect(self, channelPath): + self.send_packet(1, channelPath) + + def send_heartbeat(self): + try: + self.send_packet(2) + except SocketIOPacketError: + print 'Could not send heartbeat' + pass + + def set_messageCallback(self, callback): + 'Set callback that will be called after receiving an acknowledgment' + self.messageID += 1 + self.callbackByMessageID[self.messageID] = callback + return '%s+' % self.messageID + + def get_messageCallback(self, messageID): + try: + callback = self.callbackByMessageID[messageID] + del self.callbackByMessageID[messageID] + return callback + except KeyError: + return + + @property + def has_messageCallback(self): + return True if self.callbackByMessageID else False + + +class Channel(object): + + def __init__(self, _socketIO, Namespace, channelPath): + self._socketIO = _socketIO + self._namespace = Namespace(_socketIO) + self._channelPath = channelPath + self._callbackByEvent = {} + + def on(self, eventName, eventCallback): + self._callbackByEvent[eventName] = eventCallback + + def message(self, messageData, messageCallback=None): + self._socketIO.message(messageData, messageCallback, channelPath=self._channelPath) + + def emit(self, eventName, *eventArguments): + self._socketIO.emit(eventName, *eventArguments, channelPath=self._channelPath) + + def get_namespace(self): + return self._namespace + + def _get_eventCallback(self, eventName): + # Check callbacks defined by on() + try: + return self._callbackByEvent[eventName] + except KeyError: + pass + # Check callbacks defined explicitly or use on_() + defaultCallback = lambda *eventArguments: self.get_namespace().on_(eventName, *eventArguments) + return getattr(self, 'on_' + eventName.replace(' ', '_'), defaultCallback) class SocketIOError(Exception): diff --git a/socketIO_client/tests.py b/socketIO_client/tests.py index 53a6cf6..1cbf3b6 100644 --- a/socketIO_client/tests.py +++ b/socketIO_client/tests.py @@ -13,18 +13,26 @@ class TestSocketIO(TestCase): socketIO = SocketIO('localhost', 8000) socketIO.disconnect() self.assertEqual(socketIO.connected, False) + childThreads = [ + socketIO._rhythmicThread, + socketIO._listenerThread, + ] + for childThread in childThreads: + self.assertEqual(True, childThread.done.is_set()) def test_emit(self): - socketIO = SocketIO('localhost', 8000, Namespace) + socketIO = SocketIO('localhost', 8000) + socketIO.define(Namespace) socketIO.emit('aaa') sleep(0.5) - self.assertEqual(socketIO._namespace.payload, '') + self.assertEqual(socketIO.get_namespace().payload, '') def test_emit_with_payload(self): - socketIO = SocketIO('localhost', 8000, Namespace) + socketIO = SocketIO('localhost', 8000) + socketIO.define(Namespace) socketIO.emit('aaa', PAYLOAD) sleep(0.5) - self.assertEqual(socketIO._namespace.payload, PAYLOAD) + self.assertEqual(socketIO.get_namespace().payload, PAYLOAD) def test_emit_with_callback(self): global ON_RESPONSE_CALLED @@ -44,20 +52,21 @@ class TestSocketIO(TestCase): self.assertEqual(ON_RESPONSE_CALLED, True) def test_channels(self): - mainSocket = SocketIO('localhost', 8000, Namespace) - chatSocket = mainSocket.connect('/chat', Namespace) - newsSocket = mainSocket.connect('/news', Namespace) + socketIO = SocketIO('localhost', 8000) + mainSocket = socketIO.define(Namespace) + chatSocket = socketIO.define(Namespace, '/chat') + newsSocket = socketIO.define(Namespace, '/news') newsSocket.emit('aaa', PAYLOAD) sleep(0.5) - self.assertNotEqual(mainSocket._namespace.payload, PAYLOAD) - self.assertNotEqual(chatSocket._namespace.payload, PAYLOAD) - self.assertEqual(newsSocket._namespace.payload, PAYLOAD) + self.assertNotEqual(mainSocket.get_namespace().payload, PAYLOAD) + self.assertNotEqual(chatSocket.get_namespace().payload, PAYLOAD) + self.assertEqual(newsSocket.get_namespace().payload, PAYLOAD) def test_delete(self): socketIO = SocketIO('localhost', 8000) childThreads = [ - socketIO._heartbeatThread, - socketIO._namespaceThread, + socketIO._rhythmicThread, + socketIO._listenerThread, ] del socketIO for childThread in childThreads: