[core] codegen/irrt: Refactor IRRT to use more create/infer fns

This commit is contained in:
David Mak 2025-02-04 17:13:33 +08:00
parent 6bcdc3ce00
commit f52ba9f151
11 changed files with 270 additions and 420 deletions

View File

@ -15,7 +15,7 @@ use pyo3::{
use super::{symbol_resolver::InnerResolver, timeline::TimeFns}; use super::{symbol_resolver::InnerResolver, timeline::TimeFns};
use nac3core::{ use nac3core::{
codegen::{ codegen::{
expr::{destructure_range, gen_call}, expr::{create_fn_and_call, destructure_range, gen_call, infer_and_call_function},
llvm_intrinsics::{call_int_smax, call_memcpy, call_stackrestore, call_stacksave}, llvm_intrinsics::{call_int_smax, call_memcpy, call_stackrestore, call_stacksave},
stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with}, stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with},
type_aligned_alloca, type_aligned_alloca,
@ -914,47 +914,14 @@ fn rpc_codegen_callback_fn<'ctx>(
} }
// call // call
if is_async { infer_and_call_function(
let rpc_send_async = ctx.module.get_function("rpc_send_async").unwrap_or_else(|| { ctx,
ctx.module.add_function( if is_async { "rpc_send_async" } else { "rpc_send" },
"rpc_send_async",
ctx.ctx.void_type().fn_type(
&[
int32.into(),
tag_ptr_type.ptr_type(AddressSpace::default()).into(),
ptr_type.ptr_type(AddressSpace::default()).into(),
],
false,
),
None, None,
)
});
ctx.builder
.build_call(
rpc_send_async,
&[service_id.into(), tag_ptr.into(), args_ptr.into()], &[service_id.into(), tag_ptr.into(), args_ptr.into()],
"rpc.send", Some("rpc.send"),
)
.unwrap();
} else {
let rpc_send = ctx.module.get_function("rpc_send").unwrap_or_else(|| {
ctx.module.add_function(
"rpc_send",
ctx.ctx.void_type().fn_type(
&[
int32.into(),
tag_ptr_type.ptr_type(AddressSpace::default()).into(),
ptr_type.ptr_type(AddressSpace::default()).into(),
],
false,
),
None, None,
) );
});
ctx.builder
.build_call(rpc_send, &[service_id.into(), tag_ptr.into(), args_ptr.into()], "rpc.send")
.unwrap();
}
// reclaim stack space used by arguments // reclaim stack space used by arguments
call_stackrestore(ctx, stackptr); call_stackrestore(ctx, stackptr);
@ -1168,29 +1135,22 @@ fn polymorphic_print<'ctx>(
debug_assert!(!fmt.is_empty()); debug_assert!(!fmt.is_empty());
debug_assert_eq!(fmt.as_bytes().last().unwrap(), &0u8); debug_assert_eq!(fmt.as_bytes().last().unwrap(), &0u8);
let fn_name = if as_rtio { "rtio_log" } else { "core_log" };
let print_fn = ctx.module.get_function(fn_name).unwrap_or_else(|| {
let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
let fn_t = if as_rtio {
let llvm_void = ctx.ctx.void_type();
llvm_void.fn_type(&[llvm_pi8.into()], true)
} else {
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
llvm_i32.fn_type(&[llvm_pi8.into()], true) let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
};
ctx.module.add_function(fn_name, fn_t, None)
});
let fmt = ctx.gen_string(generator, fmt); let fmt = ctx.gen_string(generator, fmt);
let fmt = unsafe { fmt.get_field_at_index_unchecked(0) }.into_pointer_value(); let fmt = unsafe { fmt.get_field_at_index_unchecked(0) }.into_pointer_value();
ctx.builder create_fn_and_call(
.build_call( ctx,
print_fn, if as_rtio { "rtio_log" } else { "core_log" },
if as_rtio { None } else { Some(llvm_i32.into()) },
&[llvm_pi8.into()],
&once(fmt.into()).chain(args).map(BasicValueEnum::into).collect_vec(), &once(fmt.into()).chain(args).map(BasicValueEnum::into).collect_vec(),
"", true,
) None,
.unwrap(); None,
);
}; };
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();

View File

@ -1,11 +1,6 @@
use itertools::Either;
use nac3core::{ use nac3core::{
codegen::CodeGenContext, codegen::{expr::infer_and_call_function, CodeGenContext},
inkwell::{ inkwell::{values::BasicValueEnum, AddressSpace, AtomicOrdering},
values::{BasicValueEnum, CallSiteValue},
AddressSpace, AtomicOrdering,
},
}; };
/// Functions for manipulating the timeline. /// Functions for manipulating the timeline.
@ -288,36 +283,27 @@ pub struct ExternTimeFns {}
impl TimeFns for ExternTimeFns { impl TimeFns for ExternTimeFns {
fn emit_now_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx> { fn emit_now_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx> {
let now_mu = ctx.module.get_function("now_mu").unwrap_or_else(|| { infer_and_call_function(
ctx.module.add_function("now_mu", ctx.ctx.i64_type().fn_type(&[], false), None) ctx,
}); "now_mu",
ctx.builder Some(ctx.ctx.i64_type().into()),
.build_call(now_mu, &[], "now_mu") &[],
.map(CallSiteValue::try_as_basic_value) Some("now_mu"),
.map(Either::unwrap_left) None,
)
.unwrap() .unwrap()
} }
fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) { fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) {
let at_mu = ctx.module.get_function("at_mu").unwrap_or_else(|| { assert_eq!(t.get_type(), ctx.ctx.i64_type().into());
ctx.module.add_function(
"at_mu", infer_and_call_function(ctx, "at_mu", None, &[t], Some("at_mu"), None);
ctx.ctx.void_type().fn_type(&[ctx.ctx.i64_type().into()], false),
None,
)
});
ctx.builder.build_call(at_mu, &[t.into()], "at_mu").unwrap();
} }
fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>) { fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>) {
let delay_mu = ctx.module.get_function("delay_mu").unwrap_or_else(|| { assert_eq!(dt.get_type(), ctx.ctx.i64_type().into());
ctx.module.add_function(
"delay_mu", infer_and_call_function(ctx, "delay_mu", None, &[dt], Some("delay_mu"), None);
ctx.ctx.void_type().fn_type(&[ctx.ctx.i64_type().into()], false),
None,
)
});
ctx.builder.build_call(delay_mu, &[dt.into()], "delay_mu").unwrap();
} }
} }

View File

@ -1,10 +1,9 @@
use inkwell::{ use inkwell::{
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue}, values::{BasicValueEnum, FloatValue, IntValue},
}; };
use itertools::Either;
use super::CodeGenContext; use super::{expr::infer_and_call_function, CodeGenContext};
/// Macro to generate extern function /// Macro to generate extern function
/// Both function return type and function parameter type are `FloatValue` /// Both function return type and function parameter type are `FloatValue`
@ -46,23 +45,22 @@ macro_rules! generate_extern_fn {
let llvm_f64 = ctx.ctx.f64_type(); let llvm_f64 = ctx.ctx.f64_type();
$(debug_assert_eq!($args.get_type(), llvm_f64);)* $(debug_assert_eq!($args.get_type(), llvm_f64);)*
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| { infer_and_call_function(
let fn_type = llvm_f64.fn_type(&[$($args.get_type().into()),*], false); ctx,
let func = ctx.module.add_function(FN_NAME, fn_type, None); FN_NAME,
Some(llvm_f64.into()),
&[$($args.into()),*],
name,
Some(&|func| {
for attr in [$($attributes),*] { for attr in [$($attributes),*] {
func.add_attribute( func.add_attribute(
AttributeLoc::Function, AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
); );
} }
func })
}); )
.map(BasicValueEnum::into_float_value)
ctx.builder
.build_call(extern_fn, &[$($args.into()),*], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap() .unwrap()
} }
}; };
@ -112,24 +110,22 @@ pub fn call_ldexp<'ctx>(
debug_assert_eq!(arg.get_type(), llvm_f64); debug_assert_eq!(arg.get_type(), llvm_f64);
debug_assert_eq!(exp.get_type(), llvm_i32); debug_assert_eq!(exp.get_type(), llvm_i32);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| { infer_and_call_function(
let fn_type = llvm_f64.fn_type(&[llvm_f64.into(), llvm_i32.into()], false); ctx,
let func = ctx.module.add_function(FN_NAME, fn_type, None); FN_NAME,
Some(llvm_f64.into()),
&[arg.into(), exp.into()],
name,
Some(&|func| {
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] { for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
func.add_attribute( func.add_attribute(
AttributeLoc::Function, AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
); );
} }
}),
func )
}); .map(BasicValueEnum::into_float_value)
ctx.builder
.build_call(extern_fn, &[arg.into(), exp.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap() .unwrap()
} }
@ -163,20 +159,22 @@ macro_rules! generate_linalg_extern_fn {
name: Option<&str>, name: Option<&str>,
){ ){
const FN_NAME: &str = $extern_fn; const FN_NAME: &str = $extern_fn;
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = ctx.ctx.void_type().fn_type(&[$($input_matrix.get_type().into()),*], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None); infer_and_call_function(
ctx,
FN_NAME,
None,
&[$($input_matrix.into(),)*],
name,
Some(&|func| {
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute( func.add_attribute(
AttributeLoc::Function, AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
); );
} }
func }),
}); );
ctx.builder.build_call(extern_fn, &[$($input_matrix.into(),)*], name.unwrap_or_default()).unwrap();
} }
}; };
} }

View File

