forked from M-Labs/artiq
1
0
Fork 0

compiler: allow specifying per-function "fast-math" flags.

Fixes #351.
This commit is contained in:
whitequark 2016-03-28 21:25:40 +00:00
parent f31249ad1c
commit 1038f1321f
15 changed files with 114 additions and 46 deletions

View File

@ -30,7 +30,9 @@ class ClassDefT(ast.ClassDef):
class FunctionDefT(ast.FunctionDef, scoped):
_types = ("signature_type",)
class QuotedFunctionDefT(FunctionDefT):
pass
"""
:ivar flags: (set of str) Code generation flags (see :class:`ir.Function`).
"""
class ModuleT(ast.Module, scoped):
pass

View File

@ -546,7 +546,7 @@ class Stitcher:
value_map=self.value_map,
quote_function=self._quote_function)
def _quote_embedded_function(self, function):
def _quote_embedded_function(self, function, flags):
if not hasattr(function, "artiq_embedded"):
raise ValueError("{} is not an embedded function".format(repr(function)))
@ -596,6 +596,7 @@ class Stitcher:
globals=self.globals, host_environment=host_environment,
quote=self._quote)
function_node = asttyped_rewriter.visit_quoted_function(function_node, embedded_function)
function_node.flags = flags
# Add it into our typedtree so that it gets inferenced and codegen'd.
self._inject(function_node)
@ -774,7 +775,8 @@ class Stitcher:
notes=[note])
self.engine.process(diag)
self._quote_embedded_function(function)
self._quote_embedded_function(function,
flags=function.artiq_embedded.flags)
elif function.artiq_embedded.syscall is not None:
# Insert a storage-less global whose type instructs the compiler
# to perform a system call instead of a regular call.

View File

@ -425,6 +425,8 @@ class Function:
the module it is contained in
:ivar is_cold:
(bool) if True, the function should be considered rarely called
:ivar flags: (set of str) Code generation flags.
Flag ``fast-math`` is the equivalent of gcc's ``-ffast-math``.
"""
def __init__(self, typ, name, arguments, loc=None):
@ -434,6 +436,7 @@ class Function:
self.set_arguments(arguments)
self.is_internal = False
self.is_cold = False
self.flags = {}
def _remove_name(self, name):
self.names.remove(name)

View File

@ -224,7 +224,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
finally:
self.current_class = old_class
def visit_function(self, node, is_lambda=False, is_internal=False, is_quoted=False):
def visit_function(self, node, is_lambda=False, is_internal=False, is_quoted=False,
flags={}):
if is_lambda:
name = "lambda@{}:{}".format(node.loc.line(), node.loc.column())
typ = node.type.find()
@ -270,6 +271,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
func = ir.Function(typ, ".".join(self.name), [env_arg] + args + optargs,
loc=node.lambda_loc if is_lambda else node.keyword_loc)
func.is_internal = is_internal
func.flags = flags
self.functions.append(func)
old_func, self.current_function = self.current_function, func
@ -336,7 +338,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
self.append(ir.SetAttr(self.current_class, node.name, func))
def visit_QuotedFunctionDefT(self, node):
self.visit_function(node, is_internal=True, is_quoted=True)
self.visit_function(node, is_internal=True, is_quoted=True, flags=node.flags)
def visit_Return(self, node):
if node.value is None:

View File

@ -846,6 +846,10 @@ class Inferencer(algorithm.Visitor):
# An user-defined class.
self._unify(node.type, typ.find().instance,
node.loc, None)
elif types.is_builtin(typ, "kernel"):
# Ignored.
self._unify(node.type, builtins.TNone(),
node.loc, None)
else:
assert False
@ -1188,7 +1192,9 @@ class Inferencer(algorithm.Visitor):
def visit_FunctionDefT(self, node):
for index, decorator in enumerate(node.decorator_list):
if types.is_builtin(decorator.type, "kernel"):
if types.is_builtin(decorator.type, "kernel") or \
isinstance(decorator, asttyped.CallT) and \
types.is_builtin(decorator.func.type, "kernel"):
continue
diag = diagnostic.Diagnostic("error",

View File

@ -175,6 +175,7 @@ class LLVMIRGenerator:
self.llmodule = ll.Module(context=self.llcontext, name=module_name)
self.llmodule.triple = target.triple
self.llmodule.data_layout = target.data_layout
self.function_flags = None
self.llfunction = None
self.llmap = {}
self.llobject_map = {}
@ -562,6 +563,7 @@ class LLVMIRGenerator:
def process_function(self, func):
try:
self.function_flags = func.flags
self.llfunction = self.map(func)
if func.is_internal:
@ -617,6 +619,7 @@ class LLVMIRGenerator:
for value, block in phi.incoming():
llphi.add_incoming(self.map(value), llblock_map[block])
finally:
self.function_flags = None
self.llfunction = None
self.llmap = {}
self.phis = []
@ -863,40 +866,55 @@ class LLVMIRGenerator:
else:
assert False
def add_fast_math_flags(self, llvalue):
if 'fast-math' in self.function_flags:
llvalue.opname = llvalue.opname + ' fast'
def process_Arith(self, insn):
if isinstance(insn.op, ast.Add):
if builtins.is_float(insn.type):
return self.llbuilder.fadd(self.map(insn.lhs()), self.map(insn.rhs()),
llvalue = self.llbuilder.fadd(self.map(insn.lhs()), self.map(insn.rhs()),
name=insn.name)
self.add_fast_math_flags(llvalue)
return llvalue
else:
return self.llbuilder.add(self.map(insn.lhs()), self.map(insn.rhs()),
name=insn.name)
elif isinstance(insn.op, ast.Sub):
if builtins.is_float(insn.type):
return self.llbuilder.fsub(self.map(insn.lhs()), self.map(insn.rhs()),
llvalue = self.llbuilder.fsub(self.map(insn.lhs()), self.map(insn.rhs()),
name=insn.name)
self.add_fast_math_flags(llvalue)
return llvalue
else:
return self.llbuilder.sub(self.map(insn.lhs()), self.map(insn.rhs()),
name=insn.name)
elif isinstance(insn.op, ast.Mult):
if builtins.is_float(insn.type):
return self.llbuilder.fmul(self.map(insn.lhs()), self.map(insn.rhs()),
llvalue = self.llbuilder.fmul(self.map(insn.lhs()), self.map(insn.rhs()),
name=insn.name)
self.add_fast_math_flags(llvalue)
return llvalue
else:
return self.llbuilder.mul(self.map(insn.lhs()), self.map(insn.rhs()),
name=insn.name)
elif isinstance(insn.op, ast.Div):
if builtins.is_float(insn.lhs().type):
return self.llbuilder.fdiv(self.map(insn.lhs()), self.map(insn.rhs()),
llvalue = self.llbuilder.fdiv(self.map(insn.lhs()), self.map(insn.rhs()),
name=insn.name)
self.add_fast_math_flags(llvalue)
return llvalue
else:
lllhs = self.llbuilder.sitofp(self.map(insn.lhs()), self.llty_of_type(insn.type))
llrhs = self.llbuilder.sitofp(self.map(insn.rhs()), self.llty_of_type(insn.type))
return self.llbuilder.fdiv(lllhs, llrhs,
llvalue = self.llbuilder.fdiv(lllhs, llrhs,
name=insn.name)
self.add_fast_math_flags(llvalue)
return llvalue
elif isinstance(insn.op, ast.FloorDiv):
if builtins.is_float(insn.type):
llvalue = self.llbuilder.fdiv(self.map(insn.lhs()), self.map(insn.rhs()))
self.add_fast_math_flags(llvalue)
return self.llbuilder.call(self.llbuiltin("llvm.floor.f64"), [llvalue],
name=insn.name)
else:
@ -906,6 +924,7 @@ class LLVMIRGenerator:
# Python only has the modulo operator, LLVM only has the remainder
if builtins.is_float(insn.type):
llvalue = self.llbuilder.frem(self.map(insn.lhs()), self.map(insn.rhs()))
self.add_fast_math_flags(llvalue)
return self.llbuilder.call(self.llbuiltin("llvm.copysign.f64"),
[llvalue, self.map(insn.rhs())],
name=insn.name)

View File

@ -351,6 +351,7 @@ class TCFunction(TFunction):
attributes = OrderedDict()
def __init__(self, args, ret, name, flags={}):
assert isinstance(flags, set)
for flag in flags:
assert flag in {'nounwind', 'nowrite'}
super().__init__(args, OrderedDict(), ret)

View File

@ -2,11 +2,11 @@ from artiq.language.core import *
from artiq.language.types import *
@syscall("cache_get", flags={"nounwind", "nowrite"})
@syscall(flags={"nounwind", "nowrite"})
def cache_get(key: TStr) -> TList(TInt32):
raise NotImplementedError("syscall not simulated")
@syscall("cache_put", flags={"nowrite"})
@syscall(flags={"nowrite"})
def cache_put(key: TStr, value: TList(TInt32)) -> TNone:
raise NotImplementedError("syscall not simulated")

View File

@ -37,7 +37,7 @@ class CompileError(Exception):
return "\n" + _render_diagnostic(self.diagnostic, colored=colors_supported)
@syscall("rtio_get_counter", flags={"nounwind", "nowrite"})
@syscall(flags={"nounwind", "nowrite"})
def rtio_get_counter() -> TInt64:
raise NotImplementedError("syscall not simulated")

View File

@ -10,20 +10,20 @@ PHASE_MODE_ABSOLUTE = 1
PHASE_MODE_TRACKING = 2
@syscall("dds_init", flags={"nowrite"})
@syscall(flags={"nowrite"})
def dds_init(time_mu: TInt64, bus_channel: TInt32, channel: TInt32) -> TNone:
raise NotImplementedError("syscall not simulated")
@syscall("dds_set", flags={"nowrite"})
@syscall(flags={"nowrite"})
def dds_set(time_mu: TInt64, bus_channel: TInt32, channel: TInt32, ftw: TInt32,
pow: TInt32, phase_mode: TInt32, amplitude: TInt32) -> TNone:
raise NotImplementedError("syscall not simulated")
@syscall("dds_batch_enter", flags={"nowrite"})
@syscall(flags={"nowrite"})
def dds_batch_enter(time_mu: TInt64) -> TNone:
raise NotImplementedError("syscall not simulated")
@syscall("dds_batch_exit", flags={"nowrite"})
@syscall(flags={"nowrite"})
def dds_batch_exit() -> TNone:
raise NotImplementedError("syscall not simulated")
@ -99,27 +99,27 @@ class _DDSGeneric:
self.channel = channel
self.phase_mode = PHASE_MODE_CONTINUOUS
@portable
@portable(flags=["fast-math"])
def frequency_to_ftw(self, frequency):
"""Returns the frequency tuning word corresponding to the given
frequency.
"""
return round(int(2, width=64)**32*frequency/self.core_dds.sysclk)
@portable
@portable(flags=["fast-math"])
def ftw_to_frequency(self, ftw):
"""Returns the frequency corresponding to the given frequency tuning
word.
"""
return ftw*self.core_dds.sysclk/int(2, width=64)**32
@portable
@portable(flags=["fast-math"])
def turns_to_pow(self, turns):
"""Returns the phase offset word corresponding to the given phase
in turns."""
return round(turns*2**self.pow_width)
@portable
@portable(flags=["fast-math"])
def pow_to_turns(self, pow):
"""Returns the phase in turns corresponding to the given phase offset
word."""

View File

@ -3,27 +3,27 @@ from artiq.language.types import TBool, TInt32, TNone
from artiq.coredevice.exceptions import I2CError
@syscall("i2c_init", flags={"nowrite"})
@syscall(flags={"nowrite"})
def i2c_init(busno: TInt32) -> TNone:
raise NotImplementedError("syscall not simulated")
@syscall("i2c_start", flags={"nounwind", "nowrite"})
@syscall(flags={"nounwind", "nowrite"})
def i2c_start(busno: TInt32) -> TNone:
raise NotImplementedError("syscall not simulated")
@syscall("i2c_stop", flags={"nounwind", "nowrite"})
@syscall(flags={"nounwind", "nowrite"})
def i2c_stop(busno: TInt32) -> TNone:
raise NotImplementedError("syscall not simulated")
@syscall("i2c_write", flags={"nounwind", "nowrite"})
@syscall(flags={"nounwind", "nowrite"})
def i2c_write(busno: TInt32, b: TInt32) -> TBool:
raise NotImplementedError("syscall not simulated")
@syscall("i2c_read", flags={"nounwind", "nowrite"})
@syscall(flags={"nounwind", "nowrite"})
def i2c_read(busno: TInt32, ack: TBool) -> TInt32:
raise NotImplementedError("syscall not simulated")

View File

@ -2,17 +2,17 @@ from artiq.language.core import syscall
from artiq.language.types import TInt64, TInt32, TNone
@syscall("rtio_output", flags={"nowrite"})
@syscall(flags={"nowrite"})
def rtio_output(time_mu: TInt64, channel: TInt32, addr: TInt32, data: TInt32
) -> TNone:
raise NotImplementedError("syscall not simulated")
@syscall("rtio_input_timestamp", flags={"nowrite"})
@syscall(flags={"nowrite"})
def rtio_input_timestamp(timeout_mu: TInt64, channel: TInt32) -> TInt64:
raise NotImplementedError("syscall not simulated")
@syscall("rtio_input_data", flags={"nowrite"})
@syscall(flags={"nowrite"})
def rtio_input_data(channel: TInt32) -> TInt32:
raise NotImplementedError("syscall not simulated")

View File

@ -165,7 +165,7 @@ def round(value, width=32):
_ARTIQEmbeddedInfo = namedtuple("_ARTIQEmbeddedInfo",
"core_name function syscall forbidden flags")
def kernel(arg):
def kernel(arg=None, flags={}):
"""
This decorator marks an object's method for execution on the core
device.
@ -192,13 +192,17 @@ def kernel(arg):
return getattr(self, arg).run(run_on_core, ((self,) + k_args), k_kwargs)
run_on_core.artiq_embedded = _ARTIQEmbeddedInfo(
core_name=arg, function=function, syscall=None,
forbidden=False, flags={})
forbidden=False, flags=set(flags))
return run_on_core
return inner_decorator
elif arg is None:
def inner_decorator(function):
return kernel(function, flags)
return inner_decorator
else:
return kernel("core")(arg)
return kernel("core", flags)(arg)
def portable(function):
def portable(arg=None, flags={}):
"""
This decorator marks a function for execution on the same device as its
caller.
@ -208,12 +212,17 @@ def portable(function):
core device). A decorated function called from a kernel will be executed
on the core device (no RPC).
"""
function.artiq_embedded = \
_ARTIQEmbeddedInfo(core_name=None, function=function, syscall=None,
forbidden=False, flags={})
return function
if arg is None:
def inner_decorator(function):
return portable(function, flags)
return inner_decorator
else:
arg.artiq_embedded = \
_ARTIQEmbeddedInfo(core_name=None, function=arg, syscall=None,
forbidden=False, flags=set(flags))
return arg
def syscall(arg, flags={}):
def syscall(arg=None, flags={}):
"""
This decorator marks a function as a system call. When executed on a core
device, a C function with the provided name (or the same name as
@ -229,9 +238,13 @@ def syscall(arg, flags={}):
function.artiq_embedded = \
_ARTIQEmbeddedInfo(core_name=None, function=None,
syscall=function.__name__, forbidden=False,
flags=flags)
flags=set(flags))
return function
return inner_decorator
elif arg is None:
def inner_decorator(function):
return syscall(function.__name__, flags)(function)
return inner_decorator
else:
return syscall(arg.__name__)(arg)

View File

@ -0,0 +1,20 @@
# RUN: env ARTIQ_DUMP_UNOPT_LLVM=%t %python -m artiq.compiler.testbench.embedding +compile %s
# RUN: OutputCheck %s --file-to-check=%t_unopt.ll
from artiq.language.core import *
from artiq.language.types import *
# CHECK-L: fmul fast double 1.000000e+00, 0.000000e+00
@kernel(flags=["fast-math"])
def foo():
core_log(1.0 * 0.0)
# CHECK-L: fmul fast double 2.000000e+00, 0.000000e+00
@portable(flags=["fast-math"])
def bar():
core_log(2.0 * 0.0)
@kernel
def entrypoint():
foo()
bar()

View File

@ -9,7 +9,7 @@ from artiq.language.types import *
# CHECK-L: ; Function Attrs: nounwind
# CHECK-NEXT-L: declare void @foo()
@syscall("foo", flags={"nounwind", "nowrite"})
@syscall(flags={"nounwind", "nowrite"})
def foo() -> TNone:
pass