forked from M-Labs/nac3
core/ndstrides: add ArrayWriter & make_shape_writer
This commit is contained in:
parent
fd3d02bff0
commit
e5fe86cc93
@ -47,6 +47,7 @@ pub mod model;
|
||||
pub mod numpy;
|
||||
pub mod stmt;
|
||||
pub mod structure;
|
||||
pub mod util;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
|
@ -75,6 +75,7 @@ impl<Element: Model> PtrModel<Element> {
|
||||
|
||||
impl<'ctx, Element: Model> Ptr<'ctx, Element> {
|
||||
/// Offset the pointer by [`inkwell::builder::Builder::build_in_bounds_gep`].
|
||||
#[must_use]
|
||||
pub fn offset(
|
||||
&self,
|
||||
tyctx: TypeContext<'ctx>,
|
||||
|
@ -508,8 +508,12 @@ where
|
||||
I: Clone,
|
||||
InitFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<I, String>,
|
||||
CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<IntValue<'ctx>, String>,
|
||||
BodyFn:
|
||||
FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, BreakContinueHooks, I) -> Result<(), String>,
|
||||
BodyFn: FnOnce(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
BreakContinueHooks<'ctx>,
|
||||
I,
|
||||
) -> Result<(), String>,
|
||||
UpdateFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
|
||||
{
|
||||
let label = label.unwrap_or("for");
|
||||
@ -589,7 +593,7 @@ where
|
||||
BodyFn: FnOnce(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
BreakContinueHooks,
|
||||
BreakContinueHooks<'ctx>,
|
||||
IntValue<'ctx>,
|
||||
) -> Result<(), String>,
|
||||
{
|
||||
|
17
nac3core/src/codegen/util/array_writer.rs
Normal file
17
nac3core/src/codegen/util/array_writer.rs
Normal file
@ -0,0 +1,17 @@
|
||||
use crate::codegen::{model::*, CodeGenContext, CodeGenerator};
|
||||
|
||||
/// A closure containing details on how to write to/initialize an array.
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub struct ArrayWriter<'ctx, G: CodeGenerator + ?Sized, Len: IntKind, Item: Model> {
|
||||
/// Number of items to write
|
||||
pub len: Int<'ctx, Len>,
|
||||
/// Implementation to write to an array given its base pointer.
|
||||
pub write: Box<
|
||||
dyn Fn(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, '_>,
|
||||
Ptr<'ctx, Item>, // Base pointer
|
||||
) -> Result<(), String>
|
||||
+ 'ctx,
|
||||
>,
|
||||
}
|
42
nac3core/src/codegen/util/control.rs
Normal file
42
nac3core/src/codegen/util/control.rs
Normal file
@ -0,0 +1,42 @@
|
||||
use crate::codegen::{
|
||||
model::*,
|
||||
stmt::{gen_for_callback_incrementing, BreakContinueHooks},
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
// TODO: Document
|
||||
// TODO: Rename function
|
||||
/// Only allows positive steps
|
||||
pub fn gen_model_for<'ctx, 'a, G, F, I>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
start: Int<'ctx, I>,
|
||||
stop: Int<'ctx, I>,
|
||||
step: Int<'ctx, I>,
|
||||
body: F,
|
||||
) -> Result<(), String>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
F: FnOnce(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
BreakContinueHooks<'ctx>,
|
||||
Int<'ctx, I>,
|
||||
) -> Result<(), String>,
|
||||
I: IntKind,
|
||||
{
|
||||
let int_model = IntModel(I::default());
|
||||
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
ctx,
|
||||
None,
|
||||
start.value,
|
||||
(stop.value, false),
|
||||
|g, ctx, hooks, i| {
|
||||
let i = int_model.believe_value(i);
|
||||
body(g, ctx, hooks, i)
|
||||
},
|
||||
step.value,
|
||||
)
|
||||
}
|
3
nac3core/src/codegen/util/mod.rs
Normal file
3
nac3core/src/codegen/util/mod.rs
Normal file
@ -0,0 +1,3 @@
|
||||
pub mod array_writer;
|
||||
pub mod control;
|
||||
pub mod shape;
|
127
nac3core/src/codegen/util/shape.rs
Normal file
127
nac3core/src/codegen/util/shape.rs
Normal file
@ -0,0 +1,127 @@
|
||||
use inkwell::values::BasicValueEnum;
|
||||
|
||||
use crate::{
|
||||
codegen::{
|
||||
classes::{ListValue, UntypedArrayLikeAccessor},
|
||||
model::*,
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
typecheck::typedef::{Type, TypeEnum},
|
||||
};
|
||||
|
||||
use super::{array_writer::ArrayWriter, control::gen_model_for};
|
||||
|
||||
// TODO: Generalize to complex iterables under a common interface
|
||||
/// Create an [`ArrayWriter`] from a NumPy-like `shape` argument input.
|
||||
/// * `shape` - The `shape` parameter.
|
||||
/// * `shape_ty` - The element type of the `NDArray`.
|
||||
///
|
||||
/// The `shape` argument type may only be one of the following:
|
||||
/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
|
||||
/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))`
|
||||
/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
|
||||
///
|
||||
/// The `int32` values will be sign-extended to `SizeT`
|
||||
pub fn make_shape_writer<'ctx, G>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
shape: BasicValueEnum<'ctx>,
|
||||
shape_ty: Type,
|
||||
) -> ArrayWriter<'ctx, G, SizeT, IntModel<SizeT>>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
{
|
||||
let tyctx = generator.type_context(ctx.ctx);
|
||||
let sizet_model = IntModel(SizeT);
|
||||
|
||||
match &*ctx.unifier.get_ty(shape_ty) {
|
||||
TypeEnum::TObj { obj_id, .. }
|
||||
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
||||
{
|
||||
// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
|
||||
|
||||
// TODO: Remove ListValue with Model
|
||||
|
||||
let shape = ListValue::from_ptr_val(shape.into_pointer_value(), tyctx.size_type, None);
|
||||
let len =
|
||||
sizet_model.check_value(tyctx, ctx.ctx, shape.load_size(ctx, Some("len"))).unwrap();
|
||||
|
||||
ArrayWriter {
|
||||
len,
|
||||
write: Box::new(move |generator, ctx, dst_array| {
|
||||
gen_model_for(
|
||||
generator,
|
||||
ctx,
|
||||
sizet_model.constant(tyctx, ctx.ctx, 0),
|
||||
len,
|
||||
sizet_model.constant(tyctx, ctx.ctx, 1),
|
||||
|generator, ctx, _hooks, i| {
|
||||
let dim =
|
||||
shape.data().get(ctx, generator, &i.value, None).into_int_value();
|
||||
let dim = sizet_model.s_extend_or_bit_cast(tyctx, ctx, dim, "");
|
||||
|
||||
dst_array.offset(tyctx, ctx, i.value, "pdim").store(ctx, dim);
|
||||
Ok(())
|
||||
},
|
||||
)
|
||||
}),
|
||||
}
|
||||
}
|
||||
TypeEnum::TTuple { ty: tuple_types } => {
|
||||
// 2. A tuple of ints; e.g., `np.empty((600, 800, 3))`
|
||||
|
||||
let ndims = tuple_types.len();
|
||||
|
||||
// A tuple has to be a StructValue
|
||||
// Read [`codegen::expr::gen_expr`] to see how `nac3core` translates a Python tuple into LLVM.
|
||||
let shape = shape.into_struct_value();
|
||||
|
||||
ArrayWriter {
|
||||
len: sizet_model.constant(tyctx, ctx.ctx, ndims as u64),
|
||||
write: Box::new(move |_generator, ctx, dst_array| {
|
||||
for axis in 0..ndims {
|
||||
let dim = ctx
|
||||
.builder
|
||||
.build_extract_value(shape, axis as u32, format!("dim{axis}").as_str())
|
||||
.unwrap()
|
||||
.into_int_value();
|
||||
let dim = sizet_model.s_extend_or_bit_cast(tyctx, ctx, dim, "");
|
||||
|
||||
dst_array
|
||||
.offset(
|
||||
tyctx,
|
||||
ctx,
|
||||
sizet_model.constant(tyctx, ctx.ctx, axis as u64).value,
|
||||
"pdim",
|
||||
)
|
||||
.store(ctx, dim);
|
||||
}
|
||||
Ok(())
|
||||
}),
|
||||
}
|
||||
}
|
||||
TypeEnum::TObj { obj_id, .. }
|
||||
if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() =>
|
||||
{
|
||||
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
|
||||
|
||||
// The value has to be an integer
|
||||
let shape_int = shape.into_int_value();
|
||||
|
||||
ArrayWriter {
|
||||
len: sizet_model.constant(tyctx, ctx.ctx, 1),
|
||||
write: Box::new(move |_generator, ctx, dst_array| {
|
||||
let dim = sizet_model.s_extend_or_bit_cast(tyctx, ctx, shape_int, "");
|
||||
|
||||
// Set shape[0] = shape_int
|
||||
dst_array
|
||||
.offset(tyctx, ctx, sizet_model.constant(tyctx, ctx.ctx, 0).value, "pdim")
|
||||
.store(ctx, dim);
|
||||
|
||||
Ok(())
|
||||
}),
|
||||
}
|
||||
}
|
||||
_ => panic!("encountered shape type"),
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user