hm-inference #6
|
@ -1,6 +1,13 @@
|
|||
use crate::location::Location;
|
||||
use crate::top_level::DefinitionId;
|
||||
use crate::typecheck::typedef::Type;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::top_level::{DefinitionId, TopLevelContext, TopLevelDef};
|
||||
use crate::typecheck::{
|
||||
type_inferencer::PrimitiveStore,
|
||||
typedef::{Type, Unifier},
|
||||
};
|
||||
use crate::{location::Location, typecheck::typedef::TypeEnum};
|
||||
use itertools::{chain, izip};
|
||||
use rustpython_parser::ast::Expr;
|
||||
|
||||
#[derive(Clone, PartialEq)]
|
||||
|
@ -15,11 +22,121 @@ pub enum SymbolValue {
|
|||
}
|
||||
|
||||
pub trait SymbolResolver {
|
||||
fn get_symbol_type(&self, str: &str) -> Option<Type>;
|
||||
fn parse_type_name(&self, expr: &Expr<()>) -> Option<Type>;
|
||||
fn get_symbol_type(
|
||||
&self,
|
||||
unifier: &mut Unifier,
|
||||
primitives: &PrimitiveStore,
|
||||
str: &str,
|
||||
) -> Option<Type>;
|
||||
fn get_identifier_def(&self, str: &str) -> DefinitionId;
|
||||
fn get_symbol_value(&self, str: &str) -> Option<SymbolValue>;
|
||||
fn get_symbol_location(&self, str: &str) -> Option<Location>;
|
||||
fn get_module_resolver(&self, module_name: &str) -> Option<&dyn SymbolResolver>; // NOTE: for getting imported modules' symbol resolver?
|
||||
// handle function call etc.
|
||||
}
|
||||
|
||||
impl dyn SymbolResolver {
|
||||
pub fn parse_type_annotation<T>(
|
||||
&self,
|
||||
top_level: &TopLevelContext,
|
||||
unifier: &mut Unifier,
|
||||
primitives: &PrimitiveStore,
|
||||
expr: &Expr<T>,
|
||||
) -> Result<Type, String> {
|
||||
use rustpython_parser::ast::ExprKind::*;
|
||||
match &expr.node {
|
||||
Name { id, .. } => match id.as_str() {
|
||||
"int32" => Ok(primitives.int32),
|
||||
"int64" => Ok(primitives.int64),
|
||||
"float" => Ok(primitives.float),
|
||||
"bool" => Ok(primitives.bool),
|
||||
"None" => Ok(primitives.none),
|
||||
x => {
|
||||
let obj_id = self.get_identifier_def(x);
|
||||
let defs = top_level.definitions.read();
|
||||
let def = defs[obj_id.0].read();
|
||||
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
|
||||
if !type_vars.is_empty() {
|
||||
return Err(format!(
|
||||
"Unexpected number of type parameters: expected {} but got 0",
|
||||
type_vars.len()
|
||||
));
|
||||
}
|
||||
let fields = RefCell::new(
|
||||
chain(
|
||||
fields.iter().map(|(k, v)| (k.clone(), *v)),
|
||||
methods.iter().map(|(k, v, _)| (k.clone(), *v)),
|
||||
)
|
||||
.collect(),
|
||||
);
|
||||
Ok(unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id,
|
||||
fields,
|
||||
params: Default::default(),
|
||||
}))
|
||||
} else {
|
||||
Err("Cannot use function name as type".into())
|
||||
}
|
||||
}
|
||||
},
|
||||
Subscript { value, slice, .. } => {
|
||||
if let Name { id, .. } = &value.node {
|
||||
if id == "virtual" {
|
||||
let ty =
|
||||
self.parse_type_annotation(top_level, unifier, primitives, slice)?;
|
||||
Ok(unifier.add_ty(TypeEnum::TVirtual { ty }))
|
||||
} else {
|
||||
let types = if let Tuple { elts, .. } = &slice.node {
|
||||
elts.iter()
|
||||
.map(|v| {
|
||||
self.parse_type_annotation(top_level, unifier, primitives, v)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
} else {
|
||||
vec![self.parse_type_annotation(top_level, unifier, primitives, slice)?]
|
||||
};
|
||||
|
||||
let obj_id = self.get_identifier_def(id);
|
||||
let defs = top_level.definitions.read();
|
||||
let def = defs[obj_id.0].read();
|
||||
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
|
||||
if types.len() != type_vars.len() {
|
||||
return Err(format!(
|
||||
"Unexpected number of type parameters: expected {} but got {}",
|
||||
type_vars.len(),
|
||||
types.len()
|
||||
));
|
||||
}
|
||||
let mut subst = HashMap::new();
|
||||
for (var, ty) in izip!(type_vars.iter(), types.iter()) {
|
||||
let id = if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) {
|
||||
*id
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
subst.insert(id, *ty);
|
||||
}
|
||||
let mut fields = fields
|
||||
.iter()
|
||||
.map(|(attr, ty)| {
|
||||
let ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
|
||||
(attr.clone(), ty)
|
||||
})
|
||||
.collect::<HashMap<_, _>>();
|
||||
fields.extend(methods.iter().map(|(attr, ty, _)| {
|
||||
let ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
|
||||
(attr.clone(), ty)
|
||||
}));
|
||||
let fields = RefCell::new(fields);
|
||||
Ok(unifier.add_ty(TypeEnum::TObj { obj_id, fields, params: subst }))
|
||||
} else {
|
||||
Err("Cannot use function name as type".into())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Err("unsupported type expression".into())
|
||||
}
|
||||
}
|
||||
_ => Err("unsupported type expression".into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -440,37 +440,3 @@ impl TopLevelComposer {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn parse_type_var<T>(
|
||||
input: &ast::Expr<T>,
|
||||
resolver: &dyn SymbolResolver,
|
||||
) -> Result<Type, String> {
|
||||
match &input.node {
|
||||
ast::ExprKind::Name { id, .. } => resolver
|
||||
.get_symbol_type(id)
|
||||
.ok_or_else(|| "unknown type variable identifer".to_string()),
|
||||
|
||||
ast::ExprKind::Attribute { value, attr, .. } => {
|
||||
if let ast::ExprKind::Name { id, .. } = &value.node {
|
||||
let next_resolver = resolver
|
||||
.get_module_resolver(id)
|
||||
.ok_or_else(|| "unknown imported module".to_string())?;
|
||||
next_resolver
|
||||
.get_symbol_type(attr)
|
||||
.ok_or_else(|| "unknown type variable identifer".to_string())
|
||||
} else {
|
||||
unimplemented!()
|
||||
// recursively resolve attr thing, FIXME: new problem: how do we handle this?
|
||||
// # A.py
|
||||
// class A:
|
||||
// T = TypeVar('T', int, bool)
|
||||
// pass
|
||||
// # B.py
|
||||
// import A
|
||||
// class B(Generic[A.A.T]):
|
||||
// pass
|
||||
}
|
||||
}
|
||||
|
||||
_ => Err("not supported".into()),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ use std::{cell::RefCell, sync::Arc};
|
|||
|
||||
use super::magic_methods::*;
|
||||
use super::typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier};
|
||||
use crate::symbol_resolver::SymbolResolver;
|
||||
use crate::{symbol_resolver::SymbolResolver, top_level::TopLevelContext};
|
||||
use itertools::izip;
|
||||
use rustpython_parser::ast::{
|
||||
self,
|
||||
|
@ -44,6 +44,7 @@ pub struct FunctionData {
|
|||
}
|
||||
|
||||
pub struct Inferencer<'a> {
|
||||
pub top_level: &'a TopLevelContext,
|
||||
pub function_data: &'a mut FunctionData,
|
||||
pub unifier: &'a mut Unifier,
|
||||
pub primitives: &'a PrimitiveStore,
|
||||
|
@ -81,11 +82,12 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
|||
} else {
|
||||
None
|
||||
};
|
||||
let annotation_type = self
|
||||
.function_data
|
||||
.resolver
|
||||
.parse_type_name(annotation.as_ref())
|
||||
.ok_or_else(|| "cannot parse type name".to_string())?;
|
||||
let annotation_type = self.function_data.resolver.parse_type_annotation(
|
||||
self.top_level,
|
||||
self.unifier,
|
||||
&self.primitives,
|
||||
annotation.as_ref(),
|
||||
)?;
|
||||
self.unifier.unify(annotation_type, target.custom.unwrap())?;
|
||||
let annotation = Box::new(NaiveFolder().fold_expr(*annotation)?);
|
||||
Located {
|
||||
|
@ -235,6 +237,7 @@ impl<'a> Inferencer<'a> {
|
|||
primitives: self.primitives,
|
||||
virtual_checks: self.virtual_checks,
|
||||
calls: self.calls,
|
||||
top_level: self.top_level,
|
||||
variable_mapping,
|
||||
};
|
||||
let fun = FunSignature {
|
||||
|
@ -275,6 +278,7 @@ impl<'a> Inferencer<'a> {
|
|||
function_data: self.function_data,
|
||||
unifier: self.unifier,
|
||||
virtual_checks: self.virtual_checks,
|
||||
top_level: self.top_level,
|
||||
variable_mapping,
|
||||
primitives: self.primitives,
|
||||
calls: self.calls,
|
||||
|
@ -336,10 +340,12 @@ impl<'a> Inferencer<'a> {
|
|||
}
|
||||
let arg0 = self.fold_expr(args.remove(0))?;
|
||||
let ty = if let Some(arg) = args.pop() {
|
||||
self.function_data
|
||||
.resolver
|
||||
.parse_type_name(&arg)
|
||||
.ok_or_else(|| "error parsing type".to_string())?
|
||||
self.function_data.resolver.parse_type_annotation(
|
||||
self.top_level,
|
||||
self.unifier,
|
||||
self.primitives,
|
||||
&arg,
|
||||
)?
|
||||
} else {
|
||||
self.unifier.get_fresh_var().0
|
||||
};
|
||||
|
@ -412,7 +418,11 @@ impl<'a> Inferencer<'a> {
|
|||
if let Some(ty) = self.variable_mapping.get(id) {
|
||||
Ok(*ty)
|
||||
} else {
|
||||
Ok(self.function_data.resolver.get_symbol_type(id).unwrap_or_else(|| {
|
||||
Ok(self
|
||||
.function_data
|
||||
.resolver
|
||||
.get_symbol_type(self.unifier, self.primitives, id)
|
||||
.unwrap_or_else(|| {
|
||||
let ty = self.unifier.get_fresh_var().0;
|
||||
self.variable_mapping.insert(id.to_string(), ty);
|
||||
ty
|
||||
|
|
|
@ -1,30 +1,23 @@
|
|||
use super::super::typedef::*;
|
||||
use super::*;
|
||||
use crate::location::Location;
|
||||
use crate::symbol_resolver::*;
|
||||
use crate::top_level::DefinitionId;
|
||||
use crate::{location::Location, top_level::TopLevelDef};
|
||||
use indoc::indoc;
|
||||
use itertools::zip;
|
||||
use rustpython_parser::ast;
|
||||
use parking_lot::RwLock;
|
||||
use rustpython_parser::parser::parse_program;
|
||||
use test_case::test_case;
|
||||
|
||||
struct Resolver {
|
||||
identifier_mapping: HashMap<String, Type>,
|
||||
id_to_type: HashMap<String, Type>,
|
||||
id_to_def: HashMap<String, DefinitionId>,
|
||||
class_names: HashMap<String, Type>,
|
||||
}
|
||||
|
||||
impl SymbolResolver for Resolver {
|
||||
fn get_symbol_type(&self, str: &str) -> Option<Type> {
|
||||
self.identifier_mapping.get(str).cloned()
|
||||
}
|
||||
|
||||
fn parse_type_name(&self, ty: &ast::Expr<()>) -> Option<Type> {
|
||||
if let ExprKind::Name { id, .. } = &ty.node {
|
||||
self.class_names.get(id).cloned()
|
||||
} else {
|
||||
unimplemented!()
|
||||
}
|
||||
fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option<Type> {
|
||||
self.id_to_type.get(str).cloned()
|
||||
}
|
||||
|
||||
fn get_symbol_value(&self, _: &str) -> Option<SymbolValue> {
|
||||
|
@ -35,12 +28,8 @@ impl SymbolResolver for Resolver {
|
|||
unimplemented!()
|
||||
}
|
||||
|
||||
fn get_identifier_def(&self, _: &str) -> DefinitionId {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn get_module_resolver(&self, _: &str) -> Option<&dyn SymbolResolver> {
|
||||
unimplemented!()
|
||||
fn get_identifier_def(&self, id: &str) -> DefinitionId {
|
||||
self.id_to_def.get(id).cloned().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -52,6 +41,7 @@ struct TestEnvironment {
|
|||
pub identifier_mapping: HashMap<String, Type>,
|
||||
pub virtual_checks: Vec<(Type, Type)>,
|
||||
pub calls: HashMap<CodeLocation, Arc<Call>>,
|
||||
pub top_level: TopLevelContext,
|
||||
}
|
||||
|
||||
impl TestEnvironment {
|
||||
|
@ -101,11 +91,17 @@ impl TestEnvironment {
|
|||
identifier_mapping.insert("None".into(), none);
|
||||
|
||||
let resolver = Arc::new(Resolver {
|
||||
identifier_mapping: identifier_mapping.clone(),
|
||||
id_to_type: identifier_mapping.clone(),
|
||||
id_to_def: Default::default(),
|
||||
class_names: Default::default(),
|
||||
}) as Arc<dyn SymbolResolver>;
|
||||
|
||||
TestEnvironment {
|
||||
top_level: TopLevelContext {
|
||||
definitions: Default::default(),
|
||||
unifiers: Default::default(),
|
||||
conetexts: Default::default(),
|
||||
},
|
||||
unifier,
|
||||
function_data: FunctionData {
|
||||
resolver,
|
||||
|
@ -123,6 +119,7 @@ impl TestEnvironment {
|
|||
fn new() -> TestEnvironment {
|
||||
let mut unifier = Unifier::new();
|
||||
let mut identifier_mapping = HashMap::new();
|
||||
let mut top_level_defs = Vec::new();
|
||||
let int32 = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(0),
|
||||
fields: HashMap::new().into(),
|
||||
|
@ -149,6 +146,16 @@ impl TestEnvironment {
|
|||
params: HashMap::new(),
|
||||
});
|
||||
identifier_mapping.insert("None".into(), none);
|
||||
for i in 0..5 {
|
||||
top_level_defs.push(RwLock::new(TopLevelDef::Class {
|
||||
object_id: DefinitionId(i),
|
||||
type_vars: Default::default(),
|
||||
fields: Default::default(),
|
||||
methods: Default::default(),
|
||||
ancestors: Default::default(),
|
||||
resolver: None,
|
||||
}));
|
||||
}
|
||||
|
||||
let primitives = PrimitiveStore { int32, int64, float, bool, none };
|
||||
|
||||
|
@ -159,6 +166,14 @@ impl TestEnvironment {
|
|||
fields: [("a".into(), v0)].iter().cloned().collect::<HashMap<_, _>>().into(),
|
||||
params: [(id, v0)].iter().cloned().collect(),
|
||||
});
|
||||
top_level_defs.push(RwLock::new(TopLevelDef::Class {
|
||||
object_id: DefinitionId(5),
|
||||
type_vars: vec![v0],
|
||||
fields: [("a".into(), v0)].into(),
|
||||
methods: Default::default(),
|
||||
ancestors: Default::default(),
|
||||
resolver: None,
|
||||
}));
|
||||
|
||||
identifier_mapping.insert(
|
||||
"Foo".into(),
|
||||
|
@ -183,6 +198,14 @@ impl TestEnvironment {
|
|||
.into(),
|
||||
params: Default::default(),
|
||||
});
|
||||
top_level_defs.push(RwLock::new(TopLevelDef::Class {
|
||||
object_id: DefinitionId(6),
|
||||
type_vars: Default::default(),
|
||||
fields: [("a".into(), int32), ("b".into(), fun)].into(),
|
||||
methods: Default::default(),
|
||||
ancestors: Default::default(),
|
||||
resolver: None,
|
||||
}));
|
||||
identifier_mapping.insert(
|
||||
"Bar".into(),
|
||||
unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
|
@ -201,6 +224,14 @@ impl TestEnvironment {
|
|||
.into(),
|
||||
params: Default::default(),
|
||||
});
|
||||
top_level_defs.push(RwLock::new(TopLevelDef::Class {
|
||||
object_id: DefinitionId(7),
|
||||
type_vars: Default::default(),
|
||||
fields: [("a".into(), bool), ("b".into(), fun)].into(),
|
||||
methods: Default::default(),
|
||||
ancestors: Default::default(),
|
||||
resolver: None,
|
||||
}));
|
||||
identifier_mapping.insert(
|
||||
"Bar2".into(),
|
||||
unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
|
@ -225,12 +256,28 @@ impl TestEnvironment {
|
|||
.cloned()
|
||||
.collect();
|
||||
|
||||
let resolver =
|
||||
Arc::new(Resolver { identifier_mapping: identifier_mapping.clone(), class_names })
|
||||
as Arc<dyn SymbolResolver>;
|
||||
let top_level = TopLevelContext {
|
||||
definitions: Arc::new(RwLock::new(top_level_defs)),
|
||||
unifiers: Default::default(),
|
||||
conetexts: Default::default(),
|
||||
};
|
||||
|
||||
let resolver = Arc::new(Resolver {
|
||||
id_to_type: identifier_mapping.clone(),
|
||||
id_to_def: [
|
||||
("Foo".into(), DefinitionId(5)),
|
||||
("Bar".into(), DefinitionId(6)),
|
||||
("Bar2".into(), DefinitionId(7)),
|
||||
]
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect(),
|
||||
class_names,
|
||||
}) as Arc<dyn SymbolResolver>;
|
||||
|
||||
TestEnvironment {
|
||||
unifier,
|
||||
top_level,
|
||||
function_data: FunctionData {
|
||||
resolver,
|
||||
bound_variables: Vec::new(),
|
||||
|
@ -246,6 +293,7 @@ impl TestEnvironment {
|
|||
|
||||
fn get_inferencer(&mut self) -> Inferencer {
|
||||
Inferencer {
|
||||
top_level: &self.top_level,
|
||||
function_data: &mut self.function_data,
|
||||
unifier: &mut self.unifier,
|
||||
variable_mapping: Default::default(),
|
||||
|
|
|
@ -645,7 +645,7 @@ impl Unifier {
|
|||
/// If this returns Some(T), T would be the substituted type.
|
||||
/// If this returns None, the result type would be the original type
|
||||
/// (no substitution has to be done).
|
||||
fn subst(&mut self, a: Type, mapping: &VarMap) -> Option<Type> {
|
||||
pub fn subst(&mut self, a: Type, mapping: &VarMap) -> Option<Type> {
|
||||
use TypeVarMeta::*;
|
||||
let ty = self.unification_table.probe_value(a).clone();
|
||||
// this function would only be called when we instantiate functions.
|
||||
|
|
Loading…
Reference in New Issue