1
0
forked from M-Labs/nac3

remove self kind and extra primitive info in the return of top level composer constructor, adding some helper function for type annotation

This commit is contained in:
ychenfo 2021-08-25 13:39:55 +08:00
parent e2b11c3fee
commit 862d205f67
3 changed files with 132 additions and 84 deletions

View File

@ -1,6 +1,4 @@
use std::borrow::BorrowMut; use std::{collections::{HashMap, HashSet}, sync::Arc, ops::{Deref, DerefMut}, borrow::BorrowMut};
use std::ops::{Deref, DerefMut};
use std::{collections::HashMap, collections::HashSet, sync::Arc};
use super::typecheck::type_inferencer::PrimitiveStore; use super::typecheck::type_inferencer::PrimitiveStore;
use super::typecheck::typedef::{FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, Unifier}; use super::typecheck::typedef::{FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, Unifier};
@ -87,6 +85,12 @@ pub struct TopLevelComposer {
pub keyword_list: Vec<String>, pub keyword_list: Vec<String>,
} }
impl Default for TopLevelComposer {
fn default() -> Self {
Self::new()
}
}
impl TopLevelComposer { impl TopLevelComposer {
pub fn make_top_level_context(self) -> TopLevelContext { pub fn make_top_level_context(self) -> TopLevelContext {
TopLevelContext { TopLevelContext {
@ -134,7 +138,7 @@ impl TopLevelComposer {
/// return a composer and things to make a "primitive" symbol resolver, so that the symbol /// return a composer and things to make a "primitive" symbol resolver, so that the symbol
/// resolver can later figure out primitive type definitions when passed a primitive type name /// resolver can later figure out primitive type definitions when passed a primitive type name
// TODO: add list and tuples? // TODO: add list and tuples?
pub fn new() -> (Vec<(String, DefinitionId, Type)>, Self) { pub fn new() -> Self {
let primitives = Self::make_primitives(); let primitives = Self::make_primitives();
let top_level_def_list = vec![ let top_level_def_list = vec![
@ -147,7 +151,7 @@ impl TopLevelComposer {
let ast_list: Vec<Option<ast::Stmt<()>>> = vec![None, None, None, None, None]; let ast_list: Vec<Option<ast::Stmt<()>>> = vec![None, None, None, None, None];
let composer = TopLevelComposer { TopLevelComposer {
definition_ast_list: izip!(top_level_def_list, ast_list).collect_vec(), definition_ast_list: izip!(top_level_def_list, ast_list).collect_vec(),
primitives_ty: primitives.0, primitives_ty: primitives.0,
unifier: primitives.1, unifier: primitives.1,
@ -165,17 +169,7 @@ impl TopLevelComposer {
"none".into(), "none".into(),
"None".into(), "None".into(),
], ],
}; }
(
vec![
("int32".into(), DefinitionId(0), composer.primitives_ty.int32),
("int64".into(), DefinitionId(1), composer.primitives_ty.int64),
("float".into(), DefinitionId(2), composer.primitives_ty.float),
("bool".into(), DefinitionId(3), composer.primitives_ty.bool),
("none".into(), DefinitionId(4), composer.primitives_ty.none),
],
composer,
)
} }
/// already include the definition_id of itself inside the ancestors vector /// already include the definition_id of itself inside the ancestors vector
@ -191,7 +185,7 @@ impl TopLevelComposer {
type_vars: Default::default(), type_vars: Default::default(),
fields: Default::default(), fields: Default::default(),
methods: Default::default(), methods: Default::default(),
ancestors: vec![TypeAnnotation::SelfTypeKind(DefinitionId(index))], ancestors: Default::default(),
resolver, resolver,
} }
} }
@ -265,6 +259,8 @@ impl TopLevelComposer {
DefinitionId, DefinitionId,
Type, Type,
)> = Vec::new(); )> = Vec::new();
// we do not push anything to the def list, so we keep track of the index
// and then push in the correct order after the for loop
let mut class_method_index_offset = 0; let mut class_method_index_offset = 0;
for b in body { for b in body {
if let ast::StmtKind::FunctionDef { name: method_name, .. } = &b.node { if let ast::StmtKind::FunctionDef { name: method_name, .. } = &b.node {
@ -274,6 +270,7 @@ impl TopLevelComposer {
return Err("duplicate class method definition".into()); return Err("duplicate class method definition".into());
} }
let method_def_id = self.definition_ast_list.len() + { let method_def_id = self.definition_ast_list.len() + {
// plus 1 here since we already have the class def
class_method_index_offset += 1; class_method_index_offset += 1;
class_method_index_offset class_method_index_offset
}; };
@ -388,7 +385,10 @@ impl TopLevelComposer {
// should update the TopLevelDef::Class.typevars and the TypeEnum::TObj.params // should update the TopLevelDef::Class.typevars and the TypeEnum::TObj.params
ast::ExprKind::Subscript { value, slice, .. } ast::ExprKind::Subscript { value, slice, .. }
if { if {
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == "Generic") matches!(
&value.node,
ast::ExprKind::Name { id, .. } if id == "Generic"
)
} => } =>
{ {
if !is_generic { if !is_generic {
@ -437,14 +437,9 @@ impl TopLevelComposer {
let type_vars = type_vars let type_vars = type_vars
.into_iter() .into_iter()
.map(|x| { .map(|x| {
let range = unifier.get_ty(x); // must be type var here after previous check
if let TypeEnum::TVar { id, range, .. } = range.as_ref() { let dup = duplicate_type_var(unifier, x);
let range = &*range.borrow(); (dup.1, (dup.0).0)
let range = range.as_slice();
(*id, unifier.get_fresh_var_with_range(range).0)
} else {
unreachable!("must be type var here after previous check");
}
}) })
.collect_vec(); .collect_vec();
@ -465,13 +460,13 @@ impl TopLevelComposer {
let temp_def_list = self.extract_def_list(); let temp_def_list = self.extract_def_list();
for (class_def, class_ast) in self.definition_ast_list.iter_mut() { for (class_def, class_ast) in self.definition_ast_list.iter_mut() {
let mut class_def = class_def.write(); let mut class_def = class_def.write();
let (class_bases, class_ancestors, class_resolver) = { let (class_bases, class_ancestors, class_resolver, class_id) = {
if let TopLevelDef::Class { ancestors, resolver, .. } = class_def.deref_mut() { if let TopLevelDef::Class { ancestors, resolver, object_id, .. } = class_def.deref_mut() {
if let Some(ast::Located { if let Some(ast::Located {
node: ast::StmtKind::ClassDef { bases, .. }, .. node: ast::StmtKind::ClassDef { bases, .. }, ..
}) = class_ast }) = class_ast
{ {
(bases, ancestors, resolver) (bases, ancestors, resolver, *object_id)
} else { } else {
unreachable!("must be both class") unreachable!("must be both class")
} }
@ -511,7 +506,7 @@ impl TopLevelComposer {
b, b,
)?; )?;
if let TypeAnnotation::ConcretizedCustomClassKind { .. } = &base_ty { if let TypeAnnotation::CustomClassKind { .. } = &base_ty {
// TODO: check to prevent cyclic base class // TODO: check to prevent cyclic base class
class_ancestors.push(base_ty); class_ancestors.push(base_ty);
} else { } else {
@ -520,6 +515,11 @@ impl TopLevelComposer {
); );
} }
} }
// push self to the ancestors
class_ancestors.push(
make_self_type_annotation(&temp_def_list, class_id, self.unifier.borrow_mut())?
)
} }
Ok(()) Ok(())
} }
@ -611,7 +611,7 @@ impl TopLevelComposer {
primitives_store, primitives_store,
annotation, annotation,
)?; )?;
if let TypeEnum::TVar { id, range, .. } = if let TypeEnum::TVar { id, .. } =
unifier.get_ty(ty).as_ref() unifier.get_ty(ty).as_ref()
{ {
if let Some(occured_ty) = occured_type_var.get(id) { if let Some(occured_ty) = occured_type_var.get(id) {
@ -619,19 +619,16 @@ impl TopLevelComposer {
ty = *occured_ty; ty = *occured_ty;
} else { } else {
// if not, create a duplicate // if not, create a duplicate
let range = range.borrow(); let ty_copy = duplicate_type_var(unifier, ty);
let range = range.as_slice(); ty = ty_copy.0.0;
let ty_copy = unifier.get_fresh_var_with_range(range);
ty = ty_copy.0;
occured_type_var.insert(*id, ty); occured_type_var.insert(*id, ty);
function_var_map.insert(ty_copy.1, ty_copy.0); function_var_map.insert(ty_copy.1, ty_copy.0.0);
} }
} }
Ok(FuncArg { Ok(FuncArg {
name: x.node.arg.clone(), name: x.node.arg.clone(),
ty, ty,
// TODO: function type var
default_value: Default::default(), default_value: Default::default(),
}) })
}) })
@ -779,7 +776,7 @@ impl TopLevelComposer {
unifier.unify(*ty, associated[0].1)?; unifier.unify(*ty, associated[0].1)?;
} }
_ => { _ => {
unreachable!("should not be duplicate type var"); unreachable!("there should not be duplicate type var");
} }
} }
@ -813,7 +810,10 @@ impl TopLevelComposer {
default_value: None, default_value: None,
}; };
type_var_to_concrete_def type_var_to_concrete_def
.insert(dummy_func_arg.ty, TypeAnnotation::SelfTypeKind(*class_id)); .insert(
dummy_func_arg.ty,
make_self_type_annotation(temp_def_list, *class_id, unifier)?
);
result.push(dummy_func_arg); result.push(dummy_func_arg);
} }
} }
@ -839,7 +839,10 @@ impl TopLevelComposer {
// if is the "__init__" function, the return type is self // if is the "__init__" function, the return type is self
let dummy_return_type = unifier.get_fresh_var().0; let dummy_return_type = unifier.get_fresh_var().0;
type_var_to_concrete_def type_var_to_concrete_def
.insert(dummy_return_type, TypeAnnotation::SelfTypeKind(*class_id)); .insert(
dummy_return_type,
make_self_type_annotation(temp_def_list, *class_id, unifier)?
);
dummy_return_type dummy_return_type
} }
}; };
@ -855,7 +858,7 @@ impl TopLevelComposer {
if name == "__init__" { if name == "__init__" {
for b in body { for b in body {
let mut defined_fields: HashSet<String> = HashSet::new(); let mut defined_fields: HashSet<String> = HashSet::new();
// TODO: check the type of value, field instantiation check // TODO: check the type of value, field instantiation check?
if let ast::StmtKind::AnnAssign { annotation, target, value: _, .. } = if let ast::StmtKind::AnnAssign { annotation, target, value: _, .. } =
&b.node &b.node
{ {

View File

@ -1,15 +1,17 @@
use super::*; use super::*;
use crate::typecheck::typedef::TypeVarMeta;
#[derive(Clone)] #[derive(Clone)]
pub enum TypeAnnotation { pub enum TypeAnnotation {
PrimitiveKind(Type), PrimitiveKind(Type),
ConcretizedCustomClassKind { // we use type vars kind at
// params to represent self type
CustomClassKind {
id: DefinitionId, id: DefinitionId,
// can not be type var, others are all fine // can not be type var, others are all fine
// TODO: can also be type var? // TODO: can also be type var?
params: Vec<TypeAnnotation>, params: Vec<TypeAnnotation>,
}, },
// can only be ConcretizedCustomClassKind // can only be CustomClassKind
VirtualKind(Box<TypeAnnotation>), VirtualKind(Box<TypeAnnotation>),
// the first u32 refers to the var_id of the // the first u32 refers to the var_id of the
// TVar returned by the symbol resolver, // TVar returned by the symbol resolver,
@ -17,7 +19,6 @@ pub enum TypeAnnotation {
// associated with class/functions // associated with class/functions
// since when associating we create a copy of type vars // since when associating we create a copy of type vars
TypeVarKind(u32, Type), TypeVarKind(u32, Type),
SelfTypeKind(DefinitionId),
} }
pub fn parse_ast_to_type_annotation_kinds<T>( pub fn parse_ast_to_type_annotation_kinds<T>(
@ -38,7 +39,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
if let Some(obj_id) = resolver.get_identifier_def(x) { if let Some(obj_id) = resolver.get_identifier_def(x) {
let def = top_level_defs[obj_id.0].read(); let def = top_level_defs[obj_id.0].read();
if let TopLevelDef::Class { .. } = &*def { if let TopLevelDef::Class { .. } = &*def {
Ok(TypeAnnotation::ConcretizedCustomClassKind { Ok(TypeAnnotation::CustomClassKind {
id: obj_id, id: obj_id,
params: vec![], params: vec![],
}) })
@ -46,17 +47,16 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
Err("function cannot be used as a type".into()) Err("function cannot be used as a type".into())
} }
} else if let Some(ty) = resolver.get_symbol_type(unifier, primitives, id) { } else if let Some(ty) = resolver.get_symbol_type(unifier, primitives, id) {
if let TypeEnum::TVar { id, meta: TypeVarMeta::Generic, range } = if let TypeEnum::TVar { id, .. } = unifier.get_ty(ty).as_ref()
unifier.get_ty(ty).as_ref()
{ {
// NOTE: always create a new one here // NOTE: always create a new one here
// and later unify if needed // and later unify if needed
// but record the var_id of the original type var returned by symbol resolver // but record the var_id of the original type var
let range = range.borrow(); // returned by symbol resolver
let range = range.as_slice();
Ok(TypeAnnotation::TypeVarKind( Ok(TypeAnnotation::TypeVarKind(
*id, *id,
unifier.get_fresh_var_with_range(range).0, // TODO: maybe not duplicate will also be fine here?
duplicate_type_var(unifier, ty).0.0
)) ))
} else { } else {
Err("not a type variable identifier".into()) Err("not a type variable identifier".into())
@ -67,7 +67,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
} }
}, },
// TODO: subscript or call // TODO: subscript or call?
ast::ExprKind::Subscript { value, slice, .. } ast::ExprKind::Subscript { value, slice, .. }
if { matches!(&value.node, ast::ExprKind::Name { id, .. } if id == "virtual") } => if { matches!(&value.node, ast::ExprKind::Name { id, .. } if id == "virtual") } =>
{ {
@ -78,7 +78,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
primitives, primitives,
slice.as_ref(), slice.as_ref(),
)?; )?;
if !matches!(def, TypeAnnotation::ConcretizedCustomClassKind { .. }) { if !matches!(def, TypeAnnotation::CustomClassKind { .. }) {
unreachable!("must be concretized custom class kind in the virtual") unreachable!("must be concretized custom class kind in the virtual")
} }
Ok(TypeAnnotation::VirtualKind(def.into())) Ok(TypeAnnotation::VirtualKind(def.into()))
@ -116,7 +116,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
)?] )?]
}; };
// NOTE: allow type var in class generic application list // NOTE: allow type var in class generic application list
Ok(TypeAnnotation::ConcretizedCustomClassKind { Ok(TypeAnnotation::CustomClassKind {
id: obj_id, id: obj_id,
params: param_type_infos, params: param_type_infos,
}) })
@ -139,7 +139,7 @@ pub fn get_type_from_type_annotation_kinds(
ann: &TypeAnnotation, ann: &TypeAnnotation,
) -> Result<Type, String> { ) -> Result<Type, String> {
match ann { match ann {
TypeAnnotation::ConcretizedCustomClassKind { id, params } => { TypeAnnotation::CustomClassKind { id, params } => {
let class_def = top_level_defs[id.0].read(); let class_def = top_level_defs[id.0].read();
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*class_def { if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*class_def {
if type_vars.len() != params.len() { if type_vars.len() != params.len() {
@ -160,6 +160,7 @@ pub fn get_type_from_type_annotation_kinds(
) )
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
// FIXME: TODO: cannot directy subst type var here? need to subst types in fields/methods
let subst = type_vars let subst = type_vars
.iter() .iter()
.map(|x| { .map(|x| {
@ -195,33 +196,6 @@ pub fn get_type_from_type_annotation_kinds(
unreachable!("should be class def here") unreachable!("should be class def here")
} }
} }
TypeAnnotation::SelfTypeKind(obj_id) => {
let class_def = top_level_defs[obj_id.0].read();
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*class_def {
let subst = type_vars
.iter()
.map(|x| {
if let TypeEnum::TVar { id, .. } = unifier.get_ty(x.1).as_ref() {
(*id, x.1)
} else {
unreachable!()
}
})
.collect::<HashMap<u32, Type>>();
let mut tobj_fields = methods
.iter()
.map(|(name, ty, _)| (name.clone(), *ty))
.collect::<HashMap<String, Type>>();
tobj_fields.extend(fields.clone().into_iter());
Ok(unifier.add_ty(TypeEnum::TObj {
obj_id: *obj_id,
fields: tobj_fields.into(),
params: subst.into(),
}))
} else {
unreachable!("should be class def here")
}
}
TypeAnnotation::PrimitiveKind(ty) => Ok(*ty), TypeAnnotation::PrimitiveKind(ty) => Ok(*ty),
TypeAnnotation::TypeVarKind(_, ty) => Ok(*ty), TypeAnnotation::TypeVarKind(_, ty) => Ok(*ty),
TypeAnnotation::VirtualKind(ty) => { TypeAnnotation::VirtualKind(ty) => {
@ -235,3 +209,74 @@ pub fn get_type_from_type_annotation_kinds(
} }
} }
} }
/// the first return is the duplicated type \
/// the second return is the var_id of the duplicated type \
/// the third return is the var_id of the original type
#[inline]
pub fn duplicate_type_var(
unifier: &mut Unifier,
type_var: Type
) -> ((Type, u32), u32) {
let ty = unifier.get_ty(type_var);
if let TypeEnum::TVar { id, range, .. } = ty.as_ref() {
let range = range.borrow();
let range = range.as_slice();
(unifier.get_fresh_var_with_range(range), *id)
} else {
unreachable!("must be type var here to be duplicated");
}
}
/// given an def id, return a type annotation of self \
/// ```python
/// class A(Generic[T, V]):
/// def fun(self):
/// ```
/// the type of `self` should be equivalent to `A[T, V]`, where `T`, `V`
/// considered to be type variables associated with the class
pub fn make_self_type_annotation(
top_level_defs: &[Arc<RwLock<TopLevelDef>>],
def_id: DefinitionId,
unifier: &mut Unifier,
) -> Result<TypeAnnotation, String> {
let obj_def = top_level_defs
.get(def_id.0)
.ok_or_else(|| "invalid definition id".to_string())?;
let obj_def = obj_def.read();
let obj_def = obj_def.deref();
if let TopLevelDef::Class { type_vars, .. } = obj_def {
Ok(TypeAnnotation::CustomClassKind {
id: def_id,
params: type_vars
.iter()
.map(|(var_id, ty)| TypeAnnotation::TypeVarKind(
*var_id,
duplicate_type_var(unifier, *ty).0.0
))
.collect_vec()
})
} else {
unreachable!("must be top level class def here")
}
}
/// get all the occurences of type vars contained in a type annotation
/// e.g. `A[int, B[T], V]` => [T, V]
pub fn get_type_var_contained_in_type_annotation(ann: &TypeAnnotation) -> Vec<TypeAnnotation> {
let mut result: Vec<TypeAnnotation> = Vec::new();
match ann {
TypeAnnotation::TypeVarKind( .. ) => result.push(ann.clone()),
TypeAnnotation::VirtualKind(ann) => result.extend(
get_type_var_contained_in_type_annotation(ann.as_ref())
),
TypeAnnotation::CustomClassKind { params, .. } => {
for p in params {
result.extend(get_type_var_contained_in_type_annotation(p));
}
},
_ => { }
}
result
}

View File

@ -51,7 +51,7 @@ fn main() {
} }
}; };
let (_, composer) = TopLevelComposer::new(); let composer = TopLevelComposer::new();
let mut unifier = composer.unifier.clone(); let mut unifier = composer.unifier.clone();
let primitives = composer.primitives_ty; let primitives = composer.primitives_ty;
let top_level = Arc::new(composer.make_top_level_context()); let top_level = Arc::new(composer.make_top_level_context());