From 3a7cd82b9967cb718daaeb630047c384dca019c1 Mon Sep 17 00:00:00 2001 From: abdul124 Date: Mon, 13 Jan 2025 18:22:19 +0800 Subject: [PATCH] [WIP] Fix Class Obj behavior --- nac3artiq/demo/demo.py | 3 + nac3artiq/demo/nac3artiq.so | 1 - nac3artiq/src/codegen.rs | 36 ++++++++-- nac3artiq/src/lib.rs | 12 +++- nac3core/src/toplevel/composer.rs | 112 +++++++++++++++++++++++++++++- 5 files changed, 157 insertions(+), 7 deletions(-) delete mode 120000 nac3artiq/demo/nac3artiq.so diff --git a/nac3artiq/demo/demo.py b/nac3artiq/demo/demo.py index aa135757..9f278372 100644 --- a/nac3artiq/demo/demo.py +++ b/nac3artiq/demo/demo.py @@ -1,5 +1,6 @@ from min_artiq import * +attr1: Kernel[str] = "ss" @nac3 class Demo: @@ -14,6 +15,8 @@ class Demo: @kernel def run(self): + global attr1 + # attr1 = "2" self.core.reset() while True: with parallel: diff --git a/nac3artiq/demo/nac3artiq.so b/nac3artiq/demo/nac3artiq.so deleted file mode 120000 index d05f6c9b..00000000 --- a/nac3artiq/demo/nac3artiq.so +++ /dev/null @@ -1 +0,0 @@ -../../target/release/libnac3artiq.so \ No newline at end of file diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 9baa0afe..fc99ac0d 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -8,8 +8,7 @@ use std::{ use itertools::Itertools; use pyo3::{ - types::{PyDict, PyList}, - PyObject, PyResult, Python, + types::{PyDict, PyList}, PyObject, PyResult, Python }; use super::{symbol_resolver::InnerResolver, timeline::TimeFns}; @@ -38,7 +37,7 @@ use nac3core::{ toplevel::{ helper::{extract_ndims, PrimDef}, numpy::unpack_ndarray_var_tys, - DefinitionId, GenCall, + DefinitionId, GenCall, TopLevelDef, }, typecheck::typedef::{iter_type_vars, FunSignature, FuncArg, Type, TypeEnum, VarMap}, }; @@ -982,7 +981,8 @@ pub fn attributes_writeback<'ctx>( values.push((ty, obj.to_basic_value_enum(ctx, generator, ty).unwrap())); } - for val in (*globals).values() { + // For now the global variables are just values so completely useless (14.0 stored as a float object in globals) + for (gloabl_id, val) in &*globals { let val = val.as_ref(py); let ty = inner_resolver.get_obj_type( py, @@ -995,6 +995,25 @@ pub fn attributes_writeback<'ctx>( return Ok(Err(ty)); } let ty = ty.unwrap(); + + if let Some(def_id) = inner_resolver.pyid_to_def.read().get(gloabl_id) { + if let TopLevelDef::Variable { name, .. } = &*top_levels[def_id.0].read() { + // println!("[+] Varaible with type: {:?}\n{:?}\n", ctx.unifier.stringify(*ty), ctx.unifier.get_ty(*ty)); + println!("Sending Value of {:?}", val.to_string()); + if gen_rpc_tag(ctx, ty, &mut scratch_buffer).is_ok() { + let pydict = PyDict::new(py); + pydict.set_item("global", val)?; + pydict.set_item("name", name)?; + host_attributes.append(pydict)?; + values.push(( + ty, + inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap(), + )); + } + continue; + } + } + match &*ctx.unifier.get_ty(ty) { TypeEnum::TObj { fields, obj_id, .. } if *obj_id != ctx.primitives.option.obj_id(&ctx.unifier).unwrap() => @@ -1003,6 +1022,15 @@ pub fn attributes_writeback<'ctx>( // for non-primitive attributes, they should be in another global let mut attributes = Vec::new(); let obj = inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap(); + + if !obj.is_pointer_value() && gen_rpc_tag(ctx, ty, &mut scratch_buffer).is_ok() { + println!("[-] Other function skipped"); + // values.push((ty, obj)); + // let pydict = PyDict::new(py); + // pydict.set_item("global", val)?; + // host_attributes.append(pydict)?; + // continue; + } for (name, (field_ty, is_mutable)) in fields { if !is_mutable { continue; diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 963e209c..d2b002c9 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -228,7 +228,17 @@ impl Nac3 { } }) } - StmtKind::AnnAssign { .. } => true, + // Allow global declaration with Kernel[X] to pass across + // These are kept in sync between kernel and host + // KernelInvariants are constants, and hence not included here + StmtKind::AnnAssign { ref annotation, .. } => { + match &annotation.node { + ExprKind::Subscript { value, .. } + if matches!(&value.node, ExprKind::Name { id, .. } + if id.to_string().as_str() == "Kernel") => true, + _ => false + } + }, _ => false, }; diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 82767ae1..2b1918fa 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -1947,6 +1947,103 @@ impl TopLevelComposer { let temp_def_list = self.extract_def_list(); let unifier = &mut self.unifier; let primitives_store = &self.primitives_ty; + + + // let dummy_field_type = unifier.get_dummy_var().ty; + + // let annotation = match value { + // None => { + // // handle Kernel[T], KernelInvariant[T] + // let (annotation, mutable) = match &annotation.node { + // ExprKind::Subscript { value, slice, .. } + // if matches!( + // &value.node, + // ast::ExprKind::Name { id, .. } if id == &core_config.kernel_invariant_ann.into() + // ) => + // { + // (slice, false) + // } + // ExprKind::Subscript { value, slice, .. } + // if matches!( + // &value.node, + // ast::ExprKind::Name { id, .. } if core_config.kernel_ann.map_or(false, |c| id == &c.into()) + // ) => + // { + // (slice, true) + // } + // _ if core_config.kernel_ann.is_none() => (annotation, true), + // _ => continue, // ignore fields annotated otherwise + // }; + // class_fields_def.push((*attr, dummy_field_type, mutable)); + // annotation + // } + // // Supporting Class Attributes + // Some(boxed_expr) => { + // // Class attributes are set as immutable regardless + // let (annotation, _) = match &annotation.node { + // ExprKind::Subscript { slice, .. } => (slice, false), + // _ if core_config.kernel_ann.is_none() => (annotation, false), + // _ => continue, + // }; + + // match &**boxed_expr { + // ast::Located {location: _, custom: (), node: ExprKind::Constant { value: v, kind: _ }} => { + // // Restricting the types allowed to be defined as class attributes + // match v { + // ast::Constant::Bool(_) | ast::Constant::Str(_) | ast::Constant::Int(_) | ast::Constant::Float(_) => {} + // _ => { + // return Err(HashSet::from([ + // format!( + // "unsupported statement in class definition body (at {})", + // b.location + // ), + // ])) + // } + // } + // class_attributes_def.push((*attr, dummy_field_type, v.clone())); + // } + // _ => { + // return Err(HashSet::from([ + // format!( + // "unsupported statement in class definition body (at {})", + // b.location + // ), + // ])) + // } + // } + // annotation + // } + // }; + // let parsed_annotation = parse_ast_to_type_annotation_kinds( + // class_resolver, + // temp_def_list, + // unifier, + // primitives, + // annotation.as_ref(), + // vec![(class_id, class_type_vars_def.clone())] + // .into_iter() + // .collect::>(), + // )?; + // // find type vars within this return type annotation + // let type_vars_within = + // get_type_var_contained_in_type_annotation(&parsed_annotation); + // // handle the class type var and the method type var + // for type_var_within in type_vars_within { + // let TypeAnnotation::TypeVar(t) = type_var_within else { + // unreachable!("must be type var annotation") + // }; + + // if !class_type_vars_def.contains(&t){ + // return Err(HashSet::from([ + // format!( + // "class fields can only use type \ + // vars over which the class is generic (at {})", + // annotation.location + // ), + // ])) + // } + // } + // type_var_to_concrete_def.insert(dummy_field_type, parsed_annotation); let mut analyze = |variable_def: &Arc>| -> Result<_, HashSet> { let TopLevelDef::Variable { ty: dummy_ty, ty_decl, resolver, loc, .. } = @@ -1955,10 +2052,23 @@ impl TopLevelComposer { // not top level variable def, skip return Ok(()); }; - + let resolver = &**resolver.as_ref().unwrap(); if let Some(ty_decl) = ty_decl { + let ty_decl = match &ty_decl.node { + ExprKind::Subscript { value, slice, .. } + if matches!( + &value.node, + ast::ExprKind::Name { id, .. } if self.core_config.kernel_ann.map_or(false, |c| id == &c.into()) + ) => + { + slice + } + _ if self.core_config.kernel_ann.is_none() => ty_decl, + _ => unreachable!("Global variables should be annotated with Kernel[]") // ignore fields annotated otherwise + }; + let ty_annotation = parse_ast_to_type_annotation_kinds( resolver, &temp_def_list,