diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 50ed2e05..81156d00 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -77,6 +77,7 @@ fn test_primitives() { }; let mut virtual_checks = Vec::new(); let mut calls = HashMap::new(); + let mut identifiers = vec!["a".to_string(), "b".to_string()]; let mut inferencer = Inferencer { top_level: &top_level, function_data: &mut function_data, @@ -85,6 +86,7 @@ fn test_primitives() { primitives: &primitives, virtual_checks: &mut virtual_checks, calls: &mut calls, + defined_identifiers: identifiers.clone() }; inferencer.variable_mapping.insert("a".into(), inferencer.primitives.int32); inferencer.variable_mapping.insert("b".into(), inferencer.primitives.int32); @@ -95,7 +97,6 @@ fn test_primitives() { .collect::, _>>() .unwrap(); - let mut identifiers = vec!["a".to_string(), "b".to_string()]; inferencer.check_block(&statements, &mut identifiers).unwrap(); let top_level = Arc::new(TopLevelContext { definitions: Arc::new(RwLock::new(std::mem::take(&mut *top_level.definitions.write()))), @@ -235,6 +236,7 @@ fn test_simple_call() { }; let mut virtual_checks = Vec::new(); let mut calls = HashMap::new(); + let mut identifiers = vec!["a".to_string(), "foo".into()]; let mut inferencer = Inferencer { top_level: &top_level, function_data: &mut function_data, @@ -243,6 +245,7 @@ fn test_simple_call() { primitives: &primitives, virtual_checks: &mut virtual_checks, calls: &mut calls, + defined_identifiers: identifiers.clone() }; inferencer.variable_mapping.insert("a".into(), inferencer.primitives.int32); inferencer.variable_mapping.insert("foo".into(), fun_ty); @@ -273,7 +276,6 @@ fn test_simple_call() { unreachable!() } - let mut identifiers = vec!["a".to_string(), "foo".into()]; inferencer.check_block(&statements_1, &mut identifiers).unwrap(); let top_level = Arc::new(TopLevelContext { definitions: Arc::new(RwLock::new(std::mem::take(&mut *top_level.definitions.write()))), diff --git a/nac3core/src/typecheck/function_check.rs b/nac3core/src/typecheck/function_check.rs index 23cd80e2..cf7956de 100644 --- a/nac3core/src/typecheck/function_check.rs +++ b/nac3core/src/typecheck/function_check.rs @@ -42,7 +42,7 @@ impl<'a> Inferencer<'a> { fn check_expr( &mut self, expr: &Expr>, - defined_identifiers: &[String], + defined_identifiers: &mut Vec, ) -> Result<(), String> { // there are some cases where the custom field is None if let Some(ty) = &expr.custom { @@ -57,10 +57,14 @@ impl<'a> Inferencer<'a> { match &expr.node { ExprKind::Name { id, .. } => { if !defined_identifiers.contains(id) { - return Err(format!( - "unknown identifier {} (use before def?) at {}", - id, expr.location - )); + if self.function_data.resolver.get_identifier_def(id).is_some() { + defined_identifiers.push(id.clone()); + } else { + return Err(format!( + "unknown identifier {} (use before def?) at {}", + id, expr.location + )); + } } } ExprKind::List { elts, .. } @@ -106,7 +110,7 @@ impl<'a> Inferencer<'a> { defined_identifiers.push(arg.node.arg.clone()); } } - self.check_expr(body, &defined_identifiers)?; + self.check_expr(body, &mut defined_identifiers)?; } ExprKind::ListComp { elt, generators, .. } => { // in our type inference stage, we already make sure that there is only 1 generator @@ -115,7 +119,7 @@ impl<'a> Inferencer<'a> { let mut defined_identifiers = defined_identifiers.to_vec(); self.check_pattern(target, &mut defined_identifiers)?; for term in once(elt.as_ref()).chain(ifs.iter()) { - self.check_expr(term, &defined_identifiers)?; + self.check_expr(term, &mut defined_identifiers)?; } } ExprKind::Call { func, args, keywords } => { diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 898d5f3e..e1e4d3f2 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -45,6 +45,7 @@ pub struct FunctionData { pub struct Inferencer<'a> { pub top_level: &'a TopLevelContext, + pub defined_identifiers: Vec, pub function_data: &'a mut FunctionData, pub unifier: &'a mut Unifier, pub primitives: &'a PrimitiveStore, @@ -74,13 +75,14 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { let stmt = match node.node { // we don't want fold over type annotation ast::StmtKind::AnnAssign { target, annotation, value, simple } => { + self.infer_pattern(&target)?; let target = Box::new(self.fold_expr(*target)?); let value = if let Some(v) = value { let ty = Box::new(self.fold_expr(*v)?); self.unify(target.custom.unwrap(), ty.custom.unwrap(), &node.location)?; Some(ty) } else { - None + return Err(format!("declaration without definition is not yet supported, at {}", node.location)) }; let top_level_defs = self.top_level.definitions.read(); let annotation_type = self.function_data.resolver.parse_type_annotation( @@ -97,6 +99,16 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { node: ast::StmtKind::AnnAssign { target, annotation, value, simple }, } } + ast::StmtKind::For { ref target, .. } => { + self.infer_pattern(target)?; + fold::fold_stmt(self, node)? + } + ast::StmtKind::Assign { ref targets, .. } => { + for target in targets { + self.infer_pattern(target)?; + } + fold::fold_stmt(self, node)? + } _ => fold::fold_stmt(self, node)?, }; match &stmt.node { @@ -146,7 +158,19 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { }; let custom = match &expr.node { ast::ExprKind::Constant { value, .. } => Some(self.infer_constant(value)?), - ast::ExprKind::Name { id, .. } => Some(self.infer_identifier(id)?), + ast::ExprKind::Name { id, .. } => { + if !self.defined_identifiers.contains(id) { + if self.function_data.resolver.get_identifier_def(id.as_str()).is_some() { + self.defined_identifiers.push(id.clone()); + } else { + return Err(format!( + "unknown identifier {} (use before def?) at {}", + id, expr.location + )); + } + } + Some(self.infer_identifier(id)?) + } ast::ExprKind::List { elts, .. } => Some(self.infer_list(elts)?), ast::ExprKind::Tuple { elts, .. } => Some(self.infer_tuple(elts)?), ast::ExprKind::Attribute { value, attr, ctx: _ } => { @@ -187,6 +211,24 @@ impl<'a> Inferencer<'a> { self.unifier.unify(a, b).map_err(|old| format!("{} at {}", old, location)) } + fn infer_pattern(&mut self, pattern: &ast::Expr<()>) -> Result<(), String> { + match &pattern.node { + ExprKind::Name { id, .. } => { + if !self.defined_identifiers.contains(id) { + self.defined_identifiers.push(id.clone()); + } + Ok(()) + } + ExprKind::Tuple { elts, .. } => { + for elt in elts.iter() { + self.infer_pattern(elt)?; + } + Ok(()) + } + _ => Ok(()), + } + } + fn build_method_call( &mut self, location: Location, @@ -228,6 +270,13 @@ impl<'a> Inferencer<'a> { ); } + let mut defined_identifiers = self.defined_identifiers.clone(); + for arg in args.args.iter() { + let name = &arg.node.arg; + if !defined_identifiers.contains(name) { + defined_identifiers.push(name.clone()); + } + } let fn_args: Vec<_> = args .args .iter() @@ -236,6 +285,7 @@ impl<'a> Inferencer<'a> { let mut variable_mapping = self.variable_mapping.clone(); variable_mapping.extend(fn_args.iter().cloned()); let ret = self.unifier.get_fresh_var().0; + let mut new_context = Inferencer { function_data: self.function_data, unifier: self.unifier, @@ -243,6 +293,7 @@ impl<'a> Inferencer<'a> { virtual_checks: self.virtual_checks, calls: self.calls, top_level: self.top_level, + defined_identifiers, variable_mapping, }; let fun = FunSignature { @@ -279,6 +330,7 @@ impl<'a> Inferencer<'a> { ); } let variable_mapping = self.variable_mapping.clone(); + let defined_identifiers = self.defined_identifiers.clone(); let mut new_context = Inferencer { function_data: self.function_data, unifier: self.unifier, @@ -287,12 +339,14 @@ impl<'a> Inferencer<'a> { variable_mapping, primitives: self.primitives, calls: self.calls, + defined_identifiers, }; - let elt = new_context.fold_expr(elt)?; let generator = generators.pop().unwrap(); if generator.is_async { return Err("Async iterator not supported.".to_string()); } + new_context.infer_pattern(&generator.target)?; + let elt = new_context.fold_expr(elt)?; let target = new_context.fold_expr(*generator.target)?; let iter = new_context.fold_expr(*generator.iter)?; let ifs: Vec<_> = generator diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index db4bbd38..e081a633 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -313,6 +313,7 @@ impl TestEnvironment { primitives: &mut self.primitives, virtual_checks: &mut self.virtual_checks, calls: &mut self.calls, + defined_identifiers: vec![] } } } @@ -382,6 +383,7 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st let mut defined_identifiers: Vec<_> = env.identifier_mapping.keys().cloned().collect(); defined_identifiers.push("virtual".to_string()); let mut inferencer = env.get_inferencer(); + inferencer.defined_identifiers = defined_identifiers.clone(); let statements = parse_program(source).unwrap(); let statements = statements .into_iter() @@ -523,6 +525,7 @@ fn test_primitive_magic_methods(source: &str, mapping: HashMap<&str, &str>) { let mut defined_identifiers: Vec<_> = env.identifier_mapping.keys().cloned().collect(); defined_identifiers.push("virtual".to_string()); let mut inferencer = env.get_inferencer(); + inferencer.defined_identifiers = defined_identifiers.clone(); let statements = parse_program(source).unwrap(); let statements = statements .into_iter() diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index ddf47a59..41983a48 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -112,6 +112,7 @@ fn main() { primitives: &primitives, virtual_checks: &mut virtual_checks, calls: &mut calls, + defined_identifiers: vec![] }; let statements = statements @@ -124,6 +125,7 @@ fn main() { inferencer .check_block(&statements, &mut identifiers) .unwrap(); + let top_level = Arc::new(TopLevelContext { definitions: Arc::new(RwLock::new(std::mem::take( &mut *top_level.definitions.write(),