transforms.llvm_ir_generator: use sret attribute.

This commit is contained in:
whitequark 2016-01-10 17:31:59 +00:00
parent edb7423a4f
commit 7f914a057c
1 changed files with 20 additions and 22 deletions

View File

@ -193,6 +193,10 @@ class LLVMIRGenerator:
else: else:
return True return True
def has_sret(self, functy):
llretty = self.llty_of_type(functy.ret, for_return=True)
return self.needs_sret(llretty)
def llty_of_type(self, typ, bare=False, for_return=False): def llty_of_type(self, typ, bare=False, for_return=False):
typ = typ.find() typ = typ.find()
if types.is_tuple(typ): if types.is_tuple(typ):
@ -219,22 +223,14 @@ class LLVMIRGenerator:
for arg in typ.optargs], for arg in typ.optargs],
return_type=llretty) return_type=llretty)
# TODO: actually mark the first argument as sret (also noalias nocapture).
# llvmlite currently does not have support for this;
# https://github.com/numba/llvmlite/issues/91.
if sretarg:
llty.__has_sret = True
else:
llty.__has_sret = False
if bare: if bare:
return llty return llty
else: else:
return ll.LiteralStructType([envarg, llty.as_pointer()]) return ll.LiteralStructType([envarg, llty.as_pointer()])
elif types.is_method(typ): elif types.is_method(typ):
llfuncty = self.llty_of_type(types.get_method_function(typ)) llfunty = self.llty_of_type(types.get_method_function(typ))
llselfty = self.llty_of_type(types.get_method_self(typ)) llselfty = self.llty_of_type(types.get_method_self(typ))
return ll.LiteralStructType([llfuncty, llselfty]) return ll.LiteralStructType([llfunty, llselfty])
elif builtins.is_none(typ): elif builtins.is_none(typ):
if for_return: if for_return:
return llvoid return llvoid
@ -400,8 +396,13 @@ class LLVMIRGenerator:
elif isinstance(value, ir.Function): elif isinstance(value, ir.Function):
llfun = self.llmodule.get_global(value.name) llfun = self.llmodule.get_global(value.name)
if llfun is None: if llfun is None:
llfun = ll.Function(self.llmodule, self.llty_of_type(value.type, bare=True), llfunty = self.llty_of_type(value.type, bare=True)
value.name) llfun = ll.Function(self.llmodule, llfunty, value.name)
llretty = self.llty_of_type(value.type.ret, for_return=True)
if self.needs_sret(llretty):
llfun.args[0].add_attribute('sret')
return llfun return llfun
else: else:
assert False assert False
@ -516,11 +517,7 @@ class LLVMIRGenerator:
def process_function(self, func): def process_function(self, func):
try: try:
self.llfunction = self.llmodule.get_global(func.name) self.llfunction = self.map(func)
if self.llfunction is None:
llfunty = self.llty_of_type(func.type, bare=True)
self.llfunction = ll.Function(self.llmodule, llfunty, func.name)
if func.is_internal: if func.is_internal:
self.llfunction.linkage = 'internal' self.llfunction.linkage = 'internal'
@ -533,7 +530,7 @@ class LLVMIRGenerator:
disubprogram = self.debug_info_emitter.emit_subprogram(func, self.llfunction) disubprogram = self.debug_info_emitter.emit_subprogram(func, self.llfunction)
# First, map arguments. # First, map arguments.
if self.llfunction.type.pointee.__has_sret: if self.has_sret(func.type):
llactualargs = self.llfunction.args[1:] llactualargs = self.llfunction.args[1:]
else: else:
llactualargs = self.llfunction.args llactualargs = self.llfunction.args
@ -1095,7 +1092,7 @@ class LLVMIRGenerator:
else: else:
llfun, llargs = self._prepare_closure_call(insn) llfun, llargs = self._prepare_closure_call(insn)
if llfun.type.pointee.__has_sret: if self.has_sret(insn.target_function().type):
llstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), []) llstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), [])
llresultslot = self.llbuilder.alloca(llfun.type.pointee.args[0].pointee) llresultslot = self.llbuilder.alloca(llfun.type.pointee.args[0].pointee)
@ -1221,11 +1218,12 @@ class LLVMIRGenerator:
if builtins.is_none(insn.value().type): if builtins.is_none(insn.value().type):
return self.llbuilder.ret_void() return self.llbuilder.ret_void()
else: else:
if self.llfunction.type.pointee.__has_sret: llvalue = self.map(insn.value())
self.llbuilder.store(self.map(insn.value()), self.llfunction.args[0]) if self.needs_sret(llvalue):
self.llbuilder.store(llvalue, self.llfunction.args[0])
return self.llbuilder.ret_void() return self.llbuilder.ret_void()
else: else:
return self.llbuilder.ret(self.map(insn.value())) return self.llbuilder.ret(llvalue)
def process_Unreachable(self, insn): def process_Unreachable(self, insn):
return self.llbuilder.unreachable() return self.llbuilder.unreachable()