forked from M-Labs/nac3
core/magic_methods: Add typeof_*op
Used to determine the expected type of the binary operator with primitive operands.
This commit is contained in:
parent
3a6c53d760
commit
3540d0ab29
|
@ -170,13 +170,13 @@ impl SymbolValue {
|
|||
/// Returns the [`TypeAnnotation`] representing the data type of this value.
|
||||
pub fn get_type_annotation(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> TypeAnnotation {
|
||||
match self {
|
||||
SymbolValue::Bool(..) => TypeAnnotation::Primitive(primitives.bool),
|
||||
SymbolValue::Double(..) => TypeAnnotation::Primitive(primitives.float),
|
||||
SymbolValue::I32(..) => TypeAnnotation::Primitive(primitives.int32),
|
||||
SymbolValue::I64(..) => TypeAnnotation::Primitive(primitives.int64),
|
||||
SymbolValue::U32(..) => TypeAnnotation::Primitive(primitives.uint32),
|
||||
SymbolValue::U64(..) => TypeAnnotation::Primitive(primitives.uint64),
|
||||
SymbolValue::Str(..) => TypeAnnotation::Primitive(primitives.str),
|
||||
SymbolValue::Bool(..)
|
||||
| SymbolValue::Double(..)
|
||||
| SymbolValue::I32(..)
|
||||
| SymbolValue::I64(..)
|
||||
| SymbolValue::U32(..)
|
||||
| SymbolValue::U64(..)
|
||||
| SymbolValue::Str(..) => TypeAnnotation::Primitive(self.get_type(primitives, unifier)),
|
||||
SymbolValue::Tuple(vs) => {
|
||||
let vs_tys = vs
|
||||
.iter()
|
||||
|
@ -230,6 +230,38 @@ impl Display for SymbolValue {
|
|||
}
|
||||
}
|
||||
|
||||
impl TryFrom<SymbolValue> for u64 {
|
||||
type Error = ();
|
||||
|
||||
/// Tries to convert a [`SymbolValue`] into a [`u64`], returning [`Err`] if the value is not
|
||||
/// numeric or if the value cannot be converted into a `u64` without overflow.
|
||||
fn try_from(value: SymbolValue) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
SymbolValue::I32(v) => u64::try_from(v).map_err(|_| ()),
|
||||
SymbolValue::I64(v) => u64::try_from(v).map_err(|_| ()),
|
||||
SymbolValue::U32(v) => Ok(v as u64),
|
||||
SymbolValue::U64(v) => Ok(v),
|
||||
_ => Err(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<SymbolValue> for i128 {
|
||||
type Error = ();
|
||||
|
||||
/// Tries to convert a [`SymbolValue`] into a [`i128`], returning [`Err`] if the value is not
|
||||
/// numeric.
|
||||
fn try_from(value: SymbolValue) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
SymbolValue::I32(v) => Ok(v as i128),
|
||||
SymbolValue::I64(v) => Ok(v as i128),
|
||||
SymbolValue::U32(v) => Ok(v as i128),
|
||||
SymbolValue::U64(v) => Ok(v as i128),
|
||||
_ => Err(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait StaticValue {
|
||||
/// Returns a unique identifier for this value.
|
||||
fn get_unique_identifier(&self) -> u64;
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
use std::cmp::max;
|
||||
use crate::symbol_resolver::SymbolValue;
|
||||
use crate::toplevel::helper::PRIMITIVE_DEF_IDS;
|
||||
use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys};
|
||||
use crate::typecheck::{
|
||||
type_inferencer::*,
|
||||
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
|
||||
|
@ -6,6 +10,7 @@ use nac3parser::ast::StrRef;
|
|||
use nac3parser::ast::{Cmpop, Operator, Unaryop};
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
use itertools::Itertools;
|
||||
|
||||
#[must_use]
|
||||
pub fn binop_name(op: &Operator) -> &'static str {
|
||||
|
@ -330,6 +335,137 @@ pub fn impl_eq(
|
|||
impl_cmpop(unifier, store, ty, other_ty, &[Cmpop::Eq, Cmpop::NotEq], ret_ty);
|
||||
}
|
||||
|
||||
/// Returns the expected return type of binary operations with at least one `ndarray` operand.
|
||||
pub fn typeof_ndarray_broadcast(
|
||||
unifier: &mut Unifier,
|
||||
primitives: &PrimitiveStore,
|
||||
left: Type,
|
||||
right: Type,
|
||||
) -> Result<Type, String> {
|
||||
let is_left_ndarray = left.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
|
||||
let is_right_ndarray = right.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
|
||||
|
||||
assert!(is_left_ndarray || is_right_ndarray);
|
||||
|
||||
if is_left_ndarray && is_right_ndarray {
|
||||
// Perform broadcasting on two ndarray operands.
|
||||
|
||||
let (left_ty_dtype, left_ty_ndims) = unpack_ndarray_var_tys(unifier, left);
|
||||
let (right_ty_dtype, right_ty_ndims) = unpack_ndarray_var_tys(unifier, right);
|
||||
|
||||
assert!(unifier.unioned(left_ty_dtype, right_ty_dtype));
|
||||
|
||||
let left_ty_ndims = match &*unifier.get_ty_immutable(left_ty_ndims) {
|
||||
TypeEnum::TLiteral { values, .. } => values.clone(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let right_ty_ndims = match &*unifier.get_ty_immutable(right_ty_ndims) {
|
||||
TypeEnum::TLiteral { values, .. } => values.clone(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let res_ndims = left_ty_ndims.into_iter()
|
||||
.cartesian_product(right_ty_ndims)
|
||||
.map(|(left, right)| {
|
||||
let left_val = u64::try_from(left).unwrap();
|
||||
let right_val = u64::try_from(right).unwrap();
|
||||
|
||||
max(left_val, right_val)
|
||||
})
|
||||
.unique()
|
||||
.map(SymbolValue::U64)
|
||||
.collect_vec();
|
||||
let res_ndims = unifier.get_fresh_literal(res_ndims, None);
|
||||
|
||||
Ok(make_ndarray_ty(unifier, primitives, Some(left_ty_dtype), Some(res_ndims)))
|
||||
} else {
|
||||
let (ndarray_ty, scalar_ty) = if is_left_ndarray {
|
||||
(left, right)
|
||||
} else {
|
||||
(right, left)
|
||||
};
|
||||
|
||||
let (ndarray_ty_dtype, _) = unpack_ndarray_var_tys(unifier, ndarray_ty);
|
||||
|
||||
if unifier.unioned(ndarray_ty_dtype, scalar_ty) {
|
||||
Ok(ndarray_ty)
|
||||
} else {
|
||||
let (expected_ty, actual_ty) = if is_left_ndarray {
|
||||
(ndarray_ty_dtype, scalar_ty)
|
||||
} else {
|
||||
(scalar_ty, ndarray_ty_dtype)
|
||||
};
|
||||
|
||||
Err(format!(
|
||||
"Expected right-hand side operand to be {}, got {}",
|
||||
unifier.stringify(expected_ty),
|
||||
unifier.stringify(actual_ty),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the return type given a binary operator and its primitive operands.
|
||||
pub fn typeof_binop(
|
||||
unifier: &mut Unifier,
|
||||
primitives: &PrimitiveStore,
|
||||
op: &Operator,
|
||||
lhs: Type,
|
||||
rhs: Type,
|
||||
) -> Result<Option<Type>, String> {
|
||||
let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
|
||||
let is_right_ndarray = rhs.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
|
||||
|
||||
Ok(Some(match op {
|
||||
Operator::Add
|
||||
| Operator::Sub
|
||||
| Operator::Mult
|
||||
| Operator::Mod
|
||||
| Operator::FloorDiv => {
|
||||
if is_left_ndarray || is_right_ndarray {
|
||||
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
|
||||
} else if unifier.unioned(lhs, rhs) {
|
||||
lhs
|
||||
} else {
|
||||
return Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
Operator::MatMult => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?,
|
||||
Operator::Div => {
|
||||
if is_left_ndarray || is_right_ndarray {
|
||||
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
|
||||
} else if unifier.unioned(lhs, rhs) {
|
||||
primitives.float
|
||||
} else {
|
||||
return Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
Operator::Pow => {
|
||||
if is_left_ndarray || is_right_ndarray {
|
||||
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
|
||||
} else if [primitives.int32, primitives.int64, primitives.uint32, primitives.uint64, primitives.float].into_iter().any(|ty| unifier.unioned(lhs, ty)) {
|
||||
lhs
|
||||
} else {
|
||||
return Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
Operator::LShift
|
||||
| Operator::RShift
|
||||
| Operator::BitOr
|
||||
| Operator::BitXor
|
||||
| Operator::BitAnd => {
|
||||
if unifier.unioned(lhs, rhs) {
|
||||
lhs
|
||||
} else {
|
||||
return Ok(None)
|
||||
}
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) {
|
||||
let PrimitiveStore {
|
||||
int32: int32_t,
|
||||
|
|
Loading…
Reference in New Issue