forked from M-Labs/nac3
core/ndstrides: implement ndarray np_{empty,ndarray,zeros,ones,full}
This commit is contained in:
parent
fc9d47fb54
commit
85ef06f1e2
|
@ -0,0 +1,38 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <irrt/ndarray/basic.hpp>
|
||||||
|
#include <irrt/ndarray/def.hpp>
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
namespace ndarray {
|
||||||
|
namespace fill {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fill an ndarray with a value.
|
||||||
|
*
|
||||||
|
* @param pvalue Pointer to the fill value, and the fill value should be of `ndarray->itemsize` bytes.
|
||||||
|
*/
|
||||||
|
template <typename SizeT>
|
||||||
|
void fill_generic(NDArray<SizeT>* ndarray, const uint8_t* pvalue) {
|
||||||
|
const SizeT size = ndarray::basic::size(ndarray);
|
||||||
|
for (SizeT i = 0; i < size; i++) {
|
||||||
|
uint8_t* pelement = ndarray::basic::get_nth_pelement(
|
||||||
|
ndarray, i); // No need for checked_get_nth_pelement
|
||||||
|
ndarray::basic::set_pelement_value(ndarray, pelement, pvalue);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace fill
|
||||||
|
} // namespace ndarray
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
using namespace ndarray::fill;
|
||||||
|
|
||||||
|
void __nac3_ndarray_fill_generic(NDArray<int32_t>* ndarray, uint8_t* pvalue) {
|
||||||
|
fill_generic(ndarray, pvalue);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_fill_generic64(NDArray<int64_t>* ndarray, uint8_t* pvalue) {
|
||||||
|
fill_generic(ndarray, pvalue);
|
||||||
|
}
|
||||||
|
}
|
|
@ -6,4 +6,5 @@
|
||||||
#include <irrt/int_defs.hpp>
|
#include <irrt/int_defs.hpp>
|
||||||
#include <irrt/ndarray/basic.hpp>
|
#include <irrt/ndarray/basic.hpp>
|
||||||
#include <irrt/ndarray/def.hpp>
|
#include <irrt/ndarray/def.hpp>
|
||||||
|
#include <irrt/ndarray/fill.hpp>
|
||||||
#include <irrt/utils.hpp>
|
#include <irrt/utils.hpp>
|
|
@ -0,0 +1,21 @@
|
||||||
|
use crate::codegen::{
|
||||||
|
irrt::util::get_sized_dependent_function_name, model::*, structs::ndarray::NpArray,
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn call_nac3_ndarray_fill_generic<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray_ptr: Pointer<'ctx, StructModel<NpArray<'ctx>>>,
|
||||||
|
fill_value_ptr: Pointer<'ctx, ByteModel>,
|
||||||
|
) {
|
||||||
|
let sizet = generator.get_sizet(ctx.ctx);
|
||||||
|
|
||||||
|
FunctionBuilder::begin(
|
||||||
|
ctx,
|
||||||
|
&get_sized_dependent_function_name(sizet, "__nac3_ndarray_fill_generic"),
|
||||||
|
)
|
||||||
|
.arg("ndarray", ndarray_ptr)
|
||||||
|
.arg("pvalue", fill_value_ptr)
|
||||||
|
.returning_void();
|
||||||
|
}
|
|
@ -1,2 +1,3 @@
|
||||||
pub mod allocation;
|
pub mod allocation;
|
||||||
pub mod basic;
|
pub mod basic;
|
||||||
|
pub mod fill;
|
||||||
|
|
|
@ -45,12 +45,13 @@ pub mod irrt;
|
||||||
pub mod llvm_intrinsics;
|
pub mod llvm_intrinsics;
|
||||||
pub mod model;
|
pub mod model;
|
||||||
pub mod numpy;
|
pub mod numpy;
|
||||||
|
pub mod numpy_new;
|
||||||
pub mod stmt;
|
pub mod stmt;
|
||||||
pub mod structs;
|
pub mod structs;
|
||||||
pub mod util;
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test;
|
mod test;
|
||||||
|
pub mod util;
|
||||||
|
|
||||||
use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore};
|
use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore};
|
||||||
pub use generator::{CodeGenerator, DefaultCodeGenerator};
|
pub use generator::{CodeGenerator, DefaultCodeGenerator};
|
||||||
|
|
|
@ -0,0 +1,213 @@
|
||||||
|
use inkwell::{
|
||||||
|
types::BasicType,
|
||||||
|
values::{BasicValue, BasicValueEnum, PointerValue},
|
||||||
|
};
|
||||||
|
use nac3parser::ast::StrRef;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
codegen::{
|
||||||
|
irrt::ndarray::{
|
||||||
|
allocation::{alloca_ndarray, init_ndarray_data_by_alloca, init_ndarray_shape},
|
||||||
|
fill::call_nac3_ndarray_fill_generic,
|
||||||
|
},
|
||||||
|
model::*,
|
||||||
|
structs::ndarray::NpArray,
|
||||||
|
util::shape::parse_input_shape_arg,
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
},
|
||||||
|
symbol_resolver::ValueEnum,
|
||||||
|
toplevel::DefinitionId,
|
||||||
|
typecheck::typedef::{FunSignature, Type},
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Helper function to create an ndarray with uninitialized values
|
||||||
|
///
|
||||||
|
/// * `elem_ty` - The [`Type`] of the ndarray elements
|
||||||
|
/// * `shape` - The user input shape argument
|
||||||
|
/// * `shape_ty` - The [`Type`] of the shape argument
|
||||||
|
/// * `name` - LLVM IR name of the returned ndarray
|
||||||
|
fn create_empty_ndarray<'ctx, G>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
elem_ty: Type,
|
||||||
|
shape: BasicValueEnum<'ctx>,
|
||||||
|
shape_ty: Type,
|
||||||
|
name: &str,
|
||||||
|
) -> Result<Pointer<'ctx, StructModel<NpArray<'ctx>>>, String>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
{
|
||||||
|
let sizet = generator.get_sizet(ctx.ctx);
|
||||||
|
|
||||||
|
let shape_writer = parse_input_shape_arg(generator, ctx, shape, shape_ty);
|
||||||
|
let ndims = shape_writer.count;
|
||||||
|
|
||||||
|
let ndarray = alloca_ndarray(generator, ctx, ndims, name)?;
|
||||||
|
init_ndarray_shape(generator, ctx, ndarray, &shape_writer)?;
|
||||||
|
|
||||||
|
let itemsize = sizet
|
||||||
|
.review_value(ctx.ctx, ctx.get_llvm_type(generator, elem_ty).size_of().unwrap())
|
||||||
|
.unwrap();
|
||||||
|
ndarray.gep(ctx, |f| f.itemsize).store(ctx, itemsize);
|
||||||
|
|
||||||
|
init_ndarray_data_by_alloca(generator, ctx, ndarray); // Needs `itemsize` and `shape` initialized first
|
||||||
|
|
||||||
|
Ok(ndarray)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper function to create an ndarray full of a value.
|
||||||
|
///
|
||||||
|
/// * `elem_ty` - The [`Type`] of the ndarray elements and the fill value
|
||||||
|
/// * `shape` - The user input shape argument
|
||||||
|
/// * `shape_ty` - The [`Type`] of the shape argument
|
||||||
|
/// * `fill_value` - The user specified fill value
|
||||||
|
/// * `name` - LLVM IR name of the returned ndarray
|
||||||
|
fn create_full_ndarray<'ctx, G>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
elem_ty: Type,
|
||||||
|
shape: BasicValueEnum<'ctx>,
|
||||||
|
shape_ty: Type,
|
||||||
|
fill_value: BasicValueEnum<'ctx>,
|
||||||
|
name: &str,
|
||||||
|
) -> Result<Pointer<'ctx, StructModel<NpArray<'ctx>>>, String>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
{
|
||||||
|
let byte_model = NIntModel(Byte);
|
||||||
|
let fill_value_model = OpaqueModel(fill_value.get_type());
|
||||||
|
|
||||||
|
// Caller has to put fill_value on the stack and pass its address
|
||||||
|
let fill_value_ptr = fill_value_model.alloca(ctx, "fill_value_ptr");
|
||||||
|
fill_value_ptr.store(ctx, fill_value_model.believe_value(fill_value));
|
||||||
|
let fill_value_ptr = fill_value_ptr.cast_to(ctx, byte_model, "fill_value_bytes_ptr");
|
||||||
|
|
||||||
|
let ndarray_ptr = create_empty_ndarray(generator, ctx, elem_ty, shape, shape_ty, name)?;
|
||||||
|
call_nac3_ndarray_fill_generic(generator, ctx, ndarray_ptr, fill_value_ptr);
|
||||||
|
|
||||||
|
Ok(ndarray_ptr)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates LLVM IR for `np.empty`.
|
||||||
|
pub fn gen_ndarray_empty<'ctx>(
|
||||||
|
context: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||||
|
fun: (&FunSignature, DefinitionId),
|
||||||
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
) -> Result<PointerValue<'ctx>, String> {
|
||||||
|
assert!(obj.is_none());
|
||||||
|
assert_eq!(args.len(), 1);
|
||||||
|
|
||||||
|
// Parse arguments
|
||||||
|
let shape_ty = fun.0.args[0].ty;
|
||||||
|
let shape = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
||||||
|
|
||||||
|
// Implementation
|
||||||
|
let ndarray_ptr = create_empty_ndarray(
|
||||||
|
generator,
|
||||||
|
context,
|
||||||
|
context.primitives.float,
|
||||||
|
shape,
|
||||||
|
shape_ty,
|
||||||
|
"ndarray",
|
||||||
|
)?;
|
||||||
|
Ok(ndarray_ptr.value)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates LLVM IR for `np.zeros`.
|
||||||
|
pub fn gen_ndarray_zeros<'ctx>(
|
||||||
|
context: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||||
|
fun: (&FunSignature, DefinitionId),
|
||||||
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
) -> Result<PointerValue<'ctx>, String> {
|
||||||
|
assert!(obj.is_none());
|
||||||
|
assert_eq!(args.len(), 1);
|
||||||
|
|
||||||
|
// Parse arguments
|
||||||
|
let shape_ty = fun.0.args[0].ty;
|
||||||
|
let shape = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
||||||
|
|
||||||
|
// Implementation
|
||||||
|
// NOTE: Currently nac3's `np.zeros` is always `float64`.
|
||||||
|
let float64_ty = context.primitives.float;
|
||||||
|
let float64_llvm_type = context.get_llvm_type(generator, float64_ty).into_float_type();
|
||||||
|
|
||||||
|
let ndarray_ptr = create_full_ndarray(
|
||||||
|
generator,
|
||||||
|
context,
|
||||||
|
float64_ty, // `elem_ty` is always `float64`
|
||||||
|
shape,
|
||||||
|
shape_ty,
|
||||||
|
float64_llvm_type.const_zero().as_basic_value_enum(),
|
||||||
|
"ndarray",
|
||||||
|
)?;
|
||||||
|
Ok(ndarray_ptr.value)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates LLVM IR for `np.ones`.
|
||||||
|
pub fn gen_ndarray_ones<'ctx>(
|
||||||
|
context: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||||
|
fun: (&FunSignature, DefinitionId),
|
||||||
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
) -> Result<PointerValue<'ctx>, String> {
|
||||||
|
assert!(obj.is_none());
|
||||||
|
assert_eq!(args.len(), 1);
|
||||||
|
|
||||||
|
// Parse arguments
|
||||||
|
let shape_ty = fun.0.args[0].ty;
|
||||||
|
let shape = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
||||||
|
|
||||||
|
// Implementation
|
||||||
|
// NOTE: Currently nac3's `np.ones` is always `float64`.
|
||||||
|
let float64_ty = context.primitives.float;
|
||||||
|
let float64_llvm_type = context.get_llvm_type(generator, float64_ty).into_float_type();
|
||||||
|
|
||||||
|
let ndarray_ptr = create_full_ndarray(
|
||||||
|
generator,
|
||||||
|
context,
|
||||||
|
float64_ty, // `elem_ty` is always `float64`
|
||||||
|
shape,
|
||||||
|
shape_ty,
|
||||||
|
float64_llvm_type.const_float(1.0).as_basic_value_enum(),
|
||||||
|
"ndarray",
|
||||||
|
)?;
|
||||||
|
Ok(ndarray_ptr.value)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates LLVM IR for `np.full`.
|
||||||
|
pub fn gen_ndarray_full<'ctx>(
|
||||||
|
context: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||||
|
fun: (&FunSignature, DefinitionId),
|
||||||
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
) -> Result<PointerValue<'ctx>, String> {
|
||||||
|
assert!(obj.is_none());
|
||||||
|
assert_eq!(args.len(), 2);
|
||||||
|
|
||||||
|
// Parse argument #1 shape
|
||||||
|
let shape_ty = fun.0.args[0].ty;
|
||||||
|
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
||||||
|
|
||||||
|
// Parse argument #2 fill_value
|
||||||
|
let fill_value_ty = fun.0.args[1].ty;
|
||||||
|
let fill_value_arg =
|
||||||
|
args[1].1.clone().to_basic_value_enum(context, generator, fill_value_ty)?;
|
||||||
|
|
||||||
|
// Implementation
|
||||||
|
let ndarray_ptr = create_full_ndarray(
|
||||||
|
generator,
|
||||||
|
context,
|
||||||
|
fill_value_ty,
|
||||||
|
shape_arg,
|
||||||
|
shape_ty,
|
||||||
|
fill_value_arg,
|
||||||
|
"ndarray",
|
||||||
|
)?;
|
||||||
|
Ok(ndarray_ptr.value)
|
||||||
|
}
|
|
@ -0,0 +1 @@
|
||||||
|
pub mod factory;
|
|
@ -20,6 +20,7 @@ use crate::{
|
||||||
irrt::*,
|
irrt::*,
|
||||||
model::*,
|
model::*,
|
||||||
numpy::*,
|
numpy::*,
|
||||||
|
numpy_new,
|
||||||
stmt::exn_constructor,
|
stmt::exn_constructor,
|
||||||
structs::ndarray::NpArray,
|
structs::ndarray::NpArray,
|
||||||
},
|
},
|
||||||
|
@ -1205,9 +1206,11 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
&[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
|
&[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
|
||||||
Box::new(move |ctx, obj, fun, args, generator| {
|
Box::new(move |ctx, obj, fun, args, generator| {
|
||||||
let func = match prim {
|
let func = match prim {
|
||||||
PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => gen_ndarray_empty,
|
PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => {
|
||||||
PrimDef::FunNpZeros => gen_ndarray_zeros,
|
numpy_new::factory::gen_ndarray_empty
|
||||||
PrimDef::FunNpOnes => gen_ndarray_ones,
|
}
|
||||||
|
PrimDef::FunNpZeros => numpy_new::factory::gen_ndarray_zeros,
|
||||||
|
PrimDef::FunNpOnes => numpy_new::factory::gen_ndarray_ones,
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
func(ctx, &obj, fun, &args, generator).map(|val| Some(val.as_basic_value_enum()))
|
func(ctx, &obj, fun, &args, generator).map(|val| Some(val.as_basic_value_enum()))
|
||||||
|
@ -1275,7 +1278,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
// type variable
|
// type variable
|
||||||
&[(self.list_int32, "shape"), (tv.ty, "fill_value")],
|
&[(self.list_int32, "shape"), (tv.ty, "fill_value")],
|
||||||
Box::new(move |ctx, obj, fun, args, generator| {
|
Box::new(move |ctx, obj, fun, args, generator| {
|
||||||
gen_ndarray_full(ctx, &obj, fun, &args, generator)
|
numpy_new::factory::gen_ndarray_full(ctx, &obj, fun, &args, generator)
|
||||||
.map(|val| Some(val.as_basic_value_enum()))
|
.map(|val| Some(val.as_basic_value_enum()))
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue