diff --git a/nac3core/src/typecheck/mod.rs b/nac3core/src/typecheck/mod.rs index fbee8ebe..fe686245 100644 --- a/nac3core/src/typecheck/mod.rs +++ b/nac3core/src/typecheck/mod.rs @@ -1,8 +1,9 @@ #![allow(dead_code)] +mod function_check; pub mod location; mod magic_methods; pub mod symbol_resolver; -pub mod typedef; +mod top_level; pub mod type_inferencer; +pub mod typedef; mod unification_table; -mod function_check; diff --git a/nac3core/src/typecheck/symbol_resolver.rs b/nac3core/src/typecheck/symbol_resolver.rs index 669f7632..33d56201 100644 --- a/nac3core/src/typecheck/symbol_resolver.rs +++ b/nac3core/src/typecheck/symbol_resolver.rs @@ -1,5 +1,6 @@ -use super::typedef::Type; use super::location::Location; +use super::top_level::DefinitionId; +use super::typedef::Type; use rustpython_parser::ast::Expr; pub enum SymbolValue<'a> { @@ -14,6 +15,7 @@ pub enum SymbolValue<'a> { pub trait SymbolResolver { fn get_symbol_type(&mut self, str: &str) -> Option; fn parse_type_name(&mut self, expr: &Expr<()>) -> Option; + fn get_function_def(&mut self, str: &str) -> DefinitionId; fn get_symbol_value(&mut self, str: &str) -> Option; fn get_symbol_location(&mut self, str: &str) -> Option; // handle function call etc. diff --git a/nac3core/src/typecheck/top_level.rs b/nac3core/src/typecheck/top_level.rs new file mode 100644 index 00000000..5d63a920 --- /dev/null +++ b/nac3core/src/typecheck/top_level.rs @@ -0,0 +1,51 @@ +use std::collections::HashMap; + +use super::typedef::{SharedUnifier, Type}; +use crossbeam::queue::SegQueue; +use crossbeam::sync::ShardedLock; +use rustpython_parser::ast::Stmt; + +pub struct DefinitionId(usize); + +pub enum TopLevelDef { + Class { + // object ID used for TypeEnum + object_id: usize, + // type variables bounded to the class. + type_vars: Vec, + // class fields and method signature. + fields: Vec<(String, Type)>, + // class methods, pointing to the corresponding function definition. + methods: Vec<(String, DefinitionId)>, + // ancestor classes, including itself. + ancestors: Vec, + }, + Function { + signature: Type, + /// Function instance to symbol mapping + /// Key: string representation of type variable values, sorted by variable ID in ascending + /// order, including type variables associated with the class. + /// Value: function symbol name. + instance_to_symbol: HashMap, + /// Function instances to annotated AST mapping + /// Key: string representation of type variable values, sorted by variable ID in ascending + /// order, including type variables associated with the class. Excluding rigid type + /// variables. + /// Value: AST annotated with types together with a unification table index. Could contain + /// rigid type variables that would be substituted when the function is instantiated. + instance_to_stmt: HashMap, usize)>, + }, +} + +pub struct CodeGenTask { + pub subst: HashMap, + pub symbol_name: String, + pub body: Stmt, + pub unifier: SharedUnifier, +} + +pub struct TopLevelContext { + pub definitions: Vec>, + pub unifiers: Vec, + pub codegen_queue: SegQueue, +} diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index c6b2c394..d7729cc7 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -1,5 +1,6 @@ use super::super::location::Location; use super::super::symbol_resolver::*; +use super::super::top_level::DefinitionId; use super::super::typedef::*; use super::*; use indoc::indoc; @@ -33,6 +34,10 @@ impl SymbolResolver for Resolver { fn get_symbol_location(&mut self, _: &str) -> Option { unimplemented!() } + + fn get_function_def(&mut self, _: &str) -> DefinitionId { + unimplemented!() + } } struct TestEnvironment { @@ -48,7 +53,7 @@ struct TestEnvironment { impl TestEnvironment { pub fn basic_test_env() -> TestEnvironment { let mut unifier = Unifier::new(); - + let int32 = unifier.add_ty(TypeEnum::TObj { obj_id: 0, fields: HashMap::new().into(), @@ -76,9 +81,9 @@ impl TestEnvironment { }); // identifier_mapping.insert("None".into(), none); let primitives = PrimitiveStore { int32, int64, float, bool, none }; - + set_primirives_magic_methods(&primitives, &mut unifier); - + let id_to_name = [ (0, "int32".to_string()), (1, "int64".to_string()), @@ -95,17 +100,18 @@ impl TestEnvironment { let mut identifier_mapping = HashMap::new(); identifier_mapping.insert("None".into(), none); - - let resolver = - Box::new(Resolver { identifier_mapping: identifier_mapping.clone(), class_names: Default::default() }) - as Box; + + let resolver = Box::new(Resolver { + identifier_mapping: identifier_mapping.clone(), + class_names: Default::default(), + }) as Box; TestEnvironment { unifier, function_data: FunctionData { resolver, bound_variables: Vec::new(), - return_type: None + return_type: None, }, primitives, id_to_name, @@ -171,7 +177,11 @@ impl TestEnvironment { })); let bar = unifier.add_ty(TypeEnum::TObj { obj_id: 6, - fields: [("a".into(), int32), ("b".into(), fun)].iter().cloned().collect::>().into(), + fields: [("a".into(), int32), ("b".into(), fun)] + .iter() + .cloned() + .collect::>() + .into(), params: Default::default(), }); identifier_mapping.insert( @@ -185,7 +195,11 @@ impl TestEnvironment { let bar2 = unifier.add_ty(TypeEnum::TObj { obj_id: 7, - fields: [("a".into(), bool), ("b".into(), fun)].iter().cloned().collect::>().into(), + fields: [("a".into(), bool), ("b".into(), fun)] + .iter() + .cloned() + .collect::>() + .into(), params: Default::default(), }); identifier_mapping.insert( @@ -362,13 +376,13 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st g = a // b h = a % b "}, - [("a", "int32"), - ("b", "int32"), - ("c", "int32"), - ("d", "int32"), - ("e", "int32"), - ("f", "float"), - ("g", "int32"), + [("a", "int32"), + ("b", "int32"), + ("c", "int32"), + ("d", "int32"), + ("e", "int32"), + ("f", "float"), + ("g", "int32"), ("h", "int32")].iter().cloned().collect() ; "int32")] #[test_case( @@ -382,13 +396,13 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st g = a // b h = a % b "}, - [("a", "float"), - ("b", "float"), - ("c", "float"), - ("d", "float"), - ("e", "float"), - ("f", "float"), - ("g", "float"), + [("a", "float"), + ("b", "float"), + ("c", "float"), + ("d", "float"), + ("e", "float"), + ("f", "float"), + ("g", "float"), ("h", "float")].iter().cloned().collect() ; "float" )] @@ -407,13 +421,13 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st k = a < b l = a != b "}, - [("a", "int64"), - ("b", "int64"), - ("c", "int64"), - ("d", "int64"), - ("e", "int64"), - ("f", "float"), - ("g", "int64"), + [("a", "int64"), + ("b", "int64"), + ("c", "int64"), + ("d", "int64"), + ("e", "int64"), + ("f", "float"), + ("g", "int64"), ("h", "int64"), ("i", "bool"), ("j", "bool"), @@ -429,10 +443,10 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st d = not a e = a != b "}, - [("a", "bool"), - ("b", "bool"), - ("c", "bool"), - ("d", "bool"), + [("a", "bool"), + ("b", "bool"), + ("c", "bool"), + ("d", "bool"), ("e", "bool")].iter().cloned().collect() ; "boolean" )] @@ -469,4 +483,5 @@ fn test_primitive_magic_methods(source: &str, mapping: HashMap<&str, &str>) { ); assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name)); } -} \ No newline at end of file +} + diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 20736dca..8f2567a5 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -1,9 +1,10 @@ use itertools::{chain, zip, Itertools}; -use std::{borrow::Cow, sync::Arc}; +use std::borrow::Cow; use std::cell::RefCell; use std::collections::HashMap; use std::iter::once; use std::rc::Rc; +use std::sync::{Arc, Mutex}; use super::unification_table::{UnificationKey, UnificationTable}; @@ -89,6 +90,8 @@ impl TypeEnum { } } +pub type SharedUnifier = Arc, u32)>>; + pub struct Unifier { unification_table: UnificationTable>, var_id: u32, @@ -100,6 +103,15 @@ impl Unifier { Unifier { unification_table: UnificationTable::new(), var_id: 0 } } + pub fn from_shared_unifier(unifier: &SharedUnifier) -> Unifier { + let lock = unifier.lock().unwrap(); + Unifier { unification_table: UnificationTable::from_send(&lock.0), var_id: lock.1 } + } + + pub fn get_shared_unifier(&self) -> SharedUnifier { + Arc::new(Mutex::new((self.unification_table.get_send(), self.var_id))) + } + /// Register a type to the unifier. /// Returns a key in the unification_table. pub fn add_ty(&mut self, a: TypeEnum) -> Type { @@ -373,7 +385,11 @@ impl Unifier { (TVar { meta: Record(map), id, range, .. }, TObj { fields, .. }) => { self.occur_check(a, b)?; for (k, v) in map.borrow().iter() { - let ty = fields.borrow().get(k).copied().ok_or_else(|| format!("No such attribute {}", k))?; + let ty = fields + .borrow() + .get(k) + .copied() + .ok_or_else(|| format!("No such attribute {}", k))?; self.unify(ty, *v)?; } let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b); @@ -385,7 +401,11 @@ impl Unifier { let ty = self.get_ty(*ty); if let TObj { fields, .. } = ty.as_ref() { for (k, v) in map.borrow().iter() { - let ty = fields.borrow().get(k).copied().ok_or_else(|| format!("No such attribute {}", k))?; + let ty = fields + .borrow() + .get(k) + .copied() + .ok_or_else(|| format!("No such attribute {}", k))?; if !matches!(self.get_ty(ty).as_ref(), TFunc { .. }) { return Err(format!("Cannot access field {} for virtual type", k)); } @@ -659,7 +679,9 @@ impl Unifier { if need_subst { let obj_id = *obj_id; let params = self.subst_map(¶ms, mapping).unwrap_or_else(|| params.clone()); - let fields = self.subst_map(&fields.borrow(), mapping).unwrap_or_else(|| fields.borrow().clone()); + let fields = self + .subst_map(&fields.borrow(), mapping) + .unwrap_or_else(|| fields.borrow().clone()); Some(self.add_ty(TypeEnum::TObj { obj_id, params, fields: fields.into() })) } else { None diff --git a/nac3core/src/typecheck/unification_table.rs b/nac3core/src/typecheck/unification_table.rs index 588cf4bb..8c50b5f9 100644 --- a/nac3core/src/typecheck/unification_table.rs +++ b/nac3core/src/typecheck/unification_table.rs @@ -71,13 +71,13 @@ impl UnificationTable> where V: Clone, { - pub fn into_send(self) -> UnificationTable { + pub fn get_send(&self) -> UnificationTable { let values = self.values.iter().map(|v| v.as_ref().clone()).collect(); - UnificationTable { parents: self.parents, ranks: self.ranks, values } + UnificationTable { parents: self.parents.clone(), ranks: self.ranks.clone(), values } } - pub fn from_send(table: UnificationTable) -> UnificationTable> { - let values = table.values.into_iter().map(Rc::new).collect(); - UnificationTable { parents: table.parents, ranks: table.ranks, values } + pub fn from_send(table: &UnificationTable) -> UnificationTable> { + let values = table.values.iter().cloned().map(Rc::new).collect(); + UnificationTable { parents: table.parents.clone(), ranks: table.ranks.clone(), values } } }