diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 3b8b85d5d..8c5429a49 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1202,11 +1202,11 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( { let llvm_usize = generator.get_size_type(ctx.ctx); - if is_aug_assign { + if op.variant == BinopVariant::AugAssign { todo!("Augmented assignment operators not implemented for lists") } - match op { + match op.base { Operator::Add => { debug_assert_eq!(ty1.obj_id(&ctx.unifier), Some(PrimDef::List.id())); debug_assert_eq!(ty2.obj_id(&ctx.unifier), Some(PrimDef::List.id())); diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index 929abb315..8547193e7 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -486,18 +486,20 @@ pub fn typeof_binop( lhs: Type, rhs: Type, ) -> Result, String> { + let op = Binop { base: op, variant: BinopVariant::Normal }; + let is_left_list = lhs.obj_id(unifier).is_some_and(|id| id == PrimDef::List.id()); let is_right_list = rhs.obj_id(unifier).is_some_and(|id| id == PrimDef::List.id()); let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let is_right_ndarray = rhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - Ok(Some(match op { + Ok(Some(match op.base { Operator::Add | Operator::Sub | Operator::Mult | Operator::Mod | Operator::FloorDiv => { if is_left_list || is_right_list { - if ![Operator::Add, Operator::Mult].contains(&op) { + if ![Operator::Add, Operator::Mult].contains(&op.base) { return Err(format!( "Binary operator {} not supported for list", - binop_name(op) + op.op_info().symbol )); } diff --git a/nac3core/src/typecheck/type_error.rs b/nac3core/src/typecheck/type_error.rs index d03b3f1c1..0d84b87ea 100644 --- a/nac3core/src/typecheck/type_error.rs +++ b/nac3core/src/typecheck/type_error.rs @@ -1,11 +1,14 @@ use std::collections::HashMap; use std::fmt::Display; -use crate::typecheck::typedef::TypeEnum; +use crate::typecheck::{magic_methods::HasOpInfo, typedef::TypeEnum}; -use super::typedef::{RecordKey, Type, Unifier}; +use super::{ + magic_methods::Binop, + typedef::{RecordKey, Type, Unifier}, +}; use itertools::Itertools; -use nac3parser::ast::{Location, StrRef}; +use nac3parser::ast::{Cmpop, Location, StrRef}; #[derive(Debug, Clone)] pub enum TypeErrorKind { @@ -26,6 +29,18 @@ pub enum TypeErrorKind { expected: Type, got: Type, }, + UnsupportedBinaryOpTypes { + operator: Binop, + lhs_type: Type, + rhs_type: Type, + expected_rhs_type: Type, + }, + UnsupportedComparsionOpTypes { + operator: Cmpop, + lhs_type: Type, + rhs_type: Type, + expected_rhs_type: Type, + }, FieldUnificationError { field: RecordKey, types: (Type, Type), @@ -101,6 +116,26 @@ impl<'a> Display for DisplayTypeError<'a> { let args = missing_arg_names.iter().join(", "); write!(f, "Missing arguments: {args}") } + UnsupportedBinaryOpTypes { operator, lhs_type, rhs_type, expected_rhs_type } => { + let op_symbol = operator.op_info().symbol; + + let lhs_type_str = self.unifier.stringify_with_notes(*lhs_type, &mut notes); + let rhs_type_str = self.unifier.stringify_with_notes(*rhs_type, &mut notes); + let expected_rhs_type_str = + self.unifier.stringify_with_notes(*expected_rhs_type, &mut notes); + + write!(f, "Unsupported operand type(s) for {op_symbol}: '{lhs_type_str}' and '{rhs_type_str}' (right operand should have type {expected_rhs_type_str})") + } + UnsupportedComparsionOpTypes { operator, lhs_type, rhs_type, expected_rhs_type } => { + let op_symbol = operator.op_info().symbol; + + let lhs_type_str = self.unifier.stringify_with_notes(*lhs_type, &mut notes); + let rhs_type_str = self.unifier.stringify_with_notes(*rhs_type, &mut notes); + let expected_rhs_type_str = + self.unifier.stringify_with_notes(*expected_rhs_type, &mut notes); + + write!(f, "'{op_symbol}' not supported between instances of '{lhs_type_str}' and '{rhs_type_str}' (right operand should have type {expected_rhs_type_str})") + } UnknownArgName(name) => { write!(f, "Unknown argument name: {name}") } diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 8dc48fb71..c7ec7784b 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -4,6 +4,7 @@ use std::iter::once; use std::ops::Not; use std::{cell::RefCell, sync::Arc}; +use super::typedef::OperatorInfo; use super::{ magic_methods::*, type_error::TypeError, @@ -641,6 +642,7 @@ impl<'a> Inferencer<'a> { obj: Type, params: Vec, ret: Option, + operator_info: Option, ) -> InferenceResult { if let TypeEnum::TObj { params: class_params, fields, .. } = &*self.unifier.get_ty(obj) { if class_params.is_empty() { @@ -654,6 +656,7 @@ impl<'a> Inferencer<'a> { ret: sign.ret, fun: RefCell::new(None), loc: Some(location), + operator_info, }; if let Some(ret) = ret { self.unifier @@ -688,6 +691,7 @@ impl<'a> Inferencer<'a> { ret, fun: RefCell::new(None), loc: Some(location), + operator_info, }); self.calls.insert(location.into(), call); let call = self.unifier.add_ty(TypeEnum::TCall(vec![call])); @@ -1523,6 +1527,7 @@ impl<'a> Inferencer<'a> { fun: RefCell::new(None), ret: sign.ret, loc: Some(location), + operator_info: None, }; self.unifier.unify_call(&call, func.custom.unwrap(), sign).map_err(|e| { HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()]) @@ -1545,6 +1550,7 @@ impl<'a> Inferencer<'a> { fun: RefCell::new(None), ret, loc: Some(location), + operator_info: None, }); self.calls.insert(location.into(), call); let call = self.unifier.add_ty(TypeEnum::TCall(vec![call])); @@ -1765,7 +1771,14 @@ impl<'a> Inferencer<'a> { } }; - self.build_method_call(location, method.into(), left_ty, vec![right_ty], ret) + self.build_method_call( + location, + method.into(), + left_ty, + vec![right_ty], + ret, + Some(OperatorInfo::IsBinaryOp { self_type: left.custom.unwrap(), operator: op }), + ) } fn infer_unary_ops( @@ -1779,7 +1792,14 @@ impl<'a> Inferencer<'a> { let ret = typeof_unaryop(self.unifier, self.primitives, op, operand.custom.unwrap()) .map_err(|e| HashSet::from([format!("{e} (at {location})")]))?; - self.build_method_call(location, method, operand.custom.unwrap(), vec![], ret) + self.build_method_call( + location, + method, + operand.custom.unwrap(), + vec![], + ret, + Some(OperatorInfo::IsUnaryOp { self_type: operand.custom.unwrap(), operator: op }), + ) } fn infer_compare( @@ -1825,6 +1845,10 @@ impl<'a> Inferencer<'a> { a.custom.unwrap(), vec![b.custom.unwrap()], ret, + Some(OperatorInfo::IsComparisonOp { + self_type: left.custom.unwrap(), + operator: *c, + }), )?); } diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 79ca48f2e..845e84067 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -8,12 +8,15 @@ use std::rc::Rc; use std::sync::{Arc, Mutex}; use std::{borrow::Cow, collections::HashSet}; -use nac3parser::ast::{Location, StrRef}; +use nac3parser::ast::{Cmpop, Location, StrRef, Unaryop}; +use super::magic_methods::Binop; use super::type_error::{TypeError, TypeErrorKind}; use super::unification_table::{UnificationKey, UnificationTable}; use crate::symbol_resolver::SymbolValue; -use crate::toplevel::{helper::PrimDef, DefinitionId, TopLevelContext, TopLevelDef}; +use crate::toplevel::helper::PrimDef; +use crate::toplevel::{DefinitionId, TopLevelContext, TopLevelDef}; +use crate::typecheck::magic_methods::OpInfo; use crate::typecheck::type_inferencer::PrimitiveStore; #[cfg(test)] @@ -73,6 +76,28 @@ pub fn iter_type_vars(var_map: &VarMap) -> impl Iterator + '_ { var_map.iter().map(|(&id, &ty)| TypeVar { id, ty }) } +#[derive(Debug, Clone)] +pub enum OperatorInfo { + /// The call was written as an unary operation, e.g., `~a` or `not a`. + IsUnaryOp { + /// The [`Type`] of the `self` object + self_type: Type, + operator: Unaryop, + }, + /// The call was written as a binary operation, e.g., `a + b` or `a += b`. + IsBinaryOp { + /// The [`Type`] of the `self` object + self_type: Type, + operator: Binop, + }, + /// The call was written as a binary comparison operation, e.g., `a < b`. + IsComparisonOp { + /// The [`Type`] of the `self` object + self_type: Type, + operator: Cmpop, + }, +} + #[derive(Clone)] pub struct Call { pub posargs: Vec, @@ -80,6 +105,9 @@ pub struct Call { pub ret: Type, pub fun: RefCell>, pub loc: Option, + + /// Details about the associated Python user operator expression of this call, if any. + pub operator_info: Option, } #[derive(Debug, Clone)] @@ -627,111 +655,178 @@ impl Unifier { let TypeEnum::TFunc(signature) = &*self.get_ty(b) else { unreachable!() }; // Get details about the input arguments - let Call { posargs, kwargs, ret, fun, loc } = call; + let Call { posargs, kwargs, ret, fun, loc, operator_info } = call; let num_args = posargs.len() + kwargs.len(); - // Now we check the arguments against the parameters + // Now we check the arguments against the parameters, + // and depending on what `call_info` is, we might change how the behavior `unify_call()` + // in hopes to improve user error messages when type checking fails. + match operator_info { + Some(OperatorInfo::IsBinaryOp { self_type, operator }) => { + // The call is written in the form of (say) `a + b`. + // Technically, it is `a.__add__(b)`, and they have the following constraints: + assert_eq!(posargs.len(), 1); + assert_eq!(kwargs.len(), 0); + assert_eq!(num_params, 1); - // Helper lambdas - let mut type_check_arg = |param_name, expected_arg_ty, arg_ty| { - let ok = self.unify_impl(expected_arg_ty, arg_ty, false).is_ok(); - if ok { - Ok(()) - } else { - // Typecheck failed, throw an error. - self.restore_snapshot(); - Err(TypeError::new( - TypeErrorKind::IncorrectArgType { - name: param_name, - expected: expected_arg_ty, - got: arg_ty, - }, - *loc, - )) + let other_type = posargs[0]; // the second operand + let expected_other_type = signature.args[0].ty; + + let ok = self.unify_impl(expected_other_type, other_type, false).is_ok(); + if !ok { + self.restore_snapshot(); + return Err(TypeError::new( + TypeErrorKind::UnsupportedBinaryOpTypes { + operator: *operator, + lhs_type: *self_type, + rhs_type: other_type, + expected_rhs_type: expected_other_type, + }, + *loc, + )); + } } - }; + Some(OperatorInfo::IsComparisonOp { self_type, operator }) + if OpInfo::supports_cmpop(*operator) // Otherwise that comparison operator is not supported. + => + { + // The call is written in the form of (say) `a <= b`. + // Technically, it is `a.__le__(b)`, and they have the following constraints: + assert_eq!(posargs.len(), 1); + assert_eq!(kwargs.len(), 0); + assert_eq!(num_params, 1); - // Check for "too many arguments" - if num_params < posargs.len() { - let expected_min_count = - signature.args.iter().filter(|param| param.is_required()).count(); - let expected_max_count = num_params; + let other_type = posargs[0]; // the second operand + let expected_other_type = signature.args[0].ty; - self.restore_snapshot(); - return Err(TypeError::new( - TypeErrorKind::TooManyArguments { - expected_min_count, - expected_max_count, - got_count: num_args, - }, - *loc, - )); - } - - // NOTE: order of `param_info_by_name` is leveraged, so use an IndexMap - let mut param_info_by_name: IndexMap = signature - .args - .iter() - .map(|arg| (arg.name, ParamInfo { has_been_supplied: false, param: arg })) - .collect(); - - // Now consume all positional arguments and typecheck them. - for (&arg_ty, param) in zip(posargs, signature.args.iter()) { - // We will also use this opportunity to mark the corresponding `param_info` as having been supplied. - let param_info = param_info_by_name.get_mut(¶m.name).unwrap(); - param_info.has_been_supplied = true; - - // Typecheck - type_check_arg(param.name, param.ty, arg_ty)?; - } - - // Now consume all keyword arguments and typecheck them. - for (¶m_name, &arg_ty) in kwargs { - // We will also use this opportunity to check if this keyword argument is "legal". - - let Some(param_info) = param_info_by_name.get_mut(¶m_name) else { - self.restore_snapshot(); - return Err(TypeError::new(TypeErrorKind::UnknownArgName(param_name), *loc)); - }; - - if param_info.has_been_supplied { - // NOTE: Duplicate keyword argument (i.e., `hello(1, 2, 3, arg = 4, arg = 5)`) - // is IMPOSSIBLE as the parser would have already failed. - // We only have to care about "got multiple values for XYZ" - - self.restore_snapshot(); - return Err(TypeError::new( - TypeErrorKind::GotMultipleValues { name: param_name }, - *loc, - )); + let ok = self.unify_impl(expected_other_type, other_type, false).is_ok(); + if !ok { + self.restore_snapshot(); + return Err(TypeError::new( + TypeErrorKind::UnsupportedComparsionOpTypes { + operator: *operator, + lhs_type: *self_type, + rhs_type: other_type, + expected_rhs_type: expected_other_type, + }, + *loc, + )); + } } + _ => { + // Handle [`CallInfo::IsNormalFunctionCall`] and other uninteresting variants + // of [`CallInfo`] (e.g, `CallInfo::IsUnaryOp` and unsupported comparison operators) - param_info.has_been_supplied = true; + // Helper lambdas + let mut type_check_arg = |param_name, expected_arg_ty, arg_ty| { + let ok = self.unify_impl(expected_arg_ty, arg_ty, false).is_ok(); + if ok { + Ok(()) + } else { + // Typecheck failed, throw an error. + self.restore_snapshot(); + Err(TypeError::new( + TypeErrorKind::IncorrectArgType { + name: param_name, + expected: expected_arg_ty, + got: arg_ty, + }, + *loc, + )) + } + }; - // Typecheck - type_check_arg(param_name, param_info.param.ty, arg_ty)?; - } + // Check for "too many arguments" + if num_params < posargs.len() { + let expected_min_count = + signature.args.iter().filter(|param| param.is_required()).count(); + let expected_max_count = num_params; - // After checking posargs and kwargs, check if there are any - // unsupplied required parameters, and throw an error if they exist. - let missing_arg_names = param_info_by_name - .values() - .filter(|param_info| param_info.param.is_required() && !param_info.has_been_supplied) - .map(|param_info| param_info.param.name) - .collect_vec(); - if !missing_arg_names.is_empty() { - self.restore_snapshot(); - return Err(TypeError::new(TypeErrorKind::MissingArgs { missing_arg_names }, *loc)); - } + self.restore_snapshot(); + return Err(TypeError::new( + TypeErrorKind::TooManyArguments { + expected_min_count, + expected_max_count, + got_count: num_args, + }, + *loc, + )); + } - // Finally, check the Call's return type - self.unify_impl(*ret, signature.ret, false).map_err(|mut err| { - self.restore_snapshot(); - if err.loc.is_none() { - err.loc = *loc; + // NOTE: order of `param_info_by_name` is leveraged, so use an IndexMap + let mut param_info_by_name: IndexMap = signature + .args + .iter() + .map(|arg| (arg.name, ParamInfo { has_been_supplied: false, param: arg })) + .collect(); + + // Now consume all positional arguments and typecheck them. + for (&arg_ty, param) in zip(posargs, signature.args.iter()) { + // We will also use this opportunity to mark the corresponding `param_info` as having been supplied. + let param_info = param_info_by_name.get_mut(¶m.name).unwrap(); + param_info.has_been_supplied = true; + + // Typecheck + type_check_arg(param.name, param.ty, arg_ty)?; + } + + // Now consume all keyword arguments and typecheck them. + for (¶m_name, &arg_ty) in kwargs { + // We will also use this opportunity to check if this keyword argument is "legal". + + let Some(param_info) = param_info_by_name.get_mut(¶m_name) else { + self.restore_snapshot(); + return Err(TypeError::new( + TypeErrorKind::UnknownArgName(param_name), + *loc, + )); + }; + + if param_info.has_been_supplied { + // NOTE: Duplicate keyword argument (i.e., `hello(1, 2, 3, arg = 4, arg = 5)`) + // is IMPOSSIBLE as the parser would have already failed. + // We only have to care about "got multiple values for XYZ" + + self.restore_snapshot(); + return Err(TypeError::new( + TypeErrorKind::GotMultipleValues { name: param_name }, + *loc, + )); + } + + param_info.has_been_supplied = true; + + // Typecheck + type_check_arg(param_name, param_info.param.ty, arg_ty)?; + } + + // After checking posargs and kwargs, check if there are any + // unsupplied required parameters, and throw an error if they exist. + let missing_arg_names = param_info_by_name + .values() + .filter(|param_info| { + param_info.param.is_required() && !param_info.has_been_supplied + }) + .map(|param_info| param_info.param.name) + .collect_vec(); + if !missing_arg_names.is_empty() { + self.restore_snapshot(); + return Err(TypeError::new( + TypeErrorKind::MissingArgs { missing_arg_names }, + *loc, + )); + } + + // Finally, check the Call's return type + self.unify_impl(*ret, signature.ret, false).map_err(|mut err| { + self.restore_snapshot(); + if err.loc.is_none() { + err.loc = *loc; + } + err + })?; } - err - })?; + } *fun.borrow_mut() = Some(b);