hm-inference #6
|
@ -35,13 +35,11 @@ impl<'a> Inferencer<'a> {
|
|||
) -> Result<(), String> {
|
||||
// there are some cases where the custom field is None
|
||||
if let Some(ty) = &expr.custom {
|
||||
let ty = self.unifier.get_ty(*ty);
|
||||
let ty = ty.as_ref();
|
||||
if !ty.is_concrete() {
|
||||
if !self.unifier.is_concrete(*ty, &self.function_data.bound_variables) {
|
||||
return Err(format!(
|
||||
"expected concrete type at {} but got {}",
|
||||
expr.location,
|
||||
ty.get_type_name()
|
||||
self.unifier.get_ty(*ty).get_type_name()
|
||||
));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,14 +25,18 @@ pub struct PrimitiveStore {
|
|||
pub none: Type,
|
||||
}
|
||||
|
||||
pub struct FunctionData {
|
||||
pub resolver: Box<dyn SymbolResolver>,
|
||||
pub return_type: Option<Type>,
|
||||
pub bound_variables: Vec<Type>,
|
||||
}
|
||||
|
||||
pub struct Inferencer<'a> {
|
||||
pub resolver: &'a mut Box<dyn SymbolResolver>,
|
||||
pub function_data: &'a mut FunctionData,
|
||||
pub unifier: &'a mut Unifier,
|
||||
pub primitives: &'a PrimitiveStore,
|
||||
pub virtual_checks: &'a mut Vec<(Type, Type)>,
|
||||
pub variable_mapping: HashMap<String, Type>,
|
||||
pub calls: &'a mut Vec<Rc<Call>>,
|
||||
pub primitives: &'a PrimitiveStore,
|
||||
pub return_type: Option<Type>,
|
||||
}
|
||||
|
||||
struct NaiveFolder();
|
||||
|
@ -65,6 +69,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
|||
None
|
||||
};
|
||||
let annotation_type = self
|
||||
.function_data
|
||||
.resolver
|
||||
.parse_type_name(annotation.as_ref())
|
||||
.ok_or_else(|| "cannot parse type name".to_string())?;
|
||||
|
@ -93,7 +98,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
|||
}
|
||||
ast::StmtKind::AnnAssign { .. } | ast::StmtKind::Expr { .. } => {}
|
||||
ast::StmtKind::Break | ast::StmtKind::Continue => {}
|
||||
ast::StmtKind::Return { value } => match (value, self.return_type) {
|
||||
ast::StmtKind::Return { value } => match (value, self.function_data.return_type) {
|
||||
(Some(v), Some(v1)) => {
|
||||
self.unifier.unify(v.custom.unwrap(), v1)?;
|
||||
}
|
||||
|
@ -171,7 +176,6 @@ impl<'a> Inferencer<'a> {
|
|||
) -> InferenceResult {
|
||||
let call =
|
||||
Rc::new(Call { posargs: params, kwargs: HashMap::new(), ret, fun: RefCell::new(None) });
|
||||
self.calls.push(call.clone());
|
||||
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into()));
|
||||
let fields = once((method, call)).collect();
|
||||
let record = self.unifier.add_record(fields);
|
||||
|
@ -207,13 +211,11 @@ impl<'a> Inferencer<'a> {
|
|||
variable_mapping.extend(fn_args.iter().cloned());
|
||||
let ret = self.unifier.get_fresh_var().0;
|
||||
let mut new_context = Inferencer {
|
||||
resolver: self.resolver,
|
||||
function_data: self.function_data,
|
||||
unifier: self.unifier,
|
||||
primitives: self.primitives,
|
||||
virtual_checks: self.virtual_checks,
|
||||
variable_mapping,
|
||||
calls: self.calls,
|
||||
primitives: self.primitives,
|
||||
return_type: self.return_type,
|
||||
};
|
||||
let fun = FunSignature {
|
||||
args: fn_args
|
||||
|
@ -250,13 +252,11 @@ impl<'a> Inferencer<'a> {
|
|||
}
|
||||
let variable_mapping = self.variable_mapping.clone();
|
||||
let mut new_context = Inferencer {
|
||||
resolver: self.resolver,
|
||||
function_data: self.function_data,
|
||||
unifier: self.unifier,
|
||||
virtual_checks: self.virtual_checks,
|
||||
variable_mapping,
|
||||
calls: self.calls,
|
||||
primitives: self.primitives,
|
||||
return_type: self.return_type,
|
||||
};
|
||||
let elt = new_context.fold_expr(elt)?;
|
||||
let generator = generators.pop().unwrap();
|
||||
|
@ -315,7 +315,7 @@ impl<'a> Inferencer<'a> {
|
|||
}
|
||||
let arg0 = self.fold_expr(args.remove(0))?;
|
||||
let ty = if let Some(arg) = args.pop() {
|
||||
self.resolver
|
||||
self.function_data.resolver
|
||||
.parse_type_name(&arg)
|
||||
.ok_or_else(|| "error parsing type".to_string())?
|
||||
} else {
|
||||
|
@ -379,7 +379,6 @@ impl<'a> Inferencer<'a> {
|
|||
fun: RefCell::new(None),
|
||||
ret,
|
||||
});
|
||||
self.calls.push(call.clone());
|
||||
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into()));
|
||||
self.unifier.unify(func.custom.unwrap(), call)?;
|
||||
|
||||
|
@ -390,7 +389,7 @@ impl<'a> Inferencer<'a> {
|
|||
if let Some(ty) = self.variable_mapping.get(id) {
|
||||
Ok(*ty)
|
||||
} else {
|
||||
Ok(self.resolver.get_symbol_type(id).unwrap_or_else(|| {
|
||||
Ok(self.function_data.resolver.get_symbol_type(id).unwrap_or_else(|| {
|
||||
let ty = self.unifier.get_fresh_var().0;
|
||||
self.variable_mapping.insert(id.to_string(), ty);
|
||||
ty
|
||||
|
|
|
@ -37,8 +37,7 @@ impl SymbolResolver for Resolver {
|
|||
|
||||
struct TestEnvironment {
|
||||
pub unifier: Unifier,
|
||||
pub resolver: Box<dyn SymbolResolver>,
|
||||
pub calls: Vec<Rc<Call>>,
|
||||
pub function_data: FunctionData,
|
||||
pub primitives: PrimitiveStore,
|
||||
pub id_to_name: HashMap<usize, String>,
|
||||
pub identifier_mapping: HashMap<String, Type>,
|
||||
|
@ -149,24 +148,25 @@ impl TestEnvironment {
|
|||
|
||||
TestEnvironment {
|
||||
unifier,
|
||||
resolver,
|
||||
function_data: FunctionData {
|
||||
resolver,
|
||||
bound_variables: Vec::new(),
|
||||
return_type: None
|
||||
},
|
||||
primitives,
|
||||
id_to_name,
|
||||
identifier_mapping,
|
||||
calls: Vec::new(),
|
||||
virtual_checks: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_inferencer(&mut self) -> Inferencer {
|
||||
Inferencer {
|
||||
resolver: &mut self.resolver,
|
||||
function_data: &mut self.function_data,
|
||||
unifier: &mut self.unifier,
|
||||
variable_mapping: Default::default(),
|
||||
calls: &mut self.calls,
|
||||
primitives: &mut self.primitives,
|
||||
virtual_checks: &mut self.virtual_checks,
|
||||
return_type: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -83,10 +83,6 @@ impl TypeEnum {
|
|||
TypeEnum::TFunc { .. } => "TFunc",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_concrete(&self) -> bool {
|
||||
!matches!(self, TypeEnum::TVar { .. })
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Unifier {
|
||||
|
@ -143,6 +139,23 @@ impl Unifier {
|
|||
(self.add_ty(TypeEnum::TVar { id, range, meta: TypeVarMeta::Generic }), id)
|
||||
}
|
||||
|
||||
pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool {
|
||||
use TypeEnum::*;
|
||||
match &*self.get_ty(a) {
|
||||
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
|
||||
TCall { .. } => false,
|
||||
TList { ty } => self.is_concrete(*ty, allowed_typevars),
|
||||
TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)),
|
||||
TObj { params: vars, .. } => {
|
||||
vars.values().all(|ty| self.is_concrete(*ty, allowed_typevars))
|
||||
}
|
||||
// functions are instantiated for each call sites, so the function type can contain
|
||||
// type variables.
|
||||
TFunc { .. } => true,
|
||||
TVirtual { ty } => self.is_concrete(*ty, allowed_typevars),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unify(&mut self, a: Type, b: Type) -> Result<(), String> {
|
||||
if self.unification_table.unioned(a, b) {
|
||||
Ok(())
|
||||
|
@ -204,7 +217,7 @@ impl Unifier {
|
|||
}
|
||||
for v1 in old_range2.iter() {
|
||||
for v2 in range1.iter() {
|
||||
if let Ok(result) = self.get_intersection(*v1, *v2){
|
||||
if let Ok(result) = self.get_intersection(*v1, *v2) {
|
||||
range2.push(result.unwrap_or(*v2));
|
||||
}
|
||||
}
|
||||
|
@ -486,7 +499,7 @@ impl Unifier {
|
|||
Err(format!("Cannot unify {} with {}", a.get_type_name(), b.get_type_name()))
|
||||
}
|
||||
|
||||
/// Instantiate a function if it hasn't been instntiated.
|
||||
/// Instantiate a function if it hasn't been instantiated.
|
||||
/// Returns Some(T) where T is the instantiated type.
|
||||
/// Returns None if the function is already instantiated.
|
||||
fn instantiate_fun(&mut self, ty: Type, fun: &FunSignature) -> Type {
|
||||
|
|
Loading…
Reference in New Issue