From 084efe92af5c90d4d7f0096f3bd035de62f80056 Mon Sep 17 00:00:00 2001 From: pca006132 Date: Wed, 22 Sep 2021 17:19:27 +0800 Subject: [PATCH] nac3core: use string interning --- Cargo.lock | 51 +++- Cargo.toml | 4 + nac3core/Cargo.toml | 2 +- nac3core/src/codegen/expr.rs | 12 +- nac3core/src/codegen/mod.rs | 6 +- nac3core/src/codegen/stmt.rs | 2 +- nac3core/src/codegen/test.rs | 50 ++-- nac3core/src/symbol_resolver.rs | 251 ++++++++++-------- nac3core/src/toplevel/composer.rs | 130 +++++---- nac3core/src/toplevel/helper.rs | 26 +- nac3core/src/toplevel/mod.rs | 10 +- nac3core/src/toplevel/test.rs | 51 ++-- nac3core/src/toplevel/type_annotation.rs | 93 ++++--- nac3core/src/typecheck/function_check.rs | 12 +- nac3core/src/typecheck/type_inferencer/mod.rs | 48 ++-- .../src/typecheck/type_inferencer/test.rs | 76 +++--- nac3core/src/typecheck/typedef/mod.rs | 24 +- nac3core/src/typecheck/typedef/test.rs | 16 +- nac3embedded/Cargo.toml | 2 +- nac3standalone/Cargo.toml | 2 +- nac3standalone/src/basic_symbol_resolver.rs | 23 +- nac3standalone/src/main.rs | 2 +- 22 files changed, 495 insertions(+), 398 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f0084e6c..468d85a4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -69,6 +69,12 @@ version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693" +[[package]] +name = "byteorder" +version = "1.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" + [[package]] name = "cc" version = "1.0.68" @@ -213,6 +219,15 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37ab347416e802de484e4d03c7316c48f1ecb56574dfd4a46a80f173ce1de04d" +[[package]] +name = "fxhash" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c31b6d751ae2c7f11320402d34e41349dd1016f8d5d45e48c4312bc8625af50c" +dependencies = [ + "byteorder", +] + [[package]] name = "getrandom" version = "0.2.3" @@ -241,6 +256,15 @@ version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d7afe4a420e3fe79967a00898cc1f4db7c8a49a9333a29f8a4bd76a253d5cd04" +[[package]] +name = "hashbrown" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" +dependencies = [ + "ahash", +] + [[package]] name = "hermit-abi" version = "0.1.19" @@ -257,7 +281,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "824845a0bf897a9042383849b02c1bc219c2383772efcd5c6f9766fa4b81aef3" dependencies = [ "autocfg", - "hashbrown", + "hashbrown 0.9.1", ] [[package]] @@ -845,15 +869,19 @@ checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b" [[package]] name = "rustpython-ast" version = "0.1.0" -source = "git+https://github.com/pca006132/RustPython?branch=main#c6248660e33a2db8c2d745097ac4bff13598d955" +source = "git+https://github.com/m-labs/RustPython?branch=parser-mod#efdf7829ba1a5f87d30df8eaff12a330544f3cbd" dependencies = [ + "fxhash", + "lazy_static", "num-bigint 0.4.0", + "parking_lot", + "string-interner", ] [[package]] name = "rustpython-parser" version = "0.1.2" -source = "git+https://github.com/pca006132/RustPython?branch=main#c6248660e33a2db8c2d745097ac4bff13598d955" +source = "git+https://github.com/m-labs/RustPython?branch=parser-mod#efdf7829ba1a5f87d30df8eaff12a330544f3cbd" dependencies = [ "ahash", "lalrpop", @@ -898,6 +926,12 @@ dependencies = [ "pest", ] +[[package]] +name = "serde" +version = "1.0.130" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f12d06de37cf59146fbdecab66aa99f9fe4f78722e3607577a5375d66bd0c913" + [[package]] name = "siphasher" version = "0.3.5" @@ -910,6 +944,17 @@ version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fe0f37c9e8f3c5a4a66ad655a93c74daac4ad00c441533bf5c6e7990bb42604e" +[[package]] +name = "string-interner" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc77d3a5728ef82235df1f9b9430507f555c7404797f42b49c2403d4c1d8c6c" +dependencies = [ + "cfg-if", + "hashbrown 0.11.2", + "serde", +] + [[package]] name = "string_cache" version = "0.8.1" diff --git a/Cargo.toml b/Cargo.toml index 938324ad..fc88f147 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,3 +4,7 @@ members = [ "nac3standalone", "nac3embedded", ] + +[profile.release] +debug = true + diff --git a/nac3core/Cargo.toml b/nac3core/Cargo.toml index 20404c56..6be20907 100644 --- a/nac3core/Cargo.toml +++ b/nac3core/Cargo.toml @@ -8,7 +8,7 @@ edition = "2018" num-bigint = "0.3" num-traits = "0.2" inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm11-0"] } -rustpython-parser = { git = "https://github.com/pca006132/RustPython", branch = "main" } +rustpython-parser = { git = "https://github.com/m-labs/RustPython", branch = "parser-mod" } itertools = "0.10.1" crossbeam = "0.8.1" parking_lot = "0.11.1" diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 11df52a9..0a7b5da6 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -12,7 +12,7 @@ use inkwell::{ AddressSpace, }; use itertools::{chain, izip, zip, Itertools}; -use rustpython_parser::ast::{self, Boolop, Constant, Expr, ExprKind, Operator}; +use rustpython_parser::ast::{self, Boolop, Constant, Expr, ExprKind, Operator, StrRef}; pub fn assert_int_val<'ctx>(val: BasicValueEnum<'ctx>) -> IntValue<'ctx> { if let BasicValueEnum::IntValue(v) = val { @@ -56,7 +56,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { .join(", ") } - pub fn get_attr_index(&mut self, ty: Type, attr: &str) -> usize { + pub fn get_attr_index(&mut self, ty: Type, attr: StrRef) -> usize { let obj_id = match &*self.unifier.get_ty(ty) { TypeEnum::TObj { obj_id, .. } => *obj_id, // we cannot have other types, virtual type should be handled by function calls @@ -106,7 +106,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { &mut self, obj: Option<(Type, BasicValueEnum<'ctx>)>, fun: (&FunSignature, DefinitionId), - params: Vec<(Option, BasicValueEnum<'ctx>)>, + params: Vec<(Option, BasicValueEnum<'ctx>)>, ) -> Option> { let key = self.get_subst_key(obj.map(|a| a.0), fun.0, None); let definition = self.top_level.definitions.read().get(fun.1.0).cloned().unwrap(); @@ -122,7 +122,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { // TODO: what about other fields that require alloca? let mut fun_id = None; for (name, _, id) in methods.iter() { - if name == "__init__" { + if name == &"__init__".into() { fun_id = Some(*id); } } @@ -449,7 +449,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { } ExprKind::Attribute { value, attr, .. } => { // note that we would handle class methods directly in calls - let index = self.get_attr_index(value.custom.unwrap(), attr); + let index = self.get_attr_index(value.custom.unwrap(), *attr); let val = self.gen_expr(value).unwrap(); let ptr = assert_pointer_val(val); unsafe { @@ -666,7 +666,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { ExprKind::Name { id, .. } => { // TODO: handle primitive casts and function pointers let fun = - self.resolver.get_identifier_def(&id).expect("Unknown identifier"); + self.resolver.get_identifier_def(*id).expect("Unknown identifier"); return self.gen_call(None, (&signature, fun), params); } ExprKind::Attribute { value, attr, .. } => { diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index fca41345..b32733da 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -18,7 +18,7 @@ use inkwell::{ }; use itertools::Itertools; use parking_lot::{Condvar, Mutex}; -use rustpython_parser::ast::Stmt; +use rustpython_parser::ast::{Stmt, StrRef}; use std::collections::HashMap; use std::sync::{ atomic::{AtomicBool, Ordering}, @@ -39,7 +39,7 @@ pub struct CodeGenContext<'ctx, 'a> { pub top_level: &'a TopLevelContext, pub unifier: Unifier, pub resolver: Arc>, - pub var_assignment: HashMap>, + pub var_assignment: HashMap>, pub type_cache: HashMap>, pub primitives: PrimitiveStore, pub calls: Arc>, @@ -317,7 +317,7 @@ pub fn gen_func<'ctx>( let param = fn_val.get_nth_param(n as u32).unwrap(); let alloca = builder.build_alloca( get_llvm_type(&context, &mut unifier, top_level_ctx.as_ref(), &mut type_cache, arg.ty), - &arg.name, + &arg.name.to_string(), ); builder.build_store(alloca, param); var_assignment.insert(arg.name.clone(), alloca); diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index df6e823b..b8ce89a3 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -30,7 +30,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { }) } ExprKind::Attribute { value, attr, .. } => { - let index = self.get_attr_index(value.custom.unwrap(), attr); + let index = self.get_attr_index(value.custom.unwrap(), *attr); let val = self.gen_expr(value).unwrap(); let ptr = if let BasicValueEnum::PointerValue(v) = val { v diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 70f76e5c..1173b97e 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -12,38 +12,38 @@ use crate::{ }; use indoc::indoc; use parking_lot::RwLock; -use rustpython_parser::{ast::fold::Fold, parser::parse_program}; +use rustpython_parser::{ast::{StrRef, fold::Fold}, parser::parse_program}; use std::cell::RefCell; use std::collections::{HashMap, HashSet}; use std::sync::Arc; struct Resolver { - id_to_type: HashMap, - id_to_def: RwLock>, - class_names: HashMap, + id_to_type: HashMap, + id_to_def: RwLock>, + class_names: HashMap, } impl Resolver { - pub fn add_id_def(&self, id: String, def: DefinitionId) { + pub fn add_id_def(&self, id: StrRef, def: DefinitionId) { self.id_to_def.write().insert(id, def); } } impl SymbolResolver for Resolver { - fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option { - self.id_to_type.get(str).cloned() + fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: StrRef) -> Option { + self.id_to_type.get(&str).cloned() } - fn get_symbol_value(&self, _: &str) -> Option { + fn get_symbol_value(&self, _: StrRef) -> Option { unimplemented!() } - fn get_symbol_location(&self, _: &str) -> Option { + fn get_symbol_location(&self, _: StrRef) -> Option { unimplemented!() } - fn get_identifier_def(&self, id: &str) -> Option { - self.id_to_def.read().get(id).cloned() + fn get_identifier_def(&self, id: StrRef) -> Option { + self.id_to_def.read().get(&id).cloned() } } @@ -77,8 +77,8 @@ fn test_primitives() { let threads = ["test"]; let signature = FunSignature { args: vec![ - FuncArg { name: "a".to_string(), ty: primitives.int32, default_value: None }, - FuncArg { name: "b".to_string(), ty: primitives.int32, default_value: None }, + FuncArg { name: "a".into(), ty: primitives.int32, default_value: None }, + FuncArg { name: "b".into(), ty: primitives.int32, default_value: None }, ], ret: primitives.int32, vars: HashMap::new(), @@ -91,7 +91,7 @@ fn test_primitives() { }; let mut virtual_checks = Vec::new(); let mut calls = HashMap::new(); - let mut identifiers: HashSet<_> = ["a".to_string(), "b".to_string()].iter().cloned().collect(); + let mut identifiers: HashSet<_> = ["a".into(), "b".into()].iter().cloned().collect(); let mut inferencer = Inferencer { top_level: &top_level, function_data: &mut function_data, @@ -121,11 +121,11 @@ fn test_primitives() { let task = CodeGenTask { subst: Default::default(), - symbol_name: "testing".to_string(), - body: statements, + symbol_name: "testing".into(), + body: Arc::new(statements), resolver, unifier, - calls, + calls: Arc::new(calls), signature, }; let f = Arc::new(WithCall::new(Box::new(|module| { @@ -212,7 +212,7 @@ fn test_simple_call() { unifier.top_level = Some(top_level.clone()); let signature = FunSignature { - args: vec![FuncArg { name: "a".to_string(), ty: primitives.int32, default_value: None }], + args: vec![FuncArg { name: "a".into(), ty: primitives.int32, default_value: None }], ret: primitives.int32, vars: HashMap::new(), }; @@ -221,7 +221,7 @@ fn test_simple_call() { let foo_id = top_level.definitions.read().len(); top_level.definitions.write().push(Arc::new(RwLock::new(TopLevelDef::Function { name: "foo".to_string(), - simple_name: "foo".to_string(), + simple_name: "foo".into(), signature: fun_ty, var_id: vec![], instance_to_stmt: HashMap::new(), @@ -234,7 +234,7 @@ fn test_simple_call() { id_to_def: RwLock::new(HashMap::new()), class_names: Default::default(), }); - resolver.add_id_def("foo".to_string(), DefinitionId(foo_id)); + resolver.add_id_def("foo".into(), DefinitionId(foo_id)); let resolver = Arc::new(resolver as Box); if let TopLevelDef::Function { resolver: r, .. } = @@ -253,7 +253,7 @@ fn test_simple_call() { }; let mut virtual_checks = Vec::new(); let mut calls = HashMap::new(); - let mut identifiers: HashSet<_> = ["a".to_string(), "foo".into()].iter().cloned().collect(); + let mut identifiers: HashSet<_> = ["a".into(), "foo".into()].iter().cloned().collect(); let mut inferencer = Inferencer { top_level: &top_level, function_data: &mut function_data, @@ -288,8 +288,8 @@ fn test_simple_call() { instance_to_stmt.insert( "".to_string(), FunInstance { - body: statements_2, - calls: inferencer.calls.clone(), + body: Arc::new(statements_2), + calls: Arc::new(inferencer.calls.clone()), subst: Default::default(), unifier_id: 0, }, @@ -309,10 +309,10 @@ fn test_simple_call() { let task = CodeGenTask { subst: Default::default(), symbol_name: "testing".to_string(), - body: statements_1, + body: Arc::new(statements_1), resolver, unifier, - calls: calls1, + calls: Arc::new(calls1), signature, }; let f = Arc::new(WithCall::new(Box::new(|module| { diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index e28e2c7f..d657e4e9 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -10,7 +10,7 @@ use crate::typecheck::{ use crate::{location::Location, typecheck::typedef::TypeEnum}; use itertools::{chain, izip}; use parking_lot::RwLock; -use rustpython_parser::ast::Expr; +use rustpython_parser::ast::{Expr, StrRef}; #[derive(Clone, PartialEq)] pub enum SymbolValue { @@ -29,15 +29,28 @@ pub trait SymbolResolver { &self, unifier: &mut Unifier, primitives: &PrimitiveStore, - str: &str, + str: StrRef, ) -> Option; // get the top-level definition of identifiers - fn get_identifier_def(&self, str: &str) -> Option; - fn get_symbol_value(&self, str: &str) -> Option; - fn get_symbol_location(&self, str: &str) -> Option; + fn get_identifier_def(&self, str: StrRef) -> Option; + fn get_symbol_value(&self, str: StrRef) -> Option; + fn get_symbol_location(&self, str: StrRef) -> Option; // handle function call etc. } +thread_local! { + static IDENTIFIER_ID: [StrRef; 8] = [ + "int32".into(), + "int64".into(), + "float".into(), + "bool".into(), + "None".into(), + "virtual".into(), + "list".into(), + "tuple".into() + ]; +} + // convert type annotation into type pub fn parse_type_annotation( resolver: &dyn SymbolResolver, @@ -47,15 +60,32 @@ pub fn parse_type_annotation( expr: &Expr, ) -> Result { use rustpython_parser::ast::ExprKind::*; + let ids = IDENTIFIER_ID.with(|ids| { + *ids + }); + let int32_id = ids[0]; + let int64_id = ids[1]; + let float_id = ids[2]; + let bool_id = ids[3]; + let none_id = ids[4]; + let virtual_id = ids[5]; + let list_id = ids[6]; + let tuple_id = ids[7]; + match &expr.node { - Name { id, .. } => match id.as_str() { - "int32" => Ok(primitives.int32), - "int64" => Ok(primitives.int64), - "float" => Ok(primitives.float), - "bool" => Ok(primitives.bool), - "None" => Ok(primitives.none), - x => { - let obj_id = resolver.get_identifier_def(x); + Name { id, .. } => { + if *id == int32_id { + Ok(primitives.int32) + } else if *id == int64_id { + Ok(primitives.int64) + } else if *id == float_id { + Ok(primitives.float) + } else if *id == bool_id { + Ok(primitives.bool) + } else if *id == none_id { + Ok(primitives.none) + } else { + let obj_id = resolver.get_identifier_def(*id); if let Some(obj_id) = obj_id { let def = top_level_defs[obj_id.0].read(); if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { @@ -67,8 +97,8 @@ pub fn parse_type_annotation( } let fields = RefCell::new( chain( - fields.iter().map(|(k, v)| (k.clone(), *v)), - methods.iter().map(|(k, v, _)| (k.clone(), *v)), + fields.iter().map(|(k, v)| (*k, *v)), + methods.iter().map(|(k, v, _)| (*k, *v)), ) .collect(), ); @@ -83,121 +113,116 @@ pub fn parse_type_annotation( } else { // it could be a type variable let ty = resolver - .get_symbol_type(unifier, primitives, x) + .get_symbol_type(unifier, primitives, *id) .ok_or_else(|| "unknown type variable name".to_owned())?; if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) { Ok(ty) } else { - Err(format!("Unknown type annotation {}", x)) + Err(format!("Unknown type annotation {}", id)) } } } }, Subscript { value, slice, .. } => { if let Name { id, .. } = &value.node { - match id.as_str() { - "virtual" => { - let ty = parse_type_annotation( + if *id == virtual_id { + let ty = parse_type_annotation( + resolver, + top_level_defs, + unifier, + primitives, + slice, + )?; + Ok(unifier.add_ty(TypeEnum::TVirtual { ty })) + } else if *id == list_id { + let ty = parse_type_annotation( + resolver, + top_level_defs, + unifier, + primitives, + slice, + )?; + Ok(unifier.add_ty(TypeEnum::TList { ty })) + } else if *id == tuple_id { + if let Tuple { elts, .. } = &slice.node { + let ty = elts + .iter() + .map(|elt| { + parse_type_annotation( + resolver, + top_level_defs, + unifier, + primitives, + elt, + ) + }) + .collect::, _>>()?; + Ok(unifier.add_ty(TypeEnum::TTuple { ty })) + } else { + Err("Expected multiple elements for tuple".into()) + } + } else { + let types = if let Tuple { elts, .. } = &slice.node { + elts.iter() + .map(|v| { + parse_type_annotation( + resolver, + top_level_defs, + unifier, + primitives, + v, + ) + }) + .collect::, _>>()? + } else { + vec![parse_type_annotation( resolver, top_level_defs, unifier, primitives, slice, - )?; - Ok(unifier.add_ty(TypeEnum::TVirtual { ty })) - } - "list" => { - let ty = parse_type_annotation( - resolver, - top_level_defs, - unifier, - primitives, - slice, - )?; - Ok(unifier.add_ty(TypeEnum::TList { ty })) - } - "tuple" => { - if let Tuple { elts, .. } = &slice.node { - let ty = elts - .iter() - .map(|elt| { - parse_type_annotation( - resolver, - top_level_defs, - unifier, - primitives, - elt, - ) - }) - .collect::, _>>()?; - Ok(unifier.add_ty(TypeEnum::TTuple { ty })) - } else { - Err("Expected multiple elements for tuple".into()) - } - } - _ => { - let types = if let Tuple { elts, .. } = &slice.node { - elts.iter() - .map(|v| { - parse_type_annotation( - resolver, - top_level_defs, - unifier, - primitives, - v, - ) - }) - .collect::, _>>()? - } else { - vec![parse_type_annotation( - resolver, - top_level_defs, - unifier, - primitives, - slice, - )?] - }; + )?] + }; - let obj_id = resolver - .get_identifier_def(id) - .ok_or_else(|| format!("Unknown type annotation {}", id))?; - let def = top_level_defs[obj_id.0].read(); - if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { - if types.len() != type_vars.len() { - return Err(format!( - "Unexpected number of type parameters: expected {} but got {}", - type_vars.len(), - types.len() - )); - } - let mut subst = HashMap::new(); - for (var, ty) in izip!(type_vars.iter(), types.iter()) { - let id = if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) { - *id - } else { - unreachable!() - }; - subst.insert(id, *ty); - } - let mut fields = fields - .iter() - .map(|(attr, ty)| { - let ty = unifier.subst(*ty, &subst).unwrap_or(*ty); - (attr.clone(), ty) - }) - .collect::>(); - fields.extend(methods.iter().map(|(attr, ty, _)| { - let ty = unifier.subst(*ty, &subst).unwrap_or(*ty); - (attr.clone(), ty) - })); - Ok(unifier.add_ty(TypeEnum::TObj { - obj_id, - fields: fields.into(), - params: subst.into(), - })) - } else { - Err("Cannot use function name as type".into()) + let obj_id = resolver + .get_identifier_def(*id) + .ok_or_else(|| format!("Unknown type annotation {}", id))?; + let def = top_level_defs[obj_id.0].read(); + if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { + if types.len() != type_vars.len() { + return Err(format!( + "Unexpected number of type parameters: expected {} but got {}", + type_vars.len(), + types.len() + )); } + let mut subst = HashMap::new(); + for (var, ty) in izip!(type_vars.iter(), types.iter()) { + let id = if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) { + *id + } else { + unreachable!() + }; + subst.insert(id, *ty); + } + let mut fields = fields + .iter() + .map(|(attr, ty)| { + let ty = unifier.subst(*ty, &subst).unwrap_or(*ty); + (*attr, ty) + }) + .collect::>(); + fields.extend(methods.iter().map(|(attr, ty, _)| { + let ty = unifier.subst(*ty, &subst).unwrap_or(*ty); + (*attr, ty) + })); + Ok(unifier.add_ty(TypeEnum::TObj { + obj_id, + fields: fields.into(), + params: subst.into(), + })) + } else { + Err("Cannot use function name as type".into()) } } } else { diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 5c6817b6..dc67db4f 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -15,7 +15,7 @@ pub struct TopLevelComposer { // primitive store pub primitives_ty: PrimitiveStore, // keyword list to prevent same user-defined name - pub keyword_list: HashSet, + pub keyword_list: HashSet, // to prevent duplicate definition pub defined_names: HashSet, // get the class def id of a class method @@ -34,24 +34,39 @@ impl TopLevelComposer { /// return a composer and things to make a "primitive" symbol resolver, so that the symbol /// resolver can later figure out primitive type definitions when passed a primitive type name pub fn new( - builtins: Vec<(String, FunSignature)>, - ) -> (Self, HashMap, HashMap) { + builtins: Vec<(StrRef, FunSignature)>, + ) -> (Self, HashMap, HashMap) { let primitives = Self::make_primitives(); let mut definition_ast_list = { let top_level_def_list = vec![ - Arc::new(RwLock::new(Self::make_top_level_class_def(0, None, "int32", None))), - Arc::new(RwLock::new(Self::make_top_level_class_def(1, None, "int64", None))), - Arc::new(RwLock::new(Self::make_top_level_class_def(2, None, "float", None))), - Arc::new(RwLock::new(Self::make_top_level_class_def(3, None, "bool", None))), - Arc::new(RwLock::new(Self::make_top_level_class_def(4, None, "none", None))), + Arc::new(RwLock::new(Self::make_top_level_class_def( + 0, + None, + "int32".into(), + None, + ))), + Arc::new(RwLock::new(Self::make_top_level_class_def( + 1, + None, + "int64".into(), + None, + ))), + Arc::new(RwLock::new(Self::make_top_level_class_def( + 2, + None, + "float".into(), + None, + ))), + Arc::new(RwLock::new(Self::make_top_level_class_def(3, None, "bool".into(), None))), + Arc::new(RwLock::new(Self::make_top_level_class_def(4, None, "none".into(), None))), ]; let ast_list: Vec>> = vec![None, None, None, None, None]; izip!(top_level_def_list, ast_list).collect_vec() }; let primitives_ty = primitives.0; let mut unifier = primitives.1; - let mut keyword_list: HashSet = HashSet::from_iter(vec![ + let mut keyword_list: HashSet = HashSet::from_iter(vec![ "Generic".into(), "virtual".into(), "list".into(), @@ -69,8 +84,8 @@ impl TopLevelComposer { let defined_names: HashSet = Default::default(); let method_class: HashMap = Default::default(); - let mut built_in_id: HashMap = Default::default(); - let mut built_in_ty: HashMap = Default::default(); + let mut built_in_id: HashMap = Default::default(); + let mut built_in_ty: HashMap = Default::default(); for (name, sig) in builtins { let fun_sig = unifier.add_ty(TypeEnum::TFunc(RefCell::new(sig))); @@ -78,11 +93,11 @@ impl TopLevelComposer { built_in_id.insert(name.clone(), DefinitionId(definition_ast_list.len())); definition_ast_list.push(( Arc::new(RwLock::new(TopLevelDef::Function { - name: name.clone(), - simple_name: name.clone(), + name: name.into(), + simple_name: name, signature: fun_sig, instance_to_stmt: HashMap::new(), - instance_to_symbol: [("".to_string(), name.clone())].iter().cloned().collect(), + instance_to_symbol: [("".into(), name.into())].iter().cloned().collect(), var_id: Default::default(), resolver: None, })), @@ -131,7 +146,7 @@ impl TopLevelComposer { ast: ast::Stmt<()>, resolver: Option>>, mod_path: String, - ) -> Result<(String, DefinitionId, Option), String> { + ) -> Result<(StrRef, DefinitionId, Option), String> { let defined_names = &mut self.defined_names; match &ast.node { ast::StmtKind::ClassDef { name: class_name, body, .. } => { @@ -140,7 +155,7 @@ impl TopLevelComposer { } if !defined_names.insert({ let mut n = mod_path.clone(); - n.push_str(class_name.as_str()); + n.push_str(&class_name.to_string()); n }) { return Err("duplicate definition of class".into()); @@ -156,7 +171,7 @@ impl TopLevelComposer { Arc::new(RwLock::new(Self::make_top_level_class_def( class_def_id, resolver.clone(), - class_name.as_str(), + class_name, Some(constructor_ty), ))), None, @@ -167,7 +182,7 @@ impl TopLevelComposer { // thus cannot return their definition_id type MethodInfo = ( // the simple method name without class name - String, + StrRef, // in this top level def, method name is prefixed with the class name Arc>, DefinitionId, @@ -186,8 +201,11 @@ impl TopLevelComposer { let global_class_method_name = { let mut n = mod_path.clone(); n.push_str( - Self::make_class_method_name(class_name.clone(), method_name) - .as_str(), + Self::make_class_method_name( + class_name.into(), + &method_name.to_string(), + ) + .as_str(), ); n }; @@ -247,22 +265,22 @@ impl TopLevelComposer { // if self.keyword_list.contains(name) { // return Err("cannot use keyword as a top level function name".into()); // } - let fun_name = name.to_string(); let global_fun_name = { let mut n = mod_path; - n.push_str(name.as_str()); + n.push_str(&name.to_string()); n }; if !defined_names.insert(global_fun_name.clone()) { return Err("duplicate top level function define".into()); } + let fun_name = *name; let ty_to_be_unified = self.unifier.get_fresh_var().0; // add to the definition list self.definition_ast_list.push(( RwLock::new(Self::make_top_level_function_def( global_fun_name, - name.into(), + *name, // dummy here, unify with correct type later ty_to_be_unified, resolver, @@ -334,7 +352,7 @@ impl TopLevelComposer { if { matches!( &value.node, - ast::ExprKind::Name { id, .. } if id == "Generic" + ast::ExprKind::Name { id, .. } if id == &"Generic".into() ) } => { @@ -432,7 +450,7 @@ impl TopLevelComposer { ast::ExprKind::Subscript { value, .. } if matches!( &value.node, - ast::ExprKind::Name { id, .. } if id == "Generic" + ast::ExprKind::Name { id, .. } if id == &"Generic".into() ) ) { continue; @@ -627,9 +645,9 @@ impl TopLevelComposer { let mut function_var_map: HashMap = HashMap::new(); let arg_types = { // make sure no duplicate parameter - let mut defined_paramter_name: HashSet = HashSet::new(); + let mut defined_paramter_name: HashSet<_> = HashSet::new(); let have_unique_fuction_parameter_name = args.args.iter().all(|x| { - defined_paramter_name.insert(x.node.arg.clone()) + defined_paramter_name.insert(x.node.arg) && !keyword_list.contains(&x.node.arg) }); if !have_unique_fuction_parameter_name { @@ -765,7 +783,7 @@ impl TopLevelComposer { unifier: &mut Unifier, primitives: &PrimitiveStore, type_var_to_concrete_def: &mut HashMap, - keyword_list: &HashSet, + keyword_list: &HashSet, ) -> Result<(), String> { let mut class_def = class_def.write(); let ( @@ -809,12 +827,12 @@ impl TopLevelComposer { let class_resolver = class_resolver.as_ref().unwrap(); let class_resolver = class_resolver.as_ref(); - let mut defined_fields: HashSet = HashSet::new(); + let mut defined_fields: HashSet<_> = HashSet::new(); for b in class_body_ast { match &b.node { ast::StmtKind::FunctionDef { args, returns, name, .. } => { let (method_dummy_ty, method_id) = - Self::get_class_method_def_info(class_methods_def, name)?; + Self::get_class_method_def_info(class_methods_def, *name)?; // the method var map can surely include the class's generic parameters let mut method_var_map: HashMap = class_type_vars_def @@ -830,27 +848,28 @@ impl TopLevelComposer { let arg_types: Vec = { // check method parameters cannot have same name - let mut defined_paramter_name: HashSet = HashSet::new(); + let mut defined_paramter_name: HashSet<_> = HashSet::new(); + let zelf: StrRef = "self".into(); let have_unique_fuction_parameter_name = args.args.iter().all(|x| { defined_paramter_name.insert(x.node.arg.clone()) - && (!keyword_list.contains(&x.node.arg) || x.node.arg == "self") + && (!keyword_list.contains(&x.node.arg) || x.node.arg == zelf) }); if !have_unique_fuction_parameter_name { return Err("class method must have unique parameter names \ and names thould not be the same as the keywords" .into()); } - if name == "__init__" && !defined_paramter_name.contains("self") { + if name == &"__init__".into() && !defined_paramter_name.contains(&zelf) { return Err("__init__ function must have a `self` parameter".into()); } - if !defined_paramter_name.contains("self") { + if !defined_paramter_name.contains(&zelf) { return Err("currently does not support static method".into()); } let mut result = Vec::new(); for x in &args.args { - let name = x.node.arg.clone(); - if name != "self" { + let name = x.node.arg; + if name != zelf { let type_ann = { let annotation_expr = x .node @@ -962,14 +981,15 @@ impl TopLevelComposer { if let ast::ExprKind::Name { id: attr, .. } = &target.node { if defined_fields.insert(attr.to_string()) { let dummy_field_type = unifier.get_fresh_var().0; - class_fields_def.push((attr.to_string(), dummy_field_type)); + class_fields_def.push((*attr, dummy_field_type)); // handle Kernel[T], KernelImmutable[T] let annotation = { match &annotation.as_ref().node { ast::ExprKind::Subscript { value, slice, .. } if { - matches!(&value.node, ast::ExprKind::Name { id, .. } if id == "Kernel" || id == "KernelImmutable") + matches!(&value.node, ast::ExprKind::Name { id, .. } + if id == &"Kernel".into() || id == &"KernelImmutable".into()) } => { slice @@ -1054,19 +1074,19 @@ impl TopLevelComposer { if let TopLevelDef::Class { methods, fields, .. } = &*base { // handle methods override // since we need to maintain the order, create a new list - let mut new_child_methods: Vec<(String, Type, DefinitionId)> = Vec::new(); - let mut is_override: HashSet = HashSet::new(); + let mut new_child_methods: Vec<(StrRef, Type, DefinitionId)> = Vec::new(); + let mut is_override: HashSet = HashSet::new(); for (anc_method_name, anc_method_ty, anc_method_def_id) in methods { // find if there is a method with same name in the child class let mut to_be_added = - (anc_method_name.to_string(), *anc_method_ty, *anc_method_def_id); + (*anc_method_name, *anc_method_ty, *anc_method_def_id); for (class_method_name, class_method_ty, class_method_defid) in class_methods_def.iter() { if class_method_name == anc_method_name { // ignore and handle self // if is __init__ method, no need to check return type - let ok = class_method_name == "__init__" + let ok = class_method_name == &"__init__".into() || Self::check_overload_function_type( *class_method_ty, *anc_method_ty, @@ -1077,9 +1097,9 @@ impl TopLevelComposer { return Err("method has same name as ancestors' method, but incompatible type".into()); } // mark it as added - is_override.insert(class_method_name.to_string()); + is_override.insert(*class_method_name); to_be_added = ( - class_method_name.to_string(), + *class_method_name, *class_method_ty, *class_method_defid, ); @@ -1094,7 +1114,7 @@ impl TopLevelComposer { { if !is_override.contains(class_method_name) { new_child_methods.push(( - class_method_name.to_string(), + *class_method_name, *class_method_ty, *class_method_defid, )); @@ -1105,10 +1125,10 @@ impl TopLevelComposer { class_methods_def.extend(new_child_methods); // handle class fields - let mut new_child_fields: Vec<(String, Type)> = Vec::new(); - // let mut is_override: HashSet = HashSet::new(); + let mut new_child_fields: Vec<(StrRef, Type)> = Vec::new(); + // let mut is_override: HashSet<_> = HashSet::new(); for (anc_field_name, anc_field_ty) in fields { - let to_be_added = (anc_field_name.to_string(), *anc_field_ty); + let to_be_added = (*anc_field_name, *anc_field_ty); // find if there is a fields with the same name in the child class for (class_field_name, ..) in class_fields_def.iter() { if class_field_name == anc_field_name { @@ -1135,7 +1155,7 @@ impl TopLevelComposer { } for (class_field_name, class_field_ty) in class_fields_def.iter() { if !is_override.contains(class_field_name) { - new_child_fields.push((class_field_name.to_string(), *class_field_ty)); + new_child_fields.push((*class_field_name, *class_field_ty)); } } class_fields_def.drain(..); @@ -1173,7 +1193,7 @@ impl TopLevelComposer { let mut constructor_args: Vec = Vec::new(); let mut type_vars: HashMap = HashMap::new(); for (name, func_sig, id) in methods { - if name == "__init__" { + if name == &"__init__".into() { init_id = Some(*id); if let TypeEnum::TFunc(sig) = self.unifier.get_ty(*func_sig).as_ref() { let FunSignature { args, vars, .. } = &*sig.borrow(); @@ -1203,7 +1223,7 @@ impl TopLevelComposer { let init_ast = self.definition_ast_list.get(init_id.0).unwrap().1.as_ref().unwrap(); if let ast::StmtKind::FunctionDef { name, body, .. } = &init_ast.node { - if name != "__init__" { + if name != &"__init__".into() { unreachable!("must be init function here") } let all_inited = Self::get_all_assigned_field(body.as_slice())?; @@ -1303,7 +1323,7 @@ impl TopLevelComposer { let mut identifiers = { // NOTE: none and function args? - let mut result: HashSet = HashSet::new(); + let mut result: HashSet<_> = HashSet::new(); result.insert("None".into()); if self_type.is_some() { result.insert("self".into()); @@ -1331,7 +1351,7 @@ impl TopLevelComposer { unifier: &mut self.unifier, variable_mapping: { // NOTE: none and function args? - let mut result: HashMap = HashMap::new(); + let mut result: HashMap = HashMap::new(); result.insert("None".into(), self.primitives_ty.none); if let Some(self_ty) = self_type { result.insert("self".into(), self_ty); @@ -1350,9 +1370,9 @@ impl TopLevelComposer { { if !decorator_list.is_empty() && matches!(&decorator_list[0].node, - ast::ExprKind::Name{ id, .. } if id == "syscall") + ast::ExprKind::Name{ id, .. } if id == &"syscall".into()) { - instance_to_symbol.insert("".to_string(), simple_name.clone()); + instance_to_symbol.insert("".into(), simple_name.to_string()); continue; } body diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index c024631c..3d585c9c 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -92,11 +92,11 @@ impl TopLevelComposer { pub fn make_top_level_class_def( index: usize, resolver: Option>>, - name: &str, + name: StrRef, constructor: Option, ) -> TopLevelDef { TopLevelDef::Class { - name: name.to_string(), + name, object_id: DefinitionId(index), type_vars: Default::default(), fields: Default::default(), @@ -110,7 +110,7 @@ impl TopLevelComposer { /// when first registering, the type is a invalid value pub fn make_top_level_function_def( name: String, - simple_name: String, + simple_name: StrRef, ty: Type, resolver: Option>>, ) -> TopLevelDef { @@ -132,11 +132,11 @@ impl TopLevelComposer { } pub fn get_class_method_def_info( - class_methods_def: &[(String, Type, DefinitionId)], - method_name: &str, + class_methods_def: &[(StrRef, Type, DefinitionId)], + method_name: StrRef, ) -> Result<(Type, DefinitionId), String> { for (name, ty, def_id) in class_methods_def { - if name == method_name { + if name == &method_name { return Ok((*ty, *def_id)); } } @@ -234,7 +234,7 @@ impl TopLevelComposer { (name, type_var_to_concrete_def.get(ty).unwrap()) })) .all(|(this, other)| { - if this.0 == "self" && this.0 == other.0 { + if this.0 == &"self".into() && this.0 == other.0 { true } else { this.0 == other.0 @@ -269,15 +269,15 @@ impl TopLevelComposer { ) } - pub fn get_all_assigned_field(stmts: &[ast::Stmt<()>]) -> Result, String> { - let mut result: HashSet = HashSet::new(); + pub fn get_all_assigned_field(stmts: &[ast::Stmt<()>]) -> Result, String> { + let mut result = HashSet::new(); for s in stmts { match &s.node { ast::StmtKind::AnnAssign { target, .. } if { if let ast::ExprKind::Attribute { value, .. } = &target.node { if let ast::ExprKind::Name { id, .. } = &value.node { - id == "self" + id == &"self".into() } else { false } @@ -295,7 +295,7 @@ impl TopLevelComposer { for t in targets { if let ast::ExprKind::Attribute { value, attr, .. } = &t.node { if let ast::ExprKind::Name { id, .. } = &value.node { - if id == "self" { + if id == &"self".into() { result.insert(attr.clone()); } } @@ -312,14 +312,14 @@ impl TopLevelComposer { let inited_for_sure = Self::get_all_assigned_field(body.as_slice())? .intersection(&Self::get_all_assigned_field(orelse.as_slice())?) .cloned() - .collect::>(); + .collect::>(); result.extend(inited_for_sure); } ast::StmtKind::Try { body, orelse, finalbody, .. } => { let inited_for_sure = Self::get_all_assigned_field(body.as_slice())? .intersection(&Self::get_all_assigned_field(orelse.as_slice())?) .cloned() - .collect::>(); + .collect::>(); result.extend(inited_for_sure); result.extend(Self::get_all_assigned_field(finalbody.as_slice())?); } diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index 887d9a2d..e61bc08d 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -15,7 +15,7 @@ use crate::{ }; use itertools::{izip, Itertools}; use parking_lot::RwLock; -use rustpython_parser::ast::{self, Stmt}; +use rustpython_parser::ast::{self, Stmt, StrRef}; #[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash, Debug)] pub struct DefinitionId(pub usize); @@ -40,15 +40,15 @@ pub struct FunInstance { pub enum TopLevelDef { Class { // name for error messages and symbols - name: String, + name: StrRef, // object ID used for TypeEnum object_id: DefinitionId, /// type variables bounded to the class. type_vars: Vec, // class fields - fields: Vec<(String, Type)>, + fields: Vec<(StrRef, Type)>, // class methods, pointing to the corresponding function definition. - methods: Vec<(String, Type, DefinitionId)>, + methods: Vec<(StrRef, Type, DefinitionId)>, // ancestor classes, including itself. ancestors: Vec, // symbol resolver of the module defined the class, none if it is built-in type @@ -60,7 +60,7 @@ pub enum TopLevelDef { // prefix for symbol, should be unique globally name: String, // simple name, the same as in method/function definition - simple_name: String, + simple_name: StrRef, // function signature. signature: Type, // instantiated type variable IDs diff --git a/nac3core/src/toplevel/test.rs b/nac3core/src/toplevel/test.rs index cac627af..ab3f9517 100644 --- a/nac3core/src/toplevel/test.rs +++ b/nac3core/src/toplevel/test.rs @@ -16,17 +16,17 @@ use test_case::test_case; use super::*; struct ResolverInternal { - id_to_type: Mutex>, - id_to_def: Mutex>, - class_names: Mutex>, + id_to_type: Mutex>, + id_to_def: Mutex>, + class_names: Mutex>, } impl ResolverInternal { - fn add_id_def(&self, id: String, def: DefinitionId) { + fn add_id_def(&self, id: StrRef, def: DefinitionId) { self.id_to_def.lock().insert(id, def); } - fn add_id_type(&self, id: String, ty: Type) { + fn add_id_type(&self, id: StrRef, ty: Type) { self.id_to_type.lock().insert(id, ty); } } @@ -34,24 +34,24 @@ impl ResolverInternal { struct Resolver(Arc); impl SymbolResolver for Resolver { - fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option { - let ret = self.0.id_to_type.lock().get(str).cloned(); + fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: StrRef) -> Option { + let ret = self.0.id_to_type.lock().get(&str).cloned(); if ret.is_none() { // println!("unknown here resolver {}", str); } ret } - fn get_symbol_value(&self, _: &str) -> Option { + fn get_symbol_value(&self, _: StrRef) -> Option { unimplemented!() } - fn get_symbol_location(&self, _: &str) -> Option { + fn get_symbol_location(&self, _: StrRef) -> Option { unimplemented!() } - fn get_identifier_def(&self, id: &str) -> Option { - self.0.id_to_def.lock().get(id).cloned() + fn get_identifier_def(&self, id: StrRef) -> Option { + self.0.id_to_def.lock().get(&id).cloned() } } @@ -70,7 +70,7 @@ impl SymbolResolver for Resolver { class B: def __init__(self): self.b: float = 4.3 - + def fun(self): self.b = self.b + 3.0 "}, @@ -449,19 +449,19 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s methods: [(\"__init__\", \"fn[[a=class5[2->class2, 3->class3], b=class8], class4]\", DefinitionId(6)), (\"fun\", \"fn[[a=class5[2->class2, 3->class3]], class5[2->class3, 3->class0]]\", DefinitionId(7))], type_vars: [UnificationKey(100), UnificationKey(101)] }"}, - + indoc! {"6: Function { name: \"A.__init__\", sig: \"fn[[a=class5[2->class2, 3->class3], b=class8], class4]\", var_id: [2, 3] }"}, - + indoc! {"7: Function { name: \"A.fun\", sig: \"fn[[a=class5[2->class2, 3->class3]], class5[2->class3, 3->class0]]\", var_id: [2, 3] }"}, - + indoc! {"8: Class { name: \"B\", def_id: DefinitionId(8), @@ -470,19 +470,19 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s methods: [(\"__init__\", \"fn[[], class4]\", DefinitionId(9)), (\"fun\", \"fn[[a=class5[2->class2, 3->class3]], class5[2->class3, 3->class0]]\", DefinitionId(7)), (\"foo\", \"fn[[b=class8], class8]\", DefinitionId(10)), (\"bar\", \"fn[[a=class5[2->list[class8], 3->class0]], tuple[class5[2->virtual[class5[2->class8, 3->class0]], 3->class3], class8]]\", DefinitionId(11))], type_vars: [] }"}, - + indoc! {"9: Function { name: \"B.__init__\", sig: \"fn[[], class4]\", var_id: [] }"}, - + indoc! {"10: Function { name: \"B.foo\", sig: \"fn[[b=class8], class8]\", var_id: [] }"}, - + indoc! {"11: Function { name: \"B.bar\", sig: \"fn[[a=class5[2->list[class8], 3->class0]], tuple[class5[2->virtual[class5[2->class8, 3->class0]], 3->class3], class8]]\", @@ -648,15 +648,6 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s vec!["5: Class {\nname: \"A\",\ndef_id: DefinitionId(5),\nancestors: [CustomClassKind { id: DefinitionId(5), params: [] }],\nfields: [],\nmethods: [],\ntype_vars: []\n}"]; "simple pass in class" )] -#[test_case( - vec![indoc! {" - class A: - def fun3(self): - pass - "}], - vec!["function name `fun3` must not end with numbers"]; - "err fun end with number" -)] #[test_case( vec![indoc! {" class A: @@ -790,7 +781,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) { } } }; - internal_resolver.add_id_def(id.clone(), def_id); + internal_resolver.add_id_def(id, def_id); if let Some(ty) = ty { internal_resolver.add_id_type(id, ty); } @@ -1027,7 +1018,7 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) { ); for inst in instance_to_stmt.iter() { let ast = &inst.1.body; - for b in ast { + for b in ast.iter() { println!("{:?}", stringify_folder.fold_stmt(b.clone()).unwrap()); println!("--------------------"); } @@ -1039,7 +1030,7 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) { } fn make_internal_resolver_with_tvar( - tvars: Vec<(String, Vec)>, + tvars: Vec<(StrRef, Vec)>, unifier: &mut Unifier, print: bool, ) -> Arc { diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index 18cc880e..6d4d6f0d 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -30,49 +30,54 @@ pub fn parse_ast_to_type_annotation_kinds( mut locked: HashMap>, ) -> Result { match &expr.node { - ast::ExprKind::Name { id, .. } => match id.as_str() { - "int32" => Ok(TypeAnnotation::PrimitiveKind(primitives.int32)), - "int64" => Ok(TypeAnnotation::PrimitiveKind(primitives.int64)), - "float" => Ok(TypeAnnotation::PrimitiveKind(primitives.float)), - "bool" => Ok(TypeAnnotation::PrimitiveKind(primitives.bool)), - "None" => Ok(TypeAnnotation::PrimitiveKind(primitives.none)), - x => { - if let Some(obj_id) = resolver.get_identifier_def(x) { - let type_vars = { - let def_read = top_level_defs[obj_id.0].try_read(); - if let Some(def_read) = def_read { - if let TopLevelDef::Class { type_vars, .. } = &*def_read { - type_vars.clone() - } else { - return Err("function cannot be used as a type".into()); - } + ast::ExprKind::Name { id, .. } => { + if id == &"int32".into() { + Ok(TypeAnnotation::PrimitiveKind(primitives.int32)) + } else if id == &"int64".into() { + Ok(TypeAnnotation::PrimitiveKind(primitives.int64)) + } else if id == &"float".into() { + Ok(TypeAnnotation::PrimitiveKind(primitives.float)) + } else if id == &"bool".into() { + Ok(TypeAnnotation::PrimitiveKind(primitives.bool)) + } else if id == &"None".into() { + Ok(TypeAnnotation::PrimitiveKind(primitives.none)) + } else if let Some(obj_id) = resolver.get_identifier_def(*id) { + let type_vars = { + let def_read = top_level_defs[obj_id.0].try_read(); + if let Some(def_read) = def_read { + if let TopLevelDef::Class { type_vars, .. } = &*def_read { + type_vars.clone() } else { - locked.get(&obj_id).unwrap().clone() + return Err("function cannot be used as a type".into()); } - }; - // check param number here - if !type_vars.is_empty() { - return Err(format!( - "expect {} type variable parameter but got 0", - type_vars.len() - )); - } - Ok(TypeAnnotation::CustomClassKind { id: obj_id, params: vec![] }) - } else if let Some(ty) = resolver.get_symbol_type(unifier, primitives, id) { - if let TypeEnum::TVar { .. } = unifier.get_ty(ty).as_ref() { - Ok(TypeAnnotation::TypeVarKind(ty)) } else { - Err("not a type variable identifier".into()) + locked.get(&obj_id).unwrap().clone() } - } else { - Err("name cannot be parsed as a type annotation".into()) + }; + // check param number here + if !type_vars.is_empty() { + return Err(format!( + "expect {} type variable parameter but got 0", + type_vars.len() + )); } + Ok(TypeAnnotation::CustomClassKind { id: obj_id, params: vec![] }) + } else if let Some(ty) = resolver.get_symbol_type(unifier, primitives, *id) { + if let TypeEnum::TVar { .. } = unifier.get_ty(ty).as_ref() { + Ok(TypeAnnotation::TypeVarKind(ty)) + } else { + Err("not a type variable identifier".into()) + } + } else { + Err("name cannot be parsed as a type annotation".into()) } - }, + } // virtual ast::ExprKind::Subscript { value, slice, .. } - if { matches!(&value.node, ast::ExprKind::Name { id, .. } if id == "virtual") } => + if { + matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"virtual".into()) + } => { let def = parse_ast_to_type_annotation_kinds( resolver, @@ -90,7 +95,9 @@ pub fn parse_ast_to_type_annotation_kinds( // list ast::ExprKind::Subscript { value, slice, .. } - if { matches!(&value.node, ast::ExprKind::Name { id, .. } if id == "list") } => + if { + matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"list".into()) + } => { let def_ann = parse_ast_to_type_annotation_kinds( resolver, @@ -105,7 +112,9 @@ pub fn parse_ast_to_type_annotation_kinds( // tuple ast::ExprKind::Subscript { value, slice, .. } - if { matches!(&value.node, ast::ExprKind::Name { id, .. } if id == "tuple") } => + if { + matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"tuple".into()) + } => { if let ast::ExprKind::Tuple { elts, .. } = &slice.node { let type_annotations = elts @@ -130,11 +139,13 @@ pub fn parse_ast_to_type_annotation_kinds( // custom class ast::ExprKind::Subscript { value, slice, .. } => { if let ast::ExprKind::Name { id, .. } = &value.node { - if vec!["virtual", "Generic", "list", "tuple"].contains(&id.as_str()) { + if vec!["virtual".into(), "Generic".into(), "list".into(), "tuple".into()] + .contains(id) + { return Err("keywords cannot be class name".into()); } let obj_id = resolver - .get_identifier_def(id) + .get_identifier_def(*id) .ok_or_else(|| "unknown class name".to_string())?; let type_vars = { let def_read = top_level_defs[obj_id.0].try_read(); @@ -272,12 +283,12 @@ pub fn get_type_from_type_annotation_kinds( .iter() .map(|(name, ty, _)| { let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty); - (name.clone(), subst_ty) + (*name, subst_ty) }) - .collect::>(); + .collect::>(); tobj_fields.extend(fields.iter().map(|(name, ty)| { let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty); - (name.clone(), subst_ty) + (*name, subst_ty) })); // println!("tobj_fields: {:?}", tobj_fields); diff --git a/nac3core/src/typecheck/function_check.rs b/nac3core/src/typecheck/function_check.rs index 03306615..6c2085a8 100644 --- a/nac3core/src/typecheck/function_check.rs +++ b/nac3core/src/typecheck/function_check.rs @@ -2,14 +2,14 @@ use crate::typecheck::typedef::TypeEnum; use super::type_inferencer::Inferencer; use super::typedef::Type; -use rustpython_parser::ast::{self, Expr, ExprKind, Stmt, StmtKind}; +use rustpython_parser::ast::{self, Expr, ExprKind, Stmt, StmtKind, StrRef}; use std::{collections::HashSet, iter::once}; impl<'a> Inferencer<'a> { fn check_pattern( &mut self, pattern: &Expr>, - defined_identifiers: &mut HashSet, + defined_identifiers: &mut HashSet, ) -> Result<(), String> { match &pattern.node { ExprKind::Name { id, .. } => { @@ -42,7 +42,7 @@ impl<'a> Inferencer<'a> { fn check_expr( &mut self, expr: &Expr>, - defined_identifiers: &mut HashSet, + defined_identifiers: &mut HashSet, ) -> Result<(), String> { // there are some cases where the custom field is None if let Some(ty) = &expr.custom { @@ -57,7 +57,7 @@ impl<'a> Inferencer<'a> { match &expr.node { ExprKind::Name { id, .. } => { if !defined_identifiers.contains(id) { - if self.function_data.resolver.get_identifier_def(id).is_some() { + if self.function_data.resolver.get_identifier_def(*id).is_some() { defined_identifiers.insert(id.clone()); } else { return Err(format!( @@ -143,7 +143,7 @@ impl<'a> Inferencer<'a> { fn check_stmt( &mut self, stmt: &Stmt>, - defined_identifiers: &mut HashSet, + defined_identifiers: &mut HashSet, ) -> Result { match &stmt.node { StmtKind::For { target, iter, body, orelse, .. } => { @@ -217,7 +217,7 @@ impl<'a> Inferencer<'a> { pub fn check_block( &mut self, block: &[Stmt>], - defined_identifiers: &mut HashSet, + defined_identifiers: &mut HashSet, ) -> Result { let mut ret = false; for stmt in block { diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 608129e4..4301e696 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -10,7 +10,7 @@ use itertools::izip; use rustpython_parser::ast::{ self, fold::{self, Fold}, - Arguments, Comprehension, ExprKind, Located, Location, + Arguments, Comprehension, ExprKind, Located, Location, StrRef, }; #[cfg(test)] @@ -45,12 +45,12 @@ pub struct FunctionData { pub struct Inferencer<'a> { pub top_level: &'a TopLevelContext, - pub defined_identifiers: HashSet, + pub defined_identifiers: HashSet, pub function_data: &'a mut FunctionData, pub unifier: &'a mut Unifier, pub primitives: &'a PrimitiveStore, pub virtual_checks: &'a mut Vec<(Type, Type)>, - pub variable_mapping: HashMap, + pub variable_mapping: HashMap, pub calls: &'a mut HashMap, } @@ -163,7 +163,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { ast::ExprKind::Constant { value, .. } => Some(self.infer_constant(value)?), ast::ExprKind::Name { id, .. } => { if !self.defined_identifiers.contains(id) { - if self.function_data.resolver.get_identifier_def(id.as_str()).is_some() { + if self.function_data.resolver.get_identifier_def(*id).is_some() { self.defined_identifiers.insert(id.clone()); } else { return Err(format!( @@ -172,12 +172,12 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { )); } } - Some(self.infer_identifier(id)?) + Some(self.infer_identifier(*id)?) } ast::ExprKind::List { elts, .. } => Some(self.infer_list(elts)?), ast::ExprKind::Tuple { elts, .. } => Some(self.infer_tuple(elts)?), ast::ExprKind::Attribute { value, attr, ctx: _ } => { - Some(self.infer_attribute(value, attr)?) + Some(self.infer_attribute(value, *attr)?) } ast::ExprKind::BoolOp { values, .. } => Some(self.infer_bool_ops(values)?), ast::ExprKind::BinOp { left, op, right } => { @@ -237,7 +237,7 @@ impl<'a> Inferencer<'a> { fn build_method_call( &mut self, location: Location, - method: String, + method: StrRef, obj: Type, params: Vec, ret: Option, @@ -413,7 +413,7 @@ impl<'a> Inferencer<'a> { func { // handle special functions that cannot be typed in the usual way... - if id == "virtual" { + if id == "virtual".into() { if args.is_empty() || args.len() > 2 || !keywords.is_empty() { return Err( "`virtual` can only accept 1/2 positional arguments.".to_string() @@ -448,7 +448,7 @@ impl<'a> Inferencer<'a> { }); } // int64 is special because its argument can be a constant larger than int32 - if id == "int64" && args.len() == 1 { + if id == "int64".into() && args.len() == 1 { if let ExprKind::Constant { value: ast::Constant::Int(val), kind } = &args[0].node { @@ -508,8 +508,8 @@ impl<'a> Inferencer<'a> { Ok(Located { location, custom: Some(ret), node: ExprKind::Call { func, args, keywords } }) } - fn infer_identifier(&mut self, id: &str) -> InferenceResult { - if let Some(ty) = self.variable_mapping.get(id) { + fn infer_identifier(&mut self, id: StrRef) -> InferenceResult { + if let Some(ty) = self.variable_mapping.get(&id) { Ok(*ty) } else { let variable_mapping = &mut self.variable_mapping; @@ -520,7 +520,7 @@ impl<'a> Inferencer<'a> { .get_symbol_type(unifier, self.primitives, id) .unwrap_or_else(|| { let ty = unifier.get_fresh_var().0; - variable_mapping.insert(id.to_string(), ty); + variable_mapping.insert(id, ty); ty })) } @@ -560,9 +560,13 @@ impl<'a> Inferencer<'a> { Ok(self.unifier.add_ty(TypeEnum::TTuple { ty })) } - fn infer_attribute(&mut self, value: &ast::Expr>, attr: &str) -> InferenceResult { + fn infer_attribute( + &mut self, + value: &ast::Expr>, + attr: StrRef, + ) -> InferenceResult { let (attr_ty, _) = self.unifier.get_fresh_var(); - let fields = once((attr.to_string(), attr_ty)).collect(); + let fields = once((attr, attr_ty)).collect(); let record = self.unifier.add_record(fields); self.constrain(value.custom.unwrap(), record, &value.location)?; Ok(attr_ty) @@ -583,10 +587,10 @@ impl<'a> Inferencer<'a> { op: &ast::Operator, right: &ast::Expr>, ) -> InferenceResult { - let method = binop_name(op); + let method = binop_name(op).into(); self.build_method_call( location, - method.to_string(), + method, left.custom.unwrap(), vec![right.custom.unwrap()], None, @@ -598,14 +602,8 @@ impl<'a> Inferencer<'a> { op: &ast::Unaryop, operand: &ast::Expr>, ) -> InferenceResult { - let method = unaryop_name(op); - self.build_method_call( - operand.location, - method.to_string(), - operand.custom.unwrap(), - vec![], - None, - ) + let method = unaryop_name(op).into(); + self.build_method_call(operand.location, method, operand.custom.unwrap(), vec![], None) } fn infer_compare( @@ -617,7 +615,7 @@ impl<'a> Inferencer<'a> { let boolean = self.primitives.bool; for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) { let method = - comparison_name(c).ok_or_else(|| "unsupported comparator".to_string())?.to_string(); + comparison_name(c).ok_or_else(|| "unsupported comparator".to_string())?.into(); self.build_method_call( a.location, method, diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index 7e73f772..2f8fdb67 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -12,26 +12,26 @@ use rustpython_parser::parser::parse_program; use test_case::test_case; struct Resolver { - id_to_type: HashMap, - id_to_def: HashMap, - class_names: HashMap, + id_to_type: HashMap, + id_to_def: HashMap, + class_names: HashMap, } impl SymbolResolver for Resolver { - fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option { - self.id_to_type.get(str).cloned() + fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: StrRef) -> Option { + self.id_to_type.get(&str).cloned() } - fn get_symbol_value(&self, _: &str) -> Option { + fn get_symbol_value(&self, _: StrRef) -> Option { unimplemented!() } - fn get_symbol_location(&self, _: &str) -> Option { + fn get_symbol_location(&self, _: StrRef) -> Option { unimplemented!() } - fn get_identifier_def(&self, id: &str) -> Option { - self.id_to_def.get(id).cloned() + fn get_identifier_def(&self, id: StrRef) -> Option { + self.id_to_def.get(&id).cloned() } } @@ -39,8 +39,8 @@ struct TestEnvironment { pub unifier: Unifier, pub function_data: FunctionData, pub primitives: PrimitiveStore, - pub id_to_name: HashMap, - pub identifier_mapping: HashMap, + pub id_to_name: HashMap, + pub identifier_mapping: HashMap, pub virtual_checks: Vec<(Type, Type)>, pub calls: HashMap, pub top_level: TopLevelContext, @@ -79,11 +79,11 @@ impl TestEnvironment { set_primitives_magic_methods(&primitives, &mut unifier); let id_to_name = [ - (0, "int32".to_string()), - (1, "int64".to_string()), - (2, "float".to_string()), - (3, "bool".to_string()), - (4, "none".to_string()), + (0, "int32".into()), + (1, "int64".into()), + (2, "float".into()), + (3, "bool".into()), + (4, "none".into()), ] .iter() .cloned() @@ -150,7 +150,7 @@ impl TestEnvironment { for (i, name) in ["int32", "int64", "float", "bool", "none"].iter().enumerate() { top_level_defs.push( RwLock::new(TopLevelDef::Class { - name: name.to_string(), + name: (*name).into(), object_id: DefinitionId(i), type_vars: Default::default(), fields: Default::default(), @@ -174,7 +174,7 @@ impl TestEnvironment { }); top_level_defs.push( RwLock::new(TopLevelDef::Class { - name: "Foo".to_string(), + name: "Foo".into(), object_id: DefinitionId(5), type_vars: vec![v0], fields: [("a".into(), v0)].into(), @@ -212,7 +212,7 @@ impl TestEnvironment { }); top_level_defs.push( RwLock::new(TopLevelDef::Class { - name: "Bar".to_string(), + name: "Bar".into(), object_id: DefinitionId(6), type_vars: Default::default(), fields: [("a".into(), int32), ("b".into(), fun)].into(), @@ -241,7 +241,7 @@ impl TestEnvironment { }); top_level_defs.push( RwLock::new(TopLevelDef::Class { - name: "Bar2".to_string(), + name: "Bar2".into(), object_id: DefinitionId(7), type_vars: Default::default(), fields: [("a".into(), bool), ("b".into(), fun)].into(), @@ -261,14 +261,14 @@ impl TestEnvironment { let class_names = [("Bar".into(), bar), ("Bar2".into(), bar2)].iter().cloned().collect(); let id_to_name = [ - (0, "int32".to_string()), - (1, "int64".to_string()), - (2, "float".to_string()), - (3, "bool".to_string()), - (4, "none".to_string()), - (5, "Foo".to_string()), - (6, "Bar".to_string()), - (7, "Bar2".to_string()), + (0, "int32".into()), + (1, "int64".into()), + (2, "float".into()), + (3, "bool".into()), + (4, "none".into()), + (5, "Foo".into()), + (6, "Bar".into()), + (7, "Bar2".into()), ] .iter() .cloned() @@ -385,7 +385,7 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st let mut env = TestEnvironment::new(); let id_to_name = std::mem::take(&mut env.id_to_name); let mut defined_identifiers: HashSet<_> = env.identifier_mapping.keys().cloned().collect(); - defined_identifiers.insert("virtual".to_string()); + defined_identifiers.insert("virtual".into()); let mut inferencer = env.get_inferencer(); inferencer.defined_identifiers = defined_identifiers.clone(); let statements = parse_program(source).unwrap(); @@ -400,16 +400,16 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st for (k, v) in inferencer.variable_mapping.iter() { let name = inferencer.unifier.stringify( *v, - &mut |v| id_to_name.get(&v).unwrap().clone(), + &mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| format!("v{}", v), ); println!("{}: {}", k, name); } for (k, v) in mapping.iter() { - let ty = inferencer.variable_mapping.get(*k).unwrap(); + let ty = inferencer.variable_mapping.get(&(*k).into()).unwrap(); let name = inferencer.unifier.stringify( *ty, - &mut |v| id_to_name.get(&v).unwrap().clone(), + &mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| format!("v{}", v), ); assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name)); @@ -418,12 +418,12 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st for ((a, b), (x, y)) in zip(inferencer.virtual_checks.iter(), virtuals) { let a = inferencer.unifier.stringify( *a, - &mut |v| id_to_name.get(&v).unwrap().clone(), + &mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| format!("v{}", v), ); let b = inferencer.unifier.stringify( *b, - &mut |v| id_to_name.get(&v).unwrap().clone(), + &mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| format!("v{}", v), ); @@ -527,7 +527,7 @@ fn test_primitive_magic_methods(source: &str, mapping: HashMap<&str, &str>) { let mut env = TestEnvironment::basic_test_env(); let id_to_name = std::mem::take(&mut env.id_to_name); let mut defined_identifiers: HashSet<_> = env.identifier_mapping.keys().cloned().collect(); - defined_identifiers.insert("virtual".to_string()); + defined_identifiers.insert("virtual".into()); let mut inferencer = env.get_inferencer(); inferencer.defined_identifiers = defined_identifiers.clone(); let statements = parse_program(source).unwrap(); @@ -542,16 +542,16 @@ fn test_primitive_magic_methods(source: &str, mapping: HashMap<&str, &str>) { for (k, v) in inferencer.variable_mapping.iter() { let name = inferencer.unifier.stringify( *v, - &mut |v| id_to_name.get(&v).unwrap().clone(), + &mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| format!("v{}", v), ); println!("{}: {}", k, name); } for (k, v) in mapping.iter() { - let ty = inferencer.variable_mapping.get(*k).unwrap(); + let ty = inferencer.variable_mapping.get(&(*k).into()).unwrap(); let name = inferencer.unifier.stringify( *ty, - &mut |v| id_to_name.get(&v).unwrap().clone(), + &mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| format!("v{}", v), ); assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name)); diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index a62c46db..a961aa75 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -6,6 +6,8 @@ use std::iter::once; use std::rc::Rc; use std::sync::{Arc, Mutex}; +use rustpython_parser::ast::StrRef; + use super::unification_table::{UnificationKey, UnificationTable}; use crate::symbol_resolver::SymbolValue; use crate::toplevel::{DefinitionId, TopLevelContext, TopLevelDef}; @@ -25,14 +27,14 @@ type VarMap = Mapping; #[derive(Clone)] pub struct Call { pub posargs: Vec, - pub kwargs: HashMap, + pub kwargs: HashMap, pub ret: Type, pub fun: RefCell>, } #[derive(Clone)] pub struct FuncArg { - pub name: String, + pub name: StrRef, pub ty: Type, pub default_value: Option, } @@ -48,7 +50,7 @@ pub struct FunSignature { pub enum TypeVarMeta { Generic, Sequence(RefCell>), - Record(RefCell>), + Record(RefCell>), } #[derive(Clone)] @@ -70,7 +72,7 @@ pub enum TypeEnum { }, TObj { obj_id: DefinitionId, - fields: RefCell>, + fields: RefCell>, params: RefCell, }, TVirtual { @@ -141,7 +143,7 @@ impl Unifier { .borrow() .iter() .map(|(name, ty)| { - (name.clone(), self.copy_from(unifier, *ty, type_cache)) + (*name, self.copy_from(unifier, *ty, type_cache)) }) .collect(), ), @@ -163,7 +165,7 @@ impl Unifier { .args .iter() .map(|arg| FuncArg { - name: arg.name.clone(), + name: arg.name, ty: self.copy_from(unifier, arg.ty, type_cache), default_value: arg.default_value.clone(), }) @@ -219,7 +221,7 @@ impl Unifier { self.unification_table.new_key(Rc::new(a)) } - pub fn add_record(&mut self, fields: Mapping) -> Type { + pub fn add_record(&mut self, fields: Mapping) -> Type { let id = self.var_id + 1; self.var_id += 1; self.add_ty(TypeEnum::TVar { @@ -563,12 +565,12 @@ impl Unifier { } (TCall(calls), TFunc(signature)) => { self.occur_check(a, b)?; - let required: Vec = signature + let required: Vec = signature .borrow() .args .iter() .filter(|v| v.default_value.is_none()) - .map(|v| v.name.clone()) + .map(|v| v.name) .rev() .collect(); // we unify every calls to the function signature. @@ -590,7 +592,7 @@ impl Unifier { .borrow() .args .iter() - .map(|v| (v.name.clone(), v.ty)) + .map(|v| (v.name, v.ty)) .rev() .collect(); for (i, t) in posargs.iter().enumerate() { @@ -662,7 +664,7 @@ impl Unifier { if let TopLevelDef::Class { name, .. } = &*top_level.definitions.read()[id].read() { - name.clone() + name.to_string() } else { unreachable!("expected class definition") } diff --git a/nac3core/src/typecheck/typedef/test.rs b/nac3core/src/typecheck/typedef/test.rs index d652ceca..fe56d2d7 100644 --- a/nac3core/src/typecheck/typedef/test.rs +++ b/nac3core/src/typecheck/typedef/test.rs @@ -149,7 +149,7 @@ impl TestEnvironment { let mut fields = HashMap::new(); while &s[0..1] != "]" { let eq = s.find('=').unwrap(); - let key = s[1..eq].to_string(); + let key = s[1..eq].into(); let result = self.internal_parse(&s[eq + 1..], mapping); fields.insert(key, result.0); s = result.1; @@ -342,8 +342,8 @@ fn test_recursive_subst() { let instantiated_ty = env.unifier.get_ty(instantiated); if let TypeEnum::TObj { fields, .. } = &*instantiated_ty { let fields = fields.borrow(); - assert!(env.unifier.unioned(*fields.get("a").unwrap(), int)); - assert!(env.unifier.unioned(*fields.get("rec").unwrap(), instantiated)); + assert!(env.unifier.unioned(*fields.get(&"a".into()).unwrap(), int)); + assert!(env.unifier.unioned(*fields.get(&"rec".into()).unwrap(), instantiated)); } else { unreachable!() } @@ -358,10 +358,10 @@ fn test_virtual() { )); let bar = env.unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(5), - fields: [("f".to_string(), fun), ("a".to_string(), int)] + fields: [("f".into(), fun), ("a".into(), int)] .iter() .cloned() - .collect::>() + .collect::>() .into(), params: HashMap::new().into(), }); @@ -370,15 +370,15 @@ fn test_virtual() { let a = env.unifier.add_ty(TypeEnum::TVirtual { ty: bar }); let b = env.unifier.add_ty(TypeEnum::TVirtual { ty: v0 }); - let c = env.unifier.add_record([("f".to_string(), v1)].iter().cloned().collect()); + let c = env.unifier.add_record([("f".into(), v1)].iter().cloned().collect()); env.unifier.unify(a, b).unwrap(); env.unifier.unify(b, c).unwrap(); assert!(env.unifier.eq(v1, fun)); - let d = env.unifier.add_record([("a".to_string(), v1)].iter().cloned().collect()); + let d = env.unifier.add_record([("a".into(), v1)].iter().cloned().collect()); assert_eq!(env.unifier.unify(b, d), Err("Cannot access field a for virtual type".to_string())); - let d = env.unifier.add_record([("b".to_string(), v1)].iter().cloned().collect()); + let d = env.unifier.add_record([("b".into(), v1)].iter().cloned().collect()); assert_eq!(env.unifier.unify(b, d), Err("No such attribute b".to_string())); } diff --git a/nac3embedded/Cargo.toml b/nac3embedded/Cargo.toml index f3eeaa1c..df584d15 100644 --- a/nac3embedded/Cargo.toml +++ b/nac3embedded/Cargo.toml @@ -11,5 +11,5 @@ crate-type = ["cdylib"] [dependencies] pyo3 = { version = "0.12.4", features = ["extension-module"] } inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm11-0"] } -rustpython-parser = { git = "https://github.com/pca006132/RustPython", branch = "main" } +rustpython-parser = { git = "https://github.com/m-labs/RustPython", branch = "parser-mod" } nac3core = { path = "../nac3core" } diff --git a/nac3standalone/Cargo.toml b/nac3standalone/Cargo.toml index a22ba57e..5542d5f6 100644 --- a/nac3standalone/Cargo.toml +++ b/nac3standalone/Cargo.toml @@ -6,6 +6,6 @@ edition = "2018" [dependencies] inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm11-0"] } -rustpython-parser = { git = "https://github.com/pca006132/RustPython", branch = "main" } +rustpython-parser = { git = "https://github.com/m-labs/RustPython", branch = "parser-mod" } parking_lot = "0.11.1" nac3core = { path = "../nac3core" } diff --git a/nac3standalone/src/basic_symbol_resolver.rs b/nac3standalone/src/basic_symbol_resolver.rs index f8fe760e..92f0d0c5 100644 --- a/nac3standalone/src/basic_symbol_resolver.rs +++ b/nac3standalone/src/basic_symbol_resolver.rs @@ -8,20 +8,21 @@ use nac3core::{ }, }; use parking_lot::Mutex; +use rustpython_parser::ast::StrRef; use std::{collections::HashMap, sync::Arc}; pub struct ResolverInternal { - pub id_to_type: Mutex>, - pub id_to_def: Mutex>, - pub class_names: Mutex>, + pub id_to_type: Mutex>, + pub id_to_def: Mutex>, + pub class_names: Mutex>, } impl ResolverInternal { - pub fn add_id_def(&self, id: String, def: DefinitionId) { + pub fn add_id_def(&self, id: StrRef, def: DefinitionId) { self.id_to_def.lock().insert(id, def); } - pub fn add_id_type(&self, id: String, ty: Type) { + pub fn add_id_type(&self, id: StrRef, ty: Type) { self.id_to_type.lock().insert(id, ty); } } @@ -29,23 +30,23 @@ impl ResolverInternal { pub struct Resolver(pub Arc); impl SymbolResolver for Resolver { - fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option { - let ret = self.0.id_to_type.lock().get(str).cloned(); + fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: StrRef) -> Option { + let ret = self.0.id_to_type.lock().get(&str).cloned(); if ret.is_none() { // println!("unknown here resolver {}", str); } ret } - fn get_symbol_value(&self, _: &str) -> Option { + fn get_symbol_value(&self, _: StrRef) -> Option { unimplemented!() } - fn get_symbol_location(&self, _: &str) -> Option { + fn get_symbol_location(&self, _: StrRef) -> Option { unimplemented!() } - fn get_identifier_def(&self, id: &str) -> Option { - self.0.id_to_def.lock().get(id).cloned() + fn get_identifier_def(&self, id: StrRef) -> Option { + self.0.id_to_def.lock().get(&id).cloned() } } diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index a9a36ce4..ec4c0d77 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -92,7 +92,7 @@ fn main() { let instance = { let defs = top_level.definitions.read(); - let mut instance = defs[resolver.get_identifier_def("run").unwrap().0].write(); + let mut instance = defs[resolver.get_identifier_def("run".into()).unwrap().0].write(); if let TopLevelDef::Function { instance_to_stmt, instance_to_symbol,