forked from M-Labs/nac3
core: Replace ndarray_init_dims IRRT impl with IR impl
Implementation of that function in IR allows for more flexibility in terms of different integer type widths.
This commit is contained in:
parent
f682e9bf7a
commit
3d2abf73c8
@ -224,24 +224,6 @@ uint64_t __nac3_ndarray_calc_size64(
|
|||||||
return num_elems;
|
return num_elems;
|
||||||
}
|
}
|
||||||
|
|
||||||
void __nac3_ndarray_init_dims(
|
|
||||||
uint32_t *ndarray_dims,
|
|
||||||
const int32_t *shape_data,
|
|
||||||
uint32_t shape_len
|
|
||||||
) {
|
|
||||||
__builtin_memcpy(ndarray_dims, shape_data, shape_len * sizeof(int32_t));
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_init_dims64(
|
|
||||||
uint64_t *ndarray_dims,
|
|
||||||
const int32_t *shape_data,
|
|
||||||
uint64_t shape_len
|
|
||||||
) {
|
|
||||||
for (uint64_t i = 0; i < shape_len; ++i) {
|
|
||||||
ndarray_dims[i] = (uint64_t) shape_data[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_calc_nd_indices(
|
void __nac3_ndarray_calc_nd_indices(
|
||||||
uint32_t index,
|
uint32_t index,
|
||||||
const uint32_t* dims,
|
const uint32_t* dims,
|
||||||
|
@ -617,60 +617,6 @@ pub fn call_ndarray_calc_size<'ctx>(
|
|||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates a call to `__nac3_ndarray_init_dims`.
|
|
||||||
///
|
|
||||||
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
|
|
||||||
/// `NDArray`.
|
|
||||||
/// * `shape` - LLVM pointer to the `shape` of the `NDArray`. This value must be the LLVM
|
|
||||||
/// representation of a `list`.
|
|
||||||
pub fn call_ndarray_init_dims<'ctx>(
|
|
||||||
generator: &dyn CodeGenerator,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
ndarray: NDArrayValue<'ctx>,
|
|
||||||
shape: ListValue<'ctx>,
|
|
||||||
) {
|
|
||||||
let llvm_void = ctx.ctx.void_type();
|
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
|
|
||||||
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
|
|
||||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
|
||||||
|
|
||||||
let ndarray_init_dims_fn_name = match llvm_usize.get_bit_width() {
|
|
||||||
32 => "__nac3_ndarray_init_dims",
|
|
||||||
64 => "__nac3_ndarray_init_dims64",
|
|
||||||
bw => unreachable!("Unsupported size type bit width: {}", bw)
|
|
||||||
};
|
|
||||||
let ndarray_init_dims_fn = ctx.module.get_function(ndarray_init_dims_fn_name).unwrap_or_else(|| {
|
|
||||||
let fn_type = llvm_void.fn_type(
|
|
||||||
&[
|
|
||||||
llvm_pusize.into(),
|
|
||||||
llvm_pi32.into(),
|
|
||||||
llvm_usize.into(),
|
|
||||||
],
|
|
||||||
false,
|
|
||||||
);
|
|
||||||
|
|
||||||
ctx.module.add_function(ndarray_init_dims_fn_name, fn_type, None)
|
|
||||||
});
|
|
||||||
|
|
||||||
let ndarray_dims = ndarray.get_dims();
|
|
||||||
let shape_data = shape.get_data();
|
|
||||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
|
||||||
|
|
||||||
ctx.builder
|
|
||||||
.build_call(
|
|
||||||
ndarray_init_dims_fn,
|
|
||||||
&[
|
|
||||||
ndarray_dims.get_ptr(ctx).into(),
|
|
||||||
shape_data.get_ptr(ctx).into(),
|
|
||||||
ndarray_num_dims.into(),
|
|
||||||
],
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates a call to `__nac3_ndarray_calc_nd_indices`.
|
/// Generates a call to `__nac3_ndarray_calc_nd_indices`.
|
||||||
///
|
///
|
||||||
/// * `index` - The index to compute the multidimensional index for.
|
/// * `index` - The index to compute the multidimensional index for.
|
||||||
|
@ -10,7 +10,6 @@ use crate::{
|
|||||||
irrt::{
|
irrt::{
|
||||||
call_ndarray_calc_nd_indices,
|
call_ndarray_calc_nd_indices,
|
||||||
call_ndarray_calc_size,
|
call_ndarray_calc_size,
|
||||||
call_ndarray_init_dims,
|
|
||||||
},
|
},
|
||||||
llvm_intrinsics::call_memcpy_generic,
|
llvm_intrinsics::call_memcpy_generic,
|
||||||
stmt::gen_for_callback
|
stmt::gen_for_callback
|
||||||
@ -78,6 +77,161 @@ pub fn unpack_ndarray_tvars(
|
|||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Creates an `NDArray` instance from a dynamic shape.
|
||||||
|
///
|
||||||
|
/// * `elem_ty` - The element type of the `NDArray`.
|
||||||
|
/// * `shape` - The shape of the `NDArray`.
|
||||||
|
/// * `shape_len_fn` - A function that retrieves the number of dimensions from `shape`.
|
||||||
|
/// * `shape_data_fn` - A function that retrieves the size of a dimension from `shape`.
|
||||||
|
fn create_ndarray_dyn_shape<'ctx, 'a, V, LenFn, DataFn>(
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
elem_ty: Type,
|
||||||
|
shape: &V,
|
||||||
|
shape_len_fn: LenFn,
|
||||||
|
shape_data_fn: DataFn,
|
||||||
|
) -> Result<NDArrayValue<'ctx>, String>
|
||||||
|
where
|
||||||
|
LenFn: Fn(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, &V) -> Result<IntValue<'ctx>, String>,
|
||||||
|
DataFn: Fn(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, &V, IntValue<'ctx>) -> Result<IntValue<'ctx>, String>,
|
||||||
|
{
|
||||||
|
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None);
|
||||||
|
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
|
||||||
|
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
|
||||||
|
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum();
|
||||||
|
assert!(llvm_ndarray_data_t.is_sized());
|
||||||
|
|
||||||
|
// Assert that all dimensions are non-negative
|
||||||
|
gen_for_callback(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
|generator, ctx| {
|
||||||
|
let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||||
|
ctx.builder.build_store(i, llvm_usize.const_zero()).unwrap();
|
||||||
|
|
||||||
|
Ok(i)
|
||||||
|
},
|
||||||
|
|generator, ctx, i_addr| {
|
||||||
|
let i = ctx.builder
|
||||||
|
.build_load(i_addr, "")
|
||||||
|
.map(BasicValueEnum::into_int_value)
|
||||||
|
.unwrap();
|
||||||
|
let shape_len = shape_len_fn(generator, ctx, shape)?;
|
||||||
|
debug_assert!(shape_len.get_type().get_bit_width() <= llvm_usize.get_bit_width());
|
||||||
|
|
||||||
|
Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, shape_len, "").unwrap())
|
||||||
|
},
|
||||||
|
|generator, ctx, i_addr| {
|
||||||
|
let i = ctx.builder
|
||||||
|
.build_load(i_addr, "")
|
||||||
|
.map(BasicValueEnum::into_int_value)
|
||||||
|
.unwrap();
|
||||||
|
let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
|
||||||
|
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
|
||||||
|
|
||||||
|
let shape_dim_gez = ctx.builder
|
||||||
|
.build_int_compare(IntPredicate::SGE, shape_dim, shape_dim.get_type().const_zero(), "")
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
shape_dim_gez,
|
||||||
|
"0:ValueError",
|
||||||
|
"negative dimensions not supported",
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
|_, ctx, i_addr| {
|
||||||
|
let i = ctx.builder
|
||||||
|
.build_load(i_addr, "")
|
||||||
|
.map(BasicValueEnum::into_int_value)
|
||||||
|
.unwrap();
|
||||||
|
let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "").unwrap();
|
||||||
|
ctx.builder.build_store(i_addr, i).unwrap();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let ndarray = generator.gen_var_alloc(
|
||||||
|
ctx,
|
||||||
|
llvm_ndarray_t.into(),
|
||||||
|
None,
|
||||||
|
)?;
|
||||||
|
let ndarray = NDArrayValue::from_ptr_val(ndarray, llvm_usize, None);
|
||||||
|
|
||||||
|
let num_dims = shape_len_fn(generator, ctx, shape)?;
|
||||||
|
ndarray.store_ndims(ctx, generator, num_dims);
|
||||||
|
|
||||||
|
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||||
|
ndarray.create_dims(ctx, llvm_usize, ndarray_num_dims);
|
||||||
|
|
||||||
|
// Copy the dimension sizes from shape to ndarray.dims
|
||||||
|
gen_for_callback(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
|generator, ctx| {
|
||||||
|
let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||||
|
ctx.builder.build_store(i, llvm_usize.const_zero()).unwrap();
|
||||||
|
|
||||||
|
Ok(i)
|
||||||
|
},
|
||||||
|
|generator, ctx, i_addr| {
|
||||||
|
let i = ctx.builder
|
||||||
|
.build_load(i_addr, "")
|
||||||
|
.map(BasicValueEnum::into_int_value)
|
||||||
|
.unwrap();
|
||||||
|
let shape_len = shape_len_fn(generator, ctx, shape)?;
|
||||||
|
debug_assert!(shape_len.get_type().get_bit_width() <= llvm_usize.get_bit_width());
|
||||||
|
|
||||||
|
Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, shape_len, "").unwrap())
|
||||||
|
},
|
||||||
|
|generator, ctx, i_addr| {
|
||||||
|
let i = ctx.builder
|
||||||
|
.build_load(i_addr, "")
|
||||||
|
.map(BasicValueEnum::into_int_value)
|
||||||
|
.unwrap();
|
||||||
|
let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
|
||||||
|
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
|
||||||
|
let shape_dim = ctx.builder
|
||||||
|
.build_int_z_extend(shape_dim, llvm_usize, "")
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let ndarray_pdim = ndarray.get_dims().ptr_offset(ctx, generator, i, None);
|
||||||
|
|
||||||
|
ctx.builder.build_store(ndarray_pdim, shape_dim).unwrap();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
|_, ctx, i_addr| {
|
||||||
|
let i = ctx.builder
|
||||||
|
.build_load(i_addr, "")
|
||||||
|
.map(BasicValueEnum::into_int_value)
|
||||||
|
.unwrap();
|
||||||
|
let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "").unwrap();
|
||||||
|
ctx.builder.build_store(i_addr, i).unwrap();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let ndarray_num_elems = call_ndarray_calc_size(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
ndarray.load_ndims(ctx),
|
||||||
|
ndarray.get_dims().get_ptr(ctx),
|
||||||
|
);
|
||||||
|
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
||||||
|
|
||||||
|
Ok(ndarray)
|
||||||
|
}
|
||||||
|
|
||||||
/// Creates an `NDArray` instance from a constant shape.
|
/// Creates an `NDArray` instance from a constant shape.
|
||||||
///
|
///
|
||||||
/// * `elem_ty` - The element type of the `NDArray`.
|
/// * `elem_ty` - The element type of the `NDArray`.
|
||||||
@ -205,98 +359,18 @@ fn call_ndarray_empty_impl<'ctx>(
|
|||||||
elem_ty: Type,
|
elem_ty: Type,
|
||||||
shape: ListValue<'ctx>,
|
shape: ListValue<'ctx>,
|
||||||
) -> Result<NDArrayValue<'ctx>, String> {
|
) -> Result<NDArrayValue<'ctx>, String> {
|
||||||
let ndarray_ty = make_ndarray_ty(
|
create_ndarray_dyn_shape(
|
||||||
&mut ctx.unifier,
|
|
||||||
&ctx.primitives,
|
|
||||||
Some(elem_ty),
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
|
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
|
|
||||||
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
|
|
||||||
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
|
|
||||||
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum();
|
|
||||||
assert!(llvm_ndarray_data_t.is_sized());
|
|
||||||
|
|
||||||
// Assert that all dimensions are non-negative
|
|
||||||
gen_for_callback(
|
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
|generator, ctx| {
|
elem_ty,
|
||||||
let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
&shape,
|
||||||
ctx.builder.build_store(i, llvm_usize.const_zero()).unwrap();
|
|_, ctx, shape| {
|
||||||
|
Ok(shape.load_size(ctx, None))
|
||||||
Ok(i)
|
|
||||||
},
|
},
|
||||||
|_, ctx, i_addr| {
|
|generator, ctx, shape, idx| {
|
||||||
let i = ctx.builder
|
Ok(shape.get_data().get(ctx, generator, idx, None).into_int_value())
|
||||||
.build_load(i_addr, "")
|
|
||||||
.map(BasicValueEnum::into_int_value)
|
|
||||||
.unwrap();
|
|
||||||
let shape_len = shape.load_size(ctx, None);
|
|
||||||
|
|
||||||
Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, shape_len, "").unwrap())
|
|
||||||
},
|
},
|
||||||
|generator, ctx, i_addr| {
|
)
|
||||||
let i = ctx.builder
|
|
||||||
.build_load(i_addr, "")
|
|
||||||
.map(BasicValueEnum::into_int_value)
|
|
||||||
.unwrap();
|
|
||||||
let shape_dim = shape.get_data().get(ctx, generator, i, None).into_int_value();
|
|
||||||
|
|
||||||
let shape_dim_gez = ctx.builder
|
|
||||||
.build_int_compare(IntPredicate::SGE, shape_dim, llvm_i32.const_zero(), "")
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
shape_dim_gez,
|
|
||||||
"0:ValueError",
|
|
||||||
"negative dimensions not supported",
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
},
|
|
||||||
|_, ctx, i_addr| {
|
|
||||||
let i = ctx.builder
|
|
||||||
.build_load(i_addr, "")
|
|
||||||
.map(BasicValueEnum::into_int_value)
|
|
||||||
.unwrap();
|
|
||||||
let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "").unwrap();
|
|
||||||
ctx.builder.build_store(i_addr, i).unwrap();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
},
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let ndarray = generator.gen_var_alloc(
|
|
||||||
ctx,
|
|
||||||
llvm_ndarray_t.into(),
|
|
||||||
None,
|
|
||||||
)?;
|
|
||||||
let ndarray = NDArrayValue::from_ptr_val(ndarray, llvm_usize, None);
|
|
||||||
|
|
||||||
let num_dims = shape.load_size(ctx, None);
|
|
||||||
ndarray.store_ndims(ctx, generator, num_dims);
|
|
||||||
|
|
||||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
|
||||||
ndarray.create_dims(ctx, llvm_usize, ndarray_num_dims);
|
|
||||||
|
|
||||||
call_ndarray_init_dims(generator, ctx, ndarray, shape);
|
|
||||||
|
|
||||||
let ndarray_num_elems = call_ndarray_calc_size(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
ndarray.load_ndims(ctx),
|
|
||||||
ndarray.get_dims().get_ptr(ctx),
|
|
||||||
);
|
|
||||||
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
|
||||||
|
|
||||||
Ok(ndarray)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as
|
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as
|
||||||
|
Loading…
Reference in New Issue
Block a user