1
0
forked from M-Labs/nac3

core: Replace ndarray_init_dims IRRT impl with IR impl

Implementation of that function in IR allows for more flexibility in
terms of different integer type widths.
This commit is contained in:
David Mak 2024-03-06 16:08:36 +08:00
parent f682e9bf7a
commit 3d2abf73c8
3 changed files with 163 additions and 161 deletions

View File

@ -224,24 +224,6 @@ uint64_t __nac3_ndarray_calc_size64(
return num_elems; return num_elems;
} }
void __nac3_ndarray_init_dims(
uint32_t *ndarray_dims,
const int32_t *shape_data,
uint32_t shape_len
) {
__builtin_memcpy(ndarray_dims, shape_data, shape_len * sizeof(int32_t));
}
void __nac3_ndarray_init_dims64(
uint64_t *ndarray_dims,
const int32_t *shape_data,
uint64_t shape_len
) {
for (uint64_t i = 0; i < shape_len; ++i) {
ndarray_dims[i] = (uint64_t) shape_data[i];
}
}
void __nac3_ndarray_calc_nd_indices( void __nac3_ndarray_calc_nd_indices(
uint32_t index, uint32_t index,
const uint32_t* dims, const uint32_t* dims,

View File

@ -617,60 +617,6 @@ pub fn call_ndarray_calc_size<'ctx>(
.unwrap() .unwrap()
} }
/// Generates a call to `__nac3_ndarray_init_dims`.
///
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`.
/// * `shape` - LLVM pointer to the `shape` of the `NDArray`. This value must be the LLVM
/// representation of a `list`.
pub fn call_ndarray_init_dims<'ctx>(
generator: &dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
shape: ListValue<'ctx>,
) {
let llvm_void = ctx.ctx.void_type();
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_init_dims_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_init_dims",
64 => "__nac3_ndarray_init_dims64",
bw => unreachable!("Unsupported size type bit width: {}", bw)
};
let ndarray_init_dims_fn = ctx.module.get_function(ndarray_init_dims_fn_name).unwrap_or_else(|| {
let fn_type = llvm_void.fn_type(
&[
llvm_pusize.into(),
llvm_pi32.into(),
llvm_usize.into(),
],
false,
);
ctx.module.add_function(ndarray_init_dims_fn_name, fn_type, None)
});
let ndarray_dims = ndarray.get_dims();
let shape_data = shape.get_data();
let ndarray_num_dims = ndarray.load_ndims(ctx);
ctx.builder
.build_call(
ndarray_init_dims_fn,
&[
ndarray_dims.get_ptr(ctx).into(),
shape_data.get_ptr(ctx).into(),
ndarray_num_dims.into(),
],
"",
)
.unwrap();
}
/// Generates a call to `__nac3_ndarray_calc_nd_indices`. /// Generates a call to `__nac3_ndarray_calc_nd_indices`.
/// ///
/// * `index` - The index to compute the multidimensional index for. /// * `index` - The index to compute the multidimensional index for.

View File

@ -10,7 +10,6 @@ use crate::{
irrt::{ irrt::{
call_ndarray_calc_nd_indices, call_ndarray_calc_nd_indices,
call_ndarray_calc_size, call_ndarray_calc_size,
call_ndarray_init_dims,
}, },
llvm_intrinsics::call_memcpy_generic, llvm_intrinsics::call_memcpy_generic,
stmt::gen_for_callback stmt::gen_for_callback
@ -78,6 +77,161 @@ pub fn unpack_ndarray_tvars(
.unwrap() .unwrap()
} }
/// Creates an `NDArray` instance from a dynamic shape.
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The shape of the `NDArray`.
/// * `shape_len_fn` - A function that retrieves the number of dimensions from `shape`.
/// * `shape_data_fn` - A function that retrieves the size of a dimension from `shape`.
fn create_ndarray_dyn_shape<'ctx, 'a, V, LenFn, DataFn>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>,
elem_ty: Type,
shape: &V,
shape_len_fn: LenFn,
shape_data_fn: DataFn,
) -> Result<NDArrayValue<'ctx>, String>
where
LenFn: Fn(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, &V) -> Result<IntValue<'ctx>, String>,
DataFn: Fn(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, &V, IntValue<'ctx>) -> Result<IntValue<'ctx>, String>,
{
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None);
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum();
assert!(llvm_ndarray_data_t.is_sized());
// Assert that all dimensions are non-negative
gen_for_callback(
generator,
ctx,
|generator, ctx| {
let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
ctx.builder.build_store(i, llvm_usize.const_zero()).unwrap();
Ok(i)
},
|generator, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let shape_len = shape_len_fn(generator, ctx, shape)?;
debug_assert!(shape_len.get_type().get_bit_width() <= llvm_usize.get_bit_width());
Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, shape_len, "").unwrap())
},
|generator, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
let shape_dim_gez = ctx.builder
.build_int_compare(IntPredicate::SGE, shape_dim, shape_dim.get_type().const_zero(), "")
.unwrap();
ctx.make_assert(
generator,
shape_dim_gez,
"0:ValueError",
"negative dimensions not supported",
[None, None, None],
ctx.current_loc,
);
Ok(())
},
|_, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "").unwrap();
ctx.builder.build_store(i_addr, i).unwrap();
Ok(())
},
)?;
let ndarray = generator.gen_var_alloc(
ctx,
llvm_ndarray_t.into(),
None,
)?;
let ndarray = NDArrayValue::from_ptr_val(ndarray, llvm_usize, None);
let num_dims = shape_len_fn(generator, ctx, shape)?;
ndarray.store_ndims(ctx, generator, num_dims);
let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_dims(ctx, llvm_usize, ndarray_num_dims);
// Copy the dimension sizes from shape to ndarray.dims
gen_for_callback(
generator,
ctx,
|generator, ctx| {
let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
ctx.builder.build_store(i, llvm_usize.const_zero()).unwrap();
Ok(i)
},
|generator, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let shape_len = shape_len_fn(generator, ctx, shape)?;
debug_assert!(shape_len.get_type().get_bit_width() <= llvm_usize.get_bit_width());
Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, shape_len, "").unwrap())
},
|generator, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
let shape_dim = ctx.builder
.build_int_z_extend(shape_dim, llvm_usize, "")
.unwrap();
let ndarray_pdim = ndarray.get_dims().ptr_offset(ctx, generator, i, None);
ctx.builder.build_store(ndarray_pdim, shape_dim).unwrap();
Ok(())
},
|_, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "").unwrap();
ctx.builder.build_store(i_addr, i).unwrap();
Ok(())
},
)?;
let ndarray_num_elems = call_ndarray_calc_size(
generator,
ctx,
ndarray.load_ndims(ctx),
ndarray.get_dims().get_ptr(ctx),
);
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
Ok(ndarray)
}
/// Creates an `NDArray` instance from a constant shape. /// Creates an `NDArray` instance from a constant shape.
/// ///
/// * `elem_ty` - The element type of the `NDArray`. /// * `elem_ty` - The element type of the `NDArray`.
@ -205,98 +359,18 @@ fn call_ndarray_empty_impl<'ctx>(
elem_ty: Type, elem_ty: Type,
shape: ListValue<'ctx>, shape: ListValue<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> { ) -> Result<NDArrayValue<'ctx>, String> {
let ndarray_ty = make_ndarray_ty( create_ndarray_dyn_shape(
&mut ctx.unifier,
&ctx.primitives,
Some(elem_ty),
None,
);
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum();
assert!(llvm_ndarray_data_t.is_sized());
// Assert that all dimensions are non-negative
gen_for_callback(
generator, generator,
ctx, ctx,
|generator, ctx| { elem_ty,
let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; &shape,
ctx.builder.build_store(i, llvm_usize.const_zero()).unwrap(); |_, ctx, shape| {
Ok(shape.load_size(ctx, None))
Ok(i)
}, },
|_, ctx, i_addr| { |generator, ctx, shape, idx| {
let i = ctx.builder Ok(shape.get_data().get(ctx, generator, idx, None).into_int_value())
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let shape_len = shape.load_size(ctx, None);
Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, shape_len, "").unwrap())
}, },
|generator, ctx, i_addr| { )
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let shape_dim = shape.get_data().get(ctx, generator, i, None).into_int_value();
let shape_dim_gez = ctx.builder
.build_int_compare(IntPredicate::SGE, shape_dim, llvm_i32.const_zero(), "")
.unwrap();
ctx.make_assert(
generator,
shape_dim_gez,
"0:ValueError",
"negative dimensions not supported",
[None, None, None],
ctx.current_loc,
);
Ok(())
},
|_, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "").unwrap();
ctx.builder.build_store(i_addr, i).unwrap();
Ok(())
},
)?;
let ndarray = generator.gen_var_alloc(
ctx,
llvm_ndarray_t.into(),
None,
)?;
let ndarray = NDArrayValue::from_ptr_val(ndarray, llvm_usize, None);
let num_dims = shape.load_size(ctx, None);
ndarray.store_ndims(ctx, generator, num_dims);
let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_dims(ctx, llvm_usize, ndarray_num_dims);
call_ndarray_init_dims(generator, ctx, ndarray, shape);
let ndarray_num_elems = call_ndarray_calc_size(
generator,
ctx,
ndarray.load_ndims(ctx),
ndarray.get_dims().get_ptr(ctx),
);
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
Ok(ndarray)
} }
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as /// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as