diff --git a/nac3core/irrt/irrt/ndarray/fill.hpp b/nac3core/irrt/irrt/ndarray/fill.hpp new file mode 100644 index 00000000..e17a1ad1 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/fill.hpp @@ -0,0 +1,38 @@ +#pragma once + +#include +#include + +namespace { +namespace ndarray { +namespace fill { + +/** + * Fill an ndarray with a value. + * + * @param pvalue Pointer to the fill value, and the fill value should be of `ndarray->itemsize` bytes. + */ +template +void fill_generic(NDArray* ndarray, const uint8_t* pvalue) { + const SizeT size = ndarray::basic::size(ndarray); + for (SizeT i = 0; i < size; i++) { + uint8_t* pelement = ndarray::basic::get_nth_pelement( + ndarray, i); // No need for checked_get_nth_pelement + ndarray::basic::set_pelement_value(ndarray, pelement, pvalue); + } +} +} // namespace fill +} // namespace ndarray +} // namespace + +extern "C" { +using namespace ndarray::fill; + +void __nac3_ndarray_fill_generic(NDArray* ndarray, uint8_t* pvalue) { + fill_generic(ndarray, pvalue); +} + +void __nac3_ndarray_fill_generic64(NDArray* ndarray, uint8_t* pvalue) { + fill_generic(ndarray, pvalue); +} +} \ No newline at end of file diff --git a/nac3core/irrt/irrt_everything.hpp b/nac3core/irrt/irrt_everything.hpp index f6558051..3af426bc 100644 --- a/nac3core/irrt/irrt_everything.hpp +++ b/nac3core/irrt/irrt_everything.hpp @@ -6,4 +6,5 @@ #include #include #include +#include #include \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/ndarray/fill.rs b/nac3core/src/codegen/irrt/ndarray/fill.rs new file mode 100644 index 00000000..39d3c5af --- /dev/null +++ b/nac3core/src/codegen/irrt/ndarray/fill.rs @@ -0,0 +1,21 @@ +use crate::codegen::{ + irrt::util::get_sized_dependent_function_name, model::*, structs::ndarray::NpArray, + CodeGenContext, CodeGenerator, +}; + +pub fn call_nac3_ndarray_fill_generic<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray_ptr: Pointer<'ctx, StructModel>>, + fill_value_ptr: Pointer<'ctx, ByteModel>, +) { + let sizet = generator.get_sizet(ctx.ctx); + + FunctionBuilder::begin( + ctx, + &get_sized_dependent_function_name(sizet, "__nac3_ndarray_fill_generic"), + ) + .arg("ndarray", ndarray_ptr) + .arg("pvalue", fill_value_ptr) + .returning_void(); +} diff --git a/nac3core/src/codegen/irrt/ndarray/mod.rs b/nac3core/src/codegen/irrt/ndarray/mod.rs index 8648d99f..2f43ca6c 100644 --- a/nac3core/src/codegen/irrt/ndarray/mod.rs +++ b/nac3core/src/codegen/irrt/ndarray/mod.rs @@ -1,2 +1,3 @@ pub mod allocation; pub mod basic; +pub mod fill; diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index ddbe6d94..c6c10379 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -45,12 +45,13 @@ pub mod irrt; pub mod llvm_intrinsics; pub mod model; pub mod numpy; +pub mod numpy_new; pub mod stmt; pub mod structs; -pub mod util; #[cfg(test)] mod test; +pub mod util; use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore}; pub use generator::{CodeGenerator, DefaultCodeGenerator}; diff --git a/nac3core/src/codegen/numpy_new/factory.rs b/nac3core/src/codegen/numpy_new/factory.rs new file mode 100644 index 00000000..4ef7e859 --- /dev/null +++ b/nac3core/src/codegen/numpy_new/factory.rs @@ -0,0 +1,213 @@ +use inkwell::{ + types::BasicType, + values::{BasicValue, BasicValueEnum, PointerValue}, +}; +use nac3parser::ast::StrRef; + +use crate::{ + codegen::{ + irrt::ndarray::{ + allocation::{alloca_ndarray, init_ndarray_data_by_alloca, init_ndarray_shape}, + fill::call_nac3_ndarray_fill_generic, + }, + model::*, + structs::ndarray::NpArray, + util::shape::parse_input_shape_arg, + CodeGenContext, CodeGenerator, + }, + symbol_resolver::ValueEnum, + toplevel::DefinitionId, + typecheck::typedef::{FunSignature, Type}, +}; + +/// Helper function to create an ndarray with uninitialized values +/// +/// * `elem_ty` - The [`Type`] of the ndarray elements +/// * `shape` - The user input shape argument +/// * `shape_ty` - The [`Type`] of the shape argument +/// * `name` - LLVM IR name of the returned ndarray +fn create_empty_ndarray<'ctx, G>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + elem_ty: Type, + shape: BasicValueEnum<'ctx>, + shape_ty: Type, + name: &str, +) -> Result>>, String> +where + G: CodeGenerator + ?Sized, +{ + let sizet = generator.get_sizet(ctx.ctx); + + let shape_writer = parse_input_shape_arg(generator, ctx, shape, shape_ty); + let ndims = shape_writer.count; + + let ndarray = alloca_ndarray(generator, ctx, ndims, name)?; + init_ndarray_shape(generator, ctx, ndarray, &shape_writer)?; + + let itemsize = sizet + .review_value(ctx.ctx, ctx.get_llvm_type(generator, elem_ty).size_of().unwrap()) + .unwrap(); + ndarray.gep(ctx, |f| f.itemsize).store(ctx, itemsize); + + init_ndarray_data_by_alloca(generator, ctx, ndarray); // Needs `itemsize` and `shape` initialized first + + Ok(ndarray) +} + +/// Helper function to create an ndarray full of a value. +/// +/// * `elem_ty` - The [`Type`] of the ndarray elements and the fill value +/// * `shape` - The user input shape argument +/// * `shape_ty` - The [`Type`] of the shape argument +/// * `fill_value` - The user specified fill value +/// * `name` - LLVM IR name of the returned ndarray +fn create_full_ndarray<'ctx, G>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + elem_ty: Type, + shape: BasicValueEnum<'ctx>, + shape_ty: Type, + fill_value: BasicValueEnum<'ctx>, + name: &str, +) -> Result>>, String> +where + G: CodeGenerator + ?Sized, +{ + let byte_model = NIntModel(Byte); + let fill_value_model = OpaqueModel(fill_value.get_type()); + + // Caller has to put fill_value on the stack and pass its address + let fill_value_ptr = fill_value_model.alloca(ctx, "fill_value_ptr"); + fill_value_ptr.store(ctx, fill_value_model.believe_value(fill_value)); + let fill_value_ptr = fill_value_ptr.cast_to(ctx, byte_model, "fill_value_bytes_ptr"); + + let ndarray_ptr = create_empty_ndarray(generator, ctx, elem_ty, shape, shape_ty, name)?; + call_nac3_ndarray_fill_generic(generator, ctx, ndarray_ptr, fill_value_ptr); + + Ok(ndarray_ptr) +} + +/// Generates LLVM IR for `np.empty`. +pub fn gen_ndarray_empty<'ctx>( + context: &mut CodeGenContext<'ctx, '_>, + obj: &Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: &[(Option, ValueEnum<'ctx>)], + generator: &mut dyn CodeGenerator, +) -> Result, String> { + assert!(obj.is_none()); + assert_eq!(args.len(), 1); + + // Parse arguments + let shape_ty = fun.0.args[0].ty; + let shape = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; + + // Implementation + let ndarray_ptr = create_empty_ndarray( + generator, + context, + context.primitives.float, + shape, + shape_ty, + "ndarray", + )?; + Ok(ndarray_ptr.value) +} + +/// Generates LLVM IR for `np.zeros`. +pub fn gen_ndarray_zeros<'ctx>( + context: &mut CodeGenContext<'ctx, '_>, + obj: &Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: &[(Option, ValueEnum<'ctx>)], + generator: &mut dyn CodeGenerator, +) -> Result, String> { + assert!(obj.is_none()); + assert_eq!(args.len(), 1); + + // Parse arguments + let shape_ty = fun.0.args[0].ty; + let shape = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; + + // Implementation + // NOTE: Currently nac3's `np.zeros` is always `float64`. + let float64_ty = context.primitives.float; + let float64_llvm_type = context.get_llvm_type(generator, float64_ty).into_float_type(); + + let ndarray_ptr = create_full_ndarray( + generator, + context, + float64_ty, // `elem_ty` is always `float64` + shape, + shape_ty, + float64_llvm_type.const_zero().as_basic_value_enum(), + "ndarray", + )?; + Ok(ndarray_ptr.value) +} + +/// Generates LLVM IR for `np.ones`. +pub fn gen_ndarray_ones<'ctx>( + context: &mut CodeGenContext<'ctx, '_>, + obj: &Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: &[(Option, ValueEnum<'ctx>)], + generator: &mut dyn CodeGenerator, +) -> Result, String> { + assert!(obj.is_none()); + assert_eq!(args.len(), 1); + + // Parse arguments + let shape_ty = fun.0.args[0].ty; + let shape = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; + + // Implementation + // NOTE: Currently nac3's `np.ones` is always `float64`. + let float64_ty = context.primitives.float; + let float64_llvm_type = context.get_llvm_type(generator, float64_ty).into_float_type(); + + let ndarray_ptr = create_full_ndarray( + generator, + context, + float64_ty, // `elem_ty` is always `float64` + shape, + shape_ty, + float64_llvm_type.const_float(1.0).as_basic_value_enum(), + "ndarray", + )?; + Ok(ndarray_ptr.value) +} + +/// Generates LLVM IR for `np.full`. +pub fn gen_ndarray_full<'ctx>( + context: &mut CodeGenContext<'ctx, '_>, + obj: &Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: &[(Option, ValueEnum<'ctx>)], + generator: &mut dyn CodeGenerator, +) -> Result, String> { + assert!(obj.is_none()); + assert_eq!(args.len(), 2); + + // Parse argument #1 shape + let shape_ty = fun.0.args[0].ty; + let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; + + // Parse argument #2 fill_value + let fill_value_ty = fun.0.args[1].ty; + let fill_value_arg = + args[1].1.clone().to_basic_value_enum(context, generator, fill_value_ty)?; + + // Implementation + let ndarray_ptr = create_full_ndarray( + generator, + context, + fill_value_ty, + shape_arg, + shape_ty, + fill_value_arg, + "ndarray", + )?; + Ok(ndarray_ptr.value) +} diff --git a/nac3core/src/codegen/numpy_new/mod.rs b/nac3core/src/codegen/numpy_new/mod.rs new file mode 100644 index 00000000..a106d20e --- /dev/null +++ b/nac3core/src/codegen/numpy_new/mod.rs @@ -0,0 +1 @@ +pub mod factory; diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 8809fd1a..2ca4ee75 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -20,6 +20,7 @@ use crate::{ irrt::*, model::*, numpy::*, + numpy_new, stmt::exn_constructor, structs::ndarray::NpArray, }, @@ -1205,9 +1206,11 @@ impl<'a> BuiltinBuilder<'a> { &[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")], Box::new(move |ctx, obj, fun, args, generator| { let func = match prim { - PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => gen_ndarray_empty, - PrimDef::FunNpZeros => gen_ndarray_zeros, - PrimDef::FunNpOnes => gen_ndarray_ones, + PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => { + numpy_new::factory::gen_ndarray_empty + } + PrimDef::FunNpZeros => numpy_new::factory::gen_ndarray_zeros, + PrimDef::FunNpOnes => numpy_new::factory::gen_ndarray_ones, _ => unreachable!(), }; func(ctx, &obj, fun, &args, generator).map(|val| Some(val.as_basic_value_enum())) @@ -1275,7 +1278,7 @@ impl<'a> BuiltinBuilder<'a> { // type variable &[(self.list_int32, "shape"), (tv.ty, "fill_value")], Box::new(move |ctx, obj, fun, args, generator| { - gen_ndarray_full(ctx, &obj, fun, &args, generator) + numpy_new::factory::gen_ndarray_full(ctx, &obj, fun, &args, generator) .map(|val| Some(val.as_basic_value_enum())) }), )