forked from M-Labs/nac3
WIP: core/ndstrides: implement np.newaxis and ... + IRRT_DEBUG_ASSERT
This commit is contained in:
parent
c28166efb8
commit
f0519e7019
|
@ -24,6 +24,12 @@ fn compile_irrt_cpp() {
|
||||||
let out_dir = get_out_dir();
|
let out_dir = get_out_dir();
|
||||||
let irrt_dir = get_irrt_dir();
|
let irrt_dir = get_irrt_dir();
|
||||||
|
|
||||||
|
let (opt_flag, debug_assert_flag) = match env::var("PROFILE").as_deref() {
|
||||||
|
Ok("debug") => ("-O0", "-DIRRT_DEBUG_ASSERT=true"),
|
||||||
|
Ok("release") => ("-O3", "-DIRRT_DEBUG_ASSERT=false"),
|
||||||
|
flavor => panic!("Unknown or missing build flavor {flavor:?}"),
|
||||||
|
};
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* HACK: Sadly, clang doesn't let us emit generic LLVM bitcode.
|
* HACK: Sadly, clang doesn't let us emit generic LLVM bitcode.
|
||||||
* Compiling for WASM32 and filtering the output with regex is the closest we can get.
|
* Compiling for WASM32 and filtering the output with regex is the closest we can get.
|
||||||
|
@ -36,11 +42,8 @@ fn compile_irrt_cpp() {
|
||||||
"-fno-discard-value-names",
|
"-fno-discard-value-names",
|
||||||
"-fno-exceptions",
|
"-fno-exceptions",
|
||||||
"-fno-rtti",
|
"-fno-rtti",
|
||||||
match env::var("PROFILE").as_deref() {
|
opt_flag,
|
||||||
Ok("debug") => "-O0",
|
debug_assert_flag,
|
||||||
Ok("release") => "-O3",
|
|
||||||
flavor => panic!("Unknown or missing build flavor {flavor:?}"),
|
|
||||||
},
|
|
||||||
"-emit-llvm",
|
"-emit-llvm",
|
||||||
"-S",
|
"-S",
|
||||||
"-Wall",
|
"-Wall",
|
||||||
|
|
|
@ -71,13 +71,6 @@ namespace util {
|
||||||
template <typename SizeT>
|
template <typename SizeT>
|
||||||
SizeT validate_and_deduce_ndims_after_indexing(SizeT ndims, SizeT num_indexes,
|
SizeT validate_and_deduce_ndims_after_indexing(SizeT ndims, SizeT num_indexes,
|
||||||
const NDIndex* indexes) {
|
const NDIndex* indexes) {
|
||||||
if (num_indexes > ndims) {
|
|
||||||
raise_exception(SizeT, EXN_INDEX_ERROR,
|
|
||||||
"too many indices for array: array is {0}-dimensional, "
|
|
||||||
"but {1} were indexed",
|
|
||||||
ndims, num_indexes, NO_PARAM);
|
|
||||||
}
|
|
||||||
|
|
||||||
// There may be ellipsis `...` in `indexes`. There can only be 0 or 1 ellipsis.
|
// There may be ellipsis `...` in `indexes`. There can only be 0 or 1 ellipsis.
|
||||||
SizeT num_ellipsis = 0;
|
SizeT num_ellipsis = 0;
|
||||||
|
|
||||||
|
@ -101,6 +94,14 @@ SizeT validate_and_deduce_ndims_after_indexing(SizeT ndims, SizeT num_indexes,
|
||||||
__builtin_unreachable();
|
__builtin_unreachable();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (num_indexes - num_ellipsis > ndims) {
|
||||||
|
raise_exception(SizeT, EXN_INDEX_ERROR,
|
||||||
|
"too many indices for array: array is {0}-dimensional, "
|
||||||
|
"but {1} were indexed",
|
||||||
|
ndims, num_indexes, NO_PARAM);
|
||||||
|
}
|
||||||
|
|
||||||
return ndims;
|
return ndims;
|
||||||
}
|
}
|
||||||
} // namespace util
|
} // namespace util
|
||||||
|
@ -139,9 +140,51 @@ SizeT validate_and_deduce_ndims_after_indexing(SizeT ndims, SizeT num_indexes,
|
||||||
template <typename SizeT>
|
template <typename SizeT>
|
||||||
void index(SizeT num_indexes, const NDIndex* indexes,
|
void index(SizeT num_indexes, const NDIndex* indexes,
|
||||||
const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
||||||
SizeT expected_dst_ndarray_ndims =
|
// First, validate `indexes`.
|
||||||
util::validate_and_deduce_ndims_after_indexing(src_ndarray->ndims,
|
|
||||||
num_indexes, indexes);
|
// Expected value of `dst_ndarray->ndims`.
|
||||||
|
SizeT expected_dst_ndims = src_ndarray->ndims;
|
||||||
|
// To check for "too many indices for array: array is ?-dimensional, but ? were indexed"
|
||||||
|
SizeT num_indexed = 0;
|
||||||
|
// There may be ellipsis `...` in `indexes`. There can only be 0 or 1 ellipsis.
|
||||||
|
SizeT num_ellipsis = 0;
|
||||||
|
|
||||||
|
for (SizeT i = 0; i < num_indexes; i++) {
|
||||||
|
if (indexes[i].type == ND_INDEX_TYPE_SINGLE_ELEMENT) {
|
||||||
|
expected_dst_ndims--;
|
||||||
|
num_indexed++;
|
||||||
|
} else if (indexes[i].type == ND_INDEX_TYPE_SLICE) {
|
||||||
|
num_indexed++;
|
||||||
|
} else if (indexes[i].type == ND_INDEX_TYPE_NEWAXIS) {
|
||||||
|
expected_dst_ndims++;
|
||||||
|
} else if (indexes[i].type == ND_INDEX_TYPE_ELLIPSIS) {
|
||||||
|
num_ellipsis++;
|
||||||
|
if (num_ellipsis > 1) {
|
||||||
|
raise_exception(
|
||||||
|
SizeT, EXN_INDEX_ERROR,
|
||||||
|
"an index can only have a single ellipsis ('...')",
|
||||||
|
NO_PARAM, NO_PARAM, NO_PARAM);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
__builtin_unreachable();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (IRRT_DEBUG_ASSERT) {
|
||||||
|
if (expected_dst_ndims != dst_ndarray->ndims) {
|
||||||
|
raise_exception(
|
||||||
|
SizeT, EXN_ASSERTION_ERROR,
|
||||||
|
"IRRT expected_dst_ndims is {0}, but dst_ndarray->ndims is {1}",
|
||||||
|
expected_dst_ndims, dst_ndarray->ndims, NO_PARAM);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (src_ndarray->ndims - num_indexed < 0) {
|
||||||
|
raise_exception(SizeT, EXN_INDEX_ERROR,
|
||||||
|
"too many indices for array: array is {0}-dimensional, "
|
||||||
|
"but {1} were indexed",
|
||||||
|
src_ndarray->ndims, num_indexes, NO_PARAM);
|
||||||
|
}
|
||||||
|
|
||||||
dst_ndarray->data = src_ndarray->data;
|
dst_ndarray->data = src_ndarray->data;
|
||||||
dst_ndarray->itemsize = src_ndarray->itemsize;
|
dst_ndarray->itemsize = src_ndarray->itemsize;
|
||||||
|
@ -188,7 +231,7 @@ void index(SizeT num_indexes, const NDIndex* indexes,
|
||||||
dst_axis++;
|
dst_axis++;
|
||||||
} else if (index->type == ND_INDEX_TYPE_ELLIPSIS) {
|
} else if (index->type == ND_INDEX_TYPE_ELLIPSIS) {
|
||||||
// The number of ':' entries this '...' implies.
|
// The number of ':' entries this '...' implies.
|
||||||
SizeT ellipsis_size = src_ndarray->ndims - (num_indexes - 1);
|
SizeT ellipsis_size = src_ndarray->ndims - num_indexed;
|
||||||
|
|
||||||
for (SizeT j = 0; j < ellipsis_size; j++) {
|
for (SizeT j = 0; j < ellipsis_size; j++) {
|
||||||
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
|
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
|
||||||
|
@ -206,6 +249,19 @@ void index(SizeT num_indexes, const NDIndex* indexes,
|
||||||
dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis];
|
dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis];
|
||||||
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
|
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (IRRT_DEBUG_ASSERT) {
|
||||||
|
if (dst_ndarray->ndims != dst_axis) {
|
||||||
|
raise_exception(SizeT, EXN_ASSERTION_ERROR,
|
||||||
|
"IRRT dst_ndarray->ndims ({0}) != dst_axis ({1})",
|
||||||
|
dst_ndarray->ndims, dst_axis, NO_PARAM);
|
||||||
|
}
|
||||||
|
if (src_ndarray->ndims != src_axis) {
|
||||||
|
raise_exception(SizeT, EXN_ASSERTION_ERROR,
|
||||||
|
"IRRT src_ndarray->ndims ({0}) != src_axis ({1})",
|
||||||
|
src_ndarray->ndims, src_axis, NO_PARAM);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} // namespace indexing
|
} // namespace indexing
|
||||||
} // namespace ndarray
|
} // namespace ndarray
|
||||||
|
|
|
@ -66,7 +66,7 @@ void resolve_and_check_new_shape(SizeT size, SizeT new_ndims,
|
||||||
bool can_reshape;
|
bool can_reshape;
|
||||||
if (neg1_exists) {
|
if (neg1_exists) {
|
||||||
// Let `x` be the unknown dimension
|
// Let `x` be the unknown dimension
|
||||||
// solve `x * <new_size> = <size>`
|
// Solve `x * <new_size> = <size>`
|
||||||
if (new_size == 0 && size == 0) {
|
if (new_size == 0 && size == 0) {
|
||||||
// `x` has infinitely many solutions
|
// `x` has infinitely many solutions
|
||||||
can_reshape = false;
|
can_reshape = false;
|
||||||
|
|
|
@ -1,5 +1,9 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#ifndef IRRT_DEBUG_ASSERT
|
||||||
|
#error IRRT_DEBUG_ASSERT flag is missing!! Please define it to 'false' or 'true'.
|
||||||
|
#endif
|
||||||
|
|
||||||
#include <irrt/core.hpp>
|
#include <irrt/core.hpp>
|
||||||
#include <irrt/exception.hpp>
|
#include <irrt/exception.hpp>
|
||||||
#include <irrt/int_defs.hpp>
|
#include <irrt/int_defs.hpp>
|
||||||
|
|
|
@ -2942,3 +2942,43 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||||
_ => unimplemented!(),
|
_ => unimplemented!(),
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generate LLVM IR for an [`ExprKind::Slice`]
|
||||||
|
pub fn gen_slice<'ctx, G: CodeGenerator>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
lower: &Option<Box<Expr<Option<Type>>>>,
|
||||||
|
upper: &Option<Box<Expr<Option<Type>>>>,
|
||||||
|
step: &Option<Box<Expr<Option<Type>>>>,
|
||||||
|
) -> Result<
|
||||||
|
(
|
||||||
|
Option<Instance<'ctx, IntModel<Int32>>>,
|
||||||
|
Option<Instance<'ctx, IntModel<Int32>>>,
|
||||||
|
Option<Instance<'ctx, IntModel<Int32>>>,
|
||||||
|
),
|
||||||
|
String,
|
||||||
|
> {
|
||||||
|
let i32_model = IntModel(Int32); // TODO: Switch to usize
|
||||||
|
|
||||||
|
let mut help = |value_expr: &Option<Box<Expr<Option<Type>>>>| -> Result<_, String> {
|
||||||
|
Ok(match value_expr {
|
||||||
|
None => None,
|
||||||
|
Some(value_expr) => {
|
||||||
|
let value_expr = generator
|
||||||
|
.gen_expr(ctx, value_expr)?
|
||||||
|
.unwrap()
|
||||||
|
.to_basic_value_enum(ctx, generator, ctx.primitives.int32)?;
|
||||||
|
|
||||||
|
let value_expr = i32_model.check_value(generator, ctx.ctx, value_expr).unwrap();
|
||||||
|
|
||||||
|
Some(value_expr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
};
|
||||||
|
|
||||||
|
let lower = help(lower)?;
|
||||||
|
let upper = help(upper)?;
|
||||||
|
let step = help(step)?;
|
||||||
|
|
||||||
|
Ok((lower, upper, step))
|
||||||
|
}
|
||||||
|
|
|
@ -21,11 +21,7 @@ pub trait Model<'ctx>: fmt::Debug + Clone + Copy {
|
||||||
type Type: BasicType<'ctx>;
|
type Type: BasicType<'ctx>;
|
||||||
|
|
||||||
/// Return the [`BasicType`] of this model.
|
/// Return the [`BasicType`] of this model.
|
||||||
fn get_type<G: CodeGenerator + ?Sized>(
|
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type;
|
||||||
&self,
|
|
||||||
generator: &G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> Self::Type;
|
|
||||||
|
|
||||||
/// Check if a [`BasicType`] is the same type of this model.
|
/// Check if a [`BasicType`] is the same type of this model.
|
||||||
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
|
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
|
||||||
|
|
|
@ -97,11 +97,7 @@ impl<'ctx, N: IntKind<'ctx>> Model<'ctx> for IntModel<N> {
|
||||||
type Type = IntType<'ctx>;
|
type Type = IntType<'ctx>;
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
fn get_type<G: CodeGenerator + ?Sized>(
|
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
|
||||||
&self,
|
|
||||||
generator: &G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> Self::Type {
|
|
||||||
self.0.get_int_type(generator, ctx)
|
self.0.get_int_type(generator, ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,11 +17,7 @@ impl<'ctx, Element: Model<'ctx>> Model<'ctx> for PtrModel<Element> {
|
||||||
type Value = PointerValue<'ctx>;
|
type Value = PointerValue<'ctx>;
|
||||||
type Type = PointerType<'ctx>;
|
type Type = PointerType<'ctx>;
|
||||||
|
|
||||||
fn get_type<G: CodeGenerator + ?Sized>(
|
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
|
||||||
&self,
|
|
||||||
generator: &G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> Self::Type {
|
|
||||||
self.0.get_type(generator, ctx).ptr_type(AddressSpace::default())
|
self.0.get_type(generator, ctx).ptr_type(AddressSpace::default())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -111,11 +111,7 @@ impl<'ctx, S: StructKind<'ctx>> Model<'ctx> for StructModel<S> {
|
||||||
type Value = StructValue<'ctx>;
|
type Value = StructValue<'ctx>;
|
||||||
type Type = StructType<'ctx>;
|
type Type = StructType<'ctx>;
|
||||||
|
|
||||||
fn get_type<G: CodeGenerator + ?Sized>(
|
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
|
||||||
&self,
|
|
||||||
generator: &G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> Self::Type {
|
|
||||||
self.0.get_struct_type(generator, ctx)
|
self.0.get_struct_type(generator, ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -4,15 +4,12 @@ use inkwell::values::{BasicValue, BasicValueEnum};
|
||||||
use nac3parser::ast::StrRef;
|
use nac3parser::ast::StrRef;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
codegen::{
|
codegen::structure::{
|
||||||
structure::{
|
|
||||||
ndarray::{
|
ndarray::{
|
||||||
scalar::split_scalar_or_ndarray, shape_util::parse_numpy_int_sequence,
|
scalar::split_scalar_or_ndarray, shape_util::parse_numpy_int_sequence, NDArrayObject,
|
||||||
NDArrayObject,
|
|
||||||
},
|
},
|
||||||
tuple::TupleObject,
|
tuple::TupleObject,
|
||||||
},
|
},
|
||||||
},
|
|
||||||
symbol_resolver::ValueEnum,
|
symbol_resolver::ValueEnum,
|
||||||
toplevel::{
|
toplevel::{
|
||||||
numpy::{extract_ndims, unpack_ndarray_var_tys},
|
numpy::{extract_ndims, unpack_ndarray_var_tys},
|
||||||
|
|
|
@ -258,11 +258,11 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
|
|
||||||
pub mod util {
|
pub mod util {
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use nac3parser::ast::{Expr, ExprKind};
|
use nac3parser::ast::{Constant, Expr, ExprKind};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
codegen::{model::*, CodeGenContext, CodeGenerator},
|
codegen::{expr::gen_slice, model::*, CodeGenContext, CodeGenerator},
|
||||||
typecheck::typedef::Type,
|
typecheck::typedef::{Type, TypeEnum},
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{RustNDIndex, RustUserSlice};
|
use super::{RustNDIndex, RustUserSlice};
|
||||||
|
@ -308,45 +308,36 @@ pub mod util {
|
||||||
for index_expr in index_exprs {
|
for index_expr in index_exprs {
|
||||||
// NOTE: Currently nac3core's slices do not have an object representation,
|
// NOTE: Currently nac3core's slices do not have an object representation,
|
||||||
// so the code/implementation looks awkward - we have to do pattern matching on the expression
|
// so the code/implementation looks awkward - we have to do pattern matching on the expression
|
||||||
let ndindex =
|
let ndindex = if let ExprKind::Slice { lower, upper, step } = &index_expr.node {
|
||||||
if let ExprKind::Slice { lower: start, upper: stop, step } = &index_expr.node {
|
// Handle slices
|
||||||
|
|
||||||
// Helper function here to deduce code duplication
|
// Helper function here to deduce code duplication
|
||||||
type ValueExpr = Option<Box<Expr<Option<Type>>>>;
|
let (lower, upper, step) = gen_slice(generator, ctx, lower, upper, step)?;
|
||||||
let mut help = |value_expr: &ValueExpr| -> Result<_, String> {
|
RustNDIndex::Slice(RustUserSlice { start: lower, stop: upper, step })
|
||||||
Ok(match value_expr {
|
} else if let ExprKind::Constant { value: Constant::Ellipsis, .. } = &index_expr.node {
|
||||||
None => None,
|
// Handle '...'
|
||||||
Some(value_expr) => {
|
RustNDIndex::Ellipsis
|
||||||
let value_expr = generator
|
|
||||||
.gen_expr(ctx, value_expr)?
|
|
||||||
.unwrap()
|
|
||||||
.to_basic_value_enum(ctx, generator, ctx.primitives.int32)?;
|
|
||||||
|
|
||||||
let value_expr =
|
|
||||||
i32_model.check_value(generator, ctx.ctx, value_expr).unwrap();
|
|
||||||
|
|
||||||
Some(value_expr)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
};
|
|
||||||
|
|
||||||
let start = help(start)?;
|
|
||||||
let stop = help(stop)?;
|
|
||||||
let step = help(step)?;
|
|
||||||
|
|
||||||
RustNDIndex::Slice(RustUserSlice { start, stop, step })
|
|
||||||
} else {
|
} else {
|
||||||
todo!("implement me");
|
match &*ctx.unifier.get_ty(index_expr.custom.unwrap()) {
|
||||||
|
TypeEnum::TObj { obj_id, .. }
|
||||||
// Anything else that is not a slice (might be illegal values),
|
if *obj_id == ctx.primitives.option.obj_id(&ctx.unifier).unwrap() =>
|
||||||
// For nac3core, this should be e.g., an int32 constant, an int32 variable, otherwise its an error
|
{
|
||||||
let index = generator.gen_expr(ctx, index_expr)?.unwrap().to_basic_value_enum(
|
// Handle `np.newaxis` / `None`
|
||||||
|
RustNDIndex::NewAxis
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
// Treat and handle everything else as a single element index.
|
||||||
|
let index =
|
||||||
|
generator.gen_expr(ctx, index_expr)?.unwrap().to_basic_value_enum(
|
||||||
ctx,
|
ctx,
|
||||||
generator,
|
generator,
|
||||||
ctx.primitives.int32,
|
ctx.primitives.int32, // Must be int32, this checks for illegal values
|
||||||
)?;
|
)?;
|
||||||
let index = i32_model.check_value(generator, ctx.ctx, index).unwrap();
|
let index = i32_model.check_value(generator, ctx.ctx, index).unwrap();
|
||||||
|
|
||||||
RustNDIndex::SingleElement(index)
|
RustNDIndex::SingleElement(index)
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
rust_ndindexes.push(ndindex);
|
rust_ndindexes.push(ndindex);
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,8 +13,6 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
assert!(a.ndims >= 2);
|
assert!(a.ndims >= 2);
|
||||||
assert!(b.ndims >= 2);
|
assert!(b.ndims >= 2);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -621,7 +621,8 @@ impl Unifier {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn unify_call(
|
pub fn unify_call(
|
||||||
&mut self, call: &Call,
|
&mut self,
|
||||||
|
call: &Call,
|
||||||
b: Type,
|
b: Type,
|
||||||
signature: &FunSignature,
|
signature: &FunSignature,
|
||||||
) -> Result<(), TypeError> {
|
) -> Result<(), TypeError> {
|
||||||
|
|
Loading…
Reference in New Issue