Compare commits

...

7 Commits

Author SHA1 Message Date
lyken 916a2b4993
core/ndstrides: implement np_reshape() 2024-08-30 14:41:54 +08:00
lyken c7c3cc21a8
core: categorize np_{transpose,reshape} as 'view functions' 2024-08-30 14:41:54 +08:00
lyken d2072d9248
core/ndstrides: implement np_size() 2024-08-30 14:41:00 +08:00
lyken be19165ead
core/ndstrides: implement np_shape() and np_strides()
These functions are not important, but they are handy for debugging.

`np.strides()` is not an actual NumPy function, but `ndarray.strides` is used.
2024-08-30 14:41:00 +08:00
lyken ee58cf3fc3
core/ndstrides: implement ndarray.fill() and .copy() 2024-08-30 14:41:00 +08:00
lyken 8fe8ccf200
core/ndstrides: implement np_identity() and np_eye() 2024-08-30 14:41:00 +08:00
lyken d222236492
core/ndstrides: implement np_array()
It also checks for inconsistent dimensions if the input is a list.
e.g., rejecting `[[1.0, 2.0], [3.0]]`.

However, currently only `np_array(<input>, copy=False)` and `np_array(<input>, copy=True)` are supported. In NumPy, copy could be false, true, or None. Right now, NAC3's `np_array(<input>, copy=False)` behaves like NumPy's `np.array(<input>, copy=None)`.
2024-08-30 14:40:15 +08:00
15 changed files with 942 additions and 466 deletions

View File

@ -8,4 +8,6 @@
#include "irrt/ndarray/basic.hpp"
#include "irrt/ndarray/def.hpp"
#include "irrt/ndarray/iter.hpp"
#include "irrt/ndarray/indexing.hpp"
#include "irrt/ndarray/indexing.hpp"
#include "irrt/ndarray/array.hpp"
#include "irrt/ndarray/reshape.hpp"

View File

@ -0,0 +1,134 @@
#pragma once
#include "irrt/debug.hpp"
#include "irrt/exception.hpp"
#include "irrt/int_types.hpp"
#include "irrt/list.hpp"
#include "irrt/ndarray/basic.hpp"
#include "irrt/ndarray/def.hpp"
namespace {
namespace ndarray {
namespace array {
/**
* @brief In the context of `np.array(<list>)`, deduce the ndarray's shape produced by `<list>` and raise
* an exception if there is anything wrong with `<shape>` (e.g., inconsistent dimensions `np.array([[1.0, 2.0],
* [3.0]])`)
*
* If this function finds no issues with `<list>`, the deduced shape is written to `shape`. The caller has the
* responsibility to allocate `[SizeT; ndims]` for `shape`. The caller must also initialize `shape` with `-1`s because
* of implementation details.
*/
template<typename SizeT>
void set_and_validate_list_shape_helper(SizeT axis, List<SizeT>* list, SizeT ndims, SizeT* shape) {
if (shape[axis] == -1) {
// Dimension is unspecified. Set it.
shape[axis] = list->len;
} else {
// Dimension is specified. Check.
if (shape[axis] != list->len) {
// Mismatch, throw an error.
// NOTE: NumPy's error message is more complex and needs more PARAMS to display.
raise_exception(SizeT, EXN_VALUE_ERROR,
"The requested array has an inhomogenous shape "
"after {0} dimension(s).",
axis, shape[axis], list->len);
}
}
if (axis + 1 == ndims) {
// `list` has type `list[ItemType]`
// Do nothing
} else {
// `list` has type `list[list[...]]`
List<SizeT>** lists = (List<SizeT>**)(list->items);
for (SizeT i = 0; i < list->len; i++) {
set_and_validate_list_shape_helper<SizeT>(axis + 1, lists[i], ndims, shape);
}
}
}
/**
* @brief See `set_and_validate_list_shape_helper`.
*/
template<typename SizeT>
void set_and_validate_list_shape(List<SizeT>* list, SizeT ndims, SizeT* shape) {
for (SizeT axis = 0; axis < ndims; axis++) {
shape[axis] = -1; // Sentinel to say this dimension is unspecified.
}
set_and_validate_list_shape_helper<SizeT>(0, list, ndims, shape);
}
/**
* @brief In the context of `np.array(<list>)`, copied the contents stored in `list` to `ndarray`.
*
* `list` is assumed to be "legal". (i.e., no inconsistent dimensions)
*
* # Notes on `ndarray`
* The caller is responsible for allocating space for `ndarray`.
* Here is what this function expects from `ndarray` when called:
* - `ndarray->data` has to be allocated, contiguous, and may contain uninitialized values.
* - `ndarray->itemsize` has to be initialized.
* - `ndarray->ndims` has to be initialized.
* - `ndarray->shape` has to be initialized.
* - `ndarray->strides` is ignored, but note that `ndarray->data` is contiguous.
* When this function call ends:
* - `ndarray->data` is written with contents from `<list>`.
*/
template<typename SizeT>
void write_list_to_array_helper(SizeT axis, SizeT* index, List<SizeT>* list, NDArray<SizeT>* ndarray) {
debug_assert_eq(SizeT, list->len, ndarray->shape[axis]);
if (IRRT_DEBUG_ASSERT_BOOL) {
if (!ndarray::basic::is_c_contiguous(ndarray)) {
raise_debug_assert(SizeT, "ndarray is not C-contiguous", ndarray->strides[0], ndarray->strides[1],
NO_PARAM);
}
}
if (axis + 1 == ndarray->ndims) {
// `list` has type `list[scalar]`
// `ndarray` is contiguous, so we can do this, and this is fast.
uint8_t* dst = ndarray->data + (ndarray->itemsize * (*index));
__builtin_memcpy(dst, list->items, ndarray->itemsize * list->len);
*index += list->len;
} else {
// `list` has type `list[list[...]]`
List<SizeT>** lists = (List<SizeT>**)(list->items);
for (SizeT i = 0; i < list->len; i++) {
write_list_to_array_helper<SizeT>(axis + 1, index, lists[i], ndarray);
}
}
}
/**
* @brief See `write_list_to_array_helper`.
*/
template<typename SizeT>
void write_list_to_array(List<SizeT>* list, NDArray<SizeT>* ndarray) {
SizeT index = 0;
write_list_to_array_helper<SizeT>((SizeT)0, &index, list, ndarray);
}
} // namespace array
} // namespace ndarray
} // namespace
extern "C" {
using namespace ndarray::array;
void __nac3_ndarray_array_set_and_validate_list_shape(List<int32_t>* list, int32_t ndims, int32_t* shape) {
set_and_validate_list_shape(list, ndims, shape);
}
void __nac3_ndarray_array_set_and_validate_list_shape64(List<int64_t>* list, int64_t ndims, int64_t* shape) {
set_and_validate_list_shape(list, ndims, shape);
}
void __nac3_ndarray_array_write_list_to_array(List<int32_t>* list, NDArray<int32_t>* ndarray) {
write_list_to_array(list, ndarray);
}
void __nac3_ndarray_array_write_list_to_array64(List<int64_t>* list, NDArray<int64_t>* ndarray) {
write_list_to_array(list, ndarray);
}
}

View File

@ -0,0 +1,99 @@
#pragma once
#include "irrt/exception.hpp"
#include "irrt/int_types.hpp"
#include "irrt/ndarray/def.hpp"
namespace {
namespace ndarray {
namespace reshape {
/**
* @brief Perform assertions on and resolve unknown dimensions in `new_shape` in `np.reshape(<ndarray>, new_shape)`
*
* If `new_shape` indeed contains unknown dimensions (specified with `-1`, just like numpy), `new_shape` will be
* modified to contain the resolved dimension.
*
* To perform assertions on and resolve unknown dimensions in `new_shape`, we don't need the actual
* `<ndarray>` object itself, but only the `.size` of the `<ndarray>`.
*
* @param size The `.size` of `<ndarray>`
* @param new_ndims Number of elements in `new_shape`
* @param new_shape Target shape to reshape to
*/
template<typename SizeT>
void resolve_and_check_new_shape(SizeT size, SizeT new_ndims, SizeT* new_shape) {
// Is there a -1 in `new_shape`?
bool neg1_exists = false;
// Location of -1, only initialized if `neg1_exists` is true
SizeT neg1_axis_i;
// The computed ndarray size of `new_shape`
SizeT new_size = 1;
for (SizeT axis_i = 0; axis_i < new_ndims; axis_i++) {
SizeT dim = new_shape[axis_i];
if (dim < 0) {
if (dim == -1) {
if (neg1_exists) {
// Multiple `-1` found. Throw an error.
raise_exception(SizeT, EXN_VALUE_ERROR, "can only specify one unknown dimension", NO_PARAM,
NO_PARAM, NO_PARAM);
} else {
neg1_exists = true;
neg1_axis_i = axis_i;
}
} else {
// TODO: What? In `np.reshape` any negative dimensions is
// treated like its `-1`.
//
// Try running `np.zeros((3, 4)).reshape((-999, 2))`
//
// It is not documented by numpy.
// Throw an error for now...
raise_exception(SizeT, EXN_VALUE_ERROR, "Found non -1 negative dimension {0} on axis {1}", dim, axis_i,
NO_PARAM);
}
} else {
new_size *= dim;
}
}
bool can_reshape;
if (neg1_exists) {
// Let `x` be the unknown dimension
// Solve `x * <new_size> = <size>`
if (new_size == 0 && size == 0) {
// `x` has infinitely many solutions
can_reshape = false;
} else if (new_size == 0 && size != 0) {
// `x` has no solutions
can_reshape = false;
} else if (size % new_size != 0) {
// `x` has no integer solutions
can_reshape = false;
} else {
can_reshape = true;
new_shape[neg1_axis_i] = size / new_size; // Resolve dimension
}
} else {
can_reshape = (new_size == size);
}
if (!can_reshape) {
raise_exception(SizeT, EXN_VALUE_ERROR, "cannot reshape array of size {0} into given shape", size, NO_PARAM,
NO_PARAM);
}
}
} // namespace reshape
} // namespace ndarray
} // namespace
extern "C" {
void __nac3_ndarray_reshape_resolve_and_check_new_shape(int32_t size, int32_t new_ndims, int32_t* new_shape) {
ndarray::reshape::resolve_and_check_new_shape(size, new_ndims, new_shape);
}
void __nac3_ndarray_reshape_resolve_and_check_new_shape64(int64_t size, int64_t new_ndims, int64_t* new_shape) {
ndarray::reshape::resolve_and_check_new_shape(size, new_ndims, new_shape);
}
}

View File

@ -8,7 +8,10 @@ use super::{
llvm_intrinsics,
macros::codegen_unreachable,
model::*,
object::ndarray::{indexing::NDIndex, nditer::NDIter, NDArray},
object::{
list::List,
ndarray::{indexing::NDIndex, nditer::NDIter, NDArray},
},
stmt::gen_for_callback_incrementing,
CodeGenContext, CodeGenerator,
};
@ -1129,3 +1132,47 @@ pub fn call_nac3_ndarray_index<'ctx, G: CodeGenerator + ?Sized>(
.arg(dst_ndarray)
.returning_void();
}
pub fn call_nac3_ndarray_array_set_and_validate_list_shape<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
list: Instance<'ctx, Ptr<Struct<List<Int<Byte>>>>>,
ndims: Instance<'ctx, Int<SizeT>>,
shape: Instance<'ctx, Ptr<Int<SizeT>>>,
) {
let name = get_sizet_dependent_function_name(
generator,
ctx,
"__nac3_ndarray_array_set_and_validate_list_shape",
);
FnCall::builder(generator, ctx, &name).arg(list).arg(ndims).arg(shape).returning_void();
}
pub fn call_nac3_ndarray_array_write_list_to_array<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
list: Instance<'ctx, Ptr<Struct<List<Int<Byte>>>>>,
ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
) {
let name = get_sizet_dependent_function_name(
generator,
ctx,
"__nac3_ndarray_array_write_list_to_array",
);
FnCall::builder(generator, ctx, &name).arg(list).arg(ndarray).returning_void();
}
pub fn call_nac3_ndarray_reshape_resolve_and_check_new_shape<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
size: Instance<'ctx, Int<SizeT>>,
new_ndims: Instance<'ctx, Int<SizeT>>,
new_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
) {
let name = get_sizet_dependent_function_name(
generator,
ctx,
"__nac3_ndarray_reshape_resolve_and_check_new_shape",
);
FnCall::builder(generator, ctx, &name).arg(size).arg(new_ndims).arg(new_shape).returning_void();
}

View File

@ -182,6 +182,15 @@ impl<'ctx, Item: Model<'ctx>> Instance<'ctx, Ptr<Item>> {
Ptr(new_item).pointer_cast(generator, ctx, self.value)
}
/// Cast this pointer to `uint8_t*`
pub fn cast_to_pi8<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
) -> Instance<'ctx, Ptr<Int<Byte>>> {
Ptr(Int(Byte)).pointer_cast(generator, ctx, self.value)
}
/// Check if the pointer is null with [`inkwell::builder::Builder::build_is_null`].
pub fn is_null(&self, ctx: &CodeGenContext<'ctx, '_>) -> Instance<'ctx, Int<Bool>> {
let value = ctx.builder.build_is_null(self.value, "").unwrap();

View File

@ -13,6 +13,7 @@ use crate::{
},
llvm_intrinsics::{self, call_memcpy_generic},
macros::codegen_unreachable,
model::*,
object::{
any::AnyObject,
ndarray::{shape_util::parse_numpy_int_sequence, NDArrayObject},
@ -22,13 +23,13 @@ use crate::{
},
symbol_resolver::ValueEnum,
toplevel::{
helper::{extract_ndims, PrimDef},
helper::extract_ndims,
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
DefinitionId,
},
typecheck::{
magic_methods::Binop,
typedef::{FunSignature, Type, TypeEnum},
typedef::{FunSignature, Type},
},
};
use inkwell::{
@ -1840,26 +1841,6 @@ pub fn gen_ndarray_array<'ctx>(
assert!(matches!(args.len(), 1..=3));
let obj_ty = fun.0.args[0].ty;
let obj_elem_ty = match &*context.unifier.get_ty(obj_ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
unpack_ndarray_var_tys(&mut context.unifier, obj_ty).0
}
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
let mut ty = *params.iter().next().unwrap().1;
while let TypeEnum::TObj { obj_id, params, .. } = &*context.unifier.get_ty_immutable(ty)
{
if *obj_id != PrimDef::List.id() {
break;
}
ty = *params.iter().next().unwrap().1;
}
ty
}
_ => obj_ty,
};
let obj_arg = args[0].1.clone().to_basic_value_enum(context, generator, obj_ty)?;
let copy_arg = if let Some(arg) =
@ -1875,28 +1856,18 @@ pub fn gen_ndarray_array<'ctx>(
)
};
let ndmin_arg = if let Some(arg) =
args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[2].name))
{
let ndmin_ty = fun.0.args[2].ty;
arg.1.clone().to_basic_value_enum(context, generator, ndmin_ty)?
} else {
context.gen_symbol_val(
generator,
fun.0.args[2].default_value.as_ref().unwrap(),
fun.0.args[2].ty,
)
};
// The ndmin argument is ignored. We can simply force the ndarray's number of dimensions to be
// the `ndims` of the function return type.
let (_, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
let ndims = extract_ndims(&context.unifier, ndims);
call_ndarray_array_impl(
generator,
context,
obj_elem_ty,
obj_arg,
copy_arg.into_int_value(),
ndmin_arg.into_int_value(),
)
.map(NDArrayValue::into)
let object = AnyObject { value: obj_arg, ty: obj_ty };
// NAC3 booleans are i8.
let copy = Int(Bool).truncate(generator, context, copy_arg.into_int_value());
let ndarray = NDArrayObject::make_np_array(generator, context, object, copy)
.atleast_nd(generator, context, ndims);
Ok(ndarray.instance.value)
}
/// Generates LLVM IR for `ndarray.eye`.
@ -1935,15 +1906,23 @@ pub fn gen_ndarray_eye<'ctx>(
))
}?;
call_ndarray_eye_impl(
generator,
context,
context.primitives.float,
nrows_arg.into_int_value(),
ncols_arg.into_int_value(),
offset_arg.into_int_value(),
)
.map(NDArrayValue::into)
let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
let nrows = Int(Int32)
.check_value(generator, context.ctx, nrows_arg)
.unwrap()
.s_extend_or_bit_cast(generator, context, SizeT);
let ncols = Int(Int32)
.check_value(generator, context.ctx, ncols_arg)
.unwrap()
.s_extend_or_bit_cast(generator, context, SizeT);
let offset = Int(Int32)
.check_value(generator, context.ctx, offset_arg)
.unwrap()
.s_extend_or_bit_cast(generator, context, SizeT);
let ndarray = NDArrayObject::make_np_eye(generator, context, dtype, nrows, ncols, offset);
Ok(ndarray.instance.value)
}
/// Generates LLVM IR for `ndarray.identity`.
@ -1957,20 +1936,15 @@ pub fn gen_ndarray_identity<'ctx>(
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let llvm_usize = generator.get_size_type(context.ctx);
let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
let n_ty = fun.0.args[0].ty;
let n_arg = args[0].1.clone().to_basic_value_enum(context, generator, n_ty)?;
call_ndarray_eye_impl(
generator,
context,
context.primitives.float,
n_arg.into_int_value(),
n_arg.into_int_value(),
llvm_usize.const_zero(),
)
.map(NDArrayValue::into)
let n = Int(Int32).check_value(generator, context.ctx, n_arg).unwrap();
let n = n.s_extend_or_bit_cast(generator, context, SizeT);
let ndarray = NDArrayObject::make_np_identity(generator, context, dtype, n);
Ok(ndarray.instance.value)
}
/// Generates LLVM IR for `ndarray.copy`.
@ -1984,20 +1958,14 @@ pub fn gen_ndarray_copy<'ctx>(
assert!(obj.is_some());
assert!(args.is_empty());
let llvm_usize = generator.get_size_type(context.ctx);
let this_ty = obj.as_ref().unwrap().0;
let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty);
let this_arg =
obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?;
ndarray_copy_impl(
generator,
context,
this_elem_ty,
NDArrayValue::from_ptr_val(this_arg.into_pointer_value(), llvm_usize, None),
)
.map(NDArrayValue::into)
let this = AnyObject { value: this_arg, ty: this_ty };
let this = NDArrayObject::from_object(generator, context, this);
let ndarray = this.make_copy(generator, context);
Ok(ndarray.instance.value)
}
/// Generates LLVM IR for `ndarray.fill`.
@ -2011,48 +1979,15 @@ pub fn gen_ndarray_fill<'ctx>(
assert!(obj.is_some());
assert_eq!(args.len(), 1);
let llvm_usize = generator.get_size_type(context.ctx);
let this_ty = obj.as_ref().unwrap().0;
let this_arg = obj
.as_ref()
.unwrap()
.1
.clone()
.to_basic_value_enum(context, generator, this_ty)?
.into_pointer_value();
let this_arg =
obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?;
let value_ty = fun.0.args[0].ty;
let value_arg = args[0].1.clone().to_basic_value_enum(context, generator, value_ty)?;
ndarray_fill_flattened(
generator,
context,
NDArrayValue::from_ptr_val(this_arg, llvm_usize, None),
|generator, ctx, _| {
let value = if value_arg.is_pointer_value() {
let llvm_i1 = ctx.ctx.bool_type();
let copy = generator.gen_var_alloc(ctx, value_arg.get_type(), None)?;
call_memcpy_generic(
ctx,
copy,
value_arg.into_pointer_value(),
value_arg.get_type().size_of().map(Into::into).unwrap(),
llvm_i1.const_zero(),
);
copy.into()
} else if value_arg.is_int_value() || value_arg.is_float_value() {
value_arg
} else {
codegen_unreachable!(ctx)
};
Ok(value)
},
)?;
let this = AnyObject { value: this_arg, ty: this_ty };
let this = NDArrayObject::from_object(generator, context, this);
this.fill(generator, context, value_arg);
Ok(())
}
@ -2163,293 +2098,6 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
}
}
/// LLVM-typed implementation for generating the implementation for `ndarray.reshape`.
///
/// * `x1` - `NDArray` to reshape.
/// * `shape` - The `shape` parameter used to construct the new `NDArray`.
/// Just like numpy, the `shape` argument can be:
/// 1. A list of `int32`; e.g., `np.reshape(arr, [600, -1, 3])`
/// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))`
/// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)`
///
/// Note that unlike other generating functions, one of the dimensions in the shape can be negative.
pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
shape: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "ndarray_reshape";
let (x1_ty, x1) = x1;
let (_, shape) = shape;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
ctx.builder.build_store(acc, llvm_usize.const_int(1, false)).unwrap();
ctx.builder.build_store(num_neg, llvm_usize.const_zero()).unwrap();
let out = match shape {
BasicValueEnum::PointerValue(shape_list_ptr)
if ListValue::is_instance(shape_list_ptr, llvm_usize).is_ok() =>
{
// 1. A list of ints; e.g., `np.reshape(arr, [int64(600), int64(800, -1])`
let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None);
// Check for -1 in dimensions
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(shape_list.load_size(ctx, None), false),
|generator, ctx, _, idx| {
let ele =
shape_list.data().get(ctx, generator, &idx, None).into_int_value();
let ele = ctx.builder.build_int_s_extend(ele, llvm_usize, "").unwrap();
gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
ele,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, ctx| -> Result<Option<IntValue>, String> {
let num_neg_value =
ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
let num_neg_value = ctx
.builder
.build_int_add(
num_neg_value,
llvm_usize.const_int(1, false),
"",
)
.unwrap();
ctx.builder.build_store(num_neg, num_neg_value).unwrap();
Ok(None)
},
|_, ctx| {
let acc_value =
ctx.builder.build_load(acc, "").unwrap().into_int_value();
let acc_value =
ctx.builder.build_int_mul(acc_value, ele, "").unwrap();
ctx.builder.build_store(acc, acc_value).unwrap();
Ok(None)
},
)?;
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
// Generate the output shape by filling -1 with `rem`
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&shape_list,
|_, ctx, _| Ok(shape_list.load_size(ctx, None)),
|generator, ctx, shape_list, idx| {
let dim =
shape_list.data().get(ctx, generator, &idx, None).into_int_value();
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
Ok(gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
dim,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, _| Ok(Some(rem)),
|_, _| Ok(Some(dim)),
)?
.unwrap()
.into_int_value())
},
)
}
BasicValueEnum::StructValue(shape_tuple) => {
// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))`
let ndims = shape_tuple.get_type().count_fields();
// Check for -1 in dims
for dim_i in 0..ndims {
let dim = ctx
.builder
.build_extract_value(shape_tuple, dim_i, "")
.unwrap()
.into_int_value();
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
dim,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, ctx| -> Result<Option<IntValue>, String> {
let num_negs =
ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
let num_negs = ctx
.builder
.build_int_add(num_negs, llvm_usize.const_int(1, false), "")
.unwrap();
ctx.builder.build_store(num_neg, num_negs).unwrap();
Ok(None)
},
|_, ctx| {
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
let acc_val = ctx.builder.build_int_mul(acc_val, dim, "").unwrap();
ctx.builder.build_store(acc, acc_val).unwrap();
Ok(None)
},
)?;
}
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
let mut shape = Vec::with_capacity(ndims as usize);
// Reconstruct shape filling negatives with rem
for dim_i in 0..ndims {
let dim = ctx
.builder
.build_extract_value(shape_tuple, dim_i, "")
.unwrap()
.into_int_value();
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
let dim = gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
dim,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, _| Ok(Some(rem)),
|_, _| Ok(Some(dim)),
)?
.unwrap()
.into_int_value();
shape.push(dim);
}
create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice())
}
BasicValueEnum::IntValue(shape_int) => {
// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)`
let shape_int = gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
shape_int,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, _| Ok(Some(n_sz)),
|_, ctx| {
Ok(Some(ctx.builder.build_int_s_extend(shape_int, llvm_usize, "").unwrap()))
},
)?
.unwrap()
.into_int_value();
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
}
_ => codegen_unreachable!(ctx),
}
.unwrap();
// Only allow one dimension to be negative
let num_negs = ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
ctx.make_assert(
generator,
ctx.builder
.build_int_compare(IntPredicate::ULT, num_negs, llvm_usize.const_int(2, false), "")
.unwrap(),
"0:ValueError",
"can only specify one unknown dimension",
[None, None, None],
ctx.current_loc,
);
// The new shape must be compatible with the old shape
let out_sz = call_ndarray_calc_size(generator, ctx, &out.dim_sizes(), (None, None));
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(),
"0:ValueError",
"cannot reshape array of size {0} into provided shape of size {1}",
[Some(n_sz), Some(out_sz), None],
ctx.current_loc,
);
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(n_sz, false),
|generator, ctx, _, idx| {
let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
unsafe { out.data().set_unchecked(ctx, generator, &idx, elem) };
Ok(())
},
llvm_usize.const_int(1, false),
)?;
Ok(out.as_base_value().into())
} else {
codegen_unreachable!(
ctx,
"{FN_NAME}() not supported for '{}'",
format!("'{}'", ctx.unifier.stringify(x1_ty))
)
}
}
/// Generates LLVM IR for `ndarray.dot`.
/// Calculate inner product of two vectors or literals
/// For matrix multiplication use `np_matmul`

View File

@ -31,6 +31,17 @@ impl<'ctx, Item: Model<'ctx>> StructKind<'ctx> for List<Item> {
}
}
impl<'ctx, Item: Model<'ctx>> Instance<'ctx, Ptr<Struct<List<Item>>>> {
/// Cast the items pointer to `uint8_t*`.
pub fn with_pi8_items<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Instance<'ctx, Ptr<Struct<List<Int<Byte>>>>> {
self.pointer_cast(generator, ctx, Struct(List { item: Int(Byte) }))
}
}
/// A NAC3 Python List object.
#[derive(Debug, Clone, Copy)]
pub struct ListObject<'ctx> {

View File

@ -0,0 +1,184 @@
use super::NDArrayObject;
use crate::{
codegen::{
irrt::{
call_nac3_ndarray_array_set_and_validate_list_shape,
call_nac3_ndarray_array_write_list_to_array,
},
model::*,
object::{any::AnyObject, list::ListObject},
stmt::gen_if_else_expr_callback,
CodeGenContext, CodeGenerator,
},
toplevel::helper::{arraylike_flatten_element_type, arraylike_get_ndims},
typecheck::typedef::{Type, TypeEnum},
};
/// Get the expected `dtype` and `ndims` of the ndarray returned by `np_array(list)`.
fn get_list_object_dtype_and_ndims<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
list: ListObject<'ctx>,
) -> (Type, u64) {
let dtype = arraylike_flatten_element_type(&mut ctx.unifier, list.item_type);
let ndims = arraylike_get_ndims(&mut ctx.unifier, list.item_type);
let ndims = ndims + 1; // To count `list` itself.
(dtype, ndims)
}
impl<'ctx> NDArrayObject<'ctx> {
/// Implementation of `np_array(<list>, copy=True)`
fn make_np_array_list_copy_true_impl<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
list: ListObject<'ctx>,
) -> Self {
let (dtype, ndims_int) = get_list_object_dtype_and_ndims(ctx, list);
let list_value = list.instance.with_pi8_items(generator, ctx);
// Validate `list` has a consistent shape.
// Raise an exception if `list` is something abnormal like `[[1, 2], [3]]`.
// If `list` has a consistent shape, deduce the shape and write it to `shape`.
let ndims = Int(SizeT).const_int(generator, ctx.ctx, ndims_int, false);
let shape = Int(SizeT).array_alloca(generator, ctx, ndims.value);
call_nac3_ndarray_array_set_and_validate_list_shape(
generator, ctx, list_value, ndims, shape,
);
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims_int);
ndarray.copy_shape_from_array(generator, ctx, shape);
ndarray.create_data(generator, ctx);
// Copy all contents from the list.
call_nac3_ndarray_array_write_list_to_array(generator, ctx, list_value, ndarray.instance);
ndarray
}
/// Implementation of `np_array(<list>, copy=None)`
fn make_np_array_list_copy_none_impl<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
list: ListObject<'ctx>,
) -> Self {
// np_array without copying is only possible `list` is not nested.
//
// If `list` is `list[T]`, we can create an ndarray with `data` set
// to the array pointer of `list`.
//
// If `list` is `list[list[T]]` or worse, copy.
let (dtype, ndims) = get_list_object_dtype_and_ndims(ctx, list);
if ndims == 1 {
// `list` is not nested
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, 1);
// Set data
let data = list.instance.get(generator, ctx, |f| f.items).cast_to_pi8(generator, ctx);
ndarray.instance.set(ctx, |f| f.data, data);
// ndarray->shape[0] = list->len;
let shape = ndarray.instance.get(generator, ctx, |f| f.shape);
let list_len = list.instance.get(generator, ctx, |f| f.len);
shape.set_index_const(ctx, 0, list_len);
// Set strides, the `data` is contiguous
ndarray.set_strides_contiguous(generator, ctx);
ndarray
} else {
// `list` is nested, copy
NDArrayObject::make_np_array_list_copy_true_impl(generator, ctx, list)
}
}
/// Implementation of `np_array(<list>, copy=copy)`
fn make_np_array_list_impl<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
list: ListObject<'ctx>,
copy: Instance<'ctx, Int<Bool>>,
) -> Self {
let (dtype, ndims) = get_list_object_dtype_and_ndims(ctx, list);
let ndarray = gen_if_else_expr_callback(
generator,
ctx,
|_generator, _ctx| Ok(copy.value),
|generator, ctx| {
let ndarray =
NDArrayObject::make_np_array_list_copy_true_impl(generator, ctx, list);
Ok(Some(ndarray.instance.value))
},
|generator, ctx| {
let ndarray =
NDArrayObject::make_np_array_list_copy_none_impl(generator, ctx, list);
Ok(Some(ndarray.instance.value))
},
)
.unwrap()
.unwrap();
NDArrayObject::from_value_and_unpacked_types(generator, ctx, ndarray, dtype, ndims)
}
/// Implementation of `np_array(<ndarray>, copy=copy)`.
pub fn make_np_array_ndarray_impl<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayObject<'ctx>,
copy: Instance<'ctx, Int<Bool>>,
) -> Self {
let ndarray_val = gen_if_else_expr_callback(
generator,
ctx,
|_generator, _ctx| Ok(copy.value),
|generator, ctx| {
let ndarray = ndarray.make_copy(generator, ctx); // Force copy
Ok(Some(ndarray.instance.value))
},
|_generator, _ctx| {
// No need to copy. Return `ndarray` itself.
Ok(Some(ndarray.instance.value))
},
)
.unwrap()
.unwrap();
NDArrayObject::from_value_and_unpacked_types(
generator,
ctx,
ndarray_val,
ndarray.dtype,
ndarray.ndims,
)
}
/// Create a new ndarray like `np.array()`.
///
/// NOTE: The `ndmin` argument is not here. You may want to
/// do [`NDArrayObject::atleast_nd`] to achieve that.
pub fn make_np_array<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
object: AnyObject<'ctx>,
copy: Instance<'ctx, Int<Bool>>,
) -> Self {
match &*ctx.unifier.get_ty(object.ty) {
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{
let list = ListObject::from_object(generator, ctx, object);
NDArrayObject::make_np_array_list_impl(generator, ctx, list, copy)
}
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
{
let ndarray = NDArrayObject::from_object(generator, ctx, object);
NDArrayObject::make_np_array_ndarray_impl(generator, ctx, ndarray, copy)
}
_ => panic!("Unrecognized object type: {}", ctx.unifier.stringify(object.ty)), // Typechecker ensures this
}
}
}

View File

@ -1,4 +1,4 @@
use inkwell::values::BasicValueEnum;
use inkwell::{values::BasicValueEnum, IntPredicate};
use crate::{
codegen::{
@ -123,4 +123,54 @@ impl<'ctx> NDArrayObject<'ctx> {
let fill_value = ndarray_one_value(generator, ctx, dtype);
NDArrayObject::make_np_full(generator, ctx, dtype, ndims, shape, fill_value)
}
/// Create an ndarray like `np.eye`.
pub fn make_np_eye<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
nrows: Instance<'ctx, Int<SizeT>>,
ncols: Instance<'ctx, Int<SizeT>>,
offset: Instance<'ctx, Int<SizeT>>,
) -> Self {
let ndzero = ndarray_zero_value(generator, ctx, dtype);
let ndone = ndarray_one_value(generator, ctx, dtype);
let ndarray = NDArrayObject::alloca_dynamic_shape(generator, ctx, dtype, &[nrows, ncols]);
// Create data and make the matrix like look np.eye()
ndarray.create_data(generator, ctx);
ndarray
.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
// NOTE: rows and cols can never be zero here, since this ndarray's `np.size` would be zero
// and this loop would not execute.
// Load up `row_i` and `col_i` from indices.
let row_i = nditer.get_indices().get_index_const(generator, ctx, 0);
let col_i = nditer.get_indices().get_index_const(generator, ctx, 1);
let be_one = row_i.add(ctx, offset).compare(ctx, IntPredicate::EQ, col_i);
let value = ctx.builder.build_select(be_one.value, ndone, ndzero, "value").unwrap();
let p = nditer.get_pointer(generator, ctx);
ctx.builder.build_store(p, value).unwrap();
Ok(())
})
.unwrap();
ndarray
}
/// Create an ndarray like `np.identity`.
pub fn make_np_identity<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
size: Instance<'ctx, Int<SizeT>>,
) -> Self {
// Convenient implementation
let offset = Int(SizeT).const_0(generator, ctx.ctx);
NDArrayObject::make_np_eye(generator, ctx, dtype, size, size, offset)
}
}

