TypeVar and virtual support in Symbol Resolver #99

Merged
sb10q merged 12 commits from symbol_resolver_typevar into master 2021-12-01 22:44:53 +08:00
3 changed files with 343 additions and 66 deletions
Showing only changes of commit a3faa9b7dd - Show all commits

View File

@ -9,17 +9,23 @@ import nac3artiq
__all__ = ["KernelInvariant", "extern", "kernel", "portable", "nac3", __all__ = ["KernelInvariant", "extern", "kernel", "portable", "nac3",
"ms", "us", "ns", "ms", "us", "ns",
"print_int32", "print_int64", "print_int32", "print_int64",
"Core", "TTLOut", "parallel", "sequential"] "Core", "TTLOut", "parallel", "sequential", "virtual"]
T = TypeVar('T') T = TypeVar('T')
class KernelInvariant(Generic[T]): class KernelInvariant(Generic[T]):
pass pass
# place the `virtual` class infront of the construct of NAC3 object to ensure the
# virtual class is known during the initializing of NAC3 object
Review

Unclear - do you have an example?

Unclear - do you have an example?
Review

Sorry for the unclear comment. Here I mean that before this line in min_artiq.py: compiler = nac3artiq.NAC3(core_arguments["target"]), the class virtual should already be known to CPython (because the construction of compiler needs to read get the id of it)

if we move this class definition below:

compiler = nac3artiq.NAC3(core_arguments["target"])
class virtual(Generic[T]):
    pass

and run the compiler, we will have:

PyErr { type: <class 'KeyError'>, value: KeyError('virtual'), traceback: None }
Sorry for the unclear comment. Here I mean that before this line in `min_artiq.py`: `compiler = nac3artiq.NAC3(core_arguments["target"])`, the class `virtual` should already be known to CPython (because the construction of `compiler` needs to read get the id of it) if we move this class definition below: ```python compiler = nac3artiq.NAC3(core_arguments["target"]) class virtual(Generic[T]): pass ``` and run the compiler, we will have: ``` PyErr { type: <class 'KeyError'>, value: KeyError('virtual'), traceback: None } ```
Review

By the way, why is this different than the KernelInvariant class ?

By the way, why is this different than the KernelInvariant class ?
Review

As I understand it, here during the initialization of the nac3 python object it needs to know the id of class virtual becuase virtual[..] is actually a type that can exist in the range of a typevar, and when resolving typevars, nac3artiq needs the id of it to know that it is handling a virtual.

While KernelInvariant in min_artiq.py is just for cpython to not complain about unknown name and we do not need to know its id because so far it only occurs in class fields definition and nac3 do not need to know any cpython related information about it since we are directly looking into the string in the ast.

As I understand it, here during the initialization of the nac3 python object it needs to know the id of class `virtual` becuase `virtual[..]` is actually a type that can exist in the range of a typevar, and when resolving typevars, nac3artiq needs the id of it to know that it is handling a `virtual`. While KernelInvariant in `min_artiq.py` is just for cpython to not complain about unknown name and we do not need to know its id because so far it only occurs in class fields definition and nac3 do not need to know any cpython related information about it since we are directly looking into the string in the ast.
class virtual(Generic[T]):
pass
import device_db import device_db
core_arguments = device_db.device_db["core"]["arguments"] core_arguments = device_db.device_db["core"]["arguments"]
compiler = nac3artiq.NAC3(core_arguments["target"]) compiler = nac3artiq.NAC3(core_arguments["target"])
allow_registration = True allow_registration = True
# Delay NAC3 analysis until all referenced variables are supposed to exist on the CPython side. # Delay NAC3 analysis until all referenced variables are supposed to exist on the CPython side.

View File

@ -52,6 +52,10 @@ pub struct PrimitivePythonId {
bool: u64, bool: u64,
list: u64, list: u64,
tuple: u64, tuple: u64,
typevar: u64,
none: u64,
generic_alias: (u64, u64),
virtual_id: u64,
} }
// TopLevelComposer is unsendable as it holds the unification table, which is // TopLevelComposer is unsendable as it holds the unification table, which is
@ -267,7 +271,36 @@ impl Nac3 {
let builtins_mod = PyModule::import(py, "builtins").unwrap(); let builtins_mod = PyModule::import(py, "builtins").unwrap();
let id_fn = builtins_mod.getattr("id").unwrap(); let id_fn = builtins_mod.getattr("id").unwrap();
let numpy_mod = PyModule::import(py, "numpy").unwrap(); let numpy_mod = PyModule::import(py, "numpy").unwrap();
let typing_mod = PyModule::import(py, "typing").unwrap();
let types_mod = PyModule::import(py, "types").unwrap();
let primitive_ids = PrimitivePythonId { let primitive_ids = PrimitivePythonId {
virtual_id: py.eval(
"id(virtual)",
Some(builtins_mod.getattr("globals").unwrap().call0().unwrap().extract().unwrap()),
None
).unwrap().extract().unwrap(),
generic_alias: (
id_fn
.call1((typing_mod.getattr("_GenericAlias").unwrap(),))
.unwrap()
.extract()
.unwrap(),
id_fn
.call1((types_mod.getattr("GenericAlias").unwrap(),))
.unwrap()
.extract()
.unwrap(),
),
none: id_fn
.call1((builtins_mod.getattr("None").unwrap(),))
.unwrap()
.extract()
.unwrap(),
typevar: id_fn
.call1((typing_mod.getattr("TypeVar").unwrap(),))
.unwrap()
.extract()
.unwrap(),
int: id_fn int: id_fn
.call1((builtins_mod.getattr("int").unwrap(),)) .call1((builtins_mod.getattr("int").unwrap(),))
.unwrap() .unwrap()

View File

@ -40,6 +40,11 @@ struct PythonHelper<'a> {
type_fn: &'a PyAny, type_fn: &'a PyAny,
len_fn: &'a PyAny, len_fn: &'a PyAny,
id_fn: &'a PyAny, id_fn: &'a PyAny,
eval_type_fn: &'a PyAny,
origin_ty_fn: &'a PyAny,
args_ty_fn: &'a PyAny,
globals_dict: &'a PyAny,
print_fn: &'a PyAny,
} }
impl Resolver { impl Resolver {
@ -71,47 +76,51 @@ impl Resolver {
})) }))
} }
fn get_obj_type( // handle python objects that represent types themselves
// primitives and class types should be themselves, use `ty_id` to check,
// TypeVars and GenericAlias(`A[int, bool]`) should use `ty_ty_id` to check
// the `bool` value returned indicates whether they are instantiated or not
fn get_pyty_obj_type(
&self, &self,
obj: &PyAny, pyty: &PyAny,
helper: &PythonHelper, helper: &PythonHelper,
unifier: &mut Unifier, unifier: &mut Unifier,
defs: &[Arc<RwLock<TopLevelDef>>], defs: &[Arc<RwLock<TopLevelDef>>],
primitives: &PrimitiveStore, primitives: &PrimitiveStore,
) -> PyResult<Option<Type>> { ) -> PyResult<Result<(Type, bool), String>> {
// eval_type use only globals_dict should be fine
let evaluated_ty = helper
.eval_type_fn
.call1((pyty, helper.globals_dict, helper.globals_dict)).unwrap();
let ty_id: u64 = helper let ty_id: u64 = helper
.id_fn .id_fn
.call1((helper.type_fn.call1((obj,))?,))? .call1((evaluated_ty,))?
.extract()?;
let ty_ty_id: u64 = helper
.id_fn
.call1((helper.type_fn.call1((evaluated_ty,))?,))?
.extract()?; .extract()?;
if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 { if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 {
Ok(Some(primitives.int32)) Ok(Ok((primitives.int32, true)))
} else if ty_id == self.primitive_ids.int64 { } else if ty_id == self.primitive_ids.int64 {
Ok(Some(primitives.int64)) Ok(Ok((primitives.int64, true)))
} else if ty_id == self.primitive_ids.bool { } else if ty_id == self.primitive_ids.bool {
Ok(Some(primitives.bool)) Ok(Ok((primitives.bool, true)))
} else if ty_id == self.primitive_ids.float { } else if ty_id == self.primitive_ids.float {
Ok(Some(primitives.float)) Ok(Ok((primitives.float, true)))
} else if ty_id == self.primitive_ids.list { } else if ty_id == self.primitive_ids.list {
let len: usize = helper.len_fn.call1((obj,))?.extract()?; // do not handle type var param and concrete check here
if len == 0 {
let var = unifier.get_fresh_var().0; let var = unifier.get_fresh_var().0;
Outdated
Review

remove println

remove println
let list = unifier.add_ty(TypeEnum::TList { ty: var }); let list = unifier.add_ty(TypeEnum::TList { ty: var });
Ok(Some(list)) Ok(Ok((list, false)))
Outdated
Review

same

same
} else {
let ty = self.get_list_elem_type(obj, len, helper, unifier, defs, primitives)?;
Ok(ty.map(|ty| unifier.add_ty(TypeEnum::TList { ty })))
}
} else if ty_id == self.primitive_ids.tuple { } else if ty_id == self.primitive_ids.tuple {
let elements: &PyTuple = obj.cast_as()?; // do not handle type var param and concrete check here
let types: Result<Option<Vec<_>>, _> = elements Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false)))
.iter() } else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).cloned() {
.map(|elem| self.get_obj_type(elem, helper, unifier, defs, primitives)) // println!("getting def");
.collect();
let types = types?;
Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types })))
} else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id) {
let def = defs[def_id.0].read(); let def = defs[def_id.0].read();
// println!("got def");
if let TopLevelDef::Class { if let TopLevelDef::Class {
object_id, object_id,
type_vars, type_vars,
@ -120,35 +129,260 @@ impl Resolver {
.. ..
} = &*def } = &*def
{ {
let var_map: HashMap<_, _> = type_vars // do not handle type var param and concrete check here, and no subst
Ok(Ok({
let ty = TypeEnum::TObj {
obj_id: *object_id,
params: RefCell::new({
type_vars
.iter() .iter()
.map(|var| { .map(|x| {
( if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x) {
if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) { (*id, *x)
*id } else { unreachable!() }
}).collect()
}),
fields: RefCell::new({
let mut res = methods
.iter()
.map(|(iden, ty, _)| (*iden, (*ty, false)))
.collect::<HashMap<_, _>>();
res.extend(fields.clone().into_iter().map(|x| (x.0, (x.1, x.2))));
res
})
};
// here also false, later insta use python object to check compatible
(unifier.add_ty(ty), false)
}))
} else {
// only object is supported, functions are not supported
unreachable!("function type is not supported, should not be queried")
}
} else if ty_ty_id == self.primitive_ids.typevar {
let constraint_types = {
let constraints = pyty.getattr("__constraints__").unwrap();
let mut result: Vec<Type> = vec![];
for i in 0.. {
if let Ok(constr) = constraints.get_item(i) {
result.push({
match self.get_pyty_obj_type(constr, helper, unifier, defs, primitives)? {
Ok((ty, _)) => {
if unifier.is_concrete(ty, &[]) {
ty
} else {
return Ok(Err(format!(
"the {}th constraint of TypeVar `{}` is not concrete",
i + 1,
pyty.getattr("__name__")?.extract::<String>()?
)))
}
},
Err(err) => return Ok(Err(err))
}
})
} else {
break;
}
}
result
};
let res = unifier.get_fresh_var_with_range(&constraint_types).0;
Ok(Ok((res, true)))
} else if ty_ty_id == self.primitive_ids.generic_alias.0 || ty_ty_id == self.primitive_ids.generic_alias.1 {
let origin = helper.origin_ty_fn.call1((evaluated_ty,))?;
let args: &PyTuple = helper.args_ty_fn.call1((evaluated_ty,))?.cast_as()?;
let origin_ty = match self.get_pyty_obj_type(origin, helper, unifier, defs, primitives)? {
Ok((ty, false)) => ty,
Ok((_, true)) => return Ok(Err("instantiated type does not take type parameters".into())),
Err(err) => return Ok(Err(err))
};
match &*unifier.get_ty(origin_ty) {
TypeEnum::TList { .. } => {
if args.len() == 1 {
let ty = match self.get_pyty_obj_type(args.get_item(0), helper, unifier, defs, primitives)? {
Ok(ty) => ty,
Err(err) => return Ok(Err(err))
};
if !unifier.is_concrete(ty.0, &[]) && !ty.1 {
panic!("type list should take concrete parameters in type var ranges")
}
Ok(Ok((unifier.add_ty(TypeEnum::TList { ty: ty.0 }), true)))
} else {
return Ok(Err(format!("type list needs exactly 1 type parameters, found {}", args.len())))
}
},
TypeEnum::TTuple { .. } => {
let args = match args
.iter()
.map(|x| self.get_pyty_obj_type(x, helper, unifier, defs, primitives))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.collect::<Result<Vec<_>, _>>() {
Ok(args) if !args.is_empty() => args
.into_iter()
.map(|(x, check)| if !unifier.is_concrete(x, &[]) && !check {
panic!("type tuple should take concrete parameters in type var ranges")
} else {
x
}
)
.collect::<Vec<_>>(),
Err(err) => return Ok(Err(err)),
_ => return Ok(Err("tuple type needs at least 1 type parameters".to_string()))
};
Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: args }), true)))
},
TypeEnum::TObj { params, obj_id, .. } => {
let subst = {
let params = &*params.borrow();
if params.len() != args.len() {
return Ok(Err(format!(
"for class #{}, expect {} type parameters, got {}.",
obj_id.0,
params.len(),
args.len(),
)))
}
let args = match args
.iter()
.map(|x| self.get_pyty_obj_type(x, helper, unifier, defs, primitives))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.collect::<Result<Vec<_>, _>>() {
Ok(args) => args
.into_iter()
.map(|(x, check)| if !unifier.is_concrete(x, &[]) && !check {
panic!("type class should take concrete parameters in type var ranges")
} else {
x
}
)
.collect::<Vec<_>>(),
Err(err) => return Ok(Err(err)),
};
params
.iter()
.zip(args.iter())
.map(|((id, _), ty)| (*id, *ty))
.collect::<HashMap<_, _>>()
};
Ok(Ok((unifier.subst(origin_ty, &subst).unwrap_or(origin_ty), true)))
},
TypeEnum::TVirtual { .. } => {
if args.len() == 1 {
let ty = match self.get_pyty_obj_type(args.get_item(0), helper, unifier, defs, primitives)? {
Ok(ty) => ty,
Err(err) => return Ok(Err(err))
};
if !unifier.is_concrete(ty.0, &[]) && !ty.1 {
panic!("virtual class should take concrete parameters in type var ranges")
}
Ok(Ok((unifier.add_ty(TypeEnum::TVirtual { ty: ty.0 }), true)))
} else {
return Ok(Err(format!("virtual class needs exactly 1 type parameters, found {}", args.len())))
}
}
_ => unimplemented!()
}
} else if ty_id == self.primitive_ids.virtual_id {
Ok(Ok(({
let ty = TypeEnum::TVirtual { ty: unifier.get_fresh_var().0 };
unifier.add_ty(ty)
}, false)))
} else {
Ok(Err("unknown type".into()))
}
}
fn get_obj_type(
&self,
obj: &PyAny,
helper: &PythonHelper,
unifier: &mut Unifier,
defs: &[Arc<RwLock<TopLevelDef>>],
primitives: &PrimitiveStore,
) -> PyResult<Option<Type>> {
let (extracted_ty, inst_check) = match self.get_pyty_obj_type(
{
let ty = helper.type_fn.call1((obj,)).unwrap();
if [self.primitive_ids.typevar,
self.primitive_ids.generic_alias.0,
self.primitive_ids.generic_alias.1
].contains(&helper.id_fn.call1((ty,))?.extract::<u64>()?) {
obj
} else {
ty
}
},
helper,
unifier,
defs,
primitives
)? {
Ok(s) => s,
Err(_) => return Ok(None)
};
return match (&*unifier.get_ty(extracted_ty), inst_check) {
// do the instantiation for these three types
(TypeEnum::TList { ty }, false) => {
let len: usize = helper.len_fn.call1((obj,))?.extract()?;
if len == 0 {
assert!(matches!(
&*unifier.get_ty(extracted_ty),
TypeEnum::TVar { meta: nac3core::typecheck::typedef::TypeVarMeta::Generic, range, .. }
if range.borrow().is_empty()
));
Ok(Some(extracted_ty))
} else {
let actual_ty = self
.get_list_elem_type(obj, len, helper, unifier, defs, primitives)?;
if let Some(actual_ty) = actual_ty {
unifier.unify(*ty, actual_ty).unwrap();
Ok(Some(extracted_ty))
} else {
Ok(None)
}
}
}
(TypeEnum::TTuple { .. }, false) => {
let elements: &PyTuple = obj.cast_as()?;
let types: Result<Option<Vec<_>>, _> = elements
.iter()
.map(|elem| self.get_obj_type(elem, helper, unifier, defs, primitives))
.collect();
let types = types?;
Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types })))
}
(TypeEnum::TObj { params, fields, .. }, false) => {
let var_map = params
.borrow()
.iter()
.map(|(id_var, ty)| {
if let TypeEnum::TVar { id, range, .. } = &*unifier.get_ty(*ty) {
assert_eq!(*id, *id_var);
(*id, unifier.get_fresh_var_with_range(&range.borrow()).0)
} else { } else {
unreachable!() unreachable!()
},
unifier.get_fresh_var().0,
)
})
.collect();
let mut fields_ty = HashMap::new();
for method in methods.iter() {
fields_ty.insert(method.0, (method.1, false));
} }
for field in fields.iter() { })
let name: String = field.0.into(); .collect::<HashMap<_, _>>();
// loop through non-function fields of the class to get the instantiated value
for field in fields.borrow().iter() {
let name: String = (*field.0).into();
if let TypeEnum::TFunc( .. ) = &*unifier.get_ty(field.1.0) {
continue;
} else {
let field_data = obj.getattr(&name)?; let field_data = obj.getattr(&name)?;
let ty = self let ty = self
.get_obj_type(field_data, helper, unifier, defs, primitives)? .get_obj_type(field_data, helper, unifier, defs, primitives)?
.unwrap_or(primitives.none); .unwrap_or(primitives.none);
let field_ty = unifier.subst(field.1, &var_map).unwrap_or(field.1); let field_ty = unifier.subst(field.1.0, &var_map).unwrap_or(field.1.0);
if unifier.unify(ty, field_ty).is_err() { if unifier.unify(ty, field_ty).is_err() {
// field type mismatch // field type mismatch
return Ok(None); return Ok(None);
} }
fields_ty.insert(field.0, (ty, field.2)); }
} }
for (_, ty) in var_map.iter() { for (_, ty) in var_map.iter() {
// must be concrete type // must be concrete type
@ -156,18 +390,10 @@ impl Resolver {
return Ok(None) return Ok(None)
} }
} }
Ok(Some(unifier.add_ty(TypeEnum::TObj { return Ok(Some(unifier.subst(extracted_ty, &var_map).unwrap_or(extracted_ty)));
obj_id: *object_id,
fields: RefCell::new(fields_ty),
params: RefCell::new(var_map),
})))
} else {
// only object is supported, functions are not supported
Ok(None)
}
} else {
Ok(None)
} }
_ => Ok(Some(extracted_ty))
};
} }
fn get_obj_value<'ctx, 'a>( fn get_obj_value<'ctx, 'a>(
@ -425,10 +651,16 @@ impl SymbolResolver for Resolver {
let key: &str = member.get_item(0)?.extract()?; let key: &str = member.get_item(0)?.extract()?;
if key == str.to_string() { if key == str.to_string() {
let builtins = PyModule::import(py, "builtins")?; let builtins = PyModule::import(py, "builtins")?;
let typings = PyModule::import(py, "typing")?;
let helper = PythonHelper { let helper = PythonHelper {
id_fn: builtins.getattr("id").unwrap(), id_fn: builtins.getattr("id").unwrap(),
len_fn: builtins.getattr("len").unwrap(), len_fn: builtins.getattr("len").unwrap(),
type_fn: builtins.getattr("type").unwrap(), type_fn: builtins.getattr("type").unwrap(),
origin_ty_fn: typings.getattr("get_origin").unwrap(),
args_ty_fn: typings.getattr("get_args").unwrap(),
globals_dict: obj.getattr("__dict__").unwrap(),
eval_type_fn: typings.getattr("_eval_type").unwrap(),
print_fn: builtins.getattr("print").unwrap(),
}; };
sym_ty = self.get_obj_type( sym_ty = self.get_obj_type(
member.get_item(1)?, member.get_item(1)?,
@ -469,10 +701,16 @@ impl SymbolResolver for Resolver {
let val = member.get_item(1)?; let val = member.get_item(1)?;
if key == id.to_string() { if key == id.to_string() {
let builtins = PyModule::import(py, "builtins")?; let builtins = PyModule::import(py, "builtins")?;
let typings = PyModule::import(py, "typing")?;
let helper = PythonHelper { let helper = PythonHelper {
id_fn: builtins.getattr("id").unwrap(), id_fn: builtins.getattr("id").unwrap(),
len_fn: builtins.getattr("len").unwrap(), len_fn: builtins.getattr("len").unwrap(),
type_fn: builtins.getattr("type").unwrap(), type_fn: builtins.getattr("type").unwrap(),
origin_ty_fn: typings.getattr("get_origin").unwrap(),
args_ty_fn: typings.getattr("get_args").unwrap(),
globals_dict: obj.getattr("__dict__").unwrap(),
eval_type_fn: typings.getattr("_eval_type").unwrap(),
print_fn: builtins.getattr("print").unwrap(),
}; };
sym_value = self.get_obj_value(val, &helper, ctx)?; sym_value = self.get_obj_value(val, &helper, ctx)?;
break; break;