core: irrt proper ndarray subscript & more

Details:
- improve irrt model
- len() on ndarrays
This commit is contained in:
lyken 2024-07-15 21:05:01 +08:00
parent 29734ce3af
commit 0946bd86ea
19 changed files with 778 additions and 496 deletions

View File

@ -18,6 +18,7 @@ namespace {
ErrorId value_error; ErrorId value_error;
ErrorId assertion_error; ErrorId assertion_error;
ErrorId runtime_error; ErrorId runtime_error;
ErrorId type_error;
}; };
struct ErrorContext { struct ErrorContext {

View File

@ -111,6 +111,21 @@ namespace { namespace ndarray { namespace basic {
void set_pelement_value(NDArray<SizeT>* ndarray, uint8_t* pelement, const uint8_t* pvalue) { void set_pelement_value(NDArray<SizeT>* ndarray, uint8_t* pelement, const uint8_t* pvalue) {
__builtin_memcpy(pelement, pvalue, ndarray->itemsize); __builtin_memcpy(pelement, pvalue, ndarray->itemsize);
} }
template <typename SizeT>
void len(ErrorContext* errctx, NDArray<SizeT>* ndarray, SliceIndex* dst_length) {
// Error if the ndarray is "unsized" (i.e, ndims == 0)
if (ndarray->ndims == 0) {
// Error copied from python by doing `len(np.zeros(()))`
errctx->set_error(
errctx->error_ids->type_error,
"len() of unsized object"
);
return; // Terminate
}
*dst_length = (SliceIndex) ndarray->shape[0];
}
} } } } } }
extern "C" { extern "C" {
@ -132,6 +147,14 @@ extern "C" {
return nbytes(ndarray); return nbytes(ndarray);
} }
void __nac3_ndarray_len(ErrorContext* errctx, NDArray<int32_t>* ndarray, SliceIndex* dst_len) {
return len(errctx, ndarray, dst_len);
}
void __nac3_ndarray_len64(ErrorContext* errctx, NDArray<int64_t>* ndarray, SliceIndex* dst_len) {
return len(errctx, ndarray, dst_len);
}
void __nac3_ndarray_util_assert_shape_no_negative(ErrorContext* errctx, int32_t ndims, int32_t* shape) { void __nac3_ndarray_util_assert_shape_no_negative(ErrorContext* errctx, int32_t ndims, int32_t* shape) {
util::assert_shape_no_negative(errctx, ndims, shape); util::assert_shape_no_negative(errctx, ndims, shape);
} }

View File

@ -6,6 +6,8 @@
#include <irrt/error_context.hpp> #include <irrt/error_context.hpp>
namespace { namespace {
typedef uint32_t NumNDSubscriptsType;
typedef uint8_t NDSubscriptType; typedef uint8_t NDSubscriptType;
const NDSubscriptType INPUT_SUBSCRIPT_TYPE_INDEX = 0; const NDSubscriptType INPUT_SUBSCRIPT_TYPE_INDEX = 0;
@ -30,16 +32,25 @@ namespace {
namespace { namespace ndarray { namespace subscript { namespace { namespace ndarray { namespace subscript {
namespace util { namespace util {
template<typename SizeT> template<typename SizeT>
SizeT deduce_ndims_after_slicing(SizeT ndims, SizeT num_subscripts, const NDSubscript* subscripts) { void deduce_ndims_after_slicing(ErrorContext* errctx, SizeT* result, SizeT ndims, SizeT num_ndsubscripts, const NDSubscript* ndsubscripts) {
irrt_assert(num_subscripts <= ndims); if (num_ndsubscripts > ndims) {
// Error copied from python by doing `np.zeros((3, 4))[:, :, :]`
errctx->set_error(
errctx->error_ids->index_error,
"too many indices for array: array is {0}-dimensional, but {1} were indexed",
ndims, num_ndsubscripts
);
return; // Terminate
}
SizeT final_ndims = ndims; SizeT final_ndims = ndims;
for (SizeT i = 0; i < num_subscripts; i++) { for (SizeT i = 0; i < num_ndsubscripts; i++) {
if (subscripts[i].type == INPUT_SUBSCRIPT_TYPE_INDEX) { if (ndsubscripts[i].type == INPUT_SUBSCRIPT_TYPE_INDEX) {
final_ndims--; // An index demotes the rank by 1 final_ndims--; // An index demotes the rank by 1
} }
} }
return final_ndims;
*result = final_ndims;
} }
} }
@ -61,7 +72,7 @@ namespace { namespace ndarray { namespace subscript {
// - `dst_ndarray->itemsize` does not have to be set, it will be set to `src_ndarray->itemsize` // - `dst_ndarray->itemsize` does not have to be set, it will be set to `src_ndarray->itemsize`
// - `dst_ndarray->shape` and `dst_ndarray.strides` can contain empty values // - `dst_ndarray->shape` and `dst_ndarray.strides` can contain empty values
template <typename SizeT> template <typename SizeT>
void subscript(ErrorContext* errctx, SizeT num_subscripts, NDSubscript* subscripts, NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) { void subscript(ErrorContext* errctx, NumNDSubscriptsType num_subscripts, NDSubscript* subscripts, NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
// REFERENCE CODE (check out `_index_helper` in `__getitem__`): // REFERENCE CODE (check out `_index_helper` in `__getitem__`):
// https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L652 // https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L652
@ -142,11 +153,19 @@ namespace { namespace ndarray { namespace subscript {
extern "C" { extern "C" {
using namespace ndarray::subscript; using namespace ndarray::subscript;
void __nac3_ndarray_subscript(ErrorContext* errctx, int32_t num_subscripts, NDSubscript* subscripts, NDArray<int32_t>* src_ndarray, NDArray<int32_t> *dst_ndarray) { void __nac3_ndarray_subscript_deduce_ndims_after_slicing(ErrorContext* errctx, int32_t* result, int32_t ndims, int32_t num_ndsubscripts, const NDSubscript* ndsubscripts) {
ndarray::subscript::util::deduce_ndims_after_slicing(errctx, result, ndims, num_ndsubscripts, ndsubscripts);
}
void __nac3_ndarray_subscript_deduce_ndims_after_slicing64(ErrorContext* errctx, int64_t* result, int64_t ndims, int64_t num_ndsubscripts, const NDSubscript* ndsubscripts) {
ndarray::subscript::util::deduce_ndims_after_slicing(errctx, result, ndims, num_ndsubscripts, ndsubscripts);
}
void __nac3_ndarray_subscript(ErrorContext* errctx, NumNDSubscriptsType num_subscripts, NDSubscript* subscripts, NDArray<int32_t>* src_ndarray, NDArray<int32_t> *dst_ndarray) {
subscript(errctx, num_subscripts, subscripts, src_ndarray, dst_ndarray); subscript(errctx, num_subscripts, subscripts, src_ndarray, dst_ndarray);
} }
void __nac3_ndarray_subscript64(ErrorContext* errctx, int64_t num_subscripts, NDSubscript* subscripts, NDArray<int64_t>* src_ndarray, NDArray<int64_t> *dst_ndarray) { void __nac3_ndarray_subscript64(ErrorContext* errctx, NumNDSubscriptsType num_subscripts, NDSubscript* subscripts, NDArray<int64_t>* src_ndarray, NDArray<int64_t> *dst_ndarray) {
subscript(errctx, num_subscripts, subscripts, src_ndarray, dst_ndarray); subscript(errctx, num_subscripts, subscripts, src_ndarray, dst_ndarray);
} }
} }

View File

@ -4,11 +4,18 @@ use crate::{
codegen::{ codegen::{
classes::{ classes::{
ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, NDArrayValue, ProxyType, ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, NDArrayValue, ProxyType,
ProxyValue, RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor, ProxyValue, RangeValue, UntypedArrayLikeAccessor,
}, },
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore}, concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
gen_in_range_check, get_llvm_abi_type, get_llvm_type, gen_in_range_check, get_llvm_abi_type, get_llvm_type,
irrt::*, irrt::{
numpy::{
ndarray::{self, alloca_ndarray_and_init},
slice::{RustUserSlice, SliceIndexModel},
subscript::{call_nac3_ndarray_subscript, RustNDSubscript},
},
*,
},
llvm_intrinsics::{ llvm_intrinsics::{
call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax, call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax,
call_memcpy_generic, call_memcpy_generic,
@ -18,14 +25,10 @@ use crate::{
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise, gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
gen_var, gen_var,
}, },
CodeGenContext, CodeGenTask, CodeGenerator, CodeGenContext, CodeGenTask, CodeGenerator, Int32,
}, },
symbol_resolver::{SymbolValue, ValueEnum}, symbol_resolver::{SymbolValue, ValueEnum},
toplevel::{ toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef},
helper::PrimDef,
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
DefinitionId, TopLevelDef,
},
typecheck::{ typecheck::{
magic_methods::{Binop, BinopVariant, HasOpInfo}, magic_methods::{Binop, BinopVariant, HasOpInfo},
typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap}, typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap},
@ -34,15 +37,19 @@ use crate::{
use inkwell::{ use inkwell::{
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
types::{AnyType, BasicType, BasicTypeEnum}, types::{AnyType, BasicType, BasicTypeEnum},
values::{BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue}, values::{
AnyValue, BasicValue, BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue,
},
AddressSpace, IntPredicate, OptimizationLevel, AddressSpace, IntPredicate, OptimizationLevel,
}; };
use itertools::{chain, izip, Either, Itertools}; use itertools::{chain, izip, Either, Itertools};
use nac3parser::ast::{ use nac3parser::ast::{
self, Boolop, Cmpop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef, self, Boolop, Cmpop, Comprehension, Constant, Expr, ExprKind, Located, Location, Operator,
Unaryop, StrRef, Unaryop,
}; };
use super::{irrt::numpy::ndarray::NpArray, IntModel, Model, Pointer, PointerModel, StructModel};
pub fn get_subst_key( pub fn get_subst_key(
unifier: &mut Unifier, unifier: &mut Unifier,
obj: Option<Type>, obj: Option<Type>,
@ -2130,322 +2137,151 @@ pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>(
/// Generates code for a subscript expression on an `ndarray`. /// Generates code for a subscript expression on an `ndarray`.
/// ///
/// * `ty` - The `Type` of the `NDArray` elements. /// * `elem_ty` - The `Type` of the `NDArray` elements.
/// * `ndims` - The `Type` of the `NDArray` number-of-dimensions `Literal`. /// * `ndims` - The `Type` of the `NDArray` number-of-dimensions `Literal`.
/// * `v` - The `NDArray` value. /// * `src_ndarray` - The `NDArray` value.
/// * `slice` - The slice expression used to subscript into the `ndarray`. /// * `slice` - The slice expression used to subscript into the `ndarray`.
fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
ty: Type, elem_ty: Type,
ndims: Type, ndims: Type,
v: NDArrayValue<'ctx>, src_ndarray: Pointer<'ctx, StructModel<NpArray<'ctx>>>,
slice: &Expr<Option<Type>>, slice: &Expr<Option<Type>>,
) -> Result<Option<ValueEnum<'ctx>>, String> { ) -> Result<Option<ValueEnum<'ctx>>, String> {
let llvm_i1 = ctx.ctx.bool_type(); // TODO: Support https://numpy.org/doc/stable/user/basics.indexing.html#dimensional-indexing-tools
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) else { let sizet = src_ndarray.element.0.sizet;
debug_assert_eq!(sizet.0, generator.get_size_type(ctx.ctx)); // If the ndarray's size_type somehow isn't that of `generator.get_size_type()`... there would be a bug
let slice_index_model = SliceIndexModel::default();
// Annoying notes about `slice`
// - `my_array[5]`
// - slice is a `Constant`
// - `my_array[:5]`
// - slice is a `Slice`
// - `my_array[:]`
// - slice is a `Slice`, but lower upper step would all be `Option::None`
// - `my_array[:, :]`
// - slice is now a `Tuple` of two `Slice`-s
//
// In summary:
// - when there is a comma "," within [], `slice` will be a `Tuple` of the entries.
// - when there is not comma "," within [] (i.e., just a single entry), `slice` will be that entry itself.
//
// So we first "flatten" out the slice expression
let subscript_exprs = match &slice.node {
ExprKind::Tuple { elts, .. } => elts.iter().collect_vec(),
_ => vec![slice],
};
// Process all subscript expressions in subscripts
let mut rust_ndsubscripts: Vec<RustNDSubscript> = Vec::with_capacity(subscript_exprs.len()); // Not using iterators here because `?` is used here.
for subscript_expr in subscript_exprs {
// NOTE: Currently nac3core's slices do not have an object representation,
// so the code/implementation looks awkward - we have to do pattern matching on the expression
let ndsubscript =
if let ExprKind::Slice { lower: start, upper: stop, step } = &subscript_expr.node {
// Helper function here to deduce code duplication
type ValueExpr = Option<Box<Located<ExprKind<Option<Type>>, Option<Type>>>>;
let mut help = |value_expr: &ValueExpr| -> Result<_, String> {
Ok(match value_expr {
None => None,
Some(value_expr) => Some(
slice_index_model.check_llvm_value(
generator
.gen_expr(ctx, value_expr)?
.unwrap()
.to_basic_value_enum(ctx, generator, ctx.primitives.int32)?
.as_any_value_enum(),
),
),
})
};
let start = help(start)?;
let stop = help(stop)?;
let step = help(step)?;
RustNDSubscript::Slice(RustUserSlice { start, stop, step })
} else {
// Anything else that is not a slice (might be illegal values),
// For nac3core, this should be e.g., an int32 constant, an int32 variable, otherwise its an error
let index = slice_index_model.check_llvm_value(
generator
.gen_expr(ctx, subscript_expr)?
.unwrap()
.to_basic_value_enum(ctx, generator, ctx.primitives.int32)?
.as_any_value_enum(),
);
RustNDSubscript::Index(index)
};
rust_ndsubscripts.push(ndsubscript);
}
// Extract the `ndims` from a `Type` to `i128`
// We *HAVE* to know this statically, this is used to determine
// whether the subscript returns a scalar or an ndarray
let TypeEnum::TLiteral { values: ndims_values, .. } = &*ctx.unifier.get_ty_immutable(ndims)
else {
unreachable!() unreachable!()
}; };
assert_eq!(ndims_values.len(), 1);
let src_ndims = i128::try_from(ndims_values[0].clone()).unwrap();
let ndims = values // Check for "too many indices for array: array is ..." error
.iter() if src_ndims < rust_ndsubscripts.len() as i128 {
.map(|ndim| u64::try_from(ndim.clone()).map_err(|()| ndim.clone())) ctx.make_assert(
.collect::<Result<Vec<_>, _>>()
.map_err(|val| {
format!(
"Expected non-negative literal for ndarray.ndims, got {}",
i128::try_from(val).unwrap()
)
})?;
assert!(!ndims.is_empty());
// The number of dimensions subscripted by the index expression.
// Slicing a ndarray will yield the same number of dimensions, whereas indexing into a
// dimension will remove a dimension.
let subscripted_dims = match &slice.node {
ExprKind::Tuple { elts, .. } => elts.iter().fold(0, |acc, value_subexpr| {
if let ExprKind::Slice { .. } = &value_subexpr.node {
acc
} else {
acc + 1
}
}),
ExprKind::Slice { .. } => 0,
_ => 1,
};
let ndarray_ndims_ty = ctx.unifier.get_fresh_literal(
ndims.iter().map(|v| SymbolValue::U64(v - subscripted_dims)).collect(),
None,
);
let ndarray_ty =
make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(ty), Some(ndarray_ndims_ty));
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum();
// Check that len is non-zero
let len = v.load_ndims(ctx);
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::SGT, len, llvm_usize.const_zero(), "").unwrap(),
"0:IndexError",
"too many indices for array: array is {0}-dimensional but 1 were indexed",
[Some(len), None, None],
slice.location,
);
// Normalizes a possibly-negative index to its corresponding positive index
let normalize_index = |generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
index: IntValue<'ctx>,
dim: u64| {
gen_if_else_expr_callback(
generator, generator,
ctx, ctx.ctx.bool_type().const_int(1, false),
|_, ctx| { "0:IndexError",
Ok(ctx "too many indices for array: array is {0}-dimensional, but {1} were indexed",
.builder [None, None, None],
.build_int_compare(IntPredicate::SGE, index, index.get_type().const_zero(), "") ctx.current_loc,
.unwrap()) );
}, }
|_, _| Ok(Some(index)),
|generator, ctx| {
let llvm_i32 = ctx.ctx.i32_type();
let len = unsafe { // Statically determine `dst_ndims`
v.dim_sizes().get_typed_unchecked( let dst_ndims =
ctx, RustNDSubscript::deduce_ndims_after_slicing(&rust_ndsubscripts, src_ndims as i32);
generator,
&llvm_usize.const_int(dim, true),
None,
)
};
let index = ctx // Prepare dst_ndarray
.builder let elem_llvm_ty = ctx.get_llvm_type(generator, elem_ty);
.build_int_add( let dst_ndarray = alloca_ndarray_and_init(
len, generator,
ctx.builder.build_int_s_extend(index, llvm_usize, "").unwrap(), ctx,
"", elem_llvm_ty,
) ndarray::NDArrayInitMode::NDims { ndims: sizet.constant(dst_ndims as u64) },
.unwrap(); "subndarray",
)
.unwrap();
Ok(Some(ctx.builder.build_int_truncate(index, llvm_i32, "").unwrap())) // Prepare the subscripts
}, let ndsubscript_array = RustNDSubscript::alloca_subscripts(ctx, &rust_ndsubscripts);
)
.map(|v| v.map(BasicValueEnum::into_int_value))
};
// Converts a slice expression into a slice-range tuple // NOTE: IRRT does check for indexing errors
let expr_to_slice = |generator: &mut G, call_nac3_ndarray_subscript(
ctx: &mut CodeGenContext<'ctx, '_>, generator,
node: &ExprKind<Option<Type>>, ctx,
dim: u64| { ndsubscript_array.num_elements.signed_cast_to_fixed(ctx, Int32, "num_ndsubscripts"),
match node { ndsubscript_array.pointer,
ExprKind::Constant { value: Constant::Int(v), .. } => { src_ndarray,
let Some(index) = dst_ndarray,
normalize_index(generator, ctx, llvm_i32.const_int(*v as u64, true), dim)? );
else {
return Ok(None);
};
Ok(Some((index, index, llvm_i32.const_int(1, true)))) // ...and return the result, with two cases
} let result_llvm_value = if dst_ndims == 0 {
// 1) ndims == 0 (this happens when you do `np.zerps((3, 4))[1, 1]`), return *THE ELEMENT*
ExprKind::Slice { lower, upper, step } => { let element_ptr = dst_ndarray.gep(ctx, |f| f.data).load(ctx, "pelement"); // `*data` points to the first element by definition
let dim_sz = unsafe { element_ptr.cast_opaque_to(ctx, elem_llvm_ty, "").load_opaque(ctx, "element")
v.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(dim, false),
None,
)
};
handle_slice_indices(lower, upper, step, ctx, generator, dim_sz)
}
_ => {
let Some(index) = generator.gen_expr(ctx, slice)? else { return Ok(None) };
let index = index
.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?
.into_int_value();
let Some(index) = normalize_index(generator, ctx, index, dim)? else {
return Ok(None);
};
Ok(Some((index, index, llvm_i32.const_int(1, true))))
}
}
};
let make_indices_arr = |generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>|
-> Result<_, String> {
Ok(if let ExprKind::Tuple { elts, .. } = &slice.node {
let llvm_int_ty = ctx.get_llvm_type(generator, elts[0].custom.unwrap());
let index_addr = generator.gen_array_var_alloc(
ctx,
llvm_int_ty,
llvm_usize.const_int(elts.len() as u64, false),
None,
)?;
for (i, elt) in elts.iter().enumerate() {
let Some(index) = generator.gen_expr(ctx, elt)? else {
return Ok(None);
};
let index = index
.to_basic_value_enum(ctx, generator, elt.custom.unwrap())?
.into_int_value();
let Some(index) = normalize_index(generator, ctx, index, 0)? else {
return Ok(None);
};
let store_ptr = unsafe {
index_addr.ptr_offset_unchecked(
ctx,
generator,
&llvm_usize.const_int(i as u64, false),
None,
)
};
ctx.builder.build_store(store_ptr, index).unwrap();
}
Some(index_addr)
} else if let Some(index) = generator.gen_expr(ctx, slice)? {
let llvm_int_ty = ctx.get_llvm_type(generator, slice.custom.unwrap());
let index_addr = generator.gen_array_var_alloc(
ctx,
llvm_int_ty,
llvm_usize.const_int(1u64, false),
None,
)?;
let index =
index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value();
let Some(index) = normalize_index(generator, ctx, index, 0)? else { return Ok(None) };
let store_ptr = unsafe {
index_addr.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
ctx.builder.build_store(store_ptr, index).unwrap();
Some(index_addr)
} else {
None
})
};
Ok(Some(if ndims.len() == 1 && ndims[0] - subscripted_dims == 0 {
let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) };
v.data().get(ctx, generator, &index_addr, None).into()
} else { } else {
match &slice.node { // 2) ndims > 0 (other cases), return subndarray
ExprKind::Tuple { elts, .. } => { dst_ndarray.value.as_basic_value_enum()
let slices = elts };
.iter() Ok(Some(ValueEnum::Dynamic(result_llvm_value)))
.enumerate()
.map(|(dim, elt)| expr_to_slice(generator, ctx, &elt.node, dim as u64))
.take_while_inclusive(|slice| slice.as_ref().is_ok_and(Option::is_some))
.collect::<Result<Vec<_>, _>>()?;
if slices.len() < elts.len() {
return Ok(None);
}
let slices = slices.into_iter().map(Option::unwrap).collect_vec();
numpy::ndarray_sliced_copy(generator, ctx, ty, v, &slices)?.as_base_value().into()
}
ExprKind::Slice { .. } => {
let Some(slice) = expr_to_slice(generator, ctx, &slice.node, 0)? else {
return Ok(None);
};
numpy::ndarray_sliced_copy(generator, ctx, ty, v, &[slice])?.as_base_value().into()
}
_ => {
// Accessing an element from a multi-dimensional `ndarray`
let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) };
// Create a new array, remove the top dimension from the dimension-size-list, and copy the
// elements over
let subscripted_ndarray =
generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?;
let ndarray = NDArrayValue::from_ptr_val(subscripted_ndarray, llvm_usize, None);
let num_dims = v.load_ndims(ctx);
ndarray.store_ndims(
ctx,
generator,
ctx.builder
.build_int_sub(num_dims, llvm_usize.const_int(1, false), "")
.unwrap(),
);
let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
let ndarray_num_dims = ndarray.load_ndims(ctx);
let v_dims_src_ptr = unsafe {
v.dim_sizes().ptr_offset_unchecked(
ctx,
generator,
&llvm_usize.const_int(1, false),
None,
)
};
call_memcpy_generic(
ctx,
ndarray.dim_sizes().base_ptr(ctx, generator),
v_dims_src_ptr,
ctx.builder
.build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "")
.map(Into::into)
.unwrap(),
llvm_i1.const_zero(),
);
let ndarray_num_elems = call_ndarray_calc_size(
generator,
ctx,
&ndarray.dim_sizes().as_slice_value(ctx, generator),
(None, None),
);
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
let v_data_src_ptr = v.data().ptr_offset(ctx, generator, &index_addr, None);
call_memcpy_generic(
ctx,
ndarray.data().base_ptr(ctx, generator),
v_data_src_ptr,
ctx.builder
.build_int_mul(
ndarray_num_elems,
llvm_ndarray_data_t.size_of().unwrap(),
"",
)
.map(Into::into)
.unwrap(),
llvm_i1.const_zero(),
);
ndarray.as_base_value().into()
}
}
}))
} }
/// See [`CodeGenerator::gen_expr`]. /// See [`CodeGenerator::gen_expr`].
@ -3088,17 +2924,26 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
} }
} }
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::NDArray.id() => {
let (ty, ndims) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap(); let (elem_ty, ndims) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap();
let v = if let Some(v) = generator.gen_expr(ctx, value)? { let ndarray_ptr = if let Some(v) = generator.gen_expr(ctx, value)? {
v.to_basic_value_enum(ctx, generator, value.custom.unwrap())? let sizet = IntModel(generator.get_size_type(ctx.ctx));
.into_pointer_value() let ndarray_ptr_model = PointerModel(StructModel(NpArray { sizet }));
let v = v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?;
ndarray_ptr_model.check_llvm_value(v.as_any_value_enum())
} else { } else {
return Ok(None); return Ok(None);
}; };
let v = NDArrayValue::from_ptr_val(v, usize, None);
return gen_ndarray_subscript_expr(generator, ctx, *ty, *ndims, v, slice); return gen_ndarray_subscript_expr(
generator,
ctx,
*elem_ty,
*ndims,
ndarray_ptr,
slice,
);
} }
TypeEnum::TTuple { .. } => { TypeEnum::TTuple { .. } => {
let index: u32 = let index: u32 =

View File

@ -7,7 +7,7 @@ pub struct StrFields<'ctx> {
pub length: Field<IntModel<'ctx>>, pub length: Field<IntModel<'ctx>>,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone, Copy)]
pub struct Str<'ctx> { pub struct Str<'ctx> {
pub sizet: IntModel<'ctx>, pub sizet: IntModel<'ctx>,
} }
@ -33,8 +33,10 @@ pub struct ErrorIdsFields {
pub value_error: Field<FixedIntModel<ErrorId>>, pub value_error: Field<FixedIntModel<ErrorId>>,
pub assertion_error: Field<FixedIntModel<ErrorId>>, pub assertion_error: Field<FixedIntModel<ErrorId>>,
pub runtime_error: Field<FixedIntModel<ErrorId>>, pub runtime_error: Field<FixedIntModel<ErrorId>>,
pub type_error: Field<FixedIntModel<ErrorId>>,
} }
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Copy)]
pub struct ErrorIds; pub struct ErrorIds;
impl<'ctx> IsStruct<'ctx> for ErrorIds { impl<'ctx> IsStruct<'ctx> for ErrorIds {
@ -50,6 +52,7 @@ impl<'ctx> IsStruct<'ctx> for ErrorIds {
value_error: builder.add_field_auto("value_error"), value_error: builder.add_field_auto("value_error"),
assertion_error: builder.add_field_auto("assertion_error"), assertion_error: builder.add_field_auto("assertion_error"),
runtime_error: builder.add_field_auto("runtime_error"), runtime_error: builder.add_field_auto("runtime_error"),
type_error: builder.add_field_auto("type_error"),
} }
} }
} }
@ -61,7 +64,8 @@ pub struct ErrorContextFields {
pub param2: Field<FixedIntModel<Int64>>, pub param2: Field<FixedIntModel<Int64>>,
pub param3: Field<FixedIntModel<Int64>>, pub param3: Field<FixedIntModel<Int64>>,
} }
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Copy)]
pub struct ErrorContext; pub struct ErrorContext;
impl<'ctx> IsStruct<'ctx> for ErrorContext { impl<'ctx> IsStruct<'ctx> for ErrorContext {
@ -92,46 +96,47 @@ fn build_error_ids<'ctx>(ctx: &CodeGenContext<'ctx, '_>) -> Pointer<'ctx, Struct
let get_string_id = let get_string_id =
|string_id| i32_model.constant(ctx.ctx, ctx.resolver.get_string_id(string_id) as u64); |string_id| i32_model.constant(ctx.ctx, ctx.resolver.get_string_id(string_id) as u64);
error_ids.gep(ctx, |f| f.index_error).store(ctx, &get_string_id("0:IndexError")); error_ids.gep(ctx, |f| f.index_error).store(ctx, get_string_id("0:IndexError"));
error_ids.gep(ctx, |f| f.value_error).store(ctx, &get_string_id("0:ValueError")); error_ids.gep(ctx, |f| f.value_error).store(ctx, get_string_id("0:ValueError"));
error_ids.gep(ctx, |f| f.assertion_error).store(ctx, &get_string_id("0:AssertionError")); error_ids.gep(ctx, |f| f.assertion_error).store(ctx, get_string_id("0:AssertionError"));
error_ids.gep(ctx, |f| f.runtime_error).store(ctx, &get_string_id("0:RuntimeError")); error_ids.gep(ctx, |f| f.runtime_error).store(ctx, get_string_id("0:RuntimeError"));
error_ids.gep(ctx, |f| f.type_error).store(ctx, get_string_id("0:TypeError"));
error_ids error_ids
} }
pub fn call_nac3_error_context_initialize<'ctx>( pub fn call_nac3_error_context_initialize<'ctx>(
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
perrctx: &Pointer<'ctx, StructModel<ErrorContext>>, perrctx: Pointer<'ctx, StructModel<ErrorContext>>,
perror_ids: &Pointer<'ctx, StructModel<ErrorIds>>, perror_ids: Pointer<'ctx, StructModel<ErrorIds>>,
) { ) {
FunctionBuilder::begin(ctx, "__nac3_error_context_initialize") FunctionBuilder::begin(ctx, "__nac3_error_context_initialize")
.arg("errctx", &PointerModel(StructModel(ErrorContext)), perrctx) .arg("errctx", PointerModel(StructModel(ErrorContext)), perrctx)
.arg("error_ids", &PointerModel(StructModel(ErrorIds)), perror_ids) .arg("error_ids", PointerModel(StructModel(ErrorIds)), perror_ids)
.returning_void(); .returning_void();
} }
pub fn call_nac3_error_context_has_no_error<'ctx>( pub fn call_nac3_error_context_has_no_error<'ctx>(
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
errctx: &Pointer<'ctx, StructModel<ErrorContext>>, errctx: Pointer<'ctx, StructModel<ErrorContext>>,
) -> FixedInt<'ctx, Bool> { ) -> FixedInt<'ctx, Bool> {
FunctionBuilder::begin(ctx, "__nac3_error_context_has_no_error") FunctionBuilder::begin(ctx, "__nac3_error_context_has_no_error")
.arg("errctx", &PointerModel(StructModel(ErrorContext)), errctx) .arg("errctx", PointerModel(StructModel(ErrorContext)), errctx)
.returning("has_error", &FixedIntModel(Bool)) .returning("has_error", FixedIntModel(Bool))
} }
pub fn call_nac3_error_context_get_error_str<'ctx>( pub fn call_nac3_error_context_get_error_str<'ctx>(
sizet: IntModel<'ctx>, sizet: IntModel<'ctx>,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
errctx: &Pointer<'ctx, StructModel<ErrorContext>>, errctx: Pointer<'ctx, StructModel<ErrorContext>>,
dst_str: &Pointer<'ctx, StructModel<Str<'ctx>>>, dst_str: Pointer<'ctx, StructModel<Str<'ctx>>>,
) { ) {
FunctionBuilder::begin( FunctionBuilder::begin(
ctx, ctx,
&get_sized_dependent_function_name(sizet, "__nac3_error_context_get_error_str"), &get_sized_dependent_function_name(sizet, "__nac3_error_context_get_error_str"),
) )
.arg("errctx", &PointerModel(StructModel(ErrorContext)), errctx) .arg("errctx", PointerModel(StructModel(ErrorContext)), errctx)
.arg("dst_str", &PointerModel(StructModel(Str { sizet })), dst_str) .arg("dst_str", PointerModel(StructModel(Str { sizet })), dst_str)
.returning_void(); .returning_void();
} }
@ -140,20 +145,20 @@ pub fn prepare_error_context<'ctx>(
) -> Pointer<'ctx, StructModel<ErrorContext>> { ) -> Pointer<'ctx, StructModel<ErrorContext>> {
let error_ids = build_error_ids(ctx); let error_ids = build_error_ids(ctx);
let errctx_ptr = StructModel(ErrorContext).alloca(ctx, "errctx"); let errctx_ptr = StructModel(ErrorContext).alloca(ctx, "errctx");
call_nac3_error_context_initialize(ctx, &errctx_ptr, &error_ids); call_nac3_error_context_initialize(ctx, errctx_ptr, error_ids);
errctx_ptr errctx_ptr
} }
pub fn check_error_context<'ctx, G: CodeGenerator + ?Sized>( pub fn check_error_context<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
errctx_ptr: &Pointer<'ctx, StructModel<ErrorContext>>, errctx_ptr: Pointer<'ctx, StructModel<ErrorContext>>,
) { ) {
let sizet = IntModel(generator.get_size_type(ctx.ctx)); let sizet = IntModel(generator.get_size_type(ctx.ctx));
let has_error = call_nac3_error_context_has_no_error(ctx, errctx_ptr); let has_error = call_nac3_error_context_has_no_error(ctx, errctx_ptr);
let pstr = StructModel(Str { sizet }).alloca(ctx, "error_str"); let pstr = StructModel(Str { sizet }).alloca(ctx, "error_str");
call_nac3_error_context_get_error_str(sizet, ctx, errctx_ptr, &pstr); call_nac3_error_context_get_error_str(sizet, ctx, errctx_ptr, pstr);
let error_id = errctx_ptr.gep(ctx, |f| f.error_id).load(ctx, "error_id"); let error_id = errctx_ptr.gep(ctx, |f| f.error_id).load(ctx, "error_id");
let error_str = pstr.load(ctx, "error_str"); let error_str = pstr.load(ctx, "error_str");
@ -176,7 +181,7 @@ pub fn call_nac3_dummy_raise<G: CodeGenerator + ?Sized>(
) { ) {
let errctx = prepare_error_context(ctx); let errctx = prepare_error_context(ctx);
FunctionBuilder::begin(ctx, "__nac3_error_dummy_raise") FunctionBuilder::begin(ctx, "__nac3_error_dummy_raise")
.arg("errctx", &PointerModel(StructModel(ErrorContext)), &errctx) .arg("errctx", PointerModel(StructModel(ErrorContext)), errctx)
.returning_void(); .returning_void();
check_error_context(generator, ctx, &errctx); check_error_context(generator, ctx, errctx);
} }

View File

@ -1,5 +1,4 @@
pub mod ndarray; pub mod ndarray;
pub mod shape; pub mod shape;
pub mod slice;
pub use ndarray::*; pub mod subscript;
pub use shape::*;

View File

@ -9,10 +9,13 @@ use crate::codegen::{
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}; };
use super::Producer; use super::{
shape::Producer,
slice::{SliceIndex, SliceIndexModel},
};
pub struct NpArrayFields<'ctx> { pub struct NpArrayFields<'ctx> {
pub data: Field<OpaquePointerModel>, pub data: Field<PointerModel<ByteModel>>,
pub itemsize: Field<IntModel<'ctx>>, pub itemsize: Field<IntModel<'ctx>>,
pub ndims: Field<IntModel<'ctx>>, pub ndims: Field<IntModel<'ctx>>,
pub shape: Field<PointerModel<IntModel<'ctx>>>, pub shape: Field<PointerModel<IntModel<'ctx>>>,
@ -63,7 +66,7 @@ pub fn alloca_ndarray<'ctx, G>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
elem_type: BasicTypeEnum<'ctx>, elem_type: BasicTypeEnum<'ctx>,
ndims: &Int<'ctx>, ndims: Int<'ctx>,
name: &str, name: &str,
) -> Result<Pointer<'ctx, StructModel<NpArray<'ctx>>>, String> ) -> Result<Pointer<'ctx, StructModel<NpArray<'ctx>>>, String>
where where
@ -78,19 +81,16 @@ where
ndarray_ptr.gep(ctx, |f| f.ndims).store(ctx, ndims); ndarray_ptr.gep(ctx, |f| f.ndims).store(ctx, ndims);
// Set itemsize // Set itemsize
let itemsize = elem_type.size_of().unwrap(); let itemsize = Int(elem_type.size_of().unwrap());
let itemsize = ndarray_ptr.gep(ctx, |f| f.itemsize).store(ctx, itemsize.signed_cast_to_int(ctx, sizet, ""));
ctx.builder.build_int_s_extend_or_bit_cast(itemsize, sizet.0, "itemsize").unwrap();
ndarray_ptr.gep(ctx, |f| f.itemsize).store(ctx, &Int(itemsize));
// Allocate and set shape // Allocate and set shape
let shape_ptr = ctx.builder.build_array_alloca(sizet.0, ndims.0, "shape").unwrap(); let shape_array = sizet.array_alloca(ctx, ndims, "shape");
ndarray_ptr.gep(ctx, |f| f.shape).store(ctx, &Pointer { element: sizet, value: shape_ptr }); ndarray_ptr.gep(ctx, |f| f.shape).store(ctx, shape_array.pointer);
// .store(ctx, &Pointer { addressee_optic: IntLens(sizet), address: shape_ptr });
// Allocate and set strides // Allocate and set strides
let strides_ptr = ctx.builder.build_array_alloca(sizet.0, ndims.0, "strides").unwrap(); let strides_array = sizet.array_alloca(ctx, ndims, "strides");
ndarray_ptr.gep(ctx, |f| f.strides).store(ctx, &Pointer { element: sizet, value: strides_ptr }); ndarray_ptr.gep(ctx, |f| f.strides).store(ctx, strides_array.pointer);
Ok(ndarray_ptr) Ok(ndarray_ptr)
} }
@ -115,12 +115,12 @@ where
// It is implemented verbosely in order to make the initialization modes super clear in their intent. // It is implemented verbosely in order to make the initialization modes super clear in their intent.
match init_mode { match init_mode {
NDArrayInitMode::NDims { ndims } => { NDArrayInitMode::NDims { ndims } => {
let ndarray_ptr = alloca_ndarray(generator, ctx, elem_type, &ndims, name)?; let ndarray_ptr = alloca_ndarray(generator, ctx, elem_type, ndims, name)?;
Ok(ndarray_ptr) Ok(ndarray_ptr)
} }
NDArrayInitMode::Shape { shape } => { NDArrayInitMode::Shape { shape } => {
let ndims = shape.count; let ndims = shape.count;
let ndarray_ptr = alloca_ndarray(generator, ctx, elem_type, &ndims, name)?; let ndarray_ptr = alloca_ndarray(generator, ctx, elem_type, ndims, name)?;
// Fill `ndarray.shape` // Fill `ndarray.shape`
(shape.write_to_array)(generator, ctx, &ndarray_ptr.shape_slice(ctx))?; (shape.write_to_array)(generator, ctx, &ndarray_ptr.shape_slice(ctx))?;
@ -129,8 +129,8 @@ where
call_nac3_ndarray_util_assert_shape_no_negative( call_nac3_ndarray_util_assert_shape_no_negative(
generator, generator,
ctx, ctx,
&ndims, ndims,
&ndarray_ptr.gep(ctx, |f| f.shape).load(ctx, "shape"), ndarray_ptr.gep(ctx, |f| f.shape).load(ctx, "shape"),
); );
// NOTE: DO NOT DO `set_strides_by_shape` HERE. // NOTE: DO NOT DO `set_strides_by_shape` HERE.
@ -140,7 +140,7 @@ where
} }
NDArrayInitMode::ShapeAndAllocaData { shape } => { NDArrayInitMode::ShapeAndAllocaData { shape } => {
let ndims = shape.count; let ndims = shape.count;
let ndarray_ptr = alloca_ndarray(generator, ctx, elem_type, &ndims, name)?; let ndarray_ptr = alloca_ndarray(generator, ctx, elem_type, ndims, name)?;
// Fill `ndarray.shape` // Fill `ndarray.shape`
(shape.write_to_array)(generator, ctx, &ndarray_ptr.shape_slice(ctx))?; (shape.write_to_array)(generator, ctx, &ndarray_ptr.shape_slice(ctx))?;
@ -149,26 +149,22 @@ where
call_nac3_ndarray_util_assert_shape_no_negative( call_nac3_ndarray_util_assert_shape_no_negative(
generator, generator,
ctx, ctx,
&ndims, ndims,
&ndarray_ptr.gep(ctx, |f| f.shape).load(ctx, "shape"), ndarray_ptr.gep(ctx, |f| f.shape).load(ctx, "shape"),
); );
// Now we populate `ndarray.data` by alloca-ing. // Now we populate `ndarray.data` by alloca-ing.
// But first, we need to know the size of the ndarray to know how many elements to alloca, // But first, we need to know the size of the ndarray to know how many elements to alloca,
// since calculating nbytes of an ndarray requires `ndarray.shape` to be set. // since calculating nbytes of an ndarray requires `ndarray.shape` to be set.
let ndarray_nbytes = call_nac3_ndarray_nbytes(generator, ctx, &ndarray_ptr); let ndarray_nbytes = call_nac3_ndarray_nbytes(ctx, ndarray_ptr);
// Alloca `data` and assign it to `ndarray.data` // Alloca `data` and assign it to `ndarray.data`
let data_ptr = OpaquePointer( let data_array = FixedIntModel(Byte).array_alloca(ctx, ndarray_nbytes, "data");
ctx.builder ndarray_ptr.gep(ctx, |f| f.data).store(ctx, data_array.pointer);
.build_array_alloca(ctx.ctx.i8_type(), ndarray_nbytes.0, "data")
.unwrap(),
);
ndarray_ptr.gep(ctx, |f| f.data).store(ctx, &data_ptr);
// Finally, do `set_strides_by_shape` // Finally, do `set_strides_by_shape`
// Check out https://ajcr.net/stride-guide-part-1/ to see what numpy "strides" are. // Check out https://ajcr.net/stride-guide-part-1/ to see what numpy "strides" are.
call_nac3_ndarray_set_strides_by_shape(generator, ctx, &ndarray_ptr); call_nac3_ndarray_set_strides_by_shape(ctx, ndarray_ptr);
Ok(ndarray_ptr) Ok(ndarray_ptr)
} }
@ -178,8 +174,8 @@ where
fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Sized>( fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
ndims: &Int<'ctx>, ndims: Int<'ctx>,
shape_ptr: &Pointer<'ctx, IntModel<'ctx>>, shape_ptr: Pointer<'ctx, IntModel<'ctx>>,
) { ) {
let sizet = IntModel(generator.get_size_type(ctx.ctx)); let sizet = IntModel(generator.get_size_type(ctx.ctx));
@ -188,53 +184,71 @@ fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Siz
ctx, ctx,
&get_sized_dependent_function_name(sizet, "__nac3_ndarray_util_assert_shape_no_negative"), &get_sized_dependent_function_name(sizet, "__nac3_ndarray_util_assert_shape_no_negative"),
) )
.arg("errctx", &PointerModel(StructModel(ErrorContext)), &errctx) .arg("errctx", PointerModel(StructModel(ErrorContext)), errctx)
.arg("ndims", &sizet, ndims) .arg("ndims", sizet, ndims)
.arg("shape", &PointerModel(sizet), shape_ptr) .arg("shape", PointerModel(sizet), shape_ptr)
.returning_void(); .returning_void();
check_error_context(generator, ctx, &errctx); check_error_context(generator, ctx, errctx);
} }
fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>( fn call_nac3_ndarray_set_strides_by_shape<'ctx>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
ndarray_ptr: &Pointer<'ctx, StructModel<NpArray<'ctx>>>, ndarray_ptr: Pointer<'ctx, StructModel<NpArray<'ctx>>>,
) { ) {
let sizet = IntModel(generator.get_size_type(ctx.ctx)); let sizet = ndarray_ptr.element.0.sizet;
FunctionBuilder::begin( FunctionBuilder::begin(
ctx, ctx,
&get_sized_dependent_function_name(sizet, "__nac3_ndarray_set_strides_by_shape"), &get_sized_dependent_function_name(sizet, "__nac3_ndarray_set_strides_by_shape"),
) )
.arg("ndarray", &PointerModel(StructModel(NpArray { sizet })), ndarray_ptr) .arg("ndarray", PointerModel(StructModel(NpArray { sizet })), ndarray_ptr)
.returning_void(); .returning_void();
} }
fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>( pub fn call_nac3_ndarray_nbytes<'ctx>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
ndarray_ptr: &Pointer<'ctx, StructModel<NpArray<'ctx>>>, ndarray_ptr: Pointer<'ctx, StructModel<NpArray<'ctx>>>,
) -> Int<'ctx> { ) -> Int<'ctx> {
let sizet = IntModel(generator.get_size_type(ctx.ctx)); let sizet = ndarray_ptr.element.0.sizet;
FunctionBuilder::begin(ctx, &get_sized_dependent_function_name(sizet, "__nac3_ndarray_nbytes")) FunctionBuilder::begin(ctx, &get_sized_dependent_function_name(sizet, "__nac3_ndarray_nbytes"))
.arg("ndarray", &PointerModel(StructModel(NpArray { sizet })), ndarray_ptr) .arg("ndarray", PointerModel(StructModel(NpArray { sizet })), ndarray_ptr)
.returning("nbytes", &sizet) .returning("nbytes", sizet)
} }
pub fn call_nac3_ndarray_fill_generic<'ctx, G: CodeGenerator + ?Sized>( pub fn call_nac3_ndarray_fill_generic<'ctx>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
ndarray_ptr: &Pointer<'ctx, StructModel<NpArray<'ctx>>>, ndarray_ptr: Pointer<'ctx, StructModel<NpArray<'ctx>>>,
fill_value_ptr: &OpaquePointer<'ctx>, fill_value_ptr: Pointer<'ctx, ByteModel>,
) { ) {
let sizet = IntModel(generator.get_size_type(ctx.ctx)); let sizet = ndarray_ptr.element.0.sizet;
FunctionBuilder::begin( FunctionBuilder::begin(
ctx, ctx,
&get_sized_dependent_function_name(sizet, "__nac3_ndarray_fill_generic"), &get_sized_dependent_function_name(sizet, "__nac3_ndarray_fill_generic"),
) )
.arg("ndarray", &PointerModel(StructModel(NpArray { sizet })), ndarray_ptr) .arg("ndarray", PointerModel(StructModel(NpArray { sizet })), ndarray_ptr)
.arg("pvalue", &OpaquePointerModel, fill_value_ptr) .arg("pvalue", PointerModel(FixedIntModel(Byte)), fill_value_ptr)
.returning_void(); .returning_void();
} }
pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray_ptr: Pointer<'ctx, StructModel<NpArray<'ctx>>>,
) -> SliceIndex<'ctx> {
let sizet = ndarray_ptr.element.0.sizet;
let slice_index_model = SliceIndexModel::default();
let dst_len = slice_index_model.alloca(ctx, "dst_len");
let errctx = prepare_error_context(ctx);
FunctionBuilder::begin(ctx, &get_sized_dependent_function_name(sizet, "__nac3_ndarray_len"))
.arg("errctx", PointerModel(StructModel(ErrorContext)), errctx)
.arg("ndarray", PointerModel(StructModel(NpArray { sizet })), ndarray_ptr)
.arg("dst_len", PointerModel(slice_index_model), dst_len)
.returning_void();
check_error_context(generator, ctx, errctx);
dst_len.load(ctx, "len")
}

View File

@ -85,7 +85,7 @@ where
.unwrap(); .unwrap();
// Write // Write
dst_array.ix(generator, ctx, axis, "dim").store(ctx, &Int(dim)); dst_array.ix(generator, ctx, axis, "dim").store(ctx, Int(dim));
Ok(()) Ok(())
}, },
incr_val, incr_val,
@ -127,7 +127,7 @@ where
// Write // Write
dst_array dst_array
.ix(generator, ctx, sizet.constant(axis as u64), "dim") .ix(generator, ctx, sizet.constant(axis as u64), "dim")
.store(ctx, &Int(dim)); .store(ctx, Int(dim));
} }
Ok(()) Ok(())
}), }),
@ -151,7 +151,7 @@ where
.unwrap(); .unwrap();
// Set shape[0] = shape_int // Set shape[0] = shape_int
dst_array.ix(generator, ctx, sizet.constant(0), "dim").store(ctx, &Int(dim)); dst_array.ix(generator, ctx, sizet.constant(0), "dim").store(ctx, Int(dim));
Ok(()) Ok(())
}), }),

View File

@ -0,0 +1,86 @@
use crate::codegen::{model::*, CodeGenContext};
// nac3core's slicing index/length values are always int32_t
pub type SliceIndexInt = Int32;
pub type SliceIndexModel = FixedIntModel<SliceIndexInt>;
pub type SliceIndex<'ctx> = FixedInt<'ctx, SliceIndexInt>;
#[derive(Debug, Clone)]
pub struct UserSliceFields {
pub start_defined: Field<BoolModel>,
pub start: Field<SliceIndexModel>,
pub stop_defined: Field<BoolModel>,
pub stop: Field<SliceIndexModel>,
pub step_defined: Field<BoolModel>,
pub step: Field<SliceIndexModel>,
}
#[derive(Debug, Clone, Copy, Default)]
pub struct UserSlice;
impl<'ctx> IsStruct<'ctx> for UserSlice {
type Fields = UserSliceFields;
fn struct_name(&self) -> &'static str {
"UserSlice"
}
fn build_fields(&self, builder: &mut FieldBuilder<'ctx>) -> Self::Fields {
Self::Fields {
start_defined: builder.add_field_auto("start_defined"),
start: builder.add_field_auto("start"),
stop_defined: builder.add_field_auto("stop_defined"),
stop: builder.add_field_auto("stop"),
step_defined: builder.add_field_auto("step_defined"),
step: builder.add_field_auto("step"),
}
}
}
#[derive(Debug, Clone)]
pub struct RustUserSlice<'ctx> {
pub start: Option<SliceIndex<'ctx>>,
pub stop: Option<SliceIndex<'ctx>>,
pub step: Option<SliceIndex<'ctx>>,
}
impl<'ctx> RustUserSlice<'ctx> {
// Set the values of an LLVM UserSlice
// in the format of Python's `slice()`
pub fn write_to_user_slice(
&self,
ctx: &CodeGenContext<'ctx, '_>,
dst_slice_ptr: Pointer<'ctx, StructModel<UserSlice>>,
) {
// TODO: make this neater, with a helper lambda?
let bool_model = BoolModel::default();
let false_ = bool_model.constant(ctx.ctx, 0);
let true_ = bool_model.constant(ctx.ctx, 1);
match self.start {
Some(start) => {
dst_slice_ptr.gep(ctx, |f| f.start_defined).store(ctx, true_);
dst_slice_ptr.gep(ctx, |f| f.start).store(ctx, start);
}
None => dst_slice_ptr.gep(ctx, |f| f.start_defined).store(ctx, false_),
}
match self.stop {
Some(stop) => {
dst_slice_ptr.gep(ctx, |f| f.stop_defined).store(ctx, true_);
dst_slice_ptr.gep(ctx, |f| f.stop).store(ctx, stop);
}
None => dst_slice_ptr.gep(ctx, |f| f.stop_defined).store(ctx, false_),
}
match self.step {
Some(step) => {
dst_slice_ptr.gep(ctx, |f| f.step_defined).store(ctx, true_);
dst_slice_ptr.gep(ctx, |f| f.step).store(ctx, step);
}
None => dst_slice_ptr.gep(ctx, |f| f.step_defined).store(ctx, false_),
}
}
}

View File

@ -0,0 +1,181 @@
use crate::codegen::{
irrt::{
error_context::{check_error_context, prepare_error_context, ErrorContext},
util::{get_sized_dependent_function_name, FunctionBuilder},
},
model::*,
CodeGenContext, CodeGenerator,
};
use super::{
ndarray::NpArray,
slice::{RustUserSlice, SliceIndex, SliceIndexModel, UserSlice},
};
#[derive(Debug, Clone, Copy)]
pub struct NDSubscriptFields {
pub type_: Field<ByteModel>, // Defined to be uint8_t in IRRT
pub data: Field<PointerModel<ByteModel>>,
}
#[derive(Debug, Clone, Copy, Default)]
pub struct NDSubscript;
impl<'ctx> IsStruct<'ctx> for NDSubscript {
type Fields = NDSubscriptFields;
fn struct_name(&self) -> &'static str {
"NDSubscript"
}
fn build_fields(&self, builder: &mut FieldBuilder<'ctx>) -> Self::Fields {
Self::Fields { type_: builder.add_field_auto("type"), data: builder.add_field_auto("data") }
}
}
// An enum variant to store the content
// and type of an NDSubscript in high level.
#[derive(Debug, Clone)]
pub enum RustNDSubscript<'ctx> {
Index(SliceIndex<'ctx>),
Slice(RustUserSlice<'ctx>),
}
impl<'ctx> RustNDSubscript<'ctx> {
fn irrt_subscript_id(&self) -> u64 {
// Defined in IRRT
match self {
RustNDSubscript::Index(_) => 0,
RustNDSubscript::Slice(_) => 1,
}
}
fn write_to_ndsubscript(
&self,
ctx: &CodeGenContext<'ctx, '_>,
dst_ndsubscript_ptr: Pointer<'ctx, StructModel<NDSubscript>>,
) {
let byte_model = ByteModel::default();
let slice_index_model = SliceIndexModel::default();
let user_slice_model = StructModel(UserSlice);
// Set `dst_ndsubscript_ptr->type`
dst_ndsubscript_ptr
.gep(ctx, |f| f.type_)
.store(ctx, byte_model.constant(ctx.ctx, self.irrt_subscript_id()));
// Set `dst_ndsubscript_ptr->data`
let data = match self {
RustNDSubscript::Index(in_index) => {
let index_ptr = slice_index_model.alloca(ctx, "index");
index_ptr.store(ctx, *in_index);
index_ptr.cast_to(ctx, FixedIntModel(Byte), "")
}
RustNDSubscript::Slice(in_rust_slice) => {
let user_slice_ptr = user_slice_model.alloca(ctx, "user_slice");
in_rust_slice.write_to_user_slice(ctx, user_slice_ptr);
user_slice_ptr.cast_to(ctx, FixedIntModel(Byte), "")
}
};
dst_ndsubscript_ptr.gep(ctx, |f| f.data).store(ctx, data);
}
// Allocate an array of subscripts onto the stack and return its stack pointer
pub fn alloca_subscripts(
ctx: &CodeGenContext<'ctx, '_>,
subscripts: &[RustNDSubscript<'ctx>],
) -> ArraySlice<'ctx, StructModel<NDSubscript>> {
let index_model = Int32Model::default();
let ndsubscript_model = StructModel(NDSubscript);
let ndsubscript_array = ndsubscript_model.array_alloca(
ctx,
index_model.constant(ctx.ctx, subscripts.len() as u64).to_int(),
"ndsubscripts",
);
for (i, rust_ndsubscript) in subscripts.iter().enumerate() {
let ndsubscript_ptr = ndsubscript_array.ix_unchecked(
ctx,
index_model.constant(ctx.ctx, i as u64).to_int(),
"",
);
rust_ndsubscript.write_to_ndsubscript(ctx, ndsubscript_ptr);
}
ndsubscript_array
}
#[must_use]
pub fn deduce_ndims_after_slicing(slices: &[RustNDSubscript], original_ndims: i32) -> i32 {
let mut final_ndims: i32 = original_ndims;
for slice in slices {
match slice {
RustNDSubscript::Index(_) => {
// Index subscripts demotes the rank by 1
final_ndims -= 1;
}
RustNDSubscript::Slice(_) => {
// Nothing
}
}
}
final_ndims
}
}
pub fn call_nac3_ndarray_subscript_deduce_ndims_after_slicing<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
sizet: IntModel<'ctx>,
ndims: Int<'ctx>,
num_ndsubscripts: Int<'ctx>,
ndsubscripts: Pointer<'ctx, StructModel<NDSubscript>>,
) -> Int<'ctx> {
let result = sizet.alloca(ctx, "result");
let errctx_ptr = prepare_error_context(ctx);
FunctionBuilder::begin(
ctx,
&get_sized_dependent_function_name(
sizet,
"__nac3_ndarray_subscript_deduce_ndims_after_slicing",
),
)
.arg("errctx", PointerModel(StructModel(ErrorContext)), errctx_ptr)
.arg("result", PointerModel(sizet), result)
.arg("ndims", sizet, ndims)
.arg("num_ndsubscripts", sizet, num_ndsubscripts)
.arg("ndsubscripts", PointerModel(StructModel(NDSubscript)), ndsubscripts)
.returning_void();
check_error_context(generator, ctx, errctx_ptr);
result.load(ctx, "final_ndims")
}
pub fn call_nac3_ndarray_subscript<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
num_subscripts: FixedInt<'ctx, Int32>,
subscripts: Pointer<'ctx, StructModel<NDSubscript>>,
src_ndarray: Pointer<'ctx, StructModel<NpArray<'ctx>>>,
dst_ndarray: Pointer<'ctx, StructModel<NpArray<'ctx>>>,
) {
let sizet = src_ndarray.element.0.sizet;
assert!(sizet.same_as(dst_ndarray.element.0.sizet)); // SizeT of src_ndarray and dst_ndarray must match
let errctx_ptr = prepare_error_context(ctx);
FunctionBuilder::begin(
ctx,
&get_sized_dependent_function_name(sizet, "__nac3_ndarray_subscript"),
)
.arg("errctx", PointerModel(StructModel(ErrorContext)), errctx_ptr)
.arg("num_subscripts", FixedIntModel(Int32), num_subscripts)
.arg("subscripts", PointerModel(StructModel(NDSubscript)), subscripts)
.arg("src_ndarray", PointerModel(StructModel(NpArray { sizet })), src_ndarray)
.arg("dst_ndarray", PointerModel(StructModel(NpArray { sizet })), dst_ndarray)
.returning_void();
check_error_context(generator, ctx, errctx_ptr);
}

View File

@ -45,13 +45,13 @@ impl<'ctx, 'a> FunctionBuilder<'ctx, 'a> {
// The name is for self-documentation // The name is for self-documentation
#[must_use] #[must_use]
pub fn arg<M: Model<'ctx>>(mut self, _name: &'static str, model: &M, value: &M::Value) -> Self { pub fn arg<M: Model<'ctx>>(mut self, _name: &'static str, model: M, value: M::Value) -> Self {
self.arguments self.arguments
.push((model.get_llvm_type(self.ctx.ctx).into(), value.get_llvm_value().into())); .push((model.get_llvm_type(self.ctx.ctx).into(), value.get_llvm_value().into()));
self self
} }
pub fn returning<M: Model<'ctx>>(self, name: &'static str, return_model: &M) -> M::Value { pub fn returning<M: Model<'ctx>>(self, name: &'static str, return_model: M) -> M::Value {
let (param_tys, param_vals): (Vec<_>, Vec<_>) = self.arguments.into_iter().unzip(); let (param_tys, param_vals): (Vec<_>, Vec<_>) = self.arguments.into_iter().unzip();
let function = self.ctx.module.get_function(self.fn_name).unwrap_or_else(|| { let function = self.ctx.module.get_function(self.fn_name).unwrap_or_else(|| {

View File

@ -23,7 +23,7 @@ use inkwell::{
values::{BasicValueEnum, FunctionValue, IntValue, PhiValue, PointerValue}, values::{BasicValueEnum, FunctionValue, IntValue, PhiValue, PointerValue},
AddressSpace, IntPredicate, OptimizationLevel, AddressSpace, IntPredicate, OptimizationLevel,
}; };
use irrt::{error_context::Str, numpy::NpArray}; use irrt::{error_context::Str, numpy::ndarray::NpArray};
use itertools::Itertools; use itertools::Itertools;
use model::*; use model::*;
use nac3parser::ast::{Location, Stmt, StrRef}; use nac3parser::ast::{Location, Stmt, StrRef};

View File

@ -16,13 +16,13 @@ use inkwell::{
use crate::codegen::CodeGenContext; use crate::codegen::CodeGenContext;
use super::Pointer; use super::{slice::ArraySlice, Int, Pointer};
pub trait ModelValue<'ctx> { pub trait ModelValue<'ctx>: Clone + Copy {
fn get_llvm_value(&self) -> BasicValueEnum<'ctx>; fn get_llvm_value(&self) -> BasicValueEnum<'ctx>;
} }
pub trait Model<'ctx>: Clone { pub trait Model<'ctx>: Clone + Copy {
type Value: ModelValue<'ctx>; type Value: ModelValue<'ctx>;
fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx>; fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx>;
@ -30,8 +30,26 @@ pub trait Model<'ctx>: Clone {
fn alloca(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> Pointer<'ctx, Self> { fn alloca(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> Pointer<'ctx, Self> {
Pointer { Pointer {
element: self.clone(), element: *self,
value: ctx.builder.build_alloca(self.get_llvm_type(ctx.ctx), name).unwrap(), value: ctx.builder.build_alloca(self.get_llvm_type(ctx.ctx), name).unwrap(),
} }
} }
fn array_alloca(
&self,
ctx: &CodeGenContext<'ctx, '_>,
count: Int<'ctx>,
name: &str,
) -> ArraySlice<'ctx, Self> {
ArraySlice {
num_elements: count,
pointer: Pointer {
element: *self,
value: ctx
.builder
.build_array_alloca(self.get_llvm_type(ctx.ctx), count.0, name)
.unwrap(),
},
}
}
} }

View File

@ -9,7 +9,7 @@ use crate::codegen::CodeGenContext;
use super::{Model, ModelValue, Pointer}; use super::{Model, ModelValue, Pointer};
#[derive(Debug, Clone)] #[derive(Debug, Clone, Copy)]
pub struct Field<E> { pub struct Field<E> {
pub gep_index: u64, pub gep_index: u64,
pub name: &'static str, pub name: &'static str,
@ -17,7 +17,7 @@ pub struct Field<E> {
} }
// Like [`Field<E>`] but element must be [`BasicTypeEnum<'ctx>`] // Like [`Field<E>`] but element must be [`BasicTypeEnum<'ctx>`]
#[derive(Debug)] #[derive(Debug, Clone, Copy)]
struct FieldLLVM<'ctx> { struct FieldLLVM<'ctx> {
gep_index: u64, gep_index: u64,
name: &'ctx str, name: &'ctx str,
@ -57,7 +57,7 @@ impl<'ctx> FieldBuilder<'ctx> {
} }
} }
pub trait IsStruct<'ctx>: Clone { pub trait IsStruct<'ctx>: Clone + Copy {
type Fields; type Fields;
fn struct_name(&self) -> &'static str; fn struct_name(&self) -> &'static str;
@ -79,11 +79,10 @@ pub trait IsStruct<'ctx>: Clone {
} }
} }
// To play nice with Rust's trait resolution #[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone)]
pub struct StructModel<S>(pub S); pub struct StructModel<S>(pub S);
// TODO: enrich it #[derive(Debug, Clone, Copy)]
pub struct Struct<'ctx, S> { pub struct Struct<'ctx, S> {
pub structure: S, pub structure: S,
pub value: StructValue<'ctx>, pub value: StructValue<'ctx>,
@ -104,7 +103,7 @@ impl<'ctx, S: IsStruct<'ctx>> Model<'ctx> for StructModel<S> {
fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value { fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
// TODO: check structure // TODO: check structure
Struct { structure: self.0.clone(), value: value.into_struct_value() } Struct { structure: self.0, value: value.into_struct_value() }
} }
} }

