diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 0af1bfcd..813f7984 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -1,12 +1,10 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::fs; use std::path::Path; use std::process::Command; use std::sync::Arc; use inkwell::{ - AddressSpace, AtomicOrdering, - values::BasicValueEnum, passes::{PassManager, PassManagerBuilder}, targets::*, OptimizationLevel, @@ -18,13 +16,13 @@ use rustpython_parser::{ parser, }; -use parking_lot::RwLock; +use parking_lot::{RwLock, Mutex}; use nac3core::{ codegen::{CodeGenTask, WithCall, WorkerRegistry}, symbol_resolver::SymbolResolver, - toplevel::{composer::TopLevelComposer, TopLevelContext, TopLevelDef, GenCall}, - typecheck::typedef::{FunSignature, FuncArg}, + toplevel::{composer::TopLevelComposer, TopLevelContext, TopLevelDef}, + typecheck::typedef::FunSignature, }; use nac3core::{ toplevel::DefinitionId, @@ -33,6 +31,7 @@ use nac3core::{ use crate::symbol_resolver::Resolver; +mod builtins; mod symbol_resolver; #[derive(PartialEq, Clone, Copy)] @@ -41,6 +40,17 @@ enum Isa { CortexA9, } +#[derive(Clone)] +pub struct PrimitivePythonId { + int: u64, + int32: u64, + int64: u64, + float: u64, + bool: u64, + list: u64, + tuple: u64, +} + // TopLevelComposer is unsendable as it holds the unification table, which is // unsendable due to Rc. Arc would cause a performance hit. #[pyclass(unsendable, name = "NAC3")] @@ -54,6 +64,8 @@ struct Nac3 { composer: TopLevelComposer, top_level: Option>, to_be_registered: Vec, + primitive_ids: PrimitivePythonId, + global_value_ids: Arc>>, } impl Nac3 { @@ -81,16 +93,18 @@ impl Nac3 { let source = fs::read_to_string(source_file).map_err(|e| { exceptions::PyIOError::new_err(format!("failed to read input file: {}", e)) })?; - let parser_result = parser::parse_program(&source).map_err(|e| { - exceptions::PySyntaxError::new_err(format!("parse error: {}", e)) - })?; + let parser_result = parser::parse_program(&source) + .map_err(|e| exceptions::PySyntaxError::new_err(format!("parse error: {}", e)))?; let resolver = Arc::new(Box::new(Resolver { id_to_type: self.builtins_ty.clone().into(), id_to_def: self.builtins_def.clone().into(), pyid_to_def: self.pyid_to_def.clone(), pyid_to_type: self.pyid_to_type.clone(), + primitive_ids: self.primitive_ids.clone(), + global_value_ids: self.global_value_ids.clone(), class_names: Default::default(), name_to_pyid: name_to_pyid.clone(), + module: obj, }) as Box); let mut name_to_def = HashMap::new(); let mut name_to_type = HashMap::new(); @@ -163,128 +177,76 @@ impl Nac3 { } } -// ARTIQ timeline control with now-pinning optimization. -fn timeline_builtins(primitive: &PrimitiveStore) -> Vec<(StrRef, FunSignature, Arc)> { - vec![( - "now_mu".into(), - FunSignature { - args: vec![], - ret: primitive.int64, - vars: HashMap::new(), - }, - Arc::new(GenCall::new(Box::new( - |ctx, _, _, _| { - let i64_type = ctx.ctx.i64_type(); - let now = ctx.module.get_global("now").unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); - let now_raw = ctx.builder.build_load(now.as_pointer_value(), "now"); - if let BasicValueEnum::IntValue(now_raw) = now_raw { - let i64_32 = i64_type.const_int(32, false).into(); - let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now_shl"); - let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now_lshr").into(); - Some(ctx.builder.build_or(now_lo, now_hi, "now_or").into()) - } else { - unreachable!() - } - } - ))) - ),( - "at_mu".into(), - FunSignature { - args: vec![FuncArg { - name: "t".into(), - ty: primitive.int64, - default_value: None, - }], - ret: primitive.none, - vars: HashMap::new(), - }, - Arc::new(GenCall::new(Box::new( - |ctx, _, _, args| { - let i32_type = ctx.ctx.i32_type(); - let i64_type = ctx.ctx.i64_type(); - let i64_32 = i64_type.const_int(32, false).into(); - if let BasicValueEnum::IntValue(time) = args[0].1 { - let time_hi = ctx.builder.build_int_truncate(ctx.builder.build_right_shift(time, i64_32, false, "now_lshr"), i32_type, "now_trunc"); - let time_lo = ctx.builder.build_int_truncate(time, i32_type, "now_trunc"); - let now = ctx.module.get_global("now").unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); - let now_hiptr = ctx.builder.build_bitcast(now, i32_type.ptr_type(AddressSpace::Generic), "now_bitcast"); - if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr { - let now_loptr = unsafe { ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false).into()], "now_gep") }; - ctx.builder.build_store(now_hiptr, time_hi).set_atomic_ordering(AtomicOrdering::SequentiallyConsistent).unwrap(); - ctx.builder.build_store(now_loptr, time_lo).set_atomic_ordering(AtomicOrdering::SequentiallyConsistent).unwrap(); - None - } else { - unreachable!(); - } - } else { - unreachable!(); - } - } - ))) - ),( - "delay_mu".into(), - FunSignature { - args: vec![FuncArg { - name: "dt".into(), - ty: primitive.int64, - default_value: None, - }], - ret: primitive.none, - vars: HashMap::new(), - }, - Arc::new(GenCall::new(Box::new( - |ctx, _, _, args| { - let i32_type = ctx.ctx.i32_type(); - let i64_type = ctx.ctx.i64_type(); - let i64_32 = i64_type.const_int(32, false).into(); - let now = ctx.module.get_global("now").unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); - let now_raw = ctx.builder.build_load(now.as_pointer_value(), "now"); - if let (BasicValueEnum::IntValue(now_raw), BasicValueEnum::IntValue(dt)) = (now_raw, args[0].1) { - let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now_shl"); - let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now_lshr").into(); - let now_val = ctx.builder.build_or(now_lo, now_hi, "now_or"); - let time = ctx.builder.build_int_add(now_val, dt, "now_add"); - let time_hi = ctx.builder.build_int_truncate(ctx.builder.build_right_shift(time, i64_32, false, "now_lshr"), i32_type, "now_trunc"); - let time_lo = ctx.builder.build_int_truncate(time, i32_type, "now_trunc"); - let now_hiptr = ctx.builder.build_bitcast(now, i32_type.ptr_type(AddressSpace::Generic), "now_bitcast"); - if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr { - let now_loptr = unsafe { ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false).into()], "now_gep") }; - ctx.builder.build_store(now_hiptr, time_hi).set_atomic_ordering(AtomicOrdering::SequentiallyConsistent).unwrap(); - ctx.builder.build_store(now_loptr, time_lo).set_atomic_ordering(AtomicOrdering::SequentiallyConsistent).unwrap(); - None - } else { - unreachable!(); - } - } else { - unreachable!(); - } - } - ))) - )] -} - #[pymethods] impl Nac3 { #[new] - fn new(isa: &str) -> PyResult { + fn new(isa: &str, py: Python) -> PyResult { let isa = match isa { "riscv" => Isa::RiscV, "cortexa9" => Isa::CortexA9, _ => return Err(exceptions::PyValueError::new_err("invalid ISA")), }; let primitive: PrimitiveStore = TopLevelComposer::make_primitives().0; - let builtins = if isa == Isa::RiscV { timeline_builtins(&primitive) } else { vec![] }; + let builtins = if isa == Isa::RiscV { + builtins::timeline_builtins(&primitive) + } else { + vec![] + }; let (composer, builtins_def, builtins_ty) = TopLevelComposer::new(builtins); + + let builtins_mod = PyModule::import(py, "builtins").unwrap(); + let id_fn = builtins_mod.getattr("id").unwrap(); + let numpy_mod = PyModule::import(py, "numpy").unwrap(); + let primitive_ids = PrimitivePythonId { + int: id_fn + .call1((builtins_mod.getattr("int").unwrap(),)) + .unwrap() + .extract() + .unwrap(), + int32: id_fn + .call1((numpy_mod.getattr("int32").unwrap(),)) + .unwrap() + .extract() + .unwrap(), + int64: id_fn + .call1((numpy_mod.getattr("int64").unwrap(),)) + .unwrap() + .extract() + .unwrap(), + bool: id_fn + .call1((builtins_mod.getattr("bool").unwrap(),)) + .unwrap() + .extract() + .unwrap(), + float: id_fn + .call1((builtins_mod.getattr("float").unwrap(),)) + .unwrap() + .extract() + .unwrap(), + list: id_fn + .call1((builtins_mod.getattr("list").unwrap(),)) + .unwrap() + .extract() + .unwrap(), + tuple: id_fn + .call1((builtins_mod.getattr("tuple").unwrap(),)) + .unwrap() + .extract() + .unwrap(), + }; + Ok(Nac3 { isa, primitive, builtins_ty, builtins_def, composer, + primitive_ids, top_level: None, pyid_to_def: Default::default(), pyid_to_type: Default::default(), to_be_registered: Default::default(), + global_value_ids: Default::default(), }) } @@ -301,7 +263,7 @@ impl Nac3 { Ok(()) } - fn compile_method(&mut self, class: u64, method_name: String) -> PyResult<()> { + fn compile_method(&mut self, class: u64, method_name: String, py: Python) -> PyResult<()> { let top_level = self.top_level.as_ref().unwrap(); let module_resolver; let instance = { @@ -356,7 +318,7 @@ impl Nac3 { let isa = self.isa; let f = Arc::new(WithCall::new(Box::new(move |module| { let builder = PassManagerBuilder::create(); - builder.set_optimization_level(OptimizationLevel::Aggressive); + builder.set_optimization_level(OptimizationLevel::Default); let passes = PassManager::create(()); builder.populate_module_pass_manager(&passes); passes.run_on(module); @@ -390,9 +352,13 @@ impl Nac3 { }))); let thread_names: Vec = (0..4).map(|i| format!("module{}", i)).collect(); let threads: Vec<_> = thread_names.iter().map(|s| s.as_str()).collect(); - let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level.clone(), f); - registry.add_task(task); - registry.wait_tasks_complete(handles); + + py.allow_threads(|| { + let (registry, handles) = + WorkerRegistry::create_workers(&threads, top_level.clone(), f); + registry.add_task(task); + registry.wait_tasks_complete(handles); + }); let mut linker_args = vec![ "-shared".to_string(), diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index f163de64..1b1c62e9 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -1,32 +1,348 @@ +use inkwell::{types::BasicType, values::BasicValueEnum, AddressSpace}; use nac3core::{ + codegen::CodeGenContext, location::Location, - symbol_resolver::{SymbolResolver, SymbolValue}, - toplevel::DefinitionId, + symbol_resolver::SymbolResolver, + toplevel::{DefinitionId, TopLevelDef}, typecheck::{ type_inferencer::PrimitiveStore, - typedef::{Type, Unifier}, + typedef::{Type, TypeEnum, Unifier}, }, }; use parking_lot::{Mutex, RwLock}; +use pyo3::{ + types::{PyList, PyModule, PyTuple}, + PyAny, PyObject, PyResult, Python, +}; use rustpython_parser::ast::StrRef; -use std::{collections::HashMap, sync::Arc}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; + +use crate::PrimitivePythonId; pub struct Resolver { pub id_to_type: Mutex>, pub id_to_def: Mutex>, + pub global_value_ids: Arc>>, pub class_names: Mutex>, pub pyid_to_def: Arc>>, pub pyid_to_type: Arc>>, + pub primitive_ids: PrimitivePythonId, // module specific pub name_to_pyid: HashMap, + pub module: PyObject, +} + +struct PythonHelper<'a> { + type_fn: &'a PyAny, + len_fn: &'a PyAny, + id_fn: &'a PyAny, +} + +impl Resolver { + fn get_list_elem_type( + &self, + list: &PyAny, + len: usize, + helper: &PythonHelper, + unifier: &mut Unifier, + defs: &[Arc>], + primitives: &PrimitiveStore, + ) -> PyResult> { + let first = self.get_obj_type(list.get_item(0)?, helper, unifier, defs, primitives)?; + Ok((1..len).fold(first, |a, i| { + let b = list + .get_item(i) + .map(|elem| self.get_obj_type(elem, helper, unifier, defs, primitives)); + a.and_then(|a| { + if let Ok(Ok(Some(ty))) = b { + if unifier.unify(a, ty).is_ok() { + Some(a) + } else { + None + } + } else { + None + } + }) + })) + } + + fn get_obj_type( + &self, + obj: &PyAny, + helper: &PythonHelper, + unifier: &mut Unifier, + defs: &[Arc>], + primitives: &PrimitiveStore, + ) -> PyResult> { + let ty_id: u64 = helper + .id_fn + .call1((helper.type_fn.call1((obj,))?,))? + .extract()?; + + if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 { + Ok(Some(primitives.int32)) + } else if ty_id == self.primitive_ids.int64 { + Ok(Some(primitives.int64)) + } else if ty_id == self.primitive_ids.bool { + Ok(Some(primitives.bool)) + } else if ty_id == self.primitive_ids.float { + Ok(Some(primitives.float)) + } else if ty_id == self.primitive_ids.list { + let len: usize = helper.len_fn.call1((obj,))?.extract()?; + if len == 0 { + let var = unifier.get_fresh_var().0; + let list = unifier.add_ty(TypeEnum::TList { ty: var }); + Ok(Some(list)) + } else { + let ty = self.get_list_elem_type(obj, len, helper, unifier, defs, primitives)?; + Ok(ty.map(|ty| unifier.add_ty(TypeEnum::TList { ty }))) + } + } else if ty_id == self.primitive_ids.tuple { + let elements: &PyTuple = obj.cast_as()?; + let types: Result>, _> = elements + .iter() + .map(|elem| self.get_obj_type(elem, helper, unifier, defs, primitives)) + .collect(); + let types = types?; + Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types }))) + } else { + Ok(None) + } + } + + fn get_obj_value<'ctx, 'a>( + &self, + obj: &PyAny, + helper: &PythonHelper, + ctx: &mut CodeGenContext<'ctx, 'a>, + ) -> PyResult>> { + let ty_id: u64 = helper + .id_fn + .call1((helper.type_fn.call1((obj,))?,))? + .extract()?; + if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 { + let val: i32 = obj.extract()?; + Ok(Some(ctx.ctx.i32_type().const_int(val as u64, false).into())) + } else if ty_id == self.primitive_ids.int64 { + let val: i64 = obj.extract()?; + Ok(Some(ctx.ctx.i64_type().const_int(val as u64, false).into())) + } else if ty_id == self.primitive_ids.bool { + let val: bool = obj.extract()?; + Ok(Some( + ctx.ctx.bool_type().const_int(val as u64, false).into(), + )) + } else if ty_id == self.primitive_ids.float { + let val: f64 = obj.extract()?; + Ok(Some(ctx.ctx.f64_type().const_float(val).into())) + } else if ty_id == self.primitive_ids.list { + let id: u64 = helper.id_fn.call1((obj,))?.extract()?; + let id_str = id.to_string(); + let len: usize = helper.len_fn.call1((obj,))?.extract()?; + if len == 0 { + let int32 = ctx.ctx.i32_type(); + return Ok(Some( + ctx.ctx + .struct_type( + &[int32.into(), int32.ptr_type(AddressSpace::Generic).into()], + false, + ) + .const_zero() + .into(), + )); + } + let ty = self + .get_list_elem_type( + obj, + len, + helper, + &mut ctx.unifier, + &ctx.top_level.definitions.read(), + &ctx.primitives, + )? + .unwrap(); + let ty = ctx.get_llvm_type(ty); + let arr_ty = ctx.ctx.struct_type( + &[ + ctx.ctx.i32_type().into(), + ty.ptr_type(AddressSpace::Generic).into(), + ], + false, + ); + + { + let mut global_value_ids = self.global_value_ids.lock(); + if global_value_ids.contains(&id) { + let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { + ctx.module + .add_global(arr_ty, Some(AddressSpace::Generic), &id_str) + }); + return Ok(Some(global.as_pointer_value().into())); + } else { + global_value_ids.insert(id); + } + } + + let arr: Result>, _> = (0..len) + .map(|i| { + obj.get_item(i) + .and_then(|elem| self.get_obj_value(elem, helper, ctx)) + }) + .collect(); + let arr = arr?.unwrap(); + + let arr_global = ctx.module.add_global( + ty.array_type(len as u32), + Some(AddressSpace::Generic), + &(id_str.clone() + "_"), + ); + let arr: BasicValueEnum = if ty.is_int_type() { + let arr: Vec<_> = arr + .into_iter() + .map(BasicValueEnum::into_int_value) + .collect(); + ty.into_int_type().const_array(&arr) + } else if ty.is_float_type() { + let arr: Vec<_> = arr + .into_iter() + .map(BasicValueEnum::into_float_value) + .collect(); + ty.into_float_type().const_array(&arr) + } else if ty.is_array_type() { + let arr: Vec<_> = arr + .into_iter() + .map(BasicValueEnum::into_array_value) + .collect(); + ty.into_array_type().const_array(&arr) + } else if ty.is_struct_type() { + let arr: Vec<_> = arr + .into_iter() + .map(BasicValueEnum::into_struct_value) + .collect(); + ty.into_struct_type().const_array(&arr) + } else if ty.is_pointer_type() { + let arr: Vec<_> = arr + .into_iter() + .map(BasicValueEnum::into_pointer_value) + .collect(); + ty.into_pointer_type().const_array(&arr) + } else { + unreachable!() + } + .into(); + arr_global.set_initializer(&arr); + + let val = arr_ty.const_named_struct(&[ + ctx.ctx.i32_type().const_int(len as u64, false).into(), + arr_global + .as_pointer_value() + .const_cast(ty.ptr_type(AddressSpace::Generic)) + .into(), + ]); + + let global = ctx + .module + .add_global(arr_ty, Some(AddressSpace::Generic), &id_str); + global.set_initializer(&val); + + Ok(Some(global.as_pointer_value().into())) + } else if ty_id == self.primitive_ids.tuple { + let id: u64 = helper.id_fn.call1((obj,))?.extract()?; + let id_str = id.to_string(); + let elements: &PyTuple = obj.cast_as()?; + let types: Result>, _> = elements + .iter() + .map(|elem| { + self.get_obj_type( + elem, + helper, + &mut ctx.unifier, + &ctx.top_level.definitions.read(), + &ctx.primitives, + ) + .map(|ty| ty.map(|ty| ctx.get_llvm_type(ty))) + }) + .collect(); + let types = types?.unwrap(); + let ty = ctx.ctx.struct_type(&types, false); + + { + let mut global_value_ids = self.global_value_ids.lock(); + if global_value_ids.contains(&id) { + let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { + ctx.module + .add_global(ty, Some(AddressSpace::Generic), &id_str) + }); + return Ok(Some(global.as_pointer_value().into())); + } else { + global_value_ids.insert(id); + } + } + + let val: Result>, _> = elements + .iter() + .map(|elem| self.get_obj_value(elem, helper, ctx)) + .collect(); + let val = val?.unwrap(); + let val = ctx.ctx.const_struct(&val, false); + let global = ctx + .module + .add_global(ty, Some(AddressSpace::Generic), &id_str); + global.set_initializer(&val); + Ok(Some(global.as_pointer_value().into())) + } else { + Ok(None) + } + } } impl SymbolResolver for Resolver { - fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: StrRef) -> Option { + fn get_symbol_type( + &self, + unifier: &mut Unifier, + defs: &[Arc>], + primitives: &PrimitiveStore, + str: StrRef, + ) -> Option { let mut id_to_type = self.id_to_type.lock(); id_to_type.get(&str).cloned().or_else(|| { let py_id = self.name_to_pyid.get(&str); - let result = py_id.and_then(|id| self.pyid_to_type.read().get(&id).copied()); + let result = py_id.and_then(|id| { + self.pyid_to_type.read().get(&id).copied().or_else(|| { + Python::with_gil(|py| -> PyResult> { + let obj: &PyAny = self.module.extract(py)?; + let members: &PyList = PyModule::import(py, "inspect")? + .getattr("getmembers")? + .call1((obj,))? + .cast_as()?; + let mut sym_ty = None; + for member in members.iter() { + let key: &str = member.get_item(0)?.extract()?; + if key == str.to_string() { + let builtins = PyModule::import(py, "builtins")?; + let helper = PythonHelper { + id_fn: builtins.getattr("id").unwrap(), + len_fn: builtins.getattr("len").unwrap(), + type_fn: builtins.getattr("type").unwrap(), + }; + sym_ty = self.get_obj_type( + member.get_item(1)?, + &helper, + unifier, + defs, + primitives, + )?; + break; + } + } + Ok(sym_ty) + }) + .unwrap() + }) + }); if let Some(result) = &result { id_to_type.insert(str, *result); } @@ -34,8 +350,35 @@ impl SymbolResolver for Resolver { }) } - fn get_symbol_value(&self, _: StrRef) -> Option { - unimplemented!() + fn get_symbol_value<'ctx, 'a>( + &self, + id: StrRef, + ctx: &mut CodeGenContext<'ctx, 'a>, + ) -> Option> { + Python::with_gil(|py| -> PyResult>> { + let obj: &PyAny = self.module.extract(py)?; + let members: &PyList = PyModule::import(py, "inspect")? + .getattr("getmembers")? + .call1((obj,))? + .cast_as()?; + let mut sym_value = None; + for member in members.iter() { + let key: &str = member.get_item(0)?.extract()?; + let val = member.get_item(1)?; + if key == id.to_string() { + let builtins = PyModule::import(py, "builtins")?; + let helper = PythonHelper { + id_fn: builtins.getattr("id").unwrap(), + len_fn: builtins.getattr("len").unwrap(), + type_fn: builtins.getattr("type").unwrap(), + }; + sym_value = self.get_obj_value(val, &helper, ctx)?; + break; + } + } + Ok(sym_value) + }) + .unwrap() } fn get_symbol_location(&self, _: StrRef) -> Option { diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 34407bc0..b3233c1b 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -108,7 +108,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { fun: (&FunSignature, DefinitionId), params: Vec<(Option, BasicValueEnum<'ctx>)>, ) -> Option> { - let definition = self.top_level.definitions.read().get(fun.1.0).cloned().unwrap(); + let definition = self.top_level.definitions.read().get(fun.1 .0).cloned().unwrap(); let mut task = None; let key = self.get_subst_key(obj.map(|a| a.0), fun.0, None); let symbol = { @@ -283,7 +283,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { } else { unreachable!(); }; - ty.const_int(v.try_into().unwrap(), false).into() + let val: i64 = v.try_into().unwrap(); + ty.const_int(val as u64, false).into() } Constant::Float(v) => { assert!(self.unifier.unioned(ty, self.primitives.float)); @@ -386,8 +387,13 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { self.gen_const(value, ty) } ExprKind::Name { id, .. } => { - let ptr = self.var_assignment.get(id).unwrap(); - self.builder.build_load(*ptr, "load") + let ptr = self.var_assignment.get(id); + if let Some(ptr) = ptr { + self.builder.build_load(*ptr, "load") + } else { + let resolver = self.resolver.clone(); + resolver.get_symbol_value(*id, self).unwrap() + } } ExprKind::List { elts, .. } => { // this shall be optimized later for constant primitive lists... @@ -647,10 +653,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { let mut params = args.iter().map(|arg| (None, self.gen_expr(arg).unwrap())).collect_vec(); let kw_iter = keywords.iter().map(|kw| { - ( - Some(*kw.node.arg.as_ref().unwrap()), - self.gen_expr(&kw.node.value).unwrap(), - ) + (Some(*kw.node.arg.as_ref().unwrap()), self.gen_expr(&kw.node.value).unwrap()) }); params.extend(kw_iter); let call = self.calls.get(&expr.location.into()); diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 3f34c664..2c7bad40 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -1,7 +1,7 @@ use crate::{ - codegen::{CodeGenTask, WithCall, WorkerRegistry}, + codegen::{CodeGenTask, WithCall, WorkerRegistry, CodeGenContext}, location::Location, - symbol_resolver::{SymbolResolver, SymbolValue}, + symbol_resolver::SymbolResolver, toplevel::{ composer::TopLevelComposer, DefinitionId, FunInstance, TopLevelContext, TopLevelDef, }, @@ -11,6 +11,7 @@ use crate::{ }, }; use indoc::indoc; +use inkwell::values::BasicValueEnum; use parking_lot::RwLock; use rustpython_parser::{ ast::{fold::Fold, StrRef}, @@ -33,11 +34,11 @@ impl Resolver { } impl SymbolResolver for Resolver { - fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: StrRef) -> Option { + fn get_symbol_type(&self, _: &mut Unifier, _: &[Arc>], _: &PrimitiveStore, str: StrRef) -> Option { self.id_to_type.get(&str).cloned() } - fn get_symbol_value(&self, _: StrRef) -> Option { + fn get_symbol_value<'ctx, 'a>(&self, _: StrRef, _: &mut CodeGenContext<'ctx, 'a>) -> Option> { unimplemented!() } diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index d657e4e9..958757af 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use std::fmt::Debug; use std::{cell::RefCell, sync::Arc}; -use crate::toplevel::{DefinitionId, TopLevelDef}; +use crate::{codegen::CodeGenContext, toplevel::{DefinitionId, TopLevelDef}}; use crate::typecheck::{ type_inferencer::PrimitiveStore, typedef::{Type, Unifier}, @@ -11,6 +11,7 @@ use crate::{location::Location, typecheck::typedef::TypeEnum}; use itertools::{chain, izip}; use parking_lot::RwLock; use rustpython_parser::ast::{Expr, StrRef}; +use inkwell::values::BasicValueEnum; #[derive(Clone, PartialEq)] pub enum SymbolValue { @@ -28,12 +29,13 @@ pub trait SymbolResolver { fn get_symbol_type( &self, unifier: &mut Unifier, + top_level_defs: &[Arc>], primitives: &PrimitiveStore, str: StrRef, ) -> Option; // get the top-level definition of identifiers fn get_identifier_def(&self, str: StrRef) -> Option; - fn get_symbol_value(&self, str: StrRef) -> Option; + fn get_symbol_value<'ctx, 'a>(&self, str: StrRef, ctx: &mut CodeGenContext<'ctx, 'a>) -> Option>; fn get_symbol_location(&self, str: StrRef) -> Option; // handle function call etc. } @@ -113,7 +115,7 @@ pub fn parse_type_annotation( } else { // it could be a type variable let ty = resolver - .get_symbol_type(unifier, primitives, *id) + .get_symbol_type(unifier, top_level_defs, primitives, *id) .ok_or_else(|| "unknown type variable name".to_owned())?; if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) { Ok(ty) diff --git a/nac3core/src/toplevel/test.rs b/nac3core/src/toplevel/test.rs index a53bed77..e652236b 100644 --- a/nac3core/src/toplevel/test.rs +++ b/nac3core/src/toplevel/test.rs @@ -1,6 +1,7 @@ use crate::{ + codegen::CodeGenContext, location::Location, - symbol_resolver::{SymbolResolver, SymbolValue}, + symbol_resolver::SymbolResolver, toplevel::DefinitionId, typecheck::{ type_inferencer::PrimitiveStore, @@ -34,7 +35,7 @@ impl ResolverInternal { struct Resolver(Arc); impl SymbolResolver for Resolver { - fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: StrRef) -> Option { + fn get_symbol_type(&self, _: &mut Unifier, _: &[Arc>], _: &PrimitiveStore, str: StrRef) -> Option { let ret = self.0.id_to_type.lock().get(&str).cloned(); if ret.is_none() { // println!("unknown here resolver {}", str); @@ -42,7 +43,11 @@ impl SymbolResolver for Resolver { ret } - fn get_symbol_value(&self, _: StrRef) -> Option { + fn get_symbol_value<'ctx, 'a>( + &self, + _: StrRef, + _: &mut CodeGenContext<'ctx, 'a>, + ) -> Option> { unimplemented!() } diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index 6d4d6f0d..7f33e823 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -62,7 +62,7 @@ pub fn parse_ast_to_type_annotation_kinds( )); } Ok(TypeAnnotation::CustomClassKind { id: obj_id, params: vec![] }) - } else if let Some(ty) = resolver.get_symbol_type(unifier, primitives, *id) { + } else if let Some(ty) = resolver.get_symbol_type(unifier, top_level_defs, primitives, *id) { if let TypeEnum::TVar { .. } = unifier.get_ty(ty).as_ref() { Ok(TypeAnnotation::TypeVarKind(ty)) } else { diff --git a/nac3core/src/typecheck/function_check.rs b/nac3core/src/typecheck/function_check.rs index fce0589b..c66bacfe 100644 --- a/nac3core/src/typecheck/function_check.rs +++ b/nac3core/src/typecheck/function_check.rs @@ -57,7 +57,17 @@ impl<'a> Inferencer<'a> { match &expr.node { ExprKind::Name { id, .. } => { if !defined_identifiers.contains(id) { - if self.function_data.resolver.get_identifier_def(*id).is_some() { + if self + .function_data + .resolver + .get_symbol_type( + self.unifier, + &self.top_level.definitions.read(), + self.primitives, + *id, + ) + .is_some() + { defined_identifiers.insert(*id); } else { return Err(format!( diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index af7008ad..eb54ce9c 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -111,32 +111,40 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { if let ast::StmtKind::Assign { targets, value, .. } = node.node { let value = self.fold_expr(*value)?; let value_ty = value.custom.unwrap(); - let targets: Result, _> = targets.into_iter().map(|target| { - if let ast::ExprKind::Name { id, ctx } = target.node { - self.defined_identifiers.insert(id); - let target_ty = if let Some(ty) = self.variable_mapping.get(&id) { - *ty + let targets: Result, _> = targets + .into_iter() + .map(|target| { + if let ast::ExprKind::Name { id, ctx } = target.node { + self.defined_identifiers.insert(id); + let target_ty = if let Some(ty) = self.variable_mapping.get(&id) + { + *ty + } else { + let unifier = &mut self.unifier; + self.function_data + .resolver + .get_symbol_type( + unifier, + &self.top_level.definitions.read(), + self.primitives, + id, + ) + .unwrap_or_else(|| { + self.variable_mapping.insert(id, value_ty); + value_ty + }) + }; + let location = target.location; + self.unifier.unify(value_ty, target_ty).map(|_| Located { + location, + node: ast::ExprKind::Name { id, ctx }, + custom: Some(target_ty), + }) } else { - let unifier = &mut self.unifier; - self - .function_data - .resolver - .get_symbol_type(unifier, self.primitives, id) - .unwrap_or_else(|| { - self.variable_mapping.insert(id, value_ty); - value_ty - }) - }; - let location = target.location; - self.unifier.unify(value_ty, target_ty).map(|_| Located { - location, - node: ast::ExprKind::Name { id, ctx }, - custom: Some(target_ty) - }) - } else { - unreachable!() - } - }).collect(); + unreachable!() + } + }) + .collect(); let targets = targets?; return Ok(Located { location: node.location, @@ -145,7 +153,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { value: Box::new(value), type_comment: None, }, - custom: None + custom: None, }); } else { unreachable!() @@ -207,7 +215,17 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { ast::ExprKind::Constant { value, .. } => Some(self.infer_constant(value)?), ast::ExprKind::Name { id, .. } => { if !self.defined_identifiers.contains(id) { - if self.function_data.resolver.get_identifier_def(*id).is_some() { + if self + .function_data + .resolver + .get_symbol_type( + self.unifier, + &self.top_level.definitions.read(), + self.primitives, + *id, + ) + .is_some() + { self.defined_identifiers.insert(*id); } else { return Err(format!( @@ -359,11 +377,8 @@ impl<'a> Inferencer<'a> { defined_identifiers.insert(*name); } } - let fn_args: Vec<_> = args - .args - .iter() - .map(|v| (v.node.arg, self.unifier.get_fresh_var().0)) - .collect(); + let fn_args: Vec<_> = + args.args.iter().map(|v| (v.node.arg, self.unifier.get_fresh_var().0)).collect(); let mut variable_mapping = self.variable_mapping.clone(); variable_mapping.extend(fn_args.iter().cloned()); let ret = self.unifier.get_fresh_var().0; @@ -596,7 +611,7 @@ impl<'a> Inferencer<'a> { Ok(self .function_data .resolver - .get_symbol_type(unifier, self.primitives, id) + .get_symbol_type(unifier, &self.top_level.definitions.read(), self.primitives, id) .unwrap_or_else(|| { let ty = unifier.get_fresh_var().0; variable_mapping.insert(id, ty); diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index 24dc2df1..aea6a326 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -1,11 +1,12 @@ use super::super::typedef::*; use super::*; -use crate::symbol_resolver::*; use crate::{ + codegen::CodeGenContext, location::Location, toplevel::{DefinitionId, TopLevelDef}, }; use indoc::indoc; +use inkwell::values::BasicValueEnum; use itertools::zip; use parking_lot::RwLock; use rustpython_parser::parser::parse_program; @@ -18,11 +19,15 @@ struct Resolver { } impl SymbolResolver for Resolver { - fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: StrRef) -> Option { + fn get_symbol_type(&self, _: &mut Unifier, _: &[Arc>], _: &PrimitiveStore, str: StrRef) -> Option { self.id_to_type.get(&str).cloned() } - fn get_symbol_value(&self, _: StrRef) -> Option { + fn get_symbol_value<'ctx, 'a>( + &self, + _: StrRef, + _: &mut CodeGenContext<'ctx, 'a>, + ) -> Option> { unimplemented!() } @@ -278,7 +283,7 @@ impl TestEnvironment { let top_level = TopLevelContext { definitions: Arc::new(top_level_defs.into()), unifiers: Default::default(), - personality_symbol: None + personality_symbol: None, }; let resolver = Arc::new(Box::new(Resolver { diff --git a/nac3standalone/src/basic_symbol_resolver.rs b/nac3standalone/src/basic_symbol_resolver.rs index 92f0d0c5..4d42cc37 100644 --- a/nac3standalone/src/basic_symbol_resolver.rs +++ b/nac3standalone/src/basic_symbol_resolver.rs @@ -1,13 +1,9 @@ -use nac3core::{ - location::Location, - symbol_resolver::{SymbolResolver, SymbolValue}, - toplevel::DefinitionId, - typecheck::{ +use inkwell::values::BasicValueEnum; +use nac3core::{codegen::CodeGenContext, location::Location, symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::{DefinitionId, TopLevelDef}, typecheck::{ type_inferencer::PrimitiveStore, typedef::{Type, Unifier}, - }, -}; -use parking_lot::Mutex; + }}; +use parking_lot::{Mutex, RwLock}; use rustpython_parser::ast::StrRef; use std::{collections::HashMap, sync::Arc}; @@ -30,7 +26,7 @@ impl ResolverInternal { pub struct Resolver(pub Arc); impl SymbolResolver for Resolver { - fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: StrRef) -> Option { + fn get_symbol_type(&self, _: &mut Unifier, _: &[Arc>], _: &PrimitiveStore, str: StrRef) -> Option { let ret = self.0.id_to_type.lock().get(&str).cloned(); if ret.is_none() { // println!("unknown here resolver {}", str); @@ -38,7 +34,7 @@ impl SymbolResolver for Resolver { ret } - fn get_symbol_value(&self, _: StrRef) -> Option { + fn get_symbol_value<'ctx, 'a>(&self, _: StrRef, _: &mut CodeGenContext<'ctx, 'a>) -> Option> { unimplemented!() }