diff --git a/parallel_vectorize.py b/parallel_vectorize.py index f26ad46..f70954b 100644 --- a/parallel_vectorize.py +++ b/parallel_vectorize.py @@ -16,6 +16,8 @@ from llvm.passes import * from llvm_cbuilder import * import llvm_cbuilder.shortnames as C +import sys + class WorkQueue(CStruct): '''structure for workqueue for parallel-ufunc. ''' @@ -365,3 +367,85 @@ class PThreadAPI(CExternal): pthread_join = Type.function(C.int, [C.void_p, C.void_p]) +class UFuncCoreGeneric(UFuncCore): + '''A generic ufunc core worker from LLVM function type + ''' + def _do_work(self, common, item, tid): + ufunc_type = Type.function(self.RETTY, self.ARGTYS) + ufunc_ptr = CFunc(self, common.func.cast(C.pointer(ufunc_type)).value) + + get_offset = lambda B, S, T: B[item * S].reference().cast(C.pointer(T)) + + indata = [] + for i, argty in enumerate(self.ARGTYS): + ptr = get_offset(common.args[i], common.steps[i], argty) + indata.append(ptr.load()) + + out_index = len(self.ARGTYS) + outptr = get_offset(common.args[out_index], common.steps[out_index], + self.RETTY) + + res = ufunc_ptr(*indata) + outptr.store(res) + + @classmethod + def specialize(cls, fntype): + '''specialize to a LLVM function type + + fntype : a LLVM function type (llvm.core.FunctionType) + ''' + cls._name_ = '.'.join([cls._name_] + + map(str, [fntype.return_type] + fntype.args)) + + cls.RETTY = fntype.return_type + cls.ARGTYS = tuple(fntype.args) + + +if sys.platform not in ['win32']: + class ParallelUFuncPlatform(ParallelUFunc, ParallelUFuncPosixMixin): + pass +else: + raise NotImplementedError("Threading for %s" % sys.platform) + + + +def parallel_vectorize_from_func(lfunc, engine=None): + fntype = lfunc.type.pointee + def_spuf = SpecializedParallelUFunc(ParallelUFuncPlatform(num_thread=2), + UFuncCoreGeneric(fntype), + CFuncRef(lfunc)) + spuf = def_spuf(lfunc.module) + if engine is None: + return spuf + else: + import numpy as np + + fptr = engine.get_pointer_to_function(spuf) + inct = len(fntype.args) + outct = 1 + + # TODO refactor + typemap = { + 'i8' : np.uint8, + 'i16' : np.uint16, + 'i32' : np.uint32, + 'i64' : np.uint64, + 'float' : np.float32, + 'double' : np.float64, + } + try: + ptr_t = long + except: + ptr_t = int + assert False, "Having check this yet" + + get_typenum = lambda T:np.dtype(typemap[str(T)]).num + assert fntype.return_type != C.void + tys = list(map(get_typenum, list(fntype.args) + [fntype.return_type])) + # Becareful that fromfunc does not provide full error checking yet. + # If typenum is out-of-bound, we have nasty memory corruptions. + # For instance, -1 for typenum will cause segfault. + # If elements of type-list (2nd arg) is tuple instead, + # there will also memory corruption. (Seems like code rewrite.) + return np.fromfunc([ptr_t(fptr)], [tys], inct, outct, [None]) +