From 3540d0ab29fc2660225b80de451d50d751713826 Mon Sep 17 00:00:00 2001
From: David Mak <chmakac@connect.ust.hk>
Date: Tue, 26 Mar 2024 19:05:44 +0800
Subject: [PATCH] core/magic_methods: Add typeof_*op

Used to determine the expected type of the binary operator with
primitive operands.
---
 nac3core/src/symbol_resolver.rs         |  46 ++++++--
 nac3core/src/typecheck/magic_methods.rs | 136 ++++++++++++++++++++++++
 2 files changed, 175 insertions(+), 7 deletions(-)

diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs
index 2696075..5e890e6 100644
--- a/nac3core/src/symbol_resolver.rs
+++ b/nac3core/src/symbol_resolver.rs
@@ -170,13 +170,13 @@ impl SymbolValue {
     /// Returns the [`TypeAnnotation`] representing the data type of this value.
     pub fn get_type_annotation(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> TypeAnnotation {
         match self {
-            SymbolValue::Bool(..) => TypeAnnotation::Primitive(primitives.bool),
-            SymbolValue::Double(..) => TypeAnnotation::Primitive(primitives.float),
-            SymbolValue::I32(..) => TypeAnnotation::Primitive(primitives.int32),
-            SymbolValue::I64(..) => TypeAnnotation::Primitive(primitives.int64),
-            SymbolValue::U32(..) => TypeAnnotation::Primitive(primitives.uint32),
-            SymbolValue::U64(..) => TypeAnnotation::Primitive(primitives.uint64),
-            SymbolValue::Str(..) => TypeAnnotation::Primitive(primitives.str),
+            SymbolValue::Bool(..)
+            | SymbolValue::Double(..)
+            | SymbolValue::I32(..)
+            | SymbolValue::I64(..)
+            | SymbolValue::U32(..)
+            | SymbolValue::U64(..)
+            | SymbolValue::Str(..) => TypeAnnotation::Primitive(self.get_type(primitives, unifier)),
             SymbolValue::Tuple(vs) => {
                 let vs_tys = vs
                     .iter()
@@ -230,6 +230,38 @@ impl Display for SymbolValue {
     }
 }
 
+impl TryFrom<SymbolValue> for u64 {
+    type Error = ();
+
+    /// Tries to convert a [`SymbolValue`] into a [`u64`], returning [`Err`] if the value is not
+    /// numeric or if the value cannot be converted into a `u64` without overflow.
+    fn try_from(value: SymbolValue) -> Result<Self, Self::Error> {
+        match value {
+            SymbolValue::I32(v) => u64::try_from(v).map_err(|_| ()),
+            SymbolValue::I64(v) => u64::try_from(v).map_err(|_| ()),
+            SymbolValue::U32(v) => Ok(v as u64),
+            SymbolValue::U64(v) => Ok(v),
+            _ => Err(()),
+        }
+    }
+}
+
+impl TryFrom<SymbolValue> for i128 {
+    type Error = ();
+
+    /// Tries to convert a [`SymbolValue`] into a [`i128`], returning [`Err`] if the value is not
+    /// numeric.
+    fn try_from(value: SymbolValue) -> Result<Self, Self::Error> {
+        match value {
+            SymbolValue::I32(v) => Ok(v as i128),
+            SymbolValue::I64(v) => Ok(v as i128),
+            SymbolValue::U32(v) => Ok(v as i128),
+            SymbolValue::U64(v) => Ok(v as i128),
+            _ => Err(()),
+        }
+    }
+}
+
 pub trait StaticValue {
     /// Returns a unique identifier for this value.
     fn get_unique_identifier(&self) -> u64;
diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs
index a11705f..ec0c064 100644
--- a/nac3core/src/typecheck/magic_methods.rs
+++ b/nac3core/src/typecheck/magic_methods.rs
@@ -1,3 +1,7 @@
+use std::cmp::max;
+use crate::symbol_resolver::SymbolValue;
+use crate::toplevel::helper::PRIMITIVE_DEF_IDS;
+use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys};
 use crate::typecheck::{
     type_inferencer::*,
     typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
@@ -6,6 +10,7 @@ use nac3parser::ast::StrRef;
 use nac3parser::ast::{Cmpop, Operator, Unaryop};
 use std::collections::HashMap;
 use std::rc::Rc;
+use itertools::Itertools;
 
 #[must_use]
 pub fn binop_name(op: &Operator) -> &'static str {
@@ -330,6 +335,137 @@ pub fn impl_eq(
     impl_cmpop(unifier, store, ty, other_ty, &[Cmpop::Eq, Cmpop::NotEq], ret_ty);
 }
 
+/// Returns the expected return type of binary operations with at least one `ndarray` operand.
+pub fn typeof_ndarray_broadcast(
+    unifier: &mut Unifier,
+    primitives: &PrimitiveStore,
+    left: Type,
+    right: Type,
+) -> Result<Type, String> {
+    let is_left_ndarray = left.obj_id(unifier).is_some_and(|id| id ==  PRIMITIVE_DEF_IDS.ndarray);
+    let is_right_ndarray = right.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
+
+    assert!(is_left_ndarray || is_right_ndarray);
+
+    if is_left_ndarray && is_right_ndarray {
+        // Perform broadcasting on two ndarray operands.
+
+        let (left_ty_dtype, left_ty_ndims) = unpack_ndarray_var_tys(unifier, left);
+        let (right_ty_dtype, right_ty_ndims) = unpack_ndarray_var_tys(unifier, right);
+
+        assert!(unifier.unioned(left_ty_dtype, right_ty_dtype));
+
+        let left_ty_ndims = match &*unifier.get_ty_immutable(left_ty_ndims) {
+            TypeEnum::TLiteral { values, .. } => values.clone(),
+            _ => unreachable!(),
+        };
+        let right_ty_ndims = match &*unifier.get_ty_immutable(right_ty_ndims) {
+            TypeEnum::TLiteral { values, .. } => values.clone(),
+            _ => unreachable!(),
+        };
+
+        let res_ndims = left_ty_ndims.into_iter()
+            .cartesian_product(right_ty_ndims)
+            .map(|(left, right)| {
+                let left_val = u64::try_from(left).unwrap();
+                let right_val = u64::try_from(right).unwrap();
+
+                max(left_val, right_val)
+            })
+            .unique()
+            .map(SymbolValue::U64)
+            .collect_vec();
+        let res_ndims = unifier.get_fresh_literal(res_ndims, None);
+
+        Ok(make_ndarray_ty(unifier, primitives, Some(left_ty_dtype), Some(res_ndims)))
+    } else {
+        let (ndarray_ty, scalar_ty) = if is_left_ndarray {
+            (left, right)
+        } else {
+            (right, left)
+        };
+
+        let (ndarray_ty_dtype, _) = unpack_ndarray_var_tys(unifier, ndarray_ty);
+
+        if unifier.unioned(ndarray_ty_dtype, scalar_ty) {
+            Ok(ndarray_ty)
+        } else {
+            let (expected_ty, actual_ty) = if is_left_ndarray {
+                (ndarray_ty_dtype, scalar_ty)
+            } else {
+                (scalar_ty, ndarray_ty_dtype)
+            };
+
+            Err(format!(
+                "Expected right-hand side operand to be {}, got {}",
+                unifier.stringify(expected_ty),
+                unifier.stringify(actual_ty),
+            ))
+        }
+    }
+}
+
+/// Returns the return type given a binary operator and its primitive operands.
+pub fn typeof_binop(
+    unifier: &mut Unifier,
+    primitives: &PrimitiveStore,
+    op: &Operator,
+    lhs: Type,
+    rhs: Type,
+) -> Result<Option<Type>, String> {
+    let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id ==  PRIMITIVE_DEF_IDS.ndarray);
+    let is_right_ndarray = rhs.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
+
+    Ok(Some(match op {
+        Operator::Add
+        | Operator::Sub
+        | Operator::Mult
+        | Operator::Mod
+        | Operator::FloorDiv => {
+            if is_left_ndarray || is_right_ndarray {
+                typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
+            } else if unifier.unioned(lhs, rhs) {
+                lhs
+            } else {
+                return Ok(None)
+            }
+        }
+
+        Operator::MatMult => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?,
+        Operator::Div => {
+            if is_left_ndarray || is_right_ndarray {
+                typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
+            } else if unifier.unioned(lhs, rhs) {
+                primitives.float
+            } else {
+                return Ok(None)
+            }
+        }
+
+        Operator::Pow => {
+            if is_left_ndarray || is_right_ndarray {
+                typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
+            } else if [primitives.int32, primitives.int64, primitives.uint32, primitives.uint64, primitives.float].into_iter().any(|ty| unifier.unioned(lhs, ty)) {
+                lhs
+            } else {
+                return Ok(None)
+            }
+        }
+
+        Operator::LShift
+        | Operator::RShift
+        | Operator::BitOr
+        | Operator::BitXor
+        | Operator::BitAnd => {
+            if unifier.unioned(lhs, rhs) {
+                lhs
+            } else {
+                return Ok(None)
+            }
+        }
+    }))
+}
+
 pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) {
     let PrimitiveStore {
         int32: int32_t,