diff --git a/nac3artiq/demo/min_artiq.py b/nac3artiq/demo/min_artiq.py index 529eea3..1ea9611 100644 --- a/nac3artiq/demo/min_artiq.py +++ b/nac3artiq/demo/min_artiq.py @@ -8,18 +8,23 @@ import nac3artiq __all__ = ["KernelInvariant", "extern", "kernel", "portable", "nac3", "ms", "us", "ns", - "Core", "TTLOut", "parallel", "sequential"] + "Core", "TTLOut", "parallel", "sequential", "virtual"] import device_db core_arguments = device_db.device_db["core"]["arguments"] +T = TypeVar('T') +# place the `virtual` class infront of the construct of NAC3 object to ensure the +# virtual class is known during the initializing of NAC3 object +class virtual(Generic[T]): + pass + compiler = nac3artiq.NAC3(core_arguments["target"]) allow_module_registration = True registered_modules = set() -T = TypeVar('T') class KernelInvariant(Generic[T]): pass diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index bfb319a..94a290a 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -51,6 +51,10 @@ pub struct PrimitivePythonId { bool: u64, list: u64, tuple: u64, + typevar: u64, + none: u64, + generic_alias: (u64, u64), + virtual_id: u64, } // TopLevelComposer is unsendable as it holds the unification table, which is @@ -246,7 +250,36 @@ impl Nac3 { let builtins_mod = PyModule::import(py, "builtins").unwrap(); let id_fn = builtins_mod.getattr("id").unwrap(); let numpy_mod = PyModule::import(py, "numpy").unwrap(); + let typing_mod = PyModule::import(py, "typing").unwrap(); + let types_mod = PyModule::import(py, "types").unwrap(); let primitive_ids = PrimitivePythonId { + virtual_id: py.eval( + "id(virtual)", + Some(builtins_mod.getattr("globals").unwrap().call0().unwrap().extract().unwrap()), + None + ).unwrap().extract().unwrap(), + generic_alias: ( + id_fn + .call1((typing_mod.getattr("_GenericAlias").unwrap(),)) + .unwrap() + .extract() + .unwrap(), + id_fn + .call1((types_mod.getattr("GenericAlias").unwrap(),)) + .unwrap() + .extract() + .unwrap(), + ), + none: id_fn + .call1((builtins_mod.getattr("None").unwrap(),)) + .unwrap() + .extract() + .unwrap(), + typevar: id_fn + .call1((typing_mod.getattr("TypeVar").unwrap(),)) + .unwrap() + .extract() + .unwrap(), int: id_fn .call1((builtins_mod.getattr("int").unwrap(),)) .unwrap() diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index e420bd8..b27c7df 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -40,6 +40,11 @@ struct PythonHelper<'a> { type_fn: &'a PyAny, len_fn: &'a PyAny, id_fn: &'a PyAny, + eval_type_fn: &'a PyAny, + origin_ty_fn: &'a PyAny, + args_ty_fn: &'a PyAny, + globals_dict: &'a PyAny, + print_fn: &'a PyAny, } impl Resolver { @@ -71,47 +76,51 @@ impl Resolver { })) } - fn get_obj_type( + // handle python objects that represent types themselves + // primitives and class types should be themselves, use `ty_id` to check, + // TypeVars and GenericAlias(`A[int, bool]`) should use `ty_ty_id` to check + // the `bool` value returned indicates whether they are instantiated or not + fn get_pyty_obj_type( &self, - obj: &PyAny, + pyty: &PyAny, helper: &PythonHelper, unifier: &mut Unifier, defs: &[Arc>], primitives: &PrimitiveStore, - ) -> PyResult> { + ) -> PyResult> { + // eval_type use only globals_dict should be fine + let evaluated_ty = helper + .eval_type_fn + .call1((pyty, helper.globals_dict, helper.globals_dict)).unwrap(); let ty_id: u64 = helper .id_fn - .call1((helper.type_fn.call1((obj,))?,))? + .call1((evaluated_ty,))? .extract()?; - + let ty_ty_id: u64 = helper + .id_fn + .call1((helper.type_fn.call1((evaluated_ty,))?,))? + .extract()?; + if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 { - Ok(Some(primitives.int32)) + Ok(Ok((primitives.int32, true))) } else if ty_id == self.primitive_ids.int64 { - Ok(Some(primitives.int64)) + Ok(Ok((primitives.int64, true))) } else if ty_id == self.primitive_ids.bool { - Ok(Some(primitives.bool)) + Ok(Ok((primitives.bool, true))) } else if ty_id == self.primitive_ids.float { - Ok(Some(primitives.float)) + Ok(Ok((primitives.float, true))) } else if ty_id == self.primitive_ids.list { - let len: usize = helper.len_fn.call1((obj,))?.extract()?; - if len == 0 { - let var = unifier.get_fresh_var().0; - let list = unifier.add_ty(TypeEnum::TList { ty: var }); - Ok(Some(list)) - } else { - let ty = self.get_list_elem_type(obj, len, helper, unifier, defs, primitives)?; - Ok(ty.map(|ty| unifier.add_ty(TypeEnum::TList { ty }))) - } + // do not handle type var param and concrete check here + let var = unifier.get_fresh_var().0; + let list = unifier.add_ty(TypeEnum::TList { ty: var }); + Ok(Ok((list, false))) } else if ty_id == self.primitive_ids.tuple { - let elements: &PyTuple = obj.cast_as()?; - let types: Result>, _> = elements - .iter() - .map(|elem| self.get_obj_type(elem, helper, unifier, defs, primitives)) - .collect(); - let types = types?; - Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types }))) - } else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id) { + // do not handle type var param and concrete check here + Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false))) + } else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).cloned() { + // println!("getting def"); let def = defs[def_id.0].read(); + // println!("got def"); if let TopLevelDef::Class { object_id, type_vars, @@ -120,35 +129,260 @@ impl Resolver { .. } = &*def { - let var_map: HashMap<_, _> = type_vars - .iter() - .map(|var| { - ( - if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) { - *id - } else { - unreachable!() - }, - unifier.get_fresh_var().0, - ) - }) - .collect(); - let mut fields_ty = HashMap::new(); - for method in methods.iter() { - fields_ty.insert(method.0, (method.1, false)); - } - for field in fields.iter() { - let name: String = field.0.into(); - let field_data = obj.getattr(&name)?; - let ty = self - .get_obj_type(field_data, helper, unifier, defs, primitives)? - .unwrap_or(primitives.none); - let field_ty = unifier.subst(field.1, &var_map).unwrap_or(field.1); - if unifier.unify(ty, field_ty).is_err() { - // field type mismatch - return Ok(None); + // do not handle type var param and concrete check here, and no subst + Ok(Ok({ + let ty = TypeEnum::TObj { + obj_id: *object_id, + params: RefCell::new({ + type_vars + .iter() + .map(|x| { + if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x) { + (*id, *x) + } else { unreachable!() } + }).collect() + }), + fields: RefCell::new({ + let mut res = methods + .iter() + .map(|(iden, ty, _)| (*iden, (*ty, false))) + .collect::>(); + res.extend(fields.clone().into_iter().map(|x| (x.0, (x.1, x.2)))); + res + }) + }; + // here also false, later insta use python object to check compatible + (unifier.add_ty(ty), false) + })) + } else { + // only object is supported, functions are not supported + unreachable!("function type is not supported, should not be queried") + } + } else if ty_ty_id == self.primitive_ids.typevar { + let constraint_types = { + let constraints = pyty.getattr("__constraints__").unwrap(); + let mut result: Vec = vec![]; + for i in 0.. { + if let Ok(constr) = constraints.get_item(i) { + result.push({ + match self.get_pyty_obj_type(constr, helper, unifier, defs, primitives)? { + Ok((ty, _)) => { + if unifier.is_concrete(ty, &[]) { + ty + } else { + return Ok(Err(format!( + "the {}th constraint of TypeVar `{}` is not concrete", + i + 1, + pyty.getattr("__name__")?.extract::()? + ))) + } + }, + Err(err) => return Ok(Err(err)) + } + }) + } else { + break; + } + } + result + }; + let res = unifier.get_fresh_var_with_range(&constraint_types).0; + Ok(Ok((res, true))) + } else if ty_ty_id == self.primitive_ids.generic_alias.0 || ty_ty_id == self.primitive_ids.generic_alias.1 { + let origin = helper.origin_ty_fn.call1((evaluated_ty,))?; + let args: &PyTuple = helper.args_ty_fn.call1((evaluated_ty,))?.cast_as()?; + let origin_ty = match self.get_pyty_obj_type(origin, helper, unifier, defs, primitives)? { + Ok((ty, false)) => ty, + Ok((_, true)) => return Ok(Err("instantiated type does not take type parameters".into())), + Err(err) => return Ok(Err(err)) + }; + + match &*unifier.get_ty(origin_ty) { + TypeEnum::TList { .. } => { + if args.len() == 1 { + let ty = match self.get_pyty_obj_type(args.get_item(0), helper, unifier, defs, primitives)? { + Ok(ty) => ty, + Err(err) => return Ok(Err(err)) + }; + if !unifier.is_concrete(ty.0, &[]) && !ty.1 { + panic!("type list should take concrete parameters in type var ranges") + } + Ok(Ok((unifier.add_ty(TypeEnum::TList { ty: ty.0 }), true))) + } else { + return Ok(Err(format!("type list needs exactly 1 type parameters, found {}", args.len()))) + } + }, + TypeEnum::TTuple { .. } => { + let args = match args + .iter() + .map(|x| self.get_pyty_obj_type(x, helper, unifier, defs, primitives)) + .collect::, _>>()? + .into_iter() + .collect::, _>>() { + Ok(args) if !args.is_empty() => args + .into_iter() + .map(|(x, check)| if !unifier.is_concrete(x, &[]) && !check { + panic!("type tuple should take concrete parameters in type var ranges") + } else { + x + } + ) + .collect::>(), + Err(err) => return Ok(Err(err)), + _ => return Ok(Err("tuple type needs at least 1 type parameters".to_string())) + }; + Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: args }), true))) + }, + TypeEnum::TObj { params, obj_id, .. } => { + let subst = { + let params = &*params.borrow(); + if params.len() != args.len() { + return Ok(Err(format!( + "for class #{}, expect {} type parameters, got {}.", + obj_id.0, + params.len(), + args.len(), + ))) + } + let args = match args + .iter() + .map(|x| self.get_pyty_obj_type(x, helper, unifier, defs, primitives)) + .collect::, _>>()? + .into_iter() + .collect::, _>>() { + Ok(args) => args + .into_iter() + .map(|(x, check)| if !unifier.is_concrete(x, &[]) && !check { + panic!("type class should take concrete parameters in type var ranges") + } else { + x + } + ) + .collect::>(), + Err(err) => return Ok(Err(err)), + }; + params + .iter() + .zip(args.iter()) + .map(|((id, _), ty)| (*id, *ty)) + .collect::>() + }; + Ok(Ok((unifier.subst(origin_ty, &subst).unwrap_or(origin_ty), true))) + }, + TypeEnum::TVirtual { .. } => { + if args.len() == 1 { + let ty = match self.get_pyty_obj_type(args.get_item(0), helper, unifier, defs, primitives)? { + Ok(ty) => ty, + Err(err) => return Ok(Err(err)) + }; + if !unifier.is_concrete(ty.0, &[]) && !ty.1 { + panic!("virtual class should take concrete parameters in type var ranges") + } + Ok(Ok((unifier.add_ty(TypeEnum::TVirtual { ty: ty.0 }), true))) + } else { + return Ok(Err(format!("virtual class needs exactly 1 type parameters, found {}", args.len()))) + } + } + _ => unimplemented!() + } + } else if ty_id == self.primitive_ids.virtual_id { + Ok(Ok(({ + let ty = TypeEnum::TVirtual { ty: unifier.get_fresh_var().0 }; + unifier.add_ty(ty) + }, false))) + } else { + Ok(Err("unknown type".into())) + } + } + + fn get_obj_type( + &self, + obj: &PyAny, + helper: &PythonHelper, + unifier: &mut Unifier, + defs: &[Arc>], + primitives: &PrimitiveStore, + ) -> PyResult> { + let (extracted_ty, inst_check) = match self.get_pyty_obj_type( + { + let ty = helper.type_fn.call1((obj,)).unwrap(); + if [self.primitive_ids.typevar, + self.primitive_ids.generic_alias.0, + self.primitive_ids.generic_alias.1 + ].contains(&helper.id_fn.call1((ty,))?.extract::()?) { + obj + } else { + ty + } + }, + helper, + unifier, + defs, + primitives + )? { + Ok(s) => s, + Err(_) => return Ok(None) + }; + return match (&*unifier.get_ty(extracted_ty), inst_check) { + // do the instantiation for these three types + (TypeEnum::TList { ty }, false) => { + let len: usize = helper.len_fn.call1((obj,))?.extract()?; + if len == 0 { + assert!(matches!( + &*unifier.get_ty(extracted_ty), + TypeEnum::TVar { meta: nac3core::typecheck::typedef::TypeVarMeta::Generic, range, .. } + if range.borrow().is_empty() + )); + Ok(Some(extracted_ty)) + } else { + let actual_ty = self + .get_list_elem_type(obj, len, helper, unifier, defs, primitives)?; + if let Some(actual_ty) = actual_ty { + unifier.unify(*ty, actual_ty).unwrap(); + Ok(Some(extracted_ty)) + } else { + Ok(None) + } + } + } + (TypeEnum::TTuple { .. }, false) => { + let elements: &PyTuple = obj.cast_as()?; + let types: Result>, _> = elements + .iter() + .map(|elem| self.get_obj_type(elem, helper, unifier, defs, primitives)) + .collect(); + let types = types?; + Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types }))) + } + (TypeEnum::TObj { params, fields, .. }, false) => { + let var_map = params + .borrow() + .iter() + .map(|(id_var, ty)| { + if let TypeEnum::TVar { id, range, .. } = &*unifier.get_ty(*ty) { + assert_eq!(*id, *id_var); + (*id, unifier.get_fresh_var_with_range(&range.borrow()).0) + } else { + unreachable!() + } + }) + .collect::>(); + // loop through non-function fields of the class to get the instantiated value + for field in fields.borrow().iter() { + let name: String = (*field.0).into(); + if let TypeEnum::TFunc( .. ) = &*unifier.get_ty(field.1.0) { + continue; + } else { + let field_data = obj.getattr(&name)?; + let ty = self + .get_obj_type(field_data, helper, unifier, defs, primitives)? + .unwrap_or(primitives.none); + let field_ty = unifier.subst(field.1.0, &var_map).unwrap_or(field.1.0); + if unifier.unify(ty, field_ty).is_err() { + // field type mismatch + return Ok(None); + } } - fields_ty.insert(field.0, (ty, field.2)); } for (_, ty) in var_map.iter() { // must be concrete type @@ -156,18 +390,10 @@ impl Resolver { return Ok(None) } } - Ok(Some(unifier.add_ty(TypeEnum::TObj { - obj_id: *object_id, - fields: RefCell::new(fields_ty), - params: RefCell::new(var_map), - }))) - } else { - // only object is supported, functions are not supported - Ok(None) + return Ok(Some(unifier.subst(extracted_ty, &var_map).unwrap_or(extracted_ty))); } - } else { - Ok(None) - } + _ => Ok(Some(extracted_ty)) + }; } fn get_obj_value<'ctx, 'a>( @@ -425,10 +651,16 @@ impl SymbolResolver for Resolver { let key: &str = member.get_item(0)?.extract()?; if key == str.to_string() { let builtins = PyModule::import(py, "builtins")?; + let typings = PyModule::import(py, "typing")?; let helper = PythonHelper { id_fn: builtins.getattr("id").unwrap(), len_fn: builtins.getattr("len").unwrap(), type_fn: builtins.getattr("type").unwrap(), + origin_ty_fn: typings.getattr("get_origin").unwrap(), + args_ty_fn: typings.getattr("get_args").unwrap(), + globals_dict: obj.getattr("__dict__").unwrap(), + eval_type_fn: typings.getattr("_eval_type").unwrap(), + print_fn: builtins.getattr("print").unwrap(), }; sym_ty = self.get_obj_type( member.get_item(1)?, @@ -469,10 +701,16 @@ impl SymbolResolver for Resolver { let val = member.get_item(1)?; if key == id.to_string() { let builtins = PyModule::import(py, "builtins")?; + let typings = PyModule::import(py, "typing")?; let helper = PythonHelper { id_fn: builtins.getattr("id").unwrap(), len_fn: builtins.getattr("len").unwrap(), type_fn: builtins.getattr("type").unwrap(), + origin_ty_fn: typings.getattr("get_origin").unwrap(), + args_ty_fn: typings.getattr("get_args").unwrap(), + globals_dict: obj.getattr("__dict__").unwrap(), + eval_type_fn: typings.getattr("_eval_type").unwrap(), + print_fn: builtins.getattr("print").unwrap(), }; sym_value = self.get_obj_value(val, &helper, ctx)?; break;