[core] codegen/ndarray: Reimplement np_{eye,identity}
Based on fa047d50
: core/ndstrides: implement np_identity() and np_eye()
This commit is contained in:
parent
1eb462a5c2
commit
6bd81ce1ac
@ -18,10 +18,7 @@ use super::{
|
|||||||
llvm_intrinsics::{self, call_memcpy_generic},
|
llvm_intrinsics::{self, call_memcpy_generic},
|
||||||
macros::codegen_unreachable,
|
macros::codegen_unreachable,
|
||||||
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
|
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
|
||||||
types::ndarray::{
|
types::ndarray::{factory::ndarray_zero_value, NDArrayType},
|
||||||
factory::{ndarray_one_value, ndarray_zero_value},
|
|
||||||
NDArrayType,
|
|
||||||
},
|
|
||||||
values::{
|
values::{
|
||||||
ndarray::{shape::parse_numpy_int_sequence, NDArrayValue},
|
ndarray::{shape::parse_numpy_int_sequence, NDArrayValue},
|
||||||
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, ProxyValue,
|
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, ProxyValue,
|
||||||
@ -406,55 +403,6 @@ where
|
|||||||
Ok(res)
|
Ok(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// LLVM-typed implementation for generating the implementation for `ndarray.eye`.
|
|
||||||
///
|
|
||||||
/// * `elem_ty` - The element type of the `NDArray`.
|
|
||||||
fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
elem_ty: Type,
|
|
||||||
nrows: IntValue<'ctx>,
|
|
||||||
ncols: IntValue<'ctx>,
|
|
||||||
offset: IntValue<'ctx>,
|
|
||||||
) -> Result<NDArrayValue<'ctx>, String> {
|
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
|
|
||||||
let nrows = ctx.builder.build_int_z_extend_or_bit_cast(nrows, llvm_usize, "").unwrap();
|
|
||||||
let ncols = ctx.builder.build_int_z_extend_or_bit_cast(ncols, llvm_usize, "").unwrap();
|
|
||||||
|
|
||||||
let ndarray = create_ndarray_const_shape(generator, ctx, elem_ty, &[nrows, ncols])?;
|
|
||||||
|
|
||||||
ndarray_fill_indexed(generator, ctx, ndarray, |generator, ctx, indices| {
|
|
||||||
let (row, col) = unsafe {
|
|
||||||
(
|
|
||||||
indices.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None),
|
|
||||||
indices.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None),
|
|
||||||
)
|
|
||||||
};
|
|
||||||
|
|
||||||
let col_with_offset = ctx
|
|
||||||
.builder
|
|
||||||
.build_int_add(
|
|
||||||
col,
|
|
||||||
ctx.builder.build_int_s_extend_or_bit_cast(offset, llvm_i32, "").unwrap(),
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
let is_on_diag =
|
|
||||||
ctx.builder.build_int_compare(IntPredicate::EQ, row, col_with_offset, "").unwrap();
|
|
||||||
|
|
||||||
let zero = ndarray_zero_value(generator, ctx, elem_ty);
|
|
||||||
let one = ndarray_one_value(generator, ctx, elem_ty);
|
|
||||||
|
|
||||||
let value = ctx.builder.build_select(is_on_diag, one, zero, "").unwrap();
|
|
||||||
|
|
||||||
Ok(value)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok(ndarray)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Copies a slice of an [`NDArrayValue`] to another.
|
/// Copies a slice of an [`NDArrayValue`] to another.
|
||||||
///
|
///
|
||||||
/// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `shape`
|
/// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `shape`
|
||||||
@ -1304,15 +1252,27 @@ pub fn gen_ndarray_eye<'ctx>(
|
|||||||
))
|
))
|
||||||
}?;
|
}?;
|
||||||
|
|
||||||
call_ndarray_eye_impl(
|
let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
|
||||||
generator,
|
|
||||||
context,
|
let llvm_usize = generator.get_size_type(context.ctx);
|
||||||
context.primitives.float,
|
let llvm_dtype = context.get_llvm_type(generator, dtype);
|
||||||
nrows_arg.into_int_value(),
|
|
||||||
ncols_arg.into_int_value(),
|
let nrows = context
|
||||||
offset_arg.into_int_value(),
|
.builder
|
||||||
)
|
.build_int_s_extend_or_bit_cast(nrows_arg.into_int_value(), llvm_usize, "")
|
||||||
.map(NDArrayValue::into)
|
.unwrap();
|
||||||
|
let ncols = context
|
||||||
|
.builder
|
||||||
|
.build_int_s_extend_or_bit_cast(ncols_arg.into_int_value(), llvm_usize, "")
|
||||||
|
.unwrap();
|
||||||
|
let offset = context
|
||||||
|
.builder
|
||||||
|
.build_int_s_extend_or_bit_cast(offset_arg.into_int_value(), llvm_usize, "")
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(2))
|
||||||
|
.construct_numpy_eye(generator, context, dtype, nrows, ncols, offset, None);
|
||||||
|
Ok(ndarray.as_base_value())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates LLVM IR for `ndarray.identity`.
|
/// Generates LLVM IR for `ndarray.identity`.
|
||||||
@ -1326,20 +1286,21 @@ pub fn gen_ndarray_identity<'ctx>(
|
|||||||
assert!(obj.is_none());
|
assert!(obj.is_none());
|
||||||
assert_eq!(args.len(), 1);
|
assert_eq!(args.len(), 1);
|
||||||
|
|
||||||
let llvm_usize = generator.get_size_type(context.ctx);
|
|
||||||
|
|
||||||
let n_ty = fun.0.args[0].ty;
|
let n_ty = fun.0.args[0].ty;
|
||||||
let n_arg = args[0].1.clone().to_basic_value_enum(context, generator, n_ty)?;
|
let n_arg = args[0].1.clone().to_basic_value_enum(context, generator, n_ty)?;
|
||||||
|
|
||||||
call_ndarray_eye_impl(
|
let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
|
||||||
generator,
|
|
||||||
context,
|
let llvm_usize = generator.get_size_type(context.ctx);
|
||||||
context.primitives.float,
|
let llvm_dtype = context.get_llvm_type(generator, dtype);
|
||||||
n_arg.into_int_value(),
|
|
||||||
n_arg.into_int_value(),
|
let n = context
|
||||||
llvm_usize.const_zero(),
|
.builder
|
||||||
)
|
.build_int_s_extend_or_bit_cast(n_arg.into_int_value(), llvm_usize, "")
|
||||||
.map(NDArrayValue::into)
|
.unwrap();
|
||||||
|
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(2))
|
||||||
|
.construct_numpy_identity(generator, context, dtype, n, None);
|
||||||
|
Ok(ndarray.as_base_value())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates LLVM IR for `ndarray.copy`.
|
/// Generates LLVM IR for `ndarray.copy`.
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
use inkwell::values::{BasicValueEnum, IntValue};
|
use inkwell::{
|
||||||
|
values::{BasicValueEnum, IntValue},
|
||||||
|
IntPredicate,
|
||||||
|
};
|
||||||
|
|
||||||
use super::NDArrayType;
|
use super::NDArrayType;
|
||||||
use crate::{
|
use crate::{
|
||||||
@ -36,7 +39,7 @@ pub fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Get the one value in `np.ones()` of a `dtype`.
|
/// Get the one value in `np.ones()` of a `dtype`.
|
||||||
pub fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
|
fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
dtype: Type,
|
dtype: Type,
|
||||||
@ -139,4 +142,89 @@ impl<'ctx> NDArrayType<'ctx> {
|
|||||||
let fill_value = ndarray_one_value(generator, ctx, dtype);
|
let fill_value = ndarray_one_value(generator, ctx, dtype);
|
||||||
self.construct_numpy_full(generator, ctx, shape, fill_value, name)
|
self.construct_numpy_full(generator, ctx, shape, fill_value, name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create an ndarray like `np.eye`.
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn construct_numpy_eye<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
dtype: Type,
|
||||||
|
nrows: IntValue<'ctx>,
|
||||||
|
ncols: IntValue<'ctx>,
|
||||||
|
offset: IntValue<'ctx>,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
) -> <Self as ProxyType<'ctx>>::Value {
|
||||||
|
assert_eq!(
|
||||||
|
ctx.get_llvm_type(generator, dtype),
|
||||||
|
self.dtype,
|
||||||
|
"Expected LLVM dtype={} but got {}",
|
||||||
|
self.dtype.print_to_string(),
|
||||||
|
ctx.get_llvm_type(generator, dtype).print_to_string(),
|
||||||
|
);
|
||||||
|
assert_eq!(nrows.get_type(), self.llvm_usize);
|
||||||
|
assert_eq!(ncols.get_type(), self.llvm_usize);
|
||||||
|
assert_eq!(offset.get_type(), self.llvm_usize);
|
||||||
|
|
||||||
|
let ndzero = ndarray_zero_value(generator, ctx, dtype);
|
||||||
|
let ndone = ndarray_one_value(generator, ctx, dtype);
|
||||||
|
|
||||||
|
let ndarray = self.construct_dyn_shape(generator, ctx, &[nrows, ncols], name);
|
||||||
|
|
||||||
|
// Create data and make the matrix like look np.eye()
|
||||||
|
unsafe {
|
||||||
|
ndarray.create_data(generator, ctx);
|
||||||
|
}
|
||||||
|
ndarray
|
||||||
|
.foreach(generator, ctx, |generator, ctx, _, nditer| {
|
||||||
|
// NOTE: rows and cols can never be zero here, since this ndarray's `np.size` would be zero
|
||||||
|
// and this loop would not execute.
|
||||||
|
|
||||||
|
let indices = nditer.get_indices();
|
||||||
|
|
||||||
|
let row_i = unsafe {
|
||||||
|
indices.get_typed_unchecked(ctx, generator, &self.llvm_usize.const_zero(), None)
|
||||||
|
};
|
||||||
|
let col_i = unsafe {
|
||||||
|
indices.get_typed_unchecked(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
&self.llvm_usize.const_int(1, false),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
let be_one = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(
|
||||||
|
IntPredicate::EQ,
|
||||||
|
ctx.builder.build_int_add(row_i, offset, "").unwrap(),
|
||||||
|
col_i,
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let value = ctx.builder.build_select(be_one, ndone, ndzero, "value").unwrap();
|
||||||
|
|
||||||
|
let p = nditer.get_pointer(ctx);
|
||||||
|
ctx.builder.build_store(p, value).unwrap();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
ndarray
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create an ndarray like `np.identity`.
|
||||||
|
pub fn construct_numpy_identity<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
dtype: Type,
|
||||||
|
size: IntValue<'ctx>,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
) -> <Self as ProxyType<'ctx>>::Value {
|
||||||
|
let offset = self.llvm_usize.const_zero();
|
||||||
|
self.construct_numpy_eye(generator, ctx, dtype, size, size, offset, name)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user