core/magic_methods: Add typeof_*op
Used to determine the expected type of the binary operator with primitive operands.
This commit is contained in:
parent
1a09ea126d
commit
42c482f897
|
@ -1,3 +1,7 @@
|
||||||
|
use std::cmp::max;
|
||||||
|
use crate::symbol_resolver::SymbolValue;
|
||||||
|
use crate::toplevel::helper::PRIMITIVE_DEF_IDS;
|
||||||
|
use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys};
|
||||||
use crate::typecheck::{
|
use crate::typecheck::{
|
||||||
type_inferencer::*,
|
type_inferencer::*,
|
||||||
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
|
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
|
||||||
|
@ -6,6 +10,7 @@ use nac3parser::ast::StrRef;
|
||||||
use nac3parser::ast::{Cmpop, Operator, Unaryop};
|
use nac3parser::ast::{Cmpop, Operator, Unaryop};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
|
use itertools::Itertools;
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn binop_name(op: &Operator) -> &'static str {
|
pub fn binop_name(op: &Operator) -> &'static str {
|
||||||
|
@ -293,6 +298,137 @@ pub fn impl_eq(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
|
||||||
impl_cmpop(unifier, store, ty, ty, &[Cmpop::Eq, Cmpop::NotEq]);
|
impl_cmpop(unifier, store, ty, ty, &[Cmpop::Eq, Cmpop::NotEq]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the expected return type of binary operations with at least one `ndarray` operand.
|
||||||
|
pub fn typeof_ndarray_broadcast(
|
||||||
|
unifier: &mut Unifier,
|
||||||
|
primitives: &PrimitiveStore,
|
||||||
|
left: Type,
|
||||||
|
right: Type,
|
||||||
|
) -> Result<Type, String> {
|
||||||
|
let is_left_ndarray = left.get_obj_id(unifier) == PRIMITIVE_DEF_IDS.ndarray;
|
||||||
|
let is_right_ndarray = right.get_obj_id(unifier) == PRIMITIVE_DEF_IDS.ndarray;
|
||||||
|
|
||||||
|
assert!(is_left_ndarray || is_right_ndarray);
|
||||||
|
|
||||||
|
if is_left_ndarray && is_right_ndarray {
|
||||||
|
// Perform broadcasting on two ndarray operands.
|
||||||
|
|
||||||
|
let (left_ty_dtype, left_ty_ndims) = unpack_ndarray_var_tys(unifier, left);
|
||||||
|
let (right_ty_dtype, right_ty_ndims) = unpack_ndarray_var_tys(unifier, right);
|
||||||
|
|
||||||
|
assert!(unifier.unioned(left_ty_dtype, right_ty_dtype));
|
||||||
|
|
||||||
|
let left_ty_ndims = match &*unifier.get_ty_immutable(left_ty_ndims) {
|
||||||
|
TypeEnum::TLiteral { values, .. } => values.clone(),
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
let right_ty_ndims = match &*unifier.get_ty_immutable(right_ty_ndims) {
|
||||||
|
TypeEnum::TLiteral { values, .. } => values.clone(),
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let res_ndims = left_ty_ndims.into_iter()
|
||||||
|
.cartesian_product(right_ty_ndims.into_iter())
|
||||||
|
.map(|(left, right)| {
|
||||||
|
let left_val = u64::try_from(left).unwrap();
|
||||||
|
let right_val = u64::try_from(right).unwrap();
|
||||||
|
|
||||||
|
max(left_val, right_val)
|
||||||
|
})
|
||||||
|
.unique()
|
||||||
|
.map(|ndim| SymbolValue::U64(ndim))
|
||||||
|
.collect_vec();
|
||||||
|
let res_ndims = unifier.get_fresh_literal(res_ndims, None);
|
||||||
|
|
||||||
|
Ok(make_ndarray_ty(unifier, primitives, Some(left_ty_dtype), Some(res_ndims)))
|
||||||
|
} else {
|
||||||
|
let (ndarray_ty, scalar_ty) = if is_left_ndarray {
|
||||||
|
(left, right)
|
||||||
|
} else {
|
||||||
|
(right, left)
|
||||||
|
};
|
||||||
|
|
||||||
|
let (ndarray_ty_dtype, _) = unpack_ndarray_var_tys(unifier, ndarray_ty);
|
||||||
|
|
||||||
|
if !unifier.unioned(ndarray_ty_dtype, scalar_ty) {
|
||||||
|
let (expected_ty, actual_ty) = if is_left_ndarray {
|
||||||
|
(ndarray_ty_dtype, scalar_ty)
|
||||||
|
} else {
|
||||||
|
(scalar_ty, ndarray_ty_dtype)
|
||||||
|
};
|
||||||
|
|
||||||
|
Err(format!(
|
||||||
|
"Expected right-hand side operand to be {}, got {}",
|
||||||
|
unifier.stringify(expected_ty),
|
||||||
|
unifier.stringify(actual_ty),
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
Ok(ndarray_ty)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the return type given a binary operator and its primitive operands.
|
||||||
|
pub fn typeof_binop(
|
||||||
|
unifier: &mut Unifier,
|
||||||
|
primitives: &PrimitiveStore,
|
||||||
|
op: &Operator,
|
||||||
|
lhs: Type,
|
||||||
|
rhs: Type,
|
||||||
|
) -> Result<Option<Type>, String> {
|
||||||
|
let is_left_ndarray = lhs.get_obj_id(unifier) == PRIMITIVE_DEF_IDS.ndarray;
|
||||||
|
let is_right_ndarray = rhs.get_obj_id(unifier) == PRIMITIVE_DEF_IDS.ndarray;
|
||||||
|
|
||||||
|
Ok(Some(match op {
|
||||||
|
Operator::Add
|
||||||
|
| Operator::Sub
|
||||||
|
| Operator::Mult
|
||||||
|
| Operator::Mod
|
||||||
|
| Operator::FloorDiv => {
|
||||||
|
if is_left_ndarray || is_right_ndarray {
|
||||||
|
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
|
||||||
|
} else if unifier.unioned(lhs, rhs) {
|
||||||
|
lhs
|
||||||
|
} else {
|
||||||
|
return Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Operator::MatMult => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?,
|
||||||
|
Operator::Div => {
|
||||||
|
if is_left_ndarray || is_right_ndarray {
|
||||||
|
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
|
||||||
|
} else if unifier.unioned(lhs, rhs) {
|
||||||
|
primitives.float
|
||||||
|
} else {
|
||||||
|
return Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Operator::Pow => {
|
||||||
|
if is_left_ndarray || is_right_ndarray {
|
||||||
|
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
|
||||||
|
} else if [primitives.int32, primitives.int64, primitives.uint32, primitives.uint64, primitives.float].into_iter().any(|ty| unifier.unioned(lhs, ty)) {
|
||||||
|
lhs
|
||||||
|
} else {
|
||||||
|
return Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Operator::LShift
|
||||||
|
| Operator::RShift
|
||||||
|
| Operator::BitOr
|
||||||
|
| Operator::BitXor
|
||||||
|
| Operator::BitAnd => {
|
||||||
|
if unifier.unioned(lhs, rhs) {
|
||||||
|
lhs
|
||||||
|
} else {
|
||||||
|
return Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) {
|
pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) {
|
||||||
let PrimitiveStore {
|
let PrimitiveStore {
|
||||||
int32: int32_t,
|
int32: int32_t,
|
||||||
|
|
Loading…
Reference in New Issue