core: Implement numpy and scipy functions

This commit is contained in:
David Mak 2023-10-06 17:48:31 +08:00
parent c28ad78b65
commit e2d7a54d0d
3 changed files with 649 additions and 64 deletions

View File

@ -917,70 +917,22 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
)))), )))),
loc: None, loc: None,
})), })),
Arc::new(RwLock::new(TopLevelDef::Function { create_fn_by_intrinsic(
name: "floor".into(), primitives,
simple_name: "floor".into(), &var_map,
signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature { "floor",
args: vec![FuncArg { name: "n".into(), ty: float, default_value: None }], float,
ret: int32, &[(float, "x")],
vars: Default::default(), "llvm.floor.f64",
})), ),
var_id: Default::default(), create_fn_by_intrinsic(
instance_to_symbol: Default::default(), primitives,
instance_to_stmt: Default::default(), &var_map,
resolver: None, "ceil",
codegen_callback: Some(Arc::new(GenCall::new(Box::new( float,
|ctx, _, _, args, generator| { &[(float, "x")],
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, ctx.primitives.float)?; "llvm.ceil.f64",
let floor_intrinsic = ),
ctx.module.get_function("llvm.floor.f64").unwrap_or_else(|| {
let float = ctx.ctx.f64_type();
let fn_type = float.fn_type(&[float.into()], false);
ctx.module.add_function("llvm.floor.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(floor_intrinsic, &[arg.into()], "floor")
.try_as_basic_value()
.left()
.unwrap();
Ok(val.into())
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ceil".into(),
simple_name: "ceil".into(),
signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { name: "n".into(), ty: float, default_value: None }],
ret: int32,
vars: Default::default(),
})),
var_id: Default::default(),
instance_to_symbol: Default::default(),
instance_to_stmt: Default::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, _, _, args, generator| {
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
let ceil_intrinsic =
ctx.module.get_function("llvm.ceil.f64").unwrap_or_else(|| {
let float = ctx.ctx.f64_type();
let fn_type = float.fn_type(&[float.into()], false);
ctx.module.add_function("llvm.ceil.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(ceil_intrinsic, &[arg.into()], "ceil")
.try_as_basic_value()
.left()
.unwrap();
Ok(val.into())
},
)))),
loc: None,
})),
Arc::new(RwLock::new({ Arc::new(RwLock::new({
let list_var = primitives.1.get_fresh_var(Some("L".into()), None); let list_var = primitives.1.get_fresh_var(Some("L".into()), None);
let list = primitives.1.add_ty(TypeEnum::TList { ty: list_var.0 }); let list = primitives.1.add_ty(TypeEnum::TList { ty: list_var.0 });
@ -1226,6 +1178,353 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
)))), )))),
loc: None, loc: None,
})), })),
create_fn_by_intrinsic(
primitives,
&var_map,
"sin",
float,
&[(float, "x")],
"llvm.sin.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"cos",
float,
&[(float, "x")],
"llvm.cos.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"exp",
float,
&[(float, "x")],
"llvm.exp.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"exp2",
float,
&[(float, "x")],
"llvm.exp2.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"log",
float,
&[(float, "x")],
"llvm.log.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"log10",
float,
&[(float, "x")],
"llvm.log10.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"log2",
float,
&[(float, "x")],
"llvm.log2.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"fabs",
float,
&[(float, "x")],
"llvm.fabs.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"trunc",
float,
&[(float, "x")],
"llvm.trunc.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"sqrt",
float,
&[(float, "x")],
"llvm.sqrt.f64",
),
create_fn_by_codegen(
primitives,
&var_map,
"rint",
float,
&[(float, "x")],
Box::new(|ctx, _, fun, args, generator| {
let float = ctx.primitives.float;
let llvm_f64 = ctx.ctx.f64_type();
let x_ty = fun.0.args[0].ty;
let x_val = args[0].1.clone()
.to_basic_value_enum(ctx, generator, x_ty)?;
assert!(ctx.unifier.unioned(x_ty, float));
let intrinsic_fn = ctx.module.get_function("llvm.round.f64").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("llvm.round.f64", fn_type, None)
});
// rint(x) == round(x * 0.5) * 2.0
// %0 = fmul f64 %x, 0.5
let x_half = ctx.builder
.build_float_mul(x_val.into_float_value(), llvm_f64.const_float(0.5), "");
// %1 = call f64 @llvm.round.f64(f64 %0)
let round = ctx.builder
.build_call(
intrinsic_fn,
&vec![x_half.into()],
"",
)
.try_as_basic_value()
.left()
.unwrap();
// %2 = fmul f64 %1, 2.0
let val = ctx.builder
.build_float_mul(round.into_float_value(), llvm_f64.const_float(2.0).into(), "rint");
Ok(Some(val.into()))
}),
),
create_fn_by_extern(
primitives,
&var_map,
"tan",
float,
&[(float, "x")],
"tan",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"arcsin",
float,
&[(float, "x")],
"asin",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"arccos",
float,
&[(float, "x")],
"acos",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"arctan",
float,
&[(float, "x")],
"atan",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"sinh",
float,
&[(float, "x")],
"sinh",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"cosh",
float,
&[(float, "x")],
"cosh",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"tanh",
float,
&[(float, "x")],
"tanh",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"arcsinh",
float,
&[(float, "x")],
"asinh",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"arccosh",
float,
&[(float, "x")],
"acosh",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"arctanh",
float,
&[(float, "x")],
"atanh",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"expm1",
float,
&[(float, "x")],
"expm1",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"cbrt",
float,
&[(float, "x")],
"cbrt",
&["readnone", "willreturn"],
),
create_fn_by_extern(
primitives,
&var_map,
"erf",
float,
&[(float, "z")],
"erf",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"erfc",
float,
&[(float, "x")],
"erfc",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"gamma",
float,
&[(float, "z")],
"tgamma",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"gammaln",
float,
&[(float, "x")],
"lgamma",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"j0",
float,
&[(float, "x")],
"j0",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"j1",
float,
&[(float, "x")],
"j1",
&[],
),
// Not mapped: jv/yv, libm only supports integer orders.
create_fn_by_extern(
primitives,
&var_map,
"arctan2",
float,
&[(float, "x1"), (float, "x2")],
"atan2",
&[],
),
create_fn_by_intrinsic(
primitives,
&var_map,
"copysign",
float,
&[(float, "x1"), (float, "x2")],
"llvm.copysign.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"fmax",
float,
&[(float, "x1"), (float, "x2")],
"llvm.maxnum.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"fmin",
float,
&[(float, "x1"), (float, "x2")],
"llvm.minnum.f64",
),
create_fn_by_extern(
primitives,
&var_map,
"ldexp",
float,
&[(float, "x1"), (int32, "x2")],
"ldexp",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"hypot",
float,
&[(float, "x1"), (float, "x2")],
"hypot",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"nextafter",
float,
&[(float, "x1"), (float, "x2")],
"nextafter",
&[],
),
Arc::new(RwLock::new(TopLevelDef::Function { Arc::new(RwLock::new(TopLevelDef::Function {
name: "Some".into(), name: "Some".into(),
simple_name: "Some".into(), simple_name: "Some".into(),
@ -1270,6 +1569,42 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
"min", "min",
"max", "max",
"abs", "abs",
"sin",
"cos",
"exp",
"exp2",
"log",
"log10",
"log2",
"fabs",
"trunc",
"sqrt",
"rint",
"tan",
"arcsin",
"arccos",
"arctan",
"sinh",
"cosh",
"tanh",
"arcsinh",
"arccosh",
"arctanh",
"expm1",
"cbrt",
"erf",
"erfc",
"gamma",
"gammaln",
"j0",
"j1",
"arctan2",
"copysign",
"fmax",
"fmin",
"ldexp",
"hypot",
"nextafter",
"Some", "Some",
], ],
) )

View File

@ -3,7 +3,9 @@
import sys import sys
import importlib.util import importlib.util
import importlib.machinery import importlib.machinery
import numpy
import pathlib import pathlib
import scipy
from numpy import int32, int64, uint32, uint64 from numpy import int32, int64, uint32, uint64
from typing import TypeVar, Generic from typing import TypeVar, Generic
@ -86,6 +88,46 @@ def patch(module):
module.Some = Some module.Some = Some
module.none = none module.none = none
# NumPy math functions
module.sin = numpy.sin
module.cos = numpy.cos
module.exp = numpy.exp
module.exp2 = numpy.exp2
module.log = numpy.log
module.log10 = numpy.log10
module.log2 = numpy.log2
module.fabs = numpy.fabs
module.floor = numpy.floor
module.ceil = numpy.ceil
module.trunc = numpy.trunc
module.sqrt = numpy.sqrt
module.rint = numpy.rint
module.tan = numpy.tan
module.arcsin = numpy.arcsin
module.arccos = numpy.arccos
module.arctan = numpy.arctan
module.sinh = numpy.sinh
module.cosh = numpy.cosh
module.tanh = numpy.tanh
module.arcsinh = numpy.arcsinh
module.arccosh = numpy.arccosh
module.arctanh = numpy.arctanh
module.expm1 = numpy.expm1
module.cbrt = numpy.cbrt
module.erf = scipy.special.erf
module.erfc = scipy.special.erfc
module.gamma = scipy.special.gamma
module.gammaln = scipy.special.gammaln
module.j0 = scipy.special.j0
module.j1 = scipy.special.j1
module.arctan2 = numpy.arctan2
module.copysign = numpy.copysign
module.fmax = numpy.fmax
module.fmin = numpy.fmin
module.ldexp = numpy.ldexp
module.hypot = numpy.hypot
module.nextafter = numpy.nextafter
def file_import(filename, prefix="file_import_"): def file_import(filename, prefix="file_import_"):
filename = pathlib.Path(filename) filename = pathlib.Path(filename)

View File

@ -0,0 +1,208 @@
@extern
def output_float64(x: float):
...
def test_sin():
pi = 3.1415926535897932384626433
for x in [-pi, -pi / 2.0, -pi / 4.0, 0.0, pi / 4.0, pi / 2.0, pi]:
output_float64(sin(x))
def test_cos():
pi = 3.1415926535897932384626433
for x in [-pi, -pi / 2.0, -pi / 4.0, 0.0, pi / 4.0, pi / 2.0, pi]:
output_float64(cos(x))
def test_exp():
for x in [0.0, 1.0]:
output_float64(exp(x))
def test_exp2():
for x in [0.0, 1.0]:
output_float64(exp2(x))
def test_log():
e = 2.71828182845904523536028747135266249775724709369995
for x in [1.0, e]:
output_float64(log(x))
def test_log10():
for x in [1.0, 10.0]:
output_float64(log10(x))
def test_log2():
for x in [1.0, 2.0]:
output_float64(log2(x))
def test_fabs():
for x in [-1.0, 0.0, 1.0]:
output_float64(fabs(x))
def test_floor():
for x in [-1.5, -0.5, 0.5, 1.5]:
output_float64(floor(x))
def test_ceil():
for x in [-1.5, -0.5, 0.5, 1.5]:
output_float64(ceil(x))
def test_trunc():
for x in [-1.5, -0.5, 0.5, 1.5]:
output_float64(trunc(x))
def test_sqrt():
for x in [1.0, 2.0, 4.0]:
output_float64(sqrt(x))
def test_rint():
for x in [-1.5, -0.5, 0.5, 1.5]:
output_float64(rint(x))
def test_tan():
pi = 3.1415926535897932384626433
for x in [-pi, -pi / 2.0, -pi / 4.0, 0.0, pi / 4.0, pi / 2.0, pi]:
output_float64(tan(x))
def test_arcsin():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(arcsin(x))
def test_arccos():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(arccos(x))
def test_arctan():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(arctan(x))
def test_sinh():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(sinh(x))
def test_cosh():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(cosh(x))
def test_tanh():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(tanh(x))
def test_arcsinh():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(arcsinh(x))
def test_arccosh():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(arccosh(x))
def test_arctanh():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(arctanh(x))
def test_expm1():
for x in [0.0, 1.0]:
output_float64(expm1(x))
def test_cbrt():
for x in [1.0, 8.0, 27.0]:
output_float64(expm1(x))
def test_erf():
for x in [-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]:
output_float64(erf(x))
def test_erfc():
for x in [-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]:
output_float64(erfc(x))
def test_gamma():
for x in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]:
output_float64(gamma(x))
def test_gammaln():
for x in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]:
output_float64(gammaln(x))
def test_j0():
for x in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]:
output_float64(j0(x))
def test_j1():
for x in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]:
output_float64(j1(x))
def test_arctan2():
for x1 in [-1.0, -0.5, 0.0, 0.5, 1.0]:
for x2 in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(arctan2(x1, x2))
def test_copysign():
for x1 in [-1.0, -0.5, 0.0, 0.5, 1.0]:
for x2 in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(copysign(x1, x2))
def test_fmax():
for x1 in [-1.0, -0.5, 0.0, 0.5, 1.0]:
for x2 in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(fmax(x1, x2))
def test_fmin():
for x1 in [-1.0, -0.5, 0.0, 0.5, 1.0]:
for x2 in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(fmin(x1, x2))
def test_ldexp():
for x1 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]:
for x2 in [-2, -1, 0, 1, 2]:
output_float64(ldexp(x1, x2))
def test_hypot():
for x1 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]:
for x2 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]:
output_float64(hypot(x1, x2))
def test_nextafter():
for x1 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]:
for x2 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]:
output_float64(nextafter(x1, x2))
def run() -> int32:
test_sin()
test_cos()
test_exp()
test_exp2()
test_log()
test_log10()
test_log2()
test_fabs()
test_floor()
test_ceil()
test_trunc()
test_sqrt()
test_rint()
test_tan()
test_arcsin()
test_arccos()
test_arctan()
test_sinh()
test_cosh()
test_tanh()
test_arcsinh()
test_arccosh()
test_arctanh()
test_expm1()
test_cbrt()
test_erf()
test_erfc()
test_gamma()
test_gammaln()
test_j0()
test_j1()
test_arctan2()
test_copysign()
test_fmax()
test_fmin()
test_ldexp()
test_hypot()
test_nextafter()
return 0