forked from M-Labs/artiq
compiler: Support common numpy.* math functions
Relies on the runtime to provide the necessary (libm-compatible) functions. The test is nifty, but a bit brittle; if this breaks in the future because of optimizer changes, do not hesitate to convert this into a more pedestrian test case.
This commit is contained in:
parent
d37503f21d
commit
4d48470320
@ -14,7 +14,7 @@ from pythonparser import lexer as source_lexer, parser as source_parser
|
||||
from Levenshtein import ratio as similarity, jaro_winkler
|
||||
|
||||
from ..language import core as language_core
|
||||
from . import types, builtins, asttyped, prelude
|
||||
from . import types, builtins, asttyped, math_fns, prelude
|
||||
from .transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer, TypedtreePrinter
|
||||
from .transforms.asttyped_rewriter import LocalExtractor
|
||||
|
||||
@ -246,7 +246,8 @@ class ASTSynthesizer:
|
||||
loc=begin_loc.join(end_loc))
|
||||
elif inspect.isfunction(value) or inspect.ismethod(value) or \
|
||||
isinstance(value, pytypes.BuiltinFunctionType) or \
|
||||
isinstance(value, SpecializedFunction):
|
||||
isinstance(value, SpecializedFunction) or \
|
||||
isinstance(value, numpy.ufunc):
|
||||
if inspect.ismethod(value):
|
||||
quoted_self = self.quote(value.__self__)
|
||||
function_type = self.quote_function(value.__func__, self.expanded_from)
|
||||
@ -1057,7 +1058,11 @@ class Stitcher:
|
||||
host_function = function
|
||||
|
||||
if function in self.functions:
|
||||
pass
|
||||
return self.functions[function]
|
||||
|
||||
math_type = math_fns.match(function)
|
||||
if math_type is not None:
|
||||
self.functions[function] = math_type
|
||||
elif not hasattr(host_function, "artiq_embedded") or \
|
||||
(host_function.artiq_embedded.core_name is None and
|
||||
host_function.artiq_embedded.portable is False and
|
||||
|
42
artiq/compiler/math_fns.py
Normal file
42
artiq/compiler/math_fns.py
Normal file
@ -0,0 +1,42 @@
|
||||
from collections import OrderedDict
|
||||
import numpy
|
||||
from . import builtins, types
|
||||
|
||||
#: float -> float numpy.* math functions for which llvm.* intrinsics exist.
|
||||
unary_fp_intrinsics = [(name, "llvm." + name + ".f64") for name in [
|
||||
"sin",
|
||||
"cos",
|
||||
"exp",
|
||||
"exp2",
|
||||
"log",
|
||||
"log10",
|
||||
"log2",
|
||||
"fabs",
|
||||
"floor",
|
||||
"ceil",
|
||||
"trunc",
|
||||
"rint",
|
||||
]]
|
||||
|
||||
#: float -> float numpy.* math functions lowered to runtime calls.
|
||||
unary_fp_runtime_calls = [
|
||||
("tan", "tan"),
|
||||
("arcsin", "asin"),
|
||||
("arccos", "acos"),
|
||||
("arctan", "atan"),
|
||||
]
|
||||
|
||||
|
||||
def unary_fp_type(name):
|
||||
return types.TExternalFunction(OrderedDict([("arg", builtins.TFloat())]),
|
||||
builtins.TFloat(), name)
|
||||
|
||||
|
||||
numpy_map = {
|
||||
getattr(numpy, symbol): unary_fp_type(mangle)
|
||||
for symbol, mangle in (unary_fp_intrinsics + unary_fp_runtime_calls)
|
||||
}
|
||||
|
||||
|
||||
def match(obj):
|
||||
return numpy_map.get(obj, None)
|
30
artiq/test/lit/embedding/math_fns.py
Normal file
30
artiq/test/lit/embedding/math_fns.py
Normal file
@ -0,0 +1,30 @@
|
||||
# RUN: env ARTIQ_DUMP_LLVM=%t %python -m artiq.compiler.testbench.embedding %s
|
||||
# RUN: OutputCheck %s --file-to-check=%t.ll
|
||||
|
||||
from artiq.language.core import *
|
||||
from artiq.language.types import *
|
||||
import numpy
|
||||
|
||||
@kernel
|
||||
def entrypoint():
|
||||
# LLVM's constant folding for transcendental functions is good enough that
|
||||
# we can do a basic smoke test by just making sure the module compiles and
|
||||
# all assertions are statically eliminated.
|
||||
|
||||
# CHECK-NOT: assert
|
||||
assert numpy.sin(0.0) == 0.0
|
||||
assert numpy.cos(0.0) == 1.0
|
||||
assert numpy.exp(0.0) == 1.0
|
||||
assert numpy.exp2(1.0) == 2.0
|
||||
assert numpy.log(numpy.exp(1.0)) == 1.0
|
||||
assert numpy.log10(10.0) == 1.0
|
||||
assert numpy.log2(2.0) == 1.0
|
||||
assert numpy.fabs(-1.0) == 1.0
|
||||
assert numpy.floor(42.5) == 42.0
|
||||
assert numpy.ceil(42.5) == 43.0
|
||||
assert numpy.trunc(41.5) == 41.0
|
||||
assert numpy.rint(41.5) == 42.0
|
||||
assert numpy.tan(0.0) == 0.0
|
||||
assert numpy.arcsin(0.0) == 0.0
|
||||
assert numpy.arccos(1.0) == 0.0
|
||||
assert numpy.arctan(0.0) == 0.0
|
Loading…
Reference in New Issue
Block a user