1
0
forked from M-Labs/nac3

core/builtins: Add np_min/np_max

This commit is contained in:
David Mak 2024-05-08 17:42:19 +08:00
parent 7627acea41
commit 73e81259f3
10 changed files with 299 additions and 17 deletions

View File

@ -1,11 +1,12 @@
use inkwell::{FloatPredicate, IntPredicate};
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
use inkwell::types::BasicTypeEnum;
use inkwell::values::BasicValueEnum;
use itertools::Itertools;
use crate::codegen::{CodeGenContext, CodeGenerator, extern_fns, irrt, llvm_intrinsics, numpy};
use crate::codegen::classes::NDArrayValue;
use crate::codegen::classes::{NDArrayValue, UntypedArrayLikeAccessor};
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
use crate::codegen::stmt::gen_for_callback_incrementing;
use crate::toplevel::helper::PRIMITIVE_DEF_IDS;
use crate::toplevel::numpy::unpack_ndarray_var_tys;
use crate::typecheck::typedef::Type;
@ -705,6 +706,92 @@ pub fn call_min<'ctx>(
}
}
/// Invokes the `np_min` builtin function.
pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
a: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_min";
let llvm_usize = generator.get_size_type(ctx.ctx);
let (a_ty, a) = a;
Ok(match a {
BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => {
debug_assert!([
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.uint32,
ctx.primitives.int64,
ctx.primitives.uint64,
ctx.primitives.float,
].iter().any(|ty| ctx.unifier.unioned(a_ty, *ty)));
a
}
BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes());
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let n_sz_eqz = ctx.builder
.build_int_compare(
IntPredicate::NE,
n_sz,
n_sz.get_type().const_zero(),
"",
)
.unwrap();
ctx.make_assert(
generator,
n_sz_eqz,
"0:ValueError",
"zero-size array to reduction operation minimum which has no identity",
[None, None, None],
ctx.current_loc,
);
}
let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
unsafe {
let identity = n.data()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
ctx.builder.build_store(accumulator_addr, identity).unwrap();
}
gen_for_callback_incrementing(
generator,
ctx,
llvm_usize.const_int(1, false),
(n_sz, false),
|generator, ctx, idx| {
let elem = unsafe {
n.data().get_unchecked(ctx, generator, &idx, None)
};
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
let result = call_min(ctx, (elem_ty, accumulator), (elem_ty, elem));
ctx.builder.build_store(accumulator_addr, result).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
accumulator
}
_ => unsupported_type(ctx, FN_NAME, &[a_ty])
})
}
/// Invokes the `max` builtin function.
pub fn call_max<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
@ -752,6 +839,92 @@ pub fn call_max<'ctx>(
}
}
/// Invokes the `np_max` builtin function.
pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
a: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_max";
let llvm_usize = generator.get_size_type(ctx.ctx);
let (a_ty, a) = a;
Ok(match a {
BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => {
debug_assert!([
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.uint32,
ctx.primitives.int64,
ctx.primitives.uint64,
ctx.primitives.float,
].iter().any(|ty| ctx.unifier.unioned(a_ty, *ty)));
a
}
BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes());
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let n_sz_eqz = ctx.builder
.build_int_compare(
IntPredicate::NE,
n_sz,
n_sz.get_type().const_zero(),
"",
)
.unwrap();
ctx.make_assert(
generator,
n_sz_eqz,
"0:ValueError",
"zero-size array to reduction operation minimum which has no identity",
[None, None, None],
ctx.current_loc,
);
}
let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
unsafe {
let identity = n.data()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
ctx.builder.build_store(accumulator_addr, identity).unwrap();
}
gen_for_callback_incrementing(
generator,
ctx,
llvm_usize.const_int(1, false),
(n_sz, false),
|generator, ctx, idx| {
let elem = unsafe {
n.data().get_unchecked(ctx, generator, &idx, None)
};
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
let result = call_max(ctx, (elem_ty, accumulator), (elem_ty, elem));
ctx.builder.build_store(accumulator_addr, result).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
accumulator
}
_ => unsupported_type(ctx, FN_NAME, &[a_ty])
})
}
/// Invokes the `abs` builtin function.
pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,

View File

@ -1,4 +1,14 @@
use super::*;
use std::iter::once;
use indexmap::IndexMap;
use inkwell::{
attributes::{Attribute, AttributeLoc},
IntPredicate,
types::{BasicMetadataTypeEnum, BasicType},
values::{BasicMetadataValueEnum, BasicValue, CallSiteValue}
};
use itertools::Either;
use crate::{
codegen::{
builtin_fns,
@ -15,13 +25,8 @@ use crate::{
},
typecheck::typedef::VarMap,
};
use inkwell::{
attributes::{Attribute, AttributeLoc},
types::{BasicType, BasicMetadataTypeEnum},
values::{BasicValue, BasicMetadataValueEnum, CallSiteValue},
IntPredicate
};
use itertools::Either;
use super::*;
type BuiltinInfo = Vec<(Arc<RwLock<TopLevelDef>>, Option<Stmt>)>;
@ -1378,6 +1383,28 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built
)))),
loc: None,
})),
{
let ret_ty = unifier.get_fresh_var(Some("R".into()), None);
let var_map = num_or_ndarray_var_map.clone()
.into_iter()
.chain(once((ret_ty.1, ret_ty.0)))
.collect::<IndexMap<_, _>>();
create_fn_by_codegen(
unifier,
&var_map,
"np_min",
ret_ty.0,
&[(float_or_ndarray_ty.0, "a")],
Box::new(|ctx, _, fun, args, generator| {
let a_ty = fun.0.args[0].ty;
let a = args[0].1.clone()
.to_basic_value_enum(ctx, generator, a_ty)?;
Ok(Some(builtin_fns::call_numpy_min(generator, ctx, (a_ty, a))?))
}),
)
},
Arc::new(RwLock::new(TopLevelDef::Function {
name: "max".into(),
simple_name: "max".into(),
@ -1405,6 +1432,28 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built
)))),
loc: None,
})),
{
let ret_ty = unifier.get_fresh_var(Some("R".into()), None);
let var_map = num_or_ndarray_var_map.clone()
.into_iter()
.chain(once((ret_ty.1, ret_ty.0)))
.collect::<IndexMap<_, _>>();
create_fn_by_codegen(
unifier,
&var_map,
"np_max",
ret_ty.0,
&[(float_or_ndarray_ty.0, "a")],
Box::new(|ctx, _, fun, args, generator| {
let a_ty = fun.0.args[0].ty;
let a = args[0].1.clone()
.to_basic_value_enum(ctx, generator, a_ty)?;
Ok(Some(builtin_fns::call_numpy_max(generator, ctx, (a_ty, a))?))
}),
)
},
Arc::new(RwLock::new(TopLevelDef::Function {
name: "abs".into(),
simple_name: "abs".into(),

View File

@ -5,7 +5,7 @@ expression: res_vec
[
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [222]\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [224]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",

View File

@ -7,7 +7,7 @@ expression: res_vec
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar211]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar211\"]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar213]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar213\"]\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",

View File

@ -5,8 +5,8 @@ expression: res_vec
[
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [224]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [229]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [226]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [231]\n}\n",
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",

View File

@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec
---
[
"Class {\nname: \"A\",\nancestors: [\"A[typevar210, typevar211]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar210\", \"typevar211\"]\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[typevar212, typevar213]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar212\", \"typevar213\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",

View File

@ -6,12 +6,12 @@ expression: res_vec
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [230]\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [232]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [238]\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [240]\n}\n",
]

View File

@ -908,6 +908,48 @@ impl<'a> Inferencer<'a> {
}))
}
if [
"np_min",
"np_max",
].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 {
let arg0 = self.fold_expr(args.remove(0))?;
let arg0_ty = arg0.custom.unwrap();
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
let (ndarray_dtype, _) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
ndarray_dtype
} else {
arg0_ty
};
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg {
name: "a".into(),
ty: arg0.custom.unwrap(),
default_value: None,
},
],
ret,
vars: VarMap::new(),
}));
return Ok(Some(Located {
location,
custom: Some(ret),
node: ExprKind::Call {
func: Box::new(Located {
custom: Some(custom),
location: func.location,
node: ExprKind::Name { id: *id, ctx: ctx.clone() },
}),
args: vec![arg0],
keywords: vec![],
},
}))
}
if [
"np_arctan2",
"np_copysign",

View File

@ -174,6 +174,8 @@ def patch(module):
# NumPy Math functions
module.np_isnan = np.isnan
module.np_isinf = np.isinf
module.np_min = np.min
module.np_max = np.max
module.np_sin = np.sin
module.np_cos = np.cos
module.np_exp = np.exp

View File

@ -759,6 +759,20 @@ def test_ndarray_ceil():
output_ndarray_int64_2(xf64)
output_ndarray_float_2(xff)
def test_ndarray_min():
x = np_identity(2)
y = np_min(x)
output_ndarray_float_2(x)
output_float64(y)
def test_ndarray_max():
x = np_identity(2)
y = np_max(x)
output_ndarray_float_2(x)
output_float64(y)
def test_ndarray_abs():
x = np_identity(2)
y = abs(x)
@ -1363,6 +1377,8 @@ def run() -> int32:
test_ndarray_round()
test_ndarray_floor()
test_ndarray_min()
test_ndarray_max()
test_ndarray_abs()
test_ndarray_isnan()
test_ndarray_isinf()