From e5fe86cc933835b11c8f7abcd8076e59bca05a43 Mon Sep 17 00:00:00 2001 From: lyken Date: Sun, 28 Jul 2024 15:43:38 +0800 Subject: [PATCH] core/ndstrides: add ArrayWriter & make_shape_writer --- nac3core/src/codegen/mod.rs | 1 + nac3core/src/codegen/model/ptr.rs | 1 + nac3core/src/codegen/stmt.rs | 10 +- nac3core/src/codegen/util/array_writer.rs | 17 +++ nac3core/src/codegen/util/control.rs | 42 +++++++ nac3core/src/codegen/util/mod.rs | 3 + nac3core/src/codegen/util/shape.rs | 127 ++++++++++++++++++++++ 7 files changed, 198 insertions(+), 3 deletions(-) create mode 100644 nac3core/src/codegen/util/array_writer.rs create mode 100644 nac3core/src/codegen/util/control.rs create mode 100644 nac3core/src/codegen/util/mod.rs create mode 100644 nac3core/src/codegen/util/shape.rs diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 244b8f8d..c78710a8 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -47,6 +47,7 @@ pub mod model; pub mod numpy; pub mod stmt; pub mod structure; +pub mod util; #[cfg(test)] mod test; diff --git a/nac3core/src/codegen/model/ptr.rs b/nac3core/src/codegen/model/ptr.rs index 6fe5deef..d4fe8f95 100644 --- a/nac3core/src/codegen/model/ptr.rs +++ b/nac3core/src/codegen/model/ptr.rs @@ -75,6 +75,7 @@ impl PtrModel { impl<'ctx, Element: Model> Ptr<'ctx, Element> { /// Offset the pointer by [`inkwell::builder::Builder::build_in_bounds_gep`]. + #[must_use] pub fn offset( &self, tyctx: TypeContext<'ctx>, diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 88f80c97..f7ab5ee7 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -508,8 +508,12 @@ where I: Clone, InitFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result, CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result, String>, - BodyFn: - FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, BreakContinueHooks, I) -> Result<(), String>, + BodyFn: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + BreakContinueHooks<'ctx>, + I, + ) -> Result<(), String>, UpdateFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>, { let label = label.unwrap_or("for"); @@ -589,7 +593,7 @@ where BodyFn: FnOnce( &mut G, &mut CodeGenContext<'ctx, 'a>, - BreakContinueHooks, + BreakContinueHooks<'ctx>, IntValue<'ctx>, ) -> Result<(), String>, { diff --git a/nac3core/src/codegen/util/array_writer.rs b/nac3core/src/codegen/util/array_writer.rs new file mode 100644 index 00000000..616d06b9 --- /dev/null +++ b/nac3core/src/codegen/util/array_writer.rs @@ -0,0 +1,17 @@ +use crate::codegen::{model::*, CodeGenContext, CodeGenerator}; + +/// A closure containing details on how to write to/initialize an array. +#[allow(clippy::type_complexity)] +pub struct ArrayWriter<'ctx, G: CodeGenerator + ?Sized, Len: IntKind, Item: Model> { + /// Number of items to write + pub len: Int<'ctx, Len>, + /// Implementation to write to an array given its base pointer. + pub write: Box< + dyn Fn( + &mut G, + &mut CodeGenContext<'ctx, '_>, + Ptr<'ctx, Item>, // Base pointer + ) -> Result<(), String> + + 'ctx, + >, +} diff --git a/nac3core/src/codegen/util/control.rs b/nac3core/src/codegen/util/control.rs new file mode 100644 index 00000000..03c662da --- /dev/null +++ b/nac3core/src/codegen/util/control.rs @@ -0,0 +1,42 @@ +use crate::codegen::{ + model::*, + stmt::{gen_for_callback_incrementing, BreakContinueHooks}, + CodeGenContext, CodeGenerator, +}; + +// TODO: Document +// TODO: Rename function +/// Only allows positive steps +pub fn gen_model_for<'ctx, 'a, G, F, I>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + start: Int<'ctx, I>, + stop: Int<'ctx, I>, + step: Int<'ctx, I>, + body: F, +) -> Result<(), String> +where + G: CodeGenerator + ?Sized, + F: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + BreakContinueHooks<'ctx>, + Int<'ctx, I>, + ) -> Result<(), String>, + I: IntKind, +{ + let int_model = IntModel(I::default()); + + gen_for_callback_incrementing( + generator, + ctx, + None, + start.value, + (stop.value, false), + |g, ctx, hooks, i| { + let i = int_model.believe_value(i); + body(g, ctx, hooks, i) + }, + step.value, + ) +} diff --git a/nac3core/src/codegen/util/mod.rs b/nac3core/src/codegen/util/mod.rs new file mode 100644 index 00000000..10107074 --- /dev/null +++ b/nac3core/src/codegen/util/mod.rs @@ -0,0 +1,3 @@ +pub mod array_writer; +pub mod control; +pub mod shape; diff --git a/nac3core/src/codegen/util/shape.rs b/nac3core/src/codegen/util/shape.rs new file mode 100644 index 00000000..5a6862c4 --- /dev/null +++ b/nac3core/src/codegen/util/shape.rs @@ -0,0 +1,127 @@ +use inkwell::values::BasicValueEnum; + +use crate::{ + codegen::{ + classes::{ListValue, UntypedArrayLikeAccessor}, + model::*, + CodeGenContext, CodeGenerator, + }, + typecheck::typedef::{Type, TypeEnum}, +}; + +use super::{array_writer::ArrayWriter, control::gen_model_for}; + +// TODO: Generalize to complex iterables under a common interface +/// Create an [`ArrayWriter`] from a NumPy-like `shape` argument input. +/// * `shape` - The `shape` parameter. +/// * `shape_ty` - The element type of the `NDArray`. +/// +/// The `shape` argument type may only be one of the following: +/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])` +/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))` +/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` +/// +/// The `int32` values will be sign-extended to `SizeT` +pub fn make_shape_writer<'ctx, G>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + shape: BasicValueEnum<'ctx>, + shape_ty: Type, +) -> ArrayWriter<'ctx, G, SizeT, IntModel> +where + G: CodeGenerator + ?Sized, +{ + let tyctx = generator.type_context(ctx.ctx); + let sizet_model = IntModel(SizeT); + + match &*ctx.unifier.get_ty(shape_ty) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + // 1. A list of `int32`; e.g., `np.empty([600, 800, 3])` + + // TODO: Remove ListValue with Model + + let shape = ListValue::from_ptr_val(shape.into_pointer_value(), tyctx.size_type, None); + let len = + sizet_model.check_value(tyctx, ctx.ctx, shape.load_size(ctx, Some("len"))).unwrap(); + + ArrayWriter { + len, + write: Box::new(move |generator, ctx, dst_array| { + gen_model_for( + generator, + ctx, + sizet_model.constant(tyctx, ctx.ctx, 0), + len, + sizet_model.constant(tyctx, ctx.ctx, 1), + |generator, ctx, _hooks, i| { + let dim = + shape.data().get(ctx, generator, &i.value, None).into_int_value(); + let dim = sizet_model.s_extend_or_bit_cast(tyctx, ctx, dim, ""); + + dst_array.offset(tyctx, ctx, i.value, "pdim").store(ctx, dim); + Ok(()) + }, + ) + }), + } + } + TypeEnum::TTuple { ty: tuple_types } => { + // 2. A tuple of ints; e.g., `np.empty((600, 800, 3))` + + let ndims = tuple_types.len(); + + // A tuple has to be a StructValue + // Read [`codegen::expr::gen_expr`] to see how `nac3core` translates a Python tuple into LLVM. + let shape = shape.into_struct_value(); + + ArrayWriter { + len: sizet_model.constant(tyctx, ctx.ctx, ndims as u64), + write: Box::new(move |_generator, ctx, dst_array| { + for axis in 0..ndims { + let dim = ctx + .builder + .build_extract_value(shape, axis as u32, format!("dim{axis}").as_str()) + .unwrap() + .into_int_value(); + let dim = sizet_model.s_extend_or_bit_cast(tyctx, ctx, dim, ""); + + dst_array + .offset( + tyctx, + ctx, + sizet_model.constant(tyctx, ctx.ctx, axis as u64).value, + "pdim", + ) + .store(ctx, dim); + } + Ok(()) + }), + } + } + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() => + { + // 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` + + // The value has to be an integer + let shape_int = shape.into_int_value(); + + ArrayWriter { + len: sizet_model.constant(tyctx, ctx.ctx, 1), + write: Box::new(move |_generator, ctx, dst_array| { + let dim = sizet_model.s_extend_or_bit_cast(tyctx, ctx, shape_int, ""); + + // Set shape[0] = shape_int + dst_array + .offset(tyctx, ctx, sizet_model.constant(tyctx, ctx.ctx, 0).value, "pdim") + .store(ctx, dim); + + Ok(()) + }), + } + } + _ => panic!("encountered shape type"), + } +}