Improvements to unary operators #394

Merged
sb10q merged 1 commits from enhance/issue-149-ndarray/operators into master 2024-08-17 17:37:21 +08:00
3 changed files with 82 additions and 18 deletions

View File

@ -30,7 +30,7 @@ use crate::{
},
typecheck::{
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
magic_methods::{binop_name, binop_assign_name},
magic_methods::{binop_name, binop_assign_name, unaryop_name},
},
};
use inkwell::{
@ -1306,18 +1306,27 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
Ok(Some(if ty == ctx.primitives.bool {
let val = val.into_int_value();
match op {
ast::Unaryop::Invert | ast::Unaryop::Not => {
let not = ctx.builder.build_not(val, "not").unwrap();
let not_bool = ctx.builder.build_and(
not,
not.get_type().const_int(1, false),
"",
).unwrap();
if *op == ast::Unaryop::Not {
let not = ctx.builder.build_not(val, "not").unwrap();
let not_bool = ctx.builder.build_and(
not,
not.get_type().const_int(1, false),
"",
).unwrap();
not_bool.into()
} else {
let llvm_i32 = ctx.ctx.i32_type();
not_bool.into()
}
_ => val.into(),
gen_unaryop_expr_with_values(
generator,
ctx,
op,
(
&Some(ctx.primitives.int32),
ctx.builder.build_int_z_extend(val, llvm_i32, "").map(Into::into).unwrap()
),
)?.unwrap()
}
} else if [ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty) {
let val = val.into_int_value();
@ -1353,6 +1362,18 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
None,
);
// ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before
// passing it to the elementwise codegen function
let op = if ndarray_dtype.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) {
if *op == ast::Unaryop::Invert {
&ast::Unaryop::Not
} else {
unreachable!("ufunc {} not supported for ndarray[bool, N]", unaryop_name(op))
}
} else {
op
};
let res = numpy::ndarray_elementwise_unaryop_impl(
generator,
ctx,
@ -1364,7 +1385,7 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
generator,
ctx,
op,
(&Some(ndarray_dtype), val)
(&Some(ndarray_dtype), val),
)?.unwrap().to_basic_value_enum(ctx, generator, ndarray_dtype)
},
)?;

View File

@ -472,23 +472,47 @@ pub fn typeof_unaryop(
op: &Unaryop,
operand: Type,
) -> Result<Option<Type>, String> {
if *op == Unaryop::Not && operand.obj_id(unifier).is_some_and(|id| id == primitives.ndarray.obj_id(unifier).unwrap()) {
let operand_obj_id = operand.obj_id(unifier);
if *op == Unaryop::Not && operand_obj_id.is_some_and(|id| id == primitives.ndarray.obj_id(unifier).unwrap()) {
return Err("The truth value of an array with more than one element is ambiguous".to_string())
}
Ok(match *op {
Unaryop::Not => {
match operand.obj_id(unifier) {
match operand_obj_id {
Some(v) if v == PRIMITIVE_DEF_IDS.ndarray => Some(operand),
Some(_) => Some(primitives.bool),
_ => None
}
}
Unaryop::Invert
| Unaryop::UAdd
Unaryop::Invert => {
if operand_obj_id.is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) {
Some(primitives.int32)
} else if operand_obj_id.is_some_and(|id| PRIMITIVE_DEF_IDS.iter().any(|prim_id| id == prim_id)) {
Some(operand)
} else {
None
}
}
Unaryop::UAdd
| Unaryop::USub => {
if operand.obj_id(unifier).is_some_and(|id| PRIMITIVE_DEF_IDS.iter().any(|prim_id| id == prim_id)) {
if operand_obj_id.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
let (dtype, _) = unpack_ndarray_var_tys(unifier, operand);
if dtype.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) {
return Err(if *op == Unaryop::UAdd {
"The ufunc 'positive' cannot be applied to ndarray[bool, N]".to_string()
} else {
"The numpy boolean negative, the `-` operator, is not supported, use the `~` operator function instead.".to_string()
})
}
Some(operand)
} else if operand_obj_id.is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) {
Some(primitives.int32)
} else if operand_obj_id.is_some_and(|id| PRIMITIVE_DEF_IDS.iter().any(|prim_id| id == prim_id)) {
Some(operand)
} else {
None
@ -571,7 +595,9 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
/* bool ======== */
let ndarray_bool_t = make_ndarray_ty(unifier, store, Some(bool_t), None);
impl_invert(unifier, store, bool_t, Some(int32_t));
impl_not(unifier, store, bool_t, Some(bool_t));
impl_sign(unifier, store, bool_t, Some(int32_t));
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None);
/* ndarray ===== */

View File

@ -1,5 +1,9 @@
from __future__ import annotations
@extern
def output_bool(x: bool):
...
@extern
def output_int32(x: int32):
...
@ -17,6 +21,7 @@ def output_float64(x: float):
...
def run() -> int32:
test_bool()
test_int32()
test_uint32()
test_int64()
@ -25,6 +30,18 @@ def run() -> int32:
# test_B()
return 0
def test_bool():
t = True
f = False
output_bool(not t)
output_bool(not f)
output_int32(~t)
output_int32(~f)
output_int32(+t)
output_int32(+f)
output_int32(-t)
output_int32(-f)
def test_int32():
a = 17
b = 3