From dcbba2fb2630c2ad2124dc09ceeedda27d00e2c3 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 25 Apr 2024 22:46:54 +0800 Subject: [PATCH] core: WIP - float and bool works now --- nac3core/src/codegen/builtin_fns.rs | 83 ++++++++++++++++--- nac3core/src/toplevel/builtins.rs | 12 +-- nac3core/src/typecheck/type_inferencer/mod.rs | 18 +++- nac3standalone/demo/src/ndarray.py | 18 ++++ 4 files changed, 110 insertions(+), 21 deletions(-) diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index a5300af..6506ea8 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -373,15 +373,17 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( } /// Invokes the `float` builtin function. -pub fn call_float<'ctx>( +pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, n: (Type, BasicValueEnum<'ctx>), -) -> FloatValue<'ctx> { +) -> Result, String> { let llvm_f64 = ctx.ctx.f64_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); let (n_ty, n) = n; - match n.get_type() { + Ok(match n.get_type() { BasicTypeEnum::IntType(int_ty) if matches!(int_ty.get_bit_width(), 1 | 8 | 32 | 64) => { debug_assert!([ ctx.primitives.bool, @@ -398,22 +400,49 @@ pub fn call_float<'ctx>( ].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty)) { ctx.builder .build_signed_int_to_float(n.into_int_value(), llvm_f64, "sitofp") + .map(Into::into) .unwrap() } else { ctx.builder .build_unsigned_int_to_float(n.into_int_value(), llvm_f64, "uitofp") - .unwrap() + .map(Into::into) + .unwrap() } } BasicTypeEnum::FloatType(_) => { debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - n.into_float_value() + n + } + + BasicTypeEnum::PointerType(_) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + ctx.primitives.float, + None, + NDArrayValue::from_ptr_val( + n.into_pointer_value(), + llvm_usize, + None, + ), + |generator, ctx, val| { + call_float( + generator, + ctx, + (elem_ty, val), + ) + }, + )?; + + ndarray.as_ptr_value().into() } _ => unsupported_type(ctx, "float", &[n_ty]) - } + }) } /// Invokes the `round` builtin function. @@ -451,19 +480,22 @@ pub fn call_numpy_round<'ctx>( } /// Invokes the `bool` builtin function. -pub fn call_bool<'ctx>( +pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, n: (Type, BasicValueEnum<'ctx>), -) -> IntValue<'ctx> { +) -> Result, String> { const FN_NAME: &str = "bool"; + let llvm_usize = generator.get_size_type(ctx.ctx); + let (n_ty, n) = n; - match n.get_type() { + Ok(match n.get_type() { BasicTypeEnum::IntType(int_ty) if matches!(int_ty.get_bit_width(), 1 | 8) => { debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool)); - n.into_int_value() + n } BasicTypeEnum::IntType(_) => { @@ -477,6 +509,7 @@ pub fn call_bool<'ctx>( let val = n.into_int_value(); ctx.builder .build_int_compare(IntPredicate::NE, val, val.get_type().const_zero(), FN_NAME) + .map(Into::into) .unwrap() } @@ -486,11 +519,39 @@ pub fn call_bool<'ctx>( let val = n.into_float_value(); ctx.builder .build_float_compare(FloatPredicate::UNE, val, val.get_type().const_zero(), FN_NAME) + .map(Into::into) .unwrap() } + BasicTypeEnum::PointerType(_) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + ctx.primitives.bool, + None, + NDArrayValue::from_ptr_val( + n.into_pointer_value(), + llvm_usize, + None, + ), + |generator, ctx, val| { + let elem = call_bool( + generator, + ctx, + (elem_ty, val), + )?; + + Ok(generator.bool_to_i8(ctx, elem.into_int_value()).into()) + }, + )?; + + ndarray.as_ptr_value().into() + } + _ => unsupported_type(ctx, FN_NAME, &[n_ty]) - } + }) } /// Invokes the `floor` builtin function. diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index d6f5c68..00db8c1 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -655,8 +655,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { name: "float".into(), simple_name: "float".into(), signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }], - ret: float, + args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }], + ret: num_or_ndarray_ty.0, vars: var_map.clone(), })), var_id: Vec::default(), @@ -668,7 +668,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(builtin_fns::call_float(ctx, (arg_ty, arg)).into())) + Ok(Some(builtin_fns::call_float(generator, ctx, (arg_ty, arg))?)) }, )))), loc: None, @@ -960,8 +960,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { name: "bool".into(), simple_name: "bool".into(), signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }], - ret: primitives.0.bool, + args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }], + ret: num_or_ndarray_ty.0, vars: var_map.clone(), })), var_id: Vec::default(), @@ -973,7 +973,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(builtin_fns::call_bool(ctx, (arg_ty, arg)).into())) + Ok(Some(builtin_fns::call_bool(generator, ctx, (arg_ty, arg))?)) }, )))), loc: None, diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index a144c01..7fa6c63 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -851,8 +851,18 @@ impl<'a> Inferencer<'a> { })) } - if id == &"int32".into() && args.len() == 1 { - let int32_ty = self.primitives.int32; + if [ + "int32", + "float", + "bool", + ].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 { + let target_ty = if id == &"int32".into() { + self.primitives.int32 + } else if id == &"float".into() { + self.primitives.float + } else if id == &"bool".into() { + self.primitives.bool + } else { unreachable!() }; let arg0 = self.fold_expr(args.remove(0))?; let arg0_ty = arg0.custom.unwrap(); @@ -860,9 +870,9 @@ impl<'a> Inferencer<'a> { let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) { let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty); - make_ndarray_ty(self.unifier, self.primitives, Some(int32_ty), Some(ndarray_ndims)) + make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims)) } else { - int32_ty + target_ty }; let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index 661272a..d52d520 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -704,6 +704,22 @@ def test_ndarray_uint64(): output_ndarray_float_2(x) output_ndarray_uint64_2(y) +# TODO: builtin function float() cannot cast ndarrays - numpy.float_ needed +# def test_ndarray_float(): +# x = np_full([2, 2], 1) +# y = float(x) +# +# output_ndarray_int32_2(x) +# output_ndarray_float_2(y) + +# TODO: builtin function bool() cannot cast ndarrays - numpy.bool_ needed +# def test_ndarray_bool(): +# x = np_identity(2) +# y = bool(x) +# +# output_ndarray_float_2(x) +# output_ndarray_bool_2(y) + def run() -> int32: test_ndarray_ctor() test_ndarray_empty() @@ -798,5 +814,7 @@ def run() -> int32: test_ndarray_int64() test_ndarray_uint32() test_ndarray_uint64() + # test_ndarray_float() + # test_ndarray_bool() return 0