[core] codegen: Reimplement ndarray unary op
Based on bb992704
: core/ndstrides: implement unary op
This commit is contained in:
parent
59f19e29df
commit
a2f1b25fd8
@ -1777,10 +1777,10 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
_ => val.into(),
|
||||
}
|
||||
} else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
|
||||
let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, ty);
|
||||
let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
||||
|
||||
let val = llvm_ndarray_ty.map_value(val.into_pointer_value(), None);
|
||||
let ndarray = NDArrayType::from_unifier_type(generator, ctx, ty)
|
||||
.map_value(val.into_pointer_value(), None);
|
||||
|
||||
// ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before
|
||||
// passing it to the elementwise codegen function
|
||||
@ -1798,20 +1798,18 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
op
|
||||
};
|
||||
|
||||
let res = numpy::ndarray_elementwise_unaryop_impl(
|
||||
let mapped_ndarray = ndarray.map(
|
||||
generator,
|
||||
ctx,
|
||||
ndarray_dtype,
|
||||
None,
|
||||
val,
|
||||
|generator, ctx, val| {
|
||||
gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), val))?
|
||||
NDArrayOut::NewNDArray { dtype: ndarray.get_type().element_type() },
|
||||
|generator, ctx, scalar| {
|
||||
gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), scalar))?
|
||||
.map(|val| val.to_basic_value_enum(ctx, generator, ndarray_dtype))
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, ndarray_dtype)
|
||||
},
|
||||
)?;
|
||||
|
||||
res.as_base_value().into()
|
||||
mapped_ndarray.as_base_value().into()
|
||||
} else {
|
||||
unimplemented!()
|
||||
}))
|
||||
|
@ -195,28 +195,6 @@ where
|
||||
})
|
||||
}
|
||||
|
||||
fn ndarray_fill_mapping<'ctx, 'a, G, MapFn>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
src: NDArrayValue<'ctx>,
|
||||
dest: NDArrayValue<'ctx>,
|
||||
map_fn: MapFn,
|
||||
) -> Result<(), String>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
MapFn: Fn(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
BasicValueEnum<'ctx>,
|
||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
{
|
||||
ndarray_fill_flattened(generator, ctx, dest, |generator, ctx, i| {
|
||||
let elem = unsafe { src.data().get_unchecked(ctx, generator, &i, None) };
|
||||
|
||||
map_fn(generator, ctx, elem)
|
||||
})
|
||||
}
|
||||
|
||||
/// Generates the LLVM IR for checking whether the source `ndarray` can be broadcast to the shape of
|
||||
/// the target `ndarray`.
|
||||
fn ndarray_assert_is_broadcastable<'ctx, G: CodeGenerator + ?Sized>(
|
||||
@ -614,43 +592,6 @@ fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||
ndarray_sliced_copy(generator, ctx, elem_ty, this, &[])
|
||||
}
|
||||
|
||||
pub fn ndarray_elementwise_unaryop_impl<'ctx, 'a, G, MapFn>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
elem_ty: Type,
|
||||
res: Option<NDArrayValue<'ctx>>,
|
||||
operand: NDArrayValue<'ctx>,
|
||||
map_fn: MapFn,
|
||||
) -> Result<NDArrayValue<'ctx>, String>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
MapFn: Fn(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
BasicValueEnum<'ctx>,
|
||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
{
|
||||
let res = res.unwrap_or_else(|| {
|
||||
create_ndarray_dyn_shape(
|
||||
generator,
|
||||
ctx,
|
||||
elem_ty,
|
||||
&operand,
|
||||
|_, ctx, v| Ok(v.load_ndims(ctx)),
|
||||
|generator, ctx, v, idx| unsafe {
|
||||
Ok(v.shape().get_typed_unchecked(ctx, generator, &idx, None))
|
||||
},
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
ndarray_fill_mapping(generator, ctx, operand, res, |generator, ctx, elem| {
|
||||
map_fn(generator, ctx, elem)
|
||||
})?;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// LLVM-typed implementation for computing elementwise binary operations on two input operands.
|
||||
///
|
||||
/// If the operand is a `ndarray`, the broadcast index corresponding to each element in the output
|
||||
|
Loading…
Reference in New Issue
Block a user