Compare commits

..

66 Commits

Author SHA1 Message Date
lyken b9e837109b
core/ndstrides: implement np_transpose() (no axes argument)
The IRRT implementation knows how to handle axes. But the argument is
not in NAC3 yet.
2024-08-30 14:50:48 +08:00
lyken d32268fb5d
core/ndstrides: implement broadcasting & np_broadcast_to() 2024-08-30 14:45:43 +08:00
lyken 916a2b4993
core/ndstrides: implement np_reshape() 2024-08-30 14:41:54 +08:00
lyken c7c3cc21a8
core: categorize np_{transpose,reshape} as 'view functions' 2024-08-30 14:41:54 +08:00
lyken d2072d9248
core/ndstrides: implement np_size() 2024-08-30 14:41:00 +08:00
lyken be19165ead
core/ndstrides: implement np_shape() and np_strides()
These functions are not important, but they are handy for debugging.

`np.strides()` is not an actual NumPy function, but `ndarray.strides` is used.
2024-08-30 14:41:00 +08:00
lyken ee58cf3fc3
core/ndstrides: implement ndarray.fill() and .copy() 2024-08-30 14:41:00 +08:00
lyken 8fe8ccf200
core/ndstrides: implement np_identity() and np_eye() 2024-08-30 14:41:00 +08:00
lyken d222236492
core/ndstrides: implement np_array()
It also checks for inconsistent dimensions if the input is a list.
e.g., rejecting `[[1.0, 2.0], [3.0]]`.

However, currently only `np_array(<input>, copy=False)` and `np_array(<input>, copy=True)` are supported. In NumPy, copy could be false, true, or None. Right now, NAC3's `np_array(<input>, copy=False)` behaves like NumPy's `np.array(<input>, copy=None)`.
2024-08-30 14:40:15 +08:00
lyken 13715dbda9
core/irrt: add List
Needed for implementing np_array()
2024-08-30 14:20:19 +08:00
lyken 7910de10a1
core/ndstrides: add NDArrayObject::atleast_nd 2024-08-30 14:18:34 +08:00
lyken 6edc3f895b
core/ndstrides: add NDArrayObject::make_copy 2024-08-30 14:18:17 +08:00
lyken a40cdde8d2
core/ndstrides: implement ndarray indexing
The functionality for `...` and `np.newaxis` is there in IRRT, but there
is no implementation of them for @kernel Python expressions because of
#486.
2024-08-30 14:12:54 +08:00
lyken 853fa39537
core/irrt: rename NDIndex to NDIndexInt
Unfortunately the name `NDIndex` is used in later commits. Renaming this
typedef to `NDIndexInt` to avoid amending. `NDIndexInt` will be removed
anyway when ndarray strides is completed.
2024-08-30 14:00:19 +08:00
lyken b6a1880226
core/irrt: add Slice and Range
Needed for implementing general ndarray indexing.

Currently IRRT slice and range have nothing to do with NAC3's slice
and range. The IRRT slice and range are currently there to implement
ndarray specific features. However, in the future their definitions may
be used to replace that of NAC3's. (NAC3's range is a [i32 x 3], IRRT's
range is a proper struct. NAC3 does not have a slice struct).
2024-08-30 13:57:10 +08:00
lyken d1c75c7444
core/ndstrides: implement len(ndarray) & refactor len() 2024-08-30 13:45:25 +08:00
lyken 58c5bc56b9
core/ndstrides: implement np_{zeros,ones,full,empty} 2024-08-30 13:44:12 +08:00
lyken ddc0e44c61
core/model: add util::gen_for_model 2024-08-30 13:42:39 +08:00
lyken 549536f72c
core/object: add ListObject and TupleObject
Needed for implementing other ndarray utils.
2024-08-30 13:41:31 +08:00
lyken 40c42b571a
core/ndstrides: implement ndarray iterator NDIter
A necessary utility to iterate through all elements in a possibly strided ndarray.
2024-08-30 13:39:10 +08:00
lyken 92e7103ec7
core/ndstrides: introduce NDArray
NDArray with strides.
2024-08-30 13:24:45 +08:00
lyken 9bc5e96dba
core/irrt: fix exception.hpp C++ castings 2024-08-30 13:15:07 +08:00
lyken 78639b1030
core/toplevel/helper: add {extract,create}_ndims 2024-08-30 13:05:16 +08:00
lyken 9723c17e24
core/object: introduce object
A small abstraction to simplify implementations.
2024-08-30 13:04:54 +08:00
lyken d1c7a8ee50
StructKind::{traverse -> iter}_fields 2024-08-30 12:51:17 +08:00
lyken e0524c19eb
Newline "Otherwise, it will be caught..." 2024-08-30 12:51:17 +08:00
lyken 32822f9052
gep_index must be u32 2024-08-30 12:51:17 +08:00
lyken 6283036815
FieldTraversal::{Out -> Output} 2024-08-30 12:51:17 +08:00
lyken f167f5f215
Ptr::copy_from to use SizeT 2024-08-30 12:51:17 +08:00
lyken baf8ee2b3d
Ptr::offset_const offset i64, can be negative 2024-08-30 12:51:17 +08:00
lyken d68760447f
Int::const_int to have sign_extend 2024-08-30 12:51:17 +08:00
lyken fdd194ee2a
FnCall::{begin -> builder} 2024-08-30 12:51:17 +08:00
lyken 5fca81c68e
CallFunction -> FnCall 2024-08-30 12:51:17 +08:00
lyken 0562e9a385
Instance add newline 2024-08-30 12:51:17 +08:00
lyken 36af473816
unsafe Model::believe_value 2024-08-30 12:51:17 +08:00
lyken 7c7e1b3ab8
Model::{sizeof -> size_of} 2024-08-30 12:51:17 +08:00
lyken dbcfc9538a
ArrayLen::{get_length -> length} 2024-08-30 12:51:17 +08:00
lyken 5c4ba09e2f
LenKind -> ArrayLen 2024-08-30 12:51:17 +08:00
lyken eb34b99ee9
core/model: renaming and add notes on upgrading Ptr to LLVM 15 2024-08-30 12:51:17 +08:00
lyken d397b9ceaa
core/model: introduce models 2024-08-30 12:51:17 +08:00
David Mak 71c3a65a31 [core] codegen/stmt: Fix obtaining return type of sret functions 2024-08-29 19:15:30 +08:00
David Mak 8c540d1033 [core] codegen/stmt: Add more casts for boolean types 2024-08-29 16:36:32 +08:00
David Mak 0cc60a3d33 [core] codegen/expr: Fix missing cast to i1 2024-08-29 16:36:32 +08:00
David Mak a59c26aa99 [artiq] Fix RPC of ndarrays from host 2024-08-29 16:08:45 +08:00
David Mak 02d93b11d1 [meta] Update dependencies 2024-08-29 14:32:21 +08:00
lyken 59cad5bfe1
standalone: clang-format demo.c 2024-08-29 10:37:24 +08:00
lyken 4318f8de84
standalone: improve src/assignment.py 2024-08-29 10:33:58 +08:00
David Mak 15ac00708a [core] Use quoted include paths instead of angled brackets
This is preferred for user-defined headers.
2024-08-28 16:37:03 +08:00
lyken c8dfdcfdea
standalone & artiq: remove class_names from resolver 2024-08-27 23:43:40 +08:00
Sébastien Bourdeauducq 600a5c8679 Revert "standalone: reformat demo.c"
This reverts commit 308edb8237.
2024-08-27 23:06:49 +08:00
lyken 22c4d25802 core/typecheck: add missing typecheck in matmul 2024-08-27 22:59:39 +08:00
lyken 308edb8237 standalone: reformat demo.c 2024-08-27 22:55:22 +08:00
lyken 9848795dcc core/irrt: add exceptions and debug utils 2024-08-27 22:55:22 +08:00
lyken 58222feed4 core/irrt: split into headers 2024-08-27 22:55:22 +08:00
lyken 518f21d174 core/irrt: build.rs capture IR defined constants 2024-08-27 22:55:22 +08:00
lyken e8e49684bf core/irrt: build.rs capture IR defined types 2024-08-27 22:55:22 +08:00
lyken b2900b4883 core/irrt: use +std=c++20 to compile
To explicitly set the C++ variant and avoid inconsistencies.
2024-08-27 22:55:22 +08:00
lyken c6dade1394 core/irrt: reformat 2024-08-27 22:55:22 +08:00
lyken 7e3fcc0845 add .clang-format 2024-08-27 22:55:22 +08:00
lyken d3b4c60d7f core/irrt: comment build.rs & move irrt to nac3core/irrt 2024-08-27 22:55:22 +08:00
abdul124 5b2b6db7ed core: improve error messages 2024-08-26 18:37:55 +08:00
abdul124 15e62f467e standalone: add tests for polymorphism 2024-08-26 18:37:55 +08:00
abdul124 2c88924ff7 core: add support for simple polymorphism 2024-08-26 18:37:55 +08:00
abdul124 a744b139ba core: allow Call and AnnAssign in init block 2024-08-26 18:37:55 +08:00
David Mak 2b2b2dbf8f [core] Fix resolution of exception names in raise short form
Previous implementation fails as `resolver.get_identifier_def` in ARTIQ
would return the exception __init__ function rather than the class.

We fix this by limiting the exception class resolution to only include
raise statements, and to force the exception name to always be treated
as a class.

Fixes #501.
2024-08-26 18:35:02 +08:00
David Mak d9f96dab33 [core] Add codegen_unreachable 2024-08-23 13:10:55 +08:00
62 changed files with 2067 additions and 1682 deletions

View File

@ -1,3 +1,32 @@
BasedOnStyle: Microsoft BasedOnStyle: LLVM
Language: Cpp
Standard: Cpp11
AccessModifierOffset: -1
AlignEscapedNewlines: Left
AlwaysBreakAfterReturnType: None
AlwaysBreakTemplateDeclarations: Yes
AllowAllParametersOfDeclarationOnNextLine: false
AllowShortFunctionsOnASingleLine: Inline
BinPackParameters: false
BreakBeforeBinaryOperators: NonAssignment
BreakBeforeTernaryOperators: true
BreakConstructorInitializers: AfterColon
BreakInheritanceList: AfterColon
ColumnLimit: 120
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ContinuationIndentWidth: 4
DerivePointerAlignment: false
IndentCaseLabels: true
IndentPPDirectives: None
IndentWidth: 4 IndentWidth: 4
ReflowComments: false MaxEmptyLinesToKeep: 1
PointerAlignment: Left
ReflowComments: true
SortIncludes: false
SortUsingDeclarations: true
SpaceAfterTemplateKeyword: false
SpacesBeforeTrailingComments: 2
TabWidth: 4
UseTab: Never

62
Cargo.lock generated
View File

@ -117,9 +117,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.1.13" version = "1.1.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72db2f7947ecee9b03b510377e8bb9077afa27176fdbff55c51027e976fdcc48" checksum = "57b6a275aa2903740dc87da01c62040406b8812552e97129a63ea8850a17c6e6"
dependencies = [ dependencies = [
"shlex", "shlex",
] ]
@ -161,7 +161,7 @@ dependencies = [
"heck 0.5.0", "heck 0.5.0",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.75", "syn 2.0.76",
] ]
[[package]] [[package]]
@ -310,9 +310,9 @@ dependencies = [
[[package]] [[package]]
name = "fastrand" name = "fastrand"
version = "2.1.0" version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6"
[[package]] [[package]]
name = "fixedbitset" name = "fixedbitset"
@ -424,7 +424,7 @@ checksum = "4fa4d8d74483041a882adaa9a29f633253a66dde85055f0495c121620ac484b2"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.75", "syn 2.0.76",
] ]
[[package]] [[package]]
@ -510,9 +510,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.157" version = "0.2.158"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "374af5f94e54fa97cf75e945cce8a6b201e88a1a07e688b47dfd2a59c66dbd86" checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439"
[[package]] [[package]]
name = "libloading" name = "libloading"
@ -752,7 +752,7 @@ dependencies = [
"phf_shared 0.11.2", "phf_shared 0.11.2",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.75", "syn 2.0.76",
] ]
[[package]] [[package]]
@ -856,7 +856,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"pyo3-macros-backend", "pyo3-macros-backend",
"quote", "quote",
"syn 2.0.75", "syn 2.0.76",
] ]
[[package]] [[package]]
@ -869,14 +869,14 @@ dependencies = [
"proc-macro2", "proc-macro2",
"pyo3-build-config", "pyo3-build-config",
"quote", "quote",
"syn 2.0.75", "syn 2.0.76",
] ]
[[package]] [[package]]
name = "quote" name = "quote"
version = "1.0.36" version = "1.0.37"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
] ]
@ -942,9 +942,9 @@ dependencies = [
[[package]] [[package]]
name = "redox_users" name = "redox_users"
version = "0.4.5" version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd283d9651eeda4b2a83a43c1c91b266c40fd76ecd39a50a8c630ae69dc72891" checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43"
dependencies = [ dependencies = [
"getrandom", "getrandom",
"libredox", "libredox",
@ -989,9 +989,9 @@ dependencies = [
[[package]] [[package]]
name = "rustix" name = "rustix"
version = "0.38.34" version = "0.38.35"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" checksum = "a85d50532239da68e9addb745ba38ff4612a242c1c7ceea689c4bc7c2f43c36f"
dependencies = [ dependencies = [
"bitflags", "bitflags",
"errno", "errno",
@ -1035,29 +1035,29 @@ checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b"
[[package]] [[package]]
name = "serde" name = "serde"
version = "1.0.208" version = "1.0.209"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cff085d2cb684faa248efb494c39b68e522822ac0de72ccf08109abde717cfb2" checksum = "99fce0ffe7310761ca6bf9faf5115afbc19688edd00171d81b1bb1b116c63e09"
dependencies = [ dependencies = [
"serde_derive", "serde_derive",
] ]
[[package]] [[package]]
name = "serde_derive" name = "serde_derive"
version = "1.0.208" version = "1.0.209"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24008e81ff7613ed8e5ba0cfaf24e2c2f1e5b8a0495711e44fcd4882fca62bcf" checksum = "a5831b979fd7b5439637af1752d535ff49f4860c0f341d1baeb6faf0f4242170"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.75", "syn 2.0.76",
] ]
[[package]] [[package]]
name = "serde_json" name = "serde_json"
version = "1.0.125" version = "1.0.127"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "83c8e735a073ccf5be70aa8066aa984eaf2fa000db6c8d0100ae605b366d31ed" checksum = "8043c06d9f82bd7271361ed64f415fe5e12a77fdb52e573e7f06a516dea329ad"
dependencies = [ dependencies = [
"itoa", "itoa",
"memchr", "memchr",
@ -1147,7 +1147,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"rustversion", "rustversion",
"syn 2.0.75", "syn 2.0.76",
] ]
[[package]] [[package]]
@ -1163,9 +1163,9 @@ dependencies = [
[[package]] [[package]]
name = "syn" name = "syn"
version = "2.0.75" version = "2.0.76"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6af063034fc1935ede7be0122941bafa9bacb949334d090b77ca98b5817c7d9" checksum = "578e081a14e0cefc3279b0472138c513f37b41a08d5a3cca9b6e4e8ceb6cd525"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -1232,7 +1232,7 @@ checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.75", "syn 2.0.76",
] ]
[[package]] [[package]]
@ -1310,9 +1310,9 @@ checksum = "0336d538f7abc86d282a4189614dfaa90810dfc2c6f6427eaf88e16311dd225d"
[[package]] [[package]]
name = "unicode-xid" name = "unicode-xid"
version = "0.2.4" version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c" checksum = "229730647fbc343e3a80e463c1db7f78f3855d3f3739bee0dda773c9a037c90a"
[[package]] [[package]]
name = "unicode_names2" name = "unicode_names2"
@ -1510,5 +1510,5 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.75", "syn 2.0.76",
] ]

View File

@ -2,7 +2,7 @@ use nac3core::{
codegen::{ codegen::{
classes::{ classes::{
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayType, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayType,
NDArrayValue, RangeValue, UntypedArrayLikeAccessor, NDArrayValue, ProxyType, ProxyValue, RangeValue, UntypedArrayLikeAccessor,
}, },
expr::{destructure_range, gen_call}, expr::{destructure_range, gen_call},
irrt::call_ndarray_calc_size, irrt::call_ndarray_calc_size,
@ -22,7 +22,7 @@ use inkwell::{
module::Linkage, module::Linkage,
types::{BasicType, IntType}, types::{BasicType, IntType},
values::{BasicValueEnum, PointerValue, StructValue}, values::{BasicValueEnum, PointerValue, StructValue},
AddressSpace, IntPredicate, AddressSpace, IntPredicate, OptimizationLevel,
}; };
use pyo3::{ use pyo3::{
@ -32,6 +32,7 @@ use pyo3::{
use crate::{symbol_resolver::InnerResolver, timeline::TimeFns}; use crate::{symbol_resolver::InnerResolver, timeline::TimeFns};
use inkwell::values::IntValue;
use itertools::Itertools; use itertools::Itertools;
use std::{ use std::{
collections::{hash_map::DefaultHasher, HashMap}, collections::{hash_map::DefaultHasher, HashMap},
@ -486,13 +487,10 @@ fn format_rpc_arg<'ctx>(
let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.arg").unwrap(); let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.arg").unwrap();
let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.arg")); let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.arg"));
let ppdata = generator.gen_var_alloc(ctx, llvm_arg_ty.element_type(), None).unwrap();
ctx.builder.build_store(ppdata, llvm_arg.data().base_ptr(ctx, generator)).unwrap();
call_memcpy_generic( call_memcpy_generic(
ctx, ctx,
buffer.base_ptr(ctx, generator), buffer.base_ptr(ctx, generator),
ppdata, llvm_arg.ptr_to_data(ctx),
llvm_pdata_sizeof, llvm_pdata_sizeof,
llvm_i1.const_zero(), llvm_i1.const_zero(),
); );
@ -528,6 +526,298 @@ fn format_rpc_arg<'ctx>(
arg_slot arg_slot
} }
/// Formats an RPC return value to conform to the expected format required by NAC3.
fn format_rpc_ret<'ctx>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, '_>,
ret_ty: Type,
) -> Option<BasicValueEnum<'ctx>> {
// -- receive value:
// T result = {
// void *ret_ptr = alloca(sizeof(T));
// void *ptr = ret_ptr;
// loop: int size = rpc_recv(ptr);
// // Non-zero: Provide `size` bytes of extra storage for variable-length data.
// if(size) { ptr = alloca(size); goto loop; }
// else *(T*)ret_ptr
// }
let llvm_i8 = ctx.ctx.i8_type();
let llvm_i32 = ctx.ctx.i32_type();
let llvm_i8_8 = ctx.ctx.struct_type(&[llvm_i8.array_type(8).into()], false);
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| {
ctx.module.add_function("rpc_recv", llvm_i32.fn_type(&[llvm_pi8.into()], false), None)
});
if ctx.unifier.unioned(ret_ty, ctx.primitives.none) {
ctx.build_call_or_invoke(rpc_recv, &[llvm_pi8.const_null().into()], "rpc_recv");
return None;
}
let prehead_bb = ctx.builder.get_insert_block().unwrap();
let current_function = prehead_bb.get_parent().unwrap();
let head_bb = ctx.ctx.append_basic_block(current_function, "rpc.head");
let alloc_bb = ctx.ctx.append_basic_block(current_function, "rpc.continue");
let tail_bb = ctx.ctx.append_basic_block(current_function, "rpc.tail");
let llvm_ret_ty = ctx.get_llvm_abi_type(generator, ret_ty);
let result = match &*ctx.unifier.get_ty_immutable(ret_ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let llvm_i1 = ctx.ctx.bool_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
// Round `val` up to its modulo `power_of_two`
let round_up = |ctx: &mut CodeGenContext<'ctx, '_>,
val: IntValue<'ctx>,
power_of_two: IntValue<'ctx>| {
debug_assert_eq!(
val.get_type().get_bit_width(),
power_of_two.get_type().get_bit_width()
);
let llvm_val_t = val.get_type();
let max_rem = ctx
.builder
.build_int_sub(power_of_two, llvm_val_t.const_int(1, false), "")
.unwrap();
ctx.builder
.build_and(
ctx.builder.build_int_add(val, max_rem, "").unwrap(),
ctx.builder.build_not(max_rem, "").unwrap(),
"",
)
.unwrap()
};
// Setup types
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let llvm_ret_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty);
// Allocate the resulting ndarray
// A condition after format_rpc_ret ensures this will not be popped this off.
let ndarray = llvm_ret_ty.new_value(generator, ctx, Some("rpc.result"));
// Setup ndims
let ndims =
if let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) {
assert_eq!(values.len(), 1);
u64::try_from(values[0].clone()).unwrap()
} else {
unreachable!();
};
// Set `ndarray.ndims`
ndarray.store_ndims(ctx, generator, llvm_usize.const_int(ndims, false));
// Allocate `ndarray.shape` [size_t; ndims]
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray.load_ndims(ctx));
/*
ndarray now:
- .ndims: initialized
- .shape: allocated but uninitialized .shape
- .data: uninitialized
*/
let llvm_usize_sizeof = ctx
.builder
.build_int_truncate_or_bit_cast(llvm_usize.size_of(), llvm_usize, "")
.unwrap();
let llvm_pdata_sizeof = ctx
.builder
.build_int_truncate_or_bit_cast(
llvm_ret_ty.element_type().size_of().unwrap(),
llvm_usize,
"",
)
.unwrap();
let llvm_elem_sizeof = ctx
.builder
.build_int_truncate_or_bit_cast(llvm_elem_ty.size_of().unwrap(), llvm_usize, "")
.unwrap();
// Allocates a buffer for the initial RPC'ed object, which is guaranteed to be
// (4 + 4 * ndims) bytes with 8-byte alignment
let sizeof_dims =
ctx.builder.build_int_mul(ndarray.load_ndims(ctx), llvm_usize_sizeof, "").unwrap();
let unaligned_buffer_size =
ctx.builder.build_int_add(sizeof_dims, llvm_pdata_sizeof, "").unwrap();
let buffer_size = round_up(ctx, unaligned_buffer_size, llvm_usize.const_int(8, false));
let stackptr = call_stacksave(ctx, None);
// Just to be absolutely sure, alloca in [i8 x 8] slices to force 8-byte alignment
let buffer = ctx
.builder
.build_array_alloca(
llvm_i8_8,
ctx.builder
.build_int_unsigned_div(buffer_size, llvm_usize.const_int(8, false), "")
.unwrap(),
"rpc.buffer",
)
.unwrap();
let buffer = ctx
.builder
.build_bitcast(buffer, llvm_pi8, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, None);
// The first call to `rpc_recv` reads the top-level ndarray object: [pdata, shape]
//
// The returned value is the number of bytes for `ndarray.data`.
let ndarray_nbytes = ctx
.build_call_or_invoke(
rpc_recv,
&[buffer.base_ptr(ctx, generator).into()], // Reads [usize; ndims]. NOTE: We are allocated [size_t; ndims].
"rpc.size.next",
)
.map(BasicValueEnum::into_int_value)
.unwrap();
// debug_assert(ndarray_nbytes > 0)
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
ctx.make_assert(
generator,
ctx.builder
.build_int_compare(
IntPredicate::UGT,
ndarray_nbytes,
ndarray_nbytes.get_type().const_zero(),
"",
)
.unwrap(),
"0:AssertionError",
"Unexpected RPC termination for ndarray - Expected data buffer next",
[None, None, None],
ctx.current_loc,
);
}
// Copy shape from the buffer to `ndarray.shape`.
let pbuffer_dims =
unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) };
call_memcpy_generic(
ctx,
ndarray.dim_sizes().base_ptr(ctx, generator),
pbuffer_dims,
sizeof_dims,
llvm_i1.const_zero(),
);
// Restore stack from before allocation of buffer
call_stackrestore(ctx, stackptr);
// Allocate `ndarray.data`.
// `ndarray.shape` must be initialized beforehand in this implementation
// (for ndarray.create_data() to know how many elements to allocate)
let num_elements =
call_ndarray_calc_size(generator, ctx, &ndarray.dim_sizes(), (None, None));
// debug_assert(nelems * sizeof(T) >= ndarray_nbytes)
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let sizeof_data =
ctx.builder.build_int_mul(num_elements, llvm_elem_sizeof, "").unwrap();
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::UGE,
sizeof_data,
ndarray_nbytes,
"",
).unwrap(),
"0:AssertionError",
"Unexpected allocation size request for ndarray data - Expected up to {0} bytes, got {1} bytes",
[Some(sizeof_data), Some(ndarray_nbytes), None],
ctx.current_loc,
);
}
ndarray.create_data(ctx, llvm_elem_ty, num_elements);
let ndarray_data = ndarray.data().base_ptr(ctx, generator);
let ndarray_data_i8 =
ctx.builder.build_pointer_cast(ndarray_data, llvm_pi8, "").unwrap();
// NOTE: Currently on `prehead_bb`
ctx.builder.build_unconditional_branch(head_bb).unwrap();
// Inserting into `head_bb`. Do `rpc_recv` for `data` recursively.
ctx.builder.position_at_end(head_bb);
let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap();
phi.add_incoming(&[(&ndarray_data_i8, prehead_bb)]);
let alloc_size = ctx
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
.map(BasicValueEnum::into_int_value)
.unwrap();
let is_done = ctx
.builder
.build_int_compare(IntPredicate::EQ, llvm_i32.const_zero(), alloc_size, "rpc.done")
.unwrap();
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap();
ctx.builder.position_at_end(alloc_bb);
// Align the allocation to sizeof(T)
let alloc_size = round_up(ctx, alloc_size, llvm_elem_sizeof);
let alloc_ptr = ctx
.builder
.build_array_alloca(
llvm_elem_ty,
ctx.builder.build_int_unsigned_div(alloc_size, llvm_elem_sizeof, "").unwrap(),
"rpc.alloc",
)
.unwrap();
let alloc_ptr =
ctx.builder.build_pointer_cast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap();
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
ctx.builder.build_unconditional_branch(head_bb).unwrap();
ctx.builder.position_at_end(tail_bb);
ndarray.as_base_value().into()
}
_ => {
let slot = ctx.builder.build_alloca(llvm_ret_ty, "rpc.ret.slot").unwrap();
let slotgen = ctx.builder.build_bitcast(slot, llvm_pi8, "rpc.ret.ptr").unwrap();
ctx.builder.build_unconditional_branch(head_bb).unwrap();
ctx.builder.position_at_end(head_bb);
let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap();
phi.add_incoming(&[(&slotgen, prehead_bb)]);
let alloc_size = ctx
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
.unwrap()
.into_int_value();
let is_done = ctx
.builder
.build_int_compare(IntPredicate::EQ, llvm_i32.const_zero(), alloc_size, "rpc.done")
.unwrap();
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap();
ctx.builder.position_at_end(alloc_bb);
let alloc_ptr =
ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap();
let alloc_ptr =
ctx.builder.build_bitcast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap();
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
ctx.builder.build_unconditional_branch(head_bb).unwrap();
ctx.builder.position_at_end(tail_bb);
ctx.builder.build_load(slot, "rpc.result").unwrap()
}
};
Some(result)
}
fn rpc_codegen_callback_fn<'ctx>( fn rpc_codegen_callback_fn<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
obj: Option<(Type, ValueEnum<'ctx>)>, obj: Option<(Type, ValueEnum<'ctx>)>,
@ -663,63 +953,14 @@ fn rpc_codegen_callback_fn<'ctx>(
// reclaim stack space used by arguments // reclaim stack space used by arguments
call_stackrestore(ctx, stackptr); call_stackrestore(ctx, stackptr);
// -- receive value: let result = format_rpc_ret(generator, ctx, fun.0.ret);
// T result = {
// void *ret_ptr = alloca(sizeof(T));
// void *ptr = ret_ptr;
// loop: int size = rpc_recv(ptr);
// // Non-zero: Provide `size` bytes of extra storage for variable-length data.
// if(size) { ptr = alloca(size); goto loop; }
// else *(T*)ret_ptr
// }
let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| {
ctx.module.add_function("rpc_recv", int32.fn_type(&[ptr_type.into()], false), None)
});
if ctx.unifier.unioned(fun.0.ret, ctx.primitives.none) { if !result.is_some_and(|res| res.get_type().is_pointer_type()) {
ctx.build_call_or_invoke(rpc_recv, &[ptr_type.const_null().into()], "rpc_recv"); // An RPC returning an NDArray would not touch here.
return Ok(None);
}
let prehead_bb = ctx.builder.get_insert_block().unwrap();
let current_function = prehead_bb.get_parent().unwrap();
let head_bb = ctx.ctx.append_basic_block(current_function, "rpc.head");
let alloc_bb = ctx.ctx.append_basic_block(current_function, "rpc.continue");
let tail_bb = ctx.ctx.append_basic_block(current_function, "rpc.tail");
let ret_ty = ctx.get_llvm_abi_type(generator, fun.0.ret);
let need_load = !ret_ty.is_pointer_type();
let slot = ctx.builder.build_alloca(ret_ty, "rpc.ret.slot").unwrap();
let slotgen = ctx.builder.build_bitcast(slot, ptr_type, "rpc.ret.ptr").unwrap();
ctx.builder.build_unconditional_branch(head_bb).unwrap();
ctx.builder.position_at_end(head_bb);
let phi = ctx.builder.build_phi(ptr_type, "rpc.ptr").unwrap();
phi.add_incoming(&[(&slotgen, prehead_bb)]);
let alloc_size = ctx
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
.unwrap()
.into_int_value();
let is_done = ctx
.builder
.build_int_compare(inkwell::IntPredicate::EQ, int32.const_zero(), alloc_size, "rpc.done")
.unwrap();
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap();
ctx.builder.position_at_end(alloc_bb);
let alloc_ptr = ctx.builder.build_array_alloca(ptr_type, alloc_size, "rpc.alloc").unwrap();
let alloc_ptr = ctx.builder.build_bitcast(alloc_ptr, ptr_type, "rpc.alloc.ptr").unwrap();
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
ctx.builder.build_unconditional_branch(head_bb).unwrap();
ctx.builder.position_at_end(tail_bb);
let result = ctx.builder.build_load(slot, "rpc.result").unwrap();
if need_load {
call_stackrestore(ctx, stackptr); call_stackrestore(ctx, stackptr);
} }
Ok(Some(result))
Ok(result)
} }
pub fn attributes_writeback( pub fn attributes_writeback(

View File

@ -33,7 +33,6 @@ use inkwell::{
OptimizationLevel, OptimizationLevel,
}; };
use itertools::Itertools; use itertools::Itertools;
use nac3core::codegen::irrt::setup_irrt_exceptions;
use nac3core::codegen::{gen_func_impl, CodeGenLLVMOptions, CodeGenTargetMachineOptions}; use nac3core::codegen::{gen_func_impl, CodeGenLLVMOptions, CodeGenTargetMachineOptions};
use nac3core::toplevel::builtins::get_exn_constructor; use nac3core::toplevel::builtins::get_exn_constructor;
use nac3core::typecheck::typedef::{into_var_map, TypeEnum, Unifier, VarMap}; use nac3core::typecheck::typedef::{into_var_map, TypeEnum, Unifier, VarMap};
@ -449,7 +448,6 @@ impl Nac3 {
pyid_to_type: pyid_to_type.clone(), pyid_to_type: pyid_to_type.clone(),
primitive_ids: self.primitive_ids.clone(), primitive_ids: self.primitive_ids.clone(),
global_value_ids: global_value_ids.clone(), global_value_ids: global_value_ids.clone(),
class_names: Mutex::default(),
name_to_pyid: name_to_pyid.clone(), name_to_pyid: name_to_pyid.clone(),
module: module.clone(), module: module.clone(),
id_to_pyval: RwLock::default(), id_to_pyval: RwLock::default(),
@ -541,7 +539,6 @@ impl Nac3 {
pyid_to_type: pyid_to_type.clone(), pyid_to_type: pyid_to_type.clone(),
primitive_ids: self.primitive_ids.clone(), primitive_ids: self.primitive_ids.clone(),
global_value_ids: global_value_ids.clone(), global_value_ids: global_value_ids.clone(),
class_names: Mutex::default(),
id_to_pyval: RwLock::default(), id_to_pyval: RwLock::default(),
id_to_primitive: RwLock::default(), id_to_primitive: RwLock::default(),
field_to_val: RwLock::default(), field_to_val: RwLock::default(),
@ -560,8 +557,7 @@ impl Nac3 {
// Process IRRT // Process IRRT
let context = inkwell::context::Context::create(); let context = inkwell::context::Context::create();
let irrt = load_irrt(&context); let irrt = load_irrt(&context, resolver.as_ref());
setup_irrt_exceptions(&context, &irrt, resolver.as_ref());
let fun_signature = let fun_signature =
FunSignature { args: vec![], ret: self.primitive.none, vars: VarMap::new() }; FunSignature { args: vec![], ret: self.primitive.none, vars: VarMap::new() };

View File

@ -23,7 +23,7 @@ use nac3core::{
}, },
}; };
use nac3parser::ast::{self, StrRef}; use nac3parser::ast::{self, StrRef};
use parking_lot::{Mutex, RwLock}; use parking_lot::RwLock;
use pyo3::{ use pyo3::{
types::{PyDict, PyTuple}, types::{PyDict, PyTuple},
PyAny, PyObject, PyResult, Python, PyAny, PyObject, PyResult, Python,
@ -79,7 +79,6 @@ pub struct InnerResolver {
pub id_to_primitive: RwLock<HashMap<u64, PrimitiveValue>>, pub id_to_primitive: RwLock<HashMap<u64, PrimitiveValue>>,
pub field_to_val: RwLock<HashMap<ResolverField, Option<PyFieldHandle>>>, pub field_to_val: RwLock<HashMap<ResolverField, Option<PyFieldHandle>>>,
pub global_value_ids: Arc<RwLock<HashMap<u64, PyObject>>>, pub global_value_ids: Arc<RwLock<HashMap<u64, PyObject>>>,
pub class_names: Mutex<HashMap<StrRef, Type>>,
pub pyid_to_def: Arc<RwLock<HashMap<u64, DefinitionId>>>, pub pyid_to_def: Arc<RwLock<HashMap<u64, DefinitionId>>>,
pub pyid_to_type: Arc<RwLock<HashMap<u64, Type>>>, pub pyid_to_type: Arc<RwLock<HashMap<u64, Type>>>,
pub primitive_ids: PrimitivePythonId, pub primitive_ids: PrimitivePythonId,

View File

@ -22,6 +22,7 @@ fn main() {
"--target=wasm32", "--target=wasm32",
"-x", "-x",
"c++", "c++",
"-std=c++20",
"-fno-discard-value-names", "-fno-discard-value-names",
"-fno-exceptions", "-fno-exceptions",
"-fno-rtti", "-fno-rtti",

View File

@ -1,15 +1,15 @@
#include <irrt/exception.hpp> #include "irrt/exception.hpp"
#include <irrt/int_types.hpp> #include "irrt/int_types.hpp"
#include <irrt/list.hpp> #include "irrt/list.hpp"
#include <irrt/math_util.hpp> #include "irrt/math.hpp"
#include <irrt/ndarray/array.hpp> #include "irrt/ndarray.hpp"
#include <irrt/ndarray/basic.hpp> #include "irrt/range.hpp"
#include <irrt/ndarray/broadcast.hpp> #include "irrt/slice.hpp"
#include <irrt/ndarray/def.hpp> #include "irrt/ndarray/basic.hpp"
#include <irrt/ndarray/indexing.hpp> #include "irrt/ndarray/def.hpp"
#include <irrt/ndarray/iter.hpp> #include "irrt/ndarray/iter.hpp"
#include <irrt/ndarray/reshape.hpp> #include "irrt/ndarray/indexing.hpp"
#include <irrt/ndarray/transpose.hpp> #include "irrt/ndarray/array.hpp"
#include <irrt/original.hpp> #include "irrt/ndarray/reshape.hpp"
#include <irrt/range.hpp> #include "irrt/ndarray/broadcast.hpp"
#include <irrt/slice.hpp> #include "irrt/ndarray/transpose.hpp"

View File

@ -1,9 +1,9 @@
#pragma once #pragma once
#include <irrt/int_types.hpp> #include "irrt/int_types.hpp"
template <typename SizeT> struct CSlice template<typename SizeT>
{ struct CSlice {
uint8_t *base; uint8_t* base;
SizeT len; SizeT len;
}; };

View File

@ -1,20 +0,0 @@
#pragma once
#include <irrt/int_types.hpp>
namespace cstr
{
/**
* @brief Implementation of `strlen()`.
*/
uint32_t length(const char *str)
{
uint32_t length = 0;
while (*str != '\0')
{
length++;
str++;
}
return length;
}
} // namespace cstr

View File

@ -7,17 +7,19 @@
#define IRRT_DEBUG_ASSERT_BOOL false #define IRRT_DEBUG_ASSERT_BOOL false
#endif #endif
#define raise_debug_assert(SizeT, msg, param1, param2, param3) \ #define raise_debug_assert(SizeT, msg, param1, param2, param3) \
raise_exception(SizeT, EXN_ASSERTION_ERROR, "IRRT debug assert failed: " msg, param1, param2, param3); raise_exception(SizeT, EXN_ASSERTION_ERROR, "IRRT debug assert failed: " msg, param1, param2, param3)
#define debug_assert_eq(SizeT, lhs, rhs) \ #define debug_assert_eq(SizeT, lhs, rhs) \
if (IRRT_DEBUG_ASSERT_BOOL && (lhs) != (rhs)) \ if constexpr (IRRT_DEBUG_ASSERT_BOOL) { \
{ \ if ((lhs) != (rhs)) { \
raise_debug_assert(SizeT, "LHS = {0}. RHS = {1}", lhs, rhs, NO_PARAM); \ raise_debug_assert(SizeT, "LHS = {0}. RHS = {1}", lhs, rhs, NO_PARAM); \
} \
} }
#define debug_assert(SizeT, expr) \ #define debug_assert(SizeT, expr) \
if (IRRT_DEBUG_ASSERT_BOOL && !(expr)) \ if constexpr (IRRT_DEBUG_ASSERT_BOOL) { \
{ \ if (!(expr)) { \
raise_debug_assert(SizeT, "Got false.", NO_PARAM, NO_PARAM, NO_PARAM); \ raise_debug_assert(SizeT, "Got false.", NO_PARAM, NO_PARAM, NO_PARAM); \
} \
} }

View File

@ -1,8 +1,7 @@
#pragma once #pragma once
#include <irrt/cslice.hpp> #include "irrt/cslice.hpp"
#include <irrt/cstr_util.hpp> #include "irrt/int_types.hpp"
#include <irrt/int_types.hpp>
/** /**
* @brief The int type of ARTIQ exception IDs. * @brief The int type of ARTIQ exception IDs.
@ -13,12 +12,11 @@ typedef int32_t ExceptionId;
* Set of exceptions C++ IRRT can use. * Set of exceptions C++ IRRT can use.
* Must be synchronized with `setup_irrt_exceptions` in `nac3core/src/codegen/irrt/mod.rs`. * Must be synchronized with `setup_irrt_exceptions` in `nac3core/src/codegen/irrt/mod.rs`.
*/ */
extern "C" extern "C" {
{ ExceptionId EXN_INDEX_ERROR;
ExceptionId EXN_INDEX_ERROR; ExceptionId EXN_VALUE_ERROR;
ExceptionId EXN_VALUE_ERROR; ExceptionId EXN_ASSERTION_ERROR;
ExceptionId EXN_ASSERTION_ERROR; ExceptionId EXN_TYPE_ERROR;
ExceptionId EXN_TYPE_ERROR;
} }
/** /**
@ -27,15 +25,14 @@ extern "C"
* The parameter `err` could be `Exception<int32_t>` or `Exception<int64_t>`. The caller * The parameter `err` could be `Exception<int32_t>` or `Exception<int64_t>`. The caller
* must make sure to pass `Exception`s with the correct `SizeT` depending on the `size_t` of the runtime. * must make sure to pass `Exception`s with the correct `SizeT` depending on the `size_t` of the runtime.
*/ */
extern "C" void __nac3_raise(void *err); extern "C" void __nac3_raise(void* err);
namespace namespace {
{
/** /**
* @brief NAC3's Exception struct * @brief NAC3's Exception struct
*/ */
template <typename SizeT> struct Exception template<typename SizeT>
{ struct Exception {
ExceptionId id; ExceptionId id;
CSlice<SizeT> filename; CSlice<SizeT> filename;
int32_t line; int32_t line;
@ -45,24 +42,32 @@ template <typename SizeT> struct Exception
int64_t params[3]; int64_t params[3];
}; };
const int64_t NO_PARAM = 0; constexpr int64_t NO_PARAM = 0;
template <typename SizeT> template<typename SizeT>
void _raise_exception_helper(ExceptionId id, const char *filename, int32_t line, const char *function, const char *msg, void _raise_exception_helper(ExceptionId id,
int64_t param0, int64_t param1, int64_t param2) const char* filename,
{ int32_t line,
const char* function,
const char* msg,
int64_t param0,
int64_t param1,
int64_t param2) {
Exception<SizeT> e = { Exception<SizeT> e = {
.id = id, .id = id,
.filename = {.base = (uint8_t *)filename, .len = (int32_t)cstr::length(filename)}, .filename = {.base = reinterpret_cast<uint8_t*>(const_cast<char*>(filename)),
.len = static_cast<int32_t>(__builtin_strlen(filename))},
.line = line, .line = line,
.column = 0, .column = 0,
.function = {.base = (uint8_t *)function, .len = (int32_t)cstr::length(function)}, .function = {.base = reinterpret_cast<uint8_t*>(const_cast<char*>(function)),
.msg = {.base = (uint8_t *)msg, .len = (int32_t)cstr::length(msg)}, .len = static_cast<int32_t>(__builtin_strlen(function))},
.msg = {.base = reinterpret_cast<uint8_t*>(const_cast<char*>(msg)),
.len = static_cast<int32_t>(__builtin_strlen(msg))},
}; };
e.params[0] = param0; e.params[0] = param0;
e.params[1] = param1; e.params[1] = param1;
e.params[2] = param2; e.params[2] = param2;
__nac3_raise((void *)&e); __nac3_raise(reinterpret_cast<void*>(&e));
__builtin_unreachable(); __builtin_unreachable();
} }
@ -75,6 +80,6 @@ void _raise_exception_helper(ExceptionId id, const char *filename, int32_t line,
* `param0` to `param2` are optional format arguments of `msg`. They should be set to * `param0` to `param2` are optional format arguments of `msg`. They should be set to
* `NO_PARAM` to indicate they are unused. * `NO_PARAM` to indicate they are unused.
*/ */
#define raise_exception(SizeT, id, msg, param0, param1, param2) \ #define raise_exception(SizeT, id, msg, param0, param1, param2) \
_raise_exception_helper<SizeT>(id, __FILE__, __LINE__, __FUNCTION__, msg, param0, param1, param2) _raise_exception_helper<SizeT>(id, __FILE__, __LINE__, __FUNCTION__, msg, param0, param1, param2)
} // namespace } // namespace

View File

@ -6,3 +6,8 @@ using int32_t = _BitInt(32);
using uint32_t = unsigned _BitInt(32); using uint32_t = unsigned _BitInt(32);
using int64_t = _BitInt(64); using int64_t = _BitInt(64);
using uint64_t = unsigned _BitInt(64); using uint64_t = unsigned _BitInt(64);
// NDArray indices are always `uint32_t`.
using NDIndexInt = uint32_t;
// The type of an index or a value describing the length of a range/slice is always `int32_t`.
using SliceIndex = int32_t;

View File

@ -1,19 +1,90 @@
#pragma once #pragma once
#include <irrt/int_types.hpp> #include "irrt/int_types.hpp"
#include <irrt/slice.hpp> #include "irrt/math_util.hpp"
#include "irrt/slice.hpp"
namespace namespace {
{
/** /**
* @brief A list in NAC3. * @brief A list in NAC3.
* *
* The `items` field is opaque. You must rely on external contexts to * The `items` field is opaque. You must rely on external contexts to
* know how to interpret it. * know how to interpret it.
*/ */
template <typename SizeT> struct List template<typename SizeT>
{ struct List {
uint8_t *items; uint8_t* items;
SizeT len; SizeT len;
}; };
} // namespace } // namespace
extern "C" {
// Handle list assignment and dropping part of the list when
// both dest_step and src_step are +1.
// - All the index must *not* be out-of-bound or negative,
// - The end index is *inclusive*,
// - The length of src and dest slice size should already
// be checked: if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest)
SliceIndex __nac3_list_slice_assign_var_size(SliceIndex dest_start,
SliceIndex dest_end,
SliceIndex dest_step,
uint8_t* dest_arr,
SliceIndex dest_arr_len,
SliceIndex src_start,
SliceIndex src_end,
SliceIndex src_step,
uint8_t* src_arr,
SliceIndex src_arr_len,
const SliceIndex size) {
/* if dest_arr_len == 0, do nothing since we do not support extending list */
if (dest_arr_len == 0)
return dest_arr_len;
/* if both step is 1, memmove directly, handle the dropping of the list, and shrink size */
if (src_step == dest_step && dest_step == 1) {
const SliceIndex src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0;
const SliceIndex dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
if (src_len > 0) {
__builtin_memmove(dest_arr + dest_start * size, src_arr + src_start * size, src_len * size);
}
if (dest_len > 0) {
/* dropping */
__builtin_memmove(dest_arr + (dest_start + src_len) * size, dest_arr + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size);
}
/* shrink size */
return dest_arr_len - (dest_len - src_len);
}
/* if two range overlaps, need alloca */
uint8_t need_alloca = (dest_arr == src_arr)
&& !(max(dest_start, dest_end) < min(src_start, src_end)
|| max(src_start, src_end) < min(dest_start, dest_end));
if (need_alloca) {
uint8_t* tmp = reinterpret_cast<uint8_t*>(__builtin_alloca(src_arr_len * size));
__builtin_memcpy(tmp, src_arr, src_arr_len * size);
src_arr = tmp;
}
SliceIndex src_ind = src_start;
SliceIndex dest_ind = dest_start;
for (; (src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end); src_ind += src_step, dest_ind += dest_step) {
/* for constant optimization */
if (size == 1) {
__builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1);
} else if (size == 4) {
__builtin_memcpy(dest_arr + dest_ind * 4, src_arr + src_ind * 4, 4);
} else if (size == 8) {
__builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8);
} else {
/* memcpy for var size, cannot overlap after previous alloca */
__builtin_memcpy(dest_arr + dest_ind * size, src_arr + src_ind * size, size);
}
}
/* only dest_step == 1 can we shrink the dest list. */
/* size should be ensured prior to calling this function */
if (dest_step == 1 && dest_end >= dest_start) {
__builtin_memmove(dest_arr + dest_ind * size, dest_arr + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size);
return dest_arr_len - (dest_end - dest_ind) - 1;
}
return dest_arr_len;
}
} // extern "C"

View File

@ -0,0 +1,93 @@
#pragma once
namespace {
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
// need to make sure `exp >= 0` before calling this function
template<typename T>
T __nac3_int_exp_impl(T base, T exp) {
T res = 1;
/* repeated squaring method */
do {
if (exp & 1) {
res *= base; /* for n odd */
}
exp >>= 1;
base *= base;
} while (exp);
return res;
}
} // namespace
#define DEF_nac3_int_exp_(T) \
T __nac3_int_exp_##T(T base, T exp) { \
return __nac3_int_exp_impl(base, exp); \
}
extern "C" {
// Putting semicolons here to make clang-format not reformat this into
// a stair shape.
DEF_nac3_int_exp_(int32_t);
DEF_nac3_int_exp_(int64_t);
DEF_nac3_int_exp_(uint32_t);
DEF_nac3_int_exp_(uint64_t);
int32_t __nac3_isinf(double x) {
return __builtin_isinf(x);
}
int32_t __nac3_isnan(double x) {
return __builtin_isnan(x);
}
double tgamma(double arg);
double __nac3_gamma(double z) {
// Handling for denormals
// | x | Python gamma(x) | C tgamma(x) |
// --- | ----------------- | --------------- | ----------- |
// (1) | nan | nan | nan |
// (2) | -inf | -inf | inf |
// (3) | inf | inf | inf |
// (4) | 0.0 | inf | inf |
// (5) | {-1.0, -2.0, ...} | inf | nan |
// (1)-(3)
if (__builtin_isinf(z) || __builtin_isnan(z)) {
return z;
}
double v = tgamma(z);
// (4)-(5)
return __builtin_isinf(v) || __builtin_isnan(v) ? __builtin_inf() : v;
}
double lgamma(double arg);
double __nac3_gammaln(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: gammaln(-inf) -> -inf
// - libm : lgamma(-inf) -> inf
if (__builtin_isinf(x)) {
return x;
}
return lgamma(x);
}
double j0(double x);
double __nac3_j0(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: j0(inf) -> nan
// - libm : j0(inf) -> 0.0
if (__builtin_isinf(x)) {
return __builtin_nan("");
}
return j0(x);
}
}

View File

@ -1,14 +1,13 @@
#pragma once #pragma once
namespace namespace {
{ template<typename T>
template <typename T> const T &max(const T &a, const T &b) const T& max(const T& a, const T& b) {
{
return a > b ? a : b; return a > b ? a : b;
} }
template <typename T> const T &min(const T &a, const T &b) template<typename T>
{ const T& min(const T& a, const T& b) {
return a > b ? b : a; return a > b ? b : a;
} }
} // namespace } // namespace

View File

@ -0,0 +1,151 @@
#pragma once
#include "irrt/int_types.hpp"
// TODO: To be deleted since NDArray with strides is done.
namespace {
template<typename SizeT>
SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) {
__builtin_assume(end_idx <= list_len);
SizeT num_elems = 1;
for (SizeT i = begin_idx; i < end_idx; ++i) {
SizeT val = list_data[i];
__builtin_assume(val > 0);
num_elems *= val;
}
return num_elems;
}
template<typename SizeT>
void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT num_dims, NDIndexInt* idxs) {
SizeT stride = 1;
for (SizeT dim = 0; dim < num_dims; dim++) {
SizeT i = num_dims - dim - 1;
__builtin_assume(dims[i] > 0);
idxs[i] = (index / stride) % dims[i];
stride *= dims[i];
}
}
template<typename SizeT>
SizeT __nac3_ndarray_flatten_index_impl(const SizeT* dims,
SizeT num_dims,
const NDIndexInt* indices,
SizeT num_indices) {
SizeT idx = 0;
SizeT stride = 1;
for (SizeT i = 0; i < num_dims; ++i) {
SizeT ri = num_dims - i - 1;
if (ri < num_indices) {
idx += stride * indices[ri];
}
__builtin_assume(dims[i] > 0);
stride *= dims[ri];
}
return idx;
}
template<typename SizeT>
void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims,
SizeT lhs_ndims,
const SizeT* rhs_dims,
SizeT rhs_ndims,
SizeT* out_dims) {
SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
for (SizeT i = 0; i < max_ndims; ++i) {
const SizeT* lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr;
const SizeT* rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr;
SizeT* out_dim = &out_dims[max_ndims - i - 1];
if (lhs_dim_sz == nullptr) {
*out_dim = *rhs_dim_sz;
} else if (rhs_dim_sz == nullptr) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == 1) {
*out_dim = *rhs_dim_sz;
} else if (*rhs_dim_sz == 1) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == *rhs_dim_sz) {
*out_dim = *lhs_dim_sz;
} else {
__builtin_unreachable();
}
}
}
template<typename SizeT>
void __nac3_ndarray_calc_broadcast_idx_impl(const SizeT* src_dims,
SizeT src_ndims,
const NDIndexInt* in_idx,
NDIndexInt* out_idx) {
for (SizeT i = 0; i < src_ndims; ++i) {
SizeT src_i = src_ndims - i - 1;
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
}
}
} // namespace
extern "C" {
uint32_t __nac3_ndarray_calc_size(const uint32_t* list_data, uint32_t list_len, uint32_t begin_idx, uint32_t end_idx) {
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
}
uint64_t
__nac3_ndarray_calc_size64(const uint64_t* list_data, uint64_t list_len, uint64_t begin_idx, uint64_t end_idx) {
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
}
void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims, uint32_t num_dims, NDIndexInt* idxs) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
}
void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndexInt* idxs) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
}
uint32_t
__nac3_ndarray_flatten_index(const uint32_t* dims, uint32_t num_dims, const NDIndexInt* indices, uint32_t num_indices) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
}
uint64_t __nac3_ndarray_flatten_index64(const uint64_t* dims,
uint64_t num_dims,
const NDIndexInt* indices,
uint64_t num_indices) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
}
void __nac3_ndarray_calc_broadcast(const uint32_t* lhs_dims,
uint32_t lhs_ndims,
const uint32_t* rhs_dims,
uint32_t rhs_ndims,
uint32_t* out_dims) {
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
}
void __nac3_ndarray_calc_broadcast64(const uint64_t* lhs_dims,
uint64_t lhs_ndims,
const uint64_t* rhs_dims,
uint64_t rhs_ndims,
uint64_t* out_dims) {
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
}
void __nac3_ndarray_calc_broadcast_idx(const uint32_t* src_dims,
uint32_t src_ndims,
const NDIndexInt* in_idx,
NDIndexInt* out_idx) {
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
}
void __nac3_ndarray_calc_broadcast_idx64(const uint64_t* src_dims,
uint64_t src_ndims,
const NDIndexInt* in_idx,
NDIndexInt* out_idx) {
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
}
}

