1
0
forked from M-Labs/nac3

[core] codegen/ndarray: Reimplement np_{eye,identity}

Based on fa047d50: core/ndstrides: implement np_identity() and np_eye()
This commit is contained in:
David Mak 2024-12-17 18:01:12 +08:00
parent fadadd7505
commit acb437919d
2 changed files with 126 additions and 75 deletions

View File

@ -18,10 +18,7 @@ use super::{
llvm_intrinsics::{self, call_memcpy_generic}, llvm_intrinsics::{self, call_memcpy_generic},
macros::codegen_unreachable, macros::codegen_unreachable,
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
types::ndarray::{ types::ndarray::{factory::ndarray_zero_value, NDArrayType},
factory::{ndarray_one_value, ndarray_zero_value},
NDArrayType,
},
values::{ values::{
ndarray::{shape::parse_numpy_int_sequence, NDArrayValue}, ndarray::{shape::parse_numpy_int_sequence, NDArrayValue},
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, ProxyValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, ProxyValue,
@ -406,55 +403,6 @@ where
Ok(res) Ok(res)
} }
/// LLVM-typed implementation for generating the implementation for `ndarray.eye`.
///
/// * `elem_ty` - The element type of the `NDArray`.
fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
nrows: IntValue<'ctx>,
ncols: IntValue<'ctx>,
offset: IntValue<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let nrows = ctx.builder.build_int_z_extend_or_bit_cast(nrows, llvm_usize, "").unwrap();
let ncols = ctx.builder.build_int_z_extend_or_bit_cast(ncols, llvm_usize, "").unwrap();
let ndarray = create_ndarray_const_shape(generator, ctx, elem_ty, &[nrows, ncols])?;
ndarray_fill_indexed(generator, ctx, ndarray, |generator, ctx, indices| {
let (row, col) = unsafe {
(
indices.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None),
indices.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None),
)
};
let col_with_offset = ctx
.builder
.build_int_add(
col,
ctx.builder.build_int_s_extend_or_bit_cast(offset, llvm_i32, "").unwrap(),
"",
)
.unwrap();
let is_on_diag =
ctx.builder.build_int_compare(IntPredicate::EQ, row, col_with_offset, "").unwrap();
let zero = ndarray_zero_value(generator, ctx, elem_ty);
let one = ndarray_one_value(generator, ctx, elem_ty);
let value = ctx.builder.build_select(is_on_diag, one, zero, "").unwrap();
Ok(value)
})?;
Ok(ndarray)
}
/// Copies a slice of an [`NDArrayValue`] to another. /// Copies a slice of an [`NDArrayValue`] to another.
/// ///
/// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `shape` /// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `shape`
@ -1304,15 +1252,27 @@ pub fn gen_ndarray_eye<'ctx>(
)) ))
}?; }?;
call_ndarray_eye_impl( let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
generator,
context, let llvm_usize = generator.get_size_type(context.ctx);
context.primitives.float, let llvm_dtype = context.get_llvm_type(generator, dtype);
nrows_arg.into_int_value(),
ncols_arg.into_int_value(), let nrows = context
offset_arg.into_int_value(), .builder
) .build_int_s_extend_or_bit_cast(nrows_arg.into_int_value(), llvm_usize, "")
.map(NDArrayValue::into) .unwrap();
let ncols = context
.builder
.build_int_s_extend_or_bit_cast(ncols_arg.into_int_value(), llvm_usize, "")
.unwrap();
let offset = context
.builder
.build_int_s_extend_or_bit_cast(offset_arg.into_int_value(), llvm_usize, "")
.unwrap();
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(2))
.construct_numpy_eye(generator, context, dtype, nrows, ncols, offset, None);
Ok(ndarray.as_base_value())
} }
/// Generates LLVM IR for `ndarray.identity`. /// Generates LLVM IR for `ndarray.identity`.
@ -1326,20 +1286,21 @@ pub fn gen_ndarray_identity<'ctx>(
assert!(obj.is_none()); assert!(obj.is_none());
assert_eq!(args.len(), 1); assert_eq!(args.len(), 1);
let llvm_usize = generator.get_size_type(context.ctx);
let n_ty = fun.0.args[0].ty; let n_ty = fun.0.args[0].ty;
let n_arg = args[0].1.clone().to_basic_value_enum(context, generator, n_ty)?; let n_arg = args[0].1.clone().to_basic_value_enum(context, generator, n_ty)?;
call_ndarray_eye_impl( let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
generator,
context, let llvm_usize = generator.get_size_type(context.ctx);
context.primitives.float, let llvm_dtype = context.get_llvm_type(generator, dtype);
n_arg.into_int_value(),
n_arg.into_int_value(), let n = context
llvm_usize.const_zero(), .builder
) .build_int_s_extend_or_bit_cast(n_arg.into_int_value(), llvm_usize, "")
.map(NDArrayValue::into) .unwrap();
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(2))
.construct_numpy_identity(generator, context, dtype, n, None);
Ok(ndarray.as_base_value())
} }
/// Generates LLVM IR for `ndarray.copy`. /// Generates LLVM IR for `ndarray.copy`.

