diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 4c26b43..7c203df 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -789,17 +789,19 @@ pub fn call_max<'ctx>( } /// Invokes the `abs` builtin function. -pub fn call_abs<'ctx>( +pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, n: (Type, BasicValueEnum<'ctx>), -) -> BasicValueEnum<'ctx> { +) -> Result, String> { const FN_NAME: &str = "abs"; let llvm_i1 = ctx.ctx.bool_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); let (n_ty, n) = n; - match n.get_type() { + Ok(match n.get_type() { BasicTypeEnum::IntType(_) => { debug_assert!([ ctx.primitives.bool, @@ -830,8 +832,25 @@ pub fn call_abs<'ctx>( llvm_intrinsics::call_float_fabs(ctx, n.into_float_value(), Some(FN_NAME)).into() } + BasicTypeEnum::PointerType(_) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(n.into_pointer_value(), llvm_usize, None), + |generator, ctx, val| { + call_abs(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + _ => unsupported_type(ctx, FN_NAME, &[n_ty]) - } + }) } /// Invokes the `np_isnan` builtin function. diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index c64a196..f9049d9 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1366,8 +1366,8 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built name: "abs".into(), simple_name: "abs".into(), signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }], - ret: num_ty.0, + args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }], + ret: num_or_ndarray_ty.0, vars: var_map.clone(), })), var_id: Vec::default(), @@ -1379,7 +1379,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built let n_ty = fun.0.args[0].ty; let n_val = args[0].1.clone().to_basic_value_enum(ctx, generator, n_ty)?; - Ok(Some(builtin_fns::call_abs(ctx, (n_ty, n_val)))) + Ok(Some(builtin_fns::call_abs(generator, ctx, (n_ty, n_val))?)) }, )))), loc: None, diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index 9af48f1..d803c80 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -751,6 +751,13 @@ def test_ndarray_ceil(): output_ndarray_int64_2(xf64) output_ndarray_float_2(xff) +def test_ndarray_abs(): + x = np_identity(2) + y = abs(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + def run() -> int32: test_ndarray_ctor() test_ndarray_empty() @@ -850,5 +857,6 @@ def run() -> int32: test_ndarray_round() test_ndarray_floor() + test_ndarray_abs() return 0