polymorphism and inheritance related fixes #92

Closed
ychenfo wants to merge 7 commits from range_with_class into master
10 changed files with 782 additions and 388 deletions

View File

@ -8,18 +8,23 @@ import nac3artiq
__all__ = ["KernelInvariant", "extern", "kernel", "portable", "nac3", __all__ = ["KernelInvariant", "extern", "kernel", "portable", "nac3",
"ms", "us", "ns", "ms", "us", "ns",
"Core", "TTLOut", "parallel", "sequential"] "Core", "TTLOut", "parallel", "sequential", "virtual"]
import device_db import device_db
core_arguments = device_db.device_db["core"]["arguments"] core_arguments = device_db.device_db["core"]["arguments"]
T = TypeVar('T')
# place the `virtual` class infront of the construct of NAC3 object to ensure the
# virtual class is known during the initializing of NAC3 object
class virtual(Generic[T]):
pass
compiler = nac3artiq.NAC3(core_arguments["target"]) compiler = nac3artiq.NAC3(core_arguments["target"])
allow_module_registration = True allow_module_registration = True
registered_modules = set() registered_modules = set()
nac3annotated_class_ids = set()
T = TypeVar('T')
class KernelInvariant(Generic[T]): class KernelInvariant(Generic[T]):
pass pass
@ -64,6 +69,7 @@ def nac3(cls):
All classes containing kernels or portable methods must use this decorator. All classes containing kernels or portable methods must use this decorator.
""" """
register_module_of(cls) register_module_of(cls)
nac3annotated_class_ids.add(id(cls))
return cls return cls
@ -106,7 +112,7 @@ class Core:
def run(self, method, *args, **kwargs): def run(self, method, *args, **kwargs):
global allow_module_registration global allow_module_registration
if allow_module_registration: if allow_module_registration:
compiler.analyze_modules(registered_modules) compiler.analyze_modules(registered_modules, nac3annotated_class_ids)
allow_module_registration = False allow_module_registration = False
if hasattr(method, "__self__"): if hasattr(method, "__self__"):

View File

