This commit is contained in:
Roy Hyunjin Han 2013-04-26 09:34:47 -07:00
commit 77a8e72c1f
6 changed files with 286 additions and 159 deletions

View file

@ -16,25 +16,39 @@ class BaseNamespace(object): # pragma: no cover
self._socketIO = _socketIO
self._path = path
self._callbackByEvent = {}
self.initialize()
def initialize(self):
'Initialize custom variables here; you can override this method'
pass
def on_connect(self):
'Called when socket is connecting; you can override this method'
pass
def on_disconnect(self):
'Called when socket is disconnecting; you can override this method'
pass
def on_error(self, reason, advice):
'Called when server sends an error; you can override this method'
print '[Error] %s' % advice
def on_message(self, data):
'Called when server sends a message; you can override this method'
print '[Message] %s' % data
def on_default(self, event, *args):
def on_event(self, event, *args):
"""
Called when server emits an event; you can override this method.
Called only if the program cannot find a more specific event handler,
such as one defined by namespace.on('my_event', my_function).
"""
callback, args = find_callback(args)
arguments = [str(_) for _ in args]
arguments = [repr(_) for _ in args]
if callback:
arguments.append('callback(*args)')
callback()
callback(*args)
print '[Event] %s(%s)' % (event, ', '.join(arguments))
def on_open(self, *args):
@ -49,7 +63,7 @@ class BaseNamespace(object): # pragma: no cover
def on_reconnect(self, *args):
print '[Reconnect]', args
def message(self, data, callback=None):
def message(self, data='', callback=None):
self._socketIO.message(data, callback, path=self._path)
def emit(self, event, *args, **kw):
@ -57,6 +71,7 @@ class BaseNamespace(object): # pragma: no cover
self._socketIO.emit(event, *args, **kw)
def on(self, event, callback):
'Define a callback to handle a custom event emitted by the server'
self._callbackByEvent[event] = callback
def _get_eventCallback(self, event):
@ -65,8 +80,8 @@ class BaseNamespace(object): # pragma: no cover
return self._callbackByEvent[event]
except KeyError:
pass
# Check callbacks defined explicitly or use on_default()
callback = lambda *args: self.on_default(event, *args)
# Check callbacks defined explicitly or use on_event()
callback = lambda *args: self.on_event(event, *args)
return getattr(self, 'on_' + event.replace(' ', '_'), callback)
@ -117,7 +132,8 @@ class SocketIO(object):
self._listenerThread.cancel()
def define(self, Namespace, path=''):
self._socketIO.connect(path)
if path:
self._socketIO.connect(path)
namespace = Namespace(self._socketIO, path)
self._namespaceByPath[path] = namespace
return namespace
@ -128,17 +144,15 @@ class SocketIO(object):
def on(self, event, callback, path=''):
return self.get_namespace(path).on(event, callback)
def message(self, data, callback=None, path=''):
def message(self, data='', callback=None, path=''):
self._socketIO.message(data, callback, path)
def emit(self, event, *args, **kw):
self._socketIO.emit(event, *args, **kw)
def wait(self, seconds=None, forCallbacks=False):
if forCallbacks:
self._listenerThread.wait_for_callbacks(seconds)
elif seconds:
sleep(seconds)
def wait(self, seconds=None):
if seconds:
self._listenerThread.wait(seconds)
else:
try:
while self.connected:
@ -146,6 +160,9 @@ class SocketIO(object):
except KeyboardInterrupt:
pass
def wait_for_callbacks(self, seconds=None):
self._listenerThread.wait_for_callbacks(seconds)
class _RhythmicThread(Thread):
'Execute call every few seconds'
@ -179,15 +196,18 @@ class _ListenerThread(Thread):
self._socketIO = _socketIO
self._namespaceByPath = _namespaceByPath
self.done = Event()
self.waiting = Event()
self.ready = Event()
self.ready.set()
def cancel(self):
self.done.set()
def wait(self, seconds):
self.done.wait(seconds)
def wait_for_callbacks(self, seconds):
self.waiting.set()
# Block callingThread until listenerThread terminates
self.join(seconds)
self.ready.clear()
self.ready.wait(seconds)
def get_ackCallback(self, packetID):
return lambda *args: self._socketIO.ack(packetID, *args)
@ -215,7 +235,7 @@ class _ListenerThread(Thread):
'3': self.on_message,
'4': self.on_json,
'5': self.on_event,
'6': self.on_acknowledgment,
'6': self.on_ack,
'7': self.on_error,
}[code]
except KeyError:
@ -236,13 +256,13 @@ class _ListenerThread(Thread):
args = [data]
if packetID:
args.append(self.get_ackCallback(packetID))
get_eventCallback('message')(args)
get_eventCallback('message')(*args)
def on_json(self, packetID, get_eventCallback, data):
args = [loads(data)]
if packetID:
args.append(self.get_ackCallback(packetID))
get_eventCallback('message')(args)
get_eventCallback('message')(*args)
def on_event(self, packetID, get_eventCallback, data):
valueByName = loads(data)
@ -252,16 +272,16 @@ class _ListenerThread(Thread):
args.append(self.get_ackCallback(packetID))
get_eventCallback(event)(*args)
def on_acknowledgment(self, packetID, get_eventCallback, data):
def on_ack(self, packetID, get_eventCallback, data):
dataParts = data.split('+', 1)
messageID = int(dataParts[0])
args = loads(dataParts[1]) or []
args = loads(dataParts[1]) if len(dataParts) > 1 else []
callback = self._socketIO.get_messageCallback(messageID)
if not callback:
return
callback(*args)
if self.waiting.is_set() and not self._socketIO.has_messageCallback:
self.cancel()
if not self._socketIO.has_messageCallback:
self.ready.set()
def on_error(self, packetID, get_eventCallback, data):
reason, advice = data.split('+', 1)
@ -289,7 +309,7 @@ class _SocketIO(object):
# connectionTimeout = int(responseParts[2])
supportedTransports = responseParts[3].split(',')
if 'websocket' not in supportedTransports:
raise SocketIOError('Could not parse handshake') # pragma: no cover
raise SocketIOError('Could not parse handshake')
socketScheme = 'wss' if secure else 'ws'
socketURL = '%s://%s/websocket/%s' % (socketScheme, baseURL, sessionID)
self.connection = create_connection(socketURL)
@ -334,7 +354,10 @@ class _SocketIO(object):
def ack(self, packetID, *args):
packetID = packetID.rstrip('+')
packetData = '%s+%s' % (packetID, dumps(args, ensure_ascii=False)) if args else packetID
packetData = '%s+%s' % (
packetID,
dumps(args, ensure_ascii=False),
) if args else packetID
self.send_packet(6, data=packetData)
def set_messageCallback(self, callback):
@ -359,11 +382,14 @@ class _SocketIO(object):
try:
packet = self.connection.recv()
except WebSocketConnectionClosedException:
raise SocketIOConnectionError('Lost connection (Connection closed)')
text = 'Lost connection (Connection closed)'
raise SocketIOConnectionError(text)
except socket.timeout:
raise SocketIOConnectionError('Lost connection (Connection timed out)')
text = 'Lost connection (Connection timed out)'
raise SocketIOConnectionError(text)
except socket.error:
raise SocketIOConnectionError('Lost connection')
text = 'Lost connection'
raise SocketIOConnectionError(text)
try:
packetParts = packet.split(':', 3)
except AttributeError:
@ -382,7 +408,8 @@ class _SocketIO(object):
packetID = self.set_messageCallback(callback) if callback else ''
packetParts = [str(code), packetID, path, data]
try:
self.connection.send(':'.join(packetParts))
packet = ':'.join(packetParts)
self.connection.send(packet)
except socket.error:
raise SocketIOPacketError('Could not send packet')

