hm-inference #6

Merged
sb10q merged 136 commits from hm-inference into master 2021-08-19 11:46:50 +08:00
4 changed files with 43 additions and 33 deletions
Showing only changes of commit 832513e210 - Show all commits

View File

@ -35,13 +35,11 @@ impl<'a> Inferencer<'a> {
) -> Result<(), String> {
// there are some cases where the custom field is None
if let Some(ty) = &expr.custom {
let ty = self.unifier.get_ty(*ty);
let ty = ty.as_ref();
if !ty.is_concrete() {
if !self.unifier.is_concrete(*ty, &self.function_data.bound_variables) {
return Err(format!(
"expected concrete type at {} but got {}",
expr.location,
ty.get_type_name()
self.unifier.get_ty(*ty).get_type_name()
));
}
}

View File

@ -25,14 +25,18 @@ pub struct PrimitiveStore {
pub none: Type,
}
pub struct FunctionData {
pub resolver: Box<dyn SymbolResolver>,
pub return_type: Option<Type>,
pub bound_variables: Vec<Type>,
}
pub struct Inferencer<'a> {
pub resolver: &'a mut Box<dyn SymbolResolver>,
pub function_data: &'a mut FunctionData,
pub unifier: &'a mut Unifier,
pub primitives: &'a PrimitiveStore,
pub virtual_checks: &'a mut Vec<(Type, Type)>,
pub variable_mapping: HashMap<String, Type>,
pub calls: &'a mut Vec<Rc<Call>>,
pub primitives: &'a PrimitiveStore,
pub return_type: Option<Type>,
}
struct NaiveFolder();
@ -65,6 +69,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
None
};
let annotation_type = self
.function_data
.resolver
.parse_type_name(annotation.as_ref())
.ok_or_else(|| "cannot parse type name".to_string())?;
@ -93,7 +98,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
}
ast::StmtKind::AnnAssign { .. } | ast::StmtKind::Expr { .. } => {}
ast::StmtKind::Break | ast::StmtKind::Continue => {}
ast::StmtKind::Return { value } => match (value, self.return_type) {
ast::StmtKind::Return { value } => match (value, self.function_data.return_type) {
(Some(v), Some(v1)) => {
self.unifier.unify(v.custom.unwrap(), v1)?;
}
@ -171,7 +176,6 @@ impl<'a> Inferencer<'a> {
) -> InferenceResult {
let call =
Rc::new(Call { posargs: params, kwargs: HashMap::new(), ret, fun: RefCell::new(None) });
self.calls.push(call.clone());
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into()));
let fields = once((method, call)).collect();
let record = self.unifier.add_record(fields);
@ -207,13 +211,11 @@ impl<'a> Inferencer<'a> {
variable_mapping.extend(fn_args.iter().cloned());
let ret = self.unifier.get_fresh_var().0;
let mut new_context = Inferencer {
resolver: self.resolver,
function_data: self.function_data,
unifier: self.unifier,
primitives: self.primitives,
virtual_checks: self.virtual_checks,
variable_mapping,
calls: self.calls,
primitives: self.primitives,
return_type: self.return_type,
};
let fun = FunSignature {
args: fn_args
@ -250,13 +252,11 @@ impl<'a> Inferencer<'a> {
}
let variable_mapping = self.variable_mapping.clone();
let mut new_context = Inferencer {
resolver: self.resolver,
function_data: self.function_data,
unifier: self.unifier,
virtual_checks: self.virtual_checks,
variable_mapping,
calls: self.calls,
primitives: self.primitives,
return_type: self.return_type,
};
let elt = new_context.fold_expr(elt)?;
let generator = generators.pop().unwrap();
@ -315,7 +315,7 @@ impl<'a> Inferencer<'a> {
}
let arg0 = self.fold_expr(args.remove(0))?;
let ty = if let Some(arg) = args.pop() {
self.resolver
self.function_data.resolver
.parse_type_name(&arg)
.ok_or_else(|| "error parsing type".to_string())?
} else {
@ -379,7 +379,6 @@ impl<'a> Inferencer<'a> {
fun: RefCell::new(None),
ret,
});
self.calls.push(call.clone());
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into()));
self.unifier.unify(func.custom.unwrap(), call)?;
@ -390,7 +389,7 @@ impl<'a> Inferencer<'a> {
if let Some(ty) = self.variable_mapping.get(id) {
Ok(*ty)
} else {
Ok(self.resolver.get_symbol_type(id).unwrap_or_else(|| {
Ok(self.function_data.resolver.get_symbol_type(id).unwrap_or_else(|| {
let ty = self.unifier.get_fresh_var().0;
self.variable_mapping.insert(id.to_string(), ty);
ty

View File

@ -37,8 +37,7 @@ impl SymbolResolver for Resolver {
struct TestEnvironment {
pub unifier: Unifier,
pub resolver: Box<dyn SymbolResolver>,
pub calls: Vec<Rc<Call>>,
pub function_data: FunctionData,
pub primitives: PrimitiveStore,
pub id_to_name: HashMap<usize, String>,
pub identifier_mapping: HashMap<String, Type>,
@ -149,24 +148,25 @@ impl TestEnvironment {
TestEnvironment {
unifier,
resolver,
function_data: FunctionData {
resolver,
bound_variables: Vec::new(),
return_type: None
},
primitives,
id_to_name,
identifier_mapping,
calls: Vec::new(),
virtual_checks: Vec::new(),
}
}
fn get_inferencer(&mut self) -> Inferencer {
Inferencer {
resolver: &mut self.resolver,
function_data: &mut self.function_data,
unifier: &mut self.unifier,
variable_mapping: Default::default(),
calls: &mut self.calls,
primitives: &mut self.primitives,
virtual_checks: &mut self.virtual_checks,
return_type: None,
}
}
}

View File

@ -83,10 +83,6 @@ impl TypeEnum {
TypeEnum::TFunc { .. } => "TFunc",
}
}
pub fn is_concrete(&self) -> bool {
!matches!(self, TypeEnum::TVar { .. })
}
}
pub struct Unifier {
@ -143,6 +139,23 @@ impl Unifier {
(self.add_ty(TypeEnum::TVar { id, range, meta: TypeVarMeta::Generic }), id)
}
pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool {
use TypeEnum::*;
match &*self.get_ty(a) {
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
TCall { .. } => false,
TList { ty } => self.is_concrete(*ty, allowed_typevars),
TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)),
TObj { params: vars, .. } => {
vars.values().all(|ty| self.is_concrete(*ty, allowed_typevars))
}
// functions are instantiated for each call sites, so the function type can contain
// type variables.
TFunc { .. } => true,
TVirtual { ty } => self.is_concrete(*ty, allowed_typevars),
}
}
pub fn unify(&mut self, a: Type, b: Type) -> Result<(), String> {
if self.unification_table.unioned(a, b) {
Ok(())
@ -204,7 +217,7 @@ impl Unifier {
}
for v1 in old_range2.iter() {
for v2 in range1.iter() {
if let Ok(result) = self.get_intersection(*v1, *v2){
if let Ok(result) = self.get_intersection(*v1, *v2) {
range2.push(result.unwrap_or(*v2));
}
}
@ -486,7 +499,7 @@ impl Unifier {
Err(format!("Cannot unify {} with {}", a.get_type_name(), b.get_type_name()))
}
/// Instantiate a function if it hasn't been instntiated.
/// Instantiate a function if it hasn't been instantiated.
/// Returns Some(T) where T is the instantiated type.
/// Returns None if the function is already instantiated.
fn instantiate_fun(&mut self, ty: Type, fun: &FunSignature) -> Type {