compiler: allow specifying per-function "fast-math" flags.

Fixes #351.
This commit is contained in:
whitequark 2016-03-28 21:25:40 +00:00 committed by Sebastien Bourdeauducq
parent 5fafcc1341
commit ee7e648cb0
15 changed files with 114 additions and 46 deletions

View File

@ -30,7 +30,9 @@ class ClassDefT(ast.ClassDef):
class FunctionDefT(ast.FunctionDef, scoped): class FunctionDefT(ast.FunctionDef, scoped):
_types = ("signature_type",) _types = ("signature_type",)
class QuotedFunctionDefT(FunctionDefT): class QuotedFunctionDefT(FunctionDefT):
pass """
:ivar flags: (set of str) Code generation flags (see :class:`ir.Function`).
"""
class ModuleT(ast.Module, scoped): class ModuleT(ast.Module, scoped):
pass pass

View File

@ -546,7 +546,7 @@ class Stitcher:
value_map=self.value_map, value_map=self.value_map,
quote_function=self._quote_function) quote_function=self._quote_function)
def _quote_embedded_function(self, function): def _quote_embedded_function(self, function, flags):
if not hasattr(function, "artiq_embedded"): if not hasattr(function, "artiq_embedded"):
raise ValueError("{} is not an embedded function".format(repr(function))) raise ValueError("{} is not an embedded function".format(repr(function)))
@ -596,6 +596,7 @@ class Stitcher:
globals=self.globals, host_environment=host_environment, globals=self.globals, host_environment=host_environment,
quote=self._quote) quote=self._quote)
function_node = asttyped_rewriter.visit_quoted_function(function_node, embedded_function) function_node = asttyped_rewriter.visit_quoted_function(function_node, embedded_function)
function_node.flags = flags
# Add it into our typedtree so that it gets inferenced and codegen'd. # Add it into our typedtree so that it gets inferenced and codegen'd.
self._inject(function_node) self._inject(function_node)
@ -774,7 +775,8 @@ class Stitcher:
notes=[note]) notes=[note])
self.engine.process(diag) self.engine.process(diag)
self._quote_embedded_function(function) self._quote_embedded_function(function,
flags=function.artiq_embedded.flags)
elif function.artiq_embedded.syscall is not None: elif function.artiq_embedded.syscall is not None:
# Insert a storage-less global whose type instructs the compiler # Insert a storage-less global whose type instructs the compiler
# to perform a system call instead of a regular call. # to perform a system call instead of a regular call.

View File

@ -425,6 +425,8 @@ class Function:
the module it is contained in the module it is contained in
:ivar is_cold: :ivar is_cold:
(bool) if True, the function should be considered rarely called (bool) if True, the function should be considered rarely called
:ivar flags: (set of str) Code generation flags.
Flag ``fast-math`` is the equivalent of gcc's ``-ffast-math``.
""" """
def __init__(self, typ, name, arguments, loc=None): def __init__(self, typ, name, arguments, loc=None):
@ -434,6 +436,7 @@ class Function:
self.set_arguments(arguments) self.set_arguments(arguments)
self.is_internal = False self.is_internal = False
self.is_cold = False self.is_cold = False
self.flags = {}
def _remove_name(self, name): def _remove_name(self, name):
self.names.remove(name) self.names.remove(name)

View File

@ -224,7 +224,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
finally: finally:
self.current_class = old_class self.current_class = old_class
def visit_function(self, node, is_lambda=False, is_internal=False, is_quoted=False): def visit_function(self, node, is_lambda=False, is_internal=False, is_quoted=False,
flags={}):
if is_lambda: if is_lambda:
name = "lambda@{}:{}".format(node.loc.line(), node.loc.column()) name = "lambda@{}:{}".format(node.loc.line(), node.loc.column())
typ = node.type.find() typ = node.type.find()
@ -270,6 +271,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
func = ir.Function(typ, ".".join(self.name), [env_arg] + args + optargs, func = ir.Function(typ, ".".join(self.name), [env_arg] + args + optargs,
loc=node.lambda_loc if is_lambda else node.keyword_loc) loc=node.lambda_loc if is_lambda else node.keyword_loc)
func.is_internal = is_internal func.is_internal = is_internal
func.flags = flags
self.functions.append(func) self.functions.append(func)
old_func, self.current_function = self.current_function, func old_func, self.current_function = self.current_function, func
@ -336,7 +338,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
self.append(ir.SetAttr(self.current_class, node.name, func)) self.append(ir.SetAttr(self.current_class, node.name, func))
def visit_QuotedFunctionDefT(self, node): def visit_QuotedFunctionDefT(self, node):
self.visit_function(node, is_internal=True, is_quoted=True) self.visit_function(node, is_internal=True, is_quoted=True, flags=node.flags)
def visit_Return(self, node): def visit_Return(self, node):
if node.value is None: if node.value is None:

View File

@ -846,6 +846,10 @@ class Inferencer(algorithm.Visitor):
# An user-defined class. # An user-defined class.
self._unify(node.type, typ.find().instance, self._unify(node.type, typ.find().instance,
node.loc, None) node.loc, None)
elif types.is_builtin(typ, "kernel"):
# Ignored.
self._unify(node.type, builtins.TNone(),
node.loc, None)
else: else:
assert False assert False
@ -1188,7 +1192,9 @@ class Inferencer(algorithm.Visitor):
def visit_FunctionDefT(self, node): def visit_FunctionDefT(self, node):
for index, decorator in enumerate(node.decorator_list): for index, decorator in enumerate(node.decorator_list):
if types.is_builtin(decorator.type, "kernel"): if types.is_builtin(decorator.type, "kernel") or \
isinstance(decorator, asttyped.CallT) and \
types.is_builtin(decorator.func.type, "kernel"):
continue continue
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",

View File

@ -175,6 +175,7 @@ class LLVMIRGenerator:
self.llmodule = ll.Module(context=self.llcontext, name=module_name) self.llmodule = ll.Module(context=self.llcontext, name=module_name)
self.llmodule.triple = target.triple self.llmodule.triple = target.triple
self.llmodule.data_layout = target.data_layout self.llmodule.data_layout = target.data_layout
self.function_flags = None
self.llfunction = None self.llfunction = None
self.llmap = {} self.llmap = {}
self.llobject_map = {} self.llobject_map = {}
@ -562,6 +563,7 @@ class LLVMIRGenerator:
def process_function(self, func): def process_function(self, func):
try: try:
self.function_flags = func.flags
self.llfunction = self.map(func) self.llfunction = self.map(func)
if func.is_internal: if func.is_internal:
@ -617,6 +619,7 @@ class LLVMIRGenerator:
for value, block in phi.incoming(): for value, block in phi.incoming():
llphi.add_incoming(self.map(value), llblock_map[block]) llphi.add_incoming(self.map(value), llblock_map[block])
finally: finally:
self.function_flags = None
self.llfunction = None self.llfunction = None
self.llmap = {} self.llmap = {}
self.phis = [] self.phis = []
@ -863,40 +866,55 @@ class LLVMIRGenerator:
else: else:
assert False assert False
def add_fast_math_flags(self, llvalue):
if 'fast-math' in self.function_flags:
llvalue.opname = llvalue.opname + ' fast'
def process_Arith(self, insn): def process_Arith(self, insn):
if isinstance(insn.op, ast.Add): if isinstance(insn.op, ast.Add):
if builtins.is_float(insn.type): if builtins.is_float(insn.type):
return self.llbuilder.fadd(self.map(insn.lhs()), self.map(insn.rhs()), llvalue = self.llbuilder.fadd(self.map(insn.lhs()), self.map(insn.rhs()),
name=insn.name) name=insn.name)
self.add_fast_math_flags(llvalue)
return llvalue
else: else:
return self.llbuilder.add(self.map(insn.lhs()), self.map(insn.rhs()), return self.llbuilder.add(self.map(insn.lhs()), self.map(insn.rhs()),
name=insn.name) name=insn.name)
elif isinstance(insn.op, ast.Sub): elif isinstance(insn.op, ast.Sub):
if builtins.is_float(insn.type): if builtins.is_float(insn.type):
return self.llbuilder.fsub(self.map(insn.lhs()), self.map(insn.rhs()), llvalue = self.llbuilder.fsub(self.map(insn.lhs()), self.map(insn.rhs()),
name=insn.name) name=insn.name)
self.add_fast_math_flags(llvalue)
return llvalue
else: else:
return self.llbuilder.sub(self.map(insn.lhs()), self.map(insn.rhs()), return self.llbuilder.sub(self.map(insn.lhs()), self.map(insn.rhs()),
name=insn.name) name=insn.name)
elif isinstance(insn.op, ast.Mult): elif isinstance(insn.op, ast.Mult):
if builtins.is_float(insn.type): if builtins.is_float(insn.type):
return self.llbuilder.fmul(self.map(insn.lhs()), self.map(insn.rhs()), llvalue = self.llbuilder.fmul(self.map(insn.lhs()), self.map(insn.rhs()),
name=insn.name) name=insn.name)
self.add_fast_math_flags(llvalue)
return llvalue
else: else:
return self.llbuilder.mul(self.map(insn.lhs()), self.map(insn.rhs()), return self.llbuilder.mul(self.map(insn.lhs()), self.map(insn.rhs()),
name=insn.name) name=insn.name)
elif isinstance(insn.op, ast.Div): elif isinstance(insn.op, ast.Div):
if builtins.is_float(insn.lhs().type): if builtins.is_float(insn.lhs().type):
return self.llbuilder.fdiv(self.map(insn.lhs()), self.map(insn.rhs()), llvalue = self.llbuilder.fdiv(self.map(insn.lhs()), self.map(insn.rhs()),
name=insn.name) name=insn.name)
self.add_fast_math_flags(llvalue)
return llvalue
else: else:
lllhs = self.llbuilder.sitofp(self.map(insn.lhs()), self.llty_of_type(insn.type)) lllhs = self.llbuilder.sitofp(self.map(insn.lhs()), self.llty_of_type(insn.type))
llrhs = self.llbuilder.sitofp(self.map(insn.rhs()), self.llty_of_type(insn.type)) llrhs = self.llbuilder.sitofp(self.map(insn.rhs()), self.llty_of_type(insn.type))
return self.llbuilder.fdiv(lllhs, llrhs, llvalue = self.llbuilder.fdiv(lllhs, llrhs,
name=insn.name) name=insn.name)
self.add_fast_math_flags(llvalue)
return llvalue
elif isinstance(insn.op, ast.FloorDiv): elif isinstance(insn.op, ast.FloorDiv):
if builtins.is_float(insn.type): if builtins.is_float(insn.type):
llvalue = self.llbuilder.fdiv(self.map(insn.lhs()), self.map(insn.rhs())) llvalue = self.llbuilder.fdiv(self.map(insn.lhs()), self.map(insn.rhs()))
self.add_fast_math_flags(llvalue)
return self.llbuilder.call(self.llbuiltin("llvm.floor.f64"), [llvalue], return self.llbuilder.call(self.llbuiltin("llvm.floor.f64"), [llvalue],
name=insn.name) name=insn.name)
else: else:
@ -906,6 +924,7 @@ class LLVMIRGenerator:
# Python only has the modulo operator, LLVM only has the remainder # Python only has the modulo operator, LLVM only has the remainder
if builtins.is_float(insn.type): if builtins.is_float(insn.type):
llvalue = self.llbuilder.frem(self.map(insn.lhs()), self.map(insn.rhs())) llvalue = self.llbuilder.frem(self.map(insn.lhs()), self.map(insn.rhs()))
self.add_fast_math_flags(llvalue)
return self.llbuilder.call(self.llbuiltin("llvm.copysign.f64"), return self.llbuilder.call(self.llbuiltin("llvm.copysign.f64"),
[llvalue, self.map(insn.rhs())], [llvalue, self.map(insn.rhs())],
name=insn.name) name=insn.name)

View File

@ -351,6 +351,7 @@ class TCFunction(TFunction):
attributes = OrderedDict() attributes = OrderedDict()
def __init__(self, args, ret, name, flags={}): def __init__(self, args, ret, name, flags={}):
assert isinstance(flags, set)
for flag in flags: for flag in flags:
assert flag in {'nounwind', 'nowrite'} assert flag in {'nounwind', 'nowrite'}
super().__init__(args, OrderedDict(), ret) super().__init__(args, OrderedDict(), ret)

View File

@ -2,11 +2,11 @@ from artiq.language.core import *
from artiq.language.types import * from artiq.language.types import *
@syscall("cache_get", flags={"nounwind", "nowrite"}) @syscall(flags={"nounwind", "nowrite"})
def cache_get(key: TStr) -> TList(TInt32): def cache_get(key: TStr) -> TList(TInt32):
raise NotImplementedError("syscall not simulated") raise NotImplementedError("syscall not simulated")
@syscall("cache_put", flags={"nowrite"}) @syscall(flags={"nowrite"})
def cache_put(key: TStr, value: TList(TInt32)) -> TNone: def cache_put(key: TStr, value: TList(TInt32)) -> TNone:
raise NotImplementedError("syscall not simulated") raise NotImplementedError("syscall not simulated")

View File

@ -37,7 +37,7 @@ class CompileError(Exception):
return "\n" + _render_diagnostic(self.diagnostic, colored=colors_supported) return "\n" + _render_diagnostic(self.diagnostic, colored=colors_supported)
@syscall("rtio_get_counter", flags={"nounwind", "nowrite"}) @syscall(flags={"nounwind", "nowrite"})
def rtio_get_counter() -> TInt64: def rtio_get_counter() -> TInt64:
raise NotImplementedError("syscall not simulated") raise NotImplementedError("syscall not simulated")

View File

@ -10,20 +10,20 @@ PHASE_MODE_ABSOLUTE = 1
PHASE_MODE_TRACKING = 2 PHASE_MODE_TRACKING = 2
@syscall("dds_init", flags={"nowrite"}) @syscall(flags={"nowrite"})
def dds_init(time_mu: TInt64, bus_channel: TInt32, channel: TInt32) -> TNone: def dds_init(time_mu: TInt64, bus_channel: TInt32, channel: TInt32) -> TNone:
raise NotImplementedError("syscall not simulated") raise NotImplementedError("syscall not simulated")
@syscall("dds_set", flags={"nowrite"}) @syscall(flags={"nowrite"})
def dds_set(time_mu: TInt64, bus_channel: TInt32, channel: TInt32, ftw: TInt32, def dds_set(time_mu: TInt64, bus_channel: TInt32, channel: TInt32, ftw: TInt32,
pow: TInt32, phase_mode: TInt32, amplitude: TInt32) -> TNone: pow: TInt32, phase_mode: TInt32, amplitude: TInt32) -> TNone:
raise NotImplementedError("syscall not simulated") raise NotImplementedError("syscall not simulated")
@syscall("dds_batch_enter", flags={"nowrite"}) @syscall(flags={"nowrite"})
def dds_batch_enter(time_mu: TInt64) -> TNone: def dds_batch_enter(time_mu: TInt64) -> TNone:
raise NotImplementedError("syscall not simulated") raise NotImplementedError("syscall not simulated")
@syscall("dds_batch_exit", flags={"nowrite"}) @syscall(flags={"nowrite"})
def dds_batch_exit() -> TNone: def dds_batch_exit() -> TNone:
raise NotImplementedError("syscall not simulated") raise NotImplementedError("syscall not simulated")
@ -99,27 +99,27 @@ class _DDSGeneric:
self.channel = channel self.channel = channel
self.phase_mode = PHASE_MODE_CONTINUOUS self.phase_mode = PHASE_MODE_CONTINUOUS
@portable @portable(flags=["fast-math"])
def frequency_to_ftw(self, frequency): def frequency_to_ftw(self, frequency):
"""Returns the frequency tuning word corresponding to the given """Returns the frequency tuning word corresponding to the given
frequency. frequency.
""" """
return round(int(2, width=64)**32*frequency/self.core_dds.sysclk) return round(int(2, width=64)**32*frequency/self.core_dds.sysclk)
@portable @portable(flags=["fast-math"])
def ftw_to_frequency(self, ftw): def ftw_to_frequency(self, ftw):
"""Returns the frequency corresponding to the given frequency tuning """Returns the frequency corresponding to the given frequency tuning
word. word.
""" """
return ftw*self.core_dds.sysclk/int(2, width=64)**32 return ftw*self.core_dds.sysclk/int(2, width=64)**32
@portable @portable(flags=["fast-math"])
def turns_to_pow(self, turns): def turns_to_pow(self, turns):
"""Returns the phase offset word corresponding to the given phase """Returns the phase offset word corresponding to the given phase
in turns.""" in turns."""
return round(turns*2**self.pow_width) return round(turns*2**self.pow_width)
@portable @portable(flags=["fast-math"])
def pow_to_turns(self, pow): def pow_to_turns(self, pow):
"""Returns the phase in turns corresponding to the given phase offset """Returns the phase in turns corresponding to the given phase offset
word.""" word."""

View File

@ -3,27 +3,27 @@ from artiq.language.types import TBool, TInt32, TNone
from artiq.coredevice.exceptions import I2CError from artiq.coredevice.exceptions import I2CError
@syscall("i2c_init", flags={"nowrite"}) @syscall(flags={"nowrite"})
def i2c_init(busno: TInt32) -> TNone: def i2c_init(busno: TInt32) -> TNone:
raise NotImplementedError("syscall not simulated") raise NotImplementedError("syscall not simulated")
@syscall("i2c_start", flags={"nounwind", "nowrite"}) @syscall(flags={"nounwind", "nowrite"})
def i2c_start(busno: TInt32) -> TNone: def i2c_start(busno: TInt32) -> TNone:
raise NotImplementedError("syscall not simulated") raise NotImplementedError("syscall not simulated")
@syscall("i2c_stop", flags={"nounwind", "nowrite"}) @syscall(flags={"nounwind", "nowrite"})
def i2c_stop(busno: TInt32) -> TNone: def i2c_stop(busno: TInt32) -> TNone:
raise NotImplementedError("syscall not simulated") raise NotImplementedError("syscall not simulated")
@syscall("i2c_write", flags={"nounwind", "nowrite"}) @syscall(flags={"nounwind", "nowrite"})
def i2c_write(busno: TInt32, b: TInt32) -> TBool: def i2c_write(busno: TInt32, b: TInt32) -> TBool:
raise NotImplementedError("syscall not simulated") raise NotImplementedError("syscall not simulated")
@syscall("i2c_read", flags={"nounwind", "nowrite"}) @syscall(flags={"nounwind", "nowrite"})
def i2c_read(busno: TInt32, ack: TBool) -> TInt32: def i2c_read(busno: TInt32, ack: TBool) -> TInt32:
raise NotImplementedError("syscall not simulated") raise NotImplementedError("syscall not simulated")

View File

@ -2,17 +2,17 @@ from artiq.language.core import syscall
from artiq.language.types import TInt64, TInt32, TNone from artiq.language.types import TInt64, TInt32, TNone
@syscall("rtio_output", flags={"nowrite"}) @syscall(flags={"nowrite"})
def rtio_output(time_mu: TInt64, channel: TInt32, addr: TInt32, data: TInt32 def rtio_output(time_mu: TInt64, channel: TInt32, addr: TInt32, data: TInt32
) -> TNone: ) -> TNone:
raise NotImplementedError("syscall not simulated") raise NotImplementedError("syscall not simulated")
@syscall("rtio_input_timestamp", flags={"nowrite"}) @syscall(flags={"nowrite"})
def rtio_input_timestamp(timeout_mu: TInt64, channel: TInt32) -> TInt64: def rtio_input_timestamp(timeout_mu: TInt64, channel: TInt32) -> TInt64:
raise NotImplementedError("syscall not simulated") raise NotImplementedError("syscall not simulated")
@syscall("rtio_input_data", flags={"nowrite"}) @syscall(flags={"nowrite"})
def rtio_input_data(channel: TInt32) -> TInt32: def rtio_input_data(channel: TInt32) -> TInt32:
raise NotImplementedError("syscall not simulated") raise NotImplementedError("syscall not simulated")

View File

@ -165,7 +165,7 @@ def round(value, width=32):
_ARTIQEmbeddedInfo = namedtuple("_ARTIQEmbeddedInfo", _ARTIQEmbeddedInfo = namedtuple("_ARTIQEmbeddedInfo",
"core_name function syscall forbidden flags") "core_name function syscall forbidden flags")
def kernel(arg): def kernel(arg=None, flags={}):
""" """
This decorator marks an object's method for execution on the core This decorator marks an object's method for execution on the core
device. device.
@ -192,13 +192,17 @@ def kernel(arg):
return getattr(self, arg).run(run_on_core, ((self,) + k_args), k_kwargs) return getattr(self, arg).run(run_on_core, ((self,) + k_args), k_kwargs)
run_on_core.artiq_embedded = _ARTIQEmbeddedInfo( run_on_core.artiq_embedded = _ARTIQEmbeddedInfo(
core_name=arg, function=function, syscall=None, core_name=arg, function=function, syscall=None,
forbidden=False, flags={}) forbidden=False, flags=set(flags))
return run_on_core return run_on_core
return inner_decorator return inner_decorator
elif arg is None:
def inner_decorator(function):
return kernel(function, flags)
return inner_decorator
else: else:
return kernel("core")(arg) return kernel("core", flags)(arg)
def portable(function): def portable(arg=None, flags={}):
""" """
This decorator marks a function for execution on the same device as its This decorator marks a function for execution on the same device as its
caller. caller.
@ -208,12 +212,17 @@ def portable(function):
core device). A decorated function called from a kernel will be executed core device). A decorated function called from a kernel will be executed
on the core device (no RPC). on the core device (no RPC).
""" """
function.artiq_embedded = \ if arg is None:
_ARTIQEmbeddedInfo(core_name=None, function=function, syscall=None, def inner_decorator(function):
forbidden=False, flags={}) return portable(function, flags)
return function return inner_decorator
else:
arg.artiq_embedded = \
_ARTIQEmbeddedInfo(core_name=None, function=arg, syscall=None,
forbidden=False, flags=set(flags))
return arg
def syscall(arg, flags={}): def syscall(arg=None, flags={}):
""" """
This decorator marks a function as a system call. When executed on a core This decorator marks a function as a system call. When executed on a core
device, a C function with the provided name (or the same name as device, a C function with the provided name (or the same name as
@ -229,9 +238,13 @@ def syscall(arg, flags={}):
function.artiq_embedded = \ function.artiq_embedded = \
_ARTIQEmbeddedInfo(core_name=None, function=None, _ARTIQEmbeddedInfo(core_name=None, function=None,
syscall=function.__name__, forbidden=False, syscall=function.__name__, forbidden=False,
flags=flags) flags=set(flags))
return function return function
return inner_decorator return inner_decorator
elif arg is None:
def inner_decorator(function):
return syscall(function.__name__, flags)(function)
return inner_decorator
else: else:
return syscall(arg.__name__)(arg) return syscall(arg.__name__)(arg)

View File

@ -0,0 +1,20 @@
# RUN: env ARTIQ_DUMP_UNOPT_LLVM=%t %python -m artiq.compiler.testbench.embedding +compile %s
# RUN: OutputCheck %s --file-to-check=%t_unopt.ll
from artiq.language.core import *
from artiq.language.types import *
# CHECK-L: fmul fast double 1.000000e+00, 0.000000e+00
@kernel(flags=["fast-math"])
def foo():
core_log(1.0 * 0.0)
# CHECK-L: fmul fast double 2.000000e+00, 0.000000e+00
@portable(flags=["fast-math"])
def bar():
core_log(2.0 * 0.0)
@kernel
def entrypoint():
foo()
bar()

View File

@ -9,7 +9,7 @@ from artiq.language.types import *
# CHECK-L: ; Function Attrs: nounwind # CHECK-L: ; Function Attrs: nounwind
# CHECK-NEXT-L: declare void @foo() # CHECK-NEXT-L: declare void @foo()
@syscall("foo", flags={"nounwind", "nowrite"}) @syscall(flags={"nounwind", "nowrite"})
def foo() -> TNone: def foo() -> TNone:
pass pass