core: Implement numpy and scipy functions

This commit is contained in:
David Mak 2023-10-06 17:48:31 +08:00
parent c28ad78b65
commit e2d7a54d0d
3 changed files with 649 additions and 64 deletions

View File

@ -917,70 +917,22 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "floor".into(),
simple_name: "floor".into(),
signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { name: "n".into(), ty: float, default_value: None }],
ret: int32,
vars: Default::default(),
})),
var_id: Default::default(),
instance_to_symbol: Default::default(),
instance_to_stmt: Default::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, _, _, args, generator| {
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
let floor_intrinsic =
ctx.module.get_function("llvm.floor.f64").unwrap_or_else(|| {
let float = ctx.ctx.f64_type();
let fn_type = float.fn_type(&[float.into()], false);
ctx.module.add_function("llvm.floor.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(floor_intrinsic, &[arg.into()], "floor")
.try_as_basic_value()
.left()
.unwrap();
Ok(val.into())
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ceil".into(),
simple_name: "ceil".into(),
signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { name: "n".into(), ty: float, default_value: None }],
ret: int32,
vars: Default::default(),
})),
var_id: Default::default(),
instance_to_symbol: Default::default(),
instance_to_stmt: Default::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, _, _, args, generator| {
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
let ceil_intrinsic =
ctx.module.get_function("llvm.ceil.f64").unwrap_or_else(|| {
let float = ctx.ctx.f64_type();
let fn_type = float.fn_type(&[float.into()], false);
ctx.module.add_function("llvm.ceil.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(ceil_intrinsic, &[arg.into()], "ceil")
.try_as_basic_value()
.left()
.unwrap();
Ok(val.into())
},
)))),
loc: None,
})),
create_fn_by_intrinsic(
primitives,
&var_map,
"floor",
float,
&[(float, "x")],
"llvm.floor.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"ceil",
float,
&[(float, "x")],
"llvm.ceil.f64",
),
Arc::new(RwLock::new({
let list_var = primitives.1.get_fresh_var(Some("L".into()), None);
let list = primitives.1.add_ty(TypeEnum::TList { ty: list_var.0 });
@ -1226,6 +1178,353 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
)))),
loc: None,
})),
create_fn_by_intrinsic(
primitives,
&var_map,
"sin",
float,
&[(float, "x")],
"llvm.sin.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"cos",
float,
&[(float, "x")],
"llvm.cos.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"exp",
float,
&[(float, "x")],
"llvm.exp.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"exp2",
float,
&[(float, "x")],
"llvm.exp2.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"log",
float,
&[(float, "x")],
"llvm.log.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"log10",
float,
&[(float, "x")],
"llvm.log10.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"log2",
float,
&[(float, "x")],
"llvm.log2.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"fabs",
float,
&[(float, "x")],
"llvm.fabs.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"trunc",
float,
&[(float, "x")],
"llvm.trunc.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"sqrt",
float,
&[(float, "x")],
"llvm.sqrt.f64",
),
create_fn_by_codegen(
primitives,
&var_map,
"rint",
float,
&[(float, "x")],
Box::new(|ctx, _, fun, args, generator| {
let float = ctx.primitives.float;
let llvm_f64 = ctx.ctx.f64_type();
let x_ty = fun.0.args[0].ty;
let x_val = args[0].1.clone()
.to_basic_value_enum(ctx, generator, x_ty)?;
assert!(ctx.unifier.unioned(x_ty, float));
let intrinsic_fn = ctx.module.get_function("llvm.round.f64").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("llvm.round.f64", fn_type, None)
});
// rint(x) == round(x * 0.5) * 2.0
// %0 = fmul f64 %x, 0.5
let x_half = ctx.builder
.build_float_mul(x_val.into_float_value(), llvm_f64.const_float(0.5), "");
// %1 = call f64 @llvm.round.f64(f64 %0)
let round = ctx.builder
.build_call(
intrinsic_fn,
&vec![x_half.into()],
"",
)
.try_as_basic_value()
.left()
.unwrap();
// %2 = fmul f64 %1, 2.0
let val = ctx.builder
.build_float_mul(round.into_float_value(), llvm_f64.const_float(2.0).into(), "rint");
Ok(Some(val.into()))
}),
),
create_fn_by_extern(
primitives,
&var_map,
"tan",
float,
&[(float, "x")],
"tan",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"arcsin",
float,
&[(float, "x")],
"asin",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"arccos",
float,
&[(float, "x")],
"acos",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"arctan",
float,
&[(float, "x")],
"atan",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"sinh",
float,
&[(float, "x")],
"sinh",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"cosh",
float,
&[(float, "x")],
"cosh",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"tanh",
float,
&[(float, "x")],
"tanh",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"arcsinh",
float,
&[(float, "x")],
"asinh",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"arccosh",
float,
&[(float, "x")],
"acosh",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"arctanh",
float,
&[(float, "x")],
"atanh",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"expm1",
float,
&[(float, "x")],
"expm1",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"cbrt",
float,
&[(float, "x")],
"cbrt",
&["readnone", "willreturn"],
),
create_fn_by_extern(
primitives,
&var_map,
"erf",
float,
&[(float, "z")],
"erf",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"erfc",
float,
&[(float, "x")],
"erfc",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"gamma",
float,
&[(float, "z")],
"tgamma",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"gammaln",
float,
&[(float, "x")],
"lgamma",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"j0",
float,
&[(float, "x")],
"j0",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"j1",
float,
&[(float, "x")],
"j1",
&[],
),
// Not mapped: jv/yv, libm only supports integer orders.
create_fn_by_extern(
primitives,
&var_map,
"arctan2",
float,
&[(float, "x1"), (float, "x2")],
"atan2",
&[],
),
create_fn_by_intrinsic(
primitives,
&var_map,
"copysign",
float,
&[(float, "x1"), (float, "x2")],
"llvm.copysign.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"fmax",
float,
&[(float, "x1"), (float, "x2")],
"llvm.maxnum.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"fmin",
float,
&[(float, "x1"), (float, "x2")],
"llvm.minnum.f64",
),
create_fn_by_extern(
primitives,
&var_map,
"ldexp",
float,
&[(float, "x1"), (int32, "x2")],
"ldexp",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"hypot",
float,
&[(float, "x1"), (float, "x2")],
"hypot",
&[],
),
create_fn_by_extern(
primitives,
&var_map,
"nextafter",
float,
&[(float, "x1"), (float, "x2")],
"nextafter",
&[],
),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "Some".into(),
simple_name: "Some".into(),
@ -1270,6 +1569,42 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
"min",
"max",
"abs",
"sin",
"cos",
"exp",
"exp2",
"log",
"log10",
"log2",
"fabs",
"trunc",
"sqrt",
"rint",
"tan",
"arcsin",
"arccos",
"arctan",
"sinh",
"cosh",
"tanh",
"arcsinh",
"arccosh",
"arctanh",
"expm1",
"cbrt",
"erf",
"erfc",
"gamma",
"gammaln",
"j0",
"j1",
"arctan2",
"copysign",
"fmax",
"fmin",
"ldexp",
"hypot",
"nextafter",
"Some",
],
)

