Trying to refactor the damn code to remove cyclic references

This commit is contained in:
Roy Hyunjin Han 2013-02-13 08:27:58 -08:00
commit 01bfefdd8b
3 changed files with 233 additions and 208 deletions

View file

@ -70,7 +70,8 @@ Define events in a namespace. ::
def on_ddd(self, *args):
self.socketIO.emit('eee', {'fff': 'ggg'})
socketIO = SocketIO('localhost', 8000, Namespace)
socketIO = SocketIO('localhost', 8000)
socketIO.define(Namespace)
socketIO.wait() # Loop until CTRL-C
Define standard events. ::
@ -79,7 +80,7 @@ Define standard events. ::
class Namespace(BaseNamespace):
def on_connect(self, socketIO):
def on_connect(self):
print '[Connected]'
def on_disconnect(self):
@ -91,32 +92,36 @@ Define standard events. ::
def on_message(self, id, message):
print '[Message] %s: %s' % (id, message)
socketIO = SocketIO('localhost', 8000, Namespace)
socketIO = SocketIO('localhost', 8000)
socketIO.define(Namespace)
socketIO.wait() # Loop until CTRL-C
Define different behavior for different channels on a single socket. ::
Define different namespaces on a single socket. ::
from socketIO_client import SocketIO, BaseNamespace
class MainNamespace(BaseNamespace):
class MainNamespace(Channel):
def on_aaa(self, *args):
print 'aaa', args
class ChatNamespace(BaseNamespace):
class ChatNamespace(Channel):
def on_bbb(self, *args):
print 'bbb', args
class NewsNamespace(BaseNamespace):
class NewsNamespace(Channel):
def on_ccc(self, *args):
print 'ccc', args
mainSocket = SocketIO('localhost', 8000, MainNamespace)
chatSocket = mainSocket.connect('/chat', ChatNamespace)
newsSocket = mainSocket.connect('/news', NewsNamespace)
mainSocket.wait() # Loop until CTRL-C
socketIO = SocketIO('localhost', 8000)
socketIO.define(MainNamespace)
chatSocket = socketIO.define(ChatNamespace, '/chat')
chatSocket.emit('bbb')
newsSocket = socketIO.define(NewsNamespace, '/news')
newsSocket.emit('ccc')
socketIO.wait() # Loop until CTRL-C
Open secure websockets (HTTPS / WSS) behind a proxy. ::

View file

@ -1,7 +1,5 @@
import socket
import weakref
from anyjson import dumps, loads
from functools import partial
from threading import Event, Thread
from time import sleep
from urllib import urlopen
@ -13,8 +11,8 @@ PROTOCOL = 1 # SocketIO protocol version
class BaseNamespace(object): # pragma: no cover
def __init__(self, socketIO):
self.socketIO = socketIO
def __init__(self, _socketIO):
self._socketIO = _socketIO
def on_connect(self):
pass
@ -44,40 +42,21 @@ class BaseNamespace(object): # pragma: no cover
print '[Reconnect]', args
class Channel(object):
def __init__(self, socketIO, channelName, Namespace):
self._socketIO = weakref.proxy(socketIO)
self._channelName = channelName
self._namespace = Namespace(self)
self._callbackByEvent = {}
def disconnect(self):
self._socketIO.disconnect(self._channelName)
def emit(self, eventName, *eventArguments):
self._socketIO.emit(eventName, *eventArguments, channelName=self._channelName)
def message(self, messageData, callback=None):
self._socketIO.message(messageData, callback, channelName=self._channelName)
def on(self, eventName, eventCallback):
self._callbackByEvent[eventName] = eventCallback
class SocketIO(object):
def __init__(self, host, port, Namespace=BaseNamespace, secure=False, proxies=None):
def __init__(self, host, port, secure=False, proxies=None):
self._socketIO = _SocketIO(host, port, secure, proxies)
self._channelByPath = {}
self._heartbeatThread = _RhythmicThread(
self._rhythmicThread = _RhythmicThread(
self._socketIO.heartbeatTimeout,
self._socketIO.send_heartbeat)
self._heartbeatThread.start()
self._rhythmicThread.start()
self._namespace = Namespace(self._socketIO)
self._namespaceThread = _ListenerThread(self._socketIO)
self._namespaceThread.start()
self._listenerThread = _ListenerThread(
self._socketIO,
self._channelByPath)
self._listenerThread.start()
def __enter__(self):
return self
@ -88,20 +67,30 @@ class SocketIO(object):
def __del__(self):
self.disconnect()
def disconnect(self, channelName=''):
self._send_packet(0, channelName)
if channelName:
del self.channelByName[channelName]
else:
self._heartbeatThread.cancel()
self._namespaceThread.cancel()
@property
def connected(self):
return self._socketIO.connection.connected
def connect(self, channelName, Namespace=BaseNamespace):
channel = Channel(self, channelName, Namespace)
self.channelByName[channelName] = channel
self.send_packet(1, channelName)
def disconnect(self, channelPath=''):
self._socketIO.disconnect(channelPath)
if channelPath:
del self._channelByPath[channelPath]
else:
self._rhythmicThread.cancel()
self._listenerThread.cancel()
def define(self, Namespace, channelPath=''):
self._socketIO.connect(channelPath)
channel = Channel(self._socketIO, Namespace, channelPath)
self._channelByPath[channelPath] = channel
return channel
def get_namespace(self, channelPath=''):
return self._channelByPath[channelPath].get_namespace()
def on(self, eventName, eventCallback):
self._callbackByEvent[eventName] = callback
def message(self, messageData, callback=None, channelName=''):
if isinstance(messageData, basestring):
code = 3
@ -111,7 +100,8 @@ class SocketIO(object):
data = dumps(messageData)
self._send_packet(code, channelName, data, callback)
def emit(self, eventName, *eventArguments, **eventKeywords):
def emit(self, eventName, *eventArguments):
self._socketIO.emit(eventName, *eventArguments)
code = 5
callback = None
if eventArguments and callable(eventArguments[-1]):
@ -121,12 +111,9 @@ class SocketIO(object):
data = dumps(dict(name=eventName, args=eventArguments))
self._send_packet(code, channelName, data, callback)
def on(self, eventName, callback):
self._callbackByEvent[eventName] = callback
def wait(self, seconds=None, forCallbacks=False):
if forCallbacks:
self._namespaceThread.wait_for_callbacks(seconds)
self.__listenerThread.wait_for_callbacks(seconds)
elif seconds:
sleep(seconds)
else:
@ -137,115 +124,6 @@ class SocketIO(object):
pass
class _SocketIO(object):
'Low-level interface to remove cyclic references in child threads'
messageID = 0
def __init__(self, host, port, secure=False, proxies=None):
self.connect(host, port, secure, proxies)
self.callbackByMessageID = {}
self.callbackByEvent = {}
self.channelByName = {}
def __del__(self):
self.connection.close()
def connect(self, host, port, secure, proxies):
baseURL = '%s:%d/socket.io/%s' % (host, port, PROTOCOL)
targetScheme = 'https' if secure else 'http'
targetURL = '%s://%s/' % (targetScheme, baseURL)
try:
response = urlopen(targetURL, proxies=proxies)
except IOError: # pragma: no cover
raise SocketIOError('Could not start connection')
if 200 != response.getcode(): # pragma: no cover
raise SocketIOError('Could not establish connection')
responseParts = response.readline().split(':')
sessionID = responseParts[0]
heartbeatTimeout = int(responseParts[1])
# connectionTimeout = int(responseParts[2])
supportedTransports = responseParts[3].split(',')
if 'websocket' not in supportedTransports:
raise SocketIOError('Could not parse handshake') # pragma: no cover
socketScheme = 'wss' if secure else 'ws'
socketURL = '%s://%s/websocket/%s' % (socketScheme, baseURL, sessionID)
self.connection = create_connection(socketURL)
self.heartbeatInterval = heartbeatTimeout - 2
def recv_packet(self):
code, packetID, channelName, data = -1, None, None, None
try:
packet = self.connection.recv()
except WebSocketConnectionClosedException:
raise SocketIOConnectionError('Lost connection (Connection closed)')
except socket.timeout:
raise SocketIOConnectionError('Lost connection (Connection timed out)')
try:
packetParts = packet.split(':', 3)
except AttributeError:
raise SocketIOPacketError('Received invalid packet (%s)' % packet)
packetCount = len(packetParts)
if 4 == packetCount:
code, packetID, channelName, data = packetParts
elif 3 == packetCount:
code, packetID, channelName = packetParts
elif 1 == packetCount: # pragma: no cover
code = packetParts[0]
return int(code), packetID, channelName, data
def send_packet(self, code, channelName='', data='', callback=None):
callbackNumber = self.set_messageID_callback(callback) if callback else ''
packetParts = [str(code), callbackNumber, channelName, data]
try:
self.connection.send(':'.join(packetParts))
except socket.error:
raise SocketIOPacketError('Could not send packet')
def set_messageID_callback(self, callback):
'Set callback that will be called after receiving an acknowledgment'
self.messageID += 1
self.callbackByMessageID[self.messageID] = callback
return '%s+' % self.messageID
def get_messageID_callback(self, messageID):
'Get callback associated with messageID'
try:
callback = self.callbackByMessageID[messageID]
del self.callbackByMessageID[messageID]
return callback
except KeyError:
return
@property
def has_messageID_callback(self):
return True if self.callbackByMessageID else False
def get_event_callback(self, channelName, eventName):
'Get callback associated with channelName and eventName'
_socketIO = self.channelByName[channelName] if channelName else self
try:
return _socketIO.callbackByEvent[eventName]
except KeyError:
pass
def callback_(*eventArguments):
return _socketIO.namespace.on_(eventName, *eventArguments)
callbackName = 'on_' + eventName.replace(' ', '_')
return getattr(_socketIO.namespace, callbackName, callback_)
@property
def connected(self):
return self.connection.connected
def send_heartbeat(self):
try:
self.send_packet(2)
except SocketIOPacketError:
print 'Could not send heartbeat'
pass
class _RhythmicThread(Thread):
'Execute call every few seconds'
@ -273,23 +151,32 @@ class _ListenerThread(Thread):
daemon = True
def __init__(self, _socketIO, get_event_callback, ):
def __init__(self, _socketIO, _channelByPath):
super(_ListenerThread, self).__init__()
self._socketIO = _socketIO
self._channelByPath = _channelByPath
self.done = Event()
self.waitingForCallbacks = Event()
self.waiting = Event()
def cancel(self):
self.done.set()
def wait_for_callbacks(self, seconds):
self.waiting.set()
# Block callingThread until listenerThread terminates
self.join(seconds)
def run(self):
while not self.done.is_set():
try:
code, packetID, channelName, data = self._socketIO.recv_packet()
code, packetID, channelPath, data = self._socketIO.recv_packet()
except SocketIOConnectionError, error:
print error
return
except SocketIOPacketError, error:
print error
continue
get_channel_callback = partial(self._socketIO.get_event_callback, channelName)
channel = self._channelByPath[channelPath]
try:
delegate = {
0: self.on_disconnect,
@ -303,50 +190,174 @@ class _ListenerThread(Thread):
}[code]
except KeyError:
continue
delegate(packetID, get_channel_callback, data)
delegate(packetID, channel._get_eventCallback, data)
def cancel(self):
self.done.set()
def on_disconnect(self, packetID, get_eventCallback, data):
get_eventCallback('disconnect')()
def wait_for_callbacks(self, seconds):
self.waitingForCallbacks.set()
self.join(seconds)
def on_connect(self, packetID, get_eventCallback, data):
get_eventCallback('connect')()
def on_disconnect(self, packetID, get_channel_callback, data):
get_channel_callback('disconnect')()
def on_heartbeat(self, packetID, get_eventCallback, data):
get_eventCallback('heartbeat')()
def on_connect(self, packetID, get_channel_callback, data):
get_channel_callback('connect')()
def on_message(self, packetID, get_eventCallback, data):
get_eventCallback('message')(data)
def on_heartbeat(self, packetID, get_channel_callback, data):
pass
def on_json(self, packetID, get_eventCallback, data):
get_eventCallback('message')(loads(data))
def on_message(self, packetID, get_channel_callback, data):
get_channel_callback('message')(data)
def on_json(self, packetID, get_channel_callback, data):
get_channel_callback('message')(loads(data))
def on_event(self, packetID, get_channel_callback, data):
def on_event(self, packetID, get_eventCallback, data):
valueByName = loads(data)
eventName = valueByName['name']
eventArguments = valueByName['args']
get_channel_callback(eventName)(*eventArguments)
get_eventCallback(eventName)(*eventArguments)
def on_acknowledgment(self, packetID, get_channel_callback, data):
def on_acknowledgment(self, packetID, get_eventCallback, data):
dataParts = data.split('+', 1)
messageID = int(dataParts[0])
arguments = loads(dataParts[1]) or []
callback = self._socketIO.get_messageID_callback(messageID)
callback = self._socketIO.get_messageCallback(messageID)
if not callback:
return
callback(*arguments)
if self.waitingForCallbacks.is_set() and not self._socketIO.has_messageID_callback:
if self.waiting.is_set() and not self._socketIO.has_messageCallback:
self.cancel()
def on_error(self, packetID, get_channel_callback, data):
def on_error(self, packetID, get_eventCallback, data):
reason, advice = data.split('+', 1)
get_channel_callback('error')(reason, advice)
get_eventCallback('error')(reason, advice)
class _SocketIO(object):
'Low-level interface to remove cyclic references in child threads'
messageID = 0
self.callbackByMessageID = {}
self.callbackByEvent = {}
def __init__(self, host, port, secure, proxies):
baseURL = '%s:%d/socket.io/%s' % (host, port, PROTOCOL)
targetScheme = 'https' if secure else 'http'
targetURL = '%s://%s/' % (targetScheme, baseURL)
try:
response = urlopen(targetURL, proxies=proxies)
except IOError: # pragma: no cover
raise SocketIOError('Could not start connection')
if 200 != response.getcode(): # pragma: no cover
raise SocketIOError('Could not establish connection')
responseParts = response.readline().split(':')
sessionID = responseParts[0]
heartbeatTimeout = int(responseParts[1])
# connectionTimeout = int(responseParts[2])
supportedTransports = responseParts[3].split(',')
if 'websocket' not in supportedTransports:
raise SocketIOError('Could not parse handshake') # pragma: no cover
socketScheme = 'wss' if secure else 'ws'
socketURL = '%s://%s/websocket/%s' % (socketScheme, baseURL, sessionID)
self.connection = create_connection(socketURL)
self.heartbeatInterval = heartbeatTimeout - 2
def __del__(self):
self.connection.close()
def get_channel(self, channelPath):
def get_channel(self, channelName):
return self.channelByName[channelName]
pass
def recv_packet(self):
code, packetID, channelName, data = -1, None, None, None
try:
packet = self.connection.recv()
except WebSocketConnectionClosedException:
raise SocketIOConnectionError('Lost connection (Connection closed)')
except socket.timeout:
raise SocketIOConnectionError('Lost connection (Connection timed out)')
try:
packetParts = packet.split(':', 3)
except AttributeError:
raise SocketIOPacketError('Received invalid packet (%s)' % packet)
packetCount = len(packetParts)
if 4 == packetCount:
code, packetID, channelName, data = packetParts
elif 3 == packetCount:
code, packetID, channelName = packetParts
elif 1 == packetCount: # pragma: no cover
code = packetParts[0]
return int(code), packetID, channelName, data
def send_packet(self, code, channelName='', data='', callback=None):
callbackNumber = self.set_messageCallback(callback) if callback else ''
packetParts = [str(code), callbackNumber, channelName, data]
try:
self.connection.send(':'.join(packetParts))
except socket.error:
raise SocketIOPacketError('Could not send packet')
def disconnect(self, channelPath):
self.send_packet(0, channelPath)
def connect(self, channelPath):
self.send_packet(1, channelPath)
def send_heartbeat(self):
try:
self.send_packet(2)
except SocketIOPacketError:
print 'Could not send heartbeat'
pass
def set_messageCallback(self, callback):
'Set callback that will be called after receiving an acknowledgment'
self.messageID += 1
self.callbackByMessageID[self.messageID] = callback
return '%s+' % self.messageID
def get_messageCallback(self, messageID):
try:
callback = self.callbackByMessageID[messageID]
del self.callbackByMessageID[messageID]
return callback
except KeyError:
return
@property
def has_messageCallback(self):
return True if self.callbackByMessageID else False
class Channel(object):
def __init__(self, _socketIO, Namespace, channelPath):
self._socketIO = _socketIO
self._namespace = Namespace(_socketIO)
self._channelPath = channelPath
self._callbackByEvent = {}
def on(self, eventName, eventCallback):
self._callbackByEvent[eventName] = eventCallback
def message(self, messageData, messageCallback=None):
self._socketIO.message(messageData, messageCallback, channelPath=self._channelPath)
def emit(self, eventName, *eventArguments):
self._socketIO.emit(eventName, *eventArguments, channelPath=self._channelPath)
def get_namespace(self):
return self._namespace
def _get_eventCallback(self, eventName):
# Check callbacks defined by on()
try:
return self._callbackByEvent[eventName]
except KeyError:
pass
# Check callbacks defined explicitly or use on_()
defaultCallback = lambda *eventArguments: self.get_namespace().on_(eventName, *eventArguments)
return getattr(self, 'on_' + eventName.replace(' ', '_'), defaultCallback)
class SocketIOError(Exception):

View file

@ -13,18 +13,26 @@ class TestSocketIO(TestCase):
socketIO = SocketIO('localhost', 8000)
socketIO.disconnect()
self.assertEqual(socketIO.connected, False)
childThreads = [
socketIO._rhythmicThread,
socketIO._listenerThread,
]
for childThread in childThreads:
self.assertEqual(True, childThread.done.is_set())
def test_emit(self):
socketIO = SocketIO('localhost', 8000, Namespace)
socketIO = SocketIO('localhost', 8000)
socketIO.define(Namespace)
socketIO.emit('aaa')
sleep(0.5)
self.assertEqual(socketIO._namespace.payload, '')
self.assertEqual(socketIO.get_namespace().payload, '')
def test_emit_with_payload(self):
socketIO = SocketIO('localhost', 8000, Namespace)
socketIO = SocketIO('localhost', 8000)
socketIO.define(Namespace)
socketIO.emit('aaa', PAYLOAD)
sleep(0.5)
self.assertEqual(socketIO._namespace.payload, PAYLOAD)
self.assertEqual(socketIO.get_namespace().payload, PAYLOAD)
def test_emit_with_callback(self):
global ON_RESPONSE_CALLED
@ -44,20 +52,21 @@ class TestSocketIO(TestCase):
self.assertEqual(ON_RESPONSE_CALLED, True)
def test_channels(self):
mainSocket = SocketIO('localhost', 8000, Namespace)
chatSocket = mainSocket.connect('/chat', Namespace)
newsSocket = mainSocket.connect('/news', Namespace)
socketIO = SocketIO('localhost', 8000)
mainSocket = socketIO.define(Namespace)
chatSocket = socketIO.define(Namespace, '/chat')
newsSocket = socketIO.define(Namespace, '/news')
newsSocket.emit('aaa', PAYLOAD)
sleep(0.5)
self.assertNotEqual(mainSocket._namespace.payload, PAYLOAD)
self.assertNotEqual(chatSocket._namespace.payload, PAYLOAD)
self.assertEqual(newsSocket._namespace.payload, PAYLOAD)
self.assertNotEqual(mainSocket.get_namespace().payload, PAYLOAD)
self.assertNotEqual(chatSocket.get_namespace().payload, PAYLOAD)
self.assertEqual(newsSocket.get_namespace().payload, PAYLOAD)
def test_delete(self):
socketIO = SocketIO('localhost', 8000)
childThreads = [
socketIO._heartbeatThread,
socketIO._namespaceThread,
socketIO._rhythmicThread,
socketIO._listenerThread,
]
del socketIO
for childThread in childThreads: