From 08947d20c202e8d62c0845735e7a52070351bad5 Mon Sep 17 00:00:00 2001 From: pca006132 Date: Thu, 7 Oct 2021 15:57:45 +0800 Subject: [PATCH] nac3artiq: implements #33 --- nac3artiq/src/symbol_resolver.rs | 119 ++++++++++++++++++++++++++----- 1 file changed, 103 insertions(+), 16 deletions(-) diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 1b1c62e9..098f86b1 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -16,6 +16,7 @@ use pyo3::{ }; use rustpython_parser::ast::StrRef; use std::{ + cell::RefCell, collections::{HashMap, HashSet}, sync::Arc, }; @@ -109,6 +110,46 @@ impl Resolver { .collect(); let types = types?; Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types }))) + } else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id) { + let def = defs[def_id.0].read(); + if let TopLevelDef::Class { + object_id, + type_vars, + fields, + methods, + .. + } = &*def + { + if type_vars.is_empty() { + let mut fields_ty = HashMap::new(); + for method in methods.iter() { + fields_ty.insert(method.0, method.1); + } + for field in fields.iter() { + let name: String = field.0.into(); + let field_data = obj.getattr(&name)?; + let ty = self + .get_obj_type(field_data, helper, unifier, defs, primitives)? + .unwrap_or(primitives.none); + if unifier.unify(ty, field.1).is_err() { + // field type mismatch + return Ok(None); + } + fields_ty.insert(field.0, ty); + } + Ok(Some(unifier.add_ty(TypeEnum::TObj { + obj_id: *object_id, + fields: RefCell::new(fields_ty), + params: Default::default(), + }))) + } else { + // type var is not supported for now + Ok(None) + } + } else { + // only object is supported, functions are not supported + Ok(None) + } } else { Ok(None) } @@ -142,20 +183,10 @@ impl Resolver { let id: u64 = helper.id_fn.call1((obj,))?.extract()?; let id_str = id.to_string(); let len: usize = helper.len_fn.call1((obj,))?.extract()?; - if len == 0 { - let int32 = ctx.ctx.i32_type(); - return Ok(Some( - ctx.ctx - .struct_type( - &[int32.into(), int32.ptr_type(AddressSpace::Generic).into()], - false, - ) - .const_zero() - .into(), - )); - } - let ty = self - .get_list_elem_type( + let ty = if len == 0 { + ctx.primitives.int32 + } else { + self.get_list_elem_type( obj, len, helper, @@ -163,7 +194,8 @@ impl Resolver { &ctx.top_level.definitions.read(), &ctx.primitives, )? - .unwrap(); + .unwrap() + }; let ty = ctx.get_llvm_type(ty); let arr_ty = ctx.ctx.struct_type( &[ @@ -294,7 +326,62 @@ impl Resolver { global.set_initializer(&val); Ok(Some(global.as_pointer_value().into())) } else { - Ok(None) + let id: u64 = helper.id_fn.call1((obj,))?.extract()?; + let id_str = id.to_string(); + let top_level_defs = ctx.top_level.definitions.read(); + let ty = self + .get_obj_type( + obj, + helper, + &mut ctx.unifier, + &top_level_defs, + &ctx.primitives, + )? + .unwrap(); + let ty = ctx + .get_llvm_type(ty) + .into_pointer_type() + .get_element_type() + .into_struct_type() + .as_basic_type_enum(); + { + let mut global_value_ids = self.global_value_ids.lock(); + if global_value_ids.contains(&id) { + let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { + ctx.module + .add_global(ty, Some(AddressSpace::Generic), &id_str) + }); + return Ok(Some(global.as_pointer_value().into())); + } else { + global_value_ids.insert(id); + } + } + // should be classes + let definition = top_level_defs + .get(self.pyid_to_def.read().get(&ty_id).unwrap().0) + .unwrap() + .read(); + if let TopLevelDef::Class { fields, .. } = &*definition { + let values: Result>, _> = fields + .iter() + .map(|(name, _)| { + self.get_obj_value(obj.getattr(&name.to_string())?, helper, ctx) + }) + .collect(); + let values = values?; + if let Some(values) = values { + let val = ctx.ctx.const_struct(&values, false); + let global = ctx + .module + .add_global(ty, Some(AddressSpace::Generic), &id_str); + global.set_initializer(&val); + Ok(Some(global.as_pointer_value().into())) + } else { + Ok(None) + } + } else { + unreachable!() + } } } }