forked from M-Labs/nac3
core/codegen: add ArrayWriter & parse_input_shape_arg
This commit is contained in:
parent
19c2beffbb
commit
f78a60a644
|
@ -47,6 +47,7 @@ pub mod model;
|
||||||
pub mod numpy;
|
pub mod numpy;
|
||||||
pub mod stmt;
|
pub mod stmt;
|
||||||
pub mod structs;
|
pub mod structs;
|
||||||
|
pub mod util;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test;
|
mod test;
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
use inkwell::{types::IntType, values::IntValue};
|
||||||
|
|
||||||
|
use crate::codegen::{model::*, CodeGenContext, CodeGenerator};
|
||||||
|
|
||||||
|
pub type ArrayWriterWrite<'ctx, G, N, E> = Box<
|
||||||
|
dyn Fn(&mut G, &mut CodeGenContext<'ctx, '_>, &ArraySlice<'ctx, N, E>) -> Result<(), String>
|
||||||
|
+ 'ctx,
|
||||||
|
>;
|
||||||
|
|
||||||
|
// TODO: Document
|
||||||
|
pub struct ArrayWriter<'ctx, G: CodeGenerator + ?Sized, N, E: Model<'ctx>>
|
||||||
|
where
|
||||||
|
N: Model<'ctx, Value = IntValue<'ctx>, Type = IntType<'ctx>>,
|
||||||
|
{
|
||||||
|
pub count: Instance<'ctx, N>,
|
||||||
|
pub write: ArrayWriterWrite<'ctx, G, N, E>,
|
||||||
|
}
|
|
@ -0,0 +1,2 @@
|
||||||
|
pub mod array_writer;
|
||||||
|
pub mod shape;
|
|
@ -0,0 +1,144 @@
|
||||||
|
use inkwell::values::BasicValueEnum;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
codegen::{
|
||||||
|
classes::{ListValue, UntypedArrayLikeAccessor},
|
||||||
|
model::*,
|
||||||
|
stmt::gen_for_callback_incrementing,
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
},
|
||||||
|
typecheck::typedef::{Type, TypeEnum},
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::array_writer::ArrayWriter;
|
||||||
|
|
||||||
|
/// TODO: UPDATE DOCUMENTATION
|
||||||
|
/// LLVM-typed implementation for generating a [`ArrayWriter`] that sets a list of ints.
|
||||||
|
///
|
||||||
|
/// * `elem_ty` - The element type of the `NDArray`.
|
||||||
|
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
|
||||||
|
///
|
||||||
|
/// ### Notes on `shape`
|
||||||
|
///
|
||||||
|
/// Just like numpy, the `shape` argument can be:
|
||||||
|
/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
|
||||||
|
/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))`
|
||||||
|
/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
|
||||||
|
///
|
||||||
|
/// See also [`typecheck::type_inferencer::fold_numpy_function_call_shape_argument`] to
|
||||||
|
/// learn how `shape` gets from being a Python user expression to here.
|
||||||
|
pub fn parse_input_shape_arg<'ctx, G>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
shape: BasicValueEnum<'ctx>,
|
||||||
|
shape_ty: Type,
|
||||||
|
) -> ArrayWriter<'ctx, G, SizeTModel<'ctx>, SizeTModel<'ctx>>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
{
|
||||||
|
let sizet = generator.get_sizet(ctx.ctx);
|
||||||
|
|
||||||
|
match &*ctx.unifier.get_ty(shape_ty) {
|
||||||
|
TypeEnum::TObj { obj_id, .. }
|
||||||
|
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
||||||
|
{
|
||||||
|
// 1. A list of ints; e.g., `np.empty([600, 800, 3])`
|
||||||
|
|
||||||
|
// A list has to be a PointerValue
|
||||||
|
let shape_list = ListValue::from_ptr_val(shape.into_pointer_value(), sizet.0, None);
|
||||||
|
|
||||||
|
// Create `ArrayWriter`
|
||||||
|
let ndims =
|
||||||
|
sizet.review_value(ctx.ctx, shape_list.load_size(ctx, Some("count"))).unwrap();
|
||||||
|
ArrayWriter {
|
||||||
|
count: ndims,
|
||||||
|
write: Box::new(move |generator, ctx, dst_array| {
|
||||||
|
// Basically iterate through the list and write to `dst_slice` accordingly
|
||||||
|
let init_val = sizet.constant(ctx.ctx, 0).value;
|
||||||
|
let max_val = (ndims.value, false);
|
||||||
|
let incr_val = sizet.constant(ctx.ctx, 1).value;
|
||||||
|
gen_for_callback_incrementing(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
init_val,
|
||||||
|
max_val,
|
||||||
|
|generator, ctx, _hooks, axis| {
|
||||||
|
let axis = sizet.review_value(ctx.ctx, axis).unwrap();
|
||||||
|
|
||||||
|
// TODO: Remove ProxyValue ListValue
|
||||||
|
|
||||||
|
// Get the dimension at `axis`
|
||||||
|
let dim: Int<'ctx> = shape_list
|
||||||
|
.data()
|
||||||
|
.get(ctx, generator, &axis.value, None)
|
||||||
|
.into_int_value()
|
||||||
|
.into();
|
||||||
|
|
||||||
|
let dim = dim.s_extend_or_bit_cast(ctx, sizet, "dim_casted");
|
||||||
|
dst_array.ix(generator, ctx, axis, "dim").store(ctx, dim);
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
incr_val,
|
||||||
|
)
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TypeEnum::TTuple { ty: tuple_types } => {
|
||||||
|
// 2. A tuple of ints; e.g., `np.empty((600, 800, 3))`
|
||||||
|
|
||||||
|
// Get the length/size of the tuple, which also happens to be the value of `ndims`.
|
||||||
|
let ndims = tuple_types.len();
|
||||||
|
|
||||||
|
// A tuple has to be a StructValue
|
||||||
|
// Read [`codegen::expr::gen_expr`] to see how `nac3core` translates a Python tuple into LLVM.
|
||||||
|
let shape_tuple = shape.into_struct_value();
|
||||||
|
|
||||||
|
ArrayWriter {
|
||||||
|
count: sizet.constant(ctx.ctx, ndims as u64),
|
||||||
|
write: Box::new(move |generator, ctx, dst_array| {
|
||||||
|
for axis in 0..ndims {
|
||||||
|
// Get the dimension at `axis`
|
||||||
|
let dim: Int<'ctx> = ctx
|
||||||
|
.builder
|
||||||
|
.build_extract_value(
|
||||||
|
shape_tuple,
|
||||||
|
axis as u32,
|
||||||
|
format!("dim{axis}").as_str(),
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
.into_int_value()
|
||||||
|
.into();
|
||||||
|
|
||||||
|
let dim = dim.s_extend_or_bit_cast(ctx, sizet, "dim_casted");
|
||||||
|
dst_array
|
||||||
|
.ix(generator, ctx, sizet.constant(ctx.ctx, axis as u64), "dim")
|
||||||
|
.store(ctx, dim);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TypeEnum::TObj { obj_id, .. }
|
||||||
|
if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() =>
|
||||||
|
{
|
||||||
|
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
|
||||||
|
|
||||||
|
// The value has to be an integer
|
||||||
|
let shape_int: Int<'ctx> = shape.into_int_value().into();
|
||||||
|
|
||||||
|
ArrayWriter {
|
||||||
|
count: sizet.constant(ctx.ctx, 1),
|
||||||
|
write: Box::new(move |generator, ctx, dst_array| {
|
||||||
|
// Cast `shape_int` to SizeT
|
||||||
|
let dim = shape_int.s_extend_or_bit_cast(ctx, sizet, "dim_casted");
|
||||||
|
|
||||||
|
// Set shape[0] = shape_int
|
||||||
|
dst_array.ix(generator, ctx, sizet.constant(ctx.ctx, 0), "dim").store(ctx, dim);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => panic!("parse_input_shape_arg encountered unknown type"),
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue