diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 9f7398dec4..3b8b85d5de 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -27,7 +27,7 @@ use crate::{ DefinitionId, TopLevelDef, }, typecheck::{ - magic_methods::{binop_assign_name, binop_name, unaryop_name}, + magic_methods::{Binop, BinopVariant, HasOpInfo}, typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap}, }, }; @@ -1165,10 +1165,9 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, left: (&Option, BasicValueEnum<'ctx>), - op: Operator, + op: Binop, right: (&Option, BasicValueEnum<'ctx>), loc: Location, - is_aug_assign: bool, ) -> Result>, String> { let (left_ty, left_val) = left; let (right_ty, right_val) = right; @@ -1180,17 +1179,17 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( // which would be unchanged until further unification, which we would never do // when doing code generation for function instances if ty1 == ty2 && [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1) { - Ok(Some(ctx.gen_int_ops(generator, op, left_val, right_val, true).into())) + Ok(Some(ctx.gen_int_ops(generator, op.base, left_val, right_val, true).into())) } else if ty1 == ty2 && [ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty1) { - Ok(Some(ctx.gen_int_ops(generator, op, left_val, right_val, false).into())) - } else if [Operator::LShift, Operator::RShift].contains(&op) { + Ok(Some(ctx.gen_int_ops(generator, op.base, left_val, right_val, false).into())) + } else if [Operator::LShift, Operator::RShift].contains(&op.base) { let signed = [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1); - Ok(Some(ctx.gen_int_ops(generator, op, left_val, right_val, signed).into())) + Ok(Some(ctx.gen_int_ops(generator, op.base, left_val, right_val, signed).into())) } else if ty1 == ty2 && ctx.primitives.float == ty1 { - Ok(Some(ctx.gen_float_ops(op, left_val, right_val).into())) + Ok(Some(ctx.gen_float_ops(op.base, left_val, right_val).into())) } else if ty1 == ctx.primitives.float && ty2 == ctx.primitives.int32 { // Pow is the only operator that would pass typecheck between float and int - assert_eq!(op, Operator::Pow); + assert_eq!(op.base, Operator::Pow); let res = call_float_powi( ctx, left_val.into_float_value(), @@ -1379,13 +1378,16 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( let right_val = NDArrayValue::from_ptr_val(right_val.into_pointer_value(), llvm_usize, None); - let res = if op == Operator::MatMult { + let res = if op.base == Operator::MatMult { // MatMult is the only binop which is not an elementwise op numpy::ndarray_matmul_2d( generator, ctx, ndarray_dtype1, - if is_aug_assign { Some(left_val) } else { None }, + match op.variant { + BinopVariant::Normal => None, + BinopVariant::AugAssign => Some(left_val), + }, left_val, right_val, )? @@ -1394,7 +1396,10 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( generator, ctx, ndarray_dtype1, - if is_aug_assign { Some(left_val) } else { None }, + match op.variant { + BinopVariant::Normal => None, + BinopVariant::AugAssign => Some(left_val), + }, (left_val.as_base_value().into(), false), (right_val.as_base_value().into(), false), |generator, ctx, (lhs, rhs)| { @@ -1405,7 +1410,6 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( op, (&Some(ndarray_dtype2), rhs), ctx.current_loc, - is_aug_assign, )? .unwrap() .to_basic_value_enum( @@ -1430,7 +1434,10 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( generator, ctx, ndarray_dtype, - if is_aug_assign { Some(ndarray_val) } else { None }, + match op.variant { + BinopVariant::Normal => None, + BinopVariant::AugAssign => Some(ndarray_val), + }, (left_val, !is_ndarray1), (right_val, !is_ndarray2), |generator, ctx, (lhs, rhs)| { @@ -1441,7 +1448,6 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( op, (&Some(ndarray_dtype), rhs), ctx.current_loc, - is_aug_assign, )? .unwrap() .to_basic_value_enum(ctx, generator, ndarray_dtype) @@ -1456,13 +1462,16 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( unreachable!("must be tobj") }; let (op_name, id) = { - let (binop_name, binop_assign_name) = - (binop_name(op).into(), binop_assign_name(op).into()); + let normal_method_name = Binop::normal(op.base).op_info().method_name; + let assign_method_name = Binop::aug_assign(op.base).op_info().method_name; + // if is aug_assign, try aug_assign operator first - if is_aug_assign && fields.contains_key(&binop_assign_name) { - (binop_assign_name, *obj_id) + if op.variant == BinopVariant::AugAssign + && fields.contains_key(&assign_method_name.into()) + { + (assign_method_name.into(), *obj_id) } else { - (binop_name, *obj_id) + (normal_method_name.into(), *obj_id) } }; @@ -1509,10 +1518,9 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, left: &Expr>, - op: Operator, + op: Binop, right: &Expr>, loc: Location, - is_aug_assign: bool, ) -> Result>, String> { let left_val = if let Some(v) = generator.gen_expr(ctx, left)? { v.to_basic_value_enum(ctx, generator, left.custom.unwrap())? @@ -1532,7 +1540,6 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>( op, (&right.custom, right_val), loc, - is_aug_assign, ) } @@ -1616,7 +1623,10 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( if op == ast::Unaryop::Invert { ast::Unaryop::Not } else { - unreachable!("ufunc {} not supported for ndarray[bool, N]", unaryop_name(op)) + unreachable!( + "ufunc {} not supported for ndarray[bool, N]", + op.op_info().method_name, + ) } } else { op @@ -2698,7 +2708,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } } ExprKind::BinOp { op, left, right } => { - return gen_binop_expr(generator, ctx, left, *op, right, expr.location, false); + return gen_binop_expr(generator, ctx, left, Binop::normal(*op), right, expr.location); } ExprKind::UnaryOp { op, operand } => return gen_unaryop_expr(generator, ctx, *op, operand), ExprKind::Compare { left, ops, comparators } => { diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index fee016af58..7421c89466 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -11,8 +11,7 @@ use crate::{ call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices, call_ndarray_calc_size, }, - llvm_intrinsics, - llvm_intrinsics::call_memcpy_generic, + llvm_intrinsics::{self, call_memcpy_generic}, stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, CodeGenContext, CodeGenerator, }, @@ -22,7 +21,10 @@ use crate::{ numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, DefinitionId, }, - typecheck::typedef::{FunSignature, Type, TypeEnum}, + typecheck::{ + magic_methods::Binop, + typedef::{FunSignature, Type, TypeEnum}, + }, }; use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType}; use inkwell::{ @@ -1679,10 +1681,9 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( generator, ctx, (&Some(elem_ty), a), - Operator::Mult, + Binop::normal(Operator::Mult), (&Some(elem_ty), b), ctx.current_loc, - false, )? .unwrap() .to_basic_value_enum(ctx, generator, elem_ty)?; @@ -1692,10 +1693,9 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( generator, ctx, (&Some(elem_ty), result), - Operator::Add, + Binop::normal(Operator::Add), (&Some(elem_ty), a_mul_b), ctx.current_loc, - false, )? .unwrap() .to_basic_value_enum(ctx, generator, elem_ty)?; diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 23ba4352e2..cf16d3e5c2 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -11,7 +11,10 @@ use crate::{ gen_in_range_check, }, toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef}, - typecheck::typedef::{FunSignature, Type, TypeEnum}, + typecheck::{ + magic_methods::Binop, + typedef::{FunSignature, Type, TypeEnum}, + }, }; use inkwell::{ attributes::{Attribute, AttributeLoc}, @@ -1593,7 +1596,14 @@ pub fn gen_stmt( StmtKind::For { .. } => generator.gen_for(ctx, stmt)?, StmtKind::With { .. } => generator.gen_with(ctx, stmt)?, StmtKind::AugAssign { target, op, value, .. } => { - let value = gen_binop_expr(generator, ctx, target, *op, value, stmt.location, true)?; + let value = gen_binop_expr( + generator, + ctx, + target, + Binop::aug_assign(*op), + value, + stmt.location, + )?; generator.gen_assign(ctx, target, value.unwrap())?; } StmtKind::Try { .. } => gen_try(generator, ctx, stmt)?, diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index 7f6b2f3591..929abb315e 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -5,7 +5,7 @@ use crate::typecheck::{ type_inferencer::*, typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, }; -use itertools::Itertools; +use itertools::{iproduct, Itertools}; use nac3parser::ast::StrRef; use nac3parser::ast::{Cmpop, Operator, Unaryop}; use std::cmp::max; @@ -13,67 +13,138 @@ use std::collections::HashMap; use std::rc::Rc; use strum::IntoEnumIterator; -#[must_use] -pub fn binop_name(op: Operator) -> &'static str { - match op { - Operator::Add => "__add__", - Operator::Sub => "__sub__", - Operator::Div => "__truediv__", - Operator::Mod => "__mod__", - Operator::Mult => "__mul__", - Operator::Pow => "__pow__", - Operator::BitOr => "__or__", - Operator::BitXor => "__xor__", - Operator::BitAnd => "__and__", - Operator::LShift => "__lshift__", - Operator::RShift => "__rshift__", - Operator::FloorDiv => "__floordiv__", - Operator::MatMult => "__matmul__", +/// The variant of a binary operator. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BinopVariant { + /// The normal variant. + /// For addition, it would be `+`. + Normal, + /// The "Augmented Assigning Operator" variant. + /// For addition, it would be `+=`. + AugAssign, +} + +/// A binary operator with its variant. +#[derive(Debug, Clone, Copy)] +pub struct Binop { + /// The base [`Operator`] of this binary operator. + pub base: Operator, + /// The variant of this binary operator. + pub variant: BinopVariant, +} + +impl Binop { + /// Make a [`Binop`] of the normal variant from an [`Operator`]. + #[must_use] + pub fn normal(base: Operator) -> Self { + Binop { base, variant: BinopVariant::Normal } + } + + /// Make a [`Binop`] of the aug assign variant from an [`Operator`]. + #[must_use] + pub fn aug_assign(base: Operator) -> Self { + Binop { base, variant: BinopVariant::AugAssign } } } -#[must_use] -pub fn binop_assign_name(op: Operator) -> &'static str { - match op { - Operator::Add => "__iadd__", - Operator::Sub => "__isub__", - Operator::Div => "__itruediv__", - Operator::Mod => "__imod__", - Operator::Mult => "__imul__", - Operator::Pow => "__ipow__", - Operator::BitOr => "__ior__", - Operator::BitXor => "__ixor__", - Operator::BitAnd => "__iand__", - Operator::LShift => "__ilshift__", - Operator::RShift => "__irshift__", - Operator::FloorDiv => "__ifloordiv__", - Operator::MatMult => "__imatmul__", - } +/// Details about an operator (unary, binary, etc...) in Python +#[derive(Debug, Clone, Copy)] +pub struct OpInfo { + /// The method name of the binary operator. + /// For addition, this would be `__add__`, and `__iadd__` if + /// it is the augmented assigning variant. + pub method_name: &'static str, + /// The symbol of the binary operator. + /// For addition, this would be `+`, and `+=` if + /// it is the augmented assigning variant. + pub symbol: &'static str, } -#[must_use] -pub fn unaryop_name(op: Unaryop) -> &'static str { - match op { - Unaryop::UAdd => "__pos__", - Unaryop::USub => "__neg__", - Unaryop::Not => "__not__", - Unaryop::Invert => "__inv__", - } +/// Helper macro to conveniently build an [`OpInfo`]. +/// +/// Example usage: `make_info("add", "+")` generates `OpInfo { name: "__add__", symbol: "+" }` +macro_rules! make_op_info { + ($name:expr, $symbol:expr) => { + OpInfo { method_name: concat!("__", $name, "__"), symbol: $symbol } + }; } -#[must_use] -pub fn comparison_name(op: Cmpop) -> Option<&'static str> { +pub trait HasOpInfo { + fn op_info(&self) -> OpInfo; +} + +fn try_get_cmpop_info(op: Cmpop) -> Option { match op { - Cmpop::Lt => Some("__lt__"), - Cmpop::LtE => Some("__le__"), - Cmpop::Gt => Some("__gt__"), - Cmpop::GtE => Some("__ge__"), - Cmpop::Eq => Some("__eq__"), - Cmpop::NotEq => Some("__ne__"), + Cmpop::Lt => Some(make_op_info!("lt", "<")), + Cmpop::LtE => Some(make_op_info!("le", "<=")), + Cmpop::Gt => Some(make_op_info!("gt", ">")), + Cmpop::GtE => Some(make_op_info!("ge", ">=")), + Cmpop::Eq => Some(make_op_info!("eq", "==")), + Cmpop::NotEq => Some(make_op_info!("ne", "!=")), _ => None, } } +impl OpInfo { + #[must_use] + pub fn supports_cmpop(op: Cmpop) -> bool { + try_get_cmpop_info(op).is_some() + } +} + +impl HasOpInfo for Cmpop { + fn op_info(&self) -> OpInfo { + try_get_cmpop_info(*self).expect("{self:?} is not supported") + } +} + +impl HasOpInfo for Binop { + fn op_info(&self) -> OpInfo { + // Helper macro to generate both the normal variant [`OpInfo`] and the + // augmented assigning variant [`OpInfo`] for a binary operator conveniently. + macro_rules! info { + ($name:literal, $symbol:literal) => { + ( + make_op_info!($name, $symbol), + make_op_info!(concat!("i", $name), concat!($symbol, "=")), + ) + }; + } + + let (normal_variant, aug_assign_variant) = match self.base { + Operator::Add => info!("add", "+"), + Operator::Sub => info!("sub", "-"), + Operator::Div => info!("truediv", "/"), + Operator::Mod => info!("mod", "%"), + Operator::Mult => info!("mul", "*"), + Operator::Pow => info!("pow", "**"), + Operator::BitOr => info!("or", "|"), + Operator::BitXor => info!("xor", "^"), + Operator::BitAnd => info!("and", "&"), + Operator::LShift => info!("lshift", "<<"), + Operator::RShift => info!("rshift", ">>"), + Operator::FloorDiv => info!("floordiv", "//"), + Operator::MatMult => info!("matmul", "@"), + }; + + match self.variant { + BinopVariant::Normal => normal_variant, + BinopVariant::AugAssign => aug_assign_variant, + } + } +} + +impl HasOpInfo for Unaryop { + fn op_info(&self) -> OpInfo { + match self { + Unaryop::UAdd => make_op_info!("pos", "+"), + Unaryop::USub => make_op_info!("neg", "-"), + Unaryop::Not => make_op_info!("not", "not"), // i.e., `not False`, so the symbol is just `not`. + Unaryop::Invert => make_op_info!("inv", "~"), + } + } +} + pub(super) fn with_fields(unifier: &mut Unifier, ty: Type, f: F) where F: FnOnce(&mut Unifier, &mut HashMap), @@ -115,23 +186,9 @@ pub fn impl_binop( let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty); - for op in ops { - fields.insert(binop_name(*op).into(), { - ( - unifier.add_ty(TypeEnum::TFunc(FunSignature { - ret: ret_ty, - vars: function_vars.clone(), - args: vec![FuncArg { - ty: other_ty, - default_value: None, - name: "other".into(), - }], - })), - false, - ) - }); - - fields.insert(binop_assign_name(*op).into(), { + for (base_op, variant) in iproduct!(ops, [BinopVariant::Normal, BinopVariant::AugAssign]) { + let op = Binop { base: *base_op, variant }; + fields.insert(op.op_info().method_name.into(), { ( unifier.add_ty(TypeEnum::TFunc(FunSignature { ret: ret_ty, @@ -155,7 +212,7 @@ pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Option, ops: for op in ops { fields.insert( - unaryop_name(*op).into(), + op.op_info().method_name.into(), ( unifier.add_ty(TypeEnum::TFunc(FunSignature { ret: ret_ty, @@ -195,7 +252,7 @@ pub fn impl_cmpop( for op in ops { fields.insert( - comparison_name(*op).unwrap().into(), + op.op_info().method_name.into(), ( unifier.add_ty(TypeEnum::TFunc(FunSignature { ret: ret_ty, diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index d959c1ab21..8dc48fb71a 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -491,7 +491,8 @@ impl<'a> Fold<()> for Inferencer<'a> { (None, None) => {} }, ast::StmtKind::AugAssign { target, op, value, .. } => { - let res_ty = self.infer_bin_ops(stmt.location, target, *op, value, true)?; + let res_ty = + self.infer_bin_ops(stmt.location, target, Binop::aug_assign(*op), value)?; self.unify(res_ty, target.custom.unwrap(), &stmt.location)?; } ast::StmtKind::Assert { test, msg, .. } => { @@ -573,7 +574,7 @@ impl<'a> Fold<()> for Inferencer<'a> { } ExprKind::BoolOp { values, .. } => Some(self.infer_bool_ops(values)?), ExprKind::BinOp { left, op, right } => { - Some(self.infer_bin_ops(expr.location, left, *op, right, false)?) + Some(self.infer_bin_ops(expr.location, left, Binop::normal(*op), right)?) } ExprKind::UnaryOp { op, operand } => { Some(self.infer_unary_ops(expr.location, *op, operand)?) @@ -1729,9 +1730,8 @@ impl<'a> Inferencer<'a> { &mut self, location: Location, left: &ast::Expr>, - op: ast::Operator, + op: Binop, right: &ast::Expr>, - is_aug_assign: bool, ) -> InferenceResult { let left_ty = left.custom.unwrap(); let right_ty = right.custom.unwrap(); @@ -1739,27 +1739,33 @@ impl<'a> Inferencer<'a> { let method = if let TypeEnum::TObj { fields, .. } = self.unifier.get_ty_immutable(left_ty).as_ref() { - let (binop_name, binop_assign_name) = - (binop_name(op).into(), binop_assign_name(op).into()); + let normal_method_name = Binop::normal(op.base).op_info().method_name; + let assign_method_name = Binop::aug_assign(op.base).op_info().method_name; + // if is aug_assign, try aug_assign operator first - if is_aug_assign && fields.contains_key(&binop_assign_name) { - binop_assign_name + if op.variant == BinopVariant::AugAssign + && fields.contains_key(&assign_method_name.into()) + { + assign_method_name } else { - binop_name + normal_method_name } } else { - binop_name(op).into() + op.op_info().method_name }; - let ret = if is_aug_assign { - // The type of augmented assignment operator should never change - Some(left_ty) - } else { - typeof_binop(self.unifier, self.primitives, op, left_ty, right_ty) - .map_err(|e| HashSet::from([format!("{e} (at {location})")]))? + let ret = match op.variant { + BinopVariant::Normal => { + typeof_binop(self.unifier, self.primitives, op.base, left_ty, right_ty) + .map_err(|e| HashSet::from([format!("{e} (at {location})")]))? + } + BinopVariant::AugAssign => { + // The type of augmented assignment operator should never change + Some(left_ty) + } }; - self.build_method_call(location, method, left_ty, vec![right_ty], ret) + self.build_method_call(location, method.into(), left_ty, vec![right_ty], ret) } fn infer_unary_ops( @@ -1768,7 +1774,7 @@ impl<'a> Inferencer<'a> { op: ast::Unaryop, operand: &ast::Expr>, ) -> InferenceResult { - let method = unaryop_name(op).into(); + let method = op.op_info().method_name.into(); let ret = typeof_unaryop(self.unifier, self.primitives, op, operand.custom.unwrap()) .map_err(|e| HashSet::from([format!("{e} (at {location})")]))?; @@ -1798,9 +1804,11 @@ impl<'a> Inferencer<'a> { let mut res = None; for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) { - let method = comparison_name(*c) - .ok_or_else(|| HashSet::from(["unsupported comparator".to_string()]))? - .into(); + if !OpInfo::supports_cmpop(*c) { + return Err(HashSet::from(["unsupported comparator".to_string()])); + } + + let method = c.op_info().method_name.into(); let ret = typeof_cmpop( self.unifier,