Compare commits
2 Commits
cbff356d50
...
f52086b706
Author | SHA1 | Date | |
---|---|---|---|
f52086b706 | |||
0a732691c9 |
@ -27,7 +27,7 @@ use crate::{
|
|||||||
DefinitionId, TopLevelDef,
|
DefinitionId, TopLevelDef,
|
||||||
},
|
},
|
||||||
typecheck::{
|
typecheck::{
|
||||||
magic_methods::{binop_assign_name, binop_name, unaryop_name},
|
magic_methods::{Binop, BinopVariant, HasOpInfo},
|
||||||
typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap},
|
typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
@ -1165,10 +1165,9 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
left: (&Option<Type>, BasicValueEnum<'ctx>),
|
left: (&Option<Type>, BasicValueEnum<'ctx>),
|
||||||
op: Operator,
|
op: Binop,
|
||||||
right: (&Option<Type>, BasicValueEnum<'ctx>),
|
right: (&Option<Type>, BasicValueEnum<'ctx>),
|
||||||
loc: Location,
|
loc: Location,
|
||||||
is_aug_assign: bool,
|
|
||||||
) -> Result<Option<ValueEnum<'ctx>>, String> {
|
) -> Result<Option<ValueEnum<'ctx>>, String> {
|
||||||
let (left_ty, left_val) = left;
|
let (left_ty, left_val) = left;
|
||||||
let (right_ty, right_val) = right;
|
let (right_ty, right_val) = right;
|
||||||
@ -1180,17 +1179,17 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
// which would be unchanged until further unification, which we would never do
|
// which would be unchanged until further unification, which we would never do
|
||||||
// when doing code generation for function instances
|
// when doing code generation for function instances
|
||||||
if ty1 == ty2 && [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1) {
|
if ty1 == ty2 && [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1) {
|
||||||
Ok(Some(ctx.gen_int_ops(generator, op, left_val, right_val, true).into()))
|
Ok(Some(ctx.gen_int_ops(generator, op.base, left_val, right_val, true).into()))
|
||||||
} else if ty1 == ty2 && [ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty1) {
|
} else if ty1 == ty2 && [ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty1) {
|
||||||
Ok(Some(ctx.gen_int_ops(generator, op, left_val, right_val, false).into()))
|
Ok(Some(ctx.gen_int_ops(generator, op.base, left_val, right_val, false).into()))
|
||||||
} else if [Operator::LShift, Operator::RShift].contains(&op) {
|
} else if [Operator::LShift, Operator::RShift].contains(&op.base) {
|
||||||
let signed = [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1);
|
let signed = [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1);
|
||||||
Ok(Some(ctx.gen_int_ops(generator, op, left_val, right_val, signed).into()))
|
Ok(Some(ctx.gen_int_ops(generator, op.base, left_val, right_val, signed).into()))
|
||||||
} else if ty1 == ty2 && ctx.primitives.float == ty1 {
|
} else if ty1 == ty2 && ctx.primitives.float == ty1 {
|
||||||
Ok(Some(ctx.gen_float_ops(op, left_val, right_val).into()))
|
Ok(Some(ctx.gen_float_ops(op.base, left_val, right_val).into()))
|
||||||
} else if ty1 == ctx.primitives.float && ty2 == ctx.primitives.int32 {
|
} else if ty1 == ctx.primitives.float && ty2 == ctx.primitives.int32 {
|
||||||
// Pow is the only operator that would pass typecheck between float and int
|
// Pow is the only operator that would pass typecheck between float and int
|
||||||
assert_eq!(op, Operator::Pow);
|
assert_eq!(op.base, Operator::Pow);
|
||||||
let res = call_float_powi(
|
let res = call_float_powi(
|
||||||
ctx,
|
ctx,
|
||||||
left_val.into_float_value(),
|
left_val.into_float_value(),
|
||||||
@ -1203,11 +1202,11 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
{
|
{
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
if is_aug_assign {
|
if op.variant == BinopVariant::AugAssign {
|
||||||
todo!("Augmented assignment operators not implemented for lists")
|
todo!("Augmented assignment operators not implemented for lists")
|
||||||
}
|
}
|
||||||
|
|
||||||
match op {
|
match op.base {
|
||||||
Operator::Add => {
|
Operator::Add => {
|
||||||
debug_assert_eq!(ty1.obj_id(&ctx.unifier), Some(PrimDef::List.id()));
|
debug_assert_eq!(ty1.obj_id(&ctx.unifier), Some(PrimDef::List.id()));
|
||||||
debug_assert_eq!(ty2.obj_id(&ctx.unifier), Some(PrimDef::List.id()));
|
debug_assert_eq!(ty2.obj_id(&ctx.unifier), Some(PrimDef::List.id()));
|
||||||
@ -1379,13 +1378,16 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
let right_val =
|
let right_val =
|
||||||
NDArrayValue::from_ptr_val(right_val.into_pointer_value(), llvm_usize, None);
|
NDArrayValue::from_ptr_val(right_val.into_pointer_value(), llvm_usize, None);
|
||||||
|
|
||||||
let res = if op == Operator::MatMult {
|
let res = if op.base == Operator::MatMult {
|
||||||
// MatMult is the only binop which is not an elementwise op
|
// MatMult is the only binop which is not an elementwise op
|
||||||
numpy::ndarray_matmul_2d(
|
numpy::ndarray_matmul_2d(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ndarray_dtype1,
|
ndarray_dtype1,
|
||||||
if is_aug_assign { Some(left_val) } else { None },
|
match op.variant {
|
||||||
|
BinopVariant::Normal => None,
|
||||||
|
BinopVariant::AugAssign => Some(left_val),
|
||||||
|
},
|
||||||
left_val,
|
left_val,
|
||||||
right_val,
|
right_val,
|
||||||
)?
|
)?
|
||||||
@ -1394,7 +1396,10 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ndarray_dtype1,
|
ndarray_dtype1,
|
||||||
if is_aug_assign { Some(left_val) } else { None },
|
match op.variant {
|
||||||
|
BinopVariant::Normal => None,
|
||||||
|
BinopVariant::AugAssign => Some(left_val),
|
||||||
|
},
|
||||||
(left_val.as_base_value().into(), false),
|
(left_val.as_base_value().into(), false),
|
||||||
(right_val.as_base_value().into(), false),
|
(right_val.as_base_value().into(), false),
|
||||||
|generator, ctx, (lhs, rhs)| {
|
|generator, ctx, (lhs, rhs)| {
|
||||||
@ -1405,7 +1410,6 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
op,
|
op,
|
||||||
(&Some(ndarray_dtype2), rhs),
|
(&Some(ndarray_dtype2), rhs),
|
||||||
ctx.current_loc,
|
ctx.current_loc,
|
||||||
is_aug_assign,
|
|
||||||
)?
|
)?
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.to_basic_value_enum(
|
.to_basic_value_enum(
|
||||||
@ -1430,7 +1434,10 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ndarray_dtype,
|
ndarray_dtype,
|
||||||
if is_aug_assign { Some(ndarray_val) } else { None },
|
match op.variant {
|
||||||
|
BinopVariant::Normal => None,
|
||||||
|
BinopVariant::AugAssign => Some(ndarray_val),
|
||||||
|
},
|
||||||
(left_val, !is_ndarray1),
|
(left_val, !is_ndarray1),
|
||||||
(right_val, !is_ndarray2),
|
(right_val, !is_ndarray2),
|
||||||
|generator, ctx, (lhs, rhs)| {
|
|generator, ctx, (lhs, rhs)| {
|
||||||
@ -1441,7 +1448,6 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
op,
|
op,
|
||||||
(&Some(ndarray_dtype), rhs),
|
(&Some(ndarray_dtype), rhs),
|
||||||
ctx.current_loc,
|
ctx.current_loc,
|
||||||
is_aug_assign,
|
|
||||||
)?
|
)?
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.to_basic_value_enum(ctx, generator, ndarray_dtype)
|
.to_basic_value_enum(ctx, generator, ndarray_dtype)
|
||||||
@ -1456,13 +1462,16 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
unreachable!("must be tobj")
|
unreachable!("must be tobj")
|
||||||
};
|
};
|
||||||
let (op_name, id) = {
|
let (op_name, id) = {
|
||||||
let (binop_name, binop_assign_name) =
|
let normal_method_name = Binop::normal(op.base).op_info().method_name;
|
||||||
(binop_name(op).into(), binop_assign_name(op).into());
|
let assign_method_name = Binop::aug_assign(op.base).op_info().method_name;
|
||||||
|
|
||||||
// if is aug_assign, try aug_assign operator first
|
// if is aug_assign, try aug_assign operator first
|
||||||
if is_aug_assign && fields.contains_key(&binop_assign_name) {
|
if op.variant == BinopVariant::AugAssign
|
||||||
(binop_assign_name, *obj_id)
|
&& fields.contains_key(&assign_method_name.into())
|
||||||
|
{
|
||||||
|
(assign_method_name.into(), *obj_id)
|
||||||
} else {
|
} else {
|
||||||
(binop_name, *obj_id)
|
(normal_method_name.into(), *obj_id)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1509,10 +1518,9 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
|
|||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
left: &Expr<Option<Type>>,
|
left: &Expr<Option<Type>>,
|
||||||
op: Operator,
|
op: Binop,
|
||||||
right: &Expr<Option<Type>>,
|
right: &Expr<Option<Type>>,
|
||||||
loc: Location,
|
loc: Location,
|
||||||
is_aug_assign: bool,
|
|
||||||
) -> Result<Option<ValueEnum<'ctx>>, String> {
|
) -> Result<Option<ValueEnum<'ctx>>, String> {
|
||||||
let left_val = if let Some(v) = generator.gen_expr(ctx, left)? {
|
let left_val = if let Some(v) = generator.gen_expr(ctx, left)? {
|
||||||
v.to_basic_value_enum(ctx, generator, left.custom.unwrap())?
|
v.to_basic_value_enum(ctx, generator, left.custom.unwrap())?
|
||||||
@ -1532,7 +1540,6 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
|
|||||||
op,
|
op,
|
||||||
(&right.custom, right_val),
|
(&right.custom, right_val),
|
||||||
loc,
|
loc,
|
||||||
is_aug_assign,
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1616,7 +1623,10 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
if op == ast::Unaryop::Invert {
|
if op == ast::Unaryop::Invert {
|
||||||
ast::Unaryop::Not
|
ast::Unaryop::Not
|
||||||
} else {
|
} else {
|
||||||
unreachable!("ufunc {} not supported for ndarray[bool, N]", unaryop_name(op))
|
unreachable!(
|
||||||
|
"ufunc {} not supported for ndarray[bool, N]",
|
||||||
|
op.op_info().method_name,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
op
|
op
|
||||||
@ -2698,7 +2708,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
ExprKind::BinOp { op, left, right } => {
|
ExprKind::BinOp { op, left, right } => {
|
||||||
return gen_binop_expr(generator, ctx, left, *op, right, expr.location, false);
|
return gen_binop_expr(generator, ctx, left, Binop::normal(*op), right, expr.location);
|
||||||
}
|
}
|
||||||
ExprKind::UnaryOp { op, operand } => return gen_unaryop_expr(generator, ctx, *op, operand),
|
ExprKind::UnaryOp { op, operand } => return gen_unaryop_expr(generator, ctx, *op, operand),
|
||||||
ExprKind::Compare { left, ops, comparators } => {
|
ExprKind::Compare { left, ops, comparators } => {
|
||||||
|
@ -11,8 +11,7 @@ use crate::{
|
|||||||
call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices,
|
call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices,
|
||||||
call_ndarray_calc_size,
|
call_ndarray_calc_size,
|
||||||
},
|
},
|
||||||
llvm_intrinsics,
|
llvm_intrinsics::{self, call_memcpy_generic},
|
||||||
llvm_intrinsics::call_memcpy_generic,
|
|
||||||
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
|
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
},
|
},
|
||||||
@ -22,7 +21,10 @@ use crate::{
|
|||||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
||||||
DefinitionId,
|
DefinitionId,
|
||||||
},
|
},
|
||||||
typecheck::typedef::{FunSignature, Type, TypeEnum},
|
typecheck::{
|
||||||
|
magic_methods::Binop,
|
||||||
|
typedef::{FunSignature, Type, TypeEnum},
|
||||||
|
},
|
||||||
};
|
};
|
||||||
use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType};
|
use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType};
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
@ -1679,10 +1681,9 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
|
|||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
(&Some(elem_ty), a),
|
(&Some(elem_ty), a),
|
||||||
Operator::Mult,
|
Binop::normal(Operator::Mult),
|
||||||
(&Some(elem_ty), b),
|
(&Some(elem_ty), b),
|
||||||
ctx.current_loc,
|
ctx.current_loc,
|
||||||
false,
|
|
||||||
)?
|
)?
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.to_basic_value_enum(ctx, generator, elem_ty)?;
|
.to_basic_value_enum(ctx, generator, elem_ty)?;
|
||||||
@ -1692,10 +1693,9 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
|
|||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
(&Some(elem_ty), result),
|
(&Some(elem_ty), result),
|
||||||
Operator::Add,
|
Binop::normal(Operator::Add),
|
||||||
(&Some(elem_ty), a_mul_b),
|
(&Some(elem_ty), a_mul_b),
|
||||||
ctx.current_loc,
|
ctx.current_loc,
|
||||||
false,
|
|
||||||
)?
|
)?
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.to_basic_value_enum(ctx, generator, elem_ty)?;
|
.to_basic_value_enum(ctx, generator, elem_ty)?;
|
||||||
|
@ -11,7 +11,10 @@ use crate::{
|
|||||||
gen_in_range_check,
|
gen_in_range_check,
|
||||||
},
|
},
|
||||||
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef},
|
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef},
|
||||||
typecheck::typedef::{FunSignature, Type, TypeEnum},
|
typecheck::{
|
||||||
|
magic_methods::Binop,
|
||||||
|
typedef::{FunSignature, Type, TypeEnum},
|
||||||
|
},
|
||||||
};
|
};
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
attributes::{Attribute, AttributeLoc},
|
attributes::{Attribute, AttributeLoc},
|
||||||
@ -1593,7 +1596,14 @@ pub fn gen_stmt<G: CodeGenerator>(
|
|||||||
StmtKind::For { .. } => generator.gen_for(ctx, stmt)?,
|
StmtKind::For { .. } => generator.gen_for(ctx, stmt)?,
|
||||||
StmtKind::With { .. } => generator.gen_with(ctx, stmt)?,
|
StmtKind::With { .. } => generator.gen_with(ctx, stmt)?,
|
||||||
StmtKind::AugAssign { target, op, value, .. } => {
|
StmtKind::AugAssign { target, op, value, .. } => {
|
||||||
let value = gen_binop_expr(generator, ctx, target, *op, value, stmt.location, true)?;
|
let value = gen_binop_expr(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
target,
|
||||||
|
Binop::aug_assign(*op),
|
||||||
|
value,
|
||||||
|
stmt.location,
|
||||||
|
)?;
|
||||||
generator.gen_assign(ctx, target, value.unwrap())?;
|
generator.gen_assign(ctx, target, value.unwrap())?;
|
||||||
}
|
}
|
||||||
StmtKind::Try { .. } => gen_try(generator, ctx, stmt)?,
|
StmtKind::Try { .. } => gen_try(generator, ctx, stmt)?,
|
||||||
|
@ -5,7 +5,7 @@ use crate::typecheck::{
|
|||||||
type_inferencer::*,
|
type_inferencer::*,
|
||||||
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
|
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
|
||||||
};
|
};
|
||||||
use itertools::Itertools;
|
use itertools::{iproduct, Itertools};
|
||||||
use nac3parser::ast::StrRef;
|
use nac3parser::ast::StrRef;
|
||||||
use nac3parser::ast::{Cmpop, Operator, Unaryop};
|
use nac3parser::ast::{Cmpop, Operator, Unaryop};
|
||||||
use std::cmp::max;
|
use std::cmp::max;
|
||||||
@ -13,67 +13,138 @@ use std::collections::HashMap;
|
|||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
use strum::IntoEnumIterator;
|
use strum::IntoEnumIterator;
|
||||||
|
|
||||||
#[must_use]
|
/// The variant of a binary operator.
|
||||||
pub fn binop_name(op: Operator) -> &'static str {
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
match op {
|
pub enum BinopVariant {
|
||||||
Operator::Add => "__add__",
|
/// The normal variant.
|
||||||
Operator::Sub => "__sub__",
|
/// For addition, it would be `+`.
|
||||||
Operator::Div => "__truediv__",
|
Normal,
|
||||||
Operator::Mod => "__mod__",
|
/// The "Augmented Assigning Operator" variant.
|
||||||
Operator::Mult => "__mul__",
|
/// For addition, it would be `+=`.
|
||||||
Operator::Pow => "__pow__",
|
AugAssign,
|
||||||
Operator::BitOr => "__or__",
|
}
|
||||||
Operator::BitXor => "__xor__",
|
|
||||||
Operator::BitAnd => "__and__",
|
/// A binary operator with its variant.
|
||||||
Operator::LShift => "__lshift__",
|
#[derive(Debug, Clone, Copy)]
|
||||||
Operator::RShift => "__rshift__",
|
pub struct Binop {
|
||||||
Operator::FloorDiv => "__floordiv__",
|
/// The base [`Operator`] of this binary operator.
|
||||||
Operator::MatMult => "__matmul__",
|
pub base: Operator,
|
||||||
|
/// The variant of this binary operator.
|
||||||
|
pub variant: BinopVariant,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Binop {
|
||||||
|
/// Make a [`Binop`] of the normal variant from an [`Operator`].
|
||||||
|
#[must_use]
|
||||||
|
pub fn normal(base: Operator) -> Self {
|
||||||
|
Binop { base, variant: BinopVariant::Normal }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Make a [`Binop`] of the aug assign variant from an [`Operator`].
|
||||||
|
#[must_use]
|
||||||
|
pub fn aug_assign(base: Operator) -> Self {
|
||||||
|
Binop { base, variant: BinopVariant::AugAssign }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[must_use]
|
/// Details about an operator (unary, binary, etc...) in Python
|
||||||
pub fn binop_assign_name(op: Operator) -> &'static str {
|
#[derive(Debug, Clone, Copy)]
|
||||||
match op {
|
pub struct OpInfo {
|
||||||
Operator::Add => "__iadd__",
|
/// The method name of the binary operator.
|
||||||
Operator::Sub => "__isub__",
|
/// For addition, this would be `__add__`, and `__iadd__` if
|
||||||
Operator::Div => "__itruediv__",
|
/// it is the augmented assigning variant.
|
||||||
Operator::Mod => "__imod__",
|
pub method_name: &'static str,
|
||||||
Operator::Mult => "__imul__",
|
/// The symbol of the binary operator.
|
||||||
Operator::Pow => "__ipow__",
|
/// For addition, this would be `+`, and `+=` if
|
||||||
Operator::BitOr => "__ior__",
|
/// it is the augmented assigning variant.
|
||||||
Operator::BitXor => "__ixor__",
|
pub symbol: &'static str,
|
||||||
Operator::BitAnd => "__iand__",
|
|
||||||
Operator::LShift => "__ilshift__",
|
|
||||||
Operator::RShift => "__irshift__",
|
|
||||||
Operator::FloorDiv => "__ifloordiv__",
|
|
||||||
Operator::MatMult => "__imatmul__",
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[must_use]
|
/// Helper macro to conveniently build an [`OpInfo`].
|
||||||
pub fn unaryop_name(op: Unaryop) -> &'static str {
|
///
|
||||||
match op {
|
/// Example usage: `make_info("add", "+")` generates `OpInfo { name: "__add__", symbol: "+" }`
|
||||||
Unaryop::UAdd => "__pos__",
|
macro_rules! make_op_info {
|
||||||
Unaryop::USub => "__neg__",
|
($name:expr, $symbol:expr) => {
|
||||||
Unaryop::Not => "__not__",
|
OpInfo { method_name: concat!("__", $name, "__"), symbol: $symbol }
|
||||||
Unaryop::Invert => "__inv__",
|
};
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[must_use]
|
pub trait HasOpInfo {
|
||||||
pub fn comparison_name(op: Cmpop) -> Option<&'static str> {
|
fn op_info(&self) -> OpInfo;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn try_get_cmpop_info(op: Cmpop) -> Option<OpInfo> {
|
||||||
match op {
|
match op {
|
||||||
Cmpop::Lt => Some("__lt__"),
|
Cmpop::Lt => Some(make_op_info!("lt", "<")),
|
||||||
Cmpop::LtE => Some("__le__"),
|
Cmpop::LtE => Some(make_op_info!("le", "<=")),
|
||||||
Cmpop::Gt => Some("__gt__"),
|
Cmpop::Gt => Some(make_op_info!("gt", ">")),
|
||||||
Cmpop::GtE => Some("__ge__"),
|
Cmpop::GtE => Some(make_op_info!("ge", ">=")),
|
||||||
Cmpop::Eq => Some("__eq__"),
|
Cmpop::Eq => Some(make_op_info!("eq", "==")),
|
||||||
Cmpop::NotEq => Some("__ne__"),
|
Cmpop::NotEq => Some(make_op_info!("ne", "!=")),
|
||||||
_ => None,
|
_ => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl OpInfo {
|
||||||
|
#[must_use]
|
||||||
|
pub fn supports_cmpop(op: Cmpop) -> bool {
|
||||||
|
try_get_cmpop_info(op).is_some()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HasOpInfo for Cmpop {
|
||||||
|
fn op_info(&self) -> OpInfo {
|
||||||
|
try_get_cmpop_info(*self).expect("{self:?} is not supported")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HasOpInfo for Binop {
|
||||||
|
fn op_info(&self) -> OpInfo {
|
||||||
|
// Helper macro to generate both the normal variant [`OpInfo`] and the
|
||||||
|
// augmented assigning variant [`OpInfo`] for a binary operator conveniently.
|
||||||
|
macro_rules! info {
|
||||||
|
($name:literal, $symbol:literal) => {
|
||||||
|
(
|
||||||
|
make_op_info!($name, $symbol),
|
||||||
|
make_op_info!(concat!("i", $name), concat!($symbol, "=")),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
let (normal_variant, aug_assign_variant) = match self.base {
|
||||||
|
Operator::Add => info!("add", "+"),
|
||||||
|
Operator::Sub => info!("sub", "-"),
|
||||||
|
Operator::Div => info!("truediv", "/"),
|
||||||
|
Operator::Mod => info!("mod", "%"),
|
||||||
|
Operator::Mult => info!("mul", "*"),
|
||||||
|
Operator::Pow => info!("pow", "**"),
|
||||||
|
Operator::BitOr => info!("or", "|"),
|
||||||
|
Operator::BitXor => info!("xor", "^"),
|
||||||
|
Operator::BitAnd => info!("and", "&"),
|
||||||
|
Operator::LShift => info!("lshift", "<<"),
|
||||||
|
Operator::RShift => info!("rshift", ">>"),
|
||||||
|
Operator::FloorDiv => info!("floordiv", "//"),
|
||||||
|
Operator::MatMult => info!("matmul", "@"),
|
||||||
|
};
|
||||||
|
|
||||||
|
match self.variant {
|
||||||
|
BinopVariant::Normal => normal_variant,
|
||||||
|
BinopVariant::AugAssign => aug_assign_variant,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HasOpInfo for Unaryop {
|
||||||
|
fn op_info(&self) -> OpInfo {
|
||||||
|
match self {
|
||||||
|
Unaryop::UAdd => make_op_info!("pos", "+"),
|
||||||
|
Unaryop::USub => make_op_info!("neg", "-"),
|
||||||
|
Unaryop::Not => make_op_info!("not", "not"), // i.e., `not False`, so the symbol is just `not`.
|
||||||
|
Unaryop::Invert => make_op_info!("inv", "~"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(super) fn with_fields<F>(unifier: &mut Unifier, ty: Type, f: F)
|
pub(super) fn with_fields<F>(unifier: &mut Unifier, ty: Type, f: F)
|
||||||
where
|
where
|
||||||
F: FnOnce(&mut Unifier, &mut HashMap<StrRef, (Type, bool)>),
|
F: FnOnce(&mut Unifier, &mut HashMap<StrRef, (Type, bool)>),
|
||||||
@ -115,23 +186,9 @@ pub fn impl_binop(
|
|||||||
|
|
||||||
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty);
|
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty);
|
||||||
|
|
||||||
for op in ops {
|
for (base_op, variant) in iproduct!(ops, [BinopVariant::Normal, BinopVariant::AugAssign]) {
|
||||||
fields.insert(binop_name(*op).into(), {
|
let op = Binop { base: *base_op, variant };
|
||||||
(
|
fields.insert(op.op_info().method_name.into(), {
|
||||||
unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
|
||||||
ret: ret_ty,
|
|
||||||
vars: function_vars.clone(),
|
|
||||||
args: vec![FuncArg {
|
|
||||||
ty: other_ty,
|
|
||||||
default_value: None,
|
|
||||||
name: "other".into(),
|
|
||||||
}],
|
|
||||||
})),
|
|
||||||
false,
|
|
||||||
)
|
|
||||||
});
|
|
||||||
|
|
||||||
fields.insert(binop_assign_name(*op).into(), {
|
|
||||||
(
|
(
|
||||||
unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
ret: ret_ty,
|
ret: ret_ty,
|
||||||
@ -155,7 +212,7 @@ pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Option<Type>, ops:
|
|||||||
|
|
||||||
for op in ops {
|
for op in ops {
|
||||||
fields.insert(
|
fields.insert(
|
||||||
unaryop_name(*op).into(),
|
op.op_info().method_name.into(),
|
||||||
(
|
(
|
||||||
unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
ret: ret_ty,
|
ret: ret_ty,
|
||||||
@ -195,7 +252,7 @@ pub fn impl_cmpop(
|
|||||||
|
|
||||||
for op in ops {
|
for op in ops {
|
||||||
fields.insert(
|
fields.insert(
|
||||||
comparison_name(*op).unwrap().into(),
|
op.op_info().method_name.into(),
|
||||||
(
|
(
|
||||||
unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
ret: ret_ty,
|
ret: ret_ty,
|
||||||
@ -429,18 +486,20 @@ pub fn typeof_binop(
|
|||||||
lhs: Type,
|
lhs: Type,
|
||||||
rhs: Type,
|
rhs: Type,
|
||||||
) -> Result<Option<Type>, String> {
|
) -> Result<Option<Type>, String> {
|
||||||
|
let op = Binop { base: op, variant: BinopVariant::Normal };
|
||||||
|
|
||||||
let is_left_list = lhs.obj_id(unifier).is_some_and(|id| id == PrimDef::List.id());
|
let is_left_list = lhs.obj_id(unifier).is_some_and(|id| id == PrimDef::List.id());
|
||||||
let is_right_list = rhs.obj_id(unifier).is_some_and(|id| id == PrimDef::List.id());
|
let is_right_list = rhs.obj_id(unifier).is_some_and(|id| id == PrimDef::List.id());
|
||||||
let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||||
let is_right_ndarray = rhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
let is_right_ndarray = rhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||||
|
|
||||||
Ok(Some(match op {
|
Ok(Some(match op.base {
|
||||||
Operator::Add | Operator::Sub | Operator::Mult | Operator::Mod | Operator::FloorDiv => {
|
Operator::Add | Operator::Sub | Operator::Mult | Operator::Mod | Operator::FloorDiv => {
|
||||||
if is_left_list || is_right_list {
|
if is_left_list || is_right_list {
|
||||||
if ![Operator::Add, Operator::Mult].contains(&op) {
|
if ![Operator::Add, Operator::Mult].contains(&op.base) {
|
||||||
return Err(format!(
|
return Err(format!(
|
||||||
"Binary operator {} not supported for list",
|
"Binary operator {} not supported for list",
|
||||||
binop_name(op)
|
op.op_info().symbol
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,11 +1,14 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::fmt::Display;
|
use std::fmt::Display;
|
||||||
|
|
||||||
use crate::typecheck::typedef::TypeEnum;
|
use crate::typecheck::{magic_methods::HasOpInfo, typedef::TypeEnum};
|
||||||
|
|
||||||
use super::typedef::{RecordKey, Type, Unifier};
|
use super::{
|
||||||
|
magic_methods::Binop,
|
||||||
|
typedef::{RecordKey, Type, Unifier},
|
||||||
|
};
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use nac3parser::ast::{Location, StrRef};
|
use nac3parser::ast::{Cmpop, Location, StrRef};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum TypeErrorKind {
|
pub enum TypeErrorKind {
|
||||||
@ -26,6 +29,18 @@ pub enum TypeErrorKind {
|
|||||||
expected: Type,
|
expected: Type,
|
||||||
got: Type,
|
got: Type,
|
||||||
},
|
},
|
||||||
|
UnsupportedBinaryOpTypes {
|
||||||
|
operator: Binop,
|
||||||
|
lhs_type: Type,
|
||||||
|
rhs_type: Type,
|
||||||
|
expected_rhs_type: Type,
|
||||||
|
},
|
||||||
|
UnsupportedComparsionOpTypes {
|
||||||
|
operator: Cmpop,
|
||||||
|
lhs_type: Type,
|
||||||
|
rhs_type: Type,
|
||||||
|
expected_rhs_type: Type,
|
||||||
|
},
|
||||||
FieldUnificationError {
|
FieldUnificationError {
|
||||||
field: RecordKey,
|
field: RecordKey,
|
||||||
types: (Type, Type),
|
types: (Type, Type),
|
||||||
@ -101,6 +116,26 @@ impl<'a> Display for DisplayTypeError<'a> {
|
|||||||
let args = missing_arg_names.iter().join(", ");
|
let args = missing_arg_names.iter().join(", ");
|
||||||
write!(f, "Missing arguments: {args}")
|
write!(f, "Missing arguments: {args}")
|
||||||
}
|
}
|
||||||
|
UnsupportedBinaryOpTypes { operator, lhs_type, rhs_type, expected_rhs_type } => {
|
||||||
|
let op_symbol = operator.op_info().symbol;
|
||||||
|
|
||||||
|
let lhs_type_str = self.unifier.stringify_with_notes(*lhs_type, &mut notes);
|
||||||
|
let rhs_type_str = self.unifier.stringify_with_notes(*rhs_type, &mut notes);
|
||||||
|
let expected_rhs_type_str =
|
||||||
|
self.unifier.stringify_with_notes(*expected_rhs_type, &mut notes);
|
||||||
|
|
||||||
|
write!(f, "Unsupported operand type(s) for {op_symbol}: '{lhs_type_str}' and '{rhs_type_str}' (right operand should have type {expected_rhs_type_str})")
|
||||||
|
}
|
||||||
|
UnsupportedComparsionOpTypes { operator, lhs_type, rhs_type, expected_rhs_type } => {
|
||||||
|
let op_symbol = operator.op_info().symbol;
|
||||||
|
|
||||||
|
let lhs_type_str = self.unifier.stringify_with_notes(*lhs_type, &mut notes);
|
||||||
|
let rhs_type_str = self.unifier.stringify_with_notes(*rhs_type, &mut notes);
|
||||||
|
let expected_rhs_type_str =
|
||||||
|
self.unifier.stringify_with_notes(*expected_rhs_type, &mut notes);
|
||||||
|
|
||||||
|
write!(f, "'{op_symbol}' not supported between instances of '{lhs_type_str}' and '{rhs_type_str}' (right operand should have type {expected_rhs_type_str})")
|
||||||
|
}
|
||||||
UnknownArgName(name) => {
|
UnknownArgName(name) => {
|
||||||
write!(f, "Unknown argument name: {name}")
|
write!(f, "Unknown argument name: {name}")
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,7 @@ use std::iter::once;
|
|||||||
use std::ops::Not;
|
use std::ops::Not;
|
||||||
use std::{cell::RefCell, sync::Arc};
|
use std::{cell::RefCell, sync::Arc};
|
||||||
|
|
||||||
|
use super::typedef::OperatorInfo;
|
||||||
use super::{
|
use super::{
|
||||||
magic_methods::*,
|
magic_methods::*,
|
||||||
type_error::TypeError,
|
type_error::TypeError,
|
||||||
@ -491,7 +492,8 @@ impl<'a> Fold<()> for Inferencer<'a> {
|
|||||||
(None, None) => {}
|
(None, None) => {}
|
||||||
},
|
},
|
||||||
ast::StmtKind::AugAssign { target, op, value, .. } => {
|
ast::StmtKind::AugAssign { target, op, value, .. } => {
|
||||||
let res_ty = self.infer_bin_ops(stmt.location, target, *op, value, true)?;
|
let res_ty =
|
||||||
|
self.infer_bin_ops(stmt.location, target, Binop::aug_assign(*op), value)?;
|
||||||
self.unify(res_ty, target.custom.unwrap(), &stmt.location)?;
|
self.unify(res_ty, target.custom.unwrap(), &stmt.location)?;
|
||||||
}
|
}
|
||||||
ast::StmtKind::Assert { test, msg, .. } => {
|
ast::StmtKind::Assert { test, msg, .. } => {
|
||||||
@ -573,7 +575,7 @@ impl<'a> Fold<()> for Inferencer<'a> {
|
|||||||
}
|
}
|
||||||
ExprKind::BoolOp { values, .. } => Some(self.infer_bool_ops(values)?),
|
ExprKind::BoolOp { values, .. } => Some(self.infer_bool_ops(values)?),
|
||||||
ExprKind::BinOp { left, op, right } => {
|
ExprKind::BinOp { left, op, right } => {
|
||||||
Some(self.infer_bin_ops(expr.location, left, *op, right, false)?)
|
Some(self.infer_bin_ops(expr.location, left, Binop::normal(*op), right)?)
|
||||||
}
|
}
|
||||||
ExprKind::UnaryOp { op, operand } => {
|
ExprKind::UnaryOp { op, operand } => {
|
||||||
Some(self.infer_unary_ops(expr.location, *op, operand)?)
|
Some(self.infer_unary_ops(expr.location, *op, operand)?)
|
||||||
@ -640,6 +642,7 @@ impl<'a> Inferencer<'a> {
|
|||||||
obj: Type,
|
obj: Type,
|
||||||
params: Vec<Type>,
|
params: Vec<Type>,
|
||||||
ret: Option<Type>,
|
ret: Option<Type>,
|
||||||
|
operator_info: Option<OperatorInfo>,
|
||||||
) -> InferenceResult {
|
) -> InferenceResult {
|
||||||
if let TypeEnum::TObj { params: class_params, fields, .. } = &*self.unifier.get_ty(obj) {
|
if let TypeEnum::TObj { params: class_params, fields, .. } = &*self.unifier.get_ty(obj) {
|
||||||
if class_params.is_empty() {
|
if class_params.is_empty() {
|
||||||
@ -653,6 +656,7 @@ impl<'a> Inferencer<'a> {
|
|||||||
ret: sign.ret,
|
ret: sign.ret,
|
||||||
fun: RefCell::new(None),
|
fun: RefCell::new(None),
|
||||||
loc: Some(location),
|
loc: Some(location),
|
||||||
|
operator_info,
|
||||||
};
|
};
|
||||||
if let Some(ret) = ret {
|
if let Some(ret) = ret {
|
||||||
self.unifier
|
self.unifier
|
||||||
@ -687,6 +691,7 @@ impl<'a> Inferencer<'a> {
|
|||||||
ret,
|
ret,
|
||||||
fun: RefCell::new(None),
|
fun: RefCell::new(None),
|
||||||
loc: Some(location),
|
loc: Some(location),
|
||||||
|
operator_info,
|
||||||
});
|
});
|
||||||
self.calls.insert(location.into(), call);
|
self.calls.insert(location.into(), call);
|
||||||
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call]));
|
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call]));
|
||||||
@ -1522,6 +1527,7 @@ impl<'a> Inferencer<'a> {
|
|||||||
fun: RefCell::new(None),
|
fun: RefCell::new(None),
|
||||||
ret: sign.ret,
|
ret: sign.ret,
|
||||||
loc: Some(location),
|
loc: Some(location),
|
||||||
|
operator_info: None,
|
||||||
};
|
};
|
||||||
self.unifier.unify_call(&call, func.custom.unwrap(), sign).map_err(|e| {
|
self.unifier.unify_call(&call, func.custom.unwrap(), sign).map_err(|e| {
|
||||||
HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()])
|
HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()])
|
||||||
@ -1544,6 +1550,7 @@ impl<'a> Inferencer<'a> {
|
|||||||
fun: RefCell::new(None),
|
fun: RefCell::new(None),
|
||||||
ret,
|
ret,
|
||||||
loc: Some(location),
|
loc: Some(location),
|
||||||
|
operator_info: None,
|
||||||
});
|
});
|
||||||
self.calls.insert(location.into(), call);
|
self.calls.insert(location.into(), call);
|
||||||
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call]));
|
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call]));
|
||||||
@ -1729,9 +1736,8 @@ impl<'a> Inferencer<'a> {
|
|||||||
&mut self,
|
&mut self,
|
||||||
location: Location,
|
location: Location,
|
||||||
left: &ast::Expr<Option<Type>>,
|
left: &ast::Expr<Option<Type>>,
|
||||||
op: ast::Operator,
|
op: Binop,
|
||||||
right: &ast::Expr<Option<Type>>,
|
right: &ast::Expr<Option<Type>>,
|
||||||
is_aug_assign: bool,
|
|
||||||
) -> InferenceResult {
|
) -> InferenceResult {
|
||||||
let left_ty = left.custom.unwrap();
|
let left_ty = left.custom.unwrap();
|
||||||
let right_ty = right.custom.unwrap();
|
let right_ty = right.custom.unwrap();
|
||||||
@ -1739,27 +1745,40 @@ impl<'a> Inferencer<'a> {
|
|||||||
let method = if let TypeEnum::TObj { fields, .. } =
|
let method = if let TypeEnum::TObj { fields, .. } =
|
||||||
self.unifier.get_ty_immutable(left_ty).as_ref()
|
self.unifier.get_ty_immutable(left_ty).as_ref()
|
||||||
{
|
{
|
||||||
let (binop_name, binop_assign_name) =
|
let normal_method_name = Binop::normal(op.base).op_info().method_name;
|
||||||
(binop_name(op).into(), binop_assign_name(op).into());
|
let assign_method_name = Binop::aug_assign(op.base).op_info().method_name;
|
||||||
|
|
||||||
// if is aug_assign, try aug_assign operator first
|
// if is aug_assign, try aug_assign operator first
|
||||||
if is_aug_assign && fields.contains_key(&binop_assign_name) {
|
if op.variant == BinopVariant::AugAssign
|
||||||
binop_assign_name
|
&& fields.contains_key(&assign_method_name.into())
|
||||||
|
{
|
||||||
|
assign_method_name
|
||||||
} else {
|
} else {
|
||||||
binop_name
|
normal_method_name
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
binop_name(op).into()
|
op.op_info().method_name
|
||||||
};
|
};
|
||||||
|
|
||||||
let ret = if is_aug_assign {
|
let ret = match op.variant {
|
||||||
// The type of augmented assignment operator should never change
|
BinopVariant::Normal => {
|
||||||
Some(left_ty)
|
typeof_binop(self.unifier, self.primitives, op.base, left_ty, right_ty)
|
||||||
} else {
|
.map_err(|e| HashSet::from([format!("{e} (at {location})")]))?
|
||||||
typeof_binop(self.unifier, self.primitives, op, left_ty, right_ty)
|
}
|
||||||
.map_err(|e| HashSet::from([format!("{e} (at {location})")]))?
|
BinopVariant::AugAssign => {
|
||||||
|
// The type of augmented assignment operator should never change
|
||||||
|
Some(left_ty)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
self.build_method_call(location, method, left_ty, vec![right_ty], ret)
|
self.build_method_call(
|
||||||
|
location,
|
||||||
|
method.into(),
|
||||||
|
left_ty,
|
||||||
|
vec![right_ty],
|
||||||
|
ret,
|
||||||
|
Some(OperatorInfo::IsBinaryOp { self_type: left.custom.unwrap(), operator: op }),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn infer_unary_ops(
|
fn infer_unary_ops(
|
||||||
@ -1768,12 +1787,19 @@ impl<'a> Inferencer<'a> {
|
|||||||
op: ast::Unaryop,
|
op: ast::Unaryop,
|
||||||
operand: &ast::Expr<Option<Type>>,
|
operand: &ast::Expr<Option<Type>>,
|
||||||
) -> InferenceResult {
|
) -> InferenceResult {
|
||||||
let method = unaryop_name(op).into();
|
let method = op.op_info().method_name.into();
|
||||||
|
|
||||||
let ret = typeof_unaryop(self.unifier, self.primitives, op, operand.custom.unwrap())
|
let ret = typeof_unaryop(self.unifier, self.primitives, op, operand.custom.unwrap())
|
||||||
.map_err(|e| HashSet::from([format!("{e} (at {location})")]))?;
|
.map_err(|e| HashSet::from([format!("{e} (at {location})")]))?;
|
||||||
|
|
||||||
self.build_method_call(location, method, operand.custom.unwrap(), vec![], ret)
|
self.build_method_call(
|
||||||
|
location,
|
||||||
|
method,
|
||||||
|
operand.custom.unwrap(),
|
||||||
|
vec![],
|
||||||
|
ret,
|
||||||
|
Some(OperatorInfo::IsUnaryOp { self_type: operand.custom.unwrap(), operator: op }),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn infer_compare(
|
fn infer_compare(
|
||||||
@ -1798,9 +1824,11 @@ impl<'a> Inferencer<'a> {
|
|||||||
|
|
||||||
let mut res = None;
|
let mut res = None;
|
||||||
for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) {
|
for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) {
|
||||||
let method = comparison_name(*c)
|
if !OpInfo::supports_cmpop(*c) {
|
||||||
.ok_or_else(|| HashSet::from(["unsupported comparator".to_string()]))?
|
return Err(HashSet::from(["unsupported comparator".to_string()]));
|
||||||
.into();
|
}
|
||||||
|
|
||||||
|
let method = c.op_info().method_name.into();
|
||||||
|
|
||||||
let ret = typeof_cmpop(
|
let ret = typeof_cmpop(
|
||||||
self.unifier,
|
self.unifier,
|
||||||
@ -1817,6 +1845,10 @@ impl<'a> Inferencer<'a> {
|
|||||||
a.custom.unwrap(),
|
a.custom.unwrap(),
|
||||||
vec![b.custom.unwrap()],
|
vec![b.custom.unwrap()],
|
||||||
ret,
|
ret,
|
||||||
|
Some(OperatorInfo::IsComparisonOp {
|
||||||
|
self_type: left.custom.unwrap(),
|
||||||
|
operator: *c,
|
||||||
|
}),
|
||||||
)?);
|
)?);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8,12 +8,15 @@ use std::rc::Rc;
|
|||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
use std::{borrow::Cow, collections::HashSet};
|
use std::{borrow::Cow, collections::HashSet};
|
||||||
|
|
||||||
use nac3parser::ast::{Location, StrRef};
|
use nac3parser::ast::{Cmpop, Location, StrRef, Unaryop};
|
||||||
|
|
||||||
|
use super::magic_methods::Binop;
|
||||||
use super::type_error::{TypeError, TypeErrorKind};
|
use super::type_error::{TypeError, TypeErrorKind};
|
||||||
use super::unification_table::{UnificationKey, UnificationTable};
|
use super::unification_table::{UnificationKey, UnificationTable};
|
||||||
use crate::symbol_resolver::SymbolValue;
|
use crate::symbol_resolver::SymbolValue;
|
||||||
use crate::toplevel::{helper::PrimDef, DefinitionId, TopLevelContext, TopLevelDef};
|
use crate::toplevel::helper::PrimDef;
|
||||||
|
use crate::toplevel::{DefinitionId, TopLevelContext, TopLevelDef};
|
||||||
|
use crate::typecheck::magic_methods::OpInfo;
|
||||||
use crate::typecheck::type_inferencer::PrimitiveStore;
|
use crate::typecheck::type_inferencer::PrimitiveStore;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@ -73,6 +76,28 @@ pub fn iter_type_vars(var_map: &VarMap) -> impl Iterator<Item = TypeVar> + '_ {
|
|||||||
var_map.iter().map(|(&id, &ty)| TypeVar { id, ty })
|
var_map.iter().map(|(&id, &ty)| TypeVar { id, ty })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum OperatorInfo {
|
||||||
|
/// The call was written as an unary operation, e.g., `~a` or `not a`.
|
||||||
|
IsUnaryOp {
|
||||||
|
/// The [`Type`] of the `self` object
|
||||||
|
self_type: Type,
|
||||||
|
operator: Unaryop,
|
||||||
|
},
|
||||||
|
/// The call was written as a binary operation, e.g., `a + b` or `a += b`.
|
||||||
|
IsBinaryOp {
|
||||||
|
/// The [`Type`] of the `self` object
|
||||||
|
self_type: Type,
|
||||||
|
operator: Binop,
|
||||||
|
},
|
||||||
|
/// The call was written as a binary comparison operation, e.g., `a < b`.
|
||||||
|
IsComparisonOp {
|
||||||
|
/// The [`Type`] of the `self` object
|
||||||
|
self_type: Type,
|
||||||
|
operator: Cmpop,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct Call {
|
pub struct Call {
|
||||||
pub posargs: Vec<Type>,
|
pub posargs: Vec<Type>,
|
||||||
@ -80,6 +105,9 @@ pub struct Call {
|
|||||||
pub ret: Type,
|
pub ret: Type,
|
||||||
pub fun: RefCell<Option<Type>>,
|
pub fun: RefCell<Option<Type>>,
|
||||||
pub loc: Option<Location>,
|
pub loc: Option<Location>,
|
||||||
|
|
||||||
|
/// Details about the associated Python user operator expression of this call, if any.
|
||||||
|
pub operator_info: Option<OperatorInfo>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -627,111 +655,178 @@ impl Unifier {
|
|||||||
let TypeEnum::TFunc(signature) = &*self.get_ty(b) else { unreachable!() };
|
let TypeEnum::TFunc(signature) = &*self.get_ty(b) else { unreachable!() };
|
||||||
|
|
||||||
// Get details about the input arguments
|
// Get details about the input arguments
|
||||||
let Call { posargs, kwargs, ret, fun, loc } = call;
|
let Call { posargs, kwargs, ret, fun, loc, operator_info } = call;
|
||||||
let num_args = posargs.len() + kwargs.len();
|
let num_args = posargs.len() + kwargs.len();
|
||||||
|
|
||||||
// Now we check the arguments against the parameters
|
// Now we check the arguments against the parameters,
|
||||||
|
// and depending on what `call_info` is, we might change how the behavior `unify_call()`
|
||||||
|
// in hopes to improve user error messages when type checking fails.
|
||||||
|
match operator_info {
|
||||||
|
Some(OperatorInfo::IsBinaryOp { self_type, operator }) => {
|
||||||
|
// The call is written in the form of (say) `a + b`.
|
||||||
|
// Technically, it is `a.__add__(b)`, and they have the following constraints:
|
||||||
|
assert_eq!(posargs.len(), 1);
|
||||||
|
assert_eq!(kwargs.len(), 0);
|
||||||
|
assert_eq!(num_params, 1);
|
||||||
|
|
||||||
// Helper lambdas
|
let other_type = posargs[0]; // the second operand
|
||||||
let mut type_check_arg = |param_name, expected_arg_ty, arg_ty| {
|
let expected_other_type = signature.args[0].ty;
|
||||||
let ok = self.unify_impl(expected_arg_ty, arg_ty, false).is_ok();
|
|
||||||
if ok {
|
let ok = self.unify_impl(expected_other_type, other_type, false).is_ok();
|
||||||
Ok(())
|
if !ok {
|
||||||
} else {
|
self.restore_snapshot();
|
||||||
// Typecheck failed, throw an error.
|
return Err(TypeError::new(
|
||||||
self.restore_snapshot();
|
TypeErrorKind::UnsupportedBinaryOpTypes {
|
||||||
Err(TypeError::new(
|
operator: *operator,
|
||||||
TypeErrorKind::IncorrectArgType {
|
lhs_type: *self_type,
|
||||||
name: param_name,
|
rhs_type: other_type,
|
||||||
expected: expected_arg_ty,
|
expected_rhs_type: expected_other_type,
|
||||||
got: arg_ty,
|
},
|
||||||
},
|
*loc,
|
||||||
*loc,
|
));
|
||||||
))
|
}
|
||||||
}
|
}
|
||||||
};
|
Some(OperatorInfo::IsComparisonOp { self_type, operator })
|
||||||
|
if OpInfo::supports_cmpop(*operator) // Otherwise that comparison operator is not supported.
|
||||||
|
=>
|
||||||
|
{
|
||||||
|
// The call is written in the form of (say) `a <= b`.
|
||||||
|
// Technically, it is `a.__le__(b)`, and they have the following constraints:
|
||||||
|
assert_eq!(posargs.len(), 1);
|
||||||
|
assert_eq!(kwargs.len(), 0);
|
||||||
|
assert_eq!(num_params, 1);
|
||||||
|
|
||||||
// Check for "too many arguments"
|
let other_type = posargs[0]; // the second operand
|
||||||
if num_params < posargs.len() {
|
let expected_other_type = signature.args[0].ty;
|
||||||
let expected_min_count =
|
|
||||||
signature.args.iter().filter(|param| param.is_required()).count();
|
|
||||||
let expected_max_count = num_params;
|
|
||||||
|
|
||||||
self.restore_snapshot();
|
let ok = self.unify_impl(expected_other_type, other_type, false).is_ok();
|
||||||
return Err(TypeError::new(
|
if !ok {
|
||||||
TypeErrorKind::TooManyArguments {
|
self.restore_snapshot();
|
||||||
expected_min_count,
|
return Err(TypeError::new(
|
||||||
expected_max_count,
|
TypeErrorKind::UnsupportedComparsionOpTypes {
|
||||||
got_count: num_args,
|
operator: *operator,
|
||||||
},
|
lhs_type: *self_type,
|
||||||
*loc,
|
rhs_type: other_type,
|
||||||
));
|
expected_rhs_type: expected_other_type,
|
||||||
}
|
},
|
||||||
|
*loc,
|
||||||
// NOTE: order of `param_info_by_name` is leveraged, so use an IndexMap
|
));
|
||||||
let mut param_info_by_name: IndexMap<StrRef, ParamInfo> = signature
|
}
|
||||||
.args
|
|
||||||
.iter()
|
|
||||||
.map(|arg| (arg.name, ParamInfo { has_been_supplied: false, param: arg }))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
// Now consume all positional arguments and typecheck them.
|
|
||||||
for (&arg_ty, param) in zip(posargs, signature.args.iter()) {
|
|
||||||
// We will also use this opportunity to mark the corresponding `param_info` as having been supplied.
|
|
||||||
let param_info = param_info_by_name.get_mut(¶m.name).unwrap();
|
|
||||||
param_info.has_been_supplied = true;
|
|
||||||
|
|
||||||
// Typecheck
|
|
||||||
type_check_arg(param.name, param.ty, arg_ty)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now consume all keyword arguments and typecheck them.
|
|
||||||
for (¶m_name, &arg_ty) in kwargs {
|
|
||||||
// We will also use this opportunity to check if this keyword argument is "legal".
|
|
||||||
|
|
||||||
let Some(param_info) = param_info_by_name.get_mut(¶m_name) else {
|
|
||||||
self.restore_snapshot();
|
|
||||||
return Err(TypeError::new(TypeErrorKind::UnknownArgName(param_name), *loc));
|
|
||||||
};
|
|
||||||
|
|
||||||
if param_info.has_been_supplied {
|
|
||||||
// NOTE: Duplicate keyword argument (i.e., `hello(1, 2, 3, arg = 4, arg = 5)`)
|
|
||||||
// is IMPOSSIBLE as the parser would have already failed.
|
|
||||||
// We only have to care about "got multiple values for XYZ"
|
|
||||||
|
|
||||||
self.restore_snapshot();
|
|
||||||
return Err(TypeError::new(
|
|
||||||
TypeErrorKind::GotMultipleValues { name: param_name },
|
|
||||||
*loc,
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
_ => {
|
||||||
|
// Handle [`CallInfo::IsNormalFunctionCall`] and other uninteresting variants
|
||||||
|
// of [`CallInfo`] (e.g, `CallInfo::IsUnaryOp` and unsupported comparison operators)
|
||||||
|
|
||||||
param_info.has_been_supplied = true;
|
// Helper lambdas
|
||||||
|
let mut type_check_arg = |param_name, expected_arg_ty, arg_ty| {
|
||||||
|
let ok = self.unify_impl(expected_arg_ty, arg_ty, false).is_ok();
|
||||||
|
if ok {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
// Typecheck failed, throw an error.
|
||||||
|
self.restore_snapshot();
|
||||||
|
Err(TypeError::new(
|
||||||
|
TypeErrorKind::IncorrectArgType {
|
||||||
|
name: param_name,
|
||||||
|
expected: expected_arg_ty,
|
||||||
|
got: arg_ty,
|
||||||
|
},
|
||||||
|
*loc,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Typecheck
|
// Check for "too many arguments"
|
||||||
type_check_arg(param_name, param_info.param.ty, arg_ty)?;
|
if num_params < posargs.len() {
|
||||||
}
|
let expected_min_count =
|
||||||
|
signature.args.iter().filter(|param| param.is_required()).count();
|
||||||
|
let expected_max_count = num_params;
|
||||||
|
|
||||||
// After checking posargs and kwargs, check if there are any
|
self.restore_snapshot();
|
||||||
// unsupplied required parameters, and throw an error if they exist.
|
return Err(TypeError::new(
|
||||||
let missing_arg_names = param_info_by_name
|
TypeErrorKind::TooManyArguments {
|
||||||
.values()
|
expected_min_count,
|
||||||
.filter(|param_info| param_info.param.is_required() && !param_info.has_been_supplied)
|
expected_max_count,
|
||||||
.map(|param_info| param_info.param.name)
|
got_count: num_args,
|
||||||
.collect_vec();
|
},
|
||||||
if !missing_arg_names.is_empty() {
|
*loc,
|
||||||
self.restore_snapshot();
|
));
|
||||||
return Err(TypeError::new(TypeErrorKind::MissingArgs { missing_arg_names }, *loc));
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Finally, check the Call's return type
|
// NOTE: order of `param_info_by_name` is leveraged, so use an IndexMap
|
||||||
self.unify_impl(*ret, signature.ret, false).map_err(|mut err| {
|
let mut param_info_by_name: IndexMap<StrRef, ParamInfo> = signature
|
||||||
self.restore_snapshot();
|
.args
|
||||||
if err.loc.is_none() {
|
.iter()
|
||||||
err.loc = *loc;
|
.map(|arg| (arg.name, ParamInfo { has_been_supplied: false, param: arg }))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Now consume all positional arguments and typecheck them.
|
||||||
|
for (&arg_ty, param) in zip(posargs, signature.args.iter()) {
|
||||||
|
// We will also use this opportunity to mark the corresponding `param_info` as having been supplied.
|
||||||
|
let param_info = param_info_by_name.get_mut(¶m.name).unwrap();
|
||||||
|
param_info.has_been_supplied = true;
|
||||||
|
|
||||||
|
// Typecheck
|
||||||
|
type_check_arg(param.name, param.ty, arg_ty)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now consume all keyword arguments and typecheck them.
|
||||||
|
for (¶m_name, &arg_ty) in kwargs {
|
||||||
|
// We will also use this opportunity to check if this keyword argument is "legal".
|
||||||
|
|
||||||
|
let Some(param_info) = param_info_by_name.get_mut(¶m_name) else {
|
||||||
|
self.restore_snapshot();
|
||||||
|
return Err(TypeError::new(
|
||||||
|
TypeErrorKind::UnknownArgName(param_name),
|
||||||
|
*loc,
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
|
if param_info.has_been_supplied {
|
||||||
|
// NOTE: Duplicate keyword argument (i.e., `hello(1, 2, 3, arg = 4, arg = 5)`)
|
||||||
|
// is IMPOSSIBLE as the parser would have already failed.
|
||||||
|
// We only have to care about "got multiple values for XYZ"
|
||||||
|
|
||||||
|
self.restore_snapshot();
|
||||||
|
return Err(TypeError::new(
|
||||||
|
TypeErrorKind::GotMultipleValues { name: param_name },
|
||||||
|
*loc,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
param_info.has_been_supplied = true;
|
||||||
|
|
||||||
|
// Typecheck
|
||||||
|
type_check_arg(param_name, param_info.param.ty, arg_ty)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// After checking posargs and kwargs, check if there are any
|
||||||
|
// unsupplied required parameters, and throw an error if they exist.
|
||||||
|
let missing_arg_names = param_info_by_name
|
||||||
|
.values()
|
||||||
|
.filter(|param_info| {
|
||||||
|
param_info.param.is_required() && !param_info.has_been_supplied
|
||||||
|
})
|
||||||
|
.map(|param_info| param_info.param.name)
|
||||||
|
.collect_vec();
|
||||||
|
if !missing_arg_names.is_empty() {
|
||||||
|
self.restore_snapshot();
|
||||||
|
return Err(TypeError::new(
|
||||||
|
TypeErrorKind::MissingArgs { missing_arg_names },
|
||||||
|
*loc,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally, check the Call's return type
|
||||||
|
self.unify_impl(*ret, signature.ret, false).map_err(|mut err| {
|
||||||
|
self.restore_snapshot();
|
||||||
|
if err.loc.is_none() {
|
||||||
|
err.loc = *loc;
|
||||||
|
}
|
||||||
|
err
|
||||||
|
})?;
|
||||||
}
|
}
|
||||||
err
|
}
|
||||||
})?;
|
|
||||||
|
|
||||||
*fun.borrow_mut() = Some(b);
|
*fun.borrow_mut() = Some(b);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user