forked from M-Labs/nac3
WIP: core/ndstrides: implement np.newaxis and ... + IRRT_DEBUG_ASSERT
This commit is contained in:
parent
c28166efb8
commit
f0519e7019
|
@ -24,6 +24,12 @@ fn compile_irrt_cpp() {
|
|||
let out_dir = get_out_dir();
|
||||
let irrt_dir = get_irrt_dir();
|
||||
|
||||
let (opt_flag, debug_assert_flag) = match env::var("PROFILE").as_deref() {
|
||||
Ok("debug") => ("-O0", "-DIRRT_DEBUG_ASSERT=true"),
|
||||
Ok("release") => ("-O3", "-DIRRT_DEBUG_ASSERT=false"),
|
||||
flavor => panic!("Unknown or missing build flavor {flavor:?}"),
|
||||
};
|
||||
|
||||
/*
|
||||
* HACK: Sadly, clang doesn't let us emit generic LLVM bitcode.
|
||||
* Compiling for WASM32 and filtering the output with regex is the closest we can get.
|
||||
|
@ -36,11 +42,8 @@ fn compile_irrt_cpp() {
|
|||
"-fno-discard-value-names",
|
||||
"-fno-exceptions",
|
||||
"-fno-rtti",
|
||||
match env::var("PROFILE").as_deref() {
|
||||
Ok("debug") => "-O0",
|
||||
Ok("release") => "-O3",
|
||||
flavor => panic!("Unknown or missing build flavor {flavor:?}"),
|
||||
},
|
||||
opt_flag,
|
||||
debug_assert_flag,
|
||||
"-emit-llvm",
|
||||
"-S",
|
||||
"-Wall",
|
||||
|
|
|
@ -71,13 +71,6 @@ namespace util {
|
|||
template <typename SizeT>
|
||||
SizeT validate_and_deduce_ndims_after_indexing(SizeT ndims, SizeT num_indexes,
|
||||
const NDIndex* indexes) {
|
||||
if (num_indexes > ndims) {
|
||||
raise_exception(SizeT, EXN_INDEX_ERROR,
|
||||
"too many indices for array: array is {0}-dimensional, "
|
||||
"but {1} were indexed",
|
||||
ndims, num_indexes, NO_PARAM);
|
||||
}
|
||||
|
||||
// There may be ellipsis `...` in `indexes`. There can only be 0 or 1 ellipsis.
|
||||
SizeT num_ellipsis = 0;
|
||||
|
||||
|
@ -101,6 +94,14 @@ SizeT validate_and_deduce_ndims_after_indexing(SizeT ndims, SizeT num_indexes,
|
|||
__builtin_unreachable();
|
||||
}
|
||||
}
|
||||
|
||||
if (num_indexes - num_ellipsis > ndims) {
|
||||
raise_exception(SizeT, EXN_INDEX_ERROR,
|
||||
"too many indices for array: array is {0}-dimensional, "
|
||||
"but {1} were indexed",
|
||||
ndims, num_indexes, NO_PARAM);
|
||||
}
|
||||
|
||||
return ndims;
|
||||
}
|
||||
} // namespace util
|
||||
|
@ -139,9 +140,51 @@ SizeT validate_and_deduce_ndims_after_indexing(SizeT ndims, SizeT num_indexes,
|
|||
template <typename SizeT>
|
||||
void index(SizeT num_indexes, const NDIndex* indexes,
|
||||
const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
||||
SizeT expected_dst_ndarray_ndims =
|
||||
util::validate_and_deduce_ndims_after_indexing(src_ndarray->ndims,
|
||||
num_indexes, indexes);
|
||||
// First, validate `indexes`.
|
||||
|
||||
// Expected value of `dst_ndarray->ndims`.
|
||||
SizeT expected_dst_ndims = src_ndarray->ndims;
|
||||
// To check for "too many indices for array: array is ?-dimensional, but ? were indexed"
|
||||
SizeT num_indexed = 0;
|
||||
// There may be ellipsis `...` in `indexes`. There can only be 0 or 1 ellipsis.
|
||||
SizeT num_ellipsis = 0;
|
||||
|
||||
for (SizeT i = 0; i < num_indexes; i++) {
|
||||
if (indexes[i].type == ND_INDEX_TYPE_SINGLE_ELEMENT) {
|
||||
expected_dst_ndims--;
|
||||
num_indexed++;
|
||||
} else if (indexes[i].type == ND_INDEX_TYPE_SLICE) {
|
||||
num_indexed++;
|
||||
} else if (indexes[i].type == ND_INDEX_TYPE_NEWAXIS) {
|
||||
expected_dst_ndims++;
|
||||
} else if (indexes[i].type == ND_INDEX_TYPE_ELLIPSIS) {
|
||||
num_ellipsis++;
|
||||
if (num_ellipsis > 1) {
|
||||
raise_exception(
|
||||
SizeT, EXN_INDEX_ERROR,
|
||||
"an index can only have a single ellipsis ('...')",
|
||||
NO_PARAM, NO_PARAM, NO_PARAM);
|
||||
}
|
||||
} else {
|
||||
__builtin_unreachable();
|
||||
}
|
||||
}
|
||||
|
||||
if (IRRT_DEBUG_ASSERT) {
|
||||
if (expected_dst_ndims != dst_ndarray->ndims) {
|
||||
raise_exception(
|
||||
SizeT, EXN_ASSERTION_ERROR,
|
||||
"IRRT expected_dst_ndims is {0}, but dst_ndarray->ndims is {1}",
|
||||
expected_dst_ndims, dst_ndarray->ndims, NO_PARAM);
|
||||
}
|
||||
}
|
||||
|
||||
if (src_ndarray->ndims - num_indexed < 0) {
|
||||
raise_exception(SizeT, EXN_INDEX_ERROR,
|
||||
"too many indices for array: array is {0}-dimensional, "
|
||||
"but {1} were indexed",
|
||||
src_ndarray->ndims, num_indexes, NO_PARAM);
|
||||
}
|
||||
|
||||
dst_ndarray->data = src_ndarray->data;
|
||||
dst_ndarray->itemsize = src_ndarray->itemsize;
|
||||
|
@ -188,7 +231,7 @@ void index(SizeT num_indexes, const NDIndex* indexes,
|
|||
dst_axis++;
|
||||
} else if (index->type == ND_INDEX_TYPE_ELLIPSIS) {
|
||||
// The number of ':' entries this '...' implies.
|
||||
SizeT ellipsis_size = src_ndarray->ndims - (num_indexes - 1);
|
||||
SizeT ellipsis_size = src_ndarray->ndims - num_indexed;
|
||||
|
||||
for (SizeT j = 0; j < ellipsis_size; j++) {
|
||||
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
|
||||
|
@ -206,6 +249,19 @@ void index(SizeT num_indexes, const NDIndex* indexes,
|
|||
dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis];
|
||||
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
|
||||
}
|
||||
|
||||
if (IRRT_DEBUG_ASSERT) {
|
||||
if (dst_ndarray->ndims != dst_axis) {
|
||||
raise_exception(SizeT, EXN_ASSERTION_ERROR,
|
||||
"IRRT dst_ndarray->ndims ({0}) != dst_axis ({1})",
|
||||
dst_ndarray->ndims, dst_axis, NO_PARAM);
|
||||
}
|
||||
if (src_ndarray->ndims != src_axis) {
|
||||
raise_exception(SizeT, EXN_ASSERTION_ERROR,
|
||||
"IRRT src_ndarray->ndims ({0}) != src_axis ({1})",
|
||||
src_ndarray->ndims, src_axis, NO_PARAM);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace indexing
|
||||
} // namespace ndarray
|
||||
|
|
|
@ -66,7 +66,7 @@ void resolve_and_check_new_shape(SizeT size, SizeT new_ndims,
|
|||
bool can_reshape;
|
||||
if (neg1_exists) {
|
||||
// Let `x` be the unknown dimension
|
||||
// solve `x * <new_size> = <size>`
|
||||
// Solve `x * <new_size> = <size>`
|
||||
if (new_size == 0 && size == 0) {
|
||||
// `x` has infinitely many solutions
|
||||
can_reshape = false;
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
#pragma once
|
||||
|
||||
#ifndef IRRT_DEBUG_ASSERT
|
||||
#error IRRT_DEBUG_ASSERT flag is missing!! Please define it to 'false' or 'true'.
|
||||
#endif
|
||||
|
||||
#include <irrt/core.hpp>
|
||||
#include <irrt/exception.hpp>
|
||||
#include <irrt/int_defs.hpp>
|
||||
|
|
|
@ -2942,3 +2942,43 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
|||
_ => unimplemented!(),
|
||||
}))
|
||||
}
|
||||
|
||||
/// Generate LLVM IR for an [`ExprKind::Slice`]
|
||||
pub fn gen_slice<'ctx, G: CodeGenerator>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
lower: &Option<Box<Expr<Option<Type>>>>,
|
||||
upper: &Option<Box<Expr<Option<Type>>>>,
|
||||
step: &Option<Box<Expr<Option<Type>>>>,
|
||||
) -> Result<
|
||||
(
|
||||
Option<Instance<'ctx, IntModel<Int32>>>,
|
||||
Option<Instance<'ctx, IntModel<Int32>>>,
|
||||
Option<Instance<'ctx, IntModel<Int32>>>,
|
||||
),
|
||||
String,
|
||||
> {
|
||||
let i32_model = IntModel(Int32); // TODO: Switch to usize
|
||||
|
||||
let mut help = |value_expr: &Option<Box<Expr<Option<Type>>>>| -> Result<_, String> {
|
||||
Ok(match value_expr {
|
||||
None => None,
|
||||
Some(value_expr) => {
|
||||
let value_expr = generator
|
||||
.gen_expr(ctx, value_expr)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.int32)?;
|
||||
|
||||
let value_expr = i32_model.check_value(generator, ctx.ctx, value_expr).unwrap();
|
||||
|
||||
Some(value_expr)
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
let lower = help(lower)?;
|
||||
let upper = help(upper)?;
|
||||
let step = help(step)?;
|
||||
|
||||
Ok((lower, upper, step))
|
||||
}
|
||||
|
|
|
@ -21,11 +21,7 @@ pub trait Model<'ctx>: fmt::Debug + Clone + Copy {
|
|||
type Type: BasicType<'ctx>;
|
||||
|
||||
/// Return the [`BasicType`] of this model.
|
||||
fn get_type<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &G,
|
||||
ctx: &'ctx Context,
|
||||
) -> Self::Type;
|
||||
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type;
|
||||
|
||||
/// Check if a [`BasicType`] is the same type of this model.
|
||||
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
|
||||
|
|
|
@ -97,11 +97,7 @@ impl<'ctx, N: IntKind<'ctx>> Model<'ctx> for IntModel<N> {
|
|||
type Type = IntType<'ctx>;
|
||||
|
||||
#[must_use]
|
||||
fn get_type<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &G,
|
||||
ctx: &'ctx Context,
|
||||
) -> Self::Type {
|
||||
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
|
||||
self.0.get_int_type(generator, ctx)
|
||||
}
|
||||
|
||||
|
|
|
@ -17,11 +17,7 @@ impl<'ctx, Element: Model<'ctx>> Model<'ctx> for PtrModel<Element> {
|
|||
type Value = PointerValue<'ctx>;
|
||||
type Type = PointerType<'ctx>;
|
||||
|
||||
fn get_type<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &G,
|
||||
ctx: &'ctx Context,
|
||||
) -> Self::Type {
|
||||
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
|
||||
self.0.get_type(generator, ctx).ptr_type(AddressSpace::default())
|
||||
}
|
||||
|
||||
|
|
|
@ -111,11 +111,7 @@ impl<'ctx, S: StructKind<'ctx>> Model<'ctx> for StructModel<S> {
|
|||
type Value = StructValue<'ctx>;
|
||||
type Type = StructType<'ctx>;
|
||||
|
||||
fn get_type<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &G,
|
||||
ctx: &'ctx Context,
|
||||
) -> Self::Type {
|
||||
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
|
||||
self.0.get_struct_type(generator, ctx)
|
||||
}
|
||||
|
||||
|
|
|
@ -4,15 +4,12 @@ use inkwell::values::{BasicValue, BasicValueEnum};
|
|||
use nac3parser::ast::StrRef;
|
||||
|
||||
use crate::{
|
||||
codegen::{
|
||||
structure::{
|
||||
codegen::structure::{
|
||||
ndarray::{
|
||||
scalar::split_scalar_or_ndarray, shape_util::parse_numpy_int_sequence,
|
||||
NDArrayObject,
|
||||
scalar::split_scalar_or_ndarray, shape_util::parse_numpy_int_sequence, NDArrayObject,
|
||||
},
|
||||
tuple::TupleObject,
|
||||
},
|
||||
},
|
||||
symbol_resolver::ValueEnum,
|
||||
toplevel::{
|
||||
numpy::{extract_ndims, unpack_ndarray_var_tys},
|
||||
|
|
|
@ -258,11 +258,11 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||
|
||||
pub mod util {
|
||||
use itertools::Itertools;
|
||||
use nac3parser::ast::{Expr, ExprKind};
|
||||
use nac3parser::ast::{Constant, Expr, ExprKind};
|
||||
|
||||
use crate::{
|
||||
codegen::{model::*, CodeGenContext, CodeGenerator},
|
||||
typecheck::typedef::Type,
|
||||
codegen::{expr::gen_slice, model::*, CodeGenContext, CodeGenerator},
|
||||
typecheck::typedef::{Type, TypeEnum},
|
||||
};
|
||||
|
||||
use super::{RustNDIndex, RustUserSlice};
|
||||
|
@ -308,45 +308,36 @@ pub mod util {
|
|||
for index_expr in index_exprs {
|
||||
// NOTE: Currently nac3core's slices do not have an object representation,
|
||||
// so the code/implementation looks awkward - we have to do pattern matching on the expression
|
||||
let ndindex =
|
||||
if let ExprKind::Slice { lower: start, upper: stop, step } = &index_expr.node {
|
||||
let ndindex = if let ExprKind::Slice { lower, upper, step } = &index_expr.node {
|
||||
// Handle slices
|
||||
|
||||
// Helper function here to deduce code duplication
|
||||
type ValueExpr = Option<Box<Expr<Option<Type>>>>;
|
||||
let mut help = |value_expr: &ValueExpr| -> Result<_, String> {
|
||||
Ok(match value_expr {
|
||||
None => None,
|
||||
Some(value_expr) => {
|
||||
let value_expr = generator
|
||||
.gen_expr(ctx, value_expr)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.int32)?;
|
||||
|
||||
let value_expr =
|
||||
i32_model.check_value(generator, ctx.ctx, value_expr).unwrap();
|
||||
|
||||
Some(value_expr)
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
let start = help(start)?;
|
||||
let stop = help(stop)?;
|
||||
let step = help(step)?;
|
||||
|
||||
RustNDIndex::Slice(RustUserSlice { start, stop, step })
|
||||
let (lower, upper, step) = gen_slice(generator, ctx, lower, upper, step)?;
|
||||
RustNDIndex::Slice(RustUserSlice { start: lower, stop: upper, step })
|
||||
} else if let ExprKind::Constant { value: Constant::Ellipsis, .. } = &index_expr.node {
|
||||
// Handle '...'
|
||||
RustNDIndex::Ellipsis
|
||||
} else {
|
||||
todo!("implement me");
|
||||
|
||||
// Anything else that is not a slice (might be illegal values),
|
||||
// For nac3core, this should be e.g., an int32 constant, an int32 variable, otherwise its an error
|
||||
let index = generator.gen_expr(ctx, index_expr)?.unwrap().to_basic_value_enum(
|
||||
match &*ctx.unifier.get_ty(index_expr.custom.unwrap()) {
|
||||
TypeEnum::TObj { obj_id, .. }
|
||||
if *obj_id == ctx.primitives.option.obj_id(&ctx.unifier).unwrap() =>
|
||||
{
|
||||
// Handle `np.newaxis` / `None`
|
||||
RustNDIndex::NewAxis
|
||||
}
|
||||
_ => {
|
||||
// Treat and handle everything else as a single element index.
|
||||
let index =
|
||||
generator.gen_expr(ctx, index_expr)?.unwrap().to_basic_value_enum(
|
||||
ctx,
|
||||
generator,
|
||||
ctx.primitives.int32,
|
||||
ctx.primitives.int32, // Must be int32, this checks for illegal values
|
||||
)?;
|
||||
let index = i32_model.check_value(generator, ctx.ctx, index).unwrap();
|
||||
|
||||
RustNDIndex::SingleElement(index)
|
||||
}
|
||||
}
|
||||
};
|
||||
rust_ndindexes.push(ndindex);
|
||||
}
|
||||
|
|
|
@ -13,8 +13,6 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||
assert!(a.ndims >= 2);
|
||||
assert!(b.ndims >= 2);
|
||||
|
||||
|
||||
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -621,7 +621,8 @@ impl Unifier {
|
|||
}
|
||||
|
||||
pub fn unify_call(
|
||||
&mut self, call: &Call,
|
||||
&mut self,
|
||||
call: &Call,
|
||||
b: Type,
|
||||
signature: &FunSignature,
|
||||
) -> Result<(), TypeError> {
|
||||
|
|
Loading…
Reference in New Issue