2024-10-03 12:37:56 +08:00
|
|
|
use inkwell::{
|
[core] codegen/ndarray: Reimplement np_array()
Based on 8f0084ac: core/ndstrides: implement np_array()
It also checks for inconsistent dimensions if the input is a list.
e.g., rejecting `[[1.0, 2.0], [3.0]]`.
However, currently only `np_array(<input>, copy=False)` and `np_array
(<input>, copy=True)` are supported. In NumPy, copy could be false,
true, or None. Right now, NAC3's `np_array(<input>, copy=False)` behaves
like NumPy's `np.array(<input>, copy=None)`.
2024-08-20 14:51:40 +08:00
|
|
|
types::BasicType,
|
2024-10-03 12:37:56 +08:00
|
|
|
values::{BasicValue, BasicValueEnum, IntValue, PointerValue},
|
[core] codegen/ndarray: Reimplement np_array()
Based on 8f0084ac: core/ndstrides: implement np_array()
It also checks for inconsistent dimensions if the input is a list.
e.g., rejecting `[[1.0, 2.0], [3.0]]`.
However, currently only `np_array(<input>, copy=False)` and `np_array
(<input>, copy=True)` are supported. In NumPy, copy could be false,
true, or None. Right now, NAC3's `np_array(<input>, copy=False)` behaves
like NumPy's `np.array(<input>, copy=None)`.
2024-08-20 14:51:40 +08:00
|
|
|
IntPredicate, OptimizationLevel,
|
2024-10-03 12:37:56 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
use nac3parser::ast::{Operator, StrRef};
|
|
|
|
|
2024-10-17 15:57:33 +08:00
|
|
|
use super::{
|
|
|
|
expr::gen_binop_expr_with_values,
|
|
|
|
irrt::{
|
2024-11-22 16:38:57 +08:00
|
|
|
calculate_len_for_slice_range,
|
|
|
|
ndarray::{
|
|
|
|
call_ndarray_calc_broadcast, call_ndarray_calc_broadcast_index,
|
|
|
|
call_ndarray_calc_nd_indices, call_ndarray_calc_size,
|
|
|
|
},
|
2024-03-11 14:47:01 +08:00
|
|
|
},
|
2024-10-17 15:57:33 +08:00
|
|
|
llvm_intrinsics::{self, call_memcpy_generic},
|
|
|
|
macros::codegen_unreachable,
|
|
|
|
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
|
2024-12-17 18:01:12 +08:00
|
|
|
types::ndarray::{factory::ndarray_zero_value, NDArrayType},
|
2024-10-29 13:57:28 +08:00
|
|
|
values::{
|
2024-12-16 15:26:18 +08:00
|
|
|
ndarray::{shape::parse_numpy_int_sequence, NDArrayValue},
|
2024-12-18 11:40:23 +08:00
|
|
|
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ProxyValue, TypedArrayLikeAccessor,
|
|
|
|
TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor,
|
|
|
|
UntypedArrayLikeMutator,
|
2024-10-29 13:57:28 +08:00
|
|
|
},
|
2024-10-17 15:57:33 +08:00
|
|
|
CodeGenContext, CodeGenerator,
|
|
|
|
};
|
|
|
|
use crate::{
|
2024-03-11 14:47:01 +08:00
|
|
|
symbol_resolver::ValueEnum,
|
[core] codegen/ndarray: Reimplement np_array()
Based on 8f0084ac: core/ndstrides: implement np_array()
It also checks for inconsistent dimensions if the input is a list.
e.g., rejecting `[[1.0, 2.0], [3.0]]`.
However, currently only `np_array(<input>, copy=False)` and `np_array
(<input>, copy=True)` are supported. In NumPy, copy could be false,
true, or None. Right now, NAC3's `np_array(<input>, copy=False)` behaves
like NumPy's `np.array(<input>, copy=None)`.
2024-08-20 14:51:40 +08:00
|
|
|
toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys, DefinitionId},
|
2024-06-27 13:01:26 +08:00
|
|
|
typecheck::{
|
|
|
|
magic_methods::Binop,
|
[core] codegen/ndarray: Reimplement np_array()
Based on 8f0084ac: core/ndstrides: implement np_array()
It also checks for inconsistent dimensions if the input is a list.
e.g., rejecting `[[1.0, 2.0], [3.0]]`.
However, currently only `np_array(<input>, copy=False)` and `np_array
(<input>, copy=True)` are supported. In NumPy, copy could be false,
true, or None. Right now, NAC3's `np_array(<input>, copy=False)` behaves
like NumPy's `np.array(<input>, copy=None)`.
2024-08-20 14:51:40 +08:00
|
|
|
typedef::{FunSignature, Type},
|
2024-06-27 13:01:26 +08:00
|
|
|
},
|
2024-03-11 14:47:01 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
/// Creates an `NDArray` instance from a dynamic shape.
|
|
|
|
///
|
|
|
|
/// * `elem_ty` - The element type of the `NDArray`.
|
|
|
|
/// * `shape` - The shape of the `NDArray`.
|
|
|
|
/// * `shape_len_fn` - A function that retrieves the number of dimensions from `shape`.
|
|
|
|
/// * `shape_data_fn` - A function that retrieves the size of a dimension from `shape`.
|
2024-03-19 18:24:30 +08:00
|
|
|
fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>(
|
|
|
|
generator: &mut G,
|
2024-03-11 14:47:01 +08:00
|
|
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
|
|
|
elem_ty: Type,
|
|
|
|
shape: &V,
|
|
|
|
shape_len_fn: LenFn,
|
|
|
|
shape_data_fn: DataFn,
|
|
|
|
) -> Result<NDArrayValue<'ctx>, String>
|
2024-06-12 14:45:03 +08:00
|
|
|
where
|
|
|
|
G: CodeGenerator + ?Sized,
|
|
|
|
LenFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, &V) -> Result<IntValue<'ctx>, String>,
|
|
|
|
DataFn: Fn(
|
|
|
|
&mut G,
|
|
|
|
&mut CodeGenContext<'ctx, 'a>,
|
|
|
|
&V,
|
|
|
|
IntValue<'ctx>,
|
|
|
|
) -> Result<IntValue<'ctx>, String>,
|
2024-03-11 14:47:01 +08:00
|
|
|
{
|
|
|
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
2024-11-27 14:45:13 +08:00
|
|
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
2024-03-11 14:47:01 +08:00
|
|
|
|
|
|
|
// Assert that all dimensions are non-negative
|
2024-03-08 13:13:18 +08:00
|
|
|
let shape_len = shape_len_fn(generator, ctx, shape)?;
|
|
|
|
gen_for_callback_incrementing(
|
2024-03-11 14:47:01 +08:00
|
|
|
generator,
|
|
|
|
ctx,
|
2024-07-25 15:54:39 +08:00
|
|
|
None,
|
2024-03-08 13:13:18 +08:00
|
|
|
llvm_usize.const_zero(),
|
|
|
|
(shape_len, false),
|
2024-07-02 19:05:00 +08:00
|
|
|
|generator, ctx, _, i| {
|
2024-03-11 14:47:01 +08:00
|
|
|
let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
|
|
|
|
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
|
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
let shape_dim_gez = ctx
|
|
|
|
.builder
|
|
|
|
.build_int_compare(
|
|
|
|
IntPredicate::SGE,
|
|
|
|
shape_dim,
|
|
|
|
shape_dim.get_type().const_zero(),
|
|
|
|
"",
|
|
|
|
)
|
2024-03-11 14:47:01 +08:00
|
|
|
.unwrap();
|
|
|
|
|
|
|
|
ctx.make_assert(
|
|
|
|
generator,
|
|
|
|
shape_dim_gez,
|
|
|
|
"0:ValueError",
|
|
|
|
"negative dimensions not supported",
|
|
|
|
[None, None, None],
|
|
|
|
ctx.current_loc,
|
|
|
|
);
|
2024-06-12 14:45:03 +08:00
|
|
|
|
2024-11-27 16:06:16 +08:00
|
|
|
// TODO: Disallow shape > u32_MAX
|
2024-03-11 14:47:01 +08:00
|
|
|
|
|
|
|
Ok(())
|
|
|
|
},
|
2024-03-08 13:13:18 +08:00
|
|
|
llvm_usize.const_int(1, false),
|
2024-03-11 14:47:01 +08:00
|
|
|
)?;
|
|
|
|
|
|
|
|
let num_dims = shape_len_fn(generator, ctx, shape)?;
|
|
|
|
|
2024-11-27 14:45:13 +08:00
|
|
|
let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty, None)
|
|
|
|
.construct_dyn_ndims(generator, ctx, num_dims, None);
|
2024-03-11 14:47:01 +08:00
|
|
|
|
|
|
|
// Copy the dimension sizes from shape to ndarray.dims
|
2024-03-08 13:13:18 +08:00
|
|
|
let shape_len = shape_len_fn(generator, ctx, shape)?;
|
|
|
|
gen_for_callback_incrementing(
|
2024-03-11 14:47:01 +08:00
|
|
|
generator,
|
|
|
|
ctx,
|
2024-07-25 15:54:39 +08:00
|
|
|
None,
|
2024-03-08 13:13:18 +08:00
|
|
|
llvm_usize.const_zero(),
|
|
|
|
(shape_len, false),
|
2024-07-02 19:05:00 +08:00
|
|
|
|generator, ctx, _, i| {
|
2024-03-11 14:47:01 +08:00
|
|
|
let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
|
|
|
|
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
|
2024-06-12 14:45:03 +08:00
|
|
|
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
let ndarray_pdim =
|
2024-11-13 15:53:29 +08:00
|
|
|
unsafe { ndarray.shape().ptr_offset_unchecked(ctx, generator, &i, None) };
|
2024-03-11 14:47:01 +08:00
|
|
|
|
|
|
|
ctx.builder.build_store(ndarray_pdim, shape_dim).unwrap();
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
},
|
2024-03-08 13:13:18 +08:00
|
|
|
llvm_usize.const_int(1, false),
|
2024-03-11 14:47:01 +08:00
|
|
|
)?;
|
|
|
|
|
2024-11-27 14:45:13 +08:00
|
|
|
unsafe { ndarray.create_data(generator, ctx) };
|
2024-03-11 14:47:01 +08:00
|
|
|
|
|
|
|
Ok(ndarray)
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as
|
|
|
|
/// its input.
|
2024-03-19 18:24:30 +08:00
|
|
|
fn ndarray_fill_flattened<'ctx, 'a, G, ValueFn>(
|
|
|
|
generator: &mut G,
|
2024-03-11 14:47:01 +08:00
|
|
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
|
|
|
ndarray: NDArrayValue<'ctx>,
|
|
|
|
value_fn: ValueFn,
|
|
|
|
) -> Result<(), String>
|
2024-06-12 14:45:03 +08:00
|
|
|
where
|
|
|
|
G: CodeGenerator + ?Sized,
|
|
|
|
ValueFn: Fn(
|
|
|
|
&mut G,
|
|
|
|
&mut CodeGenContext<'ctx, 'a>,
|
|
|
|
IntValue<'ctx>,
|
|
|
|
) -> Result<BasicValueEnum<'ctx>, String>,
|
2024-03-11 14:47:01 +08:00
|
|
|
{
|
|
|
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
|
|
|
2024-12-19 12:21:08 +08:00
|
|
|
let ndarray_num_elems = ndarray.size(generator, ctx);
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-03-08 13:13:18 +08:00
|
|
|
gen_for_callback_incrementing(
|
2024-03-11 14:47:01 +08:00
|
|
|
generator,
|
|
|
|
ctx,
|
2024-07-25 15:54:39 +08:00
|
|
|
None,
|
2024-03-08 13:13:18 +08:00
|
|
|
llvm_usize.const_zero(),
|
|
|
|
(ndarray_num_elems, false),
|
2024-07-02 19:05:00 +08:00
|
|
|
|generator, ctx, _, i| {
|
2024-06-12 14:45:03 +08:00
|
|
|
let elem = unsafe { ndarray.data().ptr_offset_unchecked(ctx, generator, &i, None) };
|
2024-03-11 14:47:01 +08:00
|
|
|
|
|
|
|
let value = value_fn(generator, ctx, i)?;
|
|
|
|
ctx.builder.build_store(elem, value).unwrap();
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
},
|
2024-03-08 13:13:18 +08:00
|
|
|
llvm_usize.const_int(1, false),
|
2024-03-11 14:47:01 +08:00
|
|
|
)
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with the dimension-indices
|
|
|
|
/// as its input.
|
2024-04-29 23:21:57 +08:00
|
|
|
fn ndarray_fill_indexed<'ctx, 'a, G, ValueFn>(
|
2024-03-19 18:24:30 +08:00
|
|
|
generator: &mut G,
|
2024-04-29 23:21:57 +08:00
|
|
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
2024-03-11 14:47:01 +08:00
|
|
|
ndarray: NDArrayValue<'ctx>,
|
|
|
|
value_fn: ValueFn,
|
|
|
|
) -> Result<(), String>
|
2024-06-12 14:45:03 +08:00
|
|
|
where
|
|
|
|
G: CodeGenerator + ?Sized,
|
|
|
|
ValueFn: Fn(
|
|
|
|
&mut G,
|
|
|
|
&mut CodeGenContext<'ctx, 'a>,
|
2024-12-17 18:03:03 +08:00
|
|
|
&TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>>,
|
2024-06-12 14:45:03 +08:00
|
|
|
) -> Result<BasicValueEnum<'ctx>, String>,
|
2024-03-11 14:47:01 +08:00
|
|
|
{
|
2024-06-12 14:45:03 +08:00
|
|
|
ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, idx| {
|
|
|
|
let indices = call_ndarray_calc_nd_indices(generator, ctx, idx, ndarray);
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
value_fn(generator, ctx, &indices)
|
|
|
|
})
|
2024-03-11 14:47:01 +08:00
|
|
|
}
|
|
|
|
|
2024-04-29 23:21:57 +08:00
|
|
|
fn ndarray_fill_mapping<'ctx, 'a, G, MapFn>(
|
2024-03-27 17:06:58 +08:00
|
|
|
generator: &mut G,
|
2024-04-29 23:21:57 +08:00
|
|
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
2024-03-27 17:06:58 +08:00
|
|
|
src: NDArrayValue<'ctx>,
|
|
|
|
dest: NDArrayValue<'ctx>,
|
|
|
|
map_fn: MapFn,
|
|
|
|
) -> Result<(), String>
|
2024-06-12 14:45:03 +08:00
|
|
|
where
|
|
|
|
G: CodeGenerator + ?Sized,
|
|
|
|
MapFn: Fn(
|
|
|
|
&mut G,
|
|
|
|
&mut CodeGenContext<'ctx, 'a>,
|
|
|
|
BasicValueEnum<'ctx>,
|
|
|
|
) -> Result<BasicValueEnum<'ctx>, String>,
|
2024-03-27 17:06:58 +08:00
|
|
|
{
|
2024-06-12 14:45:03 +08:00
|
|
|
ndarray_fill_flattened(generator, ctx, dest, |generator, ctx, i| {
|
|
|
|
let elem = unsafe { src.data().get_unchecked(ctx, generator, &i, None) };
|
2024-03-27 17:06:58 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
map_fn(generator, ctx, elem)
|
|
|
|
})
|
2024-03-27 17:06:58 +08:00
|
|
|
}
|
|
|
|
|
2024-03-13 11:16:23 +08:00
|
|
|
/// Generates the LLVM IR for checking whether the source `ndarray` can be broadcast to the shape of
|
|
|
|
/// the target `ndarray`.
|
|
|
|
fn ndarray_assert_is_broadcastable<'ctx, G: CodeGenerator + ?Sized>(
|
|
|
|
generator: &mut G,
|
|
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
target: NDArrayValue<'ctx>,
|
|
|
|
source: NDArrayValue<'ctx>,
|
|
|
|
) {
|
|
|
|
let array_ndims = source.load_ndims(ctx);
|
|
|
|
let broadcast_size = target.load_ndims(ctx);
|
|
|
|
|
|
|
|
ctx.make_assert(
|
|
|
|
generator,
|
|
|
|
ctx.builder.build_int_compare(IntPredicate::ULE, array_ndims, broadcast_size, "").unwrap(),
|
|
|
|
"0:ValueError",
|
|
|
|
"operands cannot be broadcast together",
|
|
|
|
[None, None, None],
|
|
|
|
ctx.current_loc,
|
|
|
|
);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Generates the LLVM IR for populating the entire `NDArray` from two `ndarray` or scalar value
|
|
|
|
/// with broadcast-compatible shapes.
|
2024-04-29 23:21:57 +08:00
|
|
|
fn ndarray_broadcast_fill<'ctx, 'a, G, ValueFn>(
|
2024-03-13 11:16:23 +08:00
|
|
|
generator: &mut G,
|
2024-04-29 23:21:57 +08:00
|
|
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
2024-03-13 11:16:23 +08:00
|
|
|
res: NDArrayValue<'ctx>,
|
2024-11-27 16:06:16 +08:00
|
|
|
(lhs_ty, lhs_val, lhs_scalar): (Type, BasicValueEnum<'ctx>, bool),
|
|
|
|
(rhs_ty, rhs_val, rhs_scalar): (Type, BasicValueEnum<'ctx>, bool),
|
2024-03-13 11:16:23 +08:00
|
|
|
value_fn: ValueFn,
|
|
|
|
) -> Result<NDArrayValue<'ctx>, String>
|
2024-06-12 14:45:03 +08:00
|
|
|
where
|
|
|
|
G: CodeGenerator + ?Sized,
|
|
|
|
ValueFn: Fn(
|
|
|
|
&mut G,
|
|
|
|
&mut CodeGenContext<'ctx, 'a>,
|
|
|
|
(BasicValueEnum<'ctx>, BasicValueEnum<'ctx>),
|
|
|
|
) -> Result<BasicValueEnum<'ctx>, String>,
|
2024-03-13 11:16:23 +08:00
|
|
|
{
|
2024-06-12 14:45:03 +08:00
|
|
|
assert!(
|
|
|
|
!(lhs_scalar && rhs_scalar),
|
|
|
|
"One of the operands must be a ndarray instance: `{}`, `{}`",
|
|
|
|
lhs_val.get_type(),
|
|
|
|
rhs_val.get_type()
|
|
|
|
);
|
2024-03-13 11:16:23 +08:00
|
|
|
|
2024-12-19 12:21:08 +08:00
|
|
|
// Returns the element of an ndarray indexed by the given indices, performing int-promotion on
|
|
|
|
// `indices` where necessary.
|
|
|
|
//
|
|
|
|
// Required for compatibility with `NDArrayType::get_unchecked`.
|
|
|
|
let get_data_by_indices_compat =
|
|
|
|
|generator: &mut G,
|
|
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
ndarray: NDArrayValue<'ctx>,
|
|
|
|
indices: TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>>| {
|
|
|
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
|
|
|
|
|
|
// Workaround: Promote lhs_idx to usize* to make the array compatible with new IRRT
|
|
|
|
let stackptr = llvm_intrinsics::call_stacksave(ctx, None);
|
|
|
|
let indices = if llvm_usize == ctx.ctx.i32_type() {
|
|
|
|
indices
|
|
|
|
} else {
|
|
|
|
let indices_usize = TypedArrayLikeAdapter::<G, IntValue<'ctx>>::from(
|
|
|
|
ArraySliceValue::from_ptr_val(
|
|
|
|
ctx.builder
|
|
|
|
.build_array_alloca(llvm_usize, indices.size(ctx, generator), "")
|
|
|
|
.unwrap(),
|
|
|
|
indices.size(ctx, generator),
|
|
|
|
None,
|
|
|
|
),
|
|
|
|
|_, _, val| val.into_int_value(),
|
|
|
|
|_, _, val| val.into(),
|
|
|
|
);
|
|
|
|
|
|
|
|
gen_for_callback_incrementing(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
None,
|
|
|
|
llvm_usize.const_zero(),
|
|
|
|
(indices.size(ctx, generator), false),
|
|
|
|
|generator, ctx, _, i| {
|
|
|
|
let idx = unsafe { indices.get_typed_unchecked(ctx, generator, &i, None) };
|
|
|
|
let idx = ctx
|
|
|
|
.builder
|
|
|
|
.build_int_z_extend_or_bit_cast(idx, llvm_usize, "")
|
|
|
|
.unwrap();
|
|
|
|
unsafe {
|
|
|
|
indices_usize.set_typed_unchecked(ctx, generator, &i, idx);
|
|
|
|
}
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
},
|
|
|
|
llvm_usize.const_int(1, false),
|
|
|
|
)
|
|
|
|
.unwrap();
|
|
|
|
|
|
|
|
indices_usize
|
|
|
|
};
|
|
|
|
|
|
|
|
let elem = unsafe { ndarray.data().get_unchecked(ctx, generator, &indices, None) };
|
|
|
|
|
|
|
|
llvm_intrinsics::call_stackrestore(ctx, stackptr);
|
|
|
|
|
|
|
|
elem
|
|
|
|
};
|
|
|
|
|
2024-03-13 11:16:23 +08:00
|
|
|
// Assert that all ndarray operands are broadcastable to the target size
|
|
|
|
if !lhs_scalar {
|
2024-11-27 16:06:16 +08:00
|
|
|
let lhs_val = NDArrayType::from_unifier_type(generator, ctx, lhs_ty)
|
|
|
|
.map_value(lhs_val.into_pointer_value(), None);
|
2024-03-13 11:16:23 +08:00
|
|
|
ndarray_assert_is_broadcastable(generator, ctx, res, lhs_val);
|
|
|
|
}
|
|
|
|
|
|
|
|
if !rhs_scalar {
|
2024-11-27 16:06:16 +08:00
|
|
|
let rhs_val = NDArrayType::from_unifier_type(generator, ctx, rhs_ty)
|
|
|
|
.map_value(rhs_val.into_pointer_value(), None);
|
2024-03-13 11:16:23 +08:00
|
|
|
ndarray_assert_is_broadcastable(generator, ctx, res, rhs_val);
|
|
|
|
}
|
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
ndarray_fill_indexed(generator, ctx, res, |generator, ctx, idx| {
|
|
|
|
let lhs_elem = if lhs_scalar {
|
|
|
|
lhs_val
|
|
|
|
} else {
|
2024-11-27 16:06:16 +08:00
|
|
|
let lhs = NDArrayType::from_unifier_type(generator, ctx, lhs_ty)
|
|
|
|
.map_value(lhs_val.into_pointer_value(), None);
|
2024-06-12 14:45:03 +08:00
|
|
|
let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, idx);
|
2024-03-13 11:16:23 +08:00
|
|
|
|
2024-12-19 12:21:08 +08:00
|
|
|
get_data_by_indices_compat(generator, ctx, lhs, lhs_idx)
|
2024-06-12 14:45:03 +08:00
|
|
|
};
|
2024-03-13 11:16:23 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
let rhs_elem = if rhs_scalar {
|
|
|
|
rhs_val
|
|
|
|
} else {
|
2024-11-27 16:06:16 +08:00
|
|
|
let rhs = NDArrayType::from_unifier_type(generator, ctx, rhs_ty)
|
|
|
|
.map_value(rhs_val.into_pointer_value(), None);
|
2024-06-12 14:45:03 +08:00
|
|
|
let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, idx);
|
2024-03-13 11:16:23 +08:00
|
|
|
|
2024-12-19 12:21:08 +08:00
|
|
|
get_data_by_indices_compat(generator, ctx, rhs, rhs_idx)
|
2024-06-12 14:45:03 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
value_fn(generator, ctx, (lhs_elem, rhs_elem))
|
|
|
|
})?;
|
2024-03-13 11:16:23 +08:00
|
|
|
|
|
|
|
Ok(res)
|
|
|
|
}
|
|
|
|
|
2024-05-30 14:25:56 +08:00
|
|
|
/// Copies a slice of an [`NDArrayValue`] to another.
|
2024-03-11 14:47:01 +08:00
|
|
|
///
|
2024-11-27 16:06:16 +08:00
|
|
|
/// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `shape`
|
2024-08-21 11:10:52 +08:00
|
|
|
/// fields should be populated before calling this function.
|
2024-05-30 14:25:56 +08:00
|
|
|
/// - `dst_slice_ptr`: The [`PointerValue`] to the first element of the currently processing
|
2024-08-21 11:10:52 +08:00
|
|
|
/// dimensional slice in the destination array.
|
2024-05-30 14:25:56 +08:00
|
|
|
/// - `src_arr`: The [`NDArrayValue`] instance of the source array.
|
|
|
|
/// - `src_slice_ptr`: The [`PointerValue`] to the first element of the currently processing
|
2024-08-21 11:10:52 +08:00
|
|
|
/// dimensional slice in the source array.
|
2024-05-30 14:25:56 +08:00
|
|
|
/// - `dim`: The index of the currently processing dimension.
|
|
|
|
/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to
|
2024-08-21 11:10:52 +08:00
|
|
|
/// this dimension. The `start`/`stop` values of each slice must be non-negative indices.
|
2024-05-30 14:25:56 +08:00
|
|
|
fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
|
2024-03-19 18:24:30 +08:00
|
|
|
generator: &mut G,
|
2024-03-11 14:47:01 +08:00
|
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
2024-05-30 14:25:56 +08:00
|
|
|
(dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>),
|
|
|
|
(src_arr, src_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>),
|
|
|
|
dim: u64,
|
|
|
|
slices: &[(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)],
|
|
|
|
) -> Result<(), String> {
|
2024-03-11 14:47:01 +08:00
|
|
|
let llvm_i1 = ctx.ctx.bool_type();
|
2024-05-30 14:25:56 +08:00
|
|
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
|
|
|
2024-08-28 16:33:03 +08:00
|
|
|
assert_eq!(dst_arr.get_type().element_type(), src_arr.get_type().element_type());
|
|
|
|
|
|
|
|
let sizeof_elem = dst_arr.get_type().element_type().size_of().unwrap();
|
|
|
|
|
2024-05-30 14:25:56 +08:00
|
|
|
// If there are no (remaining) slice expressions, memcpy the entire dimension
|
|
|
|
if slices.is_empty() {
|
|
|
|
let stride = call_ndarray_calc_size(
|
|
|
|
generator,
|
|
|
|
ctx,
|
2024-11-13 15:53:29 +08:00
|
|
|
&src_arr.shape(),
|
2024-05-30 14:25:56 +08:00
|
|
|
(Some(llvm_usize.const_int(dim, false)), None),
|
|
|
|
);
|
2024-07-22 01:46:50 +08:00
|
|
|
let stride =
|
|
|
|
ctx.builder.build_int_z_extend_or_bit_cast(stride, sizeof_elem.get_type(), "").unwrap();
|
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
let cpy_len = ctx.builder.build_int_mul(stride, sizeof_elem, "").unwrap();
|
2024-05-30 14:25:56 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
call_memcpy_generic(ctx, dst_slice_ptr, src_slice_ptr, cpy_len, llvm_i1.const_zero());
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
return Ok(());
|
2024-05-30 14:25:56 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
// The stride of elements in this dimension, i.e. the number of elements between arr[i] and
|
|
|
|
// arr[i + 1] in this dimension
|
|
|
|
let src_stride = call_ndarray_calc_size(
|
2024-03-11 14:47:01 +08:00
|
|
|
generator,
|
|
|
|
ctx,
|
2024-11-13 15:53:29 +08:00
|
|
|
&src_arr.shape(),
|
2024-05-30 14:25:56 +08:00
|
|
|
(Some(llvm_usize.const_int(dim + 1, false)), None),
|
|
|
|
);
|
|
|
|
let dst_stride = call_ndarray_calc_size(
|
|
|
|
generator,
|
|
|
|
ctx,
|
2024-11-13 15:53:29 +08:00
|
|
|
&dst_arr.shape(),
|
2024-05-30 14:25:56 +08:00
|
|
|
(Some(llvm_usize.const_int(dim + 1, false)), None),
|
|
|
|
);
|
|
|
|
|
|
|
|
let (start, stop, step) = slices[0];
|
|
|
|
let start = ctx.builder.build_int_s_extend_or_bit_cast(start, llvm_usize, "").unwrap();
|
|
|
|
let stop = ctx.builder.build_int_s_extend_or_bit_cast(stop, llvm_usize, "").unwrap();
|
|
|
|
let step = ctx.builder.build_int_s_extend_or_bit_cast(step, llvm_usize, "").unwrap();
|
|
|
|
|
|
|
|
let dst_i_addr = generator.gen_var_alloc(ctx, start.get_type().into(), None).unwrap();
|
|
|
|
ctx.builder.build_store(dst_i_addr, start.get_type().const_zero()).unwrap();
|
|
|
|
|
|
|
|
gen_for_range_callback(
|
|
|
|
generator,
|
|
|
|
ctx,
|
2024-07-25 15:54:39 +08:00
|
|
|
None,
|
2024-05-30 14:25:56 +08:00
|
|
|
false,
|
|
|
|
|_, _| Ok(start),
|
|
|
|
(|_, _| Ok(stop), true),
|
|
|
|
|_, _| Ok(step),
|
2024-07-25 15:54:39 +08:00
|
|
|
|generator, ctx, _, src_i| {
|
2024-05-30 14:25:56 +08:00
|
|
|
// Calculate the offset of the active slice
|
2024-06-12 14:45:03 +08:00
|
|
|
let src_data_offset = ctx.builder.build_int_mul(src_stride, src_i, "").unwrap();
|
2024-08-28 16:33:03 +08:00
|
|
|
let src_data_offset = ctx
|
|
|
|
.builder
|
|
|
|
.build_int_mul(
|
|
|
|
src_data_offset,
|
|
|
|
ctx.builder
|
2024-11-29 17:19:43 +08:00
|
|
|
.build_int_z_extend_or_bit_cast(sizeof_elem, src_data_offset.get_type(), "")
|
2024-08-28 16:33:03 +08:00
|
|
|
.unwrap(),
|
|
|
|
"",
|
|
|
|
)
|
|
|
|
.unwrap();
|
2024-06-12 14:45:03 +08:00
|
|
|
let dst_i =
|
|
|
|
ctx.builder.build_load(dst_i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
|
|
|
|
let dst_data_offset = ctx.builder.build_int_mul(dst_stride, dst_i, "").unwrap();
|
2024-08-28 16:33:03 +08:00
|
|
|
let dst_data_offset = ctx
|
|
|
|
.builder
|
|
|
|
.build_int_mul(
|
|
|
|
dst_data_offset,
|
|
|
|
ctx.builder
|
2024-11-29 17:19:43 +08:00
|
|
|
.build_int_z_extend_or_bit_cast(sizeof_elem, dst_data_offset.get_type(), "")
|
2024-08-28 16:33:03 +08:00
|
|
|
.unwrap(),
|
|
|
|
"",
|
|
|
|
)
|
|
|
|
.unwrap();
|
2024-05-30 14:25:56 +08:00
|
|
|
|
|
|
|
let (src_ptr, dst_ptr) = unsafe {
|
|
|
|
(
|
|
|
|
ctx.builder.build_gep(src_slice_ptr, &[src_data_offset], "").unwrap(),
|
|
|
|
ctx.builder.build_gep(dst_slice_ptr, &[dst_data_offset], "").unwrap(),
|
|
|
|
)
|
|
|
|
};
|
|
|
|
|
|
|
|
ndarray_sliced_copyto_impl(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
(dst_arr, dst_ptr),
|
|
|
|
(src_arr, src_ptr),
|
|
|
|
dim + 1,
|
|
|
|
&slices[1..],
|
|
|
|
)?;
|
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
let dst_i =
|
|
|
|
ctx.builder.build_load(dst_i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
|
|
|
|
let dst_i_add1 =
|
|
|
|
ctx.builder.build_int_add(dst_i, llvm_usize.const_int(1, false), "").unwrap();
|
2024-05-30 14:25:56 +08:00
|
|
|
ctx.builder.build_store(dst_i_addr, dst_i_add1).unwrap();
|
|
|
|
|
|
|
|
Ok(())
|
2024-03-11 14:47:01 +08:00
|
|
|
},
|
|
|
|
)?;
|
|
|
|
|
2024-05-30 14:25:56 +08:00
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Copies a [`NDArrayValue`] using slices.
|
|
|
|
///
|
|
|
|
/// * `elem_ty` - The element type of the `NDArray`.
|
|
|
|
/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to
|
2024-08-21 11:10:52 +08:00
|
|
|
/// this dimension. The `start`/`stop` values of each slice must be positive indices.
|
2024-05-30 14:25:56 +08:00
|
|
|
pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
|
|
|
|
generator: &mut G,
|
|
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
elem_ty: Type,
|
|
|
|
this: NDArrayValue<'ctx>,
|
|
|
|
slices: &[(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)],
|
|
|
|
) -> Result<NDArrayValue<'ctx>, String> {
|
|
|
|
let llvm_i32 = ctx.ctx.i32_type();
|
|
|
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
2024-11-27 14:45:13 +08:00
|
|
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
2024-05-30 14:25:56 +08:00
|
|
|
|
2024-11-27 16:06:16 +08:00
|
|
|
let ndarray =
|
|
|
|
if slices.is_empty() {
|
|
|
|
create_ndarray_dyn_shape(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
elem_ty,
|
|
|
|
&this,
|
|
|
|
|_, ctx, shape| Ok(shape.load_ndims(ctx)),
|
|
|
|
|generator, ctx, shape, idx| unsafe {
|
|
|
|
Ok(shape.shape().get_typed_unchecked(ctx, generator, &idx, None))
|
|
|
|
},
|
|
|
|
)?
|
|
|
|
} else {
|
|
|
|
let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty, None)
|
|
|
|
.construct_dyn_ndims(generator, ctx, this.load_ndims(ctx), None);
|
|
|
|
|
|
|
|
// Populate the first slices.len() dimensions by computing the size of each dim slice
|
|
|
|
for (i, (start, stop, step)) in slices.iter().enumerate() {
|
|
|
|
// HACK: workaround calculate_len_for_slice_range requiring exclusive stop
|
|
|
|
let stop = ctx
|
|
|
|
.builder
|
|
|
|
.build_select(
|
|
|
|
ctx.builder
|
|
|
|
.build_int_compare(
|
|
|
|
IntPredicate::SLT,
|
|
|
|
*step,
|
|
|
|
llvm_i32.const_zero(),
|
|
|
|
"is_neg",
|
|
|
|
)
|
|
|
|
.unwrap(),
|
|
|
|
ctx.builder
|
|
|
|
.build_int_sub(*stop, llvm_i32.const_int(1, true), "e_min_one")
|
|
|
|
.unwrap(),
|
|
|
|
ctx.builder
|
|
|
|
.build_int_add(*stop, llvm_i32.const_int(1, true), "e_add_one")
|
|
|
|
.unwrap(),
|
|
|
|
"final_e",
|
|
|
|
)
|
|
|
|
.map(BasicValueEnum::into_int_value)
|
|
|
|
.unwrap();
|
2024-05-30 14:25:56 +08:00
|
|
|
|
2024-11-27 16:06:16 +08:00
|
|
|
let slice_len = calculate_len_for_slice_range(generator, ctx, *start, stop, *step);
|
|
|
|
let slice_len =
|
|
|
|
ctx.builder.build_int_z_extend_or_bit_cast(slice_len, llvm_usize, "").unwrap();
|
2024-05-30 14:25:56 +08:00
|
|
|
|
|
|
|
unsafe {
|
2024-11-27 16:06:16 +08:00
|
|
|
ndarray.shape().set_typed_unchecked(
|
|
|
|
ctx,
|
|
|
|
generator,
|
|
|
|
&llvm_usize.const_int(i as u64, false),
|
|
|
|
slice_len,
|
|
|
|
);
|
2024-05-30 14:25:56 +08:00
|
|
|
}
|
2024-11-27 16:06:16 +08:00
|
|
|
}
|
2024-05-30 14:25:56 +08:00
|
|
|
|
2024-11-27 16:06:16 +08:00
|
|
|
// Populate the rest by directly copying the dim size from the source array
|
|
|
|
gen_for_callback_incrementing(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
None,
|
|
|
|
llvm_usize.const_int(slices.len() as u64, false),
|
|
|
|
(this.load_ndims(ctx), false),
|
|
|
|
|generator, ctx, _, idx| {
|
|
|
|
unsafe {
|
|
|
|
let shape = this.shape().get_typed_unchecked(ctx, generator, &idx, None);
|
|
|
|
ndarray.shape().set_typed_unchecked(ctx, generator, &idx, shape);
|
|
|
|
}
|
2024-05-30 14:25:56 +08:00
|
|
|
|
2024-11-27 16:06:16 +08:00
|
|
|
Ok(())
|
|
|
|
},
|
|
|
|
llvm_usize.const_int(1, false),
|
|
|
|
)
|
|
|
|
.unwrap();
|
|
|
|
|
2024-11-27 14:45:13 +08:00
|
|
|
unsafe { ndarray.create_data(generator, ctx) };
|
|
|
|
|
|
|
|
ndarray
|
2024-11-27 16:06:16 +08:00
|
|
|
};
|
2024-05-30 14:25:56 +08:00
|
|
|
|
|
|
|
ndarray_sliced_copyto_impl(
|
2024-03-11 14:47:01 +08:00
|
|
|
generator,
|
|
|
|
ctx,
|
2024-05-30 14:25:56 +08:00
|
|
|
(ndarray, ndarray.data().base_ptr(ctx, generator)),
|
|
|
|
(this, this.data().base_ptr(ctx, generator)),
|
|
|
|
0,
|
|
|
|
slices,
|
|
|
|
)?;
|
2024-03-11 14:47:01 +08:00
|
|
|
|
|
|
|
Ok(ndarray)
|
|
|
|
}
|
|
|
|
|
2024-05-30 14:25:56 +08:00
|
|
|
/// LLVM-typed implementation for generating the implementation for `ndarray.copy`.
|
|
|
|
///
|
|
|
|
/// * `elem_ty` - The element type of the `NDArray`.
|
|
|
|
fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|
|
|
generator: &mut G,
|
|
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
elem_ty: Type,
|
|
|
|
this: NDArrayValue<'ctx>,
|
|
|
|
) -> Result<NDArrayValue<'ctx>, String> {
|
|
|
|
ndarray_sliced_copy(generator, ctx, elem_ty, this, &[])
|
|
|
|
}
|
|
|
|
|
2024-04-29 23:21:57 +08:00
|
|
|
pub fn ndarray_elementwise_unaryop_impl<'ctx, 'a, G, MapFn>(
|
2024-03-27 17:06:58 +08:00
|
|
|
generator: &mut G,
|
2024-04-29 23:21:57 +08:00
|
|
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
2024-03-27 17:06:58 +08:00
|
|
|
elem_ty: Type,
|
|
|
|
res: Option<NDArrayValue<'ctx>>,
|
|
|
|
operand: NDArrayValue<'ctx>,
|
|
|
|
map_fn: MapFn,
|
|
|
|
) -> Result<NDArrayValue<'ctx>, String>
|
2024-06-12 14:45:03 +08:00
|
|
|
where
|
|
|
|
G: CodeGenerator + ?Sized,
|
|
|
|
MapFn: Fn(
|
|
|
|
&mut G,
|
|
|
|
&mut CodeGenContext<'ctx, 'a>,
|
|
|
|
BasicValueEnum<'ctx>,
|
|
|
|
) -> Result<BasicValueEnum<'ctx>, String>,
|
2024-03-27 17:06:58 +08:00
|
|
|
{
|
|
|
|
let res = res.unwrap_or_else(|| {
|
|
|
|
create_ndarray_dyn_shape(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
elem_ty,
|
|
|
|
&operand,
|
2024-06-12 14:45:03 +08:00
|
|
|
|_, ctx, v| Ok(v.load_ndims(ctx)),
|
|
|
|
|generator, ctx, v, idx| unsafe {
|
2024-11-13 15:53:29 +08:00
|
|
|
Ok(v.shape().get_typed_unchecked(ctx, generator, &idx, None))
|
2024-03-27 17:06:58 +08:00
|
|
|
},
|
2024-06-12 14:45:03 +08:00
|
|
|
)
|
|
|
|
.unwrap()
|
2024-03-27 17:06:58 +08:00
|
|
|
});
|
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
ndarray_fill_mapping(generator, ctx, operand, res, |generator, ctx, elem| {
|
|
|
|
map_fn(generator, ctx, elem)
|
|
|
|
})?;
|
2024-03-27 17:06:58 +08:00
|
|
|
|
|
|
|
Ok(res)
|
|
|
|
}
|
|
|
|
|
2024-03-13 11:16:23 +08:00
|
|
|
/// LLVM-typed implementation for computing elementwise binary operations on two input operands.
|
|
|
|
///
|
2024-06-12 14:45:03 +08:00
|
|
|
/// If the operand is a `ndarray`, the broadcast index corresponding to each element in the output
|
|
|
|
/// is computed, the element accessed and used as an operand of the `value_fn` arguments tuple.
|
|
|
|
/// Otherwise, the operand is treated as a scalar value, and is used as an operand of the
|
2024-03-13 11:16:23 +08:00
|
|
|
/// `value_fn` arguments tuple for all output elements.
|
|
|
|
///
|
|
|
|
/// The second element of the tuple indicates whether to treat the operand value as a `ndarray`
|
2024-06-12 14:45:03 +08:00
|
|
|
/// (which would be accessed by its broadcast index) or as a scalar value (which would be
|
2024-03-13 11:16:23 +08:00
|
|
|
/// broadcast to all elements).
|
|
|
|
///
|
|
|
|
/// * `elem_ty` - The element type of the `NDArray`.
|
|
|
|
/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be
|
2024-08-21 11:10:52 +08:00
|
|
|
/// written to a new `ndarray`.
|
2024-03-13 11:16:23 +08:00
|
|
|
/// * `value_fn` - Function mapping the two input elements into the result.
|
|
|
|
///
|
|
|
|
/// # Panic
|
|
|
|
///
|
|
|
|
/// This function will panic if neither input operands (`lhs` or `rhs`) is a `ndarray`.
|
2024-04-29 23:21:57 +08:00
|
|
|
pub fn ndarray_elementwise_binop_impl<'ctx, 'a, G, ValueFn>(
|
2024-03-13 11:16:23 +08:00
|
|
|
generator: &mut G,
|
2024-04-29 23:21:57 +08:00
|
|
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
2024-03-13 11:16:23 +08:00
|
|
|
elem_ty: Type,
|
|
|
|
res: Option<NDArrayValue<'ctx>>,
|
2024-08-28 16:33:03 +08:00
|
|
|
lhs: (Type, BasicValueEnum<'ctx>, bool),
|
|
|
|
rhs: (Type, BasicValueEnum<'ctx>, bool),
|
2024-03-13 11:16:23 +08:00
|
|
|
value_fn: ValueFn,
|
|
|
|
) -> Result<NDArrayValue<'ctx>, String>
|
2024-06-12 14:45:03 +08:00
|
|
|
where
|
|
|
|
G: CodeGenerator + ?Sized,
|
|
|
|
ValueFn: Fn(
|
|
|
|
&mut G,
|
|
|
|
&mut CodeGenContext<'ctx, 'a>,
|
|
|
|
(BasicValueEnum<'ctx>, BasicValueEnum<'ctx>),
|
|
|
|
) -> Result<BasicValueEnum<'ctx>, String>,
|
2024-03-13 11:16:23 +08:00
|
|
|
{
|
2024-08-28 16:33:03 +08:00
|
|
|
let (lhs_ty, lhs_val, lhs_scalar) = lhs;
|
|
|
|
let (rhs_ty, rhs_val, rhs_scalar) = rhs;
|
2024-03-13 11:16:23 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
assert!(
|
|
|
|
!(lhs_scalar && rhs_scalar),
|
|
|
|
"One of the operands must be a ndarray instance: `{}`, `{}`",
|
|
|
|
lhs_val.get_type(),
|
|
|
|
rhs_val.get_type()
|
|
|
|
);
|
2024-03-13 11:16:23 +08:00
|
|
|
|
|
|
|
let ndarray = res.unwrap_or_else(|| {
|
|
|
|
if lhs_scalar && rhs_scalar {
|
2024-11-27 16:06:16 +08:00
|
|
|
let lhs_val = NDArrayType::from_unifier_type(generator, ctx, lhs_ty)
|
|
|
|
.map_value(lhs_val.into_pointer_value(), None);
|
|
|
|
let rhs_val = NDArrayType::from_unifier_type(generator, ctx, rhs_ty)
|
|
|
|
.map_value(rhs_val.into_pointer_value(), None);
|
2024-03-13 11:16:23 +08:00
|
|
|
|
|
|
|
let ndarray_dims = call_ndarray_calc_broadcast(generator, ctx, lhs_val, rhs_val);
|
|
|
|
|
|
|
|
create_ndarray_dyn_shape(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
elem_ty,
|
|
|
|
&ndarray_dims,
|
2024-06-12 14:45:03 +08:00
|
|
|
|generator, ctx, v| Ok(v.size(ctx, generator)),
|
|
|
|
|generator, ctx, v, idx| unsafe {
|
|
|
|
Ok(v.get_typed_unchecked(ctx, generator, &idx, None))
|
2024-03-13 11:16:23 +08:00
|
|
|
},
|
2024-06-12 14:45:03 +08:00
|
|
|
)
|
|
|
|
.unwrap()
|
2024-03-13 11:16:23 +08:00
|
|
|
} else {
|
2024-11-27 16:06:16 +08:00
|
|
|
let ndarray = NDArrayType::from_unifier_type(
|
|
|
|
generator,
|
|
|
|
ctx,
|
2024-08-28 16:33:03 +08:00
|
|
|
if lhs_scalar { rhs_ty } else { lhs_ty },
|
2024-11-27 16:06:16 +08:00
|
|
|
)
|
|
|
|
.map_value(if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(), None);
|
2024-03-13 11:16:23 +08:00
|
|
|
|
|
|
|
create_ndarray_dyn_shape(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
elem_ty,
|
|
|
|
&ndarray,
|
2024-06-12 14:45:03 +08:00
|
|
|
|_, ctx, v| Ok(v.load_ndims(ctx)),
|
|
|
|
|generator, ctx, v, idx| unsafe {
|
2024-11-13 15:53:29 +08:00
|
|
|
Ok(v.shape().get_typed_unchecked(ctx, generator, &idx, None))
|
2024-03-13 11:16:23 +08:00
|
|
|
},
|
2024-06-12 14:45:03 +08:00
|
|
|
)
|
|
|
|
.unwrap()
|
2024-03-13 11:16:23 +08:00
|
|
|
}
|
|
|
|
});
|
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
ndarray_broadcast_fill(generator, ctx, ndarray, lhs, rhs, |generator, ctx, elems| {
|
|
|
|
value_fn(generator, ctx, elems)
|
|
|
|
})?;
|
2024-03-13 11:16:23 +08:00
|
|
|
|
|
|
|
Ok(ndarray)
|
|
|
|
}
|
|
|
|
|
2024-04-19 19:00:07 +08:00
|
|
|
/// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s.
|
|
|
|
///
|
|
|
|
/// * `elem_ty` - The element type of the `NDArray`.
|
|
|
|
/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be
|
2024-08-21 11:10:52 +08:00
|
|
|
/// written to a new `ndarray`.
|
2024-04-19 19:00:07 +08:00
|
|
|
pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
|
|
|
|
generator: &mut G,
|
|
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
elem_ty: Type,
|
|
|
|
res: Option<NDArrayValue<'ctx>>,
|
|
|
|
lhs: NDArrayValue<'ctx>,
|
|
|
|
rhs: NDArrayValue<'ctx>,
|
|
|
|
) -> Result<NDArrayValue<'ctx>, String> {
|
|
|
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
|
|
|
|
|
|
if cfg!(debug_assertions) {
|
|
|
|
let lhs_ndims = lhs.load_ndims(ctx);
|
|
|
|
let rhs_ndims = rhs.load_ndims(ctx);
|
|
|
|
|
|
|
|
// lhs.ndims == 2
|
|
|
|
ctx.make_assert(
|
|
|
|
generator,
|
2024-06-12 14:45:03 +08:00
|
|
|
ctx.builder
|
|
|
|
.build_int_compare(IntPredicate::EQ, lhs_ndims, llvm_usize.const_int(2, false), "")
|
|
|
|
.unwrap(),
|
2024-04-19 19:00:07 +08:00
|
|
|
"0:ValueError",
|
|
|
|
"",
|
|
|
|
[None, None, None],
|
|
|
|
ctx.current_loc,
|
|
|
|
);
|
|
|
|
|
|
|
|
// rhs.ndims == 2
|
|
|
|
ctx.make_assert(
|
|
|
|
generator,
|
2024-06-12 14:45:03 +08:00
|
|
|
ctx.builder
|
|
|
|
.build_int_compare(IntPredicate::EQ, rhs_ndims, llvm_usize.const_int(2, false), "")
|
|
|
|
.unwrap(),
|
2024-04-19 19:00:07 +08:00
|
|
|
"0:ValueError",
|
|
|
|
"",
|
|
|
|
[None, None, None],
|
|
|
|
ctx.current_loc,
|
|
|
|
);
|
|
|
|
|
|
|
|
if let Some(res) = res {
|
|
|
|
let res_ndims = res.load_ndims(ctx);
|
|
|
|
let res_dim0 = unsafe {
|
2024-11-13 15:53:29 +08:00
|
|
|
res.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
2024-04-19 19:00:07 +08:00
|
|
|
};
|
|
|
|
let res_dim1 = unsafe {
|
2024-11-13 15:53:29 +08:00
|
|
|
res.shape().get_typed_unchecked(
|
2024-06-12 14:45:03 +08:00
|
|
|
ctx,
|
|
|
|
generator,
|
|
|
|
&llvm_usize.const_int(1, false),
|
|
|
|
None,
|
|
|
|
)
|
2024-04-19 19:00:07 +08:00
|
|
|
};
|
|
|
|
let lhs_dim0 = unsafe {
|
2024-11-13 15:53:29 +08:00
|
|
|
lhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
2024-04-19 19:00:07 +08:00
|
|
|
};
|
|
|
|
let rhs_dim1 = unsafe {
|
2024-11-13 15:53:29 +08:00
|
|
|
rhs.shape().get_typed_unchecked(
|
2024-06-12 14:45:03 +08:00
|
|
|
ctx,
|
|
|
|
generator,
|
|
|
|
&llvm_usize.const_int(1, false),
|
|
|
|
None,
|
|
|
|
)
|
2024-04-19 19:00:07 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
// res.ndims == 2
|
|
|
|
ctx.make_assert(
|
|
|
|
generator,
|
2024-06-12 14:45:03 +08:00
|
|
|
ctx.builder
|
|
|
|
.build_int_compare(
|
|
|
|
IntPredicate::EQ,
|
|
|
|
res_ndims,
|
|
|
|
llvm_usize.const_int(2, false),
|
|
|
|
"",
|
|
|
|
)
|
|
|
|
.unwrap(),
|
2024-04-19 19:00:07 +08:00
|
|
|
"0:ValueError",
|
|
|
|
"",
|
|
|
|
[None, None, None],
|
|
|
|
ctx.current_loc,
|
|
|
|
);
|
|
|
|
|
|
|
|
// res.dims[0] == lhs.dims[0]
|
|
|
|
ctx.make_assert(
|
|
|
|
generator,
|
2024-06-12 14:45:03 +08:00
|
|
|
ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim0, res_dim0, "").unwrap(),
|
2024-04-19 19:00:07 +08:00
|
|
|
"0:ValueError",
|
|
|
|
"",
|
|
|
|
[None, None, None],
|
|
|
|
ctx.current_loc,
|
|
|
|
);
|
|
|
|
|
|
|
|
// res.dims[1] == rhs.dims[0]
|
|
|
|
ctx.make_assert(
|
|
|
|
generator,
|
2024-06-12 14:45:03 +08:00
|
|
|
ctx.builder.build_int_compare(IntPredicate::EQ, rhs_dim1, res_dim1, "").unwrap(),
|
2024-04-19 19:00:07 +08:00
|
|
|
"0:ValueError",
|
|
|
|
"",
|
|
|
|
[None, None, None],
|
|
|
|
ctx.current_loc,
|
|
|
|
);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
|
|
|
let lhs_dim1 = unsafe {
|
2024-11-13 15:53:29 +08:00
|
|
|
lhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
2024-04-19 19:00:07 +08:00
|
|
|
};
|
|
|
|
let rhs_dim0 = unsafe {
|
2024-11-13 15:53:29 +08:00
|
|
|
rhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
2024-04-19 19:00:07 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
// lhs.dims[1] == rhs.dims[0]
|
|
|
|
ctx.make_assert(
|
|
|
|
generator,
|
2024-06-12 14:45:03 +08:00
|
|
|
ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim1, rhs_dim0, "").unwrap(),
|
2024-04-19 19:00:07 +08:00
|
|
|
"0:ValueError",
|
|
|
|
"",
|
|
|
|
[None, None, None],
|
|
|
|
ctx.current_loc,
|
|
|
|
);
|
|
|
|
}
|
|
|
|
|
2024-06-06 12:16:09 +08:00
|
|
|
let lhs = if res.is_some_and(|res| res.as_base_value() == lhs.as_base_value()) {
|
2024-04-19 19:00:07 +08:00
|
|
|
ndarray_copy_impl(generator, ctx, elem_ty, lhs)?
|
|
|
|
} else {
|
|
|
|
lhs
|
|
|
|
};
|
|
|
|
|
|
|
|
let ndarray = res.unwrap_or_else(|| {
|
|
|
|
create_ndarray_dyn_shape(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
elem_ty,
|
|
|
|
&(lhs, rhs),
|
2024-06-12 14:45:03 +08:00
|
|
|
|_, _, _| Ok(llvm_usize.const_int(2, false)),
|
2024-04-19 19:00:07 +08:00
|
|
|
|generator, ctx, (lhs, rhs), idx| {
|
|
|
|
gen_if_else_expr_callback(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
|_, ctx| {
|
2024-06-12 14:45:03 +08:00
|
|
|
Ok(ctx
|
|
|
|
.builder
|
|
|
|
.build_int_compare(IntPredicate::EQ, idx, llvm_usize.const_zero(), "")
|
|
|
|
.unwrap())
|
2024-04-19 19:00:07 +08:00
|
|
|
},
|
|
|
|
|generator, ctx| {
|
|
|
|
Ok(Some(unsafe {
|
2024-11-13 15:53:29 +08:00
|
|
|
lhs.shape().get_typed_unchecked(
|
2024-04-19 19:00:07 +08:00
|
|
|
ctx,
|
|
|
|
generator,
|
|
|
|
&llvm_usize.const_zero(),
|
|
|
|
None,
|
|
|
|
)
|
|
|
|
}))
|
|
|
|
},
|
|
|
|
|generator, ctx| {
|
|
|
|
Ok(Some(unsafe {
|
2024-11-13 15:53:29 +08:00
|
|
|
rhs.shape().get_typed_unchecked(
|
2024-04-19 19:00:07 +08:00
|
|
|
ctx,
|
|
|
|
generator,
|
|
|
|
&llvm_usize.const_int(1, false),
|
|
|
|
None,
|
|
|
|
)
|
|
|
|
}))
|
|
|
|
},
|
2024-06-12 14:45:03 +08:00
|
|
|
)
|
|
|
|
.map(|v| v.map(BasicValueEnum::into_int_value).unwrap())
|
2024-04-19 19:00:07 +08:00
|
|
|
},
|
2024-06-12 14:45:03 +08:00
|
|
|
)
|
|
|
|
.unwrap()
|
2024-04-19 19:00:07 +08:00
|
|
|
});
|
|
|
|
|
|
|
|
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
ndarray_fill_indexed(generator, ctx, ndarray, |generator, ctx, idx| {
|
|
|
|
llvm_intrinsics::call_expect(
|
|
|
|
ctx,
|
|
|
|
idx.size(ctx, generator).get_type().const_int(2, false),
|
|
|
|
idx.size(ctx, generator),
|
|
|
|
None,
|
|
|
|
);
|
2024-04-19 19:00:07 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
let common_dim = {
|
|
|
|
let lhs_idx1 = unsafe {
|
2024-11-13 15:53:29 +08:00
|
|
|
lhs.shape().get_typed_unchecked(
|
2024-04-19 19:00:07 +08:00
|
|
|
ctx,
|
|
|
|
generator,
|
|
|
|
&llvm_usize.const_int(1, false),
|
|
|
|
None,
|
2024-06-12 14:45:03 +08:00
|
|
|
)
|
|
|
|
};
|
|
|
|
let rhs_idx0 = unsafe {
|
2024-11-13 15:53:29 +08:00
|
|
|
rhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
2024-04-19 19:00:07 +08:00
|
|
|
};
|
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None);
|
2024-04-19 19:00:07 +08:00
|
|
|
|
2024-12-19 12:21:08 +08:00
|
|
|
ctx.builder.build_int_z_extend_or_bit_cast(idx, llvm_usize, "").unwrap()
|
2024-06-12 14:45:03 +08:00
|
|
|
};
|
2024-04-19 19:00:07 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
let idx0 = unsafe {
|
|
|
|
let idx0 = idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
|
2024-04-19 19:00:07 +08:00
|
|
|
|
2024-12-19 12:21:08 +08:00
|
|
|
ctx.builder.build_int_z_extend_or_bit_cast(idx0, llvm_usize, "").unwrap()
|
2024-06-12 14:45:03 +08:00
|
|
|
};
|
|
|
|
let idx1 = unsafe {
|
|
|
|
let idx1 =
|
|
|
|
idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None);
|
2024-04-19 19:00:07 +08:00
|
|
|
|
2024-12-19 12:21:08 +08:00
|
|
|
ctx.builder.build_int_z_extend_or_bit_cast(idx1, llvm_usize, "").unwrap()
|
2024-06-12 14:45:03 +08:00
|
|
|
};
|
2024-04-19 19:00:07 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
let result_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
|
|
|
|
let result_identity = ndarray_zero_value(generator, ctx, elem_ty);
|
|
|
|
ctx.builder.build_store(result_addr, result_identity).unwrap();
|
2024-04-19 19:00:07 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
gen_for_callback_incrementing(
|
|
|
|
generator,
|
|
|
|
ctx,
|
2024-07-25 15:54:39 +08:00
|
|
|
None,
|
2024-12-19 12:21:08 +08:00
|
|
|
llvm_usize.const_zero(),
|
2024-06-12 14:45:03 +08:00
|
|
|
(common_dim, false),
|
2024-07-02 19:05:00 +08:00
|
|
|
|generator, ctx, _, i| {
|
2024-06-12 14:45:03 +08:00
|
|
|
let ab_idx = generator.gen_array_var_alloc(
|
|
|
|
ctx,
|
2024-12-19 12:21:08 +08:00
|
|
|
llvm_usize.into(),
|
2024-06-12 14:45:03 +08:00
|
|
|
llvm_usize.const_int(2, false),
|
|
|
|
None,
|
|
|
|
)?;
|
|
|
|
|
|
|
|
let a = unsafe {
|
|
|
|
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), idx0.into());
|
|
|
|
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_int(1, false), i.into());
|
|
|
|
|
|
|
|
lhs.data().get_unchecked(ctx, generator, &ab_idx, None)
|
|
|
|
};
|
|
|
|
let b = unsafe {
|
|
|
|
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), i.into());
|
|
|
|
ab_idx.set_unchecked(
|
2024-04-19 19:00:07 +08:00
|
|
|
ctx,
|
|
|
|
generator,
|
2024-06-12 14:45:03 +08:00
|
|
|
&llvm_usize.const_int(1, false),
|
|
|
|
idx1.into(),
|
|
|
|
);
|
2024-04-19 19:00:07 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
rhs.data().get_unchecked(ctx, generator, &ab_idx, None)
|
|
|
|
};
|
2024-04-19 19:00:07 +08:00
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
let a_mul_b = gen_binop_expr_with_values(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
(&Some(elem_ty), a),
|
2024-06-27 13:01:26 +08:00
|
|
|
Binop::normal(Operator::Mult),
|
2024-06-12 14:45:03 +08:00
|
|
|
(&Some(elem_ty), b),
|
|
|
|
ctx.current_loc,
|
|
|
|
)?
|
|
|
|
.unwrap()
|
|
|
|
.to_basic_value_enum(ctx, generator, elem_ty)?;
|
|
|
|
|
|
|
|
let result = ctx.builder.build_load(result_addr, "").unwrap();
|
|
|
|
let result = gen_binop_expr_with_values(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
(&Some(elem_ty), result),
|
2024-06-27 13:01:26 +08:00
|
|
|
Binop::normal(Operator::Add),
|
2024-06-12 14:45:03 +08:00
|
|
|
(&Some(elem_ty), a_mul_b),
|
|
|
|
ctx.current_loc,
|
|
|
|
)?
|
|
|
|
.unwrap()
|
|
|
|
.to_basic_value_enum(ctx, generator, elem_ty)?;
|
|
|
|
ctx.builder.build_store(result_addr, result).unwrap();
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
},
|
|
|
|
llvm_usize.const_int(1, false),
|
|
|
|
)?;
|
|
|
|
|
|
|
|
let result = ctx.builder.build_load(result_addr, "").unwrap();
|
|
|
|
Ok(result)
|
|
|
|
})?;
|
2024-04-19 19:00:07 +08:00
|
|
|
|
|
|
|
Ok(ndarray)
|
|
|
|
}
|
|
|
|
|
2024-03-11 14:47:01 +08:00
|
|
|
/// Generates LLVM IR for `ndarray.empty`.
|
|
|
|
pub fn gen_ndarray_empty<'ctx>(
|
|
|
|
context: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
|
|
|
fun: (&FunSignature, DefinitionId),
|
|
|
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
|
|
|
generator: &mut dyn CodeGenerator,
|
|
|
|
) -> Result<PointerValue<'ctx>, String> {
|
|
|
|
assert!(obj.is_none());
|
|
|
|
assert_eq!(args.len(), 1);
|
|
|
|
|
|
|
|
let shape_ty = fun.0.args[0].ty;
|
2024-06-12 14:45:03 +08:00
|
|
|
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-12-16 15:26:18 +08:00
|
|
|
let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
|
|
|
|
let llvm_dtype = context.get_llvm_type(generator, dtype);
|
|
|
|
let ndims = extract_ndims(&context.unifier, ndims);
|
|
|
|
|
|
|
|
let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg));
|
|
|
|
|
|
|
|
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims))
|
|
|
|
.construct_numpy_empty(generator, context, &shape, None);
|
|
|
|
Ok(ndarray.as_base_value())
|
2024-03-11 14:47:01 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Generates LLVM IR for `ndarray.zeros`.
|
|
|
|
pub fn gen_ndarray_zeros<'ctx>(
|
|
|
|
context: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
|
|
|
fun: (&FunSignature, DefinitionId),
|
|
|
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
|
|
|
generator: &mut dyn CodeGenerator,
|
|
|
|
) -> Result<PointerValue<'ctx>, String> {
|
|
|
|
assert!(obj.is_none());
|
|
|
|
assert_eq!(args.len(), 1);
|
|
|
|
|
|
|
|
let shape_ty = fun.0.args[0].ty;
|
2024-06-12 14:45:03 +08:00
|
|
|
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-12-16 15:26:18 +08:00
|
|
|
let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
|
|
|
|
let llvm_dtype = context.get_llvm_type(generator, dtype);
|
|
|
|
let ndims = extract_ndims(&context.unifier, ndims);
|
|
|
|
|
|
|
|
let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg));
|
|
|
|
|
|
|
|
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims))
|
|
|
|
.construct_numpy_zeros(generator, context, dtype, &shape, None);
|
|
|
|
Ok(ndarray.as_base_value())
|
2024-03-11 14:47:01 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Generates LLVM IR for `ndarray.ones`.
|
|
|
|
pub fn gen_ndarray_ones<'ctx>(
|
|
|
|
context: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
|
|
|
fun: (&FunSignature, DefinitionId),
|
|
|
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
|
|
|
generator: &mut dyn CodeGenerator,
|
|
|
|
) -> Result<PointerValue<'ctx>, String> {
|
|
|
|
assert!(obj.is_none());
|
|
|
|
assert_eq!(args.len(), 1);
|
|
|
|
|
|
|
|
let shape_ty = fun.0.args[0].ty;
|
2024-06-12 14:45:03 +08:00
|
|
|
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-12-16 15:26:18 +08:00
|
|
|
let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
|
|
|
|
let llvm_dtype = context.get_llvm_type(generator, dtype);
|
|
|
|
let ndims = extract_ndims(&context.unifier, ndims);
|
|
|
|
|
|
|
|
let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg));
|
|
|
|
|
|
|
|
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims))
|
|
|
|
.construct_numpy_ones(generator, context, dtype, &shape, None);
|
|
|
|
Ok(ndarray.as_base_value())
|
2024-03-11 14:47:01 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Generates LLVM IR for `ndarray.full`.
|
|
|
|
pub fn gen_ndarray_full<'ctx>(
|
|
|
|
context: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
|
|
|
fun: (&FunSignature, DefinitionId),
|
|
|
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
|
|
|
generator: &mut dyn CodeGenerator,
|
|
|
|
) -> Result<PointerValue<'ctx>, String> {
|
|
|
|
assert!(obj.is_none());
|
|
|
|
assert_eq!(args.len(), 2);
|
|
|
|
|
|
|
|
let shape_ty = fun.0.args[0].ty;
|
2024-06-12 14:45:03 +08:00
|
|
|
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
2024-03-11 14:47:01 +08:00
|
|
|
let fill_value_ty = fun.0.args[1].ty;
|
2024-06-12 14:45:03 +08:00
|
|
|
let fill_value_arg =
|
|
|
|
args[1].1.clone().to_basic_value_enum(context, generator, fill_value_ty)?;
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-12-16 15:26:18 +08:00
|
|
|
let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
|
|
|
|
let llvm_dtype = context.get_llvm_type(generator, dtype);
|
|
|
|
let ndims = extract_ndims(&context.unifier, ndims);
|
|
|
|
|
|
|
|
let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg));
|
|
|
|
|
|
|
|
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims))
|
|
|
|
.construct_numpy_full(generator, context, &shape, fill_value_arg, None);
|
|
|
|
Ok(ndarray.as_base_value())
|
2024-03-11 14:47:01 +08:00
|
|
|
}
|
|
|
|
|
2024-06-11 15:29:32 +08:00
|
|
|
pub fn gen_ndarray_array<'ctx>(
|
|
|
|
context: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
|
|
|
fun: (&FunSignature, DefinitionId),
|
|
|
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
|
|
|
generator: &mut dyn CodeGenerator,
|
|
|
|
) -> Result<PointerValue<'ctx>, String> {
|
|
|
|
assert!(obj.is_none());
|
|
|
|
assert!(matches!(args.len(), 1..=3));
|
|
|
|
|
|
|
|
let obj_ty = fun.0.args[0].ty;
|
2024-06-12 14:45:03 +08:00
|
|
|
let obj_arg = args[0].1.clone().to_basic_value_enum(context, generator, obj_ty)?;
|
2024-06-11 15:29:32 +08:00
|
|
|
|
|
|
|
let copy_arg = if let Some(arg) =
|
2024-06-12 14:45:03 +08:00
|
|
|
args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name))
|
|
|
|
{
|
2024-06-11 15:29:32 +08:00
|
|
|
let copy_ty = fun.0.args[1].ty;
|
|
|
|
arg.1.clone().to_basic_value_enum(context, generator, copy_ty)?
|
|
|
|
} else {
|
|
|
|
context.gen_symbol_val(
|
|
|
|
generator,
|
|
|
|
fun.0.args[1].default_value.as_ref().unwrap(),
|
|
|
|
fun.0.args[1].ty,
|
2024-06-12 14:45:03 +08:00
|
|
|
)
|
2024-06-11 15:29:32 +08:00
|
|
|
};
|
|
|
|
|
[core] codegen/ndarray: Reimplement np_array()
Based on 8f0084ac: core/ndstrides: implement np_array()
It also checks for inconsistent dimensions if the input is a list.
e.g., rejecting `[[1.0, 2.0], [3.0]]`.
However, currently only `np_array(<input>, copy=False)` and `np_array
(<input>, copy=True)` are supported. In NumPy, copy could be false,
true, or None. Right now, NAC3's `np_array(<input>, copy=False)` behaves
like NumPy's `np.array(<input>, copy=None)`.
2024-08-20 14:51:40 +08:00
|
|
|
// The ndmin argument is ignored. We can simply force the ndarray's number of dimensions to be
|
|
|
|
// the `ndims` of the function return type.
|
|
|
|
let (_, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
|
|
|
|
let ndims = extract_ndims(&context.unifier, ndims);
|
2024-06-11 15:29:32 +08:00
|
|
|
|
[core] codegen/ndarray: Reimplement np_array()
Based on 8f0084ac: core/ndstrides: implement np_array()
It also checks for inconsistent dimensions if the input is a list.
e.g., rejecting `[[1.0, 2.0], [3.0]]`.
However, currently only `np_array(<input>, copy=False)` and `np_array
(<input>, copy=True)` are supported. In NumPy, copy could be false,
true, or None. Right now, NAC3's `np_array(<input>, copy=False)` behaves
like NumPy's `np.array(<input>, copy=None)`.
2024-08-20 14:51:40 +08:00
|
|
|
let copy = generator.bool_to_i1(context, copy_arg.into_int_value());
|
|
|
|
let ndarray = NDArrayType::from_unifier_type(generator, context, fun.0.ret)
|
|
|
|
.construct_numpy_array(generator, context, (obj_ty, obj_arg), copy, None)
|
|
|
|
.atleast_nd(generator, context, ndims);
|
|
|
|
|
|
|
|
Ok(ndarray.as_base_value())
|
2024-06-11 15:29:32 +08:00
|
|
|
}
|
|
|
|
|
2024-03-11 14:47:01 +08:00
|
|
|
/// Generates LLVM IR for `ndarray.eye`.
|
|
|
|
pub fn gen_ndarray_eye<'ctx>(
|
|
|
|
context: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
|
|
|
fun: (&FunSignature, DefinitionId),
|
|
|
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
|
|
|
generator: &mut dyn CodeGenerator,
|
|
|
|
) -> Result<PointerValue<'ctx>, String> {
|
|
|
|
assert!(obj.is_none());
|
|
|
|
assert!(matches!(args.len(), 1..=3));
|
|
|
|
|
|
|
|
let nrows_ty = fun.0.args[0].ty;
|
2024-06-12 14:45:03 +08:00
|
|
|
let nrows_arg = args[0].1.clone().to_basic_value_enum(context, generator, nrows_ty)?;
|
2024-03-11 14:47:01 +08:00
|
|
|
|
|
|
|
let ncols_ty = fun.0.args[1].ty;
|
2024-06-12 14:45:03 +08:00
|
|
|
let ncols_arg = if let Some(arg) =
|
|
|
|
args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name))
|
|
|
|
{
|
2024-04-01 16:22:40 +08:00
|
|
|
arg.1.clone().to_basic_value_enum(context, generator, ncols_ty)
|
|
|
|
} else {
|
|
|
|
args[0].1.clone().to_basic_value_enum(context, generator, nrows_ty)
|
|
|
|
}?;
|
2024-03-11 14:47:01 +08:00
|
|
|
|
|
|
|
let offset_ty = fun.0.args[2].ty;
|
2024-04-01 16:22:40 +08:00
|
|
|
let offset_arg = if let Some(arg) =
|
2024-06-12 14:45:03 +08:00
|
|
|
args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[2].name))
|
|
|
|
{
|
|
|
|
arg.1.clone().to_basic_value_enum(context, generator, offset_ty)
|
2024-04-01 16:22:40 +08:00
|
|
|
} else {
|
|
|
|
Ok(context.gen_symbol_val(
|
|
|
|
generator,
|
|
|
|
fun.0.args[2].default_value.as_ref().unwrap(),
|
2024-06-12 14:45:03 +08:00
|
|
|
offset_ty,
|
|
|
|
))
|
2024-04-01 16:22:40 +08:00
|
|
|
}?;
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-12-17 18:01:12 +08:00
|
|
|
let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
|
|
|
|
|
|
|
|
let llvm_usize = generator.get_size_type(context.ctx);
|
|
|
|
let llvm_dtype = context.get_llvm_type(generator, dtype);
|
|
|
|
|
|
|
|
let nrows = context
|
|
|
|
.builder
|
|
|
|
.build_int_s_extend_or_bit_cast(nrows_arg.into_int_value(), llvm_usize, "")
|
|
|
|
.unwrap();
|
|
|
|
let ncols = context
|
|
|
|
.builder
|
|
|
|
.build_int_s_extend_or_bit_cast(ncols_arg.into_int_value(), llvm_usize, "")
|
|
|
|
.unwrap();
|
|
|
|
let offset = context
|
|
|
|
.builder
|
|
|
|
.build_int_s_extend_or_bit_cast(offset_arg.into_int_value(), llvm_usize, "")
|
|
|
|
.unwrap();
|
|
|
|
|
|
|
|
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(2))
|
|
|
|
.construct_numpy_eye(generator, context, dtype, nrows, ncols, offset, None);
|
|
|
|
Ok(ndarray.as_base_value())
|
2024-03-11 14:47:01 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Generates LLVM IR for `ndarray.identity`.
|
|
|
|
pub fn gen_ndarray_identity<'ctx>(
|
|
|
|
context: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
|
|
|
fun: (&FunSignature, DefinitionId),
|
|
|
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
|
|
|
generator: &mut dyn CodeGenerator,
|
|
|
|
) -> Result<PointerValue<'ctx>, String> {
|
|
|
|
assert!(obj.is_none());
|
|
|
|
assert_eq!(args.len(), 1);
|
|
|
|
|
|
|
|
let n_ty = fun.0.args[0].ty;
|
2024-06-12 14:45:03 +08:00
|
|
|
let n_arg = args[0].1.clone().to_basic_value_enum(context, generator, n_ty)?;
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-12-17 18:01:12 +08:00
|
|
|
let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
|
|
|
|
|
|
|
|
let llvm_usize = generator.get_size_type(context.ctx);
|
|
|
|
let llvm_dtype = context.get_llvm_type(generator, dtype);
|
|
|
|
|
|
|
|
let n = context
|
|
|
|
.builder
|
|
|
|
.build_int_s_extend_or_bit_cast(n_arg.into_int_value(), llvm_usize, "")
|
|
|
|
.unwrap();
|
|
|
|
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(2))
|
|
|
|
.construct_numpy_identity(generator, context, dtype, n, None);
|
|
|
|
Ok(ndarray.as_base_value())
|
2024-03-11 14:47:01 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Generates LLVM IR for `ndarray.copy`.
|
|
|
|
pub fn gen_ndarray_copy<'ctx>(
|
|
|
|
context: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
|
|
|
_fun: (&FunSignature, DefinitionId),
|
|
|
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
|
|
|
generator: &mut dyn CodeGenerator,
|
|
|
|
) -> Result<PointerValue<'ctx>, String> {
|
|
|
|
assert!(obj.is_some());
|
|
|
|
assert!(args.is_empty());
|
|
|
|
|
|
|
|
let this_ty = obj.as_ref().unwrap().0;
|
2024-06-12 14:45:03 +08:00
|
|
|
let this_arg =
|
|
|
|
obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?;
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-12-18 09:53:00 +08:00
|
|
|
let this = NDArrayType::from_unifier_type(generator, context, this_ty)
|
|
|
|
.map_value(this_arg.into_pointer_value(), None);
|
|
|
|
let ndarray = this.make_copy(generator, context);
|
|
|
|
Ok(ndarray.as_base_value())
|
2024-03-11 14:47:01 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Generates LLVM IR for `ndarray.fill`.
|
|
|
|
pub fn gen_ndarray_fill<'ctx>(
|
|
|
|
context: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
|
|
|
fun: (&FunSignature, DefinitionId),
|
|
|
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
|
|
|
generator: &mut dyn CodeGenerator,
|
|
|
|
) -> Result<(), String> {
|
|
|
|
assert!(obj.is_some());
|
|
|
|
assert_eq!(args.len(), 1);
|
|
|
|
|
|
|
|
let this_ty = obj.as_ref().unwrap().0;
|
2024-12-18 09:53:00 +08:00
|
|
|
let this_arg =
|
|
|
|
obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?;
|
2024-03-11 14:47:01 +08:00
|
|
|
let value_ty = fun.0.args[0].ty;
|
2024-06-12 14:45:03 +08:00
|
|
|
let value_arg = args[0].1.clone().to_basic_value_enum(context, generator, value_ty)?;
|
2024-03-11 14:47:01 +08:00
|
|
|
|
2024-12-18 09:53:00 +08:00
|
|
|
let this = NDArrayType::from_unifier_type(generator, context, this_ty)
|
|
|
|
.map_value(this_arg.into_pointer_value(), None);
|
|
|
|
this.fill(generator, context, value_arg);
|
2024-03-11 14:47:01 +08:00
|
|
|
Ok(())
|
2024-06-12 14:45:03 +08:00
|
|
|
}
|
2024-07-31 13:16:42 +08:00
|
|
|
|
2024-07-31 15:53:51 +08:00
|
|
|
/// Generates LLVM IR for `ndarray.dot`.
|
|
|
|
/// Calculate inner product of two vectors or literals
|
|
|
|
/// For matrix multiplication use `np_matmul`
|
|
|
|
///
|
|
|
|
/// The input `NDArray` are flattened and treated as 1D
|
2024-07-31 18:02:54 +08:00
|
|
|
/// The operation is equivalent to `np.dot(arr1.ravel(), arr2.ravel())`
|
2024-07-31 15:53:51 +08:00
|
|
|
pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
|
|
|
|
generator: &mut G,
|
|
|
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
|
|
x1: (Type, BasicValueEnum<'ctx>),
|
|
|
|
x2: (Type, BasicValueEnum<'ctx>),
|
|
|
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
|
|
|
const FN_NAME: &str = "ndarray_dot";
|
|
|
|
let (x1_ty, x1) = x1;
|
2024-08-28 16:33:03 +08:00
|
|
|
let (x2_ty, x2) = x2;
|
2024-07-31 15:53:51 +08:00
|
|
|
|
|
|
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
|
|
|
|
|
|
match (x1, x2) {
|
|
|
|
(BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => {
|
2024-11-27 16:06:16 +08:00
|
|
|
let n1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(n1, None);
|
|
|
|
let n2 = NDArrayType::from_unifier_type(generator, ctx, x2_ty).map_value(n2, None);
|
2024-07-31 15:53:51 +08:00
|
|
|
|
2024-12-19 12:21:08 +08:00
|
|
|
let n1_sz = n1.size(generator, ctx);
|
|
|
|
let n2_sz = n2.size(generator, ctx);
|
2024-07-31 15:53:51 +08:00
|
|
|
|
|
|
|
ctx.make_assert(
|
|
|
|
generator,
|
|
|
|
ctx.builder.build_int_compare(IntPredicate::EQ, n1_sz, n2_sz, "").unwrap(),
|
|
|
|
"0:ValueError",
|
|
|
|
"shapes ({0}), ({1}) not aligned",
|
|
|
|
[Some(n1_sz), Some(n2_sz), None],
|
|
|
|
ctx.current_loc,
|
|
|
|
);
|
|
|
|
|
|
|
|
let identity =
|
|
|
|
unsafe { n1.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
|
|
|
|
let acc = ctx.builder.build_alloca(identity.get_type(), "").unwrap();
|
|
|
|
ctx.builder.build_store(acc, identity.get_type().const_zero()).unwrap();
|
|
|
|
|
|
|
|
gen_for_callback_incrementing(
|
|
|
|
generator,
|
|
|
|
ctx,
|
|
|
|
None,
|
|
|
|
llvm_usize.const_zero(),
|
|
|
|
(n1_sz, false),
|
|
|
|
|generator, ctx, _, idx| {
|
|
|
|
let elem1 = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
|
|
|
|
let elem2 = unsafe { n2.data().get_unchecked(ctx, generator, &idx, None) };
|
|
|
|
|
|
|
|
let product = match elem1 {
|
|
|
|
BasicValueEnum::IntValue(e1) => ctx
|
|
|
|
.builder
|
|
|
|
.build_int_mul(e1, elem2.into_int_value(), "")
|
|
|
|
.unwrap()
|
|
|
|
.as_basic_value_enum(),
|
|
|
|
BasicValueEnum::FloatValue(e1) => ctx
|
|
|
|
.builder
|
|
|
|
.build_float_mul(e1, elem2.into_float_value(), "")
|
|
|
|
.unwrap()
|
|
|
|
.as_basic_value_enum(),
|
2024-08-28 16:33:03 +08:00
|
|
|
_ => codegen_unreachable!(ctx, "product: {}", elem1.get_type()),
|
2024-07-31 15:53:51 +08:00
|
|
|
};
|
|
|
|
let acc_val = ctx.builder.build_load(acc, "").unwrap();
|
|
|
|
let acc_val = match acc_val {
|
|
|
|
BasicValueEnum::IntValue(e1) => ctx
|
|
|
|
.builder
|
|
|
|
.build_int_add(e1, product.into_int_value(), "")
|
|
|
|
.unwrap()
|
|
|
|
.as_basic_value_enum(),
|
|
|
|
BasicValueEnum::FloatValue(e1) => ctx
|
|
|
|
.builder
|
|
|
|
.build_float_add(e1, product.into_float_value(), "")
|
|
|
|
.unwrap()
|
|
|
|
.as_basic_value_enum(),
|
2024-08-28 16:33:03 +08:00
|
|
|
_ => codegen_unreachable!(ctx, "acc_val: {}", acc_val.get_type()),
|
2024-07-31 15:53:51 +08:00
|
|
|
};
|
|
|
|
ctx.builder.build_store(acc, acc_val).unwrap();
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
},
|
|
|
|
llvm_usize.const_int(1, false),
|
|
|
|
)?;
|
|
|
|
let acc_val = ctx.builder.build_load(acc, "").unwrap();
|
|
|
|
Ok(acc_val)
|
|
|
|
}
|
|
|
|
(BasicValueEnum::IntValue(e1), BasicValueEnum::IntValue(e2)) => {
|
|
|
|
Ok(ctx.builder.build_int_mul(e1, e2, "").unwrap().as_basic_value_enum())
|
|
|
|
}
|
|
|
|
(BasicValueEnum::FloatValue(e1), BasicValueEnum::FloatValue(e2)) => {
|
|
|
|
Ok(ctx.builder.build_float_mul(e1, e2, "").unwrap().as_basic_value_enum())
|
|
|
|
}
|
2024-08-23 13:10:55 +08:00
|
|
|
_ => codegen_unreachable!(
|
|
|
|
ctx,
|
2024-07-31 15:53:51 +08:00
|
|
|
"{FN_NAME}() not supported for '{}'",
|
|
|
|
format!("'{}'", ctx.unifier.stringify(x1_ty))
|
|
|
|
),
|
|
|
|
}
|
|
|
|
}
|