diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index c63be8e3..c4e8fc77 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -87,6 +87,7 @@ struct Nac3 { working_directory: TempDir, top_levels: Vec<TopLevelComponent>, string_store: Arc<RwLock<HashMap<String, i32>>>, + exception_ids: Arc<RwLock<HashMap<usize, usize>>>, } create_exception!(nac3artiq, CompileError, exceptions::PyException); @@ -395,6 +396,7 @@ impl Nac3 { pyid_to_def: Default::default(), working_directory, string_store: Default::default(), + exception_ids: Default::default(), }) } @@ -442,6 +444,8 @@ impl Nac3 { let builtins = PyModule::import(py, "builtins")?; let typings = PyModule::import(py, "typing")?; let id_fn = builtins.getattr("id")?; + let issubclass = builtins.getattr("issubclass")?; + let exn_class = builtins.getattr("Exception")?; let store_obj = embedding_map.getattr("store_object").unwrap().to_object(py); let store_str = embedding_map.getattr("store_str").unwrap().to_object(py); let store_fun = embedding_map.getattr("store_function").unwrap().to_object(py); @@ -451,7 +455,7 @@ impl Nac3 { type_fn: builtins.getattr("type").unwrap().to_object(py), origin_ty_fn: typings.getattr("get_origin").unwrap().to_object(py), args_ty_fn: typings.getattr("get_args").unwrap().to_object(py), - store_obj, + store_obj: store_obj.clone(), store_str, }; let mut module_to_resolver_cache: HashMap<u64, _> = HashMap::new(); @@ -463,6 +467,18 @@ impl Nac3 { let py_module: &PyAny = module.extract(py)?; let module_id: u64 = id_fn.call1((py_module,))?.extract()?; let helper = helper.clone(); + let class_obj; + if let StmtKind::ClassDef { name, .. } = &stmt.node { + let class = py_module.getattr(name.to_string()).unwrap(); + if issubclass.call1((class, exn_class)).unwrap().extract().unwrap() && + class.getattr("artiq_builtin").is_err() { + class_obj = Some(class); + } else { + class_obj = None; + } + } else { + class_obj = None; + } let (name_to_pyid, resolver) = module_to_resolver_cache.get(&module_id).cloned().unwrap_or_else(|| { let mut name_to_pyid: HashMap<StrRef, u64> = HashMap::new(); @@ -488,6 +504,7 @@ impl Nac3 { field_to_val: Default::default(), helper, string_store: self.string_store.clone(), + exception_ids: self.exception_ids.clone(), }))) as Arc<dyn SymbolResolver + Send + Sync>; let name_to_pyid = Rc::new(name_to_pyid); @@ -504,6 +521,9 @@ impl Nac3 { e )) })?; + if let Some(class_obj) = class_obj { + self.exception_ids.write().insert(def_id.0, store_obj.call1(py, (class_obj, ))?.extract(py)?); + } match &stmt.node { StmtKind::FunctionDef { decorator_list, .. } => { @@ -569,6 +589,7 @@ impl Nac3 { module: module.to_object(py), helper, string_store: self.string_store.clone(), + exception_ids: self.exception_ids.clone(), }))) as Arc<dyn SymbolResolver + Send + Sync>; let (_, def_id, _) = composer .register_top_level(synthesized.pop().unwrap(), Some(resolver.clone()), "".into()) diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index e104ec56..d59b0fac 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -41,6 +41,7 @@ pub struct InnerResolver { pub primitive_ids: PrimitivePythonId, pub helper: PythonHelper, pub string_store: Arc<RwLock<HashMap<String, i32>>>, + pub exception_ids: Arc<RwLock<HashMap<usize, usize>>>, // module specific pub name_to_pyid: HashMap<StrRef, u64>, pub module: PyObject, @@ -977,4 +978,9 @@ impl SymbolResolver for Resolver { id } } + + fn get_exception_id(&self, tyid: usize) -> usize { + let exn_ids = self.0.exception_ids.read(); + exn_ids.get(&tyid).cloned().unwrap_or(0) + } } diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 96b39972..1bc72857 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -468,7 +468,7 @@ pub fn exn_constructor<'ctx, 'a>( let def = defs[zelf_id].read(); let zelf_name = if let TopLevelDef::Class { name, .. } = &*def { *name } else { unreachable!() }; - let exception_name = format!("0:{}", zelf_name); + let exception_name = format!("{}:{}", ctx.resolver.get_exception_id(zelf_id), zelf_name); unsafe { let id_ptr = ctx.builder.build_in_bounds_gep(zelf, &[zero, zero], "exn.id"); let id = ctx.resolver.get_string_id(&exception_name); @@ -641,7 +641,13 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>( &mut ctx.unifier, type_.custom.unwrap(), ); - let exn_id = ctx.resolver.get_string_id(&format!("0:{}", exn_name)); + let obj_id = if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(type_.custom.unwrap()) { + *obj_id + } else { + unreachable!() + }; + let exception_name = format!("{}:{}", ctx.resolver.get_exception_id(obj_id.0), exn_name); + let exn_id = ctx.resolver.get_string_id(&exception_name); let exn_id_global = ctx.module.add_global(ctx.ctx.i32_type(), None, &format!("exn.{}", exn_id)); exn_id_global.set_initializer(&ctx.ctx.i32_type().const_int(exn_id as u64, false)); diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index a70c91bf..0a22cc06 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -136,6 +136,7 @@ pub trait SymbolResolver { fn get_default_param_value(&self, expr: &nac3parser::ast::Expr) -> Option<SymbolValue>; fn get_string_id(&self, s: &str) -> i32; + fn get_exception_id(&self, tyid: usize) -> usize; // handle function call etc. } diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index bf36adcd..694d0a06 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -1281,6 +1281,7 @@ impl TopLevelComposer { )); } } + ast::StmtKind::Assign { .. } => {}, // we don't class attributes ast::StmtKind::Pass { .. } => {} ast::StmtKind::Expr { value: _, .. } => {} // typically a docstring; ignoring all expressions matches CPython behavior _ => {