diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index c9b9cbd..fa6c865 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -341,15 +341,17 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( } /// Invokes the `float` builtin function. -pub fn call_float<'ctx>( +pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, n: (Type, BasicValueEnum<'ctx>), -) -> FloatValue<'ctx> { +) -> Result, String> { let llvm_f64 = ctx.ctx.f64_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); let (n_ty, n) = n; - match n.get_type() { + Ok(match n.get_type() { BasicTypeEnum::IntType(int_ty) if matches!(int_ty.get_bit_width(), 1 | 8 | 32 | 64) => { debug_assert!([ ctx.primitives.bool, @@ -366,22 +368,41 @@ pub fn call_float<'ctx>( ].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty)) { ctx.builder .build_signed_int_to_float(n.into_int_value(), llvm_f64, "sitofp") + .map(Into::into) .unwrap() } else { ctx.builder .build_unsigned_int_to_float(n.into_int_value(), llvm_f64, "uitofp") - .unwrap() + .map(Into::into) + .unwrap() } } BasicTypeEnum::FloatType(_) => { debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - n.into_float_value() + n + } + + BasicTypeEnum::PointerType(_) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + ctx.primitives.float, + None, + NDArrayValue::from_ptr_val(n.into_pointer_value(), llvm_usize, None), + |generator, ctx, val| { + call_float(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() } _ => unsupported_type(ctx, "float", &[n_ty]) - } + }) } /// Invokes the `round` builtin function. @@ -419,19 +440,22 @@ pub fn call_numpy_round<'ctx>( } /// Invokes the `bool` builtin function. -pub fn call_bool<'ctx>( +pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, n: (Type, BasicValueEnum<'ctx>), -) -> IntValue<'ctx> { +) -> Result, String> { const FN_NAME: &str = "bool"; + let llvm_usize = generator.get_size_type(ctx.ctx); + let (n_ty, n) = n; - match n.get_type() { + Ok(match n.get_type() { BasicTypeEnum::IntType(int_ty) if matches!(int_ty.get_bit_width(), 1 | 8) => { debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool)); - n.into_int_value() + n } BasicTypeEnum::IntType(_) => { @@ -445,6 +469,7 @@ pub fn call_bool<'ctx>( let val = n.into_int_value(); ctx.builder .build_int_compare(IntPredicate::NE, val, val.get_type().const_zero(), FN_NAME) + .map(Into::into) .unwrap() } @@ -454,11 +479,39 @@ pub fn call_bool<'ctx>( let val = n.into_float_value(); ctx.builder .build_float_compare(FloatPredicate::UNE, val, val.get_type().const_zero(), FN_NAME) + .map(Into::into) .unwrap() } + BasicTypeEnum::PointerType(_) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + ctx.primitives.bool, + None, + NDArrayValue::from_ptr_val( + n.into_pointer_value(), + llvm_usize, + None, + ), + |generator, ctx, val| { + let elem = call_bool( + generator, + ctx, + (elem_ty, val), + )?; + + Ok(generator.bool_to_i8(ctx, elem.into_int_value()).into()) + }, + )?; + + ndarray.as_ptr_value().into() + } + _ => unsupported_type(ctx, FN_NAME, &[n_ty]) - } + }) } /// Invokes the `floor` builtin function. diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 1b79efd..343b158 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -662,8 +662,8 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built name: "float".into(), simple_name: "float".into(), signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }], - ret: float, + args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }], + ret: num_or_ndarray_ty.0, vars: var_map.clone(), })), var_id: Vec::default(), @@ -675,7 +675,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(builtin_fns::call_float(ctx, (arg_ty, arg)).into())) + Ok(Some(builtin_fns::call_float(generator, ctx, (arg_ty, arg))?)) }, )))), loc: None, @@ -967,8 +967,8 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built name: "bool".into(), simple_name: "bool".into(), signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }], - ret: primitives.bool, + args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }], + ret: num_or_ndarray_ty.0, vars: var_map.clone(), })), var_id: Vec::default(), @@ -980,7 +980,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(builtin_fns::call_bool(ctx, (arg_ty, arg)).into())) + Ok(Some(builtin_fns::call_bool(generator, ctx, (arg_ty, arg))?)) }, )))), loc: None, diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index a144c01..7fa6c63 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -851,8 +851,18 @@ impl<'a> Inferencer<'a> { })) } - if id == &"int32".into() && args.len() == 1 { - let int32_ty = self.primitives.int32; + if [ + "int32", + "float", + "bool", + ].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 { + let target_ty = if id == &"int32".into() { + self.primitives.int32 + } else if id == &"float".into() { + self.primitives.float + } else if id == &"bool".into() { + self.primitives.bool + } else { unreachable!() }; let arg0 = self.fold_expr(args.remove(0))?; let arg0_ty = arg0.custom.unwrap(); @@ -860,9 +870,9 @@ impl<'a> Inferencer<'a> { let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) { let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty); - make_ndarray_ty(self.unifier, self.primitives, Some(int32_ty), Some(ndarray_ndims)) + make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims)) } else { - int32_ty + target_ty }; let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index 03deff4..b3c9f69 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -58,6 +58,18 @@ class _NDArrayDummy(Generic[T, N]): # https://stackoverflow.com/questions/67803260/how-to-create-a-type-alias-with-a-throw-away-generic NDArray = Union[npt.NDArray[T], _NDArrayDummy[T, N]] +def _bool(x): + if isinstance(x, np.ndarray): + return np.bool_(x) + else: + return bool(x) + +def _float(x): + if isinstance(x, np.ndarray): + return np.float_(x) + else: + return float(x) + def round_away_zero(x): if x >= 0.0: return math.floor(x + 0.5) @@ -112,6 +124,8 @@ def patch(module): module.int64 = int64 module.uint32 = uint32 module.uint64 = uint64 + module.bool = _bool + module.float = _float module.TypeVar = TypeVar module.ConstGeneric = ConstGeneric module.Generic = Generic diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index 661272a..4c3e7d0 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -704,6 +704,20 @@ def test_ndarray_uint64(): output_ndarray_float_2(x) output_ndarray_uint64_2(y) +def test_ndarray_float(): + x = np_full([2, 2], 1) + y = float(x) + + output_ndarray_int32_2(x) + output_ndarray_float_2(y) + +def test_ndarray_bool(): + x = np_identity(2) + y = bool(x) + + output_ndarray_float_2(x) + output_ndarray_bool_2(y) + def run() -> int32: test_ndarray_ctor() test_ndarray_empty() @@ -798,5 +812,7 @@ def run() -> int32: test_ndarray_int64() test_ndarray_uint32() test_ndarray_uint64() + test_ndarray_float() + test_ndarray_bool() return 0