From d64e947aac64f0c0b5001feb4294a832d494c22d Mon Sep 17 00:00:00 2001 From: Sean Arietta Date: Mon, 22 Dec 2014 00:36:09 -0800 Subject: [PATCH] Checkpoint for updates. Namespaces are working. Events are working. Reconnects working. Disconnects working. Need to implemenet ACKs / callbacks, WebSocket transport, and JSONP transport --- socketIO_client/__init__.py | 124 ++++++++++++++++++++------- socketIO_client/parser.py | 117 +++++++++++++++++++++---- socketIO_client/transports.py | 155 +++++++++++++++++++++++----------- 3 files changed, 302 insertions(+), 94 deletions(-) diff --git a/socketIO_client/__init__.py b/socketIO_client/__init__.py index e4d5fa1..bd3556a 100644 --- a/socketIO_client/__init__.py +++ b/socketIO_client/__init__.py @@ -1,6 +1,8 @@ from collections import namedtuple +import copy import logging import json +import multiprocessing import parser import requests import time @@ -16,7 +18,8 @@ from .transports import _get_response, _negotiate_transport, TRANSPORTS _SocketIOSession = namedtuple('_SocketIOSession', [ 'id', - 'heartbeat_timeout', + 'heartbeat_interval', + 'connection_timeout', 'server_supported_transports', ]) _log = logging.getLogger(__name__) @@ -42,7 +45,11 @@ class BaseNamespace(object): def emit(self, event, *args, **kw): callback, args = find_callback(args, kw) - self._transport.emit(self.path, event, args, callback) + + if callback is not None: + _log.warn("Callback was specified but is not supported."); + + self._transport.emit(self.path, event, args, None) def disconnect(self): self._transport.disconnect(self.path) @@ -138,6 +145,16 @@ class SocketIO(object): self._namespace_by_path = {} self.client_supported_transports = transports self.kw = kw + # These two fields work to control the heartbeat thread. + self.heartbeat_terminator = None; + self.heartbeat_thread = None; + # Saved session information. + self.session = None; + # This is stores the set of paths (namespaces) that need to be + # reconnected to. + self.reconnect_paths = {}; + # This sets of a chain of events that attempts to connect to + # the server at the base namespace. self.define(Namespace) def __enter__(self): @@ -145,9 +162,17 @@ class SocketIO(object): def __exit__(self, *exception_pack): self.disconnect() + self._terminate_heartbeat(); def __del__(self): self.disconnect() + self._terminate_heartbeat(); + + def _terminate_heartbeat(self): + if self.heartbeat_terminator is not None: + self.heartbeat_terminator.set(); + #time.sleep(self.session.heartbeat_interval); + self.heartbeat_thread.join(); def define(self, Namespace, path=''): if path: @@ -167,6 +192,19 @@ class SocketIO(object): callback, args = find_callback(args, kw) self._transport.emit(path, event, args, callback) + def reconnect(self): + """Reconnects to a set of namespaces. + + """ + for path in self.reconnect_paths: + # We avoid reconnecting to the default namespace because + # socketIO_client connects to that already. + if (len(self.reconnect_paths) > 1 and path is ''): + continue; + _log.debug("Reconnecting to path: %s" % repr(path)) + self._transport.connect(path); + self.reconnect_paths = {}; + def wait(self, seconds=None, for_callbacks=False): """Wait in a loop and process events as defined in the namespaces. @@ -181,14 +219,28 @@ class SocketIO(object): pass if self._stop_waiting(for_callbacks): break - self.heartbeat_pacemaker.send(elapsed_time) + + # We will end up here in the case that we + # disconnected, then reconnected AND we were + # successful. + if len(self.reconnect_paths) > 0: + self.reconnect(); except ConnectionError as e: try: + # This is where we end up if the connection was + # severed. The client will disconnect here. + if len(self.reconnect_paths) is 0: + self.reconnect_paths = copy.deepcopy(self._namespace_by_path); + + self._terminate_heartbeat(); + warning = Exception('[connection error] %s' % e) + self._transport._connected = False; warning_screen.throw(warning) except StopIteration: _log.warn(warning) self.disconnect() + _log.debug("[wait canceled]"); def _process_events(self): for packet in self._transport.recv_packet(): @@ -249,31 +301,29 @@ class SocketIO(object): return self.__transport def _get_transport(self): - socketIO_session = _get_socketIO_session( - self.is_secure, self.base_url, **self.kw) - _log.debug('[transports available] %s', ' '.join( - socketIO_session.server_supported_transports)) - # Initialize heartbeat_pacemaker - self.heartbeat_pacemaker = self._make_heartbeat_pacemaker( - heartbeat_interval=socketIO_session.heartbeat_timeout / 2) - next(self.heartbeat_pacemaker) + self.session = _get_socketIO_session(self.is_secure, self.base_url, **self.kw) + _log.debug('[transports available] %s', ' '.join(self.session.server_supported_transports)) + # Negotiate transport transport = _negotiate_transport( - self.client_supported_transports, socketIO_session, + self.client_supported_transports, self.session, self.is_secure, self.base_url, **self.kw) # Update namespaces for path, namespace in self._namespace_by_path.items(): namespace._transport = transport transport.connect(path) - return transport + + transport.set_timeout(self.session.connection_timeout); - def _make_heartbeat_pacemaker(self, heartbeat_interval): - heartbeat_time = 0 - while True: - elapsed_time = (yield) - if elapsed_time - heartbeat_time > heartbeat_interval: - heartbeat_time = elapsed_time - self._transport.send_heartbeat() + # Start the heartbeat pacemaker (PING). + _log.debug("[start heartbeat pacemaker]"); + self.heartbeat_terminator = multiprocessing.Event(); + self.heartbeat_thread = multiprocessing.Process( + target = _make_heartbeat_pacemaker, + args = (self.heartbeat_terminator, transport, self.session.heartbeat_interval / 2)); + self.heartbeat_thread.start(); + + return transport def get_namespace(self, path=''): try: @@ -369,7 +419,7 @@ def _parse_host(host, port): url_pack = parse_url(host) is_secure = url_pack.scheme == 'https' port = port or url_pack.port or (443 if is_secure else 80) - base_url = '%s:%d%s/socket.io/%s' % (url_pack.hostname, port, url_pack.path, PROTOCOL_VERSION) + base_url = '%s:%d%s/socket.io' % (url_pack.hostname, port, url_pack.path) return is_secure, base_url @@ -396,7 +446,8 @@ def _yield_elapsed_time(seconds=None): def _get_socketIO_session(is_secure, base_url, **kw): - server_url = '%s://%s/?transport=polling' % ('https' if is_secure else 'http', base_url) + server_url = '%s://%s/?EIO=%d&transport=polling' \ + % ('https' if is_secure else 'http', base_url, parser.ENGINE_PROTOCOL) _log.debug('[session] %s', server_url) try: response = _get_response(requests.get, server_url, **kw) @@ -404,15 +455,28 @@ def _get_socketIO_session(is_secure, base_url, **kw): raise ConnectionError(e) _log.debug("[response] %s", response.text); - decoded = parser.decode_response(response); - _log.debug("[decoded] %s", repr(decoded)); + packet = parser.decode_response(response); + _log.debug("[decoded] %s", repr(packet)); + + if packet.type is not parser.PacketType.OPEN: + _log.warn("Got unexpected packet during connection handshake: %d" % packet.type); + return None; + + handshake = json.loads(packet.payload); return _SocketIOSession( - id = decoded["payload"]["sid"], - heartbeat_timeout = int(decoded["payload"]["pingInterval"]), + id = handshake["sid"], + heartbeat_interval = int(handshake["pingInterval"]) / 1000, + connection_timeout = int(handshake["pingTimeout"]) / 1000, server_supported_transports = ["xhr-polling"]);#decoded["payload"]["upgrades"]); - #return _SocketIOSession( - # id=response_parts[0], - # heartbeat_timeout=int(response_parts[1]), - # server_supported_transports=response_parts[3].split(',')) +def _make_heartbeat_pacemaker(terminator, transport, heartbeat_interval): + while True: + if terminator.wait(heartbeat_interval): + break; + _log.debug("[hearbeat]"); + try: + transport.send_heartbeat(); + except: + pass; + _log.debug("[heartbeat terminated]"); diff --git a/socketIO_client/parser.py b/socketIO_client/parser.py index 2231481..6896c40 100644 --- a/socketIO_client/parser.py +++ b/socketIO_client/parser.py @@ -1,21 +1,94 @@ +from enum import Enum import logging import json _log = logging.getLogger(__name__) -""" Decodes a response from requests lib. -""" -def decode_response(response): - # TODO(sean): Should we use the 'raw' stream instead? - raw_bytes = response.content; - packet_type = "string" if ord(raw_bytes[0]) == 0 else "binary"; - _log.debug("Packet type: %s" % packet_type); +ENGINE_PROTOCOL = 3; - if packet_type is "string": +class PacketType(Enum): + OPEN = 0; + CLOSE = 1; + PING = 2; + PONG = 3; + MESSAGE = 4; + UPGRADE = 5; + NOOP = 6; + +class MessageType(Enum): + CONNECT = 0; + DISCONNECT = 1; + EVENT = 2; + ACK = 3; + ERROR = 4; + BINARY_EVENT = 5; + BINARY_ACK = 6; + +class Packet(): + def __init__(self, packet_type, payload): + self.type = packet_type; + self.payload = payload; + +class Message(): + def __init__(self, message_type, message, path = ""): + self.type = message_type; + if isinstance(message, basestring): + try: + self.message = json.loads(message); + except: + self.message = message; + else: + self.message = message; + + self.path = path; + + def encode_as_json(self): + """Encodes a Message to be sent to socket.io server. + + Assumes the message payload will be dumped as a json string. + """ + if self.path == "": + return str(self.type) + json.dumps(self.message); + return str(self.type) + self.path + "," + json.dumps(self.message); + + def encode_as_string(self): + """Same as the encode_as_string method except it doesn't encode things as a JSON string""" + if self.path == "": + return str(self.type) + self.message; + return str(self.type) + self.path + "," + self.message; + +def decode_message(payload): + """ Decodes a message encoded via socket.io + """ + + message_type = int(payload[0]); + message = payload[1:]; + + return Message(message_type, message); + +def decode_response(response): + """Decodes a response from requests lib. + + """ + # TODO(sean): Should we use the 'raw' stream instead? + return decode_packet(response.content); + +def decode_packet(packet): + """Decodes a packet sent via engine.io. + + If the packet is a message, this method assumes the message was + encoded by socket.io and will parse it as such. + + """ + + packet_format = "string" if ord(packet[0]) == 0 else "binary"; + _log.debug("Packet type: %s" % packet_format); + + if packet_format is "string": length_bytes = []; offset = 1; - while ord(raw_bytes[offset]) is not 255: - length_bytes.append(ord(raw_bytes[offset])); + while ord(packet[offset]) is not 255: + length_bytes.append(ord(packet[offset])); offset += 1; offset += 1; @@ -26,16 +99,30 @@ def decode_response(response): base *= 10; _log.debug("Packet length: %d" % length); - message_type = raw_bytes[offset]; + packet_type = int(packet[offset]); offset += 1; - message = {"type": message_type, "payload": json.loads(raw_bytes[offset:offset + length - 1])}; - _log.debug("Message: %s" % repr(message)); - return message; + payload = packet[offset:offset + length - 1]; + _log.debug("Payload: %s" % repr(payload)); + + if packet_type is PacketType.MESSAGE: + message = decode_message(payload); + payload = message; + + return Packet(packet_type, payload); else: pass; return ""; -def decode_packet(packet): pass; + +def encode_packet_string(code, path, data): + """Encodes packet to be sent to socket.io server. + """ + + code_length = len(str(code)); + data_length = len(data); + length = code_length + data_length; + + return str(length) + ":" + str(code) + str(data); diff --git a/socketIO_client/transports.py b/socketIO_client/transports.py index e1a0064..a2f37e8 100644 --- a/socketIO_client/transports.py +++ b/socketIO_client/transports.py @@ -1,6 +1,7 @@ import json import logging import parser +from parser import Message, MessageType, PacketType import re import requests import six @@ -13,7 +14,7 @@ from .exceptions import SocketIOError, ConnectionError, TimeoutError TRANSPORTS = 'websocket', 'xhr-polling', 'jsonp-polling' BOUNDARY = six.u('\ufffd') -TIMEOUT_IN_SECONDS = 3 +TIMEOUT_IN_SECONDS = 300 _log = logging.getLogger(__name__) @@ -24,6 +25,10 @@ class _AbstractTransport(object): self._callback_by_packet_id = {} self._wants_to_disconnect = False self._packets = [] + self._timeout = TIMEOUT_IN_SECONDS; + + def set_timeout(self, timeout): + self._timeout = timeout; def disconnect(self, path=''): if not path: @@ -31,15 +36,20 @@ class _AbstractTransport(object): if not self.connected: return if path: - self.send_packet(0, path) + self.send_packet(PacketType.CLOSE, path) else: self.close() def connect(self, path): - self.send_packet(1, path) + if path != "": + _log.debug("Connecting to path: %s" % path); + data = Message(MessageType.CONNECT, path).encode_as_string(); + self.send_packet(PacketType.MESSAGE, path, data); + else: + self.send_packet(PacketType.OPEN, path, data); def send_heartbeat(self): - self.send_packet(2) + self.send_packet(PacketType.PING) def message(self, path, data, callback): if isinstance(data, basestring): @@ -50,8 +60,8 @@ class _AbstractTransport(object): self.send_packet(code, path, data, callback) def emit(self, path, event, args, callback): - data = json.dumps(dict(name=event, args=args), ensure_ascii=False) - self.send_packet(5, path, data, callback) + message = Message(MessageType.EVENT, [event, args], path); + self.send_packet(PacketType.MESSAGE, path, message.encode_as_json(), callback) def ack(self, path, packet_id, *args): packet_id = packet_id.rstrip('+') @@ -59,15 +69,13 @@ class _AbstractTransport(object): packet_id, json.dumps(args, ensure_ascii=False), ) if args else packet_id - self.send_packet(6, path, data) + #self.send_packet(6, path, data) def noop(self, path=''): - self.send_packet(8, path) + self.send_packet(PacketType.NOOP, path) def send_packet(self, code, path='', data='', callback=None): - packet_id = self.set_ack_callback(callback) if callback else '' - packet_parts = str(code), packet_id, path, data - packet_text = ':'.join(packet_parts) + packet_text = parser.encode_packet_string(code, path, data); self.send(packet_text) _log.debug('[packet sent] %s', packet_text) @@ -77,22 +85,48 @@ class _AbstractTransport(object): yield self._packets.pop(0) except IndexError: pass - for packet_text in self.recv(): - _log.debug('[packet received] %s', packet_text) + for response in self.recv(): + _log.debug('[packet received] %s', response.text); try: - #packet = parser.decode_response(packet_text); - packet_parts = packet_text.split(':', 3) + packet = parser.decode_response(response); except AttributeError: - _log.warn('[packet error] %s', packet_text) + _log.warn('[packet error] %s', response.text) continue - code, packet_id, path, data = None, None, None, None - packet_count = len(packet_parts) - if 4 == packet_count: - code, packet_id, path, data = packet_parts - elif 3 == packet_count: - code, packet_id, path = packet_parts - elif 1 == packet_count: - code = packet_parts[0] + code, packet_id, path, data = None, None, '', None + + if packet.type is PacketType.OPEN: + code = '1'; + continue; + elif packet.type is PacketType.CLOSE: + code = '0'; + elif packet.type is PacketType.PING: + code = '2'; + elif packet.type is PacketType.PONG: + code = '2'; + elif packet.type is PacketType.UPGRADE: + _log.warn("Don't know how to handle upgrade packets"); + yield code, packet_id, path, data; + elif packet.type is PacketType.NOOP: + code = '8'; + elif packet.type is PacketType.MESSAGE: + if packet.payload.type is MessageType.CONNECT: + code = '1'; + elif packet.payload.type is MessageType.DISCONNECT: + code = '0'; + elif packet.payload.type is MessageType.EVENT: + code = '5'; + data = json.dumps({"name": packet.payload.message[0], "args": []}); + elif packet.payload.type is MessageType.ACK: + code = '6'; + elif packet.payload.type is MessageType.ERROR: + code = '7'; + else: + _log.warn("Don't know how to handle message type: %d" % packet.payload.type); + yield code, packet_id, path, data; + else: + _log.warn("Don't know how to handle packet type: %d" % packet.type); + yield code, packet_id, path, data; + yield code, packet_id, path, data def _enqueue_packet(self, packet): @@ -169,14 +203,16 @@ class _XHR_PollingTransport(_AbstractTransport): def __init__(self, socketIO_session, is_secure, base_url, **kw): super(_XHR_PollingTransport, self).__init__() - self._url = '%s://%s/?transport=polling&sid=%s' % ( + self._url = '%s://%s/?EIO=%d&transport=polling&sid=%s' % ( 'https' if is_secure else 'http', - base_url, socketIO_session.id) + base_url, parser.ENGINE_PROTOCOL, socketIO_session.id) self._connected = True self._http_session = _prepare_http_session(kw) + self._waiting = False; + # Create connection - for packet in self.recv_packet(): - self._enqueue_packet(packet) + #for packet in self.recv_packet(): + # self._enqueue_packet(packet) @property def connected(self): @@ -184,35 +220,54 @@ class _XHR_PollingTransport(_AbstractTransport): @property def _params(self): - return dict(t=int(time.time())) + return dict(t=int(time.time() * 1000)) def send(self, packet_text): - _get_response( - self._http_session.post, - self._url, - params=self._params, - data=packet_text, - timeout=TIMEOUT_IN_SECONDS) + uri = self._url + "&" + '&'.join("%s=%s" % (k, v) for (k, v) in self._params.iteritems()); + response = None; + try: + response = requests.post(uri, data=packet_text); + except requests.exceptions.Timeout as e: + message = 'timed out while sending %s (%s)' % (packet_text, e) + _log.warn(message) + raise TimeoutError(e) + except requests.exceptions.ConnectionError as e: + message = 'disconnected while sending %s (%s)' % (packet_text, e) + _log.warn(message) + raise ConnectionError(message) + except requests.exceptions.SSLError as e: + raise ConnectionError('could not negotiate SSL (%s)' % e) + status = response.status_code + if 200 != status: + raise ConnectionError('unexpected status code (%s)' % status) + return response def recv(self): + if self._waiting: + return; + + self._waiting = True; response = _get_response( self._http_session.get, self._url, - params=self._params, - timeout=TIMEOUT_IN_SECONDS) - #response_text = response.content - response_text = response.text - if not response_text.startswith(BOUNDARY): - yield response_text - return - for packet_text in _yield_text_from_framed_data(response_text): - yield packet_text + params = self._params, + timeout = self._timeout) + + self._waiting = False; + if response is None: + return; + + response_text = response; + #response_text = response.text + #if not response_text.startswith(BOUNDARY): + yield response_text + return + #for packet_text in _yield_text_from_framed_data(response_text): + # yield packet_text def close(self): - _get_response( - self._http_session.get, - self._url, - params=dict(self._params.items() + [('disconnect', True)])) + self.send_packet(41) + self.send_packet(1) self._connected = False @@ -310,8 +365,10 @@ def _yield_text_from_framed_data(framed_data, parse=lambda x: x): def _get_response(request, *args, **kw): + response = None; try: - response = request(*args, **kw) + response = request(*args, **kw); + response.close(); except requests.exceptions.Timeout as e: raise TimeoutError(e) except requests.exceptions.ConnectionError as e: