diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index c5d66a7c..7e3fddb4 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -650,6 +650,11 @@ impl Nac3 { } } } + TopLevelDef::Variable { .. } => { + return Err(CompileError::new_err(String::from( + "Unsupported @rpc annotation on global variable", + ))) + } } } } diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 313b75d8..59f06bfc 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -1470,6 +1470,7 @@ impl SymbolResolver for Resolver { &self, id: StrRef, _: &mut CodeGenContext<'ctx, '_>, + _: &mut dyn CodeGenerator, ) -> Option> { let sym_value = { let id_to_val = self.0.id_to_pyval.read(); diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index c478fb3f..52146145 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -977,6 +977,7 @@ pub fn gen_call<'ctx, G: CodeGenerator>( TopLevelDef::Class { .. } => { return Ok(Some(generator.gen_constructor(ctx, fun.0, &def, params)?)) } + TopLevelDef::Variable { .. } => unreachable!(), } } .or_else(|_: String| { @@ -2885,7 +2886,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( Some((_, Some(static_value), _)) => ValueEnum::Static(static_value.clone()), None => { let resolver = ctx.resolver.clone(); - resolver.get_symbol_value(*id, ctx).unwrap() + resolver.get_symbol_value(*id, ctx, generator).unwrap() } }, ExprKind::List { elts, .. } => { diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 145eab6a..74a0244d 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -853,10 +853,9 @@ pub fn gen_func_impl< builder.position_at_end(init_bb); let body_bb = context.append_basic_block(fn_val, "body"); + // Store non-vararg argument values into local variables let mut var_assignment = HashMap::new(); let offset = u32::from(has_sret); - - // Store non-vararg argument values into local variables for (n, arg) in args.iter().enumerate().filter(|(_, arg)| !arg.is_vararg) { let param = fn_val.get_nth_param((n as u32) + offset).unwrap(); let local_type = get_llvm_type( diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index fb81c6dc..cfc188c8 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -1828,6 +1828,37 @@ pub fn gen_stmt( stmt.location, ); } + StmtKind::Global { names, .. } => { + let registered_globals = ctx + .top_level + .definitions + .read() + .iter() + .filter_map(|def| { + if let TopLevelDef::Variable { simple_name, ty, .. } = &*def.read() { + Some((*simple_name, *ty)) + } else { + None + } + }) + .collect_vec(); + + for id in names { + let Some((_, ty)) = registered_globals.iter().find(|(name, _)| name == id) else { + return Err(format!("{id} is not a global at {}", stmt.location)); + }; + + let resolver = ctx.resolver.clone(); + let ptr = resolver + .get_symbol_value(*id, ctx, generator) + .map(|val| val.to_basic_value_enum(ctx, generator, *ty)) + .transpose()? + .map(BasicValueEnum::into_pointer_value) + .unwrap(); + + ctx.var_assignment.insert(*id, (ptr, None, 0)); + } + } _ => unimplemented!(), }; Ok(()) diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 9649e68e..b2799b04 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -28,7 +28,7 @@ use crate::{ DefinitionId, FunInstance, TopLevelContext, TopLevelDef, }, typecheck::{ - type_inferencer::{FunctionData, Inferencer, PrimitiveStore}, + type_inferencer::{FunctionData, IdentifierInfo, Inferencer, PrimitiveStore}, typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, }, }; @@ -67,6 +67,7 @@ impl SymbolResolver for Resolver { &self, _: StrRef, _: &mut CodeGenContext<'ctx, '_>, + _: &mut dyn CodeGenerator, ) -> Option> { unimplemented!() } @@ -141,7 +142,8 @@ fn test_primitives() { }; let mut virtual_checks = Vec::new(); let mut calls = HashMap::new(); - let mut identifiers: HashSet<_> = ["a".into(), "b".into()].into(); + let mut identifiers: HashMap<_, _> = + ["a".into(), "b".into()].map(|id| (id, IdentifierInfo::default())).into(); let mut inferencer = Inferencer { top_level: &top_level, function_data: &mut function_data, @@ -320,7 +322,8 @@ fn test_simple_call() { }; let mut virtual_checks = Vec::new(); let mut calls = HashMap::new(); - let mut identifiers: HashSet<_> = ["a".into(), "foo".into()].into(); + let mut identifiers: HashMap<_, _> = + ["a".into(), "foo".into()].map(|id| (id, IdentifierInfo::default())).into(); let mut inferencer = Inferencer { top_level: &top_level, function_data: &mut function_data, diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index ccd90a1c..bab823c1 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -369,6 +369,7 @@ pub trait SymbolResolver { &self, str: StrRef, ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut dyn CodeGenerator, ) -> Option>; fn get_default_param_value(&self, expr: &Expr) -> Option; diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 619cc50b..62a54473 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -1,13 +1,13 @@ use std::rc::Rc; -use nac3parser::ast::fold::Fold; +use nac3parser::ast::{fold::Fold, ExprKind}; use super::*; use crate::{ codegen::{expr::get_subst_key, stmt::exn_constructor}, symbol_resolver::SymbolValue, typecheck::{ - type_inferencer::{FunctionData, Inferencer}, + type_inferencer::{FunctionData, IdentifierInfo, Inferencer}, typedef::{TypeVar, VarMap}, }, }; @@ -101,7 +101,8 @@ impl TopLevelComposer { .iter() .map(|def_ast| match *def_ast.0.read() { TopLevelDef::Class { name, .. } => name.to_string(), - TopLevelDef::Function { simple_name, .. } => simple_name.to_string(), + TopLevelDef::Function { simple_name, .. } + | TopLevelDef::Variable { simple_name, .. } => simple_name.to_string(), }) .collect_vec(); @@ -381,8 +382,58 @@ impl TopLevelComposer { )) } + ast::StmtKind::AnnAssign { target, annotation, .. } => { + let ExprKind::Name { id: name, .. } = target.node else { + return Err(format!( + "global variable declaration must be an identifier (at {})", + ast.location + )); + }; + + if self.keyword_list.contains(&name) { + return Err(format!( + "cannot use keyword `{}` as a class name (at {})", + name, + ast.location + )); + } + + let global_var_name = if mod_path.is_empty() { + name.to_string() + } else { + format!("{mod_path}.{name}") + }; + if !defined_names.insert(global_var_name.clone()) { + return Err(format!( + "global variable `{}` defined twice (at {})", + global_var_name, + ast.location + )); + } + + let ty_to_be_unified = self.unifier.get_dummy_var().ty; + self.definition_ast_list.push(( + RwLock::new(Self::make_top_level_variable_def( + global_var_name, + name, + // dummy here, unify with correct type later, + ty_to_be_unified, + *(annotation.clone()), + resolver, + Some(ast.location), + )).into(), + None, + )); + + Ok(( + name, + DefinitionId(self.definition_ast_list.len() - 1), + Some(ty_to_be_unified), + )) + } + _ => Err(format!( - "registrations of constructs other than top level classes/functions are not supported (at {})", + "registrations of constructs other than top level classes/functions/variables are not supported (at {})", ast.location )), } @@ -396,6 +447,7 @@ impl TopLevelComposer { if inference { self.analyze_function_instance()?; } + self.analyze_top_level_variables()?; Ok(()) } @@ -500,6 +552,7 @@ impl TopLevelComposer { } Ok(()) }; + let mut errors = HashSet::new(); for (class_def, class_ast) in def_list.iter().skip(self.builtin_num) { if class_ast.is_none() { @@ -853,7 +906,6 @@ impl TopLevelComposer { let unifier = self.unifier.borrow_mut(); let primitives_store = &self.primitives_ty; - let mut errors = HashSet::new(); let mut analyze = |function_def: &Arc>, function_ast: &Option| { let mut function_def = function_def.write(); let function_def = &mut *function_def; @@ -1128,6 +1180,8 @@ impl TopLevelComposer { })?; Ok(()) }; + + let mut errors = HashSet::new(); for (function_def, function_ast) in def_list.iter().skip(self.builtin_num) { if function_ast.is_none() { continue; @@ -1702,7 +1756,6 @@ impl TopLevelComposer { } } - let mut errors = HashSet::new(); let mut analyze = |i, def: &Arc>, ast: &Option| { let class_def = def.read(); if let TopLevelDef::Class { @@ -1845,6 +1898,8 @@ impl TopLevelComposer { } Ok(()) }; + + let mut errors = HashSet::new(); for (i, (def, ast)) in definition_ast_list.iter().enumerate().skip(self.builtin_num) { if ast.is_none() { continue; @@ -2003,11 +2058,12 @@ impl TopLevelComposer { }) }; let mut identifiers = { - let mut result: HashSet<_> = HashSet::new(); + let mut result = HashMap::new(); if self_type.is_some() { - result.insert("self".into()); + result.insert("self".into(), IdentifierInfo::default()); } - result.extend(inst_args.iter().map(|x| x.name)); + result + .extend(inst_args.iter().map(|x| (x.name, IdentifierInfo::default()))); result }; let mut calls: HashMap = HashMap::new(); @@ -2172,4 +2228,57 @@ impl TopLevelComposer { } Ok(()) } + + /// Step 6. Analyze and populate the types of global variables. + fn analyze_top_level_variables(&mut self) -> Result<(), HashSet> { + let def_list = &self.definition_ast_list; + let temp_def_list = self.extract_def_list(); + let unifier = &mut self.unifier; + let primitives_store = &self.primitives_ty; + + let mut analyze = |variable_def: &Arc>| -> Result<_, HashSet> { + let variable_def = &mut *variable_def.write(); + + let TopLevelDef::Variable { ty: dummy_ty, ty_decl, resolver, loc, .. } = variable_def + else { + // not top level variable def, skip + return Ok(()); + }; + + let resolver = &**resolver.as_ref().unwrap(); + + let ty_annotation = parse_ast_to_type_annotation_kinds( + resolver, + &temp_def_list, + unifier, + primitives_store, + ty_decl, + HashMap::new(), + )?; + let ty_from_ty_annotation = get_type_from_type_annotation_kinds( + &temp_def_list, + unifier, + primitives_store, + &ty_annotation, + &mut None, + )?; + + unifier.unify(*dummy_ty, ty_from_ty_annotation).map_err(|e| { + HashSet::from([e.at(Some(loc.unwrap())).to_display(unifier).to_string()]) + })?; + Ok(()) + }; + + let mut errors = HashSet::new(); + for (variable_def, _) in def_list.iter().skip(self.builtin_num) { + if let Err(e) = analyze(variable_def) { + errors.extend(e); + } + } + if !errors.is_empty() { + return Err(errors); + } + + Ok(()) + } } diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 66233ce0..d674c51b 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -391,6 +391,9 @@ impl TopLevelDef { r } ), + TopLevelDef::Variable { name, ty, .. } => { + format!("Variable {{ name: {name:?}, ty: {:?} }}", unifier.stringify(*ty),) + } } } } @@ -592,6 +595,18 @@ impl TopLevelComposer { } } + #[must_use] + pub fn make_top_level_variable_def( + name: String, + simple_name: StrRef, + ty: Type, + ty_decl: Expr, + resolver: Option>, + loc: Option, + ) -> TopLevelDef { + TopLevelDef::Variable { name, simple_name, ty, ty_decl, resolver, loc } + } + #[must_use] pub fn make_class_method_name(mut class_name: String, method_name: &str) -> String { class_name.push('.'); diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index b786aa73..8241c3ac 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -10,7 +10,7 @@ use inkwell::values::BasicValueEnum; use itertools::Itertools; use parking_lot::RwLock; -use nac3parser::ast::{self, Location, Stmt, StrRef}; +use nac3parser::ast::{self, Expr, Location, Stmt, StrRef}; use crate::{ codegen::{CodeGenContext, CodeGenerator}, @@ -148,6 +148,25 @@ pub enum TopLevelDef { /// Definition location. loc: Option, }, + Variable { + /// Qualified name of the global variable, should be unique globally. + name: String, + + /// Simple name, the same as in method/function definition. + simple_name: StrRef, + + /// Type of the global variable. + ty: Type, + + /// The declared type of the global variable. + ty_decl: Expr, + + /// Symbol resolver of the module defined the class. + resolver: Option>, + + /// Definition location. + loc: Option, + }, } pub struct TopLevelContext { diff --git a/nac3core/src/toplevel/test.rs b/nac3core/src/toplevel/test.rs index cda680c7..077f6ab9 100644 --- a/nac3core/src/toplevel/test.rs +++ b/nac3core/src/toplevel/test.rs @@ -65,6 +65,7 @@ impl SymbolResolver for Resolver { &self, _: StrRef, _: &mut CodeGenContext<'ctx, '_>, + _: &mut dyn CodeGenerator, ) -> Option> { unimplemented!() } diff --git a/nac3core/src/typecheck/function_check.rs b/nac3core/src/typecheck/function_check.rs index 9dd84091..0f38a341 100644 --- a/nac3core/src/typecheck/function_check.rs +++ b/nac3core/src/typecheck/function_check.rs @@ -1,4 +1,7 @@ -use std::{collections::HashSet, iter::once}; +use std::{ + collections::{HashMap, HashSet}, + iter::once, +}; use nac3parser::ast::{ self, Constant, Expr, ExprKind, @@ -7,7 +10,7 @@ use nac3parser::ast::{ }; use super::{ - type_inferencer::Inferencer, + type_inferencer::{IdentifierInfo, Inferencer}, typedef::{Type, TypeEnum}, }; use crate::toplevel::helper::PrimDef; @@ -24,15 +27,15 @@ impl<'a> Inferencer<'a> { fn check_pattern( &mut self, pattern: &Expr>, - defined_identifiers: &mut HashSet, + defined_identifiers: &mut HashMap, ) -> Result<(), HashSet> { match &pattern.node { ExprKind::Name { id, .. } if id == &"none".into() => { Err(HashSet::from([format!("cannot assign to a `none` (at {})", pattern.location)])) } ExprKind::Name { id, .. } => { - if !defined_identifiers.contains(id) { - defined_identifiers.insert(*id); + if !defined_identifiers.contains_key(id) { + defined_identifiers.insert(*id, IdentifierInfo::default()); } self.should_have_value(pattern)?; Ok(()) @@ -72,7 +75,7 @@ impl<'a> Inferencer<'a> { fn check_expr( &mut self, expr: &Expr>, - defined_identifiers: &mut HashSet, + defined_identifiers: &mut HashMap, ) -> Result<(), HashSet> { // there are some cases where the custom field is None if let Some(ty) = &expr.custom { @@ -93,7 +96,7 @@ impl<'a> Inferencer<'a> { return Ok(()); } self.should_have_value(expr)?; - if !defined_identifiers.contains(id) { + if !defined_identifiers.contains_key(id) { match self.function_data.resolver.get_symbol_type( self.unifier, &self.top_level.definitions.read(), @@ -101,7 +104,7 @@ impl<'a> Inferencer<'a> { *id, ) { Ok(_) => { - self.defined_identifiers.insert(*id); + self.defined_identifiers.insert(*id, IdentifierInfo::default()); } Err(e) => { return Err(HashSet::from([format!( @@ -174,9 +177,7 @@ impl<'a> Inferencer<'a> { let mut defined_identifiers = defined_identifiers.clone(); for arg in &args.args { // TODO: should we check the types here? - if !defined_identifiers.contains(&arg.node.arg) { - defined_identifiers.insert(arg.node.arg); - } + defined_identifiers.entry(arg.node.arg).or_default(); } self.check_expr(body, &mut defined_identifiers)?; } @@ -239,7 +240,7 @@ impl<'a> Inferencer<'a> { fn check_stmt( &mut self, stmt: &Stmt>, - defined_identifiers: &mut HashSet, + defined_identifiers: &mut HashMap, ) -> Result> { match &stmt.node { StmtKind::For { target, iter, body, orelse, .. } => { @@ -265,9 +266,11 @@ impl<'a> Inferencer<'a> { let body_returned = self.check_block(body, &mut body_identifiers)?; let orelse_returned = self.check_block(orelse, &mut orelse_identifiers)?; - for ident in &body_identifiers { - if !defined_identifiers.contains(ident) && orelse_identifiers.contains(ident) { - defined_identifiers.insert(*ident); + for ident in body_identifiers.keys() { + if !defined_identifiers.contains_key(ident) + && orelse_identifiers.contains_key(ident) + { + defined_identifiers.insert(*ident, IdentifierInfo::default()); } } Ok(body_returned && orelse_returned) @@ -298,7 +301,7 @@ impl<'a> Inferencer<'a> { let mut defined_identifiers = defined_identifiers.clone(); let ast::ExcepthandlerKind::ExceptHandler { name, body, .. } = &handler.node; if let Some(name) = name { - defined_identifiers.insert(*name); + defined_identifiers.insert(*name, IdentifierInfo::default()); } self.check_block(body, &mut defined_identifiers)?; } @@ -362,6 +365,40 @@ impl<'a> Inferencer<'a> { } Ok(true) } + StmtKind::Global { names, .. } => { + for id in names { + if let Some(id_info) = defined_identifiers.get(id) { + if !id_info.is_global { + return Err(HashSet::from([format!( + "name '{id}' is assigned to before global declaration at {}", + stmt.location, + )])); + } + + continue; + } + + match self.function_data.resolver.get_symbol_type( + self.unifier, + &self.top_level.definitions.read(), + self.primitives, + *id, + ) { + Ok(_) => { + self.defined_identifiers + .insert(*id, IdentifierInfo { is_global: true }); + } + Err(e) => { + return Err(HashSet::from([format!( + "type error at identifier `{}` ({}) at {}", + id, e, stmt.location + )])) + } + } + } + + Ok(false) + } // break, raise, etc. _ => Ok(false), } @@ -370,7 +407,7 @@ impl<'a> Inferencer<'a> { pub fn check_block( &mut self, block: &[Stmt>], - defined_identifiers: &mut HashSet, + defined_identifiers: &mut HashMap, ) -> Result> { let mut ret = false; for stmt in block { diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 48736738..ce2e76d2 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -88,6 +88,20 @@ impl PrimitiveStore { } } +/// Information regarding a defined identifier. +#[derive(Clone, Copy, Debug, Default)] +pub struct IdentifierInfo { + /// Whether this identifier refers to a global variable. + pub is_global: bool, +} + +impl IdentifierInfo { + #[must_use] + pub fn new() -> IdentifierInfo { + IdentifierInfo::default() + } +} + pub struct FunctionData { pub resolver: Arc, pub return_type: Option, @@ -96,7 +110,7 @@ pub struct FunctionData { pub struct Inferencer<'a> { pub top_level: &'a TopLevelContext, - pub defined_identifiers: HashSet, + pub defined_identifiers: HashMap, pub function_data: &'a mut FunctionData, pub unifier: &'a mut Unifier, pub primitives: &'a PrimitiveStore, @@ -228,9 +242,7 @@ impl<'a> Fold<()> for Inferencer<'a> { handler.location, )); if let Some(name) = name { - if !self.defined_identifiers.contains(&name) { - self.defined_identifiers.insert(name); - } + self.defined_identifiers.entry(name).or_default(); if let Some(old_typ) = self.variable_mapping.insert(name, typ) { let loc = handler.location; self.unifier.unify(old_typ, typ).map_err(|e| { @@ -382,6 +394,7 @@ impl<'a> Fold<()> for Inferencer<'a> { | ast::StmtKind::Continue { .. } | ast::StmtKind::Expr { .. } | ast::StmtKind::For { .. } + | ast::StmtKind::Global { .. } | ast::StmtKind::Pass { .. } | ast::StmtKind::Try { .. } => {} ast::StmtKind::If { test, .. } | ast::StmtKind::While { test, .. } => { @@ -553,7 +566,7 @@ impl<'a> Fold<()> for Inferencer<'a> { unreachable!("must be tobj") } } else { - if !self.defined_identifiers.contains(id) { + if !self.defined_identifiers.contains_key(id) { match self.function_data.resolver.get_symbol_type( self.unifier, &self.top_level.definitions.read(), @@ -561,7 +574,7 @@ impl<'a> Fold<()> for Inferencer<'a> { *id, ) { Ok(_) => { - self.defined_identifiers.insert(*id); + self.defined_identifiers.insert(*id, IdentifierInfo::default()); } Err(e) => { return report_error( @@ -626,8 +639,8 @@ impl<'a> Inferencer<'a> { fn infer_pattern(&mut self, pattern: &ast::Expr) -> Result<(), InferenceError> { match &pattern.node { ExprKind::Name { id, .. } => { - if !self.defined_identifiers.contains(id) { - self.defined_identifiers.insert(*id); + if !self.defined_identifiers.contains_key(id) { + self.defined_identifiers.insert(*id, IdentifierInfo::default()); } Ok(()) } @@ -736,8 +749,8 @@ impl<'a> Inferencer<'a> { let mut defined_identifiers = self.defined_identifiers.clone(); for arg in &args.args { let name = &arg.node.arg; - if !defined_identifiers.contains(name) { - defined_identifiers.insert(*name); + if !defined_identifiers.contains_key(name) { + defined_identifiers.insert(*name, IdentifierInfo::default()); } } let fn_args: Vec<_> = args diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index f81e4ca4..e56cb283 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -9,7 +9,7 @@ use nac3parser::{ast::FileName, parser::parse_program}; use super::*; use crate::{ - codegen::CodeGenContext, + codegen::{CodeGenContext, CodeGenerator}, symbol_resolver::ValueEnum, toplevel::{helper::PrimDef, DefinitionId, TopLevelDef}, typecheck::{magic_methods::with_fields, typedef::*}, @@ -43,6 +43,7 @@ impl SymbolResolver for Resolver { &self, _: StrRef, _: &mut CodeGenContext<'ctx, '_>, + _: &mut dyn CodeGenerator, ) -> Option> { unimplemented!() } @@ -519,7 +520,7 @@ impl TestEnvironment { primitives: &mut self.primitives, virtual_checks: &mut self.virtual_checks, calls: &mut self.calls, - defined_identifiers: HashSet::default(), + defined_identifiers: HashMap::default(), in_handler: false, } } @@ -595,8 +596,9 @@ fn test_basic(source: &str, mapping: &HashMap<&str, &str>, virtuals: &[(&str, &s println!("source:\n{source}"); let mut env = TestEnvironment::new(); let id_to_name = std::mem::take(&mut env.id_to_name); - let mut defined_identifiers: HashSet<_> = env.identifier_mapping.keys().copied().collect(); - defined_identifiers.insert("virtual".into()); + let mut defined_identifiers: HashMap<_, _> = + env.identifier_mapping.keys().copied().map(|id| (id, IdentifierInfo::default())).collect(); + defined_identifiers.insert("virtual".into(), IdentifierInfo::default()); let mut inferencer = env.get_inferencer(); inferencer.defined_identifiers.clone_from(&defined_identifiers); let statements = parse_program(source, FileName::default()).unwrap(); @@ -741,8 +743,9 @@ fn test_primitive_magic_methods(source: &str, mapping: &HashMap<&str, &str>) { println!("source:\n{source}"); let mut env = TestEnvironment::basic_test_env(); let id_to_name = std::mem::take(&mut env.id_to_name); - let mut defined_identifiers: HashSet<_> = env.identifier_mapping.keys().copied().collect(); - defined_identifiers.insert("virtual".into()); + let mut defined_identifiers: HashMap<_, _> = + env.identifier_mapping.keys().copied().map(|id| (id, IdentifierInfo::default())).collect(); + defined_identifiers.insert("virtual".into(), IdentifierInfo::default()); let mut inferencer = env.get_inferencer(); inferencer.defined_identifiers.clone_from(&defined_identifiers); let statements = parse_program(source, FileName::default()).unwrap(); diff --git a/nac3standalone/demo/src/globals.py b/nac3standalone/demo/src/globals.py new file mode 100644 index 00000000..d8b7822d --- /dev/null +++ b/nac3standalone/demo/src/globals.py @@ -0,0 +1,31 @@ +@extern +def output_int32(x: int32): + ... + +@extern +def output_int64(x: int64): + ... + +X: int32 = 0 +Y: int64 = int64(1) + +def f(): + global X, Y + X = 1 + Y = int64(2) + +def run() -> int32: + global X, Y + + output_int32(X) + output_int64(Y) + f() + output_int32(X) + output_int64(Y) + + X = 0 + Y = int64(0) + output_int32(X) + output_int64(Y) + + return 0 \ No newline at end of file diff --git a/nac3standalone/src/basic_symbol_resolver.rs b/nac3standalone/src/basic_symbol_resolver.rs index 631ee737..9f449a48 100644 --- a/nac3standalone/src/basic_symbol_resolver.rs +++ b/nac3standalone/src/basic_symbol_resolver.rs @@ -6,7 +6,8 @@ use std::{ use parking_lot::{Mutex, RwLock}; use nac3core::{ - codegen::CodeGenContext, + codegen::{CodeGenContext, CodeGenerator}, + inkwell::{module::Linkage, values::BasicValue}, nac3parser::ast::{self, StrRef}, symbol_resolver::{SymbolResolver, SymbolValue, ValueEnum}, toplevel::{DefinitionId, TopLevelDef}, @@ -49,20 +50,51 @@ impl SymbolResolver for Resolver { fn get_symbol_type( &self, - _: &mut Unifier, + unifier: &mut Unifier, _: &[Arc>], - _: &PrimitiveStore, + primitives: &PrimitiveStore, str: StrRef, ) -> Result { - self.0.id_to_type.lock().get(&str).copied().ok_or(format!("cannot get type of {str}")) + self.0 + .id_to_type + .lock() + .get(&str) + .copied() + .or_else(|| { + self.0 + .module_globals + .lock() + .get(&str) + .cloned() + .map(|v| v.get_type(primitives, unifier)) + }) + .ok_or(format!("cannot get type of {str}")) } fn get_symbol_value<'ctx>( &self, - _: StrRef, - _: &mut CodeGenContext<'ctx, '_>, + str: StrRef, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut dyn CodeGenerator, ) -> Option> { - unimplemented!() + self.0.module_globals.lock().get(&str).cloned().map(|v| { + ctx.module + .get_global(&str.to_string()) + .unwrap_or_else(|| { + let ty = v.get_type(&ctx.primitives, &mut ctx.unifier); + + let init_val = ctx.gen_symbol_val(generator, &v, ty); + let llvm_ty = init_val.get_type(); + + let global = ctx.module.add_global(llvm_ty, None, &str.to_string()); + global.set_linkage(Linkage::LinkOnceAny); + global.set_initializer(&init_val); + + global + }) + .as_basic_value_enum() + .into() + }) } fn get_identifier_def(&self, id: StrRef) -> Result> { diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index 83985709..e27c5e1f 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -245,6 +245,34 @@ fn handle_assignment_pattern( } } +fn handle_global_var( + target: &Expr, + value: Option<&Expr>, + resolver: &(dyn SymbolResolver + Send + Sync), + internal_resolver: &ResolverInternal, +) -> Result<(), String> { + let ExprKind::Name { id, .. } = target.node else { + return Err(format!( + "global variable declaration must be an identifier (at {})", + target.location, + )); + }; + + let Some(value) = value else { + return Err(format!("global variable `{id}` must be initialized in its definition")); + }; + + if let Ok(val) = parse_parameter_default_value(value, resolver) { + internal_resolver.add_module_global(id, val); + Ok(()) + } else { + Err(format!( + "failed to evaluate this expression `{:?}` as a constant at {}", + target.node, target.location, + )) + } +} + fn main() { let cli = CommandLineArgs::parse(); let CommandLineArgs { file_name, threads, opt_level, emit_llvm, triple, mcpu, target_features } = @@ -297,8 +325,7 @@ fn main() { let program = match fs::read_to_string(file_name.clone()) { Ok(program) => program, Err(err) => { - println!("Cannot open input file: {err}"); - return; + panic!("Cannot open input file: {err}"); } }; @@ -340,10 +367,26 @@ fn main() { unifier, primitives, ) { - eprintln!("{err}"); - return; + panic!("{err}"); } } + + StmtKind::AnnAssign { target, value, .. } => { + if let Err(err) = handle_global_var( + target, + value.as_ref().map(Box::as_ref), + resolver.as_ref(), + internal_resolver.as_ref(), + ) { + panic!("{err}"); + } + + let (name, def_id, _) = composer + .register_top_level(stmt, Some(resolver.clone()), "__main__", true) + .unwrap(); + internal_resolver.add_id_def(name, def_id); + } + // allow (and ignore) "from __future__ import annotations" StmtKind::ImportFrom { module, names, .. } if module == &Some("__future__".into())