diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index a462cf93..62dae23d 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -2,6 +2,7 @@ use nac3core::{ codegen::{expr::gen_call, stmt::gen_with, CodeGenContext, CodeGenerator}, toplevel::DefinitionId, typecheck::typedef::{FunSignature, Type}, + symbol_resolver::ValueEnum, }; use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef}; @@ -38,13 +39,13 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { fn gen_call<'ctx, 'a>( &mut self, ctx: &mut CodeGenContext<'ctx, 'a>, - obj: Option<(Type, BasicValueEnum<'ctx>)>, + obj: Option<(Type, ValueEnum<'ctx>)>, fun: (&FunSignature, DefinitionId), - params: Vec<(Option, BasicValueEnum<'ctx>)>, + params: Vec<(Option, ValueEnum<'ctx>)>, ) -> Option> { let result = gen_call(self, ctx, obj, fun, params); if let Some(end) = self.end.clone() { - let old_end = self.gen_expr(ctx, &end).unwrap(); + let old_end = self.gen_expr(ctx, &end).unwrap().to_basic_value_enum(ctx); let now = self.timeline.emit_now_mu(ctx); let smax = ctx.module.get_function("llvm.smax.i64").unwrap_or_else(|| { let i64 = ctx.ctx.i64_type(); @@ -64,7 +65,7 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { ctx.builder.build_store(end_store, max); } if let Some(start) = self.start.clone() { - let start_val = self.gen_expr(ctx, &start).unwrap(); + let start_val = self.gen_expr(ctx, &start).unwrap().to_basic_value_enum(ctx); self.timeline.emit_at_mu(ctx, start_val); } result @@ -96,7 +97,7 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { let old_start = self.start.take(); let old_end = self.end.take(); let now = if let Some(old_start) = &old_start { - self.gen_expr(ctx, old_start).unwrap() + self.gen_expr(ctx, old_start).unwrap().to_basic_value_enum(ctx) } else { self.timeline.emit_now_mu(ctx) }; @@ -145,7 +146,7 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { } // set duration let end_expr = self.end.take().unwrap(); - let end_val = self.gen_expr(ctx, &end_expr).unwrap(); + let end_val = self.gen_expr(ctx, &end_expr).unwrap().to_basic_value_enum(ctx); // inside an sequential block if old_start.is_none() { @@ -153,7 +154,7 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { } // inside a parallel block, should update the outer max now_mu if let Some(old_end) = &old_end { - let outer_end_val = self.gen_expr(ctx, old_end).unwrap(); + let outer_end_val = self.gen_expr(ctx, old_end).unwrap().to_basic_value_enum(ctx); let smax = ctx.module.get_function("llvm.smax.i64").unwrap_or_else(|| { let i64 = ctx.ctx.i64_type(); ctx.module.add_function( diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index fb71ee89..61e16b94 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -8,12 +8,12 @@ use inkwell::{ targets::*, OptimizationLevel, }; -use pyo3::prelude::*; -use pyo3::{exceptions, types::PyList, types::PySet, types::PyBytes}; use nac3parser::{ ast::{self, StrRef}, parser::{self, parse_program}, }; +use pyo3::prelude::*; +use pyo3::{exceptions, types::PyBytes, types::PyList, types::PySet}; use parking_lot::{Mutex, RwLock}; @@ -27,7 +27,10 @@ use nac3core::{ use tempfile::{self, TempDir}; -use crate::{codegen::ArtiqCodeGenerator, symbol_resolver::Resolver}; +use crate::{ + codegen::ArtiqCodeGenerator, + symbol_resolver::{InnerResolver, PythonHelper, Resolver}, +}; mod codegen; mod symbol_resolver; @@ -73,33 +76,45 @@ struct Nac3 { } impl Nac3 { - fn register_module(&mut self, module: PyObject, registered_class_ids: &HashSet) -> PyResult<()> { + fn register_module( + &mut self, + module: PyObject, + registered_class_ids: &HashSet, + ) -> PyResult<()> { let mut name_to_pyid: HashMap = HashMap::new(); - let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> { - let module: &PyAny = module.extract(py)?; - let builtins = PyModule::import(py, "builtins")?; - let id_fn = builtins.getattr("id")?; - let members: &PyList = PyModule::import(py, "inspect")? - .getattr("getmembers")? - .call1((module,))? - .cast_as()?; - for member in members.iter() { - let key: &str = member.get_item(0)?.extract()?; - let val = id_fn.call1((member.get_item(1)?,))?.extract()?; - name_to_pyid.insert(key.into(), val); - } - Ok(( - module.getattr("__name__")?.extract()?, - module.getattr("__file__")?.extract()?, - )) - })?; + let (module_name, source_file, helper) = + Python::with_gil(|py| -> PyResult<(String, String, PythonHelper)> { + let module: &PyAny = module.extract(py)?; + let builtins = PyModule::import(py, "builtins")?; + let id_fn = builtins.getattr("id")?; + let members: &PyList = PyModule::import(py, "inspect")? + .getattr("getmembers")? + .call1((module,))? + .cast_as()?; + for member in members.iter() { + let key: &str = member.get_item(0)?.extract()?; + let val = id_fn.call1((member.get_item(1)?,))?.extract()?; + name_to_pyid.insert(key.into(), val); + } + let helper = PythonHelper { + id_fn: builtins.getattr("id").unwrap().to_object(py), + len_fn: builtins.getattr("len").unwrap().to_object(py), + type_fn: builtins.getattr("type").unwrap().to_object(py), + }; + Ok(( + module.getattr("__name__")?.extract()?, + module.getattr("__file__")?.extract()?, + helper, + )) + })?; let source = fs::read_to_string(source_file).map_err(|e| { exceptions::PyIOError::new_err(format!("failed to read input file: {}", e)) })?; let parser_result = parser::parse_program(&source) .map_err(|e| exceptions::PySyntaxError::new_err(format!("parse error: {}", e)))?; - let resolver = Arc::new(Resolver { + + let resolver = Arc::new(Resolver(Arc::new(InnerResolver { id_to_type: self.builtins_ty.clone().into(), id_to_def: self.builtins_def.clone().into(), pyid_to_def: self.pyid_to_def.clone(), @@ -109,7 +124,8 @@ impl Nac3 { class_names: Default::default(), name_to_pyid: name_to_pyid.clone(), module: module.clone(), - }) as Arc; + helper, + }))) as Arc; let mut name_to_def = HashMap::new(); let mut name_to_type = HashMap::new(); @@ -140,10 +156,11 @@ impl Nac3 { let base_obj = module.getattr(py, id.to_string())?; let base_id = id_fn.call1((base_obj,))?.extract()?; Ok(registered_class_ids.contains(&base_id)) - }, - _ => Ok(true) + } + _ => Ok(true), } - }).unwrap() + }) + .unwrap() }); body.retain(|stmt| { if let ast::StmtKind::FunctionDef { @@ -306,7 +323,11 @@ impl Nac3 { }; let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap(); - fs::write(working_directory.path().join("kernel.ld"), include_bytes!("kernel.ld")).unwrap(); + fs::write( + working_directory.path().join("kernel.ld"), + include_bytes!("kernel.ld"), + ) + .unwrap(); Ok(Nac3 { isa, @@ -320,29 +341,30 @@ impl Nac3 { pyid_to_def: Default::default(), pyid_to_type: Default::default(), global_value_ids: Default::default(), - working_directory + working_directory, }) } fn analyze(&mut self, functions: &PySet, classes: &PySet) -> PyResult<()> { - let (modules, class_ids) = Python::with_gil(|py| -> PyResult<(HashMap, HashSet)> { - let mut modules: HashMap = HashMap::new(); - let mut class_ids: HashSet = HashSet::new(); + let (modules, class_ids) = + Python::with_gil(|py| -> PyResult<(HashMap, HashSet)> { + let mut modules: HashMap = HashMap::new(); + let mut class_ids: HashSet = HashSet::new(); - let id_fn = PyModule::import(py, "builtins")?.getattr("id")?; - let getmodule_fn = PyModule::import(py, "inspect")?.getattr("getmodule")?; + let id_fn = PyModule::import(py, "builtins")?.getattr("id")?; + let getmodule_fn = PyModule::import(py, "inspect")?.getattr("getmodule")?; - for function in functions.iter() { - let module = getmodule_fn.call1((function,))?.extract()?; - modules.insert(id_fn.call1((&module,))?.extract()?, module); - } - for class in classes.iter() { - let module = getmodule_fn.call1((class,))?.extract()?; - modules.insert(id_fn.call1((&module,))?.extract()?, module); - class_ids.insert(id_fn.call1((class,))?.extract()?); - } - Ok((modules, class_ids)) - })?; + for function in functions.iter() { + let module = getmodule_fn.call1((function,))?.extract()?; + modules.insert(id_fn.call1((&module,))?.extract()?, module); + } + for class in classes.iter() { + let module = getmodule_fn.call1((class,))?.extract()?; + modules.insert(id_fn.call1((&module,))?.extract()?, module); + class_ids.insert(id_fn.call1((class,))?.extract()?); + } + Ok((modules, class_ids)) + })?; for module in modules.into_values() { self.register_module(module, &class_ids)?; @@ -380,7 +402,13 @@ impl Nac3 { ) }; let mut synthesized = parse_program(&synthesized).unwrap(); - let resolver = Arc::new(Resolver { + let builtins = PyModule::import(py, "builtins")?; + let helper = PythonHelper { + id_fn: builtins.getattr("id").unwrap().to_object(py), + len_fn: builtins.getattr("len").unwrap().to_object(py), + type_fn: builtins.getattr("type").unwrap().to_object(py), + }; + let resolver = Arc::new(Resolver(Arc::new(InnerResolver { id_to_type: self.builtins_ty.clone().into(), id_to_def: self.builtins_def.clone().into(), pyid_to_def: self.pyid_to_def.clone(), @@ -390,7 +418,8 @@ impl Nac3 { class_names: Default::default(), name_to_pyid, module: module.to_object(py), - }) as Arc; + helper, + }))) as Arc; let (_, def_id, _) = self .composer .register_top_level( @@ -443,6 +472,7 @@ impl Nac3 { store, unifier_index: instance.unifier_id, calls: instance.calls, + id: 0, }; let isa = self.isa; let working_directory = self.working_directory.path().to_owned(); @@ -454,9 +484,18 @@ impl Nac3 { passes.run_on(module); let (triple, features) = match isa { - Isa::Host => (TargetMachine::get_default_triple(), TargetMachine::get_host_cpu_features().to_string()), - Isa::RiscV32G => (TargetTriple::create("riscv32-unknown-linux"), "+a,+m,+f,+d".to_string()), - Isa::RiscV32IMA => (TargetTriple::create("riscv32-unknown-linux"), "+a,+m".to_string()), + Isa::Host => ( + TargetMachine::get_default_triple(), + TargetMachine::get_host_cpu_features().to_string(), + ), + Isa::RiscV32G => ( + TargetTriple::create("riscv32-unknown-linux"), + "+a,+m,+f,+d".to_string(), + ), + Isa::RiscV32IMA => ( + TargetTriple::create("riscv32-unknown-linux"), + "+a,+m".to_string(), + ), Isa::CortexA9 => ( TargetTriple::create("armv7-unknown-linux-gnueabihf"), "+dsp,+fp16,+neon,+vfp3".to_string(), @@ -502,11 +541,24 @@ impl Nac3 { filename.to_string(), ]; if isa != Isa::Host { - linker_args.push("-T".to_string() + self.working_directory.path().join("kernel.ld").to_str().unwrap()); + linker_args.push( + "-T".to_string() + + self + .working_directory + .path() + .join("kernel.ld") + .to_str() + .unwrap(), + ); } linker_args.extend(thread_names.iter().map(|name| { let name_o = name.to_owned() + ".o"; - self.working_directory.path().join(name_o.as_str()).to_str().unwrap().to_string() + self.working_directory + .path() + .join(name_o.as_str()) + .to_str() + .unwrap() + .to_string() })); if let Ok(linker_status) = Command::new("ld.lld").args(linker_args).status() { if !linker_status.success() { diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 16525096..0a85824b 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -2,19 +2,19 @@ use inkwell::{types::BasicType, values::BasicValueEnum, AddressSpace}; use nac3core::{ codegen::CodeGenContext, location::Location, - symbol_resolver::{SymbolResolver, SymbolValue}, + symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum}, toplevel::{DefinitionId, TopLevelDef}, typecheck::{ type_inferencer::PrimitiveStore, typedef::{Type, TypeEnum, Unifier}, }, }; +use nac3parser::ast::{self, StrRef}; use parking_lot::{Mutex, RwLock}; use pyo3::{ types::{PyList, PyModule, PyTuple}, PyAny, PyObject, PyResult, Python, }; -use nac3parser::ast::{self, StrRef}; use std::{ cell::RefCell, collections::{HashMap, HashSet}, @@ -23,7 +23,7 @@ use std::{ use crate::PrimitivePythonId; -pub struct Resolver { +pub struct InnerResolver { pub id_to_type: Mutex>, pub id_to_def: Mutex>, pub global_value_ids: Arc>>, @@ -31,32 +31,94 @@ pub struct Resolver { pub pyid_to_def: Arc>>, pub pyid_to_type: Arc>>, pub primitive_ids: PrimitivePythonId, + pub helper: PythonHelper, // module specific pub name_to_pyid: HashMap, pub module: PyObject, } -struct PythonHelper<'a> { - type_fn: &'a PyAny, - len_fn: &'a PyAny, - id_fn: &'a PyAny, +pub struct Resolver(pub Arc); + +pub struct PythonHelper { + pub type_fn: PyObject, + pub len_fn: PyObject, + pub id_fn: PyObject, } -impl Resolver { +struct PythonValue { + id: u64, + value: PyObject, + resolver: Arc, +} + +impl StaticValue for PythonValue { + fn get_unique_identifier(&self) -> u64 { + self.id + } + + fn to_basic_value_enum<'ctx, 'a>( + &self, + ctx: &mut CodeGenContext<'ctx, 'a>, + ) -> BasicValueEnum<'ctx> { + Python::with_gil(|py| -> PyResult> { + self.resolver + .get_obj_value(py, self.value.as_ref(py), ctx) + .map(Option::unwrap) + }) + .unwrap() + } + + fn get_field<'ctx, 'a>( + &self, + name: StrRef, + ctx: &mut CodeGenContext<'ctx, 'a>, + ) -> Option> { + Python::with_gil(|py| -> PyResult>> { + let helper = &self.resolver.helper; + let ty = helper.type_fn.call1(py, (&self.value,))?; + let ty_id: u64 = helper.id_fn.call1(py, (ty,))?.extract(py)?; + let def_id = { *self.resolver.pyid_to_def.read().get(&ty_id).unwrap() }; + let mut mutable = true; + let defs = ctx.top_level.definitions.read(); + if let TopLevelDef::Class { fields, .. } = &*defs[def_id.0].read() { + for (field_name, _, is_mutable) in fields.iter() { + if field_name == &name { + mutable = *is_mutable; + break; + } + } + } + Ok(if mutable { + None + } else { + let obj = self.value.getattr(py, &name.to_string())?; + let id = self.resolver.helper.id_fn.call1(py, (&obj,))?.extract(py)?; + Some(ValueEnum::Static(Arc::new(PythonValue { + id, + value: obj, + resolver: self.resolver.clone(), + }))) + }) + }) + .unwrap() + } +} + +impl InnerResolver { fn get_list_elem_type( &self, + py: Python, list: &PyAny, len: usize, - helper: &PythonHelper, unifier: &mut Unifier, defs: &[Arc>], primitives: &PrimitiveStore, ) -> PyResult> { - let first = self.get_obj_type(list.get_item(0)?, helper, unifier, defs, primitives)?; + let first = self.get_obj_type(py, list.get_item(0)?, unifier, defs, primitives)?; Ok((1..len).fold(first, |a, i| { let b = list .get_item(i) - .map(|elem| self.get_obj_type(elem, helper, unifier, defs, primitives)); + .map(|elem| self.get_obj_type(py, elem, unifier, defs, primitives)); a.and_then(|a| { if let Ok(Ok(Some(ty))) = b { if unifier.unify(a, ty).is_ok() { @@ -73,16 +135,17 @@ impl Resolver { fn get_obj_type( &self, + py: Python, obj: &PyAny, - helper: &PythonHelper, unifier: &mut Unifier, defs: &[Arc>], primitives: &PrimitiveStore, ) -> PyResult> { - let ty_id: u64 = helper + let ty_id: u64 = self + .helper .id_fn - .call1((helper.type_fn.call1((obj,))?,))? - .extract()?; + .call1(py, (self.helper.type_fn.call1(py, (obj,))?,))? + .extract(py)?; if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 { Ok(Some(primitives.int32)) @@ -93,20 +156,20 @@ impl Resolver { } else if ty_id == self.primitive_ids.float { Ok(Some(primitives.float)) } else if ty_id == self.primitive_ids.list { - let len: usize = helper.len_fn.call1((obj,))?.extract()?; + let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?; if len == 0 { let var = unifier.get_fresh_var().0; let list = unifier.add_ty(TypeEnum::TList { ty: var }); Ok(Some(list)) } else { - let ty = self.get_list_elem_type(obj, len, helper, unifier, defs, primitives)?; + let ty = self.get_list_elem_type(py, obj, len, unifier, defs, primitives)?; Ok(ty.map(|ty| unifier.add_ty(TypeEnum::TList { ty }))) } } else if ty_id == self.primitive_ids.tuple { let elements: &PyTuple = obj.cast_as()?; let types: Result>, _> = elements .iter() - .map(|elem| self.get_obj_type(elem, helper, unifier, defs, primitives)) + .map(|elem| self.get_obj_type(py, elem, unifier, defs, primitives)) .collect(); let types = types?; Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types }))) @@ -141,7 +204,7 @@ impl Resolver { let name: String = field.0.into(); let field_data = obj.getattr(&name)?; let ty = self - .get_obj_type(field_data, helper, unifier, defs, primitives)? + .get_obj_type(py, field_data, unifier, defs, primitives)? .unwrap_or(primitives.none); let field_ty = unifier.subst(field.1, &var_map).unwrap_or(field.1); if unifier.unify(ty, field_ty).is_err() { @@ -153,7 +216,7 @@ impl Resolver { for (_, ty) in var_map.iter() { // must be concrete type if !unifier.is_concrete(*ty, &[]) { - return Ok(None) + return Ok(None); } } Ok(Some(unifier.add_ty(TypeEnum::TObj { @@ -172,14 +235,15 @@ impl Resolver { fn get_obj_value<'ctx, 'a>( &self, + py: Python, obj: &PyAny, - helper: &PythonHelper, ctx: &mut CodeGenContext<'ctx, 'a>, ) -> PyResult>> { - let ty_id: u64 = helper + let ty_id: u64 = self + .helper .id_fn - .call1((helper.type_fn.call1((obj,))?,))? - .extract()?; + .call1(py, (self.helper.type_fn.call1(py, (obj,))?,))? + .extract(py)?; if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 { let val: i32 = obj.extract()?; Ok(Some(ctx.ctx.i32_type().const_int(val as u64, false).into())) @@ -195,16 +259,16 @@ impl Resolver { let val: f64 = obj.extract()?; Ok(Some(ctx.ctx.f64_type().const_float(val).into())) } else if ty_id == self.primitive_ids.list { - let id: u64 = helper.id_fn.call1((obj,))?.extract()?; + let id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?; let id_str = id.to_string(); - let len: usize = helper.len_fn.call1((obj,))?.extract()?; + let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?; let ty = if len == 0 { ctx.primitives.int32 } else { self.get_list_elem_type( + py, obj, len, - helper, &mut ctx.unifier, &ctx.top_level.definitions.read(), &ctx.primitives, @@ -236,7 +300,7 @@ impl Resolver { let arr: Result>, _> = (0..len) .map(|i| { obj.get_item(i) - .and_then(|elem| self.get_obj_value(elem, helper, ctx)) + .and_then(|elem| self.get_obj_value(py, elem, ctx)) }) .collect(); let arr = arr?.unwrap(); @@ -297,15 +361,15 @@ impl Resolver { Ok(Some(global.as_pointer_value().into())) } else if ty_id == self.primitive_ids.tuple { - let id: u64 = helper.id_fn.call1((obj,))?.extract()?; + let id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?; let id_str = id.to_string(); let elements: &PyTuple = obj.cast_as()?; let types: Result>, _> = elements .iter() .map(|elem| { self.get_obj_type( + py, elem, - helper, &mut ctx.unifier, &ctx.top_level.definitions.read(), &ctx.primitives, @@ -331,7 +395,7 @@ impl Resolver { let val: Result>, _> = elements .iter() - .map(|elem| self.get_obj_value(elem, helper, ctx)) + .map(|elem| self.get_obj_value(py, elem, ctx)) .collect(); let val = val?.unwrap(); let val = ctx.ctx.const_struct(&val, false); @@ -341,17 +405,11 @@ impl Resolver { global.set_initializer(&val); Ok(Some(global.as_pointer_value().into())) } else { - let id: u64 = helper.id_fn.call1((obj,))?.extract()?; + let id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?; let id_str = id.to_string(); let top_level_defs = ctx.top_level.definitions.read(); let ty = self - .get_obj_type( - obj, - helper, - &mut ctx.unifier, - &top_level_defs, - &ctx.primitives, - )? + .get_obj_type(py, obj, &mut ctx.unifier, &top_level_defs, &ctx.primitives)? .unwrap(); let ty = ctx .get_llvm_type(ty) @@ -380,7 +438,7 @@ impl Resolver { let values: Result>, _> = fields .iter() .map(|(name, _, _)| { - self.get_obj_value(obj.getattr(&name.to_string())?, helper, ctx) + self.get_obj_value(py, obj.getattr(&name.to_string())?, ctx) }) .collect(); let values = values?; @@ -400,11 +458,16 @@ impl Resolver { } } - fn get_default_param_obj_value(&self, obj: &PyAny, helper: &PythonHelper) -> PyResult> { - let ty_id: u64 = helper + fn get_default_param_obj_value( + &self, + py: Python, + obj: &PyAny, + ) -> PyResult> { + let ty_id: u64 = self + .helper .id_fn - .call1((helper.type_fn.call1((obj,))?,))? - .extract()?; + .call1(py, (self.helper.type_fn.call1(py, (obj,))?,))? + .extract(py)?; Ok( if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 { let val: i32 = obj.extract()?; @@ -422,60 +485,49 @@ impl Resolver { let elements: &PyTuple = obj.cast_as()?; let elements: Result, String>, _> = elements .iter() - .map(|elem| { - self.get_default_param_obj_value( - elem, - helper - ) - }) + .map(|elem| self.get_default_param_obj_value(py, elem)) .collect(); let elements = match elements? { Ok(el) => el, - Err(err) => return Ok(Err(err)) + Err(err) => return Ok(Err(err)), }; Ok(SymbolValue::Tuple(elements)) } else { Err("only primitives values and tuple can be default parameter value".into()) - } + }, ) } } impl SymbolResolver for Resolver { - fn get_default_param_value( - &self, - expr: &ast::Expr - ) -> Option { + fn get_default_param_value(&self, expr: &ast::Expr) -> Option { match &expr.node { ast::ExprKind::Name { id, .. } => { - Python::with_gil( - |py| -> PyResult> { - let obj: &PyAny = self.module.extract(py)?; - let members: &PyList = PyModule::import(py, "inspect")? - .getattr("getmembers")? - .call1((obj,))? - .cast_as()?; - let mut sym_value = None; - for member in members.iter() { - let key: &str = member.get_item(0)?.extract()?; - let val = member.get_item(1)?; - if key == id.to_string() { - let builtins = PyModule::import(py, "builtins")?; - let helper = PythonHelper { - id_fn: builtins.getattr("id").unwrap(), - len_fn: builtins.getattr("len").unwrap(), - type_fn: builtins.getattr("type").unwrap(), - }; - sym_value = Some(self.get_default_param_obj_value(val, &helper).unwrap().unwrap()); - break; - } + Python::with_gil(|py| -> PyResult> { + let obj: &PyAny = self.0.module.extract(py)?; + let members: &PyList = PyModule::import(py, "inspect")? + .getattr("getmembers")? + .call1((obj,))? + .cast_as()?; + let mut sym_value = None; + for member in members.iter() { + let key: &str = member.get_item(0)?.extract()?; + let val = member.get_item(1)?; + if key == id.to_string() { + sym_value = Some( + self.0 + .get_default_param_obj_value(py, val) + .unwrap() + .unwrap(), + ); + break; } - Ok(sym_value) } - ) + Ok(sym_value) + }) .unwrap() } - _ => unimplemented!("other type of expr not supported at {}", expr.location) + _ => unimplemented!("other type of expr not supported at {}", expr.location), } } @@ -486,13 +538,13 @@ impl SymbolResolver for Resolver { primitives: &PrimitiveStore, str: StrRef, ) -> Option { - let mut id_to_type = self.id_to_type.lock(); + let mut id_to_type = self.0.id_to_type.lock(); id_to_type.get(&str).cloned().or_else(|| { - let py_id = self.name_to_pyid.get(&str); + let py_id = self.0.name_to_pyid.get(&str); let result = py_id.and_then(|id| { - self.pyid_to_type.read().get(id).copied().or_else(|| { + self.0.pyid_to_type.read().get(id).copied().or_else(|| { Python::with_gil(|py| -> PyResult> { - let obj: &PyAny = self.module.extract(py)?; + let obj: &PyAny = self.0.module.extract(py)?; let members: &PyList = PyModule::import(py, "inspect")? .getattr("getmembers")? .call1((obj,))? @@ -501,15 +553,9 @@ impl SymbolResolver for Resolver { for member in members.iter() { let key: &str = member.get_item(0)?.extract()?; if key == str.to_string() { - let builtins = PyModule::import(py, "builtins")?; - let helper = PythonHelper { - id_fn: builtins.getattr("id").unwrap(), - len_fn: builtins.getattr("len").unwrap(), - type_fn: builtins.getattr("type").unwrap(), - }; - sym_ty = self.get_obj_type( + sym_ty = self.0.get_obj_type( + py, member.get_item(1)?, - &helper, unifier, defs, primitives, @@ -532,10 +578,10 @@ impl SymbolResolver for Resolver { fn get_symbol_value<'ctx, 'a>( &self, id: StrRef, - ctx: &mut CodeGenContext<'ctx, 'a>, - ) -> Option> { - Python::with_gil(|py| -> PyResult>> { - let obj: &PyAny = self.module.extract(py)?; + _: &mut CodeGenContext<'ctx, 'a>, + ) -> Option> { + Python::with_gil(|py| -> PyResult>> { + let obj: &PyAny = self.0.module.extract(py)?; let members: &PyList = PyModule::import(py, "inspect")? .getattr("getmembers")? .call1((obj,))? @@ -545,17 +591,16 @@ impl SymbolResolver for Resolver { let key: &str = member.get_item(0)?.extract()?; let val = member.get_item(1)?; if key == id.to_string() { - let builtins = PyModule::import(py, "builtins")?; - let helper = PythonHelper { - id_fn: builtins.getattr("id").unwrap(), - len_fn: builtins.getattr("len").unwrap(), - type_fn: builtins.getattr("type").unwrap(), - }; - sym_value = self.get_obj_value(val, &helper, ctx)?; + let id = self.0.helper.id_fn.call1(py, (val,))?.extract(py)?; + sym_value = Some(PythonValue { + id, + value: val.extract()?, + resolver: self.0.clone(), + }); break; } } - Ok(sym_value) + Ok(sym_value.map(|v| ValueEnum::Static(Arc::new(v)))) }) .unwrap() } @@ -565,10 +610,10 @@ impl SymbolResolver for Resolver { } fn get_identifier_def(&self, id: StrRef) -> Option { - let mut id_to_def = self.id_to_def.lock(); + let mut id_to_def = self.0.id_to_def.lock(); id_to_def.get(&id).cloned().or_else(|| { - let py_id = self.name_to_pyid.get(&id); - let result = py_id.and_then(|id| self.pyid_to_def.read().get(id).copied()); + let py_id = self.0.name_to_pyid.get(&id); + let result = py_id.and_then(|id| self.0.pyid_to_def.read().get(id).copied()); if let Some(result) = &result { id_to_def.insert(id, *result); } diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index b267d330..dd46673e 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -5,7 +5,7 @@ use crate::{ concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore}, get_llvm_type, CodeGenContext, CodeGenTask, }, - symbol_resolver::SymbolValue, + symbol_resolver::{SymbolValue, ValueEnum}, toplevel::{DefinitionId, TopLevelDef}, typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier}, }; @@ -15,9 +15,7 @@ use inkwell::{ AddressSpace, }; use itertools::{chain, izip, zip, Itertools}; -use nac3parser::ast::{ - self, Boolop, Comprehension, Constant, Expr, ExprKind, Operator, StrRef, -}; +use nac3parser::ast::{self, Boolop, Comprehension, Constant, Expr, ExprKind, Operator, StrRef}; use super::CodeGenerator; @@ -231,7 +229,7 @@ pub fn gen_constructor<'ctx, 'a, G: CodeGenerator + ?Sized>( ctx: &mut CodeGenContext<'ctx, 'a>, signature: &FunSignature, def: &TopLevelDef, - params: Vec<(Option, BasicValueEnum<'ctx>)>, + params: Vec<(Option, ValueEnum<'ctx>)>, ) -> BasicValueEnum<'ctx> { match def { TopLevelDef::Class { methods, .. } => { @@ -244,12 +242,17 @@ pub fn gen_constructor<'ctx, 'a, G: CodeGenerator + ?Sized>( } let ty = ctx.get_llvm_type(signature.ret).into_pointer_type(); let zelf_ty: BasicTypeEnum = ty.get_element_type().try_into().unwrap(); - let zelf = ctx.builder.build_alloca(zelf_ty, "alloca").into(); + let zelf: BasicValueEnum<'ctx> = ctx.builder.build_alloca(zelf_ty, "alloca").into(); // call `__init__` if there is one if let Some(fun_id) = fun_id { let mut sign = signature.clone(); sign.ret = ctx.primitives.none; - generator.gen_call(ctx, Some((signature.ret, zelf)), (&sign, fun_id), params); + generator.gen_call( + ctx, + Some((signature.ret, zelf.into())), + (&sign, fun_id), + params, + ); } zelf } @@ -259,8 +262,9 @@ pub fn gen_constructor<'ctx, 'a, G: CodeGenerator + ?Sized>( pub fn gen_func_instance<'ctx, 'a>( ctx: &mut CodeGenContext<'ctx, 'a>, - obj: Option<(Type, BasicValueEnum<'ctx>)>, + obj: Option<(Type, ValueEnum<'ctx>)>, fun: (&FunSignature, &mut TopLevelDef, String), + id: usize, ) -> String { if let ( sign, @@ -272,8 +276,8 @@ pub fn gen_func_instance<'ctx, 'a>( { instance_to_symbol.get(&key).cloned().unwrap_or_else(|| { let symbol = format!("{}.{}", name, instance_to_symbol.len()); - instance_to_symbol.insert(key, symbol.clone()); - let key = ctx.get_subst_key(obj.map(|a| a.0), sign, Some(var_id)); + instance_to_symbol.insert(key.clone(), symbol.clone()); + let key = ctx.get_subst_key(obj.as_ref().map(|a| a.0), sign, Some(var_id)); let instance = instance_to_stmt.get(&key).unwrap(); let mut store = ConcreteTypeStore::new(); @@ -316,6 +320,7 @@ pub fn gen_func_instance<'ctx, 'a>( signature, store, unifier_index: instance.unifier_id, + id, }); symbol }) @@ -327,20 +332,86 @@ pub fn gen_func_instance<'ctx, 'a>( pub fn gen_call<'ctx, 'a, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, - obj: Option<(Type, BasicValueEnum<'ctx>)>, + obj: Option<(Type, ValueEnum<'ctx>)>, fun: (&FunSignature, DefinitionId), - params: Vec<(Option, BasicValueEnum<'ctx>)>, + params: Vec<(Option, ValueEnum<'ctx>)>, ) -> Option> { let definition = ctx.top_level.definitions.read().get(fun.1 .0).cloned().unwrap(); - let key = ctx.get_subst_key(obj.map(|a| a.0), fun.0, None); + + let id; + let key; + let param_vals; let symbol = { // make sure this lock guard is dropped at the end of this scope... let def = definition.read(); match &*def { - TopLevelDef::Function { instance_to_symbol, codegen_callback, .. } => { + TopLevelDef::Function { + instance_to_symbol, + instance_to_stmt, + codegen_callback, + .. + } => { if let Some(callback) = codegen_callback { + // TODO: Change signature + let obj = obj.map(|(t, v)| (t, v.to_basic_value_enum(ctx))); + let params = params + .into_iter() + .map(|(name, val)| (name, val.to_basic_value_enum(ctx))) + .collect(); return callback.run(ctx, obj, fun, params); } + let old_key = ctx.get_subst_key(obj.as_ref().map(|a| a.0), fun.0, None); + let mut keys = fun.0.args.clone(); + let mut mapping = HashMap::new(); + for (key, value) in params.into_iter() { + mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value); + } + // default value handling + for k in keys.into_iter() { + mapping.insert(k.name, ctx.gen_symbol_val(&k.default_value.unwrap()).into()); + } + // reorder the parameters + let mut real_params = + fun.0.args.iter().map(|arg| mapping.remove(&arg.name).unwrap()).collect_vec(); + if let Some(obj) = &obj { + real_params.insert(0, obj.1.clone()); + } + + let static_params = real_params + .iter() + .enumerate() + .filter_map(|(i, v)| { + if let ValueEnum::Static(s) = v { + Some((i, s.clone())) + } else { + None + } + }) + .collect_vec(); + id = { + let ids = static_params + .iter() + .map(|(i, v)| (*i, v.get_unique_identifier())) + .collect_vec(); + let mut store = ctx.static_value_store.lock(); + match store.lookup.get(&ids) { + Some(index) => *index, + None => { + let length = store.store.len(); + store.lookup.insert(ids, length); + store.store.push(static_params.into_iter().collect()); + length + } + } + }; + // special case: extern functions + key = if instance_to_stmt.is_empty() { + "".to_string() + } else { + format!("{}:{}", id, old_key) + }; + param_vals = + real_params.into_iter().map(|p| p.to_basic_value_enum(ctx)).collect_vec(); instance_to_symbol.get(&key).cloned() } TopLevelDef::Class { .. } => { @@ -349,7 +420,7 @@ pub fn gen_call<'ctx, 'a, G: CodeGenerator + ?Sized>( } } .unwrap_or_else(|| { - generator.gen_func_instance(ctx, obj, (fun.0, &mut *definition.write(), key)) + generator.gen_func_instance(ctx, obj.clone(), (fun.0, &mut *definition.write(), key), id) }); let fun_val = ctx.module.get_function(&symbol).unwrap_or_else(|| { let mut args = fun.0.args.clone(); @@ -364,21 +435,7 @@ pub fn gen_call<'ctx, 'a, G: CodeGenerator + ?Sized>( }; ctx.module.add_function(&symbol, fun_ty, None) }); - let mut keys = fun.0.args.clone(); - let mut mapping = HashMap::new(); - for (key, value) in params.into_iter() { - mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value); - } - // default value handling - for k in keys.into_iter() { - mapping.insert(k.name, ctx.gen_symbol_val(&k.default_value.unwrap())); - } - // reorder the parameters - let mut params = fun.0.args.iter().map(|arg| mapping.remove(&arg.name).unwrap()).collect_vec(); - if let Some(obj) = obj { - params.insert(0, obj.1); - } - ctx.builder.build_call(fun_val, ¶ms, "call").try_as_basic_value().left() + ctx.builder.build_call(fun_val, ¶m_vals, "call").try_as_basic_value().left() } pub fn destructure_range<'ctx, 'a>( @@ -435,7 +492,7 @@ pub fn gen_comprehension<'ctx, 'a, G: CodeGenerator + ?Sized>( let cont_bb = ctx.ctx.append_basic_block(current, "cont"); let Comprehension { target, iter, ifs, .. } = &generators[0]; - let iter_val = generator.gen_expr(ctx, iter).unwrap(); + let iter_val = generator.gen_expr(ctx, iter).unwrap().to_basic_value_enum(ctx); let int32 = ctx.ctx.i32_type(); let zero = int32.const_zero(); @@ -534,10 +591,11 @@ pub fn gen_comprehension<'ctx, 'a, G: CodeGenerator + ?Sized>( ) .into_pointer_value(); let val = ctx.build_gep_and_load(arr_ptr, &[tmp]); - generator.gen_assign(ctx, target, val); + generator.gen_assign(ctx, target, val.into()); } for cond in ifs.iter() { - let result = generator.gen_expr(ctx, cond).unwrap().into_int_value(); + let result = + generator.gen_expr(ctx, cond).unwrap().to_basic_value_enum(ctx).into_int_value(); let succ = ctx.ctx.append_basic_block(current, "then"); ctx.builder.build_conditional_branch(result, succ, test_bb); ctx.builder.position_at_end(succ); @@ -545,7 +603,8 @@ pub fn gen_comprehension<'ctx, 'a, G: CodeGenerator + ?Sized>( let elem = generator.gen_expr(ctx, elt).unwrap(); let i = ctx.builder.build_load(index, "i").into_int_value(); let elem_ptr = unsafe { ctx.builder.build_gep(list_content, &[i], "elem_ptr") }; - ctx.builder.build_store(elem_ptr, elem); + let val = elem.to_basic_value_enum(ctx); + ctx.builder.build_store(elem_ptr, val); ctx.builder .build_store(index, ctx.builder.build_int_add(i, int32.const_int(1, false), "inc")); ctx.builder.build_unconditional_branch(test_bb); @@ -562,27 +621,29 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, expr: &Expr>, -) -> Option> { +) -> Option> { let int32 = ctx.ctx.i32_type(); let zero = int32.const_int(0, false); Some(match &expr.node { ExprKind::Constant { value, .. } => { let ty = expr.custom.unwrap(); - ctx.gen_const(value, ty) + ctx.gen_const(value, ty).into() } - ExprKind::Name { id, .. } => { - let ptr = ctx.var_assignment.get(id); - if let Some(ptr) = ptr { - ctx.builder.build_load(*ptr, "load") - } else { + ExprKind::Name { id, .. } => match ctx.var_assignment.get(id) { + Some((ptr, None, _)) => ctx.builder.build_load(*ptr, "load").into(), + Some((_, Some(static_value), _)) => ValueEnum::Static(static_value.clone()), + None => { let resolver = ctx.resolver.clone(); resolver.get_symbol_value(*id, ctx).unwrap() } - } + }, ExprKind::List { elts, .. } => { // this shall be optimized later for constant primitive lists... // we should use memcpy for that instead of generating thousands of stores - let elements = elts.iter().map(|x| generator.gen_expr(ctx, x).unwrap()).collect_vec(); + let elements = elts + .iter() + .map(|x| generator.gen_expr(ctx, x).unwrap().to_basic_value_enum(ctx)) + .collect_vec(); let ty = if elements.is_empty() { int32.into() } else { elements[0].get_type() }; let length = int32.const_int(elements.len() as u64, false); let arr_str_ptr = allocate_list(ctx, ty, length); @@ -602,8 +663,10 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator + ?Sized>( arr_str_ptr.into() } ExprKind::Tuple { elts, .. } => { - let element_val = - elts.iter().map(|x| generator.gen_expr(ctx, x).unwrap()).collect_vec(); + let element_val = elts + .iter() + .map(|x| generator.gen_expr(ctx, x).unwrap().to_basic_value_enum(ctx)) + .collect_vec(); let element_ty = element_val.iter().map(BasicValueEnum::get_type).collect_vec(); let tuple_ty = ctx.ctx.struct_type(&element_ty, false); let tuple_ptr = ctx.builder.build_alloca(tuple_ty, "tuple"); @@ -621,13 +684,31 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator + ?Sized>( } ExprKind::Attribute { value, attr, .. } => { // note that we would handle class methods directly in calls - let index = ctx.get_attr_index(value.custom.unwrap(), *attr); - let ptr = generator.gen_expr(ctx, value).unwrap().into_pointer_value(); - ctx.build_gep_and_load(ptr, &[zero, int32.const_int(index as u64, false)]) + match generator.gen_expr(ctx, value).unwrap() { + ValueEnum::Static(v) => v.get_field(*attr, ctx).unwrap_or_else(|| { + let v = v.to_basic_value_enum(ctx); + let index = ctx.get_attr_index(value.custom.unwrap(), *attr); + ValueEnum::Dynamic(ctx.build_gep_and_load( + v.into_pointer_value(), + &[zero, int32.const_int(index as u64, false)], + )) + }), + ValueEnum::Dynamic(v) => { + let index = ctx.get_attr_index(value.custom.unwrap(), *attr); + ValueEnum::Dynamic(ctx.build_gep_and_load( + v.into_pointer_value(), + &[zero, int32.const_int(index as u64, false)], + )) + } + } } ExprKind::BoolOp { op, values } => { // requires conditional branches for short-circuiting... - let left = generator.gen_expr(ctx, &values[0]).unwrap().into_int_value(); + let left = generator + .gen_expr(ctx, &values[0]) + .unwrap() + .to_basic_value_enum(ctx) + .into_int_value(); let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); let a_bb = ctx.ctx.append_basic_block(current, "a"); let b_bb = ctx.ctx.append_basic_block(current, "b"); @@ -639,13 +720,21 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator + ?Sized>( let a = ctx.ctx.bool_type().const_int(1, false); ctx.builder.build_unconditional_branch(cont_bb); ctx.builder.position_at_end(b_bb); - let b = generator.gen_expr(ctx, &values[1]).unwrap().into_int_value(); + let b = generator + .gen_expr(ctx, &values[1]) + .unwrap() + .to_basic_value_enum(ctx) + .into_int_value(); ctx.builder.build_unconditional_branch(cont_bb); (a, b) } Boolop::And => { ctx.builder.position_at_end(a_bb); - let a = generator.gen_expr(ctx, &values[1]).unwrap().into_int_value(); + let a = generator + .gen_expr(ctx, &values[1]) + .unwrap() + .to_basic_value_enum(ctx) + .into_int_value(); ctx.builder.build_unconditional_branch(cont_bb); ctx.builder.position_at_end(b_bb); let b = ctx.ctx.bool_type().const_int(0, false); @@ -656,13 +745,13 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator + ?Sized>( ctx.builder.position_at_end(cont_bb); let phi = ctx.builder.build_phi(ctx.ctx.bool_type(), "phi"); phi.add_incoming(&[(&a, a_bb), (&b, b_bb)]); - phi.as_basic_value() + phi.as_basic_value().into() } ExprKind::BinOp { op, left, right } => { let ty1 = ctx.unifier.get_representative(left.custom.unwrap()); let ty2 = ctx.unifier.get_representative(right.custom.unwrap()); - let left = generator.gen_expr(ctx, left).unwrap(); - let right = generator.gen_expr(ctx, right).unwrap(); + let left = generator.gen_expr(ctx, left).unwrap().to_basic_value_enum(ctx); + let right = generator.gen_expr(ctx, right).unwrap().to_basic_value_enum(ctx); // we can directly compare the types, because we've got their representatives // which would be unchanged until further unification, which we would never do @@ -674,10 +763,11 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator + ?Sized>( } else { unimplemented!() } + .into() } ExprKind::UnaryOp { op, operand } => { let ty = ctx.unifier.get_representative(operand.custom.unwrap()); - let val = generator.gen_expr(ctx, operand).unwrap(); + let val = generator.gen_expr(ctx, operand).unwrap().to_basic_value_enum(ctx); if ty == ctx.primitives.bool { let val = val.into_int_value(); match op { @@ -734,8 +824,8 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator + ?Sized>( BasicValueEnum::IntValue(lhs), BasicValueEnum::IntValue(rhs), ) = ( - generator.gen_expr(ctx, lhs).unwrap(), - generator.gen_expr(ctx, rhs).unwrap(), + generator.gen_expr(ctx, lhs).unwrap().to_basic_value_enum(ctx), + generator.gen_expr(ctx, rhs).unwrap().to_basic_value_enum(ctx), ) { (lhs, rhs) } else { @@ -756,8 +846,8 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator + ?Sized>( BasicValueEnum::FloatValue(lhs), BasicValueEnum::FloatValue(rhs), ) = ( - generator.gen_expr(ctx, lhs).unwrap(), - generator.gen_expr(ctx, rhs).unwrap(), + generator.gen_expr(ctx, lhs).unwrap().to_basic_value_enum(ctx), + generator.gen_expr(ctx, rhs).unwrap().to_basic_value_enum(ctx), ) { (lhs, rhs) } else { @@ -782,22 +872,23 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator + ?Sized>( .into() // as there should be at least 1 element, it should never be none } ExprKind::IfExp { test, body, orelse } => { - let test = generator.gen_expr(ctx, test).unwrap().into_int_value(); + let test = + generator.gen_expr(ctx, test).unwrap().to_basic_value_enum(ctx).into_int_value(); let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); let then_bb = ctx.ctx.append_basic_block(current, "then"); let else_bb = ctx.ctx.append_basic_block(current, "else"); let cont_bb = ctx.ctx.append_basic_block(current, "cont"); ctx.builder.build_conditional_branch(test, then_bb, else_bb); ctx.builder.position_at_end(then_bb); - let a = generator.gen_expr(ctx, body).unwrap(); + let a = generator.gen_expr(ctx, body).unwrap().to_basic_value_enum(ctx); ctx.builder.build_unconditional_branch(cont_bb); ctx.builder.position_at_end(else_bb); - let b = generator.gen_expr(ctx, orelse).unwrap(); + let b = generator.gen_expr(ctx, orelse).unwrap().to_basic_value_enum(ctx); ctx.builder.build_unconditional_branch(cont_bb); ctx.builder.position_at_end(cont_bb); let phi = ctx.builder.build_phi(a.get_type(), "ifexpr"); phi.add_incoming(&[(&a, then_bb), (&b, else_bb)]); - phi.as_basic_value() + phi.as_basic_value().into() } ExprKind::Call { func, args, keywords } => { let mut params = @@ -825,7 +916,9 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator + ?Sized>( ExprKind::Name { id, .. } => { // TODO: handle primitive casts and function pointers let fun = ctx.resolver.get_identifier_def(*id).expect("Unknown identifier"); - return generator.gen_call(ctx, None, (&signature, fun), params); + return generator + .gen_call(ctx, None, (&signature, fun), params) + .map(|v| v.into()); } ExprKind::Attribute { value, attr, .. } => { let val = generator.gen_expr(ctx, value).unwrap(); @@ -851,12 +944,14 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator + ?Sized>( unreachable!() } }; - return generator.gen_call( - ctx, - Some((value.custom.unwrap(), val)), - (&signature, fun_id), - params, - ); + return generator + .gen_call( + ctx, + Some((value.custom.unwrap(), val)), + (&signature, fun_id), + params, + ) + .map(|v| v.into()); } _ => unimplemented!(), } @@ -867,19 +962,36 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator + ?Sized>( unimplemented!() } else { // TODO: bound check - let v = generator.gen_expr(ctx, value).unwrap().into_pointer_value(); - let index = generator.gen_expr(ctx, slice).unwrap().into_int_value(); + let v = generator + .gen_expr(ctx, value) + .unwrap() + .to_basic_value_enum(ctx) + .into_pointer_value(); + let index = generator + .gen_expr(ctx, slice) + .unwrap() + .to_basic_value_enum(ctx) + .into_int_value(); let arr_ptr = ctx.build_gep_and_load(v, &[int32.const_zero(), int32.const_int(1, false)]); ctx.build_gep_and_load(arr_ptr.into_pointer_value(), &[index]) } } else { - let v = generator.gen_expr(ctx, value).unwrap().into_pointer_value(); - let index = generator.gen_expr(ctx, slice).unwrap().into_int_value(); + let v = generator + .gen_expr(ctx, value) + .unwrap() + .to_basic_value_enum(ctx) + .into_pointer_value(); + let index = generator + .gen_expr(ctx, slice) + .unwrap() + .to_basic_value_enum(ctx) + .into_int_value(); ctx.build_gep_and_load(v, &[int32.const_zero(), index]) } } - ExprKind::ListComp { .. } => gen_comprehension(generator, ctx, expr), + .into(), + ExprKind::ListComp { .. } => gen_comprehension(generator, ctx, expr).into(), _ => unimplemented!(), }) } diff --git a/nac3core/src/codegen/generator.rs b/nac3core/src/codegen/generator.rs index 0f5bd3a0..cf66967e 100644 --- a/nac3core/src/codegen/generator.rs +++ b/nac3core/src/codegen/generator.rs @@ -1,5 +1,6 @@ use crate::{ codegen::{expr::*, stmt::*, CodeGenContext}, + symbol_resolver::ValueEnum, toplevel::{DefinitionId, TopLevelDef}, typecheck::typedef::{FunSignature, Type}, }; @@ -18,9 +19,9 @@ pub trait CodeGenerator { fn gen_call<'ctx, 'a>( &mut self, ctx: &mut CodeGenContext<'ctx, 'a>, - obj: Option<(Type, BasicValueEnum<'ctx>)>, + obj: Option<(Type, ValueEnum<'ctx>)>, fun: (&FunSignature, DefinitionId), - params: Vec<(Option, BasicValueEnum<'ctx>)>, + params: Vec<(Option, ValueEnum<'ctx>)>, ) -> Option> { gen_call(self, ctx, obj, fun, params) } @@ -34,7 +35,7 @@ pub trait CodeGenerator { ctx: &mut CodeGenContext<'ctx, 'a>, signature: &FunSignature, def: &TopLevelDef, - params: Vec<(Option, BasicValueEnum<'ctx>)>, + params: Vec<(Option, ValueEnum<'ctx>)>, ) -> BasicValueEnum<'ctx> { gen_constructor(self, ctx, signature, def, params) } @@ -49,10 +50,11 @@ pub trait CodeGenerator { fn gen_func_instance<'ctx, 'a>( &mut self, ctx: &mut CodeGenContext<'ctx, 'a>, - obj: Option<(Type, BasicValueEnum<'ctx>)>, + obj: Option<(Type, ValueEnum<'ctx>)>, fun: (&FunSignature, &mut TopLevelDef, String), + id: usize, ) -> String { - gen_func_instance(ctx, obj, fun) + gen_func_instance(ctx, obj, fun, id) } /// Generate the code for an expression. @@ -60,7 +62,7 @@ pub trait CodeGenerator { &mut self, ctx: &mut CodeGenContext<'ctx, 'a>, expr: &Expr>, - ) -> Option> { + ) -> Option> { gen_expr(self, ctx, expr) } @@ -88,7 +90,7 @@ pub trait CodeGenerator { &mut self, ctx: &mut CodeGenContext<'ctx, 'a>, target: &Expr>, - value: BasicValueEnum<'ctx>, + value: ValueEnum<'ctx>, ) { gen_assign(self, ctx, target, value) } diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index b6c94f93..b01aaf66 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -1,5 +1,5 @@ use crate::{ - symbol_resolver::SymbolResolver, + symbol_resolver::{StaticValue, SymbolResolver}, toplevel::{TopLevelContext, TopLevelDef}, typecheck::{ type_inferencer::{CodeLocation, PrimitiveStore}, @@ -18,8 +18,8 @@ use inkwell::{ AddressSpace, OptimizationLevel, }; use itertools::Itertools; -use parking_lot::{Condvar, Mutex}; use nac3parser::ast::{Stmt, StrRef}; +use parking_lot::{Condvar, Mutex}; use std::collections::HashMap; use std::sync::{ atomic::{AtomicBool, Ordering}, @@ -29,8 +29,8 @@ use std::thread; pub mod concrete_type; pub mod expr; -pub mod stmt; mod generator; +pub mod stmt; #[cfg(test)] mod test; @@ -38,6 +38,14 @@ mod test; use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore}; pub use generator::{CodeGenerator, DefaultCodeGenerator}; +#[derive(Default)] +pub struct StaticValueStore { + pub lookup: HashMap, usize>, + pub store: Vec>>, +} + +pub type VarValue<'ctx> = (PointerValue<'ctx>, Option>, i64); + pub struct CodeGenContext<'ctx, 'a> { pub ctx: &'ctx Context, pub builder: Builder<'ctx>, @@ -45,7 +53,8 @@ pub struct CodeGenContext<'ctx, 'a> { pub top_level: &'a TopLevelContext, pub unifier: Unifier, pub resolver: Arc, - pub var_assignment: HashMap>, + pub static_value_store: Arc>, + pub var_assignment: HashMap>, pub type_cache: HashMap>, pub primitives: PrimitiveStore, pub calls: Arc>, @@ -80,6 +89,8 @@ pub struct WorkerRegistry { task_count: Mutex, thread_count: usize, wait_condvar: Condvar, + top_level_ctx: Arc, + static_value_store: Arc>, } impl WorkerRegistry { @@ -92,23 +103,29 @@ impl WorkerRegistry { let task_count = Mutex::new(0); let wait_condvar = Condvar::new(); + // init: 0 to be empty + let mut static_value_store: StaticValueStore = Default::default(); + static_value_store.lookup.insert(Default::default(), 0); + static_value_store.store.push(Default::default()); + let registry = Arc::new(WorkerRegistry { sender: Arc::new(sender), receiver: Arc::new(receiver), thread_count: generators.len(), panicked: AtomicBool::new(false), + static_value_store: Arc::new(Mutex::new(static_value_store)), task_count, wait_condvar, + top_level_ctx, }); let mut handles = Vec::new(); for mut generator in generators.into_iter() { - let top_level_ctx = top_level_ctx.clone(); let registry = registry.clone(); let registry2 = registry.clone(); let f = f.clone(); let handle = thread::spawn(move || { - registry.worker_thread(generator.as_mut(), top_level_ctx, f); + registry.worker_thread(generator.as_mut(), f); }); let handle = thread::spawn(move || { if let Err(e) = handle.join() { @@ -161,12 +178,7 @@ impl WorkerRegistry { self.sender.send(Some(task)).unwrap(); } - fn worker_thread( - &self, - generator: &mut G, - top_level_ctx: Arc, - f: Arc, - ) { + fn worker_thread(&self, generator: &mut G, f: Arc) { let context = Context::create(); let mut builder = context.create_builder(); let mut module = context.create_module(generator.get_name()); @@ -177,8 +189,7 @@ impl WorkerRegistry { pass_builder.populate_function_pass_manager(&passes); while let Some(task) = self.receiver.recv().unwrap() { - let result = - gen_func(&context, generator, self, builder, module, task, top_level_ctx.clone()); + let result = gen_func(&context, generator, self, builder, module, task); builder = result.0; module = result.1; passes.run_on(&result.2); @@ -208,6 +219,7 @@ pub struct CodeGenTask { pub calls: Arc>, pub unifier_index: usize, pub resolver: Arc, + pub id: usize, } fn get_llvm_type<'ctx>( @@ -268,8 +280,9 @@ pub fn gen_func<'ctx, G: CodeGenerator + ?Sized>( builder: Builder<'ctx>, module: Module<'ctx>, task: CodeGenTask, - top_level_ctx: Arc, ) -> (Builder<'ctx>, Module<'ctx>, FunctionValue<'ctx>) { + let top_level_ctx = registry.top_level_ctx.clone(); + let static_value_store = registry.static_value_store.clone(); let (mut unifier, primitives) = { let (unifier, primitives) = &top_level_ctx.unifiers.read()[task.unifier_index]; (Unifier::from_shared_unifier(unifier), *primitives) @@ -306,7 +319,10 @@ pub fn gen_func<'ctx, G: CodeGenerator + ?Sized>( (unifier.get_representative(primitives.int64), context.i64_type().into()), (unifier.get_representative(primitives.float), context.f64_type().into()), (unifier.get_representative(primitives.bool), context.bool_type().into()), - (unifier.get_representative(primitives.str), context.i8_type().ptr_type(AddressSpace::Generic).into()), + ( + unifier.get_representative(primitives.str), + context.i8_type().ptr_type(AddressSpace::Generic).into(), + ), ] .iter() .cloned() @@ -366,8 +382,17 @@ pub fn gen_func<'ctx, G: CodeGenerator + ?Sized>( &arg.name.to_string(), ); builder.build_store(alloca, param); - var_assignment.insert(arg.name, alloca); + var_assignment.insert(arg.name, (alloca, None, 0)); } + let static_values = { + let store = registry.static_value_store.lock(); + store.store[task.id].clone() + }; + for (k, v) in static_values.into_iter() { + let (_, static_val, _) = var_assignment.get_mut(&args[k].name).unwrap(); + *static_val = Some(v); + } + builder.build_unconditional_branch(body_bb); builder.position_at_end(body_bb); @@ -385,6 +410,7 @@ pub fn gen_func<'ctx, G: CodeGenerator + ?Sized>( builder, module, unifier, + static_value_store, }; let mut returned = false; diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 1b00f456..cbad16d7 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -1,4 +1,4 @@ -use super::{expr::destructure_range, CodeGenContext, CodeGenerator}; +use super::{expr::destructure_range, CodeGenContext, CodeGenerator, super::symbol_resolver::ValueEnum}; use crate::typecheck::typedef::Type; use inkwell::values::{BasicValue, BasicValueEnum, PointerValue}; use nac3parser::ast::{Expr, ExprKind, Stmt, StmtKind}; @@ -22,14 +22,14 @@ pub fn gen_store_target<'ctx, 'a, G: CodeGenerator + ?Sized>( // very similar to gen_expr, but we don't do an extra load at the end // and we flatten nested tuples match &pattern.node { - ExprKind::Name { id, .. } => ctx.var_assignment.get(id).cloned().unwrap_or_else(|| { + ExprKind::Name { id, .. } => ctx.var_assignment.get(id).map(|v| v.0).unwrap_or_else(|| { let ptr = generator.gen_var_alloc(ctx, pattern.custom.unwrap()); - ctx.var_assignment.insert(*id, ptr); + ctx.var_assignment.insert(*id, (ptr, None, 0)); ptr }), ExprKind::Attribute { value, attr, .. } => { let index = ctx.get_attr_index(value.custom.unwrap(), *attr); - let val = generator.gen_expr(ctx, value).unwrap(); + let val = generator.gen_expr(ctx, value).unwrap().to_basic_value_enum(ctx); let ptr = if let BasicValueEnum::PointerValue(v) = val { v } else { @@ -48,8 +48,13 @@ pub fn gen_store_target<'ctx, 'a, G: CodeGenerator + ?Sized>( } ExprKind::Subscript { value, slice, .. } => { let i32_type = ctx.ctx.i32_type(); - let v = generator.gen_expr(ctx, value).unwrap().into_pointer_value(); - let index = generator.gen_expr(ctx, slice).unwrap().into_int_value(); + let v = generator + .gen_expr(ctx, value) + .unwrap() + .to_basic_value_enum(ctx) + .into_pointer_value(); + let index = + generator.gen_expr(ctx, slice).unwrap().to_basic_value_enum(ctx).into_int_value(); unsafe { let arr_ptr = ctx .build_gep_and_load(v, &[i32_type.const_zero(), i32_type.const_int(1, false)]) @@ -65,24 +70,32 @@ pub fn gen_assign<'ctx, 'a, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, target: &Expr>, - value: BasicValueEnum<'ctx>, + value: ValueEnum<'ctx>, ) { - let i32_type = ctx.ctx.i32_type(); if let ExprKind::Tuple { elts, .. } = &target.node { - if let BasicValueEnum::PointerValue(ptr) = value { + if let BasicValueEnum::PointerValue(ptr) = value.to_basic_value_enum(ctx) { + let i32_type = ctx.ctx.i32_type(); for (i, elt) in elts.iter().enumerate() { let v = ctx.build_gep_and_load( ptr, &[i32_type.const_zero(), i32_type.const_int(i as u64, false)], ); - generator.gen_assign(ctx, elt, v); + generator.gen_assign(ctx, elt, v.into()); } } else { unreachable!() } } else { let ptr = generator.gen_store_target(ctx, target); - ctx.builder.build_store(ptr, value); + if let ExprKind::Name { id, .. } = &target.node { + let (_, static_value, counter) = ctx.var_assignment.get_mut(id).unwrap(); + *counter += 1; + if let ValueEnum::Static(s) = &value { + *static_value = Some(s.clone()); + } + } + let val = value.to_basic_value_enum(ctx); + ctx.builder.build_store(ptr, val); } } @@ -92,6 +105,10 @@ pub fn gen_for<'ctx, 'a, G: CodeGenerator + ?Sized>( stmt: &Stmt>, ) { if let StmtKind::For { iter, target, body, orelse, .. } = &stmt.node { + // var_assignment static values may be changed in another branch + // if so, remove the static value as it may not be correct in this branch + let var_assignment = ctx.var_assignment.clone(); + let int32 = ctx.ctx.i32_type(); let zero = int32.const_zero(); let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); @@ -104,7 +121,7 @@ pub fn gen_for<'ctx, 'a, G: CodeGenerator + ?Sized>( // store loop bb information and restore it later let loop_bb = ctx.loop_bb.replace((test_bb, cont_bb)); - let iter_val = generator.gen_expr(ctx, iter).unwrap(); + let iter_val = generator.gen_expr(ctx, iter).unwrap().to_basic_value_enum(ctx); if ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range) { // setup let iter_val = iter_val.into_pointer_value(); @@ -160,12 +177,18 @@ pub fn gen_for<'ctx, 'a, G: CodeGenerator + ?Sized>( ) .into_pointer_value(); let val = ctx.build_gep_and_load(arr_ptr, &[tmp]); - generator.gen_assign(ctx, target, val); + generator.gen_assign(ctx, target, val.into()); } for stmt in body.iter() { generator.gen_stmt(ctx, stmt); } + for (k, (_, _, counter)) in var_assignment.iter() { + let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap(); + if counter != counter2 { + *static_val = None; + } + } ctx.builder.build_unconditional_branch(test_bb); if !orelse.is_empty() { ctx.builder.position_at_end(orelse_bb); @@ -174,6 +197,12 @@ pub fn gen_for<'ctx, 'a, G: CodeGenerator + ?Sized>( } ctx.builder.build_unconditional_branch(cont_bb); } + for (k, (_, _, counter)) in var_assignment.iter() { + let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap(); + if counter != counter2 { + *static_val = None; + } + } ctx.builder.position_at_end(cont_bb); ctx.loop_bb = loop_bb; } else { @@ -187,6 +216,10 @@ pub fn gen_while<'ctx, 'a, G: CodeGenerator + ?Sized>( stmt: &Stmt>, ) { if let StmtKind::While { test, body, orelse, .. } = &stmt.node { + // var_assignment static values may be changed in another branch + // if so, remove the static value as it may not be correct in this branch + let var_assignment = ctx.var_assignment.clone(); + let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); let test_bb = ctx.ctx.append_basic_block(current, "test"); let body_bb = ctx.ctx.append_basic_block(current, "body"); @@ -198,7 +231,7 @@ pub fn gen_while<'ctx, 'a, G: CodeGenerator + ?Sized>( let loop_bb = ctx.loop_bb.replace((test_bb, cont_bb)); ctx.builder.build_unconditional_branch(test_bb); ctx.builder.position_at_end(test_bb); - let test = generator.gen_expr(ctx, test).unwrap(); + let test = generator.gen_expr(ctx, test).unwrap().to_basic_value_enum(ctx); if let BasicValueEnum::IntValue(test) = test { ctx.builder.build_conditional_branch(test, body_bb, orelse_bb); } else { @@ -208,6 +241,12 @@ pub fn gen_while<'ctx, 'a, G: CodeGenerator + ?Sized>( for stmt in body.iter() { generator.gen_stmt(ctx, stmt); } + for (k, (_, _, counter)) in var_assignment.iter() { + let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap(); + if counter != counter2 { + *static_val = None; + } + } ctx.builder.build_unconditional_branch(test_bb); if !orelse.is_empty() { ctx.builder.position_at_end(orelse_bb); @@ -216,6 +255,12 @@ pub fn gen_while<'ctx, 'a, G: CodeGenerator + ?Sized>( } ctx.builder.build_unconditional_branch(cont_bb); } + for (k, (_, _, counter)) in var_assignment.iter() { + let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap(); + if counter != counter2 { + *static_val = None; + } + } ctx.builder.position_at_end(cont_bb); ctx.loop_bb = loop_bb; } else { @@ -229,6 +274,10 @@ pub fn gen_if<'ctx, 'a, G: CodeGenerator + ?Sized>( stmt: &Stmt>, ) -> bool { if let StmtKind::If { test, body, orelse, .. } = &stmt.node { + // var_assignment static values may be changed in another branch + // if so, remove the static value as it may not be correct in this branch + let var_assignment = ctx.var_assignment.clone(); + let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); let test_bb = ctx.ctx.append_basic_block(current, "test"); let body_bb = ctx.ctx.append_basic_block(current, "body"); @@ -242,7 +291,7 @@ pub fn gen_if<'ctx, 'a, G: CodeGenerator + ?Sized>( }; ctx.builder.build_unconditional_branch(test_bb); ctx.builder.position_at_end(test_bb); - let test = generator.gen_expr(ctx, test).unwrap(); + let test = generator.gen_expr(ctx, test).unwrap().to_basic_value_enum(ctx); if let BasicValueEnum::IntValue(test) = test { ctx.builder.build_conditional_branch(test, body_bb, orelse_bb); } else { @@ -256,6 +305,13 @@ pub fn gen_if<'ctx, 'a, G: CodeGenerator + ?Sized>( break; } } + for (k, (_, _, counter)) in var_assignment.iter() { + let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap(); + if counter != counter2 { + *static_val = None; + } + } + if !exited { if cont_bb.is_none() { cont_bb = Some(ctx.ctx.append_basic_block(current, "cont")); @@ -285,6 +341,12 @@ pub fn gen_if<'ctx, 'a, G: CodeGenerator + ?Sized>( if let Some(cont_bb) = cont_bb { ctx.builder.position_at_end(cont_bb); } + for (k, (_, _, counter)) in var_assignment.iter() { + let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap(); + if counter != counter2 { + *static_val = None; + } + } then_exited && else_exited } else { unreachable!() @@ -306,12 +368,14 @@ pub fn gen_stmt<'ctx, 'a, G: CodeGenerator + ?Sized>( stmt: &Stmt>, ) -> bool { match &stmt.node { - StmtKind::Pass { .. } => {} + StmtKind::Pass { .. } => {} StmtKind::Expr { value, .. } => { generator.gen_expr(ctx, value); } StmtKind::Return { value, .. } => { - let value = value.as_ref().map(|v| generator.gen_expr(ctx, v).unwrap()); + let value = value + .as_ref() + .map(|v| generator.gen_expr(ctx, v).unwrap().to_basic_value_enum(ctx)); let value = value.as_ref().map(|v| v as &dyn BasicValue); ctx.builder.build_return(value); return true; @@ -325,14 +389,14 @@ pub fn gen_stmt<'ctx, 'a, G: CodeGenerator + ?Sized>( StmtKind::Assign { targets, value, .. } => { let value = generator.gen_expr(ctx, value).unwrap(); for target in targets.iter() { - generator.gen_assign(ctx, target, value); + generator.gen_assign(ctx, target, value.clone()); } } StmtKind::Continue { .. } => { ctx.builder.build_unconditional_branch(ctx.loop_bb.unwrap().0); return true; } - StmtKind::Break { .. }=> { + StmtKind::Break { .. } => { ctx.builder.build_unconditional_branch(ctx.loop_bb.unwrap().1); return true; } @@ -344,8 +408,8 @@ pub fn gen_stmt<'ctx, 'a, G: CodeGenerator + ?Sized>( let value = { let ty1 = ctx.unifier.get_representative(target.custom.unwrap()); let ty2 = ctx.unifier.get_representative(value.custom.unwrap()); - let left = generator.gen_expr(ctx, target).unwrap(); - let right = generator.gen_expr(ctx, value).unwrap(); + let left = generator.gen_expr(ctx, target).unwrap().to_basic_value_enum(ctx); + let right = generator.gen_expr(ctx, value).unwrap().to_basic_value_enum(ctx); // we can directly compare the types, because we've got their representatives // which would be unchanged until further unification, which we would never do @@ -358,7 +422,7 @@ pub fn gen_stmt<'ctx, 'a, G: CodeGenerator + ?Sized>( unimplemented!() } }; - generator.gen_assign(ctx, target, value); + generator.gen_assign(ctx, target, value.into()); } _ => unimplemented!(), }; diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index f9e2810f..2bbc75ed 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -4,7 +4,7 @@ use crate::{ WithCall, WorkerRegistry, }, location::Location, - symbol_resolver::SymbolResolver, + symbol_resolver::{SymbolResolver, ValueEnum}, toplevel::{ composer::TopLevelComposer, DefinitionId, FunInstance, TopLevelContext, TopLevelDef, }, @@ -14,12 +14,11 @@ use crate::{ }, }; use indoc::indoc; -use inkwell::values::BasicValueEnum; -use parking_lot::RwLock; use nac3parser::{ ast::{fold::Fold, StrRef}, parser::parse_program, }; +use parking_lot::RwLock; use std::cell::RefCell; use std::collections::{HashMap, HashSet}; use std::sync::Arc; @@ -55,7 +54,7 @@ impl SymbolResolver for Resolver { &self, _: StrRef, _: &mut CodeGenContext<'ctx, 'a>, - ) -> Option> { + ) -> Option> { unimplemented!() } @@ -147,6 +146,7 @@ fn test_primitives() { resolver, store, signature, + id: 0, }; let f = Arc::new(WithCall::new(Box::new(|module| { // the following IR is equivalent to @@ -314,6 +314,7 @@ fn test_simple_call() { resolver, signature, store, + id: 0, }; let f = Arc::new(WithCall::new(Box::new(|module| { let expected = indoc! {" diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 516260a8..b2133fe2 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -11,7 +11,7 @@ use crate::{ toplevel::{DefinitionId, TopLevelDef}, }; use crate::{location::Location, typecheck::typedef::TypeEnum}; -use inkwell::values::BasicValueEnum; +use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue}; use itertools::{chain, izip}; use nac3parser::ast::{Expr, StrRef}; use parking_lot::RwLock; @@ -23,8 +23,63 @@ pub enum SymbolValue { Double(f64), Bool(bool), Tuple(Vec), - // we should think about how to implement bytes later... - // Bytes(&'a [u8]), +} + +pub trait StaticValue { + fn get_unique_identifier(&self) -> u64; + + fn to_basic_value_enum<'ctx, 'a>( + &self, + ctx: &mut CodeGenContext<'ctx, 'a>, + ) -> BasicValueEnum<'ctx>; + + fn get_field<'ctx, 'a>( + &self, + name: StrRef, + ctx: &mut CodeGenContext<'ctx, 'a>, + ) -> Option>; +} + +#[derive(Clone)] +pub enum ValueEnum<'ctx> { + Static(Arc), + Dynamic(BasicValueEnum<'ctx>), +} + +impl<'ctx> From> for ValueEnum<'ctx> { + fn from(v: BasicValueEnum<'ctx>) -> Self { + ValueEnum::Dynamic(v) + } +} + +impl<'ctx> From> for ValueEnum<'ctx> { + fn from(v: PointerValue<'ctx>) -> Self { + ValueEnum::Dynamic(v.into()) + } +} + +impl<'ctx> From> for ValueEnum<'ctx> { + fn from(v: IntValue<'ctx>) -> Self { + ValueEnum::Dynamic(v.into()) + } +} + +impl<'ctx> From> for ValueEnum<'ctx> { + fn from(v: FloatValue<'ctx>) -> Self { + ValueEnum::Dynamic(v.into()) + } +} + +impl<'ctx> ValueEnum<'ctx> { + pub fn to_basic_value_enum<'a>( + self, + ctx: &mut CodeGenContext<'ctx, 'a>, + ) -> BasicValueEnum<'ctx> { + match self { + ValueEnum::Static(v) => v.to_basic_value_enum(ctx), + ValueEnum::Dynamic(v) => v, + } + } } pub trait SymbolResolver { @@ -36,13 +91,16 @@ pub trait SymbolResolver { primitives: &PrimitiveStore, str: StrRef, ) -> Option; + // get the top-level definition of identifiers fn get_identifier_def(&self, str: StrRef) -> Option; + fn get_symbol_value<'ctx, 'a>( &self, str: StrRef, ctx: &mut CodeGenContext<'ctx, 'a>, - ) -> Option>; + ) -> Option>; + fn get_symbol_location(&self, str: StrRef) -> Option; fn get_default_param_value(&self, expr: &nac3parser::ast::Expr) -> Option; // handle function call etc. diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 9fcd313f..ebbfe6cc 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -1,7 +1,7 @@ use std::cell::RefCell; use nac3parser::ast::fold::Fold; -use inkwell::FloatPredicate; +use inkwell::{FloatPredicate, IntPredicate}; use crate::{ symbol_resolver::SymbolValue, @@ -195,7 +195,7 @@ impl TopLevelComposer { signature: primitives.1.add_ty(TypeEnum::TFunc(RefCell::new(FunSignature { args: vec![FuncArg { name: "_".into(), ty: num_ty.0, default_value: None }], ret: float, - vars: var_map, + vars: var_map.clone(), }))), var_id: Default::default(), instance_to_symbol: Default::default(), @@ -398,7 +398,7 @@ impl TopLevelComposer { signature: primitives.1.add_ty(TypeEnum::TFunc(RefCell::new(FunSignature { args: vec![FuncArg { name: "_".into(), ty: num_ty.0, default_value: None }], ret: primitives.0.bool, - vars: Default::default(), + vars: var_map, }))), var_id: Default::default(), instance_to_symbol: Default::default(), @@ -415,15 +415,12 @@ impl TopLevelComposer { if ctx.unifier.unioned(arg_ty, boolean) { Some(arg) } else if ctx.unifier.unioned(arg_ty, int32) || ctx.unifier.unioned(arg_ty, int64) { - Some( - ctx.builder - .build_int_truncate( - arg.into_int_value(), - ctx.ctx.bool_type(), - "trunc", - ) - .into(), - ) + Some(ctx.builder.build_int_compare( + IntPredicate::NE, + ctx.ctx.i64_type().const_zero(), + arg.into_int_value(), + "bool", + ).into()) } else if ctx.unifier.unioned(arg_ty, float) { let val = ctx.builder. build_float_compare( @@ -1106,7 +1103,7 @@ impl TopLevelComposer { )); } } - + let arg_with_default: Vec<(&ast::Located>, Option<&ast::Expr>)> = args .args .iter() @@ -1349,7 +1346,7 @@ impl TopLevelComposer { } let mut result = Vec::new(); - + let arg_with_default: Vec<(&ast::Located>, Option<&ast::Expr>)> = args .args .iter() @@ -1361,7 +1358,7 @@ impl TopLevelComposer { .map(|x| -> Option<&ast::Expr> { Some(x) }) .chain(std::iter::repeat(None)) ).collect_vec(); - + for (x, default) in arg_with_default.into_iter().rev() { let name = x.node.arg; if name != zelf { diff --git a/nac3core/src/toplevel/test.rs b/nac3core/src/toplevel/test.rs index 3aed02b2..1fcfa6ad 100644 --- a/nac3core/src/toplevel/test.rs +++ b/nac3core/src/toplevel/test.rs @@ -1,7 +1,7 @@ use crate::{ codegen::CodeGenContext, location::Location, - symbol_resolver::SymbolResolver, + symbol_resolver::{SymbolResolver, ValueEnum}, toplevel::DefinitionId, typecheck::{ type_inferencer::PrimitiveStore, @@ -54,7 +54,7 @@ impl SymbolResolver for Resolver { &self, _: StrRef, _: &mut CodeGenContext<'ctx, 'a>, - ) -> Option> { + ) -> Option> { unimplemented!() } diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index da51aa9d..cc209bba 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -3,10 +3,10 @@ use super::*; use crate::{ codegen::CodeGenContext, location::Location, + symbol_resolver::ValueEnum, toplevel::{DefinitionId, TopLevelDef}, }; use indoc::indoc; -use inkwell::values::BasicValueEnum; use itertools::zip; use nac3parser::parser::parse_program; use parking_lot::RwLock; @@ -37,7 +37,7 @@ impl SymbolResolver for Resolver { &self, _: StrRef, _: &mut CodeGenContext<'ctx, 'a>, - ) -> Option> { + ) -> Option> { unimplemented!() } diff --git a/nac3standalone/src/basic_symbol_resolver.rs b/nac3standalone/src/basic_symbol_resolver.rs index 9983d507..a32432e9 100644 --- a/nac3standalone/src/basic_symbol_resolver.rs +++ b/nac3standalone/src/basic_symbol_resolver.rs @@ -1,16 +1,15 @@ -use inkwell::values::BasicValueEnum; use nac3core::{ codegen::CodeGenContext, location::Location, - symbol_resolver::{SymbolResolver, SymbolValue}, + symbol_resolver::{SymbolResolver, SymbolValue, ValueEnum}, toplevel::{DefinitionId, TopLevelDef}, typecheck::{ type_inferencer::PrimitiveStore, typedef::{Type, Unifier}, }, }; -use parking_lot::{Mutex, RwLock}; use nac3parser::ast::{self, StrRef}; +use parking_lot::{Mutex, RwLock}; use std::{collections::HashMap, sync::Arc}; pub struct ResolverInternal { @@ -64,7 +63,7 @@ impl SymbolResolver for Resolver { &self, _: StrRef, _: &mut CodeGenContext<'ctx, 'a>, - ) -> Option> { + ) -> Option> { unimplemented!() } diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index f6c4be68..85b39325 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -183,6 +183,7 @@ fn main() { store, unifier_index: instance.unifier_id, calls: instance.calls, + id: 0, }; let f = Arc::new(WithCall::new(Box::new(move |module| { let builder = PassManagerBuilder::create();