From f78a60a644a129e545ec742b1cec711c3095be41 Mon Sep 17 00:00:00 2001 From: lyken Date: Fri, 26 Jul 2024 15:26:39 +0800 Subject: [PATCH] core/codegen: add ArrayWriter & parse_input_shape_arg --- nac3core/src/codegen/mod.rs | 1 + nac3core/src/codegen/util/array_writer.rs | 17 +++ nac3core/src/codegen/util/mod.rs | 2 + nac3core/src/codegen/util/shape.rs | 144 ++++++++++++++++++++++ 4 files changed, 164 insertions(+) create mode 100644 nac3core/src/codegen/util/array_writer.rs create mode 100644 nac3core/src/codegen/util/mod.rs create mode 100644 nac3core/src/codegen/util/shape.rs diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 7ce06b35..09d65494 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -47,6 +47,7 @@ pub mod model; pub mod numpy; pub mod stmt; pub mod structs; +pub mod util; #[cfg(test)] mod test; diff --git a/nac3core/src/codegen/util/array_writer.rs b/nac3core/src/codegen/util/array_writer.rs new file mode 100644 index 00000000..ab88b7d5 --- /dev/null +++ b/nac3core/src/codegen/util/array_writer.rs @@ -0,0 +1,17 @@ +use inkwell::{types::IntType, values::IntValue}; + +use crate::codegen::{model::*, CodeGenContext, CodeGenerator}; + +pub type ArrayWriterWrite<'ctx, G, N, E> = Box< + dyn Fn(&mut G, &mut CodeGenContext<'ctx, '_>, &ArraySlice<'ctx, N, E>) -> Result<(), String> + + 'ctx, +>; + +// TODO: Document +pub struct ArrayWriter<'ctx, G: CodeGenerator + ?Sized, N, E: Model<'ctx>> +where + N: Model<'ctx, Value = IntValue<'ctx>, Type = IntType<'ctx>>, +{ + pub count: Instance<'ctx, N>, + pub write: ArrayWriterWrite<'ctx, G, N, E>, +} diff --git a/nac3core/src/codegen/util/mod.rs b/nac3core/src/codegen/util/mod.rs new file mode 100644 index 00000000..b3ffc4ec --- /dev/null +++ b/nac3core/src/codegen/util/mod.rs @@ -0,0 +1,2 @@ +pub mod array_writer; +pub mod shape; diff --git a/nac3core/src/codegen/util/shape.rs b/nac3core/src/codegen/util/shape.rs new file mode 100644 index 00000000..b37e13d3 --- /dev/null +++ b/nac3core/src/codegen/util/shape.rs @@ -0,0 +1,144 @@ +use inkwell::values::BasicValueEnum; + +use crate::{ + codegen::{ + classes::{ListValue, UntypedArrayLikeAccessor}, + model::*, + stmt::gen_for_callback_incrementing, + CodeGenContext, CodeGenerator, + }, + typecheck::typedef::{Type, TypeEnum}, +}; + +use super::array_writer::ArrayWriter; + +/// TODO: UPDATE DOCUMENTATION +/// LLVM-typed implementation for generating a [`ArrayWriter`] that sets a list of ints. +/// +/// * `elem_ty` - The element type of the `NDArray`. +/// * `shape` - The `shape` parameter used to construct the `NDArray`. +/// +/// ### Notes on `shape` +/// +/// Just like numpy, the `shape` argument can be: +/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])` +/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))` +/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` +/// +/// See also [`typecheck::type_inferencer::fold_numpy_function_call_shape_argument`] to +/// learn how `shape` gets from being a Python user expression to here. +pub fn parse_input_shape_arg<'ctx, G>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + shape: BasicValueEnum<'ctx>, + shape_ty: Type, +) -> ArrayWriter<'ctx, G, SizeTModel<'ctx>, SizeTModel<'ctx>> +where + G: CodeGenerator + ?Sized, +{ + let sizet = generator.get_sizet(ctx.ctx); + + match &*ctx.unifier.get_ty(shape_ty) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + // 1. A list of ints; e.g., `np.empty([600, 800, 3])` + + // A list has to be a PointerValue + let shape_list = ListValue::from_ptr_val(shape.into_pointer_value(), sizet.0, None); + + // Create `ArrayWriter` + let ndims = + sizet.review_value(ctx.ctx, shape_list.load_size(ctx, Some("count"))).unwrap(); + ArrayWriter { + count: ndims, + write: Box::new(move |generator, ctx, dst_array| { + // Basically iterate through the list and write to `dst_slice` accordingly + let init_val = sizet.constant(ctx.ctx, 0).value; + let max_val = (ndims.value, false); + let incr_val = sizet.constant(ctx.ctx, 1).value; + gen_for_callback_incrementing( + generator, + ctx, + init_val, + max_val, + |generator, ctx, _hooks, axis| { + let axis = sizet.review_value(ctx.ctx, axis).unwrap(); + + // TODO: Remove ProxyValue ListValue + + // Get the dimension at `axis` + let dim: Int<'ctx> = shape_list + .data() + .get(ctx, generator, &axis.value, None) + .into_int_value() + .into(); + + let dim = dim.s_extend_or_bit_cast(ctx, sizet, "dim_casted"); + dst_array.ix(generator, ctx, axis, "dim").store(ctx, dim); + Ok(()) + }, + incr_val, + ) + }), + } + } + TypeEnum::TTuple { ty: tuple_types } => { + // 2. A tuple of ints; e.g., `np.empty((600, 800, 3))` + + // Get the length/size of the tuple, which also happens to be the value of `ndims`. + let ndims = tuple_types.len(); + + // A tuple has to be a StructValue + // Read [`codegen::expr::gen_expr`] to see how `nac3core` translates a Python tuple into LLVM. + let shape_tuple = shape.into_struct_value(); + + ArrayWriter { + count: sizet.constant(ctx.ctx, ndims as u64), + write: Box::new(move |generator, ctx, dst_array| { + for axis in 0..ndims { + // Get the dimension at `axis` + let dim: Int<'ctx> = ctx + .builder + .build_extract_value( + shape_tuple, + axis as u32, + format!("dim{axis}").as_str(), + ) + .unwrap() + .into_int_value() + .into(); + + let dim = dim.s_extend_or_bit_cast(ctx, sizet, "dim_casted"); + dst_array + .ix(generator, ctx, sizet.constant(ctx.ctx, axis as u64), "dim") + .store(ctx, dim); + } + Ok(()) + }), + } + } + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() => + { + // 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` + + // The value has to be an integer + let shape_int: Int<'ctx> = shape.into_int_value().into(); + + ArrayWriter { + count: sizet.constant(ctx.ctx, 1), + write: Box::new(move |generator, ctx, dst_array| { + // Cast `shape_int` to SizeT + let dim = shape_int.s_extend_or_bit_cast(ctx, sizet, "dim_casted"); + + // Set shape[0] = shape_int + dst_array.ix(generator, ctx, sizet.constant(ctx.ctx, 0), "dim").store(ctx, dim); + + Ok(()) + }), + } + } + _ => panic!("parse_input_shape_arg encountered unknown type"), + } +}