Compare commits
4 Commits
30c6cffbad
...
09b51d0e66
Author | SHA1 | Date |
---|---|---|
David Mak | 09b51d0e66 | |
David Mak | 0c94667cf6 | |
David Mak | 48a409c918 | |
David Mak | 876e6ea7b8 |
|
@ -117,9 +117,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
|
|||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.0.96"
|
||||
version = "1.0.97"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "065a29261d53ba54260972629f9ca6bffa69bac13cd1fed61420f7fa68b9f8bd"
|
||||
checksum = "099a5357d84c4c61eb35fc8eafa9a79a902c2f76911e5747ced4e032edd8d9b4"
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
|
@ -158,7 +158,7 @@ dependencies = [
|
|||
"heck 0.5.0",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.60",
|
||||
"syn 2.0.61",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -276,9 +276,9 @@ checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2"
|
|||
|
||||
[[package]]
|
||||
name = "ena"
|
||||
version = "0.14.2"
|
||||
version = "0.14.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c533630cf40e9caa44bd91aadc88a75d75a4c3a12b4cfde353cbed41daa1e1f1"
|
||||
checksum = "3d248bdd43ce613d87415282f69b9bb99d947d290b10962dd6c56233312c2ad5"
|
||||
dependencies = [
|
||||
"log",
|
||||
]
|
||||
|
@ -337,9 +337,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.2.14"
|
||||
version = "0.2.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "94b22e06ecb0110981051723910cbf0b5f5e09a2062dd7663334ee79a9d1286c"
|
||||
checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
|
@ -421,7 +421,7 @@ checksum = "4fa4d8d74483041a882adaa9a29f633253a66dde85055f0495c121620ac484b2"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.60",
|
||||
"syn 2.0.61",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -747,7 +747,7 @@ dependencies = [
|
|||
"phf_shared 0.11.2",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.60",
|
||||
"syn 2.0.61",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -794,9 +794,9 @@ checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c"
|
|||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.81"
|
||||
version = "1.0.82"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3d1597b0c024618f09a9c3b8655b7e430397a36d23fdafec26d6965e9eec3eba"
|
||||
checksum = "8ad3d49ab951a01fbaafe34f2ec74122942fe18a3f9814c3268f1bb72042131b"
|
||||
dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
@ -848,7 +848,7 @@ dependencies = [
|
|||
"proc-macro2",
|
||||
"pyo3-macros-backend",
|
||||
"quote",
|
||||
"syn 2.0.60",
|
||||
"syn 2.0.61",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -861,7 +861,7 @@ dependencies = [
|
|||
"proc-macro2",
|
||||
"pyo3-build-config",
|
||||
"quote",
|
||||
"syn 2.0.60",
|
||||
"syn 2.0.61",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -994,15 +994,15 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "rustversion"
|
||||
version = "1.0.15"
|
||||
version = "1.0.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "80af6f9131f277a45a3fba6ce8e2258037bb0477a67e610d3c1fe046ab31de47"
|
||||
checksum = "092474d1a01ea8278f69e6a358998405fae5b8b963ddaeb2b0b04a128bf1dfb0"
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
version = "1.0.17"
|
||||
version = "1.0.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1"
|
||||
checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
|
||||
|
||||
[[package]]
|
||||
name = "same-file"
|
||||
|
@ -1021,35 +1021,35 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
|||
|
||||
[[package]]
|
||||
name = "semver"
|
||||
version = "1.0.22"
|
||||
version = "1.0.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca"
|
||||
checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b"
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.200"
|
||||
version = "1.0.201"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ddc6f9cc94d67c0e21aaf7eda3a010fd3af78ebf6e096aa6e2e13c79749cce4f"
|
||||
checksum = "780f1cebed1629e4753a1a38a3c72d30b97ec044f0aef68cb26650a3c5cf363c"
|
||||
dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.200"
|
||||
version = "1.0.201"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "856f046b9400cee3c8c94ed572ecdb752444c24528c035cd35882aad6f492bcb"
|
||||
checksum = "c5e405930b9796f1c00bee880d03fc7e0bb4b9a11afc776885ffe84320da2865"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.60",
|
||||
"syn 2.0.61",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.116"
|
||||
version = "1.0.117"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3e17db7126d17feb94eb3fad46bf1a96b034e8aacbc2e775fe81505f8b0b2813"
|
||||
checksum = "455182ea6142b14f93f4bc5320a2b31c1f266b66a4a5c858b013302a5d8cbfc3"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"ryu",
|
||||
|
@ -1129,9 +1129,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.60"
|
||||
version = "2.0.61"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "909518bc7b1c9b779f1bbf07f2929d35af9f0f37e47c6e9ef7f9dddc1e1821f3"
|
||||
checksum = "c993ed8ccba56ae856363b1845da7266a7cb78e1d146c8a32d54b45a8b831fc9"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
|
@ -1182,22 +1182,22 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.59"
|
||||
version = "1.0.60"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f0126ad08bff79f29fc3ae6a55cc72352056dfff61e3ff8bb7129476d44b23aa"
|
||||
checksum = "579e9083ca58dd9dcf91a9923bb9054071b9ebbd800b342194c9feb0ee89fc18"
|
||||
dependencies = [
|
||||
"thiserror-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror-impl"
|
||||
version = "1.0.59"
|
||||
version = "1.0.60"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d1cd413b5d558b4c5bf3680e324a6fa5014e7b7c067a51e69dbdf47eb7148b66"
|
||||
checksum = "e2470041c06ec3ac1ab38d0356a6119054dedaea53e12fbefc0de730a1c08524"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.60",
|
||||
"syn 2.0.61",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -1450,20 +1450,20 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "zerocopy"
|
||||
version = "0.7.33"
|
||||
version = "0.7.34"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "087eca3c1eaf8c47b94d02790dd086cd594b912d2043d4de4bfdd466b3befb7c"
|
||||
checksum = "ae87e3fcd617500e5d106f0380cf7b77f3c6092aae37191433159dda23cfb087"
|
||||
dependencies = [
|
||||
"zerocopy-derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zerocopy-derive"
|
||||
version = "0.7.33"
|
||||
version = "0.7.34"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6f4b6c273f496d8fd4eaf18853e6b448760225dc030ff2c485a786859aea6393"
|
||||
checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.60",
|
||||
"syn 2.0.61",
|
||||
]
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
use inkwell::{FloatPredicate, IntPredicate};
|
||||
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
|
||||
use inkwell::types::BasicTypeEnum;
|
||||
use inkwell::values::BasicValueEnum;
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::codegen::{CodeGenContext, CodeGenerator, extern_fns, irrt, llvm_intrinsics, numpy};
|
||||
use crate::codegen::classes::NDArrayValue;
|
||||
use crate::codegen::classes::{NDArrayValue, UntypedArrayLikeAccessor};
|
||||
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
|
||||
use crate::codegen::stmt::gen_for_callback_incrementing;
|
||||
use crate::toplevel::helper::PRIMITIVE_DEF_IDS;
|
||||
use crate::toplevel::numpy::unpack_ndarray_var_tys;
|
||||
use crate::typecheck::typedef::Type;
|
||||
|
@ -705,6 +706,177 @@ pub fn call_min<'ctx>(
|
|||
}
|
||||
}
|
||||
|
||||
/// Invokes the `np_min` builtin function.
|
||||
pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
a: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_min";
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (a_ty, a) = a;
|
||||
|
||||
Ok(match a {
|
||||
BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => {
|
||||
debug_assert!([
|
||||
ctx.primitives.bool,
|
||||
ctx.primitives.int32,
|
||||
ctx.primitives.uint32,
|
||||
ctx.primitives.int64,
|
||||
ctx.primitives.uint64,
|
||||
ctx.primitives.float,
|
||||
].iter().any(|ty| ctx.unifier.unioned(a_ty, *ty)));
|
||||
|
||||
a
|
||||
}
|
||||
|
||||
BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
|
||||
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
|
||||
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
|
||||
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes());
|
||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||
let n_sz_eqz = ctx.builder
|
||||
.build_int_compare(
|
||||
IntPredicate::NE,
|
||||
n_sz,
|
||||
n_sz.get_type().const_zero(),
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
n_sz_eqz,
|
||||
"0:ValueError",
|
||||
"zero-size array to reduction operation minimum which has no identity",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
}
|
||||
|
||||
let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
|
||||
unsafe {
|
||||
let identity = n.data()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
|
||||
ctx.builder.build_store(accumulator_addr, identity).unwrap();
|
||||
}
|
||||
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
ctx,
|
||||
llvm_usize.const_int(1, false),
|
||||
(n_sz, false),
|
||||
|generator, ctx, idx| {
|
||||
let elem = unsafe {
|
||||
n.data().get_unchecked(ctx, generator, &idx, None)
|
||||
};
|
||||
|
||||
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
|
||||
let result = call_min(ctx, (elem_ty, accumulator), (elem_ty, elem));
|
||||
ctx.builder.build_store(accumulator_addr, result).unwrap();
|
||||
|
||||
Ok(())
|
||||
},
|
||||
llvm_usize.const_int(1, false),
|
||||
)?;
|
||||
|
||||
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
|
||||
accumulator
|
||||
}
|
||||
|
||||
_ => unsupported_type(ctx, FN_NAME, &[a_ty])
|
||||
})
|
||||
}
|
||||
|
||||
/// Invokes the `np_minimum` builtin function.
|
||||
pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
x2: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_minimum";
|
||||
|
||||
let (x1_ty, x1) = x1;
|
||||
let (x2_ty, x2) = x2;
|
||||
|
||||
let common_ty = if ctx.unifier.unioned(x1_ty, x2_ty) {
|
||||
Some(x1_ty)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(match (x1, x2) {
|
||||
(BasicValueEnum::IntValue(x1), BasicValueEnum::IntValue(x2)) => {
|
||||
debug_assert!([
|
||||
ctx.primitives.bool,
|
||||
ctx.primitives.int32,
|
||||
ctx.primitives.uint32,
|
||||
ctx.primitives.int64,
|
||||
ctx.primitives.uint64,
|
||||
ctx.primitives.float,
|
||||
].iter().any(|ty| ctx.unifier.unioned(common_ty.unwrap(), *ty)));
|
||||
|
||||
call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into()))
|
||||
}
|
||||
|
||||
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
|
||||
debug_assert!(ctx.unifier.unioned(common_ty.unwrap(), ctx.primitives.float));
|
||||
|
||||
call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into()))
|
||||
}
|
||||
|
||||
(x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) => {
|
||||
let is_ndarray1 = x1_ty.obj_id(&ctx.unifier)
|
||||
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
|
||||
let is_ndarray2 = x2_ty.obj_id(&ctx.unifier)
|
||||
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
|
||||
|
||||
let dtype = if is_ndarray1 && is_ndarray2 {
|
||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
||||
|
||||
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||
|
||||
ndarray_dtype1
|
||||
} else if is_ndarray1 {
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
|
||||
} else if is_ndarray2 {
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
||||
} else { unreachable!() };
|
||||
|
||||
let x1_scalar_ty = if is_ndarray1 {
|
||||
dtype
|
||||
} else {
|
||||
x1_ty
|
||||
};
|
||||
let x2_scalar_ty = if is_ndarray2 {
|
||||
dtype
|
||||
} else {
|
||||
x2_ty
|
||||
};
|
||||
|
||||
numpy::ndarray_elementwise_binop_impl(
|
||||
generator,
|
||||
ctx,
|
||||
dtype,
|
||||
None,
|
||||
(x1, !is_ndarray1),
|
||||
(x2, !is_ndarray2),
|
||||
|generator, ctx, (lhs, rhs)| {
|
||||
call_numpy_minimum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
|
||||
},
|
||||
)?.as_ptr_value().into()
|
||||
}
|
||||
|
||||
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
|
||||
})
|
||||
}
|
||||
|
||||
/// Invokes the `max` builtin function.
|
||||
pub fn call_max<'ctx>(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
|
@ -752,6 +924,177 @@ pub fn call_max<'ctx>(
|
|||
}
|
||||
}
|
||||
|
||||
/// Invokes the `np_max` builtin function.
|
||||
pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
a: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_max";
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (a_ty, a) = a;
|
||||
|
||||
Ok(match a {
|
||||
BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => {
|
||||
debug_assert!([
|
||||
ctx.primitives.bool,
|
||||
ctx.primitives.int32,
|
||||
ctx.primitives.uint32,
|
||||
ctx.primitives.int64,
|
||||
ctx.primitives.uint64,
|
||||
ctx.primitives.float,
|
||||
].iter().any(|ty| ctx.unifier.unioned(a_ty, *ty)));
|
||||
|
||||
a
|
||||
}
|
||||
|
||||
BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
|
||||
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
|
||||
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
|
||||
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes());
|
||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||
let n_sz_eqz = ctx.builder
|
||||
.build_int_compare(
|
||||
IntPredicate::NE,
|
||||
n_sz,
|
||||
n_sz.get_type().const_zero(),
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
n_sz_eqz,
|
||||
"0:ValueError",
|
||||
"zero-size array to reduction operation minimum which has no identity",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
}
|
||||
|
||||
let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
|
||||
unsafe {
|
||||
let identity = n.data()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
|
||||
ctx.builder.build_store(accumulator_addr, identity).unwrap();
|
||||
}
|
||||
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
ctx,
|
||||
llvm_usize.const_int(1, false),
|
||||
(n_sz, false),
|
||||
|generator, ctx, idx| {
|
||||
let elem = unsafe {
|
||||
n.data().get_unchecked(ctx, generator, &idx, None)
|
||||
};
|
||||
|
||||
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
|
||||
let result = call_max(ctx, (elem_ty, accumulator), (elem_ty, elem));
|
||||
ctx.builder.build_store(accumulator_addr, result).unwrap();
|
||||
|
||||
Ok(())
|
||||
},
|
||||
llvm_usize.const_int(1, false),
|
||||
)?;
|
||||
|
||||
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
|
||||
accumulator
|
||||
}
|
||||
|
||||
_ => unsupported_type(ctx, FN_NAME, &[a_ty])
|
||||
})
|
||||
}
|
||||
|
||||
/// Invokes the `np_maximum` builtin function.
|
||||
pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
x2: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_maximum";
|
||||
|
||||
let (x1_ty, x1) = x1;
|
||||
let (x2_ty, x2) = x2;
|
||||
|
||||
let common_ty = if ctx.unifier.unioned(x1_ty, x2_ty) {
|
||||
Some(x1_ty)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(match (x1, x2) {
|
||||
(BasicValueEnum::IntValue(x1), BasicValueEnum::IntValue(x2)) => {
|
||||
debug_assert!([
|
||||
ctx.primitives.bool,
|
||||
ctx.primitives.int32,
|
||||
ctx.primitives.uint32,
|
||||
ctx.primitives.int64,
|
||||
ctx.primitives.uint64,
|
||||
ctx.primitives.float,
|
||||
].iter().any(|ty| ctx.unifier.unioned(common_ty.unwrap(), *ty)));
|
||||
|
||||
call_max(ctx, (x1_ty, x1.into()), (x2_ty, x2.into()))
|
||||
}
|
||||
|
||||
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
|
||||
debug_assert!(ctx.unifier.unioned(common_ty.unwrap(), ctx.primitives.float));
|
||||
|
||||
call_max(ctx, (x1_ty, x1.into()), (x2_ty, x2.into()))
|
||||
}
|
||||
|
||||
(x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) => {
|
||||
let is_ndarray1 = x1_ty.obj_id(&ctx.unifier)
|
||||
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
|
||||
let is_ndarray2 = x2_ty.obj_id(&ctx.unifier)
|
||||
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
|
||||
|
||||
let dtype = if is_ndarray1 && is_ndarray2 {
|
||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
||||
|
||||
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||
|
||||
ndarray_dtype1
|
||||
} else if is_ndarray1 {
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
|
||||
} else if is_ndarray2 {
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
||||
} else { unreachable!() };
|
||||
|
||||
let x1_scalar_ty = if is_ndarray1 {
|
||||
dtype
|
||||
} else {
|
||||
x1_ty
|
||||
};
|
||||
let x2_scalar_ty = if is_ndarray2 {
|
||||
dtype
|
||||
} else {
|
||||
x2_ty
|
||||
};
|
||||
|
||||
numpy::ndarray_elementwise_binop_impl(
|
||||
generator,
|
||||
ctx,
|
||||
dtype,
|
||||
None,
|
||||
(x1, !is_ndarray1),
|
||||
(x2, !is_ndarray2),
|
||||
|generator, ctx, (lhs, rhs)| {
|
||||
call_numpy_maximum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
|
||||
},
|
||||
)?.as_ptr_value().into()
|
||||
}
|
||||
|
||||
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
|
||||
})
|
||||
}
|
||||
|
||||
/// Invokes the `abs` builtin function.
|
||||
pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
|
|
|
@ -1,4 +1,14 @@
|
|||
use super::*;
|
||||
use std::iter::once;
|
||||
|
||||
use indexmap::IndexMap;
|
||||
use inkwell::{
|
||||
attributes::{Attribute, AttributeLoc},
|
||||
IntPredicate,
|
||||
types::{BasicMetadataTypeEnum, BasicType},
|
||||
values::{BasicMetadataValueEnum, BasicValue, CallSiteValue}
|
||||
};
|
||||
use itertools::Either;
|
||||
|
||||
use crate::{
|
||||
codegen::{
|
||||
builtin_fns,
|
||||
|
@ -15,13 +25,8 @@ use crate::{
|
|||
},
|
||||
typecheck::typedef::VarMap,
|
||||
};
|
||||
use inkwell::{
|
||||
attributes::{Attribute, AttributeLoc},
|
||||
types::{BasicType, BasicMetadataTypeEnum},
|
||||
values::{BasicValue, BasicMetadataValueEnum, CallSiteValue},
|
||||
IntPredicate
|
||||
};
|
||||
use itertools::Either;
|
||||
|
||||
use super::*;
|
||||
|
||||
type BuiltinInfo = Vec<(Arc<RwLock<TopLevelDef>>, Option<Stmt>)>;
|
||||
|
||||
|
@ -1378,6 +1383,65 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built
|
|||
)))),
|
||||
loc: None,
|
||||
})),
|
||||
{
|
||||
let ret_ty = unifier.get_fresh_var(Some("R".into()), None);
|
||||
let var_map = num_or_ndarray_var_map.clone()
|
||||
.into_iter()
|
||||
.chain(once((ret_ty.1, ret_ty.0)))
|
||||
.collect::<IndexMap<_, _>>();
|
||||
|
||||
create_fn_by_codegen(
|
||||
unifier,
|
||||
&var_map,
|
||||
"np_min",
|
||||
ret_ty.0,
|
||||
&[(float_or_ndarray_ty.0, "a")],
|
||||
Box::new(|ctx, _, fun, args, generator| {
|
||||
let a_ty = fun.0.args[0].ty;
|
||||
let a = args[0].1.clone()
|
||||
.to_basic_value_enum(ctx, generator, a_ty)?;
|
||||
|
||||
Ok(Some(builtin_fns::call_numpy_min(generator, ctx, (a_ty, a))?))
|
||||
}),
|
||||
)
|
||||
},
|
||||
{
|
||||
let x1_ty = new_type_or_ndarray_ty(unifier, primitives, num_ty.0);
|
||||
let x2_ty = new_type_or_ndarray_ty(unifier, primitives, num_ty.0);
|
||||
let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")];
|
||||
let ret_ty = unifier.get_fresh_var(None, None);
|
||||
|
||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||
name: "np_minimum".into(),
|
||||
simple_name: "np_minimum".into(),
|
||||
signature: unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: param_ty.iter().map(|p| FuncArg {
|
||||
name: p.1.into(),
|
||||
ty: p.0,
|
||||
default_value: None,
|
||||
}).collect(),
|
||||
ret: ret_ty.0,
|
||||
vars: [
|
||||
(x1_ty.1, x1_ty.0),
|
||||
(x2_ty.1, x2_ty.0),
|
||||
(ret_ty.1, ret_ty.0),
|
||||
].into_iter().collect(),
|
||||
})),
|
||||
var_id: vec![x1_ty.1, x2_ty.1],
|
||||
instance_to_symbol: HashMap::default(),
|
||||
instance_to_stmt: HashMap::default(),
|
||||
resolver: None,
|
||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| {
|
||||
let x1_ty = fun.0.args[0].ty;
|
||||
let x2_ty = fun.0.args[1].ty;
|
||||
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
||||
|
||||
Ok(Some(builtin_fns::call_numpy_minimum(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
|
||||
})))),
|
||||
loc: None,
|
||||
}))
|
||||
},
|
||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||
name: "max".into(),
|
||||
simple_name: "max".into(),
|
||||
|
@ -1405,6 +1469,65 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built
|
|||
)))),
|
||||
loc: None,
|
||||
})),
|
||||
{
|
||||
let ret_ty = unifier.get_fresh_var(Some("R".into()), None);
|
||||
let var_map = num_or_ndarray_var_map.clone()
|
||||
.into_iter()
|
||||
.chain(once((ret_ty.1, ret_ty.0)))
|
||||
.collect::<IndexMap<_, _>>();
|
||||
|
||||
create_fn_by_codegen(
|
||||
unifier,
|
||||
&var_map,
|
||||
"np_max",
|
||||
ret_ty.0,
|
||||
&[(float_or_ndarray_ty.0, "a")],
|
||||
Box::new(|ctx, _, fun, args, generator| {
|
||||
let a_ty = fun.0.args[0].ty;
|
||||
let a = args[0].1.clone()
|
||||
.to_basic_value_enum(ctx, generator, a_ty)?;
|
||||
|
||||
Ok(Some(builtin_fns::call_numpy_max(generator, ctx, (a_ty, a))?))
|
||||
}),
|
||||
)
|
||||
},
|
||||
{
|
||||
let x1_ty = new_type_or_ndarray_ty(unifier, primitives, num_ty.0);
|
||||
let x2_ty = new_type_or_ndarray_ty(unifier, primitives, num_ty.0);
|
||||
let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")];
|
||||
let ret_ty = unifier.get_fresh_var(None, None);
|
||||
|
||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||
name: "np_maximum".into(),
|
||||
simple_name: "np_maximum".into(),
|
||||
signature: unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: param_ty.iter().map(|p| FuncArg {
|
||||
name: p.1.into(),
|
||||
ty: p.0,
|
||||
default_value: None,
|
||||
}).collect(),
|
||||
ret: ret_ty.0,
|
||||
vars: [
|
||||
(x1_ty.1, x1_ty.0),
|
||||
(x2_ty.1, x2_ty.0),
|
||||
(ret_ty.1, ret_ty.0),
|
||||
].into_iter().collect(),
|
||||
})),
|
||||
var_id: vec![x1_ty.1, x2_ty.1],
|
||||
instance_to_symbol: HashMap::default(),
|
||||
instance_to_stmt: HashMap::default(),
|
||||
resolver: None,
|
||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| {
|
||||
let x1_ty = fun.0.args[0].ty;
|
||||
let x2_ty = fun.0.args[1].ty;
|
||||
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
||||
|
||||
Ok(Some(builtin_fns::call_numpy_maximum(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
|
||||
})))),
|
||||
loc: None,
|
||||
}))
|
||||
},
|
||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||
name: "abs".into(),
|
||||
simple_name: "abs".into(),
|
||||
|
|
|
@ -5,7 +5,7 @@ expression: res_vec
|
|||
[
|
||||
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
||||
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [222]\n}\n",
|
||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [238]\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
||||
|
|
|
@ -7,7 +7,7 @@ expression: res_vec
|
|||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar211]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar211\"]\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar227]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar227\"]\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
||||
|
|
|
@ -5,8 +5,8 @@ expression: res_vec
|
|||
[
|
||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [224]\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [229]\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [240]\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [245]\n}\n",
|
||||
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
|
|
|
@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
|
|||
expression: res_vec
|
||||
---
|
||||
[
|
||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar210, typevar211]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar210\", \"typevar211\"]\n}\n",
|
||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar226, typevar227]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar226\", \"typevar227\"]\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
|
||||
|
|
|
@ -6,12 +6,12 @@ expression: res_vec
|
|||
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [230]\n}\n",
|
||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [246]\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [238]\n}\n",
|
||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [254]\n}\n",
|
||||
]
|
||||
|
|
|
@ -905,6 +905,50 @@ impl<'a> Inferencer<'a> {
|
|||
}
|
||||
|
||||
if [
|
||||
"np_min",
|
||||
"np_max",
|
||||
].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 {
|
||||
let arg0 = self.fold_expr(args.remove(0))?;
|
||||
let arg0_ty = arg0.custom.unwrap();
|
||||
|
||||
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
|
||||
let (ndarray_dtype, _) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
|
||||
|
||||
ndarray_dtype
|
||||
} else {
|
||||
arg0_ty
|
||||
};
|
||||
|
||||
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![
|
||||
FuncArg {
|
||||
name: "a".into(),
|
||||
ty: arg0.custom.unwrap(),
|
||||
default_value: None,
|
||||
},
|
||||
],
|
||||
ret,
|
||||
vars: VarMap::new(),
|
||||
}));
|
||||
|
||||
return Ok(Some(Located {
|
||||
location,
|
||||
custom: Some(ret),
|
||||
node: ExprKind::Call {
|
||||
func: Box::new(Located {
|
||||
custom: Some(custom),
|
||||
location: func.location,
|
||||
node: ExprKind::Name { id: *id, ctx: ctx.clone() },
|
||||
}),
|
||||
args: vec![arg0],
|
||||
keywords: vec![],
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
if [
|
||||
"np_minimum",
|
||||
"np_maximum",
|
||||
"np_arctan2",
|
||||
"np_copysign",
|
||||
"np_fmax",
|
||||
|
@ -913,8 +957,6 @@ impl<'a> Inferencer<'a> {
|
|||
"np_hypot",
|
||||
"np_nextafter",
|
||||
].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 2 {
|
||||
let target_ty = self.primitives.float;
|
||||
|
||||
let arg0 = self.fold_expr(args.remove(0))?;
|
||||
let arg0_ty = arg0.custom.unwrap();
|
||||
let arg1 = self.fold_expr(args.remove(0))?;
|
||||
|
@ -931,6 +973,7 @@ impl<'a> Inferencer<'a> {
|
|||
} else {
|
||||
arg1_ty
|
||||
};
|
||||
|
||||
let expected_arg1_dtype = if id == &"np_ldexp".into() {
|
||||
self.primitives.int32
|
||||
} else {
|
||||
|
@ -939,7 +982,7 @@ impl<'a> Inferencer<'a> {
|
|||
if !self.unifier.unioned(arg1_dtype, expected_arg1_dtype) {
|
||||
return report_error(
|
||||
format!(
|
||||
"Expected {} for second argument of {id}, got {}",
|
||||
"Expected broadcast-compatible type of ndarray[{}, N] for second argument of {id}, got {}",
|
||||
self.unifier.stringify(expected_arg1_dtype),
|
||||
self.unifier.stringify(arg1_dtype),
|
||||
).as_str(),
|
||||
|
@ -947,6 +990,12 @@ impl<'a> Inferencer<'a> {
|
|||
)
|
||||
}
|
||||
|
||||
let target_ty = if id == &"np_minimum".into() || id == &"np_maximum".into() {
|
||||
arg0_dtype
|
||||
} else {
|
||||
self.primitives.float
|
||||
};
|
||||
|
||||
let ret = if [
|
||||
&arg0_ty,
|
||||
&arg1_ty,
|
||||
|
|
|
@ -174,6 +174,10 @@ def patch(module):
|
|||
# NumPy Math functions
|
||||
module.np_isnan = np.isnan
|
||||
module.np_isinf = np.isinf
|
||||
module.np_min = np.min
|
||||
module.np_minimum = np.minimum
|
||||
module.np_max = np.max
|
||||
module.np_maximum = np.maximum
|
||||
module.np_sin = np.sin
|
||||
module.np_cos = np.cos
|
||||
module.np_exp = np.exp
|
||||
|
|
|
@ -759,6 +759,92 @@ def test_ndarray_ceil():
|
|||
output_ndarray_int64_2(xf64)
|
||||
output_ndarray_float_2(xff)
|
||||
|
||||
def test_ndarray_min():
|
||||
x = np_identity(2)
|
||||
y = np_min(x)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_float64(y)
|
||||
|
||||
def test_ndarray_minimum():
|
||||
x = np_identity(2)
|
||||
min_x_zeros = np_minimum(x, np_zeros([2]))
|
||||
min_x_ones = np_minimum(x, np_zeros([2]))
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(min_x_zeros)
|
||||
output_ndarray_float_2(min_x_ones)
|
||||
|
||||
def test_ndarray_minimum_broadcast():
|
||||
x = np_identity(2)
|
||||
min_x_zeros = np_minimum(x, np_zeros([2]))
|
||||
min_x_ones = np_minimum(x, np_zeros([2]))
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(min_x_zeros)
|
||||
output_ndarray_float_2(min_x_ones)
|
||||
|
||||
def test_ndarray_minimum_broadcast_lhs_scalar():
|
||||
x = np_identity(2)
|
||||
min_x_zeros = np_minimum(0.0, x)
|
||||
min_x_ones = np_minimum(1.0, x)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(min_x_zeros)
|
||||
output_ndarray_float_2(min_x_ones)
|
||||
|
||||
def test_ndarray_minimum_broadcast_rhs_scalar():
|
||||
x = np_identity(2)
|
||||
min_x_zeros = np_minimum(x, 0.0)
|
||||
min_x_ones = np_minimum(x, 1.0)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(min_x_zeros)
|
||||
output_ndarray_float_2(min_x_ones)
|
||||
|
||||
def test_ndarray_max():
|
||||
x = np_identity(2)
|
||||
y = np_max(x)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_float64(y)
|
||||
|
||||
def test_ndarray_maximum():
|
||||
x = np_identity(2)
|
||||
max_x_zeros = np_maximum(x, np_zeros([2]))
|
||||
max_x_ones = np_maximum(x, np_zeros([2]))
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(max_x_zeros)
|
||||
output_ndarray_float_2(max_x_ones)
|
||||
|
||||
def test_ndarray_maximum_broadcast():
|
||||
x = np_identity(2)
|
||||
max_x_zeros = np_maximum(x, np_zeros([2]))
|
||||
max_x_ones = np_maximum(x, np_zeros([2]))
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(max_x_zeros)
|
||||
output_ndarray_float_2(max_x_ones)
|
||||
|
||||
def test_ndarray_maximum_broadcast_lhs_scalar():
|
||||
x = np_identity(2)
|
||||
max_x_zeros = np_maximum(0.0, x)
|
||||
max_x_ones = np_maximum(1.0, x)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(max_x_zeros)
|
||||
output_ndarray_float_2(max_x_ones)
|
||||
|
||||
def test_ndarray_maximum_broadcast_rhs_scalar():
|
||||
x = np_identity(2)
|
||||
max_x_zeros = np_maximum(x, 0.0)
|
||||
max_x_ones = np_maximum(x, 1.0)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(max_x_zeros)
|
||||
output_ndarray_float_2(max_x_ones)
|
||||
|
||||
def test_ndarray_abs():
|
||||
x = np_identity(2)
|
||||
y = abs(x)
|
||||
|
@ -1363,6 +1449,16 @@ def run() -> int32:
|
|||
|
||||
test_ndarray_round()
|
||||
test_ndarray_floor()
|
||||
test_ndarray_min()
|
||||
test_ndarray_minimum()
|
||||
test_ndarray_minimum_broadcast()
|
||||
test_ndarray_minimum_broadcast_lhs_scalar()
|
||||
test_ndarray_minimum_broadcast_rhs_scalar()
|
||||
test_ndarray_max()
|
||||
test_ndarray_maximum()
|
||||
test_ndarray_maximum_broadcast()
|
||||
test_ndarray_maximum_broadcast_lhs_scalar()
|
||||
test_ndarray_maximum_broadcast_rhs_scalar()
|
||||
test_ndarray_abs()
|
||||
test_ndarray_isnan()
|
||||
test_ndarray_isinf()
|
||||
|
|
Loading…
Reference in New Issue