forked from M-Labs/nac3
core: Implement Not/UAdd/USub for booleans
Not sure if this is deliberate or an oversight, but we implement it anyway for consistency with other Python implementations.
This commit is contained in:
parent
00d1b9be9b
commit
52c731c312
|
@ -30,7 +30,7 @@ use crate::{
|
|||
},
|
||||
typecheck::{
|
||||
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
|
||||
magic_methods::{binop_name, binop_assign_name},
|
||||
magic_methods::{binop_name, binop_assign_name, unaryop_name},
|
||||
},
|
||||
};
|
||||
use inkwell::{
|
||||
|
@ -1306,18 +1306,27 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||
|
||||
Ok(Some(if ty == ctx.primitives.bool {
|
||||
let val = val.into_int_value();
|
||||
match op {
|
||||
ast::Unaryop::Invert | ast::Unaryop::Not => {
|
||||
let not = ctx.builder.build_not(val, "not").unwrap();
|
||||
let not_bool = ctx.builder.build_and(
|
||||
not,
|
||||
not.get_type().const_int(1, false),
|
||||
"",
|
||||
).unwrap();
|
||||
if *op == ast::Unaryop::Not {
|
||||
let not = ctx.builder.build_not(val, "not").unwrap();
|
||||
let not_bool = ctx.builder.build_and(
|
||||
not,
|
||||
not.get_type().const_int(1, false),
|
||||
"",
|
||||
).unwrap();
|
||||
|
||||
not_bool.into()
|
||||
}
|
||||
_ => val.into(),
|
||||
not_bool.into()
|
||||
} else {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
|
||||
gen_unaryop_expr_with_values(
|
||||
generator,
|
||||
ctx,
|
||||
op,
|
||||
(
|
||||
&Some(ctx.primitives.int32),
|
||||
ctx.builder.build_int_z_extend(val, llvm_i32, "").map(Into::into).unwrap()
|
||||
),
|
||||
)?.unwrap()
|
||||
}
|
||||
} else if [ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty) {
|
||||
let val = val.into_int_value();
|
||||
|
@ -1353,6 +1362,18 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||
None,
|
||||
);
|
||||
|
||||
// ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before
|
||||
// passing it to the elementwise codegen function
|
||||
let op = if ndarray_dtype.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) {
|
||||
if *op == ast::Unaryop::Invert {
|
||||
&ast::Unaryop::Not
|
||||
} else {
|
||||
unreachable!("ufunc {} not supported for ndarray[bool, N]", unaryop_name(op))
|
||||
}
|
||||
} else {
|
||||
op
|
||||
};
|
||||
|
||||
let res = numpy::ndarray_elementwise_unaryop_impl(
|
||||
generator,
|
||||
ctx,
|
||||
|
@ -1364,7 +1385,7 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||
generator,
|
||||
ctx,
|
||||
op,
|
||||
(&Some(ndarray_dtype), val)
|
||||
(&Some(ndarray_dtype), val),
|
||||
)?.unwrap().to_basic_value_enum(ctx, generator, ndarray_dtype)
|
||||
},
|
||||
)?;
|
||||
|
|
|
@ -472,23 +472,47 @@ pub fn typeof_unaryop(
|
|||
op: &Unaryop,
|
||||
operand: Type,
|
||||
) -> Result<Option<Type>, String> {
|
||||
if *op == Unaryop::Not && operand.obj_id(unifier).is_some_and(|id| id == primitives.ndarray.obj_id(unifier).unwrap()) {
|
||||
let operand_obj_id = operand.obj_id(unifier);
|
||||
|
||||
if *op == Unaryop::Not && operand_obj_id.is_some_and(|id| id == primitives.ndarray.obj_id(unifier).unwrap()) {
|
||||
return Err("The truth value of an array with more than one element is ambiguous".to_string())
|
||||
}
|
||||
|
||||
Ok(match *op {
|
||||
Unaryop::Not => {
|
||||
match operand.obj_id(unifier) {
|
||||
match operand_obj_id {
|
||||
Some(v) if v == PRIMITIVE_DEF_IDS.ndarray => Some(operand),
|
||||
Some(_) => Some(primitives.bool),
|
||||
_ => None
|
||||
}
|
||||
}
|
||||
|
||||
Unaryop::Invert
|
||||
| Unaryop::UAdd
|
||||
Unaryop::Invert => {
|
||||
if operand_obj_id.is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) {
|
||||
Some(primitives.int32)
|
||||
} else if operand_obj_id.is_some_and(|id| PRIMITIVE_DEF_IDS.iter().any(|prim_id| id == prim_id)) {
|
||||
Some(operand)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
Unaryop::UAdd
|
||||
| Unaryop::USub => {
|
||||
if operand.obj_id(unifier).is_some_and(|id| PRIMITIVE_DEF_IDS.iter().any(|prim_id| id == prim_id)) {
|
||||
if operand_obj_id.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
|
||||
let (dtype, _) = unpack_ndarray_var_tys(unifier, operand);
|
||||
if dtype.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) {
|
||||
return Err(if *op == Unaryop::UAdd {
|
||||
"The ufunc 'positive' cannot be applied to ndarray[bool, N]".to_string()
|
||||
} else {
|
||||
"The numpy boolean negative, the `-` operator, is not supported, use the `~` operator function instead.".to_string()
|
||||
})
|
||||
}
|
||||
|
||||
Some(operand)
|
||||
} else if operand_obj_id.is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) {
|
||||
Some(primitives.int32)
|
||||
} else if operand_obj_id.is_some_and(|id| PRIMITIVE_DEF_IDS.iter().any(|prim_id| id == prim_id)) {
|
||||
Some(operand)
|
||||
} else {
|
||||
None
|
||||
|
@ -571,7 +595,9 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
|||
|
||||
/* bool ======== */
|
||||
let ndarray_bool_t = make_ndarray_ty(unifier, store, Some(bool_t), None);
|
||||
impl_invert(unifier, store, bool_t, Some(int32_t));
|
||||
impl_not(unifier, store, bool_t, Some(bool_t));
|
||||
impl_sign(unifier, store, bool_t, Some(int32_t));
|
||||
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None);
|
||||
|
||||
/* ndarray ===== */
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
@extern
|
||||
def output_bool(x: bool):
|
||||
...
|
||||
|
||||
@extern
|
||||
def output_int32(x: int32):
|
||||
...
|
||||
|
@ -17,6 +21,7 @@ def output_float64(x: float):
|
|||
...
|
||||
|
||||
def run() -> int32:
|
||||
test_bool()
|
||||
test_int32()
|
||||
test_uint32()
|
||||
test_int64()
|
||||
|
@ -25,6 +30,18 @@ def run() -> int32:
|
|||
# test_B()
|
||||
return 0
|
||||
|
||||
def test_bool():
|
||||
t = True
|
||||
f = False
|
||||
output_bool(not t)
|
||||
output_bool(not f)
|
||||
output_int32(~t)
|
||||
output_int32(~f)
|
||||
output_int32(+t)
|
||||
output_int32(+f)
|
||||
output_int32(-t)
|
||||
output_int32(-f)
|
||||
|
||||
def test_int32():
|
||||
a = 17
|
||||
b = 3
|
||||
|
|
Loading…
Reference in New Issue