forked from M-Labs/nac3
1
0
Fork 0

core/ndstrides: implement ndarray np_{empty,ndarray,zeros,ones,full}

This commit is contained in:
lyken 2024-07-26 16:23:11 +08:00
parent fc9d47fb54
commit 85ef06f1e2
8 changed files with 284 additions and 5 deletions

View File

@ -0,0 +1,38 @@
#pragma once
#include <irrt/ndarray/basic.hpp>
#include <irrt/ndarray/def.hpp>
namespace {
namespace ndarray {
namespace fill {
/**
* Fill an ndarray with a value.
*
* @param pvalue Pointer to the fill value, and the fill value should be of `ndarray->itemsize` bytes.
*/
template <typename SizeT>
void fill_generic(NDArray<SizeT>* ndarray, const uint8_t* pvalue) {
const SizeT size = ndarray::basic::size(ndarray);
for (SizeT i = 0; i < size; i++) {
uint8_t* pelement = ndarray::basic::get_nth_pelement(
ndarray, i); // No need for checked_get_nth_pelement
ndarray::basic::set_pelement_value(ndarray, pelement, pvalue);
}
}
} // namespace fill
} // namespace ndarray
} // namespace
extern "C" {
using namespace ndarray::fill;
void __nac3_ndarray_fill_generic(NDArray<int32_t>* ndarray, uint8_t* pvalue) {
fill_generic(ndarray, pvalue);
}
void __nac3_ndarray_fill_generic64(NDArray<int64_t>* ndarray, uint8_t* pvalue) {
fill_generic(ndarray, pvalue);
}
}

View File

@ -6,4 +6,5 @@
#include <irrt/int_defs.hpp>
#include <irrt/ndarray/basic.hpp>
#include <irrt/ndarray/def.hpp>
#include <irrt/ndarray/fill.hpp>
#include <irrt/utils.hpp>

View File

@ -0,0 +1,21 @@
use crate::codegen::{
irrt::util::get_sized_dependent_function_name, model::*, structs::ndarray::NpArray,
CodeGenContext, CodeGenerator,
};
pub fn call_nac3_ndarray_fill_generic<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray_ptr: Pointer<'ctx, StructModel<NpArray<'ctx>>>,
fill_value_ptr: Pointer<'ctx, ByteModel>,
) {
let sizet = generator.get_sizet(ctx.ctx);
FunctionBuilder::begin(
ctx,
&get_sized_dependent_function_name(sizet, "__nac3_ndarray_fill_generic"),
)
.arg("ndarray", ndarray_ptr)
.arg("pvalue", fill_value_ptr)
.returning_void();
}

View File

@ -1,2 +1,3 @@
pub mod allocation;
pub mod basic;
pub mod fill;

View File

@ -45,12 +45,13 @@ pub mod irrt;
pub mod llvm_intrinsics;
pub mod model;
pub mod numpy;
pub mod numpy_new;
pub mod stmt;
pub mod structs;
pub mod util;
#[cfg(test)]
mod test;
pub mod util;
use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore};
pub use generator::{CodeGenerator, DefaultCodeGenerator};

View File

@ -0,0 +1,213 @@
use inkwell::{
types::BasicType,
values::{BasicValue, BasicValueEnum, PointerValue},
};
use nac3parser::ast::StrRef;
use crate::{
codegen::{
irrt::ndarray::{
allocation::{alloca_ndarray, init_ndarray_data_by_alloca, init_ndarray_shape},
fill::call_nac3_ndarray_fill_generic,
},
model::*,
structs::ndarray::NpArray,
util::shape::parse_input_shape_arg,
CodeGenContext, CodeGenerator,
},
symbol_resolver::ValueEnum,
toplevel::DefinitionId,
typecheck::typedef::{FunSignature, Type},
};
/// Helper function to create an ndarray with uninitialized values
///
/// * `elem_ty` - The [`Type`] of the ndarray elements
/// * `shape` - The user input shape argument
/// * `shape_ty` - The [`Type`] of the shape argument
/// * `name` - LLVM IR name of the returned ndarray
fn create_empty_ndarray<'ctx, G>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
shape: BasicValueEnum<'ctx>,
shape_ty: Type,
name: &str,
) -> Result<Pointer<'ctx, StructModel<NpArray<'ctx>>>, String>
where
G: CodeGenerator + ?Sized,
{
let sizet = generator.get_sizet(ctx.ctx);
let shape_writer = parse_input_shape_arg(generator, ctx, shape, shape_ty);
let ndims = shape_writer.count;
let ndarray = alloca_ndarray(generator, ctx, ndims, name)?;
init_ndarray_shape(generator, ctx, ndarray, &shape_writer)?;
let itemsize = sizet
.review_value(ctx.ctx, ctx.get_llvm_type(generator, elem_ty).size_of().unwrap())
.unwrap();
ndarray.gep(ctx, |f| f.itemsize).store(ctx, itemsize);
init_ndarray_data_by_alloca(generator, ctx, ndarray); // Needs `itemsize` and `shape` initialized first
Ok(ndarray)
}
/// Helper function to create an ndarray full of a value.
///
/// * `elem_ty` - The [`Type`] of the ndarray elements and the fill value
/// * `shape` - The user input shape argument
/// * `shape_ty` - The [`Type`] of the shape argument
/// * `fill_value` - The user specified fill value
/// * `name` - LLVM IR name of the returned ndarray
fn create_full_ndarray<'ctx, G>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
shape: BasicValueEnum<'ctx>,
shape_ty: Type,
fill_value: BasicValueEnum<'ctx>,
name: &str,
) -> Result<Pointer<'ctx, StructModel<NpArray<'ctx>>>, String>
where
G: CodeGenerator + ?Sized,
{
let byte_model = NIntModel(Byte);
let fill_value_model = OpaqueModel(fill_value.get_type());
// Caller has to put fill_value on the stack and pass its address
let fill_value_ptr = fill_value_model.alloca(ctx, "fill_value_ptr");
fill_value_ptr.store(ctx, fill_value_model.believe_value(fill_value));
let fill_value_ptr = fill_value_ptr.cast_to(ctx, byte_model, "fill_value_bytes_ptr");
let ndarray_ptr = create_empty_ndarray(generator, ctx, elem_ty, shape, shape_ty, name)?;
call_nac3_ndarray_fill_generic(generator, ctx, ndarray_ptr, fill_value_ptr);
Ok(ndarray_ptr)
}
/// Generates LLVM IR for `np.empty`.
pub fn gen_ndarray_empty<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
// Parse arguments
let shape_ty = fun.0.args[0].ty;
let shape = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
// Implementation
let ndarray_ptr = create_empty_ndarray(
generator,
context,
context.primitives.float,
shape,
shape_ty,
"ndarray",
)?;
Ok(ndarray_ptr.value)
}
/// Generates LLVM IR for `np.zeros`.
pub fn gen_ndarray_zeros<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
// Parse arguments
let shape_ty = fun.0.args[0].ty;
let shape = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
// Implementation
// NOTE: Currently nac3's `np.zeros` is always `float64`.
let float64_ty = context.primitives.float;
let float64_llvm_type = context.get_llvm_type(generator, float64_ty).into_float_type();
let ndarray_ptr = create_full_ndarray(
generator,
context,
float64_ty, // `elem_ty` is always `float64`
shape,
shape_ty,
float64_llvm_type.const_zero().as_basic_value_enum(),
"ndarray",
)?;
Ok(ndarray_ptr.value)
}
/// Generates LLVM IR for `np.ones`.
pub fn gen_ndarray_ones<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
// Parse arguments
let shape_ty = fun.0.args[0].ty;
let shape = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
// Implementation
// NOTE: Currently nac3's `np.ones` is always `float64`.
let float64_ty = context.primitives.float;
let float64_llvm_type = context.get_llvm_type(generator, float64_ty).into_float_type();
let ndarray_ptr = create_full_ndarray(
generator,
context,
float64_ty, // `elem_ty` is always `float64`
shape,
shape_ty,
float64_llvm_type.const_float(1.0).as_basic_value_enum(),
"ndarray",
)?;
Ok(ndarray_ptr.value)
}
/// Generates LLVM IR for `np.full`.
pub fn gen_ndarray_full<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 2);
// Parse argument #1 shape
let shape_ty = fun.0.args[0].ty;
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
// Parse argument #2 fill_value
let fill_value_ty = fun.0.args[1].ty;
let fill_value_arg =
args[1].1.clone().to_basic_value_enum(context, generator, fill_value_ty)?;
// Implementation
let ndarray_ptr = create_full_ndarray(
generator,
context,
fill_value_ty,
shape_arg,
shape_ty,
fill_value_arg,
"ndarray",
)?;
Ok(ndarray_ptr.value)
}

View File

@ -0,0 +1 @@
pub mod factory;

View File

@ -20,6 +20,7 @@ use crate::{
irrt::*,
model::*,
numpy::*,
numpy_new,
stmt::exn_constructor,
structs::ndarray::NpArray,
},
@ -1205,9 +1206,11 @@ impl<'a> BuiltinBuilder<'a> {
&[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
Box::new(move |ctx, obj, fun, args, generator| {
let func = match prim {
PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => gen_ndarray_empty,
PrimDef::FunNpZeros => gen_ndarray_zeros,
PrimDef::FunNpOnes => gen_ndarray_ones,
PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => {
numpy_new::factory::gen_ndarray_empty
}
PrimDef::FunNpZeros => numpy_new::factory::gen_ndarray_zeros,
PrimDef::FunNpOnes => numpy_new::factory::gen_ndarray_ones,
_ => unreachable!(),
};
func(ctx, &obj, fun, &args, generator).map(|val| Some(val.as_basic_value_enum()))
@ -1275,7 +1278,7 @@ impl<'a> BuiltinBuilder<'a> {
// type variable
&[(self.list_int32, "shape"), (tv.ty, "fill_value")],
Box::new(move |ctx, obj, fun, args, generator| {
gen_ndarray_full(ctx, &obj, fun, &args, generator)
numpy_new::factory::gen_ndarray_full(ctx, &obj, fun, &args, generator)
.map(|val| Some(val.as_basic_value_enum()))
}),
)