forked from M-Labs/nac3
1
0
Fork 0

WIP: core/ndstrides: more iter and less builtin

This commit is contained in:
lyken 2024-08-15 00:33:23 +08:00
parent 15dfb2eaa0
commit 5dce27e87d
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
14 changed files with 478 additions and 279 deletions

View File

@ -21,40 +21,61 @@ namespace {
* - If shape contains zeroes, there are no enumerations.
*/
template <typename SizeT>
struct IndicesIter {
struct NDIter {
SizeT ndims;
SizeT* shape;
SizeT* strides;
SizeT size; // Product of shape
SizeT* indices; // The current indices
SizeT nth; // The nth (0-based) index of the current indices.
uint8_t* element; // The current element
/**
* @brief The current indices.
*
* Must be allocated by the caller.
*/
SizeT* indices;
// A convenient constructor for internal C++ IRRT.
IndicesIter(SizeT ndims, SizeT* shape, SizeT* strides, SizeT *indices, uint8_t* element) {
/**
* @brief The nth (0-based) index of the current indices.
*/
SizeT nth;
/**
* @brief Pointer to the current element.
*/
uint8_t* element;
/**
* @brief The product of shape.
*/
SizeT size;
// TODO:: There is something called backstrides to speedup iteration.
// See https://ajcr.net/stride-guide-part-1/, and https://docs.scipy.org/doc/numpy-1.13.0/reference/c-api.types-and-structures.html#c.PyArrayIterObject.PyArrayIterObject.backstrides.
// Maybe LLVM is clever and knows how to optimize.
void initialize(SizeT ndims, SizeT* shape, SizeT* strides, uint8_t* element,
SizeT* indices) {
this->ndims = ndims;
this->shape = shape;
this->strides = strides;
this->indices = indices;
this->element = element;
this->initialize();
}
void initialize() {
reset();
// Compute size and backstrides
this->size = 1;
for (SizeT i = 0; i < ndims; i++) {
this->size *= shape[i];
}
}
void reset() {
for (SizeT axis = 0; axis < ndims; axis++) indices[axis] = 0;
nth = 0;
}
void initialize_by_ndarray(NDArray<SizeT>* ndarray, SizeT* indices) {
this->initialize(ndarray->ndims, ndarray->shape, ndarray->strides,
ndarray->data, indices);
}
bool has_next() { return nth < size; }
void next() {
@ -63,7 +84,11 @@ struct IndicesIter {
indices[axis]++;
if (indices[axis] >= shape[axis]) {
indices[axis] = 0;
// TODO: Can be optimized with backstrides.
element -= strides[axis] * shape[axis];
} else {
element += strides[axis];
break;
}
}
@ -73,32 +98,21 @@ struct IndicesIter {
} // namespace
extern "C" {
void __call_nac3_ndarray_indices_iter_initialize(IndicesIter<int32_t>* iter,
int32_t ndims, int32_t* shape,
int32_t* indices) {
iter->initialize(ndims, shape, indices);
void __nac3_nditer_initialize(NDIter<int32_t>* iter, NDArray<int32_t>* ndarray,
int32_t* indices) {
iter->initialize_by_ndarray(ndarray, indices);
}
void __call_nac3_ndarray_indices_iter_initialize64(IndicesIter<int64_t>* iter,
int64_t ndims,
int64_t* shape,
int64_t* indices) {
iter->initialize(ndims, shape, indices);
void __nac3_nditer_initialize64(NDIter<int64_t>* iter,
NDArray<int64_t>* ndarray, int64_t* indices) {
iter->initialize_by_ndarray(ndarray, indices);
}
bool __call_nac3_ndarray_indices_iter_has_next(IndicesIter<int32_t>* iter) {
iter->has_next();
}
bool __nac3_nditer_has_next(NDIter<int32_t>* iter) { return iter->has_next(); }
bool __call_nac3_ndarray_indices_iter_has_next64(IndicesIter<int64_t>* iter) {
iter->has_next();
}
bool __nac3_nditer_has_next64(NDIter<int64_t>* iter) { return iter->has_next(); }
bool __call_nac3_ndarray_indices_iter_next(IndicesIter<int32_t>* iter) {
iter->next();
}
void __nac3_nditer_next(NDIter<int32_t>* iter) { iter->next(); }
bool __call_nac3_ndarray_indices_iter_next64(IndicesIter<int64_t>* iter) {
iter->next();
}
void __nac3_nditer_next64(NDIter<int64_t>* iter) { iter->next(); }
}

View File

@ -123,7 +123,9 @@ void matmul_at_least_2d(NDArray<SizeT>* a_ndarray, NDArray<SizeT>* b_ndarray,
SizeT* indices =
(SizeT*)__builtin_alloca(sizeof(SizeT) * dst_ndarray->ndims);
SizeT* mat_indices = indices + u;
IndicesIter<SizeT> iter(u, dst_ndarray->shape, indices);
NDIter<SizeT> iter;
iter.initialize(u, dst_ndarray->shape, dst_ndarray->strides,
dst_ndarray->data, indices);
for (; iter.has_next(); iter.next()) {
for (SizeT i = 0; i < dst_mat_shape[0]; i++) {

View File

@ -1565,7 +1565,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
ctx,
&[left, right],
out,
|generator, ctx, _i, scalars| {
|generator, ctx, scalars| {
let left = scalars[0];
let right = scalars[1];
gen_binop_expr_with_values(

View File

@ -6,7 +6,7 @@ mod test;
use super::model::*;
use super::object::ndarray::broadcast::ShapeEntry;
use super::object::ndarray::indexing::NDIndex;
use super::structure::{List, NDArray};
use super::structure::{List, NDArray, NDIter};
use super::{
classes::{
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue,
@ -1239,7 +1239,31 @@ pub fn call_nac3_array_write_list_to_array<'ctx, G: CodeGenerator + ?Sized>(
CallFunction::begin(generator, ctx, &name).arg(list).arg(ndarray).returning_void();
}
pub fn call_nac3_ndarray_indices_iter_initialize<'ctx, G: CodeGenerator + ?Sized>(
pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
iter: Ptr<'ctx, StructModel<NDIter>>,
ndarray: Ptr<'ctx, StructModel<NDArray>>,
indices: Ptr<'ctx, IntModel<SizeT>>,
) {
}
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_initialize");
CallFunction::begin(generator, ctx, &name).arg(iter).arg(ndarray).arg(indices).returning_void();
}
pub fn call_nac3_nditer_has_next<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
iter: Ptr<'ctx, StructModel<NDIter>>,
) -> Int<'ctx, Bool> {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_has_next");
CallFunction::begin(generator, ctx, &name).arg(iter).returning_auto("has_next")
}
pub fn call_nac3_nditer_next<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
iter: Ptr<'ctx, StructModel<NDIter>>,
) {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_next");
CallFunction::begin(generator, ctx, &name).arg(iter).returning_void();
}

View File

@ -219,4 +219,4 @@ impl<'ctx, S: StructKind<'ctx>> Ptr<'ctx, StructModel<S>> {
}
}
// TODO: Add an opaque struct type?
// TODO: Add an opaque struct type?

View File

@ -127,35 +127,6 @@ pub fn gen_ndarray_ones<'ctx>(
Ok(ndarray.instance.value.as_basic_value_enum())
}
/// Generates LLVM IR for `np.full`.
pub fn gen_ndarray_full<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<BasicValueEnum<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 2);
// Parse argument #1 shape
let shape_ty = fun.0.args[0].ty;
let shape = args[0].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
let shape = AnyObject { ty: shape_ty, value: shape };
// Parse argument #2 fill_value
let fill_value_ty = fun.0.args[1].ty;
let fill_value = args[1].1.clone().to_basic_value_enum(ctx, generator, fill_value_ty)?;
// Implementation
let ndarray_ty = fun.0.ret;
let ndarray = create_empty_ndarray(generator, ctx, ndarray_ty, shape);
ndarray.fill(generator, ctx, fill_value);
Ok(ndarray.instance.value.as_basic_value_enum())
}
/// Generates LLVM IR for `np.broadcast_to`.
pub fn gen_ndarray_broadcast_to<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
@ -345,89 +316,3 @@ pub fn gen_ndarray_strides<'ctx>(
Ok(strides.value.as_basic_value_enum())
}
/// Generates LLVM IR for `np.transpose`.
pub fn gen_ndarray_transpose<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<BasicValueEnum<'ctx>, String> {
// TODO: The implementation will be changed once default values start working again.
// Read the comment on this function in BuiltinBuilder.
// TODO: Change axes values to `SizeT`
assert!(obj.is_none());
assert_eq!(args.len(), 1);
// Parse argument #1 ndarray
let ndarray_ty = fun.0.args[0].ty;
let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
// Implementation
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
let has_axes = args.len() >= 2;
let transposed_ndarray = if has_axes {
// Parse argument #2 axes
let in_axes_ty = fun.0.args[1].ty;
let in_axes = args[1].1.clone().to_basic_value_enum(ctx, generator, in_axes_ty)?;
let in_axes = AnyObject { ty: in_axes_ty, value: in_axes };
let (_, axes) = parse_numpy_int_sequence(generator, ctx, in_axes);
ndarray.transpose(generator, ctx, Some(axes))
} else {
// axes is not given
ndarray.transpose(generator, ctx, None)
};
Ok(transposed_ndarray.instance.value.as_basic_value_enum())
}
pub fn gen_ndarray_array<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<BasicValueEnum<'ctx>, String> {
assert!(obj.is_none());
assert!(matches!(args.len(), 1..=3));
let object_ty = fun.0.args[0].ty;
let object = args[0].1.clone().to_basic_value_enum(ctx, generator, object_ty)?;
let object = AnyObject { ty: object_ty, value: object };
let copy_arg = if let Some(arg) =
args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name))
{
let copy_ty = fun.0.args[1].ty;
arg.1.clone().to_basic_value_enum(ctx, generator, copy_ty)?
} else {
ctx.gen_symbol_val(
generator,
fun.0.args[1].default_value.as_ref().unwrap(),
fun.0.args[1].ty,
)
};
// The argument `ndmin` is completely ignored. We don't need to know its LLVM value.
// We simply make the output ndarray's ndims correct with `atleast_nd`.
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
let output_ndims = extract_ndims(&ctx.unifier, ndims);
let copy = IntModel(Byte).check_value(generator, ctx.ctx, copy_arg).unwrap(); // NAC3 booleans are i8
let copy = copy.truncate(generator, ctx, Bool, "copy_bool");
let ndarray = NDArrayObject::from_np_array(generator, ctx, object, copy);
debug_assert!(ndarray.ndims <= output_ndims); // Sanity check on `ndims`
let ndarray = ndarray.atleast_nd(generator, ctx, output_ndims);
debug_assert!(ctx.unifier.unioned(ndarray.dtype, dtype)); // Sanity check on `dtype`
Ok(ndarray.instance.value.as_basic_value_enum())
}

