hm-inference #6

Merged
sb10q merged 136 commits from hm-inference into master 2021-08-19 11:46:50 +08:00
4 changed files with 43 additions and 33 deletions
Showing only changes of commit 832513e210 - Show all commits

View File

@ -35,13 +35,11 @@ impl<'a> Inferencer<'a> {
) -> Result<(), String> { ) -> Result<(), String> {
// there are some cases where the custom field is None // there are some cases where the custom field is None
if let Some(ty) = &expr.custom { if let Some(ty) = &expr.custom {
let ty = self.unifier.get_ty(*ty); if !self.unifier.is_concrete(*ty, &self.function_data.bound_variables) {
let ty = ty.as_ref();
if !ty.is_concrete() {
return Err(format!( return Err(format!(
"expected concrete type at {} but got {}", "expected concrete type at {} but got {}",
expr.location, expr.location,
ty.get_type_name() self.unifier.get_ty(*ty).get_type_name()
)); ));
} }
} }

View File

@ -25,14 +25,18 @@ pub struct PrimitiveStore {
pub none: Type, pub none: Type,
} }
pub struct FunctionData {
pub resolver: Box<dyn SymbolResolver>,
pub return_type: Option<Type>,
pub bound_variables: Vec<Type>,
}
pub struct Inferencer<'a> { pub struct Inferencer<'a> {
pub resolver: &'a mut Box<dyn SymbolResolver>, pub function_data: &'a mut FunctionData,
pub unifier: &'a mut Unifier, pub unifier: &'a mut Unifier,
pub primitives: &'a PrimitiveStore,
pub virtual_checks: &'a mut Vec<(Type, Type)>, pub virtual_checks: &'a mut Vec<(Type, Type)>,
pub variable_mapping: HashMap<String, Type>, pub variable_mapping: HashMap<String, Type>,
pub calls: &'a mut Vec<Rc<Call>>,
pub primitives: &'a PrimitiveStore,
pub return_type: Option<Type>,
} }
struct NaiveFolder(); struct NaiveFolder();
@ -65,6 +69,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
None None
}; };
let annotation_type = self let annotation_type = self
.function_data
.resolver .resolver
.parse_type_name(annotation.as_ref()) .parse_type_name(annotation.as_ref())
.ok_or_else(|| "cannot parse type name".to_string())?; .ok_or_else(|| "cannot parse type name".to_string())?;
@ -93,7 +98,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
} }
ast::StmtKind::AnnAssign { .. } | ast::StmtKind::Expr { .. } => {} ast::StmtKind::AnnAssign { .. } | ast::StmtKind::Expr { .. } => {}
ast::StmtKind::Break | ast::StmtKind::Continue => {} ast::StmtKind::Break | ast::StmtKind::Continue => {}
ast::StmtKind::Return { value } => match (value, self.return_type) { ast::StmtKind::Return { value } => match (value, self.function_data.return_type) {
(Some(v), Some(v1)) => { (Some(v), Some(v1)) => {
self.unifier.unify(v.custom.unwrap(), v1)?; self.unifier.unify(v.custom.unwrap(), v1)?;
} }
@ -171,7 +176,6 @@ impl<'a> Inferencer<'a> {
) -> InferenceResult { ) -> InferenceResult {
let call = let call =
Rc::new(Call { posargs: params, kwargs: HashMap::new(), ret, fun: RefCell::new(None) }); Rc::new(Call { posargs: params, kwargs: HashMap::new(), ret, fun: RefCell::new(None) });
self.calls.push(call.clone());
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into())); let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into()));
let fields = once((method, call)).collect(); let fields = once((method, call)).collect();
let record = self.unifier.add_record(fields); let record = self.unifier.add_record(fields);
@ -207,13 +211,11 @@ impl<'a> Inferencer<'a> {
variable_mapping.extend(fn_args.iter().cloned()); variable_mapping.extend(fn_args.iter().cloned());
let ret = self.unifier.get_fresh_var().0; let ret = self.unifier.get_fresh_var().0;
let mut new_context = Inferencer { let mut new_context = Inferencer {
resolver: self.resolver, function_data: self.function_data,
unifier: self.unifier, unifier: self.unifier,
primitives: self.primitives,
virtual_checks: self.virtual_checks, virtual_checks: self.virtual_checks,
variable_mapping, variable_mapping,
calls: self.calls,
primitives: self.primitives,
return_type: self.return_type,
}; };
let fun = FunSignature { let fun = FunSignature {
args: fn_args args: fn_args
@ -250,13 +252,11 @@ impl<'a> Inferencer<'a> {
} }
let variable_mapping = self.variable_mapping.clone(); let variable_mapping = self.variable_mapping.clone();
let mut new_context = Inferencer { let mut new_context = Inferencer {
resolver: self.resolver, function_data: self.function_data,
unifier: self.unifier, unifier: self.unifier,
virtual_checks: self.virtual_checks, virtual_checks: self.virtual_checks,
variable_mapping, variable_mapping,
calls: self.calls,
primitives: self.primitives, primitives: self.primitives,
return_type: self.return_type,
}; };
let elt = new_context.fold_expr(elt)?; let elt = new_context.fold_expr(elt)?;
let generator = generators.pop().unwrap(); let generator = generators.pop().unwrap();
@ -315,7 +315,7 @@ impl<'a> Inferencer<'a> {
} }
let arg0 = self.fold_expr(args.remove(0))?; let arg0 = self.fold_expr(args.remove(0))?;
let ty = if let Some(arg) = args.pop() { let ty = if let Some(arg) = args.pop() {
self.resolver self.function_data.resolver
.parse_type_name(&arg) .parse_type_name(&arg)
.ok_or_else(|| "error parsing type".to_string())? .ok_or_else(|| "error parsing type".to_string())?
} else { } else {
@ -379,7 +379,6 @@ impl<'a> Inferencer<'a> {
fun: RefCell::new(None), fun: RefCell::new(None),
ret, ret,
}); });
self.calls.push(call.clone());
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into())); let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into()));
self.unifier.unify(func.custom.unwrap(), call)?; self.unifier.unify(func.custom.unwrap(), call)?;
@ -390,7 +389,7 @@ impl<'a> Inferencer<'a> {
if let Some(ty) = self.variable_mapping.get(id) { if let Some(ty) = self.variable_mapping.get(id) {
Ok(*ty) Ok(*ty)
} else { } else {
Ok(self.resolver.get_symbol_type(id).unwrap_or_else(|| { Ok(self.function_data.resolver.get_symbol_type(id).unwrap_or_else(|| {
let ty = self.unifier.get_fresh_var().0; let ty = self.unifier.get_fresh_var().0;
self.variable_mapping.insert(id.to_string(), ty); self.variable_mapping.insert(id.to_string(), ty);
ty ty

View File

@ -37,8 +37,7 @@ impl SymbolResolver for Resolver {
struct TestEnvironment { struct TestEnvironment {
pub unifier: Unifier, pub unifier: Unifier,
pub resolver: Box<dyn SymbolResolver>, pub function_data: FunctionData,
pub calls: Vec<Rc<Call>>,
pub primitives: PrimitiveStore, pub primitives: PrimitiveStore,
pub id_to_name: HashMap<usize, String>, pub id_to_name: HashMap<usize, String>,
pub identifier_mapping: HashMap<String, Type>, pub identifier_mapping: HashMap<String, Type>,
@ -149,24 +148,25 @@ impl TestEnvironment {
TestEnvironment { TestEnvironment {
unifier, unifier,
function_data: FunctionData {
resolver, resolver,
bound_variables: Vec::new(),
return_type: None
},
primitives, primitives,
id_to_name, id_to_name,
identifier_mapping, identifier_mapping,
calls: Vec::new(),
virtual_checks: Vec::new(), virtual_checks: Vec::new(),
} }
} }
fn get_inferencer(&mut self) -> Inferencer { fn get_inferencer(&mut self) -> Inferencer {
Inferencer { Inferencer {
resolver: &mut self.resolver, function_data: &mut self.function_data,
unifier: &mut self.unifier, unifier: &mut self.unifier,
variable_mapping: Default::default(), variable_mapping: Default::default(),
calls: &mut self.calls,
primitives: &mut self.primitives, primitives: &mut self.primitives,
virtual_checks: &mut self.virtual_checks, virtual_checks: &mut self.virtual_checks,
return_type: None,
} }
} }
} }

View File

@ -83,10 +83,6 @@ impl TypeEnum {
TypeEnum::TFunc { .. } => "TFunc", TypeEnum::TFunc { .. } => "TFunc",
} }
} }
pub fn is_concrete(&self) -> bool {
!matches!(self, TypeEnum::TVar { .. })
}
} }
pub struct Unifier { pub struct Unifier {
@ -143,6 +139,23 @@ impl Unifier {
(self.add_ty(TypeEnum::TVar { id, range, meta: TypeVarMeta::Generic }), id) (self.add_ty(TypeEnum::TVar { id, range, meta: TypeVarMeta::Generic }), id)
} }
pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool {
use TypeEnum::*;
match &*self.get_ty(a) {
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
TCall { .. } => false,
TList { ty } => self.is_concrete(*ty, allowed_typevars),
TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)),
TObj { params: vars, .. } => {
vars.values().all(|ty| self.is_concrete(*ty, allowed_typevars))
}
// functions are instantiated for each call sites, so the function type can contain
// type variables.
TFunc { .. } => true,
TVirtual { ty } => self.is_concrete(*ty, allowed_typevars),
}
}
pub fn unify(&mut self, a: Type, b: Type) -> Result<(), String> { pub fn unify(&mut self, a: Type, b: Type) -> Result<(), String> {
if self.unification_table.unioned(a, b) { if self.unification_table.unioned(a, b) {
Ok(()) Ok(())
@ -204,7 +217,7 @@ impl Unifier {
} }
for v1 in old_range2.iter() { for v1 in old_range2.iter() {
for v2 in range1.iter() { for v2 in range1.iter() {
if let Ok(result) = self.get_intersection(*v1, *v2){ if let Ok(result) = self.get_intersection(*v1, *v2) {
range2.push(result.unwrap_or(*v2)); range2.push(result.unwrap_or(*v2));
} }
} }
@ -486,7 +499,7 @@ impl Unifier {
Err(format!("Cannot unify {} with {}", a.get_type_name(), b.get_type_name())) Err(format!("Cannot unify {} with {}", a.get_type_name(), b.get_type_name()))
} }
/// Instantiate a function if it hasn't been instntiated. /// Instantiate a function if it hasn't been instantiated.
/// Returns Some(T) where T is the instantiated type. /// Returns Some(T) where T is the instantiated type.
/// Returns None if the function is already instantiated. /// Returns None if the function is already instantiated.
fn instantiate_fun(&mut self, ty: Type, fun: &FunSignature) -> Type { fn instantiate_fun(&mut self, ty: Type, fun: &FunSignature) -> Type {