Added support for websockets via upgrade paradigm. Also added support for series of packets in responses rather than assuming single packets each time. Added support for all message fields in socket.io protocol

This commit is contained in:
Sean Arietta 2014-12-22 16:29:58 -08:00
commit 57971b5f71
3 changed files with 205 additions and 155 deletions

View file

@ -4,6 +4,7 @@ import logging
import json
import multiprocessing
import parser
from parser import Message, Packet, MessageType, PacketType
import requests
import time
@ -13,7 +14,7 @@ except ImportError:
from urlparse import urlparse as parse_url
from .exceptions import ConnectionError, TimeoutError, PacketError
from .transports import _get_response, _negotiate_transport, TRANSPORTS
from .transports import _get_response
_SocketIOSession = namedtuple('_SocketIOSession', [
@ -139,11 +140,10 @@ class SocketIO(object):
def __init__(
self, host, port=None, Namespace=BaseNamespace,
wait_for_connection=True, transports=TRANSPORTS, **kw):
wait_for_connection=True, **kw):
self.is_secure, self.base_url = _parse_host(host, port)
self.wait_for_connection = wait_for_connection
self._namespace_by_path = {}
self.client_supported_transports = transports
self.kw = kw
# These two fields work to control the heartbeat thread.
self.heartbeat_terminator = None;
@ -250,10 +250,15 @@ class SocketIO(object):
_log.warn('[packet error] %s', e)
def _process_packet(self, packet):
code, packet_id, path, data = packet
code, packet_id, path, data, p = packet
namespace = self.get_namespace(path)
delegate = self._get_delegate(code)
delegate(packet, namespace._find_event_callback)
delegate = None;
try:
delegate = self._get_delegate(code)
except:
pass;
if delegate is not None:
delegate(packet, namespace._find_event_callback)
def _stop_waiting(self, for_callbacks):
# Use __transport to make sure that we do not reconnect inadvertently
@ -300,14 +305,40 @@ class SocketIO(object):
_log.warn(warning)
return self.__transport
def _get_transport(self):
self.session = _get_socketIO_session(self.is_secure, self.base_url, **self.kw)
_log.debug('[transports available] %s', ' '.join(self.session.server_supported_transports))
def _upgrade(self):
websocket = transports.WebsocketTransport(self.session, self.is_secure, self.base_url, **self.kw);
websocket.send_packet(PacketType.PING, "", "probe");
for packet in websocket.recv_packet():
_log.debug("[websocket] Packet: %s" % str(packet));
(code, packet_id, path, data, p) = packet;
if code == PacketType.PONG:
packet = p;
_log.debug("[PONG] %s" % repr(packet));
# Negotiate transport
transport = _negotiate_transport(
self.client_supported_transports, self.session,
self.is_secure, self.base_url, **self.kw)
self.heartbeat_terminator.set();
# Technically we would need to pause the current
# transport (which should be polling in this
# implementation), but since we haven't actually
# started a polling yet, we can upgrade without that.
_log.debug("[upgrading] Sending upgrade request");
websocket.send_packet(PacketType.UPGRADE);
self._start_heartbeat(websocket);
return websocket;
def _start_heartbeat(self, transport):
_log.debug("[start heartbeat pacemaker]");
self.heartbeat_terminator = multiprocessing.Event();
self.heartbeat_thread = multiprocessing.Process(
target = _make_heartbeat_pacemaker,
args = (self.heartbeat_terminator, transport, self.session.heartbeat_interval / 2));
self.heartbeat_thread.start();
def _get_transport(self):
self.session = _get_socketIO_session(self.is_secure, self.base_url, **self.kw);
# Negotiate initial transport
transport = transports.XHR_PollingTransport(self.session, self.is_secure, self.base_url, **self.kw);
# Update namespaces
for path, namespace in self._namespace_by_path.items():
namespace._transport = transport
@ -316,12 +347,17 @@ class SocketIO(object):
transport.set_timeout(self.session.connection_timeout);
# Start the heartbeat pacemaker (PING).
_log.debug("[start heartbeat pacemaker]");
self.heartbeat_terminator = multiprocessing.Event();
self.heartbeat_thread = multiprocessing.Process(
target = _make_heartbeat_pacemaker,
args = (self.heartbeat_terminator, transport, self.session.heartbeat_interval / 2));
self.heartbeat_thread.start();
self._start_heartbeat(transport);
# If websocket is available, upgrade to it immediately.
# TODO(sean): We could run this on a separate thread for
# maximum efficiency although that would require some
# synchronization to ensure buffers are flushed, etc.
if "websocket" in self.session.server_supported_transports:
try:
return self._upgrade();
except:
pass;
return transport
@ -371,10 +407,19 @@ class SocketIO(object):
find_event_callback('message')(*args)
def _on_event(self, packet, find_event_callback):
code, packet_id, path, data = packet
value_by_name = json.loads(data)
event = value_by_name['name']
args = value_by_name.get('args', [])
code, packet_id, path, data, p = packet
packet = p;
# Accoding to the documentation
# (https://github.com/automattic/socket.io-protocol#event),
# the event name is the first entry in the message array, and
# the arguments are the rest of the entries.
event = packet.payload.message[0];
args = packet.payload.message[1:] if len(packet.payload.message) > 1 else [];
_log.debug("[event] %s (%s)" % (repr(event), repr(args)));
if packet_id:
args.append(self._prepare_to_send_ack(path, packet_id))
find_event_callback(event)(*args)
@ -455,12 +500,11 @@ def _get_socketIO_session(is_secure, base_url, **kw):
raise ConnectionError(e)
_log.debug("[response] %s", response.text);
packet = parser.decode_response(response);
_log.debug("[decoded] %s", repr(packet));
if packet.type is not parser.PacketType.OPEN:
_log.warn("Got unexpected packet during connection handshake: %d" % packet.type);
return None;
for packet in parser.decode_response(response):
_log.debug("[decoded] %s", str(packet));
if packet.type is not parser.PacketType.OPEN:
_log.warn("Got unexpected packet during connection handshake: %d" % packet.type);
return None;
handshake = json.loads(packet.payload);
@ -468,7 +512,7 @@ def _get_socketIO_session(is_secure, base_url, **kw):
id = handshake["sid"],
heartbeat_interval = int(handshake["pingInterval"]) / 1000,
connection_timeout = int(handshake["pingTimeout"]) / 1000,
server_supported_transports = ["xhr-polling"]);#decoded["payload"]["upgrades"]);
server_supported_transports = handshake["upgrades"]);
def _make_heartbeat_pacemaker(terminator, transport, heartbeat_interval):
while True:

View file

@ -29,8 +29,11 @@ class Packet():
self.type = packet_type;
self.payload = payload;
def __str__(self):
return "PACKET{type: " + str(self.type) + ", payload: " + str(self.payload) + "}";
class Message():
def __init__(self, message_type, message, path = ""):
def __init__(self, message_type, message, path = "", attachments = "", message_id = None):
self.type = message_type;
if isinstance(message, basestring):
try:
@ -41,6 +44,23 @@ class Message():
self.message = message;
self.path = path;
self.attachments = attachments;
self.id = message_id;
def __str__(self):
if self.id is not None:
return "MESSAGE{" + \
"id: " + str(self.id) + ", " + \
"type: " + str(self.type) + ", " + \
"message: " + str(self.message) + ", " + \
"path: " + self.path + \
"}";
else:
return "MESSAGE{" + \
"type: " + str(self.type) + ", " + \
"message: " + str(self.message) + ", " + \
"path: " + self.path + \
"}";
def encode_as_json(self):
"""Encodes a Message to be sent to socket.io server.
@ -61,17 +81,78 @@ def decode_message(payload):
""" Decodes a message encoded via socket.io
"""
message_type = int(payload[0]);
message = payload[1:];
_log.debug("[decode payload] %s" % repr(payload));
return Message(message_type, message);
i = 0;
message_type = int(payload[i]);
message = "";
path = "";
attachments = "";
message_id = None;
i += 1;
if len(payload) > i:
if message_type == MessageType.BINARY_EVENT or message_type == MessageType.BINARY_ACK:
while (payload[i] != "-"):
attachments += payload[i];
i += 1;
if len(payload) > i:
# This is kind of odd, but it is how socket.io-parser works (see
# https://github.com/Automattic/socket.io-parser/blob/master/index.js#L292
# @0ae9a4f).
if payload[i] == "/":
if "," in payload:
split_point = payload.index(",");
path = payload[i:split_point];
i += split_point;
else:
path = payload[i:];
i += len(path);
if len(payload) > i:
# This is the same pecularity as above.
if "," in payload[i:]:
split_point = payload.index(",");
message_id = int(payload[i:split_point]);
i += split_point;
if len(payload) > i:
message = payload[i:];
return Message(message_type, message, path, attachments, message_id);
def decode_response(response):
"""Decodes a response from requests lib.
"""
# TODO(sean): Should we use the 'raw' stream instead?
return decode_packet(response.content);
if isinstance(response, basestring):
_log.debug("[decode response (string)] Response: %s" % str(response));
packet = decode_packet_string(response);
yield packet;
else:
content = response.content;
total_length = len(content);
processed = 0;
while processed < total_length:
_log.debug("[decode response] Content: %s" % str(content));
(read, packet) = decode_packet(content);
content = content[read:];
processed += read;
yield packet;
def decode_packet_string(packet):
packet_type = int(packet[0]);
payload = packet[1:];
if packet_type == PacketType.MESSAGE:
message = decode_message(payload);
payload = message;
return Packet(packet_type, payload);
def decode_packet(packet):
"""Decodes a packet sent via engine.io.
@ -108,15 +189,11 @@ def decode_packet(packet):
if packet_type is PacketType.MESSAGE:
message = decode_message(payload);
payload = message;
return Packet(packet_type, payload);
return offset + length, Packet(packet_type, payload);
else:
pass;
return "";
pass;
def encode_packet_string(code, path, data):
"""Encodes packet to be sent to socket.io server.
"""

