2
0
mirror of https://github.com/m-labs/artiq.git synced 2025-01-19 07:06:42 +08:00

compiler: Provide libm special functions (erf, Bessel functions, …)

Tests hard-depend on SciPy to make sure this is exercised
during CI.
This commit is contained in:
David Nadlinger 2020-11-11 16:35:33 +01:00
parent a5dcd86fb8
commit 9ff47bacab
2 changed files with 71 additions and 20 deletions

View File

@ -8,6 +8,15 @@ from collections import OrderedDict
import numpy
from . import builtins, types
# Some special mathematical functions are exposed via their scipy.special
# equivalents. Since the rest of the ARTIQ core does not depend on SciPy,
# gracefully handle it not being present, making the functions simply not
# available.
try:
import scipy.special as scipy_special
except ImportError:
scipy_special = None
#: float -> float numpy.* math functions for which llvm.* intrinsics exist.
unary_fp_intrinsics = [(name, "llvm." + name + ".f64") for name in [
"sin",
@ -36,6 +45,21 @@ unary_fp_intrinsics = [(name, "llvm." + name + ".f64") for name in [
("rint", "llvm.round.f64"),
]
#: float -> float numpy.* math functions lowered to runtime calls.
unary_fp_runtime_calls = [
("tan", "tan"),
("arcsin", "asin"),
("arccos", "acos"),
("arctan", "atan"),
("sinh", "sinh"),
("cosh", "cosh"),
("tanh", "tanh"),
("arcsinh", "asinh"),
("arccosh", "acosh"),
("arctanh", "atanh"),
("expm1", "expm1"),
("cbrt", "cbrt"),
]
#: float -> float numpy.* math functions lowered to runtime calls.
unary_fp_runtime_calls = [
@ -53,6 +77,18 @@ unary_fp_runtime_calls = [
("cbrt", "cbrt"),
]
scipy_special_unary_runtime_calls = [
("erf", "erf"),
("erfc", "erfc"),
("gamma", "tgamma"),
("gammaln", "lgamma"),
("j0", "j0"),
("j1", "j1"),
("y0", "y0"),
("y1", "y1"),
]
# Not mapped: jv/yv, libm only supports integer orders.
#: (float, float) -> float numpy.* math functions lowered to runtime calls.
binary_fp_runtime_calls = [
("arctan2", "atan2"),
@ -70,22 +106,27 @@ numpy_builtins = ["transpose"]
def fp_runtime_type(name, arity):
args = [("arg{}".format(i), builtins.TFloat()) for i in range(arity)]
return types.TExternalFunction(OrderedDict(args),
builtins.TFloat(),
name,
# errno isn't observable from ARTIQ Python.
flags={"nounwind", "nowrite"},
broadcast_across_arrays=True)
return types.TExternalFunction(
OrderedDict(args),
builtins.TFloat(),
name,
# errno isn't observable from ARTIQ Python.
flags={"nounwind", "nowrite"},
broadcast_across_arrays=True)
numpy_map = {
math_fn_map = {
getattr(numpy, symbol): fp_runtime_type(mangle, arity=1)
for symbol, mangle in (unary_fp_intrinsics + unary_fp_runtime_calls)
}
for symbol, mangle in binary_fp_runtime_calls:
numpy_map[getattr(numpy, symbol)] = fp_runtime_type(mangle, arity=2)
math_fn_map[getattr(numpy, symbol)] = fp_runtime_type(mangle, arity=2)
for name in numpy_builtins:
numpy_map[getattr(numpy, name)] = types.TBuiltinFunction("numpy." + name)
math_fn_map[getattr(numpy, name)] = types.TBuiltinFunction("numpy." + name)
if scipy_special is not None:
for symbol, mangle in scipy_special_unary_runtime_calls:
math_fn_map[getattr(scipy_special, symbol)] = fp_runtime_type(mangle, arity=1)
def match(obj):
return numpy_map.get(obj, None)
return math_fn_map.get(obj, None)

View File

@ -1,5 +1,6 @@
from artiq.experiment import *
import numpy
import scipy.special
from artiq.test.hardware_testbench import ExperimentCase
from artiq.compiler.targets import CortexA9Target
from artiq.compiler import math_fns
@ -10,12 +11,12 @@ class _RunOnDevice(EnvExperiment):
self.setattr_device("core")
@kernel
def run_on_kernel_unary(self, a, callback, numpy):
self.run(a, callback, numpy)
def run_on_kernel_unary(self, a, callback, numpy, scipy):
self.run(a, callback, numpy, scipy)
@kernel
def run_on_kernel_binary(self, a, b, callback, numpy):
self.run(a, b, callback, numpy)
def run_on_kernel_binary(self, a, b, callback, numpy, scipy):
self.run(a, b, callback, numpy, scipy)
# Binary operations supported for scalars and arrays of any dimension, including
@ -26,7 +27,7 @@ ELEM_WISE_BINOPS = ["+", "*", "//", "%", "**", "-", "/"]
class CompareHostDeviceTest(ExperimentCase):
def _test_binop(self, op, a, b):
exp = self.create(_RunOnDevice)
exp.run = kernel_from_string(["a", "b", "callback", "numpy"],
exp.run = kernel_from_string(["a", "b", "callback", "numpy", "scipy"],
"callback(" + op + ")",
decorator=portable)
checked = False
@ -40,14 +41,14 @@ class CompareHostDeviceTest(ExperimentCase):
"Discrepancy in binop test for '{}': Expexcted ({}, {}) -> {}, got {}"
.format(op, a, b, host, device))
exp.run_on_kernel_binary(a, b, with_both_results, numpy)
exp.run_on_kernel_binary(a, b, with_both_results, numpy, scipy)
exp.run(a, b, with_host_result, numpy)
exp.run(a, b, with_host_result, numpy, scipy)
self.assertTrue(checked, "Test did not run")
def _test_unaryop(self, op, a):
exp = self.create(_RunOnDevice)
exp.run = kernel_from_string(["a", "callback", "numpy"],
exp.run = kernel_from_string(["a", "callback", "numpy", "scipy"],
"callback(" + op + ")",
decorator=portable)
checked = False
@ -61,9 +62,9 @@ class CompareHostDeviceTest(ExperimentCase):
"Discrepancy in unaryop test for '{}': Expexcted {} -> {}, got {}"
.format(op, a, host, device))
exp.run_on_kernel_unary(a, with_both_results, numpy)
exp.run_on_kernel_unary(a, with_both_results, numpy, scipy)
exp.run(a, with_host_result, numpy)
exp.run(a, with_host_result, numpy, scipy)
self.assertTrue(checked, "Test did not run")
def test_scalar_scalar_binops(self):
@ -101,6 +102,15 @@ class CompareHostDeviceTest(ExperimentCase):
self._test_unaryop(op, 0.51)
self._test_unaryop(op, numpy.array([[0.3, 0.4], [0.51, 0.6]]))
def test_unary_scipy_fns(self):
names = [name for name, _ in math_fns.scipy_special_unary_runtime_calls]
if self.create(_RunOnDevice).core.target_cls != CortexA9Target:
names.remove("gamma")
for name in names:
op = "scipy.special.{}(a)".format(name)
self._test_unaryop(op, 0.5)
self._test_unaryop(op, numpy.array([[0.3, 0.4], [0.5, 0.6]]))
def test_binary_math_fns(self):
names = [name for name, _ in math_fns.binary_fp_runtime_calls]
exp = self.create(_RunOnDevice)