forked from M-Labs/nac3
core/builtins: Add np_min/np_max
This commit is contained in:
parent
7627acea41
commit
73e81259f3
|
@ -1,11 +1,12 @@
|
||||||
use inkwell::{FloatPredicate, IntPredicate};
|
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
|
||||||
use inkwell::types::BasicTypeEnum;
|
use inkwell::types::BasicTypeEnum;
|
||||||
use inkwell::values::BasicValueEnum;
|
use inkwell::values::BasicValueEnum;
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
|
|
||||||
use crate::codegen::{CodeGenContext, CodeGenerator, extern_fns, irrt, llvm_intrinsics, numpy};
|
use crate::codegen::{CodeGenContext, CodeGenerator, extern_fns, irrt, llvm_intrinsics, numpy};
|
||||||
use crate::codegen::classes::NDArrayValue;
|
use crate::codegen::classes::{NDArrayValue, UntypedArrayLikeAccessor};
|
||||||
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
|
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
|
||||||
|
use crate::codegen::stmt::gen_for_callback_incrementing;
|
||||||
use crate::toplevel::helper::PRIMITIVE_DEF_IDS;
|
use crate::toplevel::helper::PRIMITIVE_DEF_IDS;
|
||||||
use crate::toplevel::numpy::unpack_ndarray_var_tys;
|
use crate::toplevel::numpy::unpack_ndarray_var_tys;
|
||||||
use crate::typecheck::typedef::Type;
|
use crate::typecheck::typedef::Type;
|
||||||
|
@ -705,6 +706,92 @@ pub fn call_min<'ctx>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Invokes the `np_min` builtin function.
|
||||||
|
pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
a: (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
const FN_NAME: &str = "np_min";
|
||||||
|
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
let (a_ty, a) = a;
|
||||||
|
|
||||||
|
Ok(match a {
|
||||||
|
BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => {
|
||||||
|
debug_assert!([
|
||||||
|
ctx.primitives.bool,
|
||||||
|
ctx.primitives.int32,
|
||||||
|
ctx.primitives.uint32,
|
||||||
|
ctx.primitives.int64,
|
||||||
|
ctx.primitives.uint64,
|
||||||
|
ctx.primitives.float,
|
||||||
|
].iter().any(|ty| ctx.unifier.unioned(a_ty, *ty)));
|
||||||
|
|
||||||
|
a
|
||||||
|
}
|
||||||
|
|
||||||
|
BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => {
|
||||||
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
|
||||||
|
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
|
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
|
||||||
|
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes());
|
||||||
|
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||||
|
let n_sz_eqz = ctx.builder
|
||||||
|
.build_int_compare(
|
||||||
|
IntPredicate::NE,
|
||||||
|
n_sz,
|
||||||
|
n_sz.get_type().const_zero(),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
n_sz_eqz,
|
||||||
|
"0:ValueError",
|
||||||
|
"zero-size array to reduction operation minimum which has no identity",
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
|
||||||
|
unsafe {
|
||||||
|
let identity = n.data()
|
||||||
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
|
||||||
|
ctx.builder.build_store(accumulator_addr, identity).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
gen_for_callback_incrementing(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
llvm_usize.const_int(1, false),
|
||||||
|
(n_sz, false),
|
||||||
|
|generator, ctx, idx| {
|
||||||
|
let elem = unsafe {
|
||||||
|
n.data().get_unchecked(ctx, generator, &idx, None)
|
||||||
|
};
|
||||||
|
|
||||||
|
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
|
||||||
|
let result = call_min(ctx, (elem_ty, accumulator), (elem_ty, elem));
|
||||||
|
ctx.builder.build_store(accumulator_addr, result).unwrap();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
llvm_usize.const_int(1, false),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
|
||||||
|
accumulator
|
||||||
|
}
|
||||||
|
|
||||||
|
_ => unsupported_type(ctx, FN_NAME, &[a_ty])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
/// Invokes the `max` builtin function.
|
/// Invokes the `max` builtin function.
|
||||||
pub fn call_max<'ctx>(
|
pub fn call_max<'ctx>(
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
@ -752,6 +839,92 @@ pub fn call_max<'ctx>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Invokes the `np_max` builtin function.
|
||||||
|
pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
a: (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
const FN_NAME: &str = "np_max";
|
||||||
|
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
let (a_ty, a) = a;
|
||||||
|
|
||||||
|
Ok(match a {
|
||||||
|
BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => {
|
||||||
|
debug_assert!([
|
||||||
|
ctx.primitives.bool,
|
||||||
|
ctx.primitives.int32,
|
||||||
|
ctx.primitives.uint32,
|
||||||
|
ctx.primitives.int64,
|
||||||
|
ctx.primitives.uint64,
|
||||||
|
ctx.primitives.float,
|
||||||
|
].iter().any(|ty| ctx.unifier.unioned(a_ty, *ty)));
|
||||||
|
|
||||||
|
a
|
||||||
|
}
|
||||||
|
|
||||||
|
BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => {
|
||||||
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
|
||||||
|
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
|
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
|
||||||
|
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes());
|
||||||
|
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||||
|
let n_sz_eqz = ctx.builder
|
||||||
|
.build_int_compare(
|
||||||
|
IntPredicate::NE,
|
||||||
|
n_sz,
|
||||||
|
n_sz.get_type().const_zero(),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
n_sz_eqz,
|
||||||
|
"0:ValueError",
|
||||||
|
"zero-size array to reduction operation minimum which has no identity",
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
|
||||||
|
unsafe {
|
||||||
|
let identity = n.data()
|
||||||
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
|
||||||
|
ctx.builder.build_store(accumulator_addr, identity).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
gen_for_callback_incrementing(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
llvm_usize.const_int(1, false),
|
||||||
|
(n_sz, false),
|
||||||
|
|generator, ctx, idx| {
|
||||||
|
let elem = unsafe {
|
||||||
|
n.data().get_unchecked(ctx, generator, &idx, None)
|
||||||
|
};
|
||||||
|
|
||||||
|
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
|
||||||
|
let result = call_max(ctx, (elem_ty, accumulator), (elem_ty, elem));
|
||||||
|
ctx.builder.build_store(accumulator_addr, result).unwrap();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
llvm_usize.const_int(1, false),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
|
||||||
|
accumulator
|
||||||
|
}
|
||||||
|
|
||||||
|
_ => unsupported_type(ctx, FN_NAME, &[a_ty])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
/// Invokes the `abs` builtin function.
|
/// Invokes the `abs` builtin function.
|
||||||
pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
|
pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
|
|
|
@ -1,4 +1,14 @@
|
||||||
use super::*;
|
use std::iter::once;
|
||||||
|
|
||||||
|
use indexmap::IndexMap;
|
||||||
|
use inkwell::{
|
||||||
|
attributes::{Attribute, AttributeLoc},
|
||||||
|
IntPredicate,
|
||||||
|
types::{BasicMetadataTypeEnum, BasicType},
|
||||||
|
values::{BasicMetadataValueEnum, BasicValue, CallSiteValue}
|
||||||
|
};
|
||||||
|
use itertools::Either;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
codegen::{
|
codegen::{
|
||||||
builtin_fns,
|
builtin_fns,
|
||||||
|
@ -15,13 +25,8 @@ use crate::{
|
||||||
},
|
},
|
||||||
typecheck::typedef::VarMap,
|
typecheck::typedef::VarMap,
|
||||||
};
|
};
|
||||||
use inkwell::{
|
|
||||||
attributes::{Attribute, AttributeLoc},
|
use super::*;
|
||||||
types::{BasicType, BasicMetadataTypeEnum},
|
|
||||||
values::{BasicValue, BasicMetadataValueEnum, CallSiteValue},
|
|
||||||
IntPredicate
|
|
||||||
};
|
|
||||||
use itertools::Either;
|
|
||||||
|
|
||||||
type BuiltinInfo = Vec<(Arc<RwLock<TopLevelDef>>, Option<Stmt>)>;
|
type BuiltinInfo = Vec<(Arc<RwLock<TopLevelDef>>, Option<Stmt>)>;
|
||||||
|
|
||||||
|
@ -1378,6 +1383,28 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built
|
||||||
)))),
|
)))),
|
||||||
loc: None,
|
loc: None,
|
||||||
})),
|
})),
|
||||||
|
{
|
||||||
|
let ret_ty = unifier.get_fresh_var(Some("R".into()), None);
|
||||||
|
let var_map = num_or_ndarray_var_map.clone()
|
||||||
|
.into_iter()
|
||||||
|
.chain(once((ret_ty.1, ret_ty.0)))
|
||||||
|
.collect::<IndexMap<_, _>>();
|
||||||
|
|
||||||
|
create_fn_by_codegen(
|
||||||
|
unifier,
|
||||||
|
&var_map,
|
||||||
|
"np_min",
|
||||||
|
ret_ty.0,
|
||||||
|
&[(float_or_ndarray_ty.0, "a")],
|
||||||
|
Box::new(|ctx, _, fun, args, generator| {
|
||||||
|
let a_ty = fun.0.args[0].ty;
|
||||||
|
let a = args[0].1.clone()
|
||||||
|
.to_basic_value_enum(ctx, generator, a_ty)?;
|
||||||
|
|
||||||
|
Ok(Some(builtin_fns::call_numpy_min(generator, ctx, (a_ty, a))?))
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
},
|
||||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||||
name: "max".into(),
|
name: "max".into(),
|
||||||
simple_name: "max".into(),
|
simple_name: "max".into(),
|
||||||
|
@ -1405,6 +1432,28 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built
|
||||||
)))),
|
)))),
|
||||||
loc: None,
|
loc: None,
|
||||||
})),
|
})),
|
||||||
|
{
|
||||||
|
let ret_ty = unifier.get_fresh_var(Some("R".into()), None);
|
||||||
|
let var_map = num_or_ndarray_var_map.clone()
|
||||||
|
.into_iter()
|
||||||
|
.chain(once((ret_ty.1, ret_ty.0)))
|
||||||
|
.collect::<IndexMap<_, _>>();
|
||||||
|
|
||||||
|
create_fn_by_codegen(
|
||||||
|
unifier,
|
||||||
|
&var_map,
|
||||||
|
"np_max",
|
||||||
|
ret_ty.0,
|
||||||
|
&[(float_or_ndarray_ty.0, "a")],
|
||||||
|
Box::new(|ctx, _, fun, args, generator| {
|
||||||
|
let a_ty = fun.0.args[0].ty;
|
||||||
|
let a = args[0].1.clone()
|
||||||
|
.to_basic_value_enum(ctx, generator, a_ty)?;
|
||||||
|
|
||||||
|
Ok(Some(builtin_fns::call_numpy_max(generator, ctx, (a_ty, a))?))
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
},
|
||||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||||
name: "abs".into(),
|
name: "abs".into(),
|
||||||
simple_name: "abs".into(),
|
simple_name: "abs".into(),
|
||||||
|
|
|
@ -5,7 +5,7 @@ expression: res_vec
|
||||||
[
|
[
|
||||||
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
||||||
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [222]\n}\n",
|
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [224]\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
||||||
|
|
|
@ -7,7 +7,7 @@ expression: res_vec
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar211]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar211\"]\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B[typevar213]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar213\"]\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
||||||
|
|
|
@ -5,8 +5,8 @@ expression: res_vec
|
||||||
[
|
[
|
||||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [224]\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [226]\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [229]\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [231]\n}\n",
|
||||||
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
|
|
|
@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
|
||||||
expression: res_vec
|
expression: res_vec
|
||||||
---
|
---
|
||||||
[
|
[
|
||||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar210, typevar211]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar210\", \"typevar211\"]\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A[typevar212, typevar213]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar212\", \"typevar213\"]\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
|
||||||
|
|
|
@ -6,12 +6,12 @@ expression: res_vec
|
||||||
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [230]\n}\n",
|
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [232]\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [238]\n}\n",
|
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [240]\n}\n",
|
||||||
]
|
]
|
||||||
|
|
|
@ -908,6 +908,48 @@ impl<'a> Inferencer<'a> {
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if [
|
||||||
|
"np_min",
|
||||||
|
"np_max",
|
||||||
|
].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 {
|
||||||
|
let arg0 = self.fold_expr(args.remove(0))?;
|
||||||
|
let arg0_ty = arg0.custom.unwrap();
|
||||||
|
|
||||||
|
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
|
||||||
|
let (ndarray_dtype, _) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
|
||||||
|
|
||||||
|
ndarray_dtype
|
||||||
|
} else {
|
||||||
|
arg0_ty
|
||||||
|
};
|
||||||
|
|
||||||
|
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
|
args: vec![
|
||||||
|
FuncArg {
|
||||||
|
name: "a".into(),
|
||||||
|
ty: arg0.custom.unwrap(),
|
||||||
|
default_value: None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
ret,
|
||||||
|
vars: VarMap::new(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
return Ok(Some(Located {
|
||||||
|
location,
|
||||||
|
custom: Some(ret),
|
||||||
|
node: ExprKind::Call {
|
||||||
|
func: Box::new(Located {
|
||||||
|
custom: Some(custom),
|
||||||
|
location: func.location,
|
||||||
|
node: ExprKind::Name { id: *id, ctx: ctx.clone() },
|
||||||
|
}),
|
||||||
|
args: vec![arg0],
|
||||||
|
keywords: vec![],
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
if [
|
if [
|
||||||
"np_arctan2",
|
"np_arctan2",
|
||||||
"np_copysign",
|
"np_copysign",
|
||||||
|
|
|
@ -174,6 +174,8 @@ def patch(module):
|
||||||
# NumPy Math functions
|
# NumPy Math functions
|
||||||
module.np_isnan = np.isnan
|
module.np_isnan = np.isnan
|
||||||
module.np_isinf = np.isinf
|
module.np_isinf = np.isinf
|
||||||
|
module.np_min = np.min
|
||||||
|
module.np_max = np.max
|
||||||
module.np_sin = np.sin
|
module.np_sin = np.sin
|
||||||
module.np_cos = np.cos
|
module.np_cos = np.cos
|
||||||
module.np_exp = np.exp
|
module.np_exp = np.exp
|
||||||
|
|
|
@ -759,6 +759,20 @@ def test_ndarray_ceil():
|
||||||
output_ndarray_int64_2(xf64)
|
output_ndarray_int64_2(xf64)
|
||||||
output_ndarray_float_2(xff)
|
output_ndarray_float_2(xff)
|
||||||
|
|
||||||
|
def test_ndarray_min():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = np_min(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_float64(y)
|
||||||
|
|
||||||
|
def test_ndarray_max():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = np_max(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_float64(y)
|
||||||
|
|
||||||
def test_ndarray_abs():
|
def test_ndarray_abs():
|
||||||
x = np_identity(2)
|
x = np_identity(2)
|
||||||
y = abs(x)
|
y = abs(x)
|
||||||
|
@ -1363,6 +1377,8 @@ def run() -> int32:
|
||||||
|
|
||||||
test_ndarray_round()
|
test_ndarray_round()
|
||||||
test_ndarray_floor()
|
test_ndarray_floor()
|
||||||
|
test_ndarray_min()
|
||||||
|
test_ndarray_max()
|
||||||
test_ndarray_abs()
|
test_ndarray_abs()
|
||||||
test_ndarray_isnan()
|
test_ndarray_isnan()
|
||||||
test_ndarray_isinf()
|
test_ndarray_isinf()
|
||||||
|
|
Loading…
Reference in New Issue