diff --git a/artiq/compiler/transforms/llvm_ir_generator.py b/artiq/compiler/transforms/llvm_ir_generator.py index 660cf6750..2956020c8 100644 --- a/artiq/compiler/transforms/llvm_ir_generator.py +++ b/artiq/compiler/transforms/llvm_ir_generator.py @@ -171,6 +171,21 @@ class LLVMIRGenerator: self.phis = [] self.debug_info_emitter = DebugInfoEmitter(self.llmodule) + def needs_sret(self, lltyp, may_be_large=True): + if isinstance(lltyp, ll.VoidType): + return False + elif isinstance(lltyp, ll.IntType) and lltyp.width <= 32: + return False + elif isinstance(lltyp, ll.PointerType): + return False + elif may_be_large and isinstance(lltyp, ll.DoubleType): + return False + elif may_be_large and isinstance(lltyp, ll.LiteralStructType) \ + and len(lltyp.elements) <= 2: + return not any([self.needs_sret(elt, may_be_large=False) for elt in lltyp.elements]) + else: + return True + def llty_of_type(self, typ, bare=False, for_return=False): typ = typ.find() if types.is_tuple(typ): @@ -183,13 +198,28 @@ class LLVMIRGenerator: elif types._is_pointer(typ): return llptr elif types.is_function(typ): + sretarg = [] + llretty = self.llty_of_type(typ.ret, for_return=True) + if self.needs_sret(llretty): + sretarg = [llretty.as_pointer()] + llretty = llvoid + envarg = llptr - llty = ll.FunctionType(args=[envarg] + + llty = ll.FunctionType(args=sretarg + [envarg] + [self.llty_of_type(typ.args[arg]) for arg in typ.args] + [self.llty_of_type(ir.TOption(typ.optargs[arg])) for arg in typ.optargs], - return_type=self.llty_of_type(typ.ret, for_return=True)) + return_type=llretty) + + # TODO: actually mark the first argument as sret (also noalias nocapture). + # llvmlite currently does not have support for this; + # https://github.com/numba/llvmlite/issues/91. + if sretarg: + llty.__has_sret = True + else: + llty.__has_sret = False + if bare: return llty else: @@ -896,8 +926,22 @@ class LLVMIRGenerator: name=insn.name) else: llfun, llargs = self._prepare_closure_call(insn) - return self.llbuilder.call(llfun, llargs, - name=insn.name) + + if llfun.type.pointee.__has_sret: + llstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), []) + + llresultslot = self.llbuilder.alloca(llfun.type.pointee.args[0].pointee) + print(llfun) + print(llresultslot) + self.llbuilder.call(llfun, [llresultslot] + llargs) + llresult = self.llbuilder.load(llresultslot) + + self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llstackptr]) + + return llresult + else: + return self.llbuilder.call(llfun, llargs, + name=insn.name) def process_Invoke(self, insn): llnormalblock = self.map(insn.normal_target()) @@ -937,7 +981,11 @@ class LLVMIRGenerator: if builtins.is_none(insn.value().type): return self.llbuilder.ret_void() else: - return self.llbuilder.ret(self.map(insn.value())) + if self.llfunction.type.pointee.__has_sret: + self.llbuilder.store(self.map(insn.value()), self.llfunction.args[0]) + return self.llbuilder.ret_void() + else: + return self.llbuilder.ret(self.map(insn.value())) def process_Unreachable(self, insn): return self.llbuilder.unreachable()