forked from M-Labs/nac3
core: irrt proper ndarray subscript & more
Details: - improve irrt model - len() on ndarrays
This commit is contained in:
parent
29734ce3af
commit
0946bd86ea
@ -18,6 +18,7 @@ namespace {
|
||||
ErrorId value_error;
|
||||
ErrorId assertion_error;
|
||||
ErrorId runtime_error;
|
||||
ErrorId type_error;
|
||||
};
|
||||
|
||||
struct ErrorContext {
|
||||
|
@ -111,6 +111,21 @@ namespace { namespace ndarray { namespace basic {
|
||||
void set_pelement_value(NDArray<SizeT>* ndarray, uint8_t* pelement, const uint8_t* pvalue) {
|
||||
__builtin_memcpy(pelement, pvalue, ndarray->itemsize);
|
||||
}
|
||||
|
||||
template <typename SizeT>
|
||||
void len(ErrorContext* errctx, NDArray<SizeT>* ndarray, SliceIndex* dst_length) {
|
||||
// Error if the ndarray is "unsized" (i.e, ndims == 0)
|
||||
if (ndarray->ndims == 0) {
|
||||
// Error copied from python by doing `len(np.zeros(()))`
|
||||
errctx->set_error(
|
||||
errctx->error_ids->type_error,
|
||||
"len() of unsized object"
|
||||
);
|
||||
return; // Terminate
|
||||
}
|
||||
|
||||
*dst_length = (SliceIndex) ndarray->shape[0];
|
||||
}
|
||||
} } }
|
||||
|
||||
extern "C" {
|
||||
@ -132,6 +147,14 @@ extern "C" {
|
||||
return nbytes(ndarray);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_len(ErrorContext* errctx, NDArray<int32_t>* ndarray, SliceIndex* dst_len) {
|
||||
return len(errctx, ndarray, dst_len);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_len64(ErrorContext* errctx, NDArray<int64_t>* ndarray, SliceIndex* dst_len) {
|
||||
return len(errctx, ndarray, dst_len);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_util_assert_shape_no_negative(ErrorContext* errctx, int32_t ndims, int32_t* shape) {
|
||||
util::assert_shape_no_negative(errctx, ndims, shape);
|
||||
}
|
||||
|
@ -6,6 +6,8 @@
|
||||
#include <irrt/error_context.hpp>
|
||||
|
||||
namespace {
|
||||
typedef uint32_t NumNDSubscriptsType;
|
||||
|
||||
typedef uint8_t NDSubscriptType;
|
||||
|
||||
const NDSubscriptType INPUT_SUBSCRIPT_TYPE_INDEX = 0;
|
||||
@ -30,16 +32,25 @@ namespace {
|
||||
namespace { namespace ndarray { namespace subscript {
|
||||
namespace util {
|
||||
template<typename SizeT>
|
||||
SizeT deduce_ndims_after_slicing(SizeT ndims, SizeT num_subscripts, const NDSubscript* subscripts) {
|
||||
irrt_assert(num_subscripts <= ndims);
|
||||
void deduce_ndims_after_slicing(ErrorContext* errctx, SizeT* result, SizeT ndims, SizeT num_ndsubscripts, const NDSubscript* ndsubscripts) {
|
||||
if (num_ndsubscripts > ndims) {
|
||||
// Error copied from python by doing `np.zeros((3, 4))[:, :, :]`
|
||||
errctx->set_error(
|
||||
errctx->error_ids->index_error,
|
||||
"too many indices for array: array is {0}-dimensional, but {1} were indexed",
|
||||
ndims, num_ndsubscripts
|
||||
);
|
||||
return; // Terminate
|
||||
}
|
||||
|
||||
SizeT final_ndims = ndims;
|
||||
for (SizeT i = 0; i < num_subscripts; i++) {
|
||||
if (subscripts[i].type == INPUT_SUBSCRIPT_TYPE_INDEX) {
|
||||
for (SizeT i = 0; i < num_ndsubscripts; i++) {
|
||||
if (ndsubscripts[i].type == INPUT_SUBSCRIPT_TYPE_INDEX) {
|
||||
final_ndims--; // An index demotes the rank by 1
|
||||
}
|
||||
}
|
||||
return final_ndims;
|
||||
|
||||
*result = final_ndims;
|
||||
}
|
||||
}
|
||||
|
||||
@ -61,7 +72,7 @@ namespace { namespace ndarray { namespace subscript {
|
||||
// - `dst_ndarray->itemsize` does not have to be set, it will be set to `src_ndarray->itemsize`
|
||||
// - `dst_ndarray->shape` and `dst_ndarray.strides` can contain empty values
|
||||
template <typename SizeT>
|
||||
void subscript(ErrorContext* errctx, SizeT num_subscripts, NDSubscript* subscripts, NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
||||
void subscript(ErrorContext* errctx, NumNDSubscriptsType num_subscripts, NDSubscript* subscripts, NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
||||
// REFERENCE CODE (check out `_index_helper` in `__getitem__`):
|
||||
// https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L652
|
||||
|
||||
@ -142,11 +153,19 @@ namespace { namespace ndarray { namespace subscript {
|
||||
extern "C" {
|
||||
using namespace ndarray::subscript;
|
||||
|
||||
void __nac3_ndarray_subscript(ErrorContext* errctx, int32_t num_subscripts, NDSubscript* subscripts, NDArray<int32_t>* src_ndarray, NDArray<int32_t> *dst_ndarray) {
|
||||
void __nac3_ndarray_subscript_deduce_ndims_after_slicing(ErrorContext* errctx, int32_t* result, int32_t ndims, int32_t num_ndsubscripts, const NDSubscript* ndsubscripts) {
|
||||
ndarray::subscript::util::deduce_ndims_after_slicing(errctx, result, ndims, num_ndsubscripts, ndsubscripts);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_subscript_deduce_ndims_after_slicing64(ErrorContext* errctx, int64_t* result, int64_t ndims, int64_t num_ndsubscripts, const NDSubscript* ndsubscripts) {
|
||||
ndarray::subscript::util::deduce_ndims_after_slicing(errctx, result, ndims, num_ndsubscripts, ndsubscripts);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_subscript(ErrorContext* errctx, NumNDSubscriptsType num_subscripts, NDSubscript* subscripts, NDArray<int32_t>* src_ndarray, NDArray<int32_t> *dst_ndarray) {
|
||||
subscript(errctx, num_subscripts, subscripts, src_ndarray, dst_ndarray);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_subscript64(ErrorContext* errctx, int64_t num_subscripts, NDSubscript* subscripts, NDArray<int64_t>* src_ndarray, NDArray<int64_t> *dst_ndarray) {
|
||||
void __nac3_ndarray_subscript64(ErrorContext* errctx, NumNDSubscriptsType num_subscripts, NDSubscript* subscripts, NDArray<int64_t>* src_ndarray, NDArray<int64_t> *dst_ndarray) {
|
||||
subscript(errctx, num_subscripts, subscripts, src_ndarray, dst_ndarray);
|
||||
}
|
||||
}
|
@ -4,11 +4,18 @@ use crate::{
|
||||
codegen::{
|
||||
classes::{
|
||||
ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, NDArrayValue, ProxyType,
|
||||
ProxyValue, RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor,
|
||||
ProxyValue, RangeValue, UntypedArrayLikeAccessor,
|
||||
},
|
||||
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
|
||||
gen_in_range_check, get_llvm_abi_type, get_llvm_type,
|
||||
irrt::*,
|
||||
irrt::{
|
||||
numpy::{
|
||||
ndarray::{self, alloca_ndarray_and_init},
|
||||
slice::{RustUserSlice, SliceIndexModel},
|
||||
subscript::{call_nac3_ndarray_subscript, RustNDSubscript},
|
||||
},
|
||||
*,
|
||||
},
|
||||
llvm_intrinsics::{
|
||||
call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax,
|
||||
call_memcpy_generic,
|
||||
@ -18,14 +25,10 @@ use crate::{
|
||||
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
|
||||
gen_var,
|
||||
},
|
||||
CodeGenContext, CodeGenTask, CodeGenerator,
|
||||
CodeGenContext, CodeGenTask, CodeGenerator, Int32,
|
||||
},
|
||||
symbol_resolver::{SymbolValue, ValueEnum},
|
||||
toplevel::{
|
||||
helper::PrimDef,
|
||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
||||
DefinitionId, TopLevelDef,
|
||||
},
|
||||
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef},
|
||||
typecheck::{
|
||||
magic_methods::{Binop, BinopVariant, HasOpInfo},
|
||||
typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap},
|
||||
@ -34,15 +37,19 @@ use crate::{
|
||||
use inkwell::{
|
||||
attributes::{Attribute, AttributeLoc},
|
||||
types::{AnyType, BasicType, BasicTypeEnum},
|
||||
values::{BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue},
|
||||
values::{
|
||||
AnyValue, BasicValue, BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue,
|
||||
},
|
||||
AddressSpace, IntPredicate, OptimizationLevel,
|
||||
};
|
||||
use itertools::{chain, izip, Either, Itertools};
|
||||
use nac3parser::ast::{
|
||||
self, Boolop, Cmpop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef,
|
||||
Unaryop,
|
||||
self, Boolop, Cmpop, Comprehension, Constant, Expr, ExprKind, Located, Location, Operator,
|
||||
StrRef, Unaryop,
|
||||
};
|
||||
|
||||
use super::{irrt::numpy::ndarray::NpArray, IntModel, Model, Pointer, PointerModel, StructModel};
|
||||
|
||||
pub fn get_subst_key(
|
||||
unifier: &mut Unifier,
|
||||
obj: Option<Type>,
|
||||
@ -2130,322 +2137,151 @@ pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>(
|
||||
|
||||
/// Generates code for a subscript expression on an `ndarray`.
|
||||
///
|
||||
/// * `ty` - The `Type` of the `NDArray` elements.
|
||||
/// * `elem_ty` - The `Type` of the `NDArray` elements.
|
||||
/// * `ndims` - The `Type` of the `NDArray` number-of-dimensions `Literal`.
|
||||
/// * `v` - The `NDArray` value.
|
||||
/// * `src_ndarray` - The `NDArray` value.
|
||||
/// * `slice` - The slice expression used to subscript into the `ndarray`.
|
||||
fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ty: Type,
|
||||
elem_ty: Type,
|
||||
ndims: Type,
|
||||
v: NDArrayValue<'ctx>,
|
||||
src_ndarray: Pointer<'ctx, StructModel<NpArray<'ctx>>>,
|
||||
slice: &Expr<Option<Type>>,
|
||||
) -> Result<Option<ValueEnum<'ctx>>, String> {
|
||||
let llvm_i1 = ctx.ctx.bool_type();
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
// TODO: Support https://numpy.org/doc/stable/user/basics.indexing.html#dimensional-indexing-tools
|
||||
|
||||
let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) else {
|
||||
let sizet = src_ndarray.element.0.sizet;
|
||||
debug_assert_eq!(sizet.0, generator.get_size_type(ctx.ctx)); // If the ndarray's size_type somehow isn't that of `generator.get_size_type()`... there would be a bug
|
||||
|
||||
let slice_index_model = SliceIndexModel::default();
|
||||
|
||||
// Annoying notes about `slice`
|
||||
// - `my_array[5]`
|
||||
// - slice is a `Constant`
|
||||
// - `my_array[:5]`
|
||||
// - slice is a `Slice`
|
||||
// - `my_array[:]`
|
||||
// - slice is a `Slice`, but lower upper step would all be `Option::None`
|
||||
// - `my_array[:, :]`
|
||||
// - slice is now a `Tuple` of two `Slice`-s
|
||||
//
|
||||
// In summary:
|
||||
// - when there is a comma "," within [], `slice` will be a `Tuple` of the entries.
|
||||
// - when there is not comma "," within [] (i.e., just a single entry), `slice` will be that entry itself.
|
||||
//
|
||||
// So we first "flatten" out the slice expression
|
||||
let subscript_exprs = match &slice.node {
|
||||
ExprKind::Tuple { elts, .. } => elts.iter().collect_vec(),
|
||||
_ => vec![slice],
|
||||
};
|
||||
|
||||
// Process all subscript expressions in subscripts
|
||||
let mut rust_ndsubscripts: Vec<RustNDSubscript> = Vec::with_capacity(subscript_exprs.len()); // Not using iterators here because `?` is used here.
|
||||
for subscript_expr in subscript_exprs {
|
||||
// NOTE: Currently nac3core's slices do not have an object representation,
|
||||
// so the code/implementation looks awkward - we have to do pattern matching on the expression
|
||||
let ndsubscript =
|
||||
if let ExprKind::Slice { lower: start, upper: stop, step } = &subscript_expr.node {
|
||||
// Helper function here to deduce code duplication
|
||||
type ValueExpr = Option<Box<Located<ExprKind<Option<Type>>, Option<Type>>>>;
|
||||
let mut help = |value_expr: &ValueExpr| -> Result<_, String> {
|
||||
Ok(match value_expr {
|
||||
None => None,
|
||||
Some(value_expr) => Some(
|
||||
slice_index_model.check_llvm_value(
|
||||
generator
|
||||
.gen_expr(ctx, value_expr)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.int32)?
|
||||
.as_any_value_enum(),
|
||||
),
|
||||
),
|
||||
})
|
||||
};
|
||||
|
||||
let start = help(start)?;
|
||||
let stop = help(stop)?;
|
||||
let step = help(step)?;
|
||||
|
||||
RustNDSubscript::Slice(RustUserSlice { start, stop, step })
|
||||
} else {
|
||||
// Anything else that is not a slice (might be illegal values),
|
||||
// For nac3core, this should be e.g., an int32 constant, an int32 variable, otherwise its an error
|
||||
|
||||
let index = slice_index_model.check_llvm_value(
|
||||
generator
|
||||
.gen_expr(ctx, subscript_expr)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.int32)?
|
||||
.as_any_value_enum(),
|
||||
);
|
||||
|
||||
RustNDSubscript::Index(index)
|
||||
};
|
||||
rust_ndsubscripts.push(ndsubscript);
|
||||
}
|
||||
|
||||
// Extract the `ndims` from a `Type` to `i128`
|
||||
// We *HAVE* to know this statically, this is used to determine
|
||||
// whether the subscript returns a scalar or an ndarray
|
||||
let TypeEnum::TLiteral { values: ndims_values, .. } = &*ctx.unifier.get_ty_immutable(ndims)
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
assert_eq!(ndims_values.len(), 1);
|
||||
let src_ndims = i128::try_from(ndims_values[0].clone()).unwrap();
|
||||
|
||||
let ndims = values
|
||||
.iter()
|
||||
.map(|ndim| u64::try_from(ndim.clone()).map_err(|()| ndim.clone()))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|val| {
|
||||
format!(
|
||||
"Expected non-negative literal for ndarray.ndims, got {}",
|
||||
i128::try_from(val).unwrap()
|
||||
)
|
||||
})?;
|
||||
|
||||
assert!(!ndims.is_empty());
|
||||
|
||||
// The number of dimensions subscripted by the index expression.
|
||||
// Slicing a ndarray will yield the same number of dimensions, whereas indexing into a
|
||||
// dimension will remove a dimension.
|
||||
let subscripted_dims = match &slice.node {
|
||||
ExprKind::Tuple { elts, .. } => elts.iter().fold(0, |acc, value_subexpr| {
|
||||
if let ExprKind::Slice { .. } = &value_subexpr.node {
|
||||
acc
|
||||
} else {
|
||||
acc + 1
|
||||
}
|
||||
}),
|
||||
|
||||
ExprKind::Slice { .. } => 0,
|
||||
_ => 1,
|
||||
};
|
||||
|
||||
let ndarray_ndims_ty = ctx.unifier.get_fresh_literal(
|
||||
ndims.iter().map(|v| SymbolValue::U64(v - subscripted_dims)).collect(),
|
||||
None,
|
||||
);
|
||||
let ndarray_ty =
|
||||
make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(ty), Some(ndarray_ndims_ty));
|
||||
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
|
||||
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
|
||||
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum();
|
||||
|
||||
// Check that len is non-zero
|
||||
let len = v.load_ndims(ctx);
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder.build_int_compare(IntPredicate::SGT, len, llvm_usize.const_zero(), "").unwrap(),
|
||||
"0:IndexError",
|
||||
"too many indices for array: array is {0}-dimensional but 1 were indexed",
|
||||
[Some(len), None, None],
|
||||
slice.location,
|
||||
);
|
||||
|
||||
// Normalizes a possibly-negative index to its corresponding positive index
|
||||
let normalize_index = |generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
index: IntValue<'ctx>,
|
||||
dim: u64| {
|
||||
gen_if_else_expr_callback(
|
||||
// Check for "too many indices for array: array is ..." error
|
||||
if src_ndims < rust_ndsubscripts.len() as i128 {
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx,
|
||||
|_, ctx| {
|
||||
Ok(ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::SGE, index, index.get_type().const_zero(), "")
|
||||
.unwrap())
|
||||
},
|
||||
|_, _| Ok(Some(index)),
|
||||
|generator, ctx| {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
ctx.ctx.bool_type().const_int(1, false),
|
||||
"0:IndexError",
|
||||
"too many indices for array: array is {0}-dimensional, but {1} were indexed",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
}
|
||||
|
||||
let len = unsafe {
|
||||
v.dim_sizes().get_typed_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(dim, true),
|
||||
None,
|
||||
)
|
||||
};
|
||||
// Statically determine `dst_ndims`
|
||||
let dst_ndims =
|
||||
RustNDSubscript::deduce_ndims_after_slicing(&rust_ndsubscripts, src_ndims as i32);
|
||||
|
||||
let index = ctx
|
||||
.builder
|
||||
.build_int_add(
|
||||
len,
|
||||
ctx.builder.build_int_s_extend(index, llvm_usize, "").unwrap(),
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
// Prepare dst_ndarray
|
||||
let elem_llvm_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
let dst_ndarray = alloca_ndarray_and_init(
|
||||
generator,
|
||||
ctx,
|
||||
elem_llvm_ty,
|
||||
ndarray::NDArrayInitMode::NDims { ndims: sizet.constant(dst_ndims as u64) },
|
||||
"subndarray",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
Ok(Some(ctx.builder.build_int_truncate(index, llvm_i32, "").unwrap()))
|
||||
},
|
||||
)
|
||||
.map(|v| v.map(BasicValueEnum::into_int_value))
|
||||
};
|
||||
// Prepare the subscripts
|
||||
let ndsubscript_array = RustNDSubscript::alloca_subscripts(ctx, &rust_ndsubscripts);
|
||||
|
||||
// Converts a slice expression into a slice-range tuple
|
||||
let expr_to_slice = |generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
node: &ExprKind<Option<Type>>,
|
||||
dim: u64| {
|
||||
match node {
|
||||
ExprKind::Constant { value: Constant::Int(v), .. } => {
|
||||
let Some(index) =
|
||||
normalize_index(generator, ctx, llvm_i32.const_int(*v as u64, true), dim)?
|
||||
else {
|
||||
return Ok(None);
|
||||
};
|
||||
// NOTE: IRRT does check for indexing errors
|
||||
call_nac3_ndarray_subscript(
|
||||
generator,
|
||||
ctx,
|
||||
ndsubscript_array.num_elements.signed_cast_to_fixed(ctx, Int32, "num_ndsubscripts"),
|
||||
ndsubscript_array.pointer,
|
||||
src_ndarray,
|
||||
dst_ndarray,
|
||||
);
|
||||
|
||||
Ok(Some((index, index, llvm_i32.const_int(1, true))))
|
||||
}
|
||||
|
||||
ExprKind::Slice { lower, upper, step } => {
|
||||
let dim_sz = unsafe {
|
||||
v.dim_sizes().get_typed_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(dim, false),
|
||||
None,
|
||||
)
|
||||
};
|
||||
|
||||
handle_slice_indices(lower, upper, step, ctx, generator, dim_sz)
|
||||
}
|
||||
|
||||
_ => {
|
||||
let Some(index) = generator.gen_expr(ctx, slice)? else { return Ok(None) };
|
||||
let index = index
|
||||
.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?
|
||||
.into_int_value();
|
||||
let Some(index) = normalize_index(generator, ctx, index, dim)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
Ok(Some((index, index, llvm_i32.const_int(1, true))))
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let make_indices_arr = |generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>|
|
||||
-> Result<_, String> {
|
||||
Ok(if let ExprKind::Tuple { elts, .. } = &slice.node {
|
||||
let llvm_int_ty = ctx.get_llvm_type(generator, elts[0].custom.unwrap());
|
||||
let index_addr = generator.gen_array_var_alloc(
|
||||
ctx,
|
||||
llvm_int_ty,
|
||||
llvm_usize.const_int(elts.len() as u64, false),
|
||||
None,
|
||||
)?;
|
||||
|
||||
for (i, elt) in elts.iter().enumerate() {
|
||||
let Some(index) = generator.gen_expr(ctx, elt)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let index = index
|
||||
.to_basic_value_enum(ctx, generator, elt.custom.unwrap())?
|
||||
.into_int_value();
|
||||
let Some(index) = normalize_index(generator, ctx, index, 0)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let store_ptr = unsafe {
|
||||
index_addr.ptr_offset_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(i as u64, false),
|
||||
None,
|
||||
)
|
||||
};
|
||||
ctx.builder.build_store(store_ptr, index).unwrap();
|
||||
}
|
||||
|
||||
Some(index_addr)
|
||||
} else if let Some(index) = generator.gen_expr(ctx, slice)? {
|
||||
let llvm_int_ty = ctx.get_llvm_type(generator, slice.custom.unwrap());
|
||||
let index_addr = generator.gen_array_var_alloc(
|
||||
ctx,
|
||||
llvm_int_ty,
|
||||
llvm_usize.const_int(1u64, false),
|
||||
None,
|
||||
)?;
|
||||
|
||||
let index =
|
||||
index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value();
|
||||
let Some(index) = normalize_index(generator, ctx, index, 0)? else { return Ok(None) };
|
||||
|
||||
let store_ptr = unsafe {
|
||||
index_addr.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
};
|
||||
ctx.builder.build_store(store_ptr, index).unwrap();
|
||||
|
||||
Some(index_addr)
|
||||
} else {
|
||||
None
|
||||
})
|
||||
};
|
||||
|
||||
Ok(Some(if ndims.len() == 1 && ndims[0] - subscripted_dims == 0 {
|
||||
let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) };
|
||||
|
||||
v.data().get(ctx, generator, &index_addr, None).into()
|
||||
// ...and return the result, with two cases
|
||||
let result_llvm_value = if dst_ndims == 0 {
|
||||
// 1) ndims == 0 (this happens when you do `np.zerps((3, 4))[1, 1]`), return *THE ELEMENT*
|
||||
let element_ptr = dst_ndarray.gep(ctx, |f| f.data).load(ctx, "pelement"); // `*data` points to the first element by definition
|
||||
element_ptr.cast_opaque_to(ctx, elem_llvm_ty, "").load_opaque(ctx, "element")
|
||||
} else {
|
||||
match &slice.node {
|
||||
ExprKind::Tuple { elts, .. } => {
|
||||
let slices = elts
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(dim, elt)| expr_to_slice(generator, ctx, &elt.node, dim as u64))
|
||||
.take_while_inclusive(|slice| slice.as_ref().is_ok_and(Option::is_some))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
if slices.len() < elts.len() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let slices = slices.into_iter().map(Option::unwrap).collect_vec();
|
||||
|
||||
numpy::ndarray_sliced_copy(generator, ctx, ty, v, &slices)?.as_base_value().into()
|
||||
}
|
||||
|
||||
ExprKind::Slice { .. } => {
|
||||
let Some(slice) = expr_to_slice(generator, ctx, &slice.node, 0)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
numpy::ndarray_sliced_copy(generator, ctx, ty, v, &[slice])?.as_base_value().into()
|
||||
}
|
||||
|
||||
_ => {
|
||||
// Accessing an element from a multi-dimensional `ndarray`
|
||||
|
||||
let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) };
|
||||
|
||||
// Create a new array, remove the top dimension from the dimension-size-list, and copy the
|
||||
// elements over
|
||||
let subscripted_ndarray =
|
||||
generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?;
|
||||
let ndarray = NDArrayValue::from_ptr_val(subscripted_ndarray, llvm_usize, None);
|
||||
|
||||
let num_dims = v.load_ndims(ctx);
|
||||
ndarray.store_ndims(
|
||||
ctx,
|
||||
generator,
|
||||
ctx.builder
|
||||
.build_int_sub(num_dims, llvm_usize.const_int(1, false), "")
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
|
||||
|
||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||
let v_dims_src_ptr = unsafe {
|
||||
v.dim_sizes().ptr_offset_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(1, false),
|
||||
None,
|
||||
)
|
||||
};
|
||||
call_memcpy_generic(
|
||||
ctx,
|
||||
ndarray.dim_sizes().base_ptr(ctx, generator),
|
||||
v_dims_src_ptr,
|
||||
ctx.builder
|
||||
.build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "")
|
||||
.map(Into::into)
|
||||
.unwrap(),
|
||||
llvm_i1.const_zero(),
|
||||
);
|
||||
|
||||
let ndarray_num_elems = call_ndarray_calc_size(
|
||||
generator,
|
||||
ctx,
|
||||
&ndarray.dim_sizes().as_slice_value(ctx, generator),
|
||||
(None, None),
|
||||
);
|
||||
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
||||
|
||||
let v_data_src_ptr = v.data().ptr_offset(ctx, generator, &index_addr, None);
|
||||
call_memcpy_generic(
|
||||
ctx,
|
||||
ndarray.data().base_ptr(ctx, generator),
|
||||
v_data_src_ptr,
|
||||
ctx.builder
|
||||
.build_int_mul(
|
||||
ndarray_num_elems,
|
||||
llvm_ndarray_data_t.size_of().unwrap(),
|
||||
"",
|
||||
)
|
||||
.map(Into::into)
|
||||
.unwrap(),
|
||||
llvm_i1.const_zero(),
|
||||
);
|
||||
|
||||
ndarray.as_base_value().into()
|
||||
}
|
||||
}
|
||||
}))
|
||||
// 2) ndims > 0 (other cases), return subndarray
|
||||
dst_ndarray.value.as_basic_value_enum()
|
||||
};
|
||||
Ok(Some(ValueEnum::Dynamic(result_llvm_value)))
|
||||
}
|
||||
|
||||
/// See [`CodeGenerator::gen_expr`].
|
||||
@ -3088,17 +2924,26 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||
}
|
||||
}
|
||||
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||
let (ty, ndims) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap();
|
||||
let (elem_ty, ndims) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap();
|
||||
|
||||
let v = if let Some(v) = generator.gen_expr(ctx, value)? {
|
||||
v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?
|
||||
.into_pointer_value()
|
||||
let ndarray_ptr = if let Some(v) = generator.gen_expr(ctx, value)? {
|
||||
let sizet = IntModel(generator.get_size_type(ctx.ctx));
|
||||
let ndarray_ptr_model = PointerModel(StructModel(NpArray { sizet }));
|
||||
|
||||
let v = v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?;
|
||||
ndarray_ptr_model.check_llvm_value(v.as_any_value_enum())
|
||||
} else {
|
||||
return Ok(None);
|
||||
};
|
||||
let v = NDArrayValue::from_ptr_val(v, usize, None);
|
||||
|
||||
return gen_ndarray_subscript_expr(generator, ctx, *ty, *ndims, v, slice);
|
||||
return gen_ndarray_subscript_expr(
|
||||
generator,
|
||||
ctx,
|
||||
*elem_ty,
|
||||
*ndims,
|
||||
ndarray_ptr,
|
||||
slice,
|
||||
);
|
||||
}
|
||||
TypeEnum::TTuple { .. } => {
|
||||
let index: u32 =
|
||||
|
@ -7,7 +7,7 @@ pub struct StrFields<'ctx> {
|
||||
pub length: Field<IntModel<'ctx>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct Str<'ctx> {
|
||||
pub sizet: IntModel<'ctx>,
|
||||
}
|
||||
@ -33,8 +33,10 @@ pub struct ErrorIdsFields {
|
||||
pub value_error: Field<FixedIntModel<ErrorId>>,
|
||||
pub assertion_error: Field<FixedIntModel<ErrorId>>,
|
||||
pub runtime_error: Field<FixedIntModel<ErrorId>>,
|
||||
pub type_error: Field<FixedIntModel<ErrorId>>,
|
||||
}
|
||||
#[derive(Debug, Clone)]
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ErrorIds;
|
||||
|
||||
impl<'ctx> IsStruct<'ctx> for ErrorIds {
|
||||
@ -50,6 +52,7 @@ impl<'ctx> IsStruct<'ctx> for ErrorIds {
|
||||
value_error: builder.add_field_auto("value_error"),
|
||||
assertion_error: builder.add_field_auto("assertion_error"),
|
||||
runtime_error: builder.add_field_auto("runtime_error"),
|
||||
type_error: builder.add_field_auto("type_error"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -61,7 +64,8 @@ pub struct ErrorContextFields {
|
||||
pub param2: Field<FixedIntModel<Int64>>,
|
||||
pub param3: Field<FixedIntModel<Int64>>,
|
||||
}
|
||||
#[derive(Debug, Clone)]
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ErrorContext;
|
||||
|
||||
impl<'ctx> IsStruct<'ctx> for ErrorContext {
|
||||
@ -92,46 +96,47 @@ fn build_error_ids<'ctx>(ctx: &CodeGenContext<'ctx, '_>) -> Pointer<'ctx, Struct
|
||||
let get_string_id =
|
||||
|string_id| i32_model.constant(ctx.ctx, ctx.resolver.get_string_id(string_id) as u64);
|
||||
|
||||
error_ids.gep(ctx, |f| f.index_error).store(ctx, &get_string_id("0:IndexError"));
|
||||
error_ids.gep(ctx, |f| f.value_error).store(ctx, &get_string_id("0:ValueError"));
|
||||
error_ids.gep(ctx, |f| f.assertion_error).store(ctx, &get_string_id("0:AssertionError"));
|
||||
error_ids.gep(ctx, |f| f.runtime_error).store(ctx, &get_string_id("0:RuntimeError"));
|
||||
error_ids.gep(ctx, |f| f.index_error).store(ctx, get_string_id("0:IndexError"));
|
||||
error_ids.gep(ctx, |f| f.value_error).store(ctx, get_string_id("0:ValueError"));
|
||||
error_ids.gep(ctx, |f| f.assertion_error).store(ctx, get_string_id("0:AssertionError"));
|
||||
error_ids.gep(ctx, |f| f.runtime_error).store(ctx, get_string_id("0:RuntimeError"));
|
||||
error_ids.gep(ctx, |f| f.type_error).store(ctx, get_string_id("0:TypeError"));
|
||||
|
||||
error_ids
|
||||
}
|
||||
|
||||
pub fn call_nac3_error_context_initialize<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
perrctx: &Pointer<'ctx, StructModel<ErrorContext>>,
|
||||
perror_ids: &Pointer<'ctx, StructModel<ErrorIds>>,
|
||||
perrctx: Pointer<'ctx, StructModel<ErrorContext>>,
|
||||
perror_ids: Pointer<'ctx, StructModel<ErrorIds>>,
|
||||
) {
|
||||
FunctionBuilder::begin(ctx, "__nac3_error_context_initialize")
|
||||
.arg("errctx", &PointerModel(StructModel(ErrorContext)), perrctx)
|
||||
.arg("error_ids", &PointerModel(StructModel(ErrorIds)), perror_ids)
|
||||
.arg("errctx", PointerModel(StructModel(ErrorContext)), perrctx)
|
||||
.arg("error_ids", PointerModel(StructModel(ErrorIds)), perror_ids)
|
||||
.returning_void();
|
||||
}
|
||||
|
||||
pub fn call_nac3_error_context_has_no_error<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
errctx: &Pointer<'ctx, StructModel<ErrorContext>>,
|
||||
errctx: Pointer<'ctx, StructModel<ErrorContext>>,
|
||||
) -> FixedInt<'ctx, Bool> {
|
||||
FunctionBuilder::begin(ctx, "__nac3_error_context_has_no_error")
|
||||
.arg("errctx", &PointerModel(StructModel(ErrorContext)), errctx)
|
||||
.returning("has_error", &FixedIntModel(Bool))
|
||||
.arg("errctx", PointerModel(StructModel(ErrorContext)), errctx)
|
||||
.returning("has_error", FixedIntModel(Bool))
|
||||
}
|
||||
|
||||
pub fn call_nac3_error_context_get_error_str<'ctx>(
|
||||
sizet: IntModel<'ctx>,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
errctx: &Pointer<'ctx, StructModel<ErrorContext>>,
|
||||
dst_str: &Pointer<'ctx, StructModel<Str<'ctx>>>,
|
||||
errctx: Pointer<'ctx, StructModel<ErrorContext>>,
|
||||
dst_str: Pointer<'ctx, StructModel<Str<'ctx>>>,
|
||||
) {
|
||||
FunctionBuilder::begin(
|
||||
ctx,
|
||||
&get_sized_dependent_function_name(sizet, "__nac3_error_context_get_error_str"),
|
||||
)
|
||||
.arg("errctx", &PointerModel(StructModel(ErrorContext)), errctx)
|
||||
.arg("dst_str", &PointerModel(StructModel(Str { sizet })), dst_str)
|
||||
.arg("errctx", PointerModel(StructModel(ErrorContext)), errctx)
|
||||
.arg("dst_str", PointerModel(StructModel(Str { sizet })), dst_str)
|
||||
.returning_void();
|
||||
}
|
||||
|
||||
@ -140,20 +145,20 @@ pub fn prepare_error_context<'ctx>(
|
||||
) -> Pointer<'ctx, StructModel<ErrorContext>> {
|
||||
let error_ids = build_error_ids(ctx);
|
||||
let errctx_ptr = StructModel(ErrorContext).alloca(ctx, "errctx");
|
||||
call_nac3_error_context_initialize(ctx, &errctx_ptr, &error_ids);
|
||||
call_nac3_error_context_initialize(ctx, errctx_ptr, error_ids);
|
||||
errctx_ptr
|
||||
}
|
||||
|
||||
pub fn check_error_context<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
errctx_ptr: &Pointer<'ctx, StructModel<ErrorContext>>,
|
||||
errctx_ptr: Pointer<'ctx, StructModel<ErrorContext>>,
|
||||
) {
|
||||
let sizet = IntModel(generator.get_size_type(ctx.ctx));
|
||||
|
||||
let has_error = call_nac3_error_context_has_no_error(ctx, errctx_ptr);
|
||||
let pstr = StructModel(Str { sizet }).alloca(ctx, "error_str");
|
||||
call_nac3_error_context_get_error_str(sizet, ctx, errctx_ptr, &pstr);
|
||||
call_nac3_error_context_get_error_str(sizet, ctx, errctx_ptr, pstr);
|
||||
|
||||
let error_id = errctx_ptr.gep(ctx, |f| f.error_id).load(ctx, "error_id");
|
||||
let error_str = pstr.load(ctx, "error_str");
|
||||
@ -176,7 +181,7 @@ pub fn call_nac3_dummy_raise<G: CodeGenerator + ?Sized>(
|
||||
) {
|
||||
let errctx = prepare_error_context(ctx);
|
||||
FunctionBuilder::begin(ctx, "__nac3_error_dummy_raise")
|
||||
.arg("errctx", &PointerModel(StructModel(ErrorContext)), &errctx)
|
||||
.arg("errctx", PointerModel(StructModel(ErrorContext)), errctx)
|
||||
.returning_void();
|
||||
check_error_context(generator, ctx, &errctx);
|
||||
check_error_context(generator, ctx, errctx);
|
||||
}
|
||||
|
@ -1,5 +1,4 @@
|
||||
pub mod ndarray;
|
||||
pub mod shape;
|
||||
|
||||
pub use ndarray::*;
|
||||
pub use shape::*;
|
||||
pub mod slice;
|
||||
pub mod subscript;
|
||||
|
@ -9,10 +9,13 @@ use crate::codegen::{
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
use super::Producer;
|
||||
use super::{
|
||||
shape::Producer,
|
||||
slice::{SliceIndex, SliceIndexModel},
|
||||
};
|
||||
|
||||
pub struct NpArrayFields<'ctx> {
|
||||
pub data: Field<OpaquePointerModel>,
|
||||
pub data: Field<PointerModel<ByteModel>>,
|
||||
pub itemsize: Field<IntModel<'ctx>>,
|
||||
pub ndims: Field<IntModel<'ctx>>,
|
||||
pub shape: Field<PointerModel<IntModel<'ctx>>>,
|
||||
@ -63,7 +66,7 @@ pub fn alloca_ndarray<'ctx, G>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
elem_type: BasicTypeEnum<'ctx>,
|
||||
ndims: &Int<'ctx>,
|
||||
ndims: Int<'ctx>,
|
||||
name: &str,
|
||||
) -> Result<Pointer<'ctx, StructModel<NpArray<'ctx>>>, String>
|
||||
where
|
||||
@ -78,19 +81,16 @@ where
|
||||
ndarray_ptr.gep(ctx, |f| f.ndims).store(ctx, ndims);
|
||||
|
||||
// Set itemsize
|
||||
let itemsize = elem_type.size_of().unwrap();
|
||||
let itemsize =
|
||||
ctx.builder.build_int_s_extend_or_bit_cast(itemsize, sizet.0, "itemsize").unwrap();
|
||||
ndarray_ptr.gep(ctx, |f| f.itemsize).store(ctx, &Int(itemsize));
|
||||
let itemsize = Int(elem_type.size_of().unwrap());
|
||||
ndarray_ptr.gep(ctx, |f| f.itemsize).store(ctx, itemsize.signed_cast_to_int(ctx, sizet, ""));
|
||||
|
||||
// Allocate and set shape
|
||||
let shape_ptr = ctx.builder.build_array_alloca(sizet.0, ndims.0, "shape").unwrap();
|
||||
ndarray_ptr.gep(ctx, |f| f.shape).store(ctx, &Pointer { element: sizet, value: shape_ptr });
|
||||
// .store(ctx, &Pointer { addressee_optic: IntLens(sizet), address: shape_ptr });
|
||||
let shape_array = sizet.array_alloca(ctx, ndims, "shape");
|
||||
ndarray_ptr.gep(ctx, |f| f.shape).store(ctx, shape_array.pointer);
|
||||
|
||||
// Allocate and set strides
|
||||
let strides_ptr = ctx.builder.build_array_alloca(sizet.0, ndims.0, "strides").unwrap();
|
||||
ndarray_ptr.gep(ctx, |f| f.strides).store(ctx, &Pointer { element: sizet, value: strides_ptr });
|
||||
let strides_array = sizet.array_alloca(ctx, ndims, "strides");
|
||||
ndarray_ptr.gep(ctx, |f| f.strides).store(ctx, strides_array.pointer);
|
||||
|
||||
Ok(ndarray_ptr)
|
||||
}
|
||||
@ -115,12 +115,12 @@ where
|
||||
// It is implemented verbosely in order to make the initialization modes super clear in their intent.
|
||||
match init_mode {
|
||||
NDArrayInitMode::NDims { ndims } => {
|
||||
let ndarray_ptr = alloca_ndarray(generator, ctx, elem_type, &ndims, name)?;
|
||||
let ndarray_ptr = alloca_ndarray(generator, ctx, elem_type, ndims, name)?;
|
||||
Ok(ndarray_ptr)
|
||||
}
|
||||
NDArrayInitMode::Shape { shape } => {
|
||||
let ndims = shape.count;
|
||||
let ndarray_ptr = alloca_ndarray(generator, ctx, elem_type, &ndims, name)?;
|
||||
let ndarray_ptr = alloca_ndarray(generator, ctx, elem_type, ndims, name)?;
|
||||
|
||||
// Fill `ndarray.shape`
|
||||
(shape.write_to_array)(generator, ctx, &ndarray_ptr.shape_slice(ctx))?;
|
||||
@ -129,8 +129,8 @@ where
|
||||
call_nac3_ndarray_util_assert_shape_no_negative(
|
||||
generator,
|
||||
ctx,
|
||||
&ndims,
|
||||
&ndarray_ptr.gep(ctx, |f| f.shape).load(ctx, "shape"),
|
||||
ndims,
|
||||
ndarray_ptr.gep(ctx, |f| f.shape).load(ctx, "shape"),
|
||||
);
|
||||
|
||||
// NOTE: DO NOT DO `set_strides_by_shape` HERE.
|
||||
@ -140,7 +140,7 @@ where
|
||||
}
|
||||
NDArrayInitMode::ShapeAndAllocaData { shape } => {
|
||||
let ndims = shape.count;
|
||||
let ndarray_ptr = alloca_ndarray(generator, ctx, elem_type, &ndims, name)?;
|
||||
let ndarray_ptr = alloca_ndarray(generator, ctx, elem_type, ndims, name)?;
|
||||
|
||||
// Fill `ndarray.shape`
|
||||
(shape.write_to_array)(generator, ctx, &ndarray_ptr.shape_slice(ctx))?;
|
||||
@ -149,26 +149,22 @@ where
|
||||
call_nac3_ndarray_util_assert_shape_no_negative(
|
||||
generator,
|
||||
ctx,
|
||||
&ndims,
|
||||
&ndarray_ptr.gep(ctx, |f| f.shape).load(ctx, "shape"),
|
||||
ndims,
|
||||
ndarray_ptr.gep(ctx, |f| f.shape).load(ctx, "shape"),
|
||||
);
|
||||
|
||||
// Now we populate `ndarray.data` by alloca-ing.
|
||||
// But first, we need to know the size of the ndarray to know how many elements to alloca,
|
||||
// since calculating nbytes of an ndarray requires `ndarray.shape` to be set.
|
||||
let ndarray_nbytes = call_nac3_ndarray_nbytes(generator, ctx, &ndarray_ptr);
|
||||
let ndarray_nbytes = call_nac3_ndarray_nbytes(ctx, ndarray_ptr);
|
||||
|
||||
// Alloca `data` and assign it to `ndarray.data`
|
||||
let data_ptr = OpaquePointer(
|
||||
ctx.builder
|
||||
.build_array_alloca(ctx.ctx.i8_type(), ndarray_nbytes.0, "data")
|
||||
.unwrap(),
|
||||
);
|
||||
ndarray_ptr.gep(ctx, |f| f.data).store(ctx, &data_ptr);
|
||||
let data_array = FixedIntModel(Byte).array_alloca(ctx, ndarray_nbytes, "data");
|
||||
ndarray_ptr.gep(ctx, |f| f.data).store(ctx, data_array.pointer);
|
||||
|
||||
// Finally, do `set_strides_by_shape`
|
||||
// Check out https://ajcr.net/stride-guide-part-1/ to see what numpy "strides" are.
|
||||
call_nac3_ndarray_set_strides_by_shape(generator, ctx, &ndarray_ptr);
|
||||
call_nac3_ndarray_set_strides_by_shape(ctx, ndarray_ptr);
|
||||
|
||||
Ok(ndarray_ptr)
|
||||
}
|
||||
@ -178,8 +174,8 @@ where
|
||||
fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ndims: &Int<'ctx>,
|
||||
shape_ptr: &Pointer<'ctx, IntModel<'ctx>>,
|
||||
ndims: Int<'ctx>,
|
||||
shape_ptr: Pointer<'ctx, IntModel<'ctx>>,
|
||||
) {
|
||||
let sizet = IntModel(generator.get_size_type(ctx.ctx));
|
||||
|
||||
@ -188,53 +184,71 @@ fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Siz
|
||||
ctx,
|
||||
&get_sized_dependent_function_name(sizet, "__nac3_ndarray_util_assert_shape_no_negative"),
|
||||
)
|
||||
.arg("errctx", &PointerModel(StructModel(ErrorContext)), &errctx)
|
||||
.arg("ndims", &sizet, ndims)
|
||||
.arg("shape", &PointerModel(sizet), shape_ptr)
|
||||
.arg("errctx", PointerModel(StructModel(ErrorContext)), errctx)
|
||||
.arg("ndims", sizet, ndims)
|
||||
.arg("shape", PointerModel(sizet), shape_ptr)
|
||||
.returning_void();
|
||||
check_error_context(generator, ctx, &errctx);
|
||||
check_error_context(generator, ctx, errctx);
|
||||
}
|
||||
|
||||
fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
fn call_nac3_ndarray_set_strides_by_shape<'ctx>(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ndarray_ptr: &Pointer<'ctx, StructModel<NpArray<'ctx>>>,
|
||||
ndarray_ptr: Pointer<'ctx, StructModel<NpArray<'ctx>>>,
|
||||
) {
|
||||
let sizet = IntModel(generator.get_size_type(ctx.ctx));
|
||||
let sizet = ndarray_ptr.element.0.sizet;
|
||||
|
||||
FunctionBuilder::begin(
|
||||
ctx,
|
||||
&get_sized_dependent_function_name(sizet, "__nac3_ndarray_set_strides_by_shape"),
|
||||
)
|
||||
.arg("ndarray", &PointerModel(StructModel(NpArray { sizet })), ndarray_ptr)
|
||||
.arg("ndarray", PointerModel(StructModel(NpArray { sizet })), ndarray_ptr)
|
||||
.returning_void();
|
||||
}
|
||||
|
||||
fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
pub fn call_nac3_ndarray_nbytes<'ctx>(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ndarray_ptr: &Pointer<'ctx, StructModel<NpArray<'ctx>>>,
|
||||
ndarray_ptr: Pointer<'ctx, StructModel<NpArray<'ctx>>>,
|
||||
) -> Int<'ctx> {
|
||||
let sizet = IntModel(generator.get_size_type(ctx.ctx));
|
||||
let sizet = ndarray_ptr.element.0.sizet;
|
||||
|
||||
FunctionBuilder::begin(ctx, &get_sized_dependent_function_name(sizet, "__nac3_ndarray_nbytes"))
|
||||
.arg("ndarray", &PointerModel(StructModel(NpArray { sizet })), ndarray_ptr)
|
||||
.returning("nbytes", &sizet)
|
||||
.arg("ndarray", PointerModel(StructModel(NpArray { sizet })), ndarray_ptr)
|
||||
.returning("nbytes", sizet)
|
||||
}
|
||||
|
||||
pub fn call_nac3_ndarray_fill_generic<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
pub fn call_nac3_ndarray_fill_generic<'ctx>(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ndarray_ptr: &Pointer<'ctx, StructModel<NpArray<'ctx>>>,
|
||||
fill_value_ptr: &OpaquePointer<'ctx>,
|
||||
ndarray_ptr: Pointer<'ctx, StructModel<NpArray<'ctx>>>,
|
||||
fill_value_ptr: Pointer<'ctx, ByteModel>,
|
||||
) {
|
||||
let sizet = IntModel(generator.get_size_type(ctx.ctx));
|
||||
let sizet = ndarray_ptr.element.0.sizet;
|
||||
|
||||
FunctionBuilder::begin(
|
||||
ctx,
|
||||
&get_sized_dependent_function_name(sizet, "__nac3_ndarray_fill_generic"),
|
||||
)
|
||||
.arg("ndarray", &PointerModel(StructModel(NpArray { sizet })), ndarray_ptr)
|
||||
.arg("pvalue", &OpaquePointerModel, fill_value_ptr)
|
||||
.arg("ndarray", PointerModel(StructModel(NpArray { sizet })), ndarray_ptr)
|
||||
.arg("pvalue", PointerModel(FixedIntModel(Byte)), fill_value_ptr)
|
||||
.returning_void();
|
||||
}
|
||||
|
||||
pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ndarray_ptr: Pointer<'ctx, StructModel<NpArray<'ctx>>>,
|
||||
) -> SliceIndex<'ctx> {
|
||||
let sizet = ndarray_ptr.element.0.sizet;
|
||||
let slice_index_model = SliceIndexModel::default();
|
||||
|
||||
let dst_len = slice_index_model.alloca(ctx, "dst_len");
|
||||
|
||||
let errctx = prepare_error_context(ctx);
|
||||
FunctionBuilder::begin(ctx, &get_sized_dependent_function_name(sizet, "__nac3_ndarray_len"))
|
||||
.arg("errctx", PointerModel(StructModel(ErrorContext)), errctx)
|
||||
.arg("ndarray", PointerModel(StructModel(NpArray { sizet })), ndarray_ptr)
|
||||
.arg("dst_len", PointerModel(slice_index_model), dst_len)
|
||||
.returning_void();
|
||||
check_error_context(generator, ctx, errctx);
|
||||
|
||||
dst_len.load(ctx, "len")
|
||||
}
|
||||
|
@ -85,7 +85,7 @@ where
|
||||
.unwrap();
|
||||
|
||||
// Write
|
||||
dst_array.ix(generator, ctx, axis, "dim").store(ctx, &Int(dim));
|
||||
dst_array.ix(generator, ctx, axis, "dim").store(ctx, Int(dim));
|
||||
Ok(())
|
||||
},
|
||||
incr_val,
|
||||
@ -127,7 +127,7 @@ where
|
||||
// Write
|
||||
dst_array
|
||||
.ix(generator, ctx, sizet.constant(axis as u64), "dim")
|
||||
.store(ctx, &Int(dim));
|
||||
.store(ctx, Int(dim));
|
||||
}
|
||||
Ok(())
|
||||
}),
|
||||
@ -151,7 +151,7 @@ where
|
||||
.unwrap();
|
||||
|
||||
// Set shape[0] = shape_int
|
||||
dst_array.ix(generator, ctx, sizet.constant(0), "dim").store(ctx, &Int(dim));
|
||||
dst_array.ix(generator, ctx, sizet.constant(0), "dim").store(ctx, Int(dim));
|
||||
|
||||
Ok(())
|
||||
}),
|
||||
|
86
nac3core/src/codegen/irrt/numpy/slice.rs
Normal file
86
nac3core/src/codegen/irrt/numpy/slice.rs
Normal file
@ -0,0 +1,86 @@
|
||||
use crate::codegen::{model::*, CodeGenContext};
|
||||
|
||||
// nac3core's slicing index/length values are always int32_t
|
||||
pub type SliceIndexInt = Int32;
|
||||
pub type SliceIndexModel = FixedIntModel<SliceIndexInt>;
|
||||
pub type SliceIndex<'ctx> = FixedInt<'ctx, SliceIndexInt>;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct UserSliceFields {
|
||||
pub start_defined: Field<BoolModel>,
|
||||
pub start: Field<SliceIndexModel>,
|
||||
pub stop_defined: Field<BoolModel>,
|
||||
pub stop: Field<SliceIndexModel>,
|
||||
pub step_defined: Field<BoolModel>,
|
||||
pub step: Field<SliceIndexModel>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct UserSlice;
|
||||
|
||||
impl<'ctx> IsStruct<'ctx> for UserSlice {
|
||||
type Fields = UserSliceFields;
|
||||
|
||||
fn struct_name(&self) -> &'static str {
|
||||
"UserSlice"
|
||||
}
|
||||
|
||||
fn build_fields(&self, builder: &mut FieldBuilder<'ctx>) -> Self::Fields {
|
||||
Self::Fields {
|
||||
start_defined: builder.add_field_auto("start_defined"),
|
||||
start: builder.add_field_auto("start"),
|
||||
stop_defined: builder.add_field_auto("stop_defined"),
|
||||
stop: builder.add_field_auto("stop"),
|
||||
step_defined: builder.add_field_auto("step_defined"),
|
||||
step: builder.add_field_auto("step"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RustUserSlice<'ctx> {
|
||||
pub start: Option<SliceIndex<'ctx>>,
|
||||
pub stop: Option<SliceIndex<'ctx>>,
|
||||
pub step: Option<SliceIndex<'ctx>>,
|
||||
}
|
||||
|
||||
impl<'ctx> RustUserSlice<'ctx> {
|
||||
// Set the values of an LLVM UserSlice
|
||||
// in the format of Python's `slice()`
|
||||
pub fn write_to_user_slice(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
dst_slice_ptr: Pointer<'ctx, StructModel<UserSlice>>,
|
||||
) {
|
||||
// TODO: make this neater, with a helper lambda?
|
||||
|
||||
let bool_model = BoolModel::default();
|
||||
|
||||
let false_ = bool_model.constant(ctx.ctx, 0);
|
||||
let true_ = bool_model.constant(ctx.ctx, 1);
|
||||
|
||||
match self.start {
|
||||
Some(start) => {
|
||||
dst_slice_ptr.gep(ctx, |f| f.start_defined).store(ctx, true_);
|
||||
dst_slice_ptr.gep(ctx, |f| f.start).store(ctx, start);
|
||||
}
|
||||
None => dst_slice_ptr.gep(ctx, |f| f.start_defined).store(ctx, false_),
|
||||
}
|
||||
|
||||
match self.stop {
|
||||
Some(stop) => {
|
||||
dst_slice_ptr.gep(ctx, |f| f.stop_defined).store(ctx, true_);
|
||||
dst_slice_ptr.gep(ctx, |f| f.stop).store(ctx, stop);
|
||||
}
|
||||
None => dst_slice_ptr.gep(ctx, |f| f.stop_defined).store(ctx, false_),
|
||||
}
|
||||
|
||||
match self.step {
|
||||
Some(step) => {
|
||||
dst_slice_ptr.gep(ctx, |f| f.step_defined).store(ctx, true_);
|
||||
dst_slice_ptr.gep(ctx, |f| f.step).store(ctx, step);
|
||||
}
|
||||
None => dst_slice_ptr.gep(ctx, |f| f.step_defined).store(ctx, false_),
|
||||
}
|
||||
}
|
||||
}
|
181
nac3core/src/codegen/irrt/numpy/subscript.rs
Normal file
181
nac3core/src/codegen/irrt/numpy/subscript.rs
Normal file
@ -0,0 +1,181 @@
|
||||
use crate::codegen::{
|
||||
irrt::{
|
||||
error_context::{check_error_context, prepare_error_context, ErrorContext},
|
||||
util::{get_sized_dependent_function_name, FunctionBuilder},
|
||||
},
|
||||
model::*,
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
use super::{
|
||||
ndarray::NpArray,
|
||||
slice::{RustUserSlice, SliceIndex, SliceIndexModel, UserSlice},
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct NDSubscriptFields {
|
||||
pub type_: Field<ByteModel>, // Defined to be uint8_t in IRRT
|
||||
pub data: Field<PointerModel<ByteModel>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct NDSubscript;
|
||||
|
||||
impl<'ctx> IsStruct<'ctx> for NDSubscript {
|
||||
type Fields = NDSubscriptFields;
|
||||
|
||||
fn struct_name(&self) -> &'static str {
|
||||
"NDSubscript"
|
||||
}
|
||||
|
||||
fn build_fields(&self, builder: &mut FieldBuilder<'ctx>) -> Self::Fields {
|
||||
Self::Fields { type_: builder.add_field_auto("type"), data: builder.add_field_auto("data") }
|
||||
}
|
||||
}
|
||||
|
||||
// An enum variant to store the content
|
||||
// and type of an NDSubscript in high level.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum RustNDSubscript<'ctx> {
|
||||
Index(SliceIndex<'ctx>),
|
||||
Slice(RustUserSlice<'ctx>),
|
||||
}
|
||||
|
||||
impl<'ctx> RustNDSubscript<'ctx> {
|
||||
fn irrt_subscript_id(&self) -> u64 {
|
||||
// Defined in IRRT
|
||||
match self {
|
||||
RustNDSubscript::Index(_) => 0,
|
||||
RustNDSubscript::Slice(_) => 1,
|
||||
}
|
||||
}
|
||||
|
||||
fn write_to_ndsubscript(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
dst_ndsubscript_ptr: Pointer<'ctx, StructModel<NDSubscript>>,
|
||||
) {
|
||||
let byte_model = ByteModel::default();
|
||||
let slice_index_model = SliceIndexModel::default();
|
||||
let user_slice_model = StructModel(UserSlice);
|
||||
|
||||
// Set `dst_ndsubscript_ptr->type`
|
||||
dst_ndsubscript_ptr
|
||||
.gep(ctx, |f| f.type_)
|
||||
.store(ctx, byte_model.constant(ctx.ctx, self.irrt_subscript_id()));
|
||||
|
||||
// Set `dst_ndsubscript_ptr->data`
|
||||
let data = match self {
|
||||
RustNDSubscript::Index(in_index) => {
|
||||
let index_ptr = slice_index_model.alloca(ctx, "index");
|
||||
index_ptr.store(ctx, *in_index);
|
||||
index_ptr.cast_to(ctx, FixedIntModel(Byte), "")
|
||||
}
|
||||
RustNDSubscript::Slice(in_rust_slice) => {
|
||||
let user_slice_ptr = user_slice_model.alloca(ctx, "user_slice");
|
||||
in_rust_slice.write_to_user_slice(ctx, user_slice_ptr);
|
||||
user_slice_ptr.cast_to(ctx, FixedIntModel(Byte), "")
|
||||
}
|
||||
};
|
||||
dst_ndsubscript_ptr.gep(ctx, |f| f.data).store(ctx, data);
|
||||
}
|
||||
|
||||
// Allocate an array of subscripts onto the stack and return its stack pointer
|
||||
pub fn alloca_subscripts(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
subscripts: &[RustNDSubscript<'ctx>],
|
||||
) -> ArraySlice<'ctx, StructModel<NDSubscript>> {
|
||||
let index_model = Int32Model::default();
|
||||
|
||||
let ndsubscript_model = StructModel(NDSubscript);
|
||||
let ndsubscript_array = ndsubscript_model.array_alloca(
|
||||
ctx,
|
||||
index_model.constant(ctx.ctx, subscripts.len() as u64).to_int(),
|
||||
"ndsubscripts",
|
||||
);
|
||||
|
||||
for (i, rust_ndsubscript) in subscripts.iter().enumerate() {
|
||||
let ndsubscript_ptr = ndsubscript_array.ix_unchecked(
|
||||
ctx,
|
||||
index_model.constant(ctx.ctx, i as u64).to_int(),
|
||||
"",
|
||||
);
|
||||
rust_ndsubscript.write_to_ndsubscript(ctx, ndsubscript_ptr);
|
||||
}
|
||||
|
||||
ndsubscript_array
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn deduce_ndims_after_slicing(slices: &[RustNDSubscript], original_ndims: i32) -> i32 {
|
||||
let mut final_ndims: i32 = original_ndims;
|
||||
for slice in slices {
|
||||
match slice {
|
||||
RustNDSubscript::Index(_) => {
|
||||
// Index subscripts demotes the rank by 1
|
||||
final_ndims -= 1;
|
||||
}
|
||||
RustNDSubscript::Slice(_) => {
|
||||
// Nothing
|
||||
}
|
||||
}
|
||||
}
|
||||
final_ndims
|
||||
}
|
||||
}
|
||||
|
||||
pub fn call_nac3_ndarray_subscript_deduce_ndims_after_slicing<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
sizet: IntModel<'ctx>,
|
||||
ndims: Int<'ctx>,
|
||||
num_ndsubscripts: Int<'ctx>,
|
||||
ndsubscripts: Pointer<'ctx, StructModel<NDSubscript>>,
|
||||
) -> Int<'ctx> {
|
||||
let result = sizet.alloca(ctx, "result");
|
||||
|
||||
let errctx_ptr = prepare_error_context(ctx);
|
||||
FunctionBuilder::begin(
|
||||
ctx,
|
||||
&get_sized_dependent_function_name(
|
||||
sizet,
|
||||
"__nac3_ndarray_subscript_deduce_ndims_after_slicing",
|
||||
),
|
||||
)
|
||||
.arg("errctx", PointerModel(StructModel(ErrorContext)), errctx_ptr)
|
||||
.arg("result", PointerModel(sizet), result)
|
||||
.arg("ndims", sizet, ndims)
|
||||
.arg("num_ndsubscripts", sizet, num_ndsubscripts)
|
||||
.arg("ndsubscripts", PointerModel(StructModel(NDSubscript)), ndsubscripts)
|
||||
.returning_void();
|
||||
check_error_context(generator, ctx, errctx_ptr);
|
||||
|
||||
result.load(ctx, "final_ndims")
|
||||
}
|
||||
|
||||
pub fn call_nac3_ndarray_subscript<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
num_subscripts: FixedInt<'ctx, Int32>,
|
||||
subscripts: Pointer<'ctx, StructModel<NDSubscript>>,
|
||||
src_ndarray: Pointer<'ctx, StructModel<NpArray<'ctx>>>,
|
||||
dst_ndarray: Pointer<'ctx, StructModel<NpArray<'ctx>>>,
|
||||
) {
|
||||
let sizet = src_ndarray.element.0.sizet;
|
||||
assert!(sizet.same_as(dst_ndarray.element.0.sizet)); // SizeT of src_ndarray and dst_ndarray must match
|
||||
|
||||
let errctx_ptr = prepare_error_context(ctx);
|
||||
|
||||
FunctionBuilder::begin(
|
||||
ctx,
|
||||
&get_sized_dependent_function_name(sizet, "__nac3_ndarray_subscript"),
|
||||
)
|
||||
.arg("errctx", PointerModel(StructModel(ErrorContext)), errctx_ptr)
|
||||
.arg("num_subscripts", FixedIntModel(Int32), num_subscripts)
|
||||
.arg("subscripts", PointerModel(StructModel(NDSubscript)), subscripts)
|
||||
.arg("src_ndarray", PointerModel(StructModel(NpArray { sizet })), src_ndarray)
|
||||
.arg("dst_ndarray", PointerModel(StructModel(NpArray { sizet })), dst_ndarray)
|
||||
.returning_void();
|
||||
|
||||
check_error_context(generator, ctx, errctx_ptr);
|
||||
}
|
@ -45,13 +45,13 @@ impl<'ctx, 'a> FunctionBuilder<'ctx, 'a> {
|
||||
|
||||
// The name is for self-documentation
|
||||
#[must_use]
|
||||
pub fn arg<M: Model<'ctx>>(mut self, _name: &'static str, model: &M, value: &M::Value) -> Self {
|
||||
pub fn arg<M: Model<'ctx>>(mut self, _name: &'static str, model: M, value: M::Value) -> Self {
|
||||
self.arguments
|
||||
.push((model.get_llvm_type(self.ctx.ctx).into(), value.get_llvm_value().into()));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn returning<M: Model<'ctx>>(self, name: &'static str, return_model: &M) -> M::Value {
|
||||
pub fn returning<M: Model<'ctx>>(self, name: &'static str, return_model: M) -> M::Value {
|
||||
let (param_tys, param_vals): (Vec<_>, Vec<_>) = self.arguments.into_iter().unzip();
|
||||
|
||||
let function = self.ctx.module.get_function(self.fn_name).unwrap_or_else(|| {
|
||||
|
@ -23,7 +23,7 @@ use inkwell::{
|
||||
values::{BasicValueEnum, FunctionValue, IntValue, PhiValue, PointerValue},
|
||||
AddressSpace, IntPredicate, OptimizationLevel,
|
||||
};
|
||||
use irrt::{error_context::Str, numpy::NpArray};
|
||||
use irrt::{error_context::Str, numpy::ndarray::NpArray};
|
||||
use itertools::Itertools;
|
||||
use model::*;
|
||||
use nac3parser::ast::{Location, Stmt, StrRef};
|
||||
|
@ -16,13 +16,13 @@ use inkwell::{
|
||||
|
||||
use crate::codegen::CodeGenContext;
|
||||
|
||||
use super::Pointer;
|
||||
use super::{slice::ArraySlice, Int, Pointer};
|
||||
|
||||
pub trait ModelValue<'ctx> {
|
||||
pub trait ModelValue<'ctx>: Clone + Copy {
|
||||
fn get_llvm_value(&self) -> BasicValueEnum<'ctx>;
|
||||
}
|
||||
|
||||
pub trait Model<'ctx>: Clone {
|
||||
pub trait Model<'ctx>: Clone + Copy {
|
||||
type Value: ModelValue<'ctx>;
|
||||
|
||||
fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx>;
|
||||
@ -30,8 +30,26 @@ pub trait Model<'ctx>: Clone {
|
||||
|
||||
fn alloca(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> Pointer<'ctx, Self> {
|
||||
Pointer {
|
||||
element: self.clone(),
|
||||
element: *self,
|
||||
value: ctx.builder.build_alloca(self.get_llvm_type(ctx.ctx), name).unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
fn array_alloca(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
count: Int<'ctx>,
|
||||
name: &str,
|
||||
) -> ArraySlice<'ctx, Self> {
|
||||
ArraySlice {
|
||||
num_elements: count,
|
||||
pointer: Pointer {
|
||||
element: *self,
|
||||
value: ctx
|
||||
.builder
|
||||
.build_array_alloca(self.get_llvm_type(ctx.ctx), count.0, name)
|
||||
.unwrap(),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -9,7 +9,7 @@ use crate::codegen::CodeGenContext;
|
||||
|
||||
use super::{Model, ModelValue, Pointer};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct Field<E> {
|
||||
pub gep_index: u64,
|
||||
pub name: &'static str,
|
||||
@ -17,7 +17,7 @@ pub struct Field<E> {
|
||||
}
|
||||
|
||||
// Like [`Field<E>`] but element must be [`BasicTypeEnum<'ctx>`]
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct FieldLLVM<'ctx> {
|
||||
gep_index: u64,
|
||||
name: &'ctx str,
|
||||
@ -57,7 +57,7 @@ impl<'ctx> FieldBuilder<'ctx> {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait IsStruct<'ctx>: Clone {
|
||||
pub trait IsStruct<'ctx>: Clone + Copy {
|
||||
type Fields;
|
||||
|
||||
fn struct_name(&self) -> &'static str;
|
||||
@ -79,11 +79,10 @@ pub trait IsStruct<'ctx>: Clone {
|
||||
}
|
||||
}
|
||||
|
||||
// To play nice with Rust's trait resolution
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct StructModel<S>(pub S);
|
||||
|
||||
// TODO: enrich it
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct Struct<'ctx, S> {
|
||||
pub structure: S,
|
||||
pub value: StructValue<'ctx>,
|
||||
@ -104,7 +103,7 @@ impl<'ctx, S: IsStruct<'ctx>> Model<'ctx> for StructModel<S> {
|
||||
|
||||
fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
|
||||
// TODO: check structure
|
||||
Struct { structure: self.0.clone(), value: value.into_struct_value() }
|
||||
Struct { structure: self.0, value: value.into_struct_value() }
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -4,6 +4,8 @@ use inkwell::{
|
||||
values::{AnyValueEnum, BasicValue, BasicValueEnum, IntValue},
|
||||
};
|
||||
|
||||
use crate::codegen::CodeGenContext;
|
||||
|
||||
use super::core::*;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
@ -32,21 +34,85 @@ impl<'ctx> Model<'ctx> for IntModel<'ctx> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> Int<'ctx> {
|
||||
#[must_use]
|
||||
pub fn signed_cast_to_int(
|
||||
self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
target_int: IntModel<'ctx>,
|
||||
name: &str,
|
||||
) -> Int<'ctx> {
|
||||
Int(ctx.builder.build_int_s_extend_or_bit_cast(self.0, target_int.0, name).unwrap())
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn signed_cast_to_fixed<T: IsFixedInt>(
|
||||
self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
target_fixed: T,
|
||||
name: &str,
|
||||
) -> FixedInt<'ctx, T> {
|
||||
FixedInt {
|
||||
int: target_fixed,
|
||||
value: ctx
|
||||
.builder
|
||||
.build_int_s_extend_or_bit_cast(self.0, T::get_int_type(ctx.ctx), name)
|
||||
.unwrap(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> IntModel<'ctx> {
|
||||
#[must_use]
|
||||
pub fn constant(&self, value: u64) -> Int<'ctx> {
|
||||
Int(self.0.const_int(value, false))
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn same_as(&self, other: IntModel<'ctx>) -> bool {
|
||||
// TODO: or `self.0 == other.0` would also work?
|
||||
self.0.get_bit_width() == other.0.get_bit_width()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct FixedIntModel<T>(pub T);
|
||||
|
||||
impl<T: IsFixedInt> FixedIntModel<T> {
|
||||
pub fn to_int_model(self, ctx: &Context) -> IntModel<'_> {
|
||||
IntModel(T::get_int_type(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct FixedInt<'ctx, T: IsFixedInt> {
|
||||
pub int: T,
|
||||
pub value: IntValue<'ctx>,
|
||||
}
|
||||
|
||||
pub trait IsFixedInt: Clone + Default {
|
||||
impl<'ctx, T: IsFixedInt> FixedInt<'ctx, T> {
|
||||
pub fn to_int(self) -> Int<'ctx> {
|
||||
Int(self.value)
|
||||
}
|
||||
|
||||
pub fn signed_cast_to_fixed<R: IsFixedInt>(
|
||||
self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
target_fixed_int: R,
|
||||
name: &str,
|
||||
) -> FixedInt<'ctx, R> {
|
||||
FixedInt {
|
||||
int: target_fixed_int,
|
||||
value: ctx
|
||||
.builder
|
||||
.build_int_s_extend_or_bit_cast(self.value, R::get_int_type(ctx.ctx), name)
|
||||
.unwrap(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Default instance is to enable `FieldBuilder::add_field_auto`
|
||||
pub trait IsFixedInt: Clone + Copy + Default {
|
||||
fn get_int_type(ctx: &Context) -> IntType<'_>;
|
||||
fn get_bit_width() -> u32; // This is required, instead of only relying on get_int_type
|
||||
}
|
||||
@ -67,18 +133,19 @@ impl<'ctx, T: IsFixedInt> Model<'ctx> for FixedIntModel<T> {
|
||||
fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
|
||||
let value = value.into_int_value();
|
||||
assert_eq!(value.get_type().get_bit_width(), T::get_bit_width());
|
||||
FixedInt { int: self.0.clone(), value }
|
||||
FixedInt { int: self.0, value }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx, T: IsFixedInt> FixedIntModel<T> {
|
||||
pub fn constant(&self, ctx: &'ctx Context, value: u64) -> FixedInt<'ctx, T> {
|
||||
FixedInt { int: self.0.clone(), value: T::get_int_type(ctx).const_int(value, false) }
|
||||
FixedInt { int: self.0, value: T::get_int_type(ctx).const_int(value, false) }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct Bool;
|
||||
pub type BoolModel = FixedIntModel<Bool>;
|
||||
|
||||
impl IsFixedInt for Bool {
|
||||
fn get_int_type(ctx: &Context) -> IntType<'_> {
|
||||
@ -90,8 +157,9 @@ impl IsFixedInt for Bool {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct Byte;
|
||||
pub type ByteModel = FixedIntModel<Byte>;
|
||||
|
||||
impl IsFixedInt for Byte {
|
||||
fn get_int_type(ctx: &Context) -> IntType<'_> {
|
||||
@ -103,8 +171,9 @@ impl IsFixedInt for Byte {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct Int32;
|
||||
pub type Int32Model = FixedIntModel<Int32>;
|
||||
|
||||
impl IsFixedInt for Int32 {
|
||||
fn get_int_type(ctx: &Context) -> IntType<'_> {
|
||||
@ -116,8 +185,9 @@ impl IsFixedInt for Int32 {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct Int64;
|
||||
pub type Int64Model = FixedIntModel<Int64>;
|
||||
|
||||
impl IsFixedInt for Int64 {
|
||||
fn get_int_type(ctx: &Context) -> IntType<'_> {
|
||||
|
@ -9,12 +9,13 @@ use crate::codegen::CodeGenContext;
|
||||
|
||||
use super::core::*;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct Pointer<'ctx, E: Model<'ctx>> {
|
||||
pub element: E,
|
||||
pub value: PointerValue<'ctx>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct PointerModel<E>(pub E);
|
||||
|
||||
impl<'ctx, E: Model<'ctx>> ModelValue<'ctx> for Pointer<'ctx, E> {
|
||||
@ -24,7 +25,7 @@ impl<'ctx, E: Model<'ctx>> ModelValue<'ctx> for Pointer<'ctx, E> {
|
||||
}
|
||||
|
||||
impl<'ctx, E: Model<'ctx>> Pointer<'ctx, E> {
|
||||
pub fn store(&self, ctx: &CodeGenContext<'ctx, '_>, val: &E::Value) {
|
||||
pub fn store(&self, ctx: &CodeGenContext<'ctx, '_>, val: E::Value) {
|
||||
ctx.builder.build_store(self.value, val.get_llvm_value()).unwrap();
|
||||
}
|
||||
|
||||
@ -32,6 +33,30 @@ impl<'ctx, E: Model<'ctx>> Pointer<'ctx, E> {
|
||||
let val = ctx.builder.build_load(self.value, name).unwrap();
|
||||
self.element.check_llvm_value(val.as_any_value_enum())
|
||||
}
|
||||
|
||||
pub fn to_opaque(self) -> OpaquePointer<'ctx> {
|
||||
OpaquePointer(self.value)
|
||||
}
|
||||
|
||||
pub fn cast_opaque_to(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
element_type: BasicTypeEnum<'ctx>,
|
||||
name: &str,
|
||||
) -> OpaquePointer<'ctx> {
|
||||
self.to_opaque().cast_opaque_to(ctx, element_type, name)
|
||||
}
|
||||
|
||||
pub fn cast_to<R: Model<'ctx>>(
|
||||
self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
element_model: R,
|
||||
name: &str,
|
||||
) -> Pointer<'ctx, R> {
|
||||
let casted_ptr =
|
||||
self.to_opaque().cast_opaque_to(ctx, element_model.get_llvm_type(ctx.ctx), name).0;
|
||||
Pointer { element: element_model, value: casted_ptr }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx, E: Model<'ctx>> Model<'ctx> for PointerModel<E> {
|
||||
@ -43,13 +68,15 @@ impl<'ctx, E: Model<'ctx>> Model<'ctx> for PointerModel<E> {
|
||||
|
||||
fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
|
||||
// TODO: Check get_element_type()? for LLVM 14 at least...
|
||||
Pointer { element: self.0.clone(), value: value.into_pointer_value() }
|
||||
Pointer { element: self.0, value: value.into_pointer_value() }
|
||||
}
|
||||
}
|
||||
|
||||
// A pointer of which the element's model is unknown.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct OpaquePointer<'ctx>(pub PointerValue<'ctx>);
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct OpaquePointerModel;
|
||||
|
||||
impl<'ctx> ModelValue<'ctx> for OpaquePointer<'ctx> {
|
||||
@ -74,19 +101,39 @@ impl<'ctx> Model<'ctx> for OpaquePointerModel {
|
||||
}
|
||||
|
||||
impl<'ctx> OpaquePointer<'ctx> {
|
||||
pub fn store(&self, ctx: &CodeGenContext<'ctx, '_>, value: BasicValueEnum<'ctx>) {
|
||||
pub fn load_opaque(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> BasicValueEnum<'ctx> {
|
||||
ctx.builder.build_load(self.0, name).unwrap()
|
||||
}
|
||||
|
||||
pub fn store_opaque(&self, ctx: &CodeGenContext<'ctx, '_>, value: BasicValueEnum<'ctx>) {
|
||||
ctx.builder.build_store(self.0, value).unwrap();
|
||||
}
|
||||
|
||||
pub fn from_ptr(ctx: &CodeGenContext<'ctx, '_>, ptr: PointerValue<'ctx>) -> Self {
|
||||
let ptr = ctx
|
||||
.builder
|
||||
.build_pointer_cast(
|
||||
ptr,
|
||||
ctx.ctx.i8_type().ptr_type(AddressSpace::default()),
|
||||
"opaque.from_ptr",
|
||||
)
|
||||
.unwrap();
|
||||
OpaquePointer(ptr)
|
||||
#[must_use]
|
||||
pub fn cast_opaque_to(
|
||||
self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
element_llvm_type: BasicTypeEnum<'ctx>,
|
||||
name: &str,
|
||||
) -> OpaquePointer<'ctx> {
|
||||
OpaquePointer(
|
||||
ctx.builder
|
||||
.build_pointer_cast(
|
||||
self.0,
|
||||
element_llvm_type.ptr_type(AddressSpace::default()),
|
||||
name,
|
||||
)
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn cast_to<E: Model<'ctx>>(
|
||||
self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
element_model: E,
|
||||
name: &str,
|
||||
) -> Pointer<'ctx, E> {
|
||||
let ptr = self.cast_opaque_to(ctx, element_model.get_llvm_type(ctx.ctx), name).0;
|
||||
Pointer { element: element_model, value: ptr }
|
||||
}
|
||||
}
|
||||
|
@ -16,7 +16,7 @@ impl<'ctx, E: Model<'ctx>> ArraySlice<'ctx, E> {
|
||||
) -> Pointer<'ctx, E> {
|
||||
let element_addr =
|
||||
unsafe { ctx.builder.build_in_bounds_gep(self.pointer.value, &[idx.0], name).unwrap() };
|
||||
Pointer { value: element_addr, element: self.pointer.element.clone() }
|
||||
Pointer { value: element_addr, element: self.pointer.element }
|
||||
}
|
||||
|
||||
pub fn ix<G: CodeGenerator + ?Sized>(
|
||||
|
@ -9,8 +9,10 @@ use crate::{
|
||||
|
||||
use super::{
|
||||
irrt::numpy::{
|
||||
alloca_ndarray_and_init, call_nac3_ndarray_fill_generic, parse_input_shape_arg,
|
||||
NDArrayInitMode, NpArray,
|
||||
ndarray::{
|
||||
alloca_ndarray_and_init, call_nac3_ndarray_fill_generic, NDArrayInitMode, NpArray,
|
||||
},
|
||||
shape::parse_input_shape_arg,
|
||||
},
|
||||
model::*,
|
||||
CodeGenContext, CodeGenerator,
|
||||
@ -59,13 +61,11 @@ where
|
||||
// Allocate fill_value on the stack and give the corresponding stack pointer
|
||||
// to call_nac3_ndarray_fill_generic
|
||||
let fill_value_ptr = ctx.builder.build_alloca(fill_value.get_type(), "fill_value_ptr").unwrap();
|
||||
ctx.builder.build_store(fill_value_ptr, fill_value).unwrap();
|
||||
let fill_value_ptr = OpaquePointer(fill_value_ptr);
|
||||
fill_value_ptr.store_opaque(ctx, fill_value);
|
||||
|
||||
// Opaque-ize fill_value_ptr (turning it into `i8*`) before passing
|
||||
// to call_nac3_ndarray_fill_generic
|
||||
let fill_value_ptr = OpaquePointer::from_ptr(ctx, fill_value_ptr);
|
||||
|
||||
call_nac3_ndarray_fill_generic(generator, ctx, &ndarray_ptr, &fill_value_ptr);
|
||||
let fill_value_ptr = fill_value_ptr.cast_to(ctx, FixedIntModel(Byte), "");
|
||||
call_nac3_ndarray_fill_generic(ctx, ndarray_ptr, fill_value_ptr);
|
||||
|
||||
Ok(ndarray_ptr)
|
||||
}
|
||||
|
@ -5,7 +5,7 @@ use indexmap::IndexMap;
|
||||
use inkwell::{
|
||||
attributes::{Attribute, AttributeLoc},
|
||||
types::{BasicMetadataTypeEnum, BasicType},
|
||||
values::{BasicMetadataValueEnum, BasicValue, CallSiteValue},
|
||||
values::{AnyValue, BasicMetadataValueEnum, BasicValue, CallSiteValue},
|
||||
IntPredicate,
|
||||
};
|
||||
use itertools::Either;
|
||||
@ -14,10 +14,17 @@ use strum::IntoEnumIterator;
|
||||
use crate::{
|
||||
codegen::{
|
||||
builtin_fns,
|
||||
classes::{ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor},
|
||||
classes::{ProxyValue, RangeValue},
|
||||
expr::destructure_range,
|
||||
irrt::*,
|
||||
numpy::*,
|
||||
irrt::{
|
||||
calculate_len_for_slice_range,
|
||||
numpy::ndarray::{call_nac3_ndarray_len, NpArray},
|
||||
},
|
||||
model::*,
|
||||
numpy::{
|
||||
gen_ndarray_array, gen_ndarray_copy, gen_ndarray_eye, gen_ndarray_fill,
|
||||
gen_ndarray_identity,
|
||||
},
|
||||
numpy_new,
|
||||
stmt::exn_constructor,
|
||||
},
|
||||
@ -1265,7 +1272,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
// type variable
|
||||
&[(self.list_int32, "shape"), (tv.ty, "fill_value")],
|
||||
Box::new(move |ctx, obj, fun, args, generator| {
|
||||
gen_ndarray_full(ctx, &obj, fun, &args, generator)
|
||||
numpy_new::gen_ndarray_full(ctx, &obj, fun, &args, generator)
|
||||
.map(|val| Some(val.as_basic_value_enum()))
|
||||
}),
|
||||
)
|
||||
@ -1457,51 +1464,19 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
}
|
||||
}
|
||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
// Parse `arg`
|
||||
let sizet = IntModel(generator.get_size_type(ctx.ctx));
|
||||
|
||||
let arg = NDArrayValue::from_ptr_val(
|
||||
arg.into_pointer_value(),
|
||||
llvm_usize,
|
||||
None,
|
||||
);
|
||||
let ndarray_ptr_model =
|
||||
PointerModel(StructModel(NpArray { sizet }));
|
||||
let ndarray_ptr =
|
||||
ndarray_ptr_model.check_llvm_value(arg.as_any_value_enum());
|
||||
|
||||
let ndims = arg.dim_sizes().size(ctx, generator);
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder
|
||||
.build_int_compare(
|
||||
IntPredicate::NE,
|
||||
ndims,
|
||||
llvm_usize.const_zero(),
|
||||
"",
|
||||
)
|
||||
.unwrap(),
|
||||
"0:TypeError",
|
||||
&format!("{name}() of unsized object", name = prim.name()),
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
let len = unsafe {
|
||||
arg.dim_sizes().get_typed_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_zero(),
|
||||
None,
|
||||
)
|
||||
};
|
||||
|
||||
if len.get_type().get_bit_width() == 32 {
|
||||
Some(len.into())
|
||||
} else {
|
||||
Some(
|
||||
ctx.builder
|
||||
.build_int_truncate(len, llvm_i32, "len")
|
||||
.map(Into::into)
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
// Calculate len
|
||||
// NOTE: Unsized object is asserted in IRRT
|
||||
let len = call_nac3_ndarray_len(generator, ctx, ndarray_ptr);
|
||||
let len = len.signed_cast_to_fixed(ctx, Int32, "len_i32");
|
||||
Some(len.value.as_basic_value_enum())
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user