diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 9649e68e..59f62129 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -28,7 +28,7 @@ use crate::{ DefinitionId, FunInstance, TopLevelContext, TopLevelDef, }, typecheck::{ - type_inferencer::{FunctionData, Inferencer, PrimitiveStore}, + type_inferencer::{FunctionData, IdentifierInfo, Inferencer, PrimitiveStore}, typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, }, }; @@ -141,7 +141,8 @@ fn test_primitives() { }; let mut virtual_checks = Vec::new(); let mut calls = HashMap::new(); - let mut identifiers: HashSet<_> = ["a".into(), "b".into()].into(); + let mut identifiers: HashMap<_, _> = + ["a".into(), "b".into()].map(|id| (id, IdentifierInfo::default())).into(); let mut inferencer = Inferencer { top_level: &top_level, function_data: &mut function_data, @@ -320,7 +321,8 @@ fn test_simple_call() { }; let mut virtual_checks = Vec::new(); let mut calls = HashMap::new(); - let mut identifiers: HashSet<_> = ["a".into(), "foo".into()].into(); + let mut identifiers: HashMap<_, _> = + ["a".into(), "foo".into()].map(|id| (id, IdentifierInfo::default())).into(); let mut inferencer = Inferencer { top_level: &top_level, function_data: &mut function_data, diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 06d5022b..020c7a8f 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -7,7 +7,7 @@ use crate::{ codegen::{expr::get_subst_key, stmt::exn_constructor}, symbol_resolver::SymbolValue, typecheck::{ - type_inferencer::{FunctionData, Inferencer}, + type_inferencer::{FunctionData, IdentifierInfo, Inferencer}, typedef::{TypeVar, VarMap}, }, }; @@ -2057,11 +2057,12 @@ impl TopLevelComposer { }) }; let mut identifiers = { - let mut result: HashSet<_> = HashSet::new(); + let mut result = HashMap::new(); if self_type.is_some() { - result.insert("self".into()); + result.insert("self".into(), IdentifierInfo::default()); } - result.extend(inst_args.iter().map(|x| x.name)); + result + .extend(inst_args.iter().map(|x| (x.name, IdentifierInfo::default()))); result }; let mut calls: HashMap = HashMap::new(); diff --git a/nac3core/src/typecheck/function_check.rs b/nac3core/src/typecheck/function_check.rs index 9dd84091..4b5ccb57 100644 --- a/nac3core/src/typecheck/function_check.rs +++ b/nac3core/src/typecheck/function_check.rs @@ -1,4 +1,7 @@ -use std::{collections::HashSet, iter::once}; +use std::{ + collections::{HashMap, HashSet}, + iter::once, +}; use nac3parser::ast::{ self, Constant, Expr, ExprKind, @@ -7,7 +10,7 @@ use nac3parser::ast::{ }; use super::{ - type_inferencer::Inferencer, + type_inferencer::{IdentifierInfo, Inferencer}, typedef::{Type, TypeEnum}, }; use crate::toplevel::helper::PrimDef; @@ -24,15 +27,15 @@ impl<'a> Inferencer<'a> { fn check_pattern( &mut self, pattern: &Expr>, - defined_identifiers: &mut HashSet, + defined_identifiers: &mut HashMap, ) -> Result<(), HashSet> { match &pattern.node { ExprKind::Name { id, .. } if id == &"none".into() => { Err(HashSet::from([format!("cannot assign to a `none` (at {})", pattern.location)])) } ExprKind::Name { id, .. } => { - if !defined_identifiers.contains(id) { - defined_identifiers.insert(*id); + if !defined_identifiers.contains_key(id) { + defined_identifiers.insert(*id, IdentifierInfo::default()); } self.should_have_value(pattern)?; Ok(()) @@ -72,7 +75,7 @@ impl<'a> Inferencer<'a> { fn check_expr( &mut self, expr: &Expr>, - defined_identifiers: &mut HashSet, + defined_identifiers: &mut HashMap, ) -> Result<(), HashSet> { // there are some cases where the custom field is None if let Some(ty) = &expr.custom { @@ -93,7 +96,7 @@ impl<'a> Inferencer<'a> { return Ok(()); } self.should_have_value(expr)?; - if !defined_identifiers.contains(id) { + if !defined_identifiers.contains_key(id) { match self.function_data.resolver.get_symbol_type( self.unifier, &self.top_level.definitions.read(), @@ -101,7 +104,7 @@ impl<'a> Inferencer<'a> { *id, ) { Ok(_) => { - self.defined_identifiers.insert(*id); + self.defined_identifiers.insert(*id, IdentifierInfo::default()); } Err(e) => { return Err(HashSet::from([format!( @@ -174,9 +177,7 @@ impl<'a> Inferencer<'a> { let mut defined_identifiers = defined_identifiers.clone(); for arg in &args.args { // TODO: should we check the types here? - if !defined_identifiers.contains(&arg.node.arg) { - defined_identifiers.insert(arg.node.arg); - } + defined_identifiers.entry(arg.node.arg).or_default(); } self.check_expr(body, &mut defined_identifiers)?; } @@ -239,7 +240,7 @@ impl<'a> Inferencer<'a> { fn check_stmt( &mut self, stmt: &Stmt>, - defined_identifiers: &mut HashSet, + defined_identifiers: &mut HashMap, ) -> Result> { match &stmt.node { StmtKind::For { target, iter, body, orelse, .. } => { @@ -265,9 +266,11 @@ impl<'a> Inferencer<'a> { let body_returned = self.check_block(body, &mut body_identifiers)?; let orelse_returned = self.check_block(orelse, &mut orelse_identifiers)?; - for ident in &body_identifiers { - if !defined_identifiers.contains(ident) && orelse_identifiers.contains(ident) { - defined_identifiers.insert(*ident); + for ident in body_identifiers.keys() { + if !defined_identifiers.contains_key(ident) + && orelse_identifiers.contains_key(ident) + { + defined_identifiers.insert(*ident, IdentifierInfo::default()); } } Ok(body_returned && orelse_returned) @@ -298,7 +301,7 @@ impl<'a> Inferencer<'a> { let mut defined_identifiers = defined_identifiers.clone(); let ast::ExcepthandlerKind::ExceptHandler { name, body, .. } = &handler.node; if let Some(name) = name { - defined_identifiers.insert(*name); + defined_identifiers.insert(*name, IdentifierInfo::default()); } self.check_block(body, &mut defined_identifiers)?; } @@ -370,7 +373,7 @@ impl<'a> Inferencer<'a> { pub fn check_block( &mut self, block: &[Stmt>], - defined_identifiers: &mut HashSet, + defined_identifiers: &mut HashMap, ) -> Result> { let mut ret = false; for stmt in block { diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 48736738..ba09bd5a 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -88,6 +88,20 @@ impl PrimitiveStore { } } +/// Information regarding a defined identifier. +#[derive(Clone, Copy, Debug, Default)] +pub struct IdentifierInfo { + /// Whether this identifier refers to a global variable. + pub is_global: bool, +} + +impl IdentifierInfo { + #[must_use] + pub fn new() -> IdentifierInfo { + IdentifierInfo::default() + } +} + pub struct FunctionData { pub resolver: Arc, pub return_type: Option, @@ -96,7 +110,7 @@ pub struct FunctionData { pub struct Inferencer<'a> { pub top_level: &'a TopLevelContext, - pub defined_identifiers: HashSet, + pub defined_identifiers: HashMap, pub function_data: &'a mut FunctionData, pub unifier: &'a mut Unifier, pub primitives: &'a PrimitiveStore, @@ -228,9 +242,7 @@ impl<'a> Fold<()> for Inferencer<'a> { handler.location, )); if let Some(name) = name { - if !self.defined_identifiers.contains(&name) { - self.defined_identifiers.insert(name); - } + self.defined_identifiers.entry(name).or_default(); if let Some(old_typ) = self.variable_mapping.insert(name, typ) { let loc = handler.location; self.unifier.unify(old_typ, typ).map_err(|e| { @@ -553,7 +565,7 @@ impl<'a> Fold<()> for Inferencer<'a> { unreachable!("must be tobj") } } else { - if !self.defined_identifiers.contains(id) { + if !self.defined_identifiers.contains_key(id) { match self.function_data.resolver.get_symbol_type( self.unifier, &self.top_level.definitions.read(), @@ -561,7 +573,7 @@ impl<'a> Fold<()> for Inferencer<'a> { *id, ) { Ok(_) => { - self.defined_identifiers.insert(*id); + self.defined_identifiers.insert(*id, IdentifierInfo::default()); } Err(e) => { return report_error( @@ -626,8 +638,8 @@ impl<'a> Inferencer<'a> { fn infer_pattern(&mut self, pattern: &ast::Expr) -> Result<(), InferenceError> { match &pattern.node { ExprKind::Name { id, .. } => { - if !self.defined_identifiers.contains(id) { - self.defined_identifiers.insert(*id); + if !self.defined_identifiers.contains_key(id) { + self.defined_identifiers.insert(*id, IdentifierInfo::default()); } Ok(()) } @@ -736,8 +748,8 @@ impl<'a> Inferencer<'a> { let mut defined_identifiers = self.defined_identifiers.clone(); for arg in &args.args { let name = &arg.node.arg; - if !defined_identifiers.contains(name) { - defined_identifiers.insert(*name); + if !defined_identifiers.contains_key(name) { + defined_identifiers.insert(*name, IdentifierInfo::default()); } } let fn_args: Vec<_> = args diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index f81e4ca4..a3e307f4 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -519,7 +519,7 @@ impl TestEnvironment { primitives: &mut self.primitives, virtual_checks: &mut self.virtual_checks, calls: &mut self.calls, - defined_identifiers: HashSet::default(), + defined_identifiers: HashMap::default(), in_handler: false, } } @@ -595,8 +595,9 @@ fn test_basic(source: &str, mapping: &HashMap<&str, &str>, virtuals: &[(&str, &s println!("source:\n{source}"); let mut env = TestEnvironment::new(); let id_to_name = std::mem::take(&mut env.id_to_name); - let mut defined_identifiers: HashSet<_> = env.identifier_mapping.keys().copied().collect(); - defined_identifiers.insert("virtual".into()); + let mut defined_identifiers: HashMap<_, _> = + env.identifier_mapping.keys().copied().map(|id| (id, IdentifierInfo::default())).collect(); + defined_identifiers.insert("virtual".into(), IdentifierInfo::default()); let mut inferencer = env.get_inferencer(); inferencer.defined_identifiers.clone_from(&defined_identifiers); let statements = parse_program(source, FileName::default()).unwrap(); @@ -741,8 +742,9 @@ fn test_primitive_magic_methods(source: &str, mapping: &HashMap<&str, &str>) { println!("source:\n{source}"); let mut env = TestEnvironment::basic_test_env(); let id_to_name = std::mem::take(&mut env.id_to_name); - let mut defined_identifiers: HashSet<_> = env.identifier_mapping.keys().copied().collect(); - defined_identifiers.insert("virtual".into()); + let mut defined_identifiers: HashMap<_, _> = + env.identifier_mapping.keys().copied().map(|id| (id, IdentifierInfo::default())).collect(); + defined_identifiers.insert("virtual".into(), IdentifierInfo::default()); let mut inferencer = env.get_inferencer(); inferencer.defined_identifiers.clone_from(&defined_identifiers); let statements = parse_program(source, FileName::default()).unwrap();