forked from M-Labs/nac3
core/ndstrides: implement gen_foreach_ndarray_elements & np_{empty,ndarray,zeros,ones,full}
This commit is contained in:
parent
5b9ac9b09c
commit
2211c4d852
|
@ -303,4 +303,14 @@ void __nac3_ndarray_copy_data64(NDArray<int64_t>* src_ndarray,
|
||||||
NDArray<int64_t>* dst_ndarray) {
|
NDArray<int64_t>* dst_ndarray) {
|
||||||
copy_data(src_ndarray, dst_ndarray);
|
copy_data(src_ndarray, dst_ndarray);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint8_t* __nac3_ndarray_get_nth_pelement(NDArray<int32_t>* ndarray,
|
||||||
|
int32_t index) {
|
||||||
|
return get_nth_pelement(ndarray, index);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t* __nac3_ndarray_get_nth_pelement64(NDArray<int64_t>* ndarray,
|
||||||
|
int64_t index) {
|
||||||
|
return get_nth_pelement(ndarray, index);
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -133,3 +133,21 @@ pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
.arg("dst_ndarray", dst_ndarray)
|
.arg("dst_ndarray", dst_ndarray)
|
||||||
.returning("is_c_contiguous")
|
.returning("is_c_contiguous")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
pndarray: Ptr<'ctx, StructModel<NpArray>>,
|
||||||
|
index: Int<'ctx, SizeT>,
|
||||||
|
) -> Ptr<'ctx, IntModel<Byte>> {
|
||||||
|
let tyctx = generator.type_context(ctx.ctx);
|
||||||
|
|
||||||
|
CallFunction::begin(
|
||||||
|
tyctx,
|
||||||
|
ctx,
|
||||||
|
&get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_get_nth_pelement"),
|
||||||
|
)
|
||||||
|
.arg("ndarray", pndarray)
|
||||||
|
.arg("index", index)
|
||||||
|
.returning("pelement")
|
||||||
|
}
|
||||||
|
|
|
@ -45,6 +45,7 @@ pub mod irrt;
|
||||||
pub mod llvm_intrinsics;
|
pub mod llvm_intrinsics;
|
||||||
pub mod model;
|
pub mod model;
|
||||||
pub mod numpy;
|
pub mod numpy;
|
||||||
|
pub mod numpy_new;
|
||||||
pub mod stmt;
|
pub mod stmt;
|
||||||
pub mod structure;
|
pub mod structure;
|
||||||
pub mod util;
|
pub mod util;
|
||||||
|
|
|
@ -0,0 +1,49 @@
|
||||||
|
use crate::codegen::{
|
||||||
|
irrt::ndarray::basic::{call_nac3_ndarray_get_nth_pelement, call_nac3_ndarray_size},
|
||||||
|
model::*,
|
||||||
|
stmt::BreakContinueHooks,
|
||||||
|
structure::ndarray::NpArray,
|
||||||
|
util::control::gen_model_for,
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Iterate through all elements in an ndarray.
|
||||||
|
///
|
||||||
|
/// `body` is given the index of an element and an opaque pointer (as an `uint8_t*`, you might want to cast it) to the element.
|
||||||
|
///
|
||||||
|
/// Short-circuiting is possible with the given [`BreakContinueHooks`].
|
||||||
|
pub fn gen_foreach_ndarray_elements<'ctx, G, F>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
pndarray: Ptr<'ctx, StructModel<NpArray>>,
|
||||||
|
body: F,
|
||||||
|
) -> Result<(), String>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
F: Fn(
|
||||||
|
&mut G,
|
||||||
|
&mut CodeGenContext<'ctx, '_>,
|
||||||
|
BreakContinueHooks<'ctx>,
|
||||||
|
Int<'ctx, SizeT>,
|
||||||
|
Ptr<'ctx, IntModel<Byte>>,
|
||||||
|
) -> Result<(), String>,
|
||||||
|
{
|
||||||
|
// TODO: Make this more efficient - use a special NDArray iterator?
|
||||||
|
|
||||||
|
let tyctx = generator.type_context(ctx.ctx);
|
||||||
|
|
||||||
|
let sizet_model = IntModel(SizeT);
|
||||||
|
let size = call_nac3_ndarray_size(generator, ctx, pndarray);
|
||||||
|
|
||||||
|
gen_model_for(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
sizet_model.const_0(tyctx, ctx.ctx),
|
||||||
|
size,
|
||||||
|
sizet_model.const_1(tyctx, ctx.ctx),
|
||||||
|
|generator, ctx, hooks, index| {
|
||||||
|
let pelement = call_nac3_ndarray_get_nth_pelement(generator, ctx, pndarray, index);
|
||||||
|
body(generator, ctx, hooks, index, pelement)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
|
@ -0,0 +1,219 @@
|
||||||
|
use inkwell::{
|
||||||
|
types::BasicType,
|
||||||
|
values::{BasicValue, BasicValueEnum, PointerValue},
|
||||||
|
AddressSpace,
|
||||||
|
};
|
||||||
|
use nac3parser::ast::StrRef;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
codegen::{
|
||||||
|
irrt::ndarray::allocation::{
|
||||||
|
alloca_ndarray, init_ndarray_data_by_alloca, init_ndarray_shape,
|
||||||
|
},
|
||||||
|
model::*,
|
||||||
|
structure::ndarray::NpArray,
|
||||||
|
util::shape::make_shape_writer,
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
},
|
||||||
|
symbol_resolver::ValueEnum,
|
||||||
|
toplevel::DefinitionId,
|
||||||
|
typecheck::typedef::{FunSignature, Type},
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::control::gen_foreach_ndarray_elements;
|
||||||
|
|
||||||
|
/// Helper function to create an ndarray with uninitialized values
|
||||||
|
///
|
||||||
|
/// * `elem_ty` - The [`Type`] of the ndarray elements
|
||||||
|
/// * `shape` - The user input shape argument
|
||||||
|
/// * `shape_ty` - The [`Type`] of the shape argument
|
||||||
|
/// * `name` - LLVM IR name of the returned ndarray
|
||||||
|
fn create_empty_ndarray<'ctx, G>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
elem_ty: Type,
|
||||||
|
shape: BasicValueEnum<'ctx>,
|
||||||
|
shape_ty: Type,
|
||||||
|
name: &str,
|
||||||
|
) -> Result<Ptr<'ctx, StructModel<NpArray>>, String>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
{
|
||||||
|
let tyctx = generator.type_context(ctx.ctx);
|
||||||
|
let sizet_model = IntModel(SizeT);
|
||||||
|
|
||||||
|
let shape_writer = make_shape_writer(generator, ctx, shape, shape_ty);
|
||||||
|
let ndims = shape_writer.len;
|
||||||
|
|
||||||
|
let ndarray = alloca_ndarray(generator, ctx, ndims, name)?;
|
||||||
|
init_ndarray_shape(generator, ctx, ndarray, &shape_writer)?;
|
||||||
|
|
||||||
|
let itemsize = ctx.get_llvm_type(generator, elem_ty).size_of().unwrap();
|
||||||
|
let itemsize = sizet_model.check_value(tyctx, ctx.ctx, itemsize).unwrap();
|
||||||
|
ndarray.gep(ctx, |f| f.itemsize).store(ctx, itemsize);
|
||||||
|
|
||||||
|
// Needs `itemsize` and `shape` initialized
|
||||||
|
init_ndarray_data_by_alloca(generator, ctx, ndarray);
|
||||||
|
|
||||||
|
Ok(ndarray)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper function to create an ndarray full of a value.
|
||||||
|
///
|
||||||
|
/// * `elem_ty` - The [`Type`] of the ndarray elements and the fill value
|
||||||
|
/// * `shape` - The user input shape argument
|
||||||
|
/// * `shape_ty` - The [`Type`] of the shape argument
|
||||||
|
/// * `fill_value` - The user specified fill value
|
||||||
|
/// * `name` - LLVM IR name of the returned ndarray
|
||||||
|
fn create_full_ndarray<'ctx, G>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
elem_ty: Type,
|
||||||
|
shape: BasicValueEnum<'ctx>,
|
||||||
|
shape_ty: Type,
|
||||||
|
fill_value: BasicValueEnum<'ctx>,
|
||||||
|
name: &str,
|
||||||
|
) -> Result<Ptr<'ctx, StructModel<NpArray>>, String>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
{
|
||||||
|
let pndarray = create_empty_ndarray(generator, ctx, elem_ty, shape, shape_ty, name)?;
|
||||||
|
gen_foreach_ndarray_elements(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
pndarray,
|
||||||
|
|_generator, ctx, _hooks, _i, pelement| {
|
||||||
|
// Cannot use Model here, fill_value's type is not statically known.
|
||||||
|
let pfill_value_ty = fill_value.get_type().ptr_type(AddressSpace::default());
|
||||||
|
let pelement =
|
||||||
|
ctx.builder.build_pointer_cast(pelement.value, pfill_value_ty, "pelement").unwrap();
|
||||||
|
ctx.builder.build_store(pelement, fill_value).unwrap();
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
Ok(pndarray)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates LLVM IR for `np.empty`.
|
||||||
|
pub fn gen_ndarray_empty<'ctx>(
|
||||||
|
context: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||||
|
fun: (&FunSignature, DefinitionId),
|
||||||
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
) -> Result<PointerValue<'ctx>, String> {
|
||||||
|
assert!(obj.is_none());
|
||||||
|
assert_eq!(args.len(), 1);
|
||||||
|
|
||||||
|
// Parse arguments
|
||||||
|
let shape_ty = fun.0.args[0].ty;
|
||||||
|
let shape = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
||||||
|
|
||||||
|
// Implementation
|
||||||
|
let ndarray_ptr = create_empty_ndarray(
|
||||||
|
generator,
|
||||||
|
context,
|
||||||
|
context.primitives.float,
|
||||||
|
shape,
|
||||||
|
shape_ty,
|
||||||
|
"ndarray",
|
||||||
|
)?;
|
||||||
|
Ok(ndarray_ptr.value)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates LLVM IR for `np.zeros`.
|
||||||
|
pub fn gen_ndarray_zeros<'ctx>(
|
||||||
|
context: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||||
|
fun: (&FunSignature, DefinitionId),
|
||||||
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
) -> Result<PointerValue<'ctx>, String> {
|
||||||
|
assert!(obj.is_none());
|
||||||
|
assert_eq!(args.len(), 1);
|
||||||
|
|
||||||
|
// Parse arguments
|
||||||
|
let shape_ty = fun.0.args[0].ty;
|
||||||
|
let shape = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
||||||
|
|
||||||
|
// Implementation
|
||||||
|
// NOTE: Currently nac3's `np.zeros` is always `float64`.
|
||||||
|
let float64_ty = context.primitives.float;
|
||||||
|
let float64_llvm_type = context.get_llvm_type(generator, float64_ty).into_float_type();
|
||||||
|
|
||||||
|
let ndarray_ptr = create_full_ndarray(
|
||||||
|
generator,
|
||||||
|
context,
|
||||||
|
float64_ty, // `elem_ty` is always `float64`
|
||||||
|
shape,
|
||||||
|
shape_ty,
|
||||||
|
float64_llvm_type.const_zero().as_basic_value_enum(),
|
||||||
|
"ndarray",
|
||||||
|
)?;
|
||||||
|
Ok(ndarray_ptr.value)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates LLVM IR for `np.ones`.
|
||||||
|
pub fn gen_ndarray_ones<'ctx>(
|
||||||
|
context: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||||
|
fun: (&FunSignature, DefinitionId),
|
||||||
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
) -> Result<PointerValue<'ctx>, String> {
|
||||||
|
assert!(obj.is_none());
|
||||||
|
assert_eq!(args.len(), 1);
|
||||||
|
|
||||||
|
// Parse arguments
|
||||||
|
let shape_ty = fun.0.args[0].ty;
|
||||||
|
let shape = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
||||||
|
|
||||||
|
// Implementation
|
||||||
|
// NOTE: Currently nac3's `np.ones` is always `float64`.
|
||||||
|
let float64_ty = context.primitives.float;
|
||||||
|
let float64_llvm_type = context.get_llvm_type(generator, float64_ty).into_float_type();
|
||||||
|
|
||||||
|
let ndarray_ptr = create_full_ndarray(
|
||||||
|
generator,
|
||||||
|
context,
|
||||||
|
float64_ty, // `elem_ty` is always `float64`
|
||||||
|
shape,
|
||||||
|
shape_ty,
|
||||||
|
float64_llvm_type.const_float(1.0).as_basic_value_enum(),
|
||||||
|
"ndarray",
|
||||||
|
)?;
|
||||||
|
Ok(ndarray_ptr.value)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates LLVM IR for `np.full`.
|
||||||
|
pub fn gen_ndarray_full<'ctx>(
|
||||||
|
context: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||||
|
fun: (&FunSignature, DefinitionId),
|
||||||
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
) -> Result<PointerValue<'ctx>, String> {
|
||||||
|
assert!(obj.is_none());
|
||||||
|
assert_eq!(args.len(), 2);
|
||||||
|
|
||||||
|
// Parse argument #1 shape
|
||||||
|
let shape_ty = fun.0.args[0].ty;
|
||||||
|
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
||||||
|
|
||||||
|
// Parse argument #2 fill_value
|
||||||
|
let fill_value_ty = fun.0.args[1].ty;
|
||||||
|
let fill_value_arg =
|
||||||
|
args[1].1.clone().to_basic_value_enum(context, generator, fill_value_ty)?;
|
||||||
|
|
||||||
|
// Implementation
|
||||||
|
let ndarray_ptr = create_full_ndarray(
|
||||||
|
generator,
|
||||||
|
context,
|
||||||
|
fill_value_ty,
|
||||||
|
shape_arg,
|
||||||
|
shape_ty,
|
||||||
|
fill_value_arg,
|
||||||
|
"ndarray",
|
||||||
|
)?;
|
||||||
|
Ok(ndarray_ptr.value)
|
||||||
|
}
|
|
@ -0,0 +1,2 @@
|
||||||
|
pub mod control;
|
||||||
|
pub mod factory;
|
|
@ -20,6 +20,7 @@ use crate::{
|
||||||
irrt::*,
|
irrt::*,
|
||||||
model::*,
|
model::*,
|
||||||
numpy::*,
|
numpy::*,
|
||||||
|
numpy_new,
|
||||||
stmt::exn_constructor,
|
stmt::exn_constructor,
|
||||||
structure::ndarray::NpArray,
|
structure::ndarray::NpArray,
|
||||||
},
|
},
|
||||||
|
@ -1205,9 +1206,11 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
&[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
|
&[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
|
||||||
Box::new(move |ctx, obj, fun, args, generator| {
|
Box::new(move |ctx, obj, fun, args, generator| {
|
||||||
let func = match prim {
|
let func = match prim {
|
||||||
PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => gen_ndarray_empty,
|
PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => {
|
||||||
PrimDef::FunNpZeros => gen_ndarray_zeros,
|
numpy_new::factory::gen_ndarray_empty
|
||||||
PrimDef::FunNpOnes => gen_ndarray_ones,
|
}
|
||||||
|
PrimDef::FunNpZeros => numpy_new::factory::gen_ndarray_zeros,
|
||||||
|
PrimDef::FunNpOnes => numpy_new::factory::gen_ndarray_ones,
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
func(ctx, &obj, fun, &args, generator).map(|val| Some(val.as_basic_value_enum()))
|
func(ctx, &obj, fun, &args, generator).map(|val| Some(val.as_basic_value_enum()))
|
||||||
|
@ -1275,7 +1278,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
// type variable
|
// type variable
|
||||||
&[(self.list_int32, "shape"), (tv.ty, "fill_value")],
|
&[(self.list_int32, "shape"), (tv.ty, "fill_value")],
|
||||||
Box::new(move |ctx, obj, fun, args, generator| {
|
Box::new(move |ctx, obj, fun, args, generator| {
|
||||||
gen_ndarray_full(ctx, &obj, fun, &args, generator)
|
numpy_new::factory::gen_ndarray_full(ctx, &obj, fun, &args, generator)
|
||||||
.map(|val| Some(val.as_basic_value_enum()))
|
.map(|val| Some(val.as_basic_value_enum()))
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue