From 364054331c39498111bbc09a46ec9cadc7eaddef Mon Sep 17 00:00:00 2001 From: ychenfo Date: Mon, 23 Aug 2021 02:52:54 +0800 Subject: [PATCH] handle class fields and methods --- nac3core/src/symbol_resolver.rs | 2 +- nac3core/src/top_level.rs | 967 +++++++++++++++++++++----------- 2 files changed, 632 insertions(+), 337 deletions(-) diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 91296ebf4..994b9de4e 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -83,7 +83,7 @@ pub fn parse_type_annotation( // it could be a type variable let ty = resolver .get_symbol_type(unifier, primitives, x) - .ok_or_else(|| "Cannot use function name as type".to_owned())?; + .ok_or_else(|| "unknown type variable name".to_owned())?; if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) { Ok(ty) } else { diff --git a/nac3core/src/top_level.rs b/nac3core/src/top_level.rs index 5772593f8..5e6a63605 100644 --- a/nac3core/src/top_level.rs +++ b/nac3core/src/top_level.rs @@ -2,17 +2,300 @@ use std::borrow::BorrowMut; use std::ops::{Deref, DerefMut}; use std::{collections::HashMap, collections::HashSet, sync::Arc}; +use self::top_level_type_annotation_info::*; use super::typecheck::type_inferencer::PrimitiveStore; use super::typecheck::typedef::{SharedUnifier, Type, TypeEnum, Unifier}; use crate::typecheck::{typedef::{FunSignature, FuncArg}}; -use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Mapping}; +use crate::symbol_resolver::SymbolResolver; use itertools::{Itertools, izip}; use parking_lot::{Mutex, RwLock}; use rustpython_parser::ast::{self, Stmt}; -#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] +#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] pub struct DefinitionId(pub usize); +pub mod top_level_type_annotation_info { + use super::*; + + #[derive(Clone)] + pub enum TypeAnnotation { + PrimitiveKind(Type), + ConcretizedCustomClassKind { + id: DefinitionId, + // can not be type var, others are all fine + params: Vec + }, + // can only be ConcretizedCustomClassKind + VirtualKind(Box), + TypeVarKind(Type), + SelfTypeKind(DefinitionId), + } + + pub fn parse_ast_to_type_annotation_kinds( + resolver: &dyn SymbolResolver, + top_level_defs: &[Arc>], + unifier: &mut Unifier, + primitives: &PrimitiveStore, + expr: &ast::Expr, + ) -> Result { + let results = vec![ + parse_ast_to_concrete_primitive_kind(resolver, top_level_defs, unifier, primitives, expr), + parse_ast_to_concretized_custom_class_kind(resolver, top_level_defs, unifier, primitives, expr), + parse_ast_to_type_variable_kind(resolver, top_level_defs, unifier, primitives, expr), + parse_ast_to_virtual_kind(resolver, top_level_defs, unifier, primitives, expr) + ]; + let results = results.iter().filter(|x| x.is_ok()).collect_vec(); + + if results.len() == 1 { + results[0].clone() + } else { + Err("cannot be parsed the type annotation without ambiguity".into()) + } + + } + + pub fn get_type_from_type_annotation_kinds( + top_level_defs: &[Arc>], + unifier: &mut Unifier, + primitives: &PrimitiveStore, + ann: &TypeAnnotation + ) -> Result { + match ann { + TypeAnnotation::ConcretizedCustomClassKind { id, params } => { + let class_def = top_level_defs[id.0].read(); + if let TopLevelDef::Class {fields, methods, type_vars, .. } = &*class_def { + if type_vars.len() != params.len() { + Err(format!( + "unexpected number of type parameters: expected {} but got {}", + type_vars.len(), + params.len() + )) + } else { + let param_ty = params + .iter() + .map(|x| get_type_from_type_annotation_kinds( + top_level_defs, + unifier, + primitives, + x + )) + .collect::, _>>()?; + + let subst = type_vars + .iter() + .map(|x| { + if let TypeEnum::TVar { id, .. } = unifier.get_ty(*x).as_ref() { + *id + } else { + unreachable!() + } + }) + .zip(param_ty.into_iter()) + .collect::>(); + + let mut tobj_fields = methods + .iter() + .map(|(name, ty, _)| { + let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty); + (name.clone(), subst_ty) + }) + .collect::>(); + + tobj_fields.extend( + fields + .iter() + .map(|(name, ty)| { + let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty); + (name.clone(), subst_ty) + }) + ); + Ok(unifier.add_ty(TypeEnum::TObj { + obj_id: *id, + fields: tobj_fields.into(), + params: subst.into() + })) + } + } else { + unreachable!("should be class def here") + } + } + TypeAnnotation::SelfTypeKind(obj_id) => { + let class_def = top_level_defs[obj_id.0].read(); + if let TopLevelDef::Class {fields, methods, type_vars, .. } = &*class_def { + let subst = type_vars + .iter() + .map(|x| { + if let TypeEnum::TVar { id, .. } = unifier.get_ty(*x).as_ref() { + (*id, *x) + } else { + unreachable!() + } + }) + .collect::>(); + + let mut tobj_fields = methods + .iter() + .map(|(name, ty, _)| { + (name.clone(), *ty) + }) + .collect::>(); + + tobj_fields.extend(fields.clone().into_iter()); + Ok(unifier.add_ty(TypeEnum::TObj { + obj_id: *obj_id, + fields: tobj_fields.into(), + params: subst.into() + })) + } else { + unreachable!("should be class def here") + } + } + TypeAnnotation::PrimitiveKind(ty) => Ok(*ty), + TypeAnnotation::TypeVarKind(ty) => Ok(*ty), + TypeAnnotation::VirtualKind(ty) => { + let ty = get_type_from_type_annotation_kinds( + top_level_defs, + unifier, + primitives, + ty.as_ref() + )?; + Ok(unifier.add_ty(TypeEnum::TVirtual { ty })) + } + } + } + + fn parse_ast_to_concrete_primitive_kind( + _resolver: &dyn SymbolResolver, + _top_level_defs: &[Arc>], + _unifier: &mut Unifier, + primitives: &PrimitiveStore, + expr: &ast::Expr, + ) -> Result { + match &expr.node { + ast::ExprKind::Name { id, .. } => match id.as_str() { + "int32" => Ok(TypeAnnotation::PrimitiveKind(primitives.int32)), + "int64" => Ok(TypeAnnotation::PrimitiveKind(primitives.int64)), + "float" => Ok(TypeAnnotation::PrimitiveKind(primitives.float)), + "bool" => Ok(TypeAnnotation::PrimitiveKind(primitives.bool)), + "None" => Ok(TypeAnnotation::PrimitiveKind(primitives.none)), + _ => Err("not primitive".into()) + } + + _ => Err("not primitive".into()) + } + } + + pub fn parse_ast_to_concretized_custom_class_kind( + resolver: &dyn SymbolResolver, + top_level_defs: &[Arc>], + unifier: &mut Unifier, + primitives: &PrimitiveStore, + expr: &ast::Expr, + ) -> Result { + match &expr.node { + ast::ExprKind::Name { id, .. } => match id.as_str() { + "int32" | "int64" | "float" | "bool" | "None" => + Err("expect custom class instead of primitives here".into()), + x => { + let obj_id = resolver + .get_identifier_def(x) + .ok_or_else(|| "unknown class name".to_string())?; + let def = top_level_defs[obj_id.0].read(); + if let TopLevelDef::Class { .. } = &*def { + Ok(TypeAnnotation::ConcretizedCustomClassKind { id: obj_id, params: vec![]}) + } else { + Err("function cannot be used as a type".into()) + } + } + }, + + ast::ExprKind::Subscript { value, slice, .. } => { + if let ast::ExprKind::Name { id, .. } = &value.node { + if vec!["virtual", "Generic"].contains(&id.as_str()) { return Err("keywords cannot be class name".into()) } + let obj_id = resolver + .get_identifier_def(id) + .ok_or_else(|| "unknown class name".to_string())?; + let def = top_level_defs[obj_id.0].read(); + if let TopLevelDef::Class { .. } = &*def { + let param_type_infos = + if let ast::ExprKind::Tuple { elts, .. } = &slice.node { + elts.iter() + .map(|v| { + parse_ast_to_type_annotation_kinds( + resolver, + top_level_defs, + unifier, + primitives, + v + ) + }) + .collect::, _>>()? + } else { + vec![parse_ast_to_type_annotation_kinds( + resolver, top_level_defs, unifier, primitives, slice, + )?] + }; + if param_type_infos.iter().any(|x| matches!(x, TypeAnnotation::TypeVarKind( .. ))) { + return Err("cannot apply type variable to class generic parameters".into()) + } + Ok(TypeAnnotation::ConcretizedCustomClassKind { id: obj_id, params: param_type_infos }) + } else { + Err("function cannot be used as a type".into()) + } + } else { + Err("unsupported expression type".into()) + } + }, + + _ => Err("unsupported expression type".into()) + } + } + + pub fn parse_ast_to_virtual_kind( + resolver: &dyn SymbolResolver, + top_level_defs: &[Arc>], + unifier: &mut Unifier, + primitives: &PrimitiveStore, + expr: &ast::Expr, + ) -> Result { + match &expr.node { + ast::ExprKind::Subscript { value, slice, .. } + if matches!(&value.node, ast::ExprKind::Name { id, .. } if id == "virtual") => { + let def = parse_ast_to_concretized_custom_class_kind( + resolver, + top_level_defs, + unifier, + primitives, + slice.as_ref() + )?; + if !matches!(def, TypeAnnotation::ConcretizedCustomClassKind { .. }) { + unreachable!("should must be concretized custom class kind") + } + Ok(TypeAnnotation::VirtualKind(def.into())) + } + + _ => Err("virtual type annotation must be like `virtual[ .. ]`".into()) + } + } + + pub fn parse_ast_to_type_variable_kind( + resolver: &dyn SymbolResolver, + _top_level_defs: &[Arc>], + unifier: &mut Unifier, + primitives: &PrimitiveStore, + expr: &ast::Expr, + ) -> Result { + if let ast::ExprKind::Name { id, .. } = &expr.node { + let ty = resolver + .get_symbol_type(unifier, primitives, id) + .ok_or_else(|| "unknown type variable name".to_string())?; + Ok(TypeAnnotation::TypeVarKind(ty)) + } else { + Err("unsupported expression for type variable".into()) + } + } +} + pub enum TopLevelDef { Class { // name for error messages and symbols @@ -26,7 +309,7 @@ pub enum TopLevelDef { // class methods, pointing to the corresponding function definition. methods: Vec<(String, Type, DefinitionId)>, // ancestor classes, including itself. - ancestors: Vec, + ancestors: Vec, // symbol resolver of the module defined the class, none if it is built-in type resolver: Option>>, }, @@ -60,27 +343,28 @@ pub struct TopLevelContext { pub unifiers: Arc>>, } -impl TopLevelContext { - pub fn read_top_level_def_list(&self) -> &[Arc>] { - self.definitions.as_slice() - } -} - pub struct TopLevelComposer { // list of top level definitions, same as top level context pub definition_ast_list: Vec<(Arc>, Option>)>, // start as a primitive unifier, will add more top_level defs inside pub unifier: Unifier, // primitive store - pub primitives: PrimitiveStore, + pub primitives_ty: PrimitiveStore, // mangled class method name to def_id // pub class_method_to_def_id: HashMap, // record the def id of the classes whoses fields and methods are to be analyzed // pub to_be_analyzed_class: Vec, + pub keyword_list: Vec, +} + +impl TopLevelContext { + pub fn read_top_level_def_list(&self) -> &[Arc>] { + self.definitions.as_slice() + } } impl TopLevelComposer { - pub fn to_top_level_context(self) -> TopLevelContext { + pub fn make_top_level_context(self) -> TopLevelContext { TopLevelContext { definitions: self .definition_ast_list @@ -127,6 +411,7 @@ impl TopLevelComposer { /// return a composer and things to make a "primitive" symbol resolver, so that the symbol /// resolver can later figure out primitive type definitions when passed a primitive type name + // TODO: add list and tuples? pub fn new() -> (Vec<(String, DefinitionId, Type)>, Self) { let primitives = Self::make_primitives(); @@ -142,18 +427,30 @@ impl TopLevelComposer { let composer = TopLevelComposer { definition_ast_list: izip!(top_level_def_list, ast_list).collect_vec(), - primitives: primitives.0, + primitives_ty: primitives.0, unifier: primitives.1, // class_method_to_def_id: Default::default(), // to_be_analyzed_class: Default::default(), + keyword_list: vec![ + "Generic".into(), + "virtual".into(), + "list".into(), + "tuple".into(), + "int32".into(), + "int64".into(), + "float".into(), + "bool".into(), + "none".into(), + "None".into(), + ] }; ( vec![ - ("int32".into(), DefinitionId(0), composer.primitives.int32), - ("int64".into(), DefinitionId(1), composer.primitives.int64), - ("float".into(), DefinitionId(2), composer.primitives.float), - ("bool".into(), DefinitionId(3), composer.primitives.bool), - ("none".into(), DefinitionId(4), composer.primitives.none), + ("int32".into(), DefinitionId(0), composer.primitives_ty.int32), + ("int64".into(), DefinitionId(1), composer.primitives_ty.int64), + ("float".into(), DefinitionId(2), composer.primitives_ty.float), + ("bool".into(), DefinitionId(3), composer.primitives_ty.bool), + ("none".into(), DefinitionId(4), composer.primitives_ty.none), ], composer, ) @@ -172,7 +469,7 @@ impl TopLevelComposer { type_vars: Default::default(), fields: Default::default(), methods: Default::default(), - ancestors: vec![DefinitionId(index)], + ancestors: vec![TypeAnnotation::SelfTypeKind(DefinitionId(index))], resolver, } } @@ -192,12 +489,7 @@ impl TopLevelComposer { } } - // fn get_class_method_def_id(class_name: &str, method_name: &str, resolver: &dyn SymbolResolver) -> Result { - // let class_def = resolver.get_identifier_def(class_name).ok_or_else(|| "no such class".to_string())?; - - // } - - fn name_mangling(class_name: String, method_name: &str) -> String { + fn make_class_method_name(mut class_name: String, method_name: &str) -> String { class_name.push_str(method_name); class_name } @@ -218,6 +510,10 @@ impl TopLevelComposer { ) -> Result<(String, DefinitionId), String> { match &ast.node { ast::StmtKind::ClassDef { name, body, .. } => { + if self.keyword_list.contains(name) { + return Err("cannot use keyword as a class name".into()) + } + let class_name = name.to_string(); let class_def_id = self.definition_ast_list.len(); @@ -236,7 +532,9 @@ impl TopLevelComposer { // module's symbol resolver would not know the name of the class methods, // thus cannot return their definition_id let mut class_method_name_def_ids: Vec<( + // the simple method name without class name String, + // in this top level def, method name is prefixed with the class name Arc>, DefinitionId, Type @@ -254,7 +552,7 @@ impl TopLevelComposer { class_method_name_def_ids.push(( method_name.clone(), RwLock::new(Self::make_top_level_function_def( - Self::name_mangling(class_name, method_name), + Self::make_class_method_name(class_name.clone(), method_name), // later unify with parsed type dummy_method_type.0, resolver.clone(), @@ -273,9 +571,10 @@ impl TopLevelComposer { // move the ast to the entry of the class in the ast_list class_def_ast.1 = Some(ast); // get the methods into the class_def - for (name, _, id, ty) in class_method_name_def_ids { - if let TopLevelDef::Class { methods, .. } = class_def_ast.0.get_mut() { - methods.push((name, ty, id)) + for (name, _, id, ty) in &class_method_name_def_ids { + let mut class_def = class_def_ast.0.write(); + if let TopLevelDef::Class { methods, .. } = class_def.deref_mut() { + methods.push((name.clone(), *ty, *id)) } else { unreachable!() } } // now class_def_ast and class_method_def_ast_ids are ok, put them into actual def list in correct order @@ -291,9 +590,6 @@ impl TopLevelComposer { None, )); - // class, put its def_id into the to be analyzed set - // self.to_be_analyzed_class.push(DefinitionId(class_def_id)); - Ok((class_name, DefinitionId(class_def_id))) } @@ -322,7 +618,12 @@ impl TopLevelComposer { /// step 1, analyze the type vars associated with top level class fn analyze_top_level_class_type_var(&mut self) -> Result<(), String> { - for (class_def, class_ast) in self.definition_ast_list { + let def_list = &self.definition_ast_list; + let temp_def_list = self.extract_def_list(); + let unifier = self.unifier.borrow_mut(); + let primitives_store = &self.primitives_ty; + + for (class_def, class_ast) in def_list { // only deal with class def here let mut class_def = class_def.write(); let (class_bases_ast, class_def_type_vars, class_resolver) = { @@ -341,6 +642,7 @@ impl TopLevelComposer { }; let class_resolver = class_resolver.as_ref().unwrap().lock(); let class_resolver = class_resolver.deref(); + let mut is_generic = false; for b in class_bases_ast { match &b.node { @@ -359,7 +661,7 @@ impl TopLevelComposer { return Err("Only single Generic[...] can be in bases".into()); } - let type_var_list: Vec<&ast::Expr<()>> = vec![]; + let mut type_var_list: Vec<&ast::Expr<()>> = vec![]; // if `class A(Generic[T, V, G])` if let ast::ExprKind::Tuple { elts, .. } = &slice.node { type_var_list.extend(elts.iter()); @@ -370,28 +672,42 @@ impl TopLevelComposer { // parse the type vars let type_vars = type_var_list - .into_iter() - .map(|e| { - let temp_def_list = self.extract_def_list(); - class_resolver.parse_type_annotation( - &temp_def_list, - self.unifier.borrow_mut(), - &self.primitives, - e - ) - }) - .collect::, _>>()?; + .into_iter() + .map(|e| { + class_resolver.parse_type_annotation( + &temp_def_list, + unifier, + primitives_store, + e + ) + }) + .collect::, _>>()?; // check if all are unique type vars let mut occured_type_var_id: HashSet = HashSet::new(); let all_unique_type_var = type_vars.iter().all(|x| { - let ty = self.unifier.get_ty(*x); + let ty = unifier.get_ty(*x); if let TypeEnum::TVar { id, .. } = ty.as_ref() { occured_type_var_id.insert(*id) } else { false } }); + + // NOTE: create a copy of all type vars for the type vars associated with class + let type_vars = type_vars + .into_iter() + .map(|x| { + let range = unifier.get_ty(x); + if let TypeEnum::TVar { range, .. } = range.as_ref() { + let range = &*range.borrow(); + let range = range.as_slice(); + unifier.get_fresh_var_with_range(range).0 + } else { + unreachable!("must be type var here"); + } + }) + .collect_vec(); if !all_unique_type_var { return Err("expect unique type variables".into()); @@ -409,12 +725,9 @@ impl TopLevelComposer { Ok(()) } - /// step 2, base classes. Need to separate step1 and step2 for this reason: - /// `class B(Generic[T, V]); - /// class A(B[int, bool])` - /// if the type var associated with class `B` has not been handled properly, - /// the parse of type annotation of `B[int, bool]` will fail + /// step 2, base classes. fn analyze_top_level_class_bases(&mut self) -> Result<(), String> { + let temp_def_list = self.extract_def_list(); for (class_def, class_ast) in self.definition_ast_list.iter_mut() { let mut class_def = class_def.write(); let (class_bases, class_ancestors, class_resolver) = { @@ -445,31 +758,27 @@ impl TopLevelComposer { ast::ExprKind::Name { id, .. } if id == "Generic" ) ) { continue } - has_base = true; + if has_base { return Err("a class def can only have at most one base class \ - declaration and one generic declaration".into()) + declaration and one generic declaration".into()) } + has_base = true; - let temp_def_list = self.extract_def_list(); - let base_ty = class_resolver.parse_type_annotation( + let base_ty = parse_ast_to_type_annotation_kinds( + class_resolver, &temp_def_list, self.unifier.borrow_mut(), - &self.primitives, + &self.primitives_ty, b )?; - let base_id = - if let TypeEnum::TObj { obj_id, .. } = self.unifier.get_ty(base_ty).as_ref() { - *obj_id - } else { - return Err("expect concrete class/type to be base class".into()); - }; - - - // TODO: when base class is generic, record the generic type parameter - // TODO: check to prevent cyclic base class - class_ancestors.push(base_id); + if let TypeAnnotation::ConcretizedCustomClassKind { .. } = base_ty { + // TODO: check to prevent cyclic base class + class_ancestors.push(base_ty); + } else { + return Err("class base declaration can only be concretized custom class".into()) + } } } Ok(()) @@ -477,284 +786,270 @@ impl TopLevelComposer { /// step 3, class fields and methods fn analyze_top_level_class_fields_methods(&mut self) -> Result<(), String> { - let mut max_iter = to_be_analyzed_class.len() * 4; - 'class: loop { - if to_be_analyzed_class.is_empty() && { - max_iter -= 1; - max_iter > 0 - } { - break; - } - - let class_ind = to_be_analyzed_class.remove(0).0; - let (class_name, class_body_ast, class_bases_ast, class_resolver, class_ancestors) = { - let (class_def, class_ast) = &mut def_ast_list[class_ind]; - if let Some(ast::Located { - node: ast::StmtKind::ClassDef { name, body, bases, .. }, - .. - }) = class_ast.as_ref() - { - if let TopLevelDef::Class { resolver, ancestors, .. } = - class_def.write().deref() - { - (name, body, bases, resolver.as_ref().unwrap().clone(), ancestors.clone()) - } else { - unreachable!() - } - } else { - unreachable!("should be class def ast") - } - }; - let class_resolver = class_resolver.as_ref().lock(); - let class_resolver = class_resolver.deref(); - - let all_base_class_analyzed = { - let not_yet_analyzed = - to_be_analyzed_class.clone().into_iter().collect::>(); - let base = class_ancestors.clone().into_iter().collect::>(); - let intersection = not_yet_analyzed.intersection(&base).collect_vec(); - intersection.is_empty() - }; - if !all_base_class_analyzed { - to_be_analyzed_class.push(DefinitionId(class_ind)); - continue 'class; - } - - // get the bases type, can directly do this since it - // already pass the check in the previous stages - let class_bases_ty = class_bases_ast - .iter() - .filter_map(|x| { - self.parse_type_annotation(class_resolver, x).ok() - }) - .collect_vec(); - - // need these vectors to check re-defining methods, class fields - // and store the parsed result in case some method cannot be typed for now - let mut class_methods_parsing_result: Vec<(String, Type, DefinitionId)> = vec![]; - let mut class_fields_parsing_result: Vec<(String, Type)> = vec![]; - for b in class_body_ast { - if let ast::StmtKind::FunctionDef { - args: method_args_ast, - body: method_body_ast, - name: method_name, - returns: method_returns_ast, - .. - } = &b.node - { - let arg_name_tys: Vec<(String, Type)> = { - let mut result = vec![]; - for a in &method_args_ast.args { - if a.node.arg != "self" { - let annotation = a - .node - .annotation - .as_ref() - .ok_or_else(|| { - "type annotation for function parameter is needed" - .to_string() - })? - .as_ref(); - - let ty = self.parse_type_annotation(class_resolver, annotation)?; - if !Self::check_ty_analyzed(ty, unifier, to_be_analyzed_class) { - to_be_analyzed_class.push(DefinitionId(class_ind)); - continue 'class; - } - result.push((a.node.arg.to_string(), ty)); - } else { - // TODO: handle self, how - unimplemented!() - } - } - result - }; - - let method_type_var = arg_name_tys - .iter() - .filter_map(|(_, ty)| { - let ty_enum = unifier.get_ty(*ty); - if let TypeEnum::TVar { id, .. } = ty_enum.as_ref() { - Some((*id, *ty)) - } else { - None - } - }) - .collect::>(); - - let ret_ty = { - if method_name != "__init__" { - let ty = method_returns_ast - .as_ref() - .map(|x| { - self.parse_type_annotation(class_resolver, x) - }) - .ok_or_else(|| "return type annotation error".to_string())??; - if !Self::check_ty_analyzed(ty, unifier, to_be_analyzed_class) { - to_be_analyzed_class.push(DefinitionId(class_ind)); - continue 'class; - } else { - ty - } - } else { - // TODO: __init__ function, self type, how - unimplemented!() - } - }; - - // handle fields - let class_field_name_tys: Option> = if method_name - == "__init__" - { - let mut result: Vec<(String, Type)> = vec![]; - for body in method_body_ast { - match &body.node { - ast::StmtKind::AnnAssign { target, annotation, .. } - if { - if let ast::ExprKind::Attribute { value, .. } = &target.node - { - matches!( - &value.node, - ast::ExprKind::Name { id, .. } if id == "self") - } else { - false - } - } => - { - let field_ty = - self.parse_type_annotation(class_resolver, annotation)?; - if !Self::check_ty_analyzed( - field_ty, - unifier, - to_be_analyzed_class, - ) { - to_be_analyzed_class.push(DefinitionId(class_ind)); - continue 'class; - } else { - result.push(( - if let ast::ExprKind::Attribute { attr, .. } = - &target.node - { - attr.to_string() - } else { - unreachable!() - }, - field_ty, - )) - } - } - - // exclude those without type annotation - ast::StmtKind::Assign { targets, .. } - if { - if let ast::ExprKind::Attribute { value, .. } = - &targets[0].node - { - matches!( - &value.node, - ast::ExprKind::Name {id, ..} if id == "self") - } else { - false - } - } => - { - return Err("class fields type annotation needed".into()) - } - - // do nothing - _ => {} - } - } - Some(result) - } else { - None - }; - - // current method all type ok, put the current method into the list - if class_methods_parsing_result.iter().any(|(name, _, _)| name == method_name) { - return Err("duplicate method definition".into()); - } else { - class_methods_parsing_result.push(( - method_name.clone(), - unifier.add_ty(TypeEnum::TFunc( - FunSignature { - ret: ret_ty, - args: arg_name_tys - .into_iter() - .map(|(name, ty)| FuncArg { name, ty, default_value: None }) - .collect_vec(), - vars: method_type_var, - } - .into(), - )), - *self - .class_method_to_def_id - .get(&Self::name_mangling(class_name.clone(), method_name)) - .unwrap(), - )) - } - - // put the fiedlds inside - if let Some(class_field_name_tys) = class_field_name_tys { - assert!(class_fields_parsing_result.is_empty()); - class_fields_parsing_result.extend(class_field_name_tys); - } - } else { - // what should we do with `class A: a = 3`? - // do nothing, continue the for loop to iterate class ast - continue; - } - } - - // now it should be confirmed that every - // methods and fields of the class can be correctly typed, put the results - // into the actual class def method and fields field - let (class_def, _) = &def_ast_list[class_ind]; - let mut class_def = class_def.write(); - if let TopLevelDef::Class { fields, methods, .. } = class_def.deref_mut() { - for (ref n, ref t) in class_fields_parsing_result { - fields.push((n.clone(), *t)); - } - for (n, t, id) in &class_methods_parsing_result { - methods.push((n.clone(), *t, *id)); - } - } else { - unreachable!() - } - - // change the signature field of the class methods - for (_, ty, id) in &class_methods_parsing_result { - let (method_def, _) = &def_ast_list[id.0]; - let mut method_def = method_def.write(); - if let TopLevelDef::Function { signature, .. } = method_def.deref_mut() { - *signature = *ty; - } - } + let temp_def_list = self.extract_def_list(); + let unifier = self.unifier.borrow_mut(); + let primitives = &self.primitives_ty; + let def_ast_list = &self.definition_ast_list; + + let mut type_var_to_concrete_def: HashMap = HashMap::new(); + for (class_def, class_ast) in def_ast_list { + Self::analyze_single_class( + class_def.clone(), + &class_ast.as_ref().unwrap().node, + &temp_def_list, + unifier, + primitives, + &mut type_var_to_concrete_def + )? } + + // base class methods add and check + // TODO: + + // unification of previously assigned typevar + for (ty, def) in type_var_to_concrete_def { + let target_ty = get_type_from_type_annotation_kinds(&temp_def_list, unifier, primitives, &def)?; + unifier.unify(ty, target_ty)?; + } + Ok(()) } + /// step 4, after class methods are done fn analyze_top_level_function(&mut self) -> Result<(), String> { - unimplemented!() - } + let def_list = &self.definition_ast_list; + let temp_def_list = self.extract_def_list(); + let unifier = self.unifier.borrow_mut(); + let primitives_store = &self.primitives_ty; - fn analyze_top_level_field_instantiation(&mut self) -> Result<(), String> { - unimplemented!() - } + for (function_def, function_ast) in def_list { + let function_def = function_def.read(); + let function_def = function_def.deref(); + let function_ast = if let Some(function_ast) = function_ast { + function_ast + } else { + continue; + }; + if let TopLevelDef::Function { signature: dummy_ty, resolver, .. } = function_def { + if let ast::StmtKind::FunctionDef { args, returns, .. } = &function_ast.node { + let resolver = resolver.as_ref(); + let resolver = resolver.unwrap(); + let resolver = resolver.deref().lock(); + let function_resolver = resolver.deref(); + + let arg_types = { + args + .args + .iter() + .map(|x| -> Result { + let annotation = x + .node + .annotation + .as_ref() + .ok_or_else(|| "function parameter type annotation needed".to_string())? + .as_ref(); + Ok(FuncArg { + name: x.node.arg.clone(), + ty: function_resolver.parse_type_annotation( + temp_def_list.as_slice(), + unifier, + primitives_store, + annotation + )?, + // TODO: function type var + default_value: Default::default() + }) + }) + .collect::, _>>()? + }; - fn check_ty_analyzed(ty: Type, unifier: &mut Unifier, to_be_analyzed: &[DefinitionId]) -> bool { - let type_enum = unifier.get_ty(ty); - match type_enum.as_ref() { - TypeEnum::TObj { obj_id, .. } => !to_be_analyzed.contains(obj_id), - TypeEnum::TVirtual { ty } => { - if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty(*ty).as_ref() { - !to_be_analyzed.contains(obj_id) + let return_ty = { + let return_annotation = returns + .as_ref() + .ok_or_else(|| "function return type needed".to_string())? + .as_ref(); + function_resolver.parse_type_annotation( + temp_def_list.as_slice(), + unifier, + primitives_store, + return_annotation + )? + }; + + let function_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: arg_types, + ret: return_ty, + // TODO: handle var map + vars: Default::default() + }.into())); + unifier.unify(*dummy_ty, function_ty)?; } else { - unreachable!() + unreachable!("must be both function"); } + } else { + continue; } - TypeEnum::TVar { .. } => true, - _ => unreachable!(), - } + }; + Ok(()) } -} + + /// step 5, field instantiation? + fn analyze_top_level_field_instantiation(&mut self) -> Result<(), String> { + // TODO: + unimplemented!() + } + + fn analyze_single_class( + class_def: Arc>, + class_ast: &ast::StmtKind<()>, + temp_def_list: &[Arc>], + unifier: &mut Unifier, + primitives: &PrimitiveStore, + type_var_to_concrete_def: &mut HashMap + ) -> Result<(), String> { + let mut class_def = class_def.write(); + let ( + _class_id, + _class_name, + _class_bases_ast, + class_body_ast, + _class_ancestor_def, + class_fields_def, + class_methods_def, + _class_type_vars_def, + class_resolver, + ) = if let TopLevelDef::Class { + object_id, + ancestors, + fields, + methods, + resolver, + type_vars + } = class_def.deref_mut() { + if let ast::StmtKind::ClassDef { name, bases, body, .. } = &class_ast { + (object_id, name.clone(), bases, body, ancestors, fields, methods, type_vars, resolver) + } else { + unreachable!("here must be class def ast"); + } + } else { + unreachable!("here must be class def ast"); + }; + let class_resolver = class_resolver.as_ref().unwrap(); + let mut class_resolver = class_resolver.lock(); + let class_resolver = class_resolver.deref_mut(); + + for b in class_body_ast { + if let ast::StmtKind::FunctionDef { args, returns, name, body, .. } = &b.node { + let (method_dummy_ty, ..) = Self::get_class_method_def_info(class_methods_def, name)?; + // TODO: handle self arg + // TODO: handle parameter with same name + let arg_type: Vec = { + let mut result = Vec::new(); + for x in &args.args{ + let name = x.node.arg.clone(); + let type_ann = { + let annotation_expr = x + .node + .annotation + .as_ref() + .ok_or_else(|| "type annotation needed".to_string())? + .as_ref(); + parse_ast_to_type_annotation_kinds( + class_resolver, + temp_def_list, + unifier, + primitives, + annotation_expr + )? + }; + if let TypeAnnotation::TypeVarKind(_ty) = &type_ann { + // TODO: need to handle to different type vars that are + // asscosiated with the class and that are not + } + let dummy_func_arg = FuncArg { + name, + ty: unifier.get_fresh_var().0, + // TODO: symbol default value? + default_value: None + }; + // push the dummy type and the type annotation + // into the list for later unification + type_var_to_concrete_def.insert(dummy_func_arg.ty, type_ann.clone()); + result.push(dummy_func_arg) + } + result + }; + let ret_type = { + let result = returns + .as_ref() + .ok_or_else(|| "method return type annotation needed".to_string())? + .as_ref(); + let annotation = parse_ast_to_type_annotation_kinds( + class_resolver, + temp_def_list, + unifier, + primitives, + result + )?; + let dummy_return_type = unifier.get_fresh_var().0; + type_var_to_concrete_def.insert(dummy_return_type, annotation.clone()); + dummy_return_type + }; + // TODO: handle var map, to create a new copy of type var + // while tracking the type var associated with class + let method_var_map: HashMap = HashMap::new(); + let method_type = unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: arg_type, + ret: ret_type, + vars: method_var_map + }.into())); + // unify now since function type is not in type annotation define + // which is fine since type within method_type will be subst later + unifier.unify(method_dummy_ty, method_type)?; + + // class fields + if name == "__init__" { + for b in body { + let mut defined_fields: HashSet = HashSet::new(); + // TODO: check the type of value, field instantiation check + if let ast::StmtKind::AnnAssign { annotation, target, value: _, .. } = &b.node { + if let ast::ExprKind::Attribute { value, attr, .. } = &target.node { + if matches!(&value.node, ast::ExprKind::Name { id, .. } if id == "self") { + if defined_fields.insert(attr.to_string()) { + let dummy_field_type = unifier.get_fresh_var().0; + class_fields_def.push((attr.to_string(), dummy_field_type)); + let annotation = parse_ast_to_type_annotation_kinds( + class_resolver, + &temp_def_list, + unifier, + primitives, + annotation.as_ref() + )?; + type_var_to_concrete_def.insert(dummy_field_type, annotation); + } else { + return Err("same class fields defined twice".into()); + } + } + } + } + } + } + } else { + continue; + } + }; + Ok(()) + } + + fn get_class_method_def_info( + class_methods_def: &[(String, Type, DefinitionId)], + method_name: &str + ) -> Result<(Type, DefinitionId), String> { + for (name, ty, def_id) in class_methods_def { + if name == method_name { + return Ok((*ty, *def_id)); + } + } + Err(format!("no method {} in the current class", method_name)) + } +} \ No newline at end of file