forked from M-Labs/nac3
top-level related changes
This commit is contained in:
parent
d4d12a9d1d
commit
f00c1813e3
@ -1,8 +1,9 @@
|
||||
#![allow(dead_code)]
|
||||
mod function_check;
|
||||
pub mod location;
|
||||
mod magic_methods;
|
||||
pub mod symbol_resolver;
|
||||
pub mod typedef;
|
||||
mod top_level;
|
||||
pub mod type_inferencer;
|
||||
pub mod typedef;
|
||||
mod unification_table;
|
||||
mod function_check;
|
||||
|
@ -1,5 +1,6 @@
|
||||
use super::typedef::Type;
|
||||
use super::location::Location;
|
||||
use super::top_level::DefinitionId;
|
||||
use super::typedef::Type;
|
||||
use rustpython_parser::ast::Expr;
|
||||
|
||||
pub enum SymbolValue<'a> {
|
||||
@ -14,6 +15,7 @@ pub enum SymbolValue<'a> {
|
||||
pub trait SymbolResolver {
|
||||
fn get_symbol_type(&mut self, str: &str) -> Option<Type>;
|
||||
fn parse_type_name(&mut self, expr: &Expr<()>) -> Option<Type>;
|
||||
fn get_function_def(&mut self, str: &str) -> DefinitionId;
|
||||
fn get_symbol_value(&mut self, str: &str) -> Option<SymbolValue>;
|
||||
fn get_symbol_location(&mut self, str: &str) -> Option<Location>;
|
||||
// handle function call etc.
|
||||
|
51
nac3core/src/typecheck/top_level.rs
Normal file
51
nac3core/src/typecheck/top_level.rs
Normal file
@ -0,0 +1,51 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::typedef::{SharedUnifier, Type};
|
||||
use crossbeam::queue::SegQueue;
|
||||
use crossbeam::sync::ShardedLock;
|
||||
use rustpython_parser::ast::Stmt;
|
||||
|
||||
pub struct DefinitionId(usize);
|
||||
|
||||
pub enum TopLevelDef {
|
||||
Class {
|
||||
// object ID used for TypeEnum
|
||||
object_id: usize,
|
||||
// type variables bounded to the class.
|
||||
type_vars: Vec<Type>,
|
||||
// class fields and method signature.
|
||||
fields: Vec<(String, Type)>,
|
||||
// class methods, pointing to the corresponding function definition.
|
||||
methods: Vec<(String, DefinitionId)>,
|
||||
// ancestor classes, including itself.
|
||||
ancestors: Vec<DefinitionId>,
|
||||
},
|
||||
Function {
|
||||
signature: Type,
|
||||
/// Function instance to symbol mapping
|
||||
/// Key: string representation of type variable values, sorted by variable ID in ascending
|
||||
/// order, including type variables associated with the class.
|
||||
/// Value: function symbol name.
|
||||
instance_to_symbol: HashMap<String, String>,
|
||||
/// Function instances to annotated AST mapping
|
||||
/// Key: string representation of type variable values, sorted by variable ID in ascending
|
||||
/// order, including type variables associated with the class. Excluding rigid type
|
||||
/// variables.
|
||||
/// Value: AST annotated with types together with a unification table index. Could contain
|
||||
/// rigid type variables that would be substituted when the function is instantiated.
|
||||
instance_to_stmt: HashMap<String, (Stmt<Type>, usize)>,
|
||||
},
|
||||
}
|
||||
|
||||
pub struct CodeGenTask {
|
||||
pub subst: HashMap<usize, Type>,
|
||||
pub symbol_name: String,
|
||||
pub body: Stmt<Type>,
|
||||
pub unifier: SharedUnifier,
|
||||
}
|
||||
|
||||
pub struct TopLevelContext {
|
||||
pub definitions: Vec<ShardedLock<TopLevelDef>>,
|
||||
pub unifiers: Vec<SharedUnifier>,
|
||||
pub codegen_queue: SegQueue<CodeGenTask>,
|
||||
}
|
@ -1,5 +1,6 @@
|
||||
use super::super::location::Location;
|
||||
use super::super::symbol_resolver::*;
|
||||
use super::super::top_level::DefinitionId;
|
||||
use super::super::typedef::*;
|
||||
use super::*;
|
||||
use indoc::indoc;
|
||||
@ -33,6 +34,10 @@ impl SymbolResolver for Resolver {
|
||||
fn get_symbol_location(&mut self, _: &str) -> Option<Location> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn get_function_def(&mut self, _: &str) -> DefinitionId {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
struct TestEnvironment {
|
||||
@ -48,7 +53,7 @@ struct TestEnvironment {
|
||||
impl TestEnvironment {
|
||||
pub fn basic_test_env() -> TestEnvironment {
|
||||
let mut unifier = Unifier::new();
|
||||
|
||||
|
||||
let int32 = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: 0,
|
||||
fields: HashMap::new().into(),
|
||||
@ -76,9 +81,9 @@ impl TestEnvironment {
|
||||
});
|
||||
// identifier_mapping.insert("None".into(), none);
|
||||
let primitives = PrimitiveStore { int32, int64, float, bool, none };
|
||||
|
||||
|
||||
set_primirives_magic_methods(&primitives, &mut unifier);
|
||||
|
||||
|
||||
let id_to_name = [
|
||||
(0, "int32".to_string()),
|
||||
(1, "int64".to_string()),
|
||||
@ -95,17 +100,18 @@ impl TestEnvironment {
|
||||
|
||||
let mut identifier_mapping = HashMap::new();
|
||||
identifier_mapping.insert("None".into(), none);
|
||||
|
||||
let resolver =
|
||||
Box::new(Resolver { identifier_mapping: identifier_mapping.clone(), class_names: Default::default() })
|
||||
as Box<dyn SymbolResolver>;
|
||||
|
||||
let resolver = Box::new(Resolver {
|
||||
identifier_mapping: identifier_mapping.clone(),
|
||||
class_names: Default::default(),
|
||||
}) as Box<dyn SymbolResolver>;
|
||||
|
||||
TestEnvironment {
|
||||
unifier,
|
||||
function_data: FunctionData {
|
||||
resolver,
|
||||
bound_variables: Vec::new(),
|
||||
return_type: None
|
||||
return_type: None,
|
||||
},
|
||||
primitives,
|
||||
id_to_name,
|
||||
@ -171,7 +177,11 @@ impl TestEnvironment {
|
||||
}));
|
||||
let bar = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: 6,
|
||||
fields: [("a".into(), int32), ("b".into(), fun)].iter().cloned().collect::<HashMap<_, _>>().into(),
|
||||
fields: [("a".into(), int32), ("b".into(), fun)]
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect::<HashMap<_, _>>()
|
||||
.into(),
|
||||
params: Default::default(),
|
||||
});
|
||||
identifier_mapping.insert(
|
||||
@ -185,7 +195,11 @@ impl TestEnvironment {
|
||||
|
||||
let bar2 = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: 7,
|
||||
fields: [("a".into(), bool), ("b".into(), fun)].iter().cloned().collect::<HashMap<_, _>>().into(),
|
||||
fields: [("a".into(), bool), ("b".into(), fun)]
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect::<HashMap<_, _>>()
|
||||
.into(),
|
||||
params: Default::default(),
|
||||
});
|
||||
identifier_mapping.insert(
|
||||
@ -362,13 +376,13 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
|
||||
g = a // b
|
||||
h = a % b
|
||||
"},
|
||||
[("a", "int32"),
|
||||
("b", "int32"),
|
||||
("c", "int32"),
|
||||
("d", "int32"),
|
||||
("e", "int32"),
|
||||
("f", "float"),
|
||||
("g", "int32"),
|
||||
[("a", "int32"),
|
||||
("b", "int32"),
|
||||
("c", "int32"),
|
||||
("d", "int32"),
|
||||
("e", "int32"),
|
||||
("f", "float"),
|
||||
("g", "int32"),
|
||||
("h", "int32")].iter().cloned().collect()
|
||||
; "int32")]
|
||||
#[test_case(
|
||||
@ -382,13 +396,13 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
|
||||
g = a // b
|
||||
h = a % b
|
||||
"},
|
||||
[("a", "float"),
|
||||
("b", "float"),
|
||||
("c", "float"),
|
||||
("d", "float"),
|
||||
("e", "float"),
|
||||
("f", "float"),
|
||||
("g", "float"),
|
||||
[("a", "float"),
|
||||
("b", "float"),
|
||||
("c", "float"),
|
||||
("d", "float"),
|
||||
("e", "float"),
|
||||
("f", "float"),
|
||||
("g", "float"),
|
||||
("h", "float")].iter().cloned().collect()
|
||||
; "float"
|
||||
)]
|
||||
@ -407,13 +421,13 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
|
||||
k = a < b
|
||||
l = a != b
|
||||
"},
|
||||
[("a", "int64"),
|
||||
("b", "int64"),
|
||||
("c", "int64"),
|
||||
("d", "int64"),
|
||||
("e", "int64"),
|
||||
("f", "float"),
|
||||
("g", "int64"),
|
||||
[("a", "int64"),
|
||||
("b", "int64"),
|
||||
("c", "int64"),
|
||||
("d", "int64"),
|
||||
("e", "int64"),
|
||||
("f", "float"),
|
||||
("g", "int64"),
|
||||
("h", "int64"),
|
||||
("i", "bool"),
|
||||
("j", "bool"),
|
||||
@ -429,10 +443,10 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
|
||||
d = not a
|
||||
e = a != b
|
||||
"},
|
||||
[("a", "bool"),
|
||||
("b", "bool"),
|
||||
("c", "bool"),
|
||||
("d", "bool"),
|
||||
[("a", "bool"),
|
||||
("b", "bool"),
|
||||
("c", "bool"),
|
||||
("d", "bool"),
|
||||
("e", "bool")].iter().cloned().collect()
|
||||
; "boolean"
|
||||
)]
|
||||
@ -469,4 +483,5 @@ fn test_primitive_magic_methods(source: &str, mapping: HashMap<&str, &str>) {
|
||||
);
|
||||
assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,9 +1,10 @@
|
||||
use itertools::{chain, zip, Itertools};
|
||||
use std::{borrow::Cow, sync::Arc};
|
||||
use std::borrow::Cow;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::iter::once;
|
||||
use std::rc::Rc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use super::unification_table::{UnificationKey, UnificationTable};
|
||||
|
||||
@ -89,6 +90,8 @@ impl TypeEnum {
|
||||
}
|
||||
}
|
||||
|
||||
pub type SharedUnifier = Arc<Mutex<(UnificationTable<TypeEnum>, u32)>>;
|
||||
|
||||
pub struct Unifier {
|
||||
unification_table: UnificationTable<Rc<TypeEnum>>,
|
||||
var_id: u32,
|
||||
@ -100,6 +103,15 @@ impl Unifier {
|
||||
Unifier { unification_table: UnificationTable::new(), var_id: 0 }
|
||||
}
|
||||
|
||||
pub fn from_shared_unifier(unifier: &SharedUnifier) -> Unifier {
|
||||
let lock = unifier.lock().unwrap();
|
||||
Unifier { unification_table: UnificationTable::from_send(&lock.0), var_id: lock.1 }
|
||||
}
|
||||
|
||||
pub fn get_shared_unifier(&self) -> SharedUnifier {
|
||||
Arc::new(Mutex::new((self.unification_table.get_send(), self.var_id)))
|
||||
}
|
||||
|
||||
/// Register a type to the unifier.
|
||||
/// Returns a key in the unification_table.
|
||||
pub fn add_ty(&mut self, a: TypeEnum) -> Type {
|
||||
@ -373,7 +385,11 @@ impl Unifier {
|
||||
(TVar { meta: Record(map), id, range, .. }, TObj { fields, .. }) => {
|
||||
self.occur_check(a, b)?;
|
||||
for (k, v) in map.borrow().iter() {
|
||||
let ty = fields.borrow().get(k).copied().ok_or_else(|| format!("No such attribute {}", k))?;
|
||||
let ty = fields
|
||||
.borrow()
|
||||
.get(k)
|
||||
.copied()
|
||||
.ok_or_else(|| format!("No such attribute {}", k))?;
|
||||
self.unify(ty, *v)?;
|
||||
}
|
||||
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
|
||||
@ -385,7 +401,11 @@ impl Unifier {
|
||||
let ty = self.get_ty(*ty);
|
||||
if let TObj { fields, .. } = ty.as_ref() {
|
||||
for (k, v) in map.borrow().iter() {
|
||||
let ty = fields.borrow().get(k).copied().ok_or_else(|| format!("No such attribute {}", k))?;
|
||||
let ty = fields
|
||||
.borrow()
|
||||
.get(k)
|
||||
.copied()
|
||||
.ok_or_else(|| format!("No such attribute {}", k))?;
|
||||
if !matches!(self.get_ty(ty).as_ref(), TFunc { .. }) {
|
||||
return Err(format!("Cannot access field {} for virtual type", k));
|
||||
}
|
||||
@ -659,7 +679,9 @@ impl Unifier {
|
||||
if need_subst {
|
||||
let obj_id = *obj_id;
|
||||
let params = self.subst_map(¶ms, mapping).unwrap_or_else(|| params.clone());
|
||||
let fields = self.subst_map(&fields.borrow(), mapping).unwrap_or_else(|| fields.borrow().clone());
|
||||
let fields = self
|
||||
.subst_map(&fields.borrow(), mapping)
|
||||
.unwrap_or_else(|| fields.borrow().clone());
|
||||
Some(self.add_ty(TypeEnum::TObj { obj_id, params, fields: fields.into() }))
|
||||
} else {
|
||||
None
|
||||
|
@ -71,13 +71,13 @@ impl<V> UnificationTable<Rc<V>>
|
||||
where
|
||||
V: Clone,
|
||||
{
|
||||
pub fn into_send(self) -> UnificationTable<V> {
|
||||
pub fn get_send(&self) -> UnificationTable<V> {
|
||||
let values = self.values.iter().map(|v| v.as_ref().clone()).collect();
|
||||
UnificationTable { parents: self.parents, ranks: self.ranks, values }
|
||||
UnificationTable { parents: self.parents.clone(), ranks: self.ranks.clone(), values }
|
||||
}
|
||||
|
||||
pub fn from_send(table: UnificationTable<V>) -> UnificationTable<Rc<V>> {
|
||||
let values = table.values.into_iter().map(Rc::new).collect();
|
||||
UnificationTable { parents: table.parents, ranks: table.ranks, values }
|
||||
pub fn from_send(table: &UnificationTable<V>) -> UnificationTable<Rc<V>> {
|
||||
let values = table.values.iter().cloned().map(Rc::new).collect();
|
||||
UnificationTable { parents: table.parents.clone(), ranks: table.ranks.clone(), values }
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user