1
0
forked from M-Labs/nac3
nac3/nac3core/src/codegen/numpy.rs

2486 lines
90 KiB
Rust
Raw Normal View History

use inkwell::{
types::{AnyTypeEnum, BasicType, BasicTypeEnum, PointerType},
values::{BasicValue, BasicValueEnum, IntValue, PointerValue},
AddressSpace, IntPredicate, OptimizationLevel,
};
use itertools::Itertools;
use nac3parser::ast::{Operator, StrRef};
use super::{
expr::gen_binop_expr_with_values,
irrt::{
2024-11-22 16:38:57 +08:00
calculate_len_for_slice_range,
ndarray::{
call_ndarray_calc_broadcast, call_ndarray_calc_broadcast_index,
call_ndarray_calc_nd_indices, call_ndarray_calc_size,
},
},
llvm_intrinsics::{self, call_memcpy_generic},
macros::codegen_unreachable,
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
types::{ndarray::NDArrayType, ListType, ProxyType},
values::{
ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue,
TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator,
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
},
CodeGenContext, CodeGenerator,
};
use crate::{
symbol_resolver::ValueEnum,
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId},
typecheck::{
magic_methods::Binop,
typedef::{FunSignature, Type, TypeEnum},
},
};
/// Creates an `NDArray` instance from a dynamic shape.
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The shape of the `NDArray`.
/// * `shape_len_fn` - A function that retrieves the number of dimensions from `shape`.
/// * `shape_data_fn` - A function that retrieves the size of a dimension from `shape`.
fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
elem_ty: Type,
shape: &V,
shape_len_fn: LenFn,
shape_data_fn: DataFn,
) -> Result<NDArrayValue<'ctx>, String>
2024-06-12 14:45:03 +08:00
where
G: CodeGenerator + ?Sized,
LenFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, &V) -> Result<IntValue<'ctx>, String>,
DataFn: Fn(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
&V,
IntValue<'ctx>,
) -> Result<IntValue<'ctx>, String>,
{
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
// Assert that all dimensions are non-negative
let shape_len = shape_len_fn(generator, ctx, shape)?;
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(shape_len, false),
|generator, ctx, _, i| {
let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
2024-06-12 14:45:03 +08:00
let shape_dim_gez = ctx
.builder
.build_int_compare(
IntPredicate::SGE,
shape_dim,
shape_dim.get_type().const_zero(),
"",
)
.unwrap();
ctx.make_assert(
generator,
shape_dim_gez,
"0:ValueError",
"negative dimensions not supported",
[None, None, None],
ctx.current_loc,
);
2024-06-12 14:45:03 +08:00
// TODO: Disallow shape > u32_MAX
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let num_dims = shape_len_fn(generator, ctx, shape)?;
let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty, None)
.construct_dyn_ndims(generator, ctx, num_dims, None);
// Copy the dimension sizes from shape to ndarray.dims
let shape_len = shape_len_fn(generator, ctx, shape)?;
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(shape_len, false),
|generator, ctx, _, i| {
let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
2024-06-12 14:45:03 +08:00
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
2024-06-12 14:45:03 +08:00
let ndarray_pdim =
unsafe { ndarray.shape().ptr_offset_unchecked(ctx, generator, &i, None) };
ctx.builder.build_store(ndarray_pdim, shape_dim).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)?;
unsafe { ndarray.create_data(generator, ctx) };
Ok(ndarray)
}
/// Creates an `NDArray` instance from a constant shape.
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The shape of the `NDArray`, represented am array of [`IntValue`]s.
2024-07-25 12:16:53 +08:00
pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
shape: &[IntValue<'ctx>],
) -> Result<NDArrayValue<'ctx>, String> {
let llvm_usize = generator.get_size_type(ctx.ctx);
for &shape_dim in shape {
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
2024-06-12 14:45:03 +08:00
let shape_dim_gez = ctx
.builder
.build_int_compare(IntPredicate::SGE, shape_dim, llvm_usize.const_zero(), "")
.unwrap();
ctx.make_assert(
generator,
shape_dim_gez,
"0:ValueError",
"negative dimensions not supported",
[None, None, None],
ctx.current_loc,
);
// TODO: Disallow shape > u32_MAX
}
let llvm_dtype = ctx.get_llvm_type(generator, elem_ty);
let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_dtype, Some(shape.len() as u64))
.construct_dyn_shape(generator, ctx, shape, None);
unsafe { ndarray.create_data(generator, ctx) };
2024-05-29 14:19:12 +08:00
Ok(ndarray)
}
fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
) -> BasicValueEnum<'ctx> {
2024-06-12 14:45:03 +08:00
if [ctx.primitives.int32, ctx.primitives.uint32]
.iter()
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
{
ctx.ctx.i32_type().const_zero().into()
2024-06-12 14:45:03 +08:00
} else if [ctx.primitives.int64, ctx.primitives.uint64]
.iter()
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
{
ctx.ctx.i64_type().const_zero().into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) {
ctx.ctx.f64_type().const_zero().into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
ctx.ctx.bool_type().const_zero().into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
ctx.gen_string(generator, "").into()
} else {
2024-08-23 13:10:55 +08:00
codegen_unreachable!(ctx)
}
}
fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
) -> BasicValueEnum<'ctx> {
2024-06-12 14:45:03 +08:00
if [ctx.primitives.int32, ctx.primitives.uint32]
.iter()
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
{
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int32);
ctx.ctx.i32_type().const_int(1, is_signed).into()
2024-06-12 14:45:03 +08:00
} else if [ctx.primitives.int64, ctx.primitives.uint64]
.iter()
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
{
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int64);
ctx.ctx.i64_type().const_int(1, is_signed).into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) {
ctx.ctx.f64_type().const_float(1.0).into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
ctx.ctx.bool_type().const_int(1, false).into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
ctx.gen_string(generator, "1").into()
} else {
2024-08-23 13:10:55 +08:00
codegen_unreachable!(ctx)
}
}
/// LLVM-typed implementation for generating the implementation for constructing an `NDArray`.
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
///
/// ### Notes on `shape`
///
/// Just like numpy, the `shape` argument can be:
/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))`
/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
///
/// See also [`typecheck::type_inferencer::fold_numpy_function_call_shape_argument`] to
/// learn how `shape` gets from being a Python user expression to here.
fn call_ndarray_empty_impl<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
shape: BasicValueEnum<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
let llvm_usize = generator.get_size_type(ctx.ctx);
match shape {
BasicValueEnum::PointerValue(shape_list_ptr)
if ListValue::is_representable(shape_list_ptr, llvm_usize).is_ok() =>
{
// 1. A list of ints; e.g., `np.empty([600, 800, 3])`
let shape_list = ListValue::from_pointer_value(shape_list_ptr, llvm_usize, None);
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&shape_list,
|_, ctx, shape_list| Ok(shape_list.load_size(ctx, None)),
|generator, ctx, shape_list, idx| {
Ok(shape_list.data().get(ctx, generator, &idx, None).into_int_value())
},
)
}
BasicValueEnum::StructValue(shape_tuple) => {
// 2. A tuple of ints; e.g., `np.empty((600, 800, 3))`
// Read [`codegen::expr::gen_expr`] to see how `nac3core` translates a Python tuple into LLVM.
// Get the length/size of the tuple, which also happens to be the value of `ndims`.
let ndims = shape_tuple.get_type().count_fields();
let shape = (0..ndims)
.map(|dim_i| {
ctx.builder
.build_extract_value(shape_tuple, dim_i, format!("dim{dim_i}").as_str())
.map(BasicValueEnum::into_int_value)
.map(|v| {
ctx.builder.build_int_z_extend_or_bit_cast(v, llvm_usize, "").unwrap()
})
.unwrap()
})
.collect_vec();
create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice())
}
BasicValueEnum::IntValue(shape_int) => {
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
let shape_int =
ctx.builder.build_int_z_extend_or_bit_cast(shape_int, llvm_usize, "").unwrap();
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
}
2024-08-23 13:10:55 +08:00
_ => codegen_unreachable!(ctx),
}
}
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as
/// its input.
fn ndarray_fill_flattened<'ctx, 'a, G, ValueFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
ndarray: NDArrayValue<'ctx>,
value_fn: ValueFn,
) -> Result<(), String>
2024-06-12 14:45:03 +08:00
where
G: CodeGenerator + ?Sized,
ValueFn: Fn(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
IntValue<'ctx>,
) -> Result<BasicValueEnum<'ctx>, String>,
{
let llvm_usize = generator.get_size_type(ctx.ctx);
let ndarray_num_elems = call_ndarray_calc_size(
generator,
ctx,
&ndarray.shape().as_slice_value(ctx, generator),
(None, None),
);
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(ndarray_num_elems, false),
|generator, ctx, _, i| {
2024-06-12 14:45:03 +08:00
let elem = unsafe { ndarray.data().ptr_offset_unchecked(ctx, generator, &i, None) };
let value = value_fn(generator, ctx, i)?;
ctx.builder.build_store(elem, value).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)
}
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with the dimension-indices
/// as its input.
fn ndarray_fill_indexed<'ctx, 'a, G, ValueFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
ndarray: NDArrayValue<'ctx>,
value_fn: ValueFn,
) -> Result<(), String>
2024-06-12 14:45:03 +08:00
where
G: CodeGenerator + ?Sized,
ValueFn: Fn(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
&TypedArrayLikeAdapter<'ctx, IntValue<'ctx>>,
) -> Result<BasicValueEnum<'ctx>, String>,
{
2024-06-12 14:45:03 +08:00
ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, idx| {
let indices = call_ndarray_calc_nd_indices(generator, ctx, idx, ndarray);
2024-06-12 14:45:03 +08:00
value_fn(generator, ctx, &indices)
})
}
fn ndarray_fill_mapping<'ctx, 'a, G, MapFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
src: NDArrayValue<'ctx>,
dest: NDArrayValue<'ctx>,
map_fn: MapFn,
) -> Result<(), String>
2024-06-12 14:45:03 +08:00
where
G: CodeGenerator + ?Sized,
MapFn: Fn(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
BasicValueEnum<'ctx>,
) -> Result<BasicValueEnum<'ctx>, String>,
{
2024-06-12 14:45:03 +08:00
ndarray_fill_flattened(generator, ctx, dest, |generator, ctx, i| {
let elem = unsafe { src.data().get_unchecked(ctx, generator, &i, None) };
2024-06-12 14:45:03 +08:00
map_fn(generator, ctx, elem)
})
}
/// Generates the LLVM IR for checking whether the source `ndarray` can be broadcast to the shape of
/// the target `ndarray`.
fn ndarray_assert_is_broadcastable<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
target: NDArrayValue<'ctx>,
source: NDArrayValue<'ctx>,
) {
let array_ndims = source.load_ndims(ctx);
let broadcast_size = target.load_ndims(ctx);
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::ULE, array_ndims, broadcast_size, "").unwrap(),
"0:ValueError",
"operands cannot be broadcast together",
[None, None, None],
ctx.current_loc,
);
}
/// Generates the LLVM IR for populating the entire `NDArray` from two `ndarray` or scalar value
/// with broadcast-compatible shapes.
fn ndarray_broadcast_fill<'ctx, 'a, G, ValueFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
res: NDArrayValue<'ctx>,
(lhs_ty, lhs_val, lhs_scalar): (Type, BasicValueEnum<'ctx>, bool),
(rhs_ty, rhs_val, rhs_scalar): (Type, BasicValueEnum<'ctx>, bool),
value_fn: ValueFn,
) -> Result<NDArrayValue<'ctx>, String>
2024-06-12 14:45:03 +08:00
where
G: CodeGenerator + ?Sized,
ValueFn: Fn(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
(BasicValueEnum<'ctx>, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String>,
{
2024-06-12 14:45:03 +08:00
assert!(
!(lhs_scalar && rhs_scalar),
"One of the operands must be a ndarray instance: `{}`, `{}`",
lhs_val.get_type(),
rhs_val.get_type()
);
// Assert that all ndarray operands are broadcastable to the target size
if !lhs_scalar {
let lhs_val = NDArrayType::from_unifier_type(generator, ctx, lhs_ty)
.map_value(lhs_val.into_pointer_value(), None);
ndarray_assert_is_broadcastable(generator, ctx, res, lhs_val);
}
if !rhs_scalar {
let rhs_val = NDArrayType::from_unifier_type(generator, ctx, rhs_ty)
.map_value(rhs_val.into_pointer_value(), None);
ndarray_assert_is_broadcastable(generator, ctx, res, rhs_val);
}
2024-06-12 14:45:03 +08:00
ndarray_fill_indexed(generator, ctx, res, |generator, ctx, idx| {
let lhs_elem = if lhs_scalar {
lhs_val
} else {
let lhs = NDArrayType::from_unifier_type(generator, ctx, lhs_ty)
.map_value(lhs_val.into_pointer_value(), None);
2024-06-12 14:45:03 +08:00
let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, idx);
2024-06-12 14:45:03 +08:00
unsafe { lhs.data().get_unchecked(ctx, generator, &lhs_idx, None) }
};
2024-06-12 14:45:03 +08:00
let rhs_elem = if rhs_scalar {
rhs_val
} else {
let rhs = NDArrayType::from_unifier_type(generator, ctx, rhs_ty)
.map_value(rhs_val.into_pointer_value(), None);
2024-06-12 14:45:03 +08:00
let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, idx);
2024-06-12 14:45:03 +08:00
unsafe { rhs.data().get_unchecked(ctx, generator, &rhs_idx, None) }
};
value_fn(generator, ctx, (lhs_elem, rhs_elem))
})?;
Ok(res)
}
/// LLVM-typed implementation for generating the implementation for `ndarray.zeros`.
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
fn call_ndarray_zeros_impl<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
shape: BasicValueEnum<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
let supported_types = [
ctx.primitives.int32,
ctx.primitives.int64,
ctx.primitives.uint32,
ctx.primitives.uint64,
ctx.primitives.float,
ctx.primitives.bool,
ctx.primitives.str,
];
assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty)));
let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?;
2024-06-12 14:45:03 +08:00
ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| {
let value = ndarray_zero_value(generator, ctx, elem_ty);
2024-06-12 14:45:03 +08:00
Ok(value)
})?;
Ok(ndarray)
}
/// LLVM-typed implementation for generating the implementation for `ndarray.ones`.
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
fn call_ndarray_ones_impl<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
shape: BasicValueEnum<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
let supported_types = [
ctx.primitives.int32,
ctx.primitives.int64,
ctx.primitives.uint32,
ctx.primitives.uint64,
ctx.primitives.float,
ctx.primitives.bool,
ctx.primitives.str,
];
assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty)));
let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?;
2024-06-12 14:45:03 +08:00
ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| {
let value = ndarray_one_value(generator, ctx, elem_ty);
2024-06-12 14:45:03 +08:00
Ok(value)
})?;
Ok(ndarray)
}
/// LLVM-typed implementation for generating the implementation for `ndarray.full`.
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
fn call_ndarray_full_impl<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
shape: BasicValueEnum<'ctx>,
fill_value: BasicValueEnum<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?;
2024-06-12 14:45:03 +08:00
ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| {
let value = if fill_value.is_pointer_value() {
let llvm_i1 = ctx.ctx.bool_type();
2024-06-12 14:45:03 +08:00
let copy = generator.gen_var_alloc(ctx, fill_value.get_type(), None)?;
2024-06-12 14:45:03 +08:00
call_memcpy_generic(
ctx,
copy,
fill_value.into_pointer_value(),
fill_value.get_type().size_of().map(Into::into).unwrap(),
llvm_i1.const_zero(),
);
2024-06-12 14:45:03 +08:00
copy.into()
} else if fill_value.is_int_value() || fill_value.is_float_value() {
fill_value
} else {
2024-08-23 13:10:55 +08:00
codegen_unreachable!(ctx)
2024-06-12 14:45:03 +08:00
};
2024-06-12 14:45:03 +08:00
Ok(value)
})?;
Ok(ndarray)
}
/// Returns the number of dimensions for a multidimensional list as an [`IntValue`].
fn llvm_ndlist_get_ndims<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ty: PointerType<'ctx>,
) -> IntValue<'ctx> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let list_ty = ListType::from_type(ty, llvm_usize);
let list_elem_ty = list_ty.element_type();
let ndims = llvm_usize.const_int(1, false);
match list_elem_ty {
AnyTypeEnum::PointerType(ptr_ty)
if ListType::is_representable(ptr_ty, llvm_usize).is_ok() =>
{
ndims.const_add(llvm_ndlist_get_ndims(generator, ctx, ptr_ty))
}
AnyTypeEnum::PointerType(ptr_ty)
if NDArrayType::is_representable(ptr_ty, llvm_usize).is_ok() =>
{
todo!("Getting ndims for list[ndarray] not supported")
}
_ => ndims,
}
}
/// Flattens and copies the values from a multidimensional list into an [`NDArrayValue`].
fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
(dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>),
src_lst: ListValue<'ctx>,
dim: u64,
) -> Result<(), String> {
let llvm_i1 = ctx.ctx.bool_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let list_elem_ty = src_lst.get_type().element_type();
match list_elem_ty {
AnyTypeEnum::PointerType(ptr_ty)
if ListType::is_representable(ptr_ty, llvm_usize).is_ok() =>
{
2024-06-12 14:45:03 +08:00
// The stride of elements in this dimension, i.e. the number of elements between arr[i]
// and arr[i + 1] in this dimension
let stride = call_ndarray_calc_size(
generator,
ctx,
&dst_arr.shape(),
(Some(llvm_usize.const_int(dim + 1, false)), None),
);
gen_for_range_callback(
generator,
ctx,
None,
true,
|_, _| Ok(llvm_usize.const_zero()),
(|_, ctx| Ok(src_lst.load_size(ctx, None)), false),
|_, _| Ok(llvm_usize.const_int(1, false)),
|generator, ctx, _, i| {
2024-06-12 14:45:03 +08:00
let offset = ctx.builder.build_int_mul(stride, i, "").unwrap();
let offset = ctx
.builder
.build_int_mul(
offset,
ctx.builder
.build_int_truncate_or_bit_cast(
dst_arr.get_type().element_type().size_of().unwrap(),
offset.get_type(),
"",
)
.unwrap(),
"",
)
.unwrap();
2024-06-12 14:45:03 +08:00
let dst_ptr =
unsafe { ctx.builder.build_gep(dst_slice_ptr, &[offset], "").unwrap() };
let nested_lst_elem = ListValue::from_pointer_value(
2024-06-12 14:45:03 +08:00
unsafe { src_lst.data().get_unchecked(ctx, generator, &i, None) }
.into_pointer_value(),
llvm_usize,
None,
);
ndarray_from_ndlist_impl(
generator,
ctx,
(dst_arr, dst_ptr),
nested_lst_elem,
dim + 1,
)?;
Ok(())
},
)?;
}
AnyTypeEnum::PointerType(ptr_ty)
if NDArrayType::is_representable(ptr_ty, llvm_usize).is_ok() =>
{
todo!("Not implemented for list[ndarray]")
}
_ => {
let lst_len = src_lst.load_size(ctx, None);
let sizeof_elem = dst_arr.get_type().element_type().size_of().unwrap();
let sizeof_elem =
ctx.builder.build_int_z_extend_or_bit_cast(sizeof_elem, llvm_usize, "").unwrap();
2024-06-12 14:45:03 +08:00
let cpy_len = ctx
.builder
.build_int_mul(
ctx.builder.build_int_z_extend_or_bit_cast(lst_len, llvm_usize, "").unwrap(),
sizeof_elem,
"",
)
.unwrap();
call_memcpy_generic(
ctx,
dst_slice_ptr,
src_lst.data().base_ptr(ctx, generator),
cpy_len,
llvm_i1.const_zero(),
);
}
}
Ok(())
}
/// LLVM-typed implementation for `ndarray.array`.
fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
object: BasicValueEnum<'ctx>,
copy: IntValue<'ctx>,
ndmin: IntValue<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
let llvm_i1 = ctx.ctx.bool_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
2024-06-12 14:45:03 +08:00
let ndmin = ctx.builder.build_int_z_extend_or_bit_cast(ndmin, llvm_usize, "").unwrap();
// TODO(Derppening): Add assertions for sizes of different dimensions
// object is not a pointer - 0-dim NDArray
if !object.is_pointer_value() {
2024-06-12 14:45:03 +08:00
let ndarray = create_ndarray_const_shape(generator, ctx, elem_ty, &[])?;
unsafe {
2024-06-12 14:45:03 +08:00
ndarray.data().set_unchecked(ctx, generator, &llvm_usize.const_zero(), object);
}
2024-06-12 14:45:03 +08:00
return Ok(ndarray);
}
let object = object.into_pointer_value();
// object is an NDArray instance - copy object unless copy=0 && ndmin < object.ndims
if NDArrayValue::is_representable(object, llvm_usize).is_ok() {
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let object = NDArrayValue::from_pointer_value(object, llvm_elem_ty, None, llvm_usize, None);
let ndarray = gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
2024-06-12 14:45:03 +08:00
let copy_nez = ctx
.builder
.build_int_compare(IntPredicate::NE, copy, llvm_i1.const_zero(), "")
.unwrap();
2024-06-12 14:45:03 +08:00
let ndmin_gt_ndims = ctx
.builder
.build_int_compare(IntPredicate::UGT, ndmin, object.load_ndims(ctx), "")
.unwrap();
2024-06-12 14:45:03 +08:00
Ok(ctx.builder.build_and(copy_nez, ndmin_gt_ndims, "").unwrap())
},
|generator, ctx| {
let ndarray = create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&object,
|_, ctx, object| {
let ndims = object.load_ndims(ctx);
2024-06-12 14:45:03 +08:00
let ndmin_gt_ndims = ctx
.builder
.build_int_compare(IntPredicate::UGT, ndmin, object.load_ndims(ctx), "")
.unwrap();
2024-06-12 14:45:03 +08:00
Ok(ctx
.builder
.build_select(ndmin_gt_ndims, ndmin, ndims, "")
.map(BasicValueEnum::into_int_value)
.unwrap())
},
|generator, ctx, object, idx| {
let ndims = object.load_ndims(ctx);
let ndmin = llvm_intrinsics::call_int_umax(ctx, ndims, ndmin, None);
// The number of dimensions to prepend 1's to
let offset = ctx.builder.build_int_sub(ndmin, ndims, "").unwrap();
Ok(gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
2024-06-12 14:45:03 +08:00
Ok(ctx
.builder
.build_int_compare(IntPredicate::UGE, idx, offset, "")
.unwrap())
},
2024-06-12 14:45:03 +08:00
|_, _| Ok(Some(llvm_usize.const_int(1, false))),
|_, ctx| Ok(Some(ctx.builder.build_int_sub(idx, offset, "").unwrap())),
)?
.map(BasicValueEnum::into_int_value)
.unwrap())
},
)?;
ndarray_sliced_copyto_impl(
generator,
ctx,
(ndarray, ndarray.data().base_ptr(ctx, generator)),
(object, object.data().base_ptr(ctx, generator)),
0,
&[],
)?;
Ok(Some(ndarray.as_base_value()))
},
2024-06-12 14:45:03 +08:00
|_, _| Ok(Some(object.as_base_value())),
)?;
return Ok(NDArrayValue::from_pointer_value(
ndarray.map(BasicValueEnum::into_pointer_value).unwrap(),
llvm_elem_ty,
None,
llvm_usize,
None,
2024-06-12 14:45:03 +08:00
));
}
// Remaining case: TList
assert!(ListValue::is_representable(object, llvm_usize).is_ok());
let object = ListValue::from_pointer_value(object, llvm_usize, None);
// The number of dimensions to prepend 1's to
let ndims = llvm_ndlist_get_ndims(generator, ctx, object.as_base_value().get_type());
let ndmin = llvm_intrinsics::call_int_umax(ctx, ndims, ndmin, None);
let offset = ctx.builder.build_int_sub(ndmin, ndims, "").unwrap();
let ndarray = create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&object,
|generator, ctx, object| {
let ndims = llvm_ndlist_get_ndims(generator, ctx, object.as_base_value().get_type());
2024-06-12 14:45:03 +08:00
let ndmin_gt_ndims =
ctx.builder.build_int_compare(IntPredicate::UGT, ndmin, ndims, "").unwrap();
2024-06-12 14:45:03 +08:00
Ok(ctx
.builder
.build_select(ndmin_gt_ndims, ndmin, ndims, "")
.map(BasicValueEnum::into_int_value)
.unwrap())
},
|generator, ctx, object, idx| {
Ok(gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
2024-06-12 14:45:03 +08:00
Ok(ctx.builder.build_int_compare(IntPredicate::ULT, idx, offset, "").unwrap())
},
2024-06-12 14:45:03 +08:00
|_, _| Ok(Some(llvm_usize.const_int(1, false))),
|generator, ctx| {
let make_llvm_list = |elem_ty: BasicTypeEnum<'ctx>| {
ctx.ctx.struct_type(
2024-06-12 14:45:03 +08:00
&[elem_ty.ptr_type(AddressSpace::default()).into(), llvm_usize.into()],
false,
)
};
let llvm_i8 = ctx.ctx.i8_type();
let llvm_list_i8 = make_llvm_list(llvm_i8.into());
let llvm_plist_i8 = llvm_list_i8.ptr_type(AddressSpace::default());
// Cast list to { i8*, usize } since we only care about the size
2024-06-12 14:45:03 +08:00
let lst = generator
.gen_var_alloc(
ctx,
ListType::new(generator, ctx.ctx, llvm_i8.into()).as_base_type().into(),
None,
)
.unwrap();
ctx.builder
.build_store(
lst,
ctx.builder
2024-08-20 20:16:36 +08:00
.build_bit_cast(object.as_base_value(), llvm_plist_i8, "")
2024-06-12 14:45:03 +08:00
.unwrap(),
)
.unwrap();
let stop = ctx.builder.build_int_sub(idx, offset, "").unwrap();
gen_for_range_callback(
generator,
ctx,
None,
true,
|_, _| Ok(llvm_usize.const_zero()),
(|_, _| Ok(stop), false),
|_, _| Ok(llvm_usize.const_int(1, false)),
|generator, ctx, _, _| {
let plist_plist_i8 = make_llvm_list(llvm_plist_i8.into())
.ptr_type(AddressSpace::default());
2024-06-12 14:45:03 +08:00
let this_dim = ctx
.builder
.build_load(lst, "")
.map(BasicValueEnum::into_pointer_value)
2024-08-20 20:16:36 +08:00
.map(|v| ctx.builder.build_bit_cast(v, plist_plist_i8, "").unwrap())
.map(BasicValueEnum::into_pointer_value)
.unwrap();
let this_dim =
ListValue::from_pointer_value(this_dim, llvm_usize, None);
// TODO: Assert this_dim.sz != 0
let next_dim = unsafe {
2024-06-12 14:45:03 +08:00
this_dim.data().get_unchecked(
ctx,
generator,
&llvm_usize.const_zero(),
None,
)
}
.into_pointer_value();
ctx.builder
.build_store(
lst,
2024-08-20 20:16:36 +08:00
ctx.builder
.build_bit_cast(next_dim, llvm_plist_i8, "")
.unwrap(),
2024-06-12 14:45:03 +08:00
)
.unwrap();
Ok(())
},
)?;
let lst = ListValue::from_pointer_value(
ctx.builder
.build_load(lst, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap(),
llvm_usize,
None,
);
Ok(Some(lst.load_size(ctx, None)))
},
2024-06-12 14:45:03 +08:00
)?
.map(BasicValueEnum::into_int_value)
.unwrap())
},
)?;
ndarray_from_ndlist_impl(
generator,
ctx,
(ndarray, ndarray.data().base_ptr(ctx, generator)),
object,
0,
)?;
Ok(ndarray)
}
/// LLVM-typed implementation for generating the implementation for `ndarray.eye`.
///
/// * `elem_ty` - The element type of the `NDArray`.
fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
nrows: IntValue<'ctx>,
ncols: IntValue<'ctx>,
offset: IntValue<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let nrows = ctx.builder.build_int_z_extend_or_bit_cast(nrows, llvm_usize, "").unwrap();
let ncols = ctx.builder.build_int_z_extend_or_bit_cast(ncols, llvm_usize, "").unwrap();
2024-06-12 14:45:03 +08:00
let ndarray = create_ndarray_const_shape(generator, ctx, elem_ty, &[nrows, ncols])?;
2024-06-12 14:45:03 +08:00
ndarray_fill_indexed(generator, ctx, ndarray, |generator, ctx, indices| {
let (row, col) = unsafe {
(
indices.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None),
indices.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None),
)
};
2024-06-12 14:45:03 +08:00
let col_with_offset = ctx
.builder
.build_int_add(
col,
ctx.builder.build_int_s_extend_or_bit_cast(offset, llvm_i32, "").unwrap(),
"",
)
.unwrap();
let is_on_diag =
ctx.builder.build_int_compare(IntPredicate::EQ, row, col_with_offset, "").unwrap();
2024-06-12 14:45:03 +08:00
let zero = ndarray_zero_value(generator, ctx, elem_ty);
let one = ndarray_one_value(generator, ctx, elem_ty);
2024-06-12 14:45:03 +08:00
let value = ctx.builder.build_select(is_on_diag, one, zero, "").unwrap();
2024-06-12 14:45:03 +08:00
Ok(value)
})?;
Ok(ndarray)
}
/// Copies a slice of an [`NDArrayValue`] to another.
///
/// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `shape`
2024-08-21 11:10:52 +08:00
/// fields should be populated before calling this function.
/// - `dst_slice_ptr`: The [`PointerValue`] to the first element of the currently processing
2024-08-21 11:10:52 +08:00
/// dimensional slice in the destination array.
/// - `src_arr`: The [`NDArrayValue`] instance of the source array.
/// - `src_slice_ptr`: The [`PointerValue`] to the first element of the currently processing
2024-08-21 11:10:52 +08:00
/// dimensional slice in the source array.
/// - `dim`: The index of the currently processing dimension.
/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to
2024-08-21 11:10:52 +08:00
/// this dimension. The `start`/`stop` values of each slice must be non-negative indices.
fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
(dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>),
(src_arr, src_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>),
dim: u64,
slices: &[(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)],
) -> Result<(), String> {
let llvm_i1 = ctx.ctx.bool_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
assert_eq!(dst_arr.get_type().element_type(), src_arr.get_type().element_type());
let sizeof_elem = dst_arr.get_type().element_type().size_of().unwrap();
// If there are no (remaining) slice expressions, memcpy the entire dimension
if slices.is_empty() {
let stride = call_ndarray_calc_size(
generator,
ctx,
&src_arr.shape(),
(Some(llvm_usize.const_int(dim, false)), None),
);
let stride =
ctx.builder.build_int_z_extend_or_bit_cast(stride, sizeof_elem.get_type(), "").unwrap();
2024-06-12 14:45:03 +08:00
let cpy_len = ctx.builder.build_int_mul(stride, sizeof_elem, "").unwrap();
2024-06-12 14:45:03 +08:00
call_memcpy_generic(ctx, dst_slice_ptr, src_slice_ptr, cpy_len, llvm_i1.const_zero());
2024-06-12 14:45:03 +08:00
return Ok(());
}
// The stride of elements in this dimension, i.e. the number of elements between arr[i] and
// arr[i + 1] in this dimension
let src_stride = call_ndarray_calc_size(
generator,
ctx,
&src_arr.shape(),
(Some(llvm_usize.const_int(dim + 1, false)), None),
);
let dst_stride = call_ndarray_calc_size(
generator,
ctx,
&dst_arr.shape(),
(Some(llvm_usize.const_int(dim + 1, false)), None),
);
let (start, stop, step) = slices[0];
let start = ctx.builder.build_int_s_extend_or_bit_cast(start, llvm_usize, "").unwrap();
let stop = ctx.builder.build_int_s_extend_or_bit_cast(stop, llvm_usize, "").unwrap();
let step = ctx.builder.build_int_s_extend_or_bit_cast(step, llvm_usize, "").unwrap();
let dst_i_addr = generator.gen_var_alloc(ctx, start.get_type().into(), None).unwrap();
ctx.builder.build_store(dst_i_addr, start.get_type().const_zero()).unwrap();
gen_for_range_callback(
generator,
ctx,
None,
false,
|_, _| Ok(start),
(|_, _| Ok(stop), true),
|_, _| Ok(step),
|generator, ctx, _, src_i| {
// Calculate the offset of the active slice
2024-06-12 14:45:03 +08:00
let src_data_offset = ctx.builder.build_int_mul(src_stride, src_i, "").unwrap();
let src_data_offset = ctx
.builder
.build_int_mul(
src_data_offset,
ctx.builder
.build_int_z_extend_or_bit_cast(sizeof_elem, src_data_offset.get_type(), "")
.unwrap(),
"",
)
.unwrap();
2024-06-12 14:45:03 +08:00
let dst_i =
ctx.builder.build_load(dst_i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
let dst_data_offset = ctx.builder.build_int_mul(dst_stride, dst_i, "").unwrap();
let dst_data_offset = ctx
.builder
.build_int_mul(
dst_data_offset,
ctx.builder
.build_int_z_extend_or_bit_cast(sizeof_elem, dst_data_offset.get_type(), "")
.unwrap(),
"",
)
.unwrap();
let (src_ptr, dst_ptr) = unsafe {
(
ctx.builder.build_gep(src_slice_ptr, &[src_data_offset], "").unwrap(),
ctx.builder.build_gep(dst_slice_ptr, &[dst_data_offset], "").unwrap(),
)
};
ndarray_sliced_copyto_impl(
generator,
ctx,
(dst_arr, dst_ptr),
(src_arr, src_ptr),
dim + 1,
&slices[1..],
)?;
2024-06-12 14:45:03 +08:00
let dst_i =
ctx.builder.build_load(dst_i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
let dst_i_add1 =
ctx.builder.build_int_add(dst_i, llvm_usize.const_int(1, false), "").unwrap();
ctx.builder.build_store(dst_i_addr, dst_i_add1).unwrap();
Ok(())
},
)?;
Ok(())
}
/// Copies a [`NDArrayValue`] using slices.
///
/// * `elem_ty` - The element type of the `NDArray`.
/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to
2024-08-21 11:10:52 +08:00
/// this dimension. The `start`/`stop` values of each slice must be positive indices.
pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
this: NDArrayValue<'ctx>,
slices: &[(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)],
) -> Result<NDArrayValue<'ctx>, String> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let ndarray =
if slices.is_empty() {
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&this,
|_, ctx, shape| Ok(shape.load_ndims(ctx)),
|generator, ctx, shape, idx| unsafe {
Ok(shape.shape().get_typed_unchecked(ctx, generator, &idx, None))
},
)?
} else {
let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty, None)
.construct_dyn_ndims(generator, ctx, this.load_ndims(ctx), None);
// Populate the first slices.len() dimensions by computing the size of each dim slice
for (i, (start, stop, step)) in slices.iter().enumerate() {
// HACK: workaround calculate_len_for_slice_range requiring exclusive stop
let stop = ctx
.builder
.build_select(
ctx.builder
.build_int_compare(
IntPredicate::SLT,
*step,
llvm_i32.const_zero(),
"is_neg",
)
.unwrap(),
ctx.builder
.build_int_sub(*stop, llvm_i32.const_int(1, true), "e_min_one")
.unwrap(),
ctx.builder
.build_int_add(*stop, llvm_i32.const_int(1, true), "e_add_one")
.unwrap(),
"final_e",
)
.map(BasicValueEnum::into_int_value)
.unwrap();
let slice_len = calculate_len_for_slice_range(generator, ctx, *start, stop, *step);
let slice_len =
ctx.builder.build_int_z_extend_or_bit_cast(slice_len, llvm_usize, "").unwrap();
unsafe {
ndarray.shape().set_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(i as u64, false),
slice_len,
);
}
}
// Populate the rest by directly copying the dim size from the source array
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_int(slices.len() as u64, false),
(this.load_ndims(ctx), false),
|generator, ctx, _, idx| {
unsafe {
let shape = this.shape().get_typed_unchecked(ctx, generator, &idx, None);
ndarray.shape().set_typed_unchecked(ctx, generator, &idx, shape);
}
Ok(())
},
llvm_usize.const_int(1, false),
)
.unwrap();
unsafe { ndarray.create_data(generator, ctx) };
ndarray
};
ndarray_sliced_copyto_impl(
generator,
ctx,
(ndarray, ndarray.data().base_ptr(ctx, generator)),
(this, this.data().base_ptr(ctx, generator)),
0,
slices,
)?;
Ok(ndarray)
}
/// LLVM-typed implementation for generating the implementation for `ndarray.copy`.
///
/// * `elem_ty` - The element type of the `NDArray`.
fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
this: NDArrayValue<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
ndarray_sliced_copy(generator, ctx, elem_ty, this, &[])
}
pub fn ndarray_elementwise_unaryop_impl<'ctx, 'a, G, MapFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
elem_ty: Type,
res: Option<NDArrayValue<'ctx>>,
operand: NDArrayValue<'ctx>,
map_fn: MapFn,
) -> Result<NDArrayValue<'ctx>, String>
2024-06-12 14:45:03 +08:00
where
G: CodeGenerator + ?Sized,
MapFn: Fn(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
BasicValueEnum<'ctx>,
) -> Result<BasicValueEnum<'ctx>, String>,
{
let res = res.unwrap_or_else(|| {
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&operand,
2024-06-12 14:45:03 +08:00
|_, ctx, v| Ok(v.load_ndims(ctx)),
|generator, ctx, v, idx| unsafe {
Ok(v.shape().get_typed_unchecked(ctx, generator, &idx, None))
},
2024-06-12 14:45:03 +08:00
)
.unwrap()
});
2024-06-12 14:45:03 +08:00
ndarray_fill_mapping(generator, ctx, operand, res, |generator, ctx, elem| {
map_fn(generator, ctx, elem)
})?;
Ok(res)
}
/// LLVM-typed implementation for computing elementwise binary operations on two input operands.
///
2024-06-12 14:45:03 +08:00
/// If the operand is a `ndarray`, the broadcast index corresponding to each element in the output
/// is computed, the element accessed and used as an operand of the `value_fn` arguments tuple.
/// Otherwise, the operand is treated as a scalar value, and is used as an operand of the
/// `value_fn` arguments tuple for all output elements.
///
/// The second element of the tuple indicates whether to treat the operand value as a `ndarray`
2024-06-12 14:45:03 +08:00
/// (which would be accessed by its broadcast index) or as a scalar value (which would be
/// broadcast to all elements).
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be
2024-08-21 11:10:52 +08:00
/// written to a new `ndarray`.
/// * `value_fn` - Function mapping the two input elements into the result.
///
/// # Panic
///
/// This function will panic if neither input operands (`lhs` or `rhs`) is a `ndarray`.
pub fn ndarray_elementwise_binop_impl<'ctx, 'a, G, ValueFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
elem_ty: Type,
res: Option<NDArrayValue<'ctx>>,
lhs: (Type, BasicValueEnum<'ctx>, bool),
rhs: (Type, BasicValueEnum<'ctx>, bool),
value_fn: ValueFn,
) -> Result<NDArrayValue<'ctx>, String>
2024-06-12 14:45:03 +08:00
where
G: CodeGenerator + ?Sized,
ValueFn: Fn(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
(BasicValueEnum<'ctx>, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String>,
{
let (lhs_ty, lhs_val, lhs_scalar) = lhs;
let (rhs_ty, rhs_val, rhs_scalar) = rhs;
2024-06-12 14:45:03 +08:00
assert!(
!(lhs_scalar && rhs_scalar),
"One of the operands must be a ndarray instance: `{}`, `{}`",
lhs_val.get_type(),
rhs_val.get_type()
);
let ndarray = res.unwrap_or_else(|| {
if lhs_scalar && rhs_scalar {
let lhs_val = NDArrayType::from_unifier_type(generator, ctx, lhs_ty)
.map_value(lhs_val.into_pointer_value(), None);
let rhs_val = NDArrayType::from_unifier_type(generator, ctx, rhs_ty)
.map_value(rhs_val.into_pointer_value(), None);
let ndarray_dims = call_ndarray_calc_broadcast(generator, ctx, lhs_val, rhs_val);
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&ndarray_dims,
2024-06-12 14:45:03 +08:00
|generator, ctx, v| Ok(v.size(ctx, generator)),
|generator, ctx, v, idx| unsafe {
Ok(v.get_typed_unchecked(ctx, generator, &idx, None))
},
2024-06-12 14:45:03 +08:00
)
.unwrap()
} else {
let ndarray = NDArrayType::from_unifier_type(
generator,
ctx,
if lhs_scalar { rhs_ty } else { lhs_ty },
)
.map_value(if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(), None);
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&ndarray,
2024-06-12 14:45:03 +08:00
|_, ctx, v| Ok(v.load_ndims(ctx)),
|generator, ctx, v, idx| unsafe {
Ok(v.shape().get_typed_unchecked(ctx, generator, &idx, None))
},
2024-06-12 14:45:03 +08:00
)
.unwrap()
}
});
2024-06-12 14:45:03 +08:00
ndarray_broadcast_fill(generator, ctx, ndarray, lhs, rhs, |generator, ctx, elems| {
value_fn(generator, ctx, elems)
})?;
Ok(ndarray)
}
/// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s.
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be
2024-08-21 11:10:52 +08:00
/// written to a new `ndarray`.
pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
res: Option<NDArrayValue<'ctx>>,
lhs: NDArrayValue<'ctx>,
rhs: NDArrayValue<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
if cfg!(debug_assertions) {
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_ndims = rhs.load_ndims(ctx);
// lhs.ndims == 2
ctx.make_assert(
generator,
2024-06-12 14:45:03 +08:00
ctx.builder
.build_int_compare(IntPredicate::EQ, lhs_ndims, llvm_usize.const_int(2, false), "")
.unwrap(),
"0:ValueError",
"",
[None, None, None],
ctx.current_loc,
);
// rhs.ndims == 2
ctx.make_assert(
generator,
2024-06-12 14:45:03 +08:00
ctx.builder
.build_int_compare(IntPredicate::EQ, rhs_ndims, llvm_usize.const_int(2, false), "")
.unwrap(),
"0:ValueError",
"",
[None, None, None],
ctx.current_loc,
);
if let Some(res) = res {
let res_ndims = res.load_ndims(ctx);
let res_dim0 = unsafe {
res.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
let res_dim1 = unsafe {
res.shape().get_typed_unchecked(
2024-06-12 14:45:03 +08:00
ctx,
generator,
&llvm_usize.const_int(1, false),
None,
)
};
let lhs_dim0 = unsafe {
lhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
let rhs_dim1 = unsafe {
rhs.shape().get_typed_unchecked(
2024-06-12 14:45:03 +08:00
ctx,
generator,
&llvm_usize.const_int(1, false),
None,
)
};
// res.ndims == 2
ctx.make_assert(
generator,
2024-06-12 14:45:03 +08:00
ctx.builder
.build_int_compare(
IntPredicate::EQ,
res_ndims,
llvm_usize.const_int(2, false),
"",
)
.unwrap(),
"0:ValueError",
"",
[None, None, None],
ctx.current_loc,
);
// res.dims[0] == lhs.dims[0]
ctx.make_assert(
generator,
2024-06-12 14:45:03 +08:00
ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim0, res_dim0, "").unwrap(),
"0:ValueError",
"",
[None, None, None],
ctx.current_loc,
);
// res.dims[1] == rhs.dims[0]
ctx.make_assert(
generator,
2024-06-12 14:45:03 +08:00
ctx.builder.build_int_compare(IntPredicate::EQ, rhs_dim1, res_dim1, "").unwrap(),
"0:ValueError",
"",
[None, None, None],
ctx.current_loc,
);
}
}
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let lhs_dim1 = unsafe {
lhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
};
let rhs_dim0 = unsafe {
rhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
// lhs.dims[1] == rhs.dims[0]
ctx.make_assert(
generator,
2024-06-12 14:45:03 +08:00
ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim1, rhs_dim0, "").unwrap(),
"0:ValueError",
"",
[None, None, None],
ctx.current_loc,
);
}
let lhs = if res.is_some_and(|res| res.as_base_value() == lhs.as_base_value()) {
ndarray_copy_impl(generator, ctx, elem_ty, lhs)?
} else {
lhs
};
let ndarray = res.unwrap_or_else(|| {
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&(lhs, rhs),
2024-06-12 14:45:03 +08:00
|_, _, _| Ok(llvm_usize.const_int(2, false)),
|generator, ctx, (lhs, rhs), idx| {
gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
2024-06-12 14:45:03 +08:00
Ok(ctx
.builder
.build_int_compare(IntPredicate::EQ, idx, llvm_usize.const_zero(), "")
.unwrap())
},
|generator, ctx| {
Ok(Some(unsafe {
lhs.shape().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_zero(),
None,
)
}))
},
|generator, ctx| {
Ok(Some(unsafe {
rhs.shape().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(1, false),
None,
)
}))
},
2024-06-12 14:45:03 +08:00
)
.map(|v| v.map(BasicValueEnum::into_int_value).unwrap())
},
2024-06-12 14:45:03 +08:00
)
.unwrap()
});
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
2024-06-12 14:45:03 +08:00
ndarray_fill_indexed(generator, ctx, ndarray, |generator, ctx, idx| {
llvm_intrinsics::call_expect(
ctx,
idx.size(ctx, generator).get_type().const_int(2, false),
idx.size(ctx, generator),
None,
);
2024-06-12 14:45:03 +08:00
let common_dim = {
let lhs_idx1 = unsafe {
lhs.shape().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(1, false),
None,
2024-06-12 14:45:03 +08:00
)
};
let rhs_idx0 = unsafe {
rhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
2024-06-12 14:45:03 +08:00
let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None);
2024-06-12 14:45:03 +08:00
ctx.builder.build_int_truncate(idx, llvm_i32, "").unwrap()
};
2024-06-12 14:45:03 +08:00
let idx0 = unsafe {
let idx0 = idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
2024-06-12 14:45:03 +08:00
ctx.builder.build_int_truncate(idx0, llvm_i32, "").unwrap()
};
let idx1 = unsafe {
let idx1 =
idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None);
2024-06-12 14:45:03 +08:00
ctx.builder.build_int_truncate(idx1, llvm_i32, "").unwrap()
};
2024-06-12 14:45:03 +08:00
let result_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
let result_identity = ndarray_zero_value(generator, ctx, elem_ty);
ctx.builder.build_store(result_addr, result_identity).unwrap();
2024-06-12 14:45:03 +08:00
gen_for_callback_incrementing(
generator,
ctx,
None,
2024-06-12 14:45:03 +08:00
llvm_i32.const_zero(),
(common_dim, false),
|generator, ctx, _, i| {
2024-06-12 14:45:03 +08:00
let i = ctx.builder.build_int_truncate(i, llvm_i32, "").unwrap();
let ab_idx = generator.gen_array_var_alloc(
ctx,
llvm_i32.into(),
llvm_usize.const_int(2, false),
None,
)?;
let a = unsafe {
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), idx0.into());
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_int(1, false), i.into());
lhs.data().get_unchecked(ctx, generator, &ab_idx, None)
};
let b = unsafe {
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), i.into());
ab_idx.set_unchecked(
ctx,
generator,
2024-06-12 14:45:03 +08:00
&llvm_usize.const_int(1, false),
idx1.into(),
);
2024-06-12 14:45:03 +08:00
rhs.data().get_unchecked(ctx, generator, &ab_idx, None)
};
2024-06-12 14:45:03 +08:00
let a_mul_b = gen_binop_expr_with_values(
generator,
ctx,
(&Some(elem_ty), a),
Binop::normal(Operator::Mult),
2024-06-12 14:45:03 +08:00
(&Some(elem_ty), b),
ctx.current_loc,
)?
.unwrap()
.to_basic_value_enum(ctx, generator, elem_ty)?;
let result = ctx.builder.build_load(result_addr, "").unwrap();
let result = gen_binop_expr_with_values(
generator,
ctx,
(&Some(elem_ty), result),
Binop::normal(Operator::Add),
2024-06-12 14:45:03 +08:00
(&Some(elem_ty), a_mul_b),
ctx.current_loc,
)?
.unwrap()
.to_basic_value_enum(ctx, generator, elem_ty)?;
ctx.builder.build_store(result_addr, result).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let result = ctx.builder.build_load(result_addr, "").unwrap();
Ok(result)
})?;
Ok(ndarray)
}
/// Generates LLVM IR for `ndarray.empty`.
pub fn gen_ndarray_empty<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let shape_ty = fun.0.args[0].ty;
2024-06-12 14:45:03 +08:00
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
call_ndarray_empty_impl(generator, context, context.primitives.float, shape_arg)
.map(NDArrayValue::into)
}
/// Generates LLVM IR for `ndarray.zeros`.
pub fn gen_ndarray_zeros<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let shape_ty = fun.0.args[0].ty;
2024-06-12 14:45:03 +08:00
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
call_ndarray_zeros_impl(generator, context, context.primitives.float, shape_arg)
.map(NDArrayValue::into)
}
/// Generates LLVM IR for `ndarray.ones`.
pub fn gen_ndarray_ones<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let shape_ty = fun.0.args[0].ty;
2024-06-12 14:45:03 +08:00
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
call_ndarray_ones_impl(generator, context, context.primitives.float, shape_arg)
.map(NDArrayValue::into)
}
/// Generates LLVM IR for `ndarray.full`.
pub fn gen_ndarray_full<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 2);
let shape_ty = fun.0.args[0].ty;
2024-06-12 14:45:03 +08:00
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
let fill_value_ty = fun.0.args[1].ty;
2024-06-12 14:45:03 +08:00
let fill_value_arg =
args[1].1.clone().to_basic_value_enum(context, generator, fill_value_ty)?;
call_ndarray_full_impl(generator, context, fill_value_ty, shape_arg, fill_value_arg)
.map(NDArrayValue::into)
}
pub fn gen_ndarray_array<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert!(matches!(args.len(), 1..=3));
let obj_ty = fun.0.args[0].ty;
let obj_elem_ty = match &*context.unifier.get_ty(obj_ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
unpack_ndarray_var_tys(&mut context.unifier, obj_ty).0
}
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
let mut ty = *params.iter().next().unwrap().1;
while let TypeEnum::TObj { obj_id, params, .. } = &*context.unifier.get_ty_immutable(ty)
{
if *obj_id != PrimDef::List.id() {
break;
}
ty = *params.iter().next().unwrap().1;
}
ty
2024-06-12 14:45:03 +08:00
}
_ => obj_ty,
};
2024-06-12 14:45:03 +08:00
let obj_arg = args[0].1.clone().to_basic_value_enum(context, generator, obj_ty)?;
let copy_arg = if let Some(arg) =
2024-06-12 14:45:03 +08:00
args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name))
{
let copy_ty = fun.0.args[1].ty;
arg.1.clone().to_basic_value_enum(context, generator, copy_ty)?
} else {
context.gen_symbol_val(
generator,
fun.0.args[1].default_value.as_ref().unwrap(),
fun.0.args[1].ty,
2024-06-12 14:45:03 +08:00
)
};
let ndmin_arg = if let Some(arg) =
2024-06-12 14:45:03 +08:00
args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[2].name))
{
let ndmin_ty = fun.0.args[2].ty;
arg.1.clone().to_basic_value_enum(context, generator, ndmin_ty)?
} else {
context.gen_symbol_val(
generator,
fun.0.args[2].default_value.as_ref().unwrap(),
fun.0.args[2].ty,
)
2024-06-12 14:45:03 +08:00
};
call_ndarray_array_impl(
generator,
context,
obj_elem_ty,
obj_arg,
copy_arg.into_int_value(),
ndmin_arg.into_int_value(),
2024-06-12 14:45:03 +08:00
)
.map(NDArrayValue::into)
}
/// Generates LLVM IR for `ndarray.eye`.
pub fn gen_ndarray_eye<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert!(matches!(args.len(), 1..=3));
let nrows_ty = fun.0.args[0].ty;
2024-06-12 14:45:03 +08:00
let nrows_arg = args[0].1.clone().to_basic_value_enum(context, generator, nrows_ty)?;
let ncols_ty = fun.0.args[1].ty;
2024-06-12 14:45:03 +08:00
let ncols_arg = if let Some(arg) =
args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name))
{
2024-04-01 16:22:40 +08:00
arg.1.clone().to_basic_value_enum(context, generator, ncols_ty)
} else {
args[0].1.clone().to_basic_value_enum(context, generator, nrows_ty)
}?;
let offset_ty = fun.0.args[2].ty;
2024-04-01 16:22:40 +08:00
let offset_arg = if let Some(arg) =
2024-06-12 14:45:03 +08:00
args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[2].name))
{
arg.1.clone().to_basic_value_enum(context, generator, offset_ty)
2024-04-01 16:22:40 +08:00
} else {
Ok(context.gen_symbol_val(
generator,
fun.0.args[2].default_value.as_ref().unwrap(),
2024-06-12 14:45:03 +08:00
offset_ty,
))
2024-04-01 16:22:40 +08:00
}?;
call_ndarray_eye_impl(
generator,
context,
context.primitives.float,
nrows_arg.into_int_value(),
ncols_arg.into_int_value(),
offset_arg.into_int_value(),
2024-06-12 14:45:03 +08:00
)
.map(NDArrayValue::into)
}
/// Generates LLVM IR for `ndarray.identity`.
pub fn gen_ndarray_identity<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let llvm_usize = generator.get_size_type(context.ctx);
let n_ty = fun.0.args[0].ty;
2024-06-12 14:45:03 +08:00
let n_arg = args[0].1.clone().to_basic_value_enum(context, generator, n_ty)?;
call_ndarray_eye_impl(
generator,
context,
context.primitives.float,
n_arg.into_int_value(),
n_arg.into_int_value(),
llvm_usize.const_zero(),
2024-06-12 14:45:03 +08:00
)
.map(NDArrayValue::into)
}
/// Generates LLVM IR for `ndarray.copy`.
pub fn gen_ndarray_copy<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
_fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_some());
assert!(args.is_empty());
let this_ty = obj.as_ref().unwrap().0;
let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty);
2024-06-12 14:45:03 +08:00
let this_arg =
obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?;
let llvm_this_ty = NDArrayType::from_unifier_type(generator, context, this_ty);
ndarray_copy_impl(
generator,
context,
this_elem_ty,
llvm_this_ty.map_value(this_arg.into_pointer_value(), None),
2024-06-12 14:45:03 +08:00
)
.map(NDArrayValue::into)
}
/// Generates LLVM IR for `ndarray.fill`.
pub fn gen_ndarray_fill<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<(), String> {
assert!(obj.is_some());
assert_eq!(args.len(), 1);
let this_ty = obj.as_ref().unwrap().0;
2024-06-12 14:45:03 +08:00
let this_arg = obj
.as_ref()
.unwrap()
.1
.clone()
.to_basic_value_enum(context, generator, this_ty)?
.into_pointer_value();
let value_ty = fun.0.args[0].ty;
2024-06-12 14:45:03 +08:00
let value_arg = args[0].1.clone().to_basic_value_enum(context, generator, value_ty)?;
let llvm_this_ty = NDArrayType::from_unifier_type(generator, context, this_ty);
ndarray_fill_flattened(
generator,
context,
llvm_this_ty.map_value(this_arg, None),
|generator, ctx, _| {
let value = if value_arg.is_pointer_value() {
let llvm_i1 = ctx.ctx.bool_type();
let copy = generator.gen_var_alloc(ctx, value_arg.get_type(), None)?;
call_memcpy_generic(
ctx,
copy,
value_arg.into_pointer_value(),
value_arg.get_type().size_of().map(Into::into).unwrap(),
llvm_i1.const_zero(),
);
copy.into()
} else if value_arg.is_int_value() || value_arg.is_float_value() {
value_arg
} else {
2024-08-23 13:10:55 +08:00
codegen_unreachable!(ctx)
};
Ok(value)
2024-06-12 14:45:03 +08:00
},
)?;
Ok(())
2024-06-12 14:45:03 +08:00
}
/// Generates LLVM IR for `ndarray.transpose`.
pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "ndarray_transpose";
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty);
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1 = llvm_ndarray_ty.map_value(n1, None);
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
// Dimensions are reversed in the transposed array
let out = create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&n1,
|_, ctx, n| Ok(n.load_ndims(ctx)),
|generator, ctx, n, idx| {
let new_idx = ctx.builder.build_int_sub(n.load_ndims(ctx), idx, "").unwrap();
let new_idx = ctx
.builder
.build_int_sub(new_idx, new_idx.get_type().const_int(1, false), "")
.unwrap();
unsafe { Ok(n.shape().get_typed_unchecked(ctx, generator, &new_idx, None)) }
},
)
.unwrap();
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(n_sz, false),
|generator, ctx, _, idx| {
let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
let new_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
let rem_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
ctx.builder.build_store(new_idx, llvm_usize.const_zero()).unwrap();
ctx.builder.build_store(rem_idx, idx).unwrap();
// Incrementally calculate the new index in the transposed array
// For each index, we first decompose it into the n-dims and use those to reconstruct the new index
// The formula used for indexing is:
// idx = dim_n * ( ... (dim2 * (dim0 * dim1) + dim1) + dim2 ... ) + dim_n
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(n1.load_ndims(ctx), false),
|generator, ctx, _, ndim| {
let ndim_rev =
ctx.builder.build_int_sub(n1.load_ndims(ctx), ndim, "").unwrap();
let ndim_rev = ctx
.builder
.build_int_sub(ndim_rev, llvm_usize.const_int(1, false), "")
.unwrap();
let dim = unsafe {
n1.shape().get_typed_unchecked(ctx, generator, &ndim_rev, None)
};
let rem_idx_val =
ctx.builder.build_load(rem_idx, "").unwrap().into_int_value();
let new_idx_val =
ctx.builder.build_load(new_idx, "").unwrap().into_int_value();
let add_component =
ctx.builder.build_int_unsigned_rem(rem_idx_val, dim, "").unwrap();
let rem_idx_val =
ctx.builder.build_int_unsigned_div(rem_idx_val, dim, "").unwrap();
let new_idx_val = ctx.builder.build_int_mul(new_idx_val, dim, "").unwrap();
let new_idx_val =
ctx.builder.build_int_add(new_idx_val, add_component, "").unwrap();
ctx.builder.build_store(rem_idx, rem_idx_val).unwrap();
ctx.builder.build_store(new_idx, new_idx_val).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let new_idx_val = ctx.builder.build_load(new_idx, "").unwrap().into_int_value();
unsafe { out.data().set_unchecked(ctx, generator, &new_idx_val, elem) };
Ok(())
},
llvm_usize.const_int(1, false),
)?;
Ok(out.as_base_value().into())
} else {
2024-08-23 13:10:55 +08:00
codegen_unreachable!(
ctx,
"{FN_NAME}() not supported for '{}'",
format!("'{}'", ctx.unifier.stringify(x1_ty))
)
}
}
/// LLVM-typed implementation for generating the implementation for `ndarray.reshape`.
///
/// * `x1` - `NDArray` to reshape.
/// * `shape` - The `shape` parameter used to construct the new `NDArray`.
2024-08-21 11:10:52 +08:00
/// Just like numpy, the `shape` argument can be:
/// 1. A list of `int32`; e.g., `np.reshape(arr, [600, -1, 3])`
/// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))`
/// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)`
2024-08-20 11:29:03 +08:00
///
/// Note that unlike other generating functions, one of the dimensions in the shape can be negative.
pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
shape: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "ndarray_reshape";
let (x1_ty, x1) = x1;
let (_, shape) = shape;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty);
let n1 = llvm_ndarray_ty.map_value(n1, None);
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
ctx.builder.build_store(acc, llvm_usize.const_int(1, false)).unwrap();
ctx.builder.build_store(num_neg, llvm_usize.const_zero()).unwrap();
let out = match shape {
BasicValueEnum::PointerValue(shape_list_ptr)
if ListValue::is_representable(shape_list_ptr, llvm_usize).is_ok() =>
{
// 1. A list of ints; e.g., `np.reshape(arr, [int64(600), int64(800, -1])`
let shape_list = ListValue::from_pointer_value(shape_list_ptr, llvm_usize, None);
// Check for -1 in dimensions
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(shape_list.load_size(ctx, None), false),
|generator, ctx, _, idx| {
let ele =
shape_list.data().get(ctx, generator, &idx, None).into_int_value();
let ele = ctx.builder.build_int_s_extend(ele, llvm_usize, "").unwrap();
gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
ele,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, ctx| -> Result<Option<IntValue>, String> {
let num_neg_value =
ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
let num_neg_value = ctx
.builder
.build_int_add(
num_neg_value,
llvm_usize.const_int(1, false),
"",
)
.unwrap();
ctx.builder.build_store(num_neg, num_neg_value).unwrap();
Ok(None)
},
|_, ctx| {
let acc_value =
ctx.builder.build_load(acc, "").unwrap().into_int_value();
let acc_value =
ctx.builder.build_int_mul(acc_value, ele, "").unwrap();
ctx.builder.build_store(acc, acc_value).unwrap();
Ok(None)
},
)?;
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
// Generate the output shape by filling -1 with `rem`
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&shape_list,
|_, ctx, _| Ok(shape_list.load_size(ctx, None)),
|generator, ctx, shape_list, idx| {
let dim =
shape_list.data().get(ctx, generator, &idx, None).into_int_value();
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
Ok(gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
dim,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, _| Ok(Some(rem)),
|_, _| Ok(Some(dim)),
)?
.unwrap()
.into_int_value())
},
)
}
BasicValueEnum::StructValue(shape_tuple) => {
// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))`
let ndims = shape_tuple.get_type().count_fields();
// Check for -1 in dims
for dim_i in 0..ndims {
let dim = ctx
.builder
.build_extract_value(shape_tuple, dim_i, "")
.unwrap()
.into_int_value();
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
dim,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, ctx| -> Result<Option<IntValue>, String> {
let num_negs =
ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
let num_negs = ctx
.builder
.build_int_add(num_negs, llvm_usize.const_int(1, false), "")
.unwrap();
ctx.builder.build_store(num_neg, num_negs).unwrap();
Ok(None)
},
|_, ctx| {
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
let acc_val = ctx.builder.build_int_mul(acc_val, dim, "").unwrap();
ctx.builder.build_store(acc, acc_val).unwrap();
Ok(None)
},
)?;
}
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
let mut shape = Vec::with_capacity(ndims as usize);
// Reconstruct shape filling negatives with rem
for dim_i in 0..ndims {
let dim = ctx
.builder
.build_extract_value(shape_tuple, dim_i, "")
.unwrap()
.into_int_value();
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
let dim = gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
dim,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, _| Ok(Some(rem)),
|_, _| Ok(Some(dim)),
)?
.unwrap()
.into_int_value();
shape.push(dim);
}
create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice())
}
BasicValueEnum::IntValue(shape_int) => {
// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)`
let shape_int = gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
shape_int,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, _| Ok(Some(n_sz)),
|_, ctx| {
Ok(Some(ctx.builder.build_int_s_extend(shape_int, llvm_usize, "").unwrap()))
},
)?
.unwrap()
.into_int_value();
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
}
2024-08-23 13:10:55 +08:00
_ => codegen_unreachable!(ctx),
}
.unwrap();
// Only allow one dimension to be negative
let num_negs = ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
ctx.make_assert(
generator,
ctx.builder
.build_int_compare(IntPredicate::ULT, num_negs, llvm_usize.const_int(2, false), "")
.unwrap(),
"0:ValueError",
"can only specify one unknown dimension",
[None, None, None],
ctx.current_loc,
);
// The new shape must be compatible with the old shape
let out_sz = call_ndarray_calc_size(generator, ctx, &out.shape(), (None, None));
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(),
"0:ValueError",
2024-07-31 15:53:51 +08:00
"cannot reshape array of size {0} into provided shape of size {1}",
[Some(n_sz), Some(out_sz), None],
ctx.current_loc,
);
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(n_sz, false),
|generator, ctx, _, idx| {
let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
unsafe { out.data().set_unchecked(ctx, generator, &idx, elem) };
Ok(())
},
llvm_usize.const_int(1, false),
)?;
Ok(out.as_base_value().into())
} else {
2024-08-23 13:10:55 +08:00
codegen_unreachable!(
ctx,
"{FN_NAME}() not supported for '{}'",
format!("'{}'", ctx.unifier.stringify(x1_ty))
)
}
}
2024-07-31 15:53:51 +08:00
/// Generates LLVM IR for `ndarray.dot`.
/// Calculate inner product of two vectors or literals
/// For matrix multiplication use `np_matmul`
///
/// The input `NDArray` are flattened and treated as 1D
/// The operation is equivalent to `np.dot(arr1.ravel(), arr2.ravel())`
2024-07-31 15:53:51 +08:00
pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "ndarray_dot";
let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
2024-07-31 15:53:51 +08:00
let llvm_usize = generator.get_size_type(ctx.ctx);
match (x1, x2) {
(BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => {
let n1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(n1, None);
let n2 = NDArrayType::from_unifier_type(generator, ctx, x2_ty).map_value(n2, None);
2024-07-31 15:53:51 +08:00
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
2024-07-31 15:53:51 +08:00
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::EQ, n1_sz, n2_sz, "").unwrap(),
"0:ValueError",
"shapes ({0}), ({1}) not aligned",
[Some(n1_sz), Some(n2_sz), None],
ctx.current_loc,
);
let identity =
unsafe { n1.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
let acc = ctx.builder.build_alloca(identity.get_type(), "").unwrap();
ctx.builder.build_store(acc, identity.get_type().const_zero()).unwrap();
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(n1_sz, false),
|generator, ctx, _, idx| {
let elem1 = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
let elem2 = unsafe { n2.data().get_unchecked(ctx, generator, &idx, None) };
let product = match elem1 {
BasicValueEnum::IntValue(e1) => ctx
.builder
.build_int_mul(e1, elem2.into_int_value(), "")
.unwrap()
.as_basic_value_enum(),
BasicValueEnum::FloatValue(e1) => ctx
.builder
.build_float_mul(e1, elem2.into_float_value(), "")
.unwrap()
.as_basic_value_enum(),
_ => codegen_unreachable!(ctx, "product: {}", elem1.get_type()),
2024-07-31 15:53:51 +08:00
};
let acc_val = ctx.builder.build_load(acc, "").unwrap();
let acc_val = match acc_val {
BasicValueEnum::IntValue(e1) => ctx
.builder
.build_int_add(e1, product.into_int_value(), "")
.unwrap()
.as_basic_value_enum(),
BasicValueEnum::FloatValue(e1) => ctx
.builder
.build_float_add(e1, product.into_float_value(), "")
.unwrap()
.as_basic_value_enum(),
_ => codegen_unreachable!(ctx, "acc_val: {}", acc_val.get_type()),
2024-07-31 15:53:51 +08:00
};
ctx.builder.build_store(acc, acc_val).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let acc_val = ctx.builder.build_load(acc, "").unwrap();
Ok(acc_val)
}
(BasicValueEnum::IntValue(e1), BasicValueEnum::IntValue(e2)) => {
Ok(ctx.builder.build_int_mul(e1, e2, "").unwrap().as_basic_value_enum())
}
(BasicValueEnum::FloatValue(e1), BasicValueEnum::FloatValue(e2)) => {
Ok(ctx.builder.build_float_mul(e1, e2, "").unwrap().as_basic_value_enum())
}
2024-08-23 13:10:55 +08:00
_ => codegen_unreachable!(
ctx,
2024-07-31 15:53:51 +08:00
"{FN_NAME}() not supported for '{}'",
format!("'{}'", ctx.unifier.stringify(x1_ty))
),
}
}