forked from M-Labs/nac3
core: Add gen_for_callback_incrementing
Simplifies generation of monotonically increasing for loops.
This commit is contained in:
parent
50264e8750
commit
cfbc37c1ed
nac3core/src/codegen
@ -8,7 +8,7 @@ use crate::codegen::{
|
||||
CodeGenerator,
|
||||
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index, call_ndarray_flatten_index_const},
|
||||
llvm_intrinsics::call_int_umin,
|
||||
stmt::gen_for_callback,
|
||||
stmt::gen_for_callback_incrementing,
|
||||
};
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
@ -940,30 +940,15 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
gen_for_callback(
|
||||
generator,
|
||||
ctx,
|
||||
|generator, ctx| {
|
||||
let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||
ctx.builder.build_store(i, llvm_usize.const_zero()).unwrap();
|
||||
|
||||
Ok(i)
|
||||
},
|
||||
|_, ctx, i_addr| {
|
||||
let indices_len = indices.load_size(ctx, None);
|
||||
let ndarray_len = self.0.load_ndims(ctx);
|
||||
|
||||
let len = call_int_umin(ctx, indices_len, ndarray_len, None);
|
||||
|
||||
let i = ctx.builder.build_load(i_addr, "")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
Ok(ctx.builder.build_int_compare(IntPredicate::SLT, i, len, "").unwrap())
|
||||
},
|
||||
|generator, ctx, i_addr| {
|
||||
let i = ctx.builder.build_load(i_addr, "")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
ctx,
|
||||
llvm_usize.const_zero(),
|
||||
(len, false),
|
||||
|generator, ctx, i| {
|
||||
let (dim_idx, dim_sz) = unsafe {
|
||||
(
|
||||
indices.data().get_unchecked(ctx, i, None).into_int_value(),
|
||||
@ -989,16 +974,7 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
|
||||
|
||||
Ok(())
|
||||
},
|
||||
|_, ctx, i_addr| {
|
||||
let i = ctx.builder
|
||||
.build_load(i_addr, "")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "").unwrap();
|
||||
ctx.builder.build_store(i_addr, i).unwrap();
|
||||
|
||||
Ok(())
|
||||
},
|
||||
llvm_usize.const_int(1, false),
|
||||
).unwrap();
|
||||
|
||||
unsafe {
|
||||
|
@ -14,7 +14,7 @@ use crate::{
|
||||
call_ndarray_calc_size,
|
||||
},
|
||||
llvm_intrinsics::call_memcpy_generic,
|
||||
stmt::gen_for_callback
|
||||
stmt::gen_for_callback_incrementing,
|
||||
},
|
||||
symbol_resolver::ValueEnum,
|
||||
toplevel::{
|
||||
@ -52,30 +52,13 @@ fn create_ndarray_dyn_shape<'ctx, 'a, V, LenFn, DataFn>(
|
||||
assert!(llvm_ndarray_data_t.is_sized());
|
||||
|
||||
// Assert that all dimensions are non-negative
|
||||
gen_for_callback(
|
||||
let shape_len = shape_len_fn(generator, ctx, shape)?;
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
ctx,
|
||||
|generator, ctx| {
|
||||
let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||
ctx.builder.build_store(i, llvm_usize.const_zero()).unwrap();
|
||||
|
||||
Ok(i)
|
||||
},
|
||||
|generator, ctx, i_addr| {
|
||||
let i = ctx.builder
|
||||
.build_load(i_addr, "")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
let shape_len = shape_len_fn(generator, ctx, shape)?;
|
||||
debug_assert!(shape_len.get_type().get_bit_width() <= llvm_usize.get_bit_width());
|
||||
|
||||
Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, shape_len, "").unwrap())
|
||||
},
|
||||
|generator, ctx, i_addr| {
|
||||
let i = ctx.builder
|
||||
.build_load(i_addr, "")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
llvm_usize.const_zero(),
|
||||
(shape_len, false),
|
||||
|generator, ctx, i| {
|
||||
let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
|
||||
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
|
||||
|
||||
@ -94,16 +77,7 @@ fn create_ndarray_dyn_shape<'ctx, 'a, V, LenFn, DataFn>(
|
||||
|
||||
Ok(())
|
||||
},
|
||||
|_, ctx, i_addr| {
|
||||
let i = ctx.builder
|
||||
.build_load(i_addr, "")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "").unwrap();
|
||||
ctx.builder.build_store(i_addr, i).unwrap();
|
||||
|
||||
Ok(())
|
||||
},
|
||||
llvm_usize.const_int(1, false),
|
||||
)?;
|
||||
|
||||
let ndarray = generator.gen_var_alloc(
|
||||
@ -120,30 +94,13 @@ fn create_ndarray_dyn_shape<'ctx, 'a, V, LenFn, DataFn>(
|
||||
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
|
||||
|
||||
// Copy the dimension sizes from shape to ndarray.dims
|
||||
gen_for_callback(
|
||||
let shape_len = shape_len_fn(generator, ctx, shape)?;
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
ctx,
|
||||
|generator, ctx| {
|
||||
let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||
ctx.builder.build_store(i, llvm_usize.const_zero()).unwrap();
|
||||
|
||||
Ok(i)
|
||||
},
|
||||
|generator, ctx, i_addr| {
|
||||
let i = ctx.builder
|
||||
.build_load(i_addr, "")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
let shape_len = shape_len_fn(generator, ctx, shape)?;
|
||||
debug_assert!(shape_len.get_type().get_bit_width() <= llvm_usize.get_bit_width());
|
||||
|
||||
Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, shape_len, "").unwrap())
|
||||
},
|
||||
|generator, ctx, i_addr| {
|
||||
let i = ctx.builder
|
||||
.build_load(i_addr, "")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
llvm_usize.const_zero(),
|
||||
(shape_len, false),
|
||||
|generator, ctx, i| {
|
||||
let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
|
||||
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
|
||||
let shape_dim = ctx.builder
|
||||
@ -158,16 +115,7 @@ fn create_ndarray_dyn_shape<'ctx, 'a, V, LenFn, DataFn>(
|
||||
|
||||
Ok(())
|
||||
},
|
||||
|_, ctx, i_addr| {
|
||||
let i = ctx.builder
|
||||
.build_load(i_addr, "")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "").unwrap();
|
||||
ctx.builder.build_store(i_addr, i).unwrap();
|
||||
|
||||
Ok(())
|
||||
},
|
||||
llvm_usize.const_int(1, false),
|
||||
)?;
|
||||
|
||||
let ndarray_num_elems = call_ndarray_calc_size(
|
||||
@ -342,28 +290,12 @@ fn ndarray_fill_flattened<'ctx, 'a, ValueFn>(
|
||||
ndarray.dim_sizes().as_ptr_value(ctx),
|
||||
);
|
||||
|
||||
gen_for_callback(
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
ctx,
|
||||
|generator, ctx| {
|
||||
let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||
ctx.builder.build_store(i, llvm_usize.const_zero()).unwrap();
|
||||
|
||||
Ok(i)
|
||||
},
|
||||
|_, ctx, i_addr| {
|
||||
let i = ctx.builder
|
||||
.build_load(i_addr, "")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
|
||||
Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, ndarray_num_elems, "").unwrap())
|
||||
},
|
||||
|generator, ctx, i_addr| {
|
||||
let i = ctx.builder
|
||||
.build_load(i_addr, "")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
llvm_usize.const_zero(),
|
||||
(ndarray_num_elems, false),
|
||||
|generator, ctx, i| {
|
||||
let elem = unsafe {
|
||||
ndarray.data().ptr_to_data_flattened_unchecked(ctx, i, None)
|
||||
};
|
||||
@ -373,16 +305,7 @@ fn ndarray_fill_flattened<'ctx, 'a, ValueFn>(
|
||||
|
||||
Ok(())
|
||||
},
|
||||
|_, ctx, i_addr| {
|
||||
let i = ctx.builder
|
||||
.build_load(i_addr, "")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "").unwrap();
|
||||
ctx.builder.build_store(i_addr, i).unwrap();
|
||||
|
||||
Ok(())
|
||||
},
|
||||
llvm_usize.const_int(1, false),
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -536,6 +536,84 @@ pub fn gen_for_callback<'ctx, 'a, I, InitFn, CondFn, BodyFn, UpdateFn>(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Generates a C-style monotonically-increasing `for` construct using lambdas, similar to the
|
||||
/// following C code:
|
||||
///
|
||||
/// ```c
|
||||
/// for (int x = init_val; x /* < or <= ; see `max_val` */ max_val; x += incr_val) {
|
||||
/// body(x);
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// * `init_val` - The initial value of the loop variable. The type of this value will also be used
|
||||
/// as the type of the loop variable.
|
||||
/// * `max_val` - A tuple containing the maximum value of the loop variable, and whether the maximum
|
||||
/// value should be treated as inclusive (as opposed to exclusive).
|
||||
/// * `body` - A lambda containing IR statements within the loop body.
|
||||
/// * `incr_val` - The value to increment the loop variable on each iteration.
|
||||
pub fn gen_for_callback_incrementing<'ctx, 'a, BodyFn>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
init_val: IntValue<'ctx>,
|
||||
max_val: (IntValue<'ctx>, bool),
|
||||
body: BodyFn,
|
||||
incr_val: IntValue<'ctx>,
|
||||
) -> Result<(), String>
|
||||
where
|
||||
BodyFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<(), String>,
|
||||
{
|
||||
let init_val_t = init_val.get_type();
|
||||
|
||||
gen_for_callback(
|
||||
generator,
|
||||
ctx,
|
||||
|generator, ctx| {
|
||||
let i_addr = generator.gen_var_alloc(ctx, init_val_t.into(), None)?;
|
||||
ctx.builder.build_store(i_addr, init_val).unwrap();
|
||||
|
||||
Ok(i_addr)
|
||||
},
|
||||
|_, ctx, i_addr| {
|
||||
let cmp_op = if max_val.1 {
|
||||
IntPredicate::ULE
|
||||
} else {
|
||||
IntPredicate::ULT
|
||||
};
|
||||
|
||||
let i = ctx.builder
|
||||
.build_load(i_addr, "")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
let max_val = ctx.builder
|
||||
.build_int_z_extend_or_bit_cast(max_val.0, init_val_t, "")
|
||||
.unwrap();
|
||||
|
||||
Ok(ctx.builder.build_int_compare(cmp_op, i, max_val, "").unwrap())
|
||||
},
|
||||
|generator, ctx, i_addr| {
|
||||
let i = ctx.builder
|
||||
.build_load(i_addr, "")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
|
||||
body(generator, ctx, i)
|
||||
},
|
||||
|_, ctx, i_addr| {
|
||||
let i = ctx.builder
|
||||
.build_load(i_addr, "")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
let incr_val = ctx.builder
|
||||
.build_int_z_extend_or_bit_cast(incr_val, init_val_t, "")
|
||||
.unwrap();
|
||||
let i = ctx.builder.build_int_add(i, incr_val, "").unwrap();
|
||||
ctx.builder.build_store(i_addr, i).unwrap();
|
||||
|
||||
Ok(())
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// See [`CodeGenerator::gen_while`].
|
||||
pub fn gen_while<G: CodeGenerator>(
|
||||
generator: &mut G,
|
||||
|
Loading…
Reference in New Issue
Block a user