From 4d32b431812ed24a436e169cb746c32729a64f86 Mon Sep 17 00:00:00 2001 From: ychenfo Date: Tue, 9 Nov 2021 01:15:41 +0800 Subject: [PATCH 1/7] nac3core: fix polymorphic class method partial instantiation --- nac3core/src/codegen/concrete_type.rs | 10 +++++-- nac3core/src/toplevel/composer.rs | 33 +++++++++++++++++++++--- nac3core/src/toplevel/type_annotation.rs | 8 +++--- nac3core/src/typecheck/typedef/mod.rs | 16 +++++------- 4 files changed, 48 insertions(+), 19 deletions(-) diff --git a/nac3core/src/codegen/concrete_type.rs b/nac3core/src/codegen/concrete_type.rs index f4f4518b..422e4dfc 100644 --- a/nac3core/src/codegen/concrete_type.rs +++ b/nac3core/src/codegen/concrete_type.rs @@ -147,8 +147,14 @@ impl ConcreteTypeStore { fields: fields .borrow() .iter() - .map(|(name, ty)| { - (*name, (self.from_unifier_type(unifier, primitives, ty.0, cache), ty.1)) + .filter_map(|(name, ty)| { + // filter out functions as they can have type vars and + // will not affect codegen + if let TypeEnum::TFunc( .. ) = &*unifier.get_ty(ty.0) { + None + } else { + Some((*name, (self.from_unifier_type(unifier, primitives, ty.0, cache), ty.1))) + } }) .collect(), params: params diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index bf9f1561..57ca2643 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -1654,7 +1654,7 @@ impl TopLevelComposer { if let TypeEnum::TFunc(func_sig) = self.unifier.get_ty(*signature).as_ref() { let FunSignature { args, ret, vars } = &*func_sig.borrow(); // None if is not class method - let self_type = { + let uninst_self_type = { if let Some(class_id) = self.method_class.get(&DefinitionId(id)) { let class_def = self.definition_ast_list.get(class_id.0).unwrap(); let class_def = class_def.0.read(); @@ -1666,7 +1666,7 @@ impl TopLevelComposer { &self.primitives_ty, &ty_ann, )?; - Some(self_ty) + Some((self_ty, type_vars.clone())) } else { unreachable!("must be class def") } @@ -1717,9 +1717,34 @@ impl TopLevelComposer { }; let self_type = { let unifier = &mut self.unifier; - self_type.map(|x| unifier.subst(x, &subst).unwrap_or(x)) + uninst_self_type + .clone() + .map(|(self_type, type_vars)| { + let subst_for_self = { + let class_ty_var_ids = type_vars + .iter() + .map(|x| { + if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x) { + *id + } else { + unreachable!("must be type var here"); + } + }) + .collect::>(); + subst + .iter() + .filter_map(|(ty_var_id, ty_var_target)| { + if class_ty_var_ids.contains(ty_var_id) { + Some((*ty_var_id, *ty_var_target)) + } else { + None + } + }) + .collect::>() + }; + unifier.subst(self_type, &subst_for_self).unwrap_or(self_type) + }) }; - let mut identifiers = { // NOTE: none and function args? let mut result: HashSet<_> = HashSet::new(); diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index 257d582e..307d8794 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -280,9 +280,11 @@ pub fn get_type_from_type_annotation_kinds( { let ok: bool = { // create a temp type var and unify to check compatibility - let temp = - unifier.get_fresh_var_with_range(range.borrow().as_slice()); - unifier.unify(temp.0, p).is_ok() + p == *tvar || { + let temp = + unifier.get_fresh_var_with_range(range.borrow().as_slice()); + unifier.unify(temp.0, p).is_ok() + } }; if ok { result.insert(*id, p); diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index eb0cf4bc..85e619c8 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -719,22 +719,18 @@ impl Unifier { /// Returns Some(T) where T is the instantiated type. /// Returns None if the function is already instantiated. fn instantiate_fun(&mut self, ty: Type, fun: &FunSignature) -> Type { - let mut instantiated = false; + let mut instantiated = true; let mut vars = Vec::new(); for (k, v) in fun.vars.iter() { if let TypeEnum::TVar { id, range, .. } = self.unification_table.probe_value(*v).as_ref() { - if k != id { - instantiated = true; - break; + // need to do this for partial instantiated function + // (in class methods that contains type vars not in class) + if k == id { + instantiated = false; + vars.push((*k, range.clone())); } - // actually, if the first check succeeded, the function should be uninstatiated. - // The cloned values must be used and would not be wasted. - vars.push((*k, range.clone())); - } else { - instantiated = true; - break; } } if instantiated { -- 2.44.2 From c08aad3ffe26a0a7bee5681bd4819ffcabafeed9 Mon Sep 17 00:00:00 2001 From: ychenfo Date: Wed, 10 Nov 2021 23:38:47 +0800 Subject: [PATCH 2/7] nac3core: top level use codegen official get_subst_key --- nac3core/src/codegen/expr.rs | 45 ++++++++++++++++++------------- nac3core/src/toplevel/composer.rs | 23 ++++++---------- 2 files changed, 35 insertions(+), 33 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 5b31a15d..4026b8b2 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -7,7 +7,7 @@ use crate::{ }, symbol_resolver::SymbolValue, toplevel::{DefinitionId, TopLevelDef}, - typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum}, + typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier}, }; use inkwell::{ types::{BasicType, BasicTypeEnum}, @@ -21,6 +21,31 @@ use nac3parser::ast::{ use super::CodeGenerator; +pub fn get_subst_key( + unifier: &mut Unifier, + obj: Option, + fun_vars: &HashMap, + filter: Option<&Vec>, +) -> String { + let mut vars = obj + .map(|ty| { + if let TypeEnum::TObj { params, .. } = &*unifier.get_ty(ty) { + params.borrow().clone() + } else { + unreachable!() + } + }) + .unwrap_or_default(); + vars.extend(fun_vars.iter()); + let sorted = + vars.keys().filter(|id| filter.map(|v| v.contains(id)).unwrap_or(true)).sorted(); + sorted + .map(|id| { + unifier.stringify(vars[id], &mut |id| id.to_string(), &mut |id| id.to_string()) + }) + .join(", ") +} + impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { pub fn build_gep_and_load( &mut self, @@ -36,23 +61,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { fun: &FunSignature, filter: Option<&Vec>, ) -> String { - let mut vars = obj - .map(|ty| { - if let TypeEnum::TObj { params, .. } = &*self.unifier.get_ty(ty) { - params.borrow().clone() - } else { - unreachable!() - } - }) - .unwrap_or_default(); - vars.extend(fun.vars.iter()); - let sorted = - vars.keys().filter(|id| filter.map(|v| v.contains(id)).unwrap_or(true)).sorted(); - sorted - .map(|id| { - self.unifier.stringify(vars[id], &mut |id| id.to_string(), &mut |id| id.to_string()) - }) - .join(", ") + get_subst_key(&mut self.unifier, obj, &fun.vars, filter) } pub fn get_attr_index(&mut self, ty: Type, attr: StrRef) -> usize { diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 57ca2643..8dd5404c 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -6,6 +6,7 @@ use inkwell::FloatPredicate; use crate::{ symbol_resolver::SymbolValue, typecheck::type_inferencer::{FunctionData, Inferencer}, + codegen::expr::get_subst_key, }; use super::*; @@ -1835,21 +1836,13 @@ impl TopLevelComposer { instance_to_stmt.insert( // NOTE: refer to codegen/expr/get_subst_key function - { - let unifier = &mut self.unifier; - subst - .keys() - .sorted() - .map(|id| { - let ty = subst.get(id).unwrap(); - unifier.stringify( - *ty, - &mut |id| id.to_string(), - &mut |id| id.to_string(), - ) - }) - .join(", ") - }, + + get_subst_key( + &mut self.unifier, + self_type, + &subst, + None + ), FunInstance { body: Arc::new(fun_body), unifier_id: 0, -- 2.44.2 From 0a9ed4e24f8e6173775a07625fda898d9350041a Mon Sep 17 00:00:00 2001 From: ychenfo Date: Thu, 11 Nov 2021 01:23:27 +0800 Subject: [PATCH 3/7] nac3artiq: symbol reslover handle typevar, virtual and fForwardRef --- nac3artiq/demo/min_artiq.py | 9 +- nac3artiq/src/lib.rs | 33 +++ nac3artiq/src/symbol_resolver.rs | 368 +++++++++++++++++++++++++------ 3 files changed, 343 insertions(+), 67 deletions(-) diff --git a/nac3artiq/demo/min_artiq.py b/nac3artiq/demo/min_artiq.py index 529eea3d..1ea9611e 100644 --- a/nac3artiq/demo/min_artiq.py +++ b/nac3artiq/demo/min_artiq.py @@ -8,18 +8,23 @@ import nac3artiq __all__ = ["KernelInvariant", "extern", "kernel", "portable", "nac3", "ms", "us", "ns", - "Core", "TTLOut", "parallel", "sequential"] + "Core", "TTLOut", "parallel", "sequential", "virtual"] import device_db core_arguments = device_db.device_db["core"]["arguments"] +T = TypeVar('T') +# place the `virtual` class infront of the construct of NAC3 object to ensure the +# virtual class is known during the initializing of NAC3 object +class virtual(Generic[T]): + pass + compiler = nac3artiq.NAC3(core_arguments["target"]) allow_module_registration = True registered_modules = set() -T = TypeVar('T') class KernelInvariant(Generic[T]): pass diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index bfb319a8..94a290a6 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -51,6 +51,10 @@ pub struct PrimitivePythonId { bool: u64, list: u64, tuple: u64, + typevar: u64, + none: u64, + generic_alias: (u64, u64), + virtual_id: u64, } // TopLevelComposer is unsendable as it holds the unification table, which is @@ -246,7 +250,36 @@ impl Nac3 { let builtins_mod = PyModule::import(py, "builtins").unwrap(); let id_fn = builtins_mod.getattr("id").unwrap(); let numpy_mod = PyModule::import(py, "numpy").unwrap(); + let typing_mod = PyModule::import(py, "typing").unwrap(); + let types_mod = PyModule::import(py, "types").unwrap(); let primitive_ids = PrimitivePythonId { + virtual_id: py.eval( + "id(virtual)", + Some(builtins_mod.getattr("globals").unwrap().call0().unwrap().extract().unwrap()), + None + ).unwrap().extract().unwrap(), + generic_alias: ( + id_fn + .call1((typing_mod.getattr("_GenericAlias").unwrap(),)) + .unwrap() + .extract() + .unwrap(), + id_fn + .call1((types_mod.getattr("GenericAlias").unwrap(),)) + .unwrap() + .extract() + .unwrap(), + ), + none: id_fn + .call1((builtins_mod.getattr("None").unwrap(),)) + .unwrap() + .extract() + .unwrap(), + typevar: id_fn + .call1((typing_mod.getattr("TypeVar").unwrap(),)) + .unwrap() + .extract() + .unwrap(), int: id_fn .call1((builtins_mod.getattr("int").unwrap(),)) .unwrap() diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index e420bd80..b27c7dfd 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -40,6 +40,11 @@ struct PythonHelper<'a> { type_fn: &'a PyAny, len_fn: &'a PyAny, id_fn: &'a PyAny, + eval_type_fn: &'a PyAny, + origin_ty_fn: &'a PyAny, + args_ty_fn: &'a PyAny, + globals_dict: &'a PyAny, + print_fn: &'a PyAny, } impl Resolver { @@ -71,47 +76,51 @@ impl Resolver { })) } - fn get_obj_type( + // handle python objects that represent types themselves + // primitives and class types should be themselves, use `ty_id` to check, + // TypeVars and GenericAlias(`A[int, bool]`) should use `ty_ty_id` to check + // the `bool` value returned indicates whether they are instantiated or not + fn get_pyty_obj_type( &self, - obj: &PyAny, + pyty: &PyAny, helper: &PythonHelper, unifier: &mut Unifier, defs: &[Arc>], primitives: &PrimitiveStore, - ) -> PyResult> { + ) -> PyResult> { + // eval_type use only globals_dict should be fine + let evaluated_ty = helper + .eval_type_fn + .call1((pyty, helper.globals_dict, helper.globals_dict)).unwrap(); let ty_id: u64 = helper .id_fn - .call1((helper.type_fn.call1((obj,))?,))? + .call1((evaluated_ty,))? .extract()?; - + let ty_ty_id: u64 = helper + .id_fn + .call1((helper.type_fn.call1((evaluated_ty,))?,))? + .extract()?; + if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 { - Ok(Some(primitives.int32)) + Ok(Ok((primitives.int32, true))) } else if ty_id == self.primitive_ids.int64 { - Ok(Some(primitives.int64)) + Ok(Ok((primitives.int64, true))) } else if ty_id == self.primitive_ids.bool { - Ok(Some(primitives.bool)) + Ok(Ok((primitives.bool, true))) } else if ty_id == self.primitive_ids.float { - Ok(Some(primitives.float)) + Ok(Ok((primitives.float, true))) } else if ty_id == self.primitive_ids.list { - let len: usize = helper.len_fn.call1((obj,))?.extract()?; - if len == 0 { - let var = unifier.get_fresh_var().0; - let list = unifier.add_ty(TypeEnum::TList { ty: var }); - Ok(Some(list)) - } else { - let ty = self.get_list_elem_type(obj, len, helper, unifier, defs, primitives)?; - Ok(ty.map(|ty| unifier.add_ty(TypeEnum::TList { ty }))) - } + // do not handle type var param and concrete check here + let var = unifier.get_fresh_var().0; + let list = unifier.add_ty(TypeEnum::TList { ty: var }); + Ok(Ok((list, false))) } else if ty_id == self.primitive_ids.tuple { - let elements: &PyTuple = obj.cast_as()?; - let types: Result>, _> = elements - .iter() - .map(|elem| self.get_obj_type(elem, helper, unifier, defs, primitives)) - .collect(); - let types = types?; - Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types }))) - } else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id) { + // do not handle type var param and concrete check here + Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false))) + } else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).cloned() { + // println!("getting def"); let def = defs[def_id.0].read(); + // println!("got def"); if let TopLevelDef::Class { object_id, type_vars, @@ -120,35 +129,260 @@ impl Resolver { .. } = &*def { - let var_map: HashMap<_, _> = type_vars - .iter() - .map(|var| { - ( - if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) { - *id - } else { - unreachable!() - }, - unifier.get_fresh_var().0, - ) - }) - .collect(); - let mut fields_ty = HashMap::new(); - for method in methods.iter() { - fields_ty.insert(method.0, (method.1, false)); - } - for field in fields.iter() { - let name: String = field.0.into(); - let field_data = obj.getattr(&name)?; - let ty = self - .get_obj_type(field_data, helper, unifier, defs, primitives)? - .unwrap_or(primitives.none); - let field_ty = unifier.subst(field.1, &var_map).unwrap_or(field.1); - if unifier.unify(ty, field_ty).is_err() { - // field type mismatch - return Ok(None); + // do not handle type var param and concrete check here, and no subst + Ok(Ok({ + let ty = TypeEnum::TObj { + obj_id: *object_id, + params: RefCell::new({ + type_vars + .iter() + .map(|x| { + if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x) { + (*id, *x) + } else { unreachable!() } + }).collect() + }), + fields: RefCell::new({ + let mut res = methods + .iter() + .map(|(iden, ty, _)| (*iden, (*ty, false))) + .collect::>(); + res.extend(fields.clone().into_iter().map(|x| (x.0, (x.1, x.2)))); + res + }) + }; + // here also false, later insta use python object to check compatible + (unifier.add_ty(ty), false) + })) + } else { + // only object is supported, functions are not supported + unreachable!("function type is not supported, should not be queried") + } + } else if ty_ty_id == self.primitive_ids.typevar { + let constraint_types = { + let constraints = pyty.getattr("__constraints__").unwrap(); + let mut result: Vec = vec![]; + for i in 0.. { + if let Ok(constr) = constraints.get_item(i) { + result.push({ + match self.get_pyty_obj_type(constr, helper, unifier, defs, primitives)? { + Ok((ty, _)) => { + if unifier.is_concrete(ty, &[]) { + ty + } else { + return Ok(Err(format!( + "the {}th constraint of TypeVar `{}` is not concrete", + i + 1, + pyty.getattr("__name__")?.extract::()? + ))) + } + }, + Err(err) => return Ok(Err(err)) + } + }) + } else { + break; + } + } + result + }; + let res = unifier.get_fresh_var_with_range(&constraint_types).0; + Ok(Ok((res, true))) + } else if ty_ty_id == self.primitive_ids.generic_alias.0 || ty_ty_id == self.primitive_ids.generic_alias.1 { + let origin = helper.origin_ty_fn.call1((evaluated_ty,))?; + let args: &PyTuple = helper.args_ty_fn.call1((evaluated_ty,))?.cast_as()?; + let origin_ty = match self.get_pyty_obj_type(origin, helper, unifier, defs, primitives)? { + Ok((ty, false)) => ty, + Ok((_, true)) => return Ok(Err("instantiated type does not take type parameters".into())), + Err(err) => return Ok(Err(err)) + }; + + match &*unifier.get_ty(origin_ty) { + TypeEnum::TList { .. } => { + if args.len() == 1 { + let ty = match self.get_pyty_obj_type(args.get_item(0), helper, unifier, defs, primitives)? { + Ok(ty) => ty, + Err(err) => return Ok(Err(err)) + }; + if !unifier.is_concrete(ty.0, &[]) && !ty.1 { + panic!("type list should take concrete parameters in type var ranges") + } + Ok(Ok((unifier.add_ty(TypeEnum::TList { ty: ty.0 }), true))) + } else { + return Ok(Err(format!("type list needs exactly 1 type parameters, found {}", args.len()))) + } + }, + TypeEnum::TTuple { .. } => { + let args = match args + .iter() + .map(|x| self.get_pyty_obj_type(x, helper, unifier, defs, primitives)) + .collect::, _>>()? + .into_iter() + .collect::, _>>() { + Ok(args) if !args.is_empty() => args + .into_iter() + .map(|(x, check)| if !unifier.is_concrete(x, &[]) && !check { + panic!("type tuple should take concrete parameters in type var ranges") + } else { + x + } + ) + .collect::>(), + Err(err) => return Ok(Err(err)), + _ => return Ok(Err("tuple type needs at least 1 type parameters".to_string())) + }; + Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: args }), true))) + }, + TypeEnum::TObj { params, obj_id, .. } => { + let subst = { + let params = &*params.borrow(); + if params.len() != args.len() { + return Ok(Err(format!( + "for class #{}, expect {} type parameters, got {}.", + obj_id.0, + params.len(), + args.len(), + ))) + } + let args = match args + .iter() + .map(|x| self.get_pyty_obj_type(x, helper, unifier, defs, primitives)) + .collect::, _>>()? + .into_iter() + .collect::, _>>() { + Ok(args) => args + .into_iter() + .map(|(x, check)| if !unifier.is_concrete(x, &[]) && !check { + panic!("type class should take concrete parameters in type var ranges") + } else { + x + } + ) + .collect::>(), + Err(err) => return Ok(Err(err)), + }; + params + .iter() + .zip(args.iter()) + .map(|((id, _), ty)| (*id, *ty)) + .collect::>() + }; + Ok(Ok((unifier.subst(origin_ty, &subst).unwrap_or(origin_ty), true))) + }, + TypeEnum::TVirtual { .. } => { + if args.len() == 1 { + let ty = match self.get_pyty_obj_type(args.get_item(0), helper, unifier, defs, primitives)? { + Ok(ty) => ty, + Err(err) => return Ok(Err(err)) + }; + if !unifier.is_concrete(ty.0, &[]) && !ty.1 { + panic!("virtual class should take concrete parameters in type var ranges") + } + Ok(Ok((unifier.add_ty(TypeEnum::TVirtual { ty: ty.0 }), true))) + } else { + return Ok(Err(format!("virtual class needs exactly 1 type parameters, found {}", args.len()))) + } + } + _ => unimplemented!() + } + } else if ty_id == self.primitive_ids.virtual_id { + Ok(Ok(({ + let ty = TypeEnum::TVirtual { ty: unifier.get_fresh_var().0 }; + unifier.add_ty(ty) + }, false))) + } else { + Ok(Err("unknown type".into())) + } + } + + fn get_obj_type( + &self, + obj: &PyAny, + helper: &PythonHelper, + unifier: &mut Unifier, + defs: &[Arc>], + primitives: &PrimitiveStore, + ) -> PyResult> { + let (extracted_ty, inst_check) = match self.get_pyty_obj_type( + { + let ty = helper.type_fn.call1((obj,)).unwrap(); + if [self.primitive_ids.typevar, + self.primitive_ids.generic_alias.0, + self.primitive_ids.generic_alias.1 + ].contains(&helper.id_fn.call1((ty,))?.extract::()?) { + obj + } else { + ty + } + }, + helper, + unifier, + defs, + primitives + )? { + Ok(s) => s, + Err(_) => return Ok(None) + }; + return match (&*unifier.get_ty(extracted_ty), inst_check) { + // do the instantiation for these three types + (TypeEnum::TList { ty }, false) => { + let len: usize = helper.len_fn.call1((obj,))?.extract()?; + if len == 0 { + assert!(matches!( + &*unifier.get_ty(extracted_ty), + TypeEnum::TVar { meta: nac3core::typecheck::typedef::TypeVarMeta::Generic, range, .. } + if range.borrow().is_empty() + )); + Ok(Some(extracted_ty)) + } else { + let actual_ty = self + .get_list_elem_type(obj, len, helper, unifier, defs, primitives)?; + if let Some(actual_ty) = actual_ty { + unifier.unify(*ty, actual_ty).unwrap(); + Ok(Some(extracted_ty)) + } else { + Ok(None) + } + } + } + (TypeEnum::TTuple { .. }, false) => { + let elements: &PyTuple = obj.cast_as()?; + let types: Result>, _> = elements + .iter() + .map(|elem| self.get_obj_type(elem, helper, unifier, defs, primitives)) + .collect(); + let types = types?; + Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types }))) + } + (TypeEnum::TObj { params, fields, .. }, false) => { + let var_map = params + .borrow() + .iter() + .map(|(id_var, ty)| { + if let TypeEnum::TVar { id, range, .. } = &*unifier.get_ty(*ty) { + assert_eq!(*id, *id_var); + (*id, unifier.get_fresh_var_with_range(&range.borrow()).0) + } else { + unreachable!() + } + }) + .collect::>(); + // loop through non-function fields of the class to get the instantiated value + for field in fields.borrow().iter() { + let name: String = (*field.0).into(); + if let TypeEnum::TFunc( .. ) = &*unifier.get_ty(field.1.0) { + continue; + } else { + let field_data = obj.getattr(&name)?; + let ty = self + .get_obj_type(field_data, helper, unifier, defs, primitives)? + .unwrap_or(primitives.none); + let field_ty = unifier.subst(field.1.0, &var_map).unwrap_or(field.1.0); + if unifier.unify(ty, field_ty).is_err() { + // field type mismatch + return Ok(None); + } } - fields_ty.insert(field.0, (ty, field.2)); } for (_, ty) in var_map.iter() { // must be concrete type @@ -156,18 +390,10 @@ impl Resolver { return Ok(None) } } - Ok(Some(unifier.add_ty(TypeEnum::TObj { - obj_id: *object_id, - fields: RefCell::new(fields_ty), - params: RefCell::new(var_map), - }))) - } else { - // only object is supported, functions are not supported - Ok(None) + return Ok(Some(unifier.subst(extracted_ty, &var_map).unwrap_or(extracted_ty))); } - } else { - Ok(None) - } + _ => Ok(Some(extracted_ty)) + }; } fn get_obj_value<'ctx, 'a>( @@ -425,10 +651,16 @@ impl SymbolResolver for Resolver { let key: &str = member.get_item(0)?.extract()?; if key == str.to_string() { let builtins = PyModule::import(py, "builtins")?; + let typings = PyModule::import(py, "typing")?; let helper = PythonHelper { id_fn: builtins.getattr("id").unwrap(), len_fn: builtins.getattr("len").unwrap(), type_fn: builtins.getattr("type").unwrap(), + origin_ty_fn: typings.getattr("get_origin").unwrap(), + args_ty_fn: typings.getattr("get_args").unwrap(), + globals_dict: obj.getattr("__dict__").unwrap(), + eval_type_fn: typings.getattr("_eval_type").unwrap(), + print_fn: builtins.getattr("print").unwrap(), }; sym_ty = self.get_obj_type( member.get_item(1)?, @@ -469,10 +701,16 @@ impl SymbolResolver for Resolver { let val = member.get_item(1)?; if key == id.to_string() { let builtins = PyModule::import(py, "builtins")?; + let typings = PyModule::import(py, "typing")?; let helper = PythonHelper { id_fn: builtins.getattr("id").unwrap(), len_fn: builtins.getattr("len").unwrap(), type_fn: builtins.getattr("type").unwrap(), + origin_ty_fn: typings.getattr("get_origin").unwrap(), + args_ty_fn: typings.getattr("get_args").unwrap(), + globals_dict: obj.getattr("__dict__").unwrap(), + eval_type_fn: typings.getattr("_eval_type").unwrap(), + print_fn: builtins.getattr("print").unwrap(), }; sym_value = self.get_obj_value(val, &helper, ctx)?; break; -- 2.44.2 From 77c8b947f4412782fbdcd6a894242390c737db19 Mon Sep 17 00:00:00 2001 From: ychenfo Date: Thu, 11 Nov 2021 01:54:29 +0800 Subject: [PATCH 4/7] nac3standalone: basic resolver typevar handling --- nac3core/src/toplevel/composer.rs | 2 +- nac3standalone/src/main.rs | 42 ++++++++++++++++++++++++++++++- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 8dd5404c..9362f02a 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -535,7 +535,7 @@ impl TopLevelComposer { } } - fn extract_def_list(&self) -> Vec>> { + pub fn extract_def_list(&self) -> Vec>> { self.definition_ast_list.iter().map(|(def, ..)| def.clone()).collect_vec() } diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index 5b4e31dd..a4a82e53 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -4,7 +4,7 @@ use inkwell::{ OptimizationLevel, }; use nac3core::typecheck::type_inferencer::PrimitiveStore; -use nac3parser::parser; +use nac3parser::{ast::{ExprKind, StmtKind}, parser}; use std::env; use std::fs; use std::{collections::HashMap, path::Path, sync::Arc, time::SystemTime}; @@ -66,6 +66,46 @@ fn main() { ); for stmt in parser_result.into_iter() { + // handle type vars in toplevel + if let StmtKind::Assign { value, targets, .. } = &stmt.node { + assert_eq!(targets.len(), 1, "only support single assignment for now, at {}", targets[0].location); + if let ExprKind::Call { func, args, .. } = &value.node { + if matches!(&func.node, ExprKind::Name { id, .. } if id == &"TypeVar".into()) { + print!("registering typevar {:?}", targets[0].node); + let constraints = args + .iter() + .skip(1) + .map(|x| { + let def_list = &composer.extract_def_list(); + let unifier = &mut composer.unifier; + resolver.parse_type_annotation( + def_list, + unifier, + &primitive, + x + ).unwrap() + }) + .collect::>(); + let res_ty = composer.unifier.get_fresh_var_with_range(&constraints).0; + println!( + " ...registered: {}", + composer.unifier.stringify( + res_ty, + &mut |x| format!("obj{}", x), + &mut |x| format!("tavr{}", x) + ) + ); + internal_resolver.add_id_type( + if let ExprKind::Name { id, .. } = &targets[0].node { *id } else { + panic!("must assign simple name variable as type variable for now") + }, + res_ty + ); + continue; + } + } + } + let (name, def_id, ty) = composer .register_top_level(stmt, Some(resolver.clone()), "__main__".into()) .unwrap(); -- 2.44.2 From 12ab8bcd39d5930a26655a7924dfc5a7d39aa495 Mon Sep 17 00:00:00 2001 From: ychenfo Date: Thu, 11 Nov 2021 02:46:41 +0800 Subject: [PATCH 5/7] nac3core: parse type annotation python forwardref handling --- nac3core/src/symbol_resolver.rs | 295 ++++++++++++----------- nac3core/src/toplevel/type_annotation.rs | 237 +++++++++--------- 2 files changed, 272 insertions(+), 260 deletions(-) diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 161d78a9..9f9c161d 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -13,7 +13,7 @@ use crate::{ use crate::{location::Location, typecheck::typedef::TypeEnum}; use inkwell::values::BasicValueEnum; use itertools::{chain, izip}; -use nac3parser::ast::{Expr, StrRef}; +use nac3parser::ast::{Constant::Str, Expr, StrRef}; use parking_lot::RwLock; #[derive(Clone, PartialEq, Debug)] @@ -79,159 +79,168 @@ pub fn parse_type_annotation( let list_id = ids[6]; let tuple_id = ids[7]; - match &expr.node { - Name { id, .. } => { - if *id == int32_id { - Ok(primitives.int32) - } else if *id == int64_id { - Ok(primitives.int64) - } else if *id == float_id { - Ok(primitives.float) - } else if *id == bool_id { - Ok(primitives.bool) - } else if *id == none_id { - Ok(primitives.none) - } else { - let obj_id = resolver.get_identifier_def(*id); - if let Some(obj_id) = obj_id { - let def = top_level_defs[obj_id.0].read(); - if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { - if !type_vars.is_empty() { - return Err(format!( - "Unexpected number of type parameters: expected {} but got 0", - type_vars.len() - )); - } - let fields = RefCell::new( - chain( - fields.iter().map(|(k, v, m)| (*k, (*v, *m))), - methods.iter().map(|(k, v, _)| (*k, (*v, false))), - ) - .collect(), - ); - Ok(unifier.add_ty(TypeEnum::TObj { - obj_id, - fields, - params: Default::default(), - })) - } else { - Err("Cannot use function name as type".into()) + let name_handling = |id: &StrRef, unifier: &mut Unifier| { + if *id == int32_id { + Ok(primitives.int32) + } else if *id == int64_id { + Ok(primitives.int64) + } else if *id == float_id { + Ok(primitives.float) + } else if *id == bool_id { + Ok(primitives.bool) + } else if *id == none_id { + Ok(primitives.none) + } else { + let obj_id = resolver.get_identifier_def(*id); + if let Some(obj_id) = obj_id { + let def = top_level_defs[obj_id.0].read(); + if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { + if !type_vars.is_empty() { + return Err(format!( + "Unexpected number of type parameters: expected {} but got 0", + type_vars.len() + )); } + let fields = RefCell::new( + chain( + fields.iter().map(|(k, v, m)| (*k, (*v, *m))), + methods.iter().map(|(k, v, _)| (*k, (*v, false))), + ) + .collect(), + ); + Ok(unifier.add_ty(TypeEnum::TObj { + obj_id, + fields, + params: Default::default(), + })) } else { - // it could be a type variable - let ty = resolver - .get_symbol_type(unifier, top_level_defs, primitives, *id) - .ok_or_else(|| "unknown type variable name".to_owned())?; - if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) { - Ok(ty) - } else { - Err(format!("Unknown type annotation {}", id)) - } + Err("Cannot use function name as type".into()) + } + } else { + // it could be a type variable + let ty = resolver + .get_symbol_type(unifier, top_level_defs, primitives, *id) + .ok_or_else(|| "unknown type variable name".to_owned())?; + if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) { + Ok(ty) + } else { + Err(format!("Unknown type annotation {}", id)) } } } - Subscript { value, slice, .. } => { - if let Name { id, .. } = &value.node { - if *id == virtual_id { - let ty = parse_type_annotation( - resolver, - top_level_defs, - unifier, - primitives, - slice, - )?; - Ok(unifier.add_ty(TypeEnum::TVirtual { ty })) - } else if *id == list_id { - let ty = parse_type_annotation( - resolver, - top_level_defs, - unifier, - primitives, - slice, - )?; - Ok(unifier.add_ty(TypeEnum::TList { ty })) - } else if *id == tuple_id { - if let Tuple { elts, .. } = &slice.node { - let ty = elts - .iter() - .map(|elt| { - parse_type_annotation( - resolver, - top_level_defs, - unifier, - primitives, - elt, - ) - }) - .collect::, _>>()?; - Ok(unifier.add_ty(TypeEnum::TTuple { ty })) - } else { - Err("Expected multiple elements for tuple".into()) - } - } else { - let types = if let Tuple { elts, .. } = &slice.node { - elts.iter() - .map(|v| { - parse_type_annotation( - resolver, - top_level_defs, - unifier, - primitives, - v, - ) - }) - .collect::, _>>()? - } else { - vec![parse_type_annotation( + }; + + let subscript_name_handle = |id: &StrRef, slice: &Expr, unifier: &mut Unifier| { + if *id == virtual_id { + let ty = parse_type_annotation( + resolver, + top_level_defs, + unifier, + primitives, + slice, + )?; + Ok(unifier.add_ty(TypeEnum::TVirtual { ty })) + } else if *id == list_id { + let ty = parse_type_annotation( + resolver, + top_level_defs, + unifier, + primitives, + slice, + )?; + Ok(unifier.add_ty(TypeEnum::TList { ty })) + } else if *id == tuple_id { + if let Tuple { elts, .. } = &slice.node { + let ty = elts + .iter() + .map(|elt| { + parse_type_annotation( resolver, top_level_defs, unifier, primitives, - slice, - )?] - }; + elt, + ) + }) + .collect::, _>>()?; + Ok(unifier.add_ty(TypeEnum::TTuple { ty })) + } else { + Err("Expected multiple elements for tuple".into()) + } + } else { + let types = if let Tuple { elts, .. } = &slice.node { + elts.iter() + .map(|v| { + parse_type_annotation( + resolver, + top_level_defs, + unifier, + primitives, + v, + ) + }) + .collect::, _>>()? + } else { + vec![parse_type_annotation( + resolver, + top_level_defs, + unifier, + primitives, + slice, + )?] + }; - let obj_id = resolver - .get_identifier_def(*id) - .ok_or_else(|| format!("Unknown type annotation {}", id))?; - let def = top_level_defs[obj_id.0].read(); - if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { - if types.len() != type_vars.len() { - return Err(format!( - "Unexpected number of type parameters: expected {} but got {}", - type_vars.len(), - types.len() - )); - } - let mut subst = HashMap::new(); - for (var, ty) in izip!(type_vars.iter(), types.iter()) { - let id = if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) { - *id - } else { - unreachable!() - }; - subst.insert(id, *ty); - } - let mut fields = fields - .iter() - .map(|(attr, ty, is_mutable)| { - let ty = unifier.subst(*ty, &subst).unwrap_or(*ty); - (*attr, (ty, *is_mutable)) - }) - .collect::>(); - fields.extend(methods.iter().map(|(attr, ty, _)| { - let ty = unifier.subst(*ty, &subst).unwrap_or(*ty); - (*attr, (ty, false)) - })); - Ok(unifier.add_ty(TypeEnum::TObj { - obj_id, - fields: fields.into(), - params: subst.into(), - })) - } else { - Err("Cannot use function name as type".into()) - } + let obj_id = resolver + .get_identifier_def(*id) + .ok_or_else(|| format!("Unknown type annotation {}", id))?; + let def = top_level_defs[obj_id.0].read(); + if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { + if types.len() != type_vars.len() { + return Err(format!( + "Unexpected number of type parameters: expected {} but got {}", + type_vars.len(), + types.len() + )); } + let mut subst = HashMap::new(); + for (var, ty) in izip!(type_vars.iter(), types.iter()) { + let id = if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) { + *id + } else { + unreachable!() + }; + subst.insert(id, *ty); + } + let mut fields = fields + .iter() + .map(|(attr, ty, is_mutable)| { + let ty = unifier.subst(*ty, &subst).unwrap_or(*ty); + (*attr, (ty, *is_mutable)) + }) + .collect::>(); + fields.extend(methods.iter().map(|(attr, ty, _)| { + let ty = unifier.subst(*ty, &subst).unwrap_or(*ty); + (*attr, (ty, false)) + })); + Ok(unifier.add_ty(TypeEnum::TObj { + obj_id, + fields: fields.into(), + params: subst.into(), + })) + } else { + Err("Cannot use function name as type".into()) + } + } + }; + + match &expr.node { + Name { id, .. } => name_handling(id, unifier), + Constant { value: Str(id), .. } => name_handling(&id.clone().into(), unifier), + Subscript { value, slice, .. } => { + if let Name { id, .. } = &value.node { + subscript_name_handle(id, slice, unifier) + } else if let Constant { value: Str(id), .. } = &value.node { + subscript_name_handle(&id.clone().into(), slice, unifier) } else { Err("unsupported type expression".into()) } diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index 307d8794..2ae31c7c 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -1,7 +1,7 @@ use std::cell::RefCell; use crate::typecheck::typedef::TypeVarMeta; - +use ast::Constant::Str; use super::*; #[derive(Clone, Debug)] @@ -49,58 +49,127 @@ pub fn parse_ast_to_type_annotation_kinds( primitives: &PrimitiveStore, expr: &ast::Expr, // the key stores the type_var of this topleveldef::class, we only need this field here - mut locked: HashMap>, + locked: HashMap>, ) -> Result { - match &expr.node { - ast::ExprKind::Name { id, .. } => { - if id == &"int32".into() { - Ok(TypeAnnotation::Primitive(primitives.int32)) - } else if id == &"int64".into() { - Ok(TypeAnnotation::Primitive(primitives.int64)) - } else if id == &"float".into() { - Ok(TypeAnnotation::Primitive(primitives.float)) - } else if id == &"bool".into() { - Ok(TypeAnnotation::Primitive(primitives.bool)) - } else if id == &"None".into() { - Ok(TypeAnnotation::Primitive(primitives.none)) - } else if id == &"str".into() { - Ok(TypeAnnotation::Primitive(primitives.str)) - } else if let Some(obj_id) = resolver.get_identifier_def(*id) { - let type_vars = { - let def_read = top_level_defs[obj_id.0].try_read(); - if let Some(def_read) = def_read { - if let TopLevelDef::Class { type_vars, .. } = &*def_read { - type_vars.clone() - } else { - return Err("function cannot be used as a type".into()); - } + let name_handle = |id: &StrRef, unifier: &mut Unifier, locked: HashMap>| { + if id == &"int32".into() { + Ok(TypeAnnotation::Primitive(primitives.int32)) + } else if id == &"int64".into() { + Ok(TypeAnnotation::Primitive(primitives.int64)) + } else if id == &"float".into() { + Ok(TypeAnnotation::Primitive(primitives.float)) + } else if id == &"bool".into() { + Ok(TypeAnnotation::Primitive(primitives.bool)) + } else if id == &"None".into() { + Ok(TypeAnnotation::Primitive(primitives.none)) + } else if id == &"str".into() { + Ok(TypeAnnotation::Primitive(primitives.str)) + } else if let Some(obj_id) = resolver.get_identifier_def(*id) { + let type_vars = { + let def_read = top_level_defs[obj_id.0].try_read(); + if let Some(def_read) = def_read { + if let TopLevelDef::Class { type_vars, .. } = &*def_read { + type_vars.clone() } else { - locked.get(&obj_id).unwrap().clone() + return Err("function cannot be used as a type".into()); } - }; - // check param number here - if !type_vars.is_empty() { - return Err(format!( - "expect {} type variable parameter but got 0", - type_vars.len() - )); - } - Ok(TypeAnnotation::CustomClass { id: obj_id, params: vec![] }) - } else if let Some(ty) = resolver.get_symbol_type(unifier, top_level_defs, primitives, *id) { - if let TypeEnum::TVar { .. } = unifier.get_ty(ty).as_ref() { - Ok(TypeAnnotation::TypeVar(ty)) } else { - Err("not a type variable identifier".into()) + locked.get(&obj_id).unwrap().clone() + } + }; + // check param number here + if !type_vars.is_empty() { + return Err(format!( + "expect {} type variable parameter but got 0", + type_vars.len() + )); + } + Ok(TypeAnnotation::CustomClass { id: obj_id, params: vec![] }) + } else if let Some(ty) = resolver.get_symbol_type(unifier, top_level_defs, primitives, *id) { + if let TypeEnum::TVar { .. } = unifier.get_ty(ty).as_ref() { + Ok(TypeAnnotation::TypeVar(ty)) + } else { + Err("not a type variable identifier".into()) + } + } else { + Err("name cannot be parsed as a type annotation".into()) + } + }; + + let class_name_handle = + |id: &StrRef, slice: &ast::Expr, unifier: &mut Unifier, mut locked: HashMap>| { + if vec!["virtual".into(), "Generic".into(), "list".into(), "tuple".into()] + .contains(id) + { + return Err("keywords cannot be class name".into()); + } + let obj_id = resolver + .get_identifier_def(*id) + .ok_or_else(|| "unknown class name".to_string())?; + let type_vars = { + let def_read = top_level_defs[obj_id.0].try_read(); + if let Some(def_read) = def_read { + if let TopLevelDef::Class { type_vars, .. } = &*def_read { + type_vars.clone() + } else { + unreachable!("must be class here") } } else { - Err("name cannot be parsed as a type annotation".into()) + locked.get(&obj_id).unwrap().clone() } - } - + }; + // we do not check whether the application of type variables are compatible here + let param_type_infos = { + let params_ast = if let ast::ExprKind::Tuple { elts, .. } = &slice.node { + elts.iter().collect_vec() + } else { + vec![slice] + }; + if type_vars.len() != params_ast.len() { + return Err(format!( + "expect {} type parameters but got {}", + type_vars.len(), + params_ast.len() + )); + } + let result = params_ast + .into_iter() + .map(|x| { + parse_ast_to_type_annotation_kinds( + resolver, + top_level_defs, + unifier, + primitives, + x, + { + locked.insert(obj_id, type_vars.clone()); + locked.clone() + }, + ) + }) + .collect::, _>>()?; + // make sure the result do not contain any type vars + let no_type_var = result + .iter() + .all(|x| get_type_var_contained_in_type_annotation(x).is_empty()); + if no_type_var { + result + } else { + return Err("application of type vars to generic class \ + is not currently supported" + .into()); + } + }; + Ok(TypeAnnotation::CustomClass { id: obj_id, params: param_type_infos }) + }; + match &expr.node { + ast::ExprKind::Name { id, .. } => name_handle(id, unifier, locked), + ast::ExprKind::Constant { value: Str(id), .. } => name_handle(&id.clone().into(), unifier, locked), // virtual ast::ExprKind::Subscript { value, slice, .. } if { - matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"virtual".into()) + matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"virtual".into()) || + matches!(&value.node, ast::ExprKind::Constant { value: Str(id), .. } if id == "virtual") } => { let def = parse_ast_to_type_annotation_kinds( @@ -120,7 +189,8 @@ pub fn parse_ast_to_type_annotation_kinds( // list ast::ExprKind::Subscript { value, slice, .. } if { - matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"list".into()) + matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"list".into()) || + matches!(&value.node, ast::ExprKind::Constant { value: Str(id), .. } if id == "list") } => { let def_ann = parse_ast_to_type_annotation_kinds( @@ -137,7 +207,8 @@ pub fn parse_ast_to_type_annotation_kinds( // tuple ast::ExprKind::Subscript { value, slice, .. } if { - matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"tuple".into()) + matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"tuple".into()) || + matches!(&value.node, ast::ExprKind::Constant { value: Str(id), .. } if id == "tuple") } => { if let ast::ExprKind::Tuple { elts, .. } = &slice.node { @@ -163,71 +234,9 @@ pub fn parse_ast_to_type_annotation_kinds( // custom class ast::ExprKind::Subscript { value, slice, .. } => { if let ast::ExprKind::Name { id, .. } = &value.node { - if vec!["virtual".into(), "Generic".into(), "list".into(), "tuple".into()] - .contains(id) - { - return Err("keywords cannot be class name".into()); - } - let obj_id = resolver - .get_identifier_def(*id) - .ok_or_else(|| "unknown class name".to_string())?; - let type_vars = { - let def_read = top_level_defs[obj_id.0].try_read(); - if let Some(def_read) = def_read { - if let TopLevelDef::Class { type_vars, .. } = &*def_read { - type_vars.clone() - } else { - unreachable!("must be class here") - } - } else { - locked.get(&obj_id).unwrap().clone() - } - }; - // we do not check whether the application of type variables are compatible here - let param_type_infos = { - let params_ast = if let ast::ExprKind::Tuple { elts, .. } = &slice.node { - elts.iter().collect_vec() - } else { - vec![slice.as_ref()] - }; - if type_vars.len() != params_ast.len() { - return Err(format!( - "expect {} type parameters but got {}", - type_vars.len(), - params_ast.len() - )); - } - let result = params_ast - .into_iter() - .map(|x| { - parse_ast_to_type_annotation_kinds( - resolver, - top_level_defs, - unifier, - primitives, - x, - { - locked.insert(obj_id, type_vars.clone()); - locked.clone() - }, - ) - }) - .collect::, _>>()?; - - // make sure the result do not contain any type vars - let no_type_var = result - .iter() - .all(|x| get_type_var_contained_in_type_annotation(x).is_empty()); - if no_type_var { - result - } else { - return Err("application of type vars to generic class \ - is not currently supported" - .into()); - } - }; - - Ok(TypeAnnotation::CustomClass { id: obj_id, params: param_type_infos }) + class_name_handle(id, slice, unifier, locked) + } else if let ast::ExprKind::Constant { value: Str(id), .. } = &value.node { + class_name_handle(&id.clone().into(), slice, unifier, locked) } else { Err("unsupported expression type for class name".into()) } @@ -370,13 +379,7 @@ pub fn get_type_from_type_annotation_kinds( /// But note that here we do not make a duplication of `T`, `V`, we direclty /// use them as they are in the TopLevelDef::Class since those in the /// TopLevelDef::Class.type_vars will be substitute later when seeing applications/instantiations -/// the Type of their fields and methods will also be subst when application/instantiation \ -/// \ -/// Note this implicit self type is different with seeing `A[T, V]` explicitly outside -/// the class def ast body, where it is a new instantiation of the generic class `A`, -/// but equivalent to seeing `A[T, V]` inside the class def body ast, where although we -/// create copies of `T` and `V`, we will find them out as occured type vars in the analyze_class() -/// and unify them with the class generic `T`, `V` +/// the Type of their fields and methods will also be subst when application/instantiation pub fn make_self_type_annotation(type_vars: &[Type], object_id: DefinitionId) -> TypeAnnotation { TypeAnnotation::CustomClass { id: object_id, -- 2.44.2 From a2da1ecf05205df6e2a552be0062b83f31d277f6 Mon Sep 17 00:00:00 2001 From: ychenfo Date: Thu, 11 Nov 2021 04:14:55 +0800 Subject: [PATCH 6/7] nac3artiq: filter out base class not annotated with nac3 --- nac3artiq/demo/min_artiq.py | 5 +++-- nac3artiq/src/lib.rs | 38 ++++++++++++++++++++++++++++++++----- 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/nac3artiq/demo/min_artiq.py b/nac3artiq/demo/min_artiq.py index 1ea9611e..9e5ad3d4 100644 --- a/nac3artiq/demo/min_artiq.py +++ b/nac3artiq/demo/min_artiq.py @@ -23,7 +23,7 @@ class virtual(Generic[T]): compiler = nac3artiq.NAC3(core_arguments["target"]) allow_module_registration = True registered_modules = set() - +nac3annotated_class_ids = set() class KernelInvariant(Generic[T]): pass @@ -69,6 +69,7 @@ def nac3(cls): All classes containing kernels or portable methods must use this decorator. """ register_module_of(cls) + nac3annotated_class_ids.add(id(cls)) return cls @@ -111,7 +112,7 @@ class Core: def run(self, method, *args, **kwargs): global allow_module_registration if allow_module_registration: - compiler.analyze_modules(registered_modules) + compiler.analyze_modules(registered_modules, nac3annotated_class_ids) allow_module_registration = False if hasattr(method, "__self__"): diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 94a290a6..9e7cae5e 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -11,7 +11,7 @@ use inkwell::{ use pyo3::prelude::*; use pyo3::{exceptions, types::PyList, types::PySet, types::PyBytes}; use nac3parser::{ - ast::{self, StrRef}, + ast::{self, StrRef, Constant::Str}, parser::{self, parse_program}, }; @@ -76,7 +76,7 @@ struct Nac3 { } impl Nac3 { - fn register_module_impl(&mut self, obj: PyObject) -> PyResult<()> { + fn register_module_impl(&mut self, obj: PyObject, nac3_annotated_cls: &PySet) -> PyResult<()> { let mut name_to_pyid: HashMap = HashMap::new(); let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> { let obj: &PyAny = obj.extract(py)?; @@ -111,7 +111,7 @@ impl Nac3 { global_value_ids: self.global_value_ids.clone(), class_names: Default::default(), name_to_pyid: name_to_pyid.clone(), - module: obj, + module: obj.clone(), }) as Arc; let mut name_to_def = HashMap::new(); let mut name_to_type = HashMap::new(); @@ -121,6 +121,7 @@ impl Nac3 { ast::StmtKind::ClassDef { ref decorator_list, ref mut body, + ref mut bases, .. } => { let kernels = decorator_list.iter().any(|decorator| { @@ -146,6 +147,33 @@ impl Nac3 { true } }); + bases.retain(|b| { + Python::with_gil(|py| -> PyResult { + let obj: &PyAny = obj.extract(py)?; + let annot_check = |id: &str| -> bool { + let id = py.eval( + &format!("id({})", id), + Some(obj.getattr("__dict__").unwrap().extract().unwrap()), + None + ).unwrap(); + nac3_annotated_cls.contains(id).unwrap() + }; + match &b.node { + ast::ExprKind::Name { id, .. } => Ok(annot_check(&id.to_string())), + ast::ExprKind::Constant { value: Str(id), .. } => + Ok(annot_check(id.split('[').next().unwrap())), + ast::ExprKind::Subscript { value, .. } => { + match &value.node { + ast::ExprKind::Name { id, .. } => Ok(annot_check(&id.to_string()) || *id == "Generic".into()), + ast::ExprKind::Constant { value: Str(id), .. } => + Ok(annot_check(id.split('[').next().unwrap())), + _ => unreachable!("unsupported base declaration") + } + } + _ => unreachable!("unsupported base declaration") + } + }).unwrap() + }); kernels } ast::StmtKind::FunctionDef { @@ -336,9 +364,9 @@ impl Nac3 { }) } - fn analyze_modules(&mut self, modules: &PySet) -> PyResult<()> { + fn analyze_modules(&mut self, modules: &PySet, nac3_annotated_cls: &PySet) -> PyResult<()> { for obj in modules.iter() { - self.register_module_impl(obj.into())?; + self.register_module_impl(obj.into(), nac3_annotated_cls)?; } Ok(()) } -- 2.44.2 From 07e1079f280a9426d61c1d93ef987c7a9f0b8f84 Mon Sep 17 00:00:00 2001 From: ychenfo Date: Thu, 11 Nov 2021 14:18:52 +0800 Subject: [PATCH 7/7] nac3artiq: avoid using py.eval to get id of class virtual --- nac3artiq/src/lib.rs | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 9e7cae5e..83a52a1d 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -281,11 +281,17 @@ impl Nac3 { let typing_mod = PyModule::import(py, "typing").unwrap(); let types_mod = PyModule::import(py, "types").unwrap(); let primitive_ids = PrimitivePythonId { - virtual_id: py.eval( - "id(virtual)", - Some(builtins_mod.getattr("globals").unwrap().call0().unwrap().extract().unwrap()), - None - ).unwrap().extract().unwrap(), + virtual_id: id_fn + .call1((builtins_mod + .getattr("globals") + .unwrap() + .call0() + .unwrap() + .get_item("virtual") + .unwrap(), + )).unwrap() + .extract() + .unwrap(), generic_alias: ( id_fn .call1((typing_mod.getattr("_GenericAlias").unwrap(),)) -- 2.44.2