View File

@ -1,3 +1,4 @@
pub mod array;
pub mod factory;
pub mod indexing;
pub mod nditer;
@ -26,7 +27,7 @@ use crate::{
typecheck::typedef::Type,
};
use super::any::AnyObject;
use super::{any::AnyObject, tuple::TupleObject};
/// Fields of [`NDArray`]
pub struct NDArrayFields<'ctx, F: FieldTraversal<'ctx>> {
@ -74,8 +75,19 @@ impl<'ctx> NDArrayObject<'ctx> {
) -> NDArrayObject<'ctx> {
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, object.ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
Self::from_value_and_unpacked_types(generator, ctx, object.value, dtype, ndims)
}
let value = Ptr(Struct(NDArray)).check_value(generator, ctx.ctx, object.value).unwrap();
/// Like [`NDArrayObject::from_object`] but you directly supply the ndarray's
/// `dtype` and `ndims`.
pub fn from_value_and_unpacked_types<V: BasicValue<'ctx>, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
value: V,
dtype: Type,
ndims: u64,
) -> Self {
let value = Ptr(Struct(NDArray)).check_value(generator, ctx.ctx, value).unwrap();
NDArrayObject { dtype, ndims, instance: value }
}
@ -406,6 +418,8 @@ impl<'ctx> NDArrayObject<'ctx> {
ctx: &mut CodeGenContext<'ctx, '_>,
value: BasicValueEnum<'ctx>,
) {
// TODO: It is possible to optimize this by exploiting contiguous strides with memset.
// Probably best to implement in IRRT.
self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
let p = nditer.get_pointer(generator, ctx);
ctx.builder.build_store(p, value).unwrap();
@ -413,6 +427,62 @@ impl<'ctx> NDArrayObject<'ctx> {
})
.unwrap();
}
/// Create the shape tuple of this ndarray like `np.shape(<ndarray>)`.
///
/// The returned integers in the tuple are in int32.
pub fn make_shape_tuple<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> TupleObject<'ctx> {
// TODO: Return a tuple of SizeT
let mut objects = Vec::with_capacity(self.ndims as usize);
for i in 0..self.ndims {
let dim = self
.instance
.get(generator, ctx, |f| f.shape)
.get_index_const(generator, ctx, i64::try_from(i).unwrap())
.truncate_or_bit_cast(generator, ctx, Int32);
objects.push(AnyObject {
ty: ctx.primitives.int32,
value: dim.value.as_basic_value_enum(),
});
}
TupleObject::from_objects(generator, ctx, objects)
}
/// Create the strides tuple of this ndarray like `<ndarray>.strides`.
///
/// The returned integers in the tuple are in int32.
pub fn make_strides_tuple<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> TupleObject<'ctx> {
// TODO: Return a tuple of SizeT.
let mut objects = Vec::with_capacity(self.ndims as usize);
for i in 0..self.ndims {
let dim = self
.instance
.get(generator, ctx, |f| f.strides)
.get_index_const(generator, ctx, i64::try_from(i).unwrap())
.truncate_or_bit_cast(generator, ctx, Int32);
objects.push(AnyObject {
ty: ctx.primitives.int32,
value: dim.value.as_basic_value_enum(),
});
}
TupleObject::from_objects(generator, ctx, objects)
}
}
/// A convenience enum for implementing functions that acts on scalars or ndarrays or both.

View File

