From 1038f1321fd73ed8866c018331d6b0e33d931f5d Mon Sep 17 00:00:00 2001 From: whitequark Date: Mon, 28 Mar 2016 21:25:40 +0000 Subject: [PATCH] compiler: allow specifying per-function "fast-math" flags. Fixes #351. --- artiq/compiler/asttyped.py | 4 +- artiq/compiler/embedding.py | 6 ++- artiq/compiler/ir.py | 3 ++ .../compiler/transforms/artiq_ir_generator.py | 6 ++- artiq/compiler/transforms/inferencer.py | 8 +++- .../compiler/transforms/llvm_ir_generator.py | 39 ++++++++++++++----- artiq/compiler/types.py | 1 + artiq/coredevice/cache.py | 4 +- artiq/coredevice/core.py | 2 +- artiq/coredevice/dds.py | 16 ++++---- artiq/coredevice/i2c.py | 10 ++--- artiq/coredevice/rtio.py | 6 +-- artiq/language/core.py | 33 +++++++++++----- artiq/test/lit/embedding/fast_math_flags.py | 20 ++++++++++ artiq/test/lit/embedding/syscall_flags.py | 2 +- 15 files changed, 114 insertions(+), 46 deletions(-) create mode 100644 artiq/test/lit/embedding/fast_math_flags.py diff --git a/artiq/compiler/asttyped.py b/artiq/compiler/asttyped.py index 77a5349ae..10b197fa4 100644 --- a/artiq/compiler/asttyped.py +++ b/artiq/compiler/asttyped.py @@ -30,7 +30,9 @@ class ClassDefT(ast.ClassDef): class FunctionDefT(ast.FunctionDef, scoped): _types = ("signature_type",) class QuotedFunctionDefT(FunctionDefT): - pass + """ + :ivar flags: (set of str) Code generation flags (see :class:`ir.Function`). + """ class ModuleT(ast.Module, scoped): pass diff --git a/artiq/compiler/embedding.py b/artiq/compiler/embedding.py index 6680576ca..ffe9f3426 100644 --- a/artiq/compiler/embedding.py +++ b/artiq/compiler/embedding.py @@ -546,7 +546,7 @@ class Stitcher: value_map=self.value_map, quote_function=self._quote_function) - def _quote_embedded_function(self, function): + def _quote_embedded_function(self, function, flags): if not hasattr(function, "artiq_embedded"): raise ValueError("{} is not an embedded function".format(repr(function))) @@ -596,6 +596,7 @@ class Stitcher: globals=self.globals, host_environment=host_environment, quote=self._quote) function_node = asttyped_rewriter.visit_quoted_function(function_node, embedded_function) + function_node.flags = flags # Add it into our typedtree so that it gets inferenced and codegen'd. self._inject(function_node) @@ -774,7 +775,8 @@ class Stitcher: notes=[note]) self.engine.process(diag) - self._quote_embedded_function(function) + self._quote_embedded_function(function, + flags=function.artiq_embedded.flags) elif function.artiq_embedded.syscall is not None: # Insert a storage-less global whose type instructs the compiler # to perform a system call instead of a regular call. diff --git a/artiq/compiler/ir.py b/artiq/compiler/ir.py index 91cc88988..7dab3e3c3 100644 --- a/artiq/compiler/ir.py +++ b/artiq/compiler/ir.py @@ -425,6 +425,8 @@ class Function: the module it is contained in :ivar is_cold: (bool) if True, the function should be considered rarely called + :ivar flags: (set of str) Code generation flags. + Flag ``fast-math`` is the equivalent of gcc's ``-ffast-math``. """ def __init__(self, typ, name, arguments, loc=None): @@ -434,6 +436,7 @@ class Function: self.set_arguments(arguments) self.is_internal = False self.is_cold = False + self.flags = {} def _remove_name(self, name): self.names.remove(name) diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py index 1875c4989..5e5fd81b6 100644 --- a/artiq/compiler/transforms/artiq_ir_generator.py +++ b/artiq/compiler/transforms/artiq_ir_generator.py @@ -224,7 +224,8 @@ class ARTIQIRGenerator(algorithm.Visitor): finally: self.current_class = old_class - def visit_function(self, node, is_lambda=False, is_internal=False, is_quoted=False): + def visit_function(self, node, is_lambda=False, is_internal=False, is_quoted=False, + flags={}): if is_lambda: name = "lambda@{}:{}".format(node.loc.line(), node.loc.column()) typ = node.type.find() @@ -270,6 +271,7 @@ class ARTIQIRGenerator(algorithm.Visitor): func = ir.Function(typ, ".".join(self.name), [env_arg] + args + optargs, loc=node.lambda_loc if is_lambda else node.keyword_loc) func.is_internal = is_internal + func.flags = flags self.functions.append(func) old_func, self.current_function = self.current_function, func @@ -336,7 +338,7 @@ class ARTIQIRGenerator(algorithm.Visitor): self.append(ir.SetAttr(self.current_class, node.name, func)) def visit_QuotedFunctionDefT(self, node): - self.visit_function(node, is_internal=True, is_quoted=True) + self.visit_function(node, is_internal=True, is_quoted=True, flags=node.flags) def visit_Return(self, node): if node.value is None: diff --git a/artiq/compiler/transforms/inferencer.py b/artiq/compiler/transforms/inferencer.py index fc8f6212e..460be6767 100644 --- a/artiq/compiler/transforms/inferencer.py +++ b/artiq/compiler/transforms/inferencer.py @@ -846,6 +846,10 @@ class Inferencer(algorithm.Visitor): # An user-defined class. self._unify(node.type, typ.find().instance, node.loc, None) + elif types.is_builtin(typ, "kernel"): + # Ignored. + self._unify(node.type, builtins.TNone(), + node.loc, None) else: assert False @@ -1188,7 +1192,9 @@ class Inferencer(algorithm.Visitor): def visit_FunctionDefT(self, node): for index, decorator in enumerate(node.decorator_list): - if types.is_builtin(decorator.type, "kernel"): + if types.is_builtin(decorator.type, "kernel") or \ + isinstance(decorator, asttyped.CallT) and \ + types.is_builtin(decorator.func.type, "kernel"): continue diag = diagnostic.Diagnostic("error", diff --git a/artiq/compiler/transforms/llvm_ir_generator.py b/artiq/compiler/transforms/llvm_ir_generator.py index 65dd7a47d..f25a5a9c6 100644 --- a/artiq/compiler/transforms/llvm_ir_generator.py +++ b/artiq/compiler/transforms/llvm_ir_generator.py @@ -175,6 +175,7 @@ class LLVMIRGenerator: self.llmodule = ll.Module(context=self.llcontext, name=module_name) self.llmodule.triple = target.triple self.llmodule.data_layout = target.data_layout + self.function_flags = None self.llfunction = None self.llmap = {} self.llobject_map = {} @@ -562,6 +563,7 @@ class LLVMIRGenerator: def process_function(self, func): try: + self.function_flags = func.flags self.llfunction = self.map(func) if func.is_internal: @@ -617,6 +619,7 @@ class LLVMIRGenerator: for value, block in phi.incoming(): llphi.add_incoming(self.map(value), llblock_map[block]) finally: + self.function_flags = None self.llfunction = None self.llmap = {} self.phis = [] @@ -863,40 +866,55 @@ class LLVMIRGenerator: else: assert False + def add_fast_math_flags(self, llvalue): + if 'fast-math' in self.function_flags: + llvalue.opname = llvalue.opname + ' fast' + def process_Arith(self, insn): if isinstance(insn.op, ast.Add): if builtins.is_float(insn.type): - return self.llbuilder.fadd(self.map(insn.lhs()), self.map(insn.rhs()), - name=insn.name) + llvalue = self.llbuilder.fadd(self.map(insn.lhs()), self.map(insn.rhs()), + name=insn.name) + self.add_fast_math_flags(llvalue) + return llvalue else: return self.llbuilder.add(self.map(insn.lhs()), self.map(insn.rhs()), name=insn.name) elif isinstance(insn.op, ast.Sub): if builtins.is_float(insn.type): - return self.llbuilder.fsub(self.map(insn.lhs()), self.map(insn.rhs()), - name=insn.name) + llvalue = self.llbuilder.fsub(self.map(insn.lhs()), self.map(insn.rhs()), + name=insn.name) + self.add_fast_math_flags(llvalue) + return llvalue else: return self.llbuilder.sub(self.map(insn.lhs()), self.map(insn.rhs()), name=insn.name) elif isinstance(insn.op, ast.Mult): if builtins.is_float(insn.type): - return self.llbuilder.fmul(self.map(insn.lhs()), self.map(insn.rhs()), - name=insn.name) + llvalue = self.llbuilder.fmul(self.map(insn.lhs()), self.map(insn.rhs()), + name=insn.name) + self.add_fast_math_flags(llvalue) + return llvalue else: return self.llbuilder.mul(self.map(insn.lhs()), self.map(insn.rhs()), name=insn.name) elif isinstance(insn.op, ast.Div): if builtins.is_float(insn.lhs().type): - return self.llbuilder.fdiv(self.map(insn.lhs()), self.map(insn.rhs()), - name=insn.name) + llvalue = self.llbuilder.fdiv(self.map(insn.lhs()), self.map(insn.rhs()), + name=insn.name) + self.add_fast_math_flags(llvalue) + return llvalue else: lllhs = self.llbuilder.sitofp(self.map(insn.lhs()), self.llty_of_type(insn.type)) llrhs = self.llbuilder.sitofp(self.map(insn.rhs()), self.llty_of_type(insn.type)) - return self.llbuilder.fdiv(lllhs, llrhs, - name=insn.name) + llvalue = self.llbuilder.fdiv(lllhs, llrhs, + name=insn.name) + self.add_fast_math_flags(llvalue) + return llvalue elif isinstance(insn.op, ast.FloorDiv): if builtins.is_float(insn.type): llvalue = self.llbuilder.fdiv(self.map(insn.lhs()), self.map(insn.rhs())) + self.add_fast_math_flags(llvalue) return self.llbuilder.call(self.llbuiltin("llvm.floor.f64"), [llvalue], name=insn.name) else: @@ -906,6 +924,7 @@ class LLVMIRGenerator: # Python only has the modulo operator, LLVM only has the remainder if builtins.is_float(insn.type): llvalue = self.llbuilder.frem(self.map(insn.lhs()), self.map(insn.rhs())) + self.add_fast_math_flags(llvalue) return self.llbuilder.call(self.llbuiltin("llvm.copysign.f64"), [llvalue, self.map(insn.rhs())], name=insn.name) diff --git a/artiq/compiler/types.py b/artiq/compiler/types.py index 42bd5b6a7..26fc24d95 100644 --- a/artiq/compiler/types.py +++ b/artiq/compiler/types.py @@ -351,6 +351,7 @@ class TCFunction(TFunction): attributes = OrderedDict() def __init__(self, args, ret, name, flags={}): + assert isinstance(flags, set) for flag in flags: assert flag in {'nounwind', 'nowrite'} super().__init__(args, OrderedDict(), ret) diff --git a/artiq/coredevice/cache.py b/artiq/coredevice/cache.py index e0ca77a13..7b01e96b9 100644 --- a/artiq/coredevice/cache.py +++ b/artiq/coredevice/cache.py @@ -2,11 +2,11 @@ from artiq.language.core import * from artiq.language.types import * -@syscall("cache_get", flags={"nounwind", "nowrite"}) +@syscall(flags={"nounwind", "nowrite"}) def cache_get(key: TStr) -> TList(TInt32): raise NotImplementedError("syscall not simulated") -@syscall("cache_put", flags={"nowrite"}) +@syscall(flags={"nowrite"}) def cache_put(key: TStr, value: TList(TInt32)) -> TNone: raise NotImplementedError("syscall not simulated") diff --git a/artiq/coredevice/core.py b/artiq/coredevice/core.py index c9e69f593..1b3e10fdb 100644 --- a/artiq/coredevice/core.py +++ b/artiq/coredevice/core.py @@ -37,7 +37,7 @@ class CompileError(Exception): return "\n" + _render_diagnostic(self.diagnostic, colored=colors_supported) -@syscall("rtio_get_counter", flags={"nounwind", "nowrite"}) +@syscall(flags={"nounwind", "nowrite"}) def rtio_get_counter() -> TInt64: raise NotImplementedError("syscall not simulated") diff --git a/artiq/coredevice/dds.py b/artiq/coredevice/dds.py index 2af15175b..c548ab28e 100644 --- a/artiq/coredevice/dds.py +++ b/artiq/coredevice/dds.py @@ -10,20 +10,20 @@ PHASE_MODE_ABSOLUTE = 1 PHASE_MODE_TRACKING = 2 -@syscall("dds_init", flags={"nowrite"}) +@syscall(flags={"nowrite"}) def dds_init(time_mu: TInt64, bus_channel: TInt32, channel: TInt32) -> TNone: raise NotImplementedError("syscall not simulated") -@syscall("dds_set", flags={"nowrite"}) +@syscall(flags={"nowrite"}) def dds_set(time_mu: TInt64, bus_channel: TInt32, channel: TInt32, ftw: TInt32, pow: TInt32, phase_mode: TInt32, amplitude: TInt32) -> TNone: raise NotImplementedError("syscall not simulated") -@syscall("dds_batch_enter", flags={"nowrite"}) +@syscall(flags={"nowrite"}) def dds_batch_enter(time_mu: TInt64) -> TNone: raise NotImplementedError("syscall not simulated") -@syscall("dds_batch_exit", flags={"nowrite"}) +@syscall(flags={"nowrite"}) def dds_batch_exit() -> TNone: raise NotImplementedError("syscall not simulated") @@ -99,27 +99,27 @@ class _DDSGeneric: self.channel = channel self.phase_mode = PHASE_MODE_CONTINUOUS - @portable + @portable(flags=["fast-math"]) def frequency_to_ftw(self, frequency): """Returns the frequency tuning word corresponding to the given frequency. """ return round(int(2, width=64)**32*frequency/self.core_dds.sysclk) - @portable + @portable(flags=["fast-math"]) def ftw_to_frequency(self, ftw): """Returns the frequency corresponding to the given frequency tuning word. """ return ftw*self.core_dds.sysclk/int(2, width=64)**32 - @portable + @portable(flags=["fast-math"]) def turns_to_pow(self, turns): """Returns the phase offset word corresponding to the given phase in turns.""" return round(turns*2**self.pow_width) - @portable + @portable(flags=["fast-math"]) def pow_to_turns(self, pow): """Returns the phase in turns corresponding to the given phase offset word.""" diff --git a/artiq/coredevice/i2c.py b/artiq/coredevice/i2c.py index 0123b21c4..fb37a2aa6 100644 --- a/artiq/coredevice/i2c.py +++ b/artiq/coredevice/i2c.py @@ -3,27 +3,27 @@ from artiq.language.types import TBool, TInt32, TNone from artiq.coredevice.exceptions import I2CError -@syscall("i2c_init", flags={"nowrite"}) +@syscall(flags={"nowrite"}) def i2c_init(busno: TInt32) -> TNone: raise NotImplementedError("syscall not simulated") -@syscall("i2c_start", flags={"nounwind", "nowrite"}) +@syscall(flags={"nounwind", "nowrite"}) def i2c_start(busno: TInt32) -> TNone: raise NotImplementedError("syscall not simulated") -@syscall("i2c_stop", flags={"nounwind", "nowrite"}) +@syscall(flags={"nounwind", "nowrite"}) def i2c_stop(busno: TInt32) -> TNone: raise NotImplementedError("syscall not simulated") -@syscall("i2c_write", flags={"nounwind", "nowrite"}) +@syscall(flags={"nounwind", "nowrite"}) def i2c_write(busno: TInt32, b: TInt32) -> TBool: raise NotImplementedError("syscall not simulated") -@syscall("i2c_read", flags={"nounwind", "nowrite"}) +@syscall(flags={"nounwind", "nowrite"}) def i2c_read(busno: TInt32, ack: TBool) -> TInt32: raise NotImplementedError("syscall not simulated") diff --git a/artiq/coredevice/rtio.py b/artiq/coredevice/rtio.py index 5ad71e5e3..f4d0e2c82 100644 --- a/artiq/coredevice/rtio.py +++ b/artiq/coredevice/rtio.py @@ -2,17 +2,17 @@ from artiq.language.core import syscall from artiq.language.types import TInt64, TInt32, TNone -@syscall("rtio_output", flags={"nowrite"}) +@syscall(flags={"nowrite"}) def rtio_output(time_mu: TInt64, channel: TInt32, addr: TInt32, data: TInt32 ) -> TNone: raise NotImplementedError("syscall not simulated") -@syscall("rtio_input_timestamp", flags={"nowrite"}) +@syscall(flags={"nowrite"}) def rtio_input_timestamp(timeout_mu: TInt64, channel: TInt32) -> TInt64: raise NotImplementedError("syscall not simulated") -@syscall("rtio_input_data", flags={"nowrite"}) +@syscall(flags={"nowrite"}) def rtio_input_data(channel: TInt32) -> TInt32: raise NotImplementedError("syscall not simulated") diff --git a/artiq/language/core.py b/artiq/language/core.py index fac508676..c8a15c00f 100644 --- a/artiq/language/core.py +++ b/artiq/language/core.py @@ -165,7 +165,7 @@ def round(value, width=32): _ARTIQEmbeddedInfo = namedtuple("_ARTIQEmbeddedInfo", "core_name function syscall forbidden flags") -def kernel(arg): +def kernel(arg=None, flags={}): """ This decorator marks an object's method for execution on the core device. @@ -192,13 +192,17 @@ def kernel(arg): return getattr(self, arg).run(run_on_core, ((self,) + k_args), k_kwargs) run_on_core.artiq_embedded = _ARTIQEmbeddedInfo( core_name=arg, function=function, syscall=None, - forbidden=False, flags={}) + forbidden=False, flags=set(flags)) return run_on_core return inner_decorator + elif arg is None: + def inner_decorator(function): + return kernel(function, flags) + return inner_decorator else: - return kernel("core")(arg) + return kernel("core", flags)(arg) -def portable(function): +def portable(arg=None, flags={}): """ This decorator marks a function for execution on the same device as its caller. @@ -208,12 +212,17 @@ def portable(function): core device). A decorated function called from a kernel will be executed on the core device (no RPC). """ - function.artiq_embedded = \ - _ARTIQEmbeddedInfo(core_name=None, function=function, syscall=None, - forbidden=False, flags={}) - return function + if arg is None: + def inner_decorator(function): + return portable(function, flags) + return inner_decorator + else: + arg.artiq_embedded = \ + _ARTIQEmbeddedInfo(core_name=None, function=arg, syscall=None, + forbidden=False, flags=set(flags)) + return arg -def syscall(arg, flags={}): +def syscall(arg=None, flags={}): """ This decorator marks a function as a system call. When executed on a core device, a C function with the provided name (or the same name as @@ -229,9 +238,13 @@ def syscall(arg, flags={}): function.artiq_embedded = \ _ARTIQEmbeddedInfo(core_name=None, function=None, syscall=function.__name__, forbidden=False, - flags=flags) + flags=set(flags)) return function return inner_decorator + elif arg is None: + def inner_decorator(function): + return syscall(function.__name__, flags)(function) + return inner_decorator else: return syscall(arg.__name__)(arg) diff --git a/artiq/test/lit/embedding/fast_math_flags.py b/artiq/test/lit/embedding/fast_math_flags.py new file mode 100644 index 000000000..3ed40f70d --- /dev/null +++ b/artiq/test/lit/embedding/fast_math_flags.py @@ -0,0 +1,20 @@ +# RUN: env ARTIQ_DUMP_UNOPT_LLVM=%t %python -m artiq.compiler.testbench.embedding +compile %s +# RUN: OutputCheck %s --file-to-check=%t_unopt.ll + +from artiq.language.core import * +from artiq.language.types import * + +# CHECK-L: fmul fast double 1.000000e+00, 0.000000e+00 +@kernel(flags=["fast-math"]) +def foo(): + core_log(1.0 * 0.0) + +# CHECK-L: fmul fast double 2.000000e+00, 0.000000e+00 +@portable(flags=["fast-math"]) +def bar(): + core_log(2.0 * 0.0) + +@kernel +def entrypoint(): + foo() + bar() diff --git a/artiq/test/lit/embedding/syscall_flags.py b/artiq/test/lit/embedding/syscall_flags.py index aa4398fc7..15636507f 100644 --- a/artiq/test/lit/embedding/syscall_flags.py +++ b/artiq/test/lit/embedding/syscall_flags.py @@ -9,7 +9,7 @@ from artiq.language.types import * # CHECK-L: ; Function Attrs: nounwind # CHECK-NEXT-L: declare void @foo() -@syscall("foo", flags={"nounwind", "nowrite"}) +@syscall(flags={"nounwind", "nowrite"}) def foo() -> TNone: pass