diff --git a/artiq/compiler/transforms/llvm_ir_generator.py b/artiq/compiler/transforms/llvm_ir_generator.py index 197bc2cdf..7bf6037f8 100644 --- a/artiq/compiler/transforms/llvm_ir_generator.py +++ b/artiq/compiler/transforms/llvm_ir_generator.py @@ -1359,11 +1359,24 @@ class LLVMIRGenerator: else: llfun = self.map(insn.static_target_function) llenv = self.llbuilder.extract_value(llclosure, 0, name="env.fun") - return llfun, [llenv] + list(llargs), {} + return llfun, [llenv] + list(llargs), {}, None def _prepare_ffi_call(self, insn): llargs = [] llarg_attrs = {} + + stack_save_needed = False + for i, arg in enumerate(insn.arguments()): + llarg = self.map(arg) + if isinstance(llarg.type, (ll.LiteralStructType, ll.IdentifiedStructType)): + stack_save_needed = True + break + + if stack_save_needed: + llcallstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), []) + else: + llcallstackptr = None + for i, arg in enumerate(insn.arguments()): llarg = self.map(arg) if isinstance(llarg.type, (ll.LiteralStructType, ll.IdentifiedStructType)): @@ -1399,7 +1412,7 @@ class LLVMIRGenerator: if 'nowrite' in insn.target_function().type.flags: llfun.attributes.add('inaccessiblememonly') - return llfun, list(llargs), llarg_attrs + return llfun, list(llargs), llarg_attrs, llcallstackptr def _build_rpc(self, fun_loc, fun_type, args, llnormalblock, llunwindblock): llservice = ll.Constant(lli32, fun_type.service) @@ -1535,9 +1548,9 @@ class LLVMIRGenerator: insn.arguments(), llnormalblock=None, llunwindblock=None) elif types.is_external_function(functiontyp): - llfun, llargs, llarg_attrs = self._prepare_ffi_call(insn) + llfun, llargs, llarg_attrs, llcallstackptr = self._prepare_ffi_call(insn) else: - llfun, llargs, llarg_attrs = self._prepare_closure_call(insn) + llfun, llargs, llarg_attrs, llcallstackptr = self._prepare_closure_call(insn) if self.has_sret(functiontyp): llstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), []) @@ -1556,6 +1569,9 @@ class LLVMIRGenerator: # {} elsewhere. llresult = ll.Constant(llunit, []) + if llcallstackptr != None: + self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llcallstackptr]) + return llresult def process_Invoke(self, insn):