@ -1,4 +1,7 @@
use crate::codegen::{CodeGenContext, CodeGenerator};
use crate::codegen::{
irrt::call_nac3_ndarray_reshape_resolve_and_check_new_shape, model::*, CodeGenContext,
CodeGenerator,
};
use super::{indexing::RustNDIndex, NDArrayObject};
@ -26,4 +29,61 @@ impl<'ctx> NDArrayObject<'ctx> {
*self
}
}
/// Create a reshaped view on this ndarray like `np.reshape()`.
///
/// If there is a `-1` in `new_shape`, it will be resolved; `new_shape` would **NOT** be modified as a result.
///
/// If reshape without copying is impossible, this function will allocate a new ndarray and copy contents.
///
/// * `new_ndims` - The number of dimensions of `new_shape` as a [`Type`].
/// * `new_shape` - The target shape to do `np.reshape()`.
#[must_use]
pub fn reshape_or_copy<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
new_ndims: u64,
new_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
) -> Self {
// TODO: The current criterion for whether to do a full copy or not is by checking `is_c_contiguous`,
// but this is not optimal - there are cases when the ndarray is not contiguous but could be reshaped
// without copying data. Look into how numpy does it.
let current_bb = ctx.builder.get_insert_block().unwrap();
let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "then_bb");
let else_bb = ctx.ctx.insert_basic_block_after(then_bb, "else_bb");
let end_bb = ctx.ctx.insert_basic_block_after(else_bb, "end_bb");
let dst_ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, new_ndims);
dst_ndarray.copy_shape_from_array(generator, ctx, new_shape);
// Reolsve negative indices
let size = self.size(generator, ctx);
let dst_ndims = dst_ndarray.ndims_llvm(generator, ctx.ctx);
let dst_shape = dst_ndarray.instance.get(generator, ctx, |f| f.shape);
call_nac3_ndarray_reshape_resolve_and_check_new_shape(
generator, ctx, size, dst_ndims, dst_shape,
);
let is_c_contiguous = self.is_c_contiguous(generator, ctx);
ctx.builder.build_conditional_branch(is_c_contiguous.value, then_bb, else_bb).unwrap();
// Inserting into then_bb: reshape is possible without copying
ctx.builder.position_at_end(then_bb);
dst_ndarray.set_strides_contiguous(generator, ctx);
dst_ndarray.instance.set(ctx, |f| f.data, self.instance.get(generator, ctx, |f| f.data));
ctx.builder.build_unconditional_branch(end_bb).unwrap();
// Inserting into else_bb: reshape is impossible without copying
ctx.builder.position_at_end(else_bb);
dst_ndarray.create_data(generator, ctx);
dst_ndarray.copy_data_from(generator, ctx, *self);
ctx.builder.build_unconditional_branch(end_bb).unwrap();
// Reposition for continuation
ctx.builder.position_at_end(end_bb);
dst_ndarray
}
}

View File

@ -1,6 +1,6 @@
use std::iter::once;
use helper::{debug_assert_prim_is_allowed, make_exception_fields, PrimDefDetails};
use helper::{debug_assert_prim_is_allowed, extract_ndims, make_exception_fields, PrimDefDetails};
use indexmap::IndexMap;
use inkwell::{
attributes::{Attribute, AttributeLoc},
@ -9,13 +9,19 @@ use inkwell::{
IntPredicate,
};
use itertools::Either;
use numpy::unpack_ndarray_var_tys;
use strum::IntoEnumIterator;
use crate::{
codegen::{
builtin_fns,
classes::{ProxyValue, RangeValue},
model::*,
numpy::*,
object::{
any::AnyObject,
ndarray::{shape_util::parse_numpy_int_sequence, NDArrayObject},
},
stmt::exn_constructor,
},
symbol_resolver::SymbolValue,
@ -511,6 +517,14 @@ impl<'a> BuiltinBuilder<'a> {
| PrimDef::FunNpEye
| PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim),
PrimDef::FunNpSize | PrimDef::FunNpShape | PrimDef::FunNpStrides => {
self.build_ndarray_property_getter_function(prim)
}
PrimDef::FunNpTranspose | PrimDef::FunNpReshape => {
self.build_ndarray_view_function(prim)
}
PrimDef::FunStr => self.build_str_function(),
PrimDef::FunFloor | PrimDef::FunFloor64 | PrimDef::FunCeil | PrimDef::FunCeil64 => {
@ -576,10 +590,6 @@ impl<'a> BuiltinBuilder<'a> {
| PrimDef::FunNpHypot
| PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim),
PrimDef::FunNpTranspose | PrimDef::FunNpReshape => {
self.build_np_sp_ndarray_function(prim)
}
PrimDef::FunNpDot
| PrimDef::FunNpLinalgCholesky
| PrimDef::FunNpLinalgQr
@ -1385,6 +1395,149 @@ impl<'a> BuiltinBuilder<'a> {
}
}
fn build_ndarray_property_getter_function(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(
prim,
&[PrimDef::FunNpSize, PrimDef::FunNpShape, PrimDef::FunNpStrides],
);
let in_ndarray_ty = self.unifier.get_fresh_var_with_range(
&[self.primitives.ndarray],
Some("T".into()),
None,
);
match prim {
PrimDef::FunNpSize => create_fn_by_codegen(
self.unifier,
&into_var_map([in_ndarray_ty]),
prim.name(),
self.primitives.int32,
&[(in_ndarray_ty.ty, "a")],
Box::new(|ctx, obj, fun, args, generator| {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let ndarray_ty = fun.0.args[0].ty;
let ndarray =
args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
let size =
ndarray.size(generator, ctx).truncate_or_bit_cast(generator, ctx, Int32);
Ok(Some(size.value.as_basic_value_enum()))
}),
),
PrimDef::FunNpShape | PrimDef::FunNpStrides => {
// The function signatures of `np_shape` an `np_size` are the same.
// Mixed together for convenience.
// The return type is a tuple of variable length depending on the ndims of the input ndarray.
let ret_ty = self.unifier.get_dummy_var().ty; // Handled by special folding
create_fn_by_codegen(
self.unifier,
&into_var_map([in_ndarray_ty]),
prim.name(),
ret_ty,
&[(in_ndarray_ty.ty, "a")],
Box::new(move |ctx, obj, fun, args, generator| {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let ndarray_ty = fun.0.args[0].ty;
let ndarray =
args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
let result_tuple = match prim {
PrimDef::FunNpShape => ndarray.make_shape_tuple(generator, ctx),
PrimDef::FunNpStrides => ndarray.make_strides_tuple(generator, ctx),
_ => unreachable!(),
};
Ok(Some(result_tuple.value.as_basic_value_enum()))
}),
)
}
_ => unreachable!(),
}
}
/// Build np/sp functions that take as input `NDArray` only
fn build_ndarray_view_function(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]);
let in_ndarray_ty = self.unifier.get_fresh_var_with_range(
&[self.primitives.ndarray],
Some("T".into()),
None,
);
match prim {
PrimDef::FunNpTranspose => create_fn_by_codegen(
self.unifier,
&into_var_map([in_ndarray_ty]),
prim.name(),
in_ndarray_ty.ty,
&[(in_ndarray_ty.ty, "x")],
Box::new(move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty;
let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?))
}),
),
// NOTE: on `ndarray_factory_fn_shape_arg_tvar` and
// the `param_ty` for `create_fn_by_codegen`.
//
// Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking
// to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`],
// and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`.
PrimDef::FunNpReshape => {
let ret_ty = self.unifier.get_dummy_var().ty; // Handled by special holding
create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
ret_ty,
&[
(in_ndarray_ty.ty, "x"),
(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape"), // Handled by special folding
],
Box::new(move |ctx, _, fun, args, generator| {
let ndarray_ty = fun.0.args[0].ty;
let ndarray_val =
args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
let shape_ty = fun.0.args[1].ty;
let shape_val =
args[1].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
let ndarray = AnyObject { value: ndarray_val, ty: ndarray_ty };
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
let shape = AnyObject { value: shape_val, ty: shape_ty };
let (_, shape) = parse_numpy_int_sequence(generator, ctx, shape);
// The ndims after reshaping is gotten from the return type of the call.
let (_, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
let ndims = extract_ndims(&ctx.unifier, ndims);
let new_ndarray = ndarray.reshape_or_copy(generator, ctx, ndims, shape);
Ok(Some(new_ndarray.instance.value.as_basic_value_enum()))
}),
)
}
_ => unreachable!(),
}
}
/// Build the `str()` function.
fn build_str_function(&mut self) -> TopLevelDef {
let prim = PrimDef::FunStr;
@ -1872,57 +2025,6 @@ impl<'a> BuiltinBuilder<'a> {
}
}
/// Build np/sp functions that take as input `NDArray` only
fn build_np_sp_ndarray_function(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]);
match prim {
PrimDef::FunNpTranspose => {
let ndarray_ty = self.unifier.get_fresh_var_with_range(
&[self.ndarray_num_ty],
Some("T".into()),
None,
);
create_fn_by_codegen(
self.unifier,
&into_var_map([ndarray_ty]),
prim.name(),
ndarray_ty.ty,
&[(ndarray_ty.ty, "x")],
Box::new(move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty;
let arg_val =
args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?))
}),
)
}
// NOTE: on `ndarray_factory_fn_shape_arg_tvar` and
// the `param_ty` for `create_fn_by_codegen`.
//
// Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking
// to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`],
// and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`.
PrimDef::FunNpReshape => create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
self.ndarray_num_ty,
&[(self.ndarray_num_ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_ty = fun.0.args[1].ty;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
}),
),
_ => unreachable!(),
}
}
/// Build `np_linalg` and `sp_linalg` functions
///
/// The input to these functions must be floating point `NDArray`

