1
0
forked from M-Labs/nac3

core: Implement numpy and scipy functions

This commit is contained in:
David Mak 2023-10-06 17:48:31 +08:00
parent 60ad100fbb
commit 2b635a0b97
3 changed files with 646 additions and 64 deletions

View File

@ -919,70 +919,22 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
)))), )))),
loc: None, loc: None,
})), })),
Arc::new(RwLock::new(TopLevelDef::Function { create_fn_by_intrinsic(
name: "floor".into(), primitives,
simple_name: "floor".into(), &var_map,
signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature { "floor",
args: vec![FuncArg { name: "n".into(), ty: float, default_value: None }], float,
ret: int32, &[(float, "x")],
vars: Default::default(), "llvm.floor.f64",
})), ),
var_id: Default::default(), create_fn_by_intrinsic(
instance_to_symbol: Default::default(), primitives,
instance_to_stmt: Default::default(), &var_map,
resolver: None, "ceil",
codegen_callback: Some(Arc::new(GenCall::new(Box::new( float,
|ctx, _, _, args, generator| { &[(float, "x")],
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, ctx.primitives.float)?; "llvm.ceil.f64",
let floor_intrinsic = ),
ctx.module.get_function("llvm.floor.f64").unwrap_or_else(|| {
let float = ctx.ctx.f64_type();
let fn_type = float.fn_type(&[float.into()], false);
ctx.module.add_function("llvm.floor.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(floor_intrinsic, &[arg.into()], "floor")
.try_as_basic_value()
.left()
.unwrap();
Ok(val.into())
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ceil".into(),
simple_name: "ceil".into(),
signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { name: "n".into(), ty: float, default_value: None }],
ret: int32,
vars: Default::default(),
})),
var_id: Default::default(),
instance_to_symbol: Default::default(),
instance_to_stmt: Default::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, _, _, args, generator| {
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
let ceil_intrinsic =
ctx.module.get_function("llvm.ceil.f64").unwrap_or_else(|| {
let float = ctx.ctx.f64_type();
let fn_type = float.fn_type(&[float.into()], false);
ctx.module.add_function("llvm.ceil.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(ceil_intrinsic, &[arg.into()], "ceil")
.try_as_basic_value()
.left()
.unwrap();
Ok(val.into())
},
)))),
loc: None,
})),
Arc::new(RwLock::new({ Arc::new(RwLock::new({
let list_var = primitives.1.get_fresh_var(Some("L".into()), None); let list_var = primitives.1.get_fresh_var(Some("L".into()), None);
let list = primitives.1.add_ty(TypeEnum::TList { ty: list_var.0 }); let list = primitives.1.add_ty(TypeEnum::TList { ty: list_var.0 });
@ -1268,6 +1220,353 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
Ok(Some(val.into())) Ok(Some(val.into()))
}), }),
), ),
create_fn_by_intrinsic(
primitives,
&var_map,
"sin",
float,
&[(float, "x")],
"llvm.sin.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"cos",
float,
&[(float, "x")],
"llvm.cos.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"exp",
float,
&[(float, "x")],
"llvm.exp.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"exp2",
float,
&[(float, "x")],
"llvm.exp2.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"log",
float,
&[(float, "x")],
"llvm.log.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"log10",
float,
&[(float, "x")],
"llvm.log10.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"log2",
float,
&[(float, "x")],
"llvm.log2.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"fabs",
float,
&[(float, "x")],
"llvm.fabs.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"trunc",
float,
&[(float, "x")],
"llvm.trunc.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"sqrt",
float,
&[(float, "x")],
"llvm.sqrt.f64",
),
create_fn_by_codegen(
primitives,
&var_map,
"rint",
float,
&[(float, "x")],
Box::new(|ctx, _, fun, args, generator| {
let float = ctx.primitives.float;
let llvm_f64 = ctx.ctx.f64_type();
let x_ty = fun.0.args[0].ty;
let x_val = args[0].1.clone()
.to_basic_value_enum(ctx, generator, x_ty)?;
assert!(ctx.unifier.unioned(x_ty, float));
let intrinsic_fn = ctx.module.get_function("llvm.round.f64").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("llvm.round.f64", fn_type, None)
});
// rint(x) == round(x * 0.5) * 2.0
// %0 = fmul f64 %x, 0.5
let x_half = ctx.builder
.build_float_mul(x_val.into_float_value(), llvm_f64.const_float(0.5), "");
// %1 = call f64 @llvm.round.f64(f64 %0)
let round = ctx.builder
.build_call(
intrinsic_fn,
&vec![x_half.into()],
"",
)
.try_as_basic_value()
.left()
.unwrap();
// %2 = fmul f64 %1, 2.0
let val = ctx.builder
.build_float_mul(round.into_float_value(), llvm_f64.const_float(2.0).into(), "rint");
Ok(Some(val.into()))
}),
),
create_fn_by_extern(
primitives,
&var_map,
"tan",
float,
&[(float, "x")],
"tan",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"arcsin",
float,
&[(float, "x")],
"asin",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"arccos",
float,
&[(float, "x")],
"acos",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"arctan",
float,
&[(float, "x")],
"atan",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"sinh",
float,
&[(float, "x")],
"sinh",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"cosh",
float,
&[(float, "x")],
"cosh",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"tanh",
float,
&[(float, "x")],
"tanh",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"arcsinh",
float,
&[(float, "x")],
"asinh",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"arccosh",
float,
&[(float, "x")],
"acosh",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"arctanh",
float,
&[(float, "x")],
"atanh",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"expm1",
float,
&[(float, "x")],
"expm1",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"cbrt",
float,
&[(float, "x")],
"cbrt",
&["readnone", "willreturn"],
),
create_fn_by_extern(
primitives,
&var_map,
"erf",
float,
&[(float, "z")],
"erf",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"erfc",
float,
&[(float, "x")],
"erfc",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"gamma",
float,
&[(float, "z")],
"tgamma",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"gammaln",
float,
&[(float, "x")],
"lgamma",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"j0",
float,
&[(float, "x")],
"j0",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"j1",
float,
&[(float, "x")],
"j1",
&[],
),
// Not mapped: jv/yv, libm only supports integer orders.
create_fn_by_extern(
primitives,
&var_map,
"arctan2",
float,
&[(float, "x1"), (float, "x2")],
"atan2",
&[],
),
create_fn_by_intrinsic(
primitives,
&var_map,
"copysign",
float,
&[(float, "x1"), (float, "x2")],
"llvm.copysign.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"fmax",
float,
&[(float, "x1"), (float, "x2")],
"llvm.maxnum.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"fmin",
float,
&[(float, "x1"), (float, "x2")],
"llvm.minnum.f64",
),
create_fn_by_extern(
primitives,
&var_map,
"ldexp",
float,
&[(float, "x1"), (int32, "x2")],
"ldexp",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"hypot",
float,
&[(float, "x1"), (float, "x2")],
"hypot",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"nextafter",
float,
&[(float, "x1"), (float, "x2")],
"nextafter",
&[],
),
Arc::new(RwLock::new(TopLevelDef::Function { Arc::new(RwLock::new(TopLevelDef::Function {
name: "Some".into(), name: "Some".into(),
simple_name: "Some".into(), simple_name: "Some".into(),
@ -1314,6 +1613,42 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
"abs", "abs",
"isnan", "isnan",
"isinf", "isinf",
"sin",
"cos",
"exp",
"exp2",
"log",
"log10",
"log2",
"fabs",
"trunc",
"sqrt",
"rint",
"tan",
"arcsin",
"arccos",
"arctan",
"sinh",
"cosh",
"tanh",
"arcsinh",
"arccosh",
"arctanh",
"expm1",
"cbrt",
"erf",
"erfc",
"gamma",
"gammaln",
"j0",
"j1",
"arctan2",
"copysign",
"fmax",
"fmin",
"ldexp",
"hypot",
"nextafter",
"Some", "Some",
], ],
) )

View File

@ -5,6 +5,7 @@ import importlib.util
import importlib.machinery import importlib.machinery
import numpy as np import numpy as np
import pathlib import pathlib
import scipy
from numpy import int32, int64, uint32, uint64 from numpy import int32, int64, uint32, uint64
from typing import TypeVar, Generic from typing import TypeVar, Generic
@ -97,8 +98,49 @@ def patch(module):
module.Some = Some module.Some = Some
module.none = none module.none = none
# NumPy Math functions
module.isnan = np.isnan module.isnan = np.isnan
module.isinf = np.isinf module.isinf = np.isinf
module.sin = np.sin
module.cos = np.cos
module.exp = np.exp
module.exp2 = np.exp2
module.log = np.log
module.log10 = np.log10
module.log2 = np.log2
module.fabs = np.fabs
module.floor = np.floor
module.ceil = np.ceil
module.trunc = np.trunc
module.sqrt = np.sqrt
module.rint = np.rint
module.tan = np.tan
module.arcsin = np.arcsin
module.arccos = np.arccos
module.arctan = np.arctan
module.sinh = np.sinh
module.cosh = np.cosh
module.tanh = np.tanh
module.arcsinh = np.arcsinh
module.arccosh = np.arccosh
module.arctanh = np.arctanh
module.expm1 = np.expm1
module.cbrt = np.cbrt
module.arctan2 = np.arctan2
module.copysign = np.copysign
module.fmax = np.fmax
module.fmin = np.fmin
module.ldexp = np.ldexp
module.hypot = np.hypot
module.nextafter = np.nextafter
# SciPy Math Functions
module.erf = scipy.special.erf
module.erfc = scipy.special.erfc
module.gamma = scipy.special.gamma
module.gammaln = scipy.special.gammaln
module.j0 = scipy.special.j0
module.j1 = scipy.special.j1
def file_import(filename, prefix="file_import_"): def file_import(filename, prefix="file_import_"):

View File

@ -2,6 +2,10 @@
def output_bool(x: bool): def output_bool(x: bool):
... ...
@extern
def output_float64(x: float):
...
@extern @extern
def dbl_nan() -> float: def dbl_nan() -> float:
... ...
@ -18,8 +22,209 @@ def test_isinf():
for x in [dbl_inf(), 0.0, dbl_nan()]: for x in [dbl_inf(), 0.0, dbl_nan()]:
output_bool(isinf(x)) output_bool(isinf(x))
def test_sin():
pi = 3.1415926535897932384626433
for x in [-pi, -pi / 2.0, -pi / 4.0, 0.0, pi / 4.0, pi / 2.0, pi]:
output_float64(sin(x))
def test_cos():
pi = 3.1415926535897932384626433
for x in [-pi, -pi / 2.0, -pi / 4.0, 0.0, pi / 4.0, pi / 2.0, pi]:
output_float64(cos(x))
def test_exp():
for x in [0.0, 1.0]:
output_float64(exp(x))
def test_exp2():
for x in [0.0, 1.0]:
output_float64(exp2(x))
def test_log():
e = 2.71828182845904523536028747135266249775724709369995
for x in [1.0, e]:
output_float64(log(x))
def test_log10():
for x in [1.0, 10.0]:
output_float64(log10(x))
def test_log2():
for x in [1.0, 2.0]:
output_float64(log2(x))
def test_fabs():
for x in [-1.0, 0.0, 1.0]:
output_float64(fabs(x))
def test_floor():
for x in [-1.5, -0.5, 0.5, 1.5]:
output_float64(floor(x))
def test_ceil():
for x in [-1.5, -0.5, 0.5, 1.5]:
output_float64(ceil(x))
def test_trunc():
for x in [-1.5, -0.5, 0.5, 1.5]:
output_float64(trunc(x))
def test_sqrt():
for x in [1.0, 2.0, 4.0]:
output_float64(sqrt(x))
def test_rint():
for x in [-1.5, -0.5, 0.5, 1.5]:
output_float64(rint(x))
def test_tan():
pi = 3.1415926535897932384626433
for x in [-pi, -pi / 2.0, -pi / 4.0, 0.0, pi / 4.0, pi / 2.0, pi]:
output_float64(tan(x))
def test_arcsin():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(arcsin(x))
def test_arccos():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(arccos(x))
def test_arctan():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(arctan(x))
def test_sinh():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(sinh(x))
def test_cosh():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(cosh(x))
def test_tanh():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(tanh(x))
def test_arcsinh():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(arcsinh(x))
def test_arccosh():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(arccosh(x))
def test_arctanh():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(arctanh(x))
def test_expm1():
for x in [0.0, 1.0]:
output_float64(expm1(x))
def test_cbrt():
for x in [1.0, 8.0, 27.0]:
output_float64(expm1(x))
def test_erf():
for x in [-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]:
output_float64(erf(x))
def test_erfc():
for x in [-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]:
output_float64(erfc(x))
def test_gamma():
for x in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]:
output_float64(gamma(x))
def test_gammaln():
for x in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]:
output_float64(gammaln(x))
def test_j0():
for x in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]:
output_float64(j0(x))
def test_j1():
for x in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]:
output_float64(j1(x))
def test_arctan2():
for x1 in [-1.0, -0.5, 0.0, 0.5, 1.0]:
for x2 in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(arctan2(x1, x2))
def test_copysign():
for x1 in [-1.0, -0.5, 0.0, 0.5, 1.0]:
for x2 in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(copysign(x1, x2))
def test_fmax():
for x1 in [-1.0, -0.5, 0.0, 0.5, 1.0]:
for x2 in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(fmax(x1, x2))
def test_fmin():
for x1 in [-1.0, -0.5, 0.0, 0.5, 1.0]:
for x2 in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(fmin(x1, x2))
def test_ldexp():
for x1 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]:
for x2 in [-2, -1, 0, 1, 2]:
output_float64(ldexp(x1, x2))
def test_hypot():
for x1 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]:
for x2 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]:
output_float64(hypot(x1, x2))
def test_nextafter():
for x1 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]:
for x2 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]:
output_float64(nextafter(x1, x2))
def run() -> int32: def run() -> int32:
test_isnan() test_isnan()
test_isinf() test_isinf()
test_sin()
test_cos()
test_exp()
test_exp2()
test_log()
test_log10()
test_log2()
test_fabs()
test_floor()
test_ceil()
test_trunc()
test_sqrt()
test_rint()
test_tan()
test_arcsin()
test_arccos()
test_arctan()
test_sinh()
test_cosh()
test_tanh()
test_arcsinh()
test_arccosh()
test_arctanh()
test_expm1()
test_cbrt()
test_erf()
test_erfc()
test_gamma()
test_gammaln()
test_j0()
test_j1()
test_arctan2()
test_copysign()
test_fmax()
test_fmin()
test_ldexp()
test_hypot()
test_nextafter()
return 0 return 0