optimization (#13) #15

Merged
pca006132 merged 9 commits from optimization into master 2021-09-23 19:58:43 +08:00
22 changed files with 495 additions and 398 deletions
Showing only changes of commit 084efe92af - Show all commits

51
Cargo.lock generated
View File

@ -69,6 +69,12 @@ version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693" checksum = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693"
[[package]]
name = "byteorder"
version = "1.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610"
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.0.68" version = "1.0.68"
@ -213,6 +219,15 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37ab347416e802de484e4d03c7316c48f1ecb56574dfd4a46a80f173ce1de04d" checksum = "37ab347416e802de484e4d03c7316c48f1ecb56574dfd4a46a80f173ce1de04d"
[[package]]
name = "fxhash"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c31b6d751ae2c7f11320402d34e41349dd1016f8d5d45e48c4312bc8625af50c"
dependencies = [
"byteorder",
]
[[package]] [[package]]
name = "getrandom" name = "getrandom"
version = "0.2.3" version = "0.2.3"
@ -241,6 +256,15 @@ version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d7afe4a420e3fe79967a00898cc1f4db7c8a49a9333a29f8a4bd76a253d5cd04" checksum = "d7afe4a420e3fe79967a00898cc1f4db7c8a49a9333a29f8a4bd76a253d5cd04"
[[package]]
name = "hashbrown"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e"
dependencies = [
"ahash",
]
[[package]] [[package]]
name = "hermit-abi" name = "hermit-abi"
version = "0.1.19" version = "0.1.19"
@ -257,7 +281,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "824845a0bf897a9042383849b02c1bc219c2383772efcd5c6f9766fa4b81aef3" checksum = "824845a0bf897a9042383849b02c1bc219c2383772efcd5c6f9766fa4b81aef3"
dependencies = [ dependencies = [
"autocfg", "autocfg",
"hashbrown", "hashbrown 0.9.1",
] ]
[[package]] [[package]]
@ -845,15 +869,19 @@ checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b"
[[package]] [[package]]
name = "rustpython-ast" name = "rustpython-ast"
version = "0.1.0" version = "0.1.0"
source = "git+https://github.com/pca006132/RustPython?branch=main#c6248660e33a2db8c2d745097ac4bff13598d955" source = "git+https://github.com/m-labs/RustPython?branch=parser-mod#efdf7829ba1a5f87d30df8eaff12a330544f3cbd"
dependencies = [ dependencies = [
"fxhash",
"lazy_static",
"num-bigint 0.4.0", "num-bigint 0.4.0",
"parking_lot",
"string-interner",
] ]
[[package]] [[package]]
name = "rustpython-parser" name = "rustpython-parser"
version = "0.1.2" version = "0.1.2"
source = "git+https://github.com/pca006132/RustPython?branch=main#c6248660e33a2db8c2d745097ac4bff13598d955" source = "git+https://github.com/m-labs/RustPython?branch=parser-mod#efdf7829ba1a5f87d30df8eaff12a330544f3cbd"
dependencies = [ dependencies = [
"ahash", "ahash",
"lalrpop", "lalrpop",
@ -898,6 +926,12 @@ dependencies = [
"pest", "pest",
] ]
[[package]]
name = "serde"
version = "1.0.130"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f12d06de37cf59146fbdecab66aa99f9fe4f78722e3607577a5375d66bd0c913"
[[package]] [[package]]
name = "siphasher" name = "siphasher"
version = "0.3.5" version = "0.3.5"
@ -910,6 +944,17 @@ version = "1.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fe0f37c9e8f3c5a4a66ad655a93c74daac4ad00c441533bf5c6e7990bb42604e" checksum = "fe0f37c9e8f3c5a4a66ad655a93c74daac4ad00c441533bf5c6e7990bb42604e"
[[package]]
name = "string-interner"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ecc77d3a5728ef82235df1f9b9430507f555c7404797f42b49c2403d4c1d8c6c"
dependencies = [
"cfg-if",
"hashbrown 0.11.2",
"serde",
]
[[package]] [[package]]
name = "string_cache" name = "string_cache"
version = "0.8.1" version = "0.8.1"

View File

@ -4,3 +4,7 @@ members = [
"nac3standalone", "nac3standalone",
"nac3embedded", "nac3embedded",
] ]
[profile.release]
debug = true

View File

@ -8,7 +8,7 @@ edition = "2018"
num-bigint = "0.3" num-bigint = "0.3"
num-traits = "0.2" num-traits = "0.2"
inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm11-0"] } inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm11-0"] }
rustpython-parser = { git = "https://github.com/pca006132/RustPython", branch = "main" } rustpython-parser = { git = "https://github.com/m-labs/RustPython", branch = "parser-mod" }
itertools = "0.10.1" itertools = "0.10.1"
crossbeam = "0.8.1" crossbeam = "0.8.1"
parking_lot = "0.11.1" parking_lot = "0.11.1"

View File

