diff --git a/nac3core/src/top_level.rs b/nac3core/src/top_level.rs index 6b573ec7..06507437 100644 --- a/nac3core/src/top_level.rs +++ b/nac3core/src/top_level.rs @@ -4,9 +4,9 @@ use std::{collections::HashMap, collections::HashSet, sync::Arc}; use super::typecheck::type_inferencer::PrimitiveStore; use super::typecheck::typedef::{SharedUnifier, Type, TypeEnum, Unifier}; -use crate::symbol_resolver::SymbolResolver; +use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Mapping}; use crate::typecheck::typedef::{FunSignature, FuncArg}; -use itertools::{Itertools, chain}; +use itertools::Itertools; use parking_lot::{Mutex, RwLock}; use rustpython_parser::ast::{self, Stmt}; @@ -154,7 +154,7 @@ impl TopLevelComposer { top_level_def_list.into_iter().zip(ast_list).collect_vec() ).into(), primitives: primitives.0, - unifier: primitives.1.into(), + unifier: primitives.1, class_method_to_def_id: Default::default(), to_be_analyzed_class: Default::default(), }; @@ -252,22 +252,11 @@ impl TopLevelComposer { // move the ast to the entry of the class in the ast_list class_def_ast.1 = Some(ast); - // put methods into the class def - { - let mut class_def = class_def_ast.0.write(); - let class_def_methods = - if let TopLevelDef::Class { methods, .. } = class_def.deref_mut() { - methods - } else { unimplemented!() }; - for (name, _, id) in &class_method_name_def_ids { - class_def_methods.push((name.into(), self.primitives.none, *id)); - } - } - // now class_def_ast and class_method_def_ast_ids are ok, put them into actual def list in correct order def_list.push(class_def_ast); - for (_, def, _) in class_method_name_def_ids { + for (name, def, id) in class_method_name_def_ids { def_list.push((def, None)); + self.class_method_to_def_id.insert(name, id); } // put the constructor into the def_list @@ -280,8 +269,7 @@ impl TopLevelComposer { )); // class, put its def_id into the to be analyzed set - let to_be_analyzed = &mut self.to_be_analyzed_class; - to_be_analyzed.push(DefinitionId(class_def_id)); + self.to_be_analyzed_class.push(DefinitionId(class_def_id)); Ok((class_name, DefinitionId(class_def_id))) } @@ -461,38 +449,50 @@ impl TopLevelComposer { return Err("expect concrete class/type to be base class".into()); }; - // write to the class ancestors - class_ancestors.push(base_id); + // write to the class ancestors, make sure the uniqueness + if !class_ancestors.contains(&base_id) { + class_ancestors.push(base_id); + } else { + return Err("cannot specify the same base class twice".into()) + } } } Ok(()) } /// step 3, class fields and methods + // FIXME: need analyze base classes here + // FIXME: how to deal with self type + // FIXME: how to prevent cycles fn analyze_top_level_class_fields_methods(&mut self) -> Result<(), String> { - let mut def_list = self.definition_ast_list.write(); + let mut def_ast_list = self.definition_ast_list.write(); let converted_top_level = &self.to_top_level_context(); let primitives = &self.primitives; let to_be_analyzed_class = &mut self.to_be_analyzed_class; let unifier = &mut self.unifier; + + 'class: loop{ + if to_be_analyzed_class.is_empty() { break; } - while !to_be_analyzed_class.is_empty() { let class_ind = to_be_analyzed_class.remove(0).0; - let (class_name, class_body, classs_def) = { - let class_ast = def_list[class_ind].1.as_ref(); + let (class_name, class_body, class_resolver) = { + let (class_def, class_ast) = &mut def_ast_list[class_ind]; if let Some(ast::Located { node: ast::StmtKind::ClassDef { name, body, .. }, .. - }) = class_ast + }) = class_ast.as_ref() { - let class_def = def_list[class_ind].0; - (name, body, class_def) + if let TopLevelDef::Class { resolver, .. } = class_def.write().deref() { + (name, body, resolver.as_ref().unwrap().clone()) + } else { unreachable!() } } else { unreachable!("should be class def ast") } }; - let class_methods_parsing_result: Vec<(String, Type, DefinitionId)> = vec![]; - let class_fields_parsing_result: Vec<(String, Type)> = vec![]; + // need these vectors to check re-defining methods, class fields + // and store the parsed result in case some method cannot be typed for now + let mut class_methods_parsing_result: Vec<(String, Type, DefinitionId)> = vec![]; + let mut class_fields_parsing_result: Vec<(String, Type)> = vec![]; for b in class_body { if let ast::StmtKind::FunctionDef { args: method_args_ast, @@ -502,181 +502,193 @@ impl TopLevelComposer { .. } = &b.node { - let (class_def, method_def) = { - // unwrap should not fail - let method_ind = class_method_to_def_id - .get(&Self::name_mangling(class_name.into(), method_name)) - .unwrap() - .0; - - // split the def_list to two parts to get the - // mutable reference to both the method and the class - assert_ne!(method_ind, class_ind); - let min_ind = - (if method_ind > class_ind { class_ind } else { method_ind }) + 1; - let (head_slice, tail_slice) = def_list.split_at_mut(min_ind); - let (new_method_ind, new_class_ind) = ( - if method_ind >= min_ind { method_ind - min_ind } else { method_ind }, - if class_ind >= min_ind { class_ind - min_ind } else { class_ind }, - ); - if new_class_ind == class_ind { - (&mut head_slice[new_class_ind], &mut tail_slice[new_method_ind]) - } else { - (&mut tail_slice[new_class_ind], &mut head_slice[new_method_ind]) - } - }; - let (class_fields, class_methods, class_resolver) = { - if let TopLevelDef::Class { resolver, fields, methods, .. } = - class_def.0.get_mut() - { - (fields, methods, resolver) - } else { - unreachable!("must be class def here") - } - }; - - let arg_tys = method_args_ast - .args - .iter() - .map(|x| -> Result { - if x.node.arg != "self" { - let annotation = x + let arg_name_tys: Vec<(String, Type)> = { + let mut result = vec![]; + for a in &method_args_ast.args { + if a.node.arg != "self" { + let annotation = a .node .annotation .as_ref() .ok_or_else(|| { "type annotation for function parameter is needed".to_string() - })? - .as_ref(); + })?.as_ref(); let ty = - class_resolver.as_ref().unwrap().lock().parse_type_annotation( + class_resolver.as_ref().lock().parse_type_annotation( converted_top_level, unifier.borrow_mut(), primitives, annotation, )?; - Ok(ty) + if !Self::check_ty_analyzed(ty, unifier, to_be_analyzed_class) { + to_be_analyzed_class.push(DefinitionId(class_ind)); + continue 'class; + } + result.push((a.node.arg.to_string(), ty)); } else { // TODO: handle self, how unimplemented!() } + } + result + }; + + let method_type_var = + arg_name_tys + .iter() + .filter_map(|(_, ty)| { + let ty_enum = unifier.get_ty(*ty); + if let TypeEnum::TVar { id, .. } = ty_enum.as_ref() { + Some((*id, *ty)) + } else { None } }) - .collect::, _>>()?; + .collect::>(); - let ret_ty = if method_name != "__init__" { - method_returns_ast - .as_ref() - .map(|x| - class_resolver.as_ref().unwrap().lock().parse_type_annotation( - converted_top_level, - unifier.borrow_mut(), - primitives, - x.as_ref(), + let ret_ty = { + if method_name != "__init__" { + let ty = method_returns_ast + .as_ref() + .map(|x| + class_resolver.as_ref().lock().parse_type_annotation( + converted_top_level, + unifier.borrow_mut(), + primitives, + x.as_ref(), + ) ) - ) - .ok_or_else(|| "return type annotation needed".to_string())?? - } else { - // TODO: self type, how - unimplemented!() + .ok_or_else(|| "return type annotation error".to_string())??; + if !Self::check_ty_analyzed(ty, unifier, to_be_analyzed_class) { + to_be_analyzed_class.push(DefinitionId(class_ind)); + continue 'class; + } else { ty } + } else { + // TODO: __init__ function, self type, how + unimplemented!() + } }; // handle fields - if method_name == "__init__" { - for body in method_body_ast { - match &body.node { - ast::StmtKind::AnnAssign { - target, - annotation, - .. - } if { - if let ast::ExprKind::Attribute { - value, - attr, + let class_field_name_tys: Option> = + if method_name == "__init__" { + let mut result: Vec<(String, Type)> = vec![]; + for body in method_body_ast { + match &body.node { + ast::StmtKind::AnnAssign { + target, + annotation, .. - } = &target.node { - if let ast::ExprKind::Name {id, ..} = &value.node { - id == "self" + } if { + if let ast::ExprKind::Attribute { + value, .. + } = &target.node { + matches!( + &value.node, + ast::ExprKind::Name { id, .. } if id == "self") } else { false } - } else { false } - } => { - // TODO: record this field with its type - }, + } => { + let field_ty = class_resolver.as_ref().lock().parse_type_annotation( + converted_top_level, + unifier.borrow_mut(), + primitives, + annotation.as_ref())?; + if !Self::check_ty_analyzed(field_ty, unifier, to_be_analyzed_class) { + to_be_analyzed_class.push(DefinitionId(class_ind)); + continue 'class; + } else { + result.push(( + if let ast::ExprKind::Attribute { + attr, .. + } = &target.node { + attr.to_string() + } else { unreachable!() }, + field_ty + )) } + }, - // TODO: exclude those without type annotation - ast::StmtKind::Assign { - targets, - .. - } if { - if let ast::ExprKind::Attribute { - value, - attr, - .. - } = &targets[0].node { - if let ast::ExprKind::Name {id, ..} = &value.node { - id == "self" + // exclude those without type annotation + ast::StmtKind::Assign { + targets, .. + } if { + if let ast::ExprKind::Attribute { + value, .. + } = &targets[0].node { + matches!( + &value.node, + ast::ExprKind::Name {id, ..} if id == "self") } else { false } - } else { false } - } => { - unimplemented!() - }, + } => { + return Err("class fields type annotation needed".into()) + }, - // do nothing - _ => { } - } - } + // do nothing + _ => { } + } + }; + Some(result) + } else { None }; + + // current method all type ok, put the current method into the list + if class_methods_parsing_result + .iter() + .any(|(name, _, _)| name == method_name) { + return Err("duplicate method definition".into()) + } else { + class_methods_parsing_result.push(( + method_name.clone(), + unifier.add_ty(TypeEnum::TFunc(FunSignature { + ret: ret_ty, + args: arg_name_tys.into_iter().map(|(name, ty)| { + FuncArg { + name, + ty, + default_value: None + } + }).collect_vec(), + vars: method_type_var + }.into())), + *self.class_method_to_def_id.get(&Self::name_mangling(class_name.clone(), method_name)).unwrap() + )) } - let all_tys_ok = { - let ret_ty_iter = vec![ret_ty]; - let ret_ty_iter = ret_ty_iter.iter(); - let mut all_tys = chain!(arg_tys.iter(), ret_ty_iter); - all_tys.all(|x| { - let type_enum = unifier.get_ty(*x); - match type_enum.as_ref() { - TypeEnum::TObj { obj_id, .. } => { - !to_be_analyzed_class.contains(obj_id) - } - TypeEnum::TVirtual { ty } => { - if let TypeEnum::TObj { obj_id, .. } = - unifier.get_ty(*ty).as_ref() - { - !to_be_analyzed_class.contains(obj_id) - } else { - unreachable!() - } - } - TypeEnum::TVar { .. } => true, - _ => unreachable!(), - } - }) - }; - - if all_tys_ok { - // TODO: put related value to the `class_methods_parsing_result` - unimplemented!() - } else { - to_be_analyzed_class.push(DefinitionId(class_ind)); - // TODO: go to the next WHILE loop - unimplemented!() + // put the fiedlds inside + if let Some(class_field_name_tys) = class_field_name_tys { + assert!(class_fields_parsing_result.is_empty()); + class_fields_parsing_result.extend(class_field_name_tys); } } else { // what should we do with `class A: a = 3`? + // do nothing, continue the for loop to iterate class ast continue; } - } - - // TODO: now it should be confirmed that every + }; + + // now it should be confirmed that every // methods and fields of the class can be correctly typed, put the results - // into the actual def_list and the unifier - } + // into the actual class def method and fields field + let (class_def, _) = &def_ast_list[class_ind]; + let mut class_def = class_def.write(); + if let TopLevelDef::Class { fields, methods, .. } = class_def.deref_mut() { + for (ref n, ref t) in class_fields_parsing_result { + fields.push((n.clone(), *t)); + } + for (n, t, id) in &class_methods_parsing_result { + methods.push((n.clone(), *t, *id)); + } + } else { unreachable!() } + + // change the signature field of the class methods + for (_, ty, id) in &class_methods_parsing_result { + let (method_def, _) = &def_ast_list[id.0]; + let mut method_def = method_def.write(); + if let TopLevelDef::Function { signature, .. } = method_def.deref_mut() { + *signature = *ty; + } + } + }; Ok(()) } - fn analyze_top_level_inheritance(&mut self) -> Result<(), String> { - unimplemented!() - } - fn analyze_top_level_function(&mut self) -> Result<(), String> { unimplemented!() } @@ -684,4 +696,27 @@ impl TopLevelComposer { fn analyze_top_level_field_instantiation(&mut self) -> Result<(), String> { unimplemented!() } + + fn check_ty_analyzed(ty: Type, + unifier: &mut Unifier, + to_be_analyzed: &[DefinitionId]) -> bool + { + let type_enum = unifier.get_ty(ty); + match type_enum.as_ref() { + TypeEnum::TObj { obj_id, .. } => { + !to_be_analyzed.contains(obj_id) + } + TypeEnum::TVirtual { ty } => { + if let TypeEnum::TObj { obj_id, .. } = + unifier.get_ty(*ty).as_ref() + { + !to_be_analyzed.contains(obj_id) + } else { + unreachable!() + } + } + TypeEnum::TVar { .. } => true, + _ => unreachable!(), + } + } }