core/ndstrides: implement unary op

This commit is contained in:
lyken 2024-08-21 10:08:30 +08:00 committed by David Mak
parent 9e40c83490
commit bb992704b2

View File

@ -1777,14 +1777,12 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
_ => val.into(), _ => val.into(),
} }
} else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { } else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
let llvm_usize = generator.get_size_type(ctx.ctx); let ndarray = AnyObject { value: val, ty };
let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
let val = NDArrayValue::from_ptr_val(val.into_pointer_value(), llvm_usize, None);
// ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before // ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before
// passing it to the elementwise codegen function // passing it to the elementwise codegen function
let op = if ndarray_dtype.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::Bool.id()) { let op = if ndarray.dtype.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::Bool.id()) {
if op == ast::Unaryop::Invert { if op == ast::Unaryop::Invert {
ast::Unaryop::Not ast::Unaryop::Not
} else { } else {
@ -1798,20 +1796,18 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
op op
}; };
let res = numpy::ndarray_elementwise_unaryop_impl( let mapped_ndarray = ndarray.map(
generator, generator,
ctx, ctx,
ndarray_dtype, NDArrayOut::NewNDArray { dtype: ndarray.dtype },
None, |generator, ctx, scalar| {
val, gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray.dtype), scalar))?
|generator, ctx, val| {
gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), val))?
.unwrap() .unwrap()
.to_basic_value_enum(ctx, generator, ndarray_dtype) .to_basic_value_enum(ctx, generator, ndarray.dtype)
}, },
)?; )?;
res.as_base_value().into() ValueEnum::Dynamic(mapped_ndarray.instance.value.as_basic_value_enum())
} else { } else {
unimplemented!() unimplemented!()
})) }))