TypeVar and virtual support in Symbol Resolver #99

Merged
sb10q merged 12 commits from symbol_resolver_typevar into master 2021-12-01 22:44:53 +08:00
3 changed files with 343 additions and 66 deletions
Showing only changes of commit a3faa9b7dd - Show all commits

View File

@ -9,17 +9,23 @@ import nac3artiq
__all__ = ["KernelInvariant", "extern", "kernel", "portable", "nac3",
"ms", "us", "ns",
"print_int32", "print_int64",
"Core", "TTLOut", "parallel", "sequential"]
"Core", "TTLOut", "parallel", "sequential", "virtual"]
T = TypeVar('T')
class KernelInvariant(Generic[T]):
pass
# place the `virtual` class infront of the construct of NAC3 object to ensure the
# virtual class is known during the initializing of NAC3 object
Review

Unclear - do you have an example?

Unclear - do you have an example?
Review

Sorry for the unclear comment. Here I mean that before this line in min_artiq.py: compiler = nac3artiq.NAC3(core_arguments["target"]), the class virtual should already be known to CPython (because the construction of compiler needs to read get the id of it)

if we move this class definition below:

compiler = nac3artiq.NAC3(core_arguments["target"])
class virtual(Generic[T]):
    pass

and run the compiler, we will have:

PyErr { type: <class 'KeyError'>, value: KeyError('virtual'), traceback: None }
Sorry for the unclear comment. Here I mean that before this line in `min_artiq.py`: `compiler = nac3artiq.NAC3(core_arguments["target"])`, the class `virtual` should already be known to CPython (because the construction of `compiler` needs to read get the id of it) if we move this class definition below: ```python compiler = nac3artiq.NAC3(core_arguments["target"]) class virtual(Generic[T]): pass ``` and run the compiler, we will have: ``` PyErr { type: <class 'KeyError'>, value: KeyError('virtual'), traceback: None } ```
Review

By the way, why is this different than the KernelInvariant class ?

By the way, why is this different than the KernelInvariant class ?
Review

As I understand it, here during the initialization of the nac3 python object it needs to know the id of class virtual becuase virtual[..] is actually a type that can exist in the range of a typevar, and when resolving typevars, nac3artiq needs the id of it to know that it is handling a virtual.

While KernelInvariant in min_artiq.py is just for cpython to not complain about unknown name and we do not need to know its id because so far it only occurs in class fields definition and nac3 do not need to know any cpython related information about it since we are directly looking into the string in the ast.

As I understand it, here during the initialization of the nac3 python object it needs to know the id of class `virtual` becuase `virtual[..]` is actually a type that can exist in the range of a typevar, and when resolving typevars, nac3artiq needs the id of it to know that it is handling a `virtual`. While KernelInvariant in `min_artiq.py` is just for cpython to not complain about unknown name and we do not need to know its id because so far it only occurs in class fields definition and nac3 do not need to know any cpython related information about it since we are directly looking into the string in the ast.
class virtual(Generic[T]):
pass
import device_db
core_arguments = device_db.device_db["core"]["arguments"]
compiler = nac3artiq.NAC3(core_arguments["target"])
allow_registration = True
# Delay NAC3 analysis until all referenced variables are supposed to exist on the CPython side.

View File

@ -52,6 +52,10 @@ pub struct PrimitivePythonId {
bool: u64,
list: u64,
tuple: u64,
typevar: u64,
none: u64,
generic_alias: (u64, u64),
virtual_id: u64,
}
// TopLevelComposer is unsendable as it holds the unification table, which is
@ -267,7 +271,36 @@ impl Nac3 {
let builtins_mod = PyModule::import(py, "builtins").unwrap();
let id_fn = builtins_mod.getattr("id").unwrap();
let numpy_mod = PyModule::import(py, "numpy").unwrap();
let typing_mod = PyModule::import(py, "typing").unwrap();
let types_mod = PyModule::import(py, "types").unwrap();
let primitive_ids = PrimitivePythonId {
virtual_id: py.eval(
"id(virtual)",
Some(builtins_mod.getattr("globals").unwrap().call0().unwrap().extract().unwrap()),
None
).unwrap().extract().unwrap(),
generic_alias: (
id_fn
.call1((typing_mod.getattr("_GenericAlias").unwrap(),))
.unwrap()
.extract()
.unwrap(),
id_fn
.call1((types_mod.getattr("GenericAlias").unwrap(),))
.unwrap()
.extract()
.unwrap(),
),
none: id_fn
.call1((builtins_mod.getattr("None").unwrap(),))
.unwrap()
.extract()
.unwrap(),
typevar: id_fn
.call1((typing_mod.getattr("TypeVar").unwrap(),))
.unwrap()
.extract()
.unwrap(),
int: id_fn
.call1((builtins_mod.getattr("int").unwrap(),))
.unwrap()

View File

@ -40,6 +40,11 @@ struct PythonHelper<'a> {
type_fn: &'a PyAny,
len_fn: &'a PyAny,
id_fn: &'a PyAny,
eval_type_fn: &'a PyAny,
origin_ty_fn: &'a PyAny,
args_ty_fn: &'a PyAny,
globals_dict: &'a PyAny,
print_fn: &'a PyAny,
}
impl Resolver {
@ -71,47 +76,51 @@ impl Resolver {
}))
}
fn get_obj_type(
// handle python objects that represent types themselves
// primitives and class types should be themselves, use `ty_id` to check,
// TypeVars and GenericAlias(`A[int, bool]`) should use `ty_ty_id` to check
// the `bool` value returned indicates whether they are instantiated or not
fn get_pyty_obj_type(
&self,
obj: &PyAny,
pyty: &PyAny,
helper: &PythonHelper,
unifier: &mut Unifier,
defs: &[Arc<RwLock<TopLevelDef>>],
primitives: &PrimitiveStore,
) -> PyResult<Option<Type>> {
) -> PyResult<Result<(Type, bool), String>> {
// eval_type use only globals_dict should be fine
let evaluated_ty = helper
.eval_type_fn
.call1((pyty, helper.globals_dict, helper.globals_dict)).unwrap();
let ty_id: u64 = helper
.id_fn
.call1((helper.type_fn.call1((obj,))?,))?
.call1((evaluated_ty,))?
.extract()?;
let ty_ty_id: u64 = helper
.id_fn
.call1((helper.type_fn.call1((evaluated_ty,))?,))?
.extract()?;
if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 {
Ok(Some(primitives.int32))
Ok(Ok((primitives.int32, true)))
} else if ty_id == self.primitive_ids.int64 {
Ok(Some(primitives.int64))
Ok(Ok((primitives.int64, true)))
} else if ty_id == self.primitive_ids.bool {
Ok(Some(primitives.bool))
Ok(Ok((primitives.bool, true)))
} else if ty_id == self.primitive_ids.float {
Ok(Some(primitives.float))
Ok(Ok((primitives.float, true)))
} else if ty_id == self.primitive_ids.list {
let len: usize = helper.len_fn.call1((obj,))?.extract()?;
if len == 0 {
let var = unifier.get_fresh_var().0;
let list = unifier.add_ty(TypeEnum::TList { ty: var });
Ok(Some(list))
} else {
let ty = self.get_list_elem_type(obj, len, helper, unifier, defs, primitives)?;
Ok(ty.map(|ty| unifier.add_ty(TypeEnum::TList { ty })))
}
// do not handle type var param and concrete check here
let var = unifier.get_fresh_var().0;
Outdated
Review

remove println

remove println
let list = unifier.add_ty(TypeEnum::TList { ty: var });
Ok(Ok((list, false)))
Outdated
Review

same

same
} else if ty_id == self.primitive_ids.tuple {
let elements: &PyTuple = obj.cast_as()?;
let types: Result<Option<Vec<_>>, _> = elements
.iter()
.map(|elem| self.get_obj_type(elem, helper, unifier, defs, primitives))
.collect();
let types = types?;
Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types })))
} else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id) {
// do not handle type var param and concrete check here
Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false)))
} else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).cloned() {
// println!("getting def");
let def = defs[def_id.0].read();
// println!("got def");
if let TopLevelDef::Class {
object_id,
type_vars,
@ -120,35 +129,260 @@ impl Resolver {
..
} = &*def
{
let var_map: HashMap<_, _> = type_vars
.iter()
.map(|var| {
(
if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) {
*id
} else {
unreachable!()
},
unifier.get_fresh_var().0,
)
})
.collect();
let mut fields_ty = HashMap::new();
for method in methods.iter() {
fields_ty.insert(method.0, (method.1, false));
}
for field in fields.iter() {
let name: String = field.0.into();
let field_data = obj.getattr(&name)?;
let ty = self
.get_obj_type(field_data, helper, unifier, defs, primitives)?
.unwrap_or(primitives.none);
let field_ty = unifier.subst(field.1, &var_map).unwrap_or(field.1);
if unifier.unify(ty, field_ty).is_err() {
// field type mismatch
return Ok(None);
// do not handle type var param and concrete check here, and no subst
Ok(Ok({
let ty = TypeEnum::TObj {
obj_id: *object_id,
params: RefCell::new({
type_vars
.iter()
.map(|x| {
if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x) {
(*id, *x)
} else { unreachable!() }
}).collect()
}),
fields: RefCell::new({
let mut res = methods
.iter()
.map(|(iden, ty, _)| (*iden, (*ty, false)))
.collect::<HashMap<_, _>>();
res.extend(fields.clone().into_iter().map(|x| (x.0, (x.1, x.2))));
res
})
};
// here also false, later insta use python object to check compatible
(unifier.add_ty(ty), false)
}))
} else {
// only object is supported, functions are not supported
unreachable!("function type is not supported, should not be queried")
}
} else if ty_ty_id == self.primitive_ids.typevar {
let constraint_types = {
let constraints = pyty.getattr("__constraints__").unwrap();
let mut result: Vec<Type> = vec![];
for i in 0.. {
if let Ok(constr) = constraints.get_item(i) {
result.push({
match self.get_pyty_obj_type(constr, helper, unifier, defs, primitives)? {
Ok((ty, _)) => {
if unifier.is_concrete(ty, &[]) {
ty
} else {
return Ok(Err(format!(
"the {}th constraint of TypeVar `{}` is not concrete",
i + 1,
pyty.getattr("__name__")?.extract::<String>()?
)))
}
},
Err(err) => return Ok(Err(err))
}
})
} else {
break;
}
}
result
};
let res = unifier.get_fresh_var_with_range(&constraint_types).0;
Ok(Ok((res, true)))
} else if ty_ty_id == self.primitive_ids.generic_alias.0 || ty_ty_id == self.primitive_ids.generic_alias.1 {
let origin = helper.origin_ty_fn.call1((evaluated_ty,))?;
let args: &PyTuple = helper.args_ty_fn.call1((evaluated_ty,))?.cast_as()?;
let origin_ty = match self.get_pyty_obj_type(origin, helper, unifier, defs, primitives)? {
Ok((ty, false)) => ty,
Ok((_, true)) => return Ok(Err("instantiated type does not take type parameters".into())),
Err(err) => return Ok(Err(err))
};
match &*unifier.get_ty(origin_ty) {
TypeEnum::TList { .. } => {
if args.len() == 1 {
let ty = match self.get_pyty_obj_type(args.get_item(0), helper, unifier, defs, primitives)? {
Ok(ty) => ty,
Err(err) => return Ok(Err(err))
};
if !unifier.is_concrete(ty.0, &[]) && !ty.1 {
panic!("type list should take concrete parameters in type var ranges")
}
Ok(Ok((unifier.add_ty(TypeEnum::TList { ty: ty.0 }), true)))
} else {
return Ok(Err(format!("type list needs exactly 1 type parameters, found {}", args.len())))
}
},
TypeEnum::TTuple { .. } => {
let args = match args
.iter()
.map(|x| self.get_pyty_obj_type(x, helper, unifier, defs, primitives))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.collect::<Result<Vec<_>, _>>() {
Ok(args) if !args.is_empty() => args
.into_iter()
.map(|(x, check)| if !unifier.is_concrete(x, &[]) && !check {
panic!("type tuple should take concrete parameters in type var ranges")
} else {
x
}
)
.collect::<Vec<_>>(),
Err(err) => return Ok(Err(err)),
_ => return Ok(Err("tuple type needs at least 1 type parameters".to_string()))
};
Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: args }), true)))
},
TypeEnum::TObj { params, obj_id, .. } => {
let subst = {
let params = &*params.borrow();
if params.len() != args.len() {
return Ok(Err(format!(
"for class #{}, expect {} type parameters, got {}.",
obj_id.0,
params.len(),
args.len(),
)))
}
let args = match args
.iter()
.map(|x| self.get_pyty_obj_type(x, helper, unifier, defs, primitives))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.collect::<Result<Vec<_>, _>>() {
Ok(args) => args
.into_iter()
.map(|(x, check)| if !unifier.is_concrete(x, &[]) && !check {
panic!("type class should take concrete parameters in type var ranges")
} else {
x
}
)
.collect::<Vec<_>>(),
Err(err) => return Ok(Err(err)),
};
params
.iter()
.zip(args.iter())
.map(|((id, _), ty)| (*id, *ty))
.collect::<HashMap<_, _>>()
};
Ok(Ok((unifier.subst(origin_ty, &subst).unwrap_or(origin_ty), true)))
},
TypeEnum::TVirtual { .. } => {
if args.len() == 1 {
let ty = match self.get_pyty_obj_type(args.get_item(0), helper, unifier, defs, primitives)? {
Ok(ty) => ty,
Err(err) => return Ok(Err(err))
};
if !unifier.is_concrete(ty.0, &[]) && !ty.1 {
panic!("virtual class should take concrete parameters in type var ranges")
}
Ok(Ok((unifier.add_ty(TypeEnum::TVirtual { ty: ty.0 }), true)))
} else {
return Ok(Err(format!("virtual class needs exactly 1 type parameters, found {}", args.len())))
}
}
_ => unimplemented!()
}
} else if ty_id == self.primitive_ids.virtual_id {
Ok(Ok(({
let ty = TypeEnum::TVirtual { ty: unifier.get_fresh_var().0 };
unifier.add_ty(ty)
}, false)))
} else {
Ok(Err("unknown type".into()))
}
}
fn get_obj_type(
&self,
obj: &PyAny,
helper: &PythonHelper,
unifier: &mut Unifier,
defs: &[Arc<RwLock<TopLevelDef>>],
primitives: &PrimitiveStore,
) -> PyResult<Option<Type>> {
let (extracted_ty, inst_check) = match self.get_pyty_obj_type(
{
let ty = helper.type_fn.call1((obj,)).unwrap();
if [self.primitive_ids.typevar,
self.primitive_ids.generic_alias.0,
self.primitive_ids.generic_alias.1
].contains(&helper.id_fn.call1((ty,))?.extract::<u64>()?) {
obj
} else {
ty
}
},
helper,
unifier,
defs,
primitives
)? {
Ok(s) => s,
Err(_) => return Ok(None)
};
return match (&*unifier.get_ty(extracted_ty), inst_check) {
// do the instantiation for these three types
(TypeEnum::TList { ty }, false) => {
let len: usize = helper.len_fn.call1((obj,))?.extract()?;
if len == 0 {
assert!(matches!(
&*unifier.get_ty(extracted_ty),
TypeEnum::TVar { meta: nac3core::typecheck::typedef::TypeVarMeta::Generic, range, .. }
if range.borrow().is_empty()
));
Ok(Some(extracted_ty))
} else {
let actual_ty = self
.get_list_elem_type(obj, len, helper, unifier, defs, primitives)?;
if let Some(actual_ty) = actual_ty {
unifier.unify(*ty, actual_ty).unwrap();
Ok(Some(extracted_ty))
} else {
Ok(None)
}
}
}
(TypeEnum::TTuple { .. }, false) => {
let elements: &PyTuple = obj.cast_as()?;
let types: Result<Option<Vec<_>>, _> = elements
.iter()
.map(|elem| self.get_obj_type(elem, helper, unifier, defs, primitives))
.collect();
let types = types?;
Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types })))
}
(TypeEnum::TObj { params, fields, .. }, false) => {
let var_map = params
.borrow()
.iter()
.map(|(id_var, ty)| {
if let TypeEnum::TVar { id, range, .. } = &*unifier.get_ty(*ty) {
assert_eq!(*id, *id_var);
(*id, unifier.get_fresh_var_with_range(&range.borrow()).0)
} else {
unreachable!()
}
})
.collect::<HashMap<_, _>>();
// loop through non-function fields of the class to get the instantiated value
for field in fields.borrow().iter() {
let name: String = (*field.0).into();
if let TypeEnum::TFunc( .. ) = &*unifier.get_ty(field.1.0) {
continue;
} else {
let field_data = obj.getattr(&name)?;
let ty = self
.get_obj_type(field_data, helper, unifier, defs, primitives)?
.unwrap_or(primitives.none);
let field_ty = unifier.subst(field.1.0, &var_map).unwrap_or(field.1.0);
if unifier.unify(ty, field_ty).is_err() {
// field type mismatch
return Ok(None);
}
}
fields_ty.insert(field.0, (ty, field.2));
}
for (_, ty) in var_map.iter() {
// must be concrete type
@ -156,18 +390,10 @@ impl Resolver {
return Ok(None)
}
}
Ok(Some(unifier.add_ty(TypeEnum::TObj {
obj_id: *object_id,
fields: RefCell::new(fields_ty),
params: RefCell::new(var_map),
})))
} else {
// only object is supported, functions are not supported
Ok(None)
return Ok(Some(unifier.subst(extracted_ty, &var_map).unwrap_or(extracted_ty)));
}
} else {
Ok(None)
}
_ => Ok(Some(extracted_ty))
};
}
fn get_obj_value<'ctx, 'a>(
@ -425,10 +651,16 @@ impl SymbolResolver for Resolver {
let key: &str = member.get_item(0)?.extract()?;
if key == str.to_string() {
let builtins = PyModule::import(py, "builtins")?;
let typings = PyModule::import(py, "typing")?;
let helper = PythonHelper {
id_fn: builtins.getattr("id").unwrap(),
len_fn: builtins.getattr("len").unwrap(),
type_fn: builtins.getattr("type").unwrap(),
origin_ty_fn: typings.getattr("get_origin").unwrap(),
args_ty_fn: typings.getattr("get_args").unwrap(),
globals_dict: obj.getattr("__dict__").unwrap(),
eval_type_fn: typings.getattr("_eval_type").unwrap(),
print_fn: builtins.getattr("print").unwrap(),
};
sym_ty = self.get_obj_type(
member.get_item(1)?,
@ -469,10 +701,16 @@ impl SymbolResolver for Resolver {
let val = member.get_item(1)?;
if key == id.to_string() {
let builtins = PyModule::import(py, "builtins")?;
let typings = PyModule::import(py, "typing")?;
let helper = PythonHelper {
id_fn: builtins.getattr("id").unwrap(),
len_fn: builtins.getattr("len").unwrap(),
type_fn: builtins.getattr("type").unwrap(),
origin_ty_fn: typings.getattr("get_origin").unwrap(),
args_ty_fn: typings.getattr("get_args").unwrap(),
globals_dict: obj.getattr("__dict__").unwrap(),
eval_type_fn: typings.getattr("_eval_type").unwrap(),
print_fn: builtins.getattr("print").unwrap(),
};
sym_value = self.get_obj_value(val, &helper, ctx)?;
break;