@ -1,13 +1,14 @@
use inkwell::{ use inkwell::{
types::BasicTypeEnum, types::BasicTypeEnum,
values::{BasicValueEnum, CallSiteValue, IntValue}, values::{BasicValueEnum, IntValue},
AddressSpace, IntPredicate, AddressSpace, IntPredicate,
}; };
use itertools::Either;
use super::calculate_len_for_slice_range; use super::calculate_len_for_slice_range;
use crate::codegen::{ use crate::codegen::{
expr::infer_and_call_function,
macros::codegen_unreachable, macros::codegen_unreachable,
stmt::gen_if_callback,
values::{ArrayLikeValue, ListValue}, values::{ArrayLikeValue, ListValue},
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}; };
@ -36,25 +37,6 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
assert_eq!(src_idx.2.get_type(), llvm_i32); assert_eq!(src_idx.2.get_type(), llvm_i32);
let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", llvm_pi8); let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", llvm_pi8);
let slice_assign_fun = {
let ty_vec = vec![
llvm_i32.into(), // dest start idx
llvm_i32.into(), // dest end idx
llvm_i32.into(), // dest step
elem_ptr_type.into(), // dest arr ptr
llvm_i32.into(), // dest arr len
llvm_i32.into(), // src start idx
llvm_i32.into(), // src end idx
llvm_i32.into(), // src step
elem_ptr_type.into(), // src arr ptr
llvm_i32.into(), // src arr len
llvm_i32.into(), // size
];
ctx.module.get_function(fun_symbol).unwrap_or_else(|| {
let fn_t = llvm_i32.fn_type(ty_vec.as_slice(), false);
ctx.module.add_function(fun_symbol, fn_t, None)
})
};
let zero = llvm_i32.const_zero(); let zero = llvm_i32.const_zero();
let one = llvm_i32.const_int(1, false); let one = llvm_i32.const_int(1, false);
@ -127,7 +109,7 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
); );
let new_len = { let new_len = {
let args = vec![ let args = [
dest_idx.0.into(), // dest start idx dest_idx.0.into(), // dest start idx
dest_idx.1.into(), // dest end idx dest_idx.1.into(), // dest end idx
dest_idx.2.into(), // dest step dest_idx.2.into(), // dest step
@ -150,25 +132,35 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
} }
.into(), .into(),
]; ];
ctx.builder infer_and_call_function(
.build_call(slice_assign_fun, args.as_slice(), "slice_assign") ctx,
.map(CallSiteValue::try_as_basic_value) fun_symbol,
.map(|v| v.map_left(BasicValueEnum::into_int_value)) Some(llvm_i32.into()),
.map(Either::unwrap_left) &args,
Some("slice_assign"),
None,
)
.map(BasicValueEnum::into_int_value)
.unwrap() .unwrap()
}; };
// update length // update length
let need_update = gen_if_callback(
ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update").unwrap(); generator,
let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); ctx,
let update_bb = ctx.ctx.append_basic_block(current, "update"); |_, ctx| {
let cont_bb = ctx.ctx.append_basic_block(current, "cont"); Ok(ctx
ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb).unwrap(); .builder
ctx.builder.position_at_end(update_bb); .build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update")
.unwrap())
},
|_, ctx| {
let new_len = let new_len =
ctx.builder.build_int_z_extend_or_bit_cast(new_len, llvm_usize, "new_len").unwrap(); ctx.builder.build_int_z_extend_or_bit_cast(new_len, llvm_usize, "new_len").unwrap();
dest_arr.store_size(ctx, new_len); dest_arr.store_size(ctx, new_len);
ctx.builder.build_unconditional_branch(cont_bb).unwrap(); Ok(())
ctx.builder.position_at_end(cont_bb); },
|_, _| Ok(()),
)
.unwrap();
} }

View File

@ -1,10 +1,10 @@
use inkwell::{ use inkwell::{
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue}, values::{BasicValueEnum, FloatValue, IntValue},
IntPredicate, IntPredicate,
}; };
use itertools::Either;
use crate::codegen::{ use crate::codegen::{
expr::infer_and_call_function,
macros::codegen_unreachable, macros::codegen_unreachable,
{CodeGenContext, CodeGenerator}, {CodeGenContext, CodeGenerator},
}; };
@ -18,18 +18,16 @@ pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>(
exp: IntValue<'ctx>, exp: IntValue<'ctx>,
signed: bool, signed: bool,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
let symbol = match (base.get_type().get_bit_width(), exp.get_type().get_bit_width(), signed) { let base_type = base.get_type();
let symbol = match (base_type.get_bit_width(), exp.get_type().get_bit_width(), signed) {
(32, 32, true) => "__nac3_int_exp_int32_t", (32, 32, true) => "__nac3_int_exp_int32_t",
(64, 64, true) => "__nac3_int_exp_int64_t", (64, 64, true) => "__nac3_int_exp_int64_t",
(32, 32, false) => "__nac3_int_exp_uint32_t", (32, 32, false) => "__nac3_int_exp_uint32_t",
(64, 64, false) => "__nac3_int_exp_uint64_t", (64, 64, false) => "__nac3_int_exp_uint64_t",
_ => codegen_unreachable!(ctx), _ => codegen_unreachable!(ctx),
}; };
let base_type = base.get_type();
let pow_fun = ctx.module.get_function(symbol).unwrap_or_else(|| {
let fn_type = base_type.fn_type(&[base_type.into(), base_type.into()], false);
ctx.module.add_function(symbol, fn_type, None)
});
// throw exception when exp < 0 // throw exception when exp < 0
let ge_zero = ctx let ge_zero = ctx
.builder .builder
@ -48,11 +46,16 @@ pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>(
[None, None, None], [None, None, None],
ctx.current_loc, ctx.current_loc,
); );
ctx.builder
.build_call(pow_fun, &[base.into(), exp.into()], "call_int_pow") infer_and_call_function(
.map(CallSiteValue::try_as_basic_value) ctx,
.map(|v| v.map_left(BasicValueEnum::into_int_value)) symbol,
.map(Either::unwrap_left) Some(base_type.into()),
&[base.into(), exp.into()],
Some("call_int_pow"),
None,
)
.map(BasicValueEnum::into_int_value)
.unwrap() .unwrap()
} }
@ -67,20 +70,17 @@ pub fn call_isinf<'ctx, G: CodeGenerator + ?Sized>(
assert_eq!(v.get_type(), llvm_f64); assert_eq!(v.get_type(), llvm_f64);
let intrinsic_fn = ctx.module.get_function("__nac3_isinf").unwrap_or_else(|| { infer_and_call_function(
let fn_type = llvm_i32.fn_type(&[llvm_f64.into()], false); ctx,
ctx.module.add_function("__nac3_isinf", fn_type, None) "__nac3_isinf",
}); Some(llvm_i32.into()),
&[v.into()],
let ret = ctx Some("isinf"),
.builder None,
.build_call(intrinsic_fn, &[v.into()], "isinf") )
.map(CallSiteValue::try_as_basic_value) .map(BasicValueEnum::into_int_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value)) .map(|ret| generator.bool_to_i1(ctx, ret))
.map(Either::unwrap_left) .unwrap()
.unwrap();
generator.bool_to_i1(ctx, ret)
} }
/// Generates a call to `isnan` in IR. Returns an `i1` representing the result. /// Generates a call to `isnan` in IR. Returns an `i1` representing the result.
@ -94,20 +94,17 @@ pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>(
assert_eq!(v.get_type(), llvm_f64); assert_eq!(v.get_type(), llvm_f64);
let intrinsic_fn = ctx.module.get_function("__nac3_isnan").unwrap_or_else(|| { infer_and_call_function(
let fn_type = llvm_i32.fn_type(&[llvm_f64.into()], false); ctx,
ctx.module.add_function("__nac3_isnan", fn_type, None) "__nac3_isnan",
}); Some(llvm_i32.into()),
&[v.into()],
let ret = ctx Some("isnan"),
.builder None,
.build_call(intrinsic_fn, &[v.into()], "isnan") )
.map(CallSiteValue::try_as_basic_value) .map(BasicValueEnum::into_int_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value)) .map(|ret| generator.bool_to_i1(ctx, ret))
.map(Either::unwrap_left) .unwrap()
.unwrap();
generator.bool_to_i1(ctx, ret)
} }
/// Generates a call to `gamma` in IR. Returns an `f64` representing the result. /// Generates a call to `gamma` in IR. Returns an `f64` representing the result.
@ -116,16 +113,15 @@ pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) ->
assert_eq!(v.get_type(), llvm_f64); assert_eq!(v.get_type(), llvm_f64);
let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| { infer_and_call_function(
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); ctx,
ctx.module.add_function("__nac3_gamma", fn_type, None) "__nac3_gamma",
}); Some(llvm_f64.into()),
&[v.into()],
ctx.builder Some("gamma"),
.build_call(intrinsic_fn, &[v.into()], "gamma") None,
.map(CallSiteValue::try_as_basic_value) )
.map(|v| v.map_left(BasicValueEnum::into_float_value)) .map(BasicValueEnum::into_float_value)
.map(Either::unwrap_left)
.unwrap() .unwrap()
} }
@ -135,16 +131,15 @@ pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -
assert_eq!(v.get_type(), llvm_f64); assert_eq!(v.get_type(), llvm_f64);
let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| { infer_and_call_function(
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); ctx,
ctx.module.add_function("__nac3_gammaln", fn_type, None) "__nac3_gammaln",
}); Some(llvm_f64.into()),
&[v.into()],
ctx.builder Some("gammaln"),
.build_call(intrinsic_fn, &[v.into()], "gammaln") None,
.map(CallSiteValue::try_as_basic_value) )
.map(|v| v.map_left(BasicValueEnum::into_float_value)) .map(BasicValueEnum::into_float_value)
.map(Either::unwrap_left)
.unwrap() .unwrap()
} }
@ -154,15 +149,7 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo
assert_eq!(v.get_type(), llvm_f64); assert_eq!(v.get_type(), llvm_f64);
let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| { infer_and_call_function(ctx, "__nac3_j0", Some(llvm_f64.into()), &[v.into()], Some("j0"), None)
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); .map(BasicValueEnum::into_float_value)
ctx.module.add_function("__nac3_j0", fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[v.into()], "j0")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap() .unwrap()
} }

View File

@ -1,13 +1,11 @@
use inkwell::{ use inkwell::{
types::BasicTypeEnum,
values::{BasicValueEnum, IntValue, PointerValue}, values::{BasicValueEnum, IntValue, PointerValue},
AddressSpace, AddressSpace,
}; };
use crate::codegen::{ use crate::codegen::{
expr::{create_and_call_function, infer_and_call_function}, expr::infer_and_call_function,
irrt::get_usize_dependent_function_name, irrt::get_usize_dependent_function_name,
types::ProxyType,
values::{ndarray::NDArrayValue, ProxyValue, TypedArrayLikeAccessor}, values::{ndarray::NDArrayValue, ProxyValue, TypedArrayLikeAccessor},
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}; };
@ -21,24 +19,17 @@ pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator +
shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
) { ) {
let llvm_usize = ctx.get_size_type(); let llvm_usize = ctx.get_size_type();
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
assert_eq!( assert_eq!(shape.element_type(ctx, generator), llvm_usize.into());
BasicTypeEnum::try_from(shape.element_type(ctx, generator)).unwrap(),
llvm_usize.into()
);
let name = let name =
get_usize_dependent_function_name(ctx, "__nac3_ndarray_util_assert_shape_no_negative"); get_usize_dependent_function_name(ctx, "__nac3_ndarray_util_assert_shape_no_negative");
create_and_call_function( infer_and_call_function(
ctx, ctx,
&name, &name,
Some(llvm_usize.into()), Some(llvm_usize.into()),
&[ &[shape.size(ctx, generator).into(), shape.base_ptr(ctx, generator).into()],
(llvm_usize.into(), shape.size(ctx, generator).into()),
(llvm_pusize.into(), shape.base_ptr(ctx, generator).into()),
],
None, None,
None, None,
); );
@ -55,29 +46,22 @@ pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator +
output_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, output_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
) { ) {
let llvm_usize = ctx.get_size_type(); let llvm_usize = ctx.get_size_type();
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
assert_eq!( assert_eq!(ndarray_shape.element_type(ctx, generator), llvm_usize.into());
BasicTypeEnum::try_from(ndarray_shape.element_type(ctx, generator)).unwrap(), assert_eq!(output_shape.element_type(ctx, generator), llvm_usize.into());
llvm_usize.into()
);
assert_eq!(
BasicTypeEnum::try_from(output_shape.element_type(ctx, generator)).unwrap(),
llvm_usize.into()
);
let name = let name =
get_usize_dependent_function_name(ctx, "__nac3_ndarray_util_assert_output_shape_same"); get_usize_dependent_function_name(ctx, "__nac3_ndarray_util_assert_output_shape_same");
create_and_call_function( infer_and_call_function(
ctx, ctx,
&name, &name,
Some(llvm_usize.into()), Some(llvm_usize.into()),
&[ &[
(llvm_usize.into(), ndarray_shape.size(ctx, generator).into()), ndarray_shape.size(ctx, generator).into(),
(llvm_pusize.into(), ndarray_shape.base_ptr(ctx, generator).into()), ndarray_shape.base_ptr(ctx, generator).into(),
(llvm_usize.into(), output_shape.size(ctx, generator).into()), output_shape.size(ctx, generator).into(),
(llvm_pusize.into(), output_shape.base_ptr(ctx, generator).into()), output_shape.base_ptr(ctx, generator).into(),
], ],
None, None,
None, None,
@ -93,15 +77,14 @@ pub fn call_nac3_ndarray_size<'ctx>(
ndarray: NDArrayValue<'ctx>, ndarray: NDArrayValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
let llvm_usize = ctx.get_size_type(); let llvm_usize = ctx.get_size_type();
let llvm_ndarray = ndarray.get_type();
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_size"); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_size");
create_and_call_function( infer_and_call_function(
ctx, ctx,
&name, &name,
Some(llvm_usize.into()), Some(llvm_usize.into()),
&[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())], &[ndarray.as_abi_value(ctx).into()],
Some("size"), Some("size"),
None, None,
) )
@ -118,15 +101,14 @@ pub fn call_nac3_ndarray_nbytes<'ctx>(
ndarray: NDArrayValue<'ctx>, ndarray: NDArrayValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
let llvm_usize = ctx.get_size_type(); let llvm_usize = ctx.get_size_type();
let llvm_ndarray = ndarray.get_type();
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_nbytes"); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_nbytes");
create_and_call_function( infer_and_call_function(
ctx, ctx,
&name, &name,
Some(llvm_usize.into()), Some(llvm_usize.into()),
&[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())], &[ndarray.as_abi_value(ctx).into()],
Some("nbytes"), Some("nbytes"),
None, None,
) )
@ -143,15 +125,14 @@ pub fn call_nac3_ndarray_len<'ctx>(
ndarray: NDArrayValue<'ctx>, ndarray: NDArrayValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
let llvm_usize = ctx.get_size_type(); let llvm_usize = ctx.get_size_type();
let llvm_ndarray = ndarray.get_type();
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_len"); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_len");
create_and_call_function( infer_and_call_function(
ctx, ctx,
&name, &name,
Some(llvm_usize.into()), Some(llvm_usize.into()),
&[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())], &[ndarray.as_abi_value(ctx).into()],
Some("len"), Some("len"),
None, None,
) )
@ -167,15 +148,14 @@ pub fn call_nac3_ndarray_is_c_contiguous<'ctx>(
ndarray: NDArrayValue<'ctx>, ndarray: NDArrayValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
let llvm_i1 = ctx.ctx.bool_type(); let llvm_i1 = ctx.ctx.bool_type();
let llvm_ndarray = ndarray.get_type();
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_is_c_contiguous"); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_is_c_contiguous");
create_and_call_function( infer_and_call_function(
ctx, ctx,
&name, &name,
Some(llvm_i1.into()), Some(llvm_i1.into()),
&[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())], &[ndarray.as_abi_value(ctx).into()],
Some("is_c_contiguous"), Some("is_c_contiguous"),
None, None,
) )
@ -194,20 +174,16 @@ pub fn call_nac3_ndarray_get_nth_pelement<'ctx>(
let llvm_i8 = ctx.ctx.i8_type(); let llvm_i8 = ctx.ctx.i8_type();
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
let llvm_usize = ctx.get_size_type(); let llvm_usize = ctx.get_size_type();
let llvm_ndarray = ndarray.get_type();
assert_eq!(index.get_type(), llvm_usize); assert_eq!(index.get_type(), llvm_usize);
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_get_nth_pelement"); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_get_nth_pelement");
create_and_call_function( infer_and_call_function(
ctx, ctx,
&name, &name,
Some(llvm_pi8.into()), Some(llvm_pi8.into()),
&[ &[ndarray.as_abi_value(ctx).into(), index.into()],
(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into()),
(llvm_usize.into(), index.into()),
],
Some("pelement"), Some("pelement"),
None, None,
) )
@ -229,24 +205,16 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized
let llvm_i8 = ctx.ctx.i8_type(); let llvm_i8 = ctx.ctx.i8_type();
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
let llvm_usize = ctx.get_size_type(); let llvm_usize = ctx.get_size_type();
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let llvm_ndarray = ndarray.get_type();
assert_eq!( assert_eq!(indices.element_type(ctx, generator), llvm_usize.into());
BasicTypeEnum::try_from(indices.element_type(ctx, generator)).unwrap(),
llvm_usize.into()
);
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_get_pelement_by_indices"); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_get_pelement_by_indices");
create_and_call_function( infer_and_call_function(
ctx, ctx,
&name, &name,
Some(llvm_pi8.into()), Some(llvm_pi8.into()),
&[ &[ndarray.as_abi_value(ctx).into(), indices.base_ptr(ctx, generator).into()],
(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into()),
(llvm_pusize.into(), indices.base_ptr(ctx, generator).into()),
],
Some("pelement"), Some("pelement"),
None, None,
) )
@ -261,18 +229,9 @@ pub fn call_nac3_ndarray_set_strides_by_shape<'ctx>(
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>, ndarray: NDArrayValue<'ctx>,
) { ) {
let llvm_ndarray = ndarray.get_type();
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_set_strides_by_shape"); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_set_strides_by_shape");
create_and_call_function( infer_and_call_function(ctx, &name, None, &[ndarray.as_abi_value(ctx).into()], None, None);
ctx,
&name,
None,
&[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())],
None,
None,
);
} }
/// Generates a call to `__nac3_ndarray_copy_data`. /// Generates a call to `__nac3_ndarray_copy_data`.

View File

@ -1,13 +1,8 @@
use inkwell::{ use inkwell::values::{BasicValueEnum, IntValue};
types::BasicTypeEnum,
values::{BasicValueEnum, IntValue},
AddressSpace,
};
use crate::codegen::{ use crate::codegen::{
expr::{create_and_call_function, infer_and_call_function}, expr::infer_and_call_function,
irrt::get_usize_dependent_function_name, irrt::get_usize_dependent_function_name,
types::ProxyType,
values::{ values::{
ndarray::{NDArrayValue, NDIterValue}, ndarray::{NDArrayValue, NDIterValue},
ProxyValue, TypedArrayLikeAccessor, ProxyValue, TypedArrayLikeAccessor,
@ -26,23 +21,19 @@ pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>(
indices: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, indices: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
) { ) {
let llvm_usize = ctx.get_size_type(); let llvm_usize = ctx.get_size_type();
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
assert_eq!( assert_eq!(indices.element_type(ctx, generator), llvm_usize.into());
BasicTypeEnum::try_from(indices.element_type(ctx, generator)).unwrap(),
llvm_usize.into()
);
let name = get_usize_dependent_function_name(ctx, "__nac3_nditer_initialize"); let name = get_usize_dependent_function_name(ctx, "__nac3_nditer_initialize");
create_and_call_function( infer_and_call_function(
ctx, ctx,
&name, &name,
None, None,
&[ &[
(iter.get_type().as_abi_type().into(), iter.as_abi_value(ctx).into()), iter.as_abi_value(ctx).into(),
(ndarray.get_type().as_abi_type().into(), ndarray.as_abi_value(ctx).into()), ndarray.as_abi_value(ctx).into(),
(llvm_pusize.into(), indices.base_ptr(ctx, generator).into()), indices.base_ptr(ctx, generator).into(),
], ],
None, None,
None, None,

View File

@ -1,4 +1,4 @@
use inkwell::{types::BasicTypeEnum, values::IntValue}; use inkwell::values::IntValue;
use crate::codegen::{ use crate::codegen::{
expr::infer_and_call_function, irrt::get_usize_dependent_function_name, expr::infer_and_call_function, irrt::get_usize_dependent_function_name,
@ -22,26 +22,12 @@ pub fn call_nac3_ndarray_matmul_calculate_shapes<'ctx, G: CodeGenerator + ?Sized
) { ) {
let llvm_usize = ctx.get_size_type(); let llvm_usize = ctx.get_size_type();
assert_eq!( assert_eq!(a_shape.element_type(ctx, generator), llvm_usize.into());
BasicTypeEnum::try_from(a_shape.element_type(ctx, generator)).unwrap(), assert_eq!(b_shape.element_type(ctx, generator), llvm_usize.into());
llvm_usize.into() assert_eq!(final_ndims.get_type(), llvm_usize);
); assert_eq!(new_a_shape.element_type(ctx, generator), llvm_usize.into());
assert_eq!( assert_eq!(new_b_shape.element_type(ctx, generator), llvm_usize.into());
BasicTypeEnum::try_from(b_shape.element_type(ctx, generator)).unwrap(), assert_eq!(dst_shape.element_type(ctx, generator), llvm_usize.into());
llvm_usize.into()
);
assert_eq!(
BasicTypeEnum::try_from(new_a_shape.element_type(ctx, generator)).unwrap(),
llvm_usize.into()
);
assert_eq!(
BasicTypeEnum::try_from(new_b_shape.element_type(ctx, generator)).unwrap(),
llvm_usize.into()
);
assert_eq!(
BasicTypeEnum::try_from(dst_shape.element_type(ctx, generator)).unwrap(),
llvm_usize.into()
);
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_matmul_calculate_shapes"); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_matmul_calculate_shapes");

View File

@ -1,10 +1,9 @@
use inkwell::{ use inkwell::{
values::{BasicValueEnum, CallSiteValue, IntValue}, values::{BasicValueEnum, IntValue},
IntPredicate, IntPredicate,
}; };
use itertools::Either;
use crate::codegen::{CodeGenContext, CodeGenerator}; use crate::codegen::{expr::infer_and_call_function, CodeGenContext, CodeGenerator};
/// Invokes the `__nac3_range_slice_len` in IRRT. /// Invokes the `__nac3_range_slice_len` in IRRT.
/// ///
@ -23,16 +22,10 @@ pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>(
const SYMBOL: &str = "__nac3_range_slice_len"; const SYMBOL: &str = "__nac3_range_slice_len";
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
assert_eq!(start.get_type(), llvm_i32); assert_eq!(start.get_type(), llvm_i32);
assert_eq!(end.get_type(), llvm_i32); assert_eq!(end.get_type(), llvm_i32);
assert_eq!(step.get_type(), llvm_i32); assert_eq!(step.get_type(), llvm_i32);
let len_func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
let fn_t = llvm_i32.fn_type(&[llvm_i32.into(), llvm_i32.into(), llvm_i32.into()], false);
ctx.module.add_function(SYMBOL, fn_t, None)
});
// assert step != 0, throw exception if not // assert step != 0, throw exception if not
let not_zero = ctx let not_zero = ctx
.builder .builder
@ -47,10 +40,14 @@ pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>(
ctx.current_loc, ctx.current_loc,
); );
ctx.builder infer_and_call_function(
.build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len") ctx,
.map(CallSiteValue::try_as_basic_value) SYMBOL,
.map(|v| v.map_left(BasicValueEnum::into_int_value)) Some(llvm_i32.into()),
.map(Either::unwrap_left) &[start.into(), end.into(), step.into()],
Some("calc_len"),
None,
)
.map(BasicValueEnum::into_int_value)
.unwrap() .unwrap()
} }

View File

@ -1,10 +1,9 @@
use inkwell::values::{BasicValueEnum, CallSiteValue, IntValue}; use inkwell::values::{BasicValueEnum, IntValue};
use itertools::Either;
use nac3parser::ast::Expr; use nac3parser::ast::Expr;
use crate::{ use crate::{
codegen::{CodeGenContext, CodeGenerator}, codegen::{expr::infer_and_call_function, CodeGenContext, CodeGenerator},
typecheck::typedef::Type, typecheck::typedef::Type,
}; };
@ -17,23 +16,26 @@ pub fn handle_slice_index_bound<'ctx, G: CodeGenerator>(
length: IntValue<'ctx>, length: IntValue<'ctx>,
) -> Result<Option<IntValue<'ctx>>, String> { ) -> Result<Option<IntValue<'ctx>>, String> {
const SYMBOL: &str = "__nac3_slice_index_bound"; const SYMBOL: &str = "__nac3_slice_index_bound";
let func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
let i32_t = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into()], false); assert_eq!(length.get_type(), llvm_i32);
ctx.module.add_function(SYMBOL, fn_t, None)
});
let i = if let Some(v) = generator.gen_expr(ctx, i)? { let i = if let Some(v) = generator.gen_expr(ctx, i)? {
v.to_basic_value_enum(ctx, generator, i.custom.unwrap())? v.to_basic_value_enum(ctx, generator, i.custom.unwrap())?
} else { } else {
return Ok(None); return Ok(None);
}; };
Ok(Some( Ok(Some(
ctx.builder infer_and_call_function(
.build_call(func, &[i.into(), length.into()], "bounded_ind") ctx,
.map(CallSiteValue::try_as_basic_value) SYMBOL,
.map(|v| v.map_left(BasicValueEnum::into_int_value)) Some(llvm_i32.into()),
.map(Either::unwrap_left) &[i, length.into()],
Some("bounded_ind"),
None,
)
.map(BasicValueEnum::into_int_value)
.unwrap(), .unwrap(),
)) ))
} }

