Improvements to unary operators #394

Merged
sb10q merged 1 commits from enhance/issue-149-ndarray/operators into master 2024-08-17 17:37:21 +08:00
3 changed files with 82 additions and 18 deletions

View File

@ -30,7 +30,7 @@ use crate::{
}, },
typecheck::{ typecheck::{
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
magic_methods::{binop_name, binop_assign_name}, magic_methods::{binop_name, binop_assign_name, unaryop_name},
}, },
}; };
use inkwell::{ use inkwell::{
@ -1306,8 +1306,7 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
Ok(Some(if ty == ctx.primitives.bool { Ok(Some(if ty == ctx.primitives.bool {
let val = val.into_int_value(); let val = val.into_int_value();
match op { if *op == ast::Unaryop::Not {
ast::Unaryop::Invert | ast::Unaryop::Not => {
let not = ctx.builder.build_not(val, "not").unwrap(); let not = ctx.builder.build_not(val, "not").unwrap();
let not_bool = ctx.builder.build_and( let not_bool = ctx.builder.build_and(
not, not,
@ -1316,8 +1315,18 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
).unwrap(); ).unwrap();
not_bool.into() not_bool.into()
} } else {
_ => val.into(), let llvm_i32 = ctx.ctx.i32_type();
gen_unaryop_expr_with_values(
generator,
ctx,
op,
(
&Some(ctx.primitives.int32),
ctx.builder.build_int_z_extend(val, llvm_i32, "").map(Into::into).unwrap()
),
)?.unwrap()
} }
} else if [ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty) { } else if [ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty) {
let val = val.into_int_value(); let val = val.into_int_value();
@ -1353,6 +1362,18 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
None, None,
); );
// ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before
// passing it to the elementwise codegen function
let op = if ndarray_dtype.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) {
if *op == ast::Unaryop::Invert {
&ast::Unaryop::Not
} else {
unreachable!("ufunc {} not supported for ndarray[bool, N]", unaryop_name(op))
}
} else {
op
};
let res = numpy::ndarray_elementwise_unaryop_impl( let res = numpy::ndarray_elementwise_unaryop_impl(
generator, generator,
ctx, ctx,
@ -1364,7 +1385,7 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
generator, generator,
ctx, ctx,
op, op,
(&Some(ndarray_dtype), val) (&Some(ndarray_dtype), val),
)?.unwrap().to_basic_value_enum(ctx, generator, ndarray_dtype) )?.unwrap().to_basic_value_enum(ctx, generator, ndarray_dtype)
}, },
)?; )?;

View File

@ -472,23 +472,47 @@ pub fn typeof_unaryop(
op: &Unaryop, op: &Unaryop,
operand: Type, operand: Type,
) -> Result<Option<Type>, String> { ) -> Result<Option<Type>, String> {
if *op == Unaryop::Not && operand.obj_id(unifier).is_some_and(|id| id == primitives.ndarray.obj_id(unifier).unwrap()) { let operand_obj_id = operand.obj_id(unifier);
if *op == Unaryop::Not && operand_obj_id.is_some_and(|id| id == primitives.ndarray.obj_id(unifier).unwrap()) {
return Err("The truth value of an array with more than one element is ambiguous".to_string()) return Err("The truth value of an array with more than one element is ambiguous".to_string())
} }
Ok(match *op { Ok(match *op {
Unaryop::Not => { Unaryop::Not => {
match operand.obj_id(unifier) { match operand_obj_id {
Some(v) if v == PRIMITIVE_DEF_IDS.ndarray => Some(operand), Some(v) if v == PRIMITIVE_DEF_IDS.ndarray => Some(operand),
Some(_) => Some(primitives.bool), Some(_) => Some(primitives.bool),
_ => None _ => None
} }
} }
Unaryop::Invert Unaryop::Invert => {
| Unaryop::UAdd if operand_obj_id.is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) {
Some(primitives.int32)
} else if operand_obj_id.is_some_and(|id| PRIMITIVE_DEF_IDS.iter().any(|prim_id| id == prim_id)) {
Some(operand)
} else {
None
}
}
Unaryop::UAdd
| Unaryop::USub => { | Unaryop::USub => {
if operand.obj_id(unifier).is_some_and(|id| PRIMITIVE_DEF_IDS.iter().any(|prim_id| id == prim_id)) { if operand_obj_id.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
let (dtype, _) = unpack_ndarray_var_tys(unifier, operand);
if dtype.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) {
return Err(if *op == Unaryop::UAdd {
"The ufunc 'positive' cannot be applied to ndarray[bool, N]".to_string()
} else {
"The numpy boolean negative, the `-` operator, is not supported, use the `~` operator function instead.".to_string()
})
}
Some(operand)
} else if operand_obj_id.is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) {
Some(primitives.int32)
} else if operand_obj_id.is_some_and(|id| PRIMITIVE_DEF_IDS.iter().any(|prim_id| id == prim_id)) {
Some(operand) Some(operand)
} else { } else {
None None
@ -571,7 +595,9 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
/* bool ======== */ /* bool ======== */
let ndarray_bool_t = make_ndarray_ty(unifier, store, Some(bool_t), None); let ndarray_bool_t = make_ndarray_ty(unifier, store, Some(bool_t), None);
impl_invert(unifier, store, bool_t, Some(int32_t));
impl_not(unifier, store, bool_t, Some(bool_t)); impl_not(unifier, store, bool_t, Some(bool_t));
impl_sign(unifier, store, bool_t, Some(int32_t));
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None); impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None);
/* ndarray ===== */ /* ndarray ===== */

View File

@ -1,5 +1,9 @@
from __future__ import annotations from __future__ import annotations
@extern
def output_bool(x: bool):
...
@extern @extern
def output_int32(x: int32): def output_int32(x: int32):
... ...
@ -17,6 +21,7 @@ def output_float64(x: float):
... ...
def run() -> int32: def run() -> int32:
test_bool()
test_int32() test_int32()
test_uint32() test_uint32()
test_int64() test_int64()
@ -25,6 +30,18 @@ def run() -> int32:
# test_B() # test_B()
return 0 return 0
def test_bool():
t = True
f = False
output_bool(not t)
output_bool(not f)
output_int32(~t)
output_int32(~f)
output_int32(+t)
output_int32(+f)
output_int32(-t)
output_int32(-f)
def test_int32(): def test_int32():
a = 17 a = 17
b = 3 b = 3