diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 2d18805..9b139fa 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -21,7 +21,7 @@ use crate::{ DefinitionId, TopLevelDef, }, typecheck::{ - magic_methods::{binop_assign_name, binop_name, unaryop_name}, + magic_methods::{BinOpVariant, OpInfo}, typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap}, }, }; @@ -1167,7 +1167,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( op: Operator, right: (&Option, BasicValueEnum<'ctx>), loc: Location, - is_aug_assign: bool, + variant: BinOpVariant, ) -> Result>, String> { let (left_ty, left_val) = left; let (right_ty, right_val) = right; @@ -1222,7 +1222,10 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( generator, ctx, ndarray_dtype1, - if is_aug_assign { Some(left_val) } else { None }, + match variant { + BinOpVariant::Normal => None, + BinOpVariant::AugAssign => Some(left_val), + }, left_val, right_val, )? @@ -1231,7 +1234,10 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( generator, ctx, ndarray_dtype1, - if is_aug_assign { Some(left_val) } else { None }, + match variant { + BinOpVariant::Normal => None, + BinOpVariant::AugAssign => Some(left_val), + }, (left_val.as_base_value().into(), false), (right_val.as_base_value().into(), false), |generator, ctx, (lhs, rhs)| { @@ -1242,7 +1248,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( op, (&Some(ndarray_dtype2), rhs), ctx.current_loc, - is_aug_assign, + variant, )? .unwrap() .to_basic_value_enum( @@ -1267,7 +1273,10 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( generator, ctx, ndarray_dtype, - if is_aug_assign { Some(ndarray_val) } else { None }, + match variant { + BinOpVariant::Normal => None, + BinOpVariant::AugAssign => Some(ndarray_val), + }, (left_val, !is_ndarray1), (right_val, !is_ndarray2), |generator, ctx, (lhs, rhs)| { @@ -1278,7 +1287,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( op, (&Some(ndarray_dtype), rhs), ctx.current_loc, - is_aug_assign, + variant, )? .unwrap() .to_basic_value_enum(ctx, generator, ndarray_dtype) @@ -1293,13 +1302,15 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( unreachable!("must be tobj") }; let (op_name, id) = { - let (binop_name, binop_assign_name) = - (binop_name(op).into(), binop_assign_name(op).into()); + let normal_method_name = OpInfo::from_binop(op, BinOpVariant::Normal).method_name; + let assign_method_name = OpInfo::from_binop(op, BinOpVariant::AugAssign).method_name; + // if is aug_assign, try aug_assign operator first - if is_aug_assign && fields.contains_key(&binop_assign_name) { - (binop_assign_name, *obj_id) + if variant == BinOpVariant::AugAssign && fields.contains_key(&assign_method_name.into()) + { + (assign_method_name.into(), *obj_id) } else { - (binop_name, *obj_id) + (normal_method_name.into(), *obj_id) } }; @@ -1349,7 +1360,7 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>( op: Operator, right: &Expr>, loc: Location, - is_aug_assign: bool, + variant: BinOpVariant, ) -> Result>, String> { let left_val = if let Some(v) = generator.gen_expr(ctx, left)? { v.to_basic_value_enum(ctx, generator, left.custom.unwrap())? @@ -1369,7 +1380,7 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>( op, (&right.custom, right_val), loc, - is_aug_assign, + variant, ) } @@ -1453,7 +1464,10 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( if op == ast::Unaryop::Invert { ast::Unaryop::Not } else { - unreachable!("ufunc {} not supported for ndarray[bool, N]", unaryop_name(op)) + unreachable!( + "ufunc {} not supported for ndarray[bool, N]", + OpInfo::from_unaryop(op).method_name + ) } } else { op @@ -2343,7 +2357,15 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } } ExprKind::BinOp { op, left, right } => { - return gen_binop_expr(generator, ctx, left, *op, right, expr.location, false); + return gen_binop_expr( + generator, + ctx, + left, + *op, + right, + expr.location, + BinOpVariant::Normal, + ); } ExprKind::UnaryOp { op, operand } => return gen_unaryop_expr(generator, ctx, *op, operand), ExprKind::Compare { left, ops, comparators } => { diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 3fab259..bd23cdd 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -11,8 +11,7 @@ use crate::{ call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices, call_ndarray_calc_size, }, - llvm_intrinsics, - llvm_intrinsics::call_memcpy_generic, + llvm_intrinsics::{self, call_memcpy_generic}, stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, CodeGenContext, CodeGenerator, }, @@ -22,7 +21,10 @@ use crate::{ numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, DefinitionId, }, - typecheck::typedef::{FunSignature, Type, TypeEnum}, + typecheck::{ + magic_methods::BinOpVariant, + typedef::{FunSignature, Type, TypeEnum}, + }, }; use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType}; use inkwell::{ @@ -1680,7 +1682,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( Operator::Mult, (&Some(elem_ty), b), ctx.current_loc, - false, + BinOpVariant::Normal, )? .unwrap() .to_basic_value_enum(ctx, generator, elem_ty)?; @@ -1693,7 +1695,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( Operator::Add, (&Some(elem_ty), a_mul_b), ctx.current_loc, - false, + BinOpVariant::Normal, )? .unwrap() .to_basic_value_enum(ctx, generator, elem_ty)?; diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 5bae9a9..0950356 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -11,7 +11,10 @@ use crate::{ gen_in_range_check, }, toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef}, - typecheck::typedef::{FunSignature, Type, TypeEnum}, + typecheck::{ + magic_methods::BinOpVariant, + typedef::{FunSignature, Type, TypeEnum}, + }, }; use inkwell::{ attributes::{Attribute, AttributeLoc}, @@ -1574,7 +1577,15 @@ pub fn gen_stmt( StmtKind::For { .. } => generator.gen_for(ctx, stmt)?, StmtKind::With { .. } => generator.gen_with(ctx, stmt)?, StmtKind::AugAssign { target, op, value, .. } => { - let value = gen_binop_expr(generator, ctx, target, *op, value, stmt.location, true)?; + let value = gen_binop_expr( + generator, + ctx, + target, + *op, + value, + stmt.location, + BinOpVariant::AugAssign, + )?; generator.gen_assign(ctx, target, value.unwrap())?; } StmtKind::Try { .. } => gen_try(generator, ctx, stmt)?, diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index f2b995e..13ef35d 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -5,7 +5,7 @@ use crate::typecheck::{ type_inferencer::*, typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, }; -use itertools::Itertools; +use itertools::{iproduct, Itertools}; use nac3parser::ast::StrRef; use nac3parser::ast::{Cmpop, Operator, Unaryop}; use std::cmp::max; @@ -13,64 +13,93 @@ use std::collections::HashMap; use std::rc::Rc; use strum::IntoEnumIterator; -#[must_use] -pub fn binop_name(op: Operator) -> &'static str { - match op { - Operator::Add => "__add__", - Operator::Sub => "__sub__", - Operator::Div => "__truediv__", - Operator::Mod => "__mod__", - Operator::Mult => "__mul__", - Operator::Pow => "__pow__", - Operator::BitOr => "__or__", - Operator::BitXor => "__xor__", - Operator::BitAnd => "__and__", - Operator::LShift => "__lshift__", - Operator::RShift => "__rshift__", - Operator::FloorDiv => "__floordiv__", - Operator::MatMult => "__matmul__", - } +/// Details about an operator (unary, binary, etc...) in Python +#[derive(Debug, Clone, Copy)] +pub struct OpInfo { + /// The method name of the binary operator. + /// For addition, this would be `__add__`, and `__iadd__` if + /// it is the augmented assigning variant. + pub method_name: &'static str, + /// The symbol of the binary operator. + /// For addition, this would be `+`, and `+=` if + /// it is the augmented assigning variant. + pub symbol: &'static str, } -#[must_use] -pub fn binop_assign_name(op: Operator) -> &'static str { - match op { - Operator::Add => "__iadd__", - Operator::Sub => "__isub__", - Operator::Div => "__itruediv__", - Operator::Mod => "__imod__", - Operator::Mult => "__imul__", - Operator::Pow => "__ipow__", - Operator::BitOr => "__ior__", - Operator::BitXor => "__ixor__", - Operator::BitAnd => "__iand__", - Operator::LShift => "__ilshift__", - Operator::RShift => "__irshift__", - Operator::FloorDiv => "__ifloordiv__", - Operator::MatMult => "__imatmul__", - } +/// Helper macro to conveniently build an [`OpInfo`]. +/// +/// Example usage: `make_info("add", "+")` generates `OpInfo { name: "__add__", symbol: "+" }` +macro_rules! make_info { + ($name:expr, $symbol:expr) => { + OpInfo { method_name: concat!("__", $name, "__"), symbol: $symbol } + }; } -#[must_use] -pub fn unaryop_name(op: Unaryop) -> &'static str { - match op { - Unaryop::UAdd => "__pos__", - Unaryop::USub => "__neg__", - Unaryop::Not => "__not__", - Unaryop::Invert => "__inv__", - } +/// The variant of a binary operator. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BinOpVariant { + /// The normal variant. + /// For addition, it would be `+`. + Normal, + /// The "Augmented Assigning Operator" variant. + /// For addition, it would be `+=`. + AugAssign, } -#[must_use] -pub fn comparison_name(op: Cmpop) -> Option<&'static str> { - match op { - Cmpop::Lt => Some("__lt__"), - Cmpop::LtE => Some("__le__"), - Cmpop::Gt => Some("__gt__"), - Cmpop::GtE => Some("__ge__"), - Cmpop::Eq => Some("__eq__"), - Cmpop::NotEq => Some("__ne__"), - _ => None, +impl OpInfo { + #[must_use] + pub fn from_binop(op: Operator, variant: BinOpVariant) -> Self { + // Helper macro to generate both the normal variant [`OpInfo`] and the + // augmented assigning variant [`OpInfo`] for a binary operator conveniently. + macro_rules! info { + ($name:literal, $symbol:literal) => { + (make_info!($name, $symbol), make_info!(concat!("i", $name), concat!($symbol, "="))) + }; + } + + let (normal_variant, aug_assign_variant) = match op { + Operator::Add => info!("add", "+"), + Operator::Sub => info!("sub", "-"), + Operator::Div => info!("truediv", "/"), + Operator::Mod => info!("mod", "%"), + Operator::Mult => info!("mul", "*"), + Operator::Pow => info!("pow", "**"), + Operator::BitOr => info!("or", "|"), + Operator::BitXor => info!("xor", "^"), + Operator::BitAnd => info!("and", "&"), + Operator::LShift => info!("lshift", "<<"), + Operator::RShift => info!("rshift", ">>"), + Operator::FloorDiv => info!("floordiv", "//"), + Operator::MatMult => info!("matmul", "@"), + }; + + match variant { + BinOpVariant::Normal => normal_variant, + BinOpVariant::AugAssign => aug_assign_variant, + } + } + + #[must_use] + pub fn from_unaryop(op: Unaryop) -> Self { + match op { + Unaryop::UAdd => make_info!("pos", "+"), + Unaryop::USub => make_info!("neg", "-"), + Unaryop::Not => make_info!("not", "not"), // i.e., `not False`, so the symbol is just `not`. + Unaryop::Invert => make_info!("inv", "~"), + } + } + + #[must_use] + pub fn from_cmpop(op: Cmpop) -> Option { + match op { + Cmpop::Lt => Some(make_info!("lt", "<")), + Cmpop::LtE => Some(make_info!("le", "<=")), + Cmpop::Gt => Some(make_info!("gt", ">")), + Cmpop::GtE => Some(make_info!("ge", ">=")), + Cmpop::Eq => Some(make_info!("eq", "==")), + Cmpop::NotEq => Some(make_info!("ne", "!=")), + _ => None, + } } } @@ -115,23 +144,8 @@ pub fn impl_binop( let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty); - for op in ops { - fields.insert(binop_name(*op).into(), { - ( - unifier.add_ty(TypeEnum::TFunc(FunSignature { - ret: ret_ty, - vars: function_vars.clone(), - args: vec![FuncArg { - ty: other_ty, - default_value: None, - name: "other".into(), - }], - })), - false, - ) - }); - - fields.insert(binop_assign_name(*op).into(), { + for (op, variant) in iproduct!(ops, [BinOpVariant::Normal, BinOpVariant::AugAssign]) { + fields.insert(OpInfo::from_binop(*op, variant).method_name.into(), { ( unifier.add_ty(TypeEnum::TFunc(FunSignature { ret: ret_ty, @@ -155,7 +169,7 @@ pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Option, ops: for op in ops { fields.insert( - unaryop_name(*op).into(), + OpInfo::from_unaryop(*op).method_name.into(), ( unifier.add_ty(TypeEnum::TFunc(FunSignature { ret: ret_ty, @@ -195,7 +209,7 @@ pub fn impl_cmpop( for op in ops { fields.insert( - comparison_name(*op).unwrap().into(), + OpInfo::from_cmpop(*op).unwrap().method_name.into(), ( unifier.add_ty(TypeEnum::TFunc(FunSignature { ret: ret_ty, diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index af2fd8d..7d0eee6 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -466,7 +466,8 @@ impl<'a> Fold<()> for Inferencer<'a> { (None, None) => {} }, ast::StmtKind::AugAssign { target, op, value, .. } => { - let res_ty = self.infer_bin_ops(stmt.location, target, *op, value, true)?; + let res_ty = + self.infer_bin_ops(stmt.location, target, *op, value, BinOpVariant::AugAssign)?; self.unify(res_ty, target.custom.unwrap(), &stmt.location)?; } ast::StmtKind::Assert { test, msg, .. } => { @@ -548,7 +549,7 @@ impl<'a> Fold<()> for Inferencer<'a> { } ExprKind::BoolOp { values, .. } => Some(self.infer_bool_ops(values)?), ExprKind::BinOp { left, op, right } => { - Some(self.infer_bin_ops(expr.location, left, *op, right, false)?) + Some(self.infer_bin_ops(expr.location, left, *op, right, BinOpVariant::Normal)?) } ExprKind::UnaryOp { op, operand } => { Some(self.infer_unary_ops(expr.location, *op, operand)?) @@ -1670,7 +1671,7 @@ impl<'a> Inferencer<'a> { left: &ast::Expr>, op: ast::Operator, right: &ast::Expr>, - is_aug_assign: bool, + variant: BinOpVariant, ) -> InferenceResult { let left_ty = left.custom.unwrap(); let right_ty = right.custom.unwrap(); @@ -1678,27 +1679,32 @@ impl<'a> Inferencer<'a> { let method = if let TypeEnum::TObj { fields, .. } = self.unifier.get_ty_immutable(left_ty).as_ref() { - let (binop_name, binop_assign_name) = - (binop_name(op).into(), binop_assign_name(op).into()); + let normal_method_name = OpInfo::from_binop(op, BinOpVariant::Normal).method_name; + let assign_method_name = OpInfo::from_binop(op, BinOpVariant::AugAssign).method_name; + // if is aug_assign, try aug_assign operator first - if is_aug_assign && fields.contains_key(&binop_assign_name) { - binop_assign_name + if variant == BinOpVariant::AugAssign && fields.contains_key(&assign_method_name.into()) + { + assign_method_name } else { - binop_name + normal_method_name } } else { - binop_name(op).into() + OpInfo::from_binop(op, variant).method_name }; - let ret = if is_aug_assign { - // The type of augmented assignment operator should never change - Some(left_ty) - } else { - typeof_binop(self.unifier, self.primitives, op, left_ty, right_ty) - .map_err(|e| HashSet::from([format!("{e} (at {location})")]))? + let ret = match variant { + BinOpVariant::Normal => { + typeof_binop(self.unifier, self.primitives, op, left_ty, right_ty) + .map_err(|e| HashSet::from([format!("{e} (at {location})")]))? + } + BinOpVariant::AugAssign => { + // The type of augmented assignment operator should never change + Some(left_ty) + } }; - self.build_method_call(location, method, left_ty, vec![right_ty], ret) + self.build_method_call(location, method.into(), left_ty, vec![right_ty], ret) } fn infer_unary_ops( @@ -1707,7 +1713,7 @@ impl<'a> Inferencer<'a> { op: ast::Unaryop, operand: &ast::Expr>, ) -> InferenceResult { - let method = unaryop_name(op).into(); + let method = OpInfo::from_unaryop(op).method_name.into(); let ret = typeof_unaryop(self.unifier, self.primitives, op, operand.custom.unwrap()) .map_err(|e| HashSet::from([format!("{e} (at {location})")]))?; @@ -1737,8 +1743,9 @@ impl<'a> Inferencer<'a> { let mut res = None; for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) { - let method = comparison_name(*c) + let method = OpInfo::from_cmpop(*c) .ok_or_else(|| HashSet::from(["unsupported comparator".to_string()]))? + .method_name .into(); let ret = typeof_cmpop(