View File

@ -1,38 +1,32 @@
#pragma once #pragma once
#include <irrt/debug.hpp> #include "irrt/debug.hpp"
#include <irrt/exception.hpp> #include "irrt/exception.hpp"
#include <irrt/int_types.hpp> #include "irrt/int_types.hpp"
#include <irrt/list.hpp> #include "irrt/list.hpp"
#include <irrt/ndarray/basic.hpp> #include "irrt/ndarray/basic.hpp"
#include <irrt/ndarray/def.hpp> #include "irrt/ndarray/def.hpp"
namespace namespace {
{ namespace ndarray {
namespace ndarray namespace array {
{
namespace array
{
/** /**
* @brief In the context of `np.array(<list>)`, deduce the ndarray's shape produced by `<list>` and raise * @brief In the context of `np.array(<list>)`, deduce the ndarray's shape produced by `<list>` and raise
* an exception if there is anything wrong with `<shape>` (e.g., inconsistent dimensions `np.array([[1.0, 2.0], [3.0]])`) * an exception if there is anything wrong with `<shape>` (e.g., inconsistent dimensions `np.array([[1.0, 2.0],
* [3.0]])`)
* *
* If this function finds no issues with `<list>`, the deduced shape is written to `shape`. The caller has the responsibility to * If this function finds no issues with `<list>`, the deduced shape is written to `shape`. The caller has the
* allocate `[SizeT; ndims]` for `shape`. The caller must also initialize `shape` with `-1`s because of implementation details. * responsibility to allocate `[SizeT; ndims]` for `shape`. The caller must also initialize `shape` with `-1`s because
* of implementation details.
*/ */
template <typename SizeT> template<typename SizeT>
void set_and_validate_list_shape_helper(SizeT axis, List<SizeT> *list, SizeT ndims, SizeT *shape) void set_and_validate_list_shape_helper(SizeT axis, List<SizeT>* list, SizeT ndims, SizeT* shape) {
{ if (shape[axis] == -1) {
if (shape[axis] == -1)
{
// Dimension is unspecified. Set it. // Dimension is unspecified. Set it.
shape[axis] = list->len; shape[axis] = list->len;
} } else {
else
{
// Dimension is specified. Check. // Dimension is specified. Check.
if (shape[axis] != list->len) if (shape[axis] != list->len) {
{
// Mismatch, throw an error. // Mismatch, throw an error.
// NOTE: NumPy's error message is more complex and needs more PARAMS to display. // NOTE: NumPy's error message is more complex and needs more PARAMS to display.
raise_exception(SizeT, EXN_VALUE_ERROR, raise_exception(SizeT, EXN_VALUE_ERROR,
@ -42,17 +36,13 @@ void set_and_validate_list_shape_helper(SizeT axis, List<SizeT> *list, SizeT ndi
} }
} }
if (axis + 1 == ndims) if (axis + 1 == ndims) {
{
// `list` has type `list[ItemType]` // `list` has type `list[ItemType]`
// Do nothing // Do nothing
} } else {
else
{
// `list` has type `list[list[...]]` // `list` has type `list[list[...]]`
List<SizeT> **lists = (List<SizeT> **)(list->items); List<SizeT>** lists = (List<SizeT>**)(list->items);
for (SizeT i = 0; i < list->len; i++) for (SizeT i = 0; i < list->len; i++) {
{
set_and_validate_list_shape_helper<SizeT>(axis + 1, lists[i], ndims, shape); set_and_validate_list_shape_helper<SizeT>(axis + 1, lists[i], ndims, shape);
} }
} }
@ -61,11 +51,10 @@ void set_and_validate_list_shape_helper(SizeT axis, List<SizeT> *list, SizeT ndi
/** /**
* @brief See `set_and_validate_list_shape_helper`. * @brief See `set_and_validate_list_shape_helper`.
*/ */
template <typename SizeT> void set_and_validate_list_shape(List<SizeT> *list, SizeT ndims, SizeT *shape) template<typename SizeT>
{ void set_and_validate_list_shape(List<SizeT>* list, SizeT ndims, SizeT* shape) {
for (SizeT axis = 0; axis < ndims; axis++) for (SizeT axis = 0; axis < ndims; axis++) {
{ shape[axis] = -1; // Sentinel to say this dimension is unspecified.
shape[axis] = -1; // Sentinel to say this dimension is unspecified.
} }
set_and_validate_list_shape_helper<SizeT>(0, list, ndims, shape); set_and_validate_list_shape_helper<SizeT>(0, list, ndims, shape);
} }
@ -86,34 +75,27 @@ template <typename SizeT> void set_and_validate_list_shape(List<SizeT> *list, Si
* When this function call ends: * When this function call ends:
* - `ndarray->data` is written with contents from `<list>`. * - `ndarray->data` is written with contents from `<list>`.
*/ */
template <typename SizeT> template<typename SizeT>
void write_list_to_array_helper(SizeT axis, SizeT *index, List<SizeT> *list, NDArray<SizeT> *ndarray) void write_list_to_array_helper(SizeT axis, SizeT* index, List<SizeT>* list, NDArray<SizeT>* ndarray) {
{
debug_assert_eq(SizeT, list->len, ndarray->shape[axis]); debug_assert_eq(SizeT, list->len, ndarray->shape[axis]);
if (IRRT_DEBUG_ASSERT_BOOL) if (IRRT_DEBUG_ASSERT_BOOL) {
{ if (!ndarray::basic::is_c_contiguous(ndarray)) {
if (!ndarray::basic::is_c_contiguous(ndarray))
{
raise_debug_assert(SizeT, "ndarray is not C-contiguous", ndarray->strides[0], ndarray->strides[1], raise_debug_assert(SizeT, "ndarray is not C-contiguous", ndarray->strides[0], ndarray->strides[1],
NO_PARAM); NO_PARAM);
} }
} }
if (axis + 1 == ndarray->ndims) if (axis + 1 == ndarray->ndims) {
{
// `list` has type `list[scalar]` // `list` has type `list[scalar]`
// `ndarray` is contiguous, so we can do this, and this is fast. // `ndarray` is contiguous, so we can do this, and this is fast.
uint8_t *dst = ndarray->data + (ndarray->itemsize * (*index)); uint8_t* dst = ndarray->data + (ndarray->itemsize * (*index));
__builtin_memcpy(dst, list->items, ndarray->itemsize * list->len); __builtin_memcpy(dst, list->items, ndarray->itemsize * list->len);
*index += list->len; *index += list->len;
} } else {
else
{
// `list` has type `list[list[...]]` // `list` has type `list[list[...]]`
List<SizeT> **lists = (List<SizeT> **)(list->items); List<SizeT>** lists = (List<SizeT>**)(list->items);
for (SizeT i = 0; i < list->len; i++) for (SizeT i = 0; i < list->len; i++) {
{
write_list_to_array_helper<SizeT>(axis + 1, index, lists[i], ndarray); write_list_to_array_helper<SizeT>(axis + 1, index, lists[i], ndarray);
} }
} }
@ -122,36 +104,31 @@ void write_list_to_array_helper(SizeT axis, SizeT *index, List<SizeT> *list, NDA
/** /**
* @brief See `write_list_to_array_helper`. * @brief See `write_list_to_array_helper`.
*/ */
template <typename SizeT> void write_list_to_array(List<SizeT> *list, NDArray<SizeT> *ndarray) template<typename SizeT>
{ void write_list_to_array(List<SizeT>* list, NDArray<SizeT>* ndarray) {
SizeT index = 0; SizeT index = 0;
write_list_to_array_helper<SizeT>((SizeT)0, &index, list, ndarray); write_list_to_array_helper<SizeT>((SizeT)0, &index, list, ndarray);
} }
} // namespace array } // namespace array
} // namespace ndarray } // namespace ndarray
} // namespace } // namespace
extern "C" extern "C" {
{ using namespace ndarray::array;
using namespace ndarray::array;
void __nac3_ndarray_array_set_and_validate_list_shape(List<int32_t> *list, int32_t ndims, int32_t *shape) void __nac3_ndarray_array_set_and_validate_list_shape(List<int32_t>* list, int32_t ndims, int32_t* shape) {
{ set_and_validate_list_shape(list, ndims, shape);
set_and_validate_list_shape(list, ndims, shape); }
}
void __nac3_ndarray_array_set_and_validate_list_shape64(List<int64_t> *list, int64_t ndims, int64_t *shape) void __nac3_ndarray_array_set_and_validate_list_shape64(List<int64_t>* list, int64_t ndims, int64_t* shape) {
{ set_and_validate_list_shape(list, ndims, shape);
set_and_validate_list_shape(list, ndims, shape); }
}
void __nac3_ndarray_array_write_list_to_array(List<int32_t> *list, NDArray<int32_t> *ndarray) void __nac3_ndarray_array_write_list_to_array(List<int32_t>* list, NDArray<int32_t>* ndarray) {
{ write_list_to_array(list, ndarray);
write_list_to_array(list, ndarray); }
}
void __nac3_ndarray_array_write_list_to_array64(List<int64_t> *list, NDArray<int64_t> *ndarray) void __nac3_ndarray_array_write_list_to_array64(List<int64_t>* list, NDArray<int64_t>* ndarray) {
{ write_list_to_array(list, ndarray);
write_list_to_array(list, ndarray); }
}
} }

View File

@ -1,28 +1,23 @@
#pragma once #pragma once
#include <irrt/debug.hpp> #include "irrt/debug.hpp"
#include <irrt/exception.hpp> #include "irrt/exception.hpp"
#include <irrt/int_types.hpp> #include "irrt/int_types.hpp"
#include <irrt/ndarray/def.hpp> #include "irrt/ndarray/def.hpp"
namespace namespace {
{ namespace ndarray {
namespace ndarray namespace basic {
{
namespace basic
{
/** /**
* @brief Assert that `shape` does not contain negative dimensions. * @brief Assert that `shape` does not contain negative dimensions.
* *
* @param ndims Number of dimensions in `shape` * @param ndims Number of dimensions in `shape`
* @param shape The shape to check on * @param shape The shape to check on
*/ */
template <typename SizeT> void assert_shape_no_negative(SizeT ndims, const SizeT *shape) template<typename SizeT>
{ void assert_shape_no_negative(SizeT ndims, const SizeT* shape) {
for (SizeT axis = 0; axis < ndims; axis++) for (SizeT axis = 0; axis < ndims; axis++) {
{ if (shape[axis] < 0) {
if (shape[axis] < 0)
{
raise_exception(SizeT, EXN_VALUE_ERROR, raise_exception(SizeT, EXN_VALUE_ERROR,
"negative dimensions are not allowed; axis {0} " "negative dimensions are not allowed; axis {0} "
"has dimension {1}", "has dimension {1}",
@ -34,21 +29,19 @@ template <typename SizeT> void assert_shape_no_negative(SizeT ndims, const SizeT
/** /**
* @brief Assert that two shapes are the same in the context of writing output to an ndarray. * @brief Assert that two shapes are the same in the context of writing output to an ndarray.
*/ */
template <typename SizeT> template<typename SizeT>
void assert_output_shape_same(SizeT ndarray_ndims, const SizeT *ndarray_shape, SizeT output_ndims, void assert_output_shape_same(SizeT ndarray_ndims,
const SizeT *output_shape) const SizeT* ndarray_shape,
{ SizeT output_ndims,
if (ndarray_ndims != output_ndims) const SizeT* output_shape) {
{ if (ndarray_ndims != output_ndims) {
// There is no corresponding NumPy error message like this. // There is no corresponding NumPy error message like this.
raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot write output of ndims {0} to an ndarray with ndims {1}", raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot write output of ndims {0} to an ndarray with ndims {1}",
output_ndims, ndarray_ndims, NO_PARAM); output_ndims, ndarray_ndims, NO_PARAM);
} }
for (SizeT axis = 0; axis < ndarray_ndims; axis++) for (SizeT axis = 0; axis < ndarray_ndims; axis++) {
{ if (ndarray_shape[axis] != output_shape[axis]) {
if (ndarray_shape[axis] != output_shape[axis])
{
// There is no corresponding NumPy error message like this. // There is no corresponding NumPy error message like this.
raise_exception(SizeT, EXN_VALUE_ERROR, raise_exception(SizeT, EXN_VALUE_ERROR,
"Mismatched dimensions on axis {0}, output has " "Mismatched dimensions on axis {0}, output has "
@ -64,8 +57,8 @@ void assert_output_shape_same(SizeT ndarray_ndims, const SizeT *ndarray_shape, S
* @param ndims Number of dimensions in `shape` * @param ndims Number of dimensions in `shape`
* @param shape The shape of the ndarray * @param shape The shape of the ndarray
*/ */
template <typename SizeT> SizeT calc_size_from_shape(SizeT ndims, const SizeT *shape) template<typename SizeT>
{ SizeT calc_size_from_shape(SizeT ndims, const SizeT* shape) {
SizeT size = 1; SizeT size = 1;
for (SizeT axis = 0; axis < ndims; axis++) for (SizeT axis = 0; axis < ndims; axis++)
size *= shape[axis]; size *= shape[axis];
@ -80,10 +73,9 @@ template <typename SizeT> SizeT calc_size_from_shape(SizeT ndims, const SizeT *s
* @param indices The returned indices indexing the ndarray with shape `shape`. * @param indices The returned indices indexing the ndarray with shape `shape`.
* @param nth The index of the element of interest. * @param nth The index of the element of interest.
*/ */
template <typename SizeT> void set_indices_by_nth(SizeT ndims, const SizeT *shape, SizeT *indices, SizeT nth) template<typename SizeT>
{ void set_indices_by_nth(SizeT ndims, const SizeT* shape, SizeT* indices, SizeT nth) {
for (SizeT i = 0; i < ndims; i++) for (SizeT i = 0; i < ndims; i++) {
{
SizeT axis = ndims - i - 1; SizeT axis = ndims - i - 1;
SizeT dim = shape[axis]; SizeT dim = shape[axis];
@ -97,8 +89,8 @@ template <typename SizeT> void set_indices_by_nth(SizeT ndims, const SizeT *shap
* *
* This function corresponds to `<an_ndarray>.size` * This function corresponds to `<an_ndarray>.size`
*/ */
template <typename SizeT> SizeT size(const NDArray<SizeT> *ndarray) template<typename SizeT>
{ SizeT size(const NDArray<SizeT>* ndarray) {
return calc_size_from_shape(ndarray->ndims, ndarray->shape); return calc_size_from_shape(ndarray->ndims, ndarray->shape);
} }
@ -107,8 +99,8 @@ template <typename SizeT> SizeT size(const NDArray<SizeT> *ndarray)
* *
* This function corresponds to `<an_ndarray>.nbytes`. * This function corresponds to `<an_ndarray>.nbytes`.
*/ */
template <typename SizeT> SizeT nbytes(const NDArray<SizeT> *ndarray) template<typename SizeT>
{ SizeT nbytes(const NDArray<SizeT>* ndarray) {
return size(ndarray) * ndarray->itemsize; return size(ndarray) * ndarray->itemsize;
} }
@ -119,15 +111,12 @@ template <typename SizeT> SizeT nbytes(const NDArray<SizeT> *ndarray)
* *
* @param dst_length The length. * @param dst_length The length.
*/ */
template <typename SizeT> SizeT len(const NDArray<SizeT> *ndarray) template<typename SizeT>
{ SizeT len(const NDArray<SizeT>* ndarray) {
// numpy prohibits `__len__` on unsized objects // numpy prohibits `__len__` on unsized objects
if (ndarray->ndims == 0) if (ndarray->ndims == 0) {
{
raise_exception(SizeT, EXN_TYPE_ERROR, "len() of unsized object", NO_PARAM, NO_PARAM, NO_PARAM); raise_exception(SizeT, EXN_TYPE_ERROR, "len() of unsized object", NO_PARAM, NO_PARAM, NO_PARAM);
} } else {
else
{
return ndarray->shape[0]; return ndarray->shape[0];
} }
} }
@ -135,16 +124,21 @@ template <typename SizeT> SizeT len(const NDArray<SizeT> *ndarray)
/** /**
* @brief Return a boolean indicating if `ndarray` is (C-)contiguous. * @brief Return a boolean indicating if `ndarray` is (C-)contiguous.
* *
* You may want to see ndarray's rules for C-contiguity: https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45 * You may want to see ndarray's rules for C-contiguity:
* https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45
*/ */
template <typename SizeT> bool is_c_contiguous(const NDArray<SizeT> *ndarray) template<typename SizeT>
{ bool is_c_contiguous(const NDArray<SizeT>* ndarray) {
// References: // References:
// - tinynumpy's implementation: https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L102 // - tinynumpy's implementation:
// - ndarray's flags["C_CONTIGUOUS"]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags // https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L102
// - ndarray's rules for C-contiguity: https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45 // - ndarray's flags["C_CONTIGUOUS"]:
// https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags
// - ndarray's rules for C-contiguity:
// https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45
// From https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45: // From
// https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45:
// //
// The traditional rule is that for an array to be flagged as C contiguous, // The traditional rule is that for an array to be flagged as C contiguous,
// the following must hold: // the following must hold:
@ -160,21 +154,17 @@ template <typename SizeT> bool is_c_contiguous(const NDArray<SizeT> *ndarray)
// with shape[i] == 0. In the second case `strides == itemsize` will // with shape[i] == 0. In the second case `strides == itemsize` will
// can be true for all dimensions and both flags are set. // can be true for all dimensions and both flags are set.
if (ndarray->ndims == 0) if (ndarray->ndims == 0) {
{
return true; return true;
} }
if (ndarray->strides[ndarray->ndims - 1] != ndarray->itemsize) if (ndarray->strides[ndarray->ndims - 1] != ndarray->itemsize) {
{
return false; return false;
} }
for (SizeT i = 1; i < ndarray->ndims; i++) for (SizeT i = 1; i < ndarray->ndims; i++) {
{
SizeT axis_i = ndarray->ndims - i - 1; SizeT axis_i = ndarray->ndims - i - 1;
if (ndarray->strides[axis_i] != ndarray->shape[axis_i + 1] * ndarray->strides[axis_i + 1]) if (ndarray->strides[axis_i] != ndarray->shape[axis_i + 1] * ndarray->strides[axis_i + 1]) {
{
return false; return false;
} }
} }
@ -187,9 +177,9 @@ template <typename SizeT> bool is_c_contiguous(const NDArray<SizeT> *ndarray)
* *
* This function does no bound check. * This function does no bound check.
*/ */
template <typename SizeT> uint8_t *get_pelement_by_indices(const NDArray<SizeT> *ndarray, const SizeT *indices) template<typename SizeT>
{ uint8_t* get_pelement_by_indices(const NDArray<SizeT>* ndarray, const SizeT* indices) {
uint8_t *element = ndarray->data; uint8_t* element = ndarray->data;
for (SizeT dim_i = 0; dim_i < ndarray->ndims; dim_i++) for (SizeT dim_i = 0; dim_i < ndarray->ndims; dim_i++)
element += indices[dim_i] * ndarray->strides[dim_i]; element += indices[dim_i] * ndarray->strides[dim_i];
return element; return element;
@ -200,11 +190,10 @@ template <typename SizeT> uint8_t *get_pelement_by_indices(const NDArray<SizeT>
* *
* This function does no bound check. * This function does no bound check.
*/ */
template <typename SizeT> uint8_t *get_nth_pelement(const NDArray<SizeT> *ndarray, SizeT nth) template<typename SizeT>
{ uint8_t* get_nth_pelement(const NDArray<SizeT>* ndarray, SizeT nth) {
uint8_t *element = ndarray->data; uint8_t* element = ndarray->data;
for (SizeT i = 0; i < ndarray->ndims; i++) for (SizeT i = 0; i < ndarray->ndims; i++) {
{
SizeT axis = ndarray->ndims - i - 1; SizeT axis = ndarray->ndims - i - 1;
SizeT dim = ndarray->shape[axis]; SizeT dim = ndarray->shape[axis];
element += ndarray->strides[axis] * (nth % dim); element += ndarray->strides[axis] * (nth % dim);
@ -218,11 +207,10 @@ template <typename SizeT> uint8_t *get_nth_pelement(const NDArray<SizeT> *ndarra
* *
* You might want to read https://ajcr.net/stride-guide-part-1/. * You might want to read https://ajcr.net/stride-guide-part-1/.
*/ */
template <typename SizeT> void set_strides_by_shape(NDArray<SizeT> *ndarray) template<typename SizeT>
{ void set_strides_by_shape(NDArray<SizeT>* ndarray) {
SizeT stride_product = 1; SizeT stride_product = 1;
for (SizeT i = 0; i < ndarray->ndims; i++) for (SizeT i = 0; i < ndarray->ndims; i++) {
{
SizeT axis = ndarray->ndims - i - 1; SizeT axis = ndarray->ndims - i - 1;
ndarray->strides[axis] = stride_product * ndarray->itemsize; ndarray->strides[axis] = stride_product * ndarray->itemsize;
stride_product *= ndarray->shape[axis]; stride_product *= ndarray->shape[axis];
@ -235,8 +223,8 @@ template <typename SizeT> void set_strides_by_shape(NDArray<SizeT> *ndarray)
* @param pelement Pointer to the element in `ndarray` to be set. * @param pelement Pointer to the element in `ndarray` to be set.
* @param pvalue Pointer to the value `pelement` will be set to. * @param pvalue Pointer to the value `pelement` will be set to.
*/ */
template <typename SizeT> void set_pelement_value(NDArray<SizeT> *ndarray, uint8_t *pelement, const uint8_t *pvalue) template<typename SizeT>
{ void set_pelement_value(NDArray<SizeT>* ndarray, uint8_t* pelement, const uint8_t* pvalue) {
__builtin_memcpy(pelement, pvalue, ndarray->itemsize); __builtin_memcpy(pelement, pvalue, ndarray->itemsize);
} }
@ -245,127 +233,109 @@ template <typename SizeT> void set_pelement_value(NDArray<SizeT> *ndarray, uint8
* *
* Both ndarrays will be viewed in their flatten views when copying the elements. * Both ndarrays will be viewed in their flatten views when copying the elements.
*/ */
template <typename SizeT> void copy_data(const NDArray<SizeT> *src_ndarray, NDArray<SizeT> *dst_ndarray) template<typename SizeT>
{ void copy_data(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
// TODO: Make this faster with memcpy when we see a contiguous segment. // TODO: Make this faster with memcpy when we see a contiguous segment.
// TODO: Handle overlapping. // TODO: Handle overlapping.
debug_assert_eq(SizeT, src_ndarray->itemsize, dst_ndarray->itemsize); debug_assert_eq(SizeT, src_ndarray->itemsize, dst_ndarray->itemsize);
for (SizeT i = 0; i < size(src_ndarray); i++) for (SizeT i = 0; i < size(src_ndarray); i++) {
{
auto src_element = ndarray::basic::get_nth_pelement(src_ndarray, i); auto src_element = ndarray::basic::get_nth_pelement(src_ndarray, i);
auto dst_element = ndarray::basic::get_nth_pelement(dst_ndarray, i); auto dst_element = ndarray::basic::get_nth_pelement(dst_ndarray, i);
ndarray::basic::set_pelement_value(dst_ndarray, dst_element, src_element); ndarray::basic::set_pelement_value(dst_ndarray, dst_element, src_element);
} }
} }
} // namespace basic } // namespace basic
} // namespace ndarray } // namespace ndarray
} // namespace } // namespace
extern "C" extern "C" {
{ using namespace ndarray::basic;
using namespace ndarray::basic;
void __nac3_ndarray_util_assert_shape_no_negative(int32_t ndims, int32_t *shape) void __nac3_ndarray_util_assert_shape_no_negative(int32_t ndims, int32_t* shape) {
{ assert_shape_no_negative(ndims, shape);
assert_shape_no_negative(ndims, shape); }
}
void __nac3_ndarray_util_assert_shape_no_negative64(int64_t ndims, int64_t *shape) void __nac3_ndarray_util_assert_shape_no_negative64(int64_t ndims, int64_t* shape) {
{ assert_shape_no_negative(ndims, shape);
assert_shape_no_negative(ndims, shape); }
}
void __nac3_ndarray_util_assert_output_shape_same(int32_t ndarray_ndims, const int32_t *ndarray_shape, void __nac3_ndarray_util_assert_output_shape_same(int32_t ndarray_ndims,
int32_t output_ndims, const int32_t *output_shape) const int32_t* ndarray_shape,
{ int32_t output_ndims,
assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims, output_shape); const int32_t* output_shape) {
} assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims, output_shape);
}
void __nac3_ndarray_util_assert_output_shape_same64(int64_t ndarray_ndims, const int64_t *ndarray_shape, void __nac3_ndarray_util_assert_output_shape_same64(int64_t ndarray_ndims,
int64_t output_ndims, const int64_t *output_shape) const int64_t* ndarray_shape,
{ int64_t output_ndims,
assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims, output_shape); const int64_t* output_shape) {
} assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims, output_shape);
}
uint32_t __nac3_ndarray_size(NDArray<int32_t> *ndarray) uint32_t __nac3_ndarray_size(NDArray<int32_t>* ndarray) {
{ return size(ndarray);
return size(ndarray); }
}
uint64_t __nac3_ndarray_size64(NDArray<int64_t> *ndarray) uint64_t __nac3_ndarray_size64(NDArray<int64_t>* ndarray) {
{ return size(ndarray);
return size(ndarray); }
}
uint32_t __nac3_ndarray_nbytes(NDArray<int32_t> *ndarray) uint32_t __nac3_ndarray_nbytes(NDArray<int32_t>* ndarray) {
{ return nbytes(ndarray);
return nbytes(ndarray); }
}
uint64_t __nac3_ndarray_nbytes64(NDArray<int64_t> *ndarray) uint64_t __nac3_ndarray_nbytes64(NDArray<int64_t>* ndarray) {
{ return nbytes(ndarray);
return nbytes(ndarray); }
}
int32_t __nac3_ndarray_len(NDArray<int32_t> *ndarray) int32_t __nac3_ndarray_len(NDArray<int32_t>* ndarray) {
{ return len(ndarray);
return len(ndarray); }
}
int64_t __nac3_ndarray_len64(NDArray<int64_t> *ndarray) int64_t __nac3_ndarray_len64(NDArray<int64_t>* ndarray) {
{ return len(ndarray);
return len(ndarray); }
}
bool __nac3_ndarray_is_c_contiguous(NDArray<int32_t> *ndarray) bool __nac3_ndarray_is_c_contiguous(NDArray<int32_t>* ndarray) {
{ return is_c_contiguous(ndarray);
return is_c_contiguous(ndarray); }
}
bool __nac3_ndarray_is_c_contiguous64(NDArray<int64_t> *ndarray) bool __nac3_ndarray_is_c_contiguous64(NDArray<int64_t>* ndarray) {
{ return is_c_contiguous(ndarray);
return is_c_contiguous(ndarray); }
}
uint8_t *__nac3_ndarray_get_nth_pelement(const NDArray<int32_t> *ndarray, int32_t nth) uint8_t* __nac3_ndarray_get_nth_pelement(const NDArray<int32_t>* ndarray, int32_t nth) {
{ return get_nth_pelement(ndarray, nth);
return get_nth_pelement(ndarray, nth); }
}
uint8_t *__nac3_ndarray_get_nth_pelement64(const NDArray<int64_t> *ndarray, int64_t nth) uint8_t* __nac3_ndarray_get_nth_pelement64(const NDArray<int64_t>* ndarray, int64_t nth) {
{ return get_nth_pelement(ndarray, nth);
return get_nth_pelement(ndarray, nth); }
}
uint8_t *__nac3_ndarray_get_pelement_by_indices(const NDArray<int32_t> *ndarray, int32_t *indices) uint8_t* __nac3_ndarray_get_pelement_by_indices(const NDArray<int32_t>* ndarray, int32_t* indices) {
{ return get_pelement_by_indices(ndarray, indices);
return get_pelement_by_indices(ndarray, indices); }
}
uint8_t *__nac3_ndarray_get_pelement_by_indices64(const NDArray<int64_t> *ndarray, int64_t *indices) uint8_t* __nac3_ndarray_get_pelement_by_indices64(const NDArray<int64_t>* ndarray, int64_t* indices) {
{ return get_pelement_by_indices(ndarray, indices);
return get_pelement_by_indices(ndarray, indices); }
}
void __nac3_ndarray_set_strides_by_shape(NDArray<int32_t> *ndarray) void __nac3_ndarray_set_strides_by_shape(NDArray<int32_t>* ndarray) {
{ set_strides_by_shape(ndarray);
set_strides_by_shape(ndarray); }
}
void __nac3_ndarray_set_strides_by_shape64(NDArray<int64_t> *ndarray) void __nac3_ndarray_set_strides_by_shape64(NDArray<int64_t>* ndarray) {
{ set_strides_by_shape(ndarray);
set_strides_by_shape(ndarray); }
}
void __nac3_ndarray_copy_data(NDArray<int32_t> *src_ndarray, NDArray<int32_t> *dst_ndarray) void __nac3_ndarray_copy_data(NDArray<int32_t>* src_ndarray, NDArray<int32_t>* dst_ndarray) {
{ copy_data(src_ndarray, dst_ndarray);
copy_data(src_ndarray, dst_ndarray); }
}
void __nac3_ndarray_copy_data64(NDArray<int64_t> *src_ndarray, NDArray<int64_t> *dst_ndarray) void __nac3_ndarray_copy_data64(NDArray<int64_t>* src_ndarray, NDArray<int64_t>* dst_ndarray) {
{ copy_data(src_ndarray, dst_ndarray);
copy_data(src_ndarray, dst_ndarray); }
}
} }

View File

@ -1,43 +1,35 @@
#pragma once #pragma once
#include <irrt/int_types.hpp> #include "irrt/int_types.hpp"
#include <irrt/ndarray/def.hpp> #include "irrt/ndarray/def.hpp"
#include <irrt/slice.hpp> #include "irrt/slice.hpp"
namespace namespace {
{ template<typename SizeT>
template <typename SizeT> struct ShapeEntry struct ShapeEntry {
{
SizeT ndims; SizeT ndims;
SizeT *shape; SizeT* shape;
}; };
} // namespace } // namespace
namespace namespace {
{ namespace ndarray {
namespace ndarray namespace broadcast {
{
namespace broadcast
{
/** /**
* @brief Return true if `src_shape` can broadcast to `dst_shape`. * @brief Return true if `src_shape` can broadcast to `dst_shape`.
* *
* See https://numpy.org/doc/stable/user/basics.broadcasting.html * See https://numpy.org/doc/stable/user/basics.broadcasting.html
*/ */
template <typename SizeT> template<typename SizeT>
bool can_broadcast_shape_to(SizeT target_ndims, const SizeT *target_shape, SizeT src_ndims, const SizeT *src_shape) bool can_broadcast_shape_to(SizeT target_ndims, const SizeT* target_shape, SizeT src_ndims, const SizeT* src_shape) {
{ if (src_ndims > target_ndims) {
if (src_ndims > target_ndims)
{
return false; return false;
} }
for (SizeT i = 0; i < src_ndims; i++) for (SizeT i = 0; i < src_ndims; i++) {
{
SizeT target_dim = target_shape[target_ndims - i - 1]; SizeT target_dim = target_shape[target_ndims - i - 1];
SizeT src_dim = src_shape[src_ndims - i - 1]; SizeT src_dim = src_shape[src_ndims - i - 1];
if (!(src_dim == 1 || target_dim == src_dim)) if (!(src_dim == 1 || target_dim == src_dim)) {
{
return false; return false;
} }
} }
@ -55,11 +47,9 @@ bool can_broadcast_shape_to(SizeT target_ndims, const SizeT *target_shape, SizeT
* @param dst_shape The resulting shape. Must be pre-allocated by the caller. This function calculate the result * @param dst_shape The resulting shape. Must be pre-allocated by the caller. This function calculate the result
* of `np.broadcast_shapes` and write it here. * of `np.broadcast_shapes` and write it here.
*/ */
template <typename SizeT> template<typename SizeT>
void broadcast_shapes(SizeT num_shapes, const ShapeEntry<SizeT> *shapes, SizeT dst_ndims, SizeT *dst_shape) void broadcast_shapes(SizeT num_shapes, const ShapeEntry<SizeT>* shapes, SizeT dst_ndims, SizeT* dst_shape) {
{ for (SizeT dst_axis = 0; dst_axis < dst_ndims; dst_axis++) {
for (SizeT dst_axis = 0; dst_axis < dst_ndims; dst_axis++)
{
dst_shape[dst_axis] = 1; dst_shape[dst_axis] = 1;
} }
@ -67,8 +57,7 @@ void broadcast_shapes(SizeT num_shapes, const ShapeEntry<SizeT> *shapes, SizeT d
SizeT max_ndims_found = 0; SizeT max_ndims_found = 0;
#endif #endif
for (SizeT i = 0; i < num_shapes; i++) for (SizeT i = 0; i < num_shapes; i++) {
{
ShapeEntry<SizeT> entry = shapes[i]; ShapeEntry<SizeT> entry = shapes[i];
// Check pre-condition: `dst_ndims` must be `max([shape.ndims for shape in shapes])` // Check pre-condition: `dst_ndims` must be `max([shape.ndims for shape in shapes])`
@ -78,24 +67,18 @@ void broadcast_shapes(SizeT num_shapes, const ShapeEntry<SizeT> *shapes, SizeT d
max_ndims_found = max(max_ndims_found, entry.ndims); max_ndims_found = max(max_ndims_found, entry.ndims);
#endif #endif
for (SizeT j = 0; j < entry.ndims; j++) for (SizeT j = 0; j < entry.ndims; j++) {
{
SizeT entry_axis = entry.ndims - j - 1; SizeT entry_axis = entry.ndims - j - 1;
SizeT dst_axis = dst_ndims - j - 1; SizeT dst_axis = dst_ndims - j - 1;
SizeT entry_dim = entry.shape[entry_axis]; SizeT entry_dim = entry.shape[entry_axis];
SizeT dst_dim = dst_shape[dst_axis]; SizeT dst_dim = dst_shape[dst_axis];
if (dst_dim == 1) if (dst_dim == 1) {
{
dst_shape[dst_axis] = entry_dim; dst_shape[dst_axis] = entry_dim;
} } else if (entry_dim == 1 || entry_dim == dst_dim) {
else if (entry_dim == 1 || entry_dim == dst_dim)
{
// Do nothing // Do nothing
} } else {
else
{
raise_exception(SizeT, EXN_VALUE_ERROR, raise_exception(SizeT, EXN_VALUE_ERROR,
"shape mismatch: objects cannot be broadcast " "shape mismatch: objects cannot be broadcast "
"to a single shape.", "to a single shape.",
@ -129,11 +112,10 @@ void broadcast_shapes(SizeT num_shapes, const ShapeEntry<SizeT> *shapes, SizeT d
* - `dst_ndarray->shape` is unchanged. * - `dst_ndarray->shape` is unchanged.
* - `dst_ndarray->strides` is updated accordingly by how ndarray broadcast_to works. * - `dst_ndarray->strides` is updated accordingly by how ndarray broadcast_to works.
*/ */
template <typename SizeT> void broadcast_to(const NDArray<SizeT> *src_ndarray, NDArray<SizeT> *dst_ndarray) template<typename SizeT>
{ void broadcast_to(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
if (!ndarray::broadcast::can_broadcast_shape_to(dst_ndarray->ndims, dst_ndarray->shape, src_ndarray->ndims, if (!ndarray::broadcast::can_broadcast_shape_to(dst_ndarray->ndims, dst_ndarray->shape, src_ndarray->ndims,
src_ndarray->shape)) src_ndarray->shape)) {
{
raise_exception(SizeT, EXN_VALUE_ERROR, "operands could not be broadcast together", NO_PARAM, NO_PARAM, raise_exception(SizeT, EXN_VALUE_ERROR, "operands could not be broadcast together", NO_PARAM, NO_PARAM,
NO_PARAM); NO_PARAM);
} }
@ -141,48 +123,43 @@ template <typename SizeT> void broadcast_to(const NDArray<SizeT> *src_ndarray, N
dst_ndarray->data = src_ndarray->data; dst_ndarray->data = src_ndarray->data;
dst_ndarray->itemsize = src_ndarray->itemsize; dst_ndarray->itemsize = src_ndarray->itemsize;
for (SizeT i = 0; i < dst_ndarray->ndims; i++) for (SizeT i = 0; i < dst_ndarray->ndims; i++) {
{
SizeT src_axis = src_ndarray->ndims - i - 1; SizeT src_axis = src_ndarray->ndims - i - 1;
SizeT dst_axis = dst_ndarray->ndims - i - 1; SizeT dst_axis = dst_ndarray->ndims - i - 1;
if (src_axis < 0 || (src_ndarray->shape[src_axis] == 1 && dst_ndarray->shape[dst_axis] != 1)) if (src_axis < 0 || (src_ndarray->shape[src_axis] == 1 && dst_ndarray->shape[dst_axis] != 1)) {
{
// Freeze the steps in-place // Freeze the steps in-place
dst_ndarray->strides[dst_axis] = 0; dst_ndarray->strides[dst_axis] = 0;
} } else {
else
{
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis]; dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
} }
} }
} }
} // namespace broadcast } // namespace broadcast
} // namespace ndarray } // namespace ndarray
} // namespace } // namespace
extern "C" extern "C" {
{ using namespace ndarray::broadcast;
using namespace ndarray::broadcast;
void __nac3_ndarray_broadcast_to(NDArray<int32_t> *src_ndarray, NDArray<int32_t> *dst_ndarray) void __nac3_ndarray_broadcast_to(NDArray<int32_t>* src_ndarray, NDArray<int32_t>* dst_ndarray) {
{ broadcast_to(src_ndarray, dst_ndarray);
broadcast_to(src_ndarray, dst_ndarray); }
}
void __nac3_ndarray_broadcast_to64(NDArray<int64_t> *src_ndarray, NDArray<int64_t> *dst_ndarray) void __nac3_ndarray_broadcast_to64(NDArray<int64_t>* src_ndarray, NDArray<int64_t>* dst_ndarray) {
{ broadcast_to(src_ndarray, dst_ndarray);
broadcast_to(src_ndarray, dst_ndarray); }
}
void __nac3_ndarray_broadcast_shapes(int32_t num_shapes, const ShapeEntry<int32_t> *shapes, int32_t dst_ndims, void __nac3_ndarray_broadcast_shapes(int32_t num_shapes,
int32_t *dst_shape) const ShapeEntry<int32_t>* shapes,
{ int32_t dst_ndims,
broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape); int32_t* dst_shape) {
} broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape);
}
void __nac3_ndarray_broadcast_shapes64(int64_t num_shapes, const ShapeEntry<int64_t> *shapes, int64_t dst_ndims, void __nac3_ndarray_broadcast_shapes64(int64_t num_shapes,
int64_t *dst_shape) const ShapeEntry<int64_t>* shapes,
{ int64_t dst_ndims,
broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape); int64_t* dst_shape) {
} broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape);
}
} }

View File

@ -1,20 +1,20 @@
#pragma once #pragma once
#include <irrt/int_types.hpp> #include "irrt/int_types.hpp"
namespace namespace {
{
/** /**
* @brief The NDArray object * @brief The NDArray object
* *
* Official numpy implementation: https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst * Official numpy implementation:
* https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst
*/ */
template <typename SizeT> struct NDArray template<typename SizeT>
{ struct NDArray {
/** /**
* @brief The underlying data this `ndarray` is pointing to. * @brief The underlying data this `ndarray` is pointing to.
*/ */
uint8_t *data; uint8_t* data;
/** /**
* @brief The number of bytes of a single element in `data`. * @brief The number of bytes of a single element in `data`.
@ -31,7 +31,7 @@ template <typename SizeT> struct NDArray
* *
* Note that it may contain 0. * Note that it may contain 0.
*/ */
SizeT *shape; SizeT* shape;
/** /**
* @brief Array strides, with length equal to `ndims` * @brief Array strides, with length equal to `ndims`
@ -40,6 +40,6 @@ template <typename SizeT> struct NDArray
* *
* Note that `strides` can have negative values or contain 0. * Note that `strides` can have negative values or contain 0.
*/ */
SizeT *strides; SizeT* strides;
}; };
} // namespace } // namespace

View File

@ -1,14 +1,13 @@
#pragma once #pragma once
#include <irrt/exception.hpp> #include "irrt/exception.hpp"
#include <irrt/int_types.hpp> #include "irrt/int_types.hpp"
#include <irrt/ndarray/basic.hpp> #include "irrt/ndarray/basic.hpp"
#include <irrt/ndarray/def.hpp> #include "irrt/ndarray/def.hpp"
#include <irrt/range.hpp> #include "irrt/range.hpp"
#include <irrt/slice.hpp> #include "irrt/slice.hpp"
namespace namespace {
{
typedef uint8_t NDIndexType; typedef uint8_t NDIndexType;
/** /**
@ -48,8 +47,7 @@ const NDIndexType ND_INDEX_TYPE_ELLIPSIS = 3;
* ^^^^ ^ ^^^ ^^^^^^^^^^ each of these is represented by an NDIndex. * ^^^^ ^ ^^^ ^^^^^^^^^^ each of these is represented by an NDIndex.
* ``` * ```
*/ */
struct NDIndex struct NDIndex {
{
/** /**
* @brief Enum tag to specify the type of index. * @brief Enum tag to specify the type of index.
* *
@ -62,16 +60,13 @@ struct NDIndex
* *
* Please see the comment of each enum constant. * Please see the comment of each enum constant.
*/ */
uint8_t *data; uint8_t* data;
}; };
} // namespace } // namespace
namespace namespace {
{ namespace ndarray {
namespace ndarray namespace indexing {
{
namespace indexing
{
/** /**
* @brief Perform ndarray "basic indexing" (https://numpy.org/doc/stable/user/basics.indexing.html#basic-indexing) * @brief Perform ndarray "basic indexing" (https://numpy.org/doc/stable/user/basics.indexing.html#basic-indexing)
* *
@ -99,9 +94,8 @@ namespace indexing
* @param src_ndarray The NDArray to be indexed. * @param src_ndarray The NDArray to be indexed.
* @param dst_ndarray The resulting NDArray after indexing. Further details in the comments above, * @param dst_ndarray The resulting NDArray after indexing. Further details in the comments above,
*/ */
template <typename SizeT> template<typename SizeT>
void index(SizeT num_indices, const NDIndex *indices, const NDArray<SizeT> *src_ndarray, NDArray<SizeT> *dst_ndarray) void index(SizeT num_indices, const NDIndex* indices, const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
{
// Validate `indices`. // Validate `indices`.
// Expected value of `dst_ndarray->ndims`. // Expected value of `dst_ndarray->ndims`.
@ -111,40 +105,28 @@ void index(SizeT num_indices, const NDIndex *indices, const NDArray<SizeT> *src_
// There may be ellipsis `...` in `indices`. There can only be 0 or 1 ellipsis. // There may be ellipsis `...` in `indices`. There can only be 0 or 1 ellipsis.
SizeT num_ellipsis = 0; SizeT num_ellipsis = 0;
for (SizeT i = 0; i < num_indices; i++) for (SizeT i = 0; i < num_indices; i++) {
{ if (indices[i].type == ND_INDEX_TYPE_SINGLE_ELEMENT) {
if (indices[i].type == ND_INDEX_TYPE_SINGLE_ELEMENT)
{
expected_dst_ndims--; expected_dst_ndims--;
num_indexed++; num_indexed++;
} } else if (indices[i].type == ND_INDEX_TYPE_SLICE) {
else if (indices[i].type == ND_INDEX_TYPE_SLICE)
{
num_indexed++; num_indexed++;
} } else if (indices[i].type == ND_INDEX_TYPE_NEWAXIS) {
else if (indices[i].type == ND_INDEX_TYPE_NEWAXIS)
{
expected_dst_ndims++; expected_dst_ndims++;
} } else if (indices[i].type == ND_INDEX_TYPE_ELLIPSIS) {
else if (indices[i].type == ND_INDEX_TYPE_ELLIPSIS)
{
num_ellipsis++; num_ellipsis++;
if (num_ellipsis > 1) if (num_ellipsis > 1) {
{
raise_exception(SizeT, EXN_INDEX_ERROR, "an index can only have a single ellipsis ('...')", NO_PARAM, raise_exception(SizeT, EXN_INDEX_ERROR, "an index can only have a single ellipsis ('...')", NO_PARAM,
NO_PARAM, NO_PARAM); NO_PARAM, NO_PARAM);
} }
} } else {
else
{
__builtin_unreachable(); __builtin_unreachable();
} }
} }
debug_assert_eq(SizeT, expected_dst_ndims, dst_ndarray->ndims); debug_assert_eq(SizeT, expected_dst_ndims, dst_ndarray->ndims);
if (src_ndarray->ndims - num_indexed < 0) if (src_ndarray->ndims - num_indexed < 0) {
{
raise_exception(SizeT, EXN_INDEX_ERROR, raise_exception(SizeT, EXN_INDEX_ERROR,
"too many indices for array: array is {0}-dimensional, " "too many indices for array: array is {0}-dimensional, "
"but {1} were indexed", "but {1} were indexed",
@ -154,20 +136,18 @@ void index(SizeT num_indices, const NDIndex *indices, const NDArray<SizeT> *src_
dst_ndarray->data = src_ndarray->data; dst_ndarray->data = src_ndarray->data;
dst_ndarray->itemsize = src_ndarray->itemsize; dst_ndarray->itemsize = src_ndarray->itemsize;
// Reference code: https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L652 // Reference code:
// https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L652
SizeT src_axis = 0; SizeT src_axis = 0;
SizeT dst_axis = 0; SizeT dst_axis = 0;
for (int32_t i = 0; i < num_indices; i++) for (int32_t i = 0; i < num_indices; i++) {
{ const NDIndex* index = &indices[i];
const NDIndex *index = &indices[i]; if (index->type == ND_INDEX_TYPE_SINGLE_ELEMENT) {
if (index->type == ND_INDEX_TYPE_SINGLE_ELEMENT) SizeT input = (SizeT) * ((int32_t*)index->data);
{
SizeT input = (SizeT) * ((int32_t *)index->data);
SizeT k = slice::resolve_index_in_length(src_ndarray->shape[src_axis], input); SizeT k = slice::resolve_index_in_length(src_ndarray->shape[src_axis], input);
if (k == -1) if (k == -1) {
{
raise_exception(SizeT, EXN_INDEX_ERROR, raise_exception(SizeT, EXN_INDEX_ERROR,
"index {0} is out of bounds for axis {1} " "index {0} is out of bounds for axis {1} "
"with size {2}", "with size {2}",
@ -177,10 +157,8 @@ void index(SizeT num_indices, const NDIndex *indices, const NDArray<SizeT> *src_
dst_ndarray->data += k * src_ndarray->strides[src_axis]; dst_ndarray->data += k * src_ndarray->strides[src_axis];
src_axis++; src_axis++;
} } else if (index->type == ND_INDEX_TYPE_SLICE) {
else if (index->type == ND_INDEX_TYPE_SLICE) Slice<int32_t>* slice = (Slice<int32_t>*)index->data;
{
Slice<int32_t> *slice = (Slice<int32_t> *)index->data;
Range<int32_t> range = slice->indices_checked<SizeT>(src_ndarray->shape[src_axis]); Range<int32_t> range = slice->indices_checked<SizeT>(src_ndarray->shape[src_axis]);
@ -190,36 +168,28 @@ void index(SizeT num_indices, const NDIndex *indices, const NDArray<SizeT> *src_
dst_axis++; dst_axis++;
src_axis++; src_axis++;
} } else if (index->type == ND_INDEX_TYPE_NEWAXIS) {
else if (index->type == ND_INDEX_TYPE_NEWAXIS)
{
dst_ndarray->strides[dst_axis] = 0; dst_ndarray->strides[dst_axis] = 0;
dst_ndarray->shape[dst_axis] = 1; dst_ndarray->shape[dst_axis] = 1;
dst_axis++; dst_axis++;
} } else if (index->type == ND_INDEX_TYPE_ELLIPSIS) {
else if (index->type == ND_INDEX_TYPE_ELLIPSIS)
{
// The number of ':' entries this '...' implies. // The number of ':' entries this '...' implies.
SizeT ellipsis_size = src_ndarray->ndims - num_indexed; SizeT ellipsis_size = src_ndarray->ndims - num_indexed;
for (SizeT j = 0; j < ellipsis_size; j++) for (SizeT j = 0; j < ellipsis_size; j++) {
{
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis]; dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis]; dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis];
dst_axis++; dst_axis++;
src_axis++; src_axis++;
} }
} } else {
else
{
__builtin_unreachable(); __builtin_unreachable();
} }
} }
for (; dst_axis < dst_ndarray->ndims; dst_axis++, src_axis++) for (; dst_axis < dst_ndarray->ndims; dst_axis++, src_axis++) {
{
dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis]; dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis];
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis]; dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
} }
@ -227,23 +197,24 @@ void index(SizeT num_indices, const NDIndex *indices, const NDArray<SizeT> *src_
debug_assert_eq(SizeT, src_ndarray->ndims, src_axis); debug_assert_eq(SizeT, src_ndarray->ndims, src_axis);
debug_assert_eq(SizeT, dst_ndarray->ndims, dst_axis); debug_assert_eq(SizeT, dst_ndarray->ndims, dst_axis);
} }
} // namespace indexing } // namespace indexing
} // namespace ndarray } // namespace ndarray
} // namespace } // namespace
extern "C" extern "C" {
{ using namespace ndarray::indexing;
using namespace ndarray::indexing;
void __nac3_ndarray_index(int32_t num_indices, NDIndex *indices, NDArray<int32_t> *src_ndarray, void __nac3_ndarray_index(int32_t num_indices,
NDArray<int32_t> *dst_ndarray) NDIndex* indices,
{ NDArray<int32_t>* src_ndarray,
index(num_indices, indices, src_ndarray, dst_ndarray); NDArray<int32_t>* dst_ndarray) {
} index(num_indices, indices, src_ndarray, dst_ndarray);
}
void __nac3_ndarray_index64(int64_t num_indices, NDIndex *indices, NDArray<int64_t> *src_ndarray, void __nac3_ndarray_index64(int64_t num_indices,
NDArray<int64_t> *dst_ndarray) NDIndex* indices,
{ NDArray<int64_t>* src_ndarray,
index(num_indices, indices, src_ndarray, dst_ndarray); NDArray<int64_t>* dst_ndarray) {
} index(num_indices, indices, src_ndarray, dst_ndarray);
}
} }

View File

@ -1,35 +1,55 @@
#pragma once #pragma once
#include <irrt/int_types.hpp> #include "irrt/int_types.hpp"
#include <irrt/ndarray/def.hpp> #include "irrt/ndarray/def.hpp"
namespace namespace {
{
/** /**
* @brief Helper struct to enumerate through an ndarray *efficiently*. * @brief Helper struct to enumerate through an ndarray *efficiently*.
* *
* Example usage (in pseudo-code):
* ```
* // Suppose my_ndarray has been initialized, with shape [2, 3] and dtype `double`
* NDIter nditer;
* nditer.initialize(my_ndarray);
* while (nditer.has_element()) {
* // This body is run 6 (= my_ndarray.size) times.
*
* // [0, 0] -> [0, 1] -> [0, 2] -> [1, 0] -> [1, 1] -> [1, 2] -> end
* print(nditer.indices);
*
* // 0 -> 1 -> 2 -> 3 -> 4 -> 5
* print(nditer.nth);
*
* // <1st element> -> <2nd element> -> ... -> <6th element> -> end
* print(*((double *) nditer.element))
*
* nditer.next(); // Go to next element.
* }
* ```
*
* Interesting cases: * Interesting cases:
* - If ndims == 0, there is one iteration. * - If `my_ndarray.ndims` == 0, there is one iteration.
* - If shape contains zeroes, there are no iterations. * - If `my_ndarray.shape` contains zeroes, there are no iterations.
*/ */
template <typename SizeT> struct NDIter template<typename SizeT>
{ struct NDIter {
// Information about the ndarray being iterated over. // Information about the ndarray being iterated over.
SizeT ndims; SizeT ndims;
SizeT *shape; SizeT* shape;
SizeT *strides; SizeT* strides;
/** /**
* @brief The current indices. * @brief The current indices.
* *
* Must be allocated by the caller. * Must be allocated by the caller.
*/ */
SizeT *indices; SizeT* indices;
/** /**
* @brief The nth (0-based) index of the current indices. * @brief The nth (0-based) index of the current indices.
* *
* Initially this is all 0s. * Initially this is 0.
*/ */
SizeT nth; SizeT nth;
@ -38,7 +58,7 @@ template <typename SizeT> struct NDIter
* *
* Initially this points to first element of the ndarray. * Initially this points to first element of the ndarray.
*/ */
uint8_t *element; uint8_t* element;
/** /**
* @brief Cache for the product of shape. * @brief Cache for the product of shape.
@ -47,11 +67,7 @@ template <typename SizeT> struct NDIter
*/ */
SizeT size; SizeT size;
// TODO:: Not implemented: There is something called backstrides to speedup iteration. void initialize(SizeT ndims, SizeT* shape, SizeT* strides, uint8_t* element, SizeT* indices) {
// See https://ajcr.net/stride-guide-part-1/, and https://docs.scipy.org/doc/numpy-1.13.0/reference/c-api.types-and-structures.html#c.PyArrayIterObject.PyArrayIterObject.backstrides.
void initialize(SizeT ndims, SizeT *shape, SizeT *strides, uint8_t *element, SizeT *indices)
{
this->ndims = ndims; this->ndims = ndims;
this->shape = shape; this->shape = shape;
this->strides = strides; this->strides = strides;
@ -61,41 +77,39 @@ template <typename SizeT> struct NDIter
// Compute size // Compute size
this->size = 1; this->size = 1;
for (SizeT i = 0; i < ndims; i++) for (SizeT i = 0; i < ndims; i++) {
{
this->size *= shape[i]; this->size *= shape[i];
} }
// `indices` starts on all 0s.
for (SizeT axis = 0; axis < ndims; axis++) for (SizeT axis = 0; axis < ndims; axis++)
indices[axis] = 0; indices[axis] = 0;
nth = 0; nth = 0;
} }
void initialize_by_ndarray(NDArray<SizeT> *ndarray, SizeT *indices) void initialize_by_ndarray(NDArray<SizeT>* ndarray, SizeT* indices) {
{ // NOTE: ndarray->data is pointing to the first element, and `NDIter`'s `element` should also point to the first
// element as well.
this->initialize(ndarray->ndims, ndarray->shape, ndarray->strides, ndarray->data, indices); this->initialize(ndarray->ndims, ndarray->shape, ndarray->strides, ndarray->data, indices);
} }
bool has_next() // Is the current iteration valid?
{ // If true, then `element`, `indices` and `nth` contain details about the current element.
return nth < size; bool has_element() { return nth < size; }
}
void next() // Go to the next element.
{ void next() {
for (SizeT i = 0; i < ndims; i++) for (SizeT i = 0; i < ndims; i++) {
{
SizeT axis = ndims - i - 1; SizeT axis = ndims - i - 1;
indices[axis]++; indices[axis]++;
if (indices[axis] >= shape[axis]) if (indices[axis] >= shape[axis]) {
{
indices[axis] = 0; indices[axis] = 0;
// TODO: Can be optimized with backstrides. // TODO: There is something called backstrides to speedup iteration.
// See https://ajcr.net/stride-guide-part-1/, and
// https://docs.scipy.org/doc/numpy-1.13.0/reference/c-api.types-and-structures.html#c.PyArrayIterObject.PyArrayIterObject.backstrides.
element -= strides[axis] * (shape[axis] - 1); element -= strides[axis] * (shape[axis] - 1);
} } else {
else
{
element += strides[axis]; element += strides[axis];
break; break;
} }
@ -103,37 +117,30 @@ template <typename SizeT> struct NDIter
nth++; nth++;
} }
}; };
} // namespace } // namespace
extern "C" extern "C" {
{ void __nac3_nditer_initialize(NDIter<int32_t>* iter, NDArray<int32_t>* ndarray, int32_t* indices) {
void __nac3_nditer_initialize(NDIter<int32_t> *iter, NDArray<int32_t> *ndarray, int32_t *indices) iter->initialize_by_ndarray(ndarray, indices);
{ }
iter->initialize_by_ndarray(ndarray, indices);
}
void __nac3_nditer_initialize64(NDIter<int64_t> *iter, NDArray<int64_t> *ndarray, int64_t *indices) void __nac3_nditer_initialize64(NDIter<int64_t>* iter, NDArray<int64_t>* ndarray, int64_t* indices) {
{ iter->initialize_by_ndarray(ndarray, indices);
iter->initialize_by_ndarray(ndarray, indices); }
}
bool __nac3_nditer_has_next(NDIter<int32_t> *iter) bool __nac3_nditer_has_element(NDIter<int32_t>* iter) {
{ return iter->has_element();
return iter->has_next(); }
}
bool __nac3_nditer_has_next64(NDIter<int64_t> *iter) bool __nac3_nditer_has_element64(NDIter<int64_t>* iter) {
{ return iter->has_element();
return iter->has_next(); }
}
void __nac3_nditer_next(NDIter<int32_t> *iter) void __nac3_nditer_next(NDIter<int32_t>* iter) {
{ iter->next();
iter->next(); }
}
void __nac3_nditer_next64(NDIter<int64_t> *iter) void __nac3_nditer_next64(NDIter<int64_t>* iter) {
{ iter->next();
iter->next(); }
}
} }

View File

@ -1,14 +1,12 @@
#pragma once #pragma once
#include <irrt/int_types.hpp> #include "irrt/exception.hpp"
#include <irrt/ndarray/def.hpp> #include "irrt/int_types.hpp"
#include "irrt/ndarray/def.hpp"
namespace namespace {
{ namespace ndarray {
namespace ndarray namespace reshape {
{
namespace reshape
{
/** /**
* @brief Perform assertions on and resolve unknown dimensions in `new_shape` in `np.reshape(<ndarray>, new_shape)` * @brief Perform assertions on and resolve unknown dimensions in `new_shape` in `np.reshape(<ndarray>, new_shape)`
* *
@ -22,8 +20,8 @@ namespace reshape
* @param new_ndims Number of elements in `new_shape` * @param new_ndims Number of elements in `new_shape`
* @param new_shape Target shape to reshape to * @param new_shape Target shape to reshape to
*/ */
template <typename SizeT> void resolve_and_check_new_shape(SizeT size, SizeT new_ndims, SizeT *new_shape) template<typename SizeT>
{ void resolve_and_check_new_shape(SizeT size, SizeT new_ndims, SizeT* new_shape) {
// Is there a -1 in `new_shape`? // Is there a -1 in `new_shape`?
bool neg1_exists = false; bool neg1_exists = false;
// Location of -1, only initialized if `neg1_exists` is true // Location of -1, only initialized if `neg1_exists` is true
@ -31,27 +29,19 @@ template <typename SizeT> void resolve_and_check_new_shape(SizeT size, SizeT new
// The computed ndarray size of `new_shape` // The computed ndarray size of `new_shape`
SizeT new_size = 1; SizeT new_size = 1;
for (SizeT axis_i = 0; axis_i < new_ndims; axis_i++) for (SizeT axis_i = 0; axis_i < new_ndims; axis_i++) {
{
SizeT dim = new_shape[axis_i]; SizeT dim = new_shape[axis_i];
if (dim < 0) if (dim < 0) {
{ if (dim == -1) {
if (dim == -1) if (neg1_exists) {
{
if (neg1_exists)
{
// Multiple `-1` found. Throw an error. // Multiple `-1` found. Throw an error.
raise_exception(SizeT, EXN_VALUE_ERROR, "can only specify one unknown dimension", NO_PARAM, raise_exception(SizeT, EXN_VALUE_ERROR, "can only specify one unknown dimension", NO_PARAM,
NO_PARAM, NO_PARAM); NO_PARAM, NO_PARAM);
} } else {
else
{
neg1_exists = true; neg1_exists = true;
neg1_axis_i = axis_i; neg1_axis_i = axis_i;
} }
} } else {
else
{
// TODO: What? In `np.reshape` any negative dimensions is // TODO: What? In `np.reshape` any negative dimensions is
// treated like its `-1`. // treated like its `-1`.
// //
@ -63,63 +53,47 @@ template <typename SizeT> void resolve_and_check_new_shape(SizeT size, SizeT new
raise_exception(SizeT, EXN_VALUE_ERROR, "Found non -1 negative dimension {0} on axis {1}", dim, axis_i, raise_exception(SizeT, EXN_VALUE_ERROR, "Found non -1 negative dimension {0} on axis {1}", dim, axis_i,
NO_PARAM); NO_PARAM);
} }
} } else {
else
{
new_size *= dim; new_size *= dim;
} }
} }
bool can_reshape; bool can_reshape;
if (neg1_exists) if (neg1_exists) {
{
// Let `x` be the unknown dimension // Let `x` be the unknown dimension
// Solve `x * <new_size> = <size>` // Solve `x * <new_size> = <size>`
if (new_size == 0 && size == 0) if (new_size == 0 && size == 0) {
{
// `x` has infinitely many solutions // `x` has infinitely many solutions
can_reshape = false; can_reshape = false;
} } else if (new_size == 0 && size != 0) {
else if (new_size == 0 && size != 0)
{
// `x` has no solutions // `x` has no solutions
can_reshape = false; can_reshape = false;
} } else if (size % new_size != 0) {
else if (size % new_size != 0)
{
// `x` has no integer solutions // `x` has no integer solutions
can_reshape = false; can_reshape = false;
} } else {
else
{
can_reshape = true; can_reshape = true;
new_shape[neg1_axis_i] = size / new_size; // Resolve dimension new_shape[neg1_axis_i] = size / new_size; // Resolve dimension
} }
} } else {
else
{
can_reshape = (new_size == size); can_reshape = (new_size == size);
} }
if (!can_reshape) if (!can_reshape) {
{
raise_exception(SizeT, EXN_VALUE_ERROR, "cannot reshape array of size {0} into given shape", size, NO_PARAM, raise_exception(SizeT, EXN_VALUE_ERROR, "cannot reshape array of size {0} into given shape", size, NO_PARAM,
NO_PARAM); NO_PARAM);
} }
} }
} // namespace reshape } // namespace reshape
} // namespace ndarray } // namespace ndarray
} // namespace } // namespace
extern "C" extern "C" {
{ void __nac3_ndarray_reshape_resolve_and_check_new_shape(int32_t size, int32_t new_ndims, int32_t* new_shape) {
void __nac3_ndarray_reshape_resolve_and_check_new_shape(int32_t size, int32_t new_ndims, int32_t *new_shape) ndarray::reshape::resolve_and_check_new_shape(size, new_ndims, new_shape);
{ }
ndarray::reshape::resolve_and_check_new_shape(size, new_ndims, new_shape);
} void __nac3_ndarray_reshape_resolve_and_check_new_shape64(int64_t size, int64_t new_ndims, int64_t* new_shape) {
ndarray::reshape::resolve_and_check_new_shape(size, new_ndims, new_shape);
void __nac3_ndarray_reshape_resolve_and_check_new_shape64(int64_t size, int64_t new_ndims, int64_t *new_shape) }
{
ndarray::reshape::resolve_and_check_new_shape(size, new_ndims, new_shape);
}
} }

View File

@ -1,8 +1,10 @@
#pragma once #pragma once
#include <irrt/int_types.hpp> #include "irrt/debug.hpp"
#include <irrt/ndarray/def.hpp> #include "irrt/exception.hpp"
#include <irrt/slice.hpp> #include "irrt/int_types.hpp"
#include "irrt/ndarray/def.hpp"
#include "irrt/slice.hpp"
/* /*
* Notes on `np.transpose(<array>, <axes>)` * Notes on `np.transpose(<array>, <axes>)`
@ -13,12 +15,9 @@
* Supporting it for now. * Supporting it for now.
*/ */
namespace namespace {
{ namespace ndarray {
namespace ndarray namespace transpose {
{
namespace transpose
{
/** /**
* @brief Do assertions on `<axes>` in `np.transpose(<array>, <axes>)`. * @brief Do assertions on `<axes>` in `np.transpose(<array>, <axes>)`.
* *
@ -30,30 +29,26 @@ namespace transpose
* This should be equal to `ndims`. If not, a "ValueError: axes don't match array" is thrown. * This should be equal to `ndims`. If not, a "ValueError: axes don't match array" is thrown.
* @param axes The user specified `<axes>`. * @param axes The user specified `<axes>`.
*/ */
template <typename SizeT> void assert_transpose_axes(SizeT ndims, SizeT num_axes, const SizeT *axes) template<typename SizeT>
{ void assert_transpose_axes(SizeT ndims, SizeT num_axes, const SizeT* axes) {
if (ndims != num_axes) if (ndims != num_axes) {
{
raise_exception(SizeT, EXN_VALUE_ERROR, "axes don't match array", NO_PARAM, NO_PARAM, NO_PARAM); raise_exception(SizeT, EXN_VALUE_ERROR, "axes don't match array", NO_PARAM, NO_PARAM, NO_PARAM);
} }
// TODO: Optimize this // TODO: Optimize this
bool *axe_specified = (bool *)__builtin_alloca(sizeof(bool) * ndims); bool* axe_specified = (bool*)__builtin_alloca(sizeof(bool) * ndims);
for (SizeT i = 0; i < ndims; i++) for (SizeT i = 0; i < ndims; i++)
axe_specified[i] = false; axe_specified[i] = false;
for (SizeT i = 0; i < ndims; i++) for (SizeT i = 0; i < ndims; i++) {
{
SizeT axis = slice::resolve_index_in_length(ndims, axes[i]); SizeT axis = slice::resolve_index_in_length(ndims, axes[i]);
if (axis == -1) if (axis == -1) {
{
// TODO: numpy actually throws a `numpy.exceptions.AxisError` // TODO: numpy actually throws a `numpy.exceptions.AxisError`
raise_exception(SizeT, EXN_VALUE_ERROR, "axis {0} is out of bounds for array of dimension {1}", axis, ndims, raise_exception(SizeT, EXN_VALUE_ERROR, "axis {0} is out of bounds for array of dimension {1}", axis, ndims,
NO_PARAM); NO_PARAM);
} }
if (axe_specified[axis]) if (axe_specified[axis]) {
{
raise_exception(SizeT, EXN_VALUE_ERROR, "repeated axis in transpose", NO_PARAM, NO_PARAM, NO_PARAM); raise_exception(SizeT, EXN_VALUE_ERROR, "repeated axis in transpose", NO_PARAM, NO_PARAM, NO_PARAM);
} }
@ -88,9 +83,8 @@ template <typename SizeT> void assert_transpose_axes(SizeT ndims, SizeT num_axes
* @param num_axes Number of elements in axes. Unused if `axes` is nullptr. * @param num_axes Number of elements in axes. Unused if `axes` is nullptr.
* @param axes Axes permutation. Set it to `nullptr` if `<axes>` is `None`. * @param axes Axes permutation. Set it to `nullptr` if `<axes>` is `None`.
*/ */
template <typename SizeT> template<typename SizeT>
void transpose(const NDArray<SizeT> *src_ndarray, NDArray<SizeT> *dst_ndarray, SizeT num_axes, const SizeT *axes) void transpose(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray, SizeT num_axes, const SizeT* axes) {
{
debug_assert_eq(SizeT, src_ndarray->ndims, dst_ndarray->ndims); debug_assert_eq(SizeT, src_ndarray->ndims, dst_ndarray->ndims);
const auto ndims = src_ndarray->ndims; const auto ndims = src_ndarray->ndims;
@ -101,8 +95,7 @@ void transpose(const NDArray<SizeT> *src_ndarray, NDArray<SizeT> *dst_ndarray, S
dst_ndarray->itemsize = src_ndarray->itemsize; dst_ndarray->itemsize = src_ndarray->itemsize;
// Check out https://ajcr.net/stride-guide-part-2/ to see how `np.transpose` works behind the scenes. // Check out https://ajcr.net/stride-guide-part-2/ to see how `np.transpose` works behind the scenes.
if (axes == nullptr) if (axes == nullptr) {
{
// `np.transpose(<array>, axes=None)` // `np.transpose(<array>, axes=None)`
/* /*
@ -113,19 +106,15 @@ void transpose(const NDArray<SizeT> *src_ndarray, NDArray<SizeT> *dst_ndarray, S
* This is a fast implementation to handle this special (but very common) case. * This is a fast implementation to handle this special (but very common) case.
*/ */
for (SizeT axis = 0; axis < ndims; axis++) for (SizeT axis = 0; axis < ndims; axis++) {
{
dst_ndarray->shape[axis] = src_ndarray->shape[ndims - axis - 1]; dst_ndarray->shape[axis] = src_ndarray->shape[ndims - axis - 1];
dst_ndarray->strides[axis] = src_ndarray->strides[ndims - axis - 1]; dst_ndarray->strides[axis] = src_ndarray->strides[ndims - axis - 1];
} }
} } else {
else
{
// `np.transpose(<array>, <axes>)` // `np.transpose(<array>, <axes>)`
// Permute strides and shape according to `axes`, while resolving negative indices in `axes` // Permute strides and shape according to `axes`, while resolving negative indices in `axes`
for (SizeT axis = 0; axis < ndims; axis++) for (SizeT axis = 0; axis < ndims; axis++) {
{
// `i` cannot be OUT_OF_BOUNDS because of assertions // `i` cannot be OUT_OF_BOUNDS because of assertions
SizeT i = slice::resolve_index_in_length(ndims, axes[axis]); SizeT i = slice::resolve_index_in_length(ndims, axes[axis]);
@ -134,22 +123,23 @@ void transpose(const NDArray<SizeT> *src_ndarray, NDArray<SizeT> *dst_ndarray, S
} }
} }
} }
} // namespace transpose } // namespace transpose
} // namespace ndarray } // namespace ndarray
} // namespace } // namespace
extern "C" extern "C" {
{ using namespace ndarray::transpose;
using namespace ndarray::transpose; void __nac3_ndarray_transpose(const NDArray<int32_t>* src_ndarray,
void __nac3_ndarray_transpose(const NDArray<int32_t> *src_ndarray, NDArray<int32_t> *dst_ndarray, int32_t num_axes, NDArray<int32_t>* dst_ndarray,
const int32_t *axes) int32_t num_axes,
{ const int32_t* axes) {
transpose(src_ndarray, dst_ndarray, num_axes, axes); transpose(src_ndarray, dst_ndarray, num_axes, axes);
} }
void __nac3_ndarray_transpose64(const NDArray<int64_t> *src_ndarray, NDArray<int64_t> *dst_ndarray, void __nac3_ndarray_transpose64(const NDArray<int64_t>* src_ndarray,
int64_t num_axes, const int64_t *axes) NDArray<int64_t>* dst_ndarray,
{ int64_t num_axes,
transpose(src_ndarray, dst_ndarray, num_axes, axes); const int64_t* axes) {
} transpose(src_ndarray, dst_ndarray, num_axes, axes);
}
} }

View File

@ -1,372 +0,0 @@
#pragma once
#include <irrt/int_types.hpp>
#include <irrt/math_util.hpp>
// NDArray indices are always `uint32_t`.
using NDIndexInt = uint32_t;
// The type of an index or a value describing the length of a range/slice is always `int32_t`.
using SliceIndex = int32_t;
namespace
{
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
// need to make sure `exp >= 0` before calling this function
template <typename T> T __nac3_int_exp_impl(T base, T exp)
{
T res = 1;
/* repeated squaring method */
do
{
if (exp & 1)
{
res *= base; /* for n odd */
}
exp >>= 1;
base *= base;
} while (exp);
return res;
}
template <typename SizeT>
SizeT __nac3_ndarray_calc_size_impl(const SizeT *list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx)
{
__builtin_assume(end_idx <= list_len);
SizeT num_elems = 1;
for (SizeT i = begin_idx; i < end_idx; ++i)
{
SizeT val = list_data[i];
__builtin_assume(val > 0);
num_elems *= val;
}
return num_elems;
}
template <typename SizeT>
void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT *dims, SizeT num_dims, NDIndexInt *idxs)
{
SizeT stride = 1;
for (SizeT dim = 0; dim < num_dims; dim++)
{
SizeT i = num_dims - dim - 1;
__builtin_assume(dims[i] > 0);
idxs[i] = (index / stride) % dims[i];
stride *= dims[i];
}
}
template <typename SizeT>
SizeT __nac3_ndarray_flatten_index_impl(const SizeT *dims, SizeT num_dims, const NDIndexInt *indices, SizeT num_indices)
{
SizeT idx = 0;
SizeT stride = 1;
for (SizeT i = 0; i < num_dims; ++i)
{
SizeT ri = num_dims - i - 1;
if (ri < num_indices)
{
idx += stride * indices[ri];
}
__builtin_assume(dims[i] > 0);
stride *= dims[ri];
}
return idx;
}
template <typename SizeT>
void __nac3_ndarray_calc_broadcast_impl(const SizeT *lhs_dims, SizeT lhs_ndims, const SizeT *rhs_dims, SizeT rhs_ndims,
SizeT *out_dims)
{
SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
for (SizeT i = 0; i < max_ndims; ++i)
{
const SizeT *lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr;
const SizeT *rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr;
SizeT *out_dim = &out_dims[max_ndims - i - 1];
if (lhs_dim_sz == nullptr)
{
*out_dim = *rhs_dim_sz;
}
else if (rhs_dim_sz == nullptr)
{
*out_dim = *lhs_dim_sz;
}
else if (*lhs_dim_sz == 1)
{
*out_dim = *rhs_dim_sz;
}
else if (*rhs_dim_sz == 1)
{
*out_dim = *lhs_dim_sz;
}
else if (*lhs_dim_sz == *rhs_dim_sz)
{
*out_dim = *lhs_dim_sz;
}
else
{
__builtin_unreachable();
}
}
}
template <typename SizeT>
void __nac3_ndarray_calc_broadcast_idx_impl(const SizeT *src_dims, SizeT src_ndims, const NDIndexInt *in_idx,
NDIndexInt *out_idx)
{
for (SizeT i = 0; i < src_ndims; ++i)
{
SizeT src_i = src_ndims - i - 1;
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
}
}
} // namespace
extern "C"
{
#define DEF_nac3_int_exp_(T) \
T __nac3_int_exp_##T(T base, T exp) \
{ \
return __nac3_int_exp_impl(base, exp); \
}
DEF_nac3_int_exp_(int32_t) DEF_nac3_int_exp_(int64_t) DEF_nac3_int_exp_(uint32_t) DEF_nac3_int_exp_(uint64_t)
SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len)
{
if (i < 0)
{
i = len + i;
}
if (i < 0)
{
return 0;
}
else if (i > len)
{
return len;
}
return i;
}
SliceIndex __nac3_range_slice_len(const SliceIndex start, const SliceIndex end, const SliceIndex step)
{
SliceIndex diff = end - start;
if (diff > 0 && step > 0)
{
return ((diff - 1) / step) + 1;
}
else if (diff < 0 && step < 0)
{
return ((diff + 1) / step) + 1;
}
else
{
return 0;
}
}
// Handle list assignment and dropping part of the list when
// both dest_step and src_step are +1.
// - All the index must *not* be out-of-bound or negative,
// - The end index is *inclusive*,
// - The length of src and dest slice size should already
// be checked: if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest)
SliceIndex __nac3_list_slice_assign_var_size(SliceIndex dest_start, SliceIndex dest_end, SliceIndex dest_step,
uint8_t *dest_arr, SliceIndex dest_arr_len, SliceIndex src_start,
SliceIndex src_end, SliceIndex src_step, uint8_t *src_arr,
SliceIndex src_arr_len, const SliceIndex size)
{
/* if dest_arr_len == 0, do nothing since we do not support extending list */
if (dest_arr_len == 0)
return dest_arr_len;
/* if both step is 1, memmove directly, handle the dropping of the list, and shrink size */
if (src_step == dest_step && dest_step == 1)
{
const SliceIndex src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0;
const SliceIndex dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
if (src_len > 0)
{
__builtin_memmove(dest_arr + dest_start * size, src_arr + src_start * size, src_len * size);
}
if (dest_len > 0)
{
/* dropping */
__builtin_memmove(dest_arr + (dest_start + src_len) * size, dest_arr + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size);
}
/* shrink size */
return dest_arr_len - (dest_len - src_len);
}
/* if two range overlaps, need alloca */
uint8_t need_alloca = (dest_arr == src_arr) && !(max(dest_start, dest_end) < min(src_start, src_end) ||
max(src_start, src_end) < min(dest_start, dest_end));
if (need_alloca)
{
uint8_t *tmp = reinterpret_cast<uint8_t *>(__builtin_alloca(src_arr_len * size));
__builtin_memcpy(tmp, src_arr, src_arr_len * size);
src_arr = tmp;
}
SliceIndex src_ind = src_start;
SliceIndex dest_ind = dest_start;
for (; (src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end); src_ind += src_step, dest_ind += dest_step)
{
/* for constant optimization */
if (size == 1)
{
__builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1);
}
else if (size == 4)
{
__builtin_memcpy(dest_arr + dest_ind * 4, src_arr + src_ind * 4, 4);
}
else if (size == 8)
{
__builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8);
}
else
{
/* memcpy for var size, cannot overlap after previous alloca */
__builtin_memcpy(dest_arr + dest_ind * size, src_arr + src_ind * size, size);
}
}
/* only dest_step == 1 can we shrink the dest list. */
/* size should be ensured prior to calling this function */
if (dest_step == 1 && dest_end >= dest_start)
{
__builtin_memmove(dest_arr + dest_ind * size, dest_arr + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size);
return dest_arr_len - (dest_end - dest_ind) - 1;
}
return dest_arr_len;
}
int32_t __nac3_isinf(double x)
{
return __builtin_isinf(x);
}
int32_t __nac3_isnan(double x)
{
return __builtin_isnan(x);
}
double tgamma(double arg);
double __nac3_gamma(double z)
{
// Handling for denormals
// | x | Python gamma(x) | C tgamma(x) |
// --- | ----------------- | --------------- | ----------- |
// (1) | nan | nan | nan |
// (2) | -inf | -inf | inf |
// (3) | inf | inf | inf |
// (4) | 0.0 | inf | inf |
// (5) | {-1.0, -2.0, ...} | inf | nan |
// (1)-(3)
if (__builtin_isinf(z) || __builtin_isnan(z))
{
return z;
}
double v = tgamma(z);
// (4)-(5)
return __builtin_isinf(v) || __builtin_isnan(v) ? __builtin_inf() : v;
}
double lgamma(double arg);
double __nac3_gammaln(double x)
{
// libm's handling of value overflows differs from scipy:
// - scipy: gammaln(-inf) -> -inf
// - libm : lgamma(-inf) -> inf
if (__builtin_isinf(x))
{
return x;
}
return lgamma(x);
}
double j0(double x);
double __nac3_j0(double x)
{
// libm's handling of value overflows differs from scipy:
// - scipy: j0(inf) -> nan
// - libm : j0(inf) -> 0.0
if (__builtin_isinf(x))
{
return __builtin_nan("");
}
return j0(x);
}
uint32_t __nac3_ndarray_calc_size(const uint32_t *list_data, uint32_t list_len, uint32_t begin_idx,
uint32_t end_idx)
{
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
}
uint64_t __nac3_ndarray_calc_size64(const uint64_t *list_data, uint64_t list_len, uint64_t begin_idx,
uint64_t end_idx)
{
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
}
void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t *dims, uint32_t num_dims, NDIndexInt *idxs)
{
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
}
void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t *dims, uint64_t num_dims, NDIndexInt *idxs)
{
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
}
uint32_t __nac3_ndarray_flatten_index(const uint32_t *dims, uint32_t num_dims, const NDIndexInt *indices,
uint32_t num_indices)
{
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
}
uint64_t __nac3_ndarray_flatten_index64(const uint64_t *dims, uint64_t num_dims, const NDIndexInt *indices,
uint64_t num_indices)
{
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
}
void __nac3_ndarray_calc_broadcast(const uint32_t *lhs_dims, uint32_t lhs_ndims, const uint32_t *rhs_dims,
uint32_t rhs_ndims, uint32_t *out_dims)
{
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
}
void __nac3_ndarray_calc_broadcast64(const uint64_t *lhs_dims, uint64_t lhs_ndims, const uint64_t *rhs_dims,
uint64_t rhs_ndims, uint64_t *out_dims)
{
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
}
void __nac3_ndarray_calc_broadcast_idx(const uint32_t *src_dims, uint32_t src_ndims, const NDIndexInt *in_idx,
NDIndexInt *out_idx)
{
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
}
void __nac3_ndarray_calc_broadcast_idx64(const uint64_t *src_dims, uint64_t src_ndims, const NDIndexInt *in_idx,
NDIndexInt *out_idx)
{
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
}
} // extern "C"

View File

@ -1,14 +1,12 @@
#pragma once #pragma once
#include <irrt/debug.hpp> #include "irrt/debug.hpp"
#include <irrt/int_types.hpp> #include "irrt/int_types.hpp"
namespace namespace {
{ namespace range {
namespace range template<typename T>
{ T len(T start, T stop, T step) {
template <typename T> T len(T start, T stop, T step)
{
// Reference: // Reference:
// https://github.com/python/cpython/blob/9dbd12375561a393eaec4b21ee4ac568a407cdb0/Objects/rangeobject.c#L933 // https://github.com/python/cpython/blob/9dbd12375561a393eaec4b21ee4ac568a407cdb0/Objects/rangeobject.c#L933
if (step > 0 && start < stop) if (step > 0 && start < stop)
@ -18,13 +16,13 @@ template <typename T> T len(T start, T stop, T step)
else else
return 0; return 0;
} }
} // namespace range } // namespace range
/** /**
* @brief A Python range. * @brief A Python range.
*/ */
template <typename T> struct Range template<typename T>
{ struct Range {
T start; T start;
T stop; T stop;
T step; T step;
@ -32,10 +30,18 @@ template <typename T> struct Range
/** /**
* @brief Calculate the `len()` of this range. * @brief Calculate the `len()` of this range.
*/ */
template <typename SizeT> T len() template<typename SizeT>
{ T len() {
debug_assert(SizeT, step != 0); debug_assert(SizeT, step != 0);
return range::len(start, stop, step); return range::len(start, stop, step);
} }
}; };
} // namespace } // namespace
extern "C" {
using namespace range;
SliceIndex __nac3_range_slice_len(const SliceIndex start, const SliceIndex end, const SliceIndex step) {
return len(start, end, step);
}
}

View File

@ -1,29 +1,24 @@
#pragma once #pragma once
#include <irrt/debug.hpp> #include "irrt/debug.hpp"
#include <irrt/exception.hpp> #include "irrt/exception.hpp"
#include <irrt/int_types.hpp> #include "irrt/int_types.hpp"
#include <irrt/math_util.hpp> #include "irrt/math_util.hpp"
#include <irrt/range.hpp> #include "irrt/range.hpp"
namespace namespace {
{ namespace slice {
namespace slice
{
/** /**
* @brief Resolve a possibly negative index in a list of a known length. * @brief Resolve a possibly negative index in a list of a known length.
* *
* Returns -1 if the resolved index is out of the list's bounds. * Returns -1 if the resolved index is out of the list's bounds.
*/ */
template <typename T> T resolve_index_in_length(T length, T index) template<typename T>
{ T resolve_index_in_length(T length, T index) {
T resolved = index < 0 ? length + index : index; T resolved = index < 0 ? length + index : index;
if (0 <= resolved && resolved < length) if (0 <= resolved && resolved < length) {
{
return resolved; return resolved;
} } else {
else
{
return -1; return -1;
} }
} }
@ -33,51 +28,49 @@ template <typename T> T resolve_index_in_length(T length, T index)
* *
* This is equivalent to `range(*slice(start, stop, step).indices(length))` in Python. * This is equivalent to `range(*slice(start, stop, step).indices(length))` in Python.
*/ */
template <typename T> template<typename T>
void indices(bool start_defined, T start, bool stop_defined, T stop, bool step_defined, T step, T length, void indices(bool start_defined,
T *range_start, T *range_stop, T *range_step) T start,
{ bool stop_defined,
T stop,
bool step_defined,
T step,
T length,
T* range_start,
T* range_stop,
T* range_step) {
// Reference: https://github.com/python/cpython/blob/main/Objects/sliceobject.c#L388 // Reference: https://github.com/python/cpython/blob/main/Objects/sliceobject.c#L388
*range_step = step_defined ? step : 1; *range_step = step_defined ? step : 1;
bool step_is_negative = *range_step < 0; bool step_is_negative = *range_step < 0;
T lower, upper; T lower, upper;
if (step_is_negative) if (step_is_negative) {
{
lower = -1; lower = -1;
upper = length - 1; upper = length - 1;
} } else {
else
{
lower = 0; lower = 0;
upper = length; upper = length;
} }
if (start_defined) if (start_defined) {
{
*range_start = start < 0 ? max(lower, start + length) : min(upper, start); *range_start = start < 0 ? max(lower, start + length) : min(upper, start);
} } else {
else
{
*range_start = step_is_negative ? upper : lower; *range_start = step_is_negative ? upper : lower;
} }
if (stop_defined) if (stop_defined) {
{
*range_stop = stop < 0 ? max(lower, stop + length) : min(upper, stop); *range_stop = stop < 0 ? max(lower, stop + length) : min(upper, stop);
} } else {
else
{
*range_stop = step_is_negative ? lower : upper; *range_stop = step_is_negative ? lower : upper;
} }
} }
} // namespace slice } // namespace slice
/** /**
* @brief A Python-like slice with **unresolved** indices. * @brief A Python-like slice with **unresolved** indices.
*/ */
template <typename T> struct Slice template<typename T>
{ struct Slice {
bool start_defined; bool start_defined;
T start; T start;
@ -87,32 +80,25 @@ template <typename T> struct Slice
bool step_defined; bool step_defined;
T step; T step;
Slice() Slice() { this->reset(); }
{
this->reset();
}
void reset() void reset() {
{
this->start_defined = false; this->start_defined = false;
this->stop_defined = false; this->stop_defined = false;
this->step_defined = false; this->step_defined = false;
} }
void set_start(T start) void set_start(T start) {
{
this->start_defined = true; this->start_defined = true;
this->start = start; this->start = start;
} }
void set_stop(T stop) void set_stop(T stop) {
{
this->stop_defined = true; this->stop_defined = true;
this->stop = stop; this->stop = stop;
} }
void set_step(T step) void set_step(T step) {
{
this->step_defined = true; this->step_defined = true;
this->step = step; this->step = step;
} }
@ -122,8 +108,8 @@ template <typename T> struct Slice
* *
* In Python, this would be `range(*slice(start, stop, step).indices(length))`. * In Python, this would be `range(*slice(start, stop, step).indices(length))`.
*/ */
template <typename SizeT> Range<T> indices(T length) template<typename SizeT>
{ Range<T> indices(T length) {
// Reference: // Reference:
// https://github.com/python/cpython/blob/main/Objects/sliceobject.c#L388 // https://github.com/python/cpython/blob/main/Objects/sliceobject.c#L388
debug_assert(SizeT, length >= 0); debug_assert(SizeT, length >= 0);
@ -137,22 +123,34 @@ template <typename T> struct Slice
/** /**
* @brief Like `.indices()` but with assertions. * @brief Like `.indices()` but with assertions.
*/ */
template <typename SizeT> Range<T> indices_checked(T length) template<typename SizeT>
{ Range<T> indices_checked(T length) {
// TODO: Switch to `SizeT length` // TODO: Switch to `SizeT length`
if (length < 0) if (length < 0) {
{
raise_exception(SizeT, EXN_VALUE_ERROR, "length should not be negative, got {0}", length, NO_PARAM, raise_exception(SizeT, EXN_VALUE_ERROR, "length should not be negative, got {0}", length, NO_PARAM,
NO_PARAM); NO_PARAM);
} }
if (this->step_defined && this->step == 0) if (this->step_defined && this->step == 0) {
{
raise_exception(SizeT, EXN_VALUE_ERROR, "slice step cannot be zero", NO_PARAM, NO_PARAM, NO_PARAM); raise_exception(SizeT, EXN_VALUE_ERROR, "slice step cannot be zero", NO_PARAM, NO_PARAM, NO_PARAM);
} }
return this->indices<SizeT>(length); return this->indices<SizeT>(length);
} }
}; };
} // namespace } // namespace
extern "C" {
SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) {
if (i < 0) {
i = len + i;
}
if (i < 0) {
return 0;
} else if (i > len) {
return len;
}
return i;
}
}

View File

@ -8,6 +8,7 @@ use crate::codegen::classes::{
}; };
use crate::codegen::expr::destructure_range; use crate::codegen::expr::destructure_range;
use crate::codegen::irrt::calculate_len_for_slice_range; use crate::codegen::irrt::calculate_len_for_slice_range;
use crate::codegen::macros::codegen_unreachable;
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl; use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
use crate::codegen::stmt::gen_for_callback_incrementing; use crate::codegen::stmt::gen_for_callback_incrementing;
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator}; use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
@ -25,7 +26,8 @@ use super::object::tuple::TupleObject;
/// ///
/// The generated message will contain the function name and the name of the unsupported type. /// The generated message will contain the function name and the name of the unsupported type.
fn unsupported_type(ctx: &CodeGenContext<'_, '_>, fn_name: &str, tys: &[Type]) -> ! { fn unsupported_type(ctx: &CodeGenContext<'_, '_>, fn_name: &str, tys: &[Type]) -> ! {
unreachable!( codegen_unreachable!(
ctx,
"{fn_name}() not supported for '{}'", "{fn_name}() not supported for '{}'",
tys.iter().map(|ty| format!("'{}'", ctx.unifier.stringify(*ty))).join(", "), tys.iter().map(|ty| format!("'{}'", ctx.unifier.stringify(*ty))).join(", "),
) )
@ -764,7 +766,7 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>(
} else if is_ndarray2 { } else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
} else { } else {
unreachable!() codegen_unreachable!(ctx)
}; };
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
@ -868,7 +870,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
match fn_name { match fn_name {
"np_argmin" | "np_argmax" => llvm_int64.const_zero().into(), "np_argmin" | "np_argmax" => llvm_int64.const_zero().into(),
"np_max" | "np_min" => a, "np_max" | "np_min" => a,
_ => unreachable!(), _ => codegen_unreachable!(ctx),
} }
} }
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
@ -923,7 +925,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
"np_argmax" | "np_max" => { "np_argmax" | "np_max" => {
call_max(ctx, (elem_ty, accumulator), (elem_ty, elem)) call_max(ctx, (elem_ty, accumulator), (elem_ty, elem))
} }
_ => unreachable!(), _ => codegen_unreachable!(ctx),
}; };
let updated_idx = match (accumulator, result) { let updated_idx = match (accumulator, result) {
@ -960,7 +962,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
match fn_name { match fn_name {
"np_argmin" | "np_argmax" => ctx.builder.build_load(res_idx, "").unwrap(), "np_argmin" | "np_argmax" => ctx.builder.build_load(res_idx, "").unwrap(),
"np_max" | "np_min" => ctx.builder.build_load(accumulator_addr, "").unwrap(), "np_max" | "np_min" => ctx.builder.build_load(accumulator_addr, "").unwrap(),
_ => unreachable!(), _ => codegen_unreachable!(ctx),
} }
} }
@ -1026,7 +1028,7 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
} else if is_ndarray2 { } else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
} else { } else {
unreachable!() codegen_unreachable!(ctx)
}; };
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
@ -1466,7 +1468,7 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>(
} else if is_ndarray2 { } else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
} else { } else {
unreachable!() codegen_unreachable!(ctx)
}; };
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
@ -1533,7 +1535,7 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>(
} else if is_ndarray2 { } else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
} else { } else {
unreachable!() codegen_unreachable!(ctx)
}; };
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
@ -1600,7 +1602,7 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>(
} else if is_ndarray2 { } else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
} else { } else {
unreachable!() codegen_unreachable!(ctx)
}; };
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
@ -1667,7 +1669,7 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>(
} else if is_ndarray2 { } else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
} else { } else {
unreachable!() codegen_unreachable!(ctx)
}; };
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
@ -1790,7 +1792,7 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>(
} else if is_ndarray2 { } else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
} else { } else {
unreachable!() codegen_unreachable!(ctx)
}; };
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
@ -1857,7 +1859,7 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
} else if is_ndarray2 { } else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
} else { } else {
unreachable!() codegen_unreachable!(ctx)
}; };
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };

View File

@ -1404,7 +1404,7 @@ impl<'ctx> NDArrayValue<'ctx> {
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
/// on the field. /// on the field.
fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default(); let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default();

View File

@ -11,6 +11,7 @@ use crate::{
call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax, call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax,
call_int_umin, call_memcpy_generic, call_int_umin, call_memcpy_generic,
}, },
macros::codegen_unreachable,
need_sret, numpy, need_sret, numpy,
stmt::{ stmt::{
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise, gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
@ -40,12 +41,9 @@ use std::cmp::min;
use std::iter::{repeat, repeat_with}; use std::iter::{repeat, repeat_with};
use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip}; use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
use super::{ use super::object::{
model::*, any::AnyObject,
object::{ ndarray::{indexing::util::gen_ndarray_subscript_ndindices, NDArrayObject},
any::AnyObject,
ndarray::{indexing::util::gen_ndarray_subscript_ndindices, NDArrayObject},
},
}; };
pub fn get_subst_key( pub fn get_subst_key(
@ -116,7 +114,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
let obj_id = match &*self.unifier.get_ty(ty) { let obj_id = match &*self.unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, .. } => *obj_id, TypeEnum::TObj { obj_id, .. } => *obj_id,
// we cannot have other types, virtual type should be handled by function calls // we cannot have other types, virtual type should be handled by function calls
_ => unreachable!(), _ => codegen_unreachable!(self),
}; };
let def = &self.top_level.definitions.read()[obj_id.0]; let def = &self.top_level.definitions.read()[obj_id.0];
let (index, value) = if let TopLevelDef::Class { fields, attributes, .. } = &*def.read() { let (index, value) = if let TopLevelDef::Class { fields, attributes, .. } = &*def.read() {
@ -127,7 +125,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
(attribute_index.0, Some(attribute_index.1 .2.clone())) (attribute_index.0, Some(attribute_index.1 .2.clone()))
} }
} else { } else {
unreachable!() codegen_unreachable!(self)
}; };
(index, value) (index, value)
} }
@ -137,7 +135,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
TypeEnum::TObj { fields, .. } => { TypeEnum::TObj { fields, .. } => {
fields.iter().find_position(|x| *x.0 == attr).unwrap().0 fields.iter().find_position(|x| *x.0 == attr).unwrap().0
} }
_ => unreachable!(), _ => codegen_unreachable!(self),
} }
} }
@ -192,7 +190,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
{ {
*params.iter().next().unwrap().1 *params.iter().next().unwrap().1
} }
_ => unreachable!("must be option type"), _ => codegen_unreachable!(self, "must be option type"),
}; };
let val = self.gen_symbol_val(generator, v, ty); let val = self.gen_symbol_val(generator, v, ty);
let ptr = generator let ptr = generator
@ -208,7 +206,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
{ {
*params.iter().next().unwrap().1 *params.iter().next().unwrap().1
} }
_ => unreachable!("must be option type"), _ => codegen_unreachable!(self, "must be option type"),
}; };
let actual_ptr_type = let actual_ptr_type =
self.get_llvm_type(generator, ty).ptr_type(AddressSpace::default()); self.get_llvm_type(generator, ty).ptr_type(AddressSpace::default());
@ -275,7 +273,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
{ {
self.ctx.i64_type() self.ctx.i64_type()
} else { } else {
unreachable!() codegen_unreachable!(self)
}; };
Some(ty.const_int(*val as u64, false).into()) Some(ty.const_int(*val as u64, false).into())
} }
@ -289,7 +287,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
let (types, is_vararg_ctx) = if let TypeEnum::TTuple { ty, is_vararg_ctx } = &*ty { let (types, is_vararg_ctx) = if let TypeEnum::TTuple { ty, is_vararg_ctx } = &*ty {
(ty.clone(), *is_vararg_ctx) (ty.clone(), *is_vararg_ctx)
} else { } else {
unreachable!() codegen_unreachable!(self)
}; };
let values = zip(types, v.iter()) let values = zip(types, v.iter())
.map_while(|(ty, v)| self.gen_const(generator, v, ty)) .map_while(|(ty, v)| self.gen_const(generator, v, ty))
@ -334,7 +332,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
None None
} }
_ => unreachable!(), _ => codegen_unreachable!(self),
} }
} }
@ -348,7 +346,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
signed: bool, signed: bool,
) -> BasicValueEnum<'ctx> { ) -> BasicValueEnum<'ctx> {
let (BasicValueEnum::IntValue(lhs), BasicValueEnum::IntValue(rhs)) = (lhs, rhs) else { let (BasicValueEnum::IntValue(lhs), BasicValueEnum::IntValue(rhs)) = (lhs, rhs) else {
unreachable!() codegen_unreachable!(self)
}; };
let float = self.ctx.f64_type(); let float = self.ctx.f64_type();
match (op, signed) { match (op, signed) {
@ -423,7 +421,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
.build_right_shift(lhs, rhs, signed, "rshift") .build_right_shift(lhs, rhs, signed, "rshift")
.map(Into::into) .map(Into::into)
.unwrap(), .unwrap(),
_ => unreachable!(), _ => codegen_unreachable!(self),
} }
} }
@ -435,7 +433,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
} }
(Operator::Pow, s) => integer_power(generator, self, lhs, rhs, s).into(), (Operator::Pow, s) => integer_power(generator, self, lhs, rhs, s).into(),
// special implementation? // special implementation?
(Operator::MatMult, _) => unreachable!(), (Operator::MatMult, _) => codegen_unreachable!(self),
} }
} }
@ -447,7 +445,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
rhs: BasicValueEnum<'ctx>, rhs: BasicValueEnum<'ctx>,
) -> BasicValueEnum<'ctx> { ) -> BasicValueEnum<'ctx> {
let (BasicValueEnum::FloatValue(lhs), BasicValueEnum::FloatValue(rhs)) = (lhs, rhs) else { let (BasicValueEnum::FloatValue(lhs), BasicValueEnum::FloatValue(rhs)) = (lhs, rhs) else {
unreachable!( codegen_unreachable!(
self,
"Expected (FloatValue, FloatValue), got ({}, {})", "Expected (FloatValue, FloatValue), got ({}, {})",
lhs.get_type(), lhs.get_type(),
rhs.get_type() rhs.get_type()
@ -691,7 +690,7 @@ pub fn gen_constructor<'ctx, 'a, G: CodeGenerator>(
def: &TopLevelDef, def: &TopLevelDef,
params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>, params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
) -> Result<BasicValueEnum<'ctx>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
let TopLevelDef::Class { methods, .. } = def else { unreachable!() }; let TopLevelDef::Class { methods, .. } = def else { codegen_unreachable!(ctx) };
// TODO: what about other fields that require alloca? // TODO: what about other fields that require alloca?
let fun_id = methods.iter().find(|method| method.0 == "__init__".into()).map(|method| method.2); let fun_id = methods.iter().find(|method| method.0 == "__init__".into()).map(|method| method.2);
@ -723,7 +722,7 @@ pub fn gen_func_instance<'ctx>(
key, key,
) = fun ) = fun
else { else {
unreachable!() codegen_unreachable!(ctx)
}; };
if let Some(sym) = instance_to_symbol.get(&key) { if let Some(sym) = instance_to_symbol.get(&key) {
@ -755,7 +754,7 @@ pub fn gen_func_instance<'ctx>(
.collect(); .collect();
let mut signature = store.from_signature(&mut ctx.unifier, &ctx.primitives, sign, &mut cache); let mut signature = store.from_signature(&mut ctx.unifier, &ctx.primitives, sign, &mut cache);
let ConcreteTypeEnum::TFunc { args, .. } = &mut signature else { unreachable!() }; let ConcreteTypeEnum::TFunc { args, .. } = &mut signature else { codegen_unreachable!(ctx) };
if let Some(obj) = &obj { if let Some(obj) = &obj {
let zelf = store.from_unifier_type(&mut ctx.unifier, &ctx.primitives, obj.0, &mut cache); let zelf = store.from_unifier_type(&mut ctx.unifier, &ctx.primitives, obj.0, &mut cache);
@ -1121,7 +1120,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
expr: &Expr<Option<Type>>, expr: &Expr<Option<Type>>,
) -> Result<Option<BasicValueEnum<'ctx>>, String> { ) -> Result<Option<BasicValueEnum<'ctx>>, String> {
let ExprKind::ListComp { elt, generators } = &expr.node else { unreachable!() }; let ExprKind::ListComp { elt, generators } = &expr.node else { codegen_unreachable!(ctx) };
let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
@ -1380,13 +1379,13 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
if let TypeEnum::TObj { params, .. } = &*ctx.unifier.get_ty_immutable(ty1) { if let TypeEnum::TObj { params, .. } = &*ctx.unifier.get_ty_immutable(ty1) {
ctx.unifier.get_representative(*params.iter().next().unwrap().1) ctx.unifier.get_representative(*params.iter().next().unwrap().1)
} else { } else {
unreachable!() codegen_unreachable!(ctx)
}; };
let elem_ty2 = let elem_ty2 =
if let TypeEnum::TObj { params, .. } = &*ctx.unifier.get_ty_immutable(ty2) { if let TypeEnum::TObj { params, .. } = &*ctx.unifier.get_ty_immutable(ty2) {
ctx.unifier.get_representative(*params.iter().next().unwrap().1) ctx.unifier.get_representative(*params.iter().next().unwrap().1)
} else { } else {
unreachable!() codegen_unreachable!(ctx)
}; };
debug_assert!(ctx.unifier.unioned(elem_ty1, elem_ty2)); debug_assert!(ctx.unifier.unioned(elem_ty1, elem_ty2));
@ -1459,7 +1458,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
{ {
*params.iter().next().unwrap().1 *params.iter().next().unwrap().1
} else { } else {
unreachable!() codegen_unreachable!(ctx)
}; };
(elem_ty, left_val, right_val) (elem_ty, left_val, right_val)
@ -1469,12 +1468,12 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
{ {
*params.iter().next().unwrap().1 *params.iter().next().unwrap().1
} else { } else {
unreachable!() codegen_unreachable!(ctx)
}; };
(elem_ty, right_val, left_val) (elem_ty, right_val, left_val)
} else { } else {
unreachable!() codegen_unreachable!(ctx)
}; };
let list_val = let list_val =
ListValue::from_ptr_val(list_val.into_pointer_value(), llvm_usize, None); ListValue::from_ptr_val(list_val.into_pointer_value(), llvm_usize, None);
@ -1641,7 +1640,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
} else { } else {
let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap()); let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap());
let TypeEnum::TObj { fields, obj_id, .. } = left_ty_enum.as_ref() else { let TypeEnum::TObj { fields, obj_id, .. } = left_ty_enum.as_ref() else {
unreachable!("must be tobj") codegen_unreachable!(ctx, "must be tobj")
}; };
let (op_name, id) = { let (op_name, id) = {
let normal_method_name = Binop::normal(op.base).op_info().method_name; let normal_method_name = Binop::normal(op.base).op_info().method_name;
@ -1662,19 +1661,19 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
} else { } else {
let left_enum_ty = ctx.unifier.get_ty_immutable(left_ty.unwrap()); let left_enum_ty = ctx.unifier.get_ty_immutable(left_ty.unwrap());
let TypeEnum::TObj { fields, .. } = left_enum_ty.as_ref() else { let TypeEnum::TObj { fields, .. } = left_enum_ty.as_ref() else {
unreachable!("must be tobj") codegen_unreachable!(ctx, "must be tobj")
}; };
let fn_ty = fields.get(&op_name).unwrap().0; let fn_ty = fields.get(&op_name).unwrap().0;
let fn_ty_enum = ctx.unifier.get_ty_immutable(fn_ty); let fn_ty_enum = ctx.unifier.get_ty_immutable(fn_ty);
let TypeEnum::TFunc(sig) = fn_ty_enum.as_ref() else { unreachable!() }; let TypeEnum::TFunc(sig) = fn_ty_enum.as_ref() else { codegen_unreachable!(ctx) };
sig.clone() sig.clone()
}; };
let fun_id = { let fun_id = {
let defs = ctx.top_level.definitions.read(); let defs = ctx.top_level.definitions.read();
let obj_def = defs.get(id.0).unwrap().read(); let obj_def = defs.get(id.0).unwrap().read();
let TopLevelDef::Class { methods, .. } = &*obj_def else { unreachable!() }; let TopLevelDef::Class { methods, .. } = &*obj_def else { codegen_unreachable!(ctx) };
methods.iter().find(|method| method.0 == op_name).unwrap().2 methods.iter().find(|method| method.0 == op_name).unwrap().2
}; };
@ -1805,7 +1804,8 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
if op == ast::Unaryop::Invert { if op == ast::Unaryop::Invert {
ast::Unaryop::Not ast::Unaryop::Not
} else { } else {
unreachable!( codegen_unreachable!(
ctx,
"ufunc {} not supported for ndarray[bool, N]", "ufunc {} not supported for ndarray[bool, N]",
op.op_info().method_name, op.op_info().method_name,
) )
@ -1872,8 +1872,8 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
{ {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let (Some(left_ty), lhs) = left else { unreachable!() }; let (Some(left_ty), lhs) = left else { codegen_unreachable!(ctx) };
let (Some(right_ty), rhs) = comparators[0] else { unreachable!() }; let (Some(right_ty), rhs) = comparators[0] else { codegen_unreachable!(ctx) };
let op = ops[0]; let op = ops[0];
let is_ndarray1 = let is_ndarray1 =
@ -1980,7 +1980,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
let op = match op { let op = match op {
ast::Cmpop::Eq | ast::Cmpop::Is => IntPredicate::EQ, ast::Cmpop::Eq | ast::Cmpop::Is => IntPredicate::EQ,
ast::Cmpop::NotEq => IntPredicate::NE, ast::Cmpop::NotEq => IntPredicate::NE,
_ if left_ty == ctx.primitives.bool => unreachable!(), _ if left_ty == ctx.primitives.bool => codegen_unreachable!(ctx),
ast::Cmpop::Lt => { ast::Cmpop::Lt => {
if use_unsigned_ops { if use_unsigned_ops {
IntPredicate::ULT IntPredicate::ULT
@ -2009,7 +2009,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
IntPredicate::SGE IntPredicate::SGE
} }
} }
_ => unreachable!(), _ => codegen_unreachable!(ctx),
}; };
ctx.builder.build_int_compare(op, lhs, rhs, "cmp").unwrap() ctx.builder.build_int_compare(op, lhs, rhs, "cmp").unwrap()
@ -2026,7 +2026,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
ast::Cmpop::LtE => inkwell::FloatPredicate::OLE, ast::Cmpop::LtE => inkwell::FloatPredicate::OLE,
ast::Cmpop::Gt => inkwell::FloatPredicate::OGT, ast::Cmpop::Gt => inkwell::FloatPredicate::OGT,
ast::Cmpop::GtE => inkwell::FloatPredicate::OGE, ast::Cmpop::GtE => inkwell::FloatPredicate::OGE,
_ => unreachable!(), _ => codegen_unreachable!(ctx),
}; };
ctx.builder.build_float_compare(op, lhs, rhs, "cmp").unwrap() ctx.builder.build_float_compare(op, lhs, rhs, "cmp").unwrap()
} else if left_ty == ctx.primitives.str { } else if left_ty == ctx.primitives.str {
@ -2158,7 +2158,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
match (op, val) { match (op, val) {
(Cmpop::Eq, true) | (Cmpop::NotEq, false) => llvm_i1.const_all_ones(), (Cmpop::Eq, true) | (Cmpop::NotEq, false) => llvm_i1.const_all_ones(),
(Cmpop::Eq, false) | (Cmpop::NotEq, true) => llvm_i1.const_zero(), (Cmpop::Eq, false) | (Cmpop::NotEq, true) => llvm_i1.const_zero(),
(_, _) => unreachable!(), (_, _) => codegen_unreachable!(ctx),
} }
}; };
@ -2171,14 +2171,14 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
{ {
*params.iter().next().unwrap().1 *params.iter().next().unwrap().1
} else { } else {
unreachable!() codegen_unreachable!(ctx)
}; };
let right_elem_ty = if let TypeEnum::TObj { params, .. } = let right_elem_ty = if let TypeEnum::TObj { params, .. } =
&*ctx.unifier.get_ty_immutable(right_ty) &*ctx.unifier.get_ty_immutable(right_ty)
{ {
*params.iter().next().unwrap().1 *params.iter().next().unwrap().1
} else { } else {
unreachable!() codegen_unreachable!(ctx)
}; };
if !ctx.unifier.unioned(left_elem_ty, right_elem_ty) { if !ctx.unifier.unioned(left_elem_ty, right_elem_ty) {
@ -2386,7 +2386,10 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
}) })
.map(BasicValueEnum::into_int_value)?; .map(BasicValueEnum::into_int_value)?;
Ok(ctx.builder.build_not(cmp, "").unwrap()) Ok(ctx.builder.build_not(
generator.bool_to_i1(ctx, cmp),
"",
).unwrap())
}, },
|_, ctx| { |_, ctx| {
let bb = ctx.builder.get_insert_block().unwrap(); let bb = ctx.builder.get_insert_block().unwrap();
@ -2535,7 +2538,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
.const_null() .const_null()
.into() .into()
} }
_ => unreachable!("must be option type"), _ => codegen_unreachable!(ctx, "must be option type"),
} }
} }
ExprKind::Name { id, .. } => match ctx.var_assignment.get(id) { ExprKind::Name { id, .. } => match ctx.var_assignment.get(id) {
@ -2545,29 +2548,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
Some((_, Some(static_value), _)) => ValueEnum::Static(static_value.clone()), Some((_, Some(static_value), _)) => ValueEnum::Static(static_value.clone()),
None => { None => {
let resolver = ctx.resolver.clone(); let resolver = ctx.resolver.clone();
if let Some(res) = resolver.get_symbol_value(*id, ctx) { resolver.get_symbol_value(*id, ctx).unwrap()
res
} else {
// Allow "raise Exception" short form
let def_id = resolver.get_identifier_def(*id).map_err(|e| {
format!("{} (at {})", e.iter().next().unwrap(), expr.location)
})?;
let def = ctx.top_level.definitions.read();
if let TopLevelDef::Class { constructor, .. } = *def[def_id.0].read() {
let TypeEnum::TFunc(signature) =
ctx.unifier.get_ty(constructor.unwrap()).as_ref().clone()
else {
return Err(format!(
"Failed to resolve symbol {} (at {})",
id, expr.location
));
};
return Ok(generator
.gen_call(ctx, None, (&signature, def_id), Vec::default())?
.map(Into::into));
}
return Err(format!("Failed to resolve symbol {} (at {})", id, expr.location));
}
} }
}, },
ExprKind::List { elts, .. } => { ExprKind::List { elts, .. } => {
@ -2596,7 +2577,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
*params.iter().next().unwrap().1 *params.iter().next().unwrap().1
} else { } else {
unreachable!() codegen_unreachable!(ctx)
}; };
if let TypeEnum::TVar { .. } = &*ctx.unifier.get_ty_immutable(ty) { if let TypeEnum::TVar { .. } = &*ctx.unifier.get_ty_immutable(ty) {
@ -2690,7 +2671,9 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
return generator.gen_expr(ctx, &modified_expr); return generator.gen_expr(ctx, &modified_expr);
} }
None => unreachable!("Function Type should not have attributes"), None => {
codegen_unreachable!(ctx, "Function Type should not have attributes")
}
} }
} else if let TypeEnum::TObj { obj_id, fields, params } = &*ctx.unifier.get_ty(c) { } else if let TypeEnum::TObj { obj_id, fields, params } = &*ctx.unifier.get_ty(c) {
if fields.is_empty() && params.is_empty() { if fields.is_empty() && params.is_empty() {
@ -2712,7 +2695,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
return generator.gen_expr(ctx, &modified_expr); return generator.gen_expr(ctx, &modified_expr);
} }
None => unreachable!(), None => codegen_unreachable!(ctx),
} }
} }
} }
@ -2814,7 +2797,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
} }
(Some(a), None) => a.into(), (Some(a), None) => a.into(),
(None, Some(b)) => b.into(), (None, Some(b)) => b.into(),
(None, None) => unreachable!(), (None, None) => codegen_unreachable!(ctx),
} }
} }
ExprKind::BinOp { op, left, right } => { ExprKind::BinOp { op, left, right } => {
@ -2904,7 +2887,9 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
ctx.unifier.get_call_signature(*call).unwrap() ctx.unifier.get_call_signature(*call).unwrap()
} else { } else {
let ty = func.custom.unwrap(); let ty = func.custom.unwrap();
let TypeEnum::TFunc(sign) = &*ctx.unifier.get_ty(ty) else { unreachable!() }; let TypeEnum::TFunc(sign) = &*ctx.unifier.get_ty(ty) else {
codegen_unreachable!(ctx)
};
sign.clone() sign.clone()
}; };
@ -2923,17 +2908,26 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
let Some(val) = generator.gen_expr(ctx, value)? else { return Ok(None) }; let Some(val) = generator.gen_expr(ctx, value)? else { return Ok(None) };
// Handle Class Method calls // Handle Class Method calls
// The attribute will be `DefinitionId` of the method if the call is to one of the parent methods
let func_id = attr.to_string().parse::<usize>();
let id = if let TypeEnum::TObj { obj_id, .. } = let id = if let TypeEnum::TObj { obj_id, .. } =
&*ctx.unifier.get_ty(value.custom.unwrap()) &*ctx.unifier.get_ty(value.custom.unwrap())
{ {
*obj_id *obj_id
} else { } else {
unreachable!() codegen_unreachable!(ctx)
}; };
let fun_id = {
// Use the `DefinitionID` from attribute if it is available
let fun_id = if let Ok(func_id) = func_id {
DefinitionId(func_id)
} else {
let defs = ctx.top_level.definitions.read(); let defs = ctx.top_level.definitions.read();
let obj_def = defs.get(id.0).unwrap().read(); let obj_def = defs.get(id.0).unwrap().read();
let TopLevelDef::Class { methods, .. } = &*obj_def else { unreachable!() }; let TopLevelDef::Class { methods, .. } = &*obj_def else {
codegen_unreachable!(ctx)
};
methods.iter().find(|method| method.0 == *attr).unwrap().2 methods.iter().find(|method| method.0 == *attr).unwrap().2
}; };
@ -3004,7 +2998,9 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
.unwrap(), .unwrap(),
)); ));
} }
ValueEnum::Dynamic(_) => unreachable!("option must be static or ptr"), ValueEnum::Dynamic(_) => {
codegen_unreachable!(ctx, "option must be static or ptr")
}
} }
} }
@ -3161,7 +3157,10 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
if let ExprKind::Constant { value: Constant::Int(v), .. } = &slice.node { if let ExprKind::Constant { value: Constant::Int(v), .. } = &slice.node {
(*v).try_into().unwrap() (*v).try_into().unwrap()
} else { } else {
unreachable!("tuple subscript must be const int after type check"); codegen_unreachable!(
ctx,
"tuple subscript must be const int after type check"
);
}; };
match generator.gen_expr(ctx, value)? { match generator.gen_expr(ctx, value)? {
Some(ValueEnum::Dynamic(v)) => { Some(ValueEnum::Dynamic(v)) => {
@ -3184,7 +3183,10 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
None => return Ok(None), None => return Ok(None),
} }
} }
_ => unreachable!("should not be other subscriptable types after type check"), _ => codegen_unreachable!(
ctx,
"should not be other subscriptable types after type check"
),
} }
} }
ExprKind::ListComp { .. } => { ExprKind::ListComp { .. } => {
@ -3197,42 +3199,3 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
_ => unimplemented!(), _ => unimplemented!(),
})) }))
} }
/// Generate LLVM IR for an [`ExprKind::Slice`]
#[allow(clippy::type_complexity)]
pub fn gen_slice<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
lower: &Option<Box<Expr<Option<Type>>>>,
upper: &Option<Box<Expr<Option<Type>>>>,
step: &Option<Box<Expr<Option<Type>>>>,
) -> Result<
(
Option<Instance<'ctx, Int<Int32>>>,
Option<Instance<'ctx, Int<Int32>>>,
Option<Instance<'ctx, Int<Int32>>>,
),
String,
> {
let mut help = |value_expr: &Option<Box<Expr<Option<Type>>>>| -> Result<_, String> {
Ok(match value_expr {
None => None,
Some(value_expr) => {
let value_expr = generator
.gen_expr(ctx, value_expr)?
.unwrap()
.to_basic_value_enum(ctx, generator, ctx.primitives.int32)?;
let value_expr = Int(Int32).check_value(generator, ctx.ctx, value_expr).unwrap();
Some(value_expr)
}
})
};
let lower = help(lower)?;
let upper = help(upper)?;
let step = help(step)?;
Ok((lower, upper, step))
}

View File

@ -3,19 +3,19 @@ use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Type};
use super::{ use super::{
classes::{ classes::{
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue,
TypedArrayLikeAdapter, UntypedArrayLikeAccessor, TypedArrayLikeAccessor, TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
}, },
llvm_intrinsics, llvm_intrinsics,
macros::codegen_unreachable,
model::*, model::*,
object::{ object::{
list::List, list::List,
ndarray::{broadcast::ShapeEntry, indexing::NDIndex, nditer::NDIter, NDArray}, ndarray::{broadcast::ShapeEntry, indexing::NDIndex, nditer::NDIter, NDArray},
}, },
stmt::gen_for_callback_incrementing,
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}; };
use crate::codegen::classes::TypedArrayLikeAccessor; use function::FnCall;
use crate::codegen::stmt::gen_for_callback_incrementing;
use function::CallFunction;
use inkwell::{ use inkwell::{
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
context::Context, context::Context,
@ -29,7 +29,7 @@ use itertools::Either;
use nac3parser::ast::Expr; use nac3parser::ast::Expr;
#[must_use] #[must_use]
pub fn load_irrt(ctx: &Context) -> Module { pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> {
let bitcode_buf = MemoryBuffer::create_from_memory_range( let bitcode_buf = MemoryBuffer::create_from_memory_range(
include_bytes!(concat!(env!("OUT_DIR"), "/irrt.bc")), include_bytes!(concat!(env!("OUT_DIR"), "/irrt.bc")),
"irrt_bitcode_buffer", "irrt_bitcode_buffer",
@ -45,6 +45,25 @@ pub fn load_irrt(ctx: &Context) -> Module {
let function = irrt_mod.get_function(symbol).unwrap(); let function = irrt_mod.get_function(symbol).unwrap();
function.add_attribute(AttributeLoc::Function, ctx.create_enum_attribute(inline_attr, 0)); function.add_attribute(AttributeLoc::Function, ctx.create_enum_attribute(inline_attr, 0));
} }
// Initialize all global `EXN_*` exception IDs in IRRT with the [`SymbolResolver`].
let exn_id_type = ctx.i32_type();
let errors = &[
("EXN_INDEX_ERROR", "0:IndexError"),
("EXN_VALUE_ERROR", "0:ValueError"),
("EXN_ASSERTION_ERROR", "0:AssertionError"),
("EXN_TYPE_ERROR", "0:TypeError"),
];
for (irrt_name, symbol_name) in errors {
let exn_id = symbol_resolver.get_string_id(symbol_name);
let exn_id = exn_id_type.const_int(exn_id as u64, false).as_basic_value_enum();
let global = irrt_mod.get_global(irrt_name).unwrap_or_else(|| {
panic!("Exception symbol name '{irrt_name}' should exist in the IRRT LLVM module")
});
global.set_initializer(&exn_id);
}
irrt_mod irrt_mod
} }
@ -62,7 +81,7 @@ pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>(
(64, 64, true) => "__nac3_int_exp_int64_t", (64, 64, true) => "__nac3_int_exp_int64_t",
(32, 32, false) => "__nac3_int_exp_uint32_t", (32, 32, false) => "__nac3_int_exp_uint32_t",
(64, 64, false) => "__nac3_int_exp_uint64_t", (64, 64, false) => "__nac3_int_exp_uint64_t",
_ => unreachable!(), _ => codegen_unreachable!(ctx),
}; };
let base_type = base.get_type(); let base_type = base.get_type();
let pow_fun = ctx.module.get_function(symbol).unwrap_or_else(|| { let pow_fun = ctx.module.get_function(symbol).unwrap_or_else(|| {
@ -448,7 +467,7 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
BasicTypeEnum::IntType(t) => t.size_of(), BasicTypeEnum::IntType(t) => t.size_of(),
BasicTypeEnum::PointerType(t) => t.size_of(), BasicTypeEnum::PointerType(t) => t.size_of(),
BasicTypeEnum::StructType(t) => t.size_of().unwrap(), BasicTypeEnum::StructType(t) => t.size_of().unwrap(),
_ => unreachable!(), _ => codegen_unreachable!(ctx),
}; };
ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size").unwrap() ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size").unwrap()
} }
@ -593,7 +612,7 @@ where
let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() { let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_size", 32 => "__nac3_ndarray_calc_size",
64 => "__nac3_ndarray_calc_size64", 64 => "__nac3_ndarray_calc_size64",
bw => unreachable!("Unsupported size type bit width: {}", bw), bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
}; };
let ndarray_calc_size_fn_t = llvm_usize.fn_type( let ndarray_calc_size_fn_t = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()], &[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()],
@ -644,7 +663,7 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() { let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_nd_indices", 32 => "__nac3_ndarray_calc_nd_indices",
64 => "__nac3_ndarray_calc_nd_indices64", 64 => "__nac3_ndarray_calc_nd_indices64",
bw => unreachable!("Unsupported size type bit width: {}", bw), bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
}; };
let ndarray_calc_nd_indices_fn = let ndarray_calc_nd_indices_fn =
ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| { ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| {
@ -713,7 +732,7 @@ where
let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() { let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_flatten_index", 32 => "__nac3_ndarray_flatten_index",
64 => "__nac3_ndarray_flatten_index64", 64 => "__nac3_ndarray_flatten_index64",
bw => unreachable!("Unsupported size type bit width: {}", bw), bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
}; };
let ndarray_flatten_index_fn = let ndarray_flatten_index_fn =
ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| { ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| {
@ -781,7 +800,7 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast", 32 => "__nac3_ndarray_calc_broadcast",
64 => "__nac3_ndarray_calc_broadcast64", 64 => "__nac3_ndarray_calc_broadcast64",
bw => unreachable!("Unsupported size type bit width: {}", bw), bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
}; };
let ndarray_calc_broadcast_fn = let ndarray_calc_broadcast_fn =
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
@ -901,7 +920,7 @@ pub fn call_ndarray_calc_broadcast_index<
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast_idx", 32 => "__nac3_ndarray_calc_broadcast_idx",
64 => "__nac3_ndarray_calc_broadcast_idx64", 64 => "__nac3_ndarray_calc_broadcast_idx64",
bw => unreachable!("Unsupported size type bit width: {}", bw), bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
}; };
let ndarray_calc_broadcast_fn = let ndarray_calc_broadcast_fn =
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
@ -937,32 +956,6 @@ pub fn call_ndarray_calc_broadcast_index<
) )
} }
/// Initialize all global `EXN_*` exception IDs in IRRT with the [`SymbolResolver`].
pub fn setup_irrt_exceptions<'ctx>(
ctx: &'ctx Context,
module: &Module<'ctx>,
symbol_resolver: &dyn SymbolResolver,
) {
let exn_id_type = ctx.i32_type();
let errors = &[
("EXN_INDEX_ERROR", "0:IndexError"),
("EXN_VALUE_ERROR", "0:ValueError"),
("EXN_ASSERTION_ERROR", "0:AssertionError"),
("EXN_TYPE_ERROR", "0:TypeError"),
];
for (irrt_name, symbol_name) in errors {
let exn_id = symbol_resolver.get_string_id(symbol_name);
let exn_id = exn_id_type.const_int(exn_id as u64, false).as_basic_value_enum();
let global = module.get_global(irrt_name).unwrap_or_else(|| {
panic!("Exception symbol name '{irrt_name}' should exist in the IRRT LLVM module")
});
global.set_initializer(&exn_id);
}
}
// When [`TypeContext::size_type`] is 32-bits, the function name is "{fn_name}". // When [`TypeContext::size_type`] is 32-bits, the function name is "{fn_name}".
// When [`TypeContext::size_type`] is 64-bits, the function name is "{fn_name}64". // When [`TypeContext::size_type`] is 64-bits, the function name is "{fn_name}64".
#[must_use] #[must_use]
@ -993,7 +986,7 @@ pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator +
ctx, ctx,
"__nac3_ndarray_util_assert_shape_no_negative", "__nac3_ndarray_util_assert_shape_no_negative",
); );
CallFunction::begin(generator, ctx, &name).arg(ndims).arg(shape).returning_void(); FnCall::builder(generator, ctx, &name).arg(ndims).arg(shape).returning_void();
} }
pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + ?Sized>( pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + ?Sized>(
@ -1009,7 +1002,7 @@ pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator +
ctx, ctx,
"__nac3_ndarray_util_assert_output_shape_same", "__nac3_ndarray_util_assert_output_shape_same",
); );
CallFunction::begin(generator, ctx, &name) FnCall::builder(generator, ctx, &name)
.arg(ndarray_ndims) .arg(ndarray_ndims)
.arg(ndarray_shape) .arg(ndarray_shape)
.arg(output_ndims) .arg(output_ndims)
@ -1023,7 +1016,7 @@ pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>(
ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>, ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
) -> Instance<'ctx, Int<SizeT>> { ) -> Instance<'ctx, Int<SizeT>> {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_size"); let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_size");
CallFunction::begin(generator, ctx, &name).arg(ndarray).returning_auto("size") FnCall::builder(generator, ctx, &name).arg(ndarray).returning_auto("size")
} }
pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>( pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>(
@ -1032,7 +1025,7 @@ pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>(
ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>, ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
) -> Instance<'ctx, Int<SizeT>> { ) -> Instance<'ctx, Int<SizeT>> {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_nbytes"); let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_nbytes");
CallFunction::begin(generator, ctx, &name).arg(ndarray).returning_auto("nbytes") FnCall::builder(generator, ctx, &name).arg(ndarray).returning_auto("nbytes")
} }
pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>( pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>(
@ -1041,7 +1034,7 @@ pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>(
ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>, ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
) -> Instance<'ctx, Int<SizeT>> { ) -> Instance<'ctx, Int<SizeT>> {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_len"); let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_len");
CallFunction::begin(generator, ctx, &name).arg(ndarray).returning_auto("len") FnCall::builder(generator, ctx, &name).arg(ndarray).returning_auto("len")
} }
pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>( pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>(
@ -1050,7 +1043,7 @@ pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>(
ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>, ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
) -> Instance<'ctx, Int<Bool>> { ) -> Instance<'ctx, Int<Bool>> {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_is_c_contiguous"); let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_is_c_contiguous");
CallFunction::begin(generator, ctx, &name).arg(ndarray).returning_auto("is_c_contiguous") FnCall::builder(generator, ctx, &name).arg(ndarray).returning_auto("is_c_contiguous")
} }
pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>( pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>(
@ -1060,7 +1053,7 @@ pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>(
index: Instance<'ctx, Int<SizeT>>, index: Instance<'ctx, Int<SizeT>>,
) -> Instance<'ctx, Ptr<Int<Byte>>> { ) -> Instance<'ctx, Ptr<Int<Byte>>> {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_get_nth_pelement"); let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_get_nth_pelement");
CallFunction::begin(generator, ctx, &name).arg(ndarray).arg(index).returning_auto("pelement") FnCall::builder(generator, ctx, &name).arg(ndarray).arg(index).returning_auto("pelement")
} }
pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized>( pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized>(
@ -1071,7 +1064,7 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized
) -> Instance<'ctx, Ptr<Int<Byte>>> { ) -> Instance<'ctx, Ptr<Int<Byte>>> {
let name = let name =
get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_get_pelement_by_indices"); get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_get_pelement_by_indices");
CallFunction::begin(generator, ctx, &name).arg(ndarray).arg(indices).returning_auto("pelement") FnCall::builder(generator, ctx, &name).arg(ndarray).arg(indices).returning_auto("pelement")
} }
pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>( pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>(
@ -1081,7 +1074,7 @@ pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>(
) { ) {
let name = let name =
get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_set_strides_by_shape"); get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_set_strides_by_shape");
CallFunction::begin(generator, ctx, &name).arg(ndarray).returning_void(); FnCall::builder(generator, ctx, &name).arg(ndarray).returning_void();
} }
pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>( pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>(
@ -1091,7 +1084,7 @@ pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>(
dst_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>, dst_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
) { ) {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_copy_data"); let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_copy_data");
CallFunction::begin(generator, ctx, &name).arg(src_ndarray).arg(dst_ndarray).returning_void(); FnCall::builder(generator, ctx, &name).arg(src_ndarray).arg(dst_ndarray).returning_void();
} }
pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>( pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>(
@ -1102,16 +1095,16 @@ pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>(
indices: Instance<'ctx, Ptr<Int<SizeT>>>, indices: Instance<'ctx, Ptr<Int<SizeT>>>,
) { ) {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_initialize"); let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_initialize");
CallFunction::begin(generator, ctx, &name).arg(iter).arg(ndarray).arg(indices).returning_void(); FnCall::builder(generator, ctx, &name).arg(iter).arg(ndarray).arg(indices).returning_void();
} }
pub fn call_nac3_nditer_has_next<'ctx, G: CodeGenerator + ?Sized>( pub fn call_nac3_nditer_has_element<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
iter: Instance<'ctx, Ptr<Struct<NDIter>>>, iter: Instance<'ctx, Ptr<Struct<NDIter>>>,
) -> Instance<'ctx, Int<Bool>> { ) -> Instance<'ctx, Int<Bool>> {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_has_next"); let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_has_element");
CallFunction::begin(generator, ctx, &name).arg(iter).returning_auto("has_next") FnCall::builder(generator, ctx, &name).arg(iter).returning_auto("has_element")
} }
pub fn call_nac3_nditer_next<'ctx, G: CodeGenerator + ?Sized>( pub fn call_nac3_nditer_next<'ctx, G: CodeGenerator + ?Sized>(
@ -1120,7 +1113,7 @@ pub fn call_nac3_nditer_next<'ctx, G: CodeGenerator + ?Sized>(
iter: Instance<'ctx, Ptr<Struct<NDIter>>>, iter: Instance<'ctx, Ptr<Struct<NDIter>>>,
) { ) {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_next"); let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_next");
CallFunction::begin(generator, ctx, &name).arg(iter).returning_void(); FnCall::builder(generator, ctx, &name).arg(iter).returning_void();
} }
pub fn call_nac3_ndarray_index<'ctx, G: CodeGenerator + ?Sized>( pub fn call_nac3_ndarray_index<'ctx, G: CodeGenerator + ?Sized>(
@ -1132,7 +1125,7 @@ pub fn call_nac3_ndarray_index<'ctx, G: CodeGenerator + ?Sized>(
dst_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>, dst_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
) { ) {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_index"); let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_index");
CallFunction::begin(generator, ctx, &name) FnCall::builder(generator, ctx, &name)
.arg(num_indices) .arg(num_indices)
.arg(indices) .arg(indices)
.arg(src_ndarray) .arg(src_ndarray)
@ -1152,7 +1145,7 @@ pub fn call_nac3_ndarray_array_set_and_validate_list_shape<'ctx, G: CodeGenerato
ctx, ctx,
"__nac3_ndarray_array_set_and_validate_list_shape", "__nac3_ndarray_array_set_and_validate_list_shape",
); );
CallFunction::begin(generator, ctx, &name).arg(list).arg(ndims).arg(shape).returning_void(); FnCall::builder(generator, ctx, &name).arg(list).arg(ndims).arg(shape).returning_void();
} }
pub fn call_nac3_ndarray_array_write_list_to_array<'ctx, G: CodeGenerator + ?Sized>( pub fn call_nac3_ndarray_array_write_list_to_array<'ctx, G: CodeGenerator + ?Sized>(
@ -1166,7 +1159,7 @@ pub fn call_nac3_ndarray_array_write_list_to_array<'ctx, G: CodeGenerator + ?Siz
ctx, ctx,
"__nac3_ndarray_array_write_list_to_array", "__nac3_ndarray_array_write_list_to_array",
); );
CallFunction::begin(generator, ctx, &name).arg(list).arg(ndarray).returning_void(); FnCall::builder(generator, ctx, &name).arg(list).arg(ndarray).returning_void();
} }
pub fn call_nac3_ndarray_reshape_resolve_and_check_new_shape<'ctx, G: CodeGenerator + ?Sized>( pub fn call_nac3_ndarray_reshape_resolve_and_check_new_shape<'ctx, G: CodeGenerator + ?Sized>(
@ -1181,11 +1174,7 @@ pub fn call_nac3_ndarray_reshape_resolve_and_check_new_shape<'ctx, G: CodeGenera
ctx, ctx,
"__nac3_ndarray_reshape_resolve_and_check_new_shape", "__nac3_ndarray_reshape_resolve_and_check_new_shape",
); );
CallFunction::begin(generator, ctx, &name) FnCall::builder(generator, ctx, &name).arg(size).arg(new_ndims).arg(new_shape).returning_void();
.arg(size)
.arg(new_ndims)
.arg(new_shape)
.returning_void();
} }
pub fn call_nac3_ndarray_broadcast_to<'ctx, G: CodeGenerator + ?Sized>( pub fn call_nac3_ndarray_broadcast_to<'ctx, G: CodeGenerator + ?Sized>(
@ -1195,7 +1184,7 @@ pub fn call_nac3_ndarray_broadcast_to<'ctx, G: CodeGenerator + ?Sized>(
dst_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>, dst_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
) { ) {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_to"); let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_to");
CallFunction::begin(generator, ctx, &name).arg(src_ndarray).arg(dst_ndarray).returning_void(); FnCall::builder(generator, ctx, &name).arg(src_ndarray).arg(dst_ndarray).returning_void();
} }
pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G: CodeGenerator + ?Sized>( pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G: CodeGenerator + ?Sized>(
@ -1207,7 +1196,7 @@ pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G: CodeGenerator + ?Sized>(
dst_shape: Instance<'ctx, Ptr<Int<SizeT>>>, dst_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
) { ) {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_shapes"); let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_shapes");
CallFunction::begin(generator, ctx, &name) FnCall::builder(generator, ctx, &name)
.arg(num_shape_entries) .arg(num_shape_entries)
.arg(shape_entries) .arg(shape_entries)
.arg(dst_ndims) .arg(dst_ndims)
@ -1224,7 +1213,7 @@ pub fn call_nac3_ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
axes: Instance<'ctx, Ptr<Int<SizeT>>>, axes: Instance<'ctx, Ptr<Int<SizeT>>>,
) { ) {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_transpose"); let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_transpose");
CallFunction::begin(generator, ctx, &name) FnCall::builder(generator, ctx, &name)
.arg(src_ndarray) .arg(src_ndarray)
.arg(dst_ndarray) .arg(dst_ndarray)
.arg(num_axes) .arg(num_axes)

View File

@ -54,6 +54,22 @@ mod test;
use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore}; use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore};
pub use generator::{CodeGenerator, DefaultCodeGenerator}; pub use generator::{CodeGenerator, DefaultCodeGenerator};
mod macros {
/// Codegen-variant of [`std::unreachable`] which accepts an instance of [`CodeGenContext`] as
/// its first argument to provide Python source information to indicate the codegen location
/// causing the assertion.
macro_rules! codegen_unreachable {
($ctx:expr $(,)?) => {
std::unreachable!("unreachable code while processing {}", &$ctx.current_loc)
};
($ctx:expr, $($arg:tt)*) => {
std::unreachable!("unreachable code while processing {}: {}", &$ctx.current_loc, std::format!("{}", std::format_args!($($arg)+)))
};
}
pub(crate) use codegen_unreachable;
}
#[derive(Default)] #[derive(Default)]
pub struct StaticValueStore { pub struct StaticValueStore {
pub lookup: HashMap<Vec<(usize, u64)>, usize>, pub lookup: HashMap<Vec<(usize, u64)>, usize>,
@ -493,7 +509,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
} }
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
Ptr(Struct(NDArray)).get_type(generator, ctx).as_basic_type_enum() Ptr(Struct(NDArray)).llvm_type(generator, ctx).as_basic_type_enum()
} }
_ => unreachable!( _ => unreachable!(

View File

@ -18,7 +18,7 @@ impl<'ctx> Model<'ctx> for Any<'ctx> {
type Value = BasicValueEnum<'ctx>; type Value = BasicValueEnum<'ctx>;
type Type = BasicTypeEnum<'ctx>; type Type = BasicTypeEnum<'ctx>;
fn get_type<G: CodeGenerator + ?Sized>( fn llvm_type<G: CodeGenerator + ?Sized>(
&self, &self,
_generator: &G, _generator: &G,
_ctx: &'ctx Context, _ctx: &'ctx Context,

View File

@ -11,8 +11,8 @@ use crate::codegen::{CodeGenContext, CodeGenerator};
use super::*; use super::*;
/// Trait for Rust structs identifying length values for [`Array`]. /// Trait for Rust structs identifying length values for [`Array`].
pub trait LenKind: fmt::Debug + Clone + Copy { pub trait ArrayLen: fmt::Debug + Clone + Copy {
fn get_length(&self) -> u32; fn length(&self) -> u32;
} }
/// A statically known length. /// A statically known length.
@ -23,14 +23,14 @@ pub struct Len<const N: u32>;
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub struct AnyLen(pub u32); pub struct AnyLen(pub u32);
impl<const N: u32> LenKind for Len<N> { impl<const N: u32> ArrayLen for Len<N> {
fn get_length(&self) -> u32 { fn length(&self) -> u32 {
N N
} }
} }
impl LenKind for AnyLen { impl ArrayLen for AnyLen {
fn get_length(&self) -> u32 { fn length(&self) -> u32 {
self.0 self.0
} }
} }
@ -46,12 +46,16 @@ pub struct Array<Len, Item> {
pub item: Item, pub item: Item,
} }
impl<'ctx, Len: LenKind, Item: Model<'ctx>> Model<'ctx> for Array<Len, Item> { impl<'ctx, Len: ArrayLen, Item: Model<'ctx>> Model<'ctx> for Array<Len, Item> {
type Value = ArrayValue<'ctx>; type Value = ArrayValue<'ctx>;
type Type = ArrayType<'ctx>; type Type = ArrayType<'ctx>;
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type { fn llvm_type<G: CodeGenerator + ?Sized>(
self.item.get_type(generator, ctx).array_type(self.len.get_length()) &self,
generator: &G,
ctx: &'ctx Context,
) -> Self::Type {
self.item.llvm_type(generator, ctx).array_type(self.len.length())
} }
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>( fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
@ -65,11 +69,11 @@ impl<'ctx, Len: LenKind, Item: Model<'ctx>> Model<'ctx> for Array<Len, Item> {
return Err(ModelError(format!("Expecting ArrayType, but got {ty:?}"))); return Err(ModelError(format!("Expecting ArrayType, but got {ty:?}")));
}; };
if ty.len() != self.len.get_length() { if ty.len() != self.len.length() {
return Err(ModelError(format!( return Err(ModelError(format!(
"Expecting ArrayType with size {}, but got an ArrayType with size {}", "Expecting ArrayType with size {}, but got an ArrayType with size {}",
ty.len(), ty.len(),
self.len.get_length() self.len.length()
))); )));
} }
@ -81,7 +85,7 @@ impl<'ctx, Len: LenKind, Item: Model<'ctx>> Model<'ctx> for Array<Len, Item> {
} }
} }
impl<'ctx, Len: LenKind, Item: Model<'ctx>> Instance<'ctx, Ptr<Array<Len, Item>>> { impl<'ctx, Len: ArrayLen, Item: Model<'ctx>> Instance<'ctx, Ptr<Array<Len, Item>>> {
/// Get the pointer to the `i`-th (0-based) array element. /// Get the pointer to the `i`-th (0-based) array element.
pub fn gep( pub fn gep(
&self, &self,
@ -91,15 +95,15 @@ impl<'ctx, Len: LenKind, Item: Model<'ctx>> Instance<'ctx, Ptr<Array<Len, Item>>
let zero = ctx.ctx.i32_type().const_zero(); let zero = ctx.ctx.i32_type().const_zero();
let ptr = unsafe { ctx.builder.build_in_bounds_gep(self.value, &[zero, i], "").unwrap() }; let ptr = unsafe { ctx.builder.build_in_bounds_gep(self.value, &[zero, i], "").unwrap() };
Ptr(self.model.0.item).believe_value(ptr) unsafe { Ptr(self.model.0.item).believe_value(ptr) }
} }
/// Like `gep` but `i` is a constant. /// Like `gep` but `i` is a constant.
pub fn gep_const(&self, ctx: &CodeGenContext<'ctx, '_>, i: u64) -> Instance<'ctx, Ptr<Item>> { pub fn gep_const(&self, ctx: &CodeGenContext<'ctx, '_>, i: u64) -> Instance<'ctx, Ptr<Item>> {
assert!( assert!(
i < u64::from(self.model.0.len.get_length()), i < u64::from(self.model.0.len.length()),
"Index {i} is out of bounds. Array length = {}", "Index {i} is out of bounds. Array length = {}",
self.model.0.len.get_length() self.model.0.len.length()
); );
let i = ctx.ctx.i32_type().const_int(i, false); let i = ctx.ctx.i32_type().const_int(i, false);

View File

@ -11,7 +11,7 @@ use crate::codegen::{CodeGenContext, CodeGenerator};
pub struct ModelError(pub String); pub struct ModelError(pub String);
impl ModelError { impl ModelError {
// Append a context message to the error. /// Append a context message to the error.
pub(super) fn under_context(mut self, context: &str) -> Self { pub(super) fn under_context(mut self, context: &str) -> Self {
self.0.push_str(" ... in "); self.0.push_str(" ... in ");
self.0.push_str(context); self.0.push_str(context);
@ -47,7 +47,7 @@ impl ModelError {
/// } /// }
/// ``` /// ```
/// ///
/// ### Notes on converting between Inkwell and model. /// ### Notes on converting between Inkwell and model/ge.
/// ///
/// Suppose you have an [`IntValue`], and you want to pass it into a function that takes a [`Instance<'ctx, Int<Int32>>`]. You can do use /// Suppose you have an [`IntValue`], and you want to pass it into a function that takes a [`Instance<'ctx, Int<Int32>>`]. You can do use
/// [`Model::check_value`] or [`Model::believe_value`]. /// [`Model::check_value`] or [`Model::believe_value`].
@ -68,15 +68,16 @@ pub trait Model<'ctx>: fmt::Debug + Clone + Copy {
/// Return the [`BasicType`] of this model. /// Return the [`BasicType`] of this model.
#[must_use] #[must_use]
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type; fn llvm_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context)
-> Self::Type;
/// Get the number of bytes of the [`BasicType`] of this model. /// Get the number of bytes of the [`BasicType`] of this model.
fn sizeof<G: CodeGenerator + ?Sized>( fn size_of<G: CodeGenerator + ?Sized>(
&self, &self,
generator: &mut G, generator: &mut G,
ctx: &'ctx Context, ctx: &'ctx Context,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
self.get_type(generator, ctx).size_of().unwrap() self.llvm_type(generator, ctx).size_of().unwrap()
} }
/// Check if a [`BasicType`] matches the [`BasicType`] of this model. /// Check if a [`BasicType`] matches the [`BasicType`] of this model.
@ -89,9 +90,11 @@ pub trait Model<'ctx>: fmt::Debug + Clone + Copy {
/// Create an instance from a value. /// Create an instance from a value.
/// ///
/// # Safety
///
/// Caller must make sure the type of `value` and the type of this `model` are equivalent. /// Caller must make sure the type of `value` and the type of this `model` are equivalent.
#[must_use] #[must_use]
fn believe_value(&self, value: Self::Value) -> Instance<'ctx, Self> { unsafe fn believe_value(&self, value: Self::Value) -> Instance<'ctx, Self> {
Instance { model: *self, value } Instance { model: *self, value }
} }
@ -110,7 +113,7 @@ pub trait Model<'ctx>: fmt::Debug + Clone + Copy {
let Ok(value) = Self::Value::try_from(value) else { let Ok(value) = Self::Value::try_from(value) else {
unreachable!("check_type() has bad implementation") unreachable!("check_type() has bad implementation")
}; };
Ok(self.believe_value(value)) unsafe { Ok(self.believe_value(value)) }
} }
// Allocate a value on the stack and return its pointer. // Allocate a value on the stack and return its pointer.
@ -119,8 +122,8 @@ pub trait Model<'ctx>: fmt::Debug + Clone + Copy {
generator: &mut G, generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
) -> Instance<'ctx, Ptr<Self>> { ) -> Instance<'ctx, Ptr<Self>> {
let p = ctx.builder.build_alloca(self.get_type(generator, ctx.ctx), "").unwrap(); let p = ctx.builder.build_alloca(self.llvm_type(generator, ctx.ctx), "").unwrap();
Ptr(*self).believe_value(p) unsafe { Ptr(*self).believe_value(p) }
} }
// Allocate an array on the stack and return its pointer. // Allocate an array on the stack and return its pointer.
@ -130,8 +133,9 @@ pub trait Model<'ctx>: fmt::Debug + Clone + Copy {
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
len: IntValue<'ctx>, len: IntValue<'ctx>,
) -> Instance<'ctx, Ptr<Self>> { ) -> Instance<'ctx, Ptr<Self>> {
let p = ctx.builder.build_array_alloca(self.get_type(generator, ctx.ctx), len, "").unwrap(); let p =
Ptr(*self).believe_value(p) ctx.builder.build_array_alloca(self.llvm_type(generator, ctx.ctx), len, "").unwrap();
unsafe { Ptr(*self).believe_value(p) }
} }
fn var_alloca<G: CodeGenerator + ?Sized>( fn var_alloca<G: CodeGenerator + ?Sized>(
@ -140,9 +144,9 @@ pub trait Model<'ctx>: fmt::Debug + Clone + Copy {
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&str>, name: Option<&str>,
) -> Result<Instance<'ctx, Ptr<Self>>, String> { ) -> Result<Instance<'ctx, Ptr<Self>>, String> {
let ty = self.get_type(generator, ctx.ctx).as_basic_type_enum(); let ty = self.llvm_type(generator, ctx.ctx).as_basic_type_enum();
let p = generator.gen_var_alloc(ctx, ty, name)?; let p = generator.gen_var_alloc(ctx, ty, name)?;
Ok(Ptr(*self).believe_value(p)) unsafe { Ok(Ptr(*self).believe_value(p)) }
} }
fn array_var_alloca<G: CodeGenerator + ?Sized>( fn array_var_alloca<G: CodeGenerator + ?Sized>(
@ -153,9 +157,9 @@ pub trait Model<'ctx>: fmt::Debug + Clone + Copy {
name: Option<&'ctx str>, name: Option<&'ctx str>,
) -> Result<Instance<'ctx, Ptr<Self>>, String> { ) -> Result<Instance<'ctx, Ptr<Self>>, String> {
// TODO: Remove ArraySliceValue // TODO: Remove ArraySliceValue
let ty = self.get_type(generator, ctx.ctx).as_basic_type_enum(); let ty = self.llvm_type(generator, ctx.ctx).as_basic_type_enum();
let p = generator.gen_array_var_alloc(ctx, ty, len, name)?; let p = generator.gen_array_var_alloc(ctx, ty, len, name)?;
Ok(Ptr(*self).believe_value(PointerValue::from(p))) unsafe { Ok(Ptr(*self).believe_value(PointerValue::from(p))) }
} }
/// Allocate a constant array. /// Allocate a constant array.
@ -176,7 +180,7 @@ pub trait Model<'ctx>: fmt::Debug + Clone + Copy {
}; };
} }
let value = match self.get_type(generator, ctx).as_basic_type_enum() { let value = match self.llvm_type(generator, ctx).as_basic_type_enum() {
BasicTypeEnum::ArrayType(t) => make!(t, BasicValueEnum::into_array_value), BasicTypeEnum::ArrayType(t) => make!(t, BasicValueEnum::into_array_value),
BasicTypeEnum::IntType(t) => make!(t, BasicValueEnum::into_int_value), BasicTypeEnum::IntType(t) => make!(t, BasicValueEnum::into_int_value),
BasicTypeEnum::FloatType(t) => make!(t, BasicValueEnum::into_float_value), BasicTypeEnum::FloatType(t) => make!(t, BasicValueEnum::into_float_value),
@ -195,6 +199,7 @@ pub trait Model<'ctx>: fmt::Debug + Clone + Copy {
pub struct Instance<'ctx, M: Model<'ctx>> { pub struct Instance<'ctx, M: Model<'ctx>> {
/// The model of this instance. /// The model of this instance.
pub model: M, pub model: M,
/// The value of this instance. /// The value of this instance.
/// ///
/// It is guaranteed the [`BasicType`] of `value` is consistent with that of `model`. /// It is guaranteed the [`BasicType`] of `value` is consistent with that of `model`.

View File

@ -63,7 +63,11 @@ impl<'ctx, N: FloatKind<'ctx>> Model<'ctx> for Float<N> {
type Value = FloatValue<'ctx>; type Value = FloatValue<'ctx>;
type Type = FloatType<'ctx>; type Type = FloatType<'ctx>;
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type { fn llvm_type<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &'ctx Context,
) -> Self::Type {
self.0.get_float_type(generator, ctx) self.0.get_float_type(generator, ctx)
} }

View File

@ -35,7 +35,7 @@ struct Arg<'ctx> {
/// If `my_function_name` has not been declared in `ctx.module`, once `.returning()` is called, a function /// If `my_function_name` has not been declared in `ctx.module`, once `.returning()` is called, a function
/// declaration of `my_function_name` is added to `ctx.module`, where the [`FunctionType`] is deduced from /// declaration of `my_function_name` is added to `ctx.module`, where the [`FunctionType`] is deduced from
/// the argument types and returning type. /// the argument types and returning type.
pub struct CallFunction<'ctx, 'a, 'b, 'c, 'd, G: CodeGenerator + ?Sized> { pub struct FnCall<'ctx, 'a, 'b, 'c, 'd, G: CodeGenerator + ?Sized> {
generator: &'d mut G, generator: &'d mut G,
ctx: &'b CodeGenContext<'ctx, 'a>, ctx: &'b CodeGenContext<'ctx, 'a>,
/// Function name /// Function name
@ -46,9 +46,9 @@ pub struct CallFunction<'ctx, 'a, 'b, 'c, 'd, G: CodeGenerator + ?Sized> {
attrs: Vec<&'static str>, attrs: Vec<&'static str>,
} }
impl<'ctx, 'a, 'b, 'c, 'd, G: CodeGenerator + ?Sized> CallFunction<'ctx, 'a, 'b, 'c, 'd, G> { impl<'ctx, 'a, 'b, 'c, 'd, G: CodeGenerator + ?Sized> FnCall<'ctx, 'a, 'b, 'c, 'd, G> {
pub fn begin(generator: &'d mut G, ctx: &'b CodeGenContext<'ctx, 'a>, name: &'c str) -> Self { pub fn builder(generator: &'d mut G, ctx: &'b CodeGenContext<'ctx, 'a>, name: &'c str) -> Self {
CallFunction { generator, ctx, name, args: Vec::new(), attrs: Vec::new() } FnCall { generator, ctx, name, args: Vec::new(), attrs: Vec::new() }
} }
/// Push a list of LLVM function attributes to the function declaration. /// Push a list of LLVM function attributes to the function declaration.
@ -63,7 +63,7 @@ impl<'ctx, 'a, 'b, 'c, 'd, G: CodeGenerator + ?Sized> CallFunction<'ctx, 'a, 'b,
#[must_use] #[must_use]
pub fn arg<M: Model<'ctx>>(mut self, arg: Instance<'ctx, M>) -> Self { pub fn arg<M: Model<'ctx>>(mut self, arg: Instance<'ctx, M>) -> Self {
let arg = Arg { let arg = Arg {
ty: arg.model.get_type(self.generator, self.ctx.ctx).as_basic_type_enum().into(), ty: arg.model.llvm_type(self.generator, self.ctx.ctx).as_basic_type_enum().into(),
val: arg.value.as_basic_value_enum().into(), val: arg.value.as_basic_value_enum().into(),
}; };
self.args.push(arg); self.args.push(arg);
@ -73,7 +73,7 @@ impl<'ctx, 'a, 'b, 'c, 'd, G: CodeGenerator + ?Sized> CallFunction<'ctx, 'a, 'b,
/// Call the function and expect the function to return a value of type of `return_model`. /// Call the function and expect the function to return a value of type of `return_model`.
#[must_use] #[must_use]
pub fn returning<M: Model<'ctx>>(self, name: &str, return_model: M) -> Instance<'ctx, M> { pub fn returning<M: Model<'ctx>>(self, name: &str, return_model: M) -> Instance<'ctx, M> {
let ret_ty = return_model.get_type(self.generator, self.ctx.ctx); let ret_ty = return_model.llvm_type(self.generator, self.ctx.ctx);
let ret = self.call(|tys| ret_ty.fn_type(tys, false), name); let ret = self.call(|tys| ret_ty.fn_type(tys, false), name);
let ret = BasicValueEnum::try_from(ret.as_any_value_enum()).unwrap(); // Must work let ret = BasicValueEnum::try_from(ret.as_any_value_enum()).unwrap(); // Must work

View File

@ -100,7 +100,11 @@ impl<'ctx, N: IntKind<'ctx>> Model<'ctx> for Int<N> {
type Value = IntValue<'ctx>; type Value = IntValue<'ctx>;
type Type = IntType<'ctx>; type Type = IntType<'ctx>;
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type { fn llvm_type<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &'ctx Context,
) -> Self::Type {
self.0.get_int_type(generator, ctx) self.0.get_int_type(generator, ctx)
} }
@ -134,9 +138,10 @@ impl<'ctx, N: IntKind<'ctx>> Int<N> {
generator: &mut G, generator: &mut G,
ctx: &'ctx Context, ctx: &'ctx Context,
value: u64, value: u64,
sign_extend: bool,
) -> Instance<'ctx, Self> { ) -> Instance<'ctx, Self> {
let value = self.get_type(generator, ctx).const_int(value, false); let value = self.llvm_type(generator, ctx).const_int(value, sign_extend);
self.believe_value(value) unsafe { self.believe_value(value) }
} }
pub fn const_0<G: CodeGenerator + ?Sized>( pub fn const_0<G: CodeGenerator + ?Sized>(
@ -144,8 +149,8 @@ impl<'ctx, N: IntKind<'ctx>> Int<N> {
generator: &mut G, generator: &mut G,
ctx: &'ctx Context, ctx: &'ctx Context,
) -> Instance<'ctx, Self> { ) -> Instance<'ctx, Self> {
let value = self.get_type(generator, ctx).const_zero(); let value = self.llvm_type(generator, ctx).const_zero();
self.believe_value(value) unsafe { self.believe_value(value) }
} }
pub fn const_1<G: CodeGenerator + ?Sized>( pub fn const_1<G: CodeGenerator + ?Sized>(
@ -153,7 +158,7 @@ impl<'ctx, N: IntKind<'ctx>> Int<N> {
generator: &mut G, generator: &mut G,
ctx: &'ctx Context, ctx: &'ctx Context,
) -> Instance<'ctx, Self> { ) -> Instance<'ctx, Self> {
self.const_int(generator, ctx, 1) self.const_int(generator, ctx, 1, false)
} }
pub fn const_all_ones<G: CodeGenerator + ?Sized>( pub fn const_all_ones<G: CodeGenerator + ?Sized>(
@ -161,8 +166,8 @@ impl<'ctx, N: IntKind<'ctx>> Int<N> {
generator: &mut G, generator: &mut G,
ctx: &'ctx Context, ctx: &'ctx Context,
) -> Instance<'ctx, Self> { ) -> Instance<'ctx, Self> {
let value = self.get_type(generator, ctx).const_all_ones(); let value = self.llvm_type(generator, ctx).const_all_ones();
self.believe_value(value) unsafe { self.believe_value(value) }
} }
pub fn s_extend_or_bit_cast<G: CodeGenerator + ?Sized>( pub fn s_extend_or_bit_cast<G: CodeGenerator + ?Sized>(
@ -177,9 +182,9 @@ impl<'ctx, N: IntKind<'ctx>> Int<N> {
); );
let value = ctx let value = ctx
.builder .builder
.build_int_s_extend_or_bit_cast(value, self.get_type(generator, ctx.ctx), "") .build_int_s_extend_or_bit_cast(value, self.llvm_type(generator, ctx.ctx), "")
.unwrap(); .unwrap();
self.believe_value(value) unsafe { self.believe_value(value) }
} }
pub fn s_extend<G: CodeGenerator + ?Sized>( pub fn s_extend<G: CodeGenerator + ?Sized>(
@ -193,8 +198,8 @@ impl<'ctx, N: IntKind<'ctx>> Int<N> {
< self.0.get_int_type(generator, ctx.ctx).get_bit_width() < self.0.get_int_type(generator, ctx.ctx).get_bit_width()
); );
let value = let value =
ctx.builder.build_int_s_extend(value, self.get_type(generator, ctx.ctx), "").unwrap(); ctx.builder.build_int_s_extend(value, self.llvm_type(generator, ctx.ctx), "").unwrap();
self.believe_value(value) unsafe { self.believe_value(value) }
} }
pub fn z_extend_or_bit_cast<G: CodeGenerator + ?Sized>( pub fn z_extend_or_bit_cast<G: CodeGenerator + ?Sized>(
@ -209,9 +214,9 @@ impl<'ctx, N: IntKind<'ctx>> Int<N> {
); );
let value = ctx let value = ctx
.builder .builder
.build_int_z_extend_or_bit_cast(value, self.get_type(generator, ctx.ctx), "") .build_int_z_extend_or_bit_cast(value, self.llvm_type(generator, ctx.ctx), "")
.unwrap(); .unwrap();
self.believe_value(value) unsafe { self.believe_value(value) }
} }
pub fn z_extend<G: CodeGenerator + ?Sized>( pub fn z_extend<G: CodeGenerator + ?Sized>(
@ -225,8 +230,8 @@ impl<'ctx, N: IntKind<'ctx>> Int<N> {
< self.0.get_int_type(generator, ctx.ctx).get_bit_width() < self.0.get_int_type(generator, ctx.ctx).get_bit_width()
); );
let value = let value =
ctx.builder.build_int_z_extend(value, self.get_type(generator, ctx.ctx), "").unwrap(); ctx.builder.build_int_z_extend(value, self.llvm_type(generator, ctx.ctx), "").unwrap();
self.believe_value(value) unsafe { self.believe_value(value) }
} }
pub fn truncate_or_bit_cast<G: CodeGenerator + ?Sized>( pub fn truncate_or_bit_cast<G: CodeGenerator + ?Sized>(
@ -241,9 +246,9 @@ impl<'ctx, N: IntKind<'ctx>> Int<N> {
); );
let value = ctx let value = ctx
.builder .builder
.build_int_truncate_or_bit_cast(value, self.get_type(generator, ctx.ctx), "") .build_int_truncate_or_bit_cast(value, self.llvm_type(generator, ctx.ctx), "")
.unwrap(); .unwrap();
self.believe_value(value) unsafe { self.believe_value(value) }
} }
pub fn truncate<G: CodeGenerator + ?Sized>( pub fn truncate<G: CodeGenerator + ?Sized>(
@ -257,8 +262,8 @@ impl<'ctx, N: IntKind<'ctx>> Int<N> {
> self.0.get_int_type(generator, ctx.ctx).get_bit_width() > self.0.get_int_type(generator, ctx.ctx).get_bit_width()
); );
let value = let value =
ctx.builder.build_int_truncate(value, self.get_type(generator, ctx.ctx), "").unwrap(); ctx.builder.build_int_truncate(value, self.llvm_type(generator, ctx.ctx), "").unwrap();
self.believe_value(value) unsafe { self.believe_value(value) }
} }
/// `sext` or `trunc` an int to this model's int type. Does nothing if equal bit-widths. /// `sext` or `trunc` an int to this model's int type. Does nothing if equal bit-widths.
@ -272,7 +277,7 @@ impl<'ctx, N: IntKind<'ctx>> Int<N> {
let our_width = self.0.get_int_type(generator, ctx.ctx).get_bit_width(); let our_width = self.0.get_int_type(generator, ctx.ctx).get_bit_width();
match their_width.cmp(&our_width) { match their_width.cmp(&our_width) {
Ordering::Less => self.s_extend(generator, ctx, value), Ordering::Less => self.s_extend(generator, ctx, value),
Ordering::Equal => self.believe_value(value), Ordering::Equal => unsafe { self.believe_value(value) },
Ordering::Greater => self.truncate(generator, ctx, value), Ordering::Greater => self.truncate(generator, ctx, value),
} }
} }
@ -288,7 +293,7 @@ impl<'ctx, N: IntKind<'ctx>> Int<N> {
let our_width = self.0.get_int_type(generator, ctx.ctx).get_bit_width(); let our_width = self.0.get_int_type(generator, ctx.ctx).get_bit_width();
match their_width.cmp(&our_width) { match their_width.cmp(&our_width) {
Ordering::Less => self.z_extend(generator, ctx, value), Ordering::Less => self.z_extend(generator, ctx, value),
Ordering::Equal => self.believe_value(value), Ordering::Equal => unsafe { self.believe_value(value) },
Ordering::Greater => self.truncate(generator, ctx, value), Ordering::Greater => self.truncate(generator, ctx, value),
} }
} }
@ -301,7 +306,7 @@ impl Int<Bool> {
generator: &mut G, generator: &mut G,
ctx: &'ctx Context, ctx: &'ctx Context,
) -> Instance<'ctx, Self> { ) -> Instance<'ctx, Self> {
self.const_int(generator, ctx, 0) self.const_int(generator, ctx, 0, false)
} }
#[must_use] #[must_use]
@ -310,7 +315,7 @@ impl Int<Bool> {
generator: &mut G, generator: &mut G,
ctx: &'ctx Context, ctx: &'ctx Context,
) -> Instance<'ctx, Self> { ) -> Instance<'ctx, Self> {
self.const_int(generator, ctx, 1) self.const_int(generator, ctx, 1, false)
} }
} }
@ -390,19 +395,19 @@ impl<'ctx, N: IntKind<'ctx>> Instance<'ctx, Int<N>> {
#[must_use] #[must_use]
pub fn add(&self, ctx: &CodeGenContext<'ctx, '_>, other: Self) -> Self { pub fn add(&self, ctx: &CodeGenContext<'ctx, '_>, other: Self) -> Self {
let value = ctx.builder.build_int_add(self.value, other.value, "").unwrap(); let value = ctx.builder.build_int_add(self.value, other.value, "").unwrap();
self.model.believe_value(value) unsafe { self.model.believe_value(value) }
} }
#[must_use] #[must_use]
pub fn sub(&self, ctx: &CodeGenContext<'ctx, '_>, other: Self) -> Self { pub fn sub(&self, ctx: &CodeGenContext<'ctx, '_>, other: Self) -> Self {
let value = ctx.builder.build_int_sub(self.value, other.value, "").unwrap(); let value = ctx.builder.build_int_sub(self.value, other.value, "").unwrap();
self.model.believe_value(value) unsafe { self.model.believe_value(value) }
} }
#[must_use] #[must_use]
pub fn mul(&self, ctx: &CodeGenContext<'ctx, '_>, other: Self) -> Self { pub fn mul(&self, ctx: &CodeGenContext<'ctx, '_>, other: Self) -> Self {
let value = ctx.builder.build_int_mul(self.value, other.value, "").unwrap(); let value = ctx.builder.build_int_mul(self.value, other.value, "").unwrap();
self.model.believe_value(value) unsafe { self.model.believe_value(value) }
} }
pub fn compare( pub fn compare(
@ -412,6 +417,6 @@ impl<'ctx, N: IntKind<'ctx>> Instance<'ctx, Int<N>> {
other: Self, other: Self,
) -> Instance<'ctx, Int<Bool>> { ) -> Instance<'ctx, Int<Bool>> {
let value = ctx.builder.build_int_compare(op, self.value, other.value, "").unwrap(); let value = ctx.builder.build_int_compare(op, self.value, other.value, "").unwrap();
Int(Bool).believe_value(value) unsafe { Int(Bool).believe_value(value) }
} }
} }

