forked from M-Labs/nac3
1
0
Fork 0

core/ndstrides: implement gen_foreach_ndarray_elements & np_{empty,ndarray,zeros,ones,full}

This commit is contained in:
lyken 2024-07-28 17:06:37 +08:00
parent 5b9ac9b09c
commit 2211c4d852
7 changed files with 306 additions and 4 deletions

View File

@ -303,4 +303,14 @@ void __nac3_ndarray_copy_data64(NDArray<int64_t>* src_ndarray,
NDArray<int64_t>* dst_ndarray) {
copy_data(src_ndarray, dst_ndarray);
}
uint8_t* __nac3_ndarray_get_nth_pelement(NDArray<int32_t>* ndarray,
int32_t index) {
return get_nth_pelement(ndarray, index);
}
uint8_t* __nac3_ndarray_get_nth_pelement64(NDArray<int64_t>* ndarray,
int64_t index) {
return get_nth_pelement(ndarray, index);
}
}

View File

@ -133,3 +133,21 @@ pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>(
.arg("dst_ndarray", dst_ndarray)
.returning("is_c_contiguous")
}
pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
pndarray: Ptr<'ctx, StructModel<NpArray>>,
index: Int<'ctx, SizeT>,
) -> Ptr<'ctx, IntModel<Byte>> {
let tyctx = generator.type_context(ctx.ctx);
CallFunction::begin(
tyctx,
ctx,
&get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_get_nth_pelement"),
)
.arg("ndarray", pndarray)
.arg("index", index)
.returning("pelement")
}

View File

@ -45,6 +45,7 @@ pub mod irrt;
pub mod llvm_intrinsics;
pub mod model;
pub mod numpy;
pub mod numpy_new;
pub mod stmt;
pub mod structure;
pub mod util;

View File

@ -0,0 +1,49 @@
use crate::codegen::{
irrt::ndarray::basic::{call_nac3_ndarray_get_nth_pelement, call_nac3_ndarray_size},
model::*,
stmt::BreakContinueHooks,
structure::ndarray::NpArray,
util::control::gen_model_for,
CodeGenContext, CodeGenerator,
};
/// Iterate through all elements in an ndarray.
///
/// `body` is given the index of an element and an opaque pointer (as an `uint8_t*`, you might want to cast it) to the element.
///
/// Short-circuiting is possible with the given [`BreakContinueHooks`].
pub fn gen_foreach_ndarray_elements<'ctx, G, F>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
pndarray: Ptr<'ctx, StructModel<NpArray>>,
body: F,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
F: Fn(
&mut G,
&mut CodeGenContext<'ctx, '_>,
BreakContinueHooks<'ctx>,
Int<'ctx, SizeT>,
Ptr<'ctx, IntModel<Byte>>,
) -> Result<(), String>,
{
// TODO: Make this more efficient - use a special NDArray iterator?
let tyctx = generator.type_context(ctx.ctx);
let sizet_model = IntModel(SizeT);
let size = call_nac3_ndarray_size(generator, ctx, pndarray);
gen_model_for(
generator,
ctx,
sizet_model.const_0(tyctx, ctx.ctx),
size,
sizet_model.const_1(tyctx, ctx.ctx),
|generator, ctx, hooks, index| {
let pelement = call_nac3_ndarray_get_nth_pelement(generator, ctx, pndarray, index);
body(generator, ctx, hooks, index, pelement)
},
)
}

View File

@ -0,0 +1,219 @@
use inkwell::{
types::BasicType,
values::{BasicValue, BasicValueEnum, PointerValue},
AddressSpace,
};
use nac3parser::ast::StrRef;
use crate::{
codegen::{
irrt::ndarray::allocation::{
alloca_ndarray, init_ndarray_data_by_alloca, init_ndarray_shape,
},
model::*,
structure::ndarray::NpArray,
util::shape::make_shape_writer,
CodeGenContext, CodeGenerator,
},
symbol_resolver::ValueEnum,
toplevel::DefinitionId,
typecheck::typedef::{FunSignature, Type},
};
use super::control::gen_foreach_ndarray_elements;
/// Helper function to create an ndarray with uninitialized values
///
/// * `elem_ty` - The [`Type`] of the ndarray elements
/// * `shape` - The user input shape argument
/// * `shape_ty` - The [`Type`] of the shape argument
/// * `name` - LLVM IR name of the returned ndarray
fn create_empty_ndarray<'ctx, G>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
shape: BasicValueEnum<'ctx>,
shape_ty: Type,
name: &str,
) -> Result<Ptr<'ctx, StructModel<NpArray>>, String>
where
G: CodeGenerator + ?Sized,
{
let tyctx = generator.type_context(ctx.ctx);
let sizet_model = IntModel(SizeT);
let shape_writer = make_shape_writer(generator, ctx, shape, shape_ty);
let ndims = shape_writer.len;
let ndarray = alloca_ndarray(generator, ctx, ndims, name)?;
init_ndarray_shape(generator, ctx, ndarray, &shape_writer)?;
let itemsize = ctx.get_llvm_type(generator, elem_ty).size_of().unwrap();
let itemsize = sizet_model.check_value(tyctx, ctx.ctx, itemsize).unwrap();
ndarray.gep(ctx, |f| f.itemsize).store(ctx, itemsize);
// Needs `itemsize` and `shape` initialized
init_ndarray_data_by_alloca(generator, ctx, ndarray);
Ok(ndarray)
}
/// Helper function to create an ndarray full of a value.
///
/// * `elem_ty` - The [`Type`] of the ndarray elements and the fill value
/// * `shape` - The user input shape argument
/// * `shape_ty` - The [`Type`] of the shape argument
/// * `fill_value` - The user specified fill value
/// * `name` - LLVM IR name of the returned ndarray
fn create_full_ndarray<'ctx, G>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
shape: BasicValueEnum<'ctx>,
shape_ty: Type,
fill_value: BasicValueEnum<'ctx>,
name: &str,
) -> Result<Ptr<'ctx, StructModel<NpArray>>, String>
where
G: CodeGenerator + ?Sized,
{
let pndarray = create_empty_ndarray(generator, ctx, elem_ty, shape, shape_ty, name)?;
gen_foreach_ndarray_elements(
generator,
ctx,
pndarray,
|_generator, ctx, _hooks, _i, pelement| {
// Cannot use Model here, fill_value's type is not statically known.
let pfill_value_ty = fill_value.get_type().ptr_type(AddressSpace::default());
let pelement =
ctx.builder.build_pointer_cast(pelement.value, pfill_value_ty, "pelement").unwrap();
ctx.builder.build_store(pelement, fill_value).unwrap();
Ok(())
},
)?;
Ok(pndarray)
}
/// Generates LLVM IR for `np.empty`.
pub fn gen_ndarray_empty<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
// Parse arguments
let shape_ty = fun.0.args[0].ty;
let shape = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
// Implementation
let ndarray_ptr = create_empty_ndarray(
generator,
context,
context.primitives.float,
shape,
shape_ty,
"ndarray",
)?;
Ok(ndarray_ptr.value)
}
/// Generates LLVM IR for `np.zeros`.
pub fn gen_ndarray_zeros<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
// Parse arguments
let shape_ty = fun.0.args[0].ty;
let shape = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
// Implementation
// NOTE: Currently nac3's `np.zeros` is always `float64`.
let float64_ty = context.primitives.float;
let float64_llvm_type = context.get_llvm_type(generator, float64_ty).into_float_type();
let ndarray_ptr = create_full_ndarray(
generator,
context,
float64_ty, // `elem_ty` is always `float64`
shape,
shape_ty,
float64_llvm_type.const_zero().as_basic_value_enum(),
"ndarray",
)?;
Ok(ndarray_ptr.value)
}
/// Generates LLVM IR for `np.ones`.
pub fn gen_ndarray_ones<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
// Parse arguments
let shape_ty = fun.0.args[0].ty;
let shape = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
// Implementation
// NOTE: Currently nac3's `np.ones` is always `float64`.
let float64_ty = context.primitives.float;
let float64_llvm_type = context.get_llvm_type(generator, float64_ty).into_float_type();
let ndarray_ptr = create_full_ndarray(
generator,
context,
float64_ty, // `elem_ty` is always `float64`
shape,
shape_ty,
float64_llvm_type.const_float(1.0).as_basic_value_enum(),
"ndarray",
)?;
Ok(ndarray_ptr.value)
}
/// Generates LLVM IR for `np.full`.
pub fn gen_ndarray_full<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 2);
// Parse argument #1 shape
let shape_ty = fun.0.args[0].ty;
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
// Parse argument #2 fill_value
let fill_value_ty = fun.0.args[1].ty;
let fill_value_arg =
args[1].1.clone().to_basic_value_enum(context, generator, fill_value_ty)?;
// Implementation
let ndarray_ptr = create_full_ndarray(
generator,
context,
fill_value_ty,
shape_arg,
shape_ty,
fill_value_arg,
"ndarray",
)?;
Ok(ndarray_ptr.value)
}

View File

@ -0,0 +1,2 @@
pub mod control;
pub mod factory;

View File

@ -20,6 +20,7 @@ use crate::{
irrt::*,
model::*,
numpy::*,
numpy_new,
stmt::exn_constructor,
structure::ndarray::NpArray,
},
@ -1205,9 +1206,11 @@ impl<'a> BuiltinBuilder<'a> {
&[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
Box::new(move |ctx, obj, fun, args, generator| {
let func = match prim {
PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => gen_ndarray_empty,
PrimDef::FunNpZeros => gen_ndarray_zeros,
PrimDef::FunNpOnes => gen_ndarray_ones,
PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => {
numpy_new::factory::gen_ndarray_empty
}
PrimDef::FunNpZeros => numpy_new::factory::gen_ndarray_zeros,
PrimDef::FunNpOnes => numpy_new::factory::gen_ndarray_ones,
_ => unreachable!(),
};
func(ctx, &obj, fun, &args, generator).map(|val| Some(val.as_basic_value_enum()))
@ -1275,7 +1278,7 @@ impl<'a> BuiltinBuilder<'a> {
// type variable
&[(self.list_int32, "shape"), (tv.ty, "fill_value")],
Box::new(move |ctx, obj, fun, args, generator| {
gen_ndarray_full(ctx, &obj, fun, &args, generator)
numpy_new::factory::gen_ndarray_full(ctx, &obj, fun, &args, generator)
.map(|val| Some(val.as_basic_value_enum()))
}),
)