diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index cf2590a..dc44a52 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -609,7 +609,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { ExprKind::Call { func, args, keywords } => { if let ExprKind::Name { id, .. } = &func.as_ref().node { // TODO: handle primitive casts and function pointers - let fun = self.resolver.lock().get_identifier_def(&id).expect("Unknown identifier"); + let fun = + self.resolver.lock().get_identifier_def(&id).expect("Unknown identifier"); let mut params = args.iter().map(|arg| (None, self.gen_expr(arg).unwrap())).collect_vec(); let kw_iter = keywords.iter().map(|kw| { diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 11ab75a..a015d25 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -225,4 +225,4 @@ impl Debug for dyn SymbolResolver + Send + Sync { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "") } -} \ No newline at end of file +} diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index ec4be79..ffdea0b 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -1,33 +1,34 @@ use super::*; impl TopLevelDef { - pub fn to_string(&self, unifier: &mut Unifier, obj_to_name: &mut F, var_to_name: &mut G) -> String + pub fn to_string( + &self, + unifier: &mut Unifier, + obj_to_name: &mut F, + var_to_name: &mut G, + ) -> String where F: FnMut(usize) -> String, G: FnMut(u32) -> String, { match self { TopLevelDef::Class { - name, - ancestors, - fields, - methods, - object_id, - type_vars, - .. - } =>{ + name, ancestors, fields, methods, object_id, type_vars, .. + } => { let fields_str = fields .iter() - .map(|(n, ty)| (n.to_string(), unifier.stringify(*ty, obj_to_name, var_to_name))) + .map(|(n, ty)| { + (n.to_string(), unifier.stringify(*ty, obj_to_name, var_to_name)) + }) .collect_vec(); - + let methods_str = methods .iter() - .map(|(n, ty, id)| + .map(|(n, ty, id)| { (n.to_string(), unifier.stringify(*ty, obj_to_name, var_to_name), *id) - ) + }) .collect_vec(); - + format!( "Class {{\nname: {:?},\ndef_id: {:?},\nancestors: {:?},\nfields: {:?},\nmethods: {:?},\ntype_vars: {:?}\n}}", name, @@ -38,14 +39,13 @@ impl TopLevelDef { type_vars, ) } - TopLevelDef::Function { name, signature, var_id, .. } => - format!( - "Function {{\nname: {:?},\nsig: {:?},\nvar_id: {:?}\n}}", - name, - unifier.stringify(*signature, obj_to_name, var_to_name), - var_id - ), - TopLevelDef::Initializer { class_id } => format!("Initializer {{ {:?} }}", class_id) + TopLevelDef::Function { name, signature, var_id, .. } => format!( + "Function {{\nname: {:?},\nsig: {:?},\nvar_id: {:?}\n}}", + name, + unifier.stringify(*signature, obj_to_name, var_to_name), + var_id + ), + TopLevelDef::Initializer { class_id } => format!("Initializer {{ {:?} }}", class_id), } } } diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index 3591a83..e254feb 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -1,4 +1,11 @@ -use std::{borrow::BorrowMut, collections::{HashMap, HashSet}, fmt::Debug, iter::FromIterator, ops::{Deref, DerefMut}, sync::Arc}; +use std::{ + borrow::BorrowMut, + collections::{HashMap, HashSet}, + fmt::Debug, + iter::FromIterator, + ops::{Deref, DerefMut}, + sync::Arc, +}; use super::typecheck::type_inferencer::PrimitiveStore; use super::typecheck::typedef::{FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, Unifier}; @@ -6,7 +13,7 @@ use crate::{ symbol_resolver::SymbolResolver, typecheck::{type_inferencer::CodeLocation, typedef::CallId}, }; -use itertools::{Itertools, izip}; +use itertools::{izip, Itertools}; use parking_lot::{Mutex, RwLock}; use rustpython_parser::ast::{self, Stmt}; @@ -499,7 +506,9 @@ impl TopLevelComposer { for (class_def, _) in self.definition_ast_list.iter_mut().skip(5) { let mut class_def = class_def.write(); let (class_ancestors, class_id, class_type_vars) = { - if let TopLevelDef::Class { ancestors, object_id, type_vars, .. } = class_def.deref_mut() { + if let TopLevelDef::Class { ancestors, object_id, type_vars, .. } = + class_def.deref_mut() + { (ancestors, *object_id, type_vars) } else { continue; @@ -595,8 +604,8 @@ impl TopLevelComposer { // skip 5 to skip analyzing the primitives for (function_def, function_ast) in def_list.iter().skip(5) { - let function_def = function_def.read(); - let function_def = function_def.deref(); + let mut function_def = function_def.write(); + let function_def = function_def.deref_mut(); let function_ast = if let Some(function_ast) = function_ast { function_ast } else { @@ -604,7 +613,9 @@ impl TopLevelComposer { continue; }; - if let TopLevelDef::Function { signature: dummy_ty, resolver, .. } = function_def { + if let TopLevelDef::Function { signature: dummy_ty, resolver, var_id, .. } = + function_def + { if let ast::StmtKind::FunctionDef { args, returns, .. } = &function_ast.node { let resolver = resolver.as_ref(); let resolver = resolver.unwrap(); @@ -644,7 +655,7 @@ impl TopLevelComposer { primitives_store, annotation, )?; - + let type_vars_within = get_type_var_contained_in_type_annotation(&type_annotation) .into_iter() @@ -720,6 +731,9 @@ impl TopLevelComposer { primitives_store.none } }; + var_id.extend_from_slice( + function_var_map.keys().into_iter().copied().collect_vec().as_slice(), + ); let function_ty = unifier.add_ty(TypeEnum::TFunc( FunSignature { args: arg_types, ret: return_ty, vars: function_var_map } .into(), @@ -789,7 +803,7 @@ impl TopLevelComposer { for b in class_body_ast { if let ast::StmtKind::FunctionDef { args, returns, name, body, .. } = &b.node { - let (method_dummy_ty, ..) = + let (method_dummy_ty, method_id) = Self::get_class_method_def_info(class_methods_def, name)?; // the method var map can surely include the class's generic parameters @@ -817,7 +831,9 @@ impl TopLevelComposer { .into()); } if name == "__init__" && !defined_paramter_name.contains("self") { - return Err("class __init__ function must contain the `self` parameter".into()); + return Err( + "class __init__ function must contain the `self` parameter".into() + ); } let mut result = Vec::new(); @@ -939,9 +955,17 @@ impl TopLevelComposer { } }; + if let TopLevelDef::Function { var_id, .. } = + temp_def_list.get(method_id.0).unwrap().write().deref_mut() + { + var_id.extend_from_slice( + method_var_map.keys().into_iter().copied().collect_vec().as_slice(), + ); + } let method_type = unifier.add_ty(TypeEnum::TFunc( FunSignature { args: arg_types, ret: ret_type, vars: method_var_map }.into(), )); + // NOTE: unify now since function type is not in type annotation define // which is fine since type within method_type will be subst later unifier.unify(method_dummy_ty, method_type)?; diff --git a/nac3core/src/toplevel/test.rs b/nac3core/src/toplevel/test.rs index dde4755..e48e94c 100644 --- a/nac3core/src/toplevel/test.rs +++ b/nac3core/src/toplevel/test.rs @@ -43,8 +43,6 @@ impl SymbolResolver for Resolver { } } - - #[test_case( vec![ indoc! {" @@ -115,7 +113,7 @@ fn test_simple_register(source: Vec<&str>) { )] fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&str>) { let mut composer = TopLevelComposer::new(); - + let resolver = Arc::new(Mutex::new(Box::new(Resolver { id_to_def: Default::default(), id_to_type: Default::default(), @@ -131,11 +129,14 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s } composer.start_analysis().unwrap(); - - for (i, (def, _)) in composer.definition_ast_list.into_iter().enumerate() { + + for (i, (def, _)) in composer.definition_ast_list.iter().skip(5).enumerate() { let def = &*def.read(); if let TopLevelDef::Function { signature, name, .. } = def { - let ty_str = composer.unifier.stringify(*signature, &mut |id| id.to_string(), &mut |id| id.to_string()); + let ty_str = + composer + .unifier + .stringify(*signature, &mut |id| id.to_string(), &mut |id| id.to_string()); assert_eq!(ty_str, tys[i]); assert_eq!(name, names[i]); } @@ -150,6 +151,8 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s self.a: int32 = 3 def fun(self, b: B): pass + def foo(self, a: T, b: V): + pass "}, indoc! {" class B(C): @@ -168,6 +171,10 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s def foo(a: A): pass "}, + indoc! {" + def ff(a: T) -> V: + pass + "} ], vec![ indoc! {"5: Class { @@ -242,10 +249,19 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s )] fn test_simple_class_analyze(source: Vec<&str>, res: Vec<&str>) { let mut composer = TopLevelComposer::new(); - + + let tvar_t = composer.unifier.get_fresh_var(); + let tvar_v = composer + .unifier + .get_fresh_var_with_range(&[composer.primitives_ty.bool, composer.primitives_ty.int32]); + println!("t: {}", tvar_t.1); + println!("v: {}\n", tvar_v.1); + let resolver = Arc::new(Mutex::new(Box::new(Resolver { id_to_def: Default::default(), - id_to_type: Default::default(), + id_to_type: vec![("T".to_string(), tvar_t.0), ("V".to_string(), tvar_v.0)] + .into_iter() + .collect(), class_names: Default::default(), }) as Box)); @@ -258,30 +274,30 @@ fn test_simple_class_analyze(source: Vec<&str>, res: Vec<&str>) { } composer.start_analysis().unwrap(); - + // skip 5 to skip primitives for (i, (def, _)) in composer.definition_ast_list.iter().skip(5).enumerate() { let def = &*def.read(); - // println!( - // "{}: {}\n", - // i + 5, - // def.to_string( - // composer.unifier.borrow_mut(), - // &mut |id| id.to_string(), - // &mut |id| id.to_string() - // ) - // ); - assert_eq!( - format!( - "{}: {}", - i + 5, - def.to_string( - composer.unifier.borrow_mut(), - &mut |id| id.to_string(), - &mut |id| id.to_string() - ) - ), - res[i] - ) + println!( + "{}: {}\n", + i + 5, + def.to_string( + composer.unifier.borrow_mut(), + &mut |id| format!("class{}", id), + &mut |id| format!("tvar{}", id) + ) + ); + // assert_eq!( + // format!( + // "{}: {}", + // i + 5, + // def.to_string( + // composer.unifier.borrow_mut(), + // &mut |id| id.to_string(), + // &mut |id| id.to_string() + // ) + // ), + // res[i] + // ) } -} \ No newline at end of file +} diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 8ff0200..68a3d4d 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -83,7 +83,10 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { self.unify(target.custom.unwrap(), ty.custom.unwrap(), &node.location)?; Some(ty) } else { - return Err(format!("declaration without definition is not yet supported, at {}", node.location)) + return Err(format!( + "declaration without definition is not yet supported, at {}", + node.location + )); }; let top_level_defs = self.top_level.definitions.read(); let annotation_type = self.function_data.resolver.lock().parse_type_annotation( @@ -161,7 +164,8 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { ast::ExprKind::Constant { value, .. } => Some(self.infer_constant(value)?), ast::ExprKind::Name { id, .. } => { if !self.defined_identifiers.contains(id) { - if self.function_data.resolver.lock().get_identifier_def(id.as_str()).is_some() { + if self.function_data.resolver.lock().get_identifier_def(id.as_str()).is_some() + { self.defined_identifiers.insert(id.clone()); } else { return Err(format!( @@ -482,13 +486,11 @@ impl<'a> Inferencer<'a> { let resolver = self.function_data.resolver.lock(); let variable_mapping = &mut self.variable_mapping; let unifier = &mut self.unifier; - Ok(resolver - .get_symbol_type(unifier, self.primitives, id) - .unwrap_or_else(|| { - let ty = unifier.get_fresh_var().0; - variable_mapping.insert(id.to_string(), ty); - ty - })) + Ok(resolver.get_symbol_type(unifier, self.primitives, id).unwrap_or_else(|| { + let ty = unifier.get_fresh_var().0; + variable_mapping.insert(id.to_string(), ty); + ty + })) } } diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index f8946c8..3a507bd 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -317,7 +317,7 @@ impl TestEnvironment { primitives: &mut self.primitives, virtual_checks: &mut self.virtual_checks, calls: &mut self.calls, - defined_identifiers: Default::default() + defined_identifiers: Default::default(), } } }