add parallel_vectorize_from_func

This commit is contained in:
Siu Kwan Lam 2012-08-08 17:33:00 -07:00
commit b23d7909fb

View file

@ -16,6 +16,8 @@ from llvm.passes import *
from llvm_cbuilder import *
import llvm_cbuilder.shortnames as C
import sys
class WorkQueue(CStruct):
'''structure for workqueue for parallel-ufunc.
'''
@ -365,3 +367,85 @@ class PThreadAPI(CExternal):
pthread_join = Type.function(C.int, [C.void_p, C.void_p])
class UFuncCoreGeneric(UFuncCore):
'''A generic ufunc core worker from LLVM function type
'''
def _do_work(self, common, item, tid):
ufunc_type = Type.function(self.RETTY, self.ARGTYS)
ufunc_ptr = CFunc(self, common.func.cast(C.pointer(ufunc_type)).value)
get_offset = lambda B, S, T: B[item * S].reference().cast(C.pointer(T))
indata = []
for i, argty in enumerate(self.ARGTYS):
ptr = get_offset(common.args[i], common.steps[i], argty)
indata.append(ptr.load())
out_index = len(self.ARGTYS)
outptr = get_offset(common.args[out_index], common.steps[out_index],
self.RETTY)
res = ufunc_ptr(*indata)
outptr.store(res)
@classmethod
def specialize(cls, fntype):
'''specialize to a LLVM function type
fntype : a LLVM function type (llvm.core.FunctionType)
'''
cls._name_ = '.'.join([cls._name_] +
map(str, [fntype.return_type] + fntype.args))
cls.RETTY = fntype.return_type
cls.ARGTYS = tuple(fntype.args)
if sys.platform not in ['win32']:
class ParallelUFuncPlatform(ParallelUFunc, ParallelUFuncPosixMixin):
pass
else:
raise NotImplementedError("Threading for %s" % sys.platform)
def parallel_vectorize_from_func(lfunc, engine=None):
fntype = lfunc.type.pointee
def_spuf = SpecializedParallelUFunc(ParallelUFuncPlatform(num_thread=2),
UFuncCoreGeneric(fntype),
CFuncRef(lfunc))
spuf = def_spuf(lfunc.module)
if engine is None:
return spuf
else:
import numpy as np
fptr = engine.get_pointer_to_function(spuf)
inct = len(fntype.args)
outct = 1
# TODO refactor
typemap = {
'i8' : np.uint8,
'i16' : np.uint16,
'i32' : np.uint32,
'i64' : np.uint64,
'float' : np.float32,
'double' : np.float64,
}
try:
ptr_t = long
except:
ptr_t = int
assert False, "Having check this yet"
get_typenum = lambda T:np.dtype(typemap[str(T)]).num
assert fntype.return_type != C.void
tys = list(map(get_typenum, list(fntype.args) + [fntype.return_type]))
# Becareful that fromfunc does not provide full error checking yet.
# If typenum is out-of-bound, we have nasty memory corruptions.
# For instance, -1 for typenum will cause segfault.
# If elements of type-list (2nd arg) is tuple instead,
# there will also memory corruption. (Seems like code rewrite.)
return np.fromfunc([ptr_t(fptr)], [tys], inct, outct, [None])