diff --git a/artiq/compiler/transforms/llvm_ir_generator.py b/artiq/compiler/transforms/llvm_ir_generator.py index 25b572c2d..685918981 100644 --- a/artiq/compiler/transforms/llvm_ir_generator.py +++ b/artiq/compiler/transforms/llvm_ir_generator.py @@ -193,6 +193,10 @@ class LLVMIRGenerator: else: return True + def has_sret(self, functy): + llretty = self.llty_of_type(functy.ret, for_return=True) + return self.needs_sret(llretty) + def llty_of_type(self, typ, bare=False, for_return=False): typ = typ.find() if types.is_tuple(typ): @@ -219,22 +223,14 @@ class LLVMIRGenerator: for arg in typ.optargs], return_type=llretty) - # TODO: actually mark the first argument as sret (also noalias nocapture). - # llvmlite currently does not have support for this; - # https://github.com/numba/llvmlite/issues/91. - if sretarg: - llty.__has_sret = True - else: - llty.__has_sret = False - if bare: return llty else: return ll.LiteralStructType([envarg, llty.as_pointer()]) elif types.is_method(typ): - llfuncty = self.llty_of_type(types.get_method_function(typ)) + llfunty = self.llty_of_type(types.get_method_function(typ)) llselfty = self.llty_of_type(types.get_method_self(typ)) - return ll.LiteralStructType([llfuncty, llselfty]) + return ll.LiteralStructType([llfunty, llselfty]) elif builtins.is_none(typ): if for_return: return llvoid @@ -400,8 +396,13 @@ class LLVMIRGenerator: elif isinstance(value, ir.Function): llfun = self.llmodule.get_global(value.name) if llfun is None: - llfun = ll.Function(self.llmodule, self.llty_of_type(value.type, bare=True), - value.name) + llfunty = self.llty_of_type(value.type, bare=True) + llfun = ll.Function(self.llmodule, llfunty, value.name) + + llretty = self.llty_of_type(value.type.ret, for_return=True) + if self.needs_sret(llretty): + llfun.args[0].add_attribute('sret') + return llfun else: assert False @@ -516,11 +517,7 @@ class LLVMIRGenerator: def process_function(self, func): try: - self.llfunction = self.llmodule.get_global(func.name) - - if self.llfunction is None: - llfunty = self.llty_of_type(func.type, bare=True) - self.llfunction = ll.Function(self.llmodule, llfunty, func.name) + self.llfunction = self.map(func) if func.is_internal: self.llfunction.linkage = 'internal' @@ -533,7 +530,7 @@ class LLVMIRGenerator: disubprogram = self.debug_info_emitter.emit_subprogram(func, self.llfunction) # First, map arguments. - if self.llfunction.type.pointee.__has_sret: + if self.has_sret(func.type): llactualargs = self.llfunction.args[1:] else: llactualargs = self.llfunction.args @@ -1095,7 +1092,7 @@ class LLVMIRGenerator: else: llfun, llargs = self._prepare_closure_call(insn) - if llfun.type.pointee.__has_sret: + if self.has_sret(insn.target_function().type): llstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), []) llresultslot = self.llbuilder.alloca(llfun.type.pointee.args[0].pointee) @@ -1221,11 +1218,12 @@ class LLVMIRGenerator: if builtins.is_none(insn.value().type): return self.llbuilder.ret_void() else: - if self.llfunction.type.pointee.__has_sret: - self.llbuilder.store(self.map(insn.value()), self.llfunction.args[0]) + llvalue = self.map(insn.value()) + if self.needs_sret(llvalue): + self.llbuilder.store(llvalue, self.llfunction.args[0]) return self.llbuilder.ret_void() else: - return self.llbuilder.ret(self.map(insn.value())) + return self.llbuilder.ret(llvalue) def process_Unreachable(self, insn): return self.llbuilder.unreachable()