View File

@ -3,7 +3,9 @@
import sys
import importlib.util
import importlib.machinery
import numpy
import pathlib
import scipy
from numpy import int32, int64, uint32, uint64
from typing import TypeVar, Generic
@ -86,6 +88,46 @@ def patch(module):
module.Some = Some
module.none = none
# NumPy math functions
module.sin = numpy.sin
module.cos = numpy.cos
module.exp = numpy.exp
module.exp2 = numpy.exp2
module.log = numpy.log
module.log10 = numpy.log10
module.log2 = numpy.log2
module.fabs = numpy.fabs
module.floor = numpy.floor
module.ceil = numpy.ceil
module.trunc = numpy.trunc
module.sqrt = numpy.sqrt
module.rint = numpy.rint
module.tan = numpy.tan
module.arcsin = numpy.arcsin
module.arccos = numpy.arccos
module.arctan = numpy.arctan
module.sinh = numpy.sinh
module.cosh = numpy.cosh
module.tanh = numpy.tanh
module.arcsinh = numpy.arcsinh
module.arccosh = numpy.arccosh
module.arctanh = numpy.arctanh
module.expm1 = numpy.expm1
module.cbrt = numpy.cbrt
module.erf = scipy.special.erf
module.erfc = scipy.special.erfc
module.gamma = scipy.special.gamma
module.gammaln = scipy.special.gammaln
module.j0 = scipy.special.j0
module.j1 = scipy.special.j1
module.arctan2 = numpy.arctan2
module.copysign = numpy.copysign
module.fmax = numpy.fmax
module.fmin = numpy.fmin
module.ldexp = numpy.ldexp
module.hypot = numpy.hypot
module.nextafter = numpy.nextafter
def file_import(filename, prefix="file_import_"):
filename = pathlib.Path(filename)

View File