View File

@ -4,6 +4,8 @@ use inkwell::{
values::{AnyValueEnum, BasicValue, BasicValueEnum, IntValue}, values::{AnyValueEnum, BasicValue, BasicValueEnum, IntValue},
}; };
use crate::codegen::CodeGenContext;
use super::core::*; use super::core::*;
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
@ -32,21 +34,85 @@ impl<'ctx> Model<'ctx> for IntModel<'ctx> {
} }
} }
impl<'ctx> Int<'ctx> {
#[must_use]
pub fn signed_cast_to_int(
self,
ctx: &CodeGenContext<'ctx, '_>,
target_int: IntModel<'ctx>,
name: &str,
) -> Int<'ctx> {
Int(ctx.builder.build_int_s_extend_or_bit_cast(self.0, target_int.0, name).unwrap())
}
#[must_use]
pub fn signed_cast_to_fixed<T: IsFixedInt>(
self,
ctx: &CodeGenContext<'ctx, '_>,
target_fixed: T,
name: &str,
) -> FixedInt<'ctx, T> {
FixedInt {
int: target_fixed,
value: ctx
.builder
.build_int_s_extend_or_bit_cast(self.0, T::get_int_type(ctx.ctx), name)
.unwrap(),
}
}
}
impl<'ctx> IntModel<'ctx> { impl<'ctx> IntModel<'ctx> {
#[must_use] #[must_use]
pub fn constant(&self, value: u64) -> Int<'ctx> { pub fn constant(&self, value: u64) -> Int<'ctx> {
Int(self.0.const_int(value, false)) Int(self.0.const_int(value, false))
} }
#[must_use]
pub fn same_as(&self, other: IntModel<'ctx>) -> bool {
// TODO: or `self.0 == other.0` would also work?
self.0.get_bit_width() == other.0.get_bit_width()
}
} }
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Copy, Default)]
pub struct FixedIntModel<T>(pub T); pub struct FixedIntModel<T>(pub T);
impl<T: IsFixedInt> FixedIntModel<T> {
pub fn to_int_model(self, ctx: &Context) -> IntModel<'_> {
IntModel(T::get_int_type(ctx))
}
}
#[derive(Debug, Clone, Copy)]
pub struct FixedInt<'ctx, T: IsFixedInt> { pub struct FixedInt<'ctx, T: IsFixedInt> {
pub int: T, pub int: T,
pub value: IntValue<'ctx>, pub value: IntValue<'ctx>,
} }
pub trait IsFixedInt: Clone + Default { impl<'ctx, T: IsFixedInt> FixedInt<'ctx, T> {
pub fn to_int(self) -> Int<'ctx> {
Int(self.value)
}
pub fn signed_cast_to_fixed<R: IsFixedInt>(
self,
ctx: &CodeGenContext<'ctx, '_>,
target_fixed_int: R,
name: &str,
) -> FixedInt<'ctx, R> {
FixedInt {
int: target_fixed_int,
value: ctx
.builder
.build_int_s_extend_or_bit_cast(self.value, R::get_int_type(ctx.ctx), name)
.unwrap(),
}
}
}
// Default instance is to enable `FieldBuilder::add_field_auto`
pub trait IsFixedInt: Clone + Copy + Default {
fn get_int_type(ctx: &Context) -> IntType<'_>; fn get_int_type(ctx: &Context) -> IntType<'_>;
fn get_bit_width() -> u32; // This is required, instead of only relying on get_int_type fn get_bit_width() -> u32; // This is required, instead of only relying on get_int_type
} }
@ -67,18 +133,19 @@ impl<'ctx, T: IsFixedInt> Model<'ctx> for FixedIntModel<T> {
fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value { fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
let value = value.into_int_value(); let value = value.into_int_value();
assert_eq!(value.get_type().get_bit_width(), T::get_bit_width()); assert_eq!(value.get_type().get_bit_width(), T::get_bit_width());
FixedInt { int: self.0.clone(), value } FixedInt { int: self.0, value }
} }
} }
impl<'ctx, T: IsFixedInt> FixedIntModel<T> { impl<'ctx, T: IsFixedInt> FixedIntModel<T> {
pub fn constant(&self, ctx: &'ctx Context, value: u64) -> FixedInt<'ctx, T> { pub fn constant(&self, ctx: &'ctx Context, value: u64) -> FixedInt<'ctx, T> {
FixedInt { int: self.0.clone(), value: T::get_int_type(ctx).const_int(value, false) } FixedInt { int: self.0, value: T::get_int_type(ctx).const_int(value, false) }
} }
} }
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Copy, Default)]
pub struct Bool; pub struct Bool;
pub type BoolModel = FixedIntModel<Bool>;
impl IsFixedInt for Bool { impl IsFixedInt for Bool {
fn get_int_type(ctx: &Context) -> IntType<'_> { fn get_int_type(ctx: &Context) -> IntType<'_> {
@ -90,8 +157,9 @@ impl IsFixedInt for Bool {
} }
} }
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Copy, Default)]
pub struct Byte; pub struct Byte;
pub type ByteModel = FixedIntModel<Byte>;
impl IsFixedInt for Byte { impl IsFixedInt for Byte {
fn get_int_type(ctx: &Context) -> IntType<'_> { fn get_int_type(ctx: &Context) -> IntType<'_> {
@ -103,8 +171,9 @@ impl IsFixedInt for Byte {
} }
} }
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Copy, Default)]
pub struct Int32; pub struct Int32;
pub type Int32Model = FixedIntModel<Int32>;
impl IsFixedInt for Int32 { impl IsFixedInt for Int32 {
fn get_int_type(ctx: &Context) -> IntType<'_> { fn get_int_type(ctx: &Context) -> IntType<'_> {
@ -116,8 +185,9 @@ impl IsFixedInt for Int32 {
} }
} }
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Copy, Default)]
pub struct Int64; pub struct Int64;
pub type Int64Model = FixedIntModel<Int64>;
impl IsFixedInt for Int64 { impl IsFixedInt for Int64 {
fn get_int_type(ctx: &Context) -> IntType<'_> { fn get_int_type(ctx: &Context) -> IntType<'_> {

View File

@ -9,12 +9,13 @@ use crate::codegen::CodeGenContext;
use super::core::*; use super::core::*;
#[derive(Debug, Clone, Copy)]
pub struct Pointer<'ctx, E: Model<'ctx>> { pub struct Pointer<'ctx, E: Model<'ctx>> {
pub element: E, pub element: E,
pub value: PointerValue<'ctx>, pub value: PointerValue<'ctx>,
} }
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Copy, Default)]
pub struct PointerModel<E>(pub E); pub struct PointerModel<E>(pub E);
impl<'ctx, E: Model<'ctx>> ModelValue<'ctx> for Pointer<'ctx, E> { impl<'ctx, E: Model<'ctx>> ModelValue<'ctx> for Pointer<'ctx, E> {
@ -24,7 +25,7 @@ impl<'ctx, E: Model<'ctx>> ModelValue<'ctx> for Pointer<'ctx, E> {
} }
impl<'ctx, E: Model<'ctx>> Pointer<'ctx, E> { impl<'ctx, E: Model<'ctx>> Pointer<'ctx, E> {
pub fn store(&self, ctx: &CodeGenContext<'ctx, '_>, val: &E::Value) { pub fn store(&self, ctx: &CodeGenContext<'ctx, '_>, val: E::Value) {
ctx.builder.build_store(self.value, val.get_llvm_value()).unwrap(); ctx.builder.build_store(self.value, val.get_llvm_value()).unwrap();
} }
@ -32,6 +33,30 @@ impl<'ctx, E: Model<'ctx>> Pointer<'ctx, E> {
let val = ctx.builder.build_load(self.value, name).unwrap(); let val = ctx.builder.build_load(self.value, name).unwrap();
self.element.check_llvm_value(val.as_any_value_enum()) self.element.check_llvm_value(val.as_any_value_enum())
} }
pub fn to_opaque(self) -> OpaquePointer<'ctx> {
OpaquePointer(self.value)
}
pub fn cast_opaque_to(
&self,
ctx: &CodeGenContext<'ctx, '_>,
element_type: BasicTypeEnum<'ctx>,
name: &str,
) -> OpaquePointer<'ctx> {
self.to_opaque().cast_opaque_to(ctx, element_type, name)
}
pub fn cast_to<R: Model<'ctx>>(
self,
ctx: &CodeGenContext<'ctx, '_>,
element_model: R,
name: &str,
) -> Pointer<'ctx, R> {
let casted_ptr =
self.to_opaque().cast_opaque_to(ctx, element_model.get_llvm_type(ctx.ctx), name).0;
Pointer { element: element_model, value: casted_ptr }
}
} }
impl<'ctx, E: Model<'ctx>> Model<'ctx> for PointerModel<E> { impl<'ctx, E: Model<'ctx>> Model<'ctx> for PointerModel<E> {
@ -43,13 +68,15 @@ impl<'ctx, E: Model<'ctx>> Model<'ctx> for PointerModel<E> {
fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value { fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
// TODO: Check get_element_type()? for LLVM 14 at least... // TODO: Check get_element_type()? for LLVM 14 at least...
Pointer { element: self.0.clone(), value: value.into_pointer_value() } Pointer { element: self.0, value: value.into_pointer_value() }
} }
} }
// A pointer of which the element's model is unknown.
#[derive(Debug, Clone, Copy)]
pub struct OpaquePointer<'ctx>(pub PointerValue<'ctx>); pub struct OpaquePointer<'ctx>(pub PointerValue<'ctx>);
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Copy, Default)]
pub struct OpaquePointerModel; pub struct OpaquePointerModel;
impl<'ctx> ModelValue<'ctx> for OpaquePointer<'ctx> { impl<'ctx> ModelValue<'ctx> for OpaquePointer<'ctx> {
@ -74,19 +101,39 @@ impl<'ctx> Model<'ctx> for OpaquePointerModel {
} }
impl<'ctx> OpaquePointer<'ctx> { impl<'ctx> OpaquePointer<'ctx> {
pub fn store(&self, ctx: &CodeGenContext<'ctx, '_>, value: BasicValueEnum<'ctx>) { pub fn load_opaque(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> BasicValueEnum<'ctx> {
ctx.builder.build_load(self.0, name).unwrap()
}
pub fn store_opaque(&self, ctx: &CodeGenContext<'ctx, '_>, value: BasicValueEnum<'ctx>) {
ctx.builder.build_store(self.0, value).unwrap(); ctx.builder.build_store(self.0, value).unwrap();
} }
pub fn from_ptr(ctx: &CodeGenContext<'ctx, '_>, ptr: PointerValue<'ctx>) -> Self { #[must_use]
let ptr = ctx pub fn cast_opaque_to(
.builder self,
.build_pointer_cast( ctx: &CodeGenContext<'ctx, '_>,
ptr, element_llvm_type: BasicTypeEnum<'ctx>,
ctx.ctx.i8_type().ptr_type(AddressSpace::default()), name: &str,
"opaque.from_ptr", ) -> OpaquePointer<'ctx> {
) OpaquePointer(
.unwrap(); ctx.builder
OpaquePointer(ptr) .build_pointer_cast(
self.0,
element_llvm_type.ptr_type(AddressSpace::default()),
name,
)
.unwrap(),
)
}
pub fn cast_to<E: Model<'ctx>>(
self,
ctx: &CodeGenContext<'ctx, '_>,
element_model: E,
name: &str,
) -> Pointer<'ctx, E> {
let ptr = self.cast_opaque_to(ctx, element_model.get_llvm_type(ctx.ctx), name).0;
Pointer { element: element_model, value: ptr }
} }
} }

View File

@ -16,7 +16,7 @@ impl<'ctx, E: Model<'ctx>> ArraySlice<'ctx, E> {
) -> Pointer<'ctx, E> { ) -> Pointer<'ctx, E> {
let element_addr = let element_addr =
unsafe { ctx.builder.build_in_bounds_gep(self.pointer.value, &[idx.0], name).unwrap() }; unsafe { ctx.builder.build_in_bounds_gep(self.pointer.value, &[idx.0], name).unwrap() };
Pointer { value: element_addr, element: self.pointer.element.clone() } Pointer { value: element_addr, element: self.pointer.element }
} }
pub fn ix<G: CodeGenerator + ?Sized>( pub fn ix<G: CodeGenerator + ?Sized>(

View File

@ -9,8 +9,10 @@ use crate::{
use super::{ use super::{
irrt::numpy::{ irrt::numpy::{
alloca_ndarray_and_init, call_nac3_ndarray_fill_generic, parse_input_shape_arg, ndarray::{
NDArrayInitMode, NpArray, alloca_ndarray_and_init, call_nac3_ndarray_fill_generic, NDArrayInitMode, NpArray,
},
shape::parse_input_shape_arg,
}, },
model::*, model::*,
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
@ -59,13 +61,11 @@ where
// Allocate fill_value on the stack and give the corresponding stack pointer // Allocate fill_value on the stack and give the corresponding stack pointer
// to call_nac3_ndarray_fill_generic // to call_nac3_ndarray_fill_generic
let fill_value_ptr = ctx.builder.build_alloca(fill_value.get_type(), "fill_value_ptr").unwrap(); let fill_value_ptr = ctx.builder.build_alloca(fill_value.get_type(), "fill_value_ptr").unwrap();
ctx.builder.build_store(fill_value_ptr, fill_value).unwrap(); let fill_value_ptr = OpaquePointer(fill_value_ptr);
fill_value_ptr.store_opaque(ctx, fill_value);
// Opaque-ize fill_value_ptr (turning it into `i8*`) before passing let fill_value_ptr = fill_value_ptr.cast_to(ctx, FixedIntModel(Byte), "");
// to call_nac3_ndarray_fill_generic call_nac3_ndarray_fill_generic(ctx, ndarray_ptr, fill_value_ptr);
let fill_value_ptr = OpaquePointer::from_ptr(ctx, fill_value_ptr);
call_nac3_ndarray_fill_generic(generator, ctx, &ndarray_ptr, &fill_value_ptr);
Ok(ndarray_ptr) Ok(ndarray_ptr)
} }

View File

@ -5,7 +5,7 @@ use indexmap::IndexMap;
use inkwell::{ use inkwell::{
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
types::{BasicMetadataTypeEnum, BasicType}, types::{BasicMetadataTypeEnum, BasicType},
values::{BasicMetadataValueEnum, BasicValue, CallSiteValue}, values::{AnyValue, BasicMetadataValueEnum, BasicValue, CallSiteValue},
IntPredicate, IntPredicate,
}; };
use itertools::Either; use itertools::Either;
@ -14,10 +14,17 @@ use strum::IntoEnumIterator;
use crate::{ use crate::{
codegen::{ codegen::{
builtin_fns, builtin_fns,
classes::{ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor}, classes::{ProxyValue, RangeValue},
expr::destructure_range, expr::destructure_range,
irrt::*, irrt::{
numpy::*, calculate_len_for_slice_range,
numpy::ndarray::{call_nac3_ndarray_len, NpArray},
},
model::*,
numpy::{
gen_ndarray_array, gen_ndarray_copy, gen_ndarray_eye, gen_ndarray_fill,
gen_ndarray_identity,
},
numpy_new, numpy_new,
stmt::exn_constructor, stmt::exn_constructor,
}, },
@ -1265,7 +1272,7 @@ impl<'a> BuiltinBuilder<'a> {
// type variable // type variable
&[(self.list_int32, "shape"), (tv.ty, "fill_value")], &[(self.list_int32, "shape"), (tv.ty, "fill_value")],
Box::new(move |ctx, obj, fun, args, generator| { Box::new(move |ctx, obj, fun, args, generator| {
gen_ndarray_full(ctx, &obj, fun, &args, generator) numpy_new::gen_ndarray_full(ctx, &obj, fun, &args, generator)
.map(|val| Some(val.as_basic_value_enum())) .map(|val| Some(val.as_basic_value_enum()))
}), }),
) )
@ -1457,51 +1464,19 @@ impl<'a> BuiltinBuilder<'a> {
} }
} }
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let llvm_i32 = ctx.ctx.i32_type(); // Parse `arg`
let llvm_usize = generator.get_size_type(ctx.ctx); let sizet = IntModel(generator.get_size_type(ctx.ctx));
let arg = NDArrayValue::from_ptr_val( let ndarray_ptr_model =
arg.into_pointer_value(), PointerModel(StructModel(NpArray { sizet }));
llvm_usize, let ndarray_ptr =
None, ndarray_ptr_model.check_llvm_value(arg.as_any_value_enum());
);
let ndims = arg.dim_sizes().size(ctx, generator); // Calculate len
ctx.make_assert( // NOTE: Unsized object is asserted in IRRT
generator, let len = call_nac3_ndarray_len(generator, ctx, ndarray_ptr);
ctx.builder let len = len.signed_cast_to_fixed(ctx, Int32, "len_i32");
.build_int_compare( Some(len.value.as_basic_value_enum())
IntPredicate::NE,
ndims,
llvm_usize.const_zero(),
"",
)
.unwrap(),
"0:TypeError",
&format!("{name}() of unsized object", name = prim.name()),
[None, None, None],
ctx.current_loc,
);
let len = unsafe {
arg.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_zero(),
None,
)
};
if len.get_type().get_bit_width() == 32 {
Some(len.into())
} else {
Some(
ctx.builder
.build_int_truncate(len, llvm_i32, "len")
.map(Into::into)
.unwrap(),
)
}
} }
_ => unreachable!(), _ => unreachable!(),
} }