Compare commits

...

2 Commits

7 changed files with 67 additions and 22 deletions

View File

@ -538,13 +538,25 @@ impl InnerResolver {
let types = types?; let types = types?;
Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types }))) Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types })))
} }
(TypeEnum::TObj { params: var_map, fields, .. }, false) => { (TypeEnum::TObj { params, fields, .. }, false) => {
self.pyid_to_type.write().insert(ty_id, extracted_ty); let var_map = params
.iter()
.map(|(id_var, ty)| {
if let TypeEnum::TVar { id, range, name, loc, .. } =
&*unifier.get_ty(*ty)
{
assert_eq!(*id, *id_var);
(*id, unifier.get_fresh_var_with_range(range, *name, *loc).0)
} else {
unreachable!()
}
})
.collect::<HashMap<_, _>>();
let mut instantiate_obj = || { let mut instantiate_obj = || {
// loop through non-function fields of the class to get the instantiated value // loop through non-function fields of the class to get the instantiated value
for field in fields.iter() { for field in fields.iter() {
let name: String = (*field.0).into(); let name: String = (*field.0).into();
if let TypeEnum::TFunc(..) = &*unifier.get_ty(field.1 .0) { if let TypeEnum::TFunc(..) = &*unifier.get_ty(field.1.0) {
continue; continue;
} else { } else {
let field_data = obj.getattr(&name)?; let field_data = obj.getattr(&name)?;
@ -560,7 +572,7 @@ impl InnerResolver {
} }
}; };
let field_ty = let field_ty =
unifier.subst(field.1 .0, &var_map).unwrap_or(field.1 .0); unifier.subst(field.1.0, &var_map).unwrap_or(field.1.0);
if let Err(e) = unifier.unify(ty, field_ty) { if let Err(e) = unifier.unify(ty, field_ty) {
// field type mismatch // field type mismatch
return Ok(Err(format!( return Ok(Err(format!(
@ -577,14 +589,10 @@ impl InnerResolver {
return Ok(Err("object is not of concrete type".into())); return Ok(Err("object is not of concrete type".into()));
} }
} }
let extracted_ty = unifier.subst(extracted_ty, &var_map).unwrap_or(extracted_ty);
Ok(Ok(extracted_ty)) Ok(Ok(extracted_ty))
}; };
let result = instantiate_obj(); instantiate_obj()
// do not cache the type if there are errors
if matches!(result, Err(_) | Ok(Err(_))) {
self.pyid_to_type.write().remove(&ty_id);
}
result
} }
_ => Ok(Ok(extracted_ty)), _ => Ok(Ok(extracted_ty)),
} }

View File

@ -353,6 +353,7 @@ pub fn gen_func<'ctx, G: CodeGenerator>(
let (unifier, primitives) = &top_level_ctx.unifiers.read()[task.unifier_index]; let (unifier, primitives) = &top_level_ctx.unifiers.read()[task.unifier_index];
(Unifier::from_shared_unifier(unifier), *primitives) (Unifier::from_shared_unifier(unifier), *primitives)
}; };
unifier.top_level = Some(top_level_ctx.clone());
let mut cache = HashMap::new(); let mut cache = HashMap::new();
for (a, b) in task.subst.iter() { for (a, b) in task.subst.iter() {

View File

@ -735,10 +735,11 @@ impl TopLevelComposer {
} }
} }
let mut subst_list = Some(Vec::new());
// unification of previously assigned typevar // unification of previously assigned typevar
let mut unification_helper = |ty, def| { let mut unification_helper = |ty, def| {
let target_ty = let target_ty =
get_type_from_type_annotation_kinds(&temp_def_list, unifier, primitives, &def)?; get_type_from_type_annotation_kinds(&temp_def_list, unifier, primitives, &def, &mut subst_list)?;
unifier.unify(ty, target_ty).map_err(|e| e.to_display(unifier).to_string())?; unifier.unify(ty, target_ty).map_err(|e| e.to_display(unifier).to_string())?;
Ok(()) as Result<(), String> Ok(()) as Result<(), String>
}; };
@ -747,6 +748,29 @@ impl TopLevelComposer {
errors.insert(e); errors.insert(e);
} }
} }
for ty in subst_list.unwrap().into_iter() {
if let TypeEnum::TObj { obj_id, params, fields } = &*unifier.get_ty(ty) {
let mut new_fields = HashMap::new();
let mut need_subst = false;
for (name, (ty, mutable)) in fields.iter() {
let substituted = unifier.subst(*ty, params);
need_subst |= substituted.is_some();
new_fields.insert(*name, (substituted.unwrap_or(*ty), *mutable));
}
if need_subst {
let new_ty = unifier.add_ty(TypeEnum::TObj {
obj_id: *obj_id,
params: params.clone(),
fields: new_fields,
});
if let Err(e) = unifier.unify(ty, new_ty) {
errors.insert(e.to_display(unifier).to_string());
}
}
} else {
unreachable!()
}
}
if !errors.is_empty() { if !errors.is_empty() {
return Err(errors.into_iter().sorted().join("\n----------\n")); return Err(errors.into_iter().sorted().join("\n----------\n"));
} }
@ -867,6 +891,7 @@ impl TopLevelComposer {
unifier, unifier,
primitives_store, primitives_store,
&type_annotation, &type_annotation,
&mut None
)?; )?;
Ok(FuncArg { Ok(FuncArg {
@ -934,6 +959,7 @@ impl TopLevelComposer {
unifier, unifier,
primitives_store, primitives_store,
&return_ty_annotation, &return_ty_annotation,
&mut None
)? )?
} else { } else {
primitives_store.none primitives_store.none
@ -1498,6 +1524,7 @@ impl TopLevelComposer {
unifier, unifier,
primitives_ty, primitives_ty,
&make_self_type_annotation(type_vars, *object_id), &make_self_type_annotation(type_vars, *object_id),
&mut None
)?; )?;
if ancestors if ancestors
.iter() .iter()
@ -1666,6 +1693,7 @@ impl TopLevelComposer {
unifier, unifier,
primitives_ty, primitives_ty,
&ty_ann, &ty_ann,
&mut None
)?; )?;
Some((self_ty, type_vars.clone())) Some((self_ty, type_vars.clone()))
} else { } else {

View File

@ -273,6 +273,7 @@ pub fn get_type_from_type_annotation_kinds(
unifier: &mut Unifier, unifier: &mut Unifier,
primitives: &PrimitiveStore, primitives: &PrimitiveStore,
ann: &TypeAnnotation, ann: &TypeAnnotation,
subst_list: &mut Option<Vec<Type>>
) -> Result<Type, String> { ) -> Result<Type, String> {
match ann { match ann {
TypeAnnotation::CustomClass { id: obj_id, params } => { TypeAnnotation::CustomClass { id: obj_id, params } => {
@ -294,6 +295,7 @@ pub fn get_type_from_type_annotation_kinds(
unifier, unifier,
primitives, primitives,
x, x,
subst_list
) )
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
@ -349,12 +351,16 @@ pub fn get_type_from_type_annotation_kinds(
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty); let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*name, (subst_ty, *mutability)) (*name, (subst_ty, *mutability))
})); }));
let need_subst = !subst.is_empty();
Ok(unifier.add_ty(TypeEnum::TObj { let ty = unifier.add_ty(TypeEnum::TObj {
obj_id: *obj_id, obj_id: *obj_id,
fields: tobj_fields, fields: tobj_fields,
params: subst, params: subst,
})) });
if need_subst {
subst_list.as_mut().map(|wl| wl.push(ty));
}
Ok(ty)
} }
} else { } else {
unreachable!("should be class def here") unreachable!("should be class def here")
@ -367,6 +373,7 @@ pub fn get_type_from_type_annotation_kinds(
unifier, unifier,
primitives, primitives,
ty.as_ref(), ty.as_ref(),
subst_list
)?; )?;
Ok(unifier.add_ty(TypeEnum::TVirtual { ty })) Ok(unifier.add_ty(TypeEnum::TVirtual { ty }))
} }
@ -376,6 +383,7 @@ pub fn get_type_from_type_annotation_kinds(
unifier, unifier,
primitives, primitives,
ty.as_ref(), ty.as_ref(),
subst_list
)?; )?;
Ok(unifier.add_ty(TypeEnum::TList { ty })) Ok(unifier.add_ty(TypeEnum::TList { ty }))
} }
@ -383,7 +391,7 @@ pub fn get_type_from_type_annotation_kinds(
let tys = tys let tys = tys
.iter() .iter()
.map(|x| { .map(|x| {
get_type_from_type_annotation_kinds(top_level_defs, unifier, primitives, x) get_type_from_type_annotation_kinds(top_level_defs, unifier, primitives, x, subst_list)
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
Ok(unifier.add_ty(TypeEnum::TTuple { ty: tys })) Ok(unifier.add_ty(TypeEnum::TTuple { ty: tys }))

View File

@ -520,7 +520,7 @@ impl Unifier {
match (&*ty_a, &*ty_b) { match (&*ty_a, &*ty_b) {
( (
TVar { fields: fields1, id, name: name1, loc: loc1, .. }, TVar { fields: fields1, id, name: name1, loc: loc1, .. },
TVar { fields: fields2, name: name2, loc: loc2, .. }, TVar { fields: fields2, id: id2, name: name2, loc: loc2, .. },
) => { ) => {
let new_fields = match (fields1, fields2) { let new_fields = match (fields1, fields2) {
(None, None) => None, (None, None) => None,
@ -570,7 +570,7 @@ impl Unifier {
self.unification_table.set_value( self.unification_table.set_value(
a, a,
Rc::new(TypeEnum::TVar { Rc::new(TypeEnum::TVar {
id: *id, id: name1.map_or(*id2, |_| *id),
fields: new_fields, fields: new_fields,
range, range,
name: name1.or(*name2), name: name1.or(*name2),

View File

@ -61,7 +61,7 @@ impl<V> UnificationTable<V> {
if self.ranks[a] < self.ranks[b] { if self.ranks[a] < self.ranks[b] {
std::mem::swap(&mut a, &mut b); std::mem::swap(&mut a, &mut b);
} }
self.log.push(Action::Parent { key: b, original_parent: a }); self.log.push(Action::Parent { key: b, original_parent: self.parents[b] });
self.parents[b] = a; self.parents[b] = a;
if self.ranks[a] == self.ranks[b] { if self.ranks[a] == self.ranks[b] {
self.log.push(Action::Rank { key: a, original_rank: self.ranks[a] }); self.log.push(Action::Rank { key: a, original_rank: self.ranks[a] });
@ -106,7 +106,7 @@ impl<V> UnificationTable<V> {
// a = parent.parent // a = parent.parent
let a = self.parents[parent]; let a = self.parents[parent];
// root.parent = parent.parent // root.parent = parent.parent
self.log.push(Action::Parent { key: root, original_parent: a }); self.log.push(Action::Parent { key: root, original_parent: self.parents[root] });
self.parents[root] = a; self.parents[root] = a;
root = parent; root = parent;
// parent = root.parent // parent = root.parent
@ -145,8 +145,8 @@ impl<V> UnificationTable<V> {
pub fn discard_snapshot(&mut self, snapshot: (usize, u32)) { pub fn discard_snapshot(&mut self, snapshot: (usize, u32)) {
let (log_len, generation) = snapshot; let (log_len, generation) = snapshot;
assert!(self.log.len() >= log_len, "snapshot discard error"); assert!(self.log.len() >= log_len, "snapshot discard error");
assert!(matches!(self.log[log_len - 1], Action::Marker { generation: gen } if generation == gen), "snapshot discard error"); assert!(matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation), "snapshot discard error");
self.log.truncate(log_len - 1); self.log.clear();
} }
} }

View File

@ -85,7 +85,7 @@ fn main() {
Default::default(), Default::default(),
)?; )?;
get_type_from_type_annotation_kinds( get_type_from_type_annotation_kinds(
def_list, unifier, primitives, &ty, def_list, unifier, primitives, &ty, &mut None
) )
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;