forked from M-Labs/artiq
compiler: Support common numpy.* math functions
Relies on the runtime to provide the necessary (libm-compatible) functions. The test is nifty, but a bit brittle; if this breaks in the future because of optimizer changes, do not hesitate to convert this into a more pedestrian test case.
This commit is contained in:
parent
d37503f21d
commit
4d48470320
@ -14,7 +14,7 @@ from pythonparser import lexer as source_lexer, parser as source_parser
|
|||||||
from Levenshtein import ratio as similarity, jaro_winkler
|
from Levenshtein import ratio as similarity, jaro_winkler
|
||||||
|
|
||||||
from ..language import core as language_core
|
from ..language import core as language_core
|
||||||
from . import types, builtins, asttyped, prelude
|
from . import types, builtins, asttyped, math_fns, prelude
|
||||||
from .transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer, TypedtreePrinter
|
from .transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer, TypedtreePrinter
|
||||||
from .transforms.asttyped_rewriter import LocalExtractor
|
from .transforms.asttyped_rewriter import LocalExtractor
|
||||||
|
|
||||||
@ -246,7 +246,8 @@ class ASTSynthesizer:
|
|||||||
loc=begin_loc.join(end_loc))
|
loc=begin_loc.join(end_loc))
|
||||||
elif inspect.isfunction(value) or inspect.ismethod(value) or \
|
elif inspect.isfunction(value) or inspect.ismethod(value) or \
|
||||||
isinstance(value, pytypes.BuiltinFunctionType) or \
|
isinstance(value, pytypes.BuiltinFunctionType) or \
|
||||||
isinstance(value, SpecializedFunction):
|
isinstance(value, SpecializedFunction) or \
|
||||||
|
isinstance(value, numpy.ufunc):
|
||||||
if inspect.ismethod(value):
|
if inspect.ismethod(value):
|
||||||
quoted_self = self.quote(value.__self__)
|
quoted_self = self.quote(value.__self__)
|
||||||
function_type = self.quote_function(value.__func__, self.expanded_from)
|
function_type = self.quote_function(value.__func__, self.expanded_from)
|
||||||
@ -1057,7 +1058,11 @@ class Stitcher:
|
|||||||
host_function = function
|
host_function = function
|
||||||
|
|
||||||
if function in self.functions:
|
if function in self.functions:
|
||||||
pass
|
return self.functions[function]
|
||||||
|
|
||||||
|
math_type = math_fns.match(function)
|
||||||
|
if math_type is not None:
|
||||||
|
self.functions[function] = math_type
|
||||||
elif not hasattr(host_function, "artiq_embedded") or \
|
elif not hasattr(host_function, "artiq_embedded") or \
|
||||||
(host_function.artiq_embedded.core_name is None and
|
(host_function.artiq_embedded.core_name is None and
|
||||||
host_function.artiq_embedded.portable is False and
|
host_function.artiq_embedded.portable is False and
|
||||||
|
42
artiq/compiler/math_fns.py
Normal file
42
artiq/compiler/math_fns.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
from collections import OrderedDict
|
||||||
|
import numpy
|
||||||
|
from . import builtins, types
|
||||||
|
|
||||||
|
#: float -> float numpy.* math functions for which llvm.* intrinsics exist.
|
||||||
|
unary_fp_intrinsics = [(name, "llvm." + name + ".f64") for name in [
|
||||||
|
"sin",
|
||||||
|
"cos",
|
||||||
|
"exp",
|
||||||
|
"exp2",
|
||||||
|
"log",
|
||||||
|
"log10",
|
||||||
|
"log2",
|
||||||
|
"fabs",
|
||||||
|
"floor",
|
||||||
|
"ceil",
|
||||||
|
"trunc",
|
||||||
|
"rint",
|
||||||
|
]]
|
||||||
|
|
||||||
|
#: float -> float numpy.* math functions lowered to runtime calls.
|
||||||
|
unary_fp_runtime_calls = [
|
||||||
|
("tan", "tan"),
|
||||||
|
("arcsin", "asin"),
|
||||||
|
("arccos", "acos"),
|
||||||
|
("arctan", "atan"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def unary_fp_type(name):
|
||||||
|
return types.TExternalFunction(OrderedDict([("arg", builtins.TFloat())]),
|
||||||
|
builtins.TFloat(), name)
|
||||||
|
|
||||||
|
|
||||||
|
numpy_map = {
|
||||||
|
getattr(numpy, symbol): unary_fp_type(mangle)
|
||||||
|
for symbol, mangle in (unary_fp_intrinsics + unary_fp_runtime_calls)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def match(obj):
|
||||||
|
return numpy_map.get(obj, None)
|
30
artiq/test/lit/embedding/math_fns.py
Normal file
30
artiq/test/lit/embedding/math_fns.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
# RUN: env ARTIQ_DUMP_LLVM=%t %python -m artiq.compiler.testbench.embedding %s
|
||||||
|
# RUN: OutputCheck %s --file-to-check=%t.ll
|
||||||
|
|
||||||
|
from artiq.language.core import *
|
||||||
|
from artiq.language.types import *
|
||||||
|
import numpy
|
||||||
|
|
||||||
|
@kernel
|
||||||
|
def entrypoint():
|
||||||
|
# LLVM's constant folding for transcendental functions is good enough that
|
||||||
|
# we can do a basic smoke test by just making sure the module compiles and
|
||||||
|
# all assertions are statically eliminated.
|
||||||
|
|
||||||
|
# CHECK-NOT: assert
|
||||||
|
assert numpy.sin(0.0) == 0.0
|
||||||
|
assert numpy.cos(0.0) == 1.0
|
||||||
|
assert numpy.exp(0.0) == 1.0
|
||||||
|
assert numpy.exp2(1.0) == 2.0
|
||||||
|
assert numpy.log(numpy.exp(1.0)) == 1.0
|
||||||
|
assert numpy.log10(10.0) == 1.0
|
||||||
|
assert numpy.log2(2.0) == 1.0
|
||||||
|
assert numpy.fabs(-1.0) == 1.0
|
||||||
|
assert numpy.floor(42.5) == 42.0
|
||||||
|
assert numpy.ceil(42.5) == 43.0
|
||||||
|
assert numpy.trunc(41.5) == 41.0
|
||||||
|
assert numpy.rint(41.5) == 42.0
|
||||||
|
assert numpy.tan(0.0) == 0.0
|
||||||
|
assert numpy.arcsin(0.0) == 0.0
|
||||||
|
assert numpy.arccos(1.0) == 0.0
|
||||||
|
assert numpy.arctan(0.0) == 0.0
|
Loading…
Reference in New Issue
Block a user