[core] codegen: Reimplement ndarray unary op

Based on bb992704: core/ndstrides: implement unary op
This commit is contained in:
David Mak 2024-12-19 10:37:17 +08:00
parent 59f19e29df
commit a2f1b25fd8
2 changed files with 8 additions and 69 deletions

View File

@ -1777,10 +1777,10 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
_ => val.into(), _ => val.into(),
} }
} else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { } else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, ty);
let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
let val = llvm_ndarray_ty.map_value(val.into_pointer_value(), None); let ndarray = NDArrayType::from_unifier_type(generator, ctx, ty)
.map_value(val.into_pointer_value(), None);
// ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before // ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before
// passing it to the elementwise codegen function // passing it to the elementwise codegen function
@ -1798,20 +1798,18 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
op op
}; };
let res = numpy::ndarray_elementwise_unaryop_impl( let mapped_ndarray = ndarray.map(
generator, generator,
ctx, ctx,
ndarray_dtype, NDArrayOut::NewNDArray { dtype: ndarray.get_type().element_type() },
None, |generator, ctx, scalar| {
val, gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), scalar))?
|generator, ctx, val| { .map(|val| val.to_basic_value_enum(ctx, generator, ndarray_dtype))
gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), val))?
.unwrap() .unwrap()
.to_basic_value_enum(ctx, generator, ndarray_dtype)
}, },
)?; )?;
res.as_base_value().into() mapped_ndarray.as_base_value().into()
} else { } else {
unimplemented!() unimplemented!()
})) }))

View File

@ -195,28 +195,6 @@ where
}) })
} }
fn ndarray_fill_mapping<'ctx, 'a, G, MapFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
src: NDArrayValue<'ctx>,
dest: NDArrayValue<'ctx>,
map_fn: MapFn,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
MapFn: Fn(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
BasicValueEnum<'ctx>,
) -> Result<BasicValueEnum<'ctx>, String>,
{
ndarray_fill_flattened(generator, ctx, dest, |generator, ctx, i| {
let elem = unsafe { src.data().get_unchecked(ctx, generator, &i, None) };
map_fn(generator, ctx, elem)
})
}
/// Generates the LLVM IR for checking whether the source `ndarray` can be broadcast to the shape of /// Generates the LLVM IR for checking whether the source `ndarray` can be broadcast to the shape of
/// the target `ndarray`. /// the target `ndarray`.
fn ndarray_assert_is_broadcastable<'ctx, G: CodeGenerator + ?Sized>( fn ndarray_assert_is_broadcastable<'ctx, G: CodeGenerator + ?Sized>(
@ -614,43 +592,6 @@ fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>(
ndarray_sliced_copy(generator, ctx, elem_ty, this, &[]) ndarray_sliced_copy(generator, ctx, elem_ty, this, &[])
} }
pub fn ndarray_elementwise_unaryop_impl<'ctx, 'a, G, MapFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
elem_ty: Type,
res: Option<NDArrayValue<'ctx>>,
operand: NDArrayValue<'ctx>,
map_fn: MapFn,
) -> Result<NDArrayValue<'ctx>, String>
where
G: CodeGenerator + ?Sized,
MapFn: Fn(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
BasicValueEnum<'ctx>,
) -> Result<BasicValueEnum<'ctx>, String>,
{
let res = res.unwrap_or_else(|| {
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&operand,
|_, ctx, v| Ok(v.load_ndims(ctx)),
|generator, ctx, v, idx| unsafe {
Ok(v.shape().get_typed_unchecked(ctx, generator, &idx, None))
},
)
.unwrap()
});
ndarray_fill_mapping(generator, ctx, operand, res, |generator, ctx, elem| {
map_fn(generator, ctx, elem)
})?;
Ok(res)
}
/// LLVM-typed implementation for computing elementwise binary operations on two input operands. /// LLVM-typed implementation for computing elementwise binary operations on two input operands.
/// ///
/// If the operand is a `ndarray`, the broadcast index corresponding to each element in the output /// If the operand is a `ndarray`, the broadcast index corresponding to each element in the output