diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index d04035a..78500fe 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -1018,73 +1018,203 @@ pub fn call_numpy_cos<'ctx, G: CodeGenerator + ?Sized>( } /// Invokes the `np_exp` builtin function. -pub fn call_numpy_exp<'ctx>( +pub fn call_numpy_exp<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_exp"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (x_ty, x) = x; - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "np_exp", &[x_ty]) - } + Ok(match x.get_type() { + BasicTypeEnum::FloatType(_) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); - llvm_intrinsics::call_float_exp(ctx, x, None) + llvm_intrinsics::call_float_exp(ctx, x.into_float_value(), None).into() + } + + BasicTypeEnum::PointerType(_) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(x.into_pointer_value(), llvm_usize, None), + |generator, ctx, val| { + call_numpy_exp(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) } /// Invokes the `np_exp2` builtin function. -pub fn call_numpy_exp2<'ctx>( +pub fn call_numpy_exp2<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_exp2"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (x_ty, x) = x; - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "np_exp2", &[x_ty]) - } + Ok(match x.get_type() { + BasicTypeEnum::FloatType(_) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); - llvm_intrinsics::call_float_exp2(ctx, x, None) + llvm_intrinsics::call_float_exp2(ctx, x.into_float_value(), None).into() + } + + BasicTypeEnum::PointerType(_) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(x.into_pointer_value(), llvm_usize, None), + |generator, ctx, val| { + call_numpy_exp2(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) } /// Invokes the `np_log` builtin function. -pub fn call_numpy_log<'ctx>( +pub fn call_numpy_log<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_log"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (x_ty, x) = x; - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "np_log", &[x_ty]) - } + Ok(match x.get_type() { + BasicTypeEnum::FloatType(_) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); - llvm_intrinsics::call_float_log(ctx, x, None) + llvm_intrinsics::call_float_log(ctx, x.into_float_value(), None).into() + } + + BasicTypeEnum::PointerType(_) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(x.into_pointer_value(), llvm_usize, None), + |generator, ctx, val| { + call_numpy_log(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) } /// Invokes the `np_log10` builtin function. -pub fn call_numpy_log10<'ctx>( +pub fn call_numpy_log10<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_log10"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (x_ty, x) = x; - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "np_log10", &[x_ty]) - } + Ok(match x.get_type() { + BasicTypeEnum::FloatType(_) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); - llvm_intrinsics::call_float_log10(ctx, x, None) + llvm_intrinsics::call_float_log10(ctx, x.into_float_value(), None).into() + } + + BasicTypeEnum::PointerType(_) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(x.into_pointer_value(), llvm_usize, None), + |generator, ctx, val| { + call_numpy_log10(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) } /// Invokes the `np_log2` builtin function. -pub fn call_numpy_log2<'ctx>( +pub fn call_numpy_log2<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_log2"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (x_ty, x) = x; - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "np_log2", &[x_ty]) - } + Ok(match x.get_type() { + BasicTypeEnum::FloatType(_) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); - llvm_intrinsics::call_float_log2(ctx, x, None) + llvm_intrinsics::call_float_log2(ctx, x.into_float_value(), None).into() + } + + BasicTypeEnum::PointerType(_) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(x.into_pointer_value(), llvm_usize, None), + |generator, ctx, val| { + call_numpy_log2(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) } /// Invokes the `np_sqrt` builtin function. diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 0b651b0..7c94d68 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1444,75 +1444,70 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built unifier, &var_map, "np_exp", - float, - &[(float, "x")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_numpy_exp(ctx, (x_ty, x_val)).into())) + Ok(Some(builtin_fns::call_numpy_exp(generator, ctx, (x_ty, x_val))?)) }), ), create_fn_by_codegen( unifier, &var_map, "np_exp2", - float, - &[(float, "x")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_numpy_exp2(ctx, (x_ty, x_val)).into())) + Ok(Some(builtin_fns::call_numpy_exp2(generator, ctx, (x_ty, x_val))?)) }), ), create_fn_by_codegen( unifier, &var_map, "np_log", - float, - &[(float, "x")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_numpy_log(ctx, (x_ty, x_val)).into())) + Ok(Some(builtin_fns::call_numpy_log(generator, ctx, (x_ty, x_val))?)) }), ), create_fn_by_codegen( unifier, &var_map, "np_log10", - float, - &[(float, "x")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_numpy_log10(ctx, (x_ty, x_val)).into())) + Ok(Some(builtin_fns::call_numpy_log10(generator, ctx, (x_ty, x_val))?)) }), ), create_fn_by_codegen( unifier, &var_map, "np_log2", - float, - &[(float, "x")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_numpy_log2(ctx, (x_ty, x_val)).into())) + Ok(Some(builtin_fns::call_numpy_log2(generator, ctx, (x_ty, x_val))?)) }), ), create_fn_by_codegen( diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index a938ee1..d499470 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -802,6 +802,41 @@ def test_ndarray_cos(): output_ndarray_float_2(x) output_ndarray_float_2(y) +def test_ndarray_exp(): + x = np_identity(2) + y = np_exp(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_exp2(): + x = np_identity(2) + y = np_exp2(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_log(): + x = np_identity(2) + y = np_log(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_log10(): + x = np_identity(2) + y = np_log10(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_log2(): + x = np_identity(2) + y = np_log2(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + def run() -> int32: test_ndarray_ctor() test_ndarray_empty() @@ -907,5 +942,10 @@ def run() -> int32: test_ndarray_sin() test_ndarray_cos() + test_ndarray_exp() + test_ndarray_exp2() + test_ndarray_log() + test_ndarray_log10() + test_ndarray_log2() return 0