View File

@ -1,8 +1,10 @@
use inkwell::values::{BasicValueEnum, CallSiteValue, IntValue, PointerValue}; use inkwell::{
use itertools::Either; values::{BasicValueEnum, IntValue, PointerValue},
AddressSpace,
};
use super::get_usize_dependent_function_name; use super::get_usize_dependent_function_name;
use crate::codegen::CodeGenContext; use crate::codegen::{expr::infer_and_call_function, CodeGenContext};
/// Generates a call to string equality comparison. Returns an `i1` representing whether the strings are equal. /// Generates a call to string equality comparison. Returns an `i1` representing whether the strings are equal.
pub fn call_string_eq<'ctx>( pub fn call_string_eq<'ctx>(
@ -13,33 +15,23 @@ pub fn call_string_eq<'ctx>(
str2_len: IntValue<'ctx>, str2_len: IntValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
let llvm_i1 = ctx.ctx.bool_type(); let llvm_i1 = ctx.ctx.bool_type();
let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
let llvm_usize = ctx.get_size_type();
assert_eq!(str1_ptr.get_type(), llvm_pi8);
assert_eq!(str1_len.get_type(), llvm_usize);
assert_eq!(str2_ptr.get_type(), llvm_pi8);
assert_eq!(str2_len.get_type(), llvm_usize);
let func_name = get_usize_dependent_function_name(ctx, "nac3_str_eq"); let func_name = get_usize_dependent_function_name(ctx, "nac3_str_eq");
let func = ctx.module.get_function(&func_name).unwrap_or_else(|| { infer_and_call_function(
ctx.module.add_function( ctx,
&func_name, &func_name,
llvm_i1.fn_type( Some(llvm_i1.into()),
&[ &[str1_ptr.into(), str1_len.into(), str2_ptr.into(), str2_len.into()],
str1_ptr.get_type().into(), Some("str_eq_call"),
str1_len.get_type().into(),
str2_ptr.get_type().into(),
str2_len.get_type().into(),
],
false,
),
None, None,
) )
}); .map(BasicValueEnum::into_int_value)
ctx.builder
.build_call(
func,
&[str1_ptr.into(), str1_len.into(), str2_ptr.into(), str2_len.into()],
"str_eq_call",
)
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap() .unwrap()
} }