Added support for acks thanks to zratic
This commit is contained in:
parent
61918597b7
commit
18d2a1ea8c
2 changed files with 138 additions and 109 deletions
|
|
@ -12,9 +12,9 @@ PROTOCOL = 1 # socket.io protocol version
|
|||
class BaseNamespace(object): # pragma: no cover
|
||||
'Define socket.io behavior'
|
||||
|
||||
def __init__(self, _socketIO, namespacePath):
|
||||
def __init__(self, _socketIO, path):
|
||||
self._socketIO = _socketIO
|
||||
self._namespacePath = namespacePath
|
||||
self._path = path
|
||||
self._callbackByEvent = {}
|
||||
|
||||
def on_connect(self):
|
||||
|
|
@ -26,11 +26,16 @@ class BaseNamespace(object): # pragma: no cover
|
|||
def on_error(self, reason, advice):
|
||||
print '[Error] %s' % advice
|
||||
|
||||
def on_message(self, messageData):
|
||||
print '[Message] %s' % messageData
|
||||
def on_message(self, data):
|
||||
print '[Message] %s' % data
|
||||
|
||||
def on_default(self, eventName, *eventArguments):
|
||||
print '[Event] %s%s' % (eventName, eventArguments)
|
||||
def on_default(self, event, *args):
|
||||
callback, args = find_callback(args)
|
||||
arguments = [str(_) for _ in args]
|
||||
if callback:
|
||||
arguments.append('callback(*args)')
|
||||
callback()
|
||||
print '[Event] %s(%s)' % (event, ', '.join(arguments))
|
||||
|
||||
def on_open(self, *args):
|
||||
print '[Open]', args
|
||||
|
|
@ -44,28 +49,25 @@ class BaseNamespace(object): # pragma: no cover
|
|||
def on_reconnect(self, *args):
|
||||
print '[Reconnect]', args
|
||||
|
||||
def message(self, messageData, messageCallback=None):
|
||||
self._socketIO.message(
|
||||
messageData, messageCallback, namespacePath=self._namespacePath)
|
||||
def message(self, data, callback=None):
|
||||
self._socketIO.message(data, callback, path=self._path)
|
||||
|
||||
def emit(self, eventName, *eventArguments):
|
||||
self._socketIO.emit(
|
||||
eventName, *eventArguments, namespacePath=self._namespacePath)
|
||||
def emit(self, event, *args, **kw):
|
||||
kw['path'] = self._path
|
||||
self._socketIO.emit(event, *args, **kw)
|
||||
|
||||
def on(self, eventName, eventCallback):
|
||||
self._callbackByEvent[eventName] = eventCallback
|
||||
def on(self, event, callback):
|
||||
self._callbackByEvent[event] = callback
|
||||
|
||||
def _get_eventCallback(self, eventName):
|
||||
def _get_eventCallback(self, event):
|
||||
# Check callbacks defined by on()
|
||||
try:
|
||||
return self._callbackByEvent[eventName]
|
||||
return self._callbackByEvent[event]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# Check callbacks defined explicitly or use on_default()
|
||||
def callback(*eventArguments):
|
||||
return self.on_default(eventName, *eventArguments)
|
||||
return getattr(self, 'on_' + eventName.replace(' ', '_'), callback)
|
||||
callback = lambda *args: self.on_default(event, *args)
|
||||
return getattr(self, 'on_' + event.replace(' ', '_'), callback)
|
||||
|
||||
|
||||
class SocketIO(object):
|
||||
|
|
@ -99,38 +101,38 @@ class SocketIO(object):
|
|||
self.disconnect()
|
||||
|
||||
def __del__(self):
|
||||
self.disconnect(closeSocket=False)
|
||||
self.disconnect(close=False)
|
||||
|
||||
@property
|
||||
def connected(self):
|
||||
return self._socketIO.connected
|
||||
|
||||
def disconnect(self, namespacePath='', closeSocket=True):
|
||||
def disconnect(self, path='', close=True):
|
||||
if self.connected:
|
||||
self._socketIO.disconnect(namespacePath, closeSocket)
|
||||
if namespacePath:
|
||||
del self._namespaceByPath[namespacePath]
|
||||
self._socketIO.disconnect(path, close)
|
||||
if path:
|
||||
del self._namespaceByPath[path]
|
||||
else:
|
||||
self._rhythmicThread.cancel()
|
||||
self._listenerThread.cancel()
|
||||
|
||||
def define(self, Namespace, namespacePath=''):
|
||||
self._socketIO.connect(namespacePath)
|
||||
namespace = Namespace(self._socketIO, namespacePath)
|
||||
self._namespaceByPath[namespacePath] = namespace
|
||||
def define(self, Namespace, path=''):
|
||||
self._socketIO.connect(path)
|
||||
namespace = Namespace(self._socketIO, path)
|
||||
self._namespaceByPath[path] = namespace
|
||||
return namespace
|
||||
|
||||
def get_namespace(self, namespacePath=''):
|
||||
return self._namespaceByPath[namespacePath]
|
||||
def get_namespace(self, path=''):
|
||||
return self._namespaceByPath[path]
|
||||
|
||||
def on(self, eventName, eventCallback, namespacePath=''):
|
||||
return self.get_namespace(namespacePath).on(eventName, eventCallback)
|
||||
def on(self, event, callback, path=''):
|
||||
return self.get_namespace(path).on(event, callback)
|
||||
|
||||
def message(self, messageData, messageCallback=None, namespacePath=''):
|
||||
self._socketIO.message(messageData, messageCallback, namespacePath)
|
||||
def message(self, data, callback=None, path=''):
|
||||
self._socketIO.message(data, callback, path)
|
||||
|
||||
def emit(self, eventName, *eventArguments, **eventKeywords):
|
||||
self._socketIO.emit(eventName, *eventArguments, **eventKeywords)
|
||||
def emit(self, event, *args, **kw):
|
||||
self._socketIO.emit(event, *args, **kw)
|
||||
|
||||
def wait(self, seconds=None, forCallbacks=False):
|
||||
if forCallbacks:
|
||||
|
|
@ -187,10 +189,13 @@ class _ListenerThread(Thread):
|
|||
# Block callingThread until listenerThread terminates
|
||||
self.join(seconds)
|
||||
|
||||
def get_ackCallback(self, packetID):
|
||||
return lambda *args: self._socketIO.ack(packetID, *args)
|
||||
|
||||
def run(self):
|
||||
while not self.done.is_set():
|
||||
try:
|
||||
code, packetID, namespacePath, data = self._socketIO.recv_packet()
|
||||
code, packetID, path, data = self._socketIO.recv_packet()
|
||||
except SocketIOConnectionError, error:
|
||||
print error
|
||||
return
|
||||
|
|
@ -198,9 +203,9 @@ class _ListenerThread(Thread):
|
|||
print error
|
||||
continue
|
||||
try:
|
||||
namespace = self._namespaceByPath[namespacePath]
|
||||
namespace = self._namespaceByPath[path]
|
||||
except KeyError:
|
||||
print 'Received unexpected namespacePath (%s)' % namespacePath
|
||||
print 'Received unexpected path (%s)' % path
|
||||
continue
|
||||
try:
|
||||
delegate = {
|
||||
|
|
@ -228,25 +233,33 @@ class _ListenerThread(Thread):
|
|||
pass
|
||||
|
||||
def on_message(self, packetID, get_eventCallback, data):
|
||||
get_eventCallback('message')(data)
|
||||
args = [data]
|
||||
if packetID:
|
||||
args.append(self.get_ackCallback(packetID))
|
||||
get_eventCallback('message')(args)
|
||||
|
||||
def on_json(self, packetID, get_eventCallback, data):
|
||||
get_eventCallback('message')(loads(data))
|
||||
args = [loads(data)]
|
||||
if packetID:
|
||||
args.append(self.get_ackCallback(packetID))
|
||||
get_eventCallback('message')(args)
|
||||
|
||||
def on_event(self, packetID, get_eventCallback, data):
|
||||
valueByName = loads(data)
|
||||
eventName = valueByName['name']
|
||||
eventArguments = valueByName.get('args', [])
|
||||
get_eventCallback(eventName)(*eventArguments)
|
||||
event = valueByName['name']
|
||||
args = valueByName.get('args', [])
|
||||
if packetID:
|
||||
args.append(self.get_ackCallback(packetID))
|
||||
get_eventCallback(event)(*args)
|
||||
|
||||
def on_acknowledgment(self, packetID, get_eventCallback, data):
|
||||
dataParts = data.split('+', 1)
|
||||
messageID = int(dataParts[0])
|
||||
arguments = loads(dataParts[1]) or []
|
||||
messageCallback = self._socketIO.get_messageCallback(messageID)
|
||||
if not messageCallback:
|
||||
args = loads(dataParts[1]) or []
|
||||
callback = self._socketIO.get_messageCallback(messageID)
|
||||
if not callback:
|
||||
return
|
||||
messageCallback(*arguments)
|
||||
callback(*args)
|
||||
if self.waiting.is_set() and not self._socketIO.has_messageCallback:
|
||||
self.cancel()
|
||||
|
||||
|
|
@ -284,18 +297,18 @@ class _SocketIO(object):
|
|||
self.callbackByMessageID = {}
|
||||
|
||||
def __del__(self):
|
||||
self.disconnect(closeSocket=False)
|
||||
self.disconnect(close=False)
|
||||
|
||||
def disconnect(self, namespacePath='', closeSocket=True):
|
||||
def disconnect(self, path='', close=True):
|
||||
if not self.connected:
|
||||
return
|
||||
if namespacePath:
|
||||
self.send_packet(0, namespacePath)
|
||||
elif closeSocket:
|
||||
if path:
|
||||
self.send_packet(0, path)
|
||||
elif close:
|
||||
self.connection.close()
|
||||
|
||||
def connect(self, namespacePath):
|
||||
self.send_packet(1, namespacePath)
|
||||
def connect(self, path):
|
||||
self.send_packet(1, path)
|
||||
|
||||
def send_heartbeat(self):
|
||||
try:
|
||||
|
|
@ -304,24 +317,25 @@ class _SocketIO(object):
|
|||
print 'Could not send heartbeat'
|
||||
pass
|
||||
|
||||
def message(self, messageData, messageCallback, namespacePath):
|
||||
if isinstance(messageData, basestring):
|
||||
def message(self, data, callback, path):
|
||||
if isinstance(data, basestring):
|
||||
code = 3
|
||||
data = messageData
|
||||
packetData = data
|
||||
else:
|
||||
code = 4
|
||||
data = dumps(messageData, ensure_ascii=False)
|
||||
self.send_packet(code, namespacePath, data, messageCallback)
|
||||
packetData = dumps(data, ensure_ascii=False)
|
||||
self.send_packet(code, path, packetData, callback)
|
||||
|
||||
def emit(self, eventName, *eventArguments, **eventKeywords):
|
||||
if eventArguments and callable(eventArguments[-1]):
|
||||
messageCallback = eventArguments[-1]
|
||||
eventArguments = eventArguments[:-1]
|
||||
else:
|
||||
messageCallback = None
|
||||
namespacePath = eventKeywords.get('namespacePath', '')
|
||||
data = dumps(dict(name=eventName, args=eventArguments), ensure_ascii=False)
|
||||
self.send_packet(5, namespacePath, data, messageCallback)
|
||||
def emit(self, event, *args, **kw):
|
||||
callback, args = find_callback(args, kw)
|
||||
packetData = dumps(dict(name=event, args=args), ensure_ascii=False)
|
||||
path = kw.get('path', '')
|
||||
self.send_packet(5, path, packetData, callback)
|
||||
|
||||
def ack(self, packetID, *args):
|
||||
packetID = packetID.rstrip('+')
|
||||
packetData = '%s+%s' % (packetID, dumps(args, ensure_ascii=False)) if args else packetID
|
||||
self.send_packet(6, data=packetData)
|
||||
|
||||
def set_messageCallback(self, callback):
|
||||
'Set callback that will be called after receiving an acknowledgment'
|
||||
|
|
@ -355,18 +369,18 @@ class _SocketIO(object):
|
|||
except AttributeError:
|
||||
raise SocketIOPacketError('Received invalid packet (%s)' % packet)
|
||||
packetCount = len(packetParts)
|
||||
code, packetID, namespacePath, data = None, None, None, None
|
||||
code, packetID, path, data = None, None, None, None
|
||||
if 4 == packetCount:
|
||||
code, packetID, namespacePath, data = packetParts
|
||||
code, packetID, path, data = packetParts
|
||||
elif 3 == packetCount:
|
||||
code, packetID, namespacePath = packetParts
|
||||
code, packetID, path = packetParts
|
||||
elif 1 == packetCount:
|
||||
code = packetParts[0]
|
||||
return code, packetID, namespacePath, data
|
||||
return code, packetID, path, data
|
||||
|
||||
def send_packet(self, code, namespacePath='', data='', messageCallback=None):
|
||||
callbackNumber = self.set_messageCallback(messageCallback) if messageCallback else ''
|
||||
packetParts = [str(code), callbackNumber, namespacePath, data]
|
||||
def send_packet(self, code, path='', data='', callback=None):
|
||||
packetID = self.set_messageCallback(callback) if callback else ''
|
||||
packetParts = [str(code), packetID, path, data]
|
||||
try:
|
||||
self.connection.send(':'.join(packetParts))
|
||||
except socket.error:
|
||||
|
|
@ -387,3 +401,13 @@ class SocketIOConnectionError(SocketIOError):
|
|||
|
||||
class SocketIOPacketError(SocketIOError):
|
||||
pass
|
||||
|
||||
|
||||
def find_callback(args, kw=None):
|
||||
'Return callback whether passed as a last argument or as a keyword'
|
||||
if args and callable(args[-1]):
|
||||
return args[-1], args[:-1]
|
||||
try:
|
||||
return kw['callback'], args
|
||||
except (KeyError, TypeError):
|
||||
return None, args
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
from socketIO_client import SocketIO, BaseNamespace
|
||||
from socketIO_client import SocketIO, BaseNamespace, find_callback
|
||||
from time import sleep
|
||||
from unittest import TestCase
|
||||
|
||||
|
||||
ON_RESPONSE_CALLED = False
|
||||
PORT = 8000
|
||||
PAYLOAD = {'xxx': 'yyy'}
|
||||
|
||||
|
|
@ -11,13 +10,28 @@ PAYLOAD = {'xxx': 'yyy'}
|
|||
class TestSocketIO(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
global ON_RESPONSE_CALLED
|
||||
ON_RESPONSE_CALLED = False
|
||||
self.socketIO = SocketIO('localhost', PORT)
|
||||
self.called_on_response = False
|
||||
|
||||
def tearDown(self):
|
||||
del self.socketIO
|
||||
|
||||
def on_response(self, *args):
|
||||
self.called_on_response = True
|
||||
callback, args = find_callback(args)
|
||||
if callback:
|
||||
callback(*args)
|
||||
|
||||
def test_disconnect(self):
|
||||
childThreads = [
|
||||
self.socketIO._rhythmicThread,
|
||||
self.socketIO._listenerThread,
|
||||
]
|
||||
self.socketIO.disconnect()
|
||||
for childThread in childThreads:
|
||||
self.assertEqual(True, childThread.done.is_set())
|
||||
self.assertEqual(False, self.socketIO.connected)
|
||||
|
||||
def test_emit(self):
|
||||
self.socketIO.define(Namespace)
|
||||
self.socketIO.emit('aaa')
|
||||
|
|
@ -31,20 +45,26 @@ class TestSocketIO(TestCase):
|
|||
self.assertEqual(self.socketIO.get_namespace().payload, PAYLOAD)
|
||||
|
||||
def test_emit_with_callback(self):
|
||||
self.socketIO.emit('aaa', PAYLOAD, on_response)
|
||||
self.socketIO.wait(forCallbacks=True)
|
||||
self.assertEqual(ON_RESPONSE_CALLED, True)
|
||||
self.socketIO.emit('aaa', PAYLOAD, self.on_response)
|
||||
self.socketIO.wait(seconds=0.1, forCallbacks=True)
|
||||
self.assertEqual(self.called_on_response, True)
|
||||
|
||||
def test_message(self):
|
||||
self.socketIO.message(PAYLOAD, on_response)
|
||||
self.socketIO.wait(forCallbacks=True)
|
||||
self.assertEqual(ON_RESPONSE_CALLED, True)
|
||||
|
||||
def test_events(self):
|
||||
self.socketIO.on('aaa_response', on_response)
|
||||
def test_emit_with_event(self):
|
||||
self.socketIO.on('aaa_response', self.on_response)
|
||||
self.socketIO.emit('aaa', PAYLOAD)
|
||||
sleep(0.1)
|
||||
self.assertEqual(ON_RESPONSE_CALLED, True)
|
||||
self.assertEqual(self.called_on_response, True)
|
||||
|
||||
def test_message(self):
|
||||
self.socketIO.message(PAYLOAD, self.on_response)
|
||||
self.socketIO.wait(seconds=0.1, forCallbacks=True)
|
||||
self.assertEqual(self.called_on_response, True)
|
||||
|
||||
def test_ack(self):
|
||||
self.socketIO.on('bbb_response', self.on_response)
|
||||
self.socketIO.emit('bbb', PAYLOAD)
|
||||
sleep(0.1)
|
||||
self.assertEqual(self.called_on_response, True)
|
||||
|
||||
def test_namespaces(self):
|
||||
mainNamespace = self.socketIO.define(Namespace)
|
||||
|
|
@ -59,19 +79,9 @@ class TestSocketIO(TestCase):
|
|||
|
||||
def test_namespaces_with_callback(self):
|
||||
mainNamespace = self.socketIO.get_namespace()
|
||||
mainNamespace.message(PAYLOAD, on_response)
|
||||
mainNamespace.message(PAYLOAD, self.on_response)
|
||||
sleep(0.1)
|
||||
self.assertEqual(ON_RESPONSE_CALLED, True)
|
||||
|
||||
def test_disconnect(self):
|
||||
childThreads = [
|
||||
self.socketIO._rhythmicThread,
|
||||
self.socketIO._listenerThread,
|
||||
]
|
||||
self.socketIO.disconnect()
|
||||
for childThread in childThreads:
|
||||
self.assertEqual(True, childThread.done.is_set())
|
||||
self.assertEqual(False, self.socketIO.connected)
|
||||
self.assertEqual(self.called_on_response, True)
|
||||
|
||||
|
||||
class Namespace(BaseNamespace):
|
||||
|
|
@ -81,8 +91,3 @@ class Namespace(BaseNamespace):
|
|||
def on_aaa_response(self, data=''):
|
||||
print '[Event] aaa_response(%s)' % data
|
||||
self.payload = data
|
||||
|
||||
|
||||
def on_response(*args):
|
||||
global ON_RESPONSE_CALLED
|
||||
ON_RESPONSE_CALLED = True
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue