Compare commits

...

2 Commits

7 changed files with 459 additions and 223 deletions

View File

@ -21,7 +21,7 @@ use crate::{
DefinitionId, TopLevelDef, DefinitionId, TopLevelDef,
}, },
typecheck::{ typecheck::{
magic_methods::{binop_assign_name, binop_name, unaryop_name}, magic_methods::{Binop, BinopVariant, HasOpInfo},
typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap}, typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap},
}, },
}; };
@ -1164,10 +1164,9 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
left: (&Option<Type>, BasicValueEnum<'ctx>), left: (&Option<Type>, BasicValueEnum<'ctx>),
op: Operator, op: Binop,
right: (&Option<Type>, BasicValueEnum<'ctx>), right: (&Option<Type>, BasicValueEnum<'ctx>),
loc: Location, loc: Location,
is_aug_assign: bool,
) -> Result<Option<ValueEnum<'ctx>>, String> { ) -> Result<Option<ValueEnum<'ctx>>, String> {
let (left_ty, left_val) = left; let (left_ty, left_val) = left;
let (right_ty, right_val) = right; let (right_ty, right_val) = right;
@ -1179,17 +1178,17 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
// which would be unchanged until further unification, which we would never do // which would be unchanged until further unification, which we would never do
// when doing code generation for function instances // when doing code generation for function instances
if ty1 == ty2 && [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1) { if ty1 == ty2 && [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1) {
Ok(Some(ctx.gen_int_ops(generator, op, left_val, right_val, true).into())) Ok(Some(ctx.gen_int_ops(generator, op.base, left_val, right_val, true).into()))
} else if ty1 == ty2 && [ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty1) { } else if ty1 == ty2 && [ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty1) {
Ok(Some(ctx.gen_int_ops(generator, op, left_val, right_val, false).into())) Ok(Some(ctx.gen_int_ops(generator, op.base, left_val, right_val, false).into()))
} else if [Operator::LShift, Operator::RShift].contains(&op) { } else if [Operator::LShift, Operator::RShift].contains(&op.base) {
let signed = [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1); let signed = [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1);
Ok(Some(ctx.gen_int_ops(generator, op, left_val, right_val, signed).into())) Ok(Some(ctx.gen_int_ops(generator, op.base, left_val, right_val, signed).into()))
} else if ty1 == ty2 && ctx.primitives.float == ty1 { } else if ty1 == ty2 && ctx.primitives.float == ty1 {
Ok(Some(ctx.gen_float_ops(op, left_val, right_val).into())) Ok(Some(ctx.gen_float_ops(op.base, left_val, right_val).into()))
} else if ty1 == ctx.primitives.float && ty2 == ctx.primitives.int32 { } else if ty1 == ctx.primitives.float && ty2 == ctx.primitives.int32 {
// Pow is the only operator that would pass typecheck between float and int // Pow is the only operator that would pass typecheck between float and int
assert_eq!(op, Operator::Pow); assert_eq!(op.base, Operator::Pow);
let res = call_float_powi( let res = call_float_powi(
ctx, ctx,
left_val.into_float_value(), left_val.into_float_value(),
@ -1216,13 +1215,16 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
let right_val = let right_val =
NDArrayValue::from_ptr_val(right_val.into_pointer_value(), llvm_usize, None); NDArrayValue::from_ptr_val(right_val.into_pointer_value(), llvm_usize, None);
let res = if op == Operator::MatMult { let res = if op.base == Operator::MatMult {
// MatMult is the only binop which is not an elementwise op // MatMult is the only binop which is not an elementwise op
numpy::ndarray_matmul_2d( numpy::ndarray_matmul_2d(
generator, generator,
ctx, ctx,
ndarray_dtype1, ndarray_dtype1,
if is_aug_assign { Some(left_val) } else { None }, match op.variant {
BinopVariant::Normal => None,
BinopVariant::AugAssign => Some(left_val),
},
left_val, left_val,
right_val, right_val,
)? )?
@ -1231,7 +1233,10 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
generator, generator,
ctx, ctx,
ndarray_dtype1, ndarray_dtype1,
if is_aug_assign { Some(left_val) } else { None }, match op.variant {
BinopVariant::Normal => None,
BinopVariant::AugAssign => Some(left_val),
},
(left_val.as_base_value().into(), false), (left_val.as_base_value().into(), false),
(right_val.as_base_value().into(), false), (right_val.as_base_value().into(), false),
|generator, ctx, (lhs, rhs)| { |generator, ctx, (lhs, rhs)| {
@ -1242,7 +1247,6 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
op, op,
(&Some(ndarray_dtype2), rhs), (&Some(ndarray_dtype2), rhs),
ctx.current_loc, ctx.current_loc,
is_aug_assign,
)? )?
.unwrap() .unwrap()
.to_basic_value_enum( .to_basic_value_enum(
@ -1267,7 +1271,10 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
generator, generator,
ctx, ctx,
ndarray_dtype, ndarray_dtype,
if is_aug_assign { Some(ndarray_val) } else { None }, match op.variant {
BinopVariant::Normal => None,
BinopVariant::AugAssign => Some(ndarray_val),
},
(left_val, !is_ndarray1), (left_val, !is_ndarray1),
(right_val, !is_ndarray2), (right_val, !is_ndarray2),
|generator, ctx, (lhs, rhs)| { |generator, ctx, (lhs, rhs)| {
@ -1278,7 +1285,6 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
op, op,
(&Some(ndarray_dtype), rhs), (&Some(ndarray_dtype), rhs),
ctx.current_loc, ctx.current_loc,
is_aug_assign,
)? )?
.unwrap() .unwrap()
.to_basic_value_enum(ctx, generator, ndarray_dtype) .to_basic_value_enum(ctx, generator, ndarray_dtype)
@ -1293,13 +1299,16 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
unreachable!("must be tobj") unreachable!("must be tobj")
}; };
let (op_name, id) = { let (op_name, id) = {
let (binop_name, binop_assign_name) = let normal_method_name = Binop::normal(op.base).op_info().method_name;
(binop_name(op).into(), binop_assign_name(op).into()); let assign_method_name = Binop::aug_assign(op.base).op_info().method_name;
// if is aug_assign, try aug_assign operator first // if is aug_assign, try aug_assign operator first
if is_aug_assign && fields.contains_key(&binop_assign_name) { if op.variant == BinopVariant::AugAssign
(binop_assign_name, *obj_id) && fields.contains_key(&assign_method_name.into())
{
(assign_method_name.into(), *obj_id)
} else { } else {
(binop_name, *obj_id) (normal_method_name.into(), *obj_id)
} }
}; };
@ -1346,10 +1355,9 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
left: &Expr<Option<Type>>, left: &Expr<Option<Type>>,
op: Operator, op: Binop,
right: &Expr<Option<Type>>, right: &Expr<Option<Type>>,
loc: Location, loc: Location,
is_aug_assign: bool,
) -> Result<Option<ValueEnum<'ctx>>, String> { ) -> Result<Option<ValueEnum<'ctx>>, String> {
let left_val = if let Some(v) = generator.gen_expr(ctx, left)? { let left_val = if let Some(v) = generator.gen_expr(ctx, left)? {
v.to_basic_value_enum(ctx, generator, left.custom.unwrap())? v.to_basic_value_enum(ctx, generator, left.custom.unwrap())?
@ -1369,7 +1377,6 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
op, op,
(&right.custom, right_val), (&right.custom, right_val),
loc, loc,
is_aug_assign,
) )
} }
@ -1453,7 +1460,10 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
if op == ast::Unaryop::Invert { if op == ast::Unaryop::Invert {
ast::Unaryop::Not ast::Unaryop::Not
} else { } else {
unreachable!("ufunc {} not supported for ndarray[bool, N]", unaryop_name(op)) unreachable!(
"ufunc {} not supported for ndarray[bool, N]",
op.op_info().method_name,
)
} }
} else { } else {
op op
@ -2343,7 +2353,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
} }
} }
ExprKind::BinOp { op, left, right } => { ExprKind::BinOp { op, left, right } => {
return gen_binop_expr(generator, ctx, left, *op, right, expr.location, false); return gen_binop_expr(generator, ctx, left, Binop::normal(*op), right, expr.location);
} }
ExprKind::UnaryOp { op, operand } => return gen_unaryop_expr(generator, ctx, *op, operand), ExprKind::UnaryOp { op, operand } => return gen_unaryop_expr(generator, ctx, *op, operand),
ExprKind::Compare { left, ops, comparators } => { ExprKind::Compare { left, ops, comparators } => {

View File

@ -11,8 +11,7 @@ use crate::{
call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices, call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices,
call_ndarray_calc_size, call_ndarray_calc_size,
}, },
llvm_intrinsics, llvm_intrinsics::{self, call_memcpy_generic},
llvm_intrinsics::call_memcpy_generic,
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}, },
@ -22,7 +21,10 @@ use crate::{
numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
DefinitionId, DefinitionId,
}, },
typecheck::typedef::{FunSignature, Type, TypeEnum}, typecheck::{
magic_methods::Binop,
typedef::{FunSignature, Type, TypeEnum},
},
}; };
use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType}; use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType};
use inkwell::{ use inkwell::{
@ -1677,10 +1679,9 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
generator, generator,
ctx, ctx,
(&Some(elem_ty), a), (&Some(elem_ty), a),
Operator::Mult, Binop::normal(Operator::Mult),
(&Some(elem_ty), b), (&Some(elem_ty), b),
ctx.current_loc, ctx.current_loc,
false,
)? )?
.unwrap() .unwrap()
.to_basic_value_enum(ctx, generator, elem_ty)?; .to_basic_value_enum(ctx, generator, elem_ty)?;
@ -1690,10 +1691,9 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
generator, generator,
ctx, ctx,
(&Some(elem_ty), result), (&Some(elem_ty), result),
Operator::Add, Binop::normal(Operator::Add),
(&Some(elem_ty), a_mul_b), (&Some(elem_ty), a_mul_b),
ctx.current_loc, ctx.current_loc,
false,
)? )?
.unwrap() .unwrap()
.to_basic_value_enum(ctx, generator, elem_ty)?; .to_basic_value_enum(ctx, generator, elem_ty)?;

View File

@ -11,7 +11,10 @@ use crate::{
gen_in_range_check, gen_in_range_check,
}, },
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef}, toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, Type, TypeEnum}, typecheck::{
magic_methods::Binop,
typedef::{FunSignature, Type, TypeEnum},
},
}; };
use inkwell::{ use inkwell::{
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
@ -1574,7 +1577,14 @@ pub fn gen_stmt<G: CodeGenerator>(
StmtKind::For { .. } => generator.gen_for(ctx, stmt)?, StmtKind::For { .. } => generator.gen_for(ctx, stmt)?,
StmtKind::With { .. } => generator.gen_with(ctx, stmt)?, StmtKind::With { .. } => generator.gen_with(ctx, stmt)?,
StmtKind::AugAssign { target, op, value, .. } => { StmtKind::AugAssign { target, op, value, .. } => {
let value = gen_binop_expr(generator, ctx, target, *op, value, stmt.location, true)?; let value = gen_binop_expr(
generator,
ctx,
target,
Binop::aug_assign(*op),
value,
stmt.location,
)?;
generator.gen_assign(ctx, target, value.unwrap())?; generator.gen_assign(ctx, target, value.unwrap())?;
} }
StmtKind::Try { .. } => gen_try(generator, ctx, stmt)?, StmtKind::Try { .. } => gen_try(generator, ctx, stmt)?,

View File

@ -5,7 +5,7 @@ use crate::typecheck::{
type_inferencer::*, type_inferencer::*,
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
}; };
use itertools::Itertools; use itertools::{iproduct, Itertools};
use nac3parser::ast::StrRef; use nac3parser::ast::StrRef;
use nac3parser::ast::{Cmpop, Operator, Unaryop}; use nac3parser::ast::{Cmpop, Operator, Unaryop};
use std::cmp::max; use std::cmp::max;
@ -13,67 +13,135 @@ use std::collections::HashMap;
use std::rc::Rc; use std::rc::Rc;
use strum::IntoEnumIterator; use strum::IntoEnumIterator;
#[must_use] /// The variant of a binary operator.
pub fn binop_name(op: Operator) -> &'static str { #[derive(Debug, Clone, Copy, PartialEq, Eq)]
match op { pub enum BinopVariant {
Operator::Add => "__add__", /// The normal variant.
Operator::Sub => "__sub__", /// For addition, it would be `+`.
Operator::Div => "__truediv__", Normal,
Operator::Mod => "__mod__", /// The "Augmented Assigning Operator" variant.
Operator::Mult => "__mul__", /// For addition, it would be `+=`.
Operator::Pow => "__pow__", AugAssign,
Operator::BitOr => "__or__", }
Operator::BitXor => "__xor__",
Operator::BitAnd => "__and__", /// A binary operator with its variant.
Operator::LShift => "__lshift__", #[derive(Debug, Clone, Copy)]
Operator::RShift => "__rshift__", pub struct Binop {
Operator::FloorDiv => "__floordiv__", /// The base [`Operator`] of this binary operator.
Operator::MatMult => "__matmul__", pub base: Operator,
/// The variant of this binary operator.
pub variant: BinopVariant,
}
impl Binop {
/// Make a [`Binop`] of the normal variant from an [`Operator`].
#[must_use]
pub fn normal(base: Operator) -> Self {
Binop { base, variant: BinopVariant::Normal }
}
/// Make a [`Binop`] of the aug assign variant from an [`Operator`].
#[must_use]
pub fn aug_assign(base: Operator) -> Self {
Binop { base, variant: BinopVariant::AugAssign }
} }
} }
#[must_use] /// Details about an operator (unary, binary, etc...) in Python
pub fn binop_assign_name(op: Operator) -> &'static str { #[derive(Debug, Clone, Copy)]
match op { pub struct OpInfo {
Operator::Add => "__iadd__", /// The method name of the binary operator.
Operator::Sub => "__isub__", /// For addition, this would be `__add__`, and `__iadd__` if
Operator::Div => "__itruediv__", /// it is the augmented assigning variant.
Operator::Mod => "__imod__", pub method_name: &'static str,
Operator::Mult => "__imul__", /// The symbol of the binary operator.
Operator::Pow => "__ipow__", /// For addition, this would be `+`, and `+=` if
Operator::BitOr => "__ior__", /// it is the augmented assigning variant.
Operator::BitXor => "__ixor__", pub symbol: &'static str,
Operator::BitAnd => "__iand__",
Operator::LShift => "__ilshift__",
Operator::RShift => "__irshift__",
Operator::FloorDiv => "__ifloordiv__",
Operator::MatMult => "__imatmul__",
}
} }
#[must_use] /// Helper macro to conveniently build an [`OpInfo`].
pub fn unaryop_name(op: Unaryop) -> &'static str { ///
match op { /// Example usage: `make_info("add", "+")` generates `OpInfo { name: "__add__", symbol: "+" }`
Unaryop::UAdd => "__pos__", macro_rules! make_info {
Unaryop::USub => "__neg__", ($name:expr, $symbol:expr) => {
Unaryop::Not => "__not__", OpInfo { method_name: concat!("__", $name, "__"), symbol: $symbol }
Unaryop::Invert => "__inv__", };
}
} }
#[must_use] pub trait HasOpInfo {
pub fn comparison_name(op: Cmpop) -> Option<&'static str> { fn op_info(&self) -> OpInfo;
}
fn try_get_cmpop_info(op: Cmpop) -> Option<OpInfo> {
match op { match op {
Cmpop::Lt => Some("__lt__"), Cmpop::Lt => Some(make_info!("lt", "<")),
Cmpop::LtE => Some("__le__"), Cmpop::LtE => Some(make_info!("le", "<=")),
Cmpop::Gt => Some("__gt__"), Cmpop::Gt => Some(make_info!("gt", ">")),
Cmpop::GtE => Some("__ge__"), Cmpop::GtE => Some(make_info!("ge", ">=")),
Cmpop::Eq => Some("__eq__"), Cmpop::Eq => Some(make_info!("eq", "==")),
Cmpop::NotEq => Some("__ne__"), Cmpop::NotEq => Some(make_info!("ne", "!=")),
_ => None, _ => None,
} }
} }
impl OpInfo {
#[must_use]
pub fn supports_cmpop(op: Cmpop) -> bool {
try_get_cmpop_info(op).is_some()
}
}
impl HasOpInfo for Cmpop {
fn op_info(&self) -> OpInfo {
try_get_cmpop_info(*self).expect("{self:?} is not supported")
}
}
impl HasOpInfo for Binop {
fn op_info(&self) -> OpInfo {
// Helper macro to generate both the normal variant [`OpInfo`] and the
// augmented assigning variant [`OpInfo`] for a binary operator conveniently.
macro_rules! info {
($name:literal, $symbol:literal) => {
(make_info!($name, $symbol), make_info!(concat!("i", $name), concat!($symbol, "=")))
};
}
let (normal_variant, aug_assign_variant) = match self.base {
Operator::Add => info!("add", "+"),
Operator::Sub => info!("sub", "-"),
Operator::Div => info!("truediv", "/"),
Operator::Mod => info!("mod", "%"),
Operator::Mult => info!("mul", "*"),
Operator::Pow => info!("pow", "**"),
Operator::BitOr => info!("or", "|"),
Operator::BitXor => info!("xor", "^"),
Operator::BitAnd => info!("and", "&"),
Operator::LShift => info!("lshift", "<<"),
Operator::RShift => info!("rshift", ">>"),
Operator::FloorDiv => info!("floordiv", "//"),
Operator::MatMult => info!("matmul", "@"),
};
match self.variant {
BinopVariant::Normal => normal_variant,
BinopVariant::AugAssign => aug_assign_variant,
}
}
}
impl HasOpInfo for Unaryop {
fn op_info(&self) -> OpInfo {
match self {
Unaryop::UAdd => make_info!("pos", "+"),
Unaryop::USub => make_info!("neg", "-"),
Unaryop::Not => make_info!("not", "not"), // i.e., `not False`, so the symbol is just `not`.
Unaryop::Invert => make_info!("inv", "~"),
}
}
}
pub(super) fn with_fields<F>(unifier: &mut Unifier, ty: Type, f: F) pub(super) fn with_fields<F>(unifier: &mut Unifier, ty: Type, f: F)
where where
F: FnOnce(&mut Unifier, &mut HashMap<StrRef, (Type, bool)>), F: FnOnce(&mut Unifier, &mut HashMap<StrRef, (Type, bool)>),
@ -115,23 +183,9 @@ pub fn impl_binop(
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty); let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty);
for op in ops { for (base_op, variant) in iproduct!(ops, [BinopVariant::Normal, BinopVariant::AugAssign]) {
fields.insert(binop_name(*op).into(), { let op = Binop { base: *base_op, variant };
( fields.insert(op.op_info().method_name.into(), {
unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ret_ty,
vars: function_vars.clone(),
args: vec![FuncArg {
ty: other_ty,
default_value: None,
name: "other".into(),
}],
})),
false,
)
});
fields.insert(binop_assign_name(*op).into(), {
( (
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ret_ty, ret: ret_ty,
@ -155,7 +209,7 @@ pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Option<Type>, ops:
for op in ops { for op in ops {
fields.insert( fields.insert(
unaryop_name(*op).into(), op.op_info().method_name.into(),
( (
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ret_ty, ret: ret_ty,
@ -195,7 +249,7 @@ pub fn impl_cmpop(
for op in ops { for op in ops {
fields.insert( fields.insert(
comparison_name(*op).unwrap().into(), op.op_info().method_name.into(),
( (
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ret_ty, ret: ret_ty,

View File

@ -1,11 +1,14 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::Display; use std::fmt::Display;
use crate::typecheck::typedef::TypeEnum; use crate::typecheck::{magic_methods::HasOpInfo, typedef::TypeEnum};
use super::typedef::{RecordKey, Type, Unifier}; use super::{
magic_methods::Binop,
typedef::{RecordKey, Type, Unifier},
};
use itertools::Itertools; use itertools::Itertools;
use nac3parser::ast::{Location, StrRef}; use nac3parser::ast::{Cmpop, Location, StrRef};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum TypeErrorKind { pub enum TypeErrorKind {
@ -26,6 +29,18 @@ pub enum TypeErrorKind {
expected: Type, expected: Type,
got: Type, got: Type,
}, },
UnsupportedBinaryOpTypes {
operator: Binop,
lhs_type: Type,
rhs_type: Type,
expected_rhs_type: Type,
},
UnsupportedComparsionOpTypes {
operator: Cmpop,
lhs_type: Type,
rhs_type: Type,
expected_rhs_type: Type,
},
FieldUnificationError { FieldUnificationError {
field: RecordKey, field: RecordKey,
types: (Type, Type), types: (Type, Type),
@ -101,6 +116,26 @@ impl<'a> Display for DisplayTypeError<'a> {
let args = missing_arg_names.iter().join(", "); let args = missing_arg_names.iter().join(", ");
write!(f, "Missing arguments: {args}") write!(f, "Missing arguments: {args}")
} }
UnsupportedBinaryOpTypes { operator, lhs_type, rhs_type, expected_rhs_type } => {
let op_symbol = operator.op_info().symbol;
let lhs_type_str = self.unifier.stringify_with_notes(*lhs_type, &mut notes);
let rhs_type_str = self.unifier.stringify_with_notes(*rhs_type, &mut notes);
let expected_rhs_type_str =
self.unifier.stringify_with_notes(*expected_rhs_type, &mut notes);
write!(f, "Unsupported operand type(s) for {op_symbol}: '{lhs_type_str}' and '{rhs_type_str}' (right operand should have type {expected_rhs_type_str})")
}
UnsupportedComparsionOpTypes { operator, lhs_type, rhs_type, expected_rhs_type } => {
let op_symbol = operator.op_info().symbol;
let lhs_type_str = self.unifier.stringify_with_notes(*lhs_type, &mut notes);
let rhs_type_str = self.unifier.stringify_with_notes(*rhs_type, &mut notes);
let expected_rhs_type_str =
self.unifier.stringify_with_notes(*expected_rhs_type, &mut notes);
write!(f, "'{op_symbol}' not supported between instances of '{lhs_type_str}' and '{rhs_type_str}' (right operand should have type {expected_rhs_type_str})")
}
UnknownArgName(name) => { UnknownArgName(name) => {
write!(f, "Unknown argument name: {name}") write!(f, "Unknown argument name: {name}")
} }

View File

@ -4,7 +4,9 @@ use std::iter::once;
use std::ops::Not; use std::ops::Not;
use std::{cell::RefCell, sync::Arc}; use std::{cell::RefCell, sync::Arc};
use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier, VarMap}; use super::typedef::{
Call, FunSignature, FuncArg, OperatorInfo, RecordField, Type, TypeEnum, Unifier, VarMap,
};
use super::{magic_methods::*, type_error::TypeError, typedef::CallId}; use super::{magic_methods::*, type_error::TypeError, typedef::CallId};
use crate::toplevel::TopLevelDef; use crate::toplevel::TopLevelDef;
use crate::{ use crate::{
@ -466,7 +468,8 @@ impl<'a> Fold<()> for Inferencer<'a> {
(None, None) => {} (None, None) => {}
}, },
ast::StmtKind::AugAssign { target, op, value, .. } => { ast::StmtKind::AugAssign { target, op, value, .. } => {
let res_ty = self.infer_bin_ops(stmt.location, target, *op, value, true)?; let res_ty =
self.infer_bin_ops(stmt.location, target, Binop::aug_assign(*op), value)?;
self.unify(res_ty, target.custom.unwrap(), &stmt.location)?; self.unify(res_ty, target.custom.unwrap(), &stmt.location)?;
} }
ast::StmtKind::Assert { test, msg, .. } => { ast::StmtKind::Assert { test, msg, .. } => {
@ -548,7 +551,7 @@ impl<'a> Fold<()> for Inferencer<'a> {
} }
ExprKind::BoolOp { values, .. } => Some(self.infer_bool_ops(values)?), ExprKind::BoolOp { values, .. } => Some(self.infer_bool_ops(values)?),
ExprKind::BinOp { left, op, right } => { ExprKind::BinOp { left, op, right } => {
Some(self.infer_bin_ops(expr.location, left, *op, right, false)?) Some(self.infer_bin_ops(expr.location, left, Binop::normal(*op), right)?)
} }
ExprKind::UnaryOp { op, operand } => { ExprKind::UnaryOp { op, operand } => {
Some(self.infer_unary_ops(expr.location, *op, operand)?) Some(self.infer_unary_ops(expr.location, *op, operand)?)
@ -615,6 +618,7 @@ impl<'a> Inferencer<'a> {
obj: Type, obj: Type,
params: Vec<Type>, params: Vec<Type>,
ret: Option<Type>, ret: Option<Type>,
operator_info: Option<OperatorInfo>,
) -> InferenceResult { ) -> InferenceResult {
if let TypeEnum::TObj { params: class_params, fields, .. } = &*self.unifier.get_ty(obj) { if let TypeEnum::TObj { params: class_params, fields, .. } = &*self.unifier.get_ty(obj) {
if class_params.is_empty() { if class_params.is_empty() {
@ -628,6 +632,7 @@ impl<'a> Inferencer<'a> {
ret: sign.ret, ret: sign.ret,
fun: RefCell::new(None), fun: RefCell::new(None),
loc: Some(location), loc: Some(location),
operator_info,
}; };
if let Some(ret) = ret { if let Some(ret) = ret {
self.unifier self.unifier
@ -662,6 +667,7 @@ impl<'a> Inferencer<'a> {
ret, ret,
fun: RefCell::new(None), fun: RefCell::new(None),
loc: Some(location), loc: Some(location),
operator_info,
}); });
self.calls.insert(location.into(), call); self.calls.insert(location.into(), call);
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call])); let call = self.unifier.add_ty(TypeEnum::TCall(vec![call]));
@ -1473,6 +1479,7 @@ impl<'a> Inferencer<'a> {
fun: RefCell::new(None), fun: RefCell::new(None),
ret: sign.ret, ret: sign.ret,
loc: Some(location), loc: Some(location),
operator_info: None,
}; };
self.unifier.unify_call(&call, func.custom.unwrap(), sign).map_err(|e| { self.unifier.unify_call(&call, func.custom.unwrap(), sign).map_err(|e| {
HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()]) HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()])
@ -1495,6 +1502,7 @@ impl<'a> Inferencer<'a> {
fun: RefCell::new(None), fun: RefCell::new(None),
ret, ret,
loc: Some(location), loc: Some(location),
operator_info: None,
}); });
self.calls.insert(location.into(), call); self.calls.insert(location.into(), call);
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call])); let call = self.unifier.add_ty(TypeEnum::TCall(vec![call]));
@ -1668,9 +1676,8 @@ impl<'a> Inferencer<'a> {
&mut self, &mut self,
location: Location, location: Location,
left: &ast::Expr<Option<Type>>, left: &ast::Expr<Option<Type>>,
op: ast::Operator, op: Binop,
right: &ast::Expr<Option<Type>>, right: &ast::Expr<Option<Type>>,
is_aug_assign: bool,
) -> InferenceResult { ) -> InferenceResult {
let left_ty = left.custom.unwrap(); let left_ty = left.custom.unwrap();
let right_ty = right.custom.unwrap(); let right_ty = right.custom.unwrap();
@ -1678,27 +1685,40 @@ impl<'a> Inferencer<'a> {
let method = if let TypeEnum::TObj { fields, .. } = let method = if let TypeEnum::TObj { fields, .. } =
self.unifier.get_ty_immutable(left_ty).as_ref() self.unifier.get_ty_immutable(left_ty).as_ref()
{ {
let (binop_name, binop_assign_name) = let normal_method_name = Binop::normal(op.base).op_info().method_name;
(binop_name(op).into(), binop_assign_name(op).into()); let assign_method_name = Binop::aug_assign(op.base).op_info().method_name;
// if is aug_assign, try aug_assign operator first // if is aug_assign, try aug_assign operator first
if is_aug_assign && fields.contains_key(&binop_assign_name) { if op.variant == BinopVariant::AugAssign
binop_assign_name && fields.contains_key(&assign_method_name.into())
{
assign_method_name
} else { } else {
binop_name normal_method_name
} }
} else { } else {
binop_name(op).into() op.op_info().method_name
}; };
let ret = if is_aug_assign { let ret = match op.variant {
// The type of augmented assignment operator should never change BinopVariant::Normal => {
Some(left_ty) typeof_binop(self.unifier, self.primitives, op.base, left_ty, right_ty)
} else { .map_err(|e| HashSet::from([format!("{e} (at {location})")]))?
typeof_binop(self.unifier, self.primitives, op, left_ty, right_ty) }
.map_err(|e| HashSet::from([format!("{e} (at {location})")]))? BinopVariant::AugAssign => {
// The type of augmented assignment operator should never change
Some(left_ty)
}
}; };
self.build_method_call(location, method, left_ty, vec![right_ty], ret) self.build_method_call(
location,
method.into(),
left_ty,
vec![right_ty],
ret,
Some(OperatorInfo::IsBinaryOp { self_type: left.custom.unwrap(), operator: op }),
)
} }
fn infer_unary_ops( fn infer_unary_ops(
@ -1707,12 +1727,19 @@ impl<'a> Inferencer<'a> {
op: ast::Unaryop, op: ast::Unaryop,
operand: &ast::Expr<Option<Type>>, operand: &ast::Expr<Option<Type>>,
) -> InferenceResult { ) -> InferenceResult {
let method = unaryop_name(op).into(); let method = op.op_info().method_name.into();
let ret = typeof_unaryop(self.unifier, self.primitives, op, operand.custom.unwrap()) let ret = typeof_unaryop(self.unifier, self.primitives, op, operand.custom.unwrap())
.map_err(|e| HashSet::from([format!("{e} (at {location})")]))?; .map_err(|e| HashSet::from([format!("{e} (at {location})")]))?;
self.build_method_call(location, method, operand.custom.unwrap(), vec![], ret) self.build_method_call(
location,
method,
operand.custom.unwrap(),
vec![],
ret,
Some(OperatorInfo::IsUnaryOp { self_type: operand.custom.unwrap(), operator: op }),
)
} }
fn infer_compare( fn infer_compare(
@ -1737,9 +1764,11 @@ impl<'a> Inferencer<'a> {
let mut res = None; let mut res = None;
for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) { for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) {
let method = comparison_name(*c) if !OpInfo::supports_cmpop(*c) {
.ok_or_else(|| HashSet::from(["unsupported comparator".to_string()]))? return Err(HashSet::from(["unsupported comparator".to_string()]));
.into(); }
let method = c.op_info().method_name.into();
let ret = typeof_cmpop( let ret = typeof_cmpop(
self.unifier, self.unifier,
@ -1756,6 +1785,10 @@ impl<'a> Inferencer<'a> {
a.custom.unwrap(), a.custom.unwrap(),
vec![b.custom.unwrap()], vec![b.custom.unwrap()],
ret, ret,
Some(OperatorInfo::IsComparisonOp {
self_type: left.custom.unwrap(),
operator: *c,
}),
)?); )?);
} }

View File

@ -8,12 +8,14 @@ use std::rc::Rc;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use std::{borrow::Cow, collections::HashSet}; use std::{borrow::Cow, collections::HashSet};
use nac3parser::ast::{Location, StrRef}; use nac3parser::ast::{Cmpop, Location, StrRef, Unaryop};
use super::magic_methods::Binop;
use super::type_error::{TypeError, TypeErrorKind}; use super::type_error::{TypeError, TypeErrorKind};
use super::unification_table::{UnificationKey, UnificationTable}; use super::unification_table::{UnificationKey, UnificationTable};
use crate::symbol_resolver::SymbolValue; use crate::symbol_resolver::SymbolValue;
use crate::toplevel::{DefinitionId, TopLevelContext, TopLevelDef}; use crate::toplevel::{DefinitionId, TopLevelContext, TopLevelDef};
use crate::typecheck::magic_methods::OpInfo;
use crate::typecheck::type_inferencer::PrimitiveStore; use crate::typecheck::type_inferencer::PrimitiveStore;
#[cfg(test)] #[cfg(test)]
@ -73,6 +75,28 @@ pub fn iter_type_vars(var_map: &VarMap) -> impl Iterator<Item = TypeVar> + '_ {
var_map.iter().map(|(&id, &ty)| TypeVar { id, ty }) var_map.iter().map(|(&id, &ty)| TypeVar { id, ty })
} }
#[derive(Debug, Clone)]
pub enum OperatorInfo {
/// The call was written as an unary operation, e.g., `~a` or `not a`.
IsUnaryOp {
/// The [`Type`] of the `self` object
self_type: Type,
operator: Unaryop,
},
/// The call was written as a binary operation, e.g., `a + b` or `a += b`.
IsBinaryOp {
/// The [`Type`] of the `self` object
self_type: Type,
operator: Binop,
},
/// The call was written as a binary comparison operation, e.g., `a < b`.
IsComparisonOp {
/// The [`Type`] of the `self` object
self_type: Type,
operator: Cmpop,
},
}
#[derive(Clone)] #[derive(Clone)]
pub struct Call { pub struct Call {
pub posargs: Vec<Type>, pub posargs: Vec<Type>,
@ -80,6 +104,9 @@ pub struct Call {
pub ret: Type, pub ret: Type,
pub fun: RefCell<Option<Type>>, pub fun: RefCell<Option<Type>>,
pub loc: Option<Location>, pub loc: Option<Location>,
/// Details about the associated Python user operator expression pf this call, if any.
pub operator_info: Option<OperatorInfo>,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -618,111 +645,178 @@ impl Unifier {
let TypeEnum::TFunc(signature) = &*self.get_ty(b) else { unreachable!() }; let TypeEnum::TFunc(signature) = &*self.get_ty(b) else { unreachable!() };
// Get details about the input arguments // Get details about the input arguments
let Call { posargs, kwargs, ret, fun, loc } = call; let Call { posargs, kwargs, ret, fun, loc, operator_info } = call;
let num_args = posargs.len() + kwargs.len(); let num_args = posargs.len() + kwargs.len();
// Now we check the arguments against the parameters // Now we check the arguments against the parameters,
// and depending on what `call_info` is, we might change how the behavior `unify_call()`
// in hopes to improve user error messages when type checking fails.
match operator_info {
Some(OperatorInfo::IsBinaryOp { self_type, operator }) => {
// The call is written in the form of (say) `a + b`.
// Technically, it is `a.__add__(b)`, and they have the following constraints:
assert_eq!(posargs.len(), 1);
assert_eq!(kwargs.len(), 0);
assert_eq!(num_params, 1);
// Helper lambdas let other_type = posargs[0]; // the second operand
let mut type_check_arg = |param_name, expected_arg_ty, arg_ty| { let expected_other_type = signature.args[0].ty;
let ok = self.unify_impl(expected_arg_ty, arg_ty, false).is_ok();
if ok { let ok = self.unify_impl(expected_other_type, other_type, false).is_ok();
Ok(()) if !ok {
} else { self.restore_snapshot();
// Typecheck failed, throw an error. return Err(TypeError::new(
self.restore_snapshot(); TypeErrorKind::UnsupportedBinaryOpTypes {
Err(TypeError::new( operator: *operator,
TypeErrorKind::IncorrectArgType { lhs_type: *self_type,
name: param_name, rhs_type: other_type,
expected: expected_arg_ty, expected_rhs_type: expected_other_type,
got: arg_ty, },
}, *loc,
*loc, ));
)) }
} }
}; Some(OperatorInfo::IsComparisonOp { self_type, operator })
if OpInfo::supports_cmpop(*operator) // Otherwise that comparison operator is not supported.
=>
{
// The call is written in the form of (say) `a <= b`.
// Technically, it is `a.__le__(b)`, and they have the following constraints:
assert_eq!(posargs.len(), 1);
assert_eq!(kwargs.len(), 0);
assert_eq!(num_params, 1);
// Check for "too many arguments" let other_type = posargs[0]; // the second operand
if num_params < posargs.len() { let expected_other_type = signature.args[0].ty;
let expected_min_count =
signature.args.iter().filter(|param| param.is_required()).count();
let expected_max_count = num_params;
self.restore_snapshot(); let ok = self.unify_impl(expected_other_type, other_type, false).is_ok();
return Err(TypeError::new( if !ok {
TypeErrorKind::TooManyArguments { self.restore_snapshot();
expected_min_count, return Err(TypeError::new(
expected_max_count, TypeErrorKind::UnsupportedComparsionOpTypes {
got_count: num_args, operator: *operator,
}, lhs_type: *self_type,
*loc, rhs_type: other_type,
)); expected_rhs_type: expected_other_type,
} },
*loc,
// NOTE: order of `param_info_by_name` is leveraged, so use an IndexMap ));
let mut param_info_by_name: IndexMap<StrRef, ParamInfo> = signature }
.args
.iter()
.map(|arg| (arg.name, ParamInfo { has_been_supplied: false, param: arg }))
.collect();
// Now consume all positional arguments and typecheck them.
for (&arg_ty, param) in zip(posargs, signature.args.iter()) {
// We will also use this opportunity to mark the corresponding `param_info` as having been supplied.
let param_info = param_info_by_name.get_mut(&param.name).unwrap();
param_info.has_been_supplied = true;
// Typecheck
type_check_arg(param.name, param.ty, arg_ty)?;
}
// Now consume all keyword arguments and typecheck them.
for (&param_name, &arg_ty) in kwargs {
// We will also use this opportunity to check if this keyword argument is "legal".
let Some(param_info) = param_info_by_name.get_mut(&param_name) else {
self.restore_snapshot();
return Err(TypeError::new(TypeErrorKind::UnknownArgName(param_name), *loc));
};
if param_info.has_been_supplied {
// NOTE: Duplicate keyword argument (i.e., `hello(1, 2, 3, arg = 4, arg = 5)`)
// is IMPOSSIBLE as the parser would have already failed.
// We only have to care about "got multiple values for XYZ"
self.restore_snapshot();
return Err(TypeError::new(
TypeErrorKind::GotMultipleValues { name: param_name },
*loc,
));
} }
_ => {
// Handle [`CallInfo::IsNormalFunctionCall`] and other uninteresting variants
// of [`CallInfo`] (e.g, `CallInfo::IsUnaryOp` and unsupported comparison operators)
param_info.has_been_supplied = true; // Helper lambdas
let mut type_check_arg = |param_name, expected_arg_ty, arg_ty| {
let ok = self.unify_impl(expected_arg_ty, arg_ty, false).is_ok();
if ok {
Ok(())
} else {
// Typecheck failed, throw an error.
self.restore_snapshot();
Err(TypeError::new(
TypeErrorKind::IncorrectArgType {
name: param_name,
expected: expected_arg_ty,
got: arg_ty,
},
*loc,
))
}
};
// Typecheck // Check for "too many arguments"
type_check_arg(param_name, param_info.param.ty, arg_ty)?; if num_params < posargs.len() {
} let expected_min_count =
signature.args.iter().filter(|param| param.is_required()).count();
let expected_max_count = num_params;
// After checking posargs and kwargs, check if there are any self.restore_snapshot();
// unsupplied required parameters, and throw an error if they exist. return Err(TypeError::new(
let missing_arg_names = param_info_by_name TypeErrorKind::TooManyArguments {
.values() expected_min_count,
.filter(|param_info| param_info.param.is_required() && !param_info.has_been_supplied) expected_max_count,
.map(|param_info| param_info.param.name) got_count: num_args,
.collect_vec(); },
if !missing_arg_names.is_empty() { *loc,
self.restore_snapshot(); ));
return Err(TypeError::new(TypeErrorKind::MissingArgs { missing_arg_names }, *loc)); }
}
// Finally, check the Call's return type // NOTE: order of `param_info_by_name` is leveraged, so use an IndexMap
self.unify_impl(*ret, signature.ret, false).map_err(|mut err| { let mut param_info_by_name: IndexMap<StrRef, ParamInfo> = signature
self.restore_snapshot(); .args
if err.loc.is_none() { .iter()
err.loc = *loc; .map(|arg| (arg.name, ParamInfo { has_been_supplied: false, param: arg }))
.collect();
// Now consume all positional arguments and typecheck them.
for (&arg_ty, param) in zip(posargs, signature.args.iter()) {
// We will also use this opportunity to mark the corresponding `param_info` as having been supplied.
let param_info = param_info_by_name.get_mut(&param.name).unwrap();
param_info.has_been_supplied = true;
// Typecheck
type_check_arg(param.name, param.ty, arg_ty)?;
}
// Now consume all keyword arguments and typecheck them.
for (&param_name, &arg_ty) in kwargs {
// We will also use this opportunity to check if this keyword argument is "legal".
let Some(param_info) = param_info_by_name.get_mut(&param_name) else {
self.restore_snapshot();
return Err(TypeError::new(
TypeErrorKind::UnknownArgName(param_name),
*loc,
));
};
if param_info.has_been_supplied {
// NOTE: Duplicate keyword argument (i.e., `hello(1, 2, 3, arg = 4, arg = 5)`)
// is IMPOSSIBLE as the parser would have already failed.
// We only have to care about "got multiple values for XYZ"
self.restore_snapshot();
return Err(TypeError::new(
TypeErrorKind::GotMultipleValues { name: param_name },
*loc,
));
}
param_info.has_been_supplied = true;
// Typecheck
type_check_arg(param_name, param_info.param.ty, arg_ty)?;
}
// After checking posargs and kwargs, check if there are any
// unsupplied required parameters, and throw an error if they exist.
let missing_arg_names = param_info_by_name
.values()
.filter(|param_info| {
param_info.param.is_required() && !param_info.has_been_supplied
})
.map(|param_info| param_info.param.name)
.collect_vec();
if !missing_arg_names.is_empty() {
self.restore_snapshot();
return Err(TypeError::new(
TypeErrorKind::MissingArgs { missing_arg_names },
*loc,
));
}
// Finally, check the Call's return type
self.unify_impl(*ret, signature.ret, false).map_err(|mut err| {
self.restore_snapshot();
if err.loc.is_none() {
err.loc = *loc;
}
err
})?;
} }
err }
})?;
*fun.borrow_mut() = Some(b); *fun.borrow_mut() = Some(b);