From a3faa9b7dd3f550f05c50994af087e401d4803f6 Mon Sep 17 00:00:00 2001 From: ychenfo Date: Thu, 11 Nov 2021 01:23:27 +0800 Subject: [PATCH 01/10] nac3artiq: symbol reslover handle typevar, virtual and ForwardRef --- nac3artiq/demo/min_artiq.py | 8 +- nac3artiq/src/lib.rs | 33 +++ nac3artiq/src/symbol_resolver.rs | 368 +++++++++++++++++++++++++------ 3 files changed, 343 insertions(+), 66 deletions(-) diff --git a/nac3artiq/demo/min_artiq.py b/nac3artiq/demo/min_artiq.py index 2fbd076..f644eee 100644 --- a/nac3artiq/demo/min_artiq.py +++ b/nac3artiq/demo/min_artiq.py @@ -9,17 +9,23 @@ import nac3artiq __all__ = ["KernelInvariant", "extern", "kernel", "portable", "nac3", "ms", "us", "ns", "print_int32", "print_int64", - "Core", "TTLOut", "parallel", "sequential"] + "Core", "TTLOut", "parallel", "sequential", "virtual"] T = TypeVar('T') + class KernelInvariant(Generic[T]): pass +# place the `virtual` class infront of the construct of NAC3 object to ensure the +# virtual class is known during the initializing of NAC3 object +class virtual(Generic[T]): + pass import device_db core_arguments = device_db.device_db["core"]["arguments"] + compiler = nac3artiq.NAC3(core_arguments["target"]) allow_registration = True # Delay NAC3 analysis until all referenced variables are supposed to exist on the CPython side. diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index fb71ee8..07d9412 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -52,6 +52,10 @@ pub struct PrimitivePythonId { bool: u64, list: u64, tuple: u64, + typevar: u64, + none: u64, + generic_alias: (u64, u64), + virtual_id: u64, } // TopLevelComposer is unsendable as it holds the unification table, which is @@ -267,7 +271,36 @@ impl Nac3 { let builtins_mod = PyModule::import(py, "builtins").unwrap(); let id_fn = builtins_mod.getattr("id").unwrap(); let numpy_mod = PyModule::import(py, "numpy").unwrap(); + let typing_mod = PyModule::import(py, "typing").unwrap(); + let types_mod = PyModule::import(py, "types").unwrap(); let primitive_ids = PrimitivePythonId { + virtual_id: py.eval( + "id(virtual)", + Some(builtins_mod.getattr("globals").unwrap().call0().unwrap().extract().unwrap()), + None + ).unwrap().extract().unwrap(), + generic_alias: ( + id_fn + .call1((typing_mod.getattr("_GenericAlias").unwrap(),)) + .unwrap() + .extract() + .unwrap(), + id_fn + .call1((types_mod.getattr("GenericAlias").unwrap(),)) + .unwrap() + .extract() + .unwrap(), + ), + none: id_fn + .call1((builtins_mod.getattr("None").unwrap(),)) + .unwrap() + .extract() + .unwrap(), + typevar: id_fn + .call1((typing_mod.getattr("TypeVar").unwrap(),)) + .unwrap() + .extract() + .unwrap(), int: id_fn .call1((builtins_mod.getattr("int").unwrap(),)) .unwrap() diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index e420bd8..b27c7df 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -40,6 +40,11 @@ struct PythonHelper<'a> { type_fn: &'a PyAny, len_fn: &'a PyAny, id_fn: &'a PyAny, + eval_type_fn: &'a PyAny, + origin_ty_fn: &'a PyAny, + args_ty_fn: &'a PyAny, + globals_dict: &'a PyAny, + print_fn: &'a PyAny, } impl Resolver { @@ -71,47 +76,51 @@ impl Resolver { })) } - fn get_obj_type( + // handle python objects that represent types themselves + // primitives and class types should be themselves, use `ty_id` to check, + // TypeVars and GenericAlias(`A[int, bool]`) should use `ty_ty_id` to check + // the `bool` value returned indicates whether they are instantiated or not + fn get_pyty_obj_type( &self, - obj: &PyAny, + pyty: &PyAny, helper: &PythonHelper, unifier: &mut Unifier, defs: &[Arc>], primitives: &PrimitiveStore, - ) -> PyResult> { + ) -> PyResult> { + // eval_type use only globals_dict should be fine + let evaluated_ty = helper + .eval_type_fn + .call1((pyty, helper.globals_dict, helper.globals_dict)).unwrap(); let ty_id: u64 = helper .id_fn - .call1((helper.type_fn.call1((obj,))?,))? + .call1((evaluated_ty,))? .extract()?; - + let ty_ty_id: u64 = helper + .id_fn + .call1((helper.type_fn.call1((evaluated_ty,))?,))? + .extract()?; + if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 { - Ok(Some(primitives.int32)) + Ok(Ok((primitives.int32, true))) } else if ty_id == self.primitive_ids.int64 { - Ok(Some(primitives.int64)) + Ok(Ok((primitives.int64, true))) } else if ty_id == self.primitive_ids.bool { - Ok(Some(primitives.bool)) + Ok(Ok((primitives.bool, true))) } else if ty_id == self.primitive_ids.float { - Ok(Some(primitives.float)) + Ok(Ok((primitives.float, true))) } else if ty_id == self.primitive_ids.list { - let len: usize = helper.len_fn.call1((obj,))?.extract()?; - if len == 0 { - let var = unifier.get_fresh_var().0; - let list = unifier.add_ty(TypeEnum::TList { ty: var }); - Ok(Some(list)) - } else { - let ty = self.get_list_elem_type(obj, len, helper, unifier, defs, primitives)?; - Ok(ty.map(|ty| unifier.add_ty(TypeEnum::TList { ty }))) - } + // do not handle type var param and concrete check here + let var = unifier.get_fresh_var().0; + let list = unifier.add_ty(TypeEnum::TList { ty: var }); + Ok(Ok((list, false))) } else if ty_id == self.primitive_ids.tuple { - let elements: &PyTuple = obj.cast_as()?; - let types: Result>, _> = elements - .iter() - .map(|elem| self.get_obj_type(elem, helper, unifier, defs, primitives)) - .collect(); - let types = types?; - Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types }))) - } else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id) { + // do not handle type var param and concrete check here + Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false))) + } else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).cloned() { + // println!("getting def"); let def = defs[def_id.0].read(); + // println!("got def"); if let TopLevelDef::Class { object_id, type_vars, @@ -120,35 +129,260 @@ impl Resolver { .. } = &*def { - let var_map: HashMap<_, _> = type_vars - .iter() - .map(|var| { - ( - if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) { - *id - } else { - unreachable!() - }, - unifier.get_fresh_var().0, - ) - }) - .collect(); - let mut fields_ty = HashMap::new(); - for method in methods.iter() { - fields_ty.insert(method.0, (method.1, false)); - } - for field in fields.iter() { - let name: String = field.0.into(); - let field_data = obj.getattr(&name)?; - let ty = self - .get_obj_type(field_data, helper, unifier, defs, primitives)? - .unwrap_or(primitives.none); - let field_ty = unifier.subst(field.1, &var_map).unwrap_or(field.1); - if unifier.unify(ty, field_ty).is_err() { - // field type mismatch - return Ok(None); + // do not handle type var param and concrete check here, and no subst + Ok(Ok({ + let ty = TypeEnum::TObj { + obj_id: *object_id, + params: RefCell::new({ + type_vars + .iter() + .map(|x| { + if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x) { + (*id, *x) + } else { unreachable!() } + }).collect() + }), + fields: RefCell::new({ + let mut res = methods + .iter() + .map(|(iden, ty, _)| (*iden, (*ty, false))) + .collect::>(); + res.extend(fields.clone().into_iter().map(|x| (x.0, (x.1, x.2)))); + res + }) + }; + // here also false, later insta use python object to check compatible + (unifier.add_ty(ty), false) + })) + } else { + // only object is supported, functions are not supported + unreachable!("function type is not supported, should not be queried") + } + } else if ty_ty_id == self.primitive_ids.typevar { + let constraint_types = { + let constraints = pyty.getattr("__constraints__").unwrap(); + let mut result: Vec = vec![]; + for i in 0.. { + if let Ok(constr) = constraints.get_item(i) { + result.push({ + match self.get_pyty_obj_type(constr, helper, unifier, defs, primitives)? { + Ok((ty, _)) => { + if unifier.is_concrete(ty, &[]) { + ty + } else { + return Ok(Err(format!( + "the {}th constraint of TypeVar `{}` is not concrete", + i + 1, + pyty.getattr("__name__")?.extract::()? + ))) + } + }, + Err(err) => return Ok(Err(err)) + } + }) + } else { + break; + } + } + result + }; + let res = unifier.get_fresh_var_with_range(&constraint_types).0; + Ok(Ok((res, true))) + } else if ty_ty_id == self.primitive_ids.generic_alias.0 || ty_ty_id == self.primitive_ids.generic_alias.1 { + let origin = helper.origin_ty_fn.call1((evaluated_ty,))?; + let args: &PyTuple = helper.args_ty_fn.call1((evaluated_ty,))?.cast_as()?; + let origin_ty = match self.get_pyty_obj_type(origin, helper, unifier, defs, primitives)? { + Ok((ty, false)) => ty, + Ok((_, true)) => return Ok(Err("instantiated type does not take type parameters".into())), + Err(err) => return Ok(Err(err)) + }; + + match &*unifier.get_ty(origin_ty) { + TypeEnum::TList { .. } => { + if args.len() == 1 { + let ty = match self.get_pyty_obj_type(args.get_item(0), helper, unifier, defs, primitives)? { + Ok(ty) => ty, + Err(err) => return Ok(Err(err)) + }; + if !unifier.is_concrete(ty.0, &[]) && !ty.1 { + panic!("type list should take concrete parameters in type var ranges") + } + Ok(Ok((unifier.add_ty(TypeEnum::TList { ty: ty.0 }), true))) + } else { + return Ok(Err(format!("type list needs exactly 1 type parameters, found {}", args.len()))) + } + }, + TypeEnum::TTuple { .. } => { + let args = match args + .iter() + .map(|x| self.get_pyty_obj_type(x, helper, unifier, defs, primitives)) + .collect::, _>>()? + .into_iter() + .collect::, _>>() { + Ok(args) if !args.is_empty() => args + .into_iter() + .map(|(x, check)| if !unifier.is_concrete(x, &[]) && !check { + panic!("type tuple should take concrete parameters in type var ranges") + } else { + x + } + ) + .collect::>(), + Err(err) => return Ok(Err(err)), + _ => return Ok(Err("tuple type needs at least 1 type parameters".to_string())) + }; + Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: args }), true))) + }, + TypeEnum::TObj { params, obj_id, .. } => { + let subst = { + let params = &*params.borrow(); + if params.len() != args.len() { + return Ok(Err(format!( + "for class #{}, expect {} type parameters, got {}.", + obj_id.0, + params.len(), + args.len(), + ))) + } + let args = match args + .iter() + .map(|x| self.get_pyty_obj_type(x, helper, unifier, defs, primitives)) + .collect::, _>>()? + .into_iter() + .collect::, _>>() { + Ok(args) => args + .into_iter() + .map(|(x, check)| if !unifier.is_concrete(x, &[]) && !check { + panic!("type class should take concrete parameters in type var ranges") + } else { + x + } + ) + .collect::>(), + Err(err) => return Ok(Err(err)), + }; + params + .iter() + .zip(args.iter()) + .map(|((id, _), ty)| (*id, *ty)) + .collect::>() + }; + Ok(Ok((unifier.subst(origin_ty, &subst).unwrap_or(origin_ty), true))) + }, + TypeEnum::TVirtual { .. } => { + if args.len() == 1 { + let ty = match self.get_pyty_obj_type(args.get_item(0), helper, unifier, defs, primitives)? { + Ok(ty) => ty, + Err(err) => return Ok(Err(err)) + }; + if !unifier.is_concrete(ty.0, &[]) && !ty.1 { + panic!("virtual class should take concrete parameters in type var ranges") + } + Ok(Ok((unifier.add_ty(TypeEnum::TVirtual { ty: ty.0 }), true))) + } else { + return Ok(Err(format!("virtual class needs exactly 1 type parameters, found {}", args.len()))) + } + } + _ => unimplemented!() + } + } else if ty_id == self.primitive_ids.virtual_id { + Ok(Ok(({ + let ty = TypeEnum::TVirtual { ty: unifier.get_fresh_var().0 }; + unifier.add_ty(ty) + }, false))) + } else { + Ok(Err("unknown type".into())) + } + } + + fn get_obj_type( + &self, + obj: &PyAny, + helper: &PythonHelper, + unifier: &mut Unifier, + defs: &[Arc>], + primitives: &PrimitiveStore, + ) -> PyResult> { + let (extracted_ty, inst_check) = match self.get_pyty_obj_type( + { + let ty = helper.type_fn.call1((obj,)).unwrap(); + if [self.primitive_ids.typevar, + self.primitive_ids.generic_alias.0, + self.primitive_ids.generic_alias.1 + ].contains(&helper.id_fn.call1((ty,))?.extract::()?) { + obj + } else { + ty + } + }, + helper, + unifier, + defs, + primitives + )? { + Ok(s) => s, + Err(_) => return Ok(None) + }; + return match (&*unifier.get_ty(extracted_ty), inst_check) { + // do the instantiation for these three types + (TypeEnum::TList { ty }, false) => { + let len: usize = helper.len_fn.call1((obj,))?.extract()?; + if len == 0 { + assert!(matches!( + &*unifier.get_ty(extracted_ty), + TypeEnum::TVar { meta: nac3core::typecheck::typedef::TypeVarMeta::Generic, range, .. } + if range.borrow().is_empty() + )); + Ok(Some(extracted_ty)) + } else { + let actual_ty = self + .get_list_elem_type(obj, len, helper, unifier, defs, primitives)?; + if let Some(actual_ty) = actual_ty { + unifier.unify(*ty, actual_ty).unwrap(); + Ok(Some(extracted_ty)) + } else { + Ok(None) + } + } + } + (TypeEnum::TTuple { .. }, false) => { + let elements: &PyTuple = obj.cast_as()?; + let types: Result>, _> = elements + .iter() + .map(|elem| self.get_obj_type(elem, helper, unifier, defs, primitives)) + .collect(); + let types = types?; + Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types }))) + } + (TypeEnum::TObj { params, fields, .. }, false) => { + let var_map = params + .borrow() + .iter() + .map(|(id_var, ty)| { + if let TypeEnum::TVar { id, range, .. } = &*unifier.get_ty(*ty) { + assert_eq!(*id, *id_var); + (*id, unifier.get_fresh_var_with_range(&range.borrow()).0) + } else { + unreachable!() + } + }) + .collect::>(); + // loop through non-function fields of the class to get the instantiated value + for field in fields.borrow().iter() { + let name: String = (*field.0).into(); + if let TypeEnum::TFunc( .. ) = &*unifier.get_ty(field.1.0) { + continue; + } else { + let field_data = obj.getattr(&name)?; + let ty = self + .get_obj_type(field_data, helper, unifier, defs, primitives)? + .unwrap_or(primitives.none); + let field_ty = unifier.subst(field.1.0, &var_map).unwrap_or(field.1.0); + if unifier.unify(ty, field_ty).is_err() { + // field type mismatch + return Ok(None); + } } - fields_ty.insert(field.0, (ty, field.2)); } for (_, ty) in var_map.iter() { // must be concrete type @@ -156,18 +390,10 @@ impl Resolver { return Ok(None) } } - Ok(Some(unifier.add_ty(TypeEnum::TObj { - obj_id: *object_id, - fields: RefCell::new(fields_ty), - params: RefCell::new(var_map), - }))) - } else { - // only object is supported, functions are not supported - Ok(None) + return Ok(Some(unifier.subst(extracted_ty, &var_map).unwrap_or(extracted_ty))); } - } else { - Ok(None) - } + _ => Ok(Some(extracted_ty)) + }; } fn get_obj_value<'ctx, 'a>( @@ -425,10 +651,16 @@ impl SymbolResolver for Resolver { let key: &str = member.get_item(0)?.extract()?; if key == str.to_string() { let builtins = PyModule::import(py, "builtins")?; + let typings = PyModule::import(py, "typing")?; let helper = PythonHelper { id_fn: builtins.getattr("id").unwrap(), len_fn: builtins.getattr("len").unwrap(), type_fn: builtins.getattr("type").unwrap(), + origin_ty_fn: typings.getattr("get_origin").unwrap(), + args_ty_fn: typings.getattr("get_args").unwrap(), + globals_dict: obj.getattr("__dict__").unwrap(), + eval_type_fn: typings.getattr("_eval_type").unwrap(), + print_fn: builtins.getattr("print").unwrap(), }; sym_ty = self.get_obj_type( member.get_item(1)?, @@ -469,10 +701,16 @@ impl SymbolResolver for Resolver { let val = member.get_item(1)?; if key == id.to_string() { let builtins = PyModule::import(py, "builtins")?; + let typings = PyModule::import(py, "typing")?; let helper = PythonHelper { id_fn: builtins.getattr("id").unwrap(), len_fn: builtins.getattr("len").unwrap(), type_fn: builtins.getattr("type").unwrap(), + origin_ty_fn: typings.getattr("get_origin").unwrap(), + args_ty_fn: typings.getattr("get_args").unwrap(), + globals_dict: obj.getattr("__dict__").unwrap(), + eval_type_fn: typings.getattr("_eval_type").unwrap(), + print_fn: builtins.getattr("print").unwrap(), }; sym_value = self.get_obj_value(val, &helper, ctx)?; break; -- 2.44.1 From 9406a645c7ad1712baaaee1f80a535a5ffc2b963 Mon Sep 17 00:00:00 2001 From: ychenfo Date: Thu, 11 Nov 2021 14:18:52 +0800 Subject: [PATCH 02/10] nac3artiq: avoid using py.eval to get id of class virtual --- nac3artiq/src/lib.rs | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 07d9412..daa88c2 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -274,11 +274,17 @@ impl Nac3 { let typing_mod = PyModule::import(py, "typing").unwrap(); let types_mod = PyModule::import(py, "types").unwrap(); let primitive_ids = PrimitivePythonId { - virtual_id: py.eval( - "id(virtual)", - Some(builtins_mod.getattr("globals").unwrap().call0().unwrap().extract().unwrap()), - None - ).unwrap().extract().unwrap(), + virtual_id: id_fn + .call1((builtins_mod + .getattr("globals") + .unwrap() + .call0() + .unwrap() + .get_item("virtual") + .unwrap(), + )).unwrap() + .extract() + .unwrap(), generic_alias: ( id_fn .call1((typing_mod.getattr("_GenericAlias").unwrap(),)) -- 2.44.1 From eb1f353acd35c3138300dfb273cd2ed1b6637d9c Mon Sep 17 00:00:00 2001 From: ychenfo Date: Thu, 11 Nov 2021 21:27:23 +0800 Subject: [PATCH 03/10] nac3artiq: remove unnecessary python print from helper --- nac3artiq/src/symbol_resolver.rs | 3 --- 1 file changed, 3 deletions(-) diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index b27c7df..722f410 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -44,7 +44,6 @@ struct PythonHelper<'a> { origin_ty_fn: &'a PyAny, args_ty_fn: &'a PyAny, globals_dict: &'a PyAny, - print_fn: &'a PyAny, } impl Resolver { @@ -660,7 +659,6 @@ impl SymbolResolver for Resolver { args_ty_fn: typings.getattr("get_args").unwrap(), globals_dict: obj.getattr("__dict__").unwrap(), eval_type_fn: typings.getattr("_eval_type").unwrap(), - print_fn: builtins.getattr("print").unwrap(), }; sym_ty = self.get_obj_type( member.get_item(1)?, @@ -710,7 +708,6 @@ impl SymbolResolver for Resolver { args_ty_fn: typings.getattr("get_args").unwrap(), globals_dict: obj.getattr("__dict__").unwrap(), eval_type_fn: typings.getattr("_eval_type").unwrap(), - print_fn: builtins.getattr("print").unwrap(), }; sym_value = self.get_obj_value(val, &helper, ctx)?; break; -- 2.44.1 From e8a5843ca70bfd2788fd8f324d0e3be0efe9c9b3 Mon Sep 17 00:00:00 2001 From: ychenfo Date: Thu, 11 Nov 2021 01:54:29 +0800 Subject: [PATCH 04/10] nac3standalone: basic resolver typevar handling --- nac3core/src/toplevel/composer.rs | 2 +- nac3standalone/src/main.rs | 42 ++++++++++++++++++++++++++++++- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index bf9f156..013829e 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -534,7 +534,7 @@ impl TopLevelComposer { } } - fn extract_def_list(&self) -> Vec>> { + pub fn extract_def_list(&self) -> Vec>> { self.definition_ast_list.iter().map(|(def, ..)| def.clone()).collect_vec() } diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index 5b4e31d..a4a82e5 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -4,7 +4,7 @@ use inkwell::{ OptimizationLevel, }; use nac3core::typecheck::type_inferencer::PrimitiveStore; -use nac3parser::parser; +use nac3parser::{ast::{ExprKind, StmtKind}, parser}; use std::env; use std::fs; use std::{collections::HashMap, path::Path, sync::Arc, time::SystemTime}; @@ -66,6 +66,46 @@ fn main() { ); for stmt in parser_result.into_iter() { + // handle type vars in toplevel + if let StmtKind::Assign { value, targets, .. } = &stmt.node { + assert_eq!(targets.len(), 1, "only support single assignment for now, at {}", targets[0].location); + if let ExprKind::Call { func, args, .. } = &value.node { + if matches!(&func.node, ExprKind::Name { id, .. } if id == &"TypeVar".into()) { + print!("registering typevar {:?}", targets[0].node); + let constraints = args + .iter() + .skip(1) + .map(|x| { + let def_list = &composer.extract_def_list(); + let unifier = &mut composer.unifier; + resolver.parse_type_annotation( + def_list, + unifier, + &primitive, + x + ).unwrap() + }) + .collect::>(); + let res_ty = composer.unifier.get_fresh_var_with_range(&constraints).0; + println!( + " ...registered: {}", + composer.unifier.stringify( + res_ty, + &mut |x| format!("obj{}", x), + &mut |x| format!("tavr{}", x) + ) + ); + internal_resolver.add_id_type( + if let ExprKind::Name { id, .. } = &targets[0].node { *id } else { + panic!("must assign simple name variable as type variable for now") + }, + res_ty + ); + continue; + } + } + } + let (name, def_id, ty) = composer .register_top_level(stmt, Some(resolver.clone()), "__main__".into()) .unwrap(); -- 2.44.1 From 66a9eda3c18fe52379debde7fb2ad94a629ef4ce Mon Sep 17 00:00:00 2001 From: ychenfo Date: Sat, 13 Nov 2021 17:21:59 +0800 Subject: [PATCH 05/10] nac3standalone: fix resolver typevar err msg --- nac3standalone/src/main.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index a4a82e5..97b8951 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -97,7 +97,7 @@ fn main() { ); internal_resolver.add_id_type( if let ExprKind::Name { id, .. } = &targets[0].node { *id } else { - panic!("must assign simple name variable as type variable for now") + panic!("must assign simple name variable as type variable") }, res_ty ); -- 2.44.1 From dab06bdb58c102a808466a7e655253c17b791d93 Mon Sep 17 00:00:00 2001 From: ychenfo Date: Thu, 11 Nov 2021 02:46:41 +0800 Subject: [PATCH 06/10] nac3core: parse type annotation python forwardref handling --- nac3artiq/demo/min_artiq.py | 1 - nac3core/src/symbol_resolver.rs | 295 ++++++++++++----------- nac3core/src/toplevel/type_annotation.rs | 237 +++++++++--------- 3 files changed, 272 insertions(+), 261 deletions(-) diff --git a/nac3artiq/demo/min_artiq.py b/nac3artiq/demo/min_artiq.py index f644eee..fcb59ce 100644 --- a/nac3artiq/demo/min_artiq.py +++ b/nac3artiq/demo/min_artiq.py @@ -25,7 +25,6 @@ class virtual(Generic[T]): import device_db core_arguments = device_db.device_db["core"]["arguments"] - compiler = nac3artiq.NAC3(core_arguments["target"]) allow_registration = True # Delay NAC3 analysis until all referenced variables are supposed to exist on the CPython side. diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 161d78a..9f9c161 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -13,7 +13,7 @@ use crate::{ use crate::{location::Location, typecheck::typedef::TypeEnum}; use inkwell::values::BasicValueEnum; use itertools::{chain, izip}; -use nac3parser::ast::{Expr, StrRef}; +use nac3parser::ast::{Constant::Str, Expr, StrRef}; use parking_lot::RwLock; #[derive(Clone, PartialEq, Debug)] @@ -79,159 +79,168 @@ pub fn parse_type_annotation( let list_id = ids[6]; let tuple_id = ids[7]; - match &expr.node { - Name { id, .. } => { - if *id == int32_id { - Ok(primitives.int32) - } else if *id == int64_id { - Ok(primitives.int64) - } else if *id == float_id { - Ok(primitives.float) - } else if *id == bool_id { - Ok(primitives.bool) - } else if *id == none_id { - Ok(primitives.none) - } else { - let obj_id = resolver.get_identifier_def(*id); - if let Some(obj_id) = obj_id { - let def = top_level_defs[obj_id.0].read(); - if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { - if !type_vars.is_empty() { - return Err(format!( - "Unexpected number of type parameters: expected {} but got 0", - type_vars.len() - )); - } - let fields = RefCell::new( - chain( - fields.iter().map(|(k, v, m)| (*k, (*v, *m))), - methods.iter().map(|(k, v, _)| (*k, (*v, false))), - ) - .collect(), - ); - Ok(unifier.add_ty(TypeEnum::TObj { - obj_id, - fields, - params: Default::default(), - })) - } else { - Err("Cannot use function name as type".into()) + let name_handling = |id: &StrRef, unifier: &mut Unifier| { + if *id == int32_id { + Ok(primitives.int32) + } else if *id == int64_id { + Ok(primitives.int64) + } else if *id == float_id { + Ok(primitives.float) + } else if *id == bool_id { + Ok(primitives.bool) + } else if *id == none_id { + Ok(primitives.none) + } else { + let obj_id = resolver.get_identifier_def(*id); + if let Some(obj_id) = obj_id { + let def = top_level_defs[obj_id.0].read(); + if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { + if !type_vars.is_empty() { + return Err(format!( + "Unexpected number of type parameters: expected {} but got 0", + type_vars.len() + )); } + let fields = RefCell::new( + chain( + fields.iter().map(|(k, v, m)| (*k, (*v, *m))), + methods.iter().map(|(k, v, _)| (*k, (*v, false))), + ) + .collect(), + ); + Ok(unifier.add_ty(TypeEnum::TObj { + obj_id, + fields, + params: Default::default(), + })) } else { - // it could be a type variable - let ty = resolver - .get_symbol_type(unifier, top_level_defs, primitives, *id) - .ok_or_else(|| "unknown type variable name".to_owned())?; - if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) { - Ok(ty) - } else { - Err(format!("Unknown type annotation {}", id)) - } + Err("Cannot use function name as type".into()) + } + } else { + // it could be a type variable + let ty = resolver + .get_symbol_type(unifier, top_level_defs, primitives, *id) + .ok_or_else(|| "unknown type variable name".to_owned())?; + if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) { + Ok(ty) + } else { + Err(format!("Unknown type annotation {}", id)) } } } - Subscript { value, slice, .. } => { - if let Name { id, .. } = &value.node { - if *id == virtual_id { - let ty = parse_type_annotation( - resolver, - top_level_defs, - unifier, - primitives, - slice, - )?; - Ok(unifier.add_ty(TypeEnum::TVirtual { ty })) - } else if *id == list_id { - let ty = parse_type_annotation( - resolver, - top_level_defs, - unifier, - primitives, - slice, - )?; - Ok(unifier.add_ty(TypeEnum::TList { ty })) - } else if *id == tuple_id { - if let Tuple { elts, .. } = &slice.node { - let ty = elts - .iter() - .map(|elt| { - parse_type_annotation( - resolver, - top_level_defs, - unifier, - primitives, - elt, - ) - }) - .collect::, _>>()?; - Ok(unifier.add_ty(TypeEnum::TTuple { ty })) - } else { - Err("Expected multiple elements for tuple".into()) - } - } else { - let types = if let Tuple { elts, .. } = &slice.node { - elts.iter() - .map(|v| { - parse_type_annotation( - resolver, - top_level_defs, - unifier, - primitives, - v, - ) - }) - .collect::, _>>()? - } else { - vec![parse_type_annotation( + }; + + let subscript_name_handle = |id: &StrRef, slice: &Expr, unifier: &mut Unifier| { + if *id == virtual_id { + let ty = parse_type_annotation( + resolver, + top_level_defs, + unifier, + primitives, + slice, + )?; + Ok(unifier.add_ty(TypeEnum::TVirtual { ty })) + } else if *id == list_id { + let ty = parse_type_annotation( + resolver, + top_level_defs, + unifier, + primitives, + slice, + )?; + Ok(unifier.add_ty(TypeEnum::TList { ty })) + } else if *id == tuple_id { + if let Tuple { elts, .. } = &slice.node { + let ty = elts + .iter() + .map(|elt| { + parse_type_annotation( resolver, top_level_defs, unifier, primitives, - slice, - )?] - }; + elt, + ) + }) + .collect::, _>>()?; + Ok(unifier.add_ty(TypeEnum::TTuple { ty })) + } else { + Err("Expected multiple elements for tuple".into()) + } + } else { + let types = if let Tuple { elts, .. } = &slice.node { + elts.iter() + .map(|v| { + parse_type_annotation( + resolver, + top_level_defs, + unifier, + primitives, + v, + ) + }) + .collect::, _>>()? + } else { + vec![parse_type_annotation( + resolver, + top_level_defs, + unifier, + primitives, + slice, + )?] + }; - let obj_id = resolver - .get_identifier_def(*id) - .ok_or_else(|| format!("Unknown type annotation {}", id))?; - let def = top_level_defs[obj_id.0].read(); - if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { - if types.len() != type_vars.len() { - return Err(format!( - "Unexpected number of type parameters: expected {} but got {}", - type_vars.len(), - types.len() - )); - } - let mut subst = HashMap::new(); - for (var, ty) in izip!(type_vars.iter(), types.iter()) { - let id = if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) { - *id - } else { - unreachable!() - }; - subst.insert(id, *ty); - } - let mut fields = fields - .iter() - .map(|(attr, ty, is_mutable)| { - let ty = unifier.subst(*ty, &subst).unwrap_or(*ty); - (*attr, (ty, *is_mutable)) - }) - .collect::>(); - fields.extend(methods.iter().map(|(attr, ty, _)| { - let ty = unifier.subst(*ty, &subst).unwrap_or(*ty); - (*attr, (ty, false)) - })); - Ok(unifier.add_ty(TypeEnum::TObj { - obj_id, - fields: fields.into(), - params: subst.into(), - })) - } else { - Err("Cannot use function name as type".into()) - } + let obj_id = resolver + .get_identifier_def(*id) + .ok_or_else(|| format!("Unknown type annotation {}", id))?; + let def = top_level_defs[obj_id.0].read(); + if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { + if types.len() != type_vars.len() { + return Err(format!( + "Unexpected number of type parameters: expected {} but got {}", + type_vars.len(), + types.len() + )); } + let mut subst = HashMap::new(); + for (var, ty) in izip!(type_vars.iter(), types.iter()) { + let id = if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) { + *id + } else { + unreachable!() + }; + subst.insert(id, *ty); + } + let mut fields = fields + .iter() + .map(|(attr, ty, is_mutable)| { + let ty = unifier.subst(*ty, &subst).unwrap_or(*ty); + (*attr, (ty, *is_mutable)) + }) + .collect::>(); + fields.extend(methods.iter().map(|(attr, ty, _)| { + let ty = unifier.subst(*ty, &subst).unwrap_or(*ty); + (*attr, (ty, false)) + })); + Ok(unifier.add_ty(TypeEnum::TObj { + obj_id, + fields: fields.into(), + params: subst.into(), + })) + } else { + Err("Cannot use function name as type".into()) + } + } + }; + + match &expr.node { + Name { id, .. } => name_handling(id, unifier), + Constant { value: Str(id), .. } => name_handling(&id.clone().into(), unifier), + Subscript { value, slice, .. } => { + if let Name { id, .. } = &value.node { + subscript_name_handle(id, slice, unifier) + } else if let Constant { value: Str(id), .. } = &value.node { + subscript_name_handle(&id.clone().into(), slice, unifier) } else { Err("unsupported type expression".into()) } diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index 257d582..7d6523a 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -1,7 +1,7 @@ use std::cell::RefCell; use crate::typecheck::typedef::TypeVarMeta; - +use ast::Constant::Str; use super::*; #[derive(Clone, Debug)] @@ -49,58 +49,127 @@ pub fn parse_ast_to_type_annotation_kinds( primitives: &PrimitiveStore, expr: &ast::Expr, // the key stores the type_var of this topleveldef::class, we only need this field here - mut locked: HashMap>, + locked: HashMap>, ) -> Result { - match &expr.node { - ast::ExprKind::Name { id, .. } => { - if id == &"int32".into() { - Ok(TypeAnnotation::Primitive(primitives.int32)) - } else if id == &"int64".into() { - Ok(TypeAnnotation::Primitive(primitives.int64)) - } else if id == &"float".into() { - Ok(TypeAnnotation::Primitive(primitives.float)) - } else if id == &"bool".into() { - Ok(TypeAnnotation::Primitive(primitives.bool)) - } else if id == &"None".into() { - Ok(TypeAnnotation::Primitive(primitives.none)) - } else if id == &"str".into() { - Ok(TypeAnnotation::Primitive(primitives.str)) - } else if let Some(obj_id) = resolver.get_identifier_def(*id) { - let type_vars = { - let def_read = top_level_defs[obj_id.0].try_read(); - if let Some(def_read) = def_read { - if let TopLevelDef::Class { type_vars, .. } = &*def_read { - type_vars.clone() - } else { - return Err("function cannot be used as a type".into()); - } + let name_handle = |id: &StrRef, unifier: &mut Unifier, locked: HashMap>| { + if id == &"int32".into() { + Ok(TypeAnnotation::Primitive(primitives.int32)) + } else if id == &"int64".into() { + Ok(TypeAnnotation::Primitive(primitives.int64)) + } else if id == &"float".into() { + Ok(TypeAnnotation::Primitive(primitives.float)) + } else if id == &"bool".into() { + Ok(TypeAnnotation::Primitive(primitives.bool)) + } else if id == &"None".into() { + Ok(TypeAnnotation::Primitive(primitives.none)) + } else if id == &"str".into() { + Ok(TypeAnnotation::Primitive(primitives.str)) + } else if let Some(obj_id) = resolver.get_identifier_def(*id) { + let type_vars = { + let def_read = top_level_defs[obj_id.0].try_read(); + if let Some(def_read) = def_read { + if let TopLevelDef::Class { type_vars, .. } = &*def_read { + type_vars.clone() } else { - locked.get(&obj_id).unwrap().clone() + return Err("function cannot be used as a type".into()); } - }; - // check param number here - if !type_vars.is_empty() { - return Err(format!( - "expect {} type variable parameter but got 0", - type_vars.len() - )); - } - Ok(TypeAnnotation::CustomClass { id: obj_id, params: vec![] }) - } else if let Some(ty) = resolver.get_symbol_type(unifier, top_level_defs, primitives, *id) { - if let TypeEnum::TVar { .. } = unifier.get_ty(ty).as_ref() { - Ok(TypeAnnotation::TypeVar(ty)) } else { - Err("not a type variable identifier".into()) + locked.get(&obj_id).unwrap().clone() + } + }; + // check param number here + if !type_vars.is_empty() { + return Err(format!( + "expect {} type variable parameter but got 0", + type_vars.len() + )); + } + Ok(TypeAnnotation::CustomClass { id: obj_id, params: vec![] }) + } else if let Some(ty) = resolver.get_symbol_type(unifier, top_level_defs, primitives, *id) { + if let TypeEnum::TVar { .. } = unifier.get_ty(ty).as_ref() { + Ok(TypeAnnotation::TypeVar(ty)) + } else { + Err("not a type variable identifier".into()) + } + } else { + Err("name cannot be parsed as a type annotation".into()) + } + }; + + let class_name_handle = + |id: &StrRef, slice: &ast::Expr, unifier: &mut Unifier, mut locked: HashMap>| { + if vec!["virtual".into(), "Generic".into(), "list".into(), "tuple".into()] + .contains(id) + { + return Err("keywords cannot be class name".into()); + } + let obj_id = resolver + .get_identifier_def(*id) + .ok_or_else(|| "unknown class name".to_string())?; + let type_vars = { + let def_read = top_level_defs[obj_id.0].try_read(); + if let Some(def_read) = def_read { + if let TopLevelDef::Class { type_vars, .. } = &*def_read { + type_vars.clone() + } else { + unreachable!("must be class here") } } else { - Err("name cannot be parsed as a type annotation".into()) + locked.get(&obj_id).unwrap().clone() } - } - + }; + // we do not check whether the application of type variables are compatible here + let param_type_infos = { + let params_ast = if let ast::ExprKind::Tuple { elts, .. } = &slice.node { + elts.iter().collect_vec() + } else { + vec![slice] + }; + if type_vars.len() != params_ast.len() { + return Err(format!( + "expect {} type parameters but got {}", + type_vars.len(), + params_ast.len() + )); + } + let result = params_ast + .into_iter() + .map(|x| { + parse_ast_to_type_annotation_kinds( + resolver, + top_level_defs, + unifier, + primitives, + x, + { + locked.insert(obj_id, type_vars.clone()); + locked.clone() + }, + ) + }) + .collect::, _>>()?; + // make sure the result do not contain any type vars + let no_type_var = result + .iter() + .all(|x| get_type_var_contained_in_type_annotation(x).is_empty()); + if no_type_var { + result + } else { + return Err("application of type vars to generic class \ + is not currently supported" + .into()); + } + }; + Ok(TypeAnnotation::CustomClass { id: obj_id, params: param_type_infos }) + }; + match &expr.node { + ast::ExprKind::Name { id, .. } => name_handle(id, unifier, locked), + ast::ExprKind::Constant { value: Str(id), .. } => name_handle(&id.clone().into(), unifier, locked), // virtual ast::ExprKind::Subscript { value, slice, .. } if { - matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"virtual".into()) + matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"virtual".into()) || + matches!(&value.node, ast::ExprKind::Constant { value: Str(id), .. } if id == "virtual") } => { let def = parse_ast_to_type_annotation_kinds( @@ -120,7 +189,8 @@ pub fn parse_ast_to_type_annotation_kinds( // list ast::ExprKind::Subscript { value, slice, .. } if { - matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"list".into()) + matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"list".into()) || + matches!(&value.node, ast::ExprKind::Constant { value: Str(id), .. } if id == "list") } => { let def_ann = parse_ast_to_type_annotation_kinds( @@ -137,7 +207,8 @@ pub fn parse_ast_to_type_annotation_kinds( // tuple ast::ExprKind::Subscript { value, slice, .. } if { - matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"tuple".into()) + matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"tuple".into()) || + matches!(&value.node, ast::ExprKind::Constant { value: Str(id), .. } if id == "tuple") } => { if let ast::ExprKind::Tuple { elts, .. } = &slice.node { @@ -163,71 +234,9 @@ pub fn parse_ast_to_type_annotation_kinds( // custom class ast::ExprKind::Subscript { value, slice, .. } => { if let ast::ExprKind::Name { id, .. } = &value.node { - if vec!["virtual".into(), "Generic".into(), "list".into(), "tuple".into()] - .contains(id) - { - return Err("keywords cannot be class name".into()); - } - let obj_id = resolver - .get_identifier_def(*id) - .ok_or_else(|| "unknown class name".to_string())?; - let type_vars = { - let def_read = top_level_defs[obj_id.0].try_read(); - if let Some(def_read) = def_read { - if let TopLevelDef::Class { type_vars, .. } = &*def_read { - type_vars.clone() - } else { - unreachable!("must be class here") - } - } else { - locked.get(&obj_id).unwrap().clone() - } - }; - // we do not check whether the application of type variables are compatible here - let param_type_infos = { - let params_ast = if let ast::ExprKind::Tuple { elts, .. } = &slice.node { - elts.iter().collect_vec() - } else { - vec![slice.as_ref()] - }; - if type_vars.len() != params_ast.len() { - return Err(format!( - "expect {} type parameters but got {}", - type_vars.len(), - params_ast.len() - )); - } - let result = params_ast - .into_iter() - .map(|x| { - parse_ast_to_type_annotation_kinds( - resolver, - top_level_defs, - unifier, - primitives, - x, - { - locked.insert(obj_id, type_vars.clone()); - locked.clone() - }, - ) - }) - .collect::, _>>()?; - - // make sure the result do not contain any type vars - let no_type_var = result - .iter() - .all(|x| get_type_var_contained_in_type_annotation(x).is_empty()); - if no_type_var { - result - } else { - return Err("application of type vars to generic class \ - is not currently supported" - .into()); - } - }; - - Ok(TypeAnnotation::CustomClass { id: obj_id, params: param_type_infos }) + class_name_handle(id, slice, unifier, locked) + } else if let ast::ExprKind::Constant { value: Str(id), .. } = &value.node { + class_name_handle(&id.clone().into(), slice, unifier, locked) } else { Err("unsupported expression type for class name".into()) } @@ -368,13 +377,7 @@ pub fn get_type_from_type_annotation_kinds( /// But note that here we do not make a duplication of `T`, `V`, we direclty /// use them as they are in the TopLevelDef::Class since those in the /// TopLevelDef::Class.type_vars will be substitute later when seeing applications/instantiations -/// the Type of their fields and methods will also be subst when application/instantiation \ -/// \ -/// Note this implicit self type is different with seeing `A[T, V]` explicitly outside -/// the class def ast body, where it is a new instantiation of the generic class `A`, -/// but equivalent to seeing `A[T, V]` inside the class def body ast, where although we -/// create copies of `T` and `V`, we will find them out as occured type vars in the analyze_class() -/// and unify them with the class generic `T`, `V` +/// the Type of their fields and methods will also be subst when application/instantiation pub fn make_self_type_annotation(type_vars: &[Type], object_id: DefinitionId) -> TypeAnnotation { TypeAnnotation::CustomClass { id: object_id, -- 2.44.1 From 1c5e68aca9be8032a08556d5fc91a2b058e096c4 Mon Sep 17 00:00:00 2001 From: ychenfo Date: Sun, 21 Nov 2021 05:07:36 +0800 Subject: [PATCH 07/10] nac3artiq/nac3core: remove forwardref type annotation support for unstable python API --- nac3artiq/src/symbol_resolver.rs | 18 ++++-------------- nac3core/src/symbol_resolver.rs | 5 +---- nac3core/src/toplevel/type_annotation.rs | 13 +++---------- 3 files changed, 8 insertions(+), 28 deletions(-) diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 722f410..454f178 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -40,10 +40,8 @@ struct PythonHelper<'a> { type_fn: &'a PyAny, len_fn: &'a PyAny, id_fn: &'a PyAny, - eval_type_fn: &'a PyAny, origin_ty_fn: &'a PyAny, args_ty_fn: &'a PyAny, - globals_dict: &'a PyAny, } impl Resolver { @@ -87,17 +85,13 @@ impl Resolver { defs: &[Arc>], primitives: &PrimitiveStore, ) -> PyResult> { - // eval_type use only globals_dict should be fine - let evaluated_ty = helper - .eval_type_fn - .call1((pyty, helper.globals_dict, helper.globals_dict)).unwrap(); let ty_id: u64 = helper .id_fn - .call1((evaluated_ty,))? + .call1((pyty,))? .extract()?; let ty_ty_id: u64 = helper .id_fn - .call1((helper.type_fn.call1((evaluated_ty,))?,))? + .call1((helper.type_fn.call1((pyty,))?,))? .extract()?; if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 { @@ -188,8 +182,8 @@ impl Resolver { let res = unifier.get_fresh_var_with_range(&constraint_types).0; Ok(Ok((res, true))) } else if ty_ty_id == self.primitive_ids.generic_alias.0 || ty_ty_id == self.primitive_ids.generic_alias.1 { - let origin = helper.origin_ty_fn.call1((evaluated_ty,))?; - let args: &PyTuple = helper.args_ty_fn.call1((evaluated_ty,))?.cast_as()?; + let origin = helper.origin_ty_fn.call1((pyty,))?; + let args: &PyTuple = helper.args_ty_fn.call1((pyty,))?.cast_as()?; let origin_ty = match self.get_pyty_obj_type(origin, helper, unifier, defs, primitives)? { Ok((ty, false)) => ty, Ok((_, true)) => return Ok(Err("instantiated type does not take type parameters".into())), @@ -657,8 +651,6 @@ impl SymbolResolver for Resolver { type_fn: builtins.getattr("type").unwrap(), origin_ty_fn: typings.getattr("get_origin").unwrap(), args_ty_fn: typings.getattr("get_args").unwrap(), - globals_dict: obj.getattr("__dict__").unwrap(), - eval_type_fn: typings.getattr("_eval_type").unwrap(), }; sym_ty = self.get_obj_type( member.get_item(1)?, @@ -706,8 +698,6 @@ impl SymbolResolver for Resolver { type_fn: builtins.getattr("type").unwrap(), origin_ty_fn: typings.getattr("get_origin").unwrap(), args_ty_fn: typings.getattr("get_args").unwrap(), - globals_dict: obj.getattr("__dict__").unwrap(), - eval_type_fn: typings.getattr("_eval_type").unwrap(), }; sym_value = self.get_obj_value(val, &helper, ctx)?; break; diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 9f9c161..1b1c146 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -13,7 +13,7 @@ use crate::{ use crate::{location::Location, typecheck::typedef::TypeEnum}; use inkwell::values::BasicValueEnum; use itertools::{chain, izip}; -use nac3parser::ast::{Constant::Str, Expr, StrRef}; +use nac3parser::ast::{Expr, StrRef}; use parking_lot::RwLock; #[derive(Clone, PartialEq, Debug)] @@ -235,12 +235,9 @@ pub fn parse_type_annotation( match &expr.node { Name { id, .. } => name_handling(id, unifier), - Constant { value: Str(id), .. } => name_handling(&id.clone().into(), unifier), Subscript { value, slice, .. } => { if let Name { id, .. } = &value.node { subscript_name_handle(id, slice, unifier) - } else if let Constant { value: Str(id), .. } = &value.node { - subscript_name_handle(&id.clone().into(), slice, unifier) } else { Err("unsupported type expression".into()) } diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index 7d6523a..c6482d9 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -1,7 +1,6 @@ use std::cell::RefCell; use crate::typecheck::typedef::TypeVarMeta; -use ast::Constant::Str; use super::*; #[derive(Clone, Debug)] @@ -164,12 +163,10 @@ pub fn parse_ast_to_type_annotation_kinds( }; match &expr.node { ast::ExprKind::Name { id, .. } => name_handle(id, unifier, locked), - ast::ExprKind::Constant { value: Str(id), .. } => name_handle(&id.clone().into(), unifier, locked), // virtual ast::ExprKind::Subscript { value, slice, .. } if { - matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"virtual".into()) || - matches!(&value.node, ast::ExprKind::Constant { value: Str(id), .. } if id == "virtual") + matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"virtual".into()) } => { let def = parse_ast_to_type_annotation_kinds( @@ -189,8 +186,7 @@ pub fn parse_ast_to_type_annotation_kinds( // list ast::ExprKind::Subscript { value, slice, .. } if { - matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"list".into()) || - matches!(&value.node, ast::ExprKind::Constant { value: Str(id), .. } if id == "list") + matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"list".into()) } => { let def_ann = parse_ast_to_type_annotation_kinds( @@ -207,8 +203,7 @@ pub fn parse_ast_to_type_annotation_kinds( // tuple ast::ExprKind::Subscript { value, slice, .. } if { - matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"tuple".into()) || - matches!(&value.node, ast::ExprKind::Constant { value: Str(id), .. } if id == "tuple") + matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"tuple".into()) } => { if let ast::ExprKind::Tuple { elts, .. } = &slice.node { @@ -235,8 +230,6 @@ pub fn parse_ast_to_type_annotation_kinds( ast::ExprKind::Subscript { value, slice, .. } => { if let ast::ExprKind::Name { id, .. } = &value.node { class_name_handle(id, slice, unifier, locked) - } else if let ast::ExprKind::Constant { value: Str(id), .. } = &value.node { - class_name_handle(&id.clone().into(), slice, unifier, locked) } else { Err("unsupported expression type for class name".into()) } -- 2.44.1 From fe4fbdc5db33b234c609a5ca388e11da9542e270 Mon Sep 17 00:00:00 2001 From: ychenfo Date: Mon, 22 Nov 2021 15:10:50 +0800 Subject: [PATCH 08/10] cleanup println --- nac3artiq/src/symbol_resolver.rs | 2 -- nac3standalone/src/main.rs | 9 --------- 2 files changed, 11 deletions(-) diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 454f178..1e42bc8 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -111,9 +111,7 @@ impl Resolver { // do not handle type var param and concrete check here Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false))) } else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).cloned() { - // println!("getting def"); let def = defs[def_id.0].read(); - // println!("got def"); if let TopLevelDef::Class { object_id, type_vars, diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index 97b8951..b7c161d 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -71,7 +71,6 @@ fn main() { assert_eq!(targets.len(), 1, "only support single assignment for now, at {}", targets[0].location); if let ExprKind::Call { func, args, .. } = &value.node { if matches!(&func.node, ExprKind::Name { id, .. } if id == &"TypeVar".into()) { - print!("registering typevar {:?}", targets[0].node); let constraints = args .iter() .skip(1) @@ -87,14 +86,6 @@ fn main() { }) .collect::>(); let res_ty = composer.unifier.get_fresh_var_with_range(&constraints).0; - println!( - " ...registered: {}", - composer.unifier.stringify( - res_ty, - &mut |x| format!("obj{}", x), - &mut |x| format!("tavr{}", x) - ) - ); internal_resolver.add_id_type( if let ExprKind::Name { id, .. } = &targets[0].node { *id } else { panic!("must assign simple name variable as type variable") -- 2.44.1 From 49240a80ad9ae20d842c0e2a3fa5d9a219eb3670 Mon Sep 17 00:00:00 2001 From: ychenfo Date: Wed, 24 Nov 2021 18:24:16 +0800 Subject: [PATCH 09/10] nac3standalone: iteration over multiple typevar assignment in the same line --- nac3core/src/toplevel/mod.rs | 2 +- nac3standalone/src/main.rs | 97 ++++++++++++++++++++++++++++++++---- 2 files changed, 89 insertions(+), 10 deletions(-) diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index f6fc15b..bcf1e62 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -24,7 +24,7 @@ pub struct DefinitionId(pub usize); pub mod composer; pub mod helper; -mod type_annotation; +pub mod type_annotation; use composer::*; use type_annotation::*; #[cfg(test)] diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index f6c4be6..55c66d9 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -3,8 +3,9 @@ use inkwell::{ targets::*, OptimizationLevel, }; -use nac3core::typecheck::type_inferencer::PrimitiveStore; +use nac3core::typecheck::{type_inferencer::PrimitiveStore, typedef::{Type, Unifier}}; use nac3parser::{ast::{Expr, ExprKind, StmtKind}, parser}; +use parking_lot::RwLock; use std::{borrow::Borrow, env}; use std::fs; use std::{collections::HashMap, path::Path, sync::Arc, time::SystemTime}; @@ -15,7 +16,11 @@ use nac3core::{ WorkerRegistry, }, symbol_resolver::SymbolResolver, - toplevel::{composer::TopLevelComposer, TopLevelDef, helper::parse_parameter_default_value}, + toplevel::{ + composer::TopLevelComposer, + TopLevelDef, helper::parse_parameter_default_value, + type_annotation::*, + }, typecheck::typedef::FunSignature, }; @@ -68,25 +73,84 @@ fn main() { for stmt in parser_result.into_iter() { if let StmtKind::Assign { targets, value, .. } = &stmt.node { + fn handle_typevar_definition( + var: &Expr, + resolver: &(dyn SymbolResolver + Send + Sync), + def_list: &[Arc>], + unifier: &mut Unifier, + primitives: &PrimitiveStore, + ) -> Result { + if let ExprKind::Call { func, args, .. } = &var.node { + if matches!(&func.node, ExprKind::Name { id, .. } if id == &"TypeVar".into()) { + let constraints = args + .iter() + .skip(1) + .map(|x| -> Result { + let ty = parse_ast_to_type_annotation_kinds( + resolver, + def_list, + unifier, + primitives, + x, + Default::default(), + )?; + get_type_from_type_annotation_kinds(def_list, unifier, primitives, &ty) + }) + .collect::, _>>()?; + Ok(unifier.get_fresh_var_with_range(&constraints).0) + } else { + Err(format!("expression {:?} cannot be handled as a TypeVar in global scope", var)) + } + } else { + Err(format!("expression {:?} cannot be handled as a TypeVar in global scope", var)) + } + } + fn handle_assignment_pattern( targets: &[Expr], value: &Expr, resolver: &(dyn SymbolResolver + Send + Sync), internal_resolver: &ResolverInternal, + def_list: &[Arc>], + unifier: &mut Unifier, + primitives: &PrimitiveStore, ) -> Result<(), String> { if targets.len() == 1 { match &targets[0].node { ExprKind::Name { id, .. } => { - let val = parse_parameter_default_value(value.borrow(), resolver)?; - internal_resolver.add_module_global(*id, val); - Ok(()) + if let Ok(var) = handle_typevar_definition( + value.borrow(), + resolver, + def_list, + unifier, + primitives, + ) { + internal_resolver.add_id_type(*id, var); + Ok(()) + } else if let Ok(val) = parse_parameter_default_value(value.borrow(), resolver) { + internal_resolver.add_module_global(*id, val); + Ok(()) + } else { + Err(format!("fails to evaluate this expression `{:?}` as a constant or TypeVar at {}", + targets[0].node, + targets[0].location, + )) + } } ExprKind::List { elts, .. } | ExprKind::Tuple { elts, .. } => { - handle_assignment_pattern(elts, value, resolver, internal_resolver)?; + handle_assignment_pattern( + elts, + value, + resolver, + internal_resolver, + def_list, + unifier, + primitives + )?; Ok(()) } - _ => unreachable!("cannot be assigned") + _ => Err(format!("assignment to {:?} is not supported at {}", targets[0], targets[0].location)) } } else { match &value.node { @@ -105,7 +169,10 @@ fn main() { std::slice::from_ref(tar), val, resolver, - internal_resolver + internal_resolver, + def_list, + unifier, + primitives )?; } Ok(()) @@ -115,7 +182,19 @@ fn main() { } } } - if let Err(err) = handle_assignment_pattern(targets, value, resolver.as_ref(), internal_resolver.as_ref()) { + + let def_list = composer.extract_def_list(); + let unifier = &mut composer.unifier; + let primitives = &composer.primitives_ty; + if let Err(err) = handle_assignment_pattern( + targets, + value, + resolver.as_ref(), + internal_resolver.as_ref(), + &def_list, + unifier, + primitives, + ) { eprintln!("{}", err); return; } -- 2.44.1 From 6e93e41a3bb3bbe3b1dbbab3ea1e5fc6007cd8e2 Mon Sep 17 00:00:00 2001 From: ychenfo Date: Tue, 30 Nov 2021 21:44:48 +0800 Subject: [PATCH 10/10] nac3artiq: symbol resolver resolve typevar type --- nac3artiq/src/lib.rs | 6 + nac3artiq/src/symbol_resolver.rs | 352 +++++++++++++++++++++++++------ 2 files changed, 293 insertions(+), 65 deletions(-) diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 0734aa5..ac08799 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -100,10 +100,13 @@ impl Nac3 { let val = id_fn.call1((member.get_item(1)?,))?.extract()?; name_to_pyid.insert(key.into(), val); } + let typings = PyModule::import(py, "typing")?; let helper = PythonHelper { id_fn: builtins.getattr("id").unwrap().to_object(py), len_fn: builtins.getattr("len").unwrap().to_object(py), type_fn: builtins.getattr("type").unwrap().to_object(py), + origin_ty_fn: typings.getattr("get_origin").unwrap().to_object(py), + args_ty_fn: typings.getattr("get_args").unwrap().to_object(py), }; Ok(( module.getattr("__name__")?.extract()?, @@ -442,10 +445,13 @@ impl Nac3 { }; let mut synthesized = parse_program(&synthesized).unwrap(); let builtins = PyModule::import(py, "builtins")?; + let typings = PyModule::import(py, "typing")?; let helper = PythonHelper { id_fn: builtins.getattr("id").unwrap().to_object(py), len_fn: builtins.getattr("len").unwrap().to_object(py), type_fn: builtins.getattr("type").unwrap().to_object(py), + origin_ty_fn: typings.getattr("get_origin").unwrap().to_object(py), + args_ty_fn: typings.getattr("get_args").unwrap().to_object(py), }; let resolver = Arc::new(Resolver(Arc::new(InnerResolver { id_to_type: self.builtins_ty.clone().into(), diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 5160b78..f36c146 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -43,6 +43,8 @@ pub struct PythonHelper { pub type_fn: PyObject, pub len_fn: PyObject, pub id_fn: PyObject, + pub origin_ty_fn: PyObject, + pub args_ty_fn: PyObject, } struct PythonValue { @@ -133,47 +135,46 @@ impl InnerResolver { })) } - fn get_obj_type( + // handle python objects that represent types themselves + // primitives and class types should be themselves, use `ty_id` to check, + // TypeVars and GenericAlias(`A[int, bool]`) should use `ty_ty_id` to check + // the `bool` value returned indicates whether they are instantiated or not + fn get_pyty_obj_type( &self, py: Python, - obj: &PyAny, + pyty: &PyAny, unifier: &mut Unifier, defs: &[Arc>], primitives: &PrimitiveStore, - ) -> PyResult> { + ) -> PyResult> { let ty_id: u64 = self .helper .id_fn - .call1(py, (self.helper.type_fn.call1(py, (obj,))?,))? + .call1(py, (pyty,))? + .extract(py)?; + let ty_ty_id: u64 = self + .helper + .id_fn + .call1(py, (self.helper.type_fn.call1(py, (pyty,))?,))? .extract(py)?; if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 { - Ok(Some(primitives.int32)) + Ok(Ok((primitives.int32, true))) } else if ty_id == self.primitive_ids.int64 { - Ok(Some(primitives.int64)) + Ok(Ok((primitives.int64, true))) } else if ty_id == self.primitive_ids.bool { - Ok(Some(primitives.bool)) + Ok(Ok((primitives.bool, true))) } else if ty_id == self.primitive_ids.float { - Ok(Some(primitives.float)) + Ok(Ok((primitives.float, true))) } else if ty_id == self.primitive_ids.list { - let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?; - if len == 0 { - let var = unifier.get_fresh_var().0; - let list = unifier.add_ty(TypeEnum::TList { ty: var }); - Ok(Some(list)) - } else { - let ty = self.get_list_elem_type(py, obj, len, unifier, defs, primitives)?; - Ok(ty.map(|ty| unifier.add_ty(TypeEnum::TList { ty }))) - } + // do not handle type var param and concrete check here + let var = unifier.get_fresh_var().0; + let list = unifier.add_ty(TypeEnum::TList { ty: var }); + Ok(Ok((list, false))) } else if ty_id == self.primitive_ids.tuple { - let elements: &PyTuple = obj.cast_as()?; - let types: Result>, _> = elements - .iter() - .map(|elem| self.get_obj_type(py, elem, unifier, defs, primitives)) - .collect(); - let types = types?; - Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types }))) - } else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id) { + // do not handle type var param and concrete check here + Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false))) + } else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).cloned() { let def = defs[def_id.0].read(); if let TopLevelDef::Class { object_id, @@ -183,54 +184,275 @@ impl InnerResolver { .. } = &*def { - let var_map: HashMap<_, _> = type_vars - .iter() - .map(|var| { - ( - if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) { - *id - } else { - unreachable!() - }, - unifier.get_fresh_var().0, - ) - }) - .collect(); - let mut fields_ty = HashMap::new(); - for method in methods.iter() { - fields_ty.insert(method.0, (method.1, false)); - } - for field in fields.iter() { - let name: String = field.0.into(); - let field_data = obj.getattr(&name)?; - let ty = self - .get_obj_type(py, field_data, unifier, defs, primitives)? - .unwrap_or(primitives.none); - let field_ty = unifier.subst(field.1, &var_map).unwrap_or(field.1); - if unifier.unify(ty, field_ty).is_err() { - // field type mismatch - return Ok(None); + // do not handle type var param and concrete check here, and no subst + Ok(Ok({ + let ty = TypeEnum::TObj { + obj_id: *object_id, + params: RefCell::new({ + type_vars + .iter() + .map(|x| { + if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x) { + (*id, *x) + } else { + println!("{}", unifier.default_stringify(*x)); + unreachable!() + } + }).collect() + }), + fields: RefCell::new({ + let mut res = methods + .iter() + .map(|(iden, ty, _)| (*iden, (*ty, false))) + .collect::>(); + res.extend(fields.clone().into_iter().map(|x| (x.0, (x.1, x.2)))); + res + }) + }; + // here also false, later instantiation use python object to check compatible + (unifier.add_ty(ty), false) + })) + } else { + // only object is supported, functions are not supported + unreachable!("function type is not supported, should not be queried") + } + } else if ty_ty_id == self.primitive_ids.typevar { + let constraint_types = { + let constraints = pyty.getattr("__constraints__").unwrap(); + let mut result: Vec = vec![]; + for i in 0.. { + if let Ok(constr) = constraints.get_item(i) { + result.push({ + match self.get_pyty_obj_type(py, constr, unifier, defs, primitives)? { + Ok((ty, _)) => { + if unifier.is_concrete(ty, &[]) { + ty + } else { + return Ok(Err(format!( + "the {}th constraint of TypeVar `{}` is not concrete", + i + 1, + pyty.getattr("__name__")?.extract::()? + ))) + } + }, + Err(err) => return Ok(Err(err)) + } + }) + } else { + break; + } + } + result + }; + let res = unifier.get_fresh_var_with_range(&constraint_types).0; + Ok(Ok((res, true))) + } else if ty_ty_id == self.primitive_ids.generic_alias.0 || ty_ty_id == self.primitive_ids.generic_alias.1 { + let origin = self.helper.origin_ty_fn.call1(py, (pyty,))?; + let args = self.helper.args_ty_fn.call1(py, (pyty,))?; + let args: &PyTuple = args.cast_as(py)?; + let origin_ty = match self.get_pyty_obj_type(py, origin.as_ref(py), unifier, defs, primitives)? { + Ok((ty, false)) => ty, + Ok((_, true)) => return Ok(Err("instantiated type does not take type parameters".into())), + Err(err) => return Ok(Err(err)) + }; + + match &*unifier.get_ty(origin_ty) { + TypeEnum::TList { .. } => { + if args.len() == 1 { + let ty = match self.get_pyty_obj_type(py, args.get_item(0), unifier, defs, primitives)? { + Ok(ty) => ty, + Err(err) => return Ok(Err(err)) + }; + if !unifier.is_concrete(ty.0, &[]) && !ty.1 { + panic!("type list should take concrete parameters in type var ranges") + } + Ok(Ok((unifier.add_ty(TypeEnum::TList { ty: ty.0 }), true))) + } else { + return Ok(Err(format!("type list needs exactly 1 type parameters, found {}", args.len()))) + } + }, + TypeEnum::TTuple { .. } => { + let args = match args + .iter() + .map(|x| self.get_pyty_obj_type(py, x, unifier, defs, primitives)) + .collect::, _>>()? + .into_iter() + .collect::, _>>() { + Ok(args) if !args.is_empty() => args + .into_iter() + .map(|(x, check)| if !unifier.is_concrete(x, &[]) && !check { + panic!("type tuple should take concrete parameters in type var ranges") + } else { + x + } + ) + .collect::>(), + Err(err) => return Ok(Err(err)), + _ => return Ok(Err("tuple type needs at least 1 type parameters".to_string())) + }; + Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: args }), true))) + }, + TypeEnum::TObj { params, obj_id, .. } => { + let subst = { + let params = &*params.borrow(); + if params.len() != args.len() { + return Ok(Err(format!( + "for class #{}, expect {} type parameters, got {}.", + obj_id.0, + params.len(), + args.len(), + ))) + } + let args = match args + .iter() + .map(|x| self.get_pyty_obj_type(py, x, unifier, defs, primitives)) + .collect::, _>>()? + .into_iter() + .collect::, _>>() { + Ok(args) => args + .into_iter() + .map(|(x, check)| if !unifier.is_concrete(x, &[]) && !check { + panic!("type class should take concrete parameters in type var ranges") + } else { + x + } + ) + .collect::>(), + Err(err) => return Ok(Err(err)), + }; + params + .iter() + .zip(args.iter()) + .map(|((id, _), ty)| (*id, *ty)) + .collect::>() + }; + Ok(Ok((unifier.subst(origin_ty, &subst).unwrap_or(origin_ty), true))) + }, + TypeEnum::TVirtual { .. } => { + if args.len() == 1 { + let ty = match self.get_pyty_obj_type(py, args.get_item(0), unifier, defs, primitives)? { + Ok(ty) => ty, + Err(err) => return Ok(Err(err)) + }; + if !unifier.is_concrete(ty.0, &[]) && !ty.1 { + panic!("virtual class should take concrete parameters in type var ranges") + } + Ok(Ok((unifier.add_ty(TypeEnum::TVirtual { ty: ty.0 }), true))) + } else { + return Ok(Err(format!("virtual class needs exactly 1 type parameters, found {}", args.len()))) + } + } + _ => unimplemented!() + } + } else if ty_id == self.primitive_ids.virtual_id { + Ok(Ok(({ + let ty = TypeEnum::TVirtual { ty: unifier.get_fresh_var().0 }; + unifier.add_ty(ty) + }, false))) + } else { + Ok(Err("unknown type".into())) + } + } + + fn get_obj_type( + &self, + py: Python, + obj: &PyAny, + unifier: &mut Unifier, + defs: &[Arc>], + primitives: &PrimitiveStore, + ) -> PyResult> { + let ty = self.helper.type_fn.call1(py, (obj,)).unwrap(); + let (extracted_ty, inst_check) = match self.get_pyty_obj_type( + py, + { + if [self.primitive_ids.typevar, + self.primitive_ids.generic_alias.0, + self.primitive_ids.generic_alias.1 + ].contains(&self.helper.id_fn.call1(py, (ty.clone(),))?.extract::(py)?) { + obj + } else { + ty.as_ref(py) + } + }, + unifier, + defs, + primitives + )? { + Ok(s) => s, + Err(_) => return Ok(None) + }; + return match (&*unifier.get_ty(extracted_ty), inst_check) { + // do the instantiation for these three types + (TypeEnum::TList { ty }, false) => { + let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?; + if len == 0 { + assert!(matches!( + &*unifier.get_ty(extracted_ty), + TypeEnum::TVar { meta: nac3core::typecheck::typedef::TypeVarMeta::Generic, range, .. } + if range.borrow().is_empty() + )); + Ok(Some(extracted_ty)) + } else { + let actual_ty = self + .get_list_elem_type(py, obj, len, unifier, defs, primitives)?; + if let Some(actual_ty) = actual_ty { + unifier.unify(*ty, actual_ty).unwrap(); + Ok(Some(extracted_ty)) + } else { + Ok(None) + } + } + } + (TypeEnum::TTuple { .. }, false) => { + let elements: &PyTuple = obj.cast_as()?; + let types: Result>, _> = elements + .iter() + .map(|elem| self.get_obj_type(py, elem, unifier, defs, primitives)) + .collect(); + let types = types?; + Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types }))) + } + (TypeEnum::TObj { params, fields, .. }, false) => { + let var_map = params + .borrow() + .iter() + .map(|(id_var, ty)| { + if let TypeEnum::TVar { id, range, .. } = &*unifier.get_ty(*ty) { + assert_eq!(*id, *id_var); + (*id, unifier.get_fresh_var_with_range(&range.borrow()).0) + } else { + unreachable!() + } + }) + .collect::>(); + // loop through non-function fields of the class to get the instantiated value + for field in fields.borrow().iter() { + let name: String = (*field.0).into(); + if let TypeEnum::TFunc( .. ) = &*unifier.get_ty(field.1.0) { + continue; + } else { + let field_data = obj.getattr(&name)?; + let ty = self + .get_obj_type(py, field_data, unifier, defs, primitives)? + .unwrap_or(primitives.none); + let field_ty = unifier.subst(field.1.0, &var_map).unwrap_or(field.1.0); + if unifier.unify(ty, field_ty).is_err() { + // field type mismatch + return Ok(None); + } } - fields_ty.insert(field.0, (ty, field.2)); } for (_, ty) in var_map.iter() { // must be concrete type if !unifier.is_concrete(*ty, &[]) { - return Ok(None); + return Ok(None) } } - Ok(Some(unifier.add_ty(TypeEnum::TObj { - obj_id: *object_id, - fields: RefCell::new(fields_ty), - params: RefCell::new(var_map), - }))) - } else { - // only object is supported, functions are not supported - Ok(None) + return Ok(Some(unifier.subst(extracted_ty, &var_map).unwrap_or(extracted_ty))); } - } else { - Ok(None) - } + _ => Ok(Some(extracted_ty)) + }; } fn get_obj_value<'ctx, 'a>( -- 2.44.1