Compare commits
4 Commits
b0b804051a
...
8dfd7f2a7d
Author | SHA1 | Date |
---|---|---|
lyken | 8dfd7f2a7d | |
lyken | 4468943736 | |
lyken | 352a7f3de5 | |
lyken | 54a794ae44 |
|
@ -21,7 +21,7 @@ use crate::{
|
|||
DefinitionId, TopLevelDef,
|
||||
},
|
||||
typecheck::{
|
||||
magic_methods::{binop_assign_name, binop_name, unaryop_name},
|
||||
magic_methods::{BinOpVariant, OpInfo},
|
||||
typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap},
|
||||
},
|
||||
};
|
||||
|
@ -1167,7 +1167,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||
op: Operator,
|
||||
right: (&Option<Type>, BasicValueEnum<'ctx>),
|
||||
loc: Location,
|
||||
is_aug_assign: bool,
|
||||
variant: BinOpVariant,
|
||||
) -> Result<Option<ValueEnum<'ctx>>, String> {
|
||||
let (left_ty, left_val) = left;
|
||||
let (right_ty, right_val) = right;
|
||||
|
@ -1222,7 +1222,10 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||
generator,
|
||||
ctx,
|
||||
ndarray_dtype1,
|
||||
if is_aug_assign { Some(left_val) } else { None },
|
||||
match variant {
|
||||
BinOpVariant::Normal => None,
|
||||
BinOpVariant::AugAssign => Some(left_val),
|
||||
},
|
||||
left_val,
|
||||
right_val,
|
||||
)?
|
||||
|
@ -1231,7 +1234,10 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||
generator,
|
||||
ctx,
|
||||
ndarray_dtype1,
|
||||
if is_aug_assign { Some(left_val) } else { None },
|
||||
match variant {
|
||||
BinOpVariant::Normal => None,
|
||||
BinOpVariant::AugAssign => Some(left_val),
|
||||
},
|
||||
(left_val.as_base_value().into(), false),
|
||||
(right_val.as_base_value().into(), false),
|
||||
|generator, ctx, (lhs, rhs)| {
|
||||
|
@ -1242,7 +1248,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||
op,
|
||||
(&Some(ndarray_dtype2), rhs),
|
||||
ctx.current_loc,
|
||||
is_aug_assign,
|
||||
variant,
|
||||
)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(
|
||||
|
@ -1267,7 +1273,10 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||
generator,
|
||||
ctx,
|
||||
ndarray_dtype,
|
||||
if is_aug_assign { Some(ndarray_val) } else { None },
|
||||
match variant {
|
||||
BinOpVariant::Normal => None,
|
||||
BinOpVariant::AugAssign => Some(ndarray_val),
|
||||
},
|
||||
(left_val, !is_ndarray1),
|
||||
(right_val, !is_ndarray2),
|
||||
|generator, ctx, (lhs, rhs)| {
|
||||
|
@ -1278,7 +1287,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||
op,
|
||||
(&Some(ndarray_dtype), rhs),
|
||||
ctx.current_loc,
|
||||
is_aug_assign,
|
||||
variant,
|
||||
)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, ndarray_dtype)
|
||||
|
@ -1293,13 +1302,15 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||
unreachable!("must be tobj")
|
||||
};
|
||||
let (op_name, id) = {
|
||||
let (binop_name, binop_assign_name) =
|
||||
(binop_name(op).into(), binop_assign_name(op).into());
|
||||
let normal_method_name = OpInfo::from_binop(op, BinOpVariant::Normal).method_name;
|
||||
let assign_method_name = OpInfo::from_binop(op, BinOpVariant::AugAssign).method_name;
|
||||
|
||||
// if is aug_assign, try aug_assign operator first
|
||||
if is_aug_assign && fields.contains_key(&binop_assign_name) {
|
||||
(binop_assign_name, *obj_id)
|
||||
if variant == BinOpVariant::AugAssign && fields.contains_key(&assign_method_name.into())
|
||||
{
|
||||
(assign_method_name.into(), *obj_id)
|
||||
} else {
|
||||
(binop_name, *obj_id)
|
||||
(normal_method_name.into(), *obj_id)
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1349,7 +1360,7 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
|
|||
op: Operator,
|
||||
right: &Expr<Option<Type>>,
|
||||
loc: Location,
|
||||
is_aug_assign: bool,
|
||||
variant: BinOpVariant,
|
||||
) -> Result<Option<ValueEnum<'ctx>>, String> {
|
||||
let left_val = if let Some(v) = generator.gen_expr(ctx, left)? {
|
||||
v.to_basic_value_enum(ctx, generator, left.custom.unwrap())?
|
||||
|
@ -1369,7 +1380,7 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
|
|||
op,
|
||||
(&right.custom, right_val),
|
||||
loc,
|
||||
is_aug_assign,
|
||||
variant,
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -1453,7 +1464,10 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||
if op == ast::Unaryop::Invert {
|
||||
ast::Unaryop::Not
|
||||
} else {
|
||||
unreachable!("ufunc {} not supported for ndarray[bool, N]", unaryop_name(op))
|
||||
unreachable!(
|
||||
"ufunc {} not supported for ndarray[bool, N]",
|
||||
OpInfo::from_unaryop(op).method_name
|
||||
)
|
||||
}
|
||||
} else {
|
||||
op
|
||||
|
@ -2343,7 +2357,15 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
|||
}
|
||||
}
|
||||
ExprKind::BinOp { op, left, right } => {
|
||||
return gen_binop_expr(generator, ctx, left, *op, right, expr.location, false);
|
||||
return gen_binop_expr(
|
||||
generator,
|
||||
ctx,
|
||||
left,
|
||||
*op,
|
||||
right,
|
||||
expr.location,
|
||||
BinOpVariant::Normal,
|
||||
);
|
||||
}
|
||||
ExprKind::UnaryOp { op, operand } => return gen_unaryop_expr(generator, ctx, *op, operand),
|
||||
ExprKind::Compare { left, ops, comparators } => {
|
||||
|
|
|
@ -11,8 +11,7 @@ use crate::{
|
|||
call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices,
|
||||
call_ndarray_calc_size,
|
||||
},
|
||||
llvm_intrinsics,
|
||||
llvm_intrinsics::call_memcpy_generic,
|
||||
llvm_intrinsics::{self, call_memcpy_generic},
|
||||
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
|
@ -22,7 +21,10 @@ use crate::{
|
|||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
||||
DefinitionId,
|
||||
},
|
||||
typecheck::typedef::{FunSignature, Type, TypeEnum},
|
||||
typecheck::{
|
||||
magic_methods::BinOpVariant,
|
||||
typedef::{FunSignature, Type, TypeEnum},
|
||||
},
|
||||
};
|
||||
use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType};
|
||||
use inkwell::{
|
||||
|
@ -1632,7 +1634,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
|
|||
Operator::Mult,
|
||||
(&Some(elem_ty), b),
|
||||
ctx.current_loc,
|
||||
false,
|
||||
BinOpVariant::Normal,
|
||||
)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, elem_ty)?;
|
||||
|
@ -1645,7 +1647,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
|
|||
Operator::Add,
|
||||
(&Some(elem_ty), a_mul_b),
|
||||
ctx.current_loc,
|
||||
false,
|
||||
BinOpVariant::Normal,
|
||||
)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, elem_ty)?;
|
||||
|
|
|
@ -11,7 +11,10 @@ use crate::{
|
|||
gen_in_range_check,
|
||||
},
|
||||
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef},
|
||||
typecheck::typedef::{FunSignature, Type, TypeEnum},
|
||||
typecheck::{
|
||||
magic_methods::BinOpVariant,
|
||||
typedef::{FunSignature, Type, TypeEnum},
|
||||
},
|
||||
};
|
||||
use inkwell::{
|
||||
attributes::{Attribute, AttributeLoc},
|
||||
|
@ -1574,7 +1577,15 @@ pub fn gen_stmt<G: CodeGenerator>(
|
|||
StmtKind::For { .. } => generator.gen_for(ctx, stmt)?,
|
||||
StmtKind::With { .. } => generator.gen_with(ctx, stmt)?,
|
||||
StmtKind::AugAssign { target, op, value, .. } => {
|
||||
let value = gen_binop_expr(generator, ctx, target, *op, value, stmt.location, true)?;
|
||||
let value = gen_binop_expr(
|
||||
generator,
|
||||
ctx,
|
||||
target,
|
||||
*op,
|
||||
value,
|
||||
stmt.location,
|
||||
BinOpVariant::AugAssign,
|
||||
)?;
|
||||
generator.gen_assign(ctx, target, value.unwrap())?;
|
||||
}
|
||||
StmtKind::Try { .. } => gen_try(generator, ctx, stmt)?,
|
||||
|
|
|
@ -5,7 +5,7 @@ use crate::typecheck::{
|
|||
type_inferencer::*,
|
||||
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use itertools::{iproduct, Itertools};
|
||||
use nac3parser::ast::StrRef;
|
||||
use nac3parser::ast::{Cmpop, Operator, Unaryop};
|
||||
use std::cmp::max;
|
||||
|
@ -13,64 +13,93 @@ use std::collections::HashMap;
|
|||
use std::rc::Rc;
|
||||
use strum::IntoEnumIterator;
|
||||
|
||||
#[must_use]
|
||||
pub fn binop_name(op: Operator) -> &'static str {
|
||||
match op {
|
||||
Operator::Add => "__add__",
|
||||
Operator::Sub => "__sub__",
|
||||
Operator::Div => "__truediv__",
|
||||
Operator::Mod => "__mod__",
|
||||
Operator::Mult => "__mul__",
|
||||
Operator::Pow => "__pow__",
|
||||
Operator::BitOr => "__or__",
|
||||
Operator::BitXor => "__xor__",
|
||||
Operator::BitAnd => "__and__",
|
||||
Operator::LShift => "__lshift__",
|
||||
Operator::RShift => "__rshift__",
|
||||
Operator::FloorDiv => "__floordiv__",
|
||||
Operator::MatMult => "__matmul__",
|
||||
}
|
||||
/// Details about an operator (unary, binary, etc...) in Python
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct OpInfo {
|
||||
/// The method name of the binary operator.
|
||||
/// For addition, this would be `__add__`, and `__iadd__` if
|
||||
/// it is the augmented assigning variant.
|
||||
pub method_name: &'static str,
|
||||
/// The symbol of the binary operator.
|
||||
/// For addition, this would be `+`, and `+=` if
|
||||
/// it is the augmented assigning variant.
|
||||
pub symbol: &'static str,
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn binop_assign_name(op: Operator) -> &'static str {
|
||||
match op {
|
||||
Operator::Add => "__iadd__",
|
||||
Operator::Sub => "__isub__",
|
||||
Operator::Div => "__itruediv__",
|
||||
Operator::Mod => "__imod__",
|
||||
Operator::Mult => "__imul__",
|
||||
Operator::Pow => "__ipow__",
|
||||
Operator::BitOr => "__ior__",
|
||||
Operator::BitXor => "__ixor__",
|
||||
Operator::BitAnd => "__iand__",
|
||||
Operator::LShift => "__ilshift__",
|
||||
Operator::RShift => "__irshift__",
|
||||
Operator::FloorDiv => "__ifloordiv__",
|
||||
Operator::MatMult => "__imatmul__",
|
||||
}
|
||||
/// Helper macro to conveniently build an [`OpInfo`].
|
||||
///
|
||||
/// Example usage: `make_info("add", "+")` generates `OpInfo { name: "__add__", symbol: "+" }`
|
||||
macro_rules! make_info {
|
||||
($name:expr, $symbol:expr) => {
|
||||
OpInfo { method_name: concat!("__", $name, "__"), symbol: $symbol }
|
||||
};
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn unaryop_name(op: Unaryop) -> &'static str {
|
||||
match op {
|
||||
Unaryop::UAdd => "__pos__",
|
||||
Unaryop::USub => "__neg__",
|
||||
Unaryop::Not => "__not__",
|
||||
Unaryop::Invert => "__inv__",
|
||||
}
|
||||
/// The variant of a binary operator.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum BinOpVariant {
|
||||
/// The normal variant.
|
||||
/// For addition, it would be `+`.
|
||||
Normal,
|
||||
/// The "Augmented Assigning Operator" variant.
|
||||
/// For addition, it would be `+=`.
|
||||
AugAssign,
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn comparison_name(op: Cmpop) -> Option<&'static str> {
|
||||
match op {
|
||||
Cmpop::Lt => Some("__lt__"),
|
||||
Cmpop::LtE => Some("__le__"),
|
||||
Cmpop::Gt => Some("__gt__"),
|
||||
Cmpop::GtE => Some("__ge__"),
|
||||
Cmpop::Eq => Some("__eq__"),
|
||||
Cmpop::NotEq => Some("__ne__"),
|
||||
_ => None,
|
||||
impl OpInfo {
|
||||
#[must_use]
|
||||
pub fn from_binop(op: Operator, variant: BinOpVariant) -> Self {
|
||||
// Helper macro to generate both the normal variant [`OpInfo`] and the
|
||||
// augmented assigning variant [`OpInfo`] for a binary operator conveniently.
|
||||
macro_rules! info {
|
||||
($name:literal, $symbol:literal) => {
|
||||
(make_info!($name, $symbol), make_info!(concat!("i", $name), concat!($symbol, "=")))
|
||||
};
|
||||
}
|
||||
|
||||
let (normal_variant, aug_assign_variant) = match op {
|
||||
Operator::Add => info!("add", "+"),
|
||||
Operator::Sub => info!("sub", "-"),
|
||||
Operator::Div => info!("truediv", "/"),
|
||||
Operator::Mod => info!("mod", "%"),
|
||||
Operator::Mult => info!("mul", "*"),
|
||||
Operator::Pow => info!("pow", "**"),
|
||||
Operator::BitOr => info!("or", "|"),
|
||||
Operator::BitXor => info!("xor", "^"),
|
||||
Operator::BitAnd => info!("and", "&"),
|
||||
Operator::LShift => info!("lshift", "<<"),
|
||||
Operator::RShift => info!("rshift", ">>"),
|
||||
Operator::FloorDiv => info!("floordiv", "//"),
|
||||
Operator::MatMult => info!("matmul", "@"),
|
||||
};
|
||||
|
||||
match variant {
|
||||
BinOpVariant::Normal => normal_variant,
|
||||
BinOpVariant::AugAssign => aug_assign_variant,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn from_unaryop(op: Unaryop) -> Self {
|
||||
match op {
|
||||
Unaryop::UAdd => make_info!("pos", "+"),
|
||||
Unaryop::USub => make_info!("neg", "-"),
|
||||
Unaryop::Not => make_info!("not", "not"), // i.e., `not False`, so the symbol is just `not`.
|
||||
Unaryop::Invert => make_info!("inv", "~"),
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn from_cmpop(op: Cmpop) -> Option<Self> {
|
||||
match op {
|
||||
Cmpop::Lt => Some(make_info!("lt", "<")),
|
||||
Cmpop::LtE => Some(make_info!("le", "<=")),
|
||||
Cmpop::Gt => Some(make_info!("gt", ">")),
|
||||
Cmpop::GtE => Some(make_info!("ge", ">=")),
|
||||
Cmpop::Eq => Some(make_info!("eq", "==")),
|
||||
Cmpop::NotEq => Some(make_info!("ne", "!=")),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -115,23 +144,8 @@ pub fn impl_binop(
|
|||
|
||||
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty);
|
||||
|
||||
for op in ops {
|
||||
fields.insert(binop_name(*op).into(), {
|
||||
(
|
||||
unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
ret: ret_ty,
|
||||
vars: function_vars.clone(),
|
||||
args: vec![FuncArg {
|
||||
ty: other_ty,
|
||||
default_value: None,
|
||||
name: "other".into(),
|
||||
}],
|
||||
})),
|
||||
false,
|
||||
)
|
||||
});
|
||||
|
||||
fields.insert(binop_assign_name(*op).into(), {
|
||||
for (op, variant) in iproduct!(ops, [BinOpVariant::Normal, BinOpVariant::AugAssign]) {
|
||||
fields.insert(OpInfo::from_binop(*op, variant).method_name.into(), {
|
||||
(
|
||||
unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
ret: ret_ty,
|
||||
|
@ -155,7 +169,7 @@ pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Option<Type>, ops:
|
|||
|
||||
for op in ops {
|
||||
fields.insert(
|
||||
unaryop_name(*op).into(),
|
||||
OpInfo::from_unaryop(*op).method_name.into(),
|
||||
(
|
||||
unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
ret: ret_ty,
|
||||
|
@ -195,7 +209,7 @@ pub fn impl_cmpop(
|
|||
|
||||
for op in ops {
|
||||
fields.insert(
|
||||
comparison_name(*op).unwrap().into(),
|
||||
OpInfo::from_cmpop(*op).unwrap().method_name.into(),
|
||||
(
|
||||
unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
ret: ret_ty,
|
||||
|
|
|
@ -1,24 +1,47 @@
|
|||
use std::collections::HashMap;
|
||||
use std::fmt::Display;
|
||||
|
||||
use crate::typecheck::typedef::TypeEnum;
|
||||
use crate::typecheck::{magic_methods::OpInfo, typedef::TypeEnum};
|
||||
|
||||
use super::typedef::{RecordKey, Type, Unifier};
|
||||
use nac3parser::ast::{Location, StrRef};
|
||||
use super::{
|
||||
magic_methods::BinOpVariant,
|
||||
typedef::{RecordKey, Type, Unifier},
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use nac3parser::ast::{Cmpop, Location, Operator, StrRef};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum TypeErrorKind {
|
||||
TooManyArguments {
|
||||
expected: usize,
|
||||
got: usize,
|
||||
GotMultipleValues {
|
||||
name: StrRef,
|
||||
},
|
||||
TooManyArguments {
|
||||
expected_min_count: usize,
|
||||
expected_max_count: usize,
|
||||
got_count: usize,
|
||||
},
|
||||
MissingArgs {
|
||||
missing_arg_names: Vec<StrRef>,
|
||||
},
|
||||
MissingArgs(String),
|
||||
UnknownArgName(StrRef),
|
||||
IncorrectArgType {
|
||||
name: StrRef,
|
||||
expected: Type,
|
||||
got: Type,
|
||||
},
|
||||
UnsupportedBinaryOpTypes {
|
||||
operator: Operator,
|
||||
variant: BinOpVariant,
|
||||
lhs_type: Type,
|
||||
rhs_type: Type,
|
||||
expected_rhs_type: Type,
|
||||
},
|
||||
UnsupportedComparsionOpTypes {
|
||||
operator: Cmpop,
|
||||
lhs_type: Type,
|
||||
rhs_type: Type,
|
||||
expected_rhs_type: Type,
|
||||
},
|
||||
FieldUnificationError {
|
||||
field: RecordKey,
|
||||
types: (Type, Type),
|
||||
|
@ -78,19 +101,55 @@ impl<'a> Display for DisplayTypeError<'a> {
|
|||
use TypeErrorKind::*;
|
||||
let mut notes = Some(HashMap::new());
|
||||
match &self.err.kind {
|
||||
TooManyArguments { expected, got } => {
|
||||
write!(f, "Too many arguments. Expected {expected} but got {got}")
|
||||
GotMultipleValues { name } => {
|
||||
write!(f, "For multiple values for parameter {name}")
|
||||
}
|
||||
MissingArgs(args) => {
|
||||
TooManyArguments { expected_min_count, expected_max_count, got_count } => {
|
||||
debug_assert!(expected_min_count <= expected_max_count);
|
||||
if expected_min_count == expected_max_count {
|
||||
let expected_count = expected_min_count; // or expected_max_count
|
||||
write!(f, "Too many arguments. Expected {expected_count} but got {got_count}")
|
||||
} else {
|
||||
write!(f, "Too many arguments. Expected {expected_min_count} to {expected_max_count} arguments but got {got_count}")
|
||||
}
|
||||
}
|
||||
MissingArgs { missing_arg_names } => {
|
||||
let args = missing_arg_names.iter().join(", ");
|
||||
write!(f, "Missing arguments: {args}")
|
||||
}
|
||||
UnsupportedBinaryOpTypes {
|
||||
operator,
|
||||
variant,
|
||||
lhs_type,
|
||||
rhs_type,
|
||||
expected_rhs_type,
|
||||
} => {
|
||||
let op_symbol = OpInfo::from_binop(*operator, *variant).symbol;
|
||||
|
||||
let lhs_type_str = self.unifier.stringify_with_notes(*lhs_type, &mut notes);
|
||||
let rhs_type_str = self.unifier.stringify_with_notes(*rhs_type, &mut notes);
|
||||
let expected_rhs_type_str =
|
||||
self.unifier.stringify_with_notes(*expected_rhs_type, &mut notes);
|
||||
|
||||
write!(f, "Unsupported operand type(s) for {op_symbol}: '{lhs_type_str}' and '{rhs_type_str}' (right operand should have type {expected_rhs_type_str})")
|
||||
}
|
||||
UnsupportedComparsionOpTypes { operator, lhs_type, rhs_type, expected_rhs_type } => {
|
||||
let op_symbol = OpInfo::from_cmpop(*operator).unwrap().symbol;
|
||||
|
||||
let lhs_type_str = self.unifier.stringify_with_notes(*lhs_type, &mut notes);
|
||||
let rhs_type_str = self.unifier.stringify_with_notes(*rhs_type, &mut notes);
|
||||
let expected_rhs_type_str =
|
||||
self.unifier.stringify_with_notes(*expected_rhs_type, &mut notes);
|
||||
|
||||
write!(f, "'{op_symbol}' not supported between instances of '{lhs_type_str}' and '{rhs_type_str}' (right operand should have type {expected_rhs_type_str})")
|
||||
}
|
||||
UnknownArgName(name) => {
|
||||
write!(f, "Unknown argument name: {name}")
|
||||
}
|
||||
IncorrectArgType { name, expected, got } => {
|
||||
let expected = self.unifier.stringify_with_notes(*expected, &mut notes);
|
||||
let got = self.unifier.stringify_with_notes(*got, &mut notes);
|
||||
write!(f, "Incorrect argument type for {name}. Expected {expected}, but got {got}")
|
||||
write!(f, "Incorrect argument type for parameter {name}. Expected {expected}, but got {got}")
|
||||
}
|
||||
FieldUnificationError { field, types, loc } => {
|
||||
let lhs = self.unifier.stringify_with_notes(types.0, &mut notes);
|
||||
|
|
|
@ -4,7 +4,9 @@ use std::iter::once;
|
|||
use std::ops::Not;
|
||||
use std::{cell::RefCell, sync::Arc};
|
||||
|
||||
use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier, VarMap};
|
||||
use super::typedef::{
|
||||
Call, CallInfo, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier, VarMap,
|
||||
};
|
||||
use super::{magic_methods::*, type_error::TypeError, typedef::CallId};
|
||||
use crate::toplevel::TopLevelDef;
|
||||
use crate::{
|
||||
|
@ -466,7 +468,8 @@ impl<'a> Fold<()> for Inferencer<'a> {
|
|||
(None, None) => {}
|
||||
},
|
||||
ast::StmtKind::AugAssign { target, op, value, .. } => {
|
||||
let res_ty = self.infer_bin_ops(stmt.location, target, *op, value, true)?;
|
||||
let res_ty =
|
||||
self.infer_bin_ops(stmt.location, target, *op, value, BinOpVariant::AugAssign)?;
|
||||
self.unify(res_ty, target.custom.unwrap(), &stmt.location)?;
|
||||
}
|
||||
ast::StmtKind::Assert { test, msg, .. } => {
|
||||
|
@ -548,7 +551,7 @@ impl<'a> Fold<()> for Inferencer<'a> {
|
|||
}
|
||||
ExprKind::BoolOp { values, .. } => Some(self.infer_bool_ops(values)?),
|
||||
ExprKind::BinOp { left, op, right } => {
|
||||
Some(self.infer_bin_ops(expr.location, left, *op, right, false)?)
|
||||
Some(self.infer_bin_ops(expr.location, left, *op, right, BinOpVariant::Normal)?)
|
||||
}
|
||||
ExprKind::UnaryOp { op, operand } => {
|
||||
Some(self.infer_unary_ops(expr.location, *op, operand)?)
|
||||
|
@ -615,6 +618,7 @@ impl<'a> Inferencer<'a> {
|
|||
obj: Type,
|
||||
params: Vec<Type>,
|
||||
ret: Option<Type>,
|
||||
call_info: CallInfo,
|
||||
) -> InferenceResult {
|
||||
if let TypeEnum::TObj { params: class_params, fields, .. } = &*self.unifier.get_ty(obj) {
|
||||
if class_params.is_empty() {
|
||||
|
@ -628,6 +632,7 @@ impl<'a> Inferencer<'a> {
|
|||
ret: sign.ret,
|
||||
fun: RefCell::new(None),
|
||||
loc: Some(location),
|
||||
info: call_info,
|
||||
};
|
||||
if let Some(ret) = ret {
|
||||
self.unifier
|
||||
|
@ -642,14 +647,7 @@ impl<'a> Inferencer<'a> {
|
|||
})
|
||||
.unwrap();
|
||||
}
|
||||
let required: Vec<_> = sign
|
||||
.args
|
||||
.iter()
|
||||
.filter(|v| v.default_value.is_none())
|
||||
.map(|v| v.name)
|
||||
.rev()
|
||||
.collect();
|
||||
self.unifier.unify_call(&call, ty, sign, &required).map_err(|e| {
|
||||
self.unifier.unify_call(&call, ty, sign).map_err(|e| {
|
||||
HashSet::from([e
|
||||
.at(Some(location))
|
||||
.to_display(self.unifier)
|
||||
|
@ -669,6 +667,7 @@ impl<'a> Inferencer<'a> {
|
|||
ret,
|
||||
fun: RefCell::new(None),
|
||||
loc: Some(location),
|
||||
info: call_info,
|
||||
});
|
||||
self.calls.insert(location.into(), call);
|
||||
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call]));
|
||||
|
@ -1346,17 +1345,11 @@ impl<'a> Inferencer<'a> {
|
|||
fun: RefCell::new(None),
|
||||
ret: sign.ret,
|
||||
loc: Some(location),
|
||||
info: CallInfo::IsNormalFunctionCall,
|
||||
};
|
||||
let required: Vec<_> = sign
|
||||
.args
|
||||
.iter()
|
||||
.filter(|v| v.default_value.is_none())
|
||||
.map(|v| v.name)
|
||||
.rev()
|
||||
.collect();
|
||||
self.unifier.unify_call(&call, func.custom.unwrap(), sign, &required).map_err(
|
||||
|e| HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()]),
|
||||
)?;
|
||||
self.unifier.unify_call(&call, func.custom.unwrap(), sign).map_err(|e| {
|
||||
HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()])
|
||||
})?;
|
||||
return Ok(Located {
|
||||
location,
|
||||
custom: Some(sign.ret),
|
||||
|
@ -1375,6 +1368,7 @@ impl<'a> Inferencer<'a> {
|
|||
fun: RefCell::new(None),
|
||||
ret,
|
||||
loc: Some(location),
|
||||
info: CallInfo::IsNormalFunctionCall,
|
||||
});
|
||||
self.calls.insert(location.into(), call);
|
||||
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call]));
|
||||
|
@ -1550,7 +1544,7 @@ impl<'a> Inferencer<'a> {
|
|||
left: &ast::Expr<Option<Type>>,
|
||||
op: ast::Operator,
|
||||
right: &ast::Expr<Option<Type>>,
|
||||
is_aug_assign: bool,
|
||||
variant: BinOpVariant,
|
||||
) -> InferenceResult {
|
||||
let left_ty = left.custom.unwrap();
|
||||
let right_ty = right.custom.unwrap();
|
||||
|
@ -1558,27 +1552,39 @@ impl<'a> Inferencer<'a> {
|
|||
let method = if let TypeEnum::TObj { fields, .. } =
|
||||
self.unifier.get_ty_immutable(left_ty).as_ref()
|
||||
{
|
||||
let (binop_name, binop_assign_name) =
|
||||
(binop_name(op).into(), binop_assign_name(op).into());
|
||||
let normal_method_name = OpInfo::from_binop(op, BinOpVariant::Normal).method_name;
|
||||
let assign_method_name = OpInfo::from_binop(op, BinOpVariant::AugAssign).method_name;
|
||||
|
||||
// if is aug_assign, try aug_assign operator first
|
||||
if is_aug_assign && fields.contains_key(&binop_assign_name) {
|
||||
binop_assign_name
|
||||
if variant == BinOpVariant::AugAssign && fields.contains_key(&assign_method_name.into())
|
||||
{
|
||||
assign_method_name
|
||||
} else {
|
||||
binop_name
|
||||
normal_method_name
|
||||
}
|
||||
} else {
|
||||
binop_name(op).into()
|
||||
OpInfo::from_binop(op, variant).method_name
|
||||
};
|
||||
|
||||
let ret = if is_aug_assign {
|
||||
// The type of augmented assignment operator should never change
|
||||
Some(left_ty)
|
||||
} else {
|
||||
typeof_binop(self.unifier, self.primitives, op, left_ty, right_ty)
|
||||
.map_err(|e| HashSet::from([format!("{e} (at {location})")]))?
|
||||
let ret = match variant {
|
||||
BinOpVariant::Normal => {
|
||||
typeof_binop(self.unifier, self.primitives, op, left_ty, right_ty)
|
||||
.map_err(|e| HashSet::from([format!("{e} (at {location})")]))?
|
||||
}
|
||||
BinOpVariant::AugAssign => {
|
||||
// The type of augmented assignment operator should never change
|
||||
Some(left_ty)
|
||||
}
|
||||
};
|
||||
|
||||
self.build_method_call(location, method, left_ty, vec![right_ty], ret)
|
||||
self.build_method_call(
|
||||
location,
|
||||
method.into(),
|
||||
left_ty,
|
||||
vec![right_ty],
|
||||
ret,
|
||||
CallInfo::IsBinaryOp { self_type: left.custom.unwrap(), operator: op, variant },
|
||||
)
|
||||
}
|
||||
|
||||
fn infer_unary_ops(
|
||||
|
@ -1587,12 +1593,19 @@ impl<'a> Inferencer<'a> {
|
|||
op: ast::Unaryop,
|
||||
operand: &ast::Expr<Option<Type>>,
|
||||
) -> InferenceResult {
|
||||
let method = unaryop_name(op).into();
|
||||
let method = OpInfo::from_unaryop(op).method_name.into();
|
||||
|
||||
let ret = typeof_unaryop(self.unifier, self.primitives, op, operand.custom.unwrap())
|
||||
.map_err(|e| HashSet::from([format!("{e} (at {location})")]))?;
|
||||
|
||||
self.build_method_call(location, method, operand.custom.unwrap(), vec![], ret)
|
||||
self.build_method_call(
|
||||
location,
|
||||
method,
|
||||
operand.custom.unwrap(),
|
||||
vec![],
|
||||
ret,
|
||||
CallInfo::IsUnaryOp { self_type: operand.custom.unwrap(), operator: op },
|
||||
)
|
||||
}
|
||||
|
||||
fn infer_compare(
|
||||
|
@ -1617,8 +1630,9 @@ impl<'a> Inferencer<'a> {
|
|||
|
||||
let mut res = None;
|
||||
for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) {
|
||||
let method = comparison_name(*c)
|
||||
let method = OpInfo::from_cmpop(*c)
|
||||
.ok_or_else(|| HashSet::from(["unsupported comparator".to_string()]))?
|
||||
.method_name
|
||||
.into();
|
||||
|
||||
let ret = typeof_cmpop(
|
||||
|
@ -1636,6 +1650,7 @@ impl<'a> Inferencer<'a> {
|
|||
a.custom.unwrap(),
|
||||
vec![b.custom.unwrap()],
|
||||
ret,
|
||||
CallInfo::IsComparisonOp { self_type: left.custom.unwrap(), operator: *c },
|
||||
)?);
|
||||
}
|
||||
|
||||
|
|
|
@ -8,12 +8,14 @@ use std::rc::Rc;
|
|||
use std::sync::{Arc, Mutex};
|
||||
use std::{borrow::Cow, collections::HashSet};
|
||||
|
||||
use nac3parser::ast::{Location, StrRef};
|
||||
use nac3parser::ast::{Cmpop, Location, Operator, StrRef, Unaryop};
|
||||
|
||||
use super::magic_methods::BinOpVariant;
|
||||
use super::type_error::{TypeError, TypeErrorKind};
|
||||
use super::unification_table::{UnificationKey, UnificationTable};
|
||||
use crate::symbol_resolver::SymbolValue;
|
||||
use crate::toplevel::{DefinitionId, TopLevelContext, TopLevelDef};
|
||||
use crate::typecheck::magic_methods::OpInfo;
|
||||
use crate::typecheck::type_inferencer::PrimitiveStore;
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -73,6 +75,32 @@ pub fn iter_type_vars(var_map: &VarMap) -> impl Iterator<Item = TypeVar> + '_ {
|
|||
var_map.iter().map(|(&id, &ty)| TypeVar { id, ty })
|
||||
}
|
||||
|
||||
/// Extra details about how a [`Call`] was written by the user.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum CallInfo {
|
||||
/// The call was written as an unary operation, e.g., `~a` or `not a`.
|
||||
IsUnaryOp {
|
||||
/// The [`Type`] of the `self` object
|
||||
self_type: Type,
|
||||
operator: Unaryop,
|
||||
},
|
||||
/// The call was written as a binary operation, e.g., `a + b` or `a += b`.
|
||||
IsBinaryOp {
|
||||
/// The [`Type`] of the `self` object
|
||||
self_type: Type,
|
||||
operator: Operator,
|
||||
variant: BinOpVariant,
|
||||
},
|
||||
/// The call was written as a binary comparison operation, e.g., `a < b`.
|
||||
IsComparisonOp {
|
||||
/// The [`Type`] of the `self` object
|
||||
self_type: Type,
|
||||
operator: Cmpop,
|
||||
},
|
||||
/// "Normal" function calls that looks like `func(1, 2, 3)`.
|
||||
IsNormalFunctionCall,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Call {
|
||||
pub posargs: Vec<Type>,
|
||||
|
@ -80,6 +108,7 @@ pub struct Call {
|
|||
pub ret: Type,
|
||||
pub fun: RefCell<Option<Type>>,
|
||||
pub loc: Option<Location>,
|
||||
pub info: CallInfo,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
@ -89,6 +118,13 @@ pub struct FuncArg {
|
|||
pub default_value: Option<SymbolValue>,
|
||||
}
|
||||
|
||||
impl FuncArg {
|
||||
#[must_use]
|
||||
pub fn is_required(&self) -> bool {
|
||||
self.default_value.is_none()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FunSignature {
|
||||
pub args: Vec<FuncArg>,
|
||||
|
@ -562,69 +598,230 @@ impl Unifier {
|
|||
call: &Call,
|
||||
b: Type,
|
||||
signature: &FunSignature,
|
||||
required: &[StrRef],
|
||||
) -> Result<(), TypeError> {
|
||||
/*
|
||||
NOTE: scenarios to consider:
|
||||
|
||||
```python
|
||||
def func1(x: int32, y: int32, z: int32 = 5): pass
|
||||
|
||||
# Normal scenarios
|
||||
func1(23, 45) # OK, z has default
|
||||
func1(23, 45, 67) # OK, z's default is overwritten
|
||||
func1(x = 23, y = 45) # OK, user is using kwargs to set positional args
|
||||
func1(y = 45, x = 23) # OK, kwargs order doesn't matter
|
||||
|
||||
# Error scenarios
|
||||
func1() # ERROR: Missing arguments: x, y
|
||||
func1(23) # ERROR: Missing arguments: y
|
||||
func1(z = 23) # ERROR: Missing arguments: x, y
|
||||
func1(x = 23) # ERROR: Missing arguments: y
|
||||
func1(23, 45, x = 5) # ERROR: Got multiple values for x
|
||||
func1(23, 45, x = 5, y = 6) # ERROR: Got multiple values for x (y too but Python does not report it)
|
||||
func1(23, 45, 67, z = 89) # ERROR: Got multiple values for z
|
||||
func1(23, 45, 67, 89) # ERROR: Function only takes from 2 to 3 positional arguments but 4 were given.
|
||||
func1(23, 45, 67, w = 3) # ERROR: Got an unexpected keyword argument 'w'
|
||||
|
||||
# Error scenarios that do not need to be handled here.
|
||||
func1(23, 45, z = 67, z = 89) # ERROR: Keyword argument repeated: z, the parser panics on this.
|
||||
```
|
||||
*/
|
||||
|
||||
struct ParamInfo<'a> {
|
||||
/// Has this parameter been supplied with an argument already?
|
||||
has_been_supplied: bool,
|
||||
/// The corresponding [`FuncArg`] instance of this parameter (for fast table lookups)
|
||||
param: &'a FuncArg,
|
||||
}
|
||||
|
||||
let snapshot = self.unification_table.get_snapshot();
|
||||
if self.snapshot.is_none() {
|
||||
self.snapshot = Some(snapshot);
|
||||
}
|
||||
|
||||
let Call { posargs, kwargs, ret, fun, loc } = call;
|
||||
let instantiated = self.instantiate_fun(b, signature);
|
||||
let r = self.get_ty(instantiated);
|
||||
let r = r.as_ref();
|
||||
let TypeEnum::TFunc(signature) = r else { unreachable!() };
|
||||
// we check to make sure that all required arguments (those without default
|
||||
// arguments) are provided, and do not provide the same argument twice.
|
||||
let mut required = required.to_vec();
|
||||
let mut all_names: Vec<_> = signature.args.iter().map(|v| (v.name, v.ty)).rev().collect();
|
||||
for (i, t) in posargs.iter().enumerate() {
|
||||
if signature.args.len() <= i {
|
||||
self.restore_snapshot();
|
||||
return Err(TypeError::new(
|
||||
TypeErrorKind::TooManyArguments {
|
||||
expected: signature.args.len(),
|
||||
got: posargs.len() + kwargs.len(),
|
||||
},
|
||||
*loc,
|
||||
));
|
||||
// Get details about the function signature/parameters.
|
||||
let num_params = signature.args.len();
|
||||
|
||||
// Force the type vars in `b` and `signature' to be up-to-date.
|
||||
let b = self.instantiate_fun(b, signature);
|
||||
let TypeEnum::TFunc(signature) = &*self.get_ty(b) else { unreachable!() };
|
||||
|
||||
// Get details about the input arguments
|
||||
let Call { posargs, kwargs, ret, fun, loc, info: call_info } = call;
|
||||
let num_args = posargs.len() + kwargs.len();
|
||||
|
||||
// Now we check the arguments against the parameters,
|
||||
// and depending on what `call_info` is, we might change how the behavior `unify_call()`
|
||||
// in hopes to improve user error messages when type checking fails.
|
||||
match call_info {
|
||||
CallInfo::IsBinaryOp { self_type, operator, variant } => {
|
||||
// The call is written in the form of (say) `a + b`.
|
||||
// Technically, it is `a.__add__(b)`, and they have the following constraints:
|
||||
assert_eq!(posargs.len(), 1);
|
||||
assert_eq!(kwargs.len(), 0);
|
||||
assert_eq!(num_params, 1);
|
||||
|
||||
let other_type = posargs[0]; // the second operand
|
||||
let expected_other_type = signature.args[0].ty;
|
||||
|
||||
let ok = self.unify_impl(expected_other_type, other_type, false).is_ok();
|
||||
if !ok {
|
||||
self.restore_snapshot();
|
||||
return Err(TypeError::new(
|
||||
TypeErrorKind::UnsupportedBinaryOpTypes {
|
||||
operator: *operator,
|
||||
variant: *variant,
|
||||
lhs_type: *self_type,
|
||||
rhs_type: other_type,
|
||||
expected_rhs_type: expected_other_type,
|
||||
},
|
||||
*loc,
|
||||
));
|
||||
}
|
||||
}
|
||||
required.pop();
|
||||
let (name, expected) = all_names.pop().unwrap();
|
||||
self.unify_impl(expected, *t, false).map_err(|_| {
|
||||
self.restore_snapshot();
|
||||
TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc)
|
||||
})?;
|
||||
}
|
||||
for (k, t) in kwargs {
|
||||
if let Some(i) = required.iter().position(|v| v == k) {
|
||||
required.remove(i);
|
||||
CallInfo::IsComparisonOp { self_type, operator }
|
||||
if OpInfo::from_cmpop(*operator).is_some() // Otherwise that comparison operator is not supported.
|
||||
=>
|
||||
{
|
||||
// The call is written in the form of (say) `a <= b`.
|
||||
// Technically, it is `a.__le__(b)`, and they have the following constraints:
|
||||
assert_eq!(posargs.len(), 1);
|
||||
assert_eq!(kwargs.len(), 0);
|
||||
assert_eq!(num_params, 1);
|
||||
|
||||
let other_type = posargs[0]; // the second operand
|
||||
let expected_other_type = signature.args[0].ty;
|
||||
|
||||
let ok = self.unify_impl(expected_other_type, other_type, false).is_ok();
|
||||
if !ok {
|
||||
self.restore_snapshot();
|
||||
return Err(TypeError::new(
|
||||
TypeErrorKind::UnsupportedComparsionOpTypes {
|
||||
operator: *operator,
|
||||
lhs_type: *self_type,
|
||||
rhs_type: other_type,
|
||||
expected_rhs_type: expected_other_type,
|
||||
},
|
||||
*loc,
|
||||
));
|
||||
}
|
||||
}
|
||||
let i = all_names.iter().position(|v| &v.0 == k).ok_or_else(|| {
|
||||
self.restore_snapshot();
|
||||
TypeError::new(TypeErrorKind::UnknownArgName(*k), *loc)
|
||||
})?;
|
||||
let (name, expected) = all_names.remove(i);
|
||||
self.unify_impl(expected, *t, false).map_err(|_| {
|
||||
self.restore_snapshot();
|
||||
TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc)
|
||||
})?;
|
||||
}
|
||||
if !required.is_empty() {
|
||||
self.restore_snapshot();
|
||||
return Err(TypeError::new(
|
||||
TypeErrorKind::MissingArgs(required.iter().join(", ")),
|
||||
*loc,
|
||||
));
|
||||
}
|
||||
self.unify_impl(*ret, signature.ret, false).map_err(|mut err| {
|
||||
self.restore_snapshot();
|
||||
if err.loc.is_none() {
|
||||
err.loc = *loc;
|
||||
_ => {
|
||||
// Handle [`CallInfo::IsNormalFunctionCall`] and other uninteresting variants
|
||||
// of [`CallInfo`] (e.g, `CallInfo::IsUnaryOp` and unsupported comparison operators)
|
||||
|
||||
// Helper lambdas
|
||||
let mut type_check_arg = |param_name, expected_arg_ty, arg_ty| {
|
||||
let ok = self.unify_impl(expected_arg_ty, arg_ty, false).is_ok();
|
||||
if ok {
|
||||
Ok(())
|
||||
} else {
|
||||
// Typecheck failed, throw an error.
|
||||
self.restore_snapshot();
|
||||
Err(TypeError::new(
|
||||
TypeErrorKind::IncorrectArgType {
|
||||
name: param_name,
|
||||
expected: expected_arg_ty,
|
||||
got: arg_ty,
|
||||
},
|
||||
*loc,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// Check for "too many arguments"
|
||||
if num_params < posargs.len() {
|
||||
let expected_min_count =
|
||||
signature.args.iter().filter(|param| param.is_required()).count();
|
||||
let expected_max_count = num_params;
|
||||
|
||||
self.restore_snapshot();
|
||||
return Err(TypeError::new(
|
||||
TypeErrorKind::TooManyArguments {
|
||||
expected_min_count,
|
||||
expected_max_count,
|
||||
got_count: num_args,
|
||||
},
|
||||
*loc,
|
||||
));
|
||||
}
|
||||
|
||||
// NOTE: order of `param_info_by_name` is leveraged, so use an IndexMap
|
||||
let mut param_info_by_name: IndexMap<StrRef, ParamInfo> = signature
|
||||
.args
|
||||
.iter()
|
||||
.map(|arg| (arg.name, ParamInfo { has_been_supplied: false, param: arg }))
|
||||
.collect();
|
||||
|
||||
// Now consume all positional arguments and typecheck them.
|
||||
for (&arg_ty, param) in zip(posargs, signature.args.iter()) {
|
||||
// We will also use this opportunity to mark the corresponding `param_info` as having been supplied.
|
||||
let param_info = param_info_by_name.get_mut(¶m.name).unwrap();
|
||||
param_info.has_been_supplied = true;
|
||||
|
||||
// Typecheck
|
||||
type_check_arg(param.name, param.ty, arg_ty)?;
|
||||
}
|
||||
|
||||
// Now consume all keyword arguments and typecheck them.
|
||||
for (¶m_name, &arg_ty) in kwargs {
|
||||
// We will also use this opportunity to check if this keyword argument is "legal".
|
||||
|
||||
let Some(param_info) = param_info_by_name.get_mut(¶m_name) else {
|
||||
self.restore_snapshot();
|
||||
return Err(TypeError::new(
|
||||
TypeErrorKind::UnknownArgName(param_name),
|
||||
*loc,
|
||||
));
|
||||
};
|
||||
|
||||
if param_info.has_been_supplied {
|
||||
// NOTE: Duplicate keyword argument (i.e., `hello(1, 2, 3, arg = 4, arg = 5)`)
|
||||
// is IMPOSSIBLE as the parser would have already failed.
|
||||
// We only have to care about "got multiple values for XYZ"
|
||||
|
||||
self.restore_snapshot();
|
||||
return Err(TypeError::new(
|
||||
TypeErrorKind::GotMultipleValues { name: param_name },
|
||||
*loc,
|
||||
));
|
||||
}
|
||||
|
||||
param_info.has_been_supplied = true;
|
||||
|
||||
// Typecheck
|
||||
type_check_arg(param_name, param_info.param.ty, arg_ty)?;
|
||||
}
|
||||
|
||||
// After checking posargs and kwargs, check if there are any
|
||||
// unsupplied required parameters, and throw an error if they exist.
|
||||
let missing_arg_names = param_info_by_name
|
||||
.values()
|
||||
.filter(|param_info| {
|
||||
param_info.param.is_required() && !param_info.has_been_supplied
|
||||
})
|
||||
.map(|param_info| param_info.param.name)
|
||||
.collect_vec();
|
||||
if !missing_arg_names.is_empty() {
|
||||
self.restore_snapshot();
|
||||
return Err(TypeError::new(
|
||||
TypeErrorKind::MissingArgs { missing_arg_names },
|
||||
*loc,
|
||||
));
|
||||
}
|
||||
|
||||
// Finally, check the Call's return type
|
||||
self.unify_impl(*ret, signature.ret, false).map_err(|mut err| {
|
||||
self.restore_snapshot();
|
||||
if err.loc.is_none() {
|
||||
err.loc = *loc;
|
||||
}
|
||||
err
|
||||
})?;
|
||||
}
|
||||
err
|
||||
})?;
|
||||
*fun.borrow_mut() = Some(instantiated);
|
||||
}
|
||||
|
||||
*fun.borrow_mut() = Some(b);
|
||||
|
||||
self.discard_snapshot(snapshot);
|
||||
Ok(())
|
||||
|
@ -990,17 +1187,10 @@ impl Unifier {
|
|||
self.unification_table.set_value(b, Rc::new(TCall(calls)));
|
||||
}
|
||||
(TCall(calls), TFunc(signature)) => {
|
||||
let required: Vec<StrRef> = signature
|
||||
.args
|
||||
.iter()
|
||||
.filter(|v| v.default_value.is_none())
|
||||
.map(|v| v.name)
|
||||
.rev()
|
||||
.collect();
|
||||
// we unify every calls to the function signature.
|
||||
for c in calls {
|
||||
let call = self.calls[c.0].clone();
|
||||
self.unify_call(&call, b, signature, &required)?;
|
||||
self.unify_call(&call, b, signature)?;
|
||||
}
|
||||
self.set_a_to_b(a, b);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue