From fb92b6d364b5a6a4c5bb31bbbf46ad173b46789e Mon Sep 17 00:00:00 2001 From: pca006132 Date: Sat, 23 Oct 2021 23:53:36 +0800 Subject: [PATCH] nac3core: supports range iterator --- nac3core/src/codegen/generator.rs | 11 ++ nac3core/src/codegen/mod.rs | 1 + nac3core/src/codegen/stmt.rs | 109 +++++++++++++++++- nac3core/src/toplevel/composer.rs | 91 ++++++++++++++- nac3core/src/toplevel/helper.rs | 7 +- ...el__test__test_analyze__generic_class.snap | 12 +- ...t__test_analyze__inheritance_override.snap | 18 +-- ...est__test_analyze__list_tuple_generic.snap | 14 +-- ...__toplevel__test__test_analyze__self1.snap | 14 +-- ...t__test_analyze__simple_class_compose.snap | 22 ++-- ...t__test_analyze__simple_pass_in_class.snap | 2 +- nac3core/src/typecheck/type_inferencer/mod.rs | 17 ++- .../src/typecheck/type_inferencer/test.rs | 41 ++++--- 13 files changed, 295 insertions(+), 64 deletions(-) diff --git a/nac3core/src/codegen/generator.rs b/nac3core/src/codegen/generator.rs index 023c7476..53049795 100644 --- a/nac3core/src/codegen/generator.rs +++ b/nac3core/src/codegen/generator.rs @@ -104,6 +104,17 @@ pub trait CodeGenerator { false } + /// Generate code for a while expression. + /// Return true if the while loop must early return + fn gen_for<'ctx, 'a>( + &mut self, + ctx: &mut CodeGenContext<'ctx, 'a>, + stmt: &Stmt>, + ) -> bool { + gen_for(self, ctx, stmt); + false + } + /// Generate code for an if expression. /// Return true if the statement must early return fn gen_if<'ctx, 'a>( diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 5db7c6f6..17396417 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -290,6 +290,7 @@ pub fn gen_func<'ctx, G: CodeGenerator + ?Sized>( float: unifier.get_representative(primitives.float), bool: unifier.get_representative(primitives.bool), none: unifier.get_representative(primitives.none), + range: unifier.get_representative(primitives.range), }; let mut type_cache: HashMap<_, _> = [ diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 806472b9..bdb2ebab 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -96,6 +96,112 @@ pub fn gen_assign<'ctx, 'a, G: CodeGenerator + ?Sized>( } } +pub fn gen_for<'ctx, 'a, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + stmt: &Stmt>, +) { + if let StmtKind::For { iter, target, body, orelse, .. } = &stmt.node { + let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); + let test_bb = ctx.ctx.append_basic_block(current, "test"); + let body_bb = ctx.ctx.append_basic_block(current, "body"); + let cont_bb = ctx.ctx.append_basic_block(current, "cont"); + // if there is no orelse, we just go to cont_bb + let orelse_bb = + if orelse.is_empty() { cont_bb } else { ctx.ctx.append_basic_block(current, "orelse") }; + // store loop bb information and restore it later + let loop_bb = ctx.loop_bb.replace((test_bb, cont_bb)); + + let iter_val = generator.gen_expr(ctx, iter).unwrap(); + if ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range) { + // setup + let iter_val = iter_val.into_pointer_value(); + let i = generator.gen_store_target(ctx, target); + let int32 = ctx.ctx.i32_type(); + let start; + let end; + let step; + unsafe { + start = ctx + .builder + .build_load( + ctx.builder.build_in_bounds_gep( + iter_val, + &[int32.const_zero(), int32.const_int(0, false)], + "start_ptr", + ), + "start", + ) + .into_int_value(); + end = ctx + .builder + .build_load( + ctx.builder.build_in_bounds_gep( + iter_val, + &[int32.const_zero(), int32.const_int(1, false)], + "end_ptr", + ), + "end", + ) + .into_int_value(); + step = ctx + .builder + .build_load( + ctx.builder.build_in_bounds_gep( + iter_val, + &[int32.const_zero(), int32.const_int(2, false)], + "step_ptr", + ), + "step", + ) + .into_int_value(); + } + ctx.builder.build_store(i, ctx.builder.build_int_sub(start, step, "start_init")); + ctx.builder.build_unconditional_branch(test_bb); + ctx.builder.position_at_end(test_bb); + let sign = ctx.builder.build_int_compare( + inkwell::IntPredicate::SGT, + step, + int32.const_zero(), + "sign", + ); + // add and test + let tmp = ctx.builder.build_int_add(ctx.builder.build_load(i, "i").into_int_value(), step, "start_loop"); + ctx.builder.build_store(i, tmp); + // // if step > 0, continue when i < end + let cmp1 = ctx.builder.build_int_compare(inkwell::IntPredicate::SLT, tmp, end, "cmp1"); + // if step < 0, continue when i > end + let cmp2 = ctx.builder.build_int_compare(inkwell::IntPredicate::SGT, tmp, end, "cmp2"); + let pos = ctx.builder.build_and(sign, cmp1, "pos"); + let neg = ctx.builder.build_and(ctx.builder.build_not(sign, "inv"), cmp2, "neg"); + ctx.builder.build_conditional_branch( + ctx.builder.build_or(pos, neg, "or"), + body_bb, + orelse_bb, + ); + } else { + unimplemented!() + } + + ctx.builder.position_at_end(body_bb); + for stmt in body.iter() { + generator.gen_stmt(ctx, stmt); + } + ctx.builder.build_unconditional_branch(test_bb); + if !orelse.is_empty() { + ctx.builder.position_at_end(orelse_bb); + for stmt in orelse.iter() { + generator.gen_stmt(ctx, stmt); + } + ctx.builder.build_unconditional_branch(cont_bb); + } + ctx.builder.position_at_end(cont_bb); + ctx.loop_bb = loop_bb; + } else { + unreachable!() + } +} + pub fn gen_while<'ctx, 'a, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, @@ -244,7 +350,8 @@ pub fn gen_stmt<'ctx, 'a, G: CodeGenerator + ?Sized>( } StmtKind::If { .. } => return generator.gen_if(ctx, stmt), StmtKind::While { .. } => return generator.gen_while(ctx, stmt), - _ => unimplemented!() + StmtKind::For { .. } => return generator.gen_for(ctx, stmt), + _ => unimplemented!(), }; false } diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index d471e44b..c6d3688b 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -2,7 +2,10 @@ use std::cell::RefCell; use rustpython_parser::ast::fold::Fold; -use crate::typecheck::type_inferencer::{FunctionData, Inferencer}; +use crate::{ + symbol_resolver::SymbolValue, + typecheck::type_inferencer::{FunctionData, Inferencer}, +}; use super::*; @@ -42,6 +45,7 @@ impl TopLevelComposer { let int64 = primitives.0.int64; let float = primitives.0.float; let boolean = primitives.0.bool; + let range = primitives.0.range; let num_ty = primitives.1.get_fresh_var_with_range(&[int32, int64, float, boolean]); let var_map: HashMap<_, _> = vec![(num_ty.1, num_ty.0)].into_iter().collect(); @@ -67,6 +71,12 @@ impl TopLevelComposer { ))), Arc::new(RwLock::new(Self::make_top_level_class_def(3, None, "bool".into(), None))), Arc::new(RwLock::new(Self::make_top_level_class_def(4, None, "none".into(), None))), + Arc::new(RwLock::new(Self::make_top_level_class_def( + 5, + None, + "range".into(), + None, + ))), Arc::new(RwLock::new(TopLevelDef::Function { name: "int32".into(), simple_name: "int32".into(), @@ -287,6 +297,81 @@ impl TopLevelComposer { ) })))), })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "range".into(), + simple_name: "range".into(), + signature: primitives.1.add_ty(TypeEnum::TFunc(RefCell::new(FunSignature { + args: vec![ + FuncArg { name: "start".into(), ty: int32, default_value: None }, + FuncArg { + name: "stop".into(), + ty: int32, + // placeholder + default_value: Some(SymbolValue::I32(0)), + }, + FuncArg { + name: "step".into(), + ty: int32, + default_value: Some(SymbolValue::I32(1)), + }, + ], + ret: range, + vars: Default::default(), + }))), + var_id: Default::default(), + instance_to_symbol: Default::default(), + instance_to_stmt: Default::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, _, args| { + let mut start = None; + let mut stop = None; + let mut step = None; + let int32 = ctx.ctx.i32_type(); + let zero = int32.const_zero(); + for (i, arg) in args.iter().enumerate() { + if arg.0 == Some("start".into()) { + start = Some(arg.1); + } else if arg.0 == Some("stop".into()) { + stop = Some(arg.1); + } else if arg.0 == Some("step".into()) { + step = Some(arg.1); + } else if i == 0 { + start = Some(arg.1); + } else if i == 1 { + stop = Some(arg.1); + } else if i == 2 { + step = Some(arg.1); + } + } + // TODO: error when step == 0 + let step = step.unwrap_or_else(|| int32.const_int(1, false).into()); + let stop = stop.unwrap_or_else(|| { + let v = start.unwrap(); + start = None; + v + }); + let start = start.unwrap_or_else(|| int32.const_zero().into()); + let ty = int32.array_type(3); + let ptr = ctx.builder.build_alloca(ty, "range"); + unsafe { + let a = ctx.builder.build_in_bounds_gep(ptr, &[zero, zero], "start"); + let b = ctx.builder.build_in_bounds_gep( + ptr, + &[zero, int32.const_int(1, false)], + "end", + ); + let c = ctx.builder.build_in_bounds_gep( + ptr, + &[zero, int32.const_int(2, false)], + "step", + ); + ctx.builder.build_store(a, start); + ctx.builder.build_store(b, stop); + ctx.builder.build_store(c, step); + } + Some(ptr.into()) + })))), + })), ]; let ast_list: Vec>> = (0..top_level_def_list.len()).map(|_| None).collect(); @@ -315,7 +400,9 @@ impl TopLevelComposer { let mut built_in_id: HashMap = Default::default(); let mut built_in_ty: HashMap = Default::default(); - for (id, name) in ["int32", "int64", "float", "round", "round64"].iter().rev().enumerate() { + for (id, name) in + ["int32", "int64", "float", "round", "round64", "range"].iter().rev().enumerate() + { let name = (**name).into(); let id = definition_ast_list.len() - id - 1; let def = definition_ast_list[id].0.read(); diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index afd802cf..b3d76004 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -82,7 +82,12 @@ impl TopLevelComposer { fields: HashMap::new().into(), params: HashMap::new().into(), }); - let primitives = PrimitiveStore { int32, int64, float, bool, none }; + let range = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(5), + fields: HashMap::new().into(), + params: HashMap::new().into(), + }); + let primitives = PrimitiveStore { int32, int64, float, bool, none, range }; crate::typecheck::magic_methods::set_primitives_magic_methods(&primitives, &mut unifier); (primitives, unifier) } diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index 1ce3a7bd..063c296f 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -4,10 +4,10 @@ expression: res_vec --- [ - "10: Class {\nname: \"Generic_A\",\ndef_id: DefinitionId(10),\nancestors: [CustomClassKind { id: DefinitionId(10), params: [TypeVarKind(UnificationKey(107))] }, CustomClassKind { id: DefinitionId(13), params: [] }],\nfields: [(\"aa\", \"class3\"), (\"a\", \"class1\")],\nmethods: [(\"__init__\", \"fn[[], class4]\", DefinitionId(11)), (\"foo\", \"fn[[b=tvar3], class4]\", DefinitionId(15)), (\"fun\", \"fn[[a=class0], tvar4]\", DefinitionId(12))],\ntype_vars: [UnificationKey(107)]\n}\n", - "11: Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], class4]\",\nvar_id: [4]\n}\n", - "12: Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a=class0], tvar4]\",\nvar_id: [4]\n}\n", - "13: Class {\nname: \"B\",\ndef_id: DefinitionId(13),\nancestors: [CustomClassKind { id: DefinitionId(13), params: [] }],\nfields: [(\"aa\", \"class3\")],\nmethods: [(\"__init__\", \"fn[[], class4]\", DefinitionId(14)), (\"foo\", \"fn[[b=tvar3], class4]\", DefinitionId(15))],\ntype_vars: []\n}\n", - "14: Function {\nname: \"B.__init__\",\nsig: \"fn[[], class4]\",\nvar_id: []\n}\n", - "15: Function {\nname: \"B.foo\",\nsig: \"fn[[b=tvar3], class4]\",\nvar_id: [3]\n}\n", + "12: Class {\nname: \"Generic_A\",\ndef_id: DefinitionId(12),\nancestors: [CustomClassKind { id: DefinitionId(12), params: [TypeVarKind(UnificationKey(109))] }, CustomClassKind { id: DefinitionId(15), params: [] }],\nfields: [(\"aa\", \"class3\"), (\"a\", \"class1\")],\nmethods: [(\"__init__\", \"fn[[], class4]\", DefinitionId(13)), (\"foo\", \"fn[[b=tvar3], class4]\", DefinitionId(17)), (\"fun\", \"fn[[a=class0], tvar4]\", DefinitionId(14))],\ntype_vars: [UnificationKey(109)]\n}\n", + "13: Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], class4]\",\nvar_id: [4]\n}\n", + "14: Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a=class0], tvar4]\",\nvar_id: [4]\n}\n", + "15: Class {\nname: \"B\",\ndef_id: DefinitionId(15),\nancestors: [CustomClassKind { id: DefinitionId(15), params: [] }],\nfields: [(\"aa\", \"class3\")],\nmethods: [(\"__init__\", \"fn[[], class4]\", DefinitionId(16)), (\"foo\", \"fn[[b=tvar3], class4]\", DefinitionId(17))],\ntype_vars: []\n}\n", + "16: Function {\nname: \"B.__init__\",\nsig: \"fn[[], class4]\",\nvar_id: []\n}\n", + "17: Function {\nname: \"B.foo\",\nsig: \"fn[[b=tvar3], class4]\",\nvar_id: [3]\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index ca43daed..6fdb6a72 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -4,13 +4,13 @@ expression: res_vec --- [ - "10: Class {\nname: \"A\",\ndef_id: DefinitionId(10),\nancestors: [CustomClassKind { id: DefinitionId(10), params: [TypeVarKind(UnificationKey(106))] }],\nfields: [(\"a\", \"class0\"), (\"b\", \"tvar3\"), (\"c\", \"class10[3->class1]\")],\nmethods: [(\"__init__\", \"fn[[t=tvar3], class4]\", DefinitionId(11)), (\"fun\", \"fn[[a=class0, b=tvar3], list[virtual[class14[4->class3]]]]\", DefinitionId(12)), (\"foo\", \"fn[[c=class17], class4]\", DefinitionId(13))],\ntype_vars: [UnificationKey(106)]\n}\n", - "11: Function {\nname: \"A.__init__\",\nsig: \"fn[[t=tvar3], class4]\",\nvar_id: [3]\n}\n", - "12: Function {\nname: \"A.fun\",\nsig: \"fn[[a=class0, b=tvar3], list[virtual[class14[4->class3]]]]\",\nvar_id: [3]\n}\n", - "13: Function {\nname: \"A.foo\",\nsig: \"fn[[c=class17], class4]\",\nvar_id: [3]\n}\n", - "14: Class {\nname: \"B\",\ndef_id: DefinitionId(14),\nancestors: [CustomClassKind { id: DefinitionId(14), params: [TypeVarKind(UnificationKey(107))] }, CustomClassKind { id: DefinitionId(10), params: [PrimitiveKind(UnificationKey(2))] }],\nfields: [(\"a\", \"class0\"), (\"b\", \"tvar3\"), (\"c\", \"class10[3->class1]\"), (\"d\", \"class17\")],\nmethods: [(\"__init__\", \"fn[[], class4]\", DefinitionId(15)), (\"fun\", \"fn[[a=class0, b=tvar3], list[virtual[class14[4->class3]]]]\", DefinitionId(16)), (\"foo\", \"fn[[c=class17], class4]\", DefinitionId(13))],\ntype_vars: [UnificationKey(107)]\n}\n", - "15: Function {\nname: \"B.__init__\",\nsig: \"fn[[], class4]\",\nvar_id: [4]\n}\n", - "16: Function {\nname: \"B.fun\",\nsig: \"fn[[a=class0, b=tvar3], list[virtual[class14[4->class3]]]]\",\nvar_id: [3, 4]\n}\n", - "17: Class {\nname: \"C\",\ndef_id: DefinitionId(17),\nancestors: [CustomClassKind { id: DefinitionId(17), params: [] }, CustomClassKind { id: DefinitionId(14), params: [PrimitiveKind(UnificationKey(3))] }, CustomClassKind { id: DefinitionId(10), params: [PrimitiveKind(UnificationKey(2))] }],\nfields: [(\"a\", \"class0\"), (\"b\", \"tvar3\"), (\"c\", \"class10[3->class1]\"), (\"d\", \"class17\"), (\"e\", \"class1\")],\nmethods: [(\"__init__\", \"fn[[], class4]\", DefinitionId(18)), (\"fun\", \"fn[[a=class0, b=tvar3], list[virtual[class14[4->class3]]]]\", DefinitionId(16)), (\"foo\", \"fn[[c=class17], class4]\", DefinitionId(13))],\ntype_vars: []\n}\n", - "18: Function {\nname: \"C.__init__\",\nsig: \"fn[[], class4]\",\nvar_id: []\n}\n", + "12: Class {\nname: \"A\",\ndef_id: DefinitionId(12),\nancestors: [CustomClassKind { id: DefinitionId(12), params: [TypeVarKind(UnificationKey(108))] }],\nfields: [(\"a\", \"class0\"), (\"b\", \"tvar3\"), (\"c\", \"class12[3->class1]\")],\nmethods: [(\"__init__\", \"fn[[t=tvar3], class4]\", DefinitionId(13)), (\"fun\", \"fn[[a=class0, b=tvar3], list[virtual[class16[4->class3]]]]\", DefinitionId(14)), (\"foo\", \"fn[[c=class19], class4]\", DefinitionId(15))],\ntype_vars: [UnificationKey(108)]\n}\n", + "13: Function {\nname: \"A.__init__\",\nsig: \"fn[[t=tvar3], class4]\",\nvar_id: [3]\n}\n", + "14: Function {\nname: \"A.fun\",\nsig: \"fn[[a=class0, b=tvar3], list[virtual[class16[4->class3]]]]\",\nvar_id: [3]\n}\n", + "15: Function {\nname: \"A.foo\",\nsig: \"fn[[c=class19], class4]\",\nvar_id: [3]\n}\n", + "16: Class {\nname: \"B\",\ndef_id: DefinitionId(16),\nancestors: [CustomClassKind { id: DefinitionId(16), params: [TypeVarKind(UnificationKey(109))] }, CustomClassKind { id: DefinitionId(12), params: [PrimitiveKind(UnificationKey(2))] }],\nfields: [(\"a\", \"class0\"), (\"b\", \"tvar3\"), (\"c\", \"class12[3->class1]\"), (\"d\", \"class19\")],\nmethods: [(\"__init__\", \"fn[[], class4]\", DefinitionId(17)), (\"fun\", \"fn[[a=class0, b=tvar3], list[virtual[class16[4->class3]]]]\", DefinitionId(18)), (\"foo\", \"fn[[c=class19], class4]\", DefinitionId(15))],\ntype_vars: [UnificationKey(109)]\n}\n", + "17: Function {\nname: \"B.__init__\",\nsig: \"fn[[], class4]\",\nvar_id: [4]\n}\n", + "18: Function {\nname: \"B.fun\",\nsig: \"fn[[a=class0, b=tvar3], list[virtual[class16[4->class3]]]]\",\nvar_id: [3, 4]\n}\n", + "19: Class {\nname: \"C\",\ndef_id: DefinitionId(19),\nancestors: [CustomClassKind { id: DefinitionId(19), params: [] }, CustomClassKind { id: DefinitionId(16), params: [PrimitiveKind(UnificationKey(3))] }, CustomClassKind { id: DefinitionId(12), params: [PrimitiveKind(UnificationKey(2))] }],\nfields: [(\"a\", \"class0\"), (\"b\", \"tvar3\"), (\"c\", \"class12[3->class1]\"), (\"d\", \"class19\"), (\"e\", \"class1\")],\nmethods: [(\"__init__\", \"fn[[], class4]\", DefinitionId(20)), (\"fun\", \"fn[[a=class0, b=tvar3], list[virtual[class16[4->class3]]]]\", DefinitionId(18)), (\"foo\", \"fn[[c=class19], class4]\", DefinitionId(15))],\ntype_vars: []\n}\n", + "20: Function {\nname: \"C.__init__\",\nsig: \"fn[[], class4]\",\nvar_id: []\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index 510136c2..48d49a91 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -4,11 +4,11 @@ expression: res_vec --- [ - "10: Function {\nname: \"foo\",\nsig: \"fn[[a=list[class0], b=tuple[tvar3, class2]], class11[3->class15, 4->class3]]\",\nvar_id: [3]\n}\n", - "11: Class {\nname: \"A\",\ndef_id: DefinitionId(11),\nancestors: [CustomClassKind { id: DefinitionId(11), params: [TypeVarKind(UnificationKey(106)), TypeVarKind(UnificationKey(107))] }],\nfields: [(\"a\", \"tvar3\"), (\"b\", \"tvar4\")],\nmethods: [(\"__init__\", \"fn[[v=tvar4], class4]\", DefinitionId(12)), (\"fun\", \"fn[[a=tvar3], tvar4]\", DefinitionId(13))],\ntype_vars: [UnificationKey(106), UnificationKey(107)]\n}\n", - "12: Function {\nname: \"A.__init__\",\nsig: \"fn[[v=tvar4], class4]\",\nvar_id: [3, 4]\n}\n", - "13: Function {\nname: \"A.fun\",\nsig: \"fn[[a=tvar3], tvar4]\",\nvar_id: [3, 4]\n}\n", - "14: Function {\nname: \"gfun\",\nsig: \"fn[[a=class11[3->list[class2], 4->class0]], class4]\",\nvar_id: []\n}\n", - "15: Class {\nname: \"B\",\ndef_id: DefinitionId(15),\nancestors: [CustomClassKind { id: DefinitionId(15), params: [] }],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], class4]\", DefinitionId(16))],\ntype_vars: []\n}\n", - "16: Function {\nname: \"B.__init__\",\nsig: \"fn[[], class4]\",\nvar_id: []\n}\n", + "12: Function {\nname: \"foo\",\nsig: \"fn[[a=list[class0], b=tuple[tvar3, class2]], class13[3->class17, 4->class3]]\",\nvar_id: [3]\n}\n", + "13: Class {\nname: \"A\",\ndef_id: DefinitionId(13),\nancestors: [CustomClassKind { id: DefinitionId(13), params: [TypeVarKind(UnificationKey(108)), TypeVarKind(UnificationKey(109))] }],\nfields: [(\"a\", \"tvar3\"), (\"b\", \"tvar4\")],\nmethods: [(\"__init__\", \"fn[[v=tvar4], class4]\", DefinitionId(14)), (\"fun\", \"fn[[a=tvar3], tvar4]\", DefinitionId(15))],\ntype_vars: [UnificationKey(108), UnificationKey(109)]\n}\n", + "14: Function {\nname: \"A.__init__\",\nsig: \"fn[[v=tvar4], class4]\",\nvar_id: [3, 4]\n}\n", + "15: Function {\nname: \"A.fun\",\nsig: \"fn[[a=tvar3], tvar4]\",\nvar_id: [3, 4]\n}\n", + "16: Function {\nname: \"gfun\",\nsig: \"fn[[a=class13[3->list[class2], 4->class0]], class4]\",\nvar_id: []\n}\n", + "17: Class {\nname: \"B\",\ndef_id: DefinitionId(17),\nancestors: [CustomClassKind { id: DefinitionId(17), params: [] }],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], class4]\", DefinitionId(18))],\ntype_vars: []\n}\n", + "18: Function {\nname: \"B.__init__\",\nsig: \"fn[[], class4]\",\nvar_id: []\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index 91ebc409..60832b44 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -4,11 +4,11 @@ expression: res_vec --- [ - "10: Class {\nname: \"A\",\ndef_id: DefinitionId(10),\nancestors: [CustomClassKind { id: DefinitionId(10), params: [TypeVarKind(UnificationKey(106)), TypeVarKind(UnificationKey(107))] }],\nfields: [(\"a\", \"class10[3->class2, 4->class3]\"), (\"b\", \"class13\")],\nmethods: [(\"__init__\", \"fn[[a=class10[3->class2, 4->class3], b=class13], class4]\", DefinitionId(11)), (\"fun\", \"fn[[a=class10[3->class2, 4->class3]], class10[3->class3, 4->class0]]\", DefinitionId(12))],\ntype_vars: [UnificationKey(106), UnificationKey(107)]\n}\n", - "11: Function {\nname: \"A.__init__\",\nsig: \"fn[[a=class10[3->class2, 4->class3], b=class13], class4]\",\nvar_id: [3, 4]\n}\n", - "12: Function {\nname: \"A.fun\",\nsig: \"fn[[a=class10[3->class2, 4->class3]], class10[3->class3, 4->class0]]\",\nvar_id: [3, 4]\n}\n", - "13: Class {\nname: \"B\",\ndef_id: DefinitionId(13),\nancestors: [CustomClassKind { id: DefinitionId(13), params: [] }, CustomClassKind { id: DefinitionId(10), params: [PrimitiveKind(UnificationKey(1)), PrimitiveKind(UnificationKey(3))] }],\nfields: [(\"a\", \"class10[3->class2, 4->class3]\"), (\"b\", \"class13\")],\nmethods: [(\"__init__\", \"fn[[], class4]\", DefinitionId(14)), (\"fun\", \"fn[[a=class10[3->class2, 4->class3]], class10[3->class3, 4->class0]]\", DefinitionId(12)), (\"foo\", \"fn[[b=class13], class13]\", DefinitionId(15)), (\"bar\", \"fn[[a=class10[3->list[class13], 4->class0]], tuple[class10[3->virtual[class10[3->class13, 4->class0]], 4->class3], class13]]\", DefinitionId(16))],\ntype_vars: []\n}\n", - "14: Function {\nname: \"B.__init__\",\nsig: \"fn[[], class4]\",\nvar_id: []\n}\n", - "15: Function {\nname: \"B.foo\",\nsig: \"fn[[b=class13], class13]\",\nvar_id: []\n}\n", - "16: Function {\nname: \"B.bar\",\nsig: \"fn[[a=class10[3->list[class13], 4->class0]], tuple[class10[3->virtual[class10[3->class13, 4->class0]], 4->class3], class13]]\",\nvar_id: []\n}\n", + "12: Class {\nname: \"A\",\ndef_id: DefinitionId(12),\nancestors: [CustomClassKind { id: DefinitionId(12), params: [TypeVarKind(UnificationKey(108)), TypeVarKind(UnificationKey(109))] }],\nfields: [(\"a\", \"class12[3->class2, 4->class3]\"), (\"b\", \"class15\")],\nmethods: [(\"__init__\", \"fn[[a=class12[3->class2, 4->class3], b=class15], class4]\", DefinitionId(13)), (\"fun\", \"fn[[a=class12[3->class2, 4->class3]], class12[3->class3, 4->class0]]\", DefinitionId(14))],\ntype_vars: [UnificationKey(108), UnificationKey(109)]\n}\n", + "13: Function {\nname: \"A.__init__\",\nsig: \"fn[[a=class12[3->class2, 4->class3], b=class15], class4]\",\nvar_id: [3, 4]\n}\n", + "14: Function {\nname: \"A.fun\",\nsig: \"fn[[a=class12[3->class2, 4->class3]], class12[3->class3, 4->class0]]\",\nvar_id: [3, 4]\n}\n", + "15: Class {\nname: \"B\",\ndef_id: DefinitionId(15),\nancestors: [CustomClassKind { id: DefinitionId(15), params: [] }, CustomClassKind { id: DefinitionId(12), params: [PrimitiveKind(UnificationKey(1)), PrimitiveKind(UnificationKey(3))] }],\nfields: [(\"a\", \"class12[3->class2, 4->class3]\"), (\"b\", \"class15\")],\nmethods: [(\"__init__\", \"fn[[], class4]\", DefinitionId(16)), (\"fun\", \"fn[[a=class12[3->class2, 4->class3]], class12[3->class3, 4->class0]]\", DefinitionId(14)), (\"foo\", \"fn[[b=class15], class15]\", DefinitionId(17)), (\"bar\", \"fn[[a=class12[3->list[class15], 4->class0]], tuple[class12[3->virtual[class12[3->class15, 4->class0]], 4->class3], class15]]\", DefinitionId(18))],\ntype_vars: []\n}\n", + "16: Function {\nname: \"B.__init__\",\nsig: \"fn[[], class4]\",\nvar_id: []\n}\n", + "17: Function {\nname: \"B.foo\",\nsig: \"fn[[b=class15], class15]\",\nvar_id: []\n}\n", + "18: Function {\nname: \"B.bar\",\nsig: \"fn[[a=class12[3->list[class15], 4->class0]], tuple[class12[3->virtual[class12[3->class15, 4->class0]], 4->class3], class15]]\",\nvar_id: []\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index ae9f1fb9..a10d819c 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -4,15 +4,15 @@ expression: res_vec --- [ - "10: Class {\nname: \"A\",\ndef_id: DefinitionId(10),\nancestors: [CustomClassKind { id: DefinitionId(10), params: [] }],\nfields: [(\"a\", \"class0\")],\nmethods: [(\"__init__\", \"fn[[], class4]\", DefinitionId(11)), (\"fun\", \"fn[[b=class14], class4]\", DefinitionId(12)), (\"foo\", \"fn[[a=tvar3, b=tvar4], class4]\", DefinitionId(13))],\ntype_vars: []\n}\n", - "11: Function {\nname: \"A.__init__\",\nsig: \"fn[[], class4]\",\nvar_id: []\n}\n", - "12: Function {\nname: \"A.fun\",\nsig: \"fn[[b=class14], class4]\",\nvar_id: []\n}\n", - "13: Function {\nname: \"A.foo\",\nsig: \"fn[[a=tvar3, b=tvar4], class4]\",\nvar_id: [3, 4]\n}\n", - "14: Class {\nname: \"B\",\ndef_id: DefinitionId(14),\nancestors: [CustomClassKind { id: DefinitionId(14), params: [] }, CustomClassKind { id: DefinitionId(16), params: [] }, CustomClassKind { id: DefinitionId(10), params: [] }],\nfields: [(\"a\", \"class0\")],\nmethods: [(\"__init__\", \"fn[[], class4]\", DefinitionId(15)), (\"fun\", \"fn[[b=class14], class4]\", DefinitionId(18)), (\"foo\", \"fn[[a=tvar3, b=tvar4], class4]\", DefinitionId(13))],\ntype_vars: []\n}\n", - "15: Function {\nname: \"B.__init__\",\nsig: \"fn[[], class4]\",\nvar_id: []\n}\n", - "16: Class {\nname: \"C\",\ndef_id: DefinitionId(16),\nancestors: [CustomClassKind { id: DefinitionId(16), params: [] }, CustomClassKind { id: DefinitionId(10), params: [] }],\nfields: [(\"a\", \"class0\")],\nmethods: [(\"__init__\", \"fn[[], class4]\", DefinitionId(17)), (\"fun\", \"fn[[b=class14], class4]\", DefinitionId(18)), (\"foo\", \"fn[[a=tvar3, b=tvar4], class4]\", DefinitionId(13))],\ntype_vars: []\n}\n", - "17: Function {\nname: \"C.__init__\",\nsig: \"fn[[], class4]\",\nvar_id: []\n}\n", - "18: Function {\nname: \"C.fun\",\nsig: \"fn[[b=class14], class4]\",\nvar_id: []\n}\n", - "19: Function {\nname: \"foo\",\nsig: \"fn[[a=class10], class4]\",\nvar_id: []\n}\n", - "20: Function {\nname: \"ff\",\nsig: \"fn[[a=tvar3], tvar4]\",\nvar_id: [3, 4]\n}\n", + "12: Class {\nname: \"A\",\ndef_id: DefinitionId(12),\nancestors: [CustomClassKind { id: DefinitionId(12), params: [] }],\nfields: [(\"a\", \"class0\")],\nmethods: [(\"__init__\", \"fn[[], class4]\", DefinitionId(13)), (\"fun\", \"fn[[b=class16], class4]\", DefinitionId(14)), (\"foo\", \"fn[[a=tvar3, b=tvar4], class4]\", DefinitionId(15))],\ntype_vars: []\n}\n", + "13: Function {\nname: \"A.__init__\",\nsig: \"fn[[], class4]\",\nvar_id: []\n}\n", + "14: Function {\nname: \"A.fun\",\nsig: \"fn[[b=class16], class4]\",\nvar_id: []\n}\n", + "15: Function {\nname: \"A.foo\",\nsig: \"fn[[a=tvar3, b=tvar4], class4]\",\nvar_id: [3, 4]\n}\n", + "16: Class {\nname: \"B\",\ndef_id: DefinitionId(16),\nancestors: [CustomClassKind { id: DefinitionId(16), params: [] }, CustomClassKind { id: DefinitionId(18), params: [] }, CustomClassKind { id: DefinitionId(12), params: [] }],\nfields: [(\"a\", \"class0\")],\nmethods: [(\"__init__\", \"fn[[], class4]\", DefinitionId(17)), (\"fun\", \"fn[[b=class16], class4]\", DefinitionId(20)), (\"foo\", \"fn[[a=tvar3, b=tvar4], class4]\", DefinitionId(15))],\ntype_vars: []\n}\n", + "17: Function {\nname: \"B.__init__\",\nsig: \"fn[[], class4]\",\nvar_id: []\n}\n", + "18: Class {\nname: \"C\",\ndef_id: DefinitionId(18),\nancestors: [CustomClassKind { id: DefinitionId(18), params: [] }, CustomClassKind { id: DefinitionId(12), params: [] }],\nfields: [(\"a\", \"class0\")],\nmethods: [(\"__init__\", \"fn[[], class4]\", DefinitionId(19)), (\"fun\", \"fn[[b=class16], class4]\", DefinitionId(20)), (\"foo\", \"fn[[a=tvar3, b=tvar4], class4]\", DefinitionId(15))],\ntype_vars: []\n}\n", + "19: Function {\nname: \"C.__init__\",\nsig: \"fn[[], class4]\",\nvar_id: []\n}\n", + "20: Function {\nname: \"C.fun\",\nsig: \"fn[[b=class16], class4]\",\nvar_id: []\n}\n", + "21: Function {\nname: \"foo\",\nsig: \"fn[[a=class12], class4]\",\nvar_id: []\n}\n", + "22: Function {\nname: \"ff\",\nsig: \"fn[[a=tvar3], tvar4]\",\nvar_id: [3, 4]\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_pass_in_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_pass_in_class.snap index 3d444393..606663aa 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_pass_in_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_pass_in_class.snap @@ -4,5 +4,5 @@ expression: res_vec --- [ - "10: Class {\nname: \"A\",\ndef_id: DefinitionId(10),\nancestors: [CustomClassKind { id: DefinitionId(10), params: [] }],\nfields: [],\nmethods: [],\ntype_vars: []\n}\n", + "12: Class {\nname: \"A\",\ndef_id: DefinitionId(12),\nancestors: [CustomClassKind { id: DefinitionId(12), params: [] }],\nfields: [],\nmethods: [],\ntype_vars: []\n}\n", ] diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index fd4a4e8e..7f42d454 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -35,6 +35,7 @@ pub struct PrimitiveStore { pub float: Type, pub bool: Type, pub none: Type, + pub range: Type, } pub struct FunctionData { @@ -168,8 +169,12 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { }; match &stmt.node { ast::StmtKind::For { target, iter, .. } => { - let list = self.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() }); - self.unify(list, iter.custom.unwrap(), &iter.location)?; + if self.unifier.unioned(iter.custom.unwrap(), self.primitives.range) { + self.unify(self.primitives.int32, target.custom.unwrap(), &target.location)?; + } else { + let list = self.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() }); + self.unify(list, iter.custom.unwrap(), &iter.location)?; + } } ast::StmtKind::If { test, .. } | ast::StmtKind::While { test, .. } => { self.unify(test.custom.unwrap(), self.primitives.bool, &test.location)?; @@ -445,8 +450,12 @@ impl<'a> Inferencer<'a> { new_context.infer_pattern(&generator.target)?; let target = new_context.fold_expr(*generator.target)?; let iter = new_context.fold_expr(*generator.iter)?; - let list = new_context.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() }); - new_context.unify(iter.custom.unwrap(), list, &iter.location)?; + if new_context.unifier.unioned(iter.custom.unwrap(), new_context.primitives.range) { + new_context.unify(target.custom.unwrap(), new_context.primitives.int32, &target.location)?; + } else { + let list = new_context.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() }); + new_context.unify(iter.custom.unwrap(), list, &iter.location)?; + } let ifs: Vec<_> = generator .ifs .into_iter() diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index 180313e5..0e096f10 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -97,7 +97,12 @@ impl TestEnvironment { fields: HashMap::new().into(), params: HashMap::new().into(), }); - let primitives = PrimitiveStore { int32, int64, float, bool, none }; + let range = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(5), + fields: HashMap::new().into(), + params: HashMap::new().into(), + }); + let primitives = PrimitiveStore { int32, int64, float, bool, none, range }; set_primitives_magic_methods(&primitives, &mut unifier); let id_to_name = [ @@ -180,8 +185,13 @@ impl TestEnvironment { fields: HashMap::new().into(), params: HashMap::new().into(), }); + let range = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(5), + fields: HashMap::new().into(), + params: HashMap::new().into(), + }); identifier_mapping.insert("None".into(), none); - for (i, name) in ["int32", "int64", "float", "bool", "none"].iter().enumerate() { + for (i, name) in ["int32", "int64", "float", "bool", "none", "range"].iter().enumerate() { top_level_defs.push( RwLock::new(TopLevelDef::Class { name: (*name).into(), @@ -197,19 +207,19 @@ impl TestEnvironment { ); } - let primitives = PrimitiveStore { int32, int64, float, bool, none }; + let primitives = PrimitiveStore { int32, int64, float, bool, none, range }; let (v0, id) = unifier.get_fresh_var(); let foo_ty = unifier.add_ty(TypeEnum::TObj { - obj_id: DefinitionId(5), + obj_id: DefinitionId(6), fields: [("a".into(), v0)].iter().cloned().collect::>().into(), params: [(id, v0)].iter().cloned().collect::>().into(), }); top_level_defs.push( RwLock::new(TopLevelDef::Class { name: "Foo".into(), - object_id: DefinitionId(5), + object_id: DefinitionId(6), type_vars: vec![v0], fields: [("a".into(), v0)].into(), methods: Default::default(), @@ -236,7 +246,7 @@ impl TestEnvironment { FunSignature { args: vec![], ret: int32, vars: Default::default() }.into(), )); let bar = unifier.add_ty(TypeEnum::TObj { - obj_id: DefinitionId(6), + obj_id: DefinitionId(7), fields: [("a".into(), int32), ("b".into(), fun)] .iter() .cloned() @@ -247,7 +257,7 @@ impl TestEnvironment { top_level_defs.push( RwLock::new(TopLevelDef::Class { name: "Bar".into(), - object_id: DefinitionId(6), + object_id: DefinitionId(7), type_vars: Default::default(), fields: [("a".into(), int32), ("b".into(), fun)].into(), methods: Default::default(), @@ -265,7 +275,7 @@ impl TestEnvironment { ); let bar2 = unifier.add_ty(TypeEnum::TObj { - obj_id: DefinitionId(7), + obj_id: DefinitionId(8), fields: [("a".into(), bool), ("b".into(), fun)] .iter() .cloned() @@ -276,7 +286,7 @@ impl TestEnvironment { top_level_defs.push( RwLock::new(TopLevelDef::Class { name: "Bar2".into(), - object_id: DefinitionId(7), + object_id: DefinitionId(8), type_vars: Default::default(), fields: [("a".into(), bool), ("b".into(), fun)].into(), methods: Default::default(), @@ -300,9 +310,10 @@ impl TestEnvironment { (2, "float".into()), (3, "bool".into()), (4, "none".into()), - (5, "Foo".into()), - (6, "Bar".into()), - (7, "Bar2".into()), + (5, "range".into()), + (6, "Foo".into()), + (7, "Bar".into()), + (8, "Bar2".into()), ] .iter() .cloned() @@ -317,9 +328,9 @@ impl TestEnvironment { let resolver = Arc::new(Resolver { id_to_type: identifier_mapping.clone(), id_to_def: [ - ("Foo".into(), DefinitionId(5)), - ("Bar".into(), DefinitionId(6)), - ("Bar2".into(), DefinitionId(7)), + ("Foo".into(), DefinitionId(6)), + ("Bar".into(), DefinitionId(7)), + ("Bar2".into(), DefinitionId(8)), ] .iter() .cloned()