forked from M-Labs/nac3
core: Add gen_for_callback_incrementing
Simplifies generation of monotonically increasing for loops.
This commit is contained in:
parent
50264e8750
commit
cfbc37c1ed
@ -8,7 +8,7 @@ use crate::codegen::{
|
|||||||
CodeGenerator,
|
CodeGenerator,
|
||||||
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index, call_ndarray_flatten_index_const},
|
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index, call_ndarray_flatten_index_const},
|
||||||
llvm_intrinsics::call_int_umin,
|
llvm_intrinsics::call_int_umin,
|
||||||
stmt::gen_for_callback,
|
stmt::gen_for_callback_incrementing,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[cfg(not(debug_assertions))]
|
#[cfg(not(debug_assertions))]
|
||||||
@ -940,30 +940,15 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
|
|||||||
ctx.current_loc,
|
ctx.current_loc,
|
||||||
);
|
);
|
||||||
|
|
||||||
gen_for_callback(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
|generator, ctx| {
|
|
||||||
let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
|
||||||
ctx.builder.build_store(i, llvm_usize.const_zero()).unwrap();
|
|
||||||
|
|
||||||
Ok(i)
|
|
||||||
},
|
|
||||||
|_, ctx, i_addr| {
|
|
||||||
let indices_len = indices.load_size(ctx, None);
|
let indices_len = indices.load_size(ctx, None);
|
||||||
let ndarray_len = self.0.load_ndims(ctx);
|
let ndarray_len = self.0.load_ndims(ctx);
|
||||||
|
|
||||||
let len = call_int_umin(ctx, indices_len, ndarray_len, None);
|
let len = call_int_umin(ctx, indices_len, ndarray_len, None);
|
||||||
|
gen_for_callback_incrementing(
|
||||||
let i = ctx.builder.build_load(i_addr, "")
|
generator,
|
||||||
.map(BasicValueEnum::into_int_value)
|
ctx,
|
||||||
.unwrap();
|
llvm_usize.const_zero(),
|
||||||
Ok(ctx.builder.build_int_compare(IntPredicate::SLT, i, len, "").unwrap())
|
(len, false),
|
||||||
},
|
|generator, ctx, i| {
|
||||||
|generator, ctx, i_addr| {
|
|
||||||
let i = ctx.builder.build_load(i_addr, "")
|
|
||||||
.map(BasicValueEnum::into_int_value)
|
|
||||||
.unwrap();
|
|
||||||
let (dim_idx, dim_sz) = unsafe {
|
let (dim_idx, dim_sz) = unsafe {
|
||||||
(
|
(
|
||||||
indices.data().get_unchecked(ctx, i, None).into_int_value(),
|
indices.data().get_unchecked(ctx, i, None).into_int_value(),
|
||||||
@ -989,16 +974,7 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
},
|
},
|
||||||
|_, ctx, i_addr| {
|
llvm_usize.const_int(1, false),
|
||||||
let i = ctx.builder
|
|
||||||
.build_load(i_addr, "")
|
|
||||||
.map(BasicValueEnum::into_int_value)
|
|
||||||
.unwrap();
|
|
||||||
let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "").unwrap();
|
|
||||||
ctx.builder.build_store(i_addr, i).unwrap();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
},
|
|
||||||
).unwrap();
|
).unwrap();
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
|
@ -14,7 +14,7 @@ use crate::{
|
|||||||
call_ndarray_calc_size,
|
call_ndarray_calc_size,
|
||||||
},
|
},
|
||||||
llvm_intrinsics::call_memcpy_generic,
|
llvm_intrinsics::call_memcpy_generic,
|
||||||
stmt::gen_for_callback
|
stmt::gen_for_callback_incrementing,
|
||||||
},
|
},
|
||||||
symbol_resolver::ValueEnum,
|
symbol_resolver::ValueEnum,
|
||||||
toplevel::{
|
toplevel::{
|
||||||
@ -52,30 +52,13 @@ fn create_ndarray_dyn_shape<'ctx, 'a, V, LenFn, DataFn>(
|
|||||||
assert!(llvm_ndarray_data_t.is_sized());
|
assert!(llvm_ndarray_data_t.is_sized());
|
||||||
|
|
||||||
// Assert that all dimensions are non-negative
|
// Assert that all dimensions are non-negative
|
||||||
gen_for_callback(
|
let shape_len = shape_len_fn(generator, ctx, shape)?;
|
||||||
|
gen_for_callback_incrementing(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
|generator, ctx| {
|
llvm_usize.const_zero(),
|
||||||
let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
(shape_len, false),
|
||||||
ctx.builder.build_store(i, llvm_usize.const_zero()).unwrap();
|
|generator, ctx, i| {
|
||||||
|
|
||||||
Ok(i)
|
|
||||||
},
|
|
||||||
|generator, ctx, i_addr| {
|
|
||||||
let i = ctx.builder
|
|
||||||
.build_load(i_addr, "")
|
|
||||||
.map(BasicValueEnum::into_int_value)
|
|
||||||
.unwrap();
|
|
||||||
let shape_len = shape_len_fn(generator, ctx, shape)?;
|
|
||||||
debug_assert!(shape_len.get_type().get_bit_width() <= llvm_usize.get_bit_width());
|
|
||||||
|
|
||||||
Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, shape_len, "").unwrap())
|
|
||||||
},
|
|
||||||
|generator, ctx, i_addr| {
|
|
||||||
let i = ctx.builder
|
|
||||||
.build_load(i_addr, "")
|
|
||||||
.map(BasicValueEnum::into_int_value)
|
|
||||||
.unwrap();
|
|
||||||
let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
|
let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
|
||||||
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
|
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
|
||||||
|
|
||||||
@ -94,16 +77,7 @@ fn create_ndarray_dyn_shape<'ctx, 'a, V, LenFn, DataFn>(
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
},
|
},
|
||||||
|_, ctx, i_addr| {
|
llvm_usize.const_int(1, false),
|
||||||
let i = ctx.builder
|
|
||||||
.build_load(i_addr, "")
|
|
||||||
.map(BasicValueEnum::into_int_value)
|
|
||||||
.unwrap();
|
|
||||||
let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "").unwrap();
|
|
||||||
ctx.builder.build_store(i_addr, i).unwrap();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
},
|
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let ndarray = generator.gen_var_alloc(
|
let ndarray = generator.gen_var_alloc(
|
||||||
@ -120,30 +94,13 @@ fn create_ndarray_dyn_shape<'ctx, 'a, V, LenFn, DataFn>(
|
|||||||
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
|
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
|
||||||
|
|
||||||
// Copy the dimension sizes from shape to ndarray.dims
|
// Copy the dimension sizes from shape to ndarray.dims
|
||||||
gen_for_callback(
|
let shape_len = shape_len_fn(generator, ctx, shape)?;
|
||||||
|
gen_for_callback_incrementing(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
|generator, ctx| {
|
llvm_usize.const_zero(),
|
||||||
let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
(shape_len, false),
|
||||||
ctx.builder.build_store(i, llvm_usize.const_zero()).unwrap();
|
|generator, ctx, i| {
|
||||||
|
|
||||||
Ok(i)
|
|
||||||
},
|
|
||||||
|generator, ctx, i_addr| {
|
|
||||||
let i = ctx.builder
|
|
||||||
.build_load(i_addr, "")
|
|
||||||
.map(BasicValueEnum::into_int_value)
|
|
||||||
.unwrap();
|
|
||||||
let shape_len = shape_len_fn(generator, ctx, shape)?;
|
|
||||||
debug_assert!(shape_len.get_type().get_bit_width() <= llvm_usize.get_bit_width());
|
|
||||||
|
|
||||||
Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, shape_len, "").unwrap())
|
|
||||||
},
|
|
||||||
|generator, ctx, i_addr| {
|
|
||||||
let i = ctx.builder
|
|
||||||
.build_load(i_addr, "")
|
|
||||||
.map(BasicValueEnum::into_int_value)
|
|
||||||
.unwrap();
|
|
||||||
let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
|
let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
|
||||||
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
|
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
|
||||||
let shape_dim = ctx.builder
|
let shape_dim = ctx.builder
|
||||||
@ -158,16 +115,7 @@ fn create_ndarray_dyn_shape<'ctx, 'a, V, LenFn, DataFn>(
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
},
|
},
|
||||||
|_, ctx, i_addr| {
|
llvm_usize.const_int(1, false),
|
||||||
let i = ctx.builder
|
|
||||||
.build_load(i_addr, "")
|
|
||||||
.map(BasicValueEnum::into_int_value)
|
|
||||||
.unwrap();
|
|
||||||
let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "").unwrap();
|
|
||||||
ctx.builder.build_store(i_addr, i).unwrap();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
},
|
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let ndarray_num_elems = call_ndarray_calc_size(
|
let ndarray_num_elems = call_ndarray_calc_size(
|
||||||
@ -342,28 +290,12 @@ fn ndarray_fill_flattened<'ctx, 'a, ValueFn>(
|
|||||||
ndarray.dim_sizes().as_ptr_value(ctx),
|
ndarray.dim_sizes().as_ptr_value(ctx),
|
||||||
);
|
);
|
||||||
|
|
||||||
gen_for_callback(
|
gen_for_callback_incrementing(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
|generator, ctx| {
|
llvm_usize.const_zero(),
|
||||||
let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
(ndarray_num_elems, false),
|
||||||
ctx.builder.build_store(i, llvm_usize.const_zero()).unwrap();
|
|generator, ctx, i| {
|
||||||
|
|
||||||
Ok(i)
|
|
||||||
},
|
|
||||||
|_, ctx, i_addr| {
|
|
||||||
let i = ctx.builder
|
|
||||||
.build_load(i_addr, "")
|
|
||||||
.map(BasicValueEnum::into_int_value)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, ndarray_num_elems, "").unwrap())
|
|
||||||
},
|
|
||||||
|generator, ctx, i_addr| {
|
|
||||||
let i = ctx.builder
|
|
||||||
.build_load(i_addr, "")
|
|
||||||
.map(BasicValueEnum::into_int_value)
|
|
||||||
.unwrap();
|
|
||||||
let elem = unsafe {
|
let elem = unsafe {
|
||||||
ndarray.data().ptr_to_data_flattened_unchecked(ctx, i, None)
|
ndarray.data().ptr_to_data_flattened_unchecked(ctx, i, None)
|
||||||
};
|
};
|
||||||
@ -373,16 +305,7 @@ fn ndarray_fill_flattened<'ctx, 'a, ValueFn>(
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
},
|
},
|
||||||
|_, ctx, i_addr| {
|
llvm_usize.const_int(1, false),
|
||||||
let i = ctx.builder
|
|
||||||
.build_load(i_addr, "")
|
|
||||||
.map(BasicValueEnum::into_int_value)
|
|
||||||
.unwrap();
|
|
||||||
let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "").unwrap();
|
|
||||||
ctx.builder.build_store(i_addr, i).unwrap();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -536,6 +536,84 @@ pub fn gen_for_callback<'ctx, 'a, I, InitFn, CondFn, BodyFn, UpdateFn>(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generates a C-style monotonically-increasing `for` construct using lambdas, similar to the
|
||||||
|
/// following C code:
|
||||||
|
///
|
||||||
|
/// ```c
|
||||||
|
/// for (int x = init_val; x /* < or <= ; see `max_val` */ max_val; x += incr_val) {
|
||||||
|
/// body(x);
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// * `init_val` - The initial value of the loop variable. The type of this value will also be used
|
||||||
|
/// as the type of the loop variable.
|
||||||
|
/// * `max_val` - A tuple containing the maximum value of the loop variable, and whether the maximum
|
||||||
|
/// value should be treated as inclusive (as opposed to exclusive).
|
||||||
|
/// * `body` - A lambda containing IR statements within the loop body.
|
||||||
|
/// * `incr_val` - The value to increment the loop variable on each iteration.
|
||||||
|
pub fn gen_for_callback_incrementing<'ctx, 'a, BodyFn>(
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
init_val: IntValue<'ctx>,
|
||||||
|
max_val: (IntValue<'ctx>, bool),
|
||||||
|
body: BodyFn,
|
||||||
|
incr_val: IntValue<'ctx>,
|
||||||
|
) -> Result<(), String>
|
||||||
|
where
|
||||||
|
BodyFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<(), String>,
|
||||||
|
{
|
||||||
|
let init_val_t = init_val.get_type();
|
||||||
|
|
||||||
|
gen_for_callback(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
|generator, ctx| {
|
||||||
|
let i_addr = generator.gen_var_alloc(ctx, init_val_t.into(), None)?;
|
||||||
|
ctx.builder.build_store(i_addr, init_val).unwrap();
|
||||||
|
|
||||||
|
Ok(i_addr)
|
||||||
|
},
|
||||||
|
|_, ctx, i_addr| {
|
||||||
|
let cmp_op = if max_val.1 {
|
||||||
|
IntPredicate::ULE
|
||||||
|
} else {
|
||||||
|
IntPredicate::ULT
|
||||||
|
};
|
||||||
|
|
||||||
|
let i = ctx.builder
|
||||||
|
.build_load(i_addr, "")
|
||||||
|
.map(BasicValueEnum::into_int_value)
|
||||||
|
.unwrap();
|
||||||
|
let max_val = ctx.builder
|
||||||
|
.build_int_z_extend_or_bit_cast(max_val.0, init_val_t, "")
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
Ok(ctx.builder.build_int_compare(cmp_op, i, max_val, "").unwrap())
|
||||||
|
},
|
||||||
|
|generator, ctx, i_addr| {
|
||||||
|
let i = ctx.builder
|
||||||
|
.build_load(i_addr, "")
|
||||||
|
.map(BasicValueEnum::into_int_value)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
body(generator, ctx, i)
|
||||||
|
},
|
||||||
|
|_, ctx, i_addr| {
|
||||||
|
let i = ctx.builder
|
||||||
|
.build_load(i_addr, "")
|
||||||
|
.map(BasicValueEnum::into_int_value)
|
||||||
|
.unwrap();
|
||||||
|
let incr_val = ctx.builder
|
||||||
|
.build_int_z_extend_or_bit_cast(incr_val, init_val_t, "")
|
||||||
|
.unwrap();
|
||||||
|
let i = ctx.builder.build_int_add(i, incr_val, "").unwrap();
|
||||||
|
ctx.builder.build_store(i_addr, i).unwrap();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
/// See [`CodeGenerator::gen_while`].
|
/// See [`CodeGenerator::gen_while`].
|
||||||
pub fn gen_while<G: CodeGenerator>(
|
pub fn gen_while<G: CodeGenerator>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
|
Loading…
Reference in New Issue
Block a user