diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 304cf587..918efcaa 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -1906,15 +1906,23 @@ pub fn gen_ndarray_eye<'ctx>( )) }?; - call_ndarray_eye_impl( - generator, - context, - context.primitives.float, - nrows_arg.into_int_value(), - ncols_arg.into_int_value(), - offset_arg.into_int_value(), - ) - .map(NDArrayValue::into) + let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + + let nrows = Int(Int32) + .check_value(generator, context.ctx, nrows_arg) + .unwrap() + .s_extend_or_bit_cast(generator, context, SizeT); + let ncols = Int(Int32) + .check_value(generator, context.ctx, ncols_arg) + .unwrap() + .s_extend_or_bit_cast(generator, context, SizeT); + let offset = Int(Int32) + .check_value(generator, context.ctx, offset_arg) + .unwrap() + .s_extend_or_bit_cast(generator, context, SizeT); + + let ndarray = NDArrayObject::make_np_eye(generator, context, dtype, nrows, ncols, offset); + Ok(ndarray.instance.value) } /// Generates LLVM IR for `ndarray.identity`. @@ -1928,20 +1936,15 @@ pub fn gen_ndarray_identity<'ctx>( assert!(obj.is_none()); assert_eq!(args.len(), 1); - let llvm_usize = generator.get_size_type(context.ctx); + let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); let n_ty = fun.0.args[0].ty; let n_arg = args[0].1.clone().to_basic_value_enum(context, generator, n_ty)?; - call_ndarray_eye_impl( - generator, - context, - context.primitives.float, - n_arg.into_int_value(), - n_arg.into_int_value(), - llvm_usize.const_zero(), - ) - .map(NDArrayValue::into) + let n = Int(Int32).check_value(generator, context.ctx, n_arg).unwrap(); + let n = n.s_extend_or_bit_cast(generator, context, SizeT); + let ndarray = NDArrayObject::make_np_identity(generator, context, dtype, n); + Ok(ndarray.instance.value) } /// Generates LLVM IR for `ndarray.copy`. diff --git a/nac3core/src/codegen/object/ndarray/factory.rs b/nac3core/src/codegen/object/ndarray/factory.rs index 9d32953b..c8d5517f 100644 --- a/nac3core/src/codegen/object/ndarray/factory.rs +++ b/nac3core/src/codegen/object/ndarray/factory.rs @@ -1,4 +1,4 @@ -use inkwell::values::BasicValueEnum; +use inkwell::{values::BasicValueEnum, IntPredicate}; use crate::{ codegen::{ @@ -123,4 +123,54 @@ impl<'ctx> NDArrayObject<'ctx> { let fill_value = ndarray_one_value(generator, ctx, dtype); NDArrayObject::make_np_full(generator, ctx, dtype, ndims, shape, fill_value) } + + /// Create an ndarray like `np.eye`. + pub fn make_np_eye( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + nrows: Instance<'ctx, Int>, + ncols: Instance<'ctx, Int>, + offset: Instance<'ctx, Int>, + ) -> Self { + let ndzero = ndarray_zero_value(generator, ctx, dtype); + let ndone = ndarray_one_value(generator, ctx, dtype); + + let ndarray = NDArrayObject::alloca_dynamic_shape(generator, ctx, dtype, &[nrows, ncols]); + + // Create data and make the matrix like look np.eye() + ndarray.create_data(generator, ctx); + ndarray + .foreach(generator, ctx, |generator, ctx, _hooks, nditer| { + // NOTE: rows and cols can never be zero here, since this ndarray's `np.size` would be zero + // and this loop would not execute. + + // Load up `row_i` and `col_i` from indices. + let row_i = nditer.get_indices().get_index_const(generator, ctx, 0); + let col_i = nditer.get_indices().get_index_const(generator, ctx, 1); + + let be_one = row_i.add(ctx, offset).compare(ctx, IntPredicate::EQ, col_i); + let value = ctx.builder.build_select(be_one.value, ndone, ndzero, "value").unwrap(); + + let p = nditer.get_pointer(generator, ctx); + ctx.builder.build_store(p, value).unwrap(); + + Ok(()) + }) + .unwrap(); + + ndarray + } + + /// Create an ndarray like `np.identity`. + pub fn make_np_identity( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + size: Instance<'ctx, Int>, + ) -> Self { + // Convenient implementation + let offset = Int(SizeT).const_0(generator, ctx.ctx); + NDArrayObject::make_np_eye(generator, ctx, dtype, size, size, offset) + } }