hm-inference #6

Merged
sb10q merged 136 commits from hm-inference into master 2021-08-19 11:46:50 +08:00
5 changed files with 221 additions and 80 deletions
Showing only changes of commit de8b67b605 - Show all commits

View File

@ -1,6 +1,13 @@
use crate::location::Location; use std::cell::RefCell;
use crate::top_level::DefinitionId; use std::collections::HashMap;
use crate::typecheck::typedef::Type;
use crate::top_level::{DefinitionId, TopLevelContext, TopLevelDef};
use crate::typecheck::{
type_inferencer::PrimitiveStore,
typedef::{Type, Unifier},
};
use crate::{location::Location, typecheck::typedef::TypeEnum};
use itertools::{chain, izip};
use rustpython_parser::ast::Expr; use rustpython_parser::ast::Expr;
#[derive(Clone, PartialEq)] #[derive(Clone, PartialEq)]
@ -15,11 +22,121 @@ pub enum SymbolValue {
} }
pub trait SymbolResolver { pub trait SymbolResolver {
fn get_symbol_type(&self, str: &str) -> Option<Type>; fn get_symbol_type(
fn parse_type_name(&self, expr: &Expr<()>) -> Option<Type>; &self,
unifier: &mut Unifier,
primitives: &PrimitiveStore,
str: &str,
) -> Option<Type>;
fn get_identifier_def(&self, str: &str) -> DefinitionId; fn get_identifier_def(&self, str: &str) -> DefinitionId;
fn get_symbol_value(&self, str: &str) -> Option<SymbolValue>; fn get_symbol_value(&self, str: &str) -> Option<SymbolValue>;
fn get_symbol_location(&self, str: &str) -> Option<Location>; fn get_symbol_location(&self, str: &str) -> Option<Location>;
fn get_module_resolver(&self, module_name: &str) -> Option<&dyn SymbolResolver>; // NOTE: for getting imported modules' symbol resolver? // handle function call etc.
// handle function call etc. }
impl dyn SymbolResolver {
pub fn parse_type_annotation<T>(
&self,
top_level: &TopLevelContext,
unifier: &mut Unifier,
primitives: &PrimitiveStore,
expr: &Expr<T>,
) -> Result<Type, String> {
use rustpython_parser::ast::ExprKind::*;
match &expr.node {
Name { id, .. } => match id.as_str() {
"int32" => Ok(primitives.int32),
"int64" => Ok(primitives.int64),
"float" => Ok(primitives.float),
"bool" => Ok(primitives.bool),
"None" => Ok(primitives.none),
x => {
let obj_id = self.get_identifier_def(x);
let defs = top_level.definitions.read();
let def = defs[obj_id.0].read();
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
if !type_vars.is_empty() {
return Err(format!(
"Unexpected number of type parameters: expected {} but got 0",
type_vars.len()
));
}
let fields = RefCell::new(
chain(
fields.iter().map(|(k, v)| (k.clone(), *v)),
methods.iter().map(|(k, v, _)| (k.clone(), *v)),
)
.collect(),
);
Ok(unifier.add_ty(TypeEnum::TObj {
obj_id,
fields,
params: Default::default(),
}))
} else {
Err("Cannot use function name as type".into())
}
}
},
Subscript { value, slice, .. } => {
if let Name { id, .. } = &value.node {
if id == "virtual" {
let ty =
self.parse_type_annotation(top_level, unifier, primitives, slice)?;
Ok(unifier.add_ty(TypeEnum::TVirtual { ty }))
} else {
let types = if let Tuple { elts, .. } = &slice.node {
elts.iter()
.map(|v| {
self.parse_type_annotation(top_level, unifier, primitives, v)
})
.collect::<Result<Vec<_>, _>>()?
} else {
vec![self.parse_type_annotation(top_level, unifier, primitives, slice)?]
};
let obj_id = self.get_identifier_def(id);
let defs = top_level.definitions.read();
let def = defs[obj_id.0].read();
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
if types.len() != type_vars.len() {
return Err(format!(
"Unexpected number of type parameters: expected {} but got {}",
type_vars.len(),
types.len()
));
}
let mut subst = HashMap::new();
for (var, ty) in izip!(type_vars.iter(), types.iter()) {
let id = if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) {
*id
} else {
unreachable!()
};
subst.insert(id, *ty);
}
let mut fields = fields
.iter()
.map(|(attr, ty)| {
let ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(attr.clone(), ty)
})
.collect::<HashMap<_, _>>();
fields.extend(methods.iter().map(|(attr, ty, _)| {
let ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(attr.clone(), ty)
}));
let fields = RefCell::new(fields);
Ok(unifier.add_ty(TypeEnum::TObj { obj_id, fields, params: subst }))
} else {
Err("Cannot use function name as type".into())
}
}
} else {
Err("unsupported type expression".into())
}
}
_ => Err("unsupported type expression".into()),
}
}
} }

View File

@ -440,37 +440,3 @@ impl TopLevelComposer {
} }
} }
pub fn parse_type_var<T>(
input: &ast::Expr<T>,
resolver: &dyn SymbolResolver,
) -> Result<Type, String> {
match &input.node {
ast::ExprKind::Name { id, .. } => resolver
.get_symbol_type(id)
.ok_or_else(|| "unknown type variable identifer".to_string()),
ast::ExprKind::Attribute { value, attr, .. } => {
if let ast::ExprKind::Name { id, .. } = &value.node {
let next_resolver = resolver
.get_module_resolver(id)
.ok_or_else(|| "unknown imported module".to_string())?;
next_resolver
.get_symbol_type(attr)
.ok_or_else(|| "unknown type variable identifer".to_string())
} else {
unimplemented!()
// recursively resolve attr thing, FIXME: new problem: how do we handle this?
// # A.py
// class A:
// T = TypeVar('T', int, bool)
// pass
// # B.py
// import A
// class B(Generic[A.A.T]):
// pass
}
}
_ => Err("not supported".into()),
}
}

View File

@ -5,7 +5,7 @@ use std::{cell::RefCell, sync::Arc};
use super::magic_methods::*; use super::magic_methods::*;
use super::typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier}; use super::typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier};
use crate::symbol_resolver::SymbolResolver; use crate::{symbol_resolver::SymbolResolver, top_level::TopLevelContext};
use itertools::izip; use itertools::izip;
use rustpython_parser::ast::{ use rustpython_parser::ast::{
self, self,
@ -44,6 +44,7 @@ pub struct FunctionData {
} }
pub struct Inferencer<'a> { pub struct Inferencer<'a> {
pub top_level: &'a TopLevelContext,
pub function_data: &'a mut FunctionData, pub function_data: &'a mut FunctionData,
pub unifier: &'a mut Unifier, pub unifier: &'a mut Unifier,
pub primitives: &'a PrimitiveStore, pub primitives: &'a PrimitiveStore,
@ -81,11 +82,12 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
} else { } else {
None None
}; };
let annotation_type = self let annotation_type = self.function_data.resolver.parse_type_annotation(
.function_data self.top_level,
.resolver self.unifier,
.parse_type_name(annotation.as_ref()) &self.primitives,
.ok_or_else(|| "cannot parse type name".to_string())?; annotation.as_ref(),
)?;
self.unifier.unify(annotation_type, target.custom.unwrap())?; self.unifier.unify(annotation_type, target.custom.unwrap())?;
let annotation = Box::new(NaiveFolder().fold_expr(*annotation)?); let annotation = Box::new(NaiveFolder().fold_expr(*annotation)?);
Located { Located {
@ -235,6 +237,7 @@ impl<'a> Inferencer<'a> {
primitives: self.primitives, primitives: self.primitives,
virtual_checks: self.virtual_checks, virtual_checks: self.virtual_checks,
calls: self.calls, calls: self.calls,
top_level: self.top_level,
variable_mapping, variable_mapping,
}; };
let fun = FunSignature { let fun = FunSignature {
@ -275,6 +278,7 @@ impl<'a> Inferencer<'a> {
function_data: self.function_data, function_data: self.function_data,
unifier: self.unifier, unifier: self.unifier,
virtual_checks: self.virtual_checks, virtual_checks: self.virtual_checks,
top_level: self.top_level,
variable_mapping, variable_mapping,
primitives: self.primitives, primitives: self.primitives,
calls: self.calls, calls: self.calls,
@ -336,10 +340,12 @@ impl<'a> Inferencer<'a> {
} }
let arg0 = self.fold_expr(args.remove(0))?; let arg0 = self.fold_expr(args.remove(0))?;
let ty = if let Some(arg) = args.pop() { let ty = if let Some(arg) = args.pop() {
self.function_data self.function_data.resolver.parse_type_annotation(
.resolver self.top_level,
.parse_type_name(&arg) self.unifier,
.ok_or_else(|| "error parsing type".to_string())? self.primitives,
&arg,
)?
} else { } else {
self.unifier.get_fresh_var().0 self.unifier.get_fresh_var().0
}; };
@ -412,11 +418,15 @@ impl<'a> Inferencer<'a> {
if let Some(ty) = self.variable_mapping.get(id) { if let Some(ty) = self.variable_mapping.get(id) {
Ok(*ty) Ok(*ty)
} else { } else {
Ok(self.function_data.resolver.get_symbol_type(id).unwrap_or_else(|| { Ok(self
let ty = self.unifier.get_fresh_var().0; .function_data
self.variable_mapping.insert(id.to_string(), ty); .resolver
ty .get_symbol_type(self.unifier, self.primitives, id)
})) .unwrap_or_else(|| {
let ty = self.unifier.get_fresh_var().0;
self.variable_mapping.insert(id.to_string(), ty);
ty
}))
} }
} }

View File

@ -1,30 +1,23 @@
use super::super::typedef::*; use super::super::typedef::*;
use super::*; use super::*;
use crate::location::Location;
use crate::symbol_resolver::*; use crate::symbol_resolver::*;
use crate::top_level::DefinitionId; use crate::top_level::DefinitionId;
use crate::{location::Location, top_level::TopLevelDef};
use indoc::indoc; use indoc::indoc;
use itertools::zip; use itertools::zip;
use rustpython_parser::ast; use parking_lot::RwLock;
use rustpython_parser::parser::parse_program; use rustpython_parser::parser::parse_program;
use test_case::test_case; use test_case::test_case;
struct Resolver { struct Resolver {
identifier_mapping: HashMap<String, Type>, id_to_type: HashMap<String, Type>,
id_to_def: HashMap<String, DefinitionId>,
class_names: HashMap<String, Type>, class_names: HashMap<String, Type>,
} }
impl SymbolResolver for Resolver { impl SymbolResolver for Resolver {
fn get_symbol_type(&self, str: &str) -> Option<Type> { fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option<Type> {
self.identifier_mapping.get(str).cloned() self.id_to_type.get(str).cloned()
}
fn parse_type_name(&self, ty: &ast::Expr<()>) -> Option<Type> {
if let ExprKind::Name { id, .. } = &ty.node {
self.class_names.get(id).cloned()
} else {
unimplemented!()
}
} }
fn get_symbol_value(&self, _: &str) -> Option<SymbolValue> { fn get_symbol_value(&self, _: &str) -> Option<SymbolValue> {
@ -35,12 +28,8 @@ impl SymbolResolver for Resolver {
unimplemented!() unimplemented!()
} }
fn get_identifier_def(&self, _: &str) -> DefinitionId { fn get_identifier_def(&self, id: &str) -> DefinitionId {
unimplemented!() self.id_to_def.get(id).cloned().unwrap()
}
fn get_module_resolver(&self, _: &str) -> Option<&dyn SymbolResolver> {
unimplemented!()
} }
} }
@ -52,6 +41,7 @@ struct TestEnvironment {
pub identifier_mapping: HashMap<String, Type>, pub identifier_mapping: HashMap<String, Type>,
pub virtual_checks: Vec<(Type, Type)>, pub virtual_checks: Vec<(Type, Type)>,
pub calls: HashMap<CodeLocation, Arc<Call>>, pub calls: HashMap<CodeLocation, Arc<Call>>,
pub top_level: TopLevelContext,
} }
impl TestEnvironment { impl TestEnvironment {
@ -101,11 +91,17 @@ impl TestEnvironment {
identifier_mapping.insert("None".into(), none); identifier_mapping.insert("None".into(), none);
let resolver = Arc::new(Resolver { let resolver = Arc::new(Resolver {
identifier_mapping: identifier_mapping.clone(), id_to_type: identifier_mapping.clone(),
id_to_def: Default::default(),
class_names: Default::default(), class_names: Default::default(),
}) as Arc<dyn SymbolResolver>; }) as Arc<dyn SymbolResolver>;
TestEnvironment { TestEnvironment {
top_level: TopLevelContext {
definitions: Default::default(),
unifiers: Default::default(),
conetexts: Default::default(),
},
unifier, unifier,
function_data: FunctionData { function_data: FunctionData {
resolver, resolver,
@ -123,6 +119,7 @@ impl TestEnvironment {
fn new() -> TestEnvironment { fn new() -> TestEnvironment {
let mut unifier = Unifier::new(); let mut unifier = Unifier::new();
let mut identifier_mapping = HashMap::new(); let mut identifier_mapping = HashMap::new();
let mut top_level_defs = Vec::new();
let int32 = unifier.add_ty(TypeEnum::TObj { let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(0), obj_id: DefinitionId(0),
fields: HashMap::new().into(), fields: HashMap::new().into(),
@ -149,6 +146,16 @@ impl TestEnvironment {
params: HashMap::new(), params: HashMap::new(),
}); });
identifier_mapping.insert("None".into(), none); identifier_mapping.insert("None".into(), none);
for i in 0..5 {
top_level_defs.push(RwLock::new(TopLevelDef::Class {
object_id: DefinitionId(i),
type_vars: Default::default(),
fields: Default::default(),
methods: Default::default(),
ancestors: Default::default(),
resolver: None,
}));
}
let primitives = PrimitiveStore { int32, int64, float, bool, none }; let primitives = PrimitiveStore { int32, int64, float, bool, none };
@ -159,6 +166,14 @@ impl TestEnvironment {
fields: [("a".into(), v0)].iter().cloned().collect::<HashMap<_, _>>().into(), fields: [("a".into(), v0)].iter().cloned().collect::<HashMap<_, _>>().into(),
params: [(id, v0)].iter().cloned().collect(), params: [(id, v0)].iter().cloned().collect(),
}); });
top_level_defs.push(RwLock::new(TopLevelDef::Class {
object_id: DefinitionId(5),
type_vars: vec![v0],
fields: [("a".into(), v0)].into(),
methods: Default::default(),
ancestors: Default::default(),
resolver: None,
}));
identifier_mapping.insert( identifier_mapping.insert(
"Foo".into(), "Foo".into(),
@ -183,6 +198,14 @@ impl TestEnvironment {
.into(), .into(),
params: Default::default(), params: Default::default(),
}); });
top_level_defs.push(RwLock::new(TopLevelDef::Class {
object_id: DefinitionId(6),
type_vars: Default::default(),
fields: [("a".into(), int32), ("b".into(), fun)].into(),
methods: Default::default(),
ancestors: Default::default(),
resolver: None,
}));
identifier_mapping.insert( identifier_mapping.insert(
"Bar".into(), "Bar".into(),
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
@ -201,6 +224,14 @@ impl TestEnvironment {
.into(), .into(),
params: Default::default(), params: Default::default(),
}); });
top_level_defs.push(RwLock::new(TopLevelDef::Class {
object_id: DefinitionId(7),
type_vars: Default::default(),
fields: [("a".into(), bool), ("b".into(), fun)].into(),
methods: Default::default(),
ancestors: Default::default(),
resolver: None,
}));
identifier_mapping.insert( identifier_mapping.insert(
"Bar2".into(), "Bar2".into(),
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
@ -225,12 +256,28 @@ impl TestEnvironment {
.cloned() .cloned()
.collect(); .collect();
let resolver = let top_level = TopLevelContext {
Arc::new(Resolver { identifier_mapping: identifier_mapping.clone(), class_names }) definitions: Arc::new(RwLock::new(top_level_defs)),
as Arc<dyn SymbolResolver>; unifiers: Default::default(),
conetexts: Default::default(),
};
let resolver = Arc::new(Resolver {
id_to_type: identifier_mapping.clone(),
id_to_def: [
("Foo".into(), DefinitionId(5)),
("Bar".into(), DefinitionId(6)),
("Bar2".into(), DefinitionId(7)),
]
.iter()
.cloned()
.collect(),
class_names,
}) as Arc<dyn SymbolResolver>;
TestEnvironment { TestEnvironment {
unifier, unifier,
top_level,
function_data: FunctionData { function_data: FunctionData {
resolver, resolver,
bound_variables: Vec::new(), bound_variables: Vec::new(),
@ -246,6 +293,7 @@ impl TestEnvironment {
fn get_inferencer(&mut self) -> Inferencer { fn get_inferencer(&mut self) -> Inferencer {
Inferencer { Inferencer {
top_level: &self.top_level,
function_data: &mut self.function_data, function_data: &mut self.function_data,
unifier: &mut self.unifier, unifier: &mut self.unifier,
variable_mapping: Default::default(), variable_mapping: Default::default(),

View File

@ -645,7 +645,7 @@ impl Unifier {
/// If this returns Some(T), T would be the substituted type. /// If this returns Some(T), T would be the substituted type.
/// If this returns None, the result type would be the original type /// If this returns None, the result type would be the original type
/// (no substitution has to be done). /// (no substitution has to be done).
fn subst(&mut self, a: Type, mapping: &VarMap) -> Option<Type> { pub fn subst(&mut self, a: Type, mapping: &VarMap) -> Option<Type> {
use TypeVarMeta::*; use TypeVarMeta::*;
let ty = self.unification_table.probe_value(a).clone(); let ty = self.unification_table.probe_value(a).clone();
// this function would only be called when we instantiate functions. // this function would only be called when we instantiate functions.