forked from M-Labs/nac3
core/numpy: Add more helper functions
This commit is contained in:
parent
b6ff75dcaf
commit
2cf79510c2
|
@ -33,6 +33,30 @@ use crate::{
|
|||
typecheck::typedef::{FunSignature, Type},
|
||||
};
|
||||
|
||||
/// Creates an uninitialized `NDArray` instance.
|
||||
fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
elem_ty: Type,
|
||||
) -> Result<NDArrayValue<'ctx>, String> {
|
||||
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None);
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let llvm_ndarray_t = ctx.get_llvm_type(generator, ndarray_ty)
|
||||
.into_pointer_type()
|
||||
.get_element_type()
|
||||
.into_struct_type();
|
||||
|
||||
let ndarray = generator.gen_var_alloc(
|
||||
ctx,
|
||||
llvm_ndarray_t.into(),
|
||||
None,
|
||||
)?;
|
||||
|
||||
Ok(NDArrayValue::from_ptr_val(ndarray, llvm_usize, None))
|
||||
}
|
||||
|
||||
/// Creates an `NDArray` instance from a dynamic shape.
|
||||
///
|
||||
/// * `elem_ty` - The element type of the `NDArray`.
|
||||
|
@ -52,15 +76,8 @@ fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>(
|
|||
LenFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, &V) -> Result<IntValue<'ctx>, String>,
|
||||
DataFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, &V, IntValue<'ctx>) -> Result<IntValue<'ctx>, String>,
|
||||
{
|
||||
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None);
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
|
||||
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
|
||||
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum();
|
||||
assert!(llvm_ndarray_data_t.is_sized());
|
||||
|
||||
// Assert that all dimensions are non-negative
|
||||
let shape_len = shape_len_fn(generator, ctx, shape)?;
|
||||
gen_for_callback_incrementing(
|
||||
|
@ -92,12 +109,7 @@ fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>(
|
|||
llvm_usize.const_int(1, false),
|
||||
)?;
|
||||
|
||||
let ndarray = generator.gen_var_alloc(
|
||||
ctx,
|
||||
llvm_ndarray_t.into(),
|
||||
None,
|
||||
)?;
|
||||
let ndarray = NDArrayValue::from_ptr_val(ndarray, llvm_usize, None);
|
||||
let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?;
|
||||
|
||||
let num_dims = shape_len_fn(generator, ctx, shape)?;
|
||||
ndarray.store_ndims(ctx, generator, num_dims);
|
||||
|
@ -130,13 +142,7 @@ fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>(
|
|||
llvm_usize.const_int(1, false),
|
||||
)?;
|
||||
|
||||
let ndarray_num_elems = call_ndarray_calc_size(
|
||||
generator,
|
||||
ctx,
|
||||
&ndarray.dim_sizes().as_slice_value(ctx, generator),
|
||||
(None, None),
|
||||
);
|
||||
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
||||
let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray);
|
||||
|
||||
Ok(ndarray)
|
||||
}
|
||||
|
@ -151,15 +157,8 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
|
|||
elem_ty: Type,
|
||||
shape: &[IntValue<'ctx>],
|
||||
) -> Result<NDArrayValue<'ctx>, String> {
|
||||
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None);
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
|
||||
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
|
||||
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum();
|
||||
assert!(llvm_ndarray_data_t.is_sized());
|
||||
|
||||
for shape_dim in shape {
|
||||
let shape_dim_gez = ctx.builder
|
||||
.build_int_compare(IntPredicate::SGE, *shape_dim, llvm_usize.const_zero(), "")
|
||||
|
@ -177,12 +176,7 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
|
|||
// TODO: Disallow dim_sz > u32_MAX
|
||||
}
|
||||
|
||||
let ndarray = generator.gen_var_alloc(
|
||||
ctx,
|
||||
llvm_ndarray_t.into(),
|
||||
None,
|
||||
)?;
|
||||
let ndarray = NDArrayValue::from_ptr_val(ndarray, llvm_usize, None);
|
||||
let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?;
|
||||
|
||||
let num_dims = llvm_usize.const_int(shape.len() as u64, false);
|
||||
ndarray.store_ndims(ctx, generator, num_dims);
|
||||
|
@ -200,6 +194,21 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
|
|||
ctx.builder.build_store(ndarray_dim, *shape_dim).unwrap();
|
||||
}
|
||||
|
||||
let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray);
|
||||
|
||||
Ok(ndarray)
|
||||
}
|
||||
|
||||
/// Initializes the `data` field of [`NDArrayValue`] based on the `ndims` and `dim_sz` fields.
|
||||
fn ndarray_init_data<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
elem_ty: Type,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
) -> NDArrayValue<'ctx> {
|
||||
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum();
|
||||
assert!(llvm_ndarray_data_t.is_sized());
|
||||
|
||||
let ndarray_num_elems = call_ndarray_calc_size(
|
||||
generator,
|
||||
ctx,
|
||||
|
@ -208,7 +217,7 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
|
|||
);
|
||||
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
||||
|
||||
Ok(ndarray)
|
||||
ndarray
|
||||
}
|
||||
|
||||
fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
|
||||
|
|
|
@ -671,14 +671,12 @@ pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>(
|
|||
|
||||
let start = start_fn(generator, ctx)?;
|
||||
let stop = stop_fn(generator, ctx)?;
|
||||
let stop = if stop.get_type().get_bit_width() != start.get_type().get_bit_width() {
|
||||
if is_unsigned {
|
||||
ctx.builder.build_int_z_extend(stop, start.get_type(), "").unwrap()
|
||||
} else {
|
||||
ctx.builder.build_int_s_extend(stop, start.get_type(), "").unwrap()
|
||||
}
|
||||
} else {
|
||||
let stop = if stop.get_type().get_bit_width() == start.get_type().get_bit_width() {
|
||||
stop
|
||||
} else if is_unsigned {
|
||||
ctx.builder.build_int_z_extend(stop, start.get_type(), "").unwrap()
|
||||
} else {
|
||||
ctx.builder.build_int_s_extend(stop, start.get_type(), "").unwrap()
|
||||
};
|
||||
|
||||
let incr = ctx.builder.build_int_compare(
|
||||
|
@ -703,14 +701,12 @@ pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>(
|
|||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
let stop = stop_fn(generator, ctx)?;
|
||||
let stop = if stop.get_type().get_bit_width() != i.get_type().get_bit_width() {
|
||||
if is_unsigned {
|
||||
ctx.builder.build_int_z_extend(stop, i.get_type(), "").unwrap()
|
||||
} else {
|
||||
ctx.builder.build_int_s_extend(stop, i.get_type(), "").unwrap()
|
||||
}
|
||||
} else {
|
||||
let stop = if stop.get_type().get_bit_width() == i.get_type().get_bit_width() {
|
||||
stop
|
||||
} else if is_unsigned {
|
||||
ctx.builder.build_int_z_extend(stop, i.get_type(), "").unwrap()
|
||||
} else {
|
||||
ctx.builder.build_int_s_extend(stop, i.get_type(), "").unwrap()
|
||||
};
|
||||
|
||||
let i_lt_end = ctx.builder
|
||||
|
@ -742,14 +738,12 @@ pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>(
|
|||
.unwrap();
|
||||
|
||||
let incr_val = step_fn(generator, ctx)?;
|
||||
let incr_val = if incr_val.get_type().get_bit_width() != i.get_type().get_bit_width() {
|
||||
if is_unsigned {
|
||||
ctx.builder.build_int_z_extend(incr_val, i.get_type(), "").unwrap()
|
||||
} else {
|
||||
ctx.builder.build_int_s_extend(incr_val, i.get_type(), "").unwrap()
|
||||
}
|
||||
} else {
|
||||
let incr_val = if incr_val.get_type().get_bit_width() == i.get_type().get_bit_width() {
|
||||
incr_val
|
||||
} else if is_unsigned {
|
||||
ctx.builder.build_int_z_extend(incr_val, i.get_type(), "").unwrap()
|
||||
} else {
|
||||
ctx.builder.build_int_s_extend(incr_val, i.get_type(), "").unwrap()
|
||||
};
|
||||
|
||||
let i = ctx.builder.build_int_add(i, incr_val, "").unwrap();
|
||||
|
|
Loading…
Reference in New Issue