forked from M-Labs/nac3
WIP: core/ndstrides: more iter and less builtin
This commit is contained in:
parent
15dfb2eaa0
commit
5dce27e87d
@ -21,40 +21,61 @@ namespace {
|
||||
* - If shape contains zeroes, there are no enumerations.
|
||||
*/
|
||||
template <typename SizeT>
|
||||
struct IndicesIter {
|
||||
struct NDIter {
|
||||
SizeT ndims;
|
||||
SizeT* shape;
|
||||
SizeT* strides;
|
||||
SizeT size; // Product of shape
|
||||
|
||||
SizeT* indices; // The current indices
|
||||
SizeT nth; // The nth (0-based) index of the current indices.
|
||||
uint8_t* element; // The current element
|
||||
/**
|
||||
* @brief The current indices.
|
||||
*
|
||||
* Must be allocated by the caller.
|
||||
*/
|
||||
SizeT* indices;
|
||||
|
||||
// A convenient constructor for internal C++ IRRT.
|
||||
IndicesIter(SizeT ndims, SizeT* shape, SizeT* strides, SizeT *indices, uint8_t* element) {
|
||||
/**
|
||||
* @brief The nth (0-based) index of the current indices.
|
||||
*/
|
||||
SizeT nth;
|
||||
|
||||
/**
|
||||
* @brief Pointer to the current element.
|
||||
*/
|
||||
uint8_t* element;
|
||||
|
||||
/**
|
||||
* @brief The product of shape.
|
||||
*/
|
||||
SizeT size;
|
||||
|
||||
// TODO:: There is something called backstrides to speedup iteration.
|
||||
// See https://ajcr.net/stride-guide-part-1/, and https://docs.scipy.org/doc/numpy-1.13.0/reference/c-api.types-and-structures.html#c.PyArrayIterObject.PyArrayIterObject.backstrides.
|
||||
// Maybe LLVM is clever and knows how to optimize.
|
||||
|
||||
void initialize(SizeT ndims, SizeT* shape, SizeT* strides, uint8_t* element,
|
||||
SizeT* indices) {
|
||||
this->ndims = ndims;
|
||||
this->shape = shape;
|
||||
this->strides = strides;
|
||||
|
||||
this->indices = indices;
|
||||
this->element = element;
|
||||
this->initialize();
|
||||
}
|
||||
|
||||
void initialize() {
|
||||
reset();
|
||||
|
||||
// Compute size and backstrides
|
||||
this->size = 1;
|
||||
for (SizeT i = 0; i < ndims; i++) {
|
||||
this->size *= shape[i];
|
||||
}
|
||||
}
|
||||
|
||||
void reset() {
|
||||
for (SizeT axis = 0; axis < ndims; axis++) indices[axis] = 0;
|
||||
nth = 0;
|
||||
}
|
||||
|
||||
void initialize_by_ndarray(NDArray<SizeT>* ndarray, SizeT* indices) {
|
||||
this->initialize(ndarray->ndims, ndarray->shape, ndarray->strides,
|
||||
ndarray->data, indices);
|
||||
}
|
||||
|
||||
bool has_next() { return nth < size; }
|
||||
|
||||
void next() {
|
||||
@ -63,7 +84,11 @@ struct IndicesIter {
|
||||
indices[axis]++;
|
||||
if (indices[axis] >= shape[axis]) {
|
||||
indices[axis] = 0;
|
||||
|
||||
// TODO: Can be optimized with backstrides.
|
||||
element -= strides[axis] * shape[axis];
|
||||
} else {
|
||||
element += strides[axis];
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -73,32 +98,21 @@ struct IndicesIter {
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
void __call_nac3_ndarray_indices_iter_initialize(IndicesIter<int32_t>* iter,
|
||||
int32_t ndims, int32_t* shape,
|
||||
int32_t* indices) {
|
||||
iter->initialize(ndims, shape, indices);
|
||||
void __nac3_nditer_initialize(NDIter<int32_t>* iter, NDArray<int32_t>* ndarray,
|
||||
int32_t* indices) {
|
||||
iter->initialize_by_ndarray(ndarray, indices);
|
||||
}
|
||||
|
||||
void __call_nac3_ndarray_indices_iter_initialize64(IndicesIter<int64_t>* iter,
|
||||
int64_t ndims,
|
||||
int64_t* shape,
|
||||
int64_t* indices) {
|
||||
iter->initialize(ndims, shape, indices);
|
||||
void __nac3_nditer_initialize64(NDIter<int64_t>* iter,
|
||||
NDArray<int64_t>* ndarray, int64_t* indices) {
|
||||
iter->initialize_by_ndarray(ndarray, indices);
|
||||
}
|
||||
|
||||
bool __call_nac3_ndarray_indices_iter_has_next(IndicesIter<int32_t>* iter) {
|
||||
iter->has_next();
|
||||
}
|
||||
bool __nac3_nditer_has_next(NDIter<int32_t>* iter) { return iter->has_next(); }
|
||||
|
||||
bool __call_nac3_ndarray_indices_iter_has_next64(IndicesIter<int64_t>* iter) {
|
||||
iter->has_next();
|
||||
}
|
||||
bool __nac3_nditer_has_next64(NDIter<int64_t>* iter) { return iter->has_next(); }
|
||||
|
||||
bool __call_nac3_ndarray_indices_iter_next(IndicesIter<int32_t>* iter) {
|
||||
iter->next();
|
||||
}
|
||||
void __nac3_nditer_next(NDIter<int32_t>* iter) { iter->next(); }
|
||||
|
||||
bool __call_nac3_ndarray_indices_iter_next64(IndicesIter<int64_t>* iter) {
|
||||
iter->next();
|
||||
}
|
||||
void __nac3_nditer_next64(NDIter<int64_t>* iter) { iter->next(); }
|
||||
}
|
@ -123,7 +123,9 @@ void matmul_at_least_2d(NDArray<SizeT>* a_ndarray, NDArray<SizeT>* b_ndarray,
|
||||
SizeT* indices =
|
||||
(SizeT*)__builtin_alloca(sizeof(SizeT) * dst_ndarray->ndims);
|
||||
SizeT* mat_indices = indices + u;
|
||||
IndicesIter<SizeT> iter(u, dst_ndarray->shape, indices);
|
||||
NDIter<SizeT> iter;
|
||||
iter.initialize(u, dst_ndarray->shape, dst_ndarray->strides,
|
||||
dst_ndarray->data, indices);
|
||||
|
||||
for (; iter.has_next(); iter.next()) {
|
||||
for (SizeT i = 0; i < dst_mat_shape[0]; i++) {
|
||||
|
@ -1565,7 +1565,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
ctx,
|
||||
&[left, right],
|
||||
out,
|
||||
|generator, ctx, _i, scalars| {
|
||||
|generator, ctx, scalars| {
|
||||
let left = scalars[0];
|
||||
let right = scalars[1];
|
||||
gen_binop_expr_with_values(
|
||||
|
@ -6,7 +6,7 @@ mod test;
|
||||
use super::model::*;
|
||||
use super::object::ndarray::broadcast::ShapeEntry;
|
||||
use super::object::ndarray::indexing::NDIndex;
|
||||
use super::structure::{List, NDArray};
|
||||
use super::structure::{List, NDArray, NDIter};
|
||||
use super::{
|
||||
classes::{
|
||||
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue,
|
||||
@ -1239,7 +1239,31 @@ pub fn call_nac3_array_write_list_to_array<'ctx, G: CodeGenerator + ?Sized>(
|
||||
CallFunction::begin(generator, ctx, &name).arg(list).arg(ndarray).returning_void();
|
||||
}
|
||||
|
||||
pub fn call_nac3_ndarray_indices_iter_initialize<'ctx, G: CodeGenerator + ?Sized>(
|
||||
|
||||
pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
iter: Ptr<'ctx, StructModel<NDIter>>,
|
||||
ndarray: Ptr<'ctx, StructModel<NDArray>>,
|
||||
indices: Ptr<'ctx, IntModel<SizeT>>,
|
||||
) {
|
||||
}
|
||||
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_initialize");
|
||||
CallFunction::begin(generator, ctx, &name).arg(iter).arg(ndarray).arg(indices).returning_void();
|
||||
}
|
||||
|
||||
pub fn call_nac3_nditer_has_next<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
iter: Ptr<'ctx, StructModel<NDIter>>,
|
||||
) -> Int<'ctx, Bool> {
|
||||
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_has_next");
|
||||
CallFunction::begin(generator, ctx, &name).arg(iter).returning_auto("has_next")
|
||||
}
|
||||
|
||||
pub fn call_nac3_nditer_next<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
iter: Ptr<'ctx, StructModel<NDIter>>,
|
||||
) {
|
||||
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_next");
|
||||
CallFunction::begin(generator, ctx, &name).arg(iter).returning_void();
|
||||
}
|
||||
|
@ -219,4 +219,4 @@ impl<'ctx, S: StructKind<'ctx>> Ptr<'ctx, StructModel<S>> {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Add an opaque struct type?
|
||||
// TODO: Add an opaque struct type?
|
||||
|
@ -127,35 +127,6 @@ pub fn gen_ndarray_ones<'ctx>(
|
||||
Ok(ndarray.instance.value.as_basic_value_enum())
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `np.full`.
|
||||
pub fn gen_ndarray_full<'ctx>(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||
fun: (&FunSignature, DefinitionId),
|
||||
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||
generator: &mut dyn CodeGenerator,
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
assert!(obj.is_none());
|
||||
assert_eq!(args.len(), 2);
|
||||
|
||||
// Parse argument #1 shape
|
||||
let shape_ty = fun.0.args[0].ty;
|
||||
let shape = args[0].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
|
||||
let shape = AnyObject { ty: shape_ty, value: shape };
|
||||
|
||||
// Parse argument #2 fill_value
|
||||
let fill_value_ty = fun.0.args[1].ty;
|
||||
let fill_value = args[1].1.clone().to_basic_value_enum(ctx, generator, fill_value_ty)?;
|
||||
|
||||
// Implementation
|
||||
let ndarray_ty = fun.0.ret;
|
||||
let ndarray = create_empty_ndarray(generator, ctx, ndarray_ty, shape);
|
||||
|
||||
ndarray.fill(generator, ctx, fill_value);
|
||||
|
||||
Ok(ndarray.instance.value.as_basic_value_enum())
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `np.broadcast_to`.
|
||||
pub fn gen_ndarray_broadcast_to<'ctx>(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
@ -345,89 +316,3 @@ pub fn gen_ndarray_strides<'ctx>(
|
||||
Ok(strides.value.as_basic_value_enum())
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `np.transpose`.
|
||||
pub fn gen_ndarray_transpose<'ctx>(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||
fun: (&FunSignature, DefinitionId),
|
||||
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||
generator: &mut dyn CodeGenerator,
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
// TODO: The implementation will be changed once default values start working again.
|
||||
// Read the comment on this function in BuiltinBuilder.
|
||||
|
||||
// TODO: Change axes values to `SizeT`
|
||||
|
||||
assert!(obj.is_none());
|
||||
assert_eq!(args.len(), 1);
|
||||
|
||||
// Parse argument #1 ndarray
|
||||
let ndarray_ty = fun.0.args[0].ty;
|
||||
let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
|
||||
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
|
||||
|
||||
// Implementation
|
||||
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
|
||||
|
||||
let has_axes = args.len() >= 2;
|
||||
let transposed_ndarray = if has_axes {
|
||||
// Parse argument #2 axes
|
||||
let in_axes_ty = fun.0.args[1].ty;
|
||||
let in_axes = args[1].1.clone().to_basic_value_enum(ctx, generator, in_axes_ty)?;
|
||||
let in_axes = AnyObject { ty: in_axes_ty, value: in_axes };
|
||||
|
||||
let (_, axes) = parse_numpy_int_sequence(generator, ctx, in_axes);
|
||||
|
||||
ndarray.transpose(generator, ctx, Some(axes))
|
||||
} else {
|
||||
// axes is not given
|
||||
ndarray.transpose(generator, ctx, None)
|
||||
};
|
||||
|
||||
Ok(transposed_ndarray.instance.value.as_basic_value_enum())
|
||||
}
|
||||
|
||||
pub fn gen_ndarray_array<'ctx>(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||
fun: (&FunSignature, DefinitionId),
|
||||
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||
generator: &mut dyn CodeGenerator,
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
assert!(obj.is_none());
|
||||
assert!(matches!(args.len(), 1..=3));
|
||||
|
||||
let object_ty = fun.0.args[0].ty;
|
||||
let object = args[0].1.clone().to_basic_value_enum(ctx, generator, object_ty)?;
|
||||
let object = AnyObject { ty: object_ty, value: object };
|
||||
|
||||
let copy_arg = if let Some(arg) =
|
||||
args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name))
|
||||
{
|
||||
let copy_ty = fun.0.args[1].ty;
|
||||
arg.1.clone().to_basic_value_enum(ctx, generator, copy_ty)?
|
||||
} else {
|
||||
ctx.gen_symbol_val(
|
||||
generator,
|
||||
fun.0.args[1].default_value.as_ref().unwrap(),
|
||||
fun.0.args[1].ty,
|
||||
)
|
||||
};
|
||||
|
||||
// The argument `ndmin` is completely ignored. We don't need to know its LLVM value.
|
||||
// We simply make the output ndarray's ndims correct with `atleast_nd`.
|
||||
|
||||
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
|
||||
let output_ndims = extract_ndims(&ctx.unifier, ndims);
|
||||
|
||||
let copy = IntModel(Byte).check_value(generator, ctx.ctx, copy_arg).unwrap(); // NAC3 booleans are i8
|
||||
let copy = copy.truncate(generator, ctx, Bool, "copy_bool");
|
||||
|
||||
let ndarray = NDArrayObject::from_np_array(generator, ctx, object, copy);
|
||||
debug_assert!(ndarray.ndims <= output_ndims); // Sanity check on `ndims`
|
||||
|
||||
let ndarray = ndarray.atleast_nd(generator, ctx, output_ndims);
|
||||
debug_assert!(ctx.unifier.unioned(ndarray.dtype, dtype)); // Sanity check on `dtype`
|
||||
|
||||
Ok(ndarray.instance.value.as_basic_value_enum())
|
||||
}
|
||||
|
@ -1,7 +1,10 @@
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::codegen::{
|
||||
irrt::{call_nac3_ndarray_broadcast_shapes, call_nac3_ndarray_broadcast_to},
|
||||
irrt::{
|
||||
call_nac3_ndarray_broadcast_shapes, call_nac3_ndarray_broadcast_to,
|
||||
call_nac3_ndarray_util_assert_shape_no_negative,
|
||||
},
|
||||
model::*,
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
@ -29,6 +32,8 @@ impl<'ctx> StructKind<'ctx> for ShapeEntry {
|
||||
impl<'ctx> NDArrayObject<'ctx> {
|
||||
/// Create a broadcast view on this ndarray with a target shape.
|
||||
///
|
||||
/// The input shape will be checked to make sure that it contains no negative values.
|
||||
///
|
||||
/// * `target_ndims` - The ndims type after broadcasting to the given shape.
|
||||
/// The caller has to figure this out for this function.
|
||||
/// * `target_shape` - An array pointer pointing to the target shape.
|
||||
@ -40,6 +45,14 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
target_ndims: u64,
|
||||
target_shape: Ptr<'ctx, IntModel<SizeT>>,
|
||||
) -> Self {
|
||||
let target_ndims_llvm = IntModel(SizeT).constant(generator, ctx.ctx, target_ndims);
|
||||
call_nac3_ndarray_util_assert_shape_no_negative(
|
||||
generator,
|
||||
ctx,
|
||||
target_ndims_llvm,
|
||||
target_shape,
|
||||
);
|
||||
|
||||
let broadcast_ndarray = NDArrayObject::alloca(
|
||||
generator,
|
||||
ctx,
|
||||
|
@ -1,4 +1,4 @@
|
||||
use inkwell::values::BasicValueEnum;
|
||||
use inkwell::{values::BasicValueEnum, IntPredicate};
|
||||
|
||||
use super::{scalar::ScalarObject, NDArrayObject};
|
||||
use crate::{
|
||||
@ -154,10 +154,15 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
// Create data and set elements
|
||||
ndarray.create_data(generator, ctx);
|
||||
ndarray
|
||||
.foreach_pointer(generator, ctx, |_generator, ctx, _hooks, i, pelement| {
|
||||
.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
|
||||
// Get the index of the current element, convert that index to float, and write it.
|
||||
// This is how we get [0.0, 1.0, 2.0, ...].
|
||||
let index = nditer.get_index(generator, ctx);
|
||||
let pelement = nditer.get_pointer(generator, ctx);
|
||||
|
||||
let val = ctx
|
||||
.builder
|
||||
.build_unsigned_int_to_float(i.value, ctx.ctx.f64_type(), "val")
|
||||
.build_unsigned_int_to_float(index.value, ctx.ctx.f64_type(), "val")
|
||||
.unwrap();
|
||||
ctx.builder.build_store(pelement, val).unwrap();
|
||||
Ok(())
|
||||
@ -172,23 +177,45 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
dtype: Type,
|
||||
rows: Int<'ctx, SizeT>,
|
||||
cols: Int<'ctx, SizeT>,
|
||||
num_rows: Int<'ctx, SizeT>,
|
||||
num_cols: Int<'ctx, SizeT>,
|
||||
diagonal: Int<'ctx, SizeT>,
|
||||
) -> Self {
|
||||
let ndzero = ndarray_zero_value(generator, ctx, dtype);
|
||||
let ndone = ndarray_one_value(generator, ctx, dtype);
|
||||
|
||||
let ndarray = NDArrayObject::alloca_dynamic_shape(
|
||||
generator,
|
||||
ctx,
|
||||
dtype,
|
||||
&[rows, cols],
|
||||
&[num_rows, num_cols],
|
||||
"eye_ndarray",
|
||||
);
|
||||
|
||||
ndarray
|
||||
.foreach_pointer(generator, ctx, |generator, ctx, hooks, i, p| {
|
||||
.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
|
||||
// NOTE: rows and cols can never be zero here, since this ndarray's `np.size` would be zero
|
||||
// and this loop would not execute.
|
||||
|
||||
todo!()
|
||||
|
||||
// Load up `row_i` and `col_i` from indices.
|
||||
let row_i = nditer
|
||||
.get_indices()
|
||||
.offset_const(generator, ctx, 0, "")
|
||||
.load(generator, ctx, "row_i");
|
||||
let col_i = nditer
|
||||
.get_indices()
|
||||
.offset_const(generator, ctx, 1, "")
|
||||
.load(generator, ctx, "col_i");
|
||||
|
||||
// Write to element
|
||||
let be_one =
|
||||
row_i.add(ctx, diagonal, "").compare(ctx, IntPredicate::EQ, col_i, "write_one");
|
||||
let value = ctx.builder.build_select(be_one.value, ndone, ndzero, "value").unwrap();
|
||||
|
||||
let p = nditer.get_pointer(generator, ctx);
|
||||
ctx.builder.build_store(p, value).unwrap();
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.unwrap();
|
||||
todo!()
|
||||
|
@ -480,7 +480,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
let zero = sizet_model.const_0(generator, ctx.ctx);
|
||||
pextremum_index.store(ctx, zero);
|
||||
|
||||
let first_scalar = self.get_nth(generator, ctx, zero);
|
||||
let first_scalar = self.get_nth_scalar(generator, ctx, zero);
|
||||
ctx.builder.build_store(pextremum, first_scalar.value).unwrap();
|
||||
|
||||
// Find extremum
|
||||
@ -491,7 +491,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
// Worth reading on "Notes" in <https://numpy.org/doc/stable/reference/generated/numpy.min.html#numpy.min>
|
||||
// on how `NaN` values have to be handled.
|
||||
|
||||
let scalar = self.get_nth(generator, ctx, i);
|
||||
let scalar = self.get_nth_scalar(generator, ctx, i);
|
||||
|
||||
let old_extremum = ctx.builder.build_load(pextremum, "current_extremum").unwrap();
|
||||
let old_extremum = ScalarObject { dtype: self.dtype, value: old_extremum };
|
||||
|
@ -1,17 +1,17 @@
|
||||
use inkwell::values::BasicValueEnum;
|
||||
use itertools::Itertools;
|
||||
use util::gen_for_model_auto;
|
||||
|
||||
use crate::{
|
||||
codegen::{
|
||||
model::*,
|
||||
object::ndarray::{NDArrayObject, ScalarObject},
|
||||
stmt::gen_for_callback,
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
typecheck::typedef::Type,
|
||||
};
|
||||
|
||||
use super::{scalar::ScalarOrNDArray, NDArrayOut};
|
||||
use super::{nditer::NDIterHandle, scalar::ScalarOrNDArray, NDArrayOut};
|
||||
|
||||
impl<'ctx> NDArrayObject<'ctx> {
|
||||
/// TODO: Document me. Has complex behavior.
|
||||
@ -28,12 +28,9 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
MappingFn: FnOnce(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
Int<'ctx, SizeT>,
|
||||
&[ScalarObject<'ctx>],
|
||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
{
|
||||
let sizet_model = IntModel(SizeT);
|
||||
|
||||
// Broadcast inputs
|
||||
let broadcast_result = NDArrayObject::broadcast(generator, ctx, ndarrays);
|
||||
|
||||
@ -66,19 +63,60 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
};
|
||||
|
||||
// Map element-wise and store results into `mapped_ndarray`.
|
||||
let start = sizet_model.const_0(generator, ctx.ctx);
|
||||
let stop = broadcast_result.ndarrays[0].size(generator, ctx); // They all should have the same `np.size`.
|
||||
let step = sizet_model.const_1(generator, ctx.ctx);
|
||||
gen_for_model_auto(generator, ctx, start, stop, step, move |generator, ctx, _hooks, i| {
|
||||
let elements =
|
||||
ndarrays.iter().map(|ndarray| ndarray.get_nth(generator, ctx, i)).collect_vec();
|
||||
let nditer = NDIterHandle::new(generator, ctx, out_ndarray);
|
||||
gen_for_callback(
|
||||
generator,
|
||||
ctx,
|
||||
Some("broadcast_starmap"),
|
||||
|generator, ctx| {
|
||||
// Create NDIters for all broadcasted input ndarrays.
|
||||
let other_nditers = broadcast_result
|
||||
.ndarrays
|
||||
.iter()
|
||||
.map(|ndarray| NDIterHandle::new(generator, ctx, *ndarray))
|
||||
.collect_vec();
|
||||
Ok((nditer, other_nditers))
|
||||
},
|
||||
|generator, ctx, (out_nditer, _in_nditers)| {
|
||||
// We can simply use `out_nditer`'s `has_next()`.
|
||||
// `in_nditers`' `has_next()`s should return the same value.
|
||||
Ok(out_nditer.has_next(generator, ctx).value)
|
||||
},
|
||||
|generator, ctx, _hooks, (out_nditer, in_nditers)| {
|
||||
// Get all the scalars from the broadcasted input ndarrays, pass them to `mapping`,
|
||||
// and write to `out_ndarray`.
|
||||
|
||||
let ret = mapping(generator, ctx, i, &elements)?;
|
||||
let in_scalars =
|
||||
in_nditers.iter().map(|nditer| nditer.get_scalar(generator, ctx)).collect_vec();
|
||||
|
||||
let pret = out_ndarray.get_nth_pointer(generator, ctx, i, "pret");
|
||||
ctx.builder.build_store(pret, ret).unwrap();
|
||||
Ok(())
|
||||
})?;
|
||||
let result = mapping(generator, ctx, &in_scalars)?;
|
||||
|
||||
let p = out_nditer.get_pointer(generator, ctx);
|
||||
ctx.builder.build_store(p, result).unwrap();
|
||||
|
||||
Ok(())
|
||||
},
|
||||
|generator, ctx, (out_nditer, in_nditers)| {
|
||||
// Advance all iterators
|
||||
out_nditer.next(generator, ctx);
|
||||
in_nditers.iter().for_each(|nditer| nditer.next(generator, ctx));
|
||||
Ok(())
|
||||
},
|
||||
)?;
|
||||
|
||||
// let start = sizet_model.const_0(generator, ctx.ctx);
|
||||
// let stop = broadcast_result.ndarrays[0].size(generator, ctx); // They all should have the same `np.size`.
|
||||
// let step = sizet_model.const_1(generator, ctx.ctx);
|
||||
// gen_for_model_auto(generator, ctx, start, stop, step, move |generator, ctx, _hooks, i| {
|
||||
// let elements =
|
||||
// ndarrays.iter().map(|ndarray| ndarray.get_nth_scalar(generator, ctx, i)).collect_vec();
|
||||
|
||||
// let ret = mapping(generator, ctx, i, &elements)?;
|
||||
|
||||
// let pret = out_ndarray.get_nth_pointer(generator, ctx, i, "pret");
|
||||
// ctx.builder.build_store(pret, ret).unwrap();
|
||||
// Ok(())
|
||||
// })?;
|
||||
|
||||
Ok(out_ndarray)
|
||||
}
|
||||
@ -95,7 +133,6 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
Mapping: FnOnce(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
Int<'ctx, SizeT>,
|
||||
ScalarObject<'ctx>,
|
||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
{
|
||||
@ -104,7 +141,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
ctx,
|
||||
&[*self],
|
||||
out,
|
||||
|generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]),
|
||||
|generator, ctx, scalars| mapping(generator, ctx, scalars[0]),
|
||||
)
|
||||
}
|
||||
}
|
||||
@ -124,20 +161,16 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||
MappingFn: FnOnce(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
Int<'ctx, SizeT>,
|
||||
&[ScalarObject<'ctx>],
|
||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
{
|
||||
let sizet_model = IntModel(SizeT);
|
||||
|
||||
// Check if all inputs are ScalarObjects
|
||||
let all_scalars: Option<Vec<_>> =
|
||||
inputs.iter().map(ScalarObject::try_from).try_collect().ok();
|
||||
|
||||
if let Some(scalars) = all_scalars {
|
||||
let i = sizet_model.const_0(generator, ctx.ctx); // Pass 0 as the index
|
||||
let scalar =
|
||||
ScalarObject { value: mapping(generator, ctx, i, &scalars)?, dtype: ret_dtype };
|
||||
ScalarObject { value: mapping(generator, ctx, &scalars)?, dtype: ret_dtype };
|
||||
Ok(ScalarOrNDArray::Scalar(scalar))
|
||||
} else {
|
||||
// Promote all input to ndarrays and map through them.
|
||||
@ -165,7 +198,6 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||
Mapping: FnOnce(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
Int<'ctx, SizeT>,
|
||||
ScalarObject<'ctx>,
|
||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
{
|
||||
@ -174,7 +206,7 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||
ctx,
|
||||
&[*self],
|
||||
ret_dtype,
|
||||
|generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]),
|
||||
|generator, ctx, scalars| mapping(generator, ctx, scalars[0]),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ pub mod functions;
|
||||
pub mod indexing;
|
||||
pub mod mapping;
|
||||
pub mod nalgebra;
|
||||
pub mod nditer;
|
||||
pub mod product;
|
||||
pub mod scalar;
|
||||
pub mod shape_util;
|
||||
@ -16,10 +17,11 @@ use crate::{
|
||||
call_nac3_ndarray_is_c_contiguous, call_nac3_ndarray_len, call_nac3_ndarray_nbytes,
|
||||
call_nac3_ndarray_resolve_and_check_new_shape, call_nac3_ndarray_set_strides_by_shape,
|
||||
call_nac3_ndarray_size, call_nac3_ndarray_transpose,
|
||||
call_nac3_ndarray_util_assert_output_shape_same,
|
||||
call_nac3_ndarray_util_assert_output_shape_same, call_nac3_nditer_has_next,
|
||||
call_nac3_nditer_initialize, call_nac3_nditer_next,
|
||||
},
|
||||
model::*,
|
||||
stmt::BreakContinueHooks,
|
||||
stmt::{gen_for_callback, BreakContinueHooks},
|
||||
structure::{NDArray, SimpleNDArray},
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
@ -36,6 +38,7 @@ use inkwell::{
|
||||
values::{BasicValue, BasicValueEnum, PointerValue},
|
||||
AddressSpace, IntPredicate,
|
||||
};
|
||||
use nditer::NDIterHandle;
|
||||
use scalar::{ScalarObject, ScalarOrNDArray};
|
||||
use util::{call_memcpy_model, gen_for_model_auto};
|
||||
|
||||
@ -250,7 +253,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
/// Get the n-th (0-based) scalar.
|
||||
///
|
||||
/// There is no out-of-bounds check.
|
||||
pub fn get_nth<G: CodeGenerator + ?Sized>(
|
||||
pub fn get_nth_scalar<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
@ -264,7 +267,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
/// Set the n-th (0-based) scalar.
|
||||
///
|
||||
/// There is no out-of-bounds check.
|
||||
pub fn set_nth<G: CodeGenerator + ?Sized>(
|
||||
pub fn set_nth_scalar<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
@ -466,7 +469,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
// NOTE: `np.size(self) == 0` here is never possible.
|
||||
let sizet_model = IntModel(SizeT);
|
||||
let zero = sizet_model.const_0(generator, ctx.ctx);
|
||||
ScalarOrNDArray::Scalar(self.get_nth(generator, ctx, zero))
|
||||
ScalarOrNDArray::Scalar(self.get_nth_scalar(generator, ctx, zero))
|
||||
} else {
|
||||
ScalarOrNDArray::NDArray(*self)
|
||||
}
|
||||
@ -543,40 +546,9 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
self.copy_strides_from_array(generator, ctx, src_strides);
|
||||
}
|
||||
|
||||
/// Iterate through every element pointer in the ndarray in its flatten view.
|
||||
/// Iterate through every element in the ndarray.
|
||||
///
|
||||
/// `body` also access to [`BreakContinueHooks`] to short-circuit and an element's
|
||||
/// index. The given element pointer also has been casted to the LLVM type of this ndarray's `dtype`.
|
||||
pub fn foreach_pointer<'a, G, F>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
body: F,
|
||||
) -> Result<(), String>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
F: FnOnce(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
BreakContinueHooks<'ctx>,
|
||||
Int<'ctx, SizeT>,
|
||||
Ptr<'ctx, IntModel<SizeT>>,
|
||||
PointerValue<'ctx>,
|
||||
) -> Result<(), String>,
|
||||
{
|
||||
let sizet_model = IntModel(SizeT);
|
||||
|
||||
let start = sizet_model.const_0(generator, ctx.ctx);
|
||||
let stop = self.size(generator, ctx);
|
||||
let step = sizet_model.const_1(generator, ctx.ctx);
|
||||
|
||||
gen_for_model_auto(generator, ctx, start, stop, step, |generator, ctx, hooks, i| {
|
||||
let pelement = self.get_nth_pointer(generator, ctx, i, "element");
|
||||
body(generator, ctx, hooks, i, pelement)
|
||||
})
|
||||
}
|
||||
|
||||
/// Iterate through every scalar in this ndarray.
|
||||
/// `body` also access to [`BreakContinueHooks`] to short-circuit.
|
||||
pub fn foreach<'a, G, F>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
@ -589,15 +561,18 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
BreakContinueHooks<'ctx>,
|
||||
Int<'ctx, SizeT>,
|
||||
ScalarObject<'ctx>,
|
||||
NDIterHandle<'ctx>,
|
||||
) -> Result<(), String>,
|
||||
{
|
||||
self.foreach_pointer(generator, ctx, |generator, ctx, hooks, i, p| {
|
||||
let value = ctx.builder.build_load(p, "value").unwrap();
|
||||
let scalar = ScalarObject { dtype: self.dtype, value };
|
||||
body(generator, ctx, hooks, i, scalar)
|
||||
})
|
||||
gen_for_callback(
|
||||
generator,
|
||||
ctx,
|
||||
Some("ndarray_foreach"),
|
||||
|generator, ctx| Ok(NDIterHandle::new(generator, ctx, *self)),
|
||||
|generator, ctx, nditer| Ok(nditer.has_next(generator, ctx).value),
|
||||
|generator, ctx, hooks, nditer| body(generator, ctx, hooks, nditer),
|
||||
|generator, ctx, nditer| Ok(nditer.next(generator, ctx)),
|
||||
)
|
||||
}
|
||||
|
||||
/// Make sure the ndarray is at least `ndmin`-dimensional.
|
||||
@ -632,8 +607,9 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
fill_value: BasicValueEnum<'ctx>,
|
||||
) {
|
||||
self.foreach_pointer(generator, ctx, |_generator, ctx, _hooks, _i, pelement| {
|
||||
ctx.builder.build_store(pelement, fill_value).unwrap();
|
||||
self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
|
||||
let p = nditer.get_pointer(generator, ctx);
|
||||
ctx.builder.build_store(p, fill_value).unwrap();
|
||||
Ok(())
|
||||
})
|
||||
.unwrap();
|
||||
|
83
nac3core/src/codegen/object/ndarray/nditer.rs
Normal file
83
nac3core/src/codegen/object/ndarray/nditer.rs
Normal file
@ -0,0 +1,83 @@
|
||||
use inkwell::{types::BasicType, values::PointerValue, AddressSpace};
|
||||
|
||||
use crate::codegen::{
|
||||
irrt::{call_nac3_nditer_has_next, call_nac3_nditer_initialize, call_nac3_nditer_next},
|
||||
model::*,
|
||||
object::ndarray::scalar::ScalarObject,
|
||||
structure::NDIter,
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
use super::NDArrayObject;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct NDIterHandle<'ctx> {
|
||||
ndarray: NDArrayObject<'ctx>,
|
||||
instance: Ptr<'ctx, StructModel<NDIter>>,
|
||||
indices: Ptr<'ctx, IntModel<SizeT>>,
|
||||
}
|
||||
|
||||
impl<'ctx> NDIterHandle<'ctx> {
|
||||
pub fn new<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ndarray: NDArrayObject<'ctx>,
|
||||
) -> Self {
|
||||
let nditer = StructModel(NDIter).alloca(generator, ctx, "nditer");
|
||||
let ndims = ndarray.get_ndims(generator, ctx.ctx);
|
||||
let indices = IntModel(SizeT).array_alloca(generator, ctx, ndims.value, "indices");
|
||||
call_nac3_nditer_initialize(generator, ctx, nditer, ndarray.instance, indices);
|
||||
NDIterHandle { ndarray, instance: nditer, indices }
|
||||
}
|
||||
|
||||
pub fn has_next<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> Int<'ctx, Bool> {
|
||||
call_nac3_nditer_has_next(generator, ctx, self.instance)
|
||||
}
|
||||
|
||||
pub fn next<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) {
|
||||
call_nac3_nditer_next(generator, ctx, self.instance)
|
||||
}
|
||||
|
||||
pub fn get_pointer<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let elem_ty = ctx.get_llvm_type(generator, self.ndarray.dtype);
|
||||
|
||||
let p = self.instance.get(generator, ctx, |f| f.element, "element");
|
||||
ctx.builder
|
||||
.build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), "element")
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub fn get_scalar<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> ScalarObject<'ctx> {
|
||||
let p = self.get_pointer(generator, ctx);
|
||||
let value = ctx.builder.build_load(p, "value").unwrap();
|
||||
ScalarObject { dtype: self.ndarray.dtype, value }
|
||||
}
|
||||
|
||||
pub fn get_index<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> Int<'ctx, SizeT> {
|
||||
self.instance.get(generator, ctx, |f| f.nth, "index")
|
||||
}
|
||||
|
||||
pub fn get_indices(&self) -> Ptr<'ctx, IntModel<SizeT>> {
|
||||
self.indices
|
||||
}
|
||||
}
|
@ -187,28 +187,36 @@ impl<'ctx, Item: Model<'ctx>> StructKind<'ctx> for SimpleNDArray<Item> {
|
||||
}
|
||||
|
||||
/// An IRRT helper structure used when iterating through an ndarray.
|
||||
/// Fields of [`IndicesIter`]
|
||||
pub struct IndicesIterFields<'ctx, F: FieldTraversal<'ctx>> {
|
||||
/// Fields of [`NDIter`]
|
||||
pub struct NDIterFields<'ctx, F: FieldTraversal<'ctx>> {
|
||||
pub ndims: F::Out<IntModel<SizeT>>,
|
||||
pub shape: F::Out<PtrModel<IntModel<SizeT>>>,
|
||||
pub strides: F::Out<PtrModel<IntModel<SizeT>>>,
|
||||
|
||||
pub indices: F::Out<PtrModel<IntModel<SizeT>>>,
|
||||
pub size: F::Out<IntModel<SizeT>>,
|
||||
pub nth: F::Out<IntModel<SizeT>>,
|
||||
pub element: F::Out<PtrModel<IntModel<Byte>>>,
|
||||
|
||||
pub size: F::Out<IntModel<SizeT>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct IndicesIter;
|
||||
pub struct NDIter;
|
||||
|
||||
impl<'ctx> StructKind<'ctx> for IndicesIter {
|
||||
type Fields<F: FieldTraversal<'ctx>> = IndicesIterFields<'ctx, F>;
|
||||
impl<'ctx> StructKind<'ctx> for NDIter {
|
||||
type Fields<F: FieldTraversal<'ctx>> = NDIterFields<'ctx, F>;
|
||||
|
||||
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
|
||||
Self::Fields {
|
||||
ndims: traversal.add_auto("ndims"),
|
||||
shape: traversal.add_auto("shape"),
|
||||
strides: traversal.add_auto("strides"),
|
||||
|
||||
indices: traversal.add_auto("indices"),
|
||||
size: traversal.add_auto("size"),
|
||||
nth: traversal.add_auto("nth"),
|
||||
element: traversal.add_auto("element"),
|
||||
|
||||
size: traversal.add_auto("size"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,9 @@
|
||||
use std::iter::once;
|
||||
|
||||
use helper::{create_ndims, debug_assert_prim_is_allowed, make_exception_fields, PrimDefDetails};
|
||||
use helper::{
|
||||
create_ndims, debug_assert_prim_is_allowed, extract_ndims, make_exception_fields,
|
||||
PrimDefDetails,
|
||||
};
|
||||
use indexmap::IndexMap;
|
||||
use inkwell::{
|
||||
attributes::{Attribute, AttributeLoc},
|
||||
@ -9,6 +12,7 @@ use inkwell::{
|
||||
IntPredicate,
|
||||
};
|
||||
use itertools::Either;
|
||||
use numpy::unpack_ndarray_var_tys;
|
||||
use strum::IntoEnumIterator;
|
||||
|
||||
use crate::{
|
||||
@ -16,15 +20,17 @@ use crate::{
|
||||
builtin_fns,
|
||||
classes::{ProxyValue, RangeValue},
|
||||
extern_fns::{self, call_np_linalg_det, call_np_linalg_matrix_power},
|
||||
irrt, llvm_intrinsics,
|
||||
model::{IntModel, SizeT},
|
||||
irrt::{self, call_nac3_ndarray_util_assert_shape_no_negative},
|
||||
llvm_intrinsics,
|
||||
model::*,
|
||||
numpy::*,
|
||||
numpy_new::{self, gen_ndarray_transpose},
|
||||
numpy_new::{self},
|
||||
object::{
|
||||
ndarray::{
|
||||
functions::{FloorOrCeil, MinOrMax},
|
||||
nalgebra::perform_nalgebra_call,
|
||||
scalar::{split_scalar_or_ndarray, ScalarObject, ScalarOrNDArray},
|
||||
shape_util::parse_numpy_int_sequence,
|
||||
NDArrayObject,
|
||||
},
|
||||
tuple::TupleObject,
|
||||
@ -1104,7 +1110,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
generator,
|
||||
ctx,
|
||||
ret_dtype,
|
||||
|generator, ctx, _i, scalar| {
|
||||
|generator, ctx, scalar| {
|
||||
let result = match prim {
|
||||
PrimDef::FunInt32 => scalar.cast_to_int32(generator, ctx),
|
||||
PrimDef::FunInt64 => scalar.cast_to_int64(generator, ctx),
|
||||
@ -1175,9 +1181,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
generator,
|
||||
ctx,
|
||||
ret_int_dtype,
|
||||
|generator, ctx, _i, scalar| {
|
||||
Ok(scalar.round(generator, ctx, ret_int_dtype).value)
|
||||
},
|
||||
|generator, ctx, scalar| Ok(scalar.round(generator, ctx, ret_int_dtype).value),
|
||||
)?;
|
||||
Ok(Some(result.to_basic_value_enum()))
|
||||
}),
|
||||
@ -1241,7 +1245,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
generator,
|
||||
ctx,
|
||||
int_sized,
|
||||
|generator, ctx, _i, scalar| {
|
||||
|generator, ctx, scalar| {
|
||||
Ok(scalar.floor_or_ceil(generator, ctx, kind, int_sized).value)
|
||||
},
|
||||
)?;
|
||||
@ -1291,13 +1295,24 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
self.ndarray_float,
|
||||
&[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
|
||||
Box::new(move |ctx, obj, fun, args, generator| {
|
||||
// Parse argument `shape`.
|
||||
let shape_ty = fun.0.args[0].ty;
|
||||
let shape_arg = args[0].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
|
||||
let shape = AnyObject { ty: shape_ty, value: shape_arg };
|
||||
let (_, shape) = parse_numpy_int_sequence(generator, ctx, shape);
|
||||
|
||||
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
|
||||
let ndims = extract_ndims(&ctx.unifier, ndims);
|
||||
|
||||
let func = match prim {
|
||||
PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => numpy_new::gen_ndarray_empty,
|
||||
PrimDef::FunNpZeros => numpy_new::gen_ndarray_zeros,
|
||||
PrimDef::FunNpOnes => numpy_new::gen_ndarray_ones,
|
||||
PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => NDArrayObject::from_np_empty,
|
||||
PrimDef::FunNpZeros => NDArrayObject::from_np_zero,
|
||||
PrimDef::FunNpOnes => NDArrayObject::from_np_ones,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
func(ctx, &obj, fun, &args, generator).map(|val| Some(val.as_basic_value_enum()))
|
||||
|
||||
let ndarray = func(generator, ctx, dtype, ndims, shape);
|
||||
Ok(Some(ndarray.instance.value.as_basic_value_enum()))
|
||||
}),
|
||||
)
|
||||
}
|
||||
@ -1356,8 +1371,47 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
resolver: None,
|
||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||
|ctx, obj, fun, args, generator| {
|
||||
numpy_new::gen_ndarray_array(ctx, &obj, fun, &args, generator)
|
||||
.map(|val| Some(val.as_basic_value_enum()))
|
||||
assert!(obj.is_none());
|
||||
assert!(matches!(args.len(), 1..=3));
|
||||
|
||||
let object_ty = fun.0.args[0].ty;
|
||||
let object =
|
||||
args[0].1.clone().to_basic_value_enum(ctx, generator, object_ty)?;
|
||||
let object = AnyObject { ty: object_ty, value: object };
|
||||
|
||||
let copy_arg = if let Some(arg) = args
|
||||
.iter()
|
||||
.find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name))
|
||||
{
|
||||
let copy_ty = fun.0.args[1].ty;
|
||||
arg.1.clone().to_basic_value_enum(ctx, generator, copy_ty)?
|
||||
} else {
|
||||
ctx.gen_symbol_val(
|
||||
generator,
|
||||
fun.0.args[1].default_value.as_ref().unwrap(),
|
||||
fun.0.args[1].ty,
|
||||
)
|
||||
};
|
||||
|
||||
// The argument `ndmin` is completely ignored. We don't need to know its LLVM value.
|
||||
// We simply make the output ndarray's ndims correct with `atleast_nd`.
|
||||
|
||||
let (dtype, ndims) =
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
|
||||
let output_ndims = extract_ndims(&ctx.unifier, ndims);
|
||||
|
||||
let copy =
|
||||
IntModel(Byte).check_value(generator, ctx.ctx, copy_arg).unwrap(); // NAC3 booleans are i8
|
||||
let copy = copy.truncate(generator, ctx, Bool, "copy_bool");
|
||||
|
||||
let ndarray =
|
||||
NDArrayObject::from_np_array(generator, ctx, object, copy);
|
||||
debug_assert!(ndarray.ndims <= output_ndims); // Sanity check on `ndims`
|
||||
|
||||
let ndarray = ndarray.atleast_nd(generator, ctx, output_ndims);
|
||||
debug_assert!(ctx.unifier.unioned(ndarray.dtype, dtype)); // Sanity check on `dtype`
|
||||
|
||||
Ok(Some(ndarray.instance.value.as_basic_value_enum()))
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
@ -1375,8 +1429,31 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
// type variable
|
||||
&[(self.list_int32, "shape"), (tv.ty, "fill_value")],
|
||||
Box::new(move |ctx, obj, fun, args, generator| {
|
||||
numpy_new::gen_ndarray_full(ctx, &obj, fun, &args, generator)
|
||||
.map(|val| Some(val.as_basic_value_enum()))
|
||||
assert!(obj.is_none());
|
||||
assert_eq!(args.len(), 2);
|
||||
|
||||
// Parse argument #1 shape
|
||||
let shape_ty = fun.0.args[0].ty;
|
||||
let shape =
|
||||
args[0].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
|
||||
let shape = AnyObject { ty: shape_ty, value: shape };
|
||||
let (_, shape) = parse_numpy_int_sequence(generator, ctx, shape);
|
||||
|
||||
// Parse argument #2 fill_value
|
||||
let fill_value_ty = fun.0.args[1].ty;
|
||||
let fill_value =
|
||||
args[1].1.clone().to_basic_value_enum(ctx, generator, fill_value_ty)?;
|
||||
let fill_value = ScalarObject { dtype: fill_value_ty, value: fill_value };
|
||||
|
||||
// Implementation
|
||||
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
|
||||
let ndims = extract_ndims(&ctx.unifier, ndims);
|
||||
|
||||
let ndarray = NDArrayObject::from_np_full(
|
||||
generator, ctx, dtype, ndims, shape, fill_value,
|
||||
);
|
||||
|
||||
Ok(Some(ndarray.instance.value.as_basic_value_enum()))
|
||||
}),
|
||||
)
|
||||
}
|
||||
@ -1483,12 +1560,38 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape"),
|
||||
],
|
||||
Box::new(move |ctx, obj, fun, args, generator| {
|
||||
let f = match prim {
|
||||
PrimDef::FunNpBroadcastTo => numpy_new::gen_ndarray_broadcast_to,
|
||||
PrimDef::FunNpReshape => numpy_new::gen_ndarray_reshape,
|
||||
// Parse argument #1 ndarray
|
||||
let input_ty = fun.0.args[0].ty;
|
||||
let input =
|
||||
args[0].1.clone().to_basic_value_enum(ctx, generator, input_ty)?;
|
||||
let input = AnyObject { ty: input_ty, value: input };
|
||||
|
||||
// Parse argument #2 shape
|
||||
let shape_ty = fun.0.args[1].ty;
|
||||
let shape =
|
||||
args[1].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
|
||||
let shape = AnyObject { ty: shape_ty, value: shape };
|
||||
let (_, return_shape) = parse_numpy_int_sequence(generator, ctx, shape);
|
||||
|
||||
// Implementation
|
||||
// Turn any input to an ndarray for a convenient implementation.
|
||||
let ndarray = split_scalar_or_ndarray(generator, ctx, input)
|
||||
.as_ndarray(generator, ctx);
|
||||
|
||||
let (_, return_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
|
||||
let return_ndims = extract_ndims(&ctx.unifier, return_ndims);
|
||||
|
||||
let result = match prim {
|
||||
PrimDef::FunNpBroadcastTo => {
|
||||
ndarray.broadcast_to(generator, ctx, return_ndims, return_shape)
|
||||
}
|
||||
PrimDef::FunNpReshape => {
|
||||
ndarray.reshape_or_copy(generator, ctx, return_ndims, return_shape)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
f(ctx, &obj, fun, &args, generator).map(Some)
|
||||
|
||||
Ok(Some(result.instance.value.as_basic_value_enum()))
|
||||
}),
|
||||
)
|
||||
}
|
||||
@ -1529,7 +1632,39 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
resolver: None,
|
||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||
|ctx, obj, fun, args, generator| {
|
||||
gen_ndarray_transpose(ctx, &obj, fun, &args, generator).map(Some)
|
||||
assert!(obj.is_none());
|
||||
assert_eq!(args.len(), 1);
|
||||
|
||||
// Parse argument #1 ndarray
|
||||
let ndarray_ty = fun.0.args[0].ty;
|
||||
let ndarray = args[0]
|
||||
.1
|
||||
.clone()
|
||||
.to_basic_value_enum(ctx, generator, ndarray_ty)?;
|
||||
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
|
||||
|
||||
// Implementation
|
||||
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
|
||||
|
||||
let has_axes = args.len() >= 2;
|
||||
let transposed_ndarray = if has_axes {
|
||||
// Parse argument #2 axes
|
||||
let in_axes_ty = fun.0.args[1].ty;
|
||||
let in_axes = args[1]
|
||||
.1
|
||||
.clone()
|
||||
.to_basic_value_enum(ctx, generator, in_axes_ty)?;
|
||||
let in_axes = AnyObject { ty: in_axes_ty, value: in_axes };
|
||||
|
||||
let (_, axes) = parse_numpy_int_sequence(generator, ctx, in_axes);
|
||||
|
||||
ndarray.transpose(generator, ctx, Some(axes))
|
||||
} else {
|
||||
// axes is not given
|
||||
ndarray.transpose(generator, ctx, None)
|
||||
};
|
||||
|
||||
Ok(Some(transposed_ndarray.instance.value.as_basic_value_enum()))
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
@ -1641,7 +1776,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
generator,
|
||||
ctx,
|
||||
ctx.primitives.float,
|
||||
move |_generator, ctx, _i, scalar| {
|
||||
move |_generator, ctx, scalar| {
|
||||
let result = scalar.np_floor_or_ceil(ctx, kind);
|
||||
Ok(result.value)
|
||||
},
|
||||
@ -1670,7 +1805,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
generator,
|
||||
ctx,
|
||||
ctx.primitives.float,
|
||||
|_generator, ctx, _i, scalar| {
|
||||
|_generator, ctx, scalar| {
|
||||
let result = scalar.np_round(ctx);
|
||||
Ok(result.value)
|
||||
},
|
||||
@ -1883,7 +2018,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
ctx,
|
||||
&[x1, x2],
|
||||
common_ty,
|
||||
|_generator, ctx, _i, scalars| {
|
||||
|_generator, ctx, scalars| {
|
||||
let x1 = scalars[0];
|
||||
let x2 = scalars[1];
|
||||
|
||||
@ -1930,7 +2065,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
generator,
|
||||
ctx,
|
||||
num_ty.ty,
|
||||
|_generator, ctx, _i, scalar| Ok(scalar.abs(ctx).value),
|
||||
|_generator, ctx, scalar| Ok(scalar.abs(ctx).value),
|
||||
)?;
|
||||
Ok(Some(result.to_basic_value_enum()))
|
||||
},
|
||||
@ -1966,7 +2101,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
generator,
|
||||
ctx,
|
||||
ctx.primitives.bool,
|
||||
|generator, ctx, _i, scalar| {
|
||||
|generator, ctx, scalar| {
|
||||
let n = scalar.into_float64(ctx);
|
||||
let n = function(generator, ctx, n);
|
||||
Ok(n.as_basic_value_enum())
|
||||
@ -2035,7 +2170,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
generator,
|
||||
ctx,
|
||||
ctx.primitives.float,
|
||||
|_generator, ctx, _i, scalar| {
|
||||
|_generator, ctx, scalar| {
|
||||
let n = scalar.into_float64(ctx);
|
||||
let n = match prim {
|
||||
PrimDef::FunNpSin => llvm_intrinsics::call_float_sin(ctx, n, None),
|
||||
@ -2160,7 +2295,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
ctx,
|
||||
&[x1, x2],
|
||||
ret_dtype,
|
||||
|_generator, ctx, _i, scalars| {
|
||||
|_generator, ctx, scalars| {
|
||||
let x1 = scalars[0];
|
||||
let x2 = scalars[1];
|
||||
|
||||
@ -2406,7 +2541,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
);
|
||||
let sizet_model = IntModel(SizeT);
|
||||
let zero = sizet_model.const_0(generator, ctx.ctx);
|
||||
x2_ndarray.set_nth(
|
||||
x2_ndarray.set_nth_scalar(
|
||||
generator,
|
||||
ctx,
|
||||
zero,
|
||||
@ -2451,7 +2586,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
|
||||
let sizet_model = IntModel(SizeT);
|
||||
let zero = sizet_model.const_0(generator, ctx.ctx);
|
||||
let determinant = out.get_nth(generator, ctx, zero);
|
||||
let determinant = out.get_nth_scalar(generator, ctx, zero);
|
||||
|
||||
Ok(Some(determinant.value))
|
||||
}),
|
||||
|
Loading…
Reference in New Issue
Block a user