View file

@ -1,16 +1,17 @@
from socketIO_client import SocketIO, BaseNamespace, find_callback
from time import sleep
from unittest import TestCase
HOST = 'localhost'
PORT = 8000
DATA = 'xxx'
PAYLOAD = {'xxx': 'yyy'}
class TestSocketIO(TestCase):
def setUp(self):
self.socketIO = SocketIO('localhost', PORT)
self.socketIO = SocketIO(HOST, PORT)
self.called_on_response = False
def tearDown(self):
@ -18,76 +19,156 @@ class TestSocketIO(TestCase):
def on_response(self, *args):
self.called_on_response = True
callback, args = find_callback(args)
if callback:
callback(*args)
for arg in args:
if isinstance(arg, dict):
self.assertEqual(arg, PAYLOAD)
else:
self.assertEqual(arg, DATA)
def is_connected(self, socketIO, connected):
childThreads = [
socketIO._rhythmicThread,
socketIO._listenerThread,
]
for childThread in childThreads:
self.assertEqual(not connected, childThread.done.is_set())
self.assertEqual(connected, socketIO.connected)
def test_disconnect(self):
childThreads = [
self.socketIO._rhythmicThread,
self.socketIO._listenerThread,
]
'Terminate child threads after disconnect'
self.is_connected(self.socketIO, True)
self.socketIO.disconnect()
for childThread in childThreads:
self.assertEqual(True, childThread.done.is_set())
self.assertEqual(False, self.socketIO.connected)
self.is_connected(self.socketIO, False)
# Use context manager
with SocketIO(HOST, PORT) as self.socketIO:
self.is_connected(self.socketIO, True)
self.is_connected(self.socketIO, False)
def test_message(self):
'Message'
self.socketIO.define(Namespace)
self.socketIO.message()
self.socketIO.wait(0.1)
namespace = self.socketIO.get_namespace()
self.assertEqual(namespace.response, 'message_response')
def test_message_with_data(self):
'Message with data'
self.socketIO.define(Namespace)
self.socketIO.message(DATA)
self.socketIO.wait(0.1)
namespace = self.socketIO.get_namespace()
self.assertEqual(namespace.response, DATA)
def test_message_with_payload(self):
'Message with payload'
self.socketIO.define(Namespace)
self.socketIO.message(PAYLOAD)
self.socketIO.wait(0.1)
namespace = self.socketIO.get_namespace()
self.assertEqual(namespace.response, PAYLOAD)
def test_message_with_callback(self):
'Message with callback'
self.socketIO.message(callback=self.on_response)
self.socketIO.wait_for_callbacks(seconds=0.1)
self.assertEqual(self.called_on_response, True)
def test_message_with_callback_with_data(self):
'Message with callback with data'
self.socketIO.message(DATA, self.on_response)
self.socketIO.wait_for_callbacks(seconds=0.1)
self.assertEqual(self.called_on_response, True)
def test_emit(self):
'Emit'
self.socketIO.define(Namespace)
self.socketIO.emit('aaa')
sleep(0.1)
self.assertEqual(self.socketIO.get_namespace().payload, '')
self.socketIO.emit('emit')
self.socketIO.wait(0.1)
self.assertEqual(self.socketIO.get_namespace().argsByEvent, {
'emit_response': (),
})
def test_emit_with_payload(self):
'Emit with payload'
self.socketIO.define(Namespace)
self.socketIO.emit('aaa', PAYLOAD)
sleep(0.1)
self.assertEqual(self.socketIO.get_namespace().payload, PAYLOAD)
self.socketIO.emit('emit_with_payload', PAYLOAD)
self.socketIO.wait(0.1)
self.assertEqual(self.socketIO.get_namespace().argsByEvent, {
'emit_with_payload_response': (PAYLOAD,),
})
def test_emit_with_multiple_payloads(self):
'Emit with multiple payloads'
self.socketIO.define(Namespace)
self.socketIO.emit('emit_with_multiple_payloads', PAYLOAD, PAYLOAD)
self.socketIO.wait(0.1)
self.assertEqual(self.socketIO.get_namespace().argsByEvent, {
'emit_with_multiple_payloads_response': (PAYLOAD, PAYLOAD),
})
def test_emit_with_callback(self):
self.socketIO.emit('aaa', PAYLOAD, self.on_response)
self.socketIO.wait(seconds=0.1, forCallbacks=True)
'Emit with callback'
self.socketIO.emit('emit_with_callback', self.on_response)
self.socketIO.wait_for_callbacks(seconds=0.1)
self.assertEqual(self.called_on_response, True)
def test_emit_with_callback_with_payload(self):
'Emit with callback with payload'
self.socketIO.emit('emit_with_callback_with_payload',
self.on_response)
self.socketIO.wait_for_callbacks(seconds=0.1)
self.assertEqual(self.called_on_response, True)
def test_emit_with_callback_with_multiple_payloads(self):
'Emit with callback with multiple payloads'
self.socketIO.emit('emit_with_callback_with_multiple_payloads',
self.on_response)
self.socketIO.wait_for_callbacks(seconds=0.1)
self.assertEqual(self.called_on_response, True)
def test_emit_with_event(self):
self.socketIO.on('aaa_response', self.on_response)
self.socketIO.emit('aaa', PAYLOAD)
sleep(0.1)
self.assertEqual(self.called_on_response, True)
def test_message(self):
self.socketIO.message(PAYLOAD, self.on_response)
self.socketIO.wait(seconds=0.1, forCallbacks=True)
'Emit to trigger an event'
self.socketIO.on('emit_with_event_response', self.on_response)
self.socketIO.emit('emit_with_event', PAYLOAD)
self.socketIO.wait_for_callbacks(0.1)
self.assertEqual(self.called_on_response, True)
def test_ack(self):
self.socketIO.on('bbb_response', self.on_response)
self.socketIO.emit('bbb', PAYLOAD)
sleep(0.1)
self.assertEqual(self.called_on_response, True)
'Trigger server callback'
self.socketIO.define(Namespace)
self.socketIO.emit('ack', PAYLOAD)
self.socketIO.wait(0.1)
self.assertEqual(self.socketIO.get_namespace().argsByEvent, {
'ack_response': (PAYLOAD,),
'ack_callback_response': (PAYLOAD,),
})
def test_namespaces(self):
'Behave differently in different namespaces'
mainNamespace = self.socketIO.define(Namespace)
chatNamespace = self.socketIO.define(Namespace, '/chat')
newsNamespace = self.socketIO.define(Namespace, '/news')
self.assertNotEqual(mainNamespace.payload, PAYLOAD)
self.assertNotEqual(chatNamespace.payload, PAYLOAD)
self.assertNotEqual(newsNamespace.payload, PAYLOAD)
newsNamespace.emit('aaa', PAYLOAD)
sleep(0.1)
self.assertEqual(newsNamespace.payload, PAYLOAD)
def test_namespaces_with_callback(self):
mainNamespace = self.socketIO.get_namespace()
mainNamespace.message(PAYLOAD, self.on_response)
sleep(0.1)
self.assertEqual(self.called_on_response, True)
newsNamespace.emit('emit_with_payload', PAYLOAD)
self.socketIO.wait(0.1)
self.assertEqual(mainNamespace.argsByEvent, {})
self.assertEqual(chatNamespace.argsByEvent, {})
self.assertEqual(newsNamespace.argsByEvent, {
'emit_with_payload_response': (PAYLOAD,),
})
class Namespace(BaseNamespace):
payload = None
def initialize(self):
self.response = None
self.argsByEvent = {}
def on_aaa_response(self, data=''):
print '[Event] aaa_response(%s)' % data
self.payload = data
def on_message(self, data):
self.response = data
def on_event(self, event, *args):
callback, args = find_callback(args)
if callback:
callback(*args)
self.argsByEvent[event] = args