From f0519e7019bac015424c606a41ec1a76fca2b4fa Mon Sep 17 00:00:00 2001 From: lyken Date: Mon, 12 Aug 2024 10:46:09 +0800 Subject: [PATCH] WIP: core/ndstrides: implement np.newaxis and ... + IRRT_DEBUG_ASSERT --- nac3core/build.rs | 13 ++-- nac3core/irrt/irrt/ndarray/indexing.hpp | 78 ++++++++++++++++--- nac3core/irrt/irrt/ndarray/reshape.hpp | 2 +- nac3core/irrt/irrt_everything.hpp | 4 + nac3core/src/codegen/expr.rs | 40 ++++++++++ nac3core/src/codegen/model/core.rs | 6 +- nac3core/src/codegen/model/int.rs | 6 +- nac3core/src/codegen/model/ptr.rs | 6 +- nac3core/src/codegen/model/structure.rs | 6 +- nac3core/src/codegen/numpy_new.rs | 11 +-- .../src/codegen/structure/ndarray/indexing.rs | 73 ++++++++--------- .../src/codegen/structure/ndarray/product.rs | 2 - nac3core/src/typecheck/typedef/mod.rs | 3 +- 13 files changed, 162 insertions(+), 88 deletions(-) diff --git a/nac3core/build.rs b/nac3core/build.rs index 345b042b..2f92281e 100644 --- a/nac3core/build.rs +++ b/nac3core/build.rs @@ -24,6 +24,12 @@ fn compile_irrt_cpp() { let out_dir = get_out_dir(); let irrt_dir = get_irrt_dir(); + let (opt_flag, debug_assert_flag) = match env::var("PROFILE").as_deref() { + Ok("debug") => ("-O0", "-DIRRT_DEBUG_ASSERT=true"), + Ok("release") => ("-O3", "-DIRRT_DEBUG_ASSERT=false"), + flavor => panic!("Unknown or missing build flavor {flavor:?}"), + }; + /* * HACK: Sadly, clang doesn't let us emit generic LLVM bitcode. * Compiling for WASM32 and filtering the output with regex is the closest we can get. @@ -36,11 +42,8 @@ fn compile_irrt_cpp() { "-fno-discard-value-names", "-fno-exceptions", "-fno-rtti", - match env::var("PROFILE").as_deref() { - Ok("debug") => "-O0", - Ok("release") => "-O3", - flavor => panic!("Unknown or missing build flavor {flavor:?}"), - }, + opt_flag, + debug_assert_flag, "-emit-llvm", "-S", "-Wall", diff --git a/nac3core/irrt/irrt/ndarray/indexing.hpp b/nac3core/irrt/irrt/ndarray/indexing.hpp index c7328242..762866fe 100644 --- a/nac3core/irrt/irrt/ndarray/indexing.hpp +++ b/nac3core/irrt/irrt/ndarray/indexing.hpp @@ -71,13 +71,6 @@ namespace util { template SizeT validate_and_deduce_ndims_after_indexing(SizeT ndims, SizeT num_indexes, const NDIndex* indexes) { - if (num_indexes > ndims) { - raise_exception(SizeT, EXN_INDEX_ERROR, - "too many indices for array: array is {0}-dimensional, " - "but {1} were indexed", - ndims, num_indexes, NO_PARAM); - } - // There may be ellipsis `...` in `indexes`. There can only be 0 or 1 ellipsis. SizeT num_ellipsis = 0; @@ -101,6 +94,14 @@ SizeT validate_and_deduce_ndims_after_indexing(SizeT ndims, SizeT num_indexes, __builtin_unreachable(); } } + + if (num_indexes - num_ellipsis > ndims) { + raise_exception(SizeT, EXN_INDEX_ERROR, + "too many indices for array: array is {0}-dimensional, " + "but {1} were indexed", + ndims, num_indexes, NO_PARAM); + } + return ndims; } } // namespace util @@ -139,9 +140,51 @@ SizeT validate_and_deduce_ndims_after_indexing(SizeT ndims, SizeT num_indexes, template void index(SizeT num_indexes, const NDIndex* indexes, const NDArray* src_ndarray, NDArray* dst_ndarray) { - SizeT expected_dst_ndarray_ndims = - util::validate_and_deduce_ndims_after_indexing(src_ndarray->ndims, - num_indexes, indexes); + // First, validate `indexes`. + + // Expected value of `dst_ndarray->ndims`. + SizeT expected_dst_ndims = src_ndarray->ndims; + // To check for "too many indices for array: array is ?-dimensional, but ? were indexed" + SizeT num_indexed = 0; + // There may be ellipsis `...` in `indexes`. There can only be 0 or 1 ellipsis. + SizeT num_ellipsis = 0; + + for (SizeT i = 0; i < num_indexes; i++) { + if (indexes[i].type == ND_INDEX_TYPE_SINGLE_ELEMENT) { + expected_dst_ndims--; + num_indexed++; + } else if (indexes[i].type == ND_INDEX_TYPE_SLICE) { + num_indexed++; + } else if (indexes[i].type == ND_INDEX_TYPE_NEWAXIS) { + expected_dst_ndims++; + } else if (indexes[i].type == ND_INDEX_TYPE_ELLIPSIS) { + num_ellipsis++; + if (num_ellipsis > 1) { + raise_exception( + SizeT, EXN_INDEX_ERROR, + "an index can only have a single ellipsis ('...')", + NO_PARAM, NO_PARAM, NO_PARAM); + } + } else { + __builtin_unreachable(); + } + } + + if (IRRT_DEBUG_ASSERT) { + if (expected_dst_ndims != dst_ndarray->ndims) { + raise_exception( + SizeT, EXN_ASSERTION_ERROR, + "IRRT expected_dst_ndims is {0}, but dst_ndarray->ndims is {1}", + expected_dst_ndims, dst_ndarray->ndims, NO_PARAM); + } + } + + if (src_ndarray->ndims - num_indexed < 0) { + raise_exception(SizeT, EXN_INDEX_ERROR, + "too many indices for array: array is {0}-dimensional, " + "but {1} were indexed", + src_ndarray->ndims, num_indexes, NO_PARAM); + } dst_ndarray->data = src_ndarray->data; dst_ndarray->itemsize = src_ndarray->itemsize; @@ -188,7 +231,7 @@ void index(SizeT num_indexes, const NDIndex* indexes, dst_axis++; } else if (index->type == ND_INDEX_TYPE_ELLIPSIS) { // The number of ':' entries this '...' implies. - SizeT ellipsis_size = src_ndarray->ndims - (num_indexes - 1); + SizeT ellipsis_size = src_ndarray->ndims - num_indexed; for (SizeT j = 0; j < ellipsis_size; j++) { dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis]; @@ -206,6 +249,19 @@ void index(SizeT num_indexes, const NDIndex* indexes, dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis]; dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis]; } + + if (IRRT_DEBUG_ASSERT) { + if (dst_ndarray->ndims != dst_axis) { + raise_exception(SizeT, EXN_ASSERTION_ERROR, + "IRRT dst_ndarray->ndims ({0}) != dst_axis ({1})", + dst_ndarray->ndims, dst_axis, NO_PARAM); + } + if (src_ndarray->ndims != src_axis) { + raise_exception(SizeT, EXN_ASSERTION_ERROR, + "IRRT src_ndarray->ndims ({0}) != src_axis ({1})", + src_ndarray->ndims, src_axis, NO_PARAM); + } + } } } // namespace indexing } // namespace ndarray diff --git a/nac3core/irrt/irrt/ndarray/reshape.hpp b/nac3core/irrt/irrt/ndarray/reshape.hpp index 1e32c324..8947a0c3 100644 --- a/nac3core/irrt/irrt/ndarray/reshape.hpp +++ b/nac3core/irrt/irrt/ndarray/reshape.hpp @@ -66,7 +66,7 @@ void resolve_and_check_new_shape(SizeT size, SizeT new_ndims, bool can_reshape; if (neg1_exists) { // Let `x` be the unknown dimension - // solve `x * = ` + // Solve `x * = ` if (new_size == 0 && size == 0) { // `x` has infinitely many solutions can_reshape = false; diff --git a/nac3core/irrt/irrt_everything.hpp b/nac3core/irrt/irrt_everything.hpp index 707f1af7..73940b52 100644 --- a/nac3core/irrt/irrt_everything.hpp +++ b/nac3core/irrt/irrt_everything.hpp @@ -1,5 +1,9 @@ #pragma once +#ifndef IRRT_DEBUG_ASSERT +#error IRRT_DEBUG_ASSERT flag is missing!! Please define it to 'false' or 'true'. +#endif + #include #include #include diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 0480e96b..6dc827ab 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -2942,3 +2942,43 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( _ => unimplemented!(), })) } + +/// Generate LLVM IR for an [`ExprKind::Slice`] +pub fn gen_slice<'ctx, G: CodeGenerator>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + lower: &Option>>>, + upper: &Option>>>, + step: &Option>>>, +) -> Result< + ( + Option>>, + Option>>, + Option>>, + ), + String, +> { + let i32_model = IntModel(Int32); // TODO: Switch to usize + + let mut help = |value_expr: &Option>>>| -> Result<_, String> { + Ok(match value_expr { + None => None, + Some(value_expr) => { + let value_expr = generator + .gen_expr(ctx, value_expr)? + .unwrap() + .to_basic_value_enum(ctx, generator, ctx.primitives.int32)?; + + let value_expr = i32_model.check_value(generator, ctx.ctx, value_expr).unwrap(); + + Some(value_expr) + } + }) + }; + + let lower = help(lower)?; + let upper = help(upper)?; + let step = help(step)?; + + Ok((lower, upper, step)) +} diff --git a/nac3core/src/codegen/model/core.rs b/nac3core/src/codegen/model/core.rs index 8e47cf02..1afe80f2 100644 --- a/nac3core/src/codegen/model/core.rs +++ b/nac3core/src/codegen/model/core.rs @@ -21,11 +21,7 @@ pub trait Model<'ctx>: fmt::Debug + Clone + Copy { type Type: BasicType<'ctx>; /// Return the [`BasicType`] of this model. - fn get_type( - &self, - generator: &G, - ctx: &'ctx Context, - ) -> Self::Type; + fn get_type(&self, generator: &G, ctx: &'ctx Context) -> Self::Type; /// Check if a [`BasicType`] is the same type of this model. fn check_type, G: CodeGenerator + ?Sized>( diff --git a/nac3core/src/codegen/model/int.rs b/nac3core/src/codegen/model/int.rs index 9f0d9f5f..cf51f673 100644 --- a/nac3core/src/codegen/model/int.rs +++ b/nac3core/src/codegen/model/int.rs @@ -97,11 +97,7 @@ impl<'ctx, N: IntKind<'ctx>> Model<'ctx> for IntModel { type Type = IntType<'ctx>; #[must_use] - fn get_type( - &self, - generator: &G, - ctx: &'ctx Context, - ) -> Self::Type { + fn get_type(&self, generator: &G, ctx: &'ctx Context) -> Self::Type { self.0.get_int_type(generator, ctx) } diff --git a/nac3core/src/codegen/model/ptr.rs b/nac3core/src/codegen/model/ptr.rs index fb2d48a2..cc810e13 100644 --- a/nac3core/src/codegen/model/ptr.rs +++ b/nac3core/src/codegen/model/ptr.rs @@ -17,11 +17,7 @@ impl<'ctx, Element: Model<'ctx>> Model<'ctx> for PtrModel { type Value = PointerValue<'ctx>; type Type = PointerType<'ctx>; - fn get_type( - &self, - generator: &G, - ctx: &'ctx Context, - ) -> Self::Type { + fn get_type(&self, generator: &G, ctx: &'ctx Context) -> Self::Type { self.0.get_type(generator, ctx).ptr_type(AddressSpace::default()) } diff --git a/nac3core/src/codegen/model/structure.rs b/nac3core/src/codegen/model/structure.rs index cd4b9995..ddc4f8a1 100644 --- a/nac3core/src/codegen/model/structure.rs +++ b/nac3core/src/codegen/model/structure.rs @@ -111,11 +111,7 @@ impl<'ctx, S: StructKind<'ctx>> Model<'ctx> for StructModel { type Value = StructValue<'ctx>; type Type = StructType<'ctx>; - fn get_type( - &self, - generator: &G, - ctx: &'ctx Context, - ) -> Self::Type { + fn get_type(&self, generator: &G, ctx: &'ctx Context) -> Self::Type { self.0.get_struct_type(generator, ctx) } diff --git a/nac3core/src/codegen/numpy_new.rs b/nac3core/src/codegen/numpy_new.rs index 0f3df368..6faf6f49 100644 --- a/nac3core/src/codegen/numpy_new.rs +++ b/nac3core/src/codegen/numpy_new.rs @@ -4,14 +4,11 @@ use inkwell::values::{BasicValue, BasicValueEnum}; use nac3parser::ast::StrRef; use crate::{ - codegen::{ - structure::{ - ndarray::{ - scalar::split_scalar_or_ndarray, shape_util::parse_numpy_int_sequence, - NDArrayObject, - }, - tuple::TupleObject, + codegen::structure::{ + ndarray::{ + scalar::split_scalar_or_ndarray, shape_util::parse_numpy_int_sequence, NDArrayObject, }, + tuple::TupleObject, }, symbol_resolver::ValueEnum, toplevel::{ diff --git a/nac3core/src/codegen/structure/ndarray/indexing.rs b/nac3core/src/codegen/structure/ndarray/indexing.rs index 5d767140..16888351 100644 --- a/nac3core/src/codegen/structure/ndarray/indexing.rs +++ b/nac3core/src/codegen/structure/ndarray/indexing.rs @@ -258,11 +258,11 @@ impl<'ctx> NDArrayObject<'ctx> { pub mod util { use itertools::Itertools; - use nac3parser::ast::{Expr, ExprKind}; + use nac3parser::ast::{Constant, Expr, ExprKind}; use crate::{ - codegen::{model::*, CodeGenContext, CodeGenerator}, - typecheck::typedef::Type, + codegen::{expr::gen_slice, model::*, CodeGenContext, CodeGenerator}, + typecheck::typedef::{Type, TypeEnum}, }; use super::{RustNDIndex, RustUserSlice}; @@ -308,46 +308,37 @@ pub mod util { for index_expr in index_exprs { // NOTE: Currently nac3core's slices do not have an object representation, // so the code/implementation looks awkward - we have to do pattern matching on the expression - let ndindex = - if let ExprKind::Slice { lower: start, upper: stop, step } = &index_expr.node { - // Helper function here to deduce code duplication - type ValueExpr = Option>>>; - let mut help = |value_expr: &ValueExpr| -> Result<_, String> { - Ok(match value_expr { - None => None, - Some(value_expr) => { - let value_expr = generator - .gen_expr(ctx, value_expr)? - .unwrap() - .to_basic_value_enum(ctx, generator, ctx.primitives.int32)?; + let ndindex = if let ExprKind::Slice { lower, upper, step } = &index_expr.node { + // Handle slices - let value_expr = - i32_model.check_value(generator, ctx.ctx, value_expr).unwrap(); + // Helper function here to deduce code duplication + let (lower, upper, step) = gen_slice(generator, ctx, lower, upper, step)?; + RustNDIndex::Slice(RustUserSlice { start: lower, stop: upper, step }) + } else if let ExprKind::Constant { value: Constant::Ellipsis, .. } = &index_expr.node { + // Handle '...' + RustNDIndex::Ellipsis + } else { + match &*ctx.unifier.get_ty(index_expr.custom.unwrap()) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.option.obj_id(&ctx.unifier).unwrap() => + { + // Handle `np.newaxis` / `None` + RustNDIndex::NewAxis + } + _ => { + // Treat and handle everything else as a single element index. + let index = + generator.gen_expr(ctx, index_expr)?.unwrap().to_basic_value_enum( + ctx, + generator, + ctx.primitives.int32, // Must be int32, this checks for illegal values + )?; + let index = i32_model.check_value(generator, ctx.ctx, index).unwrap(); - Some(value_expr) - } - }) - }; - - let start = help(start)?; - let stop = help(stop)?; - let step = help(step)?; - - RustNDIndex::Slice(RustUserSlice { start, stop, step }) - } else { - todo!("implement me"); - - // Anything else that is not a slice (might be illegal values), - // For nac3core, this should be e.g., an int32 constant, an int32 variable, otherwise its an error - let index = generator.gen_expr(ctx, index_expr)?.unwrap().to_basic_value_enum( - ctx, - generator, - ctx.primitives.int32, - )?; - let index = i32_model.check_value(generator, ctx.ctx, index).unwrap(); - - RustNDIndex::SingleElement(index) - }; + RustNDIndex::SingleElement(index) + } + } + }; rust_ndindexes.push(ndindex); } Ok(rust_ndindexes) diff --git a/nac3core/src/codegen/structure/ndarray/product.rs b/nac3core/src/codegen/structure/ndarray/product.rs index 96d337ec..e505d253 100644 --- a/nac3core/src/codegen/structure/ndarray/product.rs +++ b/nac3core/src/codegen/structure/ndarray/product.rs @@ -13,8 +13,6 @@ impl<'ctx> NDArrayObject<'ctx> { assert!(a.ndims >= 2); assert!(b.ndims >= 2); - - todo!() } } diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index e93656e3..e125ff80 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -621,7 +621,8 @@ impl Unifier { } pub fn unify_call( - &mut self, call: &Call, + &mut self, + call: &Call, b: Type, signature: &FunSignature, ) -> Result<(), TypeError> {