hm-inference #6
|
@ -35,13 +35,11 @@ impl<'a> Inferencer<'a> {
|
||||||
) -> Result<(), String> {
|
) -> Result<(), String> {
|
||||||
// there are some cases where the custom field is None
|
// there are some cases where the custom field is None
|
||||||
if let Some(ty) = &expr.custom {
|
if let Some(ty) = &expr.custom {
|
||||||
let ty = self.unifier.get_ty(*ty);
|
if !self.unifier.is_concrete(*ty, &self.function_data.bound_variables) {
|
||||||
let ty = ty.as_ref();
|
|
||||||
if !ty.is_concrete() {
|
|
||||||
return Err(format!(
|
return Err(format!(
|
||||||
"expected concrete type at {} but got {}",
|
"expected concrete type at {} but got {}",
|
||||||
expr.location,
|
expr.location,
|
||||||
ty.get_type_name()
|
self.unifier.get_ty(*ty).get_type_name()
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,14 +25,18 @@ pub struct PrimitiveStore {
|
||||||
pub none: Type,
|
pub none: Type,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct FunctionData {
|
||||||
|
pub resolver: Box<dyn SymbolResolver>,
|
||||||
|
pub return_type: Option<Type>,
|
||||||
|
pub bound_variables: Vec<Type>,
|
||||||
|
}
|
||||||
|
|
||||||
pub struct Inferencer<'a> {
|
pub struct Inferencer<'a> {
|
||||||
pub resolver: &'a mut Box<dyn SymbolResolver>,
|
pub function_data: &'a mut FunctionData,
|
||||||
pub unifier: &'a mut Unifier,
|
pub unifier: &'a mut Unifier,
|
||||||
|
pub primitives: &'a PrimitiveStore,
|
||||||
pub virtual_checks: &'a mut Vec<(Type, Type)>,
|
pub virtual_checks: &'a mut Vec<(Type, Type)>,
|
||||||
pub variable_mapping: HashMap<String, Type>,
|
pub variable_mapping: HashMap<String, Type>,
|
||||||
pub calls: &'a mut Vec<Rc<Call>>,
|
|
||||||
pub primitives: &'a PrimitiveStore,
|
|
||||||
pub return_type: Option<Type>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct NaiveFolder();
|
struct NaiveFolder();
|
||||||
|
@ -65,6 +69,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
let annotation_type = self
|
let annotation_type = self
|
||||||
|
.function_data
|
||||||
.resolver
|
.resolver
|
||||||
.parse_type_name(annotation.as_ref())
|
.parse_type_name(annotation.as_ref())
|
||||||
.ok_or_else(|| "cannot parse type name".to_string())?;
|
.ok_or_else(|| "cannot parse type name".to_string())?;
|
||||||
|
@ -93,7 +98,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
||||||
}
|
}
|
||||||
ast::StmtKind::AnnAssign { .. } | ast::StmtKind::Expr { .. } => {}
|
ast::StmtKind::AnnAssign { .. } | ast::StmtKind::Expr { .. } => {}
|
||||||
ast::StmtKind::Break | ast::StmtKind::Continue => {}
|
ast::StmtKind::Break | ast::StmtKind::Continue => {}
|
||||||
ast::StmtKind::Return { value } => match (value, self.return_type) {
|
ast::StmtKind::Return { value } => match (value, self.function_data.return_type) {
|
||||||
(Some(v), Some(v1)) => {
|
(Some(v), Some(v1)) => {
|
||||||
self.unifier.unify(v.custom.unwrap(), v1)?;
|
self.unifier.unify(v.custom.unwrap(), v1)?;
|
||||||
}
|
}
|
||||||
|
@ -171,7 +176,6 @@ impl<'a> Inferencer<'a> {
|
||||||
) -> InferenceResult {
|
) -> InferenceResult {
|
||||||
let call =
|
let call =
|
||||||
Rc::new(Call { posargs: params, kwargs: HashMap::new(), ret, fun: RefCell::new(None) });
|
Rc::new(Call { posargs: params, kwargs: HashMap::new(), ret, fun: RefCell::new(None) });
|
||||||
self.calls.push(call.clone());
|
|
||||||
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into()));
|
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into()));
|
||||||
let fields = once((method, call)).collect();
|
let fields = once((method, call)).collect();
|
||||||
let record = self.unifier.add_record(fields);
|
let record = self.unifier.add_record(fields);
|
||||||
|
@ -207,13 +211,11 @@ impl<'a> Inferencer<'a> {
|
||||||
variable_mapping.extend(fn_args.iter().cloned());
|
variable_mapping.extend(fn_args.iter().cloned());
|
||||||
let ret = self.unifier.get_fresh_var().0;
|
let ret = self.unifier.get_fresh_var().0;
|
||||||
let mut new_context = Inferencer {
|
let mut new_context = Inferencer {
|
||||||
resolver: self.resolver,
|
function_data: self.function_data,
|
||||||
unifier: self.unifier,
|
unifier: self.unifier,
|
||||||
|
primitives: self.primitives,
|
||||||
virtual_checks: self.virtual_checks,
|
virtual_checks: self.virtual_checks,
|
||||||
variable_mapping,
|
variable_mapping,
|
||||||
calls: self.calls,
|
|
||||||
primitives: self.primitives,
|
|
||||||
return_type: self.return_type,
|
|
||||||
};
|
};
|
||||||
let fun = FunSignature {
|
let fun = FunSignature {
|
||||||
args: fn_args
|
args: fn_args
|
||||||
|
@ -250,13 +252,11 @@ impl<'a> Inferencer<'a> {
|
||||||
}
|
}
|
||||||
let variable_mapping = self.variable_mapping.clone();
|
let variable_mapping = self.variable_mapping.clone();
|
||||||
let mut new_context = Inferencer {
|
let mut new_context = Inferencer {
|
||||||
resolver: self.resolver,
|
function_data: self.function_data,
|
||||||
unifier: self.unifier,
|
unifier: self.unifier,
|
||||||
virtual_checks: self.virtual_checks,
|
virtual_checks: self.virtual_checks,
|
||||||
variable_mapping,
|
variable_mapping,
|
||||||
calls: self.calls,
|
|
||||||
primitives: self.primitives,
|
primitives: self.primitives,
|
||||||
return_type: self.return_type,
|
|
||||||
};
|
};
|
||||||
let elt = new_context.fold_expr(elt)?;
|
let elt = new_context.fold_expr(elt)?;
|
||||||
let generator = generators.pop().unwrap();
|
let generator = generators.pop().unwrap();
|
||||||
|
@ -315,7 +315,7 @@ impl<'a> Inferencer<'a> {
|
||||||
}
|
}
|
||||||
let arg0 = self.fold_expr(args.remove(0))?;
|
let arg0 = self.fold_expr(args.remove(0))?;
|
||||||
let ty = if let Some(arg) = args.pop() {
|
let ty = if let Some(arg) = args.pop() {
|
||||||
self.resolver
|
self.function_data.resolver
|
||||||
.parse_type_name(&arg)
|
.parse_type_name(&arg)
|
||||||
.ok_or_else(|| "error parsing type".to_string())?
|
.ok_or_else(|| "error parsing type".to_string())?
|
||||||
} else {
|
} else {
|
||||||
|
@ -379,7 +379,6 @@ impl<'a> Inferencer<'a> {
|
||||||
fun: RefCell::new(None),
|
fun: RefCell::new(None),
|
||||||
ret,
|
ret,
|
||||||
});
|
});
|
||||||
self.calls.push(call.clone());
|
|
||||||
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into()));
|
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into()));
|
||||||
self.unifier.unify(func.custom.unwrap(), call)?;
|
self.unifier.unify(func.custom.unwrap(), call)?;
|
||||||
|
|
||||||
|
@ -390,7 +389,7 @@ impl<'a> Inferencer<'a> {
|
||||||
if let Some(ty) = self.variable_mapping.get(id) {
|
if let Some(ty) = self.variable_mapping.get(id) {
|
||||||
Ok(*ty)
|
Ok(*ty)
|
||||||
} else {
|
} else {
|
||||||
Ok(self.resolver.get_symbol_type(id).unwrap_or_else(|| {
|
Ok(self.function_data.resolver.get_symbol_type(id).unwrap_or_else(|| {
|
||||||
let ty = self.unifier.get_fresh_var().0;
|
let ty = self.unifier.get_fresh_var().0;
|
||||||
self.variable_mapping.insert(id.to_string(), ty);
|
self.variable_mapping.insert(id.to_string(), ty);
|
||||||
ty
|
ty
|
||||||
|
|
|
@ -37,8 +37,7 @@ impl SymbolResolver for Resolver {
|
||||||
|
|
||||||
struct TestEnvironment {
|
struct TestEnvironment {
|
||||||
pub unifier: Unifier,
|
pub unifier: Unifier,
|
||||||
pub resolver: Box<dyn SymbolResolver>,
|
pub function_data: FunctionData,
|
||||||
pub calls: Vec<Rc<Call>>,
|
|
||||||
pub primitives: PrimitiveStore,
|
pub primitives: PrimitiveStore,
|
||||||
pub id_to_name: HashMap<usize, String>,
|
pub id_to_name: HashMap<usize, String>,
|
||||||
pub identifier_mapping: HashMap<String, Type>,
|
pub identifier_mapping: HashMap<String, Type>,
|
||||||
|
@ -149,24 +148,25 @@ impl TestEnvironment {
|
||||||
|
|
||||||
TestEnvironment {
|
TestEnvironment {
|
||||||
unifier,
|
unifier,
|
||||||
resolver,
|
function_data: FunctionData {
|
||||||
|
resolver,
|
||||||
|
bound_variables: Vec::new(),
|
||||||
|
return_type: None
|
||||||
|
},
|
||||||
primitives,
|
primitives,
|
||||||
id_to_name,
|
id_to_name,
|
||||||
identifier_mapping,
|
identifier_mapping,
|
||||||
calls: Vec::new(),
|
|
||||||
virtual_checks: Vec::new(),
|
virtual_checks: Vec::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_inferencer(&mut self) -> Inferencer {
|
fn get_inferencer(&mut self) -> Inferencer {
|
||||||
Inferencer {
|
Inferencer {
|
||||||
resolver: &mut self.resolver,
|
function_data: &mut self.function_data,
|
||||||
unifier: &mut self.unifier,
|
unifier: &mut self.unifier,
|
||||||
variable_mapping: Default::default(),
|
variable_mapping: Default::default(),
|
||||||
calls: &mut self.calls,
|
|
||||||
primitives: &mut self.primitives,
|
primitives: &mut self.primitives,
|
||||||
virtual_checks: &mut self.virtual_checks,
|
virtual_checks: &mut self.virtual_checks,
|
||||||
return_type: None,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -83,10 +83,6 @@ impl TypeEnum {
|
||||||
TypeEnum::TFunc { .. } => "TFunc",
|
TypeEnum::TFunc { .. } => "TFunc",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_concrete(&self) -> bool {
|
|
||||||
!matches!(self, TypeEnum::TVar { .. })
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Unifier {
|
pub struct Unifier {
|
||||||
|
@ -143,6 +139,23 @@ impl Unifier {
|
||||||
(self.add_ty(TypeEnum::TVar { id, range, meta: TypeVarMeta::Generic }), id)
|
(self.add_ty(TypeEnum::TVar { id, range, meta: TypeVarMeta::Generic }), id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool {
|
||||||
|
use TypeEnum::*;
|
||||||
|
match &*self.get_ty(a) {
|
||||||
|
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
|
||||||
|
TCall { .. } => false,
|
||||||
|
TList { ty } => self.is_concrete(*ty, allowed_typevars),
|
||||||
|
TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)),
|
||||||
|
TObj { params: vars, .. } => {
|
||||||
|
vars.values().all(|ty| self.is_concrete(*ty, allowed_typevars))
|
||||||
|
}
|
||||||
|
// functions are instantiated for each call sites, so the function type can contain
|
||||||
|
// type variables.
|
||||||
|
TFunc { .. } => true,
|
||||||
|
TVirtual { ty } => self.is_concrete(*ty, allowed_typevars),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn unify(&mut self, a: Type, b: Type) -> Result<(), String> {
|
pub fn unify(&mut self, a: Type, b: Type) -> Result<(), String> {
|
||||||
if self.unification_table.unioned(a, b) {
|
if self.unification_table.unioned(a, b) {
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -204,7 +217,7 @@ impl Unifier {
|
||||||
}
|
}
|
||||||
for v1 in old_range2.iter() {
|
for v1 in old_range2.iter() {
|
||||||
for v2 in range1.iter() {
|
for v2 in range1.iter() {
|
||||||
if let Ok(result) = self.get_intersection(*v1, *v2){
|
if let Ok(result) = self.get_intersection(*v1, *v2) {
|
||||||
range2.push(result.unwrap_or(*v2));
|
range2.push(result.unwrap_or(*v2));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -486,7 +499,7 @@ impl Unifier {
|
||||||
Err(format!("Cannot unify {} with {}", a.get_type_name(), b.get_type_name()))
|
Err(format!("Cannot unify {} with {}", a.get_type_name(), b.get_type_name()))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Instantiate a function if it hasn't been instntiated.
|
/// Instantiate a function if it hasn't been instantiated.
|
||||||
/// Returns Some(T) where T is the instantiated type.
|
/// Returns Some(T) where T is the instantiated type.
|
||||||
/// Returns None if the function is already instantiated.
|
/// Returns None if the function is already instantiated.
|
||||||
fn instantiate_fun(&mut self, ty: Type, fun: &FunSignature) -> Type {
|
fn instantiate_fun(&mut self, ty: Type, fun: &FunSignature) -> Type {
|
||||||
|
|
Loading…
Reference in New Issue