From f52ba9f15140335a082c124602c43ea0539e33c7 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 4 Feb 2025 17:13:33 +0800 Subject: [PATCH] [core] codegen/irrt: Refactor IRRT to use more create/infer fns --- nac3artiq/src/codegen.rs | 82 +++--------- nac3artiq/src/timeline.rs | 48 +++---- nac3core/src/codegen/extern_fns.rs | 104 ++++++++------- nac3core/src/codegen/irrt/list.rs | 72 +++++------ nac3core/src/codegen/irrt/math.rs | 135 +++++++++----------- nac3core/src/codegen/irrt/ndarray/basic.rs | 91 ++++--------- nac3core/src/codegen/irrt/ndarray/iter.rs | 23 +--- nac3core/src/codegen/irrt/ndarray/matmul.rs | 28 +--- nac3core/src/codegen/irrt/range.rs | 27 ++-- nac3core/src/codegen/irrt/slice.rs | 30 +++-- nac3core/src/codegen/irrt/string.rs | 50 +++----- 11 files changed, 270 insertions(+), 420 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 572accc..fb6992b 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -15,7 +15,7 @@ use pyo3::{ use super::{symbol_resolver::InnerResolver, timeline::TimeFns}; use nac3core::{ codegen::{ - expr::{destructure_range, gen_call}, + expr::{create_fn_and_call, destructure_range, gen_call, infer_and_call_function}, llvm_intrinsics::{call_int_smax, call_memcpy, call_stackrestore, call_stacksave}, stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with}, type_aligned_alloca, @@ -914,47 +914,14 @@ fn rpc_codegen_callback_fn<'ctx>( } // call - if is_async { - let rpc_send_async = ctx.module.get_function("rpc_send_async").unwrap_or_else(|| { - ctx.module.add_function( - "rpc_send_async", - ctx.ctx.void_type().fn_type( - &[ - int32.into(), - tag_ptr_type.ptr_type(AddressSpace::default()).into(), - ptr_type.ptr_type(AddressSpace::default()).into(), - ], - false, - ), - None, - ) - }); - ctx.builder - .build_call( - rpc_send_async, - &[service_id.into(), tag_ptr.into(), args_ptr.into()], - "rpc.send", - ) - .unwrap(); - } else { - let rpc_send = ctx.module.get_function("rpc_send").unwrap_or_else(|| { - ctx.module.add_function( - "rpc_send", - ctx.ctx.void_type().fn_type( - &[ - int32.into(), - tag_ptr_type.ptr_type(AddressSpace::default()).into(), - ptr_type.ptr_type(AddressSpace::default()).into(), - ], - false, - ), - None, - ) - }); - ctx.builder - .build_call(rpc_send, &[service_id.into(), tag_ptr.into(), args_ptr.into()], "rpc.send") - .unwrap(); - } + infer_and_call_function( + ctx, + if is_async { "rpc_send_async" } else { "rpc_send" }, + None, + &[service_id.into(), tag_ptr.into(), args_ptr.into()], + Some("rpc.send"), + None, + ); // reclaim stack space used by arguments call_stackrestore(ctx, stackptr); @@ -1168,29 +1135,22 @@ fn polymorphic_print<'ctx>( debug_assert!(!fmt.is_empty()); debug_assert_eq!(fmt.as_bytes().last().unwrap(), &0u8); - let fn_name = if as_rtio { "rtio_log" } else { "core_log" }; - let print_fn = ctx.module.get_function(fn_name).unwrap_or_else(|| { - let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); - let fn_t = if as_rtio { - let llvm_void = ctx.ctx.void_type(); - llvm_void.fn_type(&[llvm_pi8.into()], true) - } else { - let llvm_i32 = ctx.ctx.i32_type(); - llvm_i32.fn_type(&[llvm_pi8.into()], true) - }; - ctx.module.add_function(fn_name, fn_t, None) - }); + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); let fmt = ctx.gen_string(generator, fmt); let fmt = unsafe { fmt.get_field_at_index_unchecked(0) }.into_pointer_value(); - ctx.builder - .build_call( - print_fn, - &once(fmt.into()).chain(args).map(BasicValueEnum::into).collect_vec(), - "", - ) - .unwrap(); + create_fn_and_call( + ctx, + if as_rtio { "rtio_log" } else { "core_log" }, + if as_rtio { None } else { Some(llvm_i32.into()) }, + &[llvm_pi8.into()], + &once(fmt.into()).chain(args).map(BasicValueEnum::into).collect_vec(), + true, + None, + None, + ); }; let llvm_i32 = ctx.ctx.i32_type(); diff --git a/nac3artiq/src/timeline.rs b/nac3artiq/src/timeline.rs index f51c553..5d6fc79 100644 --- a/nac3artiq/src/timeline.rs +++ b/nac3artiq/src/timeline.rs @@ -1,11 +1,6 @@ -use itertools::Either; - use nac3core::{ - codegen::CodeGenContext, - inkwell::{ - values::{BasicValueEnum, CallSiteValue}, - AddressSpace, AtomicOrdering, - }, + codegen::{expr::infer_and_call_function, CodeGenContext}, + inkwell::{values::BasicValueEnum, AddressSpace, AtomicOrdering}, }; /// Functions for manipulating the timeline. @@ -288,36 +283,27 @@ pub struct ExternTimeFns {} impl TimeFns for ExternTimeFns { fn emit_now_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx> { - let now_mu = ctx.module.get_function("now_mu").unwrap_or_else(|| { - ctx.module.add_function("now_mu", ctx.ctx.i64_type().fn_type(&[], false), None) - }); - ctx.builder - .build_call(now_mu, &[], "now_mu") - .map(CallSiteValue::try_as_basic_value) - .map(Either::unwrap_left) - .unwrap() + infer_and_call_function( + ctx, + "now_mu", + Some(ctx.ctx.i64_type().into()), + &[], + Some("now_mu"), + None, + ) + .unwrap() } fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) { - let at_mu = ctx.module.get_function("at_mu").unwrap_or_else(|| { - ctx.module.add_function( - "at_mu", - ctx.ctx.void_type().fn_type(&[ctx.ctx.i64_type().into()], false), - None, - ) - }); - ctx.builder.build_call(at_mu, &[t.into()], "at_mu").unwrap(); + assert_eq!(t.get_type(), ctx.ctx.i64_type().into()); + + infer_and_call_function(ctx, "at_mu", None, &[t], Some("at_mu"), None); } fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>) { - let delay_mu = ctx.module.get_function("delay_mu").unwrap_or_else(|| { - ctx.module.add_function( - "delay_mu", - ctx.ctx.void_type().fn_type(&[ctx.ctx.i64_type().into()], false), - None, - ) - }); - ctx.builder.build_call(delay_mu, &[dt.into()], "delay_mu").unwrap(); + assert_eq!(dt.get_type(), ctx.ctx.i64_type().into()); + + infer_and_call_function(ctx, "delay_mu", None, &[dt], Some("delay_mu"), None); } } diff --git a/nac3core/src/codegen/extern_fns.rs b/nac3core/src/codegen/extern_fns.rs index 6412fbd..8dcc2f2 100644 --- a/nac3core/src/codegen/extern_fns.rs +++ b/nac3core/src/codegen/extern_fns.rs @@ -1,10 +1,9 @@ use inkwell::{ attributes::{Attribute, AttributeLoc}, - values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue}, + values::{BasicValueEnum, FloatValue, IntValue}, }; -use itertools::Either; -use super::CodeGenContext; +use super::{expr::infer_and_call_function, CodeGenContext}; /// Macro to generate extern function /// Both function return type and function parameter type are `FloatValue` @@ -46,24 +45,23 @@ macro_rules! generate_extern_fn { let llvm_f64 = ctx.ctx.f64_type(); $(debug_assert_eq!($args.get_type(), llvm_f64);)* - let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| { - let fn_type = llvm_f64.fn_type(&[$($args.get_type().into()),*], false); - let func = ctx.module.add_function(FN_NAME, fn_type, None); - for attr in [$($attributes),*] { - func.add_attribute( - AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), - ); - } - func - }); - - ctx.builder - .build_call(extern_fn, &[$($args.into()),*], name.unwrap_or_default()) - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_float_value)) - .map(Either::unwrap_left) - .unwrap() + infer_and_call_function( + ctx, + FN_NAME, + Some(llvm_f64.into()), + &[$($args.into()),*], + name, + Some(&|func| { + for attr in [$($attributes),*] { + func.add_attribute( + AttributeLoc::Function, + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), + ); + } + }) + ) + .map(BasicValueEnum::into_float_value) + .unwrap() } }; } @@ -112,25 +110,23 @@ pub fn call_ldexp<'ctx>( debug_assert_eq!(arg.get_type(), llvm_f64); debug_assert_eq!(exp.get_type(), llvm_i32); - let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| { - let fn_type = llvm_f64.fn_type(&[llvm_f64.into(), llvm_i32.into()], false); - let func = ctx.module.add_function(FN_NAME, fn_type, None); - for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] { - func.add_attribute( - AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), - ); - } - - func - }); - - ctx.builder - .build_call(extern_fn, &[arg.into(), exp.into()], name.unwrap_or_default()) - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_float_value)) - .map(Either::unwrap_left) - .unwrap() + infer_and_call_function( + ctx, + FN_NAME, + Some(llvm_f64.into()), + &[arg.into(), exp.into()], + name, + Some(&|func| { + for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] { + func.add_attribute( + AttributeLoc::Function, + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), + ); + } + }), + ) + .map(BasicValueEnum::into_float_value) + .unwrap() } /// Macro to generate `np_linalg` and `sp_linalg` functions @@ -163,20 +159,22 @@ macro_rules! generate_linalg_extern_fn { name: Option<&str>, ){ const FN_NAME: &str = $extern_fn; - let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| { - let fn_type = ctx.ctx.void_type().fn_type(&[$($input_matrix.get_type().into()),*], false); - let func = ctx.module.add_function(FN_NAME, fn_type, None); - for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { - func.add_attribute( - AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), - ); - } - func - }); - - ctx.builder.build_call(extern_fn, &[$($input_matrix.into(),)*], name.unwrap_or_default()).unwrap(); + infer_and_call_function( + ctx, + FN_NAME, + None, + &[$($input_matrix.into(),)*], + name, + Some(&|func| { + for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { + func.add_attribute( + AttributeLoc::Function, + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), + ); + } + }), + ); } }; } diff --git a/nac3core/src/codegen/irrt/list.rs b/nac3core/src/codegen/irrt/list.rs index c01e2cb..20daace 100644 --- a/nac3core/src/codegen/irrt/list.rs +++ b/nac3core/src/codegen/irrt/list.rs @@ -1,13 +1,14 @@ use inkwell::{ types::BasicTypeEnum, - values::{BasicValueEnum, CallSiteValue, IntValue}, + values::{BasicValueEnum, IntValue}, AddressSpace, IntPredicate, }; -use itertools::Either; use super::calculate_len_for_slice_range; use crate::codegen::{ + expr::infer_and_call_function, macros::codegen_unreachable, + stmt::gen_if_callback, values::{ArrayLikeValue, ListValue}, CodeGenContext, CodeGenerator, }; @@ -36,25 +37,6 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>( assert_eq!(src_idx.2.get_type(), llvm_i32); let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", llvm_pi8); - let slice_assign_fun = { - let ty_vec = vec![ - llvm_i32.into(), // dest start idx - llvm_i32.into(), // dest end idx - llvm_i32.into(), // dest step - elem_ptr_type.into(), // dest arr ptr - llvm_i32.into(), // dest arr len - llvm_i32.into(), // src start idx - llvm_i32.into(), // src end idx - llvm_i32.into(), // src step - elem_ptr_type.into(), // src arr ptr - llvm_i32.into(), // src arr len - llvm_i32.into(), // size - ]; - ctx.module.get_function(fun_symbol).unwrap_or_else(|| { - let fn_t = llvm_i32.fn_type(ty_vec.as_slice(), false); - ctx.module.add_function(fun_symbol, fn_t, None) - }) - }; let zero = llvm_i32.const_zero(); let one = llvm_i32.const_int(1, false); @@ -127,7 +109,7 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>( ); let new_len = { - let args = vec![ + let args = [ dest_idx.0.into(), // dest start idx dest_idx.1.into(), // dest end idx dest_idx.2.into(), // dest step @@ -150,25 +132,35 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>( } .into(), ]; - ctx.builder - .build_call(slice_assign_fun, args.as_slice(), "slice_assign") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap() + infer_and_call_function( + ctx, + fun_symbol, + Some(llvm_i32.into()), + &args, + Some("slice_assign"), + None, + ) + .map(BasicValueEnum::into_int_value) + .unwrap() }; // update length - let need_update = - ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update").unwrap(); - let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); - let update_bb = ctx.ctx.append_basic_block(current, "update"); - let cont_bb = ctx.ctx.append_basic_block(current, "cont"); - ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb).unwrap(); - ctx.builder.position_at_end(update_bb); - let new_len = - ctx.builder.build_int_z_extend_or_bit_cast(new_len, llvm_usize, "new_len").unwrap(); - dest_arr.store_size(ctx, new_len); - ctx.builder.build_unconditional_branch(cont_bb).unwrap(); - ctx.builder.position_at_end(cont_bb); + gen_if_callback( + generator, + ctx, + |_, ctx| { + Ok(ctx + .builder + .build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update") + .unwrap()) + }, + |_, ctx| { + let new_len = + ctx.builder.build_int_z_extend_or_bit_cast(new_len, llvm_usize, "new_len").unwrap(); + dest_arr.store_size(ctx, new_len); + Ok(()) + }, + |_, _| Ok(()), + ) + .unwrap(); } diff --git a/nac3core/src/codegen/irrt/math.rs b/nac3core/src/codegen/irrt/math.rs index 33445b2..e430080 100644 --- a/nac3core/src/codegen/irrt/math.rs +++ b/nac3core/src/codegen/irrt/math.rs @@ -1,10 +1,10 @@ use inkwell::{ - values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue}, + values::{BasicValueEnum, FloatValue, IntValue}, IntPredicate, }; -use itertools::Either; use crate::codegen::{ + expr::infer_and_call_function, macros::codegen_unreachable, {CodeGenContext, CodeGenerator}, }; @@ -18,18 +18,16 @@ pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>( exp: IntValue<'ctx>, signed: bool, ) -> IntValue<'ctx> { - let symbol = match (base.get_type().get_bit_width(), exp.get_type().get_bit_width(), signed) { + let base_type = base.get_type(); + + let symbol = match (base_type.get_bit_width(), exp.get_type().get_bit_width(), signed) { (32, 32, true) => "__nac3_int_exp_int32_t", (64, 64, true) => "__nac3_int_exp_int64_t", (32, 32, false) => "__nac3_int_exp_uint32_t", (64, 64, false) => "__nac3_int_exp_uint64_t", _ => codegen_unreachable!(ctx), }; - let base_type = base.get_type(); - let pow_fun = ctx.module.get_function(symbol).unwrap_or_else(|| { - let fn_type = base_type.fn_type(&[base_type.into(), base_type.into()], false); - ctx.module.add_function(symbol, fn_type, None) - }); + // throw exception when exp < 0 let ge_zero = ctx .builder @@ -48,12 +46,17 @@ pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>( [None, None, None], ctx.current_loc, ); - ctx.builder - .build_call(pow_fun, &[base.into(), exp.into()], "call_int_pow") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap() + + infer_and_call_function( + ctx, + symbol, + Some(base_type.into()), + &[base.into(), exp.into()], + Some("call_int_pow"), + None, + ) + .map(BasicValueEnum::into_int_value) + .unwrap() } /// Generates a call to `isinf` in IR. Returns an `i1` representing the result. @@ -67,20 +70,17 @@ pub fn call_isinf<'ctx, G: CodeGenerator + ?Sized>( assert_eq!(v.get_type(), llvm_f64); - let intrinsic_fn = ctx.module.get_function("__nac3_isinf").unwrap_or_else(|| { - let fn_type = llvm_i32.fn_type(&[llvm_f64.into()], false); - ctx.module.add_function("__nac3_isinf", fn_type, None) - }); - - let ret = ctx - .builder - .build_call(intrinsic_fn, &[v.into()], "isinf") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap(); - - generator.bool_to_i1(ctx, ret) + infer_and_call_function( + ctx, + "__nac3_isinf", + Some(llvm_i32.into()), + &[v.into()], + Some("isinf"), + None, + ) + .map(BasicValueEnum::into_int_value) + .map(|ret| generator.bool_to_i1(ctx, ret)) + .unwrap() } /// Generates a call to `isnan` in IR. Returns an `i1` representing the result. @@ -94,20 +94,17 @@ pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>( assert_eq!(v.get_type(), llvm_f64); - let intrinsic_fn = ctx.module.get_function("__nac3_isnan").unwrap_or_else(|| { - let fn_type = llvm_i32.fn_type(&[llvm_f64.into()], false); - ctx.module.add_function("__nac3_isnan", fn_type, None) - }); - - let ret = ctx - .builder - .build_call(intrinsic_fn, &[v.into()], "isnan") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap(); - - generator.bool_to_i1(ctx, ret) + infer_and_call_function( + ctx, + "__nac3_isnan", + Some(llvm_i32.into()), + &[v.into()], + Some("isnan"), + None, + ) + .map(BasicValueEnum::into_int_value) + .map(|ret| generator.bool_to_i1(ctx, ret)) + .unwrap() } /// Generates a call to `gamma` in IR. Returns an `f64` representing the result. @@ -116,17 +113,16 @@ pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> assert_eq!(v.get_type(), llvm_f64); - let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| { - let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); - ctx.module.add_function("__nac3_gamma", fn_type, None) - }); - - ctx.builder - .build_call(intrinsic_fn, &[v.into()], "gamma") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_float_value)) - .map(Either::unwrap_left) - .unwrap() + infer_and_call_function( + ctx, + "__nac3_gamma", + Some(llvm_f64.into()), + &[v.into()], + Some("gamma"), + None, + ) + .map(BasicValueEnum::into_float_value) + .unwrap() } /// Generates a call to `gammaln` in IR. Returns an `f64` representing the result. @@ -135,17 +131,16 @@ pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) - assert_eq!(v.get_type(), llvm_f64); - let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| { - let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); - ctx.module.add_function("__nac3_gammaln", fn_type, None) - }); - - ctx.builder - .build_call(intrinsic_fn, &[v.into()], "gammaln") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_float_value)) - .map(Either::unwrap_left) - .unwrap() + infer_and_call_function( + ctx, + "__nac3_gammaln", + Some(llvm_f64.into()), + &[v.into()], + Some("gammaln"), + None, + ) + .map(BasicValueEnum::into_float_value) + .unwrap() } /// Generates a call to `j0` in IR. Returns an `f64` representing the result. @@ -154,15 +149,7 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo assert_eq!(v.get_type(), llvm_f64); - let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| { - let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); - ctx.module.add_function("__nac3_j0", fn_type, None) - }); - - ctx.builder - .build_call(intrinsic_fn, &[v.into()], "j0") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_float_value)) - .map(Either::unwrap_left) + infer_and_call_function(ctx, "__nac3_j0", Some(llvm_f64.into()), &[v.into()], Some("j0"), None) + .map(BasicValueEnum::into_float_value) .unwrap() } diff --git a/nac3core/src/codegen/irrt/ndarray/basic.rs b/nac3core/src/codegen/irrt/ndarray/basic.rs index 5f291c8..06f38f7 100644 --- a/nac3core/src/codegen/irrt/ndarray/basic.rs +++ b/nac3core/src/codegen/irrt/ndarray/basic.rs @@ -1,13 +1,11 @@ use inkwell::{ - types::BasicTypeEnum, values::{BasicValueEnum, IntValue, PointerValue}, AddressSpace, }; use crate::codegen::{ - expr::{create_and_call_function, infer_and_call_function}, + expr::infer_and_call_function, irrt::get_usize_dependent_function_name, - types::ProxyType, values::{ndarray::NDArrayValue, ProxyValue, TypedArrayLikeAccessor}, CodeGenContext, CodeGenerator, }; @@ -21,24 +19,17 @@ pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ) { let llvm_usize = ctx.get_size_type(); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - assert_eq!( - BasicTypeEnum::try_from(shape.element_type(ctx, generator)).unwrap(), - llvm_usize.into() - ); + assert_eq!(shape.element_type(ctx, generator), llvm_usize.into()); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_util_assert_shape_no_negative"); - create_and_call_function( + infer_and_call_function( ctx, &name, Some(llvm_usize.into()), - &[ - (llvm_usize.into(), shape.size(ctx, generator).into()), - (llvm_pusize.into(), shape.base_ptr(ctx, generator).into()), - ], + &[shape.size(ctx, generator).into(), shape.base_ptr(ctx, generator).into()], None, None, ); @@ -55,29 +46,22 @@ pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + output_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ) { let llvm_usize = ctx.get_size_type(); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - assert_eq!( - BasicTypeEnum::try_from(ndarray_shape.element_type(ctx, generator)).unwrap(), - llvm_usize.into() - ); - assert_eq!( - BasicTypeEnum::try_from(output_shape.element_type(ctx, generator)).unwrap(), - llvm_usize.into() - ); + assert_eq!(ndarray_shape.element_type(ctx, generator), llvm_usize.into()); + assert_eq!(output_shape.element_type(ctx, generator), llvm_usize.into()); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_util_assert_output_shape_same"); - create_and_call_function( + infer_and_call_function( ctx, &name, Some(llvm_usize.into()), &[ - (llvm_usize.into(), ndarray_shape.size(ctx, generator).into()), - (llvm_pusize.into(), ndarray_shape.base_ptr(ctx, generator).into()), - (llvm_usize.into(), output_shape.size(ctx, generator).into()), - (llvm_pusize.into(), output_shape.base_ptr(ctx, generator).into()), + ndarray_shape.size(ctx, generator).into(), + ndarray_shape.base_ptr(ctx, generator).into(), + output_shape.size(ctx, generator).into(), + output_shape.base_ptr(ctx, generator).into(), ], None, None, @@ -93,15 +77,14 @@ pub fn call_nac3_ndarray_size<'ctx>( ndarray: NDArrayValue<'ctx>, ) -> IntValue<'ctx> { let llvm_usize = ctx.get_size_type(); - let llvm_ndarray = ndarray.get_type(); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_size"); - create_and_call_function( + infer_and_call_function( ctx, &name, Some(llvm_usize.into()), - &[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())], + &[ndarray.as_abi_value(ctx).into()], Some("size"), None, ) @@ -118,15 +101,14 @@ pub fn call_nac3_ndarray_nbytes<'ctx>( ndarray: NDArrayValue<'ctx>, ) -> IntValue<'ctx> { let llvm_usize = ctx.get_size_type(); - let llvm_ndarray = ndarray.get_type(); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_nbytes"); - create_and_call_function( + infer_and_call_function( ctx, &name, Some(llvm_usize.into()), - &[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())], + &[ndarray.as_abi_value(ctx).into()], Some("nbytes"), None, ) @@ -143,15 +125,14 @@ pub fn call_nac3_ndarray_len<'ctx>( ndarray: NDArrayValue<'ctx>, ) -> IntValue<'ctx> { let llvm_usize = ctx.get_size_type(); - let llvm_ndarray = ndarray.get_type(); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_len"); - create_and_call_function( + infer_and_call_function( ctx, &name, Some(llvm_usize.into()), - &[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())], + &[ndarray.as_abi_value(ctx).into()], Some("len"), None, ) @@ -167,15 +148,14 @@ pub fn call_nac3_ndarray_is_c_contiguous<'ctx>( ndarray: NDArrayValue<'ctx>, ) -> IntValue<'ctx> { let llvm_i1 = ctx.ctx.bool_type(); - let llvm_ndarray = ndarray.get_type(); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_is_c_contiguous"); - create_and_call_function( + infer_and_call_function( ctx, &name, Some(llvm_i1.into()), - &[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())], + &[ndarray.as_abi_value(ctx).into()], Some("is_c_contiguous"), None, ) @@ -194,20 +174,16 @@ pub fn call_nac3_ndarray_get_nth_pelement<'ctx>( let llvm_i8 = ctx.ctx.i8_type(); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); let llvm_usize = ctx.get_size_type(); - let llvm_ndarray = ndarray.get_type(); assert_eq!(index.get_type(), llvm_usize); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_get_nth_pelement"); - create_and_call_function( + infer_and_call_function( ctx, &name, Some(llvm_pi8.into()), - &[ - (llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into()), - (llvm_usize.into(), index.into()), - ], + &[ndarray.as_abi_value(ctx).into(), index.into()], Some("pelement"), None, ) @@ -229,24 +205,16 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized let llvm_i8 = ctx.ctx.i8_type(); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); let llvm_usize = ctx.get_size_type(); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - let llvm_ndarray = ndarray.get_type(); - assert_eq!( - BasicTypeEnum::try_from(indices.element_type(ctx, generator)).unwrap(), - llvm_usize.into() - ); + assert_eq!(indices.element_type(ctx, generator), llvm_usize.into()); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_get_pelement_by_indices"); - create_and_call_function( + infer_and_call_function( ctx, &name, Some(llvm_pi8.into()), - &[ - (llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into()), - (llvm_pusize.into(), indices.base_ptr(ctx, generator).into()), - ], + &[ndarray.as_abi_value(ctx).into(), indices.base_ptr(ctx, generator).into()], Some("pelement"), None, ) @@ -261,18 +229,9 @@ pub fn call_nac3_ndarray_set_strides_by_shape<'ctx>( ctx: &CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, ) { - let llvm_ndarray = ndarray.get_type(); - let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_set_strides_by_shape"); - create_and_call_function( - ctx, - &name, - None, - &[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())], - None, - None, - ); + infer_and_call_function(ctx, &name, None, &[ndarray.as_abi_value(ctx).into()], None, None); } /// Generates a call to `__nac3_ndarray_copy_data`. diff --git a/nac3core/src/codegen/irrt/ndarray/iter.rs b/nac3core/src/codegen/irrt/ndarray/iter.rs index e4424df..d44870d 100644 --- a/nac3core/src/codegen/irrt/ndarray/iter.rs +++ b/nac3core/src/codegen/irrt/ndarray/iter.rs @@ -1,13 +1,8 @@ -use inkwell::{ - types::BasicTypeEnum, - values::{BasicValueEnum, IntValue}, - AddressSpace, -}; +use inkwell::values::{BasicValueEnum, IntValue}; use crate::codegen::{ - expr::{create_and_call_function, infer_and_call_function}, + expr::infer_and_call_function, irrt::get_usize_dependent_function_name, - types::ProxyType, values::{ ndarray::{NDArrayValue, NDIterValue}, ProxyValue, TypedArrayLikeAccessor, @@ -26,23 +21,19 @@ pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>( indices: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ) { let llvm_usize = ctx.get_size_type(); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - assert_eq!( - BasicTypeEnum::try_from(indices.element_type(ctx, generator)).unwrap(), - llvm_usize.into() - ); + assert_eq!(indices.element_type(ctx, generator), llvm_usize.into()); let name = get_usize_dependent_function_name(ctx, "__nac3_nditer_initialize"); - create_and_call_function( + infer_and_call_function( ctx, &name, None, &[ - (iter.get_type().as_abi_type().into(), iter.as_abi_value(ctx).into()), - (ndarray.get_type().as_abi_type().into(), ndarray.as_abi_value(ctx).into()), - (llvm_pusize.into(), indices.base_ptr(ctx, generator).into()), + iter.as_abi_value(ctx).into(), + ndarray.as_abi_value(ctx).into(), + indices.base_ptr(ctx, generator).into(), ], None, None, diff --git a/nac3core/src/codegen/irrt/ndarray/matmul.rs b/nac3core/src/codegen/irrt/ndarray/matmul.rs index 0df774f..d2e73ae 100644 --- a/nac3core/src/codegen/irrt/ndarray/matmul.rs +++ b/nac3core/src/codegen/irrt/ndarray/matmul.rs @@ -1,4 +1,4 @@ -use inkwell::{types::BasicTypeEnum, values::IntValue}; +use inkwell::values::IntValue; use crate::codegen::{ expr::infer_and_call_function, irrt::get_usize_dependent_function_name, @@ -22,26 +22,12 @@ pub fn call_nac3_ndarray_matmul_calculate_shapes<'ctx, G: CodeGenerator + ?Sized ) { let llvm_usize = ctx.get_size_type(); - assert_eq!( - BasicTypeEnum::try_from(a_shape.element_type(ctx, generator)).unwrap(), - llvm_usize.into() - ); - assert_eq!( - BasicTypeEnum::try_from(b_shape.element_type(ctx, generator)).unwrap(), - llvm_usize.into() - ); - assert_eq!( - BasicTypeEnum::try_from(new_a_shape.element_type(ctx, generator)).unwrap(), - llvm_usize.into() - ); - assert_eq!( - BasicTypeEnum::try_from(new_b_shape.element_type(ctx, generator)).unwrap(), - llvm_usize.into() - ); - assert_eq!( - BasicTypeEnum::try_from(dst_shape.element_type(ctx, generator)).unwrap(), - llvm_usize.into() - ); + assert_eq!(a_shape.element_type(ctx, generator), llvm_usize.into()); + assert_eq!(b_shape.element_type(ctx, generator), llvm_usize.into()); + assert_eq!(final_ndims.get_type(), llvm_usize); + assert_eq!(new_a_shape.element_type(ctx, generator), llvm_usize.into()); + assert_eq!(new_b_shape.element_type(ctx, generator), llvm_usize.into()); + assert_eq!(dst_shape.element_type(ctx, generator), llvm_usize.into()); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_matmul_calculate_shapes"); diff --git a/nac3core/src/codegen/irrt/range.rs b/nac3core/src/codegen/irrt/range.rs index 3b6bc31..d624929 100644 --- a/nac3core/src/codegen/irrt/range.rs +++ b/nac3core/src/codegen/irrt/range.rs @@ -1,10 +1,9 @@ use inkwell::{ - values::{BasicValueEnum, CallSiteValue, IntValue}, + values::{BasicValueEnum, IntValue}, IntPredicate, }; -use itertools::Either; -use crate::codegen::{CodeGenContext, CodeGenerator}; +use crate::codegen::{expr::infer_and_call_function, CodeGenContext, CodeGenerator}; /// Invokes the `__nac3_range_slice_len` in IRRT. /// @@ -23,16 +22,10 @@ pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>( const SYMBOL: &str = "__nac3_range_slice_len"; let llvm_i32 = ctx.ctx.i32_type(); - assert_eq!(start.get_type(), llvm_i32); assert_eq!(end.get_type(), llvm_i32); assert_eq!(step.get_type(), llvm_i32); - let len_func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| { - let fn_t = llvm_i32.fn_type(&[llvm_i32.into(), llvm_i32.into(), llvm_i32.into()], false); - ctx.module.add_function(SYMBOL, fn_t, None) - }); - // assert step != 0, throw exception if not let not_zero = ctx .builder @@ -47,10 +40,14 @@ pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>( ctx.current_loc, ); - ctx.builder - .build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap() + infer_and_call_function( + ctx, + SYMBOL, + Some(llvm_i32.into()), + &[start.into(), end.into(), step.into()], + Some("calc_len"), + None, + ) + .map(BasicValueEnum::into_int_value) + .unwrap() } diff --git a/nac3core/src/codegen/irrt/slice.rs b/nac3core/src/codegen/irrt/slice.rs index 35e2151..cc1f28d 100644 --- a/nac3core/src/codegen/irrt/slice.rs +++ b/nac3core/src/codegen/irrt/slice.rs @@ -1,10 +1,9 @@ -use inkwell::values::{BasicValueEnum, CallSiteValue, IntValue}; -use itertools::Either; +use inkwell::values::{BasicValueEnum, IntValue}; use nac3parser::ast::Expr; use crate::{ - codegen::{CodeGenContext, CodeGenerator}, + codegen::{expr::infer_and_call_function, CodeGenContext, CodeGenerator}, typecheck::typedef::Type, }; @@ -17,23 +16,26 @@ pub fn handle_slice_index_bound<'ctx, G: CodeGenerator>( length: IntValue<'ctx>, ) -> Result>, String> { const SYMBOL: &str = "__nac3_slice_index_bound"; - let func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| { - let i32_t = ctx.ctx.i32_type(); - let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into()], false); - ctx.module.add_function(SYMBOL, fn_t, None) - }); + + let llvm_i32 = ctx.ctx.i32_type(); + assert_eq!(length.get_type(), llvm_i32); let i = if let Some(v) = generator.gen_expr(ctx, i)? { v.to_basic_value_enum(ctx, generator, i.custom.unwrap())? } else { return Ok(None); }; + Ok(Some( - ctx.builder - .build_call(func, &[i.into(), length.into()], "bounded_ind") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap(), + infer_and_call_function( + ctx, + SYMBOL, + Some(llvm_i32.into()), + &[i, length.into()], + Some("bounded_ind"), + None, + ) + .map(BasicValueEnum::into_int_value) + .unwrap(), )) } diff --git a/nac3core/src/codegen/irrt/string.rs b/nac3core/src/codegen/irrt/string.rs index e2fd8c0..e015570 100644 --- a/nac3core/src/codegen/irrt/string.rs +++ b/nac3core/src/codegen/irrt/string.rs @@ -1,8 +1,10 @@ -use inkwell::values::{BasicValueEnum, CallSiteValue, IntValue, PointerValue}; -use itertools::Either; +use inkwell::{ + values::{BasicValueEnum, IntValue, PointerValue}, + AddressSpace, +}; use super::get_usize_dependent_function_name; -use crate::codegen::CodeGenContext; +use crate::codegen::{expr::infer_and_call_function, CodeGenContext}; /// Generates a call to string equality comparison. Returns an `i1` representing whether the strings are equal. pub fn call_string_eq<'ctx>( @@ -13,33 +15,23 @@ pub fn call_string_eq<'ctx>( str2_len: IntValue<'ctx>, ) -> IntValue<'ctx> { let llvm_i1 = ctx.ctx.bool_type(); + let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); + let llvm_usize = ctx.get_size_type(); + assert_eq!(str1_ptr.get_type(), llvm_pi8); + assert_eq!(str1_len.get_type(), llvm_usize); + assert_eq!(str2_ptr.get_type(), llvm_pi8); + assert_eq!(str2_len.get_type(), llvm_usize); let func_name = get_usize_dependent_function_name(ctx, "nac3_str_eq"); - let func = ctx.module.get_function(&func_name).unwrap_or_else(|| { - ctx.module.add_function( - &func_name, - llvm_i1.fn_type( - &[ - str1_ptr.get_type().into(), - str1_len.get_type().into(), - str2_ptr.get_type().into(), - str2_len.get_type().into(), - ], - false, - ), - None, - ) - }); - - ctx.builder - .build_call( - func, - &[str1_ptr.into(), str1_len.into(), str2_ptr.into(), str2_len.into()], - "str_eq_call", - ) - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap() + infer_and_call_function( + ctx, + &func_name, + Some(llvm_i1.into()), + &[str1_ptr.into(), str1_len.into(), str2_ptr.into(), str2_len.into()], + Some("str_eq_call"), + None, + ) + .map(BasicValueEnum::into_int_value) + .unwrap() }