Add locks to fix concurrency issues

This commit is contained in:
Roy Hyunjin Han 2015-04-15 14:38:25 -04:00
commit a6f9260964
5 changed files with 128 additions and 95 deletions

View file

@ -40,10 +40,12 @@ class EngineIO(LoggingMixin):
self._wait_for_connection = wait_for_connection
self._client_transports = transports
self._hurry_interval_in_seconds = hurry_interval_in_seconds
self._kw = kw
self._http_session = prepare_http_session(kw)
self._log_name = self._url
self._wants_to_close = False
self._opened = False
if Namespace:
self.define(Namespace)
self._transport
@ -63,7 +65,6 @@ class EngineIO(LoggingMixin):
def _get_engineIO_session(self):
warning_screen = self._yield_warning_screen()
self._http_session = prepare_http_session(self._kw)
for elapsed_time in warning_screen:
transport = XHR_PollingTransport(
self._http_session, self._is_secure, self._url)
@ -76,7 +77,7 @@ class EngineIO(LoggingMixin):
raise
warning = Exception('[waiting for connection] %s' % e)
warning_screen.throw(warning)
assert engineIO_packet_type == 0
assert engineIO_packet_type == 0 # engineIO_packet_type == open
return parse_engineIO_session(engineIO_packet_data)
def _negotiate_transport(self):
@ -93,6 +94,8 @@ class EngineIO(LoggingMixin):
transport.send_packet(5, '')
self._transport_instance = transport
self.transport_name = 'websocket'
else:
self._warn('unexpected engine.io packet')
except Exception:
pass
self._debug('[transport selected] %s', self.transport_name)
@ -100,8 +103,9 @@ class EngineIO(LoggingMixin):
def _reset_heartbeat(self):
try:
self._heartbeat_thread.halt()
hurried = self._heartbeat_thread.hurried
except AttributeError:
pass
hurried = False
ping_interval = self._engineIO_session.ping_interval
if self.transport_name.endswith('-polling'):
# Use ping/pong to unblock recv for polling transport
@ -114,6 +118,8 @@ class EngineIO(LoggingMixin):
relax_interval_in_seconds=ping_interval,
hurry_interval_in_seconds=hurry_interval_in_seconds)
self._heartbeat_thread.start()
if hurried:
self._heartbeat_thread.hurry()
self._debug('[heartbeat reset]')
def _connect_namespaces(self):
@ -161,7 +167,6 @@ class EngineIO(LoggingMixin):
def send(self, engineIO_packet_data):
self._message(engineIO_packet_data)
@retry
def _open(self):
engineIO_packet_type = 0
self._transport_instance.send_packet(engineIO_packet_type)
@ -172,34 +177,36 @@ class EngineIO(LoggingMixin):
if not self._opened:
return
engineIO_packet_type = 1
self._transport_instance.send_packet(engineIO_packet_type)
try:
self._transport_instance.send_packet(engineIO_packet_type)
except (TimeoutError, ConnectionError):
pass
self._opened = False
@retry
def _ping(self, engineIO_packet_data=''):
engineIO_packet_type = 2
self._transport_instance.send_packet(
engineIO_packet_type, engineIO_packet_data)
@retry
def _pong(self, engineIO_packet_data=''):
engineIO_packet_type = 3
self._transport_instance.send_packet(
engineIO_packet_type, engineIO_packet_data)
@retry
def _message(self, engineIO_packet_data):
def _message(self, engineIO_packet_data, with_transport_instance=False):
engineIO_packet_type = 4
self._transport_instance.send_packet(
engineIO_packet_type, engineIO_packet_data)
if with_transport_instance:
transport = self._transport_instance
else:
transport = self._transport
transport.send_packet(engineIO_packet_type, engineIO_packet_data)
self._debug('[socket.io packet sent] %s', engineIO_packet_data)
@retry
def _upgrade(self):
engineIO_packet_type = 5
self._transport_instance.send_packet(engineIO_packet_type)
@retry
def _noop(self):
engineIO_packet_type = 6
self._transport_instance.send_packet(engineIO_packet_type)
@ -223,6 +230,7 @@ class EngineIO(LoggingMixin):
except TimeoutError:
pass
except ConnectionError as e:
self._opened = False
try:
warning = Exception('[connection error] %s' % e)
warning_screen.throw(warning)
@ -263,31 +271,31 @@ class EngineIO(LoggingMixin):
except KeyError:
raise PacketError(
'unexpected engine.io packet type (%s)' % engineIO_packet_type)
delegate(engineIO_packet_data, namespace._find_packet_callback)
delegate(engineIO_packet_data, namespace)
if engineIO_packet_type is 4:
return engineIO_packet_data
def _on_open(self, data, find_packet_callback):
find_packet_callback('open')()
def _on_open(self, data, namespace):
namespace._find_packet_callback('open')()
def _on_close(self, data, find_packet_callback):
find_packet_callback('close')()
def _on_close(self, data, namespace):
namespace._find_packet_callback('close')()
def _on_ping(self, data, find_packet_callback):
def _on_ping(self, data, namespace):
self._pong(data)
find_packet_callback('ping')(data)
namespace._find_packet_callback('ping')(data)
def _on_pong(self, data, find_packet_callback):
find_packet_callback('pong')(data)
def _on_pong(self, data, namespace):
namespace._find_packet_callback('pong')(data)
def _on_message(self, data, find_packet_callback):
find_packet_callback('message')(data)
def _on_message(self, data, namespace):
namespace._find_packet_callback('message')(data)
def _on_upgrade(self, data, find_packet_callback):
find_packet_callback('upgrade')()
def _on_upgrade(self, data, namespace):
namespace._find_packet_callback('upgrade')()
def _on_noop(self, data, find_packet_callback):
find_packet_callback('noop')()
def _on_noop(self, data, namespace):
namespace._find_packet_callback('noop')()
class SocketIO(EngineIO):
@ -329,7 +337,7 @@ class SocketIO(EngineIO):
for path, namespace in self._namespace_by_path.items():
namespace._transport = self._transport_instance
if path:
self.connect(path)
self.connect(path, with_transport_instance=True)
def __exit__(self, *exception_pack):
self.disconnect()
@ -342,9 +350,10 @@ class SocketIO(EngineIO):
# Define
def define(self, Namespace, path=''):
self._namespace_by_path[path] = namespace = Namespace(self, path)
if path:
self.connect(path)
self._namespace_by_path[path] = namespace = Namespace(self, path)
self.wait(for_connect=True)
return namespace
def on(self, event, callback, path=''):
@ -362,20 +371,23 @@ class SocketIO(EngineIO):
# Act
def connect(self, path):
def connect(self, path, with_transport_instance=False):
socketIO_packet_type = 0
socketIO_packet_data = format_socketIO_packet_data(path)
self._message(str(socketIO_packet_type) + socketIO_packet_data)
self._message(
str(socketIO_packet_type) + socketIO_packet_data,
with_transport_instance)
def disconnect(self, path=''):
if not self._opened:
return
if path:
if not path or not self._opened:
self._close()
elif path:
socketIO_packet_type = 1
socketIO_packet_data = format_socketIO_packet_data(path)
self._message(str(socketIO_packet_type) + socketIO_packet_data)
else:
self._close()
try:
self._message(str(socketIO_packet_type) + socketIO_packet_data)
except (TimeoutError, ConnectionError):
pass
try:
namespace = self._namespace_by_path.pop(path)
namespace.on_disconnect()
@ -405,13 +417,17 @@ class SocketIO(EngineIO):
# React
def wait(self, seconds=None, for_callbacks=False):
super(SocketIO, self).wait(seconds, for_callbacks=for_callbacks)
def wait_for_callbacks(self, seconds=None):
self.wait(seconds, for_callbacks=True)
def _should_stop_waiting(self, for_callbacks):
def _should_stop_waiting(self, for_connect=False, for_callbacks=False):
if for_connect:
for namespace in self._namespace_by_path.values():
is_namespace_connected = getattr(
namespace, '_connected', False)
if not is_namespace_connected:
return False
return True
if for_callbacks and not self._has_ack_callback:
return True
return super(SocketIO, self)._should_stop_waiting()
@ -439,16 +455,18 @@ class SocketIO(EngineIO):
except KeyError:
raise PacketError(
'unexpected socket.io packet type (%s)' % socketIO_packet_type)
delegate(socketIO_packet_data, namespace._find_packet_callback)
delegate(socketIO_packet_data, namespace)
return socketIO_packet_data
def _on_connect(self, data, find_packet_callback):
find_packet_callback('connect')()
def _on_connect(self, data, namespace):
namespace._connected = True
namespace._find_packet_callback('connect')()
def _on_disconnect(self, data, find_packet_callback):
find_packet_callback('disconnect')()
def _on_disconnect(self, data, namespace):
namespace._connected = False
namespace._find_packet_callback('disconnect')()
def _on_event(self, data, find_packet_callback):
def _on_event(self, data, namespace):
data_parsed = parse_socketIO_packet_data(data)
args = data_parsed.args
try:
@ -458,9 +476,9 @@ class SocketIO(EngineIO):
if data_parsed.ack_id is not None:
args.append(self._prepare_to_send_ack(
data_parsed.path, data_parsed.ack_id))
find_packet_callback(event)(*args)
namespace._find_packet_callback(event)(*args)
def _on_ack(self, data, find_packet_callback):
def _on_ack(self, data, namespace):
data_parsed = parse_socketIO_packet_data(data)
try:
ack_callback = self._get_ack_callback(data_parsed.ack_id)
@ -468,13 +486,13 @@ class SocketIO(EngineIO):
return
ack_callback(*data_parsed.args)
def _on_error(self, data, find_packet_callback):
find_packet_callback('error')(data)
def _on_error(self, data, namespace):
namespace._find_packet_callback('error')(data)
def _on_binary_event(self, data, find_packet_callback):
def _on_binary_event(self, data, namespace):
self._warn('[not implemented] binary event')
def _on_binary_ack(self, data, find_packet_callback):
def _on_binary_ack(self, data, namespace):
self._warn('[not implemented] binary ack')
def _prepare_to_send_ack(self, path, ack_id):

View file

@ -1,3 +1,4 @@
import logging
from threading import Thread, Event
from .exceptions import ConnectionError, TimeoutError
@ -22,17 +23,17 @@ class HeartbeatThread(Thread):
def run(self):
try:
while not self._halt.is_set():
try:
self._send_heartbeat()
except TimeoutError:
pass
if self._adrenaline.is_set():
interval_in_seconds = self._hurry_interval_in_seconds
else:
interval_in_seconds = self._relax_interval_in_seconds
self._rest.wait(interval_in_seconds)
try:
self._send_heartbeat()
except TimeoutError:
pass
except ConnectionError:
pass
logging.debug('[heartbeat connection error]')
def relax(self):
self._adrenaline.clear()
@ -42,6 +43,10 @@ class HeartbeatThread(Thread):
self._rest.set()
self._rest.clear()
@property
def hurried(self):
return self._adrenaline.is_set()
def halt(self):
self._rest.set()
self._halt.set()

View file

@ -141,31 +141,31 @@ class SocketIONamespace(EngineIONamespace):
class LoggingEngineIONamespace(EngineIONamespace):
def on_open(self):
self._debug('[open]')
self._debug('[engine.io open]')
super(LoggingEngineIONamespace, self).on_open()
def on_close(self):
self._debug('[close]')
self._debug('[engine.io close]')
super(LoggingEngineIONamespace, self).on_close()
def on_ping(self, data):
self._debug('[ping] %s', data)
self._debug('[engine.io ping] %s', data)
super(LoggingEngineIONamespace, self).on_ping(data)
def on_pong(self, data):
self._debug('[pong] %s', data)
self._debug('[engine.io pong] %s', data)
super(LoggingEngineIONamespace, self).on_pong(data)
def on_message(self, data):
self._debug('[message] %s', data)
self._debug('[engine.io message] %s', data)
super(LoggingEngineIONamespace, self).on_message(data)
def on_upgrade(self):
self._debug('[upgrade]')
self._debug('[engine.io upgrade]')
super(LoggingEngineIONamespace, self).on_upgrade()
def on_noop(self):
self._debug('[noop]')
self._debug('[engine.io noop]')
super(LoggingEngineIONamespace, self).on_noop()
def on_event(self, event, *args):
@ -173,22 +173,25 @@ class LoggingEngineIONamespace(EngineIONamespace):
arguments = [repr(_) for _ in args]
if callback:
arguments.append('callback(*args)')
self._info('[event] %s(%s)', event, ', '.join(arguments))
self._info('[engine.io event] %s(%s)', event, ', '.join(arguments))
super(LoggingEngineIONamespace, self).on_event(event, *args)
class LoggingSocketIONamespace(SocketIONamespace):
class LoggingSocketIONamespace(SocketIONamespace, LoggingEngineIONamespace):
def on_connect(self):
self._debug('%s[connect]', _make_logging_header(self.path))
self._debug(
'%s[socket.io connect]', _make_logging_header(self.path))
super(LoggingSocketIONamespace, self).on_connect()
def on_reconnect(self):
self._debug('%s[reconnect]', _make_logging_header(self.path))
self._debug(
'%s[socket.io reconnect]', _make_logging_header(self.path))
super(LoggingSocketIONamespace, self).on_reconnect()
def on_disconnect(self):
self._debug('%s[disconnect]', _make_logging_header(self.path))
self._debug(
'%s[socket.io disconnect]', _make_logging_header(self.path))
super(LoggingSocketIONamespace, self).on_disconnect()
def on_event(self, event, *args):
@ -197,12 +200,13 @@ class LoggingSocketIONamespace(SocketIONamespace):
if callback:
arguments.append('callback(*args)')
self._info(
'%s[event] %s(%s)', _make_logging_header(self.path), event,
', '.join(arguments))
'%s[socket.io event] %s(%s)', _make_logging_header(self.path),
event, ', '.join(arguments))
super(LoggingSocketIONamespace, self).on_event(event, *args)
def on_error(self, data):
self._debug('%s[error] %s', _make_logging_header(self.path), data)
self._debug(
'%s[socket.io error] %s', _make_logging_header(self.path), data)
super(LoggingSocketIONamespace, self).on_error()

View file

@ -17,10 +17,11 @@ class BaseMixin(object):
def setUp(self):
super(BaseMixin, self).setUp()
self.called_on_response = False
self.wait_time_in_seconds = 1
def tearDown(self):
super(BaseMixin, self).tearDown()
del self.socketIO
self.socketIO.disconnect()
def test_disconnect(self):
'Disconnect'
@ -161,7 +162,6 @@ class Test_XHR_PollingTransport(BaseMixin, TestCase):
self.socketIO = SocketIO(HOST, PORT, LoggingNamespace, transports=[
'xhr-polling'], verify=False)
self.assertEqual(self.socketIO.transport_name, 'xhr-polling')
self.wait_time_in_seconds = 1
class Test_WebsocketTransport(BaseMixin, TestCase):
@ -171,7 +171,6 @@ class Test_WebsocketTransport(BaseMixin, TestCase):
self.socketIO = SocketIO(HOST, PORT, LoggingNamespace, transports=[
'xhr-polling', 'websocket'], verify=False)
self.assertEqual(self.socketIO.transport_name, 'websocket')
self.wait_time_in_seconds = 1
class Namespace(LoggingNamespace):

View file

@ -3,6 +3,7 @@ import six
import socket
import ssl
import sys
import threading
import time
import websocket
@ -54,10 +55,11 @@ class XHR_PollingTransport(AbstractTransport):
'EIO': ENGINEIO_PROTOCOL, 'transport': 'polling'}
if engineIO_session:
self._request_index = 1
self._kw_get = dict(timeout=engineIO_session.ping_timeout)
self._kw_post = dict(headers={
'content-type': 'application/octet-stream',
})
self._kw_get = dict(
timeout=engineIO_session.ping_timeout)
self._kw_post = dict(
timeout=engineIO_session.ping_timeout,
headers={'content-type': 'application/octet-stream'})
self._params['sid'] = engineIO_session.id
else:
self._request_index = 0
@ -65,6 +67,8 @@ class XHR_PollingTransport(AbstractTransport):
self._kw_post = {}
http_scheme = 'https' if is_secure else 'http'
self._http_url = '%s://%s/' % (http_scheme, url)
self._request_index_lock = threading.Lock()
self._send_packet_lock = threading.Lock()
def recv_packet(self):
params = dict(self._params)
@ -79,21 +83,24 @@ class XHR_PollingTransport(AbstractTransport):
yield engineIO_packet_type, engineIO_packet_data
def send_packet(self, engineIO_packet_type, engineIO_packet_data=''):
params = dict(self._params)
params['t'] = self._get_timestamp()
response = get_response(
self.http_session.post,
self._http_url,
params=params,
data=encode_engineIO_content([
(engineIO_packet_type, engineIO_packet_data),
]),
**self._kw_post)
assert response.content == b'ok'
with self._send_packet_lock:
params = dict(self._params)
params['t'] = self._get_timestamp()
response = get_response(
self.http_session.post,
self._http_url,
params=params,
data=encode_engineIO_content([
(engineIO_packet_type, engineIO_packet_data),
]),
**self._kw_post)
assert response.content == b'ok'
def _get_timestamp(self):
timestamp = '%s-%s' % (int(time.time() * 1000), self._request_index)
self._request_index += 1
with self._request_index_lock:
timestamp = '%s-%s' % (
int(time.time() * 1000), self._request_index)
self._request_index += 1
return timestamp
@ -164,7 +171,7 @@ class WebsocketTransport(AbstractTransport):
def get_response(request, *args, **kw):
try:
response = request(*args, **kw)
response = request(*args, stream=True, **kw)
except requests.exceptions.Timeout as e:
raise TimeoutError(e)
except requests.exceptions.ConnectionError as e: