forked from M-Labs/nac3
1
0
Fork 0

core/ndstrides: add ArrayWriter & make_shape_writer

This commit is contained in:
lyken 2024-07-28 15:43:38 +08:00
parent fd3d02bff0
commit e5fe86cc93
7 changed files with 198 additions and 3 deletions

View File

@ -47,6 +47,7 @@ pub mod model;
pub mod numpy;
pub mod stmt;
pub mod structure;
pub mod util;
#[cfg(test)]
mod test;

View File

@ -75,6 +75,7 @@ impl<Element: Model> PtrModel<Element> {
impl<'ctx, Element: Model> Ptr<'ctx, Element> {
/// Offset the pointer by [`inkwell::builder::Builder::build_in_bounds_gep`].
#[must_use]
pub fn offset(
&self,
tyctx: TypeContext<'ctx>,

View File

@ -508,8 +508,12 @@ where
I: Clone,
InitFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<I, String>,
CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<IntValue<'ctx>, String>,
BodyFn:
FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, BreakContinueHooks, I) -> Result<(), String>,
BodyFn: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
BreakContinueHooks<'ctx>,
I,
) -> Result<(), String>,
UpdateFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
{
let label = label.unwrap_or("for");
@ -589,7 +593,7 @@ where
BodyFn: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
BreakContinueHooks,
BreakContinueHooks<'ctx>,
IntValue<'ctx>,
) -> Result<(), String>,
{

View File

@ -0,0 +1,17 @@
use crate::codegen::{model::*, CodeGenContext, CodeGenerator};
/// A closure containing details on how to write to/initialize an array.
#[allow(clippy::type_complexity)]
pub struct ArrayWriter<'ctx, G: CodeGenerator + ?Sized, Len: IntKind, Item: Model> {
/// Number of items to write
pub len: Int<'ctx, Len>,
/// Implementation to write to an array given its base pointer.
pub write: Box<
dyn Fn(
&mut G,
&mut CodeGenContext<'ctx, '_>,
Ptr<'ctx, Item>, // Base pointer
) -> Result<(), String>
+ 'ctx,
>,
}

View File

@ -0,0 +1,42 @@
use crate::codegen::{
model::*,
stmt::{gen_for_callback_incrementing, BreakContinueHooks},
CodeGenContext, CodeGenerator,
};
// TODO: Document
// TODO: Rename function
/// Only allows positive steps
pub fn gen_model_for<'ctx, 'a, G, F, I>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
start: Int<'ctx, I>,
stop: Int<'ctx, I>,
step: Int<'ctx, I>,
body: F,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
F: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
BreakContinueHooks<'ctx>,
Int<'ctx, I>,
) -> Result<(), String>,
I: IntKind,
{
let int_model = IntModel(I::default());
gen_for_callback_incrementing(
generator,
ctx,
None,
start.value,
(stop.value, false),
|g, ctx, hooks, i| {
let i = int_model.believe_value(i);
body(g, ctx, hooks, i)
},
step.value,
)
}

View File

@ -0,0 +1,3 @@
pub mod array_writer;
pub mod control;
pub mod shape;

View File

@ -0,0 +1,127 @@
use inkwell::values::BasicValueEnum;
use crate::{
codegen::{
classes::{ListValue, UntypedArrayLikeAccessor},
model::*,
CodeGenContext, CodeGenerator,
},
typecheck::typedef::{Type, TypeEnum},
};
use super::{array_writer::ArrayWriter, control::gen_model_for};
// TODO: Generalize to complex iterables under a common interface
/// Create an [`ArrayWriter`] from a NumPy-like `shape` argument input.
/// * `shape` - The `shape` parameter.
/// * `shape_ty` - The element type of the `NDArray`.
///
/// The `shape` argument type may only be one of the following:
/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))`
/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
///
/// The `int32` values will be sign-extended to `SizeT`
pub fn make_shape_writer<'ctx, G>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
shape: BasicValueEnum<'ctx>,
shape_ty: Type,
) -> ArrayWriter<'ctx, G, SizeT, IntModel<SizeT>>
where
G: CodeGenerator + ?Sized,
{
let tyctx = generator.type_context(ctx.ctx);
let sizet_model = IntModel(SizeT);
match &*ctx.unifier.get_ty(shape_ty) {
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{
// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
// TODO: Remove ListValue with Model
let shape = ListValue::from_ptr_val(shape.into_pointer_value(), tyctx.size_type, None);
let len =
sizet_model.check_value(tyctx, ctx.ctx, shape.load_size(ctx, Some("len"))).unwrap();
ArrayWriter {
len,
write: Box::new(move |generator, ctx, dst_array| {
gen_model_for(
generator,
ctx,
sizet_model.constant(tyctx, ctx.ctx, 0),
len,
sizet_model.constant(tyctx, ctx.ctx, 1),
|generator, ctx, _hooks, i| {
let dim =
shape.data().get(ctx, generator, &i.value, None).into_int_value();
let dim = sizet_model.s_extend_or_bit_cast(tyctx, ctx, dim, "");
dst_array.offset(tyctx, ctx, i.value, "pdim").store(ctx, dim);
Ok(())
},
)
}),
}
}
TypeEnum::TTuple { ty: tuple_types } => {
// 2. A tuple of ints; e.g., `np.empty((600, 800, 3))`
let ndims = tuple_types.len();
// A tuple has to be a StructValue
// Read [`codegen::expr::gen_expr`] to see how `nac3core` translates a Python tuple into LLVM.
let shape = shape.into_struct_value();
ArrayWriter {
len: sizet_model.constant(tyctx, ctx.ctx, ndims as u64),
write: Box::new(move |_generator, ctx, dst_array| {
for axis in 0..ndims {
let dim = ctx
.builder
.build_extract_value(shape, axis as u32, format!("dim{axis}").as_str())
.unwrap()
.into_int_value();
let dim = sizet_model.s_extend_or_bit_cast(tyctx, ctx, dim, "");
dst_array
.offset(
tyctx,
ctx,
sizet_model.constant(tyctx, ctx.ctx, axis as u64).value,
"pdim",
)
.store(ctx, dim);
}
Ok(())
}),
}
}
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() =>
{
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
// The value has to be an integer
let shape_int = shape.into_int_value();
ArrayWriter {
len: sizet_model.constant(tyctx, ctx.ctx, 1),
write: Box::new(move |_generator, ctx, dst_array| {
let dim = sizet_model.s_extend_or_bit_cast(tyctx, ctx, shape_int, "");
// Set shape[0] = shape_int
dst_array
.offset(tyctx, ctx, sizet_model.constant(tyctx, ctx.ctx, 0).value, "pdim")
.store(ctx, dim);
Ok(())
}),
}
}
_ => panic!("encountered shape type"),
}
}