From 4d484703205e33d85d9dd9ade4f1ddc15c0a5217 Mon Sep 17 00:00:00 2001 From: David Nadlinger Date: Sun, 2 Aug 2020 15:27:02 +0100 Subject: [PATCH] compiler: Support common numpy.* math functions Relies on the runtime to provide the necessary (libm-compatible) functions. The test is nifty, but a bit brittle; if this breaks in the future because of optimizer changes, do not hesitate to convert this into a more pedestrian test case. --- artiq/compiler/embedding.py | 11 ++++++-- artiq/compiler/math_fns.py | 42 ++++++++++++++++++++++++++++ artiq/test/lit/embedding/math_fns.py | 30 ++++++++++++++++++++ 3 files changed, 80 insertions(+), 3 deletions(-) create mode 100644 artiq/compiler/math_fns.py create mode 100644 artiq/test/lit/embedding/math_fns.py diff --git a/artiq/compiler/embedding.py b/artiq/compiler/embedding.py index bdac020bb..ddacf947f 100644 --- a/artiq/compiler/embedding.py +++ b/artiq/compiler/embedding.py @@ -14,7 +14,7 @@ from pythonparser import lexer as source_lexer, parser as source_parser from Levenshtein import ratio as similarity, jaro_winkler from ..language import core as language_core -from . import types, builtins, asttyped, prelude +from . import types, builtins, asttyped, math_fns, prelude from .transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer, TypedtreePrinter from .transforms.asttyped_rewriter import LocalExtractor @@ -246,7 +246,8 @@ class ASTSynthesizer: loc=begin_loc.join(end_loc)) elif inspect.isfunction(value) or inspect.ismethod(value) or \ isinstance(value, pytypes.BuiltinFunctionType) or \ - isinstance(value, SpecializedFunction): + isinstance(value, SpecializedFunction) or \ + isinstance(value, numpy.ufunc): if inspect.ismethod(value): quoted_self = self.quote(value.__self__) function_type = self.quote_function(value.__func__, self.expanded_from) @@ -1057,7 +1058,11 @@ class Stitcher: host_function = function if function in self.functions: - pass + return self.functions[function] + + math_type = math_fns.match(function) + if math_type is not None: + self.functions[function] = math_type elif not hasattr(host_function, "artiq_embedded") or \ (host_function.artiq_embedded.core_name is None and host_function.artiq_embedded.portable is False and diff --git a/artiq/compiler/math_fns.py b/artiq/compiler/math_fns.py new file mode 100644 index 000000000..83c1f82b4 --- /dev/null +++ b/artiq/compiler/math_fns.py @@ -0,0 +1,42 @@ +from collections import OrderedDict +import numpy +from . import builtins, types + +#: float -> float numpy.* math functions for which llvm.* intrinsics exist. +unary_fp_intrinsics = [(name, "llvm." + name + ".f64") for name in [ + "sin", + "cos", + "exp", + "exp2", + "log", + "log10", + "log2", + "fabs", + "floor", + "ceil", + "trunc", + "rint", +]] + +#: float -> float numpy.* math functions lowered to runtime calls. +unary_fp_runtime_calls = [ + ("tan", "tan"), + ("arcsin", "asin"), + ("arccos", "acos"), + ("arctan", "atan"), +] + + +def unary_fp_type(name): + return types.TExternalFunction(OrderedDict([("arg", builtins.TFloat())]), + builtins.TFloat(), name) + + +numpy_map = { + getattr(numpy, symbol): unary_fp_type(mangle) + for symbol, mangle in (unary_fp_intrinsics + unary_fp_runtime_calls) +} + + +def match(obj): + return numpy_map.get(obj, None) diff --git a/artiq/test/lit/embedding/math_fns.py b/artiq/test/lit/embedding/math_fns.py new file mode 100644 index 000000000..40ac18e42 --- /dev/null +++ b/artiq/test/lit/embedding/math_fns.py @@ -0,0 +1,30 @@ +# RUN: env ARTIQ_DUMP_LLVM=%t %python -m artiq.compiler.testbench.embedding %s +# RUN: OutputCheck %s --file-to-check=%t.ll + +from artiq.language.core import * +from artiq.language.types import * +import numpy + +@kernel +def entrypoint(): + # LLVM's constant folding for transcendental functions is good enough that + # we can do a basic smoke test by just making sure the module compiles and + # all assertions are statically eliminated. + + # CHECK-NOT: assert + assert numpy.sin(0.0) == 0.0 + assert numpy.cos(0.0) == 1.0 + assert numpy.exp(0.0) == 1.0 + assert numpy.exp2(1.0) == 2.0 + assert numpy.log(numpy.exp(1.0)) == 1.0 + assert numpy.log10(10.0) == 1.0 + assert numpy.log2(2.0) == 1.0 + assert numpy.fabs(-1.0) == 1.0 + assert numpy.floor(42.5) == 42.0 + assert numpy.ceil(42.5) == 43.0 + assert numpy.trunc(41.5) == 41.0 + assert numpy.rint(41.5) == 42.0 + assert numpy.tan(0.0) == 0.0 + assert numpy.arcsin(0.0) == 0.0 + assert numpy.arccos(1.0) == 0.0 + assert numpy.arctan(0.0) == 0.0