forked from M-Labs/nac3
core/ndstrides: implement gen_foreach_ndarray_elements & np_{empty,ndarray,zeros,ones,full}
This commit is contained in:
parent
5b9ac9b09c
commit
2211c4d852
|
@ -303,4 +303,14 @@ void __nac3_ndarray_copy_data64(NDArray<int64_t>* src_ndarray,
|
|||
NDArray<int64_t>* dst_ndarray) {
|
||||
copy_data(src_ndarray, dst_ndarray);
|
||||
}
|
||||
|
||||
uint8_t* __nac3_ndarray_get_nth_pelement(NDArray<int32_t>* ndarray,
|
||||
int32_t index) {
|
||||
return get_nth_pelement(ndarray, index);
|
||||
}
|
||||
|
||||
uint8_t* __nac3_ndarray_get_nth_pelement64(NDArray<int64_t>* ndarray,
|
||||
int64_t index) {
|
||||
return get_nth_pelement(ndarray, index);
|
||||
}
|
||||
}
|
|
@ -133,3 +133,21 @@ pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>(
|
|||
.arg("dst_ndarray", dst_ndarray)
|
||||
.returning("is_c_contiguous")
|
||||
}
|
||||
|
||||
pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
pndarray: Ptr<'ctx, StructModel<NpArray>>,
|
||||
index: Int<'ctx, SizeT>,
|
||||
) -> Ptr<'ctx, IntModel<Byte>> {
|
||||
let tyctx = generator.type_context(ctx.ctx);
|
||||
|
||||
CallFunction::begin(
|
||||
tyctx,
|
||||
ctx,
|
||||
&get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_get_nth_pelement"),
|
||||
)
|
||||
.arg("ndarray", pndarray)
|
||||
.arg("index", index)
|
||||
.returning("pelement")
|
||||
}
|
||||
|
|
|
@ -45,6 +45,7 @@ pub mod irrt;
|
|||
pub mod llvm_intrinsics;
|
||||
pub mod model;
|
||||
pub mod numpy;
|
||||
pub mod numpy_new;
|
||||
pub mod stmt;
|
||||
pub mod structure;
|
||||
pub mod util;
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
use crate::codegen::{
|
||||
irrt::ndarray::basic::{call_nac3_ndarray_get_nth_pelement, call_nac3_ndarray_size},
|
||||
model::*,
|
||||
stmt::BreakContinueHooks,
|
||||
structure::ndarray::NpArray,
|
||||
util::control::gen_model_for,
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
/// Iterate through all elements in an ndarray.
|
||||
///
|
||||
/// `body` is given the index of an element and an opaque pointer (as an `uint8_t*`, you might want to cast it) to the element.
|
||||
///
|
||||
/// Short-circuiting is possible with the given [`BreakContinueHooks`].
|
||||
pub fn gen_foreach_ndarray_elements<'ctx, G, F>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
pndarray: Ptr<'ctx, StructModel<NpArray>>,
|
||||
body: F,
|
||||
) -> Result<(), String>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
F: Fn(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, '_>,
|
||||
BreakContinueHooks<'ctx>,
|
||||
Int<'ctx, SizeT>,
|
||||
Ptr<'ctx, IntModel<Byte>>,
|
||||
) -> Result<(), String>,
|
||||
{
|
||||
// TODO: Make this more efficient - use a special NDArray iterator?
|
||||
|
||||
let tyctx = generator.type_context(ctx.ctx);
|
||||
|
||||
let sizet_model = IntModel(SizeT);
|
||||
let size = call_nac3_ndarray_size(generator, ctx, pndarray);
|
||||
|
||||
gen_model_for(
|
||||
generator,
|
||||
ctx,
|
||||
sizet_model.const_0(tyctx, ctx.ctx),
|
||||
size,
|
||||
sizet_model.const_1(tyctx, ctx.ctx),
|
||||
|generator, ctx, hooks, index| {
|
||||
let pelement = call_nac3_ndarray_get_nth_pelement(generator, ctx, pndarray, index);
|
||||
body(generator, ctx, hooks, index, pelement)
|
||||
},
|
||||
)
|
||||
}
|
|
@ -0,0 +1,219 @@
|
|||
use inkwell::{
|
||||
types::BasicType,
|
||||
values::{BasicValue, BasicValueEnum, PointerValue},
|
||||
AddressSpace,
|
||||
};
|
||||
use nac3parser::ast::StrRef;
|
||||
|
||||
use crate::{
|
||||
codegen::{
|
||||
irrt::ndarray::allocation::{
|
||||
alloca_ndarray, init_ndarray_data_by_alloca, init_ndarray_shape,
|
||||
},
|
||||
model::*,
|
||||
structure::ndarray::NpArray,
|
||||
util::shape::make_shape_writer,
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
symbol_resolver::ValueEnum,
|
||||
toplevel::DefinitionId,
|
||||
typecheck::typedef::{FunSignature, Type},
|
||||
};
|
||||
|
||||
use super::control::gen_foreach_ndarray_elements;
|
||||
|
||||
/// Helper function to create an ndarray with uninitialized values
|
||||
///
|
||||
/// * `elem_ty` - The [`Type`] of the ndarray elements
|
||||
/// * `shape` - The user input shape argument
|
||||
/// * `shape_ty` - The [`Type`] of the shape argument
|
||||
/// * `name` - LLVM IR name of the returned ndarray
|
||||
fn create_empty_ndarray<'ctx, G>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
elem_ty: Type,
|
||||
shape: BasicValueEnum<'ctx>,
|
||||
shape_ty: Type,
|
||||
name: &str,
|
||||
) -> Result<Ptr<'ctx, StructModel<NpArray>>, String>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
{
|
||||
let tyctx = generator.type_context(ctx.ctx);
|
||||
let sizet_model = IntModel(SizeT);
|
||||
|
||||
let shape_writer = make_shape_writer(generator, ctx, shape, shape_ty);
|
||||
let ndims = shape_writer.len;
|
||||
|
||||
let ndarray = alloca_ndarray(generator, ctx, ndims, name)?;
|
||||
init_ndarray_shape(generator, ctx, ndarray, &shape_writer)?;
|
||||
|
||||
let itemsize = ctx.get_llvm_type(generator, elem_ty).size_of().unwrap();
|
||||
let itemsize = sizet_model.check_value(tyctx, ctx.ctx, itemsize).unwrap();
|
||||
ndarray.gep(ctx, |f| f.itemsize).store(ctx, itemsize);
|
||||
|
||||
// Needs `itemsize` and `shape` initialized
|
||||
init_ndarray_data_by_alloca(generator, ctx, ndarray);
|
||||
|
||||
Ok(ndarray)
|
||||
}
|
||||
|
||||
/// Helper function to create an ndarray full of a value.
|
||||
///
|
||||
/// * `elem_ty` - The [`Type`] of the ndarray elements and the fill value
|
||||
/// * `shape` - The user input shape argument
|
||||
/// * `shape_ty` - The [`Type`] of the shape argument
|
||||
/// * `fill_value` - The user specified fill value
|
||||
/// * `name` - LLVM IR name of the returned ndarray
|
||||
fn create_full_ndarray<'ctx, G>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
elem_ty: Type,
|
||||
shape: BasicValueEnum<'ctx>,
|
||||
shape_ty: Type,
|
||||
fill_value: BasicValueEnum<'ctx>,
|
||||
name: &str,
|
||||
) -> Result<Ptr<'ctx, StructModel<NpArray>>, String>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
{
|
||||
let pndarray = create_empty_ndarray(generator, ctx, elem_ty, shape, shape_ty, name)?;
|
||||
gen_foreach_ndarray_elements(
|
||||
generator,
|
||||
ctx,
|
||||
pndarray,
|
||||
|_generator, ctx, _hooks, _i, pelement| {
|
||||
// Cannot use Model here, fill_value's type is not statically known.
|
||||
let pfill_value_ty = fill_value.get_type().ptr_type(AddressSpace::default());
|
||||
let pelement =
|
||||
ctx.builder.build_pointer_cast(pelement.value, pfill_value_ty, "pelement").unwrap();
|
||||
ctx.builder.build_store(pelement, fill_value).unwrap();
|
||||
Ok(())
|
||||
},
|
||||
)?;
|
||||
Ok(pndarray)
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `np.empty`.
|
||||
pub fn gen_ndarray_empty<'ctx>(
|
||||
context: &mut CodeGenContext<'ctx, '_>,
|
||||
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||
fun: (&FunSignature, DefinitionId),
|
||||
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||
generator: &mut dyn CodeGenerator,
|
||||
) -> Result<PointerValue<'ctx>, String> {
|
||||
assert!(obj.is_none());
|
||||
assert_eq!(args.len(), 1);
|
||||
|
||||
// Parse arguments
|
||||
let shape_ty = fun.0.args[0].ty;
|
||||
let shape = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
||||
|
||||
// Implementation
|
||||
let ndarray_ptr = create_empty_ndarray(
|
||||
generator,
|
||||
context,
|
||||
context.primitives.float,
|
||||
shape,
|
||||
shape_ty,
|
||||
"ndarray",
|
||||
)?;
|
||||
Ok(ndarray_ptr.value)
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `np.zeros`.
|
||||
pub fn gen_ndarray_zeros<'ctx>(
|
||||
context: &mut CodeGenContext<'ctx, '_>,
|
||||
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||
fun: (&FunSignature, DefinitionId),
|
||||
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||
generator: &mut dyn CodeGenerator,
|
||||
) -> Result<PointerValue<'ctx>, String> {
|
||||
assert!(obj.is_none());
|
||||
assert_eq!(args.len(), 1);
|
||||
|
||||
// Parse arguments
|
||||
let shape_ty = fun.0.args[0].ty;
|
||||
let shape = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
||||
|
||||
// Implementation
|
||||
// NOTE: Currently nac3's `np.zeros` is always `float64`.
|
||||
let float64_ty = context.primitives.float;
|
||||
let float64_llvm_type = context.get_llvm_type(generator, float64_ty).into_float_type();
|
||||
|
||||
let ndarray_ptr = create_full_ndarray(
|
||||
generator,
|
||||
context,
|
||||
float64_ty, // `elem_ty` is always `float64`
|
||||
shape,
|
||||
shape_ty,
|
||||
float64_llvm_type.const_zero().as_basic_value_enum(),
|
||||
"ndarray",
|
||||
)?;
|
||||
Ok(ndarray_ptr.value)
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `np.ones`.
|
||||
pub fn gen_ndarray_ones<'ctx>(
|
||||
context: &mut CodeGenContext<'ctx, '_>,
|
||||
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||
fun: (&FunSignature, DefinitionId),
|
||||
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||
generator: &mut dyn CodeGenerator,
|
||||
) -> Result<PointerValue<'ctx>, String> {
|
||||
assert!(obj.is_none());
|
||||
assert_eq!(args.len(), 1);
|
||||
|
||||
// Parse arguments
|
||||
let shape_ty = fun.0.args[0].ty;
|
||||
let shape = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
||||
|
||||
// Implementation
|
||||
// NOTE: Currently nac3's `np.ones` is always `float64`.
|
||||
let float64_ty = context.primitives.float;
|
||||
let float64_llvm_type = context.get_llvm_type(generator, float64_ty).into_float_type();
|
||||
|
||||
let ndarray_ptr = create_full_ndarray(
|
||||
generator,
|
||||
context,
|
||||
float64_ty, // `elem_ty` is always `float64`
|
||||
shape,
|
||||
shape_ty,
|
||||
float64_llvm_type.const_float(1.0).as_basic_value_enum(),
|
||||
"ndarray",
|
||||
)?;
|
||||
Ok(ndarray_ptr.value)
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `np.full`.
|
||||
pub fn gen_ndarray_full<'ctx>(
|
||||
context: &mut CodeGenContext<'ctx, '_>,
|
||||
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||
fun: (&FunSignature, DefinitionId),
|
||||
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||
generator: &mut dyn CodeGenerator,
|
||||
) -> Result<PointerValue<'ctx>, String> {
|
||||
assert!(obj.is_none());
|
||||
assert_eq!(args.len(), 2);
|
||||
|
||||
// Parse argument #1 shape
|
||||
let shape_ty = fun.0.args[0].ty;
|
||||
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
||||
|
||||
// Parse argument #2 fill_value
|
||||
let fill_value_ty = fun.0.args[1].ty;
|
||||
let fill_value_arg =
|
||||
args[1].1.clone().to_basic_value_enum(context, generator, fill_value_ty)?;
|
||||
|
||||
// Implementation
|
||||
let ndarray_ptr = create_full_ndarray(
|
||||
generator,
|
||||
context,
|
||||
fill_value_ty,
|
||||
shape_arg,
|
||||
shape_ty,
|
||||
fill_value_arg,
|
||||
"ndarray",
|
||||
)?;
|
||||
Ok(ndarray_ptr.value)
|
||||
}
|
|
@ -0,0 +1,2 @@
|
|||
pub mod control;
|
||||
pub mod factory;
|
|
@ -20,6 +20,7 @@ use crate::{
|
|||
irrt::*,
|
||||
model::*,
|
||||
numpy::*,
|
||||
numpy_new,
|
||||
stmt::exn_constructor,
|
||||
structure::ndarray::NpArray,
|
||||
},
|
||||
|
@ -1205,9 +1206,11 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
&[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
|
||||
Box::new(move |ctx, obj, fun, args, generator| {
|
||||
let func = match prim {
|
||||
PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => gen_ndarray_empty,
|
||||
PrimDef::FunNpZeros => gen_ndarray_zeros,
|
||||
PrimDef::FunNpOnes => gen_ndarray_ones,
|
||||
PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => {
|
||||
numpy_new::factory::gen_ndarray_empty
|
||||
}
|
||||
PrimDef::FunNpZeros => numpy_new::factory::gen_ndarray_zeros,
|
||||
PrimDef::FunNpOnes => numpy_new::factory::gen_ndarray_ones,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
func(ctx, &obj, fun, &args, generator).map(|val| Some(val.as_basic_value_enum()))
|
||||
|
@ -1275,7 +1278,7 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
// type variable
|
||||
&[(self.list_int32, "shape"), (tv.ty, "fill_value")],
|
||||
Box::new(move |ctx, obj, fun, args, generator| {
|
||||
gen_ndarray_full(ctx, &obj, fun, &args, generator)
|
||||
numpy_new::factory::gen_ndarray_full(ctx, &obj, fun, &args, generator)
|
||||
.map(|val| Some(val.as_basic_value_enum()))
|
||||
}),
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue