diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 269607543..5e890e698 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -170,13 +170,13 @@ impl SymbolValue { /// Returns the [`TypeAnnotation`] representing the data type of this value. pub fn get_type_annotation(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> TypeAnnotation { match self { - SymbolValue::Bool(..) => TypeAnnotation::Primitive(primitives.bool), - SymbolValue::Double(..) => TypeAnnotation::Primitive(primitives.float), - SymbolValue::I32(..) => TypeAnnotation::Primitive(primitives.int32), - SymbolValue::I64(..) => TypeAnnotation::Primitive(primitives.int64), - SymbolValue::U32(..) => TypeAnnotation::Primitive(primitives.uint32), - SymbolValue::U64(..) => TypeAnnotation::Primitive(primitives.uint64), - SymbolValue::Str(..) => TypeAnnotation::Primitive(primitives.str), + SymbolValue::Bool(..) + | SymbolValue::Double(..) + | SymbolValue::I32(..) + | SymbolValue::I64(..) + | SymbolValue::U32(..) + | SymbolValue::U64(..) + | SymbolValue::Str(..) => TypeAnnotation::Primitive(self.get_type(primitives, unifier)), SymbolValue::Tuple(vs) => { let vs_tys = vs .iter() @@ -230,6 +230,38 @@ impl Display for SymbolValue { } } +impl TryFrom for u64 { + type Error = (); + + /// Tries to convert a [`SymbolValue`] into a [`u64`], returning [`Err`] if the value is not + /// numeric or if the value cannot be converted into a `u64` without overflow. + fn try_from(value: SymbolValue) -> Result { + match value { + SymbolValue::I32(v) => u64::try_from(v).map_err(|_| ()), + SymbolValue::I64(v) => u64::try_from(v).map_err(|_| ()), + SymbolValue::U32(v) => Ok(v as u64), + SymbolValue::U64(v) => Ok(v), + _ => Err(()), + } + } +} + +impl TryFrom for i128 { + type Error = (); + + /// Tries to convert a [`SymbolValue`] into a [`i128`], returning [`Err`] if the value is not + /// numeric. + fn try_from(value: SymbolValue) -> Result { + match value { + SymbolValue::I32(v) => Ok(v as i128), + SymbolValue::I64(v) => Ok(v as i128), + SymbolValue::U32(v) => Ok(v as i128), + SymbolValue::U64(v) => Ok(v as i128), + _ => Err(()), + } + } +} + pub trait StaticValue { /// Returns a unique identifier for this value. fn get_unique_identifier(&self) -> u64; diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index a11705fae..ec0c064cc 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -1,3 +1,7 @@ +use std::cmp::max; +use crate::symbol_resolver::SymbolValue; +use crate::toplevel::helper::PRIMITIVE_DEF_IDS; +use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys}; use crate::typecheck::{ type_inferencer::*, typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, @@ -6,6 +10,7 @@ use nac3parser::ast::StrRef; use nac3parser::ast::{Cmpop, Operator, Unaryop}; use std::collections::HashMap; use std::rc::Rc; +use itertools::Itertools; #[must_use] pub fn binop_name(op: &Operator) -> &'static str { @@ -330,6 +335,137 @@ pub fn impl_eq( impl_cmpop(unifier, store, ty, other_ty, &[Cmpop::Eq, Cmpop::NotEq], ret_ty); } +/// Returns the expected return type of binary operations with at least one `ndarray` operand. +pub fn typeof_ndarray_broadcast( + unifier: &mut Unifier, + primitives: &PrimitiveStore, + left: Type, + right: Type, +) -> Result { + let is_left_ndarray = left.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_right_ndarray = right.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + + assert!(is_left_ndarray || is_right_ndarray); + + if is_left_ndarray && is_right_ndarray { + // Perform broadcasting on two ndarray operands. + + let (left_ty_dtype, left_ty_ndims) = unpack_ndarray_var_tys(unifier, left); + let (right_ty_dtype, right_ty_ndims) = unpack_ndarray_var_tys(unifier, right); + + assert!(unifier.unioned(left_ty_dtype, right_ty_dtype)); + + let left_ty_ndims = match &*unifier.get_ty_immutable(left_ty_ndims) { + TypeEnum::TLiteral { values, .. } => values.clone(), + _ => unreachable!(), + }; + let right_ty_ndims = match &*unifier.get_ty_immutable(right_ty_ndims) { + TypeEnum::TLiteral { values, .. } => values.clone(), + _ => unreachable!(), + }; + + let res_ndims = left_ty_ndims.into_iter() + .cartesian_product(right_ty_ndims) + .map(|(left, right)| { + let left_val = u64::try_from(left).unwrap(); + let right_val = u64::try_from(right).unwrap(); + + max(left_val, right_val) + }) + .unique() + .map(SymbolValue::U64) + .collect_vec(); + let res_ndims = unifier.get_fresh_literal(res_ndims, None); + + Ok(make_ndarray_ty(unifier, primitives, Some(left_ty_dtype), Some(res_ndims))) + } else { + let (ndarray_ty, scalar_ty) = if is_left_ndarray { + (left, right) + } else { + (right, left) + }; + + let (ndarray_ty_dtype, _) = unpack_ndarray_var_tys(unifier, ndarray_ty); + + if unifier.unioned(ndarray_ty_dtype, scalar_ty) { + Ok(ndarray_ty) + } else { + let (expected_ty, actual_ty) = if is_left_ndarray { + (ndarray_ty_dtype, scalar_ty) + } else { + (scalar_ty, ndarray_ty_dtype) + }; + + Err(format!( + "Expected right-hand side operand to be {}, got {}", + unifier.stringify(expected_ty), + unifier.stringify(actual_ty), + )) + } + } +} + +/// Returns the return type given a binary operator and its primitive operands. +pub fn typeof_binop( + unifier: &mut Unifier, + primitives: &PrimitiveStore, + op: &Operator, + lhs: Type, + rhs: Type, +) -> Result, String> { + let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_right_ndarray = rhs.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + + Ok(Some(match op { + Operator::Add + | Operator::Sub + | Operator::Mult + | Operator::Mod + | Operator::FloorDiv => { + if is_left_ndarray || is_right_ndarray { + typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)? + } else if unifier.unioned(lhs, rhs) { + lhs + } else { + return Ok(None) + } + } + + Operator::MatMult => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?, + Operator::Div => { + if is_left_ndarray || is_right_ndarray { + typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)? + } else if unifier.unioned(lhs, rhs) { + primitives.float + } else { + return Ok(None) + } + } + + Operator::Pow => { + if is_left_ndarray || is_right_ndarray { + typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)? + } else if [primitives.int32, primitives.int64, primitives.uint32, primitives.uint64, primitives.float].into_iter().any(|ty| unifier.unioned(lhs, ty)) { + lhs + } else { + return Ok(None) + } + } + + Operator::LShift + | Operator::RShift + | Operator::BitOr + | Operator::BitXor + | Operator::BitAnd => { + if unifier.unioned(lhs, rhs) { + lhs + } else { + return Ok(None) + } + } + })) +} + pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) { let PrimitiveStore { int32: int32_t,