View file

@ -45,6 +45,19 @@ class _AbstractTransport(object):
_log.debug("Connecting to path: %s" % path);
data = Message(MessageType.CONNECT, path).encode_as_string();
self.send_packet(PacketType.MESSAGE, path, data);
# Wait for response.
responded = False;
while not responded:
for packet in self.recv_packet():
_log.debug("[connect wait] Waiting for confirmation");
(code, packet_id, ignore, data, p) = packet;
packet = p;
if (packet.type == PacketType.MESSAGE
and packet.payload.type == MessageType.CONNECT
and packet.payload.path == path):
_log.debug("Connected to path: %s" % path);
responded = True;
else:
self.send_packet(PacketType.OPEN, path, data);
@ -85,27 +98,20 @@ class _AbstractTransport(object):
yield self._packets.pop(0)
except IndexError:
pass
for response in self.recv():
_log.debug('[packet received] %s', response.text);
try:
packet = parser.decode_response(response);
except AttributeError:
_log.warn('[packet error] %s', response.text)
continue
for packet in self.recv():
code, packet_id, path, data = None, None, '', None
if packet.type is PacketType.OPEN:
code = '1';
continue;
elif packet.type is PacketType.CLOSE:
code = '0';
elif packet.type is PacketType.PING:
code = '2';
elif packet.type is PacketType.PONG:
code = '2';
code = PacketType.PONG;
elif packet.type is PacketType.UPGRADE:
_log.warn("Don't know how to handle upgrade packets");
yield code, packet_id, path, data;
yield code, packet_id, path, data, packet;
elif packet.type is PacketType.NOOP:
code = '8';
elif packet.type is PacketType.MESSAGE:
@ -122,12 +128,12 @@ class _AbstractTransport(object):
code = '7';
else:
_log.warn("Don't know how to handle message type: %d" % packet.payload.type);
yield code, packet_id, path, data;
yield code, packet_id, path, data, packet;
else:
_log.warn("Don't know how to handle packet type: %d" % packet.type);
yield code, packet_id, path, data;
yield code, packet_id, path, data, packet;
yield code, packet_id, path, data
yield code, packet_id, path, data, packet
def _enqueue_packet(self, packet):
self._packets.append(packet)
@ -149,15 +155,17 @@ class _AbstractTransport(object):
return True if self._callback_by_packet_id else False
class _WebsocketTransport(_AbstractTransport):
class WebsocketTransport(_AbstractTransport):
def __init__(self, socketIO_session, is_secure, base_url, **kw):
super(_WebsocketTransport, self).__init__()
url = '%s://%s/websocket/%s' % (
super(WebsocketTransport, self).__init__()
self._url = '%s://%s/?EIO=%d&transport=websocket&sid=%s' % (
'wss' if is_secure else 'ws',
base_url, socketIO_session.id)
base_url, parser.ENGINE_PROTOCOL, socketIO_session.id)
try:
self._connection = websocket.create_connection(url)
self._connection = websocket.create_connection(self._url)
except socket.timeout as e:
raise ConnectionError(e)
except socket.error as e:
@ -168,6 +176,11 @@ class _WebsocketTransport(_AbstractTransport):
def connected(self):
return self._connection.connected
def send_packet(self, code, path="", data='', callback=None):
packet_text = Message(code, data).encode_as_string();
self.send(packet_text)
_log.debug('[packet sent] %s', packet_text)
def send(self, packet_text):
try:
self._connection.send(packet_text)
@ -182,7 +195,14 @@ class _WebsocketTransport(_AbstractTransport):
def recv(self):
try:
yield self._connection.recv()
response = self._connection.recv();
try:
for packet in parser.decode_response(response):
_log.debug('[websocket packet received] %s', str(packet));
yield packet;
except AttributeError:
_log.warn('[packet error] %s', repr(response))
return;
except websocket.WebSocketTimeoutException as e:
raise TimeoutError(e)
except websocket.SSLError as e:
@ -199,10 +219,10 @@ class _WebsocketTransport(_AbstractTransport):
self._connection.close()
class _XHR_PollingTransport(_AbstractTransport):
class XHR_PollingTransport(_AbstractTransport):
def __init__(self, socketIO_session, is_secure, base_url, **kw):
super(_XHR_PollingTransport, self).__init__()
super(XHR_PollingTransport, self).__init__()
self._url = '%s://%s/?EIO=%d&transport=polling&sid=%s' % (
'https' if is_secure else 'http',
base_url, parser.ENGINE_PROTOCOL, socketIO_session.id)
@ -210,10 +230,6 @@ class _XHR_PollingTransport(_AbstractTransport):
self._http_session = _prepare_http_session(kw)
self._waiting = False;
# Create connection
#for packet in self.recv_packet():
# self._enqueue_packet(packet)
@property
def connected(self):
return self._connected
@ -257,102 +273,15 @@ class _XHR_PollingTransport(_AbstractTransport):
if response is None:
return;
response_text = response;
#response_text = response.text
#if not response_text.startswith(BOUNDARY):
yield response_text
for packet in parser.decode_response(response):
yield packet;
return
#for packet_text in _yield_text_from_framed_data(response_text):
# yield packet_text
def close(self):
self.send_packet(41)
self.send_packet(1)
self._connected = False
class _JSONP_PollingTransport(_AbstractTransport):
RESPONSE_PATTERN = re.compile(r'io.j\[(\d+)\]\("(.*)"\);')
def __init__(self, socketIO_session, is_secure, base_url, **kw):
super(_JSONP_PollingTransport, self).__init__()
self._url = '%s://%s/jsonp-polling/%s' % (
'https' if is_secure else 'http',
base_url, socketIO_session.id)
self._connected = True
self._http_session = _prepare_http_session(kw)
self._id = 0
# Create connection
for packet in self.recv_packet():
self._enqueue_packet(packet)
@property
def connected(self):
return self._connected
@property
def _params(self):
return dict(t=int(time.time()), i=self._id)
def send(self, packet_text):
_get_response(
self._http_session.post,
self._url,
params=self._params,
data='d=%s' % requests.utils.quote(json.dumps(packet_text)),
headers={'content-type': 'application/x-www-form-urlencoded'},
timeout=TIMEOUT_IN_SECONDS)
def recv(self):
'Decode the JavaScript response so that we can parse it as JSON'
response = _get_response(
self._http_session.get,
self._url,
params=self._params,
headers={'content-type': 'text/javascript; charset=UTF-8'},
timeout=TIMEOUT_IN_SECONDS)
response_text = response.text
try:
self._id, response_text = self.RESPONSE_PATTERN.match(
response_text).groups()
except AttributeError:
_log.warn('[packet error] %s', response_text)
return
if not response_text.startswith(BOUNDARY):
yield response_text.decode('unicode_escape')
return
for packet_text in _yield_text_from_framed_data(
response_text, parse=lambda x: x.decode('unicode_escape')):
yield packet_text
def close(self):
_get_response(
self._http_session.get,
self._url,
params=dict(self._params.items() + [('disconnect', True)]))
self._connected = False
def _negotiate_transport(
client_supported_transports, session,
is_secure, base_url, **kw):
server_supported_transports = session.server_supported_transports
for supported_transport in client_supported_transports:
if supported_transport in server_supported_transports:
_log.debug('[transport selected] %s', supported_transport)
return {
'websocket': _WebsocketTransport,
'xhr-polling': _XHR_PollingTransport,
'jsonp-polling': _JSONP_PollingTransport,
}[supported_transport](session, is_secure, base_url, **kw)
raise SocketIOError(' '.join([
'could not negotiate a transport:',
'client supports %s but' % ', '.join(client_supported_transports),
'server supports %s' % ', '.join(server_supported_transports),
]))
def _yield_text_from_framed_data(framed_data, parse=lambda x: x):
parts = [parse(x) for x in framed_data.split(BOUNDARY)]
for text_length, text in zip(parts[1::2], parts[2::2]):