From 5f143d2f2f1ea618e0ad839cc3258e7c9cd7c2c2 Mon Sep 17 00:00:00 2001 From: lyken Date: Wed, 21 Aug 2024 10:08:30 +0800 Subject: [PATCH] core/ndstrides: implement unary op --- nac3core/src/codegen/expr.rs | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 486c833..cedea09 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1770,14 +1770,12 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( _ => val.into(), } } else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { - let llvm_usize = generator.get_size_type(ctx.ctx); - let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); - - let val = NDArrayValue::from_ptr_val(val.into_pointer_value(), llvm_usize, None); + let ndarray = AnyObject { value: val, ty }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); // ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before // passing it to the elementwise codegen function - let op = if ndarray_dtype.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::Bool.id()) { + let op = if ndarray.dtype.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::Bool.id()) { if op == ast::Unaryop::Invert { ast::Unaryop::Not } else { @@ -1790,20 +1788,18 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( op }; - let res = numpy::ndarray_elementwise_unaryop_impl( + let mapped_ndarray = ndarray.map( generator, ctx, - ndarray_dtype, - None, - val, - |generator, ctx, val| { - gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), val))? + NDArrayOut::NewNDArray { dtype: ndarray.dtype }, + |generator, ctx, scalar| { + gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray.dtype), scalar))? .unwrap() - .to_basic_value_enum(ctx, generator, ndarray_dtype) + .to_basic_value_enum(ctx, generator, ndarray.dtype) }, )?; - res.as_base_value().into() + ValueEnum::Dynamic(mapped_ndarray.instance.value.as_basic_value_enum()) } else { unimplemented!() }))