Add locks to fix concurrency issues
This commit is contained in:
parent
f5b157014d
commit
a6f9260964
5 changed files with 128 additions and 95 deletions
|
|
@ -40,10 +40,12 @@ class EngineIO(LoggingMixin):
|
|||
self._wait_for_connection = wait_for_connection
|
||||
self._client_transports = transports
|
||||
self._hurry_interval_in_seconds = hurry_interval_in_seconds
|
||||
self._kw = kw
|
||||
self._http_session = prepare_http_session(kw)
|
||||
|
||||
self._log_name = self._url
|
||||
self._wants_to_close = False
|
||||
self._opened = False
|
||||
|
||||
if Namespace:
|
||||
self.define(Namespace)
|
||||
self._transport
|
||||
|
|
@ -63,7 +65,6 @@ class EngineIO(LoggingMixin):
|
|||
|
||||
def _get_engineIO_session(self):
|
||||
warning_screen = self._yield_warning_screen()
|
||||
self._http_session = prepare_http_session(self._kw)
|
||||
for elapsed_time in warning_screen:
|
||||
transport = XHR_PollingTransport(
|
||||
self._http_session, self._is_secure, self._url)
|
||||
|
|
@ -76,7 +77,7 @@ class EngineIO(LoggingMixin):
|
|||
raise
|
||||
warning = Exception('[waiting for connection] %s' % e)
|
||||
warning_screen.throw(warning)
|
||||
assert engineIO_packet_type == 0
|
||||
assert engineIO_packet_type == 0 # engineIO_packet_type == open
|
||||
return parse_engineIO_session(engineIO_packet_data)
|
||||
|
||||
def _negotiate_transport(self):
|
||||
|
|
@ -93,6 +94,8 @@ class EngineIO(LoggingMixin):
|
|||
transport.send_packet(5, '')
|
||||
self._transport_instance = transport
|
||||
self.transport_name = 'websocket'
|
||||
else:
|
||||
self._warn('unexpected engine.io packet')
|
||||
except Exception:
|
||||
pass
|
||||
self._debug('[transport selected] %s', self.transport_name)
|
||||
|
|
@ -100,8 +103,9 @@ class EngineIO(LoggingMixin):
|
|||
def _reset_heartbeat(self):
|
||||
try:
|
||||
self._heartbeat_thread.halt()
|
||||
hurried = self._heartbeat_thread.hurried
|
||||
except AttributeError:
|
||||
pass
|
||||
hurried = False
|
||||
ping_interval = self._engineIO_session.ping_interval
|
||||
if self.transport_name.endswith('-polling'):
|
||||
# Use ping/pong to unblock recv for polling transport
|
||||
|
|
@ -114,6 +118,8 @@ class EngineIO(LoggingMixin):
|
|||
relax_interval_in_seconds=ping_interval,
|
||||
hurry_interval_in_seconds=hurry_interval_in_seconds)
|
||||
self._heartbeat_thread.start()
|
||||
if hurried:
|
||||
self._heartbeat_thread.hurry()
|
||||
self._debug('[heartbeat reset]')
|
||||
|
||||
def _connect_namespaces(self):
|
||||
|
|
@ -161,7 +167,6 @@ class EngineIO(LoggingMixin):
|
|||
def send(self, engineIO_packet_data):
|
||||
self._message(engineIO_packet_data)
|
||||
|
||||
@retry
|
||||
def _open(self):
|
||||
engineIO_packet_type = 0
|
||||
self._transport_instance.send_packet(engineIO_packet_type)
|
||||
|
|
@ -172,34 +177,36 @@ class EngineIO(LoggingMixin):
|
|||
if not self._opened:
|
||||
return
|
||||
engineIO_packet_type = 1
|
||||
self._transport_instance.send_packet(engineIO_packet_type)
|
||||
try:
|
||||
self._transport_instance.send_packet(engineIO_packet_type)
|
||||
except (TimeoutError, ConnectionError):
|
||||
pass
|
||||
self._opened = False
|
||||
|
||||
@retry
|
||||
def _ping(self, engineIO_packet_data=''):
|
||||
engineIO_packet_type = 2
|
||||
self._transport_instance.send_packet(
|
||||
engineIO_packet_type, engineIO_packet_data)
|
||||
|
||||
@retry
|
||||
def _pong(self, engineIO_packet_data=''):
|
||||
engineIO_packet_type = 3
|
||||
self._transport_instance.send_packet(
|
||||
engineIO_packet_type, engineIO_packet_data)
|
||||
|
||||
@retry
|
||||
def _message(self, engineIO_packet_data):
|
||||
def _message(self, engineIO_packet_data, with_transport_instance=False):
|
||||
engineIO_packet_type = 4
|
||||
self._transport_instance.send_packet(
|
||||
engineIO_packet_type, engineIO_packet_data)
|
||||
if with_transport_instance:
|
||||
transport = self._transport_instance
|
||||
else:
|
||||
transport = self._transport
|
||||
transport.send_packet(engineIO_packet_type, engineIO_packet_data)
|
||||
self._debug('[socket.io packet sent] %s', engineIO_packet_data)
|
||||
|
||||
@retry
|
||||
def _upgrade(self):
|
||||
engineIO_packet_type = 5
|
||||
self._transport_instance.send_packet(engineIO_packet_type)
|
||||
|
||||
@retry
|
||||
def _noop(self):
|
||||
engineIO_packet_type = 6
|
||||
self._transport_instance.send_packet(engineIO_packet_type)
|
||||
|
|
@ -223,6 +230,7 @@ class EngineIO(LoggingMixin):
|
|||
except TimeoutError:
|
||||
pass
|
||||
except ConnectionError as e:
|
||||
self._opened = False
|
||||
try:
|
||||
warning = Exception('[connection error] %s' % e)
|
||||
warning_screen.throw(warning)
|
||||
|
|
@ -263,31 +271,31 @@ class EngineIO(LoggingMixin):
|
|||
except KeyError:
|
||||
raise PacketError(
|
||||
'unexpected engine.io packet type (%s)' % engineIO_packet_type)
|
||||
delegate(engineIO_packet_data, namespace._find_packet_callback)
|
||||
delegate(engineIO_packet_data, namespace)
|
||||
if engineIO_packet_type is 4:
|
||||
return engineIO_packet_data
|
||||
|
||||
def _on_open(self, data, find_packet_callback):
|
||||
find_packet_callback('open')()
|
||||
def _on_open(self, data, namespace):
|
||||
namespace._find_packet_callback('open')()
|
||||
|
||||
def _on_close(self, data, find_packet_callback):
|
||||
find_packet_callback('close')()
|
||||
def _on_close(self, data, namespace):
|
||||
namespace._find_packet_callback('close')()
|
||||
|
||||
def _on_ping(self, data, find_packet_callback):
|
||||
def _on_ping(self, data, namespace):
|
||||
self._pong(data)
|
||||
find_packet_callback('ping')(data)
|
||||
namespace._find_packet_callback('ping')(data)
|
||||
|
||||
def _on_pong(self, data, find_packet_callback):
|
||||
find_packet_callback('pong')(data)
|
||||
def _on_pong(self, data, namespace):
|
||||
namespace._find_packet_callback('pong')(data)
|
||||
|
||||
def _on_message(self, data, find_packet_callback):
|
||||
find_packet_callback('message')(data)
|
||||
def _on_message(self, data, namespace):
|
||||
namespace._find_packet_callback('message')(data)
|
||||
|
||||
def _on_upgrade(self, data, find_packet_callback):
|
||||
find_packet_callback('upgrade')()
|
||||
def _on_upgrade(self, data, namespace):
|
||||
namespace._find_packet_callback('upgrade')()
|
||||
|
||||
def _on_noop(self, data, find_packet_callback):
|
||||
find_packet_callback('noop')()
|
||||
def _on_noop(self, data, namespace):
|
||||
namespace._find_packet_callback('noop')()
|
||||
|
||||
|
||||
class SocketIO(EngineIO):
|
||||
|
|
@ -329,7 +337,7 @@ class SocketIO(EngineIO):
|
|||
for path, namespace in self._namespace_by_path.items():
|
||||
namespace._transport = self._transport_instance
|
||||
if path:
|
||||
self.connect(path)
|
||||
self.connect(path, with_transport_instance=True)
|
||||
|
||||
def __exit__(self, *exception_pack):
|
||||
self.disconnect()
|
||||
|
|
@ -342,9 +350,10 @@ class SocketIO(EngineIO):
|
|||
# Define
|
||||
|
||||
def define(self, Namespace, path=''):
|
||||
self._namespace_by_path[path] = namespace = Namespace(self, path)
|
||||
if path:
|
||||
self.connect(path)
|
||||
self._namespace_by_path[path] = namespace = Namespace(self, path)
|
||||
self.wait(for_connect=True)
|
||||
return namespace
|
||||
|
||||
def on(self, event, callback, path=''):
|
||||
|
|
@ -362,20 +371,23 @@ class SocketIO(EngineIO):
|
|||
|
||||
# Act
|
||||
|
||||
def connect(self, path):
|
||||
def connect(self, path, with_transport_instance=False):
|
||||
socketIO_packet_type = 0
|
||||
socketIO_packet_data = format_socketIO_packet_data(path)
|
||||
self._message(str(socketIO_packet_type) + socketIO_packet_data)
|
||||
self._message(
|
||||
str(socketIO_packet_type) + socketIO_packet_data,
|
||||
with_transport_instance)
|
||||
|
||||
def disconnect(self, path=''):
|
||||
if not self._opened:
|
||||
return
|
||||
if path:
|
||||
if not path or not self._opened:
|
||||
self._close()
|
||||
elif path:
|
||||
socketIO_packet_type = 1
|
||||
socketIO_packet_data = format_socketIO_packet_data(path)
|
||||
self._message(str(socketIO_packet_type) + socketIO_packet_data)
|
||||
else:
|
||||
self._close()
|
||||
try:
|
||||
self._message(str(socketIO_packet_type) + socketIO_packet_data)
|
||||
except (TimeoutError, ConnectionError):
|
||||
pass
|
||||
try:
|
||||
namespace = self._namespace_by_path.pop(path)
|
||||
namespace.on_disconnect()
|
||||
|
|
@ -405,13 +417,17 @@ class SocketIO(EngineIO):
|
|||
|
||||
# React
|
||||
|
||||
def wait(self, seconds=None, for_callbacks=False):
|
||||
super(SocketIO, self).wait(seconds, for_callbacks=for_callbacks)
|
||||
|
||||
def wait_for_callbacks(self, seconds=None):
|
||||
self.wait(seconds, for_callbacks=True)
|
||||
|
||||
def _should_stop_waiting(self, for_callbacks):
|
||||
def _should_stop_waiting(self, for_connect=False, for_callbacks=False):
|
||||
if for_connect:
|
||||
for namespace in self._namespace_by_path.values():
|
||||
is_namespace_connected = getattr(
|
||||
namespace, '_connected', False)
|
||||
if not is_namespace_connected:
|
||||
return False
|
||||
return True
|
||||
if for_callbacks and not self._has_ack_callback:
|
||||
return True
|
||||
return super(SocketIO, self)._should_stop_waiting()
|
||||
|
|
@ -439,16 +455,18 @@ class SocketIO(EngineIO):
|
|||
except KeyError:
|
||||
raise PacketError(
|
||||
'unexpected socket.io packet type (%s)' % socketIO_packet_type)
|
||||
delegate(socketIO_packet_data, namespace._find_packet_callback)
|
||||
delegate(socketIO_packet_data, namespace)
|
||||
return socketIO_packet_data
|
||||
|
||||
def _on_connect(self, data, find_packet_callback):
|
||||
find_packet_callback('connect')()
|
||||
def _on_connect(self, data, namespace):
|
||||
namespace._connected = True
|
||||
namespace._find_packet_callback('connect')()
|
||||
|
||||
def _on_disconnect(self, data, find_packet_callback):
|
||||
find_packet_callback('disconnect')()
|
||||
def _on_disconnect(self, data, namespace):
|
||||
namespace._connected = False
|
||||
namespace._find_packet_callback('disconnect')()
|
||||
|
||||
def _on_event(self, data, find_packet_callback):
|
||||
def _on_event(self, data, namespace):
|
||||
data_parsed = parse_socketIO_packet_data(data)
|
||||
args = data_parsed.args
|
||||
try:
|
||||
|
|
@ -458,9 +476,9 @@ class SocketIO(EngineIO):
|
|||
if data_parsed.ack_id is not None:
|
||||
args.append(self._prepare_to_send_ack(
|
||||
data_parsed.path, data_parsed.ack_id))
|
||||
find_packet_callback(event)(*args)
|
||||
namespace._find_packet_callback(event)(*args)
|
||||
|
||||
def _on_ack(self, data, find_packet_callback):
|
||||
def _on_ack(self, data, namespace):
|
||||
data_parsed = parse_socketIO_packet_data(data)
|
||||
try:
|
||||
ack_callback = self._get_ack_callback(data_parsed.ack_id)
|
||||
|
|
@ -468,13 +486,13 @@ class SocketIO(EngineIO):
|
|||
return
|
||||
ack_callback(*data_parsed.args)
|
||||
|
||||
def _on_error(self, data, find_packet_callback):
|
||||
find_packet_callback('error')(data)
|
||||
def _on_error(self, data, namespace):
|
||||
namespace._find_packet_callback('error')(data)
|
||||
|
||||
def _on_binary_event(self, data, find_packet_callback):
|
||||
def _on_binary_event(self, data, namespace):
|
||||
self._warn('[not implemented] binary event')
|
||||
|
||||
def _on_binary_ack(self, data, find_packet_callback):
|
||||
def _on_binary_ack(self, data, namespace):
|
||||
self._warn('[not implemented] binary ack')
|
||||
|
||||
def _prepare_to_send_ack(self, path, ack_id):
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
from threading import Thread, Event
|
||||
|
||||
from .exceptions import ConnectionError, TimeoutError
|
||||
|
|
@ -22,17 +23,17 @@ class HeartbeatThread(Thread):
|
|||
def run(self):
|
||||
try:
|
||||
while not self._halt.is_set():
|
||||
try:
|
||||
self._send_heartbeat()
|
||||
except TimeoutError:
|
||||
pass
|
||||
if self._adrenaline.is_set():
|
||||
interval_in_seconds = self._hurry_interval_in_seconds
|
||||
else:
|
||||
interval_in_seconds = self._relax_interval_in_seconds
|
||||
self._rest.wait(interval_in_seconds)
|
||||
try:
|
||||
self._send_heartbeat()
|
||||
except TimeoutError:
|
||||
pass
|
||||
except ConnectionError:
|
||||
pass
|
||||
logging.debug('[heartbeat connection error]')
|
||||
|
||||
def relax(self):
|
||||
self._adrenaline.clear()
|
||||
|
|
@ -42,6 +43,10 @@ class HeartbeatThread(Thread):
|
|||
self._rest.set()
|
||||
self._rest.clear()
|
||||
|
||||
@property
|
||||
def hurried(self):
|
||||
return self._adrenaline.is_set()
|
||||
|
||||
def halt(self):
|
||||
self._rest.set()
|
||||
self._halt.set()
|
||||
|
|
|
|||
|
|
@ -141,31 +141,31 @@ class SocketIONamespace(EngineIONamespace):
|
|||
class LoggingEngineIONamespace(EngineIONamespace):
|
||||
|
||||
def on_open(self):
|
||||
self._debug('[open]')
|
||||
self._debug('[engine.io open]')
|
||||
super(LoggingEngineIONamespace, self).on_open()
|
||||
|
||||
def on_close(self):
|
||||
self._debug('[close]')
|
||||
self._debug('[engine.io close]')
|
||||
super(LoggingEngineIONamespace, self).on_close()
|
||||
|
||||
def on_ping(self, data):
|
||||
self._debug('[ping] %s', data)
|
||||
self._debug('[engine.io ping] %s', data)
|
||||
super(LoggingEngineIONamespace, self).on_ping(data)
|
||||
|
||||
def on_pong(self, data):
|
||||
self._debug('[pong] %s', data)
|
||||
self._debug('[engine.io pong] %s', data)
|
||||
super(LoggingEngineIONamespace, self).on_pong(data)
|
||||
|
||||
def on_message(self, data):
|
||||
self._debug('[message] %s', data)
|
||||
self._debug('[engine.io message] %s', data)
|
||||
super(LoggingEngineIONamespace, self).on_message(data)
|
||||
|
||||
def on_upgrade(self):
|
||||
self._debug('[upgrade]')
|
||||
self._debug('[engine.io upgrade]')
|
||||
super(LoggingEngineIONamespace, self).on_upgrade()
|
||||
|
||||
def on_noop(self):
|
||||
self._debug('[noop]')
|
||||
self._debug('[engine.io noop]')
|
||||
super(LoggingEngineIONamespace, self).on_noop()
|
||||
|
||||
def on_event(self, event, *args):
|
||||
|
|
@ -173,22 +173,25 @@ class LoggingEngineIONamespace(EngineIONamespace):
|
|||
arguments = [repr(_) for _ in args]
|
||||
if callback:
|
||||
arguments.append('callback(*args)')
|
||||
self._info('[event] %s(%s)', event, ', '.join(arguments))
|
||||
self._info('[engine.io event] %s(%s)', event, ', '.join(arguments))
|
||||
super(LoggingEngineIONamespace, self).on_event(event, *args)
|
||||
|
||||
|
||||
class LoggingSocketIONamespace(SocketIONamespace):
|
||||
class LoggingSocketIONamespace(SocketIONamespace, LoggingEngineIONamespace):
|
||||
|
||||
def on_connect(self):
|
||||
self._debug('%s[connect]', _make_logging_header(self.path))
|
||||
self._debug(
|
||||
'%s[socket.io connect]', _make_logging_header(self.path))
|
||||
super(LoggingSocketIONamespace, self).on_connect()
|
||||
|
||||
def on_reconnect(self):
|
||||
self._debug('%s[reconnect]', _make_logging_header(self.path))
|
||||
self._debug(
|
||||
'%s[socket.io reconnect]', _make_logging_header(self.path))
|
||||
super(LoggingSocketIONamespace, self).on_reconnect()
|
||||
|
||||
def on_disconnect(self):
|
||||
self._debug('%s[disconnect]', _make_logging_header(self.path))
|
||||
self._debug(
|
||||
'%s[socket.io disconnect]', _make_logging_header(self.path))
|
||||
super(LoggingSocketIONamespace, self).on_disconnect()
|
||||
|
||||
def on_event(self, event, *args):
|
||||
|
|
@ -197,12 +200,13 @@ class LoggingSocketIONamespace(SocketIONamespace):
|
|||
if callback:
|
||||
arguments.append('callback(*args)')
|
||||
self._info(
|
||||
'%s[event] %s(%s)', _make_logging_header(self.path), event,
|
||||
', '.join(arguments))
|
||||
'%s[socket.io event] %s(%s)', _make_logging_header(self.path),
|
||||
event, ', '.join(arguments))
|
||||
super(LoggingSocketIONamespace, self).on_event(event, *args)
|
||||
|
||||
def on_error(self, data):
|
||||
self._debug('%s[error] %s', _make_logging_header(self.path), data)
|
||||
self._debug(
|
||||
'%s[socket.io error] %s', _make_logging_header(self.path), data)
|
||||
super(LoggingSocketIONamespace, self).on_error()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -17,10 +17,11 @@ class BaseMixin(object):
|
|||
def setUp(self):
|
||||
super(BaseMixin, self).setUp()
|
||||
self.called_on_response = False
|
||||
self.wait_time_in_seconds = 1
|
||||
|
||||
def tearDown(self):
|
||||
super(BaseMixin, self).tearDown()
|
||||
del self.socketIO
|
||||
self.socketIO.disconnect()
|
||||
|
||||
def test_disconnect(self):
|
||||
'Disconnect'
|
||||
|
|
@ -161,7 +162,6 @@ class Test_XHR_PollingTransport(BaseMixin, TestCase):
|
|||
self.socketIO = SocketIO(HOST, PORT, LoggingNamespace, transports=[
|
||||
'xhr-polling'], verify=False)
|
||||
self.assertEqual(self.socketIO.transport_name, 'xhr-polling')
|
||||
self.wait_time_in_seconds = 1
|
||||
|
||||
|
||||
class Test_WebsocketTransport(BaseMixin, TestCase):
|
||||
|
|
@ -171,7 +171,6 @@ class Test_WebsocketTransport(BaseMixin, TestCase):
|
|||
self.socketIO = SocketIO(HOST, PORT, LoggingNamespace, transports=[
|
||||
'xhr-polling', 'websocket'], verify=False)
|
||||
self.assertEqual(self.socketIO.transport_name, 'websocket')
|
||||
self.wait_time_in_seconds = 1
|
||||
|
||||
|
||||
class Namespace(LoggingNamespace):
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import six
|
|||
import socket
|
||||
import ssl
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import websocket
|
||||
|
||||
|
|
@ -54,10 +55,11 @@ class XHR_PollingTransport(AbstractTransport):
|
|||
'EIO': ENGINEIO_PROTOCOL, 'transport': 'polling'}
|
||||
if engineIO_session:
|
||||
self._request_index = 1
|
||||
self._kw_get = dict(timeout=engineIO_session.ping_timeout)
|
||||
self._kw_post = dict(headers={
|
||||
'content-type': 'application/octet-stream',
|
||||
})
|
||||
self._kw_get = dict(
|
||||
timeout=engineIO_session.ping_timeout)
|
||||
self._kw_post = dict(
|
||||
timeout=engineIO_session.ping_timeout,
|
||||
headers={'content-type': 'application/octet-stream'})
|
||||
self._params['sid'] = engineIO_session.id
|
||||
else:
|
||||
self._request_index = 0
|
||||
|
|
@ -65,6 +67,8 @@ class XHR_PollingTransport(AbstractTransport):
|
|||
self._kw_post = {}
|
||||
http_scheme = 'https' if is_secure else 'http'
|
||||
self._http_url = '%s://%s/' % (http_scheme, url)
|
||||
self._request_index_lock = threading.Lock()
|
||||
self._send_packet_lock = threading.Lock()
|
||||
|
||||
def recv_packet(self):
|
||||
params = dict(self._params)
|
||||
|
|
@ -79,21 +83,24 @@ class XHR_PollingTransport(AbstractTransport):
|
|||
yield engineIO_packet_type, engineIO_packet_data
|
||||
|
||||
def send_packet(self, engineIO_packet_type, engineIO_packet_data=''):
|
||||
params = dict(self._params)
|
||||
params['t'] = self._get_timestamp()
|
||||
response = get_response(
|
||||
self.http_session.post,
|
||||
self._http_url,
|
||||
params=params,
|
||||
data=encode_engineIO_content([
|
||||
(engineIO_packet_type, engineIO_packet_data),
|
||||
]),
|
||||
**self._kw_post)
|
||||
assert response.content == b'ok'
|
||||
with self._send_packet_lock:
|
||||
params = dict(self._params)
|
||||
params['t'] = self._get_timestamp()
|
||||
response = get_response(
|
||||
self.http_session.post,
|
||||
self._http_url,
|
||||
params=params,
|
||||
data=encode_engineIO_content([
|
||||
(engineIO_packet_type, engineIO_packet_data),
|
||||
]),
|
||||
**self._kw_post)
|
||||
assert response.content == b'ok'
|
||||
|
||||
def _get_timestamp(self):
|
||||
timestamp = '%s-%s' % (int(time.time() * 1000), self._request_index)
|
||||
self._request_index += 1
|
||||
with self._request_index_lock:
|
||||
timestamp = '%s-%s' % (
|
||||
int(time.time() * 1000), self._request_index)
|
||||
self._request_index += 1
|
||||
return timestamp
|
||||
|
||||
|
||||
|
|
@ -164,7 +171,7 @@ class WebsocketTransport(AbstractTransport):
|
|||
|
||||
def get_response(request, *args, **kw):
|
||||
try:
|
||||
response = request(*args, **kw)
|
||||
response = request(*args, stream=True, **kw)
|
||||
except requests.exceptions.Timeout as e:
|
||||
raise TimeoutError(e)
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue