forked from M-Labs/nac3
1
0
Fork 0

top-level related changes

This commit is contained in:
pca006132 2021-08-03 13:38:27 +08:00
parent d4d12a9d1d
commit f00c1813e3
6 changed files with 139 additions and 48 deletions

View File

@ -1,8 +1,9 @@
#![allow(dead_code)] #![allow(dead_code)]
mod function_check;
pub mod location; pub mod location;
mod magic_methods; mod magic_methods;
pub mod symbol_resolver; pub mod symbol_resolver;
pub mod typedef; mod top_level;
pub mod type_inferencer; pub mod type_inferencer;
pub mod typedef;
mod unification_table; mod unification_table;
mod function_check;

View File

@ -1,5 +1,6 @@
use super::typedef::Type;
use super::location::Location; use super::location::Location;
use super::top_level::DefinitionId;
use super::typedef::Type;
use rustpython_parser::ast::Expr; use rustpython_parser::ast::Expr;
pub enum SymbolValue<'a> { pub enum SymbolValue<'a> {
@ -14,6 +15,7 @@ pub enum SymbolValue<'a> {
pub trait SymbolResolver { pub trait SymbolResolver {
fn get_symbol_type(&mut self, str: &str) -> Option<Type>; fn get_symbol_type(&mut self, str: &str) -> Option<Type>;
fn parse_type_name(&mut self, expr: &Expr<()>) -> Option<Type>; fn parse_type_name(&mut self, expr: &Expr<()>) -> Option<Type>;
fn get_function_def(&mut self, str: &str) -> DefinitionId;
fn get_symbol_value(&mut self, str: &str) -> Option<SymbolValue>; fn get_symbol_value(&mut self, str: &str) -> Option<SymbolValue>;
fn get_symbol_location(&mut self, str: &str) -> Option<Location>; fn get_symbol_location(&mut self, str: &str) -> Option<Location>;
// handle function call etc. // handle function call etc.

View File

@ -0,0 +1,51 @@
use std::collections::HashMap;
use super::typedef::{SharedUnifier, Type};
use crossbeam::queue::SegQueue;
use crossbeam::sync::ShardedLock;
use rustpython_parser::ast::Stmt;
pub struct DefinitionId(usize);
pub enum TopLevelDef {
Class {
// object ID used for TypeEnum
object_id: usize,
// type variables bounded to the class.
type_vars: Vec<Type>,
// class fields and method signature.
fields: Vec<(String, Type)>,
// class methods, pointing to the corresponding function definition.
methods: Vec<(String, DefinitionId)>,
// ancestor classes, including itself.
ancestors: Vec<DefinitionId>,
},
Function {
signature: Type,
/// Function instance to symbol mapping
/// Key: string representation of type variable values, sorted by variable ID in ascending
/// order, including type variables associated with the class.
/// Value: function symbol name.
instance_to_symbol: HashMap<String, String>,
/// Function instances to annotated AST mapping
/// Key: string representation of type variable values, sorted by variable ID in ascending
/// order, including type variables associated with the class. Excluding rigid type
/// variables.
/// Value: AST annotated with types together with a unification table index. Could contain
/// rigid type variables that would be substituted when the function is instantiated.
instance_to_stmt: HashMap<String, (Stmt<Type>, usize)>,
},
}
pub struct CodeGenTask {
pub subst: HashMap<usize, Type>,
pub symbol_name: String,
pub body: Stmt<Type>,
pub unifier: SharedUnifier,
}
pub struct TopLevelContext {
pub definitions: Vec<ShardedLock<TopLevelDef>>,
pub unifiers: Vec<SharedUnifier>,
pub codegen_queue: SegQueue<CodeGenTask>,
}

View File

@ -1,5 +1,6 @@
use super::super::location::Location; use super::super::location::Location;
use super::super::symbol_resolver::*; use super::super::symbol_resolver::*;
use super::super::top_level::DefinitionId;
use super::super::typedef::*; use super::super::typedef::*;
use super::*; use super::*;
use indoc::indoc; use indoc::indoc;
@ -33,6 +34,10 @@ impl SymbolResolver for Resolver {
fn get_symbol_location(&mut self, _: &str) -> Option<Location> { fn get_symbol_location(&mut self, _: &str) -> Option<Location> {
unimplemented!() unimplemented!()
} }
fn get_function_def(&mut self, _: &str) -> DefinitionId {
unimplemented!()
}
} }
struct TestEnvironment { struct TestEnvironment {
@ -48,7 +53,7 @@ struct TestEnvironment {
impl TestEnvironment { impl TestEnvironment {
pub fn basic_test_env() -> TestEnvironment { pub fn basic_test_env() -> TestEnvironment {
let mut unifier = Unifier::new(); let mut unifier = Unifier::new();
let int32 = unifier.add_ty(TypeEnum::TObj { let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: 0, obj_id: 0,
fields: HashMap::new().into(), fields: HashMap::new().into(),
@ -76,9 +81,9 @@ impl TestEnvironment {
}); });
// identifier_mapping.insert("None".into(), none); // identifier_mapping.insert("None".into(), none);
let primitives = PrimitiveStore { int32, int64, float, bool, none }; let primitives = PrimitiveStore { int32, int64, float, bool, none };
set_primirives_magic_methods(&primitives, &mut unifier); set_primirives_magic_methods(&primitives, &mut unifier);
let id_to_name = [ let id_to_name = [
(0, "int32".to_string()), (0, "int32".to_string()),
(1, "int64".to_string()), (1, "int64".to_string()),
@ -95,17 +100,18 @@ impl TestEnvironment {
let mut identifier_mapping = HashMap::new(); let mut identifier_mapping = HashMap::new();
identifier_mapping.insert("None".into(), none); identifier_mapping.insert("None".into(), none);
let resolver = let resolver = Box::new(Resolver {
Box::new(Resolver { identifier_mapping: identifier_mapping.clone(), class_names: Default::default() }) identifier_mapping: identifier_mapping.clone(),
as Box<dyn SymbolResolver>; class_names: Default::default(),
}) as Box<dyn SymbolResolver>;
TestEnvironment { TestEnvironment {
unifier, unifier,
function_data: FunctionData { function_data: FunctionData {
resolver, resolver,
bound_variables: Vec::new(), bound_variables: Vec::new(),
return_type: None return_type: None,
}, },
primitives, primitives,
id_to_name, id_to_name,
@ -171,7 +177,11 @@ impl TestEnvironment {
})); }));
let bar = unifier.add_ty(TypeEnum::TObj { let bar = unifier.add_ty(TypeEnum::TObj {
obj_id: 6, obj_id: 6,
fields: [("a".into(), int32), ("b".into(), fun)].iter().cloned().collect::<HashMap<_, _>>().into(), fields: [("a".into(), int32), ("b".into(), fun)]
.iter()
.cloned()
.collect::<HashMap<_, _>>()
.into(),
params: Default::default(), params: Default::default(),
}); });
identifier_mapping.insert( identifier_mapping.insert(
@ -185,7 +195,11 @@ impl TestEnvironment {
let bar2 = unifier.add_ty(TypeEnum::TObj { let bar2 = unifier.add_ty(TypeEnum::TObj {
obj_id: 7, obj_id: 7,
fields: [("a".into(), bool), ("b".into(), fun)].iter().cloned().collect::<HashMap<_, _>>().into(), fields: [("a".into(), bool), ("b".into(), fun)]
.iter()
.cloned()
.collect::<HashMap<_, _>>()
.into(),
params: Default::default(), params: Default::default(),
}); });
identifier_mapping.insert( identifier_mapping.insert(
@ -362,13 +376,13 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
g = a // b g = a // b
h = a % b h = a % b
"}, "},
[("a", "int32"), [("a", "int32"),
("b", "int32"), ("b", "int32"),
("c", "int32"), ("c", "int32"),
("d", "int32"), ("d", "int32"),
("e", "int32"), ("e", "int32"),
("f", "float"), ("f", "float"),
("g", "int32"), ("g", "int32"),
("h", "int32")].iter().cloned().collect() ("h", "int32")].iter().cloned().collect()
; "int32")] ; "int32")]
#[test_case( #[test_case(
@ -382,13 +396,13 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
g = a // b g = a // b
h = a % b h = a % b
"}, "},
[("a", "float"), [("a", "float"),
("b", "float"), ("b", "float"),
("c", "float"), ("c", "float"),
("d", "float"), ("d", "float"),
("e", "float"), ("e", "float"),
("f", "float"), ("f", "float"),
("g", "float"), ("g", "float"),
("h", "float")].iter().cloned().collect() ("h", "float")].iter().cloned().collect()
; "float" ; "float"
)] )]
@ -407,13 +421,13 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
k = a < b k = a < b
l = a != b l = a != b
"}, "},
[("a", "int64"), [("a", "int64"),
("b", "int64"), ("b", "int64"),
("c", "int64"), ("c", "int64"),
("d", "int64"), ("d", "int64"),
("e", "int64"), ("e", "int64"),
("f", "float"), ("f", "float"),
("g", "int64"), ("g", "int64"),
("h", "int64"), ("h", "int64"),
("i", "bool"), ("i", "bool"),
("j", "bool"), ("j", "bool"),
@ -429,10 +443,10 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
d = not a d = not a
e = a != b e = a != b
"}, "},
[("a", "bool"), [("a", "bool"),
("b", "bool"), ("b", "bool"),
("c", "bool"), ("c", "bool"),
("d", "bool"), ("d", "bool"),
("e", "bool")].iter().cloned().collect() ("e", "bool")].iter().cloned().collect()
; "boolean" ; "boolean"
)] )]
@ -469,4 +483,5 @@ fn test_primitive_magic_methods(source: &str, mapping: HashMap<&str, &str>) {
); );
assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name)); assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name));
} }
} }

View File

@ -1,9 +1,10 @@
use itertools::{chain, zip, Itertools}; use itertools::{chain, zip, Itertools};
use std::{borrow::Cow, sync::Arc}; use std::borrow::Cow;
use std::cell::RefCell; use std::cell::RefCell;
use std::collections::HashMap; use std::collections::HashMap;
use std::iter::once; use std::iter::once;
use std::rc::Rc; use std::rc::Rc;
use std::sync::{Arc, Mutex};
use super::unification_table::{UnificationKey, UnificationTable}; use super::unification_table::{UnificationKey, UnificationTable};
@ -89,6 +90,8 @@ impl TypeEnum {
} }
} }
pub type SharedUnifier = Arc<Mutex<(UnificationTable<TypeEnum>, u32)>>;
pub struct Unifier { pub struct Unifier {
unification_table: UnificationTable<Rc<TypeEnum>>, unification_table: UnificationTable<Rc<TypeEnum>>,
var_id: u32, var_id: u32,
@ -100,6 +103,15 @@ impl Unifier {
Unifier { unification_table: UnificationTable::new(), var_id: 0 } Unifier { unification_table: UnificationTable::new(), var_id: 0 }
} }
pub fn from_shared_unifier(unifier: &SharedUnifier) -> Unifier {
let lock = unifier.lock().unwrap();
Unifier { unification_table: UnificationTable::from_send(&lock.0), var_id: lock.1 }
}
pub fn get_shared_unifier(&self) -> SharedUnifier {
Arc::new(Mutex::new((self.unification_table.get_send(), self.var_id)))
}
/// Register a type to the unifier. /// Register a type to the unifier.
/// Returns a key in the unification_table. /// Returns a key in the unification_table.
pub fn add_ty(&mut self, a: TypeEnum) -> Type { pub fn add_ty(&mut self, a: TypeEnum) -> Type {
@ -373,7 +385,11 @@ impl Unifier {
(TVar { meta: Record(map), id, range, .. }, TObj { fields, .. }) => { (TVar { meta: Record(map), id, range, .. }, TObj { fields, .. }) => {
self.occur_check(a, b)?; self.occur_check(a, b)?;
for (k, v) in map.borrow().iter() { for (k, v) in map.borrow().iter() {
let ty = fields.borrow().get(k).copied().ok_or_else(|| format!("No such attribute {}", k))?; let ty = fields
.borrow()
.get(k)
.copied()
.ok_or_else(|| format!("No such attribute {}", k))?;
self.unify(ty, *v)?; self.unify(ty, *v)?;
} }
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b); let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
@ -385,7 +401,11 @@ impl Unifier {
let ty = self.get_ty(*ty); let ty = self.get_ty(*ty);
if let TObj { fields, .. } = ty.as_ref() { if let TObj { fields, .. } = ty.as_ref() {
for (k, v) in map.borrow().iter() { for (k, v) in map.borrow().iter() {
let ty = fields.borrow().get(k).copied().ok_or_else(|| format!("No such attribute {}", k))?; let ty = fields
.borrow()
.get(k)
.copied()
.ok_or_else(|| format!("No such attribute {}", k))?;
if !matches!(self.get_ty(ty).as_ref(), TFunc { .. }) { if !matches!(self.get_ty(ty).as_ref(), TFunc { .. }) {
return Err(format!("Cannot access field {} for virtual type", k)); return Err(format!("Cannot access field {} for virtual type", k));
} }
@ -659,7 +679,9 @@ impl Unifier {
if need_subst { if need_subst {
let obj_id = *obj_id; let obj_id = *obj_id;
let params = self.subst_map(&params, mapping).unwrap_or_else(|| params.clone()); let params = self.subst_map(&params, mapping).unwrap_or_else(|| params.clone());
let fields = self.subst_map(&fields.borrow(), mapping).unwrap_or_else(|| fields.borrow().clone()); let fields = self
.subst_map(&fields.borrow(), mapping)
.unwrap_or_else(|| fields.borrow().clone());
Some(self.add_ty(TypeEnum::TObj { obj_id, params, fields: fields.into() })) Some(self.add_ty(TypeEnum::TObj { obj_id, params, fields: fields.into() }))
} else { } else {
None None

View File

@ -71,13 +71,13 @@ impl<V> UnificationTable<Rc<V>>
where where
V: Clone, V: Clone,
{ {
pub fn into_send(self) -> UnificationTable<V> { pub fn get_send(&self) -> UnificationTable<V> {
let values = self.values.iter().map(|v| v.as_ref().clone()).collect(); let values = self.values.iter().map(|v| v.as_ref().clone()).collect();
UnificationTable { parents: self.parents, ranks: self.ranks, values } UnificationTable { parents: self.parents.clone(), ranks: self.ranks.clone(), values }
} }
pub fn from_send(table: UnificationTable<V>) -> UnificationTable<Rc<V>> { pub fn from_send(table: &UnificationTable<V>) -> UnificationTable<Rc<V>> {
let values = table.values.into_iter().map(Rc::new).collect(); let values = table.values.iter().cloned().map(Rc::new).collect();
UnificationTable { parents: table.parents, ranks: table.ranks, values } UnificationTable { parents: table.parents.clone(), ranks: table.ranks.clone(), values }
} }
} }