1
0
forked from M-Labs/nac3

nac3artiq: implement attribute writeback

We will only writeback attributes that are supported by the current RPC
implementation: primitives, tuple and lists of lists... of primitives.
This commit is contained in:
pca006132 2022-03-25 22:42:01 +08:00 committed by Gitea
parent ba8ed6c663
commit bf067e2481
6 changed files with 164 additions and 39 deletions

View File

@ -5,6 +5,7 @@ class EmbeddingMap:
self.string_map = {} self.string_map = {}
self.string_reverse_map = {} self.string_reverse_map = {}
self.function_map = {} self.function_map = {}
self.attributes_writeback = []
# preallocate exception names # preallocate exception names
self.preallocate_runtime_exception_names(["RuntimeError", self.preallocate_runtime_exception_names(["RuntimeError",

View File

@ -6,7 +6,7 @@ use nac3core::{
}, },
symbol_resolver::ValueEnum, symbol_resolver::ValueEnum,
toplevel::{DefinitionId, GenCall}, toplevel::{DefinitionId, GenCall},
typecheck::typedef::{FunSignature, Type}, typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum}
}; };
use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef}; use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef};
@ -15,7 +15,9 @@ use inkwell::{
context::Context, module::Linkage, types::IntType, values::BasicValueEnum, AddressSpace, context::Context, module::Linkage, types::IntType, values::BasicValueEnum, AddressSpace,
}; };
use crate::timeline::TimeFns; use pyo3::{PyObject, PyResult, Python, types::{PyDict, PyList}};
use crate::{symbol_resolver::InnerResolver, timeline::TimeFns};
use std::{ use std::{
collections::hash_map::DefaultHasher, collections::hash_map::DefaultHasher,
@ -270,8 +272,6 @@ fn gen_rpc_tag<'ctx, 'a>(
buffer.push(b'l'); buffer.push(b'l');
gen_rpc_tag(ctx, *ty, buffer)?; gen_rpc_tag(ctx, *ty, buffer)?;
} }
// we should return an error, this will be fixed after improving error message
// as this requires returning an error during codegen
_ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))), _ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))),
} }
} }
@ -291,7 +291,7 @@ fn rpc_codegen_callback_fn<'ctx, 'a>(
let int32 = ctx.ctx.i32_type(); let int32 = ctx.ctx.i32_type();
let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false); let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false);
let service_id = int32.const_int(fun.1 .0 as u64, false); let service_id = int32.const_int(fun.1.0 as u64, false);
// -- setup rpc tags // -- setup rpc tags
let mut tag = Vec::new(); let mut tag = Vec::new();
if obj.is_some() { if obj.is_some() {
@ -486,6 +486,81 @@ fn rpc_codegen_callback_fn<'ctx, 'a>(
Ok(Some(result)) Ok(Some(result))
} }
pub fn attributes_writeback<'ctx, 'a>(
ctx: &mut CodeGenContext<'ctx, 'a>,
generator: &mut dyn CodeGenerator,
inner_resolver: &InnerResolver,
host_attributes: PyObject,
) -> Result<(), String> {
Python::with_gil(|py| -> PyResult<Result<(), String>> {
let host_attributes = host_attributes.cast_as::<PyList>(py)?;
let top_levels = ctx.top_level.definitions.read();
let globals = inner_resolver.global_value_ids.read();
let int32 = ctx.ctx.i32_type();
let zero = int32.const_zero();
let mut values = Vec::new();
let mut scratch_buffer = Vec::new();
for (_, val) in globals.iter() {
let val = val.as_ref(py);
let ty = inner_resolver.get_obj_type(py, val, &mut ctx.unifier, &top_levels, &ctx.primitives)?;
if let Err(ty) = ty {
return Ok(Err(ty))
}
let ty = ty.unwrap();
match &*ctx.unifier.get_ty(ty) {
TypeEnum::TObj { fields, .. } => {
// we only care about primitive attributes
// for non-primitive attributes, they should be in another global
let mut attributes = Vec::new();
let obj = inner_resolver.get_obj_value(py, val, ctx, generator)?.unwrap();
for (name, (field_ty, is_mutable)) in fields.iter() {
if !is_mutable {
continue
}
if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() {
attributes.push(name.to_string());
let index = ctx.get_attr_index(ty, *name);
values.push((*field_ty, ctx.build_gep_and_load(
obj.into_pointer_value(),
&[zero, int32.const_int(index as u64, false)])));
}
}
if !attributes.is_empty() {
let pydict = PyDict::new(py);
pydict.set_item("obj", val)?;
pydict.set_item("fields", attributes)?;
host_attributes.append(pydict)?;
}
},
TypeEnum::TList { ty: elem_ty } => {
if gen_rpc_tag(ctx, *elem_ty, &mut scratch_buffer).is_ok() {
let pydict = PyDict::new(py);
pydict.set_item("obj", val)?;
host_attributes.append(pydict)?;
values.push((ty, inner_resolver.get_obj_value(py, val, ctx, generator)?.unwrap()));
}
},
_ => {}
}
}
let fun = FunSignature {
args: values.iter().enumerate().map(|(i, (ty, _))| FuncArg {
name: i.to_string().into(),
ty: *ty,
default_value: None
}).collect(),
ret: ctx.primitives.none,
vars: Default::default()
};
let args: Vec<_> = values.into_iter().map(|(_, val)| (None, ValueEnum::Dynamic(val))).collect();
if let Err(e) = rpc_codegen_callback_fn(ctx, None, (&fun, DefinitionId(0)), args, generator) {
return Ok(Err(e));
}
Ok(Ok(()))
}).unwrap()?;
Ok(())
}
pub fn rpc_codegen_callback() -> Arc<GenCall> { pub fn rpc_codegen_callback() -> Arc<GenCall> {
Arc::new(GenCall::new(Box::new(|ctx, obj, fun, args, generator| { Arc::new(GenCall::new(Box::new(|ctx, obj, fun, args, generator| {
rpc_codegen_callback_fn(ctx, obj, fun, args, generator) rpc_codegen_callback_fn(ctx, obj, fun, args, generator)

View File

@ -10,6 +10,7 @@ use inkwell::{
targets::*, targets::*,
OptimizationLevel, OptimizationLevel,
}; };
use nac3core::codegen::gen_func_impl;
use nac3core::toplevel::builtins::get_exn_constructor; use nac3core::toplevel::builtins::get_exn_constructor;
use nac3core::typecheck::typedef::{TypeEnum, Unifier}; use nac3core::typecheck::typedef::{TypeEnum, Unifier};
use nac3parser::{ use nac3parser::{
@ -36,6 +37,7 @@ use nac3core::{
use tempfile::{self, TempDir}; use tempfile::{self, TempDir};
use crate::codegen::attributes_writeback;
use crate::{ use crate::{
codegen::{rpc_codegen_callback, ArtiqCodeGenerator}, codegen::{rpc_codegen_callback, ArtiqCodeGenerator},
symbol_resolver::{InnerResolver, PythonHelper, Resolver, DeferredEvaluationStore}, symbol_resolver::{InnerResolver, PythonHelper, Resolver, DeferredEvaluationStore},
@ -476,6 +478,8 @@ impl Nac3 {
let store_obj = embedding_map.getattr("store_object").unwrap().to_object(py); let store_obj = embedding_map.getattr("store_object").unwrap().to_object(py);
let store_str = embedding_map.getattr("store_str").unwrap().to_object(py); let store_str = embedding_map.getattr("store_str").unwrap().to_object(py);
let store_fun = embedding_map.getattr("store_function").unwrap().to_object(py); let store_fun = embedding_map.getattr("store_function").unwrap().to_object(py);
let host_attributes = embedding_map.getattr("attributes_writeback").unwrap().to_object(py);
let global_value_ids: Arc<RwLock<HashMap<_, _>>> = Arc::new(RwLock::new(HashMap::new()));
let helper = PythonHelper { let helper = PythonHelper {
id_fn: builtins.getattr("id").unwrap().to_object(py), id_fn: builtins.getattr("id").unwrap().to_object(py),
len_fn: builtins.getattr("len").unwrap().to_object(py), len_fn: builtins.getattr("len").unwrap().to_object(py),
@ -503,7 +507,6 @@ impl Nac3 {
let mut module_to_resolver_cache: HashMap<u64, _> = HashMap::new(); let mut module_to_resolver_cache: HashMap<u64, _> = HashMap::new();
let global_value_ids = Arc::new(RwLock::new(HashSet::<u64>::new()));
let mut rpc_ids = vec![]; let mut rpc_ids = vec![];
for (stmt, path, module) in self.top_levels.iter() { for (stmt, path, module) in self.top_levels.iter() {
let py_module: &PyAny = module.extract(py)?; let py_module: &PyAny = module.extract(py)?;
@ -617,7 +620,7 @@ impl Nac3 {
}; };
let mut synthesized = let mut synthesized =
parse_program(&synthesized, "__nac3_synthesized_modinit__".to_string().into()).unwrap(); parse_program(&synthesized, "__nac3_synthesized_modinit__".to_string().into()).unwrap();
let resolver = Arc::new(Resolver(Arc::new(InnerResolver { let inner_resolver = Arc::new(InnerResolver {
id_to_type: builtins_ty.clone().into(), id_to_type: builtins_ty.clone().into(),
id_to_def: builtins_def.clone().into(), id_to_def: builtins_def.clone().into(),
pyid_to_def: self.pyid_to_def.clone(), pyid_to_def: self.pyid_to_def.clone(),
@ -634,17 +637,18 @@ impl Nac3 {
string_store: self.string_store.clone(), string_store: self.string_store.clone(),
exception_ids: self.exception_ids.clone(), exception_ids: self.exception_ids.clone(),
deferred_eval_store: self.deferred_eval_store.clone(), deferred_eval_store: self.deferred_eval_store.clone(),
}))) as Arc<dyn SymbolResolver + Send + Sync>; });
let resolver = Arc::new(Resolver(inner_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
let (_, def_id, _) = composer let (_, def_id, _) = composer
.register_top_level(synthesized.pop().unwrap(), Some(resolver.clone()), "".into()) .register_top_level(synthesized.pop().unwrap(), Some(resolver.clone()), "".into())
.unwrap(); .unwrap();
let signature = let fun_signature =
FunSignature { args: vec![], ret: self.primitive.none, vars: HashMap::new() }; FunSignature { args: vec![], ret: self.primitive.none, vars: HashMap::new() };
let mut store = ConcreteTypeStore::new(); let mut store = ConcreteTypeStore::new();
let mut cache = HashMap::new(); let mut cache = HashMap::new();
let signature = let signature =
store.from_signature(&mut composer.unifier, &self.primitive, &signature, &mut cache); store.from_signature(&mut composer.unifier, &self.primitive, &fun_signature, &mut cache);
let signature = store.add_cty(signature); let signature = store.add_cty(signature);
if let Err(e) = composer.start_analysis(true) { if let Err(e) = composer.start_analysis(true) {
@ -721,12 +725,29 @@ impl Nac3 {
symbol_name: "__modinit__".to_string(), symbol_name: "__modinit__".to_string(),
body: instance.body, body: instance.body,
signature, signature,
resolver, resolver: resolver.clone(),
store, store,
unifier_index: instance.unifier_id, unifier_index: instance.unifier_id,
calls: instance.calls, calls: instance.calls,
id: 0, id: 0,
}; };
let mut store = ConcreteTypeStore::new();
let mut cache = HashMap::new();
let signature =
store.from_signature(&mut composer.unifier, &self.primitive, &fun_signature, &mut cache);
let signature = store.add_cty(signature);
let attributes_writeback_task = CodeGenTask {
subst: Default::default(),
symbol_name: "attributes_writeback".to_string(),
body: Arc::new(Default::default()),
signature,
resolver,
store,
unifier_index: instance.unifier_id,
calls: Arc::new(Default::default()),
id: 0,
};
let isa = self.isa; let isa = self.isa;
let working_directory = self.working_directory.path().to_owned(); let working_directory = self.working_directory.path().to_owned();
@ -746,14 +767,27 @@ impl Nac3 {
.map(|s| Box::new(ArtiqCodeGenerator::new(s.to_string(), size_t, self.time_fns))) .map(|s| Box::new(ArtiqCodeGenerator::new(s.to_string(), size_t, self.time_fns)))
.collect(); .collect();
let membuffer = membuffers.clone();
py.allow_threads(|| { py.allow_threads(|| {
let (registry, handles) = WorkerRegistry::create_workers(threads, top_level.clone(), f); let (registry, handles) = WorkerRegistry::create_workers(threads, top_level.clone(), f);
registry.add_task(task); registry.add_task(task);
registry.wait_tasks_complete(handles); registry.wait_tasks_complete(handles);
let mut generator = ArtiqCodeGenerator::new("attributes_writeback".to_string(), size_t, self.time_fns);
let context = inkwell::context::Context::create();
let module = context.create_module("attributes_writeback");
let builder = context.create_builder();
let (_, module, _) = gen_func_impl(&context, &mut generator, &registry, builder, module,
attributes_writeback_task, |generator, ctx| {
attributes_writeback(ctx, generator, inner_resolver.as_ref(), host_attributes)
}).unwrap();
let buffer = module.write_bitcode_to_memory();
let buffer = buffer.as_slice().into();
membuffer.lock().push(buffer);
}); });
let buffers = membuffers.lock();
let context = inkwell::context::Context::create(); let context = inkwell::context::Context::create();
let buffers = membuffers.lock();
let main = context let main = context
.create_module_from_ir(MemoryBuffer::create_from_memory_range(&buffers[0], "main")) .create_module_from_ir(MemoryBuffer::create_from_memory_range(&buffers[0], "main"))
.unwrap(); .unwrap();
@ -765,6 +799,11 @@ impl Nac3 {
main.link_in_module(other) main.link_in_module(other)
.map_err(|err| CompileError::new_err(err.to_string()))?; .map_err(|err| CompileError::new_err(err.to_string()))?;
} }
let builder = context.create_builder();
let modinit_return = main.get_function("__modinit__").unwrap().get_last_basic_block().unwrap().get_terminator().unwrap();
builder.position_before(&modinit_return);
builder.build_call(main.get_function("attributes_writeback").unwrap(), &[], "attributes_writeback");
main.link_in_module(load_irrt(&context)) main.link_in_module(load_irrt(&context))
.map_err(|err| CompileError::new_err(err.to_string()))?; .map_err(|err| CompileError::new_err(err.to_string()))?;

View File

@ -15,7 +15,7 @@ use pyo3::{
PyAny, PyObject, PyResult, Python, PyAny, PyObject, PyResult, Python,
}; };
use std::{ use std::{
collections::{HashMap, HashSet}, collections::HashMap,
sync::{ sync::{
Arc, Arc,
atomic::{AtomicBool, Ordering::Relaxed} atomic::{AtomicBool, Ordering::Relaxed}
@ -54,7 +54,7 @@ pub struct InnerResolver {
pub id_to_pyval: RwLock<HashMap<StrRef, (u64, PyObject)>>, pub id_to_pyval: RwLock<HashMap<StrRef, (u64, PyObject)>>,
pub id_to_primitive: RwLock<HashMap<u64, PrimitiveValue>>, pub id_to_primitive: RwLock<HashMap<u64, PrimitiveValue>>,
pub field_to_val: RwLock<HashMap<(u64, StrRef), Option<(u64, PyObject)>>>, pub field_to_val: RwLock<HashMap<(u64, StrRef), Option<(u64, PyObject)>>>,
pub global_value_ids: Arc<RwLock<HashSet<u64>>>, pub global_value_ids: Arc<RwLock<HashMap<u64, PyObject>>>,
pub class_names: Mutex<HashMap<StrRef, Type>>, pub class_names: Mutex<HashMap<StrRef, Type>>,
pub pyid_to_def: Arc<RwLock<HashMap<u64, DefinitionId>>>, pub pyid_to_def: Arc<RwLock<HashMap<u64, DefinitionId>>>,
pub pyid_to_type: Arc<RwLock<HashMap<u64, Type>>>, pub pyid_to_type: Arc<RwLock<HashMap<u64, Type>>>,
@ -503,7 +503,7 @@ impl InnerResolver {
} }
} }
fn get_obj_type( pub fn get_obj_type(
&self, &self,
py: Python, py: Python,
obj: &PyAny, obj: &PyAny,
@ -605,7 +605,7 @@ impl InnerResolver {
unreachable!("must be tobj") unreachable!("must be tobj")
} }
} }
let ty = match self.get_obj_type(py, field_data, unifier, defs, primitives)? { let ty = match self.get_obj_type(py, field_data, unifier, defs, primitives)? {
Ok(t) => t, Ok(t) => t,
Err(e) => { Err(e) => {
@ -686,7 +686,7 @@ impl InnerResolver {
} }
} }
fn get_obj_value<'ctx, 'a>( pub fn get_obj_value<'ctx, 'a>(
&self, &self,
py: Python, py: Python,
obj: &PyAny, obj: &PyAny,
@ -754,13 +754,13 @@ impl InnerResolver {
.struct_type(&[ty.ptr_type(AddressSpace::Generic).into(), size_t.into()], false); .struct_type(&[ty.ptr_type(AddressSpace::Generic).into(), size_t.into()], false);
{ {
if self.global_value_ids.read().contains(&id) { if self.global_value_ids.read().contains_key(&id) {
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
ctx.module.add_global(arr_ty, Some(AddressSpace::Generic), &id_str) ctx.module.add_global(arr_ty, Some(AddressSpace::Generic), &id_str)
}); });
return Ok(Some(global.as_pointer_value().into())); return Ok(Some(global.as_pointer_value().into()));
} else { } else {
self.global_value_ids.write().insert(id); self.global_value_ids.write().insert(id, obj.into());
} }
} }
@ -834,13 +834,13 @@ impl InnerResolver {
let ty = ctx.ctx.struct_type(&types, false); let ty = ctx.ctx.struct_type(&types, false);
{ {
if self.global_value_ids.read().contains(&id) { if self.global_value_ids.read().contains_key(&id) {
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
ctx.module.add_global(ty, Some(AddressSpace::Generic), &id_str) ctx.module.add_global(ty, Some(AddressSpace::Generic), &id_str)
}); });
return Ok(Some(global.as_pointer_value().into())); return Ok(Some(global.as_pointer_value().into()));
} else { } else {
self.global_value_ids.write().insert(id); self.global_value_ids.write().insert(id, obj.into());
} }
} }
@ -869,13 +869,13 @@ impl InnerResolver {
Some(v) => { Some(v) => {
let global_str = format!("{}_option", id); let global_str = format!("{}_option", id);
{ {
if self.global_value_ids.read().contains(&id) { if self.global_value_ids.read().contains_key(&id) {
let global = ctx.module.get_global(&global_str).unwrap_or_else(|| { let global = ctx.module.get_global(&global_str).unwrap_or_else(|| {
ctx.module.add_global(v.get_type(), Some(AddressSpace::Generic), &global_str) ctx.module.add_global(v.get_type(), Some(AddressSpace::Generic), &global_str)
}); });
return Ok(Some(global.as_pointer_value().into())); return Ok(Some(global.as_pointer_value().into()));
} else { } else {
self.global_value_ids.write().insert(id); self.global_value_ids.write().insert(id, obj.into());
} }
} }
let global = ctx.module.add_global(v.get_type(), Some(AddressSpace::Generic), &global_str); let global = ctx.module.add_global(v.get_type(), Some(AddressSpace::Generic), &global_str);
@ -902,13 +902,13 @@ impl InnerResolver {
.get_element_type() .get_element_type()
.into_struct_type(); .into_struct_type();
{ {
if self.global_value_ids.read().contains(&id) { if self.global_value_ids.read().contains_key(&id) {
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
ctx.module.add_global(ty, Some(AddressSpace::Generic), &id_str) ctx.module.add_global(ty, Some(AddressSpace::Generic), &id_str)
}); });
return Ok(Some(global.as_pointer_value().into())); return Ok(Some(global.as_pointer_value().into()));
} else { } else {
self.global_value_ids.write().insert(id); self.global_value_ids.write().insert(id, obj.into());
} }
} }
// should be classes // should be classes

View File

@ -355,9 +355,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
} }
} }
pub fn gen_string<G: CodeGenerator, S: Into<String>>( pub fn gen_string<S: Into<String>>(
&mut self, &mut self,
generator: &mut G, generator: &mut dyn CodeGenerator,
s: S, s: S,
) -> BasicValueEnum<'ctx> { ) -> BasicValueEnum<'ctx> {
self.gen_const(generator, &nac3parser::ast::Constant::Str(s.into()), self.primitives.str) self.gen_const(generator, &nac3parser::ast::Constant::Str(s.into()), self.primitives.str)

View File

@ -360,13 +360,14 @@ fn need_sret<'ctx>(ctx: &'ctx Context, ty: BasicTypeEnum<'ctx>) -> bool {
need_sret_impl(ctx, ty, true) need_sret_impl(ctx, ty, true)
} }
pub fn gen_func<'ctx, G: CodeGenerator>( pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenContext) -> Result<(), String>> (
context: &'ctx Context, context: &'ctx Context,
generator: &mut G, generator: &mut G,
registry: &WorkerRegistry, registry: &WorkerRegistry,
builder: Builder<'ctx>, builder: Builder<'ctx>,
module: Module<'ctx>, module: Module<'ctx>,
task: CodeGenTask, task: CodeGenTask,
codegen_function: F
) -> Result<(Builder<'ctx>, Module<'ctx>, FunctionValue<'ctx>), (Builder<'ctx>, String)> { ) -> Result<(Builder<'ctx>, Module<'ctx>, FunctionValue<'ctx>), (Builder<'ctx>, String)> {
let top_level_ctx = registry.top_level_ctx.clone(); let top_level_ctx = registry.top_level_ctx.clone();
let static_value_store = registry.static_value_store.clone(); let static_value_store = registry.static_value_store.clone();
@ -572,25 +573,34 @@ pub fn gen_func<'ctx, G: CodeGenerator>(
need_sret: has_sret need_sret: has_sret
}; };
let mut err = None; let result = codegen_function(generator, &mut code_gen_context);
for stmt in task.body.iter() {
if let Err(e) = generator.gen_stmt(&mut code_gen_context, stmt) {
err = Some(e);
break;
}
if code_gen_context.is_terminated() {
break;
}
}
// after static analysis, only void functions can have no return at the end. // after static analysis, only void functions can have no return at the end.
if !code_gen_context.is_terminated() { if !code_gen_context.is_terminated() {
code_gen_context.builder.build_return(None); code_gen_context.builder.build_return(None);
} }
let CodeGenContext { builder, module, .. } = code_gen_context; let CodeGenContext { builder, module, .. } = code_gen_context;
if let Some(e) = err { if let Err(e) = result {
return Err((builder, e)); return Err((builder, e));
} }
Ok((builder, module, fn_val)) Ok((builder, module, fn_val))
} }
pub fn gen_func<'ctx, G: CodeGenerator>(
context: &'ctx Context,
generator: &mut G,
registry: &WorkerRegistry,
builder: Builder<'ctx>,
module: Module<'ctx>,
task: CodeGenTask,
) -> Result<(Builder<'ctx>, Module<'ctx>, FunctionValue<'ctx>), (Builder<'ctx>, String)> {
let body = task.body.clone();
gen_func_impl(context, generator, registry, builder, module, task, |generator, ctx| {
for stmt in body.iter() {
generator.gen_stmt(ctx, stmt)?;
}
Ok(())
})
}