# # Copyright (c) 2008-10, Mahadevan R All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # # * Redistributions of source code must retain the above copyright notice, # this list of conditions and the following disclaimer. # # * Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # # * Neither the name of this software, nor the names of its # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # from io import BytesIO try: from StringIO import StringIO except ImportError: try: from cStringIO import StringIO except ImportError: from io import StringIO import contextlib, weakref import llvm from llvm._intrinsic_ids import * from llvmpy import api #===----------------------------------------------------------------------=== # Enumerations #===----------------------------------------------------------------------=== class Enum(int): '''Overload integer to print the name of the enum. ''' def __repr__(self): return '%s(%d)' % (type(self).__name__, self) @classmethod def declare(cls): declared = cls._declared_ = {} scope = globals() for name in filter(lambda s: s.startswith(cls.prefix), dir(cls)): n = getattr(cls, name) typ = type(name, (cls,), {}) obj = typ(n) declared[n] = obj scope[name] = obj @classmethod def get(cls, num): return cls._declared_[num] # type id (llvm::Type::TypeID) class TypeEnum(Enum): prefix = 'TYPE_' TypeID = api.llvm.Type.TypeID TYPE_VOID = TypeID.VoidTyID TYPE_HALF = TypeID.HalfTyID TYPE_FLOAT = TypeID.FloatTyID TYPE_DOUBLE = TypeID.DoubleTyID TYPE_X86_FP80 = TypeID.X86_FP80TyID TYPE_FP128 = TypeID.FP128TyID TYPE_PPC_FP128 = TypeID.PPC_FP128TyID TYPE_LABEL = TypeID.LabelTyID TYPE_INTEGER = TypeID.IntegerTyID TYPE_FUNCTION = TypeID.FunctionTyID TYPE_STRUCT = TypeID.StructTyID TYPE_ARRAY = TypeID.ArrayTyID TYPE_POINTER = TypeID.PointerTyID TYPE_VECTOR = TypeID.VectorTyID TYPE_METADATA = TypeID.MetadataTyID TYPE_X86_MMX = TypeID.X86_MMXTyID TypeEnum.declare() # value IDs (llvm::Value::ValueTy enum) # According to the doxygen docs, it is not a good idea to use these enums. # There are more values than those declared. class ValueEnum(Enum): prefix = 'VALUE_' ValueTy = api.llvm.Value.ValueTy VALUE_ARGUMENT = ValueTy.ArgumentVal VALUE_BASIC_BLOCK = ValueTy.BasicBlockVal VALUE_FUNCTION = ValueTy.FunctionVal VALUE_GLOBAL_ALIAS = ValueTy.GlobalAliasVal VALUE_GLOBAL_VARIABLE = ValueTy.GlobalVariableVal VALUE_UNDEF_VALUE = ValueTy.UndefValueVal VALUE_BLOCK_ADDRESS = ValueTy.BlockAddressVal VALUE_CONSTANT_EXPR = ValueTy.ConstantExprVal VALUE_CONSTANT_AGGREGATE_ZERO = ValueTy.ConstantAggregateZeroVal VALUE_CONSTANT_DATA_ARRAY = ValueTy.ConstantDataArrayVal VALUE_CONSTANT_DATA_VECTOR = ValueTy.ConstantDataVectorVal VALUE_CONSTANT_INT = ValueTy.ConstantIntVal VALUE_CONSTANT_FP = ValueTy.ConstantFPVal VALUE_CONSTANT_ARRAY = ValueTy.ConstantArrayVal VALUE_CONSTANT_STRUCT = ValueTy.ConstantStructVal VALUE_CONSTANT_VECTOR = ValueTy.ConstantVectorVal VALUE_CONSTANT_POINTER_NULL = ValueTy.ConstantPointerNullVal VALUE_MD_NODE = ValueTy.MDNodeVal VALUE_MD_STRING = ValueTy.MDStringVal VALUE_INLINE_ASM = ValueTy.InlineAsmVal VALUE_PSEUDO_SOURCE_VALUE = ValueTy.PseudoSourceValueVal VALUE_FIXED_STACK_PSEUDO_SOURCE_VALUE = ValueTy.FixedStackPseudoSourceValueVal VALUE_INSTRUCTION = ValueTy.InstructionVal ValueEnum.declare() # instruction opcodes (from include/llvm/Instruction.def) class OpcodeEnum(Enum): prefix = 'OPCODE_' OPCODE_RET = 1 OPCODE_BR = 2 OPCODE_SWITCH = 3 OPCODE_INDIRECT_BR = 4 OPCODE_INVOKE = 5 OPCODE_RESUME = 6 OPCODE_UNREACHABLE = 7 OPCODE_ADD = 8 OPCODE_FADD = 9 OPCODE_SUB = 10 OPCODE_FSUB = 11 OPCODE_MUL = 12 OPCODE_FMUL = 13 OPCODE_UDIV = 14 OPCODE_SDIV = 15 OPCODE_FDIV = 16 OPCODE_UREM = 17 OPCODE_SREM = 18 OPCODE_FREM = 19 OPCODE_SHL = 20 OPCODE_LSHR = 21 OPCODE_ASHR = 22 OPCODE_AND = 23 OPCODE_OR = 24 OPCODE_XOR = 25 OPCODE_ALLOCA = 26 OPCODE_LOAD = 27 OPCODE_STORE = 28 OPCODE_GETELEMENTPTR = 29 OPCODE_FENCE = 30 OPCODE_ATOMICCMPXCHG = 31 OPCODE_ATOMICRMW = 32 OPCODE_TRUNC = 33 OPCODE_ZEXT = 34 OPCODE_SEXT = 35 OPCODE_FPTOUI = 36 OPCODE_FPTOSI = 37 OPCODE_UITOFP = 38 OPCODE_SITOFP = 39 OPCODE_FPTRUNC = 40 OPCODE_FPEXT = 41 OPCODE_PTRTOINT = 42 OPCODE_INTTOPTR = 43 OPCODE_BITCAST = 44 OPCODE_ICMP = 45 OPCODE_FCMP = 46 OPCODE_PHI = 47 OPCODE_CALL = 48 OPCODE_SELECT = 49 OPCODE_USEROP1 = 50 OPCODE_USEROP2 = 51 OPCODE_VAARG = 52 OPCODE_EXTRACTELEMENT = 53 OPCODE_INSERTELEMENT = 54 OPCODE_SHUFFLEVECTOR = 55 OPCODE_EXTRACTVALUE = 56 OPCODE_INSERTVALUE = 57 OPCODE_LANDINGPAD = 58 OpcodeEnum.declare() # calling conventions class CCEnum(Enum): prefix = 'CC_' ID = api.llvm.CallingConv.ID CC_C = ID.C CC_FASTCALL = ID.Fast CC_COLDCALL = ID.Cold CC_GHC = ID.GHC CC_X86_STDCALL = ID.X86_StdCall CC_X86_FASTCALL = ID.X86_FastCall CC_ARM_APCS = ID.ARM_APCS CC_ARM_AAPCS = ID.ARM_AAPCS CC_ARM_AAPCS_VFP = ID.ARM_AAPCS_VFP CC_MSP430_INTR = ID.MSP430_INTR CC_X86_THISCALL = ID.X86_ThisCall CC_PTX_KERNEL = ID.PTX_Kernel CC_PTX_DEVICE = ID.PTX_Device CC_MBLAZE_INTR = ID.MBLAZE_INTR CC_MBLAZE_SVOL = ID.MBLAZE_SVOL CCEnum.declare() # int predicates class ICMPEnum(Enum): prefix = 'ICMP_' Predicate = api.llvm.CmpInst.Predicate ICMP_EQ = Predicate.ICMP_EQ ICMP_NE = Predicate.ICMP_NE ICMP_UGT = Predicate.ICMP_UGT ICMP_UGE = Predicate.ICMP_UGE ICMP_ULT = Predicate.ICMP_ULT ICMP_ULE = Predicate.ICMP_ULE ICMP_SGT = Predicate.ICMP_SGT ICMP_SGE = Predicate.ICMP_SGE ICMP_SLT = Predicate.ICMP_SLT ICMP_SLE = Predicate.ICMP_SLE ICMPEnum.declare() # same as ICMP_xx, for backward compatibility IPRED_EQ = ICMP_EQ IPRED_NE = ICMP_NE IPRED_UGT = ICMP_UGT IPRED_UGE = ICMP_UGE IPRED_ULT = ICMP_ULT IPRED_ULE = ICMP_ULE IPRED_SGT = ICMP_SGT IPRED_SGE = ICMP_SGE IPRED_SLT = ICMP_SLT IPRED_SLE = ICMP_SLE # real predicates class FCMPEnum(Enum): prefix = 'FCMP_' Predicate = api.llvm.CmpInst.Predicate FCMP_FALSE = Predicate.FCMP_FALSE FCMP_OEQ = Predicate.FCMP_OEQ FCMP_OGT = Predicate.FCMP_OGT FCMP_OGE = Predicate.FCMP_OGE FCMP_OLT = Predicate.FCMP_OLT FCMP_OLE = Predicate.FCMP_OLE FCMP_ONE = Predicate.FCMP_ONE FCMP_ORD = Predicate.FCMP_ORD FCMP_UNO = Predicate.FCMP_UNO FCMP_UEQ = Predicate.FCMP_UEQ FCMP_UGT = Predicate.FCMP_UGT FCMP_UGE = Predicate.FCMP_UGE FCMP_ULT = Predicate.FCMP_ULT FCMP_ULE = Predicate.FCMP_ULE FCMP_UNE = Predicate.FCMP_UNE FCMP_TRUE = Predicate.FCMP_TRUE FCMPEnum.declare() # real predicates RPRED_FALSE = FCMP_FALSE RPRED_OEQ = FCMP_OEQ RPRED_OGT = FCMP_OGT RPRED_OGE = FCMP_OGE RPRED_OLT = FCMP_OLT RPRED_OLE = FCMP_OLE RPRED_ONE = FCMP_ONE RPRED_ORD = FCMP_ORD RPRED_UNO = FCMP_UNO RPRED_UEQ = FCMP_UEQ RPRED_UGT = FCMP_UGT RPRED_UGE = FCMP_UGE RPRED_ULT = FCMP_ULT RPRED_ULE = FCMP_ULE RPRED_UNE = FCMP_UNE RPRED_TRUE = FCMP_TRUE # linkages (see llvm::GlobalValue::LinkageTypes) class LinkageEnum(Enum): prefix = 'LINKAGE_' LinkageTypes = api.llvm.GlobalValue.LinkageTypes LINKAGE_EXTERNAL = LinkageTypes.ExternalLinkage LINKAGE_AVAILABLE_EXTERNALLY = LinkageTypes.AvailableExternallyLinkage LINKAGE_LINKONCE_ANY = LinkageTypes.LinkOnceAnyLinkage LINKAGE_LINKONCE_ODR = LinkageTypes.LinkOnceODRLinkage LINKAGE_WEAK_ANY = LinkageTypes.WeakAnyLinkage LINKAGE_WEAK_ODR = LinkageTypes.WeakODRLinkage LINKAGE_APPENDING = LinkageTypes.AppendingLinkage LINKAGE_INTERNAL = LinkageTypes.InternalLinkage LINKAGE_PRIVATE = LinkageTypes.PrivateLinkage LINKAGE_DLLIMPORT = LinkageTypes.DLLImportLinkage LINKAGE_DLLEXPORT = LinkageTypes.DLLExportLinkage LINKAGE_EXTERNAL_WEAK = LinkageTypes.ExternalWeakLinkage LINKAGE_COMMON = LinkageTypes.CommonLinkage LINKAGE_LINKER_PRIVATE = LinkageTypes.LinkerPrivateLinkage LINKAGE_LINKER_PRIVATE_WEAK = LinkageTypes.LinkerPrivateWeakLinkage LinkageEnum.declare() # visibility (see llvm/GlobalValue.h) class VisibilityEnum(Enum): prefix = 'VISIBILITY_' VISIBILITY_DEFAULT = api.llvm.GlobalValue.VisibilityTypes.DefaultVisibility VISIBILITY_HIDDEN = api.llvm.GlobalValue.VisibilityTypes.HiddenVisibility VISIBILITY_PROTECTED = api.llvm.GlobalValue.VisibilityTypes.ProtectedVisibility VisibilityEnum.declare() # parameter attributes # LLVM 3.2 llvm::Attributes::AttrVal (see llvm/Attributes.h) # LLVM 3.3 llvm::Attribute::AttrKind (see llvm/Attributes.h) class AttrEnum(Enum): prefix = 'ATTR_' if llvm.version >= (3, 3): AttrVal = api.llvm.Attribute.AttrKind else: AttrVal = api.llvm.Attributes.AttrVal ATTR_NONE = AttrVal.None_ ATTR_ZEXT = AttrVal.ZExt ATTR_SEXT = AttrVal.SExt ATTR_NO_RETURN = AttrVal.NoReturn ATTR_IN_REG = AttrVal.InReg ATTR_STRUCT_RET = AttrVal.StructRet ATTR_NO_UNWIND = AttrVal.NoUnwind ATTR_NO_ALIAS = AttrVal.NoAlias ATTR_BY_VAL = AttrVal.ByVal ATTR_NEST = AttrVal.Nest ATTR_READ_NONE = AttrVal.ReadNone ATTR_READONLY = AttrVal.ReadOnly ATTR_NO_INLINE = AttrVal.NoInline ATTR_ALWAYS_INLINE = AttrVal.AlwaysInline ATTR_OPTIMIZE_FOR_SIZE = AttrVal.OptimizeForSize ATTR_STACK_PROTECT = AttrVal.StackProtect ATTR_STACK_PROTECT_REQ = AttrVal.StackProtectReq ATTR_ALIGNMENT = AttrVal.Alignment ATTR_NO_CAPTURE = AttrVal.NoCapture ATTR_NO_REDZONE = AttrVal.NoRedZone ATTR_NO_IMPLICIT_FLOAT = AttrVal.NoImplicitFloat ATTR_NAKED = AttrVal.Naked ATTR_INLINE_HINT = AttrVal.InlineHint ATTR_STACK_ALIGNMENT = AttrVal.StackAlignment AttrEnum.declare() class Module(llvm.Wrapper): """A Module instance stores all the information related to an LLVM module. Modules are the top level container of all other LLVM Intermediate Representation (IR) objects. Each module directly contains a list of globals variables, a list of functions, a list of libraries (or other modules) this module depends on, a symbol table, and various data about the target's characteristics. Construct a Module only using the static methods defined below, *NOT* using the constructor. A correct usage is: module_obj = Module.new('my_module') """ __cache = weakref.WeakValueDictionary() def __new__(cls, ptr): cached = cls.__cache.get(ptr) if cached: return cached obj = object.__new__(cls) cls.__cache[ptr] = obj return obj @staticmethod def new(id): """Create a new Module instance. Creates an instance of Module, having the id `id'. """ context = api.llvm.getGlobalContext() m = api.llvm.Module.new(id, context) return Module(m) @staticmethod def from_bitcode(fileobj_or_str): """Create a Module instance from the contents of a bitcode file. fileobj_or_str -- takes a file-like object or string that contains a module represented in bitcode. """ if isinstance(fileobj_or_str, str): bc = fileobj_or_str else: bc = fileobj_or_str.read() errbuf = BytesIO() context = api.llvm.getGlobalContext() m = api.llvm.ParseBitCodeFile(bc, context, errbuf) if not m: raise Exception(errbuf.getvalue()) errbuf.close() return Module(m) @staticmethod def from_assembly(fileobj_or_str): """Create a Module instance from the contents of an LLVM assembly (.ll) file. fileobj_or_str -- takes a file-like object or string that contains a module represented in llvm-ir assembly. """ if isinstance(fileobj_or_str, str): ir = fileobj_or_str else: ir = fileobj_or_str.read() errbuf = BytesIO() context = api.llvm.getGlobalContext() m = api.llvm.ParseAssemblyString(ir, None, api.llvm.SMDiagnostic.new(), context) errbuf.close() return Module(m) def __str__(self): """Text representation of a module. Returns the textual representation (`llvm assembly') of the module. Use it like this: ll = str(module_obj) print module_obj # same as `print ll' """ return str(self._ptr) def __eq__(self, rhs): assert isinstance(rhs, Module), type(rhs) if isinstance(rhs, Module): return str(self) == str(rhs) else: return False def __ne__(self, rhs): return not (self == rhs) def _get_target(self): return self._ptr.getTargetTriple() def _set_target(self, value): return self._ptr.setTargetTriple(value) target = property(_get_target, _set_target, doc="The target triple string describing the target host.") def _get_data_layout(self): return self._ptr.getDataLayout() def _set_data_layout(self, value): return self._ptr.setDataLayout(value) data_layout = property(_get_data_layout, _set_data_layout, doc = """The data layout string for the module's target platform. The data layout strings is an encoded representation of the type sizes and alignments expected by this module. """ ) @property def pointer_size(self): return self._ptr.getPointerSize() def link_in(self, other, preserve=False): """Link the `other' module into this one. The `other' module is linked into this one such that types, global variables, function, etc. are matched and resolved. The `other' module is no longer valid after this method is invoked, all refs to it should be dropped. In the future, this API might be replaced with a full-fledged Linker class. """ assert isinstance(other, Module) enum_mode = api.llvm.Linker.LinkerMode mode = enum_mode.PreserveSource if preserve else enum_mode.DestroySource with contextlib.closing(BytesIO()) as errmsg: failed = api.llvm.Linker.LinkModules(self._ptr, other._ptr, mode, errmsg) if failed: raise llvm.LLVMException(errmsg.getvalue()) def get_type_named(self, name): typ = self._ptr.getTypeByName(name) if typ: return StructType(typ) def add_global_variable(self, ty, name, addrspace=0): """Add a global variable of given type with given name.""" external = api.llvm.GlobalVariable.LinkageTypes.ExternalLinkage notthreadlocal = api.llvm.GlobalVariable.ThreadLocalMode.NotThreadLocal init = None insertbefore = None ptr = api.llvm.GlobalVariable.new(self._ptr, ty._ptr, False, external, init, name, insertbefore, notthreadlocal, addrspace) return _make_value(ptr) def get_global_variable_named(self, name): """Return a GlobalVariable object for the given name.""" ptr = self._ptr.getNamedGlobal(name) if ptr is None: raise llvm.LLVMException("No global named: %s" % name) return _make_value(ptr) @property def global_variables(self): return list(map(_make_value, self._ptr.list_globals())) def add_function(self, ty, name): """Add a function of given type with given name.""" return Function.new(self, ty, name) # fn = self.get_function_named(name) # if fn is not None: # raise llvm.LLVMException("Duplicated function %s" % name) # return self.get_or_insert_function(ty, name) def get_function_named(self, name): """Return a Function object representing function with given name.""" return Function.get(self, name) # fn = self._ptr.getFunction(name) # if fn is not None: # return _make_value(fn) def get_or_insert_function(self, ty, name): """Like get_function_named(), but does add_function() first, if function is not present.""" return Function.get_or_insert(self, ty, name) # constant = self._ptr.getOrInsertFunction(name, ty._ptr) # try: # fn = constant._downcast(api.llvm.Function) # except ValueError: # # bitcasted to function type # return _make_value(constant) # else: # return _make_value(fn) @property def functions(self): """All functions in this module.""" return list(map(_make_value, self._ptr.list_functions())) def verify(self): """Verify module. Checks module for errors. Raises `llvm.LLVMException' on any error.""" action = api.llvm.VerifierFailureAction.ReturnStatusAction errio = BytesIO() broken = api.llvm.verifyModule(self._ptr, action, errio) if broken: raise llvm.LLVMException(errio.getvalue()) def to_bitcode(self, fileobj=None): """Write bitcode representation of module to given file-like object. fileobj -- A file-like object to where the bitcode is written. If it is None, the bitcode is returned. Return value -- Returns None if fileobj is not None. Otherwise, return the bitcode as a bytestring. """ ret = False if fileobj is None: ret = True fileobj = BytesIO() api.llvm.WriteBitcodeToFile(self._ptr, fileobj) if ret: return fileobj.getvalue() def _get_id(self): return self._ptr.getModuleIdentifier() def _set_id(self, string): self._ptr.setModuleIdentifier(string) id = property(_get_id, _set_id) def _to_native_something(self, fileobj, cgft): cgft = api.llvm.TargetMachine.CodeGenFileType.CGFT_AssemblyFile cgft = api.llvm.TargetMachine.CodeGenFileType.CGFT_ObjectFile from llvm.ee import TargetMachine from llvm.passes import PassManager from llvmpy import extra tm = TargetMachine.new()._ptr pm = PassManager.new()._ptr formatted failed = tm.addPassesToEmitFile(pm, fileobj, cgft, False) if failed: raise llvm.LLVMException("Failed to write native object file") if ret: return fileobj.getvalue() def to_native_object(self, fileobj=None): '''Outputs the byte string of the module as native object code If a fileobj is given, the output is written to it; Otherwise, the output is returned ''' ret = False if fileobj is None: ret = True fileobj = BytesIO() from llvm.ee import TargetMachine tm = TargetMachine.new() fileobj.write(tm.emit_object(self)) if ret: return fileobj.getvalue() def to_native_assembly(self, fileobj=None): '''Outputs the byte string of the module as native assembly code If a fileobj is given, the output is written to it; Otherwise, the output is returned ''' ret = False if fileobj is None: ret = True fileobj = StringIO() from llvm.ee import TargetMachine tm = TargetMachine.new() asm = tm.emit_assembly(self) fileobj.write(asm) if ret: return fileobj.getvalue() def get_or_insert_named_metadata(self, name): return NamedMetaData(self._ptr.getOrInsertNamedMetadata(name)) def get_named_metadata(self, name): md = self._ptr.getNamedMetadata(name) if md: return NamedMetaData(md) def clone(self): return Module(api.llvm.CloneModule(self._ptr)) class Type(llvm.Wrapper): """Represents a type, like a 32-bit integer or an 80-bit x86 float. Use one of the static methods to create an instance. Example: ty = Type.double() """ _type_ = api.llvm.Type def __init__(self, ptr): ptr = ptr._downcast(type(self)._type_) super(Type, self).__init__(ptr) @property def kind(self): return self._ptr.getTypeID() @staticmethod def int(bits=32): """Create an integer type having the given bit width.""" context = api.llvm.getGlobalContext() ptr = api.llvm.Type.getIntNTy(context, bits) return Type(ptr) @staticmethod def float(): """Create a 32-bit floating point type.""" context = api.llvm.getGlobalContext() ptr = api.llvm.Type.getFloatTy(context) return Type(ptr) @staticmethod def double(): """Create a 64-bit floating point type.""" context = api.llvm.getGlobalContext() ptr = api.llvm.Type.getDoubleTy(context) return Type(ptr) @staticmethod def x86_fp80(): """Create a 80-bit x86 floating point type.""" context = api.llvm.getGlobalContext() ptr = api.llvm.Type.getX86_FP80Ty(context) return Type(ptr) @staticmethod def fp128(): """Create a 128-bit floating point type (with 112-bit mantissa).""" context = api.llvm.getGlobalContext() ptr = api.llvm.Type.getFP128Ty(context) return Type(ptr) @staticmethod def ppc_fp128(): """Create a 128-bit floating point type (two 64-bits).""" context = api.llvm.getGlobalContext() ptr = api.llvm.Type.getPPC_FP128Ty(context) return Type(ptr) @staticmethod def function(return_ty, param_tys, var_arg=False): """Create a function type. Creates a function type that returns a value of type `return_ty', takes arguments of types as given in the iterable `param_tys'. Set `var_arg' to True (default is False) for a variadic function.""" ptr = api.llvm.FunctionType.get(return_ty._ptr, llvm._extract_ptrs(param_tys), var_arg) return FunctionType(ptr) @staticmethod def opaque(name): """Create a opaque StructType""" context = api.llvm.getGlobalContext() if not name: raise llvm.LLVMException("Opaque type must have a non-empty name") ptr = api.llvm.StructType.create(context, name) return StructType(ptr) @staticmethod def struct(element_tys, name=''): # not packed """Create a (unpacked) structure type. Creates a structure type with elements of types as given in the iterable `element_tys'. This method creates a unpacked structure. For a packed one, use the packed_struct() method. If name is not '', creates a identified type; otherwise, creates a literal type.""" context = api.llvm.getGlobalContext() is_packed = False if name: ptr = api.llvm.StructType.create(context, name) ptr.setBody(llvm._extract_ptrs(element_tys), is_packed) else: ptr = api.llvm.StructType.get(context, llvm._extract_ptrs(element_tys), is_packed) return StructType(ptr) @staticmethod def packed_struct(element_tys, name=''): """Create a (packed) structure type. Creates a structure type with elements of types as given in the iterable `element_tys'. This method creates a packed structure. For an unpacked one, use the struct() method. If name is not '', creates a identified type; otherwise, creates a literal type.""" context = api.llvm.getGlobalContext() is_packed = True ptr = api.llvm.StructType.create(context, name) ptr.setBody(llvm._extract_ptrs(element_tys), is_packed) return StructType(ptr) @staticmethod def array(element_ty, count): """Create an array type. Creates a type for an array of elements of type `element_ty', having 'count' elements.""" ptr = api.llvm.ArrayType.get(element_ty._ptr, count) return ArrayType(ptr) @staticmethod def pointer(pointee_ty, addr_space=0): """Create a pointer type. Creates a pointer type, which can point to values of type `pointee_ty', in the address space `addr_space'.""" ptr = api.llvm.PointerType.get(pointee_ty._ptr, addr_space) return PointerType(ptr) @staticmethod def vector(element_ty, count): """Create a vector type. Creates a type for a vector of elements of type `element_ty', having `count' elements.""" ptr = api.llvm.VectorType.get(element_ty._ptr, count) return VectorType(ptr) @staticmethod def void(): """Create a void type. Represents the `void' type.""" context = api.llvm.getGlobalContext() ptr = api.llvm.Type.getVoidTy(context) return Type(ptr) @staticmethod def label(): """Create a label type.""" context = api.llvm.getGlobalContext() ptr = api.llvm.Type.getLabelTy(context) return Type(ptr) def __new__(cls, ptr): tyid = ptr.getTypeID() idmap = { TYPE_HALF: IntegerType, TYPE_INTEGER: IntegerType, TYPE_FUNCTION: FunctionType, TYPE_STRUCT: StructType, TYPE_ARRAY: ArrayType, TYPE_POINTER: PointerType, TYPE_VECTOR: VectorType, } try: newcls = idmap[tyid] except KeyError: newcls = Type obj = llvm.Wrapper.__new__(newcls) return obj def __str__(self): return str(self._ptr) def __hash__(self): return hash(self._ptr) def __eq__(self, rhs): return self._ptr is rhs._ptr def __ne__(self, rhs): return not (self == rhs) class IntegerType(Type): """Represents an integer type.""" _type_ = api.llvm.IntegerType @property def width(self): """The width of the integer type, in bits.""" return self._ptr.getIntegerBitWidth() class FunctionType(Type): """Represents a function type.""" _type_ = api.llvm.FunctionType @property def return_type(self): """The type of the value returned by this function.""" return Type(self._ptr.getReturnType()) @property def vararg(self): """True if this function is variadic.""" return self._ptr.isVarArg() @property def args(self): """An iterable that yields Type objects, representing the types of the arguments accepted by this function, in order.""" return [Type(self._ptr.getParamType(i)) for i in range(self.arg_count)] @property def arg_count(self): """Number of arguments accepted by this function. Same as len(obj.args), but faster.""" return self._ptr.getNumParams() class StructType(Type): """Represents a structure type.""" _type_ = api.llvm.StructType @property def element_count(self): """Number of elements (members) in the structure. Same as len(obj.elements), but faster.""" return self._ptr.getNumElements() @property def elements(self): """An iterable that yields Type objects, representing the types of the elements (members) of the structure, in order.""" return [Type(self._ptr.getElementType(i)) for i in range(self._ptr.getNumElements())] def set_body(self, elems, packed=False): """Filled the body of a opaque type. """ # check if not self.is_opaque: raise llvm.LLVMException("Body is already defined.") self._ptr.setBody(llvm._extract_ptrs(elems), packed) @property def packed(self): """True if the structure is packed, False otherwise.""" return self._ptr.isPacked() def _set_name(self, name): self._ptr.setName(name) def _get_name(self): if self._ptr.isLiteral(): return "" else: return self._ptr.getName() name = property(_get_name, _set_name) @property def is_literal(self): return self._ptr.isLiteral() @property def is_identified(self): return not self.is_literal @property def is_opaque(self): return self._ptr.isOpaque() def is_layout_identical(self, other): return self._ptr.isLayoutIdentical(other._ptr) class ArrayType(Type): """Represents an array type.""" _type_ = api.llvm.ArrayType @property def element(self): return Type(self._ptr.getArrayElementType()) @property def count(self): return self._ptr.getNumElements() class PointerType(Type): _type_ = api.llvm.PointerType @property def pointee(self): return Type(self._ptr.getPointerElementType()) @property def address_space(self): return self._ptr.getAddressSpace() class VectorType(Type): _type_ = api.llvm.VectorType @property def element(self): return Type(self._ptr.getVectorElementType()) @property def count(self): return self._ptr.getNumElements() class Value(llvm.Wrapper): _type_ = api.llvm.Value def __init__(self, builder, ptr): assert builder is _ValueFactory if type(self._type_) is type: if isinstance(ptr, self._type_): # is not downcast casted = ptr else: casted = ptr._downcast(self._type_) else: try: for ty in self._type_: if isinstance(ptr, ty): # is not downcast casted = ptr else: try: casted = ptr._downcast(ty) except ValueError: pass else: break else: casted = ptr except TypeError: casted = ptr super(Value, self).__init__(casted) def __str__(self): return str(self._ptr) def __hash__(self): return hash(self._ptr) def __eq__(self, rhs): if isinstance(rhs, Value): return str(self) == str(rhs) else: return False def __ne__(self, rhs): return not self == rhs def _get_name(self): return self._ptr.getName() def _set_name(self, value): return self._ptr.setName(value) name = property(_get_name, _set_name) @property def value_id(self): return self._ptr.getValueID() @property def type(self): return Type(self._ptr.getType()) @property def use_count(self): return self._ptr.getNumUses() @property def uses(self): return list(map(_make_value, self._ptr.list_use())) class User(Value): _type_ = api.llvm.User @property def operand_count(self): return self._ptr.getNumOperands() @property def operands(self): """Yields operands of this instruction.""" return [_make_value(self._ptr.getOperand(i)) for i in range(self.operand_count)] class Constant(User): _type_ = api.llvm.Constant @staticmethod def null(ty): return _make_value(api.llvm.Constant.getNullValue(ty._ptr)) @staticmethod def all_ones(ty): return _make_value(api.llvm.Constant.getAllOnesValue(ty._ptr)) @staticmethod def undef(ty): return _make_value(api.llvm.UndefValue.get(ty._ptr)) @staticmethod def int(ty, value): return _make_value(api.llvm.ConstantInt.get(ty._ptr, int(value), False)) @staticmethod def int_signextend(ty, value): return _make_value(api.llvm.ConstantInt.get(ty._ptr, int(value), True)) @staticmethod def real(ty, value): return _make_value(api.llvm.ConstantFP.get(ty._ptr, float(value))) @staticmethod def string(strval): # dont_null_terminate=True cxt = api.llvm.getGlobalContext() return _make_value(api.llvm.ConstantDataArray.getString(cxt, strval, False)) @staticmethod def stringz(strval): # dont_null_terminate=False cxt = api.llvm.getGlobalContext() return _make_value(api.llvm.ConstantDataArray.getString(cxt, strval, True)) @staticmethod def array(ty, consts): aryty = Type.array(ty, len(consts)) return _make_value(api.llvm.ConstantArray.get(aryty._ptr, llvm._extract_ptrs(consts))) @staticmethod def struct(consts): # not packed return _make_value(api.llvm.ConstantStruct.getAnon(llvm._extract_ptrs(consts), False)) @staticmethod def packed_struct(consts): return _make_value(api.llvm.ConstantStruct.getAnon(llvm._extract_ptrs(consts), False)) @staticmethod def vector(consts): return _make_value(api.llvm.ConstantVector.get(llvm._extract_ptrs(consts))) @staticmethod def sizeof(ty): return _make_value(api.llvm.ConstantExpr.getSizeOf(ty._ptr)) def neg(self): return _make_value(api.llvm.ConstantExpr.getNeg(self._ptr)) def not_(self): return _make_value(api.llvm.ConstantExpr.getNot(self._ptr)) def add(self, rhs): return _make_value(api.llvm.ConstantExpr.getAdd(self._ptr, rhs._ptr)) def fadd(self, rhs): return _make_value(api.llvm.ConstantExpr.getFAdd(self._ptr, rhs._ptr)) def sub(self, rhs): return _make_value(api.llvm.ConstantExpr.getSub(self._ptr, rhs._ptr)) def fsub(self, rhs): return _make_value(api.llvm.ConstantExpr.getFSub(self._ptr, rhs._ptr)) def mul(self, rhs): return _make_value(api.llvm.ConstantExpr.getMul(self._ptr, rhs._ptr)) def fmul(self, rhs): return _make_value(api.llvm.ConstantExpr.getFMul(self._ptr, rhs._ptr)) def udiv(self, rhs): return _make_value(api.llvm.ConstantExpr.getUDiv(self._ptr, rhs._ptr)) def sdiv(self, rhs): return _make_value(api.llvm.ConstantExpr.getSDiv(self._ptr, rhs._ptr)) def fdiv(self, rhs): return _make_value(api.llvm.ConstantExpr.getFDiv(self._ptr, rhs._ptr)) def urem(self, rhs): return _make_value(api.llvm.ConstantExpr.getURem(self._ptr, rhs._ptr)) def srem(self, rhs): return _make_value(api.llvm.ConstantExpr.getSRem(self._ptr, rhs._ptr)) def frem(self, rhs): return _make_value(api.llvm.ConstantExpr.getFRem(self._ptr, rhs._ptr)) def and_(self, rhs): return _make_value(api.llvm.ConstantExpr.getAnd(self._ptr, rhs._ptr)) def or_(self, rhs): return _make_value(api.llvm.ConstantExpr.getOr(self._ptr, rhs._ptr)) def xor(self, rhs): return _make_value(api.llvm.ConstantExpr.getXor(self._ptr, rhs._ptr)) def icmp(self, int_pred, rhs): return _make_value(api.llvm.ConstantExpr.getICmp(int_pred, self._ptr, rhs._ptr)) def fcmp(self, real_pred, rhs): return _make_value(api.llvm.ConstantExpr.getFCmp(real_pred, self._ptr, rhs._ptr)) def shl(self, rhs): return _make_value(api.llvm.ConstantExpr.getShl(self._ptr, rhs._ptr)) def lshr(self, rhs): return _make_value(api.llvm.ConstantExpr.getLShr(self._ptr, rhs._ptr)) def ashr(self, rhs): return _make_value(api.llvm.ConstantExpr.getAShr(self._ptr, rhs._ptr)) def gep(self, indices): indices = llvm._extract_ptrs(indices) return _make_value(api.llvm.ConstantExpr.getGetElementPtr(self._ptr, indices)) def trunc(self, ty): return _make_value(api.llvm.ConstantExpr.getTrunc(self._ptr, ty._ptr)) def sext(self, ty): return _make_value(api.llvm.ConstantExpr.getSExt(self._ptr, ty._ptr)) def zext(self, ty): return _make_value(api.llvm.ConstantExpr.getZExt(self._ptr, ty._ptr)) def fptrunc(self, ty): return _make_value(api.llvm.ConstantExpr.getFPTrunc(self._ptr, ty._ptr)) def fpext(self, ty): return _make_value(api.llvm.ConstantExpr.getFPExtend(self._ptr, ty._ptr)) def uitofp(self, ty): return _make_value(api.llvm.ConstantExpr.getUIToFP(self._ptr, ty._ptr)) def sitofp(self, ty): return _make_value(api.llvm.ConstantExpr.getSIToFP(self._ptr, ty._ptr)) def fptoui(self, ty): return _make_value(api.llvm.ConstantExpr.getFPToUI(self._ptr, ty._ptr)) def fptosi(self, ty): return _make_value(api.llvm.ConstantExpr.getFPToSI(self._ptr, ty._ptr)) def ptrtoint(self, ty): return _make_value(api.llvm.ConstantExpr.getPtrToInt(self._ptr, ty._ptr)) def inttoptr(self, ty): return _make_value(api.llvm.ConstantExpr.getIntToPtr(self._ptr, ty._ptr)) def bitcast(self, ty): return _make_value(api.llvm.ConstantExpr.getBitCast(self._ptr, ty._ptr)) def select(self, true_const, false_const): return _make_value(api.llvm.ConstantExpr.getSelect(self._ptr, true_const._ptr, false_const._ptr)) def extract_element(self, index): # note: self must be a _vector_ constant return _make_value(api.llvm.ConstantExpr.getExtractElement(self._ptr, index._ptr)) def insert_element(self, value, index): return _make_value(api.llvm.ConstantExpr.getExtractElement(self._ptr, value._ptr, index._ptr)) def shuffle_vector(self, vector_b, mask): return _make_value(api.llvm.ConstantExpr.getShuffleVector(self._ptr, vector_b._ptr, mask._ptr)) class ConstantExpr(Constant): _type_ = api.llvm.ConstantExpr @property def opcode(self): return self._ptr.getOpcode() @property def opcode_name(self): return self._ptr.getOpcodeName() class ConstantAggregateZero(Constant): pass class ConstantDataArray(Constant): pass class ConstantDataVector(Constant): pass class ConstantInt(Constant): _type_ = api.llvm.ConstantInt @property def z_ext_value(self): '''Obtain the zero extended value for an integer constant value.''' # Warning: assertion failure when value does not fit in 64 bits return self._ptr.getZExtValue() @property def s_ext_value(self): '''Obtain the sign extended value for an integer constant value.''' # Warning: assertion failure when value does not fit in 64 bits return self._ptr.getSExtValue() class ConstantFP(Constant): pass class ConstantArray(Constant): pass class ConstantStruct(Constant): pass class ConstantVector(Constant): pass class ConstantPointerNull(Constant): pass class UndefValue(Constant): pass class GlobalValue(Constant): _type_ = api.llvm.GlobalValue def _get_linkage(self): return self._ptr.getLinkage() def _set_linkage(self, value): self._ptr.setLinkage(value) linkage = property(_get_linkage, _set_linkage) def _get_section(self): return self._ptr.getSection() def _set_section(self, value): return self._ptr.setSection(value) section = property(_get_section, _set_section) def _get_visibility(self): return self._ptr.getVisibility() def _set_visibility(self, value): return self._ptr.setVisibility(value) visibility = property(_get_visibility, _set_visibility) def _get_alignment(self): return self._ptr.getAlignment() def _set_alignment(self, value): return self._ptr.setAlignment(value) alignment = property(_get_alignment, _set_alignment) @property def is_declaration(self): return self._ptr.isDeclaration() @property def module(self): return Module(self._ptr.getParent()) class GlobalVariable(GlobalValue): _type_ = api.llvm.GlobalVariable @staticmethod def new(module, ty, name, addrspace=0): linkage = api.llvm.GlobalValue.LinkageTypes external_linkage = linkage.ExternalLinkage tlmode = api.llvm.GlobalVariable.ThreadLocalMode not_threadlocal = tlmode.NotThreadLocal gv = api.llvm.GlobalVariable.new(module._ptr, ty._ptr, False, # is constant external_linkage, None, # initializer name, None, # insert before not_threadlocal, addrspace) return _make_value(gv) @staticmethod def get(module, name): gv = _make_value(module._ptr.getNamedGlobal(name)) if not gv: llvm.LLVMException("no global named `%s`" % name) return gv def delete(self): _ValueFactory.delete(self._ptr) self._ptr.eraseFromParent() def _get_initializer(self): if not self._ptr.hasInitializer(): return None return _make_value(self._ptr.getInitializer()) def _set_initializer(self, const): self._ptr.setInitializer(const._ptr) def _del_initializer(self): self._ptr.setInitializer(None) initializer = property(_get_initializer, _set_initializer) def _get_is_global_constant(self): return self._ptr.isConstant() def _set_is_global_constant(self, value): self._ptr.setConstant(value) global_constant = property(_get_is_global_constant, _set_is_global_constant) def _get_thread_local(self): return self._ptr.isThreadLocal() def _set_thread_local(self, value): return self._ptr.setThreadLocal(value) thread_local = property(_get_thread_local, _set_thread_local) class Argument(Value): _type_ = api.llvm.Argument _valid_attrs = frozenset([ATTR_BY_VAL, ATTR_NEST, ATTR_NO_ALIAS, ATTR_NO_CAPTURE, ATTR_STRUCT_RET]) if llvm.version >= (3, 3): def add_attribute(self, attr): context = api.llvm.getGlobalContext() attrbldr = api.llvm.AttrBuilder.new() attrbldr.addAttribute(attr) attrs = api.llvm.AttributeSet.get(context, 0, attrbldr) self._ptr.addAttr(attrs) if attr not in self: raise ValueError("Attribute %r is not valid for arg %s" % (attr, self)) def remove_attribute(self, attr): context = api.llvm.getGlobalContext() attrbldr = api.llvm.AttrBuilder.new() attrbldr.addAttribute(attr) attrs = api.llvm.AttributeSet.get(context, 0, attrbldr) self._ptr.removeAttr(attrs) def _set_alignment(self, align): context = api.llvm.getGlobalContext() attrbldr = api.llvm.AttrBuilder.new() attrbldr.addAlignmentAttr(align) attrs = api.llvm.AttributeSet.get(context, 0, attrbldr) self._ptr.addAttr(attrs) else: def add_attribute(self, attr): context = api.llvm.getGlobalContext() attrbldr = api.llvm.AttrBuilder.new() attrbldr.addAttribute(attr) attrs = api.llvm.Attributes.get(context, attrbldr) self._ptr.addAttr(attrs) if attr not in self: raise ValueError("Attribute %r is not valid for arg %s" % (attr, self)) def remove_attribute(self, attr): context = api.llvm.getGlobalContext() attrbldr = api.llvm.AttrBuilder.new() attrbldr.addAttribute(attr) attrs = api.llvm.Attributes.get(context, attrbldr) self._ptr.removeAttr(attrs) def _set_alignment(self, align): context = api.llvm.getGlobalContext() attrbldr = api.llvm.AttrBuilder.new() attrbldr.addAlignmentAttr(align) attrs = api.llvm.Attributes.get(context, attrbldr) self._ptr.addAttr(attrs) def _get_alignment(self): return self._ptr.getParamAlignment() alignment = property(_get_alignment, _set_alignment) @property def attributes(self): '''Returns a set of defined attributes. ''' return set(attr for attr in self._valid_attrs if attr in self) def __contains__(self, attr): if attr == ATTR_BY_VAL: return self.has_by_val() elif attr == ATTR_NEST: return self.has_nest() elif attr == ATTR_NO_ALIAS: return self.has_no_alias() elif attr == ATTR_NO_CAPTURE: return self.has_no_capture() elif attr == ATTR_STRUCT_RET: return self.has_struct_ret() else: raise ValueError('invalid attribute for argument') @property def arg_no(self): return self._ptr.getArgNo() def has_by_val(self): return self._ptr.hasByValAttr() def has_nest(self): return self._ptr.hasNestAttr() def has_no_alias(self): return self._ptr.hasNoAliasAttr() def has_no_capture(self): return self._ptr.hasNoCaptureAttr() def has_struct_ret(self): return self._ptr.hasStructRetAttr() class Function(GlobalValue): _type_ = api.llvm.Function @staticmethod def new(module, func_ty, name): try: fn = Function.get(module, name) except llvm.LLVMException: return Function.get_or_insert(module, func_ty, name) else: raise llvm.LLVMException("Duplicated function %s" % name) @staticmethod def get_or_insert(module, func_ty, name): constant = module._ptr.getOrInsertFunction(name, func_ty._ptr) try: fn = constant._downcast(api.llvm.Function) except ValueError: # bitcasted to function type return _make_value(constant) else: return _make_value(fn) @staticmethod def get(module, name): fn = module._ptr.getFunction(name) if fn is None: raise llvm.LLVMException("no function named `%s`" % name) else: return _make_value(fn) @staticmethod def intrinsic(module, intrinsic_id, types): fn = api.llvm.Intrinsic.getDeclaration(module._ptr, intrinsic_id, llvm._extract_ptrs(types)) return _make_value(fn) def delete(self): _ValueFactory.delete(self._ptr) self._ptr.eraseFromParent() @property def intrinsic_id(self): self._ptr.getIntrinsicID() def _get_cc(self): return self._ptr.getCallingConv() def _set_cc(self, value): self._ptr.setCallingConv(value) calling_convention = property(_get_cc, _set_cc) def _get_coll(self): return self._ptr.getGC() def _set_coll(self, value): return self._ptr.setGC(value) collector = property(_get_coll, _set_coll) # the nounwind attribute: def _get_does_not_throw(self): return self._ptr.doesNotThrow() def _set_does_not_throw(self,value): assert value self._ptr.setDoesNotThrow() does_not_throw = property(_get_does_not_throw, _set_does_not_throw) @property def args(self): args = self._ptr.getArgumentList() return list(map(_make_value, args)) @property def basic_block_count(self): return len(self.basic_blocks) @property def entry_basic_block(self): assert self.basic_block_count return _make_value(self._ptr.getEntryBlock()) def get_entry_basic_block(self): "Deprecated. Use entry_basic_block instead" return self.entry_basic_block def append_basic_block(self, name): context = api.llvm.getGlobalContext() bb = api.llvm.BasicBlock.Create(context, name, self._ptr, None) return _make_value(bb) @property def basic_blocks(self): return list(map(_make_value, self._ptr.getBasicBlockList())) def viewCFG(self): return self._ptr.viewCFG() def add_attribute(self, attr): self._ptr.addFnAttr(attr) def remove_attribute(self, attr): context = api.llvm.getGlobalContext() attrbldr = api.llvm.AttrBuilder.new() attrbldr.addAttribute(attr) attrs = api.llvm.Attributes.get(context, attrbldr) self._ptr.removeFnAttr(attrs) def viewCFGOnly(self): return self._ptr.viewCFGOnly() def verify(self): # Although we're just asking LLVM to return the success or # failure, it appears to print result to stderr and abort. # Note: LLVM has a bug in preverifier that will always abort # the process upon failure. actions = api.llvm.VerifierFailureAction broken = api.llvm.verifyFunction(self._ptr, actions.ReturnStatusAction) if broken: # If broken, then re-run to print the message api.llvm.verifyFunction(self._ptr, actions.PrintMessageAction) raise llvm.LLVMException("Function %s failed verification" % self.name) #===----------------------------------------------------------------------=== # InlineAsm #===----------------------------------------------------------------------=== class InlineAsm(Value): _type_ = api.llvm.InlineAsm @staticmethod def get(functype, asm, constrains, side_effect=False, align_stack=False, dialect=api.llvm.InlineAsm.AsmDialect.AD_ATT): ilasm = api.llvm.InlineAsm.get(functype._ptr, asm, constrains, side_effect, align_stack, dialect) return _make_value(ilasm) #===----------------------------------------------------------------------=== # MetaData #===----------------------------------------------------------------------=== class MetaData(Value): _type_ = api.llvm.MDNode @staticmethod def get(module, values): ''' values -- must be an iterable of Constant or None. None is treated as "null". ''' context = api.llvm.getGlobalContext() ptr = api.llvm.MDNode.get(context, llvm._extract_ptrs(values)) return _make_value(ptr) @staticmethod def get_named_operands(module, name): namedmd = module.get_named_metadata(name) if not namedmd: return [] return [_make_value(namedmd._ptr.getOperand(i)) for i in range(namedmd._ptr.getNumOperands())] @staticmethod def add_named_operand(module, name, operand): namedmd = module.get_or_insert_named_metadata(name)._ptr namedmd.addOperand(operand._ptr) @property def operand_count(self): return self._ptr.getNumOperands() @property def operands(self): """Yields operands of this metadata.""" res = [] for i in range(self.operand_count): op = self._ptr.getOperand(i) if op is None: res.append(None) else: res.append(_make_value(op)) return res class MetaDataString(Value): _type_ = api.llvm.MDString @staticmethod def get(module, s): context = api.llvm.getGlobalContext() ptr = api.llvm.MDString.get(context, s) return _make_value(ptr) @property def string(self): '''Same as MDString::getString''' return self._ptr.getString() class NamedMetaData(llvm.Wrapper): @staticmethod def get_or_insert(mod, name): return mod.get_or_insert_named_metadata(name) @staticmethod def get(mod, name): return mod.get_named_metadata(name) def delete(self): _ValueFactory.delete(self._ptr) self._ptr.eraseFromParent() @property def name(self): return self._ptr.getName() def __str__(self): return str(self._ptr) def add(self, operand): self._ptr.addOperand(operand._ptr) #===----------------------------------------------------------------------=== # Instruction #===----------------------------------------------------------------------=== class Instruction(User): _type_ = api.llvm.Instruction @property def basic_block(self): return _make_value(self._ptr.getParent()) @property def is_terminator(self): return self._ptr.isTerminator() @property def is_binary_op(self): return self._ptr.isBinaryOp() @property def is_shift(self): return self._ptr.isShift() @property def is_cast(self): return self._ptr.isCast() @property def is_logical_shift(self): return self._ptr.isLogicalShift() @property def is_arithmetic_shift(self): return self._ptr.isArithmeticShift() @property def is_associative(self): return self._ptr.isAssociative() @property def is_commutative(self): return self._ptr.isCommutative() @property def is_volatile(self): """True if this is a volatile load or store.""" if api.llvm.LoadInst.classof(self._ptr): return self._ptr._downcast(api.llvm.LoadInst).isVolatile() elif api.llvm.StoreInst.classof(self._ptr): return self._ptr._downcast(api.llvm.StoreInst).isVolatile() else: return False def set_volatile(self, flag): if api.llvm.LoadInst.classof(self._ptr): return self._ptr._downcast(api.llvm.LoadInst).setVolatile(flag) elif api.llvm.StoreInst.classof(self._ptr): return self._ptr._downcast(api.llvm.StoreInst).setVolatile(flag) else: return False def set_metadata(self, kind, metadata): self._ptr.setMetadata(kind, metadata._ptr) def has_metadata(self): return self._ptr.hasMetadata() def get_metadata(self, kind): return self._ptr.getMetadata(kind) @property def opcode(self): return self._ptr.getOpcode() @property def opcode_name(self): return self._ptr.getOpcodeName() def erase_from_parent(self): return self._ptr.eraseFromParent() def replace_all_uses_with(self, inst): self._ptr.replaceAllUsesWith(inst) class CallOrInvokeInstruction(Instruction): _type_ = api.llvm.CallInst, api.llvm.InvokeInst def _get_cc(self): return self._ptr.getCallingConv() def _set_cc(self, value): return self._ptr.setCallingConv(value) calling_convention = property(_get_cc, _set_cc) def add_parameter_attribute(self, idx, attr): context = api.llvm.getGlobalContext() attrbldr = api.llvm.AttrBuilder.new() attrbldr.addAttribute(attr) attrs = api.llvm.Attributes.get(context, attrbldr) self._ptr.addAttribute(idx, attrs) def remove_parameter_attribute(self, idx, attr): context = api.llvm.getGlobalContext() attrbldr = api.llvm.AttrBuilder.new() attrbldr.addAttribute(attr) attrs = api.llvm.Attributes.get(context, attrbldr) self._ptr.removeAttribute(idx, attrs) def set_parameter_alignment(self, idx, align): context = api.llvm.getGlobalContext() attrbldr = api.llvm.AttrBuilder.new() attrbldr.addAlignmentAttr(align) attrs = api.llvm.Attributes.get(context, attrbldr) self._ptr.addAttribute(idx, attrs) def _get_called_function(self): function = self._ptr.getCalledFunction() if function: # Return value can be None on indirect call/invoke return _make_value(function) def _set_called_function(self, function): self._ptr.setCalledFunction(function._ptr) called_function = property(_get_called_function, _set_called_function) class PHINode(Instruction): _type_ = api.llvm.PHINode @property def incoming_count(self): return self._ptr.getNumIncomingValues() def add_incoming(self, value, block): self._ptr.addIncoming(value._ptr, block._ptr) def get_incoming_value(self, idx): return _make_value(self._ptr.getIncomingValue(idx)) def get_incoming_block(self, idx): return _make_value(self._ptr.getIncomingBlock(idx)) class SwitchInstruction(Instruction): _type_ = api.llvm.SwitchInst def add_case(self, const, bblk): self._ptr.addCase(const._ptr, bblk._ptr) class CompareInstruction(Instruction): _type_ = api.llvm.CmpInst @property def predicate(self): n = self._ptr.getPredicate() try: return ICMPEnum.get(n) except KeyError: return FCMPEnum.get(n) #===----------------------------------------------------------------------=== # Basic block #===----------------------------------------------------------------------=== class BasicBlock(Value): _type_ = api.llvm.BasicBlock def insert_before(self, name): context = api.llvm.getGlobalContext() ptr = api.llvm.BasicBlock.Create(context, name, self.function._ptr, self._ptr) return _make_value(ptr) def delete(self): _ValueFactory.delete(self._ptr) self._ptr.eraseFromParent() @property def function(self): return _make_value(self._ptr.getParent()) @property def instructions(self): return list(map(_make_value, self._ptr.getInstList())) #===----------------------------------------------------------------------=== # Value factory method #===----------------------------------------------------------------------=== class _ValueFactory(object): cache = weakref.WeakValueDictionary() # value ID -> class map class_for_valueid = { VALUE_ARGUMENT : Argument, VALUE_BASIC_BLOCK : BasicBlock, VALUE_FUNCTION : Function, VALUE_GLOBAL_ALIAS : GlobalValue, VALUE_GLOBAL_VARIABLE : GlobalVariable, VALUE_UNDEF_VALUE : UndefValue, VALUE_CONSTANT_EXPR : ConstantExpr, VALUE_CONSTANT_AGGREGATE_ZERO : ConstantAggregateZero, VALUE_CONSTANT_DATA_ARRAY : ConstantDataArray, VALUE_CONSTANT_DATA_VECTOR : ConstantDataVector, VALUE_CONSTANT_INT : ConstantInt, VALUE_CONSTANT_FP : ConstantFP, VALUE_CONSTANT_ARRAY : ConstantArray, VALUE_CONSTANT_STRUCT : ConstantStruct, VALUE_CONSTANT_VECTOR : ConstantVector, VALUE_CONSTANT_POINTER_NULL : ConstantPointerNull, VALUE_MD_NODE : MetaData, VALUE_MD_STRING : MetaDataString, VALUE_INLINE_ASM : InlineAsm, VALUE_INSTRUCTION + OPCODE_PHI : PHINode, VALUE_INSTRUCTION + OPCODE_CALL : CallOrInvokeInstruction, VALUE_INSTRUCTION + OPCODE_INVOKE : CallOrInvokeInstruction, VALUE_INSTRUCTION + OPCODE_SWITCH : SwitchInstruction, VALUE_INSTRUCTION + OPCODE_ICMP : CompareInstruction, VALUE_INSTRUCTION + OPCODE_FCMP : CompareInstruction } @classmethod def build(cls, ptr): # try to look in the cache addr = ptr._capsule.pointer id = ptr.getValueID() key = id, addr try: obj = cls.cache[key] return obj except KeyError: pass # find class by value id ctorcls = cls.class_for_valueid.get(id) if not ctorcls: if id > VALUE_INSTRUCTION: # "generic" instruction ctorcls = Instruction else: # "generic" value ctorcls = Value # cache the obj obj = ctorcls(_ValueFactory, ptr) cls.cache[key] = obj return obj @classmethod def delete(cls, ptr): del cls.cache[(ptr.getValueID(), ptr._capsule.pointer)] def _make_value(ptr): return _ValueFactory.build(ptr) #===----------------------------------------------------------------------=== # Builder #===----------------------------------------------------------------------=== _atomic_orderings = { 'unordered' : api.llvm.AtomicOrdering.Unordered, 'monotonic' : api.llvm.AtomicOrdering.Monotonic, 'acquire' : api.llvm.AtomicOrdering.Acquire, 'release' : api.llvm.AtomicOrdering.Release, 'acq_rel' : api.llvm.AtomicOrdering.AcquireRelease, 'seq_cst' : api.llvm.AtomicOrdering.SequentiallyConsistent } class Builder(llvm.Wrapper): @staticmethod def new(basic_block): context = api.llvm.getGlobalContext() ptr = api.llvm.IRBuilder.new(context) ptr.SetInsertPoint(basic_block._ptr) return Builder(ptr) def position_at_beginning(self, bblk): """Position the builder at the beginning of the given block. Next instruction inserted will be first one in the block.""" # Instruction list won't be long anyway, # Does not matter much to build a list of all instructions instrs = bblk.instructions if instrs: self.position_before(instrs[0]) else: self.position_at_end(bblk) def position_at_end(self, bblk): """Position the builder at the end of the given block. Next instruction inserted will be last one in the block.""" self._ptr.SetInsertPoint(bblk._ptr) def position_before(self, instr): """Position the builder before the given instruction. The instruction can belong to a basic block other than the current one.""" self._ptr.SetInsertPoint(instr._ptr) @property def basic_block(self): """The basic block where the builder is positioned.""" return _make_value(self._ptr.GetInsertBlock()) # terminator instructions def _guard_terminators(self): if __debug__: import warnings for instr in self.basic_block.instructions: if instr.is_terminator: warnings.warn("BasicBlock can only have one terminator") def ret_void(self): self._guard_terminators() return _make_value(self._ptr.CreateRetVoid()) def ret(self, value): self._guard_terminators() return _make_value(self._ptr.CreateRet(value._ptr)) def ret_many(self, values): self._guard_terminators() values = llvm._extract_ptrs(values) return _make_value(self._ptr.CreateAggregateRet(values, len(values))) def branch(self, bblk): self._guard_terminators() return _make_value(self._ptr.CreateBr(bblk._ptr)) def cbranch(self, if_value, then_blk, else_blk): self._guard_terminators() return _make_value(self._ptr.CreateCondBr(if_value._ptr, then_blk._ptr, else_blk._ptr)) def switch(self, value, else_blk, n=10): self._guard_terminators() return _make_value(self._ptr.CreateSwitch(value._ptr, else_blk._ptr, n)) def invoke(self, func, args, then_blk, catch_blk, name=""): self._guard_terminators() return _make_value(self._ptr.CreateInvoke(func._ptr, then_blk._ptr, catch_blk._ptr, llvm._extract_ptrs(args))) def unreachable(self): self._guard_terminators() return _make_value(self._ptr.CreateUnreachable()) # arithmethic, bitwise and logical def add(self, lhs, rhs, name="", nuw=False, nsw=False): return _make_value(self._ptr.CreateAdd(lhs._ptr, rhs._ptr, name, nuw, nsw)) def fadd(self, lhs, rhs, name=""): return _make_value(self._ptr.CreateFAdd(lhs._ptr, rhs._ptr, name)) def sub(self, lhs, rhs, name="", nuw=False, nsw=False): return _make_value(self._ptr.CreateSub(lhs._ptr, rhs._ptr, name, nuw, nsw)) def fsub(self, lhs, rhs, name=""): return _make_value(self._ptr.CreateFSub(lhs._ptr, rhs._ptr, name)) def mul(self, lhs, rhs, name="", nuw=False, nsw=False): return _make_value(self._ptr.CreateMul(lhs._ptr, rhs._ptr, name, nuw, nsw)) def fmul(self, lhs, rhs, name=""): return _make_value(self._ptr.CreateFMul(lhs._ptr, rhs._ptr, name)) def udiv(self, lhs, rhs, name="", exact=False): return _make_value(self._ptr.CreateUDiv(lhs._ptr, rhs._ptr, name, exact)) def sdiv(self, lhs, rhs, name="", exact=False): return _make_value(self._ptr.CreateSDiv(lhs._ptr, rhs._ptr, name, exact)) def fdiv(self, lhs, rhs, name=""): return _make_value(self._ptr.CreateFDiv(lhs._ptr, rhs._ptr, name)) def urem(self, lhs, rhs, name=""): return _make_value(self._ptr.CreateURem(lhs._ptr, rhs._ptr, name)) def srem(self, lhs, rhs, name=""): return _make_value(self._ptr.CreateSRem(lhs._ptr, rhs._ptr, name)) def frem(self, lhs, rhs, name=""): return _make_value(self._ptr.CreateFRem(lhs._ptr, rhs._ptr, name)) def shl(self, lhs, rhs, name="", nuw=False, nsw=False): return _make_value(self._ptr.CreateShl(lhs._ptr, rhs._ptr, name, nuw, nsw)) def lshr(self, lhs, rhs, name="", exact=False): return _make_value(self._ptr.CreateLShr(lhs._ptr, rhs._ptr, name, exact)) def ashr(self, lhs, rhs, name="", exact=False): return _make_value(self._ptr.CreateAShr(lhs._ptr, rhs._ptr, name, exact)) def and_(self, lhs, rhs, name=""): return _make_value(self._ptr.CreateAnd(lhs._ptr, rhs._ptr, name)) def or_(self, lhs, rhs, name=""): return _make_value(self._ptr.CreateOr(lhs._ptr, rhs._ptr, name)) def xor(self, lhs, rhs, name=""): return _make_value(self._ptr.CreateXor(lhs._ptr, rhs._ptr, name)) def neg(self, val, name="", nuw=False, nsw=False): return _make_value(self._ptr.CreateNeg(val._ptr, name, nuw, nsw)) def not_(self, val, name=""): return _make_value(self._ptr.CreateNot(val._ptr, name)) # memory def malloc(self, ty, name=""): context = api.llvm.getGlobalContext() allocsz = api.llvm.ConstantExpr.getSizeOf(ty._ptr) ity = allocsz.getType() malloc = api.llvm.CallInst.CreateMalloc(self.basic_block._ptr, ity, ty._ptr, allocsz, None, None, "") inst = self._ptr.Insert(malloc, name) return _make_value(inst) def malloc_array(self, ty, size, name=""): context = api.llvm.getGlobalContext() allocsz = api.llvm.ConstantExpr.getSizeOf(ty._ptr) ity = allocsz.getType() malloc = api.llvm.CallInst.CreateMalloc(self.basic_block._ptr, ity, ty._ptr, allocsz, size._ptr, None, "") inst = self._ptr.Insert(malloc, name) return _make_value(inst) def alloca(self, ty, name=""): intty = Type.int() return _make_value(self._ptr.CreateAlloca(ty._ptr, None, name)) def alloca_array(self, ty, size, name=""): return _make_value(self._ptr.CreateAlloca(ty._ptr, size._ptr, name)) def free(self, ptr): free = api.llvm.CallInst.CreateFree(ptr._ptr, self.basic_block._ptr) inst = self._ptr.Insert(free) return _make_value(inst) def load(self, ptr, name="", align=0, volatile=False, invariant=False): inst = _make_value(self._ptr.CreateLoad(ptr._ptr, name)) if align: inst._ptr.setAlignment(align) if volatile: inst.set_volatile(volatile) if invariant: mod = self.basic_block.function.module md = MetaData.get(mod, []) # empty metadata node inst.set_metadata('invariant.load', md) return inst def store(self, value, ptr, align=0, volatile=False): inst = _make_value(self._ptr.CreateStore(value._ptr, ptr._ptr)) if align: inst._ptr.setAlignment(align) if volatile: inst.set_volatile(volatile) return inst def gep(self, ptr, indices, name="", inbounds=False): if inbounds: ret = self._ptr.CreateInBoundsGEP(ptr._ptr, llvm._extract_ptrs(indices), name) else: ret = self._ptr.CreateGEP(ptr._ptr, llvm._extract_ptrs(indices), name) return _make_value(ret) # casts and extensions def trunc(self, value, dest_ty, name=""): return _make_value(self._ptr.CreateTrunc(value._ptr, dest_ty._ptr, name)) def zext(self, value, dest_ty, name=""): return _make_value(self._ptr.CreateZExt(value._ptr, dest_ty._ptr, name)) def sext(self, value, dest_ty, name=""): return _make_value(self._ptr.CreateSExt(value._ptr, dest_ty._ptr, name)) def fptoui(self, value, dest_ty, name=""): return _make_value(self._ptr.CreateFPToUI(value._ptr, dest_ty._ptr, name)) def fptosi(self, value, dest_ty, name=""): return _make_value(self._ptr.CreateFPToSI(value._ptr, dest_ty._ptr, name)) def uitofp(self, value, dest_ty, name=""): return _make_value(self._ptr.CreateUIToFP(value._ptr, dest_ty._ptr, name)) def sitofp(self, value, dest_ty, name=""): return _make_value(self._ptr.CreateSIToFP(value._ptr, dest_ty._ptr, name)) def fptrunc(self, value, dest_ty, name=""): return _make_value(self._ptr.CreateFPTrunc(value._ptr, dest_ty._ptr, name)) def fpext(self, value, dest_ty, name=""): return _make_value(self._ptr.CreateFPExt(value._ptr, dest_ty._ptr, name)) def ptrtoint(self, value, dest_ty, name=""): return _make_value(self._ptr.CreatePtrToInt(value._ptr, dest_ty._ptr, name)) def inttoptr(self, value, dest_ty, name=""): return _make_value(self._ptr.CreateIntToPtr(value._ptr, dest_ty._ptr, name)) def bitcast(self, value, dest_ty, name=""): return _make_value(self._ptr.CreateBitCast(value._ptr, dest_ty._ptr, name)) # comparisons def icmp(self, ipred, lhs, rhs, name=""): return _make_value(self._ptr.CreateICmp(ipred, lhs._ptr, rhs._ptr, name)) def fcmp(self, rpred, lhs, rhs, name=""): return _make_value(self._ptr.CreateFCmp(rpred, lhs._ptr, rhs._ptr, name)) # misc def extract_value(self, retval, idx, name=""): return _make_value(self._ptr.CreateExtractValue(retval._ptr, [idx], name)) # obsolete synonym for extract_value getresult = extract_value def insert_value(self, retval, rhs, idx, name=""): return _make_value(self._ptr.CreateInsertValue(retval._ptr, rhs._ptr, [idx], name)) def phi(self, ty, name=""): return _make_value(self._ptr.CreatePHI(ty._ptr, 2, name)) def call(self, fn, args, name=""): err_template = 'Argument type mismatch: expected %s but got %s' for i, (t, v) in enumerate(zip(fn.type.pointee.args, args)): if t != v.type: raise TypeError(err_template % (t, v.type)) arg_ptrs = llvm._extract_ptrs(args) return _make_value(self._ptr.CreateCall(fn._ptr, arg_ptrs, name)) def select(self, cond, then_value, else_value, name=""): return _make_value(self._ptr.CreateSelect(cond._ptr, then_value._ptr, else_value._ptr, name)) def vaarg(self, list_val, ty, name=""): return _make_value(self._ptr.CreateVAArg(list_val._ptr, ty._ptr, name)) def extract_element(self, vec_val, idx_val, name=""): return _make_value(self._ptr.CreateExtractElement(vec_val._ptr, idx_val._ptr, name)) def insert_element(self, vec_val, elt_val, idx_val, name=""): return _make_value(self._ptr.CreateInsertElement(vec_val._ptr, elt_val._ptr, idx_val._ptr, name)) def shuffle_vector(self, vecA, vecB, mask, name=""): return _make_value(self._ptr.CreateShuffleVector(vecA._ptr, vecB._ptr, mask._ptr, name)) # atomics def atomic_cmpxchg(self, ptr, old, new, ordering, crossthread=True): return _make_value(self._ptr.CreateAtomicCmpXchg(ptr._ptr, old._ptr, new._ptr, _atomic_orderings[ordering], _sync_scope(crossthread))) def atomic_rmw(self, op, ptr, val, ordering, crossthread=True): op_dict = dict((k.lower(), v) for k, v in vars(api.llvm.AtomicRMWInst.BinOp).items()) op = op_dict[op] return _make_value(self._ptr.CreateAtomicRMW(op, ptr._ptr, val._ptr, _atomic_orderings[ordering], _sync_scope(crossthread))) def atomic_xchg(self, *args, **kwargs): return self.atomic_rmw('xchg', *args, **kwargs) def atomic_add(self, *args, **kwargs): return self.atomic_rmw('add', *args, **kwargs) def atomic_sub(self, *args, **kwargs): return self.atomic_rmw('sub', *args, **kwargs) def atomic_and(self, *args, **kwargs): return self.atomic_rmw('and', *args, **kwargs) def atomic_nand(self, *args, **kwargs): return self.atomic_rmw('nand', *args, **kwargs) def atomic_or(self, *args, **kwargs): return self.atomic_rmw('or', *args, **kwargs) def atomic_xor(self, *args, **kwargs): return self.atomic_rmw('xor', *args, **kwargs) def atomic_max(self, *args, **kwargs): return self.atomic_rmw('max', *args, **kwargs) def atomic_min(self, *args, **kwargs): return self.atomic_rmw('min', *args, **kwargs) def atomic_umax(self, *args, **kwargs): return self.atomic_rmw('umax', *args, **kwargs) def atomic_umin(self, *args, **kwargs): return self.atomic_rmw('umin', *args, **kwargs) def atomic_load(self, ptr, ordering, align=1, crossthread=True, volatile=False, name=""): inst = self.load(ptr, align=align, volatile=volatile, name=name) inst._ptr.setAtomic(_atomic_orderings[ordering], _sync_scope(crossthread)) return inst def atomic_store(self, value, ptr, ordering, align=1, crossthread=True, volatile=False): inst = self.store(value, ptr, align=align, volatile=volatile) inst._ptr.setAtomic(_atomic_orderings[ordering], _sync_scope(crossthread)) return inst def fence(self, ordering, crossthread=True): return _make_value(self._ptr.CreateFence(_atomic_orderings[ordering], _sync_scope(crossthread))) def _sync_scope(crossthread): if crossthread: scope = api.llvm.SynchronizationScope.CrossThread else: scope = api.llvm.SynchronizationScope.SingleThread return scope def load_library_permanently(filename): """Load a shared library. Load the given shared library (filename argument specifies the full path of the .so file) using LLVM. Symbols from these are available from the execution engine thereafter.""" with contextlib.closing(BytesIO()) as errmsg: failed = api.llvm.sys.DynamicLibrary.LoadPermanentLibrary(filename, errmsg) if failed: raise llvm.LLVMException(errmsg.getvalue()) def inline_function(call): info = api.llvm.InlineFunctionInfo.new() return api.llvm.InlineFunction(call._ptr, info) def parse_environment_options(progname, envname): api.llvm.cl.ParseEnvironmentOptions(progname, envname) if api.llvm.InitializeNativeTarget(): raise llvm.LLVMException("No native target!?") if api.llvm.InitializeNativeTargetAsmPrinter(): # should this be an optional feature? # should user trigger the initialization? raise llvm.LLVMException("No native asm printer!?") if api.llvm.InitializeNativeTargetAsmParser(): # required by MCJIT? # should this be an optional feature? # should user trigger the initialization? raise llvm.LLVMException("No native asm parser!?")