forked from M-Labs/nac3
1
0
Fork 0

core: Add gen_for_callback_incrementing

Simplifies generation of monotonically increasing for loops.
This commit is contained in:
David Mak 2024-03-08 13:13:18 +08:00
parent 50264e8750
commit cfbc37c1ed
3 changed files with 105 additions and 128 deletions

View File

@ -8,7 +8,7 @@ use crate::codegen::{
CodeGenerator,
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index, call_ndarray_flatten_index_const},
llvm_intrinsics::call_int_umin,
stmt::gen_for_callback,
stmt::gen_for_callback_incrementing,
};
#[cfg(not(debug_assertions))]
@ -940,30 +940,15 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
ctx.current_loc,
);
gen_for_callback(
let indices_len = indices.load_size(ctx, None);
let ndarray_len = self.0.load_ndims(ctx);
let len = call_int_umin(ctx, indices_len, ndarray_len, None);
gen_for_callback_incrementing(
generator,
ctx,
|generator, ctx| {
let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
ctx.builder.build_store(i, llvm_usize.const_zero()).unwrap();
Ok(i)
},
|_, ctx, i_addr| {
let indices_len = indices.load_size(ctx, None);
let ndarray_len = self.0.load_ndims(ctx);
let len = call_int_umin(ctx, indices_len, ndarray_len, None);
let i = ctx.builder.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
Ok(ctx.builder.build_int_compare(IntPredicate::SLT, i, len, "").unwrap())
},
|generator, ctx, i_addr| {
let i = ctx.builder.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
llvm_usize.const_zero(),
(len, false),
|generator, ctx, i| {
let (dim_idx, dim_sz) = unsafe {
(
indices.data().get_unchecked(ctx, i, None).into_int_value(),
@ -989,16 +974,7 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
Ok(())
},
|_, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "").unwrap();
ctx.builder.build_store(i_addr, i).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
).unwrap();
unsafe {

View File

@ -14,7 +14,7 @@ use crate::{
call_ndarray_calc_size,
},
llvm_intrinsics::call_memcpy_generic,
stmt::gen_for_callback
stmt::gen_for_callback_incrementing,
},
symbol_resolver::ValueEnum,
toplevel::{
@ -52,30 +52,13 @@ fn create_ndarray_dyn_shape<'ctx, 'a, V, LenFn, DataFn>(
assert!(llvm_ndarray_data_t.is_sized());
// Assert that all dimensions are non-negative
gen_for_callback(
let shape_len = shape_len_fn(generator, ctx, shape)?;
gen_for_callback_incrementing(
generator,
ctx,
|generator, ctx| {
let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
ctx.builder.build_store(i, llvm_usize.const_zero()).unwrap();
Ok(i)
},
|generator, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let shape_len = shape_len_fn(generator, ctx, shape)?;
debug_assert!(shape_len.get_type().get_bit_width() <= llvm_usize.get_bit_width());
Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, shape_len, "").unwrap())
},
|generator, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
llvm_usize.const_zero(),
(shape_len, false),
|generator, ctx, i| {
let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
@ -94,16 +77,7 @@ fn create_ndarray_dyn_shape<'ctx, 'a, V, LenFn, DataFn>(
Ok(())
},
|_, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "").unwrap();
ctx.builder.build_store(i_addr, i).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let ndarray = generator.gen_var_alloc(
@ -120,30 +94,13 @@ fn create_ndarray_dyn_shape<'ctx, 'a, V, LenFn, DataFn>(
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
// Copy the dimension sizes from shape to ndarray.dims
gen_for_callback(
let shape_len = shape_len_fn(generator, ctx, shape)?;
gen_for_callback_incrementing(
generator,
ctx,
|generator, ctx| {
let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
ctx.builder.build_store(i, llvm_usize.const_zero()).unwrap();
Ok(i)
},
|generator, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let shape_len = shape_len_fn(generator, ctx, shape)?;
debug_assert!(shape_len.get_type().get_bit_width() <= llvm_usize.get_bit_width());
Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, shape_len, "").unwrap())
},
|generator, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
llvm_usize.const_zero(),
(shape_len, false),
|generator, ctx, i| {
let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
let shape_dim = ctx.builder
@ -158,16 +115,7 @@ fn create_ndarray_dyn_shape<'ctx, 'a, V, LenFn, DataFn>(
Ok(())
},
|_, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "").unwrap();
ctx.builder.build_store(i_addr, i).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let ndarray_num_elems = call_ndarray_calc_size(
@ -342,28 +290,12 @@ fn ndarray_fill_flattened<'ctx, 'a, ValueFn>(
ndarray.dim_sizes().as_ptr_value(ctx),
);
gen_for_callback(
gen_for_callback_incrementing(
generator,
ctx,
|generator, ctx| {
let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
ctx.builder.build_store(i, llvm_usize.const_zero()).unwrap();
Ok(i)
},
|_, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, ndarray_num_elems, "").unwrap())
},
|generator, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
llvm_usize.const_zero(),
(ndarray_num_elems, false),
|generator, ctx, i| {
let elem = unsafe {
ndarray.data().ptr_to_data_flattened_unchecked(ctx, i, None)
};
@ -373,16 +305,7 @@ fn ndarray_fill_flattened<'ctx, 'a, ValueFn>(
Ok(())
},
|_, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "").unwrap();
ctx.builder.build_store(i_addr, i).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)
}

View File

@ -536,6 +536,84 @@ pub fn gen_for_callback<'ctx, 'a, I, InitFn, CondFn, BodyFn, UpdateFn>(
Ok(())
}
/// Generates a C-style monotonically-increasing `for` construct using lambdas, similar to the
/// following C code:
///
/// ```c
/// for (int x = init_val; x /* < or <= ; see `max_val` */ max_val; x += incr_val) {
/// body(x);
/// }
/// ```
///
/// * `init_val` - The initial value of the loop variable. The type of this value will also be used
/// as the type of the loop variable.
/// * `max_val` - A tuple containing the maximum value of the loop variable, and whether the maximum
/// value should be treated as inclusive (as opposed to exclusive).
/// * `body` - A lambda containing IR statements within the loop body.
/// * `incr_val` - The value to increment the loop variable on each iteration.
pub fn gen_for_callback_incrementing<'ctx, 'a, BodyFn>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>,
init_val: IntValue<'ctx>,
max_val: (IntValue<'ctx>, bool),
body: BodyFn,
incr_val: IntValue<'ctx>,
) -> Result<(), String>
where
BodyFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<(), String>,
{
let init_val_t = init_val.get_type();
gen_for_callback(
generator,
ctx,
|generator, ctx| {
let i_addr = generator.gen_var_alloc(ctx, init_val_t.into(), None)?;
ctx.builder.build_store(i_addr, init_val).unwrap();
Ok(i_addr)
},
|_, ctx, i_addr| {
let cmp_op = if max_val.1 {
IntPredicate::ULE
} else {
IntPredicate::ULT
};
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let max_val = ctx.builder
.build_int_z_extend_or_bit_cast(max_val.0, init_val_t, "")
.unwrap();
Ok(ctx.builder.build_int_compare(cmp_op, i, max_val, "").unwrap())
},
|generator, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
body(generator, ctx, i)
},
|_, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let incr_val = ctx.builder
.build_int_z_extend_or_bit_cast(incr_val, init_val_t, "")
.unwrap();
let i = ctx.builder.build_int_add(i, incr_val, "").unwrap();
ctx.builder.build_store(i_addr, i).unwrap();
Ok(())
},
)
}
/// See [`CodeGenerator::gen_while`].
pub fn gen_while<G: CodeGenerator>(
generator: &mut G,