diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 3ca72835..98972a81 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -10,7 +10,10 @@ use crate::{ }, symbol_resolver::{SymbolValue, ValueEnum}, toplevel::{DefinitionId, TopLevelDef}, - typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier}, + typecheck::{ + typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier}, + magic_methods::{binop_name, binop_assign_name}, + }, }; use inkwell::{ AddressSpace, @@ -927,21 +930,29 @@ pub fn gen_binop_expr<'ctx, 'a, G: CodeGenerator>( left: &Expr>, op: &Operator, right: &Expr>, -) -> Result, String> { + loc: Location, + is_aug_assign: bool, +) -> Result>, String> { let ty1 = ctx.unifier.get_representative(left.custom.unwrap()); let ty2 = ctx.unifier.get_representative(right.custom.unwrap()); - let left = generator.gen_expr(ctx, left)?.unwrap().to_basic_value_enum(ctx, generator, left.custom.unwrap())?; - let right = generator.gen_expr(ctx, right)?.unwrap().to_basic_value_enum(ctx, generator, right.custom.unwrap())?; + let left_val = generator + .gen_expr(ctx, left)? + .unwrap() + .to_basic_value_enum(ctx, generator, left.custom.unwrap())?; + let right_val = generator + .gen_expr(ctx, right)? + .unwrap() + .to_basic_value_enum(ctx, generator, right.custom.unwrap())?; // we can directly compare the types, because we've got their representatives // which would be unchanged until further unification, which we would never do // when doing code generation for function instances - Ok(if ty1 == ty2 && [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1) { - ctx.gen_int_ops(generator, op, left, right, true) + if ty1 == ty2 && [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1) { + Ok(Some(ctx.gen_int_ops(generator, op, left_val, right_val, true).into())) } else if ty1 == ty2 && [ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty1) { - ctx.gen_int_ops(generator, op, left, right, false) + Ok(Some(ctx.gen_int_ops(generator, op, left_val, right_val, false).into())) } else if ty1 == ty2 && ctx.primitives.float == ty1 { - ctx.gen_float_ops(op, left, right) + Ok(Some(ctx.gen_float_ops(op, left_val, right_val).into())) } else if ty1 == ctx.primitives.float && ty2 == ctx.primitives.int32 { // Pow is the only operator that would pass typecheck between float and int assert!(*op == Operator::Pow); @@ -951,14 +962,68 @@ pub fn gen_binop_expr<'ctx, 'a, G: CodeGenerator>( let ty = f64_t.fn_type(&[f64_t.into(), i32_t.into()], false); ctx.module.add_function("llvm.powi.f64.i32", ty, None) }); - ctx.builder - .build_call(pow_intr, &[left.into(), right.into()], "f_pow_i") + let res = ctx.builder + .build_call(pow_intr, &[left_val.into(), right_val.into()], "f_pow_i") .try_as_basic_value() - .unwrap_left() + .unwrap_left(); + Ok(Some(res.into())) } else { - unimplemented!() + let (op_name, id) = if let TypeEnum::TObj { fields, obj_id, .. } = + ctx.unifier.get_ty_immutable(left.custom.unwrap()).as_ref() + { + let (binop_name, binop_assign_name) = ( + binop_name(op).into(), + binop_assign_name(op).into() + ); + // if is aug_assign, try aug_assign operator first + if is_aug_assign && fields.contains_key(&binop_assign_name) { + (binop_assign_name, *obj_id) + } else { + (binop_name, *obj_id) + } + } else { + unreachable!("must be tobj") + }; + let signature = match ctx.calls.get(&loc.into()) { + Some(call) => ctx.unifier.get_call_signature(*call).unwrap(), + None => { + if let TypeEnum::TObj { fields, .. } = + ctx.unifier.get_ty_immutable(left.custom.unwrap()).as_ref() + { + let fn_ty = fields.get(&op_name).unwrap().0; + if let TypeEnum::TFunc(sig) = ctx.unifier.get_ty_immutable(fn_ty).as_ref() { + sig.clone() + } else { + unreachable!("must be func sig") + } + } else { + unreachable!("must be tobj") + } + }, + }; + let fun_id = { + let defs = ctx.top_level.definitions.read(); + let obj_def = defs.get(id.0).unwrap().read(); + if let TopLevelDef::Class { methods, .. } = &*obj_def { + let mut fun_id = None; + for (name, _, id) in methods.iter() { + if name == &op_name { + fun_id = Some(*id); + } + } + fun_id.unwrap() + } else { + unreachable!() + } + }; + generator + .gen_call( + ctx, + Some((left.custom.unwrap(), left_val.into())), + (&signature, fun_id), + vec![(None, right_val.into())], + ).map(|f| f.map(|f| f.into())) } - .into()) } pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( @@ -1125,7 +1190,9 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( phi.add_incoming(&[(&a, a_bb), (&b, b_bb)]); phi.as_basic_value().into() } - ExprKind::BinOp { op, left, right } => gen_binop_expr(generator, ctx, left, op, right)?, + ExprKind::BinOp { op, left, right } => { + return gen_binop_expr(generator, ctx, left, op, right, expr.location, false); + } ExprKind::UnaryOp { op, operand } => { let ty = ctx.unifier.get_representative(operand.custom.unwrap()); let val = diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 9209440d..5c1a7842 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -1020,8 +1020,8 @@ pub fn gen_stmt<'ctx, 'a, G: CodeGenerator>( StmtKind::For { .. } => generator.gen_for(ctx, stmt)?, StmtKind::With { .. } => generator.gen_with(ctx, stmt)?, StmtKind::AugAssign { target, op, value, .. } => { - let value = gen_binop_expr(generator, ctx, target, op, value)?; - generator.gen_assign(ctx, target, value)?; + let value = gen_binop_expr(generator, ctx, target, op, value, stmt.location, true)?; + generator.gen_assign(ctx, target, value.unwrap())?; } StmtKind::Try { .. } => gen_try(generator, ctx, stmt)?, StmtKind::Raise { exc, .. } => { diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index 36d46a6d..0a023b19 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -83,7 +83,7 @@ where pub fn impl_binop( unifier: &mut Unifier, - store: &PrimitiveStore, + _store: &PrimitiveStore, ty: Type, other_ty: &[Type], ret_ty: Type, @@ -120,7 +120,7 @@ pub fn impl_binop( fields.insert(binop_assign_name(op).into(), { ( unifier.add_ty(TypeEnum::TFunc(FunSignature { - ret: store.none, + ret: ret_ty, vars: function_vars.clone(), args: vec![FuncArg { ty: other_ty, diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 510a0f09..c72e1998 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -423,7 +423,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { (None, None) => {} }, ast::StmtKind::AugAssign { target, op, value, .. } => { - let res_ty = self.infer_bin_ops(stmt.location, target, op, value)?; + let res_ty = self.infer_bin_ops(stmt.location, target, op, value, true)?; self.unify(res_ty, target.custom.unwrap(), &stmt.location)?; } ast::StmtKind::Assert { test, msg, .. } => { @@ -505,7 +505,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { } ast::ExprKind::BoolOp { values, .. } => Some(self.infer_bool_ops(values)?), ast::ExprKind::BinOp { left, op, right } => { - Some(self.infer_bin_ops(expr.location, left, op, right)?) + Some(self.infer_bin_ops(expr.location, left, op, right, false)?) } ast::ExprKind::UnaryOp { op, operand } => Some(self.infer_unary_ops(op, operand)?), ast::ExprKind::Compare { left, ops, comparators } => { @@ -1028,8 +1028,24 @@ impl<'a> Inferencer<'a> { left: &ast::Expr>, op: &ast::Operator, right: &ast::Expr>, + is_aug_assign: bool, ) -> InferenceResult { - let method = binop_name(op).into(); + let method = if let TypeEnum::TObj { fields, .. } = + self.unifier.get_ty_immutable(left.custom.unwrap()).as_ref() + { + let (binop_name, binop_assign_name) = ( + binop_name(op).into(), + binop_assign_name(op).into() + ); + // if is aug_assign, try aug_assign operator first + if is_aug_assign && fields.contains_key(&binop_assign_name) { + binop_assign_name + } else { + binop_name + } + } else { + binop_name(op).into() + }; self.build_method_call( location, method, diff --git a/nac3standalone/demo/src/operators.py b/nac3standalone/demo/src/operators.py new file mode 100644 index 00000000..0470b969 --- /dev/null +++ b/nac3standalone/demo/src/operators.py @@ -0,0 +1,257 @@ +from __future__ import annotations + +@extern +def output_int32(x: int32): + ... +@extern +def output_uint32(x: uint32): + ... +@extern +def output_int64(x: int64): + ... +@extern +def output_uint64(x: uint64): + ... +@extern +def output_float64(x: float): + ... + +def run() -> int32: + test_int32() + test_uint32() + test_int64() + test_uint64() + test_A() + test_B() + return 0 + +def test_int32(): + a = 17 + b = 3 + output_int32(a + b) + output_int32(a - b) + output_int32(a * b) + output_int32(a // b) + output_int32(a % b) + output_int32(a | b) + output_int32(a ^ b) + output_int32(a & b) + output_int32(a << b) + output_int32(a >> b) + output_float64(a / b) + a += b + output_int32(a) + a -= b + output_int32(a) + a *= b + output_int32(a) + a //= b + output_int32(a) + a %= b + output_int32(a) + a |= b + output_int32(a) + a ^= b + output_int32(a) + a &= b + output_int32(a) + a <<= b + output_int32(a) + a >>= b + output_int32(a) + # fail because (a / b) is float + # a /= b + +def test_uint32(): + a = uint32(17) + b = uint32(3) + output_uint32(a + b) + output_uint32(a - b) + output_uint32(a * b) + output_uint32(a // b) + output_uint32(a % b) + output_uint32(a | b) + output_uint32(a ^ b) + output_uint32(a & b) + output_uint32(a << b) + output_uint32(a >> b) + output_float64(a / b) + a += b + output_uint32(a) + a -= b + output_uint32(a) + a *= b + output_uint32(a) + a //= b + output_uint32(a) + a %= b + output_uint32(a) + a |= b + output_uint32(a) + a ^= b + output_uint32(a) + a &= b + output_uint32(a) + a <<= b + output_uint32(a) + a >>= b + output_uint32(a) + +def test_int64(): + a = int64(17) + b = int64(3) + output_int64(a + b) + output_int64(a - b) + output_int64(a * b) + output_int64(a // b) + output_int64(a % b) + output_int64(a | b) + output_int64(a ^ b) + output_int64(a & b) + output_int64(a << b) + output_int64(a >> b) + output_float64(a / b) + a += b + output_int64(a) + a -= b + output_int64(a) + a *= b + output_int64(a) + a //= b + output_int64(a) + a %= b + output_int64(a) + a |= b + output_int64(a) + a ^= b + output_int64(a) + a &= b + output_int64(a) + a <<= b + output_int64(a) + a >>= b + output_int64(a) + +def test_uint64(): + a = uint64(17) + b = uint64(3) + output_uint64(a + b) + output_uint64(a - b) + output_uint64(a * b) + output_uint64(a // b) + output_uint64(a % b) + output_uint64(a | b) + output_uint64(a ^ b) + output_uint64(a & b) + output_uint64(a << b) + output_uint64(a >> b) + output_float64(a / b) + a += b + output_uint64(a) + a -= b + output_uint64(a) + a *= b + output_uint64(a) + a //= b + output_uint64(a) + a %= b + output_uint64(a) + a |= b + output_uint64(a) + a ^= b + output_uint64(a) + a &= b + output_uint64(a) + a <<= b + output_uint64(a) + a >>= b + output_uint64(a) + +class A: + a: int32 + def __init__(self, a: int32): + self.a = a + + def __add__(self, other: A) -> A: + output_int32(self.a + other.a) + return A(self.a + other.a) + + def __sub__(self, other: A) -> A: + output_int32(self.a - other.a) + return A(self.a - other.a) + +def test_A(): + a = A(17) + b = A(3) + + c = a + b + # fail due to alloca in __add__ function + # output_int32(c.a) + + a += b + # fail due to alloca in __add__ function + # output_int32(a.a) + + a = A(17) + b = A(3) + d = a - b + # fail due to alloca in __add__ function + # output_int32(c.a) + + a -= b + # fail due to alloca in __add__ function + # output_int32(a.a) + + a = A(17) + b = A(3) + a.__add__(b) + a.__sub__(b) + + +class B: + a: int32 + def __init__(self, a: int32): + self.a = a + + def __add__(self, other: B) -> B: + output_int32(self.a + other.a) + return B(self.a + other.a) + + def __sub__(self, other: B) -> B: + output_int32(self.a - other.a) + return B(self.a - other.a) + + def __iadd__(self, other: B) -> B: + output_int32(self.a + other.a + 24) + return B(self.a + other.a + 24) + + def __isub__(self, other: B) -> B: + output_int32(self.a - other.a - 24) + return B(self.a - other.a - 24) + +def test_B(): + a = B(17) + b = B(3) + + c = a + b + # fail due to alloca in __add__ function + # output_int32(c.a) + + a += b + # fail due to alloca in __add__ function + # output_int32(a.a) + + a = B(17) + b = B(3) + d = a - b + # fail due to alloca in __add__ function + # output_int32(c.a) + + a -= b + # fail due to alloca in __add__ function + # output_int32(a.a) + + a = B(17) + b = B(3) + a.__add__(b) + a.__sub__(b) \ No newline at end of file diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index 37925157..6b624aa6 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -205,6 +205,14 @@ fn main() { continue; } + // still needs to skip this `from __future__ import annotations` because this seems to be + // magic in python and there seems no way to patch it from another module.. + if matches!( + &stmt.node, + StmtKind::ImportFrom { module, names, .. } + if module == &Some("__future__".into()) && names[0].name == "annotations".into() + ) { continue; } + let (name, def_id, ty) = composer.register_top_level(stmt, Some(resolver.clone()), "__main__".into()).unwrap();