View File

@ -31,9 +31,13 @@ impl<'ctx, Item: Model<'ctx>> Model<'ctx> for Ptr<Item> {
type Value = PointerValue<'ctx>; type Value = PointerValue<'ctx>;
type Type = PointerType<'ctx>; type Type = PointerType<'ctx>;
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type { fn llvm_type<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &'ctx Context,
) -> Self::Type {
// TODO: LLVM 15: ctx.ptr_type(AddressSpace::default()) // TODO: LLVM 15: ctx.ptr_type(AddressSpace::default())
self.0.get_type(generator, ctx).ptr_type(AddressSpace::default()) self.0.llvm_type(generator, ctx).ptr_type(AddressSpace::default())
} }
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>( fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
@ -71,8 +75,8 @@ impl<'ctx, Item: Model<'ctx>> Ptr<Item> {
generator: &mut G, generator: &mut G,
ctx: &'ctx Context, ctx: &'ctx Context,
) -> Instance<'ctx, Ptr<Item>> { ) -> Instance<'ctx, Ptr<Item>> {
let ptr = self.get_type(generator, ctx).const_null(); let ptr = self.llvm_type(generator, ctx).const_null();
self.believe_value(ptr) unsafe { self.believe_value(ptr) }
} }
/// Cast a pointer into this model with [`inkwell::builder::Builder::build_pointer_cast`] /// Cast a pointer into this model with [`inkwell::builder::Builder::build_pointer_cast`]
@ -87,9 +91,9 @@ impl<'ctx, Item: Model<'ctx>> Ptr<Item> {
// ``` // ```
// return self.believe_value(ptr); // return self.believe_value(ptr);
// ``` // ```
let t = self.get_type(generator, ctx.ctx); let t = self.llvm_type(generator, ctx.ctx);
let ptr = ctx.builder.build_pointer_cast(ptr, t, "").unwrap(); let ptr = ctx.builder.build_pointer_cast(ptr, t, "").unwrap();
self.believe_value(ptr) unsafe { self.believe_value(ptr) }
} }
} }
@ -102,7 +106,7 @@ impl<'ctx, Item: Model<'ctx>> Instance<'ctx, Ptr<Item>> {
offset: IntValue<'ctx>, offset: IntValue<'ctx>,
) -> Instance<'ctx, Ptr<Item>> { ) -> Instance<'ctx, Ptr<Item>> {
let p = unsafe { ctx.builder.build_in_bounds_gep(self.value, &[offset], "").unwrap() }; let p = unsafe { ctx.builder.build_in_bounds_gep(self.value, &[offset], "").unwrap() };
self.model.believe_value(p) unsafe { self.model.believe_value(p) }
} }
/// Offset the pointer by [`inkwell::builder::Builder::build_in_bounds_gep`] by a constant offset. /// Offset the pointer by [`inkwell::builder::Builder::build_in_bounds_gep`] by a constant offset.
@ -110,9 +114,9 @@ impl<'ctx, Item: Model<'ctx>> Instance<'ctx, Ptr<Item>> {
pub fn offset_const( pub fn offset_const(
&self, &self,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
offset: u64, offset: i64,
) -> Instance<'ctx, Ptr<Item>> { ) -> Instance<'ctx, Ptr<Item>> {
let offset = ctx.ctx.i32_type().const_int(offset, false); let offset = ctx.ctx.i32_type().const_int(offset as u64, true);
self.offset(ctx, offset) self.offset(ctx, offset)
} }
@ -128,7 +132,7 @@ impl<'ctx, Item: Model<'ctx>> Instance<'ctx, Ptr<Item>> {
pub fn set_index_const( pub fn set_index_const(
&self, &self,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
index: u64, index: i64,
value: Instance<'ctx, Item>, value: Instance<'ctx, Item>,
) { ) {
self.offset_const(ctx, index).store(ctx, value); self.offset_const(ctx, index).store(ctx, value);
@ -147,7 +151,7 @@ impl<'ctx, Item: Model<'ctx>> Instance<'ctx, Ptr<Item>> {
&self, &self,
generator: &mut G, generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
index: u64, index: i64,
) -> Instance<'ctx, Item> { ) -> Instance<'ctx, Item> {
self.offset_const(ctx, index).load(generator, ctx) self.offset_const(ctx, index).load(generator, ctx)
} }
@ -190,13 +194,13 @@ impl<'ctx, Item: Model<'ctx>> Instance<'ctx, Ptr<Item>> {
/// Check if the pointer is null with [`inkwell::builder::Builder::build_is_null`]. /// Check if the pointer is null with [`inkwell::builder::Builder::build_is_null`].
pub fn is_null(&self, ctx: &CodeGenContext<'ctx, '_>) -> Instance<'ctx, Int<Bool>> { pub fn is_null(&self, ctx: &CodeGenContext<'ctx, '_>) -> Instance<'ctx, Int<Bool>> {
let value = ctx.builder.build_is_null(self.value, "").unwrap(); let value = ctx.builder.build_is_null(self.value, "").unwrap();
Int(Bool).believe_value(value) unsafe { Int(Bool).believe_value(value) }
} }
/// Check if the pointer is not null with [`inkwell::builder::Builder::build_is_not_null`]. /// Check if the pointer is not null with [`inkwell::builder::Builder::build_is_not_null`].
pub fn is_not_null(&self, ctx: &CodeGenContext<'ctx, '_>) -> Instance<'ctx, Int<Bool>> { pub fn is_not_null(&self, ctx: &CodeGenContext<'ctx, '_>) -> Instance<'ctx, Int<Bool>> {
let value = ctx.builder.build_is_not_null(self.value, "").unwrap(); let value = ctx.builder.build_is_not_null(self.value, "").unwrap();
Int(Bool).believe_value(value) unsafe { Int(Bool).believe_value(value) }
} }
/// `memcpy` from another pointer. /// `memcpy` from another pointer.
@ -208,9 +212,9 @@ impl<'ctx, Item: Model<'ctx>> Instance<'ctx, Ptr<Item>> {
num_items: IntValue<'ctx>, num_items: IntValue<'ctx>,
) { ) {
// Force extend `num_items` and `itemsize` to `i64` so their types would match. // Force extend `num_items` and `itemsize` to `i64` so their types would match.
let itemsize = self.model.sizeof(generator, ctx.ctx); let itemsize = self.model.size_of(generator, ctx.ctx);
let itemsize = Int(Int64).z_extend_or_truncate(generator, ctx, itemsize); let itemsize = Int(SizeT).z_extend_or_truncate(generator, ctx, itemsize);
let num_items = Int(Int64).z_extend_or_truncate(generator, ctx, num_items); let num_items = Int(SizeT).z_extend_or_truncate(generator, ctx, num_items);
let totalsize = itemsize.mul(ctx, num_items); let totalsize = itemsize.mul(ctx, num_items);
let is_volatile = ctx.ctx.bool_type().const_zero(); // is_volatile = false let is_volatile = ctx.ctx.bool_type().const_zero(); // is_volatile = false

View File

@ -13,16 +13,16 @@ use super::*;
/// A traveral that traverses a Rust `struct` that is used to declare an LLVM's struct's field types. /// A traveral that traverses a Rust `struct` that is used to declare an LLVM's struct's field types.
pub trait FieldTraversal<'ctx> { pub trait FieldTraversal<'ctx> {
/// Output type of [`FieldTraversal::add`]. /// Output type of [`FieldTraversal::add`].
type Out<M>; type Output<M>;
/// Traverse through the type of a declared field and do something with it. /// Traverse through the type of a declared field and do something with it.
/// ///
/// * `name` - The cosmetic name of the LLVM field. Used for debugging. /// * `name` - The cosmetic name of the LLVM field. Used for debugging.
/// * `model` - The [`Model`] representing the LLVM type of this field. /// * `model` - The [`Model`] representing the LLVM type of this field.
fn add<M: Model<'ctx>>(&mut self, name: &'static str, model: M) -> Self::Out<M>; fn add<M: Model<'ctx>>(&mut self, name: &'static str, model: M) -> Self::Output<M>;
/// Like [`FieldTraversal::add`] but [`Model`] is automatically inferred from its [`Default`] trait. /// Like [`FieldTraversal::add`] but [`Model`] is automatically inferred from its [`Default`] trait.
fn add_auto<M: Model<'ctx> + Default>(&mut self, name: &'static str) -> Self::Out<M> { fn add_auto<M: Model<'ctx> + Default>(&mut self, name: &'static str) -> Self::Output<M> {
self.add(name, M::default()) self.add(name, M::default())
} }
} }
@ -31,7 +31,7 @@ pub trait FieldTraversal<'ctx> {
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub struct GepField<M> { pub struct GepField<M> {
/// The GEP index of this field. This is the index to use with `build_gep`. /// The GEP index of this field. This is the index to use with `build_gep`.
pub gep_index: u64, pub gep_index: u32,
/// The cosmetic name of this field. /// The cosmetic name of this field.
pub name: &'static str, pub name: &'static str,
/// The [`Model`] of this field's type. /// The [`Model`] of this field's type.
@ -41,16 +41,16 @@ pub struct GepField<M> {
/// A traversal to calculate the GEP index of fields. /// A traversal to calculate the GEP index of fields.
pub struct GepFieldTraversal { pub struct GepFieldTraversal {
/// The current GEP index. /// The current GEP index.
gep_index_counter: u64, gep_index_counter: u32,
} }
impl<'ctx> FieldTraversal<'ctx> for GepFieldTraversal { impl<'ctx> FieldTraversal<'ctx> for GepFieldTraversal {
type Out<M> = GepField<M>; type Output<M> = GepField<M>;
fn add<M: Model<'ctx>>(&mut self, name: &'static str, model: M) -> Self::Out<M> { fn add<M: Model<'ctx>>(&mut self, name: &'static str, model: M) -> Self::Output<M> {
let gep_index = self.gep_index_counter; let gep_index = self.gep_index_counter;
self.gep_index_counter += 1; self.gep_index_counter += 1;
Self::Out { gep_index, name, model } Self::Output { gep_index, name, model }
} }
} }
@ -65,10 +65,10 @@ struct TypeFieldTraversal<'ctx, 'a, G: CodeGenerator + ?Sized> {
} }
impl<'ctx, 'a, G: CodeGenerator + ?Sized> FieldTraversal<'ctx> for TypeFieldTraversal<'ctx, 'a, G> { impl<'ctx, 'a, G: CodeGenerator + ?Sized> FieldTraversal<'ctx> for TypeFieldTraversal<'ctx, 'a, G> {
type Out<M> = (); // Checking types return nothing. type Output<M> = (); // Checking types return nothing.
fn add<M: Model<'ctx>>(&mut self, _name: &'static str, model: M) -> Self::Out<M> { fn add<M: Model<'ctx>>(&mut self, _name: &'static str, model: M) -> Self::Output<M> {
let t = model.get_type(self.generator, self.ctx).as_basic_type_enum(); let t = model.llvm_type(self.generator, self.ctx).as_basic_type_enum();
self.field_types.push(t); self.field_types.push(t);
} }
} }
@ -89,9 +89,9 @@ struct CheckTypeFieldTraversal<'ctx, 'a, G: CodeGenerator + ?Sized> {
impl<'ctx, 'a, G: CodeGenerator + ?Sized> FieldTraversal<'ctx> impl<'ctx, 'a, G: CodeGenerator + ?Sized> FieldTraversal<'ctx>
for CheckTypeFieldTraversal<'ctx, 'a, G> for CheckTypeFieldTraversal<'ctx, 'a, G>
{ {
type Out<M> = (); // Checking types return nothing. type Output<M> = (); // Checking types return nothing.
fn add<M: Model<'ctx>>(&mut self, name: &'static str, model: M) -> Self::Out<M> { fn add<M: Model<'ctx>>(&mut self, name: &'static str, model: M) -> Self::Output<M> {
let gep_index = self.gep_index_counter; let gep_index = self.gep_index_counter;
self.gep_index_counter += 1; self.gep_index_counter += 1;
@ -100,7 +100,8 @@ impl<'ctx, 'a, G: CodeGenerator + ?Sized> FieldTraversal<'ctx>
self.errors self.errors
.push(err.under_context(format!("field #{gep_index} '{name}'").as_str())); .push(err.under_context(format!("field #{gep_index} '{name}'").as_str()));
} }
} // Otherwise, it will be caught by Struct's `check_type`. }
// Otherwise, it will be caught by Struct's `check_type`.
} }
} }
@ -192,13 +193,13 @@ pub trait StructKind<'ctx>: fmt::Debug + Clone + Copy {
/// Traverse through all fields of this [`StructKind`]. /// Traverse through all fields of this [`StructKind`].
/// ///
/// Only used internally in this module for implementing other components. /// Only used internally in this module for implementing other components.
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F>; fn iter_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F>;
/// Get a convenience structure to get a struct field's GEP index through its corresponding Rust field. /// Get a convenience structure to get a struct field's GEP index through its corresponding Rust field.
/// ///
/// Only used internally in this module for implementing other components. /// Only used internally in this module for implementing other components.
fn fields(&self) -> Self::Fields<GepFieldTraversal> { fn fields(&self) -> Self::Fields<GepFieldTraversal> {
self.traverse_fields(&mut GepFieldTraversal { gep_index_counter: 0 }) self.iter_fields(&mut GepFieldTraversal { gep_index_counter: 0 })
} }
/// Get the LLVM [`StructType`] of this [`StructKind`]. /// Get the LLVM [`StructType`] of this [`StructKind`].
@ -208,7 +209,7 @@ pub trait StructKind<'ctx>: fmt::Debug + Clone + Copy {
ctx: &'ctx Context, ctx: &'ctx Context,
) -> StructType<'ctx> { ) -> StructType<'ctx> {
let mut traversal = TypeFieldTraversal { generator, ctx, field_types: Vec::new() }; let mut traversal = TypeFieldTraversal { generator, ctx, field_types: Vec::new() };
self.traverse_fields(&mut traversal); self.iter_fields(&mut traversal);
ctx.struct_type(&traversal.field_types, false) ctx.struct_type(&traversal.field_types, false)
} }
@ -242,7 +243,11 @@ impl<'ctx, S: StructKind<'ctx>> Model<'ctx> for Struct<S> {
type Value = StructValue<'ctx>; type Value = StructValue<'ctx>;
type Type = StructType<'ctx>; type Type = StructType<'ctx>;
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type { fn llvm_type<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &'ctx Context,
) -> Self::Type {
self.0.get_struct_type(generator, ctx) self.0.get_struct_type(generator, ctx)
} }
@ -265,7 +270,7 @@ impl<'ctx, S: StructKind<'ctx>> Model<'ctx> for Struct<S> {
errors: Vec::new(), errors: Vec::new(),
scrutinee: ty, scrutinee: ty,
}; };
self.0.traverse_fields(&mut traversal); self.0.iter_fields(&mut traversal);
// Check the number of fields. // Check the number of fields.
let exp_num_fields = traversal.gep_index_counter; let exp_num_fields = traversal.gep_index_counter;
@ -298,7 +303,7 @@ impl<'ctx, S: StructKind<'ctx>> Instance<'ctx, Struct<S>> {
GetField: FnOnce(S::Fields<GepFieldTraversal>) -> GepField<M>, GetField: FnOnce(S::Fields<GepFieldTraversal>) -> GepField<M>,
{ {
let field = get_field(self.model.0.fields()); let field = get_field(self.model.0.fields());
let val = self.value.get_field_at_index(field.gep_index as u32).unwrap(); let val = self.value.get_field_at_index(field.gep_index).unwrap();
field.model.check_value(generator, ctx, val).unwrap() field.model.check_value(generator, ctx, val).unwrap()
} }
} }
@ -321,13 +326,13 @@ impl<'ctx, S: StructKind<'ctx>> Instance<'ctx, Ptr<Struct<S>>> {
ctx.builder ctx.builder
.build_in_bounds_gep( .build_in_bounds_gep(
self.value, self.value,
&[llvm_i32.const_zero(), llvm_i32.const_int(field.gep_index, false)], &[llvm_i32.const_zero(), llvm_i32.const_int(u64::from(field.gep_index), false)],
field.name, field.name,
) )
.unwrap() .unwrap()
}; };
Ptr(field.model).believe_value(ptr) unsafe { Ptr(field.model).believe_value(ptr) }
} }
/// Convenience function equivalent to `.gep(...).load(...)`. /// Convenience function equivalent to `.gep(...).load(...)`.

View File

@ -34,7 +34,7 @@ where
start.value, start.value,
(stop.value, false), (stop.value, false),
|g, ctx, hooks, i| { |g, ctx, hooks, i| {
let i = int_model.believe_value(i); let i = unsafe { int_model.believe_value(i) };
body(g, ctx, hooks, i) body(g, ctx, hooks, i)
}, },
step.value, step.value,

View File

@ -12,6 +12,7 @@ use crate::{
call_ndarray_calc_size, call_ndarray_calc_size,
}, },
llvm_intrinsics::{self, call_memcpy_generic}, llvm_intrinsics::{self, call_memcpy_generic},
macros::codegen_unreachable,
model::*, model::*,
object::{ object::{
any::AnyObject, any::AnyObject,
@ -264,7 +265,7 @@ fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) { } else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
ctx.gen_string(generator, "").into() ctx.gen_string(generator, "").into()
} else { } else {
unreachable!() codegen_unreachable!(ctx)
} }
} }
@ -292,7 +293,7 @@ fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) { } else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
ctx.gen_string(generator, "1").into() ctx.gen_string(generator, "1").into()
} else { } else {
unreachable!() codegen_unreachable!(ctx)
} }
} }
@ -360,7 +361,7 @@ fn call_ndarray_empty_impl<'ctx, G: CodeGenerator + ?Sized>(
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int]) create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
} }
_ => unreachable!(), _ => codegen_unreachable!(ctx),
} }
} }
@ -631,7 +632,7 @@ fn call_ndarray_full_impl<'ctx, G: CodeGenerator + ?Sized>(
} else if fill_value.is_int_value() || fill_value.is_float_value() { } else if fill_value.is_int_value() || fill_value.is_float_value() {
fill_value fill_value
} else { } else {
unreachable!() codegen_unreachable!(ctx)
}; };
Ok(value) Ok(value)
@ -2051,7 +2052,7 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
.build_float_mul(e1, elem2.into_float_value(), "") .build_float_mul(e1, elem2.into_float_value(), "")
.unwrap() .unwrap()
.as_basic_value_enum(), .as_basic_value_enum(),
_ => unreachable!(), _ => codegen_unreachable!(ctx),
}; };
let acc_val = ctx.builder.build_load(acc, "").unwrap(); let acc_val = ctx.builder.build_load(acc, "").unwrap();
let acc_val = match acc_val { let acc_val = match acc_val {
@ -2065,7 +2066,7 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
.build_float_add(e1, product.into_float_value(), "") .build_float_add(e1, product.into_float_value(), "")
.unwrap() .unwrap()
.as_basic_value_enum(), .as_basic_value_enum(),
_ => unreachable!(), _ => codegen_unreachable!(ctx),
}; };
ctx.builder.build_store(acc, acc_val).unwrap(); ctx.builder.build_store(acc, acc_val).unwrap();
@ -2082,7 +2083,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
(BasicValueEnum::FloatValue(e1), BasicValueEnum::FloatValue(e2)) => { (BasicValueEnum::FloatValue(e1), BasicValueEnum::FloatValue(e2)) => {
Ok(ctx.builder.build_float_mul(e1, e2, "").unwrap().as_basic_value_enum()) Ok(ctx.builder.build_float_mul(e1, e2, "").unwrap().as_basic_value_enum())
} }
_ => unreachable!( _ => codegen_unreachable!(
ctx,
"{FN_NAME}() not supported for '{}'", "{FN_NAME}() not supported for '{}'",
format!("'{}'", ctx.unifier.stringify(x1_ty)) format!("'{}'", ctx.unifier.stringify(x1_ty))
), ),

