compiler: Support common numpy.* math functions

Relies on the runtime to provide the necessary
(libm-compatible) functions.

The test is nifty, but a bit brittle; if this breaks in the
future because of optimizer changes, do not hesitate to convert
this into a more pedestrian test case.
pull/1508/head
David Nadlinger 2020-08-02 15:27:02 +01:00
parent d37503f21d
commit 4d48470320
3 changed files with 80 additions and 3 deletions

View File

@ -14,7 +14,7 @@ from pythonparser import lexer as source_lexer, parser as source_parser
from Levenshtein import ratio as similarity, jaro_winkler
from ..language import core as language_core
from . import types, builtins, asttyped, prelude
from . import types, builtins, asttyped, math_fns, prelude
from .transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer, TypedtreePrinter
from .transforms.asttyped_rewriter import LocalExtractor
@ -246,7 +246,8 @@ class ASTSynthesizer:
loc=begin_loc.join(end_loc))
elif inspect.isfunction(value) or inspect.ismethod(value) or \
isinstance(value, pytypes.BuiltinFunctionType) or \
isinstance(value, SpecializedFunction):
isinstance(value, SpecializedFunction) or \
isinstance(value, numpy.ufunc):
if inspect.ismethod(value):
quoted_self = self.quote(value.__self__)
function_type = self.quote_function(value.__func__, self.expanded_from)
@ -1057,7 +1058,11 @@ class Stitcher:
host_function = function
if function in self.functions:
pass
return self.functions[function]
math_type = math_fns.match(function)
if math_type is not None:
self.functions[function] = math_type
elif not hasattr(host_function, "artiq_embedded") or \
(host_function.artiq_embedded.core_name is None and
host_function.artiq_embedded.portable is False and

View File

@ -0,0 +1,42 @@
from collections import OrderedDict
import numpy
from . import builtins, types
#: float -> float numpy.* math functions for which llvm.* intrinsics exist.
unary_fp_intrinsics = [(name, "llvm." + name + ".f64") for name in [
"sin",
"cos",
"exp",
"exp2",
"log",
"log10",
"log2",
"fabs",
"floor",
"ceil",
"trunc",
"rint",
]]
#: float -> float numpy.* math functions lowered to runtime calls.
unary_fp_runtime_calls = [
("tan", "tan"),
("arcsin", "asin"),
("arccos", "acos"),
("arctan", "atan"),
]
def unary_fp_type(name):
return types.TExternalFunction(OrderedDict([("arg", builtins.TFloat())]),
builtins.TFloat(), name)
numpy_map = {
getattr(numpy, symbol): unary_fp_type(mangle)
for symbol, mangle in (unary_fp_intrinsics + unary_fp_runtime_calls)
}
def match(obj):
return numpy_map.get(obj, None)

View File

@ -0,0 +1,30 @@
# RUN: env ARTIQ_DUMP_LLVM=%t %python -m artiq.compiler.testbench.embedding %s
# RUN: OutputCheck %s --file-to-check=%t.ll
from artiq.language.core import *
from artiq.language.types import *
import numpy
@kernel
def entrypoint():
# LLVM's constant folding for transcendental functions is good enough that
# we can do a basic smoke test by just making sure the module compiles and
# all assertions are statically eliminated.
# CHECK-NOT: assert
assert numpy.sin(0.0) == 0.0
assert numpy.cos(0.0) == 1.0
assert numpy.exp(0.0) == 1.0
assert numpy.exp2(1.0) == 2.0
assert numpy.log(numpy.exp(1.0)) == 1.0
assert numpy.log10(10.0) == 1.0
assert numpy.log2(2.0) == 1.0
assert numpy.fabs(-1.0) == 1.0
assert numpy.floor(42.5) == 42.0
assert numpy.ceil(42.5) == 43.0
assert numpy.trunc(41.5) == 41.0
assert numpy.rint(41.5) == 42.0
assert numpy.tan(0.0) == 0.0
assert numpy.arcsin(0.0) == 0.0
assert numpy.arccos(1.0) == 0.0
assert numpy.arctan(0.0) == 0.0