forked from M-Labs/artiq
compiler: Provide libm special functions (erf, Bessel functions, …)
Tests hard-depend on SciPy to make sure this is exercised during CI.
This commit is contained in:
parent
a5dcd86fb8
commit
9ff47bacab
|
@ -8,6 +8,15 @@ from collections import OrderedDict
|
||||||
import numpy
|
import numpy
|
||||||
from . import builtins, types
|
from . import builtins, types
|
||||||
|
|
||||||
|
# Some special mathematical functions are exposed via their scipy.special
|
||||||
|
# equivalents. Since the rest of the ARTIQ core does not depend on SciPy,
|
||||||
|
# gracefully handle it not being present, making the functions simply not
|
||||||
|
# available.
|
||||||
|
try:
|
||||||
|
import scipy.special as scipy_special
|
||||||
|
except ImportError:
|
||||||
|
scipy_special = None
|
||||||
|
|
||||||
#: float -> float numpy.* math functions for which llvm.* intrinsics exist.
|
#: float -> float numpy.* math functions for which llvm.* intrinsics exist.
|
||||||
unary_fp_intrinsics = [(name, "llvm." + name + ".f64") for name in [
|
unary_fp_intrinsics = [(name, "llvm." + name + ".f64") for name in [
|
||||||
"sin",
|
"sin",
|
||||||
|
@ -36,6 +45,21 @@ unary_fp_intrinsics = [(name, "llvm." + name + ".f64") for name in [
|
||||||
("rint", "llvm.round.f64"),
|
("rint", "llvm.round.f64"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
#: float -> float numpy.* math functions lowered to runtime calls.
|
||||||
|
unary_fp_runtime_calls = [
|
||||||
|
("tan", "tan"),
|
||||||
|
("arcsin", "asin"),
|
||||||
|
("arccos", "acos"),
|
||||||
|
("arctan", "atan"),
|
||||||
|
("sinh", "sinh"),
|
||||||
|
("cosh", "cosh"),
|
||||||
|
("tanh", "tanh"),
|
||||||
|
("arcsinh", "asinh"),
|
||||||
|
("arccosh", "acosh"),
|
||||||
|
("arctanh", "atanh"),
|
||||||
|
("expm1", "expm1"),
|
||||||
|
("cbrt", "cbrt"),
|
||||||
|
]
|
||||||
|
|
||||||
#: float -> float numpy.* math functions lowered to runtime calls.
|
#: float -> float numpy.* math functions lowered to runtime calls.
|
||||||
unary_fp_runtime_calls = [
|
unary_fp_runtime_calls = [
|
||||||
|
@ -53,6 +77,18 @@ unary_fp_runtime_calls = [
|
||||||
("cbrt", "cbrt"),
|
("cbrt", "cbrt"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
scipy_special_unary_runtime_calls = [
|
||||||
|
("erf", "erf"),
|
||||||
|
("erfc", "erfc"),
|
||||||
|
("gamma", "tgamma"),
|
||||||
|
("gammaln", "lgamma"),
|
||||||
|
("j0", "j0"),
|
||||||
|
("j1", "j1"),
|
||||||
|
("y0", "y0"),
|
||||||
|
("y1", "y1"),
|
||||||
|
]
|
||||||
|
# Not mapped: jv/yv, libm only supports integer orders.
|
||||||
|
|
||||||
#: (float, float) -> float numpy.* math functions lowered to runtime calls.
|
#: (float, float) -> float numpy.* math functions lowered to runtime calls.
|
||||||
binary_fp_runtime_calls = [
|
binary_fp_runtime_calls = [
|
||||||
("arctan2", "atan2"),
|
("arctan2", "atan2"),
|
||||||
|
@ -70,22 +106,27 @@ numpy_builtins = ["transpose"]
|
||||||
|
|
||||||
def fp_runtime_type(name, arity):
|
def fp_runtime_type(name, arity):
|
||||||
args = [("arg{}".format(i), builtins.TFloat()) for i in range(arity)]
|
args = [("arg{}".format(i), builtins.TFloat()) for i in range(arity)]
|
||||||
return types.TExternalFunction(OrderedDict(args),
|
return types.TExternalFunction(
|
||||||
builtins.TFloat(),
|
OrderedDict(args),
|
||||||
name,
|
builtins.TFloat(),
|
||||||
# errno isn't observable from ARTIQ Python.
|
name,
|
||||||
flags={"nounwind", "nowrite"},
|
# errno isn't observable from ARTIQ Python.
|
||||||
broadcast_across_arrays=True)
|
flags={"nounwind", "nowrite"},
|
||||||
|
broadcast_across_arrays=True)
|
||||||
|
|
||||||
numpy_map = {
|
|
||||||
|
math_fn_map = {
|
||||||
getattr(numpy, symbol): fp_runtime_type(mangle, arity=1)
|
getattr(numpy, symbol): fp_runtime_type(mangle, arity=1)
|
||||||
for symbol, mangle in (unary_fp_intrinsics + unary_fp_runtime_calls)
|
for symbol, mangle in (unary_fp_intrinsics + unary_fp_runtime_calls)
|
||||||
}
|
}
|
||||||
for symbol, mangle in binary_fp_runtime_calls:
|
for symbol, mangle in binary_fp_runtime_calls:
|
||||||
numpy_map[getattr(numpy, symbol)] = fp_runtime_type(mangle, arity=2)
|
math_fn_map[getattr(numpy, symbol)] = fp_runtime_type(mangle, arity=2)
|
||||||
for name in numpy_builtins:
|
for name in numpy_builtins:
|
||||||
numpy_map[getattr(numpy, name)] = types.TBuiltinFunction("numpy." + name)
|
math_fn_map[getattr(numpy, name)] = types.TBuiltinFunction("numpy." + name)
|
||||||
|
if scipy_special is not None:
|
||||||
|
for symbol, mangle in scipy_special_unary_runtime_calls:
|
||||||
|
math_fn_map[getattr(scipy_special, symbol)] = fp_runtime_type(mangle, arity=1)
|
||||||
|
|
||||||
|
|
||||||
def match(obj):
|
def match(obj):
|
||||||
return numpy_map.get(obj, None)
|
return math_fn_map.get(obj, None)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from artiq.experiment import *
|
from artiq.experiment import *
|
||||||
import numpy
|
import numpy
|
||||||
|
import scipy.special
|
||||||
from artiq.test.hardware_testbench import ExperimentCase
|
from artiq.test.hardware_testbench import ExperimentCase
|
||||||
from artiq.compiler.targets import CortexA9Target
|
from artiq.compiler.targets import CortexA9Target
|
||||||
from artiq.compiler import math_fns
|
from artiq.compiler import math_fns
|
||||||
|
@ -10,12 +11,12 @@ class _RunOnDevice(EnvExperiment):
|
||||||
self.setattr_device("core")
|
self.setattr_device("core")
|
||||||
|
|
||||||
@kernel
|
@kernel
|
||||||
def run_on_kernel_unary(self, a, callback, numpy):
|
def run_on_kernel_unary(self, a, callback, numpy, scipy):
|
||||||
self.run(a, callback, numpy)
|
self.run(a, callback, numpy, scipy)
|
||||||
|
|
||||||
@kernel
|
@kernel
|
||||||
def run_on_kernel_binary(self, a, b, callback, numpy):
|
def run_on_kernel_binary(self, a, b, callback, numpy, scipy):
|
||||||
self.run(a, b, callback, numpy)
|
self.run(a, b, callback, numpy, scipy)
|
||||||
|
|
||||||
|
|
||||||
# Binary operations supported for scalars and arrays of any dimension, including
|
# Binary operations supported for scalars and arrays of any dimension, including
|
||||||
|
@ -26,7 +27,7 @@ ELEM_WISE_BINOPS = ["+", "*", "//", "%", "**", "-", "/"]
|
||||||
class CompareHostDeviceTest(ExperimentCase):
|
class CompareHostDeviceTest(ExperimentCase):
|
||||||
def _test_binop(self, op, a, b):
|
def _test_binop(self, op, a, b):
|
||||||
exp = self.create(_RunOnDevice)
|
exp = self.create(_RunOnDevice)
|
||||||
exp.run = kernel_from_string(["a", "b", "callback", "numpy"],
|
exp.run = kernel_from_string(["a", "b", "callback", "numpy", "scipy"],
|
||||||
"callback(" + op + ")",
|
"callback(" + op + ")",
|
||||||
decorator=portable)
|
decorator=portable)
|
||||||
checked = False
|
checked = False
|
||||||
|
@ -40,14 +41,14 @@ class CompareHostDeviceTest(ExperimentCase):
|
||||||
"Discrepancy in binop test for '{}': Expexcted ({}, {}) -> {}, got {}"
|
"Discrepancy in binop test for '{}': Expexcted ({}, {}) -> {}, got {}"
|
||||||
.format(op, a, b, host, device))
|
.format(op, a, b, host, device))
|
||||||
|
|
||||||
exp.run_on_kernel_binary(a, b, with_both_results, numpy)
|
exp.run_on_kernel_binary(a, b, with_both_results, numpy, scipy)
|
||||||
|
|
||||||
exp.run(a, b, with_host_result, numpy)
|
exp.run(a, b, with_host_result, numpy, scipy)
|
||||||
self.assertTrue(checked, "Test did not run")
|
self.assertTrue(checked, "Test did not run")
|
||||||
|
|
||||||
def _test_unaryop(self, op, a):
|
def _test_unaryop(self, op, a):
|
||||||
exp = self.create(_RunOnDevice)
|
exp = self.create(_RunOnDevice)
|
||||||
exp.run = kernel_from_string(["a", "callback", "numpy"],
|
exp.run = kernel_from_string(["a", "callback", "numpy", "scipy"],
|
||||||
"callback(" + op + ")",
|
"callback(" + op + ")",
|
||||||
decorator=portable)
|
decorator=portable)
|
||||||
checked = False
|
checked = False
|
||||||
|
@ -61,9 +62,9 @@ class CompareHostDeviceTest(ExperimentCase):
|
||||||
"Discrepancy in unaryop test for '{}': Expexcted {} -> {}, got {}"
|
"Discrepancy in unaryop test for '{}': Expexcted {} -> {}, got {}"
|
||||||
.format(op, a, host, device))
|
.format(op, a, host, device))
|
||||||
|
|
||||||
exp.run_on_kernel_unary(a, with_both_results, numpy)
|
exp.run_on_kernel_unary(a, with_both_results, numpy, scipy)
|
||||||
|
|
||||||
exp.run(a, with_host_result, numpy)
|
exp.run(a, with_host_result, numpy, scipy)
|
||||||
self.assertTrue(checked, "Test did not run")
|
self.assertTrue(checked, "Test did not run")
|
||||||
|
|
||||||
def test_scalar_scalar_binops(self):
|
def test_scalar_scalar_binops(self):
|
||||||
|
@ -101,6 +102,15 @@ class CompareHostDeviceTest(ExperimentCase):
|
||||||
self._test_unaryop(op, 0.51)
|
self._test_unaryop(op, 0.51)
|
||||||
self._test_unaryop(op, numpy.array([[0.3, 0.4], [0.51, 0.6]]))
|
self._test_unaryop(op, numpy.array([[0.3, 0.4], [0.51, 0.6]]))
|
||||||
|
|
||||||
|
def test_unary_scipy_fns(self):
|
||||||
|
names = [name for name, _ in math_fns.scipy_special_unary_runtime_calls]
|
||||||
|
if self.create(_RunOnDevice).core.target_cls != CortexA9Target:
|
||||||
|
names.remove("gamma")
|
||||||
|
for name in names:
|
||||||
|
op = "scipy.special.{}(a)".format(name)
|
||||||
|
self._test_unaryop(op, 0.5)
|
||||||
|
self._test_unaryop(op, numpy.array([[0.3, 0.4], [0.5, 0.6]]))
|
||||||
|
|
||||||
def test_binary_math_fns(self):
|
def test_binary_math_fns(self):
|
||||||
names = [name for name, _ in math_fns.binary_fp_runtime_calls]
|
names = [name for name, _ in math_fns.binary_fp_runtime_calls]
|
||||||
exp = self.create(_RunOnDevice)
|
exp = self.create(_RunOnDevice)
|
||||||
|
|
Loading…
Reference in New Issue