Compare commits

...

4 Commits

11 changed files with 673 additions and 58 deletions

76
Cargo.lock generated
View File

@ -117,9 +117,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]]
name = "cc"
version = "1.0.96"
version = "1.0.97"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "065a29261d53ba54260972629f9ca6bffa69bac13cd1fed61420f7fa68b9f8bd"
checksum = "099a5357d84c4c61eb35fc8eafa9a79a902c2f76911e5747ced4e032edd8d9b4"
[[package]]
name = "cfg-if"
@ -158,7 +158,7 @@ dependencies = [
"heck 0.5.0",
"proc-macro2",
"quote",
"syn 2.0.60",
"syn 2.0.61",
]
[[package]]
@ -276,9 +276,9 @@ checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2"
[[package]]
name = "ena"
version = "0.14.2"
version = "0.14.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c533630cf40e9caa44bd91aadc88a75d75a4c3a12b4cfde353cbed41daa1e1f1"
checksum = "3d248bdd43ce613d87415282f69b9bb99d947d290b10962dd6c56233312c2ad5"
dependencies = [
"log",
]
@ -337,9 +337,9 @@ dependencies = [
[[package]]
name = "getrandom"
version = "0.2.14"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94b22e06ecb0110981051723910cbf0b5f5e09a2062dd7663334ee79a9d1286c"
checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
dependencies = [
"cfg-if",
"libc",
@ -421,7 +421,7 @@ checksum = "4fa4d8d74483041a882adaa9a29f633253a66dde85055f0495c121620ac484b2"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.60",
"syn 2.0.61",
]
[[package]]
@ -747,7 +747,7 @@ dependencies = [
"phf_shared 0.11.2",
"proc-macro2",
"quote",
"syn 2.0.60",
"syn 2.0.61",
]
[[package]]
@ -794,9 +794,9 @@ checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c"
[[package]]
name = "proc-macro2"
version = "1.0.81"
version = "1.0.82"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d1597b0c024618f09a9c3b8655b7e430397a36d23fdafec26d6965e9eec3eba"
checksum = "8ad3d49ab951a01fbaafe34f2ec74122942fe18a3f9814c3268f1bb72042131b"
dependencies = [
"unicode-ident",
]
@ -848,7 +848,7 @@ dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
"syn 2.0.60",
"syn 2.0.61",
]
[[package]]
@ -861,7 +861,7 @@ dependencies = [
"proc-macro2",
"pyo3-build-config",
"quote",
"syn 2.0.60",
"syn 2.0.61",
]
[[package]]
@ -994,15 +994,15 @@ dependencies = [
[[package]]
name = "rustversion"
version = "1.0.15"
version = "1.0.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "80af6f9131f277a45a3fba6ce8e2258037bb0477a67e610d3c1fe046ab31de47"
checksum = "092474d1a01ea8278f69e6a358998405fae5b8b963ddaeb2b0b04a128bf1dfb0"
[[package]]
name = "ryu"
version = "1.0.17"
version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1"
checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
[[package]]
name = "same-file"
@ -1021,35 +1021,35 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "semver"
version = "1.0.22"
version = "1.0.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca"
checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b"
[[package]]
name = "serde"
version = "1.0.200"
version = "1.0.201"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ddc6f9cc94d67c0e21aaf7eda3a010fd3af78ebf6e096aa6e2e13c79749cce4f"
checksum = "780f1cebed1629e4753a1a38a3c72d30b97ec044f0aef68cb26650a3c5cf363c"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.200"
version = "1.0.201"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "856f046b9400cee3c8c94ed572ecdb752444c24528c035cd35882aad6f492bcb"
checksum = "c5e405930b9796f1c00bee880d03fc7e0bb4b9a11afc776885ffe84320da2865"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.60",
"syn 2.0.61",
]
[[package]]
name = "serde_json"
version = "1.0.116"
version = "1.0.117"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3e17db7126d17feb94eb3fad46bf1a96b034e8aacbc2e775fe81505f8b0b2813"
checksum = "455182ea6142b14f93f4bc5320a2b31c1f266b66a4a5c858b013302a5d8cbfc3"
dependencies = [
"itoa",
"ryu",
@ -1129,9 +1129,9 @@ dependencies = [
[[package]]
name = "syn"
version = "2.0.60"
version = "2.0.61"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "909518bc7b1c9b779f1bbf07f2929d35af9f0f37e47c6e9ef7f9dddc1e1821f3"
checksum = "c993ed8ccba56ae856363b1845da7266a7cb78e1d146c8a32d54b45a8b831fc9"
dependencies = [
"proc-macro2",
"quote",
@ -1182,22 +1182,22 @@ dependencies = [
[[package]]
name = "thiserror"
version = "1.0.59"
version = "1.0.60"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0126ad08bff79f29fc3ae6a55cc72352056dfff61e3ff8bb7129476d44b23aa"
checksum = "579e9083ca58dd9dcf91a9923bb9054071b9ebbd800b342194c9feb0ee89fc18"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.59"
version = "1.0.60"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1cd413b5d558b4c5bf3680e324a6fa5014e7b7c067a51e69dbdf47eb7148b66"
checksum = "e2470041c06ec3ac1ab38d0356a6119054dedaea53e12fbefc0de730a1c08524"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.60",
"syn 2.0.61",
]
[[package]]
@ -1450,20 +1450,20 @@ dependencies = [
[[package]]
name = "zerocopy"
version = "0.7.33"
version = "0.7.34"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "087eca3c1eaf8c47b94d02790dd086cd594b912d2043d4de4bfdd466b3befb7c"
checksum = "ae87e3fcd617500e5d106f0380cf7b77f3c6092aae37191433159dda23cfb087"
dependencies = [
"zerocopy-derive",
]
[[package]]
name = "zerocopy-derive"
version = "0.7.33"
version = "0.7.34"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f4b6c273f496d8fd4eaf18853e6b448760225dc030ff2c485a786859aea6393"
checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.60",
"syn 2.0.61",
]

View File

@ -1,11 +1,12 @@
use inkwell::{FloatPredicate, IntPredicate};
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
use inkwell::types::BasicTypeEnum;
use inkwell::values::BasicValueEnum;
use itertools::Itertools;
use crate::codegen::{CodeGenContext, CodeGenerator, extern_fns, irrt, llvm_intrinsics, numpy};
use crate::codegen::classes::NDArrayValue;
use crate::codegen::classes::{NDArrayValue, UntypedArrayLikeAccessor};
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
use crate::codegen::stmt::gen_for_callback_incrementing;
use crate::toplevel::helper::PRIMITIVE_DEF_IDS;
use crate::toplevel::numpy::unpack_ndarray_var_tys;
use crate::typecheck::typedef::Type;
@ -705,6 +706,177 @@ pub fn call_min<'ctx>(
}
}
/// Invokes the `np_min` builtin function.
pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
a: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_min";
let llvm_usize = generator.get_size_type(ctx.ctx);
let (a_ty, a) = a;
Ok(match a {
BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => {
debug_assert!([
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.uint32,
ctx.primitives.int64,
ctx.primitives.uint64,
ctx.primitives.float,
].iter().any(|ty| ctx.unifier.unioned(a_ty, *ty)));
a
}
BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes());
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let n_sz_eqz = ctx.builder
.build_int_compare(
IntPredicate::NE,
n_sz,
n_sz.get_type().const_zero(),
"",
)
.unwrap();
ctx.make_assert(
generator,
n_sz_eqz,
"0:ValueError",
"zero-size array to reduction operation minimum which has no identity",
[None, None, None],
ctx.current_loc,
);
}
let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
unsafe {
let identity = n.data()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
ctx.builder.build_store(accumulator_addr, identity).unwrap();
}
gen_for_callback_incrementing(
generator,
ctx,
llvm_usize.const_int(1, false),
(n_sz, false),
|generator, ctx, idx| {
let elem = unsafe {
n.data().get_unchecked(ctx, generator, &idx, None)
};
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
let result = call_min(ctx, (elem_ty, accumulator), (elem_ty, elem));
ctx.builder.build_store(accumulator_addr, result).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
accumulator
}
_ => unsupported_type(ctx, FN_NAME, &[a_ty])
})
}
/// Invokes the `np_minimum` builtin function.
pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_minimum";
let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
let common_ty = if ctx.unifier.unioned(x1_ty, x2_ty) {
Some(x1_ty)
} else {
None
};
Ok(match (x1, x2) {
(BasicValueEnum::IntValue(x1), BasicValueEnum::IntValue(x2)) => {
debug_assert!([
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.uint32,
ctx.primitives.int64,
ctx.primitives.uint64,
ctx.primitives.float,
].iter().any(|ty| ctx.unifier.unioned(common_ty.unwrap(), *ty)));
call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into()))
}
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
debug_assert!(ctx.unifier.unioned(common_ty.unwrap(), ctx.primitives.float));
call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into()))
}
(x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) => {
let is_ndarray1 = x1_ty.obj_id(&ctx.unifier)
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
let is_ndarray2 = x2_ty.obj_id(&ctx.unifier)
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1
} else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
} else { unreachable!() };
let x1_scalar_ty = if is_ndarray1 {
dtype
} else {
x1_ty
};
let x2_scalar_ty = if is_ndarray2 {
dtype
} else {
x2_ty
};
numpy::ndarray_elementwise_binop_impl(
generator,
ctx,
dtype,
None,
(x1, !is_ndarray1),
(x2, !is_ndarray2),
|generator, ctx, (lhs, rhs)| {
call_numpy_minimum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
},
)?.as_ptr_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
})
}
/// Invokes the `max` builtin function.
pub fn call_max<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
@ -752,6 +924,177 @@ pub fn call_max<'ctx>(
}
}
/// Invokes the `np_max` builtin function.
pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
a: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_max";
let llvm_usize = generator.get_size_type(ctx.ctx);
let (a_ty, a) = a;
Ok(match a {
BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => {
debug_assert!([
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.uint32,
ctx.primitives.int64,
ctx.primitives.uint64,
ctx.primitives.float,
].iter().any(|ty| ctx.unifier.unioned(a_ty, *ty)));
a
}
BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes());
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let n_sz_eqz = ctx.builder
.build_int_compare(
IntPredicate::NE,
n_sz,
n_sz.get_type().const_zero(),
"",
)
.unwrap();
ctx.make_assert(
generator,
n_sz_eqz,
"0:ValueError",
"zero-size array to reduction operation minimum which has no identity",
[None, None, None],
ctx.current_loc,
);
}
let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
unsafe {
let identity = n.data()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
ctx.builder.build_store(accumulator_addr, identity).unwrap();
}
gen_for_callback_incrementing(
generator,
ctx,
llvm_usize.const_int(1, false),
(n_sz, false),
|generator, ctx, idx| {
let elem = unsafe {
n.data().get_unchecked(ctx, generator, &idx, None)
};
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
let result = call_max(ctx, (elem_ty, accumulator), (elem_ty, elem));
ctx.builder.build_store(accumulator_addr, result).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
accumulator
}
_ => unsupported_type(ctx, FN_NAME, &[a_ty])
})
}
/// Invokes the `np_maximum` builtin function.
pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_maximum";
let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
let common_ty = if ctx.unifier.unioned(x1_ty, x2_ty) {
Some(x1_ty)
} else {
None
};
Ok(match (x1, x2) {
(BasicValueEnum::IntValue(x1), BasicValueEnum::IntValue(x2)) => {
debug_assert!([
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.uint32,
ctx.primitives.int64,
ctx.primitives.uint64,
ctx.primitives.float,
].iter().any(|ty| ctx.unifier.unioned(common_ty.unwrap(), *ty)));
call_max(ctx, (x1_ty, x1.into()), (x2_ty, x2.into()))
}
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
debug_assert!(ctx.unifier.unioned(common_ty.unwrap(), ctx.primitives.float));
call_max(ctx, (x1_ty, x1.into()), (x2_ty, x2.into()))
}
(x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) => {
let is_ndarray1 = x1_ty.obj_id(&ctx.unifier)
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
let is_ndarray2 = x2_ty.obj_id(&ctx.unifier)
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1
} else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
} else { unreachable!() };
let x1_scalar_ty = if is_ndarray1 {
dtype
} else {
x1_ty
};
let x2_scalar_ty = if is_ndarray2 {
dtype
} else {
x2_ty
};
numpy::ndarray_elementwise_binop_impl(
generator,
ctx,
dtype,
None,
(x1, !is_ndarray1),
(x2, !is_ndarray2),
|generator, ctx, (lhs, rhs)| {
call_numpy_maximum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
},
)?.as_ptr_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
})
}
/// Invokes the `abs` builtin function.
pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,

View File

@ -1,4 +1,14 @@
use super::*;
use std::iter::once;
use indexmap::IndexMap;
use inkwell::{
attributes::{Attribute, AttributeLoc},
IntPredicate,
types::{BasicMetadataTypeEnum, BasicType},
values::{BasicMetadataValueEnum, BasicValue, CallSiteValue}
};
use itertools::Either;
use crate::{
codegen::{
builtin_fns,
@ -15,13 +25,8 @@ use crate::{
},
typecheck::typedef::VarMap,
};
use inkwell::{
attributes::{Attribute, AttributeLoc},
types::{BasicType, BasicMetadataTypeEnum},
values::{BasicValue, BasicMetadataValueEnum, CallSiteValue},
IntPredicate
};
use itertools::Either;
use super::*;
type BuiltinInfo = Vec<(Arc<RwLock<TopLevelDef>>, Option<Stmt>)>;
@ -1378,6 +1383,65 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built
)))),
loc: None,
})),
{
let ret_ty = unifier.get_fresh_var(Some("R".into()), None);
let var_map = num_or_ndarray_var_map.clone()
.into_iter()
.chain(once((ret_ty.1, ret_ty.0)))
.collect::<IndexMap<_, _>>();
create_fn_by_codegen(
unifier,
&var_map,
"np_min",
ret_ty.0,
&[(float_or_ndarray_ty.0, "a")],
Box::new(|ctx, _, fun, args, generator| {
let a_ty = fun.0.args[0].ty;
let a = args[0].1.clone()
.to_basic_value_enum(ctx, generator, a_ty)?;
Ok(Some(builtin_fns::call_numpy_min(generator, ctx, (a_ty, a))?))
}),
)
},
{
let x1_ty = new_type_or_ndarray_ty(unifier, primitives, num_ty.0);
let x2_ty = new_type_or_ndarray_ty(unifier, primitives, num_ty.0);
let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")];
let ret_ty = unifier.get_fresh_var(None, None);
Arc::new(RwLock::new(TopLevelDef::Function {
name: "np_minimum".into(),
simple_name: "np_minimum".into(),
signature: unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: param_ty.iter().map(|p| FuncArg {
name: p.1.into(),
ty: p.0,
default_value: None,
}).collect(),
ret: ret_ty.0,
vars: [
(x1_ty.1, x1_ty.0),
(x2_ty.1, x2_ty.0),
(ret_ty.1, ret_ty.0),
].into_iter().collect(),
})),
var_id: vec![x1_ty.1, x2_ty.1],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x2_ty = fun.0.args[1].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(builtin_fns::call_numpy_minimum(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
})))),
loc: None,
}))
},
Arc::new(RwLock::new(TopLevelDef::Function {
name: "max".into(),
simple_name: "max".into(),
@ -1405,6 +1469,65 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built
)))),
loc: None,
})),
{
let ret_ty = unifier.get_fresh_var(Some("R".into()), None);
let var_map = num_or_ndarray_var_map.clone()
.into_iter()
.chain(once((ret_ty.1, ret_ty.0)))
.collect::<IndexMap<_, _>>();
create_fn_by_codegen(
unifier,
&var_map,
"np_max",
ret_ty.0,
&[(float_or_ndarray_ty.0, "a")],
Box::new(|ctx, _, fun, args, generator| {
let a_ty = fun.0.args[0].ty;
let a = args[0].1.clone()
.to_basic_value_enum(ctx, generator, a_ty)?;
Ok(Some(builtin_fns::call_numpy_max(generator, ctx, (a_ty, a))?))
}),
)
},
{
let x1_ty = new_type_or_ndarray_ty(unifier, primitives, num_ty.0);
let x2_ty = new_type_or_ndarray_ty(unifier, primitives, num_ty.0);
let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")];
let ret_ty = unifier.get_fresh_var(None, None);
Arc::new(RwLock::new(TopLevelDef::Function {
name: "np_maximum".into(),
simple_name: "np_maximum".into(),
signature: unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: param_ty.iter().map(|p| FuncArg {
name: p.1.into(),
ty: p.0,
default_value: None,
}).collect(),
ret: ret_ty.0,
vars: [
(x1_ty.1, x1_ty.0),
(x2_ty.1, x2_ty.0),
(ret_ty.1, ret_ty.0),
].into_iter().collect(),
})),
var_id: vec![x1_ty.1, x2_ty.1],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x2_ty = fun.0.args[1].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(builtin_fns::call_numpy_maximum(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
})))),
loc: None,
}))
},
Arc::new(RwLock::new(TopLevelDef::Function {
name: "abs".into(),
simple_name: "abs".into(),

View File

@ -5,7 +5,7 @@ expression: res_vec
[
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [222]\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [238]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",

View File

@ -7,7 +7,7 @@ expression: res_vec
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar211]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar211\"]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar227]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar227\"]\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",

View File

@ -5,8 +5,8 @@ expression: res_vec
[
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [224]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [229]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [240]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [245]\n}\n",
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",

View File

@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec
---
[
"Class {\nname: \"A\",\nancestors: [\"A[typevar210, typevar211]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar210\", \"typevar211\"]\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[typevar226, typevar227]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar226\", \"typevar227\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",

View File

@ -6,12 +6,12 @@ expression: res_vec
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [230]\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [246]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [238]\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [254]\n}\n",
]

View File

@ -905,6 +905,50 @@ impl<'a> Inferencer<'a> {
}
if [
"np_min",
"np_max",
].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 {
let arg0 = self.fold_expr(args.remove(0))?;
let arg0_ty = arg0.custom.unwrap();
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
let (ndarray_dtype, _) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
ndarray_dtype
} else {
arg0_ty
};
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg {
name: "a".into(),
ty: arg0.custom.unwrap(),
default_value: None,
},
],
ret,
vars: VarMap::new(),
}));
return Ok(Some(Located {
location,
custom: Some(ret),
node: ExprKind::Call {
func: Box::new(Located {
custom: Some(custom),
location: func.location,
node: ExprKind::Name { id: *id, ctx: ctx.clone() },
}),
args: vec![arg0],
keywords: vec![],
},
}))
}
if [
"np_minimum",
"np_maximum",
"np_arctan2",
"np_copysign",
"np_fmax",
@ -913,8 +957,6 @@ impl<'a> Inferencer<'a> {
"np_hypot",
"np_nextafter",
].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 2 {
let target_ty = self.primitives.float;
let arg0 = self.fold_expr(args.remove(0))?;
let arg0_ty = arg0.custom.unwrap();
let arg1 = self.fold_expr(args.remove(0))?;
@ -931,6 +973,7 @@ impl<'a> Inferencer<'a> {
} else {
arg1_ty
};
let expected_arg1_dtype = if id == &"np_ldexp".into() {
self.primitives.int32
} else {
@ -939,7 +982,7 @@ impl<'a> Inferencer<'a> {
if !self.unifier.unioned(arg1_dtype, expected_arg1_dtype) {
return report_error(
format!(
"Expected {} for second argument of {id}, got {}",
"Expected broadcast-compatible type of ndarray[{}, N] for second argument of {id}, got {}",
self.unifier.stringify(expected_arg1_dtype),
self.unifier.stringify(arg1_dtype),
).as_str(),
@ -947,6 +990,12 @@ impl<'a> Inferencer<'a> {
)
}
let target_ty = if id == &"np_minimum".into() || id == &"np_maximum".into() {
arg0_dtype
} else {
self.primitives.float
};
let ret = if [
&arg0_ty,
&arg1_ty,

View File

@ -174,6 +174,10 @@ def patch(module):
# NumPy Math functions
module.np_isnan = np.isnan
module.np_isinf = np.isinf
module.np_min = np.min
module.np_minimum = np.minimum
module.np_max = np.max
module.np_maximum = np.maximum
module.np_sin = np.sin
module.np_cos = np.cos
module.np_exp = np.exp

View File

@ -759,6 +759,92 @@ def test_ndarray_ceil():
output_ndarray_int64_2(xf64)
output_ndarray_float_2(xff)
def test_ndarray_min():
x = np_identity(2)
y = np_min(x)
output_ndarray_float_2(x)
output_float64(y)
def test_ndarray_minimum():
x = np_identity(2)
min_x_zeros = np_minimum(x, np_zeros([2]))
min_x_ones = np_minimum(x, np_zeros([2]))
output_ndarray_float_2(x)
output_ndarray_float_2(min_x_zeros)
output_ndarray_float_2(min_x_ones)
def test_ndarray_minimum_broadcast():
x = np_identity(2)
min_x_zeros = np_minimum(x, np_zeros([2]))
min_x_ones = np_minimum(x, np_zeros([2]))
output_ndarray_float_2(x)
output_ndarray_float_2(min_x_zeros)
output_ndarray_float_2(min_x_ones)
def test_ndarray_minimum_broadcast_lhs_scalar():
x = np_identity(2)
min_x_zeros = np_minimum(0.0, x)
min_x_ones = np_minimum(1.0, x)
output_ndarray_float_2(x)
output_ndarray_float_2(min_x_zeros)
output_ndarray_float_2(min_x_ones)
def test_ndarray_minimum_broadcast_rhs_scalar():
x = np_identity(2)
min_x_zeros = np_minimum(x, 0.0)
min_x_ones = np_minimum(x, 1.0)
output_ndarray_float_2(x)
output_ndarray_float_2(min_x_zeros)
output_ndarray_float_2(min_x_ones)
def test_ndarray_max():
x = np_identity(2)
y = np_max(x)
output_ndarray_float_2(x)
output_float64(y)
def test_ndarray_maximum():
x = np_identity(2)
max_x_zeros = np_maximum(x, np_zeros([2]))
max_x_ones = np_maximum(x, np_zeros([2]))
output_ndarray_float_2(x)
output_ndarray_float_2(max_x_zeros)
output_ndarray_float_2(max_x_ones)
def test_ndarray_maximum_broadcast():
x = np_identity(2)
max_x_zeros = np_maximum(x, np_zeros([2]))
max_x_ones = np_maximum(x, np_zeros([2]))
output_ndarray_float_2(x)
output_ndarray_float_2(max_x_zeros)
output_ndarray_float_2(max_x_ones)
def test_ndarray_maximum_broadcast_lhs_scalar():
x = np_identity(2)
max_x_zeros = np_maximum(0.0, x)
max_x_ones = np_maximum(1.0, x)
output_ndarray_float_2(x)
output_ndarray_float_2(max_x_zeros)
output_ndarray_float_2(max_x_ones)
def test_ndarray_maximum_broadcast_rhs_scalar():
x = np_identity(2)
max_x_zeros = np_maximum(x, 0.0)
max_x_ones = np_maximum(x, 1.0)
output_ndarray_float_2(x)
output_ndarray_float_2(max_x_zeros)
output_ndarray_float_2(max_x_ones)
def test_ndarray_abs():
x = np_identity(2)
y = abs(x)
@ -1363,6 +1449,16 @@ def run() -> int32:
test_ndarray_round()
test_ndarray_floor()
test_ndarray_min()
test_ndarray_minimum()
test_ndarray_minimum_broadcast()
test_ndarray_minimum_broadcast_lhs_scalar()
test_ndarray_minimum_broadcast_rhs_scalar()
test_ndarray_max()
test_ndarray_maximum()
test_ndarray_maximum_broadcast()
test_ndarray_maximum_broadcast_lhs_scalar()
test_ndarray_maximum_broadcast_rhs_scalar()
test_ndarray_abs()
test_ndarray_isnan()
test_ndarray_isinf()