diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 5b31a15dc..4ad363c7f 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -143,7 +143,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { } } - fn gen_int_ops( + pub fn gen_int_ops( &mut self, op: &Operator, lhs: BasicValueEnum<'ctx>, @@ -178,7 +178,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { } } - fn gen_float_ops( + pub fn gen_float_ops( &mut self, op: &Operator, lhs: BasicValueEnum<'ctx>, diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 3688b7b0b..1b00f456d 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -340,6 +340,26 @@ pub fn gen_stmt<'ctx, 'a, G: CodeGenerator + ?Sized>( StmtKind::While { .. } => return generator.gen_while(ctx, stmt), StmtKind::For { .. } => return generator.gen_for(ctx, stmt), StmtKind::With { .. } => return generator.gen_with(ctx, stmt), + StmtKind::AugAssign { target, op, value, .. } => { + let value = { + let ty1 = ctx.unifier.get_representative(target.custom.unwrap()); + let ty2 = ctx.unifier.get_representative(value.custom.unwrap()); + let left = generator.gen_expr(ctx, target).unwrap(); + let right = generator.gen_expr(ctx, value).unwrap(); + + // we can directly compare the types, because we've got their representatives + // which would be unchanged until further unification, which we would never do + // when doing code generation for function instances + if ty1 == ty2 && [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1) { + ctx.gen_int_ops(op, left, right) + } else if ty1 == ty2 && ctx.primitives.float == ty1 { + ctx.gen_float_ops(op, left, right) + } else { + unimplemented!() + } + }; + generator.gen_assign(ctx, target, value); + } _ => unimplemented!(), }; false diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index c0c5c4043..4ab40e614 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -313,6 +313,10 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { } (None, None) => {} }, + ast::StmtKind::AugAssign { target, op, value, .. } => { + let res_ty = self.infer_bin_ops(stmt.location, target, op, value)?; + self.unify(res_ty, target.custom.unwrap(), &stmt.location)?; + } _ => return report_error("Unsupported statement type", stmt.location), }; Ok(stmt)