forked from M-Labs/nac3
[core] codegen/ndarray: Reimplement np_{eye,identity}
Based on fa047d50: core/ndstrides: implement np_identity() and np_eye()
This commit is contained in:
parent
fadadd7505
commit
acb437919d
@ -18,10 +18,7 @@ use super::{
|
||||
llvm_intrinsics::{self, call_memcpy_generic},
|
||||
macros::codegen_unreachable,
|
||||
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
|
||||
types::ndarray::{
|
||||
factory::{ndarray_one_value, ndarray_zero_value},
|
||||
NDArrayType,
|
||||
},
|
||||
types::ndarray::{factory::ndarray_zero_value, NDArrayType},
|
||||
values::{
|
||||
ndarray::{shape::parse_numpy_int_sequence, NDArrayValue},
|
||||
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, ProxyValue,
|
||||
@ -406,55 +403,6 @@ where
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// LLVM-typed implementation for generating the implementation for `ndarray.eye`.
|
||||
///
|
||||
/// * `elem_ty` - The element type of the `NDArray`.
|
||||
fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
elem_ty: Type,
|
||||
nrows: IntValue<'ctx>,
|
||||
ncols: IntValue<'ctx>,
|
||||
offset: IntValue<'ctx>,
|
||||
) -> Result<NDArrayValue<'ctx>, String> {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let nrows = ctx.builder.build_int_z_extend_or_bit_cast(nrows, llvm_usize, "").unwrap();
|
||||
let ncols = ctx.builder.build_int_z_extend_or_bit_cast(ncols, llvm_usize, "").unwrap();
|
||||
|
||||
let ndarray = create_ndarray_const_shape(generator, ctx, elem_ty, &[nrows, ncols])?;
|
||||
|
||||
ndarray_fill_indexed(generator, ctx, ndarray, |generator, ctx, indices| {
|
||||
let (row, col) = unsafe {
|
||||
(
|
||||
indices.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None),
|
||||
indices.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None),
|
||||
)
|
||||
};
|
||||
|
||||
let col_with_offset = ctx
|
||||
.builder
|
||||
.build_int_add(
|
||||
col,
|
||||
ctx.builder.build_int_s_extend_or_bit_cast(offset, llvm_i32, "").unwrap(),
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
let is_on_diag =
|
||||
ctx.builder.build_int_compare(IntPredicate::EQ, row, col_with_offset, "").unwrap();
|
||||
|
||||
let zero = ndarray_zero_value(generator, ctx, elem_ty);
|
||||
let one = ndarray_one_value(generator, ctx, elem_ty);
|
||||
|
||||
let value = ctx.builder.build_select(is_on_diag, one, zero, "").unwrap();
|
||||
|
||||
Ok(value)
|
||||
})?;
|
||||
|
||||
Ok(ndarray)
|
||||
}
|
||||
|
||||
/// Copies a slice of an [`NDArrayValue`] to another.
|
||||
///
|
||||
/// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `shape`
|
||||
@ -1304,15 +1252,27 @@ pub fn gen_ndarray_eye<'ctx>(
|
||||
))
|
||||
}?;
|
||||
|
||||
call_ndarray_eye_impl(
|
||||
generator,
|
||||
context,
|
||||
context.primitives.float,
|
||||
nrows_arg.into_int_value(),
|
||||
ncols_arg.into_int_value(),
|
||||
offset_arg.into_int_value(),
|
||||
)
|
||||
.map(NDArrayValue::into)
|
||||
let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
|
||||
|
||||
let llvm_usize = generator.get_size_type(context.ctx);
|
||||
let llvm_dtype = context.get_llvm_type(generator, dtype);
|
||||
|
||||
let nrows = context
|
||||
.builder
|
||||
.build_int_s_extend_or_bit_cast(nrows_arg.into_int_value(), llvm_usize, "")
|
||||
.unwrap();
|
||||
let ncols = context
|
||||
.builder
|
||||
.build_int_s_extend_or_bit_cast(ncols_arg.into_int_value(), llvm_usize, "")
|
||||
.unwrap();
|
||||
let offset = context
|
||||
.builder
|
||||
.build_int_s_extend_or_bit_cast(offset_arg.into_int_value(), llvm_usize, "")
|
||||
.unwrap();
|
||||
|
||||
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(2))
|
||||
.construct_numpy_eye(generator, context, dtype, nrows, ncols, offset, None);
|
||||
Ok(ndarray.as_base_value())
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `ndarray.identity`.
|
||||
@ -1326,20 +1286,21 @@ pub fn gen_ndarray_identity<'ctx>(
|
||||
assert!(obj.is_none());
|
||||
assert_eq!(args.len(), 1);
|
||||
|
||||
let llvm_usize = generator.get_size_type(context.ctx);
|
||||
|
||||
let n_ty = fun.0.args[0].ty;
|
||||
let n_arg = args[0].1.clone().to_basic_value_enum(context, generator, n_ty)?;
|
||||
|
||||
call_ndarray_eye_impl(
|
||||
generator,
|
||||
context,
|
||||
context.primitives.float,
|
||||
n_arg.into_int_value(),
|
||||
n_arg.into_int_value(),
|
||||
llvm_usize.const_zero(),
|
||||
)
|
||||
.map(NDArrayValue::into)
|
||||
let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
|
||||
|
||||
let llvm_usize = generator.get_size_type(context.ctx);
|
||||
let llvm_dtype = context.get_llvm_type(generator, dtype);
|
||||
|
||||
let n = context
|
||||
.builder
|
||||
.build_int_s_extend_or_bit_cast(n_arg.into_int_value(), llvm_usize, "")
|
||||
.unwrap();
|
||||
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(2))
|
||||
.construct_numpy_identity(generator, context, dtype, n, None);
|
||||
Ok(ndarray.as_base_value())
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `ndarray.copy`.
|
||||
|
@ -1,4 +1,7 @@
|
||||
use inkwell::values::{BasicValueEnum, IntValue};
|
||||
use inkwell::{
|
||||
values::{BasicValueEnum, IntValue},
|
||||
IntPredicate,
|
||||
};
|
||||
|
||||
use super::NDArrayType;
|
||||
use crate::{
|
||||
@ -36,7 +39,7 @@ pub fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
|
||||
}
|
||||
|
||||
/// Get the one value in `np.ones()` of a `dtype`.
|
||||
pub fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
|
||||
fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
dtype: Type,
|
||||
@ -143,4 +146,91 @@ impl<'ctx> NDArrayType<'ctx> {
|
||||
let fill_value = ndarray_one_value(generator, ctx, dtype);
|
||||
self.construct_numpy_full(generator, ctx, shape, fill_value, name)
|
||||
}
|
||||
|
||||
/// Create an ndarray like
|
||||
/// [`np.eye`](https://numpy.org/doc/stable/reference/generated/numpy.eye.html).
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn construct_numpy_eye<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
dtype: Type,
|
||||
nrows: IntValue<'ctx>,
|
||||
ncols: IntValue<'ctx>,
|
||||
offset: IntValue<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
assert_eq!(
|
||||
ctx.get_llvm_type(generator, dtype),
|
||||
self.dtype,
|
||||
"Expected LLVM dtype={} but got {}",
|
||||
self.dtype.print_to_string(),
|
||||
ctx.get_llvm_type(generator, dtype).print_to_string(),
|
||||
);
|
||||
assert_eq!(nrows.get_type(), self.llvm_usize);
|
||||
assert_eq!(ncols.get_type(), self.llvm_usize);
|
||||
assert_eq!(offset.get_type(), self.llvm_usize);
|
||||
|
||||
let ndzero = ndarray_zero_value(generator, ctx, dtype);
|
||||
let ndone = ndarray_one_value(generator, ctx, dtype);
|
||||
|
||||
let ndarray = self.construct_dyn_shape(generator, ctx, &[nrows, ncols], name);
|
||||
|
||||
// Create data and make the matrix like look np.eye()
|
||||
unsafe {
|
||||
ndarray.create_data(generator, ctx);
|
||||
}
|
||||
ndarray
|
||||
.foreach(generator, ctx, |generator, ctx, _, nditer| {
|
||||
// NOTE: rows and cols can never be zero here, since this ndarray's `np.size` would be zero
|
||||
// and this loop would not execute.
|
||||
|
||||
let indices = nditer.get_indices();
|
||||
|
||||
let row_i = unsafe {
|
||||
indices.get_typed_unchecked(ctx, generator, &self.llvm_usize.const_zero(), None)
|
||||
};
|
||||
let col_i = unsafe {
|
||||
indices.get_typed_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&self.llvm_usize.const_int(1, false),
|
||||
None,
|
||||
)
|
||||
};
|
||||
|
||||
let be_one = ctx
|
||||
.builder
|
||||
.build_int_compare(
|
||||
IntPredicate::EQ,
|
||||
ctx.builder.build_int_add(row_i, offset, "").unwrap(),
|
||||
col_i,
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
let value = ctx.builder.build_select(be_one, ndone, ndzero, "value").unwrap();
|
||||
|
||||
let p = nditer.get_pointer(ctx);
|
||||
ctx.builder.build_store(p, value).unwrap();
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
ndarray
|
||||
}
|
||||
|
||||
/// Create an ndarray like
|
||||
/// [`np.identity`](https://numpy.org/doc/stable/reference/generated/numpy.identity.html).
|
||||
pub fn construct_numpy_identity<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
dtype: Type,
|
||||
size: IntValue<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
let offset = self.llvm_usize.const_zero();
|
||||
self.construct_numpy_eye(generator, ctx, dtype, size, size, offset, name)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user