diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 3ca72835..98972a81 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -10,7 +10,10 @@ use crate::{ }, symbol_resolver::{SymbolValue, ValueEnum}, toplevel::{DefinitionId, TopLevelDef}, - typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier}, + typecheck::{ + typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier}, + magic_methods::{binop_name, binop_assign_name}, + }, }; use inkwell::{ AddressSpace, @@ -927,21 +930,29 @@ pub fn gen_binop_expr<'ctx, 'a, G: CodeGenerator>( left: &Expr>, op: &Operator, right: &Expr>, -) -> Result, String> { + loc: Location, + is_aug_assign: bool, +) -> Result>, String> { let ty1 = ctx.unifier.get_representative(left.custom.unwrap()); let ty2 = ctx.unifier.get_representative(right.custom.unwrap()); - let left = generator.gen_expr(ctx, left)?.unwrap().to_basic_value_enum(ctx, generator, left.custom.unwrap())?; - let right = generator.gen_expr(ctx, right)?.unwrap().to_basic_value_enum(ctx, generator, right.custom.unwrap())?; + let left_val = generator + .gen_expr(ctx, left)? + .unwrap() + .to_basic_value_enum(ctx, generator, left.custom.unwrap())?; + let right_val = generator + .gen_expr(ctx, right)? + .unwrap() + .to_basic_value_enum(ctx, generator, right.custom.unwrap())?; // we can directly compare the types, because we've got their representatives // which would be unchanged until further unification, which we would never do // when doing code generation for function instances - Ok(if ty1 == ty2 && [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1) { - ctx.gen_int_ops(generator, op, left, right, true) + if ty1 == ty2 && [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1) { + Ok(Some(ctx.gen_int_ops(generator, op, left_val, right_val, true).into())) } else if ty1 == ty2 && [ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty1) { - ctx.gen_int_ops(generator, op, left, right, false) + Ok(Some(ctx.gen_int_ops(generator, op, left_val, right_val, false).into())) } else if ty1 == ty2 && ctx.primitives.float == ty1 { - ctx.gen_float_ops(op, left, right) + Ok(Some(ctx.gen_float_ops(op, left_val, right_val).into())) } else if ty1 == ctx.primitives.float && ty2 == ctx.primitives.int32 { // Pow is the only operator that would pass typecheck between float and int assert!(*op == Operator::Pow); @@ -951,14 +962,68 @@ pub fn gen_binop_expr<'ctx, 'a, G: CodeGenerator>( let ty = f64_t.fn_type(&[f64_t.into(), i32_t.into()], false); ctx.module.add_function("llvm.powi.f64.i32", ty, None) }); - ctx.builder - .build_call(pow_intr, &[left.into(), right.into()], "f_pow_i") + let res = ctx.builder + .build_call(pow_intr, &[left_val.into(), right_val.into()], "f_pow_i") .try_as_basic_value() - .unwrap_left() + .unwrap_left(); + Ok(Some(res.into())) } else { - unimplemented!() + let (op_name, id) = if let TypeEnum::TObj { fields, obj_id, .. } = + ctx.unifier.get_ty_immutable(left.custom.unwrap()).as_ref() + { + let (binop_name, binop_assign_name) = ( + binop_name(op).into(), + binop_assign_name(op).into() + ); + // if is aug_assign, try aug_assign operator first + if is_aug_assign && fields.contains_key(&binop_assign_name) { + (binop_assign_name, *obj_id) + } else { + (binop_name, *obj_id) + } + } else { + unreachable!("must be tobj") + }; + let signature = match ctx.calls.get(&loc.into()) { + Some(call) => ctx.unifier.get_call_signature(*call).unwrap(), + None => { + if let TypeEnum::TObj { fields, .. } = + ctx.unifier.get_ty_immutable(left.custom.unwrap()).as_ref() + { + let fn_ty = fields.get(&op_name).unwrap().0; + if let TypeEnum::TFunc(sig) = ctx.unifier.get_ty_immutable(fn_ty).as_ref() { + sig.clone() + } else { + unreachable!("must be func sig") + } + } else { + unreachable!("must be tobj") + } + }, + }; + let fun_id = { + let defs = ctx.top_level.definitions.read(); + let obj_def = defs.get(id.0).unwrap().read(); + if let TopLevelDef::Class { methods, .. } = &*obj_def { + let mut fun_id = None; + for (name, _, id) in methods.iter() { + if name == &op_name { + fun_id = Some(*id); + } + } + fun_id.unwrap() + } else { + unreachable!() + } + }; + generator + .gen_call( + ctx, + Some((left.custom.unwrap(), left_val.into())), + (&signature, fun_id), + vec![(None, right_val.into())], + ).map(|f| f.map(|f| f.into())) } - .into()) } pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( @@ -1125,7 +1190,9 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( phi.add_incoming(&[(&a, a_bb), (&b, b_bb)]); phi.as_basic_value().into() } - ExprKind::BinOp { op, left, right } => gen_binop_expr(generator, ctx, left, op, right)?, + ExprKind::BinOp { op, left, right } => { + return gen_binop_expr(generator, ctx, left, op, right, expr.location, false); + } ExprKind::UnaryOp { op, operand } => { let ty = ctx.unifier.get_representative(operand.custom.unwrap()); let val = diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 9209440d..5c1a7842 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -1020,8 +1020,8 @@ pub fn gen_stmt<'ctx, 'a, G: CodeGenerator>( StmtKind::For { .. } => generator.gen_for(ctx, stmt)?, StmtKind::With { .. } => generator.gen_with(ctx, stmt)?, StmtKind::AugAssign { target, op, value, .. } => { - let value = gen_binop_expr(generator, ctx, target, op, value)?; - generator.gen_assign(ctx, target, value)?; + let value = gen_binop_expr(generator, ctx, target, op, value, stmt.location, true)?; + generator.gen_assign(ctx, target, value.unwrap())?; } StmtKind::Try { .. } => gen_try(generator, ctx, stmt)?, StmtKind::Raise { exc, .. } => { diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index 36d46a6d..0a023b19 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -83,7 +83,7 @@ where pub fn impl_binop( unifier: &mut Unifier, - store: &PrimitiveStore, + _store: &PrimitiveStore, ty: Type, other_ty: &[Type], ret_ty: Type, @@ -120,7 +120,7 @@ pub fn impl_binop( fields.insert(binop_assign_name(op).into(), { ( unifier.add_ty(TypeEnum::TFunc(FunSignature { - ret: store.none, + ret: ret_ty, vars: function_vars.clone(), args: vec![FuncArg { ty: other_ty, diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 510a0f09..c72e1998 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -423,7 +423,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { (None, None) => {} }, ast::StmtKind::AugAssign { target, op, value, .. } => { - let res_ty = self.infer_bin_ops(stmt.location, target, op, value)?; + let res_ty = self.infer_bin_ops(stmt.location, target, op, value, true)?; self.unify(res_ty, target.custom.unwrap(), &stmt.location)?; } ast::StmtKind::Assert { test, msg, .. } => { @@ -505,7 +505,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { } ast::ExprKind::BoolOp { values, .. } => Some(self.infer_bool_ops(values)?), ast::ExprKind::BinOp { left, op, right } => { - Some(self.infer_bin_ops(expr.location, left, op, right)?) + Some(self.infer_bin_ops(expr.location, left, op, right, false)?) } ast::ExprKind::UnaryOp { op, operand } => Some(self.infer_unary_ops(op, operand)?), ast::ExprKind::Compare { left, ops, comparators } => { @@ -1028,8 +1028,24 @@ impl<'a> Inferencer<'a> { left: &ast::Expr>, op: &ast::Operator, right: &ast::Expr>, + is_aug_assign: bool, ) -> InferenceResult { - let method = binop_name(op).into(); + let method = if let TypeEnum::TObj { fields, .. } = + self.unifier.get_ty_immutable(left.custom.unwrap()).as_ref() + { + let (binop_name, binop_assign_name) = ( + binop_name(op).into(), + binop_assign_name(op).into() + ); + // if is aug_assign, try aug_assign operator first + if is_aug_assign && fields.contains_key(&binop_assign_name) { + binop_assign_name + } else { + binop_name + } + } else { + binop_name(op).into() + }; self.build_method_call( location, method,