forked from M-Labs/nac3
1
0
Fork 0

WIP: core/ndstrides: implement np.newaxis and ... + IRRT_DEBUG_ASSERT

This commit is contained in:
lyken 2024-08-12 10:46:09 +08:00
parent c28166efb8
commit f0519e7019
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
13 changed files with 162 additions and 88 deletions

View File

@ -24,6 +24,12 @@ fn compile_irrt_cpp() {
let out_dir = get_out_dir();
let irrt_dir = get_irrt_dir();
let (opt_flag, debug_assert_flag) = match env::var("PROFILE").as_deref() {
Ok("debug") => ("-O0", "-DIRRT_DEBUG_ASSERT=true"),
Ok("release") => ("-O3", "-DIRRT_DEBUG_ASSERT=false"),
flavor => panic!("Unknown or missing build flavor {flavor:?}"),
};
/*
* HACK: Sadly, clang doesn't let us emit generic LLVM bitcode.
* Compiling for WASM32 and filtering the output with regex is the closest we can get.
@ -36,11 +42,8 @@ fn compile_irrt_cpp() {
"-fno-discard-value-names",
"-fno-exceptions",
"-fno-rtti",
match env::var("PROFILE").as_deref() {
Ok("debug") => "-O0",
Ok("release") => "-O3",
flavor => panic!("Unknown or missing build flavor {flavor:?}"),
},
opt_flag,
debug_assert_flag,
"-emit-llvm",
"-S",
"-Wall",

View File

@ -71,13 +71,6 @@ namespace util {
template <typename SizeT>
SizeT validate_and_deduce_ndims_after_indexing(SizeT ndims, SizeT num_indexes,
const NDIndex* indexes) {
if (num_indexes > ndims) {
raise_exception(SizeT, EXN_INDEX_ERROR,
"too many indices for array: array is {0}-dimensional, "
"but {1} were indexed",
ndims, num_indexes, NO_PARAM);
}
// There may be ellipsis `...` in `indexes`. There can only be 0 or 1 ellipsis.
SizeT num_ellipsis = 0;
@ -101,6 +94,14 @@ SizeT validate_and_deduce_ndims_after_indexing(SizeT ndims, SizeT num_indexes,
__builtin_unreachable();
}
}
if (num_indexes - num_ellipsis > ndims) {
raise_exception(SizeT, EXN_INDEX_ERROR,
"too many indices for array: array is {0}-dimensional, "
"but {1} were indexed",
ndims, num_indexes, NO_PARAM);
}
return ndims;
}
} // namespace util
@ -139,9 +140,51 @@ SizeT validate_and_deduce_ndims_after_indexing(SizeT ndims, SizeT num_indexes,
template <typename SizeT>
void index(SizeT num_indexes, const NDIndex* indexes,
const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
SizeT expected_dst_ndarray_ndims =
util::validate_and_deduce_ndims_after_indexing(src_ndarray->ndims,
num_indexes, indexes);
// First, validate `indexes`.
// Expected value of `dst_ndarray->ndims`.
SizeT expected_dst_ndims = src_ndarray->ndims;
// To check for "too many indices for array: array is ?-dimensional, but ? were indexed"
SizeT num_indexed = 0;
// There may be ellipsis `...` in `indexes`. There can only be 0 or 1 ellipsis.
SizeT num_ellipsis = 0;
for (SizeT i = 0; i < num_indexes; i++) {
if (indexes[i].type == ND_INDEX_TYPE_SINGLE_ELEMENT) {
expected_dst_ndims--;
num_indexed++;
} else if (indexes[i].type == ND_INDEX_TYPE_SLICE) {
num_indexed++;
} else if (indexes[i].type == ND_INDEX_TYPE_NEWAXIS) {
expected_dst_ndims++;
} else if (indexes[i].type == ND_INDEX_TYPE_ELLIPSIS) {
num_ellipsis++;
if (num_ellipsis > 1) {
raise_exception(
SizeT, EXN_INDEX_ERROR,
"an index can only have a single ellipsis ('...')",
NO_PARAM, NO_PARAM, NO_PARAM);
}
} else {
__builtin_unreachable();
}
}
if (IRRT_DEBUG_ASSERT) {
if (expected_dst_ndims != dst_ndarray->ndims) {
raise_exception(
SizeT, EXN_ASSERTION_ERROR,
"IRRT expected_dst_ndims is {0}, but dst_ndarray->ndims is {1}",
expected_dst_ndims, dst_ndarray->ndims, NO_PARAM);
}
}
if (src_ndarray->ndims - num_indexed < 0) {
raise_exception(SizeT, EXN_INDEX_ERROR,
"too many indices for array: array is {0}-dimensional, "
"but {1} were indexed",
src_ndarray->ndims, num_indexes, NO_PARAM);
}
dst_ndarray->data = src_ndarray->data;
dst_ndarray->itemsize = src_ndarray->itemsize;
@ -188,7 +231,7 @@ void index(SizeT num_indexes, const NDIndex* indexes,
dst_axis++;
} else if (index->type == ND_INDEX_TYPE_ELLIPSIS) {
// The number of ':' entries this '...' implies.
SizeT ellipsis_size = src_ndarray->ndims - (num_indexes - 1);
SizeT ellipsis_size = src_ndarray->ndims - num_indexed;
for (SizeT j = 0; j < ellipsis_size; j++) {
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
@ -206,6 +249,19 @@ void index(SizeT num_indexes, const NDIndex* indexes,
dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis];
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
}
if (IRRT_DEBUG_ASSERT) {
if (dst_ndarray->ndims != dst_axis) {
raise_exception(SizeT, EXN_ASSERTION_ERROR,
"IRRT dst_ndarray->ndims ({0}) != dst_axis ({1})",
dst_ndarray->ndims, dst_axis, NO_PARAM);
}
if (src_ndarray->ndims != src_axis) {
raise_exception(SizeT, EXN_ASSERTION_ERROR,
"IRRT src_ndarray->ndims ({0}) != src_axis ({1})",
src_ndarray->ndims, src_axis, NO_PARAM);
}
}
}
} // namespace indexing
} // namespace ndarray

View File

@ -66,7 +66,7 @@ void resolve_and_check_new_shape(SizeT size, SizeT new_ndims,
bool can_reshape;
if (neg1_exists) {
// Let `x` be the unknown dimension
// solve `x * <new_size> = <size>`
// Solve `x * <new_size> = <size>`
if (new_size == 0 && size == 0) {
// `x` has infinitely many solutions
can_reshape = false;

View File

@ -1,5 +1,9 @@
#pragma once
#ifndef IRRT_DEBUG_ASSERT
#error IRRT_DEBUG_ASSERT flag is missing!! Please define it to 'false' or 'true'.
#endif
#include <irrt/core.hpp>
#include <irrt/exception.hpp>
#include <irrt/int_defs.hpp>

View File

@ -2942,3 +2942,43 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
_ => unimplemented!(),
}))
}
/// Generate LLVM IR for an [`ExprKind::Slice`]
pub fn gen_slice<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
lower: &Option<Box<Expr<Option<Type>>>>,
upper: &Option<Box<Expr<Option<Type>>>>,
step: &Option<Box<Expr<Option<Type>>>>,
) -> Result<
(
Option<Instance<'ctx, IntModel<Int32>>>,
Option<Instance<'ctx, IntModel<Int32>>>,
Option<Instance<'ctx, IntModel<Int32>>>,
),
String,
> {
let i32_model = IntModel(Int32); // TODO: Switch to usize
let mut help = |value_expr: &Option<Box<Expr<Option<Type>>>>| -> Result<_, String> {
Ok(match value_expr {
None => None,
Some(value_expr) => {
let value_expr = generator
.gen_expr(ctx, value_expr)?
.unwrap()
.to_basic_value_enum(ctx, generator, ctx.primitives.int32)?;
let value_expr = i32_model.check_value(generator, ctx.ctx, value_expr).unwrap();
Some(value_expr)
}
})
};
let lower = help(lower)?;
let upper = help(upper)?;
let step = help(step)?;
Ok((lower, upper, step))
}

View File

@ -21,11 +21,7 @@ pub trait Model<'ctx>: fmt::Debug + Clone + Copy {
type Type: BasicType<'ctx>;
/// Return the [`BasicType`] of this model.
fn get_type<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &'ctx Context,
) -> Self::Type;
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type;
/// Check if a [`BasicType`] is the same type of this model.
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(

View File

@ -97,11 +97,7 @@ impl<'ctx, N: IntKind<'ctx>> Model<'ctx> for IntModel<N> {
type Type = IntType<'ctx>;
#[must_use]
fn get_type<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &'ctx Context,
) -> Self::Type {
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
self.0.get_int_type(generator, ctx)
}

View File

@ -17,11 +17,7 @@ impl<'ctx, Element: Model<'ctx>> Model<'ctx> for PtrModel<Element> {
type Value = PointerValue<'ctx>;
type Type = PointerType<'ctx>;
fn get_type<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &'ctx Context,
) -> Self::Type {
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
self.0.get_type(generator, ctx).ptr_type(AddressSpace::default())
}

View File

@ -111,11 +111,7 @@ impl<'ctx, S: StructKind<'ctx>> Model<'ctx> for StructModel<S> {
type Value = StructValue<'ctx>;
type Type = StructType<'ctx>;
fn get_type<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &'ctx Context,
) -> Self::Type {
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
self.0.get_struct_type(generator, ctx)
}

View File

@ -4,15 +4,12 @@ use inkwell::values::{BasicValue, BasicValueEnum};
use nac3parser::ast::StrRef;
use crate::{
codegen::{
structure::{
codegen::structure::{
ndarray::{
scalar::split_scalar_or_ndarray, shape_util::parse_numpy_int_sequence,
NDArrayObject,
scalar::split_scalar_or_ndarray, shape_util::parse_numpy_int_sequence, NDArrayObject,
},
tuple::TupleObject,
},
},
symbol_resolver::ValueEnum,
toplevel::{
numpy::{extract_ndims, unpack_ndarray_var_tys},

View File

@ -258,11 +258,11 @@ impl<'ctx> NDArrayObject<'ctx> {
pub mod util {
use itertools::Itertools;
use nac3parser::ast::{Expr, ExprKind};
use nac3parser::ast::{Constant, Expr, ExprKind};
use crate::{
codegen::{model::*, CodeGenContext, CodeGenerator},
typecheck::typedef::Type,
codegen::{expr::gen_slice, model::*, CodeGenContext, CodeGenerator},
typecheck::typedef::{Type, TypeEnum},
};
use super::{RustNDIndex, RustUserSlice};
@ -308,45 +308,36 @@ pub mod util {
for index_expr in index_exprs {
// NOTE: Currently nac3core's slices do not have an object representation,
// so the code/implementation looks awkward - we have to do pattern matching on the expression
let ndindex =
if let ExprKind::Slice { lower: start, upper: stop, step } = &index_expr.node {
let ndindex = if let ExprKind::Slice { lower, upper, step } = &index_expr.node {
// Handle slices
// Helper function here to deduce code duplication
type ValueExpr = Option<Box<Expr<Option<Type>>>>;
let mut help = |value_expr: &ValueExpr| -> Result<_, String> {
Ok(match value_expr {
None => None,
Some(value_expr) => {
let value_expr = generator
.gen_expr(ctx, value_expr)?
.unwrap()
.to_basic_value_enum(ctx, generator, ctx.primitives.int32)?;
let value_expr =
i32_model.check_value(generator, ctx.ctx, value_expr).unwrap();
Some(value_expr)
}
})
};
let start = help(start)?;
let stop = help(stop)?;
let step = help(step)?;
RustNDIndex::Slice(RustUserSlice { start, stop, step })
let (lower, upper, step) = gen_slice(generator, ctx, lower, upper, step)?;
RustNDIndex::Slice(RustUserSlice { start: lower, stop: upper, step })
} else if let ExprKind::Constant { value: Constant::Ellipsis, .. } = &index_expr.node {
// Handle '...'
RustNDIndex::Ellipsis
} else {
todo!("implement me");
// Anything else that is not a slice (might be illegal values),
// For nac3core, this should be e.g., an int32 constant, an int32 variable, otherwise its an error
let index = generator.gen_expr(ctx, index_expr)?.unwrap().to_basic_value_enum(
match &*ctx.unifier.get_ty(index_expr.custom.unwrap()) {
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.option.obj_id(&ctx.unifier).unwrap() =>
{
// Handle `np.newaxis` / `None`
RustNDIndex::NewAxis
}
_ => {
// Treat and handle everything else as a single element index.
let index =
generator.gen_expr(ctx, index_expr)?.unwrap().to_basic_value_enum(
ctx,
generator,
ctx.primitives.int32,
ctx.primitives.int32, // Must be int32, this checks for illegal values
)?;
let index = i32_model.check_value(generator, ctx.ctx, index).unwrap();
RustNDIndex::SingleElement(index)
}
}
};
rust_ndindexes.push(ndindex);
}

View File

@ -13,8 +13,6 @@ impl<'ctx> NDArrayObject<'ctx> {
assert!(a.ndims >= 2);
assert!(b.ndims >= 2);
todo!()
}
}

View File

@ -621,7 +621,8 @@ impl Unifier {
}
pub fn unify_call(
&mut self, call: &Call,
&mut self,
call: &Call,
b: Type,
signature: &FunSignature,
) -> Result<(), TypeError> {