forked from M-Labs/artiq
9ff47bacab
Tests hard-depend on SciPy to make sure this is exercised during CI.
133 lines
3.9 KiB
Python
133 lines
3.9 KiB
Python
r"""
|
|
The :mod:`math_fns` module lists math-related functions from NumPy recognized
|
|
by the ARTIQ compiler so host function objects can be :func:`match`\ ed to
|
|
the compiler type metadata describing their core device analogue.
|
|
"""
|
|
|
|
from collections import OrderedDict
|
|
import numpy
|
|
from . import builtins, types
|
|
|
|
# Some special mathematical functions are exposed via their scipy.special
|
|
# equivalents. Since the rest of the ARTIQ core does not depend on SciPy,
|
|
# gracefully handle it not being present, making the functions simply not
|
|
# available.
|
|
try:
|
|
import scipy.special as scipy_special
|
|
except ImportError:
|
|
scipy_special = None
|
|
|
|
#: float -> float numpy.* math functions for which llvm.* intrinsics exist.
|
|
unary_fp_intrinsics = [(name, "llvm." + name + ".f64") for name in [
|
|
"sin",
|
|
"cos",
|
|
"exp",
|
|
"exp2",
|
|
"log",
|
|
"log10",
|
|
"log2",
|
|
"fabs",
|
|
"floor",
|
|
"ceil",
|
|
"trunc",
|
|
"sqrt",
|
|
]] + [
|
|
# numpy.rint() seems to (NumPy 1.19.0, Python 3.8.5, Linux x86_64)
|
|
# implement round-to-even, but unfortunately, rust-lang/libm only
|
|
# provides round(), which always rounds away from zero.
|
|
#
|
|
# As there is no equivalent of the latter in NumPy (nor any other
|
|
# basic rounding function), expose round() as numpy.rint anyway,
|
|
# even if the rounding modes don't match up, so there is some way
|
|
# to do rounding on the core device. (numpy.round() has entirely
|
|
# different semantics; it rounds to a configurable number of
|
|
# decimals.)
|
|
("rint", "llvm.round.f64"),
|
|
]
|
|
|
|
#: float -> float numpy.* math functions lowered to runtime calls.
|
|
unary_fp_runtime_calls = [
|
|
("tan", "tan"),
|
|
("arcsin", "asin"),
|
|
("arccos", "acos"),
|
|
("arctan", "atan"),
|
|
("sinh", "sinh"),
|
|
("cosh", "cosh"),
|
|
("tanh", "tanh"),
|
|
("arcsinh", "asinh"),
|
|
("arccosh", "acosh"),
|
|
("arctanh", "atanh"),
|
|
("expm1", "expm1"),
|
|
("cbrt", "cbrt"),
|
|
]
|
|
|
|
#: float -> float numpy.* math functions lowered to runtime calls.
|
|
unary_fp_runtime_calls = [
|
|
("tan", "tan"),
|
|
("arcsin", "asin"),
|
|
("arccos", "acos"),
|
|
("arctan", "atan"),
|
|
("sinh", "sinh"),
|
|
("cosh", "cosh"),
|
|
("tanh", "tanh"),
|
|
("arcsinh", "asinh"),
|
|
("arccosh", "acosh"),
|
|
("arctanh", "atanh"),
|
|
("expm1", "expm1"),
|
|
("cbrt", "cbrt"),
|
|
]
|
|
|
|
scipy_special_unary_runtime_calls = [
|
|
("erf", "erf"),
|
|
("erfc", "erfc"),
|
|
("gamma", "tgamma"),
|
|
("gammaln", "lgamma"),
|
|
("j0", "j0"),
|
|
("j1", "j1"),
|
|
("y0", "y0"),
|
|
("y1", "y1"),
|
|
]
|
|
# Not mapped: jv/yv, libm only supports integer orders.
|
|
|
|
#: (float, float) -> float numpy.* math functions lowered to runtime calls.
|
|
binary_fp_runtime_calls = [
|
|
("arctan2", "atan2"),
|
|
("copysign", "copysign"),
|
|
("fmax", "fmax"),
|
|
("fmin", "fmin"),
|
|
# ("ldexp", "ldexp"), # One argument is an int; would need a bit more plumbing.
|
|
("hypot", "hypot"),
|
|
("nextafter", "nextafter"),
|
|
]
|
|
|
|
#: Array handling builtins (special treatment due to allocations).
|
|
numpy_builtins = ["transpose"]
|
|
|
|
|
|
def fp_runtime_type(name, arity):
|
|
args = [("arg{}".format(i), builtins.TFloat()) for i in range(arity)]
|
|
return types.TExternalFunction(
|
|
OrderedDict(args),
|
|
builtins.TFloat(),
|
|
name,
|
|
# errno isn't observable from ARTIQ Python.
|
|
flags={"nounwind", "nowrite"},
|
|
broadcast_across_arrays=True)
|
|
|
|
|
|
math_fn_map = {
|
|
getattr(numpy, symbol): fp_runtime_type(mangle, arity=1)
|
|
for symbol, mangle in (unary_fp_intrinsics + unary_fp_runtime_calls)
|
|
}
|
|
for symbol, mangle in binary_fp_runtime_calls:
|
|
math_fn_map[getattr(numpy, symbol)] = fp_runtime_type(mangle, arity=2)
|
|
for name in numpy_builtins:
|
|
math_fn_map[getattr(numpy, name)] = types.TBuiltinFunction("numpy." + name)
|
|
if scipy_special is not None:
|
|
for symbol, mangle in scipy_special_unary_runtime_calls:
|
|
math_fn_map[getattr(scipy_special, symbol)] = fp_runtime_type(mangle, arity=1)
|
|
|
|
|
|
def match(obj):
|
|
return math_fn_map.get(obj, None)
|