diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 46ee93a..405645d 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -78,7 +78,7 @@ struct Nac3 { pyid_to_def: Arc>>, pyid_to_type: Arc>>, primitive_ids: PrimitivePythonId, - global_value_ids: Arc>>, + global_value_ids: Arc>>, working_directory: TempDir, top_levels: Vec, } @@ -417,6 +417,9 @@ impl Nac3 { class_names: Default::default(), name_to_pyid: name_to_pyid.clone(), module: module.clone(), + id_to_pyval: Default::default(), + id_to_primitive: Default::default(), + field_to_val: Default::default(), helper, }))) as Arc; @@ -468,15 +471,6 @@ impl Nac3 { ) }; let mut synthesized = parse_program(&synthesized).unwrap(); - let builtins = PyModule::import(py, "builtins")?; - let typings = PyModule::import(py, "typing")?; - let helper = PythonHelper { - id_fn: builtins.getattr("id").unwrap().to_object(py), - len_fn: builtins.getattr("len").unwrap().to_object(py), - type_fn: builtins.getattr("type").unwrap().to_object(py), - origin_ty_fn: typings.getattr("get_origin").unwrap().to_object(py), - args_ty_fn: typings.getattr("get_args").unwrap().to_object(py), - }; let resolver = Arc::new(Resolver(Arc::new(InnerResolver { id_to_type: self.builtins_ty.clone().into(), id_to_def: self.builtins_def.clone().into(), @@ -485,6 +479,9 @@ impl Nac3 { primitive_ids: self.primitive_ids.clone(), global_value_ids: self.global_value_ids.clone(), class_names: Default::default(), + id_to_pyval: Default::default(), + id_to_primitive: Default::default(), + field_to_val: Default::default(), name_to_pyid, module: module.to_object(py), helper, diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 0eeedf2..e42db75 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -12,7 +12,7 @@ use nac3core::{ use nac3parser::ast::{self, StrRef}; use parking_lot::{Mutex, RwLock}; use pyo3::{ - types::{PyList, PyModule, PyTuple}, + types::{PyDict, PyTuple}, PyAny, PyObject, PyResult, Python, }; use std::{ @@ -23,10 +23,20 @@ use std::{ use crate::PrimitivePythonId; +pub enum PrimitiveValue { + I32(i32), + I64(i64), + F64(f64), + Bool(bool), +} + pub struct InnerResolver { - pub id_to_type: Mutex>, - pub id_to_def: Mutex>, - pub global_value_ids: Arc>>, + pub id_to_type: RwLock>, + pub id_to_def: RwLock>, + pub id_to_pyval: RwLock>, + pub id_to_primitive: RwLock>, + pub field_to_val: RwLock>>, + pub global_value_ids: Arc>>, pub class_names: Mutex>, pub pyid_to_def: Arc>>, pub pyid_to_type: Arc>>, @@ -63,6 +73,20 @@ impl StaticValue for PythonValue { &self, ctx: &mut CodeGenContext<'ctx, 'a>, ) -> BasicValueEnum<'ctx> { + if let Some(val) = self.resolver.id_to_primitive.read().get(&self.id) { + return match val { + PrimitiveValue::I32(val) => ctx.ctx.i32_type().const_int(*val as u64, false).into(), + PrimitiveValue::I64(val) => ctx.ctx.i64_type().const_int(*val as u64, false).into(), + PrimitiveValue::F64(val) => ctx.ctx.f64_type().const_float(*val).into(), + PrimitiveValue::Bool(val) => { + ctx.ctx.bool_type().const_int(*val as u64, false).into() + } + }; + } + if let Some(global) = ctx.module.get_global(&self.id.to_string()) { + return global.as_pointer_value().into(); + } + Python::with_gil(|py| -> PyResult> { self.resolver .get_obj_value(py, self.value.as_ref(py), ctx) @@ -76,34 +100,48 @@ impl StaticValue for PythonValue { name: StrRef, ctx: &mut CodeGenContext<'ctx, 'a>, ) -> Option> { - Python::with_gil(|py| -> PyResult>> { - let helper = &self.resolver.helper; - let ty = helper.type_fn.call1(py, (&self.value,))?; - let ty_id: u64 = helper.id_fn.call1(py, (ty,))?.extract(py)?; - let def_id = { *self.resolver.pyid_to_def.read().get(&ty_id).unwrap() }; - let mut mutable = true; - let defs = ctx.top_level.definitions.read(); - if let TopLevelDef::Class { fields, .. } = &*defs[def_id.0].read() { - for (field_name, _, is_mutable) in fields.iter() { - if field_name == &name { - mutable = *is_mutable; - break; + { + let field_to_val = self.resolver.field_to_val.read(); + field_to_val.get(&(self.id, name)).cloned() + } + .unwrap_or_else(|| { + Python::with_gil(|py| -> PyResult> { + let helper = &self.resolver.helper; + let ty = helper.type_fn.call1(py, (&self.value,))?; + let ty_id: u64 = helper.id_fn.call1(py, (ty,))?.extract(py)?; + let def_id = { *self.resolver.pyid_to_def.read().get(&ty_id).unwrap() }; + let mut mutable = true; + let defs = ctx.top_level.definitions.read(); + if let TopLevelDef::Class { fields, .. } = &*defs[def_id.0].read() { + for (field_name, _, is_mutable) in fields.iter() { + if field_name == &name { + mutable = *is_mutable; + break; + } } } - } - Ok(if mutable { - None - } else { - let obj = self.value.getattr(py, &name.to_string())?; - let id = self.resolver.helper.id_fn.call1(py, (&obj,))?.extract(py)?; - Some(ValueEnum::Static(Arc::new(PythonValue { - id, - value: obj, - resolver: self.resolver.clone(), - }))) + let result = if mutable { + None + } else { + let obj = self.value.getattr(py, &name.to_string())?; + let id = self.resolver.helper.id_fn.call1(py, (&obj,))?.extract(py)?; + Some((id, obj)) + }; + self.resolver + .field_to_val + .write() + .insert((self.id, name), result.clone()); + Ok(result) }) + .unwrap() + }) + .map(|(id, obj)| { + ValueEnum::Static(Arc::new(PythonValue { + id, + value: obj, + resolver: self.resolver.clone(), + })) }) - .unwrap() } } @@ -148,11 +186,7 @@ impl InnerResolver { defs: &[Arc>], primitives: &PrimitiveStore, ) -> PyResult> { - let ty_id: u64 = self - .helper - .id_fn - .call1(py, (pyty,))? - .extract(py)?; + let ty_id: u64 = self.helper.id_fn.call1(py, (pyty,))?.extract(py)?; let ty_ty_id: u64 = self .helper .id_fn @@ -198,7 +232,8 @@ impl InnerResolver { } else { unreachable!() } - }).collect() + }) + .collect() }), fields: RefCell::new({ let mut res = methods @@ -207,7 +242,7 @@ impl InnerResolver { .collect::>(); res.extend(fields.clone().into_iter().map(|x| (x.0, (x.1, x.2)))); res - }) + }), }; // here also false, later instantiation use python object to check compatible (unifier.add_ty(ty), false) @@ -232,10 +267,10 @@ impl InnerResolver { "the {}th constraint of TypeVar `{}` is not concrete", i + 1, pyty.getattr("__name__")?.extract::()? - ))) + ))); } - }, - Err(err) => return Ok(Err(err)) + } + Err(err) => return Ok(Err(err)), } }) } else { @@ -246,31 +281,45 @@ impl InnerResolver { }; let res = unifier.get_fresh_var_with_range(&constraint_types).0; Ok(Ok((res, true))) - } else if ty_ty_id == self.primitive_ids.generic_alias.0 || ty_ty_id == self.primitive_ids.generic_alias.1 { + } else if ty_ty_id == self.primitive_ids.generic_alias.0 + || ty_ty_id == self.primitive_ids.generic_alias.1 + { let origin = self.helper.origin_ty_fn.call1(py, (pyty,))?; let args = self.helper.args_ty_fn.call1(py, (pyty,))?; let args: &PyTuple = args.cast_as(py)?; - let origin_ty = match self.get_pyty_obj_type(py, origin.as_ref(py), unifier, defs, primitives)? { - Ok((ty, false)) => ty, - Ok((_, true)) => return Ok(Err("instantiated type does not take type parameters".into())), - Err(err) => return Ok(Err(err)) - }; + let origin_ty = + match self.get_pyty_obj_type(py, origin.as_ref(py), unifier, defs, primitives)? { + Ok((ty, false)) => ty, + Ok((_, true)) => { + return Ok(Err("instantiated type does not take type parameters".into())) + } + Err(err) => return Ok(Err(err)), + }; match &*unifier.get_ty(origin_ty) { TypeEnum::TList { .. } => { if args.len() == 1 { - let ty = match self.get_pyty_obj_type(py, args.get_item(0), unifier, defs, primitives)? { + let ty = match self.get_pyty_obj_type( + py, + args.get_item(0), + unifier, + defs, + primitives, + )? { Ok(ty) => ty, - Err(err) => return Ok(Err(err)) + Err(err) => return Ok(Err(err)), }; if !unifier.is_concrete(ty.0, &[]) && !ty.1 { panic!("type list should take concrete parameters in type var ranges") } Ok(Ok((unifier.add_ty(TypeEnum::TList { ty: ty.0 }), true))) } else { - return Ok(Err(format!("type list needs exactly 1 type parameters, found {}", args.len()))) + return Ok(Err(format!( + "type list needs exactly 1 type parameters, found {}", + args.len() + ))); } - }, + } TypeEnum::TTuple { .. } => { let args = match args .iter() @@ -291,7 +340,7 @@ impl InnerResolver { _ => return Ok(Err("tuple type needs at least 1 type parameters".to_string())) }; Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: args }), true))) - }, + } TypeEnum::TObj { params, obj_id, .. } => { let subst = { let params = &*params.borrow(); @@ -301,7 +350,7 @@ impl InnerResolver { obj_id.0, params.len(), args.len(), - ))) + ))); } let args = match args .iter() @@ -326,29 +375,48 @@ impl InnerResolver { .map(|((id, _), ty)| (*id, *ty)) .collect::>() }; - Ok(Ok((unifier.subst(origin_ty, &subst).unwrap_or(origin_ty), true))) - }, + Ok(Ok(( + unifier.subst(origin_ty, &subst).unwrap_or(origin_ty), + true, + ))) + } TypeEnum::TVirtual { .. } => { if args.len() == 1 { - let ty = match self.get_pyty_obj_type(py, args.get_item(0), unifier, defs, primitives)? { + let ty = match self.get_pyty_obj_type( + py, + args.get_item(0), + unifier, + defs, + primitives, + )? { Ok(ty) => ty, - Err(err) => return Ok(Err(err)) + Err(err) => return Ok(Err(err)), }; if !unifier.is_concrete(ty.0, &[]) && !ty.1 { - panic!("virtual class should take concrete parameters in type var ranges") + panic!( + "virtual class should take concrete parameters in type var ranges" + ) } Ok(Ok((unifier.add_ty(TypeEnum::TVirtual { ty: ty.0 }), true))) } else { - return Ok(Err(format!("virtual class needs exactly 1 type parameters, found {}", args.len()))) + return Ok(Err(format!( + "virtual class needs exactly 1 type parameters, found {}", + args.len() + ))); } } - _ => unimplemented!() + _ => unimplemented!(), } } else if ty_id == self.primitive_ids.virtual_id { - Ok(Ok(({ - let ty = TypeEnum::TVirtual { ty: unifier.get_fresh_var().0 }; - unifier.add_ty(ty) - }, false))) + Ok(Ok(( + { + let ty = TypeEnum::TVirtual { + ty: unifier.get_fresh_var().0, + }; + unifier.add_ty(ty) + }, + false, + ))) } else { Ok(Err("unknown type".into())) } @@ -366,10 +434,18 @@ impl InnerResolver { let (extracted_ty, inst_check) = match self.get_pyty_obj_type( py, { - if [self.primitive_ids.typevar, + if [ + self.primitive_ids.typevar, self.primitive_ids.generic_alias.0, - self.primitive_ids.generic_alias.1 - ].contains(&self.helper.id_fn.call1(py, (ty.clone(),))?.extract::(py)?) { + self.primitive_ids.generic_alias.1, + ] + .contains( + &self + .helper + .id_fn + .call1(py, (ty.clone(),))? + .extract::(py)?, + ) { obj } else { ty.as_ref(py) @@ -377,10 +453,10 @@ impl InnerResolver { }, unifier, defs, - primitives + primitives, )? { Ok(s) => s, - Err(_) => return Ok(None) + Err(_) => return Ok(None), }; return match (&*unifier.get_ty(extracted_ty), inst_check) { // do the instantiation for these three types @@ -394,8 +470,8 @@ impl InnerResolver { )); Ok(Some(extracted_ty)) } else { - let actual_ty = self - .get_list_elem_type(py, obj, len, unifier, defs, primitives)?; + let actual_ty = + self.get_list_elem_type(py, obj, len, unifier, defs, primitives)?; if let Some(actual_ty) = actual_ty { unifier.unify(*ty, actual_ty).unwrap(); Ok(Some(extracted_ty)) @@ -429,14 +505,14 @@ impl InnerResolver { // loop through non-function fields of the class to get the instantiated value for field in fields.borrow().iter() { let name: String = (*field.0).into(); - if let TypeEnum::TFunc( .. ) = &*unifier.get_ty(field.1.0) { + if let TypeEnum::TFunc(..) = &*unifier.get_ty(field.1 .0) { continue; } else { let field_data = obj.getattr(&name)?; let ty = self .get_obj_type(py, field_data, unifier, defs, primitives)? .unwrap_or(primitives.none); - let field_ty = unifier.subst(field.1.0, &var_map).unwrap_or(field.1.0); + let field_ty = unifier.subst(field.1 .0, &var_map).unwrap_or(field.1 .0); if unifier.unify(ty, field_ty).is_err() { // field type mismatch return Ok(None); @@ -446,12 +522,16 @@ impl InnerResolver { for (_, ty) in var_map.iter() { // must be concrete type if !unifier.is_concrete(*ty, &[]) { - return Ok(None) + return Ok(None); } } - return Ok(Some(unifier.subst(extracted_ty, &var_map).unwrap_or(extracted_ty))); + return Ok(Some( + unifier + .subst(extracted_ty, &var_map) + .unwrap_or(extracted_ty), + )); } - _ => Ok(Some(extracted_ty)) + _ => Ok(Some(extracted_ty)), }; } @@ -466,23 +546,40 @@ impl InnerResolver { .id_fn .call1(py, (self.helper.type_fn.call1(py, (obj,))?,))? .extract(py)?; + let id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?; if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 { let val: i32 = obj.extract()?; + self.id_to_primitive + .write() + .insert(id, PrimitiveValue::I32(val)); Ok(Some(ctx.ctx.i32_type().const_int(val as u64, false).into())) } else if ty_id == self.primitive_ids.int64 { let val: i64 = obj.extract()?; + self.id_to_primitive + .write() + .insert(id, PrimitiveValue::I64(val)); Ok(Some(ctx.ctx.i64_type().const_int(val as u64, false).into())) } else if ty_id == self.primitive_ids.bool { let val: bool = obj.extract()?; + self.id_to_primitive + .write() + .insert(id, PrimitiveValue::Bool(val)); Ok(Some( ctx.ctx.bool_type().const_int(val as u64, false).into(), )) } else if ty_id == self.primitive_ids.float { let val: f64 = obj.extract()?; + self.id_to_primitive + .write() + .insert(id, PrimitiveValue::F64(val)); Ok(Some(ctx.ctx.f64_type().const_float(val).into())) } else if ty_id == self.primitive_ids.list { - let id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?; let id_str = id.to_string(); + + if let Some(global) = ctx.module.get_global(&id_str) { + return Ok(Some(global.as_pointer_value().into())); + } + let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?; let ty = if len == 0 { ctx.primitives.int32 @@ -507,15 +604,14 @@ impl InnerResolver { ); { - let mut global_value_ids = self.global_value_ids.lock(); - if global_value_ids.contains(&id) { + if self.global_value_ids.read().contains(&id) { let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { ctx.module .add_global(arr_ty, Some(AddressSpace::Generic), &id_str) }); return Ok(Some(global.as_pointer_value().into())); } else { - global_value_ids.insert(id); + self.global_value_ids.write().insert(id); } } @@ -583,8 +679,12 @@ impl InnerResolver { Ok(Some(global.as_pointer_value().into())) } else if ty_id == self.primitive_ids.tuple { - let id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?; let id_str = id.to_string(); + + if let Some(global) = ctx.module.get_global(&id_str) { + return Ok(Some(global.as_pointer_value().into())); + } + let elements: &PyTuple = obj.cast_as()?; let types: Result>, _> = elements .iter() @@ -603,15 +703,14 @@ impl InnerResolver { let ty = ctx.ctx.struct_type(&types, false); { - let mut global_value_ids = self.global_value_ids.lock(); - if global_value_ids.contains(&id) { + if self.global_value_ids.read().contains(&id) { let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { ctx.module .add_global(ty, Some(AddressSpace::Generic), &id_str) }); return Ok(Some(global.as_pointer_value().into())); } else { - global_value_ids.insert(id); + self.global_value_ids.write().insert(id); } } @@ -627,8 +726,12 @@ impl InnerResolver { global.set_initializer(&val); Ok(Some(global.as_pointer_value().into())) } else { - let id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?; let id_str = id.to_string(); + + if let Some(global) = ctx.module.get_global(&id_str) { + return Ok(Some(global.as_pointer_value().into())); + } + let top_level_defs = ctx.top_level.definitions.read(); let ty = self .get_obj_type(py, obj, &mut ctx.unifier, &top_level_defs, &ctx.primitives)? @@ -639,16 +742,16 @@ impl InnerResolver { .get_element_type() .into_struct_type() .as_basic_type_enum(); + { - let mut global_value_ids = self.global_value_ids.lock(); - if global_value_ids.contains(&id) { + if self.global_value_ids.read().contains(&id) { let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { ctx.module .add_global(ty, Some(AddressSpace::Generic), &id_str) }); return Ok(Some(global.as_pointer_value().into())); } else { - global_value_ids.insert(id); + self.global_value_ids.write().insert(id); } } // should be classes @@ -727,14 +830,10 @@ impl SymbolResolver for Resolver { ast::ExprKind::Name { id, .. } => { Python::with_gil(|py| -> PyResult> { let obj: &PyAny = self.0.module.extract(py)?; - let members: &PyList = PyModule::import(py, "inspect")? - .getattr("getmembers")? - .call1((obj,))? - .cast_as()?; + let members: &PyDict = obj.getattr("__dict__").unwrap().cast_as().unwrap(); let mut sym_value = None; - for member in members.iter() { - let key: &str = member.get_item(0)?.extract()?; - let val = member.get_item(1)?; + for (key, val) in members.iter() { + let key: &str = key.extract()?; if key == id.to_string() { sym_value = Some( self.0 @@ -760,38 +859,40 @@ impl SymbolResolver for Resolver { primitives: &PrimitiveStore, str: StrRef, ) -> Option { - let mut id_to_type = self.0.id_to_type.lock(); - id_to_type.get(&str).cloned().or_else(|| { + { + let id_to_type = self.0.id_to_type.read(); + id_to_type.get(&str).cloned() + } + .or_else(|| { let py_id = self.0.name_to_pyid.get(&str); let result = py_id.and_then(|id| { - self.0.pyid_to_type.read().get(id).copied().or_else(|| { - Python::with_gil(|py| -> PyResult> { + { + let pyid_to_type = self.0.pyid_to_type.read(); + pyid_to_type.get(id).copied() + } + .or_else(|| { + let result = Python::with_gil(|py| -> PyResult> { let obj: &PyAny = self.0.module.extract(py)?; - let members: &PyList = PyModule::import(py, "inspect")? - .getattr("getmembers")? - .call1((obj,))? - .cast_as()?; let mut sym_ty = None; - for member in members.iter() { - let key: &str = member.get_item(0)?.extract()?; + let members: &PyDict = obj.getattr("__dict__").unwrap().cast_as().unwrap(); + for (key, val) in members.iter() { + let key: &str = key.extract()?; if key == str.to_string() { - sym_ty = self.0.get_obj_type( - py, - member.get_item(1)?, - unifier, - defs, - primitives, - )?; + sym_ty = self.0.get_obj_type(py, val, unifier, defs, primitives)?; break; } } Ok(sym_ty) }) - .unwrap() + .unwrap(); + if let Some(result) = result { + self.0.pyid_to_type.write().insert(*id, result); + } + result }) }); if let Some(result) = &result { - id_to_type.insert(str, *result); + self.0.id_to_type.write().insert(str, *result); } result }) @@ -802,29 +903,37 @@ impl SymbolResolver for Resolver { id: StrRef, _: &mut CodeGenContext<'ctx, 'a>, ) -> Option> { - Python::with_gil(|py| -> PyResult>> { - let obj: &PyAny = self.0.module.extract(py)?; - let members: &PyList = PyModule::import(py, "inspect")? - .getattr("getmembers")? - .call1((obj,))? - .cast_as()?; - let mut sym_value = None; - for member in members.iter() { - let key: &str = member.get_item(0)?.extract()?; - let val = member.get_item(1)?; - if key == id.to_string() { - let id = self.0.helper.id_fn.call1(py, (val,))?.extract(py)?; - sym_value = Some(PythonValue { - id, - value: val.extract()?, - resolver: self.0.clone(), - }); - break; + let sym_value = { + let id_to_val = self.0.id_to_pyval.read(); + id_to_val.get(&id).cloned() + } + .or_else(|| { + Python::with_gil(|py| -> PyResult> { + let obj: &PyAny = self.0.module.extract(py)?; + let mut sym_value: Option<(u64, PyObject)> = None; + let members: &PyDict = obj.getattr("__dict__").unwrap().cast_as().unwrap(); + for (key, val) in members.iter() { + let key: &str = key.extract()?; + if key == id.to_string() { + let id = self.0.helper.id_fn.call1(py, (val,))?.extract(py)?; + sym_value = Some((id, val.extract()?)); + break; + } } - } - Ok(sym_value.map(|v| ValueEnum::Static(Arc::new(v)))) + if let Some((pyid, val)) = &sym_value { + self.0.id_to_pyval.write().insert(id, (*pyid, val.clone())); + } + Ok(sym_value) + }) + .unwrap() + }); + sym_value.map(|(id, v)| { + ValueEnum::Static(Arc::new(PythonValue { + id, + value: v, + resolver: self.0.clone(), + })) }) - .unwrap() } fn get_symbol_location(&self, _: StrRef) -> Option { @@ -832,14 +941,17 @@ impl SymbolResolver for Resolver { } fn get_identifier_def(&self, id: StrRef) -> Option { - let mut id_to_def = self.0.id_to_def.lock(); - id_to_def.get(&id).cloned().or_else(|| { + { + let id_to_def = self.0.id_to_def.read(); + id_to_def.get(&id).cloned() + } + .or_else(|| { let py_id = self.0.name_to_pyid.get(&id); let result = py_id.and_then(|id| self.0.pyid_to_def.read().get(id).copied()); if let Some(result) = &result { - id_to_def.insert(id, *result); + self.0.id_to_def.write().insert(id, *result); } result }) } -} \ No newline at end of file +}