From 2211c4d852db1db1948ca5d1c6bce7e4e7392345 Mon Sep 17 00:00:00 2001 From: lyken Date: Sun, 28 Jul 2024 17:06:37 +0800 Subject: [PATCH] core/ndstrides: implement gen_foreach_ndarray_elements & np_{empty,ndarray,zeros,ones,full} --- nac3core/irrt/irrt/ndarray/basic.hpp | 10 + nac3core/src/codegen/irrt/ndarray/basic.rs | 18 ++ nac3core/src/codegen/mod.rs | 1 + nac3core/src/codegen/numpy_new/control.rs | 49 +++++ nac3core/src/codegen/numpy_new/factory.rs | 219 +++++++++++++++++++++ nac3core/src/codegen/numpy_new/mod.rs | 2 + nac3core/src/toplevel/builtins.rs | 11 +- 7 files changed, 306 insertions(+), 4 deletions(-) create mode 100644 nac3core/src/codegen/numpy_new/control.rs create mode 100644 nac3core/src/codegen/numpy_new/factory.rs create mode 100644 nac3core/src/codegen/numpy_new/mod.rs diff --git a/nac3core/irrt/irrt/ndarray/basic.hpp b/nac3core/irrt/irrt/ndarray/basic.hpp index 3e91ea71..9e0565b9 100644 --- a/nac3core/irrt/irrt/ndarray/basic.hpp +++ b/nac3core/irrt/irrt/ndarray/basic.hpp @@ -303,4 +303,14 @@ void __nac3_ndarray_copy_data64(NDArray* src_ndarray, NDArray* dst_ndarray) { copy_data(src_ndarray, dst_ndarray); } + +uint8_t* __nac3_ndarray_get_nth_pelement(NDArray* ndarray, + int32_t index) { + return get_nth_pelement(ndarray, index); +} + +uint8_t* __nac3_ndarray_get_nth_pelement64(NDArray* ndarray, + int64_t index) { + return get_nth_pelement(ndarray, index); +} } \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/ndarray/basic.rs b/nac3core/src/codegen/irrt/ndarray/basic.rs index a4796ef7..df3577ea 100644 --- a/nac3core/src/codegen/irrt/ndarray/basic.rs +++ b/nac3core/src/codegen/irrt/ndarray/basic.rs @@ -133,3 +133,21 @@ pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>( .arg("dst_ndarray", dst_ndarray) .returning("is_c_contiguous") } + +pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + pndarray: Ptr<'ctx, StructModel>, + index: Int<'ctx, SizeT>, +) -> Ptr<'ctx, IntModel> { + let tyctx = generator.type_context(ctx.ctx); + + CallFunction::begin( + tyctx, + ctx, + &get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_get_nth_pelement"), + ) + .arg("ndarray", pndarray) + .arg("index", index) + .returning("pelement") +} diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 432073ae..27e45b88 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -45,6 +45,7 @@ pub mod irrt; pub mod llvm_intrinsics; pub mod model; pub mod numpy; +pub mod numpy_new; pub mod stmt; pub mod structure; pub mod util; diff --git a/nac3core/src/codegen/numpy_new/control.rs b/nac3core/src/codegen/numpy_new/control.rs new file mode 100644 index 00000000..ee99f3ac --- /dev/null +++ b/nac3core/src/codegen/numpy_new/control.rs @@ -0,0 +1,49 @@ +use crate::codegen::{ + irrt::ndarray::basic::{call_nac3_ndarray_get_nth_pelement, call_nac3_ndarray_size}, + model::*, + stmt::BreakContinueHooks, + structure::ndarray::NpArray, + util::control::gen_model_for, + CodeGenContext, CodeGenerator, +}; + +/// Iterate through all elements in an ndarray. +/// +/// `body` is given the index of an element and an opaque pointer (as an `uint8_t*`, you might want to cast it) to the element. +/// +/// Short-circuiting is possible with the given [`BreakContinueHooks`]. +pub fn gen_foreach_ndarray_elements<'ctx, G, F>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + pndarray: Ptr<'ctx, StructModel>, + body: F, +) -> Result<(), String> +where + G: CodeGenerator + ?Sized, + F: Fn( + &mut G, + &mut CodeGenContext<'ctx, '_>, + BreakContinueHooks<'ctx>, + Int<'ctx, SizeT>, + Ptr<'ctx, IntModel>, + ) -> Result<(), String>, +{ + // TODO: Make this more efficient - use a special NDArray iterator? + + let tyctx = generator.type_context(ctx.ctx); + + let sizet_model = IntModel(SizeT); + let size = call_nac3_ndarray_size(generator, ctx, pndarray); + + gen_model_for( + generator, + ctx, + sizet_model.const_0(tyctx, ctx.ctx), + size, + sizet_model.const_1(tyctx, ctx.ctx), + |generator, ctx, hooks, index| { + let pelement = call_nac3_ndarray_get_nth_pelement(generator, ctx, pndarray, index); + body(generator, ctx, hooks, index, pelement) + }, + ) +} diff --git a/nac3core/src/codegen/numpy_new/factory.rs b/nac3core/src/codegen/numpy_new/factory.rs new file mode 100644 index 00000000..817fd5b5 --- /dev/null +++ b/nac3core/src/codegen/numpy_new/factory.rs @@ -0,0 +1,219 @@ +use inkwell::{ + types::BasicType, + values::{BasicValue, BasicValueEnum, PointerValue}, + AddressSpace, +}; +use nac3parser::ast::StrRef; + +use crate::{ + codegen::{ + irrt::ndarray::allocation::{ + alloca_ndarray, init_ndarray_data_by_alloca, init_ndarray_shape, + }, + model::*, + structure::ndarray::NpArray, + util::shape::make_shape_writer, + CodeGenContext, CodeGenerator, + }, + symbol_resolver::ValueEnum, + toplevel::DefinitionId, + typecheck::typedef::{FunSignature, Type}, +}; + +use super::control::gen_foreach_ndarray_elements; + +/// Helper function to create an ndarray with uninitialized values +/// +/// * `elem_ty` - The [`Type`] of the ndarray elements +/// * `shape` - The user input shape argument +/// * `shape_ty` - The [`Type`] of the shape argument +/// * `name` - LLVM IR name of the returned ndarray +fn create_empty_ndarray<'ctx, G>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + elem_ty: Type, + shape: BasicValueEnum<'ctx>, + shape_ty: Type, + name: &str, +) -> Result>, String> +where + G: CodeGenerator + ?Sized, +{ + let tyctx = generator.type_context(ctx.ctx); + let sizet_model = IntModel(SizeT); + + let shape_writer = make_shape_writer(generator, ctx, shape, shape_ty); + let ndims = shape_writer.len; + + let ndarray = alloca_ndarray(generator, ctx, ndims, name)?; + init_ndarray_shape(generator, ctx, ndarray, &shape_writer)?; + + let itemsize = ctx.get_llvm_type(generator, elem_ty).size_of().unwrap(); + let itemsize = sizet_model.check_value(tyctx, ctx.ctx, itemsize).unwrap(); + ndarray.gep(ctx, |f| f.itemsize).store(ctx, itemsize); + + // Needs `itemsize` and `shape` initialized + init_ndarray_data_by_alloca(generator, ctx, ndarray); + + Ok(ndarray) +} + +/// Helper function to create an ndarray full of a value. +/// +/// * `elem_ty` - The [`Type`] of the ndarray elements and the fill value +/// * `shape` - The user input shape argument +/// * `shape_ty` - The [`Type`] of the shape argument +/// * `fill_value` - The user specified fill value +/// * `name` - LLVM IR name of the returned ndarray +fn create_full_ndarray<'ctx, G>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + elem_ty: Type, + shape: BasicValueEnum<'ctx>, + shape_ty: Type, + fill_value: BasicValueEnum<'ctx>, + name: &str, +) -> Result>, String> +where + G: CodeGenerator + ?Sized, +{ + let pndarray = create_empty_ndarray(generator, ctx, elem_ty, shape, shape_ty, name)?; + gen_foreach_ndarray_elements( + generator, + ctx, + pndarray, + |_generator, ctx, _hooks, _i, pelement| { + // Cannot use Model here, fill_value's type is not statically known. + let pfill_value_ty = fill_value.get_type().ptr_type(AddressSpace::default()); + let pelement = + ctx.builder.build_pointer_cast(pelement.value, pfill_value_ty, "pelement").unwrap(); + ctx.builder.build_store(pelement, fill_value).unwrap(); + Ok(()) + }, + )?; + Ok(pndarray) +} + +/// Generates LLVM IR for `np.empty`. +pub fn gen_ndarray_empty<'ctx>( + context: &mut CodeGenContext<'ctx, '_>, + obj: &Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: &[(Option, ValueEnum<'ctx>)], + generator: &mut dyn CodeGenerator, +) -> Result, String> { + assert!(obj.is_none()); + assert_eq!(args.len(), 1); + + // Parse arguments + let shape_ty = fun.0.args[0].ty; + let shape = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; + + // Implementation + let ndarray_ptr = create_empty_ndarray( + generator, + context, + context.primitives.float, + shape, + shape_ty, + "ndarray", + )?; + Ok(ndarray_ptr.value) +} + +/// Generates LLVM IR for `np.zeros`. +pub fn gen_ndarray_zeros<'ctx>( + context: &mut CodeGenContext<'ctx, '_>, + obj: &Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: &[(Option, ValueEnum<'ctx>)], + generator: &mut dyn CodeGenerator, +) -> Result, String> { + assert!(obj.is_none()); + assert_eq!(args.len(), 1); + + // Parse arguments + let shape_ty = fun.0.args[0].ty; + let shape = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; + + // Implementation + // NOTE: Currently nac3's `np.zeros` is always `float64`. + let float64_ty = context.primitives.float; + let float64_llvm_type = context.get_llvm_type(generator, float64_ty).into_float_type(); + + let ndarray_ptr = create_full_ndarray( + generator, + context, + float64_ty, // `elem_ty` is always `float64` + shape, + shape_ty, + float64_llvm_type.const_zero().as_basic_value_enum(), + "ndarray", + )?; + Ok(ndarray_ptr.value) +} + +/// Generates LLVM IR for `np.ones`. +pub fn gen_ndarray_ones<'ctx>( + context: &mut CodeGenContext<'ctx, '_>, + obj: &Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: &[(Option, ValueEnum<'ctx>)], + generator: &mut dyn CodeGenerator, +) -> Result, String> { + assert!(obj.is_none()); + assert_eq!(args.len(), 1); + + // Parse arguments + let shape_ty = fun.0.args[0].ty; + let shape = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; + + // Implementation + // NOTE: Currently nac3's `np.ones` is always `float64`. + let float64_ty = context.primitives.float; + let float64_llvm_type = context.get_llvm_type(generator, float64_ty).into_float_type(); + + let ndarray_ptr = create_full_ndarray( + generator, + context, + float64_ty, // `elem_ty` is always `float64` + shape, + shape_ty, + float64_llvm_type.const_float(1.0).as_basic_value_enum(), + "ndarray", + )?; + Ok(ndarray_ptr.value) +} + +/// Generates LLVM IR for `np.full`. +pub fn gen_ndarray_full<'ctx>( + context: &mut CodeGenContext<'ctx, '_>, + obj: &Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: &[(Option, ValueEnum<'ctx>)], + generator: &mut dyn CodeGenerator, +) -> Result, String> { + assert!(obj.is_none()); + assert_eq!(args.len(), 2); + + // Parse argument #1 shape + let shape_ty = fun.0.args[0].ty; + let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; + + // Parse argument #2 fill_value + let fill_value_ty = fun.0.args[1].ty; + let fill_value_arg = + args[1].1.clone().to_basic_value_enum(context, generator, fill_value_ty)?; + + // Implementation + let ndarray_ptr = create_full_ndarray( + generator, + context, + fill_value_ty, + shape_arg, + shape_ty, + fill_value_arg, + "ndarray", + )?; + Ok(ndarray_ptr.value) +} diff --git a/nac3core/src/codegen/numpy_new/mod.rs b/nac3core/src/codegen/numpy_new/mod.rs new file mode 100644 index 00000000..0f1a26b8 --- /dev/null +++ b/nac3core/src/codegen/numpy_new/mod.rs @@ -0,0 +1,2 @@ +pub mod control; +pub mod factory; diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index e6de64ee..eb9ed31b 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -20,6 +20,7 @@ use crate::{ irrt::*, model::*, numpy::*, + numpy_new, stmt::exn_constructor, structure::ndarray::NpArray, }, @@ -1205,9 +1206,11 @@ impl<'a> BuiltinBuilder<'a> { &[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")], Box::new(move |ctx, obj, fun, args, generator| { let func = match prim { - PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => gen_ndarray_empty, - PrimDef::FunNpZeros => gen_ndarray_zeros, - PrimDef::FunNpOnes => gen_ndarray_ones, + PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => { + numpy_new::factory::gen_ndarray_empty + } + PrimDef::FunNpZeros => numpy_new::factory::gen_ndarray_zeros, + PrimDef::FunNpOnes => numpy_new::factory::gen_ndarray_ones, _ => unreachable!(), }; func(ctx, &obj, fun, &args, generator).map(|val| Some(val.as_basic_value_enum())) @@ -1275,7 +1278,7 @@ impl<'a> BuiltinBuilder<'a> { // type variable &[(self.list_int32, "shape"), (tv.ty, "fill_value")], Box::new(move |ctx, obj, fun, args, generator| { - gen_ndarray_full(ctx, &obj, fun, &args, generator) + numpy_new::factory::gen_ndarray_full(ctx, &obj, fun, &args, generator) .map(|val| Some(val.as_basic_value_enum())) }), )