Added test to check that child threads die when parent dies

This commit is contained in:
Roy Hyunjin Han 2013-02-09 19:12:21 -08:00
commit 0a5b069cdd
7 changed files with 193 additions and 137 deletions

View file

@ -1,3 +1,6 @@
0.4
---
0.3
---
- Added support for secure connections

View file

@ -35,10 +35,9 @@ Activate isolated environment. ::
Emit. ::
from socketIO_client import SocketIO
socketIO = SocketIO('localhost', 8000)
socketIO.emit('aaa', {'bbb': 'ccc'})
socketIO.wait(seconds=1) # Exit after one second
with SocketIO('localhost', 8000) as socketIO:
socketIO.emit('aaa')
socketIO.wait(1) # Wait a second
Emit with callback. ::
@ -47,9 +46,9 @@ Emit with callback. ::
def on_response(*args):
print args
socketIO = SocketIO('localhost', 8000)
socketIO.emit('aaa', {'bbb': 'ccc'}, on_response)
socketIO.wait(forCallbacks=True) # Exit after callbacks run
with SocketIO('localhost', 8000) as socketIO:
socketIO.emit('aaa', {'bbb': 'ccc'}, on_response)
socketIO.wait(seconds=1, forCallbacks=True) # Wait for callback
Define events. ::

View file

@ -1,5 +1,8 @@
Let user define a proxy #5
Let user emit without arguments #5
+ Fix unittests
+ Fix exceptions when websocket server disappears
Fix thread exceptions
Integrate Zac's fork #6
Integrate Sajal's fork #7
Integrate Francis's fork #10

17
serve_tests.py Executable file → Normal file
View file

@ -1,7 +1,14 @@
'Launch this server in another terminal window before running tests'
from socketio import socketio_manage
from socketio.namespace import BaseNamespace
from socketio.server import SocketIOServer
import sys
try:
from socketio import socketio_manage
from socketio.namespace import BaseNamespace
from socketio.server import SocketIOServer
except ImportError:
from setuptools.command import easy_install
easy_install.main(['-U', 'gevent-socketio'])
print('\nPlease run the script again to launch the test server.')
sys.exit(1)
class Namespace(BaseNamespace):
@ -25,5 +32,7 @@ class Application(object):
if __name__ == '__main__':
socketIOServer = SocketIOServer(('0.0.0.0', 8000), Application())
port = 8000
print 'Starting server at port %s' % port
socketIOServer = SocketIOServer(('0.0.0.0', port), Application())
socketIOServer.serve_forever()

1
setup.py Executable file → Normal file
View file

@ -24,7 +24,6 @@ setup(
url='https://github.com/invisibleroads/socketIO-client',
install_requires=[
'anyjson',
'gevent-socketio',
'websocket-client',
],
packages=find_packages(),

View file

@ -1,11 +1,16 @@
import websocket
import sys
import traceback
import socket
from anyjson import dumps, loads
from functools import partial
from threading import Thread, Event
from time import sleep
from urllib import urlopen
from websocket import WebSocketConnectionClosedException, create_connection
__version__ = '0.3'
__version__ = '0.4'
PROTOCOL = 1 # SocketIO protocol version
@ -16,7 +21,7 @@ class BaseNamespace(object): # pragma: no cover
def __init__(self, socketIO):
self.socketIO = socketIO
def on_connect(self, socketIO):
def on_connect(self):
pass
def on_disconnect(self):
@ -46,54 +51,69 @@ class BaseNamespace(object): # pragma: no cover
class SocketIO(object):
messageID = 0
_messageID = 0
def __init__(self, host, port, Namespace=BaseNamespace, secure=False):
self.host = host
self.port = int(port)
self.namespace = Namespace(self)
self.secure = secure
self.__connect()
def __init__(self, host, port, Namespace=BaseNamespace, secure=False, proxies=None):
self._host = host
self._port = int(port)
self._namespace = Namespace(self)
self._secure = secure
self._proxies = proxies
self._connect()
heartbeatInterval = self.heartbeatTimeout - 2
self.heartbeatThread = RhythmicThread(heartbeatInterval,
self._send_heartbeat)
self.heartbeatThread.start()
heartbeatInterval = self._heartbeatTimeout - 2
self._heartbeatThread = RhythmicThread(heartbeatInterval, self._send_heartbeat)
self._heartbeatThread.start()
self.channelByName = {}
self.callbackByEvent = {}
self.namespaceThread = ListenerThread(self)
self.namespaceThread.start()
self._channelByName = {}
self._callbackByEvent = {}
self._namespaceThread = ListenerThread(self._recv_packet, self._get_callback)
self._namespaceThread.start()
def __del__(self): # pragma: no cover
self.heartbeatThread.cancel()
self.namespaceThread.cancel()
self.connection.close()
def __enter__(self):
return self
def __connect(self):
baseURL = '%s:%d/socket.io/%s' % (self.host, self.port, PROTOCOL)
def __exit__(self, exc_type, exc_value, traceback):
self.__del__()
def __del__(self):
self._heartbeatThread.cancel()
self._namespaceThread.cancel()
self._connection.close()
def _connect(self):
baseURL = '%s:%d/socket.io/%s' % (self._host, self._port, PROTOCOL)
try:
response = urlopen('%s://%s/' % (
'https' if self.secure else 'http', baseURL))
'https' if self._secure else 'http', baseURL),
proxies=self._proxies)
except IOError: # pragma: no cover
raise SocketIOError('Could not start connection')
if 200 != response.getcode(): # pragma: no cover
raise SocketIOError('Could not establish connection')
responseParts = response.readline().split(':')
self.sessionID = responseParts[0]
self.heartbeatTimeout = int(responseParts[1])
self.connectionTimeout = int(responseParts[2])
self.supportedTransports = responseParts[3].split(',')
if 'websocket' not in self.supportedTransports:
self._sessionID = responseParts[0]
self._heartbeatTimeout = int(responseParts[1])
self._connectionTimeout = int(responseParts[2])
self._supportedTransports = responseParts[3].split(',')
if 'websocket' not in self._supportedTransports:
raise SocketIOError('Could not parse handshake') # pragma: no cover
socketURL = '%s://%s/websocket/%s' % (
'wss' if self.secure else 'ws', baseURL, self.sessionID)
self.connection = websocket.create_connection(socketURL)
'wss' if self._secure else 'ws', baseURL, self._sessionID)
self._connection = create_connection(socketURL)
def _recv_packet(self):
code, packetID, channelName, data = -1, None, None, None
packet = self.connection.recv()
packetParts = packet.split(':', 3)
try:
packet = self._connection.recv()
except WebSocketConnectionClosedException:
raise SocketIOConnectionError('Lost connection (Connection closed)')
except socket.timeout:
raise SocketIOConnectionError('Lost connection (Connection timed out)')
try:
packetParts = packet.split(':', 3)
except AttributeError:
raise SocketIOPacketError('Received invalid packet (%s)' % packet)
packetCount = len(packetParts)
if 4 == packetCount:
code, packetID, channelName, data = packetParts
@ -104,34 +124,36 @@ class SocketIO(object):
return int(code), packetID, channelName, data
def _send_packet(self, code, channelName='', data='', callback=None):
self.connection.send(':'.join([
str(code),
self.set_callback(callback) if callback else '',
channelName,
data]))
callbackNumber = self._set_callback(callback) if callback else ''
packetParts = [str(code), callbackNumber, channelName, data]
try:
self._connection.send(':'.join(packetParts))
except socket.error:
raise SocketIOPacketError('Could not send packet')
def disconnect(self, channelName=''):
self._send_packet(0, channelName)
if channelName:
del self.channelByName[channelName]
del self._channelByName[channelName]
else:
self.__del__()
@property
def connected(self):
return self.connection.connected
return self._connection.connected
def connect(self, channelName, Namespace=BaseNamespace):
channel = Channel(self, channelName, Namespace)
self.channelByName[channelName] = channel
self._channelByName[channelName] = channel
self._send_packet(1, channelName)
return channel
def _send_heartbeat(self):
try:
self._send_packet(2)
except:
self.__del__()
except SocketIOPacketError:
print 'Could not send heartbeat'
pass
def message(self, messageData, callback=None, channelName=''):
if isinstance(messageData, basestring):
@ -144,40 +166,39 @@ class SocketIO(object):
def emit(self, eventName, *eventArguments, **eventKeywords):
code = 5
if callable(eventArguments[-1]):
callback = None
if eventArguments and callable(eventArguments[-1]):
callback = eventArguments[-1]
eventArguments = eventArguments[:-1]
else:
callback = None
channelName = eventKeywords.get('channelName', '')
data = dumps(dict(name=eventName, args=eventArguments))
self._send_packet(code, channelName, data, callback)
def get_callback(self, channelName, eventName):
def _get_callback(self, channelName, eventName):
'Get callback associated with channelName and eventName'
socketIO = self.channelByName[channelName] if channelName else self
socketIO = self._channelByName[channelName] if channelName else self
try:
return socketIO.callbackByEvent[eventName]
return socketIO._callbackByEvent[eventName]
except KeyError:
pass
namespace = socketIO.namespace
def callback_(*eventArguments):
return namespace.on_(eventName, *eventArguments)
return getattr(namespace, name_callback(eventName), callback_)
return socketIO._namespace.on_(eventName, *eventArguments)
callbackName = 'on_' + eventName.replace(' ', '_')
return getattr(socketIO._namespace, callbackName, callback_)
def set_callback(self, callback):
def _set_callback(self, callback):
'Set callback that will be called after receiving an acknowledgment'
self.messageID += 1
self.namespaceThread.set_callback(self.messageID, callback)
return '%s+' % self.messageID
self._messageID += 1
self._namespaceThread.set_callback(self._messageID, callback)
return '%s+' % self._messageID
def on(self, eventName, callback):
self.callbackByEvent[eventName] = callback
self._callbackByEvent[eventName] = callback
def wait(self, seconds=None, forCallbacks=False):
if forCallbacks:
self.namespaceThread.wait_for_callbacks(seconds)
self._namespaceThread.wait_for_callbacks(seconds)
elif seconds:
sleep(seconds)
else:
@ -191,24 +212,22 @@ class SocketIO(object):
class Channel(object):
def __init__(self, socketIO, channelName, Namespace):
self.socketIO = socketIO
self.channelName = channelName
self.namespace = Namespace(self)
self.callbackByEvent = {}
self._socketIO = socketIO
self._channelName = channelName
self._namespace = Namespace(self)
self._callbackByEvent = {}
def disconnect(self):
self.socketIO.disconnect(self.channelName)
self._socketIO.disconnect(self._channelName)
def emit(self, eventName, *eventArguments):
self.socketIO.emit(eventName, *eventArguments,
channelName=self.channelName)
self._socketIO.emit(eventName, *eventArguments, channelName=self._channelName)
def message(self, messageData, callback=None):
self.socketIO.message(messageData, callback,
channelName=self.channelName)
self._socketIO.message(messageData, callback, channelName=self._channelName)
def on(self, eventName, eventCallback):
self.callbackByEvent[eventName] = eventCallback
self._callbackByEvent[eventName] = eventCallback
class ListenerThread(Thread):
@ -216,34 +235,43 @@ class ListenerThread(Thread):
daemon = True
def __init__(self, socketIO):
def __init__(self, recv_packet, get_callback):
super(ListenerThread, self).__init__()
self.socketIO = socketIO
self.done = Event()
self.waitingForCallbacks = Event()
self.callbackByMessageID = {}
self.get_callback = self.socketIO.get_callback
self.recv_packet = recv_packet
self.get_callback = get_callback
def run(self):
while not self.done.is_set():
try:
code, packetID, channelName, data = self.socketIO._recv_packet()
except:
continue
try:
delegate = {
0: self.on_disconnect,
1: self.on_connect,
2: self.on_heartbeat,
3: self.on_message,
4: self.on_json,
5: self.on_event,
6: self.on_acknowledgment,
7: self.on_error,
}[code]
except KeyError:
continue
delegate(packetID, channelName, data)
try:
while not self.done.is_set():
try:
code, packetID, channelName, data = self.recv_packet()
except SocketIOConnectionError, error:
print error
return
except SocketIOPacketError, error:
print error
continue
get_channel_callback = partial(self.get_callback, channelName)
try:
delegate = {
0: self.on_disconnect,
1: self.on_connect,
2: self.on_heartbeat,
3: self.on_message,
4: self.on_json,
5: self.on_event,
6: self.on_acknowledgment,
7: self.on_error,
}[code]
except KeyError:
continue
delegate(packetID, get_channel_callback, data)
except:
exc_type, exc_value, exc_traceback = sys.exc_info()
open('tracebacks.log', 'a+t').write('\n'.join(traceback.format_tb(exc_traceback)))
def cancel(self):
self.done.set()
@ -255,33 +283,28 @@ class ListenerThread(Thread):
def set_callback(self, messageID, callback):
self.callbackByMessageID[messageID] = callback
def on_disconnect(self, packetID, channelName, data):
callback = self.get_callback(channelName, 'disconnect')
callback()
def on_disconnect(self, packetID, get_channel_callback, data):
get_channel_callback('disconnect')()
def on_connect(self, packetID, channelName, data):
callback = self.get_callback(channelName, 'connect')
callback(self.socketIO)
def on_connect(self, packetID, get_channel_callback, data):
get_channel_callback('connect')()
def on_heartbeat(self, packetID, channelName, data):
def on_heartbeat(self, packetID, get_channel_callback, data):
pass
def on_message(self, packetID, channelName, data):
callback = self.get_callback(channelName, 'message')
callback(data)
def on_message(self, packetID, get_channel_callback, data):
get_channel_callback('message')(data)
def on_json(self, packetID, channelName, data):
callback = self.get_callback(channelName, 'message')
callback(loads(data))
def on_json(self, packetID, get_channel_callback, data):
get_channel_callback('message')(loads(data))
def on_event(self, packetID, channelName, data):
def on_event(self, packetID, get_channel_callback, data):
valueByName = loads(data)
eventName = valueByName['name']
eventArguments = valueByName['args']
callback = self.get_callback(channelName, eventName)
callback(*eventArguments)
get_channel_callback(eventName)(*eventArguments)
def on_acknowledgment(self, packetID, channelName, data):
def on_acknowledgment(self, packetID, get_channel_callback, data):
dataParts = data.split('+', 1)
messageID = int(dataParts[0])
arguments = loads(dataParts[1]) or []
@ -296,21 +319,20 @@ class ListenerThread(Thread):
if self.waitingForCallbacks.is_set() and not callbackCount:
self.cancel()
def on_error(self, packetID, channelName, data):
def on_error(self, packetID, get_channel_callback, data):
reason, advice = data.split('+', 1)
callback = self.get_callback(channelName, 'error')
callback(reason, advice)
get_channel_callback('error')(reason, advice)
class RhythmicThread(Thread):
'Execute rhythmicFunction every few seconds'
'Execute call every few seconds'
daemon = True
def __init__(self, intervalInSeconds, rhythmicFunction, *args, **kw):
def __init__(self, intervalInSeconds, call, *args, **kw):
super(RhythmicThread, self).__init__()
self.intervalInSeconds = intervalInSeconds
self.rhythmicFunction = rhythmicFunction
self.call = call
self.args = args
self.kw = kw
self.done = Event()
@ -318,10 +340,11 @@ class RhythmicThread(Thread):
def run(self):
try:
while not self.done.is_set():
self.rhythmicFunction(*self.args, **self.kw)
self.call(*self.args, **self.kw)
self.done.wait(self.intervalInSeconds)
except:
pass
exc_type, exc_value, exc_traceback = sys.exc_info()
open('tracebacks.log', 'a+t').write('\n'.join(traceback.format_tb(exc_traceback)))
def cancel(self):
self.done.set()
@ -331,5 +354,9 @@ class SocketIOError(Exception):
pass
def name_callback(eventName):
return 'on_' + eventName.replace(' ', '_')
class SocketIOConnectionError(SocketIOError):
pass
class SocketIOPacketError(SocketIOError):
pass

View file

@ -15,10 +15,16 @@ class TestSocketIO(TestCase):
self.assertEqual(socketIO.connected, False)
def test_emit(self):
socketIO = SocketIO('localhost', 8000, Namespace)
socketIO.emit('aaa')
sleep(0.5)
self.assertEqual(socketIO._namespace.payload, '')
def test_emit_with_payload(self):
socketIO = SocketIO('localhost', 8000, Namespace)
socketIO.emit('aaa', PAYLOAD)
sleep(0.5)
self.assertEqual(socketIO.namespace.payload, PAYLOAD)
self.assertEqual(socketIO._namespace.payload, PAYLOAD)
def test_emit_with_callback(self):
global ON_RESPONSE_CALLED
@ -43,16 +49,26 @@ class TestSocketIO(TestCase):
newsSocket = mainSocket.connect('/news', Namespace)
newsSocket.emit('aaa', PAYLOAD)
sleep(0.5)
self.assertNotEqual(mainSocket.namespace.payload, PAYLOAD)
self.assertNotEqual(chatSocket.namespace.payload, PAYLOAD)
self.assertEqual(newsSocket.namespace.payload, PAYLOAD)
self.assertNotEqual(mainSocket._namespace.payload, PAYLOAD)
self.assertNotEqual(chatSocket._namespace.payload, PAYLOAD)
self.assertEqual(newsSocket._namespace.payload, PAYLOAD)
def test_delete(self):
socketIO = SocketIO('localhost', 8000)
childThreads = [
socketIO._heartbeatThread,
socketIO._namespaceThread,
]
del socketIO
for childThread in childThreads:
self.assertEqual(True, childThread.done.is_set())
class Namespace(BaseNamespace):
payload = None
def on_ddd(self, data):
def on_ddd(self, data=''):
self.payload = data