diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 994b9de4..39a06a09 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -1,5 +1,5 @@ -use std::{cell::RefCell, sync::Arc}; use std::collections::HashMap; +use std::{cell::RefCell, sync::Arc}; use crate::top_level::{DefinitionId, TopLevelDef}; use crate::typecheck::{ @@ -95,19 +95,34 @@ pub fn parse_type_annotation( Subscript { value, slice, .. } => { if let Name { id, .. } = &value.node { if id == "virtual" { - let ty = - parse_type_annotation(resolver, top_level_defs, unifier, primitives, slice)?; + let ty = parse_type_annotation( + resolver, + top_level_defs, + unifier, + primitives, + slice, + )?; Ok(unifier.add_ty(TypeEnum::TVirtual { ty })) } else { let types = if let Tuple { elts, .. } = &slice.node { elts.iter() .map(|v| { - parse_type_annotation(resolver, top_level_defs, unifier, primitives, v) + parse_type_annotation( + resolver, + top_level_defs, + unifier, + primitives, + v, + ) }) .collect::, _>>()? } else { vec![parse_type_annotation( - resolver, top_level_defs, unifier, primitives, slice, + resolver, + top_level_defs, + unifier, + primitives, + slice, )?] }; diff --git a/nac3core/src/top_level.rs b/nac3core/src/top_level.rs index 052a7f93..ed612d52 100644 --- a/nac3core/src/top_level.rs +++ b/nac3core/src/top_level.rs @@ -5,9 +5,9 @@ use std::{collections::HashMap, collections::HashSet, sync::Arc}; use self::top_level_type_annotation_info::*; use super::typecheck::type_inferencer::PrimitiveStore; use super::typecheck::typedef::{SharedUnifier, Type, TypeEnum, Unifier}; -use crate::typecheck::{typedef::{FunSignature, FuncArg}}; use crate::symbol_resolver::SymbolResolver; -use itertools::{Itertools, izip}; +use crate::typecheck::typedef::{FunSignature, FuncArg}; +use itertools::{izip, Itertools}; use parking_lot::{Mutex, RwLock}; use rustpython_parser::ast::{self, Stmt}; @@ -23,7 +23,7 @@ pub mod top_level_type_annotation_info { ConcretizedCustomClassKind { id: DefinitionId, // can not be type var, others are all fine - params: Vec + params: Vec, }, // can only be ConcretizedCustomClassKind VirtualKind(Box), @@ -39,10 +39,22 @@ pub mod top_level_type_annotation_info { expr: &ast::Expr, ) -> Result { let results = vec![ - parse_ast_to_concrete_primitive_kind(resolver, top_level_defs, unifier, primitives, expr), - parse_ast_to_concretized_custom_class_kind(resolver, top_level_defs, unifier, primitives, expr), + parse_ast_to_concrete_primitive_kind( + resolver, + top_level_defs, + unifier, + primitives, + expr, + ), + parse_ast_to_concretized_custom_class_kind( + resolver, + top_level_defs, + unifier, + primitives, + expr, + ), parse_ast_to_type_variable_kind(resolver, top_level_defs, unifier, primitives, expr), - parse_ast_to_virtual_kind(resolver, top_level_defs, unifier, primitives, expr) + parse_ast_to_virtual_kind(resolver, top_level_defs, unifier, primitives, expr), ]; let results = results.iter().filter(|x| x.is_ok()).collect_vec(); @@ -51,19 +63,18 @@ pub mod top_level_type_annotation_info { } else { Err("cannot be parsed the type annotation without ambiguity".into()) } - } pub fn get_type_from_type_annotation_kinds( top_level_defs: &[Arc>], unifier: &mut Unifier, primitives: &PrimitiveStore, - ann: &TypeAnnotation + ann: &TypeAnnotation, ) -> Result { match ann { TypeAnnotation::ConcretizedCustomClassKind { id, params } => { let class_def = top_level_defs[id.0].read(); - if let TopLevelDef::Class {fields, methods, type_vars, .. } = &*class_def { + if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*class_def { if type_vars.len() != params.len() { Err(format!( "unexpected number of type parameters: expected {} but got {}", @@ -73,12 +84,14 @@ pub mod top_level_type_annotation_info { } else { let param_ty = params .iter() - .map(|x| get_type_from_type_annotation_kinds( - top_level_defs, - unifier, - primitives, - x - )) + .map(|x| { + get_type_from_type_annotation_kinds( + top_level_defs, + unifier, + primitives, + x, + ) + }) .collect::, _>>()?; let subst = type_vars @@ -100,19 +113,15 @@ pub mod top_level_type_annotation_info { (name.clone(), subst_ty) }) .collect::>(); - - tobj_fields.extend( - fields - .iter() - .map(|(name, ty)| { - let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty); - (name.clone(), subst_ty) - }) - ); + + tobj_fields.extend(fields.iter().map(|(name, ty)| { + let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty); + (name.clone(), subst_ty) + })); Ok(unifier.add_ty(TypeEnum::TObj { obj_id: *id, fields: tobj_fields.into(), - params: subst.into() + params: subst.into(), })) } } else { @@ -121,7 +130,7 @@ pub mod top_level_type_annotation_info { } TypeAnnotation::SelfTypeKind(obj_id) => { let class_def = top_level_defs[obj_id.0].read(); - if let TopLevelDef::Class {fields, methods, type_vars, .. } = &*class_def { + if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*class_def { let subst = type_vars .iter() .map(|x| { @@ -135,16 +144,14 @@ pub mod top_level_type_annotation_info { let mut tobj_fields = methods .iter() - .map(|(name, ty, _)| { - (name.clone(), *ty) - }) + .map(|(name, ty, _)| (name.clone(), *ty)) .collect::>(); - + tobj_fields.extend(fields.clone().into_iter()); Ok(unifier.add_ty(TypeEnum::TObj { obj_id: *obj_id, fields: tobj_fields.into(), - params: subst.into() + params: subst.into(), })) } else { unreachable!("should be class def here") @@ -157,7 +164,7 @@ pub mod top_level_type_annotation_info { top_level_defs, unifier, primitives, - ty.as_ref() + ty.as_ref(), )?; Ok(unifier.add_ty(TypeEnum::TVirtual { ty })) } @@ -176,12 +183,12 @@ pub mod top_level_type_annotation_info { "int32" => Ok(TypeAnnotation::PrimitiveKind(primitives.int32)), "int64" => Ok(TypeAnnotation::PrimitiveKind(primitives.int64)), "float" => Ok(TypeAnnotation::PrimitiveKind(primitives.float)), - "bool" => Ok(TypeAnnotation::PrimitiveKind(primitives.bool)), - "None" => Ok(TypeAnnotation::PrimitiveKind(primitives.none)), - _ => Err("not primitive".into()) - } + "bool" => Ok(TypeAnnotation::PrimitiveKind(primitives.bool)), + "None" => Ok(TypeAnnotation::PrimitiveKind(primitives.none)), + _ => Err("not primitive".into()), + }, - _ => Err("not primitive".into()) + _ => Err("not primitive".into()), } } @@ -194,15 +201,19 @@ pub mod top_level_type_annotation_info { ) -> Result { match &expr.node { ast::ExprKind::Name { id, .. } => match id.as_str() { - "int32" | "int64" | "float" | "bool" | "None" => - Err("expect custom class instead of primitives here".into()), + "int32" | "int64" | "float" | "bool" | "None" => { + Err("expect custom class instead of primitives here".into()) + } x => { let obj_id = resolver .get_identifier_def(x) .ok_or_else(|| "unknown class name".to_string())?; let def = top_level_defs[obj_id.0].read(); if let TopLevelDef::Class { .. } = &*def { - Ok(TypeAnnotation::ConcretizedCustomClassKind { id: obj_id, params: vec![]}) + Ok(TypeAnnotation::ConcretizedCustomClassKind { + id: obj_id, + params: vec![], + }) } else { Err("function cannot be used as a type".into()) } @@ -211,13 +222,15 @@ pub mod top_level_type_annotation_info { ast::ExprKind::Subscript { value, slice, .. } => { if let ast::ExprKind::Name { id, .. } = &value.node { - if vec!["virtual", "Generic"].contains(&id.as_str()) { return Err("keywords cannot be class name".into()) } + if vec!["virtual", "Generic"].contains(&id.as_str()) { + return Err("keywords cannot be class name".into()); + } let obj_id = resolver .get_identifier_def(id) .ok_or_else(|| "unknown class name".to_string())?; let def = top_level_defs[obj_id.0].read(); if let TopLevelDef::Class { .. } = &*def { - let param_type_infos = + let param_type_infos = if let ast::ExprKind::Tuple { elts, .. } = &slice.node { elts.iter() .map(|v| { @@ -226,28 +239,40 @@ pub mod top_level_type_annotation_info { top_level_defs, unifier, primitives, - v + v, ) }) .collect::, _>>()? } else { vec![parse_ast_to_type_annotation_kinds( - resolver, top_level_defs, unifier, primitives, slice, + resolver, + top_level_defs, + unifier, + primitives, + slice, )?] }; - if param_type_infos.iter().any(|x| matches!(x, TypeAnnotation::TypeVarKind( .. ))) { - return Err("cannot apply type variable to class generic parameters".into()) + if param_type_infos + .iter() + .any(|x| matches!(x, TypeAnnotation::TypeVarKind(..))) + { + return Err( + "cannot apply type variable to class generic parameters".into() + ); } - Ok(TypeAnnotation::ConcretizedCustomClassKind { id: obj_id, params: param_type_infos }) + Ok(TypeAnnotation::ConcretizedCustomClassKind { + id: obj_id, + params: param_type_infos, + }) } else { Err("function cannot be used as a type".into()) } } else { Err("unsupported expression type".into()) } - }, + } - _ => Err("unsupported expression type".into()) + _ => Err("unsupported expression type".into()), } } @@ -260,13 +285,14 @@ pub mod top_level_type_annotation_info { ) -> Result { match &expr.node { ast::ExprKind::Subscript { value, slice, .. } - if matches!(&value.node, ast::ExprKind::Name { id, .. } if id == "virtual") => { + if { matches!(&value.node, ast::ExprKind::Name { id, .. } if id == "virtual") } => + { let def = parse_ast_to_concretized_custom_class_kind( resolver, top_level_defs, unifier, primitives, - slice.as_ref() + slice.as_ref(), )?; if !matches!(def, TypeAnnotation::ConcretizedCustomClassKind { .. }) { unreachable!("should must be concretized custom class kind") @@ -274,7 +300,7 @@ pub mod top_level_type_annotation_info { Ok(TypeAnnotation::VirtualKind(def.into())) } - _ => Err("virtual type annotation must be like `virtual[ .. ]`".into()) + _ => Err("virtual type annotation must be like `virtual[ .. ]`".into()), } } @@ -360,12 +386,10 @@ pub struct TopLevelComposer { impl TopLevelComposer { pub fn make_top_level_context(self) -> TopLevelContext { TopLevelContext { - definitions: RwLock::new(self - .definition_ast_list - .into_iter() - .map(|(x, ..)| x) - .collect::>() - ).into(), + definitions: RwLock::new( + self.definition_ast_list.into_iter().map(|(x, ..)| x).collect::>(), + ) + .into(), // FIXME: all the big unifier or? unifiers: Default::default(), } @@ -436,7 +460,7 @@ impl TopLevelComposer { "bool".into(), "none".into(), "None".into(), - ] + ], }; ( vec![ @@ -489,11 +513,7 @@ impl TopLevelComposer { } fn extract_def_list(&self) -> Vec>> { - self - .definition_ast_list - .iter() - .map(|(def, ..)| def.clone()) - .collect_vec() + self.definition_ast_list.iter().map(|(def, ..)| def.clone()).collect_vec() } /// step 0, register, just remeber the names of top level classes/function @@ -505,9 +525,9 @@ impl TopLevelComposer { match &ast.node { ast::StmtKind::ClassDef { name, body, .. } => { if self.keyword_list.contains(name) { - return Err("cannot use keyword as a class name".into()) + return Err("cannot use keyword as a class name".into()); } - + let class_name = name.to_string(); let class_def_id = self.definition_ast_list.len(); @@ -531,7 +551,7 @@ impl TopLevelComposer { // in this top level def, method name is prefixed with the class name Arc>, DefinitionId, - Type + Type, )> = Vec::new(); let mut class_method_index_offset = 0; for b in body { @@ -553,15 +573,14 @@ impl TopLevelComposer { )) .into(), DefinitionId(method_def_id), - dummy_method_type.0 + dummy_method_type.0, )); - } else { // do nothing - continue + continue; } } - + // move the ast to the entry of the class in the ast_list class_def_ast.1 = Some(ast); // get the methods into the class_def @@ -569,7 +588,9 @@ impl TopLevelComposer { let mut class_def = class_def_ast.0.write(); if let TopLevelDef::Class { methods, .. } = class_def.deref_mut() { methods.push((name.clone(), *ty, *id)) - } else { unreachable!() } + } else { + unreachable!() + } } // now class_def_ast and class_method_def_ast_ids are ok, put them into actual def list in correct order self.definition_ast_list.push(class_def_ast); @@ -647,8 +668,9 @@ impl TopLevelComposer { // should update the TopLevelDef::Class.typevars and the TypeEnum::TObj.params ast::ExprKind::Subscript { value, slice, .. } if { - matches!(&value.node, ast::ExprKind::Name { id, .. } if id == "Generic") - } => { + matches!(&value.node, ast::ExprKind::Name { id, .. } if id == "Generic") + } => + { if !is_generic { is_generic = true; } else { @@ -672,7 +694,7 @@ impl TopLevelComposer { &temp_def_list, unifier, primitives_store, - e + e, ) }) .collect::, _>>()?; @@ -702,11 +724,11 @@ impl TopLevelComposer { } }) .collect_vec(); - + if !all_unique_type_var { return Err("expect unique type variables".into()); } - + // add to TopLevelDef class_def_type_vars.extend(type_vars); } @@ -751,11 +773,14 @@ impl TopLevelComposer { &value.node, ast::ExprKind::Name { id, .. } if id == "Generic" ) - ) { continue } - + ) { + continue; + } + if has_base { return Err("a class def can only have at most one base class \ - declaration and one generic declaration".into()) + declaration and one generic declaration" + .into()); } has_base = true; @@ -764,14 +789,16 @@ impl TopLevelComposer { &temp_def_list, self.unifier.borrow_mut(), &self.primitives_ty, - b + b, )?; if let TypeAnnotation::ConcretizedCustomClassKind { .. } = base_ty { // TODO: check to prevent cyclic base class class_ancestors.push(base_ty); } else { - return Err("class base declaration can only be concretized custom class".into()) + return Err( + "class base declaration can only be concretized custom class".into() + ); } } } @@ -784,7 +811,7 @@ impl TopLevelComposer { let unifier = self.unifier.borrow_mut(); let primitives = &self.primitives_ty; let def_ast_list = &self.definition_ast_list; - + let mut type_var_to_concrete_def: HashMap = HashMap::new(); for (class_def, class_ast) in def_ast_list { Self::analyze_single_class( @@ -793,19 +820,20 @@ impl TopLevelComposer { &temp_def_list, unifier, primitives, - &mut type_var_to_concrete_def + &mut type_var_to_concrete_def, )? } // base class methods add and check - // TODO: + // TODO: // unification of previously assigned typevar for (ty, def) in type_var_to_concrete_def { - let target_ty = get_type_from_type_annotation_kinds(&temp_def_list, unifier, primitives, &def)?; + let target_ty = + get_type_from_type_annotation_kinds(&temp_def_list, unifier, primitives, &def)?; unifier.unify(ty, target_ty)?; } - + Ok(()) } @@ -830,17 +858,18 @@ impl TopLevelComposer { let resolver = resolver.unwrap(); let resolver = resolver.deref().lock(); let function_resolver = resolver.deref(); - + let arg_types = { - args - .args + args.args .iter() .map(|x| -> Result { let annotation = x .node .annotation .as_ref() - .ok_or_else(|| "function parameter type annotation needed".to_string())? + .ok_or_else(|| { + "function parameter type annotation needed".to_string() + })? .as_ref(); Ok(FuncArg { name: x.node.arg.clone(), @@ -848,10 +877,10 @@ impl TopLevelComposer { temp_def_list.as_slice(), unifier, primitives_store, - annotation + annotation, )?, // TODO: function type var - default_value: Default::default() + default_value: Default::default(), }) }) .collect::, _>>()? @@ -866,16 +895,19 @@ impl TopLevelComposer { temp_def_list.as_slice(), unifier, primitives_store, - return_annotation + return_annotation, )? }; - let function_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: arg_types, - ret: return_ty, - // TODO: handle var map - vars: Default::default() - }.into())); + let function_ty = unifier.add_ty(TypeEnum::TFunc( + FunSignature { + args: arg_types, + ret: return_ty, + // TODO: handle var map + vars: Default::default(), + } + .into(), + )); unifier.unify(*dummy_ty, function_ty)?; } else { unreachable!("must be both function"); @@ -883,7 +915,7 @@ impl TopLevelComposer { } else { continue; } - }; + } Ok(()) } @@ -899,7 +931,7 @@ impl TopLevelComposer { temp_def_list: &[Arc>], unifier: &mut Unifier, primitives: &PrimitiveStore, - type_var_to_concrete_def: &mut HashMap + type_var_to_concrete_def: &mut HashMap, ) -> Result<(), String> { let mut class_def = class_def.write(); let ( @@ -920,9 +952,20 @@ impl TopLevelComposer { resolver, type_vars, .. - } = class_def.deref_mut() { + } = class_def.deref_mut() + { if let ast::StmtKind::ClassDef { name, bases, body, .. } = &class_ast { - (object_id, name.clone(), bases, body, ancestors, fields, methods, type_vars, resolver) + ( + object_id, + name.clone(), + bases, + body, + ancestors, + fields, + methods, + type_vars, + resolver, + ) } else { unreachable!("here must be class def ast"); } @@ -935,12 +978,13 @@ impl TopLevelComposer { for b in class_body_ast { if let ast::StmtKind::FunctionDef { args, returns, name, body, .. } = &b.node { - let (method_dummy_ty, ..) = Self::get_class_method_def_info(class_methods_def, name)?; + let (method_dummy_ty, ..) = + Self::get_class_method_def_info(class_methods_def, name)?; // TODO: handle self arg // TODO: handle parameter with same name let arg_type: Vec = { let mut result = Vec::new(); - for x in &args.args{ + for x in &args.args { let name = x.node.arg.clone(); let type_ann = { let annotation_expr = x @@ -954,7 +998,7 @@ impl TopLevelComposer { temp_def_list, unifier, primitives, - annotation_expr + annotation_expr, )? }; if let TypeAnnotation::TypeVarKind(_ty) = &type_ann { @@ -965,7 +1009,7 @@ impl TopLevelComposer { name, ty: unifier.get_fresh_var().0, // TODO: symbol default value? - default_value: None + default_value: None, }; // push the dummy type and the type annotation // into the list for later unification @@ -984,7 +1028,7 @@ impl TopLevelComposer { temp_def_list, unifier, primitives, - result + result, )?; let dummy_return_type = unifier.get_fresh_var().0; type_var_to_concrete_def.insert(dummy_return_type, annotation.clone()); @@ -993,11 +1037,9 @@ impl TopLevelComposer { // TODO: handle var map, to create a new copy of type var // while tracking the type var associated with class let method_var_map: HashMap = HashMap::new(); - let method_type = unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: arg_type, - ret: ret_type, - vars: method_var_map - }.into())); + let method_type = unifier.add_ty(TypeEnum::TFunc( + FunSignature { args: arg_type, ret: ret_type, vars: method_var_map }.into(), + )); // unify now since function type is not in type annotation define // which is fine since type within method_type will be subst later unifier.unify(method_dummy_ty, method_type)?; @@ -1007,9 +1049,12 @@ impl TopLevelComposer { for b in body { let mut defined_fields: HashSet = HashSet::new(); // TODO: check the type of value, field instantiation check - if let ast::StmtKind::AnnAssign { annotation, target, value: _, .. } = &b.node { + if let ast::StmtKind::AnnAssign { annotation, target, value: _, .. } = + &b.node + { if let ast::ExprKind::Attribute { value, attr, .. } = &target.node { - if matches!(&value.node, ast::ExprKind::Name { id, .. } if id == "self") { + if matches!(&value.node, ast::ExprKind::Name { id, .. } if id == "self") + { if defined_fields.insert(attr.to_string()) { let dummy_field_type = unifier.get_fresh_var().0; class_fields_def.push((attr.to_string(), dummy_field_type)); @@ -1018,9 +1063,10 @@ impl TopLevelComposer { &temp_def_list, unifier, primitives, - annotation.as_ref() + annotation.as_ref(), )?; - type_var_to_concrete_def.insert(dummy_field_type, annotation); + type_var_to_concrete_def + .insert(dummy_field_type, annotation); } else { return Err("same class fields defined twice".into()); } @@ -1032,13 +1078,13 @@ impl TopLevelComposer { } else { continue; } - }; + } Ok(()) } fn get_class_method_def_info( class_methods_def: &[(String, Type, DefinitionId)], - method_name: &str + method_name: &str, ) -> Result<(Type, DefinitionId), String> { for (name, ty, def_id) in class_methods_def { if name == method_name { @@ -1047,4 +1093,4 @@ impl TopLevelComposer { } Err(format!("no method {} in the current class", method_name)) } -} \ No newline at end of file +} diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 671b32e5..709d1742 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -672,7 +672,11 @@ impl Unifier { } fn incompatible_types(&mut self, a: Type, b: Type) -> Result<(), String> { - Err(format!("Cannot unify {} with {}", self.internal_stringify(a), self.internal_stringify(b))) + Err(format!( + "Cannot unify {} with {}", + self.internal_stringify(a), + self.internal_stringify(b) + )) } /// Instantiate a function if it hasn't been instantiated. diff --git a/nac3core/src/typecheck/typedef/test.rs b/nac3core/src/typecheck/typedef/test.rs index ee795056..2940c276 100644 --- a/nac3core/src/typecheck/typedef/test.rs +++ b/nac3core/src/typecheck/typedef/test.rs @@ -469,10 +469,7 @@ fn test_rigid_var() { assert_eq!(env.unifier.unify(a, b), Err("Cannot unify var3 with var2".to_string())); env.unifier.unify(list_a, list_x).unwrap(); - assert_eq!( - env.unifier.unify(list_x, list_int), - Err("Cannot unify 0 with var2".to_string()) - ); + assert_eq!(env.unifier.unify(list_x, list_int), Err("Cannot unify 0 with var2".to_string())); env.unifier.replace_rigid_var(a, int); env.unifier.unify(list_x, list_int).unwrap();