LLVMIRGenerator: use sret when returning large structures.

This commit is contained in:
whitequark 2015-08-19 15:06:03 -07:00
parent 673512f356
commit 27a697920a
1 changed files with 53 additions and 5 deletions

View File

@ -171,6 +171,21 @@ class LLVMIRGenerator:
self.phis = []
self.debug_info_emitter = DebugInfoEmitter(self.llmodule)
def needs_sret(self, lltyp, may_be_large=True):
if isinstance(lltyp, ll.VoidType):
return False
elif isinstance(lltyp, ll.IntType) and lltyp.width <= 32:
return False
elif isinstance(lltyp, ll.PointerType):
return False
elif may_be_large and isinstance(lltyp, ll.DoubleType):
return False
elif may_be_large and isinstance(lltyp, ll.LiteralStructType) \
and len(lltyp.elements) <= 2:
return not any([self.needs_sret(elt, may_be_large=False) for elt in lltyp.elements])
else:
return True
def llty_of_type(self, typ, bare=False, for_return=False):
typ = typ.find()
if types.is_tuple(typ):
@ -183,13 +198,28 @@ class LLVMIRGenerator:
elif types._is_pointer(typ):
return llptr
elif types.is_function(typ):
sretarg = []
llretty = self.llty_of_type(typ.ret, for_return=True)
if self.needs_sret(llretty):
sretarg = [llretty.as_pointer()]
llretty = llvoid
envarg = llptr
llty = ll.FunctionType(args=[envarg] +
llty = ll.FunctionType(args=sretarg + [envarg] +
[self.llty_of_type(typ.args[arg])
for arg in typ.args] +
[self.llty_of_type(ir.TOption(typ.optargs[arg]))
for arg in typ.optargs],
return_type=self.llty_of_type(typ.ret, for_return=True))
return_type=llretty)
# TODO: actually mark the first argument as sret (also noalias nocapture).
# llvmlite currently does not have support for this;
# https://github.com/numba/llvmlite/issues/91.
if sretarg:
llty.__has_sret = True
else:
llty.__has_sret = False
if bare:
return llty
else:
@ -896,6 +926,20 @@ class LLVMIRGenerator:
name=insn.name)
else:
llfun, llargs = self._prepare_closure_call(insn)
if llfun.type.pointee.__has_sret:
llstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), [])
llresultslot = self.llbuilder.alloca(llfun.type.pointee.args[0].pointee)
print(llfun)
print(llresultslot)
self.llbuilder.call(llfun, [llresultslot] + llargs)
llresult = self.llbuilder.load(llresultslot)
self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llstackptr])
return llresult
else:
return self.llbuilder.call(llfun, llargs,
name=insn.name)
@ -936,6 +980,10 @@ class LLVMIRGenerator:
def process_Return(self, insn):
if builtins.is_none(insn.value().type):
return self.llbuilder.ret_void()
else:
if self.llfunction.type.pointee.__has_sret:
self.llbuilder.store(self.map(insn.value()), self.llfunction.args[0])
return self.llbuilder.ret_void()
else:
return self.llbuilder.ret(self.map(insn.value()))