core/numpy: Add more helper functions
This commit is contained in:
parent
b6ff75dcaf
commit
2cf79510c2
@ -33,6 +33,30 @@ use crate::{
|
|||||||
typecheck::typedef::{FunSignature, Type},
|
typecheck::typedef::{FunSignature, Type},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Creates an uninitialized `NDArray` instance.
|
||||||
|
fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
elem_ty: Type,
|
||||||
|
) -> Result<NDArrayValue<'ctx>, String> {
|
||||||
|
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None);
|
||||||
|
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
let llvm_ndarray_t = ctx.get_llvm_type(generator, ndarray_ty)
|
||||||
|
.into_pointer_type()
|
||||||
|
.get_element_type()
|
||||||
|
.into_struct_type();
|
||||||
|
|
||||||
|
let ndarray = generator.gen_var_alloc(
|
||||||
|
ctx,
|
||||||
|
llvm_ndarray_t.into(),
|
||||||
|
None,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(NDArrayValue::from_ptr_val(ndarray, llvm_usize, None))
|
||||||
|
}
|
||||||
|
|
||||||
/// Creates an `NDArray` instance from a dynamic shape.
|
/// Creates an `NDArray` instance from a dynamic shape.
|
||||||
///
|
///
|
||||||
/// * `elem_ty` - The element type of the `NDArray`.
|
/// * `elem_ty` - The element type of the `NDArray`.
|
||||||
@ -52,15 +76,8 @@ fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>(
|
|||||||
LenFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, &V) -> Result<IntValue<'ctx>, String>,
|
LenFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, &V) -> Result<IntValue<'ctx>, String>,
|
||||||
DataFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, &V, IntValue<'ctx>) -> Result<IntValue<'ctx>, String>,
|
DataFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, &V, IntValue<'ctx>) -> Result<IntValue<'ctx>, String>,
|
||||||
{
|
{
|
||||||
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None);
|
|
||||||
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
|
|
||||||
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
|
|
||||||
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum();
|
|
||||||
assert!(llvm_ndarray_data_t.is_sized());
|
|
||||||
|
|
||||||
// Assert that all dimensions are non-negative
|
// Assert that all dimensions are non-negative
|
||||||
let shape_len = shape_len_fn(generator, ctx, shape)?;
|
let shape_len = shape_len_fn(generator, ctx, shape)?;
|
||||||
gen_for_callback_incrementing(
|
gen_for_callback_incrementing(
|
||||||
@ -92,12 +109,7 @@ fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>(
|
|||||||
llvm_usize.const_int(1, false),
|
llvm_usize.const_int(1, false),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let ndarray = generator.gen_var_alloc(
|
let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?;
|
||||||
ctx,
|
|
||||||
llvm_ndarray_t.into(),
|
|
||||||
None,
|
|
||||||
)?;
|
|
||||||
let ndarray = NDArrayValue::from_ptr_val(ndarray, llvm_usize, None);
|
|
||||||
|
|
||||||
let num_dims = shape_len_fn(generator, ctx, shape)?;
|
let num_dims = shape_len_fn(generator, ctx, shape)?;
|
||||||
ndarray.store_ndims(ctx, generator, num_dims);
|
ndarray.store_ndims(ctx, generator, num_dims);
|
||||||
@ -130,13 +142,7 @@ fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>(
|
|||||||
llvm_usize.const_int(1, false),
|
llvm_usize.const_int(1, false),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let ndarray_num_elems = call_ndarray_calc_size(
|
let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray);
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
&ndarray.dim_sizes().as_slice_value(ctx, generator),
|
|
||||||
(None, None),
|
|
||||||
);
|
|
||||||
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
|
||||||
|
|
||||||
Ok(ndarray)
|
Ok(ndarray)
|
||||||
}
|
}
|
||||||
@ -151,15 +157,8 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
elem_ty: Type,
|
elem_ty: Type,
|
||||||
shape: &[IntValue<'ctx>],
|
shape: &[IntValue<'ctx>],
|
||||||
) -> Result<NDArrayValue<'ctx>, String> {
|
) -> Result<NDArrayValue<'ctx>, String> {
|
||||||
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None);
|
|
||||||
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
|
|
||||||
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
|
|
||||||
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum();
|
|
||||||
assert!(llvm_ndarray_data_t.is_sized());
|
|
||||||
|
|
||||||
for shape_dim in shape {
|
for shape_dim in shape {
|
||||||
let shape_dim_gez = ctx.builder
|
let shape_dim_gez = ctx.builder
|
||||||
.build_int_compare(IntPredicate::SGE, *shape_dim, llvm_usize.const_zero(), "")
|
.build_int_compare(IntPredicate::SGE, *shape_dim, llvm_usize.const_zero(), "")
|
||||||
@ -177,12 +176,7 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
// TODO: Disallow dim_sz > u32_MAX
|
// TODO: Disallow dim_sz > u32_MAX
|
||||||
}
|
}
|
||||||
|
|
||||||
let ndarray = generator.gen_var_alloc(
|
let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?;
|
||||||
ctx,
|
|
||||||
llvm_ndarray_t.into(),
|
|
||||||
None,
|
|
||||||
)?;
|
|
||||||
let ndarray = NDArrayValue::from_ptr_val(ndarray, llvm_usize, None);
|
|
||||||
|
|
||||||
let num_dims = llvm_usize.const_int(shape.len() as u64, false);
|
let num_dims = llvm_usize.const_int(shape.len() as u64, false);
|
||||||
ndarray.store_ndims(ctx, generator, num_dims);
|
ndarray.store_ndims(ctx, generator, num_dims);
|
||||||
@ -200,6 +194,21 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
ctx.builder.build_store(ndarray_dim, *shape_dim).unwrap();
|
ctx.builder.build_store(ndarray_dim, *shape_dim).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray);
|
||||||
|
|
||||||
|
Ok(ndarray)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Initializes the `data` field of [`NDArrayValue`] based on the `ndims` and `dim_sz` fields.
|
||||||
|
fn ndarray_init_data<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
elem_ty: Type,
|
||||||
|
ndarray: NDArrayValue<'ctx>,
|
||||||
|
) -> NDArrayValue<'ctx> {
|
||||||
|
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum();
|
||||||
|
assert!(llvm_ndarray_data_t.is_sized());
|
||||||
|
|
||||||
let ndarray_num_elems = call_ndarray_calc_size(
|
let ndarray_num_elems = call_ndarray_calc_size(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
@ -208,7 +217,7 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
);
|
);
|
||||||
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
||||||
|
|
||||||
Ok(ndarray)
|
ndarray
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
|
fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
@ -671,14 +671,12 @@ pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>(
|
|||||||
|
|
||||||
let start = start_fn(generator, ctx)?;
|
let start = start_fn(generator, ctx)?;
|
||||||
let stop = stop_fn(generator, ctx)?;
|
let stop = stop_fn(generator, ctx)?;
|
||||||
let stop = if stop.get_type().get_bit_width() != start.get_type().get_bit_width() {
|
let stop = if stop.get_type().get_bit_width() == start.get_type().get_bit_width() {
|
||||||
if is_unsigned {
|
stop
|
||||||
|
} else if is_unsigned {
|
||||||
ctx.builder.build_int_z_extend(stop, start.get_type(), "").unwrap()
|
ctx.builder.build_int_z_extend(stop, start.get_type(), "").unwrap()
|
||||||
} else {
|
} else {
|
||||||
ctx.builder.build_int_s_extend(stop, start.get_type(), "").unwrap()
|
ctx.builder.build_int_s_extend(stop, start.get_type(), "").unwrap()
|
||||||
}
|
|
||||||
} else {
|
|
||||||
stop
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let incr = ctx.builder.build_int_compare(
|
let incr = ctx.builder.build_int_compare(
|
||||||
@ -703,14 +701,12 @@ pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>(
|
|||||||
.map(BasicValueEnum::into_int_value)
|
.map(BasicValueEnum::into_int_value)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let stop = stop_fn(generator, ctx)?;
|
let stop = stop_fn(generator, ctx)?;
|
||||||
let stop = if stop.get_type().get_bit_width() != i.get_type().get_bit_width() {
|
let stop = if stop.get_type().get_bit_width() == i.get_type().get_bit_width() {
|
||||||
if is_unsigned {
|
stop
|
||||||
|
} else if is_unsigned {
|
||||||
ctx.builder.build_int_z_extend(stop, i.get_type(), "").unwrap()
|
ctx.builder.build_int_z_extend(stop, i.get_type(), "").unwrap()
|
||||||
} else {
|
} else {
|
||||||
ctx.builder.build_int_s_extend(stop, i.get_type(), "").unwrap()
|
ctx.builder.build_int_s_extend(stop, i.get_type(), "").unwrap()
|
||||||
}
|
|
||||||
} else {
|
|
||||||
stop
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let i_lt_end = ctx.builder
|
let i_lt_end = ctx.builder
|
||||||
@ -742,14 +738,12 @@ pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>(
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let incr_val = step_fn(generator, ctx)?;
|
let incr_val = step_fn(generator, ctx)?;
|
||||||
let incr_val = if incr_val.get_type().get_bit_width() != i.get_type().get_bit_width() {
|
let incr_val = if incr_val.get_type().get_bit_width() == i.get_type().get_bit_width() {
|
||||||
if is_unsigned {
|
incr_val
|
||||||
|
} else if is_unsigned {
|
||||||
ctx.builder.build_int_z_extend(incr_val, i.get_type(), "").unwrap()
|
ctx.builder.build_int_z_extend(incr_val, i.get_type(), "").unwrap()
|
||||||
} else {
|
} else {
|
||||||
ctx.builder.build_int_s_extend(incr_val, i.get_type(), "").unwrap()
|
ctx.builder.build_int_s_extend(incr_val, i.get_type(), "").unwrap()
|
||||||
}
|
|
||||||
} else {
|
|
||||||
incr_val
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let i = ctx.builder.build_int_add(i, incr_val, "").unwrap();
|
let i = ctx.builder.build_int_add(i, incr_val, "").unwrap();
|
||||||
|
Loading…
Reference in New Issue
Block a user