From 8d6f8086f549ab690f9394973a49af902bd28ae4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Bourdeauducq?= Date: Fri, 4 Oct 2024 14:48:49 +0800 Subject: [PATCH] handle global variables --- src/basic_symbol_resolver.rs | 51 ++++++++++++++---- src/main.rs | 100 ++++++++++++++++++++++++++--------- 2 files changed, 116 insertions(+), 35 deletions(-) diff --git a/src/basic_symbol_resolver.rs b/src/basic_symbol_resolver.rs index 8481e82..9f449a4 100644 --- a/src/basic_symbol_resolver.rs +++ b/src/basic_symbol_resolver.rs @@ -1,6 +1,14 @@ -use nac3core::nac3parser::ast::{self, StrRef}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; + +use parking_lot::{Mutex, RwLock}; + use nac3core::{ - codegen::CodeGenContext, + codegen::{CodeGenContext, CodeGenerator}, + inkwell::{module::Linkage, values::BasicValue}, + nac3parser::ast::{self, StrRef}, symbol_resolver::{SymbolResolver, SymbolValue, ValueEnum}, toplevel::{DefinitionId, TopLevelDef}, typecheck::{ @@ -8,9 +16,6 @@ use nac3core::{ typedef::{Type, Unifier}, }, }; -use parking_lot::{Mutex, RwLock}; -use std::collections::HashSet; -use std::{collections::HashMap, sync::Arc}; pub struct ResolverInternal { pub id_to_type: Mutex>, @@ -45,9 +50,9 @@ impl SymbolResolver for Resolver { fn get_symbol_type( &self, - _: &mut Unifier, + unifier: &mut Unifier, _: &[Arc>], - _: &PrimitiveStore, + primitives: &PrimitiveStore, str: StrRef, ) -> Result { self.0 @@ -55,15 +60,41 @@ impl SymbolResolver for Resolver { .lock() .get(&str) .copied() + .or_else(|| { + self.0 + .module_globals + .lock() + .get(&str) + .cloned() + .map(|v| v.get_type(primitives, unifier)) + }) .ok_or(format!("cannot get type of {str}")) } fn get_symbol_value<'ctx>( &self, - _: StrRef, - _: &mut CodeGenContext<'ctx, '_>, + str: StrRef, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut dyn CodeGenerator, ) -> Option> { - unimplemented!() + self.0.module_globals.lock().get(&str).cloned().map(|v| { + ctx.module + .get_global(&str.to_string()) + .unwrap_or_else(|| { + let ty = v.get_type(&ctx.primitives, &mut ctx.unifier); + + let init_val = ctx.gen_symbol_val(generator, &v, ty); + let llvm_ty = init_val.get_type(); + + let global = ctx.module.add_global(llvm_ty, None, &str.to_string()); + global.set_linkage(Linkage::LinkOnceAny); + global.set_initializer(&init_val); + + global + }) + .as_basic_value_enum() + .into() + }) } fn get_identifier_def(&self, id: StrRef) -> Result> { diff --git a/src/main.rs b/src/main.rs index af0bd2f..9bf1074 100644 --- a/src/main.rs +++ b/src/main.rs @@ -199,6 +199,36 @@ fn handle_assignment_pattern( } } +fn handle_global_var( + target: &nac3parser::ast::Expr, + value: Option<&nac3parser::ast::Expr>, + resolver: &(dyn nac3core::symbol_resolver::SymbolResolver + Send + Sync), + internal_resolver: &ResolverInternal, +) -> Result<(), String> { + let nac3parser::ast::ExprKind::Name { id, .. } = target.node else { + return Err(format!( + "global variable declaration must be an identifier (at {})", + target.location, + )); + }; + + let Some(value) = value else { + return Err(format!( + "global variable `{id}` must be initialized in its definition" + )); + }; + + if let Ok(val) = toplevel::helper::parse_parameter_default_value(value, resolver) { + internal_resolver.add_module_global(id, val); + Ok(()) + } else { + Err(format!( + "failed to evaluate this expression `{:?}` as a constant at {}", + target.node, target.location, + )) + } +} + fn compile(code: &String, run_symbol: &String, output_filename: &Path) -> Result<(), String> { let mut target_machine_options = codegen::CodeGenTargetMachineOptions::from_host(); target_machine_options.reloc_mode = inkwell::targets::RelocMode::PIC; @@ -241,34 +271,54 @@ fn compile(code: &String, run_symbol: &String, output_filename: &Path) -> Result } }; for mut stmt in parser_result { - if let nac3parser::ast::StmtKind::Assign { targets, value, .. } = stmt.node { - let def_list = composer.extract_def_list(); - let unifier = &mut composer.unifier; - let primitives = &composer.primitives_ty; - handle_assignment_pattern( - &targets, - &value, - resolver.as_ref(), - internal_resolver.as_ref(), - &def_list, - unifier, - primitives, - )?; - } else { - if let nac3parser::ast::StmtKind::FunctionDef { name, .. } = &mut stmt.node { - if name.to_string() == "run" { - *name = run_symbol.as_str().into(); - } + match stmt.node { + nac3parser::ast::StmtKind::Assign { targets, value, .. } => { + let def_list = composer.extract_def_list(); + let unifier = &mut composer.unifier; + let primitives = &composer.primitives_ty; + handle_assignment_pattern( + &targets, + &value, + resolver.as_ref(), + internal_resolver.as_ref(), + &def_list, + unifier, + primitives, + )?; } - match composer.register_top_level(stmt, Some(resolver.clone()), "__main__", true) { - Ok((name, def_id, ty)) => { - internal_resolver.add_id_def(name, def_id); - if let Some(ty) = ty { - internal_resolver.add_id_type(name, ty); + nac3parser::ast::StmtKind::AnnAssign { + ref target, + ref value, + .. + } => { + handle_global_var( + &target, + value.as_ref().map(Box::as_ref), + resolver.as_ref(), + internal_resolver.as_ref(), + )?; + + let (name, def_id, _) = composer + .register_top_level(stmt, Some(resolver.clone()), "__main__", true) + .unwrap(); + internal_resolver.add_id_def(name, def_id); + } + _ => { + if let nac3parser::ast::StmtKind::FunctionDef { name, .. } = &mut stmt.node { + if name.to_string() == "run" { + *name = run_symbol.as_str().into(); } } - Err(err) => { - return Err(format!("composer error: {}", err)); + match composer.register_top_level(stmt, Some(resolver.clone()), "__main__", true) { + Ok((name, def_id, ty)) => { + internal_resolver.add_id_def(name, def_id); + if let Some(ty) = ty { + internal_resolver.add_id_type(name, ty); + } + } + Err(err) => { + return Err(format!("composer error: {}", err)); + } } } }