View File

@ -1,4 +1,7 @@
use inkwell::values::{BasicValueEnum, IntValue}; use inkwell::{
values::{BasicValueEnum, IntValue},
IntPredicate,
};
use super::NDArrayType; use super::NDArrayType;
use crate::{ use crate::{
@ -36,7 +39,7 @@ pub fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
} }
/// Get the one value in `np.ones()` of a `dtype`. /// Get the one value in `np.ones()` of a `dtype`.
pub fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>( fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type, dtype: Type,
@ -143,4 +146,91 @@ impl<'ctx> NDArrayType<'ctx> {
let fill_value = ndarray_one_value(generator, ctx, dtype); let fill_value = ndarray_one_value(generator, ctx, dtype);
self.construct_numpy_full(generator, ctx, shape, fill_value, name) self.construct_numpy_full(generator, ctx, shape, fill_value, name)
} }
/// Create an ndarray like
/// [`np.eye`](https://numpy.org/doc/stable/reference/generated/numpy.eye.html).
#[allow(clippy::too_many_arguments)]
pub fn construct_numpy_eye<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
nrows: IntValue<'ctx>,
ncols: IntValue<'ctx>,
offset: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
assert_eq!(
ctx.get_llvm_type(generator, dtype),
self.dtype,
"Expected LLVM dtype={} but got {}",
self.dtype.print_to_string(),
ctx.get_llvm_type(generator, dtype).print_to_string(),
);
assert_eq!(nrows.get_type(), self.llvm_usize);
assert_eq!(ncols.get_type(), self.llvm_usize);
assert_eq!(offset.get_type(), self.llvm_usize);
let ndzero = ndarray_zero_value(generator, ctx, dtype);
let ndone = ndarray_one_value(generator, ctx, dtype);
let ndarray = self.construct_dyn_shape(generator, ctx, &[nrows, ncols], name);
// Create data and make the matrix like look np.eye()
unsafe {
ndarray.create_data(generator, ctx);
}
ndarray
.foreach(generator, ctx, |generator, ctx, _, nditer| {
// NOTE: rows and cols can never be zero here, since this ndarray's `np.size` would be zero
// and this loop would not execute.
let indices = nditer.get_indices();
let row_i = unsafe {
indices.get_typed_unchecked(ctx, generator, &self.llvm_usize.const_zero(), None)
};
let col_i = unsafe {
indices.get_typed_unchecked(
ctx,
generator,
&self.llvm_usize.const_int(1, false),
None,
)
};
let be_one = ctx
.builder
.build_int_compare(
IntPredicate::EQ,
ctx.builder.build_int_add(row_i, offset, "").unwrap(),
col_i,
"",
)
.unwrap();
let value = ctx.builder.build_select(be_one, ndone, ndzero, "value").unwrap();
let p = nditer.get_pointer(ctx);
ctx.builder.build_store(p, value).unwrap();
Ok(())
})
.unwrap();
ndarray
}
/// Create an ndarray like
/// [`np.identity`](https://numpy.org/doc/stable/reference/generated/numpy.identity.html).
pub fn construct_numpy_identity<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
size: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
let offset = self.llvm_usize.const_zero();
self.construct_numpy_eye(generator, ctx, dtype, size, size, offset, name)
}
} }