View File

@ -53,6 +53,15 @@ pub enum PrimDef {
FunNpEye,
FunNpIdentity,
// NumPy ndarray property getters
FunNpSize,
FunNpShape,
FunNpStrides,
// NumPy ndarray view functions
FunNpTranspose,
FunNpReshape,
// Miscellaneous NumPy & SciPy functions
FunNpRound,
FunNpFloor,
@ -100,8 +109,6 @@ pub enum PrimDef {
FunNpLdExp,
FunNpHypot,
FunNpNextAfter,
FunNpTranspose,
FunNpReshape,
// Linalg functions
FunNpDot,
@ -239,6 +246,15 @@ impl PrimDef {
PrimDef::FunNpEye => fun("np_eye", None),
PrimDef::FunNpIdentity => fun("np_identity", None),
// NumPy NDArray property getters,
PrimDef::FunNpSize => fun("np_size", None),
PrimDef::FunNpShape => fun("np_shape", None),
PrimDef::FunNpStrides => fun("np_strides", None),
// NumPy NDArray view functions
PrimDef::FunNpTranspose => fun("np_transpose", None),
PrimDef::FunNpReshape => fun("np_reshape", None),
// Miscellaneous NumPy & SciPy functions
PrimDef::FunNpRound => fun("np_round", None),
PrimDef::FunNpFloor => fun("np_floor", None),
@ -286,8 +302,6 @@ impl PrimDef {
PrimDef::FunNpLdExp => fun("np_ldexp", None),
PrimDef::FunNpHypot => fun("np_hypot", None),
PrimDef::FunNpNextAfter => fun("np_nextafter", None),
PrimDef::FunNpTranspose => fun("np_transpose", None),
PrimDef::FunNpReshape => fun("np_reshape", None),
// Linalg functions
PrimDef::FunNpDot => fun("np_dot", None),

View File

@ -1,7 +1,7 @@
use std::cmp::max;
use std::collections::{HashMap, HashSet};
use std::convert::{From, TryInto};
use std::iter::once;
use std::iter::{self, once};
use std::{cell::RefCell, sync::Arc};
use super::{
@ -1183,6 +1183,45 @@ impl<'a> Inferencer<'a> {
}));
}
if ["np_shape".into(), "np_strides".into()].contains(id) && args.len() == 1 {
let ndarray = self.fold_expr(args.remove(0))?;
let ndims = arraylike_get_ndims(self.unifier, ndarray.custom.unwrap());
// Make a tuple of size `ndims` full of int32 (TODO: Make it usize)
let ret_ty = TypeEnum::TTuple {
ty: iter::repeat(self.primitives.int32).take(ndims as usize).collect_vec(),
is_vararg_ctx: false,
};
let ret_ty = self.unifier.add_ty(ret_ty);
let func_ty = TypeEnum::TFunc(FunSignature {
args: vec![FuncArg {
name: "a".into(),
default_value: None,
ty: ndarray.custom.unwrap(),
is_vararg: false,
}],
ret: ret_ty,
vars: VarMap::new(),
});
let func_ty = self.unifier.add_ty(func_ty);
return Ok(Some(Located {
location,
custom: Some(ret_ty),
node: ExprKind::Call {
func: Box::new(Located {
custom: Some(func_ty),
location: func.location,
node: ExprKind::Name { id: *id, ctx: *ctx },
}),
args: vec![ndarray],
keywords: vec![],
},
}));
}
if id == &"np_dot".into() {
let arg0 = self.fold_expr(args.remove(0))?;
let arg1 = self.fold_expr(args.remove(0))?;

View File

@ -179,6 +179,15 @@ def patch(module):
module.np_identity = np.identity
module.np_array = np.array
# NumPy NDArray view functions
module.np_transpose = np.transpose
module.np_reshape = np.reshape
# NumPy NDArray property getters
module.np_size = np.size
module.np_shape = np.shape
module.np_strides = lambda ndarray: ndarray.strides
# NumPy Math functions
module.np_isnan = np.isnan
module.np_isinf = np.isinf
@ -218,8 +227,6 @@ def patch(module):
module.np_ldexp = np.ldexp
module.np_hypot = np.hypot
module.np_nextafter = np.nextafter
module.np_transpose = np.transpose
module.np_reshape = np.reshape
# SciPy Math functions
module.sp_spec_erf = special.erf