forked from M-Labs/nac3
core: Replace ndarray_init_dims IRRT impl with IR impl
Implementation of that function in IR allows for more flexibility in terms of different integer type widths.
This commit is contained in:
parent
f682e9bf7a
commit
3d2abf73c8
|
@ -224,24 +224,6 @@ uint64_t __nac3_ndarray_calc_size64(
|
|||
return num_elems;
|
||||
}
|
||||
|
||||
void __nac3_ndarray_init_dims(
|
||||
uint32_t *ndarray_dims,
|
||||
const int32_t *shape_data,
|
||||
uint32_t shape_len
|
||||
) {
|
||||
__builtin_memcpy(ndarray_dims, shape_data, shape_len * sizeof(int32_t));
|
||||
}
|
||||
|
||||
void __nac3_ndarray_init_dims64(
|
||||
uint64_t *ndarray_dims,
|
||||
const int32_t *shape_data,
|
||||
uint64_t shape_len
|
||||
) {
|
||||
for (uint64_t i = 0; i < shape_len; ++i) {
|
||||
ndarray_dims[i] = (uint64_t) shape_data[i];
|
||||
}
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_nd_indices(
|
||||
uint32_t index,
|
||||
const uint32_t* dims,
|
||||
|
|
|
@ -617,60 +617,6 @@ pub fn call_ndarray_calc_size<'ctx>(
|
|||
.unwrap()
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_init_dims`.
|
||||
///
|
||||
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
|
||||
/// `NDArray`.
|
||||
/// * `shape` - LLVM pointer to the `shape` of the `NDArray`. This value must be the LLVM
|
||||
/// representation of a `list`.
|
||||
pub fn call_ndarray_init_dims<'ctx>(
|
||||
generator: &dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
shape: ListValue<'ctx>,
|
||||
) {
|
||||
let llvm_void = ctx.ctx.void_type();
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
|
||||
let ndarray_init_dims_fn_name = match llvm_usize.get_bit_width() {
|
||||
32 => "__nac3_ndarray_init_dims",
|
||||
64 => "__nac3_ndarray_init_dims64",
|
||||
bw => unreachable!("Unsupported size type bit width: {}", bw)
|
||||
};
|
||||
let ndarray_init_dims_fn = ctx.module.get_function(ndarray_init_dims_fn_name).unwrap_or_else(|| {
|
||||
let fn_type = llvm_void.fn_type(
|
||||
&[
|
||||
llvm_pusize.into(),
|
||||
llvm_pi32.into(),
|
||||
llvm_usize.into(),
|
||||
],
|
||||
false,
|
||||
);
|
||||
|
||||
ctx.module.add_function(ndarray_init_dims_fn_name, fn_type, None)
|
||||
});
|
||||
|
||||
let ndarray_dims = ndarray.get_dims();
|
||||
let shape_data = shape.get_data();
|
||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||
|
||||
ctx.builder
|
||||
.build_call(
|
||||
ndarray_init_dims_fn,
|
||||
&[
|
||||
ndarray_dims.get_ptr(ctx).into(),
|
||||
shape_data.get_ptr(ctx).into(),
|
||||
ndarray_num_dims.into(),
|
||||
],
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_calc_nd_indices`.
|
||||
///
|
||||
/// * `index` - The index to compute the multidimensional index for.
|
||||
|
|
|
@ -10,7 +10,6 @@ use crate::{
|
|||
irrt::{
|
||||
call_ndarray_calc_nd_indices,
|
||||
call_ndarray_calc_size,
|
||||
call_ndarray_init_dims,
|
||||
},
|
||||
llvm_intrinsics::call_memcpy_generic,
|
||||
stmt::gen_for_callback
|
||||
|
@ -78,6 +77,161 @@ pub fn unpack_ndarray_tvars(
|
|||
.unwrap()
|
||||
}
|
||||
|
||||
/// Creates an `NDArray` instance from a dynamic shape.
|
||||
///
|
||||
/// * `elem_ty` - The element type of the `NDArray`.
|
||||
/// * `shape` - The shape of the `NDArray`.
|
||||
/// * `shape_len_fn` - A function that retrieves the number of dimensions from `shape`.
|
||||
/// * `shape_data_fn` - A function that retrieves the size of a dimension from `shape`.
|
||||
fn create_ndarray_dyn_shape<'ctx, 'a, V, LenFn, DataFn>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
elem_ty: Type,
|
||||
shape: &V,
|
||||
shape_len_fn: LenFn,
|
||||
shape_data_fn: DataFn,
|
||||
) -> Result<NDArrayValue<'ctx>, String>
|
||||
where
|
||||
LenFn: Fn(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, &V) -> Result<IntValue<'ctx>, String>,
|
||||
DataFn: Fn(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, &V, IntValue<'ctx>) -> Result<IntValue<'ctx>, String>,
|
||||
{
|
||||
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None);
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
|
||||
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
|
||||
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum();
|
||||
assert!(llvm_ndarray_data_t.is_sized());
|
||||
|
||||
// Assert that all dimensions are non-negative
|
||||
gen_for_callback(
|
||||
generator,
|
||||
ctx,
|
||||
|generator, ctx| {
|
||||
let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||
ctx.builder.build_store(i, llvm_usize.const_zero()).unwrap();
|
||||
|
||||
Ok(i)
|
||||
},
|
||||
|generator, ctx, i_addr| {
|
||||
let i = ctx.builder
|
||||
.build_load(i_addr, "")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
let shape_len = shape_len_fn(generator, ctx, shape)?;
|
||||
debug_assert!(shape_len.get_type().get_bit_width() <= llvm_usize.get_bit_width());
|
||||
|
||||
Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, shape_len, "").unwrap())
|
||||
},
|
||||
|generator, ctx, i_addr| {
|
||||
let i = ctx.builder
|
||||
.build_load(i_addr, "")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
|
||||
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
|
||||
|
||||
let shape_dim_gez = ctx.builder
|
||||
.build_int_compare(IntPredicate::SGE, shape_dim, shape_dim.get_type().const_zero(), "")
|
||||
.unwrap();
|
||||
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
shape_dim_gez,
|
||||
"0:ValueError",
|
||||
"negative dimensions not supported",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
},
|
||||
|_, ctx, i_addr| {
|
||||
let i = ctx.builder
|
||||
.build_load(i_addr, "")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "").unwrap();
|
||||
ctx.builder.build_store(i_addr, i).unwrap();
|
||||
|
||||
Ok(())
|
||||
},
|
||||
)?;
|
||||
|
||||
let ndarray = generator.gen_var_alloc(
|
||||
ctx,
|
||||
llvm_ndarray_t.into(),
|
||||
None,
|
||||
)?;
|
||||
let ndarray = NDArrayValue::from_ptr_val(ndarray, llvm_usize, None);
|
||||
|
||||
let num_dims = shape_len_fn(generator, ctx, shape)?;
|
||||
ndarray.store_ndims(ctx, generator, num_dims);
|
||||
|
||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||
ndarray.create_dims(ctx, llvm_usize, ndarray_num_dims);
|
||||
|
||||
// Copy the dimension sizes from shape to ndarray.dims
|
||||
gen_for_callback(
|
||||
generator,
|
||||
ctx,
|
||||
|generator, ctx| {
|
||||
let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||
ctx.builder.build_store(i, llvm_usize.const_zero()).unwrap();
|
||||
|
||||
Ok(i)
|
||||
},
|
||||
|generator, ctx, i_addr| {
|
||||
let i = ctx.builder
|
||||
.build_load(i_addr, "")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
let shape_len = shape_len_fn(generator, ctx, shape)?;
|
||||
debug_assert!(shape_len.get_type().get_bit_width() <= llvm_usize.get_bit_width());
|
||||
|
||||
Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, shape_len, "").unwrap())
|
||||
},
|
||||
|generator, ctx, i_addr| {
|
||||
let i = ctx.builder
|
||||
.build_load(i_addr, "")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
|
||||
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
|
||||
let shape_dim = ctx.builder
|
||||
.build_int_z_extend(shape_dim, llvm_usize, "")
|
||||
.unwrap();
|
||||
|
||||
let ndarray_pdim = ndarray.get_dims().ptr_offset(ctx, generator, i, None);
|
||||
|
||||
ctx.builder.build_store(ndarray_pdim, shape_dim).unwrap();
|
||||
|
||||
Ok(())
|
||||
},
|
||||
|_, ctx, i_addr| {
|
||||
let i = ctx.builder
|
||||
.build_load(i_addr, "")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "").unwrap();
|
||||
ctx.builder.build_store(i_addr, i).unwrap();
|
||||
|
||||
Ok(())
|
||||
},
|
||||
)?;
|
||||
|
||||
let ndarray_num_elems = call_ndarray_calc_size(
|
||||
generator,
|
||||
ctx,
|
||||
ndarray.load_ndims(ctx),
|
||||
ndarray.get_dims().get_ptr(ctx),
|
||||
);
|
||||
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
||||
|
||||
Ok(ndarray)
|
||||
}
|
||||
|
||||
/// Creates an `NDArray` instance from a constant shape.
|
||||
///
|
||||
/// * `elem_ty` - The element type of the `NDArray`.
|
||||
|
@ -205,98 +359,18 @@ fn call_ndarray_empty_impl<'ctx>(
|
|||
elem_ty: Type,
|
||||
shape: ListValue<'ctx>,
|
||||
) -> Result<NDArrayValue<'ctx>, String> {
|
||||
let ndarray_ty = make_ndarray_ty(
|
||||
&mut ctx.unifier,
|
||||
&ctx.primitives,
|
||||
Some(elem_ty),
|
||||
None,
|
||||
);
|
||||
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
|
||||
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
|
||||
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum();
|
||||
assert!(llvm_ndarray_data_t.is_sized());
|
||||
|
||||
// Assert that all dimensions are non-negative
|
||||
gen_for_callback(
|
||||
create_ndarray_dyn_shape(
|
||||
generator,
|
||||
ctx,
|
||||
|generator, ctx| {
|
||||
let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||
ctx.builder.build_store(i, llvm_usize.const_zero()).unwrap();
|
||||
|
||||
Ok(i)
|
||||
elem_ty,
|
||||
&shape,
|
||||
|_, ctx, shape| {
|
||||
Ok(shape.load_size(ctx, None))
|
||||
},
|
||||
|_, ctx, i_addr| {
|
||||
let i = ctx.builder
|
||||
.build_load(i_addr, "")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
let shape_len = shape.load_size(ctx, None);
|
||||
|
||||
Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, shape_len, "").unwrap())
|
||||
|generator, ctx, shape, idx| {
|
||||
Ok(shape.get_data().get(ctx, generator, idx, None).into_int_value())
|
||||
},
|
||||
|generator, ctx, i_addr| {
|
||||
let i = ctx.builder
|
||||
.build_load(i_addr, "")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
let shape_dim = shape.get_data().get(ctx, generator, i, None).into_int_value();
|
||||
|
||||
let shape_dim_gez = ctx.builder
|
||||
.build_int_compare(IntPredicate::SGE, shape_dim, llvm_i32.const_zero(), "")
|
||||
.unwrap();
|
||||
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
shape_dim_gez,
|
||||
"0:ValueError",
|
||||
"negative dimensions not supported",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
},
|
||||
|_, ctx, i_addr| {
|
||||
let i = ctx.builder
|
||||
.build_load(i_addr, "")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "").unwrap();
|
||||
ctx.builder.build_store(i_addr, i).unwrap();
|
||||
|
||||
Ok(())
|
||||
},
|
||||
)?;
|
||||
|
||||
let ndarray = generator.gen_var_alloc(
|
||||
ctx,
|
||||
llvm_ndarray_t.into(),
|
||||
None,
|
||||
)?;
|
||||
let ndarray = NDArrayValue::from_ptr_val(ndarray, llvm_usize, None);
|
||||
|
||||
let num_dims = shape.load_size(ctx, None);
|
||||
ndarray.store_ndims(ctx, generator, num_dims);
|
||||
|
||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||
ndarray.create_dims(ctx, llvm_usize, ndarray_num_dims);
|
||||
|
||||
call_ndarray_init_dims(generator, ctx, ndarray, shape);
|
||||
|
||||
let ndarray_num_elems = call_ndarray_calc_size(
|
||||
generator,
|
||||
ctx,
|
||||
ndarray.load_ndims(ctx),
|
||||
ndarray.get_dims().get_ptr(ctx),
|
||||
);
|
||||
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
||||
|
||||
Ok(ndarray)
|
||||
)
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as
|
||||
|
|
Loading…
Reference in New Issue