[core] codegen: Reimplement ndarray unary op

Based on bb992704: core/ndstrides: implement unary op
This commit is contained in:
David Mak 2024-12-19 10:37:17 +08:00
parent 303b20db15
commit f43523ec72
2 changed files with 8 additions and 69 deletions

View File

@ -1800,10 +1800,10 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
_ => val.into(),
}
} else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, ty);
let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
let val = llvm_ndarray_ty.map_value(val.into_pointer_value(), None);
let ndarray = NDArrayType::from_unifier_type(generator, ctx, ty)
.map_value(val.into_pointer_value(), None);
// ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before
// passing it to the elementwise codegen function
@ -1821,20 +1821,18 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
op
};
let res = numpy::ndarray_elementwise_unaryop_impl(
let mapped_ndarray = ndarray.map(
generator,
ctx,
ndarray_dtype,
None,
val,
|generator, ctx, val| {
gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), val))?
NDArrayOut::NewNDArray { dtype: ndarray.get_type().element_type() },
|generator, ctx, scalar| {
gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), scalar))?
.map(|val| val.to_basic_value_enum(ctx, generator, ndarray_dtype))
.unwrap()
.to_basic_value_enum(ctx, generator, ndarray_dtype)
},
)?;
res.as_base_value().into()
mapped_ndarray.as_base_value().into()
} else {
unimplemented!()
}))

View File

@ -195,28 +195,6 @@ where
})
}
fn ndarray_fill_mapping<'ctx, 'a, G, MapFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
src: NDArrayValue<'ctx>,
dest: NDArrayValue<'ctx>,
map_fn: MapFn,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
MapFn: Fn(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
BasicValueEnum<'ctx>,
) -> Result<BasicValueEnum<'ctx>, String>,
{
ndarray_fill_flattened(generator, ctx, dest, |generator, ctx, i| {
let elem = unsafe { src.data().get_unchecked(ctx, generator, &i, None) };
map_fn(generator, ctx, elem)
})
}
/// Generates the LLVM IR for checking whether the source `ndarray` can be broadcast to the shape of
/// the target `ndarray`.
fn ndarray_assert_is_broadcastable<'ctx, G: CodeGenerator + ?Sized>(
@ -614,43 +592,6 @@ fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>(
ndarray_sliced_copy(generator, ctx, elem_ty, this, &[])
}
pub fn ndarray_elementwise_unaryop_impl<'ctx, 'a, G, MapFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
elem_ty: Type,
res: Option<NDArrayValue<'ctx>>,
operand: NDArrayValue<'ctx>,
map_fn: MapFn,
) -> Result<NDArrayValue<'ctx>, String>
where
G: CodeGenerator + ?Sized,
MapFn: Fn(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
BasicValueEnum<'ctx>,
) -> Result<BasicValueEnum<'ctx>, String>,
{
let res = res.unwrap_or_else(|| {
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&operand,
|_, ctx, v| Ok(v.load_ndims(ctx)),
|generator, ctx, v, idx| unsafe {
Ok(v.shape().get_typed_unchecked(ctx, generator, &idx, None))
},
)
.unwrap()
});
ndarray_fill_mapping(generator, ctx, operand, res, |generator, ctx, elem| {
map_fn(generator, ctx, elem)
})?;
Ok(res)
}
/// LLVM-typed implementation for computing elementwise binary operations on two input operands.
///
/// If the operand is a `ndarray`, the broadcast index corresponding to each element in the output