forked from M-Labs/nac3
core: Implement Not/UAdd/USub for booleans
Not sure if this is deliberate or an oversight, but we implement it anyway for consistency with other Python implementations.
This commit is contained in:
parent
00d1b9be9b
commit
52c731c312
@ -30,7 +30,7 @@ use crate::{
|
|||||||
},
|
},
|
||||||
typecheck::{
|
typecheck::{
|
||||||
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
|
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
|
||||||
magic_methods::{binop_name, binop_assign_name},
|
magic_methods::{binop_name, binop_assign_name, unaryop_name},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
@ -1306,8 +1306,7 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
|
|
||||||
Ok(Some(if ty == ctx.primitives.bool {
|
Ok(Some(if ty == ctx.primitives.bool {
|
||||||
let val = val.into_int_value();
|
let val = val.into_int_value();
|
||||||
match op {
|
if *op == ast::Unaryop::Not {
|
||||||
ast::Unaryop::Invert | ast::Unaryop::Not => {
|
|
||||||
let not = ctx.builder.build_not(val, "not").unwrap();
|
let not = ctx.builder.build_not(val, "not").unwrap();
|
||||||
let not_bool = ctx.builder.build_and(
|
let not_bool = ctx.builder.build_and(
|
||||||
not,
|
not,
|
||||||
@ -1316,8 +1315,18 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
).unwrap();
|
).unwrap();
|
||||||
|
|
||||||
not_bool.into()
|
not_bool.into()
|
||||||
}
|
} else {
|
||||||
_ => val.into(),
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
|
||||||
|
gen_unaryop_expr_with_values(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
op,
|
||||||
|
(
|
||||||
|
&Some(ctx.primitives.int32),
|
||||||
|
ctx.builder.build_int_z_extend(val, llvm_i32, "").map(Into::into).unwrap()
|
||||||
|
),
|
||||||
|
)?.unwrap()
|
||||||
}
|
}
|
||||||
} else if [ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty) {
|
} else if [ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty) {
|
||||||
let val = val.into_int_value();
|
let val = val.into_int_value();
|
||||||
@ -1353,6 +1362,18 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before
|
||||||
|
// passing it to the elementwise codegen function
|
||||||
|
let op = if ndarray_dtype.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) {
|
||||||
|
if *op == ast::Unaryop::Invert {
|
||||||
|
&ast::Unaryop::Not
|
||||||
|
} else {
|
||||||
|
unreachable!("ufunc {} not supported for ndarray[bool, N]", unaryop_name(op))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
op
|
||||||
|
};
|
||||||
|
|
||||||
let res = numpy::ndarray_elementwise_unaryop_impl(
|
let res = numpy::ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
@ -1364,7 +1385,7 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
op,
|
op,
|
||||||
(&Some(ndarray_dtype), val)
|
(&Some(ndarray_dtype), val),
|
||||||
)?.unwrap().to_basic_value_enum(ctx, generator, ndarray_dtype)
|
)?.unwrap().to_basic_value_enum(ctx, generator, ndarray_dtype)
|
||||||
},
|
},
|
||||||
)?;
|
)?;
|
||||||
|
@ -472,23 +472,47 @@ pub fn typeof_unaryop(
|
|||||||
op: &Unaryop,
|
op: &Unaryop,
|
||||||
operand: Type,
|
operand: Type,
|
||||||
) -> Result<Option<Type>, String> {
|
) -> Result<Option<Type>, String> {
|
||||||
if *op == Unaryop::Not && operand.obj_id(unifier).is_some_and(|id| id == primitives.ndarray.obj_id(unifier).unwrap()) {
|
let operand_obj_id = operand.obj_id(unifier);
|
||||||
|
|
||||||
|
if *op == Unaryop::Not && operand_obj_id.is_some_and(|id| id == primitives.ndarray.obj_id(unifier).unwrap()) {
|
||||||
return Err("The truth value of an array with more than one element is ambiguous".to_string())
|
return Err("The truth value of an array with more than one element is ambiguous".to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(match *op {
|
Ok(match *op {
|
||||||
Unaryop::Not => {
|
Unaryop::Not => {
|
||||||
match operand.obj_id(unifier) {
|
match operand_obj_id {
|
||||||
Some(v) if v == PRIMITIVE_DEF_IDS.ndarray => Some(operand),
|
Some(v) if v == PRIMITIVE_DEF_IDS.ndarray => Some(operand),
|
||||||
Some(_) => Some(primitives.bool),
|
Some(_) => Some(primitives.bool),
|
||||||
_ => None
|
_ => None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Unaryop::Invert
|
Unaryop::Invert => {
|
||||||
| Unaryop::UAdd
|
if operand_obj_id.is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) {
|
||||||
|
Some(primitives.int32)
|
||||||
|
} else if operand_obj_id.is_some_and(|id| PRIMITIVE_DEF_IDS.iter().any(|prim_id| id == prim_id)) {
|
||||||
|
Some(operand)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Unaryop::UAdd
|
||||||
| Unaryop::USub => {
|
| Unaryop::USub => {
|
||||||
if operand.obj_id(unifier).is_some_and(|id| PRIMITIVE_DEF_IDS.iter().any(|prim_id| id == prim_id)) {
|
if operand_obj_id.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
|
||||||
|
let (dtype, _) = unpack_ndarray_var_tys(unifier, operand);
|
||||||
|
if dtype.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) {
|
||||||
|
return Err(if *op == Unaryop::UAdd {
|
||||||
|
"The ufunc 'positive' cannot be applied to ndarray[bool, N]".to_string()
|
||||||
|
} else {
|
||||||
|
"The numpy boolean negative, the `-` operator, is not supported, use the `~` operator function instead.".to_string()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(operand)
|
||||||
|
} else if operand_obj_id.is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) {
|
||||||
|
Some(primitives.int32)
|
||||||
|
} else if operand_obj_id.is_some_and(|id| PRIMITIVE_DEF_IDS.iter().any(|prim_id| id == prim_id)) {
|
||||||
Some(operand)
|
Some(operand)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
@ -571,7 +595,9 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
|||||||
|
|
||||||
/* bool ======== */
|
/* bool ======== */
|
||||||
let ndarray_bool_t = make_ndarray_ty(unifier, store, Some(bool_t), None);
|
let ndarray_bool_t = make_ndarray_ty(unifier, store, Some(bool_t), None);
|
||||||
|
impl_invert(unifier, store, bool_t, Some(int32_t));
|
||||||
impl_not(unifier, store, bool_t, Some(bool_t));
|
impl_not(unifier, store, bool_t, Some(bool_t));
|
||||||
|
impl_sign(unifier, store, bool_t, Some(int32_t));
|
||||||
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None);
|
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None);
|
||||||
|
|
||||||
/* ndarray ===== */
|
/* ndarray ===== */
|
||||||
|
@ -1,5 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
@extern
|
||||||
|
def output_bool(x: bool):
|
||||||
|
...
|
||||||
|
|
||||||
@extern
|
@extern
|
||||||
def output_int32(x: int32):
|
def output_int32(x: int32):
|
||||||
...
|
...
|
||||||
@ -17,6 +21,7 @@ def output_float64(x: float):
|
|||||||
...
|
...
|
||||||
|
|
||||||
def run() -> int32:
|
def run() -> int32:
|
||||||
|
test_bool()
|
||||||
test_int32()
|
test_int32()
|
||||||
test_uint32()
|
test_uint32()
|
||||||
test_int64()
|
test_int64()
|
||||||
@ -25,6 +30,18 @@ def run() -> int32:
|
|||||||
# test_B()
|
# test_B()
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
def test_bool():
|
||||||
|
t = True
|
||||||
|
f = False
|
||||||
|
output_bool(not t)
|
||||||
|
output_bool(not f)
|
||||||
|
output_int32(~t)
|
||||||
|
output_int32(~f)
|
||||||
|
output_int32(+t)
|
||||||
|
output_int32(+f)
|
||||||
|
output_int32(-t)
|
||||||
|
output_int32(-f)
|
||||||
|
|
||||||
def test_int32():
|
def test_int32():
|
||||||
a = 17
|
a = 17
|
||||||
b = 3
|
b = 3
|
||||||
|
Loading…
Reference in New Issue
Block a user