Compare commits
3 Commits
d658d9b00e
...
45e9360c4d
Author | SHA1 | Date | |
---|---|---|---|
45e9360c4d | |||
2e01b77fc8 | |||
cea7cade51 |
@ -661,90 +661,6 @@ pub fn call_min<'ctx>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Invokes the `np_min` builtin function.
|
|
||||||
pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
a: (Type, BasicValueEnum<'ctx>),
|
|
||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
|
||||||
const FN_NAME: &str = "np_min";
|
|
||||||
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
|
|
||||||
let (a_ty, a) = a;
|
|
||||||
|
|
||||||
Ok(match a {
|
|
||||||
BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => {
|
|
||||||
debug_assert!([
|
|
||||||
ctx.primitives.bool,
|
|
||||||
ctx.primitives.int32,
|
|
||||||
ctx.primitives.uint32,
|
|
||||||
ctx.primitives.int64,
|
|
||||||
ctx.primitives.uint64,
|
|
||||||
ctx.primitives.float,
|
|
||||||
]
|
|
||||||
.iter()
|
|
||||||
.any(|ty| ctx.unifier.unioned(a_ty, *ty)));
|
|
||||||
|
|
||||||
a
|
|
||||||
}
|
|
||||||
|
|
||||||
BasicValueEnum::PointerValue(n)
|
|
||||||
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
|
||||||
{
|
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
|
|
||||||
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
||||||
|
|
||||||
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
|
|
||||||
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None));
|
|
||||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
|
||||||
let n_sz_eqz = ctx
|
|
||||||
.builder
|
|
||||||
.build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "")
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
n_sz_eqz,
|
|
||||||
"0:ValueError",
|
|
||||||
"zero-size array to reduction operation minimum which has no identity",
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
|
|
||||||
unsafe {
|
|
||||||
let identity =
|
|
||||||
n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
|
|
||||||
ctx.builder.build_store(accumulator_addr, identity).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
gen_for_callback_incrementing(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
llvm_usize.const_int(1, false),
|
|
||||||
(n_sz, false),
|
|
||||||
|generator, ctx, _, idx| {
|
|
||||||
let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) };
|
|
||||||
|
|
||||||
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
|
|
||||||
let result = call_min(ctx, (elem_ty, accumulator), (elem_ty, elem));
|
|
||||||
ctx.builder.build_store(accumulator_addr, result).unwrap();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
},
|
|
||||||
llvm_usize.const_int(1, false),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
|
|
||||||
accumulator
|
|
||||||
}
|
|
||||||
|
|
||||||
_ => unsupported_type(ctx, FN_NAME, &[a_ty]),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Invokes the `np_minimum` builtin function.
|
/// Invokes the `np_minimum` builtin function.
|
||||||
pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>(
|
pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
@ -877,19 +793,21 @@ pub fn call_max<'ctx>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Invokes the `np_max` builtin function.
|
/// Invokes the np_max, np_min, np_argmax, np_argmin functions
|
||||||
pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
|
/// * `fn_name`: Can be one of "np_argmin", "np_argmax", "np_max", "np_min"
|
||||||
|
pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
a: (Type, BasicValueEnum<'ctx>),
|
a: (Type, BasicValueEnum<'ctx>),
|
||||||
|
fn_name: &str,
|
||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
const FN_NAME: &str = "np_max";
|
debug_assert!(["np_argmin", "np_argmax", "np_max", "np_min"].iter().any(|f| *f == fn_name));
|
||||||
|
|
||||||
|
let llvm_int64 = ctx.ctx.i64_type();
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
let (a_ty, a) = a;
|
let (a_ty, a) = a;
|
||||||
|
Ok( match a {
|
||||||
Ok(match a {
|
|
||||||
BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => {
|
BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => {
|
||||||
debug_assert!([
|
debug_assert!([
|
||||||
ctx.primitives.bool,
|
ctx.primitives.bool,
|
||||||
@ -901,14 +819,17 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
]
|
]
|
||||||
.iter()
|
.iter()
|
||||||
.any(|ty| ctx.unifier.unioned(a_ty, *ty)));
|
.any(|ty| ctx.unifier.unioned(a_ty, *ty)));
|
||||||
|
|
||||||
a
|
match fn_name {
|
||||||
|
"np_argmin" | "np_argmax" => llvm_int64.const_zero().into(),
|
||||||
|
"np_max" | "np_min" => a,
|
||||||
|
_ => unreachable!()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
BasicValueEnum::PointerValue(n)
|
BasicValueEnum::PointerValue(n)
|
||||||
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
|
||||||
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
|
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
|
||||||
@ -923,41 +844,71 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
generator,
|
generator,
|
||||||
n_sz_eqz,
|
n_sz_eqz,
|
||||||
"0:ValueError",
|
"0:ValueError",
|
||||||
"zero-size array to reduction operation minimum which has no identity",
|
format!("zero-size array to reduction operation {}", fn_name).as_str(),
|
||||||
[None, None, None],
|
[None, None, None],
|
||||||
ctx.current_loc,
|
ctx.current_loc,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
|
let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
|
||||||
|
let res_idx = generator.gen_var_alloc(ctx, llvm_int64.into(), None)?;
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
let identity =
|
let identity =
|
||||||
n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
|
n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
|
||||||
ctx.builder.build_store(accumulator_addr, identity).unwrap();
|
ctx.builder.build_store(accumulator_addr, identity).unwrap();
|
||||||
|
ctx.builder.build_store(res_idx, llvm_int64.const_zero()).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
gen_for_callback_incrementing(
|
gen_for_callback_incrementing(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
llvm_usize.const_int(1, false),
|
llvm_int64.const_int(1, false),
|
||||||
(n_sz, false),
|
(n_sz, false),
|
||||||
|generator, ctx, _, idx| {
|
|generator, ctx, _, idx,| {
|
||||||
let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) };
|
let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) };
|
||||||
|
|
||||||
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
|
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
|
||||||
let result = call_max(ctx, (elem_ty, accumulator), (elem_ty, elem));
|
let cur_idx = ctx.builder.build_load(res_idx, "").unwrap();
|
||||||
|
|
||||||
|
let result = match fn_name {
|
||||||
|
"np_argmin" | "np_min" => call_min(ctx, (elem_ty, accumulator), (elem_ty, elem)),
|
||||||
|
"np_argmax" | "np_max" => call_max(ctx, (elem_ty, accumulator), (elem_ty, elem)),
|
||||||
|
_ => unreachable!()
|
||||||
|
};
|
||||||
|
|
||||||
|
let updated_idx = match (accumulator, result){
|
||||||
|
(BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => {
|
||||||
|
ctx.builder.build_select(
|
||||||
|
ctx.builder.build_int_compare(IntPredicate::NE,m, n, "").unwrap(),
|
||||||
|
idx.into(),
|
||||||
|
cur_idx,
|
||||||
|
"").unwrap()
|
||||||
|
},
|
||||||
|
(BasicValueEnum::FloatValue(m), BasicValueEnum::FloatValue(n)) => {
|
||||||
|
ctx.builder.build_select(
|
||||||
|
ctx.builder.build_float_compare(FloatPredicate::ONE,m, n, "").unwrap(),
|
||||||
|
idx.into(),
|
||||||
|
cur_idx,
|
||||||
|
"").unwrap()
|
||||||
|
},
|
||||||
|
_ => unsupported_type(ctx, fn_name, &[elem_ty, elem_ty]),
|
||||||
|
};
|
||||||
|
ctx.builder.build_store(res_idx, updated_idx).unwrap();
|
||||||
ctx.builder.build_store(accumulator_addr, result).unwrap();
|
ctx.builder.build_store(accumulator_addr, result).unwrap();
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
},
|
},
|
||||||
llvm_usize.const_int(1, false),
|
llvm_int64.const_int(1, false),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
|
match fn_name {
|
||||||
accumulator
|
"np_argmin" | "np_argmax" => ctx.builder.build_load(res_idx, "").unwrap(),
|
||||||
|
"np_max" | "np_min" => ctx.builder.build_load(accumulator_addr, "").unwrap(),
|
||||||
|
_ => unreachable!()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_ => unsupported_type(ctx, FN_NAME, &[a_ty]),
|
_ => unsupported_type(ctx, fn_name, &[a_ty])
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -510,7 +510,10 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
|
|
||||||
PrimDef::FunMin | PrimDef::FunMax => self.build_min_max_function(prim),
|
PrimDef::FunMin | PrimDef::FunMax => self.build_min_max_function(prim),
|
||||||
|
|
||||||
PrimDef::FunNpMin | PrimDef::FunNpMax => self.build_np_min_max_function(prim),
|
PrimDef::FunNpArgmin
|
||||||
|
| PrimDef::FunNpArgmax
|
||||||
|
| PrimDef::FunNpMin
|
||||||
|
| PrimDef::FunNpMax => self.build_np_max_min_function(prim),
|
||||||
|
|
||||||
PrimDef::FunNpMinimum | PrimDef::FunNpMaximum => {
|
PrimDef::FunNpMinimum | PrimDef::FunNpMaximum => {
|
||||||
self.build_np_minimum_maximum_function(prim)
|
self.build_np_minimum_maximum_function(prim)
|
||||||
@ -1555,39 +1558,42 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build the functions `np_min()` and `np_max()`.
|
/// Build the functions `np_max()`, `np_min()`, `np_argmax()` and `np_argmin()`
|
||||||
fn build_np_min_max_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
/// Calls `call_numpy_max_min` with the function name
|
||||||
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpMin, PrimDef::FunNpMax]);
|
fn build_np_max_min_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||||
|
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpArgmin, PrimDef::FunNpArgmax, PrimDef::FunNpMin, PrimDef::FunNpMax]);
|
||||||
|
|
||||||
let ret_ty = self.unifier.get_fresh_var(Some("R".into()), None);
|
let (var_map, ret_ty) = match prim {
|
||||||
let var_map = self
|
PrimDef::FunNpArgmax | PrimDef::FunNpArgmin => {
|
||||||
.num_or_ndarray_var_map
|
(self.num_or_ndarray_var_map.clone(), self.primitives.int64)
|
||||||
.clone()
|
},
|
||||||
.into_iter()
|
PrimDef::FunNpMax | PrimDef::FunNpMin => {
|
||||||
.chain(once((ret_ty.id, ret_ty.ty)))
|
let ret_ty = self.unifier.get_fresh_var(Some("R".into()), None);
|
||||||
.collect::<IndexMap<_, _>>();
|
let var_map = self
|
||||||
|
.num_or_ndarray_var_map
|
||||||
|
.clone()
|
||||||
|
.into_iter()
|
||||||
|
.chain(once((ret_ty.id, ret_ty.ty)))
|
||||||
|
.collect::<IndexMap<_, _>>();
|
||||||
|
(var_map, ret_ty.ty)
|
||||||
|
},
|
||||||
|
_ => unreachable!()
|
||||||
|
};
|
||||||
|
|
||||||
create_fn_by_codegen(
|
create_fn_by_codegen(
|
||||||
self.unifier,
|
self.unifier,
|
||||||
&var_map,
|
&var_map,
|
||||||
prim.name(),
|
prim.name(),
|
||||||
ret_ty.ty,
|
ret_ty,
|
||||||
&[(self.float_or_ndarray_ty.ty, "a")],
|
&[(self.num_or_ndarray_ty.ty, "a")],
|
||||||
Box::new(move |ctx, _, fun, args, generator| {
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
let a_ty = fun.0.args[0].ty;
|
let a_ty = fun.0.args[0].ty;
|
||||||
let a = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?;
|
let a = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?;
|
||||||
|
|
||||||
let func = match prim {
|
Ok(Some(builtin_fns::call_numpy_max_min(generator, ctx, (a_ty, a), &prim.name())?))
|
||||||
PrimDef::FunNpMin => builtin_fns::call_numpy_min,
|
|
||||||
PrimDef::FunNpMax => builtin_fns::call_numpy_max,
|
|
||||||
_ => unreachable!(),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(Some(func(generator, ctx, (a_ty, a))?))
|
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build the functions `np_minimum()` and `np_maximum()`.
|
/// Build the functions `np_minimum()` and `np_maximum()`.
|
||||||
fn build_np_minimum_maximum_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
fn build_np_minimum_maximum_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||||
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpMinimum, PrimDef::FunNpMaximum]);
|
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpMinimum, PrimDef::FunNpMaximum]);
|
||||||
|
@ -62,9 +62,11 @@ pub enum PrimDef {
|
|||||||
FunMin,
|
FunMin,
|
||||||
FunNpMin,
|
FunNpMin,
|
||||||
FunNpMinimum,
|
FunNpMinimum,
|
||||||
|
FunNpArgmin,
|
||||||
FunMax,
|
FunMax,
|
||||||
FunNpMax,
|
FunNpMax,
|
||||||
FunNpMaximum,
|
FunNpMaximum,
|
||||||
|
FunNpArgmax,
|
||||||
FunAbs,
|
FunAbs,
|
||||||
FunNpIsNan,
|
FunNpIsNan,
|
||||||
FunNpIsInf,
|
FunNpIsInf,
|
||||||
@ -216,9 +218,11 @@ impl PrimDef {
|
|||||||
PrimDef::FunMin => fun("min", None),
|
PrimDef::FunMin => fun("min", None),
|
||||||
PrimDef::FunNpMin => fun("np_min", None),
|
PrimDef::FunNpMin => fun("np_min", None),
|
||||||
PrimDef::FunNpMinimum => fun("np_minimum", None),
|
PrimDef::FunNpMinimum => fun("np_minimum", None),
|
||||||
|
PrimDef::FunNpArgmin => fun("np_argmin", None),
|
||||||
PrimDef::FunMax => fun("max", None),
|
PrimDef::FunMax => fun("max", None),
|
||||||
PrimDef::FunNpMax => fun("np_max", None),
|
PrimDef::FunNpMax => fun("np_max", None),
|
||||||
PrimDef::FunNpMaximum => fun("np_maximum", None),
|
PrimDef::FunNpMaximum => fun("np_maximum", None),
|
||||||
|
PrimDef::FunNpArgmax => fun("np_argmax", None),
|
||||||
PrimDef::FunAbs => fun("abs", None),
|
PrimDef::FunAbs => fun("abs", None),
|
||||||
PrimDef::FunNpIsNan => fun("np_isnan", None),
|
PrimDef::FunNpIsNan => fun("np_isnan", None),
|
||||||
PrimDef::FunNpIsInf => fun("np_isinf", None),
|
PrimDef::FunNpIsInf => fun("np_isinf", None),
|
||||||
|
@ -867,6 +867,13 @@ def test_ndarray_minimum_broadcast_rhs_scalar():
|
|||||||
output_ndarray_float_2(min_x_zeros)
|
output_ndarray_float_2(min_x_zeros)
|
||||||
output_ndarray_float_2(min_x_ones)
|
output_ndarray_float_2(min_x_ones)
|
||||||
|
|
||||||
|
def test_ndarray_argmin():
|
||||||
|
x = np_array([[1., 2.], [3., 4.]])
|
||||||
|
y = np_argmin(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_int64(y)
|
||||||
|
|
||||||
def test_ndarray_max():
|
def test_ndarray_max():
|
||||||
x = np_identity(2)
|
x = np_identity(2)
|
||||||
y = np_max(x)
|
y = np_max(x)
|
||||||
@ -910,6 +917,13 @@ def test_ndarray_maximum_broadcast_rhs_scalar():
|
|||||||
output_ndarray_float_2(max_x_zeros)
|
output_ndarray_float_2(max_x_zeros)
|
||||||
output_ndarray_float_2(max_x_ones)
|
output_ndarray_float_2(max_x_ones)
|
||||||
|
|
||||||
|
def test_ndarray_argmax():
|
||||||
|
x = np_array([[1., 2.], [3., 4.]])
|
||||||
|
y = np_argmax(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_int64(y)
|
||||||
|
|
||||||
def test_ndarray_abs():
|
def test_ndarray_abs():
|
||||||
x = np_identity(2)
|
x = np_identity(2)
|
||||||
y = abs(x)
|
y = abs(x)
|
||||||
@ -1524,11 +1538,13 @@ def run() -> int32:
|
|||||||
test_ndarray_minimum_broadcast()
|
test_ndarray_minimum_broadcast()
|
||||||
test_ndarray_minimum_broadcast_lhs_scalar()
|
test_ndarray_minimum_broadcast_lhs_scalar()
|
||||||
test_ndarray_minimum_broadcast_rhs_scalar()
|
test_ndarray_minimum_broadcast_rhs_scalar()
|
||||||
|
test_ndarray_argmin()
|
||||||
test_ndarray_max()
|
test_ndarray_max()
|
||||||
test_ndarray_maximum()
|
test_ndarray_maximum()
|
||||||
test_ndarray_maximum_broadcast()
|
test_ndarray_maximum_broadcast()
|
||||||
test_ndarray_maximum_broadcast_lhs_scalar()
|
test_ndarray_maximum_broadcast_lhs_scalar()
|
||||||
test_ndarray_maximum_broadcast_rhs_scalar()
|
test_ndarray_maximum_broadcast_rhs_scalar()
|
||||||
|
test_ndarray_argmax()
|
||||||
test_ndarray_abs()
|
test_ndarray_abs()
|
||||||
test_ndarray_isnan()
|
test_ndarray_isnan()
|
||||||
test_ndarray_isinf()
|
test_ndarray_isinf()
|
||||||
|
Loading…
Reference in New Issue
Block a user