View File

@ -1,7 +1,10 @@
use itertools::Itertools;
use crate::codegen::{
irrt::{call_nac3_ndarray_broadcast_shapes, call_nac3_ndarray_broadcast_to},
irrt::{
call_nac3_ndarray_broadcast_shapes, call_nac3_ndarray_broadcast_to,
call_nac3_ndarray_util_assert_shape_no_negative,
},
model::*,
CodeGenContext, CodeGenerator,
};
@ -29,6 +32,8 @@ impl<'ctx> StructKind<'ctx> for ShapeEntry {
impl<'ctx> NDArrayObject<'ctx> {
/// Create a broadcast view on this ndarray with a target shape.
///
/// The input shape will be checked to make sure that it contains no negative values.
///
/// * `target_ndims` - The ndims type after broadcasting to the given shape.
/// The caller has to figure this out for this function.
/// * `target_shape` - An array pointer pointing to the target shape.
@ -40,6 +45,14 @@ impl<'ctx> NDArrayObject<'ctx> {
target_ndims: u64,
target_shape: Ptr<'ctx, IntModel<SizeT>>,
) -> Self {
let target_ndims_llvm = IntModel(SizeT).constant(generator, ctx.ctx, target_ndims);
call_nac3_ndarray_util_assert_shape_no_negative(
generator,
ctx,
target_ndims_llvm,
target_shape,
);
let broadcast_ndarray = NDArrayObject::alloca(
generator,
ctx,

View File

@ -1,4 +1,4 @@
use inkwell::values::BasicValueEnum;
use inkwell::{values::BasicValueEnum, IntPredicate};
use super::{scalar::ScalarObject, NDArrayObject};
use crate::{
@ -154,10 +154,15 @@ impl<'ctx> NDArrayObject<'ctx> {
// Create data and set elements
ndarray.create_data(generator, ctx);
ndarray
.foreach_pointer(generator, ctx, |_generator, ctx, _hooks, i, pelement| {
.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
// Get the index of the current element, convert that index to float, and write it.
// This is how we get [0.0, 1.0, 2.0, ...].
let index = nditer.get_index(generator, ctx);
let pelement = nditer.get_pointer(generator, ctx);
let val = ctx
.builder
.build_unsigned_int_to_float(i.value, ctx.ctx.f64_type(), "val")
.build_unsigned_int_to_float(index.value, ctx.ctx.f64_type(), "val")
.unwrap();
ctx.builder.build_store(pelement, val).unwrap();
Ok(())
@ -172,23 +177,45 @@ impl<'ctx> NDArrayObject<'ctx> {
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
rows: Int<'ctx, SizeT>,
cols: Int<'ctx, SizeT>,
num_rows: Int<'ctx, SizeT>,
num_cols: Int<'ctx, SizeT>,
diagonal: Int<'ctx, SizeT>,
) -> Self {
let ndzero = ndarray_zero_value(generator, ctx, dtype);
let ndone = ndarray_one_value(generator, ctx, dtype);
let ndarray = NDArrayObject::alloca_dynamic_shape(
generator,
ctx,
dtype,
&[rows, cols],
&[num_rows, num_cols],
"eye_ndarray",
);
ndarray
.foreach_pointer(generator, ctx, |generator, ctx, hooks, i, p| {
.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
// NOTE: rows and cols can never be zero here, since this ndarray's `np.size` would be zero
// and this loop would not execute.
todo!()
// Load up `row_i` and `col_i` from indices.
let row_i = nditer
.get_indices()
.offset_const(generator, ctx, 0, "")
.load(generator, ctx, "row_i");
let col_i = nditer
.get_indices()
.offset_const(generator, ctx, 1, "")
.load(generator, ctx, "col_i");
// Write to element
let be_one =
row_i.add(ctx, diagonal, "").compare(ctx, IntPredicate::EQ, col_i, "write_one");
let value = ctx.builder.build_select(be_one.value, ndone, ndzero, "value").unwrap();
let p = nditer.get_pointer(generator, ctx);
ctx.builder.build_store(p, value).unwrap();
Ok(())
})
.unwrap();
todo!()

View File

@ -480,7 +480,7 @@ impl<'ctx> NDArrayObject<'ctx> {
let zero = sizet_model.const_0(generator, ctx.ctx);
pextremum_index.store(ctx, zero);
let first_scalar = self.get_nth(generator, ctx, zero);
let first_scalar = self.get_nth_scalar(generator, ctx, zero);
ctx.builder.build_store(pextremum, first_scalar.value).unwrap();
// Find extremum
@ -491,7 +491,7 @@ impl<'ctx> NDArrayObject<'ctx> {
// Worth reading on "Notes" in <https://numpy.org/doc/stable/reference/generated/numpy.min.html#numpy.min>
// on how `NaN` values have to be handled.
let scalar = self.get_nth(generator, ctx, i);
let scalar = self.get_nth_scalar(generator, ctx, i);
let old_extremum = ctx.builder.build_load(pextremum, "current_extremum").unwrap();
let old_extremum = ScalarObject { dtype: self.dtype, value: old_extremum };

View File

@ -1,17 +1,17 @@
use inkwell::values::BasicValueEnum;
use itertools::Itertools;
use util::gen_for_model_auto;
use crate::{
codegen::{
model::*,
object::ndarray::{NDArrayObject, ScalarObject},
stmt::gen_for_callback,
CodeGenContext, CodeGenerator,
},
typecheck::typedef::Type,
};
use super::{scalar::ScalarOrNDArray, NDArrayOut};
use super::{nditer::NDIterHandle, scalar::ScalarOrNDArray, NDArrayOut};
impl<'ctx> NDArrayObject<'ctx> {
/// TODO: Document me. Has complex behavior.
@ -28,12 +28,9 @@ impl<'ctx> NDArrayObject<'ctx> {
MappingFn: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
Int<'ctx, SizeT>,
&[ScalarObject<'ctx>],
) -> Result<BasicValueEnum<'ctx>, String>,
{
let sizet_model = IntModel(SizeT);
// Broadcast inputs
let broadcast_result = NDArrayObject::broadcast(generator, ctx, ndarrays);
@ -66,19 +63,60 @@ impl<'ctx> NDArrayObject<'ctx> {
};
// Map element-wise and store results into `mapped_ndarray`.
let start = sizet_model.const_0(generator, ctx.ctx);
let stop = broadcast_result.ndarrays[0].size(generator, ctx); // They all should have the same `np.size`.
let step = sizet_model.const_1(generator, ctx.ctx);
gen_for_model_auto(generator, ctx, start, stop, step, move |generator, ctx, _hooks, i| {
let elements =
ndarrays.iter().map(|ndarray| ndarray.get_nth(generator, ctx, i)).collect_vec();
let nditer = NDIterHandle::new(generator, ctx, out_ndarray);
gen_for_callback(
generator,
ctx,
Some("broadcast_starmap"),
|generator, ctx| {
// Create NDIters for all broadcasted input ndarrays.
let other_nditers = broadcast_result
.ndarrays
.iter()
.map(|ndarray| NDIterHandle::new(generator, ctx, *ndarray))
.collect_vec();
Ok((nditer, other_nditers))
},
|generator, ctx, (out_nditer, _in_nditers)| {
// We can simply use `out_nditer`'s `has_next()`.
// `in_nditers`' `has_next()`s should return the same value.
Ok(out_nditer.has_next(generator, ctx).value)
},
|generator, ctx, _hooks, (out_nditer, in_nditers)| {
// Get all the scalars from the broadcasted input ndarrays, pass them to `mapping`,
// and write to `out_ndarray`.
let ret = mapping(generator, ctx, i, &elements)?;
let in_scalars =
in_nditers.iter().map(|nditer| nditer.get_scalar(generator, ctx)).collect_vec();
let pret = out_ndarray.get_nth_pointer(generator, ctx, i, "pret");
ctx.builder.build_store(pret, ret).unwrap();
Ok(())
})?;
let result = mapping(generator, ctx, &in_scalars)?;
let p = out_nditer.get_pointer(generator, ctx);
ctx.builder.build_store(p, result).unwrap();
Ok(())
},
|generator, ctx, (out_nditer, in_nditers)| {
// Advance all iterators
out_nditer.next(generator, ctx);
in_nditers.iter().for_each(|nditer| nditer.next(generator, ctx));
Ok(())
},
)?;
// let start = sizet_model.const_0(generator, ctx.ctx);
// let stop = broadcast_result.ndarrays[0].size(generator, ctx); // They all should have the same `np.size`.
// let step = sizet_model.const_1(generator, ctx.ctx);
// gen_for_model_auto(generator, ctx, start, stop, step, move |generator, ctx, _hooks, i| {
// let elements =
// ndarrays.iter().map(|ndarray| ndarray.get_nth_scalar(generator, ctx, i)).collect_vec();
// let ret = mapping(generator, ctx, i, &elements)?;
// let pret = out_ndarray.get_nth_pointer(generator, ctx, i, "pret");
// ctx.builder.build_store(pret, ret).unwrap();
// Ok(())
// })?;
Ok(out_ndarray)
}
@ -95,7 +133,6 @@ impl<'ctx> NDArrayObject<'ctx> {
Mapping: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
Int<'ctx, SizeT>,
ScalarObject<'ctx>,
) -> Result<BasicValueEnum<'ctx>, String>,
{
@ -104,7 +141,7 @@ impl<'ctx> NDArrayObject<'ctx> {
ctx,
&[*self],
out,
|generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]),
|generator, ctx, scalars| mapping(generator, ctx, scalars[0]),
)
}
}
@ -124,20 +161,16 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
MappingFn: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
Int<'ctx, SizeT>,
&[ScalarObject<'ctx>],
) -> Result<BasicValueEnum<'ctx>, String>,
{
let sizet_model = IntModel(SizeT);
// Check if all inputs are ScalarObjects
let all_scalars: Option<Vec<_>> =
inputs.iter().map(ScalarObject::try_from).try_collect().ok();
if let Some(scalars) = all_scalars {
let i = sizet_model.const_0(generator, ctx.ctx); // Pass 0 as the index
let scalar =
ScalarObject { value: mapping(generator, ctx, i, &scalars)?, dtype: ret_dtype };
ScalarObject { value: mapping(generator, ctx, &scalars)?, dtype: ret_dtype };
Ok(ScalarOrNDArray::Scalar(scalar))
} else {
// Promote all input to ndarrays and map through them.
@ -165,7 +198,6 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
Mapping: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
Int<'ctx, SizeT>,
ScalarObject<'ctx>,
) -> Result<BasicValueEnum<'ctx>, String>,
{
@ -174,7 +206,7 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
ctx,
&[*self],
ret_dtype,
|generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]),
|generator, ctx, scalars| mapping(generator, ctx, scalars[0]),
)
}
}

View File

@ -5,6 +5,7 @@ pub mod functions;
pub mod indexing;
pub mod mapping;
pub mod nalgebra;
pub mod nditer;
pub mod product;
pub mod scalar;
pub mod shape_util;
@ -16,10 +17,11 @@ use crate::{
call_nac3_ndarray_is_c_contiguous, call_nac3_ndarray_len, call_nac3_ndarray_nbytes,
call_nac3_ndarray_resolve_and_check_new_shape, call_nac3_ndarray_set_strides_by_shape,
call_nac3_ndarray_size, call_nac3_ndarray_transpose,
call_nac3_ndarray_util_assert_output_shape_same,
call_nac3_ndarray_util_assert_output_shape_same, call_nac3_nditer_has_next,
call_nac3_nditer_initialize, call_nac3_nditer_next,
},
model::*,
stmt::BreakContinueHooks,
stmt::{gen_for_callback, BreakContinueHooks},
structure::{NDArray, SimpleNDArray},
CodeGenContext, CodeGenerator,
},
@ -36,6 +38,7 @@ use inkwell::{
values::{BasicValue, BasicValueEnum, PointerValue},
AddressSpace, IntPredicate,
};
use nditer::NDIterHandle;
use scalar::{ScalarObject, ScalarOrNDArray};
use util::{call_memcpy_model, gen_for_model_auto};
@ -250,7 +253,7 @@ impl<'ctx> NDArrayObject<'ctx> {
/// Get the n-th (0-based) scalar.
///
/// There is no out-of-bounds check.
pub fn get_nth<G: CodeGenerator + ?Sized>(
pub fn get_nth_scalar<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
@ -264,7 +267,7 @@ impl<'ctx> NDArrayObject<'ctx> {
/// Set the n-th (0-based) scalar.
///
/// There is no out-of-bounds check.
pub fn set_nth<G: CodeGenerator + ?Sized>(
pub fn set_nth_scalar<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
@ -466,7 +469,7 @@ impl<'ctx> NDArrayObject<'ctx> {
// NOTE: `np.size(self) == 0` here is never possible.
let sizet_model = IntModel(SizeT);
let zero = sizet_model.const_0(generator, ctx.ctx);
ScalarOrNDArray::Scalar(self.get_nth(generator, ctx, zero))
ScalarOrNDArray::Scalar(self.get_nth_scalar(generator, ctx, zero))
} else {
ScalarOrNDArray::NDArray(*self)
}
@ -543,40 +546,9 @@ impl<'ctx> NDArrayObject<'ctx> {
self.copy_strides_from_array(generator, ctx, src_strides);
}
/// Iterate through every element pointer in the ndarray in its flatten view.
/// Iterate through every element in the ndarray.
///
/// `body` also access to [`BreakContinueHooks`] to short-circuit and an element's
/// index. The given element pointer also has been casted to the LLVM type of this ndarray's `dtype`.
pub fn foreach_pointer<'a, G, F>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
body: F,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
F: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
BreakContinueHooks<'ctx>,
Int<'ctx, SizeT>,
Ptr<'ctx, IntModel<SizeT>>,
PointerValue<'ctx>,
) -> Result<(), String>,
{
let sizet_model = IntModel(SizeT);
let start = sizet_model.const_0(generator, ctx.ctx);
let stop = self.size(generator, ctx);
let step = sizet_model.const_1(generator, ctx.ctx);
gen_for_model_auto(generator, ctx, start, stop, step, |generator, ctx, hooks, i| {
let pelement = self.get_nth_pointer(generator, ctx, i, "element");
body(generator, ctx, hooks, i, pelement)
})
}
/// Iterate through every scalar in this ndarray.
/// `body` also access to [`BreakContinueHooks`] to short-circuit.
pub fn foreach<'a, G, F>(
&self,
generator: &mut G,
@ -589,15 +561,18 @@ impl<'ctx> NDArrayObject<'ctx> {
&mut G,
&mut CodeGenContext<'ctx, 'a>,
BreakContinueHooks<'ctx>,
Int<'ctx, SizeT>,
ScalarObject<'ctx>,
NDIterHandle<'ctx>,
) -> Result<(), String>,
{
self.foreach_pointer(generator, ctx, |generator, ctx, hooks, i, p| {
let value = ctx.builder.build_load(p, "value").unwrap();
let scalar = ScalarObject { dtype: self.dtype, value };
body(generator, ctx, hooks, i, scalar)
})
gen_for_callback(
generator,
ctx,
Some("ndarray_foreach"),
|generator, ctx| Ok(NDIterHandle::new(generator, ctx, *self)),
|generator, ctx, nditer| Ok(nditer.has_next(generator, ctx).value),
|generator, ctx, hooks, nditer| body(generator, ctx, hooks, nditer),
|generator, ctx, nditer| Ok(nditer.next(generator, ctx)),
)
}
/// Make sure the ndarray is at least `ndmin`-dimensional.
@ -632,8 +607,9 @@ impl<'ctx> NDArrayObject<'ctx> {
ctx: &mut CodeGenContext<'ctx, '_>,
fill_value: BasicValueEnum<'ctx>,
) {
self.foreach_pointer(generator, ctx, |_generator, ctx, _hooks, _i, pelement| {
ctx.builder.build_store(pelement, fill_value).unwrap();
self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
let p = nditer.get_pointer(generator, ctx);
ctx.builder.build_store(p, fill_value).unwrap();
Ok(())
})
.unwrap();

View File

@ -0,0 +1,83 @@
use inkwell::{types::BasicType, values::PointerValue, AddressSpace};
use crate::codegen::{
irrt::{call_nac3_nditer_has_next, call_nac3_nditer_initialize, call_nac3_nditer_next},
model::*,
object::ndarray::scalar::ScalarObject,
structure::NDIter,
CodeGenContext, CodeGenerator,
};
use super::NDArrayObject;
#[derive(Debug, Clone)]
pub struct NDIterHandle<'ctx> {
ndarray: NDArrayObject<'ctx>,
instance: Ptr<'ctx, StructModel<NDIter>>,
indices: Ptr<'ctx, IntModel<SizeT>>,
}
impl<'ctx> NDIterHandle<'ctx> {
pub fn new<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayObject<'ctx>,
) -> Self {
let nditer = StructModel(NDIter).alloca(generator, ctx, "nditer");
let ndims = ndarray.get_ndims(generator, ctx.ctx);
let indices = IntModel(SizeT).array_alloca(generator, ctx, ndims.value, "indices");
call_nac3_nditer_initialize(generator, ctx, nditer, ndarray.instance, indices);
NDIterHandle { ndarray, instance: nditer, indices }
}
pub fn has_next<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, Bool> {
call_nac3_nditer_has_next(generator, ctx, self.instance)
}
pub fn next<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) {
call_nac3_nditer_next(generator, ctx, self.instance)
}
pub fn get_pointer<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> PointerValue<'ctx> {
let elem_ty = ctx.get_llvm_type(generator, self.ndarray.dtype);
let p = self.instance.get(generator, ctx, |f| f.element, "element");
ctx.builder
.build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), "element")
.unwrap()
}
pub fn get_scalar<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> ScalarObject<'ctx> {
let p = self.get_pointer(generator, ctx);
let value = ctx.builder.build_load(p, "value").unwrap();
ScalarObject { dtype: self.ndarray.dtype, value }
}
pub fn get_index<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, SizeT> {
self.instance.get(generator, ctx, |f| f.nth, "index")
}
pub fn get_indices(&self) -> Ptr<'ctx, IntModel<SizeT>> {
self.indices
}
}

View File

@ -187,28 +187,36 @@ impl<'ctx, Item: Model<'ctx>> StructKind<'ctx> for SimpleNDArray<Item> {
}
/// An IRRT helper structure used when iterating through an ndarray.
/// Fields of [`IndicesIter`]
pub struct IndicesIterFields<'ctx, F: FieldTraversal<'ctx>> {
/// Fields of [`NDIter`]
pub struct NDIterFields<'ctx, F: FieldTraversal<'ctx>> {
pub ndims: F::Out<IntModel<SizeT>>,
pub shape: F::Out<PtrModel<IntModel<SizeT>>>,
pub strides: F::Out<PtrModel<IntModel<SizeT>>>,
pub indices: F::Out<PtrModel<IntModel<SizeT>>>,
pub size: F::Out<IntModel<SizeT>>,
pub nth: F::Out<IntModel<SizeT>>,
pub element: F::Out<PtrModel<IntModel<Byte>>>,
pub size: F::Out<IntModel<SizeT>>,
}
#[derive(Debug, Clone, Copy, Default)]
pub struct IndicesIter;
pub struct NDIter;
impl<'ctx> StructKind<'ctx> for IndicesIter {
type Fields<F: FieldTraversal<'ctx>> = IndicesIterFields<'ctx, F>;
impl<'ctx> StructKind<'ctx> for NDIter {
type Fields<F: FieldTraversal<'ctx>> = NDIterFields<'ctx, F>;
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
Self::Fields {
ndims: traversal.add_auto("ndims"),
shape: traversal.add_auto("shape"),
strides: traversal.add_auto("strides"),
indices: traversal.add_auto("indices"),
size: traversal.add_auto("size"),
nth: traversal.add_auto("nth"),
element: traversal.add_auto("element"),
size: traversal.add_auto("size"),
}
}
}

View File

@ -1,6 +1,9 @@
use std::iter::once;
use helper::{create_ndims, debug_assert_prim_is_allowed, make_exception_fields, PrimDefDetails};
use helper::{
create_ndims, debug_assert_prim_is_allowed, extract_ndims, make_exception_fields,
PrimDefDetails,
};
use indexmap::IndexMap;
use inkwell::{
attributes::{Attribute, AttributeLoc},
@ -9,6 +12,7 @@ use inkwell::{
IntPredicate,
};
use itertools::Either;
use numpy::unpack_ndarray_var_tys;
use strum::IntoEnumIterator;
use crate::{
@ -16,15 +20,17 @@ use crate::{
builtin_fns,
classes::{ProxyValue, RangeValue},
extern_fns::{self, call_np_linalg_det, call_np_linalg_matrix_power},
irrt, llvm_intrinsics,
model::{IntModel, SizeT},
irrt::{self, call_nac3_ndarray_util_assert_shape_no_negative},
llvm_intrinsics,
model::*,
numpy::*,
numpy_new::{self, gen_ndarray_transpose},
numpy_new::{self},
object::{
ndarray::{
functions::{FloorOrCeil, MinOrMax},
nalgebra::perform_nalgebra_call,
scalar::{split_scalar_or_ndarray, ScalarObject, ScalarOrNDArray},
shape_util::parse_numpy_int_sequence,
NDArrayObject,
},
tuple::TupleObject,
@ -1104,7 +1110,7 @@ impl<'a> BuiltinBuilder<'a> {
generator,
ctx,
ret_dtype,
|generator, ctx, _i, scalar| {
|generator, ctx, scalar| {
let result = match prim {
PrimDef::FunInt32 => scalar.cast_to_int32(generator, ctx),
PrimDef::FunInt64 => scalar.cast_to_int64(generator, ctx),
@ -1175,9 +1181,7 @@ impl<'a> BuiltinBuilder<'a> {
generator,
ctx,
ret_int_dtype,
|generator, ctx, _i, scalar| {
Ok(scalar.round(generator, ctx, ret_int_dtype).value)
},
|generator, ctx, scalar| Ok(scalar.round(generator, ctx, ret_int_dtype).value),
)?;
Ok(Some(result.to_basic_value_enum()))
}),
@ -1241,7 +1245,7 @@ impl<'a> BuiltinBuilder<'a> {
generator,
ctx,
int_sized,
|generator, ctx, _i, scalar| {
|generator, ctx, scalar| {
Ok(scalar.floor_or_ceil(generator, ctx, kind, int_sized).value)
},
)?;
@ -1291,13 +1295,24 @@ impl<'a> BuiltinBuilder<'a> {
self.ndarray_float,
&[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
Box::new(move |ctx, obj, fun, args, generator| {
// Parse argument `shape`.
let shape_ty = fun.0.args[0].ty;
let shape_arg = args[0].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
let shape = AnyObject { ty: shape_ty, value: shape_arg };
let (_, shape) = parse_numpy_int_sequence(generator, ctx, shape);
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
let ndims = extract_ndims(&ctx.unifier, ndims);
let func = match prim {
PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => numpy_new::gen_ndarray_empty,
PrimDef::FunNpZeros => numpy_new::gen_ndarray_zeros,
PrimDef::FunNpOnes => numpy_new::gen_ndarray_ones,
PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => NDArrayObject::from_np_empty,
PrimDef::FunNpZeros => NDArrayObject::from_np_zero,
PrimDef::FunNpOnes => NDArrayObject::from_np_ones,
_ => unreachable!(),
};
func(ctx, &obj, fun, &args, generator).map(|val| Some(val.as_basic_value_enum()))
let ndarray = func(generator, ctx, dtype, ndims, shape);
Ok(Some(ndarray.instance.value.as_basic_value_enum()))
}),
)
}
@ -1356,8 +1371,47 @@ impl<'a> BuiltinBuilder<'a> {
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, obj, fun, args, generator| {
numpy_new::gen_ndarray_array(ctx, &obj, fun, &args, generator)
.map(|val| Some(val.as_basic_value_enum()))
assert!(obj.is_none());
assert!(matches!(args.len(), 1..=3));
let object_ty = fun.0.args[0].ty;
let object =
args[0].1.clone().to_basic_value_enum(ctx, generator, object_ty)?;
let object = AnyObject { ty: object_ty, value: object };
let copy_arg = if let Some(arg) = args
.iter()
.find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name))
{
let copy_ty = fun.0.args[1].ty;
arg.1.clone().to_basic_value_enum(ctx, generator, copy_ty)?
} else {
ctx.gen_symbol_val(
generator,
fun.0.args[1].default_value.as_ref().unwrap(),
fun.0.args[1].ty,
)
};
// The argument `ndmin` is completely ignored. We don't need to know its LLVM value.
// We simply make the output ndarray's ndims correct with `atleast_nd`.
let (dtype, ndims) =
unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
let output_ndims = extract_ndims(&ctx.unifier, ndims);
let copy =
IntModel(Byte).check_value(generator, ctx.ctx, copy_arg).unwrap(); // NAC3 booleans are i8
let copy = copy.truncate(generator, ctx, Bool, "copy_bool");
let ndarray =
NDArrayObject::from_np_array(generator, ctx, object, copy);
debug_assert!(ndarray.ndims <= output_ndims); // Sanity check on `ndims`
let ndarray = ndarray.atleast_nd(generator, ctx, output_ndims);
debug_assert!(ctx.unifier.unioned(ndarray.dtype, dtype)); // Sanity check on `dtype`
Ok(Some(ndarray.instance.value.as_basic_value_enum()))
},
)))),
loc: None,
@ -1375,8 +1429,31 @@ impl<'a> BuiltinBuilder<'a> {
// type variable
&[(self.list_int32, "shape"), (tv.ty, "fill_value")],
Box::new(move |ctx, obj, fun, args, generator| {
numpy_new::gen_ndarray_full(ctx, &obj, fun, &args, generator)
.map(|val| Some(val.as_basic_value_enum()))
assert!(obj.is_none());
assert_eq!(args.len(), 2);
// Parse argument #1 shape
let shape_ty = fun.0.args[0].ty;
let shape =
args[0].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
let shape = AnyObject { ty: shape_ty, value: shape };
let (_, shape) = parse_numpy_int_sequence(generator, ctx, shape);
// Parse argument #2 fill_value
let fill_value_ty = fun.0.args[1].ty;
let fill_value =
args[1].1.clone().to_basic_value_enum(ctx, generator, fill_value_ty)?;
let fill_value = ScalarObject { dtype: fill_value_ty, value: fill_value };
// Implementation
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
let ndims = extract_ndims(&ctx.unifier, ndims);
let ndarray = NDArrayObject::from_np_full(
generator, ctx, dtype, ndims, shape, fill_value,
);
Ok(Some(ndarray.instance.value.as_basic_value_enum()))
}),
)
}
@ -1483,12 +1560,38 @@ impl<'a> BuiltinBuilder<'a> {
(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape"),
],
Box::new(move |ctx, obj, fun, args, generator| {
let f = match prim {
PrimDef::FunNpBroadcastTo => numpy_new::gen_ndarray_broadcast_to,
PrimDef::FunNpReshape => numpy_new::gen_ndarray_reshape,
// Parse argument #1 ndarray
let input_ty = fun.0.args[0].ty;
let input =
args[0].1.clone().to_basic_value_enum(ctx, generator, input_ty)?;
let input = AnyObject { ty: input_ty, value: input };
// Parse argument #2 shape
let shape_ty = fun.0.args[1].ty;
let shape =
args[1].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
let shape = AnyObject { ty: shape_ty, value: shape };
let (_, return_shape) = parse_numpy_int_sequence(generator, ctx, shape);
// Implementation
// Turn any input to an ndarray for a convenient implementation.
let ndarray = split_scalar_or_ndarray(generator, ctx, input)
.as_ndarray(generator, ctx);
let (_, return_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
let return_ndims = extract_ndims(&ctx.unifier, return_ndims);
let result = match prim {
PrimDef::FunNpBroadcastTo => {
ndarray.broadcast_to(generator, ctx, return_ndims, return_shape)
}
PrimDef::FunNpReshape => {
ndarray.reshape_or_copy(generator, ctx, return_ndims, return_shape)
}
_ => unreachable!(),
};
f(ctx, &obj, fun, &args, generator).map(Some)
Ok(Some(result.instance.value.as_basic_value_enum()))
}),
)
}
@ -1529,7 +1632,39 @@ impl<'a> BuiltinBuilder<'a> {
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, obj, fun, args, generator| {
gen_ndarray_transpose(ctx, &obj, fun, &args, generator).map(Some)
assert!(obj.is_none());
assert_eq!(args.len(), 1);
// Parse argument #1 ndarray
let ndarray_ty = fun.0.args[0].ty;
let ndarray = args[0]
.1
.clone()
.to_basic_value_enum(ctx, generator, ndarray_ty)?;
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
// Implementation
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
let has_axes = args.len() >= 2;
let transposed_ndarray = if has_axes {
// Parse argument #2 axes
let in_axes_ty = fun.0.args[1].ty;
let in_axes = args[1]
.1
.clone()
.to_basic_value_enum(ctx, generator, in_axes_ty)?;
let in_axes = AnyObject { ty: in_axes_ty, value: in_axes };
let (_, axes) = parse_numpy_int_sequence(generator, ctx, in_axes);
ndarray.transpose(generator, ctx, Some(axes))
} else {
// axes is not given
ndarray.transpose(generator, ctx, None)
};
Ok(Some(transposed_ndarray.instance.value.as_basic_value_enum()))
},
)))),
loc: None,
@ -1641,7 +1776,7 @@ impl<'a> BuiltinBuilder<'a> {
generator,
ctx,
ctx.primitives.float,
move |_generator, ctx, _i, scalar| {
move |_generator, ctx, scalar| {
let result = scalar.np_floor_or_ceil(ctx, kind);
Ok(result.value)
},
@ -1670,7 +1805,7 @@ impl<'a> BuiltinBuilder<'a> {
generator,
ctx,
ctx.primitives.float,
|_generator, ctx, _i, scalar| {
|_generator, ctx, scalar| {
let result = scalar.np_round(ctx);
Ok(result.value)
},
@ -1883,7 +2018,7 @@ impl<'a> BuiltinBuilder<'a> {
ctx,
&[x1, x2],
common_ty,
|_generator, ctx, _i, scalars| {
|_generator, ctx, scalars| {
let x1 = scalars[0];
let x2 = scalars[1];
@ -1930,7 +2065,7 @@ impl<'a> BuiltinBuilder<'a> {
generator,
ctx,
num_ty.ty,
|_generator, ctx, _i, scalar| Ok(scalar.abs(ctx).value),
|_generator, ctx, scalar| Ok(scalar.abs(ctx).value),
)?;
Ok(Some(result.to_basic_value_enum()))
},
@ -1966,7 +2101,7 @@ impl<'a> BuiltinBuilder<'a> {
generator,
ctx,
ctx.primitives.bool,
|generator, ctx, _i, scalar| {
|generator, ctx, scalar| {
let n = scalar.into_float64(ctx);
let n = function(generator, ctx, n);
Ok(n.as_basic_value_enum())
@ -2035,7 +2170,7 @@ impl<'a> BuiltinBuilder<'a> {
generator,
ctx,
ctx.primitives.float,
|_generator, ctx, _i, scalar| {
|_generator, ctx, scalar| {
let n = scalar.into_float64(ctx);
let n = match prim {
PrimDef::FunNpSin => llvm_intrinsics::call_float_sin(ctx, n, None),
@ -2160,7 +2295,7 @@ impl<'a> BuiltinBuilder<'a> {
ctx,
&[x1, x2],
ret_dtype,
|_generator, ctx, _i, scalars| {
|_generator, ctx, scalars| {
let x1 = scalars[0];
let x2 = scalars[1];
@ -2406,7 +2541,7 @@ impl<'a> BuiltinBuilder<'a> {
);
let sizet_model = IntModel(SizeT);
let zero = sizet_model.const_0(generator, ctx.ctx);
x2_ndarray.set_nth(
x2_ndarray.set_nth_scalar(
generator,
ctx,
zero,
@ -2451,7 +2586,7 @@ impl<'a> BuiltinBuilder<'a> {
let sizet_model = IntModel(SizeT);
let zero = sizet_model.const_0(generator, ctx.ctx);
let determinant = out.get_nth(generator, ctx, zero);
let determinant = out.get_nth_scalar(generator, ctx, zero);
Ok(Some(determinant.value))
}),