2024-10-03 12:37:56 +08:00
|
|
|
use inkwell::{
|
|
|
|
types::{AnyTypeEnum, BasicType, BasicTypeEnum, PointerType},
|
|
|
|
values::{BasicValue, BasicValueEnum, IntValue, PointerValue},
|
|
|
|
AddressSpace, IntPredicate, OptimizationLevel,
|
|
|
|
};
|
|
|
|
|
|
|
|
use nac3parser::ast::{Operator, StrRef};
|
|
|
|
|
2024-10-17 15:57:33 +08:00
|
|
|
use super::{
|
|
|
|
expr::gen_binop_expr_with_values,
|
|
|
|
irrt::{
|
2024-11-22 16:38:57 +08:00
|
|
|
calculate_len_for_slice_range,
|
|
|
|
ndarray::{
|
|
|
|
call_ndarray_calc_broadcast, call_ndarray_calc_broadcast_index,
|
|
|
|
call_ndarray_calc_nd_indices, call_ndarray_calc_size,
|
|
|
|
},
|
2024-03-11 14:47:01 +08:00
|
|
|
},
|
2024-10-17 15:57:33 +08:00
|
|
|
llvm_intrinsics::{self, call_memcpy_generic},
|
|
|
|
macros::codegen_unreachable,
|
|
|
|
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
|
2024-12-12 11:30:14 +08:00
|
|
|
types::{ndarray::NDArrayType, ListType, ProxyType},
|
2024-10-29 13:57:28 +08:00
|
|
|
values::{
|
2024-12-12 11:30:14 +08:00
|
|
|
ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue,
|
2024-10-29 13:57:28 +08:00
|
|
|
TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator,
|
|
|
|
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
|
|
|
},
|
2024-10-17 15:57:33 +08:00
|
|
|
CodeGenContext, CodeGenerator,
|
|
|
|
};
|
|
|
|
use crate::{
|
2024-03-11 14:47:01 +08:00
|
|
|
symbol_resolver::ValueEnum,
|
|
|
|
toplevel::{
|
2024-08-28 16:33:03 +08:00
|
|
|
helper::{arraylike_flatten_element_type, PrimDef},
|
2024-03-26 19:14:56 +08:00
|
|
|
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
2024-06-12 14:45:03 +08:00
|
|
|
DefinitionId,
|
2024-03-11 14:47:01 +08:00
|
|
|
},
|
2024-06-27 13:01:26 +08:00
|
|
|
typecheck::{
|
|
|
|
magic_methods::Binop,
|
|
|
|
typedef::{FunSignature, Type, TypeEnum},
|
|
|
|
},
|
2024-03-11 14:47:01 +08:00
|
|
|
};
|
|
|
|
|
2024-05-29 14:19:12 +08:00
|
|
|
/// Creates an uninitialized `NDArray` instance.
|
|
|
|
fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
|
|
|
|
generator: &mut G,
|
|
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
elem_ty: Type,
|
|
|
|
) -> Result<NDArrayValue<'ctx>, String> {
|
2024-08-28 16:33:03 +08:00
|
|
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
2024-05-29 14:19:12 +08:00
|
|
|
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None);
|
|
|
|
|
|
|
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
let llvm_ndarray_t = ctx
|
|
|
|
.get_llvm_type(generator, ndarray_ty)
|
2024-05-29 14:19:12 +08:00
|
|
|
.into_pointer_type()
|
|
|
|
.get_element_type()
|
|
|
|
.into_struct_type();
|
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
let ndarray = generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?;
|
2024-05-29 14:19:12 +08:00
|
|
|
|
2024-08-28 16:33:03 +08:00
|
|
|
Ok(NDArrayValue::from_pointer_value(ndarray, llvm_elem_ty, llvm_usize, None))
|
2024-05-29 14:19:12 +08:00
|
|
|
}
|
|
|
|
|
2024-03-11 14:47:01 +08:00
|
|
|
/// Creates an `NDArray` instance from a dynamic shape.
|
|
|
|
///
|
|
|
|
/// * `elem_ty` - The element type of the `NDArray`.
|
|
|
|
/// * `shape` - The shape of the `NDArray`.
|
|
|
|
/// * `shape_len_fn` - A function that retrieves the number of dimensions from `shape`.
|
|
|
|
/// * `shape_data_fn` - A function that retrieves the size of a dimension from `shape`.
|
2024-03-19 18:24:30 +08:00
|
|
|
fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>(
|
|
|
|
generator: &mut G,
|
2024-03-11 14:47:01 +08:00
|
|
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
|
|
|
elem_ty: Type,
|
|
|
|
shape: &V,
|
|
|
|
shape_len_fn: LenFn,
|
|
|
|
shape_data_fn: DataFn,
|
|
|
|
) -> Result<NDArrayValue<'ctx>, String>
|
2024-06-12 14:45:03 +08:00
|
|
|
where
|
|
|
|
G: CodeGenerator + ?Sized,
|
|
|
|
LenFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, &V) -> Result<IntValue<'ctx>, String>,
|
|
|
|
DataFn: Fn(
|
|
|
|
&mut G,
|
|
|
|
&mut CodeGenContext<'ctx, 'a>,
|
|
|
|
&V,
|
|
|
|
IntValue<'ctx>,
|
|
|
|
) -> Result<IntValue<'ctx>, String>,
|
2024-03-11 14:47:01 +08:00
|
|
|
{
|
|
|
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
|
|
|
|
|
|
// Assert that all dimensions are non-negative
|
2024-03-08 13:13:18 +08:00
|
|
|
let shape_len = shape_len_fn(generator, ctx, shape)?;
|
|
|
|
gen_for_callback_incrementing(
|
2024-03-11 14:47:01 +08:00
|
|
|
generator,
|
|
|
|
ctx,
|
2024-07-25 15:54:39 +08:00
|
|
|
None,
|
2024-03-08 13:13:18 +08:00
|
|
|
llvm_usize.const_zero(),
|
|
|
|
(shape_len, false),
|
2024-07-02 19:05:00 +08:00
|
|
|
|generator, ctx, _, i| {
|
2024-03-11 14:47:01 +08:00
|
|
|
let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
|
|
|
|
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
|
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
let shape_dim_gez = ctx
|
|
|
|
.builder
|
|
|
|
.build_int_compare(
|
|
|
|
IntPredicate::SGE,
|
|
|
|
shape_dim,
|
|
|
|
shape_dim.get_type().const_zero(),
|
|
|
|
"",
|
|
|
|
)
|
2024-03-11 14:47:01 +08:00
|
|
|
.unwrap();
|
|
|
|
|
|
|
|
ctx.make_assert(
|
|
|
|
generator,
|
|
|
|
shape_dim_gez,
|
|
|
|
"0:ValueError",
|
|
|
|
"negative dimensions not supported",
|
|
|
|
[None, None, None],
|
|
|
|
ctx.current_loc,
|
|
|
|
);
|
2024-06-12 14:45:03 +08:00
|
|
|
|
2024-04-19 19:00:07 +08:00
|
|
|
// TODO: Disallow dim_sz > u32_MAX
|
2024-03-11 14:47:01 +08:00
|
|
|
|
|
|
|
Ok(())
|
|
|
|
},
|
2024-03-08 13:13:18 +08:00
|
|
|
llvm_usize.const_int(1, false),
|
2024-03-11 14:47:01 +08:00
|
|
|
)?;
|
|
|
|
|
2024-05-29 14:19:12 +08:00
|
|
|
let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?;
|
2024-03-11 14:47:01 +08:00
|
|
|
|
|
|
|
let num_dims = shape_len_fn(generator, ctx, shape)?;
|
|
|
|
ndarray.store_ndims(ctx, generator, num_dims);
|
|
|
|
|
|
|
|
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
2024-11-13 15:53:29 +08:00
|
|
|
ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims);
|
2024-03-11 14:47:01 +08:00
|
|
|
|
|
|
|
// Copy the dimension sizes from shape to ndarray.dims
|
2024-03-08 13:13:18 +08:00
|
|
|
let shape_len = shape_len_fn(generator, ctx, shape)?;
|
|
|
|
gen_for_callback_incrementing(
|
2024-03-11 14:47:01 +08:00
|
|
|
generator,
|
|
|
|
ctx,
|
2024-07-25 15:54:39 +08:00
|
|
|
None,
|
2024-03-08 13:13:18 +08:00
|
|
|
llvm_usize.const_zero(),
|
|
|
|
(shape_len, false),
|
2024-07-02 19:05:00 +08:00
|
|
|
|generator, ctx, _, i| {
|
2024-03-11 14:47:01 +08:00
|
|
|
let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
|
|
|
|
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
|
2024-06-12 14:45:03 +08:00
|
|
|
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
let ndarray_pdim =
|
2024-11-13 15:53:29 +08:00
|
|
|
unsafe { ndarray.shape().ptr_offset_unchecked(ctx, generator, &i, None) };
|
2024-03-11 14:47:01 +08:00
|
|
|
|
|
|
|
ctx.builder.build_store(ndarray_pdim, shape_dim).unwrap();
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
},
|
2024-03-08 13:13:18 +08:00
|
|
|
llvm_usize.const_int(1, false),
|
2024-03-11 14:47:01 +08:00
|
|
|
)?;
|
|
|
|
|
2024-05-29 14:19:12 +08:00
|
|
|
let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray);
|
2024-03-11 14:47:01 +08:00
|
|
|
|
|
|
|
Ok(ndarray)
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Creates an `NDArray` instance from a constant shape.
|
|
|
|
///
|
|
|
|
/// * `elem_ty` - The element type of the `NDArray`.
|
2024-03-22 16:57:36 +08:00
|
|
|
/// * `shape` - The shape of the `NDArray`, represented am array of [`IntValue`]s.
|
2024-07-25 12:16:53 +08:00
|
|
|
pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
|
2024-03-19 18:24:30 +08:00
|
|
|
generator: &mut G,
|
2024-03-11 14:47:01 +08:00
|
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
elem_ty: Type,
|
2024-03-22 16:57:36 +08:00
|
|
|
shape: &[IntValue<'ctx>],
|
2024-03-11 14:47:01 +08:00
|
|
|
) -> Result<NDArrayValue<'ctx>, String> {
|
|
|
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
|
|
|
2024-06-25 15:35:02 +08:00
|
|
|
for &shape_dim in shape {
|
|
|
|
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
|
2024-06-12 14:45:03 +08:00
|
|
|
let shape_dim_gez = ctx
|
|
|
|
.builder
|
2024-06-25 15:35:02 +08:00
|
|
|
.build_int_compare(IntPredicate::SGE, shape_dim, llvm_usize.const_zero(), "")
|
2024-03-11 14:47:01 +08:00
|
|
|
.unwrap();
|
|
|
|
|
|
|
|
ctx.make_assert(
|
|
|
|
generator,
|
|
|
|
shape_dim_gez,
|
|
|
|
"0:ValueError",
|
|
|
|
"negative dimensions not supported",
|
|
|
|
[None, None, None],
|
|
|
|
ctx.current_loc,
|
|
|
|
);
|
2024-04-19 19:00:07 +08:00
|
|
|
|
|
|
|
// TODO: Disallow dim_sz > u32_MAX
|
2024-03-11 14:47:01 +08:00
|
|
|
}
|
|
|
|
|
2024-05-29 14:19:12 +08:00
|
|
|
let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?;
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-03-22 16:57:36 +08:00
|
|
|
let num_dims = llvm_usize.const_int(shape.len() as u64, false);
|
2024-03-11 14:47:01 +08:00
|
|
|
ndarray.store_ndims(ctx, generator, num_dims);
|
|
|
|
|
|
|
|
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
2024-11-13 15:53:29 +08:00
|
|
|
ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims);
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-06-25 15:35:02 +08:00
|
|
|
for (i, &shape_dim) in shape.iter().enumerate() {
|
|
|
|
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
|
2024-03-22 16:57:36 +08:00
|
|
|
let ndarray_dim = unsafe {
|
2024-11-13 15:53:29 +08:00
|
|
|
ndarray.shape().ptr_offset_unchecked(
|
2024-06-12 14:45:03 +08:00
|
|
|
ctx,
|
|
|
|
generator,
|
|
|
|
&llvm_usize.const_int(i as u64, true),
|
|
|
|
None,
|
|
|
|
)
|
2024-03-22 16:57:36 +08:00
|
|
|
};
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-06-25 15:35:02 +08:00
|
|
|
ctx.builder.build_store(ndarray_dim, shape_dim).unwrap();
|
2024-03-11 14:47:01 +08:00
|
|
|
}
|
|
|
|
|
2024-05-29 14:19:12 +08:00
|
|
|
let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray);
|
|
|
|
|
|
|
|
Ok(ndarray)
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Initializes the `data` field of [`NDArrayValue`] based on the `ndims` and `dim_sz` fields.
|
|
|
|
fn ndarray_init_data<'ctx, G: CodeGenerator + ?Sized>(
|
|
|
|
generator: &mut G,
|
|
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
elem_ty: Type,
|
|
|
|
ndarray: NDArrayValue<'ctx>,
|
|
|
|
) -> NDArrayValue<'ctx> {
|
|
|
|
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum();
|
|
|
|
assert!(llvm_ndarray_data_t.is_sized());
|
|
|
|
|
2024-03-11 14:47:01 +08:00
|
|
|
let ndarray_num_elems = call_ndarray_calc_size(
|
|
|
|
generator,
|
|
|
|
ctx,
|
2024-11-13 15:53:29 +08:00
|
|
|
&ndarray.shape().as_slice_value(ctx, generator),
|
2024-05-27 15:58:06 +08:00
|
|
|
(None, None),
|
2024-03-11 14:47:01 +08:00
|
|
|
);
|
|
|
|
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
|
|
|
|
2024-05-29 14:19:12 +08:00
|
|
|
ndarray
|
2024-03-11 14:47:01 +08:00
|
|
|
}
|
|
|
|
|
2024-03-19 18:24:30 +08:00
|
|
|
fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
|
|
|
|
generator: &mut G,
|
2024-03-11 14:47:01 +08:00
|
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
elem_ty: Type,
|
|
|
|
) -> BasicValueEnum<'ctx> {
|
2024-06-12 14:45:03 +08:00
|
|
|
if [ctx.primitives.int32, ctx.primitives.uint32]
|
|
|
|
.iter()
|
|
|
|
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
|
|
|
|
{
|
2024-03-11 14:47:01 +08:00
|
|
|
ctx.ctx.i32_type().const_zero().into()
|
2024-06-12 14:45:03 +08:00
|
|
|
} else if [ctx.primitives.int64, ctx.primitives.uint64]
|
|
|
|
.iter()
|
|
|
|
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
|
|
|
|
{
|
2024-03-11 14:47:01 +08:00
|
|
|
ctx.ctx.i64_type().const_zero().into()
|
|
|
|
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) {
|
|
|
|
ctx.ctx.f64_type().const_zero().into()
|
|
|
|
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
|
|
|
|
ctx.ctx.bool_type().const_zero().into()
|
|
|
|
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
|
2024-08-12 20:17:41 +08:00
|
|
|
ctx.gen_string(generator, "").into()
|
2024-03-11 14:47:01 +08:00
|
|
|
} else {
|
2024-08-23 13:10:55 +08:00
|
|
|
codegen_unreachable!(ctx)
|
2024-03-11 14:47:01 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-03-19 18:24:30 +08:00
|
|
|
fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
|
|
|
|
generator: &mut G,
|
2024-03-11 14:47:01 +08:00
|
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
elem_ty: Type,
|
|
|
|
) -> BasicValueEnum<'ctx> {
|
2024-06-12 14:45:03 +08:00
|
|
|
if [ctx.primitives.int32, ctx.primitives.uint32]
|
|
|
|
.iter()
|
|
|
|
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
|
|
|
|
{
|
2024-03-11 14:47:01 +08:00
|
|
|
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int32);
|
|
|
|
ctx.ctx.i32_type().const_int(1, is_signed).into()
|
2024-06-12 14:45:03 +08:00
|
|
|
} else if [ctx.primitives.int64, ctx.primitives.uint64]
|
|
|
|
.iter()
|
|
|
|
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
|
|
|
|
{
|
2024-03-11 14:47:01 +08:00
|
|
|
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int64);
|
|
|
|
ctx.ctx.i64_type().const_int(1, is_signed).into()
|
|
|
|
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) {
|
|
|
|
ctx.ctx.f64_type().const_float(1.0).into()
|
|
|
|
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
|
|
|
|
ctx.ctx.bool_type().const_int(1, false).into()
|
|
|
|
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
|
2024-08-12 20:17:41 +08:00
|
|
|
ctx.gen_string(generator, "1").into()
|
2024-03-11 14:47:01 +08:00
|
|
|
} else {
|
2024-08-23 13:10:55 +08:00
|
|
|
codegen_unreachable!(ctx)
|
2024-03-11 14:47:01 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// LLVM-typed implementation for generating the implementation for constructing an `NDArray`.
|
|
|
|
///
|
|
|
|
/// * `elem_ty` - The element type of the `NDArray`.
|
|
|
|
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
|
2024-06-25 15:35:02 +08:00
|
|
|
///
|
|
|
|
/// ### Notes on `shape`
|
|
|
|
///
|
|
|
|
/// Just like numpy, the `shape` argument can be:
|
|
|
|
/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
|
|
|
|
/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))`
|
|
|
|
/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
|
|
|
|
///
|
|
|
|
/// See also [`typecheck::type_inferencer::fold_numpy_function_call_shape_argument`] to
|
|
|
|
/// learn how `shape` gets from being a Python user expression to here.
|
2024-03-19 18:24:30 +08:00
|
|
|
fn call_ndarray_empty_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|
|
|
generator: &mut G,
|
2024-03-11 14:47:01 +08:00
|
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
elem_ty: Type,
|
2024-06-25 15:35:02 +08:00
|
|
|
shape: BasicValueEnum<'ctx>,
|
2024-03-11 14:47:01 +08:00
|
|
|
) -> Result<NDArrayValue<'ctx>, String> {
|
2024-06-25 15:35:02 +08:00
|
|
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
|
|
|
|
|
|
match shape {
|
|
|
|
BasicValueEnum::PointerValue(shape_list_ptr)
|
2024-11-01 15:17:00 +08:00
|
|
|
if ListValue::is_representable(shape_list_ptr, llvm_usize).is_ok() =>
|
2024-06-25 15:35:02 +08:00
|
|
|
{
|
|
|
|
// 1. A list of ints; e.g., `np.empty([600, 800, 3])`
|
|
|
|
|
2024-11-01 15:17:00 +08:00
|
|
|
let shape_list = ListValue::from_pointer_value(shape_list_ptr, llvm_usize, None);
|
2024-06-25 15:35:02 +08:00
|
|
|
create_ndarray_dyn_shape(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
elem_ty,
|
|
|
|
&shape_list,
|
|
|
|
|_, ctx, shape_list| Ok(shape_list.load_size(ctx, None)),
|
|
|
|
|generator, ctx, shape_list, idx| {
|
|
|
|
Ok(shape_list.data().get(ctx, generator, &idx, None).into_int_value())
|
|
|
|
},
|
|
|
|
)
|
|
|
|
}
|
|
|
|
BasicValueEnum::StructValue(shape_tuple) => {
|
|
|
|
// 2. A tuple of ints; e.g., `np.empty((600, 800, 3))`
|
|
|
|
// Read [`codegen::expr::gen_expr`] to see how `nac3core` translates a Python tuple into LLVM.
|
|
|
|
|
|
|
|
// Get the length/size of the tuple, which also happens to be the value of `ndims`.
|
|
|
|
let ndims = shape_tuple.get_type().count_fields();
|
|
|
|
|
|
|
|
let mut shape = Vec::with_capacity(ndims as usize);
|
|
|
|
for dim_i in 0..ndims {
|
|
|
|
let dim = ctx
|
|
|
|
.builder
|
|
|
|
.build_extract_value(shape_tuple, dim_i, format!("dim{dim_i}").as_str())
|
|
|
|
.unwrap()
|
|
|
|
.into_int_value();
|
|
|
|
|
|
|
|
shape.push(dim);
|
|
|
|
}
|
|
|
|
create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice())
|
|
|
|
}
|
|
|
|
BasicValueEnum::IntValue(shape_int) => {
|
|
|
|
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
|
|
|
|
|
|
|
|
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
|
|
|
|
}
|
2024-08-23 13:10:55 +08:00
|
|
|
_ => codegen_unreachable!(ctx),
|
2024-06-25 15:35:02 +08:00
|
|
|
}
|
2024-03-11 14:47:01 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as
|
|
|
|
/// its input.
|
2024-03-19 18:24:30 +08:00
|
|
|
fn ndarray_fill_flattened<'ctx, 'a, G, ValueFn>(
|
|
|
|
generator: &mut G,
|
2024-03-11 14:47:01 +08:00
|
|
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
|
|
|
ndarray: NDArrayValue<'ctx>,
|
|
|
|
value_fn: ValueFn,
|
|
|
|
) -> Result<(), String>
|
2024-06-12 14:45:03 +08:00
|
|
|
where
|
|
|
|
G: CodeGenerator + ?Sized,
|
|
|
|
ValueFn: Fn(
|
|
|
|
&mut G,
|
|
|
|
&mut CodeGenContext<'ctx, 'a>,
|
|
|
|
IntValue<'ctx>,
|
|
|
|
) -> Result<BasicValueEnum<'ctx>, String>,
|
2024-03-11 14:47:01 +08:00
|
|
|
{
|
|
|
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
|
|
|
|
|
|
let ndarray_num_elems = call_ndarray_calc_size(
|
|
|
|
generator,
|
|
|
|
ctx,
|
2024-11-13 15:53:29 +08:00
|
|
|
&ndarray.shape().as_slice_value(ctx, generator),
|
2024-05-27 15:58:06 +08:00
|
|
|
(None, None),
|
2024-03-11 14:47:01 +08:00
|
|
|
);
|
|
|
|
|
2024-03-08 13:13:18 +08:00
|
|
|
gen_for_callback_incrementing(
|
2024-03-11 14:47:01 +08:00
|
|
|
generator,
|
|
|
|
ctx,
|
2024-07-25 15:54:39 +08:00
|
|
|
None,
|
2024-03-08 13:13:18 +08:00
|
|
|
llvm_usize.const_zero(),
|
|
|
|
(ndarray_num_elems, false),
|
2024-07-02 19:05:00 +08:00
|
|
|
|generator, ctx, _, i| {
|
2024-06-12 14:45:03 +08:00
|
|
|
let elem = unsafe { ndarray.data().ptr_offset_unchecked(ctx, generator, &i, None) };
|
2024-03-11 14:47:01 +08:00
|
|
|
|
|
|
|
let value = value_fn(generator, ctx, i)?;
|
|
|
|
ctx.builder.build_store(elem, value).unwrap();
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
},
|
2024-03-08 13:13:18 +08:00
|
|
|
llvm_usize.const_int(1, false),
|
2024-03-11 14:47:01 +08:00
|
|
|
)
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with the dimension-indices
|
|
|
|
/// as its input.
|
2024-04-29 23:21:57 +08:00
|
|
|
fn ndarray_fill_indexed<'ctx, 'a, G, ValueFn>(
|
2024-03-19 18:24:30 +08:00
|
|
|
generator: &mut G,
|
2024-04-29 23:21:57 +08:00
|
|
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
2024-03-11 14:47:01 +08:00
|
|
|
ndarray: NDArrayValue<'ctx>,
|
|
|
|
value_fn: ValueFn,
|
|
|
|
) -> Result<(), String>
|
2024-06-12 14:45:03 +08:00
|
|
|
where
|
|
|
|
G: CodeGenerator + ?Sized,
|
|
|
|
ValueFn: Fn(
|
|
|
|
&mut G,
|
|
|
|
&mut CodeGenContext<'ctx, 'a>,
|
|
|
|
&TypedArrayLikeAdapter<'ctx, IntValue<'ctx>>,
|
|
|
|
) -> Result<BasicValueEnum<'ctx>, String>,
|
2024-03-11 14:47:01 +08:00
|
|
|
{
|
2024-06-12 14:45:03 +08:00
|
|
|
ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, idx| {
|
|
|
|
let indices = call_ndarray_calc_nd_indices(generator, ctx, idx, ndarray);
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
value_fn(generator, ctx, &indices)
|
|
|
|
})
|
2024-03-11 14:47:01 +08:00
|
|
|
}
|
|
|
|
|
2024-04-29 23:21:57 +08:00
|
|
|
fn ndarray_fill_mapping<'ctx, 'a, G, MapFn>(
|
2024-03-27 17:06:58 +08:00
|
|
|
generator: &mut G,
|
2024-04-29 23:21:57 +08:00
|
|
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
2024-03-27 17:06:58 +08:00
|
|
|
src: NDArrayValue<'ctx>,
|
|
|
|
dest: NDArrayValue<'ctx>,
|
|
|
|
map_fn: MapFn,
|
|
|
|
) -> Result<(), String>
|
2024-06-12 14:45:03 +08:00
|
|
|
where
|
|
|
|
G: CodeGenerator + ?Sized,
|
|
|
|
MapFn: Fn(
|
|
|
|
&mut G,
|
|
|
|
&mut CodeGenContext<'ctx, 'a>,
|
|
|
|
BasicValueEnum<'ctx>,
|
|
|
|
) -> Result<BasicValueEnum<'ctx>, String>,
|
2024-03-27 17:06:58 +08:00
|
|
|
{
|
2024-06-12 14:45:03 +08:00
|
|
|
ndarray_fill_flattened(generator, ctx, dest, |generator, ctx, i| {
|
|
|
|
let elem = unsafe { src.data().get_unchecked(ctx, generator, &i, None) };
|
2024-03-27 17:06:58 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
map_fn(generator, ctx, elem)
|
|
|
|
})
|
2024-03-27 17:06:58 +08:00
|
|
|
}
|
|
|
|
|
2024-03-13 11:16:23 +08:00
|
|
|
/// Generates the LLVM IR for checking whether the source `ndarray` can be broadcast to the shape of
|
|
|
|
/// the target `ndarray`.
|
|
|
|
fn ndarray_assert_is_broadcastable<'ctx, G: CodeGenerator + ?Sized>(
|
|
|
|
generator: &mut G,
|
|
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
target: NDArrayValue<'ctx>,
|
|
|
|
source: NDArrayValue<'ctx>,
|
|
|
|
) {
|
|
|
|
let array_ndims = source.load_ndims(ctx);
|
|
|
|
let broadcast_size = target.load_ndims(ctx);
|
|
|
|
|
|
|
|
ctx.make_assert(
|
|
|
|
generator,
|
|
|
|
ctx.builder.build_int_compare(IntPredicate::ULE, array_ndims, broadcast_size, "").unwrap(),
|
|
|
|
"0:ValueError",
|
|
|
|
"operands cannot be broadcast together",
|
|
|
|
[None, None, None],
|
|
|
|
ctx.current_loc,
|
|
|
|
);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Generates the LLVM IR for populating the entire `NDArray` from two `ndarray` or scalar value
|
|
|
|
/// with broadcast-compatible shapes.
|
2024-04-29 23:21:57 +08:00
|
|
|
fn ndarray_broadcast_fill<'ctx, 'a, G, ValueFn>(
|
2024-03-13 11:16:23 +08:00
|
|
|
generator: &mut G,
|
2024-04-29 23:21:57 +08:00
|
|
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
2024-03-13 11:16:23 +08:00
|
|
|
res: NDArrayValue<'ctx>,
|
2024-08-28 16:33:03 +08:00
|
|
|
lhs: (Type, BasicValueEnum<'ctx>, bool),
|
|
|
|
rhs: (Type, BasicValueEnum<'ctx>, bool),
|
2024-03-13 11:16:23 +08:00
|
|
|
value_fn: ValueFn,
|
|
|
|
) -> Result<NDArrayValue<'ctx>, String>
|
2024-06-12 14:45:03 +08:00
|
|
|
where
|
|
|
|
G: CodeGenerator + ?Sized,
|
|
|
|
ValueFn: Fn(
|
|
|
|
&mut G,
|
|
|
|
&mut CodeGenContext<'ctx, 'a>,
|
|
|
|
(BasicValueEnum<'ctx>, BasicValueEnum<'ctx>),
|
|
|
|
) -> Result<BasicValueEnum<'ctx>, String>,
|
2024-03-13 11:16:23 +08:00
|
|
|
{
|
|
|
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
|
|
|
2024-08-28 16:33:03 +08:00
|
|
|
let (lhs_ty, lhs_val, lhs_scalar) = lhs;
|
|
|
|
let (rhs_ty, rhs_val, rhs_scalar) = rhs;
|
2024-03-13 11:16:23 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
assert!(
|
|
|
|
!(lhs_scalar && rhs_scalar),
|
|
|
|
"One of the operands must be a ndarray instance: `{}`, `{}`",
|
|
|
|
lhs_val.get_type(),
|
|
|
|
rhs_val.get_type()
|
|
|
|
);
|
2024-03-13 11:16:23 +08:00
|
|
|
|
|
|
|
// Assert that all ndarray operands are broadcastable to the target size
|
|
|
|
if !lhs_scalar {
|
2024-08-28 16:33:03 +08:00
|
|
|
let lhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, lhs_ty);
|
|
|
|
let llvm_lhs_elem_ty = ctx.get_llvm_type(generator, lhs_dtype);
|
|
|
|
let lhs_val = NDArrayValue::from_pointer_value(
|
|
|
|
lhs_val.into_pointer_value(),
|
|
|
|
llvm_lhs_elem_ty,
|
|
|
|
llvm_usize,
|
|
|
|
None,
|
|
|
|
);
|
2024-03-13 11:16:23 +08:00
|
|
|
ndarray_assert_is_broadcastable(generator, ctx, res, lhs_val);
|
|
|
|
}
|
|
|
|
|
|
|
|
if !rhs_scalar {
|
2024-08-28 16:33:03 +08:00
|
|
|
let rhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, rhs_ty);
|
|
|
|
let llvm_rhs_elem_ty = ctx.get_llvm_type(generator, rhs_dtype);
|
|
|
|
let rhs_val = NDArrayValue::from_pointer_value(
|
|
|
|
rhs_val.into_pointer_value(),
|
|
|
|
llvm_rhs_elem_ty,
|
|
|
|
llvm_usize,
|
|
|
|
None,
|
|
|
|
);
|
2024-03-13 11:16:23 +08:00
|
|
|
ndarray_assert_is_broadcastable(generator, ctx, res, rhs_val);
|
|
|
|
}
|
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
ndarray_fill_indexed(generator, ctx, res, |generator, ctx, idx| {
|
|
|
|
let lhs_elem = if lhs_scalar {
|
|
|
|
lhs_val
|
|
|
|
} else {
|
2024-08-28 16:33:03 +08:00
|
|
|
let lhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, lhs_ty);
|
|
|
|
let llvm_lhs_elem_ty = ctx.get_llvm_type(generator, lhs_dtype);
|
|
|
|
let lhs = NDArrayValue::from_pointer_value(
|
|
|
|
lhs_val.into_pointer_value(),
|
|
|
|
llvm_lhs_elem_ty,
|
|
|
|
llvm_usize,
|
|
|
|
None,
|
|
|
|
);
|
2024-06-12 14:45:03 +08:00
|
|
|
let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, idx);
|
2024-03-13 11:16:23 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
unsafe { lhs.data().get_unchecked(ctx, generator, &lhs_idx, None) }
|
|
|
|
};
|
2024-03-13 11:16:23 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
let rhs_elem = if rhs_scalar {
|
|
|
|
rhs_val
|
|
|
|
} else {
|
2024-08-28 16:33:03 +08:00
|
|
|
let rhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, rhs_ty);
|
|
|
|
let llvm_rhs_elem_ty = ctx.get_llvm_type(generator, rhs_dtype);
|
|
|
|
let rhs = NDArrayValue::from_pointer_value(
|
|
|
|
rhs_val.into_pointer_value(),
|
|
|
|
llvm_rhs_elem_ty,
|
|
|
|
llvm_usize,
|
|
|
|
None,
|
|
|
|
);
|
2024-06-12 14:45:03 +08:00
|
|
|
let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, idx);
|
2024-03-13 11:16:23 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
unsafe { rhs.data().get_unchecked(ctx, generator, &rhs_idx, None) }
|
|
|
|
};
|
|
|
|
|
|
|
|
value_fn(generator, ctx, (lhs_elem, rhs_elem))
|
|
|
|
})?;
|
2024-03-13 11:16:23 +08:00
|
|
|
|
|
|
|
Ok(res)
|
|
|
|
}
|
|
|
|
|
2024-03-11 14:47:01 +08:00
|
|
|
/// LLVM-typed implementation for generating the implementation for `ndarray.zeros`.
|
|
|
|
///
|
|
|
|
/// * `elem_ty` - The element type of the `NDArray`.
|
|
|
|
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
|
2024-03-19 18:24:30 +08:00
|
|
|
fn call_ndarray_zeros_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|
|
|
generator: &mut G,
|
2024-03-11 14:47:01 +08:00
|
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
elem_ty: Type,
|
2024-06-25 15:35:02 +08:00
|
|
|
shape: BasicValueEnum<'ctx>,
|
2024-03-11 14:47:01 +08:00
|
|
|
) -> Result<NDArrayValue<'ctx>, String> {
|
|
|
|
let supported_types = [
|
|
|
|
ctx.primitives.int32,
|
|
|
|
ctx.primitives.int64,
|
|
|
|
ctx.primitives.uint32,
|
|
|
|
ctx.primitives.uint64,
|
|
|
|
ctx.primitives.float,
|
|
|
|
ctx.primitives.bool,
|
|
|
|
ctx.primitives.str,
|
|
|
|
];
|
|
|
|
assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty)));
|
|
|
|
|
|
|
|
let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?;
|
2024-06-12 14:45:03 +08:00
|
|
|
ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| {
|
|
|
|
let value = ndarray_zero_value(generator, ctx, elem_ty);
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
Ok(value)
|
|
|
|
})?;
|
2024-03-11 14:47:01 +08:00
|
|
|
|
|
|
|
Ok(ndarray)
|
|
|
|
}
|
|
|
|
|
|
|
|
/// LLVM-typed implementation for generating the implementation for `ndarray.ones`.
|
|
|
|
///
|
|
|
|
/// * `elem_ty` - The element type of the `NDArray`.
|
|
|
|
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
|
2024-03-19 18:24:30 +08:00
|
|
|
fn call_ndarray_ones_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|
|
|
generator: &mut G,
|
2024-03-11 14:47:01 +08:00
|
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
elem_ty: Type,
|
2024-06-25 15:35:02 +08:00
|
|
|
shape: BasicValueEnum<'ctx>,
|
2024-03-11 14:47:01 +08:00
|
|
|
) -> Result<NDArrayValue<'ctx>, String> {
|
|
|
|
let supported_types = [
|
|
|
|
ctx.primitives.int32,
|
|
|
|
ctx.primitives.int64,
|
|
|
|
ctx.primitives.uint32,
|
|
|
|
ctx.primitives.uint64,
|
|
|
|
ctx.primitives.float,
|
|
|
|
ctx.primitives.bool,
|
|
|
|
ctx.primitives.str,
|
|
|
|
];
|
|
|
|
assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty)));
|
|
|
|
|
|
|
|
let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?;
|
2024-06-12 14:45:03 +08:00
|
|
|
ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| {
|
|
|
|
let value = ndarray_one_value(generator, ctx, elem_ty);
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
Ok(value)
|
|
|
|
})?;
|
2024-03-11 14:47:01 +08:00
|
|
|
|
|
|
|
Ok(ndarray)
|
|
|
|
}
|
|
|
|
|
|
|
|
/// LLVM-typed implementation for generating the implementation for `ndarray.full`.
|
|
|
|
///
|
|
|
|
/// * `elem_ty` - The element type of the `NDArray`.
|
|
|
|
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
|
2024-03-19 18:24:30 +08:00
|
|
|
fn call_ndarray_full_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|
|
|
generator: &mut G,
|
2024-03-11 14:47:01 +08:00
|
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
elem_ty: Type,
|
2024-06-25 15:35:02 +08:00
|
|
|
shape: BasicValueEnum<'ctx>,
|
2024-03-11 14:47:01 +08:00
|
|
|
fill_value: BasicValueEnum<'ctx>,
|
|
|
|
) -> Result<NDArrayValue<'ctx>, String> {
|
|
|
|
let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?;
|
2024-06-12 14:45:03 +08:00
|
|
|
ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| {
|
|
|
|
let value = if fill_value.is_pointer_value() {
|
|
|
|
let llvm_i1 = ctx.ctx.bool_type();
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
let copy = generator.gen_var_alloc(ctx, fill_value.get_type(), None)?;
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
call_memcpy_generic(
|
|
|
|
ctx,
|
|
|
|
copy,
|
|
|
|
fill_value.into_pointer_value(),
|
|
|
|
fill_value.get_type().size_of().map(Into::into).unwrap(),
|
|
|
|
llvm_i1.const_zero(),
|
|
|
|
);
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
copy.into()
|
|
|
|
} else if fill_value.is_int_value() || fill_value.is_float_value() {
|
|
|
|
fill_value
|
|
|
|
} else {
|
2024-08-23 13:10:55 +08:00
|
|
|
codegen_unreachable!(ctx)
|
2024-06-12 14:45:03 +08:00
|
|
|
};
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
Ok(value)
|
|
|
|
})?;
|
2024-03-11 14:47:01 +08:00
|
|
|
|
|
|
|
Ok(ndarray)
|
|
|
|
}
|
|
|
|
|
2024-06-11 15:29:32 +08:00
|
|
|
/// Returns the number of dimensions for a multidimensional list as an [`IntValue`].
|
|
|
|
fn llvm_ndlist_get_ndims<'ctx, G: CodeGenerator + ?Sized>(
|
|
|
|
generator: &G,
|
|
|
|
ctx: &CodeGenContext<'ctx, '_>,
|
|
|
|
ty: PointerType<'ctx>,
|
|
|
|
) -> IntValue<'ctx> {
|
|
|
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
|
|
|
|
|
|
let list_ty = ListType::from_type(ty, llvm_usize);
|
|
|
|
let list_elem_ty = list_ty.element_type();
|
|
|
|
|
|
|
|
let ndims = llvm_usize.const_int(1, false);
|
|
|
|
match list_elem_ty {
|
2024-11-01 15:17:00 +08:00
|
|
|
AnyTypeEnum::PointerType(ptr_ty)
|
|
|
|
if ListType::is_representable(ptr_ty, llvm_usize).is_ok() =>
|
|
|
|
{
|
2024-06-11 15:29:32 +08:00
|
|
|
ndims.const_add(llvm_ndlist_get_ndims(generator, ctx, ptr_ty))
|
|
|
|
}
|
|
|
|
|
2024-11-01 15:17:00 +08:00
|
|
|
AnyTypeEnum::PointerType(ptr_ty)
|
|
|
|
if NDArrayType::is_representable(ptr_ty, llvm_usize).is_ok() =>
|
|
|
|
{
|
2024-06-11 15:29:32 +08:00
|
|
|
todo!("Getting ndims for list[ndarray] not supported")
|
|
|
|
}
|
|
|
|
|
|
|
|
_ => ndims,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Returns the number of dimensions for an array-like object as an [`IntValue`].
|
|
|
|
fn llvm_arraylike_get_ndims<'ctx, G: CodeGenerator + ?Sized>(
|
|
|
|
generator: &mut G,
|
|
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
2024-08-28 16:33:03 +08:00
|
|
|
(ty, value): (Type, BasicValueEnum<'ctx>),
|
2024-06-11 15:29:32 +08:00
|
|
|
) -> IntValue<'ctx> {
|
|
|
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
|
|
|
|
|
|
match value {
|
2024-11-01 15:17:00 +08:00
|
|
|
BasicValueEnum::PointerValue(v)
|
|
|
|
if NDArrayValue::is_representable(v, llvm_usize).is_ok() =>
|
|
|
|
{
|
2024-08-28 16:33:03 +08:00
|
|
|
let dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty);
|
|
|
|
let llvm_elem_ty = ctx.get_llvm_type(generator, dtype);
|
|
|
|
NDArrayValue::from_pointer_value(v, llvm_elem_ty, llvm_usize, None).load_ndims(ctx)
|
2024-06-11 15:29:32 +08:00
|
|
|
}
|
|
|
|
|
2024-11-01 15:17:00 +08:00
|
|
|
BasicValueEnum::PointerValue(v) if ListValue::is_representable(v, llvm_usize).is_ok() => {
|
2024-06-11 15:29:32 +08:00
|
|
|
llvm_ndlist_get_ndims(generator, ctx, v.get_type())
|
|
|
|
}
|
|
|
|
|
|
|
|
_ => llvm_usize.const_zero(),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Flattens and copies the values from a multidimensional list into an [`NDArrayValue`].
|
|
|
|
fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|
|
|
generator: &mut G,
|
|
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
(dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>),
|
|
|
|
src_lst: ListValue<'ctx>,
|
|
|
|
dim: u64,
|
|
|
|
) -> Result<(), String> {
|
|
|
|
let llvm_i1 = ctx.ctx.bool_type();
|
|
|
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
|
|
|
|
|
|
let list_elem_ty = src_lst.get_type().element_type();
|
|
|
|
|
|
|
|
match list_elem_ty {
|
2024-11-01 15:17:00 +08:00
|
|
|
AnyTypeEnum::PointerType(ptr_ty)
|
|
|
|
if ListType::is_representable(ptr_ty, llvm_usize).is_ok() =>
|
|
|
|
{
|
2024-06-12 14:45:03 +08:00
|
|
|
// The stride of elements in this dimension, i.e. the number of elements between arr[i]
|
2024-06-11 15:29:32 +08:00
|
|
|
// and arr[i + 1] in this dimension
|
|
|
|
let stride = call_ndarray_calc_size(
|
|
|
|
generator,
|
|
|
|
ctx,
|
2024-11-13 15:53:29 +08:00
|
|
|
&dst_arr.shape(),
|
2024-06-11 15:29:32 +08:00
|
|
|
(Some(llvm_usize.const_int(dim + 1, false)), None),
|
|
|
|
);
|
|
|
|
|
|
|
|
gen_for_range_callback(
|
|
|
|
generator,
|
|
|
|
ctx,
|
2024-07-25 15:54:39 +08:00
|
|
|
None,
|
2024-06-11 15:29:32 +08:00
|
|
|
true,
|
|
|
|
|_, _| Ok(llvm_usize.const_zero()),
|
|
|
|
(|_, ctx| Ok(src_lst.load_size(ctx, None)), false),
|
|
|
|
|_, _| Ok(llvm_usize.const_int(1, false)),
|
2024-07-25 15:54:39 +08:00
|
|
|
|generator, ctx, _, i| {
|
2024-06-12 14:45:03 +08:00
|
|
|
let offset = ctx.builder.build_int_mul(stride, i, "").unwrap();
|
2024-08-28 16:33:03 +08:00
|
|
|
let offset = ctx
|
|
|
|
.builder
|
|
|
|
.build_int_mul(
|
|
|
|
offset,
|
|
|
|
ctx.builder
|
|
|
|
.build_int_truncate_or_bit_cast(
|
|
|
|
dst_arr.get_type().element_type().size_of().unwrap(),
|
|
|
|
offset.get_type(),
|
|
|
|
"",
|
|
|
|
)
|
|
|
|
.unwrap(),
|
|
|
|
"",
|
|
|
|
)
|
|
|
|
.unwrap();
|
2024-06-11 15:29:32 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
let dst_ptr =
|
|
|
|
unsafe { ctx.builder.build_gep(dst_slice_ptr, &[offset], "").unwrap() };
|
2024-06-11 15:29:32 +08:00
|
|
|
|
2024-11-01 15:17:00 +08:00
|
|
|
let nested_lst_elem = ListValue::from_pointer_value(
|
2024-06-12 14:45:03 +08:00
|
|
|
unsafe { src_lst.data().get_unchecked(ctx, generator, &i, None) }
|
|
|
|
.into_pointer_value(),
|
2024-06-11 15:29:32 +08:00
|
|
|
llvm_usize,
|
|
|
|
None,
|
|
|
|
);
|
|
|
|
|
|
|
|
ndarray_from_ndlist_impl(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
(dst_arr, dst_ptr),
|
|
|
|
nested_lst_elem,
|
|
|
|
dim + 1,
|
|
|
|
)?;
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
},
|
|
|
|
)?;
|
|
|
|
}
|
|
|
|
|
2024-11-01 15:17:00 +08:00
|
|
|
AnyTypeEnum::PointerType(ptr_ty)
|
|
|
|
if NDArrayType::is_representable(ptr_ty, llvm_usize).is_ok() =>
|
|
|
|
{
|
2024-06-11 15:29:32 +08:00
|
|
|
todo!("Not implemented for list[ndarray]")
|
|
|
|
}
|
|
|
|
|
|
|
|
_ => {
|
|
|
|
let lst_len = src_lst.load_size(ctx, None);
|
2024-08-28 16:33:03 +08:00
|
|
|
let sizeof_elem = dst_arr.get_type().element_type().size_of().unwrap();
|
2024-11-29 17:19:43 +08:00
|
|
|
let sizeof_elem =
|
|
|
|
ctx.builder.build_int_z_extend_or_bit_cast(sizeof_elem, llvm_usize, "").unwrap();
|
2024-07-04 12:24:52 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
let cpy_len = ctx
|
|
|
|
.builder
|
|
|
|
.build_int_mul(
|
|
|
|
ctx.builder.build_int_z_extend_or_bit_cast(lst_len, llvm_usize, "").unwrap(),
|
|
|
|
sizeof_elem,
|
|
|
|
"",
|
|
|
|
)
|
|
|
|
.unwrap();
|
2024-06-11 15:29:32 +08:00
|
|
|
|
|
|
|
call_memcpy_generic(
|
|
|
|
ctx,
|
|
|
|
dst_slice_ptr,
|
|
|
|
src_lst.data().base_ptr(ctx, generator),
|
|
|
|
cpy_len,
|
|
|
|
llvm_i1.const_zero(),
|
|
|
|
);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
|
|
|
|
/// LLVM-typed implementation for `ndarray.array`.
|
|
|
|
fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|
|
|
generator: &mut G,
|
|
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
elem_ty: Type,
|
|
|
|
object: BasicValueEnum<'ctx>,
|
|
|
|
copy: IntValue<'ctx>,
|
|
|
|
ndmin: IntValue<'ctx>,
|
|
|
|
) -> Result<NDArrayValue<'ctx>, String> {
|
|
|
|
let llvm_i1 = ctx.ctx.bool_type();
|
|
|
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
let ndmin = ctx.builder.build_int_z_extend_or_bit_cast(ndmin, llvm_usize, "").unwrap();
|
2024-06-11 15:29:32 +08:00
|
|
|
|
|
|
|
// TODO(Derppening): Add assertions for sizes of different dimensions
|
|
|
|
|
|
|
|
// object is not a pointer - 0-dim NDArray
|
|
|
|
if !object.is_pointer_value() {
|
2024-06-12 14:45:03 +08:00
|
|
|
let ndarray = create_ndarray_const_shape(generator, ctx, elem_ty, &[])?;
|
2024-06-11 15:29:32 +08:00
|
|
|
|
|
|
|
unsafe {
|
2024-06-12 14:45:03 +08:00
|
|
|
ndarray.data().set_unchecked(ctx, generator, &llvm_usize.const_zero(), object);
|
2024-06-11 15:29:32 +08:00
|
|
|
}
|
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
return Ok(ndarray);
|
2024-06-11 15:29:32 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
let object = object.into_pointer_value();
|
|
|
|
|
|
|
|
// object is an NDArray instance - copy object unless copy=0 && ndmin < object.ndims
|
2024-11-01 15:17:00 +08:00
|
|
|
if NDArrayValue::is_representable(object, llvm_usize).is_ok() {
|
2024-08-28 16:33:03 +08:00
|
|
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
|
|
let object = NDArrayValue::from_pointer_value(object, llvm_elem_ty, llvm_usize, None);
|
2024-06-11 15:29:32 +08:00
|
|
|
|
|
|
|
let ndarray = gen_if_else_expr_callback(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
|_, ctx| {
|
2024-06-12 14:45:03 +08:00
|
|
|
let copy_nez = ctx
|
|
|
|
.builder
|
2024-06-11 15:29:32 +08:00
|
|
|
.build_int_compare(IntPredicate::NE, copy, llvm_i1.const_zero(), "")
|
|
|
|
.unwrap();
|
2024-06-12 14:45:03 +08:00
|
|
|
let ndmin_gt_ndims = ctx
|
|
|
|
.builder
|
2024-06-11 15:29:32 +08:00
|
|
|
.build_int_compare(IntPredicate::UGT, ndmin, object.load_ndims(ctx), "")
|
|
|
|
.unwrap();
|
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
Ok(ctx.builder.build_and(copy_nez, ndmin_gt_ndims, "").unwrap())
|
2024-06-11 15:29:32 +08:00
|
|
|
},
|
|
|
|
|generator, ctx| {
|
|
|
|
let ndarray = create_ndarray_dyn_shape(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
elem_ty,
|
|
|
|
&object,
|
|
|
|
|_, ctx, object| {
|
|
|
|
let ndims = object.load_ndims(ctx);
|
2024-06-12 14:45:03 +08:00
|
|
|
let ndmin_gt_ndims = ctx
|
|
|
|
.builder
|
2024-06-11 15:29:32 +08:00
|
|
|
.build_int_compare(IntPredicate::UGT, ndmin, object.load_ndims(ctx), "")
|
|
|
|
.unwrap();
|
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
Ok(ctx
|
|
|
|
.builder
|
2024-06-11 15:29:32 +08:00
|
|
|
.build_select(ndmin_gt_ndims, ndmin, ndims, "")
|
|
|
|
.map(BasicValueEnum::into_int_value)
|
|
|
|
.unwrap())
|
|
|
|
},
|
|
|
|
|generator, ctx, object, idx| {
|
|
|
|
let ndims = object.load_ndims(ctx);
|
|
|
|
let ndmin = llvm_intrinsics::call_int_umax(ctx, ndims, ndmin, None);
|
|
|
|
// The number of dimensions to prepend 1's to
|
|
|
|
let offset = ctx.builder.build_int_sub(ndmin, ndims, "").unwrap();
|
|
|
|
|
|
|
|
Ok(gen_if_else_expr_callback(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
|_, ctx| {
|
2024-06-12 14:45:03 +08:00
|
|
|
Ok(ctx
|
|
|
|
.builder
|
2024-06-11 15:29:32 +08:00
|
|
|
.build_int_compare(IntPredicate::UGE, idx, offset, "")
|
|
|
|
.unwrap())
|
|
|
|
},
|
2024-06-12 14:45:03 +08:00
|
|
|
|_, _| Ok(Some(llvm_usize.const_int(1, false))),
|
|
|
|
|_, ctx| Ok(Some(ctx.builder.build_int_sub(idx, offset, "").unwrap())),
|
|
|
|
)?
|
|
|
|
.map(BasicValueEnum::into_int_value)
|
|
|
|
.unwrap())
|
2024-06-11 15:29:32 +08:00
|
|
|
},
|
|
|
|
)?;
|
|
|
|
|
|
|
|
ndarray_sliced_copyto_impl(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
(ndarray, ndarray.data().base_ptr(ctx, generator)),
|
|
|
|
(object, object.data().base_ptr(ctx, generator)),
|
|
|
|
0,
|
|
|
|
&[],
|
|
|
|
)?;
|
|
|
|
|
|
|
|
Ok(Some(ndarray.as_base_value()))
|
|
|
|
},
|
2024-06-12 14:45:03 +08:00
|
|
|
|_, _| Ok(Some(object.as_base_value())),
|
2024-06-11 15:29:32 +08:00
|
|
|
)?;
|
|
|
|
|
2024-11-01 15:17:00 +08:00
|
|
|
return Ok(NDArrayValue::from_pointer_value(
|
2024-06-11 15:29:32 +08:00
|
|
|
ndarray.map(BasicValueEnum::into_pointer_value).unwrap(),
|
2024-08-28 16:33:03 +08:00
|
|
|
llvm_elem_ty,
|
2024-06-11 15:29:32 +08:00
|
|
|
llvm_usize,
|
|
|
|
None,
|
2024-06-12 14:45:03 +08:00
|
|
|
));
|
2024-06-11 15:29:32 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
// Remaining case: TList
|
2024-11-01 15:17:00 +08:00
|
|
|
assert!(ListValue::is_representable(object, llvm_usize).is_ok());
|
|
|
|
let object = ListValue::from_pointer_value(object, llvm_usize, None);
|
2024-06-11 15:29:32 +08:00
|
|
|
|
|
|
|
// The number of dimensions to prepend 1's to
|
|
|
|
let ndims = llvm_ndlist_get_ndims(generator, ctx, object.as_base_value().get_type());
|
|
|
|
let ndmin = llvm_intrinsics::call_int_umax(ctx, ndims, ndmin, None);
|
|
|
|
let offset = ctx.builder.build_int_sub(ndmin, ndims, "").unwrap();
|
|
|
|
|
|
|
|
let ndarray = create_ndarray_dyn_shape(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
elem_ty,
|
|
|
|
&object,
|
|
|
|
|generator, ctx, object| {
|
|
|
|
let ndims = llvm_ndlist_get_ndims(generator, ctx, object.as_base_value().get_type());
|
2024-06-12 14:45:03 +08:00
|
|
|
let ndmin_gt_ndims =
|
|
|
|
ctx.builder.build_int_compare(IntPredicate::UGT, ndmin, ndims, "").unwrap();
|
2024-06-11 15:29:32 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
Ok(ctx
|
|
|
|
.builder
|
2024-06-11 15:29:32 +08:00
|
|
|
.build_select(ndmin_gt_ndims, ndmin, ndims, "")
|
|
|
|
.map(BasicValueEnum::into_int_value)
|
|
|
|
.unwrap())
|
|
|
|
},
|
|
|
|
|generator, ctx, object, idx| {
|
|
|
|
Ok(gen_if_else_expr_callback(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
|_, ctx| {
|
2024-06-12 14:45:03 +08:00
|
|
|
Ok(ctx.builder.build_int_compare(IntPredicate::ULT, idx, offset, "").unwrap())
|
2024-06-11 15:29:32 +08:00
|
|
|
},
|
2024-06-12 14:45:03 +08:00
|
|
|
|_, _| Ok(Some(llvm_usize.const_int(1, false))),
|
2024-06-11 15:29:32 +08:00
|
|
|
|generator, ctx| {
|
|
|
|
let make_llvm_list = |elem_ty: BasicTypeEnum<'ctx>| {
|
|
|
|
ctx.ctx.struct_type(
|
2024-06-12 14:45:03 +08:00
|
|
|
&[elem_ty.ptr_type(AddressSpace::default()).into(), llvm_usize.into()],
|
2024-06-11 15:29:32 +08:00
|
|
|
false,
|
|
|
|
)
|
|
|
|
};
|
|
|
|
|
|
|
|
let llvm_i8 = ctx.ctx.i8_type();
|
|
|
|
let llvm_list_i8 = make_llvm_list(llvm_i8.into());
|
|
|
|
let llvm_plist_i8 = llvm_list_i8.ptr_type(AddressSpace::default());
|
|
|
|
|
|
|
|
// Cast list to { i8*, usize } since we only care about the size
|
2024-06-12 14:45:03 +08:00
|
|
|
let lst = generator
|
|
|
|
.gen_var_alloc(
|
|
|
|
ctx,
|
|
|
|
ListType::new(generator, ctx.ctx, llvm_i8.into()).as_base_type().into(),
|
|
|
|
None,
|
|
|
|
)
|
|
|
|
.unwrap();
|
|
|
|
ctx.builder
|
|
|
|
.build_store(
|
|
|
|
lst,
|
|
|
|
ctx.builder
|
2024-08-20 20:16:36 +08:00
|
|
|
.build_bit_cast(object.as_base_value(), llvm_plist_i8, "")
|
2024-06-12 14:45:03 +08:00
|
|
|
.unwrap(),
|
|
|
|
)
|
|
|
|
.unwrap();
|
2024-06-11 15:29:32 +08:00
|
|
|
|
|
|
|
let stop = ctx.builder.build_int_sub(idx, offset, "").unwrap();
|
|
|
|
gen_for_range_callback(
|
|
|
|
generator,
|
|
|
|
ctx,
|
2024-07-25 15:54:39 +08:00
|
|
|
None,
|
2024-06-11 15:29:32 +08:00
|
|
|
true,
|
|
|
|
|_, _| Ok(llvm_usize.const_zero()),
|
|
|
|
(|_, _| Ok(stop), false),
|
|
|
|
|_, _| Ok(llvm_usize.const_int(1, false)),
|
2024-07-25 15:54:39 +08:00
|
|
|
|generator, ctx, _, _| {
|
2024-06-11 15:29:32 +08:00
|
|
|
let plist_plist_i8 = make_llvm_list(llvm_plist_i8.into())
|
|
|
|
.ptr_type(AddressSpace::default());
|
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
let this_dim = ctx
|
|
|
|
.builder
|
2024-06-11 15:29:32 +08:00
|
|
|
.build_load(lst, "")
|
|
|
|
.map(BasicValueEnum::into_pointer_value)
|
2024-08-20 20:16:36 +08:00
|
|
|
.map(|v| ctx.builder.build_bit_cast(v, plist_plist_i8, "").unwrap())
|
2024-06-11 15:29:32 +08:00
|
|
|
.map(BasicValueEnum::into_pointer_value)
|
|
|
|
.unwrap();
|
2024-11-01 15:17:00 +08:00
|
|
|
let this_dim =
|
|
|
|
ListValue::from_pointer_value(this_dim, llvm_usize, None);
|
2024-06-11 15:29:32 +08:00
|
|
|
|
|
|
|
// TODO: Assert this_dim.sz != 0
|
|
|
|
|
|
|
|
let next_dim = unsafe {
|
2024-06-12 14:45:03 +08:00
|
|
|
this_dim.data().get_unchecked(
|
|
|
|
ctx,
|
|
|
|
generator,
|
|
|
|
&llvm_usize.const_zero(),
|
|
|
|
None,
|
|
|
|
)
|
|
|
|
}
|
|
|
|
.into_pointer_value();
|
|
|
|
ctx.builder
|
|
|
|
.build_store(
|
|
|
|
lst,
|
2024-08-20 20:16:36 +08:00
|
|
|
ctx.builder
|
|
|
|
.build_bit_cast(next_dim, llvm_plist_i8, "")
|
|
|
|
.unwrap(),
|
2024-06-12 14:45:03 +08:00
|
|
|
)
|
|
|
|
.unwrap();
|
2024-06-11 15:29:32 +08:00
|
|
|
|
|
|
|
Ok(())
|
|
|
|
},
|
|
|
|
)?;
|
|
|
|
|
2024-11-01 15:17:00 +08:00
|
|
|
let lst = ListValue::from_pointer_value(
|
2024-06-11 15:29:32 +08:00
|
|
|
ctx.builder
|
|
|
|
.build_load(lst, "")
|
|
|
|
.map(BasicValueEnum::into_pointer_value)
|
|
|
|
.unwrap(),
|
|
|
|
llvm_usize,
|
|
|
|
None,
|
|
|
|
);
|
|
|
|
|
|
|
|
Ok(Some(lst.load_size(ctx, None)))
|
|
|
|
},
|
2024-06-12 14:45:03 +08:00
|
|
|
)?
|
|
|
|
.map(BasicValueEnum::into_int_value)
|
|
|
|
.unwrap())
|
2024-06-11 15:29:32 +08:00
|
|
|
},
|
|
|
|
)?;
|
|
|
|
|
|
|
|
ndarray_from_ndlist_impl(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
(ndarray, ndarray.data().base_ptr(ctx, generator)),
|
|
|
|
object,
|
|
|
|
0,
|
|
|
|
)?;
|
|
|
|
|
|
|
|
Ok(ndarray)
|
|
|
|
}
|
|
|
|
|
2024-03-11 14:47:01 +08:00
|
|
|
/// LLVM-typed implementation for generating the implementation for `ndarray.eye`.
|
|
|
|
///
|
|
|
|
/// * `elem_ty` - The element type of the `NDArray`.
|
2024-03-19 18:24:30 +08:00
|
|
|
fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|
|
|
generator: &mut G,
|
2024-03-11 14:47:01 +08:00
|
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
elem_ty: Type,
|
|
|
|
nrows: IntValue<'ctx>,
|
|
|
|
ncols: IntValue<'ctx>,
|
|
|
|
offset: IntValue<'ctx>,
|
|
|
|
) -> Result<NDArrayValue<'ctx>, String> {
|
2024-03-22 16:10:42 +08:00
|
|
|
let llvm_i32 = ctx.ctx.i32_type();
|
2024-03-11 14:47:01 +08:00
|
|
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
|
|
|
|
|
|
let nrows = ctx.builder.build_int_z_extend_or_bit_cast(nrows, llvm_usize, "").unwrap();
|
|
|
|
let ncols = ctx.builder.build_int_z_extend_or_bit_cast(ncols, llvm_usize, "").unwrap();
|
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
let ndarray = create_ndarray_const_shape(generator, ctx, elem_ty, &[nrows, ncols])?;
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
ndarray_fill_indexed(generator, ctx, ndarray, |generator, ctx, indices| {
|
|
|
|
let (row, col) = unsafe {
|
|
|
|
(
|
|
|
|
indices.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None),
|
|
|
|
indices.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None),
|
|
|
|
)
|
|
|
|
};
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
let col_with_offset = ctx
|
|
|
|
.builder
|
|
|
|
.build_int_add(
|
|
|
|
col,
|
|
|
|
ctx.builder.build_int_s_extend_or_bit_cast(offset, llvm_i32, "").unwrap(),
|
|
|
|
"",
|
|
|
|
)
|
|
|
|
.unwrap();
|
|
|
|
let is_on_diag =
|
|
|
|
ctx.builder.build_int_compare(IntPredicate::EQ, row, col_with_offset, "").unwrap();
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
let zero = ndarray_zero_value(generator, ctx, elem_ty);
|
|
|
|
let one = ndarray_one_value(generator, ctx, elem_ty);
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
let value = ctx.builder.build_select(is_on_diag, one, zero, "").unwrap();
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
Ok(value)
|
|
|
|
})?;
|
2024-03-11 14:47:01 +08:00
|
|
|
|
|
|
|
Ok(ndarray)
|
|
|
|
}
|
|
|
|
|
2024-05-30 14:25:56 +08:00
|
|
|
/// Copies a slice of an [`NDArrayValue`] to another.
|
2024-03-11 14:47:01 +08:00
|
|
|
///
|
2024-05-30 14:25:56 +08:00
|
|
|
/// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `dim_sz`
|
2024-08-21 11:10:52 +08:00
|
|
|
/// fields should be populated before calling this function.
|
2024-05-30 14:25:56 +08:00
|
|
|
/// - `dst_slice_ptr`: The [`PointerValue`] to the first element of the currently processing
|
2024-08-21 11:10:52 +08:00
|
|
|
/// dimensional slice in the destination array.
|
2024-05-30 14:25:56 +08:00
|
|
|
/// - `src_arr`: The [`NDArrayValue`] instance of the source array.
|
|
|
|
/// - `src_slice_ptr`: The [`PointerValue`] to the first element of the currently processing
|
2024-08-21 11:10:52 +08:00
|
|
|
/// dimensional slice in the source array.
|
2024-05-30 14:25:56 +08:00
|
|
|
/// - `dim`: The index of the currently processing dimension.
|
|
|
|
/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to
|
2024-08-21 11:10:52 +08:00
|
|
|
/// this dimension. The `start`/`stop` values of each slice must be non-negative indices.
|
2024-05-30 14:25:56 +08:00
|
|
|
fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
|
2024-03-19 18:24:30 +08:00
|
|
|
generator: &mut G,
|
2024-03-11 14:47:01 +08:00
|
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
2024-05-30 14:25:56 +08:00
|
|
|
(dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>),
|
|
|
|
(src_arr, src_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>),
|
|
|
|
dim: u64,
|
|
|
|
slices: &[(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)],
|
|
|
|
) -> Result<(), String> {
|
2024-03-11 14:47:01 +08:00
|
|
|
let llvm_i1 = ctx.ctx.bool_type();
|
2024-05-30 14:25:56 +08:00
|
|
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
|
|
|
2024-08-28 16:33:03 +08:00
|
|
|
assert_eq!(dst_arr.get_type().element_type(), src_arr.get_type().element_type());
|
|
|
|
|
|
|
|
let sizeof_elem = dst_arr.get_type().element_type().size_of().unwrap();
|
|
|
|
|
2024-05-30 14:25:56 +08:00
|
|
|
// If there are no (remaining) slice expressions, memcpy the entire dimension
|
|
|
|
if slices.is_empty() {
|
|
|
|
let stride = call_ndarray_calc_size(
|
|
|
|
generator,
|
|
|
|
ctx,
|
2024-11-13 15:53:29 +08:00
|
|
|
&src_arr.shape(),
|
2024-05-30 14:25:56 +08:00
|
|
|
(Some(llvm_usize.const_int(dim, false)), None),
|
|
|
|
);
|
2024-07-22 01:46:50 +08:00
|
|
|
let stride =
|
|
|
|
ctx.builder.build_int_z_extend_or_bit_cast(stride, sizeof_elem.get_type(), "").unwrap();
|
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
let cpy_len = ctx.builder.build_int_mul(stride, sizeof_elem, "").unwrap();
|
2024-05-30 14:25:56 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
call_memcpy_generic(ctx, dst_slice_ptr, src_slice_ptr, cpy_len, llvm_i1.const_zero());
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
return Ok(());
|
2024-05-30 14:25:56 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
// The stride of elements in this dimension, i.e. the number of elements between arr[i] and
|
|
|
|
// arr[i + 1] in this dimension
|
|
|
|
let src_stride = call_ndarray_calc_size(
|
2024-03-11 14:47:01 +08:00
|
|
|
generator,
|
|
|
|
ctx,
|
2024-11-13 15:53:29 +08:00
|
|
|
&src_arr.shape(),
|
2024-05-30 14:25:56 +08:00
|
|
|
(Some(llvm_usize.const_int(dim + 1, false)), None),
|
|
|
|
);
|
|
|
|
let dst_stride = call_ndarray_calc_size(
|
|
|
|
generator,
|
|
|
|
ctx,
|
2024-11-13 15:53:29 +08:00
|
|
|
&dst_arr.shape(),
|
2024-05-30 14:25:56 +08:00
|
|
|
(Some(llvm_usize.const_int(dim + 1, false)), None),
|
|
|
|
);
|
|
|
|
|
|
|
|
let (start, stop, step) = slices[0];
|
|
|
|
let start = ctx.builder.build_int_s_extend_or_bit_cast(start, llvm_usize, "").unwrap();
|
|
|
|
let stop = ctx.builder.build_int_s_extend_or_bit_cast(stop, llvm_usize, "").unwrap();
|
|
|
|
let step = ctx.builder.build_int_s_extend_or_bit_cast(step, llvm_usize, "").unwrap();
|
|
|
|
|
|
|
|
let dst_i_addr = generator.gen_var_alloc(ctx, start.get_type().into(), None).unwrap();
|
|
|
|
ctx.builder.build_store(dst_i_addr, start.get_type().const_zero()).unwrap();
|
|
|
|
|
|
|
|
gen_for_range_callback(
|
|
|
|
generator,
|
|
|
|
ctx,
|
2024-07-25 15:54:39 +08:00
|
|
|
None,
|
2024-05-30 14:25:56 +08:00
|
|
|
false,
|
|
|
|
|_, _| Ok(start),
|
|
|
|
(|_, _| Ok(stop), true),
|
|
|
|
|_, _| Ok(step),
|
2024-07-25 15:54:39 +08:00
|
|
|
|generator, ctx, _, src_i| {
|
2024-05-30 14:25:56 +08:00
|
|
|
// Calculate the offset of the active slice
|
2024-06-12 14:45:03 +08:00
|
|
|
let src_data_offset = ctx.builder.build_int_mul(src_stride, src_i, "").unwrap();
|
2024-08-28 16:33:03 +08:00
|
|
|
let src_data_offset = ctx
|
|
|
|
.builder
|
|
|
|
.build_int_mul(
|
|
|
|
src_data_offset,
|
|
|
|
ctx.builder
|
2024-11-29 17:19:43 +08:00
|
|
|
.build_int_z_extend_or_bit_cast(sizeof_elem, src_data_offset.get_type(), "")
|
2024-08-28 16:33:03 +08:00
|
|
|
.unwrap(),
|
|
|
|
"",
|
|
|
|
)
|
|
|
|
.unwrap();
|
2024-06-12 14:45:03 +08:00
|
|
|
let dst_i =
|
|
|
|
ctx.builder.build_load(dst_i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
|
|
|
|
let dst_data_offset = ctx.builder.build_int_mul(dst_stride, dst_i, "").unwrap();
|
2024-08-28 16:33:03 +08:00
|
|
|
let dst_data_offset = ctx
|
|
|
|
.builder
|
|
|
|
.build_int_mul(
|
|
|
|
dst_data_offset,
|
|
|
|
ctx.builder
|
2024-11-29 17:19:43 +08:00
|
|
|
.build_int_z_extend_or_bit_cast(sizeof_elem, dst_data_offset.get_type(), "")
|
2024-08-28 16:33:03 +08:00
|
|
|
.unwrap(),
|
|
|
|
"",
|
|
|
|
)
|
|
|
|
.unwrap();
|
2024-05-30 14:25:56 +08:00
|
|
|
|
|
|
|
let (src_ptr, dst_ptr) = unsafe {
|
|
|
|
(
|
|
|
|
ctx.builder.build_gep(src_slice_ptr, &[src_data_offset], "").unwrap(),
|
|
|
|
ctx.builder.build_gep(dst_slice_ptr, &[dst_data_offset], "").unwrap(),
|
|
|
|
)
|
|
|
|
};
|
|
|
|
|
|
|
|
ndarray_sliced_copyto_impl(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
(dst_arr, dst_ptr),
|
|
|
|
(src_arr, src_ptr),
|
|
|
|
dim + 1,
|
|
|
|
&slices[1..],
|
|
|
|
)?;
|
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
let dst_i =
|
|
|
|
ctx.builder.build_load(dst_i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
|
|
|
|
let dst_i_add1 =
|
|
|
|
ctx.builder.build_int_add(dst_i, llvm_usize.const_int(1, false), "").unwrap();
|
2024-05-30 14:25:56 +08:00
|
|
|
ctx.builder.build_store(dst_i_addr, dst_i_add1).unwrap();
|
|
|
|
|
|
|
|
Ok(())
|
2024-03-11 14:47:01 +08:00
|
|
|
},
|
|
|
|
)?;
|
|
|
|
|
2024-05-30 14:25:56 +08:00
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Copies a [`NDArrayValue`] using slices.
|
|
|
|
///
|
|
|
|
/// * `elem_ty` - The element type of the `NDArray`.
|
|
|
|
/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to
|
2024-08-21 11:10:52 +08:00
|
|
|
/// this dimension. The `start`/`stop` values of each slice must be positive indices.
|
2024-05-30 14:25:56 +08:00
|
|
|
pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
|
|
|
|
generator: &mut G,
|
|
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
elem_ty: Type,
|
|
|
|
this: NDArrayValue<'ctx>,
|
|
|
|
slices: &[(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)],
|
|
|
|
) -> Result<NDArrayValue<'ctx>, String> {
|
|
|
|
let llvm_i32 = ctx.ctx.i32_type();
|
|
|
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
|
|
|
|
|
|
let ndarray = if slices.is_empty() {
|
|
|
|
create_ndarray_dyn_shape(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
elem_ty,
|
|
|
|
&this,
|
2024-06-12 14:45:03 +08:00
|
|
|
|_, ctx, shape| Ok(shape.load_ndims(ctx)),
|
|
|
|
|generator, ctx, shape, idx| unsafe {
|
2024-11-13 15:53:29 +08:00
|
|
|
Ok(shape.shape().get_typed_unchecked(ctx, generator, &idx, None))
|
2024-05-30 14:25:56 +08:00
|
|
|
},
|
|
|
|
)?
|
|
|
|
} else {
|
|
|
|
let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?;
|
|
|
|
ndarray.store_ndims(ctx, generator, this.load_ndims(ctx));
|
|
|
|
|
|
|
|
let ndims = this.load_ndims(ctx);
|
2024-11-13 15:53:29 +08:00
|
|
|
ndarray.create_shape(ctx, llvm_usize, ndims);
|
2024-05-30 14:25:56 +08:00
|
|
|
|
|
|
|
// Populate the first slices.len() dimensions by computing the size of each dim slice
|
|
|
|
for (i, (start, stop, step)) in slices.iter().enumerate() {
|
|
|
|
// HACK: workaround calculate_len_for_slice_range requiring exclusive stop
|
2024-06-12 14:45:03 +08:00
|
|
|
let stop = ctx
|
|
|
|
.builder
|
2024-05-30 14:25:56 +08:00
|
|
|
.build_select(
|
2024-06-12 14:45:03 +08:00
|
|
|
ctx.builder
|
|
|
|
.build_int_compare(
|
|
|
|
IntPredicate::SLT,
|
|
|
|
*step,
|
|
|
|
llvm_i32.const_zero(),
|
|
|
|
"is_neg",
|
|
|
|
)
|
|
|
|
.unwrap(),
|
|
|
|
ctx.builder
|
|
|
|
.build_int_sub(*stop, llvm_i32.const_int(1, true), "e_min_one")
|
|
|
|
.unwrap(),
|
|
|
|
ctx.builder
|
|
|
|
.build_int_add(*stop, llvm_i32.const_int(1, true), "e_add_one")
|
|
|
|
.unwrap(),
|
2024-05-30 14:25:56 +08:00
|
|
|
"final_e",
|
|
|
|
)
|
|
|
|
.map(BasicValueEnum::into_int_value)
|
|
|
|
.unwrap();
|
|
|
|
|
|
|
|
let slice_len = calculate_len_for_slice_range(generator, ctx, *start, stop, *step);
|
2024-06-12 14:45:03 +08:00
|
|
|
let slice_len =
|
|
|
|
ctx.builder.build_int_z_extend_or_bit_cast(slice_len, llvm_usize, "").unwrap();
|
2024-05-30 14:25:56 +08:00
|
|
|
|
|
|
|
unsafe {
|
2024-11-13 15:53:29 +08:00
|
|
|
ndarray.shape().set_typed_unchecked(
|
2024-06-12 14:45:03 +08:00
|
|
|
ctx,
|
|
|
|
generator,
|
|
|
|
&llvm_usize.const_int(i as u64, false),
|
|
|
|
slice_len,
|
|
|
|
);
|
2024-05-30 14:25:56 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Populate the rest by directly copying the dim size from the source array
|
|
|
|
gen_for_callback_incrementing(
|
|
|
|
generator,
|
|
|
|
ctx,
|
2024-07-25 15:54:39 +08:00
|
|
|
None,
|
2024-05-30 14:25:56 +08:00
|
|
|
llvm_usize.const_int(slices.len() as u64, false),
|
|
|
|
(this.load_ndims(ctx), false),
|
2024-07-02 19:05:00 +08:00
|
|
|
|generator, ctx, _, idx| {
|
2024-05-30 14:25:56 +08:00
|
|
|
unsafe {
|
2024-11-13 15:53:29 +08:00
|
|
|
let dim_sz = this.shape().get_typed_unchecked(ctx, generator, &idx, None);
|
|
|
|
ndarray.shape().set_typed_unchecked(ctx, generator, &idx, dim_sz);
|
2024-05-30 14:25:56 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
},
|
|
|
|
llvm_usize.const_int(1, false),
|
2024-06-12 14:45:03 +08:00
|
|
|
)
|
|
|
|
.unwrap();
|
2024-05-30 14:25:56 +08:00
|
|
|
|
|
|
|
ndarray_init_data(generator, ctx, elem_ty, ndarray)
|
|
|
|
};
|
|
|
|
|
|
|
|
ndarray_sliced_copyto_impl(
|
2024-03-11 14:47:01 +08:00
|
|
|
generator,
|
|
|
|
ctx,
|
2024-05-30 14:25:56 +08:00
|
|
|
(ndarray, ndarray.data().base_ptr(ctx, generator)),
|
|
|
|
(this, this.data().base_ptr(ctx, generator)),
|
|
|
|
0,
|
|
|
|
slices,
|
|
|
|
)?;
|
2024-03-11 14:47:01 +08:00
|
|
|
|
|
|
|
Ok(ndarray)
|
|
|
|
}
|
|
|
|
|
2024-05-30 14:25:56 +08:00
|
|
|
/// LLVM-typed implementation for generating the implementation for `ndarray.copy`.
|
|
|
|
///
|
|
|
|
/// * `elem_ty` - The element type of the `NDArray`.
|
|
|
|
fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|
|
|
generator: &mut G,
|
|
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
elem_ty: Type,
|
|
|
|
this: NDArrayValue<'ctx>,
|
|
|
|
) -> Result<NDArrayValue<'ctx>, String> {
|
|
|
|
ndarray_sliced_copy(generator, ctx, elem_ty, this, &[])
|
|
|
|
}
|
|
|
|
|
2024-04-29 23:21:57 +08:00
|
|
|
pub fn ndarray_elementwise_unaryop_impl<'ctx, 'a, G, MapFn>(
|
2024-03-27 17:06:58 +08:00
|
|
|
generator: &mut G,
|
2024-04-29 23:21:57 +08:00
|
|
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
2024-03-27 17:06:58 +08:00
|
|
|
elem_ty: Type,
|
|
|
|
res: Option<NDArrayValue<'ctx>>,
|
|
|
|
operand: NDArrayValue<'ctx>,
|
|
|
|
map_fn: MapFn,
|
|
|
|
) -> Result<NDArrayValue<'ctx>, String>
|
2024-06-12 14:45:03 +08:00
|
|
|
where
|
|
|
|
G: CodeGenerator + ?Sized,
|
|
|
|
MapFn: Fn(
|
|
|
|
&mut G,
|
|
|
|
&mut CodeGenContext<'ctx, 'a>,
|
|
|
|
BasicValueEnum<'ctx>,
|
|
|
|
) -> Result<BasicValueEnum<'ctx>, String>,
|
2024-03-27 17:06:58 +08:00
|
|
|
{
|
|
|
|
let res = res.unwrap_or_else(|| {
|
|
|
|
create_ndarray_dyn_shape(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
elem_ty,
|
|
|
|
&operand,
|
2024-06-12 14:45:03 +08:00
|
|
|
|_, ctx, v| Ok(v.load_ndims(ctx)),
|
|
|
|
|generator, ctx, v, idx| unsafe {
|
2024-11-13 15:53:29 +08:00
|
|
|
Ok(v.shape().get_typed_unchecked(ctx, generator, &idx, None))
|
2024-03-27 17:06:58 +08:00
|
|
|
},
|
2024-06-12 14:45:03 +08:00
|
|
|
)
|
|
|
|
.unwrap()
|
2024-03-27 17:06:58 +08:00
|
|
|
});
|
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
ndarray_fill_mapping(generator, ctx, operand, res, |generator, ctx, elem| {
|
|
|
|
map_fn(generator, ctx, elem)
|
|
|
|
})?;
|
2024-03-27 17:06:58 +08:00
|
|
|
|
|
|
|
Ok(res)
|
|
|
|
}
|
|
|
|
|
2024-03-13 11:16:23 +08:00
|
|
|
/// LLVM-typed implementation for computing elementwise binary operations on two input operands.
|
|
|
|
///
|
2024-06-12 14:45:03 +08:00
|
|
|
/// If the operand is a `ndarray`, the broadcast index corresponding to each element in the output
|
|
|
|
/// is computed, the element accessed and used as an operand of the `value_fn` arguments tuple.
|
|
|
|
/// Otherwise, the operand is treated as a scalar value, and is used as an operand of the
|
2024-03-13 11:16:23 +08:00
|
|
|
/// `value_fn` arguments tuple for all output elements.
|
|
|
|
///
|
|
|
|
/// The second element of the tuple indicates whether to treat the operand value as a `ndarray`
|
2024-06-12 14:45:03 +08:00
|
|
|
/// (which would be accessed by its broadcast index) or as a scalar value (which would be
|
2024-03-13 11:16:23 +08:00
|
|
|
/// broadcast to all elements).
|
|
|
|
///
|
|
|
|
/// * `elem_ty` - The element type of the `NDArray`.
|
|
|
|
/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be
|
2024-08-21 11:10:52 +08:00
|
|
|
/// written to a new `ndarray`.
|
2024-03-13 11:16:23 +08:00
|
|
|
/// * `value_fn` - Function mapping the two input elements into the result.
|
|
|
|
///
|
|
|
|
/// # Panic
|
|
|
|
///
|
|
|
|
/// This function will panic if neither input operands (`lhs` or `rhs`) is a `ndarray`.
|
2024-04-29 23:21:57 +08:00
|
|
|
pub fn ndarray_elementwise_binop_impl<'ctx, 'a, G, ValueFn>(
|
2024-03-13 11:16:23 +08:00
|
|
|
generator: &mut G,
|
2024-04-29 23:21:57 +08:00
|
|
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
2024-03-13 11:16:23 +08:00
|
|
|
elem_ty: Type,
|
|
|
|
res: Option<NDArrayValue<'ctx>>,
|
2024-08-28 16:33:03 +08:00
|
|
|
lhs: (Type, BasicValueEnum<'ctx>, bool),
|
|
|
|
rhs: (Type, BasicValueEnum<'ctx>, bool),
|
2024-03-13 11:16:23 +08:00
|
|
|
value_fn: ValueFn,
|
|
|
|
) -> Result<NDArrayValue<'ctx>, String>
|
2024-06-12 14:45:03 +08:00
|
|
|
where
|
|
|
|
G: CodeGenerator + ?Sized,
|
|
|
|
ValueFn: Fn(
|
|
|
|
&mut G,
|
|
|
|
&mut CodeGenContext<'ctx, 'a>,
|
|
|
|
(BasicValueEnum<'ctx>, BasicValueEnum<'ctx>),
|
|
|
|
) -> Result<BasicValueEnum<'ctx>, String>,
|
2024-03-13 11:16:23 +08:00
|
|
|
{
|
|
|
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
|
|
|
2024-08-28 16:33:03 +08:00
|
|
|
let (lhs_ty, lhs_val, lhs_scalar) = lhs;
|
|
|
|
let (rhs_ty, rhs_val, rhs_scalar) = rhs;
|
2024-03-13 11:16:23 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
assert!(
|
|
|
|
!(lhs_scalar && rhs_scalar),
|
|
|
|
"One of the operands must be a ndarray instance: `{}`, `{}`",
|
|
|
|
lhs_val.get_type(),
|
|
|
|
rhs_val.get_type()
|
|
|
|
);
|
2024-03-13 11:16:23 +08:00
|
|
|
|
|
|
|
let ndarray = res.unwrap_or_else(|| {
|
|
|
|
if lhs_scalar && rhs_scalar {
|
2024-08-28 16:33:03 +08:00
|
|
|
let lhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, lhs_ty);
|
|
|
|
let llvm_lhs_elem_ty = ctx.get_llvm_type(generator, lhs_dtype);
|
|
|
|
let lhs_val = NDArrayValue::from_pointer_value(
|
|
|
|
lhs_val.into_pointer_value(),
|
|
|
|
llvm_lhs_elem_ty,
|
|
|
|
llvm_usize,
|
|
|
|
None,
|
|
|
|
);
|
|
|
|
let rhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, rhs_ty);
|
|
|
|
let llvm_rhs_elem_ty = ctx.get_llvm_type(generator, rhs_dtype);
|
|
|
|
let rhs_val = NDArrayValue::from_pointer_value(
|
|
|
|
rhs_val.into_pointer_value(),
|
|
|
|
llvm_rhs_elem_ty,
|
|
|
|
llvm_usize,
|
|
|
|
None,
|
|
|
|
);
|
2024-03-13 11:16:23 +08:00
|
|
|
|
|
|
|
let ndarray_dims = call_ndarray_calc_broadcast(generator, ctx, lhs_val, rhs_val);
|
|
|
|
|
|
|
|
create_ndarray_dyn_shape(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
elem_ty,
|
|
|
|
&ndarray_dims,
|
2024-06-12 14:45:03 +08:00
|
|
|
|generator, ctx, v| Ok(v.size(ctx, generator)),
|
|
|
|
|generator, ctx, v, idx| unsafe {
|
|
|
|
Ok(v.get_typed_unchecked(ctx, generator, &idx, None))
|
2024-03-13 11:16:23 +08:00
|
|
|
},
|
2024-06-12 14:45:03 +08:00
|
|
|
)
|
|
|
|
.unwrap()
|
2024-03-13 11:16:23 +08:00
|
|
|
} else {
|
2024-08-28 16:33:03 +08:00
|
|
|
let dtype = arraylike_flatten_element_type(
|
|
|
|
&mut ctx.unifier,
|
|
|
|
if lhs_scalar { rhs_ty } else { lhs_ty },
|
|
|
|
);
|
|
|
|
let llvm_elem_ty = ctx.get_llvm_type(generator, dtype);
|
2024-11-01 15:17:00 +08:00
|
|
|
let ndarray = NDArrayValue::from_pointer_value(
|
2024-03-13 11:16:23 +08:00
|
|
|
if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(),
|
2024-08-28 16:33:03 +08:00
|
|
|
llvm_elem_ty,
|
2024-03-13 11:16:23 +08:00
|
|
|
llvm_usize,
|
|
|
|
None,
|
|
|
|
);
|
|
|
|
|
|
|
|
create_ndarray_dyn_shape(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
elem_ty,
|
|
|
|
&ndarray,
|
2024-06-12 14:45:03 +08:00
|
|
|
|_, ctx, v| Ok(v.load_ndims(ctx)),
|
|
|
|
|generator, ctx, v, idx| unsafe {
|
2024-11-13 15:53:29 +08:00
|
|
|
Ok(v.shape().get_typed_unchecked(ctx, generator, &idx, None))
|
2024-03-13 11:16:23 +08:00
|
|
|
},
|
2024-06-12 14:45:03 +08:00
|
|
|
)
|
|
|
|
.unwrap()
|
2024-03-13 11:16:23 +08:00
|
|
|
}
|
|
|
|
});
|
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
ndarray_broadcast_fill(generator, ctx, ndarray, lhs, rhs, |generator, ctx, elems| {
|
|
|
|
value_fn(generator, ctx, elems)
|
|
|
|
})?;
|
2024-03-13 11:16:23 +08:00
|
|
|
|
|
|
|
Ok(ndarray)
|
|
|
|
}
|
|
|
|
|
2024-04-19 19:00:07 +08:00
|
|
|
/// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s.
|
|
|
|
///
|
|
|
|
/// * `elem_ty` - The element type of the `NDArray`.
|
|
|
|
/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be
|
2024-08-21 11:10:52 +08:00
|
|
|
/// written to a new `ndarray`.
|
2024-04-19 19:00:07 +08:00
|
|
|
pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
|
|
|
|
generator: &mut G,
|
|
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
elem_ty: Type,
|
|
|
|
res: Option<NDArrayValue<'ctx>>,
|
|
|
|
lhs: NDArrayValue<'ctx>,
|
|
|
|
rhs: NDArrayValue<'ctx>,
|
|
|
|
) -> Result<NDArrayValue<'ctx>, String> {
|
|
|
|
let llvm_i32 = ctx.ctx.i32_type();
|
|
|
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
|
|
|
|
|
|
if cfg!(debug_assertions) {
|
|
|
|
let lhs_ndims = lhs.load_ndims(ctx);
|
|
|
|
let rhs_ndims = rhs.load_ndims(ctx);
|
|
|
|
|
|
|
|
// lhs.ndims == 2
|
|
|
|
ctx.make_assert(
|
|
|
|
generator,
|
2024-06-12 14:45:03 +08:00
|
|
|
ctx.builder
|
|
|
|
.build_int_compare(IntPredicate::EQ, lhs_ndims, llvm_usize.const_int(2, false), "")
|
|
|
|
.unwrap(),
|
2024-04-19 19:00:07 +08:00
|
|
|
"0:ValueError",
|
|
|
|
"",
|
|
|
|
[None, None, None],
|
|
|
|
ctx.current_loc,
|
|
|
|
);
|
|
|
|
|
|
|
|
// rhs.ndims == 2
|
|
|
|
ctx.make_assert(
|
|
|
|
generator,
|
2024-06-12 14:45:03 +08:00
|
|
|
ctx.builder
|
|
|
|
.build_int_compare(IntPredicate::EQ, rhs_ndims, llvm_usize.const_int(2, false), "")
|
|
|
|
.unwrap(),
|
2024-04-19 19:00:07 +08:00
|
|
|
"0:ValueError",
|
|
|
|
"",
|
|
|
|
[None, None, None],
|
|
|
|
ctx.current_loc,
|
|
|
|
);
|
|
|
|
|
|
|
|
if let Some(res) = res {
|
|
|
|
let res_ndims = res.load_ndims(ctx);
|
|
|
|
let res_dim0 = unsafe {
|
2024-11-13 15:53:29 +08:00
|
|
|
res.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
2024-04-19 19:00:07 +08:00
|
|
|
};
|
|
|
|
let res_dim1 = unsafe {
|
2024-11-13 15:53:29 +08:00
|
|
|
res.shape().get_typed_unchecked(
|
2024-06-12 14:45:03 +08:00
|
|
|
ctx,
|
|
|
|
generator,
|
|
|
|
&llvm_usize.const_int(1, false),
|
|
|
|
None,
|
|
|
|
)
|
2024-04-19 19:00:07 +08:00
|
|
|
};
|
|
|
|
let lhs_dim0 = unsafe {
|
2024-11-13 15:53:29 +08:00
|
|
|
lhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
2024-04-19 19:00:07 +08:00
|
|
|
};
|
|
|
|
let rhs_dim1 = unsafe {
|
2024-11-13 15:53:29 +08:00
|
|
|
rhs.shape().get_typed_unchecked(
|
2024-06-12 14:45:03 +08:00
|
|
|
ctx,
|
|
|
|
generator,
|
|
|
|
&llvm_usize.const_int(1, false),
|
|
|
|
None,
|
|
|
|
)
|
2024-04-19 19:00:07 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
// res.ndims == 2
|
|
|
|
ctx.make_assert(
|
|
|
|
generator,
|
2024-06-12 14:45:03 +08:00
|
|
|
ctx.builder
|
|
|
|
.build_int_compare(
|
|
|
|
IntPredicate::EQ,
|
|
|
|
res_ndims,
|
|
|
|
llvm_usize.const_int(2, false),
|
|
|
|
"",
|
|
|
|
)
|
|
|
|
.unwrap(),
|
2024-04-19 19:00:07 +08:00
|
|
|
"0:ValueError",
|
|
|
|
"",
|
|
|
|
[None, None, None],
|
|
|
|
ctx.current_loc,
|
|
|
|
);
|
|
|
|
|
|
|
|
// res.dims[0] == lhs.dims[0]
|
|
|
|
ctx.make_assert(
|
|
|
|
generator,
|
2024-06-12 14:45:03 +08:00
|
|
|
ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim0, res_dim0, "").unwrap(),
|
2024-04-19 19:00:07 +08:00
|
|
|
"0:ValueError",
|
|
|
|
"",
|
|
|
|
[None, None, None],
|
|
|
|
ctx.current_loc,
|
|
|
|
);
|
|
|
|
|
|
|
|
// res.dims[1] == rhs.dims[0]
|
|
|
|
ctx.make_assert(
|
|
|
|
generator,
|
2024-06-12 14:45:03 +08:00
|
|
|
ctx.builder.build_int_compare(IntPredicate::EQ, rhs_dim1, res_dim1, "").unwrap(),
|
2024-04-19 19:00:07 +08:00
|
|
|
"0:ValueError",
|
|
|
|
"",
|
|
|
|
[None, None, None],
|
|
|
|
ctx.current_loc,
|
|
|
|
);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
|
|
|
let lhs_dim1 = unsafe {
|
2024-11-13 15:53:29 +08:00
|
|
|
lhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
2024-04-19 19:00:07 +08:00
|
|
|
};
|
|
|
|
let rhs_dim0 = unsafe {
|
2024-11-13 15:53:29 +08:00
|
|
|
rhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
2024-04-19 19:00:07 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
// lhs.dims[1] == rhs.dims[0]
|
|
|
|
ctx.make_assert(
|
|
|
|
generator,
|
2024-06-12 14:45:03 +08:00
|
|
|
ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim1, rhs_dim0, "").unwrap(),
|
2024-04-19 19:00:07 +08:00
|
|
|
"0:ValueError",
|
|
|
|
"",
|
|
|
|
[None, None, None],
|
|
|
|
ctx.current_loc,
|
|
|
|
);
|
|
|
|
}
|
|
|
|
|
2024-06-06 12:16:09 +08:00
|
|
|
let lhs = if res.is_some_and(|res| res.as_base_value() == lhs.as_base_value()) {
|
2024-04-19 19:00:07 +08:00
|
|
|
ndarray_copy_impl(generator, ctx, elem_ty, lhs)?
|
|
|
|
} else {
|
|
|
|
lhs
|
|
|
|
};
|
|
|
|
|
|
|
|
let ndarray = res.unwrap_or_else(|| {
|
|
|
|
create_ndarray_dyn_shape(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
elem_ty,
|
|
|
|
&(lhs, rhs),
|
2024-06-12 14:45:03 +08:00
|
|
|
|_, _, _| Ok(llvm_usize.const_int(2, false)),
|
2024-04-19 19:00:07 +08:00
|
|
|
|generator, ctx, (lhs, rhs), idx| {
|
|
|
|
gen_if_else_expr_callback(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
|_, ctx| {
|
2024-06-12 14:45:03 +08:00
|
|
|
Ok(ctx
|
|
|
|
.builder
|
|
|
|
.build_int_compare(IntPredicate::EQ, idx, llvm_usize.const_zero(), "")
|
|
|
|
.unwrap())
|
2024-04-19 19:00:07 +08:00
|
|
|
},
|
|
|
|
|generator, ctx| {
|
|
|
|
Ok(Some(unsafe {
|
2024-11-13 15:53:29 +08:00
|
|
|
lhs.shape().get_typed_unchecked(
|
2024-04-19 19:00:07 +08:00
|
|
|
ctx,
|
|
|
|
generator,
|
|
|
|
&llvm_usize.const_zero(),
|
|
|
|
None,
|
|
|
|
)
|
|
|
|
}))
|
|
|
|
},
|
|
|
|
|generator, ctx| {
|
|
|
|
Ok(Some(unsafe {
|
2024-11-13 15:53:29 +08:00
|
|
|
rhs.shape().get_typed_unchecked(
|
2024-04-19 19:00:07 +08:00
|
|
|
ctx,
|
|
|
|
generator,
|
|
|
|
&llvm_usize.const_int(1, false),
|
|
|
|
None,
|
|
|
|
)
|
|
|
|
}))
|
|
|
|
},
|
2024-06-12 14:45:03 +08:00
|
|
|
)
|
|
|
|
.map(|v| v.map(BasicValueEnum::into_int_value).unwrap())
|
2024-04-19 19:00:07 +08:00
|
|
|
},
|
2024-06-12 14:45:03 +08:00
|
|
|
)
|
|
|
|
.unwrap()
|
2024-04-19 19:00:07 +08:00
|
|
|
});
|
|
|
|
|
|
|
|
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
ndarray_fill_indexed(generator, ctx, ndarray, |generator, ctx, idx| {
|
|
|
|
llvm_intrinsics::call_expect(
|
|
|
|
ctx,
|
|
|
|
idx.size(ctx, generator).get_type().const_int(2, false),
|
|
|
|
idx.size(ctx, generator),
|
|
|
|
None,
|
|
|
|
);
|
2024-04-19 19:00:07 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
let common_dim = {
|
|
|
|
let lhs_idx1 = unsafe {
|
2024-11-13 15:53:29 +08:00
|
|
|
lhs.shape().get_typed_unchecked(
|
2024-04-19 19:00:07 +08:00
|
|
|
ctx,
|
|
|
|
generator,
|
|
|
|
&llvm_usize.const_int(1, false),
|
|
|
|
None,
|
2024-06-12 14:45:03 +08:00
|
|
|
)
|
|
|
|
};
|
|
|
|
let rhs_idx0 = unsafe {
|
2024-11-13 15:53:29 +08:00
|
|
|
rhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
2024-04-19 19:00:07 +08:00
|
|
|
};
|
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None);
|
2024-04-19 19:00:07 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
ctx.builder.build_int_truncate(idx, llvm_i32, "").unwrap()
|
|
|
|
};
|
2024-04-19 19:00:07 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
let idx0 = unsafe {
|
|
|
|
let idx0 = idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
|
2024-04-19 19:00:07 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
ctx.builder.build_int_truncate(idx0, llvm_i32, "").unwrap()
|
|
|
|
};
|
|
|
|
let idx1 = unsafe {
|
|
|
|
let idx1 =
|
|
|
|
idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None);
|
2024-04-19 19:00:07 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
ctx.builder.build_int_truncate(idx1, llvm_i32, "").unwrap()
|
|
|
|
};
|
2024-04-19 19:00:07 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
let result_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
|
|
|
|
let result_identity = ndarray_zero_value(generator, ctx, elem_ty);
|
|
|
|
ctx.builder.build_store(result_addr, result_identity).unwrap();
|
2024-04-19 19:00:07 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
gen_for_callback_incrementing(
|
|
|
|
generator,
|
|
|
|
ctx,
|
2024-07-25 15:54:39 +08:00
|
|
|
None,
|
2024-06-12 14:45:03 +08:00
|
|
|
llvm_i32.const_zero(),
|
|
|
|
(common_dim, false),
|
2024-07-02 19:05:00 +08:00
|
|
|
|generator, ctx, _, i| {
|
2024-06-12 14:45:03 +08:00
|
|
|
let i = ctx.builder.build_int_truncate(i, llvm_i32, "").unwrap();
|
|
|
|
|
|
|
|
let ab_idx = generator.gen_array_var_alloc(
|
|
|
|
ctx,
|
|
|
|
llvm_i32.into(),
|
|
|
|
llvm_usize.const_int(2, false),
|
|
|
|
None,
|
|
|
|
)?;
|
|
|
|
|
|
|
|
let a = unsafe {
|
|
|
|
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), idx0.into());
|
|
|
|
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_int(1, false), i.into());
|
|
|
|
|
|
|
|
lhs.data().get_unchecked(ctx, generator, &ab_idx, None)
|
|
|
|
};
|
|
|
|
let b = unsafe {
|
|
|
|
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), i.into());
|
|
|
|
ab_idx.set_unchecked(
|
2024-04-19 19:00:07 +08:00
|
|
|
ctx,
|
|
|
|
generator,
|
2024-06-12 14:45:03 +08:00
|
|
|
&llvm_usize.const_int(1, false),
|
|
|
|
idx1.into(),
|
|
|
|
);
|
2024-04-19 19:00:07 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
rhs.data().get_unchecked(ctx, generator, &ab_idx, None)
|
|
|
|
};
|
2024-04-19 19:00:07 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
let a_mul_b = gen_binop_expr_with_values(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
(&Some(elem_ty), a),
|
2024-06-27 13:01:26 +08:00
|
|
|
Binop::normal(Operator::Mult),
|
2024-06-12 14:45:03 +08:00
|
|
|
(&Some(elem_ty), b),
|
|
|
|
ctx.current_loc,
|
|
|
|
)?
|
|
|
|
.unwrap()
|
|
|
|
.to_basic_value_enum(ctx, generator, elem_ty)?;
|
|
|
|
|
|
|
|
let result = ctx.builder.build_load(result_addr, "").unwrap();
|
|
|
|
let result = gen_binop_expr_with_values(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
(&Some(elem_ty), result),
|
2024-06-27 13:01:26 +08:00
|
|
|
Binop::normal(Operator::Add),
|
2024-06-12 14:45:03 +08:00
|
|
|
(&Some(elem_ty), a_mul_b),
|
|
|
|
ctx.current_loc,
|
|
|
|
)?
|
|
|
|
.unwrap()
|
|
|
|
.to_basic_value_enum(ctx, generator, elem_ty)?;
|
|
|
|
ctx.builder.build_store(result_addr, result).unwrap();
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
},
|
|
|
|
llvm_usize.const_int(1, false),
|
|
|
|
)?;
|
|
|
|
|
|
|
|
let result = ctx.builder.build_load(result_addr, "").unwrap();
|
|
|
|
Ok(result)
|
|
|
|
})?;
|
2024-04-19 19:00:07 +08:00
|
|
|
|
|
|
|
Ok(ndarray)
|
|
|
|
}
|
|
|
|
|
2024-03-11 14:47:01 +08:00
|
|
|
/// Generates LLVM IR for `ndarray.empty`.
|
|
|
|
pub fn gen_ndarray_empty<'ctx>(
|
|
|
|
context: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
|
|
|
fun: (&FunSignature, DefinitionId),
|
|
|
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
|
|
|
generator: &mut dyn CodeGenerator,
|
|
|
|
) -> Result<PointerValue<'ctx>, String> {
|
|
|
|
assert!(obj.is_none());
|
|
|
|
assert_eq!(args.len(), 1);
|
|
|
|
|
|
|
|
let shape_ty = fun.0.args[0].ty;
|
2024-06-12 14:45:03 +08:00
|
|
|
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-06-25 15:35:02 +08:00
|
|
|
call_ndarray_empty_impl(generator, context, context.primitives.float, shape_arg)
|
|
|
|
.map(NDArrayValue::into)
|
2024-03-11 14:47:01 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Generates LLVM IR for `ndarray.zeros`.
|
|
|
|
pub fn gen_ndarray_zeros<'ctx>(
|
|
|
|
context: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
|
|
|
fun: (&FunSignature, DefinitionId),
|
|
|
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
|
|
|
generator: &mut dyn CodeGenerator,
|
|
|
|
) -> Result<PointerValue<'ctx>, String> {
|
|
|
|
assert!(obj.is_none());
|
|
|
|
assert_eq!(args.len(), 1);
|
|
|
|
|
|
|
|
let shape_ty = fun.0.args[0].ty;
|
2024-06-12 14:45:03 +08:00
|
|
|
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-06-25 15:35:02 +08:00
|
|
|
call_ndarray_zeros_impl(generator, context, context.primitives.float, shape_arg)
|
|
|
|
.map(NDArrayValue::into)
|
2024-03-11 14:47:01 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Generates LLVM IR for `ndarray.ones`.
|
|
|
|
pub fn gen_ndarray_ones<'ctx>(
|
|
|
|
context: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
|
|
|
fun: (&FunSignature, DefinitionId),
|
|
|
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
|
|
|
generator: &mut dyn CodeGenerator,
|
|
|
|
) -> Result<PointerValue<'ctx>, String> {
|
|
|
|
assert!(obj.is_none());
|
|
|
|
assert_eq!(args.len(), 1);
|
|
|
|
|
|
|
|
let shape_ty = fun.0.args[0].ty;
|
2024-06-12 14:45:03 +08:00
|
|
|
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-06-25 15:35:02 +08:00
|
|
|
call_ndarray_ones_impl(generator, context, context.primitives.float, shape_arg)
|
|
|
|
.map(NDArrayValue::into)
|
2024-03-11 14:47:01 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Generates LLVM IR for `ndarray.full`.
|
|
|
|
pub fn gen_ndarray_full<'ctx>(
|
|
|
|
context: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
|
|
|
fun: (&FunSignature, DefinitionId),
|
|
|
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
|
|
|
generator: &mut dyn CodeGenerator,
|
|
|
|
) -> Result<PointerValue<'ctx>, String> {
|
|
|
|
assert!(obj.is_none());
|
|
|
|
assert_eq!(args.len(), 2);
|
|
|
|
|
|
|
|
let shape_ty = fun.0.args[0].ty;
|
2024-06-12 14:45:03 +08:00
|
|
|
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
2024-03-11 14:47:01 +08:00
|
|
|
let fill_value_ty = fun.0.args[1].ty;
|
2024-06-12 14:45:03 +08:00
|
|
|
let fill_value_arg =
|
|
|
|
args[1].1.clone().to_basic_value_enum(context, generator, fill_value_ty)?;
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-06-25 15:35:02 +08:00
|
|
|
call_ndarray_full_impl(generator, context, fill_value_ty, shape_arg, fill_value_arg)
|
|
|
|
.map(NDArrayValue::into)
|
2024-03-11 14:47:01 +08:00
|
|
|
}
|
|
|
|
|
2024-06-11 15:29:32 +08:00
|
|
|
pub fn gen_ndarray_array<'ctx>(
|
|
|
|
context: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
|
|
|
fun: (&FunSignature, DefinitionId),
|
|
|
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
|
|
|
generator: &mut dyn CodeGenerator,
|
|
|
|
) -> Result<PointerValue<'ctx>, String> {
|
|
|
|
assert!(obj.is_none());
|
|
|
|
assert!(matches!(args.len(), 1..=3));
|
|
|
|
|
|
|
|
let obj_ty = fun.0.args[0].ty;
|
|
|
|
let obj_elem_ty = match &*context.unifier.get_ty(obj_ty) {
|
2024-06-12 15:01:01 +08:00
|
|
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
2024-06-11 15:29:32 +08:00
|
|
|
unpack_ndarray_var_tys(&mut context.unifier, obj_ty).0
|
|
|
|
}
|
|
|
|
|
2024-07-02 11:05:05 +08:00
|
|
|
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
|
|
|
|
let mut ty = *params.iter().next().unwrap().1;
|
|
|
|
while let TypeEnum::TObj { obj_id, params, .. } = &*context.unifier.get_ty_immutable(ty)
|
|
|
|
{
|
|
|
|
if *obj_id != PrimDef::List.id() {
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
|
|
|
|
ty = *params.iter().next().unwrap().1;
|
2024-06-11 15:29:32 +08:00
|
|
|
}
|
|
|
|
ty
|
2024-06-12 14:45:03 +08:00
|
|
|
}
|
2024-06-11 15:29:32 +08:00
|
|
|
|
|
|
|
_ => obj_ty,
|
|
|
|
};
|
2024-06-12 14:45:03 +08:00
|
|
|
let obj_arg = args[0].1.clone().to_basic_value_enum(context, generator, obj_ty)?;
|
2024-06-11 15:29:32 +08:00
|
|
|
|
|
|
|
let copy_arg = if let Some(arg) =
|
2024-06-12 14:45:03 +08:00
|
|
|
args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name))
|
|
|
|
{
|
2024-06-11 15:29:32 +08:00
|
|
|
let copy_ty = fun.0.args[1].ty;
|
|
|
|
arg.1.clone().to_basic_value_enum(context, generator, copy_ty)?
|
|
|
|
} else {
|
|
|
|
context.gen_symbol_val(
|
|
|
|
generator,
|
|
|
|
fun.0.args[1].default_value.as_ref().unwrap(),
|
|
|
|
fun.0.args[1].ty,
|
2024-06-12 14:45:03 +08:00
|
|
|
)
|
2024-06-11 15:29:32 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
let ndmin_arg = if let Some(arg) =
|
2024-06-12 14:45:03 +08:00
|
|
|
args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[2].name))
|
|
|
|
{
|
2024-06-11 15:29:32 +08:00
|
|
|
let ndmin_ty = fun.0.args[2].ty;
|
|
|
|
arg.1.clone().to_basic_value_enum(context, generator, ndmin_ty)?
|
|
|
|
} else {
|
|
|
|
context.gen_symbol_val(
|
|
|
|
generator,
|
|
|
|
fun.0.args[2].default_value.as_ref().unwrap(),
|
|
|
|
fun.0.args[2].ty,
|
|
|
|
)
|
2024-06-12 14:45:03 +08:00
|
|
|
};
|
2024-06-11 15:29:32 +08:00
|
|
|
|
|
|
|
call_ndarray_array_impl(
|
|
|
|
generator,
|
|
|
|
context,
|
|
|
|
obj_elem_ty,
|
|
|
|
obj_arg,
|
|
|
|
copy_arg.into_int_value(),
|
|
|
|
ndmin_arg.into_int_value(),
|
2024-06-12 14:45:03 +08:00
|
|
|
)
|
|
|
|
.map(NDArrayValue::into)
|
2024-06-11 15:29:32 +08:00
|
|
|
}
|
|
|
|
|
2024-03-11 14:47:01 +08:00
|
|
|
/// Generates LLVM IR for `ndarray.eye`.
|
|
|
|
pub fn gen_ndarray_eye<'ctx>(
|
|
|
|
context: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
|
|
|
fun: (&FunSignature, DefinitionId),
|
|
|
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
|
|
|
generator: &mut dyn CodeGenerator,
|
|
|
|
) -> Result<PointerValue<'ctx>, String> {
|
|
|
|
assert!(obj.is_none());
|
|
|
|
assert!(matches!(args.len(), 1..=3));
|
|
|
|
|
|
|
|
let nrows_ty = fun.0.args[0].ty;
|
2024-06-12 14:45:03 +08:00
|
|
|
let nrows_arg = args[0].1.clone().to_basic_value_enum(context, generator, nrows_ty)?;
|
2024-03-11 14:47:01 +08:00
|
|
|
|
|
|
|
let ncols_ty = fun.0.args[1].ty;
|
2024-06-12 14:45:03 +08:00
|
|
|
let ncols_arg = if let Some(arg) =
|
|
|
|
args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name))
|
|
|
|
{
|
2024-04-01 16:22:40 +08:00
|
|
|
arg.1.clone().to_basic_value_enum(context, generator, ncols_ty)
|
|
|
|
} else {
|
|
|
|
args[0].1.clone().to_basic_value_enum(context, generator, nrows_ty)
|
|
|
|
}?;
|
2024-03-11 14:47:01 +08:00
|
|
|
|
|
|
|
let offset_ty = fun.0.args[2].ty;
|
2024-04-01 16:22:40 +08:00
|
|
|
let offset_arg = if let Some(arg) =
|
2024-06-12 14:45:03 +08:00
|
|
|
args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[2].name))
|
|
|
|
{
|
|
|
|
arg.1.clone().to_basic_value_enum(context, generator, offset_ty)
|
2024-04-01 16:22:40 +08:00
|
|
|
} else {
|
|
|
|
Ok(context.gen_symbol_val(
|
|
|
|
generator,
|
|
|
|
fun.0.args[2].default_value.as_ref().unwrap(),
|
2024-06-12 14:45:03 +08:00
|
|
|
offset_ty,
|
|
|
|
))
|
2024-04-01 16:22:40 +08:00
|
|
|
}?;
|
2024-03-11 14:47:01 +08:00
|
|
|
|
|
|
|
call_ndarray_eye_impl(
|
|
|
|
generator,
|
|
|
|
context,
|
|
|
|
context.primitives.float,
|
|
|
|
nrows_arg.into_int_value(),
|
|
|
|
ncols_arg.into_int_value(),
|
|
|
|
offset_arg.into_int_value(),
|
2024-06-12 14:45:03 +08:00
|
|
|
)
|
|
|
|
.map(NDArrayValue::into)
|
2024-03-11 14:47:01 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Generates LLVM IR for `ndarray.identity`.
|
|
|
|
pub fn gen_ndarray_identity<'ctx>(
|
|
|
|
context: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
|
|
|
fun: (&FunSignature, DefinitionId),
|
|
|
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
|
|
|
generator: &mut dyn CodeGenerator,
|
|
|
|
) -> Result<PointerValue<'ctx>, String> {
|
|
|
|
assert!(obj.is_none());
|
|
|
|
assert_eq!(args.len(), 1);
|
|
|
|
|
|
|
|
let llvm_usize = generator.get_size_type(context.ctx);
|
|
|
|
|
|
|
|
let n_ty = fun.0.args[0].ty;
|
2024-06-12 14:45:03 +08:00
|
|
|
let n_arg = args[0].1.clone().to_basic_value_enum(context, generator, n_ty)?;
|
2024-03-11 14:47:01 +08:00
|
|
|
|
|
|
|
call_ndarray_eye_impl(
|
|
|
|
generator,
|
|
|
|
context,
|
|
|
|
context.primitives.float,
|
|
|
|
n_arg.into_int_value(),
|
|
|
|
n_arg.into_int_value(),
|
|
|
|
llvm_usize.const_zero(),
|
2024-06-12 14:45:03 +08:00
|
|
|
)
|
|
|
|
.map(NDArrayValue::into)
|
2024-03-11 14:47:01 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Generates LLVM IR for `ndarray.copy`.
|
|
|
|
pub fn gen_ndarray_copy<'ctx>(
|
|
|
|
context: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
|
|
|
_fun: (&FunSignature, DefinitionId),
|
|
|
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
|
|
|
generator: &mut dyn CodeGenerator,
|
|
|
|
) -> Result<PointerValue<'ctx>, String> {
|
|
|
|
assert!(obj.is_some());
|
|
|
|
assert!(args.is_empty());
|
|
|
|
|
|
|
|
let llvm_usize = generator.get_size_type(context.ctx);
|
|
|
|
|
|
|
|
let this_ty = obj.as_ref().unwrap().0;
|
2024-03-26 19:14:56 +08:00
|
|
|
let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty);
|
2024-06-12 14:45:03 +08:00
|
|
|
let this_arg =
|
|
|
|
obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?;
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-08-28 16:33:03 +08:00
|
|
|
let llvm_elem_ty = context.get_llvm_type(generator, this_elem_ty);
|
|
|
|
|
2024-03-11 14:47:01 +08:00
|
|
|
ndarray_copy_impl(
|
|
|
|
generator,
|
|
|
|
context,
|
|
|
|
this_elem_ty,
|
2024-08-28 16:33:03 +08:00
|
|
|
NDArrayValue::from_pointer_value(
|
|
|
|
this_arg.into_pointer_value(),
|
|
|
|
llvm_elem_ty,
|
|
|
|
llvm_usize,
|
|
|
|
None,
|
|
|
|
),
|
2024-06-12 14:45:03 +08:00
|
|
|
)
|
|
|
|
.map(NDArrayValue::into)
|
2024-03-11 14:47:01 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Generates LLVM IR for `ndarray.fill`.
|
|
|
|
pub fn gen_ndarray_fill<'ctx>(
|
|
|
|
context: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
|
|
|
fun: (&FunSignature, DefinitionId),
|
|
|
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
|
|
|
generator: &mut dyn CodeGenerator,
|
|
|
|
) -> Result<(), String> {
|
|
|
|
assert!(obj.is_some());
|
|
|
|
assert_eq!(args.len(), 1);
|
|
|
|
|
|
|
|
let llvm_usize = generator.get_size_type(context.ctx);
|
|
|
|
|
|
|
|
let this_ty = obj.as_ref().unwrap().0;
|
2024-08-28 16:33:03 +08:00
|
|
|
let this_elem_ty = arraylike_flatten_element_type(&mut context.unifier, this_ty);
|
2024-06-12 14:45:03 +08:00
|
|
|
let this_arg = obj
|
|
|
|
.as_ref()
|
|
|
|
.unwrap()
|
|
|
|
.1
|
|
|
|
.clone()
|
2024-03-11 14:47:01 +08:00
|
|
|
.to_basic_value_enum(context, generator, this_ty)?
|
|
|
|
.into_pointer_value();
|
|
|
|
let value_ty = fun.0.args[0].ty;
|
2024-06-12 14:45:03 +08:00
|
|
|
let value_arg = args[0].1.clone().to_basic_value_enum(context, generator, value_ty)?;
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-08-28 16:33:03 +08:00
|
|
|
let llvm_elem_ty = context.get_llvm_type(generator, this_elem_ty);
|
|
|
|
|
2024-03-11 14:47:01 +08:00
|
|
|
ndarray_fill_flattened(
|
|
|
|
generator,
|
|
|
|
context,
|
2024-08-28 16:33:03 +08:00
|
|
|
NDArrayValue::from_pointer_value(this_arg, llvm_elem_ty, llvm_usize, None),
|
2024-03-11 14:47:01 +08:00
|
|
|
|generator, ctx, _| {
|
|
|
|
let value = if value_arg.is_pointer_value() {
|
|
|
|
let llvm_i1 = ctx.ctx.bool_type();
|
|
|
|
|
|
|
|
let copy = generator.gen_var_alloc(ctx, value_arg.get_type(), None)?;
|
|
|
|
|
|
|
|
call_memcpy_generic(
|
|
|
|
ctx,
|
|
|
|
copy,
|
|
|
|
value_arg.into_pointer_value(),
|
|
|
|
value_arg.get_type().size_of().map(Into::into).unwrap(),
|
|
|
|
llvm_i1.const_zero(),
|
|
|
|
);
|
|
|
|
|
|
|
|
copy.into()
|
|
|
|
} else if value_arg.is_int_value() || value_arg.is_float_value() {
|
|
|
|
value_arg
|
|
|
|
} else {
|
2024-08-23 13:10:55 +08:00
|
|
|
codegen_unreachable!(ctx)
|
2024-03-11 14:47:01 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
Ok(value)
|
2024-06-12 14:45:03 +08:00
|
|
|
},
|
2024-03-11 14:47:01 +08:00
|
|
|
)?;
|
|
|
|
|
|
|
|
Ok(())
|
2024-06-12 14:45:03 +08:00
|
|
|
}
|
2024-07-31 13:16:42 +08:00
|
|
|
|
|
|
|
/// Generates LLVM IR for `ndarray.transpose`.
|
|
|
|
pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
|
|
|
generator: &mut G,
|
|
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
x1: (Type, BasicValueEnum<'ctx>),
|
|
|
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
|
|
|
const FN_NAME: &str = "ndarray_transpose";
|
|
|
|
let (x1_ty, x1) = x1;
|
|
|
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
|
|
|
|
|
|
if let BasicValueEnum::PointerValue(n1) = x1 {
|
|
|
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
2024-08-28 16:33:03 +08:00
|
|
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
|
|
let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None);
|
2024-11-13 15:53:29 +08:00
|
|
|
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
2024-07-31 13:16:42 +08:00
|
|
|
|
|
|
|
// Dimensions are reversed in the transposed array
|
|
|
|
let out = create_ndarray_dyn_shape(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
elem_ty,
|
|
|
|
&n1,
|
|
|
|
|_, ctx, n| Ok(n.load_ndims(ctx)),
|
|
|
|
|generator, ctx, n, idx| {
|
|
|
|
let new_idx = ctx.builder.build_int_sub(n.load_ndims(ctx), idx, "").unwrap();
|
|
|
|
let new_idx = ctx
|
|
|
|
.builder
|
|
|
|
.build_int_sub(new_idx, new_idx.get_type().const_int(1, false), "")
|
|
|
|
.unwrap();
|
2024-11-13 15:53:29 +08:00
|
|
|
unsafe { Ok(n.shape().get_typed_unchecked(ctx, generator, &new_idx, None)) }
|
2024-07-31 13:16:42 +08:00
|
|
|
},
|
|
|
|
)
|
|
|
|
.unwrap();
|
|
|
|
|
|
|
|
gen_for_callback_incrementing(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
None,
|
|
|
|
llvm_usize.const_zero(),
|
|
|
|
(n_sz, false),
|
|
|
|
|generator, ctx, _, idx| {
|
|
|
|
let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
|
|
|
|
|
|
|
|
let new_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
|
|
|
let rem_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
|
|
|
ctx.builder.build_store(new_idx, llvm_usize.const_zero()).unwrap();
|
|
|
|
ctx.builder.build_store(rem_idx, idx).unwrap();
|
|
|
|
|
|
|
|
// Incrementally calculate the new index in the transposed array
|
|
|
|
// For each index, we first decompose it into the n-dims and use those to reconstruct the new index
|
|
|
|
// The formula used for indexing is:
|
|
|
|
// idx = dim_n * ( ... (dim2 * (dim0 * dim1) + dim1) + dim2 ... ) + dim_n
|
|
|
|
gen_for_callback_incrementing(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
None,
|
|
|
|
llvm_usize.const_zero(),
|
|
|
|
(n1.load_ndims(ctx), false),
|
|
|
|
|generator, ctx, _, ndim| {
|
|
|
|
let ndim_rev =
|
|
|
|
ctx.builder.build_int_sub(n1.load_ndims(ctx), ndim, "").unwrap();
|
|
|
|
let ndim_rev = ctx
|
|
|
|
.builder
|
|
|
|
.build_int_sub(ndim_rev, llvm_usize.const_int(1, false), "")
|
|
|
|
.unwrap();
|
|
|
|
let dim = unsafe {
|
2024-11-13 15:53:29 +08:00
|
|
|
n1.shape().get_typed_unchecked(ctx, generator, &ndim_rev, None)
|
2024-07-31 13:16:42 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
let rem_idx_val =
|
|
|
|
ctx.builder.build_load(rem_idx, "").unwrap().into_int_value();
|
|
|
|
let new_idx_val =
|
|
|
|
ctx.builder.build_load(new_idx, "").unwrap().into_int_value();
|
|
|
|
|
|
|
|
let add_component =
|
|
|
|
ctx.builder.build_int_unsigned_rem(rem_idx_val, dim, "").unwrap();
|
|
|
|
let rem_idx_val =
|
|
|
|
ctx.builder.build_int_unsigned_div(rem_idx_val, dim, "").unwrap();
|
|
|
|
|
|
|
|
let new_idx_val = ctx.builder.build_int_mul(new_idx_val, dim, "").unwrap();
|
|
|
|
let new_idx_val =
|
|
|
|
ctx.builder.build_int_add(new_idx_val, add_component, "").unwrap();
|
|
|
|
|
|
|
|
ctx.builder.build_store(rem_idx, rem_idx_val).unwrap();
|
|
|
|
ctx.builder.build_store(new_idx, new_idx_val).unwrap();
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
},
|
|
|
|
llvm_usize.const_int(1, false),
|
|
|
|
)?;
|
|
|
|
|
|
|
|
let new_idx_val = ctx.builder.build_load(new_idx, "").unwrap().into_int_value();
|
|
|
|
unsafe { out.data().set_unchecked(ctx, generator, &new_idx_val, elem) };
|
|
|
|
Ok(())
|
|
|
|
},
|
|
|
|
llvm_usize.const_int(1, false),
|
|
|
|
)?;
|
|
|
|
|
|
|
|
Ok(out.as_base_value().into())
|
|
|
|
} else {
|
2024-08-23 13:10:55 +08:00
|
|
|
codegen_unreachable!(
|
|
|
|
ctx,
|
2024-07-31 13:16:42 +08:00
|
|
|
"{FN_NAME}() not supported for '{}'",
|
|
|
|
format!("'{}'", ctx.unifier.stringify(x1_ty))
|
|
|
|
)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// LLVM-typed implementation for generating the implementation for `ndarray.reshape`.
|
|
|
|
///
|
|
|
|
/// * `x1` - `NDArray` to reshape.
|
|
|
|
/// * `shape` - The `shape` parameter used to construct the new `NDArray`.
|
2024-08-21 11:10:52 +08:00
|
|
|
/// Just like numpy, the `shape` argument can be:
|
2024-07-31 13:16:42 +08:00
|
|
|
/// 1. A list of `int32`; e.g., `np.reshape(arr, [600, -1, 3])`
|
|
|
|
/// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))`
|
|
|
|
/// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)`
|
2024-08-20 11:29:03 +08:00
|
|
|
///
|
|
|
|
/// Note that unlike other generating functions, one of the dimensions in the shape can be negative.
|
2024-07-31 13:16:42 +08:00
|
|
|
pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
|
|
|
generator: &mut G,
|
|
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
x1: (Type, BasicValueEnum<'ctx>),
|
|
|
|
shape: (Type, BasicValueEnum<'ctx>),
|
|
|
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
|
|
|
const FN_NAME: &str = "ndarray_reshape";
|
|
|
|
let (x1_ty, x1) = x1;
|
|
|
|
let (_, shape) = shape;
|
|
|
|
|
|
|
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
|
|
|
|
|
|
if let BasicValueEnum::PointerValue(n1) = x1 {
|
|
|
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
2024-08-28 16:33:03 +08:00
|
|
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
|
|
let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None);
|
2024-11-13 15:53:29 +08:00
|
|
|
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
2024-07-31 13:16:42 +08:00
|
|
|
|
|
|
|
let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
|
|
|
let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
|
|
|
ctx.builder.build_store(acc, llvm_usize.const_int(1, false)).unwrap();
|
|
|
|
ctx.builder.build_store(num_neg, llvm_usize.const_zero()).unwrap();
|
|
|
|
|
|
|
|
let out = match shape {
|
|
|
|
BasicValueEnum::PointerValue(shape_list_ptr)
|
2024-11-01 15:17:00 +08:00
|
|
|
if ListValue::is_representable(shape_list_ptr, llvm_usize).is_ok() =>
|
2024-07-31 13:16:42 +08:00
|
|
|
{
|
|
|
|
// 1. A list of ints; e.g., `np.reshape(arr, [int64(600), int64(800, -1])`
|
|
|
|
|
2024-11-01 15:17:00 +08:00
|
|
|
let shape_list = ListValue::from_pointer_value(shape_list_ptr, llvm_usize, None);
|
2024-07-31 13:16:42 +08:00
|
|
|
// Check for -1 in dimensions
|
|
|
|
gen_for_callback_incrementing(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
None,
|
|
|
|
llvm_usize.const_zero(),
|
|
|
|
(shape_list.load_size(ctx, None), false),
|
|
|
|
|generator, ctx, _, idx| {
|
|
|
|
let ele =
|
|
|
|
shape_list.data().get(ctx, generator, &idx, None).into_int_value();
|
|
|
|
let ele = ctx.builder.build_int_s_extend(ele, llvm_usize, "").unwrap();
|
|
|
|
|
|
|
|
gen_if_else_expr_callback(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
|_, ctx| {
|
|
|
|
Ok(ctx
|
|
|
|
.builder
|
|
|
|
.build_int_compare(
|
|
|
|
IntPredicate::SLT,
|
|
|
|
ele,
|
|
|
|
llvm_usize.const_zero(),
|
|
|
|
"",
|
|
|
|
)
|
|
|
|
.unwrap())
|
|
|
|
},
|
|
|
|
|_, ctx| -> Result<Option<IntValue>, String> {
|
|
|
|
let num_neg_value =
|
|
|
|
ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
|
|
|
|
let num_neg_value = ctx
|
|
|
|
.builder
|
|
|
|
.build_int_add(
|
|
|
|
num_neg_value,
|
|
|
|
llvm_usize.const_int(1, false),
|
|
|
|
"",
|
|
|
|
)
|
|
|
|
.unwrap();
|
|
|
|
ctx.builder.build_store(num_neg, num_neg_value).unwrap();
|
|
|
|
Ok(None)
|
|
|
|
},
|
|
|
|
|_, ctx| {
|
|
|
|
let acc_value =
|
|
|
|
ctx.builder.build_load(acc, "").unwrap().into_int_value();
|
|
|
|
let acc_value =
|
|
|
|
ctx.builder.build_int_mul(acc_value, ele, "").unwrap();
|
|
|
|
ctx.builder.build_store(acc, acc_value).unwrap();
|
|
|
|
Ok(None)
|
|
|
|
},
|
|
|
|
)?;
|
|
|
|
Ok(())
|
|
|
|
},
|
|
|
|
llvm_usize.const_int(1, false),
|
|
|
|
)?;
|
|
|
|
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
|
|
|
|
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
|
|
|
|
// Generate the output shape by filling -1 with `rem`
|
|
|
|
create_ndarray_dyn_shape(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
elem_ty,
|
|
|
|
&shape_list,
|
|
|
|
|_, ctx, _| Ok(shape_list.load_size(ctx, None)),
|
|
|
|
|generator, ctx, shape_list, idx| {
|
|
|
|
let dim =
|
|
|
|
shape_list.data().get(ctx, generator, &idx, None).into_int_value();
|
|
|
|
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
|
|
|
|
|
|
|
|
Ok(gen_if_else_expr_callback(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
|_, ctx| {
|
|
|
|
Ok(ctx
|
|
|
|
.builder
|
|
|
|
.build_int_compare(
|
|
|
|
IntPredicate::SLT,
|
|
|
|
dim,
|
|
|
|
llvm_usize.const_zero(),
|
|
|
|
"",
|
|
|
|
)
|
|
|
|
.unwrap())
|
|
|
|
},
|
|
|
|
|_, _| Ok(Some(rem)),
|
|
|
|
|_, _| Ok(Some(dim)),
|
|
|
|
)?
|
|
|
|
.unwrap()
|
|
|
|
.into_int_value())
|
|
|
|
},
|
|
|
|
)
|
|
|
|
}
|
|
|
|
BasicValueEnum::StructValue(shape_tuple) => {
|
|
|
|
// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))`
|
|
|
|
|
|
|
|
let ndims = shape_tuple.get_type().count_fields();
|
|
|
|
// Check for -1 in dims
|
|
|
|
for dim_i in 0..ndims {
|
|
|
|
let dim = ctx
|
|
|
|
.builder
|
|
|
|
.build_extract_value(shape_tuple, dim_i, "")
|
|
|
|
.unwrap()
|
|
|
|
.into_int_value();
|
|
|
|
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
|
|
|
|
|
|
|
|
gen_if_else_expr_callback(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
|_, ctx| {
|
|
|
|
Ok(ctx
|
|
|
|
.builder
|
|
|
|
.build_int_compare(
|
|
|
|
IntPredicate::SLT,
|
|
|
|
dim,
|
|
|
|
llvm_usize.const_zero(),
|
|
|
|
"",
|
|
|
|
)
|
|
|
|
.unwrap())
|
|
|
|
},
|
|
|
|
|_, ctx| -> Result<Option<IntValue>, String> {
|
|
|
|
let num_negs =
|
|
|
|
ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
|
|
|
|
let num_negs = ctx
|
|
|
|
.builder
|
|
|
|
.build_int_add(num_negs, llvm_usize.const_int(1, false), "")
|
|
|
|
.unwrap();
|
|
|
|
ctx.builder.build_store(num_neg, num_negs).unwrap();
|
|
|
|
Ok(None)
|
|
|
|
},
|
|
|
|
|_, ctx| {
|
|
|
|
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
|
|
|
|
let acc_val = ctx.builder.build_int_mul(acc_val, dim, "").unwrap();
|
|
|
|
ctx.builder.build_store(acc, acc_val).unwrap();
|
|
|
|
Ok(None)
|
|
|
|
},
|
|
|
|
)?;
|
|
|
|
}
|
|
|
|
|
|
|
|
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
|
|
|
|
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
|
|
|
|
let mut shape = Vec::with_capacity(ndims as usize);
|
|
|
|
|
|
|
|
// Reconstruct shape filling negatives with rem
|
|
|
|
for dim_i in 0..ndims {
|
|
|
|
let dim = ctx
|
|
|
|
.builder
|
|
|
|
.build_extract_value(shape_tuple, dim_i, "")
|
|
|
|
.unwrap()
|
|
|
|
.into_int_value();
|
|
|
|
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
|
|
|
|
|
|
|
|
let dim = gen_if_else_expr_callback(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
|_, ctx| {
|
|
|
|
Ok(ctx
|
|
|
|
.builder
|
|
|
|
.build_int_compare(
|
|
|
|
IntPredicate::SLT,
|
|
|
|
dim,
|
|
|
|
llvm_usize.const_zero(),
|
|
|
|
"",
|
|
|
|
)
|
|
|
|
.unwrap())
|
|
|
|
},
|
|
|
|
|_, _| Ok(Some(rem)),
|
|
|
|
|_, _| Ok(Some(dim)),
|
|
|
|
)?
|
|
|
|
.unwrap()
|
|
|
|
.into_int_value();
|
|
|
|
shape.push(dim);
|
|
|
|
}
|
|
|
|
create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice())
|
|
|
|
}
|
|
|
|
BasicValueEnum::IntValue(shape_int) => {
|
|
|
|
// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)`
|
|
|
|
let shape_int = gen_if_else_expr_callback(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
|_, ctx| {
|
|
|
|
Ok(ctx
|
|
|
|
.builder
|
|
|
|
.build_int_compare(
|
|
|
|
IntPredicate::SLT,
|
|
|
|
shape_int,
|
|
|
|
llvm_usize.const_zero(),
|
|
|
|
"",
|
|
|
|
)
|
|
|
|
.unwrap())
|
|
|
|
},
|
|
|
|
|_, _| Ok(Some(n_sz)),
|
|
|
|
|_, ctx| {
|
|
|
|
Ok(Some(ctx.builder.build_int_s_extend(shape_int, llvm_usize, "").unwrap()))
|
|
|
|
},
|
|
|
|
)?
|
|
|
|
.unwrap()
|
|
|
|
.into_int_value();
|
|
|
|
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
|
|
|
|
}
|
2024-08-23 13:10:55 +08:00
|
|
|
_ => codegen_unreachable!(ctx),
|
2024-07-31 13:16:42 +08:00
|
|
|
}
|
|
|
|
.unwrap();
|
|
|
|
|
|
|
|
// Only allow one dimension to be negative
|
|
|
|
let num_negs = ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
|
|
|
|
ctx.make_assert(
|
|
|
|
generator,
|
|
|
|
ctx.builder
|
|
|
|
.build_int_compare(IntPredicate::ULT, num_negs, llvm_usize.const_int(2, false), "")
|
|
|
|
.unwrap(),
|
|
|
|
"0:ValueError",
|
|
|
|
"can only specify one unknown dimension",
|
|
|
|
[None, None, None],
|
|
|
|
ctx.current_loc,
|
|
|
|
);
|
|
|
|
|
|
|
|
// The new shape must be compatible with the old shape
|
2024-11-13 15:53:29 +08:00
|
|
|
let out_sz = call_ndarray_calc_size(generator, ctx, &out.shape(), (None, None));
|
2024-07-31 13:16:42 +08:00
|
|
|
ctx.make_assert(
|
|
|
|
generator,
|
|
|
|
ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(),
|
|
|
|
"0:ValueError",
|
2024-07-31 15:53:51 +08:00
|
|
|
"cannot reshape array of size {0} into provided shape of size {1}",
|
2024-07-31 13:16:42 +08:00
|
|
|
[Some(n_sz), Some(out_sz), None],
|
|
|
|
ctx.current_loc,
|
|
|
|
);
|
|
|
|
|
|
|
|
gen_for_callback_incrementing(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
None,
|
|
|
|
llvm_usize.const_zero(),
|
|
|
|
(n_sz, false),
|
|
|
|
|generator, ctx, _, idx| {
|
|
|
|
let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
|
|
|
|
unsafe { out.data().set_unchecked(ctx, generator, &idx, elem) };
|
|
|
|
Ok(())
|
|
|
|
},
|
|
|
|
llvm_usize.const_int(1, false),
|
|
|
|
)?;
|
|
|
|
|
|
|
|
Ok(out.as_base_value().into())
|
|
|
|
} else {
|
2024-08-23 13:10:55 +08:00
|
|
|
codegen_unreachable!(
|
|
|
|
ctx,
|
2024-07-31 13:16:42 +08:00
|
|
|
"{FN_NAME}() not supported for '{}'",
|
|
|
|
format!("'{}'", ctx.unifier.stringify(x1_ty))
|
|
|
|
)
|
|
|
|
}
|
|
|
|
}
|
2024-07-31 15:53:51 +08:00
|
|
|
|
|
|
|
/// Generates LLVM IR for `ndarray.dot`.
|
|
|
|
/// Calculate inner product of two vectors or literals
|
|
|
|
/// For matrix multiplication use `np_matmul`
|
|
|
|
///
|
|
|
|
/// The input `NDArray` are flattened and treated as 1D
|
2024-07-31 18:02:54 +08:00
|
|
|
/// The operation is equivalent to `np.dot(arr1.ravel(), arr2.ravel())`
|
2024-07-31 15:53:51 +08:00
|
|
|
pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
|
|
|
|
generator: &mut G,
|
|
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
x1: (Type, BasicValueEnum<'ctx>),
|
|
|
|
x2: (Type, BasicValueEnum<'ctx>),
|
|
|
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
|
|
|
const FN_NAME: &str = "ndarray_dot";
|
|
|
|
let (x1_ty, x1) = x1;
|
2024-08-28 16:33:03 +08:00
|
|
|
let (x2_ty, x2) = x2;
|
2024-07-31 15:53:51 +08:00
|
|
|
|
|
|
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
|
|
|
|
|
|
match (x1, x2) {
|
|
|
|
(BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => {
|
2024-08-28 16:33:03 +08:00
|
|
|
let n1_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x1_ty);
|
|
|
|
let n2_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x2_ty);
|
|
|
|
let llvm_n1_data_ty = ctx.get_llvm_type(generator, n1_dtype);
|
|
|
|
let llvm_n2_data_ty = ctx.get_llvm_type(generator, n2_dtype);
|
|
|
|
|
|
|
|
let n1 = NDArrayValue::from_pointer_value(n1, llvm_n1_data_ty, llvm_usize, None);
|
|
|
|
let n2 = NDArrayValue::from_pointer_value(n2, llvm_n2_data_ty, llvm_usize, None);
|
2024-07-31 15:53:51 +08:00
|
|
|
|
2024-11-13 15:53:29 +08:00
|
|
|
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
|
|
|
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
2024-07-31 15:53:51 +08:00
|
|
|
|
|
|
|
ctx.make_assert(
|
|
|
|
generator,
|
|
|
|
ctx.builder.build_int_compare(IntPredicate::EQ, n1_sz, n2_sz, "").unwrap(),
|
|
|
|
"0:ValueError",
|
|
|
|
"shapes ({0}), ({1}) not aligned",
|
|
|
|
[Some(n1_sz), Some(n2_sz), None],
|
|
|
|
ctx.current_loc,
|
|
|
|
);
|
|
|
|
|
|
|
|
let identity =
|
|
|
|
unsafe { n1.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
|
|
|
|
let acc = ctx.builder.build_alloca(identity.get_type(), "").unwrap();
|
|
|
|
ctx.builder.build_store(acc, identity.get_type().const_zero()).unwrap();
|
|
|
|
|
|
|
|
gen_for_callback_incrementing(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
None,
|
|
|
|
llvm_usize.const_zero(),
|
|
|
|
(n1_sz, false),
|
|
|
|
|generator, ctx, _, idx| {
|
|
|
|
let elem1 = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
|
|
|
|
let elem2 = unsafe { n2.data().get_unchecked(ctx, generator, &idx, None) };
|
|
|
|
|
|
|
|
let product = match elem1 {
|
|
|
|
BasicValueEnum::IntValue(e1) => ctx
|
|
|
|
.builder
|
|
|
|
.build_int_mul(e1, elem2.into_int_value(), "")
|
|
|
|
.unwrap()
|
|
|
|
.as_basic_value_enum(),
|
|
|
|
BasicValueEnum::FloatValue(e1) => ctx
|
|
|
|
.builder
|
|
|
|
.build_float_mul(e1, elem2.into_float_value(), "")
|
|
|
|
.unwrap()
|
|
|
|
.as_basic_value_enum(),
|
2024-08-28 16:33:03 +08:00
|
|
|
_ => codegen_unreachable!(ctx, "product: {}", elem1.get_type()),
|
2024-07-31 15:53:51 +08:00
|
|
|
};
|
|
|
|
let acc_val = ctx.builder.build_load(acc, "").unwrap();
|
|
|
|
let acc_val = match acc_val {
|
|
|
|
BasicValueEnum::IntValue(e1) => ctx
|
|
|
|
.builder
|
|
|
|
.build_int_add(e1, product.into_int_value(), "")
|
|
|
|
.unwrap()
|
|
|
|
.as_basic_value_enum(),
|
|
|
|
BasicValueEnum::FloatValue(e1) => ctx
|
|
|
|
.builder
|
|
|
|
.build_float_add(e1, product.into_float_value(), "")
|
|
|
|
.unwrap()
|
|
|
|
.as_basic_value_enum(),
|
2024-08-28 16:33:03 +08:00
|
|
|
_ => codegen_unreachable!(ctx, "acc_val: {}", acc_val.get_type()),
|
2024-07-31 15:53:51 +08:00
|
|
|
};
|
|
|
|
ctx.builder.build_store(acc, acc_val).unwrap();
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
},
|
|
|
|
llvm_usize.const_int(1, false),
|
|
|
|
)?;
|
|
|
|
let acc_val = ctx.builder.build_load(acc, "").unwrap();
|
|
|
|
Ok(acc_val)
|
|
|
|
}
|
|
|
|
(BasicValueEnum::IntValue(e1), BasicValueEnum::IntValue(e2)) => {
|
|
|
|
Ok(ctx.builder.build_int_mul(e1, e2, "").unwrap().as_basic_value_enum())
|
|
|
|
}
|
|
|
|
(BasicValueEnum::FloatValue(e1), BasicValueEnum::FloatValue(e2)) => {
|
|
|
|
Ok(ctx.builder.build_float_mul(e1, e2, "").unwrap().as_basic_value_enum())
|
|
|
|
}
|
2024-08-23 13:10:55 +08:00
|
|
|
_ => codegen_unreachable!(
|
|
|
|
ctx,
|
2024-07-31 15:53:51 +08:00
|
|
|
"{FN_NAME}() not supported for '{}'",
|
|
|
|
format!("'{}'", ctx.unifier.stringify(x1_ty))
|
|
|
|
),
|
|
|
|
}
|
|
|
|
}
|