polymorphism and inheritance related fixes #92

Closed
ychenfo wants to merge 7 commits from range_with_class into master
10 changed files with 782 additions and 388 deletions

View File

@ -8,18 +8,23 @@ import nac3artiq
__all__ = ["KernelInvariant", "extern", "kernel", "portable", "nac3", __all__ = ["KernelInvariant", "extern", "kernel", "portable", "nac3",
"ms", "us", "ns", "ms", "us", "ns",
"Core", "TTLOut", "parallel", "sequential"] "Core", "TTLOut", "parallel", "sequential", "virtual"]
import device_db import device_db
core_arguments = device_db.device_db["core"]["arguments"] core_arguments = device_db.device_db["core"]["arguments"]
T = TypeVar('T')
# place the `virtual` class infront of the construct of NAC3 object to ensure the
# virtual class is known during the initializing of NAC3 object
class virtual(Generic[T]):
pass
compiler = nac3artiq.NAC3(core_arguments["target"]) compiler = nac3artiq.NAC3(core_arguments["target"])
allow_module_registration = True allow_module_registration = True
registered_modules = set() registered_modules = set()
nac3annotated_class_ids = set()
T = TypeVar('T')
class KernelInvariant(Generic[T]): class KernelInvariant(Generic[T]):
pass pass
@ -64,6 +69,7 @@ def nac3(cls):
All classes containing kernels or portable methods must use this decorator. All classes containing kernels or portable methods must use this decorator.
""" """
register_module_of(cls) register_module_of(cls)
nac3annotated_class_ids.add(id(cls))
return cls return cls
@ -106,7 +112,7 @@ class Core:
def run(self, method, *args, **kwargs): def run(self, method, *args, **kwargs):
global allow_module_registration global allow_module_registration
if allow_module_registration: if allow_module_registration:
compiler.analyze_modules(registered_modules) compiler.analyze_modules(registered_modules, nac3annotated_class_ids)
allow_module_registration = False allow_module_registration = False
if hasattr(method, "__self__"): if hasattr(method, "__self__"):

View File

@ -11,7 +11,7 @@ use inkwell::{
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::{exceptions, types::PyList, types::PySet, types::PyBytes}; use pyo3::{exceptions, types::PyList, types::PySet, types::PyBytes};
use nac3parser::{ use nac3parser::{
ast::{self, StrRef}, ast::{self, StrRef, Constant::Str},
parser::{self, parse_program}, parser::{self, parse_program},
}; };
@ -51,6 +51,10 @@ pub struct PrimitivePythonId {
bool: u64, bool: u64,
list: u64, list: u64,
tuple: u64, tuple: u64,
typevar: u64,
none: u64,
generic_alias: (u64, u64),
virtual_id: u64,
} }
// TopLevelComposer is unsendable as it holds the unification table, which is // TopLevelComposer is unsendable as it holds the unification table, which is
@ -72,7 +76,7 @@ struct Nac3 {
} }
impl Nac3 { impl Nac3 {
fn register_module_impl(&mut self, obj: PyObject) -> PyResult<()> { fn register_module_impl(&mut self, obj: PyObject, nac3_annotated_cls: &PySet) -> PyResult<()> {
let mut name_to_pyid: HashMap<StrRef, u64> = HashMap::new(); let mut name_to_pyid: HashMap<StrRef, u64> = HashMap::new();
let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> { let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> {
let obj: &PyAny = obj.extract(py)?; let obj: &PyAny = obj.extract(py)?;
@ -107,7 +111,7 @@ impl Nac3 {
global_value_ids: self.global_value_ids.clone(), global_value_ids: self.global_value_ids.clone(),
class_names: Default::default(), class_names: Default::default(),
name_to_pyid: name_to_pyid.clone(), name_to_pyid: name_to_pyid.clone(),
module: obj, module: obj.clone(),
}) as Arc<dyn SymbolResolver + Send + Sync>; }) as Arc<dyn SymbolResolver + Send + Sync>;
let mut name_to_def = HashMap::new(); let mut name_to_def = HashMap::new();
let mut name_to_type = HashMap::new(); let mut name_to_type = HashMap::new();
@ -117,6 +121,7 @@ impl Nac3 {
ast::StmtKind::ClassDef { ast::StmtKind::ClassDef {
ref decorator_list, ref decorator_list,
ref mut body, ref mut body,
ref mut bases,
.. ..
} => { } => {
let kernels = decorator_list.iter().any(|decorator| { let kernels = decorator_list.iter().any(|decorator| {
@ -142,6 +147,33 @@ impl Nac3 {
true true
} }
}); });
bases.retain(|b| {
Python::with_gil(|py| -> PyResult<bool> {
let obj: &PyAny = obj.extract(py)?;
let annot_check = |id: &str| -> bool {
let id = py.eval(
&format!("id({})", id),
Some(obj.getattr("__dict__").unwrap().extract().unwrap()),
None
).unwrap();
nac3_annotated_cls.contains(id).unwrap()
};
match &b.node {
ast::ExprKind::Name { id, .. } => Ok(annot_check(&id.to_string())),
ast::ExprKind::Constant { value: Str(id), .. } =>
Ok(annot_check(id.split('[').next().unwrap())),
ast::ExprKind::Subscript { value, .. } => {
match &value.node {
ast::ExprKind::Name { id, .. } => Ok(annot_check(&id.to_string()) || *id == "Generic".into()),
ast::ExprKind::Constant { value: Str(id), .. } =>
Ok(annot_check(id.split('[').next().unwrap())),
_ => unreachable!("unsupported base declaration")
}
}
_ => unreachable!("unsupported base declaration")
}
}).unwrap()
});
kernels kernels
} }
ast::StmtKind::FunctionDef { ast::StmtKind::FunctionDef {
@ -246,7 +278,42 @@ impl Nac3 {
let builtins_mod = PyModule::import(py, "builtins").unwrap(); let builtins_mod = PyModule::import(py, "builtins").unwrap();
let id_fn = builtins_mod.getattr("id").unwrap(); let id_fn = builtins_mod.getattr("id").unwrap();
let numpy_mod = PyModule::import(py, "numpy").unwrap(); let numpy_mod = PyModule::import(py, "numpy").unwrap();
let typing_mod = PyModule::import(py, "typing").unwrap();
let types_mod = PyModule::import(py, "types").unwrap();
let primitive_ids = PrimitivePythonId { let primitive_ids = PrimitivePythonId {
virtual_id: id_fn
.call1((builtins_mod
.getattr("globals")
.unwrap()
.call0()
.unwrap()
.get_item("virtual")
.unwrap(),
)).unwrap()
.extract()
.unwrap(),
generic_alias: (
id_fn
.call1((typing_mod.getattr("_GenericAlias").unwrap(),))
.unwrap()
.extract()
.unwrap(),
id_fn
.call1((types_mod.getattr("GenericAlias").unwrap(),))
.unwrap()
.extract()
.unwrap(),
),
none: id_fn
.call1((builtins_mod.getattr("None").unwrap(),))
.unwrap()
.extract()
.unwrap(),
typevar: id_fn
.call1((typing_mod.getattr("TypeVar").unwrap(),))
.unwrap()
.extract()
.unwrap(),
int: id_fn int: id_fn
.call1((builtins_mod.getattr("int").unwrap(),)) .call1((builtins_mod.getattr("int").unwrap(),))
.unwrap() .unwrap()
@ -303,9 +370,9 @@ impl Nac3 {
}) })
} }
fn analyze_modules(&mut self, modules: &PySet) -> PyResult<()> { fn analyze_modules(&mut self, modules: &PySet, nac3_annotated_cls: &PySet) -> PyResult<()> {
for obj in modules.iter() { for obj in modules.iter() {
self.register_module_impl(obj.into())?; self.register_module_impl(obj.into(), nac3_annotated_cls)?;
} }
Ok(()) Ok(())
} }

View File

@ -40,6 +40,11 @@ struct PythonHelper<'a> {
type_fn: &'a PyAny, type_fn: &'a PyAny,
len_fn: &'a PyAny, len_fn: &'a PyAny,
id_fn: &'a PyAny, id_fn: &'a PyAny,
eval_type_fn: &'a PyAny,
origin_ty_fn: &'a PyAny,
args_ty_fn: &'a PyAny,
globals_dict: &'a PyAny,
print_fn: &'a PyAny,
} }
impl Resolver { impl Resolver {
@ -71,47 +76,51 @@ impl Resolver {
})) }))
} }
fn get_obj_type( // handle python objects that represent types themselves
// primitives and class types should be themselves, use `ty_id` to check,
// TypeVars and GenericAlias(`A[int, bool]`) should use `ty_ty_id` to check
// the `bool` value returned indicates whether they are instantiated or not
fn get_pyty_obj_type(
&self, &self,
obj: &PyAny, pyty: &PyAny,
helper: &PythonHelper, helper: &PythonHelper,
unifier: &mut Unifier, unifier: &mut Unifier,
defs: &[Arc<RwLock<TopLevelDef>>], defs: &[Arc<RwLock<TopLevelDef>>],
primitives: &PrimitiveStore, primitives: &PrimitiveStore,
) -> PyResult<Option<Type>> { ) -> PyResult<Result<(Type, bool), String>> {
// eval_type use only globals_dict should be fine
let evaluated_ty = helper
.eval_type_fn
.call1((pyty, helper.globals_dict, helper.globals_dict)).unwrap();
let ty_id: u64 = helper let ty_id: u64 = helper
.id_fn .id_fn
.call1((helper.type_fn.call1((obj,))?,))? .call1((evaluated_ty,))?
.extract()?; .extract()?;
let ty_ty_id: u64 = helper
.id_fn
.call1((helper.type_fn.call1((evaluated_ty,))?,))?
.extract()?;
if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 { if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 {
Ok(Some(primitives.int32)) Ok(Ok((primitives.int32, true)))
} else if ty_id == self.primitive_ids.int64 { } else if ty_id == self.primitive_ids.int64 {
Ok(Some(primitives.int64)) Ok(Ok((primitives.int64, true)))
} else if ty_id == self.primitive_ids.bool { } else if ty_id == self.primitive_ids.bool {
Ok(Some(primitives.bool)) Ok(Ok((primitives.bool, true)))
} else if ty_id == self.primitive_ids.float { } else if ty_id == self.primitive_ids.float {
Ok(Some(primitives.float)) Ok(Ok((primitives.float, true)))
} else if ty_id == self.primitive_ids.list { } else if ty_id == self.primitive_ids.list {
let len: usize = helper.len_fn.call1((obj,))?.extract()?; // do not handle type var param and concrete check here
if len == 0 { let var = unifier.get_fresh_var().0;
let var = unifier.get_fresh_var().0; let list = unifier.add_ty(TypeEnum::TList { ty: var });
let list = unifier.add_ty(TypeEnum::TList { ty: var }); Ok(Ok((list, false)))
Ok(Some(list))
} else {
let ty = self.get_list_elem_type(obj, len, helper, unifier, defs, primitives)?;
Ok(ty.map(|ty| unifier.add_ty(TypeEnum::TList { ty })))
}
} else if ty_id == self.primitive_ids.tuple { } else if ty_id == self.primitive_ids.tuple {
let elements: &PyTuple = obj.cast_as()?; // do not handle type var param and concrete check here
let types: Result<Option<Vec<_>>, _> = elements Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false)))
.iter() } else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).cloned() {
.map(|elem| self.get_obj_type(elem, helper, unifier, defs, primitives)) // println!("getting def");
.collect();
let types = types?;
Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types })))
} else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id) {
let def = defs[def_id.0].read(); let def = defs[def_id.0].read();
// println!("got def");
if let TopLevelDef::Class { if let TopLevelDef::Class {
object_id, object_id,
type_vars, type_vars,
@ -120,35 +129,260 @@ impl Resolver {
.. ..
} = &*def } = &*def
{ {
let var_map: HashMap<_, _> = type_vars // do not handle type var param and concrete check here, and no subst
.iter() Ok(Ok({
.map(|var| { let ty = TypeEnum::TObj {
( obj_id: *object_id,
if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) { params: RefCell::new({
*id type_vars
} else { .iter()
unreachable!() .map(|x| {
}, if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x) {
unifier.get_fresh_var().0, (*id, *x)
) } else { unreachable!() }
}) }).collect()
.collect(); }),
let mut fields_ty = HashMap::new(); fields: RefCell::new({
for method in methods.iter() { let mut res = methods
fields_ty.insert(method.0, (method.1, false)); .iter()
} .map(|(iden, ty, _)| (*iden, (*ty, false)))
for field in fields.iter() { .collect::<HashMap<_, _>>();
let name: String = field.0.into(); res.extend(fields.clone().into_iter().map(|x| (x.0, (x.1, x.2))));
let field_data = obj.getattr(&name)?; res
let ty = self })
.get_obj_type(field_data, helper, unifier, defs, primitives)? };
.unwrap_or(primitives.none); // here also false, later insta use python object to check compatible
let field_ty = unifier.subst(field.1, &var_map).unwrap_or(field.1); (unifier.add_ty(ty), false)
if unifier.unify(ty, field_ty).is_err() { }))
// field type mismatch } else {
return Ok(None); // only object is supported, functions are not supported
unreachable!("function type is not supported, should not be queried")
}
} else if ty_ty_id == self.primitive_ids.typevar {
let constraint_types = {
let constraints = pyty.getattr("__constraints__").unwrap();
let mut result: Vec<Type> = vec![];
for i in 0.. {
if let Ok(constr) = constraints.get_item(i) {
result.push({
match self.get_pyty_obj_type(constr, helper, unifier, defs, primitives)? {
Ok((ty, _)) => {
if unifier.is_concrete(ty, &[]) {
ty
} else {
return Ok(Err(format!(
"the {}th constraint of TypeVar `{}` is not concrete",
i + 1,
pyty.getattr("__name__")?.extract::<String>()?
)))
}
},
Err(err) => return Ok(Err(err))
}
})
} else {
break;
}
}
result
};
let res = unifier.get_fresh_var_with_range(&constraint_types).0;
Ok(Ok((res, true)))
} else if ty_ty_id == self.primitive_ids.generic_alias.0 || ty_ty_id == self.primitive_ids.generic_alias.1 {
let origin = helper.origin_ty_fn.call1((evaluated_ty,))?;
let args: &PyTuple = helper.args_ty_fn.call1((evaluated_ty,))?.cast_as()?;
let origin_ty = match self.get_pyty_obj_type(origin, helper, unifier, defs, primitives)? {
Ok((ty, false)) => ty,
Ok((_, true)) => return Ok(Err("instantiated type does not take type parameters".into())),
Err(err) => return Ok(Err(err))
};
match &*unifier.get_ty(origin_ty) {
TypeEnum::TList { .. } => {
if args.len() == 1 {
let ty = match self.get_pyty_obj_type(args.get_item(0), helper, unifier, defs, primitives)? {
Ok(ty) => ty,
Err(err) => return Ok(Err(err))
};
if !unifier.is_concrete(ty.0, &[]) && !ty.1 {
panic!("type list should take concrete parameters in type var ranges")
}
Ok(Ok((unifier.add_ty(TypeEnum::TList { ty: ty.0 }), true)))
} else {
return Ok(Err(format!("type list needs exactly 1 type parameters, found {}", args.len())))
}
},
TypeEnum::TTuple { .. } => {
let args = match args
.iter()
.map(|x| self.get_pyty_obj_type(x, helper, unifier, defs, primitives))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.collect::<Result<Vec<_>, _>>() {
Ok(args) if !args.is_empty() => args
.into_iter()
.map(|(x, check)| if !unifier.is_concrete(x, &[]) && !check {
panic!("type tuple should take concrete parameters in type var ranges")
} else {
x
}
)
.collect::<Vec<_>>(),
Err(err) => return Ok(Err(err)),
_ => return Ok(Err("tuple type needs at least 1 type parameters".to_string()))
};
Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: args }), true)))
},
TypeEnum::TObj { params, obj_id, .. } => {
let subst = {
let params = &*params.borrow();
if params.len() != args.len() {
return Ok(Err(format!(
"for class #{}, expect {} type parameters, got {}.",
obj_id.0,
params.len(),
args.len(),
)))
}
let args = match args
.iter()
.map(|x| self.get_pyty_obj_type(x, helper, unifier, defs, primitives))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.collect::<Result<Vec<_>, _>>() {
Ok(args) => args
.into_iter()
.map(|(x, check)| if !unifier.is_concrete(x, &[]) && !check {
panic!("type class should take concrete parameters in type var ranges")
} else {
x
}
)
.collect::<Vec<_>>(),
Err(err) => return Ok(Err(err)),
};
params
.iter()
.zip(args.iter())
.map(|((id, _), ty)| (*id, *ty))
.collect::<HashMap<_, _>>()
};
Ok(Ok((unifier.subst(origin_ty, &subst).unwrap_or(origin_ty), true)))
},
TypeEnum::TVirtual { .. } => {
if args.len() == 1 {
let ty = match self.get_pyty_obj_type(args.get_item(0), helper, unifier, defs, primitives)? {
Ok(ty) => ty,
Err(err) => return Ok(Err(err))
};
if !unifier.is_concrete(ty.0, &[]) && !ty.1 {
panic!("virtual class should take concrete parameters in type var ranges")
}
Ok(Ok((unifier.add_ty(TypeEnum::TVirtual { ty: ty.0 }), true)))
} else {
return Ok(Err(format!("virtual class needs exactly 1 type parameters, found {}", args.len())))
}
}
_ => unimplemented!()
}
} else if ty_id == self.primitive_ids.virtual_id {
Ok(Ok(({
let ty = TypeEnum::TVirtual { ty: unifier.get_fresh_var().0 };
unifier.add_ty(ty)
}, false)))
} else {
Ok(Err("unknown type".into()))
}
}
fn get_obj_type(
&self,
obj: &PyAny,
helper: &PythonHelper,
unifier: &mut Unifier,
defs: &[Arc<RwLock<TopLevelDef>>],
primitives: &PrimitiveStore,
) -> PyResult<Option<Type>> {
let (extracted_ty, inst_check) = match self.get_pyty_obj_type(
{
let ty = helper.type_fn.call1((obj,)).unwrap();
if [self.primitive_ids.typevar,
self.primitive_ids.generic_alias.0,
self.primitive_ids.generic_alias.1
].contains(&helper.id_fn.call1((ty,))?.extract::<u64>()?) {
obj
} else {
ty
}
},
helper,
unifier,
defs,
primitives
)? {
Ok(s) => s,
Err(_) => return Ok(None)
};
return match (&*unifier.get_ty(extracted_ty), inst_check) {
// do the instantiation for these three types
(TypeEnum::TList { ty }, false) => {
let len: usize = helper.len_fn.call1((obj,))?.extract()?;
if len == 0 {
assert!(matches!(
&*unifier.get_ty(extracted_ty),
TypeEnum::TVar { meta: nac3core::typecheck::typedef::TypeVarMeta::Generic, range, .. }
if range.borrow().is_empty()
));
Ok(Some(extracted_ty))
} else {
let actual_ty = self
.get_list_elem_type(obj, len, helper, unifier, defs, primitives)?;
if let Some(actual_ty) = actual_ty {
unifier.unify(*ty, actual_ty).unwrap();
Ok(Some(extracted_ty))
} else {
Ok(None)
}
}
}
(TypeEnum::TTuple { .. }, false) => {
let elements: &PyTuple = obj.cast_as()?;
let types: Result<Option<Vec<_>>, _> = elements
.iter()
.map(|elem| self.get_obj_type(elem, helper, unifier, defs, primitives))
.collect();
let types = types?;
Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types })))
}
(TypeEnum::TObj { params, fields, .. }, false) => {
let var_map = params
.borrow()
.iter()
.map(|(id_var, ty)| {
if let TypeEnum::TVar { id, range, .. } = &*unifier.get_ty(*ty) {
assert_eq!(*id, *id_var);
(*id, unifier.get_fresh_var_with_range(&range.borrow()).0)
} else {
unreachable!()
}
})
.collect::<HashMap<_, _>>();
// loop through non-function fields of the class to get the instantiated value
for field in fields.borrow().iter() {
let name: String = (*field.0).into();
if let TypeEnum::TFunc( .. ) = &*unifier.get_ty(field.1.0) {
continue;
} else {
let field_data = obj.getattr(&name)?;
let ty = self
.get_obj_type(field_data, helper, unifier, defs, primitives)?
.unwrap_or(primitives.none);
let field_ty = unifier.subst(field.1.0, &var_map).unwrap_or(field.1.0);
if unifier.unify(ty, field_ty).is_err() {
// field type mismatch
return Ok(None);
}
} }
fields_ty.insert(field.0, (ty, field.2));
} }
for (_, ty) in var_map.iter() { for (_, ty) in var_map.iter() {
// must be concrete type // must be concrete type
@ -156,18 +390,10 @@ impl Resolver {
return Ok(None) return Ok(None)
} }
} }
Ok(Some(unifier.add_ty(TypeEnum::TObj { return Ok(Some(unifier.subst(extracted_ty, &var_map).unwrap_or(extracted_ty)));
obj_id: *object_id,
fields: RefCell::new(fields_ty),
params: RefCell::new(var_map),
})))
} else {
// only object is supported, functions are not supported
Ok(None)
} }
} else { _ => Ok(Some(extracted_ty))
Ok(None) };
}
} }
fn get_obj_value<'ctx, 'a>( fn get_obj_value<'ctx, 'a>(
@ -425,10 +651,16 @@ impl SymbolResolver for Resolver {
let key: &str = member.get_item(0)?.extract()?; let key: &str = member.get_item(0)?.extract()?;
if key == str.to_string() { if key == str.to_string() {
let builtins = PyModule::import(py, "builtins")?; let builtins = PyModule::import(py, "builtins")?;
let typings = PyModule::import(py, "typing")?;
let helper = PythonHelper { let helper = PythonHelper {
id_fn: builtins.getattr("id").unwrap(), id_fn: builtins.getattr("id").unwrap(),
len_fn: builtins.getattr("len").unwrap(), len_fn: builtins.getattr("len").unwrap(),
type_fn: builtins.getattr("type").unwrap(), type_fn: builtins.getattr("type").unwrap(),
origin_ty_fn: typings.getattr("get_origin").unwrap(),
args_ty_fn: typings.getattr("get_args").unwrap(),
globals_dict: obj.getattr("__dict__").unwrap(),
eval_type_fn: typings.getattr("_eval_type").unwrap(),
print_fn: builtins.getattr("print").unwrap(),
}; };
sym_ty = self.get_obj_type( sym_ty = self.get_obj_type(
member.get_item(1)?, member.get_item(1)?,
@ -469,10 +701,16 @@ impl SymbolResolver for Resolver {
let val = member.get_item(1)?; let val = member.get_item(1)?;
if key == id.to_string() { if key == id.to_string() {
let builtins = PyModule::import(py, "builtins")?; let builtins = PyModule::import(py, "builtins")?;
let typings = PyModule::import(py, "typing")?;
let helper = PythonHelper { let helper = PythonHelper {
id_fn: builtins.getattr("id").unwrap(), id_fn: builtins.getattr("id").unwrap(),
len_fn: builtins.getattr("len").unwrap(), len_fn: builtins.getattr("len").unwrap(),
type_fn: builtins.getattr("type").unwrap(), type_fn: builtins.getattr("type").unwrap(),
origin_ty_fn: typings.getattr("get_origin").unwrap(),
args_ty_fn: typings.getattr("get_args").unwrap(),
globals_dict: obj.getattr("__dict__").unwrap(),
eval_type_fn: typings.getattr("_eval_type").unwrap(),
print_fn: builtins.getattr("print").unwrap(),
}; };
sym_value = self.get_obj_value(val, &helper, ctx)?; sym_value = self.get_obj_value(val, &helper, ctx)?;
break; break;

View File

@ -147,8 +147,14 @@ impl ConcreteTypeStore {
fields: fields fields: fields
.borrow() .borrow()
.iter() .iter()
.map(|(name, ty)| { .filter_map(|(name, ty)| {
(*name, (self.from_unifier_type(unifier, primitives, ty.0, cache), ty.1)) // filter out functions as they can have type vars and
// will not affect codegen
if let TypeEnum::TFunc( .. ) = &*unifier.get_ty(ty.0) {
None
} else {
Some((*name, (self.from_unifier_type(unifier, primitives, ty.0, cache), ty.1)))
}
}) })
.collect(), .collect(),
params: params params: params

View File

@ -7,7 +7,7 @@ use crate::{
}, },
symbol_resolver::SymbolValue, symbol_resolver::SymbolValue,
toplevel::{DefinitionId, TopLevelDef}, toplevel::{DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum}, typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier},
}; };
use inkwell::{ use inkwell::{
types::{BasicType, BasicTypeEnum}, types::{BasicType, BasicTypeEnum},
@ -21,6 +21,31 @@ use nac3parser::ast::{
use super::CodeGenerator; use super::CodeGenerator;
pub fn get_subst_key(
unifier: &mut Unifier,
obj: Option<Type>,
fun_vars: &HashMap<u32, Type>,
filter: Option<&Vec<u32>>,
) -> String {
let mut vars = obj
.map(|ty| {
if let TypeEnum::TObj { params, .. } = &*unifier.get_ty(ty) {
params.borrow().clone()
} else {
unreachable!()
}
})
.unwrap_or_default();
vars.extend(fun_vars.iter());
let sorted =
vars.keys().filter(|id| filter.map(|v| v.contains(id)).unwrap_or(true)).sorted();
sorted
.map(|id| {
unifier.stringify(vars[id], &mut |id| id.to_string(), &mut |id| id.to_string())
})
.join(", ")
}
impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
pub fn build_gep_and_load( pub fn build_gep_and_load(
&mut self, &mut self,
@ -36,23 +61,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
fun: &FunSignature, fun: &FunSignature,
filter: Option<&Vec<u32>>, filter: Option<&Vec<u32>>,
) -> String { ) -> String {
let mut vars = obj get_subst_key(&mut self.unifier, obj, &fun.vars, filter)
.map(|ty| {
if let TypeEnum::TObj { params, .. } = &*self.unifier.get_ty(ty) {
params.borrow().clone()
} else {
unreachable!()
}
})
.unwrap_or_default();
vars.extend(fun.vars.iter());
let sorted =
vars.keys().filter(|id| filter.map(|v| v.contains(id)).unwrap_or(true)).sorted();
sorted
.map(|id| {
self.unifier.stringify(vars[id], &mut |id| id.to_string(), &mut |id| id.to_string())
})
.join(", ")
} }
pub fn get_attr_index(&mut self, ty: Type, attr: StrRef) -> usize { pub fn get_attr_index(&mut self, ty: Type, attr: StrRef) -> usize {

View File

@ -13,7 +13,7 @@ use crate::{
use crate::{location::Location, typecheck::typedef::TypeEnum}; use crate::{location::Location, typecheck::typedef::TypeEnum};
use inkwell::values::BasicValueEnum; use inkwell::values::BasicValueEnum;
use itertools::{chain, izip}; use itertools::{chain, izip};
use nac3parser::ast::{Expr, StrRef}; use nac3parser::ast::{Constant::Str, Expr, StrRef};
use parking_lot::RwLock; use parking_lot::RwLock;
#[derive(Clone, PartialEq, Debug)] #[derive(Clone, PartialEq, Debug)]
@ -79,159 +79,168 @@ pub fn parse_type_annotation<T>(
let list_id = ids[6]; let list_id = ids[6];
let tuple_id = ids[7]; let tuple_id = ids[7];
match &expr.node { let name_handling = |id: &StrRef, unifier: &mut Unifier| {
Name { id, .. } => { if *id == int32_id {
if *id == int32_id { Ok(primitives.int32)
Ok(primitives.int32) } else if *id == int64_id {
} else if *id == int64_id { Ok(primitives.int64)
Ok(primitives.int64) } else if *id == float_id {
} else if *id == float_id { Ok(primitives.float)
Ok(primitives.float) } else if *id == bool_id {
} else if *id == bool_id { Ok(primitives.bool)
Ok(primitives.bool) } else if *id == none_id {
} else if *id == none_id { Ok(primitives.none)
Ok(primitives.none) } else {
} else { let obj_id = resolver.get_identifier_def(*id);
let obj_id = resolver.get_identifier_def(*id); if let Some(obj_id) = obj_id {
if let Some(obj_id) = obj_id { let def = top_level_defs[obj_id.0].read();
let def = top_level_defs[obj_id.0].read(); if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { if !type_vars.is_empty() {
if !type_vars.is_empty() { return Err(format!(
return Err(format!( "Unexpected number of type parameters: expected {} but got 0",
"Unexpected number of type parameters: expected {} but got 0", type_vars.len()
type_vars.len() ));
));
}
let fields = RefCell::new(
chain(
fields.iter().map(|(k, v, m)| (*k, (*v, *m))),
methods.iter().map(|(k, v, _)| (*k, (*v, false))),
)
.collect(),
);
Ok(unifier.add_ty(TypeEnum::TObj {
obj_id,
fields,
params: Default::default(),
}))
} else {
Err("Cannot use function name as type".into())
} }
let fields = RefCell::new(
chain(
fields.iter().map(|(k, v, m)| (*k, (*v, *m))),
methods.iter().map(|(k, v, _)| (*k, (*v, false))),
)
.collect(),
);
Ok(unifier.add_ty(TypeEnum::TObj {
obj_id,
fields,
params: Default::default(),
}))
} else { } else {
// it could be a type variable Err("Cannot use function name as type".into())
let ty = resolver }
.get_symbol_type(unifier, top_level_defs, primitives, *id) } else {
.ok_or_else(|| "unknown type variable name".to_owned())?; // it could be a type variable
if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) { let ty = resolver
Ok(ty) .get_symbol_type(unifier, top_level_defs, primitives, *id)
} else { .ok_or_else(|| "unknown type variable name".to_owned())?;
Err(format!("Unknown type annotation {}", id)) if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) {
} Ok(ty)
} else {
Err(format!("Unknown type annotation {}", id))
} }
} }
} }
Subscript { value, slice, .. } => { };
if let Name { id, .. } = &value.node {
if *id == virtual_id { let subscript_name_handle = |id: &StrRef, slice: &Expr<T>, unifier: &mut Unifier| {
let ty = parse_type_annotation( if *id == virtual_id {
resolver, let ty = parse_type_annotation(
top_level_defs, resolver,
unifier, top_level_defs,
primitives, unifier,
slice, primitives,
)?; slice,
Ok(unifier.add_ty(TypeEnum::TVirtual { ty })) )?;
} else if *id == list_id { Ok(unifier.add_ty(TypeEnum::TVirtual { ty }))
let ty = parse_type_annotation( } else if *id == list_id {
resolver, let ty = parse_type_annotation(
top_level_defs, resolver,
unifier, top_level_defs,
primitives, unifier,
slice, primitives,
)?; slice,
Ok(unifier.add_ty(TypeEnum::TList { ty })) )?;
} else if *id == tuple_id { Ok(unifier.add_ty(TypeEnum::TList { ty }))
if let Tuple { elts, .. } = &slice.node { } else if *id == tuple_id {
let ty = elts if let Tuple { elts, .. } = &slice.node {
.iter() let ty = elts
.map(|elt| { .iter()
parse_type_annotation( .map(|elt| {
resolver, parse_type_annotation(
top_level_defs,
unifier,
primitives,
elt,
)
})
.collect::<Result<Vec<_>, _>>()?;
Ok(unifier.add_ty(TypeEnum::TTuple { ty }))
} else {
Err("Expected multiple elements for tuple".into())
}
} else {
let types = if let Tuple { elts, .. } = &slice.node {
elts.iter()
.map(|v| {
parse_type_annotation(
resolver,
top_level_defs,
unifier,
primitives,
v,
)
})
.collect::<Result<Vec<_>, _>>()?
} else {
vec![parse_type_annotation(
resolver, resolver,
top_level_defs, top_level_defs,
unifier, unifier,
primitives, primitives,
slice, elt,
)?] )
}; })
.collect::<Result<Vec<_>, _>>()?;
Ok(unifier.add_ty(TypeEnum::TTuple { ty }))
} else {
Err("Expected multiple elements for tuple".into())
}
} else {
let types = if let Tuple { elts, .. } = &slice.node {
elts.iter()
.map(|v| {
parse_type_annotation(
resolver,
top_level_defs,
unifier,
primitives,
v,
)
})
.collect::<Result<Vec<_>, _>>()?
} else {
vec![parse_type_annotation(
resolver,
top_level_defs,
unifier,
primitives,
slice,
)?]
};
let obj_id = resolver let obj_id = resolver
.get_identifier_def(*id) .get_identifier_def(*id)
.ok_or_else(|| format!("Unknown type annotation {}", id))?; .ok_or_else(|| format!("Unknown type annotation {}", id))?;
let def = top_level_defs[obj_id.0].read(); let def = top_level_defs[obj_id.0].read();
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
if types.len() != type_vars.len() { if types.len() != type_vars.len() {
return Err(format!( return Err(format!(
"Unexpected number of type parameters: expected {} but got {}", "Unexpected number of type parameters: expected {} but got {}",
type_vars.len(), type_vars.len(),
types.len() types.len()
)); ));
}
let mut subst = HashMap::new();
for (var, ty) in izip!(type_vars.iter(), types.iter()) {
let id = if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) {
*id
} else {
unreachable!()
};
subst.insert(id, *ty);
}
let mut fields = fields
.iter()
.map(|(attr, ty, is_mutable)| {
let ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*attr, (ty, *is_mutable))
})
.collect::<HashMap<_, _>>();
fields.extend(methods.iter().map(|(attr, ty, _)| {
let ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*attr, (ty, false))
}));
Ok(unifier.add_ty(TypeEnum::TObj {
obj_id,
fields: fields.into(),
params: subst.into(),
}))
} else {
Err("Cannot use function name as type".into())
}
} }
let mut subst = HashMap::new();
for (var, ty) in izip!(type_vars.iter(), types.iter()) {
let id = if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) {
*id
} else {
unreachable!()
};
subst.insert(id, *ty);
}
let mut fields = fields
.iter()
.map(|(attr, ty, is_mutable)| {
let ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*attr, (ty, *is_mutable))
})
.collect::<HashMap<_, _>>();
fields.extend(methods.iter().map(|(attr, ty, _)| {
let ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*attr, (ty, false))
}));
Ok(unifier.add_ty(TypeEnum::TObj {
obj_id,
fields: fields.into(),
params: subst.into(),
}))
} else {
Err("Cannot use function name as type".into())
}
}
};
match &expr.node {
Name { id, .. } => name_handling(id, unifier),
Constant { value: Str(id), .. } => name_handling(&id.clone().into(), unifier),
Subscript { value, slice, .. } => {
if let Name { id, .. } = &value.node {
subscript_name_handle(id, slice, unifier)
} else if let Constant { value: Str(id), .. } = &value.node {
subscript_name_handle(&id.clone().into(), slice, unifier)
} else { } else {
Err("unsupported type expression".into()) Err("unsupported type expression".into())
} }

View File

@ -6,6 +6,7 @@ use inkwell::FloatPredicate;
use crate::{ use crate::{
symbol_resolver::SymbolValue, symbol_resolver::SymbolValue,
typecheck::type_inferencer::{FunctionData, Inferencer}, typecheck::type_inferencer::{FunctionData, Inferencer},
codegen::expr::get_subst_key,
}; };
use super::*; use super::*;
@ -534,7 +535,7 @@ impl TopLevelComposer {
} }
} }
fn extract_def_list(&self) -> Vec<Arc<RwLock<TopLevelDef>>> { pub fn extract_def_list(&self) -> Vec<Arc<RwLock<TopLevelDef>>> {
self.definition_ast_list.iter().map(|(def, ..)| def.clone()).collect_vec() self.definition_ast_list.iter().map(|(def, ..)| def.clone()).collect_vec()
} }
@ -1654,7 +1655,7 @@ impl TopLevelComposer {
if let TypeEnum::TFunc(func_sig) = self.unifier.get_ty(*signature).as_ref() { if let TypeEnum::TFunc(func_sig) = self.unifier.get_ty(*signature).as_ref() {
let FunSignature { args, ret, vars } = &*func_sig.borrow(); let FunSignature { args, ret, vars } = &*func_sig.borrow();
// None if is not class method // None if is not class method
let self_type = { let uninst_self_type = {
if let Some(class_id) = self.method_class.get(&DefinitionId(id)) { if let Some(class_id) = self.method_class.get(&DefinitionId(id)) {
let class_def = self.definition_ast_list.get(class_id.0).unwrap(); let class_def = self.definition_ast_list.get(class_id.0).unwrap();
let class_def = class_def.0.read(); let class_def = class_def.0.read();
@ -1666,7 +1667,7 @@ impl TopLevelComposer {
&self.primitives_ty, &self.primitives_ty,
&ty_ann, &ty_ann,
)?; )?;
Some(self_ty) Some((self_ty, type_vars.clone()))
} else { } else {
unreachable!("must be class def") unreachable!("must be class def")
} }
@ -1717,9 +1718,34 @@ impl TopLevelComposer {
}; };
let self_type = { let self_type = {
let unifier = &mut self.unifier; let unifier = &mut self.unifier;
self_type.map(|x| unifier.subst(x, &subst).unwrap_or(x)) uninst_self_type
.clone()
.map(|(self_type, type_vars)| {
let subst_for_self = {
let class_ty_var_ids = type_vars
.iter()
.map(|x| {
if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x) {
*id
} else {
unreachable!("must be type var here");
}
})
.collect::<HashSet<_>>();
subst
.iter()
.filter_map(|(ty_var_id, ty_var_target)| {
if class_ty_var_ids.contains(ty_var_id) {
Some((*ty_var_id, *ty_var_target))
} else {
None
}
})
.collect::<HashMap<_, _>>()
};
unifier.subst(self_type, &subst_for_self).unwrap_or(self_type)
})
}; };
let mut identifiers = { let mut identifiers = {
// NOTE: none and function args? // NOTE: none and function args?
let mut result: HashSet<_> = HashSet::new(); let mut result: HashSet<_> = HashSet::new();
@ -1810,21 +1836,13 @@ impl TopLevelComposer {
instance_to_stmt.insert( instance_to_stmt.insert(
// NOTE: refer to codegen/expr/get_subst_key function // NOTE: refer to codegen/expr/get_subst_key function
{
let unifier = &mut self.unifier; get_subst_key(
subst &mut self.unifier,
.keys() self_type,
.sorted() &subst,
.map(|id| { None
let ty = subst.get(id).unwrap(); ),
unifier.stringify(
*ty,
&mut |id| id.to_string(),
&mut |id| id.to_string(),
)
})
.join(", ")
},
FunInstance { FunInstance {
body: Arc::new(fun_body), body: Arc::new(fun_body),
unifier_id: 0, unifier_id: 0,

View File

@ -1,7 +1,7 @@
use std::cell::RefCell; use std::cell::RefCell;
use crate::typecheck::typedef::TypeVarMeta; use crate::typecheck::typedef::TypeVarMeta;
use ast::Constant::Str;
use super::*; use super::*;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -49,58 +49,127 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
primitives: &PrimitiveStore, primitives: &PrimitiveStore,
expr: &ast::Expr<T>, expr: &ast::Expr<T>,
// the key stores the type_var of this topleveldef::class, we only need this field here // the key stores the type_var of this topleveldef::class, we only need this field here
mut locked: HashMap<DefinitionId, Vec<Type>>, locked: HashMap<DefinitionId, Vec<Type>>,
) -> Result<TypeAnnotation, String> { ) -> Result<TypeAnnotation, String> {
match &expr.node { let name_handle = |id: &StrRef, unifier: &mut Unifier, locked: HashMap<DefinitionId, Vec<Type>>| {
ast::ExprKind::Name { id, .. } => { if id == &"int32".into() {
if id == &"int32".into() { Ok(TypeAnnotation::Primitive(primitives.int32))
Ok(TypeAnnotation::Primitive(primitives.int32)) } else if id == &"int64".into() {
} else if id == &"int64".into() { Ok(TypeAnnotation::Primitive(primitives.int64))
Ok(TypeAnnotation::Primitive(primitives.int64)) } else if id == &"float".into() {
} else if id == &"float".into() { Ok(TypeAnnotation::Primitive(primitives.float))
Ok(TypeAnnotation::Primitive(primitives.float)) } else if id == &"bool".into() {
} else if id == &"bool".into() { Ok(TypeAnnotation::Primitive(primitives.bool))
Ok(TypeAnnotation::Primitive(primitives.bool)) } else if id == &"None".into() {
} else if id == &"None".into() { Ok(TypeAnnotation::Primitive(primitives.none))
Ok(TypeAnnotation::Primitive(primitives.none)) } else if id == &"str".into() {
} else if id == &"str".into() { Ok(TypeAnnotation::Primitive(primitives.str))
Ok(TypeAnnotation::Primitive(primitives.str)) } else if let Some(obj_id) = resolver.get_identifier_def(*id) {
} else if let Some(obj_id) = resolver.get_identifier_def(*id) { let type_vars = {
let type_vars = { let def_read = top_level_defs[obj_id.0].try_read();
let def_read = top_level_defs[obj_id.0].try_read(); if let Some(def_read) = def_read {
if let Some(def_read) = def_read { if let TopLevelDef::Class { type_vars, .. } = &*def_read {
if let TopLevelDef::Class { type_vars, .. } = &*def_read { type_vars.clone()
type_vars.clone()
} else {
return Err("function cannot be used as a type".into());
}
} else { } else {
locked.get(&obj_id).unwrap().clone() return Err("function cannot be used as a type".into());
} }
};
// check param number here
if !type_vars.is_empty() {
return Err(format!(
"expect {} type variable parameter but got 0",
type_vars.len()
));
}
Ok(TypeAnnotation::CustomClass { id: obj_id, params: vec![] })
} else if let Some(ty) = resolver.get_symbol_type(unifier, top_level_defs, primitives, *id) {
if let TypeEnum::TVar { .. } = unifier.get_ty(ty).as_ref() {
Ok(TypeAnnotation::TypeVar(ty))
} else { } else {
Err("not a type variable identifier".into()) locked.get(&obj_id).unwrap().clone()
}
};
// check param number here
if !type_vars.is_empty() {
return Err(format!(
"expect {} type variable parameter but got 0",
type_vars.len()
));
}
Ok(TypeAnnotation::CustomClass { id: obj_id, params: vec![] })
} else if let Some(ty) = resolver.get_symbol_type(unifier, top_level_defs, primitives, *id) {
if let TypeEnum::TVar { .. } = unifier.get_ty(ty).as_ref() {
Ok(TypeAnnotation::TypeVar(ty))
} else {
Err("not a type variable identifier".into())
}
} else {
Err("name cannot be parsed as a type annotation".into())
}
};
let class_name_handle =
|id: &StrRef, slice: &ast::Expr<T>, unifier: &mut Unifier, mut locked: HashMap<DefinitionId, Vec<Type>>| {
if vec!["virtual".into(), "Generic".into(), "list".into(), "tuple".into()]
.contains(id)
{
return Err("keywords cannot be class name".into());
}
let obj_id = resolver
.get_identifier_def(*id)
.ok_or_else(|| "unknown class name".to_string())?;
let type_vars = {
let def_read = top_level_defs[obj_id.0].try_read();
if let Some(def_read) = def_read {
if let TopLevelDef::Class { type_vars, .. } = &*def_read {
type_vars.clone()
} else {
unreachable!("must be class here")
} }
} else { } else {
Err("name cannot be parsed as a type annotation".into()) locked.get(&obj_id).unwrap().clone()
} }
} };
// we do not check whether the application of type variables are compatible here
let param_type_infos = {
let params_ast = if let ast::ExprKind::Tuple { elts, .. } = &slice.node {
elts.iter().collect_vec()
} else {
vec![slice]
};
if type_vars.len() != params_ast.len() {
return Err(format!(
"expect {} type parameters but got {}",
type_vars.len(),
params_ast.len()
));
}
let result = params_ast
.into_iter()
.map(|x| {
parse_ast_to_type_annotation_kinds(
resolver,
top_level_defs,
unifier,
primitives,
x,
{
locked.insert(obj_id, type_vars.clone());
locked.clone()
},
)
})
.collect::<Result<Vec<_>, _>>()?;
// make sure the result do not contain any type vars
let no_type_var = result
.iter()
.all(|x| get_type_var_contained_in_type_annotation(x).is_empty());
if no_type_var {
result
} else {
return Err("application of type vars to generic class \
is not currently supported"
.into());
}
};
Ok(TypeAnnotation::CustomClass { id: obj_id, params: param_type_infos })
};
match &expr.node {
ast::ExprKind::Name { id, .. } => name_handle(id, unifier, locked),
ast::ExprKind::Constant { value: Str(id), .. } => name_handle(&id.clone().into(), unifier, locked),
// virtual // virtual
ast::ExprKind::Subscript { value, slice, .. } ast::ExprKind::Subscript { value, slice, .. }
if { if {
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"virtual".into()) matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"virtual".into()) ||
matches!(&value.node, ast::ExprKind::Constant { value: Str(id), .. } if id == "virtual")
} => } =>
{ {
let def = parse_ast_to_type_annotation_kinds( let def = parse_ast_to_type_annotation_kinds(
@ -120,7 +189,8 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
// list // list
ast::ExprKind::Subscript { value, slice, .. } ast::ExprKind::Subscript { value, slice, .. }
if { if {
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"list".into()) matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"list".into()) ||
matches!(&value.node, ast::ExprKind::Constant { value: Str(id), .. } if id == "list")
} => } =>
{ {
let def_ann = parse_ast_to_type_annotation_kinds( let def_ann = parse_ast_to_type_annotation_kinds(
@ -137,7 +207,8 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
// tuple // tuple
ast::ExprKind::Subscript { value, slice, .. } ast::ExprKind::Subscript { value, slice, .. }
if { if {
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"tuple".into()) matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"tuple".into()) ||
matches!(&value.node, ast::ExprKind::Constant { value: Str(id), .. } if id == "tuple")
} => } =>
{ {
if let ast::ExprKind::Tuple { elts, .. } = &slice.node { if let ast::ExprKind::Tuple { elts, .. } = &slice.node {
@ -163,71 +234,9 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
// custom class // custom class
ast::ExprKind::Subscript { value, slice, .. } => { ast::ExprKind::Subscript { value, slice, .. } => {
if let ast::ExprKind::Name { id, .. } = &value.node { if let ast::ExprKind::Name { id, .. } = &value.node {
if vec!["virtual".into(), "Generic".into(), "list".into(), "tuple".into()] class_name_handle(id, slice, unifier, locked)
.contains(id) } else if let ast::ExprKind::Constant { value: Str(id), .. } = &value.node {
{ class_name_handle(&id.clone().into(), slice, unifier, locked)
return Err("keywords cannot be class name".into());
}
let obj_id = resolver
.get_identifier_def(*id)
.ok_or_else(|| "unknown class name".to_string())?;
let type_vars = {
let def_read = top_level_defs[obj_id.0].try_read();
if let Some(def_read) = def_read {
if let TopLevelDef::Class { type_vars, .. } = &*def_read {
type_vars.clone()
} else {
unreachable!("must be class here")
}
} else {
locked.get(&obj_id).unwrap().clone()
}
};
// we do not check whether the application of type variables are compatible here
let param_type_infos = {
let params_ast = if let ast::ExprKind::Tuple { elts, .. } = &slice.node {
elts.iter().collect_vec()
} else {
vec![slice.as_ref()]
};
if type_vars.len() != params_ast.len() {
return Err(format!(
"expect {} type parameters but got {}",
type_vars.len(),
params_ast.len()
));
}
let result = params_ast
.into_iter()
.map(|x| {
parse_ast_to_type_annotation_kinds(
resolver,
top_level_defs,
unifier,
primitives,
x,
{
locked.insert(obj_id, type_vars.clone());
locked.clone()
},
)
})
.collect::<Result<Vec<_>, _>>()?;
// make sure the result do not contain any type vars
let no_type_var = result
.iter()
.all(|x| get_type_var_contained_in_type_annotation(x).is_empty());
if no_type_var {
result
} else {
return Err("application of type vars to generic class \
is not currently supported"
.into());
}
};
Ok(TypeAnnotation::CustomClass { id: obj_id, params: param_type_infos })
} else { } else {
Err("unsupported expression type for class name".into()) Err("unsupported expression type for class name".into())
} }
@ -280,9 +289,11 @@ pub fn get_type_from_type_annotation_kinds(
{ {
let ok: bool = { let ok: bool = {
// create a temp type var and unify to check compatibility // create a temp type var and unify to check compatibility
let temp = p == *tvar || {
unifier.get_fresh_var_with_range(range.borrow().as_slice()); let temp =
unifier.unify(temp.0, p).is_ok() unifier.get_fresh_var_with_range(range.borrow().as_slice());
unifier.unify(temp.0, p).is_ok()
}
}; };
if ok { if ok {
result.insert(*id, p); result.insert(*id, p);
@ -368,13 +379,7 @@ pub fn get_type_from_type_annotation_kinds(
/// But note that here we do not make a duplication of `T`, `V`, we direclty /// But note that here we do not make a duplication of `T`, `V`, we direclty
/// use them as they are in the TopLevelDef::Class since those in the /// use them as they are in the TopLevelDef::Class since those in the
/// TopLevelDef::Class.type_vars will be substitute later when seeing applications/instantiations /// TopLevelDef::Class.type_vars will be substitute later when seeing applications/instantiations
/// the Type of their fields and methods will also be subst when application/instantiation \ /// the Type of their fields and methods will also be subst when application/instantiation
/// \
/// Note this implicit self type is different with seeing `A[T, V]` explicitly outside
/// the class def ast body, where it is a new instantiation of the generic class `A`,
/// but equivalent to seeing `A[T, V]` inside the class def body ast, where although we
/// create copies of `T` and `V`, we will find them out as occured type vars in the analyze_class()
/// and unify them with the class generic `T`, `V`
pub fn make_self_type_annotation(type_vars: &[Type], object_id: DefinitionId) -> TypeAnnotation { pub fn make_self_type_annotation(type_vars: &[Type], object_id: DefinitionId) -> TypeAnnotation {
TypeAnnotation::CustomClass { TypeAnnotation::CustomClass {
id: object_id, id: object_id,

View File

@ -719,22 +719,18 @@ impl Unifier {
/// Returns Some(T) where T is the instantiated type. /// Returns Some(T) where T is the instantiated type.
/// Returns None if the function is already instantiated. /// Returns None if the function is already instantiated.
fn instantiate_fun(&mut self, ty: Type, fun: &FunSignature) -> Type { fn instantiate_fun(&mut self, ty: Type, fun: &FunSignature) -> Type {
let mut instantiated = false; let mut instantiated = true;
let mut vars = Vec::new(); let mut vars = Vec::new();
for (k, v) in fun.vars.iter() { for (k, v) in fun.vars.iter() {
if let TypeEnum::TVar { id, range, .. } = if let TypeEnum::TVar { id, range, .. } =
self.unification_table.probe_value(*v).as_ref() self.unification_table.probe_value(*v).as_ref()
{ {
if k != id { // need to do this for partial instantiated function
instantiated = true; // (in class methods that contains type vars not in class)
break; if k == id {
instantiated = false;
vars.push((*k, range.clone()));
} }
// actually, if the first check succeeded, the function should be uninstatiated.
// The cloned values must be used and would not be wasted.
vars.push((*k, range.clone()));
} else {
instantiated = true;
break;
} }
} }
if instantiated { if instantiated {

View File

@ -4,7 +4,7 @@ use inkwell::{
OptimizationLevel, OptimizationLevel,
}; };
use nac3core::typecheck::type_inferencer::PrimitiveStore; use nac3core::typecheck::type_inferencer::PrimitiveStore;
use nac3parser::parser; use nac3parser::{ast::{ExprKind, StmtKind}, parser};
use std::env; use std::env;
use std::fs; use std::fs;
use std::{collections::HashMap, path::Path, sync::Arc, time::SystemTime}; use std::{collections::HashMap, path::Path, sync::Arc, time::SystemTime};
@ -66,6 +66,46 @@ fn main() {
); );
for stmt in parser_result.into_iter() { for stmt in parser_result.into_iter() {
// handle type vars in toplevel
if let StmtKind::Assign { value, targets, .. } = &stmt.node {
assert_eq!(targets.len(), 1, "only support single assignment for now, at {}", targets[0].location);
if let ExprKind::Call { func, args, .. } = &value.node {
if matches!(&func.node, ExprKind::Name { id, .. } if id == &"TypeVar".into()) {
print!("registering typevar {:?}", targets[0].node);
let constraints = args
.iter()
.skip(1)
.map(|x| {
let def_list = &composer.extract_def_list();
let unifier = &mut composer.unifier;
resolver.parse_type_annotation(
def_list,
unifier,
&primitive,
x
).unwrap()
})
.collect::<Vec<_>>();
let res_ty = composer.unifier.get_fresh_var_with_range(&constraints).0;
println!(
" ...registered: {}",
composer.unifier.stringify(
res_ty,
&mut |x| format!("obj{}", x),
&mut |x| format!("tavr{}", x)
)
);
internal_resolver.add_id_type(
if let ExprKind::Name { id, .. } = &targets[0].node { *id } else {
panic!("must assign simple name variable as type variable for now")
},
res_ty
);
continue;
}
}
}
let (name, def_id, ty) = composer let (name, def_id, ty) = composer
.register_top_level(stmt, Some(resolver.clone()), "__main__".into()) .register_top_level(stmt, Some(resolver.clone()), "__main__".into())
.unwrap(); .unwrap();