@ -12,7 +12,7 @@ use inkwell::{
AddressSpace, AddressSpace,
}; };
use itertools::{chain, izip, zip, Itertools}; use itertools::{chain, izip, zip, Itertools};
use rustpython_parser::ast::{self, Boolop, Constant, Expr, ExprKind, Operator}; use rustpython_parser::ast::{self, Boolop, Constant, Expr, ExprKind, Operator, StrRef};
pub fn assert_int_val<'ctx>(val: BasicValueEnum<'ctx>) -> IntValue<'ctx> { pub fn assert_int_val<'ctx>(val: BasicValueEnum<'ctx>) -> IntValue<'ctx> {
if let BasicValueEnum::IntValue(v) = val { if let BasicValueEnum::IntValue(v) = val {
@ -56,7 +56,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
.join(", ") .join(", ")
} }
pub fn get_attr_index(&mut self, ty: Type, attr: &str) -> usize { pub fn get_attr_index(&mut self, ty: Type, attr: StrRef) -> usize {
let obj_id = match &*self.unifier.get_ty(ty) { let obj_id = match &*self.unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, .. } => *obj_id, TypeEnum::TObj { obj_id, .. } => *obj_id,
// we cannot have other types, virtual type should be handled by function calls // we cannot have other types, virtual type should be handled by function calls
@ -106,7 +106,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
&mut self, &mut self,
obj: Option<(Type, BasicValueEnum<'ctx>)>, obj: Option<(Type, BasicValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId), fun: (&FunSignature, DefinitionId),
params: Vec<(Option<String>, BasicValueEnum<'ctx>)>, params: Vec<(Option<StrRef>, BasicValueEnum<'ctx>)>,
) -> Option<BasicValueEnum<'ctx>> { ) -> Option<BasicValueEnum<'ctx>> {
let key = self.get_subst_key(obj.map(|a| a.0), fun.0, None); let key = self.get_subst_key(obj.map(|a| a.0), fun.0, None);
let definition = self.top_level.definitions.read().get(fun.1.0).cloned().unwrap(); let definition = self.top_level.definitions.read().get(fun.1.0).cloned().unwrap();
@ -122,7 +122,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
// TODO: what about other fields that require alloca? // TODO: what about other fields that require alloca?
let mut fun_id = None; let mut fun_id = None;
for (name, _, id) in methods.iter() { for (name, _, id) in methods.iter() {
if name == "__init__" { if name == &"__init__".into() {
fun_id = Some(*id); fun_id = Some(*id);
} }
} }
@ -449,7 +449,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
} }
ExprKind::Attribute { value, attr, .. } => { ExprKind::Attribute { value, attr, .. } => {
// note that we would handle class methods directly in calls // note that we would handle class methods directly in calls
let index = self.get_attr_index(value.custom.unwrap(), attr); let index = self.get_attr_index(value.custom.unwrap(), *attr);
let val = self.gen_expr(value).unwrap(); let val = self.gen_expr(value).unwrap();
let ptr = assert_pointer_val(val); let ptr = assert_pointer_val(val);
unsafe { unsafe {
@ -666,7 +666,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
ExprKind::Name { id, .. } => { ExprKind::Name { id, .. } => {
// TODO: handle primitive casts and function pointers // TODO: handle primitive casts and function pointers
let fun = let fun =
self.resolver.get_identifier_def(&id).expect("Unknown identifier"); self.resolver.get_identifier_def(*id).expect("Unknown identifier");
return self.gen_call(None, (&signature, fun), params); return self.gen_call(None, (&signature, fun), params);
} }
ExprKind::Attribute { value, attr, .. } => { ExprKind::Attribute { value, attr, .. } => {

View File

@ -18,7 +18,7 @@ use inkwell::{
}; };
use itertools::Itertools; use itertools::Itertools;
use parking_lot::{Condvar, Mutex}; use parking_lot::{Condvar, Mutex};
use rustpython_parser::ast::Stmt; use rustpython_parser::ast::{Stmt, StrRef};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{ use std::sync::{
atomic::{AtomicBool, Ordering}, atomic::{AtomicBool, Ordering},
@ -39,7 +39,7 @@ pub struct CodeGenContext<'ctx, 'a> {
pub top_level: &'a TopLevelContext, pub top_level: &'a TopLevelContext,
pub unifier: Unifier, pub unifier: Unifier,
pub resolver: Arc<Box<dyn SymbolResolver + Send + Sync>>, pub resolver: Arc<Box<dyn SymbolResolver + Send + Sync>>,
pub var_assignment: HashMap<String, PointerValue<'ctx>>, pub var_assignment: HashMap<StrRef, PointerValue<'ctx>>,
pub type_cache: HashMap<Type, BasicTypeEnum<'ctx>>, pub type_cache: HashMap<Type, BasicTypeEnum<'ctx>>,
pub primitives: PrimitiveStore, pub primitives: PrimitiveStore,
pub calls: Arc<HashMap<CodeLocation, CallId>>, pub calls: Arc<HashMap<CodeLocation, CallId>>,
@ -317,7 +317,7 @@ pub fn gen_func<'ctx>(
let param = fn_val.get_nth_param(n as u32).unwrap(); let param = fn_val.get_nth_param(n as u32).unwrap();
let alloca = builder.build_alloca( let alloca = builder.build_alloca(
get_llvm_type(&context, &mut unifier, top_level_ctx.as_ref(), &mut type_cache, arg.ty), get_llvm_type(&context, &mut unifier, top_level_ctx.as_ref(), &mut type_cache, arg.ty),
&arg.name, &arg.name.to_string(),
); );
builder.build_store(alloca, param); builder.build_store(alloca, param);
var_assignment.insert(arg.name.clone(), alloca); var_assignment.insert(arg.name.clone(), alloca);

View File

@ -30,7 +30,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
}) })
} }
ExprKind::Attribute { value, attr, .. } => { ExprKind::Attribute { value, attr, .. } => {
let index = self.get_attr_index(value.custom.unwrap(), attr); let index = self.get_attr_index(value.custom.unwrap(), *attr);
let val = self.gen_expr(value).unwrap(); let val = self.gen_expr(value).unwrap();
let ptr = if let BasicValueEnum::PointerValue(v) = val { let ptr = if let BasicValueEnum::PointerValue(v) = val {
v v

View File

@ -12,38 +12,38 @@ use crate::{
}; };
use indoc::indoc; use indoc::indoc;
use parking_lot::RwLock; use parking_lot::RwLock;
use rustpython_parser::{ast::fold::Fold, parser::parse_program}; use rustpython_parser::{ast::{StrRef, fold::Fold}, parser::parse_program};
use std::cell::RefCell; use std::cell::RefCell;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::sync::Arc; use std::sync::Arc;
struct Resolver { struct Resolver {
id_to_type: HashMap<String, Type>, id_to_type: HashMap<StrRef, Type>,
id_to_def: RwLock<HashMap<String, DefinitionId>>, id_to_def: RwLock<HashMap<StrRef, DefinitionId>>,
class_names: HashMap<String, Type>, class_names: HashMap<StrRef, Type>,
} }
impl Resolver { impl Resolver {
pub fn add_id_def(&self, id: String, def: DefinitionId) { pub fn add_id_def(&self, id: StrRef, def: DefinitionId) {
self.id_to_def.write().insert(id, def); self.id_to_def.write().insert(id, def);
} }
} }
impl SymbolResolver for Resolver { impl SymbolResolver for Resolver {
fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option<Type> { fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: StrRef) -> Option<Type> {
self.id_to_type.get(str).cloned() self.id_to_type.get(&str).cloned()
} }
fn get_symbol_value(&self, _: &str) -> Option<SymbolValue> { fn get_symbol_value(&self, _: StrRef) -> Option<SymbolValue> {
unimplemented!() unimplemented!()
} }
fn get_symbol_location(&self, _: &str) -> Option<Location> { fn get_symbol_location(&self, _: StrRef) -> Option<Location> {
unimplemented!() unimplemented!()
} }
fn get_identifier_def(&self, id: &str) -> Option<DefinitionId> { fn get_identifier_def(&self, id: StrRef) -> Option<DefinitionId> {
self.id_to_def.read().get(id).cloned() self.id_to_def.read().get(&id).cloned()
} }
} }
@ -77,8 +77,8 @@ fn test_primitives() {
let threads = ["test"]; let threads = ["test"];
let signature = FunSignature { let signature = FunSignature {
args: vec![ args: vec![
FuncArg { name: "a".to_string(), ty: primitives.int32, default_value: None }, FuncArg { name: "a".into(), ty: primitives.int32, default_value: None },
FuncArg { name: "b".to_string(), ty: primitives.int32, default_value: None }, FuncArg { name: "b".into(), ty: primitives.int32, default_value: None },
], ],
ret: primitives.int32, ret: primitives.int32,
vars: HashMap::new(), vars: HashMap::new(),
@ -91,7 +91,7 @@ fn test_primitives() {
}; };
let mut virtual_checks = Vec::new(); let mut virtual_checks = Vec::new();
let mut calls = HashMap::new(); let mut calls = HashMap::new();
let mut identifiers: HashSet<_> = ["a".to_string(), "b".to_string()].iter().cloned().collect(); let mut identifiers: HashSet<_> = ["a".into(), "b".into()].iter().cloned().collect();
let mut inferencer = Inferencer { let mut inferencer = Inferencer {
top_level: &top_level, top_level: &top_level,
function_data: &mut function_data, function_data: &mut function_data,
@ -121,11 +121,11 @@ fn test_primitives() {
let task = CodeGenTask { let task = CodeGenTask {
subst: Default::default(), subst: Default::default(),
symbol_name: "testing".to_string(), symbol_name: "testing".into(),
body: statements, body: Arc::new(statements),
resolver, resolver,
unifier, unifier,
calls, calls: Arc::new(calls),
signature, signature,
}; };
let f = Arc::new(WithCall::new(Box::new(|module| { let f = Arc::new(WithCall::new(Box::new(|module| {
@ -212,7 +212,7 @@ fn test_simple_call() {
unifier.top_level = Some(top_level.clone()); unifier.top_level = Some(top_level.clone());
let signature = FunSignature { let signature = FunSignature {
args: vec![FuncArg { name: "a".to_string(), ty: primitives.int32, default_value: None }], args: vec![FuncArg { name: "a".into(), ty: primitives.int32, default_value: None }],
ret: primitives.int32, ret: primitives.int32,
vars: HashMap::new(), vars: HashMap::new(),
}; };
@ -221,7 +221,7 @@ fn test_simple_call() {
let foo_id = top_level.definitions.read().len(); let foo_id = top_level.definitions.read().len();
top_level.definitions.write().push(Arc::new(RwLock::new(TopLevelDef::Function { top_level.definitions.write().push(Arc::new(RwLock::new(TopLevelDef::Function {
name: "foo".to_string(), name: "foo".to_string(),
simple_name: "foo".to_string(), simple_name: "foo".into(),
signature: fun_ty, signature: fun_ty,
var_id: vec![], var_id: vec![],
instance_to_stmt: HashMap::new(), instance_to_stmt: HashMap::new(),
@ -234,7 +234,7 @@ fn test_simple_call() {
id_to_def: RwLock::new(HashMap::new()), id_to_def: RwLock::new(HashMap::new()),
class_names: Default::default(), class_names: Default::default(),
}); });
resolver.add_id_def("foo".to_string(), DefinitionId(foo_id)); resolver.add_id_def("foo".into(), DefinitionId(foo_id));
let resolver = Arc::new(resolver as Box<dyn SymbolResolver + Send + Sync>); let resolver = Arc::new(resolver as Box<dyn SymbolResolver + Send + Sync>);
if let TopLevelDef::Function { resolver: r, .. } = if let TopLevelDef::Function { resolver: r, .. } =
@ -253,7 +253,7 @@ fn test_simple_call() {
}; };
let mut virtual_checks = Vec::new(); let mut virtual_checks = Vec::new();
let mut calls = HashMap::new(); let mut calls = HashMap::new();
let mut identifiers: HashSet<_> = ["a".to_string(), "foo".into()].iter().cloned().collect(); let mut identifiers: HashSet<_> = ["a".into(), "foo".into()].iter().cloned().collect();
let mut inferencer = Inferencer { let mut inferencer = Inferencer {
top_level: &top_level, top_level: &top_level,
function_data: &mut function_data, function_data: &mut function_data,
@ -288,8 +288,8 @@ fn test_simple_call() {
instance_to_stmt.insert( instance_to_stmt.insert(
"".to_string(), "".to_string(),
FunInstance { FunInstance {
body: statements_2, body: Arc::new(statements_2),
calls: inferencer.calls.clone(), calls: Arc::new(inferencer.calls.clone()),
subst: Default::default(), subst: Default::default(),
unifier_id: 0, unifier_id: 0,
}, },
@ -309,10 +309,10 @@ fn test_simple_call() {
let task = CodeGenTask { let task = CodeGenTask {
subst: Default::default(), subst: Default::default(),
symbol_name: "testing".to_string(), symbol_name: "testing".to_string(),
body: statements_1, body: Arc::new(statements_1),
resolver, resolver,
unifier, unifier,
calls: calls1, calls: Arc::new(calls1),
signature, signature,
}; };
let f = Arc::new(WithCall::new(Box::new(|module| { let f = Arc::new(WithCall::new(Box::new(|module| {

View File

@ -10,7 +10,7 @@ use crate::typecheck::{
use crate::{location::Location, typecheck::typedef::TypeEnum}; use crate::{location::Location, typecheck::typedef::TypeEnum};
use itertools::{chain, izip}; use itertools::{chain, izip};
use parking_lot::RwLock; use parking_lot::RwLock;
use rustpython_parser::ast::Expr; use rustpython_parser::ast::{Expr, StrRef};
#[derive(Clone, PartialEq)] #[derive(Clone, PartialEq)]
pub enum SymbolValue { pub enum SymbolValue {
@ -29,15 +29,28 @@ pub trait SymbolResolver {
&self, &self,
unifier: &mut Unifier, unifier: &mut Unifier,
primitives: &PrimitiveStore, primitives: &PrimitiveStore,
str: &str, str: StrRef,
) -> Option<Type>; ) -> Option<Type>;
// get the top-level definition of identifiers // get the top-level definition of identifiers
fn get_identifier_def(&self, str: &str) -> Option<DefinitionId>; fn get_identifier_def(&self, str: StrRef) -> Option<DefinitionId>;
fn get_symbol_value(&self, str: &str) -> Option<SymbolValue>; fn get_symbol_value(&self, str: StrRef) -> Option<SymbolValue>;
fn get_symbol_location(&self, str: &str) -> Option<Location>; fn get_symbol_location(&self, str: StrRef) -> Option<Location>;
// handle function call etc. // handle function call etc.
} }
thread_local! {
static IDENTIFIER_ID: [StrRef; 8] = [
"int32".into(),
"int64".into(),
"float".into(),
"bool".into(),
"None".into(),
"virtual".into(),
"list".into(),
"tuple".into()
];
}
// convert type annotation into type // convert type annotation into type
pub fn parse_type_annotation<T>( pub fn parse_type_annotation<T>(
resolver: &dyn SymbolResolver, resolver: &dyn SymbolResolver,
@ -47,15 +60,32 @@ pub fn parse_type_annotation<T>(
expr: &Expr<T>, expr: &Expr<T>,
) -> Result<Type, String> { ) -> Result<Type, String> {
use rustpython_parser::ast::ExprKind::*; use rustpython_parser::ast::ExprKind::*;
let ids = IDENTIFIER_ID.with(|ids| {
*ids
});
let int32_id = ids[0];
let int64_id = ids[1];
let float_id = ids[2];
let bool_id = ids[3];
let none_id = ids[4];
let virtual_id = ids[5];
let list_id = ids[6];
let tuple_id = ids[7];
match &expr.node { match &expr.node {
Name { id, .. } => match id.as_str() { Name { id, .. } => {
"int32" => Ok(primitives.int32), if *id == int32_id {
"int64" => Ok(primitives.int64), Ok(primitives.int32)
"float" => Ok(primitives.float), } else if *id == int64_id {
"bool" => Ok(primitives.bool), Ok(primitives.int64)
"None" => Ok(primitives.none), } else if *id == float_id {
x => { Ok(primitives.float)
let obj_id = resolver.get_identifier_def(x); } else if *id == bool_id {
Ok(primitives.bool)
} else if *id == none_id {
Ok(primitives.none)
} else {
let obj_id = resolver.get_identifier_def(*id);
if let Some(obj_id) = obj_id { if let Some(obj_id) = obj_id {
let def = top_level_defs[obj_id.0].read(); let def = top_level_defs[obj_id.0].read();
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
@ -67,8 +97,8 @@ pub fn parse_type_annotation<T>(
} }
let fields = RefCell::new( let fields = RefCell::new(
chain( chain(
fields.iter().map(|(k, v)| (k.clone(), *v)), fields.iter().map(|(k, v)| (*k, *v)),
methods.iter().map(|(k, v, _)| (k.clone(), *v)), methods.iter().map(|(k, v, _)| (*k, *v)),
) )
.collect(), .collect(),
); );
@ -83,121 +113,116 @@ pub fn parse_type_annotation<T>(
} else { } else {
// it could be a type variable // it could be a type variable
let ty = resolver let ty = resolver
.get_symbol_type(unifier, primitives, x) .get_symbol_type(unifier, primitives, *id)
.ok_or_else(|| "unknown type variable name".to_owned())?; .ok_or_else(|| "unknown type variable name".to_owned())?;
if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) { if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) {
Ok(ty) Ok(ty)
} else { } else {
Err(format!("Unknown type annotation {}", x)) Err(format!("Unknown type annotation {}", id))
} }
} }
} }
}, },
Subscript { value, slice, .. } => { Subscript { value, slice, .. } => {
if let Name { id, .. } = &value.node { if let Name { id, .. } = &value.node {
match id.as_str() { if *id == virtual_id {
"virtual" => { let ty = parse_type_annotation(
let ty = parse_type_annotation( resolver,
top_level_defs,
unifier,
primitives,
slice,
)?;
Ok(unifier.add_ty(TypeEnum::TVirtual { ty }))
} else if *id == list_id {
let ty = parse_type_annotation(
resolver,
top_level_defs,
unifier,
primitives,
slice,
)?;
Ok(unifier.add_ty(TypeEnum::TList { ty }))
} else if *id == tuple_id {
if let Tuple { elts, .. } = &slice.node {
let ty = elts
.iter()
.map(|elt| {
parse_type_annotation(
resolver,
top_level_defs,
unifier,
primitives,
elt,
)
})
.collect::<Result<Vec<_>, _>>()?;
Ok(unifier.add_ty(TypeEnum::TTuple { ty }))
} else {
Err("Expected multiple elements for tuple".into())
}
} else {
let types = if let Tuple { elts, .. } = &slice.node {
elts.iter()
.map(|v| {
parse_type_annotation(
resolver,
top_level_defs,
unifier,
primitives,
v,
)
})
.collect::<Result<Vec<_>, _>>()?
} else {
vec![parse_type_annotation(
resolver, resolver,
top_level_defs, top_level_defs,
unifier, unifier,
primitives, primitives,
slice, slice,
)?; )?]
Ok(unifier.add_ty(TypeEnum::TVirtual { ty })) };
}
"list" => {
let ty = parse_type_annotation(
resolver,
top_level_defs,
unifier,
primitives,
slice,
)?;
Ok(unifier.add_ty(TypeEnum::TList { ty }))
}
"tuple" => {
if let Tuple { elts, .. } = &slice.node {
let ty = elts
.iter()
.map(|elt| {
parse_type_annotation(
resolver,
top_level_defs,
unifier,
primitives,
elt,
)
})
.collect::<Result<Vec<_>, _>>()?;
Ok(unifier.add_ty(TypeEnum::TTuple { ty }))
} else {
Err("Expected multiple elements for tuple".into())
}
}
_ => {
let types = if let Tuple { elts, .. } = &slice.node {
elts.iter()
.map(|v| {
parse_type_annotation(
resolver,
top_level_defs,
unifier,
primitives,
v,
)
})
.collect::<Result<Vec<_>, _>>()?
} else {
vec![parse_type_annotation(
resolver,
top_level_defs,
unifier,
primitives,
slice,
)?]
};
let obj_id = resolver let obj_id = resolver
.get_identifier_def(id) .get_identifier_def(*id)
.ok_or_else(|| format!("Unknown type annotation {}", id))?; .ok_or_else(|| format!("Unknown type annotation {}", id))?;
let def = top_level_defs[obj_id.0].read(); let def = top_level_defs[obj_id.0].read();
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
if types.len() != type_vars.len() { if types.len() != type_vars.len() {
return Err(format!( return Err(format!(
"Unexpected number of type parameters: expected {} but got {}", "Unexpected number of type parameters: expected {} but got {}",
type_vars.len(), type_vars.len(),
types.len() types.len()
)); ));
}
let mut subst = HashMap::new();
for (var, ty) in izip!(type_vars.iter(), types.iter()) {
let id = if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) {
*id
} else {
unreachable!()
};
subst.insert(id, *ty);
}
let mut fields = fields
.iter()
.map(|(attr, ty)| {
let ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(attr.clone(), ty)
})
.collect::<HashMap<_, _>>();
fields.extend(methods.iter().map(|(attr, ty, _)| {
let ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(attr.clone(), ty)
}));
Ok(unifier.add_ty(TypeEnum::TObj {
obj_id,
fields: fields.into(),
params: subst.into(),
}))
} else {
Err("Cannot use function name as type".into())
} }
let mut subst = HashMap::new();
for (var, ty) in izip!(type_vars.iter(), types.iter()) {
let id = if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) {
*id
} else {
unreachable!()
};
subst.insert(id, *ty);
}
let mut fields = fields
.iter()
.map(|(attr, ty)| {
let ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*attr, ty)
})
.collect::<HashMap<_, _>>();
fields.extend(methods.iter().map(|(attr, ty, _)| {
let ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*attr, ty)
}));
Ok(unifier.add_ty(TypeEnum::TObj {
obj_id,
fields: fields.into(),
params: subst.into(),
}))
} else {
Err("Cannot use function name as type".into())
} }
} }
} else { } else {

View File

@ -15,7 +15,7 @@ pub struct TopLevelComposer {
// primitive store // primitive store
pub primitives_ty: PrimitiveStore, pub primitives_ty: PrimitiveStore,
// keyword list to prevent same user-defined name // keyword list to prevent same user-defined name
pub keyword_list: HashSet<String>, pub keyword_list: HashSet<StrRef>,
// to prevent duplicate definition // to prevent duplicate definition
pub defined_names: HashSet<String>, pub defined_names: HashSet<String>,
// get the class def id of a class method // get the class def id of a class method
@ -34,24 +34,39 @@ impl TopLevelComposer {
/// return a composer and things to make a "primitive" symbol resolver, so that the symbol /// return a composer and things to make a "primitive" symbol resolver, so that the symbol
/// resolver can later figure out primitive type definitions when passed a primitive type name /// resolver can later figure out primitive type definitions when passed a primitive type name
pub fn new( pub fn new(
builtins: Vec<(String, FunSignature)>, builtins: Vec<(StrRef, FunSignature)>,
) -> (Self, HashMap<String, DefinitionId>, HashMap<String, Type>) { ) -> (Self, HashMap<StrRef, DefinitionId>, HashMap<StrRef, Type>) {
let primitives = Self::make_primitives(); let primitives = Self::make_primitives();
let mut definition_ast_list = { let mut definition_ast_list = {
let top_level_def_list = vec![ let top_level_def_list = vec![
Arc::new(RwLock::new(Self::make_top_level_class_def(0, None, "int32", None))), Arc::new(RwLock::new(Self::make_top_level_class_def(
Arc::new(RwLock::new(Self::make_top_level_class_def(1, None, "int64", None))), 0,
Arc::new(RwLock::new(Self::make_top_level_class_def(2, None, "float", None))), None,
Arc::new(RwLock::new(Self::make_top_level_class_def(3, None, "bool", None))), "int32".into(),
Arc::new(RwLock::new(Self::make_top_level_class_def(4, None, "none", None))), None,
))),
Arc::new(RwLock::new(Self::make_top_level_class_def(
1,
None,
"int64".into(),
None,
))),
Arc::new(RwLock::new(Self::make_top_level_class_def(
2,
None,
"float".into(),
None,
))),
Arc::new(RwLock::new(Self::make_top_level_class_def(3, None, "bool".into(), None))),
Arc::new(RwLock::new(Self::make_top_level_class_def(4, None, "none".into(), None))),
]; ];
let ast_list: Vec<Option<ast::Stmt<()>>> = vec![None, None, None, None, None]; let ast_list: Vec<Option<ast::Stmt<()>>> = vec![None, None, None, None, None];
izip!(top_level_def_list, ast_list).collect_vec() izip!(top_level_def_list, ast_list).collect_vec()
}; };
let primitives_ty = primitives.0; let primitives_ty = primitives.0;
let mut unifier = primitives.1; let mut unifier = primitives.1;
let mut keyword_list: HashSet<String> = HashSet::from_iter(vec![ let mut keyword_list: HashSet<StrRef> = HashSet::from_iter(vec![
"Generic".into(), "Generic".into(),
"virtual".into(), "virtual".into(),
"list".into(), "list".into(),
@ -69,8 +84,8 @@ impl TopLevelComposer {
let defined_names: HashSet<String> = Default::default(); let defined_names: HashSet<String> = Default::default();
let method_class: HashMap<DefinitionId, DefinitionId> = Default::default(); let method_class: HashMap<DefinitionId, DefinitionId> = Default::default();
let mut built_in_id: HashMap<String, DefinitionId> = Default::default(); let mut built_in_id: HashMap<StrRef, DefinitionId> = Default::default();
let mut built_in_ty: HashMap<String, Type> = Default::default(); let mut built_in_ty: HashMap<StrRef, Type> = Default::default();
for (name, sig) in builtins { for (name, sig) in builtins {
let fun_sig = unifier.add_ty(TypeEnum::TFunc(RefCell::new(sig))); let fun_sig = unifier.add_ty(TypeEnum::TFunc(RefCell::new(sig)));
@ -78,11 +93,11 @@ impl TopLevelComposer {
built_in_id.insert(name.clone(), DefinitionId(definition_ast_list.len())); built_in_id.insert(name.clone(), DefinitionId(definition_ast_list.len()));
definition_ast_list.push(( definition_ast_list.push((
Arc::new(RwLock::new(TopLevelDef::Function { Arc::new(RwLock::new(TopLevelDef::Function {
name: name.clone(), name: name.into(),
simple_name: name.clone(), simple_name: name,
signature: fun_sig, signature: fun_sig,
instance_to_stmt: HashMap::new(), instance_to_stmt: HashMap::new(),
instance_to_symbol: [("".to_string(), name.clone())].iter().cloned().collect(), instance_to_symbol: [("".into(), name.into())].iter().cloned().collect(),
var_id: Default::default(), var_id: Default::default(),
resolver: None, resolver: None,
})), })),
@ -131,7 +146,7 @@ impl TopLevelComposer {
ast: ast::Stmt<()>, ast: ast::Stmt<()>,
resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>, resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>,
mod_path: String, mod_path: String,
) -> Result<(String, DefinitionId, Option<Type>), String> { ) -> Result<(StrRef, DefinitionId, Option<Type>), String> {
let defined_names = &mut self.defined_names; let defined_names = &mut self.defined_names;
match &ast.node { match &ast.node {
ast::StmtKind::ClassDef { name: class_name, body, .. } => { ast::StmtKind::ClassDef { name: class_name, body, .. } => {
@ -140,7 +155,7 @@ impl TopLevelComposer {
} }
if !defined_names.insert({ if !defined_names.insert({
let mut n = mod_path.clone(); let mut n = mod_path.clone();
n.push_str(class_name.as_str()); n.push_str(&class_name.to_string());
n n
}) { }) {
return Err("duplicate definition of class".into()); return Err("duplicate definition of class".into());
@ -156,7 +171,7 @@ impl TopLevelComposer {
Arc::new(RwLock::new(Self::make_top_level_class_def( Arc::new(RwLock::new(Self::make_top_level_class_def(
class_def_id, class_def_id,
resolver.clone(), resolver.clone(),
class_name.as_str(), class_name,
Some(constructor_ty), Some(constructor_ty),
))), ))),
None, None,
@ -167,7 +182,7 @@ impl TopLevelComposer {
// thus cannot return their definition_id // thus cannot return their definition_id
type MethodInfo = ( type MethodInfo = (
// the simple method name without class name // the simple method name without class name
String, StrRef,
// in this top level def, method name is prefixed with the class name // in this top level def, method name is prefixed with the class name
Arc<RwLock<TopLevelDef>>, Arc<RwLock<TopLevelDef>>,
DefinitionId, DefinitionId,
@ -186,8 +201,11 @@ impl TopLevelComposer {
let global_class_method_name = { let global_class_method_name = {
let mut n = mod_path.clone(); let mut n = mod_path.clone();
n.push_str( n.push_str(
Self::make_class_method_name(class_name.clone(), method_name) Self::make_class_method_name(
.as_str(), class_name.into(),
&method_name.to_string(),
)
.as_str(),
); );
n n
}; };
@ -247,22 +265,22 @@ impl TopLevelComposer {
// if self.keyword_list.contains(name) { // if self.keyword_list.contains(name) {
// return Err("cannot use keyword as a top level function name".into()); // return Err("cannot use keyword as a top level function name".into());
// } // }
let fun_name = name.to_string();
let global_fun_name = { let global_fun_name = {
let mut n = mod_path; let mut n = mod_path;
n.push_str(name.as_str()); n.push_str(&name.to_string());
n n
}; };
if !defined_names.insert(global_fun_name.clone()) { if !defined_names.insert(global_fun_name.clone()) {
return Err("duplicate top level function define".into()); return Err("duplicate top level function define".into());
} }
let fun_name = *name;
let ty_to_be_unified = self.unifier.get_fresh_var().0; let ty_to_be_unified = self.unifier.get_fresh_var().0;
// add to the definition list // add to the definition list
self.definition_ast_list.push(( self.definition_ast_list.push((
RwLock::new(Self::make_top_level_function_def( RwLock::new(Self::make_top_level_function_def(
global_fun_name, global_fun_name,
name.into(), *name,
// dummy here, unify with correct type later // dummy here, unify with correct type later
ty_to_be_unified, ty_to_be_unified,
resolver, resolver,
@ -334,7 +352,7 @@ impl TopLevelComposer {
if { if {
matches!( matches!(
&value.node, &value.node,
ast::ExprKind::Name { id, .. } if id == "Generic" ast::ExprKind::Name { id, .. } if id == &"Generic".into()
) )
} => } =>
{ {
@ -432,7 +450,7 @@ impl TopLevelComposer {
ast::ExprKind::Subscript { value, .. } ast::ExprKind::Subscript { value, .. }
if matches!( if matches!(
&value.node, &value.node,
ast::ExprKind::Name { id, .. } if id == "Generic" ast::ExprKind::Name { id, .. } if id == &"Generic".into()
) )
) { ) {
continue; continue;
@ -627,9 +645,9 @@ impl TopLevelComposer {
let mut function_var_map: HashMap<u32, Type> = HashMap::new(); let mut function_var_map: HashMap<u32, Type> = HashMap::new();
let arg_types = { let arg_types = {
// make sure no duplicate parameter // make sure no duplicate parameter
let mut defined_paramter_name: HashSet<String> = HashSet::new(); let mut defined_paramter_name: HashSet<_> = HashSet::new();
let have_unique_fuction_parameter_name = args.args.iter().all(|x| { let have_unique_fuction_parameter_name = args.args.iter().all(|x| {
defined_paramter_name.insert(x.node.arg.clone()) defined_paramter_name.insert(x.node.arg)
&& !keyword_list.contains(&x.node.arg) && !keyword_list.contains(&x.node.arg)
}); });
if !have_unique_fuction_parameter_name { if !have_unique_fuction_parameter_name {
@ -765,7 +783,7 @@ impl TopLevelComposer {
unifier: &mut Unifier, unifier: &mut Unifier,
primitives: &PrimitiveStore, primitives: &PrimitiveStore,
type_var_to_concrete_def: &mut HashMap<Type, TypeAnnotation>, type_var_to_concrete_def: &mut HashMap<Type, TypeAnnotation>,
keyword_list: &HashSet<String>, keyword_list: &HashSet<StrRef>,
) -> Result<(), String> { ) -> Result<(), String> {
let mut class_def = class_def.write(); let mut class_def = class_def.write();
let ( let (
@ -809,12 +827,12 @@ impl TopLevelComposer {
let class_resolver = class_resolver.as_ref().unwrap(); let class_resolver = class_resolver.as_ref().unwrap();
let class_resolver = class_resolver.as_ref(); let class_resolver = class_resolver.as_ref();
let mut defined_fields: HashSet<String> = HashSet::new(); let mut defined_fields: HashSet<_> = HashSet::new();
for b in class_body_ast { for b in class_body_ast {
match &b.node { match &b.node {
ast::StmtKind::FunctionDef { args, returns, name, .. } => { ast::StmtKind::FunctionDef { args, returns, name, .. } => {
let (method_dummy_ty, method_id) = let (method_dummy_ty, method_id) =
Self::get_class_method_def_info(class_methods_def, name)?; Self::get_class_method_def_info(class_methods_def, *name)?;
// the method var map can surely include the class's generic parameters // the method var map can surely include the class's generic parameters
let mut method_var_map: HashMap<u32, Type> = class_type_vars_def let mut method_var_map: HashMap<u32, Type> = class_type_vars_def
@ -830,27 +848,28 @@ impl TopLevelComposer {
let arg_types: Vec<FuncArg> = { let arg_types: Vec<FuncArg> = {
// check method parameters cannot have same name // check method parameters cannot have same name
let mut defined_paramter_name: HashSet<String> = HashSet::new(); let mut defined_paramter_name: HashSet<_> = HashSet::new();
let zelf: StrRef = "self".into();
let have_unique_fuction_parameter_name = args.args.iter().all(|x| { let have_unique_fuction_parameter_name = args.args.iter().all(|x| {
defined_paramter_name.insert(x.node.arg.clone()) defined_paramter_name.insert(x.node.arg.clone())
&& (!keyword_list.contains(&x.node.arg) || x.node.arg == "self") && (!keyword_list.contains(&x.node.arg) || x.node.arg == zelf)
}); });
if !have_unique_fuction_parameter_name { if !have_unique_fuction_parameter_name {
return Err("class method must have unique parameter names \ return Err("class method must have unique parameter names \
and names thould not be the same as the keywords" and names thould not be the same as the keywords"
.into()); .into());
} }
if name == "__init__" && !defined_paramter_name.contains("self") { if name == &"__init__".into() && !defined_paramter_name.contains(&zelf) {
return Err("__init__ function must have a `self` parameter".into()); return Err("__init__ function must have a `self` parameter".into());
} }
if !defined_paramter_name.contains("self") { if !defined_paramter_name.contains(&zelf) {
return Err("currently does not support static method".into()); return Err("currently does not support static method".into());
} }
let mut result = Vec::new(); let mut result = Vec::new();
for x in &args.args { for x in &args.args {
let name = x.node.arg.clone(); let name = x.node.arg;
if name != "self" { if name != zelf {
let type_ann = { let type_ann = {
let annotation_expr = x let annotation_expr = x
.node .node
@ -962,14 +981,15 @@ impl TopLevelComposer {
if let ast::ExprKind::Name { id: attr, .. } = &target.node { if let ast::ExprKind::Name { id: attr, .. } = &target.node {
if defined_fields.insert(attr.to_string()) { if defined_fields.insert(attr.to_string()) {
let dummy_field_type = unifier.get_fresh_var().0; let dummy_field_type = unifier.get_fresh_var().0;
class_fields_def.push((attr.to_string(), dummy_field_type)); class_fields_def.push((*attr, dummy_field_type));
// handle Kernel[T], KernelImmutable[T] // handle Kernel[T], KernelImmutable[T]
let annotation = { let annotation = {
match &annotation.as_ref().node { match &annotation.as_ref().node {
ast::ExprKind::Subscript { value, slice, .. } ast::ExprKind::Subscript { value, slice, .. }
if { if {
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == "Kernel" || id == "KernelImmutable") matches!(&value.node, ast::ExprKind::Name { id, .. }
if id == &"Kernel".into() || id == &"KernelImmutable".into())
} => } =>
{ {
slice slice
@ -1054,19 +1074,19 @@ impl TopLevelComposer {
if let TopLevelDef::Class { methods, fields, .. } = &*base { if let TopLevelDef::Class { methods, fields, .. } = &*base {
// handle methods override // handle methods override
// since we need to maintain the order, create a new list // since we need to maintain the order, create a new list
let mut new_child_methods: Vec<(String, Type, DefinitionId)> = Vec::new(); let mut new_child_methods: Vec<(StrRef, Type, DefinitionId)> = Vec::new();
let mut is_override: HashSet<String> = HashSet::new(); let mut is_override: HashSet<StrRef> = HashSet::new();
for (anc_method_name, anc_method_ty, anc_method_def_id) in methods { for (anc_method_name, anc_method_ty, anc_method_def_id) in methods {
// find if there is a method with same name in the child class // find if there is a method with same name in the child class
let mut to_be_added = let mut to_be_added =
(anc_method_name.to_string(), *anc_method_ty, *anc_method_def_id); (*anc_method_name, *anc_method_ty, *anc_method_def_id);
for (class_method_name, class_method_ty, class_method_defid) in for (class_method_name, class_method_ty, class_method_defid) in
class_methods_def.iter() class_methods_def.iter()
{ {
if class_method_name == anc_method_name { if class_method_name == anc_method_name {
// ignore and handle self // ignore and handle self
// if is __init__ method, no need to check return type // if is __init__ method, no need to check return type
let ok = class_method_name == "__init__" let ok = class_method_name == &"__init__".into()
|| Self::check_overload_function_type( || Self::check_overload_function_type(
*class_method_ty, *class_method_ty,
*anc_method_ty, *anc_method_ty,
@ -1077,9 +1097,9 @@ impl TopLevelComposer {
return Err("method has same name as ancestors' method, but incompatible type".into()); return Err("method has same name as ancestors' method, but incompatible type".into());
} }
// mark it as added // mark it as added
is_override.insert(class_method_name.to_string()); is_override.insert(*class_method_name);
to_be_added = ( to_be_added = (
class_method_name.to_string(), *class_method_name,
*class_method_ty, *class_method_ty,
*class_method_defid, *class_method_defid,
); );
@ -1094,7 +1114,7 @@ impl TopLevelComposer {
{ {
if !is_override.contains(class_method_name) { if !is_override.contains(class_method_name) {
new_child_methods.push(( new_child_methods.push((
class_method_name.to_string(), *class_method_name,
*class_method_ty, *class_method_ty,
*class_method_defid, *class_method_defid,
)); ));
@ -1105,10 +1125,10 @@ impl TopLevelComposer {
class_methods_def.extend(new_child_methods); class_methods_def.extend(new_child_methods);
// handle class fields // handle class fields
let mut new_child_fields: Vec<(String, Type)> = Vec::new(); let mut new_child_fields: Vec<(StrRef, Type)> = Vec::new();
// let mut is_override: HashSet<String> = HashSet::new(); // let mut is_override: HashSet<_> = HashSet::new();
for (anc_field_name, anc_field_ty) in fields { for (anc_field_name, anc_field_ty) in fields {
let to_be_added = (anc_field_name.to_string(), *anc_field_ty); let to_be_added = (*anc_field_name, *anc_field_ty);
// find if there is a fields with the same name in the child class // find if there is a fields with the same name in the child class
for (class_field_name, ..) in class_fields_def.iter() { for (class_field_name, ..) in class_fields_def.iter() {
if class_field_name == anc_field_name { if class_field_name == anc_field_name {
@ -1135,7 +1155,7 @@ impl TopLevelComposer {
} }
for (class_field_name, class_field_ty) in class_fields_def.iter() { for (class_field_name, class_field_ty) in class_fields_def.iter() {
if !is_override.contains(class_field_name) { if !is_override.contains(class_field_name) {
new_child_fields.push((class_field_name.to_string(), *class_field_ty)); new_child_fields.push((*class_field_name, *class_field_ty));
} }
} }
class_fields_def.drain(..); class_fields_def.drain(..);
@ -1173,7 +1193,7 @@ impl TopLevelComposer {
let mut constructor_args: Vec<FuncArg> = Vec::new(); let mut constructor_args: Vec<FuncArg> = Vec::new();
let mut type_vars: HashMap<u32, Type> = HashMap::new(); let mut type_vars: HashMap<u32, Type> = HashMap::new();
for (name, func_sig, id) in methods { for (name, func_sig, id) in methods {
if name == "__init__" { if name == &"__init__".into() {
init_id = Some(*id); init_id = Some(*id);
if let TypeEnum::TFunc(sig) = self.unifier.get_ty(*func_sig).as_ref() { if let TypeEnum::TFunc(sig) = self.unifier.get_ty(*func_sig).as_ref() {
let FunSignature { args, vars, .. } = &*sig.borrow(); let FunSignature { args, vars, .. } = &*sig.borrow();
@ -1203,7 +1223,7 @@ impl TopLevelComposer {
let init_ast = let init_ast =
self.definition_ast_list.get(init_id.0).unwrap().1.as_ref().unwrap(); self.definition_ast_list.get(init_id.0).unwrap().1.as_ref().unwrap();
if let ast::StmtKind::FunctionDef { name, body, .. } = &init_ast.node { if let ast::StmtKind::FunctionDef { name, body, .. } = &init_ast.node {
if name != "__init__" { if name != &"__init__".into() {
unreachable!("must be init function here") unreachable!("must be init function here")
} }
let all_inited = Self::get_all_assigned_field(body.as_slice())?; let all_inited = Self::get_all_assigned_field(body.as_slice())?;
@ -1303,7 +1323,7 @@ impl TopLevelComposer {
let mut identifiers = { let mut identifiers = {
// NOTE: none and function args? // NOTE: none and function args?
let mut result: HashSet<String> = HashSet::new(); let mut result: HashSet<_> = HashSet::new();
result.insert("None".into()); result.insert("None".into());
if self_type.is_some() { if self_type.is_some() {
result.insert("self".into()); result.insert("self".into());
@ -1331,7 +1351,7 @@ impl TopLevelComposer {
unifier: &mut self.unifier, unifier: &mut self.unifier,
variable_mapping: { variable_mapping: {
// NOTE: none and function args? // NOTE: none and function args?
let mut result: HashMap<String, Type> = HashMap::new(); let mut result: HashMap<StrRef, Type> = HashMap::new();
result.insert("None".into(), self.primitives_ty.none); result.insert("None".into(), self.primitives_ty.none);
if let Some(self_ty) = self_type { if let Some(self_ty) = self_type {
result.insert("self".into(), self_ty); result.insert("self".into(), self_ty);
@ -1350,9 +1370,9 @@ impl TopLevelComposer {
{ {
if !decorator_list.is_empty() if !decorator_list.is_empty()
&& matches!(&decorator_list[0].node, && matches!(&decorator_list[0].node,
ast::ExprKind::Name{ id, .. } if id == "syscall") ast::ExprKind::Name{ id, .. } if id == &"syscall".into())
{ {
instance_to_symbol.insert("".to_string(), simple_name.clone()); instance_to_symbol.insert("".into(), simple_name.to_string());
continue; continue;
} }
body body

View File

@ -92,11 +92,11 @@ impl TopLevelComposer {
pub fn make_top_level_class_def( pub fn make_top_level_class_def(
index: usize, index: usize,
resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>, resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>,
name: &str, name: StrRef,
constructor: Option<Type>, constructor: Option<Type>,
) -> TopLevelDef { ) -> TopLevelDef {
TopLevelDef::Class { TopLevelDef::Class {
name: name.to_string(), name,
object_id: DefinitionId(index), object_id: DefinitionId(index),
type_vars: Default::default(), type_vars: Default::default(),
fields: Default::default(), fields: Default::default(),
@ -110,7 +110,7 @@ impl TopLevelComposer {
/// when first registering, the type is a invalid value /// when first registering, the type is a invalid value
pub fn make_top_level_function_def( pub fn make_top_level_function_def(
name: String, name: String,
simple_name: String, simple_name: StrRef,
ty: Type, ty: Type,
resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>, resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>,
) -> TopLevelDef { ) -> TopLevelDef {
@ -132,11 +132,11 @@ impl TopLevelComposer {
} }
pub fn get_class_method_def_info( pub fn get_class_method_def_info(
class_methods_def: &[(String, Type, DefinitionId)], class_methods_def: &[(StrRef, Type, DefinitionId)],
method_name: &str, method_name: StrRef,
) -> Result<(Type, DefinitionId), String> { ) -> Result<(Type, DefinitionId), String> {
for (name, ty, def_id) in class_methods_def { for (name, ty, def_id) in class_methods_def {
if name == method_name { if name == &method_name {
return Ok((*ty, *def_id)); return Ok((*ty, *def_id));
} }
} }
@ -234,7 +234,7 @@ impl TopLevelComposer {
(name, type_var_to_concrete_def.get(ty).unwrap()) (name, type_var_to_concrete_def.get(ty).unwrap())
})) }))
.all(|(this, other)| { .all(|(this, other)| {
if this.0 == "self" && this.0 == other.0 { if this.0 == &"self".into() && this.0 == other.0 {
true true
} else { } else {
this.0 == other.0 this.0 == other.0
@ -269,15 +269,15 @@ impl TopLevelComposer {
) )
} }
pub fn get_all_assigned_field(stmts: &[ast::Stmt<()>]) -> Result<HashSet<String>, String> { pub fn get_all_assigned_field(stmts: &[ast::Stmt<()>]) -> Result<HashSet<StrRef>, String> {
let mut result: HashSet<String> = HashSet::new(); let mut result = HashSet::new();
for s in stmts { for s in stmts {
match &s.node { match &s.node {
ast::StmtKind::AnnAssign { target, .. } ast::StmtKind::AnnAssign { target, .. }
if { if {
if let ast::ExprKind::Attribute { value, .. } = &target.node { if let ast::ExprKind::Attribute { value, .. } = &target.node {
if let ast::ExprKind::Name { id, .. } = &value.node { if let ast::ExprKind::Name { id, .. } = &value.node {
id == "self" id == &"self".into()
} else { } else {
false false
} }
@ -295,7 +295,7 @@ impl TopLevelComposer {
for t in targets { for t in targets {
if let ast::ExprKind::Attribute { value, attr, .. } = &t.node { if let ast::ExprKind::Attribute { value, attr, .. } = &t.node {
if let ast::ExprKind::Name { id, .. } = &value.node { if let ast::ExprKind::Name { id, .. } = &value.node {
if id == "self" { if id == &"self".into() {
result.insert(attr.clone()); result.insert(attr.clone());
} }
} }
@ -312,14 +312,14 @@ impl TopLevelComposer {
let inited_for_sure = Self::get_all_assigned_field(body.as_slice())? let inited_for_sure = Self::get_all_assigned_field(body.as_slice())?
.intersection(&Self::get_all_assigned_field(orelse.as_slice())?) .intersection(&Self::get_all_assigned_field(orelse.as_slice())?)
.cloned() .cloned()
.collect::<HashSet<String>>(); .collect::<HashSet<_>>();
result.extend(inited_for_sure); result.extend(inited_for_sure);
} }
ast::StmtKind::Try { body, orelse, finalbody, .. } => { ast::StmtKind::Try { body, orelse, finalbody, .. } => {
let inited_for_sure = Self::get_all_assigned_field(body.as_slice())? let inited_for_sure = Self::get_all_assigned_field(body.as_slice())?
.intersection(&Self::get_all_assigned_field(orelse.as_slice())?) .intersection(&Self::get_all_assigned_field(orelse.as_slice())?)
.cloned() .cloned()
.collect::<HashSet<String>>(); .collect::<HashSet<_>>();
result.extend(inited_for_sure); result.extend(inited_for_sure);
result.extend(Self::get_all_assigned_field(finalbody.as_slice())?); result.extend(Self::get_all_assigned_field(finalbody.as_slice())?);
} }

View File

@ -15,7 +15,7 @@ use crate::{
}; };
use itertools::{izip, Itertools}; use itertools::{izip, Itertools};
use parking_lot::RwLock; use parking_lot::RwLock;
use rustpython_parser::ast::{self, Stmt}; use rustpython_parser::ast::{self, Stmt, StrRef};
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash, Debug)] #[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash, Debug)]
pub struct DefinitionId(pub usize); pub struct DefinitionId(pub usize);
@ -40,15 +40,15 @@ pub struct FunInstance {
pub enum TopLevelDef { pub enum TopLevelDef {
Class { Class {
// name for error messages and symbols // name for error messages and symbols
name: String, name: StrRef,
// object ID used for TypeEnum // object ID used for TypeEnum
object_id: DefinitionId, object_id: DefinitionId,
/// type variables bounded to the class. /// type variables bounded to the class.
type_vars: Vec<Type>, type_vars: Vec<Type>,
// class fields // class fields
fields: Vec<(String, Type)>, fields: Vec<(StrRef, Type)>,
// class methods, pointing to the corresponding function definition. // class methods, pointing to the corresponding function definition.
methods: Vec<(String, Type, DefinitionId)>, methods: Vec<(StrRef, Type, DefinitionId)>,
// ancestor classes, including itself. // ancestor classes, including itself.
ancestors: Vec<TypeAnnotation>, ancestors: Vec<TypeAnnotation>,
// symbol resolver of the module defined the class, none if it is built-in type // symbol resolver of the module defined the class, none if it is built-in type
@ -60,7 +60,7 @@ pub enum TopLevelDef {
// prefix for symbol, should be unique globally // prefix for symbol, should be unique globally
name: String, name: String,
// simple name, the same as in method/function definition // simple name, the same as in method/function definition
simple_name: String, simple_name: StrRef,
// function signature. // function signature.
signature: Type, signature: Type,
// instantiated type variable IDs // instantiated type variable IDs

View File

@ -16,17 +16,17 @@ use test_case::test_case;
use super::*; use super::*;
struct ResolverInternal { struct ResolverInternal {
id_to_type: Mutex<HashMap<String, Type>>, id_to_type: Mutex<HashMap<StrRef, Type>>,
id_to_def: Mutex<HashMap<String, DefinitionId>>, id_to_def: Mutex<HashMap<StrRef, DefinitionId>>,
class_names: Mutex<HashMap<String, Type>>, class_names: Mutex<HashMap<StrRef, Type>>,
} }
impl ResolverInternal { impl ResolverInternal {
fn add_id_def(&self, id: String, def: DefinitionId) { fn add_id_def(&self, id: StrRef, def: DefinitionId) {
self.id_to_def.lock().insert(id, def); self.id_to_def.lock().insert(id, def);
} }
fn add_id_type(&self, id: String, ty: Type) { fn add_id_type(&self, id: StrRef, ty: Type) {
self.id_to_type.lock().insert(id, ty); self.id_to_type.lock().insert(id, ty);
} }
} }
@ -34,24 +34,24 @@ impl ResolverInternal {
struct Resolver(Arc<ResolverInternal>); struct Resolver(Arc<ResolverInternal>);
impl SymbolResolver for Resolver { impl SymbolResolver for Resolver {
fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option<Type> { fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: StrRef) -> Option<Type> {
let ret = self.0.id_to_type.lock().get(str).cloned(); let ret = self.0.id_to_type.lock().get(&str).cloned();
if ret.is_none() { if ret.is_none() {
// println!("unknown here resolver {}", str); // println!("unknown here resolver {}", str);
} }
ret ret
} }
fn get_symbol_value(&self, _: &str) -> Option<SymbolValue> { fn get_symbol_value(&self, _: StrRef) -> Option<SymbolValue> {
unimplemented!() unimplemented!()
} }
fn get_symbol_location(&self, _: &str) -> Option<Location> { fn get_symbol_location(&self, _: StrRef) -> Option<Location> {
unimplemented!() unimplemented!()
} }
fn get_identifier_def(&self, id: &str) -> Option<DefinitionId> { fn get_identifier_def(&self, id: StrRef) -> Option<DefinitionId> {
self.0.id_to_def.lock().get(id).cloned() self.0.id_to_def.lock().get(&id).cloned()
} }
} }
@ -70,7 +70,7 @@ impl SymbolResolver for Resolver {
class B: class B:
def __init__(self): def __init__(self):
self.b: float = 4.3 self.b: float = 4.3
def fun(self): def fun(self):
self.b = self.b + 3.0 self.b = self.b + 3.0
"}, "},
@ -449,19 +449,19 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
methods: [(\"__init__\", \"fn[[a=class5[2->class2, 3->class3], b=class8], class4]\", DefinitionId(6)), (\"fun\", \"fn[[a=class5[2->class2, 3->class3]], class5[2->class3, 3->class0]]\", DefinitionId(7))], methods: [(\"__init__\", \"fn[[a=class5[2->class2, 3->class3], b=class8], class4]\", DefinitionId(6)), (\"fun\", \"fn[[a=class5[2->class2, 3->class3]], class5[2->class3, 3->class0]]\", DefinitionId(7))],
type_vars: [UnificationKey(100), UnificationKey(101)] type_vars: [UnificationKey(100), UnificationKey(101)]
}"}, }"},
indoc! {"6: Function { indoc! {"6: Function {
name: \"A.__init__\", name: \"A.__init__\",
sig: \"fn[[a=class5[2->class2, 3->class3], b=class8], class4]\", sig: \"fn[[a=class5[2->class2, 3->class3], b=class8], class4]\",
var_id: [2, 3] var_id: [2, 3]
}"}, }"},
indoc! {"7: Function { indoc! {"7: Function {
name: \"A.fun\", name: \"A.fun\",
sig: \"fn[[a=class5[2->class2, 3->class3]], class5[2->class3, 3->class0]]\", sig: \"fn[[a=class5[2->class2, 3->class3]], class5[2->class3, 3->class0]]\",
var_id: [2, 3] var_id: [2, 3]
}"}, }"},
indoc! {"8: Class { indoc! {"8: Class {
name: \"B\", name: \"B\",
def_id: DefinitionId(8), def_id: DefinitionId(8),
@ -470,19 +470,19 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
methods: [(\"__init__\", \"fn[[], class4]\", DefinitionId(9)), (\"fun\", \"fn[[a=class5[2->class2, 3->class3]], class5[2->class3, 3->class0]]\", DefinitionId(7)), (\"foo\", \"fn[[b=class8], class8]\", DefinitionId(10)), (\"bar\", \"fn[[a=class5[2->list[class8], 3->class0]], tuple[class5[2->virtual[class5[2->class8, 3->class0]], 3->class3], class8]]\", DefinitionId(11))], methods: [(\"__init__\", \"fn[[], class4]\", DefinitionId(9)), (\"fun\", \"fn[[a=class5[2->class2, 3->class3]], class5[2->class3, 3->class0]]\", DefinitionId(7)), (\"foo\", \"fn[[b=class8], class8]\", DefinitionId(10)), (\"bar\", \"fn[[a=class5[2->list[class8], 3->class0]], tuple[class5[2->virtual[class5[2->class8, 3->class0]], 3->class3], class8]]\", DefinitionId(11))],
type_vars: [] type_vars: []
}"}, }"},
indoc! {"9: Function { indoc! {"9: Function {
name: \"B.__init__\", name: \"B.__init__\",
sig: \"fn[[], class4]\", sig: \"fn[[], class4]\",
var_id: [] var_id: []
}"}, }"},
indoc! {"10: Function { indoc! {"10: Function {
name: \"B.foo\", name: \"B.foo\",
sig: \"fn[[b=class8], class8]\", sig: \"fn[[b=class8], class8]\",
var_id: [] var_id: []
}"}, }"},
indoc! {"11: Function { indoc! {"11: Function {
name: \"B.bar\", name: \"B.bar\",
sig: \"fn[[a=class5[2->list[class8], 3->class0]], tuple[class5[2->virtual[class5[2->class8, 3->class0]], 3->class3], class8]]\", sig: \"fn[[a=class5[2->list[class8], 3->class0]], tuple[class5[2->virtual[class5[2->class8, 3->class0]], 3->class3], class8]]\",
@ -648,15 +648,6 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
vec!["5: Class {\nname: \"A\",\ndef_id: DefinitionId(5),\nancestors: [CustomClassKind { id: DefinitionId(5), params: [] }],\nfields: [],\nmethods: [],\ntype_vars: []\n}"]; vec!["5: Class {\nname: \"A\",\ndef_id: DefinitionId(5),\nancestors: [CustomClassKind { id: DefinitionId(5), params: [] }],\nfields: [],\nmethods: [],\ntype_vars: []\n}"];
"simple pass in class" "simple pass in class"
)] )]
#[test_case(
vec![indoc! {"
class A:
def fun3(self):
pass
"}],
vec!["function name `fun3` must not end with numbers"];
"err fun end with number"
)]
#[test_case( #[test_case(
vec![indoc! {" vec![indoc! {"
class A: class A:
@ -790,7 +781,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
} }
} }
}; };
internal_resolver.add_id_def(id.clone(), def_id); internal_resolver.add_id_def(id, def_id);
if let Some(ty) = ty { if let Some(ty) = ty {
internal_resolver.add_id_type(id, ty); internal_resolver.add_id_type(id, ty);
} }
@ -1027,7 +1018,7 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) {
); );
for inst in instance_to_stmt.iter() { for inst in instance_to_stmt.iter() {
let ast = &inst.1.body; let ast = &inst.1.body;
for b in ast { for b in ast.iter() {
println!("{:?}", stringify_folder.fold_stmt(b.clone()).unwrap()); println!("{:?}", stringify_folder.fold_stmt(b.clone()).unwrap());
println!("--------------------"); println!("--------------------");
} }
@ -1039,7 +1030,7 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) {
} }
fn make_internal_resolver_with_tvar( fn make_internal_resolver_with_tvar(
tvars: Vec<(String, Vec<Type>)>, tvars: Vec<(StrRef, Vec<Type>)>,
unifier: &mut Unifier, unifier: &mut Unifier,
print: bool, print: bool,
) -> Arc<ResolverInternal> { ) -> Arc<ResolverInternal> {

View File

@ -30,49 +30,54 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
mut locked: HashMap<DefinitionId, Vec<Type>>, mut locked: HashMap<DefinitionId, Vec<Type>>,
) -> Result<TypeAnnotation, String> { ) -> Result<TypeAnnotation, String> {
match &expr.node { match &expr.node {
ast::ExprKind::Name { id, .. } => match id.as_str() { ast::ExprKind::Name { id, .. } => {
"int32" => Ok(TypeAnnotation::PrimitiveKind(primitives.int32)), if id == &"int32".into() {
"int64" => Ok(TypeAnnotation::PrimitiveKind(primitives.int64)), Ok(TypeAnnotation::PrimitiveKind(primitives.int32))
"float" => Ok(TypeAnnotation::PrimitiveKind(primitives.float)), } else if id == &"int64".into() {
"bool" => Ok(TypeAnnotation::PrimitiveKind(primitives.bool)), Ok(TypeAnnotation::PrimitiveKind(primitives.int64))
"None" => Ok(TypeAnnotation::PrimitiveKind(primitives.none)), } else if id == &"float".into() {
x => { Ok(TypeAnnotation::PrimitiveKind(primitives.float))
if let Some(obj_id) = resolver.get_identifier_def(x) { } else if id == &"bool".into() {
let type_vars = { Ok(TypeAnnotation::PrimitiveKind(primitives.bool))
let def_read = top_level_defs[obj_id.0].try_read(); } else if id == &"None".into() {
if let Some(def_read) = def_read { Ok(TypeAnnotation::PrimitiveKind(primitives.none))
if let TopLevelDef::Class { type_vars, .. } = &*def_read { } else if let Some(obj_id) = resolver.get_identifier_def(*id) {
type_vars.clone() let type_vars = {
} else { let def_read = top_level_defs[obj_id.0].try_read();
return Err("function cannot be used as a type".into()); if let Some(def_read) = def_read {
} if let TopLevelDef::Class { type_vars, .. } = &*def_read {
type_vars.clone()
} else { } else {
locked.get(&obj_id).unwrap().clone() return Err("function cannot be used as a type".into());
} }
};
// check param number here
if !type_vars.is_empty() {
return Err(format!(
"expect {} type variable parameter but got 0",
type_vars.len()
));
}
Ok(TypeAnnotation::CustomClassKind { id: obj_id, params: vec![] })
} else if let Some(ty) = resolver.get_symbol_type(unifier, primitives, id) {
if let TypeEnum::TVar { .. } = unifier.get_ty(ty).as_ref() {
Ok(TypeAnnotation::TypeVarKind(ty))
} else { } else {
Err("not a type variable identifier".into()) locked.get(&obj_id).unwrap().clone()
} }
} else { };
Err("name cannot be parsed as a type annotation".into()) // check param number here
if !type_vars.is_empty() {
return Err(format!(
"expect {} type variable parameter but got 0",
type_vars.len()
));
} }
Ok(TypeAnnotation::CustomClassKind { id: obj_id, params: vec![] })
} else if let Some(ty) = resolver.get_symbol_type(unifier, primitives, *id) {
if let TypeEnum::TVar { .. } = unifier.get_ty(ty).as_ref() {
Ok(TypeAnnotation::TypeVarKind(ty))
} else {
Err("not a type variable identifier".into())
}
} else {
Err("name cannot be parsed as a type annotation".into())
} }
}, }
// virtual // virtual
ast::ExprKind::Subscript { value, slice, .. } ast::ExprKind::Subscript { value, slice, .. }
if { matches!(&value.node, ast::ExprKind::Name { id, .. } if id == "virtual") } => if {
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"virtual".into())
} =>
{ {
let def = parse_ast_to_type_annotation_kinds( let def = parse_ast_to_type_annotation_kinds(
resolver, resolver,
@ -90,7 +95,9 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
// list // list
ast::ExprKind::Subscript { value, slice, .. } ast::ExprKind::Subscript { value, slice, .. }
if { matches!(&value.node, ast::ExprKind::Name { id, .. } if id == "list") } => if {
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"list".into())
} =>
{ {
let def_ann = parse_ast_to_type_annotation_kinds( let def_ann = parse_ast_to_type_annotation_kinds(
resolver, resolver,
@ -105,7 +112,9 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
// tuple // tuple
ast::ExprKind::Subscript { value, slice, .. } ast::ExprKind::Subscript { value, slice, .. }
if { matches!(&value.node, ast::ExprKind::Name { id, .. } if id == "tuple") } => if {
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"tuple".into())
} =>
{ {
if let ast::ExprKind::Tuple { elts, .. } = &slice.node { if let ast::ExprKind::Tuple { elts, .. } = &slice.node {
let type_annotations = elts let type_annotations = elts
@ -130,11 +139,13 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
// custom class // custom class
ast::ExprKind::Subscript { value, slice, .. } => { ast::ExprKind::Subscript { value, slice, .. } => {
if let ast::ExprKind::Name { id, .. } = &value.node { if let ast::ExprKind::Name { id, .. } = &value.node {
if vec!["virtual", "Generic", "list", "tuple"].contains(&id.as_str()) { if vec!["virtual".into(), "Generic".into(), "list".into(), "tuple".into()]
.contains(id)
{
return Err("keywords cannot be class name".into()); return Err("keywords cannot be class name".into());
} }
let obj_id = resolver let obj_id = resolver
.get_identifier_def(id) .get_identifier_def(*id)
.ok_or_else(|| "unknown class name".to_string())?; .ok_or_else(|| "unknown class name".to_string())?;
let type_vars = { let type_vars = {
let def_read = top_level_defs[obj_id.0].try_read(); let def_read = top_level_defs[obj_id.0].try_read();
@ -272,12 +283,12 @@ pub fn get_type_from_type_annotation_kinds(
.iter() .iter()
.map(|(name, ty, _)| { .map(|(name, ty, _)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty); let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(name.clone(), subst_ty) (*name, subst_ty)
}) })
.collect::<HashMap<String, Type>>(); .collect::<HashMap<_, Type>>();
tobj_fields.extend(fields.iter().map(|(name, ty)| { tobj_fields.extend(fields.iter().map(|(name, ty)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty); let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(name.clone(), subst_ty) (*name, subst_ty)
})); }));
// println!("tobj_fields: {:?}", tobj_fields); // println!("tobj_fields: {:?}", tobj_fields);

View File

@ -2,14 +2,14 @@ use crate::typecheck::typedef::TypeEnum;
use super::type_inferencer::Inferencer; use super::type_inferencer::Inferencer;
use super::typedef::Type; use super::typedef::Type;
use rustpython_parser::ast::{self, Expr, ExprKind, Stmt, StmtKind}; use rustpython_parser::ast::{self, Expr, ExprKind, Stmt, StmtKind, StrRef};
use std::{collections::HashSet, iter::once}; use std::{collections::HashSet, iter::once};
impl<'a> Inferencer<'a> { impl<'a> Inferencer<'a> {
fn check_pattern( fn check_pattern(
&mut self, &mut self,
pattern: &Expr<Option<Type>>, pattern: &Expr<Option<Type>>,
defined_identifiers: &mut HashSet<String>, defined_identifiers: &mut HashSet<StrRef>,
) -> Result<(), String> { ) -> Result<(), String> {
match &pattern.node { match &pattern.node {
ExprKind::Name { id, .. } => { ExprKind::Name { id, .. } => {
@ -42,7 +42,7 @@ impl<'a> Inferencer<'a> {
fn check_expr( fn check_expr(
&mut self, &mut self,
expr: &Expr<Option<Type>>, expr: &Expr<Option<Type>>,
defined_identifiers: &mut HashSet<String>, defined_identifiers: &mut HashSet<StrRef>,
) -> Result<(), String> { ) -> Result<(), String> {
// there are some cases where the custom field is None // there are some cases where the custom field is None
if let Some(ty) = &expr.custom { if let Some(ty) = &expr.custom {
@ -57,7 +57,7 @@ impl<'a> Inferencer<'a> {
match &expr.node { match &expr.node {
ExprKind::Name { id, .. } => { ExprKind::Name { id, .. } => {
if !defined_identifiers.contains(id) { if !defined_identifiers.contains(id) {
if self.function_data.resolver.get_identifier_def(id).is_some() { if self.function_data.resolver.get_identifier_def(*id).is_some() {
defined_identifiers.insert(id.clone()); defined_identifiers.insert(id.clone());
} else { } else {
return Err(format!( return Err(format!(
@ -143,7 +143,7 @@ impl<'a> Inferencer<'a> {
fn check_stmt( fn check_stmt(
&mut self, &mut self,
stmt: &Stmt<Option<Type>>, stmt: &Stmt<Option<Type>>,
defined_identifiers: &mut HashSet<String>, defined_identifiers: &mut HashSet<StrRef>,
) -> Result<bool, String> { ) -> Result<bool, String> {
match &stmt.node { match &stmt.node {
StmtKind::For { target, iter, body, orelse, .. } => { StmtKind::For { target, iter, body, orelse, .. } => {
@ -217,7 +217,7 @@ impl<'a> Inferencer<'a> {
pub fn check_block( pub fn check_block(
&mut self, &mut self,
block: &[Stmt<Option<Type>>], block: &[Stmt<Option<Type>>],
defined_identifiers: &mut HashSet<String>, defined_identifiers: &mut HashSet<StrRef>,
) -> Result<bool, String> { ) -> Result<bool, String> {
let mut ret = false; let mut ret = false;
for stmt in block { for stmt in block {

View File

@ -10,7 +10,7 @@ use itertools::izip;
use rustpython_parser::ast::{ use rustpython_parser::ast::{
self, self,
fold::{self, Fold}, fold::{self, Fold},
Arguments, Comprehension, ExprKind, Located, Location, Arguments, Comprehension, ExprKind, Located, Location, StrRef,
}; };
#[cfg(test)] #[cfg(test)]
@ -45,12 +45,12 @@ pub struct FunctionData {
pub struct Inferencer<'a> { pub struct Inferencer<'a> {
pub top_level: &'a TopLevelContext, pub top_level: &'a TopLevelContext,
pub defined_identifiers: HashSet<String>, pub defined_identifiers: HashSet<StrRef>,
pub function_data: &'a mut FunctionData, pub function_data: &'a mut FunctionData,
pub unifier: &'a mut Unifier, pub unifier: &'a mut Unifier,
pub primitives: &'a PrimitiveStore, pub primitives: &'a PrimitiveStore,
pub virtual_checks: &'a mut Vec<(Type, Type)>, pub virtual_checks: &'a mut Vec<(Type, Type)>,
pub variable_mapping: HashMap<String, Type>, pub variable_mapping: HashMap<StrRef, Type>,
pub calls: &'a mut HashMap<CodeLocation, CallId>, pub calls: &'a mut HashMap<CodeLocation, CallId>,
} }
@ -163,7 +163,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
ast::ExprKind::Constant { value, .. } => Some(self.infer_constant(value)?), ast::ExprKind::Constant { value, .. } => Some(self.infer_constant(value)?),
ast::ExprKind::Name { id, .. } => { ast::ExprKind::Name { id, .. } => {
if !self.defined_identifiers.contains(id) { if !self.defined_identifiers.contains(id) {
if self.function_data.resolver.get_identifier_def(id.as_str()).is_some() { if self.function_data.resolver.get_identifier_def(*id).is_some() {
self.defined_identifiers.insert(id.clone()); self.defined_identifiers.insert(id.clone());
} else { } else {
return Err(format!( return Err(format!(
@ -172,12 +172,12 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
)); ));
} }
} }
Some(self.infer_identifier(id)?) Some(self.infer_identifier(*id)?)
} }
ast::ExprKind::List { elts, .. } => Some(self.infer_list(elts)?), ast::ExprKind::List { elts, .. } => Some(self.infer_list(elts)?),
ast::ExprKind::Tuple { elts, .. } => Some(self.infer_tuple(elts)?), ast::ExprKind::Tuple { elts, .. } => Some(self.infer_tuple(elts)?),
ast::ExprKind::Attribute { value, attr, ctx: _ } => { ast::ExprKind::Attribute { value, attr, ctx: _ } => {
Some(self.infer_attribute(value, attr)?) Some(self.infer_attribute(value, *attr)?)
} }
ast::ExprKind::BoolOp { values, .. } => Some(self.infer_bool_ops(values)?), ast::ExprKind::BoolOp { values, .. } => Some(self.infer_bool_ops(values)?),
ast::ExprKind::BinOp { left, op, right } => { ast::ExprKind::BinOp { left, op, right } => {
@ -237,7 +237,7 @@ impl<'a> Inferencer<'a> {
fn build_method_call( fn build_method_call(
&mut self, &mut self,
location: Location, location: Location,
method: String, method: StrRef,
obj: Type, obj: Type,
params: Vec<Type>, params: Vec<Type>,
ret: Option<Type>, ret: Option<Type>,
@ -413,7 +413,7 @@ impl<'a> Inferencer<'a> {
func func
{ {
// handle special functions that cannot be typed in the usual way... // handle special functions that cannot be typed in the usual way...
if id == "virtual" { if id == "virtual".into() {
if args.is_empty() || args.len() > 2 || !keywords.is_empty() { if args.is_empty() || args.len() > 2 || !keywords.is_empty() {
return Err( return Err(
"`virtual` can only accept 1/2 positional arguments.".to_string() "`virtual` can only accept 1/2 positional arguments.".to_string()
@ -448,7 +448,7 @@ impl<'a> Inferencer<'a> {
}); });
} }
// int64 is special because its argument can be a constant larger than int32 // int64 is special because its argument can be a constant larger than int32
if id == "int64" && args.len() == 1 { if id == "int64".into() && args.len() == 1 {
if let ExprKind::Constant { value: ast::Constant::Int(val), kind } = if let ExprKind::Constant { value: ast::Constant::Int(val), kind } =
&args[0].node &args[0].node
{ {
@ -508,8 +508,8 @@ impl<'a> Inferencer<'a> {
Ok(Located { location, custom: Some(ret), node: ExprKind::Call { func, args, keywords } }) Ok(Located { location, custom: Some(ret), node: ExprKind::Call { func, args, keywords } })
} }
fn infer_identifier(&mut self, id: &str) -> InferenceResult { fn infer_identifier(&mut self, id: StrRef) -> InferenceResult {
if let Some(ty) = self.variable_mapping.get(id) { if let Some(ty) = self.variable_mapping.get(&id) {
Ok(*ty) Ok(*ty)
} else { } else {
let variable_mapping = &mut self.variable_mapping; let variable_mapping = &mut self.variable_mapping;
@ -520,7 +520,7 @@ impl<'a> Inferencer<'a> {
.get_symbol_type(unifier, self.primitives, id) .get_symbol_type(unifier, self.primitives, id)
.unwrap_or_else(|| { .unwrap_or_else(|| {
let ty = unifier.get_fresh_var().0; let ty = unifier.get_fresh_var().0;
variable_mapping.insert(id.to_string(), ty); variable_mapping.insert(id, ty);
ty ty
})) }))
} }
@ -560,9 +560,13 @@ impl<'a> Inferencer<'a> {
Ok(self.unifier.add_ty(TypeEnum::TTuple { ty })) Ok(self.unifier.add_ty(TypeEnum::TTuple { ty }))
} }
fn infer_attribute(&mut self, value: &ast::Expr<Option<Type>>, attr: &str) -> InferenceResult { fn infer_attribute(
&mut self,
value: &ast::Expr<Option<Type>>,
attr: StrRef,
) -> InferenceResult {
let (attr_ty, _) = self.unifier.get_fresh_var(); let (attr_ty, _) = self.unifier.get_fresh_var();
let fields = once((attr.to_string(), attr_ty)).collect(); let fields = once((attr, attr_ty)).collect();
let record = self.unifier.add_record(fields); let record = self.unifier.add_record(fields);
self.constrain(value.custom.unwrap(), record, &value.location)?; self.constrain(value.custom.unwrap(), record, &value.location)?;
Ok(attr_ty) Ok(attr_ty)
@ -583,10 +587,10 @@ impl<'a> Inferencer<'a> {
op: &ast::Operator, op: &ast::Operator,
right: &ast::Expr<Option<Type>>, right: &ast::Expr<Option<Type>>,
) -> InferenceResult { ) -> InferenceResult {
let method = binop_name(op); let method = binop_name(op).into();
self.build_method_call( self.build_method_call(
location, location,
method.to_string(), method,
left.custom.unwrap(), left.custom.unwrap(),
vec![right.custom.unwrap()], vec![right.custom.unwrap()],
None, None,
@ -598,14 +602,8 @@ impl<'a> Inferencer<'a> {
op: &ast::Unaryop, op: &ast::Unaryop,
operand: &ast::Expr<Option<Type>>, operand: &ast::Expr<Option<Type>>,
) -> InferenceResult { ) -> InferenceResult {
let method = unaryop_name(op); let method = unaryop_name(op).into();
self.build_method_call( self.build_method_call(operand.location, method, operand.custom.unwrap(), vec![], None)
operand.location,
method.to_string(),
operand.custom.unwrap(),
vec![],
None,
)
} }
fn infer_compare( fn infer_compare(
@ -617,7 +615,7 @@ impl<'a> Inferencer<'a> {
let boolean = self.primitives.bool; let boolean = self.primitives.bool;
for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) { for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) {
let method = let method =
comparison_name(c).ok_or_else(|| "unsupported comparator".to_string())?.to_string(); comparison_name(c).ok_or_else(|| "unsupported comparator".to_string())?.into();
self.build_method_call( self.build_method_call(
a.location, a.location,
method, method,

View File

@ -12,26 +12,26 @@ use rustpython_parser::parser::parse_program;
use test_case::test_case; use test_case::test_case;
struct Resolver { struct Resolver {
id_to_type: HashMap<String, Type>, id_to_type: HashMap<StrRef, Type>,
id_to_def: HashMap<String, DefinitionId>, id_to_def: HashMap<StrRef, DefinitionId>,
class_names: HashMap<String, Type>, class_names: HashMap<StrRef, Type>,
} }
impl SymbolResolver for Resolver { impl SymbolResolver for Resolver {
fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option<Type> { fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: StrRef) -> Option<Type> {
self.id_to_type.get(str).cloned() self.id_to_type.get(&str).cloned()
} }
fn get_symbol_value(&self, _: &str) -> Option<SymbolValue> { fn get_symbol_value(&self, _: StrRef) -> Option<SymbolValue> {
unimplemented!() unimplemented!()
} }
fn get_symbol_location(&self, _: &str) -> Option<Location> { fn get_symbol_location(&self, _: StrRef) -> Option<Location> {
unimplemented!() unimplemented!()
} }
fn get_identifier_def(&self, id: &str) -> Option<DefinitionId> { fn get_identifier_def(&self, id: StrRef) -> Option<DefinitionId> {
self.id_to_def.get(id).cloned() self.id_to_def.get(&id).cloned()
} }
} }
@ -39,8 +39,8 @@ struct TestEnvironment {
pub unifier: Unifier, pub unifier: Unifier,
pub function_data: FunctionData, pub function_data: FunctionData,
pub primitives: PrimitiveStore, pub primitives: PrimitiveStore,
pub id_to_name: HashMap<usize, String>, pub id_to_name: HashMap<usize, StrRef>,
pub identifier_mapping: HashMap<String, Type>, pub identifier_mapping: HashMap<StrRef, Type>,
pub virtual_checks: Vec<(Type, Type)>, pub virtual_checks: Vec<(Type, Type)>,
pub calls: HashMap<CodeLocation, CallId>, pub calls: HashMap<CodeLocation, CallId>,
pub top_level: TopLevelContext, pub top_level: TopLevelContext,
@ -79,11 +79,11 @@ impl TestEnvironment {
set_primitives_magic_methods(&primitives, &mut unifier); set_primitives_magic_methods(&primitives, &mut unifier);
let id_to_name = [ let id_to_name = [
(0, "int32".to_string()), (0, "int32".into()),
(1, "int64".to_string()), (1, "int64".into()),
(2, "float".to_string()), (2, "float".into()),
(3, "bool".to_string()), (3, "bool".into()),
(4, "none".to_string()), (4, "none".into()),
] ]
.iter() .iter()
.cloned() .cloned()
@ -150,7 +150,7 @@ impl TestEnvironment {
for (i, name) in ["int32", "int64", "float", "bool", "none"].iter().enumerate() { for (i, name) in ["int32", "int64", "float", "bool", "none"].iter().enumerate() {
top_level_defs.push( top_level_defs.push(
RwLock::new(TopLevelDef::Class { RwLock::new(TopLevelDef::Class {
name: name.to_string(), name: (*name).into(),
object_id: DefinitionId(i), object_id: DefinitionId(i),
type_vars: Default::default(), type_vars: Default::default(),
fields: Default::default(), fields: Default::default(),
@ -174,7 +174,7 @@ impl TestEnvironment {
}); });
top_level_defs.push( top_level_defs.push(
RwLock::new(TopLevelDef::Class { RwLock::new(TopLevelDef::Class {
name: "Foo".to_string(), name: "Foo".into(),
object_id: DefinitionId(5), object_id: DefinitionId(5),
type_vars: vec![v0], type_vars: vec![v0],
fields: [("a".into(), v0)].into(), fields: [("a".into(), v0)].into(),
@ -212,7 +212,7 @@ impl TestEnvironment {
}); });
top_level_defs.push( top_level_defs.push(
RwLock::new(TopLevelDef::Class { RwLock::new(TopLevelDef::Class {
name: "Bar".to_string(), name: "Bar".into(),
object_id: DefinitionId(6), object_id: DefinitionId(6),
type_vars: Default::default(), type_vars: Default::default(),
fields: [("a".into(), int32), ("b".into(), fun)].into(), fields: [("a".into(), int32), ("b".into(), fun)].into(),
@ -241,7 +241,7 @@ impl TestEnvironment {
}); });
top_level_defs.push( top_level_defs.push(
RwLock::new(TopLevelDef::Class { RwLock::new(TopLevelDef::Class {
name: "Bar2".to_string(), name: "Bar2".into(),
object_id: DefinitionId(7), object_id: DefinitionId(7),
type_vars: Default::default(), type_vars: Default::default(),
fields: [("a".into(), bool), ("b".into(), fun)].into(), fields: [("a".into(), bool), ("b".into(), fun)].into(),
@ -261,14 +261,14 @@ impl TestEnvironment {
let class_names = [("Bar".into(), bar), ("Bar2".into(), bar2)].iter().cloned().collect(); let class_names = [("Bar".into(), bar), ("Bar2".into(), bar2)].iter().cloned().collect();
let id_to_name = [ let id_to_name = [
(0, "int32".to_string()), (0, "int32".into()),
(1, "int64".to_string()), (1, "int64".into()),
(2, "float".to_string()), (2, "float".into()),
(3, "bool".to_string()), (3, "bool".into()),
(4, "none".to_string()), (4, "none".into()),
(5, "Foo".to_string()), (5, "Foo".into()),
(6, "Bar".to_string()), (6, "Bar".into()),
(7, "Bar2".to_string()), (7, "Bar2".into()),
] ]
.iter() .iter()
.cloned() .cloned()
@ -385,7 +385,7 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
let mut env = TestEnvironment::new(); let mut env = TestEnvironment::new();
let id_to_name = std::mem::take(&mut env.id_to_name); let id_to_name = std::mem::take(&mut env.id_to_name);
let mut defined_identifiers: HashSet<_> = env.identifier_mapping.keys().cloned().collect(); let mut defined_identifiers: HashSet<_> = env.identifier_mapping.keys().cloned().collect();
defined_identifiers.insert("virtual".to_string()); defined_identifiers.insert("virtual".into());
let mut inferencer = env.get_inferencer(); let mut inferencer = env.get_inferencer();
inferencer.defined_identifiers = defined_identifiers.clone(); inferencer.defined_identifiers = defined_identifiers.clone();
let statements = parse_program(source).unwrap(); let statements = parse_program(source).unwrap();
@ -400,16 +400,16 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
for (k, v) in inferencer.variable_mapping.iter() { for (k, v) in inferencer.variable_mapping.iter() {
let name = inferencer.unifier.stringify( let name = inferencer.unifier.stringify(
*v, *v,
&mut |v| id_to_name.get(&v).unwrap().clone(), &mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{}", v), &mut |v| format!("v{}", v),
); );
println!("{}: {}", k, name); println!("{}: {}", k, name);
} }
for (k, v) in mapping.iter() { for (k, v) in mapping.iter() {
let ty = inferencer.variable_mapping.get(*k).unwrap(); let ty = inferencer.variable_mapping.get(&(*k).into()).unwrap();
let name = inferencer.unifier.stringify( let name = inferencer.unifier.stringify(
*ty, *ty,
&mut |v| id_to_name.get(&v).unwrap().clone(), &mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{}", v), &mut |v| format!("v{}", v),
); );
assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name)); assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name));
@ -418,12 +418,12 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
for ((a, b), (x, y)) in zip(inferencer.virtual_checks.iter(), virtuals) { for ((a, b), (x, y)) in zip(inferencer.virtual_checks.iter(), virtuals) {
let a = inferencer.unifier.stringify( let a = inferencer.unifier.stringify(
*a, *a,
&mut |v| id_to_name.get(&v).unwrap().clone(), &mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{}", v), &mut |v| format!("v{}", v),
); );
let b = inferencer.unifier.stringify( let b = inferencer.unifier.stringify(
*b, *b,
&mut |v| id_to_name.get(&v).unwrap().clone(), &mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{}", v), &mut |v| format!("v{}", v),
); );
@ -527,7 +527,7 @@ fn test_primitive_magic_methods(source: &str, mapping: HashMap<&str, &str>) {
let mut env = TestEnvironment::basic_test_env(); let mut env = TestEnvironment::basic_test_env();
let id_to_name = std::mem::take(&mut env.id_to_name); let id_to_name = std::mem::take(&mut env.id_to_name);
let mut defined_identifiers: HashSet<_> = env.identifier_mapping.keys().cloned().collect(); let mut defined_identifiers: HashSet<_> = env.identifier_mapping.keys().cloned().collect();
defined_identifiers.insert("virtual".to_string()); defined_identifiers.insert("virtual".into());
let mut inferencer = env.get_inferencer(); let mut inferencer = env.get_inferencer();
inferencer.defined_identifiers = defined_identifiers.clone(); inferencer.defined_identifiers = defined_identifiers.clone();
let statements = parse_program(source).unwrap(); let statements = parse_program(source).unwrap();
@ -542,16 +542,16 @@ fn test_primitive_magic_methods(source: &str, mapping: HashMap<&str, &str>) {
for (k, v) in inferencer.variable_mapping.iter() { for (k, v) in inferencer.variable_mapping.iter() {
let name = inferencer.unifier.stringify( let name = inferencer.unifier.stringify(
*v, *v,
&mut |v| id_to_name.get(&v).unwrap().clone(), &mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{}", v), &mut |v| format!("v{}", v),
); );
println!("{}: {}", k, name); println!("{}: {}", k, name);
} }
for (k, v) in mapping.iter() { for (k, v) in mapping.iter() {
let ty = inferencer.variable_mapping.get(*k).unwrap(); let ty = inferencer.variable_mapping.get(&(*k).into()).unwrap();
let name = inferencer.unifier.stringify( let name = inferencer.unifier.stringify(
*ty, *ty,
&mut |v| id_to_name.get(&v).unwrap().clone(), &mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{}", v), &mut |v| format!("v{}", v),
); );
assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name)); assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name));

View File

@ -6,6 +6,8 @@ use std::iter::once;
use std::rc::Rc; use std::rc::Rc;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use rustpython_parser::ast::StrRef;
use super::unification_table::{UnificationKey, UnificationTable}; use super::unification_table::{UnificationKey, UnificationTable};
use crate::symbol_resolver::SymbolValue; use crate::symbol_resolver::SymbolValue;
use crate::toplevel::{DefinitionId, TopLevelContext, TopLevelDef}; use crate::toplevel::{DefinitionId, TopLevelContext, TopLevelDef};
@ -25,14 +27,14 @@ type VarMap = Mapping<u32>;
#[derive(Clone)] #[derive(Clone)]
pub struct Call { pub struct Call {
pub posargs: Vec<Type>, pub posargs: Vec<Type>,
pub kwargs: HashMap<String, Type>, pub kwargs: HashMap<StrRef, Type>,
pub ret: Type, pub ret: Type,
pub fun: RefCell<Option<Type>>, pub fun: RefCell<Option<Type>>,
} }
#[derive(Clone)] #[derive(Clone)]
pub struct FuncArg { pub struct FuncArg {
pub name: String, pub name: StrRef,
pub ty: Type, pub ty: Type,
pub default_value: Option<SymbolValue>, pub default_value: Option<SymbolValue>,
} }
@ -48,7 +50,7 @@ pub struct FunSignature {
pub enum TypeVarMeta { pub enum TypeVarMeta {
Generic, Generic,
Sequence(RefCell<Mapping<i32>>), Sequence(RefCell<Mapping<i32>>),
Record(RefCell<Mapping<String>>), Record(RefCell<Mapping<StrRef>>),
} }
#[derive(Clone)] #[derive(Clone)]
@ -70,7 +72,7 @@ pub enum TypeEnum {
}, },
TObj { TObj {
obj_id: DefinitionId, obj_id: DefinitionId,
fields: RefCell<Mapping<String>>, fields: RefCell<Mapping<StrRef>>,
params: RefCell<VarMap>, params: RefCell<VarMap>,
}, },
TVirtual { TVirtual {
@ -141,7 +143,7 @@ impl Unifier {
.borrow() .borrow()
.iter() .iter()
.map(|(name, ty)| { .map(|(name, ty)| {
(name.clone(), self.copy_from(unifier, *ty, type_cache)) (*name, self.copy_from(unifier, *ty, type_cache))
}) })
.collect(), .collect(),
), ),
@ -163,7 +165,7 @@ impl Unifier {
.args .args
.iter() .iter()
.map(|arg| FuncArg { .map(|arg| FuncArg {
name: arg.name.clone(), name: arg.name,
ty: self.copy_from(unifier, arg.ty, type_cache), ty: self.copy_from(unifier, arg.ty, type_cache),
default_value: arg.default_value.clone(), default_value: arg.default_value.clone(),
}) })
@ -219,7 +221,7 @@ impl Unifier {
self.unification_table.new_key(Rc::new(a)) self.unification_table.new_key(Rc::new(a))
} }
pub fn add_record(&mut self, fields: Mapping<String>) -> Type { pub fn add_record(&mut self, fields: Mapping<StrRef>) -> Type {
let id = self.var_id + 1; let id = self.var_id + 1;
self.var_id += 1; self.var_id += 1;
self.add_ty(TypeEnum::TVar { self.add_ty(TypeEnum::TVar {
@ -563,12 +565,12 @@ impl Unifier {
} }
(TCall(calls), TFunc(signature)) => { (TCall(calls), TFunc(signature)) => {
self.occur_check(a, b)?; self.occur_check(a, b)?;
let required: Vec<String> = signature let required: Vec<StrRef> = signature
.borrow() .borrow()
.args .args
.iter() .iter()
.filter(|v| v.default_value.is_none()) .filter(|v| v.default_value.is_none())
.map(|v| v.name.clone()) .map(|v| v.name)
.rev() .rev()
.collect(); .collect();
// we unify every calls to the function signature. // we unify every calls to the function signature.
@ -590,7 +592,7 @@ impl Unifier {
.borrow() .borrow()
.args .args
.iter() .iter()
.map(|v| (v.name.clone(), v.ty)) .map(|v| (v.name, v.ty))
.rev() .rev()
.collect(); .collect();
for (i, t) in posargs.iter().enumerate() { for (i, t) in posargs.iter().enumerate() {
@ -662,7 +664,7 @@ impl Unifier {
if let TopLevelDef::Class { name, .. } = if let TopLevelDef::Class { name, .. } =
&*top_level.definitions.read()[id].read() &*top_level.definitions.read()[id].read()
{ {
name.clone() name.to_string()
} else { } else {
unreachable!("expected class definition") unreachable!("expected class definition")
} }

View File

@ -149,7 +149,7 @@ impl TestEnvironment {
let mut fields = HashMap::new(); let mut fields = HashMap::new();
while &s[0..1] != "]" { while &s[0..1] != "]" {
let eq = s.find('=').unwrap(); let eq = s.find('=').unwrap();
let key = s[1..eq].to_string(); let key = s[1..eq].into();
let result = self.internal_parse(&s[eq + 1..], mapping); let result = self.internal_parse(&s[eq + 1..], mapping);
fields.insert(key, result.0); fields.insert(key, result.0);
s = result.1; s = result.1;
@ -342,8 +342,8 @@ fn test_recursive_subst() {
let instantiated_ty = env.unifier.get_ty(instantiated); let instantiated_ty = env.unifier.get_ty(instantiated);
if let TypeEnum::TObj { fields, .. } = &*instantiated_ty { if let TypeEnum::TObj { fields, .. } = &*instantiated_ty {
let fields = fields.borrow(); let fields = fields.borrow();
assert!(env.unifier.unioned(*fields.get("a").unwrap(), int)); assert!(env.unifier.unioned(*fields.get(&"a".into()).unwrap(), int));
assert!(env.unifier.unioned(*fields.get("rec").unwrap(), instantiated)); assert!(env.unifier.unioned(*fields.get(&"rec".into()).unwrap(), instantiated));
} else { } else {
unreachable!() unreachable!()
} }
@ -358,10 +358,10 @@ fn test_virtual() {
)); ));
let bar = env.unifier.add_ty(TypeEnum::TObj { let bar = env.unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(5), obj_id: DefinitionId(5),
fields: [("f".to_string(), fun), ("a".to_string(), int)] fields: [("f".into(), fun), ("a".into(), int)]
.iter() .iter()
.cloned() .cloned()
.collect::<HashMap<_, _>>() .collect::<HashMap<StrRef, _>>()
.into(), .into(),
params: HashMap::new().into(), params: HashMap::new().into(),
}); });
@ -370,15 +370,15 @@ fn test_virtual() {
let a = env.unifier.add_ty(TypeEnum::TVirtual { ty: bar }); let a = env.unifier.add_ty(TypeEnum::TVirtual { ty: bar });
let b = env.unifier.add_ty(TypeEnum::TVirtual { ty: v0 }); let b = env.unifier.add_ty(TypeEnum::TVirtual { ty: v0 });
let c = env.unifier.add_record([("f".to_string(), v1)].iter().cloned().collect()); let c = env.unifier.add_record([("f".into(), v1)].iter().cloned().collect());
env.unifier.unify(a, b).unwrap(); env.unifier.unify(a, b).unwrap();
env.unifier.unify(b, c).unwrap(); env.unifier.unify(b, c).unwrap();
assert!(env.unifier.eq(v1, fun)); assert!(env.unifier.eq(v1, fun));
let d = env.unifier.add_record([("a".to_string(), v1)].iter().cloned().collect()); let d = env.unifier.add_record([("a".into(), v1)].iter().cloned().collect());
assert_eq!(env.unifier.unify(b, d), Err("Cannot access field a for virtual type".to_string())); assert_eq!(env.unifier.unify(b, d), Err("Cannot access field a for virtual type".to_string()));
let d = env.unifier.add_record([("b".to_string(), v1)].iter().cloned().collect()); let d = env.unifier.add_record([("b".into(), v1)].iter().cloned().collect());
assert_eq!(env.unifier.unify(b, d), Err("No such attribute b".to_string())); assert_eq!(env.unifier.unify(b, d), Err("No such attribute b".to_string()));
} }

View File

@ -11,5 +11,5 @@ crate-type = ["cdylib"]
[dependencies] [dependencies]
pyo3 = { version = "0.12.4", features = ["extension-module"] } pyo3 = { version = "0.12.4", features = ["extension-module"] }
inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm11-0"] } inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm11-0"] }
rustpython-parser = { git = "https://github.com/pca006132/RustPython", branch = "main" } rustpython-parser = { git = "https://github.com/m-labs/RustPython", branch = "parser-mod" }
nac3core = { path = "../nac3core" } nac3core = { path = "../nac3core" }

View File

@ -6,6 +6,6 @@ edition = "2018"
[dependencies] [dependencies]
inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm11-0"] } inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm11-0"] }
rustpython-parser = { git = "https://github.com/pca006132/RustPython", branch = "main" } rustpython-parser = { git = "https://github.com/m-labs/RustPython", branch = "parser-mod" }
parking_lot = "0.11.1" parking_lot = "0.11.1"
nac3core = { path = "../nac3core" } nac3core = { path = "../nac3core" }

View File

@ -8,20 +8,21 @@ use nac3core::{
}, },
}; };
use parking_lot::Mutex; use parking_lot::Mutex;
use rustpython_parser::ast::StrRef;
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, sync::Arc};
pub struct ResolverInternal { pub struct ResolverInternal {
pub id_to_type: Mutex<HashMap<String, Type>>, pub id_to_type: Mutex<HashMap<StrRef, Type>>,
pub id_to_def: Mutex<HashMap<String, DefinitionId>>, pub id_to_def: Mutex<HashMap<StrRef, DefinitionId>>,
pub class_names: Mutex<HashMap<String, Type>>, pub class_names: Mutex<HashMap<StrRef, Type>>,
} }
impl ResolverInternal { impl ResolverInternal {
pub fn add_id_def(&self, id: String, def: DefinitionId) { pub fn add_id_def(&self, id: StrRef, def: DefinitionId) {
self.id_to_def.lock().insert(id, def); self.id_to_def.lock().insert(id, def);
} }
pub fn add_id_type(&self, id: String, ty: Type) { pub fn add_id_type(&self, id: StrRef, ty: Type) {
self.id_to_type.lock().insert(id, ty); self.id_to_type.lock().insert(id, ty);
} }
} }
@ -29,23 +30,23 @@ impl ResolverInternal {
pub struct Resolver(pub Arc<ResolverInternal>); pub struct Resolver(pub Arc<ResolverInternal>);
impl SymbolResolver for Resolver { impl SymbolResolver for Resolver {
fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option<Type> { fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: StrRef) -> Option<Type> {
let ret = self.0.id_to_type.lock().get(str).cloned(); let ret = self.0.id_to_type.lock().get(&str).cloned();
if ret.is_none() { if ret.is_none() {
// println!("unknown here resolver {}", str); // println!("unknown here resolver {}", str);
} }
ret ret
} }
fn get_symbol_value(&self, _: &str) -> Option<SymbolValue> { fn get_symbol_value(&self, _: StrRef) -> Option<SymbolValue> {
unimplemented!() unimplemented!()
} }
fn get_symbol_location(&self, _: &str) -> Option<Location> { fn get_symbol_location(&self, _: StrRef) -> Option<Location> {
unimplemented!() unimplemented!()
} }
fn get_identifier_def(&self, id: &str) -> Option<DefinitionId> { fn get_identifier_def(&self, id: StrRef) -> Option<DefinitionId> {
self.0.id_to_def.lock().get(id).cloned() self.0.id_to_def.lock().get(&id).cloned()
} }
} }

View File

@ -92,7 +92,7 @@ fn main() {
let instance = { let instance = {
let defs = top_level.definitions.read(); let defs = top_level.definitions.read();
let mut instance = defs[resolver.get_identifier_def("run").unwrap().0].write(); let mut instance = defs[resolver.get_identifier_def("run".into()).unwrap().0].write();
if let TopLevelDef::Function { if let TopLevelDef::Function {
instance_to_stmt, instance_to_stmt,
instance_to_symbol, instance_to_symbol,