From 096d41c072517e6e78ffcc97ba113bd54273b7fe Mon Sep 17 00:00:00 2001 From: Roy Hyunjin Han Date: Sun, 3 Nov 2013 01:01:29 -0800 Subject: [PATCH] Fixes #23 #21 #20 #18 --- TODO.goals | 16 +- socketIO_client/__init__.py | 774 ++++++++++++++++++++---------------- socketIO_client/tests.py | 62 +-- 3 files changed, 467 insertions(+), 385 deletions(-) diff --git a/TODO.goals b/TODO.goals index 8cd11fc..87f2b73 100644 --- a/TODO.goals +++ b/TODO.goals @@ -1,8 +1,8 @@ -= Investigate coroutine replacement for threads -Make client robust if server disconnects -Signal on_disconnect -Add on_noop -Add query_string -Add more transports -Replace print with logging -Test with django-socketio += Test with django-socketio ++ Investigate coroutine replacement for threads ++ Make client robust if server disconnects ++ Signal on_disconnect ++ Add on_noop ++ Add query_string ++ Add more transports ++ Replace print with logging diff --git a/socketIO_client/__init__.py b/socketIO_client/__init__.py index 59b107f..6dbbc7f 100644 --- a/socketIO_client/__init__.py +++ b/socketIO_client/__init__.py @@ -1,47 +1,66 @@ +import logging +import json import requests import socket -from json import dumps, loads -from threading import Thread, Event -from time import sleep -from websocket import WebSocketConnectionClosedException, create_connection +import time +import websocket +from collections import namedtuple -PROTOCOL = 1 # socket.io protocol version +_Session = namedtuple('_Session', [ + 'id', + 'heartbeat_timeout', + 'server_supported_transports', +]) +_log = logging.getLogger(__name__) +TRANSPORTS = 'websocket', 'xhr-polling', 'jsonp-polling' +PROTOCOL_VERSION = 1 -class BaseNamespace(object): # pragma: no cover - 'Define socket.io behavior' +class BaseNamespace(object): + 'Define client behavior' - def __init__(self, _socketIO, path): - self._socketIO = _socketIO + def __init__(self, _transport, path): + self._transport = _transport self._path = path - self._callbackByEvent = {} + self._callback_by_event = {} self.initialize() def initialize(self): 'Initialize custom variables here; you can override this method' pass + def message(self, data='', callback=None): + self._transport.message(self._path, data, callback) + + def emit(self, event, *args, **kw): + callback, args = find_callback(args, kw) + self._transport.emit(self._path, event, args, callback) + + def on(self, event, callback): + 'Define a callback to handle a custom event emitted by the server' + self._callback_by_event[event] = callback + def on_connect(self): - 'Called when socket is connecting; you can override this method' - pass + 'Called after server connects; you can override this method' + _log.debug('[connect]') def on_disconnect(self): - 'Called when socket is disconnecting; you can override this method' - pass + 'Called after server disconnects; you can override this method' + _log.debug('[disconnect]') - def on_error(self, reason, advice): - 'Called when server sends an error; you can override this method' - print '[Error] %s' % advice + def on_heartbeat(self): + 'Called after server sends a heartbeat; you can override this method' + _log.debug('[heartbeat]') def on_message(self, data): - 'Called when server sends a message; you can override this method' - print '[Message] %s' % data + 'Called after server sends a message; you can override this method' + _log.info('[message] %s', data) def on_event(self, event, *args): """ - Called when server emits an event; you can override this method. - Called only if the program cannot find a more specific event handler, + Called after server sends an event; you can override this method. + Called only if a custom event handler does not exist, such as one defined by namespace.on('my_event', my_function). """ callback, args = find_callback(args) @@ -49,389 +68,387 @@ class BaseNamespace(object): # pragma: no cover if callback: arguments.append('callback(*args)') callback(*args) - print '[Event] %s(%s)' % (event, ', '.join(arguments)) + _log.info('[event] %s(%s)', event, ', '.join(arguments)) + + def on_error(self, reason, advice): + 'Called after server sends an error; you can override this method' + _log.info('[error] %s', advice) + + def on_noop(self): + 'Called after server sends a noop; you can override this method' + _log.info('[noop]') def on_open(self, *args): - print '[Open]', args + _log.info('[open] %s', args) def on_close(self, *args): - print '[Close]', args + _log.info('[close] %s', args) def on_retry(self, *args): - print '[Retry]', args + _log.info('[retry] %s', args) def on_reconnect(self, *args): - print '[Reconnect]', args + _log.info('[reconnect] %s', args) - def message(self, data='', callback=None): - self._socketIO.message(data, callback, path=self._path) - - def emit(self, event, *args, **kw): - kw['path'] = self._path - self._socketIO.emit(event, *args, **kw) - - def on(self, event, callback): - 'Define a callback to handle a custom event emitted by the server' - self._callbackByEvent[event] = callback - - def _get_eventCallback(self, event): + def _find_event_callback(self, event): # Check callbacks defined by on() try: - return self._callbackByEvent[event] + return self._callback_by_event[event] except KeyError: pass # Check callbacks defined explicitly or use on_event() - callback = lambda *args: self.on_event(event, *args) - return getattr(self, 'on_' + event.replace(' ', '_'), callback) + return getattr( + self, + 'on_' + event.replace(' ', '_'), + lambda *args: self.on_event(event, *args)) class SocketIO(object): - def __init__(self, host, port, Namespace=BaseNamespace, secure=False, headers=None, proxies=None): + def __init__( + self, host, port, Namespace=BaseNamespace, secure=False, + wait_for_connection=True, transports=TRANSPORTS, **kw): """ Create a socket.io client that connects to a socket.io server - at the specified host and port. Set secure=True to use HTTPS / WSS. + at the specified host and port. + - Define the behavior of the client by specifying a custom Namespace. + - Set secure=True to use HTTPS / WSS. + - Set wait_for_connection=True to block until we have a connection. + - List the transports you want to use (%s). + - Pass query params, headers, cookies, proxies as keyword arguments. - SocketIO('localhost', 8000, secure=True, - proxies={'https': 'https://proxy.example.com:8080'}) - """ - self._socketIO = _SocketIO(host, port, secure, headers, proxies) - self._namespaceByPath = {} + SocketIO('localhost', 8000, proxies={ + 'https': 'https://proxy.example.com:8080'}) + """ % ', '.join(TRANSPORTS) + self.base_url = '%s:%d/socket.io/%s' % (host, port, PROTOCOL_VERSION) + self.secure = secure + self.wait_for_connection = wait_for_connection + self._namespace_by_path = {} + self.client_supported_transports = transports + self.kw = kw self.define(Namespace) - self._rhythmicThread = _RhythmicThread( - self._socketIO.heartbeatInterval, - self._socketIO.send_heartbeat) - self._rhythmicThread.start() - - self._listenerThread = _ListenerThread( - self._socketIO, - self._namespaceByPath) - self._listenerThread.start() - def __enter__(self): return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, *exception_pack): self.disconnect() def __del__(self): - self.disconnect(close=False) - - @property - def connected(self): - return self._socketIO.connected - - def disconnect(self, path='', close=True): - if self.connected: - self._socketIO.disconnect(path, close) - if path: - del self._namespaceByPath[path] - else: - self._rhythmicThread.cancel() - self._listenerThread.cancel() + self.disconnect() def define(self, Namespace, path=''): if path: - self._socketIO.connect(path) - namespace = Namespace(self._socketIO, path) - self._namespaceByPath[path] = namespace + self._transport.connect(path) + namespace = Namespace(self._transport, path) + self._namespace_by_path[path] = namespace return namespace - def get_namespace(self, path=''): - return self._namespaceByPath[path] - def on(self, event, callback, path=''): return self.get_namespace(path).on(event, callback) def message(self, data='', callback=None, path=''): - self._socketIO.message(data, callback, path) + self._transport.message(path, data, callback) def emit(self, event, *args, **kw): - self._socketIO.emit(event, *args, **kw) + path = kw.get('path', '') + callback, args = find_callback(args, kw) + self._transport.emit(path, event, args, callback) - def wait(self, seconds=None): - if seconds: - self._listenerThread.wait(seconds) - else: - try: - while self.connected: - sleep(1) - except KeyboardInterrupt: - pass - - def wait_for_callbacks(self, seconds=None): - self._listenerThread.wait_for_callbacks(seconds) - - -class _RhythmicThread(Thread): - 'Execute call every few seconds' - - daemon = True - - def __init__(self, intervalInSeconds, call, *args, **kw): - super(_RhythmicThread, self).__init__() - self.intervalInSeconds = intervalInSeconds - self.call = call - self.args = args - self.kw = kw - self.done = Event() - - def run(self): - while not self.done.is_set(): - self.call(*self.args, **self.kw) - self.done.wait(self.intervalInSeconds) - - def cancel(self): - self.done.set() - - -class _ListenerThread(Thread): - 'Process messages from socket.io server' - - daemon = True - - def __init__(self, _socketIO, _namespaceByPath): - super(_ListenerThread, self).__init__() - self._socketIO = _socketIO - self._namespaceByPath = _namespaceByPath - self.done = Event() - self.ready = Event() - self.ready.set() - - def cancel(self): - self.done.set() - - def wait(self, seconds): - self.done.wait(seconds) - - def wait_for_callbacks(self, seconds): - self.ready.clear() - self.ready.wait(seconds) - - def get_ackCallback(self, packetID): - return lambda *args: self._socketIO.ack(packetID, *args) - - def run(self): - while not self.done.is_set(): - try: - code, packetID, path, data = self._socketIO.recv_packet() - except SocketIOConnectionError, error: - print error - self.cancel() - break - except SocketIOPacketError, error: - print error - continue - try: - namespace = self._namespaceByPath[path] - except KeyError: - print 'Received unexpected path (%s)' % path - continue - try: - delegate = { - '0': self.on_disconnect, - '1': self.on_connect, - '2': self.on_heartbeat, - '3': self.on_message, - '4': self.on_json, - '5': self.on_event, - '6': self.on_ack, - '7': self.on_error, - }[code] - except KeyError: - print 'Received unexpected code (%s)' % code - continue - delegate(packetID, namespace._get_eventCallback, data) - - def on_disconnect(self, packetID, get_eventCallback, data): - get_eventCallback('disconnect')() - - def on_connect(self, packetID, get_eventCallback, data): - get_eventCallback('connect')() - - def on_heartbeat(self, packetID, get_eventCallback, data): - pass - - def on_message(self, packetID, get_eventCallback, data): - args = [data] - if packetID: - args.append(self.get_ackCallback(packetID)) - get_eventCallback('message')(*args) - - def on_json(self, packetID, get_eventCallback, data): - args = [loads(data)] - if packetID: - args.append(self.get_ackCallback(packetID)) - get_eventCallback('message')(*args) - - def on_event(self, packetID, get_eventCallback, data): - valueByName = loads(data) - event = valueByName['name'] - args = valueByName.get('args', []) - if packetID: - args.append(self.get_ackCallback(packetID)) - get_eventCallback(event)(*args) - - def on_ack(self, packetID, get_eventCallback, data): - dataParts = data.split('+', 1) - messageID = int(dataParts[0]) - args = loads(dataParts[1]) if len(dataParts) > 1 else [] - callback = self._socketIO.get_messageCallback(messageID) - if not callback: - return - callback(*args) - if not self._socketIO.has_messageCallback: - self.ready.set() - - def on_error(self, packetID, get_eventCallback, data): - reason, advice = data.split('+', 1) - get_eventCallback('error')(reason, advice) - - -class _SocketIO(object): - 'Low-level interface to remove cyclic references in child threads' - - messageID = 0 - - def __init__(self, host, port, secure, headers, proxies): - baseURL = '%s:%d/socket.io/%s' % (host, port, PROTOCOL) - targetScheme = 'https' if secure else 'http' - targetURL = '%s://%s/' % (targetScheme, baseURL) + def wait(self, seconds=None, for_callbacks=False): try: - response = requests.get( - targetURL, - headers=headers, - proxies=proxies) - except IOError: # pragma: no cover - raise SocketIOError('Could not start connection') - if 200 != response.status_code: # pragma: no cover - raise SocketIOError('Could not establish connection') - responseParts = response.text.split(':') - sessionID = responseParts[0] - heartbeatTimeout = int(responseParts[1]) - # connectionTimeout = int(responseParts[2]) - supportedTransports = responseParts[3].split(',') - if 'websocket' not in supportedTransports: - raise SocketIOError('Could not parse handshake') - socketScheme = 'wss' if secure else 'ws' - socketURL = '%s://%s/websocket/%s' % (socketScheme, baseURL, sessionID) - self.connection = create_connection(socketURL) - self.heartbeatInterval = heartbeatTimeout - 2 - self.callbackByMessageID = {} - - def __del__(self): - self.disconnect(close=False) - - def disconnect(self, path='', close=True): - if not self.connected: - return - if path: - self.send_packet(0, path) - elif close: - self.connection.close() - - def connect(self, path): - self.send_packet(1, path) - - def send_heartbeat(self): - try: - self.send_packet(2) - except SocketIOPacketError: - print 'Could not send heartbeat' + warning_screen = _yield_warning_screen(seconds, sleep=1) + for elapsed_time in warning_screen: + try: + if for_callbacks and not self._transport.has_ack_callback: + break + try: + self._process_packet(self._transport.recv_packet()) + except _TimeoutError: + pass + except _PacketError as error: + _log.warn('[packet error] %s', error) + self.heartbeat_pacemaker.send(elapsed_time) + except SocketIOConnectionError as error: + self.disconnect() + warning = Exception('[connection error] %s' % error) + warning_screen.throw(warning) + except KeyboardInterrupt: pass - def message(self, data, callback, path): - if isinstance(data, basestring): - code = 3 - packetData = data - else: - code = 4 - packetData = dumps(data, ensure_ascii=False) - self.send_packet(code, path, packetData, callback) + def wait_for_callbacks(self, seconds=None): + self.wait(seconds, for_callbacks=True) - def emit(self, event, *args, **kw): - callback, args = find_callback(args, kw) - packetData = dumps(dict(name=event, args=args), ensure_ascii=False) - path = kw.get('path', '') - self.send_packet(5, path, packetData, callback) - - def ack(self, packetID, *args): - packetID = packetID.rstrip('+') - packetData = '%s+%s' % ( - packetID, - dumps(args, ensure_ascii=False), - ) if args else packetID - self.send_packet(6, data=packetData) - - def set_messageCallback(self, callback): - 'Set callback that will be called after receiving an acknowledgment' - self.messageID += 1 - self.callbackByMessageID[self.messageID] = callback - return '%s+' % self.messageID - - def get_messageCallback(self, messageID): - try: - callback = self.callbackByMessageID[messageID] - del self.callbackByMessageID[messageID] - return callback - except KeyError: - return - - @property - def has_messageCallback(self): - return True if self.callbackByMessageID else False - - def recv_packet(self): - try: - packet = self.connection.recv() - except WebSocketConnectionClosedException: - text = 'Lost connection (Connection closed)' - raise SocketIOConnectionError(text) - except socket.timeout: - text = 'Lost connection (Connection timed out)' - raise SocketIOConnectionError(text) - except socket.error: - text = 'Lost connection' - raise SocketIOConnectionError(text) - try: - packetParts = packet.split(':', 3) - except AttributeError: - raise SocketIOPacketError('Received invalid packet (%s)' % packet) - packetCount = len(packetParts) - code, packetID, path, data = None, None, None, None - if 4 == packetCount: - code, packetID, path, data = packetParts - elif 3 == packetCount: - code, packetID, path = packetParts - elif 1 == packetCount: - code = packetParts[0] - return code, packetID, path, data - - def send_packet(self, code, path='', data='', callback=None): - packetID = self.set_messageCallback(callback) if callback else '' - packetParts = [str(code), packetID, path, data] - try: - packet = ':'.join(packetParts) - self.connection.send(packet) - except socket.error: - raise SocketIOPacketError('Could not send packet') + def disconnect(self, path=''): + if self.connected: + self._transport.disconnect(path) + namespace = self._namespace_by_path[path] + namespace.on_disconnect() + if path: + del self._namespace_by_path[path] @property def connected(self): - return self.connection.connected + return self.__transport.connected + + @property + def _transport(self): + try: + if self.connected: + return self.__transport + except AttributeError: + pass + warning_screen = _yield_warning_screen(seconds=None, sleep=1) + for elapsed_time in warning_screen: + try: + self.__transport = self._get_transport() + break + except SocketIOConnectionError as error: + if not self.wait_for_connection: + raise + warning = Exception('[waiting for connection] %s' % error) + warning_screen.throw(warning) + return self.__transport + + def _get_transport(self): + self.session = _get_session(self.secure, self.base_url, **self.kw) + _log.debug('[transports available] %s', ' '.join( + self.session.server_supported_transports)) + # Initialize heartbeat_pacemaker + self.heartbeat_pacemaker = self._make_heartbeat_pacemaker( + heartbeat_interval=self.session.heartbeat_timeout - 2) + self.heartbeat_pacemaker.next() + # Negotiate transport + transport = _negotiate_transport( + self.client_supported_transports, self.session, + self.secure, self.base_url, **self.kw) + # Update namespaces + for namespace in self._namespace_by_path.values(): + namespace._transport = transport + return transport + + def _make_heartbeat_pacemaker(self, heartbeat_interval): + heartbeat_time = 0 + while True: + elapsed_time = (yield) + if elapsed_time - heartbeat_time > heartbeat_interval: + heartbeat_time = elapsed_time + self._transport.send_heartbeat() + + def _process_packet(self, packet): + code, packet_id, path, data = packet + namespace = self.get_namespace(path) + delegate = self._get_delegate(code) + delegate(packet_id, data, namespace._find_event_callback) + + def get_namespace(self, path=''): + try: + return self._namespace_by_path[path] + except KeyError: + raise _PacketError('unexpected namespace path (%s)' % path) + + def _get_delegate(self, code): + try: + return { + '0': self._on_disconnect, + '1': self._on_connect, + '2': self._on_heartbeat, + '3': self._on_message, + '4': self._on_json, + '5': self._on_event, + '6': self._on_ack, + '7': self._on_error, + '8': self._on_noop, + }[code] + except KeyError: + raise _PacketError('unexpected code (%s)' % code) + + def _on_disconnect(self, packet_id, data, find_event_callback): + find_event_callback('disconnect')() + + def _on_connect(self, packet_id, data, find_event_callback): + find_event_callback('connect')() + + def _on_heartbeat(self, packet_id, data, find_event_callback): + find_event_callback('heartbeat')() + + def _on_message(self, packet_id, data, find_event_callback): + args = [data] + if packet_id: + args.append(self._prepare_to_send_ack(packet_id)) + find_event_callback('message')(*args) + + def _on_json(self, packet_id, data, find_event_callback): + args = [json.loads(data)] + if packet_id: + args.append(self._prepare_to_send_ack(packet_id)) + find_event_callback('message')(*args) + + def _on_event(self, packet_id, data, find_event_callback): + value_by_name = json.loads(data) + event = value_by_name['name'] + args = value_by_name.get('args', []) + if packet_id: + args.append(self._prepare_to_send_ack(packet_id)) + find_event_callback(event)(*args) + + def _on_ack(self, packet_id, data, find_event_callback): + data_parts = data.split('+', 1) + packet_id = data_parts[0] + try: + ack_callback = self._transport.get_ack_callback(packet_id) + except KeyError: + return + args = json.loads(data_parts[1]) if len(data_parts) > 1 else [] + ack_callback(*args) + + def _on_error(self, packet_id, data, find_event_callback): + reason, advice = data.split('+', 1) + find_event_callback('error')(reason, advice) + + def _on_noop(self, packet_id, data, find_event_callback): + find_event_callback('noop')() + + def _prepare_to_send_ack(self, packet_id): + 'Return function that acknowledges the server' + return lambda *args: self._transport.ack(packet_id, *args) class SocketIOError(Exception): pass +class _TimeoutError(Exception): + pass + + +class _PacketError(SocketIOError): + pass + + class SocketIOConnectionError(SocketIOError): pass -class SocketIOPacketError(SocketIOError): - pass +class _AbstractTransport(object): + + def __init__(self): + self._packet_id = 0 + self._callback_by_packet_id = {} + + def disconnect(self, path=''): + if not self.connected: + return + if path: + self.send_packet(0, path) + else: + self.connection.close() + + def connect(self, path): + self.send_packet(1, path) + + def send_heartbeat(self): + self.send_packet(2) + + def message(self, path, data, callback): + if isinstance(data, basestring): + code = 3 + else: + code = 4 + data = json.dumps(data, ensure_ascii=False) + self.send_packet(code, path, data, callback) + + def emit(self, path, event, args, callback): + data = json.dumps(dict(name=event, args=args), ensure_ascii=False) + self.send_packet(5, path, data, callback) + + def ack(self, packet_id, *args): + packet_id = packet_id.rstrip('+') + data = '%s+%s' % ( + packet_id, + json.dumps(args, ensure_ascii=False), + ) if args else packet_id + self.send_packet(6, data=data) + + def noop(self): + self.send_packet(8) + + def send_packet(self, code, path='', data='', callback=None): + packet_id = self.set_ack_callback(callback) if callback else '' + packet_parts = str(code), packet_id, path, data + packet_text = ':'.join(packet_parts) + self.send(packet_text) + _log.debug('[packet sent] %s', packet_text) + + def recv_packet(self): + code, packet_id, path, data = None, None, None, None + packet_text = self.recv() + _log.debug('[packet received] %s', packet_text) + try: + packet_parts = packet_text.split(':', 3) + except AttributeError: + raise _PacketError('invalid packet (%s)' % packet_text) + packet_count = len(packet_parts) + if 4 == packet_count: + code, packet_id, path, data = packet_parts + elif 3 == packet_count: + code, packet_id, path = packet_parts + elif 1 == packet_count: + code = packet_parts[0] + return code, packet_id, path, data + + def set_ack_callback(self, callback): + 'Set callback to be called after server sends an acknowledgment' + self._packet_id += 1 + self._callback_by_packet_id[str(self._packet_id)] = callback + return '%s+' % self._packet_id + + def get_ack_callback(self, packet_id): + 'Get callback to be called after server sends an acknowledgment' + callback = self._callback_by_packet_id[packet_id] + del self._callback_by_packet_id[packet_id] + return callback + + @property + def has_ack_callback(self): + return True if self._callback_by_packet_id else False + + +class _WebsocketTransport(_AbstractTransport): + + def __init__(self, session, secure, base_url, **kw): + super(_WebsocketTransport, self).__init__() + url = '%s://%s/websocket/%s' % ( + 'wss' if secure else 'ws', + base_url, session.id) + _log.debug('[transport selected] %s', url) + try: + self.connection = websocket.create_connection(url) + except socket.timeout as error: + raise SocketIOConnectionError(error) + except socket.error as error: + raise SocketIOConnectionError(error) + self.connection.settimeout(1) + + @property + def connected(self): + return self.connection.connected + + def recv(self): + try: + return self.connection.recv() + except socket.timeout: + raise _TimeoutError + except socket.error as error: + raise SocketIOConnectionError(error) + except websocket.WebSocketConnectionClosedException: + raise SocketIOConnectionError('server closed connection') + + def send(self, packet_text): + try: + self.connection.send(packet_text) + except socket.error: + raise SocketIOConnectionError('could not send %s' % packet_text) def find_callback(args, kw=None): @@ -442,3 +459,68 @@ def find_callback(args, kw=None): return kw['callback'], args except (KeyError, TypeError): return None, args + + +def _yield_warning_screen(seconds=None, sleep=0): + last_warning = None + for elapsed_time in _yield_elapsed_time(seconds): + try: + yield elapsed_time + except Exception as warning: + warning = str(warning) + if last_warning != warning: + last_warning = warning + _log.warn(warning) + time.sleep(sleep) + + +def _yield_elapsed_time(seconds=None): + if seconds is None: + while True: + yield float('inf') + start_time = time.time() + while time.time() - start_time < seconds: + yield time.time() - start_time + + +def _get_session(secure, base_url, **kw): + server_url = '%s://%s/' % ('https' if secure else 'http', base_url) + try: + response = requests.get(server_url, **kw) + except requests.exceptions.ConnectionError: + raise SocketIOConnectionError('could not start connection') + status = response.status_code + if 200 != status: + raise SocketIOConnectionError('unexpected status code (%s)' % status) + response_parts = response.text.split(':') + return _Session( + id=response_parts[0], + heartbeat_timeout=int(response_parts[1]), + server_supported_transports=response_parts[3].split(',')) + + +def _negotiate_transport( + client_supported_transports, session, + secure, base_url, **kw): + server_supported_transports = session.server_supported_transports + for supported_transport in client_supported_transports: + if supported_transport in server_supported_transports: + return { + 'websocket': _WebsocketTransport, + # 'xhr-polling': + # 'jsonp-polling': + }[supported_transport](session, secure, base_url, **kw) + raise SocketIOError(' '.join([ + 'could not negotiate a transport:', + 'client supports %s but' % ', '.join(client_supported_transports), + 'server supports %s' % ', '.join(server_supported_transports), + ])) + + +if __name__ == '__main__': + requests_log = logging.getLogger('requests') + requests_log.setLevel(logging.WARNING) + logging.basicConfig(level=logging.DEBUG) + socketIO = SocketIO('localhost', 8000) + socketIO.emit('aaa') + socketIO.wait() diff --git a/socketIO_client/tests.py b/socketIO_client/tests.py index bdbfbd5..d16e7cd 100644 --- a/socketIO_client/tests.py +++ b/socketIO_client/tests.py @@ -1,3 +1,4 @@ +import logging from socketIO_client import SocketIO, BaseNamespace, find_callback from unittest import TestCase @@ -6,6 +7,7 @@ HOST = 'localhost' PORT = 8000 DATA = 'xxx' PAYLOAD = {'xxx': 'yyy'} +logging.basicConfig(level=logging.DEBUG) class TestSocketIO(TestCase): @@ -25,24 +27,18 @@ class TestSocketIO(TestCase): else: self.assertEqual(arg, DATA) - def is_connected(self, socketIO, connected): - childThreads = [ - socketIO._rhythmicThread, - socketIO._listenerThread, - ] - for childThread in childThreads: - self.assertEqual(not connected, childThread.done.is_set()) - self.assertEqual(connected, socketIO.connected) - def test_disconnect(self): - 'Terminate child threads after disconnect' - self.is_connected(self.socketIO, True) + 'Disconnect' + self.assertTrue(self.socketIO.connected) self.socketIO.disconnect() - self.is_connected(self.socketIO, False) + self.assertFalse(self.socketIO.connected) # Use context manager - with SocketIO(HOST, PORT) as self.socketIO: - self.is_connected(self.socketIO, True) - self.is_connected(self.socketIO, False) + with SocketIO(HOST, PORT, Namespace) as self.socketIO: + namespace = self.socketIO.get_namespace() + self.assertFalse(namespace.called_on_disconnect) + self.assertTrue(self.socketIO.connected) + self.assertTrue(namespace.called_on_disconnect) + self.assertFalse(self.socketIO.connected) def test_message(self): 'Message' @@ -72,20 +68,20 @@ class TestSocketIO(TestCase): 'Message with callback' self.socketIO.message(callback=self.on_response) self.socketIO.wait_for_callbacks(seconds=0.1) - self.assertEqual(self.called_on_response, True) + self.assertTrue(self.called_on_response) def test_message_with_callback_with_data(self): 'Message with callback with data' self.socketIO.message(DATA, self.on_response) self.socketIO.wait_for_callbacks(seconds=0.1) - self.assertEqual(self.called_on_response, True) + self.assertTrue(self.called_on_response) def test_emit(self): 'Emit' self.socketIO.define(Namespace) self.socketIO.emit('emit') self.socketIO.wait(0.1) - self.assertEqual(self.socketIO.get_namespace().argsByEvent, { + self.assertEqual(self.socketIO.get_namespace().args_by_event, { 'emit_response': (), }) @@ -94,7 +90,7 @@ class TestSocketIO(TestCase): self.socketIO.define(Namespace) self.socketIO.emit('emit_with_payload', PAYLOAD) self.socketIO.wait(0.1) - self.assertEqual(self.socketIO.get_namespace().argsByEvent, { + self.assertEqual(self.socketIO.get_namespace().args_by_event, { 'emit_with_payload_response': (PAYLOAD,), }) @@ -103,7 +99,7 @@ class TestSocketIO(TestCase): self.socketIO.define(Namespace) self.socketIO.emit('emit_with_multiple_payloads', PAYLOAD, PAYLOAD) self.socketIO.wait(0.1) - self.assertEqual(self.socketIO.get_namespace().argsByEvent, { + self.assertEqual(self.socketIO.get_namespace().args_by_event, { 'emit_with_multiple_payloads_response': (PAYLOAD, PAYLOAD), }) @@ -111,35 +107,35 @@ class TestSocketIO(TestCase): 'Emit with callback' self.socketIO.emit('emit_with_callback', self.on_response) self.socketIO.wait_for_callbacks(seconds=0.1) - self.assertEqual(self.called_on_response, True) + self.assertTrue(self.called_on_response) def test_emit_with_callback_with_payload(self): 'Emit with callback with payload' self.socketIO.emit('emit_with_callback_with_payload', self.on_response) self.socketIO.wait_for_callbacks(seconds=0.1) - self.assertEqual(self.called_on_response, True) + self.assertTrue(self.called_on_response) def test_emit_with_callback_with_multiple_payloads(self): 'Emit with callback with multiple payloads' self.socketIO.emit('emit_with_callback_with_multiple_payloads', self.on_response) self.socketIO.wait_for_callbacks(seconds=0.1) - self.assertEqual(self.called_on_response, True) + self.assertTrue(self.called_on_response) def test_emit_with_event(self): 'Emit to trigger an event' self.socketIO.on('emit_with_event_response', self.on_response) self.socketIO.emit('emit_with_event', PAYLOAD) - self.socketIO.wait_for_callbacks(0.1) - self.assertEqual(self.called_on_response, True) + self.socketIO.wait(0.1) + self.assertTrue(self.called_on_response) def test_ack(self): 'Trigger server callback' self.socketIO.define(Namespace) self.socketIO.emit('ack', PAYLOAD) self.socketIO.wait(0.1) - self.assertEqual(self.socketIO.get_namespace().argsByEvent, { + self.assertEqual(self.socketIO.get_namespace().args_by_event, { 'ack_response': (PAYLOAD,), 'ack_callback_response': (PAYLOAD,), }) @@ -151,9 +147,9 @@ class TestSocketIO(TestCase): newsNamespace = self.socketIO.define(Namespace, '/news') newsNamespace.emit('emit_with_payload', PAYLOAD) self.socketIO.wait(0.1) - self.assertEqual(mainNamespace.argsByEvent, {}) - self.assertEqual(chatNamespace.argsByEvent, {}) - self.assertEqual(newsNamespace.argsByEvent, { + self.assertEqual(mainNamespace.args_by_event, {}) + self.assertEqual(chatNamespace.args_by_event, {}) + self.assertEqual(newsNamespace.args_by_event, { 'emit_with_payload_response': (PAYLOAD,), }) @@ -162,7 +158,11 @@ class Namespace(BaseNamespace): def initialize(self): self.response = None - self.argsByEvent = {} + self.args_by_event = {} + self.called_on_disconnect = False + + def on_disconnect(self): + self.called_on_disconnect = True def on_message(self, data): self.response = data @@ -171,4 +171,4 @@ class Namespace(BaseNamespace): callback, args = find_callback(args) if callback: callback(*args) - self.argsByEvent[event] = args + self.args_by_event[event] = args