@ -0,0 +1,208 @@
@extern
def output_float64(x: float):
...
def test_sin():
pi = 3.1415926535897932384626433
for x in [-pi, -pi / 2.0, -pi / 4.0, 0.0, pi / 4.0, pi / 2.0, pi]:
output_float64(sin(x))
def test_cos():
pi = 3.1415926535897932384626433
for x in [-pi, -pi / 2.0, -pi / 4.0, 0.0, pi / 4.0, pi / 2.0, pi]:
output_float64(cos(x))
def test_exp():
for x in [0.0, 1.0]:
output_float64(exp(x))
def test_exp2():
for x in [0.0, 1.0]:
output_float64(exp2(x))
def test_log():
e = 2.71828182845904523536028747135266249775724709369995
for x in [1.0, e]:
output_float64(log(x))
def test_log10():
for x in [1.0, 10.0]:
output_float64(log10(x))
def test_log2():
for x in [1.0, 2.0]:
output_float64(log2(x))
def test_fabs():
for x in [-1.0, 0.0, 1.0]:
output_float64(fabs(x))
def test_floor():
for x in [-1.5, -0.5, 0.5, 1.5]:
output_float64(floor(x))
def test_ceil():
for x in [-1.5, -0.5, 0.5, 1.5]:
output_float64(ceil(x))
def test_trunc():
for x in [-1.5, -0.5, 0.5, 1.5]:
output_float64(trunc(x))
def test_sqrt():
for x in [1.0, 2.0, 4.0]:
output_float64(sqrt(x))
def test_rint():
for x in [-1.5, -0.5, 0.5, 1.5]:
output_float64(rint(x))
def test_tan():
pi = 3.1415926535897932384626433
for x in [-pi, -pi / 2.0, -pi / 4.0, 0.0, pi / 4.0, pi / 2.0, pi]:
output_float64(tan(x))
def test_arcsin():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(arcsin(x))
def test_arccos():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(arccos(x))
def test_arctan():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(arctan(x))
def test_sinh():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(sinh(x))
def test_cosh():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(cosh(x))
def test_tanh():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(tanh(x))
def test_arcsinh():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(arcsinh(x))
def test_arccosh():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(arccosh(x))
def test_arctanh():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(arctanh(x))
def test_expm1():
for x in [0.0, 1.0]:
output_float64(expm1(x))
def test_cbrt():
for x in [1.0, 8.0, 27.0]:
output_float64(expm1(x))
def test_erf():
for x in [-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]:
output_float64(erf(x))
def test_erfc():
for x in [-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]:
output_float64(erfc(x))
def test_gamma():
for x in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]:
output_float64(gamma(x))
def test_gammaln():
for x in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]:
output_float64(gammaln(x))
def test_j0():
for x in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]:
output_float64(j0(x))
def test_j1():
for x in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]:
output_float64(j1(x))
def test_arctan2():
for x1 in [-1.0, -0.5, 0.0, 0.5, 1.0]:
for x2 in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(arctan2(x1, x2))
def test_copysign():
for x1 in [-1.0, -0.5, 0.0, 0.5, 1.0]:
for x2 in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(copysign(x1, x2))
def test_fmax():
for x1 in [-1.0, -0.5, 0.0, 0.5, 1.0]:
for x2 in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(fmax(x1, x2))
def test_fmin():
for x1 in [-1.0, -0.5, 0.0, 0.5, 1.0]:
for x2 in [-1.0, -0.5, 0.0, 0.5, 1.0]:
output_float64(fmin(x1, x2))
def test_ldexp():
for x1 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]:
for x2 in [-2, -1, 0, 1, 2]:
output_float64(ldexp(x1, x2))
def test_hypot():
for x1 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]:
for x2 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]:
output_float64(hypot(x1, x2))
def test_nextafter():
for x1 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]:
for x2 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]:
output_float64(nextafter(x1, x2))
def run() -> int32:
test_sin()
test_cos()
test_exp()
test_exp2()
test_log()
test_log10()
test_log2()
test_fabs()
test_floor()
test_ceil()
test_trunc()
test_sqrt()
test_rint()
test_tan()
test_arcsin()
test_arccos()
test_arctan()
test_sinh()
test_cosh()
test_tanh()
test_arcsinh()
test_arccosh()
test_arctanh()
test_expm1()
test_cbrt()
test_erf()
test_erfc()
test_gamma()
test_gammaln()
test_j0()
test_j1()
test_arctan2()
test_copysign()
test_fmax()
test_fmin()
test_ldexp()
test_hypot()
test_nextafter()
return 0