diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index bcf023cc..74e55207 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -6,21 +6,20 @@ use nac3core::{ CodeGenContext, CodeGenerator, }, symbol_resolver::ValueEnum, - toplevel::{DefinitionId, GenCall, helper::PRIMITIVE_DEF_IDS}, - typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, VarMap} + toplevel::{helper::PRIMITIVE_DEF_IDS, DefinitionId, GenCall}, + typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, VarMap}, }; use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef}; use inkwell::{ - context::Context, - module::Linkage, - types::IntType, - values::BasicValueEnum, - AddressSpace, + context::Context, module::Linkage, types::IntType, values::BasicValueEnum, AddressSpace, }; -use pyo3::{PyObject, PyResult, Python, types::{PyDict, PyList}}; +use pyo3::{ + types::{PyDict, PyList}, + PyObject, PyResult, Python, +}; use crate::{symbol_resolver::InnerResolver, timeline::TimeFns}; @@ -46,7 +45,7 @@ enum ParallelMode { /// /// Each function call within the `with` block (except those within a nested `sequential` block) /// are treated to be executed in parallel. - Deep + Deep, } pub struct ArtiqCodeGenerator<'a> { @@ -96,14 +95,13 @@ impl<'a> ArtiqCodeGenerator<'a> { /// /// Direct-`parallel` block context refers to when the generator is generating statements whose /// closest parent `with` statement is a `with parallel` block. - fn timeline_reset_start( - &mut self, - ctx: &mut CodeGenContext<'_, '_> - ) -> Result<(), String> { + fn timeline_reset_start(&mut self, ctx: &mut CodeGenContext<'_, '_>) -> Result<(), String> { if let Some(start) = self.start.clone() { - let start_val = self.gen_expr(ctx, &start)? - .unwrap() - .to_basic_value_enum(ctx, self, start.custom.unwrap())?; + let start_val = self.gen_expr(ctx, &start)?.unwrap().to_basic_value_enum( + ctx, + self, + start.custom.unwrap(), + )?; self.timeline.emit_at_mu(ctx, start_val); } @@ -129,20 +127,20 @@ impl<'a> ArtiqCodeGenerator<'a> { store_name: Option<&str>, ) -> Result<(), String> { if let Some(end) = end { - let old_end = self.gen_expr(ctx, &end)? - .unwrap() - .to_basic_value_enum(ctx, self, end.custom.unwrap())?; - let now = self.timeline.emit_now_mu(ctx); - let max = call_int_smax( - ctx, - old_end.into_int_value(), - now.into_int_value(), - Some("smax") - ); - let end_store = self.gen_store_target( + let old_end = self.gen_expr(ctx, &end)?.unwrap().to_basic_value_enum( ctx, - &end, - store_name.map(|name| format!("{name}.addr")).as_deref())? + self, + end.custom.unwrap(), + )?; + let now = self.timeline.emit_now_mu(ctx); + let max = + call_int_smax(ctx, old_end.into_int_value(), now.into_int_value(), Some("smax")); + let end_store = self + .gen_store_target( + ctx, + &end, + store_name.map(|name| format!("{name}.addr")).as_deref(), + )? .unwrap(); ctx.builder.build_store(end_store, max).unwrap(); } @@ -164,11 +162,14 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { } } - fn gen_block<'ctx, 'a, 'c, I: Iterator>>>( + fn gen_block<'ctx, 'a, 'c, I: Iterator>>>( &mut self, ctx: &mut CodeGenContext<'ctx, 'a>, - stmts: I - ) -> Result<(), String> where Self: Sized { + stmts: I, + ) -> Result<(), String> + where + Self: Sized, + { // Legacy parallel emits timeline end-update/timeline-reset after each top-level statement // in the parallel block if self.parallel_mode == ParallelMode::Legacy { @@ -212,9 +213,7 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { ctx: &mut CodeGenContext<'_, '_>, stmt: &Stmt>, ) -> Result<(), String> { - let StmtKind::With { items, body, .. } = &stmt.node else { - unreachable!() - }; + let StmtKind::With { items, body, .. } = &stmt.node else { unreachable!() }; if items.len() == 1 && items[0].optional_vars.is_none() { let item = &items[0]; @@ -239,9 +238,11 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { let old_parallel_mode = self.parallel_mode; let now = if let Some(old_start) = &old_start { - self.gen_expr(ctx, old_start)? - .unwrap() - .to_basic_value_enum(ctx, self, old_start.custom.unwrap())? + self.gen_expr(ctx, old_start)?.unwrap().to_basic_value_enum( + ctx, + self, + old_start.custom.unwrap(), + )? } else { self.timeline.emit_now_mu(ctx) }; @@ -277,9 +278,7 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { node: ExprKind::Name { id: end, ctx: name_ctx.clone() }, custom: Some(ctx.primitives.int64), }; - let end = self - .gen_store_target(ctx, &end_expr, Some("end.addr"))? - .unwrap(); + let end = self.gen_store_target(ctx, &end_expr, Some("end.addr"))?.unwrap(); ctx.builder.build_store(end, now).unwrap(); self.end = Some(end_expr); self.name_counter += 1; @@ -309,10 +308,11 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { // set duration let end_expr = self.end.take().unwrap(); - let end_val = self - .gen_expr(ctx, &end_expr)? - .unwrap() - .to_basic_value_enum(ctx, self, end_expr.custom.unwrap())?; + let end_val = self.gen_expr(ctx, &end_expr)?.unwrap().to_basic_value_enum( + ctx, + self, + end_expr.custom.unwrap(), + )?; // inside a sequential block if old_start.is_none() { @@ -416,7 +416,7 @@ fn rpc_codegen_callback_fn<'ctx>( let int32 = ctx.ctx.i32_type(); let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false); - let service_id = int32.const_int(fun.1.0 as u64, false); + let service_id = int32.const_int(fun.1 .0 as u64, false); // -- setup rpc tags let mut tag = Vec::new(); if obj.is_some() { @@ -461,7 +461,8 @@ fn rpc_codegen_callback_fn<'ctx>( let arg_length = args.len() + usize::from(obj.is_some()); let stackptr = call_stacksave(ctx, Some("rpc.stack")); - let args_ptr = ctx.builder + let args_ptr = ctx + .builder .build_array_alloca( ptr_type, ctx.ctx.i32_type().const_int(arg_length as u64, false), @@ -477,10 +478,8 @@ fn rpc_codegen_callback_fn<'ctx>( } // default value handling for k in keys { - mapping.insert( - k.name, - ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into() - ); + mapping + .insert(k.name, ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into()); } // reorder the parameters let mut real_params = fun @@ -499,7 +498,8 @@ fn rpc_codegen_callback_fn<'ctx>( } for (i, arg) in real_params.iter().enumerate() { - let arg_slot = generator.gen_var_alloc(ctx, arg.get_type(), Some(&format!("rpc.arg{i}"))).unwrap(); + let arg_slot = + generator.gen_var_alloc(ctx, arg.get_type(), Some(&format!("rpc.arg{i}"))).unwrap(); ctx.builder.build_store(arg_slot, *arg).unwrap(); let arg_slot = ctx.builder.build_bitcast(arg_slot, ptr_type, "rpc.arg").unwrap(); let arg_ptr = unsafe { @@ -508,7 +508,8 @@ fn rpc_codegen_callback_fn<'ctx>( &[int32.const_int(i as u64, false)], &format!("rpc.arg{i}"), ) - }.unwrap(); + } + .unwrap(); ctx.builder.build_store(arg_ptr, arg_slot).unwrap(); } @@ -528,11 +529,7 @@ fn rpc_codegen_callback_fn<'ctx>( ) }); ctx.builder - .build_call( - rpc_send, - &[service_id.into(), tag_ptr.into(), args_ptr.into()], - "rpc.send", - ) + .build_call(rpc_send, &[service_id.into(), tag_ptr.into(), args_ptr.into()], "rpc.send") .unwrap(); // reclaim stack space used by arguments @@ -575,13 +572,9 @@ fn rpc_codegen_callback_fn<'ctx>( .build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next") .unwrap() .into_int_value(); - let is_done = ctx.builder - .build_int_compare( - inkwell::IntPredicate::EQ, - int32.const_zero(), - alloc_size, - "rpc.done", - ) + let is_done = ctx + .builder + .build_int_compare(inkwell::IntPredicate::EQ, int32.const_zero(), alloc_size, "rpc.done") .unwrap(); ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap(); @@ -617,9 +610,15 @@ pub fn attributes_writeback( let mut scratch_buffer = Vec::new(); for val in (*globals).values() { let val = val.as_ref(py); - let ty = inner_resolver.get_obj_type(py, val, &mut ctx.unifier, &top_levels, &ctx.primitives)?; + let ty = inner_resolver.get_obj_type( + py, + val, + &mut ctx.unifier, + &top_levels, + &ctx.primitives, + )?; if let Err(ty) = ty { - return Ok(Err(ty)) + return Ok(Err(ty)); } let ty = ty.unwrap(); match &*ctx.unifier.get_ty(ty) { @@ -632,14 +631,19 @@ pub fn attributes_writeback( let obj = inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap(); for (name, (field_ty, is_mutable)) in fields { if !is_mutable { - continue + continue; } if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() { attributes.push(name.to_string()); let index = ctx.get_attr_index(ty, *name); - values.push((*field_ty, ctx.build_gep_and_load( - obj.into_pointer_value(), - &[zero, int32.const_int(index as u64, false)], None))); + values.push(( + *field_ty, + ctx.build_gep_and_load( + obj.into_pointer_value(), + &[zero, int32.const_int(index as u64, false)], + None, + ), + )); } } if !attributes.is_empty() { @@ -648,33 +652,44 @@ pub fn attributes_writeback( pydict.set_item("fields", attributes)?; host_attributes.append(pydict)?; } - }, + } TypeEnum::TList { ty: elem_ty } => { if gen_rpc_tag(ctx, *elem_ty, &mut scratch_buffer).is_ok() { let pydict = PyDict::new(py); pydict.set_item("obj", val)?; host_attributes.append(pydict)?; - values.push((ty, inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap())); + values.push(( + ty, + inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap(), + )); } - }, + } _ => {} } } let fun = FunSignature { - args: values.iter().enumerate().map(|(i, (ty, _))| FuncArg { - name: i.to_string().into(), - ty: *ty, - default_value: None - }).collect(), + args: values + .iter() + .enumerate() + .map(|(i, (ty, _))| FuncArg { + name: i.to_string().into(), + ty: *ty, + default_value: None, + }) + .collect(), ret: ctx.primitives.none, - vars: VarMap::default() + vars: VarMap::default(), }; - let args: Vec<_> = values.into_iter().map(|(_, val)| (None, ValueEnum::Dynamic(val))).collect(); - if let Err(e) = rpc_codegen_callback_fn(ctx, None, (&fun, PRIMITIVE_DEF_IDS.int32), args, generator) { + let args: Vec<_> = + values.into_iter().map(|(_, val)| (None, ValueEnum::Dynamic(val))).collect(); + if let Err(e) = + rpc_codegen_callback_fn(ctx, None, (&fun, PRIMITIVE_DEF_IDS.int32), args, generator) + { return Ok(Err(e)); } Ok(Ok(())) - }).unwrap()?; + }) + .unwrap()?; Ok(()) } diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index f8f029ff..3263b4f6 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -14,16 +14,16 @@ use inkwell::{ OptimizationLevel, }; use itertools::Itertools; -use nac3core::codegen::{CodeGenLLVMOptions, CodeGenTargetMachineOptions, gen_func_impl}; +use nac3core::codegen::{gen_func_impl, CodeGenLLVMOptions, CodeGenTargetMachineOptions}; use nac3core::toplevel::builtins::get_exn_constructor; use nac3core::typecheck::typedef::{TypeEnum, Unifier, VarMap}; use nac3parser::{ ast::{ExprKind, Stmt, StmtKind, StrRef}, parser::parse_program, }; +use pyo3::create_exception; use pyo3::prelude::*; use pyo3::{exceptions, types::PyBytes, types::PyDict, types::PySet}; -use pyo3::create_exception; use parking_lot::{Mutex, RwLock}; @@ -46,7 +46,7 @@ use tempfile::{self, TempDir}; use crate::codegen::attributes_writeback; use crate::{ codegen::{rpc_codegen_callback, ArtiqCodeGenerator}, - symbol_resolver::{InnerResolver, PythonHelper, Resolver, DeferredEvaluationStore}, + symbol_resolver::{DeferredEvaluationStore, InnerResolver, PythonHelper, Resolver}, }; mod codegen; @@ -138,9 +138,7 @@ impl Nac3 { for mut stmt in parser_result { let include = match stmt.node { - StmtKind::ClassDef { - ref decorator_list, ref mut body, ref mut bases, .. - } => { + StmtKind::ClassDef { ref decorator_list, ref mut body, ref mut bases, .. } => { let nac3_class = decorator_list.iter().any(|decorator| { if let ExprKind::Name { id, .. } = decorator.node { id.to_string() == "nac3" @@ -160,7 +158,8 @@ impl Nac3 { if *id == "Exception".into() { Ok(true) } else { - let base_obj = module.getattr(py, id.to_string().as_str())?; + let base_obj = + module.getattr(py, id.to_string().as_str())?; let base_id = id_fn.call1((base_obj,))?.extract()?; Ok(registered_class_ids.contains(&base_id)) } @@ -341,8 +340,9 @@ impl Nac3 { let class_obj; if let StmtKind::ClassDef { name, .. } = &stmt.node { let class = py_module.getattr(name.to_string().as_str()).unwrap(); - if issubclass.call1((class, exn_class)).unwrap().extract().unwrap() && - class.getattr("artiq_builtin").is_err() { + if issubclass.call1((class, exn_class)).unwrap().extract().unwrap() + && class.getattr("artiq_builtin").is_err() + { class_obj = Some(class); } else { class_obj = None; @@ -388,12 +388,12 @@ impl Nac3 { let (name, def_id, ty) = composer .register_top_level(stmt.clone(), Some(resolver.clone()), path, false) .map_err(|e| { - CompileError::new_err(format!( - "compilation failed\n----------\n{e}" - )) + CompileError::new_err(format!("compilation failed\n----------\n{e}")) })?; if let Some(class_obj) = class_obj { - self.exception_ids.write().insert(def_id.0, store_obj.call1(py, (class_obj, ))?.extract(py)?); + self.exception_ids + .write() + .insert(def_id.0, store_obj.call1(py, (class_obj,))?.extract(py)?); } match &stmt.node { @@ -470,7 +470,8 @@ impl Nac3 { exception_ids: self.exception_ids.clone(), deferred_eval_store: self.deferred_eval_store.clone(), }); - let resolver = Arc::new(Resolver(inner_resolver.clone())) as Arc; + let resolver = + Arc::new(Resolver(inner_resolver.clone())) as Arc; let (_, def_id, _) = composer .register_top_level(synthesized.pop().unwrap(), Some(resolver.clone()), "", false) .unwrap(); @@ -479,8 +480,12 @@ impl Nac3 { FunSignature { args: vec![], ret: self.primitive.none, vars: VarMap::new() }; let mut store = ConcreteTypeStore::new(); let mut cache = HashMap::new(); - let signature = - store.from_signature(&mut composer.unifier, &self.primitive, &fun_signature, &mut cache); + let signature = store.from_signature( + &mut composer.unifier, + &self.primitive, + &fun_signature, + &mut cache, + ); let signature = store.add_cty(signature); if let Err(e) = composer.start_analysis(true) { @@ -499,13 +504,11 @@ impl Nac3 { msg.unwrap_or(e.iter().sorted().join("\n----------\n")) ))) } else { - Err(CompileError::new_err( - format!( - "compilation failed\n----------\n{}", - e.iter().sorted().join("\n----------\n"), - ), - )) - } + Err(CompileError::new_err(format!( + "compilation failed\n----------\n{}", + e.iter().sorted().join("\n----------\n"), + ))) + }; } let top_level = Arc::new(composer.make_top_level_context()); @@ -533,7 +536,9 @@ impl Nac3 { py, ( id.0.into_py(py), - class_def.getattr(py, name.to_string().as_str()).unwrap(), + class_def + .getattr(py, name.to_string().as_str()) + .unwrap(), ), ) .unwrap(); @@ -548,7 +553,8 @@ impl Nac3 { let defs = top_level.definitions.read(); let mut definition = defs[def_id.0].write(); let TopLevelDef::Function { instance_to_stmt, instance_to_symbol, .. } = - &mut *definition else { + &mut *definition + else { unreachable!() }; @@ -570,8 +576,12 @@ impl Nac3 { let mut store = ConcreteTypeStore::new(); let mut cache = HashMap::new(); - let signature = - store.from_signature(&mut composer.unifier, &self.primitive, &fun_signature, &mut cache); + let signature = store.from_signature( + &mut composer.unifier, + &self.primitive, + &fun_signature, + &mut cache, + ); let signature = store.add_cty(signature); let attributes_writeback_task = CodeGenTask { subst: Vec::default(), @@ -604,23 +614,28 @@ impl Nac3 { let membuffer = membuffers.clone(); py.allow_threads(|| { - let (registry, handles) = WorkerRegistry::create_workers( - threads, - top_level.clone(), - &self.llvm_options, - &f - ); + let (registry, handles) = + WorkerRegistry::create_workers(threads, top_level.clone(), &self.llvm_options, &f); registry.add_task(task); registry.wait_tasks_complete(handles); - let mut generator = ArtiqCodeGenerator::new("attributes_writeback".to_string(), size_t, self.time_fns); + let mut generator = + ArtiqCodeGenerator::new("attributes_writeback".to_string(), size_t, self.time_fns); let context = inkwell::context::Context::create(); let module = context.create_module("attributes_writeback"); let builder = context.create_builder(); - let (_, module, _) = gen_func_impl(&context, &mut generator, ®istry, builder, module, - attributes_writeback_task, |generator, ctx| { + let (_, module, _) = gen_func_impl( + &context, + &mut generator, + ®istry, + builder, + module, + attributes_writeback_task, + |generator, ctx| { attributes_writeback(ctx, generator, inner_resolver.as_ref(), &host_attributes) - }).unwrap(); + }, + ) + .unwrap(); let buffer = module.write_bitcode_to_memory(); let buffer = buffer.as_slice().into(); membuffer.lock().push(buffer); @@ -636,11 +651,16 @@ impl Nac3 { .create_module_from_ir(MemoryBuffer::create_from_memory_range(buffer, "main")) .unwrap(); - main.link_in_module(other) - .map_err(|err| CompileError::new_err(err.to_string()))?; + main.link_in_module(other).map_err(|err| CompileError::new_err(err.to_string()))?; } let builder = context.create_builder(); - let modinit_return = main.get_function("__modinit__").unwrap().get_last_basic_block().unwrap().get_terminator().unwrap(); + let modinit_return = main + .get_function("__modinit__") + .unwrap() + .get_last_basic_block() + .unwrap() + .get_terminator() + .unwrap(); builder.position_before(&modinit_return); builder .build_call( @@ -662,10 +682,7 @@ impl Nac3 { } // Demote all global variables that will not be referenced in the kernel to private - let preserved_symbols: Vec<&'static [u8]> = vec![ - b"typeinfo", - b"now", - ]; + let preserved_symbols: Vec<&'static [u8]> = vec![b"typeinfo", b"now"]; let mut global_option = main.get_first_global(); while let Some(global) = global_option { if !preserved_symbols.contains(&(global.get_name().to_bytes())) { @@ -674,7 +691,9 @@ impl Nac3 { global_option = global.get_next_global(); } - let target_machine = self.llvm_options.target + let target_machine = self + .llvm_options + .target .create_target_machine(self.llvm_options.opt_level) .expect("couldn't create target machine"); @@ -738,10 +757,7 @@ impl Nac3 { } } -fn link_with_lld( - elf_filename: String, - obj_filename: String, -) -> PyResult<()>{ +fn link_with_lld(elf_filename: String, obj_filename: String) -> PyResult<()> { let linker_args = vec![ "-shared".to_string(), "--eh-frame-hdr".to_string(), @@ -760,9 +776,7 @@ fn link_with_lld( return Err(CompileError::new_err("failed to start linker")); } } else { - return Err(CompileError::new_err( - "linker returned non-zero status code", - )); + return Err(CompileError::new_err("linker returned non-zero status code")); } Ok(()) @@ -772,7 +786,7 @@ fn add_exceptions( composer: &mut TopLevelComposer, builtin_def: &mut HashMap, builtin_ty: &mut HashMap, - error_names: &[&str] + error_names: &[&str], ) -> Vec { let mut types = Vec::new(); // note: this is only for builtin exceptions, i.e. the exception name is "0:{exn}" @@ -785,7 +799,7 @@ fn add_exceptions( // constructor id def_id + 1, &mut composer.unifier, - &composer.primitives_ty + &composer.primitives_ty, ); composer.definition_ast_list.push((Arc::new(RwLock::new(exception_class)), None)); composer.definition_ast_list.push((Arc::new(RwLock::new(exception_fn)), None)); @@ -834,7 +848,8 @@ impl Nac3 { }, Arc::new(GenCall::new(Box::new(move |ctx, _, fun, args, generator| { let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty).unwrap(); + let arg = + args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty).unwrap(); time_fns.emit_at_mu(ctx, arg); Ok(None) }))), @@ -852,7 +867,8 @@ impl Nac3 { }, Arc::new(GenCall::new(Box::new(move |ctx, _, fun, args, generator| { let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty).unwrap(); + let arg = + args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty).unwrap(); time_fns.emit_delay_mu(ctx, arg); Ok(None) }))), @@ -867,8 +883,9 @@ impl Nac3 { let types_mod = PyModule::import(py, "types").unwrap(); let get_id = |x: &PyAny| id_fn.call1((x,)).and_then(PyAny::extract).unwrap(); - let get_attr_id = |obj: &PyModule, attr| id_fn.call1((obj.getattr(attr).unwrap(),)) - .unwrap().extract().unwrap(); + let get_attr_id = |obj: &PyModule, attr| { + id_fn.call1((obj.getattr(attr).unwrap(),)).unwrap().extract().unwrap() + }; let primitive_ids = PrimitivePythonId { virtual_id: get_id(artiq_builtins.get_item("virtual").ok().flatten().unwrap()), generic_alias: ( @@ -877,7 +894,9 @@ impl Nac3 { ), none: get_id(artiq_builtins.get_item("none").ok().flatten().unwrap()), typevar: get_attr_id(typing_mod, "TypeVar"), - const_generic_marker: get_id(artiq_builtins.get_item("_ConstGenericMarker").ok().flatten().unwrap()), + const_generic_marker: get_id( + artiq_builtins.get_item("_ConstGenericMarker").ok().flatten().unwrap(), + ), int: get_attr_id(builtins_mod, "int"), int32: get_attr_id(numpy_mod, "int32"), int64: get_attr_id(numpy_mod, "int64"), @@ -911,7 +930,7 @@ impl Nac3 { llvm_options: CodeGenLLVMOptions { opt_level: OptimizationLevel::Default, target: Nac3::get_llvm_target_options(isa), - } + }, }) } @@ -952,7 +971,7 @@ impl Nac3 { py: Python, ) -> PyResult<()> { let target_machine = self.get_llvm_target_machine(); - + if self.isa == Isa::Host { let link_fn = |module: &Module| { let working_directory = self.working_directory.path().to_owned(); @@ -961,7 +980,7 @@ impl Nac3 { .expect("couldn't write module to file"); link_with_lld( filename.to_string(), - working_directory.join("module.o").to_string_lossy().to_string() + working_directory.join("module.o").to_string_lossy().to_string(), )?; Ok(()) }; @@ -997,7 +1016,7 @@ impl Nac3 { py: Python, ) -> PyResult { let target_machine = self.get_llvm_target_machine(); - + if self.isa == Isa::Host { let link_fn = |module: &Module| { let working_directory = self.working_directory.path().to_owned(); @@ -1009,7 +1028,7 @@ impl Nac3 { let filename = filename_path.to_str().unwrap(); link_with_lld( filename.to_string(), - working_directory.join("module.o").to_string_lossy().to_string() + working_directory.join("module.o").to_string_lossy().to_string(), )?; Ok(PyBytes::new(py, &fs::read(filename).unwrap()).into()) diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 32c0f9d8..9f4118e8 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -3,10 +3,9 @@ use nac3core::{ codegen::{CodeGenContext, CodeGenerator}, symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum}, toplevel::{ - DefinitionId, - helper::PRIMITIVE_DEF_IDS, + helper::PRIMITIVE_DEF_IDS, numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, - TopLevelDef, + DefinitionId, TopLevelDef, }, typecheck::{ type_inferencer::PrimitiveStore, @@ -22,9 +21,9 @@ use pyo3::{ use std::{ collections::{HashMap, HashSet}, sync::{ + atomic::{AtomicBool, Ordering::Relaxed}, Arc, - atomic::{AtomicBool, Ordering::Relaxed} - } + }, }; use crate::PrimitivePythonId; @@ -58,7 +57,7 @@ impl DeferredEvaluationStore { } } -/// A class field as stored in the [`InnerResolver`], represented by the ID and name of the +/// A class field as stored in the [`InnerResolver`], represented by the ID and name of the /// associated [`PythonValue`]. type ResolverField = (u64, StrRef); /// A class field as stored in Python, represented by the `id()` and [`PyObject`] of the field. @@ -114,27 +113,27 @@ impl StaticValue for PythonValue { ctx: &mut CodeGenContext<'ctx, '_>, _: &mut dyn CodeGenerator, ) -> BasicValueEnum<'ctx> { - ctx.module - .get_global(format!("{}_const", self.id).as_str()) - .map_or_else( - || Python::with_gil(|py| -> PyResult> { - let id: u32 = self.store_obj.call1(py, (self.value.clone(),))?.extract(py)?; - let struct_type = ctx.ctx.struct_type(&[ctx.ctx.i32_type().into()], false); - let global = ctx.module.add_global( - struct_type, - None, - format!("{}_const", self.id).as_str(), - ); - global.set_constant(true); - global.set_initializer(&ctx.ctx.const_struct( - &[ctx.ctx.i32_type().const_int(id as u64, false).into()], - false, - )); - Ok(global.as_pointer_value().into()) - }) - .unwrap(), - |val| val.as_pointer_value().into(), - ) + ctx.module.get_global(format!("{}_const", self.id).as_str()).map_or_else( + || { + Python::with_gil(|py| -> PyResult> { + let id: u32 = self.store_obj.call1(py, (self.value.clone(),))?.extract(py)?; + let struct_type = ctx.ctx.struct_type(&[ctx.ctx.i32_type().into()], false); + let global = ctx.module.add_global( + struct_type, + None, + format!("{}_const", self.id).as_str(), + ); + global.set_constant(true); + global.set_initializer(&ctx.ctx.const_struct( + &[ctx.ctx.i32_type().const_int(id as u64, false).into()], + false, + )); + Ok(global.as_pointer_value().into()) + }) + .unwrap() + }, + |val| val.as_pointer_value().into(), + ) } fn to_basic_value_enum<'ctx, 'a>( @@ -161,7 +160,8 @@ impl StaticValue for PythonValue { self.resolver .get_obj_value(py, self.value.as_ref(py), ctx, generator, expected_ty) .map(Option::unwrap) - }).map_err(|e| e.to_string()) + }) + .map_err(|e| e.to_string()) } fn get_field<'ctx>( @@ -186,7 +186,7 @@ impl StaticValue for PythonValue { Ok(None) } else { Ok(Some((id, obj))) - } + }; } let def_id = { *self.resolver.pyid_to_def.read().get(&ty_id).unwrap() }; let mut mutable = true; @@ -264,9 +264,7 @@ impl InnerResolver { .map(|elem| self.get_obj_type(py, elem, unifier, defs, primitives))?? { Ok(t) => t, - Err(e) => { - return Ok(Err(format!("type error ({e}) at element #{i} of the list"))) - } + Err(e) => return Ok(Err(format!("type error ({e}) at element #{i} of the list"))), }; ty = match unifier.unify(ty, b) { Ok(()) => ty, @@ -377,7 +375,7 @@ impl InnerResolver { let constr_id: u64 = self.helper.id_fn.call1(py, (constr,))?.extract(py)?; if constr_id == self.primitive_ids.const_generic_marker { is_const_generic = true; - continue + continue; } if !is_const_generic && needs_defer { @@ -406,11 +404,11 @@ impl InnerResolver { } if !is_const_generic && needs_defer { - self.deferred_eval_store.store.write() - .push((result.clone(), - constraints.extract()?, - pyty.getattr("__name__")?.extract::()? - )); + self.deferred_eval_store.store.write().push(( + result.clone(), + constraints.extract()?, + pyty.getattr("__name__")?.extract::()?, + )); } (result, is_const_generic) @@ -418,7 +416,10 @@ impl InnerResolver { let res = if is_const_generic { if constraint_types.len() != 1 { - return Ok(Err(format!("ConstGeneric expects 1 argument, got {}", constraint_types.len()))) + return Ok(Err(format!( + "ConstGeneric expects 1 argument, got {}", + constraint_types.len() + ))); } unifier.get_fresh_const_generic_var(constraint_types[0], Some(name.into()), None).0 @@ -572,9 +573,7 @@ impl InnerResolver { let str_fn = pyo3::types::PyModule::import(py, "builtins").unwrap().getattr("repr").unwrap(); let str_repr: String = str_fn.call1((pyty,)).unwrap().extract().unwrap(); - Ok(Err(format!( - "{str_repr} is not registered with NAC3 (@nac3 decorator missing?)" - ))) + Ok(Err(format!("{str_repr} is not registered with NAC3 (@nac3 decorator missing?)"))) } } @@ -589,31 +588,28 @@ impl InnerResolver { let ty = self.helper.type_fn.call1(py, (obj,)).unwrap(); let py_obj_id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?; if let Some(ty) = self.pyid_to_type.read().get(&py_obj_id) { - return Ok(Ok(*ty)) + return Ok(Ok(*ty)); } // check if constructor function exists in the methods list let pyid_to_def = self.pyid_to_def.read(); - let constructor_ty = pyid_to_def - .get(&py_obj_id) - .and_then(|def_id| { - defs - .iter() - .find_map(|def| { - if let TopLevelDef::Class { - object_id, methods, constructor, .. - } = &*def.read() { - if object_id == def_id && constructor.is_some() && methods.iter().any(|(s, _, _)| s == &"__init__".into()) { - return *constructor; - } + let constructor_ty = pyid_to_def.get(&py_obj_id).and_then(|def_id| { + defs.iter().find_map(|def| { + if let TopLevelDef::Class { object_id, methods, constructor, .. } = &*def.read() { + if object_id == def_id + && constructor.is_some() + && methods.iter().any(|(s, _, _)| s == &"__init__".into()) + { + return *constructor; } - None - }) - }); + } + None + }) + }); if let Some(ty) = constructor_ty { self.pyid_to_type.write().insert(py_obj_id, ty); - return Ok(Ok(ty)) + return Ok(Ok(ty)); } let (extracted_ty, inst_check) = match self.get_pyty_obj_type( @@ -680,12 +676,8 @@ impl InnerResolver { match actual_ty { Ok(t) => match unifier.unify(ty, t) { Ok(()) => { - let ndarray_ty = make_ndarray_ty( - unifier, - primitives, - Some(ty), - Some(ndims), - ); + let ndarray_ty = + make_ndarray_ty(unifier, primitives, Some(ty), Some(ndims)); Ok(Ok(ndarray_ty)) } @@ -726,7 +718,8 @@ impl InnerResolver { let var_map = params .iter() .map(|(id_var, ty)| { - let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(*ty) else { + let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(*ty) + else { unreachable!() }; @@ -734,7 +727,7 @@ impl InnerResolver { (*id, unifier.get_fresh_var_with_range(range, *name, *loc).0) }) .collect::(); - return Ok(Ok(unifier.subst(primitives.option, &var_map).unwrap())) + return Ok(Ok(unifier.subst(primitives.option, &var_map).unwrap())); } let ty = match self.get_obj_type(py, field_data, unifier, defs, primitives)? { @@ -754,8 +747,8 @@ impl InnerResolver { let var_map = params .iter() .map(|(id_var, ty)| { - let TypeEnum::TVar { id, range, name, loc, .. } = - &*unifier.get_ty(*ty) else { + let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(*ty) + else { unreachable!() }; @@ -767,25 +760,23 @@ impl InnerResolver { // loop through non-function fields of the class to get the instantiated value for field in fields { let name: String = (*field.0).into(); - if let TypeEnum::TFunc(..) = &*unifier.get_ty(field.1.0) { + if let TypeEnum::TFunc(..) = &*unifier.get_ty(field.1 .0) { continue; } let field_data = match obj.getattr(name.as_str()) { Ok(d) => d, Err(e) => return Ok(Err(format!("{e}"))), }; - let ty = match self - .get_obj_type(py, field_data, unifier, defs, primitives)? - { - Ok(t) => t, - Err(e) => { - return Ok(Err(format!( - "error when getting type of field `{name}` ({e})" - ))) - } - }; - let field_ty = - unifier.subst(field.1.0, &var_map).unwrap_or(field.1.0); + let ty = + match self.get_obj_type(py, field_data, unifier, defs, primitives)? { + Ok(t) => t, + Err(e) => { + return Ok(Err(format!( + "error when getting type of field `{name}` ({e})" + ))) + } + }; + let field_ty = unifier.subst(field.1 .0, &var_map).unwrap_or(field.1 .0); if let Err(e) = unifier.unify(ty, field_ty) { // field type mismatch return Ok(Err(format!( @@ -800,14 +791,15 @@ impl InnerResolver { return Ok(Err("object is not of concrete type".into())); } } - let extracted_ty = unifier.subst(extracted_ty, &var_map).unwrap_or(extracted_ty); + let extracted_ty = + unifier.subst(extracted_ty, &var_map).unwrap_or(extracted_ty); Ok(Ok(extracted_ty)) }; let result = instantiate_obj(); // update/remove the cache according to the result match result { Ok(Ok(ty)) => self.pyid_to_type.write().insert(py_obj_id, ty), - _ => self.pyid_to_type.write().remove(&py_obj_id) + _ => self.pyid_to_type.write().remove(&py_obj_id), }; result } @@ -816,32 +808,32 @@ impl InnerResolver { if unifier.unioned(extracted_ty, primitives.int32) { obj.extract::().map_or_else( |_| Ok(Err(format!("{obj} is not in the range of int32"))), - |_| Ok(Ok(extracted_ty)) + |_| Ok(Ok(extracted_ty)), ) } else if unifier.unioned(extracted_ty, primitives.int64) { obj.extract::().map_or_else( |_| Ok(Err(format!("{obj} is not in the range of int64"))), - |_| Ok(Ok(extracted_ty)) + |_| Ok(Ok(extracted_ty)), ) } else if unifier.unioned(extracted_ty, primitives.uint32) { obj.extract::().map_or_else( |_| Ok(Err(format!("{obj} is not in the range of uint32"))), - |_| Ok(Ok(extracted_ty)) + |_| Ok(Ok(extracted_ty)), ) } else if unifier.unioned(extracted_ty, primitives.uint64) { obj.extract::().map_or_else( |_| Ok(Err(format!("{obj} is not in the range of uint64"))), - |_| Ok(Ok(extracted_ty)) + |_| Ok(Ok(extracted_ty)), ) } else if unifier.unioned(extracted_ty, primitives.bool) { obj.extract::().map_or_else( |_| Ok(Err(format!("{obj} is not in the range of bool"))), - |_| Ok(Ok(extracted_ty)) + |_| Ok(Ok(extracted_ty)), ) } else if unifier.unioned(extracted_ty, primitives.float) { obj.extract::().map_or_else( |_| Ok(Err(format!("{obj} is not in the range of float64"))), - |_| Ok(Ok(extracted_ty)) + |_| Ok(Ok(extracted_ty)), ) } else { Ok(Ok(extracted_ty)) @@ -893,8 +885,8 @@ impl InnerResolver { } let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?; - let elem_ty = - if let TypeEnum::TList { ty } = ctx.unifier.get_ty_immutable(expected_ty).as_ref() + let elem_ty = if let TypeEnum::TList { ty } = + ctx.unifier.get_ty_immutable(expected_ty).as_ref() { *ty } else { @@ -918,13 +910,11 @@ impl InnerResolver { let arr: Result>, _> = (0..len) .map(|i| { - obj - .get_item(i) - .and_then(|elem| self.get_obj_value(py, elem, ctx, generator, elem_ty) - .map_err( - |e| super::CompileError::new_err( - format!("Error getting element {i}: {e}")) - )) + obj.get_item(i).and_then(|elem| { + self.get_obj_value(py, elem, ctx, generator, elem_ty).map_err(|e| { + super::CompileError::new_err(format!("Error getting element {i}: {e}")) + }) + }) }) .collect(); let arr = arr?.unwrap(); @@ -956,7 +946,10 @@ impl InnerResolver { arr_global.set_initializer(&arr); let val = arr_ty.const_named_struct(&[ - arr_global.as_pointer_value().const_cast(ty.ptr_type(AddressSpace::default())).into(), + arr_global + .as_pointer_value() + .const_cast(ty.ptr_type(AddressSpace::default())) + .into(), size_t.const_int(len as u64, false).into(), ]); @@ -968,25 +961,21 @@ impl InnerResolver { todo!() } else if ty_id == self.primitive_ids.tuple { let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty); - let TypeEnum::TTuple { ty } = expected_ty_enum.as_ref() else { - unreachable!() - }; + let TypeEnum::TTuple { ty } = expected_ty_enum.as_ref() else { unreachable!() }; let tup_tys = ty.iter(); let elements: &PyTuple = obj.downcast()?; assert_eq!(elements.len(), tup_tys.len()); - let val: Result>, _> = - elements - .iter() - .enumerate() - .zip(tup_tys) - .map(|((i, elem), ty)| self - .get_obj_value(py, elem, ctx, generator, *ty).map_err(|e| - super::CompileError::new_err( - format!("Error getting element {i}: {e}") - ) - ) - ).collect(); + let val: Result>, _> = elements + .iter() + .enumerate() + .zip(tup_tys) + .map(|((i, elem), ty)| { + self.get_obj_value(py, elem, ctx, generator, *ty).map_err(|e| { + super::CompileError::new_err(format!("Error getting element {i}: {e}")) + }) + }) + .collect(); let val = val?.unwrap(); let val = ctx.ctx.const_struct(&val, false); Ok(Some(val.into())) @@ -997,7 +986,7 @@ impl InnerResolver { { *params.iter().next().unwrap().1 } - _ => unreachable!("must be option type") + _ => unreachable!("must be option type"), }; if id == self.primitive_ids.none { // for option type, just a null ptr @@ -1009,7 +998,13 @@ impl InnerResolver { )) } else { match self - .get_obj_value(py, obj.getattr("_nac3_option").unwrap(), ctx, generator, option_val_ty) + .get_obj_value( + py, + obj.getattr("_nac3_option").unwrap(), + ctx, + generator, + option_val_ty, + ) .map_err(|e| { super::CompileError::new_err(format!( "Error getting value of Option object: {e}" @@ -1019,17 +1014,26 @@ impl InnerResolver { let global_str = format!("{id}_option"); { if self.global_value_ids.read().contains_key(&id) { - let global = ctx.module.get_global(&global_str).unwrap_or_else(|| { - ctx.module.add_global(v.get_type(), Some(AddressSpace::default()), &global_str) - }); + let global = + ctx.module.get_global(&global_str).unwrap_or_else(|| { + ctx.module.add_global( + v.get_type(), + Some(AddressSpace::default()), + &global_str, + ) + }); return Ok(Some(global.as_pointer_value().into())); } self.global_value_ids.write().insert(id, obj.into()); } - let global = ctx.module.add_global(v.get_type(), Some(AddressSpace::default()), &global_str); + let global = ctx.module.add_global( + v.get_type(), + Some(AddressSpace::default()), + &global_str, + ); global.set_initializer(&v); Ok(Some(global.as_pointer_value().into())) - }, + } None => Ok(None), } } @@ -1066,8 +1070,16 @@ impl InnerResolver { let values: Result>, _> = fields .iter() .map(|(name, ty, _)| { - self.get_obj_value(py, obj.getattr(name.to_string().as_str())?, ctx, generator, *ty) - .map_err(|e| super::CompileError::new_err(format!("Error getting field {name}: {e}"))) + self.get_obj_value( + py, + obj.getattr(name.to_string().as_str())?, + ctx, + generator, + *ty, + ) + .map_err(|e| { + super::CompileError::new_err(format!("Error getting field {name}: {e}")) + }) }) .collect(); let values = values?; @@ -1119,8 +1131,7 @@ impl InnerResolver { if id == self.primitive_ids.none { Ok(SymbolValue::OptionNone) } else { - self - .get_default_param_obj_value(py, obj.getattr("_nac3_option").unwrap())? + self.get_default_param_obj_value(py, obj.getattr("_nac3_option").unwrap())? .map(|v| SymbolValue::OptionSome(Box::new(v))) } } else { @@ -1149,7 +1160,8 @@ impl SymbolResolver for Resolver { } } Ok(sym_value) - }).unwrap() + }) + .unwrap() } fn get_symbol_type( @@ -1166,7 +1178,7 @@ impl SymbolResolver for Resolver { Ok(ty) } else { let Some(id) = self.0.name_to_pyid.get(&str) else { - return Err(format!("cannot find symbol `{str}`")) + return Err(format!("cannot find symbol `{str}`")); }; let result = if let Some(t) = { let pyid_to_type = self.0.pyid_to_type.read(); @@ -1191,7 +1203,8 @@ impl SymbolResolver for Resolver { } } Ok(sym_ty) - }).unwrap() + }) + .unwrap() }; result } @@ -1242,15 +1255,16 @@ impl SymbolResolver for Resolver { id_to_def.get(&id).copied().ok_or_else(String::new) } .or_else(|_| { - let py_id = self.0.name_to_pyid.get(&id) - .ok_or_else(|| HashSet::from([ - format!("Undefined identifier `{id}`"), - ]))?; - let result = self.0.pyid_to_def.read().get(py_id) - .copied() - .ok_or_else(|| HashSet::from([ - format!("`{id}` is not registered with NAC3 (@nac3 decorator missing?)"), - ]))?; + let py_id = self + .0 + .name_to_pyid + .get(&id) + .ok_or_else(|| HashSet::from([format!("Undefined identifier `{id}`")]))?; + let result = self.0.pyid_to_def.read().get(py_id).copied().ok_or_else(|| { + HashSet::from([format!( + "`{id}` is not registered with NAC3 (@nac3 decorator missing?)" + )]) + })?; self.0.id_to_def.write().insert(id, result); Ok(result) }) @@ -1274,7 +1288,7 @@ impl SymbolResolver for Resolver { &self, unifier: &mut Unifier, defs: &[Arc>], - primitives: &PrimitiveStore + primitives: &PrimitiveStore, ) -> Result<(), String> { // we don't need a lock because this will only be run in a single thread if self.0.deferred_eval_store.needs_defer.load(Relaxed) { @@ -1304,7 +1318,8 @@ impl SymbolResolver for Resolver { } } Ok(Ok(())) - }).unwrap()?; + }) + .unwrap()?; } Ok(()) } diff --git a/nac3artiq/src/timeline.rs b/nac3artiq/src/timeline.rs index bf89d897..97e220bb 100644 --- a/nac3artiq/src/timeline.rs +++ b/nac3artiq/src/timeline.rs @@ -1,10 +1,12 @@ -use inkwell::{values::{BasicValueEnum, CallSiteValue}, AddressSpace, AtomicOrdering}; +use inkwell::{ + values::{BasicValueEnum, CallSiteValue}, + AddressSpace, AtomicOrdering, +}; use itertools::Either; use nac3core::codegen::CodeGenContext; /// Functions for manipulating the timeline. pub trait TimeFns { - /// Emits LLVM IR for `now_mu`. fn emit_now_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx>; @@ -27,26 +29,31 @@ impl TimeFns for NowPinningTimeFns64 { .module .get_global("now") .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); - let now_hiptr = ctx.builder + let now_hiptr = ctx + .builder .build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr") .map(BasicValueEnum::into_pointer_value) .unwrap(); let now_loptr = unsafe { ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr") - }.unwrap(); + } + .unwrap(); - let now_hi = ctx.builder.build_load(now_hiptr, "now.hi") + let now_hi = ctx + .builder + .build_load(now_hiptr, "now.hi") .map(BasicValueEnum::into_int_value) .unwrap(); - let now_lo = ctx.builder.build_load(now_loptr, "now.lo") + let now_lo = ctx + .builder + .build_load(now_loptr, "now.lo") .map(BasicValueEnum::into_int_value) .unwrap(); let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "").unwrap(); - let shifted_hi = ctx.builder - .build_left_shift(zext_hi, i64_type.const_int(32, false), "") - .unwrap(); + let shifted_hi = + ctx.builder.build_left_shift(zext_hi, i64_type.const_int(32, false), "").unwrap(); let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "").unwrap(); ctx.builder.build_or(shifted_hi, zext_lo, "now_mu").map(Into::into).unwrap() } @@ -58,7 +65,8 @@ impl TimeFns for NowPinningTimeFns64 { let i64_32 = i64_type.const_int(32, false); let time = t.into_int_value(); - let time_hi = ctx.builder + let time_hi = ctx + .builder .build_int_truncate( ctx.builder.build_right_shift(time, i64_32, false, "time.hi").unwrap(), i32_type, @@ -70,14 +78,16 @@ impl TimeFns for NowPinningTimeFns64 { .module .get_global("now") .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); - let now_hiptr = ctx.builder + let now_hiptr = ctx + .builder .build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr") .map(BasicValueEnum::into_pointer_value) .unwrap(); let now_loptr = unsafe { ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr") - }.unwrap(); + } + .unwrap(); ctx.builder .build_store(now_hiptr, time_hi) .unwrap() @@ -90,50 +100,49 @@ impl TimeFns for NowPinningTimeFns64 { .unwrap(); } - fn emit_delay_mu<'ctx>( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - dt: BasicValueEnum<'ctx>, - ) { + fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>) { let i64_type = ctx.ctx.i64_type(); let i32_type = ctx.ctx.i32_type(); let now = ctx .module .get_global("now") .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); - let now_hiptr = ctx.builder + let now_hiptr = ctx + .builder .build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr") .map(BasicValueEnum::into_pointer_value) .unwrap(); let now_loptr = unsafe { ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr") - }.unwrap(); + } + .unwrap(); - let now_hi = ctx.builder.build_load(now_hiptr, "now.hi") + let now_hi = ctx + .builder + .build_load(now_hiptr, "now.hi") .map(BasicValueEnum::into_int_value) .unwrap(); - let now_lo = ctx.builder.build_load(now_loptr, "now.lo") + let now_lo = ctx + .builder + .build_load(now_loptr, "now.lo") .map(BasicValueEnum::into_int_value) .unwrap(); let dt = dt.into_int_value(); let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "").unwrap(); - let shifted_hi = ctx.builder - .build_left_shift(zext_hi, i64_type.const_int(32, false), "") - .unwrap(); + let shifted_hi = + ctx.builder.build_left_shift(zext_hi, i64_type.const_int(32, false), "").unwrap(); let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "").unwrap(); let now_val = ctx.builder.build_or(shifted_hi, zext_lo, "now").unwrap(); let time = ctx.builder.build_int_add(now_val, dt, "time").unwrap(); - let time_hi = ctx.builder + let time_hi = ctx + .builder .build_int_truncate( - ctx.builder.build_right_shift( - time, - i64_type.const_int(32, false), - false, - "", - ).unwrap(), + ctx.builder + .build_right_shift(time, i64_type.const_int(32, false), false, "") + .unwrap(), i32_type, "time.hi", ) @@ -164,16 +173,16 @@ impl TimeFns for NowPinningTimeFns { .module .get_global("now") .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); - let now_raw = ctx.builder.build_load(now.as_pointer_value(), "now") + let now_raw = ctx + .builder + .build_load(now.as_pointer_value(), "now") .map(BasicValueEnum::into_int_value) .unwrap(); let i64_32 = i64_type.const_int(32, false); let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now.lo").unwrap(); let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi").unwrap(); - ctx.builder.build_or(now_lo, now_hi, "now_mu") - .map(Into::into) - .unwrap() + ctx.builder.build_or(now_lo, now_hi, "now_mu").map(Into::into).unwrap() } fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) { @@ -183,7 +192,8 @@ impl TimeFns for NowPinningTimeFns { let time = t.into_int_value(); - let time_hi = ctx.builder + let time_hi = ctx + .builder .build_int_truncate( ctx.builder.build_right_shift(time, i64_32, false, "").unwrap(), i32_type, @@ -195,14 +205,16 @@ impl TimeFns for NowPinningTimeFns { .module .get_global("now") .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); - let now_hiptr = ctx.builder + let now_hiptr = ctx + .builder .build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr") .map(BasicValueEnum::into_pointer_value) .unwrap(); let now_loptr = unsafe { ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr") - }.unwrap(); + } + .unwrap(); ctx.builder .build_store(now_hiptr, time_hi) .unwrap() @@ -215,11 +227,7 @@ impl TimeFns for NowPinningTimeFns { .unwrap(); } - fn emit_delay_mu<'ctx>( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - dt: BasicValueEnum<'ctx>, - ) { + fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>) { let i32_type = ctx.ctx.i32_type(); let i64_type = ctx.ctx.i64_type(); let i64_32 = i64_type.const_int(32, false); @@ -227,7 +235,8 @@ impl TimeFns for NowPinningTimeFns { .module .get_global("now") .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); - let now_raw = ctx.builder + let now_raw = ctx + .builder .build_load(now.as_pointer_value(), "") .map(BasicValueEnum::into_int_value) .unwrap(); @@ -238,7 +247,8 @@ impl TimeFns for NowPinningTimeFns { let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi").unwrap(); let now_val = ctx.builder.build_or(now_lo, now_hi, "now_val").unwrap(); let time = ctx.builder.build_int_add(now_val, dt, "time").unwrap(); - let time_hi = ctx.builder + let time_hi = ctx + .builder .build_int_truncate( ctx.builder.build_right_shift(time, i64_32, false, "time.hi").unwrap(), i32_type, @@ -246,14 +256,16 @@ impl TimeFns for NowPinningTimeFns { ) .unwrap(); let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap(); - let now_hiptr = ctx.builder + let now_hiptr = ctx + .builder .build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr") .map(BasicValueEnum::into_pointer_value) .unwrap(); let now_loptr = unsafe { ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr") - }.unwrap(); + } + .unwrap(); ctx.builder .build_store(now_hiptr, time_hi) .unwrap() @@ -276,7 +288,8 @@ impl TimeFns for ExternTimeFns { let now_mu = ctx.module.get_function("now_mu").unwrap_or_else(|| { ctx.module.add_function("now_mu", ctx.ctx.i64_type().fn_type(&[], false), None) }); - ctx.builder.build_call(now_mu, &[], "now_mu") + ctx.builder + .build_call(now_mu, &[], "now_mu") .map(CallSiteValue::try_as_basic_value) .map(Either::unwrap_left) .unwrap() @@ -293,11 +306,7 @@ impl TimeFns for ExternTimeFns { ctx.builder.build_call(at_mu, &[t.into()], "at_mu").unwrap(); } - fn emit_delay_mu<'ctx>( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - dt: BasicValueEnum<'ctx>, - ) { + fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>) { let delay_mu = ctx.module.get_function("delay_mu").unwrap_or_else(|| { ctx.module.add_function( "delay_mu", diff --git a/nac3ast/src/ast_gen.rs b/nac3ast/src/ast_gen.rs index b5e76df7..09533a2d 100644 --- a/nac3ast/src/ast_gen.rs +++ b/nac3ast/src/ast_gen.rs @@ -1,16 +1,17 @@ // File automatically generated by ast/asdl_rs.py. -pub use crate::location::Location; pub use crate::constant::*; +pub use crate::location::Location; -use std::{fmt, collections::HashMap, cell::RefCell}; -use parking_lot::{Mutex, MutexGuard}; -use string_interner::{DefaultBackend, StringInterner, symbol::SymbolU32}; use fxhash::FxBuildHasher; +use parking_lot::{Mutex, MutexGuard}; +use std::{cell::RefCell, collections::HashMap, fmt}; +use string_interner::{symbol::SymbolU32, DefaultBackend, StringInterner}; pub type Interner = StringInterner; lazy_static! { - static ref INTERNER: Mutex = Mutex::new(StringInterner::with_hasher(FxBuildHasher::default())); + static ref INTERNER: Mutex = + Mutex::new(StringInterner::with_hasher(FxBuildHasher::default())); } thread_local! { @@ -54,7 +55,7 @@ impl From<&str> for StrRef { } } -impl From for String{ +impl From for String { fn from(s: StrRef) -> Self { get_str_from_ref(&get_str_ref_lock(), s).to_string() } @@ -89,20 +90,10 @@ impl Located { #[derive(Clone, Debug, PartialEq)] pub enum Mod { - Module { - body: Vec>, - type_ignores: Vec, - }, - Interactive { - body: Vec>, - }, - Expression { - body: Box>, - }, - FunctionType { - argtypes: Vec>, - returns: Box>, - }, + Module { body: Vec>, type_ignores: Vec }, + Interactive { body: Vec> }, + Expression { body: Box> }, + FunctionType { argtypes: Vec>, returns: Box> }, } #[derive(Clone, Debug, PartialEq)] @@ -430,11 +421,7 @@ pub struct Comprehension { #[derive(Clone, Debug, PartialEq)] pub enum ExcepthandlerKind { - ExceptHandler { - type_: Option>>, - name: Option, - body: Vec>, - }, + ExceptHandler { type_: Option>>, name: Option, body: Vec> }, } pub type Excepthandler = Located, U>; @@ -478,13 +465,9 @@ pub struct Withitem { #[derive(Clone, Debug, PartialEq)] pub enum TypeIgnore { - TypeIgnore { - lineno: usize, - tag: String, - }, + TypeIgnore { lineno: usize, tag: String }, } - #[cfg(feature = "fold")] pub mod fold { use super::*; @@ -492,123 +475,159 @@ pub mod fold { pub trait Fold { type TargetU; type Error; - fn map_user(&mut self, user: U) -> Result; - fn fold_mod(&mut self, node: Mod) -> Result, Self::Error> { - fold_mod(self, node) - } - fn fold_stmt(&mut self, node: Stmt) -> Result, Self::Error> { - fold_stmt(self, node) - } - fn fold_expr(&mut self, node: Expr) -> Result, Self::Error> { - fold_expr(self, node) - } - fn fold_expr_context(&mut self, node: ExprContext) -> Result { - fold_expr_context(self, node) - } - fn fold_boolop(&mut self, node: Boolop) -> Result { - fold_boolop(self, node) - } - fn fold_operator(&mut self, node: Operator) -> Result { - fold_operator(self, node) - } - fn fold_unaryop(&mut self, node: Unaryop) -> Result { - fold_unaryop(self, node) - } - fn fold_cmpop(&mut self, node: Cmpop) -> Result { - fold_cmpop(self, node) - } - fn fold_comprehension(&mut self, node: Comprehension) -> Result, Self::Error> { - fold_comprehension(self, node) - } - fn fold_excepthandler(&mut self, node: Excepthandler) -> Result, Self::Error> { - fold_excepthandler(self, node) - } - fn fold_arguments(&mut self, node: Arguments) -> Result, Self::Error> { - fold_arguments(self, node) - } - fn fold_arg(&mut self, node: Arg) -> Result, Self::Error> { - fold_arg(self, node) - } - fn fold_keyword(&mut self, node: Keyword) -> Result, Self::Error> { - fold_keyword(self, node) - } - fn fold_alias(&mut self, node: Alias) -> Result { - fold_alias(self, node) - } - fn fold_withitem(&mut self, node: Withitem) -> Result, Self::Error> { - fold_withitem(self, node) - } - fn fold_type_ignore(&mut self, node: TypeIgnore) -> Result { - fold_type_ignore(self, node) - } + fn map_user(&mut self, user: U) -> Result; + fn fold_mod(&mut self, node: Mod) -> Result, Self::Error> { + fold_mod(self, node) + } + fn fold_stmt(&mut self, node: Stmt) -> Result, Self::Error> { + fold_stmt(self, node) + } + fn fold_expr(&mut self, node: Expr) -> Result, Self::Error> { + fold_expr(self, node) + } + fn fold_expr_context(&mut self, node: ExprContext) -> Result { + fold_expr_context(self, node) + } + fn fold_boolop(&mut self, node: Boolop) -> Result { + fold_boolop(self, node) + } + fn fold_operator(&mut self, node: Operator) -> Result { + fold_operator(self, node) + } + fn fold_unaryop(&mut self, node: Unaryop) -> Result { + fold_unaryop(self, node) + } + fn fold_cmpop(&mut self, node: Cmpop) -> Result { + fold_cmpop(self, node) + } + fn fold_comprehension( + &mut self, + node: Comprehension, + ) -> Result, Self::Error> { + fold_comprehension(self, node) + } + fn fold_excepthandler( + &mut self, + node: Excepthandler, + ) -> Result, Self::Error> { + fold_excepthandler(self, node) + } + fn fold_arguments( + &mut self, + node: Arguments, + ) -> Result, Self::Error> { + fold_arguments(self, node) + } + fn fold_arg(&mut self, node: Arg) -> Result, Self::Error> { + fold_arg(self, node) + } + fn fold_keyword( + &mut self, + node: Keyword, + ) -> Result, Self::Error> { + fold_keyword(self, node) + } + fn fold_alias(&mut self, node: Alias) -> Result { + fold_alias(self, node) + } + fn fold_withitem( + &mut self, + node: Withitem, + ) -> Result, Self::Error> { + fold_withitem(self, node) + } + fn fold_type_ignore(&mut self, node: TypeIgnore) -> Result { + fold_type_ignore(self, node) + } } - fn fold_located + ?Sized, T, MT>(folder: &mut F, node: Located, f: impl FnOnce(&mut F, T) -> Result) -> Result, F::Error> { - Ok(Located { custom: folder.map_user(node.custom)?, location: node.location, node: f(folder, node.node)? }) + fn fold_located + ?Sized, T, MT>( + folder: &mut F, + node: Located, + f: impl FnOnce(&mut F, T) -> Result, + ) -> Result, F::Error> { + Ok(Located { + custom: folder.map_user(node.custom)?, + location: node.location, + node: f(folder, node.node)?, + }) } impl Foldable for Mod { type Mapped = Mod; - fn fold + ?Sized>(self, folder: &mut F) -> Result { + fn fold + ?Sized>( + self, + folder: &mut F, + ) -> Result { folder.fold_mod(self) } } - pub fn fold_mod + ?Sized>(#[allow(unused)] folder: &mut F, node: Mod) -> Result, F::Error> { + pub fn fold_mod + ?Sized>( + #[allow(unused)] folder: &mut F, + node: Mod, + ) -> Result, F::Error> { match node { - Mod::Module { body,type_ignores } => { - Ok(Mod::Module { - body: Foldable::fold(body, folder)?, - type_ignores: Foldable::fold(type_ignores, folder)?, - }) - } + Mod::Module { body, type_ignores } => Ok(Mod::Module { + body: Foldable::fold(body, folder)?, + type_ignores: Foldable::fold(type_ignores, folder)?, + }), Mod::Interactive { body } => { - Ok(Mod::Interactive { - body: Foldable::fold(body, folder)?, - }) - } - Mod::Expression { body } => { - Ok(Mod::Expression { - body: Foldable::fold(body, folder)?, - }) - } - Mod::FunctionType { argtypes,returns } => { - Ok(Mod::FunctionType { - argtypes: Foldable::fold(argtypes, folder)?, - returns: Foldable::fold(returns, folder)?, - }) + Ok(Mod::Interactive { body: Foldable::fold(body, folder)? }) } + Mod::Expression { body } => Ok(Mod::Expression { body: Foldable::fold(body, folder)? }), + Mod::FunctionType { argtypes, returns } => Ok(Mod::FunctionType { + argtypes: Foldable::fold(argtypes, folder)?, + returns: Foldable::fold(returns, folder)?, + }), } } impl Foldable for Stmt { type Mapped = Stmt; - fn fold + ?Sized>(self, folder: &mut F) -> Result { + fn fold + ?Sized>( + self, + folder: &mut F, + ) -> Result { folder.fold_stmt(self) } } - pub fn fold_stmt + ?Sized>(#[allow(unused)] folder: &mut F, node: Stmt) -> Result, F::Error> { - fold_located(folder, node, |folder, node| { - match node { - StmtKind::FunctionDef { name,args,body,decorator_list,returns,type_comment,config_comment } => { - Ok(StmtKind::FunctionDef { - name: Foldable::fold(name, folder)?, - args: Foldable::fold(args, folder)?, - body: Foldable::fold(body, folder)?, - decorator_list: Foldable::fold(decorator_list, folder)?, - returns: Foldable::fold(returns, folder)?, - type_comment: Foldable::fold(type_comment, folder)?, - config_comment: Foldable::fold(config_comment, folder)?, - }) - } - StmtKind::AsyncFunctionDef { name,args,body,decorator_list,returns,type_comment,config_comment } => { - Ok(StmtKind::AsyncFunctionDef { - name: Foldable::fold(name, folder)?, - args: Foldable::fold(args, folder)?, - body: Foldable::fold(body, folder)?, - decorator_list: Foldable::fold(decorator_list, folder)?, - returns: Foldable::fold(returns, folder)?, - type_comment: Foldable::fold(type_comment, folder)?, - config_comment: Foldable::fold(config_comment, folder)?, - }) - } - StmtKind::ClassDef { name,bases,keywords,body,decorator_list,config_comment } => { + pub fn fold_stmt + ?Sized>( + #[allow(unused)] folder: &mut F, + node: Stmt, + ) -> Result, F::Error> { + fold_located(folder, node, |folder, node| match node { + StmtKind::FunctionDef { + name, + args, + body, + decorator_list, + returns, + type_comment, + config_comment, + } => Ok(StmtKind::FunctionDef { + name: Foldable::fold(name, folder)?, + args: Foldable::fold(args, folder)?, + body: Foldable::fold(body, folder)?, + decorator_list: Foldable::fold(decorator_list, folder)?, + returns: Foldable::fold(returns, folder)?, + type_comment: Foldable::fold(type_comment, folder)?, + config_comment: Foldable::fold(config_comment, folder)?, + }), + StmtKind::AsyncFunctionDef { + name, + args, + body, + decorator_list, + returns, + type_comment, + config_comment, + } => Ok(StmtKind::AsyncFunctionDef { + name: Foldable::fold(name, folder)?, + args: Foldable::fold(args, folder)?, + body: Foldable::fold(body, folder)?, + decorator_list: Foldable::fold(decorator_list, folder)?, + returns: Foldable::fold(returns, folder)?, + type_comment: Foldable::fold(type_comment, folder)?, + config_comment: Foldable::fold(config_comment, folder)?, + }), + StmtKind::ClassDef { name, bases, keywords, body, decorator_list, config_comment } => { Ok(StmtKind::ClassDef { name: Foldable::fold(name, folder)?, bases: Foldable::fold(bases, folder)?, @@ -618,19 +637,15 @@ pub mod fold { config_comment: Foldable::fold(config_comment, folder)?, }) } - StmtKind::Return { value,config_comment } => { - Ok(StmtKind::Return { - value: Foldable::fold(value, folder)?, - config_comment: Foldable::fold(config_comment, folder)?, - }) - } - StmtKind::Delete { targets,config_comment } => { - Ok(StmtKind::Delete { - targets: Foldable::fold(targets, folder)?, - config_comment: Foldable::fold(config_comment, folder)?, - }) - } - StmtKind::Assign { targets,value,type_comment,config_comment } => { + StmtKind::Return { value, config_comment } => Ok(StmtKind::Return { + value: Foldable::fold(value, folder)?, + config_comment: Foldable::fold(config_comment, folder)?, + }), + StmtKind::Delete { targets, config_comment } => Ok(StmtKind::Delete { + targets: Foldable::fold(targets, folder)?, + config_comment: Foldable::fold(config_comment, folder)?, + }), + StmtKind::Assign { targets, value, type_comment, config_comment } => { Ok(StmtKind::Assign { targets: Foldable::fold(targets, folder)?, value: Foldable::fold(value, folder)?, @@ -638,15 +653,13 @@ pub mod fold { config_comment: Foldable::fold(config_comment, folder)?, }) } - StmtKind::AugAssign { target,op,value,config_comment } => { - Ok(StmtKind::AugAssign { - target: Foldable::fold(target, folder)?, - op: Foldable::fold(op, folder)?, - value: Foldable::fold(value, folder)?, - config_comment: Foldable::fold(config_comment, folder)?, - }) - } - StmtKind::AnnAssign { target,annotation,value,simple,config_comment } => { + StmtKind::AugAssign { target, op, value, config_comment } => Ok(StmtKind::AugAssign { + target: Foldable::fold(target, folder)?, + op: Foldable::fold(op, folder)?, + value: Foldable::fold(value, folder)?, + config_comment: Foldable::fold(config_comment, folder)?, + }), + StmtKind::AnnAssign { target, annotation, value, simple, config_comment } => { Ok(StmtKind::AnnAssign { target: Foldable::fold(target, folder)?, annotation: Foldable::fold(annotation, folder)?, @@ -655,7 +668,7 @@ pub mod fold { config_comment: Foldable::fold(config_comment, folder)?, }) } - StmtKind::For { target,iter,body,orelse,type_comment,config_comment } => { + StmtKind::For { target, iter, body, orelse, type_comment, config_comment } => { Ok(StmtKind::For { target: Foldable::fold(target, folder)?, iter: Foldable::fold(iter, folder)?, @@ -665,7 +678,7 @@ pub mod fold { config_comment: Foldable::fold(config_comment, folder)?, }) } - StmtKind::AsyncFor { target,iter,body,orelse,type_comment,config_comment } => { + StmtKind::AsyncFor { target, iter, body, orelse, type_comment, config_comment } => { Ok(StmtKind::AsyncFor { target: Foldable::fold(target, folder)?, iter: Foldable::fold(iter, folder)?, @@ -675,31 +688,25 @@ pub mod fold { config_comment: Foldable::fold(config_comment, folder)?, }) } - StmtKind::While { test,body,orelse,config_comment } => { - Ok(StmtKind::While { - test: Foldable::fold(test, folder)?, - body: Foldable::fold(body, folder)?, - orelse: Foldable::fold(orelse, folder)?, - config_comment: Foldable::fold(config_comment, folder)?, - }) - } - StmtKind::If { test,body,orelse,config_comment } => { - Ok(StmtKind::If { - test: Foldable::fold(test, folder)?, - body: Foldable::fold(body, folder)?, - orelse: Foldable::fold(orelse, folder)?, - config_comment: Foldable::fold(config_comment, folder)?, - }) - } - StmtKind::With { items,body,type_comment,config_comment } => { - Ok(StmtKind::With { - items: Foldable::fold(items, folder)?, - body: Foldable::fold(body, folder)?, - type_comment: Foldable::fold(type_comment, folder)?, - config_comment: Foldable::fold(config_comment, folder)?, - }) - } - StmtKind::AsyncWith { items,body,type_comment,config_comment } => { + StmtKind::While { test, body, orelse, config_comment } => Ok(StmtKind::While { + test: Foldable::fold(test, folder)?, + body: Foldable::fold(body, folder)?, + orelse: Foldable::fold(orelse, folder)?, + config_comment: Foldable::fold(config_comment, folder)?, + }), + StmtKind::If { test, body, orelse, config_comment } => Ok(StmtKind::If { + test: Foldable::fold(test, folder)?, + body: Foldable::fold(body, folder)?, + orelse: Foldable::fold(orelse, folder)?, + config_comment: Foldable::fold(config_comment, folder)?, + }), + StmtKind::With { items, body, type_comment, config_comment } => Ok(StmtKind::With { + items: Foldable::fold(items, folder)?, + body: Foldable::fold(body, folder)?, + type_comment: Foldable::fold(type_comment, folder)?, + config_comment: Foldable::fold(config_comment, folder)?, + }), + StmtKind::AsyncWith { items, body, type_comment, config_comment } => { Ok(StmtKind::AsyncWith { items: Foldable::fold(items, folder)?, body: Foldable::fold(body, folder)?, @@ -707,14 +714,12 @@ pub mod fold { config_comment: Foldable::fold(config_comment, folder)?, }) } - StmtKind::Raise { exc,cause,config_comment } => { - Ok(StmtKind::Raise { - exc: Foldable::fold(exc, folder)?, - cause: Foldable::fold(cause, folder)?, - config_comment: Foldable::fold(config_comment, folder)?, - }) - } - StmtKind::Try { body,handlers,orelse,finalbody,config_comment } => { + StmtKind::Raise { exc, cause, config_comment } => Ok(StmtKind::Raise { + exc: Foldable::fold(exc, folder)?, + cause: Foldable::fold(cause, folder)?, + config_comment: Foldable::fold(config_comment, folder)?, + }), + StmtKind::Try { body, handlers, orelse, finalbody, config_comment } => { Ok(StmtKind::Try { body: Foldable::fold(body, folder)?, handlers: Foldable::fold(handlers, folder)?, @@ -723,20 +728,16 @@ pub mod fold { config_comment: Foldable::fold(config_comment, folder)?, }) } - StmtKind::Assert { test,msg,config_comment } => { - Ok(StmtKind::Assert { - test: Foldable::fold(test, folder)?, - msg: Foldable::fold(msg, folder)?, - config_comment: Foldable::fold(config_comment, folder)?, - }) - } - StmtKind::Import { names,config_comment } => { - Ok(StmtKind::Import { - names: Foldable::fold(names, folder)?, - config_comment: Foldable::fold(config_comment, folder)?, - }) - } - StmtKind::ImportFrom { module,names,level,config_comment } => { + StmtKind::Assert { test, msg, config_comment } => Ok(StmtKind::Assert { + test: Foldable::fold(test, folder)?, + msg: Foldable::fold(msg, folder)?, + config_comment: Foldable::fold(config_comment, folder)?, + }), + StmtKind::Import { names, config_comment } => Ok(StmtKind::Import { + names: Foldable::fold(names, folder)?, + config_comment: Foldable::fold(config_comment, folder)?, + }), + StmtKind::ImportFrom { module, names, level, config_comment } => { Ok(StmtKind::ImportFrom { module: Foldable::fold(module, folder)?, names: Foldable::fold(names, folder)?, @@ -744,155 +745,111 @@ pub mod fold { config_comment: Foldable::fold(config_comment, folder)?, }) } - StmtKind::Global { names,config_comment } => { - Ok(StmtKind::Global { - names: Foldable::fold(names, folder)?, - config_comment: Foldable::fold(config_comment, folder)?, - }) - } - StmtKind::Nonlocal { names,config_comment } => { - Ok(StmtKind::Nonlocal { - names: Foldable::fold(names, folder)?, - config_comment: Foldable::fold(config_comment, folder)?, - }) - } - StmtKind::Expr { value,config_comment } => { - Ok(StmtKind::Expr { - value: Foldable::fold(value, folder)?, - config_comment: Foldable::fold(config_comment, folder)?, - }) - } + StmtKind::Global { names, config_comment } => Ok(StmtKind::Global { + names: Foldable::fold(names, folder)?, + config_comment: Foldable::fold(config_comment, folder)?, + }), + StmtKind::Nonlocal { names, config_comment } => Ok(StmtKind::Nonlocal { + names: Foldable::fold(names, folder)?, + config_comment: Foldable::fold(config_comment, folder)?, + }), + StmtKind::Expr { value, config_comment } => Ok(StmtKind::Expr { + value: Foldable::fold(value, folder)?, + config_comment: Foldable::fold(config_comment, folder)?, + }), StmtKind::Pass { config_comment } => { - Ok(StmtKind::Pass { - config_comment: Foldable::fold(config_comment, folder)?, - }) + Ok(StmtKind::Pass { config_comment: Foldable::fold(config_comment, folder)? }) } StmtKind::Break { config_comment } => { - Ok(StmtKind::Break { - config_comment: Foldable::fold(config_comment, folder)?, - }) + Ok(StmtKind::Break { config_comment: Foldable::fold(config_comment, folder)? }) } StmtKind::Continue { config_comment } => { - Ok(StmtKind::Continue { - config_comment: Foldable::fold(config_comment, folder)?, - }) + Ok(StmtKind::Continue { config_comment: Foldable::fold(config_comment, folder)? }) } - } - }) + }) } impl Foldable for Expr { type Mapped = Expr; - fn fold + ?Sized>(self, folder: &mut F) -> Result { + fn fold + ?Sized>( + self, + folder: &mut F, + ) -> Result { folder.fold_expr(self) } } - pub fn fold_expr + ?Sized>(#[allow(unused)] folder: &mut F, node: Expr) -> Result, F::Error> { - fold_located(folder, node, |folder, node| { - match node { - ExprKind::BoolOp { op,values } => { - Ok(ExprKind::BoolOp { - op: Foldable::fold(op, folder)?, - values: Foldable::fold(values, folder)?, - }) - } - ExprKind::NamedExpr { target,value } => { - Ok(ExprKind::NamedExpr { - target: Foldable::fold(target, folder)?, - value: Foldable::fold(value, folder)?, - }) - } - ExprKind::BinOp { left,op,right } => { - Ok(ExprKind::BinOp { - left: Foldable::fold(left, folder)?, - op: Foldable::fold(op, folder)?, - right: Foldable::fold(right, folder)?, - }) - } - ExprKind::UnaryOp { op,operand } => { - Ok(ExprKind::UnaryOp { - op: Foldable::fold(op, folder)?, - operand: Foldable::fold(operand, folder)?, - }) - } - ExprKind::Lambda { args,body } => { - Ok(ExprKind::Lambda { - args: Foldable::fold(args, folder)?, - body: Foldable::fold(body, folder)?, - }) - } - ExprKind::IfExp { test,body,orelse } => { - Ok(ExprKind::IfExp { - test: Foldable::fold(test, folder)?, - body: Foldable::fold(body, folder)?, - orelse: Foldable::fold(orelse, folder)?, - }) - } - ExprKind::Dict { keys,values } => { - Ok(ExprKind::Dict { - keys: Foldable::fold(keys, folder)?, - values: Foldable::fold(values, folder)?, - }) - } - ExprKind::Set { elts } => { - Ok(ExprKind::Set { - elts: Foldable::fold(elts, folder)?, - }) - } - ExprKind::ListComp { elt,generators } => { - Ok(ExprKind::ListComp { - elt: Foldable::fold(elt, folder)?, - generators: Foldable::fold(generators, folder)?, - }) - } - ExprKind::SetComp { elt,generators } => { - Ok(ExprKind::SetComp { - elt: Foldable::fold(elt, folder)?, - generators: Foldable::fold(generators, folder)?, - }) - } - ExprKind::DictComp { key,value,generators } => { - Ok(ExprKind::DictComp { - key: Foldable::fold(key, folder)?, - value: Foldable::fold(value, folder)?, - generators: Foldable::fold(generators, folder)?, - }) - } - ExprKind::GeneratorExp { elt,generators } => { - Ok(ExprKind::GeneratorExp { - elt: Foldable::fold(elt, folder)?, - generators: Foldable::fold(generators, folder)?, - }) - } + pub fn fold_expr + ?Sized>( + #[allow(unused)] folder: &mut F, + node: Expr, + ) -> Result, F::Error> { + fold_located(folder, node, |folder, node| match node { + ExprKind::BoolOp { op, values } => Ok(ExprKind::BoolOp { + op: Foldable::fold(op, folder)?, + values: Foldable::fold(values, folder)?, + }), + ExprKind::NamedExpr { target, value } => Ok(ExprKind::NamedExpr { + target: Foldable::fold(target, folder)?, + value: Foldable::fold(value, folder)?, + }), + ExprKind::BinOp { left, op, right } => Ok(ExprKind::BinOp { + left: Foldable::fold(left, folder)?, + op: Foldable::fold(op, folder)?, + right: Foldable::fold(right, folder)?, + }), + ExprKind::UnaryOp { op, operand } => Ok(ExprKind::UnaryOp { + op: Foldable::fold(op, folder)?, + operand: Foldable::fold(operand, folder)?, + }), + ExprKind::Lambda { args, body } => Ok(ExprKind::Lambda { + args: Foldable::fold(args, folder)?, + body: Foldable::fold(body, folder)?, + }), + ExprKind::IfExp { test, body, orelse } => Ok(ExprKind::IfExp { + test: Foldable::fold(test, folder)?, + body: Foldable::fold(body, folder)?, + orelse: Foldable::fold(orelse, folder)?, + }), + ExprKind::Dict { keys, values } => Ok(ExprKind::Dict { + keys: Foldable::fold(keys, folder)?, + values: Foldable::fold(values, folder)?, + }), + ExprKind::Set { elts } => Ok(ExprKind::Set { elts: Foldable::fold(elts, folder)? }), + ExprKind::ListComp { elt, generators } => Ok(ExprKind::ListComp { + elt: Foldable::fold(elt, folder)?, + generators: Foldable::fold(generators, folder)?, + }), + ExprKind::SetComp { elt, generators } => Ok(ExprKind::SetComp { + elt: Foldable::fold(elt, folder)?, + generators: Foldable::fold(generators, folder)?, + }), + ExprKind::DictComp { key, value, generators } => Ok(ExprKind::DictComp { + key: Foldable::fold(key, folder)?, + value: Foldable::fold(value, folder)?, + generators: Foldable::fold(generators, folder)?, + }), + ExprKind::GeneratorExp { elt, generators } => Ok(ExprKind::GeneratorExp { + elt: Foldable::fold(elt, folder)?, + generators: Foldable::fold(generators, folder)?, + }), ExprKind::Await { value } => { - Ok(ExprKind::Await { - value: Foldable::fold(value, folder)?, - }) + Ok(ExprKind::Await { value: Foldable::fold(value, folder)? }) } ExprKind::Yield { value } => { - Ok(ExprKind::Yield { - value: Foldable::fold(value, folder)?, - }) + Ok(ExprKind::Yield { value: Foldable::fold(value, folder)? }) } ExprKind::YieldFrom { value } => { - Ok(ExprKind::YieldFrom { - value: Foldable::fold(value, folder)?, - }) + Ok(ExprKind::YieldFrom { value: Foldable::fold(value, folder)? }) } - ExprKind::Compare { left,ops,comparators } => { - Ok(ExprKind::Compare { - left: Foldable::fold(left, folder)?, - ops: Foldable::fold(ops, folder)?, - comparators: Foldable::fold(comparators, folder)?, - }) - } - ExprKind::Call { func,args,keywords } => { - Ok(ExprKind::Call { - func: Foldable::fold(func, folder)?, - args: Foldable::fold(args, folder)?, - keywords: Foldable::fold(keywords, folder)?, - }) - } - ExprKind::FormattedValue { value,conversion,format_spec } => { + ExprKind::Compare { left, ops, comparators } => Ok(ExprKind::Compare { + left: Foldable::fold(left, folder)?, + ops: Foldable::fold(ops, folder)?, + comparators: Foldable::fold(comparators, folder)?, + }), + ExprKind::Call { func, args, keywords } => Ok(ExprKind::Call { + func: Foldable::fold(func, folder)?, + args: Foldable::fold(args, folder)?, + keywords: Foldable::fold(keywords, folder)?, + }), + ExprKind::FormattedValue { value, conversion, format_spec } => { Ok(ExprKind::FormattedValue { value: Foldable::fold(value, folder)?, conversion: Foldable::fold(conversion, folder)?, @@ -900,250 +857,171 @@ pub mod fold { }) } ExprKind::JoinedStr { values } => { - Ok(ExprKind::JoinedStr { - values: Foldable::fold(values, folder)?, - }) + Ok(ExprKind::JoinedStr { values: Foldable::fold(values, folder)? }) } - ExprKind::Constant { value,kind } => { - Ok(ExprKind::Constant { - value: Foldable::fold(value, folder)?, - kind: Foldable::fold(kind, folder)?, - }) - } - ExprKind::Attribute { value,attr,ctx } => { - Ok(ExprKind::Attribute { - value: Foldable::fold(value, folder)?, - attr: Foldable::fold(attr, folder)?, - ctx: Foldable::fold(ctx, folder)?, - }) - } - ExprKind::Subscript { value,slice,ctx } => { - Ok(ExprKind::Subscript { - value: Foldable::fold(value, folder)?, - slice: Foldable::fold(slice, folder)?, - ctx: Foldable::fold(ctx, folder)?, - }) - } - ExprKind::Starred { value,ctx } => { - Ok(ExprKind::Starred { - value: Foldable::fold(value, folder)?, - ctx: Foldable::fold(ctx, folder)?, - }) - } - ExprKind::Name { id,ctx } => { - Ok(ExprKind::Name { - id: Foldable::fold(id, folder)?, - ctx: Foldable::fold(ctx, folder)?, - }) - } - ExprKind::List { elts,ctx } => { - Ok(ExprKind::List { - elts: Foldable::fold(elts, folder)?, - ctx: Foldable::fold(ctx, folder)?, - }) - } - ExprKind::Tuple { elts,ctx } => { - Ok(ExprKind::Tuple { - elts: Foldable::fold(elts, folder)?, - ctx: Foldable::fold(ctx, folder)?, - }) - } - ExprKind::Slice { lower,upper,step } => { - Ok(ExprKind::Slice { - lower: Foldable::fold(lower, folder)?, - upper: Foldable::fold(upper, folder)?, - step: Foldable::fold(step, folder)?, - }) - } - } - }) + ExprKind::Constant { value, kind } => Ok(ExprKind::Constant { + value: Foldable::fold(value, folder)?, + kind: Foldable::fold(kind, folder)?, + }), + ExprKind::Attribute { value, attr, ctx } => Ok(ExprKind::Attribute { + value: Foldable::fold(value, folder)?, + attr: Foldable::fold(attr, folder)?, + ctx: Foldable::fold(ctx, folder)?, + }), + ExprKind::Subscript { value, slice, ctx } => Ok(ExprKind::Subscript { + value: Foldable::fold(value, folder)?, + slice: Foldable::fold(slice, folder)?, + ctx: Foldable::fold(ctx, folder)?, + }), + ExprKind::Starred { value, ctx } => Ok(ExprKind::Starred { + value: Foldable::fold(value, folder)?, + ctx: Foldable::fold(ctx, folder)?, + }), + ExprKind::Name { id, ctx } => Ok(ExprKind::Name { + id: Foldable::fold(id, folder)?, + ctx: Foldable::fold(ctx, folder)?, + }), + ExprKind::List { elts, ctx } => Ok(ExprKind::List { + elts: Foldable::fold(elts, folder)?, + ctx: Foldable::fold(ctx, folder)?, + }), + ExprKind::Tuple { elts, ctx } => Ok(ExprKind::Tuple { + elts: Foldable::fold(elts, folder)?, + ctx: Foldable::fold(ctx, folder)?, + }), + ExprKind::Slice { lower, upper, step } => Ok(ExprKind::Slice { + lower: Foldable::fold(lower, folder)?, + upper: Foldable::fold(upper, folder)?, + step: Foldable::fold(step, folder)?, + }), + }) } impl Foldable for ExprContext { type Mapped = ExprContext; - fn fold + ?Sized>(self, folder: &mut F) -> Result { + fn fold + ?Sized>( + self, + folder: &mut F, + ) -> Result { folder.fold_expr_context(self) } } - pub fn fold_expr_context + ?Sized>(#[allow(unused)] folder: &mut F, node: ExprContext) -> Result { + pub fn fold_expr_context + ?Sized>( + #[allow(unused)] folder: &mut F, + node: ExprContext, + ) -> Result { match node { - ExprContext::Load { } => { - Ok(ExprContext::Load { - }) - } - ExprContext::Store { } => { - Ok(ExprContext::Store { - }) - } - ExprContext::Del { } => { - Ok(ExprContext::Del { - }) - } + ExprContext::Load {} => Ok(ExprContext::Load {}), + ExprContext::Store {} => Ok(ExprContext::Store {}), + ExprContext::Del {} => Ok(ExprContext::Del {}), } } impl Foldable for Boolop { type Mapped = Boolop; - fn fold + ?Sized>(self, folder: &mut F) -> Result { + fn fold + ?Sized>( + self, + folder: &mut F, + ) -> Result { folder.fold_boolop(self) } } - pub fn fold_boolop + ?Sized>(#[allow(unused)] folder: &mut F, node: Boolop) -> Result { + pub fn fold_boolop + ?Sized>( + #[allow(unused)] folder: &mut F, + node: Boolop, + ) -> Result { match node { - Boolop::And { } => { - Ok(Boolop::And { - }) - } - Boolop::Or { } => { - Ok(Boolop::Or { - }) - } + Boolop::And {} => Ok(Boolop::And {}), + Boolop::Or {} => Ok(Boolop::Or {}), } } impl Foldable for Operator { type Mapped = Operator; - fn fold + ?Sized>(self, folder: &mut F) -> Result { + fn fold + ?Sized>( + self, + folder: &mut F, + ) -> Result { folder.fold_operator(self) } } - pub fn fold_operator + ?Sized>(#[allow(unused)] folder: &mut F, node: Operator) -> Result { + pub fn fold_operator + ?Sized>( + #[allow(unused)] folder: &mut F, + node: Operator, + ) -> Result { match node { - Operator::Add { } => { - Ok(Operator::Add { - }) - } - Operator::Sub { } => { - Ok(Operator::Sub { - }) - } - Operator::Mult { } => { - Ok(Operator::Mult { - }) - } - Operator::MatMult { } => { - Ok(Operator::MatMult { - }) - } - Operator::Div { } => { - Ok(Operator::Div { - }) - } - Operator::Mod { } => { - Ok(Operator::Mod { - }) - } - Operator::Pow { } => { - Ok(Operator::Pow { - }) - } - Operator::LShift { } => { - Ok(Operator::LShift { - }) - } - Operator::RShift { } => { - Ok(Operator::RShift { - }) - } - Operator::BitOr { } => { - Ok(Operator::BitOr { - }) - } - Operator::BitXor { } => { - Ok(Operator::BitXor { - }) - } - Operator::BitAnd { } => { - Ok(Operator::BitAnd { - }) - } - Operator::FloorDiv { } => { - Ok(Operator::FloorDiv { - }) - } + Operator::Add {} => Ok(Operator::Add {}), + Operator::Sub {} => Ok(Operator::Sub {}), + Operator::Mult {} => Ok(Operator::Mult {}), + Operator::MatMult {} => Ok(Operator::MatMult {}), + Operator::Div {} => Ok(Operator::Div {}), + Operator::Mod {} => Ok(Operator::Mod {}), + Operator::Pow {} => Ok(Operator::Pow {}), + Operator::LShift {} => Ok(Operator::LShift {}), + Operator::RShift {} => Ok(Operator::RShift {}), + Operator::BitOr {} => Ok(Operator::BitOr {}), + Operator::BitXor {} => Ok(Operator::BitXor {}), + Operator::BitAnd {} => Ok(Operator::BitAnd {}), + Operator::FloorDiv {} => Ok(Operator::FloorDiv {}), } } impl Foldable for Unaryop { type Mapped = Unaryop; - fn fold + ?Sized>(self, folder: &mut F) -> Result { + fn fold + ?Sized>( + self, + folder: &mut F, + ) -> Result { folder.fold_unaryop(self) } } - pub fn fold_unaryop + ?Sized>(#[allow(unused)] folder: &mut F, node: Unaryop) -> Result { + pub fn fold_unaryop + ?Sized>( + #[allow(unused)] folder: &mut F, + node: Unaryop, + ) -> Result { match node { - Unaryop::Invert { } => { - Ok(Unaryop::Invert { - }) - } - Unaryop::Not { } => { - Ok(Unaryop::Not { - }) - } - Unaryop::UAdd { } => { - Ok(Unaryop::UAdd { - }) - } - Unaryop::USub { } => { - Ok(Unaryop::USub { - }) - } + Unaryop::Invert {} => Ok(Unaryop::Invert {}), + Unaryop::Not {} => Ok(Unaryop::Not {}), + Unaryop::UAdd {} => Ok(Unaryop::UAdd {}), + Unaryop::USub {} => Ok(Unaryop::USub {}), } } impl Foldable for Cmpop { type Mapped = Cmpop; - fn fold + ?Sized>(self, folder: &mut F) -> Result { + fn fold + ?Sized>( + self, + folder: &mut F, + ) -> Result { folder.fold_cmpop(self) } } - pub fn fold_cmpop + ?Sized>(#[allow(unused)] folder: &mut F, node: Cmpop) -> Result { + pub fn fold_cmpop + ?Sized>( + #[allow(unused)] folder: &mut F, + node: Cmpop, + ) -> Result { match node { - Cmpop::Eq { } => { - Ok(Cmpop::Eq { - }) - } - Cmpop::NotEq { } => { - Ok(Cmpop::NotEq { - }) - } - Cmpop::Lt { } => { - Ok(Cmpop::Lt { - }) - } - Cmpop::LtE { } => { - Ok(Cmpop::LtE { - }) - } - Cmpop::Gt { } => { - Ok(Cmpop::Gt { - }) - } - Cmpop::GtE { } => { - Ok(Cmpop::GtE { - }) - } - Cmpop::Is { } => { - Ok(Cmpop::Is { - }) - } - Cmpop::IsNot { } => { - Ok(Cmpop::IsNot { - }) - } - Cmpop::In { } => { - Ok(Cmpop::In { - }) - } - Cmpop::NotIn { } => { - Ok(Cmpop::NotIn { - }) - } + Cmpop::Eq {} => Ok(Cmpop::Eq {}), + Cmpop::NotEq {} => Ok(Cmpop::NotEq {}), + Cmpop::Lt {} => Ok(Cmpop::Lt {}), + Cmpop::LtE {} => Ok(Cmpop::LtE {}), + Cmpop::Gt {} => Ok(Cmpop::Gt {}), + Cmpop::GtE {} => Ok(Cmpop::GtE {}), + Cmpop::Is {} => Ok(Cmpop::Is {}), + Cmpop::IsNot {} => Ok(Cmpop::IsNot {}), + Cmpop::In {} => Ok(Cmpop::In {}), + Cmpop::NotIn {} => Ok(Cmpop::NotIn {}), } } impl Foldable for Comprehension { type Mapped = Comprehension; - fn fold + ?Sized>(self, folder: &mut F) -> Result { + fn fold + ?Sized>( + self, + folder: &mut F, + ) -> Result { folder.fold_comprehension(self) } } - pub fn fold_comprehension + ?Sized>(#[allow(unused)] folder: &mut F, node: Comprehension) -> Result, F::Error> { - let Comprehension { target,iter,ifs,is_async } = node; + pub fn fold_comprehension + ?Sized>( + #[allow(unused)] folder: &mut F, + node: Comprehension, + ) -> Result, F::Error> { + let Comprehension { target, iter, ifs, is_async } = node; Ok(Comprehension { target: Foldable::fold(target, folder)?, iter: Foldable::fold(iter, folder)?, @@ -1153,31 +1031,42 @@ pub mod fold { } impl Foldable for Excepthandler { type Mapped = Excepthandler; - fn fold + ?Sized>(self, folder: &mut F) -> Result { + fn fold + ?Sized>( + self, + folder: &mut F, + ) -> Result { folder.fold_excepthandler(self) } } - pub fn fold_excepthandler + ?Sized>(#[allow(unused)] folder: &mut F, node: Excepthandler) -> Result, F::Error> { - fold_located(folder, node, |folder, node| { - match node { - ExcepthandlerKind::ExceptHandler { type_,name,body } => { + pub fn fold_excepthandler + ?Sized>( + #[allow(unused)] folder: &mut F, + node: Excepthandler, + ) -> Result, F::Error> { + fold_located(folder, node, |folder, node| match node { + ExcepthandlerKind::ExceptHandler { type_, name, body } => { Ok(ExcepthandlerKind::ExceptHandler { type_: Foldable::fold(type_, folder)?, name: Foldable::fold(name, folder)?, body: Foldable::fold(body, folder)?, }) } - } - }) + }) } impl Foldable for Arguments { type Mapped = Arguments; - fn fold + ?Sized>(self, folder: &mut F) -> Result { + fn fold + ?Sized>( + self, + folder: &mut F, + ) -> Result { folder.fold_arguments(self) } } - pub fn fold_arguments + ?Sized>(#[allow(unused)] folder: &mut F, node: Arguments) -> Result, F::Error> { - let Arguments { posonlyargs,args,vararg,kwonlyargs,kw_defaults,kwarg,defaults } = node; + pub fn fold_arguments + ?Sized>( + #[allow(unused)] folder: &mut F, + node: Arguments, + ) -> Result, F::Error> { + let Arguments { posonlyargs, args, vararg, kwonlyargs, kw_defaults, kwarg, defaults } = + node; Ok(Arguments { posonlyargs: Foldable::fold(posonlyargs, folder)?, args: Foldable::fold(args, folder)?, @@ -1190,56 +1079,77 @@ pub mod fold { } impl Foldable for Arg { type Mapped = Arg; - fn fold + ?Sized>(self, folder: &mut F) -> Result { + fn fold + ?Sized>( + self, + folder: &mut F, + ) -> Result { folder.fold_arg(self) } } - pub fn fold_arg + ?Sized>(#[allow(unused)] folder: &mut F, node: Arg) -> Result, F::Error> { - fold_located(folder, node, |folder, node| { - let ArgData { arg,annotation,type_comment } = node; - Ok(ArgData { - arg: Foldable::fold(arg, folder)?, - annotation: Foldable::fold(annotation, folder)?, - type_comment: Foldable::fold(type_comment, folder)?, + pub fn fold_arg + ?Sized>( + #[allow(unused)] folder: &mut F, + node: Arg, + ) -> Result, F::Error> { + fold_located(folder, node, |folder, node| { + let ArgData { arg, annotation, type_comment } = node; + Ok(ArgData { + arg: Foldable::fold(arg, folder)?, + annotation: Foldable::fold(annotation, folder)?, + type_comment: Foldable::fold(type_comment, folder)?, + }) }) - }) } impl Foldable for Keyword { type Mapped = Keyword; - fn fold + ?Sized>(self, folder: &mut F) -> Result { + fn fold + ?Sized>( + self, + folder: &mut F, + ) -> Result { folder.fold_keyword(self) } } - pub fn fold_keyword + ?Sized>(#[allow(unused)] folder: &mut F, node: Keyword) -> Result, F::Error> { - fold_located(folder, node, |folder, node| { - let KeywordData { arg,value } = node; - Ok(KeywordData { - arg: Foldable::fold(arg, folder)?, - value: Foldable::fold(value, folder)?, + pub fn fold_keyword + ?Sized>( + #[allow(unused)] folder: &mut F, + node: Keyword, + ) -> Result, F::Error> { + fold_located(folder, node, |folder, node| { + let KeywordData { arg, value } = node; + Ok(KeywordData { + arg: Foldable::fold(arg, folder)?, + value: Foldable::fold(value, folder)?, + }) }) - }) } impl Foldable for Alias { type Mapped = Alias; - fn fold + ?Sized>(self, folder: &mut F) -> Result { + fn fold + ?Sized>( + self, + folder: &mut F, + ) -> Result { folder.fold_alias(self) } } - pub fn fold_alias + ?Sized>(#[allow(unused)] folder: &mut F, node: Alias) -> Result { - let Alias { name,asname } = node; - Ok(Alias { - name: Foldable::fold(name, folder)?, - asname: Foldable::fold(asname, folder)?, - }) + pub fn fold_alias + ?Sized>( + #[allow(unused)] folder: &mut F, + node: Alias, + ) -> Result { + let Alias { name, asname } = node; + Ok(Alias { name: Foldable::fold(name, folder)?, asname: Foldable::fold(asname, folder)? }) } impl Foldable for Withitem { type Mapped = Withitem; - fn fold + ?Sized>(self, folder: &mut F) -> Result { + fn fold + ?Sized>( + self, + folder: &mut F, + ) -> Result { folder.fold_withitem(self) } } - pub fn fold_withitem + ?Sized>(#[allow(unused)] folder: &mut F, node: Withitem) -> Result, F::Error> { - let Withitem { context_expr,optional_vars } = node; + pub fn fold_withitem + ?Sized>( + #[allow(unused)] folder: &mut F, + node: Withitem, + ) -> Result, F::Error> { + let Withitem { context_expr, optional_vars } = node; Ok(Withitem { context_expr: Foldable::fold(context_expr, folder)?, optional_vars: Foldable::fold(optional_vars, folder)?, @@ -1247,19 +1157,22 @@ pub mod fold { } impl Foldable for TypeIgnore { type Mapped = TypeIgnore; - fn fold + ?Sized>(self, folder: &mut F) -> Result { + fn fold + ?Sized>( + self, + folder: &mut F, + ) -> Result { folder.fold_type_ignore(self) } } - pub fn fold_type_ignore + ?Sized>(#[allow(unused)] folder: &mut F, node: TypeIgnore) -> Result { + pub fn fold_type_ignore + ?Sized>( + #[allow(unused)] folder: &mut F, + node: TypeIgnore, + ) -> Result { match node { - TypeIgnore::TypeIgnore { lineno,tag } => { - Ok(TypeIgnore::TypeIgnore { - lineno: Foldable::fold(lineno, folder)?, - tag: Foldable::fold(tag, folder)?, - }) - } + TypeIgnore::TypeIgnore { lineno, tag } => Ok(TypeIgnore::TypeIgnore { + lineno: Foldable::fold(lineno, folder)?, + tag: Foldable::fold(tag, folder)?, + }), } } } - diff --git a/nac3ast/src/constant.rs b/nac3ast/src/constant.rs index d9b0d338..b5e53588 100644 --- a/nac3ast/src/constant.rs +++ b/nac3ast/src/constant.rs @@ -85,33 +85,22 @@ impl crate::fold::Fold for ConstantOptimizer { fn fold_expr(&mut self, node: crate::Expr) -> Result, Self::Error> { match node.node { crate::ExprKind::Tuple { elts, ctx } => { - let elts = elts - .into_iter() - .map(|x| self.fold_expr(x)) - .collect::, _>>()?; - let expr = if elts - .iter() - .all(|e| matches!(e.node, crate::ExprKind::Constant { .. })) - { - let tuple = elts - .into_iter() - .map(|e| match e.node { - crate::ExprKind::Constant { value, .. } => value, - _ => unreachable!(), - }) - .collect(); - crate::ExprKind::Constant { - value: Constant::Tuple(tuple), - kind: None, - } - } else { - crate::ExprKind::Tuple { elts, ctx } - }; - Ok(crate::Expr { - node: expr, - custom: node.custom, - location: node.location, - }) + let elts = + elts.into_iter().map(|x| self.fold_expr(x)).collect::, _>>()?; + let expr = + if elts.iter().all(|e| matches!(e.node, crate::ExprKind::Constant { .. })) { + let tuple = elts + .into_iter() + .map(|e| match e.node { + crate::ExprKind::Constant { value, .. } => value, + _ => unreachable!(), + }) + .collect(); + crate::ExprKind::Constant { value: Constant::Tuple(tuple), kind: None } + } else { + crate::ExprKind::Tuple { elts, ctx } + }; + Ok(crate::Expr { node: expr, custom: node.custom, location: node.location }) } _ => crate::fold::fold_expr(self, node), } @@ -138,18 +127,12 @@ mod tests { Located { location, custom, - node: ExprKind::Constant { - value: 1.into(), - kind: None, - }, + node: ExprKind::Constant { value: 1.into(), kind: None }, }, Located { location, custom, - node: ExprKind::Constant { - value: 2.into(), - kind: None, - }, + node: ExprKind::Constant { value: 2.into(), kind: None }, }, Located { location, @@ -160,26 +143,17 @@ mod tests { Located { location, custom, - node: ExprKind::Constant { - value: 3.into(), - kind: None, - }, + node: ExprKind::Constant { value: 3.into(), kind: None }, }, Located { location, custom, - node: ExprKind::Constant { - value: 4.into(), - kind: None, - }, + node: ExprKind::Constant { value: 4.into(), kind: None }, }, Located { location, custom, - node: ExprKind::Constant { - value: 5.into(), - kind: None, - }, + node: ExprKind::Constant { value: 5.into(), kind: None }, }, ], }, @@ -187,9 +161,7 @@ mod tests { ], }, }; - let new_ast = ConstantOptimizer::new() - .fold_expr(ast) - .unwrap_or_else(|e| match e {}); + let new_ast = ConstantOptimizer::new().fold_expr(ast).unwrap_or_else(|e| match e {}); assert_eq!( new_ast, Located { @@ -199,11 +171,7 @@ mod tests { value: Constant::Tuple(vec![ 1.into(), 2.into(), - Constant::Tuple(vec![ - 3.into(), - 4.into(), - 5.into(), - ]) + Constant::Tuple(vec![3.into(), 4.into(), 5.into(),]) ]), kind: None }, diff --git a/nac3ast/src/fold_helpers.rs b/nac3ast/src/fold_helpers.rs index 5ff83d0b..b7e55121 100644 --- a/nac3ast/src/fold_helpers.rs +++ b/nac3ast/src/fold_helpers.rs @@ -64,11 +64,4 @@ macro_rules! simple_fold { }; } -simple_fold!( - usize, - String, - bool, - StrRef, - constant::Constant, - constant::ConversionFlag -); +simple_fold!(usize, String, bool, StrRef, constant::Constant, constant::ConversionFlag); diff --git a/nac3ast/src/impls.rs b/nac3ast/src/impls.rs index 666acd1f..82398dac 100644 --- a/nac3ast/src/impls.rs +++ b/nac3ast/src/impls.rs @@ -34,10 +34,7 @@ impl ExprKind { ExprKind::Starred { .. } => "starred", ExprKind::Slice { .. } => "slice", ExprKind::JoinedStr { values } => { - if values - .iter() - .any(|e| matches!(e.node, ExprKind::JoinedStr { .. })) - { + if values.iter().any(|e| matches!(e.node, ExprKind::JoinedStr { .. })) { "f-string expression" } else { "literal" diff --git a/nac3ast/src/lib.rs b/nac3ast/src/lib.rs index 2a1786d6..46259896 100644 --- a/nac3ast/src/lib.rs +++ b/nac3ast/src/lib.rs @@ -9,6 +9,6 @@ mod impls; mod location; pub use ast_gen::*; -pub use location::{Location, FileName}; +pub use location::{FileName, Location}; pub type Suite = Vec>; diff --git a/nac3ast/src/location.rs b/nac3ast/src/location.rs index 976dbf05..42664b3d 100644 --- a/nac3ast/src/location.rs +++ b/nac3ast/src/location.rs @@ -1,6 +1,6 @@ //! Datatypes to support source location information. -use std::cmp::Ordering; use crate::ast_gen::StrRef; +use std::cmp::Ordering; use std::fmt; #[derive(Clone, Copy, Debug, Eq, PartialEq)] @@ -22,7 +22,7 @@ impl From for FileName { pub struct Location { pub row: usize, pub column: usize, - pub file: FileName + pub file: FileName, } impl fmt::Display for Location { @@ -35,12 +35,12 @@ impl Ord for Location { fn cmp(&self, other: &Self) -> Ordering { let file_cmp = self.file.0.to_string().cmp(&other.file.0.to_string()); if file_cmp != Ordering::Equal { - return file_cmp + return file_cmp; } let row_cmp = self.row.cmp(&other.row); if row_cmp != Ordering::Equal { - return row_cmp + return row_cmp; } self.column.cmp(&other.column) @@ -76,11 +76,7 @@ impl Location { ) } } - Visualize { - loc: *self, - line, - desc, - } + Visualize { loc: *self, line, desc } } } diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index d8bdf8da..0fdd50cc 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -1,12 +1,12 @@ -use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel}; use inkwell::types::BasicTypeEnum; use inkwell::values::BasicValueEnum; +use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel}; use itertools::Itertools; -use crate::codegen::{CodeGenContext, CodeGenerator, extern_fns, irrt, llvm_intrinsics, numpy}; use crate::codegen::classes::{NDArrayValue, ProxyValue, UntypedArrayLikeAccessor}; use crate::codegen::numpy::ndarray_elementwise_unaryop_impl; use crate::codegen::stmt::gen_for_callback_incrementing; +use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator}; use crate::toplevel::helper::PRIMITIVE_DEF_IDS; use crate::toplevel::numpy::unpack_ndarray_var_tys; use crate::typecheck::typedef::Type; @@ -14,11 +14,7 @@ use crate::typecheck::typedef::Type; /// Shorthand for [`unreachable!()`] when a type of argument is not supported. /// /// The generated message will contain the function name and the name of the unsupported type. -fn unsupported_type( - ctx: &CodeGenContext<'_, '_>, - fn_name: &str, - tys: &[Type], -) -> ! { +fn unsupported_type(ctx: &CodeGenContext<'_, '_>, fn_name: &str, tys: &[Type]) -> ! { unreachable!( "{fn_name}() not supported for '{}'", tys.iter().map(|ty| format!("'{}'", ctx.unifier.stringify(*ty))).join(", "), @@ -40,46 +36,36 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => { debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool)); - ctx.builder - .build_int_z_extend(n, llvm_i32, "zext") - .map(Into::into) - .unwrap() + ctx.builder.build_int_z_extend(n, llvm_i32, "zext").map(Into::into).unwrap() } BasicValueEnum::IntValue(n) if n.get_type().get_bit_width() == 32 => { - debug_assert!([ - ctx.primitives.int32, - ctx.primitives.uint32, - ].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty))); + debug_assert!([ctx.primitives.int32, ctx.primitives.uint32,] + .iter() + .any(|ty| ctx.unifier.unioned(n_ty, *ty))); n.into() } BasicValueEnum::IntValue(n) if n.get_type().get_bit_width() == 64 => { - debug_assert!([ - ctx.primitives.int64, - ctx.primitives.uint64, - ].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty))); + debug_assert!([ctx.primitives.int64, ctx.primitives.uint64,] + .iter() + .any(|ty| ctx.unifier.unioned(n_ty, *ty))); - ctx.builder - .build_int_truncate(n, llvm_i32, "trunc") - .map(Into::into) - .unwrap() + ctx.builder.build_int_truncate(n, llvm_i32, "trunc").map(Into::into).unwrap() } BasicValueEnum::FloatValue(n) => { debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - let to_int64 = ctx.builder - .build_float_to_signed_int(n, ctx.ctx.i64_type(), "") - .unwrap(); - ctx.builder - .build_int_truncate(to_int64, llvm_i32, "conv") - .map(Into::into) - .unwrap() + let to_int64 = + ctx.builder.build_float_to_signed_int(n, ctx.ctx.i64_type(), "").unwrap(); + ctx.builder.build_int_truncate(to_int64, llvm_i32, "conv").map(Into::into).unwrap() } - BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(n) + if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -88,15 +74,13 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>( ctx.primitives.int32, None, NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| { - call_int32(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_int32(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, "int32", &[n_ty]) + _ => unsupported_type(ctx, "int32", &[n_ty]), }) } @@ -113,30 +97,21 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>( Ok(match n { BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32) => { - debug_assert!([ - ctx.primitives.bool, - ctx.primitives.int32, - ctx.primitives.uint32, - ].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty))); + debug_assert!([ctx.primitives.bool, ctx.primitives.int32, ctx.primitives.uint32,] + .iter() + .any(|ty| ctx.unifier.unioned(n_ty, *ty))); if ctx.unifier.unioned(n_ty, ctx.primitives.int32) { - ctx.builder - .build_int_s_extend(n, llvm_i64, "sext") - .map(Into::into) - .unwrap() + ctx.builder.build_int_s_extend(n, llvm_i64, "sext").map(Into::into).unwrap() } else { - ctx.builder - .build_int_z_extend(n, llvm_i64, "zext") - .map(Into::into) - .unwrap() + ctx.builder.build_int_z_extend(n, llvm_i64, "zext").map(Into::into).unwrap() } } BasicValueEnum::IntValue(n) if n.get_type().get_bit_width() == 64 => { - debug_assert!([ - ctx.primitives.int64, - ctx.primitives.uint64, - ].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty))); + debug_assert!([ctx.primitives.int64, ctx.primitives.uint64,] + .iter() + .any(|ty| ctx.unifier.unioned(n_ty, *ty))); n.into() } @@ -150,7 +125,9 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } - BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(n) + if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -159,15 +136,13 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>( ctx.primitives.int64, None, NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| { - call_int64(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_int64(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, "int64", &[n_ty]) + _ => unsupported_type(ctx, "int64", &[n_ty]), }) } @@ -186,17 +161,13 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => { debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool)); - ctx.builder - .build_int_z_extend(n, llvm_i32, "zext") - .map(Into::into) - .unwrap() + ctx.builder.build_int_z_extend(n, llvm_i32, "zext").map(Into::into).unwrap() } BasicValueEnum::IntValue(n) if n.get_type().get_bit_width() == 32 => { - debug_assert!([ - ctx.primitives.int32, - ctx.primitives.uint32, - ].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty))); + debug_assert!([ctx.primitives.int32, ctx.primitives.uint32,] + .iter() + .any(|ty| ctx.unifier.unioned(n_ty, *ty))); n.into() } @@ -207,25 +178,20 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>( || ctx.unifier.unioned(n_ty, ctx.primitives.uint64) ); - ctx.builder - .build_int_truncate(n, llvm_i32, "trunc") - .map(Into::into) - .unwrap() + ctx.builder.build_int_truncate(n, llvm_i32, "trunc").map(Into::into).unwrap() } BasicValueEnum::FloatValue(n) => { debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - let n_gez = ctx.builder + let n_gez = ctx + .builder .build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "") .unwrap(); - let to_int32 = ctx.builder - .build_float_to_signed_int(n, llvm_i32, "") - .unwrap(); - let to_uint64 = ctx.builder - .build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "") - .unwrap(); + let to_int32 = ctx.builder.build_float_to_signed_int(n, llvm_i32, "").unwrap(); + let to_uint64 = + ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap(); ctx.builder .build_select( @@ -237,7 +203,9 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } - BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(n) + if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -246,15 +214,13 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>( ctx.primitives.uint32, None, NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| { - call_uint32(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_uint32(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, "uint32", &[n_ty]) + _ => unsupported_type(ctx, "uint32", &[n_ty]), }) } @@ -271,30 +237,21 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( Ok(match n { BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32) => { - debug_assert!([ - ctx.primitives.bool, - ctx.primitives.int32, - ctx.primitives.uint32, - ].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty))); + debug_assert!([ctx.primitives.bool, ctx.primitives.int32, ctx.primitives.uint32,] + .iter() + .any(|ty| ctx.unifier.unioned(n_ty, *ty))); if ctx.unifier.unioned(n_ty, ctx.primitives.int32) { - ctx.builder - .build_int_s_extend(n, llvm_i64, "sext") - .map(Into::into) - .unwrap() + ctx.builder.build_int_s_extend(n, llvm_i64, "sext").map(Into::into).unwrap() } else { - ctx.builder - .build_int_z_extend(n, llvm_i64, "zext") - .map(Into::into) - .unwrap() + ctx.builder.build_int_z_extend(n, llvm_i64, "zext").map(Into::into).unwrap() } } BasicValueEnum::IntValue(n) if n.get_type().get_bit_width() == 64 => { - debug_assert!([ - ctx.primitives.int64, - ctx.primitives.uint64, - ].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty))); + debug_assert!([ctx.primitives.int64, ctx.primitives.uint64,] + .iter() + .any(|ty| ctx.unifier.unioned(n_ty, *ty))); n.into() } @@ -302,23 +259,20 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::FloatValue(n) => { debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - let val_gez = ctx.builder + let val_gez = ctx + .builder .build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "") .unwrap(); - let to_int64 = ctx.builder - .build_float_to_signed_int(n, llvm_i64, "") - .unwrap(); - let to_uint64 = ctx.builder - .build_float_to_unsigned_int(n, llvm_i64, "") - .unwrap(); + let to_int64 = ctx.builder.build_float_to_signed_int(n, llvm_i64, "").unwrap(); + let to_uint64 = ctx.builder.build_float_to_unsigned_int(n, llvm_i64, "").unwrap(); - ctx.builder - .build_select(val_gez, to_uint64, to_int64, "conv") - .unwrap() + ctx.builder.build_select(val_gez, to_uint64, to_int64, "conv").unwrap() } - BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(n) + if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -327,15 +281,13 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( ctx.primitives.uint64, None, NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| { - call_uint64(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_uint64(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, "uint64", &[n_ty]) + _ => unsupported_type(ctx, "uint64", &[n_ty]), }) } @@ -358,13 +310,14 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( ctx.primitives.uint32, ctx.primitives.int64, ctx.primitives.uint64, - ].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty))); + ] + .iter() + .any(|ty| ctx.unifier.unioned(n_ty, *ty))); - if [ - ctx.primitives.bool, - ctx.primitives.int32, - ctx.primitives.int64, - ].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty)) { + if [ctx.primitives.bool, ctx.primitives.int32, ctx.primitives.int64] + .iter() + .any(|ty| ctx.unifier.unioned(n_ty, *ty)) + { ctx.builder .build_signed_int_to_float(n, llvm_f64, "sitofp") .map(Into::into) @@ -383,7 +336,9 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( n.into() } - BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(n) + if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -392,15 +347,13 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( ctx.primitives.float, None, NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| { - call_float(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_float(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, "float", &[n_ty]) + _ => unsupported_type(ctx, "float", &[n_ty]), }) } @@ -429,7 +382,9 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } - BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(n) + if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -438,15 +393,13 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( ret_elem_ty, None, NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| { - call_round(generator, ctx, (elem_ty, val), ret_elem_ty) - }, + |generator, ctx, val| call_round(generator, ctx, (elem_ty, val), ret_elem_ty), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[n_ty]) + _ => unsupported_type(ctx, FN_NAME, &[n_ty]), }) } @@ -469,7 +422,9 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>( llvm_intrinsics::call_float_roundeven(ctx, n, None).into() } - BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(n) + if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -478,15 +433,13 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>( ctx.primitives.float, None, NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| { - call_numpy_round(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_numpy_round(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[n_ty]) + _ => unsupported_type(ctx, FN_NAME, &[n_ty]), }) } @@ -515,7 +468,9 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( ctx.primitives.uint32, ctx.primitives.int64, ctx.primitives.uint64, - ].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty))); + ] + .iter() + .any(|ty| ctx.unifier.unioned(n_ty, *ty))); ctx.builder .build_int_compare(IntPredicate::NE, n, n.get_type().const_zero(), FN_NAME) @@ -532,7 +487,9 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } - BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(n) + if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -542,11 +499,7 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( None, NDArrayValue::from_ptr_val(n, llvm_usize, None), |generator, ctx, val| { - let elem = call_bool( - generator, - ctx, - (elem_ty, val), - )?; + let elem = call_bool(generator, ctx, (elem_ty, val))?; Ok(generator.bool_to_i8(ctx, elem.into_int_value()).into()) }, @@ -555,7 +508,7 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[n_ty]) + _ => unsupported_type(ctx, FN_NAME, &[n_ty]), }) } @@ -588,7 +541,9 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>( } } - BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(n) + if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -597,15 +552,13 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>( ret_elem_ty, None, NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| { - call_floor(generator, ctx, (elem_ty, val), ret_elem_ty) - }, + |generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[n_ty]) + _ => unsupported_type(ctx, FN_NAME, &[n_ty]), }) } @@ -638,7 +591,9 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>( } } - BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(n) + if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -647,15 +602,13 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>( ret_elem_ty, None, NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| { - call_floor(generator, ctx, (elem_ty, val), ret_elem_ty) - }, + |generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[n_ty]) + _ => unsupported_type(ctx, FN_NAME, &[n_ty]), }) } @@ -684,12 +637,14 @@ pub fn call_min<'ctx>( ctx.primitives.uint32, ctx.primitives.int64, ctx.primitives.uint64, - ].iter().any(|ty| ctx.unifier.unioned(common_ty, *ty))); + ] + .iter() + .any(|ty| ctx.unifier.unioned(common_ty, *ty))); - if [ - ctx.primitives.int32, - ctx.primitives.int64, - ].iter().any(|ty| ctx.unifier.unioned(common_ty, *ty)) { + if [ctx.primitives.int32, ctx.primitives.int64] + .iter() + .any(|ty| ctx.unifier.unioned(common_ty, *ty)) + { llvm_intrinsics::call_int_smin(ctx, m, n, Some(FN_NAME)).into() } else { llvm_intrinsics::call_int_umin(ctx, m, n, Some(FN_NAME)).into() @@ -702,7 +657,7 @@ pub fn call_min<'ctx>( llvm_intrinsics::call_float_minnum(ctx, m, n, Some(FN_NAME)).into() } - _ => unsupported_type(ctx, FN_NAME, &[m_ty, n_ty]) + _ => unsupported_type(ctx, FN_NAME, &[m_ty, n_ty]), } } @@ -727,25 +682,25 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>( ctx.primitives.int64, ctx.primitives.uint64, ctx.primitives.float, - ].iter().any(|ty| ctx.unifier.unioned(a_ty, *ty))); + ] + .iter() + .any(|ty| ctx.unifier.unioned(a_ty, *ty))); a } - BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(n) + if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None)); if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { - let n_sz_eqz = ctx.builder - .build_int_compare( - IntPredicate::NE, - n_sz, - n_sz.get_type().const_zero(), - "", - ) + let n_sz_eqz = ctx + .builder + .build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "") .unwrap(); ctx.make_assert( @@ -760,8 +715,8 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>( let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; unsafe { - let identity = n.data() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None); + let identity = + n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None); ctx.builder.build_store(accumulator_addr, identity).unwrap(); } @@ -771,9 +726,7 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>( llvm_usize.const_int(1, false), (n_sz, false), |generator, ctx, idx| { - let elem = unsafe { - n.data().get_unchecked(ctx, generator, &idx, None) - }; + let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) }; let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); let result = call_min(ctx, (elem_ty, accumulator), (elem_ty, elem)); @@ -788,7 +741,7 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>( accumulator } - _ => unsupported_type(ctx, FN_NAME, &[a_ty]) + _ => unsupported_type(ctx, FN_NAME, &[a_ty]), }) } @@ -804,11 +757,7 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>( let (x1_ty, x1) = x1; let (x2_ty, x2) = x2; - let common_ty = if ctx.unifier.unioned(x1_ty, x2_ty) { - Some(x1_ty) - } else { - None - }; + let common_ty = if ctx.unifier.unioned(x1_ty, x2_ty) { Some(x1_ty) } else { None }; Ok(match (x1, x2) { (BasicValueEnum::IntValue(x1), BasicValueEnum::IntValue(x2)) => { @@ -819,7 +768,9 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>( ctx.primitives.int64, ctx.primitives.uint64, ctx.primitives.float, - ].iter().any(|ty| ctx.unifier.unioned(common_ty.unwrap(), *ty))); + ] + .iter() + .any(|ty| ctx.unifier.unioned(common_ty.unwrap(), *ty))); call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) } @@ -830,11 +781,15 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>( call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) } - (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) => { - let is_ndarray1 = x1_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); - let is_ndarray2 = x2_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + (x1, x2) + if [&x1_ty, &x2_ty].into_iter().any(|ty| { + ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + }) => + { + let is_ndarray1 = + x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_ndarray2 = + x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); let dtype = if is_ndarray1 && is_ndarray2 { let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); @@ -847,18 +802,12 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>( unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 } else if is_ndarray2 { unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { unreachable!() }; + } else { + unreachable!() + }; - let x1_scalar_ty = if is_ndarray1 { - dtype - } else { - x1_ty - }; - let x2_scalar_ty = if is_ndarray2 { - dtype - } else { - x2_ty - }; + let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; + let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; numpy::ndarray_elementwise_binop_impl( generator, @@ -870,10 +819,12 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>( |generator, ctx, (lhs, rhs)| { call_numpy_minimum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) }, - )?.as_base_value().into() + )? + .as_base_value() + .into() } - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), }) } @@ -902,12 +853,14 @@ pub fn call_max<'ctx>( ctx.primitives.uint32, ctx.primitives.int64, ctx.primitives.uint64, - ].iter().any(|ty| ctx.unifier.unioned(common_ty, *ty))); + ] + .iter() + .any(|ty| ctx.unifier.unioned(common_ty, *ty))); - if [ - ctx.primitives.int32, - ctx.primitives.int64, - ].iter().any(|ty| ctx.unifier.unioned(common_ty, *ty)) { + if [ctx.primitives.int32, ctx.primitives.int64] + .iter() + .any(|ty| ctx.unifier.unioned(common_ty, *ty)) + { llvm_intrinsics::call_int_smax(ctx, m, n, Some(FN_NAME)).into() } else { llvm_intrinsics::call_int_umax(ctx, m, n, Some(FN_NAME)).into() @@ -920,7 +873,7 @@ pub fn call_max<'ctx>( llvm_intrinsics::call_float_maxnum(ctx, m, n, Some(FN_NAME)).into() } - _ => unsupported_type(ctx, FN_NAME, &[m_ty, n_ty]) + _ => unsupported_type(ctx, FN_NAME, &[m_ty, n_ty]), } } @@ -945,25 +898,25 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>( ctx.primitives.int64, ctx.primitives.uint64, ctx.primitives.float, - ].iter().any(|ty| ctx.unifier.unioned(a_ty, *ty))); + ] + .iter() + .any(|ty| ctx.unifier.unioned(a_ty, *ty))); a } - BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(n) + if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None)); if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { - let n_sz_eqz = ctx.builder - .build_int_compare( - IntPredicate::NE, - n_sz, - n_sz.get_type().const_zero(), - "", - ) + let n_sz_eqz = ctx + .builder + .build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "") .unwrap(); ctx.make_assert( @@ -978,8 +931,8 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>( let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; unsafe { - let identity = n.data() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None); + let identity = + n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None); ctx.builder.build_store(accumulator_addr, identity).unwrap(); } @@ -989,9 +942,7 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>( llvm_usize.const_int(1, false), (n_sz, false), |generator, ctx, idx| { - let elem = unsafe { - n.data().get_unchecked(ctx, generator, &idx, None) - }; + let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) }; let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); let result = call_max(ctx, (elem_ty, accumulator), (elem_ty, elem)); @@ -1006,7 +957,7 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>( accumulator } - _ => unsupported_type(ctx, FN_NAME, &[a_ty]) + _ => unsupported_type(ctx, FN_NAME, &[a_ty]), }) } @@ -1022,11 +973,7 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>( let (x1_ty, x1) = x1; let (x2_ty, x2) = x2; - let common_ty = if ctx.unifier.unioned(x1_ty, x2_ty) { - Some(x1_ty) - } else { - None - }; + let common_ty = if ctx.unifier.unioned(x1_ty, x2_ty) { Some(x1_ty) } else { None }; Ok(match (x1, x2) { (BasicValueEnum::IntValue(x1), BasicValueEnum::IntValue(x2)) => { @@ -1037,7 +984,9 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>( ctx.primitives.int64, ctx.primitives.uint64, ctx.primitives.float, - ].iter().any(|ty| ctx.unifier.unioned(common_ty.unwrap(), *ty))); + ] + .iter() + .any(|ty| ctx.unifier.unioned(common_ty.unwrap(), *ty))); call_max(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) } @@ -1048,11 +997,15 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>( call_max(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) } - (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) => { - let is_ndarray1 = x1_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); - let is_ndarray2 = x2_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + (x1, x2) + if [&x1_ty, &x2_ty].into_iter().any(|ty| { + ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + }) => + { + let is_ndarray1 = + x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_ndarray2 = + x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); let dtype = if is_ndarray1 && is_ndarray2 { let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); @@ -1065,18 +1018,12 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>( unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 } else if is_ndarray2 { unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { unreachable!() }; + } else { + unreachable!() + }; - let x1_scalar_ty = if is_ndarray1 { - dtype - } else { - x1_ty - }; - let x2_scalar_ty = if is_ndarray2 { - dtype - } else { - x2_ty - }; + let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; + let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; numpy::ndarray_elementwise_binop_impl( generator, @@ -1088,10 +1035,12 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>( |generator, ctx, (lhs, rhs)| { call_numpy_maximum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) }, - )?.as_base_value().into() + )? + .as_base_value() + .into() } - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), }) } @@ -1116,18 +1065,15 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>( ctx.primitives.uint32, ctx.primitives.int64, ctx.primitives.uint64, - ].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty))); + ] + .iter() + .any(|ty| ctx.unifier.unioned(n_ty, *ty))); - if [ - ctx.primitives.int32, - ctx.primitives.int64, - ].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty)) { - llvm_intrinsics::call_int_abs( - ctx, - n, - llvm_i1.const_zero(), - Some(FN_NAME), - ).into() + if [ctx.primitives.int32, ctx.primitives.int64] + .iter() + .any(|ty| ctx.unifier.unioned(n_ty, *ty)) + { + llvm_intrinsics::call_int_abs(ctx, n, llvm_i1.const_zero(), Some(FN_NAME)).into() } else { n.into() } @@ -1139,7 +1085,9 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>( llvm_intrinsics::call_float_fabs(ctx, n, Some(FN_NAME)).into() } - BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(n) + if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1148,15 +1096,13 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>( elem_ty, None, NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| { - call_abs(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_abs(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[n_ty]) + _ => unsupported_type(ctx, FN_NAME, &[n_ty]), }) } @@ -1179,7 +1125,9 @@ pub fn call_numpy_isnan<'ctx, G: CodeGenerator + ?Sized>( irrt::call_isnan(generator, ctx, x).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(x) + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1198,7 +1146,7 @@ pub fn call_numpy_isnan<'ctx, G: CodeGenerator + ?Sized>( ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x_ty]), }) } @@ -1221,7 +1169,9 @@ pub fn call_numpy_isinf<'ctx, G: CodeGenerator + ?Sized>( irrt::call_isinf(generator, ctx, x).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(x) + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1240,7 +1190,7 @@ pub fn call_numpy_isinf<'ctx, G: CodeGenerator + ?Sized>( ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x_ty]), }) } @@ -1263,7 +1213,9 @@ pub fn call_numpy_sin<'ctx, G: CodeGenerator + ?Sized>( llvm_intrinsics::call_float_sin(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(x) + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1272,15 +1224,13 @@ pub fn call_numpy_sin<'ctx, G: CodeGenerator + ?Sized>( elem_ty, None, NDArrayValue::from_ptr_val(x, llvm_usize, None), - |generator, ctx, val| { - call_numpy_sin(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_numpy_sin(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x_ty]), }) } @@ -1303,7 +1253,9 @@ pub fn call_numpy_cos<'ctx, G: CodeGenerator + ?Sized>( llvm_intrinsics::call_float_cos(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(x) + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1312,15 +1264,13 @@ pub fn call_numpy_cos<'ctx, G: CodeGenerator + ?Sized>( elem_ty, None, NDArrayValue::from_ptr_val(x, llvm_usize, None), - |generator, ctx, val| { - call_numpy_cos(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_numpy_cos(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x_ty]), }) } @@ -1343,7 +1293,9 @@ pub fn call_numpy_exp<'ctx, G: CodeGenerator + ?Sized>( llvm_intrinsics::call_float_exp(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(x) + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1352,15 +1304,13 @@ pub fn call_numpy_exp<'ctx, G: CodeGenerator + ?Sized>( elem_ty, None, NDArrayValue::from_ptr_val(x, llvm_usize, None), - |generator, ctx, val| { - call_numpy_exp(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_numpy_exp(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x_ty]), }) } @@ -1383,7 +1333,9 @@ pub fn call_numpy_exp2<'ctx, G: CodeGenerator + ?Sized>( llvm_intrinsics::call_float_exp2(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(x) + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1392,15 +1344,13 @@ pub fn call_numpy_exp2<'ctx, G: CodeGenerator + ?Sized>( elem_ty, None, NDArrayValue::from_ptr_val(x, llvm_usize, None), - |generator, ctx, val| { - call_numpy_exp2(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_numpy_exp2(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x_ty]), }) } @@ -1423,7 +1373,9 @@ pub fn call_numpy_log<'ctx, G: CodeGenerator + ?Sized>( llvm_intrinsics::call_float_log(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(x) + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1432,15 +1384,13 @@ pub fn call_numpy_log<'ctx, G: CodeGenerator + ?Sized>( elem_ty, None, NDArrayValue::from_ptr_val(x, llvm_usize, None), - |generator, ctx, val| { - call_numpy_log(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_numpy_log(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x_ty]), }) } @@ -1463,7 +1413,9 @@ pub fn call_numpy_log10<'ctx, G: CodeGenerator + ?Sized>( llvm_intrinsics::call_float_log10(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(x) + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1472,15 +1424,13 @@ pub fn call_numpy_log10<'ctx, G: CodeGenerator + ?Sized>( elem_ty, None, NDArrayValue::from_ptr_val(x, llvm_usize, None), - |generator, ctx, val| { - call_numpy_log10(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_numpy_log10(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x_ty]), }) } @@ -1503,7 +1453,9 @@ pub fn call_numpy_log2<'ctx, G: CodeGenerator + ?Sized>( llvm_intrinsics::call_float_log2(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(x) + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1512,15 +1464,13 @@ pub fn call_numpy_log2<'ctx, G: CodeGenerator + ?Sized>( elem_ty, None, NDArrayValue::from_ptr_val(x, llvm_usize, None), - |generator, ctx, val| { - call_numpy_log2(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_numpy_log2(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x_ty]), }) } @@ -1543,7 +1493,9 @@ pub fn call_numpy_fabs<'ctx, G: CodeGenerator + ?Sized>( llvm_intrinsics::call_float_fabs(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(x) + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1552,15 +1504,13 @@ pub fn call_numpy_fabs<'ctx, G: CodeGenerator + ?Sized>( elem_ty, None, NDArrayValue::from_ptr_val(x, llvm_usize, None), - |generator, ctx, val| { - call_numpy_fabs(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_numpy_fabs(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x_ty]), }) } @@ -1583,7 +1533,9 @@ pub fn call_numpy_sqrt<'ctx, G: CodeGenerator + ?Sized>( llvm_intrinsics::call_float_sqrt(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(x) + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1592,15 +1544,13 @@ pub fn call_numpy_sqrt<'ctx, G: CodeGenerator + ?Sized>( elem_ty, None, NDArrayValue::from_ptr_val(x, llvm_usize, None), - |generator, ctx, val| { - call_numpy_sqrt(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_numpy_sqrt(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x_ty]), }) } @@ -1623,7 +1573,9 @@ pub fn call_numpy_rint<'ctx, G: CodeGenerator + ?Sized>( llvm_intrinsics::call_float_roundeven(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(x) + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1632,15 +1584,13 @@ pub fn call_numpy_rint<'ctx, G: CodeGenerator + ?Sized>( elem_ty, None, NDArrayValue::from_ptr_val(x, llvm_usize, None), - |generator, ctx, val| { - call_numpy_rint(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_numpy_rint(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x_ty]), }) } @@ -1663,7 +1613,9 @@ pub fn call_numpy_tan<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_tan(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(x) + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1672,15 +1624,13 @@ pub fn call_numpy_tan<'ctx, G: CodeGenerator + ?Sized>( elem_ty, None, NDArrayValue::from_ptr_val(x, llvm_usize, None), - |generator, ctx, val| { - call_numpy_tan(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_numpy_tan(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x_ty]), }) } @@ -1703,7 +1653,9 @@ pub fn call_numpy_arcsin<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_asin(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(x) + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1712,15 +1664,13 @@ pub fn call_numpy_arcsin<'ctx, G: CodeGenerator + ?Sized>( elem_ty, None, NDArrayValue::from_ptr_val(x, llvm_usize, None), - |generator, ctx, val| { - call_numpy_arcsin(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_numpy_arcsin(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x_ty]), }) } @@ -1743,7 +1693,9 @@ pub fn call_numpy_arccos<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_acos(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(x) + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1752,15 +1704,13 @@ pub fn call_numpy_arccos<'ctx, G: CodeGenerator + ?Sized>( elem_ty, None, NDArrayValue::from_ptr_val(x, llvm_usize, None), - |generator, ctx, val| { - call_numpy_arccos(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_numpy_arccos(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x_ty]), }) } @@ -1783,7 +1733,9 @@ pub fn call_numpy_arctan<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_atan(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(x) + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1792,15 +1744,13 @@ pub fn call_numpy_arctan<'ctx, G: CodeGenerator + ?Sized>( elem_ty, None, NDArrayValue::from_ptr_val(x, llvm_usize, None), - |generator, ctx, val| { - call_numpy_arctan(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_numpy_arctan(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x_ty]), }) } @@ -1823,7 +1773,9 @@ pub fn call_numpy_sinh<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_sinh(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(x) + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1832,15 +1784,13 @@ pub fn call_numpy_sinh<'ctx, G: CodeGenerator + ?Sized>( elem_ty, None, NDArrayValue::from_ptr_val(x, llvm_usize, None), - |generator, ctx, val| { - call_numpy_sinh(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_numpy_sinh(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x_ty]), }) } @@ -1863,7 +1813,9 @@ pub fn call_numpy_cosh<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_cosh(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(x) + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1872,15 +1824,13 @@ pub fn call_numpy_cosh<'ctx, G: CodeGenerator + ?Sized>( elem_ty, None, NDArrayValue::from_ptr_val(x, llvm_usize, None), - |generator, ctx, val| { - call_numpy_cosh(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_numpy_cosh(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x_ty]), }) } @@ -1903,7 +1853,9 @@ pub fn call_numpy_tanh<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_tanh(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(x) + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1912,15 +1864,13 @@ pub fn call_numpy_tanh<'ctx, G: CodeGenerator + ?Sized>( elem_ty, None, NDArrayValue::from_ptr_val(x, llvm_usize, None), - |generator, ctx, val| { - call_numpy_tanh(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_numpy_tanh(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x_ty]), }) } @@ -1943,7 +1893,9 @@ pub fn call_numpy_arcsinh<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_asinh(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(x) + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1952,15 +1904,13 @@ pub fn call_numpy_arcsinh<'ctx, G: CodeGenerator + ?Sized>( elem_ty, None, NDArrayValue::from_ptr_val(x, llvm_usize, None), - |generator, ctx, val| { - call_numpy_arcsinh(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_numpy_arcsinh(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x_ty]), }) } @@ -1983,7 +1933,9 @@ pub fn call_numpy_arccosh<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_acosh(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(x) + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1992,15 +1944,13 @@ pub fn call_numpy_arccosh<'ctx, G: CodeGenerator + ?Sized>( elem_ty, None, NDArrayValue::from_ptr_val(x, llvm_usize, None), - |generator, ctx, val| { - call_numpy_arccosh(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_numpy_arccosh(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x_ty]), }) } @@ -2023,7 +1973,9 @@ pub fn call_numpy_arctanh<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_atanh(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(x) + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -2032,15 +1984,13 @@ pub fn call_numpy_arctanh<'ctx, G: CodeGenerator + ?Sized>( elem_ty, None, NDArrayValue::from_ptr_val(x, llvm_usize, None), - |generator, ctx, val| { - call_numpy_arctanh(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_numpy_arctanh(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x_ty]), }) } @@ -2063,7 +2013,9 @@ pub fn call_numpy_expm1<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_expm1(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(x) + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -2072,15 +2024,13 @@ pub fn call_numpy_expm1<'ctx, G: CodeGenerator + ?Sized>( elem_ty, None, NDArrayValue::from_ptr_val(x, llvm_usize, None), - |generator, ctx, val| { - call_numpy_expm1(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_numpy_expm1(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x_ty]), }) } @@ -2103,7 +2053,9 @@ pub fn call_numpy_cbrt<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_cbrt(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(x) + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -2112,15 +2064,13 @@ pub fn call_numpy_cbrt<'ctx, G: CodeGenerator + ?Sized>( elem_ty, None, NDArrayValue::from_ptr_val(x, llvm_usize, None), - |generator, ctx, val| { - call_numpy_cbrt(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_numpy_cbrt(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x_ty]), }) } @@ -2143,7 +2093,9 @@ pub fn call_scipy_special_erf<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_erf(ctx, z, None).into() } - BasicValueEnum::PointerValue(z) if z_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(z) + if z_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, z_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -2152,15 +2104,13 @@ pub fn call_scipy_special_erf<'ctx, G: CodeGenerator + ?Sized>( elem_ty, None, NDArrayValue::from_ptr_val(z, llvm_usize, None), - |generator, ctx, val| { - call_scipy_special_erf(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_scipy_special_erf(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[z_ty]) + _ => unsupported_type(ctx, FN_NAME, &[z_ty]), }) } @@ -2183,7 +2133,9 @@ pub fn call_scipy_special_erfc<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_erfc(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(x) + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -2192,15 +2144,13 @@ pub fn call_scipy_special_erfc<'ctx, G: CodeGenerator + ?Sized>( elem_ty, None, NDArrayValue::from_ptr_val(x, llvm_usize, None), - |generator, ctx, val| { - call_scipy_special_erfc(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_scipy_special_erfc(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x_ty]), }) } @@ -2223,7 +2173,9 @@ pub fn call_scipy_special_gamma<'ctx, G: CodeGenerator + ?Sized>( irrt::call_gamma(ctx, z).into() } - BasicValueEnum::PointerValue(z) if z_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(z) + if z_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, z_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -2232,15 +2184,13 @@ pub fn call_scipy_special_gamma<'ctx, G: CodeGenerator + ?Sized>( elem_ty, None, NDArrayValue::from_ptr_val(z, llvm_usize, None), - |generator, ctx, val| { - call_scipy_special_gamma(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_scipy_special_gamma(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[z_ty]) + _ => unsupported_type(ctx, FN_NAME, &[z_ty]), }) } @@ -2263,7 +2213,9 @@ pub fn call_scipy_special_gammaln<'ctx, G: CodeGenerator + ?Sized>( irrt::call_gammaln(ctx, x).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(x) + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -2272,15 +2224,13 @@ pub fn call_scipy_special_gammaln<'ctx, G: CodeGenerator + ?Sized>( elem_ty, None, NDArrayValue::from_ptr_val(x, llvm_usize, None), - |generator, ctx, val| { - call_scipy_special_gammaln(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_scipy_special_gammaln(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x_ty]), }) } @@ -2303,7 +2253,9 @@ pub fn call_scipy_special_j0<'ctx, G: CodeGenerator + ?Sized>( irrt::call_j0(ctx, x).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(x) + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -2312,15 +2264,13 @@ pub fn call_scipy_special_j0<'ctx, G: CodeGenerator + ?Sized>( elem_ty, None, NDArrayValue::from_ptr_val(x, llvm_usize, None), - |generator, ctx, val| { - call_scipy_special_j0(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_scipy_special_j0(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x_ty]), }) } @@ -2343,7 +2293,9 @@ pub fn call_scipy_special_j1<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_j1(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + BasicValueEnum::PointerValue(x) + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -2352,15 +2304,13 @@ pub fn call_scipy_special_j1<'ctx, G: CodeGenerator + ?Sized>( elem_ty, None, NDArrayValue::from_ptr_val(x, llvm_usize, None), - |generator, ctx, val| { - call_scipy_special_j1(generator, ctx, (elem_ty, val)) - }, + |generator, ctx, val| call_scipy_special_j1(generator, ctx, (elem_ty, val)), )?; ndarray.as_base_value().into() } - _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x_ty]), }) } @@ -2384,11 +2334,15 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_atan2(ctx, x1, x2, None).into() } - (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) => { - let is_ndarray1 = x1_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); - let is_ndarray2 = x2_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + (x1, x2) + if [&x1_ty, &x2_ty].into_iter().any(|ty| { + ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + }) => + { + let is_ndarray1 = + x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_ndarray2 = + x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); let dtype = if is_ndarray1 && is_ndarray2 { let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); @@ -2401,18 +2355,12 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>( unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 } else if is_ndarray2 { unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { unreachable!() }; + } else { + unreachable!() + }; - let x1_scalar_ty = if is_ndarray1 { - dtype - } else { - x1_ty - }; - let x2_scalar_ty = if is_ndarray2 { - dtype - } else { - x2_ty - }; + let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; + let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; numpy::ndarray_elementwise_binop_impl( generator, @@ -2424,10 +2372,12 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>( |generator, ctx, (lhs, rhs)| { call_numpy_arctan2(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) }, - )?.as_base_value().into() + )? + .as_base_value() + .into() } - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), }) } @@ -2451,11 +2401,15 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>( llvm_intrinsics::call_float_copysign(ctx, x1, x2, None).into() } - (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) => { - let is_ndarray1 = x1_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); - let is_ndarray2 = x2_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + (x1, x2) + if [&x1_ty, &x2_ty].into_iter().any(|ty| { + ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + }) => + { + let is_ndarray1 = + x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_ndarray2 = + x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); let dtype = if is_ndarray1 && is_ndarray2 { let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); @@ -2468,18 +2422,12 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>( unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 } else if is_ndarray2 { unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { unreachable!() }; + } else { + unreachable!() + }; - let x1_scalar_ty = if is_ndarray1 { - dtype - } else { - x1_ty - }; - let x2_scalar_ty = if is_ndarray2 { - dtype - } else { - x2_ty - }; + let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; + let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; numpy::ndarray_elementwise_binop_impl( generator, @@ -2491,10 +2439,12 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>( |generator, ctx, (lhs, rhs)| { call_numpy_copysign(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) }, - )?.as_base_value().into() + )? + .as_base_value() + .into() } - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), }) } @@ -2518,11 +2468,15 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>( llvm_intrinsics::call_float_maxnum(ctx, x1, x2, None).into() } - (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) => { - let is_ndarray1 = x1_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); - let is_ndarray2 = x2_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + (x1, x2) + if [&x1_ty, &x2_ty].into_iter().any(|ty| { + ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + }) => + { + let is_ndarray1 = + x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_ndarray2 = + x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); let dtype = if is_ndarray1 && is_ndarray2 { let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); @@ -2535,18 +2489,12 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>( unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 } else if is_ndarray2 { unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { unreachable!() }; + } else { + unreachable!() + }; - let x1_scalar_ty = if is_ndarray1 { - dtype - } else { - x1_ty - }; - let x2_scalar_ty = if is_ndarray2 { - dtype - } else { - x2_ty - }; + let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; + let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; numpy::ndarray_elementwise_binop_impl( generator, @@ -2558,10 +2506,12 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>( |generator, ctx, (lhs, rhs)| { call_numpy_fmax(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) }, - )?.as_base_value().into() + )? + .as_base_value() + .into() } - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), }) } @@ -2585,11 +2535,15 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>( llvm_intrinsics::call_float_minnum(ctx, x1, x2, None).into() } - (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) => { - let is_ndarray1 = x1_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); - let is_ndarray2 = x2_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + (x1, x2) + if [&x1_ty, &x2_ty].into_iter().any(|ty| { + ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + }) => + { + let is_ndarray1 = + x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_ndarray2 = + x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); let dtype = if is_ndarray1 && is_ndarray2 { let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); @@ -2602,18 +2556,12 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>( unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 } else if is_ndarray2 { unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { unreachable!() }; + } else { + unreachable!() + }; - let x1_scalar_ty = if is_ndarray1 { - dtype - } else { - x1_ty - }; - let x2_scalar_ty = if is_ndarray2 { - dtype - } else { - x2_ty - }; + let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; + let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; numpy::ndarray_elementwise_binop_impl( generator, @@ -2625,10 +2573,12 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>( |generator, ctx, (lhs, rhs)| { call_numpy_fmin(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) }, - )?.as_base_value().into() + )? + .as_base_value() + .into() } - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), }) } @@ -2652,24 +2602,22 @@ pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_ldexp(ctx, x1, x2, None).into() } - (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) => { - let is_ndarray1 = x1_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); - let is_ndarray2 = x2_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + (x1, x2) + if [&x1_ty, &x2_ty].into_iter().any(|ty| { + ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + }) => + { + let is_ndarray1 = + x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_ndarray2 = + x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); - let dtype = if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else { - x1_ty - }; + let dtype = + if is_ndarray1 { unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 } else { x1_ty }; let x1_scalar_ty = dtype; - let x2_scalar_ty = if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - x2_ty - }; + let x2_scalar_ty = + if is_ndarray2 { unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 } else { x2_ty }; numpy::ndarray_elementwise_binop_impl( generator, @@ -2681,10 +2629,12 @@ pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>( |generator, ctx, (lhs, rhs)| { call_numpy_ldexp(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) }, - )?.as_base_value().into() + )? + .as_base_value() + .into() } - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), }) } @@ -2708,11 +2658,15 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_hypot(ctx, x1, x2, None).into() } - (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) => { - let is_ndarray1 = x1_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); - let is_ndarray2 = x2_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + (x1, x2) + if [&x1_ty, &x2_ty].into_iter().any(|ty| { + ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + }) => + { + let is_ndarray1 = + x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_ndarray2 = + x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); let dtype = if is_ndarray1 && is_ndarray2 { let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); @@ -2725,18 +2679,12 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>( unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 } else if is_ndarray2 { unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { unreachable!() }; + } else { + unreachable!() + }; - let x1_scalar_ty = if is_ndarray1 { - dtype - } else { - x1_ty - }; - let x2_scalar_ty = if is_ndarray2 { - dtype - } else { - x2_ty - }; + let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; + let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; numpy::ndarray_elementwise_binop_impl( generator, @@ -2748,10 +2696,12 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>( |generator, ctx, (lhs, rhs)| { call_numpy_hypot(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) }, - )?.as_base_value().into() + )? + .as_base_value() + .into() } - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), }) } @@ -2775,11 +2725,15 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_nextafter(ctx, x1, x2, None).into() } - (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) => { - let is_ndarray1 = x1_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); - let is_ndarray2 = x2_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + (x1, x2) + if [&x1_ty, &x2_ty].into_iter().any(|ty| { + ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + }) => + { + let is_ndarray1 = + x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_ndarray2 = + x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); let dtype = if is_ndarray1 && is_ndarray2 { let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); @@ -2792,18 +2746,12 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>( unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 } else if is_ndarray2 { unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { unreachable!() }; + } else { + unreachable!() + }; - let x1_scalar_ty = if is_ndarray1 { - dtype - } else { - x1_ty - }; - let x2_scalar_ty = if is_ndarray2 { - dtype - } else { - x2_ty - }; + let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; + let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; numpy::ndarray_elementwise_binop_impl( generator, @@ -2815,9 +2763,11 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>( |generator, ctx, (lhs, rhs)| { call_numpy_nextafter(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) }, - )?.as_base_value().into() + )? + .as_base_value() + .into() } - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), }) -} \ No newline at end of file +} diff --git a/nac3core/src/codegen/classes.rs b/nac3core/src/codegen/classes.rs index 7eefd74d..36f7ee67 100644 --- a/nac3core/src/codegen/classes.rs +++ b/nac3core/src/codegen/classes.rs @@ -1,29 +1,28 @@ -use inkwell::{ - AddressSpace, IntPredicate, - types::{AnyTypeEnum, BasicTypeEnum, IntType, PointerType}, - values::{BasicValueEnum, IntValue, PointerValue}, +use crate::codegen::{ + irrt::{call_ndarray_calc_size, call_ndarray_flatten_index}, + llvm_intrinsics::call_int_umin, + stmt::gen_for_callback_incrementing, + CodeGenContext, CodeGenerator, }; use inkwell::context::Context; use inkwell::types::{ArrayType, BasicType, StructType}; use inkwell::values::{ArrayValue, BasicValue, StructValue}; -use crate::codegen::{ - CodeGenContext, - CodeGenerator, - irrt::{call_ndarray_calc_size, call_ndarray_flatten_index}, - llvm_intrinsics::call_int_umin, - stmt::gen_for_callback_incrementing, +use inkwell::{ + types::{AnyTypeEnum, BasicTypeEnum, IntType, PointerType}, + values::{BasicValueEnum, IntValue, PointerValue}, + AddressSpace, IntPredicate, }; /// A LLVM type that is used to represent a non-primitive type in NAC3. pub trait ProxyType<'ctx>: Into { - /// The LLVM type of which values of this type possess. This is usually a + /// The LLVM type of which values of this type possess. This is usually a /// [LLVM pointer type][PointerType]. type Base: BasicType<'ctx>; - /// The underlying LLVM type used to represent values. This is usually the element type of + /// The underlying LLVM type used to represent values. This is usually the element type of /// [`Base`] if it is a pointer, otherwise this is the same type as `Base`. type Underlying: BasicType<'ctx>; - + /// The type of values represented by this type. type Value: ProxyValue<'ctx>; @@ -64,7 +63,7 @@ pub trait ProxyType<'ctx>: Into { /// A LLVM type that is used to represent a non-primitive value in NAC3. pub trait ProxyValue<'ctx>: Into { - /// The type of LLVM values represented by this instance. This is usually the + /// The type of LLVM values represented by this instance. This is usually the /// [LLVM pointer type][PointerValue]. type Base: BasicValue<'ctx>; @@ -81,7 +80,7 @@ pub trait ProxyValue<'ctx>: Into { /// Returns the [base value][Self::Base] of this proxy. fn as_base_value(&self) -> Self::Base; - /// Loads this value into its [underlying representation][Self::Underlying]. Usually involves a + /// Loads this value into its [underlying representation][Self::Underlying]. Usually involves a /// `getelementptr` if [`Self::Base`] is a [pointer value][PointerValue]. fn as_underlying_value( &self, @@ -152,7 +151,9 @@ pub trait ArrayLikeIndexer<'ctx, Index = IntValue<'ctx>>: ArrayLikeValue<'ctx> { } /// An array-like value that can have its array elements accessed as a [`BasicValueEnum`]. -pub trait UntypedArrayLikeAccessor<'ctx, Index = IntValue<'ctx>>: ArrayLikeIndexer<'ctx, Index> { +pub trait UntypedArrayLikeAccessor<'ctx, Index = IntValue<'ctx>>: + ArrayLikeIndexer<'ctx, Index> +{ /// # Safety /// /// This function should be called with a valid index. @@ -181,7 +182,9 @@ pub trait UntypedArrayLikeAccessor<'ctx, Index = IntValue<'ctx>>: ArrayLikeIndex } /// An array-like value that can have its array elements mutated as a [`BasicValueEnum`]. -pub trait UntypedArrayLikeMutator<'ctx, Index = IntValue<'ctx>>: ArrayLikeIndexer<'ctx, Index> { +pub trait UntypedArrayLikeMutator<'ctx, Index = IntValue<'ctx>>: + ArrayLikeIndexer<'ctx, Index> +{ /// # Safety /// /// This function should be called with a valid index. @@ -210,9 +213,15 @@ pub trait UntypedArrayLikeMutator<'ctx, Index = IntValue<'ctx>>: ArrayLikeIndexe } /// An array-like value that can have its array elements accessed as an arbitrary type `T`. -pub trait TypedArrayLikeAccessor<'ctx, T, Index = IntValue<'ctx>>: UntypedArrayLikeAccessor<'ctx, Index> { +pub trait TypedArrayLikeAccessor<'ctx, T, Index = IntValue<'ctx>>: + UntypedArrayLikeAccessor<'ctx, Index> +{ /// Casts an element from [`BasicValueEnum`] into `T`. - fn downcast_to_type(&self, ctx: &mut CodeGenContext<'ctx, '_>, value: BasicValueEnum<'ctx>) -> T; + fn downcast_to_type( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + value: BasicValueEnum<'ctx>, + ) -> T; /// # Safety /// @@ -242,9 +251,15 @@ pub trait TypedArrayLikeAccessor<'ctx, T, Index = IntValue<'ctx>>: UntypedArrayL } /// An array-like value that can have its array elements mutated as an arbitrary type `T`. -pub trait TypedArrayLikeMutator<'ctx, T, Index = IntValue<'ctx>>: UntypedArrayLikeMutator<'ctx, Index> { +pub trait TypedArrayLikeMutator<'ctx, T, Index = IntValue<'ctx>>: + UntypedArrayLikeMutator<'ctx, Index> +{ /// Casts an element from T into [`BasicValueEnum`]. - fn upcast_from_type(&self, ctx: &mut CodeGenContext<'ctx, '_>, value: T) -> BasicValueEnum<'ctx>; + fn upcast_from_type( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + value: T, + ) -> BasicValueEnum<'ctx>; /// # Safety /// @@ -274,7 +289,8 @@ pub trait TypedArrayLikeMutator<'ctx, T, Index = IntValue<'ctx>>: UntypedArrayLi } /// Type alias for a function that casts a [`BasicValueEnum`] into a `T`. -type ValueDowncastFn<'ctx, T> = Box, BasicValueEnum<'ctx>) -> T>; +type ValueDowncastFn<'ctx, T> = + Box, BasicValueEnum<'ctx>) -> T>; /// Type alias for a function that casts a `T` into a [`BasicValueEnum`]. type ValueUpcastFn<'ctx, T> = Box, T) -> BasicValueEnum<'ctx>>; @@ -286,7 +302,9 @@ pub struct TypedArrayLikeAdapter<'ctx, T, Adapted: ArrayLikeValue<'ctx> = ArrayS } impl<'ctx, T, Adapted> TypedArrayLikeAdapter<'ctx, T, Adapted> - where Adapted: ArrayLikeValue<'ctx> { +where + Adapted: ArrayLikeValue<'ctx>, +{ /// Creates a [`TypedArrayLikeAdapter`]. /// /// * `adapted` - The value to be adapted. @@ -302,7 +320,9 @@ impl<'ctx, T, Adapted> TypedArrayLikeAdapter<'ctx, T, Adapted> } impl<'ctx, T, Adapted> ArrayLikeValue<'ctx> for TypedArrayLikeAdapter<'ctx, T, Adapted> - where Adapted: ArrayLikeValue<'ctx> { +where + Adapted: ArrayLikeValue<'ctx>, +{ fn element_type( &self, ctx: &CodeGenContext<'ctx, '_>, @@ -328,8 +348,11 @@ impl<'ctx, T, Adapted> ArrayLikeValue<'ctx> for TypedArrayLikeAdapter<'ctx, T, A } } -impl<'ctx, T, Index, Adapted> ArrayLikeIndexer<'ctx, Index> for TypedArrayLikeAdapter<'ctx, T, Adapted> - where Adapted: ArrayLikeIndexer<'ctx, Index> { +impl<'ctx, T, Index, Adapted> ArrayLikeIndexer<'ctx, Index> + for TypedArrayLikeAdapter<'ctx, T, Adapted> +where + Adapted: ArrayLikeIndexer<'ctx, Index>, +{ unsafe fn ptr_offset_unchecked( &self, ctx: &mut CodeGenContext<'ctx, '_>, @@ -351,21 +374,43 @@ impl<'ctx, T, Index, Adapted> ArrayLikeIndexer<'ctx, Index> for TypedArrayLikeAd } } -impl<'ctx, T, Index, Adapted> UntypedArrayLikeAccessor<'ctx, Index> for TypedArrayLikeAdapter<'ctx, T, Adapted> - where Adapted: UntypedArrayLikeAccessor<'ctx, Index> {} -impl<'ctx, T, Index, Adapted> UntypedArrayLikeMutator<'ctx, Index> for TypedArrayLikeAdapter<'ctx, T, Adapted> - where Adapted: UntypedArrayLikeMutator<'ctx, Index> {} +impl<'ctx, T, Index, Adapted> UntypedArrayLikeAccessor<'ctx, Index> + for TypedArrayLikeAdapter<'ctx, T, Adapted> +where + Adapted: UntypedArrayLikeAccessor<'ctx, Index>, +{ +} +impl<'ctx, T, Index, Adapted> UntypedArrayLikeMutator<'ctx, Index> + for TypedArrayLikeAdapter<'ctx, T, Adapted> +where + Adapted: UntypedArrayLikeMutator<'ctx, Index>, +{ +} -impl<'ctx, T, Index, Adapted> TypedArrayLikeAccessor<'ctx, T, Index> for TypedArrayLikeAdapter<'ctx, T, Adapted> - where Adapted: UntypedArrayLikeAccessor<'ctx, Index> { - fn downcast_to_type(&self, ctx: &mut CodeGenContext<'ctx, '_>, value: BasicValueEnum<'ctx>) -> T { +impl<'ctx, T, Index, Adapted> TypedArrayLikeAccessor<'ctx, T, Index> + for TypedArrayLikeAdapter<'ctx, T, Adapted> +where + Adapted: UntypedArrayLikeAccessor<'ctx, Index>, +{ + fn downcast_to_type( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + value: BasicValueEnum<'ctx>, + ) -> T { (self.downcast_fn)(ctx, value) } } -impl<'ctx, T, Index, Adapted> TypedArrayLikeMutator<'ctx, T, Index> for TypedArrayLikeAdapter<'ctx, T, Adapted> - where Adapted: UntypedArrayLikeMutator<'ctx, Index> { - fn upcast_from_type(&self, ctx: &mut CodeGenContext<'ctx, '_>, value: T) -> BasicValueEnum<'ctx> { +impl<'ctx, T, Index, Adapted> TypedArrayLikeMutator<'ctx, T, Index> + for TypedArrayLikeAdapter<'ctx, T, Adapted> +where + Adapted: UntypedArrayLikeMutator<'ctx, Index>, +{ + fn upcast_from_type( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + value: T, + ) -> BasicValueEnum<'ctx> { (self.upcast_fn)(ctx, value) } } @@ -427,15 +472,11 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ArraySliceValue<'ctx> { idx: &IntValue<'ctx>, name: Option<&str>, ) -> PointerValue<'ctx> { - let var_name = name - .map(|v| format!("{v}.addr")) - .unwrap_or_default(); + let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default(); - ctx.builder.build_in_bounds_gep( - self.base_ptr(ctx, generator), - &[*idx], - var_name.as_str(), - ).unwrap() + ctx.builder + .build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str()) + .unwrap() } fn ptr_offset( @@ -458,9 +499,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ArraySliceValue<'ctx> { ctx.current_loc, ); - unsafe { - self.ptr_offset_unchecked(ctx, generator, idx, name) - } + unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) } } } @@ -476,31 +515,33 @@ pub struct ListType<'ctx> { impl<'ctx> ListType<'ctx> { /// Checks whether `llvm_ty` represents a `list` type, returning [Err] if it does not. - pub fn is_type( - llvm_ty: PointerType<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { + pub fn is_type(llvm_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Result<(), String> { let llvm_list_ty = llvm_ty.get_element_type(); let AnyTypeEnum::StructType(llvm_list_ty) = llvm_list_ty else { - return Err(format!("Expected struct type for `list` type, got {llvm_list_ty}")) + return Err(format!("Expected struct type for `list` type, got {llvm_list_ty}")); }; if llvm_list_ty.count_fields() != 2 { - return Err(format!("Expected 2 fields in `list`, got {}", llvm_list_ty.count_fields())) + return Err(format!( + "Expected 2 fields in `list`, got {}", + llvm_list_ty.count_fields() + )); } let list_size_ty = llvm_list_ty.get_field_type_at_index(0).unwrap(); let Ok(_) = PointerType::try_from(list_size_ty) else { - return Err(format!("Expected pointer type for `list.0`, got {list_size_ty}")) + return Err(format!("Expected pointer type for `list.0`, got {list_size_ty}")); }; let list_data_ty = llvm_list_ty.get_field_type_at_index(1).unwrap(); let Ok(list_data_ty) = IntType::try_from(list_data_ty) else { - return Err(format!("Expected int type for `list.1`, got {list_data_ty}")) + return Err(format!("Expected int type for `list.1`, got {list_data_ty}")); }; if list_data_ty.get_bit_width() != llvm_usize.get_bit_width() { - return Err(format!("Expected {}-bit int type for `list.1`, got {}-bit int", - llvm_usize.get_bit_width(), - list_data_ty.get_bit_width())) + return Err(format!( + "Expected {}-bit int type for `list.1`, got {}-bit int", + llvm_usize.get_bit_width(), + list_data_ty.get_bit_width() + )); } Ok(()) @@ -516,10 +557,7 @@ impl<'ctx> ListType<'ctx> { let llvm_usize = generator.get_size_type(ctx); let llvm_list = ctx .struct_type( - &[ - element_type.ptr_type(AddressSpace::default()).into(), - llvm_usize.into(), - ], + &[element_type.ptr_type(AddressSpace::default()).into(), llvm_usize.into()], false, ) .ptr_type(AddressSpace::default()); @@ -555,7 +593,7 @@ impl<'ctx> ListType<'ctx> { .get_field_type_at_index(0) .map(BasicTypeEnum::into_pointer_type) .map(PointerType::get_element_type) - .unwrap() + .unwrap() } } @@ -612,16 +650,17 @@ pub struct ListValue<'ctx> { impl<'ctx> ListValue<'ctx> { /// Checks whether `value` is an instance of `list`, returning [Err] if `value` is not an /// instance. - pub fn is_instance( - value: PointerValue<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { + pub fn is_instance(value: PointerValue<'ctx>, llvm_usize: IntType<'ctx>) -> Result<(), String> { ListType::is_type(value.get_type(), llvm_usize) } /// Creates an [`ListValue`] from a [`PointerValue`]. #[must_use] - pub fn from_ptr_val(ptr: PointerValue<'ctx>, llvm_usize: IntType<'ctx>, name: Option<&'ctx str>) -> Self { + pub fn from_ptr_val( + ptr: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); >::Type::from_type(ptr.get_type(), llvm_usize) @@ -635,11 +674,13 @@ impl<'ctx> ListValue<'ctx> { let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default(); unsafe { - ctx.builder.build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - var_name.as_str(), - ).unwrap() + ctx.builder + .build_in_bounds_gep( + self.as_base_value(), + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + var_name.as_str(), + ) + .unwrap() } } @@ -649,11 +690,13 @@ impl<'ctx> ListValue<'ctx> { let var_name = self.name.map(|v| format!("{v}.size.addr")).unwrap_or_default(); unsafe { - ctx.builder.build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - var_name.as_str(), - ).unwrap() + ctx.builder + .build_in_bounds_gep( + self.as_base_value(), + &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], + var_name.as_str(), + ) + .unwrap() } } @@ -704,7 +747,8 @@ impl<'ctx> ListValue<'ctx> { .or_else(|| self.name.map(|v| format!("{v}.size"))) .unwrap_or_default(); - ctx.builder.build_load(psize, var_name.as_str()) + ctx.builder + .build_load(psize, var_name.as_str()) .map(BasicValueEnum::into_int_value) .unwrap() } @@ -761,7 +805,8 @@ impl<'ctx> ArrayLikeValue<'ctx> for ListDataProxy<'ctx, '_> { ) -> PointerValue<'ctx> { let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default(); - ctx.builder.build_load(self.0.pptr_to_data(ctx), var_name.as_str()) + ctx.builder + .build_load(self.0.pptr_to_data(ctx), var_name.as_str()) .map(BasicValueEnum::into_pointer_value) .unwrap() } @@ -783,15 +828,11 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ListDataProxy<'ctx, '_> { idx: &IntValue<'ctx>, name: Option<&str>, ) -> PointerValue<'ctx> { - let var_name = name - .map(|v| format!("{v}.addr")) - .unwrap_or_default(); + let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default(); - ctx.builder.build_in_bounds_gep( - self.base_ptr(ctx, generator), - &[*idx], - var_name.as_str(), - ).unwrap() + ctx.builder + .build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str()) + .unwrap() } fn ptr_offset( @@ -814,9 +855,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ListDataProxy<'ctx, '_> { ctx.current_loc, ); - unsafe { - self.ptr_offset_unchecked(ctx, generator, idx, name) - } + unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) } } } @@ -834,19 +873,26 @@ impl<'ctx> RangeType<'ctx> { pub fn is_type(llvm_ty: PointerType<'ctx>) -> Result<(), String> { let llvm_range_ty = llvm_ty.get_element_type(); let AnyTypeEnum::ArrayType(llvm_range_ty) = llvm_range_ty else { - return Err(format!("Expected array type for `range` type, got {llvm_range_ty}")) + return Err(format!("Expected array type for `range` type, got {llvm_range_ty}")); }; if llvm_range_ty.len() != 3 { - return Err(format!("Expected 3 elements for `range` type, got {}", llvm_range_ty.len())) + return Err(format!( + "Expected 3 elements for `range` type, got {}", + llvm_range_ty.len() + )); } let llvm_range_elem_ty = llvm_range_ty.get_element_type(); let Ok(llvm_range_elem_ty) = IntType::try_from(llvm_range_elem_ty) else { - return Err(format!("Expected int type for `range` element type, got {llvm_range_elem_ty}")) + return Err(format!( + "Expected int type for `range` element type, got {llvm_range_elem_ty}" + )); }; if llvm_range_elem_ty.get_bit_width() != 32 { - return Err(format!("Expected 32-bit int type for `range` element type, got {}", - llvm_range_elem_ty.get_bit_width())) + return Err(format!( + "Expected 32-bit int type for `range` element type, got {}", + llvm_range_elem_ty.get_bit_width() + )); } Ok(()) @@ -872,11 +918,7 @@ impl<'ctx> RangeType<'ctx> { /// Returns the type of all fields of this `range` type. #[must_use] pub fn value_type(&self) -> IntType<'ctx> { - self.as_base_type() - .get_element_type() - .into_array_type() - .get_element_type() - .into_int_type() + self.as_base_type().get_element_type().into_array_type().get_element_type().into_int_type() } } @@ -897,7 +939,11 @@ impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> { ) } - fn create_value(&self, value: >::Base, name: Option<&'ctx str>) -> Self::Value { + fn create_value( + &self, + value: >::Base, + name: Option<&'ctx str>, + ) -> Self::Value { debug_assert_eq!(value.get_type(), self.as_base_type()); RangeValue { value, name } @@ -944,11 +990,13 @@ impl<'ctx> RangeValue<'ctx> { let var_name = self.name.map(|v| format!("{v}.start.addr")).unwrap_or_default(); unsafe { - ctx.builder.build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_int(0, false)], - var_name.as_str(), - ).unwrap() + ctx.builder + .build_in_bounds_gep( + self.as_base_value(), + &[llvm_i32.const_zero(), llvm_i32.const_int(0, false)], + var_name.as_str(), + ) + .unwrap() } } @@ -957,11 +1005,13 @@ impl<'ctx> RangeValue<'ctx> { let var_name = self.name.map(|v| format!("{v}.end.addr")).unwrap_or_default(); unsafe { - ctx.builder.build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_int(1, false)], - var_name.as_str(), - ).unwrap() + ctx.builder + .build_in_bounds_gep( + self.as_base_value(), + &[llvm_i32.const_zero(), llvm_i32.const_int(1, false)], + var_name.as_str(), + ) + .unwrap() } } @@ -970,20 +1020,18 @@ impl<'ctx> RangeValue<'ctx> { let var_name = self.name.map(|v| format!("{v}.step.addr")).unwrap_or_default(); unsafe { - ctx.builder.build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_int(2, false)], - var_name.as_str(), - ).unwrap() + ctx.builder + .build_in_bounds_gep( + self.as_base_value(), + &[llvm_i32.const_zero(), llvm_i32.const_int(2, false)], + var_name.as_str(), + ) + .unwrap() } } /// Stores the `start` value into this instance. - pub fn store_start( - &self, - ctx: &CodeGenContext<'ctx, '_>, - start: IntValue<'ctx>, - ) { + pub fn store_start(&self, ctx: &CodeGenContext<'ctx, '_>, start: IntValue<'ctx>) { debug_assert_eq!(start.get_type().get_bit_width(), 32); let pstart = self.ptr_to_start(ctx); @@ -998,17 +1046,14 @@ impl<'ctx> RangeValue<'ctx> { .or_else(|| self.name.map(|v| format!("{v}.start"))) .unwrap_or_default(); - ctx.builder.build_load(pstart, var_name.as_str()) + ctx.builder + .build_load(pstart, var_name.as_str()) .map(BasicValueEnum::into_int_value) .unwrap() } /// Stores the `end` value into this instance. - pub fn store_end( - &self, - ctx: &CodeGenContext<'ctx, '_>, - end: IntValue<'ctx>, - ) { + pub fn store_end(&self, ctx: &CodeGenContext<'ctx, '_>, end: IntValue<'ctx>) { debug_assert_eq!(end.get_type().get_bit_width(), 32); let pend = self.ptr_to_end(ctx); @@ -1023,17 +1068,11 @@ impl<'ctx> RangeValue<'ctx> { .or_else(|| self.name.map(|v| format!("{v}.end"))) .unwrap_or_default(); - ctx.builder.build_load(pend, var_name.as_str()) - .map(BasicValueEnum::into_int_value) - .unwrap() + ctx.builder.build_load(pend, var_name.as_str()).map(BasicValueEnum::into_int_value).unwrap() } /// Stores the `step` value into this instance. - pub fn store_step( - &self, - ctx: &CodeGenContext<'ctx, '_>, - step: IntValue<'ctx>, - ) { + pub fn store_step(&self, ctx: &CodeGenContext<'ctx, '_>, step: IntValue<'ctx>) { debug_assert_eq!(step.get_type().get_bit_width(), 32); let pstep = self.ptr_to_step(ctx); @@ -1048,7 +1087,8 @@ impl<'ctx> RangeValue<'ctx> { .or_else(|| self.name.map(|v| format!("{v}.step"))) .unwrap_or_default(); - ctx.builder.build_load(pstep, var_name.as_str()) + ctx.builder + .build_load(pstep, var_name.as_str()) .map(BasicValueEnum::into_int_value) .unwrap() } @@ -1094,45 +1134,51 @@ pub struct NDArrayType<'ctx> { impl<'ctx> NDArrayType<'ctx> { /// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not. - pub fn is_type( - llvm_ty: PointerType<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { + pub fn is_type(llvm_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Result<(), String> { let llvm_ndarray_ty = llvm_ty.get_element_type(); let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { - return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}")) + return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}")); }; if llvm_ndarray_ty.count_fields() != 3 { - return Err(format!("Expected 3 fields in `NDArray`, got {}", llvm_ndarray_ty.count_fields())) + return Err(format!( + "Expected 3 fields in `NDArray`, got {}", + llvm_ndarray_ty.count_fields() + )); } let ndarray_ndims_ty = llvm_ndarray_ty.get_field_type_at_index(0).unwrap(); let Ok(ndarray_ndims_ty) = IntType::try_from(ndarray_ndims_ty) else { - return Err(format!("Expected int type for `ndarray.0`, got {ndarray_ndims_ty}")) + return Err(format!("Expected int type for `ndarray.0`, got {ndarray_ndims_ty}")); }; if ndarray_ndims_ty.get_bit_width() != llvm_usize.get_bit_width() { - return Err(format!("Expected {}-bit int type for `ndarray.0`, got {}-bit int", - llvm_usize.get_bit_width(), - ndarray_ndims_ty.get_bit_width())) + return Err(format!( + "Expected {}-bit int type for `ndarray.0`, got {}-bit int", + llvm_usize.get_bit_width(), + ndarray_ndims_ty.get_bit_width() + )); } let ndarray_dims_ty = llvm_ndarray_ty.get_field_type_at_index(1).unwrap(); let Ok(ndarray_pdims) = PointerType::try_from(ndarray_dims_ty) else { - return Err(format!("Expected pointer type for `ndarray.1`, got {ndarray_dims_ty}")) + return Err(format!("Expected pointer type for `ndarray.1`, got {ndarray_dims_ty}")); }; let ndarray_dims = ndarray_pdims.get_element_type(); let Ok(ndarray_dims) = IntType::try_from(ndarray_dims) else { - return Err(format!("Expected pointer-to-int type for `ndarray.1`, got pointer-to-{ndarray_dims}")) + return Err(format!( + "Expected pointer-to-int type for `ndarray.1`, got pointer-to-{ndarray_dims}" + )); }; if ndarray_dims.get_bit_width() != llvm_usize.get_bit_width() { - return Err(format!("Expected pointer-to-{}-bit int type for `ndarray.1`, got pointer-to-{}-bit int", - llvm_usize.get_bit_width(), - ndarray_dims.get_bit_width())) + return Err(format!( + "Expected pointer-to-{}-bit int type for `ndarray.1`, got pointer-to-{}-bit int", + llvm_usize.get_bit_width(), + ndarray_dims.get_bit_width() + )); } let ndarray_data_ty = llvm_ndarray_ty.get_field_type_at_index(2).unwrap(); let Ok(_) = PointerType::try_from(ndarray_data_ty) else { - return Err(format!("Expected pointer type for `ndarray.2`, got {ndarray_data_ty}")) + return Err(format!("Expected pointer type for `ndarray.2`, got {ndarray_data_ty}")); }; Ok(()) @@ -1151,13 +1197,13 @@ impl<'ctx> NDArrayType<'ctx> { // // * num_dims: Number of dimensions in the array // * dims: Pointer to an array containing the size of each dimension - // * data: Pointer to an array containing the array data + // * data: Pointer to an array containing the array data let llvm_ndarray = ctx .struct_type( &[ llvm_usize.into(), llvm_usize.ptr_type(AddressSpace::default()).into(), - dtype.ptr_type(AddressSpace::default()).into(), + dtype.ptr_type(AddressSpace::default()).into(), ], false, ) @@ -1193,7 +1239,7 @@ impl<'ctx> NDArrayType<'ctx> { .into_struct_type() .get_field_type_at_index(2) .unwrap() - } + } } impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> { @@ -1208,9 +1254,7 @@ impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> { name: Option<&'ctx str>, ) -> Self::Value { self.create_value( - generator - .gen_var_alloc(ctx, self.as_underlying_type().into(), name) - .unwrap(), + generator.gen_var_alloc(ctx, self.as_underlying_type().into(), name).unwrap(), name, ) } @@ -1251,16 +1295,17 @@ pub struct NDArrayValue<'ctx> { impl<'ctx> NDArrayValue<'ctx> { /// Checks whether `value` is an instance of `NDArray`, returning [Err] if `value` is not an /// instance. - pub fn is_instance( - value: PointerValue<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { + pub fn is_instance(value: PointerValue<'ctx>, llvm_usize: IntType<'ctx>) -> Result<(), String> { NDArrayType::is_type(value.get_type(), llvm_usize) } /// Creates an [`NDArrayValue`] from a [`PointerValue`]. #[must_use] - pub fn from_ptr_val(ptr: PointerValue<'ctx>, llvm_usize: IntType<'ctx>, name: Option<&'ctx str>) -> Self { + pub fn from_ptr_val( + ptr: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); >::Type::from_type(ptr.get_type(), llvm_usize) @@ -1273,11 +1318,13 @@ impl<'ctx> NDArrayValue<'ctx> { let var_name = self.name.map(|v| format!("{v}.ndims.addr")).unwrap_or_default(); unsafe { - ctx.builder.build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - var_name.as_str(), - ).unwrap() + ctx.builder + .build_in_bounds_gep( + self.as_base_value(), + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + var_name.as_str(), + ) + .unwrap() } } @@ -1297,9 +1344,7 @@ impl<'ctx> NDArrayValue<'ctx> { /// Returns the number of dimensions of this `NDArray` as a value. pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { let pndims = self.ptr_to_ndims(ctx); - ctx.builder.build_load(pndims, "") - .map(BasicValueEnum::into_int_value) - .unwrap() + ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap() } /// Returns the double-indirection pointer to the `dims` array, as if by calling `getelementptr` @@ -1309,11 +1354,13 @@ impl<'ctx> NDArrayValue<'ctx> { let var_name = self.name.map(|v| format!("{v}.dims.addr")).unwrap_or_default(); unsafe { - ctx.builder.build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - var_name.as_str(), - ).unwrap() + ctx.builder + .build_in_bounds_gep( + self.as_base_value(), + &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], + var_name.as_str(), + ) + .unwrap() } } @@ -1345,11 +1392,13 @@ impl<'ctx> NDArrayValue<'ctx> { let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default(); unsafe { - ctx.builder.build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_int(2, true)], - var_name.as_str(), - ).unwrap() + ctx.builder + .build_in_bounds_gep( + self.as_base_value(), + &[llvm_i32.const_zero(), llvm_i32.const_int(2, true)], + var_name.as_str(), + ) + .unwrap() } } @@ -1427,7 +1476,8 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDimsProxy<'ctx, '_> { ) -> PointerValue<'ctx> { let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default(); - ctx.builder.build_load(self.0.ptr_to_dims(ctx), var_name.as_str()) + ctx.builder + .build_load(self.0.ptr_to_dims(ctx), var_name.as_str()) .map(BasicValueEnum::into_pointer_value) .unwrap() } @@ -1449,15 +1499,11 @@ impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> idx: &IntValue<'ctx>, name: Option<&str>, ) -> PointerValue<'ctx> { - let var_name = name - .map(|v| format!("{v}.addr")) - .unwrap_or_default(); + let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default(); - ctx.builder.build_in_bounds_gep( - self.base_ptr(ctx, generator), - &[*idx], - var_name.as_str(), - ).unwrap() + ctx.builder + .build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str()) + .unwrap() } fn ptr_offset( @@ -1468,12 +1514,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> name: Option<&str>, ) -> PointerValue<'ctx> { let size = self.size(ctx, generator); - let in_range = ctx.builder.build_int_compare( - IntPredicate::ULT, - *idx, - size, - "" - ).unwrap(); + let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap(); ctx.make_assert( generator, in_range, @@ -1483,9 +1524,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> ctx.current_loc, ); - unsafe { - self.ptr_offset_unchecked(ctx, generator, idx, name) - } + unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) } } } @@ -1532,7 +1571,8 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> { ) -> PointerValue<'ctx> { let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default(); - ctx.builder.build_load(self.0.ptr_to_data(ctx), var_name.as_str()) + ctx.builder + .build_load(self.0.ptr_to_data(ctx), var_name.as_str()) .map(BasicValueEnum::into_pointer_value) .unwrap() } @@ -1554,11 +1594,9 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> { idx: &IntValue<'ctx>, name: Option<&str>, ) -> PointerValue<'ctx> { - ctx.builder.build_in_bounds_gep( - self.base_ptr(ctx, generator), - &[*idx], - name.unwrap_or_default(), - ).unwrap() + ctx.builder + .build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], name.unwrap_or_default()) + .unwrap() } fn ptr_offset( @@ -1569,12 +1607,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> { name: Option<&str>, ) -> PointerValue<'ctx> { let data_sz = self.size(ctx, generator); - let in_range = ctx.builder.build_int_compare( - IntPredicate::ULT, - *idx, - data_sz, - "" - ).unwrap(); + let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, data_sz, "").unwrap(); ctx.make_assert( generator, in_range, @@ -1584,16 +1617,16 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> { ctx.current_loc, ); - unsafe { - self.ptr_offset_unchecked(ctx, generator, idx, name) - } + unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) } } } impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDataProxy<'ctx, '_> {} impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayDataProxy<'ctx, '_> {} -impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> for NDArrayDataProxy<'ctx, '_> { +impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> + for NDArrayDataProxy<'ctx, '_> +{ unsafe fn ptr_offset_unchecked( &self, ctx: &mut CodeGenContext<'ctx, '_>, @@ -1610,21 +1643,23 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else { panic!("Expected list[int32] but got {indices_elem_ty}") }; - assert_eq!(indices_elem_ty.get_bit_width(), 32, "Expected list[int32] but got list[int{}]", indices_elem_ty.get_bit_width()); - - let index = call_ndarray_flatten_index( - generator, - ctx, - *self.0, - indices, + assert_eq!( + indices_elem_ty.get_bit_width(), + 32, + "Expected list[int32] but got list[int{}]", + indices_elem_ty.get_bit_width() ); + let index = call_ndarray_flatten_index(generator, ctx, *self.0, indices); + unsafe { - ctx.builder.build_in_bounds_gep( - self.base_ptr(ctx, generator), - &[index], - name.unwrap_or_default(), - ).unwrap() + ctx.builder + .build_in_bounds_gep( + self.base_ptr(ctx, generator), + &[index], + name.unwrap_or_default(), + ) + .unwrap() } } @@ -1638,12 +1673,10 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> let llvm_usize = generator.get_size_type(ctx.ctx); let indices_size = indices.size(ctx, generator); - let nidx_leq_ndims = ctx.builder.build_int_compare( - IntPredicate::SLE, - indices_size, - self.0.load_ndims(ctx), - "" - ).unwrap(); + let nidx_leq_ndims = ctx + .builder + .build_int_compare(IntPredicate::SLE, indices_size, self.0.load_ndims(ctx), "") + .unwrap(); ctx.make_assert( generator, nidx_leq_ndims, @@ -1668,16 +1701,13 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> self.0.dim_sizes().get_typed_unchecked(ctx, generator, &i, None), ) }; - let dim_idx = ctx.builder + let dim_idx = ctx + .builder .build_int_z_extend_or_bit_cast(dim_idx, dim_sz.get_type(), "") .unwrap(); - let dim_lt = ctx.builder.build_int_compare( - IntPredicate::SLT, - dim_idx, - dim_sz, - "" - ).unwrap(); + let dim_lt = + ctx.builder.build_int_compare(IntPredicate::SLT, dim_idx, dim_sz, "").unwrap(); ctx.make_assert( generator, @@ -1691,13 +1721,18 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> Ok(()) }, llvm_usize.const_int(1, false), - ).unwrap(); + ) + .unwrap(); - unsafe { - self.ptr_offset_unchecked(ctx, generator, indices, name) - } + unsafe { self.ptr_offset_unchecked(ctx, generator, indices, name) } } } -impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeAccessor<'ctx, Index> for NDArrayDataProxy<'ctx, '_> {} -impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeMutator<'ctx, Index> for NDArrayDataProxy<'ctx, '_> {} +impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeAccessor<'ctx, Index> + for NDArrayDataProxy<'ctx, '_> +{ +} +impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeMutator<'ctx, Index> + for NDArrayDataProxy<'ctx, '_> +{ +} diff --git a/nac3core/src/codegen/concrete_type.rs b/nac3core/src/codegen/concrete_type.rs index 87ced722..4b55654c 100644 --- a/nac3core/src/codegen/concrete_type.rs +++ b/nac3core/src/codegen/concrete_type.rs @@ -7,9 +7,9 @@ use crate::{ }, }; +use indexmap::IndexMap; use nac3parser::ast::StrRef; use std::collections::HashMap; -use indexmap::IndexMap; pub struct ConcreteTypeStore { store: Vec, @@ -202,9 +202,9 @@ impl ConcreteTypeStore { TypeEnum::TFunc(signature) => { self.from_signature(unifier, primitives, signature, cache) } - TypeEnum::TLiteral { values, .. } => ConcreteTypeEnum::TLiteral { - values: values.clone(), - }, + TypeEnum::TLiteral { values, .. } => { + ConcreteTypeEnum::TLiteral { values: values.clone() } + } _ => unreachable!("{:?}", ty_enum.get_type_name()), }; let index = if let Some(ConcreteType(index)) = cache.get(&ty).unwrap() { @@ -292,9 +292,8 @@ impl ConcreteTypeStore { .map(|(id, cty)| (*id, self.to_unifier_type(unifier, primitives, *cty, cache))) .collect::(), }), - ConcreteTypeEnum::TLiteral { values, .. } => TypeEnum::TLiteral { - values: values.clone(), - loc: None, + ConcreteTypeEnum::TLiteral { values, .. } => { + TypeEnum::TLiteral { values: values.clone(), loc: None } } }; let result = unifier.add_ty(result); diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 0123549d..b898cc2b 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -3,20 +3,11 @@ use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip}; use crate::{ codegen::{ classes::{ - ArrayLikeIndexer, - ArrayLikeValue, - ArraySliceValue, - ListValue, - NDArrayValue, - ProxyValue, - RangeValue, - TypedArrayLikeAccessor, - UntypedArrayLikeAccessor, + ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, ProxyValue, + RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor, }, concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore}, - gen_in_range_check, - get_llvm_type, - get_llvm_abi_type, + gen_in_range_check, get_llvm_abi_type, get_llvm_type, irrt::*, llvm_intrinsics::{call_expect, call_float_floor, call_float_pow, call_float_powi}, numpy, @@ -25,29 +16,27 @@ use crate::{ }, symbol_resolver::{SymbolValue, ValueEnum}, toplevel::{ - DefinitionId, helper::PRIMITIVE_DEF_IDS, numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, - TopLevelDef, + DefinitionId, TopLevelDef, }, typecheck::{ + magic_methods::{binop_assign_name, binop_name, unaryop_name}, typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, - magic_methods::{binop_name, binop_assign_name, unaryop_name}, }, }; use inkwell::{ - AddressSpace, attributes::{Attribute, AttributeLoc}, - IntPredicate, types::{AnyType, BasicType, BasicTypeEnum}, - values::{BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue} + values::{BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue}, + AddressSpace, IntPredicate, }; -use itertools::{chain, izip, Itertools, Either}; +use itertools::{chain, izip, Either, Itertools}; use nac3parser::ast::{ self, Boolop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef, }; -use super::{CodeGenerator, llvm_intrinsics::call_memcpy_generic, need_sret}; +use super::{llvm_intrinsics::call_memcpy_generic, need_sret, CodeGenerator}; pub fn get_subst_key( unifier: &mut Unifier, @@ -57,9 +46,7 @@ pub fn get_subst_key( ) -> String { let mut vars = obj .map(|ty| { - let TypeEnum::TObj { params, .. } = &*unifier.get_ty(ty) else { - unreachable!() - }; + let TypeEnum::TObj { params, .. } = &*unifier.get_ty(ty) else { unreachable!() }; params.clone() }) .unwrap_or_default(); @@ -128,7 +115,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { SymbolValue::Bool(v) => self.ctx.i8_type().const_int(*v as u64, true).into(), SymbolValue::Double(v) => self.ctx.f64_type().const_float(*v).into(), SymbolValue::Str(v) => { - let str_ptr = self.builder + let str_ptr = self + .builder .build_global_string_ptr(v, "const") .map(|v| v.as_pointer_value().into()) .unwrap(); @@ -144,11 +132,14 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { let zero = self.ctx.i32_type().const_zero(); unsafe { for (i, val) in vals.into_iter().enumerate() { - let p = self.builder.build_in_bounds_gep( - ptr, - &[zero, self.ctx.i32_type().const_int(i as u64, false)], - "elemptr", - ).unwrap(); + let p = self + .builder + .build_in_bounds_gep( + ptr, + &[zero, self.ctx.i32_type().const_int(i as u64, false)], + "elemptr", + ) + .unwrap(); self.builder.build_store(p, val).unwrap(); } } @@ -164,7 +155,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { _ => unreachable!("must be option type"), }; let val = self.gen_symbol_val(generator, v, ty); - let ptr = generator.gen_var_alloc(self, val.get_type(), Some("default_opt_some")).unwrap(); + let ptr = generator + .gen_var_alloc(self, val.get_type(), Some("default_opt_some")) + .unwrap(); self.builder.build_store(ptr, val).unwrap(); ptr.into() } @@ -272,7 +265,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { if let Some(v) = self.const_strings.get(v) { Some(*v) } else { - let str_ptr = self.builder + let str_ptr = self + .builder .build_global_string_ptr(v, "const") .map(|v| v.as_pointer_value().into()) .unwrap(); @@ -308,16 +302,22 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { op: &Operator, lhs: BasicValueEnum<'ctx>, rhs: BasicValueEnum<'ctx>, - signed: bool + signed: bool, ) -> BasicValueEnum<'ctx> { let (BasicValueEnum::IntValue(lhs), BasicValueEnum::IntValue(rhs)) = (lhs, rhs) else { unreachable!() }; let float = self.ctx.f64_type(); match (op, signed) { - (Operator::Add, _) => self.builder.build_int_add(lhs, rhs, "add").map(Into::into).unwrap(), - (Operator::Sub, _) => self.builder.build_int_sub(lhs, rhs, "sub").map(Into::into).unwrap(), - (Operator::Mult, _) => self.builder.build_int_mul(lhs, rhs, "mul").map(Into::into).unwrap(), + (Operator::Add, _) => { + self.builder.build_int_add(lhs, rhs, "add").map(Into::into).unwrap() + } + (Operator::Sub, _) => { + self.builder.build_int_sub(lhs, rhs, "sub").map(Into::into).unwrap() + } + (Operator::Mult, _) => { + self.builder.build_int_mul(lhs, rhs, "mul").map(Into::into).unwrap() + } (Operator::Div, true) => { let left = self.builder.build_signed_int_to_float(lhs, float, "i2f").unwrap(); let right = self.builder.build_signed_int_to_float(rhs, float, "i2f").unwrap(); @@ -328,11 +328,19 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { let right = self.builder.build_unsigned_int_to_float(rhs, float, "i2f").unwrap(); self.builder.build_float_div(left, right, "fdiv").map(Into::into).unwrap() } - (Operator::Mod, true) => self.builder.build_int_signed_rem(lhs, rhs, "mod").map(Into::into).unwrap(), - (Operator::Mod, false) => self.builder.build_int_unsigned_rem(lhs, rhs, "mod").map(Into::into).unwrap(), + (Operator::Mod, true) => { + self.builder.build_int_signed_rem(lhs, rhs, "mod").map(Into::into).unwrap() + } + (Operator::Mod, false) => { + self.builder.build_int_unsigned_rem(lhs, rhs, "mod").map(Into::into).unwrap() + } (Operator::BitOr, _) => self.builder.build_or(lhs, rhs, "or").map(Into::into).unwrap(), - (Operator::BitXor, _) => self.builder.build_xor(lhs, rhs, "xor").map(Into::into).unwrap(), - (Operator::BitAnd, _) => self.builder.build_and(lhs, rhs, "and").map(Into::into).unwrap(), + (Operator::BitXor, _) => { + self.builder.build_xor(lhs, rhs, "xor").map(Into::into).unwrap() + } + (Operator::BitAnd, _) => { + self.builder.build_and(lhs, rhs, "and").map(Into::into).unwrap() + } // Sign-ness of bitshift operators are always determined by the left operand (Operator::LShift | Operator::RShift, signed) => { @@ -350,7 +358,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { rhs }; - let rhs_gez = self.builder + let rhs_gez = self + .builder .build_int_compare(IntPredicate::SGE, rhs, common_type.const_zero(), "") .unwrap(); self.make_assert( @@ -359,18 +368,28 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { "ValueError", "negative shift count", [None, None, None], - self.current_loc + self.current_loc, ); match *op { - Operator::LShift => self.builder.build_left_shift(lhs, rhs, "lshift").map(Into::into).unwrap(), - Operator::RShift => self.builder.build_right_shift(lhs, rhs, signed, "rshift").map(Into::into).unwrap(), - _ => unreachable!() + Operator::LShift => { + self.builder.build_left_shift(lhs, rhs, "lshift").map(Into::into).unwrap() + } + Operator::RShift => self + .builder + .build_right_shift(lhs, rhs, signed, "rshift") + .map(Into::into) + .unwrap(), + _ => unreachable!(), } } - (Operator::FloorDiv, true) => self.builder.build_int_signed_div(lhs, rhs, "floordiv").map(Into::into).unwrap(), - (Operator::FloorDiv, false) => self.builder.build_int_unsigned_div(lhs, rhs, "floordiv").map(Into::into).unwrap(), + (Operator::FloorDiv, true) => { + self.builder.build_int_signed_div(lhs, rhs, "floordiv").map(Into::into).unwrap() + } + (Operator::FloorDiv, false) => { + self.builder.build_int_unsigned_div(lhs, rhs, "floordiv").map(Into::into).unwrap() + } (Operator::Pow, s) => integer_power(generator, self, lhs, rhs, s).into(), // special implementation? (Operator::MatMult, _) => unreachable!(), @@ -385,14 +404,28 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { rhs: BasicValueEnum<'ctx>, ) -> BasicValueEnum<'ctx> { let (BasicValueEnum::FloatValue(lhs), BasicValueEnum::FloatValue(rhs)) = (lhs, rhs) else { - unreachable!("Expected (FloatValue, FloatValue), got ({}, {})", lhs.get_type(), rhs.get_type()) + unreachable!( + "Expected (FloatValue, FloatValue), got ({}, {})", + lhs.get_type(), + rhs.get_type() + ) }; match op { - Operator::Add => self.builder.build_float_add(lhs, rhs, "fadd").map(Into::into).unwrap(), - Operator::Sub => self.builder.build_float_sub(lhs, rhs, "fsub").map(Into::into).unwrap(), - Operator::Mult => self.builder.build_float_mul(lhs, rhs, "fmul").map(Into::into).unwrap(), - Operator::Div => self.builder.build_float_div(lhs, rhs, "fdiv").map(Into::into).unwrap(), - Operator::Mod => self.builder.build_float_rem(lhs, rhs, "fmod").map(Into::into).unwrap(), + Operator::Add => { + self.builder.build_float_add(lhs, rhs, "fadd").map(Into::into).unwrap() + } + Operator::Sub => { + self.builder.build_float_sub(lhs, rhs, "fsub").map(Into::into).unwrap() + } + Operator::Mult => { + self.builder.build_float_mul(lhs, rhs, "fmul").map(Into::into).unwrap() + } + Operator::Div => { + self.builder.build_float_div(lhs, rhs, "fdiv").map(Into::into).unwrap() + } + Operator::Mod => { + self.builder.build_float_rem(lhs, rhs, "fmod").map(Into::into).unwrap() + } Operator::FloorDiv => { let div = self.builder.build_float_div(lhs, rhs, "fdiv").unwrap(); call_float_floor(self, div, Some("floor")).into() @@ -427,16 +460,16 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { let byval_id = Attribute::get_named_enum_kind_id("byval"); let offset = if fun.get_enum_attribute(AttributeLoc::Param(0), sret_id).is_some() { - return_slot = Some(self.builder - .build_alloca( - fun.get_type() - .get_param_types()[0] - .into_pointer_type() - .get_element_type() - .into_struct_type(), - call_name - ) - .unwrap() + return_slot = Some( + self.builder + .build_alloca( + fun.get_type().get_param_types()[0] + .into_pointer_type() + .get_element_type() + .into_struct_type(), + call_name, + ) + .unwrap(), ); loc_params.push((*return_slot.as_ref().unwrap()).into()); 1 @@ -445,10 +478,12 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { }; for (i, param) in params.iter().enumerate() { let loc = AttributeLoc::Param((i + offset) as u32); - if fun.get_enum_attribute(loc, byref_id).is_some() || fun.get_enum_attribute(loc, byval_id).is_some() { + if fun.get_enum_attribute(loc, byref_id).is_some() + || fun.get_enum_attribute(loc, byval_id).is_some() + { // lazy update if loc_params.is_empty() { - loc_params.extend(params[0..i+offset].iter().copied()); + loc_params.extend(params[0..i + offset].iter().copied()); } let slot = gen_var(self, param.get_type(), Some(call_name)).unwrap(); loc_params.push(slot.into()); @@ -490,7 +525,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { result } else { let param: Vec<_> = params.iter().map(|v| (*v).into()).collect(); - self.builder.build_call(fun, ¶m, call_name) + self.builder + .build_call(fun, ¶m, call_name) .map(CallSiteValue::try_as_basic_value) .map(Either::left) .unwrap() @@ -503,14 +539,10 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { } /// Helper function for generating a LLVM variable storing a [String]. - pub fn gen_string( - &mut self, - generator: &mut G, - s: S, - ) -> BasicValueEnum<'ctx> - where - G: CodeGenerator + ?Sized, - S: Into, + pub fn gen_string(&mut self, generator: &mut G, s: S) -> BasicValueEnum<'ctx> + where + G: CodeGenerator + ?Sized, + S: Into, { self.gen_const(generator, &Constant::Str(s.into()), self.primitives.str).unwrap() } @@ -534,24 +566,24 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { let int32 = self.ctx.i32_type(); let zero = int32.const_zero(); unsafe { - let id_ptr = self.builder - .build_in_bounds_gep(zelf, &[zero, zero], "exn.id") - .unwrap(); + let id_ptr = self.builder.build_in_bounds_gep(zelf, &[zero, zero], "exn.id").unwrap(); let id = self.resolver.get_string_id(name); self.builder.build_store(id_ptr, int32.const_int(id as u64, false)).unwrap(); - let ptr = self.builder.build_in_bounds_gep( - zelf, - &[zero, int32.const_int(5, false)], - "exn.msg", - ).unwrap(); + let ptr = self + .builder + .build_in_bounds_gep(zelf, &[zero, int32.const_int(5, false)], "exn.msg") + .unwrap(); self.builder.build_store(ptr, msg).unwrap(); let i64_zero = self.ctx.i64_type().const_zero(); for (i, attr_ind) in [6, 7, 8].iter().enumerate() { - let ptr = self.builder.build_in_bounds_gep( - zelf, - &[zero, int32.const_int(*attr_ind, false)], - "exn.param", - ).unwrap(); + let ptr = self + .builder + .build_in_bounds_gep( + zelf, + &[zero, int32.const_int(*attr_ind, false)], + "exn.param", + ) + .unwrap(); let val = params[i].map_or(i64_zero, |v| { self.builder.build_int_s_extend(v, self.ctx.i64_type(), "sext").unwrap() }); @@ -609,25 +641,19 @@ pub fn gen_constructor<'ctx, 'a, G: CodeGenerator>( def: &TopLevelDef, params: Vec<(Option, ValueEnum<'ctx>)>, ) -> Result, String> { - let TopLevelDef::Class { methods, .. } = def else { - unreachable!() - }; + let TopLevelDef::Class { methods, .. } = def else { unreachable!() }; // TODO: what about other fields that require alloca? let fun_id = methods.iter().find(|method| method.0 == "__init__".into()).map(|method| method.2); let ty = ctx.get_llvm_type(generator, signature.ret).into_pointer_type(); let zelf_ty: BasicTypeEnum = ty.get_element_type().try_into().unwrap(); - let zelf: BasicValueEnum<'ctx> = ctx.builder.build_alloca(zelf_ty, "alloca").map(Into::into).unwrap(); + let zelf: BasicValueEnum<'ctx> = + ctx.builder.build_alloca(zelf_ty, "alloca").map(Into::into).unwrap(); // call `__init__` if there is one if let Some(fun_id) = fun_id { let mut sign = signature.clone(); sign.ret = ctx.primitives.none; - generator.gen_call( - ctx, - Some((signature.ret, zelf.into())), - (&sign, fun_id), - params, - )?; + generator.gen_call(ctx, Some((signature.ret, zelf.into())), (&sign, fun_id), params)?; } Ok(zelf) } @@ -645,7 +671,10 @@ pub fn gen_func_instance<'ctx>( name, instance_to_symbol, instance_to_stmt, var_id, resolver, .. }, key, - ) = fun else { unreachable!() }; + ) = fun + else { + unreachable!() + }; if let Some(sym) = instance_to_symbol.get(&key) { return Ok(sym.clone()); @@ -675,20 +704,13 @@ pub fn gen_func_instance<'ctx>( }) .collect(); - let mut signature = - store.from_signature(&mut ctx.unifier, &ctx.primitives, sign, &mut cache); + let mut signature = store.from_signature(&mut ctx.unifier, &ctx.primitives, sign, &mut cache); if let Some(obj) = &obj { - let zelf = - store.from_unifier_type(&mut ctx.unifier, &ctx.primitives, obj.0, &mut cache); - let ConcreteTypeEnum::TFunc { args, .. } = &mut signature else { - unreachable!() - }; + let zelf = store.from_unifier_type(&mut ctx.unifier, &ctx.primitives, obj.0, &mut cache); + let ConcreteTypeEnum::TFunc { args, .. } = &mut signature else { unreachable!() }; - args.insert( - 0, - ConcreteFuncArg { name: "self".into(), ty: zelf, default_value: None }, - ); + args.insert(0, ConcreteFuncArg { name: "self".into(), ty: zelf, default_value: None }); } let signature = store.add_cty(signature); @@ -751,8 +773,12 @@ pub fn gen_call<'ctx, G: CodeGenerator>( ); } // reorder the parameters - let mut real_params = - fun.0.args.iter().map(|arg| (mapping.remove(&arg.name).unwrap(), arg.ty)).collect_vec(); + let mut real_params = fun + .0 + .args + .iter() + .map(|arg| (mapping.remove(&arg.name).unwrap(), arg.ty)) + .collect_vec(); if let Some(obj) = &obj { real_params.insert(0, (obj.1.clone(), obj.0)); } @@ -814,26 +840,36 @@ pub fn gen_call<'ctx, G: CodeGenerator>( }; let has_sret = ret_type.map_or(false, |ret_type| need_sret(ret_type)); let mut byrefs = Vec::new(); - let mut params = args.iter().enumerate() - .map(|(i, arg)| match ctx.get_llvm_abi_type(generator, arg.ty) { - BasicTypeEnum::StructType(ty) if is_extern => { - byrefs.push((i, ty)); - ty.ptr_type(AddressSpace::default()).into() - }, - x => x - }.into()) + let mut params = args + .iter() + .enumerate() + .map(|(i, arg)| { + match ctx.get_llvm_abi_type(generator, arg.ty) { + BasicTypeEnum::StructType(ty) if is_extern => { + byrefs.push((i, ty)); + ty.ptr_type(AddressSpace::default()).into() + } + x => x, + } + .into() + }) .collect_vec(); if has_sret { params.insert(0, ret_type.unwrap().ptr_type(AddressSpace::default()).into()); } let fun_ty = match ret_type { Some(ret_type) if !has_sret => ret_type.fn_type(¶ms, false), - _ => ctx.ctx.void_type().fn_type(¶ms, false) + _ => ctx.ctx.void_type().fn_type(¶ms, false), }; let fun_val = ctx.module.add_function(&symbol, fun_ty, None); let offset = if has_sret { - fun_val.add_attribute(AttributeLoc::Param(0), - ctx.ctx.create_type_attribute(Attribute::get_named_enum_kind_id("sret"), ret_type.unwrap().as_any_type_enum())); + fun_val.add_attribute( + AttributeLoc::Param(0), + ctx.ctx.create_type_attribute( + Attribute::get_named_enum_kind_id("sret"), + ret_type.unwrap().as_any_type_enum(), + ), + ); 1 } else { 0 @@ -843,20 +879,22 @@ pub fn gen_call<'ctx, G: CodeGenerator>( // Structure-Typed parameters of extern functions must **not** be marked as `byval`, as // `byval` explicitly specifies that the argument is to be passed on the stack, which breaks // on most ABIs where the first several arguments are expected to be passed in registers. - let passing_attr_id = Attribute::get_named_enum_kind_id( - if is_extern { "byref" } else { "byval" } - ); + let passing_attr_id = + Attribute::get_named_enum_kind_id(if is_extern { "byref" } else { "byval" }); for (i, ty) in byrefs { fun_val.add_attribute( AttributeLoc::Param((i as u32) + offset), - ctx.ctx.create_type_attribute(passing_attr_id, ty.as_any_type_enum()) + ctx.ctx.create_type_attribute(passing_attr_id, ty.as_any_type_enum()), ); } fun_val }); // Convert boolean parameter values into i1 - let param_vals = fun_val.get_params().iter().zip(param_vals) + let param_vals = fun_val + .get_params() + .iter() + .zip(param_vals) .map(|(p, v)| { if p.is_int_value() && v.is_int_value() { let expected_ty = p.into_int_value().get_type(); @@ -866,7 +904,8 @@ pub fn gen_call<'ctx, G: CodeGenerator>( generator.bool_to_i1(ctx, param_val) } else { param_val - }.into() + } + .into() } else { v } @@ -903,19 +942,16 @@ pub fn allocate_list<'ctx, G: CodeGenerator + ?Sized>( ) -> ListValue<'ctx> { let size_t = generator.get_size_type(ctx.ctx); // List structure; type { ty*, size_t } - let arr_ty = ctx.ctx - .struct_type(&[ty.ptr_type(AddressSpace::default()).into(), size_t.into()], false); + let arr_ty = + ctx.ctx.struct_type(&[ty.ptr_type(AddressSpace::default()).into(), size_t.into()], false); - let arr_str_ptr = ctx.builder.build_alloca( - arr_ty, format!("{}.addr", name.unwrap_or("list")).as_str() - ).unwrap(); + let arr_str_ptr = ctx + .builder + .build_alloca(arr_ty, format!("{}.addr", name.unwrap_or("list")).as_str()) + .unwrap(); let list = ListValue::from_ptr_val(arr_str_ptr, size_t, Some("list")); - let length = ctx.builder.build_int_z_extend( - length, - size_t, - "" - ).unwrap(); + let length = ctx.builder.build_int_z_extend(length, size_t, "").unwrap(); list.store_size(ctx, generator, length); list.create_data(ctx, ty, None); @@ -928,9 +964,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( ctx: &mut CodeGenContext<'ctx, '_>, expr: &Expr>, ) -> Result>, String> { - let ExprKind::ListComp { elt, generators } = &expr.node else { - unreachable!() - }; + let ExprKind::ListComp { elt, generators } = &expr.node else { unreachable!() }; let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); @@ -952,7 +986,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( ctx.builder.build_unreachable().unwrap(); } - return Ok(None) + return Ok(None); }; let int32 = ctx.ctx.i32_type(); let size_t = generator.get_size_type(ctx.ctx); @@ -977,22 +1011,24 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( let length = ctx.builder.build_int_signed_div(diff, step, "div").unwrap(); let length = ctx.builder.build_int_add(length, int32.const_int(1, false), "add1").unwrap(); // in case length is non-positive - let is_valid = ctx.builder - .build_int_compare(IntPredicate::SGT, length, zero_32, "check") - .unwrap(); + let is_valid = + ctx.builder.build_int_compare(IntPredicate::SGT, length, zero_32, "check").unwrap(); - let list_alloc_size = ctx.builder.build_select( - is_valid, - ctx.builder.build_int_z_extend_or_bit_cast(length, size_t, "z_ext_len").unwrap(), - zero_size_t, - "listcomp.alloc_size" - ).unwrap(); + let list_alloc_size = ctx + .builder + .build_select( + is_valid, + ctx.builder.build_int_z_extend_or_bit_cast(length, size_t, "z_ext_len").unwrap(), + zero_size_t, + "listcomp.alloc_size", + ) + .unwrap(); list = allocate_list( generator, ctx, elem_ty, list_alloc_size.into_int_value(), - Some("listcomp.addr") + Some("listcomp.addr"), ); list_content = list.data().base_ptr(ctx, generator); @@ -1007,11 +1043,14 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( ctx.builder.position_at_end(test_bb); // add and test - let tmp = ctx.builder.build_int_add( - ctx.builder.build_load(i, "i").map(BasicValueEnum::into_int_value).unwrap(), - step, - "start_loop", - ).unwrap(); + let tmp = ctx + .builder + .build_int_add( + ctx.builder.build_load(i, "i").map(BasicValueEnum::into_int_value).unwrap(), + step, + "start_loop", + ) + .unwrap(); ctx.builder.build_store(i, tmp).unwrap(); ctx.builder .build_conditional_branch(gen_in_range_check(ctx, tmp, stop, step), body_bb, cont_bb) @@ -1042,17 +1081,26 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( ctx.builder.position_at_end(body_bb); let arr_ptr = ctx - .build_gep_and_load(iter_val.into_pointer_value(), &[zero_size_t, zero_32], Some("arr.addr")) + .build_gep_and_load( + iter_val.into_pointer_value(), + &[zero_size_t, zero_32], + Some("arr.addr"), + ) .into_pointer_value(); let val = ctx.build_gep_and_load(arr_ptr, &[tmp], Some("val")); generator.gen_assign(ctx, target, val.into())?; } // Emits the content of `cont_bb` - let emit_cont_bb = |ctx: &CodeGenContext<'ctx, '_>, generator: &dyn CodeGenerator, list: ListValue<'ctx>| { - ctx.builder.position_at_end(cont_bb); - list.store_size(ctx, generator, ctx.builder.build_load(index, "index").map(BasicValueEnum::into_int_value).unwrap()); - }; + let emit_cont_bb = + |ctx: &CodeGenContext<'ctx, '_>, generator: &dyn CodeGenerator, list: ListValue<'ctx>| { + ctx.builder.position_at_end(cont_bb); + list.store_size( + ctx, + generator, + ctx.builder.build_load(index, "index").map(BasicValueEnum::into_int_value).unwrap(), + ); + }; for cond in ifs { let result = if let Some(v) = generator.gen_expr(ctx, cond)? { @@ -1062,7 +1110,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( // no element matches the predicate emit_cont_bb(ctx, generator, list); - return Ok(None) + return Ok(None); }; let result = generator.bool_to_i1(ctx, result); let succ = ctx.ctx.append_basic_block(current, "then"); @@ -1075,7 +1123,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( // Similarly, bail if the generator expression is an ellipsis, but keep cont_bb contents emit_cont_bb(ctx, generator, list); - return Ok(None) + return Ok(None); }; let i = ctx.builder.build_load(index, "i").map(BasicValueEnum::into_int_value).unwrap(); let elem_ptr = unsafe { ctx.builder.build_gep(list_content, &[i], "elem_ptr") }.unwrap(); @@ -1094,7 +1142,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( Ok(Some(list.as_base_value().into())) } -/// Generates LLVM IR for a binary operator expression using the [`Type`] and +/// Generates LLVM IR for a binary operator expression using the [`Type`] and /// [LLVM value][`BasicValueEnum`] of the operands. pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( generator: &mut G, @@ -1130,16 +1178,18 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( ctx, left_val.into_float_value(), right_val.into_int_value(), - Some("f_pow_i") + Some("f_pow_i"), ); Ok(Some(res.into())) - } else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) || ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) { + } else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + || ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + { let llvm_usize = generator.get_size_type(ctx.ctx); - let is_ndarray1 = ty1.obj_id(&ctx.unifier) - .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); - let is_ndarray2 = ty2.obj_id(&ctx.unifier) - .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_ndarray1 = + ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_ndarray2 = + ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); if is_ndarray1 && is_ndarray2 { let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1); @@ -1147,16 +1197,10 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - let left_val = NDArrayValue::from_ptr_val( - left_val.into_pointer_value(), - llvm_usize, - None, - ); - let right_val = NDArrayValue::from_ptr_val( - right_val.into_pointer_value(), - llvm_usize, - None, - ); + let left_val = + NDArrayValue::from_ptr_val(left_val.into_pointer_value(), llvm_usize, None); + let right_val = + NDArrayValue::from_ptr_val(right_val.into_pointer_value(), llvm_usize, None); let res = if *op == Operator::MatMult { // MatMult is the only binop which is not an elementwise op @@ -1185,17 +1229,21 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( (&Some(ndarray_dtype2), rhs), ctx.current_loc, is_aug_assign, - )?.unwrap().to_basic_value_enum(ctx, generator, ndarray_dtype1) + )? + .unwrap() + .to_basic_value_enum( + ctx, + generator, + ndarray_dtype1, + ) }, )? }; Ok(Some(res.as_base_value().into())) } else { - let (ndarray_dtype, _) = unpack_ndarray_var_tys( - &mut ctx.unifier, - if is_ndarray1 { ty1 } else { ty2 }, - ); + let (ndarray_dtype, _) = + unpack_ndarray_var_tys(&mut ctx.unifier, if is_ndarray1 { ty1 } else { ty2 }); let ndarray_val = NDArrayValue::from_ptr_val( if is_ndarray1 { left_val } else { right_val }.into_pointer_value(), llvm_usize, @@ -1217,7 +1265,9 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( (&Some(ndarray_dtype), rhs), ctx.current_loc, is_aug_assign, - )?.unwrap().to_basic_value_enum(ctx, generator, ndarray_dtype) + )? + .unwrap() + .to_basic_value_enum(ctx, generator, ndarray_dtype) }, )?; @@ -1229,10 +1279,8 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( unreachable!("must be tobj") }; let (op_name, id) = { - let (binop_name, binop_assign_name) = ( - binop_name(op).into(), - binop_assign_name(op).into() - ); + let (binop_name, binop_assign_name) = + (binop_name(op).into(), binop_assign_name(op).into()); // if is aug_assign, try aug_assign operator first if is_aug_assign && fields.contains_key(&binop_assign_name) { (binop_assign_name, *obj_id) @@ -1251,18 +1299,14 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( let fn_ty = fields.get(&op_name).unwrap().0; let fn_ty_enum = ctx.unifier.get_ty_immutable(fn_ty); - let TypeEnum::TFunc(sig) = fn_ty_enum.as_ref() else { - unreachable!() - }; + let TypeEnum::TFunc(sig) = fn_ty_enum.as_ref() else { unreachable!() }; sig.clone() }; let fun_id = { let defs = ctx.top_level.definitions.read(); let obj_def = defs.get(id.0).unwrap().read(); - let TopLevelDef::Class { methods, .. } = &*obj_def else { - unreachable!() - }; + let TopLevelDef::Class { methods, .. } = &*obj_def else { unreachable!() }; methods.iter().find(|method| method.0 == op_name).unwrap().2 }; @@ -1272,7 +1316,8 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( Some((left_ty.unwrap(), left_val.into())), (&signature, fun_id), vec![(None, right_val.into())], - ).map(|f| f.map(Into::into)) + ) + .map(|f| f.map(Into::into)) } } @@ -1295,12 +1340,12 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>( let left_val = if let Some(v) = generator.gen_expr(ctx, left)? { v.to_basic_value_enum(ctx, generator, left.custom.unwrap())? } else { - return Ok(None) + return Ok(None); }; let right_val = if let Some(v) = generator.gen_expr(ctx, right)? { v.to_basic_value_enum(ctx, generator, right.custom.unwrap())? } else { - return Ok(None) + return Ok(None); }; gen_binop_expr_with_values( @@ -1329,12 +1374,9 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( let val = val.into_int_value(); if *op == ast::Unaryop::Not { let not = ctx.builder.build_not(val, "not").unwrap(); - let not_bool = ctx.builder.build_and( - not, - not.get_type().const_int(1, false), - "", - ).unwrap(); - + let not_bool = + ctx.builder.build_and(not, not.get_type().const_int(1, false), "").unwrap(); + not_bool.into() } else { let llvm_i32 = ctx.ctx.i32_type(); @@ -1345,16 +1387,28 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( op, ( &Some(ctx.primitives.int32), - ctx.builder.build_int_z_extend(val, llvm_i32, "").map(Into::into).unwrap() + ctx.builder.build_int_z_extend(val, llvm_i32, "").map(Into::into).unwrap(), ), - )?.unwrap() + )? + .unwrap() } - } else if [ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty) { + } else if [ + ctx.primitives.int32, + ctx.primitives.int64, + ctx.primitives.uint32, + ctx.primitives.uint64, + ] + .contains(&ty) + { let val = val.into_int_value(); match op { ast::Unaryop::USub => ctx.builder.build_int_neg(val, "neg").map(Into::into).unwrap(), ast::Unaryop::Invert => ctx.builder.build_not(val, "not").map(Into::into).unwrap(), - ast::Unaryop::Not => ctx.builder.build_xor(val, val.get_type().const_all_ones(), "not").map(Into::into).unwrap(), + ast::Unaryop::Not => ctx + .builder + .build_xor(val, val.get_type().const_all_ones(), "not") + .map(Into::into) + .unwrap(), ast::Unaryop::UAdd => val.into(), } } else if ty == ctx.primitives.float { @@ -1377,23 +1431,20 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( let llvm_usize = generator.get_size_type(ctx.ctx); let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); - let val = NDArrayValue::from_ptr_val( - val.into_pointer_value(), - llvm_usize, - None, - ); + let val = NDArrayValue::from_ptr_val(val.into_pointer_value(), llvm_usize, None); // ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before // passing it to the elementwise codegen function - let op = if ndarray_dtype.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) { - if *op == ast::Unaryop::Invert { - &ast::Unaryop::Not + let op = + if ndarray_dtype.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) { + if *op == ast::Unaryop::Invert { + &ast::Unaryop::Not + } else { + unreachable!("ufunc {} not supported for ndarray[bool, N]", unaryop_name(op)) + } } else { - unreachable!("ufunc {} not supported for ndarray[bool, N]", unaryop_name(op)) - } - } else { - op - }; + op + }; let res = numpy::ndarray_elementwise_unaryop_impl( generator, @@ -1402,19 +1453,16 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( None, val, |generator, ctx, val| { - gen_unaryop_expr_with_values( - generator, - ctx, - op, - (&Some(ndarray_dtype), val), - )?.unwrap().to_basic_value_enum(ctx, generator, ndarray_dtype) + gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), val))? + .unwrap() + .to_basic_value_enum(ctx, generator, ndarray_dtype) }, )?; res.as_base_value().into() } else { unimplemented!() - })) + })) } /// Generates LLVM IR for a unary operator expression. @@ -1430,7 +1478,7 @@ pub fn gen_unaryop_expr<'ctx, G: CodeGenerator>( let val = if let Some(v) = generator.gen_expr(ctx, operand)? { v.to_basic_value_enum(ctx, generator, operand.custom.unwrap())? } else { - return Ok(None) + return Ok(None); }; gen_unaryop_expr_with_values(generator, ctx, op, (&operand.custom, val)) @@ -1451,29 +1499,28 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( let left_ty = ctx.unifier.get_representative(left.0.unwrap()); let right_ty = ctx.unifier.get_representative(comparators[0].0.unwrap()); - if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) || right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) { + if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + || right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + { let llvm_usize = generator.get_size_type(ctx.ctx); let (Some(left_ty), lhs) = left else { unreachable!() }; let (Some(right_ty), rhs) = comparators[0] else { unreachable!() }; let op = ops[0].clone(); - let is_ndarray1 = left_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); - let is_ndarray2 = right_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_ndarray1 = + left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_ndarray2 = + right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); return if is_ndarray1 && is_ndarray2 { let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty); let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, right_ty); - + assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - let left_val = NDArrayValue::from_ptr_val( - lhs.into_pointer_value(), - llvm_usize, - None - ); + + let left_val = + NDArrayValue::from_ptr_val(lhs.into_pointer_value(), llvm_usize, None); let res = numpy::ndarray_elementwise_binop_impl( generator, ctx, @@ -1488,12 +1535,18 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( (Some(ndarray_dtype1), lhs), &[op.clone()], &[(Some(ndarray_dtype2), rhs)], - )?.unwrap().to_basic_value_enum(ctx, generator, ctx.primitives.bool)?; + )? + .unwrap() + .to_basic_value_enum( + ctx, + generator, + ctx.primitives.bool, + )?; Ok(generator.bool_to_i8(ctx, val.into_int_value()).into()) }, )?; - + Ok(Some(res.as_base_value().into())) } else { let (ndarray_dtype, _) = unpack_ndarray_var_tys( @@ -1514,14 +1567,20 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( (Some(ndarray_dtype), lhs), &[op.clone()], &[(Some(ndarray_dtype), rhs)], - )?.unwrap().to_basic_value_enum(ctx, generator, ctx.primitives.bool)?; + )? + .unwrap() + .to_basic_value_enum( + ctx, + generator, + ctx.primitives.bool, + )?; Ok(generator.bool_to_i8(ctx, val.into_int_value()).into()) }, )?; - + Ok(Some(res.as_base_value().into())) - } + }; } } @@ -1533,71 +1592,82 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( let left_ty = ctx.unifier.get_representative(left_ty.unwrap()); let right_ty = ctx.unifier.get_representative(right_ty.unwrap()); - let current = - if [ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64, ctx.primitives.bool] - .contains(&left_ty) - { - assert!(ctx.unifier.unioned(left_ty, right_ty)); + let current = if [ + ctx.primitives.int32, + ctx.primitives.int64, + ctx.primitives.uint32, + ctx.primitives.uint64, + ctx.primitives.bool, + ] + .contains(&left_ty) + { + assert!(ctx.unifier.unioned(left_ty, right_ty)); - let use_unsigned_ops = [ - ctx.primitives.uint32, - ctx.primitives.uint64, - ].contains(&left_ty); - - let lhs = lhs.into_int_value(); - let rhs = rhs.into_int_value(); - - let op = match op { - ast::Cmpop::Eq | ast::Cmpop::Is => IntPredicate::EQ, - ast::Cmpop::NotEq => IntPredicate::NE, - _ if left_ty == ctx.primitives.bool => unreachable!(), - ast::Cmpop::Lt => if use_unsigned_ops { + let use_unsigned_ops = + [ctx.primitives.uint32, ctx.primitives.uint64].contains(&left_ty); + + let lhs = lhs.into_int_value(); + let rhs = rhs.into_int_value(); + + let op = match op { + ast::Cmpop::Eq | ast::Cmpop::Is => IntPredicate::EQ, + ast::Cmpop::NotEq => IntPredicate::NE, + _ if left_ty == ctx.primitives.bool => unreachable!(), + ast::Cmpop::Lt => { + if use_unsigned_ops { IntPredicate::ULT } else { IntPredicate::SLT - }, - ast::Cmpop::LtE => if use_unsigned_ops { + } + } + ast::Cmpop::LtE => { + if use_unsigned_ops { IntPredicate::ULE } else { IntPredicate::SLE - }, - ast::Cmpop::Gt => if use_unsigned_ops { + } + } + ast::Cmpop::Gt => { + if use_unsigned_ops { IntPredicate::UGT } else { IntPredicate::SGT - }, - ast::Cmpop::GtE => if use_unsigned_ops { + } + } + ast::Cmpop::GtE => { + if use_unsigned_ops { IntPredicate::UGE } else { IntPredicate::SGE - }, - _ => unreachable!(), - }; - - ctx.builder.build_int_compare(op, lhs, rhs, "cmp").unwrap() - } else if left_ty == ctx.primitives.float { - assert!(ctx.unifier.unioned(left_ty, right_ty)); - - let lhs = lhs.into_float_value(); - let rhs = rhs.into_float_value(); - - let op = match op { - ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::FloatPredicate::OEQ, - ast::Cmpop::NotEq => inkwell::FloatPredicate::ONE, - ast::Cmpop::Lt => inkwell::FloatPredicate::OLT, - ast::Cmpop::LtE => inkwell::FloatPredicate::OLE, - ast::Cmpop::Gt => inkwell::FloatPredicate::OGT, - ast::Cmpop::GtE => inkwell::FloatPredicate::OGE, - _ => unreachable!(), - }; - ctx.builder.build_float_compare(op, lhs, rhs, "cmp").unwrap() - } else { - unimplemented!() + } + } + _ => unreachable!(), }; + ctx.builder.build_int_compare(op, lhs, rhs, "cmp").unwrap() + } else if left_ty == ctx.primitives.float { + assert!(ctx.unifier.unioned(left_ty, right_ty)); + + let lhs = lhs.into_float_value(); + let rhs = rhs.into_float_value(); + + let op = match op { + ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::FloatPredicate::OEQ, + ast::Cmpop::NotEq => inkwell::FloatPredicate::ONE, + ast::Cmpop::Lt => inkwell::FloatPredicate::OLT, + ast::Cmpop::LtE => inkwell::FloatPredicate::OLE, + ast::Cmpop::Gt => inkwell::FloatPredicate::OGT, + ast::Cmpop::GtE => inkwell::FloatPredicate::OGE, + _ => unreachable!(), + }; + ctx.builder.build_float_compare(op, lhs, rhs, "cmp").unwrap() + } else { + unimplemented!() + }; + Ok(prev?.map(|v| ctx.builder.build_and(v, current, "cmp").unwrap()).or(Some(current))) })?; - + Ok(Some(match cmp_val { Some(v) => v.into(), None => return Ok(None), @@ -1619,29 +1689,26 @@ pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>( let left_val = if let Some(v) = generator.gen_expr(ctx, left)? { v.to_basic_value_enum(ctx, generator, left.custom.unwrap())? } else { - return Ok(None) + return Ok(None); }; - let comparator_vals = comparators.iter() + let comparator_vals = comparators + .iter() .map(|cmptor| { Ok(if let Some(v) = generator.gen_expr(ctx, cmptor)? { - Some((cmptor.custom, v.to_basic_value_enum(ctx, generator, cmptor.custom.unwrap())?)) + Some(( + cmptor.custom, + v.to_basic_value_enum(ctx, generator, cmptor.custom.unwrap())?, + )) } else { None }) }) - .take_while(|v| if let Ok(v) = v { - v.is_some() - } else { - true - }) + .take_while(|v| if let Ok(v) = v { v.is_some() } else { true }) .collect::, String>>()?; let comparator_vals = if comparator_vals.len() == comparators.len() { - comparator_vals - .into_iter() - .map(Option::unwrap) - .collect_vec() + comparator_vals.into_iter().map(Option::unwrap).collect_vec() } else { - return Ok(None) + return Ok(None); }; gen_cmpop_expr_with_values( @@ -1675,7 +1742,8 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( unreachable!() }; - let ndims = values.iter() + let ndims = values + .iter() .map(|ndim| match *ndim { SymbolValue::U64(v) => Ok(v), SymbolValue::U32(v) => Ok(v as u64), @@ -1689,16 +1757,11 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( assert!(!ndims.is_empty()); - let ndarray_ndims_ty = ctx.unifier.get_fresh_literal( - ndims.iter().map(|v| SymbolValue::U64(v - 1)).collect(), - None, - ); - let ndarray_ty = make_ndarray_ty( - &mut ctx.unifier, - &ctx.primitives, - Some(ty), - Some(ndarray_ndims_ty), - ); + let ndarray_ndims_ty = ctx + .unifier + .get_fresh_literal(ndims.iter().map(|v| SymbolValue::U64(v - 1)).collect(), None); + let ndarray_ty = + make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(ty), Some(ndarray_ndims_ty)); let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type(); let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type(); let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum(); @@ -1723,12 +1786,10 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( generator, ctx, |_, ctx| { - Ok(ctx.builder.build_int_compare( - IntPredicate::SGE, - index, - index.get_type().const_zero(), - "", - ).unwrap()) + Ok(ctx + .builder + .build_int_compare(IntPredicate::SGE, index, index.get_type().const_zero(), "") + .unwrap()) }, |_, _| Ok(Some(index)), |generator, ctx| { @@ -1743,28 +1804,32 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( ) }; - let index = ctx.builder.build_int_add( - len, - ctx.builder.build_int_s_extend(index, llvm_usize, "").unwrap(), - "", - ).unwrap(); + let index = ctx + .builder + .build_int_add( + len, + ctx.builder.build_int_s_extend(index, llvm_usize, "").unwrap(), + "", + ) + .unwrap(); Ok(Some(ctx.builder.build_int_truncate(index, llvm_i32, "").unwrap())) }, - ).map(|v| v.map(BasicValueEnum::into_int_value)) + ) + .map(|v| v.map(BasicValueEnum::into_int_value)) }; // Converts a slice expression into a slice-range tuple - let expr_to_slice = |generator: &mut G, + let expr_to_slice = |generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, node: &ExprKind>, dim: u64| { match node { ExprKind::Constant { value: Constant::Int(v), .. } => { - let Some(index) = normalize_index( - generator, ctx, llvm_i32.const_int(*v as u64, true), dim, - )? else { - return Ok(None) + let Some(index) = + normalize_index(generator, ctx, llvm_i32.const_int(*v as u64, true), dim)? + else { + return Ok(None); }; Ok(Some((index, index, llvm_i32.const_int(1, true)))) @@ -1772,27 +1837,24 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( ExprKind::Slice { lower, upper, step } => { let dim_sz = unsafe { - v.dim_sizes() - .get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(dim, false), - None, - ) + v.dim_sizes().get_typed_unchecked( + ctx, + generator, + &llvm_usize.const_int(dim, false), + None, + ) }; handle_slice_indices(lower, upper, step, ctx, generator, dim_sz) } _ => { - let Some(index) = generator.gen_expr(ctx, slice)? else { - return Ok(None) - }; + let Some(index) = generator.gen_expr(ctx, slice)? else { return Ok(None) }; let index = index .to_basic_value_enum(ctx, generator, slice.custom.unwrap())? .into_int_value(); let Some(index) = normalize_index(generator, ctx, index, dim)? else { - return Ok(None) + return Ok(None); }; Ok(Some((index, index, llvm_i32.const_int(1, true)))) @@ -1802,84 +1864,64 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( Ok(Some(match &slice.node { ExprKind::Tuple { elts, .. } => { - let slices = elts.iter().enumerate() + let slices = elts + .iter() + .enumerate() .map(|(dim, elt)| expr_to_slice(generator, ctx, &elt.node, dim as u64)) .take_while_inclusive(|slice| slice.as_ref().is_ok_and(Option::is_some)) .collect::, _>>()?; if slices.len() < elts.len() { - return Ok(None) + return Ok(None); } - let slices = slices.into_iter() - .map(Option::unwrap) - .collect_vec(); + let slices = slices.into_iter().map(Option::unwrap).collect_vec(); - numpy::ndarray_sliced_copy( - generator, - ctx, - ty, - v, - &slices, - )?.as_base_value().into() + numpy::ndarray_sliced_copy(generator, ctx, ty, v, &slices)?.as_base_value().into() } ExprKind::Slice { .. } => { let Some(slice) = expr_to_slice(generator, ctx, &slice.node, 0)? else { - return Ok(None) + return Ok(None); }; - numpy::ndarray_sliced_copy( - generator, - ctx, - ty, - v, - &[slice], - )?.as_base_value().into() + numpy::ndarray_sliced_copy(generator, ctx, ty, v, &[slice])?.as_base_value().into() } _ => { let index = if let Some(index) = generator.gen_expr(ctx, slice)? { index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value() } else { - return Ok(None) - }; - let Some(index) = normalize_index(generator, ctx, index, 0)? else { - return Ok(None) + return Ok(None); }; + let Some(index) = normalize_index(generator, ctx, index, 0)? else { return Ok(None) }; let index_addr = generator.gen_var_alloc(ctx, index.get_type().into(), None)?; ctx.builder.build_store(index_addr, index).unwrap(); if ndims.len() == 1 && ndims[0] == 1 { // Accessing an element from a 1-dimensional `ndarray` - return Ok(Some(v.data() - .get( - ctx, - generator, - &ArraySliceValue::from_ptr_val( - index_addr, - llvm_usize.const_int(1, false), + return Ok(Some( + v.data() + .get( + ctx, + generator, + &ArraySliceValue::from_ptr_val( + index_addr, + llvm_usize.const_int(1, false), + None, + ), None, - ), - None, - ) - .into())) + ) + .into(), + )); } // Accessing an element from a multi-dimensional `ndarray` // Create a new array, remove the top dimension from the dimension-size-list, and copy the // elements over - let subscripted_ndarray = generator.gen_var_alloc( - ctx, - llvm_ndarray_t.into(), - None - )?; - let ndarray = NDArrayValue::from_ptr_val( - subscripted_ndarray, - llvm_usize, - None - ); + let subscripted_ndarray = generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?; + let ndarray = NDArrayValue::from_ptr_val(subscripted_ndarray, llvm_usize, None); let num_dims = v.load_ndims(ctx); ndarray.store_ndims( @@ -1923,7 +1965,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( ctx, generator, &ArraySliceValue::from_ptr_val(index_addr, llvm_usize.const_int(1, false), None), - None + None, ); call_memcpy_generic( ctx, @@ -1951,7 +1993,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( let int32 = ctx.ctx.i32_type(); let usize = generator.get_size_type(ctx.ctx); let zero = int32.const_int(0, false); - + let loc = ctx.debug_info.0.create_debug_location( ctx.ctx, ctx.current_loc.row as u32, @@ -1964,9 +2006,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( Ok(Some(match &expr.node { ExprKind::Constant { value, .. } => { let ty = expr.custom.unwrap(); - let Some(const_val) = ctx.gen_const(generator, value, ty) else { - return Ok(None) - }; + let Some(const_val) = ctx.gen_const(generator, value, ty) else { return Ok(None) }; const_val.into() } ExprKind::Name { id, .. } if id == &"none".into() => { @@ -1974,19 +2014,21 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( ctx.unifier.get_ty(expr.custom.unwrap()).as_ref(), ctx.unifier.get_ty(ctx.primitives.option).as_ref(), ) { - ( - TypeEnum::TObj { obj_id, params, .. }, - TypeEnum::TObj { obj_id: opt_id, .. }, - ) if *obj_id == *opt_id => ctx - .get_llvm_type(generator, *params.iter().next().unwrap().1) - .ptr_type(AddressSpace::default()) - .const_null() - .into(), + (TypeEnum::TObj { obj_id, params, .. }, TypeEnum::TObj { obj_id: opt_id, .. }) + if *obj_id == *opt_id => + { + ctx.get_llvm_type(generator, *params.iter().next().unwrap().1) + .ptr_type(AddressSpace::default()) + .const_null() + .into() + } _ => unreachable!("must be option type"), } } ExprKind::Name { id, .. } => match ctx.var_assignment.get(id) { - Some((ptr, None, _)) => ctx.builder.build_load(*ptr, id.to_string().as_str()).map(Into::into).unwrap(), + Some((ptr, None, _)) => { + ctx.builder.build_load(*ptr, id.to_string().as_str()).map(Into::into).unwrap() + } Some((_, Some(static_value), _)) => ValueEnum::Static(static_value.clone()), None => { let resolver = ctx.resolver.clone(); @@ -2001,12 +2043,14 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( .map(|x| generator.gen_expr(ctx, x)) .take_while(|v| !matches!(v, Ok(None))) .collect::, _>>()?; - let elements = elements.into_iter().zip(elts) + let elements = elements + .into_iter() + .zip(elts) .map(|(v, x)| v.unwrap().to_basic_value_enum(ctx, generator, x.custom.unwrap())) .collect::, _>>()?; if elements.len() < elts.len() { - return Ok(None) + return Ok(None); } let ty = if elements.is_empty() { @@ -2022,8 +2066,12 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( let arr_str_ptr = allocate_list(generator, ctx, ty, length, Some("list")); let arr_ptr = arr_str_ptr.data(); for (i, v) in elements.iter().enumerate() { - let elem_ptr = arr_ptr - .ptr_offset(ctx, generator, &usize.const_int(i as u64, false), Some("elem_ptr")); + let elem_ptr = arr_ptr.ptr_offset( + ctx, + generator, + &usize.const_int(i as u64, false), + Some("elem_ptr"), + ); ctx.builder.build_store(elem_ptr, *v).unwrap(); } arr_str_ptr.as_base_value().into() @@ -2034,12 +2082,14 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( .map(|x| generator.gen_expr(ctx, x)) .take_while(|v| !matches!(v, Ok(None))) .collect::, _>>()?; - let element_val = elements_val.into_iter().zip(elts) + let element_val = elements_val + .into_iter() + .zip(elts) .map(|(v, x)| v.unwrap().to_basic_value_enum(ctx, generator, x.custom.unwrap())) .collect::, _>>()?; if element_val.len() < elts.len() { - return Ok(None) + return Ok(None); } let element_ty = element_val.iter().map(BasicValueEnum::get_type).collect_vec(); @@ -2047,11 +2097,14 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( let tuple_ptr = ctx.builder.build_alloca(tuple_ty, "tuple").unwrap(); for (i, v) in element_val.into_iter().enumerate() { unsafe { - let ptr = ctx.builder.build_in_bounds_gep( - tuple_ptr, - &[zero, int32.const_int(i as u64, false)], - "ptr", - ).unwrap(); + let ptr = ctx + .builder + .build_in_bounds_gep( + tuple_ptr, + &[zero, int32.const_int(i as u64, false)], + "ptr", + ) + .unwrap(); ctx.builder.build_store(ptr, v).unwrap(); } } @@ -2060,15 +2113,18 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( ExprKind::Attribute { value, attr, .. } => { // note that we would handle class methods directly in calls match generator.gen_expr(ctx, value)? { - Some(ValueEnum::Static(v)) => v.get_field(*attr, ctx).map_or_else(|| { - let v = v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?; - let index = ctx.get_attr_index(value.custom.unwrap(), *attr); - Ok(ValueEnum::Dynamic(ctx.build_gep_and_load( - v.into_pointer_value(), - &[zero, int32.const_int(index as u64, false)], - None, - ))) as Result<_, String> - }, Ok)?, + Some(ValueEnum::Static(v)) => v.get_field(*attr, ctx).map_or_else( + || { + let v = v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?; + let index = ctx.get_attr_index(value.custom.unwrap(), *attr); + Ok(ValueEnum::Dynamic(ctx.build_gep_and_load( + v.into_pointer_value(), + &[zero, int32.const_int(index as u64, false)], + None, + ))) as Result<_, String> + }, + Ok, + )?, Some(ValueEnum::Dynamic(v)) => { let index = ctx.get_attr_index(value.custom.unwrap(), *attr); ValueEnum::Dynamic(ctx.build_gep_and_load( @@ -2085,7 +2141,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( let left = if let Some(v) = generator.gen_expr(ctx, &values[0])? { v.to_basic_value_enum(ctx, generator, values[0].custom.unwrap())?.into_int_value() } else { - return Ok(None) + return Ok(None); }; let left = generator.bool_to_i1(ctx, left); let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); @@ -2101,7 +2157,9 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( ctx.builder.position_at_end(b_bb); let b = if let Some(v) = generator.gen_expr(ctx, &values[1])? { - let b = v.to_basic_value_enum(ctx, generator, values[1].custom.unwrap())?.into_int_value(); + let b = v + .to_basic_value_enum(ctx, generator, values[1].custom.unwrap())? + .into_int_value(); let b = generator.bool_to_i8(ctx, b); ctx.builder.build_unconditional_branch(cont_bb).unwrap(); @@ -2115,7 +2173,9 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( Boolop::And => { ctx.builder.position_at_end(a_bb); let a = if let Some(v) = generator.gen_expr(ctx, &values[1])? { - let a = v.to_basic_value_enum(ctx, generator, values[1].custom.unwrap())?.into_int_value(); + let a = v + .to_basic_value_enum(ctx, generator, values[1].custom.unwrap())? + .into_int_value(); let a = generator.bool_to_i8(ctx, a); ctx.builder.build_unconditional_branch(cont_bb).unwrap(); @@ -2147,15 +2207,15 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( ExprKind::BinOp { op, left, right } => { return gen_binop_expr(generator, ctx, left, op, right, expr.location, false); } - ExprKind::UnaryOp { op, operand } => { - return gen_unaryop_expr(generator, ctx, op, operand) - } + ExprKind::UnaryOp { op, operand } => return gen_unaryop_expr(generator, ctx, op, operand), ExprKind::Compare { left, ops, comparators } => { return gen_cmpop_expr(generator, ctx, left, ops, comparators) } ExprKind::IfExp { test, body, orelse } => { let test = match generator.gen_expr(ctx, test)? { - Some(v) => v.to_basic_value_enum(ctx, generator, test.custom.unwrap())?.into_int_value(), + Some(v) => { + v.to_basic_value_enum(ctx, generator, test.custom.unwrap())?.into_int_value() + } None => return Ok(None), }; let test = generator.bool_to_i1(ctx, test); @@ -2203,7 +2263,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( if let Some(v) = result { ctx.builder.build_load(v, "if_exp_val_load").map(Into::into).unwrap() } else { - return Ok(None) + return Ok(None); } } ExprKind::Call { func, args, keywords } => { @@ -2215,7 +2275,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( .collect::, _>>()?; if params.len() < args.len() { - return Ok(None) + return Ok(None); } let kw_iter = keywords.iter().map(|kw| { @@ -2231,9 +2291,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( ctx.unifier.get_call_signature(*call).unwrap() } else { let ty = func.custom.unwrap(); - let TypeEnum::TFunc(sign) = &*ctx.unifier.get_ty(ty) else { - unreachable!() - }; + let TypeEnum::TFunc(sign) = &*ctx.unifier.get_ty(ty) else { unreachable!() }; sign.clone() }; @@ -2241,18 +2299,15 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( match &func.node { ExprKind::Name { id, .. } => { // TODO: handle primitive casts and function pointers - let fun = ctx - .resolver - .get_identifier_def(*id) - .map_err(|e| format!("{} (at {})", e.iter().next().unwrap(), func.location))?; + let fun = ctx.resolver.get_identifier_def(*id).map_err(|e| { + format!("{} (at {})", e.iter().next().unwrap(), func.location) + })?; return Ok(generator .gen_call(ctx, None, (&signature, fun), params)? .map(Into::into)); } ExprKind::Attribute { value, attr, .. } => { - let Some(val) = generator.gen_expr(ctx, value)? else { - return Ok(None) - }; + let Some(val) = generator.gen_expr(ctx, value)? else { return Ok(None) }; let id = if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(value.custom.unwrap()) @@ -2264,9 +2319,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( let fun_id = { let defs = ctx.top_level.definitions.read(); let obj_def = defs.get(id.0).unwrap().read(); - let TopLevelDef::Class { methods, .. } = &*obj_def else { - unreachable!() - }; + let TopLevelDef::Class { methods, .. } = &*obj_def else { unreachable!() }; methods.iter().find(|method| method.0 == *attr).unwrap().2 }; @@ -2276,47 +2329,52 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( && id == ctx.primitives.option.obj_id(&ctx.unifier).unwrap() { match val { - ValueEnum::Static(v) => return match v.get_field("_nac3_option".into(), ctx) { - // if is none, raise exception directly - None => { - let err_msg = ctx.gen_string(generator, ""); - let current_fun = ctx - .builder - .get_insert_block() - .unwrap() - .get_parent() - .unwrap(); - let unreachable_block = ctx.ctx.append_basic_block( - current_fun, - "unwrap_none_unreachable" - ); - let exn_block = ctx.ctx.append_basic_block( - current_fun, - "unwrap_none_exception" - ); - ctx.builder.build_unconditional_branch(exn_block).unwrap(); - ctx.builder.position_at_end(exn_block); - ctx.raise_exn( - generator, - "0:UnwrapNoneError", - err_msg, - [None, None, None], - ctx.current_loc - ); - ctx.builder.position_at_end(unreachable_block); - let ptr = ctx - .get_llvm_type(generator, value.custom.unwrap()) - .into_pointer_type() - .const_null(); - Ok(Some(ctx.builder.build_load( - ptr, - "unwrap_none_unreachable_load" - ).map(Into::into).unwrap())) - } - Some(v) => Ok(Some(v)), - }, + ValueEnum::Static(v) => { + return match v.get_field("_nac3_option".into(), ctx) { + // if is none, raise exception directly + None => { + let err_msg = ctx.gen_string(generator, ""); + let current_fun = ctx + .builder + .get_insert_block() + .unwrap() + .get_parent() + .unwrap(); + let unreachable_block = ctx.ctx.append_basic_block( + current_fun, + "unwrap_none_unreachable", + ); + let exn_block = ctx.ctx.append_basic_block( + current_fun, + "unwrap_none_exception", + ); + ctx.builder.build_unconditional_branch(exn_block).unwrap(); + ctx.builder.position_at_end(exn_block); + ctx.raise_exn( + generator, + "0:UnwrapNoneError", + err_msg, + [None, None, None], + ctx.current_loc, + ); + ctx.builder.position_at_end(unreachable_block); + let ptr = ctx + .get_llvm_type(generator, value.custom.unwrap()) + .into_pointer_type() + .const_null(); + Ok(Some( + ctx.builder + .build_load(ptr, "unwrap_none_unreachable_load") + .map(Into::into) + .unwrap(), + )) + } + Some(v) => Ok(Some(v)), + }; + } ValueEnum::Dynamic(BasicValueEnum::PointerValue(ptr)) => { - let not_null = ctx.builder.build_is_not_null(ptr, "unwrap_not_null").unwrap(); + let not_null = + ctx.builder.build_is_not_null(ptr, "unwrap_not_null").unwrap(); ctx.make_assert( generator, not_null, @@ -2325,12 +2383,14 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( [None, None, None], expr.location, ); - return Ok(Some(ctx.builder.build_load( - ptr, - "unwrap_some_load" - ).map(Into::into).unwrap())) + return Ok(Some( + ctx.builder + .build_load(ptr, "unwrap_some_load") + .map(Into::into) + .unwrap(), + )); } - ValueEnum::Dynamic(_) => unreachable!("option must be static or ptr") + ValueEnum::Dynamic(_) => unreachable!("option must be static or ptr"), } } @@ -2353,34 +2413,35 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( match &*ctx.unifier.get_ty(value.custom.unwrap()) { TypeEnum::TList { ty } => { let v = if let Some(v) = generator.gen_expr(ctx, value)? { - v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?.into_pointer_value() + v.to_basic_value_enum(ctx, generator, value.custom.unwrap())? + .into_pointer_value() } else { - return Ok(None) + return Ok(None); }; let v = ListValue::from_ptr_val(v, usize, Some("arr")); let ty = ctx.get_llvm_type(generator, *ty); if let ExprKind::Slice { lower, upper, step } = &slice.node { let one = int32.const_int(1, false); let Some((start, end, step)) = handle_slice_indices( - lower, - upper, - step, - ctx, - generator, - v.load_size(ctx, None), - )? else { return Ok(None) }; + lower, + upper, + step, + ctx, + generator, + v.load_size(ctx, None), + )? + else { + return Ok(None); + }; let length = calculate_len_for_slice_range( generator, ctx, start, ctx.builder .build_select( - ctx.builder.build_int_compare( - IntPredicate::SLT, - step, - zero, - "is_neg", - ).unwrap(), + ctx.builder + .build_int_compare(IntPredicate::SLT, step, zero, "is_neg") + .unwrap(), ctx.builder.build_int_sub(end, one, "e_min_one").unwrap(), ctx.builder.build_int_add(end, one, "e_add_one").unwrap(), "final_e", @@ -2397,7 +2458,10 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( ctx, generator, res_array_ret.load_size(ctx, None), - )? else { return Ok(None) }; + )? + else { + return Ok(None); + }; list_slice_assignment( generator, ctx, @@ -2411,23 +2475,27 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } else { let len = v.load_size(ctx, Some("len")); let raw_index = if let Some(v) = generator.gen_expr(ctx, slice)? { - v.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value() + v.to_basic_value_enum(ctx, generator, slice.custom.unwrap())? + .into_int_value() } else { - return Ok(None) + return Ok(None); }; - let raw_index = ctx.builder.build_int_s_extend( - raw_index, - generator.get_size_type(ctx.ctx), - "sext", - ).unwrap(); + let raw_index = ctx + .builder + .build_int_s_extend(raw_index, generator.get_size_type(ctx.ctx), "sext") + .unwrap(); // handle negative index - let is_negative = ctx.builder.build_int_compare( - IntPredicate::SLT, - raw_index, - generator.get_size_type(ctx.ctx).const_zero(), - "is_neg", - ).unwrap(); - let adjusted = ctx.builder.build_int_add(raw_index, len, "adjusted").unwrap(); + let is_negative = ctx + .builder + .build_int_compare( + IntPredicate::SLT, + raw_index, + generator.get_size_type(ctx.ctx).const_zero(), + "is_neg", + ) + .unwrap(); + let adjusted = + ctx.builder.build_int_add(raw_index, len, "adjusted").unwrap(); let index = ctx .builder .build_select(is_negative, adjusted, raw_index, "index") @@ -2435,12 +2503,10 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( .unwrap(); // unsigned less than is enough, because negative index after adjustment is // bigger than the length (for unsigned cmp) - let bound_check = ctx.builder.build_int_compare( - IntPredicate::ULT, - index, - len, - "inbound", - ).unwrap(); + let bound_check = ctx + .builder + .build_int_compare(IntPredicate::ULT, index, len, "inbound") + .unwrap(); ctx.make_assert( generator, bound_check, @@ -2453,26 +2519,17 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } } TypeEnum::TObj { obj_id, params, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { - let (ty, ndims) = params.iter() - .map(|(_, ty)| ty) - .collect_tuple() - .unwrap(); + let (ty, ndims) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap(); let v = if let Some(v) = generator.gen_expr(ctx, value)? { - v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?.into_pointer_value() + v.to_basic_value_enum(ctx, generator, value.custom.unwrap())? + .into_pointer_value() } else { - return Ok(None) + return Ok(None); }; let v = NDArrayValue::from_ptr_val(v, usize, None); - return gen_ndarray_subscript_expr( - generator, - ctx, - *ty, - *ndims, - v, - slice, - ) + return gen_ndarray_subscript_expr(generator, ctx, *ty, *ndims, v, slice); } TypeEnum::TTuple { .. } => { let index: u32 = @@ -2493,7 +2550,10 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( let tup = v .to_basic_value_enum(ctx, generator, value.custom.unwrap())? .into_struct_value(); - ctx.builder.build_extract_value(tup, index, "tup_elem").unwrap().into() + ctx.builder + .build_extract_value(tup, index, "tup_elem") + .unwrap() + .into() } } None => return Ok(None), @@ -2501,12 +2561,12 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } _ => unreachable!("should not be other subscriptable types after type check"), } - }, + } ExprKind::ListComp { .. } => { if let Some(v) = gen_comprehension(generator, ctx, expr)? { v.into() } else { - return Ok(None) + return Ok(None); } } _ => unimplemented!(), diff --git a/nac3core/src/codegen/extern_fns.rs b/nac3core/src/codegen/extern_fns.rs index e29b8a17..c22d69d9 100644 --- a/nac3core/src/codegen/extern_fns.rs +++ b/nac3core/src/codegen/extern_fns.rs @@ -21,7 +21,7 @@ pub fn call_tan<'ctx>( for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { func.add_attribute( AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0) + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), ); } @@ -53,7 +53,7 @@ pub fn call_asin<'ctx>( for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { func.add_attribute( AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0) + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), ); } @@ -85,7 +85,7 @@ pub fn call_acos<'ctx>( for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { func.add_attribute( AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0) + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), ); } @@ -117,7 +117,7 @@ pub fn call_atan<'ctx>( for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { func.add_attribute( AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0) + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), ); } @@ -149,7 +149,7 @@ pub fn call_sinh<'ctx>( for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { func.add_attribute( AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0) + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), ); } @@ -181,7 +181,7 @@ pub fn call_cosh<'ctx>( for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { func.add_attribute( AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0) + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), ); } @@ -213,7 +213,7 @@ pub fn call_tanh<'ctx>( for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { func.add_attribute( AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0) + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), ); } @@ -245,7 +245,7 @@ pub fn call_asinh<'ctx>( for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { func.add_attribute( AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0) + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), ); } @@ -277,7 +277,7 @@ pub fn call_acosh<'ctx>( for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { func.add_attribute( AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0) + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), ); } @@ -309,7 +309,7 @@ pub fn call_atanh<'ctx>( for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { func.add_attribute( AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0) + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), ); } @@ -341,7 +341,7 @@ pub fn call_expm1<'ctx>( for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { func.add_attribute( AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0) + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), ); } @@ -373,7 +373,7 @@ pub fn call_cbrt<'ctx>( for attr in ["mustprogress", "nofree", "nosync", "nounwind", "readonly", "willreturn"] { func.add_attribute( AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0) + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), ); } @@ -404,7 +404,7 @@ pub fn call_erf<'ctx>( let func = ctx.module.add_function(FN_NAME, fn_type, None); func.add_attribute( AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0) + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0), ); func @@ -434,7 +434,7 @@ pub fn call_erfc<'ctx>( let func = ctx.module.add_function(FN_NAME, fn_type, None); func.add_attribute( AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0) + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0), ); func @@ -465,7 +465,7 @@ pub fn call_j1<'ctx>( let func = ctx.module.add_function(FN_NAME, fn_type, None); func.add_attribute( AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0) + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0), ); func @@ -498,7 +498,7 @@ pub fn call_atan2<'ctx>( for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { func.add_attribute( AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0) + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), ); } @@ -533,7 +533,7 @@ pub fn call_ldexp<'ctx>( for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] { func.add_attribute( AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0) + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), ); } @@ -566,7 +566,7 @@ pub fn call_hypot<'ctx>( let func = ctx.module.add_function(FN_NAME, fn_type, None); func.add_attribute( AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0) + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0), ); func @@ -598,7 +598,7 @@ pub fn call_nextafter<'ctx>( let func = ctx.module.add_function(FN_NAME, fn_type, None); func.add_attribute( AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0) + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0), ); func @@ -610,4 +610,4 @@ pub fn call_nextafter<'ctx>( .map(|v| v.map_left(BasicValueEnum::into_float_value)) .map(Either::unwrap_left) .unwrap() -} \ No newline at end of file +} diff --git a/nac3core/src/codegen/generator.rs b/nac3core/src/codegen/generator.rs index 3b023e35..bb822f19 100644 --- a/nac3core/src/codegen/generator.rs +++ b/nac3core/src/codegen/generator.rs @@ -1,5 +1,5 @@ use crate::{ - codegen::{classes::ArraySliceValue, expr::*, stmt::*, bool_to_i1, bool_to_i8, CodeGenContext}, + codegen::{bool_to_i1, bool_to_i8, classes::ArraySliceValue, expr::*, stmt::*, CodeGenContext}, symbol_resolver::ValueEnum, toplevel::{DefinitionId, TopLevelDef}, typecheck::typedef::{FunSignature, Type}, @@ -210,7 +210,7 @@ pub trait CodeGenerator { fn bool_to_i1<'ctx>( &self, ctx: &CodeGenContext<'ctx, '_>, - bool_value: IntValue<'ctx> + bool_value: IntValue<'ctx>, ) -> IntValue<'ctx> { bool_to_i1(&ctx.builder, bool_value) } @@ -219,7 +219,7 @@ pub trait CodeGenerator { fn bool_to_i8<'ctx>( &self, ctx: &CodeGenContext<'ctx, '_>, - bool_value: IntValue<'ctx> + bool_value: IntValue<'ctx>, ) -> IntValue<'ctx> { bool_to_i8(&ctx.builder, ctx.ctx, bool_value) } @@ -239,7 +239,6 @@ impl DefaultCodeGenerator { } impl CodeGenerator for DefaultCodeGenerator { - /// Returns the name for this [`CodeGenerator`]. fn get_name(&self) -> &str { &self.name diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 500a3e6d..8677f085 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -2,18 +2,13 @@ use crate::typecheck::typedef::Type; use super::{ classes::{ - ArrayLikeIndexer, - ArrayLikeValue, - ArraySliceValue, - ListValue, - NDArrayValue, - TypedArrayLikeAdapter, - UntypedArrayLikeAccessor, + ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, + TypedArrayLikeAdapter, UntypedArrayLikeAccessor, }, - CodeGenContext, - CodeGenerator, - llvm_intrinsics, + llvm_intrinsics, CodeGenContext, CodeGenerator, }; +use crate::codegen::classes::TypedArrayLikeAccessor; +use crate::codegen::stmt::gen_for_callback_incrementing; use inkwell::{ attributes::{Attribute, AttributeLoc}, context::Context, @@ -25,8 +20,6 @@ use inkwell::{ }; use itertools::Either; use nac3parser::ast::Expr; -use crate::codegen::classes::TypedArrayLikeAccessor; -use crate::codegen::stmt::gen_for_callback_incrementing; #[must_use] pub fn load_irrt(ctx: &Context) -> Module { @@ -70,12 +63,15 @@ pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>( ctx.module.add_function(symbol, fn_type, None) }); // throw exception when exp < 0 - let ge_zero = ctx.builder.build_int_compare( - IntPredicate::SGE, - exp, - exp.get_type().const_zero(), - "assert_int_pow_ge_0", - ).unwrap(); + let ge_zero = ctx + .builder + .build_int_compare( + IntPredicate::SGE, + exp, + exp.get_type().const_zero(), + "assert_int_pow_ge_0", + ) + .unwrap(); ctx.make_assert( generator, ge_zero, @@ -107,12 +103,10 @@ pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>( }); // assert step != 0, throw exception if not - let not_zero = ctx.builder.build_int_compare( - IntPredicate::NE, - step, - step.get_type().const_zero(), - "range_step_ne", - ).unwrap(); + let not_zero = ctx + .builder + .build_int_compare(IntPredicate::NE, step, step.get_type().const_zero(), "range_step_ne") + .unwrap(); ctx.make_assert( generator, not_zero, @@ -208,15 +202,18 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>( let step = if let Some(v) = generator.gen_expr(ctx, step)? { v.to_basic_value_enum(ctx, generator, ctx.primitives.int32)?.into_int_value() } else { - return Ok(None) + return Ok(None); }; // assert step != 0, throw exception if not - let not_zero = ctx.builder.build_int_compare( - IntPredicate::NE, - step, - step.get_type().const_zero(), - "range_step_ne", - ).unwrap(); + let not_zero = ctx + .builder + .build_int_compare( + IntPredicate::NE, + step, + step.get_type().const_zero(), + "range_step_ne", + ) + .unwrap(); ctx.make_assert( generator, not_zero, @@ -226,25 +223,32 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>( ctx.current_loc, ); let len_id = ctx.builder.build_int_sub(length, one, "lenmin1").unwrap(); - let neg = ctx.builder.build_int_compare(IntPredicate::SLT, step, zero, "step_is_neg").unwrap(); + let neg = ctx + .builder + .build_int_compare(IntPredicate::SLT, step, zero, "step_is_neg") + .unwrap(); ( match s { Some(s) => { let Some(s) = handle_slice_index_bound(s, ctx, generator, length)? else { - return Ok(None) + return Ok(None); }; ctx.builder .build_select( - ctx.builder.build_and( - ctx.builder.build_int_compare( - IntPredicate::EQ, - s, - length, - "s_eq_len", - ).unwrap(), - neg, - "should_minus_one", - ).unwrap(), + ctx.builder + .build_and( + ctx.builder + .build_int_compare( + IntPredicate::EQ, + s, + length, + "s_eq_len", + ) + .unwrap(), + neg, + "should_minus_one", + ) + .unwrap(), ctx.builder.build_int_sub(s, one, "s_min").unwrap(), s, "final_start", @@ -252,14 +256,16 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>( .map(BasicValueEnum::into_int_value) .unwrap() } - None => ctx.builder.build_select(neg, len_id, zero, "stt") + None => ctx + .builder + .build_select(neg, len_id, zero, "stt") .map(BasicValueEnum::into_int_value) .unwrap(), }, match e { Some(e) => { let Some(e) = handle_slice_index_bound(e, ctx, generator, length)? else { - return Ok(None) + return Ok(None); }; ctx.builder .build_select( @@ -271,7 +277,9 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>( .map(BasicValueEnum::into_int_value) .unwrap() } - None => ctx.builder.build_select(neg, zero, len_id, "end") + None => ctx + .builder + .build_select(neg, zero, len_id, "end") .map(BasicValueEnum::into_int_value) .unwrap(), }, @@ -299,15 +307,16 @@ pub fn handle_slice_index_bound<'ctx, G: CodeGenerator>( let i = if let Some(v) = generator.gen_expr(ctx, i)? { v.to_basic_value_enum(ctx, generator, i.custom.unwrap())? } else { - return Ok(None) + return Ok(None); }; - Ok(Some(ctx - .builder - .build_call(func, &[i.into(), length.into()], "bounded_ind") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap())) + Ok(Some( + ctx.builder + .build_call(func, &[i.into(), length.into()], "bounded_ind") + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_int_value)) + .map(Either::unwrap_left) + .unwrap(), + )) } /// This function handles 'end' **inclusively**. @@ -349,47 +358,33 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>( let zero = int32.const_zero(); let one = int32.const_int(1, false); let dest_arr_ptr = dest_arr.data().base_ptr(ctx, generator); - let dest_arr_ptr = ctx.builder.build_pointer_cast( - dest_arr_ptr, - elem_ptr_type, - "dest_arr_ptr_cast", - ).unwrap(); + let dest_arr_ptr = + ctx.builder.build_pointer_cast(dest_arr_ptr, elem_ptr_type, "dest_arr_ptr_cast").unwrap(); let dest_len = dest_arr.load_size(ctx, Some("dest.len")); let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32").unwrap(); let src_arr_ptr = src_arr.data().base_ptr(ctx, generator); - let src_arr_ptr = ctx.builder.build_pointer_cast( - src_arr_ptr, - elem_ptr_type, - "src_arr_ptr_cast", - ).unwrap(); + let src_arr_ptr = + ctx.builder.build_pointer_cast(src_arr_ptr, elem_ptr_type, "src_arr_ptr_cast").unwrap(); let src_len = src_arr.load_size(ctx, Some("src.len")); let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32").unwrap(); // index in bound and positive should be done // assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and // throw exception if not satisfied - let src_end = ctx.builder + let src_end = ctx + .builder .build_select( - ctx.builder.build_int_compare( - IntPredicate::SLT, - src_idx.2, - zero, - "is_neg", - ).unwrap(), + ctx.builder.build_int_compare(IntPredicate::SLT, src_idx.2, zero, "is_neg").unwrap(), ctx.builder.build_int_sub(src_idx.1, one, "e_min_one").unwrap(), ctx.builder.build_int_add(src_idx.1, one, "e_add_one").unwrap(), "final_e", ) .map(BasicValueEnum::into_int_value) .unwrap(); - let dest_end = ctx.builder + let dest_end = ctx + .builder .build_select( - ctx.builder.build_int_compare( - IntPredicate::SLT, - dest_idx.2, - zero, - "is_neg", - ).unwrap(), + ctx.builder.build_int_compare(IntPredicate::SLT, dest_idx.2, zero, "is_neg").unwrap(), ctx.builder.build_int_sub(dest_idx.1, one, "e_min_one").unwrap(), ctx.builder.build_int_add(dest_idx.1, one, "e_add_one").unwrap(), "final_e", @@ -400,24 +395,23 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>( calculate_len_for_slice_range(generator, ctx, src_idx.0, src_end, src_idx.2); let dest_slice_len = calculate_len_for_slice_range(generator, ctx, dest_idx.0, dest_end, dest_idx.2); - let src_eq_dest = ctx.builder.build_int_compare( - IntPredicate::EQ, - src_slice_len, - dest_slice_len, - "slice_src_eq_dest", - ).unwrap(); - let src_slt_dest = ctx.builder.build_int_compare( - IntPredicate::SLT, - src_slice_len, - dest_slice_len, - "slice_src_slt_dest", - ).unwrap(); - let dest_step_eq_one = ctx.builder.build_int_compare( - IntPredicate::EQ, - dest_idx.2, - dest_idx.2.get_type().const_int(1, false), - "slice_dest_step_eq_one", - ).unwrap(); + let src_eq_dest = ctx + .builder + .build_int_compare(IntPredicate::EQ, src_slice_len, dest_slice_len, "slice_src_eq_dest") + .unwrap(); + let src_slt_dest = ctx + .builder + .build_int_compare(IntPredicate::SLT, src_slice_len, dest_slice_len, "slice_src_slt_dest") + .unwrap(); + let dest_step_eq_one = ctx + .builder + .build_int_compare( + IntPredicate::EQ, + dest_idx.2, + dest_idx.2.get_type().const_int(1, false), + "slice_dest_step_eq_one", + ) + .unwrap(); let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1").unwrap(); let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond").unwrap(); ctx.make_assert( @@ -461,17 +455,14 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>( .unwrap() }; // update length - let need_update = ctx.builder - .build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update") - .unwrap(); + let need_update = + ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update").unwrap(); let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); let update_bb = ctx.ctx.append_basic_block(current, "update"); let cont_bb = ctx.ctx.append_basic_block(current, "cont"); ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb).unwrap(); ctx.builder.position_at_end(update_bb); - let new_len = ctx.builder - .build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len") - .unwrap(); + let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len").unwrap(); dest_arr.store_size(ctx, generator, new_len); ctx.builder.build_unconditional_branch(cont_bb).unwrap(); ctx.builder.position_at_end(cont_bb); @@ -488,7 +479,8 @@ pub fn call_isinf<'ctx, G: CodeGenerator + ?Sized>( ctx.module.add_function("__nac3_isinf", fn_type, None) }); - let ret = ctx.builder + let ret = ctx + .builder .build_call(intrinsic_fn, &[v.into()], "isinf") .map(CallSiteValue::try_as_basic_value) .map(|v| v.map_left(BasicValueEnum::into_int_value)) @@ -509,7 +501,8 @@ pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>( ctx.module.add_function("__nac3_isnan", fn_type, None) }); - let ret = ctx.builder + let ret = ctx + .builder .build_call(intrinsic_fn, &[v.into()], "isnan") .map(CallSiteValue::try_as_basic_value) .map(|v| v.map_left(BasicValueEnum::into_int_value)) @@ -520,10 +513,7 @@ pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>( } /// Generates a call to `gamma` in IR. Returns an `f64` representing the result. -pub fn call_gamma<'ctx>( - ctx: &CodeGenContext<'ctx, '_>, - v: FloatValue<'ctx>, -) -> FloatValue<'ctx> { +pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> { let llvm_f64 = ctx.ctx.f64_type(); let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| { @@ -540,10 +530,7 @@ pub fn call_gamma<'ctx>( } /// Generates a call to `gammaln` in IR. Returns an `f64` representing the result. -pub fn call_gammaln<'ctx>( - ctx: &CodeGenContext<'ctx, '_>, - v: FloatValue<'ctx>, -) -> FloatValue<'ctx> { +pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> { let llvm_f64 = ctx.ctx.f64_type(); let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| { @@ -560,10 +547,7 @@ pub fn call_gammaln<'ctx>( } /// Generates a call to `j0` in IR. Returns an `f64` representing the result. -pub fn call_j0<'ctx>( - ctx: &CodeGenContext<'ctx, '_>, - v: FloatValue<'ctx>, -) -> FloatValue<'ctx> { +pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> { let llvm_f64 = ctx.ctx.f64_type(); let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| { @@ -583,7 +567,7 @@ pub fn call_j0<'ctx>( /// calculated total size. /// /// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension. -/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for, +/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for, /// or [`None`] if starting from the first dimension and ending at the last dimension respectively. pub fn call_ndarray_calc_size<'ctx, G, Dims>( generator: &G, @@ -591,9 +575,10 @@ pub fn call_ndarray_calc_size<'ctx, G, Dims>( dims: &Dims, (begin, end): (Option>, Option>), ) -> IntValue<'ctx> - where - G: CodeGenerator + ?Sized, - Dims: ArrayLikeIndexer<'ctx>, { +where + G: CodeGenerator + ?Sized, + Dims: ArrayLikeIndexer<'ctx>, +{ let llvm_i64 = ctx.ctx.i64_type(); let llvm_usize = generator.get_size_type(ctx.ctx); @@ -602,19 +587,14 @@ pub fn call_ndarray_calc_size<'ctx, G, Dims>( let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() { 32 => "__nac3_ndarray_calc_size", 64 => "__nac3_ndarray_calc_size64", - bw => unreachable!("Unsupported size type bit width: {}", bw) + bw => unreachable!("Unsupported size type bit width: {}", bw), }; let ndarray_calc_size_fn_t = llvm_usize.fn_type( - &[ - llvm_pi64.into(), - llvm_usize.into(), - llvm_usize.into(), - llvm_usize.into(), - ], + &[llvm_pi64.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()], false, ); - let ndarray_calc_size_fn = ctx.module.get_function(ndarray_calc_size_fn_name) - .unwrap_or_else(|| { + let ndarray_calc_size_fn = + ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| { ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None) }); @@ -658,30 +638,22 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() { 32 => "__nac3_ndarray_calc_nd_indices", 64 => "__nac3_ndarray_calc_nd_indices64", - bw => unreachable!("Unsupported size type bit width: {}", bw) + bw => unreachable!("Unsupported size type bit width: {}", bw), }; - let ndarray_calc_nd_indices_fn = ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| { - let fn_type = llvm_void.fn_type( - &[ - llvm_usize.into(), - llvm_pusize.into(), - llvm_usize.into(), - llvm_pi32.into(), - ], - false, - ); + let ndarray_calc_nd_indices_fn = + ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| { + let fn_type = llvm_void.fn_type( + &[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()], + false, + ); - ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None) - }); + ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None) + }); let ndarray_num_dims = ndarray.load_ndims(ctx); let ndarray_dims = ndarray.dim_sizes(); - let indices = ctx.builder.build_array_alloca( - llvm_i32, - ndarray_num_dims, - "", - ).unwrap(); + let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap(); ctx.builder .build_call( @@ -709,9 +681,10 @@ fn call_ndarray_flatten_index_impl<'ctx, G, Indices>( ndarray: NDArrayValue<'ctx>, indices: &Indices, ) -> IntValue<'ctx> - where - G: CodeGenerator + ?Sized, - Indices: ArrayLikeIndexer<'ctx>, { +where + G: CodeGenerator + ?Sized, + Indices: ArrayLikeIndexer<'ctx>, +{ let llvm_i32 = ctx.ctx.i32_type(); let llvm_usize = generator.get_size_type(ctx.ctx); @@ -734,26 +707,23 @@ fn call_ndarray_flatten_index_impl<'ctx, G, Indices>( let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() { 32 => "__nac3_ndarray_flatten_index", 64 => "__nac3_ndarray_flatten_index64", - bw => unreachable!("Unsupported size type bit width: {}", bw) + bw => unreachable!("Unsupported size type bit width: {}", bw), }; - let ndarray_flatten_index_fn = ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| { - let fn_type = llvm_usize.fn_type( - &[ - llvm_pusize.into(), - llvm_usize.into(), - llvm_pi32.into(), - llvm_usize.into(), - ], - false, - ); + let ndarray_flatten_index_fn = + ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| { + let fn_type = llvm_usize.fn_type( + &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()], + false, + ); - ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None) - }); + ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None) + }); let ndarray_num_dims = ndarray.load_ndims(ctx); let ndarray_dims = ndarray.dim_sizes(); - let index = ctx.builder + let index = ctx + .builder .build_call( ndarray_flatten_index_fn, &[ @@ -784,16 +754,11 @@ pub fn call_ndarray_flatten_index<'ctx, G, Index>( ndarray: NDArrayValue<'ctx>, indices: &Index, ) -> IntValue<'ctx> - where - G: CodeGenerator + ?Sized, - Index: ArrayLikeIndexer<'ctx>, { - - call_ndarray_flatten_index_impl( - generator, - ctx, - ndarray, - indices, - ) +where + G: CodeGenerator + ?Sized, + Index: ArrayLikeIndexer<'ctx>, +{ + call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices) } /// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of @@ -810,22 +775,23 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { 32 => "__nac3_ndarray_calc_broadcast", 64 => "__nac3_ndarray_calc_broadcast64", - bw => unreachable!("Unsupported size type bit width: {}", bw) + bw => unreachable!("Unsupported size type bit width: {}", bw), }; - let ndarray_calc_broadcast_fn = ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { - let fn_type = llvm_usize.fn_type( - &[ - llvm_pusize.into(), - llvm_usize.into(), - llvm_pusize.into(), - llvm_usize.into(), - llvm_pusize.into(), - ], - false, - ); + let ndarray_calc_broadcast_fn = + ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { + let fn_type = llvm_usize.fn_type( + &[ + llvm_pusize.into(), + llvm_usize.into(), + llvm_pusize.into(), + llvm_usize.into(), + llvm_pusize.into(), + ], + false, + ); - ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) - }); + ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) + }); let lhs_ndims = lhs.load_ndims(ctx); let rhs_ndims = rhs.load_ndims(ctx); @@ -846,36 +812,22 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( }; let llvm_usize_const_one = llvm_usize.const_int(1, false); - let lhs_eqz = ctx.builder.build_int_compare( - IntPredicate::EQ, - lhs_dim_sz, - llvm_usize_const_one, - "", - ).unwrap(); - let rhs_eqz = ctx.builder.build_int_compare( - IntPredicate::EQ, - rhs_dim_sz, - llvm_usize_const_one, - "", - ).unwrap(); - let lhs_or_rhs_eqz = ctx.builder.build_or( - lhs_eqz, - rhs_eqz, - "" - ).unwrap(); + let lhs_eqz = ctx + .builder + .build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "") + .unwrap(); + let rhs_eqz = ctx + .builder + .build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "") + .unwrap(); + let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap(); - let lhs_eq_rhs = ctx.builder.build_int_compare( - IntPredicate::EQ, - lhs_dim_sz, - rhs_dim_sz, - "" - ).unwrap(); + let lhs_eq_rhs = ctx + .builder + .build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "") + .unwrap(); - let is_compatible = ctx.builder.build_or( - lhs_or_rhs_eqz, - lhs_eq_rhs, - "" - ).unwrap(); + let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap(); ctx.make_assert( generator, @@ -889,7 +841,8 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( Ok(()) }, llvm_usize.const_int(1, false), - ).unwrap(); + ) + .unwrap(); let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None); let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator); @@ -923,7 +876,11 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( /// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`] /// containing the indices used for accessing `array` corresponding to the index of the broadcasted /// array `broadcast_idx`. -pub fn call_ndarray_calc_broadcast_index<'ctx, G: CodeGenerator + ?Sized, BroadcastIdx: UntypedArrayLikeAccessor<'ctx>>( +pub fn call_ndarray_calc_broadcast_index< + 'ctx, + G: CodeGenerator + ?Sized, + BroadcastIdx: UntypedArrayLikeAccessor<'ctx>, +>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, array: NDArrayValue<'ctx>, @@ -937,21 +894,17 @@ pub fn call_ndarray_calc_broadcast_index<'ctx, G: CodeGenerator + ?Sized, Broadc let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { 32 => "__nac3_ndarray_calc_broadcast_idx", 64 => "__nac3_ndarray_calc_broadcast_idx64", - bw => unreachable!("Unsupported size type bit width: {}", bw) + bw => unreachable!("Unsupported size type bit width: {}", bw), }; - let ndarray_calc_broadcast_fn = ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { - let fn_type = llvm_usize.fn_type( - &[ - llvm_pusize.into(), - llvm_usize.into(), - llvm_pi32.into(), - llvm_pi32.into(), - ], - false, - ); + let ndarray_calc_broadcast_fn = + ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { + let fn_type = llvm_usize.fn_type( + &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()], + false, + ); - ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) - }); + ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) + }); let broadcast_size = broadcast_idx.size(ctx, generator); let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap(); @@ -959,23 +912,13 @@ pub fn call_ndarray_calc_broadcast_index<'ctx, G: CodeGenerator + ?Sized, Broadc let array_dims = array.dim_sizes().base_ptr(ctx, generator); let array_ndims = array.load_ndims(ctx); let broadcast_idx_ptr = unsafe { - broadcast_idx.ptr_offset_unchecked( - ctx, - generator, - &llvm_usize.const_zero(), - None - ) + broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; ctx.builder .build_call( ndarray_calc_broadcast_fn, - &[ - array_dims.into(), - array_ndims.into(), - broadcast_idx_ptr.into(), - out_idx.into(), - ], + &[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()], "", ) .unwrap(); @@ -985,4 +928,4 @@ pub fn call_ndarray_calc_broadcast_index<'ctx, G: CodeGenerator + ?Sized, Broadc Box::new(|_, v| v.into_int_value()), Box::new(|_, v| v.into()), ) -} \ No newline at end of file +} diff --git a/nac3core/src/codegen/llvm_intrinsics.rs b/nac3core/src/codegen/llvm_intrinsics.rs index 7e7dbe24..46b5aeea 100644 --- a/nac3core/src/codegen/llvm_intrinsics.rs +++ b/nac3core/src/codegen/llvm_intrinsics.rs @@ -1,35 +1,35 @@ -use inkwell::AddressSpace; +use crate::codegen::CodeGenContext; use inkwell::context::Context; use inkwell::intrinsics::Intrinsic; use inkwell::types::AnyTypeEnum::IntType; use inkwell::types::FloatType; use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue}; +use inkwell::AddressSpace; use itertools::Either; -use crate::codegen::CodeGenContext; /// Returns the string representation for the floating-point type `ft` when used in intrinsic /// functions. fn get_float_intrinsic_repr(ctx: &Context, ft: FloatType) -> &'static str { // Standard LLVM floating-point types if ft == ctx.f16_type() { - return "f16" + return "f16"; } if ft == ctx.f32_type() { - return "f32" + return "f32"; } if ft == ctx.f64_type() { - return "f64" + return "f64"; } if ft == ctx.f128_type() { - return "f128" + return "f128"; } // Non-standard floating-point types if ft == ctx.x86_f80_type() { - return "f80" + return "f80"; } if ft == ctx.ppc_f128_type() { - return "ppcf128" + return "ppcf128"; } unreachable!() @@ -69,9 +69,7 @@ pub fn call_stackrestore<'ctx>(ctx: &CodeGenContext<'ctx, '_>, ptr: PointerValue .and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_p0i8.into()])) .unwrap(); - ctx.builder - .build_call(intrinsic_fn, &[ptr.into()], "") - .unwrap(); + ctx.builder.build_call(intrinsic_fn, &[ptr.into()], "").unwrap(); } /// Invokes the [`llvm.abs`](https://llvm.org/docs/LangRef.html#llvm-abs-intrinsic) intrinsic. @@ -232,10 +230,12 @@ pub fn call_memcpy<'ctx>( let llvm_len_t = len.get_type(); let intrinsic_fn = Intrinsic::find(FN_NAME) - .and_then(|intrinsic| intrinsic.get_declaration( - &ctx.module, - &[llvm_dest_t.into(), llvm_src_t.into(), llvm_len_t.into()], - )) + .and_then(|intrinsic| { + intrinsic.get_declaration( + &ctx.module, + &[llvm_dest_t.into(), llvm_src_t.into(), llvm_len_t.into()], + ) + }) .unwrap(); ctx.builder @@ -315,10 +315,9 @@ pub fn call_float_powi<'ctx>( let llvm_power_t = power.get_type(); let intrinsic_fn = Intrinsic::find(FN_NAME) - .and_then(|intrinsic| intrinsic.get_declaration( - &ctx.module, - &[llvm_val_t.into(), llvm_power_t.into()], - )) + .and_then(|intrinsic| { + intrinsic.get_declaration(&ctx.module, &[llvm_val_t.into(), llvm_power_t.into()]) + }) .unwrap(); ctx.builder @@ -442,7 +441,6 @@ pub fn call_float_exp2<'ctx>( .unwrap() } - /// Invokes the [`llvm.log`](https://llvm.org/docs/LangRef.html#llvm-log-intrinsic) intrinsic. pub fn call_float_log<'ctx>( ctx: &CodeGenContext<'ctx, '_>, @@ -672,7 +670,7 @@ pub fn call_float_round<'ctx>( .unwrap() } -/// Invokes the +/// Invokes the /// [`llvm.roundeven`](https://llvm.org/docs/LangRef.html#llvm-roundeven-intrinsic) intrinsic. pub fn call_float_roundeven<'ctx>( ctx: &CodeGenContext<'ctx, '_>, diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 31572566..de8a3346 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -2,10 +2,7 @@ use crate::{ codegen::classes::{ListType, NDArrayType, ProxyType, RangeType}, symbol_resolver::{StaticValue, SymbolResolver}, toplevel::{ - helper::PRIMITIVE_DEF_IDS, - numpy::unpack_ndarray_var_tys, - TopLevelContext, - TopLevelDef, + helper::PRIMITIVE_DEF_IDS, numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef, }, typecheck::{ type_inferencer::{CodeLocation, PrimitiveStore}, @@ -14,24 +11,22 @@ use crate::{ }; use crossbeam::channel::{unbounded, Receiver, Sender}; use inkwell::{ - AddressSpace, - IntPredicate, - OptimizationLevel, attributes::{Attribute, AttributeLoc}, basic_block::BasicBlock, builder::Builder, context::Context, + debug_info::{ + AsDIScope, DICompileUnit, DIFlagsConstants, DIScope, DISubprogram, DebugInfoBuilder, + }, module::Module, passes::PassBuilderOptions, targets::{CodeModel, RelocMode, Target, TargetMachine, TargetTriple}, types::{AnyType, BasicType, BasicTypeEnum}, values::{BasicValueEnum, FunctionValue, IntValue, PhiValue, PointerValue}, - debug_info::{ - DebugInfoBuilder, DICompileUnit, DISubprogram, AsDIScope, DIFlagsConstants, DIScope - }, + AddressSpace, IntPredicate, OptimizationLevel, }; use itertools::Itertools; -use nac3parser::ast::{Stmt, StrRef, Location}; +use nac3parser::ast::{Location, Stmt, StrRef}; use parking_lot::{Condvar, Mutex}; use std::collections::{HashMap, HashSet}; use std::sync::{ @@ -91,7 +86,6 @@ pub struct CodeGenTargetMachineOptions { } impl CodeGenTargetMachineOptions { - /// Creates an instance of [`CodeGenTargetMachineOptions`] using the triple of the host machine. /// Other options are set to defaults. #[must_use] @@ -120,13 +114,11 @@ impl CodeGenTargetMachineOptions { /// /// See [`Target::create_target_machine`]. #[must_use] - pub fn create_target_machine( - &self, - level: OptimizationLevel, - ) -> Option { + pub fn create_target_machine(&self, level: OptimizationLevel) -> Option { let triple = TargetTriple::create(self.triple.as_str()); - let target = Target::from_triple(&triple) - .unwrap_or_else(|_| panic!("could not create target from target triple {}", self.triple)); + let target = Target::from_triple(&triple).unwrap_or_else(|_| { + panic!("could not create target from target triple {}", self.triple) + }); target.create_target_machine( &triple, @@ -134,7 +126,7 @@ impl CodeGenTargetMachineOptions { self.features.as_str(), level, self.reloc_mode, - self.code_model + self.code_model, ) } } @@ -205,7 +197,6 @@ pub struct CodeGenContext<'ctx, 'a> { } impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { - /// Whether the [current basic block][Builder::get_insert_block] referenced by `builder` /// contains a [terminator statement][BasicBlock::get_terminator]. pub fn is_terminated(&self) -> bool { @@ -251,7 +242,6 @@ pub struct WorkerRegistry { } impl WorkerRegistry { - /// Creates workers for this registry. #[must_use] pub fn create_workers( @@ -373,7 +363,11 @@ impl WorkerRegistry { *self.task_count.lock() -= 1; self.wait_condvar.notify_all(); } - assert!(errors.is_empty(), "Codegen error: {}", errors.into_iter().sorted().join("\n----------\n")); + assert!( + errors.is_empty(), + "Codegen error: {}", + errors.into_iter().sorted().join("\n----------\n") + ); let result = module.verify(); if let Err(err) = result { @@ -386,13 +380,20 @@ impl WorkerRegistry { .llvm_options .target .create_target_machine(self.llvm_options.opt_level) - .unwrap_or_else(|| panic!("could not create target machine from properties {:?}", self.llvm_options.target)); + .unwrap_or_else(|| { + panic!( + "could not create target machine from properties {:?}", + self.llvm_options.target + ) + }); let passes = format!("default", self.llvm_options.opt_level as u32); let result = module.run_passes(passes.as_str(), &target_machine, pass_options); if let Err(err) = result { - panic!("Failed to run optimization for module `{}`: {}", - module.get_name().to_str().unwrap(), - err.to_string()); + panic!( + "Failed to run optimization for module `{}`: {}", + module.get_name().to_str().unwrap(), + err.to_string() + ); } f.run(&module); @@ -455,20 +456,17 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { let (dtype, _) = unpack_ndarray_var_tys(unifier, ty); let element_type = get_llvm_type( - ctx, - module, - generator, - unifier, - top_level, - type_cache, - dtype, + ctx, module, generator, unifier, top_level, type_cache, dtype, ); NDArrayType::new(generator, ctx, element_type).as_base_type().into() } - _ => unreachable!("LLVM type for primitive {} is missing", unifier.stringify(ty)), - } + _ => unreachable!( + "LLVM type for primitive {} is missing", + unifier.stringify(ty) + ), + }; } // a struct with fields in the order of declaration let top_level_defs = top_level.definitions.read(); @@ -484,7 +482,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( let struct_type = ctx.opaque_struct_type(&name); type_cache.insert( unifier.get_representative(ty), - struct_type.ptr_type(AddressSpace::default()).into() + struct_type.ptr_type(AddressSpace::default()).into(), ); let fields = fields_list .iter() @@ -503,24 +501,21 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( struct_type.set_body(&fields, false); struct_type.ptr_type(AddressSpace::default()).into() }; - return ty + return ty; } TTuple { ty } => { // a struct with fields in the order present in the tuple let fields = ty .iter() .map(|ty| { - get_llvm_type( - ctx, module, generator, unifier, top_level, type_cache, *ty, - ) + get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty) }) .collect_vec(); ctx.struct_type(&fields, false).into() } TList { ty } => { - let element_type = get_llvm_type( - ctx, module, generator, unifier, top_level, type_cache, *ty, - ); + let element_type = + get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty); ListType::new(generator, ctx, element_type).as_base_type().into() } @@ -558,7 +553,7 @@ fn get_llvm_abi_type<'ctx, G: CodeGenerator + ?Sized>( ctx.bool_type().into() } else { get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, ty) - } + }; } /// Whether `sret` is needed for a return value with type `ty`. @@ -574,8 +569,9 @@ fn need_sret(ty: BasicTypeEnum) -> bool { match ty { BasicTypeEnum::IntType(_) | BasicTypeEnum::PointerType(_) => false, BasicTypeEnum::FloatType(_) if maybe_large => false, - BasicTypeEnum::StructType(ty) if maybe_large && ty.count_fields() <= 2 => - ty.get_field_types().iter().any(|ty| need_sret_impl(*ty, false)), + BasicTypeEnum::StructType(ty) if maybe_large && ty.count_fields() <= 2 => { + ty.get_field_types().iter().any(|ty| need_sret_impl(*ty, false)) + } _ => true, } } @@ -583,14 +579,18 @@ fn need_sret(ty: BasicTypeEnum) -> bool { } /// Implementation for generating LLVM IR for a function. -pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenContext) -> Result<(), String>> ( +pub fn gen_func_impl< + 'ctx, + G: CodeGenerator, + F: FnOnce(&mut G, &mut CodeGenContext) -> Result<(), String>, +>( context: &'ctx Context, generator: &mut G, registry: &WorkerRegistry, builder: Builder<'ctx>, module: Module<'ctx>, task: CodeGenTask, - codegen_function: F + codegen_function: F, ) -> Result<(Builder<'ctx>, Module<'ctx>, FunctionValue<'ctx>), (Builder<'ctx>, String)> { let top_level_ctx = registry.top_level_ctx.clone(); let static_value_store = registry.static_value_store.clone(); @@ -654,7 +654,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte str_type.set_body(&fields, false); str_type.into() } - Some(t) => t.as_basic_type_enum() + Some(t) => t.as_basic_type_enum(), } }), (primitives.range, RangeType::new(context).as_base_type().into()), @@ -671,7 +671,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte exception.set_body(&fields, false); exception.ptr_type(AddressSpace::default()).as_basic_type_enum() } - }) + }), ] .iter() .copied() @@ -679,8 +679,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte // NOTE: special handling of option cannot use this type cache since it contains type var, // handled inside get_llvm_type instead - let ConcreteTypeEnum::TFunc { args, ret, .. } = - task.store.get(task.signature) else { + let ConcreteTypeEnum::TFunc { args, ret, .. } = task.store.get(task.signature) else { unreachable!() }; @@ -697,7 +696,16 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte let ret_type = if unifier.unioned(ret, primitives.none) { None } else { - Some(get_llvm_abi_type(context, &module, generator, &mut unifier, top_level_ctx.as_ref(), &mut type_cache, &primitives, ret)) + Some(get_llvm_abi_type( + context, + &module, + generator, + &mut unifier, + top_level_ctx.as_ref(), + &mut type_cache, + &primitives, + ret, + )) }; let has_sret = ret_type.map_or(false, |ty| need_sret(ty)); @@ -724,7 +732,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte let fn_type = match ret_type { Some(ret_type) if !has_sret => ret_type.fn_type(¶ms, false), - _ => context.void_type().fn_type(¶ms, false) + _ => context.void_type().fn_type(¶ms, false), }; let symbol = &task.symbol_name; @@ -739,9 +747,13 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte fn_val.set_personality_function(personality); } if has_sret { - fn_val.add_attribute(AttributeLoc::Param(0), - context.create_type_attribute(Attribute::get_named_enum_kind_id("sret"), - ret_type.unwrap().as_any_type_enum())); + fn_val.add_attribute( + AttributeLoc::Param(0), + context.create_type_attribute( + Attribute::get_named_enum_kind_id("sret"), + ret_type.unwrap().as_any_type_enum(), + ), + ); } let init_bb = context.append_basic_block(fn_val, "init"); @@ -761,9 +773,8 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte &mut type_cache, arg.ty, ); - let alloca = builder - .build_alloca(local_type, &format!("{}.addr", &arg.name.to_string())) - .unwrap(); + let alloca = + builder.build_alloca(local_type, &format!("{}.addr", &arg.name.to_string())).unwrap(); // Remap boolean parameters into i8 let param = if local_type.is_int_type() && param.is_int_value() { @@ -774,7 +785,8 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte bool_to_i8(&builder, context, param_val) } else { param_val - }.into() + } + .into() } else { param }; @@ -808,10 +820,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte &task .body .first() - .map_or_else( - || "".to_string(), - |f| f.location.file.0.to_string(), - ), + .map_or_else(|| "".to_string(), |f| f.location.file.0.to_string()), /* directory */ "", /* producer */ "NAC3", /* is_optimized */ registry.llvm_options.opt_level != OptimizationLevel::None, @@ -884,10 +893,10 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte row as u32, col as u32, func_scope.as_debug_info_scope(), - None + None, ); code_gen_context.builder.set_current_debug_location(loc); - + let result = codegen_function(generator, &mut code_gen_context); // after static analysis, only void functions can have no return at the end. @@ -949,7 +958,7 @@ fn bool_to_i1<'ctx>(builder: &Builder<'ctx>, bool_value: IntValue<'ctx>) -> IntV fn bool_to_i8<'ctx>( builder: &Builder<'ctx>, ctx: &'ctx Context, - bool_value: IntValue<'ctx> + bool_value: IntValue<'ctx>, ) -> IntValue<'ctx> { let value_bits = bool_value.get_type().get_bit_width(); match value_bits { @@ -965,7 +974,7 @@ fn bool_to_i8<'ctx>( bool_value.get_type().const_zero(), "", ) - .unwrap() + .unwrap(), ), } } @@ -991,11 +1000,18 @@ fn gen_in_range_check<'ctx>( stop: IntValue<'ctx>, step: IntValue<'ctx>, ) -> IntValue<'ctx> { - let sign = ctx.builder.build_int_compare(IntPredicate::SGT, step, ctx.ctx.i32_type().const_zero(), "").unwrap(); - let lo = ctx.builder.build_select(sign, value, stop, "") + let sign = ctx + .builder + .build_int_compare(IntPredicate::SGT, step, ctx.ctx.i32_type().const_zero(), "") + .unwrap(); + let lo = ctx + .builder + .build_select(sign, value, stop, "") .map(BasicValueEnum::into_int_value) .unwrap(); - let hi = ctx.builder.build_select(sign, stop, value, "") + let hi = ctx + .builder + .build_select(sign, stop, value, "") .map(BasicValueEnum::into_int_value) .unwrap(); diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 1e685a77..30eb3b82 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -1,45 +1,36 @@ -use inkwell::{AddressSpace, IntPredicate, OptimizationLevel, types::BasicType, values::{BasicValueEnum, IntValue, PointerValue}}; -use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType}; -use nac3parser::ast::{Operator, StrRef}; use crate::{ codegen::{ classes::{ - ArrayLikeIndexer, - ArrayLikeValue, - ListType, - ListValue, - NDArrayType, - NDArrayValue, - ProxyType, - ProxyValue, - TypedArrayLikeAccessor, - TypedArrayLikeAdapter, - TypedArrayLikeMutator, - UntypedArrayLikeAccessor, - UntypedArrayLikeMutator, + ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, NDArrayType, NDArrayValue, + ProxyType, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, + TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }, - CodeGenContext, - CodeGenerator, expr::gen_binop_expr_with_values, irrt::{ - calculate_len_for_slice_range, - call_ndarray_calc_broadcast, - call_ndarray_calc_broadcast_index, - call_ndarray_calc_nd_indices, + calculate_len_for_slice_range, call_ndarray_calc_broadcast, + call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices, call_ndarray_calc_size, }, llvm_intrinsics, - llvm_intrinsics::{call_memcpy_generic}, + llvm_intrinsics::call_memcpy_generic, stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, + CodeGenContext, CodeGenerator, }, symbol_resolver::ValueEnum, toplevel::{ - DefinitionId, helper::PRIMITIVE_DEF_IDS, numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, + DefinitionId, }, typecheck::typedef::{FunSignature, Type, TypeEnum}, }; +use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType}; +use inkwell::{ + types::BasicType, + values::{BasicValueEnum, IntValue, PointerValue}, + AddressSpace, IntPredicate, OptimizationLevel, +}; +use nac3parser::ast::{Operator, StrRef}; /// Creates an uninitialized `NDArray` instance. fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>( @@ -51,16 +42,13 @@ fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>( let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_ndarray_t = ctx.get_llvm_type(generator, ndarray_ty) + let llvm_ndarray_t = ctx + .get_llvm_type(generator, ndarray_ty) .into_pointer_type() .get_element_type() .into_struct_type(); - let ndarray = generator.gen_var_alloc( - ctx, - llvm_ndarray_t.into(), - None, - )?; + let ndarray = generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?; Ok(NDArrayValue::from_ptr_val(ndarray, llvm_usize, None)) } @@ -79,10 +67,15 @@ fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>( shape_len_fn: LenFn, shape_data_fn: DataFn, ) -> Result, String> - where - G: CodeGenerator + ?Sized, - LenFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, &V) -> Result, String>, - DataFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, &V, IntValue<'ctx>) -> Result, String>, +where + G: CodeGenerator + ?Sized, + LenFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, &V) -> Result, String>, + DataFn: Fn( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + &V, + IntValue<'ctx>, + ) -> Result, String>, { let llvm_usize = generator.get_size_type(ctx.ctx); @@ -97,8 +90,14 @@ fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>( let shape_dim = shape_data_fn(generator, ctx, shape, i)?; debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width()); - let shape_dim_gez = ctx.builder - .build_int_compare(IntPredicate::SGE, shape_dim, shape_dim.get_type().const_zero(), "") + let shape_dim_gez = ctx + .builder + .build_int_compare( + IntPredicate::SGE, + shape_dim, + shape_dim.get_type().const_zero(), + "", + ) .unwrap(); ctx.make_assert( @@ -109,7 +108,7 @@ fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>( [None, None, None], ctx.current_loc, ); - + // TODO: Disallow dim_sz > u32_MAX Ok(()) @@ -135,13 +134,10 @@ fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>( |generator, ctx, i| { let shape_dim = shape_data_fn(generator, ctx, shape, i)?; debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width()); - let shape_dim = ctx.builder - .build_int_z_extend(shape_dim, llvm_usize, "") - .unwrap(); + let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap(); - let ndarray_pdim = unsafe { - ndarray.dim_sizes().ptr_offset_unchecked(ctx, generator, &i, None) - }; + let ndarray_pdim = + unsafe { ndarray.dim_sizes().ptr_offset_unchecked(ctx, generator, &i, None) }; ctx.builder.build_store(ndarray_pdim, shape_dim).unwrap(); @@ -168,7 +164,8 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>( let llvm_usize = generator.get_size_type(ctx.ctx); for shape_dim in shape { - let shape_dim_gez = ctx.builder + let shape_dim_gez = ctx + .builder .build_int_compare(IntPredicate::SGE, *shape_dim, llvm_usize.const_zero(), "") .unwrap(); @@ -194,9 +191,12 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>( for (i, shape_dim) in shape.iter().enumerate() { let ndarray_dim = unsafe { - ndarray - .dim_sizes() - .ptr_offset_unchecked(ctx, generator, &llvm_usize.const_int(i as u64, true), None) + ndarray.dim_sizes().ptr_offset_unchecked( + ctx, + generator, + &llvm_usize.const_int(i as u64, true), + None, + ) }; ctx.builder.build_store(ndarray_dim, *shape_dim).unwrap(); @@ -233,9 +233,15 @@ fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>( ctx: &mut CodeGenContext<'ctx, '_>, elem_ty: Type, ) -> BasicValueEnum<'ctx> { - if [ctx.primitives.int32, ctx.primitives.uint32].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) { + if [ctx.primitives.int32, ctx.primitives.uint32] + .iter() + .any(|ty| ctx.unifier.unioned(elem_ty, *ty)) + { ctx.ctx.i32_type().const_zero().into() - } else if [ctx.primitives.int64, ctx.primitives.uint64].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) { + } else if [ctx.primitives.int64, ctx.primitives.uint64] + .iter() + .any(|ty| ctx.unifier.unioned(elem_ty, *ty)) + { ctx.ctx.i64_type().const_zero().into() } else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) { ctx.ctx.f64_type().const_zero().into() @@ -253,10 +259,16 @@ fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>( ctx: &mut CodeGenContext<'ctx, '_>, elem_ty: Type, ) -> BasicValueEnum<'ctx> { - if [ctx.primitives.int32, ctx.primitives.uint32].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) { + if [ctx.primitives.int32, ctx.primitives.uint32] + .iter() + .any(|ty| ctx.unifier.unioned(elem_ty, *ty)) + { let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int32); ctx.ctx.i32_type().const_int(1, is_signed).into() - } else if [ctx.primitives.int64, ctx.primitives.uint64].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) { + } else if [ctx.primitives.int64, ctx.primitives.uint64] + .iter() + .any(|ty| ctx.unifier.unioned(elem_ty, *ty)) + { let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int64); ctx.ctx.i64_type().const_int(1, is_signed).into() } else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) { @@ -285,9 +297,7 @@ fn call_ndarray_empty_impl<'ctx, G: CodeGenerator + ?Sized>( ctx, elem_ty, &shape, - |_, ctx, shape| { - Ok(shape.load_size(ctx, None)) - }, + |_, ctx, shape| Ok(shape.load_size(ctx, None)), |generator, ctx, shape, idx| { Ok(shape.data().get(ctx, generator, &idx, None).into_int_value()) }, @@ -302,9 +312,13 @@ fn ndarray_fill_flattened<'ctx, 'a, G, ValueFn>( ndarray: NDArrayValue<'ctx>, value_fn: ValueFn, ) -> Result<(), String> - where - G: CodeGenerator + ?Sized, - ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result, String>, +where + G: CodeGenerator + ?Sized, + ValueFn: Fn( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + IntValue<'ctx>, + ) -> Result, String>, { let llvm_usize = generator.get_size_type(ctx.ctx); @@ -321,9 +335,7 @@ fn ndarray_fill_flattened<'ctx, 'a, G, ValueFn>( llvm_usize.const_zero(), (ndarray_num_elems, false), |generator, ctx, i| { - let elem = unsafe { - ndarray.data().ptr_offset_unchecked(ctx, generator, &i, None) - }; + let elem = unsafe { ndarray.data().ptr_offset_unchecked(ctx, generator, &i, None) }; let value = value_fn(generator, ctx, i)?; ctx.builder.build_store(elem, value).unwrap(); @@ -342,25 +354,19 @@ fn ndarray_fill_indexed<'ctx, 'a, G, ValueFn>( ndarray: NDArrayValue<'ctx>, value_fn: ValueFn, ) -> Result<(), String> - where - G: CodeGenerator + ?Sized, - ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, &TypedArrayLikeAdapter<'ctx, IntValue<'ctx>>) -> Result, String>, +where + G: CodeGenerator + ?Sized, + ValueFn: Fn( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + &TypedArrayLikeAdapter<'ctx, IntValue<'ctx>>, + ) -> Result, String>, { - ndarray_fill_flattened( - generator, - ctx, - ndarray, - |generator, ctx, idx| { - let indices = call_ndarray_calc_nd_indices( - generator, - ctx, - idx, - ndarray, - ); + ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, idx| { + let indices = call_ndarray_calc_nd_indices(generator, ctx, idx, ndarray); - value_fn(generator, ctx, &indices) - } - ) + value_fn(generator, ctx, &indices) + }) } fn ndarray_fill_mapping<'ctx, 'a, G, MapFn>( @@ -370,22 +376,19 @@ fn ndarray_fill_mapping<'ctx, 'a, G, MapFn>( dest: NDArrayValue<'ctx>, map_fn: MapFn, ) -> Result<(), String> - where - G: CodeGenerator + ?Sized, - MapFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, BasicValueEnum<'ctx>) -> Result, String>, +where + G: CodeGenerator + ?Sized, + MapFn: Fn( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + BasicValueEnum<'ctx>, + ) -> Result, String>, { - ndarray_fill_flattened( - generator, - ctx, - dest, - |generator, ctx, i| { - let elem = unsafe { - src.data().get_unchecked(ctx, generator, &i, None) - }; + ndarray_fill_flattened(generator, ctx, dest, |generator, ctx, i| { + let elem = unsafe { src.data().get_unchecked(ctx, generator, &i, None) }; - map_fn(generator, ctx, elem) - }, - ) + map_fn(generator, ctx, elem) + }) } /// Generates the LLVM IR for checking whether the source `ndarray` can be broadcast to the shape of @@ -419,19 +422,25 @@ fn ndarray_broadcast_fill<'ctx, 'a, G, ValueFn>( rhs: (BasicValueEnum<'ctx>, bool), value_fn: ValueFn, ) -> Result, String> - where - G: CodeGenerator + ?Sized, - ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>)) -> Result, String>, +where + G: CodeGenerator + ?Sized, + ValueFn: Fn( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>), + ) -> Result, String>, { let llvm_usize = generator.get_size_type(ctx.ctx); let (lhs_val, lhs_scalar) = lhs; let (rhs_val, rhs_scalar) = rhs; - assert!(!(lhs_scalar && rhs_scalar), - "One of the operands must be a ndarray instance: `{}`, `{}`", - lhs_val.get_type(), - rhs_val.get_type()); + assert!( + !(lhs_scalar && rhs_scalar), + "One of the operands must be a ndarray instance: `{}`, `{}`", + lhs_val.get_type(), + rhs_val.get_type() + ); // Assert that all ndarray operands are broadcastable to the target size if !lhs_scalar { @@ -444,36 +453,27 @@ fn ndarray_broadcast_fill<'ctx, 'a, G, ValueFn>( ndarray_assert_is_broadcastable(generator, ctx, res, rhs_val); } - ndarray_fill_indexed( - generator, - ctx, - res, - |generator, ctx, idx| { - let lhs_elem = if lhs_scalar { - lhs_val - } else { - let lhs = NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None); - let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, idx); + ndarray_fill_indexed(generator, ctx, res, |generator, ctx, idx| { + let lhs_elem = if lhs_scalar { + lhs_val + } else { + let lhs = NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None); + let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, idx); - unsafe { - lhs.data().get_unchecked(ctx, generator, &lhs_idx, None) - } - }; + unsafe { lhs.data().get_unchecked(ctx, generator, &lhs_idx, None) } + }; - let rhs_elem = if rhs_scalar { - rhs_val - } else { - let rhs = NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None); - let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, idx); - - unsafe { - rhs.data().get_unchecked(ctx, generator, &rhs_idx, None) - } - }; + let rhs_elem = if rhs_scalar { + rhs_val + } else { + let rhs = NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None); + let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, idx); - value_fn(generator, ctx, (lhs_elem, rhs_elem)) - }, - )?; + unsafe { rhs.data().get_unchecked(ctx, generator, &rhs_idx, None) } + }; + + value_fn(generator, ctx, (lhs_elem, rhs_elem)) + })?; Ok(res) } @@ -500,16 +500,11 @@ fn call_ndarray_zeros_impl<'ctx, G: CodeGenerator + ?Sized>( assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty))); let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; - ndarray_fill_flattened( - generator, - ctx, - ndarray, - |generator, ctx, _| { - let value = ndarray_zero_value(generator, ctx, elem_ty); + ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| { + let value = ndarray_zero_value(generator, ctx, elem_ty); - Ok(value) - } - )?; + Ok(value) + })?; Ok(ndarray) } @@ -536,16 +531,11 @@ fn call_ndarray_ones_impl<'ctx, G: CodeGenerator + ?Sized>( assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty))); let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; - ndarray_fill_flattened( - generator, - ctx, - ndarray, - |generator, ctx, _| { - let value = ndarray_one_value(generator, ctx, elem_ty); + ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| { + let value = ndarray_one_value(generator, ctx, elem_ty); - Ok(value) - } - )?; + Ok(value) + })?; Ok(ndarray) } @@ -562,34 +552,29 @@ fn call_ndarray_full_impl<'ctx, G: CodeGenerator + ?Sized>( fill_value: BasicValueEnum<'ctx>, ) -> Result, String> { let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; - ndarray_fill_flattened( - generator, - ctx, - ndarray, - |generator, ctx, _| { - let value = if fill_value.is_pointer_value() { - let llvm_i1 = ctx.ctx.bool_type(); + ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| { + let value = if fill_value.is_pointer_value() { + let llvm_i1 = ctx.ctx.bool_type(); - let copy = generator.gen_var_alloc(ctx, fill_value.get_type(), None)?; + let copy = generator.gen_var_alloc(ctx, fill_value.get_type(), None)?; - call_memcpy_generic( - ctx, - copy, - fill_value.into_pointer_value(), - fill_value.get_type().size_of().map(Into::into).unwrap(), - llvm_i1.const_zero(), - ); + call_memcpy_generic( + ctx, + copy, + fill_value.into_pointer_value(), + fill_value.get_type().size_of().map(Into::into).unwrap(), + llvm_i1.const_zero(), + ); - copy.into() - } else if fill_value.is_int_value() || fill_value.is_float_value() { - fill_value - } else { - unreachable!() - }; + copy.into() + } else if fill_value.is_int_value() || fill_value.is_float_value() { + fill_value + } else { + unreachable!() + }; - Ok(value) - } - )?; + Ok(value) + })?; Ok(ndarray) } @@ -656,7 +641,7 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>( match list_elem_ty { AnyTypeEnum::PointerType(ptr_ty) if ListType::is_type(ptr_ty, llvm_usize).is_ok() => { - // The stride of elements in this dimension, i.e. the number of elements between arr[i] + // The stride of elements in this dimension, i.e. the number of elements between arr[i] // and arr[i + 1] in this dimension let stride = call_ndarray_calc_size( generator, @@ -673,20 +658,14 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>( (|_, ctx| Ok(src_lst.load_size(ctx, None)), false), |_, _| Ok(llvm_usize.const_int(1, false)), |generator, ctx, i| { - let offset = ctx.builder.build_int_mul( - stride, - i, - "", - ).unwrap(); + let offset = ctx.builder.build_int_mul(stride, i, "").unwrap(); - let dst_ptr = unsafe { - ctx.builder.build_gep(dst_slice_ptr, &[offset], "").unwrap() - }; + let dst_ptr = + unsafe { ctx.builder.build_gep(dst_slice_ptr, &[offset], "").unwrap() }; let nested_lst_elem = ListValue::from_ptr_val( - unsafe { - src_lst.data().get_unchecked(ctx, generator, &i, None) - }.into_pointer_value(), + unsafe { src_lst.data().get_unchecked(ctx, generator, &i, None) } + .into_pointer_value(), llvm_usize, None, ); @@ -712,11 +691,14 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>( _ => { let lst_len = src_lst.load_size(ctx, None); let sizeof_elem = ctx.get_llvm_type(generator, elem_ty).size_of().unwrap(); - let cpy_len = ctx.builder.build_int_mul( - ctx.builder.build_int_z_extend_or_bit_cast(lst_len, llvm_usize, "").unwrap(), - sizeof_elem, - "" - ).unwrap(); + let cpy_len = ctx + .builder + .build_int_mul( + ctx.builder.build_int_z_extend_or_bit_cast(lst_len, llvm_usize, "").unwrap(), + sizeof_elem, + "", + ) + .unwrap(); call_memcpy_generic( ctx, @@ -743,27 +725,19 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>( let llvm_i1 = ctx.ctx.bool_type(); let llvm_usize = generator.get_size_type(ctx.ctx); - let ndmin = ctx.builder - .build_int_z_extend_or_bit_cast(ndmin, llvm_usize, "") - .unwrap(); + let ndmin = ctx.builder.build_int_z_extend_or_bit_cast(ndmin, llvm_usize, "").unwrap(); // TODO(Derppening): Add assertions for sizes of different dimensions // object is not a pointer - 0-dim NDArray if !object.is_pointer_value() { - let ndarray = create_ndarray_const_shape( - generator, - ctx, - elem_ty, - &[], - )?; + let ndarray = create_ndarray_const_shape(generator, ctx, elem_ty, &[])?; unsafe { - ndarray.data() - .set_unchecked(ctx, generator, &llvm_usize.const_zero(), object); + ndarray.data().set_unchecked(ctx, generator, &llvm_usize.const_zero(), object); } - return Ok(ndarray) + return Ok(ndarray); } let object = object.into_pointer_value(); @@ -776,16 +750,16 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>( generator, ctx, |_, ctx| { - let copy_nez = ctx.builder + let copy_nez = ctx + .builder .build_int_compare(IntPredicate::NE, copy, llvm_i1.const_zero(), "") .unwrap(); - let ndmin_gt_ndims = ctx.builder + let ndmin_gt_ndims = ctx + .builder .build_int_compare(IntPredicate::UGT, ndmin, object.load_ndims(ctx), "") .unwrap(); - Ok(ctx.builder - .build_and(copy_nez, ndmin_gt_ndims, "") - .unwrap()) + Ok(ctx.builder.build_and(copy_nez, ndmin_gt_ndims, "").unwrap()) }, |generator, ctx| { let ndarray = create_ndarray_dyn_shape( @@ -795,11 +769,13 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>( &object, |_, ctx, object| { let ndims = object.load_ndims(ctx); - let ndmin_gt_ndims = ctx.builder + let ndmin_gt_ndims = ctx + .builder .build_int_compare(IntPredicate::UGT, ndmin, object.load_ndims(ctx), "") .unwrap(); - Ok(ctx.builder + Ok(ctx + .builder .build_select(ndmin_gt_ndims, ndmin, ndims, "") .map(BasicValueEnum::into_int_value) .unwrap()) @@ -814,21 +790,16 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>( generator, ctx, |_, ctx| { - Ok(ctx.builder + Ok(ctx + .builder .build_int_compare(IntPredicate::UGE, idx, offset, "") .unwrap()) }, - |_, _| { - Ok(Some(llvm_usize.const_int(1, false))) - }, - |_, ctx| { - Ok(Some(ctx.builder.build_int_sub( - idx, - offset, - "" - ).unwrap())) - }, - )?.map(BasicValueEnum::into_int_value).unwrap()) + |_, _| Ok(Some(llvm_usize.const_int(1, false))), + |_, ctx| Ok(Some(ctx.builder.build_int_sub(idx, offset, "").unwrap())), + )? + .map(BasicValueEnum::into_int_value) + .unwrap()) }, )?; @@ -844,16 +815,14 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>( Ok(Some(ndarray.as_base_value())) }, - |_, _| { - Ok(Some(object.as_base_value())) - }, + |_, _| Ok(Some(object.as_base_value())), )?; return Ok(NDArrayValue::from_ptr_val( ndarray.map(BasicValueEnum::into_pointer_value).unwrap(), llvm_usize, None, - )) + )); } // Remaining case: TList @@ -872,11 +841,11 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>( &object, |generator, ctx, object| { let ndims = llvm_ndlist_get_ndims(generator, ctx, object.as_base_value().get_type()); - let ndmin_gt_ndims = ctx.builder - .build_int_compare(IntPredicate::UGT, ndmin, ndims, "") - .unwrap(); + let ndmin_gt_ndims = + ctx.builder.build_int_compare(IntPredicate::UGT, ndmin, ndims, "").unwrap(); - Ok(ctx.builder + Ok(ctx + .builder .build_select(ndmin_gt_ndims, ndmin, ndims, "") .map(BasicValueEnum::into_int_value) .unwrap()) @@ -886,20 +855,13 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>( generator, ctx, |_, ctx| { - Ok(ctx.builder - .build_int_compare(IntPredicate::ULT, idx, offset, "") - .unwrap()) - }, - |_, _| { - Ok(Some(llvm_usize.const_int(1, false))) + Ok(ctx.builder.build_int_compare(IntPredicate::ULT, idx, offset, "").unwrap()) }, + |_, _| Ok(Some(llvm_usize.const_int(1, false))), |generator, ctx| { let make_llvm_list = |elem_ty: BasicTypeEnum<'ctx>| { ctx.ctx.struct_type( - &[ - elem_ty.ptr_type(AddressSpace::default()).into(), - llvm_usize.into(), - ], + &[elem_ty.ptr_type(AddressSpace::default()).into(), llvm_usize.into()], false, ) }; @@ -909,19 +871,21 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>( let llvm_plist_i8 = llvm_list_i8.ptr_type(AddressSpace::default()); // Cast list to { i8*, usize } since we only care about the size - let lst = generator.gen_var_alloc( - ctx, - ListType::new(generator, ctx.ctx, llvm_i8.into()).as_base_type().into(), - None, - ).unwrap(); - ctx.builder.build_store( - lst, - ctx.builder.build_bitcast( - object.as_base_value(), - llvm_plist_i8, - "", - ).unwrap(), - ).unwrap(); + let lst = generator + .gen_var_alloc( + ctx, + ListType::new(generator, ctx.ctx, llvm_i8.into()).as_base_type().into(), + None, + ) + .unwrap(); + ctx.builder + .build_store( + lst, + ctx.builder + .build_bitcast(object.as_base_value(), llvm_plist_i8, "") + .unwrap(), + ) + .unwrap(); let stop = ctx.builder.build_int_sub(idx, offset, "").unwrap(); gen_for_range_callback( @@ -935,32 +899,32 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>( let plist_plist_i8 = make_llvm_list(llvm_plist_i8.into()) .ptr_type(AddressSpace::default()); - let this_dim = ctx.builder + let this_dim = ctx + .builder .build_load(lst, "") .map(BasicValueEnum::into_pointer_value) .map(|v| ctx.builder.build_bitcast(v, plist_plist_i8, "").unwrap()) .map(BasicValueEnum::into_pointer_value) .unwrap(); - let this_dim = ListValue::from_ptr_val( - this_dim, - llvm_usize, - None, - ); + let this_dim = ListValue::from_ptr_val(this_dim, llvm_usize, None); // TODO: Assert this_dim.sz != 0 let next_dim = unsafe { - this_dim.data() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }.into_pointer_value(); - ctx.builder.build_store( - lst, - ctx.builder.build_bitcast( - next_dim, - llvm_plist_i8, - "", - ).unwrap(), - ).unwrap(); + this_dim.data().get_unchecked( + ctx, + generator, + &llvm_usize.const_zero(), + None, + ) + } + .into_pointer_value(); + ctx.builder + .build_store( + lst, + ctx.builder.build_bitcast(next_dim, llvm_plist_i8, "").unwrap(), + ) + .unwrap(); Ok(()) }, @@ -977,7 +941,9 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>( Ok(Some(lst.load_size(ctx, None))) }, - )?.map(BasicValueEnum::into_int_value).unwrap()) + )? + .map(BasicValueEnum::into_int_value) + .unwrap()) }, )?; @@ -1010,44 +976,34 @@ fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>( let nrows = ctx.builder.build_int_z_extend_or_bit_cast(nrows, llvm_usize, "").unwrap(); let ncols = ctx.builder.build_int_z_extend_or_bit_cast(ncols, llvm_usize, "").unwrap(); - let ndarray = create_ndarray_const_shape( - generator, - ctx, - elem_ty, - &[nrows, ncols], - )?; + let ndarray = create_ndarray_const_shape(generator, ctx, elem_ty, &[nrows, ncols])?; - ndarray_fill_indexed( - generator, - ctx, - ndarray, - |generator, ctx, indices| { - let (row, col) = unsafe { - ( - indices.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None), - indices.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None), - ) - }; + ndarray_fill_indexed(generator, ctx, ndarray, |generator, ctx, indices| { + let (row, col) = unsafe { + ( + indices.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None), + indices.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None), + ) + }; - let col_with_offset = ctx.builder - .build_int_add( - col, - ctx.builder.build_int_s_extend_or_bit_cast(offset, llvm_i32, "").unwrap(), - "", - ) - .unwrap(); - let is_on_diag = ctx.builder - .build_int_compare(IntPredicate::EQ, row, col_with_offset, "") - .unwrap(); + let col_with_offset = ctx + .builder + .build_int_add( + col, + ctx.builder.build_int_s_extend_or_bit_cast(offset, llvm_i32, "").unwrap(), + "", + ) + .unwrap(); + let is_on_diag = + ctx.builder.build_int_compare(IntPredicate::EQ, row, col_with_offset, "").unwrap(); - let zero = ndarray_zero_value(generator, ctx, elem_ty); - let one = ndarray_one_value(generator, ctx, elem_ty); + let zero = ndarray_zero_value(generator, ctx, elem_ty); + let one = ndarray_one_value(generator, ctx, elem_ty); - let value = ctx.builder.build_select(is_on_diag, one, zero, "").unwrap(); + let value = ctx.builder.build_select(is_on_diag, one, zero, "").unwrap(); - Ok(value) - }, - )?; + Ok(value) + })?; Ok(ndarray) } @@ -1085,21 +1041,11 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>( (Some(llvm_usize.const_int(dim, false)), None), ); let sizeof_elem = ctx.get_llvm_type(generator, elem_ty).size_of().unwrap(); - let cpy_len = ctx.builder.build_int_mul( - stride, - sizeof_elem, - "" - ).unwrap(); + let cpy_len = ctx.builder.build_int_mul(stride, sizeof_elem, "").unwrap(); - call_memcpy_generic( - ctx, - dst_slice_ptr, - src_slice_ptr, - cpy_len, - llvm_i1.const_zero(), - ); + call_memcpy_generic(ctx, dst_slice_ptr, src_slice_ptr, cpy_len, llvm_i1.const_zero()); - return Ok(()) + return Ok(()); } // The stride of elements in this dimension, i.e. the number of elements between arr[i] and @@ -1134,20 +1080,10 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>( |_, _| Ok(step), |generator, ctx, src_i| { // Calculate the offset of the active slice - let src_data_offset = ctx.builder.build_int_mul( - src_stride, - src_i, - "", - ).unwrap(); - let dst_i = ctx.builder - .build_load(dst_i_addr, "") - .map(BasicValueEnum::into_int_value) - .unwrap(); - let dst_data_offset = ctx.builder.build_int_mul( - dst_stride, - dst_i, - "", - ).unwrap(); + let src_data_offset = ctx.builder.build_int_mul(src_stride, src_i, "").unwrap(); + let dst_i = + ctx.builder.build_load(dst_i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); + let dst_data_offset = ctx.builder.build_int_mul(dst_stride, dst_i, "").unwrap(); let (src_ptr, dst_ptr) = unsafe { ( @@ -1166,13 +1102,10 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>( &slices[1..], )?; - let dst_i = ctx.builder - .build_load(dst_i_addr, "") - .map(BasicValueEnum::into_int_value) - .unwrap(); - let dst_i_add1 = ctx.builder - .build_int_add(dst_i, llvm_usize.const_int(1, false), "") - .unwrap(); + let dst_i = + ctx.builder.build_load(dst_i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); + let dst_i_add1 = + ctx.builder.build_int_add(dst_i, llvm_usize.const_int(1, false), "").unwrap(); ctx.builder.build_store(dst_i_addr, dst_i_add1).unwrap(); Ok(()) @@ -1203,11 +1136,9 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>( ctx, elem_ty, &this, - |_, ctx, shape| { - Ok(shape.load_ndims(ctx)) - }, - |generator, ctx, shape, idx| { - unsafe { Ok(shape.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) } + |_, ctx, shape| Ok(shape.load_ndims(ctx)), + |generator, ctx, shape, idx| unsafe { + Ok(shape.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) }, )? } else { @@ -1220,36 +1151,39 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>( // Populate the first slices.len() dimensions by computing the size of each dim slice for (i, (start, stop, step)) in slices.iter().enumerate() { // HACK: workaround calculate_len_for_slice_range requiring exclusive stop - let stop = ctx.builder + let stop = ctx + .builder .build_select( - ctx.builder.build_int_compare( - IntPredicate::SLT, - *step, - llvm_i32.const_zero(), - "is_neg", - ).unwrap(), - ctx.builder.build_int_sub(*stop, llvm_i32.const_int(1, true), "e_min_one").unwrap(), - ctx.builder.build_int_add(*stop, llvm_i32.const_int(1, true), "e_add_one").unwrap(), + ctx.builder + .build_int_compare( + IntPredicate::SLT, + *step, + llvm_i32.const_zero(), + "is_neg", + ) + .unwrap(), + ctx.builder + .build_int_sub(*stop, llvm_i32.const_int(1, true), "e_min_one") + .unwrap(), + ctx.builder + .build_int_add(*stop, llvm_i32.const_int(1, true), "e_add_one") + .unwrap(), "final_e", ) .map(BasicValueEnum::into_int_value) .unwrap(); let slice_len = calculate_len_for_slice_range(generator, ctx, *start, stop, *step); - let slice_len = ctx.builder.build_int_z_extend_or_bit_cast( - slice_len, - llvm_usize, - "" - ).unwrap(); + let slice_len = + ctx.builder.build_int_z_extend_or_bit_cast(slice_len, llvm_usize, "").unwrap(); unsafe { - ndarray.dim_sizes() - .set_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(i as u64, false), - slice_len, - ); + ndarray.dim_sizes().set_typed_unchecked( + ctx, + generator, + &llvm_usize.const_int(i as u64, false), + slice_len, + ); } } @@ -1268,7 +1202,8 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>( Ok(()) }, llvm_usize.const_int(1, false), - ).unwrap(); + ) + .unwrap(); ndarray_init_data(generator, ctx, elem_ty, ndarray) }; @@ -1306,9 +1241,13 @@ pub fn ndarray_elementwise_unaryop_impl<'ctx, 'a, G, MapFn>( operand: NDArrayValue<'ctx>, map_fn: MapFn, ) -> Result, String> - where - G: CodeGenerator + ?Sized, - MapFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, BasicValueEnum<'ctx>) -> Result, String>, +where + G: CodeGenerator + ?Sized, + MapFn: Fn( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + BasicValueEnum<'ctx>, + ) -> Result, String>, { let res = res.unwrap_or_else(|| { create_ndarray_dyn_shape( @@ -1316,39 +1255,30 @@ pub fn ndarray_elementwise_unaryop_impl<'ctx, 'a, G, MapFn>( ctx, elem_ty, &operand, - |_, ctx, v| { - Ok(v.load_ndims(ctx)) + |_, ctx, v| Ok(v.load_ndims(ctx)), + |generator, ctx, v, idx| unsafe { + Ok(v.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) }, - |generator, ctx, v, idx| { - unsafe { - Ok(v.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) - } - }, - ).unwrap() + ) + .unwrap() }); - ndarray_fill_mapping( - generator, - ctx, - operand, - res, - |generator, ctx, elem| { - map_fn(generator, ctx, elem) - } - )?; + ndarray_fill_mapping(generator, ctx, operand, res, |generator, ctx, elem| { + map_fn(generator, ctx, elem) + })?; Ok(res) } /// LLVM-typed implementation for computing elementwise binary operations on two input operands. /// -/// If the operand is a `ndarray`, the broadcast index corresponding to each element in the output -/// is computed, the element accessed and used as an operand of the `value_fn` arguments tuple. -/// Otherwise, the operand is treated as a scalar value, and is used as an operand of the +/// If the operand is a `ndarray`, the broadcast index corresponding to each element in the output +/// is computed, the element accessed and used as an operand of the `value_fn` arguments tuple. +/// Otherwise, the operand is treated as a scalar value, and is used as an operand of the /// `value_fn` arguments tuple for all output elements. /// /// The second element of the tuple indicates whether to treat the operand value as a `ndarray` -/// (which would be accessed by its broadcast index) or as a scalar value (which would be +/// (which would be accessed by its broadcast index) or as a scalar value (which would be /// broadcast to all elements). /// /// * `elem_ty` - The element type of the `NDArray`. @@ -1368,24 +1298,32 @@ pub fn ndarray_elementwise_binop_impl<'ctx, 'a, G, ValueFn>( rhs: (BasicValueEnum<'ctx>, bool), value_fn: ValueFn, ) -> Result, String> - where - G: CodeGenerator + ?Sized, - ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>)) -> Result, String>, +where + G: CodeGenerator + ?Sized, + ValueFn: Fn( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>), + ) -> Result, String>, { let llvm_usize = generator.get_size_type(ctx.ctx); let (lhs_val, lhs_scalar) = lhs; let (rhs_val, rhs_scalar) = rhs; - assert!(!(lhs_scalar && rhs_scalar), - "One of the operands must be a ndarray instance: `{}`, `{}`", - lhs_val.get_type(), - rhs_val.get_type()); + assert!( + !(lhs_scalar && rhs_scalar), + "One of the operands must be a ndarray instance: `{}`, `{}`", + lhs_val.get_type(), + rhs_val.get_type() + ); let ndarray = res.unwrap_or_else(|| { if lhs_scalar && rhs_scalar { - let lhs_val = NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None); - let rhs_val = NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None); + let lhs_val = + NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None); + let rhs_val = + NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None); let ndarray_dims = call_ndarray_calc_broadcast(generator, ctx, lhs_val, rhs_val); @@ -1394,15 +1332,12 @@ pub fn ndarray_elementwise_binop_impl<'ctx, 'a, G, ValueFn>( ctx, elem_ty, &ndarray_dims, - |generator, ctx, v| { - Ok(v.size(ctx, generator)) + |generator, ctx, v| Ok(v.size(ctx, generator)), + |generator, ctx, v, idx| unsafe { + Ok(v.get_typed_unchecked(ctx, generator, &idx, None)) }, - |generator, ctx, v, idx| { - unsafe { - Ok(v.get_typed_unchecked(ctx, generator, &idx, None)) - } - }, - ).unwrap() + ) + .unwrap() } else { let ndarray = NDArrayValue::from_ptr_val( if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(), @@ -1415,28 +1350,18 @@ pub fn ndarray_elementwise_binop_impl<'ctx, 'a, G, ValueFn>( ctx, elem_ty, &ndarray, - |_, ctx, v| { - Ok(v.load_ndims(ctx)) + |_, ctx, v| Ok(v.load_ndims(ctx)), + |generator, ctx, v, idx| unsafe { + Ok(v.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) }, - |generator, ctx, v, idx| { - unsafe { - Ok(v.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) - } - }, - ).unwrap() + ) + .unwrap() } }); - ndarray_broadcast_fill( - generator, - ctx, - ndarray, - lhs, - rhs, - |generator, ctx, elems| { - value_fn(generator, ctx, elems) - }, - )?; + ndarray_broadcast_fill(generator, ctx, ndarray, lhs, rhs, |generator, ctx, elems| { + value_fn(generator, ctx, elems) + })?; Ok(ndarray) } @@ -1464,12 +1389,9 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( // lhs.ndims == 2 ctx.make_assert( generator, - ctx.builder.build_int_compare( - IntPredicate::EQ, - lhs_ndims, - llvm_usize.const_int(2, false), - "", - ).unwrap(), + ctx.builder + .build_int_compare(IntPredicate::EQ, lhs_ndims, llvm_usize.const_int(2, false), "") + .unwrap(), "0:ValueError", "", [None, None, None], @@ -1479,12 +1401,9 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( // rhs.ndims == 2 ctx.make_assert( generator, - ctx.builder.build_int_compare( - IntPredicate::EQ, - rhs_ndims, - llvm_usize.const_int(2, false), - "", - ).unwrap(), + ctx.builder + .build_int_compare(IntPredicate::EQ, rhs_ndims, llvm_usize.const_int(2, false), "") + .unwrap(), "0:ValueError", "", [None, None, None], @@ -1497,24 +1416,36 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( res.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; let res_dim1 = unsafe { - res.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) + res.dim_sizes().get_typed_unchecked( + ctx, + generator, + &llvm_usize.const_int(1, false), + None, + ) }; let lhs_dim0 = unsafe { lhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; let rhs_dim1 = unsafe { - rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) + rhs.dim_sizes().get_typed_unchecked( + ctx, + generator, + &llvm_usize.const_int(1, false), + None, + ) }; // res.ndims == 2 ctx.make_assert( generator, - ctx.builder.build_int_compare( - IntPredicate::EQ, - res_ndims, - llvm_usize.const_int(2, false), - "", - ).unwrap(), + ctx.builder + .build_int_compare( + IntPredicate::EQ, + res_ndims, + llvm_usize.const_int(2, false), + "", + ) + .unwrap(), "0:ValueError", "", [None, None, None], @@ -1524,12 +1455,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( // res.dims[0] == lhs.dims[0] ctx.make_assert( generator, - ctx.builder.build_int_compare( - IntPredicate::EQ, - lhs_dim0, - res_dim0, - "", - ).unwrap(), + ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim0, res_dim0, "").unwrap(), "0:ValueError", "", [None, None, None], @@ -1539,12 +1465,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( // res.dims[1] == rhs.dims[0] ctx.make_assert( generator, - ctx.builder.build_int_compare( - IntPredicate::EQ, - rhs_dim1, - res_dim1, - "", - ).unwrap(), + ctx.builder.build_int_compare(IntPredicate::EQ, rhs_dim1, res_dim1, "").unwrap(), "0:ValueError", "", [None, None, None], @@ -1555,7 +1476,12 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { let lhs_dim1 = unsafe { - lhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) + lhs.dim_sizes().get_typed_unchecked( + ctx, + generator, + &llvm_usize.const_int(1, false), + None, + ) }; let rhs_dim0 = unsafe { rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) @@ -1564,12 +1490,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( // lhs.dims[1] == rhs.dims[0] ctx.make_assert( generator, - ctx.builder.build_int_compare( - IntPredicate::EQ, - lhs_dim1, - rhs_dim0, - "", - ).unwrap(), + ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim1, rhs_dim0, "").unwrap(), "0:ValueError", "", [None, None, None], @@ -1589,20 +1510,16 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( ctx, elem_ty, &(lhs, rhs), - |_, _, _| { - Ok(llvm_usize.const_int(2, false)) - }, + |_, _, _| Ok(llvm_usize.const_int(2, false)), |generator, ctx, (lhs, rhs), idx| { gen_if_else_expr_callback( generator, ctx, |_, ctx| { - Ok(ctx.builder.build_int_compare( - IntPredicate::EQ, - idx, - llvm_usize.const_zero(), - "", - ).unwrap()) + Ok(ctx + .builder + .build_int_compare(IntPredicate::EQ, idx, llvm_usize.const_zero(), "") + .unwrap()) }, |generator, ctx| { Ok(Some(unsafe { @@ -1624,132 +1541,124 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( ) })) }, - ).map(|v| v.map(BasicValueEnum::into_int_value).unwrap()) + ) + .map(|v| v.map(BasicValueEnum::into_int_value).unwrap()) }, - ).unwrap() + ) + .unwrap() }); let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); - ndarray_fill_indexed( - generator, - ctx, - ndarray, - |generator, ctx, idx| { - llvm_intrinsics::call_expect( - ctx, - idx.size(ctx, generator).get_type().const_int(2, false), - idx.size(ctx, generator), - None, - ); + ndarray_fill_indexed(generator, ctx, ndarray, |generator, ctx, idx| { + llvm_intrinsics::call_expect( + ctx, + idx.size(ctx, generator).get_type().const_int(2, false), + idx.size(ctx, generator), + None, + ); - let common_dim = { - let lhs_idx1 = unsafe { - lhs.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - let rhs_idx0 = unsafe { - rhs.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_zero(), - None, - ) - }; - - let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None); - - ctx.builder.build_int_truncate(idx, llvm_i32, "").unwrap() - }; - - let idx0 = unsafe { - let idx0 = idx.get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_zero(), - None, - ); - - ctx.builder.build_int_truncate(idx0, llvm_i32, "").unwrap() - }; - let idx1 = unsafe { - let idx1 = idx.get_typed_unchecked( + let common_dim = { + let lhs_idx1 = unsafe { + lhs.dim_sizes().get_typed_unchecked( ctx, generator, &llvm_usize.const_int(1, false), None, - ); - - ctx.builder.build_int_truncate(idx1, llvm_i32, "").unwrap() + ) + }; + let rhs_idx0 = unsafe { + rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; - let result_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; - let result_identity = ndarray_zero_value(generator, ctx, elem_ty); - ctx.builder.build_store(result_addr, result_identity).unwrap(); + let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None); - gen_for_callback_incrementing( - generator, - ctx, - llvm_i32.const_zero(), - (common_dim, false), - |generator, ctx, i| { - let i = ctx.builder.build_int_truncate(i, llvm_i32, "").unwrap(); + ctx.builder.build_int_truncate(idx, llvm_i32, "").unwrap() + }; - let ab_idx = generator.gen_array_var_alloc( + let idx0 = unsafe { + let idx0 = idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None); + + ctx.builder.build_int_truncate(idx0, llvm_i32, "").unwrap() + }; + let idx1 = unsafe { + let idx1 = + idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None); + + ctx.builder.build_int_truncate(idx1, llvm_i32, "").unwrap() + }; + + let result_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; + let result_identity = ndarray_zero_value(generator, ctx, elem_ty); + ctx.builder.build_store(result_addr, result_identity).unwrap(); + + gen_for_callback_incrementing( + generator, + ctx, + llvm_i32.const_zero(), + (common_dim, false), + |generator, ctx, i| { + let i = ctx.builder.build_int_truncate(i, llvm_i32, "").unwrap(); + + let ab_idx = generator.gen_array_var_alloc( + ctx, + llvm_i32.into(), + llvm_usize.const_int(2, false), + None, + )?; + + let a = unsafe { + ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), idx0.into()); + ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_int(1, false), i.into()); + + lhs.data().get_unchecked(ctx, generator, &ab_idx, None) + }; + let b = unsafe { + ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), i.into()); + ab_idx.set_unchecked( ctx, - llvm_i32.into(), - llvm_usize.const_int(2, false), - None, - )?; - - let a = unsafe { - ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), idx0.into()); - ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_int(1, false), i.into()); - - lhs.data().get_unchecked(ctx, generator, &ab_idx, None) - }; - let b = unsafe { - ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), i.into()); - ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_int(1, false), idx1.into()); - - rhs.data().get_unchecked(ctx, generator, &ab_idx, None) - }; - - let a_mul_b = gen_binop_expr_with_values( generator, - ctx, - (&Some(elem_ty), a), - &Operator::Mult, - (&Some(elem_ty), b), - ctx.current_loc, - false, - )?.unwrap().to_basic_value_enum(ctx, generator, elem_ty)?; + &llvm_usize.const_int(1, false), + idx1.into(), + ); - let result = ctx.builder.build_load(result_addr, "").unwrap(); - let result = gen_binop_expr_with_values( - generator, - ctx, - (&Some(elem_ty), result), - &Operator::Add, - (&Some(elem_ty), a_mul_b), - ctx.current_loc, - false, - )?.unwrap().to_basic_value_enum(ctx, generator, elem_ty)?; - ctx.builder.build_store(result_addr, result).unwrap(); + rhs.data().get_unchecked(ctx, generator, &ab_idx, None) + }; - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; + let a_mul_b = gen_binop_expr_with_values( + generator, + ctx, + (&Some(elem_ty), a), + &Operator::Mult, + (&Some(elem_ty), b), + ctx.current_loc, + false, + )? + .unwrap() + .to_basic_value_enum(ctx, generator, elem_ty)?; - let result = ctx.builder.build_load(result_addr, "").unwrap(); - Ok(result) - } - )?; + let result = ctx.builder.build_load(result_addr, "").unwrap(); + let result = gen_binop_expr_with_values( + generator, + ctx, + (&Some(elem_ty), result), + &Operator::Add, + (&Some(elem_ty), a_mul_b), + ctx.current_loc, + false, + )? + .unwrap() + .to_basic_value_enum(ctx, generator, elem_ty)?; + ctx.builder.build_store(result_addr, result).unwrap(); + + Ok(()) + }, + llvm_usize.const_int(1, false), + )?; + + let result = ctx.builder.build_load(result_addr, "").unwrap(); + Ok(result) + })?; Ok(ndarray) } @@ -1767,15 +1676,15 @@ pub fn gen_ndarray_empty<'ctx>( let llvm_usize = generator.get_size_type(context.ctx); let shape_ty = fun.0.args[0].ty; - let shape_arg = args[0].1.clone() - .to_basic_value_enum(context, generator, shape_ty)?; + let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; call_ndarray_empty_impl( generator, context, context.primitives.float, ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None), - ).map(NDArrayValue::into) + ) + .map(NDArrayValue::into) } /// Generates LLVM IR for `ndarray.zeros`. @@ -1791,15 +1700,15 @@ pub fn gen_ndarray_zeros<'ctx>( let llvm_usize = generator.get_size_type(context.ctx); let shape_ty = fun.0.args[0].ty; - let shape_arg = args[0].1.clone() - .to_basic_value_enum(context, generator, shape_ty)?; + let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; call_ndarray_zeros_impl( generator, context, context.primitives.float, ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None), - ).map(NDArrayValue::into) + ) + .map(NDArrayValue::into) } /// Generates LLVM IR for `ndarray.ones`. @@ -1815,15 +1724,15 @@ pub fn gen_ndarray_ones<'ctx>( let llvm_usize = generator.get_size_type(context.ctx); let shape_ty = fun.0.args[0].ty; - let shape_arg = args[0].1.clone() - .to_basic_value_enum(context, generator, shape_ty)?; + let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; call_ndarray_ones_impl( generator, context, context.primitives.float, ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None), - ).map(NDArrayValue::into) + ) + .map(NDArrayValue::into) } /// Generates LLVM IR for `ndarray.full`. @@ -1839,11 +1748,10 @@ pub fn gen_ndarray_full<'ctx>( let llvm_usize = generator.get_size_type(context.ctx); let shape_ty = fun.0.args[0].ty; - let shape_arg = args[0].1.clone() - .to_basic_value_enum(context, generator, shape_ty)?; + let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; let fill_value_ty = fun.0.args[1].ty; - let fill_value_arg = args[1].1.clone() - .to_basic_value_enum(context, generator, fill_value_ty)?; + let fill_value_arg = + args[1].1.clone().to_basic_value_enum(context, generator, fill_value_ty)?; call_ndarray_full_impl( generator, @@ -1851,7 +1759,8 @@ pub fn gen_ndarray_full<'ctx>( fill_value_ty, ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None), fill_value_arg, - ).map(NDArrayValue::into) + ) + .map(NDArrayValue::into) } pub fn gen_ndarray_array<'ctx>( @@ -1876,15 +1785,15 @@ pub fn gen_ndarray_array<'ctx>( ty = *elem_ty; } ty - }, + } _ => obj_ty, }; - let obj_arg = args[0].1.clone() - .to_basic_value_enum(context, generator, obj_ty)?; + let obj_arg = args[0].1.clone().to_basic_value_enum(context, generator, obj_ty)?; let copy_arg = if let Some(arg) = - args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name)) { + args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name)) + { let copy_ty = fun.0.args[1].ty; arg.1.clone().to_basic_value_enum(context, generator, copy_ty)? } else { @@ -1892,11 +1801,12 @@ pub fn gen_ndarray_array<'ctx>( generator, fun.0.args[1].default_value.as_ref().unwrap(), fun.0.args[1].ty, - ) + ) }; let ndmin_arg = if let Some(arg) = - args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[2].name)) { + args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[2].name)) + { let ndmin_ty = fun.0.args[2].ty; arg.1.clone().to_basic_value_enum(context, generator, ndmin_ty)? } else { @@ -1905,7 +1815,7 @@ pub fn gen_ndarray_array<'ctx>( fun.0.args[2].default_value.as_ref().unwrap(), fun.0.args[2].ty, ) - }; + }; call_ndarray_array_impl( generator, @@ -1914,7 +1824,8 @@ pub fn gen_ndarray_array<'ctx>( obj_arg, copy_arg.into_int_value(), ndmin_arg.into_int_value(), - ).map(NDArrayValue::into) + ) + .map(NDArrayValue::into) } /// Generates LLVM IR for `ndarray.eye`. @@ -1929,12 +1840,12 @@ pub fn gen_ndarray_eye<'ctx>( assert!(matches!(args.len(), 1..=3)); let nrows_ty = fun.0.args[0].ty; - let nrows_arg = args[0].1.clone() - .to_basic_value_enum(context, generator, nrows_ty)?; + let nrows_arg = args[0].1.clone().to_basic_value_enum(context, generator, nrows_ty)?; let ncols_ty = fun.0.args[1].ty; - let ncols_arg = if let Some(arg) = - args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name)) { + let ncols_arg = if let Some(arg) = + args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name)) + { arg.1.clone().to_basic_value_enum(context, generator, ncols_ty) } else { args[0].1.clone().to_basic_value_enum(context, generator, nrows_ty) @@ -1942,14 +1853,15 @@ pub fn gen_ndarray_eye<'ctx>( let offset_ty = fun.0.args[2].ty; let offset_arg = if let Some(arg) = - args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[2].name)) { - arg.1.clone().to_basic_value_enum(context, generator, offset_ty) + args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[2].name)) + { + arg.1.clone().to_basic_value_enum(context, generator, offset_ty) } else { Ok(context.gen_symbol_val( generator, fun.0.args[2].default_value.as_ref().unwrap(), - offset_ty - )) + offset_ty, + )) }?; call_ndarray_eye_impl( @@ -1959,7 +1871,8 @@ pub fn gen_ndarray_eye<'ctx>( nrows_arg.into_int_value(), ncols_arg.into_int_value(), offset_arg.into_int_value(), - ).map(NDArrayValue::into) + ) + .map(NDArrayValue::into) } /// Generates LLVM IR for `ndarray.identity`. @@ -1976,8 +1889,7 @@ pub fn gen_ndarray_identity<'ctx>( let llvm_usize = generator.get_size_type(context.ctx); let n_ty = fun.0.args[0].ty; - let n_arg = args[0].1.clone() - .to_basic_value_enum(context, generator, n_ty)?; + let n_arg = args[0].1.clone().to_basic_value_enum(context, generator, n_ty)?; call_ndarray_eye_impl( generator, @@ -1986,7 +1898,8 @@ pub fn gen_ndarray_identity<'ctx>( n_arg.into_int_value(), n_arg.into_int_value(), llvm_usize.const_zero(), - ).map(NDArrayValue::into) + ) + .map(NDArrayValue::into) } /// Generates LLVM IR for `ndarray.copy`. @@ -2004,19 +1917,16 @@ pub fn gen_ndarray_copy<'ctx>( let this_ty = obj.as_ref().unwrap().0; let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty); - let this_arg = obj - .as_ref() - .unwrap() - .1 - .clone() - .to_basic_value_enum(context, generator, this_ty)?; + let this_arg = + obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?; ndarray_copy_impl( generator, context, this_elem_ty, NDArrayValue::from_ptr_val(this_arg.into_pointer_value(), llvm_usize, None), - ).map(NDArrayValue::into) + ) + .map(NDArrayValue::into) } /// Generates LLVM IR for `ndarray.fill`. @@ -2033,12 +1943,15 @@ pub fn gen_ndarray_fill<'ctx>( let llvm_usize = generator.get_size_type(context.ctx); let this_ty = obj.as_ref().unwrap().0; - let this_arg = obj.as_ref().unwrap().1.clone() + let this_arg = obj + .as_ref() + .unwrap() + .1 + .clone() .to_basic_value_enum(context, generator, this_ty)? .into_pointer_value(); let value_ty = fun.0.args[0].ty; - let value_arg = args[0].1.clone() - .to_basic_value_enum(context, generator, value_ty)?; + let value_arg = args[0].1.clone().to_basic_value_enum(context, generator, value_ty)?; ndarray_fill_flattened( generator, @@ -2066,8 +1979,8 @@ pub fn gen_ndarray_fill<'ctx>( }; Ok(value) - } + }, )?; Ok(()) -} \ No newline at end of file +} diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index e5843ceb..f14709c8 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -11,10 +11,7 @@ use crate::{ gen_in_range_check, }, toplevel::{ - DefinitionId, - helper::PRIMITIVE_DEF_IDS, - numpy::unpack_ndarray_var_tys, - TopLevelDef, + helper::PRIMITIVE_DEF_IDS, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef, }, typecheck::typedef::{FunSignature, Type, TypeEnum}, }; @@ -116,13 +113,13 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>( ctx.var_assignment.insert(*id, (ptr, None, counter)); ptr } - } + }, ExprKind::Attribute { value, attr, .. } => { let index = ctx.get_attr_index(value.custom.unwrap(), *attr); let val = if let Some(v) = generator.gen_expr(ctx, value)? { v.to_basic_value_enum(ctx, generator, value.custom.unwrap())? } else { - return Ok(None) + return Ok(None); }; let BasicValueEnum::PointerValue(ptr) = val else { unreachable!(); @@ -136,7 +133,8 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>( ], name.unwrap_or(""), ) - }.unwrap() + } + .unwrap() } ExprKind::Subscript { value, slice, .. } => { match ctx.unifier.get_ty_immutable(value.custom.unwrap()).as_ref() { @@ -153,11 +151,13 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>( .unwrap() .to_basic_value_enum(ctx, generator, slice.custom.unwrap())? .into_int_value(); - let raw_index = ctx.builder + let raw_index = ctx + .builder .build_int_s_extend(raw_index, generator.get_size_type(ctx.ctx), "sext") .unwrap(); // handle negative index - let is_negative = ctx.builder + let is_negative = ctx + .builder .build_int_compare( IntPredicate::SLT, raw_index, @@ -173,13 +173,9 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>( .unwrap(); // unsigned less than is enough, because negative index after adjustment is // bigger than the length (for unsigned cmp) - let bound_check = ctx.builder - .build_int_compare( - IntPredicate::ULT, - index, - len, - "inbound", - ) + let bound_check = ctx + .builder + .build_int_compare(IntPredicate::ULT, index, len, "inbound") .unwrap(); ctx.make_assert( generator, @@ -215,7 +211,8 @@ pub fn gen_assign<'ctx, G: CodeGenerator>( match &target.node { ExprKind::Tuple { elts, .. } => { let BasicValueEnum::StructValue(v) = - value.to_basic_value_enum(ctx, generator, target.custom.unwrap())? else { + value.to_basic_value_enum(ctx, generator, target.custom.unwrap())? + else { unreachable!() }; @@ -230,9 +227,7 @@ pub fn gen_assign<'ctx, G: CodeGenerator>( ExprKind::Subscript { value: ls, slice, .. } if matches!(&slice.node, ExprKind::Slice { .. }) => { - let ExprKind::Slice { lower, upper, step } = &slice.node else { - unreachable!() - }; + let ExprKind::Slice { lower, upper, step } = &slice.node else { unreachable!() }; let ls = generator .gen_expr(ctx, ls)? @@ -240,14 +235,11 @@ pub fn gen_assign<'ctx, G: CodeGenerator>( .to_basic_value_enum(ctx, generator, ls.custom.unwrap())? .into_pointer_value(); let ls = ListValue::from_ptr_val(ls, llvm_usize, None); - let Some((start, end, step)) = handle_slice_indices( - lower, - upper, - step, - ctx, - generator, - ls.load_size(ctx, None), - )? else { return Ok(()) }; + let Some((start, end, step)) = + handle_slice_indices(lower, upper, step, ctx, generator, ls.load_size(ctx, None))? + else { + return Ok(()); + }; let value = value .to_basic_value_enum(ctx, generator, target.custom.unwrap())? .into_pointer_value(); @@ -268,7 +260,10 @@ pub fn gen_assign<'ctx, G: CodeGenerator>( ctx, generator, value.load_size(ctx, None), - )? else { return Ok(()) }; + )? + else { + return Ok(()); + }; list_slice_assignment(generator, ctx, ty, ls, (start, end, step), value, src_ind); } _ => { @@ -278,7 +273,7 @@ pub fn gen_assign<'ctx, G: CodeGenerator>( String::from("target.addr") }; let Some(ptr) = generator.gen_store_target(ctx, target, Some(name.as_str()))? else { - return Ok(()) + return Ok(()); }; if let ExprKind::Name { id, .. } = &target.node { @@ -301,9 +296,7 @@ pub fn gen_for( ctx: &mut CodeGenContext<'_, '_>, stmt: &Stmt>, ) -> Result<(), String> { - let StmtKind::For { iter, target, body, orelse, .. } = &stmt.node else { - unreachable!() - }; + let StmtKind::For { iter, target, body, orelse, .. } = &stmt.node else { unreachable!() }; // var_assignment static values may be changed in another branch // if so, remove the static value as it may not be correct in this branch @@ -316,11 +309,8 @@ pub fn gen_for( let body_bb = ctx.ctx.append_basic_block(current, "for.body"); let cont_bb = ctx.ctx.append_basic_block(current, "for.end"); // if there is no orelse, we just go to cont_bb - let orelse_bb = if orelse.is_empty() { - cont_bb - } else { - ctx.ctx.append_basic_block(current, "for.orelse") - }; + let orelse_bb = + if orelse.is_empty() { cont_bb } else { ctx.ctx.append_basic_block(current, "for.orelse") }; // Whether the iterable is a range() expression let is_iterable_range_expr = ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range); @@ -334,20 +324,17 @@ pub fn gen_for( let loop_bb = ctx.loop_target.replace((incr_bb, cont_bb)); let iter_val = if let Some(v) = generator.gen_expr(ctx, iter)? { - v.to_basic_value_enum( - ctx, - generator, - iter.custom.unwrap(), - )? + v.to_basic_value_enum(ctx, generator, iter.custom.unwrap())? } else { - return Ok(()) + return Ok(()); }; if is_iterable_range_expr { let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range")); // Internal variable for loop; Cannot be assigned let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?; // Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed - let Some(target_i) = generator.gen_store_target(ctx, target, Some("for.target.addr"))? else { + let Some(target_i) = generator.gen_store_target(ctx, target, Some("for.target.addr"))? + else { unreachable!() }; let (start, stop, step) = destructure_range(ctx, iter_val); @@ -355,16 +342,15 @@ pub fn gen_for( ctx.builder.build_store(i, start).unwrap(); // Check "If step is zero, ValueError is raised." - let rangenez = ctx.builder - .build_int_compare(IntPredicate::NE, step, int32.const_zero(), "") - .unwrap(); + let rangenez = + ctx.builder.build_int_compare(IntPredicate::NE, step, int32.const_zero(), "").unwrap(); ctx.make_assert( generator, rangenez, "ValueError", "range() arg 3 must not be zero", [None, None, None], - ctx.current_loc + ctx.current_loc, ); ctx.builder.build_unconditional_branch(cond_bb).unwrap(); @@ -385,7 +371,8 @@ pub fn gen_for( } ctx.builder.position_at_end(incr_bb); - let next_i = ctx.builder + let next_i = ctx + .builder .build_int_add( ctx.builder.build_load(i, "").map(BasicValueEnum::into_int_value).unwrap(), step, @@ -410,13 +397,14 @@ pub fn gen_for( .build_gep_and_load( iter_val.into_pointer_value(), &[zero, int32.const_int(1, false)], - Some("len") + Some("len"), ) .into_int_value(); ctx.builder.build_unconditional_branch(cond_bb).unwrap(); ctx.builder.position_at_end(cond_bb); - let index = ctx.builder + let index = ctx + .builder .build_load(index_addr, "for.index") .map(BasicValueEnum::into_int_value) .unwrap(); @@ -424,7 +412,8 @@ pub fn gen_for( ctx.builder.build_conditional_branch(cmp, body_bb, orelse_bb).unwrap(); ctx.builder.position_at_end(incr_bb); - let index = ctx.builder.build_load(index_addr, "").map(BasicValueEnum::into_int_value).unwrap(); + let index = + ctx.builder.build_load(index_addr, "").map(BasicValueEnum::into_int_value).unwrap(); let inc = ctx.builder.build_int_add(index, size_t.const_int(1, true), "inc").unwrap(); ctx.builder.build_store(index_addr, inc).unwrap(); ctx.builder.build_unconditional_branch(cond_bb).unwrap(); @@ -433,7 +422,8 @@ pub fn gen_for( let arr_ptr = ctx .build_gep_and_load(iter_val.into_pointer_value(), &[zero, zero], Some("arr.addr")) .into_pointer_value(); - let index = ctx.builder + let index = ctx + .builder .build_load(index_addr, "for.index") .map(BasicValueEnum::into_int_value) .unwrap(); @@ -496,13 +486,13 @@ pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>( body: BodyFn, update: UpdateFn, ) -> Result<(), String> - where - G: CodeGenerator + ?Sized, - I: Clone, - InitFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result, - CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result, String>, - BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>, - UpdateFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>, +where + G: CodeGenerator + ?Sized, + I: Clone, + InitFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result, + CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result, String>, + BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>, + UpdateFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>, { let current_bb = ctx.builder.get_insert_block().unwrap(); let init_bb = ctx.ctx.insert_basic_block_after(current_bb, "for.init"); @@ -528,9 +518,7 @@ pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>( let cond = cond(generator, ctx, loop_var.clone())?; assert_eq!(cond.get_type().get_bit_width(), ctx.ctx.bool_type().get_bit_width()); if !ctx.is_terminated() { - ctx.builder - .build_conditional_branch(cond, body_bb, cont_bb) - .unwrap(); + ctx.builder.build_conditional_branch(cond, body_bb, cont_bb).unwrap(); } ctx.builder.position_at_end(body_bb); @@ -551,7 +539,7 @@ pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>( Ok(()) } -/// Generates a C-style monotonically-increasing `for` construct using lambdas, similar to the +/// Generates a C-style monotonically-increasing `for` construct using lambdas, similar to the /// following C code: /// /// ```c @@ -560,7 +548,7 @@ pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>( /// } /// ``` /// -/// * `init_val` - The initial value of the loop variable. The type of this value will also be used +/// * `init_val` - The initial value of the loop variable. The type of this value will also be used /// as the type of the loop variable. /// * `max_val` - A tuple containing the maximum value of the loop variable, and whether the maximum /// value should be treated as inclusive (as opposed to exclusive). @@ -574,9 +562,9 @@ pub fn gen_for_callback_incrementing<'ctx, 'a, G, BodyFn>( body: BodyFn, incr_val: IntValue<'ctx>, ) -> Result<(), String> - where - G: CodeGenerator + ?Sized, - BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<(), String>, +where + G: CodeGenerator + ?Sized, + BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<(), String>, { let init_val_t = init_val.get_type(); @@ -590,38 +578,23 @@ pub fn gen_for_callback_incrementing<'ctx, 'a, G, BodyFn>( Ok(i_addr) }, |_, ctx, i_addr| { - let cmp_op = if max_val.1 { - IntPredicate::ULE - } else { - IntPredicate::ULT - }; + let cmp_op = if max_val.1 { IntPredicate::ULE } else { IntPredicate::ULT }; - let i = ctx.builder - .build_load(i_addr, "") - .map(BasicValueEnum::into_int_value) - .unwrap(); - let max_val = ctx.builder - .build_int_z_extend_or_bit_cast(max_val.0, init_val_t, "") - .unwrap(); + let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); + let max_val = + ctx.builder.build_int_z_extend_or_bit_cast(max_val.0, init_val_t, "").unwrap(); Ok(ctx.builder.build_int_compare(cmp_op, i, max_val, "").unwrap()) }, |generator, ctx, i_addr| { - let i = ctx.builder - .build_load(i_addr, "") - .map(BasicValueEnum::into_int_value) - .unwrap(); + let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); body(generator, ctx, i) }, |_, ctx, i_addr| { - let i = ctx.builder - .build_load(i_addr, "") - .map(BasicValueEnum::into_int_value) - .unwrap(); - let incr_val = ctx.builder - .build_int_z_extend_or_bit_cast(incr_val, init_val_t, "") - .unwrap(); + let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); + let incr_val = + ctx.builder.build_int_z_extend_or_bit_cast(incr_val, init_val_t, "").unwrap(); let i = ctx.builder.build_int_add(i, incr_val, "").unwrap(); ctx.builder.build_store(i_addr, i).unwrap(); @@ -632,21 +605,21 @@ pub fn gen_for_callback_incrementing<'ctx, 'a, G, BodyFn>( /// Generates a `for` construct over a `range`-like iterable using lambdas, similar to the following /// C code: -/// +/// /// ```c /// bool incr = start_fn() <= end_fn(); /// for (int i = start_fn(); i /* < or > */ end_fn(); i += step_fn()) { /// body_fn(i); /// } /// ``` -/// +/// /// - `is_unsigned`: Whether to treat the values of the `range` as unsigned. -/// - `start_fn`: A lambda of IR statements that retrieves the `start` value of the `range`-like +/// - `start_fn`: A lambda of IR statements that retrieves the `start` value of the `range`-like /// iterable. /// - `stop_fn`: A lambda of IR statements that retrieves the `stop` value of the `range`-like /// iterable. This value will be extended to the size of `start`. /// - `stop_inclusive`: Whether the stop value should be treated as inclusive. -/// - `step_fn`: A lambda of IR statements that retrieves the `step` value of the `range`-like +/// - `step_fn`: A lambda of IR statements that retrieves the `step` value of the `range`-like /// iterable. This value will be extended to the size of `start`. /// - `body_fn`: A lambda of IR statements within the loop body. pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>( @@ -658,16 +631,14 @@ pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>( step_fn: StepFn, body_fn: BodyFn, ) -> Result<(), String> - where - G: CodeGenerator + ?Sized, - StartFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result, String>, - StopFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result, String>, - StepFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result, String>, - BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<(), String>, +where + G: CodeGenerator + ?Sized, + StartFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result, String>, + StopFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result, String>, + StepFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result, String>, + BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<(), String>, { - let init_val_t = start_fn(generator, ctx) - .map(IntValue::get_type) - .unwrap(); + let init_val_t = start_fn(generator, ctx).map(IntValue::get_type).unwrap(); gen_for_callback( generator, @@ -688,12 +659,15 @@ pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>( ctx.builder.build_int_s_extend(stop, start.get_type(), "").unwrap() }; - let incr = ctx.builder.build_int_compare( - if is_unsigned { IntPredicate::ULE } else { IntPredicate::SLE }, - start, - stop, - "", - ).unwrap(); + let incr = ctx + .builder + .build_int_compare( + if is_unsigned { IntPredicate::ULE } else { IntPredicate::SLE }, + start, + stop, + "", + ) + .unwrap(); Ok((i_addr, incr)) }, @@ -705,10 +679,7 @@ pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>( (false, false) => (IntPredicate::SLT, IntPredicate::SGT), }; - let i = ctx.builder - .build_load(i_addr, "") - .map(BasicValueEnum::into_int_value) - .unwrap(); + let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); let stop = stop_fn(generator, ctx)?; let stop = if stop.get_type().get_bit_width() == i.get_type().get_bit_width() { stop @@ -718,14 +689,11 @@ pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>( ctx.builder.build_int_s_extend(stop, i.get_type(), "").unwrap() }; - let i_lt_end = ctx.builder - .build_int_compare(lt_cmp_op, i, stop, "") - .unwrap(); - let i_gt_end = ctx.builder - .build_int_compare(gt_cmp_op, i, stop, "") - .unwrap(); + let i_lt_end = ctx.builder.build_int_compare(lt_cmp_op, i, stop, "").unwrap(); + let i_gt_end = ctx.builder.build_int_compare(gt_cmp_op, i, stop, "").unwrap(); - let cond = ctx.builder + let cond = ctx + .builder .build_select(incr, i_lt_end, i_gt_end, "") .map(BasicValueEnum::into_int_value) .unwrap(); @@ -733,18 +701,12 @@ pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>( Ok(cond) }, |generator, ctx, (i_addr, _)| { - let i = ctx.builder - .build_load(i_addr, "") - .map(BasicValueEnum::into_int_value) - .unwrap(); + let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); body_fn(generator, ctx, i) }, |generator, ctx, (i_addr, _)| { - let i = ctx.builder - .build_load(i_addr, "") - .map(BasicValueEnum::into_int_value) - .unwrap(); + let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); let incr_val = step_fn(generator, ctx)?; let incr_val = if incr_val.get_type().get_bit_width() == i.get_type().get_bit_width() { @@ -769,9 +731,7 @@ pub fn gen_while( ctx: &mut CodeGenContext<'_, '_>, stmt: &Stmt>, ) -> Result<(), String> { - let StmtKind::While { test, body, orelse, .. } = &stmt.node else { - unreachable!() - }; + let StmtKind::While { test, body, orelse, .. } = &stmt.node else { unreachable!() }; // var_assignment static values may be changed in another branch // if so, remove the static value as it may not be correct in this branch @@ -782,8 +742,11 @@ pub fn gen_while( let body_bb = ctx.ctx.append_basic_block(current, "while.body"); let cont_bb = ctx.ctx.append_basic_block(current, "while.cont"); // if there is no orelse, we just go to cont_bb - let orelse_bb = - if orelse.is_empty() { cont_bb } else { ctx.ctx.append_basic_block(current, "while.orelse") }; + let orelse_bb = if orelse.is_empty() { + cont_bb + } else { + ctx.ctx.append_basic_block(current, "while.orelse") + }; // store loop bb information and restore it later let loop_bb = ctx.loop_target.replace((test_bb, cont_bb)); ctx.builder.build_unconditional_branch(test_bb).unwrap(); @@ -796,11 +759,9 @@ pub fn gen_while( ctx.builder.build_unreachable().unwrap(); } - return Ok(()) - }; - let BasicValueEnum::IntValue(test) = test else { - unreachable!() + return Ok(()); }; + let BasicValueEnum::IntValue(test) = test else { unreachable!() }; ctx.builder .build_conditional_branch(generator.bool_to_i1(ctx, test), body_bb, orelse_bb) @@ -853,12 +814,12 @@ pub fn gen_if_else_expr_callback<'ctx, 'a, G, CondFn, ThenFn, ElseFn, R>( then_fn: ThenFn, else_fn: ElseFn, ) -> Result>, String> - where - G: CodeGenerator + ?Sized, - CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result, String>, - ThenFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result, String>, - ElseFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result, String>, - R: BasicValue<'ctx>, +where + G: CodeGenerator + ?Sized, + CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result, String>, + ThenFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result, String>, + ElseFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result, String>, + R: BasicValue<'ctx>, { let current_bb = ctx.builder.get_insert_block().unwrap(); @@ -893,8 +854,8 @@ pub fn gen_if_else_expr_callback<'ctx, 'a, G, CondFn, ThenFn, ElseFn, R>( let phi = ctx.builder.build_phi(tv_ty, "").unwrap(); phi.add_incoming(&[(&tv, then_end_bb), (&ev, else_end_bb)]); - Some(phi.as_basic_value()) - }, + Some(phi.as_basic_value()) + } (Some(tv), None) => Some(tv.as_basic_value_enum()), (None, Some(ev)) => Some(ev.as_basic_value_enum()), (None, None) => None, @@ -919,11 +880,11 @@ pub fn gen_if_callback<'ctx, 'a, G, CondFn, ThenFn, ElseFn>( then_fn: ThenFn, else_fn: ElseFn, ) -> Result<(), String> - where - G: CodeGenerator + ?Sized, - CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result, String>, - ThenFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<(), String>, - ElseFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<(), String>, +where + G: CodeGenerator + ?Sized, + CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result, String>, + ThenFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<(), String>, + ElseFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<(), String>, { gen_if_else_expr_callback( generator, @@ -936,7 +897,7 @@ pub fn gen_if_callback<'ctx, 'a, G, CondFn, ThenFn, ElseFn>( |generator, ctx| { else_fn(generator, ctx)?; Ok(None) - } + }, )?; Ok(()) @@ -948,9 +909,7 @@ pub fn gen_if( ctx: &mut CodeGenContext<'_, '_>, stmt: &Stmt>, ) -> Result<(), String> { - let StmtKind::If { test, body, orelse, .. } = &stmt.node else { - unreachable!() - }; + let StmtKind::If { test, body, orelse, .. } = &stmt.node else { unreachable!() }; // var_assignment static values may be changed in another branch // if so, remove the static value as it may not be correct in this branch @@ -969,9 +928,9 @@ pub fn gen_if( }; ctx.builder.build_unconditional_branch(test_bb).unwrap(); ctx.builder.position_at_end(test_bb); - let test = generator - .gen_expr(ctx, test) - .and_then(|v| v.map(|v| v.to_basic_value_enum(ctx, generator, test.custom.unwrap())).transpose())?; + let test = generator.gen_expr(ctx, test).and_then(|v| { + v.map(|v| v.to_basic_value_enum(ctx, generator, test.custom.unwrap())).transpose() + })?; if let Some(BasicValueEnum::IntValue(test)) = test { ctx.builder .build_conditional_branch(generator.bool_to_i1(ctx, test), body_bb, orelse_bb) @@ -1077,16 +1036,16 @@ pub fn exn_constructor<'ctx>( }; let defs = ctx.top_level.definitions.read(); let def = defs[zelf_id].read(); - let TopLevelDef::Class { name: zelf_name, .. } = &*def else { - unreachable!() - }; + let TopLevelDef::Class { name: zelf_name, .. } = &*def else { unreachable!() }; let exception_name = format!("{}:{}", ctx.resolver.get_exception_id(zelf_id), zelf_name); unsafe { let id_ptr = ctx.builder.build_in_bounds_gep(zelf, &[zero, zero], "exn.id").unwrap(); let id = ctx.resolver.get_string_id(&exception_name); ctx.builder.build_store(id_ptr, int32.const_int(id as u64, false)).unwrap(); - let empty_string = ctx.gen_const(generator, &Constant::Str(String::new()), ctx.primitives.str); - let ptr = ctx.builder + let empty_string = + ctx.gen_const(generator, &Constant::Str(String::new()), ctx.primitives.str); + let ptr = ctx + .builder .build_in_bounds_gep(zelf, &[zero, int32.const_int(5, false)], "exn.msg") .unwrap(); let msg = if args.is_empty() { @@ -1101,21 +1060,24 @@ pub fn exn_constructor<'ctx>( } else { args.remove(0).1.to_basic_value_enum(ctx, generator, ctx.primitives.int64)? }; - let ptr = ctx.builder + let ptr = ctx + .builder .build_in_bounds_gep(zelf, &[zero, int32.const_int(*i, false)], "exn.param") .unwrap(); ctx.builder.build_store(ptr, value).unwrap(); } // set file, func to empty string for i in &[1, 4] { - let ptr = ctx.builder + let ptr = ctx + .builder .build_in_bounds_gep(zelf, &[zero, int32.const_int(*i, false)], "exn.str") .unwrap(); ctx.builder.build_store(ptr, empty_string.unwrap()).unwrap(); } // set ints to zero for i in &[2, 3] { - let ptr = ctx.builder + let ptr = ctx + .builder .build_in_bounds_gep(zelf, &[zero, int32.const_int(*i, false)], "exn.ints") .unwrap(); ctx.builder.build_store(ptr, zero).unwrap(); @@ -1139,23 +1101,27 @@ pub fn gen_raise<'ctx, G: CodeGenerator + ?Sized>( let int32 = ctx.ctx.i32_type(); let zero = int32.const_zero(); let exception = exception.into_pointer_value(); - let file_ptr = ctx.builder + let file_ptr = ctx + .builder .build_in_bounds_gep(exception, &[zero, int32.const_int(1, false)], "file_ptr") .unwrap(); let filename = ctx.gen_string(generator, loc.file.0); ctx.builder.build_store(file_ptr, filename).unwrap(); - let row_ptr = ctx.builder + let row_ptr = ctx + .builder .build_in_bounds_gep(exception, &[zero, int32.const_int(2, false)], "row_ptr") .unwrap(); ctx.builder.build_store(row_ptr, int32.const_int(loc.row as u64, false)).unwrap(); - let col_ptr = ctx.builder + let col_ptr = ctx + .builder .build_in_bounds_gep(exception, &[zero, int32.const_int(3, false)], "col_ptr") .unwrap(); ctx.builder.build_store(col_ptr, int32.const_int(loc.column as u64, false)).unwrap(); let current_fun = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); let fun_name = ctx.gen_string(generator, current_fun.get_name().to_str().unwrap()); - let name_ptr = ctx.builder + let name_ptr = ctx + .builder .build_in_bounds_gep(exception, &[zero, int32.const_int(4, false)], "name_ptr") .unwrap(); ctx.builder.build_store(name_ptr, fun_name).unwrap(); @@ -1204,7 +1170,8 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>( let mut final_data = None; let has_cleanup = !finalbody.is_empty(); if has_cleanup { - let final_state = generator.gen_var_alloc(ctx, ptr_type.into(), Some("try.final_state.addr"))?; + let final_state = + generator.gen_var_alloc(ctx, ptr_type.into(), Some("try.final_state.addr"))?; final_data = Some((final_state, Vec::new(), Vec::new())); if let Some((continue_target, break_target)) = ctx.loop_target { let break_proxy = ctx.ctx.append_basic_block(current_fun, "try.break"); @@ -1219,8 +1186,8 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>( } else { let return_target = ctx.ctx.append_basic_block(current_fun, "try.return_target"); ctx.builder.position_at_end(return_target); - let return_value = ctx.return_buffer - .map(|v| ctx.builder.build_load(v, "$ret").unwrap()); + let return_value = + ctx.return_buffer.map(|v| ctx.builder.build_load(v, "$ret").unwrap()); ctx.builder.build_return(return_value.as_ref().map(|v| v as &dyn BasicValue)).unwrap(); ctx.builder.position_at_end(current_block); final_proxy(ctx, return_target, return_proxy, final_data.as_mut().unwrap()); @@ -1250,11 +1217,12 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>( &mut ctx.unifier, type_.custom.unwrap(), ); - let obj_id = if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(type_.custom.unwrap()) { - *obj_id - } else { - unreachable!() - }; + let obj_id = + if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(type_.custom.unwrap()) { + *obj_id + } else { + unreachable!() + }; let exception_name = format!("{}:{}", ctx.resolver.get_exception_id(obj_id.0), exn_name); let exn_id = ctx.resolver.get_string_id(&exception_name); let exn_id_global = @@ -1303,16 +1271,15 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>( // run end_catch before continue/break/return let mut final_proxy_lambda = - |ctx: &mut CodeGenContext<'ctx, 'a>, - target: BasicBlock<'ctx>, - block: BasicBlock<'ctx>| final_proxy(ctx, target, block, final_data.as_mut().unwrap()); - let mut redirect_lambda = |ctx: &mut CodeGenContext<'ctx, 'a>, - target: BasicBlock<'ctx>, - block: BasicBlock<'ctx>| { - ctx.builder.position_at_end(block); - ctx.builder.build_unconditional_branch(target).unwrap(); - ctx.builder.position_at_end(body); - }; + |ctx: &mut CodeGenContext<'ctx, 'a>, target: BasicBlock<'ctx>, block: BasicBlock<'ctx>| { + final_proxy(ctx, target, block, final_data.as_mut().unwrap()) + }; + let mut redirect_lambda = + |ctx: &mut CodeGenContext<'ctx, 'a>, target: BasicBlock<'ctx>, block: BasicBlock<'ctx>| { + ctx.builder.position_at_end(block); + ctx.builder.build_unconditional_branch(target).unwrap(); + ctx.builder.position_at_end(body); + }; let redirect = if has_cleanup { &mut final_proxy_lambda as &mut dyn FnMut(&mut CodeGenContext<'ctx, 'a>, BasicBlock<'ctx>, BasicBlock<'ctx>) @@ -1357,12 +1324,9 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>( ctx.builder.position_at_end(dispatcher); unsafe { let zero = ctx.ctx.i32_type().const_zero(); - let exnid_ptr = ctx.builder - .build_gep( - exn.as_basic_value().into_pointer_value(), - &[zero, zero], - "exnidptr", - ) + let exnid_ptr = ctx + .builder + .build_gep(exn.as_basic_value().into_pointer_value(), &[zero, zero], "exnidptr") .unwrap(); Some(ctx.builder.build_load(exnid_ptr, "exnid").unwrap()) } @@ -1388,15 +1352,15 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>( post_handlers.push(current); ctx.builder.position_at_end(dispatcher_end); if let Some(exn_type) = exn_type { - let dispatcher_cont = - ctx.ctx.append_basic_block(current_fun, "try.dispatcher_cont"); + let dispatcher_cont = ctx.ctx.append_basic_block(current_fun, "try.dispatcher_cont"); let actual_id = exnid.unwrap().into_int_value(); let expected_id = ctx .builder .build_load(exn_type.into_pointer_value(), "expected_id") .map(BasicValueEnum::into_int_value) .unwrap(); - let result = ctx.builder + let result = ctx + .builder .build_int_compare(IntPredicate::EQ, actual_id, expected_id, "exncheck") .unwrap(); ctx.builder.build_conditional_branch(result, handler_bb, dispatcher_cont).unwrap(); @@ -1522,11 +1486,9 @@ pub fn gen_return( let func = ctx.builder.get_insert_block().and_then(BasicBlock::get_parent).unwrap(); let value = if let Some(v_expr) = value.as_ref() { if let Some(v) = generator.gen_expr(ctx, v_expr).transpose() { - Some( - v.and_then(|v| v.to_basic_value_enum(ctx, generator, v_expr.custom.unwrap()))? - ) + Some(v.and_then(|v| v.to_basic_value_enum(ctx, generator, v_expr.custom.unwrap()))?) } else { - return Ok(()) + return Ok(()); } } else { None @@ -1554,7 +1516,8 @@ pub fn gen_return( generator.bool_to_i1(ctx, ret_val) } else { ret_val - }.into() + } + .into() } else { ret_val } @@ -1592,16 +1555,12 @@ pub fn gen_stmt( } StmtKind::AnnAssign { target, value, .. } => { if let Some(value) = value { - let Some(value) = generator.gen_expr(ctx, value)? else { - return Ok(()) - }; + let Some(value) = generator.gen_expr(ctx, value)? else { return Ok(()) }; generator.gen_assign(ctx, target, value)?; } } StmtKind::Assign { targets, value, .. } => { - let Some(value) = generator.gen_expr(ctx, value)? else { - return Ok(()) - }; + let Some(value) = generator.gen_expr(ctx, value)? else { return Ok(()) }; for target in targets { generator.gen_assign(ctx, target, value.clone())?; } @@ -1626,7 +1585,7 @@ pub fn gen_stmt( let exc = if let Some(v) = generator.gen_expr(ctx, exc)? { v.to_basic_value_enum(ctx, generator, exc.custom.unwrap())? } else { - return Ok(()) + return Ok(()); }; gen_raise(generator, ctx, Some(&exc), stmt.location); } else { @@ -1637,14 +1596,16 @@ pub fn gen_stmt( let test = if let Some(v) = generator.gen_expr(ctx, test)? { v.to_basic_value_enum(ctx, generator, test.custom.unwrap())? } else { - return Ok(()) + return Ok(()); }; let err_msg = match msg { - Some(msg) => if let Some(v) = generator.gen_expr(ctx, msg)? { - v.to_basic_value_enum(ctx, generator, msg.custom.unwrap())? - } else { - return Ok(()) - }, + Some(msg) => { + if let Some(v) = generator.gen_expr(ctx, msg)? { + v.to_basic_value_enum(ctx, generator, msg.custom.unwrap())? + } else { + return Ok(()); + } + } None => ctx.gen_string(generator, ""), }; ctx.make_assert_impl( @@ -1656,7 +1617,7 @@ pub fn gen_stmt( stmt.location, ); } - _ => unimplemented!() + _ => unimplemented!(), }; Ok(()) } diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 9e7412e5..22416a1d 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -1,13 +1,14 @@ use crate::{ codegen::{ classes::{ListType, NDArrayType, ProxyType, RangeType}, - concrete_type::ConcreteTypeStore, CodeGenContext, CodeGenerator, CodeGenLLVMOptions, - CodeGenTargetMachineOptions, CodeGenTask, DefaultCodeGenerator, WithCall, WorkerRegistry, + concrete_type::ConcreteTypeStore, + CodeGenContext, CodeGenLLVMOptions, CodeGenTargetMachineOptions, CodeGenTask, + CodeGenerator, DefaultCodeGenerator, WithCall, WorkerRegistry, }, symbol_resolver::{SymbolResolver, ValueEnum}, toplevel::{ - composer::{ComposerConfig, TopLevelComposer}, DefinitionId, FunInstance, TopLevelContext, - TopLevelDef, + composer::{ComposerConfig, TopLevelComposer}, + DefinitionId, FunInstance, TopLevelContext, TopLevelDef, }, typecheck::{ type_inferencer::{FunctionData, Inferencer, PrimitiveStore}, @@ -17,7 +18,7 @@ use crate::{ use indoc::indoc; use inkwell::{ targets::{InitializationConfig, Target}, - OptimizationLevel + OptimizationLevel, }; use nac3parser::{ ast::{fold::Fold, StrRef}, @@ -70,9 +71,7 @@ impl SymbolResolver for Resolver { .read() .get(&id) .cloned() - .ok_or_else(|| HashSet::from([ - format!("cannot find symbol `{}`", id), - ])) + .ok_or_else(|| HashSet::from([format!("cannot find symbol `{}`", id)])) } fn get_string_id(&self, _: &str) -> i32 { @@ -227,12 +226,7 @@ fn test_primitives() { opt_level: OptimizationLevel::Default, target: CodeGenTargetMachineOptions::from_host_triple(), }; - let (registry, handles) = WorkerRegistry::create_workers( - threads, - top_level, - &llvm_options, - &f - ); + let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f); registry.add_task(task); registry.wait_tasks_complete(handles); } @@ -417,12 +411,7 @@ fn test_simple_call() { opt_level: OptimizationLevel::Default, target: CodeGenTargetMachineOptions::from_host_triple(), }; - let (registry, handles) = WorkerRegistry::create_workers( - threads, - top_level, - &llvm_options, - &f - ); + let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f); registry.add_task(task); registry.wait_tasks_complete(handles); } diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 5e890e69..298a2329 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -1,18 +1,18 @@ use std::fmt::Debug; +use std::rc::Rc; use std::sync::Arc; use std::{collections::HashMap, collections::HashSet, fmt::Display}; -use std::rc::Rc; use crate::{ codegen::{CodeGenContext, CodeGenerator}, - toplevel::{DefinitionId, TopLevelDef, type_annotation::TypeAnnotation}, + toplevel::{type_annotation::TypeAnnotation, DefinitionId, TopLevelDef}, typecheck::{ type_inferencer::PrimitiveStore, typedef::{Type, TypeEnum, Unifier, VarMap}, }, }; use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue}; -use itertools::{chain, Itertools, izip}; +use itertools::{chain, izip, Itertools}; use nac3parser::ast::{Constant, Expr, Location, StrRef}; use parking_lot::RwLock; @@ -39,7 +39,7 @@ impl SymbolValue { constant: &Constant, expected_ty: Type, primitives: &PrimitiveStore, - unifier: &mut Unifier + unifier: &mut Unifier, ) -> Result { match constant { Constant::None => { @@ -62,24 +62,16 @@ impl SymbolValue { } else { Err(format!("Expected {expected_ty:?}, but got str")) } - }, + } Constant::Int(i) => { if unifier.unioned(expected_ty, primitives.int32) { - i32::try_from(*i) - .map(SymbolValue::I32) - .map_err(|e| e.to_string()) + i32::try_from(*i).map(SymbolValue::I32).map_err(|e| e.to_string()) } else if unifier.unioned(expected_ty, primitives.int64) { - i64::try_from(*i) - .map(SymbolValue::I64) - .map_err(|e| e.to_string()) + i64::try_from(*i).map(SymbolValue::I64).map_err(|e| e.to_string()) } else if unifier.unioned(expected_ty, primitives.uint32) { - u32::try_from(*i) - .map(SymbolValue::U32) - .map_err(|e| e.to_string()) + u32::try_from(*i).map(SymbolValue::U32).map_err(|e| e.to_string()) } else if unifier.unioned(expected_ty, primitives.uint64) { - u64::try_from(*i) - .map(SymbolValue::U64) - .map_err(|e| e.to_string()) + u64::try_from(*i).map(SymbolValue::U64).map_err(|e| e.to_string()) } else { Err(format!("Expected {}, but got int", unifier.stringify(expected_ty))) } @@ -87,7 +79,10 @@ impl SymbolValue { Constant::Tuple(t) => { let expected_ty = unifier.get_ty(expected_ty); let TypeEnum::TTuple { ty } = expected_ty.as_ref() else { - return Err(format!("Expected {:?}, but got Tuple", expected_ty.get_type_name())) + return Err(format!( + "Expected {:?}, but got Tuple", + expected_ty.get_type_name() + )); }; assert_eq!(ty.len(), t.len()); @@ -105,7 +100,7 @@ impl SymbolValue { } else { Err(format!("Expected {expected_ty:?}, but got float")) } - }, + } _ => Err(format!("Unsupported value type {constant:?}")), } } @@ -113,9 +108,7 @@ impl SymbolValue { /// Creates a [`SymbolValue`] from a [`Constant`], with its type being inferred from the constant value. /// /// * `constant` - The constant to create the value from. - pub fn from_constant_inferred( - constant: &Constant, - ) -> Result { + pub fn from_constant_inferred(constant: &Constant) -> Result { match constant { Constant::None => Ok(SymbolValue::OptionNone), Constant::Bool(b) => Ok(SymbolValue::Bool(*b)), @@ -123,13 +116,19 @@ impl SymbolValue { Constant::Int(i) => { let i = *i; if i >= 0 { - i32::try_from(i).map(SymbolValue::I32) + i32::try_from(i) + .map(SymbolValue::I32) .or_else(|_| i64::try_from(i).map(SymbolValue::I64)) - .map_err(|_| format!("Literal cannot be expressed as any integral type: {i}")) + .map_err(|_| { + format!("Literal cannot be expressed as any integral type: {i}") + }) } else { - u32::try_from(i).map(SymbolValue::U32) + u32::try_from(i) + .map(SymbolValue::U32) .or_else(|_| u64::try_from(i).map(SymbolValue::U64)) - .map_err(|_| format!("Literal cannot be expressed as any integral type: {i}")) + .map_err(|_| { + format!("Literal cannot be expressed as any integral type: {i}") + }) } } Constant::Tuple(t) => { @@ -155,20 +154,19 @@ impl SymbolValue { SymbolValue::Double(_) => primitives.float, SymbolValue::Bool(_) => primitives.bool, SymbolValue::Tuple(vs) => { - let vs_tys = vs - .iter() - .map(|v| v.get_type(primitives, unifier)) - .collect::>(); - unifier.add_ty(TypeEnum::TTuple { - ty: vs_tys, - }) + let vs_tys = vs.iter().map(|v| v.get_type(primitives, unifier)).collect::>(); + unifier.add_ty(TypeEnum::TTuple { ty: vs_tys }) } SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option, } } /// Returns the [`TypeAnnotation`] representing the data type of this value. - pub fn get_type_annotation(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> TypeAnnotation { + pub fn get_type_annotation( + &self, + primitives: &PrimitiveStore, + unifier: &mut Unifier, + ) -> TypeAnnotation { match self { SymbolValue::Bool(..) | SymbolValue::Double(..) @@ -199,7 +197,11 @@ impl SymbolValue { } /// Returns the [`TypeEnum`] representing the data type of this value. - pub fn get_type_enum(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> Rc { + pub fn get_type_enum( + &self, + primitives: &PrimitiveStore, + unifier: &mut Unifier, + ) -> Rc { let ty = self.get_type(primitives, unifier); unifier.get_ty(ty) } @@ -332,7 +334,6 @@ impl<'ctx> From> for ValueEnum<'ctx> { } impl<'ctx> ValueEnum<'ctx> { - /// Converts this [`ValueEnum`] to a [`BasicValueEnum`]. pub fn to_basic_value_enum<'a>( self, @@ -374,7 +375,7 @@ pub trait SymbolResolver { &self, _unifier: &mut Unifier, _top_level_defs: &[Arc>], - _primitives: &PrimitiveStore + _primitives: &PrimitiveStore, ) -> Result<(), String> { Ok(()) } @@ -443,40 +444,29 @@ pub fn parse_type_annotation( let def = top_level_defs[obj_id.0].read(); if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { if !type_vars.is_empty() { - return Err(HashSet::from([ - format!( - "Unexpected number of type parameters: expected {} but got 0", - type_vars.len() - ), - ])) + return Err(HashSet::from([format!( + "Unexpected number of type parameters: expected {} but got 0", + type_vars.len() + )])); } let fields = chain( fields.iter().map(|(k, v, m)| (*k, (*v, *m))), methods.iter().map(|(k, v, _)| (*k, (*v, false))), ) - .collect(); - Ok(unifier.add_ty(TypeEnum::TObj { - obj_id, - fields, - params: VarMap::default(), - })) + .collect(); + Ok(unifier.add_ty(TypeEnum::TObj { obj_id, fields, params: VarMap::default() })) } else { - Err(HashSet::from([ - format!("Cannot use function name as type at {loc}"), - ])) + Err(HashSet::from([format!("Cannot use function name as type at {loc}")])) } } else { - let ty = resolver - .get_symbol_type(unifier, top_level_defs, primitives, *id) - .map_err(|e| HashSet::from([ - format!("Unknown type annotation at {loc}: {e}"), - ]))?; + let ty = + resolver.get_symbol_type(unifier, top_level_defs, primitives, *id).map_err( + |e| HashSet::from([format!("Unknown type annotation at {loc}: {e}")]), + )?; if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) { Ok(ty) } else { - Err(HashSet::from([ - format!("Unknown type annotation {id} at {loc}"), - ])) + Err(HashSet::from([format!("Unknown type annotation {id} at {loc}")])) } } } @@ -499,9 +489,7 @@ pub fn parse_type_annotation( .collect::, _>>()?; Ok(unifier.add_ty(TypeEnum::TTuple { ty })) } else { - Err(HashSet::from([ - "Expected multiple elements for tuple".into() - ])) + Err(HashSet::from(["Expected multiple elements for tuple".into()])) } } else if *id == literal_id { let mut parse_literal = |elt: &Expr| { @@ -509,19 +497,21 @@ pub fn parse_type_annotation( let ty_enum = &*unifier.get_ty_immutable(ty); match ty_enum { TypeEnum::TLiteral { values, .. } => Ok(values.clone()), - _ => Err(HashSet::from([ - format!("Expected literal in type argument for Literal at {}", elt.location), - ])) + _ => Err(HashSet::from([format!( + "Expected literal in type argument for Literal at {}", + elt.location + )])), } }; let values = if let Tuple { elts, .. } = &slice.node { - elts.iter() - .map(&mut parse_literal) - .collect::, _>>()? + elts.iter().map(&mut parse_literal).collect::, _>>()? } else { vec![parse_literal(slice)?] - }.into_iter().flatten().collect_vec(); + } + .into_iter() + .flatten() + .collect_vec(); Ok(unifier.get_fresh_literal(values, Some(slice.location))) } else { @@ -539,13 +529,11 @@ pub fn parse_type_annotation( let def = top_level_defs[obj_id.0].read(); if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { if types.len() != type_vars.len() { - return Err(HashSet::from([ - format!( - "Unexpected number of type parameters: expected {} but got {}", - type_vars.len(), - types.len() - ), - ])) + return Err(HashSet::from([format!( + "Unexpected number of type parameters: expected {} but got {}", + type_vars.len(), + types.len() + )])); } let mut subst = VarMap::new(); for (var, ty) in izip!(type_vars.iter(), types.iter()) { @@ -569,9 +557,7 @@ pub fn parse_type_annotation( })); Ok(unifier.add_ty(TypeEnum::TObj { obj_id, fields, params: subst })) } else { - Err(HashSet::from([ - "Cannot use function name as type".into(), - ])) + Err(HashSet::from(["Cannot use function name as type".into()])) } } }; @@ -582,17 +568,13 @@ pub fn parse_type_annotation( if let Name { id, .. } = &value.node { subscript_name_handle(id, slice, unifier) } else { - Err(HashSet::from([ - format!("unsupported type expression at {}", expr.location), - ])) + Err(HashSet::from([format!("unsupported type expression at {}", expr.location)])) } } Constant { value, .. } => SymbolValue::from_constant_inferred(value) .map(|v| unifier.get_fresh_literal(vec![v], Some(expr.location))) .map_err(|err| HashSet::from([err])), - _ => Err(HashSet::from([ - format!("unsupported type expression at {}", expr.location), - ])), + _ => Err(HashSet::from([format!("unsupported type expression at {}", expr.location)])), } } diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 5e1c9868..2d5e68ce 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -3,9 +3,9 @@ use std::iter::once; use indexmap::IndexMap; use inkwell::{ attributes::{Attribute, AttributeLoc}, - IntPredicate, types::{BasicMetadataTypeEnum, BasicType}, - values::{BasicMetadataValueEnum, BasicValue, CallSiteValue} + values::{BasicMetadataValueEnum, BasicValue, CallSiteValue}, + IntPredicate, }; use itertools::Either; @@ -13,12 +13,7 @@ use crate::{ codegen::{ builtin_fns, classes::{ - ArrayLikeValue, - NDArrayValue, - ProxyType, - ProxyValue, - RangeValue, - RangeType, + ArrayLikeValue, NDArrayValue, ProxyType, ProxyValue, RangeType, RangeValue, TypedArrayLikeAccessor, }, expr::destructure_range, @@ -27,10 +22,7 @@ use crate::{ stmt::exn_constructor, }, symbol_resolver::SymbolValue, - toplevel::{ - helper::PRIMITIVE_DEF_IDS, - numpy::make_ndarray_ty, - }, + toplevel::{helper::PRIMITIVE_DEF_IDS, numpy::make_ndarray_ty}, typecheck::typedef::VarMap, }; @@ -43,8 +35,8 @@ pub fn get_exn_constructor( class_id: usize, cons_id: usize, unifier: &mut Unifier, - primitives: &PrimitiveStore -)-> (TopLevelDef, TopLevelDef, Type, Type) { + primitives: &PrimitiveStore, +) -> (TopLevelDef, TopLevelDef, Type, Type) { let int32 = primitives.int32; let int64 = primitives.int64; let string = primitives.str; @@ -126,11 +118,10 @@ fn create_fn_by_codegen( name: name.into(), simple_name: name.into(), signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: param_ty.iter().map(|p| FuncArg { - name: p.1.into(), - ty: p.0, - default_value: None, - }).collect(), + args: param_ty + .iter() + .map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None }) + .collect(), ret: ret_ty, vars: var_map.clone(), })), @@ -158,9 +149,7 @@ fn create_fn_by_intrinsic( params: &[(Type, &'static str)], intrinsic_fn: &'static str, ) -> Arc> { - let param_tys = params.iter() - .map(|p| p.0) - .collect_vec(); + let param_tys = params.iter().map(|p| p.0).collect_vec(); create_fn_by_codegen( unifier, @@ -171,21 +160,22 @@ fn create_fn_by_intrinsic( Box::new(move |ctx, _, fun, args, generator| { let args_ty = fun.0.args.iter().map(|a| a.ty).collect_vec(); - assert!(param_tys.iter().zip(&args_ty) + assert!(param_tys + .iter() + .zip(&args_ty) .all(|(expected, actual)| ctx.unifier.unioned(*expected, *actual))); - let args_val = args_ty.iter().zip_eq(args.iter()) - .map(|(ty, arg)| { - arg.1.clone() - .to_basic_value_enum(ctx, generator, *ty) - .unwrap() - }) + let args_val = args_ty + .iter() + .zip_eq(args.iter()) + .map(|(ty, arg)| arg.1.clone().to_basic_value_enum(ctx, generator, *ty).unwrap()) .map_into::() .collect_vec(); let intrinsic_fn = ctx.module.get_function(intrinsic_fn).unwrap_or_else(|| { let ret_llvm_ty = ctx.get_llvm_abi_type(generator, ret_ty); - let param_llvm_ty = param_tys.iter() + let param_llvm_ty = param_tys + .iter() .map(|p| ctx.get_llvm_abi_type(generator, *p)) .map_into::() .collect_vec(); @@ -194,7 +184,8 @@ fn create_fn_by_intrinsic( ctx.module.add_function(intrinsic_fn, fn_type, None) }); - let val = ctx.builder + let val = ctx + .builder .build_call(intrinsic_fn, args_val.as_slice(), name) .map(CallSiteValue::try_as_basic_value) .map(Either::unwrap_left) @@ -223,9 +214,7 @@ fn create_fn_by_extern( extern_fn: &'static str, attrs: &'static [&str], ) -> Arc> { - let param_tys = params.iter() - .map(|p| p.0) - .collect_vec(); + let param_tys = params.iter().map(|p| p.0).collect_vec(); create_fn_by_codegen( unifier, @@ -236,47 +225,49 @@ fn create_fn_by_extern( Box::new(move |ctx, _, fun, args, generator| { let args_ty = fun.0.args.iter().map(|a| a.ty).collect_vec(); - assert!(param_tys.iter().zip(&args_ty) + assert!(param_tys + .iter() + .zip(&args_ty) .all(|(expected, actual)| ctx.unifier.unioned(*expected, *actual))); - let args_val = args_ty.iter().zip_eq(args.iter()) - .map(|(ty, arg)| { - arg.1.clone() - .to_basic_value_enum(ctx, generator, *ty) - .unwrap() - }) + let args_val = args_ty + .iter() + .zip_eq(args.iter()) + .map(|(ty, arg)| arg.1.clone().to_basic_value_enum(ctx, generator, *ty).unwrap()) .map_into::() .collect_vec(); - let intrinsic_fn = ctx.module.get_function(extern_fn).unwrap_or_else(|| { - let ret_llvm_ty = ctx.get_llvm_abi_type(generator, ret_ty); - let param_llvm_ty = param_tys.iter() - .map(|p| ctx.get_llvm_abi_type(generator, *p)) - .map_into::() - .collect_vec(); - let fn_type = ret_llvm_ty.fn_type(param_llvm_ty.as_slice(), false); - let func = ctx.module.add_function(extern_fn, fn_type, None); + let intrinsic_fn = ctx.module.get_function(extern_fn).unwrap_or_else(|| { + let ret_llvm_ty = ctx.get_llvm_abi_type(generator, ret_ty); + let param_llvm_ty = param_tys + .iter() + .map(|p| ctx.get_llvm_abi_type(generator, *p)) + .map_into::() + .collect_vec(); + let fn_type = ret_llvm_ty.fn_type(param_llvm_ty.as_slice(), false); + let func = ctx.module.add_function(extern_fn, fn_type, None); + func.add_attribute( + AttributeLoc::Function, + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0), + ); + + for attr in attrs { func.add_attribute( AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0) + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), ); + } - for attr in attrs { - func.add_attribute( - AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0) - ); - } + func + }); - func - }); - - let val = ctx.builder - .build_call(intrinsic_fn, &args_val, name) - .map(CallSiteValue::try_as_basic_value) - .map(Either::unwrap_left) - .unwrap(); - Ok(val.into()) + let val = ctx + .builder + .build_call(intrinsic_fn, &args_val, name) + .map(CallSiteValue::try_as_basic_value) + .map(Either::unwrap_left) + .unwrap(); + Ok(val.into()) }), ) } @@ -302,10 +293,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built 32 => SymbolValue::U32(2u32), _ => unreachable!(), }; - let ndims = unifier.add_ty(TypeEnum::TLiteral { - values: vec![value], - loc: None, - }); + let ndims = unifier.add_ty(TypeEnum::TLiteral { values: vec![value], loc: None }); make_ndarray_ty(unifier, primitives, Some(float), Some(ndims)) }; @@ -315,39 +303,27 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built Some("N".into()), None, ); - let num_var_map: VarMap = vec![ - (num_ty.1, num_ty.0), - ].into_iter().collect(); + let num_var_map: VarMap = vec![(num_ty.1, num_ty.0)].into_iter().collect(); - let new_type_or_ndarray_ty = |unifier: &mut Unifier, primitives: &PrimitiveStore, scalar_ty: Type| { - let ndarray = make_ndarray_ty(unifier, primitives, Some(scalar_ty), None); + let new_type_or_ndarray_ty = + |unifier: &mut Unifier, primitives: &PrimitiveStore, scalar_ty: Type| { + let ndarray = make_ndarray_ty(unifier, primitives, Some(scalar_ty), None); - unifier.get_fresh_var_with_range( - &[scalar_ty, ndarray], - Some("T".into()), - None, - ) - }; + unifier.get_fresh_var_with_range(&[scalar_ty, ndarray], Some("T".into()), None) + }; let ndarray_num_ty = make_ndarray_ty(unifier, primitives, Some(num_ty.0), None); - let float_or_ndarray_ty = unifier.get_fresh_var_with_range( - &[float, ndarray_float], - Some("T".into()), - None, - ); - let float_or_ndarray_var_map: VarMap = vec![ - (float_or_ndarray_ty.1, float_or_ndarray_ty.0), - ].into_iter().collect(); + let float_or_ndarray_ty = + unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); + let float_or_ndarray_var_map: VarMap = + vec![(float_or_ndarray_ty.1, float_or_ndarray_ty.0)].into_iter().collect(); - let num_or_ndarray_ty = unifier.get_fresh_var_with_range( - &[num_ty.0, ndarray_num_ty], - Some("T".into()), - None, - ); - let num_or_ndarray_var_map: VarMap = vec![ - (num_ty.1, num_ty.0), - (num_or_ndarray_ty.1, num_or_ndarray_ty.0), - ].into_iter().collect(); + let num_or_ndarray_ty = + unifier.get_fresh_var_with_range(&[num_ty.0, ndarray_num_ty], Some("T".into()), None); + let num_or_ndarray_var_map: VarMap = + vec![(num_ty.1, num_ty.0), (num_or_ndarray_ty.1, num_or_ndarray_ty.0)] + .into_iter() + .collect(); let exception_fields = vec![ ("__name__".into(), int32, true), @@ -364,9 +340,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built // for Option, is_some and is_none share the same type: () -> bool, // and they are methods under the same class `Option` let (is_some_ty, unwrap_ty, (option_ty_var, option_ty_var_id)) = - if let TypeEnum::TObj { fields, params, .. } = - unifier.get_ty(primitives.option).as_ref() - { + if let TypeEnum::TObj { fields, params, .. } = unifier.get_ty(primitives.option).as_ref() { ( *fields.get(&"is_some".into()).unwrap(), *fields.get(&"unwrap".into()).unwrap(), @@ -376,24 +350,16 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built unreachable!() }; - let TypeEnum::TObj { - fields: ndarray_fields, - params: ndarray_params, - .. - } = &*unifier.get_ty(primitives.ndarray) else { + let TypeEnum::TObj { fields: ndarray_fields, params: ndarray_params, .. } = + &*unifier.get_ty(primitives.ndarray) + else { unreachable!() }; - let (ndarray_dtype_ty, ndarray_dtype_var_id) = ndarray_params - .iter() - .next() - .map(|(var_id, ty)| (*ty, *var_id)) - .unwrap(); - let (ndarray_ndims_ty, ndarray_ndims_var_id) = ndarray_params - .iter() - .nth(1) - .map(|(var_id, ty)| (*ty, *var_id)) - .unwrap(); + let (ndarray_dtype_ty, ndarray_dtype_var_id) = + ndarray_params.iter().next().map(|(var_id, ty)| (*ty, *var_id)).unwrap(); + let (ndarray_ndims_ty, ndarray_ndims_var_id) = + ndarray_params.iter().nth(1).map(|(var_id, ty)| (*ty, *var_id)).unwrap(); let ndarray_copy_ty = *ndarray_fields.get(&"copy".into()).unwrap(); let ndarray_fill_ty = *ndarray_fields.get(&"fill".into()).unwrap(); @@ -503,20 +469,13 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built codegen_callback: Some(Arc::new(GenCall::new(Box::new( |ctx, obj, _, _, generator| { let expect_ty = obj.clone().unwrap().0; - let obj_val = obj.unwrap().1.clone().to_basic_value_enum( - ctx, - generator, - expect_ty, - )?; + let obj_val = + obj.unwrap().1.clone().to_basic_value_enum(ctx, generator, expect_ty)?; let BasicValueEnum::PointerValue(ptr) = obj_val else { unreachable!("option must be ptr") }; - Ok(Some(ctx.builder - .build_is_not_null(ptr, "is_some") - .map(Into::into) - .unwrap() - )) + Ok(Some(ctx.builder.build_is_not_null(ptr, "is_some").map(Into::into).unwrap())) }, )))), loc: None, @@ -532,20 +491,13 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built codegen_callback: Some(Arc::new(GenCall::new(Box::new( |ctx, obj, _, _, generator| { let expect_ty = obj.clone().unwrap().0; - let obj_val = obj.unwrap().1.clone().to_basic_value_enum( - ctx, - generator, - expect_ty, - )?; + let obj_val = + obj.unwrap().1.clone().to_basic_value_enum(ctx, generator, expect_ty)?; let BasicValueEnum::PointerValue(ptr) = obj_val else { unreachable!("option must be ptr") }; - Ok(Some(ctx.builder - .build_is_null(ptr, "is_none") - .map(Into::into) - .unwrap() - )) + Ok(Some(ctx.builder.build_is_null(ptr, "is_none").map(Into::into).unwrap())) }, )))), loc: None, @@ -558,9 +510,9 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, - codegen_callback: Some(Arc::new(GenCall::create_dummy( - String::from("handled in gen_expr"), - ))), + codegen_callback: Some(Arc::new(GenCall::create_dummy(String::from( + "handled in gen_expr", + )))), loc: None, })), Arc::new(RwLock::new(TopLevelDef::Class { @@ -613,7 +565,11 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built name: "int32".into(), simple_name: "int32".into(), signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }], + args: vec![FuncArg { + name: "n".into(), + ty: num_or_ndarray_ty.0, + default_value: None, + }], ret: num_or_ndarray_ty.0, vars: num_or_ndarray_var_map.clone(), })), @@ -635,7 +591,11 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built name: "int64".into(), simple_name: "int64".into(), signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }], + args: vec![FuncArg { + name: "n".into(), + ty: num_or_ndarray_ty.0, + default_value: None, + }], ret: num_or_ndarray_ty.0, vars: num_or_ndarray_var_map.clone(), })), @@ -657,7 +617,11 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built name: "uint32".into(), simple_name: "uint32".into(), signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }], + args: vec![FuncArg { + name: "n".into(), + ty: num_or_ndarray_ty.0, + default_value: None, + }], ret: num_or_ndarray_ty.0, vars: num_or_ndarray_var_map.clone(), })), @@ -679,7 +643,11 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built name: "uint64".into(), simple_name: "uint64".into(), signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }], + args: vec![FuncArg { + name: "n".into(), + ty: num_or_ndarray_ty.0, + default_value: None, + }], ret: num_or_ndarray_ty.0, vars: num_or_ndarray_var_map.clone(), })), @@ -701,7 +669,11 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built name: "float".into(), simple_name: "float".into(), signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }], + args: vec![FuncArg { + name: "n".into(), + ty: num_or_ndarray_ty.0, + default_value: None, + }], ret: num_or_ndarray_ty.0, vars: num_or_ndarray_var_map.clone(), })), @@ -797,7 +769,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![ FuncArg { name: "object".into(), ty: tv.0, default_value: None }, - FuncArg { + FuncArg { name: "copy".into(), ty: boolean, default_value: Some(SymbolValue::Bool(true)), @@ -834,9 +806,13 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built FuncArg { name: "M".into(), ty: int32, - default_value: Some(SymbolValue::OptionNone) + default_value: Some(SymbolValue::OptionNone), + }, + FuncArg { + name: "k".into(), + ty: int32, + default_value: Some(SymbolValue::I32(0)), }, - FuncArg { name: "k".into(), ty: int32, default_value: Some(SymbolValue::I32(0)) }, ], ret: ndarray_float_2d, vars: VarMap::default(), @@ -865,80 +841,70 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built }), ), { - let common_ndim = unifier.get_fresh_const_generic_var( - primitives.usize(), - Some("N".into()), - None, - ); - let ndarray_int32 = make_ndarray_ty(unifier, primitives, Some(int32), Some(common_ndim.0)); - let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); + let common_ndim = + unifier.get_fresh_const_generic_var(primitives.usize(), Some("N".into()), None); + let ndarray_int32 = + make_ndarray_ty(unifier, primitives, Some(int32), Some(common_ndim.0)); + let ndarray_float = + make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); - let p0_ty = unifier.get_fresh_var_with_range( - &[float, ndarray_float], - Some("T".into()), - None, - ); - let ret_ty = unifier.get_fresh_var_with_range( - &[int32, ndarray_int32], - Some("R".into()), - None, - ); + let p0_ty = + unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); + let ret_ty = + unifier.get_fresh_var_with_range(&[int32, ndarray_int32], Some("R".into()), None); create_fn_by_codegen( unifier, - &[ - (common_ndim.1, common_ndim.0), - (p0_ty.1, p0_ty.0), - (ret_ty.1, ret_ty.0), - ].into_iter().collect(), + &[(common_ndim.1, common_ndim.0), (p0_ty.1, p0_ty.0), (ret_ty.1, ret_ty.0)] + .into_iter() + .collect(), "round", ret_ty.0, &[(p0_ty.0, "n")], Box::new(|ctx, _, fun, args, generator| { let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, arg_ty)?; + let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(builtin_fns::call_round(generator, ctx, (arg_ty, arg), ctx.primitives.int32)?)) + Ok(Some(builtin_fns::call_round( + generator, + ctx, + (arg_ty, arg), + ctx.primitives.int32, + )?)) }), ) }, { - let common_ndim = unifier.get_fresh_const_generic_var( - primitives.usize(), - Some("N".into()), - None, - ); - let ndarray_int64 = make_ndarray_ty(unifier, primitives, Some(int64), Some(common_ndim.0)); - let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); + let common_ndim = + unifier.get_fresh_const_generic_var(primitives.usize(), Some("N".into()), None); + let ndarray_int64 = + make_ndarray_ty(unifier, primitives, Some(int64), Some(common_ndim.0)); + let ndarray_float = + make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); - let p0_ty = unifier.get_fresh_var_with_range( - &[float, ndarray_float], - Some("T".into()), - None, - ); - let ret_ty = unifier.get_fresh_var_with_range( - &[int64, ndarray_int64], - Some("R".into()), - None, - ); + let p0_ty = + unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); + let ret_ty = + unifier.get_fresh_var_with_range(&[int64, ndarray_int64], Some("R".into()), None); create_fn_by_codegen( unifier, - &[ - (common_ndim.1, common_ndim.0), - (p0_ty.1, p0_ty.0), - (ret_ty.1, ret_ty.0), - ].into_iter().collect(), + &[(common_ndim.1, common_ndim.0), (p0_ty.1, p0_ty.0), (ret_ty.1, ret_ty.0)] + .into_iter() + .collect(), "round64", ret_ty.0, &[(p0_ty.0, "n")], Box::new(|ctx, _, fun, args, generator| { let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, arg_ty)?; + let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(builtin_fns::call_round(generator, ctx, (arg_ty, arg), ctx.primitives.int64)?)) + Ok(Some(builtin_fns::call_round( + generator, + ctx, + (arg_ty, arg), + ctx.primitives.int64, + )?)) }), ) }, @@ -950,8 +916,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float_or_ndarray_ty.0, "n")], Box::new(|ctx, _, fun, args, generator| { let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, arg_ty)?; + let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; Ok(Some(builtin_fns::call_numpy_round(generator, ctx, (arg_ty, arg))?)) }), @@ -990,52 +955,54 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built let ty_i32 = ctx.primitives.int32; for (i, arg) in args.iter().enumerate() { if arg.0 == Some("start".into()) { - start = Some(arg.1 - .clone() - .to_basic_value_enum(ctx, generator, ty_i32)? - .into_int_value() + start = Some( + arg.1 + .clone() + .to_basic_value_enum(ctx, generator, ty_i32)? + .into_int_value(), ); } else if arg.0 == Some("stop".into()) { stop = Some( arg.1 .clone() .to_basic_value_enum(ctx, generator, ty_i32)? - .into_int_value() + .into_int_value(), ); } else if arg.0 == Some("step".into()) { step = Some( arg.1 .clone() .to_basic_value_enum(ctx, generator, ty_i32)? - .into_int_value() + .into_int_value(), ); } else if i == 0 { start = Some( arg.1 .clone() .to_basic_value_enum(ctx, generator, ty_i32)? - .into_int_value() + .into_int_value(), ); } else if i == 1 { stop = Some( arg.1 .clone() .to_basic_value_enum(ctx, generator, ty_i32)? - .into_int_value() + .into_int_value(), ); } else if i == 2 { step = Some( arg.1 .clone() .to_basic_value_enum(ctx, generator, ty_i32)? - .into_int_value() + .into_int_value(), ); } } let step = match step { Some(step) => { // assert step != 0, throw exception if not - let not_zero = ctx.builder + let not_zero = ctx + .builder .build_int_compare( IntPredicate::NE, step, @@ -1095,7 +1062,11 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built name: "bool".into(), simple_name: "bool".into(), signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }], + args: vec![FuncArg { + name: "n".into(), + ty: num_or_ndarray_ty.0, + default_value: None, + }], ret: num_or_ndarray_ty.0, vars: num_or_ndarray_var_map.clone(), })), @@ -1114,80 +1085,70 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built loc: None, })), { - let common_ndim = unifier.get_fresh_const_generic_var( - primitives.usize(), - Some("N".into()), - None, - ); - let ndarray_int32 = make_ndarray_ty(unifier, primitives, Some(int32), Some(common_ndim.0)); - let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); + let common_ndim = + unifier.get_fresh_const_generic_var(primitives.usize(), Some("N".into()), None); + let ndarray_int32 = + make_ndarray_ty(unifier, primitives, Some(int32), Some(common_ndim.0)); + let ndarray_float = + make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); - let p0_ty = unifier.get_fresh_var_with_range( - &[float, ndarray_float], - Some("T".into()), - None, - ); - let ret_ty = unifier.get_fresh_var_with_range( - &[int32, ndarray_int32], - Some("R".into()), - None, - ); + let p0_ty = + unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); + let ret_ty = + unifier.get_fresh_var_with_range(&[int32, ndarray_int32], Some("R".into()), None); create_fn_by_codegen( unifier, - &[ - (common_ndim.1, common_ndim.0), - (p0_ty.1, p0_ty.0), - (ret_ty.1, ret_ty.0), - ].into_iter().collect(), + &[(common_ndim.1, common_ndim.0), (p0_ty.1, p0_ty.0), (ret_ty.1, ret_ty.0)] + .into_iter() + .collect(), "floor", ret_ty.0, &[(p0_ty.0, "n")], Box::new(|ctx, _, fun, args, generator| { let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, arg_ty)?; + let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(builtin_fns::call_floor(generator, ctx, (arg_ty, arg), ctx.primitives.int32)?)) + Ok(Some(builtin_fns::call_floor( + generator, + ctx, + (arg_ty, arg), + ctx.primitives.int32, + )?)) }), ) }, { - let common_ndim = unifier.get_fresh_const_generic_var( - primitives.usize(), - Some("N".into()), - None, - ); - let ndarray_int64 = make_ndarray_ty(unifier, primitives, Some(int64), Some(common_ndim.0)); - let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); + let common_ndim = + unifier.get_fresh_const_generic_var(primitives.usize(), Some("N".into()), None); + let ndarray_int64 = + make_ndarray_ty(unifier, primitives, Some(int64), Some(common_ndim.0)); + let ndarray_float = + make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); - let p0_ty = unifier.get_fresh_var_with_range( - &[float, ndarray_float], - Some("T".into()), - None, - ); - let ret_ty = unifier.get_fresh_var_with_range( - &[int64, ndarray_int64], - Some("R".into()), - None, - ); + let p0_ty = + unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); + let ret_ty = + unifier.get_fresh_var_with_range(&[int64, ndarray_int64], Some("R".into()), None); create_fn_by_codegen( unifier, - &[ - (common_ndim.1, common_ndim.0), - (p0_ty.1, p0_ty.0), - (ret_ty.1, ret_ty.0), - ].into_iter().collect(), + &[(common_ndim.1, common_ndim.0), (p0_ty.1, p0_ty.0), (ret_ty.1, ret_ty.0)] + .into_iter() + .collect(), "floor64", ret_ty.0, &[(p0_ty.0, "n")], Box::new(|ctx, _, fun, args, generator| { let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, arg_ty)?; + let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(builtin_fns::call_floor(generator, ctx, (arg_ty, arg), ctx.primitives.int64)?)) + Ok(Some(builtin_fns::call_floor( + generator, + ctx, + (arg_ty, arg), + ctx.primitives.int64, + )?)) }), ) }, @@ -1199,87 +1160,81 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float_or_ndarray_ty.0, "n")], Box::new(|ctx, _, fun, args, generator| { let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, arg_ty)?; + let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(builtin_fns::call_floor(generator, ctx, (arg_ty, arg), ctx.primitives.float)?)) + Ok(Some(builtin_fns::call_floor( + generator, + ctx, + (arg_ty, arg), + ctx.primitives.float, + )?)) }), ), { - let common_ndim = unifier.get_fresh_const_generic_var( - primitives.usize(), - Some("N".into()), - None, - ); - let ndarray_int32 = make_ndarray_ty(unifier, primitives, Some(int32), Some(common_ndim.0)); - let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); + let common_ndim = + unifier.get_fresh_const_generic_var(primitives.usize(), Some("N".into()), None); + let ndarray_int32 = + make_ndarray_ty(unifier, primitives, Some(int32), Some(common_ndim.0)); + let ndarray_float = + make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); - let p0_ty = unifier.get_fresh_var_with_range( - &[float, ndarray_float], - Some("T".into()), - None, - ); - let ret_ty = unifier.get_fresh_var_with_range( - &[int32, ndarray_int32], - Some("R".into()), - None, - ); + let p0_ty = + unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); + let ret_ty = + unifier.get_fresh_var_with_range(&[int32, ndarray_int32], Some("R".into()), None); create_fn_by_codegen( unifier, - &[ - (common_ndim.1, common_ndim.0), - (p0_ty.1, p0_ty.0), - (ret_ty.1, ret_ty.0), - ].into_iter().collect(), + &[(common_ndim.1, common_ndim.0), (p0_ty.1, p0_ty.0), (ret_ty.1, ret_ty.0)] + .into_iter() + .collect(), "ceil", ret_ty.0, &[(p0_ty.0, "n")], Box::new(|ctx, _, fun, args, generator| { let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, arg_ty)?; + let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(builtin_fns::call_ceil(generator, ctx, (arg_ty, arg), ctx.primitives.int32)?)) + Ok(Some(builtin_fns::call_ceil( + generator, + ctx, + (arg_ty, arg), + ctx.primitives.int32, + )?)) }), ) }, { - let common_ndim = unifier.get_fresh_const_generic_var( - primitives.usize(), - Some("N".into()), - None, - ); - let ndarray_int64 = make_ndarray_ty(unifier, primitives, Some(int64), Some(common_ndim.0)); - let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); + let common_ndim = + unifier.get_fresh_const_generic_var(primitives.usize(), Some("N".into()), None); + let ndarray_int64 = + make_ndarray_ty(unifier, primitives, Some(int64), Some(common_ndim.0)); + let ndarray_float = + make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); - let p0_ty = unifier.get_fresh_var_with_range( - &[float, ndarray_float], - Some("T".into()), - None, - ); - let ret_ty = unifier.get_fresh_var_with_range( - &[int64, ndarray_int64], - Some("R".into()), - None, - ); + let p0_ty = + unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); + let ret_ty = + unifier.get_fresh_var_with_range(&[int64, ndarray_int64], Some("R".into()), None); create_fn_by_codegen( unifier, - &[ - (common_ndim.1, common_ndim.0), - (p0_ty.1, p0_ty.0), - (ret_ty.1, ret_ty.0), - ].into_iter().collect(), + &[(common_ndim.1, common_ndim.0), (p0_ty.1, p0_ty.0), (ret_ty.1, ret_ty.0)] + .into_iter() + .collect(), "ceil64", ret_ty.0, &[(p0_ty.0, "n")], Box::new(|ctx, _, fun, args, generator| { let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, arg_ty)?; + let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(builtin_fns::call_ceil(generator, ctx, (arg_ty, arg), ctx.primitives.int64)?)) + Ok(Some(builtin_fns::call_ceil( + generator, + ctx, + (arg_ty, arg), + ctx.primitives.int64, + )?)) }), ) }, @@ -1291,22 +1246,22 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float_or_ndarray_ty.0, "n")], Box::new(|ctx, _, fun, args, generator| { let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, arg_ty)?; + let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(builtin_fns::call_ceil(generator, ctx, (arg_ty, arg), ctx.primitives.float)?)) + Ok(Some(builtin_fns::call_ceil( + generator, + ctx, + (arg_ty, arg), + ctx.primitives.float, + )?)) }), ), Arc::new(RwLock::new({ let tvar = unifier.get_fresh_var(Some("L".into()), None); let list = unifier.add_ty(TypeEnum::TList { ty: tvar.0 }); - let ndims = unifier.get_fresh_const_generic_var(primitives.uint64, Some("N".into()), None); - let ndarray = make_ndarray_ty( - unifier, - primitives, - Some(tvar.0), - Some(ndims.0), - ); + let ndims = + unifier.get_fresh_const_generic_var(primitives.uint64, Some("N".into()), None); + let ndarray = make_ndarray_ty(unifier, primitives, Some(tvar.0), Some(ndims.0)); let arg_ty = unifier.get_fresh_var_with_range( &[list, ndarray, primitives.range], @@ -1319,9 +1274,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![FuncArg { name: "ls".into(), ty: arg_ty.0, default_value: None }], ret: int32, - vars: vec![(tvar.1, tvar.0), (arg_ty.1, arg_ty.0)] - .into_iter() - .collect(), + vars: vec![(tvar.1, tvar.0), (arg_ty.1, arg_ty.0)].into_iter().collect(), })), var_id: Vec::default(), instance_to_symbol: HashMap::default(), @@ -1333,9 +1286,13 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; Ok(if ctx.unifier.unioned(arg_ty, range_ty) { - let arg = RangeValue::from_ptr_val(arg.into_pointer_value(), Some("range")); + let arg = + RangeValue::from_ptr_val(arg.into_pointer_value(), Some("range")); let (start, end, step) = destructure_range(ctx, arg); - Some(calculate_len_for_slice_range(generator, ctx, start, end, step).into()) + Some( + calculate_len_for_slice_range(generator, ctx, start, end, step) + .into(), + ) } else { match &*ctx.unifier.get_ty_immutable(arg_ty) { TypeEnum::TList { .. } => { @@ -1351,32 +1308,37 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built if len.get_type().get_bit_width() == 32 { Some(len.into()) } else { - Some(ctx.builder - .build_int_truncate(len, int32, "len2i32") - .map(Into::into) - .unwrap() + Some( + ctx.builder + .build_int_truncate(len, int32, "len2i32") + .map(Into::into) + .unwrap(), ) } } - TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { + TypeEnum::TObj { obj_id, .. } + if *obj_id == PRIMITIVE_DEF_IDS.ndarray => + { let llvm_i32 = ctx.ctx.i32_type(); let llvm_usize = generator.get_size_type(ctx.ctx); let arg = NDArrayValue::from_ptr_val( arg.into_pointer_value(), llvm_usize, - None + None, ); let ndims = arg.dim_sizes().size(ctx, generator); ctx.make_assert( generator, - ctx.builder.build_int_compare( - IntPredicate::NE, - ndims, - llvm_usize.const_zero(), - "", - ).unwrap(), + ctx.builder + .build_int_compare( + IntPredicate::NE, + ndims, + llvm_usize.const_zero(), + "", + ) + .unwrap(), "0:TypeError", "len() of unsized object", [None, None, None], @@ -1395,11 +1357,12 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built if len.get_type().get_bit_width() == 32 { Some(len.into()) } else { - Some(ctx.builder - .build_int_truncate(len, llvm_i32, "len") - .map(Into::into) - .unwrap() - ) + Some( + ctx.builder + .build_int_truncate(len, llvm_i32, "len") + .map(Into::into) + .unwrap(), + ) } } _ => unreachable!(), @@ -1439,7 +1402,8 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built })), { let ret_ty = unifier.get_fresh_var(Some("R".into()), None); - let var_map = num_or_ndarray_var_map.clone() + let var_map = num_or_ndarray_var_map + .clone() .into_iter() .chain(once((ret_ty.1, ret_ty.0))) .collect::>(); @@ -1452,8 +1416,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float_or_ndarray_ty.0, "a")], Box::new(|ctx, _, fun, args, generator| { let a_ty = fun.0.args[0].ty; - let a = args[0].1.clone() - .to_basic_value_enum(ctx, generator, a_ty)?; + let a = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?; Ok(Some(builtin_fns::call_numpy_min(generator, ctx, (a_ty, a))?)) }), @@ -1469,30 +1432,36 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built name: "np_minimum".into(), simple_name: "np_minimum".into(), signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: param_ty.iter().map(|p| FuncArg { - name: p.1.into(), - ty: p.0, - default_value: None, - }).collect(), + args: param_ty + .iter() + .map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None }) + .collect(), ret: ret_ty.0, - vars: [ - (x1_ty.1, x1_ty.0), - (x2_ty.1, x2_ty.0), - (ret_ty.1, ret_ty.0), - ].into_iter().collect(), + vars: [(x1_ty.1, x1_ty.0), (x2_ty.1, x2_ty.0), (ret_ty.1, ret_ty.0)] + .into_iter() + .collect(), })), var_id: vec![x1_ty.1, x2_ty.1], instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x2_ty = fun.0.args[1].ty; - let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - - Ok(Some(builtin_fns::call_numpy_minimum(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) - })))), + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x2_ty = fun.0.args[1].ty; + let x1_val = + args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_val = + args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; + + Ok(Some(builtin_fns::call_numpy_minimum( + generator, + ctx, + (x1_ty, x1_val), + (x2_ty, x2_val), + )?)) + }, + )))), loc: None, })) }, @@ -1525,7 +1494,8 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built })), { let ret_ty = unifier.get_fresh_var(Some("R".into()), None); - let var_map = num_or_ndarray_var_map.clone() + let var_map = num_or_ndarray_var_map + .clone() .into_iter() .chain(once((ret_ty.1, ret_ty.0))) .collect::>(); @@ -1538,8 +1508,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float_or_ndarray_ty.0, "a")], Box::new(|ctx, _, fun, args, generator| { let a_ty = fun.0.args[0].ty; - let a = args[0].1.clone() - .to_basic_value_enum(ctx, generator, a_ty)?; + let a = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?; Ok(Some(builtin_fns::call_numpy_max(generator, ctx, (a_ty, a))?)) }), @@ -1555,30 +1524,36 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built name: "np_maximum".into(), simple_name: "np_maximum".into(), signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: param_ty.iter().map(|p| FuncArg { - name: p.1.into(), - ty: p.0, - default_value: None, - }).collect(), + args: param_ty + .iter() + .map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None }) + .collect(), ret: ret_ty.0, - vars: [ - (x1_ty.1, x1_ty.0), - (x2_ty.1, x2_ty.0), - (ret_ty.1, ret_ty.0), - ].into_iter().collect(), + vars: [(x1_ty.1, x1_ty.0), (x2_ty.1, x2_ty.0), (ret_ty.1, ret_ty.0)] + .into_iter() + .collect(), })), var_id: vec![x1_ty.1, x2_ty.1], instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x2_ty = fun.0.args[1].ty; - let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x2_ty = fun.0.args[1].ty; + let x1_val = + args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_val = + args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - Ok(Some(builtin_fns::call_numpy_maximum(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) - })))), + Ok(Some(builtin_fns::call_numpy_maximum( + generator, + ctx, + (x1_ty, x1_val), + (x2_ty, x2_val), + )?)) + }, + )))), loc: None, })) }, @@ -1586,7 +1561,11 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built name: "abs".into(), simple_name: "abs".into(), signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }], + args: vec![FuncArg { + name: "n".into(), + ty: num_or_ndarray_ty.0, + default_value: None, + }], ret: num_or_ndarray_ty.0, vars: num_or_ndarray_var_map.clone(), })), @@ -1612,8 +1591,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; Ok(Some(builtin_fns::call_numpy_isnan(generator, ctx, (x_ty, x_val))?)) }), @@ -1626,8 +1604,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; Ok(Some(builtin_fns::call_numpy_isinf(generator, ctx, (x_ty, x_val))?)) }), @@ -1640,8 +1617,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; Ok(Some(builtin_fns::call_numpy_sin(generator, ctx, (x_ty, x_val))?)) }), @@ -1654,8 +1630,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; Ok(Some(builtin_fns::call_numpy_cos(generator, ctx, (x_ty, x_val))?)) }), @@ -1668,8 +1643,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; Ok(Some(builtin_fns::call_numpy_exp(generator, ctx, (x_ty, x_val))?)) }), @@ -1682,8 +1656,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; Ok(Some(builtin_fns::call_numpy_exp2(generator, ctx, (x_ty, x_val))?)) }), @@ -1696,8 +1669,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; Ok(Some(builtin_fns::call_numpy_log(generator, ctx, (x_ty, x_val))?)) }), @@ -1710,8 +1682,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; Ok(Some(builtin_fns::call_numpy_log10(generator, ctx, (x_ty, x_val))?)) }), @@ -1724,8 +1695,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; Ok(Some(builtin_fns::call_numpy_log2(generator, ctx, (x_ty, x_val))?)) }), @@ -1738,8 +1708,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; Ok(Some(builtin_fns::call_numpy_fabs(generator, ctx, (x_ty, x_val))?)) }), @@ -1752,8 +1721,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; Ok(Some(builtin_fns::call_numpy_sqrt(generator, ctx, (x_ty, x_val))?)) }), @@ -1766,8 +1734,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; Ok(Some(builtin_fns::call_numpy_rint(generator, ctx, (x_ty, x_val))?)) }), @@ -1780,8 +1747,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; Ok(Some(builtin_fns::call_numpy_tan(generator, ctx, (x_ty, x_val))?)) }), @@ -1794,8 +1760,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; Ok(Some(builtin_fns::call_numpy_arcsin(generator, ctx, (x_ty, x_val))?)) }), @@ -1808,8 +1773,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; Ok(Some(builtin_fns::call_numpy_arccos(generator, ctx, (x_ty, x_val))?)) }), @@ -1822,8 +1786,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; Ok(Some(builtin_fns::call_numpy_arctan(generator, ctx, (x_ty, x_val))?)) }), @@ -1836,8 +1799,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; Ok(Some(builtin_fns::call_numpy_sinh(generator, ctx, (x_ty, x_val))?)) }), @@ -1850,8 +1812,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; Ok(Some(builtin_fns::call_numpy_cosh(generator, ctx, (x_ty, x_val))?)) }), @@ -1864,8 +1825,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; Ok(Some(builtin_fns::call_numpy_tanh(generator, ctx, (x_ty, x_val))?)) }), @@ -1878,8 +1838,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; Ok(Some(builtin_fns::call_numpy_arcsinh(generator, ctx, (x_ty, x_val))?)) }), @@ -1892,8 +1851,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; Ok(Some(builtin_fns::call_numpy_arccosh(generator, ctx, (x_ty, x_val))?)) }), @@ -1906,8 +1864,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; Ok(Some(builtin_fns::call_numpy_arctanh(generator, ctx, (x_ty, x_val))?)) }), @@ -1920,8 +1877,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; Ok(Some(builtin_fns::call_numpy_expm1(generator, ctx, (x_ty, x_val))?)) }), @@ -1934,8 +1890,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; Ok(Some(builtin_fns::call_numpy_cbrt(generator, ctx, (x_ty, x_val))?)) }), @@ -1948,8 +1903,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float_or_ndarray_ty.0, "z")], Box::new(|ctx, _, fun, args, generator| { let z_ty = fun.0.args[0].ty; - let z_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, z_ty)?; + let z_val = args[0].1.clone().to_basic_value_enum(ctx, generator, z_ty)?; Ok(Some(builtin_fns::call_scipy_special_erf(generator, ctx, (z_ty, z_val))?)) }), @@ -1962,8 +1916,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let z_ty = fun.0.args[0].ty; - let z_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, z_ty)?; + let z_val = args[0].1.clone().to_basic_value_enum(ctx, generator, z_ty)?; Ok(Some(builtin_fns::call_scipy_special_erfc(generator, ctx, (z_ty, z_val))?)) }), @@ -1976,8 +1929,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float_or_ndarray_ty.0, "z")], Box::new(|ctx, _, fun, args, generator| { let z_ty = fun.0.args[0].ty; - let z_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, z_ty)?; + let z_val = args[0].1.clone().to_basic_value_enum(ctx, generator, z_ty)?; Ok(Some(builtin_fns::call_scipy_special_gamma(generator, ctx, (z_ty, z_val))?)) }), @@ -1990,8 +1942,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; Ok(Some(builtin_fns::call_scipy_special_gammaln(generator, ctx, (x_ty, x_val))?)) }), @@ -2004,8 +1955,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let z_ty = fun.0.args[0].ty; - let z_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, z_ty)?; + let z_val = args[0].1.clone().to_basic_value_enum(ctx, generator, z_ty)?; Ok(Some(builtin_fns::call_scipy_special_j0(generator, ctx, (z_ty, z_val))?)) }), @@ -2018,8 +1968,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; Ok(Some(builtin_fns::call_scipy_special_j1(generator, ctx, (x_ty, x_val))?)) }), @@ -2035,37 +1984,36 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built name: "np_arctan2".into(), simple_name: "np_arctan2".into(), signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: param_ty.iter().map(|p| FuncArg { - name: p.1.into(), - ty: p.0, - default_value: None, - }).collect(), + args: param_ty + .iter() + .map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None }) + .collect(), ret: ret_ty.0, - vars: [ - (x1_ty.1, x1_ty.0), - (x2_ty.1, x2_ty.0), - (ret_ty.1, ret_ty.0), - ].into_iter().collect(), + vars: [(x1_ty.1, x1_ty.0), (x2_ty.1, x2_ty.0), (ret_ty.1, ret_ty.0)] + .into_iter() + .collect(), })), var_id: vec![ret_ty.1], instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_ty = fun.0.args[1].ty; - let x2_val = args[1].1.clone() - .to_basic_value_enum(ctx, generator, x2_ty)?; - - Ok(Some(builtin_fns::call_numpy_arctan2( - generator, - ctx, - (x1_ty, x1_val), - (x2_ty, x2_val), - )?)) - })))), + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x1_val = + args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; + let x2_val = + args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; + + Ok(Some(builtin_fns::call_numpy_arctan2( + generator, + ctx, + (x1_ty, x1_val), + (x2_ty, x2_val), + )?)) + }, + )))), loc: None, })) }, @@ -2079,37 +2027,36 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built name: "np_copysign".into(), simple_name: "np_copysign".into(), signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: param_ty.iter().map(|p| FuncArg { - name: p.1.into(), - ty: p.0, - default_value: None, - }).collect(), + args: param_ty + .iter() + .map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None }) + .collect(), ret: ret_ty.0, - vars: [ - (x1_ty.1, x1_ty.0), - (x2_ty.1, x2_ty.0), - (ret_ty.1, ret_ty.0), - ].into_iter().collect(), + vars: [(x1_ty.1, x1_ty.0), (x2_ty.1, x2_ty.0), (ret_ty.1, ret_ty.0)] + .into_iter() + .collect(), })), var_id: vec![ret_ty.1], instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_ty = fun.0.args[1].ty; - let x2_val = args[1].1.clone() - .to_basic_value_enum(ctx, generator, x2_ty)?; + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x1_val = + args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; + let x2_val = + args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - Ok(Some(builtin_fns::call_numpy_copysign( - generator, - ctx, - (x1_ty, x1_val), - (x2_ty, x2_val), - )?)) - })))), + Ok(Some(builtin_fns::call_numpy_copysign( + generator, + ctx, + (x1_ty, x1_val), + (x2_ty, x2_val), + )?)) + }, + )))), loc: None, })) }, @@ -2123,37 +2070,36 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built name: "np_fmax".into(), simple_name: "np_fmax".into(), signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: param_ty.iter().map(|p| FuncArg { - name: p.1.into(), - ty: p.0, - default_value: None, - }).collect(), + args: param_ty + .iter() + .map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None }) + .collect(), ret: ret_ty.0, - vars: [ - (x1_ty.1, x1_ty.0), - (x2_ty.1, x2_ty.0), - (ret_ty.1, ret_ty.0), - ].into_iter().collect(), + vars: [(x1_ty.1, x1_ty.0), (x2_ty.1, x2_ty.0), (ret_ty.1, ret_ty.0)] + .into_iter() + .collect(), })), var_id: vec![x1_ty.1, x2_ty.1], instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_ty = fun.0.args[1].ty; - let x2_val = args[1].1.clone() - .to_basic_value_enum(ctx, generator, x2_ty)?; + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x1_val = + args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; + let x2_val = + args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - Ok(Some(builtin_fns::call_numpy_fmax( - generator, - ctx, - (x1_ty, x1_val), - (x2_ty, x2_val), - )?)) - })))), + Ok(Some(builtin_fns::call_numpy_fmax( + generator, + ctx, + (x1_ty, x1_val), + (x2_ty, x2_val), + )?)) + }, + )))), loc: None, })) }, @@ -2167,37 +2113,36 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built name: "np_fmin".into(), simple_name: "np_fmin".into(), signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: param_ty.iter().map(|p| FuncArg { - name: p.1.into(), - ty: p.0, - default_value: None, - }).collect(), + args: param_ty + .iter() + .map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None }) + .collect(), ret: ret_ty.0, - vars: [ - (x1_ty.1, x1_ty.0), - (x2_ty.1, x2_ty.0), - (ret_ty.1, ret_ty.0), - ].into_iter().collect(), + vars: [(x1_ty.1, x1_ty.0), (x2_ty.1, x2_ty.0), (ret_ty.1, ret_ty.0)] + .into_iter() + .collect(), })), var_id: vec![x1_ty.1, x2_ty.1], instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_ty = fun.0.args[1].ty; - let x2_val = args[1].1.clone() - .to_basic_value_enum(ctx, generator, x2_ty)?; + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x1_val = + args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; + let x2_val = + args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - Ok(Some(builtin_fns::call_numpy_fmin( - generator, - ctx, - (x1_ty, x1_val), - (x2_ty, x2_val), - )?)) - })))), + Ok(Some(builtin_fns::call_numpy_fmin( + generator, + ctx, + (x1_ty, x1_val), + (x2_ty, x2_val), + )?)) + }, + )))), loc: None, })) }, @@ -2211,37 +2156,36 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built name: "np_ldexp".into(), simple_name: "np_ldexp".into(), signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: param_ty.iter().map(|p| FuncArg { - name: p.1.into(), - ty: p.0, - default_value: None, - }).collect(), + args: param_ty + .iter() + .map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None }) + .collect(), ret: ret_ty.0, - vars: [ - (x1_ty.1, x1_ty.0), - (x2_ty.1, x2_ty.0), - (ret_ty.1, ret_ty.0), - ].into_iter().collect(), + vars: [(x1_ty.1, x1_ty.0), (x2_ty.1, x2_ty.0), (ret_ty.1, ret_ty.0)] + .into_iter() + .collect(), })), var_id: vec![x1_ty.1, x2_ty.1], instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_ty = fun.0.args[1].ty; - let x2_val = args[1].1.clone() - .to_basic_value_enum(ctx, generator, x2_ty)?; + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x1_val = + args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; + let x2_val = + args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - Ok(Some(builtin_fns::call_numpy_ldexp( - generator, - ctx, - (x1_ty, x1_val), - (x2_ty, x2_val), - )?)) - })))), + Ok(Some(builtin_fns::call_numpy_ldexp( + generator, + ctx, + (x1_ty, x1_val), + (x2_ty, x2_val), + )?)) + }, + )))), loc: None, })) }, @@ -2255,37 +2199,36 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built name: "np_hypot".into(), simple_name: "np_hypot".into(), signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: param_ty.iter().map(|p| FuncArg { - name: p.1.into(), - ty: p.0, - default_value: None, - }).collect(), + args: param_ty + .iter() + .map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None }) + .collect(), ret: ret_ty.0, - vars: [ - (x1_ty.1, x1_ty.0), - (x2_ty.1, x2_ty.0), - (ret_ty.1, ret_ty.0), - ].into_iter().collect(), + vars: [(x1_ty.1, x1_ty.0), (x2_ty.1, x2_ty.0), (ret_ty.1, ret_ty.0)] + .into_iter() + .collect(), })), var_id: vec![x1_ty.1, x2_ty.1], instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_ty = fun.0.args[1].ty; - let x2_val = args[1].1.clone() - .to_basic_value_enum(ctx, generator, x2_ty)?; + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x1_val = + args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; + let x2_val = + args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - Ok(Some(builtin_fns::call_numpy_hypot( - generator, - ctx, - (x1_ty, x1_val), - (x2_ty, x2_val), - )?)) - })))), + Ok(Some(builtin_fns::call_numpy_hypot( + generator, + ctx, + (x1_ty, x1_val), + (x2_ty, x2_val), + )?)) + }, + )))), loc: None, })) }, @@ -2299,37 +2242,36 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built name: "np_nextafter".into(), simple_name: "np_nextafter".into(), signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: param_ty.iter().map(|p| FuncArg { - name: p.1.into(), - ty: p.0, - default_value: None, - }).collect(), + args: param_ty + .iter() + .map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None }) + .collect(), ret: ret_ty.0, - vars: [ - (x1_ty.1, x1_ty.0), - (x2_ty.1, x2_ty.0), - (ret_ty.1, ret_ty.0), - ].into_iter().collect(), + vars: [(x1_ty.1, x1_ty.0), (x2_ty.1, x2_ty.0), (ret_ty.1, ret_ty.0)] + .into_iter() + .collect(), })), var_id: vec![x1_ty.1, x2_ty.1], instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_ty = fun.0.args[1].ty; - let x2_val = args[1].1.clone() - .to_basic_value_enum(ctx, generator, x2_ty)?; + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x1_val = + args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; + let x2_val = + args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - Ok(Some(builtin_fns::call_numpy_nextafter( - generator, - ctx, - (x1_ty, x1_val), - (x2_ty, x2_val), - )?)) - })))), + Ok(Some(builtin_fns::call_numpy_nextafter( + generator, + ctx, + (x1_ty, x1_val), + (x2_ty, x2_val), + )?)) + }, + )))), loc: None, })) }, @@ -2349,7 +2291,9 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built |ctx, _, fun, args, generator| { let arg_ty = fun.0.args[0].ty; let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - let alloca = generator.gen_var_alloc(ctx, arg_val.get_type(), Some("alloca_some")).unwrap(); + let alloca = generator + .gen_var_alloc(ctx, arg_val.get_type(), Some("alloca_some")) + .unwrap(); ctx.builder.build_store(alloca, arg_val).unwrap(); Ok(Some(alloca.into())) }, @@ -2358,8 +2302,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built })), ]; - let ast_list: Vec>> = - (0..top_level_def_list.len()).map(|_| None).collect(); + let ast_list: Vec>> = (0..top_level_def_list.len()).map(|_| None).collect(); izip!(top_level_def_list, ast_list).collect_vec() } diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 1b149ce7..dc3ab12c 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -82,7 +82,8 @@ impl TopLevelComposer { let mut builtin_id = HashMap::default(); let mut builtin_ty = HashMap::default(); - let builtin_name_list = definition_ast_list.iter() + let builtin_name_list = definition_ast_list + .iter() .map(|def_ast| match *def_ast.0.read() { TopLevelDef::Class { name, .. } => name.to_string(), TopLevelDef::Function { simple_name, .. } => simple_name.to_string(), @@ -93,19 +94,24 @@ impl TopLevelComposer { let name = (**name).into(); let def = definition_ast_list[id].0.read(); if let TopLevelDef::Function { name: func_name, simple_name, signature, .. } = &*def { - assert_eq!(name, *simple_name, "Simple name of builtin function should match builtin name list"); + assert_eq!( + name, *simple_name, + "Simple name of builtin function should match builtin name list" + ); // Do not add member functions into the list of builtin IDs; // Here we assume that all builtin top-level functions have the same name and simple // name, and all member functions have something prefixed to its name if *func_name != simple_name.to_string() { - continue + continue; } builtin_ty.insert(name, *signature); builtin_id.insert(name, DefinitionId(id)); - } else if let TopLevelDef::Class { name, constructor, object_id, .. } = &*def - { - assert_eq!(id, object_id.0, "Object id of class '{name}' should match its index in builtin name list"); + } else if let TopLevelDef::Class { name, constructor, object_id, .. } = &*def { + assert_eq!( + id, object_id.0, + "Object id of class '{name}' should match its index in builtin name list" + ); if let Some(constructor) = constructor { builtin_ty.insert(*name, *constructor); } @@ -384,9 +390,9 @@ impl TopLevelComposer { let mut class_def = class_def.write(); let (class_bases_ast, class_def_type_vars, class_resolver) = { if let TopLevelDef::Class { type_vars, resolver, .. } = &mut *class_def { - let Some(ast::Located { - node: ast::StmtKind::ClassDef { bases, .. }, .. - }) = class_ast else { + let Some(ast::Located { node: ast::StmtKind::ClassDef { bases, .. }, .. }) = + class_ast + else { unreachable!() }; @@ -415,12 +421,10 @@ impl TopLevelComposer { } => { if is_generic { - return Err(HashSet::from([ - format!( - "only single Generic[...] is allowed (at {})", - b.location - ), - ])) + return Err(HashSet::from([format!( + "only single Generic[...] is allowed (at {})", + b.location + )])); } is_generic = true; @@ -459,12 +463,10 @@ impl TopLevelComposer { }) }; if !all_unique_type_var { - return Err(HashSet::from([ - format!( - "duplicate type variable occurs (at {})", - slice.location - ), - ])) + return Err(HashSet::from([format!( + "duplicate type variable occurs (at {})", + slice.location + )])); } // add to TopLevelDef @@ -487,7 +489,7 @@ impl TopLevelComposer { } } if !errors.is_empty() { - return Err(errors) + return Err(errors); } Ok(()) } @@ -514,9 +516,9 @@ impl TopLevelComposer { } = &mut *class_def { let Some(ast::Located { - node: ast::StmtKind::ClassDef { bases, .. }, - .. - }) = class_ast else { + node: ast::StmtKind::ClassDef { bases, .. }, .. + }) = class_ast + else { unreachable!() }; @@ -543,13 +545,11 @@ impl TopLevelComposer { } if has_base { - return Err(HashSet::from([ - format!( - "a class definition can only have at most one base class \ + return Err(HashSet::from([format!( + "a class definition can only have at most one base class \ declaration and one generic declaration (at {})", - b.location - ), - ])) + b.location + )])); } has_base = true; @@ -567,12 +567,10 @@ impl TopLevelComposer { if let TypeAnnotation::CustomClass { .. } = &base_ty { class_ancestors.push(base_ty); } else { - return Err(HashSet::from([ - format!( - "class base declaration can only be custom class (at {})", - b.location, - ), - ])) + return Err(HashSet::from([format!( + "class base declaration can only be custom class (at {})", + b.location, + )])); } } Ok(()) @@ -589,31 +587,35 @@ impl TopLevelComposer { } } if !errors.is_empty() { - return Err(errors) + return Err(errors); } // second, get all ancestors let mut ancestors_store: HashMap> = HashMap::default(); - let mut get_all_ancestors = |class_def: &Arc>| -> Result<(), HashSet> { - let class_def = class_def.read(); - let (class_ancestors, class_id) = { - if let TopLevelDef::Class { ancestors, object_id, .. } = &*class_def { - (ancestors, *object_id) - } else { - return Ok(()); - } + let mut get_all_ancestors = + |class_def: &Arc>| -> Result<(), HashSet> { + let class_def = class_def.read(); + let (class_ancestors, class_id) = { + if let TopLevelDef::Class { ancestors, object_id, .. } = &*class_def { + (ancestors, *object_id) + } else { + return Ok(()); + } + }; + ancestors_store.insert( + class_id, + // if class has direct parents, get all ancestors of its parents. Else just empty + if class_ancestors.is_empty() { + vec![] + } else { + Self::get_all_ancestors_helper( + &class_ancestors[0], + temp_def_list.as_slice(), + )? + }, + ); + Ok(()) }; - ancestors_store.insert( - class_id, - // if class has direct parents, get all ancestors of its parents. Else just empty - if class_ancestors.is_empty() { - vec![] - } else { - Self::get_all_ancestors_helper(&class_ancestors[0], temp_def_list.as_slice())? - }, - ); - Ok(()) - }; for (class_def, ast) in self.definition_ast_list.iter().skip(self.builtin_num) { if ast.is_none() { continue; @@ -623,7 +625,7 @@ impl TopLevelComposer { } } if !errors.is_empty() { - return Err(errors) + return Err(errors); } // insert the ancestors to the def list @@ -633,8 +635,7 @@ impl TopLevelComposer { } let mut class_def = class_def.write(); let (class_ancestors, class_id, class_type_vars) = { - if let TopLevelDef::Class { ancestors, object_id, type_vars, .. } = - &mut *class_def + if let TopLevelDef::Class { ancestors, object_id, type_vars, .. } = &mut *class_def { (ancestors, *object_id, type_vars) } else { @@ -665,8 +666,9 @@ impl TopLevelComposer { ast::StmtKind::FunctionDef { .. } | ast::StmtKind::AnnAssign { .. } ) { return Err(HashSet::from([ - "Classes inherited from exception should have no custom fields/methods".into() - ])) + "Classes inherited from exception should have no custom fields/methods" + .into(), + ])); } } } @@ -674,7 +676,8 @@ impl TopLevelComposer { // deal with ancestor of Exception object let TopLevelDef::Class { name, ancestors, object_id, .. } = - &mut *self.definition_ast_list[7].0.write() else { + &mut *self.definition_ast_list[7].0.write() + else { unreachable!() }; @@ -713,7 +716,7 @@ impl TopLevelComposer { } } if !errors.is_empty() { - return Err(errors) + return Err(errors); } // handle the inherited methods and fields @@ -758,9 +761,14 @@ impl TopLevelComposer { let mut subst_list = Some(Vec::new()); // unification of previously assigned typevar let mut unification_helper = |ty, def| -> Result<(), HashSet> { - let target_ty = - get_type_from_type_annotation_kinds(&temp_def_list, unifier, &def, &mut subst_list)?; - unifier.unify(ty, target_ty) + let target_ty = get_type_from_type_annotation_kinds( + &temp_def_list, + unifier, + &def, + &mut subst_list, + )?; + unifier + .unify(ty, target_ty) .map_err(|e| HashSet::from([e.to_display(unifier).to_string()]))?; Ok(()) }; @@ -793,14 +801,16 @@ impl TopLevelComposer { } } if !errors.is_empty() { - return Err(errors) + return Err(errors); } for (def, _) in def_ast_list.iter().skip(self.builtin_num) { match &*def.read() { TopLevelDef::Class { resolver: Some(resolver), .. } | TopLevelDef::Function { resolver: Some(resolver), .. } => { - if let Err(e) = resolver.handle_deferred_eval(unifier, &temp_def_list, primitives) { + if let Err(e) = + resolver.handle_deferred_eval(unifier, &temp_def_list, primitives) + { errors.insert(e); } } @@ -828,7 +838,8 @@ impl TopLevelComposer { return Ok(()); }; - let TopLevelDef::Function { signature: dummy_ty, resolver, var_id, .. } = function_def else { + let TopLevelDef::Function { signature: dummy_ty, resolver, var_id, .. } = function_def + else { // not top level function def, skip return Ok(()); }; @@ -857,25 +868,22 @@ impl TopLevelComposer { "top level function must have unique parameter names \ and names should not be the same as the keywords (at {})", x.location - ), - ])) - }} + )])); + } + } - let arg_with_default: Vec<( - &ast::Located>, - Option<&ast::Expr>, - )> = args - .args - .iter() - .rev() - .zip( - args.defaults - .iter() - .rev() - .map(|x| -> Option<&ast::Expr> { Some(x) }) - .chain(std::iter::repeat(None)), - ) - .collect_vec(); + let arg_with_default: Vec<(&ast::Located>, Option<&ast::Expr>)> = + args.args + .iter() + .rev() + .zip( + args.defaults + .iter() + .rev() + .map(|x| -> Option<&ast::Expr> { Some(x) }) + .chain(std::iter::repeat(None)), + ) + .collect_vec(); arg_with_default .iter() @@ -885,12 +893,12 @@ impl TopLevelComposer { .node .annotation .as_ref() - .ok_or_else(|| HashSet::from([ - format!( + .ok_or_else(|| { + HashSet::from([format!( "function parameter `{}` needs type annotation at {}", x.node.arg, x.location - ), - ]))? + )]) + })? .as_ref(); let type_annotation = parse_ast_to_type_annotation_kinds( @@ -926,7 +934,7 @@ impl TopLevelComposer { temp_def_list.as_ref(), unifier, &type_annotation, - &mut None + &mut None, )?; Ok(FuncArg { @@ -935,18 +943,16 @@ impl TopLevelComposer { default_value: match default { None => None, Some(default) => Some({ - let v = Self::parse_parameter_default_value( - default, resolver, - )?; + let v = Self::parse_parameter_default_value(default, resolver)?; Self::check_default_param_type( &v, &type_annotation, primitives_store, unifier, ) - .map_err( - |err| HashSet::from([format!("{} (at {})", err, x.location), - ]))?; + .map_err(|err| { + HashSet::from([format!("{} (at {})", err, x.location)]) + })?; v }), }, @@ -993,7 +999,7 @@ impl TopLevelComposer { &temp_def_list, unifier, &return_ty_annotation, - &mut None + &mut None, )? } else { primitives_store.none @@ -1016,9 +1022,9 @@ impl TopLevelComposer { ret: return_ty, vars: function_var_map, })); - unifier.unify(*dummy_ty, function_ty).map_err(|e| HashSet::from([ - e.at(Some(function_ast.location)).to_display(unifier).to_string(), - ]))?; + unifier.unify(*dummy_ty, function_ty).map_err(|e| { + HashSet::from([e.at(Some(function_ast.location)).to_display(unifier).to_string()]) + })?; Ok(()) }; for (function_def, function_ast) in def_list.iter().skip(self.builtin_num) { @@ -1030,7 +1036,7 @@ impl TopLevelComposer { } } if !errors.is_empty() { - return Err(errors) + return Err(errors); } Ok(()) } @@ -1047,14 +1053,9 @@ impl TopLevelComposer { let (keyword_list, core_config) = core_info; let mut class_def = class_def.write(); let TopLevelDef::Class { - object_id, - ancestors, - fields, - methods, - resolver, - type_vars, - .. - } = &mut *class_def else { + object_id, ancestors, fields, methods, resolver, type_vars, .. + } = &mut *class_def + else { unreachable!("here must be toplevel class def"); }; let ast::StmtKind::ClassDef { name, bases, body, .. } = &class_ast else { @@ -1375,14 +1376,9 @@ impl TopLevelComposer { type_var_to_concrete_def: &mut HashMap, ) -> Result<(), HashSet> { let TopLevelDef::Class { - object_id, - ancestors, - fields, - methods, - resolver, - type_vars, - .. - } = class_def else { + object_id, ancestors, fields, methods, resolver, type_vars, .. + } = class_def + else { unreachable!("here must be class def ast") }; let ( @@ -1414,9 +1410,7 @@ impl TopLevelComposer { for (anc_method_name, anc_method_ty, anc_method_def_id) in methods { // find if there is a method with same name in the child class let mut to_be_added = (*anc_method_name, *anc_method_ty, *anc_method_def_id); - for (class_method_name, class_method_ty, class_method_defid) in - &*class_methods_def - { + for (class_method_name, class_method_ty, class_method_defid) in &*class_methods_def { if class_method_name == anc_method_name { // ignore and handle self // if is __init__ method, no need to check return type @@ -1430,27 +1424,20 @@ impl TopLevelComposer { if !ok { return Err(HashSet::from([format!( "method {class_method_name} has same name as ancestors' method, but incompatible type"), - ])) + ])); } // mark it as added is_override.insert(*class_method_name); - to_be_added = - (*class_method_name, *class_method_ty, *class_method_defid); + to_be_added = (*class_method_name, *class_method_ty, *class_method_defid); break; } } new_child_methods.push(to_be_added); } // add those that are not overriding method to the new_child_methods - for (class_method_name, class_method_ty, class_method_defid) in - &*class_methods_def - { + for (class_method_name, class_method_ty, class_method_defid) in &*class_methods_def { if !is_override.contains(class_method_name) { - new_child_methods.push(( - *class_method_name, - *class_method_ty, - *class_method_defid, - )); + new_child_methods.push((*class_method_name, *class_method_ty, *class_method_defid)); } } // use the new_child_methods to replace all the elements in `class_methods_def` @@ -1466,8 +1453,8 @@ impl TopLevelComposer { for (class_field_name, ..) in &*class_fields_def { if class_field_name == anc_field_name { return Err(HashSet::from([format!( - "field `{class_field_name}` has already declared in the ancestor classes"), - ])) + "field `{class_field_name}` has already declared in the ancestor classes" + )])); } } new_child_fields.push(to_be_added); @@ -1499,24 +1486,30 @@ impl TopLevelComposer { // first, fix function typevar ids // they may be changed with our use of placeholders for (def, _) in definition_ast_list.iter().skip(self.builtin_num) { - if let TopLevelDef::Function { - signature, - var_id, - .. - } = &mut *def.write() { + if let TopLevelDef::Function { signature, var_id, .. } = &mut *def.write() { if let TypeEnum::TFunc(FunSignature { args, ret, vars }) = - unifier.get_ty(*signature).as_ref() { - let new_var_ids = vars.values().map(|v| match &*unifier.get_ty(*v) { - TypeEnum::TVar{id, ..} => *id, - _ => unreachable!(), - }).collect_vec(); + unifier.get_ty(*signature).as_ref() + { + let new_var_ids = vars + .values() + .map(|v| match &*unifier.get_ty(*v) { + TypeEnum::TVar { id, .. } => *id, + _ => unreachable!(), + }) + .collect_vec(); if new_var_ids != *var_id { let new_signature = FunSignature { args: args.clone(), ret: *ret, - vars: new_var_ids.iter().zip(vars.values()).map(|(id, v)| (*id, *v)).collect(), + vars: new_var_ids + .iter() + .zip(vars.values()) + .map(|(id, v)| (*id, *v)) + .collect(), }; - unifier.unification_table.set_value(*signature, Rc::new(TypeEnum::TFunc(new_signature))); + unifier + .unification_table + .set_value(*signature, Rc::new(TypeEnum::TFunc(new_signature))); *var_id = new_var_ids; } } @@ -1542,7 +1535,7 @@ impl TopLevelComposer { &def_list, unifier, &make_self_type_annotation(type_vars, *object_id), - &mut None + &mut None, )?; if ancestors .iter() @@ -1590,9 +1583,12 @@ impl TopLevelComposer { }; constructors.push((i, signature, definition_extension.len())); definition_extension.push((Arc::new(RwLock::new(cons_fun)), None)); - unifier.unify(constructor.unwrap(), signature).map_err(|e| HashSet::from([ - e.at(Some(ast.as_ref().unwrap().location)).to_display(unifier).to_string() - ]))?; + unifier.unify(constructor.unwrap(), signature).map_err(|e| { + HashSet::from([e + .at(Some(ast.as_ref().unwrap().location)) + .to_display(unifier) + .to_string()]) + })?; return Ok(()); } let mut init_id: Option = None; @@ -1605,7 +1601,8 @@ impl TopLevelComposer { init_id = Some(*id); let func_ty_enum = unifier.get_ty(*func_sig); let TypeEnum::TFunc(FunSignature { args, vars, .. }) = - func_ty_enum.as_ref() else { + func_ty_enum.as_ref() + else { unreachable!("must be typeenum::tfunc") }; @@ -1620,9 +1617,12 @@ impl TopLevelComposer { ret: self_type, vars: contor_type_vars, })); - unifier.unify(constructor.unwrap(), contor_type).map_err(|e| HashSet::from([ - e.at(Some(ast.as_ref().unwrap().location)).to_display(unifier).to_string() - ]))?; + unifier.unify(constructor.unwrap(), contor_type).map_err(|e| { + HashSet::from([e + .at(Some(ast.as_ref().unwrap().location)) + .to_display(unifier) + .to_string()]) + })?; // class field instantiation check if let (Some(init_id), false) = (init_id, fields.is_empty()) { @@ -1641,7 +1641,7 @@ impl TopLevelComposer { class_name, body[0].location, ), - ])) + ])); } } } @@ -1658,11 +1658,12 @@ impl TopLevelComposer { } } if !errors.is_empty() { - return Err(errors) + return Err(errors); } for (i, signature, id) in constructors { - let TopLevelDef::Class { methods, .. } = &mut *self.definition_ast_list[i].0.write() else { + let TopLevelDef::Class { methods, .. } = &mut *self.definition_ast_list[i].0.write() + else { unreachable!() }; @@ -1697,8 +1698,8 @@ impl TopLevelComposer { } = &mut *function_def { let signature_ty_enum = unifier.get_ty(*signature); - let TypeEnum::TFunc(FunSignature { args, ret, vars }) = - signature_ty_enum.as_ref() else { + let TypeEnum::TFunc(FunSignature { args, ret, vars }) = signature_ty_enum.as_ref() + else { unreachable!("must be typeenum::tfunc") }; @@ -1714,10 +1715,7 @@ impl TopLevelComposer { let ty_ann = make_self_type_annotation(type_vars, *class_id); let self_ty = get_type_from_type_annotation_kinds( - &def_list, - unifier, - &ty_ann, - &mut None + &def_list, unifier, &ty_ann, &mut None, )?; vars.extend(type_vars.iter().map(|ty| { let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*ty) else { @@ -1739,7 +1737,9 @@ impl TopLevelComposer { .values() .map(|ty| { unifier.get_instantiations(*ty).unwrap_or_else(|| { - let TypeEnum::TVar { name, loc, is_const_generic: false, .. } = &*unifier.get_ty(*ty) else { + let TypeEnum::TVar { name, loc, is_const_generic: false, .. } = + &*unifier.get_ty(*ty) + else { unreachable!() }; @@ -1779,8 +1779,7 @@ impl TopLevelComposer { let class_ty_var_ids = type_vars .iter() .map(|x| { - if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x) - { + if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x) { *id } else { unreachable!("must be type var here"); @@ -1839,7 +1838,8 @@ impl TopLevelComposer { }; let ast::StmtKind::FunctionDef { body, decorator_list, .. } = - ast.clone().unwrap().node else { + ast.clone().unwrap().node + else { unreachable!("must be function def ast") }; if !decorator_list.is_empty() @@ -1857,13 +1857,12 @@ impl TopLevelComposer { continue; } - let fun_body = body + let fun_body = body .into_iter() .map(|b| inferencer.fold_stmt(b)) .collect::, _>>()?; - let returned = - inferencer.check_block(fun_body.as_slice(), &mut identifiers)?; + let returned = inferencer.check_block(fun_body.as_slice(), &mut identifiers)?; { // check virtuals let defs = ctx.definitions.read(); @@ -1873,9 +1872,9 @@ impl TopLevelComposer { if let TypeEnum::TObj { obj_id, .. } = &*base { *obj_id } else { - return Err(HashSet::from([ - format!("Base type should be a class (at {loc})"), - ])) + return Err(HashSet::from([format!( + "Base type should be a class (at {loc})" + )])); } }; let subtype_id = { @@ -1887,7 +1886,7 @@ impl TopLevelComposer { let subtype_repr = inferencer.unifier.stringify(*subtype); return Err(HashSet::from([format!( "Expected a subtype of {base_repr}, but got {subtype_repr} (at {loc})"), - ])) + ])); } }; let subtype_entry = defs[subtype_id.0].read(); @@ -1902,7 +1901,7 @@ impl TopLevelComposer { let subtype_repr = inferencer.unifier.stringify(*subtype); return Err(HashSet::from([format!( "Expected a subtype of {base_repr}, but got {subtype_repr} (at {loc})"), - ])) + ])); } } } @@ -1912,7 +1911,9 @@ impl TopLevelComposer { inst_ret, &mut |id| { let TopLevelDef::Class { name, .. } = &*def_ast_list[id].0.read() - else { unreachable!("must be class id here") }; + else { + unreachable!("must be class id here") + }; name.to_string() }, @@ -1924,11 +1925,16 @@ impl TopLevelComposer { ret_str, name, ast.as_ref().unwrap().location - ),])) + )])); } instance_to_stmt.insert( - get_subst_key(unifier, self_type, &subst, Some(&vars.keys().copied().collect())), + get_subst_key( + unifier, + self_type, + &subst, + Some(&vars.keys().copied().collect()), + ), FunInstance { body: Arc::new(fun_body), unifier_id: 0, @@ -1950,7 +1956,7 @@ impl TopLevelComposer { } } if !errors.is_empty() { - return Err(errors) + return Err(errors); } Ok(()) } diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index fdef020a..f1f72fe8 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -47,7 +47,7 @@ impl PrimitiveDefinitionIds { } /// Returns an iterator over all [`DefinitionId`]s of this instance in indeterminate order. - pub fn iter(&self) -> impl Iterator { + pub fn iter(&self) -> impl Iterator { self.as_vec().into_iter() } @@ -208,7 +208,8 @@ impl TopLevelComposer { }; let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None); - let ndarray_ndims_tvar = unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None); + let ndarray_ndims_tvar = + unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None); let ndarray_copy_fun_ret_ty = unifier.get_fresh_var(None, None); let ndarray_copy_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![], @@ -219,13 +220,11 @@ impl TopLevelComposer { ]), })); let ndarray_fill_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![ - FuncArg { - name: "value".into(), - ty: ndarray_dtype_tvar.0, - default_value: None, - }, - ], + args: vec![FuncArg { + name: "value".into(), + ty: ndarray_dtype_tvar.0, + default_value: None, + }], ret: none, vars: VarMap::from([ (ndarray_dtype_tvar.1, ndarray_dtype_tvar.0), @@ -393,9 +392,7 @@ impl TopLevelComposer { if let TypeEnum::TVar { id, .. } = unifier.get_ty(var_ty).as_ref() { Ok(*id) } else { - Err(HashSet::from([ - "not type var".to_string(), - ])) + Err(HashSet::from(["not type var".to_string()])) } } @@ -412,25 +409,27 @@ impl TopLevelComposer { let ( TypeEnum::TFunc(FunSignature { args: this_args, ret: this_ret, .. }), TypeEnum::TFunc(FunSignature { args: other_args, ret: other_ret, .. }), - ) = (this, other) else { + ) = (this, other) + else { unreachable!("this function must be called with function type") }; // check args - let args_ok = this_args - .iter() - .map(|FuncArg { name, ty, .. }| (name, type_var_to_concrete_def.get(ty).unwrap())) - .zip(other_args.iter().map(|FuncArg { name, ty, .. }| { - (name, type_var_to_concrete_def.get(ty).unwrap()) - })) - .all(|(this, other)| { - if this.0 == &"self".into() && this.0 == other.0 { - true - } else { - this.0 == other.0 - && check_overload_type_annotation_compatible(this.1, other.1, unifier) - } - }); + let args_ok = + this_args + .iter() + .map(|FuncArg { name, ty, .. }| (name, type_var_to_concrete_def.get(ty).unwrap())) + .zip(other_args.iter().map(|FuncArg { name, ty, .. }| { + (name, type_var_to_concrete_def.get(ty).unwrap()) + })) + .all(|(this, other)| { + if this.0 == &"self".into() && this.0 == other.0 { + true + } else { + this.0 == other.0 + && check_overload_type_annotation_compatible(this.1, other.1, unifier) + } + }); // check rets let ret_ok = check_overload_type_annotation_compatible( @@ -473,12 +472,10 @@ impl TopLevelComposer { } } => { - return Err(HashSet::from([ - format!( - "redundant type annotation for class fields at {}", - s.location - ), - ])) + return Err(HashSet::from([format!( + "redundant type annotation for class fields at {}", + s.location + )])) } ast::StmtKind::Assign { targets, .. } => { for t in targets { @@ -602,105 +599,102 @@ pub fn parse_parameter_default_value( Constant::Tuple(tuple) => Ok(SymbolValue::Tuple( tuple.iter().map(|x| handle_constant(x, loc)).collect::, _>>()?, )), - Constant::None => Err(HashSet::from([ - format!( - "`None` is not supported, use `none` for option type instead ({loc})" - ), - ])), + Constant::None => Err(HashSet::from([format!( + "`None` is not supported, use `none` for option type instead ({loc})" + )])), _ => unimplemented!("this constant is not supported at {}", loc), } } match &default.node { ast::ExprKind::Constant { value, .. } => handle_constant(value, &default.location), - ast::ExprKind::Call { func, args, .. } if args.len() == 1 => { - match &func.node { - ast::ExprKind::Name { id, .. } if *id == "int64".into() => match &args[0].node { - ast::ExprKind::Constant { value: Constant::Int(v), .. } => { - let v: Result = (*v).try_into(); - match v { - Ok(v) => Ok(SymbolValue::I64(v)), - _ => Err(HashSet::from([ - format!("default param value out of range at {}", default.location) - ])), - } + ast::ExprKind::Call { func, args, .. } if args.len() == 1 => match &func.node { + ast::ExprKind::Name { id, .. } if *id == "int64".into() => match &args[0].node { + ast::ExprKind::Constant { value: Constant::Int(v), .. } => { + let v: Result = (*v).try_into(); + match v { + Ok(v) => Ok(SymbolValue::I64(v)), + _ => Err(HashSet::from([format!( + "default param value out of range at {}", + default.location + )])), } - _ => Err(HashSet::from([ - format!("only allow constant integer here at {}", default.location), - ])) } - ast::ExprKind::Name { id, .. } if *id == "uint32".into() => match &args[0].node { - ast::ExprKind::Constant { value: Constant::Int(v), .. } => { - let v: Result = (*v).try_into(); - match v { - Ok(v) => Ok(SymbolValue::U32(v)), - _ => Err(HashSet::from([ - format!("default param value out of range at {}", default.location), - ])), - } + _ => Err(HashSet::from([format!( + "only allow constant integer here at {}", + default.location + )])), + }, + ast::ExprKind::Name { id, .. } if *id == "uint32".into() => match &args[0].node { + ast::ExprKind::Constant { value: Constant::Int(v), .. } => { + let v: Result = (*v).try_into(); + match v { + Ok(v) => Ok(SymbolValue::U32(v)), + _ => Err(HashSet::from([format!( + "default param value out of range at {}", + default.location + )])), } - _ => Err(HashSet::from([ - format!("only allow constant integer here at {}", default.location), - ])) } - ast::ExprKind::Name { id, .. } if *id == "uint64".into() => match &args[0].node { - ast::ExprKind::Constant { value: Constant::Int(v), .. } => { - let v: Result = (*v).try_into(); - match v { - Ok(v) => Ok(SymbolValue::U64(v)), - _ => Err(HashSet::from([ - format!("default param value out of range at {}", default.location), - ])), - } + _ => Err(HashSet::from([format!( + "only allow constant integer here at {}", + default.location + )])), + }, + ast::ExprKind::Name { id, .. } if *id == "uint64".into() => match &args[0].node { + ast::ExprKind::Constant { value: Constant::Int(v), .. } => { + let v: Result = (*v).try_into(); + match v { + Ok(v) => Ok(SymbolValue::U64(v)), + _ => Err(HashSet::from([format!( + "default param value out of range at {}", + default.location + )])), } - _ => Err(HashSet::from([ - format!("only allow constant integer here at {}", default.location), - ])) } - ast::ExprKind::Name { id, .. } if *id == "Some".into() => Ok( - SymbolValue::OptionSome( - Box::new(parse_parameter_default_value(&args[0], resolver)?) - ) - ), - _ => Err(HashSet::from([ - format!("unsupported default parameter at {}", default.location), - ])), - } - } - ast::ExprKind::Tuple { elts, .. } => Ok(SymbolValue::Tuple(elts - .iter() - .map(|x| parse_parameter_default_value(x, resolver)) - .collect::, _>>()? + _ => Err(HashSet::from([format!( + "only allow constant integer here at {}", + default.location + )])), + }, + ast::ExprKind::Name { id, .. } if *id == "Some".into() => Ok(SymbolValue::OptionSome( + Box::new(parse_parameter_default_value(&args[0], resolver)?), + )), + _ => Err(HashSet::from([format!( + "unsupported default parameter at {}", + default.location + )])), + }, + ast::ExprKind::Tuple { elts, .. } => Ok(SymbolValue::Tuple( + elts.iter() + .map(|x| parse_parameter_default_value(x, resolver)) + .collect::, _>>()?, )), ast::ExprKind::Name { id, .. } if id == &"none".into() => Ok(SymbolValue::OptionNone), ast::ExprKind::Name { id, .. } => { - resolver.get_default_param_value(default).ok_or_else( - || HashSet::from([ - format!( - "`{}` cannot be used as a default parameter at {} \ + resolver.get_default_param_value(default).ok_or_else(|| { + HashSet::from([format!( + "`{}` cannot be used as a default parameter at {} \ (not primitive type, option or tuple / not defined?)", - id, - default.location - ), - ]) - ) + id, default.location + )]) + }) } - _ => Err(HashSet::from([ - format!( - "unsupported default parameter (not primitive type, option or tuple) at {}", - default.location - ), - ])) + _ => Err(HashSet::from([format!( + "unsupported default parameter (not primitive type, option or tuple) at {}", + default.location + )])), } } /// Obtains the element type of an array-like type. pub fn arraylike_flatten_element_type(unifier: &mut Unifier, ty: Type) -> Type { match &*unifier.get_ty(ty) { - TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => - unpack_ndarray_var_tys(unifier, ty).0, + TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { + unpack_ndarray_var_tys(unifier, ty).0 + } TypeEnum::TList { ty } => arraylike_flatten_element_type(unifier, *ty), - _ => ty + _ => ty, } } @@ -721,6 +715,6 @@ pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 { } TypeEnum::TList { ty } => arraylike_get_ndims(unifier, *ty) + 1, - _ => 0 + _ => 0, } } diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index e38207f7..f7fa92b9 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -8,7 +8,9 @@ use std::{ use super::codegen::CodeGenContext; use super::typecheck::type_inferencer::PrimitiveStore; -use super::typecheck::typedef::{FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, Unifier, VarMap}; +use super::typecheck::typedef::{ + FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, Unifier, VarMap, +}; use crate::{ codegen::CodeGenerator, symbol_resolver::{SymbolResolver, ValueEnum}, @@ -32,16 +34,15 @@ use type_annotation::*; #[cfg(test)] mod test; -type GenCallCallback = - dyn for<'ctx, 'a> Fn( - &mut CodeGenContext<'ctx, 'a>, - Option<(Type, ValueEnum<'ctx>)>, - (&FunSignature, DefinitionId), - Vec<(Option, ValueEnum<'ctx>)>, - &mut dyn CodeGenerator, - ) -> Result>, String> - + Send - + Sync; +type GenCallCallback = dyn for<'ctx, 'a> Fn( + &mut CodeGenContext<'ctx, 'a>, + Option<(Type, ValueEnum<'ctx>)>, + (&FunSignature, DefinitionId), + Vec<(Option, ValueEnum<'ctx>)>, + &mut dyn CodeGenerator, + ) -> Result>, String> + + Send + + Sync; pub struct GenCall { fp: Box, @@ -53,7 +54,7 @@ impl GenCall { GenCall { fp } } - /// Creates a dummy instance of [`GenCall`], which invokes [`unreachable!()`] with the given + /// Creates a dummy instance of [`GenCall`], which invokes [`unreachable!()`] with the given /// `reason`. #[must_use] pub fn create_dummy(reason: String) -> GenCall { diff --git a/nac3core/src/toplevel/numpy.rs b/nac3core/src/toplevel/numpy.rs index aee09048..99129a9a 100644 --- a/nac3core/src/toplevel/numpy.rs +++ b/nac3core/src/toplevel/numpy.rs @@ -1,4 +1,3 @@ -use itertools::Itertools; use crate::{ toplevel::helper::PRIMITIVE_DEF_IDS, typecheck::{ @@ -6,9 +5,10 @@ use crate::{ typedef::{Type, TypeEnum, Unifier, VarMap}, }, }; +use itertools::Itertools; /// Creates a `ndarray` [`Type`] with the given type arguments. -/// +/// /// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not /// specialized. /// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not @@ -40,12 +40,10 @@ pub fn subst_ndarray_tvars( debug_assert_eq!(*obj_id, PRIMITIVE_DEF_IDS.ndarray); if dtype.is_none() && ndims.is_none() { - return ndarray + return ndarray; } - let tvar_ids = params.iter() - .map(|(obj_id, _)| *obj_id) - .collect_vec(); + let tvar_ids = params.iter().map(|(obj_id, _)| *obj_id).collect_vec(); debug_assert_eq!(tvar_ids.len(), 2); let mut tvar_subst = VarMap::new(); @@ -59,45 +57,29 @@ pub fn subst_ndarray_tvars( unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray) } -fn unpack_ndarray_tvars( - unifier: &mut Unifier, - ndarray: Type, -) -> Vec<(u32, Type)> { +fn unpack_ndarray_tvars(unifier: &mut Unifier, ndarray: Type) -> Vec<(u32, Type)> { let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else { panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray)) }; debug_assert_eq!(*obj_id, PRIMITIVE_DEF_IDS.ndarray); debug_assert_eq!(params.len(), 2); - params.iter() + params + .iter() .sorted_by_key(|(obj_id, _)| *obj_id) .map(|(var_id, ty)| (*var_id, *ty)) .collect_vec() } -/// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds -/// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` +/// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds +/// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` /// respectively. -pub fn unpack_ndarray_var_ids( - unifier: &mut Unifier, - ndarray: Type, -) -> (u32, u32) { - unpack_ndarray_tvars(unifier, ndarray) - .into_iter() - .map(|v| v.0) - .collect_tuple() - .unwrap() +pub fn unpack_ndarray_var_ids(unifier: &mut Unifier, ndarray: Type) -> (u32, u32) { + unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.0).collect_tuple().unwrap() } /// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to /// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively. -pub fn unpack_ndarray_var_tys( - unifier: &mut Unifier, - ndarray: Type, -) -> (Type, Type) { - unpack_ndarray_tvars(unifier, ndarray) - .into_iter() - .map(|v| v.1) - .collect_tuple() - .unwrap() +pub fn unpack_ndarray_var_tys(unifier: &mut Unifier, ndarray: Type) -> (Type, Type) { + unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.1).collect_tuple().unwrap() } diff --git a/nac3core/src/toplevel/test.rs b/nac3core/src/toplevel/test.rs index ed38a037..b9514da6 100644 --- a/nac3core/src/toplevel/test.rs +++ b/nac3core/src/toplevel/test.rs @@ -65,7 +65,11 @@ impl SymbolResolver for Resolver { } fn get_identifier_def(&self, id: StrRef) -> Result> { - self.0.id_to_def.lock().get(&id).cloned() + self.0 + .id_to_def + .lock() + .get(&id) + .cloned() .ok_or_else(|| HashSet::from(["Unknown identifier".to_string()])) } diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index e6769a03..4f42d8d2 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -1,7 +1,7 @@ +use super::*; use crate::symbol_resolver::SymbolValue; use crate::toplevel::helper::PRIMITIVE_DEF_IDS; use crate::typecheck::typedef::VarMap; -use super::*; use nac3parser::ast::Constant; #[derive(Clone, Debug)] @@ -29,9 +29,7 @@ impl TypeAnnotation { Primitive(ty) | TypeVar(ty) => unifier.stringify(*ty), CustomClass { id, params } => { let class_name = if let Some(ref top) = unifier.top_level { - if let TopLevelDef::Class { name, .. } = - &*top.definitions.read()[id.0].read() - { + if let TopLevelDef::Class { name, .. } = &*top.definitions.read()[id.0].read() { (*name).into() } else { unreachable!() @@ -39,24 +37,26 @@ impl TypeAnnotation { } else { format!("class_def_{}", id.0) }; - format!( - "{}{}", - class_name, - { - let param_list = params.iter().map(|p| p.stringify(unifier)).collect_vec().join(", "); - if param_list.is_empty() { - String::new() - } else { - format!("[{param_list}]") - } + format!("{}{}", class_name, { + let param_list = + params.iter().map(|p| p.stringify(unifier)).collect_vec().join(", "); + if param_list.is_empty() { + String::new() + } else { + format!("[{param_list}]") } - ) + }) + } + Literal(values) => { + format!("Literal({})", values.iter().map(|v| format!("{v:?}")).join(", ")) } - Literal(values) => format!("Literal({})", values.iter().map(|v| format!("{v:?}")).join(", ")), Virtual(ty) => format!("virtual[{}]", ty.stringify(unifier)), List(ty) => format!("list[{}]", ty.stringify(unifier)), Tuple(types) => { - format!("tuple[{}]", types.iter().map(|p| p.stringify(unifier)).collect_vec().join(", ")) + format!( + "tuple[{}]", + types.iter().map(|p| p.stringify(unifier)).collect_vec().join(", ") + ) } } } @@ -95,7 +95,10 @@ pub fn parse_ast_to_type_annotation_kinds( } else if id == &"str".into() { Ok(TypeAnnotation::Primitive(primitives.str)) } else if id == &"Exception".into() { - Ok(TypeAnnotation::CustomClass { id: PRIMITIVE_DEF_IDS.exception, params: Vec::default() }) + Ok(TypeAnnotation::CustomClass { + id: PRIMITIVE_DEF_IDS.exception, + params: Vec::default(), + }) } else if let Ok(obj_id) = resolver.get_identifier_def(*id) { let type_vars = { let def_read = top_level_defs[obj_id.0].try_read(); @@ -103,12 +106,10 @@ pub fn parse_ast_to_type_annotation_kinds( if let TopLevelDef::Class { type_vars, .. } = &*def_read { type_vars.clone() } else { - return Err(HashSet::from([ - format!( - "function cannot be used as a type (at {})", - expr.location - ), - ])) + return Err(HashSet::from([format!( + "function cannot be used as a type (at {})", + expr.location + )])); } } else { locked.get(&obj_id).unwrap().clone() @@ -116,13 +117,11 @@ pub fn parse_ast_to_type_annotation_kinds( }; // check param number here if !type_vars.is_empty() { - return Err(HashSet::from([ - format!( - "expect {} type variable parameter but got 0 (at {})", - type_vars.len(), - expr.location, - ), - ])) + return Err(HashSet::from([format!( + "expect {} type variable parameter but got 0 (at {})", + type_vars.len(), + expr.location, + )])); } Ok(TypeAnnotation::CustomClass { id: obj_id, params: vec![] }) } else if let Ok(ty) = resolver.get_symbol_type(unifier, top_level_defs, primitives, *id) { @@ -131,14 +130,16 @@ pub fn parse_ast_to_type_annotation_kinds( unifier.unify(var, ty).unwrap(); Ok(TypeAnnotation::TypeVar(ty)) } else { - Err(HashSet::from([ - format!("`{}` is not a valid type annotation (at {})", id, expr.location), - ])) + Err(HashSet::from([format!( + "`{}` is not a valid type annotation (at {})", + id, expr.location + )])) } } else { - Err(HashSet::from([ - format!("`{}` is not a valid type annotation (at {})", id, expr.location), - ])) + Err(HashSet::from([format!( + "`{}` is not a valid type annotation (at {})", + id, expr.location + )])) } }; @@ -147,11 +148,13 @@ pub fn parse_ast_to_type_annotation_kinds( slice: &ast::Expr, unifier: &mut Unifier, mut locked: HashMap>| { - if ["virtual".into(), "Generic".into(), "list".into(), "tuple".into(), "Option".into()].contains(id) + if ["virtual".into(), "Generic".into(), "list".into(), "tuple".into(), "Option".into()] + .contains(id) { - return Err(HashSet::from([ - format!("keywords cannot be class name (at {})", expr.location), - ])) + return Err(HashSet::from([format!( + "keywords cannot be class name (at {})", + expr.location + )])); } let obj_id = resolver.get_identifier_def(*id)?; let type_vars = { @@ -174,14 +177,12 @@ pub fn parse_ast_to_type_annotation_kinds( vec![slice] }; if type_vars.len() != params_ast.len() { - return Err(HashSet::from([ - format!( - "expect {} type parameters but got {} (at {})", - type_vars.len(), - params_ast.len(), - params_ast[0].location, - ), - ])) + return Err(HashSet::from([format!( + "expect {} type parameters but got {} (at {})", + type_vars.len(), + params_ast.len(), + params_ast[0].location, + )])); } let result = params_ast .iter() @@ -210,7 +211,7 @@ pub fn parse_ast_to_type_annotation_kinds( "application of type vars to generic class is not currently supported (at {})", params_ast[0].location ), - ])) + ])); } }; Ok(TypeAnnotation::CustomClass { id: obj_id, params: param_type_infos }) @@ -309,9 +310,10 @@ pub fn parse_ast_to_type_annotation_kinds( // Literal ast::ExprKind::Subscript { value, slice, .. } - if { - matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"Literal".into()) - } => { + if { + matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"Literal".into()) + } => + { let tup_elts = { if let ast::ExprKind::Tuple { elts, .. } = &slice.node { elts.as_slice() @@ -321,20 +323,18 @@ pub fn parse_ast_to_type_annotation_kinds( }; let type_annotations = tup_elts .iter() - .map(|e| { - match &e.node { - ast::ExprKind::Constant { value, .. } => Ok( - TypeAnnotation::Literal(vec![value.clone()]), - ), - _ => parse_ast_to_type_annotation_kinds( - resolver, - top_level_defs, - unifier, - primitives, - e, - locked.clone(), - ), + .map(|e| match &e.node { + ast::ExprKind::Constant { value, .. } => { + Ok(TypeAnnotation::Literal(vec![value.clone()])) } + _ => parse_ast_to_type_annotation_kinds( + resolver, + top_level_defs, + unifier, + primitives, + e, + locked.clone(), + ), }) .collect::, _>>()? .into_iter() @@ -347,9 +347,10 @@ pub fn parse_ast_to_type_annotation_kinds( if type_annotations.len() == 1 { Ok(TypeAnnotation::Literal(type_annotations)) } else { - Err(HashSet::from([ - format!("multiple literal bounds are currently unsupported (at {})", value.location) - ])) + Err(HashSet::from([format!( + "multiple literal bounds are currently unsupported (at {})", + value.location + )])) } } @@ -358,19 +359,19 @@ pub fn parse_ast_to_type_annotation_kinds( if let ast::ExprKind::Name { id, .. } = &value.node { class_name_handle(id, slice, unifier, locked) } else { - Err(HashSet::from([ - format!("unsupported expression type for class name (at {})", value.location) - ])) + Err(HashSet::from([format!( + "unsupported expression type for class name (at {})", + value.location + )])) } } - ast::ExprKind::Constant { value, .. } => { - Ok(TypeAnnotation::Literal(vec![value.clone()])) - } + ast::ExprKind::Constant { value, .. } => Ok(TypeAnnotation::Literal(vec![value.clone()])), - _ => Err(HashSet::from([ - format!("unsupported expression for type annotation (at {})", expr.location), - ])), + _ => Err(HashSet::from([format!( + "unsupported expression for type annotation (at {})", + expr.location + )])), } } @@ -381,7 +382,7 @@ pub fn get_type_from_type_annotation_kinds( top_level_defs: &[Arc>], unifier: &mut Unifier, ann: &TypeAnnotation, - subst_list: &mut Option> + subst_list: &mut Option>, ) -> Result> { match ann { TypeAnnotation::CustomClass { id: obj_id, params } => { @@ -392,24 +393,17 @@ pub fn get_type_from_type_annotation_kinds( }; if type_vars.len() != params.len() { - return Err(HashSet::from([ - format!( - "unexpected number of type parameters: expected {} but got {}", - type_vars.len(), - params.len() - ), - ])) + return Err(HashSet::from([format!( + "unexpected number of type parameters: expected {} but got {}", + type_vars.len(), + params.len() + )])); } let param_ty = params .iter() .map(|x| { - get_type_from_type_annotation_kinds( - top_level_defs, - unifier, - x, - subst_list - ) + get_type_from_type_annotation_kinds(top_level_defs, unifier, x, subst_list) }) .collect::, _>>()?; @@ -419,7 +413,14 @@ pub fn get_type_from_type_annotation_kinds( let mut result = VarMap::new(); for (tvar, p) in type_vars.iter().zip(param_ty) { match unifier.get_ty(*tvar).as_ref() { - TypeEnum::TVar { id, range, fields: None, name, loc, is_const_generic: false } => { + TypeEnum::TVar { + id, + range, + fields: None, + name, + loc, + is_const_generic: false, + } => { let ok: bool = { // create a temp type var and unify to check compatibility p == *tvar || { @@ -434,18 +435,16 @@ pub fn get_type_from_type_annotation_kinds( if ok { result.insert(*id, p); } else { - return Err(HashSet::from([ - format!( - "cannot apply type {} to type variable with id {:?}", - unifier.internal_stringify( - p, - &mut |id| format!("class{id}"), - &mut |id| format!("typevar{id}"), - &mut None - ), - *id - ) - ])) + return Err(HashSet::from([format!( + "cannot apply type {} to type variable with id {:?}", + unifier.internal_stringify( + p, + &mut |id| format!("class{id}"), + &mut |id| format!("typevar{id}"), + &mut None + ), + *id + )])); } } @@ -454,24 +453,18 @@ pub fn get_type_from_type_annotation_kinds( let ok: bool = { // create a temp type var and unify to check compatibility p == *tvar || { - let temp = unifier.get_fresh_const_generic_var( - ty, - *name, - *loc, - ); + let temp = unifier.get_fresh_const_generic_var(ty, *name, *loc); unifier.unify(temp.0, p).is_ok() } }; if ok { result.insert(*id, p); } else { - return Err(HashSet::from([ - format!( - "cannot apply type {} to type variable {}", - unifier.stringify(p), - name.unwrap_or_else(|| format!("typevar{id}").into()), - ), - ])) + return Err(HashSet::from([format!( + "cannot apply type {} to type variable {}", + unifier.stringify(p), + name.unwrap_or_else(|| format!("typevar{id}").into()), + )])); } } @@ -507,7 +500,8 @@ pub fn get_type_from_type_annotation_kinds( } TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty), TypeAnnotation::Literal(values) => { - let values = values.iter() + let values = values + .iter() .map(SymbolValue::from_constant_inferred) .collect::, _>>() .map_err(|err| HashSet::from([err]))?; @@ -520,7 +514,7 @@ pub fn get_type_from_type_annotation_kinds( top_level_defs, unifier, ty.as_ref(), - subst_list + subst_list, )?; Ok(unifier.add_ty(TypeEnum::TVirtual { ty })) } @@ -529,7 +523,7 @@ pub fn get_type_from_type_annotation_kinds( top_level_defs, unifier, ty.as_ref(), - subst_list + subst_list, )?; Ok(unifier.add_ty(TypeEnum::TList { ty })) } @@ -607,7 +601,8 @@ pub fn check_overload_type_annotation_compatible( let ( TypeEnum::TVar { id: a, fields: None, .. }, TypeEnum::TVar { id: b, fields: None, .. }, - ) = (a, b) else { + ) = (a, b) + else { unreachable!("must be type var") }; diff --git a/nac3core/src/typecheck/function_check.rs b/nac3core/src/typecheck/function_check.rs index db43e037..bce8fbc5 100644 --- a/nac3core/src/typecheck/function_check.rs +++ b/nac3core/src/typecheck/function_check.rs @@ -2,15 +2,17 @@ use crate::typecheck::typedef::TypeEnum; use super::type_inferencer::Inferencer; use super::typedef::Type; -use nac3parser::ast::{self, Constant, Expr, ExprKind, Operator::{LShift, RShift}, Stmt, StmtKind, StrRef}; +use nac3parser::ast::{ + self, Constant, Expr, ExprKind, + Operator::{LShift, RShift}, + Stmt, StmtKind, StrRef, +}; use std::{collections::HashSet, iter::once}; impl<'a> Inferencer<'a> { fn should_have_value(&mut self, expr: &Expr>) -> Result<(), HashSet> { if matches!(expr.custom, Some(ty) if self.unifier.unioned(ty, self.primitives.none)) { - Err(HashSet::from([ - format!("Error at {}: cannot have value none", expr.location), - ])) + Err(HashSet::from([format!("Error at {}: cannot have value none", expr.location)])) } else { Ok(()) } @@ -22,9 +24,9 @@ impl<'a> Inferencer<'a> { defined_identifiers: &mut HashSet, ) -> Result<(), HashSet> { match &pattern.node { - ExprKind::Name { id, .. } if id == &"none".into() => Err(HashSet::from([ - format!("cannot assign to a `none` (at {})", pattern.location), - ])), + ExprKind::Name { id, .. } if id == &"none".into() => { + Err(HashSet::from([format!("cannot assign to a `none` (at {})", pattern.location)])) + } ExprKind::Name { id, .. } => { if !defined_identifiers.contains(id) { defined_identifiers.insert(*id); @@ -44,20 +46,17 @@ impl<'a> Inferencer<'a> { self.should_have_value(value)?; self.check_expr(slice, defined_identifiers)?; if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) { - return Err(HashSet::from([ - format!( - "Error at {}: cannot assign to tuple element", - value.location - ), - ])) + return Err(HashSet::from([format!( + "Error at {}: cannot assign to tuple element", + value.location + )])); } Ok(()) } - ExprKind::Constant { .. } => { - Err(HashSet::from([ - format!("cannot assign to a constant (at {})", pattern.location), - ])) - } + ExprKind::Constant { .. } => Err(HashSet::from([format!( + "cannot assign to a constant (at {})", + pattern.location + )])), _ => self.check_expr(pattern, defined_identifiers), } } @@ -69,14 +68,14 @@ impl<'a> Inferencer<'a> { ) -> Result<(), HashSet> { // there are some cases where the custom field is None if let Some(ty) = &expr.custom { - if !matches!(&expr.node, ExprKind::Constant { value: Constant::Ellipsis, .. }) && !self.unifier.is_concrete(*ty, &self.function_data.bound_variables) { - return Err(HashSet::from([ - format!( - "expected concrete type at {} but got {}", - expr.location, - self.unifier.get_ty(*ty).get_type_name() - ) - ])) + if !matches!(&expr.node, ExprKind::Constant { value: Constant::Ellipsis, .. }) + && !self.unifier.is_concrete(*ty, &self.function_data.bound_variables) + { + return Err(HashSet::from([format!( + "expected concrete type at {} but got {}", + expr.location, + self.unifier.get_ty(*ty).get_type_name() + )])); } } match &expr.node { @@ -96,12 +95,10 @@ impl<'a> Inferencer<'a> { self.defined_identifiers.insert(*id); } Err(e) => { - return Err(HashSet::from([ - format!( - "type error at identifier `{}` ({}) at {}", - id, e, expr.location - ) - ])) + return Err(HashSet::from([format!( + "type error at identifier `{}` ({}) at {}", + id, e, expr.location + )])) } } } @@ -127,17 +124,13 @@ impl<'a> Inferencer<'a> { // Check whether a bitwise shift has a negative RHS constant value if *op == LShift || *op == RShift { if let ExprKind::Constant { value, .. } = &right.node { - let Constant::Int(rhs_val) = value else { - unreachable!() - }; + let Constant::Int(rhs_val) = value else { unreachable!() }; if *rhs_val < 0 { - return Err(HashSet::from([ - format!( - "shift count is negative at {}", - right.location - ), - ])) + return Err(HashSet::from([format!( + "shift count is negative at {}", + right.location + )])); } } } @@ -214,16 +207,16 @@ impl<'a> Inferencer<'a> { /// is freed when the function returns. fn check_return_value_ty(&mut self, ret_ty: Type) -> bool { match &*self.unifier.get_ty_immutable(ret_ty) { - TypeEnum::TObj { .. } => { - [ - self.primitives.int32, - self.primitives.int64, - self.primitives.uint32, - self.primitives.uint64, - self.primitives.float, - self.primitives.bool, - ].iter().any(|allowed_ty| self.unifier.unioned(ret_ty, *allowed_ty)) - } + TypeEnum::TObj { .. } => [ + self.primitives.int32, + self.primitives.int64, + self.primitives.uint32, + self.primitives.uint64, + self.primitives.float, + self.primitives.bool, + ] + .iter() + .any(|allowed_ty| self.unifier.unioned(ret_ty, *allowed_ty)), TypeEnum::TTuple { ty } => ty.iter().all(|t| self.check_return_value_ty(*t)), _ => false, } @@ -330,8 +323,11 @@ impl<'a> Inferencer<'a> { if let Some(ret_ty) = value.custom { // Explicitly allow ellipsis as a return value, as the type of the ellipsis is contextually // inferred and just generates an unconditional assertion - if matches!(value.node, ExprKind::Constant { value: Constant::Ellipsis, .. }) { - return Ok(true) + if matches!( + value.node, + ExprKind::Constant { value: Constant::Ellipsis, .. } + ) { + return Ok(true); } if !self.check_return_value_ty(ret_ty) { @@ -341,7 +337,7 @@ impl<'a> Inferencer<'a> { self.unifier.stringify(ret_ty), value.location, ), - ])) + ])); } } } diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index 6f0211c8..6f1eeced 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -1,4 +1,3 @@ -use std::cmp::max; use crate::symbol_resolver::SymbolValue; use crate::toplevel::helper::PRIMITIVE_DEF_IDS; use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys}; @@ -6,11 +5,12 @@ use crate::typecheck::{ type_inferencer::*, typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, }; +use itertools::Itertools; use nac3parser::ast::StrRef; use nac3parser::ast::{Cmpop, Operator, Unaryop}; +use std::cmp::max; use std::collections::HashMap; use std::rc::Rc; -use itertools::Itertools; #[must_use] pub fn binop_name(op: &Operator) -> &'static str { @@ -255,7 +255,14 @@ pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty /// `LShift`, `RShift` pub fn impl_bitwise_shift(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { - impl_binop(unifier, store, ty, &[store.int32, store.uint32], Some(ty), &[Operator::LShift, Operator::RShift]); + impl_binop( + unifier, + store, + ty, + &[store.int32, store.uint32], + Some(ty), + &[Operator::LShift, Operator::RShift], + ); } /// `Div` @@ -297,7 +304,7 @@ pub fn impl_matmul( store: &PrimitiveStore, ty: Type, other_ty: &[Type], - ret_ty: Option, + ret_ty: Option, ) { impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::MatMult]); } @@ -353,7 +360,7 @@ pub fn typeof_ndarray_broadcast( left: Type, right: Type, ) -> Result { - let is_left_ndarray = left.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_left_ndarray = left.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); let is_right_ndarray = right.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); assert!(is_left_ndarray || is_right_ndarray); @@ -375,7 +382,8 @@ pub fn typeof_ndarray_broadcast( _ => unreachable!(), }; - let res_ndims = left_ty_ndims.into_iter() + let res_ndims = left_ty_ndims + .into_iter() .cartesian_product(right_ty_ndims) .map(|(left, right)| { let left_val = u64::try_from(left).unwrap(); @@ -390,11 +398,7 @@ pub fn typeof_ndarray_broadcast( Ok(make_ndarray_ty(unifier, primitives, Some(left_ty_dtype), Some(res_ndims))) } else { - let (ndarray_ty, scalar_ty) = if is_left_ndarray { - (left, right) - } else { - (right, left) - }; + let (ndarray_ty, scalar_ty) = if is_left_ndarray { (left, right) } else { (right, left) }; let (ndarray_ty_dtype, _) = unpack_ndarray_var_tys(unifier, ndarray_ty); @@ -424,21 +428,17 @@ pub fn typeof_binop( lhs: Type, rhs: Type, ) -> Result, String> { - let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); let is_right_ndarray = rhs.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); Ok(Some(match op { - Operator::Add - | Operator::Sub - | Operator::Mult - | Operator::Mod - | Operator::FloorDiv => { + Operator::Add | Operator::Sub | Operator::Mult | Operator::Mod | Operator::FloorDiv => { if is_left_ndarray || is_right_ndarray { typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)? } else if unifier.unioned(lhs, rhs) { lhs } else { - return Ok(None) + return Ok(None); } } @@ -464,12 +464,14 @@ pub fn typeof_binop( (2, 2) => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?, (lhs, rhs) if lhs == 0 || rhs == 0 => { return Err(format!( - "Input operand {} does not have enough dimensions (has {lhs}, requires {rhs})", - (rhs == 0) as u8 - )) + "Input operand {} does not have enough dimensions (has {lhs}, requires {rhs})", + (rhs == 0) as u8 + )) } (lhs, rhs) => { - return Err(format!("ndarray.__matmul__ on {lhs}D and {rhs}D operands not supported")) + return Err(format!( + "ndarray.__matmul__ on {lhs}D and {rhs}D operands not supported" + )) } } } @@ -480,29 +482,35 @@ pub fn typeof_binop( } else if unifier.unioned(lhs, rhs) { primitives.float } else { - return Ok(None) + return Ok(None); } } Operator::Pow => { if is_left_ndarray || is_right_ndarray { typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)? - } else if [primitives.int32, primitives.int64, primitives.uint32, primitives.uint64, primitives.float].into_iter().any(|ty| unifier.unioned(lhs, ty)) { + } else if [ + primitives.int32, + primitives.int64, + primitives.uint32, + primitives.uint64, + primitives.float, + ] + .into_iter() + .any(|ty| unifier.unioned(lhs, ty)) + { lhs } else { - return Ok(None) + return Ok(None); } } - Operator::LShift - | Operator::RShift => lhs, - Operator::BitOr - | Operator::BitXor - | Operator::BitAnd => { + Operator::LShift | Operator::RShift => lhs, + Operator::BitOr | Operator::BitXor | Operator::BitAnd => { if unifier.unioned(lhs, rhs) { lhs } else { - return Ok(None) + return Ok(None); } } })) @@ -516,31 +524,34 @@ pub fn typeof_unaryop( ) -> Result, String> { let operand_obj_id = operand.obj_id(unifier); - if *op == Unaryop::Not && operand_obj_id.is_some_and(|id| id == primitives.ndarray.obj_id(unifier).unwrap()) { - return Err("The truth value of an array with more than one element is ambiguous".to_string()) + if *op == Unaryop::Not + && operand_obj_id.is_some_and(|id| id == primitives.ndarray.obj_id(unifier).unwrap()) + { + return Err( + "The truth value of an array with more than one element is ambiguous".to_string() + ); } Ok(match *op { - Unaryop::Not => { - match operand_obj_id { - Some(v) if v == PRIMITIVE_DEF_IDS.ndarray => Some(operand), - Some(_) => Some(primitives.bool), - _ => None - } - } + Unaryop::Not => match operand_obj_id { + Some(v) if v == PRIMITIVE_DEF_IDS.ndarray => Some(operand), + Some(_) => Some(primitives.bool), + _ => None, + }, Unaryop::Invert => { if operand_obj_id.is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) { Some(primitives.int32) - } else if operand_obj_id.is_some_and(|id| PRIMITIVE_DEF_IDS.iter().any(|prim_id| id == prim_id)) { + } else if operand_obj_id + .is_some_and(|id| PRIMITIVE_DEF_IDS.iter().any(|prim_id| id == prim_id)) + { Some(operand) } else { None } } - Unaryop::UAdd - | Unaryop::USub => { + Unaryop::UAdd | Unaryop::USub => { if operand_obj_id.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) { let (dtype, _) = unpack_ndarray_var_tys(unifier, operand); if dtype.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) { @@ -548,13 +559,15 @@ pub fn typeof_unaryop( "The ufunc 'positive' cannot be applied to ndarray[bool, N]".to_string() } else { "The numpy boolean negative, the `-` operator, is not supported, use the `~` operator function instead.".to_string() - }) + }); } Some(operand) } else if operand_obj_id.is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) { Some(primitives.int32) - } else if operand_obj_id.is_some_and(|id| PRIMITIVE_DEF_IDS.iter().any(|prim_id| id == prim_id)) { + } else if operand_obj_id + .is_some_and(|id| PRIMITIVE_DEF_IDS.iter().any(|prim_id| id == prim_id)) + { Some(operand) } else { None @@ -571,12 +584,8 @@ pub fn typeof_cmpop( lhs: Type, rhs: Type, ) -> Result, String> { - let is_left_ndarray = lhs - .obj_id(unifier) - .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); - let is_right_ndarray = rhs - .obj_id(unifier) - .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_right_ndarray = rhs.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); Ok(Some(if is_left_ndarray || is_right_ndarray { let brd = typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?; @@ -586,7 +595,7 @@ pub fn typeof_cmpop( } else if unifier.unioned(lhs, rhs) { primitives.bool } else { - return Ok(None) + return Ok(None); })) } @@ -643,11 +652,19 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None); /* ndarray ===== */ - let ndarray_usized_ndims_tvar = unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None); - let ndarray_unsized_t = make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.0)); + let ndarray_usized_ndims_tvar = + unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None); + let ndarray_unsized_t = + make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.0)); let (ndarray_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_t); let (ndarray_unsized_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_unsized_t); - impl_basic_arithmetic(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); + impl_basic_arithmetic( + unifier, + store, + ndarray_t, + &[ndarray_unsized_t, ndarray_unsized_dtype_t], + None, + ); impl_pow(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None); impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); diff --git a/nac3core/src/typecheck/type_error.rs b/nac3core/src/typecheck/type_error.rs index ac36b331..abf61a56 100644 --- a/nac3core/src/typecheck/type_error.rs +++ b/nac3core/src/typecheck/type_error.rs @@ -89,10 +89,7 @@ impl<'a> Display for DisplayTypeError<'a> { IncorrectArgType { name, expected, got } => { let expected = self.unifier.stringify_with_notes(*expected, &mut notes); let got = self.unifier.stringify_with_notes(*got, &mut notes); - write!( - f, - "Incorrect argument type for {name}. Expected {expected}, but got {got}" - ) + write!(f, "Incorrect argument type for {name}. Expected {expected}, but got {got}") } FieldUnificationError { field, types, loc } => { let lhs = self.unifier.stringify_with_notes(types.0, &mut notes); diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 575e5ed2..0297eaee 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -1,21 +1,25 @@ use std::collections::{HashMap, HashSet}; use std::convert::{From, TryInto}; use std::iter::once; -use std::{cell::RefCell, sync::Arc}; use std::ops::Not; +use std::{cell::RefCell, sync::Arc}; use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier, VarMap}; use super::{magic_methods::*, type_error::TypeError, typedef::CallId}; use crate::{ - symbol_resolver::{SymbolResolver, SymbolValue}, + symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::{ helper::{arraylike_flatten_element_type, arraylike_get_ndims, PRIMITIVE_DEF_IDS}, numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, TopLevelContext, }, }; -use itertools::{Itertools, izip}; -use nac3parser::ast::{self, fold::{self, Fold}, Arguments, Comprehension, ExprContext, ExprKind, Located, Location, StrRef}; +use itertools::{izip, Itertools}; +use nac3parser::ast::{ + self, + fold::{self, Fold}, + Arguments, Comprehension, ExprContext, ExprKind, Located, Location, StrRef, +}; #[cfg(test)] mod test; @@ -187,9 +191,12 @@ impl<'a> Fold<()> for Inferencer<'a> { } if let Some(old_typ) = self.variable_mapping.insert(name, typ) { let loc = handler.location; - self.unifier.unify(old_typ, typ).map_err(|e| HashSet::from([ - e.at(Some(loc)).to_display(self.unifier).to_string(), - ]))?; + self.unifier.unify(old_typ, typ).map_err(|e| { + HashSet::from([e + .at(Some(loc)) + .to_display(self.unifier) + .to_string()]) + })?; } } let mut type_ = naive_folder.fold_expr(*type_)?; @@ -234,8 +241,12 @@ impl<'a> Fold<()> for Inferencer<'a> { self.unify(self.primitives.int32, target.custom.unwrap(), &target.location)?; } else { let list_like_ty = match &*self.unifier.get_ty(iter.custom.unwrap()) { - TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() }), - TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => todo!(), + TypeEnum::TList { .. } => { + self.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() }) + } + TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { + todo!() + } _ => unreachable!(), }; self.unify(list_like_ty, iter.custom.unwrap(), &iter.location)?; @@ -273,13 +284,10 @@ impl<'a> Fold<()> for Inferencer<'a> { let targets: Result, _> = targets .into_iter() .map(|target| { - let ExprKind::Name { id, ctx } = target.node else { - unreachable!() - }; + let ExprKind::Name { id, ctx } = target.node else { unreachable!() }; self.defined_identifiers.insert(id); - let target_ty = if let Some(ty) = self.variable_mapping.get(&id) - { + let target_ty = if let Some(ty) = self.variable_mapping.get(&id) { *ty } else { let unifier: &mut Unifier = self.unifier; @@ -305,8 +313,9 @@ impl<'a> Fold<()> for Inferencer<'a> { }) .collect(); let loc = node.location; - let targets = targets - .map_err(|e| HashSet::from([e.at(Some(loc)).to_display(self.unifier).to_string()]))?; + let targets = targets.map_err(|e| { + HashSet::from([e.at(Some(loc)).to_display(self.unifier).to_string()]) + })?; return Ok(Located { location: node.location, node: ast::StmtKind::Assign { @@ -463,7 +472,7 @@ impl<'a> Fold<()> for Inferencer<'a> { self.unify(test.custom.unwrap(), self.primitives.bool, &test.location)?; match msg { Some(m) => self.unify(m.custom.unwrap(), self.primitives.str, &m.location)?, - None => () + None => (), } } _ => return report_error("Unsupported statement type", stmt.location), @@ -485,9 +494,7 @@ impl<'a> Fold<()> for Inferencer<'a> { _ => fold::fold_expr(self, node)?, }; let custom = match &expr.node { - ExprKind::Constant { value, .. } => { - Some(self.infer_constant(value, &expr.location)?) - } + ExprKind::Constant { value, .. } => Some(self.infer_constant(value, &expr.location)?), ExprKind::Name { id, .. } => { // the name `none` is special since it may have different types if id == &"none".into() { @@ -497,7 +504,9 @@ impl<'a> Fold<()> for Inferencer<'a> { let var_map = params .iter() .map(|(id_var, ty)| { - let TypeEnum::TVar { id, range, name, loc, .. } = &*self.unifier.get_ty(*ty) else { + let TypeEnum::TVar { id, range, name, loc, .. } = + &*self.unifier.get_ty(*ty) + else { unreachable!() }; @@ -552,9 +561,9 @@ impl<'a> Fold<()> for Inferencer<'a> { ExprKind::IfExp { test, body, orelse } => { Some(self.infer_if_expr(test, body.as_ref(), orelse.as_ref())?) } - ExprKind::ListComp { .. } - | ExprKind::Lambda { .. } - | ExprKind::Call { .. } => expr.custom, // already computed + ExprKind::ListComp { .. } | ExprKind::Lambda { .. } | ExprKind::Call { .. } => { + expr.custom + } // already computed ExprKind::Slice { .. } => { // slices aren't exactly ranges, but for our purposes this should suffice Some(self.primitives.range) @@ -575,11 +584,9 @@ impl<'a> Inferencer<'a> { } fn unify(&mut self, a: Type, b: Type, location: &Location) -> Result<(), HashSet> { - self.unifier - .unify(a, b) - .map_err(|e| HashSet::from([ - e.at(Some(*location)).to_display(self.unifier).to_string(), - ])) + self.unifier.unify(a, b).map_err(|e| { + HashSet::from([e.at(Some(*location)).to_display(self.unifier).to_string()]) + }) } fn infer_pattern(&mut self, pattern: &ast::Expr<()>) -> Result<(), HashSet> { @@ -622,12 +629,15 @@ impl<'a> Inferencer<'a> { loc: Some(location), }; if let Some(ret) = ret { - self.unifier.unify(sign.ret, ret) + self.unifier + .unify(sign.ret, ret) .map_err(|err| { - format!("Cannot unify {} <: {} - {:?}", - self.unifier.stringify(sign.ret), - self.unifier.stringify(ret), - TypeError::new(err.kind, Some(location))) + format!( + "Cannot unify {} <: {} - {:?}", + self.unifier.stringify(sign.ret), + self.unifier.stringify(ret), + TypeError::new(err.kind, Some(location)) + ) }) .unwrap(); } @@ -638,9 +648,12 @@ impl<'a> Inferencer<'a> { .map(|v| v.name) .rev() .collect(); - self.unifier.unify_call(&call, ty, sign, &required).map_err(|e| HashSet::from([ - e.at(Some(location)).to_display(self.unifier).to_string(), - ]))?; + self.unifier.unify_call(&call, ty, sign, &required).map_err(|e| { + HashSet::from([e + .at(Some(location)) + .to_display(self.unifier) + .to_string()]) + })?; return Ok(sign.ret); } } @@ -815,7 +828,7 @@ impl<'a> Inferencer<'a> { keywords: &[Located], ) -> Result>>, HashSet> { let Located { location: func_location, node: ExprKind::Name { id, ctx }, .. } = func else { - return Ok(None) + return Ok(None); }; // handle special functions that cannot be typed in the usual way... @@ -824,7 +837,7 @@ impl<'a> Inferencer<'a> { return report_error( "`virtual` can only accept 1/2 positional arguments", *func_location, - ) + ); } let arg0 = self.fold_expr(args.remove(0))?; let ty = if let Some(arg) = args.pop() { @@ -852,19 +865,19 @@ impl<'a> Inferencer<'a> { args: vec![arg0], keywords: vec![], }, - })) + })); } - if [ - "int32", - "float", - "bool", - "round", - "round64", - "np_isnan", - "np_isinf", - ].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 { - let target_ty = if id == &"int32".into() || id == &"round".into() || id == &"floor".into() || id == &"ceil".into() { + if ["int32", "float", "bool", "round", "round64", "np_isnan", "np_isinf"] + .iter() + .any(|fun_id| id == &(*fun_id).into()) + && args.len() == 1 + { + let target_ty = if id == &"int32".into() + || id == &"round".into() + || id == &"floor".into() + || id == &"ceil".into() + { self.primitives.int32 } else if id == &"round64".into() || id == &"floor64".into() || id == &"ceil64".into() { self.primitives.int64 @@ -872,12 +885,17 @@ impl<'a> Inferencer<'a> { self.primitives.float } else if id == &"bool".into() || id == &"np_isnan".into() || id == &"np_isinf".into() { self.primitives.bool - } else { unreachable!() }; + } else { + unreachable!() + }; let arg0 = self.fold_expr(args.remove(0))?; let arg0_ty = arg0.custom.unwrap(); - let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) { + let ret = if arg0_ty + .obj_id(self.unifier) + .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + { let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty); make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims)) @@ -886,13 +904,11 @@ impl<'a> Inferencer<'a> { }; let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![ - FuncArg { - name: "n".into(), - ty: arg0.custom.unwrap(), - default_value: None, - }, - ], + args: vec![FuncArg { + name: "n".into(), + ty: arg0.custom.unwrap(), + default_value: None, + }], ret, vars: VarMap::new(), })); @@ -909,32 +925,28 @@ impl<'a> Inferencer<'a> { args: vec![arg0], keywords: vec![], }, - })) + })); } - if [ - "np_min", - "np_max", - ].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 { + if ["np_min", "np_max"].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 { let arg0 = self.fold_expr(args.remove(0))?; let arg0_ty = arg0.custom.unwrap(); - let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) { - let (ndarray_dtype, _) = unpack_ndarray_var_tys(self.unifier, arg0_ty); + let ret = + if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) { + let (ndarray_dtype, _) = unpack_ndarray_var_tys(self.unifier, arg0_ty); - ndarray_dtype - } else { - arg0_ty - }; + ndarray_dtype + } else { + arg0_ty + }; let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![ - FuncArg { - name: "a".into(), - ty: arg0.custom.unwrap(), - default_value: None, - }, - ], + args: vec![FuncArg { + name: "a".into(), + ty: arg0.custom.unwrap(), + default_value: None, + }], ret, vars: VarMap::new(), })); @@ -951,7 +963,7 @@ impl<'a> Inferencer<'a> { args: vec![arg0], keywords: vec![], }, - })) + })); } if [ @@ -964,29 +976,32 @@ impl<'a> Inferencer<'a> { "np_ldexp", "np_hypot", "np_nextafter", - ].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 2 { + ] + .iter() + .any(|fun_id| id == &(*fun_id).into()) + && args.len() == 2 + { let arg0 = self.fold_expr(args.remove(0))?; let arg0_ty = arg0.custom.unwrap(); let arg1 = self.fold_expr(args.remove(0))?; let arg1_ty = arg1.custom.unwrap(); - let arg0_dtype = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) { - unpack_ndarray_var_tys(self.unifier, arg0_ty).0 - } else { - arg0_ty - }; + let arg0_dtype = + if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) { + unpack_ndarray_var_tys(self.unifier, arg0_ty).0 + } else { + arg0_ty + }; - let arg1_dtype = if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) { - unpack_ndarray_var_tys(self.unifier, arg1_ty).0 - } else { - arg1_ty - }; + let arg1_dtype = + if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) { + unpack_ndarray_var_tys(self.unifier, arg1_ty).0 + } else { + arg1_ty + }; - let expected_arg1_dtype = if id == &"np_ldexp".into() { - self.primitives.int32 - } else { - arg0_dtype - }; + let expected_arg1_dtype = + if id == &"np_ldexp".into() { self.primitives.int32 } else { arg0_dtype }; if !self.unifier.unioned(arg1_dtype, expected_arg1_dtype) { return report_error( format!( @@ -995,7 +1010,7 @@ impl<'a> Inferencer<'a> { self.unifier.stringify(arg1_dtype), ).as_str(), arg0.location, - ) + ); } let target_ty = if id == &"np_minimum".into() || id == &"np_maximum".into() { @@ -1004,14 +1019,16 @@ impl<'a> Inferencer<'a> { self.primitives.float }; - let ret = if [ - &arg0_ty, - &arg1_ty, - ].into_iter().any(|arg_ty| arg_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) { + let ret = if [&arg0_ty, &arg1_ty].into_iter().any(|arg_ty| { + arg_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + }) { // typeof_ndarray_broadcast requires both dtypes to be the same, but ldexp accepts // (float, int32), so convert it to align with the dtype of the first arg let arg1_ty = if id == &"np_ldexp".into() { - if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) { + if arg1_ty + .obj_id(self.unifier) + .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + { let (_, ndims) = unpack_ndarray_var_tys(self.unifier, arg1_ty); make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndims)) @@ -1032,16 +1049,8 @@ impl<'a> Inferencer<'a> { let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![ - FuncArg { - name: "x1".into(), - ty: arg0.custom.unwrap(), - default_value: None, - }, - FuncArg { - name: "x2".into(), - ty: arg1.custom.unwrap(), - default_value: None, - }, + FuncArg { name: "x1".into(), ty: arg0.custom.unwrap(), default_value: None }, + FuncArg { name: "x2".into(), ty: arg1.custom.unwrap(), default_value: None }, ], ret, vars: VarMap::new(), @@ -1059,38 +1068,37 @@ impl<'a> Inferencer<'a> { args: vec![arg0, arg1], keywords: vec![], }, - })) + })); } - // int64, uint32 and uint64 are special because their argument can be a constant outside the + // int64, uint32 and uint64 are special because their argument can be a constant outside the // range of int32s - if [ - "int64", - "uint32", - "uint64", - ].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 { + if ["int64", "uint32", "uint64"].iter().any(|fun_id| id == &(*fun_id).into()) + && args.len() == 1 + { let target_ty = if id == &"int64".into() { self.primitives.int64 } else if id == &"uint32".into() { self.primitives.uint32 } else if id == &"uint64".into() { self.primitives.uint64 - } else { unreachable!() }; + } else { + unreachable!() + }; // Handle constants first to ensure that their types are not defaulted to int32, which // causes an "Integer out of bound" error - if let ExprKind::Constant { - value: ast::Constant::Int(val), - kind - } = &args[0].node { + if let ExprKind::Constant { value: ast::Constant::Int(val), kind } = &args[0].node { let conv_is_ok = if self.unifier.unioned(target_ty, self.primitives.int64) { i64::try_from(*val).is_ok() } else if self.unifier.unioned(target_ty, self.primitives.uint32) { u32::try_from(*val).is_ok() } else if self.unifier.unioned(target_ty, self.primitives.uint64) { u64::try_from(*val).is_ok() - } else { unreachable!() }; - + } else { + unreachable!() + }; + return if conv_is_ok { Ok(Some(Located { location: args[0].location, @@ -1102,13 +1110,16 @@ impl<'a> Inferencer<'a> { })) } else { report_error("Integer out of bound", args[0].location) - } + }; } let arg0 = self.fold_expr(args.remove(0))?; let arg0_ty = arg0.custom.unwrap(); - let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) { + let ret = if arg0_ty + .obj_id(self.unifier) + .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + { let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty); make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims)) @@ -1117,13 +1128,11 @@ impl<'a> Inferencer<'a> { }; let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![ - FuncArg { - name: "n".into(), - ty: arg0.custom.unwrap(), - default_value: None, - }, - ], + args: vec![FuncArg { + name: "n".into(), + ty: arg0.custom.unwrap(), + default_value: None, + }], ret, vars: VarMap::new(), })); @@ -1140,30 +1149,29 @@ impl<'a> Inferencer<'a> { args: vec![arg0], keywords: vec![], }, - })) + })); } // 1-argument ndarray n-dimensional creation functions - if [ - "np_ndarray".into(), - "np_empty".into(), - "np_zeros".into(), - "np_ones".into(), - ].contains(id) && args.len() == 1 { + if ["np_ndarray".into(), "np_empty".into(), "np_zeros".into(), "np_ones".into()] + .contains(id) + && args.len() == 1 + { let ExprKind::List { elts, .. } = &args[0].node else { return report_error( - format!("Expected List literal for first argument of {id}, got {}", args[0].node.name()).as_str(), - args[0].location - ) + format!( + "Expected List literal for first argument of {id}, got {}", + args[0].node.name() + ) + .as_str(), + args[0].location, + ); }; let ndims = elts.len() as u64; let arg0 = self.fold_expr(args.remove(0))?; - let ndims = self.unifier.get_fresh_literal( - vec![SymbolValue::U64(ndims)], - None, - ); + let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None); let ret = make_ndarray_ty( self.unifier, self.primitives, @@ -1171,13 +1179,11 @@ impl<'a> Inferencer<'a> { Some(ndims), ); let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![ - FuncArg { - name: "shape".into(), - ty: arg0.custom.unwrap(), - default_value: None, - }, - ], + args: vec![FuncArg { + name: "shape".into(), + ty: arg0.custom.unwrap(), + default_value: None, + }], ret, vars: VarMap::new(), })); @@ -1194,16 +1200,20 @@ impl<'a> Inferencer<'a> { args: vec![arg0], keywords: vec![], }, - })) + })); } // 2-argument ndarray n-dimensional creation functions if id == &"np_full".into() && args.len() == 2 { let ExprKind::List { elts, .. } = &args[0].node else { return report_error( - format!("Expected List literal for first argument of {id}, got {}", args[0].node.name()).as_str(), - args[0].location - ) + format!( + "Expected List literal for first argument of {id}, got {}", + args[0].node.name() + ) + .as_str(), + args[0].location, + ); }; let ndims = elts.len() as u64; @@ -1212,23 +1222,11 @@ impl<'a> Inferencer<'a> { let arg1 = self.fold_expr(args.remove(0))?; let ty = arg1.custom.unwrap(); - let ndims = self.unifier.get_fresh_literal( - vec![SymbolValue::U64(ndims)], - None, - ); - let ret = make_ndarray_ty( - self.unifier, - self.primitives, - Some(ty), - Some(ndims), - ); + let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None); + let ret = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims)); let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![ - FuncArg { - name: "shape".into(), - ty: arg0.custom.unwrap(), - default_value: None, - }, + FuncArg { name: "shape".into(), ty: arg0.custom.unwrap(), default_value: None }, FuncArg { name: "fill_value".into(), ty: arg1.custom.unwrap(), @@ -1251,18 +1249,19 @@ impl<'a> Inferencer<'a> { args: vec![arg0, arg1], keywords: vec![], }, - })) + })); } // 1-argument ndarray n-dimensional creation functions if id == &"np_array".into() && args.len() == 1 { let arg0 = self.fold_expr(args.remove(0))?; - let keywords = keywords.iter() + let keywords = keywords + .iter() .map(|v| fold::fold_keyword(self, v.clone())) .collect::, _>>()?; - let ndmin_kw = keywords.iter() - .find(|kwarg| kwarg.node.arg.is_some_and(|id| id == "ndmin".into())); + let ndmin_kw = + keywords.iter().find(|kwarg| kwarg.node.arg.is_some_and(|id| id == "ndmin".into())); let ty = arraylike_flatten_element_type(self.unifier, arg0.custom.unwrap()); let ndims = if let Some(ndmin_kw) = ndmin_kw { @@ -1270,30 +1269,22 @@ impl<'a> Inferencer<'a> { ExprKind::Constant { value, .. } => match value { ast::Constant::Int(value) => *value as u64, _ => return Err(HashSet::from(["Expected uint64 for ndims".to_string()])), - } + }, - _ => arraylike_get_ndims(self.unifier, arg0.custom.unwrap()) + _ => arraylike_get_ndims(self.unifier, arg0.custom.unwrap()), } } else { arraylike_get_ndims(self.unifier, arg0.custom.unwrap()) }; - let ndims = self.unifier.get_fresh_literal( - vec![SymbolValue::U64(ndims)], - None, - ); - let ret = make_ndarray_ty( - self.unifier, - self.primitives, - Some(ty), - Some(ndims), - ); + let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None); + let ret = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims)); let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![ - FuncArg { + FuncArg { name: "object".into(), ty: arg0.custom.unwrap(), - default_value: None + default_value: None, }, FuncArg { name: "copy".into(), @@ -1322,7 +1313,7 @@ impl<'a> Inferencer<'a> { args: vec![arg0], keywords, }, - })) + })); } Ok(None) @@ -1335,8 +1326,10 @@ impl<'a> Inferencer<'a> { mut args: Vec>, keywords: Vec>, ) -> Result>, HashSet> { - if let Some(spec_call_func) = self.try_fold_special_call(location, &func, &mut args, &keywords)? { - return Ok(spec_call_func) + if let Some(spec_call_func) = + self.try_fold_special_call(location, &func, &mut args, &keywords)? + { + return Ok(spec_call_func); } let func = Box::new(self.fold_expr(func)?); @@ -1365,11 +1358,9 @@ impl<'a> Inferencer<'a> { .map(|v| v.name) .rev() .collect(); - self.unifier - .unify_call(&call, func.custom.unwrap(), sign, &required) - .map_err(|e| HashSet::from([ - e.at(Some(location)).to_display(self.unifier).to_string(), - ]))?; + self.unifier.unify_call(&call, func.custom.unwrap(), sign, &required).map_err( + |e| HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()]), + )?; return Ok(Located { location, custom: Some(sign.ret), @@ -1403,8 +1394,7 @@ impl<'a> Inferencer<'a> { } else { let variable_mapping = &mut self.variable_mapping; let unifier: &mut Unifier = self.unifier; - self - .function_data + self.function_data .resolver .get_symbol_type(unifier, &self.top_level.definitions.read(), self.primitives, id) .unwrap_or_else(|_| { @@ -1434,8 +1424,9 @@ impl<'a> Inferencer<'a> { Ok(self.unifier.add_ty(TypeEnum::TTuple { ty: ty? })) } ast::Constant::Str(_) => Ok(self.primitives.str), - ast::Constant::None - => report_error("CPython `None` not supported (nac3 uses `none` instead)", *loc), + ast::Constant::None => { + report_error("CPython `None` not supported (nac3 uses `none` instead)", *loc) + } ast::Constant::Ellipsis => Ok(self.unifier.get_fresh_var(None, None).0), _ => report_error("not supported", *loc), } @@ -1471,8 +1462,11 @@ impl<'a> Inferencer<'a> { } (None, _) => { let t = self.unifier.stringify(ty); - report_error(&format!("`{t}::{attr}` field/method does not exist"), value.location) - }, + report_error( + &format!("`{t}::{attr}` field/method does not exist"), + value.location, + ) + } } } else { let attr_ty = self.unifier.get_dummy_var().0; @@ -1509,10 +1503,8 @@ impl<'a> Inferencer<'a> { let method = if let TypeEnum::TObj { fields, .. } = self.unifier.get_ty_immutable(left_ty).as_ref() { - let (binop_name, binop_assign_name) = ( - binop_name(op).into(), - binop_assign_name(op).into() - ); + let (binop_name, binop_assign_name) = + (binop_name(op).into(), binop_assign_name(op).into()); // if is aug_assign, try aug_assign operator first if is_aug_assign && fields.contains_key(&binop_assign_name) { binop_assign_name @@ -1527,22 +1519,11 @@ impl<'a> Inferencer<'a> { // The type of augmented assignment operator should never change Some(left_ty) } else { - typeof_binop( - self.unifier, - self.primitives, - op, - left_ty, - right_ty, - ).map_err(|e| HashSet::from([format!("{e} (at {location})")]))? + typeof_binop(self.unifier, self.primitives, op, left_ty, right_ty) + .map_err(|e| HashSet::from([format!("{e} (at {location})")]))? }; - self.build_method_call( - location, - method, - left_ty, - vec![right_ty], - ret, - ) + self.build_method_call(location, method, left_ty, vec![right_ty], ret) } fn infer_unary_ops( @@ -1553,12 +1534,8 @@ impl<'a> Inferencer<'a> { ) -> InferenceResult { let method = unaryop_name(op).into(); - let ret = typeof_unaryop( - self.unifier, - self.primitives, - op, - operand.custom.unwrap(), - ).map_err(|e| HashSet::from([format!("{e} (at {location})")]))?; + let ret = typeof_unaryop(self.unifier, self.primitives, op, operand.custom.unwrap()) + .map_err(|e| HashSet::from([format!("{e} (at {location})")]))?; self.build_method_call(location, method, operand.custom.unwrap(), vec![], ret) } @@ -1570,16 +1547,23 @@ impl<'a> Inferencer<'a> { ops: &[ast::Cmpop], comparators: &[ast::Expr>], ) -> InferenceResult { - if ops.len() > 1 && once(left).chain(comparators).any(|expr| expr.custom.unwrap().obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) { - return Err(HashSet::from([String::from("Comparator chaining with ndarray types not supported")])) + if ops.len() > 1 + && once(left).chain(comparators).any(|expr| { + expr.custom + .unwrap() + .obj_id(self.unifier) + .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + }) + { + return Err(HashSet::from([String::from( + "Comparator chaining with ndarray types not supported", + )])); } let mut res = None; for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) { let method = comparison_name(c) - .ok_or_else(|| HashSet::from([ - "unsupported comparator".to_string() - ]))? + .ok_or_else(|| HashSet::from(["unsupported comparator".to_string()]))? .into(); let ret = typeof_cmpop( @@ -1588,7 +1572,8 @@ impl<'a> Inferencer<'a> { c, a.custom.unwrap(), b.custom.unwrap(), - ).map_err(|e| HashSet::from([format!("{e} (at {})", b.location)]))?; + ) + .map_err(|e| HashSet::from([format!("{e} (at {})", b.location)]))?; res.replace(self.build_method_call( location, @@ -1614,28 +1599,29 @@ impl<'a> Inferencer<'a> { TypeEnum::TVar { is_const_generic: false, .. } )); - let constrained_ty = make_ndarray_ty( - self.unifier, - self.primitives, - Some(dummy_tvar), - Some(ndims), - ); + let constrained_ty = + make_ndarray_ty(self.unifier, self.primitives, Some(dummy_tvar), Some(ndims)); self.constrain(value.custom.unwrap(), constrained_ty, &value.location)?; let TypeEnum::TLiteral { values, .. } = &*self.unifier.get_ty_immutable(ndims) else { panic!("Expected TLiteral for ndarray.ndims, got {}", self.unifier.stringify(ndims)) }; - let ndims = values.iter() + let ndims = values + .iter() .map(|ndim| match *ndim { SymbolValue::U64(v) => Ok(v), SymbolValue::U32(v) => Ok(v as u64), - SymbolValue::I32(v) => u64::try_from(v).map_err(|_| HashSet::from([ - format!("Expected non-negative literal for ndarray.ndims, got {v}"), - ])), - SymbolValue::I64(v) => u64::try_from(v).map_err(|_| HashSet::from([ - format!("Expected non-negative literal for ndarray.ndims, got {v}"), - ])), + SymbolValue::I32(v) => u64::try_from(v).map_err(|_| { + HashSet::from([format!( + "Expected non-negative literal for ndarray.ndims, got {v}" + )]) + }), + SymbolValue::I64(v) => u64::try_from(v).map_err(|_| { + HashSet::from([format!( + "Expected non-negative literal for ndarray.ndims, got {v}" + )]) + }), _ => unreachable!(), }) .collect::, _>>()?; @@ -1685,12 +1671,13 @@ impl<'a> Inferencer<'a> { let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) { TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }), TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { - let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); + let (_, ndims) = + unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims)) } - _ => unreachable!() + _ => unreachable!(), }; self.constrain(value.custom.unwrap(), list_like_ty, &value.location)?; Ok(list_like_ty) @@ -1698,18 +1685,20 @@ impl<'a> Inferencer<'a> { ExprKind::Constant { value: ast::Constant::Int(val), .. } => { match &*self.unifier.get_ty(value.custom.unwrap()) { TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { - let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); + let (_, ndims) = + unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); self.infer_subscript_ndarray(value, ty, ndims) } _ => { // the index is a constant, so value can be a sequence. let ind: Option = (*val).try_into().ok(); - let ind = ind.ok_or_else(|| HashSet::from(["Index must be int32".to_string()]))?; + let ind = + ind.ok_or_else(|| HashSet::from(["Index must be int32".to_string()]))?; let map = once(( ind.into(), RecordField::new(ty, ctx == &ExprContext::Store, Some(value.location)), )) - .collect(); + .collect(); let seq = self.unifier.add_record(map); self.constrain(value.custom.unwrap(), seq, &value.location)?; Ok(ty) @@ -1717,54 +1706,67 @@ impl<'a> Inferencer<'a> { } } ExprKind::Tuple { elts, .. } => { - if value.custom + if value + .custom .unwrap() .obj_id(self.unifier) .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) - .not() { - return report_error("Tuple slices are only supported for ndarrays", slice.location) + .not() + { + return report_error( + "Tuple slices are only supported for ndarrays", + slice.location, + ); } for elt in elts { if let ExprKind::Slice { lower, upper, step } = &elt.node { for v in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() { self.constrain(v.custom.unwrap(), self.primitives.int32, &v.location)?; - } + } } else { self.constrain(elt.custom.unwrap(), self.primitives.int32, &elt.location)?; } } let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); - let ndarray_ty = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims)); + let ndarray_ty = + make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims)); self.constrain(value.custom.unwrap(), ndarray_ty, &value.location)?; Ok(ndarray_ty) } _ => { if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) { - return report_error("Tuple index must be a constant (KernelInvariant is also not supported)", slice.location) + return report_error( + "Tuple index must be a constant (KernelInvariant is also not supported)", + slice.location, + ); } // the index is not a constant, so value can only be a list-like structure match &*self.unifier.get_ty(value.custom.unwrap()) { TypeEnum::TList { .. } => { - self.constrain(slice.custom.unwrap(), self.primitives.int32, &slice.location)?; + self.constrain( + slice.custom.unwrap(), + self.primitives.int32, + &slice.location, + )?; let list = self.unifier.add_ty(TypeEnum::TList { ty }); self.constrain(value.custom.unwrap(), list, &value.location)?; Ok(ty) } TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { - let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); + let (_, ndims) = + unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); - let valid_index_tys = [ - self.primitives.int32, - self.primitives.isize(), - ].into_iter().unique().collect_vec(); - let valid_index_ty = self.unifier.get_fresh_var_with_range( - valid_index_tys.as_slice(), - None, - None, - ).0; + let valid_index_tys = [self.primitives.int32, self.primitives.isize()] + .into_iter() + .unique() + .collect_vec(); + let valid_index_ty = self + .unifier + .get_fresh_var_with_range(valid_index_tys.as_slice(), None, None) + .0; self.constrain(slice.custom.unwrap(), valid_index_ty, &slice.location)?; self.infer_subscript_ndarray(value, ty, ndims) } diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index e6fe8009..13684e76 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -3,12 +3,12 @@ use super::*; use crate::{ codegen::CodeGenContext, symbol_resolver::ValueEnum, - toplevel::{DefinitionId, helper::PRIMITIVE_DEF_IDS, TopLevelDef}, + toplevel::{helper::PRIMITIVE_DEF_IDS, DefinitionId, TopLevelDef}, }; use indoc::indoc; -use std::iter::zip; use nac3parser::parser::parse_program; use parking_lot::RwLock; +use std::iter::zip; use test_case::test_case; struct Resolver { @@ -44,7 +44,9 @@ impl SymbolResolver for Resolver { } fn get_identifier_def(&self, id: StrRef) -> Result> { - self.id_to_def.get(&id).cloned() + self.id_to_def + .get(&id) + .cloned() .ok_or_else(|| HashSet::from(["Unknown identifier".to_string()])) } @@ -136,7 +138,8 @@ impl TestEnvironment { params: VarMap::new(), }); let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None); - let ndarray_ndims_tvar = unifier.get_fresh_const_generic_var(uint64, Some("ndarray_ndims".into()), None); + let ndarray_ndims_tvar = + unifier.get_fresh_const_generic_var(uint64, Some("ndarray_ndims".into()), None); let ndarray = unifier.add_ty(TypeEnum::TObj { obj_id: PRIMITIVE_DEF_IDS.ndarray, fields: HashMap::new(), diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 4d6098a3..b90698eb 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -1,12 +1,12 @@ +use indexmap::IndexMap; +use itertools::Itertools; use std::cell::RefCell; use std::collections::HashMap; use std::fmt::Display; +use std::iter::zip; use std::rc::Rc; use std::sync::{Arc, Mutex}; use std::{borrow::Cow, collections::HashSet}; -use std::iter::zip; -use indexmap::IndexMap; -use itertools::Itertools; use nac3parser::ast::{Location, StrRef}; @@ -61,7 +61,7 @@ pub enum RecordKey { } impl Type { - /// Wrapper function for cleaner code so that we don't need to write this long pattern matching + /// Wrapper function for cleaner code so that we don't need to write this long pattern matching /// just to get the field `obj_id`. #[must_use] pub fn obj_id(self, unifier: &Unifier) -> Option { @@ -250,9 +250,9 @@ impl Unifier { } /// Returns the [`UnificationTable`] associated with this `Unifier`. - /// + /// /// # Safety - /// + /// /// The use of this function is discouraged under most circumstances. Only use this function if /// in-place manipulation of type variables and/or type fields is necessary, otherwise prefer to /// [add a new type][`Unifier::add_ty`] and [unify the type][`Unifier::unify`] with an existing @@ -379,7 +379,17 @@ impl Unifier { let id = self.var_id + 1; self.var_id += 1; let range = range.to_vec(); - (self.add_ty(TypeEnum::TVar { id, range, fields: None, name, loc, is_const_generic: false }), id) + ( + self.add_ty(TypeEnum::TVar { + id, + range, + fields: None, + name, + loc, + is_const_generic: false, + }), + id, + ) } /// Returns a fresh type representing a constant generic variable with the given underlying type `ty`. @@ -391,19 +401,22 @@ impl Unifier { ) -> (Type, u32) { let id = self.var_id + 1; self.var_id += 1; - (self.add_ty(TypeEnum::TVar { id, range: vec![ty], fields: None, name, loc, is_const_generic: true }), id) + ( + self.add_ty(TypeEnum::TVar { + id, + range: vec![ty], + fields: None, + name, + loc, + is_const_generic: true, + }), + id, + ) } /// Returns a fresh type representing a [literal][TypeEnum::TConstant] with the given `values`. - pub fn get_fresh_literal( - &mut self, - values: Vec, - loc: Option, - ) -> Type { - let ty_enum = TypeEnum::TLiteral { - values: values.into_iter().dedup().collect(), - loc - }; + pub fn get_fresh_literal(&mut self, values: Vec, loc: Option) -> Type { + let ty_enum = TypeEnum::TLiteral { values: values.into_iter().dedup().collect(), loc }; self.add_ty(ty_enum) } @@ -423,7 +436,9 @@ impl Unifier { Some( range .iter() - .flat_map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty])) + .flat_map(|ty| { + self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty]) + }) .collect_vec(), ) } @@ -479,7 +494,7 @@ impl Unifier { pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool { use TypeEnum::*; match &*self.get_ty(a) { - TRigidVar { .. } + TRigidVar { .. } | TLiteral { .. } // functions are instantiated for each call sites, so the function type can contain // type variables. @@ -487,7 +502,7 @@ impl Unifier { TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)), TCall { .. } => false, - TList { ty } + TList { ty } | TVirtual { ty } => self.is_concrete(*ty, allowed_typevars), TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)), @@ -526,9 +541,7 @@ impl Unifier { let instantiated = self.instantiate_fun(b, signature); let r = self.get_ty(instantiated); let r = r.as_ref(); - let TypeEnum::TFunc(signature) = r else { - unreachable!() - }; + let TypeEnum::TFunc(signature) = r else { unreachable!() }; // we check to make sure that all required arguments (those without default // arguments) are provided, and do not provide the same argument twice. let mut required = required.to_vec(); @@ -555,13 +568,10 @@ impl Unifier { if let Some(i) = required.iter().position(|v| v == k) { required.remove(i); } - let i = all_names - .iter() - .position(|v| &v.0 == k) - .ok_or_else(|| { - self.restore_snapshot(); - TypeError::new(TypeErrorKind::UnknownArgName(*k), *loc) - })?; + let i = all_names.iter().position(|v| &v.0 == k).ok_or_else(|| { + self.restore_snapshot(); + TypeError::new(TypeErrorKind::UnknownArgName(*k), *loc) + })?; let (name, expected) = all_names.remove(i); self.unify_impl(expected, *t, false).map_err(|_| { self.restore_snapshot(); @@ -627,8 +637,17 @@ impl Unifier { }; match (&*ty_a, &*ty_b) { ( - TVar { fields: fields1, id, name: name1, loc: loc1, is_const_generic: false, .. }, - TVar { fields: fields2, id: id2, name: name2, loc: loc2, is_const_generic: false, .. }, + TVar { + fields: fields1, id, name: name1, loc: loc1, is_const_generic: false, .. + }, + TVar { + fields: fields2, + id: id2, + name: name2, + loc: loc2, + is_const_generic: false, + .. + }, ) => { let new_fields = match (fields1, fields2) { (None, None) => None, @@ -750,7 +769,10 @@ impl Unifier { self.set_a_to_b(a, x); } - (TVar { id: id1, range: ty1, is_const_generic: true, .. }, TVar { id: id2, range: ty2, .. }) => { + ( + TVar { id: id1, range: ty1, is_const_generic: true, .. }, + TVar { id: id2, range: ty2, .. }, + ) => { let ty1 = ty1[0]; let ty2 = ty2[0]; @@ -765,17 +787,17 @@ impl Unifier { assert_eq!(tys.len(), 1); assert_eq!(values.len(), 1); - let primitives = &self.primitive_store - .expect("Expected PrimitiveStore to be present"); + let primitives = + &self.primitive_store.expect("Expected PrimitiveStore to be present"); let ty = tys[0]; - let value= &values[0]; + let value = &values[0]; let value_ty = value.get_type(primitives, self); // If the types don't match, try to implicitly promote integers if !self.unioned(ty, value_ty) { let Ok(num_val) = i128::try_from(value.clone()) else { - return Self::incompatible_types(a, b) + return Self::incompatible_types(a, b); }; let can_convert = if self.unioned(ty, primitives.int32) { @@ -791,7 +813,7 @@ impl Unifier { }; if !can_convert { - return Self::incompatible_types(a, b) + return Self::incompatible_types(a, b); } } @@ -816,7 +838,7 @@ impl Unifier { let v2i = symbol_value_to_int(v2); if v1i != v2i { - return Self::incompatible_types(a, b) + return Self::incompatible_types(a, b); } } } @@ -1287,8 +1309,8 @@ impl Unifier { mapping: &VarMap, cache: &mut HashMap>, ) -> Option> - where - K: std::hash::Hash + Eq + Clone, + where + K: std::hash::Hash + Eq + Clone, { let mut map2 = None; for (k, v) in map { diff --git a/nac3core/src/typecheck/typedef/test.rs b/nac3core/src/typecheck/typedef/test.rs index fd15cfb2..1c6267e2 100644 --- a/nac3core/src/typecheck/typedef/test.rs +++ b/nac3core/src/typecheck/typedef/test.rs @@ -45,9 +45,9 @@ impl Unifier { } } - fn map_eq(&mut self, map1: &IndexMapping, map2: &IndexMapping) -> bool - where - K: std::hash::Hash + Eq + Clone + fn map_eq(&mut self, map1: &IndexMapping, map2: &IndexMapping) -> bool + where + K: std::hash::Hash + Eq + Clone, { if map1.len() != map2.len() { return false; @@ -342,16 +342,12 @@ fn test_recursive_subst() { with_fields(&mut env.unifier, foo_id, |_unifier, fields| { fields.insert("rec".into(), (foo_id, true)); }); - let TypeEnum::TObj { params, .. } = &*foo_ty else { - unreachable!() - }; + let TypeEnum::TObj { params, .. } = &*foo_ty else { unreachable!() }; let mapping = params.iter().map(|(id, _)| (*id, int)).collect(); let instantiated = env.unifier.subst(foo_id, &mapping).unwrap(); let instantiated_ty = env.unifier.get_ty(instantiated); - let TypeEnum::TObj { fields, .. } = &*instantiated_ty else { - unreachable!() - }; + let TypeEnum::TObj { fields, .. } = &*instantiated_ty else { unreachable!() }; assert!(env.unifier.unioned(fields.get(&"a".into()).unwrap().0, int)); assert!(env.unifier.unioned(fields.get(&"rec".into()).unwrap().0, instantiated)); } @@ -477,7 +473,8 @@ fn test_typevar_range() { assert_eq!( env.unify(a_list, int_list), Err("Incompatible types: list[typevar22] and list[0]\ - \n\nNotes:\n typevar22 ∈ {1}".into()) + \n\nNotes:\n typevar22 ∈ {1}" + .into()) ); let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0; @@ -505,7 +502,10 @@ fn test_rigid_var() { assert_eq!(env.unify(a, b), Err("Incompatible types: typevar3 and typevar2".to_string())); env.unifier.unify(list_a, list_x).unwrap(); - assert_eq!(env.unify(list_x, list_int), Err("Incompatible types: list[typevar2] and list[0]".to_string())); + assert_eq!( + env.unify(list_x, list_int), + Err("Incompatible types: list[typevar2] and list[0]".to_string()) + ); env.unifier.replace_rigid_var(a, int); env.unifier.unify(list_x, list_int).unwrap(); diff --git a/nac3core/src/typecheck/unification_table.rs b/nac3core/src/typecheck/unification_table.rs index 101057c5..5b0570f9 100644 --- a/nac3core/src/typecheck/unification_table.rs +++ b/nac3core/src/typecheck/unification_table.rs @@ -16,21 +16,10 @@ pub struct UnificationTable { #[derive(Clone, Debug)] enum Action { - Parent { - key: usize, - original_parent: usize, - }, - Value { - key: usize, - original_value: Option, - }, - Rank { - key: usize, - original_rank: u32, - }, - Marker { - generation: u32, - } + Parent { key: usize, original_parent: usize }, + Value { key: usize, original_value: Option }, + Rank { key: usize, original_rank: u32 }, + Marker { generation: u32 }, } impl Default for UnificationTable { @@ -41,7 +30,13 @@ impl Default for UnificationTable { impl UnificationTable { pub fn new() -> UnificationTable { - UnificationTable { parents: Vec::new(), ranks: Vec::new(), values: Vec::new(), log: Vec::new(), generation: 0 } + UnificationTable { + parents: Vec::new(), + ranks: Vec::new(), + values: Vec::new(), + log: Vec::new(), + generation: 0, + } } pub fn new_key(&mut self, v: V) -> UnificationKey { @@ -125,7 +120,10 @@ impl UnificationTable { pub fn restore_snapshot(&mut self, snapshot: (usize, u32)) { let (log_len, generation) = snapshot; assert!(self.log.len() >= log_len, "snapshot restoration error"); - assert!(matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation), "snapshot restoration error"); + assert!( + matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation), + "snapshot restoration error" + ); for action in self.log.drain(log_len - 1..).rev() { match action { Action::Parent { key, original_parent } => { @@ -145,7 +143,10 @@ impl UnificationTable { pub fn discard_snapshot(&mut self, snapshot: (usize, u32)) { let (log_len, generation) = snapshot; assert!(self.log.len() >= log_len, "snapshot discard error"); - assert!(matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation), "snapshot discard error"); + assert!( + matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation), + "snapshot discard error" + ); self.log.clear(); } } @@ -159,11 +160,23 @@ where .enumerate() .map(|(i, (v, p))| if *p == i { v.as_ref().map(|v| v.as_ref().clone()) } else { None }) .collect(); - UnificationTable { parents: self.parents.clone(), ranks: self.ranks.clone(), values, log: Vec::new(), generation: 0 } + UnificationTable { + parents: self.parents.clone(), + ranks: self.ranks.clone(), + values, + log: Vec::new(), + generation: 0, + } } pub fn from_send(table: &UnificationTable) -> UnificationTable> { let values = table.values.iter().cloned().map(|v| v.map(Rc::new)).collect(); - UnificationTable { parents: table.parents.clone(), ranks: table.ranks.clone(), values, log: Vec::new(), generation: 0 } + UnificationTable { + parents: table.parents.clone(), + ranks: table.ranks.clone(), + values, + log: Vec::new(), + generation: 0, + } } } diff --git a/nac3ld/src/dwarf.rs b/nac3ld/src/dwarf.rs index 44df6595..92d2cefa 100644 --- a/nac3ld/src/dwarf.rs +++ b/nac3ld/src/dwarf.rs @@ -32,7 +32,6 @@ pub struct DwarfReader<'a> { } impl<'a> DwarfReader<'a> { - pub fn new(slice: &[u8], virt_addr: u32) -> DwarfReader { DwarfReader { slice, virt_addr, base_slice: slice, base_virt_addr: virt_addr } } @@ -170,10 +169,7 @@ fn read_encoded_pointer(reader: &mut DwarfReader, encoding: u8) -> Result Result { +fn read_encoded_pointer_with_pc(reader: &mut DwarfReader, encoding: u8) -> Result { let entry_virt_addr = reader.virt_addr; let mut result = read_encoded_pointer(reader, encoding)?; @@ -223,7 +219,6 @@ pub struct EH_Frame<'a> { } impl<'a> EH_Frame<'a> { - /// Creates an [EH_Frame] using the bytes in the `.eh_frame` section and its address in the ELF /// file. pub fn new(eh_frame_slice: &[u8], eh_frame_addr: u32) -> Result { @@ -235,10 +230,7 @@ impl<'a> EH_Frame<'a> { let reader = DwarfReader::from_reader(&self.reader, true); let len = reader.slice.len(); - CFI_Records { - reader, - available: len, - } + CFI_Records { reader, available: len } } } @@ -255,7 +247,6 @@ pub struct CFI_Record<'a> { } impl<'a> CFI_Record<'a> { - pub fn from_reader(cie_reader: &mut DwarfReader<'a>) -> Result, ()> { let length = cie_reader.read_u32(); let fde_reader = match length { @@ -323,10 +314,7 @@ impl<'a> CFI_Record<'a> { } assert_ne!(fde_pointer_encoding, DW_EH_PE_omit); - Ok(CFI_Record { - fde_pointer_encoding, - fde_reader, - }) + Ok(CFI_Record { fde_pointer_encoding, fde_reader }) } /// Returns a [DwarfReader] initialized to the first Frame Description Entry (FDE) of this CFI @@ -340,11 +328,7 @@ impl<'a> CFI_Record<'a> { let reader = self.get_fde_reader(); let len = reader.slice.len(); - FDE_Records { - pointer_encoding: self.fde_pointer_encoding, - reader, - available: len, - } + FDE_Records { pointer_encoding: self.fde_pointer_encoding, reader, available: len } } } @@ -387,7 +371,7 @@ impl<'a> Iterator for CFI_Records<'a> { // Skip this record if it is a FDE if cie_ptr == 0 { // Rewind back to the start of the CFI Record - return Some(CFI_Record::from_reader(&mut this_reader).ok().unwrap()) + return Some(CFI_Record::from_reader(&mut this_reader).ok().unwrap()); } } } @@ -448,7 +432,6 @@ pub struct EH_Frame_Hdr<'a> { } impl<'a> EH_Frame_Hdr<'a> { - /// Create a [EH_Frame_Hdr] object, and write out the fixed fields of `.eh_frame_hdr` to memory. /// /// Load address is not known at this point. @@ -459,15 +442,16 @@ impl<'a> EH_Frame_Hdr<'a> { ) -> EH_Frame_Hdr { let mut writer = DwarfWriter::new(eh_frame_hdr_slice); - writer.write_u8(1); // version - writer.write_u8(0x1B); // eh_frame_ptr_enc - PC-relative 4-byte signed value - writer.write_u8(0x03); // fde_count_enc - 4-byte unsigned value - writer.write_u8(0x3B); // table_enc - .eh_frame_hdr section-relative 4-byte signed value + writer.write_u8(1); // version + writer.write_u8(0x1B); // eh_frame_ptr_enc - PC-relative 4-byte signed value + writer.write_u8(0x03); // fde_count_enc - 4-byte unsigned value + writer.write_u8(0x3B); // table_enc - .eh_frame_hdr section-relative 4-byte signed value - let eh_frame_offset = eh_frame_addr - .wrapping_sub(eh_frame_hdr_addr + writer.offset as u32 + ((mem::size_of::() as u32) * 4)); - writer.write_u32(eh_frame_offset); // eh_frame_ptr - writer.write_u32(0); // `fde_count`, will be written in finalize_fde + let eh_frame_offset = eh_frame_addr.wrapping_sub( + eh_frame_hdr_addr + writer.offset as u32 + ((mem::size_of::() as u32) * 4), + ); + writer.write_u32(eh_frame_offset); // eh_frame_ptr + writer.write_u32(0); // `fde_count`, will be written in finalize_fde EH_Frame_Hdr { fde_writer: writer, eh_frame_hdr_addr, fdes: Vec::new() } } @@ -492,7 +476,10 @@ impl<'a> EH_Frame_Hdr<'a> { self.fde_writer.write_u32(*init_loc); self.fde_writer.write_u32(*addr); } - LittleEndian::write_u32(&mut self.fde_writer.slice[Self::fde_count_offset()..], self.fdes.len() as u32); + LittleEndian::write_u32( + &mut self.fde_writer.slice[Self::fde_count_offset()..], + self.fdes.len() as u32, + ); } pub fn size_from_eh_frame(eh_frame: &[u8]) -> usize { diff --git a/nac3ld/src/lib.rs b/nac3ld/src/lib.rs index 9af041ec..c6348e1f 100644 --- a/nac3ld/src/lib.rs +++ b/nac3ld/src/lib.rs @@ -205,11 +205,9 @@ impl<'a> Linker<'a> { for reloc in relocs { let sym = match reloc.sym_info() as usize { STN_UNDEF => None, - sym_index => Some( - self.symtab - .get(sym_index) - .ok_or("symbol out of bounds of symbol table")?, - ), + sym_index => { + Some(self.symtab.get(sym_index).ok_or("symbol out of bounds of symbol table")?) + } }; let resolve_symbol_addr = @@ -314,9 +312,8 @@ impl<'a> Linker<'a> { R_RISCV_PCREL_LO12_I => { let expected_offset = sym_option.map_or(0, |sym| sym.st_value); - let indirect_reloc = relocs - .iter() - .find(|reloc| reloc.offset() == expected_offset)?; + let indirect_reloc = + relocs.iter().find(|reloc| reloc.offset() == expected_offset)?; Some(RelocInfo { defined_val: { let indirect_sym = @@ -354,10 +351,7 @@ impl<'a> Linker<'a> { indirect_reloc: None, pc_relative: false, relocate: Some(Box::new(|target_word, value| { - LittleEndian::write_u32( - target_word, - value, - ) + LittleEndian::write_u32(target_word, value) })), }), @@ -386,10 +380,7 @@ impl<'a> Linker<'a> { indirect_reloc: None, pc_relative: false, relocate: Some(Box::new(|target_word, value| { - LittleEndian::write_u16( - target_word, - value as u16, - ) + LittleEndian::write_u16(target_word, value as u16) })), }), @@ -552,9 +543,12 @@ impl<'a> Linker<'a> { eh_frame_hdr_rec.shdr.sh_offset, eh_frame_rec.shdr.sh_offset, ); - eh_frame.cfi_records() - .flat_map(|cfi| cfi.fde_records()) - .for_each(&mut |(init_pos, virt_addr)| eh_frame_hdr.add_fde(init_pos, virt_addr)); + eh_frame.cfi_records().flat_map(|cfi| cfi.fde_records()).for_each(&mut |( + init_pos, + virt_addr, + )| { + eh_frame_hdr.add_fde(init_pos, virt_addr) + }); // Sort FDE entries in .eh_frame_hdr eh_frame_hdr.finalize_fde(); @@ -599,24 +593,22 @@ impl<'a> Linker<'a> { // Section table for the .elf paired with the section name // To be formalized incrementally // Very hashmap-like structure, but the order matters, so it is a vector - let elf_shdrs = vec![ - SectionRecord { - shdr: Elf32_Shdr { - sh_name: 0, - sh_type: 0, - sh_flags: 0, - sh_addr: 0, - sh_offset: 0, - sh_size: 0, - sh_link: 0, - sh_info: 0, - sh_addralign: 0, - sh_entsize: 0, - }, - name: "", - data: vec![0; 0], + let elf_shdrs = vec![SectionRecord { + shdr: Elf32_Shdr { + sh_name: 0, + sh_type: 0, + sh_flags: 0, + sh_addr: 0, + sh_offset: 0, + sh_size: 0, + sh_link: 0, + sh_info: 0, + sh_addralign: 0, + sh_entsize: 0, }, - ]; + name: "", + data: vec![0; 0], + }]; let elf_sh_data_off = mem::size_of::() + mem::size_of::() * 5; // Image of the linked dynamic library, to be formalized incrementally @@ -1010,7 +1002,9 @@ impl<'a> Linker<'a> { let mut hash_bucket: Vec = vec![0; dynsym.len()]; let mut hash_chain: Vec = vec![0; dynsym.len()]; - for (sym_index, (str_start, str_end)) in dynsym_names.iter().enumerate().take(dynsym.len()).skip(1) { + for (sym_index, (str_start, str_end)) in + dynsym_names.iter().enumerate().take(dynsym.len()).skip(1) + { let hash = elf_hash(&dynstr[*str_start..*str_end]); let mut hash_index = hash as usize % hash_bucket.len(); @@ -1253,7 +1247,9 @@ impl<'a> Linker<'a> { update_dynsym_record!(b"__bss_start", bss_offset, bss_elf_index as Elf32_Section); update_dynsym_record!(b"_end", bss_offset, bss_elf_index as Elf32_Section); } else { - for (bss_iter_index, &(bss_section_index, section_name)) in bss_index_vec.iter().enumerate() { + for (bss_iter_index, &(bss_section_index, section_name)) in + bss_index_vec.iter().enumerate() + { let shdr = &shdrs[bss_section_index]; let bss_elf_index = linker.load_section( shdr, diff --git a/nac3parser/src/config_comment_helper.rs b/nac3parser/src/config_comment_helper.rs index dbc47257..c91e14af 100644 --- a/nac3parser/src/config_comment_helper.rs +++ b/nac3parser/src/config_comment_helper.rs @@ -1,15 +1,15 @@ -use lalrpop_util::ParseError; -use nac3ast::*; use crate::ast::Ident; use crate::ast::Location; -use crate::token::Tok; use crate::error::*; +use crate::token::Tok; +use lalrpop_util::ParseError; +use nac3ast::*; pub fn make_config_comment( com_loc: Location, stmt_loc: Location, nac3com_above: Vec<(Ident, Tok)>, - nac3com_end: Option + nac3com_end: Option, ) -> Result, ParseError> { if com_loc.column() != stmt_loc.column() && !nac3com_above.is_empty() { return Err(ParseError::User { @@ -23,18 +23,21 @@ pub fn make_config_comment( ) ) } - }) + }); }; - Ok( - nac3com_above - .into_iter() - .map(|(com, _)| com) - .chain(nac3com_end.map_or_else(|| vec![].into_iter(), |com| vec![com].into_iter())) - .collect() - ) + Ok(nac3com_above + .into_iter() + .map(|(com, _)| com) + .chain(nac3com_end.map_or_else(|| vec![].into_iter(), |com| vec![com].into_iter())) + .collect()) } -pub fn handle_small_stmt(stmts: &mut [Stmt], nac3com_above: Vec<(Ident, Tok)>, nac3com_end: Option, com_above_loc: Location) -> Result<(), ParseError> { +pub fn handle_small_stmt( + stmts: &mut [Stmt], + nac3com_above: Vec<(Ident, Tok)>, + nac3com_end: Option, + com_above_loc: Location, +) -> Result<(), ParseError> { if com_above_loc.column() != stmts[0].location.column() && !nac3com_above.is_empty() { return Err(ParseError::User { error: LexicalError { @@ -47,17 +50,12 @@ pub fn handle_small_stmt(stmts: &mut [Stmt], nac3com_above: Vec<(Ident, To ) ) } - }) + }); } - apply_config_comments( - &mut stmts[0], - nac3com_above - .into_iter() - .map(|(com, _)| com).collect() - ); + apply_config_comments(&mut stmts[0], nac3com_above.into_iter().map(|(com, _)| com).collect()); apply_config_comments( stmts.last_mut().unwrap(), - nac3com_end.map_or_else(Vec::new, |com| vec![com]) + nac3com_end.map_or_else(Vec::new, |com| vec![com]), ); Ok(()) } @@ -72,7 +70,7 @@ fn apply_config_comments(stmt: &mut Stmt, comments: Vec) { | StmtKind::AnnAssign { config_comment, .. } | StmtKind::Break { config_comment, .. } | StmtKind::Continue { config_comment, .. } - | StmtKind::Return { config_comment, .. } + | StmtKind::Return { config_comment, .. } | StmtKind::Raise { config_comment, .. } | StmtKind::Import { config_comment, .. } | StmtKind::ImportFrom { config_comment, .. } @@ -80,6 +78,8 @@ fn apply_config_comments(stmt: &mut Stmt, comments: Vec) { | StmtKind::Nonlocal { config_comment, .. } | StmtKind::Assert { config_comment, .. } => config_comment.extend(comments), - _ => { unreachable!("only small statements should call this function") } + _ => { + unreachable!("only small statements should call this function") + } } } diff --git a/nac3parser/src/error.rs b/nac3parser/src/error.rs index 7d3bc8ac..e196c059 100644 --- a/nac3parser/src/error.rs +++ b/nac3parser/src/error.rs @@ -145,35 +145,27 @@ impl From> for ParseError { fn from(err: LalrpopError) -> Self { match err { // TODO: Are there cases where this isn't an EOF? - LalrpopError::InvalidToken { location } => ParseError { - error: ParseErrorType::Eof, - location, - }, - LalrpopError::ExtraToken { token } => ParseError { - error: ParseErrorType::ExtraToken(token.1), - location: token.0, - }, - LalrpopError::User { error } => ParseError { - error: ParseErrorType::Lexical(error.error), - location: error.location, - }, + LalrpopError::InvalidToken { location } => { + ParseError { error: ParseErrorType::Eof, location } + } + LalrpopError::ExtraToken { token } => { + ParseError { error: ParseErrorType::ExtraToken(token.1), location: token.0 } + } + LalrpopError::User { error } => { + ParseError { error: ParseErrorType::Lexical(error.error), location: error.location } + } LalrpopError::UnrecognizedToken { token, expected } => { // Hacky, but it's how CPython does it. See PyParser_AddToken, // in particular "Only one possible expected token" comment. - let expected = if expected.len() == 1 { - Some(expected[0].clone()) - } else { - None - }; + let expected = if expected.len() == 1 { Some(expected[0].clone()) } else { None }; ParseError { error: ParseErrorType::UnrecognizedToken(token.1, expected), location: token.0, } } - LalrpopError::UnrecognizedEof { location, .. } => ParseError { - error: ParseErrorType::Eof, - location, - }, + LalrpopError::UnrecognizedEof { location, .. } => { + ParseError { error: ParseErrorType::Eof, location } + } } } } diff --git a/nac3parser/src/fstring.rs b/nac3parser/src/fstring.rs index 6910ddc9..e0d6318a 100644 --- a/nac3parser/src/fstring.rs +++ b/nac3parser/src/fstring.rs @@ -15,10 +15,7 @@ struct FStringParser<'a> { impl<'a> FStringParser<'a> { fn new(source: &'a str, str_location: Location) -> Self { - Self { - chars: source.chars().peekable(), - str_location, - } + Self { chars: source.chars().peekable(), str_location } } #[inline] @@ -251,17 +248,11 @@ impl<'a> FStringParser<'a> { } if !content.is_empty() { - values.push(self.expr(ExprKind::Constant { - value: content.into(), - kind: None, - })) + values.push(self.expr(ExprKind::Constant { value: content.into(), kind: None })) } let s = match values.len() { - 0 => self.expr(ExprKind::Constant { - value: String::new().into(), - kind: None, - }), + 0 => self.expr(ExprKind::Constant { value: String::new().into(), kind: None }), 1 => values.into_iter().next().unwrap(), _ => self.expr(ExprKind::JoinedStr { values }), }; @@ -277,9 +268,7 @@ fn parse_fstring_expr(source: &str) -> Result { /// Parse an fstring from a string, located at a certain position in the sourcecode. /// In case of errors, we will get the location and the error returned. pub fn parse_located_fstring(source: &str, location: Location) -> Result { - FStringParser::new(source, location) - .parse() - .map_err(|error| FStringError { error, location }) + FStringParser::new(source, location).parse().map_err(|error| FStringError { error, location }) } #[cfg(test)] diff --git a/nac3parser/src/function.rs b/nac3parser/src/function.rs index a6fa07e0..a6969e0f 100644 --- a/nac3parser/src/function.rs +++ b/nac3parser/src/function.rs @@ -69,10 +69,7 @@ pub fn parse_args(func_args: Vec) -> Result { diff --git a/nac3parser/src/lexer.rs b/nac3parser/src/lexer.rs index ae26c5bc..2329f4a4 100644 --- a/nac3parser/src/lexer.rs +++ b/nac3parser/src/lexer.rs @@ -3,12 +3,12 @@ //! This means source code is translated into separate tokens. pub use super::token::Tok; -use crate::ast::{Location, FileName}; +use crate::ast::{FileName, Location}; use crate::error::{LexicalError, LexicalErrorType}; use std::char; use std::cmp::Ordering; -use std::str::FromStr; use std::num::IntErrorKind; +use std::str::FromStr; use unic_emoji_char::is_emoji_presentation; use unic_ucd_ident::{is_xid_continue, is_xid_start}; @@ -32,20 +32,14 @@ impl IndentationLevel { if self.spaces <= other.spaces { Ok(Ordering::Less) } else { - Err(LexicalError { - location, - error: LexicalErrorType::TabError, - }) + Err(LexicalError { location, error: LexicalErrorType::TabError }) } } Ordering::Greater => { if self.spaces >= other.spaces { Ok(Ordering::Greater) } else { - Err(LexicalError { - location, - error: LexicalErrorType::TabError, - }) + Err(LexicalError { location, error: LexicalErrorType::TabError }) } } Ordering::Equal => Ok(self.spaces.cmp(&other.spaces)), @@ -63,7 +57,7 @@ pub struct Lexer> { chr1: Option, chr2: Option, location: Location, - config_comment_prefix: Option<&'static str> + config_comment_prefix: Option<&'static str>, } pub static KEYWORDS: phf::Map<&'static str, Tok> = phf::phf_map! { @@ -136,11 +130,7 @@ where T: Iterator, { pub fn new(source: T) -> Self { - let mut nlh = NewlineHandler { - source, - chr0: None, - chr1: None, - }; + let mut nlh = NewlineHandler { source, chr0: None, chr1: None }; nlh.shift(); nlh.shift(); nlh @@ -195,7 +185,7 @@ where location: start, chr1: None, chr2: None, - config_comment_prefix: Some(" nac3:") + config_comment_prefix: Some(" nac3:"), }; lxr.next_char(); lxr.next_char(); @@ -287,15 +277,15 @@ where let end_pos = self.get_pos(); let value = match i128::from_str_radix(&value_text, radix) { Ok(value) => value, - Err(e) => { - match e.kind() { - IntErrorKind::PosOverflow | IntErrorKind::NegOverflow => i128::MAX, - _ => return Err(LexicalError { + Err(e) => match e.kind() { + IntErrorKind::PosOverflow | IntErrorKind::NegOverflow => i128::MAX, + _ => { + return Err(LexicalError { error: LexicalErrorType::OtherError(format!("{:?}", e)), location: start_pos, - }), + }) } - } + }, }; Ok((start_pos, Tok::Int { value }, end_pos)) } @@ -338,14 +328,7 @@ where if self.chr0 == Some('j') || self.chr0 == Some('J') { self.next_char(); let end_pos = self.get_pos(); - Ok(( - start_pos, - Tok::Complex { - real: 0.0, - imag: value, - }, - end_pos, - )) + Ok((start_pos, Tok::Complex { real: 0.0, imag: value }, end_pos)) } else { let end_pos = self.get_pos(); Ok((start_pos, Tok::Float { value }, end_pos)) @@ -364,7 +347,7 @@ where let value = value_text.parse::().ok(); let nonzero = match value { Some(value) => value != 0i128, - None => true + None => true, }; if start_is_zero && nonzero { return Err(LexicalError { @@ -433,9 +416,8 @@ where fn lex_comment(&mut self) -> Option { self.next_char(); // if possibly nac3 pseudocomment, special handling for `# nac3:` - let (mut prefix, mut is_comment) = self - .config_comment_prefix - .map_or_else(|| ("".chars(), false), |v| (v.chars(), true)); + let (mut prefix, mut is_comment) = + self.config_comment_prefix.map_or_else(|| ("".chars(), false), |v| (v.chars(), true)); // for the correct location of config comment let mut start_loc = self.location; start_loc.go_left(); @@ -460,22 +442,20 @@ where return Some(( start_loc, Tok::ConfigComment { content: content.trim().into() }, - self.location + self.location, )); } } } } self.next_char(); - }; + } } fn unicode_literal(&mut self, literal_number: usize) -> Result { let mut p: u32 = 0u32; - let unicode_error = LexicalError { - error: LexicalErrorType::UnicodeError, - location: self.get_pos(), - }; + let unicode_error = + LexicalError { error: LexicalErrorType::UnicodeError, location: self.get_pos() }; for i in 1..=literal_number { match self.next_char() { Some(c) => match c.to_digit(16) { @@ -530,10 +510,8 @@ where } } } - unicode_names2::character(&name).ok_or(LexicalError { - error: LexicalErrorType::UnicodeError, - location: start_pos, - }) + unicode_names2::character(&name) + .ok_or(LexicalError { error: LexicalErrorType::UnicodeError, location: start_pos }) } fn lex_string( @@ -650,14 +628,9 @@ where let end_pos = self.get_pos(); let tok = if is_bytes { - Tok::Bytes { - value: string_content.chars().map(|c| c as u8).collect(), - } + Tok::Bytes { value: string_content.chars().map(|c| c as u8).collect() } } else { - Tok::String { - value: string_content, - is_fstring, - } + Tok::String { value: string_content, is_fstring } }; Ok((start_pos, tok, end_pos)) @@ -842,11 +815,7 @@ where let tok_start = self.get_pos(); self.next_char(); let tok_end = self.get_pos(); - self.emit(( - tok_start, - Tok::Name { name: c.to_string().into() }, - tok_end, - )); + self.emit((tok_start, Tok::Name { name: c.to_string().into() }, tok_end)); } else { self.consume_character(c)?; } @@ -1439,14 +1408,8 @@ class Foo(A, B): assert_eq!( tokens, vec![ - Tok::String { - value: "\\\\".to_owned(), - is_fstring: false, - }, - Tok::String { - value: "\\".to_owned(), - is_fstring: false, - }, + Tok::String { value: "\\\\".to_owned(), is_fstring: false }, + Tok::String { value: "\\".to_owned(), is_fstring: false }, Tok::Newline, ] ); @@ -1459,27 +1422,13 @@ class Foo(A, B): assert_eq!( tokens, vec![ - Tok::Int { - value: 47i128, - }, - Tok::Int { - value: 13i128, - }, - Tok::Int { - value: 0i128, - }, - Tok::Int { - value: 123i128, - }, + Tok::Int { value: 47i128 }, + Tok::Int { value: 13i128 }, + Tok::Int { value: 0i128 }, + Tok::Int { value: 123i128 }, Tok::Float { value: 0.2 }, - Tok::Complex { - real: 0.0, - imag: 2.0, - }, - Tok::Complex { - real: 0.0, - imag: 2.2, - }, + Tok::Complex { real: 0.0, imag: 2.0 }, + Tok::Complex { real: 0.0, imag: 2.2 }, Tok::Newline, ] ); @@ -1539,21 +1488,13 @@ class Foo(A, B): assert_eq!( tokens, vec![ - Tok::Name { - name: String::from("avariable").into(), - }, + Tok::Name { name: String::from("avariable").into() }, Tok::Equal, - Tok::Int { - value: 99i128 - }, + Tok::Int { value: 99i128 }, Tok::Plus, - Tok::Int { - value: 2i128 - }, + Tok::Int { value: 2i128 }, Tok::Minus, - Tok::Int { - value: 0i128 - }, + Tok::Int { value: 0i128 }, Tok::Newline, ] ); @@ -1740,42 +1681,15 @@ class Foo(A, B): assert_eq!( tokens, vec![ - Tok::String { - value: String::from("double"), - is_fstring: false, - }, - Tok::String { - value: String::from("single"), - is_fstring: false, - }, - Tok::String { - value: String::from("can't"), - is_fstring: false, - }, - Tok::String { - value: String::from("\\\""), - is_fstring: false, - }, - Tok::String { - value: String::from("\t\r\n"), - is_fstring: false, - }, - Tok::String { - value: String::from("\\g"), - is_fstring: false, - }, - Tok::String { - value: String::from("raw\\'"), - is_fstring: false, - }, - Tok::String { - value: String::from("Đ"), - is_fstring: false, - }, - Tok::String { - value: String::from("\u{80}\u{0}a"), - is_fstring: false, - }, + Tok::String { value: String::from("double"), is_fstring: false }, + Tok::String { value: String::from("single"), is_fstring: false }, + Tok::String { value: String::from("can't"), is_fstring: false }, + Tok::String { value: String::from("\\\""), is_fstring: false }, + Tok::String { value: String::from("\t\r\n"), is_fstring: false }, + Tok::String { value: String::from("\\g"), is_fstring: false }, + Tok::String { value: String::from("raw\\'"), is_fstring: false }, + Tok::String { value: String::from("Đ"), is_fstring: false }, + Tok::String { value: String::from("\u{80}\u{0}a"), is_fstring: false }, Tok::Newline, ] ); @@ -1840,41 +1754,17 @@ class Foo(A, B): fn test_raw_byte_literal() { let source = r"rb'\x1z'"; let tokens = lex_source(source); - assert_eq!( - tokens, - vec![ - Tok::Bytes { - value: b"\\x1z".to_vec() - }, - Tok::Newline - ] - ); + assert_eq!(tokens, vec![Tok::Bytes { value: b"\\x1z".to_vec() }, Tok::Newline]); let source = r"rb'\\'"; let tokens = lex_source(source); - assert_eq!( - tokens, - vec![ - Tok::Bytes { - value: b"\\\\".to_vec() - }, - Tok::Newline - ] - ) + assert_eq!(tokens, vec![Tok::Bytes { value: b"\\\\".to_vec() }, Tok::Newline]) } #[test] fn test_escape_octet() { let source = r##"b'\43a\4\1234'"##; let tokens = lex_source(source); - assert_eq!( - tokens, - vec![ - Tok::Bytes { - value: b"#a\x04S4".to_vec() - }, - Tok::Newline - ] - ) + assert_eq!(tokens, vec![Tok::Bytes { value: b"#a\x04S4".to_vec() }, Tok::Newline]) } #[test] @@ -1883,13 +1773,7 @@ class Foo(A, B): let tokens = lex_source(source); assert_eq!( tokens, - vec![ - Tok::String { - value: "\u{2002}".to_owned(), - is_fstring: false, - }, - Tok::Newline - ] + vec![Tok::String { value: "\u{2002}".to_owned(), is_fstring: false }, Tok::Newline] ) } } diff --git a/nac3parser/src/lib.rs b/nac3parser/src/lib.rs index 5e253059..991cf301 100644 --- a/nac3parser/src/lib.rs +++ b/nac3parser/src/lib.rs @@ -31,5 +31,5 @@ lalrpop_mod!( #[allow(unused)] python ); -pub mod token; pub mod config_comment_helper; +pub mod token; diff --git a/nac3parser/src/parser.rs b/nac3parser/src/parser.rs index b8968d59..a4e89c41 100644 --- a/nac3parser/src/parser.rs +++ b/nac3parser/src/parser.rs @@ -75,9 +75,7 @@ pub fn parse(source: &str, mode: Mode, file: FileName) -> Result if *value != i128::MAX { write!(f, "'{}'", value) } else { write!(f, "'#OFL#'") }, + Name { name } => { + write!(f, "'{}'", ast::get_str_from_ref(&ast::get_str_ref_lock(), *name)) + } + Int { value } => { + if *value != i128::MAX { + write!(f, "'{}'", value) + } else { + write!(f, "'#OFL#'") + } + } Float { value } => write!(f, "'{}'", value), Complex { real, imag } => write!(f, "{}j{}", real, imag), String { value, is_fstring } => { @@ -134,7 +142,11 @@ impl fmt::Display for Tok { } f.write_str("\"") } - ConfigComment { content } => write!(f, "ConfigComment: '{}'", ast::get_str_from_ref(&ast::get_str_ref_lock(), *content)), + ConfigComment { content } => write!( + f, + "ConfigComment: '{}'", + ast::get_str_from_ref(&ast::get_str_ref_lock(), *content) + ), Newline => f.write_str("Newline"), Indent => f.write_str("Indent"), Dedent => f.write_str("Dedent"), diff --git a/nac3standalone/src/basic_symbol_resolver.rs b/nac3standalone/src/basic_symbol_resolver.rs index d7fb1b8c..5fe0d4f5 100644 --- a/nac3standalone/src/basic_symbol_resolver.rs +++ b/nac3standalone/src/basic_symbol_resolver.rs @@ -9,8 +9,8 @@ use nac3core::{ }; use nac3parser::ast::{self, StrRef}; use parking_lot::{Mutex, RwLock}; -use std::{collections::HashMap, sync::Arc}; use std::collections::HashSet; +use std::{collections::HashMap, sync::Arc}; pub struct ResolverInternal { pub id_to_type: Mutex>, @@ -63,10 +63,12 @@ impl SymbolResolver for Resolver { } fn get_identifier_def(&self, id: StrRef) -> Result> { - self.0.id_to_def.lock().get(&id).copied() - .ok_or_else(|| HashSet::from([ - format!("Undefined identifier `{id}`"), - ])) + self.0 + .id_to_def + .lock() + .get(&id) + .copied() + .ok_or_else(|| HashSet::from([format!("Undefined identifier `{id}`")])) } fn get_string_id(&self, s: &str) -> i32 { diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index 6c4b4aba..349063f0 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -1,14 +1,11 @@ use clap::Parser; use inkwell::{ - memory_buffer::MemoryBuffer, - passes::PassBuilderOptions, - support::is_multithreaded, - targets::*, + memory_buffer::MemoryBuffer, passes::PassBuilderOptions, support::is_multithreaded, targets::*, OptimizationLevel, }; use parking_lot::{Mutex, RwLock}; -use std::{collections::HashMap, fs, path::Path, sync::Arc}; use std::collections::HashSet; +use std::{collections::HashMap, fs, path::Path, sync::Arc}; use nac3core::{ codegen::{ @@ -18,7 +15,7 @@ use nac3core::{ symbol_resolver::SymbolResolver, toplevel::{ composer::{ComposerConfig, TopLevelComposer}, - helper::parse_parameter_default_value, + helper::parse_parameter_default_value, type_annotation::*, TopLevelDef, }, @@ -78,19 +75,18 @@ fn handle_typevar_definition( primitives: &PrimitiveStore, ) -> Result> { let ExprKind::Call { func, args, .. } = &var.node else { - return Err(HashSet::from([ - format!( - "expression {var:?} cannot be handled as a generic parameter in global scope" - ), - ])) + return Err(HashSet::from([format!( + "expression {var:?} cannot be handled as a generic parameter in global scope" + )])); }; match &func.node { ExprKind::Name { id, .. } if id == &"TypeVar".into() => { let ExprKind::Constant { value: Constant::Str(ty_name), .. } = &args[0].node else { - return Err(HashSet::from([ - format!("Expected string constant for first parameter of `TypeVar`, got {:?}", &args[0].node), - ])) + return Err(HashSet::from([format!( + "Expected string constant for first parameter of `TypeVar`, got {:?}", + &args[0].node + )])); }; let generic_name: StrRef = ty_name.to_string().into(); @@ -106,17 +102,15 @@ fn handle_typevar_definition( x, HashMap::default(), )?; - get_type_from_type_annotation_kinds( - def_list, unifier, &ty, &mut None - ) + get_type_from_type_annotation_kinds(def_list, unifier, &ty, &mut None) }) .collect::, _>>()?; let loc = func.location; if constraints.len() == 1 { - return Err(HashSet::from([ - format!("A single constraint is not allowed (at {loc})"), - ])) + return Err(HashSet::from([format!( + "A single constraint is not allowed (at {loc})" + )])); } Ok(unifier.get_fresh_var_with_range(&constraints, Some(generic_name), Some(loc)).0) @@ -124,18 +118,17 @@ fn handle_typevar_definition( ExprKind::Name { id, .. } if id == &"ConstGeneric".into() => { if args.len() != 2 { - return Err(HashSet::from([ - format!("Expected 2 arguments for `ConstGeneric`, got {}", args.len()), - ])) + return Err(HashSet::from([format!( + "Expected 2 arguments for `ConstGeneric`, got {}", + args.len() + )])); } let ExprKind::Constant { value: Constant::Str(ty_name), .. } = &args[0].node else { - return Err(HashSet::from([ - format!( - "Expected string constant for first parameter of `ConstGeneric`, got {:?}", - &args[0].node - ), - ])) + return Err(HashSet::from([format!( + "Expected string constant for first parameter of `ConstGeneric`, got {:?}", + &args[0].node + )])); }; let generic_name: StrRef = ty_name.to_string().into(); @@ -147,19 +140,16 @@ fn handle_typevar_definition( &args[1], HashMap::default(), )?; - let constraint = get_type_from_type_annotation_kinds( - def_list, unifier, &ty, &mut None - )?; + let constraint = + get_type_from_type_annotation_kinds(def_list, unifier, &ty, &mut None)?; let loc = func.location; Ok(unifier.get_fresh_const_generic_var(constraint, Some(generic_name), Some(loc)).0) } - _ => Err(HashSet::from([ - format!( - "expression {var:?} cannot be handled as a generic parameter in global scope" - ), - ])) + _ => Err(HashSet::from([format!( + "expression {var:?} cannot be handled as a generic parameter in global scope" + )])), } } @@ -175,18 +165,12 @@ fn handle_assignment_pattern( if targets.len() == 1 { match &targets[0].node { ExprKind::Name { id, .. } => { - if let Ok(var) = handle_typevar_definition( - value, - resolver, - def_list, - unifier, - primitives, - ) { + if let Ok(var) = + handle_typevar_definition(value, resolver, def_list, unifier, primitives) + { internal_resolver.add_id_type(*id, var); Ok(()) - } else if let Ok(val) = - parse_parameter_default_value(value, resolver) - { + } else if let Ok(val) = parse_parameter_default_value(value, resolver) { internal_resolver.add_module_global(*id, val); Ok(()) } else { @@ -238,10 +222,7 @@ fn handle_assignment_pattern( )) } } - _ => Err(format!( - "unpack of this expression is not supported at {}", - value.location - )), + _ => Err(format!("unpack of this expression is not supported at {}", value.location)), } } } @@ -250,15 +231,8 @@ fn main() { const SIZE_T: u32 = usize::BITS; let cli = CommandLineArgs::parse(); - let CommandLineArgs { - file_name, - threads, - opt_level, - emit_llvm, - triple, - mcpu, - target_features, - } = cli; + let CommandLineArgs { file_name, threads, opt_level, emit_llvm, triple, mcpu, target_features } = + cli; Target::initialize_all(&InitializationConfig::default()); @@ -270,9 +244,7 @@ fn main() { let target_features = target_features.unwrap_or_default(); let threads = if is_multithreaded() { if threads == 0 { - std::thread::available_parallelism() - .map(|threads| threads.get() as u32) - .unwrap_or(1u32) + std::thread::available_parallelism().map(|threads| threads.get() as u32).unwrap_or(1u32) } else { threads } @@ -308,7 +280,8 @@ fn main() { class_names: Mutex::default(), module_globals: Mutex::default(), str_store: Mutex::default(), - }.into(); + } + .into(); let resolver = Arc::new(Resolver(internal_resolver.clone())) as Arc; @@ -332,13 +305,19 @@ fn main() { eprintln!("{err}"); return; } - }, + } // allow (and ignore) "from __future__ import annotations" StmtKind::ImportFrom { module, names, .. } - if module == &Some("__future__".into()) && names.len() == 1 && names[0].name == "annotations".into() => (), + if module == &Some("__future__".into()) + && names.len() == 1 + && names[0].name == "annotations".into() => + { + () + } _ => { - let (name, def_id, ty) = - composer.register_top_level(stmt, Some(resolver.clone()), "__main__", true).unwrap(); + let (name, def_id, ty) = composer + .register_top_level(stmt, Some(resolver.clone()), "__main__", true) + .unwrap(); internal_resolver.add_id_def(name, def_id); if let Some(ty) = ty { internal_resolver.add_id_type(name, ty); @@ -364,7 +343,8 @@ fn main() { .unwrap_or_else(|_| panic!("cannot find run() entry point")) .0] .write(); - let TopLevelDef::Function { instance_to_stmt, instance_to_symbol, .. } = &mut *instance else { + let TopLevelDef::Function { instance_to_stmt, instance_to_symbol, .. } = &mut *instance + else { unreachable!() }; instance_to_symbol.insert(String::new(), "run".to_string()); @@ -444,7 +424,8 @@ fn main() { function_iter = func.get_next_function(); } - let target_machine = llvm_options.target + let target_machine = llvm_options + .target .create_target_machine(llvm_options.opt_level) .expect("couldn't create target machine"); diff --git a/runkernel/src/main.rs b/runkernel/src/main.rs index 84f1e323..8c6d2d18 100644 --- a/runkernel/src/main.rs +++ b/runkernel/src/main.rs @@ -47,12 +47,11 @@ pub extern "C" fn __nac3_personality(_state: u32, _exception_object: u32, _conte unimplemented!(); } - fn main() { let filename = env::args().nth(1).unwrap(); unsafe { let lib = libloading::Library::new(filename).unwrap(); - let func: libloading::Symbol = lib.get(b"__modinit__").unwrap(); + let func: libloading::Symbol = lib.get(b"__modinit__").unwrap(); func() } }