From dd1be541b87718fd18b43c40610e164e48112325 Mon Sep 17 00:00:00 2001 From: ychenfo Date: Mon, 20 Sep 2021 14:24:16 +0800 Subject: [PATCH] nac3core: allow class to have no __init__, function/method name with module path added to ensure uniqueness --- nac3core/src/codegen/expr.rs | 6 +- nac3core/src/toplevel/composer.rs | 435 +++++++++++++------------- nac3core/src/toplevel/helper.rs | 2 +- nac3core/src/toplevel/test.rs | 100 +++--- nac3core/src/typecheck/typedef/mod.rs | 23 +- 5 files changed, 309 insertions(+), 257 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 4d290e27..ea24cce6 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -684,7 +684,11 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { unreachable!() } }; - return self.gen_call(Some((value.custom.unwrap(), val)), (&signature, fun_id), params); + return self.gen_call( + Some((value.custom.unwrap(), val)), + (&signature, fun_id), + params, + ); } _ => unimplemented!(), } diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index d1099a64..959676b3 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -17,11 +17,10 @@ pub struct TopLevelComposer { // keyword list to prevent same user-defined name pub keyword_list: HashSet, // to prevent duplicate definition - pub defined_class_name: HashSet, - pub defined_class_method_name: HashSet, - pub defined_function_name: HashSet, + pub defined_names: HashSet, // get the class def id of a class method pub method_class: HashMap, + // number of built-in function and classes in the definition list, later skip pub built_in_num: usize, } @@ -52,7 +51,7 @@ impl TopLevelComposer { }; let primitives_ty = primitives.0; let mut unifier = primitives.1; - let keyword_list: HashSet = HashSet::from_iter(vec![ + let mut keyword_list: HashSet = HashSet::from_iter(vec![ "Generic".into(), "virtual".into(), "list".into(), @@ -67,9 +66,7 @@ impl TopLevelComposer { "Kernel".into(), "KernelImmutable".into(), ]); - let mut defined_class_method_name: HashSet = Default::default(); - let mut defined_class_name: HashSet = Default::default(); - let mut defined_function_name: HashSet = Default::default(); + let defined_names: HashSet = Default::default(); let method_class: HashMap = Default::default(); let mut built_in_id: HashMap = Default::default(); @@ -91,9 +88,7 @@ impl TopLevelComposer { })), None, )); - defined_class_method_name.insert(name.clone()); - defined_class_name.insert(name.clone()); - defined_function_name.insert(name); + keyword_list.insert(name); } ( @@ -103,9 +98,7 @@ impl TopLevelComposer { primitives_ty, unifier, keyword_list, - defined_class_method_name, - defined_class_name, - defined_function_name, + defined_names, method_class, }, built_in_id, @@ -119,7 +112,7 @@ impl TopLevelComposer { self.definition_ast_list.iter().map(|(x, ..)| x.clone()).collect_vec(), ) .into(), - // FIXME: all the big unifier or? + // NOTE: only one for now unifiers: Arc::new(RwLock::new(vec![( self.unifier.get_shared_unifier(), self.primitives_ty, @@ -139,23 +132,21 @@ impl TopLevelComposer { resolver: Option>>, mod_path: String, ) -> Result<(String, DefinitionId, Option), String> { - let defined_class_name = &mut self.defined_class_name; - let defined_class_method_name = &mut self.defined_class_method_name; - let defined_function_name = &mut self.defined_function_name; + let defined_names = &mut self.defined_names; match &ast.node { - ast::StmtKind::ClassDef { name, body, .. } => { - if self.keyword_list.contains(name) { + ast::StmtKind::ClassDef { name: class_name, body, .. } => { + if self.keyword_list.contains(class_name) { return Err("cannot use keyword as a class name".into()); } - if !defined_class_name.insert({ + if !defined_names.insert({ let mut n = mod_path.clone(); - n.push_str(name.as_str()); + n.push_str(class_name.as_str()); n }) { return Err("duplicate definition of class".into()); } - let class_name = name.to_string(); + let class_name = class_name.clone(); let class_def_id = self.definition_ast_list.len(); // since later when registering class method, ast will still be used, @@ -165,8 +156,8 @@ impl TopLevelComposer { Arc::new(RwLock::new(Self::make_top_level_class_def( class_def_id, resolver.clone(), - name, - Some(constructor_ty) + class_name.as_str(), + Some(constructor_ty), ))), None, ); @@ -187,23 +178,27 @@ impl TopLevelComposer { // we do not push anything to the def list, so we keep track of the index // and then push in the correct order after the for loop let mut class_method_index_offset = 0; - let mut has_init = false; for b in body { if let ast::StmtKind::FunctionDef { name: method_name, .. } = &b.node { - if self.keyword_list.contains(name) { + if self.keyword_list.contains(method_name) { return Err("cannot use keyword as a method name".into()); } - let global_class_method_name = - Self::make_class_method_name(class_name.clone(), method_name); - if !defined_class_method_name.insert({ - let mut n = mod_path.clone(); - n.push_str(global_class_method_name.as_str()); - n - }) { - return Err("duplicate class method definition".into()); + if method_name.ends_with(|x: char| x.is_ascii_digit()) { + return Err(format!( + "function name `{}` must not end with numbers", + method_name + )); } - if method_name == "__init__" { - has_init = true; + let global_class_method_name = { + let mut n = mod_path.clone(); + n.push_str( + Self::make_class_method_name(class_name.clone(), method_name) + .as_str(), + ); + n + }; + if !defined_names.insert(global_class_method_name.clone()) { + return Err("duplicate class method definition".into()); } let method_def_id = self.definition_ast_list.len() + { // plus 1 here since we already have the class def @@ -232,9 +227,6 @@ impl TopLevelComposer { continue; } } - if !has_init { - return Err("class def must have __init__ method defined".into()); - } // move the ast to the entry of the class in the ast_list class_def_ast.1 = Some(ast); @@ -261,12 +253,16 @@ impl TopLevelComposer { if self.keyword_list.contains(name) { return Err("cannot use keyword as a top level function name".into()); } + if name.ends_with(|x: char| x.is_ascii_digit()) { + return Err(format!("function name `{}` must not end with numbers", name)); + } let fun_name = name.to_string(); - if !defined_function_name.insert({ + let global_fun_name = { let mut n = mod_path; n.push_str(name.as_str()); n - }) { + }; + if !defined_names.insert(global_fun_name.clone()) { return Err("duplicate top level function define".into()); } @@ -274,8 +270,7 @@ impl TopLevelComposer { // add to the definition list self.definition_ast_list.push(( RwLock::new(Self::make_top_level_function_def( - // TODO: is this fun_name or the above name with mod_path? - name.into(), + global_fun_name, name.into(), // dummy here, unify with correct type later ty_to_be_unified, @@ -824,66 +819,110 @@ impl TopLevelComposer { let mut defined_fields: HashSet = HashSet::new(); for b in class_body_ast { - if let ast::StmtKind::FunctionDef { args, returns, name, .. } = &b.node { - let (method_dummy_ty, method_id) = - Self::get_class_method_def_info(class_methods_def, name)?; + match &b.node { + ast::StmtKind::FunctionDef { args, returns, name, .. } => { + let (method_dummy_ty, method_id) = + Self::get_class_method_def_info(class_methods_def, name)?; - // the method var map can surely include the class's generic parameters - let mut method_var_map: HashMap = class_type_vars_def - .iter() - .map(|ty| { - if let TypeEnum::TVar { id, .. } = unifier.get_ty(*ty).as_ref() { - (*id, *ty) - } else { - unreachable!("must be type var here") + // the method var map can surely include the class's generic parameters + let mut method_var_map: HashMap = class_type_vars_def + .iter() + .map(|ty| { + if let TypeEnum::TVar { id, .. } = unifier.get_ty(*ty).as_ref() { + (*id, *ty) + } else { + unreachable!("must be type var here") + } + }) + .collect(); + + let arg_types: Vec = { + // check method parameters cannot have same name + let mut defined_paramter_name: HashSet = HashSet::new(); + let have_unique_fuction_parameter_name = args.args.iter().all(|x| { + defined_paramter_name.insert(x.node.arg.clone()) + && (!keyword_list.contains(&x.node.arg) || x.node.arg == "self") + }); + if !have_unique_fuction_parameter_name { + return Err("class method must have unique parameter names \ + and names thould not be the same as the keywords" + .into()); + } + if name == "__init__" && !defined_paramter_name.contains("self") { + return Err("__init__ function must have a `self` parameter".into()); + } + if !defined_paramter_name.contains("self") { + return Err("currently does not support static method".into()); } - }) - .collect(); - let arg_types: Vec = { - // check method parameters cannot have same name - let mut defined_paramter_name: HashSet = HashSet::new(); - let have_unique_fuction_parameter_name = args.args.iter().all(|x| { - defined_paramter_name.insert(x.node.arg.clone()) - && (!keyword_list.contains(&x.node.arg) || x.node.arg == "self") - }); - if !have_unique_fuction_parameter_name { - return Err("class method must have unique parameter names \ - and names thould not be the same as the keywords" - .into()); - } - if name == "__init__" && !defined_paramter_name.contains("self") { - return Err("__init__ function must have a `self` parameter".into()); - } - if !defined_paramter_name.contains("self") { - return Err("currently does not support static method".into()); - } + let mut result = Vec::new(); + for x in &args.args { + let name = x.node.arg.clone(); + if name != "self" { + let type_ann = { + let annotation_expr = x + .node + .annotation + .as_ref() + .ok_or_else(|| "type annotation needed".to_string())? + .as_ref(); + parse_ast_to_type_annotation_kinds( + class_resolver.as_ref(), + temp_def_list, + unifier, + primitives, + annotation_expr, + vec![(class_id, class_type_vars_def.clone())] + .into_iter() + .collect(), + )? + }; + // find type vars within this method parameter type annotation + let type_vars_within = + get_type_var_contained_in_type_annotation(&type_ann); + // handle the class type var and the method type var + for type_var_within in type_vars_within { + if let TypeAnnotation::TypeVarKind(ty) = type_var_within { + let id = Self::get_var_id(ty, unifier)?; + if let Some(prev_ty) = method_var_map.insert(id, ty) { + // if already in the list, make sure they are the same? + assert_eq!(prev_ty, ty); + } + } else { + unreachable!("must be type var annotation"); + } + } + // finish handling type vars + let dummy_func_arg = FuncArg { + name, + ty: unifier.get_fresh_var().0, + // TODO: default value? + default_value: None, + }; + // push the dummy type and the type annotation + // into the list for later unification + type_var_to_concrete_def + .insert(dummy_func_arg.ty, type_ann.clone()); + result.push(dummy_func_arg) + } + } + result + }; - let mut result = Vec::new(); - for x in &args.args { - let name = x.node.arg.clone(); - if name != "self" { - let type_ann = { - let annotation_expr = x - .node - .annotation - .as_ref() - .ok_or_else(|| "type annotation needed".to_string())? - .as_ref(); - parse_ast_to_type_annotation_kinds( - class_resolver.as_ref(), - temp_def_list, - unifier, - primitives, - annotation_expr, - vec![(class_id, class_type_vars_def.clone())] - .into_iter() - .collect(), - )? - }; - // find type vars within this method parameter type annotation + let ret_type = { + if let Some(result) = returns { + let result = result.as_ref(); + let annotation = parse_ast_to_type_annotation_kinds( + class_resolver.as_ref(), + temp_def_list, + unifier, + primitives, + result, + vec![(class_id, class_type_vars_def.clone())].into_iter().collect(), + )?; + // find type vars within this return type annotation let type_vars_within = - get_type_var_contained_in_type_annotation(&type_ann); + get_type_var_contained_in_type_annotation(&annotation); // handle the class type var and the method type var for type_var_within in type_vars_within { if let TypeAnnotation::TypeVarKind(ty) = type_var_within { @@ -896,128 +935,90 @@ impl TopLevelComposer { unreachable!("must be type var annotation"); } } - // finish handling type vars - let dummy_func_arg = FuncArg { - name, - ty: unifier.get_fresh_var().0, - // TODO: default value? - default_value: None, - }; - // push the dummy type and the type annotation - // into the list for later unification - type_var_to_concrete_def.insert(dummy_func_arg.ty, type_ann.clone()); - result.push(dummy_func_arg) + let dummy_return_type = unifier.get_fresh_var().0; + type_var_to_concrete_def.insert(dummy_return_type, annotation.clone()); + dummy_return_type + } else { + // if do not have return annotation, return none + // for uniform handling, still use type annoatation + let dummy_return_type = unifier.get_fresh_var().0; + type_var_to_concrete_def.insert( + dummy_return_type, + TypeAnnotation::PrimitiveKind(primitives.none), + ); + dummy_return_type } - } - result - }; + }; - let ret_type = { - if let Some(result) = returns { - let result = result.as_ref(); - let annotation = parse_ast_to_type_annotation_kinds( - class_resolver.as_ref(), - temp_def_list, - unifier, - primitives, - result, - vec![(class_id, class_type_vars_def.clone())].into_iter().collect(), - )?; - // find type vars within this return type annotation - let type_vars_within = - get_type_var_contained_in_type_annotation(&annotation); - // handle the class type var and the method type var - for type_var_within in type_vars_within { - if let TypeAnnotation::TypeVarKind(ty) = type_var_within { - let id = Self::get_var_id(ty, unifier)?; - if let Some(prev_ty) = method_var_map.insert(id, ty) { - // if already in the list, make sure they are the same? - assert_eq!(prev_ty, ty); - } - } else { - unreachable!("must be type var annotation"); - } - } - let dummy_return_type = unifier.get_fresh_var().0; - type_var_to_concrete_def.insert(dummy_return_type, annotation.clone()); - dummy_return_type - } else { - // if do not have return annotation, return none - // for uniform handling, still use type annoatation - let dummy_return_type = unifier.get_fresh_var().0; - type_var_to_concrete_def.insert( - dummy_return_type, - TypeAnnotation::PrimitiveKind(primitives.none), + if let TopLevelDef::Function { var_id, .. } = + temp_def_list.get(method_id.0).unwrap().write().deref_mut() + { + var_id.extend_from_slice( + method_var_map.keys().into_iter().copied().collect_vec().as_slice(), ); - dummy_return_type } - }; + let method_type = unifier.add_ty(TypeEnum::TFunc( + FunSignature { args: arg_types, ret: ret_type, vars: method_var_map } + .into(), + )); - if let TopLevelDef::Function { var_id, .. } = - temp_def_list.get(method_id.0).unwrap().write().deref_mut() - { - var_id.extend_from_slice( - method_var_map.keys().into_iter().copied().collect_vec().as_slice(), - ); + // unify now since function type is not in type annotation define + // which should be fine since type within method_type will be subst later + unifier.unify(method_dummy_ty, method_type)?; } - let method_type = unifier.add_ty(TypeEnum::TFunc( - FunSignature { args: arg_types, ret: ret_type, vars: method_var_map }.into(), - )); + ast::StmtKind::AnnAssign { target, annotation, value: None, .. } => { + if let ast::ExprKind::Name { id: attr, .. } = &target.node { + if defined_fields.insert(attr.to_string()) { + let dummy_field_type = unifier.get_fresh_var().0; + class_fields_def.push((attr.to_string(), dummy_field_type)); - // unify now since function type is not in type annotation define - // which should be fine since type within method_type will be subst later - unifier.unify(method_dummy_ty, method_type)?; - } else if let ast::StmtKind::AnnAssign { target, annotation, value: None, .. } = &b.node - { - if let ast::ExprKind::Name { id: attr, .. } = &target.node { - if defined_fields.insert(attr.to_string()) { - let dummy_field_type = unifier.get_fresh_var().0; - class_fields_def.push((attr.to_string(), dummy_field_type)); - - // handle Kernel[T], KernelImmutable[T] - let annotation = { - match &annotation.as_ref().node { - ast::ExprKind::Subscript { value, slice, .. } - if { - matches!(&value.node, ast::ExprKind::Name { id, .. } if id == "Kernel" || id == "KernelImmutable") - } => - { - slice + // handle Kernel[T], KernelImmutable[T] + let annotation = { + match &annotation.as_ref().node { + ast::ExprKind::Subscript { value, slice, .. } + if { + matches!(&value.node, ast::ExprKind::Name { id, .. } if id == "Kernel" || id == "KernelImmutable") + } => + { + slice + } + _ => annotation, } - _ => annotation, - } - }; + }; - let annotation = parse_ast_to_type_annotation_kinds( - class_resolver.as_ref(), - &temp_def_list, - unifier, - primitives, - annotation.as_ref(), - vec![(class_id, class_type_vars_def.clone())].into_iter().collect(), - )?; - // find type vars within this return type annotation - let type_vars_within = - get_type_var_contained_in_type_annotation(&annotation); - // handle the class type var and the method type var - for type_var_within in type_vars_within { - if let TypeAnnotation::TypeVarKind(t) = type_var_within { - if !class_type_vars_def.contains(&t) { - return Err("class fields can only use type \ - vars declared as class generic type vars" - .into()); + let annotation = parse_ast_to_type_annotation_kinds( + class_resolver.as_ref(), + &temp_def_list, + unifier, + primitives, + annotation.as_ref(), + vec![(class_id, class_type_vars_def.clone())].into_iter().collect(), + )?; + // find type vars within this return type annotation + let type_vars_within = + get_type_var_contained_in_type_annotation(&annotation); + // handle the class type var and the method type var + for type_var_within in type_vars_within { + if let TypeAnnotation::TypeVarKind(t) = type_var_within { + if !class_type_vars_def.contains(&t) { + return Err("class fields can only use type \ + vars declared as class generic type vars" + .into()); + } + } else { + unreachable!("must be type var annotation"); } - } else { - unreachable!("must be type var annotation"); } + type_var_to_concrete_def.insert(dummy_field_type, annotation); + } else { + return Err("same class fields defined twice".into()); } - type_var_to_concrete_def.insert(dummy_field_type, annotation); } else { - return Err("same class fields defined twice".into()); + return Err("unsupported statement type in class definition body".into()); } } - } else { - return Err("unsupported statement type in class definition body".into()); + ast::StmtKind::Pass => {} + _ => return Err("unsupported statement type in class definition body".into()), } } Ok(()) @@ -1162,10 +1163,16 @@ impl TopLevelComposer { for (id, (def, ast)) in self.definition_ast_list.iter().enumerate().skip(self.built_in_num) { let mut function_def = def.write(); - if let TopLevelDef::Function { instance_to_stmt, name, simple_name, signature, resolver, .. } = - &mut *function_def + if let TopLevelDef::Function { + instance_to_stmt, + name, + simple_name, + signature, + resolver, + .. + } = &mut *function_def { - if let TypeEnum::TFunc(func_sig) = self.unifier.get_ty(*signature).as_ref() { + if let TypeEnum::TFunc(func_sig) = self.unifier.get_ty(*signature).as_ref() { let FunSignature { args, ret, vars } = &*func_sig.borrow(); // None if is not class method let self_type = { @@ -1181,11 +1188,13 @@ impl TopLevelComposer { &ty_ann, )?; if simple_name == "__init__" { - let fn_type = self.unifier.add_ty(TypeEnum::TFunc(RefCell::new(FunSignature { - args: args.clone(), - ret: self_ty, - vars: vars.clone() - }))); + let fn_type = self.unifier.add_ty(TypeEnum::TFunc( + RefCell::new(FunSignature { + args: args.clone(), + ret: self_ty, + vars: vars.clone(), + }), + )); self.unifier.unify(fn_type, constructor.unwrap())?; } Some(self_ty) diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index f128b275..114c345d 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -93,7 +93,7 @@ impl TopLevelComposer { index: usize, resolver: Option>>, name: &str, - constructor: Option + constructor: Option, ) -> TopLevelDef { TopLevelDef::Class { name: name.to_string(), diff --git a/nac3core/src/toplevel/test.rs b/nac3core/src/toplevel/test.rs index b6b9843d..730e9344 100644 --- a/nac3core/src/toplevel/test.rs +++ b/nac3core/src/toplevel/test.rs @@ -94,7 +94,7 @@ fn test_simple_register(source: Vec<&str>) { let ast = parse_program(s).unwrap(); let ast = ast[0].clone(); - composer.register_top_level(ast, None, "__main__".into()).unwrap(); + composer.register_top_level(ast, None, "".into()).unwrap(); } } @@ -142,7 +142,7 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s let ast = ast[0].clone(); let (id, def_id, ty) = - composer.register_top_level(ast, Some(resolver.clone()), "__main__".into()).unwrap(); + composer.register_top_level(ast, Some(resolver.clone()), "".into()).unwrap(); internal_resolver.add_id_def(id.clone(), def_id); if let Some(ty) = ty { internal_resolver.add_id_type(id, ty); @@ -151,7 +151,8 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s composer.start_analysis(true).unwrap(); - for (i, (def, _)) in composer.definition_ast_list.iter().skip(composer.built_in_num).enumerate() { + for (i, (def, _)) in composer.definition_ast_list.iter().skip(composer.built_in_num).enumerate() + { let def = &*def.read(); if let TopLevelDef::Function { signature, name, .. } = def { let ty_str = @@ -637,13 +638,24 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s vec!["cyclic inheritance detected"]; "cyclic2" )] +#[test_case( + vec![ + indoc! {" + class A: + pass + "} + ], + vec!["5: Class {\nname: \"A\",\ndef_id: DefinitionId(5),\nancestors: [CustomClassKind { id: DefinitionId(5), params: [] }],\nfields: [],\nmethods: [],\ntype_vars: []\n}"]; + "simple pass in class" +)] #[test_case( vec![indoc! {" class A: - pass + def fun3(self): + pass "}], - vec!["class def must have __init__ method defined"]; - "err no __init__" + vec!["function name `fun3` must not end with numbers"]; + "err fun end with number" )] #[test_case( vec![indoc! {" @@ -749,13 +761,13 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) { let mut composer: TopLevelComposer = Default::default(); let internal_resolver = make_internal_resolver_with_tvar( - vec![ + vec![ ("T".into(), vec![]), ("V".into(), vec![composer.primitives_ty.bool, composer.primitives_ty.int32]), ("G".into(), vec![composer.primitives_ty.bool, composer.primitives_ty.int64]), ], &mut composer.unifier, - print + print, ); let resolver = Arc::new( Box::new(Resolver(internal_resolver.clone())) as Box @@ -766,7 +778,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) { let ast = ast[0].clone(); let (id, def_id, ty) = { - match composer.register_top_level(ast, Some(resolver.clone()), "__main__".into()) { + match composer.register_top_level(ast, Some(resolver.clone()), "".into()) { Ok(x) => x, Err(msg) => { if print { @@ -792,7 +804,9 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) { } } else { // skip 5 to skip primitives - for (i, (def, _)) in composer.definition_ast_list.iter().skip(composer.built_in_num).enumerate() { + for (i, (def, _)) in + composer.definition_ast_list.iter().skip(composer.built_in_num).enumerate() + { let def = &*def.read(); if print { @@ -921,13 +935,20 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) { let mut composer: TopLevelComposer = Default::default(); let internal_resolver = make_internal_resolver_with_tvar( - vec![ + vec![ ("T".into(), vec![]), - ("V".into(), vec![composer.primitives_ty.float, composer.primitives_ty.int32, composer.primitives_ty.int64]), + ( + "V".into(), + vec![ + composer.primitives_ty.float, + composer.primitives_ty.int32, + composer.primitives_ty.int64, + ], + ), ("G".into(), vec![composer.primitives_ty.bool, composer.primitives_ty.int64]), ], &mut composer.unifier, - print + print, ); let resolver = Arc::new( Box::new(Resolver(internal_resolver.clone())) as Box @@ -938,7 +959,7 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) { let ast = ast[0].clone(); let (id, def_id, ty) = { - match composer.register_top_level(ast, Some(resolver.clone()), "__main__".into()) { + match composer.register_top_level(ast, Some(resolver.clone()), "".into()) { Ok(x) => x, Err(msg) => { if print { @@ -964,12 +985,18 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) { } } else { // skip 5 to skip primitives - let mut stringify_folder = TypeToStringFolder { unifier: &mut composer.unifier}; - for (i, (def, _)) in composer.definition_ast_list.iter().skip(composer.built_in_num).enumerate() { + let mut stringify_folder = TypeToStringFolder { unifier: &mut composer.unifier }; + for (_i, (def, _)) in + composer.definition_ast_list.iter().skip(composer.built_in_num).enumerate() + { let def = &*def.read(); if let TopLevelDef::Function { instance_to_stmt, name, .. } = def { - println!("=========`{}`: number of instances: {}===========", name, instance_to_stmt.len()); + println!( + "=========`{}`: number of instances: {}===========", + name, + instance_to_stmt.len() + ); for inst in instance_to_stmt.iter() { let ast = &inst.1.body; for b in ast { @@ -983,25 +1010,29 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) { } } -fn make_internal_resolver_with_tvar(tvars: Vec<(String, Vec)>, unifier: &mut Unifier, print: bool) -> Arc { +fn make_internal_resolver_with_tvar( + tvars: Vec<(String, Vec)>, + unifier: &mut Unifier, + print: bool, +) -> Arc { let res: Arc = ResolverInternal { id_to_def: Default::default(), id_to_type: tvars .into_iter() - .map(|(name, range)| ( - name.clone(), - { + .map(|(name, range)| { + (name.clone(), { let (ty, id) = unifier.get_fresh_var_with_range(range.as_slice()); if print { println!("{}: {:?}, tvar{}", name, ty, id); } ty - } - )) + }) + }) .collect::>() .into(), - class_names: Default::default() - }.into(); + class_names: Default::default(), + } + .into(); if print { println!(); } @@ -1009,7 +1040,7 @@ fn make_internal_resolver_with_tvar(tvars: Vec<(String, Vec)>, unifier: &m } struct TypeToStringFolder<'a> { - unifier: &'a mut Unifier + unifier: &'a mut Unifier, } impl<'a> Fold> for TypeToStringFolder<'a> { @@ -1017,14 +1048,11 @@ impl<'a> Fold> for TypeToStringFolder<'a> { type Error = String; fn map_user(&mut self, user: Option) -> Result { Ok(if let Some(ty) = user { - self.unifier.stringify( - ty, - &mut |id| format!("class{}", id.to_string()), - &mut |id| format!("tvar{}", id.to_string()), - ) - } else { - "None".into() - } - ) + self.unifier.stringify(ty, &mut |id| format!("class{}", id.to_string()), &mut |id| { + format!("tvar{}", id.to_string()) + }) + } else { + "None".into() + }) } -} \ No newline at end of file +} diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 608ed40e..a62c46db 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -797,7 +797,12 @@ impl Unifier { self.subst_impl(a, mapping, &mut HashMap::new()) } - fn subst_impl(&mut self, a: Type, mapping: &VarMap, cache: &mut HashMap>) -> Option { + fn subst_impl( + &mut self, + a: Type, + mapping: &VarMap, + cache: &mut HashMap>, + ) -> Option { use TypeVarMeta::*; let cached = cache.get_mut(&a); if let Some(cached) = cached { @@ -831,9 +836,9 @@ impl Unifier { TypeEnum::TList { ty } => { self.subst_impl(*ty, mapping, cache).map(|t| self.add_ty(TypeEnum::TList { ty: t })) } - TypeEnum::TVirtual { ty } => { - self.subst_impl(*ty, mapping, cache).map(|t| self.add_ty(TypeEnum::TVirtual { ty: t })) - } + TypeEnum::TVirtual { ty } => self + .subst_impl(*ty, mapping, cache) + .map(|t| self.add_ty(TypeEnum::TVirtual { ty: t })), TypeEnum::TObj { obj_id, fields, params } => { // Type variables in field types must be present in the type parameter. // If the mapping does not contain any type variables in the @@ -851,7 +856,8 @@ impl Unifier { if need_subst { cache.insert(a, None); let obj_id = *obj_id; - let params = self.subst_map(¶ms, mapping, cache).unwrap_or_else(|| params.clone()); + let params = + self.subst_map(¶ms, mapping, cache).unwrap_or_else(|| params.clone()); let fields = self .subst_map(&fields.borrow(), mapping, cache) .unwrap_or_else(|| fields.borrow().clone()); @@ -897,7 +903,12 @@ impl Unifier { } } - fn subst_map(&mut self, map: &Mapping, mapping: &VarMap, cache: &mut HashMap>) -> Option> + fn subst_map( + &mut self, + map: &Mapping, + mapping: &VarMap, + cache: &mut HashMap>, + ) -> Option> where K: std::hash::Hash + std::cmp::Eq + std::clone::Clone, {