diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index f360c78..474052a 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -30,7 +30,7 @@ use crate::{ }, typecheck::{ typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, - magic_methods::{binop_name, binop_assign_name}, + magic_methods::{binop_name, binop_assign_name, unaryop_name}, }, }; use inkwell::{ @@ -1306,18 +1306,27 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( Ok(Some(if ty == ctx.primitives.bool { let val = val.into_int_value(); - match op { - ast::Unaryop::Invert | ast::Unaryop::Not => { - let not = ctx.builder.build_not(val, "not").unwrap(); - let not_bool = ctx.builder.build_and( - not, - not.get_type().const_int(1, false), - "", - ).unwrap(); + if *op == ast::Unaryop::Not { + let not = ctx.builder.build_not(val, "not").unwrap(); + let not_bool = ctx.builder.build_and( + not, + not.get_type().const_int(1, false), + "", + ).unwrap(); + + not_bool.into() + } else { + let llvm_i32 = ctx.ctx.i32_type(); - not_bool.into() - } - _ => val.into(), + gen_unaryop_expr_with_values( + generator, + ctx, + op, + ( + &Some(ctx.primitives.int32), + ctx.builder.build_int_z_extend(val, llvm_i32, "").map(Into::into).unwrap() + ), + )?.unwrap() } } else if [ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty) { let val = val.into_int_value(); @@ -1353,6 +1362,18 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( None, ); + // ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before + // passing it to the elementwise codegen function + let op = if ndarray_dtype.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) { + if *op == ast::Unaryop::Invert { + &ast::Unaryop::Not + } else { + unreachable!("ufunc {} not supported for ndarray[bool, N]", unaryop_name(op)) + } + } else { + op + }; + let res = numpy::ndarray_elementwise_unaryop_impl( generator, ctx, @@ -1364,7 +1385,7 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( generator, ctx, op, - (&Some(ndarray_dtype), val) + (&Some(ndarray_dtype), val), )?.unwrap().to_basic_value_enum(ctx, generator, ndarray_dtype) }, )?; diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index 59067b8..dc7afe5 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -472,23 +472,47 @@ pub fn typeof_unaryop( op: &Unaryop, operand: Type, ) -> Result, String> { - if *op == Unaryop::Not && operand.obj_id(unifier).is_some_and(|id| id == primitives.ndarray.obj_id(unifier).unwrap()) { + let operand_obj_id = operand.obj_id(unifier); + + if *op == Unaryop::Not && operand_obj_id.is_some_and(|id| id == primitives.ndarray.obj_id(unifier).unwrap()) { return Err("The truth value of an array with more than one element is ambiguous".to_string()) } Ok(match *op { Unaryop::Not => { - match operand.obj_id(unifier) { + match operand_obj_id { Some(v) if v == PRIMITIVE_DEF_IDS.ndarray => Some(operand), Some(_) => Some(primitives.bool), _ => None } } - Unaryop::Invert - | Unaryop::UAdd + Unaryop::Invert => { + if operand_obj_id.is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) { + Some(primitives.int32) + } else if operand_obj_id.is_some_and(|id| PRIMITIVE_DEF_IDS.iter().any(|prim_id| id == prim_id)) { + Some(operand) + } else { + None + } + } + + Unaryop::UAdd | Unaryop::USub => { - if operand.obj_id(unifier).is_some_and(|id| PRIMITIVE_DEF_IDS.iter().any(|prim_id| id == prim_id)) { + if operand_obj_id.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) { + let (dtype, _) = unpack_ndarray_var_tys(unifier, operand); + if dtype.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) { + return Err(if *op == Unaryop::UAdd { + "The ufunc 'positive' cannot be applied to ndarray[bool, N]".to_string() + } else { + "The numpy boolean negative, the `-` operator, is not supported, use the `~` operator function instead.".to_string() + }) + } + + Some(operand) + } else if operand_obj_id.is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) { + Some(primitives.int32) + } else if operand_obj_id.is_some_and(|id| PRIMITIVE_DEF_IDS.iter().any(|prim_id| id == prim_id)) { Some(operand) } else { None @@ -571,7 +595,9 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie /* bool ======== */ let ndarray_bool_t = make_ndarray_ty(unifier, store, Some(bool_t), None); + impl_invert(unifier, store, bool_t, Some(int32_t)); impl_not(unifier, store, bool_t, Some(bool_t)); + impl_sign(unifier, store, bool_t, Some(int32_t)); impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None); /* ndarray ===== */ diff --git a/nac3standalone/demo/src/operators.py b/nac3standalone/demo/src/operators.py index 8933cbe..06e7dee 100644 --- a/nac3standalone/demo/src/operators.py +++ b/nac3standalone/demo/src/operators.py @@ -1,5 +1,9 @@ from __future__ import annotations +@extern +def output_bool(x: bool): + ... + @extern def output_int32(x: int32): ... @@ -17,6 +21,7 @@ def output_float64(x: float): ... def run() -> int32: + test_bool() test_int32() test_uint32() test_int64() @@ -25,6 +30,18 @@ def run() -> int32: # test_B() return 0 +def test_bool(): + t = True + f = False + output_bool(not t) + output_bool(not f) + output_int32(~t) + output_int32(~f) + output_int32(+t) + output_int32(+f) + output_int32(-t) + output_int32(-f) + def test_int32(): a = 17 b = 3