View File

@ -8,9 +8,9 @@ use super::any::AnyObject;
/// Fields of [`List`] /// Fields of [`List`]
pub struct ListFields<'ctx, F: FieldTraversal<'ctx>, Item: Model<'ctx>> { pub struct ListFields<'ctx, F: FieldTraversal<'ctx>, Item: Model<'ctx>> {
/// Array pointer to content /// Array pointer to content
pub items: F::Out<Ptr<Item>>, pub items: F::Output<Ptr<Item>>,
/// Number of items in the array /// Number of items in the array
pub len: F::Out<Int<SizeT>>, pub len: F::Output<Int<SizeT>>,
} }
/// A list in NAC3. /// A list in NAC3.
@ -23,7 +23,7 @@ pub struct List<Item> {
impl<'ctx, Item: Model<'ctx>> StructKind<'ctx> for List<Item> { impl<'ctx, Item: Model<'ctx>> StructKind<'ctx> for List<Item> {
type Fields<F: FieldTraversal<'ctx>> = ListFields<'ctx, F, Item>; type Fields<F: FieldTraversal<'ctx>> = ListFields<'ctx, F, Item>;
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> { fn iter_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
Self::Fields { Self::Fields {
items: traversal.add("items", Ptr(self.item)), items: traversal.add("items", Ptr(self.item)),
len: traversal.add_auto("len"), len: traversal.add_auto("len"),

View File

@ -40,7 +40,7 @@ impl<'ctx> NDArrayObject<'ctx> {
// Validate `list` has a consistent shape. // Validate `list` has a consistent shape.
// Raise an exception if `list` is something abnormal like `[[1, 2], [3]]`. // Raise an exception if `list` is something abnormal like `[[1, 2], [3]]`.
// If `list` has a consistent shape, deduce the shape and write it to `shape`. // If `list` has a consistent shape, deduce the shape and write it to `shape`.
let ndims = Int(SizeT).const_int(generator, ctx.ctx, ndims_int); let ndims = Int(SizeT).const_int(generator, ctx.ctx, ndims_int, false);
let shape = Int(SizeT).array_alloca(generator, ctx, ndims.value); let shape = Int(SizeT).array_alloca(generator, ctx, ndims.value);
call_nac3_ndarray_array_set_and_validate_list_shape( call_nac3_ndarray_array_set_and_validate_list_shape(
generator, ctx, list_value, ndims, shape, generator, ctx, list_value, ndims, shape,

View File

@ -10,8 +10,8 @@ use super::NDArrayObject;
/// Fields of [`ShapeEntry`] /// Fields of [`ShapeEntry`]
pub struct ShapeEntryFields<'ctx, F: FieldTraversal<'ctx>> { pub struct ShapeEntryFields<'ctx, F: FieldTraversal<'ctx>> {
pub ndims: F::Out<Int<SizeT>>, pub ndims: F::Output<Int<SizeT>>,
pub shape: F::Out<Ptr<Int<SizeT>>>, pub shape: F::Output<Ptr<Int<SizeT>>>,
} }
/// An IRRT structure used in broadcasting. /// An IRRT structure used in broadcasting.
@ -21,7 +21,7 @@ pub struct ShapeEntry;
impl<'ctx> StructKind<'ctx> for ShapeEntry { impl<'ctx> StructKind<'ctx> for ShapeEntry {
type Fields<F: FieldTraversal<'ctx>> = ShapeEntryFields<'ctx, F>; type Fields<F: FieldTraversal<'ctx>> = ShapeEntryFields<'ctx, F>;
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> { fn iter_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
Self::Fields { ndims: traversal.add_auto("ndims"), shape: traversal.add_auto("shape") } Self::Fields { ndims: traversal.add_auto("ndims"), shape: traversal.add_auto("shape") }
} }
} }
@ -73,19 +73,23 @@ fn broadcast_shapes<'ctx, G: CodeGenerator + ?Sized>(
broadcast_shape: Instance<'ctx, Ptr<Int<SizeT>>>, broadcast_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
) { ) {
// Prepare input shape entries to be passed to `call_nac3_ndarray_broadcast_shapes`. // Prepare input shape entries to be passed to `call_nac3_ndarray_broadcast_shapes`.
let num_shape_entries = let num_shape_entries = Int(SizeT).const_int(
Int(SizeT).const_int(generator, ctx.ctx, u64::try_from(in_shape_entries.len()).unwrap()); generator,
ctx.ctx,
u64::try_from(in_shape_entries.len()).unwrap(),
false,
);
let shape_entries = Struct(ShapeEntry).array_alloca(generator, ctx, num_shape_entries.value); let shape_entries = Struct(ShapeEntry).array_alloca(generator, ctx, num_shape_entries.value);
for (i, (in_shape, in_ndims)) in in_shape_entries.iter().enumerate() { for (i, (in_shape, in_ndims)) in in_shape_entries.iter().enumerate() {
let pshape_entry = shape_entries.offset_const(ctx, i as u64); let pshape_entry = shape_entries.offset_const(ctx, i64::try_from(i).unwrap());
let in_ndims = Int(SizeT).const_int(generator, ctx.ctx, *in_ndims); let in_ndims = Int(SizeT).const_int(generator, ctx.ctx, *in_ndims, false);
pshape_entry.set(ctx, |f| f.ndims, in_ndims); pshape_entry.set(ctx, |f| f.ndims, in_ndims);
pshape_entry.set(ctx, |f| f.shape, *in_shape); pshape_entry.set(ctx, |f| f.shape, *in_shape);
} }
let broadcast_ndims = Int(SizeT).const_int(generator, ctx.ctx, broadcast_ndims); let broadcast_ndims = Int(SizeT).const_int(generator, ctx.ctx, broadcast_ndims, false);
call_nac3_ndarray_broadcast_shapes( call_nac3_ndarray_broadcast_shapes(
generator, generator,
ctx, ctx,
@ -109,7 +113,7 @@ impl<'ctx> NDArrayObject<'ctx> {
// Infer the broadcast output ndims. // Infer the broadcast output ndims.
let broadcast_ndims_int = ndarrays.iter().map(|ndarray| ndarray.ndims).max().unwrap(); let broadcast_ndims_int = ndarrays.iter().map(|ndarray| ndarray.ndims).max().unwrap();
let broadcast_ndims = Int(SizeT).const_int(generator, ctx.ctx, broadcast_ndims_int); let broadcast_ndims = Int(SizeT).const_int(generator, ctx.ctx, broadcast_ndims_int, false);
let broadcast_shape = Int(SizeT).array_alloca(generator, ctx, broadcast_ndims.value); let broadcast_shape = Int(SizeT).array_alloca(generator, ctx, broadcast_ndims.value);
let shape_entries = ndarrays let shape_entries = ndarrays

View File

@ -76,7 +76,7 @@ impl<'ctx> NDArrayObject<'ctx> {
shape: Instance<'ctx, Ptr<Int<SizeT>>>, shape: Instance<'ctx, Ptr<Int<SizeT>>>,
) -> Self { ) -> Self {
// Validate `shape` // Validate `shape`
let ndims_llvm = Int(SizeT).const_int(generator, ctx.ctx, ndims); let ndims_llvm = Int(SizeT).const_int(generator, ctx.ctx, ndims, false);
call_nac3_ndarray_util_assert_shape_no_negative(generator, ctx, ndims_llvm, shape); call_nac3_ndarray_util_assert_shape_no_negative(generator, ctx, ndims_llvm, shape);
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims); let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims);

View File

@ -12,8 +12,8 @@ pub type NDIndexType = Byte;
/// Fields of [`NDIndex`] /// Fields of [`NDIndex`]
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub struct NDIndexFields<'ctx, F: FieldTraversal<'ctx>> { pub struct NDIndexFields<'ctx, F: FieldTraversal<'ctx>> {
pub type_: F::Out<Int<NDIndexType>>, // Defined to be uint8_t in IRRT pub type_: F::Output<Int<NDIndexType>>,
pub data: F::Out<Ptr<Int<Byte>>>, pub data: F::Output<Ptr<Int<Byte>>>,
} }
/// An IRRT representation of an ndarray subscript index. /// An IRRT representation of an ndarray subscript index.
@ -23,7 +23,7 @@ pub struct NDIndex;
impl<'ctx> StructKind<'ctx> for NDIndex { impl<'ctx> StructKind<'ctx> for NDIndex {
type Fields<F: FieldTraversal<'ctx>> = NDIndexFields<'ctx, F>; type Fields<F: FieldTraversal<'ctx>> = NDIndexFields<'ctx, F>;
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> { fn iter_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
Self::Fields { type_: traversal.add_auto("type"), data: traversal.add_auto("data") } Self::Fields { type_: traversal.add_auto("type"), data: traversal.add_auto("data") }
} }
} }
@ -49,7 +49,7 @@ impl<'ctx> RustNDIndex<'ctx> {
} }
} }
/// Write the contents to an LLVM [`NDIndex`]. /// Serialize this [`RustNDIndex`] by writing it into an LLVM [`NDIndex`].
fn write_to_ndindex<G: CodeGenerator + ?Sized>( fn write_to_ndindex<G: CodeGenerator + ?Sized>(
&self, &self,
generator: &mut G, generator: &mut G,
@ -59,7 +59,7 @@ impl<'ctx> RustNDIndex<'ctx> {
// Set `dst_ndindex_ptr->type` // Set `dst_ndindex_ptr->type`
dst_ndindex_ptr.gep(ctx, |f| f.type_).store( dst_ndindex_ptr.gep(ctx, |f| f.type_).store(
ctx, ctx,
Int(NDIndexType::default()).const_int(generator, ctx.ctx, self.get_type_id()), Int(NDIndexType::default()).const_int(generator, ctx.ctx, self.get_type_id(), false),
); );
// Set `dst_ndindex_ptr->data` // Set `dst_ndindex_ptr->data`
@ -84,7 +84,7 @@ impl<'ctx> RustNDIndex<'ctx> {
} }
} }
/// Allocate an array of `NDIndex`es on the stack and return the array pointer. /// Serialize a list of `RustNDIndex` as a newly allocated LLVM array of `NDIndex`.
pub fn make_ndindices<G: CodeGenerator + ?Sized>( pub fn make_ndindices<G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
@ -92,10 +92,14 @@ impl<'ctx> RustNDIndex<'ctx> {
) -> (Instance<'ctx, Int<SizeT>>, Instance<'ctx, Ptr<Struct<NDIndex>>>) { ) -> (Instance<'ctx, Int<SizeT>>, Instance<'ctx, Ptr<Struct<NDIndex>>>) {
let ndindex_model = Struct(NDIndex); let ndindex_model = Struct(NDIndex);
let num_ndindices = Int(SizeT).const_int(generator, ctx.ctx, in_ndindices.len() as u64); // Allocate the LLVM ndindices.
let num_ndindices =
Int(SizeT).const_int(generator, ctx.ctx, in_ndindices.len() as u64, false);
let ndindices = ndindex_model.array_alloca(generator, ctx, num_ndindices.value); let ndindices = ndindex_model.array_alloca(generator, ctx, num_ndindices.value);
// Initialize all of them.
for (i, in_ndindex) in in_ndindices.iter().enumerate() { for (i, in_ndindex) in in_ndindices.iter().enumerate() {
let pndindex = ndindices.offset_const(ctx, i as u64); let pndindex = ndindices.offset_const(ctx, i64::try_from(i).unwrap());
in_ndindex.write_to_ndindex(generator, ctx, pndindex); in_ndindex.write_to_ndindex(generator, ctx, pndindex);
} }
@ -155,10 +159,7 @@ pub mod util {
use nac3parser::ast::{Expr, ExprKind}; use nac3parser::ast::{Expr, ExprKind};
use crate::{ use crate::{
codegen::{ codegen::{model::*, object::utils::slice::util::gen_slice, CodeGenContext, CodeGenerator},
expr::gen_slice, model::*, object::utils::slice::RustSlice, CodeGenContext,
CodeGenerator,
},
typecheck::typedef::Type, typecheck::typedef::Type,
}; };
@ -206,8 +207,8 @@ pub mod util {
// so the code/implementation looks awkward - we have to do pattern matching on the expression // so the code/implementation looks awkward - we have to do pattern matching on the expression
let ndindex = if let ExprKind::Slice { lower, upper, step } = &index_expr.node { let ndindex = if let ExprKind::Slice { lower, upper, step } = &index_expr.node {
// Handle slices // Handle slices
let (lower, upper, step) = gen_slice(generator, ctx, lower, upper, step)?; let slice = gen_slice(generator, ctx, lower, upper, step)?;
RustNDIndex::Slice(RustSlice { int_kind: Int32, start: lower, stop: upper, step }) RustNDIndex::Slice(slice)
} else { } else {
// Treat and handle everything else as a single element index. // Treat and handle everything else as a single element index.
let index = generator.gen_expr(ctx, index_expr)?.unwrap().to_basic_value_enum( let index = generator.gen_expr(ctx, index_expr)?.unwrap().to_basic_value_enum(

View File

@ -32,11 +32,11 @@ use super::{any::AnyObject, tuple::TupleObject};
/// Fields of [`NDArray`] /// Fields of [`NDArray`]
pub struct NDArrayFields<'ctx, F: FieldTraversal<'ctx>> { pub struct NDArrayFields<'ctx, F: FieldTraversal<'ctx>> {
pub data: F::Out<Ptr<Int<Byte>>>, pub data: F::Output<Ptr<Int<Byte>>>,
pub itemsize: F::Out<Int<SizeT>>, pub itemsize: F::Output<Int<SizeT>>,
pub ndims: F::Out<Int<SizeT>>, pub ndims: F::Output<Int<SizeT>>,
pub shape: F::Out<Ptr<Int<SizeT>>>, pub shape: F::Output<Ptr<Int<SizeT>>>,
pub strides: F::Out<Ptr<Int<SizeT>>>, pub strides: F::Output<Ptr<Int<SizeT>>>,
} }
/// A strided ndarray in NAC3. /// A strided ndarray in NAC3.
@ -48,7 +48,7 @@ pub struct NDArray;
impl<'ctx> StructKind<'ctx> for NDArray { impl<'ctx> StructKind<'ctx> for NDArray {
type Fields<F: FieldTraversal<'ctx>> = NDArrayFields<'ctx, F>; type Fields<F: FieldTraversal<'ctx>> = NDArrayFields<'ctx, F>;
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> { fn iter_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
Self::Fields { Self::Fields {
data: traversal.add_auto("data"), data: traversal.add_auto("data"),
itemsize: traversal.add_auto("itemsize"), itemsize: traversal.add_auto("itemsize"),
@ -98,7 +98,7 @@ impl<'ctx> NDArrayObject<'ctx> {
generator: &mut G, generator: &mut G,
ctx: &'ctx Context, ctx: &'ctx Context,
) -> Instance<'ctx, Int<SizeT>> { ) -> Instance<'ctx, Int<SizeT>> {
Int(SizeT).const_int(generator, ctx, self.ndims) Int(SizeT).const_int(generator, ctx, self.ndims, false)
} }
/// Allocate an ndarray on the stack given its `ndims` and `dtype`. /// Allocate an ndarray on the stack given its `ndims` and `dtype`.
@ -123,7 +123,7 @@ impl<'ctx> NDArrayObject<'ctx> {
let itemsize = Int(SizeT).z_extend_or_truncate(generator, ctx, itemsize); let itemsize = Int(SizeT).z_extend_or_truncate(generator, ctx, itemsize);
ndarray.set(ctx, |f| f.itemsize, itemsize); ndarray.set(ctx, |f| f.itemsize, itemsize);
let ndims_val = Int(SizeT).const_int(generator, ctx.ctx, ndims); let ndims_val = Int(SizeT).const_int(generator, ctx.ctx, ndims, false);
ndarray.set(ctx, |f| f.ndims, ndims_val); ndarray.set(ctx, |f| f.ndims, ndims_val);
let shape = Int(SizeT).array_alloca(generator, ctx, ndims_val.value); let shape = Int(SizeT).array_alloca(generator, ctx, ndims_val.value);
@ -149,8 +149,8 @@ impl<'ctx> NDArrayObject<'ctx> {
// Write shape // Write shape
let dst_shape = ndarray.instance.get(generator, ctx, |f| f.shape); let dst_shape = ndarray.instance.get(generator, ctx, |f| f.shape);
for (i, dim) in shape.iter().enumerate() { for (i, dim) in shape.iter().enumerate() {
let dim = Int(SizeT).const_int(generator, ctx.ctx, *dim); let dim = Int(SizeT).const_int(generator, ctx.ctx, *dim, false);
dst_shape.offset_const(ctx, i as u64).store(ctx, dim); dst_shape.offset_const(ctx, i64::try_from(i).unwrap()).store(ctx, dim);
} }
ndarray ndarray
@ -170,7 +170,7 @@ impl<'ctx> NDArrayObject<'ctx> {
// Write shape // Write shape
let dst_shape = ndarray.instance.get(generator, ctx, |f| f.shape); let dst_shape = ndarray.instance.get(generator, ctx, |f| f.shape);
for (i, dim) in shape.iter().enumerate() { for (i, dim) in shape.iter().enumerate() {
dst_shape.offset_const(ctx, i as u64).store(ctx, *dim); dst_shape.offset_const(ctx, i64::try_from(i).unwrap()).store(ctx, *dim);
} }
ndarray ndarray
@ -419,6 +419,8 @@ impl<'ctx> NDArrayObject<'ctx> {
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
value: BasicValueEnum<'ctx>, value: BasicValueEnum<'ctx>,
) { ) {
// TODO: It is possible to optimize this by exploiting contiguous strides with memset.
// Probably best to implement in IRRT.
self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| { self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
let p = nditer.get_pointer(generator, ctx); let p = nditer.get_pointer(generator, ctx);
ctx.builder.build_store(p, value).unwrap(); ctx.builder.build_store(p, value).unwrap();
@ -443,7 +445,7 @@ impl<'ctx> NDArrayObject<'ctx> {
let dim = self let dim = self
.instance .instance
.get(generator, ctx, |f| f.shape) .get(generator, ctx, |f| f.shape)
.get_index_const(generator, ctx, i) .get_index_const(generator, ctx, i64::try_from(i).unwrap())
.truncate_or_bit_cast(generator, ctx, Int32); .truncate_or_bit_cast(generator, ctx, Int32);
objects.push(AnyObject { objects.push(AnyObject {
@ -471,7 +473,7 @@ impl<'ctx> NDArrayObject<'ctx> {
let dim = self let dim = self
.instance .instance
.get(generator, ctx, |f| f.strides) .get(generator, ctx, |f| f.strides)
.get_index_const(generator, ctx, i) .get_index_const(generator, ctx, i64::try_from(i).unwrap())
.truncate_or_bit_cast(generator, ctx, Int32); .truncate_or_bit_cast(generator, ctx, Int32);
objects.push(AnyObject { objects.push(AnyObject {

View File

@ -1,7 +1,7 @@
use inkwell::{types::BasicType, values::PointerValue, AddressSpace}; use inkwell::{types::BasicType, values::PointerValue, AddressSpace};
use crate::codegen::{ use crate::codegen::{
irrt::{call_nac3_nditer_has_next, call_nac3_nditer_initialize, call_nac3_nditer_next}, irrt::{call_nac3_nditer_has_element, call_nac3_nditer_initialize, call_nac3_nditer_next},
model::*, model::*,
object::any::AnyObject, object::any::AnyObject,
stmt::{gen_for_callback, BreakContinueHooks}, stmt::{gen_for_callback, BreakContinueHooks},
@ -12,15 +12,15 @@ use super::NDArrayObject;
/// Fields of [`NDIter`] /// Fields of [`NDIter`]
pub struct NDIterFields<'ctx, F: FieldTraversal<'ctx>> { pub struct NDIterFields<'ctx, F: FieldTraversal<'ctx>> {
pub ndims: F::Out<Int<SizeT>>, pub ndims: F::Output<Int<SizeT>>,
pub shape: F::Out<Ptr<Int<SizeT>>>, pub shape: F::Output<Ptr<Int<SizeT>>>,
pub strides: F::Out<Ptr<Int<SizeT>>>, pub strides: F::Output<Ptr<Int<SizeT>>>,
pub indices: F::Out<Ptr<Int<SizeT>>>, pub indices: F::Output<Ptr<Int<SizeT>>>,
pub nth: F::Out<Int<SizeT>>, pub nth: F::Output<Int<SizeT>>,
pub element: F::Out<Ptr<Int<Byte>>>, pub element: F::Output<Ptr<Int<Byte>>>,
pub size: F::Out<Int<SizeT>>, pub size: F::Output<Int<SizeT>>,
} }
/// An IRRT helper structure used to iterate through an ndarray. /// An IRRT helper structure used to iterate through an ndarray.
@ -30,7 +30,7 @@ pub struct NDIter;
impl<'ctx> StructKind<'ctx> for NDIter { impl<'ctx> StructKind<'ctx> for NDIter {
type Fields<F: FieldTraversal<'ctx>> = NDIterFields<'ctx, F>; type Fields<F: FieldTraversal<'ctx>> = NDIterFields<'ctx, F>;
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> { fn iter_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
Self::Fields { Self::Fields {
ndims: traversal.add_auto("ndims"), ndims: traversal.add_auto("ndims"),
shape: traversal.add_auto("shape"), shape: traversal.add_auto("shape"),
@ -72,20 +72,22 @@ impl<'ctx> NDIterHandle<'ctx> {
NDIterHandle { ndarray, instance: nditer, indices } NDIterHandle { ndarray, instance: nditer, indices }
} }
/// Is there a next element? /// Is the current iteration valid?
///
/// If true, then `element`, `indices` and `nth` contain details about the current element.
/// ///
/// If `ndarray` is unsized, this returns true only for the first iteration. /// If `ndarray` is unsized, this returns true only for the first iteration.
/// If `ndarray` is 0-sized, this always returns false. /// If `ndarray` is 0-sized, this always returns false.
#[must_use] #[must_use]
pub fn has_next<G: CodeGenerator + ?Sized>( pub fn has_element<G: CodeGenerator + ?Sized>(
&self, &self,
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
) -> Instance<'ctx, Int<Bool>> { ) -> Instance<'ctx, Int<Bool>> {
call_nac3_nditer_has_next(generator, ctx, self.instance) call_nac3_nditer_has_element(generator, ctx, self.instance)
} }
/// Go to the next element. If `has_next()` is false, then this has undefined behavior. /// Go to the next element. If `has_element()` is false, then this has undefined behavior.
/// ///
/// If `ndarray` is unsized, this can only be called once. /// If `ndarray` is unsized, this can only be called once.
/// If `ndarray` is 0-sized, this can never be called. /// If `ndarray` is 0-sized, this can never be called.
@ -166,7 +168,7 @@ impl<'ctx> NDArrayObject<'ctx> {
ctx, ctx,
Some("ndarray_foreach"), Some("ndarray_foreach"),
|generator, ctx| Ok(NDIterHandle::new(generator, ctx, *self)), |generator, ctx| Ok(NDIterHandle::new(generator, ctx, *self)),
|generator, ctx, nditer| Ok(nditer.has_next(generator, ctx).value), |generator, ctx, nditer| Ok(nditer.has_element(generator, ctx).value),
|generator, ctx, hooks, nditer| body(generator, ctx, hooks, nditer), |generator, ctx, hooks, nditer| body(generator, ctx, hooks, nditer),
|generator, ctx, nditer| { |generator, ctx, nditer| {
nditer.next(generator, ctx); nditer.next(generator, ctx);

View File

@ -77,7 +77,7 @@ pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
let int = input_sequence.index(ctx, i).value.into_int_value(); let int = input_sequence.index(ctx, i).value.into_int_value();
let int = Int(SizeT).s_extend_or_bit_cast(generator, ctx, int); let int = Int(SizeT).s_extend_or_bit_cast(generator, ctx, int);
result.set_index_const(ctx, i as u64, int); result.set_index_const(ctx, i64::try_from(i).unwrap(), int);
} }
(len, result) (len, result)

View File

@ -81,7 +81,7 @@ impl<'ctx> TupleObject<'ctx> {
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
) -> Instance<'ctx, Int<SizeT>> { ) -> Instance<'ctx, Int<SizeT>> {
Int(SizeT).const_int(generator, ctx.ctx, self.num_elements() as u64) Int(SizeT).const_int(generator, ctx.ctx, self.num_elements() as u64, false)
} }
/// Get the `i`-th (0-based) object in this tuple. /// Get the `i`-th (0-based) object in this tuple.

View File

@ -3,12 +3,12 @@ use crate::codegen::{model::*, CodeGenContext, CodeGenerator};
/// Fields of [`Slice`] /// Fields of [`Slice`]
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct SliceFields<'ctx, F: FieldTraversal<'ctx>, N: IntKind<'ctx>> { pub struct SliceFields<'ctx, F: FieldTraversal<'ctx>, N: IntKind<'ctx>> {
pub start_defined: F::Out<Int<Bool>>, pub start_defined: F::Output<Int<Bool>>,
pub start: F::Out<Int<N>>, pub start: F::Output<Int<N>>,
pub stop_defined: F::Out<Int<Bool>>, pub stop_defined: F::Output<Int<Bool>>,
pub stop: F::Out<Int<N>>, pub stop: F::Output<Int<N>>,
pub step_defined: F::Out<Int<Bool>>, pub step_defined: F::Output<Int<Bool>>,
pub step: F::Out<Int<N>>, pub step: F::Output<Int<N>>,
} }
/// An IRRT representation of an (unresolved) slice. /// An IRRT representation of an (unresolved) slice.
@ -18,7 +18,7 @@ pub struct Slice<N>(pub N);
impl<'ctx, N: IntKind<'ctx>> StructKind<'ctx> for Slice<N> { impl<'ctx, N: IntKind<'ctx>> StructKind<'ctx> for Slice<N> {
type Fields<F: FieldTraversal<'ctx>> = SliceFields<'ctx, F, N>; type Fields<F: FieldTraversal<'ctx>> = SliceFields<'ctx, F, N>;
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> { fn iter_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
Self::Fields { Self::Fields {
start_defined: traversal.add_auto("start_defined"), start_defined: traversal.add_auto("start_defined"),
start: traversal.add("start", Int(self.0)), start: traversal.add("start", Int(self.0)),

View File

@ -1,15 +1,13 @@
use super::{ use super::{
super::symbol_resolver::ValueEnum, classes::{ArrayLikeIndexer, ArraySliceValue, ListValue, RangeValue},
expr::destructure_range, expr::{destructure_range, gen_binop_expr},
gen_in_range_check,
irrt::{handle_slice_indices, list_slice_assignment}, irrt::{handle_slice_indices, list_slice_assignment},
macros::codegen_unreachable,
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}; };
use crate::{ use crate::{
codegen::{ symbol_resolver::ValueEnum,
classes::{ArrayLikeIndexer, ArraySliceValue, ListValue, RangeValue},
expr::gen_binop_expr,
gen_in_range_check,
},
toplevel::{DefinitionId, TopLevelDef}, toplevel::{DefinitionId, TopLevelDef},
typecheck::{ typecheck::{
magic_methods::Binop, magic_methods::Binop,
@ -121,7 +119,7 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
return Ok(None); return Ok(None);
}; };
let BasicValueEnum::PointerValue(ptr) = val else { let BasicValueEnum::PointerValue(ptr) = val else {
unreachable!(); codegen_unreachable!(ctx);
}; };
unsafe { unsafe {
ctx.builder.build_in_bounds_gep( ctx.builder.build_in_bounds_gep(
@ -135,7 +133,7 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
} }
.unwrap() .unwrap()
} }
_ => unreachable!(), _ => codegen_unreachable!(ctx),
})) }))
} }
@ -176,6 +174,14 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
} }
} }
let val = value.to_basic_value_enum(ctx, generator, target.custom.unwrap())?; let val = value.to_basic_value_enum(ctx, generator, target.custom.unwrap())?;
// Perform i1 <-> i8 conversion as needed
let val = if ctx.unifier.unioned(target.custom.unwrap(), ctx.primitives.bool) {
generator.bool_to_i8(ctx, val.into_int_value()).into()
} else {
val
};
ctx.builder.build_store(ptr, val).unwrap(); ctx.builder.build_store(ptr, val).unwrap();
} }
}; };
@ -193,12 +199,12 @@ pub fn gen_assign_target_list<'ctx, G: CodeGenerator>(
// Deconstruct the tuple `value` // Deconstruct the tuple `value`
let BasicValueEnum::StructValue(tuple) = value.to_basic_value_enum(ctx, generator, value_ty)? let BasicValueEnum::StructValue(tuple) = value.to_basic_value_enum(ctx, generator, value_ty)?
else { else {
unreachable!() codegen_unreachable!(ctx)
}; };
// NOTE: Currently, RHS's type is forced to be a Tuple by the type inferencer. // NOTE: Currently, RHS's type is forced to be a Tuple by the type inferencer.
let TypeEnum::TTuple { ty: tuple_tys, .. } = &*ctx.unifier.get_ty(value_ty) else { let TypeEnum::TTuple { ty: tuple_tys, .. } = &*ctx.unifier.get_ty(value_ty) else {
unreachable!(); codegen_unreachable!(ctx);
}; };
assert_eq!(tuple.get_type().count_fields() as usize, tuple_tys.len()); assert_eq!(tuple.get_type().count_fields() as usize, tuple_tys.len());
@ -258,7 +264,7 @@ pub fn gen_assign_target_list<'ctx, G: CodeGenerator>(
// Now assign with that sub-tuple to the starred target. // Now assign with that sub-tuple to the starred target.
generator.gen_assign(ctx, target, ValueEnum::Dynamic(sub_tuple_val), sub_tuple_ty)?; generator.gen_assign(ctx, target, ValueEnum::Dynamic(sub_tuple_val), sub_tuple_ty)?;
} else { } else {
unreachable!() // The typechecker ensures this codegen_unreachable!(ctx) // The typechecker ensures this
} }
// Handle assignment after the starred target // Handle assignment after the starred target
@ -306,7 +312,9 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
if let ExprKind::Slice { .. } = &key.node { if let ExprKind::Slice { .. } = &key.node {
// Handle assigning to a slice // Handle assigning to a slice
let ExprKind::Slice { lower, upper, step } = &key.node else { unreachable!() }; let ExprKind::Slice { lower, upper, step } = &key.node else {
codegen_unreachable!(ctx)
};
let Some((start, end, step)) = handle_slice_indices( let Some((start, end, step)) = handle_slice_indices(
lower, lower,
upper, upper,
@ -416,7 +424,9 @@ pub fn gen_for<G: CodeGenerator>(
ctx: &mut CodeGenContext<'_, '_>, ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>, stmt: &Stmt<Option<Type>>,
) -> Result<(), String> { ) -> Result<(), String> {
let StmtKind::For { iter, target, body, orelse, .. } = &stmt.node else { unreachable!() }; let StmtKind::For { iter, target, body, orelse, .. } = &stmt.node else {
codegen_unreachable!(ctx)
};
// var_assignment static values may be changed in another branch // var_assignment static values may be changed in another branch
// if so, remove the static value as it may not be correct in this branch // if so, remove the static value as it may not be correct in this branch
@ -458,7 +468,7 @@ pub fn gen_for<G: CodeGenerator>(
let Some(target_i) = let Some(target_i) =
generator.gen_store_target(ctx, target, Some("for.target.addr"))? generator.gen_store_target(ctx, target, Some("for.target.addr"))?
else { else {
unreachable!() codegen_unreachable!(ctx)
}; };
let (start, stop, step) = destructure_range(ctx, iter_val); let (start, stop, step) = destructure_range(ctx, iter_val);
@ -901,7 +911,7 @@ pub fn gen_while<G: CodeGenerator>(
ctx: &mut CodeGenContext<'_, '_>, ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>, stmt: &Stmt<Option<Type>>,
) -> Result<(), String> { ) -> Result<(), String> {
let StmtKind::While { test, body, orelse, .. } = &stmt.node else { unreachable!() }; let StmtKind::While { test, body, orelse, .. } = &stmt.node else { codegen_unreachable!(ctx) };
// var_assignment static values may be changed in another branch // var_assignment static values may be changed in another branch
// if so, remove the static value as it may not be correct in this branch // if so, remove the static value as it may not be correct in this branch
@ -931,7 +941,7 @@ pub fn gen_while<G: CodeGenerator>(
return Ok(()); return Ok(());
}; };
let BasicValueEnum::IntValue(test) = test else { unreachable!() }; let BasicValueEnum::IntValue(test) = test else { codegen_unreachable!(ctx) };
ctx.builder ctx.builder
.build_conditional_branch(generator.bool_to_i1(ctx, test), body_bb, orelse_bb) .build_conditional_branch(generator.bool_to_i1(ctx, test), body_bb, orelse_bb)
@ -1079,7 +1089,7 @@ pub fn gen_if<G: CodeGenerator>(
ctx: &mut CodeGenContext<'_, '_>, ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>, stmt: &Stmt<Option<Type>>,
) -> Result<(), String> { ) -> Result<(), String> {
let StmtKind::If { test, body, orelse, .. } = &stmt.node else { unreachable!() }; let StmtKind::If { test, body, orelse, .. } = &stmt.node else { codegen_unreachable!(ctx) };
// var_assignment static values may be changed in another branch // var_assignment static values may be changed in another branch
// if so, remove the static value as it may not be correct in this branch // if so, remove the static value as it may not be correct in this branch
@ -1202,11 +1212,11 @@ pub fn exn_constructor<'ctx>(
let zelf_id = if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(zelf_ty) { let zelf_id = if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(zelf_ty) {
obj_id.0 obj_id.0
} else { } else {
unreachable!() codegen_unreachable!(ctx)
}; };
let defs = ctx.top_level.definitions.read(); let defs = ctx.top_level.definitions.read();
let def = defs[zelf_id].read(); let def = defs[zelf_id].read();
let TopLevelDef::Class { name: zelf_name, .. } = &*def else { unreachable!() }; let TopLevelDef::Class { name: zelf_name, .. } = &*def else { codegen_unreachable!(ctx) };
let exception_name = format!("{}:{}", ctx.resolver.get_exception_id(zelf_id), zelf_name); let exception_name = format!("{}:{}", ctx.resolver.get_exception_id(zelf_id), zelf_name);
unsafe { unsafe {
let id_ptr = ctx.builder.build_in_bounds_gep(zelf, &[zero, zero], "exn.id").unwrap(); let id_ptr = ctx.builder.build_in_bounds_gep(zelf, &[zero, zero], "exn.id").unwrap();
@ -1314,7 +1324,7 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
target: &Stmt<Option<Type>>, target: &Stmt<Option<Type>>,
) -> Result<(), String> { ) -> Result<(), String> {
let StmtKind::Try { body, handlers, orelse, finalbody, .. } = &target.node else { let StmtKind::Try { body, handlers, orelse, finalbody, .. } = &target.node else {
unreachable!() codegen_unreachable!(ctx)
}; };
// if we need to generate anything related to exception, we must have personality defined // if we need to generate anything related to exception, we must have personality defined
@ -1391,7 +1401,7 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(type_.custom.unwrap()) { if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(type_.custom.unwrap()) {
*obj_id *obj_id
} else { } else {
unreachable!() codegen_unreachable!(ctx)
}; };
let exception_name = format!("{}:{}", ctx.resolver.get_exception_id(obj_id.0), exn_name); let exception_name = format!("{}:{}", ctx.resolver.get_exception_id(obj_id.0), exn_name);
let exn_id = ctx.resolver.get_string_id(&exception_name); let exn_id = ctx.resolver.get_string_id(&exception_name);
@ -1663,6 +1673,23 @@ pub fn gen_return<G: CodeGenerator>(
} else { } else {
None None
}; };
// Remap boolean return type into i1
let value = value.map(|ret_val| {
// The "return type" of a sret function is in the first parameter
let expected_ty = if ctx.need_sret {
func.get_type().get_param_types()[0]
} else {
func.get_type().get_return_type().unwrap()
};
if matches!(expected_ty, BasicTypeEnum::IntType(ty) if ty.get_bit_width() == 1) {
generator.bool_to_i1(ctx, ret_val.into_int_value()).into()
} else {
ret_val
}
});
if let Some(return_target) = ctx.return_target { if let Some(return_target) = ctx.return_target {
if let Some(value) = value { if let Some(value) = value {
ctx.builder.build_store(ctx.return_buffer.unwrap(), value).unwrap(); ctx.builder.build_store(ctx.return_buffer.unwrap(), value).unwrap();
@ -1673,25 +1700,6 @@ pub fn gen_return<G: CodeGenerator>(
ctx.builder.build_store(ctx.return_buffer.unwrap(), value.unwrap()).unwrap(); ctx.builder.build_store(ctx.return_buffer.unwrap(), value.unwrap()).unwrap();
ctx.builder.build_return(None).unwrap(); ctx.builder.build_return(None).unwrap();
} else { } else {
// Remap boolean return type into i1
let value = value.map(|v| {
let expected_ty = func.get_type().get_return_type().unwrap();
let ret_val = v.as_basic_value_enum();
if expected_ty.is_int_type() && ret_val.is_int_value() {
let ret_type = expected_ty.into_int_type();
let ret_val = ret_val.into_int_value();
if ret_type.get_bit_width() == 1 && ret_val.get_type().get_bit_width() != 1 {
generator.bool_to_i1(ctx, ret_val)
} else {
ret_val
}
.into()
} else {
ret_val
}
});
let value = value.as_ref().map(|v| v as &dyn BasicValue); let value = value.as_ref().map(|v| v as &dyn BasicValue);
ctx.builder.build_return(value).unwrap(); ctx.builder.build_return(value).unwrap();
} }
@ -1760,7 +1768,30 @@ pub fn gen_stmt<G: CodeGenerator>(
StmtKind::Try { .. } => gen_try(generator, ctx, stmt)?, StmtKind::Try { .. } => gen_try(generator, ctx, stmt)?,
StmtKind::Raise { exc, .. } => { StmtKind::Raise { exc, .. } => {
if let Some(exc) = exc { if let Some(exc) = exc {
let exc = if let Some(v) = generator.gen_expr(ctx, exc)? { let exn = if let ExprKind::Name { id, .. } = &exc.node {
// Handle "raise Exception" short form
let def_id = ctx.resolver.get_identifier_def(*id).map_err(|e| {
format!("{} (at {})", e.iter().next().unwrap(), exc.location)
})?;
let def = ctx.top_level.definitions.read();
let TopLevelDef::Class { constructor, .. } = *def[def_id.0].read() else {
return Err(format!("Failed to resolve symbol {id} (at {})", exc.location));
};
let TypeEnum::TFunc(signature) =
ctx.unifier.get_ty(constructor.unwrap()).as_ref().clone()
else {
return Err(format!("Failed to resolve symbol {id} (at {})", exc.location));
};
generator
.gen_call(ctx, None, (&signature, def_id), Vec::default())?
.map(Into::into)
} else {
generator.gen_expr(ctx, exc)?
};
let exc = if let Some(v) = exn {
v.to_basic_value_enum(ctx, generator, exc.custom.unwrap())? v.to_basic_value_enum(ctx, generator, exc.custom.unwrap())?
} else { } else {
return Ok(()); return Ok(());

View File

@ -23,7 +23,7 @@ impl Default for ComposerConfig {
} }
} }
type DefAst = (Arc<RwLock<TopLevelDef>>, Option<Stmt<()>>); pub type DefAst = (Arc<RwLock<TopLevelDef>>, Option<Stmt<()>>);
pub struct TopLevelComposer { pub struct TopLevelComposer {
// list of top level definitions, same as top level context // list of top level definitions, same as top level context
pub definition_ast_list: Vec<DefAst>, pub definition_ast_list: Vec<DefAst>,
@ -1822,7 +1822,12 @@ impl TopLevelComposer {
if *name != init_str_id { if *name != init_str_id {
unreachable!("must be init function here") unreachable!("must be init function here")
} }
let all_inited = Self::get_all_assigned_field(body.as_slice())?;
let all_inited = Self::get_all_assigned_field(
object_id.0,
definition_ast_list,
body.as_slice(),
)?;
for (f, _, _) in fields { for (f, _, _) in fields {
if !all_inited.contains(f) { if !all_inited.contains(f) {
return Err(HashSet::from([ return Err(HashSet::from([

View File

@ -3,6 +3,7 @@ use std::convert::TryInto;
use crate::symbol_resolver::SymbolValue; use crate::symbol_resolver::SymbolValue;
use crate::toplevel::numpy::unpack_ndarray_var_tys; use crate::toplevel::numpy::unpack_ndarray_var_tys;
use crate::typecheck::typedef::{into_var_map, iter_type_vars, Mapping, TypeVarId, VarMap}; use crate::typecheck::typedef::{into_var_map, iter_type_vars, Mapping, TypeVarId, VarMap};
use ast::ExprKind;
use nac3parser::ast::{Constant, Location}; use nac3parser::ast::{Constant, Location};
use strum::IntoEnumIterator; use strum::IntoEnumIterator;
use strum_macros::EnumIter; use strum_macros::EnumIter;
@ -749,7 +750,16 @@ impl TopLevelComposer {
) )
} }
pub fn get_all_assigned_field(stmts: &[Stmt<()>]) -> Result<HashSet<StrRef>, HashSet<String>> { /// This function returns the fields that have been initialized in the `__init__` function of a class
/// The function takes as input:
/// * `class_id`: The `object_id` of the class whose function is being evaluated (check `TopLevelDef::Class`)
/// * `definition_ast_list`: A list of ast definitions and statements defined in `TopLevelComposer`
/// * `stmts`: The body of function being parsed. Each statment is analyzed to check varaible initialization statements
pub fn get_all_assigned_field(
class_id: usize,
definition_ast_list: &Vec<DefAst>,
stmts: &[Stmt<()>],
) -> Result<HashSet<StrRef>, HashSet<String>> {
let mut result = HashSet::new(); let mut result = HashSet::new();
for s in stmts { for s in stmts {
match &s.node { match &s.node {
@ -785,30 +795,138 @@ impl TopLevelComposer {
// TODO: do not check for For and While? // TODO: do not check for For and While?
ast::StmtKind::For { body, orelse, .. } ast::StmtKind::For { body, orelse, .. }
| ast::StmtKind::While { body, orelse, .. } => { | ast::StmtKind::While { body, orelse, .. } => {
result.extend(Self::get_all_assigned_field(body.as_slice())?); result.extend(Self::get_all_assigned_field(
result.extend(Self::get_all_assigned_field(orelse.as_slice())?); class_id,
definition_ast_list,
body.as_slice(),
)?);
result.extend(Self::get_all_assigned_field(
class_id,
definition_ast_list,
orelse.as_slice(),
)?);
} }
ast::StmtKind::If { body, orelse, .. } => { ast::StmtKind::If { body, orelse, .. } => {
let inited_for_sure = Self::get_all_assigned_field(body.as_slice())? let inited_for_sure = Self::get_all_assigned_field(
.intersection(&Self::get_all_assigned_field(orelse.as_slice())?) class_id,
.copied() definition_ast_list,
.collect::<HashSet<_>>(); body.as_slice(),
)?
.intersection(&Self::get_all_assigned_field(
class_id,
definition_ast_list,
orelse.as_slice(),
)?)
.copied()
.collect::<HashSet<_>>();
result.extend(inited_for_sure); result.extend(inited_for_sure);
} }
ast::StmtKind::Try { body, orelse, finalbody, .. } => { ast::StmtKind::Try { body, orelse, finalbody, .. } => {
let inited_for_sure = Self::get_all_assigned_field(body.as_slice())? let inited_for_sure = Self::get_all_assigned_field(
.intersection(&Self::get_all_assigned_field(orelse.as_slice())?) class_id,
.copied() definition_ast_list,
.collect::<HashSet<_>>(); body.as_slice(),
)?
.intersection(&Self::get_all_assigned_field(
class_id,
definition_ast_list,
orelse.as_slice(),
)?)
.copied()
.collect::<HashSet<_>>();
result.extend(inited_for_sure); result.extend(inited_for_sure);
result.extend(Self::get_all_assigned_field(finalbody.as_slice())?); result.extend(Self::get_all_assigned_field(
class_id,
definition_ast_list,
finalbody.as_slice(),
)?);
} }
ast::StmtKind::With { body, .. } => { ast::StmtKind::With { body, .. } => {
result.extend(Self::get_all_assigned_field(body.as_slice())?); result.extend(Self::get_all_assigned_field(
class_id,
definition_ast_list,
body.as_slice(),
)?);
}
// Variables Initialized in function calls
ast::StmtKind::Expr { value, .. } => {
let ExprKind::Call { func, .. } = &value.node else {
continue;
};
let ExprKind::Attribute { value, attr, .. } = &func.node else {
continue;
};
let ExprKind::Name { id, .. } = &value.node else {
continue;
};
// Need to consider the two cases:
// Case 1) Call to class function i.e. id = `self`
// Case 2) Call to class ancestor function i.e. id = ancestor_name
// We leave checking whether function in case 2 belonged to class ancestor or not to type checker
//
// According to current handling of `self`, function definition are fixed and do not change regardless
// of which object is passed as `self` i.e. virtual polymorphism is not supported
// Therefore, we change class id for case 2 to reflect behavior of our compiler
let class_name = if *id == "self".into() {
let ast::StmtKind::ClassDef { name, .. } =
&definition_ast_list[class_id].1.as_ref().unwrap().node
else {
unreachable!()
};
name
} else {
id
};
let parent_method = definition_ast_list.iter().find_map(|def| {
let (
class_def,
Some(ast::Located {
node: ast::StmtKind::ClassDef { name, body, .. },
..
}),
) = &def
else {
return None;
};
let TopLevelDef::Class { object_id: class_id, .. } = &*class_def.read()
else {
unreachable!()
};
if name == class_name {
body.iter().find_map(|m| {
let ast::StmtKind::FunctionDef { name, body, .. } = &m.node else {
return None;
};
if *name == *attr {
return Some((body.clone(), class_id.0));
}
None
})
} else {
None
}
});
// If method body is none then method does not exist
if let Some((method_body, class_id)) = parent_method {
result.extend(Self::get_all_assigned_field(
class_id,
definition_ast_list,
method_body.as_slice(),
)?);
} else {
return Err(HashSet::from([format!(
"{}.{} not found in class {class_name} at {}",
*id, *attr, value.location
)]));
}
} }
ast::StmtKind::Pass { .. } ast::StmtKind::Pass { .. }
| ast::StmtKind::Assert { .. } | ast::StmtKind::Assert { .. }
| ast::StmtKind::Expr { .. } => {} | ast::StmtKind::AnnAssign { .. } => {}
_ => { _ => {
unimplemented!() unimplemented!()

View File

@ -520,6 +520,23 @@ pub fn typeof_binop(
} }
Operator::MatMult => { Operator::MatMult => {
// NOTE: NumPy matmul's LHS and RHS must both be ndarrays. Scalars are not allowed.
match (&*unifier.get_ty(lhs), &*unifier.get_ty(rhs)) {
(
TypeEnum::TObj { obj_id: lhs_obj_id, .. },
TypeEnum::TObj { obj_id: rhs_obj_id, .. },
) if *lhs_obj_id == primitives.ndarray.obj_id(unifier).unwrap()
&& *rhs_obj_id == primitives.ndarray.obj_id(unifier).unwrap() =>
{
// LHS and RHS have valid types
}
_ => {
let lhs_str = unifier.stringify(lhs);
let rhs_str = unifier.stringify(rhs);
return Err(format!("ndarray.__matmul__ only accepts ndarray operands, but left operand has type {lhs_str}, and right operand has type {rhs_str}"));
}
}
let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs); let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs);
let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) { let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) {
TypeEnum::TLiteral { values, .. } => { TypeEnum::TLiteral { values, .. } => {

View File

@ -12,6 +12,7 @@ use super::{
RecordField, RecordKey, Type, TypeEnum, TypeVar, Unifier, VarMap, RecordField, RecordKey, Type, TypeEnum, TypeVar, Unifier, VarMap,
}, },
}; };
use crate::toplevel::type_annotation::TypeAnnotation;
use crate::{ use crate::{
symbol_resolver::{SymbolResolver, SymbolValue}, symbol_resolver::{SymbolResolver, SymbolValue},
toplevel::{ toplevel::{
@ -102,6 +103,7 @@ pub struct Inferencer<'a> {
} }
type InferenceError = HashSet<String>; type InferenceError = HashSet<String>;
type OverrideResult = Result<Option<ast::Expr<Option<Type>>>, InferenceError>;
struct NaiveFolder(); struct NaiveFolder();
impl Fold<()> for NaiveFolder { impl Fold<()> for NaiveFolder {
@ -1711,6 +1713,86 @@ impl<'a> Inferencer<'a> {
Ok(None) Ok(None)
} }
/// Checks whether a class method is calling parent function
/// Returns [`None`] if its not a call to parent method, otherwise
/// returns a new `func` with class name replaced by `self` and method resolved to its `DefinitionID`
///
/// e.g. A.f1(self, ...) returns Some(self.{DefintionID(f1)})
fn check_overriding(&mut self, func: &ast::Expr<()>, args: &[ast::Expr<()>]) -> OverrideResult {
// `self` must be first argument for call to parent method
if let Some(Located { node: ExprKind::Name { id, .. }, .. }) = &args.first() {
if *id != "self".into() {
return Ok(None);
}
} else {
return Ok(None);
}
let Located {
node: ExprKind::Attribute { value, attr: method_name, ctx }, location, ..
} = func
else {
return Ok(None);
};
let ExprKind::Name { id: class_name, ctx: class_ctx } = &value.node else {
return Ok(None);
};
let zelf = &self.fold_expr(args[0].clone())?;
// Check whether the method belongs to class ancestors
let def_id = self.unifier.get_ty(zelf.custom.unwrap());
let TypeEnum::TObj { obj_id, .. } = def_id.as_ref() else { unreachable!() };
let defs = self.top_level.definitions.read();
let res = {
if let TopLevelDef::Class { ancestors, .. } = &*defs[obj_id.0].read() {
let res = ancestors.iter().find_map(|f| {
let TypeAnnotation::CustomClass { id, .. } = f else { unreachable!() };
let TopLevelDef::Class { name, methods, .. } = &*defs[id.0].read() else {
unreachable!()
};
// Class names are stored as `__module__.class`
let name = name.to_string();
let (_, name) = name.rsplit_once('.').unwrap();
if name == class_name.to_string() {
return methods.iter().find_map(|f| {
if f.0 == *method_name {
return Some(*f);
}
None
});
}
None
});
res
} else {
None
}
};
match res {
Some(r) => {
let mut new_func = func.clone();
let mut new_value = value.clone();
new_value.node = ExprKind::Name { id: "self".into(), ctx: *class_ctx };
new_func.node =
ExprKind::Attribute { value: new_value.clone(), attr: *method_name, ctx: *ctx };
let mut new_func = self.fold_expr(new_func)?;
let ExprKind::Attribute { value, .. } = new_func.node else { unreachable!() };
new_func.node =
ExprKind::Attribute { value, attr: r.2 .0.to_string().into(), ctx: *ctx };
new_func.custom = Some(r.1);
Ok(Some(new_func))
}
None => report_error(
format!("Ancestor method [{class_name}.{method_name}] should be defined with same decorator as its overridden version").as_str(),
*location,
),
}
}
fn fold_call( fn fold_call(
&mut self, &mut self,
location: Location, location: Location,
@ -1724,8 +1806,20 @@ impl<'a> Inferencer<'a> {
return Ok(spec_call_func); return Ok(spec_call_func);
} }
let func = Box::new(self.fold_expr(func)?); // Check for call to parent method
let args = args.into_iter().map(|v| self.fold_expr(v)).collect::<Result<Vec<_>, _>>()?; let override_res = self.check_overriding(&func, &args)?;
let is_override = override_res.is_some();
let func = if is_override { override_res.unwrap() } else { self.fold_expr(func)? };
let func = Box::new(func);
let mut args =
args.into_iter().map(|v| self.fold_expr(v)).collect::<Result<Vec<_>, _>>()?;
// TODO: Handle passing of self to functions to allow runtime lookup of functions to be called
// Currently removing `self` and using compile time function definitions
if is_override {
args.remove(0);
}
let keywords = keywords let keywords = keywords
.into_iter() .into_iter()
.map(|v| fold::fold_keyword(self, v)) .map(|v| fold::fold_keyword(self, v))

View File

@ -7,11 +7,11 @@
#include <string.h> #include <string.h>
double dbl_nan(void) { double dbl_nan(void) {
return NAN; return NAN;
} }
double dbl_inf(void) { double dbl_inf(void) {
return INFINITY; return INFINITY;
} }
void output_bool(bool x) { void output_bool(bool x) {
@ -19,19 +19,19 @@ void output_bool(bool x) {
} }
void output_int32(int32_t x) { void output_int32(int32_t x) {
printf("%"PRId32"\n", x); printf("%" PRId32 "\n", x);
} }
void output_int64(int64_t x) { void output_int64(int64_t x) {
printf("%"PRId64"\n", x); printf("%" PRId64 "\n", x);
} }
void output_uint32(uint32_t x) { void output_uint32(uint32_t x) {
printf("%"PRIu32"\n", x); printf("%" PRIu32 "\n", x);
} }
void output_uint64(uint64_t x) { void output_uint64(uint64_t x) {
printf("%"PRIu64"\n", x); printf("%" PRIu64 "\n", x);
} }
void output_float64(double x) { void output_float64(double x) {
@ -52,7 +52,7 @@ void output_range(int32_t range[3]) {
} }
void output_asciiart(int32_t x) { void output_asciiart(int32_t x) {
static const char *chars = " .,-:;i+hHM$*#@ "; static const char* chars = " .,-:;i+hHM$*#@ ";
if (x < 0) { if (x < 0) {
putchar('\n'); putchar('\n');
} else { } else {
@ -61,12 +61,12 @@ void output_asciiart(int32_t x) {
} }
struct cslice { struct cslice {
void *data; void* data;
size_t len; size_t len;
}; };
void output_int32_list(struct cslice *slice) { void output_int32_list(struct cslice* slice) {
const int32_t *data = (int32_t *) slice->data; const int32_t* data = (int32_t*)slice->data;
putchar('['); putchar('[');
for (size_t i = 0; i < slice->len; ++i) { for (size_t i = 0; i < slice->len; ++i) {
@ -80,23 +80,23 @@ void output_int32_list(struct cslice *slice) {
putchar('\n'); putchar('\n');
} }
void output_str(struct cslice *slice) { void output_str(struct cslice* slice) {
const char *data = (const char *) slice->data; const char* data = (const char*)slice->data;
for (size_t i = 0; i < slice->len; ++i) { for (size_t i = 0; i < slice->len; ++i) {
putchar(data[i]); putchar(data[i]);
} }
} }
void output_strln(struct cslice *slice) { void output_strln(struct cslice* slice) {
output_str(slice); output_str(slice);
putchar('\n'); putchar('\n');
} }
uint64_t dbg_stack_address(__attribute__((unused)) struct cslice *slice) { uint64_t dbg_stack_address(__attribute__((unused)) struct cslice* slice) {
int i; int i;
void *ptr = (void *) &i; void* ptr = (void*)&i;
return (uintptr_t) ptr; return (uintptr_t)ptr;
} }
uint32_t __nac3_personality(uint32_t state, uint32_t exception_object, uint32_t context) { uint32_t __nac3_personality(uint32_t state, uint32_t exception_object, uint32_t context) {
@ -119,11 +119,12 @@ struct Exception {
uint32_t __nac3_raise(struct Exception* e) { uint32_t __nac3_raise(struct Exception* e) {
printf("__nac3_raise called. Exception details:\n"); printf("__nac3_raise called. Exception details:\n");
printf(" ID: %"PRIu32"\n", e->id); printf(" ID: %" PRIu32 "\n", e->id);
printf(" Location: %*s:%"PRIu32":%"PRIu32"\n" , (int) e->file.len, (const char*) e->file.data, e->line, e->column); printf(" Location: %*s:%" PRIu32 ":%" PRIu32 "\n", (int)e->file.len, (const char*)e->file.data, e->line,
printf(" Function: %*s\n" , (int) e->function.len, (const char*) e->function.data); e->column);
printf(" Message: \"%*s\"\n" , (int) e->message.len, (const char*) e->message.data); printf(" Function: %*s\n", (int)e->function.len, (const char*)e->function.data);
printf(" Params: {0}=%"PRId64", {1}=%"PRId64", {2}=%"PRId64"\n", e->param[0], e->param[1], e->param[2]); printf(" Message: \"%*s\"\n", (int)e->message.len, (const char*)e->message.data);
printf(" Params: {0}=%" PRId64 ", {1}=%" PRId64 ", {2}=%" PRId64 "\n", e->param[0], e->param[1], e->param[2]);
exit(101); exit(101);
__builtin_unreachable(); __builtin_unreachable();
} }

View File

@ -9,6 +9,7 @@ def output_bool(x: bool):
def example1(): def example1():
x, *ys, z = (1, 2, 3, 4, 5) x, *ys, z = (1, 2, 3, 4, 5)
output_int32(x) output_int32(x)
output_int32(len(ys))
output_int32(ys[0]) output_int32(ys[0])
output_int32(ys[1]) output_int32(ys[1])
output_int32(ys[2]) output_int32(ys[2])
@ -18,12 +19,14 @@ def example2():
x, y, *zs = (1, 2, 3, 4, 5) x, y, *zs = (1, 2, 3, 4, 5)
output_int32(x) output_int32(x)
output_int32(y) output_int32(y)
output_int32(len(zs))
output_int32(zs[0]) output_int32(zs[0])
output_int32(zs[1]) output_int32(zs[1])
output_int32(zs[2]) output_int32(zs[2])
def example3(): def example3():
*xs, y, z = (1, 2, 3, 4, 5) *xs, y, z = (1, 2, 3, 4, 5)
output_int32(len(xs))
output_int32(xs[0]) output_int32(xs[0])
output_int32(xs[1]) output_int32(xs[1])
output_int32(xs[2]) output_int32(xs[2])
@ -31,6 +34,12 @@ def example3():
output_int32(z) output_int32(z)
def example4(): def example4():
*xs, y, z = (4, 5)
output_int32(len(xs))
output_int32(y)
output_int32(z)
def example5():
# Example from: https://docs.python.org/3/reference/simple_stmts.html#assignment-statements # Example from: https://docs.python.org/3/reference/simple_stmts.html#assignment-statements
x = [0, 1] x = [0, 1]
i = 0 i = 0
@ -44,7 +53,7 @@ class A:
def __init__(self): def __init__(self):
self.value = 1000 self.value = 1000
def example5(): def example6():
ws = [88, 7, 8] ws = [88, 7, 8]
a = A() a = A()
x, [y, *ys, a.value], ws[0], (ws[0],) = 1, (2, False, 4, 5), 99, (6,) x, [y, *ys, a.value], ws[0], (ws[0],) = 1, (2, False, 4, 5), 99, (6,)
@ -63,4 +72,5 @@ def run() -> int32:
example3() example3()
example4() example4()
example5() example5()
example6()
return 0 return 0

View File

@ -10,23 +10,58 @@ class A:
def __init__(self, a: int32): def __init__(self, a: int32):
self.a = a self.a = a
def f1(self): def output_all_fields(self):
self.f2()
def f2(self):
output_int32(self.a) output_int32(self.a)
def set_a(self, a: int32):
self.a = a
class B(A): class B(A):
b: int32 b: int32
def __init__(self, b: int32): def __init__(self, b: int32):
self.a = b + 1 A.__init__(self, b + 1)
self.set_b(b)
def output_parent_fields(self):
A.output_all_fields(self)
def output_all_fields(self):
A.output_all_fields(self)
output_int32(self.b)
def set_b(self, b: int32):
self.b = b self.b = b
class C(B):
c: int32
def __init__(self, c: int32):
B.__init__(self, c + 1)
self.c = c
def output_parent_fields(self):
B.output_all_fields(self)
def output_all_fields(self):
B.output_all_fields(self)
output_int32(self.c)
def set_c(self, c: int32):
self.c = c
def run() -> int32: def run() -> int32:
aaa = A(5) ccc = C(10)
bbb = B(2) ccc.output_all_fields()
aaa.f1() ccc.set_a(1)
bbb.f1() ccc.set_b(2)
ccc.set_c(3)
ccc.output_all_fields()
bbb = B(10)
bbb.set_a(9)
bbb.set_b(8)
bbb.output_all_fields()
ccc.output_all_fields()
return 0 return 0

View File

@ -15,7 +15,6 @@ use std::{collections::HashMap, sync::Arc};
pub struct ResolverInternal { pub struct ResolverInternal {
pub id_to_type: Mutex<HashMap<StrRef, Type>>, pub id_to_type: Mutex<HashMap<StrRef, Type>>,
pub id_to_def: Mutex<HashMap<StrRef, DefinitionId>>, pub id_to_def: Mutex<HashMap<StrRef, DefinitionId>>,
pub class_names: Mutex<HashMap<StrRef, Type>>,
pub module_globals: Mutex<HashMap<StrRef, SymbolValue>>, pub module_globals: Mutex<HashMap<StrRef, SymbolValue>>,
pub str_store: Mutex<HashMap<String, i32>>, pub str_store: Mutex<HashMap<String, i32>>,
} }

View File

@ -14,7 +14,6 @@ use inkwell::{
memory_buffer::MemoryBuffer, passes::PassBuilderOptions, support::is_multithreaded, targets::*, memory_buffer::MemoryBuffer, passes::PassBuilderOptions, support::is_multithreaded, targets::*,
OptimizationLevel, OptimizationLevel,
}; };
use nac3core::codegen::irrt::setup_irrt_exceptions;
use nac3core::{ use nac3core::{
codegen::{ codegen::{
concrete_type::ConcreteTypeStore, irrt::load_irrt, CodeGenLLVMOptions, concrete_type::ConcreteTypeStore, irrt::load_irrt, CodeGenLLVMOptions,
@ -307,7 +306,6 @@ fn main() {
let internal_resolver: Arc<ResolverInternal> = ResolverInternal { let internal_resolver: Arc<ResolverInternal> = ResolverInternal {
id_to_type: builtins_ty.into(), id_to_type: builtins_ty.into(),
id_to_def: builtins_def.into(), id_to_def: builtins_def.into(),
class_names: Mutex::default(),
module_globals: Mutex::default(), module_globals: Mutex::default(),
str_store: Mutex::default(), str_store: Mutex::default(),
} }
@ -318,8 +316,7 @@ fn main() {
let context = inkwell::context::Context::create(); let context = inkwell::context::Context::create();
// Process IRRT // Process IRRT
let irrt = load_irrt(&context); let irrt = load_irrt(&context, resolver.as_ref());
setup_irrt_exceptions(&context, &irrt, resolver.as_ref());
if emit_llvm { if emit_llvm {
irrt.write_bitcode_to_path(Path::new("irrt.bc")); irrt.write_bitcode_to_path(Path::new("irrt.bc"));
} }