forked from M-Labs/nac3
core/magic_methods: Add typeof_*op
Used to determine the expected type of the binary operator with primitive operands.
This commit is contained in:
parent
3a6c53d760
commit
3540d0ab29
nac3core/src
@ -170,13 +170,13 @@ impl SymbolValue {
|
|||||||
/// Returns the [`TypeAnnotation`] representing the data type of this value.
|
/// Returns the [`TypeAnnotation`] representing the data type of this value.
|
||||||
pub fn get_type_annotation(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> TypeAnnotation {
|
pub fn get_type_annotation(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> TypeAnnotation {
|
||||||
match self {
|
match self {
|
||||||
SymbolValue::Bool(..) => TypeAnnotation::Primitive(primitives.bool),
|
SymbolValue::Bool(..)
|
||||||
SymbolValue::Double(..) => TypeAnnotation::Primitive(primitives.float),
|
| SymbolValue::Double(..)
|
||||||
SymbolValue::I32(..) => TypeAnnotation::Primitive(primitives.int32),
|
| SymbolValue::I32(..)
|
||||||
SymbolValue::I64(..) => TypeAnnotation::Primitive(primitives.int64),
|
| SymbolValue::I64(..)
|
||||||
SymbolValue::U32(..) => TypeAnnotation::Primitive(primitives.uint32),
|
| SymbolValue::U32(..)
|
||||||
SymbolValue::U64(..) => TypeAnnotation::Primitive(primitives.uint64),
|
| SymbolValue::U64(..)
|
||||||
SymbolValue::Str(..) => TypeAnnotation::Primitive(primitives.str),
|
| SymbolValue::Str(..) => TypeAnnotation::Primitive(self.get_type(primitives, unifier)),
|
||||||
SymbolValue::Tuple(vs) => {
|
SymbolValue::Tuple(vs) => {
|
||||||
let vs_tys = vs
|
let vs_tys = vs
|
||||||
.iter()
|
.iter()
|
||||||
@ -230,6 +230,38 @@ impl Display for SymbolValue {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl TryFrom<SymbolValue> for u64 {
|
||||||
|
type Error = ();
|
||||||
|
|
||||||
|
/// Tries to convert a [`SymbolValue`] into a [`u64`], returning [`Err`] if the value is not
|
||||||
|
/// numeric or if the value cannot be converted into a `u64` without overflow.
|
||||||
|
fn try_from(value: SymbolValue) -> Result<Self, Self::Error> {
|
||||||
|
match value {
|
||||||
|
SymbolValue::I32(v) => u64::try_from(v).map_err(|_| ()),
|
||||||
|
SymbolValue::I64(v) => u64::try_from(v).map_err(|_| ()),
|
||||||
|
SymbolValue::U32(v) => Ok(v as u64),
|
||||||
|
SymbolValue::U64(v) => Ok(v),
|
||||||
|
_ => Err(()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<SymbolValue> for i128 {
|
||||||
|
type Error = ();
|
||||||
|
|
||||||
|
/// Tries to convert a [`SymbolValue`] into a [`i128`], returning [`Err`] if the value is not
|
||||||
|
/// numeric.
|
||||||
|
fn try_from(value: SymbolValue) -> Result<Self, Self::Error> {
|
||||||
|
match value {
|
||||||
|
SymbolValue::I32(v) => Ok(v as i128),
|
||||||
|
SymbolValue::I64(v) => Ok(v as i128),
|
||||||
|
SymbolValue::U32(v) => Ok(v as i128),
|
||||||
|
SymbolValue::U64(v) => Ok(v as i128),
|
||||||
|
_ => Err(()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub trait StaticValue {
|
pub trait StaticValue {
|
||||||
/// Returns a unique identifier for this value.
|
/// Returns a unique identifier for this value.
|
||||||
fn get_unique_identifier(&self) -> u64;
|
fn get_unique_identifier(&self) -> u64;
|
||||||
|
@ -1,3 +1,7 @@
|
|||||||
|
use std::cmp::max;
|
||||||
|
use crate::symbol_resolver::SymbolValue;
|
||||||
|
use crate::toplevel::helper::PRIMITIVE_DEF_IDS;
|
||||||
|
use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys};
|
||||||
use crate::typecheck::{
|
use crate::typecheck::{
|
||||||
type_inferencer::*,
|
type_inferencer::*,
|
||||||
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
|
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
|
||||||
@ -6,6 +10,7 @@ use nac3parser::ast::StrRef;
|
|||||||
use nac3parser::ast::{Cmpop, Operator, Unaryop};
|
use nac3parser::ast::{Cmpop, Operator, Unaryop};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
|
use itertools::Itertools;
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn binop_name(op: &Operator) -> &'static str {
|
pub fn binop_name(op: &Operator) -> &'static str {
|
||||||
@ -330,6 +335,137 @@ pub fn impl_eq(
|
|||||||
impl_cmpop(unifier, store, ty, other_ty, &[Cmpop::Eq, Cmpop::NotEq], ret_ty);
|
impl_cmpop(unifier, store, ty, other_ty, &[Cmpop::Eq, Cmpop::NotEq], ret_ty);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the expected return type of binary operations with at least one `ndarray` operand.
|
||||||
|
pub fn typeof_ndarray_broadcast(
|
||||||
|
unifier: &mut Unifier,
|
||||||
|
primitives: &PrimitiveStore,
|
||||||
|
left: Type,
|
||||||
|
right: Type,
|
||||||
|
) -> Result<Type, String> {
|
||||||
|
let is_left_ndarray = left.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
|
||||||
|
let is_right_ndarray = right.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
|
||||||
|
|
||||||
|
assert!(is_left_ndarray || is_right_ndarray);
|
||||||
|
|
||||||
|
if is_left_ndarray && is_right_ndarray {
|
||||||
|
// Perform broadcasting on two ndarray operands.
|
||||||
|
|
||||||
|
let (left_ty_dtype, left_ty_ndims) = unpack_ndarray_var_tys(unifier, left);
|
||||||
|
let (right_ty_dtype, right_ty_ndims) = unpack_ndarray_var_tys(unifier, right);
|
||||||
|
|
||||||
|
assert!(unifier.unioned(left_ty_dtype, right_ty_dtype));
|
||||||
|
|
||||||
|
let left_ty_ndims = match &*unifier.get_ty_immutable(left_ty_ndims) {
|
||||||
|
TypeEnum::TLiteral { values, .. } => values.clone(),
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
let right_ty_ndims = match &*unifier.get_ty_immutable(right_ty_ndims) {
|
||||||
|
TypeEnum::TLiteral { values, .. } => values.clone(),
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let res_ndims = left_ty_ndims.into_iter()
|
||||||
|
.cartesian_product(right_ty_ndims)
|
||||||
|
.map(|(left, right)| {
|
||||||
|
let left_val = u64::try_from(left).unwrap();
|
||||||
|
let right_val = u64::try_from(right).unwrap();
|
||||||
|
|
||||||
|
max(left_val, right_val)
|
||||||
|
})
|
||||||
|
.unique()
|
||||||
|
.map(SymbolValue::U64)
|
||||||
|
.collect_vec();
|
||||||
|
let res_ndims = unifier.get_fresh_literal(res_ndims, None);
|
||||||
|
|
||||||
|
Ok(make_ndarray_ty(unifier, primitives, Some(left_ty_dtype), Some(res_ndims)))
|
||||||
|
} else {
|
||||||
|
let (ndarray_ty, scalar_ty) = if is_left_ndarray {
|
||||||
|
(left, right)
|
||||||
|
} else {
|
||||||
|
(right, left)
|
||||||
|
};
|
||||||
|
|
||||||
|
let (ndarray_ty_dtype, _) = unpack_ndarray_var_tys(unifier, ndarray_ty);
|
||||||
|
|
||||||
|
if unifier.unioned(ndarray_ty_dtype, scalar_ty) {
|
||||||
|
Ok(ndarray_ty)
|
||||||
|
} else {
|
||||||
|
let (expected_ty, actual_ty) = if is_left_ndarray {
|
||||||
|
(ndarray_ty_dtype, scalar_ty)
|
||||||
|
} else {
|
||||||
|
(scalar_ty, ndarray_ty_dtype)
|
||||||
|
};
|
||||||
|
|
||||||
|
Err(format!(
|
||||||
|
"Expected right-hand side operand to be {}, got {}",
|
||||||
|
unifier.stringify(expected_ty),
|
||||||
|
unifier.stringify(actual_ty),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the return type given a binary operator and its primitive operands.
|
||||||
|
pub fn typeof_binop(
|
||||||
|
unifier: &mut Unifier,
|
||||||
|
primitives: &PrimitiveStore,
|
||||||
|
op: &Operator,
|
||||||
|
lhs: Type,
|
||||||
|
rhs: Type,
|
||||||
|
) -> Result<Option<Type>, String> {
|
||||||
|
let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
|
||||||
|
let is_right_ndarray = rhs.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
|
||||||
|
|
||||||
|
Ok(Some(match op {
|
||||||
|
Operator::Add
|
||||||
|
| Operator::Sub
|
||||||
|
| Operator::Mult
|
||||||
|
| Operator::Mod
|
||||||
|
| Operator::FloorDiv => {
|
||||||
|
if is_left_ndarray || is_right_ndarray {
|
||||||
|
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
|
||||||
|
} else if unifier.unioned(lhs, rhs) {
|
||||||
|
lhs
|
||||||
|
} else {
|
||||||
|
return Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Operator::MatMult => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?,
|
||||||
|
Operator::Div => {
|
||||||
|
if is_left_ndarray || is_right_ndarray {
|
||||||
|
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
|
||||||
|
} else if unifier.unioned(lhs, rhs) {
|
||||||
|
primitives.float
|
||||||
|
} else {
|
||||||
|
return Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Operator::Pow => {
|
||||||
|
if is_left_ndarray || is_right_ndarray {
|
||||||
|
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
|
||||||
|
} else if [primitives.int32, primitives.int64, primitives.uint32, primitives.uint64, primitives.float].into_iter().any(|ty| unifier.unioned(lhs, ty)) {
|
||||||
|
lhs
|
||||||
|
} else {
|
||||||
|
return Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Operator::LShift
|
||||||
|
| Operator::RShift
|
||||||
|
| Operator::BitOr
|
||||||
|
| Operator::BitXor
|
||||||
|
| Operator::BitAnd => {
|
||||||
|
if unifier.unioned(lhs, rhs) {
|
||||||
|
lhs
|
||||||
|
} else {
|
||||||
|
return Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) {
|
pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) {
|
||||||
let PrimitiveStore {
|
let PrimitiveStore {
|
||||||
int32: int32_t,
|
int32: int32_t,
|
||||||
|
Loading…
Reference in New Issue
Block a user