diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index a4cca53..a5f42aa 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -425,7 +425,7 @@ fn rpc_codegen_callback_fn<'ctx>( if obj.is_some() { tag.push(b'O'); } - for arg in fun.0.args.iter() { + for arg in &fun.0.args { gen_rpc_tag(ctx, arg.ty, &mut tag)?; } tag.push(b':'); @@ -461,7 +461,7 @@ fn rpc_codegen_callback_fn<'ctx>( }) .as_pointer_value(); - let arg_length = args.len() + if obj.is_some() { 1 } else { 0 }; + let arg_length = args.len() + usize::from(obj.is_some()); let stacksave = ctx.module.get_function("llvm.stacksave").unwrap_or_else(|| { ctx.module.add_function("llvm.stacksave", ptr_type.fn_type(&[], false), None) @@ -484,11 +484,11 @@ fn rpc_codegen_callback_fn<'ctx>( // -- rpc args handling let mut keys = fun.0.args.clone(); let mut mapping = HashMap::new(); - for (key, value) in args.into_iter() { + for (key, value) in args { mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value); } // default value handling - for k in keys.into_iter() { + for k in keys { mapping.insert( k.name, ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into() @@ -518,7 +518,7 @@ fn rpc_codegen_callback_fn<'ctx>( ctx.builder.build_gep( args_ptr, &[int32.const_int(i as u64, false)], - &format!("rpc.arg{}", i), + &format!("rpc.arg{i}"), ) }; ctx.builder.build_store(arg_ptr, arg_slot); @@ -621,7 +621,7 @@ pub fn attributes_writeback( ctx: &mut CodeGenContext<'_, '_>, generator: &mut dyn CodeGenerator, inner_resolver: &InnerResolver, - host_attributes: PyObject, + host_attributes: &PyObject, ) -> Result<(), String> { Python::with_gil(|py| -> PyResult> { let host_attributes: &PyList = host_attributes.downcast(py)?; @@ -631,7 +631,7 @@ pub fn attributes_writeback( let zero = int32.const_zero(); let mut values = Vec::new(); let mut scratch_buffer = Vec::new(); - for (_, val) in globals.iter() { + for val in (*globals).values() { let val = val.as_ref(py); let ty = inner_resolver.get_obj_type(py, val, &mut ctx.unifier, &top_levels, &ctx.primitives)?; if let Err(ty) = ty { @@ -646,7 +646,7 @@ pub fn attributes_writeback( // for non-primitive attributes, they should be in another global let mut attributes = Vec::new(); let obj = inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap(); - for (name, (field_ty, is_mutable)) in fields.iter() { + for (name, (field_ty, is_mutable)) in fields { if !is_mutable { continue } @@ -683,7 +683,7 @@ pub fn attributes_writeback( default_value: None }).collect(), ret: ctx.primitives.none, - vars: Default::default() + vars: HashMap::default() }; let args: Vec<_> = values.into_iter().map(|(_, val)| (None, ValueEnum::Dynamic(val))).collect(); if let Err(e) = rpc_codegen_callback_fn(ctx, None, (&fun, DefinitionId(0)), args, generator) { diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 30d439e..59b3a91 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -109,7 +109,7 @@ create_exception!(nac3artiq, CompileError, exceptions::PyException); impl Nac3 { fn register_module( &mut self, - module: PyObject, + module: &PyObject, registered_class_ids: &HashSet, ) -> PyResult<()> { let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> { @@ -118,12 +118,12 @@ impl Nac3 { })?; let source = fs::read_to_string(&source_file).map_err(|e| { - exceptions::PyIOError::new_err(format!("failed to read input file: {}", e)) + exceptions::PyIOError::new_err(format!("failed to read input file: {e}")) })?; let parser_result = parse_program(&source, source_file.into()) - .map_err(|e| exceptions::PySyntaxError::new_err(format!("parse error: {}", e)))?; + .map_err(|e| exceptions::PySyntaxError::new_err(format!("parse error: {e}")))?; - for mut stmt in parser_result.into_iter() { + for mut stmt in parser_result { let include = match stmt.node { StmtKind::ClassDef { ref decorator_list, ref mut body, ref mut bases, .. @@ -197,7 +197,7 @@ impl Nac3 { fn report_modinit( arg_names: &[String], method_name: &str, - resolver: Arc, + resolver: &Arc, top_level_defs: &[Arc>], unifier: &mut Unifier, primitives: &PrimitiveStore, @@ -205,7 +205,7 @@ impl Nac3 { let base_ty = match resolver.get_symbol_type(unifier, top_level_defs, primitives, "base".into()) { Ok(ty) => ty, - Err(e) => return Some(format!("type error inside object launching kernel: {}", e)), + Err(e) => return Some(format!("type error inside object launching kernel: {e}")), }; let fun_ty = if method_name.is_empty() { @@ -215,8 +215,7 @@ impl Nac3 { Some(t) => t.0, None => { return Some(format!( - "object launching kernel does not have method `{}`", - method_name + "object launching kernel does not have method `{method_name}`" )) } } @@ -237,8 +236,7 @@ impl Nac3 { Some(n) => n, None if default_value.is_none() => { return Some(format!( - "argument `{}` not provided when launching kernel function", - name + "argument `{name}` not provided when launching kernel function" )) } _ => break, @@ -252,8 +250,7 @@ impl Nac3 { Ok(t) => t, Err(e) => { return Some(format!( - "type error ({}) at parameter #{} when calling kernel function", - e, i + "type error ({e}) at parameter #{i} when calling kernel function" )) } }; @@ -322,7 +319,7 @@ impl Nac3 { let mut module_to_resolver_cache: HashMap = HashMap::new(); let mut rpc_ids = vec![]; - for (stmt, path, module) in self.top_levels.iter() { + for (stmt, path, module) in &self.top_levels { let py_module: &PyAny = module.extract(py)?; let module_id: u64 = id_fn.call1((py_module,))?.extract()?; let helper = helper.clone(); @@ -343,7 +340,7 @@ impl Nac3 { let mut name_to_pyid: HashMap = HashMap::new(); let members: &PyDict = py_module.getattr("__dict__").unwrap().downcast().unwrap(); - for (key, val) in members.iter() { + for (key, val) in members { let key: &str = key.extract().unwrap(); let val = id_fn.call1((val,)).unwrap().extract().unwrap(); name_to_pyid.insert(key.into(), val); @@ -355,12 +352,12 @@ impl Nac3 { pyid_to_type: pyid_to_type.clone(), primitive_ids: self.primitive_ids.clone(), global_value_ids: global_value_ids.clone(), - class_names: Default::default(), + class_names: Mutex::default(), name_to_pyid: name_to_pyid.clone(), module: module.clone(), - id_to_pyval: Default::default(), - id_to_primitive: Default::default(), - field_to_val: Default::default(), + id_to_pyval: RwLock::default(), + id_to_primitive: RwLock::default(), + field_to_val: RwLock::default(), helper, string_store: self.string_store.clone(), exception_ids: self.exception_ids.clone(), @@ -377,8 +374,7 @@ impl Nac3 { .register_top_level(stmt.clone(), Some(resolver.clone()), path, false) .map_err(|e| { CompileError::new_err(format!( - "compilation failed\n----------\n{}", - e + "compilation failed\n----------\n{e}" )) })?; if let Some(class_obj) = class_obj { @@ -395,7 +391,7 @@ impl Nac3 { StmtKind::ClassDef { name, body, .. } => { let class_name = name.to_string(); let class_obj = module.getattr(py, class_name.as_str()).unwrap(); - for stmt in body.iter() { + for stmt in body { if let StmtKind::FunctionDef { name, decorator_list, .. } = &stmt.node { if decorator_list.iter().any(|decorator| matches!(decorator.node, ExprKind::Name { id, .. } if id == "rpc".into())) { if name == &"__init__".into() { @@ -429,7 +425,7 @@ impl Nac3 { name_to_pyid.insert("base".into(), id_fun.call1((obj,))?.extract()?); let mut arg_names = vec![]; for (i, arg) in args.into_iter().enumerate() { - let name = format!("tmp{}", i); + let name = format!("tmp{i}"); module.add(&name, arg)?; name_to_pyid.insert(name.clone().into(), id_fun.call1((arg,))?.extract()?); arg_names.push(name); @@ -448,10 +444,10 @@ impl Nac3 { pyid_to_type: pyid_to_type.clone(), primitive_ids: self.primitive_ids.clone(), global_value_ids: global_value_ids.clone(), - class_names: Default::default(), - id_to_pyval: Default::default(), - id_to_primitive: Default::default(), - field_to_val: Default::default(), + class_names: Mutex::default(), + id_to_pyval: RwLock::default(), + id_to_primitive: RwLock::default(), + field_to_val: RwLock::default(), name_to_pyid, module: module.to_object(py), helper, @@ -461,7 +457,7 @@ impl Nac3 { }); let resolver = Arc::new(Resolver(inner_resolver.clone())) as Arc; let (_, def_id, _) = composer - .register_top_level(synthesized.pop().unwrap(), Some(resolver.clone()), "".into(), false) + .register_top_level(synthesized.pop().unwrap(), Some(resolver.clone()), "", false) .unwrap(); let fun_signature = @@ -474,16 +470,11 @@ impl Nac3 { if let Err(e) = composer.start_analysis(true) { // report error of __modinit__ separately - return if !e.contains("") { - Err(CompileError::new_err(format!( - "compilation failed\n----------\n{}", - e - ))) - } else { + return if e.contains("") { let msg = Self::report_modinit( &arg_names, method_name, - resolver.clone(), + &resolver, &composer.extract_def_list(), &mut composer.unifier, &self.primitive, @@ -492,6 +483,10 @@ impl Nac3 { "compilation failed\n----------\n{}", msg.unwrap_or(e) ))) + } else { + Err(CompileError::new_err(format!( + "compilation failed\n----------\n{e}" + ))) } } let top_level = Arc::new(composer.make_top_level_context()); @@ -499,7 +494,7 @@ impl Nac3 { { let rpc_codegen = rpc_codegen_callback(); let defs = top_level.definitions.read(); - for (class_data, id) in rpc_ids.iter() { + for (class_data, id) in &rpc_ids { let mut def = defs[id.0].write(); match &mut *def { TopLevelDef::Function { codegen_callback, .. } => { @@ -507,7 +502,7 @@ impl Nac3 { } TopLevelDef::Class { methods, .. } => { let (class_def, method_name) = class_data.as_ref().unwrap(); - for (name, _, id) in methods.iter() { + for (name, _, id) in &*methods { if name != method_name { continue; } @@ -537,7 +532,7 @@ impl Nac3 { if let TopLevelDef::Function { instance_to_stmt, instance_to_symbol, .. } = &mut *definition { - instance_to_symbol.insert("".to_string(), "__modinit__".into()); + instance_to_symbol.insert(String::new(), "__modinit__".into()); instance_to_stmt[""].clone() } else { unreachable!() @@ -545,7 +540,7 @@ impl Nac3 { }; let task = CodeGenTask { - subst: Default::default(), + subst: Vec::default(), symbol_name: "__modinit__".to_string(), body: instance.body, signature, @@ -562,18 +557,18 @@ impl Nac3 { store.from_signature(&mut composer.unifier, &self.primitive, &fun_signature, &mut cache); let signature = store.add_cty(signature); let attributes_writeback_task = CodeGenTask { - subst: Default::default(), + subst: Vec::default(), symbol_name: "attributes_writeback".to_string(), - body: Arc::new(Default::default()), + body: Arc::new(Vec::default()), signature, resolver, store, unifier_index: instance.unifier_id, - calls: Arc::new(Default::default()), + calls: Arc::new(HashMap::default()), id: 0, }; - let membuffers: Arc>>> = Default::default(); + let membuffers: Arc>>> = Arc::default(); let membuffer = membuffers.clone(); @@ -607,7 +602,7 @@ impl Nac3 { let builder = context.create_builder(); let (_, module, _) = gen_func_impl(&context, &mut generator, ®istry, builder, module, attributes_writeback_task, |generator, ctx| { - attributes_writeback(ctx, generator, inner_resolver.as_ref(), host_attributes) + attributes_writeback(ctx, generator, inner_resolver.as_ref(), &host_attributes) }).unwrap(); let buffer = module.write_bitcode_to_memory(); let buffer = buffer.as_slice().into(); @@ -671,7 +666,7 @@ impl Nac3 { link_fn(&main) } - /// Returns the [TargetTriple] used for compiling to [isa]. + /// Returns the [`TargetTriple`] used for compiling to [isa]. fn get_llvm_target_triple(isa: Isa) -> TargetTriple { match isa { Isa::Host => TargetMachine::get_default_triple(), @@ -680,7 +675,7 @@ impl Nac3 { } } - /// Returns the [String] representing the target CPU used for compiling to [isa]. + /// Returns the [`String`] representing the target CPU used for compiling to [isa]. fn get_llvm_target_cpu(isa: Isa) -> String { match isa { Isa::Host => TargetMachine::get_host_cpu_name().to_string(), @@ -689,7 +684,7 @@ impl Nac3 { } } - /// Returns the [String] representing the target features used for compiling to [isa]. + /// Returns the [`String`] representing the target features used for compiling to [isa]. fn get_llvm_target_features(isa: Isa) -> String { match isa { Isa::Host => TargetMachine::get_host_cpu_features().to_string(), @@ -699,7 +694,7 @@ impl Nac3 { } } - /// Returns an instance of [CodeGenTargetMachineOptions] representing the target machine + /// Returns an instance of [`CodeGenTargetMachineOptions`] representing the target machine /// options used for compiling to [isa]. fn get_llvm_target_options(isa: Isa) -> CodeGenTargetMachineOptions { CodeGenTargetMachineOptions { @@ -711,7 +706,7 @@ impl Nac3 { } } - /// Returns an instance of [TargetMachine] used in compiling and linking of a program to the + /// Returns an instance of [`TargetMachine`] used in compiling and linking of a program to the /// target [isa]. fn get_llvm_target_machine(&self) -> TargetMachine { Nac3::get_llvm_target_options(self.isa) @@ -790,10 +785,9 @@ impl Nac3 { _ => return Err(exceptions::PyValueError::new_err("invalid ISA")), }; let time_fns: &(dyn TimeFns + Sync) = match isa { - Isa::Host => &timeline::EXTERN_TIME_FNS, Isa::RiscV32G => &timeline::NOW_PINNING_TIME_FNS_64, Isa::RiscV32IMA => &timeline::NOW_PINNING_TIME_FNS, - Isa::CortexA9 => &timeline::EXTERN_TIME_FNS, + Isa::CortexA9 | Isa::Host => &timeline::EXTERN_TIME_FNS, }; let primitive: PrimitiveStore = TopLevelComposer::make_primitives().0; let builtins = vec![ @@ -884,7 +878,7 @@ impl Nac3 { .and_then(|v| v.get_item("_ConstGenericMarker")) .unwrap(), )) - .and_then(|v| v.extract()) + .and_then(PyAny::extract) .unwrap(), int: get_attr_id(builtins_mod, "int"), int32: get_attr_id(numpy_mod, "int32"), @@ -919,11 +913,11 @@ impl Nac3 { primitive, builtins, primitive_ids, - top_levels: Default::default(), - pyid_to_def: Default::default(), + top_levels: Vec::default(), + pyid_to_def: Arc::default(), working_directory, - string_store: Default::default(), - exception_ids: Default::default(), + string_store: Arc::default(), + exception_ids: Arc::default(), deferred_eval_store: DeferredEvaluationStore::new(), llvm_options: CodeGenLLVMOptions { opt_level: OptimizationLevel::Default, @@ -941,11 +935,11 @@ impl Nac3 { let id_fn = PyModule::import(py, "builtins")?.getattr("id")?; let getmodule_fn = PyModule::import(py, "inspect")?.getattr("getmodule")?; - for function in functions.iter() { + for function in functions { let module = getmodule_fn.call1((function,))?.extract()?; modules.insert(id_fn.call1((&module,))?.extract()?, module); } - for class in classes.iter() { + for class in classes { let module = getmodule_fn.call1((class,))?.extract()?; modules.insert(id_fn.call1((&module,))?.extract()?, module); class_ids.insert(id_fn.call1((class,))?.extract()?); @@ -954,7 +948,7 @@ impl Nac3 { })?; for module in modules.into_values() { - self.register_module(module, &class_ids)?; + self.register_module(&module, &class_ids)?; } Ok(()) } diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 31c9640..bb5f6b0 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -100,25 +100,25 @@ impl StaticValue for PythonValue { ) -> BasicValueEnum<'ctx> { ctx.module .get_global(format!("{}_const", self.id).as_str()) - .map(|val| val.as_pointer_value().into()) - .unwrap_or_else(|| { - Python::with_gil(|py| -> PyResult> { - let id: u32 = self.store_obj.call1(py, (self.value.clone(),))?.extract(py)?; - let struct_type = ctx.ctx.struct_type(&[ctx.ctx.i32_type().into()], false); - let global = ctx.module.add_global( - struct_type, - None, - format!("{}_const", self.id).as_str(), - ); - global.set_constant(true); - global.set_initializer(&ctx.ctx.const_struct( - &[ctx.ctx.i32_type().const_int(id as u64, false).into()], - false, - )); - Ok(global.as_pointer_value().into()) - }) - .unwrap() - }) + .map_or_else( + || Python::with_gil(|py| -> PyResult> { + let id: u32 = self.store_obj.call1(py, (self.value.clone(),))?.extract(py)?; + let struct_type = ctx.ctx.struct_type(&[ctx.ctx.i32_type().into()], false); + let global = ctx.module.add_global( + struct_type, + None, + format!("{}_const", self.id).as_str(), + ); + global.set_constant(true); + global.set_initializer(&ctx.ctx.const_struct( + &[ctx.ctx.i32_type().const_int(id as u64, false).into()], + false, + )); + Ok(global.as_pointer_value().into()) + }) + .unwrap(), + |val| val.as_pointer_value().into(), + ) } fn to_basic_value_enum<'ctx, 'a>( @@ -176,7 +176,7 @@ impl StaticValue for PythonValue { let mut mutable = true; let defs = ctx.top_level.definitions.read(); if let TopLevelDef::Class { fields, .. } = &*defs[def_id.0].read() { - for (field_name, _, is_mutable) in fields.iter() { + for (field_name, _, is_mutable) in fields { if field_name == &name { mutable = *is_mutable; break; @@ -240,7 +240,7 @@ impl InnerResolver { ) -> PyResult> { let mut ty = match self.get_obj_type(py, list.get_item(0)?, unifier, defs, primitives)? { Ok(t) => t, - Err(e) => return Ok(Err(format!("type error ({}) at element #0 of the list", e))), + Err(e) => return Ok(Err(format!("type error ({e}) at element #0 of the list"))), }; for i in 1..len { let b = match list @@ -249,11 +249,11 @@ impl InnerResolver { { Ok(t) => t, Err(e) => { - return Ok(Err(format!("type error ({}) at element #{} of the list", e, i))) + return Ok(Err(format!("type error ({e}) at element #{i} of the list"))) } }; ty = match unifier.unify(ty, b) { - Ok(_) => ty, + Ok(()) => ty, Err(e) => { return Ok(Err(format!( "inhomogeneous type ({}) at element #{i} of the list", @@ -268,7 +268,7 @@ impl InnerResolver { /// Handles python objects that represent types themselves, /// /// Primitives and class types should be themselves, use `ty_id` to check; - /// TypeVars and GenericAlias(`A[int, bool]`) should use `ty_ty_id` to check. + /// `TypeVars` and `GenericAlias`(`A[int, bool]`) should use `ty_ty_id` to check. /// /// The `bool` value returned indicates whether they are instantiated or not fn get_pyty_obj_type( @@ -309,7 +309,7 @@ impl InnerResolver { Ok(Ok((primitives.option, false))) } else if ty_id == self.primitive_ids.none { unreachable!("none cannot be typeid") - } else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).cloned() { + } else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).copied() { let def = defs[def_id.0].read(); if let TopLevelDef::Class { object_id, type_vars, fields, methods, .. } = &*def { // do not handle type var param and concrete check here, and no subst @@ -376,7 +376,7 @@ impl InnerResolver { } Err(err) => return Ok(Err(err)), } - }) + }); } } else { break; @@ -388,7 +388,7 @@ impl InnerResolver { .push((result.clone(), constraints.extract()?, pyty.getattr("__name__")?.extract::()? - )) + )); } (result, is_const_generic) @@ -514,11 +514,10 @@ impl InnerResolver { Ok(ty) => ty, Err(err) => return Ok(Err(err)), }; - if !unifier.is_concrete(ty.0, &[]) && !ty.1 { - panic!( - "virtual class should take concrete parameters in type var ranges" - ) - } + assert!( + unifier.is_concrete(ty.0, &[]) || ty.1, + "virtual class should take concrete parameters in type var ranges" + ); Ok(Ok((unifier.add_ty(TypeEnum::TVirtual { ty: ty.0 }), true))) } else { return Ok(Err(format!( @@ -542,8 +541,7 @@ impl InnerResolver { pyo3::types::PyModule::import(py, "builtins").unwrap().getattr("repr").unwrap(); let str_repr: String = str_fn.call1((pyty,)).unwrap().extract().unwrap(); Ok(Err(format!( - "{} is not registered with NAC3 (@nac3 decorator missing?)", - str_repr + "{str_repr} is not registered with NAC3 (@nac3 decorator missing?)" ))) } } @@ -624,7 +622,7 @@ impl InnerResolver { self.get_list_elem_type(py, obj, len, unifier, defs, primitives)?; match actual_ty { Ok(t) => match unifier.unify(*ty, t) { - Ok(_) => Ok(Ok(unifier.add_ty(TypeEnum::TList { ty: *ty }))), + Ok(()) => Ok(Ok(unifier.add_ty(TypeEnum::TList { ty: *ty }))), Err(e) => Ok(Err(format!( "type error ({}) for the list", e.to_display(unifier) @@ -648,10 +646,8 @@ impl InnerResolver { (TypeEnum::TObj { obj_id, params, .. }, false) if *obj_id == primitives.option.get_obj_id(unifier) => { - let field_data = match obj.getattr("_nac3_option") { - Ok(d) => d, - // we use `none = Option(None)`, so the obj always have attr `_nac3_option` - Err(_) => unreachable!("cannot be None") + let Ok(field_data) = obj.getattr("_nac3_option") else { + unreachable!("cannot be None") }; // if is `none` let zelf_id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?; @@ -671,17 +667,15 @@ impl InnerResolver { }) .collect::>(); return Ok(Ok(unifier.subst(primitives.option, &var_map).unwrap())) - } else { - unreachable!("must be tobj") } + unreachable!("must be tobj") } let ty = match self.get_obj_type(py, field_data, unifier, defs, primitives)? { Ok(t) => t, Err(e) => { return Ok(Err(format!( - "error when getting type of the option object ({})", - e + "error when getting type of the option object ({e})" ))) } }; @@ -706,38 +700,36 @@ impl InnerResolver { .collect::>(); let mut instantiate_obj = || { // loop through non-function fields of the class to get the instantiated value - for field in fields.iter() { + for field in fields { let name: String = (*field.0).into(); if let TypeEnum::TFunc(..) = &*unifier.get_ty(field.1.0) { continue; - } else { - let field_data = match obj.getattr(name.as_str()) { - Ok(d) => d, - Err(e) => return Ok(Err(format!("{}", e))), - }; - let ty = match self - .get_obj_type(py, field_data, unifier, defs, primitives)? - { - Ok(t) => t, - Err(e) => { - return Ok(Err(format!( - "error when getting type of field `{}` ({})", - name, e - ))) - } - }; - let field_ty = - unifier.subst(field.1.0, &var_map).unwrap_or(field.1.0); - if let Err(e) = unifier.unify(ty, field_ty) { - // field type mismatch + } + let field_data = match obj.getattr(name.as_str()) { + Ok(d) => d, + Err(e) => return Ok(Err(format!("{e}"))), + }; + let ty = match self + .get_obj_type(py, field_data, unifier, defs, primitives)? + { + Ok(t) => t, + Err(e) => { return Ok(Err(format!( - "error when getting type of field `{name}` ({})", - e.to_display(unifier) - ))); + "error when getting type of field `{name}` ({e})" + ))) } + }; + let field_ty = + unifier.subst(field.1.0, &var_map).unwrap_or(field.1.0); + if let Err(e) = unifier.unify(ty, field_ty) { + // field type mismatch + return Ok(Err(format!( + "error when getting type of field `{name}` ({})", + e.to_display(unifier) + ))); } } - for (_, ty) in var_map.iter() { + for ty in var_map.values() { // must be concrete type if !unifier.is_concrete(*ty, &[]) { return Ok(Err("object is not of concrete type".into())); @@ -758,32 +750,32 @@ impl InnerResolver { // check integer bounds if unifier.unioned(extracted_ty, primitives.int32) { obj.extract::().map_or_else( - |_| Ok(Err(format!("{} is not in the range of int32", obj))), + |_| Ok(Err(format!("{obj} is not in the range of int32"))), |_| Ok(Ok(extracted_ty)) ) } else if unifier.unioned(extracted_ty, primitives.int64) { obj.extract::().map_or_else( - |_| Ok(Err(format!("{} is not in the range of int64", obj))), + |_| Ok(Err(format!("{obj} is not in the range of int64"))), |_| Ok(Ok(extracted_ty)) ) } else if unifier.unioned(extracted_ty, primitives.uint32) { obj.extract::().map_or_else( - |_| Ok(Err(format!("{} is not in the range of uint32", obj))), + |_| Ok(Err(format!("{obj} is not in the range of uint32"))), |_| Ok(Ok(extracted_ty)) ) } else if unifier.unioned(extracted_ty, primitives.uint64) { obj.extract::().map_or_else( - |_| Ok(Err(format!("{} is not in the range of uint64", obj))), + |_| Ok(Err(format!("{obj} is not in the range of uint64"))), |_| Ok(Ok(extracted_ty)) ) } else if unifier.unioned(extracted_ty, primitives.bool) { obj.extract::().map_or_else( - |_| Ok(Err(format!("{} is not in the range of bool", obj))), + |_| Ok(Err(format!("{obj} is not in the range of bool"))), |_| Ok(Ok(extracted_ty)) ) } else if unifier.unioned(extracted_ty, primitives.float) { obj.extract::().map_or_else( - |_| Ok(Err(format!("{} is not in the range of float64", obj))), + |_| Ok(Err(format!("{obj} is not in the range of float64"))), |_| Ok(Ok(extracted_ty)) ) } else { @@ -855,9 +847,8 @@ impl InnerResolver { ctx.module.add_global(arr_ty, Some(AddressSpace::default()), &id_str) }); return Ok(Some(global.as_pointer_value().into())); - } else { - self.global_value_ids.write().insert(id, obj.into()); } + self.global_value_ids.write().insert(id, obj.into()); } let arr: Result>, _> = (0..len) @@ -867,7 +858,7 @@ impl InnerResolver { .and_then(|elem| self.get_obj_value(py, elem, ctx, generator, elem_ty) .map_err( |e| super::CompileError::new_err( - format!("Error getting element {}: {}", i, e)) + format!("Error getting element {i}: {e}")) )) }) .collect(); @@ -921,7 +912,7 @@ impl InnerResolver { .map(|((i, elem), ty)| self .get_obj_value(py, elem, ctx, generator, *ty).map_err(|e| super::CompileError::new_err( - format!("Error getting element {}: {}", i, e) + format!("Error getting element {i}: {e}") ) ) ).collect(); @@ -953,21 +944,19 @@ impl InnerResolver { .get_obj_value(py, obj.getattr("_nac3_option").unwrap(), ctx, generator, option_val_ty) .map_err(|e| { super::CompileError::new_err(format!( - "Error getting value of Option object: {}", - e + "Error getting value of Option object: {e}" )) })? { Some(v) => { - let global_str = format!("{}_option", id); + let global_str = format!("{id}_option"); { if self.global_value_ids.read().contains_key(&id) { let global = ctx.module.get_global(&global_str).unwrap_or_else(|| { ctx.module.add_global(v.get_type(), Some(AddressSpace::default()), &global_str) }); return Ok(Some(global.as_pointer_value().into())); - } else { - self.global_value_ids.write().insert(id, obj.into()); } + self.global_value_ids.write().insert(id, obj.into()); } let global = ctx.module.add_global(v.get_type(), Some(AddressSpace::default()), &global_str); global.set_initializer(&v); @@ -998,9 +987,8 @@ impl InnerResolver { ctx.module.add_global(ty, Some(AddressSpace::default()), &id_str) }); return Ok(Some(global.as_pointer_value().into())); - } else { - self.global_value_ids.write().insert(id, obj.into()); } + self.global_value_ids.write().insert(id, obj.into()); } // should be classes let definition = @@ -1010,7 +998,7 @@ impl InnerResolver { .iter() .map(|(name, ty, _)| { self.get_obj_value(py, obj.getattr(name.to_string().as_str())?, ctx, generator, *ty) - .map_err(|e| super::CompileError::new_err(format!("Error getting field {}: {}", name, e))) + .map_err(|e| super::CompileError::new_err(format!("Error getting field {name}: {e}"))) }) .collect(); let values = values?; @@ -1083,11 +1071,11 @@ impl SymbolResolver for Resolver { let obj: &PyAny = self.0.module.extract(py)?; let members: &PyDict = obj.getattr("__dict__").unwrap().downcast().unwrap(); let mut sym_value = None; - for (key, val) in members.iter() { + for (key, val) in members { let key: &str = key.extract()?; if key == id.to_string() { if let Ok(Ok(v)) = self.0.get_default_param_obj_value(py, val) { - sym_value = Some(v) + sym_value = Some(v); } break; } @@ -1107,43 +1095,41 @@ impl SymbolResolver for Resolver { primitives: &PrimitiveStore, str: StrRef, ) -> Result { - match { + if let Some(ty) = { let id_to_type = self.0.id_to_type.read(); - id_to_type.get(&str).cloned() + id_to_type.get(&str).copied() } { - Some(ty) => Ok(ty), - None => { - let id = match self.0.name_to_pyid.get(&str) { - Some(id) => id, - None => return Err(format!("cannot find symbol `{}`", str)), - }; - let result = match { - let pyid_to_type = self.0.pyid_to_type.read(); - pyid_to_type.get(id).copied() - } { - Some(t) => Ok(t), - None => Python::with_gil(|py| -> PyResult> { - let obj: &PyAny = self.0.module.extract(py)?; - let mut sym_ty = Err(format!("cannot find symbol `{}`", str)); - let members: &PyDict = obj.getattr("__dict__").unwrap().downcast().unwrap(); - for (key, val) in members.iter() { - let key: &str = key.extract()?; - if key == str.to_string() { - sym_ty = self.0.get_obj_type(py, val, unifier, defs, primitives)?; - break; - } + Ok(ty) + } else { + let Some(id) = self.0.name_to_pyid.get(&str) else { + return Err(format!("cannot find symbol `{str}`")) + }; + let result = if let Some(t) = { + let pyid_to_type = self.0.pyid_to_type.read(); + pyid_to_type.get(id).copied() + } { + Ok(t) + } else { + Python::with_gil(|py| -> PyResult> { + let obj: &PyAny = self.0.module.extract(py)?; + let mut sym_ty = Err(format!("cannot find symbol `{str}`")); + let members: &PyDict = obj.getattr("__dict__").unwrap().downcast().unwrap(); + for (key, val) in members { + let key: &str = key.extract()?; + if key == str.to_string() { + sym_ty = self.0.get_obj_type(py, val, unifier, defs, primitives)?; + break; } - if let Ok(t) = sym_ty { - if let TypeEnum::TVar { .. } = &*unifier.get_ty(t) { - self.0.pyid_to_type.write().insert(*id, t); - } + } + if let Ok(t) = sym_ty { + if let TypeEnum::TVar { .. } = &*unifier.get_ty(t) { + self.0.pyid_to_type.write().insert(*id, t); } - Ok(sym_ty) - }) - .unwrap(), - }; - result - } + } + Ok(sym_ty) + }).unwrap() + }; + result } } @@ -1161,7 +1147,7 @@ impl SymbolResolver for Resolver { let obj: &PyAny = self.0.module.extract(py)?; let mut sym_value: Option<(u64, PyObject)> = None; let members: &PyDict = obj.getattr("__dict__").unwrap().downcast().unwrap(); - for (key, val) in members.iter() { + for (key, val) in members { let key: &str = key.extract()?; if key == id.to_string() { let id = self.0.helper.id_fn.call1(py, (val,))?.extract(py)?; @@ -1189,14 +1175,13 @@ impl SymbolResolver for Resolver { fn get_identifier_def(&self, id: StrRef) -> Result { { let id_to_def = self.0.id_to_def.read(); - id_to_def.get(&id).cloned().ok_or_else(|| "".to_string()) + id_to_def.get(&id).copied().ok_or_else(String::new) } .or_else(|_| { let py_id = - self.0.name_to_pyid.get(&id).ok_or(format!("Undefined identifier `{}`", id))?; + self.0.name_to_pyid.get(&id).ok_or(format!("Undefined identifier `{id}`"))?; let result = self.0.pyid_to_def.read().get(py_id).copied().ok_or(format!( - "`{}` is not registered with NAC3 (@nac3 decorator missing?)", - id + "`{id}` is not registered with NAC3 (@nac3 decorator missing?)" ))?; self.0.id_to_def.write().insert(id, result); Ok(result) @@ -1241,7 +1226,7 @@ impl SymbolResolver for Resolver { name, ))); } - unifier.unify(ty, *var).unwrap() + unifier.unify(ty, *var).unwrap(); } Err(err) => return Ok(Err(err)), } @@ -1251,13 +1236,13 @@ impl SymbolResolver for Resolver { } } Ok(Ok(())) - }).unwrap()? + }).unwrap()?; } Ok(()) } fn get_exception_id(&self, tyid: usize) -> usize { let exn_ids = self.0.exception_ids.read(); - exn_ids.get(&tyid).cloned().unwrap_or(0) + exn_ids.get(&tyid).copied().unwrap_or(0) } } diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 6daf50e..de28ec4 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -227,7 +227,7 @@ fn test_primitives() { threads, top_level, &llvm_options, - f + &f ); registry.add_task(task); registry.wait_tasks_complete(handles); @@ -417,7 +417,7 @@ fn test_simple_call() { threads, top_level, &llvm_options, - f + &f ); registry.add_task(task); registry.wait_tasks_complete(handles);