forked from M-Labs/nac3
core/ndstrides: implement np_identity() and np_eye()
This commit is contained in:
parent
0ec4d13735
commit
a004703ec2
|
@ -1905,15 +1905,23 @@ pub fn gen_ndarray_eye<'ctx>(
|
||||||
))
|
))
|
||||||
}?;
|
}?;
|
||||||
|
|
||||||
call_ndarray_eye_impl(
|
let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
|
||||||
generator,
|
|
||||||
context,
|
let nrows = Int(Int32)
|
||||||
context.primitives.float,
|
.check_value(generator, context.ctx, nrows_arg)
|
||||||
nrows_arg.into_int_value(),
|
.unwrap()
|
||||||
ncols_arg.into_int_value(),
|
.s_extend_or_bit_cast(generator, context, SizeT);
|
||||||
offset_arg.into_int_value(),
|
let ncols = Int(Int32)
|
||||||
)
|
.check_value(generator, context.ctx, ncols_arg)
|
||||||
.map(NDArrayValue::into)
|
.unwrap()
|
||||||
|
.s_extend_or_bit_cast(generator, context, SizeT);
|
||||||
|
let offset = Int(Int32)
|
||||||
|
.check_value(generator, context.ctx, offset_arg)
|
||||||
|
.unwrap()
|
||||||
|
.s_extend_or_bit_cast(generator, context, SizeT);
|
||||||
|
|
||||||
|
let ndarray = NDArrayObject::make_np_eye(generator, context, dtype, nrows, ncols, offset);
|
||||||
|
Ok(ndarray.instance.value)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates LLVM IR for `ndarray.identity`.
|
/// Generates LLVM IR for `ndarray.identity`.
|
||||||
|
@ -1927,20 +1935,15 @@ pub fn gen_ndarray_identity<'ctx>(
|
||||||
assert!(obj.is_none());
|
assert!(obj.is_none());
|
||||||
assert_eq!(args.len(), 1);
|
assert_eq!(args.len(), 1);
|
||||||
|
|
||||||
let llvm_usize = generator.get_size_type(context.ctx);
|
let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
|
||||||
|
|
||||||
let n_ty = fun.0.args[0].ty;
|
let n_ty = fun.0.args[0].ty;
|
||||||
let n_arg = args[0].1.clone().to_basic_value_enum(context, generator, n_ty)?;
|
let n_arg = args[0].1.clone().to_basic_value_enum(context, generator, n_ty)?;
|
||||||
|
|
||||||
call_ndarray_eye_impl(
|
let n = Int(Int32).check_value(generator, context.ctx, n_arg).unwrap();
|
||||||
generator,
|
let n = n.s_extend_or_bit_cast(generator, context, SizeT);
|
||||||
context,
|
let ndarray = NDArrayObject::make_np_identity(generator, context, dtype, n);
|
||||||
context.primitives.float,
|
Ok(ndarray.instance.value)
|
||||||
n_arg.into_int_value(),
|
|
||||||
n_arg.into_int_value(),
|
|
||||||
llvm_usize.const_zero(),
|
|
||||||
)
|
|
||||||
.map(NDArrayValue::into)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates LLVM IR for `ndarray.copy`.
|
/// Generates LLVM IR for `ndarray.copy`.
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use inkwell::values::BasicValueEnum;
|
use inkwell::{values::BasicValueEnum, IntPredicate};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
codegen::{
|
codegen::{
|
||||||
|
@ -123,4 +123,54 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
let fill_value = ndarray_one_value(generator, ctx, dtype);
|
let fill_value = ndarray_one_value(generator, ctx, dtype);
|
||||||
NDArrayObject::make_np_full(generator, ctx, dtype, ndims, shape, fill_value)
|
NDArrayObject::make_np_full(generator, ctx, dtype, ndims, shape, fill_value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create an ndarray like `np.eye`.
|
||||||
|
pub fn make_np_eye<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
dtype: Type,
|
||||||
|
nrows: Instance<'ctx, Int<SizeT>>,
|
||||||
|
ncols: Instance<'ctx, Int<SizeT>>,
|
||||||
|
offset: Instance<'ctx, Int<SizeT>>,
|
||||||
|
) -> Self {
|
||||||
|
let ndzero = ndarray_zero_value(generator, ctx, dtype);
|
||||||
|
let ndone = ndarray_one_value(generator, ctx, dtype);
|
||||||
|
|
||||||
|
let ndarray = NDArrayObject::alloca_dynamic_shape(generator, ctx, dtype, &[nrows, ncols]);
|
||||||
|
|
||||||
|
// Create data and make the matrix like look np.eye()
|
||||||
|
ndarray.create_data(generator, ctx);
|
||||||
|
ndarray
|
||||||
|
.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
|
||||||
|
// NOTE: rows and cols can never be zero here, since this ndarray's `np.size` would be zero
|
||||||
|
// and this loop would not execute.
|
||||||
|
|
||||||
|
// Load up `row_i` and `col_i` from indices.
|
||||||
|
let row_i = nditer.get_indices().get_index_const(generator, ctx, 0);
|
||||||
|
let col_i = nditer.get_indices().get_index_const(generator, ctx, 1);
|
||||||
|
|
||||||
|
let be_one = row_i.add(ctx, offset).compare(ctx, IntPredicate::EQ, col_i);
|
||||||
|
let value = ctx.builder.build_select(be_one.value, ndone, ndzero, "value").unwrap();
|
||||||
|
|
||||||
|
let p = nditer.get_pointer(generator, ctx);
|
||||||
|
ctx.builder.build_store(p, value).unwrap();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
ndarray
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create an ndarray like `np.identity`.
|
||||||
|
pub fn make_np_identity<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
dtype: Type,
|
||||||
|
size: Instance<'ctx, Int<SizeT>>,
|
||||||
|
) -> Self {
|
||||||
|
// Convenient implementation
|
||||||
|
let offset = Int(SizeT).const_0(generator, ctx.ctx);
|
||||||
|
NDArrayObject::make_np_eye(generator, ctx, dtype, size, size, offset)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue