forked from M-Labs/nac3
WIP: core/ndstrides: more iter and less builtin
This commit is contained in:
parent
15dfb2eaa0
commit
5dce27e87d
|
@ -21,40 +21,61 @@ namespace {
|
||||||
* - If shape contains zeroes, there are no enumerations.
|
* - If shape contains zeroes, there are no enumerations.
|
||||||
*/
|
*/
|
||||||
template <typename SizeT>
|
template <typename SizeT>
|
||||||
struct IndicesIter {
|
struct NDIter {
|
||||||
SizeT ndims;
|
SizeT ndims;
|
||||||
SizeT* shape;
|
SizeT* shape;
|
||||||
SizeT* strides;
|
SizeT* strides;
|
||||||
SizeT size; // Product of shape
|
|
||||||
|
|
||||||
SizeT* indices; // The current indices
|
/**
|
||||||
SizeT nth; // The nth (0-based) index of the current indices.
|
* @brief The current indices.
|
||||||
uint8_t* element; // The current element
|
*
|
||||||
|
* Must be allocated by the caller.
|
||||||
|
*/
|
||||||
|
SizeT* indices;
|
||||||
|
|
||||||
// A convenient constructor for internal C++ IRRT.
|
/**
|
||||||
IndicesIter(SizeT ndims, SizeT* shape, SizeT* strides, SizeT *indices, uint8_t* element) {
|
* @brief The nth (0-based) index of the current indices.
|
||||||
|
*/
|
||||||
|
SizeT nth;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Pointer to the current element.
|
||||||
|
*/
|
||||||
|
uint8_t* element;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief The product of shape.
|
||||||
|
*/
|
||||||
|
SizeT size;
|
||||||
|
|
||||||
|
// TODO:: There is something called backstrides to speedup iteration.
|
||||||
|
// See https://ajcr.net/stride-guide-part-1/, and https://docs.scipy.org/doc/numpy-1.13.0/reference/c-api.types-and-structures.html#c.PyArrayIterObject.PyArrayIterObject.backstrides.
|
||||||
|
// Maybe LLVM is clever and knows how to optimize.
|
||||||
|
|
||||||
|
void initialize(SizeT ndims, SizeT* shape, SizeT* strides, uint8_t* element,
|
||||||
|
SizeT* indices) {
|
||||||
this->ndims = ndims;
|
this->ndims = ndims;
|
||||||
this->shape = shape;
|
this->shape = shape;
|
||||||
this->strides = strides;
|
this->strides = strides;
|
||||||
|
|
||||||
this->indices = indices;
|
this->indices = indices;
|
||||||
this->element = element;
|
this->element = element;
|
||||||
this->initialize();
|
|
||||||
}
|
|
||||||
|
|
||||||
void initialize() {
|
|
||||||
reset();
|
|
||||||
|
|
||||||
|
// Compute size and backstrides
|
||||||
this->size = 1;
|
this->size = 1;
|
||||||
for (SizeT i = 0; i < ndims; i++) {
|
for (SizeT i = 0; i < ndims; i++) {
|
||||||
this->size *= shape[i];
|
this->size *= shape[i];
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
void reset() {
|
|
||||||
for (SizeT axis = 0; axis < ndims; axis++) indices[axis] = 0;
|
for (SizeT axis = 0; axis < ndims; axis++) indices[axis] = 0;
|
||||||
nth = 0;
|
nth = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void initialize_by_ndarray(NDArray<SizeT>* ndarray, SizeT* indices) {
|
||||||
|
this->initialize(ndarray->ndims, ndarray->shape, ndarray->strides,
|
||||||
|
ndarray->data, indices);
|
||||||
|
}
|
||||||
|
|
||||||
bool has_next() { return nth < size; }
|
bool has_next() { return nth < size; }
|
||||||
|
|
||||||
void next() {
|
void next() {
|
||||||
|
@ -63,7 +84,11 @@ struct IndicesIter {
|
||||||
indices[axis]++;
|
indices[axis]++;
|
||||||
if (indices[axis] >= shape[axis]) {
|
if (indices[axis] >= shape[axis]) {
|
||||||
indices[axis] = 0;
|
indices[axis] = 0;
|
||||||
|
|
||||||
|
// TODO: Can be optimized with backstrides.
|
||||||
|
element -= strides[axis] * shape[axis];
|
||||||
} else {
|
} else {
|
||||||
|
element += strides[axis];
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -73,32 +98,21 @@ struct IndicesIter {
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
void __call_nac3_ndarray_indices_iter_initialize(IndicesIter<int32_t>* iter,
|
void __nac3_nditer_initialize(NDIter<int32_t>* iter, NDArray<int32_t>* ndarray,
|
||||||
int32_t ndims, int32_t* shape,
|
|
||||||
int32_t* indices) {
|
int32_t* indices) {
|
||||||
iter->initialize(ndims, shape, indices);
|
iter->initialize_by_ndarray(ndarray, indices);
|
||||||
}
|
}
|
||||||
|
|
||||||
void __call_nac3_ndarray_indices_iter_initialize64(IndicesIter<int64_t>* iter,
|
void __nac3_nditer_initialize64(NDIter<int64_t>* iter,
|
||||||
int64_t ndims,
|
NDArray<int64_t>* ndarray, int64_t* indices) {
|
||||||
int64_t* shape,
|
iter->initialize_by_ndarray(ndarray, indices);
|
||||||
int64_t* indices) {
|
|
||||||
iter->initialize(ndims, shape, indices);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool __call_nac3_ndarray_indices_iter_has_next(IndicesIter<int32_t>* iter) {
|
bool __nac3_nditer_has_next(NDIter<int32_t>* iter) { return iter->has_next(); }
|
||||||
iter->has_next();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool __call_nac3_ndarray_indices_iter_has_next64(IndicesIter<int64_t>* iter) {
|
bool __nac3_nditer_has_next64(NDIter<int64_t>* iter) { return iter->has_next(); }
|
||||||
iter->has_next();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool __call_nac3_ndarray_indices_iter_next(IndicesIter<int32_t>* iter) {
|
void __nac3_nditer_next(NDIter<int32_t>* iter) { iter->next(); }
|
||||||
iter->next();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool __call_nac3_ndarray_indices_iter_next64(IndicesIter<int64_t>* iter) {
|
void __nac3_nditer_next64(NDIter<int64_t>* iter) { iter->next(); }
|
||||||
iter->next();
|
|
||||||
}
|
|
||||||
}
|
}
|
|
@ -123,7 +123,9 @@ void matmul_at_least_2d(NDArray<SizeT>* a_ndarray, NDArray<SizeT>* b_ndarray,
|
||||||
SizeT* indices =
|
SizeT* indices =
|
||||||
(SizeT*)__builtin_alloca(sizeof(SizeT) * dst_ndarray->ndims);
|
(SizeT*)__builtin_alloca(sizeof(SizeT) * dst_ndarray->ndims);
|
||||||
SizeT* mat_indices = indices + u;
|
SizeT* mat_indices = indices + u;
|
||||||
IndicesIter<SizeT> iter(u, dst_ndarray->shape, indices);
|
NDIter<SizeT> iter;
|
||||||
|
iter.initialize(u, dst_ndarray->shape, dst_ndarray->strides,
|
||||||
|
dst_ndarray->data, indices);
|
||||||
|
|
||||||
for (; iter.has_next(); iter.next()) {
|
for (; iter.has_next(); iter.next()) {
|
||||||
for (SizeT i = 0; i < dst_mat_shape[0]; i++) {
|
for (SizeT i = 0; i < dst_mat_shape[0]; i++) {
|
||||||
|
|
|
@ -1565,7 +1565,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
ctx,
|
ctx,
|
||||||
&[left, right],
|
&[left, right],
|
||||||
out,
|
out,
|
||||||
|generator, ctx, _i, scalars| {
|
|generator, ctx, scalars| {
|
||||||
let left = scalars[0];
|
let left = scalars[0];
|
||||||
let right = scalars[1];
|
let right = scalars[1];
|
||||||
gen_binop_expr_with_values(
|
gen_binop_expr_with_values(
|
||||||
|
|
|
@ -6,7 +6,7 @@ mod test;
|
||||||
use super::model::*;
|
use super::model::*;
|
||||||
use super::object::ndarray::broadcast::ShapeEntry;
|
use super::object::ndarray::broadcast::ShapeEntry;
|
||||||
use super::object::ndarray::indexing::NDIndex;
|
use super::object::ndarray::indexing::NDIndex;
|
||||||
use super::structure::{List, NDArray};
|
use super::structure::{List, NDArray, NDIter};
|
||||||
use super::{
|
use super::{
|
||||||
classes::{
|
classes::{
|
||||||
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue,
|
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue,
|
||||||
|
@ -1239,7 +1239,31 @@ pub fn call_nac3_array_write_list_to_array<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
CallFunction::begin(generator, ctx, &name).arg(list).arg(ndarray).returning_void();
|
CallFunction::begin(generator, ctx, &name).arg(list).arg(ndarray).returning_void();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn call_nac3_ndarray_indices_iter_initialize<'ctx, G: CodeGenerator + ?Sized>(
|
pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
iter: Ptr<'ctx, StructModel<NDIter>>,
|
||||||
|
ndarray: Ptr<'ctx, StructModel<NDArray>>,
|
||||||
|
indices: Ptr<'ctx, IntModel<SizeT>>,
|
||||||
) {
|
) {
|
||||||
|
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_initialize");
|
||||||
|
CallFunction::begin(generator, ctx, &name).arg(iter).arg(ndarray).arg(indices).returning_void();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_nditer_has_next<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
iter: Ptr<'ctx, StructModel<NDIter>>,
|
||||||
|
) -> Int<'ctx, Bool> {
|
||||||
|
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_has_next");
|
||||||
|
CallFunction::begin(generator, ctx, &name).arg(iter).returning_auto("has_next")
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_nditer_next<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
iter: Ptr<'ctx, StructModel<NDIter>>,
|
||||||
|
) {
|
||||||
|
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_next");
|
||||||
|
CallFunction::begin(generator, ctx, &name).arg(iter).returning_void();
|
||||||
}
|
}
|
|
@ -127,35 +127,6 @@ pub fn gen_ndarray_ones<'ctx>(
|
||||||
Ok(ndarray.instance.value.as_basic_value_enum())
|
Ok(ndarray.instance.value.as_basic_value_enum())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates LLVM IR for `np.full`.
|
|
||||||
pub fn gen_ndarray_full<'ctx>(
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
|
||||||
fun: (&FunSignature, DefinitionId),
|
|
||||||
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
|
||||||
generator: &mut dyn CodeGenerator,
|
|
||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
|
||||||
assert!(obj.is_none());
|
|
||||||
assert_eq!(args.len(), 2);
|
|
||||||
|
|
||||||
// Parse argument #1 shape
|
|
||||||
let shape_ty = fun.0.args[0].ty;
|
|
||||||
let shape = args[0].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
|
|
||||||
let shape = AnyObject { ty: shape_ty, value: shape };
|
|
||||||
|
|
||||||
// Parse argument #2 fill_value
|
|
||||||
let fill_value_ty = fun.0.args[1].ty;
|
|
||||||
let fill_value = args[1].1.clone().to_basic_value_enum(ctx, generator, fill_value_ty)?;
|
|
||||||
|
|
||||||
// Implementation
|
|
||||||
let ndarray_ty = fun.0.ret;
|
|
||||||
let ndarray = create_empty_ndarray(generator, ctx, ndarray_ty, shape);
|
|
||||||
|
|
||||||
ndarray.fill(generator, ctx, fill_value);
|
|
||||||
|
|
||||||
Ok(ndarray.instance.value.as_basic_value_enum())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates LLVM IR for `np.broadcast_to`.
|
/// Generates LLVM IR for `np.broadcast_to`.
|
||||||
pub fn gen_ndarray_broadcast_to<'ctx>(
|
pub fn gen_ndarray_broadcast_to<'ctx>(
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
@ -345,89 +316,3 @@ pub fn gen_ndarray_strides<'ctx>(
|
||||||
Ok(strides.value.as_basic_value_enum())
|
Ok(strides.value.as_basic_value_enum())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates LLVM IR for `np.transpose`.
|
|
||||||
pub fn gen_ndarray_transpose<'ctx>(
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
|
||||||
fun: (&FunSignature, DefinitionId),
|
|
||||||
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
|
||||||
generator: &mut dyn CodeGenerator,
|
|
||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
|
||||||
// TODO: The implementation will be changed once default values start working again.
|
|
||||||
// Read the comment on this function in BuiltinBuilder.
|
|
||||||
|
|
||||||
// TODO: Change axes values to `SizeT`
|
|
||||||
|
|
||||||
assert!(obj.is_none());
|
|
||||||
assert_eq!(args.len(), 1);
|
|
||||||
|
|
||||||
// Parse argument #1 ndarray
|
|
||||||
let ndarray_ty = fun.0.args[0].ty;
|
|
||||||
let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
|
|
||||||
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
|
|
||||||
|
|
||||||
// Implementation
|
|
||||||
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
|
|
||||||
|
|
||||||
let has_axes = args.len() >= 2;
|
|
||||||
let transposed_ndarray = if has_axes {
|
|
||||||
// Parse argument #2 axes
|
|
||||||
let in_axes_ty = fun.0.args[1].ty;
|
|
||||||
let in_axes = args[1].1.clone().to_basic_value_enum(ctx, generator, in_axes_ty)?;
|
|
||||||
let in_axes = AnyObject { ty: in_axes_ty, value: in_axes };
|
|
||||||
|
|
||||||
let (_, axes) = parse_numpy_int_sequence(generator, ctx, in_axes);
|
|
||||||
|
|
||||||
ndarray.transpose(generator, ctx, Some(axes))
|
|
||||||
} else {
|
|
||||||
// axes is not given
|
|
||||||
ndarray.transpose(generator, ctx, None)
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(transposed_ndarray.instance.value.as_basic_value_enum())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn gen_ndarray_array<'ctx>(
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
|
||||||
fun: (&FunSignature, DefinitionId),
|
|
||||||
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
|
||||||
generator: &mut dyn CodeGenerator,
|
|
||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
|
||||||
assert!(obj.is_none());
|
|
||||||
assert!(matches!(args.len(), 1..=3));
|
|
||||||
|
|
||||||
let object_ty = fun.0.args[0].ty;
|
|
||||||
let object = args[0].1.clone().to_basic_value_enum(ctx, generator, object_ty)?;
|
|
||||||
let object = AnyObject { ty: object_ty, value: object };
|
|
||||||
|
|
||||||
let copy_arg = if let Some(arg) =
|
|
||||||
args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name))
|
|
||||||
{
|
|
||||||
let copy_ty = fun.0.args[1].ty;
|
|
||||||
arg.1.clone().to_basic_value_enum(ctx, generator, copy_ty)?
|
|
||||||
} else {
|
|
||||||
ctx.gen_symbol_val(
|
|
||||||
generator,
|
|
||||||
fun.0.args[1].default_value.as_ref().unwrap(),
|
|
||||||
fun.0.args[1].ty,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
|
|
||||||
// The argument `ndmin` is completely ignored. We don't need to know its LLVM value.
|
|
||||||
// We simply make the output ndarray's ndims correct with `atleast_nd`.
|
|
||||||
|
|
||||||
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
|
|
||||||
let output_ndims = extract_ndims(&ctx.unifier, ndims);
|
|
||||||
|
|
||||||
let copy = IntModel(Byte).check_value(generator, ctx.ctx, copy_arg).unwrap(); // NAC3 booleans are i8
|
|
||||||
let copy = copy.truncate(generator, ctx, Bool, "copy_bool");
|
|
||||||
|
|
||||||
let ndarray = NDArrayObject::from_np_array(generator, ctx, object, copy);
|
|
||||||
debug_assert!(ndarray.ndims <= output_ndims); // Sanity check on `ndims`
|
|
||||||
|
|
||||||
let ndarray = ndarray.atleast_nd(generator, ctx, output_ndims);
|
|
||||||
debug_assert!(ctx.unifier.unioned(ndarray.dtype, dtype)); // Sanity check on `dtype`
|
|
||||||
|
|
||||||
Ok(ndarray.instance.value.as_basic_value_enum())
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
|
|
||||||
use crate::codegen::{
|
use crate::codegen::{
|
||||||
irrt::{call_nac3_ndarray_broadcast_shapes, call_nac3_ndarray_broadcast_to},
|
irrt::{
|
||||||
|
call_nac3_ndarray_broadcast_shapes, call_nac3_ndarray_broadcast_to,
|
||||||
|
call_nac3_ndarray_util_assert_shape_no_negative,
|
||||||
|
},
|
||||||
model::*,
|
model::*,
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
};
|
};
|
||||||
|
@ -29,6 +32,8 @@ impl<'ctx> StructKind<'ctx> for ShapeEntry {
|
||||||
impl<'ctx> NDArrayObject<'ctx> {
|
impl<'ctx> NDArrayObject<'ctx> {
|
||||||
/// Create a broadcast view on this ndarray with a target shape.
|
/// Create a broadcast view on this ndarray with a target shape.
|
||||||
///
|
///
|
||||||
|
/// The input shape will be checked to make sure that it contains no negative values.
|
||||||
|
///
|
||||||
/// * `target_ndims` - The ndims type after broadcasting to the given shape.
|
/// * `target_ndims` - The ndims type after broadcasting to the given shape.
|
||||||
/// The caller has to figure this out for this function.
|
/// The caller has to figure this out for this function.
|
||||||
/// * `target_shape` - An array pointer pointing to the target shape.
|
/// * `target_shape` - An array pointer pointing to the target shape.
|
||||||
|
@ -40,6 +45,14 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
target_ndims: u64,
|
target_ndims: u64,
|
||||||
target_shape: Ptr<'ctx, IntModel<SizeT>>,
|
target_shape: Ptr<'ctx, IntModel<SizeT>>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
let target_ndims_llvm = IntModel(SizeT).constant(generator, ctx.ctx, target_ndims);
|
||||||
|
call_nac3_ndarray_util_assert_shape_no_negative(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
target_ndims_llvm,
|
||||||
|
target_shape,
|
||||||
|
);
|
||||||
|
|
||||||
let broadcast_ndarray = NDArrayObject::alloca(
|
let broadcast_ndarray = NDArrayObject::alloca(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use inkwell::values::BasicValueEnum;
|
use inkwell::{values::BasicValueEnum, IntPredicate};
|
||||||
|
|
||||||
use super::{scalar::ScalarObject, NDArrayObject};
|
use super::{scalar::ScalarObject, NDArrayObject};
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -154,10 +154,15 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
// Create data and set elements
|
// Create data and set elements
|
||||||
ndarray.create_data(generator, ctx);
|
ndarray.create_data(generator, ctx);
|
||||||
ndarray
|
ndarray
|
||||||
.foreach_pointer(generator, ctx, |_generator, ctx, _hooks, i, pelement| {
|
.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
|
||||||
|
// Get the index of the current element, convert that index to float, and write it.
|
||||||
|
// This is how we get [0.0, 1.0, 2.0, ...].
|
||||||
|
let index = nditer.get_index(generator, ctx);
|
||||||
|
let pelement = nditer.get_pointer(generator, ctx);
|
||||||
|
|
||||||
let val = ctx
|
let val = ctx
|
||||||
.builder
|
.builder
|
||||||
.build_unsigned_int_to_float(i.value, ctx.ctx.f64_type(), "val")
|
.build_unsigned_int_to_float(index.value, ctx.ctx.f64_type(), "val")
|
||||||
.unwrap();
|
.unwrap();
|
||||||
ctx.builder.build_store(pelement, val).unwrap();
|
ctx.builder.build_store(pelement, val).unwrap();
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -172,23 +177,45 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
dtype: Type,
|
dtype: Type,
|
||||||
rows: Int<'ctx, SizeT>,
|
num_rows: Int<'ctx, SizeT>,
|
||||||
cols: Int<'ctx, SizeT>,
|
num_cols: Int<'ctx, SizeT>,
|
||||||
diagonal: Int<'ctx, SizeT>,
|
diagonal: Int<'ctx, SizeT>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
let ndzero = ndarray_zero_value(generator, ctx, dtype);
|
||||||
|
let ndone = ndarray_one_value(generator, ctx, dtype);
|
||||||
|
|
||||||
let ndarray = NDArrayObject::alloca_dynamic_shape(
|
let ndarray = NDArrayObject::alloca_dynamic_shape(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
dtype,
|
dtype,
|
||||||
&[rows, cols],
|
&[num_rows, num_cols],
|
||||||
"eye_ndarray",
|
"eye_ndarray",
|
||||||
);
|
);
|
||||||
|
|
||||||
ndarray
|
ndarray
|
||||||
.foreach_pointer(generator, ctx, |generator, ctx, hooks, i, p| {
|
.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
|
||||||
// NOTE: rows and cols can never be zero here, since this ndarray's `np.size` would be zero
|
// NOTE: rows and cols can never be zero here, since this ndarray's `np.size` would be zero
|
||||||
// and this loop would not execute.
|
// and this loop would not execute.
|
||||||
|
|
||||||
todo!()
|
// Load up `row_i` and `col_i` from indices.
|
||||||
|
let row_i = nditer
|
||||||
|
.get_indices()
|
||||||
|
.offset_const(generator, ctx, 0, "")
|
||||||
|
.load(generator, ctx, "row_i");
|
||||||
|
let col_i = nditer
|
||||||
|
.get_indices()
|
||||||
|
.offset_const(generator, ctx, 1, "")
|
||||||
|
.load(generator, ctx, "col_i");
|
||||||
|
|
||||||
|
// Write to element
|
||||||
|
let be_one =
|
||||||
|
row_i.add(ctx, diagonal, "").compare(ctx, IntPredicate::EQ, col_i, "write_one");
|
||||||
|
let value = ctx.builder.build_select(be_one.value, ndone, ndzero, "value").unwrap();
|
||||||
|
|
||||||
|
let p = nditer.get_pointer(generator, ctx);
|
||||||
|
ctx.builder.build_store(p, value).unwrap();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
todo!()
|
todo!()
|
||||||
|
|
|
@ -480,7 +480,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
let zero = sizet_model.const_0(generator, ctx.ctx);
|
let zero = sizet_model.const_0(generator, ctx.ctx);
|
||||||
pextremum_index.store(ctx, zero);
|
pextremum_index.store(ctx, zero);
|
||||||
|
|
||||||
let first_scalar = self.get_nth(generator, ctx, zero);
|
let first_scalar = self.get_nth_scalar(generator, ctx, zero);
|
||||||
ctx.builder.build_store(pextremum, first_scalar.value).unwrap();
|
ctx.builder.build_store(pextremum, first_scalar.value).unwrap();
|
||||||
|
|
||||||
// Find extremum
|
// Find extremum
|
||||||
|
@ -491,7 +491,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
// Worth reading on "Notes" in <https://numpy.org/doc/stable/reference/generated/numpy.min.html#numpy.min>
|
// Worth reading on "Notes" in <https://numpy.org/doc/stable/reference/generated/numpy.min.html#numpy.min>
|
||||||
// on how `NaN` values have to be handled.
|
// on how `NaN` values have to be handled.
|
||||||
|
|
||||||
let scalar = self.get_nth(generator, ctx, i);
|
let scalar = self.get_nth_scalar(generator, ctx, i);
|
||||||
|
|
||||||
let old_extremum = ctx.builder.build_load(pextremum, "current_extremum").unwrap();
|
let old_extremum = ctx.builder.build_load(pextremum, "current_extremum").unwrap();
|
||||||
let old_extremum = ScalarObject { dtype: self.dtype, value: old_extremum };
|
let old_extremum = ScalarObject { dtype: self.dtype, value: old_extremum };
|
||||||
|
|
|
@ -1,17 +1,17 @@
|
||||||
use inkwell::values::BasicValueEnum;
|
use inkwell::values::BasicValueEnum;
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use util::gen_for_model_auto;
|
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
codegen::{
|
codegen::{
|
||||||
model::*,
|
model::*,
|
||||||
object::ndarray::{NDArrayObject, ScalarObject},
|
object::ndarray::{NDArrayObject, ScalarObject},
|
||||||
|
stmt::gen_for_callback,
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
},
|
},
|
||||||
typecheck::typedef::Type,
|
typecheck::typedef::Type,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{scalar::ScalarOrNDArray, NDArrayOut};
|
use super::{nditer::NDIterHandle, scalar::ScalarOrNDArray, NDArrayOut};
|
||||||
|
|
||||||
impl<'ctx> NDArrayObject<'ctx> {
|
impl<'ctx> NDArrayObject<'ctx> {
|
||||||
/// TODO: Document me. Has complex behavior.
|
/// TODO: Document me. Has complex behavior.
|
||||||
|
@ -28,12 +28,9 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
MappingFn: FnOnce(
|
MappingFn: FnOnce(
|
||||||
&mut G,
|
&mut G,
|
||||||
&mut CodeGenContext<'ctx, 'a>,
|
&mut CodeGenContext<'ctx, 'a>,
|
||||||
Int<'ctx, SizeT>,
|
|
||||||
&[ScalarObject<'ctx>],
|
&[ScalarObject<'ctx>],
|
||||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||||
{
|
{
|
||||||
let sizet_model = IntModel(SizeT);
|
|
||||||
|
|
||||||
// Broadcast inputs
|
// Broadcast inputs
|
||||||
let broadcast_result = NDArrayObject::broadcast(generator, ctx, ndarrays);
|
let broadcast_result = NDArrayObject::broadcast(generator, ctx, ndarrays);
|
||||||
|
|
||||||
|
@ -66,19 +63,60 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Map element-wise and store results into `mapped_ndarray`.
|
// Map element-wise and store results into `mapped_ndarray`.
|
||||||
let start = sizet_model.const_0(generator, ctx.ctx);
|
let nditer = NDIterHandle::new(generator, ctx, out_ndarray);
|
||||||
let stop = broadcast_result.ndarrays[0].size(generator, ctx); // They all should have the same `np.size`.
|
gen_for_callback(
|
||||||
let step = sizet_model.const_1(generator, ctx.ctx);
|
generator,
|
||||||
gen_for_model_auto(generator, ctx, start, stop, step, move |generator, ctx, _hooks, i| {
|
ctx,
|
||||||
let elements =
|
Some("broadcast_starmap"),
|
||||||
ndarrays.iter().map(|ndarray| ndarray.get_nth(generator, ctx, i)).collect_vec();
|
|generator, ctx| {
|
||||||
|
// Create NDIters for all broadcasted input ndarrays.
|
||||||
|
let other_nditers = broadcast_result
|
||||||
|
.ndarrays
|
||||||
|
.iter()
|
||||||
|
.map(|ndarray| NDIterHandle::new(generator, ctx, *ndarray))
|
||||||
|
.collect_vec();
|
||||||
|
Ok((nditer, other_nditers))
|
||||||
|
},
|
||||||
|
|generator, ctx, (out_nditer, _in_nditers)| {
|
||||||
|
// We can simply use `out_nditer`'s `has_next()`.
|
||||||
|
// `in_nditers`' `has_next()`s should return the same value.
|
||||||
|
Ok(out_nditer.has_next(generator, ctx).value)
|
||||||
|
},
|
||||||
|
|generator, ctx, _hooks, (out_nditer, in_nditers)| {
|
||||||
|
// Get all the scalars from the broadcasted input ndarrays, pass them to `mapping`,
|
||||||
|
// and write to `out_ndarray`.
|
||||||
|
|
||||||
let ret = mapping(generator, ctx, i, &elements)?;
|
let in_scalars =
|
||||||
|
in_nditers.iter().map(|nditer| nditer.get_scalar(generator, ctx)).collect_vec();
|
||||||
|
|
||||||
|
let result = mapping(generator, ctx, &in_scalars)?;
|
||||||
|
|
||||||
|
let p = out_nditer.get_pointer(generator, ctx);
|
||||||
|
ctx.builder.build_store(p, result).unwrap();
|
||||||
|
|
||||||
let pret = out_ndarray.get_nth_pointer(generator, ctx, i, "pret");
|
|
||||||
ctx.builder.build_store(pret, ret).unwrap();
|
|
||||||
Ok(())
|
Ok(())
|
||||||
})?;
|
},
|
||||||
|
|generator, ctx, (out_nditer, in_nditers)| {
|
||||||
|
// Advance all iterators
|
||||||
|
out_nditer.next(generator, ctx);
|
||||||
|
in_nditers.iter().for_each(|nditer| nditer.next(generator, ctx));
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// let start = sizet_model.const_0(generator, ctx.ctx);
|
||||||
|
// let stop = broadcast_result.ndarrays[0].size(generator, ctx); // They all should have the same `np.size`.
|
||||||
|
// let step = sizet_model.const_1(generator, ctx.ctx);
|
||||||
|
// gen_for_model_auto(generator, ctx, start, stop, step, move |generator, ctx, _hooks, i| {
|
||||||
|
// let elements =
|
||||||
|
// ndarrays.iter().map(|ndarray| ndarray.get_nth_scalar(generator, ctx, i)).collect_vec();
|
||||||
|
|
||||||
|
// let ret = mapping(generator, ctx, i, &elements)?;
|
||||||
|
|
||||||
|
// let pret = out_ndarray.get_nth_pointer(generator, ctx, i, "pret");
|
||||||
|
// ctx.builder.build_store(pret, ret).unwrap();
|
||||||
|
// Ok(())
|
||||||
|
// })?;
|
||||||
|
|
||||||
Ok(out_ndarray)
|
Ok(out_ndarray)
|
||||||
}
|
}
|
||||||
|
@ -95,7 +133,6 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
Mapping: FnOnce(
|
Mapping: FnOnce(
|
||||||
&mut G,
|
&mut G,
|
||||||
&mut CodeGenContext<'ctx, 'a>,
|
&mut CodeGenContext<'ctx, 'a>,
|
||||||
Int<'ctx, SizeT>,
|
|
||||||
ScalarObject<'ctx>,
|
ScalarObject<'ctx>,
|
||||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||||
{
|
{
|
||||||
|
@ -104,7 +141,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
ctx,
|
ctx,
|
||||||
&[*self],
|
&[*self],
|
||||||
out,
|
out,
|
||||||
|generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]),
|
|generator, ctx, scalars| mapping(generator, ctx, scalars[0]),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -124,20 +161,16 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||||
MappingFn: FnOnce(
|
MappingFn: FnOnce(
|
||||||
&mut G,
|
&mut G,
|
||||||
&mut CodeGenContext<'ctx, 'a>,
|
&mut CodeGenContext<'ctx, 'a>,
|
||||||
Int<'ctx, SizeT>,
|
|
||||||
&[ScalarObject<'ctx>],
|
&[ScalarObject<'ctx>],
|
||||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||||
{
|
{
|
||||||
let sizet_model = IntModel(SizeT);
|
|
||||||
|
|
||||||
// Check if all inputs are ScalarObjects
|
// Check if all inputs are ScalarObjects
|
||||||
let all_scalars: Option<Vec<_>> =
|
let all_scalars: Option<Vec<_>> =
|
||||||
inputs.iter().map(ScalarObject::try_from).try_collect().ok();
|
inputs.iter().map(ScalarObject::try_from).try_collect().ok();
|
||||||
|
|
||||||
if let Some(scalars) = all_scalars {
|
if let Some(scalars) = all_scalars {
|
||||||
let i = sizet_model.const_0(generator, ctx.ctx); // Pass 0 as the index
|
|
||||||
let scalar =
|
let scalar =
|
||||||
ScalarObject { value: mapping(generator, ctx, i, &scalars)?, dtype: ret_dtype };
|
ScalarObject { value: mapping(generator, ctx, &scalars)?, dtype: ret_dtype };
|
||||||
Ok(ScalarOrNDArray::Scalar(scalar))
|
Ok(ScalarOrNDArray::Scalar(scalar))
|
||||||
} else {
|
} else {
|
||||||
// Promote all input to ndarrays and map through them.
|
// Promote all input to ndarrays and map through them.
|
||||||
|
@ -165,7 +198,6 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||||
Mapping: FnOnce(
|
Mapping: FnOnce(
|
||||||
&mut G,
|
&mut G,
|
||||||
&mut CodeGenContext<'ctx, 'a>,
|
&mut CodeGenContext<'ctx, 'a>,
|
||||||
Int<'ctx, SizeT>,
|
|
||||||
ScalarObject<'ctx>,
|
ScalarObject<'ctx>,
|
||||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||||
{
|
{
|
||||||
|
@ -174,7 +206,7 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||||
ctx,
|
ctx,
|
||||||
&[*self],
|
&[*self],
|
||||||
ret_dtype,
|
ret_dtype,
|
||||||
|generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]),
|
|generator, ctx, scalars| mapping(generator, ctx, scalars[0]),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ pub mod functions;
|
||||||
pub mod indexing;
|
pub mod indexing;
|
||||||
pub mod mapping;
|
pub mod mapping;
|
||||||
pub mod nalgebra;
|
pub mod nalgebra;
|
||||||
|
pub mod nditer;
|
||||||
pub mod product;
|
pub mod product;
|
||||||
pub mod scalar;
|
pub mod scalar;
|
||||||
pub mod shape_util;
|
pub mod shape_util;
|
||||||
|
@ -16,10 +17,11 @@ use crate::{
|
||||||
call_nac3_ndarray_is_c_contiguous, call_nac3_ndarray_len, call_nac3_ndarray_nbytes,
|
call_nac3_ndarray_is_c_contiguous, call_nac3_ndarray_len, call_nac3_ndarray_nbytes,
|
||||||
call_nac3_ndarray_resolve_and_check_new_shape, call_nac3_ndarray_set_strides_by_shape,
|
call_nac3_ndarray_resolve_and_check_new_shape, call_nac3_ndarray_set_strides_by_shape,
|
||||||
call_nac3_ndarray_size, call_nac3_ndarray_transpose,
|
call_nac3_ndarray_size, call_nac3_ndarray_transpose,
|
||||||
call_nac3_ndarray_util_assert_output_shape_same,
|
call_nac3_ndarray_util_assert_output_shape_same, call_nac3_nditer_has_next,
|
||||||
|
call_nac3_nditer_initialize, call_nac3_nditer_next,
|
||||||
},
|
},
|
||||||
model::*,
|
model::*,
|
||||||
stmt::BreakContinueHooks,
|
stmt::{gen_for_callback, BreakContinueHooks},
|
||||||
structure::{NDArray, SimpleNDArray},
|
structure::{NDArray, SimpleNDArray},
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
},
|
},
|
||||||
|
@ -36,6 +38,7 @@ use inkwell::{
|
||||||
values::{BasicValue, BasicValueEnum, PointerValue},
|
values::{BasicValue, BasicValueEnum, PointerValue},
|
||||||
AddressSpace, IntPredicate,
|
AddressSpace, IntPredicate,
|
||||||
};
|
};
|
||||||
|
use nditer::NDIterHandle;
|
||||||
use scalar::{ScalarObject, ScalarOrNDArray};
|
use scalar::{ScalarObject, ScalarOrNDArray};
|
||||||
use util::{call_memcpy_model, gen_for_model_auto};
|
use util::{call_memcpy_model, gen_for_model_auto};
|
||||||
|
|
||||||
|
@ -250,7 +253,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
/// Get the n-th (0-based) scalar.
|
/// Get the n-th (0-based) scalar.
|
||||||
///
|
///
|
||||||
/// There is no out-of-bounds check.
|
/// There is no out-of-bounds check.
|
||||||
pub fn get_nth<G: CodeGenerator + ?Sized>(
|
pub fn get_nth_scalar<G: CodeGenerator + ?Sized>(
|
||||||
&self,
|
&self,
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
@ -264,7 +267,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
/// Set the n-th (0-based) scalar.
|
/// Set the n-th (0-based) scalar.
|
||||||
///
|
///
|
||||||
/// There is no out-of-bounds check.
|
/// There is no out-of-bounds check.
|
||||||
pub fn set_nth<G: CodeGenerator + ?Sized>(
|
pub fn set_nth_scalar<G: CodeGenerator + ?Sized>(
|
||||||
&self,
|
&self,
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
@ -466,7 +469,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
// NOTE: `np.size(self) == 0` here is never possible.
|
// NOTE: `np.size(self) == 0` here is never possible.
|
||||||
let sizet_model = IntModel(SizeT);
|
let sizet_model = IntModel(SizeT);
|
||||||
let zero = sizet_model.const_0(generator, ctx.ctx);
|
let zero = sizet_model.const_0(generator, ctx.ctx);
|
||||||
ScalarOrNDArray::Scalar(self.get_nth(generator, ctx, zero))
|
ScalarOrNDArray::Scalar(self.get_nth_scalar(generator, ctx, zero))
|
||||||
} else {
|
} else {
|
||||||
ScalarOrNDArray::NDArray(*self)
|
ScalarOrNDArray::NDArray(*self)
|
||||||
}
|
}
|
||||||
|
@ -543,40 +546,9 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
self.copy_strides_from_array(generator, ctx, src_strides);
|
self.copy_strides_from_array(generator, ctx, src_strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Iterate through every element pointer in the ndarray in its flatten view.
|
/// Iterate through every element in the ndarray.
|
||||||
///
|
///
|
||||||
/// `body` also access to [`BreakContinueHooks`] to short-circuit and an element's
|
/// `body` also access to [`BreakContinueHooks`] to short-circuit.
|
||||||
/// index. The given element pointer also has been casted to the LLVM type of this ndarray's `dtype`.
|
|
||||||
pub fn foreach_pointer<'a, G, F>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
|
||||||
body: F,
|
|
||||||
) -> Result<(), String>
|
|
||||||
where
|
|
||||||
G: CodeGenerator + ?Sized,
|
|
||||||
F: FnOnce(
|
|
||||||
&mut G,
|
|
||||||
&mut CodeGenContext<'ctx, 'a>,
|
|
||||||
BreakContinueHooks<'ctx>,
|
|
||||||
Int<'ctx, SizeT>,
|
|
||||||
Ptr<'ctx, IntModel<SizeT>>,
|
|
||||||
PointerValue<'ctx>,
|
|
||||||
) -> Result<(), String>,
|
|
||||||
{
|
|
||||||
let sizet_model = IntModel(SizeT);
|
|
||||||
|
|
||||||
let start = sizet_model.const_0(generator, ctx.ctx);
|
|
||||||
let stop = self.size(generator, ctx);
|
|
||||||
let step = sizet_model.const_1(generator, ctx.ctx);
|
|
||||||
|
|
||||||
gen_for_model_auto(generator, ctx, start, stop, step, |generator, ctx, hooks, i| {
|
|
||||||
let pelement = self.get_nth_pointer(generator, ctx, i, "element");
|
|
||||||
body(generator, ctx, hooks, i, pelement)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Iterate through every scalar in this ndarray.
|
|
||||||
pub fn foreach<'a, G, F>(
|
pub fn foreach<'a, G, F>(
|
||||||
&self,
|
&self,
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
|
@ -589,15 +561,18 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
&mut G,
|
&mut G,
|
||||||
&mut CodeGenContext<'ctx, 'a>,
|
&mut CodeGenContext<'ctx, 'a>,
|
||||||
BreakContinueHooks<'ctx>,
|
BreakContinueHooks<'ctx>,
|
||||||
Int<'ctx, SizeT>,
|
NDIterHandle<'ctx>,
|
||||||
ScalarObject<'ctx>,
|
|
||||||
) -> Result<(), String>,
|
) -> Result<(), String>,
|
||||||
{
|
{
|
||||||
self.foreach_pointer(generator, ctx, |generator, ctx, hooks, i, p| {
|
gen_for_callback(
|
||||||
let value = ctx.builder.build_load(p, "value").unwrap();
|
generator,
|
||||||
let scalar = ScalarObject { dtype: self.dtype, value };
|
ctx,
|
||||||
body(generator, ctx, hooks, i, scalar)
|
Some("ndarray_foreach"),
|
||||||
})
|
|generator, ctx| Ok(NDIterHandle::new(generator, ctx, *self)),
|
||||||
|
|generator, ctx, nditer| Ok(nditer.has_next(generator, ctx).value),
|
||||||
|
|generator, ctx, hooks, nditer| body(generator, ctx, hooks, nditer),
|
||||||
|
|generator, ctx, nditer| Ok(nditer.next(generator, ctx)),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Make sure the ndarray is at least `ndmin`-dimensional.
|
/// Make sure the ndarray is at least `ndmin`-dimensional.
|
||||||
|
@ -632,8 +607,9 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
fill_value: BasicValueEnum<'ctx>,
|
fill_value: BasicValueEnum<'ctx>,
|
||||||
) {
|
) {
|
||||||
self.foreach_pointer(generator, ctx, |_generator, ctx, _hooks, _i, pelement| {
|
self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
|
||||||
ctx.builder.build_store(pelement, fill_value).unwrap();
|
let p = nditer.get_pointer(generator, ctx);
|
||||||
|
ctx.builder.build_store(p, fill_value).unwrap();
|
||||||
Ok(())
|
Ok(())
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
|
@ -0,0 +1,83 @@
|
||||||
|
use inkwell::{types::BasicType, values::PointerValue, AddressSpace};
|
||||||
|
|
||||||
|
use crate::codegen::{
|
||||||
|
irrt::{call_nac3_nditer_has_next, call_nac3_nditer_initialize, call_nac3_nditer_next},
|
||||||
|
model::*,
|
||||||
|
object::ndarray::scalar::ScalarObject,
|
||||||
|
structure::NDIter,
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::NDArrayObject;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct NDIterHandle<'ctx> {
|
||||||
|
ndarray: NDArrayObject<'ctx>,
|
||||||
|
instance: Ptr<'ctx, StructModel<NDIter>>,
|
||||||
|
indices: Ptr<'ctx, IntModel<SizeT>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> NDIterHandle<'ctx> {
|
||||||
|
pub fn new<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray: NDArrayObject<'ctx>,
|
||||||
|
) -> Self {
|
||||||
|
let nditer = StructModel(NDIter).alloca(generator, ctx, "nditer");
|
||||||
|
let ndims = ndarray.get_ndims(generator, ctx.ctx);
|
||||||
|
let indices = IntModel(SizeT).array_alloca(generator, ctx, ndims.value, "indices");
|
||||||
|
call_nac3_nditer_initialize(generator, ctx, nditer, ndarray.instance, indices);
|
||||||
|
NDIterHandle { ndarray, instance: nditer, indices }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn has_next<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) -> Int<'ctx, Bool> {
|
||||||
|
call_nac3_nditer_has_next(generator, ctx, self.instance)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn next<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) {
|
||||||
|
call_nac3_nditer_next(generator, ctx, self.instance)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_pointer<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
let elem_ty = ctx.get_llvm_type(generator, self.ndarray.dtype);
|
||||||
|
|
||||||
|
let p = self.instance.get(generator, ctx, |f| f.element, "element");
|
||||||
|
ctx.builder
|
||||||
|
.build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), "element")
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_scalar<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) -> ScalarObject<'ctx> {
|
||||||
|
let p = self.get_pointer(generator, ctx);
|
||||||
|
let value = ctx.builder.build_load(p, "value").unwrap();
|
||||||
|
ScalarObject { dtype: self.ndarray.dtype, value }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_index<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) -> Int<'ctx, SizeT> {
|
||||||
|
self.instance.get(generator, ctx, |f| f.nth, "index")
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_indices(&self) -> Ptr<'ctx, IntModel<SizeT>> {
|
||||||
|
self.indices
|
||||||
|
}
|
||||||
|
}
|
|
@ -187,28 +187,36 @@ impl<'ctx, Item: Model<'ctx>> StructKind<'ctx> for SimpleNDArray<Item> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// An IRRT helper structure used when iterating through an ndarray.
|
/// An IRRT helper structure used when iterating through an ndarray.
|
||||||
/// Fields of [`IndicesIter`]
|
/// Fields of [`NDIter`]
|
||||||
pub struct IndicesIterFields<'ctx, F: FieldTraversal<'ctx>> {
|
pub struct NDIterFields<'ctx, F: FieldTraversal<'ctx>> {
|
||||||
pub ndims: F::Out<IntModel<SizeT>>,
|
pub ndims: F::Out<IntModel<SizeT>>,
|
||||||
pub shape: F::Out<PtrModel<IntModel<SizeT>>>,
|
pub shape: F::Out<PtrModel<IntModel<SizeT>>>,
|
||||||
|
pub strides: F::Out<PtrModel<IntModel<SizeT>>>,
|
||||||
|
|
||||||
pub indices: F::Out<PtrModel<IntModel<SizeT>>>,
|
pub indices: F::Out<PtrModel<IntModel<SizeT>>>,
|
||||||
pub size: F::Out<IntModel<SizeT>>,
|
|
||||||
pub nth: F::Out<IntModel<SizeT>>,
|
pub nth: F::Out<IntModel<SizeT>>,
|
||||||
|
pub element: F::Out<PtrModel<IntModel<Byte>>>,
|
||||||
|
|
||||||
|
pub size: F::Out<IntModel<SizeT>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
pub struct IndicesIter;
|
pub struct NDIter;
|
||||||
|
|
||||||
impl<'ctx> StructKind<'ctx> for IndicesIter {
|
impl<'ctx> StructKind<'ctx> for NDIter {
|
||||||
type Fields<F: FieldTraversal<'ctx>> = IndicesIterFields<'ctx, F>;
|
type Fields<F: FieldTraversal<'ctx>> = NDIterFields<'ctx, F>;
|
||||||
|
|
||||||
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
|
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
|
||||||
Self::Fields {
|
Self::Fields {
|
||||||
ndims: traversal.add_auto("ndims"),
|
ndims: traversal.add_auto("ndims"),
|
||||||
shape: traversal.add_auto("shape"),
|
shape: traversal.add_auto("shape"),
|
||||||
|
strides: traversal.add_auto("strides"),
|
||||||
|
|
||||||
indices: traversal.add_auto("indices"),
|
indices: traversal.add_auto("indices"),
|
||||||
size: traversal.add_auto("size"),
|
|
||||||
nth: traversal.add_auto("nth"),
|
nth: traversal.add_auto("nth"),
|
||||||
|
element: traversal.add_auto("element"),
|
||||||
|
|
||||||
|
size: traversal.add_auto("size"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
use std::iter::once;
|
use std::iter::once;
|
||||||
|
|
||||||
use helper::{create_ndims, debug_assert_prim_is_allowed, make_exception_fields, PrimDefDetails};
|
use helper::{
|
||||||
|
create_ndims, debug_assert_prim_is_allowed, extract_ndims, make_exception_fields,
|
||||||
|
PrimDefDetails,
|
||||||
|
};
|
||||||
use indexmap::IndexMap;
|
use indexmap::IndexMap;
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
attributes::{Attribute, AttributeLoc},
|
attributes::{Attribute, AttributeLoc},
|
||||||
|
@ -9,6 +12,7 @@ use inkwell::{
|
||||||
IntPredicate,
|
IntPredicate,
|
||||||
};
|
};
|
||||||
use itertools::Either;
|
use itertools::Either;
|
||||||
|
use numpy::unpack_ndarray_var_tys;
|
||||||
use strum::IntoEnumIterator;
|
use strum::IntoEnumIterator;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -16,15 +20,17 @@ use crate::{
|
||||||
builtin_fns,
|
builtin_fns,
|
||||||
classes::{ProxyValue, RangeValue},
|
classes::{ProxyValue, RangeValue},
|
||||||
extern_fns::{self, call_np_linalg_det, call_np_linalg_matrix_power},
|
extern_fns::{self, call_np_linalg_det, call_np_linalg_matrix_power},
|
||||||
irrt, llvm_intrinsics,
|
irrt::{self, call_nac3_ndarray_util_assert_shape_no_negative},
|
||||||
model::{IntModel, SizeT},
|
llvm_intrinsics,
|
||||||
|
model::*,
|
||||||
numpy::*,
|
numpy::*,
|
||||||
numpy_new::{self, gen_ndarray_transpose},
|
numpy_new::{self},
|
||||||
object::{
|
object::{
|
||||||
ndarray::{
|
ndarray::{
|
||||||
functions::{FloorOrCeil, MinOrMax},
|
functions::{FloorOrCeil, MinOrMax},
|
||||||
nalgebra::perform_nalgebra_call,
|
nalgebra::perform_nalgebra_call,
|
||||||
scalar::{split_scalar_or_ndarray, ScalarObject, ScalarOrNDArray},
|
scalar::{split_scalar_or_ndarray, ScalarObject, ScalarOrNDArray},
|
||||||
|
shape_util::parse_numpy_int_sequence,
|
||||||
NDArrayObject,
|
NDArrayObject,
|
||||||
},
|
},
|
||||||
tuple::TupleObject,
|
tuple::TupleObject,
|
||||||
|
@ -1104,7 +1110,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ret_dtype,
|
ret_dtype,
|
||||||
|generator, ctx, _i, scalar| {
|
|generator, ctx, scalar| {
|
||||||
let result = match prim {
|
let result = match prim {
|
||||||
PrimDef::FunInt32 => scalar.cast_to_int32(generator, ctx),
|
PrimDef::FunInt32 => scalar.cast_to_int32(generator, ctx),
|
||||||
PrimDef::FunInt64 => scalar.cast_to_int64(generator, ctx),
|
PrimDef::FunInt64 => scalar.cast_to_int64(generator, ctx),
|
||||||
|
@ -1175,9 +1181,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ret_int_dtype,
|
ret_int_dtype,
|
||||||
|generator, ctx, _i, scalar| {
|
|generator, ctx, scalar| Ok(scalar.round(generator, ctx, ret_int_dtype).value),
|
||||||
Ok(scalar.round(generator, ctx, ret_int_dtype).value)
|
|
||||||
},
|
|
||||||
)?;
|
)?;
|
||||||
Ok(Some(result.to_basic_value_enum()))
|
Ok(Some(result.to_basic_value_enum()))
|
||||||
}),
|
}),
|
||||||
|
@ -1241,7 +1245,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
int_sized,
|
int_sized,
|
||||||
|generator, ctx, _i, scalar| {
|
|generator, ctx, scalar| {
|
||||||
Ok(scalar.floor_or_ceil(generator, ctx, kind, int_sized).value)
|
Ok(scalar.floor_or_ceil(generator, ctx, kind, int_sized).value)
|
||||||
},
|
},
|
||||||
)?;
|
)?;
|
||||||
|
@ -1291,13 +1295,24 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
self.ndarray_float,
|
self.ndarray_float,
|
||||||
&[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
|
&[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
|
||||||
Box::new(move |ctx, obj, fun, args, generator| {
|
Box::new(move |ctx, obj, fun, args, generator| {
|
||||||
|
// Parse argument `shape`.
|
||||||
|
let shape_ty = fun.0.args[0].ty;
|
||||||
|
let shape_arg = args[0].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
|
||||||
|
let shape = AnyObject { ty: shape_ty, value: shape_arg };
|
||||||
|
let (_, shape) = parse_numpy_int_sequence(generator, ctx, shape);
|
||||||
|
|
||||||
|
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
|
||||||
|
let ndims = extract_ndims(&ctx.unifier, ndims);
|
||||||
|
|
||||||
let func = match prim {
|
let func = match prim {
|
||||||
PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => numpy_new::gen_ndarray_empty,
|
PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => NDArrayObject::from_np_empty,
|
||||||
PrimDef::FunNpZeros => numpy_new::gen_ndarray_zeros,
|
PrimDef::FunNpZeros => NDArrayObject::from_np_zero,
|
||||||
PrimDef::FunNpOnes => numpy_new::gen_ndarray_ones,
|
PrimDef::FunNpOnes => NDArrayObject::from_np_ones,
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
func(ctx, &obj, fun, &args, generator).map(|val| Some(val.as_basic_value_enum()))
|
|
||||||
|
let ndarray = func(generator, ctx, dtype, ndims, shape);
|
||||||
|
Ok(Some(ndarray.instance.value.as_basic_value_enum()))
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -1356,8 +1371,47 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
resolver: None,
|
resolver: None,
|
||||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||||
|ctx, obj, fun, args, generator| {
|
|ctx, obj, fun, args, generator| {
|
||||||
numpy_new::gen_ndarray_array(ctx, &obj, fun, &args, generator)
|
assert!(obj.is_none());
|
||||||
.map(|val| Some(val.as_basic_value_enum()))
|
assert!(matches!(args.len(), 1..=3));
|
||||||
|
|
||||||
|
let object_ty = fun.0.args[0].ty;
|
||||||
|
let object =
|
||||||
|
args[0].1.clone().to_basic_value_enum(ctx, generator, object_ty)?;
|
||||||
|
let object = AnyObject { ty: object_ty, value: object };
|
||||||
|
|
||||||
|
let copy_arg = if let Some(arg) = args
|
||||||
|
.iter()
|
||||||
|
.find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name))
|
||||||
|
{
|
||||||
|
let copy_ty = fun.0.args[1].ty;
|
||||||
|
arg.1.clone().to_basic_value_enum(ctx, generator, copy_ty)?
|
||||||
|
} else {
|
||||||
|
ctx.gen_symbol_val(
|
||||||
|
generator,
|
||||||
|
fun.0.args[1].default_value.as_ref().unwrap(),
|
||||||
|
fun.0.args[1].ty,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
// The argument `ndmin` is completely ignored. We don't need to know its LLVM value.
|
||||||
|
// We simply make the output ndarray's ndims correct with `atleast_nd`.
|
||||||
|
|
||||||
|
let (dtype, ndims) =
|
||||||
|
unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
|
||||||
|
let output_ndims = extract_ndims(&ctx.unifier, ndims);
|
||||||
|
|
||||||
|
let copy =
|
||||||
|
IntModel(Byte).check_value(generator, ctx.ctx, copy_arg).unwrap(); // NAC3 booleans are i8
|
||||||
|
let copy = copy.truncate(generator, ctx, Bool, "copy_bool");
|
||||||
|
|
||||||
|
let ndarray =
|
||||||
|
NDArrayObject::from_np_array(generator, ctx, object, copy);
|
||||||
|
debug_assert!(ndarray.ndims <= output_ndims); // Sanity check on `ndims`
|
||||||
|
|
||||||
|
let ndarray = ndarray.atleast_nd(generator, ctx, output_ndims);
|
||||||
|
debug_assert!(ctx.unifier.unioned(ndarray.dtype, dtype)); // Sanity check on `dtype`
|
||||||
|
|
||||||
|
Ok(Some(ndarray.instance.value.as_basic_value_enum()))
|
||||||
},
|
},
|
||||||
)))),
|
)))),
|
||||||
loc: None,
|
loc: None,
|
||||||
|
@ -1375,8 +1429,31 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
// type variable
|
// type variable
|
||||||
&[(self.list_int32, "shape"), (tv.ty, "fill_value")],
|
&[(self.list_int32, "shape"), (tv.ty, "fill_value")],
|
||||||
Box::new(move |ctx, obj, fun, args, generator| {
|
Box::new(move |ctx, obj, fun, args, generator| {
|
||||||
numpy_new::gen_ndarray_full(ctx, &obj, fun, &args, generator)
|
assert!(obj.is_none());
|
||||||
.map(|val| Some(val.as_basic_value_enum()))
|
assert_eq!(args.len(), 2);
|
||||||
|
|
||||||
|
// Parse argument #1 shape
|
||||||
|
let shape_ty = fun.0.args[0].ty;
|
||||||
|
let shape =
|
||||||
|
args[0].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
|
||||||
|
let shape = AnyObject { ty: shape_ty, value: shape };
|
||||||
|
let (_, shape) = parse_numpy_int_sequence(generator, ctx, shape);
|
||||||
|
|
||||||
|
// Parse argument #2 fill_value
|
||||||
|
let fill_value_ty = fun.0.args[1].ty;
|
||||||
|
let fill_value =
|
||||||
|
args[1].1.clone().to_basic_value_enum(ctx, generator, fill_value_ty)?;
|
||||||
|
let fill_value = ScalarObject { dtype: fill_value_ty, value: fill_value };
|
||||||
|
|
||||||
|
// Implementation
|
||||||
|
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
|
||||||
|
let ndims = extract_ndims(&ctx.unifier, ndims);
|
||||||
|
|
||||||
|
let ndarray = NDArrayObject::from_np_full(
|
||||||
|
generator, ctx, dtype, ndims, shape, fill_value,
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(Some(ndarray.instance.value.as_basic_value_enum()))
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -1483,12 +1560,38 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape"),
|
(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape"),
|
||||||
],
|
],
|
||||||
Box::new(move |ctx, obj, fun, args, generator| {
|
Box::new(move |ctx, obj, fun, args, generator| {
|
||||||
let f = match prim {
|
// Parse argument #1 ndarray
|
||||||
PrimDef::FunNpBroadcastTo => numpy_new::gen_ndarray_broadcast_to,
|
let input_ty = fun.0.args[0].ty;
|
||||||
PrimDef::FunNpReshape => numpy_new::gen_ndarray_reshape,
|
let input =
|
||||||
|
args[0].1.clone().to_basic_value_enum(ctx, generator, input_ty)?;
|
||||||
|
let input = AnyObject { ty: input_ty, value: input };
|
||||||
|
|
||||||
|
// Parse argument #2 shape
|
||||||
|
let shape_ty = fun.0.args[1].ty;
|
||||||
|
let shape =
|
||||||
|
args[1].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
|
||||||
|
let shape = AnyObject { ty: shape_ty, value: shape };
|
||||||
|
let (_, return_shape) = parse_numpy_int_sequence(generator, ctx, shape);
|
||||||
|
|
||||||
|
// Implementation
|
||||||
|
// Turn any input to an ndarray for a convenient implementation.
|
||||||
|
let ndarray = split_scalar_or_ndarray(generator, ctx, input)
|
||||||
|
.as_ndarray(generator, ctx);
|
||||||
|
|
||||||
|
let (_, return_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
|
||||||
|
let return_ndims = extract_ndims(&ctx.unifier, return_ndims);
|
||||||
|
|
||||||
|
let result = match prim {
|
||||||
|
PrimDef::FunNpBroadcastTo => {
|
||||||
|
ndarray.broadcast_to(generator, ctx, return_ndims, return_shape)
|
||||||
|
}
|
||||||
|
PrimDef::FunNpReshape => {
|
||||||
|
ndarray.reshape_or_copy(generator, ctx, return_ndims, return_shape)
|
||||||
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
f(ctx, &obj, fun, &args, generator).map(Some)
|
|
||||||
|
Ok(Some(result.instance.value.as_basic_value_enum()))
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -1529,7 +1632,39 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
resolver: None,
|
resolver: None,
|
||||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||||
|ctx, obj, fun, args, generator| {
|
|ctx, obj, fun, args, generator| {
|
||||||
gen_ndarray_transpose(ctx, &obj, fun, &args, generator).map(Some)
|
assert!(obj.is_none());
|
||||||
|
assert_eq!(args.len(), 1);
|
||||||
|
|
||||||
|
// Parse argument #1 ndarray
|
||||||
|
let ndarray_ty = fun.0.args[0].ty;
|
||||||
|
let ndarray = args[0]
|
||||||
|
.1
|
||||||
|
.clone()
|
||||||
|
.to_basic_value_enum(ctx, generator, ndarray_ty)?;
|
||||||
|
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
|
||||||
|
|
||||||
|
// Implementation
|
||||||
|
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
|
||||||
|
|
||||||
|
let has_axes = args.len() >= 2;
|
||||||
|
let transposed_ndarray = if has_axes {
|
||||||
|
// Parse argument #2 axes
|
||||||
|
let in_axes_ty = fun.0.args[1].ty;
|
||||||
|
let in_axes = args[1]
|
||||||
|
.1
|
||||||
|
.clone()
|
||||||
|
.to_basic_value_enum(ctx, generator, in_axes_ty)?;
|
||||||
|
let in_axes = AnyObject { ty: in_axes_ty, value: in_axes };
|
||||||
|
|
||||||
|
let (_, axes) = parse_numpy_int_sequence(generator, ctx, in_axes);
|
||||||
|
|
||||||
|
ndarray.transpose(generator, ctx, Some(axes))
|
||||||
|
} else {
|
||||||
|
// axes is not given
|
||||||
|
ndarray.transpose(generator, ctx, None)
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Some(transposed_ndarray.instance.value.as_basic_value_enum()))
|
||||||
},
|
},
|
||||||
)))),
|
)))),
|
||||||
loc: None,
|
loc: None,
|
||||||
|
@ -1641,7 +1776,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.float,
|
ctx.primitives.float,
|
||||||
move |_generator, ctx, _i, scalar| {
|
move |_generator, ctx, scalar| {
|
||||||
let result = scalar.np_floor_or_ceil(ctx, kind);
|
let result = scalar.np_floor_or_ceil(ctx, kind);
|
||||||
Ok(result.value)
|
Ok(result.value)
|
||||||
},
|
},
|
||||||
|
@ -1670,7 +1805,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.float,
|
ctx.primitives.float,
|
||||||
|_generator, ctx, _i, scalar| {
|
|_generator, ctx, scalar| {
|
||||||
let result = scalar.np_round(ctx);
|
let result = scalar.np_round(ctx);
|
||||||
Ok(result.value)
|
Ok(result.value)
|
||||||
},
|
},
|
||||||
|
@ -1883,7 +2018,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
ctx,
|
ctx,
|
||||||
&[x1, x2],
|
&[x1, x2],
|
||||||
common_ty,
|
common_ty,
|
||||||
|_generator, ctx, _i, scalars| {
|
|_generator, ctx, scalars| {
|
||||||
let x1 = scalars[0];
|
let x1 = scalars[0];
|
||||||
let x2 = scalars[1];
|
let x2 = scalars[1];
|
||||||
|
|
||||||
|
@ -1930,7 +2065,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
num_ty.ty,
|
num_ty.ty,
|
||||||
|_generator, ctx, _i, scalar| Ok(scalar.abs(ctx).value),
|
|_generator, ctx, scalar| Ok(scalar.abs(ctx).value),
|
||||||
)?;
|
)?;
|
||||||
Ok(Some(result.to_basic_value_enum()))
|
Ok(Some(result.to_basic_value_enum()))
|
||||||
},
|
},
|
||||||
|
@ -1966,7 +2101,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.bool,
|
ctx.primitives.bool,
|
||||||
|generator, ctx, _i, scalar| {
|
|generator, ctx, scalar| {
|
||||||
let n = scalar.into_float64(ctx);
|
let n = scalar.into_float64(ctx);
|
||||||
let n = function(generator, ctx, n);
|
let n = function(generator, ctx, n);
|
||||||
Ok(n.as_basic_value_enum())
|
Ok(n.as_basic_value_enum())
|
||||||
|
@ -2035,7 +2170,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.float,
|
ctx.primitives.float,
|
||||||
|_generator, ctx, _i, scalar| {
|
|_generator, ctx, scalar| {
|
||||||
let n = scalar.into_float64(ctx);
|
let n = scalar.into_float64(ctx);
|
||||||
let n = match prim {
|
let n = match prim {
|
||||||
PrimDef::FunNpSin => llvm_intrinsics::call_float_sin(ctx, n, None),
|
PrimDef::FunNpSin => llvm_intrinsics::call_float_sin(ctx, n, None),
|
||||||
|
@ -2160,7 +2295,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
ctx,
|
ctx,
|
||||||
&[x1, x2],
|
&[x1, x2],
|
||||||
ret_dtype,
|
ret_dtype,
|
||||||
|_generator, ctx, _i, scalars| {
|
|_generator, ctx, scalars| {
|
||||||
let x1 = scalars[0];
|
let x1 = scalars[0];
|
||||||
let x2 = scalars[1];
|
let x2 = scalars[1];
|
||||||
|
|
||||||
|
@ -2406,7 +2541,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
);
|
);
|
||||||
let sizet_model = IntModel(SizeT);
|
let sizet_model = IntModel(SizeT);
|
||||||
let zero = sizet_model.const_0(generator, ctx.ctx);
|
let zero = sizet_model.const_0(generator, ctx.ctx);
|
||||||
x2_ndarray.set_nth(
|
x2_ndarray.set_nth_scalar(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
zero,
|
zero,
|
||||||
|
@ -2451,7 +2586,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
|
|
||||||
let sizet_model = IntModel(SizeT);
|
let sizet_model = IntModel(SizeT);
|
||||||
let zero = sizet_model.const_0(generator, ctx.ctx);
|
let zero = sizet_model.const_0(generator, ctx.ctx);
|
||||||
let determinant = out.get_nth(generator, ctx, zero);
|
let determinant = out.get_nth_scalar(generator, ctx, zero);
|
||||||
|
|
||||||
Ok(Some(determinant.value))
|
Ok(Some(determinant.value))
|
||||||
}),
|
}),
|
||||||
|
|
Loading…
Reference in New Issue