diff --git a/socketIO_client/__init__.py b/socketIO_client/__init__.py index faa0a1f..18cb1fe 100644 --- a/socketIO_client/__init__.py +++ b/socketIO_client/__init__.py @@ -12,9 +12,9 @@ PROTOCOL = 1 # socket.io protocol version class BaseNamespace(object): # pragma: no cover 'Define socket.io behavior' - def __init__(self, _socketIO, namespacePath): + def __init__(self, _socketIO, path): self._socketIO = _socketIO - self._namespacePath = namespacePath + self._path = path self._callbackByEvent = {} def on_connect(self): @@ -26,11 +26,16 @@ class BaseNamespace(object): # pragma: no cover def on_error(self, reason, advice): print '[Error] %s' % advice - def on_message(self, messageData): - print '[Message] %s' % messageData + def on_message(self, data): + print '[Message] %s' % data - def on_default(self, eventName, *eventArguments): - print '[Event] %s%s' % (eventName, eventArguments) + def on_default(self, event, *args): + callback, args = find_callback(args) + arguments = [str(_) for _ in args] + if callback: + arguments.append('callback(*args)') + callback() + print '[Event] %s(%s)' % (event, ', '.join(arguments)) def on_open(self, *args): print '[Open]', args @@ -44,28 +49,25 @@ class BaseNamespace(object): # pragma: no cover def on_reconnect(self, *args): print '[Reconnect]', args - def message(self, messageData, messageCallback=None): - self._socketIO.message( - messageData, messageCallback, namespacePath=self._namespacePath) + def message(self, data, callback=None): + self._socketIO.message(data, callback, path=self._path) - def emit(self, eventName, *eventArguments): - self._socketIO.emit( - eventName, *eventArguments, namespacePath=self._namespacePath) + def emit(self, event, *args, **kw): + kw['path'] = self._path + self._socketIO.emit(event, *args, **kw) - def on(self, eventName, eventCallback): - self._callbackByEvent[eventName] = eventCallback + def on(self, event, callback): + self._callbackByEvent[event] = callback - def _get_eventCallback(self, eventName): + def _get_eventCallback(self, event): # Check callbacks defined by on() try: - return self._callbackByEvent[eventName] + return self._callbackByEvent[event] except KeyError: pass - # Check callbacks defined explicitly or use on_default() - def callback(*eventArguments): - return self.on_default(eventName, *eventArguments) - return getattr(self, 'on_' + eventName.replace(' ', '_'), callback) + callback = lambda *args: self.on_default(event, *args) + return getattr(self, 'on_' + event.replace(' ', '_'), callback) class SocketIO(object): @@ -99,38 +101,38 @@ class SocketIO(object): self.disconnect() def __del__(self): - self.disconnect(closeSocket=False) + self.disconnect(close=False) @property def connected(self): return self._socketIO.connected - def disconnect(self, namespacePath='', closeSocket=True): + def disconnect(self, path='', close=True): if self.connected: - self._socketIO.disconnect(namespacePath, closeSocket) - if namespacePath: - del self._namespaceByPath[namespacePath] + self._socketIO.disconnect(path, close) + if path: + del self._namespaceByPath[path] else: self._rhythmicThread.cancel() self._listenerThread.cancel() - def define(self, Namespace, namespacePath=''): - self._socketIO.connect(namespacePath) - namespace = Namespace(self._socketIO, namespacePath) - self._namespaceByPath[namespacePath] = namespace + def define(self, Namespace, path=''): + self._socketIO.connect(path) + namespace = Namespace(self._socketIO, path) + self._namespaceByPath[path] = namespace return namespace - def get_namespace(self, namespacePath=''): - return self._namespaceByPath[namespacePath] + def get_namespace(self, path=''): + return self._namespaceByPath[path] - def on(self, eventName, eventCallback, namespacePath=''): - return self.get_namespace(namespacePath).on(eventName, eventCallback) + def on(self, event, callback, path=''): + return self.get_namespace(path).on(event, callback) - def message(self, messageData, messageCallback=None, namespacePath=''): - self._socketIO.message(messageData, messageCallback, namespacePath) + def message(self, data, callback=None, path=''): + self._socketIO.message(data, callback, path) - def emit(self, eventName, *eventArguments, **eventKeywords): - self._socketIO.emit(eventName, *eventArguments, **eventKeywords) + def emit(self, event, *args, **kw): + self._socketIO.emit(event, *args, **kw) def wait(self, seconds=None, forCallbacks=False): if forCallbacks: @@ -187,10 +189,13 @@ class _ListenerThread(Thread): # Block callingThread until listenerThread terminates self.join(seconds) + def get_ackCallback(self, packetID): + return lambda *args: self._socketIO.ack(packetID, *args) + def run(self): while not self.done.is_set(): try: - code, packetID, namespacePath, data = self._socketIO.recv_packet() + code, packetID, path, data = self._socketIO.recv_packet() except SocketIOConnectionError, error: print error return @@ -198,9 +203,9 @@ class _ListenerThread(Thread): print error continue try: - namespace = self._namespaceByPath[namespacePath] + namespace = self._namespaceByPath[path] except KeyError: - print 'Received unexpected namespacePath (%s)' % namespacePath + print 'Received unexpected path (%s)' % path continue try: delegate = { @@ -228,25 +233,33 @@ class _ListenerThread(Thread): pass def on_message(self, packetID, get_eventCallback, data): - get_eventCallback('message')(data) + args = [data] + if packetID: + args.append(self.get_ackCallback(packetID)) + get_eventCallback('message')(args) def on_json(self, packetID, get_eventCallback, data): - get_eventCallback('message')(loads(data)) + args = [loads(data)] + if packetID: + args.append(self.get_ackCallback(packetID)) + get_eventCallback('message')(args) def on_event(self, packetID, get_eventCallback, data): valueByName = loads(data) - eventName = valueByName['name'] - eventArguments = valueByName.get('args', []) - get_eventCallback(eventName)(*eventArguments) + event = valueByName['name'] + args = valueByName.get('args', []) + if packetID: + args.append(self.get_ackCallback(packetID)) + get_eventCallback(event)(*args) def on_acknowledgment(self, packetID, get_eventCallback, data): dataParts = data.split('+', 1) messageID = int(dataParts[0]) - arguments = loads(dataParts[1]) or [] - messageCallback = self._socketIO.get_messageCallback(messageID) - if not messageCallback: + args = loads(dataParts[1]) or [] + callback = self._socketIO.get_messageCallback(messageID) + if not callback: return - messageCallback(*arguments) + callback(*args) if self.waiting.is_set() and not self._socketIO.has_messageCallback: self.cancel() @@ -284,18 +297,18 @@ class _SocketIO(object): self.callbackByMessageID = {} def __del__(self): - self.disconnect(closeSocket=False) + self.disconnect(close=False) - def disconnect(self, namespacePath='', closeSocket=True): + def disconnect(self, path='', close=True): if not self.connected: return - if namespacePath: - self.send_packet(0, namespacePath) - elif closeSocket: + if path: + self.send_packet(0, path) + elif close: self.connection.close() - def connect(self, namespacePath): - self.send_packet(1, namespacePath) + def connect(self, path): + self.send_packet(1, path) def send_heartbeat(self): try: @@ -304,24 +317,25 @@ class _SocketIO(object): print 'Could not send heartbeat' pass - def message(self, messageData, messageCallback, namespacePath): - if isinstance(messageData, basestring): + def message(self, data, callback, path): + if isinstance(data, basestring): code = 3 - data = messageData + packetData = data else: code = 4 - data = dumps(messageData, ensure_ascii=False) - self.send_packet(code, namespacePath, data, messageCallback) + packetData = dumps(data, ensure_ascii=False) + self.send_packet(code, path, packetData, callback) - def emit(self, eventName, *eventArguments, **eventKeywords): - if eventArguments and callable(eventArguments[-1]): - messageCallback = eventArguments[-1] - eventArguments = eventArguments[:-1] - else: - messageCallback = None - namespacePath = eventKeywords.get('namespacePath', '') - data = dumps(dict(name=eventName, args=eventArguments), ensure_ascii=False) - self.send_packet(5, namespacePath, data, messageCallback) + def emit(self, event, *args, **kw): + callback, args = find_callback(args, kw) + packetData = dumps(dict(name=event, args=args), ensure_ascii=False) + path = kw.get('path', '') + self.send_packet(5, path, packetData, callback) + + def ack(self, packetID, *args): + packetID = packetID.rstrip('+') + packetData = '%s+%s' % (packetID, dumps(args, ensure_ascii=False)) if args else packetID + self.send_packet(6, data=packetData) def set_messageCallback(self, callback): 'Set callback that will be called after receiving an acknowledgment' @@ -355,18 +369,18 @@ class _SocketIO(object): except AttributeError: raise SocketIOPacketError('Received invalid packet (%s)' % packet) packetCount = len(packetParts) - code, packetID, namespacePath, data = None, None, None, None + code, packetID, path, data = None, None, None, None if 4 == packetCount: - code, packetID, namespacePath, data = packetParts + code, packetID, path, data = packetParts elif 3 == packetCount: - code, packetID, namespacePath = packetParts + code, packetID, path = packetParts elif 1 == packetCount: code = packetParts[0] - return code, packetID, namespacePath, data + return code, packetID, path, data - def send_packet(self, code, namespacePath='', data='', messageCallback=None): - callbackNumber = self.set_messageCallback(messageCallback) if messageCallback else '' - packetParts = [str(code), callbackNumber, namespacePath, data] + def send_packet(self, code, path='', data='', callback=None): + packetID = self.set_messageCallback(callback) if callback else '' + packetParts = [str(code), packetID, path, data] try: self.connection.send(':'.join(packetParts)) except socket.error: @@ -387,3 +401,13 @@ class SocketIOConnectionError(SocketIOError): class SocketIOPacketError(SocketIOError): pass + + +def find_callback(args, kw=None): + 'Return callback whether passed as a last argument or as a keyword' + if args and callable(args[-1]): + return args[-1], args[:-1] + try: + return kw['callback'], args + except (KeyError, TypeError): + return None, args diff --git a/socketIO_client/tests.py b/socketIO_client/tests.py index cc9d3e7..92ffe5f 100644 --- a/socketIO_client/tests.py +++ b/socketIO_client/tests.py @@ -1,9 +1,8 @@ -from socketIO_client import SocketIO, BaseNamespace +from socketIO_client import SocketIO, BaseNamespace, find_callback from time import sleep from unittest import TestCase -ON_RESPONSE_CALLED = False PORT = 8000 PAYLOAD = {'xxx': 'yyy'} @@ -11,13 +10,28 @@ PAYLOAD = {'xxx': 'yyy'} class TestSocketIO(TestCase): def setUp(self): - global ON_RESPONSE_CALLED - ON_RESPONSE_CALLED = False self.socketIO = SocketIO('localhost', PORT) + self.called_on_response = False def tearDown(self): del self.socketIO + def on_response(self, *args): + self.called_on_response = True + callback, args = find_callback(args) + if callback: + callback(*args) + + def test_disconnect(self): + childThreads = [ + self.socketIO._rhythmicThread, + self.socketIO._listenerThread, + ] + self.socketIO.disconnect() + for childThread in childThreads: + self.assertEqual(True, childThread.done.is_set()) + self.assertEqual(False, self.socketIO.connected) + def test_emit(self): self.socketIO.define(Namespace) self.socketIO.emit('aaa') @@ -31,20 +45,26 @@ class TestSocketIO(TestCase): self.assertEqual(self.socketIO.get_namespace().payload, PAYLOAD) def test_emit_with_callback(self): - self.socketIO.emit('aaa', PAYLOAD, on_response) - self.socketIO.wait(forCallbacks=True) - self.assertEqual(ON_RESPONSE_CALLED, True) + self.socketIO.emit('aaa', PAYLOAD, self.on_response) + self.socketIO.wait(seconds=0.1, forCallbacks=True) + self.assertEqual(self.called_on_response, True) - def test_message(self): - self.socketIO.message(PAYLOAD, on_response) - self.socketIO.wait(forCallbacks=True) - self.assertEqual(ON_RESPONSE_CALLED, True) - - def test_events(self): - self.socketIO.on('aaa_response', on_response) + def test_emit_with_event(self): + self.socketIO.on('aaa_response', self.on_response) self.socketIO.emit('aaa', PAYLOAD) sleep(0.1) - self.assertEqual(ON_RESPONSE_CALLED, True) + self.assertEqual(self.called_on_response, True) + + def test_message(self): + self.socketIO.message(PAYLOAD, self.on_response) + self.socketIO.wait(seconds=0.1, forCallbacks=True) + self.assertEqual(self.called_on_response, True) + + def test_ack(self): + self.socketIO.on('bbb_response', self.on_response) + self.socketIO.emit('bbb', PAYLOAD) + sleep(0.1) + self.assertEqual(self.called_on_response, True) def test_namespaces(self): mainNamespace = self.socketIO.define(Namespace) @@ -59,19 +79,9 @@ class TestSocketIO(TestCase): def test_namespaces_with_callback(self): mainNamespace = self.socketIO.get_namespace() - mainNamespace.message(PAYLOAD, on_response) + mainNamespace.message(PAYLOAD, self.on_response) sleep(0.1) - self.assertEqual(ON_RESPONSE_CALLED, True) - - def test_disconnect(self): - childThreads = [ - self.socketIO._rhythmicThread, - self.socketIO._listenerThread, - ] - self.socketIO.disconnect() - for childThread in childThreads: - self.assertEqual(True, childThread.done.is_set()) - self.assertEqual(False, self.socketIO.connected) + self.assertEqual(self.called_on_response, True) class Namespace(BaseNamespace): @@ -81,8 +91,3 @@ class Namespace(BaseNamespace): def on_aaa_response(self, data=''): print '[Event] aaa_response(%s)' % data self.payload = data - - -def on_response(*args): - global ON_RESPONSE_CALLED - ON_RESPONSE_CALLED = True