@ -11,7 +11,7 @@ use inkwell::{
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::{exceptions, types::PyList, types::PySet, types::PyBytes}; use pyo3::{exceptions, types::PyList, types::PySet, types::PyBytes};
use nac3parser::{ use nac3parser::{
ast::{self, StrRef}, ast::{self, StrRef, Constant::Str},
parser::{self, parse_program}, parser::{self, parse_program},
}; };
@ -51,6 +51,10 @@ pub struct PrimitivePythonId {
bool: u64, bool: u64,
list: u64, list: u64,
tuple: u64, tuple: u64,
typevar: u64,
none: u64,
generic_alias: (u64, u64),
virtual_id: u64,
} }
// TopLevelComposer is unsendable as it holds the unification table, which is // TopLevelComposer is unsendable as it holds the unification table, which is
@ -72,7 +76,7 @@ struct Nac3 {
} }
impl Nac3 { impl Nac3 {
fn register_module_impl(&mut self, obj: PyObject) -> PyResult<()> { fn register_module_impl(&mut self, obj: PyObject, nac3_annotated_cls: &PySet) -> PyResult<()> {
let mut name_to_pyid: HashMap<StrRef, u64> = HashMap::new(); let mut name_to_pyid: HashMap<StrRef, u64> = HashMap::new();
let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> { let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> {
let obj: &PyAny = obj.extract(py)?; let obj: &PyAny = obj.extract(py)?;
@ -107,7 +111,7 @@ impl Nac3 {
global_value_ids: self.global_value_ids.clone(), global_value_ids: self.global_value_ids.clone(),
class_names: Default::default(), class_names: Default::default(),
name_to_pyid: name_to_pyid.clone(), name_to_pyid: name_to_pyid.clone(),
module: obj, module: obj.clone(),
}) as Arc<dyn SymbolResolver + Send + Sync>; }) as Arc<dyn SymbolResolver + Send + Sync>;
let mut name_to_def = HashMap::new(); let mut name_to_def = HashMap::new();
let mut name_to_type = HashMap::new(); let mut name_to_type = HashMap::new();
@ -117,6 +121,7 @@ impl Nac3 {
ast::StmtKind::ClassDef { ast::StmtKind::ClassDef {
ref decorator_list, ref decorator_list,
ref mut body, ref mut body,
ref mut bases,
.. ..
} => { } => {
let kernels = decorator_list.iter().any(|decorator| { let kernels = decorator_list.iter().any(|decorator| {
@ -142,6 +147,33 @@ impl Nac3 {
true true
} }
}); });
bases.retain(|b| {
Python::with_gil(|py| -> PyResult<bool> {
let obj: &PyAny = obj.extract(py)?;
let annot_check = |id: &str| -> bool {
let id = py.eval(
&format!("id({})", id),
Some(obj.getattr("__dict__").unwrap().extract().unwrap()),
None
).unwrap();
nac3_annotated_cls.contains(id).unwrap()
};
match &b.node {
ast::ExprKind::Name { id, .. } => Ok(annot_check(&id.to_string())),
ast::ExprKind::Constant { value: Str(id), .. } =>
Ok(annot_check(id.split('[').next().unwrap())),
ast::ExprKind::Subscript { value, .. } => {
match &value.node {
ast::ExprKind::Name { id, .. } => Ok(annot_check(&id.to_string()) || *id == "Generic".into()),
ast::ExprKind::Constant { value: Str(id), .. } =>
Ok(annot_check(id.split('[').next().unwrap())),
_ => unreachable!("unsupported base declaration")
}
}
_ => unreachable!("unsupported base declaration")
}
}).unwrap()
});
kernels kernels
} }
ast::StmtKind::FunctionDef { ast::StmtKind::FunctionDef {
@ -246,7 +278,42 @@ impl Nac3 {
let builtins_mod = PyModule::import(py, "builtins").unwrap(); let builtins_mod = PyModule::import(py, "builtins").unwrap();
let id_fn = builtins_mod.getattr("id").unwrap(); let id_fn = builtins_mod.getattr("id").unwrap();
let numpy_mod = PyModule::import(py, "numpy").unwrap(); let numpy_mod = PyModule::import(py, "numpy").unwrap();
let typing_mod = PyModule::import(py, "typing").unwrap();
let types_mod = PyModule::import(py, "types").unwrap();
let primitive_ids = PrimitivePythonId { let primitive_ids = PrimitivePythonId {
virtual_id: id_fn
.call1((builtins_mod
.getattr("globals")
.unwrap()
.call0()
.unwrap()
.get_item("virtual")
.unwrap(),
)).unwrap()
.extract()
.unwrap(),
generic_alias: (
id_fn
.call1((typing_mod.getattr("_GenericAlias").unwrap(),))
.unwrap()
.extract()
.unwrap(),
id_fn
.call1((types_mod.getattr("GenericAlias").unwrap(),))
.unwrap()
.extract()
.unwrap(),
),
none: id_fn
.call1((builtins_mod.getattr("None").unwrap(),))
.unwrap()
.extract()
.unwrap(),
typevar: id_fn
.call1((typing_mod.getattr("TypeVar").unwrap(),))
.unwrap()
.extract()
.unwrap(),
int: id_fn int: id_fn
.call1((builtins_mod.getattr("int").unwrap(),)) .call1((builtins_mod.getattr("int").unwrap(),))
.unwrap() .unwrap()
@ -303,9 +370,9 @@ impl Nac3 {
}) })
} }
fn analyze_modules(&mut self, modules: &PySet) -> PyResult<()> { fn analyze_modules(&mut self, modules: &PySet, nac3_annotated_cls: &PySet) -> PyResult<()> {
for obj in modules.iter() { for obj in modules.iter() {
self.register_module_impl(obj.into())?; self.register_module_impl(obj.into(), nac3_annotated_cls)?;
} }
Ok(()) Ok(())
} }

View File

@ -40,6 +40,11 @@ struct PythonHelper<'a> {
type_fn: &'a PyAny, type_fn: &'a PyAny,
len_fn: &'a PyAny, len_fn: &'a PyAny,
id_fn: &'a PyAny, id_fn: &'a PyAny,
eval_type_fn: &'a PyAny,
origin_ty_fn: &'a PyAny,
args_ty_fn: &'a PyAny,
globals_dict: &'a PyAny,
print_fn: &'a PyAny,
} }
impl Resolver { impl Resolver {
@ -71,47 +76,51 @@ impl Resolver {
})) }))
} }
fn get_obj_type( // handle python objects that represent types themselves
// primitives and class types should be themselves, use `ty_id` to check,
// TypeVars and GenericAlias(`A[int, bool]`) should use `ty_ty_id` to check
// the `bool` value returned indicates whether they are instantiated or not
fn get_pyty_obj_type(
&self, &self,
obj: &PyAny, pyty: &PyAny,
helper: &PythonHelper, helper: &PythonHelper,
unifier: &mut Unifier, unifier: &mut Unifier,
defs: &[Arc<RwLock<TopLevelDef>>], defs: &[Arc<RwLock<TopLevelDef>>],
primitives: &PrimitiveStore, primitives: &PrimitiveStore,
) -> PyResult<Option<Type>> { ) -> PyResult<Result<(Type, bool), String>> {
// eval_type use only globals_dict should be fine
let evaluated_ty = helper
.eval_type_fn
.call1((pyty, helper.globals_dict, helper.globals_dict)).unwrap();
let ty_id: u64 = helper let ty_id: u64 = helper
.id_fn .id_fn
.call1((helper.type_fn.call1((obj,))?,))? .call1((evaluated_ty,))?
.extract()?;
let ty_ty_id: u64 = helper
.id_fn
.call1((helper.type_fn.call1((evaluated_ty,))?,))?
.extract()?; .extract()?;
if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 { if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 {
Ok(Some(primitives.int32)) Ok(Ok((primitives.int32, true)))
} else if ty_id == self.primitive_ids.int64 { } else if ty_id == self.primitive_ids.int64 {
Ok(Some(primitives.int64)) Ok(Ok((primitives.int64, true)))
} else if ty_id == self.primitive_ids.bool { } else if ty_id == self.primitive_ids.bool {
Ok(Some(primitives.bool)) Ok(Ok((primitives.bool, true)))
} else if ty_id == self.primitive_ids.float { } else if ty_id == self.primitive_ids.float {
Ok(Some(primitives.float)) Ok(Ok((primitives.float, true)))
} else if ty_id == self.primitive_ids.list { } else if ty_id == self.primitive_ids.list {
let len: usize = helper.len_fn.call1((obj,))?.extract()?; // do not handle type var param and concrete check here
if len == 0 {
let var = unifier.get_fresh_var().0; let var = unifier.get_fresh_var().0;
let list = unifier.add_ty(TypeEnum::TList { ty: var }); let list = unifier.add_ty(TypeEnum::TList { ty: var });
Ok(Some(list)) Ok(Ok((list, false)))
} else {
let ty = self.get_list_elem_type(obj, len, helper, unifier, defs, primitives)?;
Ok(ty.map(|ty| unifier.add_ty(TypeEnum::TList { ty })))
}
} else if ty_id == self.primitive_ids.tuple { } else if ty_id == self.primitive_ids.tuple {
let elements: &PyTuple = obj.cast_as()?; // do not handle type var param and concrete check here
let types: Result<Option<Vec<_>>, _> = elements Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false)))
.iter() } else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).cloned() {
.map(|elem| self.get_obj_type(elem, helper, unifier, defs, primitives)) // println!("getting def");
.collect();
let types = types?;
Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types })))
} else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id) {
let def = defs[def_id.0].read(); let def = defs[def_id.0].read();
// println!("got def");
if let TopLevelDef::Class { if let TopLevelDef::Class {
object_id, object_id,
type_vars, type_vars,
@ -120,35 +129,260 @@ impl Resolver {
.. ..
} = &*def } = &*def
{ {
let var_map: HashMap<_, _> = type_vars // do not handle type var param and concrete check here, and no subst
Ok(Ok({
let ty = TypeEnum::TObj {
obj_id: *object_id,
params: RefCell::new({
type_vars
.iter() .iter()
.map(|var| { .map(|x| {
( if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x) {
if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) { (*id, *x)
*id } else { unreachable!() }
}).collect()
}),
fields: RefCell::new({
let mut res = methods
.iter()
.map(|(iden, ty, _)| (*iden, (*ty, false)))
.collect::<HashMap<_, _>>();
res.extend(fields.clone().into_iter().map(|x| (x.0, (x.1, x.2))));
res
})
};
// here also false, later insta use python object to check compatible
(unifier.add_ty(ty), false)
}))
} else {
// only object is supported, functions are not supported
unreachable!("function type is not supported, should not be queried")
}
} else if ty_ty_id == self.primitive_ids.typevar {
let constraint_types = {
let constraints = pyty.getattr("__constraints__").unwrap();
let mut result: Vec<Type> = vec![];
for i in 0.. {
if let Ok(constr) = constraints.get_item(i) {
result.push({
match self.get_pyty_obj_type(constr, helper, unifier, defs, primitives)? {
Ok((ty, _)) => {
if unifier.is_concrete(ty, &[]) {
ty
} else {
return Ok(Err(format!(
"the {}th constraint of TypeVar `{}` is not concrete",
i + 1,
pyty.getattr("__name__")?.extract::<String>()?
)))
}
},
Err(err) => return Ok(Err(err))
}
})
} else {
break;
}
}
result
};
let res = unifier.get_fresh_var_with_range(&constraint_types).0;
Ok(Ok((res, true)))
} else if ty_ty_id == self.primitive_ids.generic_alias.0 || ty_ty_id == self.primitive_ids.generic_alias.1 {
let origin = helper.origin_ty_fn.call1((evaluated_ty,))?;
let args: &PyTuple = helper.args_ty_fn.call1((evaluated_ty,))?.cast_as()?;
let origin_ty = match self.get_pyty_obj_type(origin, helper, unifier, defs, primitives)? {
Ok((ty, false)) => ty,
Ok((_, true)) => return Ok(Err("instantiated type does not take type parameters".into())),
Err(err) => return Ok(Err(err))
};
match &*unifier.get_ty(origin_ty) {
TypeEnum::TList { .. } => {
if args.len() == 1 {
let ty = match self.get_pyty_obj_type(args.get_item(0), helper, unifier, defs, primitives)? {
Ok(ty) => ty,
Err(err) => return Ok(Err(err))
};
if !unifier.is_concrete(ty.0, &[]) && !ty.1 {
panic!("type list should take concrete parameters in type var ranges")
}
Ok(Ok((unifier.add_ty(TypeEnum::TList { ty: ty.0 }), true)))
} else {
return Ok(Err(format!("type list needs exactly 1 type parameters, found {}", args.len())))
}
},
TypeEnum::TTuple { .. } => {
let args = match args
.iter()
.map(|x| self.get_pyty_obj_type(x, helper, unifier, defs, primitives))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.collect::<Result<Vec<_>, _>>() {
Ok(args) if !args.is_empty() => args
.into_iter()
.map(|(x, check)| if !unifier.is_concrete(x, &[]) && !check {
panic!("type tuple should take concrete parameters in type var ranges")
} else {
x
}
)
.collect::<Vec<_>>(),
Err(err) => return Ok(Err(err)),
_ => return Ok(Err("tuple type needs at least 1 type parameters".to_string()))
};
Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: args }), true)))
},
TypeEnum::TObj { params, obj_id, .. } => {
let subst = {
let params = &*params.borrow();
if params.len() != args.len() {
return Ok(Err(format!(
"for class #{}, expect {} type parameters, got {}.",
obj_id.0,
params.len(),
args.len(),
)))
}
let args = match args
.iter()
.map(|x| self.get_pyty_obj_type(x, helper, unifier, defs, primitives))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.collect::<Result<Vec<_>, _>>() {
Ok(args) => args
.into_iter()
.map(|(x, check)| if !unifier.is_concrete(x, &[]) && !check {
panic!("type class should take concrete parameters in type var ranges")
} else {
x
}
)
.collect::<Vec<_>>(),
Err(err) => return Ok(Err(err)),
};
params
.iter()
.zip(args.iter())
.map(|((id, _), ty)| (*id, *ty))
.collect::<HashMap<_, _>>()
};
Ok(Ok((unifier.subst(origin_ty, &subst).unwrap_or(origin_ty), true)))
},
TypeEnum::TVirtual { .. } => {
if args.len() == 1 {
let ty = match self.get_pyty_obj_type(args.get_item(0), helper, unifier, defs, primitives)? {
Ok(ty) => ty,
Err(err) => return Ok(Err(err))
};
if !unifier.is_concrete(ty.0, &[]) && !ty.1 {
panic!("virtual class should take concrete parameters in type var ranges")
}
Ok(Ok((unifier.add_ty(TypeEnum::TVirtual { ty: ty.0 }), true)))
} else {
return Ok(Err(format!("virtual class needs exactly 1 type parameters, found {}", args.len())))
}
}
_ => unimplemented!()
}
} else if ty_id == self.primitive_ids.virtual_id {
Ok(Ok(({
let ty = TypeEnum::TVirtual { ty: unifier.get_fresh_var().0 };
unifier.add_ty(ty)
}, false)))
} else {
Ok(Err("unknown type".into()))
}
}
fn get_obj_type(
&self,
obj: &PyAny,
helper: &PythonHelper,
unifier: &mut Unifier,
defs: &[Arc<RwLock<TopLevelDef>>],
primitives: &PrimitiveStore,
) -> PyResult<Option<Type>> {
let (extracted_ty, inst_check) = match self.get_pyty_obj_type(
{
let ty = helper.type_fn.call1((obj,)).unwrap();
if [self.primitive_ids.typevar,
self.primitive_ids.generic_alias.0,
self.primitive_ids.generic_alias.1
].contains(&helper.id_fn.call1((ty,))?.extract::<u64>()?) {
obj
} else {
ty
}
},
helper,
unifier,
defs,
primitives
)? {
Ok(s) => s,
Err(_) => return Ok(None)
};
return match (&*unifier.get_ty(extracted_ty), inst_check) {
// do the instantiation for these three types
(TypeEnum::TList { ty }, false) => {
let len: usize = helper.len_fn.call1((obj,))?.extract()?;
if len == 0 {
assert!(matches!(
&*unifier.get_ty(extracted_ty),
TypeEnum::TVar { meta: nac3core::typecheck::typedef::TypeVarMeta::Generic, range, .. }
if range.borrow().is_empty()
));
Ok(Some(extracted_ty))
} else {
let actual_ty = self
.get_list_elem_type(obj, len, helper, unifier, defs, primitives)?;
if let Some(actual_ty) = actual_ty {
unifier.unify(*ty, actual_ty).unwrap();
Ok(Some(extracted_ty))
} else {
Ok(None)
}
}
}
(TypeEnum::TTuple { .. }, false) => {
let elements: &PyTuple = obj.cast_as()?;
let types: Result<Option<Vec<_>>, _> = elements
.iter()
.map(|elem| self.get_obj_type(elem, helper, unifier, defs, primitives))
.collect();
let types = types?;
Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types })))
}
(TypeEnum::TObj { params, fields, .. }, false) => {
let var_map = params
.borrow()
.iter()
.map(|(id_var, ty)| {
if let TypeEnum::TVar { id, range, .. } = &*unifier.get_ty(*ty) {
assert_eq!(*id, *id_var);
(*id, unifier.get_fresh_var_with_range(&range.borrow()).0)
} else { } else {
unreachable!() unreachable!()
},
unifier.get_fresh_var().0,
)
})
.collect();
let mut fields_ty = HashMap::new();
for method in methods.iter() {
fields_ty.insert(method.0, (method.1, false));
} }
for field in fields.iter() { })
let name: String = field.0.into(); .collect::<HashMap<_, _>>();
// loop through non-function fields of the class to get the instantiated value
for field in fields.borrow().iter() {
let name: String = (*field.0).into();
if let TypeEnum::TFunc( .. ) = &*unifier.get_ty(field.1.0) {
continue;
} else {
let field_data = obj.getattr(&name)?; let field_data = obj.getattr(&name)?;
let ty = self let ty = self
.get_obj_type(field_data, helper, unifier, defs, primitives)? .get_obj_type(field_data, helper, unifier, defs, primitives)?
.unwrap_or(primitives.none); .unwrap_or(primitives.none);
let field_ty = unifier.subst(field.1, &var_map).unwrap_or(field.1); let field_ty = unifier.subst(field.1.0, &var_map).unwrap_or(field.1.0);
if unifier.unify(ty, field_ty).is_err() { if unifier.unify(ty, field_ty).is_err() {
// field type mismatch // field type mismatch
return Ok(None); return Ok(None);
} }
fields_ty.insert(field.0, (ty, field.2)); }
} }
for (_, ty) in var_map.iter() { for (_, ty) in var_map.iter() {
// must be concrete type // must be concrete type
@ -156,18 +390,10 @@ impl Resolver {
return Ok(None) return Ok(None)
} }
} }
Ok(Some(unifier.add_ty(TypeEnum::TObj { return Ok(Some(unifier.subst(extracted_ty, &var_map).unwrap_or(extracted_ty)));
obj_id: *object_id,
fields: RefCell::new(fields_ty),
params: RefCell::new(var_map),
})))
} else {
// only object is supported, functions are not supported
Ok(None)
}
} else {
Ok(None)
} }
_ => Ok(Some(extracted_ty))
};
} }
fn get_obj_value<'ctx, 'a>( fn get_obj_value<'ctx, 'a>(
@ -425,10 +651,16 @@ impl SymbolResolver for Resolver {
let key: &str = member.get_item(0)?.extract()?; let key: &str = member.get_item(0)?.extract()?;
if key == str.to_string() { if key == str.to_string() {
let builtins = PyModule::import(py, "builtins")?; let builtins = PyModule::import(py, "builtins")?;
let typings = PyModule::import(py, "typing")?;
let helper = PythonHelper { let helper = PythonHelper {
id_fn: builtins.getattr("id").unwrap(), id_fn: builtins.getattr("id").unwrap(),
len_fn: builtins.getattr("len").unwrap(), len_fn: builtins.getattr("len").unwrap(),
type_fn: builtins.getattr("type").unwrap(), type_fn: builtins.getattr("type").unwrap(),
origin_ty_fn: typings.getattr("get_origin").unwrap(),
args_ty_fn: typings.getattr("get_args").unwrap(),
globals_dict: obj.getattr("__dict__").unwrap(),
eval_type_fn: typings.getattr("_eval_type").unwrap(),
print_fn: builtins.getattr("print").unwrap(),
}; };
sym_ty = self.get_obj_type( sym_ty = self.get_obj_type(
member.get_item(1)?, member.get_item(1)?,
@ -469,10 +701,16 @@ impl SymbolResolver for Resolver {
let val = member.get_item(1)?; let val = member.get_item(1)?;
if key == id.to_string() { if key == id.to_string() {
let builtins = PyModule::import(py, "builtins")?; let builtins = PyModule::import(py, "builtins")?;
let typings = PyModule::import(py, "typing")?;
let helper = PythonHelper { let helper = PythonHelper {
id_fn: builtins.getattr("id").unwrap(), id_fn: builtins.getattr("id").unwrap(),
len_fn: builtins.getattr("len").unwrap(), len_fn: builtins.getattr("len").unwrap(),
type_fn: builtins.getattr("type").unwrap(), type_fn: builtins.getattr("type").unwrap(),
origin_ty_fn: typings.getattr("get_origin").unwrap(),
args_ty_fn: typings.getattr("get_args").unwrap(),
globals_dict: obj.getattr("__dict__").unwrap(),
eval_type_fn: typings.getattr("_eval_type").unwrap(),
print_fn: builtins.getattr("print").unwrap(),
}; };
sym_value = self.get_obj_value(val, &helper, ctx)?; sym_value = self.get_obj_value(val, &helper, ctx)?;
break; break;

View File

@ -147,8 +147,14 @@ impl ConcreteTypeStore {
fields: fields fields: fields
.borrow() .borrow()
.iter() .iter()
.map(|(name, ty)| { .filter_map(|(name, ty)| {
(*name, (self.from_unifier_type(unifier, primitives, ty.0, cache), ty.1)) // filter out functions as they can have type vars and
// will not affect codegen
if let TypeEnum::TFunc( .. ) = &*unifier.get_ty(ty.0) {
None
} else {
Some((*name, (self.from_unifier_type(unifier, primitives, ty.0, cache), ty.1)))
}
}) })
.collect(), .collect(),
params: params params: params

View File

@ -7,7 +7,7 @@ use crate::{
}, },
symbol_resolver::SymbolValue, symbol_resolver::SymbolValue,
toplevel::{DefinitionId, TopLevelDef}, toplevel::{DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum}, typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier},
}; };
use inkwell::{ use inkwell::{
types::{BasicType, BasicTypeEnum}, types::{BasicType, BasicTypeEnum},
@ -21,6 +21,31 @@ use nac3parser::ast::{
use super::CodeGenerator; use super::CodeGenerator;
pub fn get_subst_key(
unifier: &mut Unifier,
obj: Option<Type>,
fun_vars: &HashMap<u32, Type>,
filter: Option<&Vec<u32>>,
) -> String {
let mut vars = obj
.map(|ty| {
if let TypeEnum::TObj { params, .. } = &*unifier.get_ty(ty) {
params.borrow().clone()
} else {
unreachable!()
}
})
.unwrap_or_default();
vars.extend(fun_vars.iter());
let sorted =
vars.keys().filter(|id| filter.map(|v| v.contains(id)).unwrap_or(true)).sorted();
sorted
.map(|id| {
unifier.stringify(vars[id], &mut |id| id.to_string(), &mut |id| id.to_string())
})
.join(", ")
}
impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
pub fn build_gep_and_load( pub fn build_gep_and_load(
&mut self, &mut self,
@ -36,23 +61,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
fun: &FunSignature, fun: &FunSignature,
filter: Option<&Vec<u32>>, filter: Option<&Vec<u32>>,
) -> String { ) -> String {
let mut vars = obj get_subst_key(&mut self.unifier, obj, &fun.vars, filter)
.map(|ty| {
if let TypeEnum::TObj { params, .. } = &*self.unifier.get_ty(ty) {
params.borrow().clone()
} else {
unreachable!()
}
})
.unwrap_or_default();
vars.extend(fun.vars.iter());
let sorted =
vars.keys().filter(|id| filter.map(|v| v.contains(id)).unwrap_or(true)).sorted();
sorted
.map(|id| {
self.unifier.stringify(vars[id], &mut |id| id.to_string(), &mut |id| id.to_string())
})
.join(", ")
} }
pub fn get_attr_index(&mut self, ty: Type, attr: StrRef) -> usize { pub fn get_attr_index(&mut self, ty: Type, attr: StrRef) -> usize {

View File

@ -13,7 +13,7 @@ use crate::{
use crate::{location::Location, typecheck::typedef::TypeEnum}; use crate::{location::Location, typecheck::typedef::TypeEnum};
use inkwell::values::BasicValueEnum; use inkwell::values::BasicValueEnum;
use itertools::{chain, izip}; use itertools::{chain, izip};
use nac3parser::ast::{Expr, StrRef}; use nac3parser::ast::{Constant::Str, Expr, StrRef};
use parking_lot::RwLock; use parking_lot::RwLock;
#[derive(Clone, PartialEq, Debug)] #[derive(Clone, PartialEq, Debug)]
@ -79,8 +79,7 @@ pub fn parse_type_annotation<T>(
let list_id = ids[6]; let list_id = ids[6];
let tuple_id = ids[7]; let tuple_id = ids[7];
match &expr.node { let name_handling = |id: &StrRef, unifier: &mut Unifier| {
Name { id, .. } => {
if *id == int32_id { if *id == int32_id {
Ok(primitives.int32) Ok(primitives.int32)
} else if *id == int64_id { } else if *id == int64_id {
@ -129,9 +128,9 @@ pub fn parse_type_annotation<T>(
} }
} }
} }
} };
Subscript { value, slice, .. } => {
if let Name { id, .. } = &value.node { let subscript_name_handle = |id: &StrRef, slice: &Expr<T>, unifier: &mut Unifier| {
if *id == virtual_id { if *id == virtual_id {
let ty = parse_type_annotation( let ty = parse_type_annotation(
resolver, resolver,
@ -232,6 +231,16 @@ pub fn parse_type_annotation<T>(
Err("Cannot use function name as type".into()) Err("Cannot use function name as type".into())
} }
} }
};
match &expr.node {
Name { id, .. } => name_handling(id, unifier),
Constant { value: Str(id), .. } => name_handling(&id.clone().into(), unifier),
Subscript { value, slice, .. } => {
if let Name { id, .. } = &value.node {
subscript_name_handle(id, slice, unifier)
} else if let Constant { value: Str(id), .. } = &value.node {
subscript_name_handle(&id.clone().into(), slice, unifier)
} else { } else {
Err("unsupported type expression".into()) Err("unsupported type expression".into())
} }

View File

@ -6,6 +6,7 @@ use inkwell::FloatPredicate;
use crate::{ use crate::{
symbol_resolver::SymbolValue, symbol_resolver::SymbolValue,
typecheck::type_inferencer::{FunctionData, Inferencer}, typecheck::type_inferencer::{FunctionData, Inferencer},
codegen::expr::get_subst_key,
}; };
use super::*; use super::*;
@ -534,7 +535,7 @@ impl TopLevelComposer {
} }
} }
fn extract_def_list(&self) -> Vec<Arc<RwLock<TopLevelDef>>> { pub fn extract_def_list(&self) -> Vec<Arc<RwLock<TopLevelDef>>> {
self.definition_ast_list.iter().map(|(def, ..)| def.clone()).collect_vec() self.definition_ast_list.iter().map(|(def, ..)| def.clone()).collect_vec()
} }
@ -1654,7 +1655,7 @@ impl TopLevelComposer {
if let TypeEnum::TFunc(func_sig) = self.unifier.get_ty(*signature).as_ref() { if let TypeEnum::TFunc(func_sig) = self.unifier.get_ty(*signature).as_ref() {
let FunSignature { args, ret, vars } = &*func_sig.borrow(); let FunSignature { args, ret, vars } = &*func_sig.borrow();
// None if is not class method // None if is not class method
let self_type = { let uninst_self_type = {
if let Some(class_id) = self.method_class.get(&DefinitionId(id)) { if let Some(class_id) = self.method_class.get(&DefinitionId(id)) {
let class_def = self.definition_ast_list.get(class_id.0).unwrap(); let class_def = self.definition_ast_list.get(class_id.0).unwrap();
let class_def = class_def.0.read(); let class_def = class_def.0.read();
@ -1666,7 +1667,7 @@ impl TopLevelComposer {
&self.primitives_ty, &self.primitives_ty,
&ty_ann, &ty_ann,
)?; )?;
Some(self_ty) Some((self_ty, type_vars.clone()))
} else { } else {
unreachable!("must be class def") unreachable!("must be class def")
} }
@ -1717,9 +1718,34 @@ impl TopLevelComposer {
}; };
let self_type = { let self_type = {
let unifier = &mut self.unifier; let unifier = &mut self.unifier;
self_type.map(|x| unifier.subst(x, &subst).unwrap_or(x)) uninst_self_type
.clone()
.map(|(self_type, type_vars)| {
let subst_for_self = {
let class_ty_var_ids = type_vars
.iter()
.map(|x| {
if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x) {
*id
} else {
unreachable!("must be type var here");
}
})
.collect::<HashSet<_>>();
subst
.iter()
.filter_map(|(ty_var_id, ty_var_target)| {
if class_ty_var_ids.contains(ty_var_id) {
Some((*ty_var_id, *ty_var_target))
} else {
None
}
})
.collect::<HashMap<_, _>>()
};
unifier.subst(self_type, &subst_for_self).unwrap_or(self_type)
})
}; };
let mut identifiers = { let mut identifiers = {
// NOTE: none and function args? // NOTE: none and function args?
let mut result: HashSet<_> = HashSet::new(); let mut result: HashSet<_> = HashSet::new();
@ -1810,21 +1836,13 @@ impl TopLevelComposer {
instance_to_stmt.insert( instance_to_stmt.insert(
// NOTE: refer to codegen/expr/get_subst_key function // NOTE: refer to codegen/expr/get_subst_key function
{
let unifier = &mut self.unifier; get_subst_key(
subst &mut self.unifier,
.keys() self_type,
.sorted() &subst,
.map(|id| { None
let ty = subst.get(id).unwrap(); ),
unifier.stringify(
*ty,
&mut |id| id.to_string(),
&mut |id| id.to_string(),
)
})
.join(", ")
},
FunInstance { FunInstance {
body: Arc::new(fun_body), body: Arc::new(fun_body),
unifier_id: 0, unifier_id: 0,

View File

@ -1,7 +1,7 @@
use std::cell::RefCell; use std::cell::RefCell;
use crate::typecheck::typedef::TypeVarMeta; use crate::typecheck::typedef::TypeVarMeta;
use ast::Constant::Str;
use super::*; use super::*;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -49,10 +49,9 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
primitives: &PrimitiveStore, primitives: &PrimitiveStore,
expr: &ast::Expr<T>, expr: &ast::Expr<T>,
// the key stores the type_var of this topleveldef::class, we only need this field here // the key stores the type_var of this topleveldef::class, we only need this field here
mut locked: HashMap<DefinitionId, Vec<Type>>, locked: HashMap<DefinitionId, Vec<Type>>,
) -> Result<TypeAnnotation, String> { ) -> Result<TypeAnnotation, String> {
match &expr.node { let name_handle = |id: &StrRef, unifier: &mut Unifier, locked: HashMap<DefinitionId, Vec<Type>>| {
ast::ExprKind::Name { id, .. } => {
if id == &"int32".into() { if id == &"int32".into() {
Ok(TypeAnnotation::Primitive(primitives.int32)) Ok(TypeAnnotation::Primitive(primitives.int32))
} else if id == &"int64".into() { } else if id == &"int64".into() {
@ -95,74 +94,10 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
} else { } else {
Err("name cannot be parsed as a type annotation".into()) Err("name cannot be parsed as a type annotation".into())
} }
} };
// virtual let class_name_handle =
ast::ExprKind::Subscript { value, slice, .. } |id: &StrRef, slice: &ast::Expr<T>, unifier: &mut Unifier, mut locked: HashMap<DefinitionId, Vec<Type>>| {
if {
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"virtual".into())
} =>
{
let def = parse_ast_to_type_annotation_kinds(
resolver,
top_level_defs,
unifier,
primitives,
slice.as_ref(),
locked,
)?;
if !matches!(def, TypeAnnotation::CustomClass { .. }) {
unreachable!("must be concretized custom class kind in the virtual")
}
Ok(TypeAnnotation::Virtual(def.into()))
}
// list
ast::ExprKind::Subscript { value, slice, .. }
if {
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"list".into())
} =>
{
let def_ann = parse_ast_to_type_annotation_kinds(
resolver,
top_level_defs,
unifier,
primitives,
slice.as_ref(),
locked,
)?;
Ok(TypeAnnotation::List(def_ann.into()))
}
// tuple
ast::ExprKind::Subscript { value, slice, .. }
if {
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"tuple".into())
} =>
{
if let ast::ExprKind::Tuple { elts, .. } = &slice.node {
let type_annotations = elts
.iter()
.map(|e| {
parse_ast_to_type_annotation_kinds(
resolver,
top_level_defs,
unifier,
primitives,
e,
locked.clone(),
)
})
.collect::<Result<Vec<_>, _>>()?;
Ok(TypeAnnotation::Tuple(type_annotations))
} else {
Err("Expect multiple elements for tuple".into())
}
}
// custom class
ast::ExprKind::Subscript { value, slice, .. } => {
if let ast::ExprKind::Name { id, .. } = &value.node {
if vec!["virtual".into(), "Generic".into(), "list".into(), "tuple".into()] if vec!["virtual".into(), "Generic".into(), "list".into(), "tuple".into()]
.contains(id) .contains(id)
{ {
@ -188,7 +123,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
let params_ast = if let ast::ExprKind::Tuple { elts, .. } = &slice.node { let params_ast = if let ast::ExprKind::Tuple { elts, .. } = &slice.node {
elts.iter().collect_vec() elts.iter().collect_vec()
} else { } else {
vec![slice.as_ref()] vec![slice]
}; };
if type_vars.len() != params_ast.len() { if type_vars.len() != params_ast.len() {
return Err(format!( return Err(format!(
@ -213,7 +148,6 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
) )
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
// make sure the result do not contain any type vars // make sure the result do not contain any type vars
let no_type_var = result let no_type_var = result
.iter() .iter()
@ -226,8 +160,83 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
.into()); .into());
} }
}; };
Ok(TypeAnnotation::CustomClass { id: obj_id, params: param_type_infos }) Ok(TypeAnnotation::CustomClass { id: obj_id, params: param_type_infos })
};
match &expr.node {
ast::ExprKind::Name { id, .. } => name_handle(id, unifier, locked),
ast::ExprKind::Constant { value: Str(id), .. } => name_handle(&id.clone().into(), unifier, locked),
// virtual
ast::ExprKind::Subscript { value, slice, .. }
if {
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"virtual".into()) ||
matches!(&value.node, ast::ExprKind::Constant { value: Str(id), .. } if id == "virtual")
} =>
{
let def = parse_ast_to_type_annotation_kinds(
resolver,
top_level_defs,
unifier,
primitives,
slice.as_ref(),
locked,
)?;
if !matches!(def, TypeAnnotation::CustomClass { .. }) {
unreachable!("must be concretized custom class kind in the virtual")
}
Ok(TypeAnnotation::Virtual(def.into()))
}
// list
ast::ExprKind::Subscript { value, slice, .. }
if {
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"list".into()) ||
matches!(&value.node, ast::ExprKind::Constant { value: Str(id), .. } if id == "list")
} =>
{
let def_ann = parse_ast_to_type_annotation_kinds(
resolver,
top_level_defs,
unifier,
primitives,
slice.as_ref(),
locked,
)?;
Ok(TypeAnnotation::List(def_ann.into()))
}
// tuple
ast::ExprKind::Subscript { value, slice, .. }
if {
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"tuple".into()) ||
matches!(&value.node, ast::ExprKind::Constant { value: Str(id), .. } if id == "tuple")
} =>
{
if let ast::ExprKind::Tuple { elts, .. } = &slice.node {
let type_annotations = elts
.iter()
.map(|e| {
parse_ast_to_type_annotation_kinds(
resolver,
top_level_defs,
unifier,
primitives,
e,
locked.clone(),
)
})
.collect::<Result<Vec<_>, _>>()?;
Ok(TypeAnnotation::Tuple(type_annotations))
} else {
Err("Expect multiple elements for tuple".into())
}
}
// custom class
ast::ExprKind::Subscript { value, slice, .. } => {
if let ast::ExprKind::Name { id, .. } = &value.node {
class_name_handle(id, slice, unifier, locked)
} else if let ast::ExprKind::Constant { value: Str(id), .. } = &value.node {
class_name_handle(&id.clone().into(), slice, unifier, locked)
} else { } else {
Err("unsupported expression type for class name".into()) Err("unsupported expression type for class name".into())
} }
@ -280,9 +289,11 @@ pub fn get_type_from_type_annotation_kinds(
{ {
let ok: bool = { let ok: bool = {
// create a temp type var and unify to check compatibility // create a temp type var and unify to check compatibility
p == *tvar || {
let temp = let temp =
unifier.get_fresh_var_with_range(range.borrow().as_slice()); unifier.get_fresh_var_with_range(range.borrow().as_slice());
unifier.unify(temp.0, p).is_ok() unifier.unify(temp.0, p).is_ok()
}
}; };
if ok { if ok {
result.insert(*id, p); result.insert(*id, p);
@ -368,13 +379,7 @@ pub fn get_type_from_type_annotation_kinds(
/// But note that here we do not make a duplication of `T`, `V`, we direclty /// But note that here we do not make a duplication of `T`, `V`, we direclty
/// use them as they are in the TopLevelDef::Class since those in the /// use them as they are in the TopLevelDef::Class since those in the
/// TopLevelDef::Class.type_vars will be substitute later when seeing applications/instantiations /// TopLevelDef::Class.type_vars will be substitute later when seeing applications/instantiations
/// the Type of their fields and methods will also be subst when application/instantiation \ /// the Type of their fields and methods will also be subst when application/instantiation
/// \
/// Note this implicit self type is different with seeing `A[T, V]` explicitly outside
/// the class def ast body, where it is a new instantiation of the generic class `A`,
/// but equivalent to seeing `A[T, V]` inside the class def body ast, where although we
/// create copies of `T` and `V`, we will find them out as occured type vars in the analyze_class()
/// and unify them with the class generic `T`, `V`
pub fn make_self_type_annotation(type_vars: &[Type], object_id: DefinitionId) -> TypeAnnotation { pub fn make_self_type_annotation(type_vars: &[Type], object_id: DefinitionId) -> TypeAnnotation {
TypeAnnotation::CustomClass { TypeAnnotation::CustomClass {
id: object_id, id: object_id,

View File

@ -719,22 +719,18 @@ impl Unifier {
/// Returns Some(T) where T is the instantiated type. /// Returns Some(T) where T is the instantiated type.
/// Returns None if the function is already instantiated. /// Returns None if the function is already instantiated.
fn instantiate_fun(&mut self, ty: Type, fun: &FunSignature) -> Type { fn instantiate_fun(&mut self, ty: Type, fun: &FunSignature) -> Type {
let mut instantiated = false; let mut instantiated = true;
let mut vars = Vec::new(); let mut vars = Vec::new();
for (k, v) in fun.vars.iter() { for (k, v) in fun.vars.iter() {
if let TypeEnum::TVar { id, range, .. } = if let TypeEnum::TVar { id, range, .. } =
self.unification_table.probe_value(*v).as_ref() self.unification_table.probe_value(*v).as_ref()
{ {
if k != id { // need to do this for partial instantiated function
instantiated = true; // (in class methods that contains type vars not in class)
break; if k == id {
} instantiated = false;
// actually, if the first check succeeded, the function should be uninstatiated.
// The cloned values must be used and would not be wasted.
vars.push((*k, range.clone())); vars.push((*k, range.clone()));
} else { }
instantiated = true;
break;
} }
} }
if instantiated { if instantiated {

View File

@ -4,7 +4,7 @@ use inkwell::{
OptimizationLevel, OptimizationLevel,
}; };
use nac3core::typecheck::type_inferencer::PrimitiveStore; use nac3core::typecheck::type_inferencer::PrimitiveStore;
use nac3parser::parser; use nac3parser::{ast::{ExprKind, StmtKind}, parser};
use std::env; use std::env;
use std::fs; use std::fs;
use std::{collections::HashMap, path::Path, sync::Arc, time::SystemTime}; use std::{collections::HashMap, path::Path, sync::Arc, time::SystemTime};
@ -66,6 +66,46 @@ fn main() {
); );
for stmt in parser_result.into_iter() { for stmt in parser_result.into_iter() {
// handle type vars in toplevel
if let StmtKind::Assign { value, targets, .. } = &stmt.node {
assert_eq!(targets.len(), 1, "only support single assignment for now, at {}", targets[0].location);
if let ExprKind::Call { func, args, .. } = &value.node {
if matches!(&func.node, ExprKind::Name { id, .. } if id == &"TypeVar".into()) {
print!("registering typevar {:?}", targets[0].node);
let constraints = args
.iter()
.skip(1)
.map(|x| {
let def_list = &composer.extract_def_list();
let unifier = &mut composer.unifier;
resolver.parse_type_annotation(
def_list,
unifier,
&primitive,
x
).unwrap()
})
.collect::<Vec<_>>();
let res_ty = composer.unifier.get_fresh_var_with_range(&constraints).0;
println!(
" ...registered: {}",
composer.unifier.stringify(
res_ty,
&mut |x| format!("obj{}", x),
&mut |x| format!("tavr{}", x)
)
);
internal_resolver.add_id_type(
if let ExprKind::Name { id, .. } = &targets[0].node { *id } else {
panic!("must assign simple name variable as type variable for now")
},
res_ty
);
continue;
}
}
}
let (name, def_id, ty) = composer let (name, def_id, ty) = composer
.register_top_level(stmt, Some(resolver.clone()), "__main__".into()) .register_top_level(stmt, Some(resolver.clone()), "__main__".into())
.unwrap(); .unwrap();