add parallel_vectorize_from_func
This commit is contained in:
parent
e35b8ca636
commit
b23d7909fb
1 changed files with 84 additions and 0 deletions
|
|
@ -16,6 +16,8 @@ from llvm.passes import *
|
|||
from llvm_cbuilder import *
|
||||
import llvm_cbuilder.shortnames as C
|
||||
|
||||
import sys
|
||||
|
||||
class WorkQueue(CStruct):
|
||||
'''structure for workqueue for parallel-ufunc.
|
||||
'''
|
||||
|
|
@ -365,3 +367,85 @@ class PThreadAPI(CExternal):
|
|||
pthread_join = Type.function(C.int, [C.void_p, C.void_p])
|
||||
|
||||
|
||||
class UFuncCoreGeneric(UFuncCore):
|
||||
'''A generic ufunc core worker from LLVM function type
|
||||
'''
|
||||
def _do_work(self, common, item, tid):
|
||||
ufunc_type = Type.function(self.RETTY, self.ARGTYS)
|
||||
ufunc_ptr = CFunc(self, common.func.cast(C.pointer(ufunc_type)).value)
|
||||
|
||||
get_offset = lambda B, S, T: B[item * S].reference().cast(C.pointer(T))
|
||||
|
||||
indata = []
|
||||
for i, argty in enumerate(self.ARGTYS):
|
||||
ptr = get_offset(common.args[i], common.steps[i], argty)
|
||||
indata.append(ptr.load())
|
||||
|
||||
out_index = len(self.ARGTYS)
|
||||
outptr = get_offset(common.args[out_index], common.steps[out_index],
|
||||
self.RETTY)
|
||||
|
||||
res = ufunc_ptr(*indata)
|
||||
outptr.store(res)
|
||||
|
||||
@classmethod
|
||||
def specialize(cls, fntype):
|
||||
'''specialize to a LLVM function type
|
||||
|
||||
fntype : a LLVM function type (llvm.core.FunctionType)
|
||||
'''
|
||||
cls._name_ = '.'.join([cls._name_] +
|
||||
map(str, [fntype.return_type] + fntype.args))
|
||||
|
||||
cls.RETTY = fntype.return_type
|
||||
cls.ARGTYS = tuple(fntype.args)
|
||||
|
||||
|
||||
if sys.platform not in ['win32']:
|
||||
class ParallelUFuncPlatform(ParallelUFunc, ParallelUFuncPosixMixin):
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError("Threading for %s" % sys.platform)
|
||||
|
||||
|
||||
|
||||
def parallel_vectorize_from_func(lfunc, engine=None):
|
||||
fntype = lfunc.type.pointee
|
||||
def_spuf = SpecializedParallelUFunc(ParallelUFuncPlatform(num_thread=2),
|
||||
UFuncCoreGeneric(fntype),
|
||||
CFuncRef(lfunc))
|
||||
spuf = def_spuf(lfunc.module)
|
||||
if engine is None:
|
||||
return spuf
|
||||
else:
|
||||
import numpy as np
|
||||
|
||||
fptr = engine.get_pointer_to_function(spuf)
|
||||
inct = len(fntype.args)
|
||||
outct = 1
|
||||
|
||||
# TODO refactor
|
||||
typemap = {
|
||||
'i8' : np.uint8,
|
||||
'i16' : np.uint16,
|
||||
'i32' : np.uint32,
|
||||
'i64' : np.uint64,
|
||||
'float' : np.float32,
|
||||
'double' : np.float64,
|
||||
}
|
||||
try:
|
||||
ptr_t = long
|
||||
except:
|
||||
ptr_t = int
|
||||
assert False, "Having check this yet"
|
||||
|
||||
get_typenum = lambda T:np.dtype(typemap[str(T)]).num
|
||||
assert fntype.return_type != C.void
|
||||
tys = list(map(get_typenum, list(fntype.args) + [fntype.return_type]))
|
||||
# Becareful that fromfunc does not provide full error checking yet.
|
||||
# If typenum is out-of-bound, we have nasty memory corruptions.
|
||||
# For instance, -1 for typenum will cause segfault.
|
||||
# If elements of type-list (2nd arg) is tuple instead,
|
||||
# there will also memory corruption. (Seems like code rewrite.)
|
||||
return np.fromfunc([ptr_t(fptr)], [tys], inct, outct, [None])
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue