diff --git a/.cargo/config b/.cargo/config deleted file mode 100644 index ac01e73c..00000000 --- a/.cargo/config +++ /dev/null @@ -1,2 +0,0 @@ -[unstable] -extra-link-arg = true diff --git a/flake.lock b/flake.lock index 50d000ce..b7922587 100644 --- a/flake.lock +++ b/flake.lock @@ -2,16 +2,16 @@ "nodes": { "nixpkgs": { "locked": { - "lastModified": 1636698608, - "narHash": "sha256-sxLLeQmH3UrP3UANqXzMLE0bPDgY5aIt04iBoPffG2E=", + "lastModified": 1637636156, + "narHash": "sha256-E2ym4Vcpqu9JYoQDXJZR48gVD+LPPbaCoYveIk7Xu3Y=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "92c881b6a72abce5bb2f5db3f903b4871d13aaa9", + "rev": "b026e1cf87a108dd06fe521f224fdc72fd0b013d", "type": "github" }, "original": { "owner": "NixOS", - "ref": "master", + "ref": "release-21.11", "repo": "nixpkgs", "type": "github" } diff --git a/flake.nix b/flake.nix index bcaa8970..0e89187f 100644 --- a/flake.nix +++ b/flake.nix @@ -1,7 +1,7 @@ { description = "The third-generation ARTIQ compiler"; - inputs.nixpkgs.url = github:NixOS/nixpkgs/master; + inputs.nixpkgs.url = github:NixOS/nixpkgs/release-21.11; outputs = { self, nixpkgs }: let diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 01310103..c3de02f4 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -2,14 +2,14 @@ use inkwell::{types::BasicType, values::BasicValueEnum, AddressSpace}; use nac3core::{ codegen::CodeGenContext, location::Location, - symbol_resolver::{StaticValue, SymbolResolver, ValueEnum}, + symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum}, toplevel::{DefinitionId, TopLevelDef}, typecheck::{ type_inferencer::PrimitiveStore, typedef::{Type, TypeEnum, Unifier}, }, }; -use nac3parser::ast::StrRef; +use nac3parser::ast::{self, StrRef}; use parking_lot::{Mutex, RwLock}; use pyo3::{ types::{PyList, PyModule, PyTuple}, @@ -458,9 +458,80 @@ impl InnerResolver { } } } + + fn get_default_param_obj_value( + &self, + py: Python, + obj: &PyAny, + ) -> PyResult> { + let ty_id: u64 = self + .helper + .id_fn + .call1(py, (self.helper.type_fn.call1(py, (obj,))?,))? + .extract(py)?; + Ok( + if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 { + let val: i32 = obj.extract()?; + Ok(SymbolValue::I32(val)) + } else if ty_id == self.primitive_ids.int64 { + let val: i64 = obj.extract()?; + Ok(SymbolValue::I64(val)) + } else if ty_id == self.primitive_ids.bool { + let val: bool = obj.extract()?; + Ok(SymbolValue::Bool(val)) + } else if ty_id == self.primitive_ids.float { + let val: f64 = obj.extract()?; + Ok(SymbolValue::Double(val)) + } else if ty_id == self.primitive_ids.tuple { + let elements: &PyTuple = obj.cast_as()?; + let elements: Result, String>, _> = elements + .iter() + .map(|elem| self.get_default_param_obj_value(py, elem)) + .collect(); + let elements = match elements? { + Ok(el) => el, + Err(err) => return Ok(Err(err)), + }; + Ok(SymbolValue::Tuple(elements)) + } else { + Err("only primitives values and tuple can be default parameter value".into()) + }, + ) + } } impl SymbolResolver for Resolver { + fn get_default_param_value(&self, expr: &ast::Expr) -> Option { + match &expr.node { + ast::ExprKind::Name { id, .. } => { + Python::with_gil(|py| -> PyResult> { + let obj: &PyAny = self.0.module.extract(py)?; + let members: &PyList = PyModule::import(py, "inspect")? + .getattr("getmembers")? + .call1((obj,))? + .cast_as()?; + let mut sym_value = None; + for member in members.iter() { + let key: &str = member.get_item(0)?.extract()?; + let val = member.get_item(1)?; + if key == id.to_string() { + sym_value = Some( + self.0 + .get_default_param_obj_value(py, val) + .unwrap() + .unwrap(), + ); + break; + } + } + Ok(sym_value) + }) + .unwrap() + } + _ => unimplemented!("other type of expr not supported at {}", expr.location), + } + } + fn get_symbol_type( &self, unifier: &mut Unifier, diff --git a/nac3core/src/codegen/concrete_type.rs b/nac3core/src/codegen/concrete_type.rs index f4f4518b..140142ef 100644 --- a/nac3core/src/codegen/concrete_type.rs +++ b/nac3core/src/codegen/concrete_type.rs @@ -147,8 +147,15 @@ impl ConcreteTypeStore { fields: fields .borrow() .iter() - .map(|(name, ty)| { - (*name, (self.from_unifier_type(unifier, primitives, ty.0, cache), ty.1)) + .filter_map(|(name, ty)| { + // here we should not have type vars, but some partial instantiated + // class methods can still have uninstantiated type vars, so + // filter out all the methods, as this will not affect codegen + if let TypeEnum::TFunc( .. ) = &*unifier.get_ty(ty.0) { + None + } else { + Some((*name, (self.from_unifier_type(unifier, primitives, ty.0, cache), ty.1))) + } }) .collect(), params: params diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 94a1024a..dd46673e 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -7,7 +7,7 @@ use crate::{ }, symbol_resolver::{SymbolValue, ValueEnum}, toplevel::{DefinitionId, TopLevelDef}, - typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum}, + typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier}, }; use inkwell::{ types::{BasicType, BasicTypeEnum}, @@ -19,6 +19,31 @@ use nac3parser::ast::{self, Boolop, Comprehension, Constant, Expr, ExprKind, Ope use super::CodeGenerator; +pub fn get_subst_key( + unifier: &mut Unifier, + obj: Option, + fun_vars: &HashMap, + filter: Option<&Vec>, +) -> String { + let mut vars = obj + .map(|ty| { + if let TypeEnum::TObj { params, .. } = &*unifier.get_ty(ty) { + params.borrow().clone() + } else { + unreachable!() + } + }) + .unwrap_or_default(); + vars.extend(fun_vars.iter()); + let sorted = + vars.keys().filter(|id| filter.map(|v| v.contains(id)).unwrap_or(true)).sorted(); + sorted + .map(|id| { + unifier.stringify(vars[id], &mut |id| id.to_string(), &mut |id| id.to_string()) + }) + .join(", ") +} + impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { pub fn build_gep_and_load( &mut self, @@ -34,23 +59,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { fun: &FunSignature, filter: Option<&Vec>, ) -> String { - let mut vars = obj - .map(|ty| { - if let TypeEnum::TObj { params, .. } = &*self.unifier.get_ty(ty) { - params.borrow().clone() - } else { - unreachable!() - } - }) - .unwrap_or_default(); - vars.extend(fun.vars.iter()); - let sorted = - vars.keys().filter(|id| filter.map(|v| v.contains(id)).unwrap_or(true)).sorted(); - sorted - .map(|id| { - self.unifier.stringify(vars[id], &mut |id| id.to_string(), &mut |id| id.to_string()) - }) - .join(", ") + get_subst_key(&mut self.unifier, obj, &fun.vars, filter) } pub fn get_attr_index(&mut self, ty: Type, attr: StrRef) -> usize { diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 7f8859f0..b01aaf66 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -293,7 +293,14 @@ pub fn gen_func<'ctx, G: CodeGenerator + ?Sized>( // this should be unification between variables and concrete types // and should not cause any problem... let b = task.store.to_unifier_type(&mut unifier, &primitives, *b, &mut cache); - unifier.unify(*a, b).unwrap(); + unifier.unify(*a, b).or_else(|err| { + if matches!(&*unifier.get_ty(*a), TypeEnum::TRigidVar { .. }) { + unifier.replace_rigid_var(*a, b); + Ok(()) + } else { + Err(err) + } + }).unwrap() } // rebuild primitive store with unique representatives diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 0939ab41..2bbc75ed 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -36,6 +36,10 @@ impl Resolver { } impl SymbolResolver for Resolver { + fn get_default_param_value(&self, _: &nac3parser::ast::Expr) -> Option { + unimplemented!() + } + fn get_symbol_type( &self, _: &mut Unifier, diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index d92ddfda..d7e9d697 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -102,6 +102,7 @@ pub trait SymbolResolver { ) -> Option>; fn get_symbol_location(&self, str: StrRef) -> Option; + fn get_default_param_value(&self, expr: &nac3parser::ast::Expr) -> Option; // handle function call etc. } diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index bf9f1561..a8faf130 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -6,6 +6,7 @@ use inkwell::FloatPredicate; use crate::{ symbol_resolver::SymbolValue, typecheck::type_inferencer::{FunctionData, Inferencer}, + codegen::expr::get_subst_key, }; use super::*; @@ -194,7 +195,7 @@ impl TopLevelComposer { signature: primitives.1.add_ty(TypeEnum::TFunc(RefCell::new(FunSignature { args: vec![FuncArg { name: "_".into(), ty: num_ty.0, default_value: None }], ret: float, - vars: var_map, + vars: var_map.clone(), }))), var_id: Default::default(), instance_to_symbol: Default::default(), @@ -397,7 +398,7 @@ impl TopLevelComposer { signature: primitives.1.add_ty(TypeEnum::TFunc(RefCell::new(FunSignature { args: vec![FuncArg { name: "_".into(), ty: num_ty.0, default_value: None }], ret: primitives.0.bool, - vars: Default::default(), + vars: var_map, }))), var_id: Default::default(), instance_to_symbol: Default::default(), @@ -1066,9 +1067,22 @@ impl TopLevelComposer { .into()); } - args.args + let arg_with_default: Vec<(&ast::Located>, Option<&ast::Expr>)> = args + .args .iter() - .map(|x| -> Result { + .rev() + .zip(args + .defaults + .iter() + .rev() + .map(|x| -> Option<&ast::Expr> { Some(x) }) + .chain(std::iter::repeat(None)) + ).collect_vec(); + + arg_with_default + .iter() + .rev() + .map(|(x, default)| -> Result { let annotation = x .node .annotation @@ -1120,7 +1134,19 @@ impl TopLevelComposer { Ok(FuncArg { name: x.node.arg, ty, - default_value: Default::default(), + default_value: match default { + None => None, + Some(default) => Some({ + let v = Self::parse_parameter_default_value(default, resolver)?; + Self::check_default_param_type( + &v, + &type_annotation, + primitives_store, + unifier + ).map_err(|err| format!("{} at {}", err, x.location))?; + v + }) + } }) }) .collect::, _>>()? @@ -1170,8 +1196,17 @@ impl TopLevelComposer { primitives_store.none } }; - var_id.extend_from_slice( - function_var_map.keys().into_iter().copied().collect_vec().as_slice(), + var_id.extend_from_slice(function_var_map + .iter() + .filter_map(|(id, ty)| { + if matches!(&*unifier.get_ty(*ty), TypeEnum::TVar { range, .. } if range.borrow().is_empty()) { + None + } else { + Some(*id) + } + }) + .collect_vec() + .as_slice() ); let function_ty = unifier.add_ty(TypeEnum::TFunc( FunSignature { args: arg_types, ret: return_ty, vars: function_var_map } @@ -1272,7 +1307,20 @@ impl TopLevelComposer { } let mut result = Vec::new(); - for x in &args.args { + + let arg_with_default: Vec<(&ast::Located>, Option<&ast::Expr>)> = args + .args + .iter() + .rev() + .zip(args + .defaults + .iter() + .rev() + .map(|x| -> Option<&ast::Expr> { Some(x) }) + .chain(std::iter::repeat(None)) + ).collect_vec(); + + for (x, default) in arg_with_default.into_iter().rev() { let name = x.node.arg; if name != zelf { let type_ann = { @@ -1317,8 +1365,20 @@ impl TopLevelComposer { let dummy_func_arg = FuncArg { name, ty: unifier.get_fresh_var().0, - // TODO: default value? - default_value: None, + default_value: match default { + None => None, + Some(default) => { + if name == "self".into() { + return Err(format!("`self` parameter cannot take default value at {}", x.location)); + } + Some({ + let v = Self::parse_parameter_default_value(default, class_resolver)?; + Self::check_default_param_type(&v, &type_ann, primitives, unifier) + .map_err(|err| format!("{} at {}", err, x.location))?; + v + }) + } + } }; // push the dummy type and the type annotation // into the list for later unification @@ -1374,9 +1434,20 @@ impl TopLevelComposer { if let TopLevelDef::Function { var_id, .. } = temp_def_list.get(method_id.0).unwrap().write().deref_mut() { - var_id.extend_from_slice( - method_var_map.keys().into_iter().copied().collect_vec().as_slice(), + var_id.extend_from_slice(method_var_map + .iter() + .filter_map(|(id, ty)| { + if matches!(&*unifier.get_ty(*ty), TypeEnum::TVar { range, .. } if range.borrow().is_empty()) { + None + } else { + Some(*id) + } + }) + .collect_vec() + .as_slice() ); + } else { + unreachable!() } let method_type = unifier.add_ty(TypeEnum::TFunc( FunSignature { args: arg_types, ret: ret_type, vars: method_var_map } @@ -1625,11 +1696,14 @@ impl TopLevelComposer { unreachable!("must be init function here") } let all_inited = Self::get_all_assigned_field(body.as_slice())?; - if fields.iter().any(|x| !all_inited.contains(&x.0)) { - return Err(format!( - "fields of class {} not fully initialized", - class_name - )); + for (f, _, _) in fields { + if !all_inited.contains(f) { + return Err(format!( + "fields `{}` of class `{}` not fully initialized", + f, + class_name + )); + } } } } @@ -1648,13 +1722,14 @@ impl TopLevelComposer { simple_name, signature, resolver, + var_id: insted_vars, .. } = &mut *function_def { if let TypeEnum::TFunc(func_sig) = self.unifier.get_ty(*signature).as_ref() { let FunSignature { args, ret, vars } = &*func_sig.borrow(); // None if is not class method - let self_type = { + let uninst_self_type = { if let Some(class_id) = self.method_class.get(&DefinitionId(id)) { let class_def = self.definition_ast_list.get(class_id.0).unwrap(); let class_def = class_def.0.read(); @@ -1666,7 +1741,7 @@ impl TopLevelComposer { &self.primitives_ty, &ty_ann, )?; - Some(self_ty) + Some((self_ty, type_vars.clone())) } else { unreachable!("must be class def") } @@ -1674,20 +1749,20 @@ impl TopLevelComposer { None } }; + // carefully handle those with bounds, without bounds and no typevars + // if class methods, `vars` also contains all class typevars here let (type_var_subst_comb, no_range_vars) = { let unifier = &mut self.unifier; let mut no_ranges: Vec = Vec::new(); - let var_ids = vars.iter().map(|(id, ty)| { - if matches!(unifier.get_ty(*ty).as_ref(), TypeEnum::TVar { range, .. } if range.borrow().is_empty()) { - no_ranges.push(*ty); - } - *id - }) - .collect_vec(); + let var_ids = vars.keys().copied().collect_vec(); let var_combs = vars .iter() .map(|(_, ty)| { - unifier.get_instantiations(*ty).unwrap_or_else(|| vec![*ty]) + unifier.get_instantiations(*ty).unwrap_or_else(|| { + let rigid = unifier.get_fresh_rigid_var().0; + no_ranges.push(rigid); + vec![rigid] + }) }) .multi_cartesian_product() .collect_vec(); @@ -1717,9 +1792,34 @@ impl TopLevelComposer { }; let self_type = { let unifier = &mut self.unifier; - self_type.map(|x| unifier.subst(x, &subst).unwrap_or(x)) + uninst_self_type + .clone() + .map(|(self_type, type_vars)| { + let subst_for_self = { + let class_ty_var_ids = type_vars + .iter() + .map(|x| { + if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x) { + *id + } else { + unreachable!("must be type var here"); + } + }) + .collect::>(); + subst + .iter() + .filter_map(|(ty_var_id, ty_var_target)| { + if class_ty_var_ids.contains(ty_var_id) { + Some((*ty_var_id, *ty_var_target)) + } else { + None + } + }) + .collect::>() + }; + unifier.subst(self_type, &subst_for_self).unwrap_or(self_type) + }) }; - let mut identifiers = { // NOTE: none and function args? let mut result: HashSet<_> = HashSet::new(); @@ -1809,22 +1909,12 @@ impl TopLevelComposer { } instance_to_stmt.insert( - // NOTE: refer to codegen/expr/get_subst_key function - { - let unifier = &mut self.unifier; - subst - .keys() - .sorted() - .map(|id| { - let ty = subst.get(id).unwrap(); - unifier.stringify( - *ty, - &mut |id| id.to_string(), - &mut |id| id.to_string(), - ) - }) - .join(", ") - }, + get_subst_key( + &mut self.unifier, + self_type, + &subst, + Some(insted_vars), + ), FunInstance { body: Arc::new(fun_body), unifier_id: 0, diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 35da97c9..b9259854 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -1,3 +1,8 @@ +use std::convert::TryInto; + +use nac3parser::ast::{Constant, Location}; +use crate::symbol_resolver::SymbolValue; + use super::*; impl TopLevelDef { @@ -341,4 +346,121 @@ impl TopLevelComposer { } Ok(result) } + + pub fn parse_parameter_default_value(default: &ast::Expr, resolver: &(dyn SymbolResolver + Send + Sync)) -> Result { + parse_parameter_default_value(default, resolver) + } + + pub fn check_default_param_type(val: &SymbolValue, ty: &TypeAnnotation, primitive: &PrimitiveStore, unifier: &mut Unifier) -> Result<(), String> { + let res = match val { + SymbolValue::Bool(..) => { + if matches!(ty, TypeAnnotation::Primitive(t) if *t == primitive.bool) { + None + } else { + Some("bool".to_string()) + } + } + SymbolValue::Double(..) => { + if matches!(ty, TypeAnnotation::Primitive(t) if *t == primitive.float) { + None + } else { + Some("float".to_string()) + } + } + SymbolValue::I32(..) => { + if matches!(ty, TypeAnnotation::Primitive(t) if *t == primitive.int32) { + None + } else { + Some("int32".to_string()) + } + } + SymbolValue::I64(..) => { + if matches!(ty, TypeAnnotation::Primitive(t) if *t == primitive.int64) { + None + } else { + Some("int64".to_string()) + } + } + SymbolValue::Tuple(elts) => { + if let TypeAnnotation::Tuple(elts_ty) = ty { + for (e, t) in elts.iter().zip(elts_ty.iter()) { + Self::check_default_param_type(e, t, primitive, unifier)? + } + if elts.len() != elts_ty.len() { + Some(format!("tuple of length {}", elts.len())) + } else { + None + } + } else { + Some("tuple".to_string()) + } + } + }; + if let Some(found) = res { + Err(format!( + "incompatible default parameter type, expect {}, found {}", + ty.stringify(unifier), + found + )) + } else { + Ok(()) + } + } +} + +pub fn parse_parameter_default_value(default: &ast::Expr, resolver: &(dyn SymbolResolver + Send + Sync)) -> Result { + fn handle_constant(val: &Constant, loc: &Location) -> Result { + match val { + Constant::Int(v) => { + if let Ok(v) = v.try_into() { + Ok(SymbolValue::I32(v)) + } else { + Err(format!( + "integer value out of range at {}", + loc + )) + } + } + Constant::Float(v) => Ok(SymbolValue::Double(*v)), + Constant::Bool(v) => Ok(SymbolValue::Bool(*v)), + Constant::Tuple(tuple) => Ok(SymbolValue::Tuple( + tuple.iter().map(|x| handle_constant(x, loc)).collect::, _>>()? + )), + _ => unimplemented!("this constant is not supported at {}", loc), + } + } + match &default.node { + ast::ExprKind::Constant { value, .. } => handle_constant(value, &default.location), + ast::ExprKind::Call { func, args, .. } if { + match &func.node { + ast::ExprKind::Name { id, .. } => *id == "int64".into(), + _ => false, + } + } => { + if args.len() == 1 { + match &args[0].node { + ast::ExprKind::Constant { value: Constant::Int(v), .. } => + Ok(SymbolValue::I64(v.try_into().unwrap())), + _ => Err(format!("only allow constant integer here at {}", default.location)) + } + } else { + Err(format!("only allow constant integer here at {}", default.location)) + } + } + ast::ExprKind::Tuple { elts, .. } => Ok(SymbolValue::Tuple(elts + .iter() + .map(|x| parse_parameter_default_value(x, resolver)) + .collect::, _>>()? + )), + ast::ExprKind::Name { id, .. } => { + resolver.get_default_param_value(default).ok_or_else( + || format!( + "`{}` cannot be used as a default parameter at {} (not primitive type or tuple / not defined?)", + id, + default.location + ) + ) + } + _ => Err(format!("unsupported default parameter at {}", default.location)) + } } diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index c96a7cac..f6fc15b0 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -23,7 +23,7 @@ use inkwell::values::BasicValueEnum; pub struct DefinitionId(pub usize); pub mod composer; -mod helper; +pub mod helper; mod type_annotation; use composer::*; use type_annotation::*; diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index 869c5078..42e38276 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -9,5 +9,5 @@ expression: res_vec "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a=int32], var4]\",\nvar_id: [4]\n}\n", "Class {\nname: \"B\",\nancestors: [\"{class: B, params: []}\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b=var3], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"B.foo\",\nsig: \"fn[[b=var3], none]\",\nvar_id: [3]\n}\n", + "Function {\nname: \"B.foo\",\nsig: \"fn[[b=var3], none]\",\nvar_id: []\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index 2de87dab..b08e9352 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -5,12 +5,12 @@ expression: res_vec --- [ "Class {\nname: \"A\",\nancestors: [\"{class: A, params: [\\\"var3\\\"]}\"],\nfields: [\"a\", \"b\", \"c\"],\nmethods: [(\"__init__\", \"fn[[t=var3], none]\"), (\"fun\", \"fn[[a=int32, b=var3], list[virtual[B[4->bool]]]]\"), (\"foo\", \"fn[[c=C], none]\")],\ntype_vars: [\"var3\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[t=var3], none]\",\nvar_id: [3]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a=int32, b=var3], list[virtual[B[4->bool]]]]\",\nvar_id: [3]\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[c=C], none]\",\nvar_id: [3]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[t=var3], none]\",\nvar_id: []\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a=int32, b=var3], list[virtual[B[4->bool]]]]\",\nvar_id: []\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[c=C], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"{class: B, params: [\\\"var4\\\"]}\", \"{class: A, params: [\\\"float\\\"]}\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a=int32, b=var3], list[virtual[B[4->bool]]]]\"), (\"foo\", \"fn[[c=C], none]\")],\ntype_vars: [\"var4\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: [4]\n}\n", - "Function {\nname: \"B.fun\",\nsig: \"fn[[a=int32, b=var3], list[virtual[B[4->bool]]]]\",\nvar_id: [3, 4]\n}\n", + "Function {\nname: \"B.fun\",\nsig: \"fn[[a=int32, b=var3], list[virtual[B[4->bool]]]]\",\nvar_id: [4]\n}\n", "Class {\nname: \"C\",\nancestors: [\"{class: C, params: []}\", \"{class: B, params: [\\\"bool\\\"]}\", \"{class: A, params: [\\\"float\\\"]}\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a=int32, b=var3], list[virtual[B[4->bool]]]]\"), (\"foo\", \"fn[[c=C], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index 58d1e088..2083a398 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -4,10 +4,10 @@ expression: res_vec --- [ - "Function {\nname: \"foo\",\nsig: \"fn[[a=list[int32], b=tuple[var3, float]], A[3->B, 4->bool]]\",\nvar_id: [3]\n}\n", + "Function {\nname: \"foo\",\nsig: \"fn[[a=list[int32], b=tuple[var3, float]], A[3->B, 4->bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"{class: A, params: [\\\"var3\\\", \\\"var4\\\"]}\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v=var4], none]\"), (\"fun\", \"fn[[a=var3], var4]\")],\ntype_vars: [\"var3\", \"var4\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v=var4], none]\",\nvar_id: [3, 4]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a=var3], var4]\",\nvar_id: [3, 4]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v=var4], none]\",\nvar_id: [4]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a=var3], var4]\",\nvar_id: [4]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a=A[3->list[float], 4->int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"{class: B, params: []}\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index a482fafd..2ded3960 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -5,8 +5,8 @@ expression: res_vec --- [ "Class {\nname: \"A\",\nancestors: [\"{class: A, params: [\\\"var3\\\", \\\"var4\\\"]}\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a=A[3->float, 4->bool], b=B], none]\"), (\"fun\", \"fn[[a=A[3->float, 4->bool]], A[3->bool, 4->int32]]\")],\ntype_vars: [\"var3\", \"var4\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[a=A[3->float, 4->bool], b=B], none]\",\nvar_id: [3, 4]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a=A[3->float, 4->bool]], A[3->bool, 4->int32]]\",\nvar_id: [3, 4]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[a=A[3->float, 4->bool], b=B], none]\",\nvar_id: [4]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a=A[3->float, 4->bool]], A[3->bool, 4->int32]]\",\nvar_id: [4]\n}\n", "Class {\nname: \"B\",\nancestors: [\"{class: B, params: []}\", \"{class: A, params: [\\\"int64\\\", \\\"bool\\\"]}\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a=A[3->float, 4->bool]], A[3->bool, 4->int32]]\"), (\"foo\", \"fn[[b=B], B]\"), (\"bar\", \"fn[[a=A[3->list[B], 4->int32]], tuple[A[3->virtual[A[3->B, 4->int32]], 4->bool], B]]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b=B], B]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index 42569f36..d6f269ac 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -7,12 +7,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"{class: A, params: []}\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b=B], none]\"), (\"foo\", \"fn[[a=var3, b=var4], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b=B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a=var3, b=var4], none]\",\nvar_id: [3, 4]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a=var3, b=var4], none]\",\nvar_id: [4]\n}\n", "Class {\nname: \"B\",\nancestors: [\"{class: B, params: []}\", \"{class: C, params: []}\", \"{class: A, params: []}\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b=B], none]\"), (\"foo\", \"fn[[a=var3, b=var4], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"{class: C, params: []}\", \"{class: A, params: []}\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b=B], none]\"), (\"foo\", \"fn[[a=var3, b=var4], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b=B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a=A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a=var3], var4]\",\nvar_id: [3, 4]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a=var3], var4]\",\nvar_id: [4]\n}\n", ] diff --git a/nac3core/src/toplevel/test.rs b/nac3core/src/toplevel/test.rs index ea15cad8..d1135c7d 100644 --- a/nac3core/src/toplevel/test.rs +++ b/nac3core/src/toplevel/test.rs @@ -35,6 +35,10 @@ impl ResolverInternal { struct Resolver(Arc); impl SymbolResolver for Resolver { + fn get_default_param_value(&self, _: &nac3parser::ast::Expr) -> Option { + unimplemented!() + } + fn get_symbol_type( &self, _: &mut Unifier, diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index 257d582e..307d8794 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -280,9 +280,11 @@ pub fn get_type_from_type_annotation_kinds( { let ok: bool = { // create a temp type var and unify to check compatibility - let temp = - unifier.get_fresh_var_with_range(range.borrow().as_slice()); - unifier.unify(temp.0, p).is_ok() + p == *tvar || { + let temp = + unifier.get_fresh_var_with_range(range.borrow().as_slice()); + unifier.unify(temp.0, p).is_ok() + } }; if ok { result.insert(*id, p); diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index 387336d9..cc209bba 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -19,6 +19,10 @@ struct Resolver { } impl SymbolResolver for Resolver { + fn get_default_param_value(&self, _: &nac3parser::ast::Expr) -> Option { + unimplemented!() + } + fn get_symbol_type( &self, _: &mut Unifier, diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index eb0cf4bc..0d194b26 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -719,22 +719,19 @@ impl Unifier { /// Returns Some(T) where T is the instantiated type. /// Returns None if the function is already instantiated. fn instantiate_fun(&mut self, ty: Type, fun: &FunSignature) -> Type { - let mut instantiated = false; + let mut instantiated = true; let mut vars = Vec::new(); for (k, v) in fun.vars.iter() { if let TypeEnum::TVar { id, range, .. } = self.unification_table.probe_value(*v).as_ref() { - if k != id { - instantiated = true; - break; + // for class methods that contain type vars not in class declaration, + // as long as there exits one uninstantiated type var, the function is not instantiated, + // and need to do substitution on those type vars + if k == id { + instantiated = false; + vars.push((*k, range.clone())); } - // actually, if the first check succeeded, the function should be uninstatiated. - // The cloned values must be used and would not be wasted. - vars.push((*k, range.clone())); - } else { - instantiated = true; - break; } } if instantiated { diff --git a/nac3standalone/src/basic_symbol_resolver.rs b/nac3standalone/src/basic_symbol_resolver.rs index fe5864d0..a32432e9 100644 --- a/nac3standalone/src/basic_symbol_resolver.rs +++ b/nac3standalone/src/basic_symbol_resolver.rs @@ -1,14 +1,14 @@ use nac3core::{ codegen::CodeGenContext, location::Location, - symbol_resolver::{SymbolResolver, ValueEnum}, + symbol_resolver::{SymbolResolver, SymbolValue, ValueEnum}, toplevel::{DefinitionId, TopLevelDef}, typecheck::{ type_inferencer::PrimitiveStore, typedef::{Type, Unifier}, }, }; -use nac3parser::ast::StrRef; +use nac3parser::ast::{self, StrRef}; use parking_lot::{Mutex, RwLock}; use std::{collections::HashMap, sync::Arc}; @@ -16,6 +16,7 @@ pub struct ResolverInternal { pub id_to_type: Mutex>, pub id_to_def: Mutex>, pub class_names: Mutex>, + pub module_globals: Mutex>, } impl ResolverInternal { @@ -26,11 +27,24 @@ impl ResolverInternal { pub fn add_id_type(&self, id: StrRef, ty: Type) { self.id_to_type.lock().insert(id, ty); } + + pub fn add_module_global(&self, id: StrRef, val: SymbolValue) { + self.module_globals.lock().insert(id, val); + } } pub struct Resolver(pub Arc); impl SymbolResolver for Resolver { + fn get_default_param_value(&self, expr: &ast::Expr) -> Option { + match &expr.node { + ast::ExprKind::Name { id, .. } => { + self.0.module_globals.lock().get(id).cloned() + } + _ => unimplemented!("other type of expr not supported at {}", expr.location) + } + } + fn get_symbol_type( &self, _: &mut Unifier, diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index 3d455477..85b39325 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -4,8 +4,8 @@ use inkwell::{ OptimizationLevel, }; use nac3core::typecheck::type_inferencer::PrimitiveStore; -use nac3parser::parser; -use std::env; +use nac3parser::{ast::{Expr, ExprKind, StmtKind}, parser}; +use std::{borrow::Borrow, env}; use std::fs; use std::{collections::HashMap, path::Path, sync::Arc, time::SystemTime}; @@ -15,7 +15,7 @@ use nac3core::{ WorkerRegistry, }, symbol_resolver::SymbolResolver, - toplevel::{composer::TopLevelComposer, TopLevelDef}, + toplevel::{composer::TopLevelComposer, TopLevelDef, helper::parse_parameter_default_value}, typecheck::typedef::FunSignature, }; @@ -48,6 +48,7 @@ fn main() { id_to_type: builtins_ty.into(), id_to_def: builtins_def.into(), class_names: Default::default(), + module_globals: Default::default(), } .into(); let resolver = @@ -66,6 +67,61 @@ fn main() { ); for stmt in parser_result.into_iter() { + if let StmtKind::Assign { targets, value, .. } = &stmt.node { + fn handle_assignment_pattern( + targets: &[Expr], + value: &Expr, + resolver: &(dyn SymbolResolver + Send + Sync), + internal_resolver: &ResolverInternal, + ) -> Result<(), String> { + if targets.len() == 1 { + match &targets[0].node { + ExprKind::Name { id, .. } => { + let val = parse_parameter_default_value(value.borrow(), resolver)?; + internal_resolver.add_module_global(*id, val); + Ok(()) + } + ExprKind::List { elts, .. } + | ExprKind::Tuple { elts, .. } => { + handle_assignment_pattern(elts, value, resolver, internal_resolver)?; + Ok(()) + } + _ => unreachable!("cannot be assigned") + } + } else { + match &value.node { + ExprKind::List { elts, .. } + | ExprKind::Tuple { elts, .. } => { + if elts.len() != targets.len() { + Err(format!( + "number of elements to unpack does not match (expect {}, found {}) at {}", + targets.len(), + elts.len(), + value.location + )) + } else { + for (tar, val) in targets.iter().zip(elts) { + handle_assignment_pattern( + std::slice::from_ref(tar), + val, + resolver, + internal_resolver + )?; + } + Ok(()) + } + }, + _ => Err(format!("unpack of this expression is not supported at {}", value.location)) + } + } + } + if let Err(err) = handle_assignment_pattern(targets, value, resolver.as_ref(), internal_resolver.as_ref()) { + eprintln!("{}", err); + return; + } + continue; + } + let (name, def_id, ty) = composer .register_top_level(stmt, Some(resolver.clone()), "__main__".into()) .unwrap(); @@ -100,7 +156,11 @@ fn main() { let instance = { let defs = top_level.definitions.read(); - let mut instance = defs[resolver.get_identifier_def("run".into()).unwrap().0].write(); + let mut instance = + defs[resolver + .get_identifier_def("run".into()) + .unwrap_or_else(|| panic!("cannot find run() entry point")).0 + ].write(); if let TopLevelDef::Function { instance_to_stmt, instance_to_symbol,