From acb437919d951e0c14a468a55d2908282cc16044 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 17 Dec 2024 18:01:12 +0800 Subject: [PATCH] [core] codegen/ndarray: Reimplement np_{eye,identity} Based on fa047d50: core/ndstrides: implement np_identity() and np_eye() --- nac3core/src/codegen/numpy.rs | 107 ++++++------------ nac3core/src/codegen/types/ndarray/factory.rs | 94 ++++++++++++++- 2 files changed, 126 insertions(+), 75 deletions(-) diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 09d848dc6..703d03ec2 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -18,10 +18,7 @@ use super::{ llvm_intrinsics::{self, call_memcpy_generic}, macros::codegen_unreachable, stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, - types::ndarray::{ - factory::{ndarray_one_value, ndarray_zero_value}, - NDArrayType, - }, + types::ndarray::{factory::ndarray_zero_value, NDArrayType}, values::{ ndarray::{shape::parse_numpy_int_sequence, NDArrayValue}, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, ProxyValue, @@ -406,55 +403,6 @@ where Ok(res) } -/// LLVM-typed implementation for generating the implementation for `ndarray.eye`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - nrows: IntValue<'ctx>, - ncols: IntValue<'ctx>, - offset: IntValue<'ctx>, -) -> Result, String> { - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let nrows = ctx.builder.build_int_z_extend_or_bit_cast(nrows, llvm_usize, "").unwrap(); - let ncols = ctx.builder.build_int_z_extend_or_bit_cast(ncols, llvm_usize, "").unwrap(); - - let ndarray = create_ndarray_const_shape(generator, ctx, elem_ty, &[nrows, ncols])?; - - ndarray_fill_indexed(generator, ctx, ndarray, |generator, ctx, indices| { - let (row, col) = unsafe { - ( - indices.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None), - indices.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None), - ) - }; - - let col_with_offset = ctx - .builder - .build_int_add( - col, - ctx.builder.build_int_s_extend_or_bit_cast(offset, llvm_i32, "").unwrap(), - "", - ) - .unwrap(); - let is_on_diag = - ctx.builder.build_int_compare(IntPredicate::EQ, row, col_with_offset, "").unwrap(); - - let zero = ndarray_zero_value(generator, ctx, elem_ty); - let one = ndarray_one_value(generator, ctx, elem_ty); - - let value = ctx.builder.build_select(is_on_diag, one, zero, "").unwrap(); - - Ok(value) - })?; - - Ok(ndarray) -} - /// Copies a slice of an [`NDArrayValue`] to another. /// /// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `shape` @@ -1304,15 +1252,27 @@ pub fn gen_ndarray_eye<'ctx>( )) }?; - call_ndarray_eye_impl( - generator, - context, - context.primitives.float, - nrows_arg.into_int_value(), - ncols_arg.into_int_value(), - offset_arg.into_int_value(), - ) - .map(NDArrayValue::into) + let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + + let llvm_usize = generator.get_size_type(context.ctx); + let llvm_dtype = context.get_llvm_type(generator, dtype); + + let nrows = context + .builder + .build_int_s_extend_or_bit_cast(nrows_arg.into_int_value(), llvm_usize, "") + .unwrap(); + let ncols = context + .builder + .build_int_s_extend_or_bit_cast(ncols_arg.into_int_value(), llvm_usize, "") + .unwrap(); + let offset = context + .builder + .build_int_s_extend_or_bit_cast(offset_arg.into_int_value(), llvm_usize, "") + .unwrap(); + + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(2)) + .construct_numpy_eye(generator, context, dtype, nrows, ncols, offset, None); + Ok(ndarray.as_base_value()) } /// Generates LLVM IR for `ndarray.identity`. @@ -1326,20 +1286,21 @@ pub fn gen_ndarray_identity<'ctx>( assert!(obj.is_none()); assert_eq!(args.len(), 1); - let llvm_usize = generator.get_size_type(context.ctx); - let n_ty = fun.0.args[0].ty; let n_arg = args[0].1.clone().to_basic_value_enum(context, generator, n_ty)?; - call_ndarray_eye_impl( - generator, - context, - context.primitives.float, - n_arg.into_int_value(), - n_arg.into_int_value(), - llvm_usize.const_zero(), - ) - .map(NDArrayValue::into) + let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + + let llvm_usize = generator.get_size_type(context.ctx); + let llvm_dtype = context.get_llvm_type(generator, dtype); + + let n = context + .builder + .build_int_s_extend_or_bit_cast(n_arg.into_int_value(), llvm_usize, "") + .unwrap(); + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(2)) + .construct_numpy_identity(generator, context, dtype, n, None); + Ok(ndarray.as_base_value()) } /// Generates LLVM IR for `ndarray.copy`. diff --git a/nac3core/src/codegen/types/ndarray/factory.rs b/nac3core/src/codegen/types/ndarray/factory.rs index 13aae8cd5..300167f7e 100644 --- a/nac3core/src/codegen/types/ndarray/factory.rs +++ b/nac3core/src/codegen/types/ndarray/factory.rs @@ -1,4 +1,7 @@ -use inkwell::values::{BasicValueEnum, IntValue}; +use inkwell::{ + values::{BasicValueEnum, IntValue}, + IntPredicate, +}; use super::NDArrayType; use crate::{ @@ -36,7 +39,7 @@ pub fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>( } /// Get the one value in `np.ones()` of a `dtype`. -pub fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>( +fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, dtype: Type, @@ -143,4 +146,91 @@ impl<'ctx> NDArrayType<'ctx> { let fill_value = ndarray_one_value(generator, ctx, dtype); self.construct_numpy_full(generator, ctx, shape, fill_value, name) } + + /// Create an ndarray like + /// [`np.eye`](https://numpy.org/doc/stable/reference/generated/numpy.eye.html). + #[allow(clippy::too_many_arguments)] + pub fn construct_numpy_eye( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + nrows: IntValue<'ctx>, + ncols: IntValue<'ctx>, + offset: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + assert_eq!( + ctx.get_llvm_type(generator, dtype), + self.dtype, + "Expected LLVM dtype={} but got {}", + self.dtype.print_to_string(), + ctx.get_llvm_type(generator, dtype).print_to_string(), + ); + assert_eq!(nrows.get_type(), self.llvm_usize); + assert_eq!(ncols.get_type(), self.llvm_usize); + assert_eq!(offset.get_type(), self.llvm_usize); + + let ndzero = ndarray_zero_value(generator, ctx, dtype); + let ndone = ndarray_one_value(generator, ctx, dtype); + + let ndarray = self.construct_dyn_shape(generator, ctx, &[nrows, ncols], name); + + // Create data and make the matrix like look np.eye() + unsafe { + ndarray.create_data(generator, ctx); + } + ndarray + .foreach(generator, ctx, |generator, ctx, _, nditer| { + // NOTE: rows and cols can never be zero here, since this ndarray's `np.size` would be zero + // and this loop would not execute. + + let indices = nditer.get_indices(); + + let row_i = unsafe { + indices.get_typed_unchecked(ctx, generator, &self.llvm_usize.const_zero(), None) + }; + let col_i = unsafe { + indices.get_typed_unchecked( + ctx, + generator, + &self.llvm_usize.const_int(1, false), + None, + ) + }; + + let be_one = ctx + .builder + .build_int_compare( + IntPredicate::EQ, + ctx.builder.build_int_add(row_i, offset, "").unwrap(), + col_i, + "", + ) + .unwrap(); + let value = ctx.builder.build_select(be_one, ndone, ndzero, "value").unwrap(); + + let p = nditer.get_pointer(ctx); + ctx.builder.build_store(p, value).unwrap(); + + Ok(()) + }) + .unwrap(); + + ndarray + } + + /// Create an ndarray like + /// [`np.identity`](https://numpy.org/doc/stable/reference/generated/numpy.identity.html). + pub fn construct_numpy_identity( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + size: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + let offset = self.llvm_usize.const_zero(); + self.construct_numpy_eye(generator, ctx, dtype, size, size, offset, name) + } }