forked from M-Labs/nac3
core/ndstrides: add ArrayWriter & make_shape_writer
This commit is contained in:
parent
fd3d02bff0
commit
e5fe86cc93
|
@ -47,6 +47,7 @@ pub mod model;
|
||||||
pub mod numpy;
|
pub mod numpy;
|
||||||
pub mod stmt;
|
pub mod stmt;
|
||||||
pub mod structure;
|
pub mod structure;
|
||||||
|
pub mod util;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test;
|
mod test;
|
||||||
|
|
|
@ -75,6 +75,7 @@ impl<Element: Model> PtrModel<Element> {
|
||||||
|
|
||||||
impl<'ctx, Element: Model> Ptr<'ctx, Element> {
|
impl<'ctx, Element: Model> Ptr<'ctx, Element> {
|
||||||
/// Offset the pointer by [`inkwell::builder::Builder::build_in_bounds_gep`].
|
/// Offset the pointer by [`inkwell::builder::Builder::build_in_bounds_gep`].
|
||||||
|
#[must_use]
|
||||||
pub fn offset(
|
pub fn offset(
|
||||||
&self,
|
&self,
|
||||||
tyctx: TypeContext<'ctx>,
|
tyctx: TypeContext<'ctx>,
|
||||||
|
|
|
@ -508,8 +508,12 @@ where
|
||||||
I: Clone,
|
I: Clone,
|
||||||
InitFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<I, String>,
|
InitFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<I, String>,
|
||||||
CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<IntValue<'ctx>, String>,
|
CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<IntValue<'ctx>, String>,
|
||||||
BodyFn:
|
BodyFn: FnOnce(
|
||||||
FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, BreakContinueHooks, I) -> Result<(), String>,
|
&mut G,
|
||||||
|
&mut CodeGenContext<'ctx, 'a>,
|
||||||
|
BreakContinueHooks<'ctx>,
|
||||||
|
I,
|
||||||
|
) -> Result<(), String>,
|
||||||
UpdateFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
|
UpdateFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
|
||||||
{
|
{
|
||||||
let label = label.unwrap_or("for");
|
let label = label.unwrap_or("for");
|
||||||
|
@ -589,7 +593,7 @@ where
|
||||||
BodyFn: FnOnce(
|
BodyFn: FnOnce(
|
||||||
&mut G,
|
&mut G,
|
||||||
&mut CodeGenContext<'ctx, 'a>,
|
&mut CodeGenContext<'ctx, 'a>,
|
||||||
BreakContinueHooks,
|
BreakContinueHooks<'ctx>,
|
||||||
IntValue<'ctx>,
|
IntValue<'ctx>,
|
||||||
) -> Result<(), String>,
|
) -> Result<(), String>,
|
||||||
{
|
{
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
use crate::codegen::{model::*, CodeGenContext, CodeGenerator};
|
||||||
|
|
||||||
|
/// A closure containing details on how to write to/initialize an array.
|
||||||
|
#[allow(clippy::type_complexity)]
|
||||||
|
pub struct ArrayWriter<'ctx, G: CodeGenerator + ?Sized, Len: IntKind, Item: Model> {
|
||||||
|
/// Number of items to write
|
||||||
|
pub len: Int<'ctx, Len>,
|
||||||
|
/// Implementation to write to an array given its base pointer.
|
||||||
|
pub write: Box<
|
||||||
|
dyn Fn(
|
||||||
|
&mut G,
|
||||||
|
&mut CodeGenContext<'ctx, '_>,
|
||||||
|
Ptr<'ctx, Item>, // Base pointer
|
||||||
|
) -> Result<(), String>
|
||||||
|
+ 'ctx,
|
||||||
|
>,
|
||||||
|
}
|
|
@ -0,0 +1,42 @@
|
||||||
|
use crate::codegen::{
|
||||||
|
model::*,
|
||||||
|
stmt::{gen_for_callback_incrementing, BreakContinueHooks},
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO: Document
|
||||||
|
// TODO: Rename function
|
||||||
|
/// Only allows positive steps
|
||||||
|
pub fn gen_model_for<'ctx, 'a, G, F, I>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
start: Int<'ctx, I>,
|
||||||
|
stop: Int<'ctx, I>,
|
||||||
|
step: Int<'ctx, I>,
|
||||||
|
body: F,
|
||||||
|
) -> Result<(), String>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
F: FnOnce(
|
||||||
|
&mut G,
|
||||||
|
&mut CodeGenContext<'ctx, 'a>,
|
||||||
|
BreakContinueHooks<'ctx>,
|
||||||
|
Int<'ctx, I>,
|
||||||
|
) -> Result<(), String>,
|
||||||
|
I: IntKind,
|
||||||
|
{
|
||||||
|
let int_model = IntModel(I::default());
|
||||||
|
|
||||||
|
gen_for_callback_incrementing(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
None,
|
||||||
|
start.value,
|
||||||
|
(stop.value, false),
|
||||||
|
|g, ctx, hooks, i| {
|
||||||
|
let i = int_model.believe_value(i);
|
||||||
|
body(g, ctx, hooks, i)
|
||||||
|
},
|
||||||
|
step.value,
|
||||||
|
)
|
||||||
|
}
|
|
@ -0,0 +1,3 @@
|
||||||
|
pub mod array_writer;
|
||||||
|
pub mod control;
|
||||||
|
pub mod shape;
|
|
@ -0,0 +1,127 @@
|
||||||
|
use inkwell::values::BasicValueEnum;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
codegen::{
|
||||||
|
classes::{ListValue, UntypedArrayLikeAccessor},
|
||||||
|
model::*,
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
},
|
||||||
|
typecheck::typedef::{Type, TypeEnum},
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{array_writer::ArrayWriter, control::gen_model_for};
|
||||||
|
|
||||||
|
// TODO: Generalize to complex iterables under a common interface
|
||||||
|
/// Create an [`ArrayWriter`] from a NumPy-like `shape` argument input.
|
||||||
|
/// * `shape` - The `shape` parameter.
|
||||||
|
/// * `shape_ty` - The element type of the `NDArray`.
|
||||||
|
///
|
||||||
|
/// The `shape` argument type may only be one of the following:
|
||||||
|
/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
|
||||||
|
/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))`
|
||||||
|
/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
|
||||||
|
///
|
||||||
|
/// The `int32` values will be sign-extended to `SizeT`
|
||||||
|
pub fn make_shape_writer<'ctx, G>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
shape: BasicValueEnum<'ctx>,
|
||||||
|
shape_ty: Type,
|
||||||
|
) -> ArrayWriter<'ctx, G, SizeT, IntModel<SizeT>>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
{
|
||||||
|
let tyctx = generator.type_context(ctx.ctx);
|
||||||
|
let sizet_model = IntModel(SizeT);
|
||||||
|
|
||||||
|
match &*ctx.unifier.get_ty(shape_ty) {
|
||||||
|
TypeEnum::TObj { obj_id, .. }
|
||||||
|
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
||||||
|
{
|
||||||
|
// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
|
||||||
|
|
||||||
|
// TODO: Remove ListValue with Model
|
||||||
|
|
||||||
|
let shape = ListValue::from_ptr_val(shape.into_pointer_value(), tyctx.size_type, None);
|
||||||
|
let len =
|
||||||
|
sizet_model.check_value(tyctx, ctx.ctx, shape.load_size(ctx, Some("len"))).unwrap();
|
||||||
|
|
||||||
|
ArrayWriter {
|
||||||
|
len,
|
||||||
|
write: Box::new(move |generator, ctx, dst_array| {
|
||||||
|
gen_model_for(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
sizet_model.constant(tyctx, ctx.ctx, 0),
|
||||||
|
len,
|
||||||
|
sizet_model.constant(tyctx, ctx.ctx, 1),
|
||||||
|
|generator, ctx, _hooks, i| {
|
||||||
|
let dim =
|
||||||
|
shape.data().get(ctx, generator, &i.value, None).into_int_value();
|
||||||
|
let dim = sizet_model.s_extend_or_bit_cast(tyctx, ctx, dim, "");
|
||||||
|
|
||||||
|
dst_array.offset(tyctx, ctx, i.value, "pdim").store(ctx, dim);
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TypeEnum::TTuple { ty: tuple_types } => {
|
||||||
|
// 2. A tuple of ints; e.g., `np.empty((600, 800, 3))`
|
||||||
|
|
||||||
|
let ndims = tuple_types.len();
|
||||||
|
|
||||||
|
// A tuple has to be a StructValue
|
||||||
|
// Read [`codegen::expr::gen_expr`] to see how `nac3core` translates a Python tuple into LLVM.
|
||||||
|
let shape = shape.into_struct_value();
|
||||||
|
|
||||||
|
ArrayWriter {
|
||||||
|
len: sizet_model.constant(tyctx, ctx.ctx, ndims as u64),
|
||||||
|
write: Box::new(move |_generator, ctx, dst_array| {
|
||||||
|
for axis in 0..ndims {
|
||||||
|
let dim = ctx
|
||||||
|
.builder
|
||||||
|
.build_extract_value(shape, axis as u32, format!("dim{axis}").as_str())
|
||||||
|
.unwrap()
|
||||||
|
.into_int_value();
|
||||||
|
let dim = sizet_model.s_extend_or_bit_cast(tyctx, ctx, dim, "");
|
||||||
|
|
||||||
|
dst_array
|
||||||
|
.offset(
|
||||||
|
tyctx,
|
||||||
|
ctx,
|
||||||
|
sizet_model.constant(tyctx, ctx.ctx, axis as u64).value,
|
||||||
|
"pdim",
|
||||||
|
)
|
||||||
|
.store(ctx, dim);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TypeEnum::TObj { obj_id, .. }
|
||||||
|
if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() =>
|
||||||
|
{
|
||||||
|
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
|
||||||
|
|
||||||
|
// The value has to be an integer
|
||||||
|
let shape_int = shape.into_int_value();
|
||||||
|
|
||||||
|
ArrayWriter {
|
||||||
|
len: sizet_model.constant(tyctx, ctx.ctx, 1),
|
||||||
|
write: Box::new(move |_generator, ctx, dst_array| {
|
||||||
|
let dim = sizet_model.s_extend_or_bit_cast(tyctx, ctx, shape_int, "");
|
||||||
|
|
||||||
|
// Set shape[0] = shape_int
|
||||||
|
dst_array
|
||||||
|
.offset(tyctx, ctx, sizet_model.constant(tyctx, ctx.ctx, 0).value, "pdim")
|
||||||
|
.store(ctx, dim);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => panic!("encountered shape type"),
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue