forked from M-Labs/nac3
1
0
Fork 0

Compare commits

...

99 Commits

Author SHA1 Message Date
lyken 9fa0dfe202
WIP: core/ndstrides: hold 2024-08-15 16:16:59 +08:00
lyken 1c48d54afa
WIP: core/ndstrides: fix nditer 2024-08-15 15:13:02 +08:00
lyken a69a441bdd
WIP: core/ndstrides: checkpoint 13 2024-08-15 14:38:05 +08:00
lyken 4b765cfb27
WIP: core/ndstrides: remove ScalarObject 2024-08-15 13:34:48 +08:00
lyken f8b934096d
WIP: core/ndstrides: checkpoint 2024-08-15 11:41:33 +08:00
lyken 0df2f26c98
WIP: core/ndstrides: builtin_fns deleted 2024-08-15 11:01:56 +08:00
lyken 5dce27e87d
WIP: core/ndstrides: more iter and less builtin 2024-08-15 00:33:23 +08:00
lyken 15dfb2eaa0
WIP: core/ndstrides: on iter 2024-08-14 17:30:37 +08:00
lyken fd78f7a0e8
WIP: core/ndstrides: done 2024-08-14 15:56:59 +08:00
lyken 2fbe981701
WIP: core/ndstrides: AnyObject + TupleObject 2024-08-14 12:48:10 +08:00
lyken febe78b6a4
WIP: core/ndstrides: rename .value to .instance in *Object 2024-08-14 11:34:42 +08:00
lyken 18dcbf5bbc
WIP: core/ndstrides: {make,from}_simple_ndarray 2024-08-14 11:33:56 +08:00
lyken bb1687f8a4
WIP: core/ndstrides: minor cleanup 2024-08-14 10:19:09 +08:00
lyken 1d7184708f
WIP: core/ndstrides: checkpoint 9 2024-08-14 09:59:33 +08:00
lyken 82edcd9390
core/irrt: introduce irrt testing
`cargo test -F test` would compile `nac3core/irrt/irrt_test.cpp`
targetted to the host machine (it gets to use `std`) and run the
test executable.
2024-08-13 17:05:54 +08:00
lyken 0c3534c2f9
core/irrt: split irrt.cpp into headers
To scale IRRT implementations
2024-08-13 17:05:54 +08:00
lyken 5602812c8f
core/irrt: build.rs capture IR defined constants 2024-08-13 17:05:54 +08:00
lyken 51d26ad3bf
core/irrt: build.rs capture IR defined types 2024-08-13 17:05:54 +08:00
lyken 1d2c887146
core/irrt: reformat 2024-08-13 17:05:54 +08:00
lyken 7ba77ddbd6
core: add .clang-format 2024-08-13 17:05:54 +08:00
lyken a5a25f41bb
core/irrt: comment build.rs & move irrt to its own dir
To prepare for future IRRT implementations, and to also make cargo
only have to watch a single directory.
2024-08-13 17:05:54 +08:00
lyken 432c81a500
core: update insta after #489 2024-08-13 15:30:34 +08:00
David Mak 6beff7a268 [artiq] Implement core_log and rtio_log in terms of polymorphic_print
Implementation mostly references the original implementation in Python.
2024-08-13 15:19:03 +08:00
David Mak 6ca7aecd4a [artiq] Add core_log and rtio_log function declarations 2024-08-13 15:19:03 +08:00
David Mak 8fd7216243 [core] toplevel/composer: Add lateinit_builtins
This is required for the new core_log and rtio_log functions, which take
a generic type as its parameter. However, in ARTIQ builtins are
initialized using one unifier and then actually used by another unifier.

lateinit_builtins workaround this issue by deferring the initialization
of functions requiring type variables until the actual unifier is ready.
2024-08-13 15:19:03 +08:00
David Mak 4f5e417012 [core] codegen: Add function to get format constants for integers 2024-08-13 15:19:03 +08:00
David Mak a0614bad83 [core] codegen/expr: Make gen_string return `StructValue`
So that it is clear that the value itself is returned rather than a
pointer to the struct or its data.
2024-08-13 15:19:03 +08:00
David Mak 5539d144ed [core] Add `CodeGenContext::build_in_bounds_gep_and_load`
For safer accesses to `gep`-able values and faster fails.
2024-08-13 15:19:03 +08:00
David Mak b3891b9a0d standalone: Fix several issues post script refactoring
- Add helptext for check_demos.sh
- Add back support for using debug NAC3 for running tests
- Output error message when argument is not recognized
- Fixed last non-demo script argument being ignored
- Add back SSE2 requirement to NAC3 (required for mandelbrot)
2024-08-13 15:19:03 +08:00
David Mak 6fb8939179 [meta] Update dependencies 2024-08-13 15:19:03 +08:00
lyken 973dc5041a core/typecheck: Support tuple arg type in len() 2024-08-13 15:02:59 +08:00
David Mak d0da688aa7 standalone: Add tuple len test 2024-08-13 15:02:59 +08:00
David Mak 12c4e1cf48 core/toplevel/builtins: Add support for len() on tuples 2024-08-13 15:02:59 +08:00
David Mak 9b988647ed core/toplevel/builtins: Extract len() into builtin function 2024-08-13 15:02:59 +08:00
lyken 35a7cecc12
core/typecheck: fix np_array ndmin bug 2024-08-13 12:50:04 +08:00
lyken 7e3d87f841 core/codegen: fix bug in call_ceil function 2024-08-07 16:40:55 +08:00
David Mak ac0d83ef98 standalone: Add vararg.py 2024-08-06 11:48:42 +08:00
David Mak 3ff6db1a29 core/codegen: Add va_start and va_end intrinsics 2024-08-06 11:48:42 +08:00
David Mak d7b806afb4 core/codegen: Implement support for va_info on supported architectures 2024-08-06 11:48:40 +08:00
David Mak fac60c3974 core/codegen: Handle vararg in function generation 2024-08-06 11:46:00 +08:00
David Mak f5fb504a15 core/codegen/expr: Implement vararg handling in gen_call 2024-08-06 11:46:00 +08:00
David Mak faa3bb97ad core/typecheck/typedef: Add vararg to Unifier::stringify 2024-08-06 11:46:00 +08:00
David Mak 6a64c9d1de core/typecheck/typedef: Add is_vararg_ctx to TTuple 2024-08-06 11:45:54 +08:00
David Mak 3dc8498202 core/typecheck/typedef: Handle vararg parameters in unify_call 2024-08-06 11:43:13 +08:00
David Mak cbf79c5e9c core/typecheck/typedef: Add is_vararg to FuncArg, ConcreteFuncArg 2024-08-06 11:43:13 +08:00
David Mak b8aa17bf8c core/toplevel/composer: Add parsing for vararg parameter 2024-08-06 10:52:24 +08:00
David Mak f5b998cd9c core/codegen: Remove unnecessary mut from get_llvm*_type 2024-08-06 10:52:24 +08:00
David Mak c36f85ecb9 meta: Update dependencies 2024-08-06 10:52:24 +08:00
lyken 3a8c385e01 core/typecheck: fix missing ExprKind::Asterisk in fix_assignment_target_context 2024-08-05 19:30:48 +08:00
lyken 221de4d06a core/codegen: add missing comment 2024-08-05 19:30:48 +08:00
lyken fb9fe8edf2 core: reimplement assignment type inference and codegen
- distinguish between setitem and getitem
- allow starred assignment targets, but the assigned value would be a tuple
- allow both [...] and (...) to be target lists
2024-08-05 19:30:48 +08:00
lyken 894083c6a3 core/codegen: refactor gen_{for,comprehension} to match on iter type 2024-08-05 19:30:48 +08:00
Sébastien Bourdeauducq 669c6aca6b clean up and fix 32-bit demos 2024-08-05 19:04:25 +08:00
abdul124 63d2b49b09 core: remove np_linalg_matmul 2024-08-05 11:44:55 +08:00
abdul124 bf709889c4 standalone/demo: separate linalg functions from main workspace 2024-08-05 11:44:54 +08:00
abdul124 1c72698d02 core: add np_linalg_det and np_linalg_matrix_power functions 2024-07-31 18:02:54 +08:00
abdul124 54f883f0a5 core: implement np_dot using LLVM_IR 2024-07-31 15:53:51 +08:00
abdul124 4a6845dac6 standalone: add np.transpose and np.reshape functions 2024-07-31 13:23:07 +08:00
abdul124 00236f48bc core: add np.transpose and np.reshape functions 2024-07-31 13:23:07 +08:00
abdul124 a3e6bb2292 core/helper: add linalg section 2024-07-31 13:23:07 +08:00
abdul124 17171065b1 standalone: link linalg at runtime 2024-07-31 13:23:07 +08:00
abdul124 540b35ec84 standalone: move linalg functions to demo 2024-07-31 13:23:05 +08:00
abdul124 4bb00c52e3 core/builtin_fns: improve error reporting 2024-07-31 13:21:31 +08:00
abdul124 faf07527cb standalone: add runtime implementation for linalg functions 2024-07-31 13:21:28 +08:00
abdul124 d6a4d0a634 standalone: add linalg methods and tests 2024-07-29 16:48:06 +08:00
abdul124 2242c5af43 core: add linalg methods 2024-07-29 16:48:06 +08:00
David Mak 318a675ea6 standalone: Rename -m32 to -i386 2024-07-29 14:58:58 +08:00
David Mak 32e52ce198 standalone: Revert using uint32_t as slice length
Turns out list and str have always been size_t.
2024-07-29 14:58:29 +08:00
Sebastien Bourdeauducq 665ca8e32d cargo: update dependencies 2024-07-27 22:24:56 +08:00
Sebastien Bourdeauducq 12c12b1d80 flake: update nixpkgs 2024-07-27 22:22:20 +08:00
lyken 72972fa909 core/toplevel: add more numpy categories 2024-07-27 21:57:47 +08:00
lyken 142cd48594 core/toplevel: reorder PrimDef::details 2024-07-27 21:57:47 +08:00
lyken 8adfe781c5 core/toplevel: fix PrimDef method names 2024-07-27 21:57:47 +08:00
lyken 339b74161b core/toplevel: reorganize PrimDef 2024-07-27 21:57:47 +08:00
David Mak 8c5ba37d09 standalone: Add 32-bit execution tests to check_demo.sh 2024-07-26 13:35:40 +08:00
David Mak 05a8948ff2 core: Minor cleanup to use ListValue APIs 2024-07-26 13:35:40 +08:00
David Mak 6d171ec284 core: Add label name and hooks to gen_for functions 2024-07-26 13:35:40 +08:00
David Mak 0ba68f6657 core: Set target triple and datalayout for each module
Fixes an issue with inconsistent pointer sizes causing crashes.
2024-07-26 13:35:40 +08:00
David Mak 693b2a8863 core: Add support for 32-bit size_t on 64-bit targets 2024-07-26 13:35:40 +08:00
David Mak 5faeede0e5 Determine size_t using LLVM target machine 2024-07-26 13:35:38 +08:00
David Mak 266707df9d standalone: Add support for running 32-bit binaries 2024-07-26 13:32:38 +08:00
David Mak 3d3c258756 standalone: Remove support for --lli 2024-07-26 13:32:38 +08:00
David Mak ed1182cb24 standalone: Update format specifiers for exceptions
Use platform-agnostic identifiers instead.
2024-07-26 13:32:37 +08:00
David Mak fd025c1137 standalone: Use uint32_t for cslice length
Matching the expected type of string and list slices.
2024-07-26 13:32:21 +08:00
David Mak f139db9af9 meta: Update dependencies 2024-07-26 10:33:02 +08:00
lyken 44487b76ae standalone: interpret_demo.py remove duplicated section 2024-07-22 17:23:35 +08:00
lyken 1332f113e8 standalone: fix interpret_demo.py comments 2024-07-22 17:06:14 +08:00
Sébastien Bourdeauducq 7632d6f72a cargo: update dependencies 2024-07-21 11:00:25 +08:00
David Mak 4948395ca2 core/toplevel/type_annotation: Add handling for mismatching class def
Primitive types only contain fields in its Type and not its TopLevelDef.
This causes primitive object types to lack some fields.
2024-07-19 14:42:14 +08:00
David Mak 3db3061d99 artiq/symbol_resolver: Handle type of zero-length lists 2024-07-19 14:42:14 +08:00
David Mak 51c2175c80 core/codegen/stmt: Convert assertion values to i1 2024-07-19 14:42:14 +08:00
lyken 1a31a50b8a
standalone: fix __nac3_raise def in demo.c 2024-07-17 21:22:08 +08:00
lyken 6c10e3d056 core: cargo clippy 2024-07-12 21:18:53 +08:00
lyken 2dbc1ec659 cargo fmt 2024-07-12 21:16:38 +08:00
Sebastien Bourdeauducq c80378063a add np_argmin/argmax to interpret_demo environment 2024-07-12 13:27:52 +02:00
abdul124 513d30152b core: support raise exception short form 2024-07-12 18:58:34 +08:00
abdul124 45e9360c4d standalone: Add np_argmax and np_argmin tests 2024-07-12 18:19:56 +08:00
abdul124 2e01b77fc8 core: refactor np_max/np_min functions 2024-07-12 18:18:54 +08:00
abdul124 cea7cade51 core: add np_argmax/np_argmin functions 2024-07-12 18:18:28 +08:00
111 changed files with 14332 additions and 4095 deletions

3
.clang-format Normal file
View File

@ -0,0 +1,3 @@
BasedOnStyle: Google
IndentWidth: 4
ReflowComments: false

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
__pycache__ __pycache__
/target /target
/nac3standalone/demo/linalg/target
nix/windows/msys2 nix/windows/msys2

167
Cargo.lock generated
View File

@ -26,9 +26,9 @@ dependencies = [
[[package]] [[package]]
name = "anstream" name = "anstream"
version = "0.6.14" version = "0.6.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "418c75fa768af9c03be99d17643f93f79bbba589895012a80e3452a19ddda15b" checksum = "64e15c1ab1f89faffbf04a634d5e1962e9074f2741eef6d97f3c4e322426d526"
dependencies = [ dependencies = [
"anstyle", "anstyle",
"anstyle-parse", "anstyle-parse",
@ -41,36 +41,36 @@ dependencies = [
[[package]] [[package]]
name = "anstyle" name = "anstyle"
version = "1.0.7" version = "1.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "038dfcf04a5feb68e9c60b21c9625a54c2c0616e79b72b0fd87075a056ae1d1b" checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1"
[[package]] [[package]]
name = "anstyle-parse" name = "anstyle-parse"
version = "0.2.4" version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c03a11a9034d92058ceb6ee011ce58af4a9bf61491aa7e1e59ecd24bd40d22d4" checksum = "eb47de1e80c2b463c735db5b217a0ddc39d612e7ac9e2e96a5aed1f57616c1cb"
dependencies = [ dependencies = [
"utf8parse", "utf8parse",
] ]
[[package]] [[package]]
name = "anstyle-query" name = "anstyle-query"
version = "1.1.0" version = "1.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad186efb764318d35165f1758e7dcef3b10628e26d41a44bc5550652e6804391" checksum = "6d36fc52c7f6c869915e99412912f22093507da8d9e942ceaf66fe4b7c14422a"
dependencies = [ dependencies = [
"windows-sys", "windows-sys 0.52.0",
] ]
[[package]] [[package]]
name = "anstyle-wincon" name = "anstyle-wincon"
version = "3.0.3" version = "3.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61a38449feb7068f52bb06c12759005cf459ee52bb4adc1d5a7c4322d716fb19" checksum = "5bf74e1b6e971609db8ca7a9ce79fd5768ab6ae46441c572e46cf596f59e57f8"
dependencies = [ dependencies = [
"anstyle", "anstyle",
"windows-sys", "windows-sys 0.52.0",
] ]
[[package]] [[package]]
@ -117,9 +117,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.1.0" version = "1.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eaff6f8ce506b9773fa786672d63fc7a191ffea1be33f72bbd4aeacefca9ffc8" checksum = "e9e8aabfac534be767c909e0690571677d49f41bd8465ae876fe043d52ba5292"
[[package]] [[package]]
name = "cfg-if" name = "cfg-if"
@ -129,9 +129,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]] [[package]]
name = "clap" name = "clap"
version = "4.5.9" version = "4.5.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "64acc1846d54c1fe936a78dc189c34e28d3f5afc348403f28ecf53660b9b8462" checksum = "11d8838454fda655dafd3accb2b6e2bea645b9e4078abe84a22ceb947235c5cc"
dependencies = [ dependencies = [
"clap_builder", "clap_builder",
"clap_derive", "clap_derive",
@ -139,9 +139,9 @@ dependencies = [
[[package]] [[package]]
name = "clap_builder" name = "clap_builder"
version = "4.5.9" version = "4.5.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fb8393d67ba2e7bfaf28a23458e4e2b543cc73a99595511eb207fdb8aede942" checksum = "216aec2b177652e3846684cbfe25c9964d18ec45234f0f5da5157b207ed1aab6"
dependencies = [ dependencies = [
"anstream", "anstream",
"anstyle", "anstyle",
@ -151,27 +151,27 @@ dependencies = [
[[package]] [[package]]
name = "clap_derive" name = "clap_derive"
version = "4.5.8" version = "4.5.13"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2bac35c6dafb060fd4d275d9a4ffae97917c13a6327903a8be2153cd964f7085" checksum = "501d359d5f3dcaf6ecdeee48833ae73ec6e42723a1e52419c79abf9507eec0a0"
dependencies = [ dependencies = [
"heck 0.5.0", "heck 0.5.0",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.70", "syn 2.0.74",
] ]
[[package]] [[package]]
name = "clap_lex" name = "clap_lex"
version = "0.7.1" version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4b82cf0babdbd58558212896d1a4272303a57bdb245c2bf1147185fb45640e70" checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97"
[[package]] [[package]]
name = "colorchoice" name = "colorchoice"
version = "1.0.1" version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422" checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0"
[[package]] [[package]]
name = "console" name = "console"
@ -182,7 +182,7 @@ dependencies = [
"encode_unicode", "encode_unicode",
"lazy_static", "lazy_static",
"libc", "libc",
"windows-sys", "windows-sys 0.52.0",
] ]
[[package]] [[package]]
@ -302,7 +302,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba"
dependencies = [ dependencies = [
"libc", "libc",
"windows-sys", "windows-sys 0.52.0",
] ]
[[package]] [[package]]
@ -385,9 +385,9 @@ dependencies = [
[[package]] [[package]]
name = "indexmap" name = "indexmap"
version = "2.2.6" version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" checksum = "de3fc2e30ba82dd1b3911c8de1ffc143c74a914a14e99514d7637e3099df5ea0"
dependencies = [ dependencies = [
"equivalent", "equivalent",
"hashbrown 0.14.5", "hashbrown 0.14.5",
@ -421,7 +421,7 @@ checksum = "4fa4d8d74483041a882adaa9a29f633253a66dde85055f0495c121620ac484b2"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.70", "syn 2.0.74",
] ]
[[package]] [[package]]
@ -440,9 +440,9 @@ dependencies = [
[[package]] [[package]]
name = "is_terminal_polyfill" name = "is_terminal_polyfill"
version = "1.70.0" version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f8478577c03552c21db0e2724ffb8986a5ce7af88107e6be5d2ee6e158c12800" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
[[package]] [[package]]
name = "itertools" name = "itertools"
@ -513,9 +513,9 @@ checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c"
[[package]] [[package]]
name = "libloading" name = "libloading"
version = "0.8.4" version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e310b3a6b5907f99202fcdb4960ff45b93735d7c7d96b760fcff8db2dc0e103d" checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"windows-targets", "windows-targets",
@ -616,7 +616,7 @@ name = "nac3core"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"crossbeam", "crossbeam",
"indexmap 2.2.6", "indexmap 2.3.0",
"indoc", "indoc",
"inkwell", "inkwell",
"insta", "insta",
@ -706,7 +706,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db"
dependencies = [ dependencies = [
"fixedbitset", "fixedbitset",
"indexmap 2.2.6", "indexmap 2.3.0",
] ]
[[package]] [[package]]
@ -749,7 +749,7 @@ dependencies = [
"phf_shared 0.11.2", "phf_shared 0.11.2",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.70", "syn 2.0.74",
] ]
[[package]] [[package]]
@ -778,15 +778,18 @@ checksum = "5be167a7af36ee22fe3115051bc51f6e6c7054c9348e28deb4f49bd6f705a315"
[[package]] [[package]]
name = "portable-atomic" name = "portable-atomic"
version = "1.6.0" version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265"
[[package]] [[package]]
name = "ppv-lite86" name = "ppv-lite86"
version = "0.2.17" version = "0.2.20"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04"
dependencies = [
"zerocopy",
]
[[package]] [[package]]
name = "precomputed-hash" name = "precomputed-hash"
@ -850,7 +853,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"pyo3-macros-backend", "pyo3-macros-backend",
"quote", "quote",
"syn 2.0.70", "syn 2.0.74",
] ]
[[package]] [[package]]
@ -863,7 +866,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"pyo3-build-config", "pyo3-build-config",
"quote", "quote",
"syn 2.0.70", "syn 2.0.74",
] ]
[[package]] [[package]]
@ -927,9 +930,9 @@ dependencies = [
[[package]] [[package]]
name = "redox_syscall" name = "redox_syscall"
version = "0.5.2" version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c82cf8cff14456045f55ec4241383baeff27af886adb72ffb2162f99911de0fd" checksum = "2a908a6e00f1fdd0dfd9c0eb08ce85126f6d8bbda50017e74bc4a4b7d4a926a4"
dependencies = [ dependencies = [
"bitflags", "bitflags",
] ]
@ -947,9 +950,9 @@ dependencies = [
[[package]] [[package]]
name = "regex" name = "regex"
version = "1.10.5" version = "1.10.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f" checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619"
dependencies = [ dependencies = [
"aho-corasick", "aho-corasick",
"memchr", "memchr",
@ -991,7 +994,7 @@ dependencies = [
"errno", "errno",
"libc", "libc",
"linux-raw-sys", "linux-raw-sys",
"windows-sys", "windows-sys 0.52.0",
] ]
[[package]] [[package]]
@ -1029,31 +1032,32 @@ checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b"
[[package]] [[package]]
name = "serde" name = "serde"
version = "1.0.204" version = "1.0.206"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc76f558e0cbb2a839d37354c575f1dc3fdc6546b5be373ba43d95f231bf7c12" checksum = "5b3e4cd94123dd520a128bcd11e34d9e9e423e7e3e50425cb1b4b1e3549d0284"
dependencies = [ dependencies = [
"serde_derive", "serde_derive",
] ]
[[package]] [[package]]
name = "serde_derive" name = "serde_derive"
version = "1.0.204" version = "1.0.206"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222" checksum = "fabfb6138d2383ea8208cf98ccf69cdfb1aff4088460681d84189aa259762f97"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.70", "syn 2.0.74",
] ]
[[package]] [[package]]
name = "serde_json" name = "serde_json"
version = "1.0.120" version = "1.0.124"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4e0d21c9a8cae1235ad58a00c11cb40d4b1e5c784f1ef2c537876ed6ffd8b7c5" checksum = "66ad62847a56b3dba58cc891acd13884b9c61138d330c0d7b6181713d4fce38d"
dependencies = [ dependencies = [
"itoa", "itoa",
"memchr",
"ryu", "ryu",
"serde", "serde",
] ]
@ -1072,9 +1076,9 @@ dependencies = [
[[package]] [[package]]
name = "similar" name = "similar"
version = "2.5.0" version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa42c91313f1d05da9b26f267f931cf178d4aba455b4c4622dd7355eb80c6640" checksum = "1de1d4f81173b03af4c0cbed3c898f6bff5b870e4a7f5d6f4057d62a7a4b686e"
[[package]] [[package]]
name = "siphasher" name = "siphasher"
@ -1134,7 +1138,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"rustversion", "rustversion",
"syn 2.0.70", "syn 2.0.74",
] ]
[[package]] [[package]]
@ -1150,9 +1154,9 @@ dependencies = [
[[package]] [[package]]
name = "syn" name = "syn"
version = "2.0.70" version = "2.0.74"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f0209b68b3613b093e0ec905354eccaedcfe83b8cb37cbdeae64026c3064c16" checksum = "1fceb41e3d546d0bd83421d3409b1460cc7444cd389341a4c880fe7a042cb3d7"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -1161,20 +1165,21 @@ dependencies = [
[[package]] [[package]]
name = "target-lexicon" name = "target-lexicon"
version = "0.12.15" version = "0.12.16"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4873307b7c257eddcb50c9bedf158eb669578359fb28428bef438fec8e6ba7c2" checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
[[package]] [[package]]
name = "tempfile" name = "tempfile"
version = "3.10.1" version = "3.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1" checksum = "04cbcdd0c794ebb0d4cf35e88edd2f7d2c4c3e9a5a6dab322839b321c6a87a64"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"fastrand", "fastrand",
"once_cell",
"rustix", "rustix",
"windows-sys", "windows-sys 0.59.0",
] ]
[[package]] [[package]]
@ -1203,22 +1208,22 @@ dependencies = [
[[package]] [[package]]
name = "thiserror" name = "thiserror"
version = "1.0.61" version = "1.0.63"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709" checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724"
dependencies = [ dependencies = [
"thiserror-impl", "thiserror-impl",
] ]
[[package]] [[package]]
name = "thiserror-impl" name = "thiserror-impl"
version = "1.0.61" version = "1.0.63"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.70", "syn 2.0.74",
] ]
[[package]] [[package]]
@ -1336,9 +1341,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]] [[package]]
name = "version_check" name = "version_check"
version = "0.9.4" version = "0.9.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
[[package]] [[package]]
name = "walkdir" name = "walkdir"
@ -1374,11 +1379,11 @@ checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]] [[package]]
name = "winapi-util" name = "winapi-util"
version = "0.1.8" version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4d4cc384e1e73b93bafa6fb4f1df8c41695c8a91cf9c4c64358067d15a7b6c6b" checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb"
dependencies = [ dependencies = [
"windows-sys", "windows-sys 0.59.0",
] ]
[[package]] [[package]]
@ -1396,6 +1401,15 @@ dependencies = [
"windows-targets", "windows-targets",
] ]
[[package]]
name = "windows-sys"
version = "0.59.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b"
dependencies = [
"windows-targets",
]
[[package]] [[package]]
name = "windows-targets" name = "windows-targets"
version = "0.52.6" version = "0.52.6"
@ -1475,6 +1489,7 @@ version = "0.7.35"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0"
dependencies = [ dependencies = [
"byteorder",
"zerocopy-derive", "zerocopy-derive",
] ]
@ -1486,5 +1501,5 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.70", "syn 2.0.74",
] ]

View File

@ -2,11 +2,11 @@
"nodes": { "nodes": {
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1720418205, "lastModified": 1721924956,
"narHash": "sha256-cPJoFPXU44GlhWg4pUk9oUPqurPlCFZ11ZQPk21GTPU=", "narHash": "sha256-Sb1jlyRO+N8jBXEX9Pg9Z1Qb8Bw9QyOgLDNMEpmjZ2M=",
"owner": "NixOS", "owner": "NixOS",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "655a58a72a6601292512670343087c2d75d859c1", "rev": "5ad6a14c6bf098e98800b091668718c336effc95",
"type": "github" "type": "github"
}, },
"original": { "original": {

View File

@ -6,6 +6,7 @@
outputs = { self, nixpkgs }: outputs = { self, nixpkgs }:
let let
pkgs = import nixpkgs { system = "x86_64-linux"; }; pkgs = import nixpkgs { system = "x86_64-linux"; };
pkgs32 = import nixpkgs { system = "i686-linux"; };
in rec { in rec {
packages.x86_64-linux = rec { packages.x86_64-linux = rec {
llvm-nac3 = pkgs.callPackage ./nix/llvm {}; llvm-nac3 = pkgs.callPackage ./nix/llvm {};
@ -13,8 +14,25 @@
'' ''
mkdir -p $out/bin mkdir -p $out/bin
ln -s ${pkgs.llvmPackages_14.clang-unwrapped}/bin/clang $out/bin/clang-irrt ln -s ${pkgs.llvmPackages_14.clang-unwrapped}/bin/clang $out/bin/clang-irrt
ln -s ${pkgs.llvmPackages_14.clang}/bin/clang $out/bin/clang-irrt-test
ln -s ${pkgs.llvmPackages_14.llvm.out}/bin/llvm-as $out/bin/llvm-as-irrt ln -s ${pkgs.llvmPackages_14.llvm.out}/bin/llvm-as $out/bin/llvm-as-irrt
''; '';
demo-linalg-stub = pkgs.rustPlatform.buildRustPackage {
name = "demo-linalg-stub";
src = ./nac3standalone/demo/linalg;
cargoLock = {
lockFile = ./nac3standalone/demo/linalg/Cargo.lock;
};
doCheck = false;
};
demo-linalg-stub32 = pkgs32.rustPlatform.buildRustPackage {
name = "demo-linalg-stub32";
src = ./nac3standalone/demo/linalg;
cargoLock = {
lockFile = ./nac3standalone/demo/linalg/Cargo.lock;
};
doCheck = false;
};
nac3artiq = pkgs.python3Packages.toPythonModule ( nac3artiq = pkgs.python3Packages.toPythonModule (
pkgs.rustPlatform.buildRustPackage rec { pkgs.rustPlatform.buildRustPackage rec {
name = "nac3artiq"; name = "nac3artiq";
@ -23,8 +41,9 @@
cargoLock = { cargoLock = {
lockFile = ./Cargo.lock; lockFile = ./Cargo.lock;
}; };
cargoTestFlags = [ "--features" "test" ];
passthru.cargoLock = cargoLock; passthru.cargoLock = cargoLock;
nativeBuildInputs = [ pkgs.python3 pkgs.llvmPackages_14.clang llvm-tools-irrt pkgs.llvmPackages_14.llvm.out llvm-nac3 ]; nativeBuildInputs = [ pkgs.python3 (pkgs.wrapClangMulti pkgs.llvmPackages_14.clang) llvm-tools-irrt pkgs.llvmPackages_14.llvm.out llvm-nac3 ];
buildInputs = [ pkgs.python3 llvm-nac3 ]; buildInputs = [ pkgs.python3 llvm-nac3 ];
checkInputs = [ (pkgs.python3.withPackages(ps: [ ps.numpy ps.scipy ])) ]; checkInputs = [ (pkgs.python3.withPackages(ps: [ ps.numpy ps.scipy ])) ];
checkPhase = checkPhase =
@ -32,7 +51,9 @@
echo "Checking nac3standalone demos..." echo "Checking nac3standalone demos..."
pushd nac3standalone/demo pushd nac3standalone/demo
patchShebangs . patchShebangs .
./check_demos.sh export DEMO_LINALG_STUB=${demo-linalg-stub}/lib/liblinalg.a
export DEMO_LINALG_STUB32=${demo-linalg-stub32}/lib/liblinalg.a
./check_demos.sh -i686
popd popd
echo "Running Cargo tests..." echo "Running Cargo tests..."
cargoCheckHook cargoCheckHook
@ -149,7 +170,7 @@
buildInputs = with pkgs; [ buildInputs = with pkgs; [
# build dependencies # build dependencies
packages.x86_64-linux.llvm-nac3 packages.x86_64-linux.llvm-nac3
llvmPackages_14.clang llvmPackages_14.llvm.out # for running nac3standalone demos (pkgs.wrapClangMulti llvmPackages_14.clang) llvmPackages_14.llvm.out # for running nac3standalone demos
packages.x86_64-linux.llvm-tools-irrt packages.x86_64-linux.llvm-tools-irrt
cargo cargo
rustc rustc
@ -162,6 +183,11 @@
pre-commit pre-commit
rustfmt rustfmt
]; ];
shellHook =
''
export DEMO_LINALG_STUB=${packages.x86_64-linux.demo-linalg-stub}/lib/liblinalg.a
export DEMO_LINALG_STUB32=${packages.x86_64-linux.demo-linalg-stub32}/lib/liblinalg.a
'';
}; };
devShells.x86_64-linux.msys2 = pkgs.mkShell { devShells.x86_64-linux.msys2 = pkgs.mkShell {
name = "nac3-dev-shell-msys2"; name = "nac3-dev-shell-msys2";

View File

@ -0,0 +1,24 @@
from min_artiq import *
from numpy import int32
@nac3
class EmptyList:
core: KernelInvariant[Core]
def __init__(self):
self.core = Core()
@rpc
def get_empty(self) -> list[int32]:
return []
@kernel
def run(self):
a: list[int32] = self.get_empty()
if a != []:
raise ValueError
if __name__ == "__main__":
EmptyList().run()

View File

@ -1,8 +1,10 @@
use nac3core::{ use nac3core::{
codegen::{ codegen::{
expr::gen_call, classes::{ListValue, NDArrayValue, RangeValue, UntypedArrayLikeAccessor},
expr::{destructure_range, gen_call},
irrt::call_ndarray_calc_size,
llvm_intrinsics::{call_int_smax, call_stackrestore, call_stacksave}, llvm_intrinsics::{call_int_smax, call_stackrestore, call_stacksave},
stmt::{gen_block, gen_with}, stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with},
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}, },
symbol_resolver::ValueEnum, symbol_resolver::ValueEnum,
@ -13,7 +15,11 @@ use nac3core::{
use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef}; use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef};
use inkwell::{ use inkwell::{
context::Context, module::Linkage, types::IntType, values::BasicValueEnum, AddressSpace, context::Context,
module::Linkage,
types::IntType,
values::{BasicValueEnum, StructValue},
AddressSpace, IntPredicate,
}; };
use pyo3::{ use pyo3::{
@ -23,10 +29,12 @@ use pyo3::{
use crate::{symbol_resolver::InnerResolver, timeline::TimeFns}; use crate::{symbol_resolver::InnerResolver, timeline::TimeFns};
use itertools::Itertools;
use std::{ use std::{
collections::hash_map::DefaultHasher, collections::{hash_map::DefaultHasher, HashMap},
collections::HashMap,
hash::{Hash, Hasher}, hash::{Hash, Hasher},
iter::once,
mem,
sync::Arc, sync::Arc,
}; };
@ -386,7 +394,7 @@ fn gen_rpc_tag(
} else { } else {
let ty_enum = ctx.unifier.get_ty(ty); let ty_enum = ctx.unifier.get_ty(ty);
match &*ty_enum { match &*ty_enum {
TTuple { ty } => { TTuple { ty, is_vararg_ctx: false } => {
buffer.push(b't'); buffer.push(b't');
buffer.push(ty.len() as u8); buffer.push(ty.len() as u8);
for ty in ty { for ty in ty {
@ -700,6 +708,7 @@ pub fn attributes_writeback(
name: i.to_string().into(), name: i.to_string().into(),
ty: *ty, ty: *ty,
default_value: None, default_value: None,
is_vararg: false,
}) })
.collect(), .collect(),
ret: ctx.primitives.none, ret: ctx.primitives.none,
@ -723,3 +732,475 @@ pub fn rpc_codegen_callback() -> Arc<GenCall> {
rpc_codegen_callback_fn(ctx, obj, fun, args, generator) rpc_codegen_callback_fn(ctx, obj, fun, args, generator)
}))) })))
} }
/// Returns the `fprintf` format constant for the given [`llvm_int_t`][`IntType`] on a platform with
/// [`llvm_usize`] as its native word size.
///
/// Note that, similar to format constants in `<inttypes.h>`, these constants need to be prepended
/// with `%`.
#[must_use]
fn get_fprintf_format_constant<'ctx>(
llvm_usize: IntType<'ctx>,
llvm_int_t: IntType<'ctx>,
is_unsigned: bool,
) -> String {
debug_assert!(matches!(llvm_usize.get_bit_width(), 8 | 16 | 32 | 64));
let conv_spec = if is_unsigned { 'u' } else { 'd' };
// https://en.cppreference.com/w/c/language/arithmetic_types
// Note that NAC3 does **not** support LP32 and LLP64 configurations
match llvm_int_t.get_bit_width() {
8 => format!("hh{conv_spec}"),
16 => format!("h{conv_spec}"),
32 => conv_spec.to_string(),
64 => format!("{}{conv_spec}", if llvm_usize.get_bit_width() == 64 { "l" } else { "ll" }),
_ => todo!(
"Not yet implemented for i{} on {}-bit platform",
llvm_int_t.get_bit_width(),
llvm_usize.get_bit_width()
),
}
}
/// Prints one or more `values` to `core_log` or `rtio_log`.
///
/// * `separator` - The separator between multiple values.
/// * `suffix` - String to terminate the printed string, if any.
/// * `as_repr` - Whether the `repr()` output of values instead of `str()`.
/// * `as_rtio` - Whether to print to `rtio_log` instead of `core_log`.
fn polymorphic_print<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
values: &[(Type, ValueEnum<'ctx>)],
separator: &str,
suffix: Option<&str>,
as_repr: bool,
as_rtio: bool,
) -> Result<(), String> {
let printf = |ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
fmt: String,
args: Vec<BasicValueEnum<'ctx>>| {
debug_assert!(!fmt.is_empty());
debug_assert_eq!(fmt.as_bytes().last().unwrap(), &0u8);
let fn_name = if as_rtio { "rtio_log" } else { "core_log" };
let print_fn = ctx.module.get_function(fn_name).unwrap_or_else(|| {
let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
let fn_t = if as_rtio {
let llvm_void = ctx.ctx.void_type();
llvm_void.fn_type(&[llvm_pi8.into()], true)
} else {
let llvm_i32 = ctx.ctx.i32_type();
llvm_i32.fn_type(&[llvm_pi8.into()], true)
};
ctx.module.add_function(fn_name, fn_t, None)
});
let fmt = ctx.gen_string(generator, &fmt).get_field(generator, ctx.ctx, |f| f.base).value;
ctx.builder
.build_call(
print_fn,
&once(fmt.into()).chain(args).map(BasicValueEnum::into).collect_vec(),
"",
)
.unwrap();
};
let llvm_i32 = ctx.ctx.i32_type();
let llvm_i64 = ctx.ctx.i64_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let suffix = suffix.unwrap_or_default();
let mut fmt = String::new();
let mut args = Vec::new();
let flush = |ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
fmt: &mut String,
args: &mut Vec<BasicValueEnum<'ctx>>| {
if !fmt.is_empty() {
fmt.push('\0');
printf(ctx, generator, mem::take(fmt), mem::take(args));
}
};
for (ty, value) in values {
let ty = *ty;
let value = value.clone().to_basic_value_enum(ctx, generator, ty).unwrap();
if !fmt.is_empty() {
fmt.push_str(separator);
}
match &*ctx.unifier.get_ty_immutable(ty) {
TypeEnum::TTuple { ty: tys, is_vararg_ctx: false } => {
let pvalue = {
let pvalue = generator.gen_var_alloc(ctx, value.get_type(), None).unwrap();
ctx.builder.build_store(pvalue, value).unwrap();
pvalue
};
fmt.push('(');
flush(ctx, generator, &mut fmt, &mut args);
let tuple_vals = tys
.iter()
.enumerate()
.map(|(i, ty)| {
(*ty, {
let pfield =
ctx.builder.build_struct_gep(pvalue, i as u32, "").unwrap();
ValueEnum::from(ctx.builder.build_load(pfield, "").unwrap())
})
})
.collect_vec();
polymorphic_print(ctx, generator, &tuple_vals, ", ", None, true, as_rtio)?;
if tuple_vals.len() == 1 {
fmt.push_str(",)");
} else {
fmt.push(')');
}
}
TypeEnum::TFunc { .. } => todo!(),
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::None.id() => {
fmt.push_str("None");
}
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::Bool.id() => {
fmt.push_str("%.*s");
let true_str = ctx.gen_string(generator, "True");
let true_data = true_str.get_field(generator, ctx.ctx, |f| f.base);
let true_len = true_str.get_field(generator, ctx.ctx, |f| f.len);
let false_str = ctx.gen_string(generator, "False");
let false_data = false_str.get_field(generator, ctx.ctx, |f| f.base);
let false_len = false_str.get_field(generator, ctx.ctx, |f| f.len);
let bool_val = generator.bool_to_i1(ctx, value.into_int_value());
args.extend([
ctx.builder
.build_select(bool_val, true_len.value, false_len.value, "")
.unwrap(),
ctx.builder
.build_select(bool_val, true_data.value, false_data.value, "")
.unwrap(),
]);
}
TypeEnum::TObj { obj_id, .. }
if *obj_id == PrimDef::Int32.id()
|| *obj_id == PrimDef::Int64.id()
|| *obj_id == PrimDef::UInt32.id()
|| *obj_id == PrimDef::UInt64.id() =>
{
let is_unsigned =
*obj_id == PrimDef::UInt32.id() || *obj_id == PrimDef::UInt64.id();
let llvm_int_t = value.get_type().into_int_type();
debug_assert!(matches!(llvm_usize.get_bit_width(), 32 | 64));
debug_assert!(matches!(llvm_int_t.get_bit_width(), 32 | 64));
let fmt_spec = format!(
"%{}",
get_fprintf_format_constant(llvm_usize, llvm_int_t, is_unsigned)
);
fmt.push_str(fmt_spec.as_str());
args.push(value);
}
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::Float.id() => {
fmt.push_str("%g");
args.push(value);
}
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::Str.id() => {
if as_repr {
fmt.push_str("\"%.*s\"");
} else {
fmt.push_str("%.*s");
}
let str = value.into_struct_value();
let str_data = unsafe { str.get_field_at_index_unchecked(0) }.into_pointer_value();
let str_len = unsafe { str.get_field_at_index_unchecked(1) }.into_int_value();
args.extend(&[str_len.into(), str_data.into()]);
}
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
let elem_ty = *params.iter().next().unwrap().1;
fmt.push('[');
flush(ctx, generator, &mut fmt, &mut args);
let val = ListValue::from_ptr_val(value.into_pointer_value(), llvm_usize, None);
let len = val.load_size(ctx, None);
let last =
ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap();
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(len, false),
|generator, ctx, _, i| {
let elem = unsafe { val.data().get_unchecked(ctx, generator, &i, None) };
polymorphic_print(
ctx,
generator,
&[(elem_ty, elem.into())],
"",
None,
true,
as_rtio,
)?;
gen_if_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(IntPredicate::ULT, i, last, "")
.unwrap())
},
|generator, ctx| {
printf(ctx, generator, ", \0".into(), Vec::default());
Ok(())
},
|_, _| Ok(()),
)?;
Ok(())
},
llvm_usize.const_int(1, false),
)?;
fmt.push(']');
flush(ctx, generator, &mut fmt, &mut args);
}
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
fmt.push_str("array([");
flush(ctx, generator, &mut fmt, &mut args);
let val = NDArrayValue::from_ptr_val(value.into_pointer_value(), llvm_usize, None);
let len = call_ndarray_calc_size(generator, ctx, &val.dim_sizes(), (None, None));
let last =
ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap();
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(len, false),
|generator, ctx, _, i| {
let elem = unsafe { val.data().get_unchecked(ctx, generator, &i, None) };
polymorphic_print(
ctx,
generator,
&[(elem_ty, elem.into())],
"",
None,
true,
as_rtio,
)?;
gen_if_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(IntPredicate::ULT, i, last, "")
.unwrap())
},
|generator, ctx| {
printf(ctx, generator, ", \0".into(), Vec::default());
Ok(())
},
|_, _| Ok(()),
)?;
Ok(())
},
llvm_usize.const_int(1, false),
)?;
fmt.push_str(")]");
flush(ctx, generator, &mut fmt, &mut args);
}
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::Range.id() => {
fmt.push_str("range(");
flush(ctx, generator, &mut fmt, &mut args);
let val = RangeValue::from_ptr_val(value.into_pointer_value(), None);
let (start, stop, step) = destructure_range(ctx, val);
polymorphic_print(
ctx,
generator,
&[
(ctx.primitives.int32, start.into()),
(ctx.primitives.int32, stop.into()),
(ctx.primitives.int32, step.into()),
],
", ",
None,
false,
as_rtio,
)?;
fmt.push(')');
}
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::Exception.id() => {
let fmt_str = format!(
"%{}(%{}, %{1:}, %{1:})",
get_fprintf_format_constant(llvm_usize, llvm_i32, false),
get_fprintf_format_constant(llvm_usize, llvm_i64, false),
);
let exn = value.into_pointer_value();
let name = ctx
.build_in_bounds_gep_and_load(
exn,
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
None,
)
.into_int_value();
let param0 = ctx
.build_in_bounds_gep_and_load(
exn,
&[llvm_i32.const_zero(), llvm_i32.const_int(6, false)],
None,
)
.into_int_value();
let param1 = ctx
.build_in_bounds_gep_and_load(
exn,
&[llvm_i32.const_zero(), llvm_i32.const_int(7, false)],
None,
)
.into_int_value();
let param2 = ctx
.build_in_bounds_gep_and_load(
exn,
&[llvm_i32.const_zero(), llvm_i32.const_int(8, false)],
None,
)
.into_int_value();
fmt.push_str(fmt_str.as_str());
args.extend_from_slice(&[name.into(), param0.into(), param1.into(), param2.into()]);
}
_ => unreachable!(
"Unsupported object type for polymorphic_print: {}",
ctx.unifier.stringify(ty)
),
}
}
fmt.push_str(suffix);
flush(ctx, generator, &mut fmt, &mut args);
Ok(())
}
/// Invokes the `core_log` intrinsic function.
pub fn call_core_log_impl<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
arg: (Type, BasicValueEnum<'ctx>),
) -> Result<(), String> {
let (arg_ty, arg_val) = arg;
polymorphic_print(ctx, generator, &[(arg_ty, arg_val.into())], " ", Some("\n"), false, false)?;
Ok(())
}
/// Invokes the `rtio_log` intrinsic function.
pub fn call_rtio_log_impl<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
channel: StructValue<'ctx>,
arg: (Type, BasicValueEnum<'ctx>),
) -> Result<(), String> {
let (arg_ty, arg_val) = arg;
polymorphic_print(
ctx,
generator,
&[(ctx.primitives.str, channel.into())],
" ",
Some("\x1E"),
false,
true,
)?;
polymorphic_print(ctx, generator, &[(arg_ty, arg_val.into())], " ", Some("\x1D"), false, true)?;
Ok(())
}
/// Generates a call to `core_log`.
pub fn gen_core_log<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<(), String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let value_ty = fun.0.args[0].ty;
let value_arg = args[0].1.clone().to_basic_value_enum(ctx, generator, value_ty)?;
call_core_log_impl(ctx, generator, (value_ty, value_arg))
}
/// Generates a call to `rtio_log`.
pub fn gen_rtio_log<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<(), String> {
assert!(obj.is_none());
assert_eq!(args.len(), 2);
let channel_ty = fun.0.args[0].ty;
assert!(ctx.unifier.unioned(channel_ty, ctx.primitives.str));
let channel_arg =
args[0].1.clone().to_basic_value_enum(ctx, generator, channel_ty)?.into_struct_value();
let value_ty = fun.0.args[1].ty;
let value_arg = args[1].1.clone().to_basic_value_enum(ctx, generator, value_ty)?;
call_rtio_log_impl(ctx, generator, channel_arg, (value_ty, value_arg))
}

View File

@ -24,6 +24,7 @@ use std::rc::Rc;
use std::sync::Arc; use std::sync::Arc;
use inkwell::{ use inkwell::{
context::Context,
memory_buffer::MemoryBuffer, memory_buffer::MemoryBuffer,
module::{Linkage, Module}, module::{Linkage, Module},
passes::PassBuilderOptions, passes::PassBuilderOptions,
@ -32,9 +33,10 @@ use inkwell::{
OptimizationLevel, OptimizationLevel,
}; };
use itertools::Itertools; use itertools::Itertools;
use nac3core::codegen::irrt::setup_irrt_exceptions;
use nac3core::codegen::{gen_func_impl, CodeGenLLVMOptions, CodeGenTargetMachineOptions}; use nac3core::codegen::{gen_func_impl, CodeGenLLVMOptions, CodeGenTargetMachineOptions};
use nac3core::toplevel::builtins::get_exn_constructor; use nac3core::toplevel::builtins::get_exn_constructor;
use nac3core::typecheck::typedef::{TypeEnum, Unifier, VarMap}; use nac3core::typecheck::typedef::{into_var_map, TypeEnum, Unifier, VarMap};
use nac3parser::{ use nac3parser::{
ast::{ExprKind, Stmt, StmtKind, StrRef}, ast::{ExprKind, Stmt, StmtKind, StrRef},
parser::parse_program, parser::parse_program,
@ -50,7 +52,7 @@ use nac3core::{
codegen::{concrete_type::ConcreteTypeStore, CodeGenTask, WithCall, WorkerRegistry}, codegen::{concrete_type::ConcreteTypeStore, CodeGenTask, WithCall, WorkerRegistry},
symbol_resolver::SymbolResolver, symbol_resolver::SymbolResolver,
toplevel::{ toplevel::{
composer::{ComposerConfig, TopLevelComposer}, composer::{BuiltinFuncCreator, BuiltinFuncSpec, ComposerConfig, TopLevelComposer},
DefinitionId, GenCall, TopLevelDef, DefinitionId, GenCall, TopLevelDef,
}, },
typecheck::typedef::{FunSignature, FuncArg}, typecheck::typedef::{FunSignature, FuncArg},
@ -59,13 +61,13 @@ use nac3core::{
use nac3ld::Linker; use nac3ld::Linker;
use tempfile::{self, TempDir};
use crate::codegen::attributes_writeback;
use crate::{ use crate::{
codegen::{rpc_codegen_callback, ArtiqCodeGenerator}, codegen::{
attributes_writeback, gen_core_log, gen_rtio_log, rpc_codegen_callback, ArtiqCodeGenerator,
},
symbol_resolver::{DeferredEvaluationStore, InnerResolver, PythonHelper, Resolver}, symbol_resolver::{DeferredEvaluationStore, InnerResolver, PythonHelper, Resolver},
}; };
use tempfile::{self, TempDir};
mod codegen; mod codegen;
mod symbol_resolver; mod symbol_resolver;
@ -126,7 +128,7 @@ struct Nac3 {
isa: Isa, isa: Isa,
time_fns: &'static (dyn TimeFns + Sync), time_fns: &'static (dyn TimeFns + Sync),
primitive: PrimitiveStore, primitive: PrimitiveStore,
builtins: Vec<(StrRef, FunSignature, Arc<GenCall>)>, builtins: Vec<BuiltinFuncSpec>,
pyid_to_def: Arc<RwLock<HashMap<u64, DefinitionId>>>, pyid_to_def: Arc<RwLock<HashMap<u64, DefinitionId>>>,
primitive_ids: PrimitivePythonId, primitive_ids: PrimitivePythonId,
working_directory: TempDir, working_directory: TempDir,
@ -264,7 +266,7 @@ impl Nac3 {
arg_names.len(), arg_names.len(),
)); ));
} }
for (i, FuncArg { ty, default_value, name }) in args.iter().enumerate() { for (i, FuncArg { ty, default_value, name, .. }) in args.iter().enumerate() {
let in_name = match arg_names.get(i) { let in_name = match arg_names.get(i) {
Some(n) => n, Some(n) => n,
None if default_value.is_none() => { None if default_value.is_none() => {
@ -300,6 +302,64 @@ impl Nac3 {
None None
} }
/// Returns a [`Vec`] of builtins that needs to be initialized during method compilation time.
fn get_lateinit_builtins() -> Vec<Box<BuiltinFuncCreator>> {
vec![
Box::new(|primitives, unifier| {
let arg_ty = unifier.get_fresh_var(Some("T".into()), None);
(
"core_log".into(),
FunSignature {
args: vec![FuncArg {
name: "arg".into(),
ty: arg_ty.ty,
default_value: None,
is_vararg: false,
}],
ret: primitives.none,
vars: into_var_map([arg_ty]),
},
Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| {
gen_core_log(ctx, &obj, fun, &args, generator)?;
Ok(None)
}))),
)
}),
Box::new(|primitives, unifier| {
let arg_ty = unifier.get_fresh_var(Some("T".into()), None);
(
"rtio_log".into(),
FunSignature {
args: vec![
FuncArg {
name: "channel".into(),
ty: primitives.str,
default_value: None,
is_vararg: false,
},
FuncArg {
name: "arg".into(),
ty: arg_ty.ty,
default_value: None,
is_vararg: false,
},
],
ret: primitives.none,
vars: into_var_map([arg_ty]),
},
Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| {
gen_rtio_log(ctx, &obj, fun, &args, generator)?;
Ok(None)
}))),
)
}),
]
}
fn compile_method<T>( fn compile_method<T>(
&self, &self,
obj: &PyAny, obj: &PyAny,
@ -312,6 +372,7 @@ impl Nac3 {
let size_t = self.isa.get_size_type(); let size_t = self.isa.get_size_type();
let (mut composer, mut builtins_def, mut builtins_ty) = TopLevelComposer::new( let (mut composer, mut builtins_def, mut builtins_ty) = TopLevelComposer::new(
self.builtins.clone(), self.builtins.clone(),
Self::get_lateinit_builtins(),
ComposerConfig { kernel_ann: Some("Kernel"), kernel_invariant_ann: "KernelInvariant" }, ComposerConfig { kernel_ann: Some("Kernel"), kernel_invariant_ann: "KernelInvariant" },
size_t, size_t,
); );
@ -497,6 +558,11 @@ impl Nac3 {
.register_top_level(synthesized.pop().unwrap(), Some(resolver.clone()), "", false) .register_top_level(synthesized.pop().unwrap(), Some(resolver.clone()), "", false)
.unwrap(); .unwrap();
// Process IRRT
let context = inkwell::context::Context::create();
let irrt = load_irrt(&context);
setup_irrt_exceptions(&context, &irrt, resolver.as_ref());
let fun_signature = let fun_signature =
FunSignature { args: vec![], ret: self.primitive.none, vars: VarMap::new() }; FunSignature { args: vec![], ret: self.primitive.none, vars: VarMap::new() };
let mut store = ConcreteTypeStore::new(); let mut store = ConcreteTypeStore::new();
@ -625,7 +691,9 @@ impl Nac3 {
let buffer = buffer.as_slice().into(); let buffer = buffer.as_slice().into();
membuffer.lock().push(buffer); membuffer.lock().push(buffer);
}))); })));
let size_t = if self.isa == Isa::Host { 64 } else { 32 }; let size_t = Context::create()
.ptr_sized_int_type(&self.get_llvm_target_machine().get_target_data(), None)
.get_bit_width();
let num_threads = if is_multithreaded() { 4 } else { 1 }; let num_threads = if is_multithreaded() { 4 } else { 1 };
let thread_names: Vec<String> = (0..num_threads).map(|_| "main".to_string()).collect(); let thread_names: Vec<String> = (0..num_threads).map(|_| "main".to_string()).collect();
let threads: Vec<_> = thread_names let threads: Vec<_> = thread_names
@ -644,6 +712,9 @@ impl Nac3 {
ArtiqCodeGenerator::new("attributes_writeback".to_string(), size_t, self.time_fns); ArtiqCodeGenerator::new("attributes_writeback".to_string(), size_t, self.time_fns);
let context = inkwell::context::Context::create(); let context = inkwell::context::Context::create();
let module = context.create_module("attributes_writeback"); let module = context.create_module("attributes_writeback");
let target_machine = self.llvm_options.create_target_machine().unwrap();
module.set_data_layout(&target_machine.get_target_data().get_data_layout());
module.set_triple(&target_machine.get_triple());
let builder = context.create_builder(); let builder = context.create_builder();
let (_, module, _) = gen_func_impl( let (_, module, _) = gen_func_impl(
&context, &context,
@ -662,7 +733,7 @@ impl Nac3 {
membuffer.lock().push(buffer); membuffer.lock().push(buffer);
}); });
let context = inkwell::context::Context::create(); // Link all modules into `main`.
let buffers = membuffers.lock(); let buffers = membuffers.lock();
let main = context let main = context
.create_module_from_ir(MemoryBuffer::create_from_memory_range(&buffers[0], "main")) .create_module_from_ir(MemoryBuffer::create_from_memory_range(&buffers[0], "main"))
@ -691,8 +762,7 @@ impl Nac3 {
) )
.unwrap(); .unwrap();
main.link_in_module(load_irrt(&context)) main.link_in_module(irrt).map_err(|err| CompileError::new_err(err.to_string()))?;
.map_err(|err| CompileError::new_err(err.to_string()))?;
let mut function_iter = main.get_first_function(); let mut function_iter = main.get_first_function();
while let Some(func) = function_iter { while let Some(func) = function_iter {
@ -847,7 +917,7 @@ impl Nac3 {
Isa::RiscV32IMA => &timeline::NOW_PINNING_TIME_FNS, Isa::RiscV32IMA => &timeline::NOW_PINNING_TIME_FNS,
Isa::CortexA9 | Isa::Host => &timeline::EXTERN_TIME_FNS, Isa::CortexA9 | Isa::Host => &timeline::EXTERN_TIME_FNS,
}; };
let primitive: PrimitiveStore = TopLevelComposer::make_primitives(isa.get_size_type()).0; let (primitive, _) = TopLevelComposer::make_primitives(isa.get_size_type());
let builtins = vec![ let builtins = vec![
( (
"now_mu".into(), "now_mu".into(),
@ -863,6 +933,7 @@ impl Nac3 {
name: "t".into(), name: "t".into(),
ty: primitive.int64, ty: primitive.int64,
default_value: None, default_value: None,
is_vararg: false,
}], }],
ret: primitive.none, ret: primitive.none,
vars: VarMap::new(), vars: VarMap::new(),
@ -882,6 +953,7 @@ impl Nac3 {
name: "dt".into(), name: "dt".into(),
ty: primitive.int64, ty: primitive.int64,
default_value: None, default_value: None,
is_vararg: false,
}], }],
ret: primitive.none, ret: primitive.none,
vars: VarMap::new(), vars: VarMap::new(),

View File

@ -351,7 +351,7 @@ impl InnerResolver {
Ok(Ok((ndarray, false))) Ok(Ok((ndarray, false)))
} else if ty_id == self.primitive_ids.tuple { } else if ty_id == self.primitive_ids.tuple {
// do not handle type var param and concrete check here // do not handle type var param and concrete check here
Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false))) Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![], is_vararg_ctx: false }), false)))
} else if ty_id == self.primitive_ids.option { } else if ty_id == self.primitive_ids.option {
Ok(Ok((primitives.option, false))) Ok(Ok((primitives.option, false)))
} else if ty_id == self.primitive_ids.none { } else if ty_id == self.primitive_ids.none {
@ -555,7 +555,10 @@ impl InnerResolver {
Err(err) => return Ok(Err(err)), Err(err) => return Ok(Err(err)),
_ => return Ok(Err("tuple type needs at least 1 type parameters".to_string())) _ => return Ok(Err("tuple type needs at least 1 type parameters".to_string()))
}; };
Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: args }), true))) Ok(Ok((
unifier.add_ty(TypeEnum::TTuple { ty: args, is_vararg_ctx: false }),
true,
)))
} }
TypeEnum::TObj { params, obj_id, .. } => { TypeEnum::TObj { params, obj_id, .. } => {
let subst = { let subst = {
@ -797,7 +800,9 @@ impl InnerResolver {
.map(|elem| self.get_obj_type(py, elem, unifier, defs, primitives)) .map(|elem| self.get_obj_type(py, elem, unifier, defs, primitives))
.collect(); .collect();
let types = types?; let types = types?;
Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types }))) Ok(types.map(|types| {
unifier.add_ty(TypeEnum::TTuple { ty: types, is_vararg_ctx: false })
}))
} }
// special handling for option type since its class member layout in python side // special handling for option type since its class member layout in python side
// is special and cannot be mapped directly to a nac3 type as below // is special and cannot be mapped directly to a nac3 type as below
@ -991,8 +996,15 @@ impl InnerResolver {
} }
_ => unreachable!("must be list"), _ => unreachable!("must be list"),
}; };
let ty = ctx.get_llvm_type(generator, elem_ty);
let size_t = generator.get_size_type(ctx.ctx); let size_t = generator.get_size_type(ctx.ctx);
let ty = if len == 0
&& matches!(&*ctx.unifier.get_ty_immutable(elem_ty), TypeEnum::TVar { .. })
{
// The default type for zero-length lists of unknown element type is size_t
size_t.into()
} else {
ctx.get_llvm_type(generator, elem_ty)
};
let arr_ty = ctx let arr_ty = ctx
.ctx .ctx
.struct_type(&[ty.ptr_type(AddressSpace::default()).into(), size_t.into()], false); .struct_type(&[ty.ptr_type(AddressSpace::default()).into(), size_t.into()], false);
@ -1196,7 +1208,9 @@ impl InnerResolver {
Ok(Some(ndarray.as_pointer_value().into())) Ok(Some(ndarray.as_pointer_value().into()))
} else if ty_id == self.primitive_ids.tuple { } else if ty_id == self.primitive_ids.tuple {
let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty); let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty);
let TypeEnum::TTuple { ty } = expected_ty_enum.as_ref() else { unreachable!() }; let TypeEnum::TTuple { ty, is_vararg_ctx: false } = expected_ty_enum.as_ref() else {
unreachable!()
};
let tup_tys = ty.iter(); let tup_tys = ty.iter();
let elements: &PyTuple = obj.downcast()?; let elements: &PyTuple = obj.downcast()?;

View File

@ -1,3 +1,6 @@
[features]
test = []
[package] [package]
name = "nac3core" name = "nac3core"
version = "0.1.0" version = "0.1.0"

View File

@ -3,43 +3,60 @@ use std::{
env, env,
fs::File, fs::File,
io::Write, io::Write,
path::Path, path::{Path, PathBuf},
process::{Command, Stdio}, process::{Command, Stdio},
}; };
fn main() { const CMD_IRRT_CLANG: &str = "clang-irrt";
const FILE: &str = "src/codegen/irrt/irrt.cpp"; const CMD_IRRT_CLANG_TEST: &str = "clang-irrt-test";
const CMD_IRRT_LLVM_AS: &str = "llvm-as-irrt";
fn get_out_dir() -> PathBuf {
PathBuf::from(env::var("OUT_DIR").unwrap())
}
fn get_irrt_dir() -> &'static Path {
Path::new("irrt")
}
/// Compile `irrt.cpp` for use in `src/codegen`
fn compile_irrt_cpp() {
let out_dir = get_out_dir();
let irrt_dir = get_irrt_dir();
/* /*
* HACK: Sadly, clang doesn't let us emit generic LLVM bitcode. * HACK: Sadly, clang doesn't let us emit generic LLVM bitcode.
* Compiling for WASM32 and filtering the output with regex is the closest we can get. * Compiling for WASM32 and filtering the output with regex is the closest we can get.
*/ */
let flags: &[&str] = &[ let irrt_cpp_path = irrt_dir.join("irrt.cpp");
"--target=wasm32",
FILE,
"-x",
"c++",
"-fno-discard-value-names",
"-fno-exceptions",
"-fno-rtti",
match env::var("PROFILE").as_deref() {
Ok("debug") => "-O0",
Ok("release") => "-O3",
flavor => panic!("Unknown or missing build flavor {flavor:?}"),
},
"-emit-llvm",
"-S",
"-Wall",
"-Wextra",
"-o",
"-",
];
println!("cargo:rerun-if-changed={FILE}"); let mut flags = vec![];
let out_dir = env::var("OUT_DIR").unwrap(); flags.push("--target=wasm32");
let out_path = Path::new(&out_dir); flags.extend(&["-x", "c++"]);
flags.extend(&["-fno-discard-value-names", "-fno-exceptions", "-fno-rtti"]);
flags.push("-emit-llvm");
flags.push("-S");
flags.extend(&["-Wall", "-Wextra"]);
flags.extend(&["-o", "-"]);
flags.extend(&["-I", irrt_dir.to_str().unwrap()]);
flags.push(irrt_cpp_path.to_str().unwrap());
let output = Command::new("clang-irrt") match env::var("PROFILE").as_deref() {
Ok("debug") => {
flags.push("-O0");
flags.push("-DIRRT_DEBUG");
}
Ok("release") => {
flags.push("-O3");
}
flavor => panic!("Unknown or missing build flavor {flavor:?}"),
};
// Tell Cargo to rerun if any file under `irrt_dir` (recursive) changes
println!("cargo:rerun-if-changed={}", irrt_dir.to_str().unwrap());
// Compile IRRT and capture the LLVM IR output
let output = Command::new(CMD_IRRT_CLANG)
.args(flags) .args(flags)
.output() .output()
.map(|o| { .map(|o| {
@ -52,7 +69,17 @@ fn main() {
let output = std::str::from_utf8(&output.stdout).unwrap().replace("\r\n", "\n"); let output = std::str::from_utf8(&output.stdout).unwrap().replace("\r\n", "\n");
let mut filtered_output = String::with_capacity(output.len()); let mut filtered_output = String::with_capacity(output.len());
let regex_filter = Regex::new(r"(?ms:^define.*?\}$)|(?m:^declare.*?$)").unwrap(); // Filter out irrelevant IR
//
// Regex:
// - `(?ms:^define.*?\}$)` captures LLVM `define` blocks
// - `(?m:^declare.*?$)` captures LLVM `declare` lines
// - `(?m:^%.+?=\s*type\s*\{.+?\}$)` captures LLVM `type` declarations
// - `(?m:^@.+?=.+$)` captures global constants
let regex_filter = Regex::new(
r"(?ms:^define.*?\}$)|(?m:^declare.*?$)|(?m:^%.+?=\s*type\s*\{.+?\}$)|(?m:^@.+?=.+$)",
)
.unwrap();
for f in regex_filter.captures_iter(&output) { for f in regex_filter.captures_iter(&output) {
assert_eq!(f.len(), 1); assert_eq!(f.len(), 1);
filtered_output.push_str(&f[0]); filtered_output.push_str(&f[0]);
@ -63,20 +90,71 @@ fn main() {
.unwrap() .unwrap()
.replace_all(&filtered_output, ""); .replace_all(&filtered_output, "");
println!("cargo:rerun-if-env-changed=DEBUG_DUMP_IRRT"); // For debugging
if env::var("DEBUG_DUMP_IRRT").is_ok() { // Doing `DEBUG_DUMP_IRRT=1 cargo build -p nac3core` dumps the LLVM IR generated
let mut file = File::create(out_path.join("irrt.ll")).unwrap(); const DEBUG_DUMP_IRRT: &str = "DEBUG_DUMP_IRRT";
println!("cargo:rerun-if-env-changed={DEBUG_DUMP_IRRT}");
if env::var(DEBUG_DUMP_IRRT).is_ok() {
let mut file = File::create(out_dir.join("irrt.ll")).unwrap();
file.write_all(output.as_bytes()).unwrap(); file.write_all(output.as_bytes()).unwrap();
let mut file = File::create(out_path.join("irrt-filtered.ll")).unwrap();
let mut file = File::create(out_dir.join("irrt-filtered.ll")).unwrap();
file.write_all(filtered_output.as_bytes()).unwrap(); file.write_all(filtered_output.as_bytes()).unwrap();
} }
let mut llvm_as = Command::new("llvm-as-irrt") // Assemble the emitted and filtered IR to .bc
// That .bc will be integrated into nac3core's codegen
let mut llvm_as = Command::new(CMD_IRRT_LLVM_AS)
.stdin(Stdio::piped()) .stdin(Stdio::piped())
.arg("-o") .arg("-o")
.arg(out_path.join("irrt.bc")) .arg(out_dir.join("irrt.bc"))
.spawn() .spawn()
.unwrap(); .unwrap();
llvm_as.stdin.as_mut().unwrap().write_all(filtered_output.as_bytes()).unwrap(); llvm_as.stdin.as_mut().unwrap().write_all(filtered_output.as_bytes()).unwrap();
assert!(llvm_as.wait().unwrap().success()); assert!(llvm_as.wait().unwrap().success());
} }
/// Compile `irrt_test.cpp` for testing
fn compile_irrt_test_cpp() {
let out_dir = get_out_dir();
let irrt_dir = get_irrt_dir();
let exe_path = out_dir.join("irrt_test.out"); // Output path of the compiled test executable
let irrt_test_cpp_path = irrt_dir.join("irrt_test.cpp");
let flags: &[&str] = &[
irrt_test_cpp_path.to_str().unwrap(),
"-x",
"c++",
"-I",
irrt_dir.to_str().unwrap(),
"-g",
"-fno-discard-value-names",
"-O0",
"-Wall",
"-Wextra",
"-Werror=return-type",
"-lm", // for `tgamma()`, `lgamma()`
"-o",
exe_path.to_str().unwrap(),
];
Command::new(CMD_IRRT_CLANG_TEST)
.args(flags)
.output()
.map(|o| {
assert!(o.status.success(), "{}", std::str::from_utf8(&o.stderr).unwrap());
o
})
.unwrap();
println!("cargo:rerun-if-changed={}", irrt_dir.to_str().unwrap());
}
fn main() {
compile_irrt_cpp();
// https://github.com/rust-lang/cargo/issues/2549
// `cargo test -F test` to also build `irrt_test.cpp
if cfg!(feature = "test") {
compile_irrt_test_cpp();
}
}

10
nac3core/irrt/irrt.cpp Normal file
View File

@ -0,0 +1,10 @@
#define IRRT_DEFINE_TYPEDEF_INTS
#include <irrt_everything.hpp>
/*
* All IRRT implementations.
*
* We don't have pre-compiled objects, so we are writing all implementations in
* headers and concatenate them with `#include` into one massive source file that
* contains all the IRRT stuff.
*/

View File

@ -1,27 +1,17 @@
using int8_t = _BitInt(8); #pragma once
using uint8_t = unsigned _BitInt(8);
using int32_t = _BitInt(32); #include <irrt/int_defs.hpp>
using uint32_t = unsigned _BitInt(32); #include <irrt/util.hpp>
using int64_t = _BitInt(64);
using uint64_t = unsigned _BitInt(64);
// NDArray indices are always `uint32_t`. // NDArray indices are always `uint32_t`.
using NDIndex = uint32_t; using NDIndexInt = uint32_t;
// The type of an index or a value describing the length of a range/slice is always `int32_t`. // The type of an index or a value describing the length of a
// range/slice is always `int32_t`.
using SliceIndex = int32_t; using SliceIndex = int32_t;
namespace { namespace {
template <typename T> // adapted from GNU Scientific Library:
const T& max(const T& a, const T& b) { // https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
return a > b ? a : b;
}
template <typename T>
const T& min(const T& a, const T& b) {
return a > b ? b : a;
}
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
// need to make sure `exp >= 0` before calling this function // need to make sure `exp >= 0` before calling this function
template <typename T> template <typename T>
T __nac3_int_exp_impl(T base, T exp) { T __nac3_int_exp_impl(T base, T exp) {
@ -38,12 +28,8 @@ T __nac3_int_exp_impl(T base, T exp) {
} }
template <typename SizeT> template <typename SizeT>
SizeT __nac3_ndarray_calc_size_impl( SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len,
const SizeT* list_data, SizeT begin_idx, SizeT end_idx) {
SizeT list_len,
SizeT begin_idx,
SizeT end_idx
) {
__builtin_assume(end_idx <= list_len); __builtin_assume(end_idx <= list_len);
SizeT num_elems = 1; SizeT num_elems = 1;
@ -56,12 +42,8 @@ SizeT __nac3_ndarray_calc_size_impl(
} }
template <typename SizeT> template <typename SizeT>
void __nac3_ndarray_calc_nd_indices_impl( void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims,
SizeT index, SizeT num_dims, NDIndexInt* idxs) {
const SizeT* dims,
SizeT num_dims,
NDIndex* idxs
) {
SizeT stride = 1; SizeT stride = 1;
for (SizeT dim = 0; dim < num_dims; dim++) { for (SizeT dim = 0; dim < num_dims; dim++) {
SizeT i = num_dims - dim - 1; SizeT i = num_dims - dim - 1;
@ -72,12 +54,9 @@ void __nac3_ndarray_calc_nd_indices_impl(
} }
template <typename SizeT> template <typename SizeT>
SizeT __nac3_ndarray_flatten_index_impl( SizeT __nac3_ndarray_flatten_index_impl(const SizeT* dims, SizeT num_dims,
const SizeT* dims, const NDIndexInt* indices,
SizeT num_dims, SizeT num_indices) {
const NDIndex* indices,
SizeT num_indices
) {
SizeT idx = 0; SizeT idx = 0;
SizeT stride = 1; SizeT stride = 1;
for (SizeT i = 0; i < num_dims; ++i) { for (SizeT i = 0; i < num_dims; ++i) {
@ -93,18 +72,17 @@ SizeT __nac3_ndarray_flatten_index_impl(
} }
template <typename SizeT> template <typename SizeT>
void __nac3_ndarray_calc_broadcast_impl( void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims, SizeT lhs_ndims,
const SizeT* lhs_dims, const SizeT* rhs_dims, SizeT rhs_ndims,
SizeT lhs_ndims, SizeT* out_dims) {
const SizeT* rhs_dims,
SizeT rhs_ndims,
SizeT* out_dims
) {
SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims; SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
for (SizeT i = 0; i < max_ndims; ++i) { for (SizeT i = 0; i < max_ndims; ++i) {
const SizeT* lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr; const SizeT* lhs_dim_sz =
const SizeT* rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr; i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr;
const SizeT* rhs_dim_sz =
i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr;
SizeT* out_dim = &out_dims[max_ndims - i - 1]; SizeT* out_dim = &out_dims[max_ndims - i - 1];
if (lhs_dim_sz == nullptr) { if (lhs_dim_sz == nullptr) {
@ -124,12 +102,10 @@ void __nac3_ndarray_calc_broadcast_impl(
} }
template <typename SizeT> template <typename SizeT>
void __nac3_ndarray_calc_broadcast_idx_impl( void __nac3_ndarray_calc_broadcast_idx_impl(const SizeT* src_dims,
const SizeT* src_dims, SizeT src_ndims,
SizeT src_ndims, const NDIndexInt* in_idx,
const NDIndex* in_idx, NDIndexInt* out_idx) {
NDIndex* out_idx
) {
for (SizeT i = 0; i < src_ndims; ++i) { for (SizeT i = 0; i < src_ndims; ++i) {
SizeT src_i = src_ndims - i - 1; SizeT src_i = src_ndims - i - 1;
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i]; out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
@ -138,15 +114,15 @@ void __nac3_ndarray_calc_broadcast_idx_impl(
} // namespace } // namespace
extern "C" { extern "C" {
#define DEF_nac3_int_exp_(T) \ #define DEF_nac3_int_exp_(T) \
T __nac3_int_exp_##T(T base, T exp) {\ T __nac3_int_exp_##T(T base, T exp) { \
return __nac3_int_exp_impl(base, exp);\ return __nac3_int_exp_impl(base, exp); \
} }
DEF_nac3_int_exp_(int32_t) DEF_nac3_int_exp_(int32_t);
DEF_nac3_int_exp_(int64_t) DEF_nac3_int_exp_(int64_t);
DEF_nac3_int_exp_(uint32_t) DEF_nac3_int_exp_(uint32_t);
DEF_nac3_int_exp_(uint64_t) DEF_nac3_int_exp_(uint64_t);
SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) { SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) {
if (i < 0) { if (i < 0) {
@ -160,11 +136,8 @@ SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) {
return i; return i;
} }
SliceIndex __nac3_range_slice_len( SliceIndex __nac3_range_slice_len(const SliceIndex start, const SliceIndex end,
const SliceIndex start, const SliceIndex step) {
const SliceIndex end,
const SliceIndex step
) {
SliceIndex diff = end - start; SliceIndex diff = end - start;
if (diff > 0 && step > 0) { if (diff > 0 && step > 0) {
return ((diff - 1) / step) + 1; return ((diff - 1) / step) + 1;
@ -180,62 +153,52 @@ SliceIndex __nac3_range_slice_len(
// - All the index must *not* be out-of-bound or negative, // - All the index must *not* be out-of-bound or negative,
// - The end index is *inclusive*, // - The end index is *inclusive*,
// - The length of src and dest slice size should already // - The length of src and dest slice size should already
// be checked: if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest) // be checked: if dest.step == 1 then len(src) <= len(dest) else
// len(src) == len(dest)
SliceIndex __nac3_list_slice_assign_var_size( SliceIndex __nac3_list_slice_assign_var_size(
SliceIndex dest_start, SliceIndex dest_start, SliceIndex dest_end, SliceIndex dest_step,
SliceIndex dest_end, uint8_t* dest_arr, SliceIndex dest_arr_len, SliceIndex src_start,
SliceIndex dest_step, SliceIndex src_end, SliceIndex src_step, uint8_t* src_arr,
uint8_t* dest_arr, SliceIndex src_arr_len, const SliceIndex size) {
SliceIndex dest_arr_len, /* if dest_arr_len == 0, do nothing since we do not support
SliceIndex src_start, * extending list
SliceIndex src_end, */
SliceIndex src_step,
uint8_t* src_arr,
SliceIndex src_arr_len,
const SliceIndex size
) {
/* if dest_arr_len == 0, do nothing since we do not support extending list */
if (dest_arr_len == 0) return dest_arr_len; if (dest_arr_len == 0) return dest_arr_len;
/* if both step is 1, memmove directly, handle the dropping of the list, and shrink size */ /* if both step is 1, memmove directly, handle the dropping of
* the list, and shrink size */
if (src_step == dest_step && dest_step == 1) { if (src_step == dest_step && dest_step == 1) {
const SliceIndex src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0; const SliceIndex src_len =
const SliceIndex dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0; (src_end >= src_start) ? (src_end - src_start + 1) : 0;
const SliceIndex dest_len =
(dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
if (src_len > 0) { if (src_len > 0) {
__builtin_memmove( __builtin_memmove(dest_arr + dest_start * size,
dest_arr + dest_start * size, src_arr + src_start * size, src_len * size);
src_arr + src_start * size,
src_len * size
);
} }
if (dest_len > 0) { if (dest_len > 0) {
/* dropping */ /* dropping */
__builtin_memmove( __builtin_memmove(dest_arr + (dest_start + src_len) * size,
dest_arr + (dest_start + src_len) * size, dest_arr + (dest_end + 1) * size,
dest_arr + (dest_end + 1) * size, (dest_arr_len - dest_end - 1) * size);
(dest_arr_len - dest_end - 1) * size
);
} }
/* shrink size */ /* shrink size */
return dest_arr_len - (dest_len - src_len); return dest_arr_len - (dest_len - src_len);
} }
/* if two range overlaps, need alloca */ /* if two range overlaps, need alloca */
uint8_t need_alloca = uint8_t need_alloca =
(dest_arr == src_arr) (dest_arr == src_arr) &&
&& !( !(max(dest_start, dest_end) < min(src_start, src_end) ||
max(dest_start, dest_end) < min(src_start, src_end) max(src_start, src_end) < min(dest_start, dest_end));
|| max(src_start, src_end) < min(dest_start, dest_end)
);
if (need_alloca) { if (need_alloca) {
uint8_t* tmp = reinterpret_cast<uint8_t *>(__builtin_alloca(src_arr_len * size)); uint8_t* tmp =
reinterpret_cast<uint8_t*>(__builtin_alloca(src_arr_len * size));
__builtin_memcpy(tmp, src_arr, src_arr_len * size); __builtin_memcpy(tmp, src_arr, src_arr_len * size);
src_arr = tmp; src_arr = tmp;
} }
SliceIndex src_ind = src_start; SliceIndex src_ind = src_start;
SliceIndex dest_ind = dest_start; SliceIndex dest_ind = dest_start;
for (; for (; (src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end);
(src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end); src_ind += src_step, dest_ind += dest_step) {
src_ind += src_step, dest_ind += dest_step
) {
/* for constant optimization */ /* for constant optimization */
if (size == 1) { if (size == 1) {
__builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1); __builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1);
@ -244,30 +207,26 @@ SliceIndex __nac3_list_slice_assign_var_size(
} else if (size == 8) { } else if (size == 8) {
__builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8); __builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8);
} else { } else {
/* memcpy for var size, cannot overlap after previous alloca */ /* memcpy for var size, cannot overlap after previous
__builtin_memcpy(dest_arr + dest_ind * size, src_arr + src_ind * size, size); * alloca */
__builtin_memcpy(dest_arr + dest_ind * size,
src_arr + src_ind * size, size);
} }
} }
/* only dest_step == 1 can we shrink the dest list. */ /* only dest_step == 1 can we shrink the dest list. */
/* size should be ensured prior to calling this function */ /* size should be ensured prior to calling this function */
if (dest_step == 1 && dest_end >= dest_start) { if (dest_step == 1 && dest_end >= dest_start) {
__builtin_memmove( __builtin_memmove(
dest_arr + dest_ind * size, dest_arr + dest_ind * size, dest_arr + (dest_end + 1) * size,
dest_arr + (dest_end + 1) * size, (dest_arr_len - dest_end - 1) * size + size + size + size);
(dest_arr_len - dest_end - 1) * size
);
return dest_arr_len - (dest_end - dest_ind) - 1; return dest_arr_len - (dest_end - dest_ind) - 1;
} }
return dest_arr_len; return dest_arr_len;
} }
int32_t __nac3_isinf(double x) { int32_t __nac3_isinf(double x) { return __builtin_isinf(x); }
return __builtin_isinf(x);
}
int32_t __nac3_isnan(double x) { int32_t __nac3_isnan(double x) { return __builtin_isnan(x); }
return __builtin_isnan(x);
}
double tgamma(double arg); double tgamma(double arg);
@ -320,95 +279,71 @@ double __nac3_j0(double x) {
return j0(x); return j0(x);
} }
uint32_t __nac3_ndarray_calc_size( uint32_t __nac3_ndarray_calc_size(const uint32_t* list_data, uint32_t list_len,
const uint32_t* list_data, uint32_t begin_idx, uint32_t end_idx) {
uint32_t list_len, return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx,
uint32_t begin_idx, end_idx);
uint32_t end_idx
) {
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
} }
uint64_t __nac3_ndarray_calc_size64( uint64_t __nac3_ndarray_calc_size64(const uint64_t* list_data,
const uint64_t* list_data, uint64_t list_len, uint64_t begin_idx,
uint64_t list_len, uint64_t end_idx) {
uint64_t begin_idx, return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx,
uint64_t end_idx end_idx);
) {
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
} }
void __nac3_ndarray_calc_nd_indices( void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims,
uint32_t index, uint32_t num_dims, NDIndexInt* idxs) {
const uint32_t* dims,
uint32_t num_dims,
NDIndex* idxs
) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs); __nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
} }
void __nac3_ndarray_calc_nd_indices64( void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims,
uint64_t index, uint64_t num_dims, NDIndexInt* idxs) {
const uint64_t* dims,
uint64_t num_dims,
NDIndex* idxs
) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs); __nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
} }
uint32_t __nac3_ndarray_flatten_index( uint32_t __nac3_ndarray_flatten_index(const uint32_t* dims, uint32_t num_dims,
const uint32_t* dims, const NDIndexInt* indices,
uint32_t num_dims, uint32_t num_indices) {
const NDIndex* indices, return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices,
uint32_t num_indices num_indices);
) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
} }
uint64_t __nac3_ndarray_flatten_index64( uint64_t __nac3_ndarray_flatten_index64(const uint64_t* dims, uint64_t num_dims,
const uint64_t* dims, const NDIndexInt* indices,
uint64_t num_dims, uint64_t num_indices) {
const NDIndex* indices, return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices,
uint64_t num_indices num_indices);
) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
} }
void __nac3_ndarray_calc_broadcast( void __nac3_ndarray_calc_broadcast(const uint32_t* lhs_dims, uint32_t lhs_ndims,
const uint32_t* lhs_dims, const uint32_t* rhs_dims, uint32_t rhs_ndims,
uint32_t lhs_ndims, uint32_t* out_dims) {
const uint32_t* rhs_dims, return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims,
uint32_t rhs_ndims, rhs_ndims, out_dims);
uint32_t* out_dims
) {
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
} }
void __nac3_ndarray_calc_broadcast64( void __nac3_ndarray_calc_broadcast64(const uint64_t* lhs_dims,
const uint64_t* lhs_dims, uint64_t lhs_ndims,
uint64_t lhs_ndims, const uint64_t* rhs_dims,
const uint64_t* rhs_dims, uint64_t rhs_ndims, uint64_t* out_dims) {
uint64_t rhs_ndims, return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims,
uint64_t* out_dims rhs_ndims, out_dims);
) {
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
} }
void __nac3_ndarray_calc_broadcast_idx( void __nac3_ndarray_calc_broadcast_idx(const uint32_t* src_dims,
const uint32_t* src_dims, uint32_t src_ndims,
uint32_t src_ndims, const NDIndexInt* in_idx,
const NDIndex* in_idx, NDIndexInt* out_idx) {
NDIndex* out_idx __nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx,
) { out_idx);
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
} }
void __nac3_ndarray_calc_broadcast_idx64( void __nac3_ndarray_calc_broadcast_idx64(const uint64_t* src_dims,
const uint64_t* src_dims, uint64_t src_ndims,
uint64_t src_ndims, const NDIndexInt* in_idx,
const NDIndex* in_idx, NDIndexInt* out_idx) {
NDIndex* out_idx __nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx,
) { out_idx);
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
} }
} // extern "C" } // extern "C"

View File

@ -0,0 +1,9 @@
#pragma once
#include <irrt/int_defs.hpp>
template <typename SizeT>
struct CSlice {
uint8_t* base;
SizeT len;
};

View File

@ -0,0 +1,15 @@
#pragma once
#define raise_debug_assert(SizeT, msg, param1, param2, param3) \
raise_exception(SizeT, EXN_ASSERTION_ERROR, \
"IRRT debug assert failed: " msg, param1, param2, param3);
#define debug_assert_eq(SizeT, lhs, rhs) \
if (IRRT_DEBUG_ASSERT_BOOL && (lhs) != (rhs)) { \
raise_debug_assert(SizeT, "LHS = {0}. RHS = {1}", lhs, rhs, NO_PARAM); \
}
#define debug_assert(SizeT, expr) \
if (IRRT_DEBUG_ASSERT_BOOL && !(expr)) { \
raise_debug_assert(SizeT, "Got false.", NO_PARAM, NO_PARAM, NO_PARAM); \
}

View File

@ -0,0 +1,123 @@
#pragma once
#include <irrt/cslice.hpp>
#include <irrt/int_defs.hpp>
#include <irrt/util.hpp>
/**
* @brief The int type of ARTIQ exception IDs.
*
* It is always `int32_t`
*/
typedef int32_t ExceptionId;
/*
* A set of exceptions IRRT can use.
* Must be synchronized with `setup_irrt_exceptions` in `nac3core/src/codegen/irrt/mod.rs`.
* All exception IDs are initialized by `setup_irrt_exceptions`.
*/
#ifdef IRRT_TESTING
// If we are doing IRRT tests (i.e., running `cargo test -F test`), define them with a fake set of IDs.
ExceptionId EXN_INDEX_ERROR = 0;
ExceptionId EXN_VALUE_ERROR = 1;
ExceptionId EXN_ASSERTION_ERROR = 2;
ExceptionId EXN_RUNTIME_ERROR = 3;
ExceptionId EXN_TYPE_ERROR = 4;
#else
extern "C" {
ExceptionId EXN_INDEX_ERROR;
ExceptionId EXN_VALUE_ERROR;
ExceptionId EXN_ASSERTION_ERROR;
ExceptionId EXN_RUNTIME_ERROR;
ExceptionId EXN_TYPE_ERROR;
}
#endif
namespace {
/**
* @brief NAC3's Exception struct
*/
template <typename SizeT>
struct Exception {
ExceptionId id;
CSlice<SizeT> filename;
int32_t line;
int32_t column;
CSlice<SizeT> function;
CSlice<SizeT> msg;
int64_t params[3];
};
} // namespace
// Declare/Define `__nac3_raise`
#ifdef IRRT_TESTING
#include <cstdio>
void __nac3_raise(void* err) {
// TODO: Print the error content?
printf("__nac3_raise called. Exiting...\n");
exit(1);
}
#else
/**
* @brief Extern function to `__nac3_raise`
*
* The parameter `err` could be `Exception<int32_t>` or `Exception<int64_t>`. The caller
* must make sure to pass `Exception`s with the correct `SizeT` depending on the `size_t` of the runtime.
*/
extern "C" void __nac3_raise(void* err);
#endif
namespace {
const int64_t NO_PARAM = 0;
// Helper function to raise an exception with `__nac3_raise`
// Do not use this function directly. See `raise_exception`.
template <typename SizeT>
void _raise_exception_helper(ExceptionId id, const char* filename, int32_t line,
const char* function, const char* msg,
int64_t param0, int64_t param1, int64_t param2) {
Exception<SizeT> e = {
.id = id,
.filename = {.base = (uint8_t*)filename,
.len = (int32_t)cstr_utils::length(filename)},
.line = line,
.column = 0,
.function = {.base = (uint8_t*)function,
.len = (int32_t)cstr_utils::length(function)},
.msg = {.base = (uint8_t*)msg, .len = (int32_t)cstr_utils::length(msg)},
};
e.params[0] = param0;
e.params[1] = param1;
e.params[2] = param2;
__nac3_raise((void*)&e);
__builtin_unreachable();
}
/**
* @brief Raise an exception with location details (location in the IRRT source files).
* @param SizeT The runtime `size_t` type.
* @param id The ID of the exception to raise.
* @param msg A global constant C-string of the error message.
*
* `param0` and `param2` are optional format arguments of `msg`. They should be set to
* `NO_PARAM` to indicate they are unused.
*/
#define raise_exception(SizeT, id, msg, param0, param1, param2) \
_raise_exception_helper<SizeT>(id, __FILE__, __LINE__, __FUNCTION__, msg, \
param0, param1, param2)
/**
* @brief Throw a dummy error for testing.
*/
template <typename SizeT>
void throw_dummy_error() {
raise_exception(SizeT, EXN_RUNTIME_ERROR, "dummy error", NO_PARAM, NO_PARAM,
NO_PARAM);
}
} // namespace
extern "C" {
void __nac3_throw_dummy_error() { throw_dummy_error<int32_t>(); }
void __nac3_throw_dummy_error64() { throw_dummy_error<int64_t>(); }
}

View File

@ -0,0 +1,12 @@
#pragma once
// This is made toggleable since `irrt_test.cpp` itself would include
// headers that define these typedefs
#ifdef IRRT_DEFINE_TYPEDEF_INTS
using int8_t = _BitInt(8);
using uint8_t = unsigned _BitInt(8);
using int32_t = _BitInt(32);
using uint32_t = unsigned _BitInt(32);
using int64_t = _BitInt(64);
using uint64_t = unsigned _BitInt(64);
#endif

View File

@ -0,0 +1,56 @@
#pragma once
#include <irrt/int_defs.hpp>
#include <irrt/slice.hpp>
namespace {
/**
* @brief A list in NAC3.
*
* The `items` field is opaque. You must rely on external contexts to
* know how to interpret it.
*/
template <typename SizeT>
struct List {
uint8_t* items;
SizeT len;
};
namespace list {
template <typename SizeT>
void slice_assign(List<SizeT>* dst, List<SizeT>* src, SizeT itemsize,
UserSlice* user_slice) {
Slice slice = user_slice->indices_checked<SizeT>(dst->len);
// NOTE: Python does not have this restriction.
if (slice.len() != src->len) {
raise_exception(SizeT, EXN_VALUE_ERROR,
"List destination has {} item(s), but source has {} "
"item(s). The lengths must match.",
slice.len(), src->len, NO_PARAM);
}
// TODO: Look into how the original implementation was implemented and optimized.
SizeT dst_i = slice.start;
SizeT src_i = 0;
while (src_i < slice.len()) {
__builtin_memcpy(dst->items + dst_i, src->items + src_i, itemsize);
src_i += 1;
dst_i += slice.step;
}
}
} // namespace list
} // namespace
extern "C" {
void __nac3_list_slice_assign(List<int32_t>* dst, List<int32_t>* src,
int32_t itemsize, UserSlice* user_slice) {
list::slice_assign(dst, src, itemsize, user_slice);
}
void __nac3_list_slice_assign64(List<int64_t>* dst, List<int64_t>* src,
int64_t itemsize, UserSlice* user_slice) {
list::slice_assign(dst, src, itemsize, user_slice);
}
}

View File

@ -0,0 +1,119 @@
#pragma once
#include <irrt/int_defs.hpp>
#include <irrt/list.hpp>
#include <irrt/ndarray/basic.hpp>
#include <irrt/ndarray/def.hpp>
namespace {
namespace ndarray {
namespace array {
// TODO: Document me
template <typename SizeT>
void set_and_validate_list_shape_helper(SizeT axis, List<SizeT>* list,
SizeT ndims, SizeT* shape) {
if (shape[axis] == -1) {
// Dimension is unspecified. Set it.
shape[axis] = list->len;
} else {
// Dimension is specified. Check.
if (shape[axis] != list->len) {
// Mismatch, throw an error.
// NOTE: NumPy's error message is more complex and needs more PARAMS to display.
raise_exception(SizeT, EXN_VALUE_ERROR,
"The requested array has an inhomogenous shape "
"after {0} dimension(s).",
axis, shape[axis], list->len);
}
}
if (axis + 1 == ndims) {
// `list` has type `list[ItemType]`
// Do nothing
} else {
// `list` has type `list[list[...]]`
List<SizeT>** lists = (List<SizeT>**)(list->items);
for (SizeT i = 0; i < list->len; i++) {
set_and_validate_list_shape_helper<SizeT>(axis + 1, lists[i], ndims,
shape);
}
}
}
// TODO: Document me
template <typename SizeT>
void set_and_validate_list_shape(List<SizeT>* list, SizeT ndims, SizeT* shape) {
for (SizeT axis = 0; axis < ndims; axis++) {
shape[axis] = -1; // Sentinel to say this dimension is unspecified.
}
set_and_validate_list_shape_helper<SizeT>(0, list, ndims, shape);
}
// TODO: Document me
template <typename SizeT>
void write_list_to_array_helper(SizeT axis, SizeT* index, List<SizeT>* list,
NDArray<SizeT>* ndarray) {
debug_assert_eq(SizeT, list->len, ndarray->shape[axis]);
if (IRRT_DEBUG_ASSERT_BOOL) {
if (!ndarray::basic::is_c_contiguous(ndarray)) {
raise_debug_assert(SizeT, "ndarray is not C-contiguous", ndarray->strides[0],
ndarray->strides[1], NO_PARAM);
}
}
if (axis + 1 == ndarray->ndims) {
// `list` has type `list[ItemType]`
// `ndarray` is contiguous, so we can do this, and this is fast.
uint8_t* dst = ndarray->data + (ndarray->itemsize * (*index));
__builtin_memcpy(dst, list->items, ndarray->itemsize * list->len);
*index += list->len;
} else {
// `list` has type `list[list[...]]`
List<SizeT>** lists = (List<SizeT>**)(list->items);
for (SizeT i = 0; i < list->len; i++) {
write_list_to_array_helper<SizeT>(axis + 1, index, lists[i],
ndarray);
}
}
}
// TODO: Document me
template <typename SizeT>
void write_list_to_array(List<SizeT>* list, NDArray<SizeT>* ndarray) {
// done after set_and_validate(list, ndims, shape), list is well-formed
// ndarray->data is allocated and owned
// ndarray->itemsize is set
// ndarray->ndims is set
// ndarray->shape is set
// ndarray->strides is ???
SizeT index = 0;
write_list_to_array_helper<SizeT>((SizeT)0, &index, list, ndarray);
}
} // namespace array
} // namespace ndarray
} // namespace
extern "C" {
using namespace ndarray::array;
void __nac3_array_set_and_validate_list_shape(List<int32_t>* list,
int32_t ndims, int32_t* shape) {
set_and_validate_list_shape(list, ndims, shape);
}
void __nac3_array_set_and_validate_list_shape64(List<int64_t>* list,
int64_t ndims, int64_t* shape) {
set_and_validate_list_shape(list, ndims, shape);
}
void __nac3_array_write_list_to_array(List<int32_t>* list,
NDArray<int32_t>* ndarray) {
write_list_to_array(list, ndarray);
}
void __nac3_array_write_list_to_array64(List<int64_t>* list,
NDArray<int64_t>* ndarray) {
write_list_to_array(list, ndarray);
}
}

View File

@ -0,0 +1,345 @@
#pragma once
#include <irrt/exception.hpp>
#include <irrt/int_defs.hpp>
#include <irrt/ndarray/def.hpp>
namespace {
namespace ndarray {
namespace basic {
/**
* @brief Asserts that `shape` does not contain negative dimensions.
*
* @param ndims Number of dimensions in `shape`
* @param shape The shape to check on
*/
template <typename SizeT>
void assert_shape_no_negative(SizeT ndims, const SizeT* shape) {
for (SizeT axis = 0; axis < ndims; axis++) {
if (shape[axis] < 0) {
raise_exception(SizeT, EXN_VALUE_ERROR,
"negative dimensions are not allowed; axis {0} "
"has dimension {1}",
axis, shape[axis], NO_PARAM);
}
}
}
/**
* @brief Check two shapes are the same in the context of writing outputting to an ndarray.
*
* This function throws error messages for output shape mismatches.
*/
template <typename SizeT>
void assert_output_shape_same(SizeT ndarray_ndims, const SizeT* ndarray_shape,
SizeT output_ndims, const SizeT* output_shape) {
if (ndarray_ndims != output_ndims) {
// There is no corresponding NumPy error message like this.
raise_exception(
SizeT, EXN_VALUE_ERROR,
"Cannot write output of ndims {0} to an ndarray with ndims {1}",
output_ndims, ndarray_ndims, NO_PARAM);
}
for (SizeT axis = 0; axis < ndarray_ndims; axis++) {
if (ndarray_shape[axis] != output_shape[axis]) {
// There is no corresponding NumPy error message like this.
raise_exception(
SizeT, EXN_VALUE_ERROR,
"Mismatched dimensions on axis {0}, output has "
"dimension {1}, but destination ndarray has dimension {2}.",
axis, output_shape[axis], ndarray_shape[axis]);
}
}
}
/**
* @brief Returns the number of elements of an ndarray given its shape.
*
* @param ndims Number of dimensions in `shape`
* @param shape The shape of the ndarray
*/
template <typename SizeT>
SizeT calc_size_from_shape(SizeT ndims, const SizeT* shape) {
SizeT size = 1;
for (SizeT axis = 0; axis < ndims; axis++) size *= shape[axis];
return size;
}
/**
* @brief Compute the array indices of the `nth` (0-based) element of an ndarray given only its shape.
*
* @param ndims Number of elements in `shape` and `indices`
* @param shape The shape of the ndarray
* @param indices The returned indices indexing the ndarray with shape `shape`.
* @param nth The index of the element of interest.
*/
template <typename SizeT>
void set_indices_by_nth(SizeT ndims, const SizeT* shape, SizeT* indices,
SizeT nth) {
for (SizeT i = 0; i < ndims; i++) {
SizeT axis = ndims - i - 1;
SizeT dim = shape[axis];
indices[axis] = nth % dim;
nth /= dim;
}
}
/**
* @brief Return the number of elements of an `ndarray`
*
* This function corresponds to `<an_ndarray>.size`
*/
template <typename SizeT>
SizeT size(const NDArray<SizeT>* ndarray) {
return calc_size_from_shape(ndarray->ndims, ndarray->shape);
}
/**
* @brief Return of the number of its content of an `ndarray`.
*
* This function corresponds to `<an_ndarray>.nbytes`.
*/
template <typename SizeT>
SizeT nbytes(const NDArray<SizeT>* ndarray) {
return size(ndarray) * ndarray->itemsize;
}
/**
* @brief Get the `len()` of an ndarray, and asserts that `ndarray` is a sized object.
*
* This function corresponds to `<an_ndarray>.__len__`.
*
* @param dst_length The returned result
*/
template <typename SizeT>
SizeT len(const NDArray<SizeT>* ndarray) {
// numpy prohibits `__len__` on unsized objects
if (ndarray->ndims == 0) {
raise_exception(SizeT, EXN_TYPE_ERROR, "len() of unsized object",
NO_PARAM, NO_PARAM, NO_PARAM);
} else {
return ndarray->shape[0];
}
}
/**
* @brief Return a boolean indicating if `ndarray` is (C-)contiguous.
*
* You may want to see: ndarray's rules for C-contiguity: https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45
*/
template <typename SizeT>
bool is_c_contiguous(const NDArray<SizeT>* ndarray) {
// Other references:
// - tinynumpy's implementation: https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L102
// - ndarray's flags["C_CONTIGUOUS"]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags
// - ndarray's rules for C-contiguity: https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45
// From https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45:
//
// The traditional rule is that for an array to be flagged as C contiguous,
// the following must hold:
//
// strides[-1] == itemsize
// strides[i] == shape[i+1] * strides[i + 1]
// [...]
// According to these rules, a 0- or 1-dimensional array is either both
// C- and F-contiguous, or neither; and an array with 2+ dimensions
// can be C- or F- contiguous, or neither, but not both. Though there
// there are exceptions for arrays with zero or one item, in the first
// case the check is relaxed up to and including the first dimension
// with shape[i] == 0. In the second case `strides == itemsize` will
// can be true for all dimensions and both flags are set.
if (ndarray->ndims == 0) {
return true;
}
if (ndarray->strides[ndarray->ndims - 1] != ndarray->itemsize) {
return false;
}
for (SizeT i = 1; i < ndarray->ndims; i++) {
SizeT axis_i = ndarray->ndims - i - 1;
if (ndarray->strides[axis_i] !=
ndarray->shape[axis_i + 1] * ndarray->strides[axis_i + 1]) {
return false;
}
}
return true;
}
/**
* @brief Return the pointer to the element indexed by `indices`.
*/
template <typename SizeT>
uint8_t* get_pelement_by_indices(const NDArray<SizeT>* ndarray,
const SizeT* indices) {
uint8_t* element = ndarray->data;
for (SizeT dim_i = 0; dim_i < ndarray->ndims; dim_i++)
element += indices[dim_i] * ndarray->strides[dim_i];
return element;
}
/**
* @brief Convenience function. Like `get_pelement_by_indices` but
* reinterprets the element pointer.
*/
template <typename SizeT, typename T>
T* get_ptr(const NDArray<SizeT>* ndarray, const SizeT* indices) {
return (T*)get_pelement_by_indices(ndarray, indices);
}
/**
* @brief Return the pointer to the nth (0-based) element in a flattened view of `ndarray`.
*
* This function does no bound check.
*/
template <typename SizeT>
uint8_t* get_nth_pelement(const NDArray<SizeT>* ndarray, SizeT nth) {
uint8_t* element = ndarray->data;
for (SizeT i = 0; i < ndarray->ndims; i++) {
SizeT axis = ndarray->ndims - i - 1;
SizeT dim = ndarray->shape[axis];
element += ndarray->strides[axis] * (nth % dim);
nth /= dim;
}
return element;
}
/**
* @brief Update the strides of an ndarray given an ndarray `shape`
* and assuming that the ndarray is fully c-contagious.
*
* You might want to read https://ajcr.net/stride-guide-part-1/.
*/
template <typename SizeT>
void set_strides_by_shape(NDArray<SizeT>* ndarray) {
SizeT stride_product = 1;
for (SizeT i = 0; i < ndarray->ndims; i++) {
SizeT axis = ndarray->ndims - i - 1;
ndarray->strides[axis] = stride_product * ndarray->itemsize;
stride_product *= ndarray->shape[axis];
}
}
/**
* @brief Set an element in `ndarray`.
*
* @param pelement Pointer to the element in `ndarray` to be set.
* @param pvalue Pointer to the value `pelement` will be set to.
*/
template <typename SizeT>
void set_pelement_value(NDArray<SizeT>* ndarray, uint8_t* pelement,
const uint8_t* pvalue) {
__builtin_memcpy(pelement, pvalue, ndarray->itemsize);
}
/**
* @brief Copy data from one ndarray to another of the exact same size and itemsize.
*
* Both ndarrays will be viewed in their flatten views when copying the elements.
*/
template <typename SizeT>
void copy_data(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
// TODO: Make this faster with memcpy
debug_assert_eq(SizeT, src_ndarray->itemsize, dst_ndarray->itemsize);
for (SizeT i = 0; i < size(src_ndarray); i++) {
auto src_element = ndarray::basic::get_nth_pelement(src_ndarray, i);
auto dst_element = ndarray::basic::get_nth_pelement(dst_ndarray, i);
ndarray::basic::set_pelement_value(dst_ndarray, dst_element,
src_element);
}
}
} // namespace basic
} // namespace ndarray
} // namespace
extern "C" {
using namespace ndarray::basic;
void __nac3_ndarray_util_assert_shape_no_negative(int32_t ndims,
int32_t* shape) {
assert_shape_no_negative(ndims, shape);
}
void __nac3_ndarray_util_assert_shape_no_negative64(int64_t ndims,
int64_t* shape) {
assert_shape_no_negative(ndims, shape);
}
void __nac3_ndarray_util_assert_output_shape_same(int32_t ndarray_ndims,
const int32_t* ndarray_shape,
int32_t output_ndims,
const int32_t* output_shape) {
assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims,
output_shape);
}
void __nac3_ndarray_util_assert_output_shape_same64(
int64_t ndarray_ndims, const int64_t* ndarray_shape, int64_t output_ndims,
const int64_t* output_shape) {
assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims,
output_shape);
}
uint32_t __nac3_ndarray_size(NDArray<int32_t>* ndarray) {
return size(ndarray);
}
uint64_t __nac3_ndarray_size64(NDArray<int64_t>* ndarray) {
return size(ndarray);
}
uint32_t __nac3_ndarray_nbytes(NDArray<int32_t>* ndarray) {
return nbytes(ndarray);
}
uint64_t __nac3_ndarray_nbytes64(NDArray<int64_t>* ndarray) {
return nbytes(ndarray);
}
int32_t __nac3_ndarray_len(NDArray<int32_t>* ndarray) { return len(ndarray); }
int64_t __nac3_ndarray_len64(NDArray<int64_t>* ndarray) { return len(ndarray); }
bool __nac3_ndarray_is_c_contiguous(NDArray<int32_t>* ndarray) {
return is_c_contiguous(ndarray);
}
bool __nac3_ndarray_is_c_contiguous64(NDArray<int64_t>* ndarray) {
return is_c_contiguous(ndarray);
}
uint8_t* __nac3_ndarray_get_nth_pelement(const NDArray<int32_t>* ndarray,
int32_t nth) {
return get_nth_pelement(ndarray, nth);
}
uint8_t* __nac3_ndarray_get_nth_pelement64(const NDArray<int64_t>* ndarray,
int64_t nth) {
return get_nth_pelement(ndarray, nth);
}
void __nac3_ndarray_set_strides_by_shape(NDArray<int32_t>* ndarray) {
set_strides_by_shape(ndarray);
}
void __nac3_ndarray_set_strides_by_shape64(NDArray<int64_t>* ndarray) {
set_strides_by_shape(ndarray);
}
void __nac3_ndarray_copy_data(NDArray<int32_t>* src_ndarray,
NDArray<int32_t>* dst_ndarray) {
copy_data(src_ndarray, dst_ndarray);
}
void __nac3_ndarray_copy_data64(NDArray<int64_t>* src_ndarray,
NDArray<int64_t>* dst_ndarray) {
copy_data(src_ndarray, dst_ndarray);
}
}

View File

@ -0,0 +1,171 @@
#pragma once
#include <irrt/int_defs.hpp>
#include <irrt/ndarray/def.hpp>
#include <irrt/slice.hpp>
namespace {
template <typename SizeT>
struct ShapeEntry {
SizeT ndims;
SizeT* shape;
};
} // namespace
namespace {
namespace ndarray {
namespace broadcast {
/**
* @brief Return true if `src_shape` can broadcast to `dst_shape`.
*
* See https://numpy.org/doc/stable/user/basics.broadcasting.html
*/
template <typename SizeT>
bool can_broadcast_shape_to(SizeT target_ndims, const SizeT* target_shape,
SizeT src_ndims, const SizeT* src_shape) {
if (src_ndims > target_ndims) {
return false;
}
for (SizeT i = 0; i < src_ndims; i++) {
SizeT target_dim = target_shape[target_ndims - i - 1];
SizeT src_dim = src_shape[src_ndims - i - 1];
if (!(src_dim == 1 || target_dim == src_dim)) {
return false;
}
}
return true;
}
/**
* @brief Performs `np.broadcast_shapes(<shapes>)`
*
* @param num_shapes Number of entries in `shapes`
* @param shapes The list of shape to do `np.broadcast_shapes` on.
* @param dst_ndims The length of `dst_shape`.
* `dst_ndims` must be `max([shape.ndims for shape in shapes])`, but the caller has to calculate it/provide it.
* for this function since they should already know in order to allocate `dst_shape` in the first place.
* @param dst_shape The resulting shape. Must be pre-allocated by the caller. This function calculate the result
* of `np.broadcast_shapes` and write it here.
*/
template <typename SizeT>
void broadcast_shapes(SizeT num_shapes, const ShapeEntry<SizeT>* shapes,
SizeT dst_ndims, SizeT* dst_shape) {
for (SizeT dst_axis = 0; dst_axis < dst_ndims; dst_axis++) {
dst_shape[dst_axis] = 1;
}
#ifdef IRRT_DEBUG_ASSERT
SizeT max_ndims_found = 0;
#endif
for (SizeT i = 0; i < num_shapes; i++) {
ShapeEntry<SizeT> entry = shapes[i];
// Check pre-condition: `dst_ndims` must be `max([shape.ndims for shape in shapes])`
debug_assert(SizeT, entry.ndims <= dst_ndims);
#ifdef IRRT_DEBUG_ASSERT
max_ndims_found = max(max_ndims_found, entry.ndims);
#endif
for (SizeT j = 0; j < entry.ndims; j++) {
SizeT entry_axis = entry.ndims - j - 1;
SizeT dst_axis = dst_ndims - j - 1;
SizeT entry_dim = entry.shape[entry_axis];
SizeT dst_dim = dst_shape[dst_axis];
if (dst_dim == 1) {
dst_shape[dst_axis] = entry_dim;
} else if (entry_dim == 1 || entry_dim == dst_dim) {
// Do nothing
} else {
raise_exception(SizeT, EXN_VALUE_ERROR,
"shape mismatch: objects cannot be broadcast "
"to a single shape.",
NO_PARAM, NO_PARAM, NO_PARAM);
}
}
}
// Check pre-condition: `dst_ndims` must be `max([shape.ndims for shape in shapes])`
debug_assert_eq(SizeT, max_ndims_found, dst_ndims);
}
/**
* @brief Perform `np.broadcast_to(<ndarray>, <target_shape>)` and appropriate assertions.
*
* This function attempts to broadcast `src_ndarray` to a new shape defined by `dst_ndarray.shape`,
* and return the result by modifying `dst_ndarray`.
*
* # Notes on `dst_ndarray`
* The caller is responsible for allocating space for the resulting ndarray.
* Here is what this function expects from `dst_ndarray` when called:
* - `dst_ndarray->data` does not have to be initialized.
* - `dst_ndarray->itemsize` does not have to be initialized.
* - `dst_ndarray->ndims` must be initialized, determining the length of `dst_ndarray->shape`
* - `dst_ndarray->shape` must be allocated, and must contain the desired target broadcast shape.
* - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values.
* When this function call ends:
* - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`)
* - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`
* - `dst_ndarray->ndims` is unchanged.
* - `dst_ndarray->shape` is unchanged.
* - `dst_ndarray->strides` is updated accordingly by how ndarray broadcast_to works.
*/
template <typename SizeT>
void broadcast_to(const NDArray<SizeT>* src_ndarray,
NDArray<SizeT>* dst_ndarray) {
if (!ndarray::broadcast::can_broadcast_shape_to(
dst_ndarray->ndims, dst_ndarray->shape, src_ndarray->ndims,
src_ndarray->shape)) {
raise_exception(SizeT, EXN_VALUE_ERROR,
"operands could not be broadcast together", NO_PARAM,
NO_PARAM, NO_PARAM);
}
dst_ndarray->data = src_ndarray->data;
dst_ndarray->itemsize = src_ndarray->itemsize;
for (SizeT i = 0; i < dst_ndarray->ndims; i++) {
SizeT src_axis = src_ndarray->ndims - i - 1;
SizeT dst_axis = dst_ndarray->ndims - i - 1;
if (src_axis < 0 || (src_ndarray->shape[src_axis] == 1 &&
dst_ndarray->shape[dst_axis] != 1)) {
// Freeze the steps in-place
dst_ndarray->strides[dst_axis] = 0;
} else {
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
}
}
}
} // namespace broadcast
} // namespace ndarray
} // namespace
extern "C" {
using namespace ndarray::broadcast;
void __nac3_ndarray_broadcast_to(NDArray<int32_t>* src_ndarray,
NDArray<int32_t>* dst_ndarray) {
broadcast_to(src_ndarray, dst_ndarray);
}
void __nac3_ndarray_broadcast_to64(NDArray<int64_t>* src_ndarray,
NDArray<int64_t>* dst_ndarray) {
broadcast_to(src_ndarray, dst_ndarray);
}
void __nac3_ndarray_broadcast_shapes(int32_t num_shapes,
const ShapeEntry<int32_t>* shapes,
int32_t dst_ndims, int32_t* dst_shape) {
broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape);
}
void __nac3_ndarray_broadcast_shapes64(int64_t num_shapes,
const ShapeEntry<int64_t>* shapes,
int64_t dst_ndims, int64_t* dst_shape) {
broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape);
}
}

View File

@ -0,0 +1,44 @@
#pragma once
namespace {
/**
* @brief The NDArray object
*
* The official numpy implementations: https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst
*/
template <typename SizeT>
struct NDArray {
/**
* @brief The underlying data this `ndarray` is pointing to.
*
* Must be set to `nullptr` to indicate that this NDArray's `data` is uninitialized.
*/
uint8_t* data;
/**
* @brief The number of bytes of a single element in `data`.
*/
SizeT itemsize;
/**
* @brief The number of dimensions of this shape.
*/
SizeT ndims;
/**
* @brief The NDArray shape, with length equal to `ndims`.
*
* Note that it may contain 0.
*/
SizeT* shape;
/**
* @brief Array strides, with length equal to `ndims`
*
* The stride values are in units of bytes, not number of elements.
*
* Note that `strides` can have negative values.
*/
SizeT* strides;
};
} // namespace

View File

@ -0,0 +1,221 @@
#pragma once
#include <irrt/exception.hpp>
#include <irrt/int_defs.hpp>
#include <irrt/ndarray/basic.hpp>
#include <irrt/ndarray/def.hpp>
#include <irrt/slice.hpp>
namespace {
typedef uint8_t NDIndexType;
/**
* @brief A single element index
*
* `data` points to a `SliceIndex`.
*/
const NDIndexType ND_INDEX_TYPE_SINGLE_ELEMENT = 0;
/**
* @brief A slice index
*
* `data` points to a `UserRange`.
*/
const NDIndexType ND_INDEX_TYPE_SLICE = 1;
/**
* @brief `np.newaxis` / `None`
*
* `data` is unused.
*/
const NDIndexType ND_INDEX_TYPE_NEWAXIS = 2;
/**
* @brief `Ellipsis` / `...`
*
* `data` is unused.
*/
const NDIndexType ND_INDEX_TYPE_ELLIPSIS = 3;
/**
* @brief An index used in ndarray indexing
*/
struct NDIndex {
/**
* @brief Enum tag to specify the type of index.
*
* Please see comments of each enum constant.
*/
NDIndexType type;
/**
* @brief The accompanying data associated with `type`.
*
* Please see comments of each enum constant.
*/
uint8_t* data;
};
} // namespace
namespace {
namespace ndarray {
namespace indexing {
/**
* @brief Perform ndarray "basic indexing" (https://numpy.org/doc/stable/user/basics.indexing.html#basic-indexing)
*
* This is function very similar to performing `dst_ndarray = src_ndarray[indexes]` in Python (where the variables
* can all be found in the parameter of this function).
*
* In other words, this function takes in an ndarray (`src_ndarray`), index it with `indexes`, and return the
* indexed array (by writing the result to `dst_ndarray`).
*
* This function also does proper assertions on `indexes`.
*
* # Notes on `dst_ndarray`
* The caller is responsible for allocating space for the resulting ndarray.
* Here is what this function expects from `dst_ndarray` when called:
* - `dst_ndarray->data` does not have to be initialized.
* - `dst_ndarray->itemsize` does not have to be initialized.
* - `dst_ndarray->ndims` must be initialized, and it must be equal to the expected `ndims` of the `dst_ndarray` after
* indexing `src_ndarray` with `indexes`.
* - `dst_ndarray->shape` must be allocated, through it can contain uninitialized values.
* - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values.
* When this function call ends:
* - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`)
* - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`
* - `dst_ndarray->ndims` is unchanged.
* - `dst_ndarray->shape` is updated according to how `src_ndarray` is indexed.
* - `dst_ndarray->strides` is updated accordingly by how ndarray indexing works.
*
* @param indexes Indexes to index `src_ndarray`, ordered in the same way you would write them in Python.
* @param src_ndarray The NDArray to be indexed.
* @param dst_ndarray The resulting NDArray after indexing. Further details in the comments above,
*/
template <typename SizeT>
void index(SizeT num_indexes, const NDIndex* indexes,
const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
// First, validate `indexes`.
// Expected value of `dst_ndarray->ndims`.
SizeT expected_dst_ndims = src_ndarray->ndims;
// To check for "too many indices for array: array is ?-dimensional, but ? were indexed"
SizeT num_indexed = 0;
// There may be ellipsis `...` in `indexes`. There can only be 0 or 1 ellipsis.
SizeT num_ellipsis = 0;
for (SizeT i = 0; i < num_indexes; i++) {
if (indexes[i].type == ND_INDEX_TYPE_SINGLE_ELEMENT) {
expected_dst_ndims--;
num_indexed++;
} else if (indexes[i].type == ND_INDEX_TYPE_SLICE) {
num_indexed++;
} else if (indexes[i].type == ND_INDEX_TYPE_NEWAXIS) {
expected_dst_ndims++;
} else if (indexes[i].type == ND_INDEX_TYPE_ELLIPSIS) {
num_ellipsis++;
if (num_ellipsis > 1) {
raise_exception(
SizeT, EXN_INDEX_ERROR,
"an index can only have a single ellipsis ('...')",
NO_PARAM, NO_PARAM, NO_PARAM);
}
} else {
__builtin_unreachable();
}
}
debug_assert_eq(SizeT, expected_dst_ndims, dst_ndarray->ndims);
if (src_ndarray->ndims - num_indexed < 0) {
raise_exception(SizeT, EXN_INDEX_ERROR,
"too many indices for array: array is {0}-dimensional, "
"but {1} were indexed",
src_ndarray->ndims, num_indexes, NO_PARAM);
}
dst_ndarray->data = src_ndarray->data;
dst_ndarray->itemsize = src_ndarray->itemsize;
// Reference code: https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L652
SizeT src_axis = 0;
SizeT dst_axis = 0;
for (SliceIndex i = 0; i < num_indexes; i++) {
const NDIndex* index = &indexes[i];
if (index->type == ND_INDEX_TYPE_SINGLE_ELEMENT) {
SliceIndex input = *((SliceIndex*)index->data);
SliceIndex k = slice::resolve_index_in_length(
src_ndarray->shape[src_axis], input);
if (k == slice::OUT_OF_BOUNDS) {
raise_exception(SizeT, EXN_INDEX_ERROR,
"index {0} is out of bounds for axis {1} "
"with size {2}",
input, src_axis, src_ndarray->shape[src_axis]);
}
dst_ndarray->data += k * src_ndarray->strides[src_axis];
src_axis++;
} else if (index->type == ND_INDEX_TYPE_SLICE) {
UserSlice* input = (UserSlice*)index->data;
Slice slice =
input->indices_checked<SizeT>(src_ndarray->shape[src_axis]);
dst_ndarray->data +=
(SizeT)slice.start * src_ndarray->strides[src_axis];
dst_ndarray->strides[dst_axis] =
((SizeT)slice.step) * src_ndarray->strides[src_axis];
dst_ndarray->shape[dst_axis] = (SizeT)slice.len();
dst_axis++;
src_axis++;
} else if (index->type == ND_INDEX_TYPE_NEWAXIS) {
dst_ndarray->strides[dst_axis] = 0;
dst_ndarray->shape[dst_axis] = 1;
dst_axis++;
} else if (index->type == ND_INDEX_TYPE_ELLIPSIS) {
// The number of ':' entries this '...' implies.
SizeT ellipsis_size = src_ndarray->ndims - num_indexed;
for (SizeT j = 0; j < ellipsis_size; j++) {
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis];
dst_axis++;
src_axis++;
}
} else {
__builtin_unreachable();
}
}
for (; dst_axis < dst_ndarray->ndims; dst_axis++, src_axis++) {
dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis];
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
}
debug_assert_eq(SizeT, src_ndarray->ndims, src_axis);
debug_assert_eq(SizeT, dst_ndarray->ndims, dst_axis);
}
} // namespace indexing
} // namespace ndarray
} // namespace
extern "C" {
using namespace ndarray::indexing;
void __nac3_ndarray_index(int32_t num_indexes, NDIndex* indexes,
NDArray<int32_t>* src_ndarray,
NDArray<int32_t>* dst_ndarray) {
index(num_indexes, indexes, src_ndarray, dst_ndarray);
}
void __nac3_ndarray_index64(int64_t num_indexes, NDIndex* indexes,
NDArray<int64_t>* src_ndarray,
NDArray<int64_t>* dst_ndarray) {
index(num_indexes, indexes, src_ndarray, dst_ndarray);
}
}

View File

@ -0,0 +1,118 @@
#pragma once
#include <irrt/int_defs.hpp>
#include <irrt/ndarray/def.hpp>
namespace {
/**
* @brief Helper struct to enumerate through all indices under a shape.
*
* i.e., If `shape` is `[3, 2]`, by repeating `next()`, then you get:
* - `[0, 0]`
* - `[0, 1]`
* - `[1, 0]`
* - `[1, 1]`
* - `[2, 0]`
* - `[2, 1]`
* - end.
*
* Interesting cases:
* - If ndims == 0, there is one enumeration.
* - If shape contains zeroes, there are no enumerations.
*/
template <typename SizeT>
struct NDIter {
SizeT ndims;
SizeT* shape;
SizeT* strides;
/**
* @brief The current indices.
*
* Must be allocated by the caller.
*/
SizeT* indices;
/**
* @brief The nth (0-based) index of the current indices.
*/
SizeT nth;
/**
* @brief Pointer to the current element.
*/
uint8_t* element;
/**
* @brief The product of shape.
*/
SizeT size;
// TODO:: There is something called backstrides to speedup iteration.
// See https://ajcr.net/stride-guide-part-1/, and https://docs.scipy.org/doc/numpy-1.13.0/reference/c-api.types-and-structures.html#c.PyArrayIterObject.PyArrayIterObject.backstrides.
// Maybe LLVM is clever and knows how to optimize.
void initialize(SizeT ndims, SizeT* shape, SizeT* strides, uint8_t* element,
SizeT* indices) {
this->ndims = ndims;
this->shape = shape;
this->strides = strides;
this->indices = indices;
this->element = element;
// Compute size and backstrides
this->size = 1;
for (SizeT i = 0; i < ndims; i++) {
this->size *= shape[i];
}
for (SizeT axis = 0; axis < ndims; axis++) indices[axis] = 0;
nth = 0;
}
void initialize_by_ndarray(NDArray<SizeT>* ndarray, SizeT* indices) {
this->initialize(ndarray->ndims, ndarray->shape, ndarray->strides,
ndarray->data, indices);
}
bool has_next() { return nth < size; }
void next() {
for (SizeT i = 0; i < ndims; i++) {
SizeT axis = ndims - i - 1;
indices[axis]++;
if (indices[axis] >= shape[axis]) {
indices[axis] = 0;
// TODO: Can be optimized with backstrides.
element -= strides[axis] * (shape[axis] - 1);
} else {
element += strides[axis];
break;
}
}
nth++;
}
};
} // namespace
extern "C" {
void __nac3_nditer_initialize(NDIter<int32_t>* iter, NDArray<int32_t>* ndarray,
int32_t* indices) {
iter->initialize_by_ndarray(ndarray, indices);
}
void __nac3_nditer_initialize64(NDIter<int64_t>* iter,
NDArray<int64_t>* ndarray, int64_t* indices) {
iter->initialize_by_ndarray(ndarray, indices);
}
bool __nac3_nditer_has_next(NDIter<int32_t>* iter) { return iter->has_next(); }
bool __nac3_nditer_has_next64(NDIter<int64_t>* iter) { return iter->has_next(); }
void __nac3_nditer_next(NDIter<int32_t>* iter) { iter->next(); }
void __nac3_nditer_next64(NDIter<int64_t>* iter) { iter->next(); }
}

View File

@ -0,0 +1,194 @@
#pragma once
#include <irrt/int_defs.hpp>
#include <irrt/ndarray/basic.hpp>
#include <irrt/ndarray/broadcast.hpp>
#include <irrt/ndarray/iter.hpp>
// NOTE: Everything would be much easier and elegant if einsum is implemented.
namespace {
namespace ndarray {
namespace matmul {
/*
* In einsum notation, the output is the broadcasts performed by `np.einsum("...ij,...jk->...ik", a, b)`.
*
* Example:
* Suppose `a_shape == [99, 1, 97, 4, 2]`
* and `b_shape == [ 1, 98, 1, 2, 5]`,
*
* ...then `new_a_shape == [99, 98, 97, 4, 2]`,
* `new_b_shape == [99, 98, 97, 2, 5]`,
* and `dst_shape == [99, 98, 97, 4, 5]`.
* ^^^^^^^^^^ ^^^^
* (by broadcast) (4x2 @ 2x5 => 4x5)
*/
template <typename SizeT>
void calculate_shapes(SizeT a_ndims, SizeT* a_shape, SizeT b_ndims,
SizeT* b_shape, SizeT final_ndims, SizeT* new_a_shape,
SizeT* new_b_shape, SizeT* dst_shape) {
debug_assert(SizeT, a_ndims >= 2);
debug_assert(SizeT, b_ndims >= 2);
debug_assert_eq(SizeT, max(a_ndims, b_ndims), final_ndims);
const SizeT num_entries = 2;
ShapeEntry<SizeT> entries[num_entries] = {
{.ndims = a_ndims - 2, .shape = a_shape},
{.ndims = b_ndims - 2, .shape = b_shape}};
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries,
final_ndims - 2, new_a_shape);
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries,
final_ndims - 2, new_b_shape);
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries,
final_ndims - 2, dst_shape);
new_a_shape[final_ndims - 2] = a_shape[a_ndims - 2];
new_a_shape[final_ndims - 1] = a_shape[a_ndims - 1];
new_b_shape[final_ndims - 2] = b_shape[b_ndims - 2];
new_b_shape[final_ndims - 1] = b_shape[b_ndims - 1];
dst_shape[final_ndims - 2] = a_shape[a_ndims - 2];
dst_shape[final_ndims - 1] = b_shape[b_ndims - 1];
}
/**
* @brief Perform `np.matmul(a, b)` but the inputs are both rank >=2 matrices and `a.shape[:-2] == b.shape[:-2]`.
*
* The compatibility of `a` and `b` (for their `.shape[-2:]`) are asserted.
*
* Also see https://numpy.org/doc/stable/reference/generated/numpy.matmul.html#numpy-matmul.
*
* This function expects `dst_ndarray` to contain the following content when called:
* - `dst_ndarray->data` is allocated. Can be uninitialized.
* - `dst_ndarray->itemsize` is set to `sizeof(T)`.
* - `dst_ndarray->ndims` is set appropriately.
* - `dst_ndarray->shape` is set appropriately.
* - `dst_ndarray->strides` is ignored.
*
* Moreover, the shapes of `a_ndarray`, `b_ndarray`, and `dst_ndarray` **must be the same**. This implies
*/
template <typename SizeT, typename T>
void matmul_at_least_2d(NDArray<SizeT>* a_ndarray, NDArray<SizeT>* b_ndarray,
NDArray<SizeT>* dst_ndarray) {
// All inputs' ndims should be >= 2 and be the same.
debug_assert_eq(SizeT, a_ndarray->ndims, b_ndarray->ndims);
debug_assert_eq(SizeT, a_ndarray->ndims, dst_ndarray->ndims);
debug_assert(SizeT, a_ndarray->ndims >= 2);
debug_assert_eq(SizeT, a_ndarray->itemsize, sizeof(T));
debug_assert_eq(SizeT, b_ndarray->itemsize, sizeof(T));
debug_assert_eq(SizeT, dst_ndarray->itemsize, sizeof(T));
if (IRRT_DEBUG_ASSERT_BOOL) {
// Check that the shapes are the same.
for (SizeT i = 0; i < a_ndarray->ndims - 2; i++) {
if (dst_ndarray->shape[0] != a_ndarray->shape[0]) {
raise_debug_assert(
SizeT, "Bad shape. At axis {0}, a has {1}, dst has {2}", i,
a_ndarray->shape[i], dst_ndarray->shape[i]);
}
if (dst_ndarray->shape[0] != b_ndarray->shape[0]) {
raise_debug_assert(
SizeT, "Bad shape. At axis {0}, b has {1}, dst has {2}", i,
b_ndarray->shape[i], dst_ndarray->shape[i]);
}
}
}
// Number of dimensions dedicated to stacking
// e.g., [4, 6, 1, 2, 3]
// ^^^^^^^ count these
const SizeT u = a_ndarray->ndims - 2; // Alias
SizeT* a_mat_shape = a_ndarray->shape + u;
SizeT* b_mat_shape = b_ndarray->shape + u;
SizeT* dst_mat_shape = dst_ndarray->shape + u;
// Assert that dst_ndarray has the correct shape
debug_assert_eq(SizeT, dst_mat_shape[0], a_mat_shape[0]);
debug_assert_eq(SizeT, dst_mat_shape[1], b_mat_shape[1]);
// Check that a and b are compatible for matmul
if (a_mat_shape[1] != b_mat_shape[0]) {
// This is a custom error message. Different from NumPy.
raise_exception(
SizeT, EXN_VALUE_ERROR,
"Cannot multiply LHS (shape ?x{0}) with RHS (shape {1}x?})",
a_mat_shape[1], b_mat_shape[0], NO_PARAM);
}
// Iterate through shape[:-2]. i.e,
// Given a = [5, 4, 3, m, p] and b = [5, 4, 3, p, n]. We iterate through [5, 4, 3].
SizeT* indices =
(SizeT*)__builtin_alloca(sizeof(SizeT) * dst_ndarray->ndims);
SizeT* mat_indices = indices + u;
NDIter<SizeT> iter;
iter.initialize(u, dst_ndarray->shape, dst_ndarray->strides,
dst_ndarray->data, indices);
for (; iter.has_next(); iter.next()) {
for (SizeT i = 0; i < dst_mat_shape[0]; i++) {
for (SizeT j = 0; j < dst_mat_shape[1]; j++) {
// `indices` is being reused to index into different ndarrays.
mat_indices[0] = i;
mat_indices[1] = j;
T* d = ndarray::basic::get_ptr<SizeT, T>(dst_ndarray, indices);
*d = 0;
for (SizeT k = 0; k < a_ndarray->shape[1]; k++) {
mat_indices[0] = i;
mat_indices[1] = k;
T* a =
ndarray::basic::get_ptr<SizeT, T>(a_ndarray, indices);
mat_indices[0] = k;
mat_indices[1] = j;
T* b =
ndarray::basic::get_ptr<SizeT, T>(b_ndarray, indices);
*d += (*a) * (*b);
}
}
}
}
}
} // namespace matmul
} // namespace ndarray
} // namespace
extern "C" {
using namespace ndarray::matmul;
void __nac3_ndarray_matmul_calculate_shapes(int32_t a_ndims, int32_t* a_shape,
int32_t b_ndims, int32_t* b_shape,
int32_t final_ndims,
int32_t* new_a_shape,
int32_t* new_b_shape,
int32_t* dst_shape) {
calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims,
new_a_shape, new_b_shape, dst_shape);
}
void __nac3_ndarray_matmul_calculate_shapes64(int64_t a_ndims, int64_t* a_shape,
int64_t b_ndims, int64_t* b_shape,
int64_t final_ndims,
int64_t* new_a_shape,
int64_t* new_b_shape,
int64_t* dst_shape) {
calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims,
new_a_shape, new_b_shape, dst_shape);
}
void __nac3_ndarray_float64_matmul_at_least_2d(NDArray<int32_t>* a_ndarray,
NDArray<int32_t>* b_ndarray,
NDArray<int32_t>* dst_ndarray) {
matmul_at_least_2d<int32_t, double>(a_ndarray, b_ndarray, dst_ndarray);
}
void __nac3_ndarray_float64_matmul_at_least_2d64(
NDArray<int64_t>* a_ndarray, NDArray<int64_t>* b_ndarray,
NDArray<int64_t>* dst_ndarray) {
matmul_at_least_2d<int64_t, double>(a_ndarray, b_ndarray, dst_ndarray);
}
}

View File

@ -0,0 +1,106 @@
#pragma once
#include <irrt/int_defs.hpp>
#include <irrt/ndarray/def.hpp>
namespace {
namespace ndarray {
namespace reshape {
/**
* @brief Perform assertions on and resolve unknown dimensions in `new_shape` in `np.reshape(<ndarray>, new_shape)`
*
* If `new_shape` indeed contains unknown dimensions (specified with `-1`, just like numpy), `new_shape` will be
* modified to contain the resolved dimension.
*
* To perform assertions on and resolve unknown dimensions in `new_shape`, we don't need the actual
* `<ndarray>` object itself, but only the `.size` of the `<ndarray>`.
*
* @param size The `.size` of `<ndarray>`
* @param new_ndims Number of elements in `new_shape`
* @param new_shape Target shape to reshape to
*/
template <typename SizeT>
void resolve_and_check_new_shape(SizeT size, SizeT new_ndims,
SizeT* new_shape) {
// Is there a -1 in `new_shape`?
bool neg1_exists = false;
// Location of -1, only initialized if `neg1_exists` is true
SizeT neg1_axis_i;
// The computed ndarray size of `new_shape`
SizeT new_size = 1;
for (SizeT axis_i = 0; axis_i < new_ndims; axis_i++) {
SizeT dim = new_shape[axis_i];
if (dim < 0) {
if (dim == -1) {
if (neg1_exists) {
// Multiple `-1` found. Throw an error.
raise_exception(SizeT, EXN_VALUE_ERROR,
"can only specify one unknown dimension",
NO_PARAM, NO_PARAM, NO_PARAM);
} else {
neg1_exists = true;
neg1_axis_i = axis_i;
}
} else {
// TODO: What? In `np.reshape` any negative dimensions is
// treated like its `-1`.
//
// Try running `np.zeros((3, 4)).reshape((-999, 2))`
//
// It is not documented by numpy.
// Throw an error for now...
raise_exception(
SizeT, EXN_VALUE_ERROR,
"Found non -1 negative dimension {0} on axis {1}", dim,
axis_i, NO_PARAM);
}
} else {
new_size *= dim;
}
}
bool can_reshape;
if (neg1_exists) {
// Let `x` be the unknown dimension
// Solve `x * <new_size> = <size>`
if (new_size == 0 && size == 0) {
// `x` has infinitely many solutions
can_reshape = false;
} else if (new_size == 0 && size != 0) {
// `x` has no solutions
can_reshape = false;
} else if (size % new_size != 0) {
// `x` has no integer solutions
can_reshape = false;
} else {
can_reshape = true;
new_shape[neg1_axis_i] = size / new_size; // Resolve dimension
}
} else {
can_reshape = (new_size == size);
}
if (!can_reshape) {
raise_exception(SizeT, EXN_VALUE_ERROR,
"cannot reshape array of size {0} into given shape",
size, NO_PARAM, NO_PARAM);
}
}
} // namespace reshape
} // namespace ndarray
} // namespace
extern "C" {
void __nac3_ndarray_resolve_and_check_new_shape(int32_t size, int32_t new_ndims,
int32_t* new_shape) {
ndarray::reshape::resolve_and_check_new_shape(size, new_ndims, new_shape);
}
void __nac3_ndarray_resolve_and_check_new_shape64(int64_t size,
int64_t new_ndims,
int64_t* new_shape) {
ndarray::reshape::resolve_and_check_new_shape(size, new_ndims, new_shape);
}
}

View File

@ -0,0 +1,145 @@
#pragma once
#include <irrt/int_defs.hpp>
#include <irrt/ndarray/def.hpp>
#include <irrt/slice.hpp>
/*
* Notes on `np.transpose(<array>, <axes>)`
*
* TODO: `axes`, if specified, can actually contain negative indices,
* but it is not documented in numpy.
*
* Supporting it for now.
*/
namespace {
namespace ndarray {
namespace transpose {
/**
* @brief Do assertions on `<axes>` in `np.transpose(<array>, <axes>)`.
*
* Note that `np.transpose`'s `<axe>` argument is optional. If the argument
* is specified but the user, use this function to do assertions on it.
*
* @param ndims The number of dimensions of `<array>`
* @param num_axes Number of elements in `<axes>` as specified by the user.
* This should be equal to `ndims`. If not, a "ValueError: axes don't match array" is thrown.
* @param axes The user specified `<axes>`.
*/
template <typename SizeT>
void assert_transpose_axes(SizeT ndims, SizeT num_axes, const SizeT* axes) {
if (ndims != num_axes) {
raise_exception(SizeT, EXN_VALUE_ERROR, "axes don't match array",
NO_PARAM, NO_PARAM, NO_PARAM);
}
// TODO: Optimize this
bool* axe_specified = (bool*)__builtin_alloca(sizeof(bool) * ndims);
for (SizeT i = 0; i < ndims; i++) axe_specified[i] = false;
for (SizeT i = 0; i < ndims; i++) {
SizeT axis = slice::resolve_index_in_length(ndims, axes[i]);
if (axis == slice::OUT_OF_BOUNDS) {
// TODO: numpy actually throws a `numpy.exceptions.AxisError`
raise_exception(
SizeT, EXN_VALUE_ERROR,
"axis {0} is out of bounds for array of dimension {1}", axis,
ndims, NO_PARAM);
}
if (axe_specified[axis]) {
raise_exception(SizeT, EXN_VALUE_ERROR,
"repeated axis in transpose", NO_PARAM, NO_PARAM,
NO_PARAM);
}
axe_specified[axis] = true;
}
}
/**
* @brief Create a transpose view of `src_ndarray` and perform proper assertions.
*
* This function is very similar to doing `dst_ndarray = np.transpose(src_ndarray, <axes>)`.
* If `<axes>` is supposed to be `None`, caller can pass in a `nullptr` to `<axes>`.
*
* The transpose view created is returned by modifying `dst_ndarray`.
*
* The caller is responsible for setting up `dst_ndarray` before calling this function.
* Here is what this function expects from `dst_ndarray` when called:
* - `dst_ndarray->data` does not have to be initialized.
* - `dst_ndarray->itemsize` does not have to be initialized.
* - `dst_ndarray->ndims` must be initialized, must be equal to `src_ndarray->ndims`.
* - `dst_ndarray->shape` must be allocated, through it can contain uninitialized values.
* - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values.
* When this function call ends:
* - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`)
* - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`
* - `dst_ndarray->ndims` is unchanged
* - `dst_ndarray->shape` is updated according to how `np.transpose` works
* - `dst_ndarray->strides` is updated according to how `np.transpose` works
*
* @param src_ndarray The NDArray to build a transpose view on
* @param dst_ndarray The resulting NDArray after transpose. Further details in the comments above,
* @param num_axes Number of elements in axes. Unused if `axes` is nullptr.
* @param axes Axes permutation. Set it to `nullptr` if `<axes>` is `None`.
*/
template <typename SizeT>
void transpose(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray,
SizeT num_axes, const SizeT* axes) {
debug_assert_eq(SizeT, src_ndarray->ndims, dst_ndarray->ndims);
const auto ndims = src_ndarray->ndims;
if (axes != nullptr) assert_transpose_axes(ndims, num_axes, axes);
dst_ndarray->data = src_ndarray->data;
dst_ndarray->itemsize = src_ndarray->itemsize;
// Check out https://ajcr.net/stride-guide-part-2/ to see how `np.transpose` works behind the scenes.
if (axes == nullptr) {
// `np.transpose(<array>, axes=None)`
/*
* Minor note: `np.transpose(<array>, axes=None)` is equivalent to
* `np.transpose(<array>, axes=[N-1, N-2, ..., 0])` - basically it
* is reversing the order of strides and shape.
*
* This is a fast implementation to handle this special (but very common) case.
*/
for (SizeT axis = 0; axis < ndims; axis++) {
dst_ndarray->shape[axis] = src_ndarray->shape[ndims - axis - 1];
dst_ndarray->strides[axis] = src_ndarray->strides[ndims - axis - 1];
}
} else {
// `np.transpose(<array>, <axes>)`
// Permute strides and shape according to `axes`, while resolving negative indices in `axes`
for (SizeT axis = 0; axis < ndims; axis++) {
// `i` cannot be OUT_OF_BOUNDS because of assertions
SizeT i = slice::resolve_index_in_length(ndims, axes[axis]);
dst_ndarray->shape[axis] = src_ndarray->shape[i];
dst_ndarray->strides[axis] = src_ndarray->strides[i];
}
}
}
} // namespace transpose
} // namespace ndarray
} // namespace
extern "C" {
using namespace ndarray::transpose;
void __nac3_ndarray_transpose(const NDArray<int32_t>* src_ndarray,
NDArray<int32_t>* dst_ndarray, int32_t num_axes,
const int32_t* axes) {
transpose(src_ndarray, dst_ndarray, num_axes, axes);
}
void __nac3_ndarray_transpose64(const NDArray<int64_t>* src_ndarray,
NDArray<int64_t>* dst_ndarray, int64_t num_axes,
const int64_t* axes) {
transpose(src_ndarray, dst_ndarray, num_axes, axes);
}
}

View File

@ -0,0 +1,167 @@
#pragma once
#include <irrt/int_defs.hpp>
#include <irrt/slice.hpp>
#include <irrt/util.hpp>
#include "exception.hpp"
// The type of an index or a value describing the length of a
// range/slice is always `int32_t`.
using SliceIndex = int32_t;
namespace {
/**
* @brief A Python-like slice with resolved indices.
*
* "Resolved indices" means that `start` and `stop` must be positive and are
* bound to a known length.
*/
struct Slice {
SliceIndex start;
SliceIndex stop;
SliceIndex step;
/**
* @brief Calculate and return the length / the number of the slice.
*
* If this were a Python range, this function would be `len(range(start, stop, step))`.
*/
SliceIndex len() {
SliceIndex diff = stop - start;
if (diff > 0 && step > 0) {
return ((diff - 1) / step) + 1;
} else if (diff < 0 && step < 0) {
return ((diff + 1) / step) + 1;
} else {
return 0;
}
}
};
namespace slice {
/**
* @brief Resolve a slice index under a given length like Python indexing.
*
* In Python, if you have a `list` of length 100, `list[-1]` resolves to
* `list[99]`, so `resolve_index_in_length_clamped(100, -1)` returns `99`.
*
* If `length` is 0, 0 is returned for any value of `index`.
*
* If `index` is out of bounds, clamps the returned value between `0` and
* `length - 1` (inclusive).
*
*/
SliceIndex resolve_index_in_length_clamped(SliceIndex length,
SliceIndex index) {
if (index < 0) {
return max<SliceIndex>(length + index, 0);
} else {
return min<SliceIndex>(length, index);
}
}
const SliceIndex OUT_OF_BOUNDS = -1;
/**
* @brief Like `resolve_index_in_length_clamped`, but returns `OUT_OF_BOUNDS`
* if `index` is out of bounds.
*/
SliceIndex resolve_index_in_length(SliceIndex length, SliceIndex index) {
SliceIndex resolved = index < 0 ? length + index : index;
if (0 <= resolved && resolved < length) {
return resolved;
} else {
return OUT_OF_BOUNDS;
}
}
} // namespace slice
/**
* @brief A Python-like slice with **unresolved** indices.
*/
struct UserSlice {
bool start_defined;
SliceIndex start;
bool stop_defined;
SliceIndex stop;
bool step_defined;
SliceIndex step;
UserSlice() { this->reset(); }
void reset() {
this->start_defined = false;
this->stop_defined = false;
this->step_defined = false;
}
void set_start(SliceIndex start) {
this->start_defined = true;
this->start = start;
}
void set_stop(SliceIndex stop) {
this->stop_defined = true;
this->stop = stop;
}
void set_step(SliceIndex step) {
this->step_defined = true;
this->step = step;
}
/**
* @brief Resolve this slice.
*
* In Python, this would be `slice(start, stop, step).indices(length)`.
*
* @return A `Slice` with the resolved indices.
*/
Slice indices(SliceIndex length) {
Slice result;
result.step = step_defined ? step : 1;
bool step_is_negative = result.step < 0;
if (start_defined) {
result.start =
slice::resolve_index_in_length_clamped(length, start);
} else {
result.start = step_is_negative ? length - 1 : 0;
}
if (stop_defined) {
result.stop = slice::resolve_index_in_length_clamped(length, stop);
} else {
result.stop = step_is_negative ? -1 : length;
}
return result;
}
/**
* @brief Like `.indices()` but with assertions.
*/
template <typename SizeT>
Slice indices_checked(SliceIndex length) {
// TODO: Switch to `SizeT length`
if (length < 0) {
raise_exception(SizeT, EXN_VALUE_ERROR,
"length should not be negative, got {0}", length,
NO_PARAM, NO_PARAM);
}
if (this->step_defined && this->step == 0) {
raise_exception(SizeT, EXN_VALUE_ERROR, "slice step cannot be zero",
NO_PARAM, NO_PARAM, NO_PARAM);
}
return this->indices(length);
}
};
} // namespace

101
nac3core/irrt/irrt/util.hpp Normal file
View File

@ -0,0 +1,101 @@
#pragma once
namespace {
template <typename T>
const T& max(const T& a, const T& b) {
return a > b ? a : b;
}
template <typename T>
const T& min(const T& a, const T& b) {
return a > b ? b : a;
}
template <typename T>
bool arrays_match(int len, T* as, T* bs) {
for (int i = 0; i < len; i++) {
if (as[i] != bs[i]) return false;
}
return true;
}
namespace cstr_utils {
/**
* @brief Return true if `str` is empty.
*/
bool is_empty(const char* str) { return str[0] == '\0'; }
/**
* @brief Implementation of `strcmp()`
*/
int8_t compare(const char* a, const char* b) {
uint32_t i = 0;
while (true) {
if (a[i] < b[i]) {
return -1;
} else if (a[i] > b[i]) {
return 1;
} else {
if (a[i] == '\0') {
return 0;
} else {
i++;
}
}
}
}
/**
* @brief Return true two strings have the same content.
*/
int8_t equal(const char* a, const char* b) { return compare(a, b) == 0; }
/**
* @brief Implementation of `strlen()`.
*/
uint32_t length(const char* str) {
uint32_t length = 0;
while (*str != '\0') {
length++;
str++;
}
return length;
}
/**
* @brief Copy a null-terminated string to a buffer with limited size and guaranteed null-termination.
*
* `dst_max_size` must be greater than 0, otherwise this function has undefined behavior.
*
* This function attempts to copy everything from `src` from `dst`, and *always* null-terminates `dst`.
*
* If the size of `dst` is too small, the final byte (`dst[dst_max_size - 1]`) of `dst` will be set to
* the null terminator.
*
* @param src String to copy from.
* @param dst Buffer to copy string to.
* @param dst_max_size
* Number of bytes of this buffer, including the space needed for the null terminator.
* Must be greater than 0.
* @return If `dst` is too small to contain everything in `src`.
*/
bool copy(const char* src, char* dst, uint32_t dst_max_size) {
for (uint32_t i = 0; i < dst_max_size; i++) {
bool is_last = i + 1 == dst_max_size;
if (is_last && src[i] != '\0') {
dst[i] = '\0';
return false;
}
if (src[i] == '\0') {
dst[i] = '\0';
return true;
}
dst[i] = src[i];
}
__builtin_unreachable();
}
} // namespace cstr_utils
} // namespace

View File

@ -0,0 +1,24 @@
#pragma once
#ifdef IRRT_DEBUG
#define IRRT_DEBUG_ASSERT
#define IRRT_DEBUG_ASSERT_BOOL true
#else
#define IRRT_DEBUG_ASSERT_BOOL false
#endif
#include <irrt/core.hpp>
#include <irrt/debug.hpp>
#include <irrt/exception.hpp>
#include <irrt/int_defs.hpp>
#include <irrt/list.hpp>
#include <irrt/ndarray/array.hpp>
#include <irrt/ndarray/basic.hpp>
#include <irrt/ndarray/broadcast.hpp>
#include <irrt/ndarray/def.hpp>
#include <irrt/ndarray/indexing.hpp>
#include <irrt/ndarray/iter.hpp>
#include <irrt/ndarray/product.hpp>
#include <irrt/ndarray/reshape.hpp>
#include <irrt/ndarray/transpose.hpp>
#include <irrt/util.hpp>

View File

@ -0,0 +1,25 @@
// This file will be compiled like a real C++ program,
// and we do have the luxury to use the standard libraries.
// That is if the nix flakes do not have issues... especially on msys2...
#include <cstdint>
#include <cstdio>
#include <cstdlib>
// Special macro to inform `#include <irrt/*>` that we are testing.
#define IRRT_TESTING
// Note that failure unit tests are not supported.
#include <test/test_core.hpp>
#include <test/test_ndarray_basic.hpp>
#include <test/test_ndarray_broadcast.hpp>
#include <test/test_ndarray_indexing.hpp>
int main() {
test::core::run();
test::ndarray_basic::run();
test::ndarray_indexing::run();
test::ndarray_broadcast::run();
return 0;
}

View File

@ -0,0 +1,11 @@
#pragma once
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <irrt_everything.hpp>
#include <test/util.hpp>
/*
Include this header for every test_*.cpp
*/

View File

@ -0,0 +1,16 @@
#pragma once
#include <test/includes.hpp>
namespace test {
namespace core {
void test_int_exp() {
BEGIN_TEST();
assert_values_match(125L, __nac3_int_exp_impl<int64_t>(5, 3));
assert_values_match(3125L, __nac3_int_exp_impl<int64_t>(5, 5));
}
void run() { test_int_exp(); }
} // namespace core
} // namespace test

View File

@ -0,0 +1,30 @@
#pragma once
#include <test/includes.hpp>
namespace test {
namespace ndarray_basic {
void test_calc_size_from_shape_normal() {
// Test shapes with normal values
BEGIN_TEST();
int64_t shape[4] = {2, 3, 5, 7};
assert_values_match(
210L, ndarray::basic::util::calc_size_from_shape<int64_t>(4, shape));
}
void test_calc_size_from_shape_has_zero() {
// Test shapes with 0 in them
BEGIN_TEST();
int64_t shape[4] = {2, 0, 5, 7};
assert_values_match(
0L, ndarray::basic::util::calc_size_from_shape<int64_t>(4, shape));
}
void run() {
test_calc_size_from_shape_normal();
test_calc_size_from_shape_has_zero();
}
} // namespace ndarray_basic
} // namespace test

View File

@ -0,0 +1,127 @@
#pragma once
#include <test/includes.hpp>
namespace test {
namespace ndarray_broadcast {
void test_can_broadcast_shape() {
BEGIN_TEST();
assert_values_match(true,
ndarray::broadcast::util::can_broadcast_shape_to(
1, (int32_t[]){3}, 5, (int32_t[]){1, 1, 1, 1, 3}));
assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to(
1, (int32_t[]){3}, 2, (int32_t[]){3, 1}));
assert_values_match(true, ndarray::broadcast::util::can_broadcast_shape_to(
1, (int32_t[]){3}, 1, (int32_t[]){3}));
assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to(
1, (int32_t[]){1}, 1, (int32_t[]){3}));
assert_values_match(true, ndarray::broadcast::util::can_broadcast_shape_to(
1, (int32_t[]){1}, 1, (int32_t[]){1}));
assert_values_match(
true, ndarray::broadcast::util::can_broadcast_shape_to(
3, (int32_t[]){256, 256, 3}, 3, (int32_t[]){256, 1, 3}));
assert_values_match(true,
ndarray::broadcast::util::can_broadcast_shape_to(
3, (int32_t[]){256, 256, 3}, 1, (int32_t[]){3}));
assert_values_match(false,
ndarray::broadcast::util::can_broadcast_shape_to(
3, (int32_t[]){256, 256, 3}, 1, (int32_t[]){2}));
assert_values_match(true,
ndarray::broadcast::util::can_broadcast_shape_to(
3, (int32_t[]){256, 256, 3}, 1, (int32_t[]){1}));
// In cases when the shapes contain zero(es)
assert_values_match(true, ndarray::broadcast::util::can_broadcast_shape_to(
1, (int32_t[]){0}, 1, (int32_t[]){1}));
assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to(
1, (int32_t[]){0}, 1, (int32_t[]){2}));
assert_values_match(true,
ndarray::broadcast::util::can_broadcast_shape_to(
4, (int32_t[]){0, 4, 0, 0}, 1, (int32_t[]){1}));
assert_values_match(
true, ndarray::broadcast::util::can_broadcast_shape_to(
4, (int32_t[]){0, 4, 0, 0}, 4, (int32_t[]){1, 1, 1, 1}));
assert_values_match(
true, ndarray::broadcast::util::can_broadcast_shape_to(
4, (int32_t[]){0, 4, 0, 0}, 4, (int32_t[]){1, 4, 1, 1}));
assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to(
2, (int32_t[]){4, 3}, 2, (int32_t[]){0, 3}));
assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to(
2, (int32_t[]){4, 3}, 2, (int32_t[]){0, 0}));
}
void test_ndarray_broadcast() {
/*
# array = np.array([[19.9, 29.9, 39.9, 49.9]], dtype=np.float64)
# >>> [[19.9 29.9 39.9 49.9]]
#
# array = np.broadcast_to(array, (2, 3, 4))
# >>> [[[19.9 29.9 39.9 49.9]
# >>> [19.9 29.9 39.9 49.9]
# >>> [19.9 29.9 39.9 49.9]]
# >>> [[19.9 29.9 39.9 49.9]
# >>> [19.9 29.9 39.9 49.9]
# >>> [19.9 29.9 39.9 49.9]]]
#
# assery array.strides == (0, 0, 8)
*/
BEGIN_TEST();
double in_data[4] = {19.9, 29.9, 39.9, 49.9};
const int32_t in_ndims = 2;
int32_t in_shape[in_ndims] = {1, 4};
int32_t in_strides[in_ndims] = {};
NDArray<int32_t> ndarray = {.data = (uint8_t*)in_data,
.itemsize = sizeof(double),
.ndims = in_ndims,
.shape = in_shape,
.strides = in_strides};
ndarray::basic::set_strides_by_shape(&ndarray);
const int32_t dst_ndims = 3;
int32_t dst_shape[dst_ndims] = {2, 3, 4};
int32_t dst_strides[dst_ndims] = {};
NDArray<int32_t> dst_ndarray = {
.ndims = dst_ndims, .shape = dst_shape, .strides = dst_strides};
ndarray::broadcast::broadcast_to(&ndarray, &dst_ndarray);
assert_arrays_match(dst_ndims, ((int32_t[]){0, 0, 8}), dst_ndarray.strides);
assert_values_match(19.9,
*((double*)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, ((int32_t[]){0, 0, 0}))));
assert_values_match(29.9,
*((double*)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, ((int32_t[]){0, 0, 1}))));
assert_values_match(39.9,
*((double*)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, ((int32_t[]){0, 0, 2}))));
assert_values_match(49.9,
*((double*)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, ((int32_t[]){0, 0, 3}))));
assert_values_match(19.9,
*((double*)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, ((int32_t[]){0, 1, 0}))));
assert_values_match(29.9,
*((double*)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, ((int32_t[]){0, 1, 1}))));
assert_values_match(39.9,
*((double*)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, ((int32_t[]){0, 1, 2}))));
assert_values_match(49.9,
*((double*)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, ((int32_t[]){0, 1, 3}))));
assert_values_match(49.9,
*((double*)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, ((int32_t[]){1, 2, 3}))));
}
void run() {
test_can_broadcast_shape();
test_ndarray_broadcast();
}
} // namespace ndarray_broadcast
} // namespace test

View File

@ -0,0 +1,165 @@
#pragma once
#include <test/includes.hpp>
namespace test {
namespace ndarray_indexing {
void test_normal_1() {
/*
Reference Python code:
```python
ndarray = np.arange(12, dtype=np.float64).reshape((3, 4));
# array([[ 0., 1., 2., 3.],
# [ 4., 5., 6., 7.],
# [ 8., 9., 10., 11.]])
dst_ndarray = ndarray[-2:, 1::2]
# array([[ 5., 7.],
# [ 9., 11.]])
assert dst_ndarray.shape == (2, 2)
assert dst_ndarray.strides == (32, 16)
assert dst_ndarray[0, 0] == 5.0
assert dst_ndarray[0, 1] == 7.0
assert dst_ndarray[1, 0] == 9.0
assert dst_ndarray[1, 1] == 11.0
```
*/
BEGIN_TEST();
// Prepare src_ndarray
double src_data[12] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0,
6.0, 7.0, 8.0, 9.0, 10.0, 11.0};
int64_t src_itemsize = sizeof(double);
const int64_t src_ndims = 2;
int64_t src_shape[src_ndims] = {3, 4};
int64_t src_strides[src_ndims] = {};
NDArray<int64_t> src_ndarray = {.data = (uint8_t *)src_data,
.itemsize = src_itemsize,
.ndims = src_ndims,
.shape = src_shape,
.strides = src_strides};
ndarray::basic::set_strides_by_shape(&src_ndarray);
// Prepare dst_ndarray
const int64_t dst_ndims = 2;
int64_t dst_shape[dst_ndims] = {999, 999}; // Empty values
int64_t dst_strides[dst_ndims] = {999, 999}; // Empty values
NDArray<int64_t> dst_ndarray = {.data = nullptr,
.ndims = dst_ndims,
.shape = dst_shape,
.strides = dst_strides};
// Create the subscripts in `ndarray[-2::, 1::2]`
UserSlice subscript_1;
subscript_1.set_start(-2);
UserSlice subscript_2;
subscript_2.set_start(1);
subscript_2.set_step(2);
const int64_t num_indexes = 2;
NDIndex indexes[num_indexes] = {
{.type = ND_INDEX_TYPE_SLICE, .data = (uint8_t *)&subscript_1},
{.type = ND_INDEX_TYPE_SLICE, .data = (uint8_t *)&subscript_2}};
ndarray::indexing::index(num_indexes, indexes, &src_ndarray, &dst_ndarray);
int64_t expected_shape[dst_ndims] = {2, 2};
int64_t expected_strides[dst_ndims] = {32, 16};
assert_arrays_match(dst_ndims, expected_shape, dst_ndarray.shape);
assert_arrays_match(dst_ndims, expected_strides, dst_ndarray.strides);
// dst_ndarray[0, 0]
assert_values_match(5.0,
*((double *)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, (int64_t[dst_ndims]){0, 0})));
// dst_ndarray[0, 1]
assert_values_match(7.0,
*((double *)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, (int64_t[dst_ndims]){0, 1})));
// dst_ndarray[1, 0]
assert_values_match(9.0,
*((double *)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, (int64_t[dst_ndims]){1, 0})));
// dst_ndarray[1, 1]
assert_values_match(11.0,
*((double *)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, (int64_t[dst_ndims]){1, 1})));
}
void test_normal_2() {
/*
```python
ndarray = np.arange(12, dtype=np.float64).reshape((3, 4))
# array([[ 0., 1., 2., 3.],
# [ 4., 5., 6., 7.],
# [ 8., 9., 10., 11.]])
dst_ndarray = ndarray[2, ::-2]
# array([11., 9.])
assert dst_ndarray.shape == (2,)
assert dst_ndarray.strides == (-16,)
assert dst_ndarray[0] == 11.0
assert dst_ndarray[1] == 9.0
```
*/
BEGIN_TEST();
// Prepare src_ndarray
double src_data[12] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0,
6.0, 7.0, 8.0, 9.0, 10.0, 11.0};
int64_t src_itemsize = sizeof(double);
const int64_t src_ndims = 2;
int64_t src_shape[src_ndims] = {3, 4};
int64_t src_strides[src_ndims] = {};
NDArray<int64_t> src_ndarray = {.data = (uint8_t *)src_data,
.itemsize = src_itemsize,
.ndims = src_ndims,
.shape = src_shape,
.strides = src_strides};
ndarray::basic::set_strides_by_shape(&src_ndarray);
// Prepare dst_ndarray
const int64_t dst_ndims = 1;
int64_t dst_shape[dst_ndims] = {999}; // Empty values
int64_t dst_strides[dst_ndims] = {999}; // Empty values
NDArray<int64_t> dst_ndarray = {.data = nullptr,
.ndims = dst_ndims,
.shape = dst_shape,
.strides = dst_strides};
// Create the subscripts in `ndarray[2, ::-2]`
int64_t subscript_1 = 2;
UserSlice subscript_2;
subscript_2.set_step(-2);
const int64_t num_indexes = 2;
NDIndex indexes[num_indexes] = {
{.type = ND_INDEX_TYPE_SINGLE_ELEMENT, .data = (uint8_t *)&subscript_1},
{.type = ND_INDEX_TYPE_SLICE, .data = (uint8_t *)&subscript_2}};
ndarray::indexing::index(num_indexes, indexes, &src_ndarray, &dst_ndarray);
int64_t expected_shape[dst_ndims] = {2};
int64_t expected_strides[dst_ndims] = {-16};
assert_arrays_match(dst_ndims, expected_shape, dst_ndarray.shape);
assert_arrays_match(dst_ndims, expected_strides, dst_ndarray.strides);
assert_values_match(11.0,
*((double *)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, (int64_t[dst_ndims]){0})));
assert_values_match(9.0,
*((double *)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, (int64_t[dst_ndims]){1})));
}
void run() {
test_normal_1();
test_normal_2();
}
} // namespace ndarray_indexing
} // namespace test

131
nac3core/irrt/test/util.hpp Normal file
View File

@ -0,0 +1,131 @@
#pragma once
#include <cstdio>
#include <cstdlib>
template <class T>
void print_value(const T& value);
template <>
void print_value(const bool& value) {
printf("%s", value ? "true" : "false");
}
template <>
void print_value(const int8_t& value) {
printf("%d", value);
}
template <>
void print_value(const int32_t& value) {
printf("%d", value);
}
template <>
void print_value(const int64_t& value) {
printf("%d", value);
}
template <>
void print_value(const uint8_t& value) {
printf("%u", value);
}
template <>
void print_value(const uint32_t& value) {
printf("%u", value);
}
template <>
void print_value(const uint64_t& value) {
printf("%d", value);
}
template <>
void print_value(const float& value) {
printf("%f", value);
}
template <>
void print_value(const double& value) {
printf("%f", value);
}
void __begin_test(const char* function_name, const char* file, int line) {
printf("######### Running %s @ %s:%d\n", function_name, file, line);
}
#define BEGIN_TEST() __begin_test(__FUNCTION__, __FILE__, __LINE__)
void test_fail() {
printf("[!] Test failed. Exiting with status code 1.\n");
exit(1);
}
template <typename T>
void debug_print_array(int len, const T* as) {
printf("[");
for (int i = 0; i < len; i++) {
if (i != 0) printf(", ");
print_value(as[i]);
}
printf("]");
}
void print_assertion_passed(const char* file, int line) {
printf("[*] Assertion passed on %s:%d\n", file, line);
}
void print_assertion_failed(const char* file, int line) {
printf("[!] Assertion failed on %s:%d\n", file, line);
}
void __assert_true(const char* file, int line, bool cond) {
if (cond) {
print_assertion_passed(file, line);
} else {
print_assertion_failed(file, line);
test_fail();
}
}
#define assert_true(cond) __assert_true(__FILE__, __LINE__, cond)
template <typename T>
void __assert_arrays_match(const char* file, int line, int len,
const T* expected, const T* got) {
if (arrays_match(len, expected, got)) {
print_assertion_passed(file, line);
} else {
print_assertion_failed(file, line);
printf("Expect = ");
debug_print_array(len, expected);
printf("\n");
printf(" Got = ");
debug_print_array(len, got);
printf("\n");
test_fail();
}
}
#define assert_arrays_match(len, expected, got) \
__assert_arrays_match(__FILE__, __LINE__, len, expected, got)
template <typename T>
void __assert_values_match(const char* file, int line, T expected, T got) {
if (expected == got) {
print_assertion_passed(file, line);
} else {
print_assertion_failed(file, line);
printf("Expect = ");
print_value(expected);
printf("\n");
printf(" Got = ");
print_value(got);
printf("\n");
test_fail();
}
}
#define assert_values_match(expected, got) \
__assert_values_match(__FILE__, __LINE__, expected, got)

File diff suppressed because it is too large Load Diff

View File

@ -1717,6 +1717,7 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
gen_for_callback_incrementing( gen_for_callback_incrementing(
generator, generator,
ctx, ctx,
None,
llvm_usize.const_zero(), llvm_usize.const_zero(),
(len, false), (len, false),
|generator, ctx, _, i| { |generator, ctx, _, i| {

View File

@ -25,6 +25,7 @@ pub struct ConcreteFuncArg {
pub name: StrRef, pub name: StrRef,
pub ty: ConcreteType, pub ty: ConcreteType,
pub default_value: Option<SymbolValue>, pub default_value: Option<SymbolValue>,
pub is_vararg: bool,
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -46,6 +47,7 @@ pub enum ConcreteTypeEnum {
TPrimitive(Primitive), TPrimitive(Primitive),
TTuple { TTuple {
ty: Vec<ConcreteType>, ty: Vec<ConcreteType>,
is_vararg_ctx: bool,
}, },
TObj { TObj {
obj_id: DefinitionId, obj_id: DefinitionId,
@ -102,8 +104,16 @@ impl ConcreteTypeStore {
.iter() .iter()
.map(|arg| ConcreteFuncArg { .map(|arg| ConcreteFuncArg {
name: arg.name, name: arg.name,
ty: self.from_unifier_type(unifier, primitives, arg.ty, cache), ty: if arg.is_vararg {
let tuple_ty = unifier
.add_ty(TypeEnum::TTuple { ty: vec![arg.ty], is_vararg_ctx: true });
self.from_unifier_type(unifier, primitives, tuple_ty, cache)
} else {
self.from_unifier_type(unifier, primitives, arg.ty, cache)
},
default_value: arg.default_value.clone(), default_value: arg.default_value.clone(),
is_vararg: arg.is_vararg,
}) })
.collect(), .collect(),
ret: self.from_unifier_type(unifier, primitives, signature.ret, cache), ret: self.from_unifier_type(unifier, primitives, signature.ret, cache),
@ -158,11 +168,12 @@ impl ConcreteTypeStore {
cache.insert(ty, None); cache.insert(ty, None);
let ty_enum = unifier.get_ty(ty); let ty_enum = unifier.get_ty(ty);
let result = match &*ty_enum { let result = match &*ty_enum {
TypeEnum::TTuple { ty } => ConcreteTypeEnum::TTuple { TypeEnum::TTuple { ty, is_vararg_ctx } => ConcreteTypeEnum::TTuple {
ty: ty ty: ty
.iter() .iter()
.map(|t| self.from_unifier_type(unifier, primitives, *t, cache)) .map(|t| self.from_unifier_type(unifier, primitives, *t, cache))
.collect(), .collect(),
is_vararg_ctx: *is_vararg_ctx,
}, },
TypeEnum::TObj { obj_id, fields, params } => ConcreteTypeEnum::TObj { TypeEnum::TObj { obj_id, fields, params } => ConcreteTypeEnum::TObj {
obj_id: *obj_id, obj_id: *obj_id,
@ -248,11 +259,12 @@ impl ConcreteTypeStore {
*cache.get_mut(&cty).unwrap() = Some(ty); *cache.get_mut(&cty).unwrap() = Some(ty);
return ty; return ty;
} }
ConcreteTypeEnum::TTuple { ty } => TypeEnum::TTuple { ConcreteTypeEnum::TTuple { ty, is_vararg_ctx } => TypeEnum::TTuple {
ty: ty ty: ty
.iter() .iter()
.map(|cty| self.to_unifier_type(unifier, primitives, *cty, cache)) .map(|cty| self.to_unifier_type(unifier, primitives, *cty, cache))
.collect(), .collect(),
is_vararg_ctx: *is_vararg_ctx,
}, },
ConcreteTypeEnum::TVirtual { ty } => { ConcreteTypeEnum::TVirtual { ty } => {
TypeEnum::TVirtual { ty: self.to_unifier_type(unifier, primitives, *ty, cache) } TypeEnum::TVirtual { ty: self.to_unifier_type(unifier, primitives, *ty, cache) }
@ -277,6 +289,7 @@ impl ConcreteTypeStore {
name: arg.name, name: arg.name,
ty: self.to_unifier_type(unifier, primitives, arg.ty, cache), ty: self.to_unifier_type(unifier, primitives, arg.ty, cache),
default_value: arg.default_value.clone(), default_value: arg.default_value.clone(),
is_vararg: false,
}) })
.collect(), .collect(),
ret: self.to_unifier_type(unifier, primitives, *ret, cache), ret: self.to_unifier_type(unifier, primitives, *ret, cache),

File diff suppressed because it is too large Load Diff

View File

@ -130,3 +130,62 @@ pub fn call_ldexp<'ctx>(
.map(Either::unwrap_left) .map(Either::unwrap_left)
.unwrap() .unwrap()
} }
/// Macro to generate `np_linalg` and `sp_linalg` functions
/// The function takes as input `NDArray` and returns ()
///
/// Arguments:
/// * `$fn_name:ident`: The identifier of the rust function to be generated
/// * `$extern_fn:literal`: Name of underlying extern function
/// * (2/3/4): Number of `NDArray` that function takes as input
///
/// Note:
/// The operands and resulting `NDArray` are both passed as input to the funcion
/// It is the responsibility of caller to ensure that output `NDArray` is properly allocated on stack
/// The function changes the content of the output `NDArray` in-place
macro_rules! generate_linalg_extern_fn {
($fn_name:ident, $extern_fn:literal, 2) => {
generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2);
};
($fn_name:ident, $extern_fn:literal, 3) => {
generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2, mat3);
};
($fn_name:ident, $extern_fn:literal, 4) => {
generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2, mat3, mat4);
};
($fn_name:ident, $extern_fn:literal $(,$input_matrix:ident)*) => {
#[doc = concat!("Invokes the linalg `", stringify!($extern_fn), " function." )]
pub fn $fn_name<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>
$(,$input_matrix: BasicValueEnum<'ctx>)*,
name: Option<&str>,
){
const FN_NAME: &str = $extern_fn;
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = ctx.ctx.void_type().fn_type(&[$($input_matrix.get_type().into()),*], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder.build_call(extern_fn, &[$($input_matrix.into(),)*], name.unwrap_or_default()).unwrap();
}
};
}
generate_linalg_extern_fn!(call_np_linalg_cholesky, "np_linalg_cholesky", 2);
generate_linalg_extern_fn!(call_np_linalg_qr, "np_linalg_qr", 3);
generate_linalg_extern_fn!(call_np_linalg_svd, "np_linalg_svd", 4);
generate_linalg_extern_fn!(call_np_linalg_inv, "np_linalg_inv", 2);
generate_linalg_extern_fn!(call_np_linalg_pinv, "np_linalg_pinv", 2);
generate_linalg_extern_fn!(call_np_linalg_matrix_power, "np_linalg_matrix_power", 3);
generate_linalg_extern_fn!(call_np_linalg_det, "np_linalg_det", 2);
generate_linalg_extern_fn!(call_sp_linalg_lu, "sp_linalg_lu", 3);
generate_linalg_extern_fn!(call_sp_linalg_schur, "sp_linalg_schur", 3);
generate_linalg_extern_fn!(call_sp_linalg_hessenberg, "sp_linalg_hessenberg", 3);

View File

@ -123,11 +123,45 @@ pub trait CodeGenerator {
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
target: &Expr<Option<Type>>, target: &Expr<Option<Type>>,
value: ValueEnum<'ctx>, value: ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String> ) -> Result<(), String>
where where
Self: Sized, Self: Sized,
{ {
gen_assign(self, ctx, target, value) gen_assign(self, ctx, target, value, value_ty)
}
/// Generate code for an assignment expression where LHS is a `"target_list"`.
///
/// See <https://docs.python.org/3/reference/simple_stmts.html#assignment-statements>.
fn gen_assign_target_list<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
targets: &Vec<Expr<Option<Type>>>,
value: ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String>
where
Self: Sized,
{
gen_assign_target_list(self, ctx, targets, value, value_ty)
}
/// Generate code for an item assignment.
///
/// i.e., `target[key] = value`
fn gen_setitem<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
target: &Expr<Option<Type>>,
key: &Expr<Option<Type>>,
value: ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String>
where
Self: Sized,
{
gen_setitem(self, ctx, target, key, value, value_ty)
} }
/// Generate code for a while expression. /// Generate code for a while expression.

View File

@ -1,5 +1,12 @@
use crate::symbol_resolver::SymbolResolver;
use crate::typecheck::typedef::Type; use crate::typecheck::typedef::Type;
mod test;
use super::model::*;
use super::object::ndarray::broadcast::ShapeEntry;
use super::object::ndarray::indexing::{NDIndex, UserSlice};
use super::structure::{List, NDArray, NDIter};
use super::{ use super::{
classes::{ classes::{
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue,
@ -9,6 +16,8 @@ use super::{
}; };
use crate::codegen::classes::TypedArrayLikeAccessor; use crate::codegen::classes::TypedArrayLikeAccessor;
use crate::codegen::stmt::gen_for_callback_incrementing; use crate::codegen::stmt::gen_for_callback_incrementing;
use function::{get_sizet_dependent_function_name, CallFunction};
use inkwell::values::BasicValue;
use inkwell::{ use inkwell::{
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
context::Context, context::Context,
@ -414,14 +423,27 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
.unwrap(); .unwrap();
let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1").unwrap(); let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1").unwrap();
let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond").unwrap(); let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond").unwrap();
ctx.make_assert(
generator, // TODO: Temporary fix. Rewrite `list_slice_assignment` later
cond, // Exception params should have been i64
"0:ValueError", {
"attempt to assign sequence of size {0} to slice of size {1} with step size {2}", let param_model = IntModel(Int64);
[Some(src_slice_len), Some(dest_slice_len), Some(dest_idx.2)],
ctx.current_loc, let src_slice_len =
); param_model.s_extend_or_bit_cast(generator, ctx, src_slice_len, "src_slice_len");
let dest_slice_len =
param_model.s_extend_or_bit_cast(generator, ctx, dest_slice_len, "dest_slice_len");
let dest_idx_2 = param_model.s_extend_or_bit_cast(generator, ctx, dest_idx.2, "dest_idx_2");
ctx.make_assert(
generator,
cond,
"0:ValueError",
"attempt to assign sequence of size {0} to slice of size {1} with step size {2}",
[Some(src_slice_len.value), Some(dest_slice_len.value), Some(dest_idx_2.value)],
ctx.current_loc,
);
}
let new_len = { let new_len = {
let args = vec![ let args = vec![
@ -798,6 +820,7 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
gen_for_callback_incrementing( gen_for_callback_incrementing(
generator, generator,
ctx, ctx,
None,
llvm_usize.const_zero(), llvm_usize.const_zero(),
(min_ndims, false), (min_ndims, false),
|generator, ctx, _, idx| { |generator, ctx, _, idx| {
@ -872,7 +895,7 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
} }
/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`] /// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`]
/// containing the indices used for accessing `array` corresponding to the index of the broadcasted /// containing the indices used for accessing `array` corresponding to the index of the broadcast
/// array `broadcast_idx`. /// array `broadcast_idx`.
pub fn call_ndarray_calc_broadcast_index< pub fn call_ndarray_calc_broadcast_index<
'ctx, 'ctx,
@ -927,3 +950,337 @@ pub fn call_ndarray_calc_broadcast_index<
Box::new(|_, v| v.into()), Box::new(|_, v| v.into()),
) )
} }
pub fn call_nac3_throw_dummy_error<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &CodeGenContext<'_, '_>,
) {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_throw_dummy_error");
CallFunction::begin(generator, ctx, &name).returning_void();
}
/// Initialize all global `EXN_*` exception IDs in IRRT with the [`SymbolResolver`].
pub fn setup_irrt_exceptions<'ctx>(
ctx: &'ctx Context,
module: &Module<'ctx>,
symbol_resolver: &dyn SymbolResolver,
) {
let exn_id_type = ctx.i32_type();
let errors = &[
("EXN_INDEX_ERROR", "0:IndexError"),
("EXN_VALUE_ERROR", "0:ValueError"),
("EXN_ASSERTION_ERROR", "0:AssertionError"),
("EXN_RUNTIME_ERROR", "0:RuntimeError"),
("EXN_TYPE_ERROR", "0:TypeError"),
];
for (irrt_name, symbol_name) in errors {
let exn_id = symbol_resolver.get_string_id(symbol_name);
let exn_id = exn_id_type.const_int(exn_id as u64, false).as_basic_value_enum();
let global = module.get_global(irrt_name).unwrap_or_else(|| {
panic!("Exception symbol name '{irrt_name}' should exist in the IRRT LLVM module")
});
global.set_initializer(&exn_id);
}
}
pub fn call_nac3_list_slice_assign<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dst: Ptr<'ctx, StructModel<List<IntModel<Byte>>>>,
src: Ptr<'ctx, StructModel<List<IntModel<Byte>>>>,
itemsize: Int<'ctx, SizeT>,
user_slice: Ptr<'ctx, StructModel<UserSlice>>,
) {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_list_slice_assign");
CallFunction::begin(generator, ctx, &name)
.arg(dst)
.arg(src)
.arg(itemsize)
.arg(user_slice)
.returning_void();
}
pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndims: Int<'ctx, SizeT>,
shape: Ptr<'ctx, IntModel<SizeT>>,
) {
let name = get_sizet_dependent_function_name(
generator,
ctx,
"__nac3_ndarray_util_assert_shape_no_negative",
);
CallFunction::begin(generator, ctx, &name).arg(ndims).arg(shape).returning_void();
}
pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray_ndims: Int<'ctx, SizeT>,
ndarray_shape: Ptr<'ctx, IntModel<SizeT>>,
output_ndims: Int<'ctx, SizeT>,
output_shape: Ptr<'ctx, IntModel<SizeT>>,
) {
let name = get_sizet_dependent_function_name(
generator,
ctx,
"__nac3_ndarray_util_assert_output_shape_same",
);
CallFunction::begin(generator, ctx, &name)
.arg(ndarray_ndims)
.arg(ndarray_shape)
.arg(output_ndims)
.arg(output_shape)
.returning_void();
}
pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
pndarray: Ptr<'ctx, StructModel<NDArray>>,
) -> Int<'ctx, SizeT> {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_size");
CallFunction::begin(generator, ctx, &name).arg(pndarray).returning_auto("size")
}
pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
pndarray: Ptr<'ctx, StructModel<NDArray>>,
) -> Int<'ctx, SizeT> {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_nbytes");
CallFunction::begin(generator, ctx, &name).arg(pndarray).returning_auto("nbytes")
}
pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
pndarray: Ptr<'ctx, StructModel<NDArray>>,
) -> Int<'ctx, SizeT> {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_len");
CallFunction::begin(generator, ctx, &name).arg(pndarray).returning_auto("len")
}
pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray_ptr: Ptr<'ctx, StructModel<NDArray>>,
) -> Int<'ctx, Bool> {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_is_c_contiguous");
CallFunction::begin(generator, ctx, &name).arg(ndarray_ptr).returning_auto("is_c_contiguous")
}
pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
pndarray: Ptr<'ctx, StructModel<NDArray>>,
index: Int<'ctx, SizeT>,
) -> Ptr<'ctx, IntModel<Byte>> {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_get_nth_pelement");
CallFunction::begin(generator, ctx, &name).arg(pndarray).arg(index).returning_auto("pelement")
}
pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
pdnarray: Ptr<'ctx, StructModel<NDArray>>,
) {
let name =
get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_set_strides_by_shape");
CallFunction::begin(generator, ctx, &name).arg(pdnarray).returning_void();
}
pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src_ndarray: Ptr<'ctx, StructModel<NDArray>>,
dst_ndarray: Ptr<'ctx, StructModel<NDArray>>,
) {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_copy_data");
CallFunction::begin(generator, ctx, &name).arg(src_ndarray).arg(dst_ndarray).returning_void();
}
pub fn call_nac3_ndarray_index<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
num_indexes: Int<'ctx, SizeT>,
indexes: Ptr<'ctx, StructModel<NDIndex>>,
src_ndarray: Ptr<'ctx, StructModel<NDArray>>,
dst_ndarray: Ptr<'ctx, StructModel<NDArray>>,
) {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_index");
CallFunction::begin(generator, ctx, &name)
.arg(num_indexes)
.arg(indexes)
.arg(src_ndarray)
.arg(dst_ndarray)
.returning_void();
}
pub fn call_nac3_ndarray_broadcast_to<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src_ndarray: Ptr<'ctx, StructModel<NDArray>>,
dst_ndarray: Ptr<'ctx, StructModel<NDArray>>,
) {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_to");
CallFunction::begin(generator, ctx, &name).arg(src_ndarray).arg(dst_ndarray).returning_void();
}
pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
num_shape_entries: Int<'ctx, SizeT>,
shape_entries: Ptr<'ctx, StructModel<ShapeEntry>>,
dst_ndims: Int<'ctx, SizeT>,
dst_shape: Ptr<'ctx, IntModel<SizeT>>,
) {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_shapes");
CallFunction::begin(generator, ctx, &name)
.arg(num_shape_entries)
.arg(shape_entries)
.arg(dst_ndims)
.arg(dst_shape)
.returning_void();
}
pub fn call_nac3_ndarray_resolve_and_check_new_shape<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
size: Int<'ctx, SizeT>,
new_ndims: Int<'ctx, SizeT>,
new_shape: Ptr<'ctx, IntModel<SizeT>>,
) {
let name = get_sizet_dependent_function_name(
generator,
ctx,
"__nac3_ndarray_resolve_and_check_new_shape",
);
CallFunction::begin(generator, ctx, &name)
.arg(size)
.arg(new_ndims)
.arg(new_shape)
.returning_void();
}
pub fn call_nac3_ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src_ndarray: Ptr<'ctx, StructModel<NDArray>>,
dst_ndarray: Ptr<'ctx, StructModel<NDArray>>,
num_axes: Int<'ctx, SizeT>,
axes: Ptr<'ctx, IntModel<SizeT>>,
) {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_transpose");
CallFunction::begin(generator, ctx, &name)
.arg(src_ndarray)
.arg(dst_ndarray)
.arg(num_axes)
.arg(axes)
.returning_void();
}
#[allow(clippy::too_many_arguments)]
pub fn call_nac3_ndarray_matmul_calculate_shapes<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
a_ndims: Int<'ctx, SizeT>,
a_shape: Ptr<'ctx, IntModel<SizeT>>,
b_ndims: Int<'ctx, SizeT>,
b_shape: Ptr<'ctx, IntModel<SizeT>>,
final_ndims: Int<'ctx, SizeT>,
new_a_shape: Ptr<'ctx, IntModel<SizeT>>,
new_b_shape: Ptr<'ctx, IntModel<SizeT>>,
dst_shape: Ptr<'ctx, IntModel<SizeT>>,
) {
let name =
get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_matmul_calculate_shapes");
CallFunction::begin(generator, ctx, &name)
.arg(a_ndims)
.arg(a_shape)
.arg(b_ndims)
.arg(b_shape)
.arg(final_ndims)
.arg(new_a_shape)
.arg(new_b_shape)
.arg(dst_shape)
.returning_void();
}
pub fn call_nac3_ndarray_float64_matmul_at_least_2d<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
a_ndarray: Ptr<'ctx, StructModel<NDArray>>,
b_ndarray: Ptr<'ctx, StructModel<NDArray>>,
dst_ndarray: Ptr<'ctx, StructModel<NDArray>>,
) {
let name = get_sizet_dependent_function_name(
generator,
ctx,
"__nac3_ndarray_float64_matmul_at_least_2d",
);
CallFunction::begin(generator, ctx, &name)
.arg(a_ndarray)
.arg(b_ndarray)
.arg(dst_ndarray)
.returning_void();
}
pub fn call_nac3_array_set_and_validate_list_shape<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
list: Ptr<'ctx, StructModel<List<IntModel<Byte>>>>,
ndims: Int<'ctx, SizeT>,
shape: Ptr<'ctx, IntModel<SizeT>>,
) {
let name = get_sizet_dependent_function_name(
generator,
ctx,
"__nac3_array_set_and_validate_list_shape",
);
CallFunction::begin(generator, ctx, &name).arg(list).arg(ndims).arg(shape).returning_void();
}
pub fn call_nac3_array_write_list_to_array<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
list: Ptr<'ctx, StructModel<List<IntModel<Byte>>>>,
ndarray: Ptr<'ctx, StructModel<NDArray>>,
) {
let name =
get_sizet_dependent_function_name(generator, ctx, "__nac3_array_write_list_to_array");
CallFunction::begin(generator, ctx, &name).arg(list).arg(ndarray).returning_void();
}
pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
iter: Ptr<'ctx, StructModel<NDIter>>,
ndarray: Ptr<'ctx, StructModel<NDArray>>,
indices: Ptr<'ctx, IntModel<SizeT>>,
) {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_initialize");
CallFunction::begin(generator, ctx, &name).arg(iter).arg(ndarray).arg(indices).returning_void();
}
pub fn call_nac3_nditer_has_next<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
iter: Ptr<'ctx, StructModel<NDIter>>,
) -> Int<'ctx, Bool> {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_has_next");
CallFunction::begin(generator, ctx, &name).arg(iter).returning_auto("has_next")
}
pub fn call_nac3_nditer_next<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
iter: Ptr<'ctx, StructModel<NDIter>>,
) {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_next");
CallFunction::begin(generator, ctx, &name).arg(iter).returning_void();
}

View File

@ -0,0 +1,26 @@
#[cfg(test)]
mod tests {
use std::{path::Path, process::Command};
#[test]
fn run_irrt_test() {
assert!(
cfg!(feature = "test"),
"Please do `cargo test -F test` to compile `irrt_test.out` and run test"
);
let irrt_test_out_path = Path::new(concat!(env!("OUT_DIR"), "/irrt_test.out"));
let output = Command::new(irrt_test_out_path.to_str().unwrap()).output().unwrap();
if !output.status.success() {
eprintln!("irrt_test failed with status {}:", output.status);
eprintln!("====== stdout ======");
eprintln!("{}", String::from_utf8(output.stdout).unwrap());
eprintln!("====== stderr ======");
eprintln!("{}", String::from_utf8(output.stderr).unwrap());
eprintln!("====================");
panic!("irrt_test failed");
}
}
}

View File

@ -35,6 +35,40 @@ fn get_float_intrinsic_repr(ctx: &Context, ft: FloatType) -> &'static str {
unreachable!() unreachable!()
} }
/// Invokes the [`llvm.va_start`](https://llvm.org/docs/LangRef.html#llvm-va-start-intrinsic)
/// intrinsic.
pub fn call_va_start<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue<'ctx>) {
const FN_NAME: &str = "llvm.va_start";
let intrinsic_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let llvm_void = ctx.ctx.void_type();
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let fn_type = llvm_void.fn_type(&[llvm_p0i8.into()], false);
ctx.module.add_function(FN_NAME, fn_type, None)
});
ctx.builder.build_call(intrinsic_fn, &[arglist.into()], "").unwrap();
}
/// Invokes the [`llvm.va_start`](https://llvm.org/docs/LangRef.html#llvm-va-start-intrinsic)
/// intrinsic.
pub fn call_va_end<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue<'ctx>) {
const FN_NAME: &str = "llvm.va_end";
let intrinsic_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let llvm_void = ctx.ctx.void_type();
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let fn_type = llvm_void.fn_type(&[llvm_p0i8.into()], false);
ctx.module.add_function(FN_NAME, fn_type, None)
});
ctx.builder.build_call(intrinsic_fn, &[arglist.into()], "").unwrap();
}
/// Invokes the [`llvm.stacksave`](https://llvm.org/docs/LangRef.html#llvm-stacksave-intrinsic) /// Invokes the [`llvm.stacksave`](https://llvm.org/docs/LangRef.html#llvm-stacksave-intrinsic)
/// intrinsic. /// intrinsic.
pub fn call_stacksave<'ctx>( pub fn call_stacksave<'ctx>(

View File

@ -1,7 +1,7 @@
use crate::{ use crate::{
codegen::classes::{ListType, NDArrayType, ProxyType, RangeType}, codegen::classes::{ListType, ProxyType, RangeType},
symbol_resolver::{StaticValue, SymbolResolver}, symbol_resolver::{StaticValue, SymbolResolver},
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef}, toplevel::{helper::PrimDef, TopLevelContext, TopLevelDef},
typecheck::{ typecheck::{
type_inferencer::{CodeLocation, PrimitiveStore}, type_inferencer::{CodeLocation, PrimitiveStore},
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier}, typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
@ -24,6 +24,7 @@ use inkwell::{
AddressSpace, IntPredicate, OptimizationLevel, AddressSpace, IntPredicate, OptimizationLevel,
}; };
use itertools::Itertools; use itertools::Itertools;
use model::*;
use nac3parser::ast::{Location, Stmt, StrRef}; use nac3parser::ast::{Location, Stmt, StrRef};
use parking_lot::{Condvar, Mutex}; use parking_lot::{Condvar, Mutex};
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
@ -32,8 +33,8 @@ use std::sync::{
Arc, Arc,
}; };
use std::thread; use std::thread;
use structure::{CSlice, Exception, NDArray};
pub mod builtin_fns;
pub mod classes; pub mod classes;
pub mod concrete_type; pub mod concrete_type;
pub mod expr; pub mod expr;
@ -41,8 +42,12 @@ pub mod extern_fns;
mod generator; mod generator;
pub mod irrt; pub mod irrt;
pub mod llvm_intrinsics; pub mod llvm_intrinsics;
pub mod model;
pub mod numpy; pub mod numpy;
pub mod numpy_new;
pub mod object;
pub mod stmt; pub mod stmt;
pub mod structure;
#[cfg(test)] #[cfg(test)]
mod test; mod test;
@ -68,6 +73,16 @@ pub struct CodeGenLLVMOptions {
pub target: CodeGenTargetMachineOptions, pub target: CodeGenTargetMachineOptions,
} }
impl CodeGenLLVMOptions {
/// Creates a [`TargetMachine`] using the target options specified by this struct.
///
/// See [`Target::create_target_machine`].
#[must_use]
pub fn create_target_machine(&self) -> Option<TargetMachine> {
self.target.create_target_machine(self.opt_level)
}
}
/// Additional options for code generation for the target machine. /// Additional options for code generation for the target machine.
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug, Eq, PartialEq)]
pub struct CodeGenTargetMachineOptions { pub struct CodeGenTargetMachineOptions {
@ -158,11 +173,11 @@ pub struct CodeGenContext<'ctx, 'a> {
pub registry: &'a WorkerRegistry, pub registry: &'a WorkerRegistry,
/// Cache for constant strings. /// Cache for constant strings.
pub const_strings: HashMap<String, BasicValueEnum<'ctx>>, pub const_strings: HashMap<String, Struct<'ctx, CSlice>>,
/// [`BasicBlock`] containing all `alloca` statements for the current function. /// [`BasicBlock`] containing all `alloca` statements for the current function.
pub init_bb: BasicBlock<'ctx>, pub init_bb: BasicBlock<'ctx>,
pub exception_val: Option<PointerValue<'ctx>>, pub exception_val: Option<Ptr<'ctx, StructModel<Exception>>>,
/// The header and exit basic blocks of a loop in this context. See /// The header and exit basic blocks of a loop in this context. See
/// <https://llvm.org/docs/LoopTerminology.html> for explanation of these terminology. /// <https://llvm.org/docs/LoopTerminology.html> for explanation of these terminology.
@ -338,6 +353,10 @@ impl WorkerRegistry {
let mut builder = context.create_builder(); let mut builder = context.create_builder();
let mut module = context.create_module(generator.get_name()); let mut module = context.create_module(generator.get_name());
let target_machine = self.llvm_options.create_target_machine().unwrap();
module.set_data_layout(&target_machine.get_target_data().get_data_layout());
module.set_triple(&target_machine.get_triple());
module.add_basic_value_flag( module.add_basic_value_flag(
"Debug Info Version", "Debug Info Version",
inkwell::module::FlagBehavior::Warning, inkwell::module::FlagBehavior::Warning,
@ -361,6 +380,10 @@ impl WorkerRegistry {
errors.insert(e); errors.insert(e);
// create a new empty module just to continue codegen and collect errors // create a new empty module just to continue codegen and collect errors
module = context.create_module(&format!("{}_recover", generator.get_name())); module = context.create_module(&format!("{}_recover", generator.get_name()));
let target_machine = self.llvm_options.create_target_machine().unwrap();
module.set_data_layout(&target_machine.get_target_data().get_data_layout());
module.set_triple(&target_machine.get_triple());
} }
} }
*self.task_count.lock() -= 1; *self.task_count.lock() -= 1;
@ -426,7 +449,7 @@ pub struct CodeGenTask {
fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
ctx: &'ctx Context, ctx: &'ctx Context,
module: &Module<'ctx>, module: &Module<'ctx>,
generator: &mut G, generator: &G,
unifier: &mut Unifier, unifier: &mut Unifier,
top_level: &TopLevelContext, top_level: &TopLevelContext,
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>, type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
@ -471,12 +494,8 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
} }
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (dtype, _) = unpack_ndarray_var_tys(unifier, ty); let pndarray_model = PtrModel(StructModel(NDArray));
let element_type = get_llvm_type( pndarray_model.get_type(generator, ctx).as_basic_type_enum()
ctx, module, generator, unifier, top_level, type_cache, dtype,
);
NDArrayType::new(generator, ctx, element_type).as_base_type().into()
} }
_ => unreachable!( _ => unreachable!(
@ -520,8 +539,10 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
}; };
return ty; return ty;
} }
TTuple { ty } => { TTuple { ty, is_vararg_ctx } => {
// a struct with fields in the order present in the tuple // a struct with fields in the order present in the tuple
assert!(!is_vararg_ctx, "Tuples in vararg context must be instantiated with the correct number of arguments before calling get_llvm_type");
let fields = ty let fields = ty
.iter() .iter()
.map(|ty| { .map(|ty| {
@ -551,7 +572,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
fn get_llvm_abi_type<'ctx, G: CodeGenerator + ?Sized>( fn get_llvm_abi_type<'ctx, G: CodeGenerator + ?Sized>(
ctx: &'ctx Context, ctx: &'ctx Context,
module: &Module<'ctx>, module: &Module<'ctx>,
generator: &mut G, generator: &G,
unifier: &mut Unifier, unifier: &mut Unifier,
top_level: &TopLevelContext, top_level: &TopLevelContext,
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>, type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
@ -589,6 +610,40 @@ fn need_sret(ty: BasicTypeEnum) -> bool {
need_sret_impl(ty, true) need_sret_impl(ty, true)
} }
/// Returns the [`BasicTypeEnum`] representing a `va_list` struct for variadic arguments.
fn get_llvm_valist_type<'ctx>(ctx: &'ctx Context, triple: &TargetTriple) -> BasicTypeEnum<'ctx> {
let triple = TargetMachine::normalize_triple(triple);
let triple = triple.as_str().to_str().unwrap();
let arch = triple.split('-').next().unwrap();
let llvm_pi8 = ctx.i8_type().ptr_type(AddressSpace::default());
// Referenced from parseArch() in llvm/lib/Support/Triple.cpp
match arch {
"i386" | "i486" | "i586" | "i686" | "riscv32" => {
ctx.i8_type().ptr_type(AddressSpace::default()).into()
}
"amd64" | "x86_64" | "x86_64h" => {
let llvm_i32 = ctx.i32_type();
let va_list_tag = ctx.opaque_struct_type("struct.__va_list_tag");
va_list_tag.set_body(
&[llvm_i32.into(), llvm_i32.into(), llvm_pi8.into(), llvm_pi8.into()],
false,
);
va_list_tag.into()
}
"armv7" => {
let va_list = ctx.opaque_struct_type("struct.__va_list");
va_list.set_body(&[llvm_pi8.into()], false);
va_list.into()
}
triple => {
todo!("Unsupported platform for varargs: {triple}")
}
}
}
/// Implementation for generating LLVM IR for a function. /// Implementation for generating LLVM IR for a function.
pub fn gen_func_impl< pub fn gen_func_impl<
'ctx, 'ctx,
@ -646,43 +701,19 @@ pub fn gen_func_impl<
..primitives ..primitives
}; };
let mut type_cache: HashMap<_, _> = [ let cslice_model = StructModel(CSlice);
let pexn_model = PtrModel(StructModel(Exception));
let mut type_cache: HashMap<_, BasicTypeEnum<'ctx>> = [
(primitives.int32, context.i32_type().into()), (primitives.int32, context.i32_type().into()),
(primitives.int64, context.i64_type().into()), (primitives.int64, context.i64_type().into()),
(primitives.uint32, context.i32_type().into()), (primitives.uint32, context.i32_type().into()),
(primitives.uint64, context.i64_type().into()), (primitives.uint64, context.i64_type().into()),
(primitives.float, context.f64_type().into()), (primitives.float, context.f64_type().into()),
(primitives.bool, context.i8_type().into()), (primitives.bool, context.i8_type().into()),
(primitives.str, { (primitives.str, cslice_model.get_type(generator, context).into()),
let name = "str";
match module.get_struct_type(name) {
None => {
let str_type = context.opaque_struct_type("str");
let fields = [
context.i8_type().ptr_type(AddressSpace::default()).into(),
generator.get_size_type(context).into(),
];
str_type.set_body(&fields, false);
str_type.into()
}
Some(t) => t.as_basic_type_enum(),
}
}),
(primitives.range, RangeType::new(context).as_base_type().into()), (primitives.range, RangeType::new(context).as_base_type().into()),
(primitives.exception, { (primitives.exception, pexn_model.get_type(generator, context).into()),
let name = "Exception";
if let Some(t) = module.get_struct_type(name) {
t.ptr_type(AddressSpace::default()).as_basic_type_enum()
} else {
let exception = context.opaque_struct_type("Exception");
let int32 = context.i32_type().into();
let int64 = context.i64_type().into();
let str_ty = module.get_struct_type("str").unwrap().as_basic_type_enum();
let fields = [int32, str_ty, int32, int32, str_ty, str_ty, int64, int64, int64];
exception.set_body(&fields, false);
exception.ptr_type(AddressSpace::default()).as_basic_type_enum()
}
}),
] ]
.iter() .iter()
.copied() .copied()
@ -700,6 +731,7 @@ pub fn gen_func_impl<
name: arg.name, name: arg.name,
ty: task.store.to_unifier_type(&mut unifier, &primitives, arg.ty, &mut cache), ty: task.store.to_unifier_type(&mut unifier, &primitives, arg.ty, &mut cache),
default_value: arg.default_value.clone(), default_value: arg.default_value.clone(),
is_vararg: arg.is_vararg,
}) })
.collect_vec(), .collect_vec(),
task.store.to_unifier_type(&mut unifier, &primitives, *ret, &mut cache), task.store.to_unifier_type(&mut unifier, &primitives, *ret, &mut cache),
@ -722,7 +754,10 @@ pub fn gen_func_impl<
let has_sret = ret_type.map_or(false, |ty| need_sret(ty)); let has_sret = ret_type.map_or(false, |ty| need_sret(ty));
let mut params = args let mut params = args
.iter() .iter()
.filter(|arg| !arg.is_vararg)
.map(|arg| { .map(|arg| {
debug_assert!(!arg.is_vararg);
get_llvm_abi_type( get_llvm_abi_type(
context, context,
&module, &module,
@ -741,9 +776,12 @@ pub fn gen_func_impl<
params.insert(0, ret_type.unwrap().ptr_type(AddressSpace::default()).into()); params.insert(0, ret_type.unwrap().ptr_type(AddressSpace::default()).into());
} }
debug_assert!(matches!(args.iter().filter(|arg| arg.is_vararg).count(), 0..=1));
let vararg_arg = args.iter().find(|arg| arg.is_vararg);
let fn_type = match ret_type { let fn_type = match ret_type {
Some(ret_type) if !has_sret => ret_type.fn_type(&params, false), Some(ret_type) if !has_sret => ret_type.fn_type(&params, vararg_arg.is_some()),
_ => context.void_type().fn_type(&params, false), _ => context.void_type().fn_type(&params, vararg_arg.is_some()),
}; };
let symbol = &task.symbol_name; let symbol = &task.symbol_name;
@ -773,7 +811,9 @@ pub fn gen_func_impl<
let mut var_assignment = HashMap::new(); let mut var_assignment = HashMap::new();
let offset = u32::from(has_sret); let offset = u32::from(has_sret);
for (n, arg) in args.iter().enumerate() {
// Store non-vararg argument values into local variables
for (n, arg) in args.iter().enumerate().filter(|(_, arg)| !arg.is_vararg) {
let param = fn_val.get_nth_param((n as u32) + offset).unwrap(); let param = fn_val.get_nth_param((n as u32) + offset).unwrap();
let local_type = get_llvm_type( let local_type = get_llvm_type(
context, context,
@ -806,6 +846,8 @@ pub fn gen_func_impl<
var_assignment.insert(arg.name, (alloca, None, 0)); var_assignment.insert(arg.name, (alloca, None, 0));
} }
// TODO: Save vararg parameters as list
let return_buffer = if has_sret { let return_buffer = if has_sret {
Some(fn_val.get_nth_param(0).unwrap().into_pointer_value()) Some(fn_val.get_nth_param(0).unwrap().into_pointer_value())
} else { } else {
@ -1028,3 +1070,9 @@ fn gen_in_range_check<'ctx>(
ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp").unwrap() ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp").unwrap()
} }
/// Returns the internal name for the `va_count` argument, used to indicate the number of arguments
/// passed to the variadic function.
fn get_va_count_arg_name(arg_name: StrRef) -> StrRef {
format!("__{}_va_count", &arg_name).into()
}

View File

@ -0,0 +1,40 @@
use inkwell::{
context::Context,
types::{BasicType, BasicTypeEnum},
values::BasicValueEnum,
};
use crate::codegen::CodeGenerator;
use super::*;
#[derive(Debug, Clone, Copy)]
pub struct AnyModel<'ctx>(pub BasicTypeEnum<'ctx>);
pub type Anything<'ctx> = Instance<'ctx, AnyModel<'ctx>>;
impl<'ctx> Model<'ctx> for AnyModel<'ctx> {
type Value = BasicValueEnum<'ctx>;
type Type = BasicTypeEnum<'ctx>;
fn get_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
_ctx: &'ctx Context,
) -> Self::Type {
self.0
}
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
&self,
_generator: &mut G,
_ctx: &'ctx Context,
ty: T,
) -> Result<(), ModelError> {
let ty = ty.as_basic_type_enum();
if ty == self.0 {
Ok(())
} else {
Err(ModelError(format!("Expecting {}, but got {}", self.0, ty)))
}
}
}

View File

@ -0,0 +1,122 @@
use inkwell::{
context::Context,
types::{ArrayType, BasicType, BasicTypeEnum},
values::ArrayValue,
};
use crate::codegen::{CodeGenContext, CodeGenerator};
use super::*;
/// A Model for an [`ArrayType`].
#[derive(Debug, Clone, Copy)]
pub struct ArrayModel<Element> {
pub len: u32,
pub element: Element,
}
pub type Array<'ctx, Element> = Instance<'ctx, ArrayModel<Element>>;
impl<'ctx, Element: Model<'ctx>> Model<'ctx> for ArrayModel<Element> {
type Value = ArrayValue<'ctx>;
type Type = ArrayType<'ctx>;
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
self.element.get_type(generator, ctx).array_type(self.len)
}
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
ty: T,
) -> Result<(), ModelError> {
let ty = ty.as_basic_type_enum();
let BasicTypeEnum::ArrayType(ty) = ty else {
return Err(ModelError(format!("Expecting ArrayType, but got {ty:?}")));
};
if ty.len() != self.len {
return Err(ModelError(format!(
"Expecting ArrayType with size {}, but got an ArrayType with size {}",
ty.len(),
self.len
)));
}
self.element
.check_type(generator, ctx, ty.get_element_type())
.map_err(|err| err.under_context("an ArrayType"))?;
Ok(())
}
}
impl<'ctx, Element: Model<'ctx>> Ptr<'ctx, ArrayModel<Element>> {
/// Get the pointer to the `i`-th (0-based) array element.
pub fn at<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
i: u32,
name: &str,
) -> Ptr<'ctx, Element> {
assert!(i < self.model.0.len);
let zero = ctx.ctx.i32_type().const_zero();
let i = ctx.ctx.i32_type().const_int(u64::from(i), false);
let ptr = unsafe { ctx.builder.build_in_bounds_gep(self.value, &[zero, i], name).unwrap() };
PtrModel(self.model.0.element).check_value(generator, ctx.ctx, ptr).unwrap()
}
}
/// Like [`ArrayModel`] but length is strongly-typed.
#[derive(Debug, Clone, Copy, Default)]
pub struct NArrayModel<const LEN: u32, Element>(pub Element);
pub type NArray<'ctx, const LEN: u32, Element> = Instance<'ctx, NArrayModel<LEN, Element>>;
impl<'ctx, const LEN: u32, Element: Model<'ctx>> NArrayModel<LEN, Element> {
/// Forget the `LEN` constant generic and get an [`ArrayModel`] with the same length.
pub fn forget_len(&self) -> ArrayModel<Element> {
ArrayModel { element: self.0, len: LEN }
}
}
impl<'ctx, const LEN: u32, Element: Model<'ctx>> Model<'ctx> for NArrayModel<LEN, Element> {
type Value = ArrayValue<'ctx>;
type Type = ArrayType<'ctx>;
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
// Convenient implementation
self.forget_len().get_type(generator, ctx)
}
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
ty: T,
) -> Result<(), ModelError> {
// Convenient implementation
self.forget_len().check_type(generator, ctx, ty)
}
}
impl<'ctx, const LEN: u32, Element: Model<'ctx>> Ptr<'ctx, NArrayModel<LEN, Element>> {
/// Get the pointer to the `i`-th (0-based) array element.
pub fn at_const<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
i: u32,
name: &str,
) -> Ptr<'ctx, Element> {
assert!(i < LEN);
let zero = ctx.ctx.i32_type().const_zero();
let i = ctx.ctx.i32_type().const_int(u64::from(i), false);
let ptr = unsafe { ctx.builder.build_in_bounds_gep(self.value, &[zero, i], name).unwrap() };
PtrModel(self.model.0 .0).check_value(generator, ctx.ctx, ptr).unwrap()
}
}

View File

@ -0,0 +1,123 @@
use std::fmt;
use inkwell::{context::Context, types::*, values::*};
use super::*;
use crate::codegen::{CodeGenContext, CodeGenerator};
#[derive(Debug, Clone)]
pub struct ModelError(pub String);
impl ModelError {
pub(super) fn under_context(mut self, context: &str) -> Self {
self.0.push_str(" ... in ");
self.0.push_str(context);
self
}
}
pub trait Model<'ctx>: fmt::Debug + Clone + Copy {
type Value: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>>;
type Type: BasicType<'ctx>;
/// Return the [`BasicType`] of this model.
#[must_use]
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type;
/// Check if a [`BasicType`] is the same type of this model.
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
ty: T,
) -> Result<(), ModelError>;
/// Create an instance from a value with [`Instance::model`] being this model.
///
/// Caller must make sure the type of `value` and the type of this `model` are equivalent.
#[must_use]
fn believe_value(&self, value: Self::Value) -> Instance<'ctx, Self> {
Instance { model: *self, value }
}
/// Check if a [`BasicValue`]'s type is equivalent to the type of this model.
/// Wrap it into an [`Instance`] if it is.
fn check_value<V: BasicValue<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
value: V,
) -> Result<Instance<'ctx, Self>, ModelError> {
let value = value.as_basic_value_enum();
self.check_type(generator, ctx, value.get_type())
.map_err(|err| err.under_context(format!("the value {value:?}").as_str()))?;
let Ok(value) = Self::Value::try_from(value) else {
unreachable!("check_type() has bad implementation")
};
Ok(self.believe_value(value))
}
// Allocate a value on the stack and return its pointer.
fn alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
name: &str,
) -> Ptr<'ctx, Self> {
let pmodel = PtrModel(*self);
let p = ctx.builder.build_alloca(self.get_type(generator, ctx.ctx), name).unwrap();
pmodel.believe_value(p)
}
// Allocate an array on the stack and return its pointer.
fn array_alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
len: IntValue<'ctx>,
name: &str,
) -> Ptr<'ctx, Self> {
let pmodel = PtrModel(*self);
let p =
ctx.builder.build_array_alloca(self.get_type(generator, ctx.ctx), len, name).unwrap();
pmodel.believe_value(p)
}
fn var_alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&str>,
) -> Result<Ptr<'ctx, Self>, String> {
let pmodel = PtrModel(*self);
let ty = self.get_type(generator, ctx.ctx).as_basic_type_enum();
let p = generator.gen_var_alloc(ctx, ty, name)?;
Ok(pmodel.believe_value(p))
}
fn array_var_alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
len: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> Result<Ptr<'ctx, Self>, String> {
// TODO: Remove ArraySliceValue
let pmodel = PtrModel(*self);
let ty = self.get_type(generator, ctx.ctx).as_basic_type_enum();
let p = generator.gen_array_var_alloc(ctx, ty, len, name)?;
Ok(pmodel.believe_value(PointerValue::from(p)))
}
}
#[derive(Debug, Clone, Copy)]
pub struct Instance<'ctx, M: Model<'ctx>> {
/// The model of this instance.
pub model: M,
/// The value of this instance.
///
/// Caller must make sure the type of `value` and the type of this `model` are equivalent,
/// down to having the same [`IntType::get_bit_width`] in case of [`IntType`] for example.
pub value: M::Value,
}

View File

@ -0,0 +1,88 @@
use std::fmt;
use inkwell::{context::Context, types::FloatType, values::FloatValue};
use crate::codegen::CodeGenerator;
use super::*;
pub trait FloatKind<'ctx>: fmt::Debug + Clone + Copy {
fn get_float_type<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &'ctx Context,
) -> FloatType<'ctx>;
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Float32;
#[derive(Debug, Clone, Copy, Default)]
pub struct Float64;
impl<'ctx> FloatKind<'ctx> for Float32 {
fn get_float_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
ctx: &'ctx Context,
) -> FloatType<'ctx> {
ctx.f32_type()
}
}
impl<'ctx> FloatKind<'ctx> for Float64 {
fn get_float_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
ctx: &'ctx Context,
) -> FloatType<'ctx> {
ctx.f64_type()
}
}
#[derive(Debug, Clone, Copy)]
pub struct AnyFloat<'ctx>(FloatType<'ctx>);
impl<'ctx> FloatKind<'ctx> for AnyFloat<'ctx> {
fn get_float_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
_ctx: &'ctx Context,
) -> FloatType<'ctx> {
self.0
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct FloatModel<N>(pub N);
pub type Float<'ctx, N> = Instance<'ctx, FloatModel<N>>;
impl<'ctx, N: FloatKind<'ctx>> Model<'ctx> for FloatModel<N> {
type Value = FloatValue<'ctx>;
type Type = FloatType<'ctx>;
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
self.0.get_float_type(generator, ctx)
}
fn check_type<T: inkwell::types::BasicType<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
ty: T,
) -> Result<(), ModelError> {
let ty = ty.as_basic_type_enum();
let Ok(ty) = FloatType::try_from(ty) else {
return Err(ModelError(format!("Expecting FloatType, but got {ty:?}")));
};
let exp_ty = self.0.get_float_type(generator, ctx);
// TODO: Inkwell does not have get_bit_width for FloatType?
// TODO: Quick hack for now, but does this actually work?
if ty != exp_ty {
return Err(ModelError(format!("Expecting {exp_ty:?}, but got {ty:?}")));
}
Ok(())
}
}

View File

@ -0,0 +1,125 @@
use inkwell::{
attributes::{Attribute, AttributeLoc},
types::{BasicMetadataTypeEnum, BasicType, FunctionType},
values::{AnyValue, BasicMetadataValueEnum, BasicValue, BasicValueEnum, CallSiteValue},
};
use itertools::Itertools;
use crate::codegen::{CodeGenContext, CodeGenerator};
use super::*;
// When [`TypeContext::size_type`] is 32-bits, the function name is "{fn_name}".
// When [`TypeContext::size_type`] is 64-bits, the function name is "{fn_name}64".
#[must_use]
pub fn get_sizet_dependent_function_name<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &CodeGenContext<'_, '_>,
name: &str,
) -> String {
let mut name = name.to_owned();
match generator.get_size_type(ctx.ctx).get_bit_width() {
32 => {}
64 => name.push_str("64"),
bit_width => {
panic!("Unsupported int type bit width {bit_width}, must be either 32-bits or 64-bits")
}
}
name
}
#[derive(Debug, Clone, Copy)]
struct Arg<'ctx> {
ty: BasicMetadataTypeEnum<'ctx>,
val: BasicMetadataValueEnum<'ctx>,
}
/// A structure to construct & call an LLVM function.
///
/// This is a helper to reduce IRRT Inkwell function call boilerplate
// TODO: Remove the lifetimes somehow? There is 4 of them.
pub struct CallFunction<'ctx, 'a, 'b, 'c, 'd, G: CodeGenerator + ?Sized> {
generator: &'d mut G,
ctx: &'b CodeGenContext<'ctx, 'a>,
/// Function name
name: &'c str,
/// Call arguments
args: Vec<Arg<'ctx>>,
/// LLVM function Attributes
attrs: Vec<&'static str>,
}
impl<'ctx, 'a, 'b, 'c, 'd, G: CodeGenerator + ?Sized> CallFunction<'ctx, 'a, 'b, 'c, 'd, G> {
pub fn begin(generator: &'d mut G, ctx: &'b CodeGenContext<'ctx, 'a>, name: &'c str) -> Self {
CallFunction { generator, ctx, name, args: Vec::new(), attrs: Vec::new() }
}
/// Push a list of LLVM function attributes to the function declaration.
#[must_use]
pub fn attrs(mut self, attrs: Vec<&'static str>) -> Self {
self.attrs = attrs;
self
}
/// Push a call argument to the function call.
#[allow(clippy::needless_pass_by_value)]
#[must_use]
pub fn arg<M: Model<'ctx>>(mut self, arg: Instance<'ctx, M>) -> Self {
let arg = Arg {
ty: arg.model.get_type(self.generator, self.ctx.ctx).as_basic_type_enum().into(),
val: arg.value.as_basic_value_enum().into(),
};
self.args.push(arg);
self
}
/// Call the function and expect the function to return a value of type of `return_model`.
#[must_use]
pub fn returning<M: Model<'ctx>>(self, name: &str, return_model: M) -> Instance<'ctx, M> {
let ret_ty = return_model.get_type(self.generator, self.ctx.ctx);
let ret = self.get_function(|tys| ret_ty.fn_type(tys, false), name);
let ret = BasicValueEnum::try_from(ret.as_any_value_enum()).unwrap(); // Must work
let ret = return_model.check_value(self.generator, self.ctx.ctx, ret).unwrap(); // Must work
ret
}
/// Like [`CallFunction::returning_`] but `return_model` is automatically inferred.
#[must_use]
pub fn returning_auto<M: Model<'ctx> + Default>(self, name: &str) -> Instance<'ctx, M> {
self.returning(name, M::default())
}
/// Call the function and expect the function to return a void-type.
pub fn returning_void(self) {
let ret_ty = self.ctx.ctx.void_type();
let _ = self.get_function(|tys| ret_ty.fn_type(tys, false), "");
}
fn get_function<F>(&self, make_fn_type: F, return_value_name: &str) -> CallSiteValue<'ctx>
where
F: FnOnce(&[BasicMetadataTypeEnum<'ctx>]) -> FunctionType<'ctx>,
{
// Get the LLVM function.
let func = self.ctx.module.get_function(self.name).unwrap_or_else(|| {
// Declare the function if it doesn't exist.
let tys = self.args.iter().map(|arg| arg.ty).collect_vec();
let func_type = make_fn_type(&tys);
let func = self.ctx.module.add_function(self.name, func_type, None);
for attr in &self.attrs {
func.add_attribute(
AttributeLoc::Function,
self.ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
let vals = self.args.iter().map(|arg| arg.val).collect_vec();
self.ctx.builder.build_call(func, &vals, return_value_name).unwrap()
}
}

View File

@ -0,0 +1,275 @@
use std::fmt;
use inkwell::{context::Context, types::IntType, values::IntValue, IntPredicate};
use crate::codegen::{CodeGenContext, CodeGenerator};
use super::*;
pub trait IntKind<'ctx>: fmt::Debug + Clone + Copy {
fn get_int_type<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &'ctx Context,
) -> IntType<'ctx>;
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Bool;
#[derive(Debug, Clone, Copy, Default)]
pub struct Byte;
#[derive(Debug, Clone, Copy, Default)]
pub struct Int32;
#[derive(Debug, Clone, Copy, Default)]
pub struct Int64;
#[derive(Debug, Clone, Copy, Default)]
pub struct SizeT;
impl<'ctx> IntKind<'ctx> for Bool {
fn get_int_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
ctx: &'ctx Context,
) -> IntType<'ctx> {
ctx.bool_type()
}
}
impl<'ctx> IntKind<'ctx> for Byte {
fn get_int_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
ctx: &'ctx Context,
) -> IntType<'ctx> {
ctx.i8_type()
}
}
impl<'ctx> IntKind<'ctx> for Int32 {
fn get_int_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
ctx: &'ctx Context,
) -> IntType<'ctx> {
ctx.i32_type()
}
}
impl<'ctx> IntKind<'ctx> for Int64 {
fn get_int_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
ctx: &'ctx Context,
) -> IntType<'ctx> {
ctx.i64_type()
}
}
impl<'ctx> IntKind<'ctx> for SizeT {
fn get_int_type<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &'ctx Context,
) -> IntType<'ctx> {
generator.get_size_type(ctx)
}
}
#[derive(Debug, Clone, Copy)]
pub struct AnyInt<'ctx>(pub IntType<'ctx>);
impl<'ctx> IntKind<'ctx> for AnyInt<'ctx> {
fn get_int_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
_ctx: &'ctx Context,
) -> IntType<'ctx> {
self.0
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct IntModel<N>(pub N);
pub type Int<'ctx, N> = Instance<'ctx, IntModel<N>>;
impl<'ctx, N: IntKind<'ctx>> Model<'ctx> for IntModel<N> {
type Value = IntValue<'ctx>;
type Type = IntType<'ctx>;
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
self.0.get_int_type(generator, ctx)
}
fn check_type<T: inkwell::types::BasicType<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
ty: T,
) -> Result<(), ModelError> {
let ty = ty.as_basic_type_enum();
let Ok(ty) = IntType::try_from(ty) else {
return Err(ModelError(format!("Expecting IntType, but got {ty:?}")));
};
let exp_ty = self.0.get_int_type(generator, ctx);
if ty.get_bit_width() != exp_ty.get_bit_width() {
return Err(ModelError(format!(
"Expecting IntType to have {} bit(s), but got {} bit(s)",
exp_ty.get_bit_width(),
ty.get_bit_width()
)));
}
Ok(())
}
}
impl<'ctx, N: IntKind<'ctx>> IntModel<N> {
pub fn constant<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
value: u64,
) -> Int<'ctx, N> {
let value = self.get_type(generator, ctx).const_int(value, false);
self.believe_value(value)
}
pub fn const_0<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
) -> Int<'ctx, N> {
self.constant(generator, ctx, 0)
}
pub fn const_1<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
) -> Int<'ctx, N> {
self.constant(generator, ctx, 1)
}
pub fn const_all_1s<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
) -> Int<'ctx, N> {
let value = self.get_type(generator, ctx).const_all_ones();
self.believe_value(value)
}
pub fn s_extend_or_bit_cast<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
value: IntValue<'ctx>,
name: &str,
) -> Int<'ctx, N> {
let value = ctx
.builder
.build_int_s_extend_or_bit_cast(value, self.get_type(generator, ctx.ctx), name)
.unwrap();
self.believe_value(value)
}
pub fn truncate<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
value: IntValue<'ctx>,
name: &str,
) -> Int<'ctx, N> {
let value =
ctx.builder.build_int_truncate(value, self.get_type(generator, ctx.ctx), name).unwrap();
self.believe_value(value)
}
}
impl IntModel<Bool> {
#[must_use]
pub fn const_false<'ctx, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
) -> Int<'ctx, Bool> {
self.constant(generator, ctx, 0)
}
#[must_use]
pub fn const_true<'ctx, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
) -> Int<'ctx, Bool> {
self.constant(generator, ctx, 1)
}
}
impl<'ctx, N: IntKind<'ctx>> Int<'ctx, N> {
pub fn s_extend_or_bit_cast<NewN: IntKind<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
to_int_kind: NewN,
name: &str,
) -> Int<'ctx, NewN> {
IntModel(to_int_kind).s_extend_or_bit_cast(generator, ctx, self.value, name)
}
pub fn truncate<NewN: IntKind<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
to_int_kind: NewN,
name: &str,
) -> Int<'ctx, NewN> {
IntModel(to_int_kind).truncate(generator, ctx, self.value, name)
}
#[must_use]
pub fn add(
&self,
ctx: &CodeGenContext<'ctx, '_>,
other: Int<'ctx, N>,
name: &str,
) -> Int<'ctx, N> {
let value = ctx.builder.build_int_add(self.value, other.value, name).unwrap();
self.model.believe_value(value)
}
#[must_use]
pub fn sub(
&self,
ctx: &CodeGenContext<'ctx, '_>,
other: Int<'ctx, N>,
name: &str,
) -> Int<'ctx, N> {
let value = ctx.builder.build_int_sub(self.value, other.value, name).unwrap();
self.model.believe_value(value)
}
#[must_use]
pub fn mul(
&self,
ctx: &CodeGenContext<'ctx, '_>,
other: Int<'ctx, N>,
name: &str,
) -> Int<'ctx, N> {
let value = ctx.builder.build_int_mul(self.value, other.value, name).unwrap();
self.model.believe_value(value)
}
pub fn compare(
&self,
ctx: &CodeGenContext<'ctx, '_>,
op: IntPredicate,
other: Int<'ctx, N>,
name: &str,
) -> Int<'ctx, Bool> {
let bool_model = IntModel(Bool);
let value = ctx.builder.build_int_compare(op, self.value, other.value, name).unwrap();
bool_model.believe_value(value)
}
}

View File

@ -0,0 +1,17 @@
mod any;
mod array;
mod core;
mod float;
pub mod function;
mod int;
mod ptr;
mod structure;
pub mod util;
pub use any::*;
pub use array::*;
pub use core::*;
pub use float::*;
pub use int::*;
pub use ptr::*;
pub use structure::*;

View File

@ -0,0 +1,145 @@
use inkwell::{
context::Context,
types::{BasicType, BasicTypeEnum, PointerType},
values::{IntValue, PointerValue},
AddressSpace,
};
use crate::codegen::{CodeGenContext, CodeGenerator};
use super::*;
#[derive(Debug, Clone, Copy, Default)]
pub struct PtrModel<Element>(pub Element);
pub type Ptr<'ctx, Element> = Instance<'ctx, PtrModel<Element>>;
impl<'ctx, Element: Model<'ctx>> Model<'ctx> for PtrModel<Element> {
type Value = PointerValue<'ctx>;
type Type = PointerType<'ctx>;
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
self.0.get_type(generator, ctx).ptr_type(AddressSpace::default())
}
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
ty: T,
) -> Result<(), ModelError> {
let ty = ty.as_basic_type_enum();
let Ok(ty) = PointerType::try_from(ty) else {
return Err(ModelError(format!("Expecting PointerType, but got {ty:?}")));
};
let elem_ty = ty.get_element_type();
let Ok(elem_ty) = BasicTypeEnum::try_from(elem_ty) else {
return Err(ModelError(format!(
"Expecting pointer element type to be a BasicTypeEnum, but got {elem_ty:?}"
)));
};
// TODO: inkwell `get_element_type()` will be deprecated.
// Remove the check for `get_element_type()` when the time comes.
self.0
.check_type(generator, ctx, elem_ty)
.map_err(|err| err.under_context("a PointerType"))?;
Ok(())
}
}
impl<'ctx, Element: Model<'ctx>> PtrModel<Element> {
/// Return a ***constant*** nullptr.
pub fn nullptr<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
) -> Ptr<'ctx, Element> {
let ptr = self.get_type(generator, ctx).const_null();
self.believe_value(ptr)
}
/// Cast a pointer into this model with [`inkwell::builder::Builder::build_pointer_cast`]
pub fn pointer_cast<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
ptr: PointerValue<'ctx>,
name: &str,
) -> Ptr<'ctx, Element> {
let ptr =
ctx.builder.build_pointer_cast(ptr, self.get_type(generator, ctx.ctx), name).unwrap();
self.believe_value(ptr)
}
}
impl<'ctx, Element: Model<'ctx>> Ptr<'ctx, Element> {
/// Offset the pointer by [`inkwell::builder::Builder::build_in_bounds_gep`].
#[must_use]
pub fn offset<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
offset: IntValue<'ctx>,
name: &str,
) -> Ptr<'ctx, Element> {
let new_ptr =
unsafe { ctx.builder.build_in_bounds_gep(self.value, &[offset], name).unwrap() };
self.model.check_value(generator, ctx.ctx, new_ptr).unwrap()
}
/// Offset the pointer by [`inkwell::builder::Builder::build_in_bounds_gep`] by a constant offset.
#[must_use]
pub fn offset_const<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
offset: u64,
name: &str,
) -> Ptr<'ctx, Element> {
let offset = ctx.ctx.i32_type().const_int(offset, false);
self.offset(generator, ctx, offset, name)
}
/// Load the value with [`inkwell::builder::Builder::build_load`].
pub fn load<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
name: &str,
) -> Instance<'ctx, Element> {
let value = ctx.builder.build_load(self.value, name).unwrap();
self.model.0.check_value(generator, ctx.ctx, value).unwrap() // If unwrap() panics, there is a logic error.
}
/// Store a value with [`inkwell::builder::Builder::build_store`].
pub fn store(&self, ctx: &CodeGenContext<'ctx, '_>, value: Instance<'ctx, Element>) {
ctx.builder.build_store(self.value, value.value).unwrap();
}
/// Return a casted pointer of element type `NewElement` with [`inkwell::builder::Builder::build_pointer_cast`].
pub fn pointer_cast<NewElement: Model<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
new_model: NewElement,
name: &str,
) -> Ptr<'ctx, NewElement> {
PtrModel(new_model).pointer_cast(generator, ctx, self.value, name)
}
/// Check if the pointer is null with [`inkwell::builder::Builder::build_is_null`].
pub fn is_null(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> Int<'ctx, Bool> {
let bool_model = IntModel(Bool);
let value = ctx.builder.build_is_null(self.value, name).unwrap();
bool_model.believe_value(value)
}
/// Check if the pointer is not null with [`inkwell::builder::Builder::build_is_not_null`].
pub fn is_not_null(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> Int<'ctx, Bool> {
let bool_model = IntModel(Bool);
let value = ctx.builder.build_is_not_null(self.value, name).unwrap();
bool_model.believe_value(value)
}
}

View File

@ -0,0 +1,222 @@
use std::fmt;
use inkwell::{
context::Context,
types::{BasicType, BasicTypeEnum, StructType},
values::StructValue,
};
use crate::codegen::{CodeGenContext, CodeGenerator};
use super::*;
#[derive(Debug, Clone, Copy)]
pub struct GepField<M> {
pub gep_index: u64,
pub name: &'static str,
pub model: M,
}
pub trait FieldTraversal<'ctx> {
type Out<M>;
fn add<M: Model<'ctx>>(&mut self, name: &'static str, model: M) -> Self::Out<M>;
/// Like [`FieldTraversal::visit`] but [`Model`] is automatically inferred from [`Default`] trait.
fn add_auto<M: Model<'ctx> + Default>(&mut self, name: &'static str) -> Self::Out<M> {
self.add(name, M::default())
}
}
pub struct GepFieldTraversal {
gep_index_counter: u64,
}
impl<'ctx> FieldTraversal<'ctx> for GepFieldTraversal {
type Out<M> = GepField<M>;
fn add<M: Model<'ctx>>(&mut self, name: &'static str, model: M) -> Self::Out<M> {
let gep_index = self.gep_index_counter;
self.gep_index_counter += 1;
Self::Out { gep_index, name, model }
}
}
struct TypeFieldTraversal<'ctx, 'a, G: CodeGenerator + ?Sized> {
generator: &'a G,
ctx: &'ctx Context,
field_types: Vec<BasicTypeEnum<'ctx>>,
}
impl<'ctx, 'a, G: CodeGenerator + ?Sized> FieldTraversal<'ctx> for TypeFieldTraversal<'ctx, 'a, G> {
type Out<M> = ();
fn add<M: Model<'ctx>>(&mut self, _name: &'static str, model: M) -> Self::Out<M> {
let t = model.get_type(self.generator, self.ctx).as_basic_type_enum();
self.field_types.push(t);
}
}
struct CheckTypeFieldTraversal<'ctx, 'a, G: CodeGenerator + ?Sized> {
generator: &'a mut G,
ctx: &'ctx Context,
index: u32,
scrutinee: StructType<'ctx>,
errors: Vec<ModelError>,
}
impl<'ctx, 'a, G: CodeGenerator + ?Sized> FieldTraversal<'ctx>
for CheckTypeFieldTraversal<'ctx, 'a, G>
{
type Out<M> = ();
fn add<M: Model<'ctx>>(&mut self, name: &'static str, model: M) -> Self::Out<M> {
let i = self.index;
self.index += 1;
if let Some(t) = self.scrutinee.get_field_type_at_index(i) {
if let Err(err) = model.check_type(self.generator, self.ctx, t) {
self.errors.push(err.under_context(format!("field #{i} '{name}'").as_str()));
}
} // Otherwise, it will be caught
}
}
pub trait StructKind<'ctx>: fmt::Debug + Clone + Copy {
type Fields<F: FieldTraversal<'ctx>>;
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F>;
fn fields(&self) -> Self::Fields<GepFieldTraversal> {
self.traverse_fields(&mut GepFieldTraversal { gep_index_counter: 0 })
}
fn get_struct_type<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &'ctx Context,
) -> StructType<'ctx> {
let mut traversal = TypeFieldTraversal { generator, ctx, field_types: Vec::new() };
self.traverse_fields(&mut traversal);
ctx.struct_type(&traversal.field_types, false)
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct StructModel<S>(pub S);
pub type Struct<'ctx, S> = Instance<'ctx, StructModel<S>>;
impl<'ctx, S: StructKind<'ctx>> Model<'ctx> for StructModel<S> {
type Value = StructValue<'ctx>;
type Type = StructType<'ctx>;
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
self.0.get_struct_type(generator, ctx)
}
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
ty: T,
) -> Result<(), ModelError> {
let ty = ty.as_basic_type_enum();
let Ok(ty) = StructType::try_from(ty) else {
return Err(ModelError(format!("Expecting StructType, but got {ty:?}")));
};
let mut traversal =
CheckTypeFieldTraversal { generator, ctx, index: 0, errors: Vec::new(), scrutinee: ty };
self.0.traverse_fields(&mut traversal);
let exp_num_fields = traversal.index;
let got_num_fields = u32::try_from(ty.get_field_types().len()).unwrap();
if exp_num_fields != got_num_fields {
return Err(ModelError(format!(
"Expecting StructType with {exp_num_fields} field(s), but got {got_num_fields}"
)));
}
if !traversal.errors.is_empty() {
return Err(traversal.errors[0].clone()); // TODO: Return other errors as well
}
Ok(())
}
}
impl<'ctx, S: StructKind<'ctx>> Struct<'ctx, S> {
pub fn get_field<G: CodeGenerator + ?Sized, M, GetField>(
&self,
generator: &mut G,
ctx: &'ctx Context,
get_field: GetField,
) -> Instance<'ctx, M>
where
M: Model<'ctx>,
GetField: FnOnce(S::Fields<GepFieldTraversal>) -> GepField<M>,
{
let field = get_field(self.model.0.fields());
let val = self.value.get_field_at_index(field.gep_index as u32).unwrap();
field.model.check_value(generator, ctx, val).unwrap()
}
}
impl<'ctx, S: StructKind<'ctx>> Ptr<'ctx, StructModel<S>> {
pub fn gep<M, GetField>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
get_field: GetField,
) -> Ptr<'ctx, M>
where
M: Model<'ctx>,
GetField: FnOnce(S::Fields<GepFieldTraversal>) -> GepField<M>,
{
let field = get_field(self.model.0 .0.fields());
let llvm_i32 = ctx.ctx.i32_type(); // i64 would segfault
let ptr = unsafe {
ctx.builder
.build_in_bounds_gep(
self.value,
&[llvm_i32.const_zero(), llvm_i32.const_int(field.gep_index, false)],
field.name,
)
.unwrap()
};
let ptr_model = PtrModel(field.model);
ptr_model.believe_value(ptr)
}
/// Convenience function equivalent to `.gep(...).load(...)`.
pub fn get<M, GetField, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
get_field: GetField,
name: &str,
) -> Instance<'ctx, M>
where
M: Model<'ctx>,
GetField: FnOnce(S::Fields<GepFieldTraversal>) -> GepField<M>,
{
self.gep(ctx, get_field).load(generator, ctx, name)
}
/// Convenience function equivalent to `.gep(...).store(...)`.
pub fn set<M, GetField>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
get_field: GetField,
value: Instance<'ctx, M>,
) where
M: Model<'ctx>,
GetField: FnOnce(S::Fields<GepFieldTraversal>) -> GepField<M>,
{
self.gep(ctx, get_field).store(ctx, value);
}
}
// TODO: Add an opaque struct type?

View File

@ -0,0 +1,62 @@
use inkwell::{types::BasicType, values::IntValue};
/// `llvm.memcpy` but under the [`Model`] abstraction
use crate::codegen::{
llvm_intrinsics::call_memcpy_generic,
stmt::{gen_for_callback_incrementing, BreakContinueHooks},
CodeGenContext, CodeGenerator,
};
use super::*;
/// Convenience function.
///
/// Like [`call_memcpy_generic`] but with model abstractions and `is_volatile` set to `false`.
pub fn call_memcpy_model<'ctx, Item: Model<'ctx> + Default, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
dst_array: Ptr<'ctx, Item>,
src_array: Ptr<'ctx, Item>,
num_items: IntValue<'ctx>,
) {
let itemsize = Item::default().get_type(generator, ctx.ctx).size_of().unwrap();
let totalsize = ctx.builder.build_int_mul(itemsize, num_items, "totalsize").unwrap(); // TODO: Int types may not match.
let is_volatile = ctx.ctx.bool_type().const_zero();
call_memcpy_generic(ctx, dst_array.value, src_array.value, totalsize, is_volatile);
}
/// Like [`gen_for_callback_incrementing`] with [`Model`] abstractions.
/// The [`IntKind`] is automatically inferred.
pub fn gen_for_model_auto<'ctx, 'a, G, F, I>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
start: Int<'ctx, I>,
stop: Int<'ctx, I>,
step: Int<'ctx, I>,
body: F,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
F: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
BreakContinueHooks<'ctx>,
Int<'ctx, I>,
) -> Result<(), String>,
I: IntKind<'ctx> + Default,
{
let int_model = IntModel(I::default());
gen_for_callback_incrementing(
generator,
ctx,
None,
start.value,
(stop.value, false),
|g, ctx, hooks, i| {
let i = int_model.believe_value(i);
body(g, ctx, hooks, i)
},
step.value,
)
}

View File

@ -26,12 +26,15 @@ use crate::{
typedef::{FunSignature, Type, TypeEnum}, typedef::{FunSignature, Type, TypeEnum},
}, },
}; };
use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType};
use inkwell::{ use inkwell::{
types::BasicType, types::BasicType,
values::{BasicValueEnum, IntValue, PointerValue}, values::{BasicValueEnum, IntValue, PointerValue},
AddressSpace, IntPredicate, OptimizationLevel, AddressSpace, IntPredicate, OptimizationLevel,
}; };
use inkwell::{
types::{AnyTypeEnum, BasicTypeEnum, PointerType},
values::BasicValue,
};
use nac3parser::ast::{Operator, StrRef}; use nac3parser::ast::{Operator, StrRef};
/// Creates an uninitialized `NDArray` instance. /// Creates an uninitialized `NDArray` instance.
@ -86,6 +89,7 @@ where
gen_for_callback_incrementing( gen_for_callback_incrementing(
generator, generator,
ctx, ctx,
None,
llvm_usize.const_zero(), llvm_usize.const_zero(),
(shape_len, false), (shape_len, false),
|generator, ctx, _, i| { |generator, ctx, _, i| {
@ -131,6 +135,7 @@ where
gen_for_callback_incrementing( gen_for_callback_incrementing(
generator, generator,
ctx, ctx,
None,
llvm_usize.const_zero(), llvm_usize.const_zero(),
(shape_len, false), (shape_len, false),
|generator, ctx, _, i| { |generator, ctx, _, i| {
@ -157,7 +162,7 @@ where
/// ///
/// * `elem_ty` - The element type of the `NDArray`. /// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The shape of the `NDArray`, represented am array of [`IntValue`]s. /// * `shape` - The shape of the `NDArray`, represented am array of [`IntValue`]s.
fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>( pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type, elem_ty: Type,
@ -252,7 +257,7 @@ fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) { } else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
ctx.ctx.bool_type().const_zero().into() ctx.ctx.bool_type().const_zero().into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) { } else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
ctx.gen_string(generator, "") ctx.gen_string(generator, "").value.into()
} else { } else {
unreachable!() unreachable!()
} }
@ -280,7 +285,7 @@ fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) { } else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
ctx.ctx.bool_type().const_int(1, false).into() ctx.ctx.bool_type().const_int(1, false).into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) { } else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
ctx.gen_string(generator, "1") ctx.gen_string(generator, "1").value.into()
} else { } else {
unreachable!() unreachable!()
} }
@ -382,6 +387,7 @@ where
gen_for_callback_incrementing( gen_for_callback_incrementing(
generator, generator,
ctx, ctx,
None,
llvm_usize.const_zero(), llvm_usize.const_zero(),
(ndarray_num_elems, false), (ndarray_num_elems, false),
|generator, ctx, _, i| { |generator, ctx, _, i| {
@ -703,11 +709,12 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>(
gen_for_range_callback( gen_for_range_callback(
generator, generator,
ctx, ctx,
None,
true, true,
|_, _| Ok(llvm_usize.const_zero()), |_, _| Ok(llvm_usize.const_zero()),
(|_, ctx| Ok(src_lst.load_size(ctx, None)), false), (|_, ctx| Ok(src_lst.load_size(ctx, None)), false),
|_, _| Ok(llvm_usize.const_int(1, false)), |_, _| Ok(llvm_usize.const_int(1, false)),
|generator, ctx, i| { |generator, ctx, _, i| {
let offset = ctx.builder.build_int_mul(stride, i, "").unwrap(); let offset = ctx.builder.build_int_mul(stride, i, "").unwrap();
let dst_ptr = let dst_ptr =
@ -943,11 +950,12 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
gen_for_range_callback( gen_for_range_callback(
generator, generator,
ctx, ctx,
None,
true, true,
|_, _| Ok(llvm_usize.const_zero()), |_, _| Ok(llvm_usize.const_zero()),
(|_, _| Ok(stop), false), (|_, _| Ok(stop), false),
|_, _| Ok(llvm_usize.const_int(1, false)), |_, _| Ok(llvm_usize.const_int(1, false)),
|generator, ctx, _| { |generator, ctx, _, _| {
let plist_plist_i8 = make_llvm_list(llvm_plist_i8.into()) let plist_plist_i8 = make_llvm_list(llvm_plist_i8.into())
.ptr_type(AddressSpace::default()); .ptr_type(AddressSpace::default());
@ -1086,13 +1094,17 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
// If there are no (remaining) slice expressions, memcpy the entire dimension // If there are no (remaining) slice expressions, memcpy the entire dimension
if slices.is_empty() { if slices.is_empty() {
let sizeof_elem = ctx.get_llvm_type(generator, elem_ty).size_of().unwrap();
let stride = call_ndarray_calc_size( let stride = call_ndarray_calc_size(
generator, generator,
ctx, ctx,
&src_arr.dim_sizes(), &src_arr.dim_sizes(),
(Some(llvm_usize.const_int(dim, false)), None), (Some(llvm_usize.const_int(dim, false)), None),
); );
let sizeof_elem = ctx.get_llvm_type(generator, elem_ty).size_of().unwrap(); let stride =
ctx.builder.build_int_z_extend_or_bit_cast(stride, sizeof_elem.get_type(), "").unwrap();
let cpy_len = ctx.builder.build_int_mul(stride, sizeof_elem, "").unwrap(); let cpy_len = ctx.builder.build_int_mul(stride, sizeof_elem, "").unwrap();
call_memcpy_generic(ctx, dst_slice_ptr, src_slice_ptr, cpy_len, llvm_i1.const_zero()); call_memcpy_generic(ctx, dst_slice_ptr, src_slice_ptr, cpy_len, llvm_i1.const_zero());
@ -1126,11 +1138,12 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
gen_for_range_callback( gen_for_range_callback(
generator, generator,
ctx, ctx,
None,
false, false,
|_, _| Ok(start), |_, _| Ok(start),
(|_, _| Ok(stop), true), (|_, _| Ok(stop), true),
|_, _| Ok(step), |_, _| Ok(step),
|generator, ctx, src_i| { |generator, ctx, _, src_i| {
// Calculate the offset of the active slice // Calculate the offset of the active slice
let src_data_offset = ctx.builder.build_int_mul(src_stride, src_i, "").unwrap(); let src_data_offset = ctx.builder.build_int_mul(src_stride, src_i, "").unwrap();
let dst_i = let dst_i =
@ -1243,6 +1256,7 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
gen_for_callback_incrementing( gen_for_callback_incrementing(
generator, generator,
ctx, ctx,
None,
llvm_usize.const_int(slices.len() as u64, false), llvm_usize.const_int(slices.len() as u64, false),
(this.load_ndims(ctx), false), (this.load_ndims(ctx), false),
|generator, ctx, _, idx| { |generator, ctx, _, idx| {
@ -1647,6 +1661,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
gen_for_callback_incrementing( gen_for_callback_incrementing(
generator, generator,
ctx, ctx,
None,
llvm_i32.const_zero(), llvm_i32.const_zero(),
(common_dim, false), (common_dim, false),
|generator, ctx, _, i| { |generator, ctx, _, i| {
@ -2014,3 +2029,493 @@ pub fn gen_ndarray_fill<'ctx>(
Ok(()) Ok(())
} }
/// Generates LLVM IR for `ndarray.transpose`.
pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "ndarray_transpose";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
// Dimensions are reversed in the transposed array
let out = create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&n1,
|_, ctx, n| Ok(n.load_ndims(ctx)),
|generator, ctx, n, idx| {
let new_idx = ctx.builder.build_int_sub(n.load_ndims(ctx), idx, "").unwrap();
let new_idx = ctx
.builder
.build_int_sub(new_idx, new_idx.get_type().const_int(1, false), "")
.unwrap();
unsafe { Ok(n.dim_sizes().get_typed_unchecked(ctx, generator, &new_idx, None)) }
},
)
.unwrap();
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(n_sz, false),
|generator, ctx, _, idx| {
let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
let new_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
let rem_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
ctx.builder.build_store(new_idx, llvm_usize.const_zero()).unwrap();
ctx.builder.build_store(rem_idx, idx).unwrap();
// Incrementally calculate the new index in the transposed array
// For each index, we first decompose it into the n-dims and use those to reconstruct the new index
// The formula used for indexing is:
// idx = dim_n * ( ... (dim2 * (dim0 * dim1) + dim1) + dim2 ... ) + dim_n
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(n1.load_ndims(ctx), false),
|generator, ctx, _, ndim| {
let ndim_rev =
ctx.builder.build_int_sub(n1.load_ndims(ctx), ndim, "").unwrap();
let ndim_rev = ctx
.builder
.build_int_sub(ndim_rev, llvm_usize.const_int(1, false), "")
.unwrap();
let dim = unsafe {
n1.dim_sizes().get_typed_unchecked(ctx, generator, &ndim_rev, None)
};
let rem_idx_val =
ctx.builder.build_load(rem_idx, "").unwrap().into_int_value();
let new_idx_val =
ctx.builder.build_load(new_idx, "").unwrap().into_int_value();
let add_component =
ctx.builder.build_int_unsigned_rem(rem_idx_val, dim, "").unwrap();
let rem_idx_val =
ctx.builder.build_int_unsigned_div(rem_idx_val, dim, "").unwrap();
let new_idx_val = ctx.builder.build_int_mul(new_idx_val, dim, "").unwrap();
let new_idx_val =
ctx.builder.build_int_add(new_idx_val, add_component, "").unwrap();
ctx.builder.build_store(rem_idx, rem_idx_val).unwrap();
ctx.builder.build_store(new_idx, new_idx_val).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let new_idx_val = ctx.builder.build_load(new_idx, "").unwrap().into_int_value();
unsafe { out.data().set_unchecked(ctx, generator, &new_idx_val, elem) };
Ok(())
},
llvm_usize.const_int(1, false),
)?;
Ok(out.as_base_value().into())
} else {
unreachable!(
"{FN_NAME}() not supported for '{}'",
format!("'{}'", ctx.unifier.stringify(x1_ty))
)
}
}
/// LLVM-typed implementation for generating the implementation for `ndarray.reshape`.
///
/// * `x1` - `NDArray` to reshape.
/// * `shape` - The `shape` parameter used to construct the new `NDArray`.
/// Just like numpy, the `shape` argument can be:
/// 1. A list of `int32`; e.g., `np.reshape(arr, [600, -1, 3])`
/// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))`
/// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)`
/// Note that unlike other generating functions, one of the dimesions in the shape can be negative
pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
shape: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "ndarray_reshape";
let (x1_ty, x1) = x1;
let (_, shape) = shape;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
ctx.builder.build_store(acc, llvm_usize.const_int(1, false)).unwrap();
ctx.builder.build_store(num_neg, llvm_usize.const_zero()).unwrap();
let out = match shape {
BasicValueEnum::PointerValue(shape_list_ptr)
if ListValue::is_instance(shape_list_ptr, llvm_usize).is_ok() =>
{
// 1. A list of ints; e.g., `np.reshape(arr, [int64(600), int64(800, -1])`
let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None);
// Check for -1 in dimensions
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(shape_list.load_size(ctx, None), false),
|generator, ctx, _, idx| {
let ele =
shape_list.data().get(ctx, generator, &idx, None).into_int_value();
let ele = ctx.builder.build_int_s_extend(ele, llvm_usize, "").unwrap();
gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
ele,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, ctx| -> Result<Option<IntValue>, String> {
let num_neg_value =
ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
let num_neg_value = ctx
.builder
.build_int_add(
num_neg_value,
llvm_usize.const_int(1, false),
"",
)
.unwrap();
ctx.builder.build_store(num_neg, num_neg_value).unwrap();
Ok(None)
},
|_, ctx| {
let acc_value =
ctx.builder.build_load(acc, "").unwrap().into_int_value();
let acc_value =
ctx.builder.build_int_mul(acc_value, ele, "").unwrap();
ctx.builder.build_store(acc, acc_value).unwrap();
Ok(None)
},
)?;
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
// Generate the output shape by filling -1 with `rem`
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&shape_list,
|_, ctx, _| Ok(shape_list.load_size(ctx, None)),
|generator, ctx, shape_list, idx| {
let dim =
shape_list.data().get(ctx, generator, &idx, None).into_int_value();
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
Ok(gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
dim,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, _| Ok(Some(rem)),
|_, _| Ok(Some(dim)),
)?
.unwrap()
.into_int_value())
},
)
}
BasicValueEnum::StructValue(shape_tuple) => {
// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))`
let ndims = shape_tuple.get_type().count_fields();
// Check for -1 in dims
for dim_i in 0..ndims {
let dim = ctx
.builder
.build_extract_value(shape_tuple, dim_i, "")
.unwrap()
.into_int_value();
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
dim,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, ctx| -> Result<Option<IntValue>, String> {
let num_negs =
ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
let num_negs = ctx
.builder
.build_int_add(num_negs, llvm_usize.const_int(1, false), "")
.unwrap();
ctx.builder.build_store(num_neg, num_negs).unwrap();
Ok(None)
},
|_, ctx| {
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
let acc_val = ctx.builder.build_int_mul(acc_val, dim, "").unwrap();
ctx.builder.build_store(acc, acc_val).unwrap();
Ok(None)
},
)?;
}
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
let mut shape = Vec::with_capacity(ndims as usize);
// Reconstruct shape filling negatives with rem
for dim_i in 0..ndims {
let dim = ctx
.builder
.build_extract_value(shape_tuple, dim_i, "")
.unwrap()
.into_int_value();
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
let dim = gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
dim,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, _| Ok(Some(rem)),
|_, _| Ok(Some(dim)),
)?
.unwrap()
.into_int_value();
shape.push(dim);
}
create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice())
}
BasicValueEnum::IntValue(shape_int) => {
// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)`
let shape_int = gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
shape_int,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, _| Ok(Some(n_sz)),
|_, ctx| {
Ok(Some(ctx.builder.build_int_s_extend(shape_int, llvm_usize, "").unwrap()))
},
)?
.unwrap()
.into_int_value();
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
}
_ => unreachable!(),
}
.unwrap();
// Only allow one dimension to be negative
let num_negs = ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
ctx.make_assert(
generator,
ctx.builder
.build_int_compare(IntPredicate::ULT, num_negs, llvm_usize.const_int(2, false), "")
.unwrap(),
"0:ValueError",
"can only specify one unknown dimension",
[None, None, None],
ctx.current_loc,
);
// The new shape must be compatible with the old shape
let out_sz = call_ndarray_calc_size(generator, ctx, &out.dim_sizes(), (None, None));
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(),
"0:ValueError",
"cannot reshape array of size {0} into provided shape of size {1}",
[Some(n_sz), Some(out_sz), None],
ctx.current_loc,
);
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(n_sz, false),
|generator, ctx, _, idx| {
let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
unsafe { out.data().set_unchecked(ctx, generator, &idx, elem) };
Ok(())
},
llvm_usize.const_int(1, false),
)?;
Ok(out.as_base_value().into())
} else {
unreachable!(
"{FN_NAME}() not supported for '{}'",
format!("'{}'", ctx.unifier.stringify(x1_ty))
)
}
}
/// Generates LLVM IR for `ndarray.dot`.
/// Calculate inner product of two vectors or literals
/// For matrix multiplication use `np_matmul`
///
/// The input `NDArray` are flattened and treated as 1D
/// The operation is equivalent to `np.dot(arr1.ravel(), arr2.ravel())`
pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "ndarray_dot";
let (x1_ty, x1) = x1;
let (_, x2) = x2;
let llvm_usize = generator.get_size_type(ctx.ctx);
match (x1, x2) {
(BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => {
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None);
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::EQ, n1_sz, n2_sz, "").unwrap(),
"0:ValueError",
"shapes ({0}), ({1}) not aligned",
[Some(n1_sz), Some(n2_sz), None],
ctx.current_loc,
);
let identity =
unsafe { n1.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
let acc = ctx.builder.build_alloca(identity.get_type(), "").unwrap();
ctx.builder.build_store(acc, identity.get_type().const_zero()).unwrap();
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(n1_sz, false),
|generator, ctx, _, idx| {
let elem1 = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
let elem2 = unsafe { n2.data().get_unchecked(ctx, generator, &idx, None) };
let product = match elem1 {
BasicValueEnum::IntValue(e1) => ctx
.builder
.build_int_mul(e1, elem2.into_int_value(), "")
.unwrap()
.as_basic_value_enum(),
BasicValueEnum::FloatValue(e1) => ctx
.builder
.build_float_mul(e1, elem2.into_float_value(), "")
.unwrap()
.as_basic_value_enum(),
_ => unreachable!(),
};
let acc_val = ctx.builder.build_load(acc, "").unwrap();
let acc_val = match acc_val {
BasicValueEnum::IntValue(e1) => ctx
.builder
.build_int_add(e1, product.into_int_value(), "")
.unwrap()
.as_basic_value_enum(),
BasicValueEnum::FloatValue(e1) => ctx
.builder
.build_float_add(e1, product.into_float_value(), "")
.unwrap()
.as_basic_value_enum(),
_ => unreachable!(),
};
ctx.builder.build_store(acc, acc_val).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let acc_val = ctx.builder.build_load(acc, "").unwrap();
Ok(acc_val)
}
(BasicValueEnum::IntValue(e1), BasicValueEnum::IntValue(e2)) => {
Ok(ctx.builder.build_int_mul(e1, e2, "").unwrap().as_basic_value_enum())
}
(BasicValueEnum::FloatValue(e1), BasicValueEnum::FloatValue(e2)) => {
Ok(ctx.builder.build_float_mul(e1, e2, "").unwrap().as_basic_value_enum())
}
_ => unreachable!(
"{FN_NAME}() not supported for '{}'",
format!("'{}'", ctx.unifier.stringify(x1_ty))
),
}
}

View File

@ -0,0 +1,210 @@
// TODO: Replace numpy.rs
use inkwell::values::{BasicValue, BasicValueEnum};
use nac3parser::ast::StrRef;
use crate::{
codegen::object::{ndarray::scalar::split_scalar_or_ndarray, tuple::TupleObject},
symbol_resolver::ValueEnum,
toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys, DefinitionId},
typecheck::typedef::{FunSignature, Type},
};
use super::{
irrt::call_nac3_ndarray_util_assert_shape_no_negative,
model::*,
object::{
ndarray::{shape_util::parse_numpy_int_sequence, NDArrayObject},
AnyObject,
},
CodeGenContext, CodeGenerator,
};
/// Generates LLVM IR for `np.broadcast_to`.
pub fn gen_ndarray_broadcast_to<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<BasicValueEnum<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 2);
// Parse argument #1 input
let input_ty = fun.0.args[0].ty;
let input = args[0].1.clone().to_basic_value_enum(ctx, generator, input_ty)?;
let input = AnyObject { ty: input_ty, value: input };
// Parse argument #2 shape
let shape_ty = fun.0.args[1].ty;
let shape = args[1].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
let shape = AnyObject { ty: shape_ty, value: shape };
// Define models
let sizet_model = IntModel(SizeT);
// Extract broadcast_ndims, this is the only way to get the
// ndims of the ndarray result statically.
let (_, broadcast_ndims_ty) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
let broadcast_ndims = extract_ndims(&ctx.unifier, broadcast_ndims_ty);
// Process `input`
let in_ndarray = split_scalar_or_ndarray(generator, ctx, input).as_ndarray(generator, ctx);
// Process `shape`
let (_, broadcast_shape) = parse_numpy_int_sequence(generator, ctx, shape);
// NOTE: shape.size should equal to `broadcasted_ndims`.
let broadcast_ndims_llvm = sizet_model.constant(generator, ctx.ctx, broadcast_ndims);
call_nac3_ndarray_util_assert_shape_no_negative(
generator,
ctx,
broadcast_ndims_llvm,
broadcast_shape,
);
// Create broadcast view
let broadcast_ndarray =
in_ndarray.broadcast_to(generator, ctx, broadcast_ndims, broadcast_shape);
Ok(broadcast_ndarray.instance.value.as_basic_value_enum())
}
/// Generates LLVM IR for `np.reshape`.
pub fn gen_ndarray_reshape<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<BasicValueEnum<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 2);
// Parse argument #1 input
let input_ty = fun.0.args[0].ty;
let input = args[0].1.clone().to_basic_value_enum(ctx, generator, input_ty)?;
let input = AnyObject { ty: input_ty, value: input };
// Parse argument #2 shape
let shape_ty = fun.0.args[1].ty;
let shape = args[1].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
let shape = AnyObject { ty: shape_ty, value: shape };
// Extract reshaped_ndims
let (_, reshaped_ndims_ty) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
let reshaped_ndims = extract_ndims(&ctx.unifier, reshaped_ndims_ty);
// Process `input`
let in_ndarray = split_scalar_or_ndarray(generator, ctx, input).as_ndarray(generator, ctx);
// Process the shape input from user and resolve negative indices.
// The resulting `new_shape`'s size should be equal to reshaped_ndims.
// This is ensured by the typechecker.
let (_, new_shape) = parse_numpy_int_sequence(generator, ctx, shape);
let reshaped_ndarray = in_ndarray.reshape_or_copy(generator, ctx, reshaped_ndims, new_shape);
Ok(reshaped_ndarray.instance.value.as_basic_value_enum())
}
/// Generates LLVM IR for `np.arange`.
pub fn gen_ndarray_arange<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<BasicValueEnum<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
// Parse argument #1 len
let input_ty = fun.0.args[0].ty;
let input = args[0].1.clone().to_basic_value_enum(ctx, generator, input_ty)?.into_int_value();
// Implementation
let input_dim = IntModel(SizeT).s_extend_or_bit_cast(generator, ctx, input, "input_dim");
let ndarray = NDArrayObject::from_np_arange(generator, ctx, input_dim);
Ok(ndarray.instance.value.as_basic_value_enum())
}
/// Generates LLVM IR for `np.size`.
pub fn gen_ndarray_size<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<BasicValueEnum<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let ndarray_ty = fun.0.args[0].ty;
let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
let size = ndarray.size(generator, ctx).truncate(generator, ctx, Int32, "size");
Ok(size.value.as_basic_value_enum())
}
/// Generates LLVM IR for `np.shape`.
pub fn gen_ndarray_shape<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<BasicValueEnum<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
// Parse argument #1 ndarray
let ndarray_ty = fun.0.args[0].ty;
let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
// Process ndarray
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
Ok(ndarray.make_shape_tuple(generator, ctx).value.as_basic_value_enum())
}
/// Generates LLVM IR for `<ndarray>.strides`.
pub fn gen_ndarray_strides<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<BasicValueEnum<'ctx>, String> {
// TODO: Code duplication: This function looks exactly like `gen_ndarray_shapes`.
assert!(obj.is_none());
assert_eq!(args.len(), 1);
// Parse argument #1 ndarray
let ndarray_ty = fun.0.args[0].ty;
let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
// Process ndarray
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
let mut objects = Vec::with_capacity(ndarray.ndims as usize);
for i in 0..ndarray.ndims {
let dim = ndarray
.instance
.get(generator, ctx, |f| f.strides, "")
.offset_const(generator, ctx, i, "")
.load(generator, ctx, "dim");
let dim = dim.truncate(generator, ctx, Int32, "dim"); // TODO: keep using SizeT
objects
.push(AnyObject { ty: ctx.primitives.int32, value: dim.value.as_basic_value_enum() });
}
let strides = TupleObject::create(generator, ctx, objects, "strides");
Ok(strides.value.as_basic_value_enum())
}

View File

@ -0,0 +1,121 @@
use crate::{
codegen::{
irrt::{call_nac3_list_slice_assign, list_slice_assignment},
model::*,
object::ndarray::indexing::UserSlice,
structure::List,
CodeGenContext, CodeGenerator,
},
typecheck::typedef::{iter_type_vars, Type, TypeEnum},
};
use super::{ndarray::indexing::RustUserSlice, AnyObject};
/// A NAC3 Python List object.
#[derive(Debug, Clone, Copy)]
pub struct ListObject<'ctx> {
/// Typechecker type of the list items
pub item_type: Type,
pub instance: Ptr<'ctx, StructModel<List<AnyModel<'ctx>>>>,
}
impl<'ctx> ListObject<'ctx> {
/// Create a [`ListObject`] from an LLVM value and its typechecker [`Type`].
pub fn from_object<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
object: AnyObject<'ctx>,
) -> Self {
// Check typechecker type and extract `item_type`
let item_type = match &*ctx.unifier.get_ty(object.ty) {
TypeEnum::TObj { obj_id, params, .. }
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{
iter_type_vars(params).next().unwrap().ty // Extract `item_type`
}
_ => {
panic!("Expecting type to be a list, but got {}", ctx.unifier.stringify(object.ty))
}
};
let item_model = AnyModel(ctx.get_llvm_type(generator, item_type));
let plist_model = PtrModel(StructModel(List { item: item_model }));
// Create object
let value = plist_model.check_value(generator, ctx.ctx, object.value).unwrap();
ListObject { item_type, instance: value }
}
/// Get the `items` field as an opaque pointer.
pub fn get_opaque_items_ptr<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Ptr<'ctx, IntModel<Byte>> {
self.instance.get(generator, ctx, |f| f.items, "items").pointer_cast(
generator,
ctx,
IntModel(Byte),
"items_opaque",
)
}
/// Get the value of this [`ListObject`] as a list with opaque items.
///
/// This function allocates on the stack to create the list, but the
/// reference to the `items` are preserved.
pub fn get_opaque_list_ptr<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Ptr<'ctx, StructModel<List<IntModel<Byte>>>> {
let opaque_list_model = StructModel(List { item: IntModel(Byte) });
let opaque_list_ptr = opaque_list_model.alloca(generator, ctx, "opaque_list_ptr");
// Copy items pointer
let items = self.get_opaque_items_ptr(generator, ctx);
opaque_list_ptr.set(ctx, |f| f.items, items);
// Copy len
let len = self.instance.get(generator, ctx, |f| f.len, "len");
opaque_list_ptr.set(ctx, |f| f.len, len);
opaque_list_ptr
}
/// Get the `len()` of this list.
pub fn len<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, SizeT> {
self.instance.get(generator, ctx, |f| f.len, "list_len")
}
pub fn slice_assign_from<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
user_slice: &RustUserSlice<'ctx>,
source: ListObject<'ctx>,
) {
// Sanity check
assert!(ctx.unifier.unioned(self.item_type, source.item_type));
let user_slice_model = StructModel(UserSlice);
let puser_slice = user_slice_model.alloca(generator, ctx, "user_slice");
user_slice.write_to_user_slice(generator, ctx, puser_slice);
let itemsize = self.instance.model.get_type(generator, ctx.ctx).size_of();
call_nac3_list_slice_assign(
generator,
ctx,
self.get_opaque_list_ptr(generator, ctx),
source.instance.value,
itemsize,
user_slice,
);
todo!()
}
}

View File

@ -0,0 +1,608 @@
use inkwell::{
values::{BasicValue, BasicValueEnum, FloatValue, IntValue},
FloatPredicate, IntPredicate,
};
use itertools::Itertools;
use list::ListObject;
use ndarray::{NDArrayObject, NDArrayOut};
use range::RangeObject;
use tuple::TupleObject;
use crate::{
toplevel::helper::PrimDef,
typecheck::typedef::{Type, TypeEnum},
};
use super::{llvm_intrinsics, model::*, CodeGenContext, CodeGenerator};
pub mod list;
pub mod ndarray;
pub mod range;
pub mod tuple;
/// Convenience function to crash the program when types of arguments are not supported.
/// Used to be debugged with a stacktrace.
fn unsupported_type<I>(ctx: &CodeGenContext<'_, '_>, tys: I) -> !
where
I: IntoIterator<Item = Type>,
{
unreachable!(
"unsupported types found '{}'",
tys.into_iter().map(|ty| format!("'{}'", ctx.unifier.stringify(ty))).join(", "),
)
}
#[derive(Debug, Clone, Copy)]
pub enum FloorOrCeil {
Floor,
Ceil,
}
#[derive(Debug, Clone, Copy)]
pub enum MinOrMax {
Min,
Max,
}
fn signed_ints(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
vec![ctx.primitives.int32, ctx.primitives.int64]
}
fn unsigned_ints(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
vec![ctx.primitives.uint32, ctx.primitives.uint64]
}
fn ints(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
vec![ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64]
}
fn int_like(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
vec![
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.int64,
ctx.primitives.uint32,
ctx.primitives.uint64,
]
}
#[derive(Debug, Clone, Copy)]
pub struct AnyObject<'ctx> {
pub ty: Type,
pub value: BasicValueEnum<'ctx>,
}
impl<'ctx> AnyObject<'ctx> {
/// Returns true if this object's type is a [`TypeEnum::TObj`] and has the object ID as `prim`.
pub fn is_obj(&self, ctx: &mut CodeGenContext<'ctx, '_>, prim: PrimDef) -> bool {
match &*ctx.unifier.get_ty(self.ty) {
TypeEnum::TObj { obj_id, .. } => *obj_id == prim.id(),
_ => false,
}
}
/// Returns true if this object's type is a [`TypeEnum::TTuple`]
pub fn is_tuple(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
matches!(&*ctx.unifier.get_ty(self.ty), TypeEnum::TTuple { .. })
}
pub fn into_tuple() {}
pub fn is_int32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
ctx.unifier.unioned(self.ty, ctx.primitives.int32)
}
pub fn into_int32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Int<'ctx, Int32> {
assert!(self.is_int32(ctx));
IntModel(Int32).believe_value(self.value.into_int_value())
}
pub fn is_uint32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
ctx.unifier.unioned(self.ty, ctx.primitives.uint32)
}
pub fn is_int64(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
ctx.unifier.unioned(self.ty, ctx.primitives.int64)
}
pub fn is_uint64(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
ctx.unifier.unioned(self.ty, ctx.primitives.uint64)
}
pub fn is_bool(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
ctx.unifier.unioned(self.ty, ctx.primitives.bool)
}
/// Returns true if the object type is `bool`, `int32`, `int64`, `uint32`, or `uint64`.
pub fn is_int_like(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
ctx.unifier.unioned_any(self.ty, int_like(ctx))
}
/// Returns true if the object type is `int32`, `int64`.
pub fn is_signed_int(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
ctx.unifier.unioned_any(self.ty, signed_ints(ctx))
}
/// Returns true if the object type is `uint32`, `uint64`.
pub fn is_unsigned_int(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
ctx.unifier.unioned_any(self.ty, unsigned_ints(ctx))
}
pub fn into_int(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
assert!(self.is_int_like(ctx));
self.value.into_int_value()
}
pub fn is_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
self.is_obj(ctx, PrimDef::Float)
}
pub fn into_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Float<'ctx, Float64> {
assert!(self.is_float(ctx));
FloatModel(Float64).believe_value(self.value.into_float_value())
}
pub fn is_ndarray(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
self.is_obj(ctx, PrimDef::NDArray)
}
pub fn into_ndarray<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> NDArrayObject<'ctx> {
NDArrayObject::from_object(generator, ctx, *self)
}
/// Create an object from a boolean from an i1.
///
/// NOTE: In NAC3, booleans are i8. This function does converts the input i1 to an i8.
pub fn from_bool(ctx: &mut CodeGenContext<'ctx, '_>, n: Int<'ctx, Bool>) -> AnyObject<'ctx> {
let llvm_i8 = ctx.ctx.i8_type();
let value = ctx.builder.build_int_z_extend(n.value, llvm_i8, "bool").unwrap();
AnyObject { value: value.as_basic_value_enum(), ty: ctx.primitives.bool }
}
/// Helper function to compare two scalars.
///
/// Only int-to-int and float-to-float comparisons are allowed.
///
/// Panic otherwise.
pub fn compare_int_or_float_by_predicate<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
lhs: AnyObject<'ctx>,
rhs: AnyObject<'ctx>,
int_predicate: IntPredicate,
float_predicate: FloatPredicate,
name: &str,
) -> Int<'ctx, Bool> {
assert!(ctx.unifier.unioned(lhs.ty, rhs.ty), "lhs and rhs type should be the same");
let bool_model = IntModel(Bool);
let common_ty = lhs.ty;
let result = if lhs.is_float(ctx) {
let lhs = lhs.into_float(ctx);
let rhs = rhs.into_float(ctx);
ctx.builder.build_float_compare(float_predicate, lhs.value, rhs.value, name).unwrap()
} else if ctx.unifier.unioned_any(common_ty, int_like(ctx)) {
let lhs = lhs.into_int(ctx);
let rhs = rhs.into_int(ctx);
ctx.builder.build_int_compare(int_predicate, lhs, rhs, name).unwrap()
} else {
unsupported_type(ctx, [lhs.ty, rhs.ty]);
};
bool_model.check_value(generator, ctx.ctx, result).unwrap()
}
/// Helper function for `int32()`, `int64()`, `uint32()`, and `uint64()`.
///
/// TODO: Document me
fn cast_to_int_conversion<'a, G, HandleFloatFn>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
ret_int_ty: Type,
handle_float: HandleFloatFn,
) -> AnyObject<'ctx>
where
G: CodeGenerator + ?Sized,
HandleFloatFn:
FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, FloatValue<'ctx>) -> IntValue<'ctx>,
{
let ret_int_ty_llvm = ctx.get_llvm_type(generator, ret_int_ty).into_int_type();
let result = if self.is_float(ctx) {
// Handle float to int
let n = self.into_float(ctx);
handle_float(generator, ctx, n.value)
} else if self.is_int_like(ctx) {
// Handle int to a new int type
let n = self.into_int(ctx);
if n.get_type().get_bit_width() <= ret_int_ty_llvm.get_bit_width() {
ctx.builder.build_int_z_extend(n, ret_int_ty_llvm, "zext").unwrap()
} else {
ctx.builder.build_int_truncate(n, ret_int_ty_llvm, "trunc").unwrap()
}
} else {
unsupported_type(ctx, [self.ty]);
};
assert_eq!(ret_int_ty_llvm.get_bit_width(), result.get_type().get_bit_width()); // Sanity check
AnyObject { value: result.into(), ty: ret_int_ty }
}
/// Call `int32()` on this object.
#[must_use]
pub fn call_int32<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> AnyObject<'ctx> {
self.cast_to_int_conversion(
generator,
ctx,
ctx.primitives.int32,
|_generator, ctx, input| {
let n =
ctx.builder.build_float_to_signed_int(input, ctx.ctx.i64_type(), "").unwrap();
ctx.builder.build_int_truncate(n, ctx.ctx.i32_type(), "conv").unwrap()
},
)
}
/// Call `int64()` on this object.
#[must_use]
pub fn call_int64<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> AnyObject<'ctx> {
self.cast_to_int_conversion(
generator,
ctx,
ctx.primitives.int64,
|_generator, ctx, input| {
ctx.builder.build_float_to_signed_int(input, ctx.ctx.i64_type(), "").unwrap()
},
)
}
/// Call `uint32()` on this object.
#[must_use]
pub fn call_uint32<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> AnyObject<'ctx> {
self.cast_to_int_conversion(generator, ctx, ctx.primitives.uint32, |_generator, ctx, n| {
let n_gez = ctx
.builder
.build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "")
.unwrap();
let to_int32 =
ctx.builder.build_float_to_signed_int(n, ctx.ctx.i32_type(), "").unwrap();
let to_uint64 =
ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap();
ctx.builder
.build_select(
n_gez,
ctx.builder.build_int_truncate(to_uint64, ctx.ctx.i32_type(), "").unwrap(),
to_int32,
"conv",
)
.unwrap()
.into_int_value()
})
}
/// Call `uint64()` on this object.
#[must_use]
pub fn call_uint64<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> AnyObject<'ctx> {
self.cast_to_int_conversion(generator, ctx, ctx.primitives.uint64, |_generator, ctx, n| {
let val_gez = ctx
.builder
.build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "")
.unwrap();
let to_int64 =
ctx.builder.build_float_to_signed_int(n, ctx.ctx.i64_type(), "").unwrap();
let to_uint64 =
ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap();
ctx.builder.build_select(val_gez, to_uint64, to_int64, "conv").unwrap().into_int_value()
})
}
// Get the `len()` of this object.
#[must_use]
pub fn call_len<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> AnyObject<'ctx> {
// TODO: Switch to returning SizeT
let result = match &*ctx.unifier.get_ty_immutable(self.ty) {
TypeEnum::TTuple { .. } => {
let tuple = TupleObject::from_object(ctx, *self);
tuple.len(generator, ctx).truncate(generator, ctx, Int32, "tuple_len_32")
}
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() =>
{
let range = RangeObject::from_object(generator, ctx, *self);
range.len(generator, ctx)
}
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{
let list = ListObject::from_object(generator, ctx, *self);
list.len(generator, ctx).truncate(generator, ctx, Int32, "list_len_i32")
}
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
{
let ndarray = NDArrayObject::from_object(generator, ctx, *self);
ndarray.len(generator, ctx).truncate(generator, ctx, Int32, "ndarray_len_i32")
}
_ => unreachable!(),
};
AnyObject { ty: ctx.primitives.int32, value: result.value.as_basic_value_enum() }
}
/// Like [`AnyObject::call_bool`] but this returns an `Int<'ctx, Bool>` instead of an object.
pub fn bool(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Int<'ctx, Bool> {
let bool_model = IntModel(Bool);
if self.is_int_like(ctx) {
let n = self.into_int(ctx);
let n = ctx
.builder
.build_int_compare(inkwell::IntPredicate::NE, n, n.get_type().const_zero(), "bool")
.unwrap();
bool_model.believe_value(n)
} else if self.is_float(ctx) {
let n = self.value.into_float_value();
let n = ctx
.builder
.build_float_compare(FloatPredicate::UNE, n, n.get_type().const_zero(), "bool")
.unwrap();
bool_model.believe_value(n)
} else {
unsupported_type(ctx, [self.ty])
}
}
/// Call `bool()` on this object.
#[must_use]
pub fn call_bool(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> AnyObject<'ctx> {
let n = self.bool(ctx);
AnyObject::from_bool(ctx, n)
}
/// Call `float()` on this object.
#[must_use]
pub fn call_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> AnyObject<'ctx> {
let f64_model = FloatModel(Float64);
let llvm_f64 = ctx.ctx.f64_type();
let result = if self.is_float(ctx) {
self.into_float(ctx)
} else if self.is_signed_int(ctx) || self.is_bool(ctx) {
let n = self.into_int(ctx);
let n = ctx.builder.build_signed_int_to_float(n, llvm_f64, "sitofp").unwrap();
f64_model.believe_value(n)
} else if self.is_unsigned_int(ctx) {
let n = self.into_int(ctx);
let n = ctx.builder.build_unsigned_int_to_float(n, llvm_f64, "uitofp").unwrap();
f64_model.believe_value(n)
} else {
unsupported_type(ctx, [self.ty]);
};
AnyObject { ty: ctx.primitives.float, value: result.value.as_basic_value_enum() }
}
// Call `abs()` on this object.
#[must_use]
pub fn call_abs<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Self {
if self.is_float(ctx) {
let n = self.value.into_float_value();
let n = llvm_intrinsics::call_float_fabs(ctx, n, Some("abs"));
AnyObject { value: n.into(), ty: ctx.primitives.float }
} else if self.is_unsigned_int(ctx) || self.is_signed_int(ctx) {
let is_poisoned = ctx.ctx.bool_type().const_zero(); // is_poisoned = false
let n = self.value.into_int_value();
let n = llvm_intrinsics::call_int_abs(ctx, n, is_poisoned, Some("abs"));
AnyObject { value: n.into(), ty: self.ty }
} else if self.is_ndarray(ctx) {
let ndarray = self.into_ndarray(generator, ctx);
ndarray
.map(
generator,
ctx,
NDArrayOut::NewNDArray { dtype: ndarray.dtype },
|generator, ctx, scalar| Ok(scalar.call_abs(generator, ctx)),
)
.unwrap()
.to_any_object(ctx)
} else {
unsupported_type(ctx, [self.ty])
}
}
// Call `round()` on this object.
//
// It is possible to specify which kind of int type to return.
#[must_use]
pub fn call_round<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ret_int_ty: Type,
) -> AnyObject<'ctx> {
let ret_int_ty_llvm = ctx.get_llvm_type(generator, ret_int_ty).into_int_type();
let result = if ctx.unifier.unioned(self.ty, ctx.primitives.float) {
let n = self.value.into_float_value();
let n = llvm_intrinsics::call_float_round(ctx, n, None);
ctx.builder.build_float_to_signed_int(n, ret_int_ty_llvm, "round").unwrap()
} else {
unsupported_type(ctx, [self.ty])
};
AnyObject { ty: ret_int_ty, value: result.as_basic_value_enum() }
}
/// Call `np_round()` on this object.
///
/// NOTE: `np.round()` has different behaviors than `round()` when in comes to "tie" cases and return type.
#[must_use]
pub fn call_np_round<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> AnyObject<'ctx> {
if self.is_float(ctx) {
let n = self.into_float(ctx);
let n = llvm_intrinsics::call_float_rint(ctx, n.value, None);
AnyObject { ty: ctx.primitives.float, value: n.as_basic_value_enum() }
} else if self.is_ndarray(ctx) {
let ndarray = self.into_ndarray(generator, ctx);
ndarray
.map(
generator,
ctx,
NDArrayOut::NewNDArray { dtype: ndarray.dtype },
|generator, ctx, scalar| Ok(scalar.call_np_round(generator, ctx)),
)
.unwrap()
.to_any_object(ctx)
} else {
unsupported_type(ctx, [self.ty])
}
}
/// Call `min()` or `max()` on two objects.
#[must_use]
pub fn call_min_or_max(
ctx: &mut CodeGenContext<'ctx, '_>,
kind: MinOrMax,
a: AnyObject<'ctx>,
b: AnyObject<'ctx>,
) -> AnyObject<'ctx> {
if !ctx.unifier.unioned(a.ty, b.ty) {
unsupported_type(ctx, [a.ty, b.ty])
}
let common_ty = a.ty;
if a.is_float(ctx) {
let function = match kind {
MinOrMax::Min => llvm_intrinsics::call_float_minnum,
MinOrMax::Max => llvm_intrinsics::call_float_maxnum,
};
let a = a.into_float(ctx).value;
let b = b.into_float(ctx).value;
let result = function(ctx, a, b, None);
AnyObject { value: result.as_basic_value_enum(), ty: ctx.primitives.float }
} else if a.is_unsigned_int(ctx) || a.is_bool(ctx) {
// Treating bool has an unsigned int since that is convenient
let function = match kind {
MinOrMax::Min => llvm_intrinsics::call_int_umin,
MinOrMax::Max => llvm_intrinsics::call_int_umax,
};
let a = a.into_int(ctx);
let b = b.into_int(ctx);
let result = function(ctx, a, b, None);
AnyObject { value: result.as_basic_value_enum(), ty: common_ty }
} else if a.is_signed_int(ctx) {
let function = match kind {
MinOrMax::Min => llvm_intrinsics::call_int_smin,
MinOrMax::Max => llvm_intrinsics::call_int_smax,
};
let a = a.into_int(ctx);
let b = b.into_int(ctx);
let result = function(ctx, a, b, None);
AnyObject { value: result.as_basic_value_enum(), ty: common_ty }
} else {
unsupported_type(ctx, [common_ty])
}
}
/// Call `floor()` or `ceil()` on this object.
///
/// It is possible to specify which kind of int type to return.
#[must_use]
pub fn call_floor_or_ceil<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
kind: FloorOrCeil,
ret_int_ty: Type,
) -> Self {
let ret_int_ty_llvm = ctx.get_llvm_type(generator, ret_int_ty).into_int_type();
if self.is_float(ctx) {
let function = match kind {
FloorOrCeil::Floor => llvm_intrinsics::call_float_floor,
FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil,
};
let n = self.into_float(ctx).value;
let n = function(ctx, n, None);
let n = ctx.builder.build_float_to_signed_int(n, ret_int_ty_llvm, "").unwrap();
AnyObject { ty: ret_int_ty, value: n.as_basic_value_enum() }
} else {
unsupported_type(ctx, [self.ty])
}
}
/// Call `np_floor()` or `np_ceil()` on this object.
#[must_use]
pub fn call_np_floor_or_ceil<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
kind: FloorOrCeil,
) -> Self {
// TODO:
if self.is_float(ctx) {
let function = match kind {
FloorOrCeil::Floor => llvm_intrinsics::call_float_floor,
FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil,
};
let n = self.into_float(ctx).value;
let n = function(ctx, n, None);
AnyObject { ty: ctx.primitives.float, value: n.as_basic_value_enum() }
} else if self.is_ndarray(ctx) {
let ndarray = self.into_ndarray(generator, ctx);
ndarray
.map(
generator,
ctx,
NDArrayOut::NewNDArray { dtype: ctx.primitives.float },
|generator, ctx, scalar| Ok(scalar.call_np_floor_or_ceil(generator, ctx, kind)),
)
.unwrap()
.to_any_object(ctx)
} else {
unsupported_type(ctx, [self.ty])
}
}
}

View File

@ -0,0 +1,176 @@
use super::NDArrayObject;
use crate::{
codegen::{
irrt::{call_nac3_array_set_and_validate_list_shape, call_nac3_array_write_list_to_array},
model::*,
object::{list::ListObject, AnyObject},
stmt::gen_if_else_expr_callback,
CodeGenContext, CodeGenerator,
},
toplevel::helper::{arraylike_flatten_element_type, arraylike_get_ndims},
typecheck::typedef::{Type, TypeEnum},
};
fn get_list_object_dtype_and_ndims<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
list: ListObject<'ctx>,
) -> (Type, u64) {
let dtype = arraylike_flatten_element_type(&mut ctx.unifier, list.item_type);
let ndims = arraylike_get_ndims(&mut ctx.unifier, list.item_type);
let ndims = ndims + 1; // To count `list` itself.
(dtype, ndims)
}
impl<'ctx> NDArrayObject<'ctx> {
fn from_np_array_list_copy_impl<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
list: ListObject<'ctx>,
) -> Self {
let sizet_model = IntModel(SizeT);
let (dtype, ndims_int) = get_list_object_dtype_and_ndims(ctx, list);
let list_value = list.get_opaque_list_ptr(generator, ctx);
// Validate `list` has a consistent shape.
// Raise an exception if `list` is something abnormal like `[[1, 2], [3]]`.
// If `list` has a consistent shape, deduce the shape and write it to `shape`.
let ndims = sizet_model.constant(generator, ctx.ctx, ndims_int);
let shape = sizet_model.array_alloca(generator, ctx, ndims.value, "shape");
call_nac3_array_set_and_validate_list_shape(generator, ctx, list_value, ndims, shape);
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims_int, "ndarray_from_list");
ndarray.copy_shape_from_array(generator, ctx, shape);
ndarray.create_data(generator, ctx);
// Copy all contents from the list.
call_nac3_array_write_list_to_array(generator, ctx, list_value, ndarray.instance);
ndarray
}
fn from_np_array_list_try_no_copy_impl<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
list: ListObject<'ctx>,
) -> Self {
// np_array without copy is only possible `list` is not nested.
// If `list` is `list[T]`, we can create an ndarray with `data` set
// to the array pointer of `list`.
let sizet_model = IntModel(SizeT);
let zero = sizet_model.const_0(generator, ctx.ctx);
let (dtype, ndims) = get_list_object_dtype_and_ndims(ctx, list);
if ndims == 1 {
// `list` is not nested, does not need to copy.
let ndarray =
NDArrayObject::alloca(generator, ctx, dtype, 1, "ndarray_from_list_no_copy");
// Set data
let data = list.get_opaque_items_ptr(generator, ctx);
ndarray.instance.set(ctx, |f| f.data, data);
// Set shape
// dim = list->len;
// shape[0] = dim;
let shape = ndarray.instance.get(generator, ctx, |f| f.shape, "shape");
let dim = list.instance.get(generator, ctx, |f| f.len, "dim");
shape.offset(generator, ctx, zero.value, "pdim").store(ctx, dim);
// Set strides, the `data` is contiguous
ndarray.update_strides_by_shape(generator, ctx);
// Done
ndarray
} else {
// `list` is nested, it is impossible to not copy.
NDArrayObject::from_np_array_list_copy_impl(generator, ctx, list)
}
}
fn from_np_array_list_impl<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
list: ListObject<'ctx>,
copy: Int<'ctx, Bool>,
) -> Self {
let (dtype, ndims) = get_list_object_dtype_and_ndims(ctx, list);
let ndarray = gen_if_else_expr_callback(
generator,
ctx,
|_generator, _ctx| Ok(copy.value),
|generator, ctx| {
let ndarray = NDArrayObject::from_np_array_list_copy_impl(generator, ctx, list);
Ok(Some(ndarray.instance.value))
},
|generator, ctx| {
let ndarray =
NDArrayObject::from_np_array_list_try_no_copy_impl(generator, ctx, list);
Ok(Some(ndarray.instance.value))
},
)
.unwrap()
.unwrap();
NDArrayObject::from_value_and_unpacked_types(generator, ctx, ndarray, dtype, ndims)
}
pub fn from_np_array_ndarray_impl<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayObject<'ctx>,
copy: Int<'ctx, Bool>,
) -> Self {
let ndarray_val = gen_if_else_expr_callback(
generator,
ctx,
|_generator, _ctx| Ok(copy.value),
|generator, ctx| {
let ndarray = ndarray.make_copy(generator, ctx, "np_array_copied_ndarray"); // Force copy
Ok(Some(ndarray.instance.value))
},
|_generator, _ctx| {
// No need to copy. Return `ndarray` itself.
Ok(Some(ndarray.instance.value))
},
)
.unwrap()
.unwrap();
NDArrayObject::from_value_and_unpacked_types(
generator,
ctx,
ndarray_val,
ndarray.dtype,
ndarray.ndims,
)
}
pub fn from_np_array<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
object: AnyObject<'ctx>,
copy: Int<'ctx, Bool>,
) -> Self {
match &*ctx.unifier.get_ty(object.ty) {
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{
let list = ListObject::from_object(generator, ctx, object);
NDArrayObject::from_np_array_list_impl(generator, ctx, list, copy)
}
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
{
let ndarray = NDArrayObject::from_object(generator, ctx, object);
NDArrayObject::from_np_array_ndarray_impl(generator, ctx, ndarray, copy)
}
_ => panic!("Unrecognized object type: {}", ctx.unifier.stringify(object.ty)), // Typechecker ensures this
}
}
}

View File

@ -0,0 +1,161 @@
use itertools::Itertools;
use crate::codegen::{
irrt::{
call_nac3_ndarray_broadcast_shapes, call_nac3_ndarray_broadcast_to,
call_nac3_ndarray_util_assert_shape_no_negative,
},
model::*,
CodeGenContext, CodeGenerator,
};
use super::NDArrayObject;
/// Fields of [`ShapeEntry`]
pub struct ShapeEntryFields<'ctx, F: FieldTraversal<'ctx>> {
pub ndims: F::Out<IntModel<SizeT>>,
pub shape: F::Out<PtrModel<IntModel<SizeT>>>,
}
/// An IRRT structure used in broadcasting.
#[derive(Debug, Clone, Copy, Default)]
pub struct ShapeEntry;
impl<'ctx> StructKind<'ctx> for ShapeEntry {
type Fields<F: FieldTraversal<'ctx>> = ShapeEntryFields<'ctx, F>;
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
Self::Fields { ndims: traversal.add_auto("ndims"), shape: traversal.add_auto("shape") }
}
}
impl<'ctx> NDArrayObject<'ctx> {
/// Create a broadcast view on this ndarray with a target shape.
///
/// The input shape will be checked to make sure that it contains no negative values.
///
/// * `target_ndims` - The ndims type after broadcasting to the given shape.
/// The caller has to figure this out for this function.
/// * `target_shape` - An array pointer pointing to the target shape.
#[must_use]
pub fn broadcast_to<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
target_ndims: u64,
target_shape: Ptr<'ctx, IntModel<SizeT>>,
) -> Self {
let target_ndims_llvm = IntModel(SizeT).constant(generator, ctx.ctx, target_ndims);
call_nac3_ndarray_util_assert_shape_no_negative(
generator,
ctx,
target_ndims_llvm,
target_shape,
);
let broadcast_ndarray = NDArrayObject::alloca(
generator,
ctx,
self.dtype,
target_ndims,
"broadcast_ndarray_to_dst",
);
broadcast_ndarray.copy_shape_from_array(generator, ctx, target_shape);
call_nac3_ndarray_broadcast_to(generator, ctx, self.instance, broadcast_ndarray.instance);
broadcast_ndarray
}
}
/// A result produced by [`broadcast_all_ndarrays`]
#[derive(Debug, Clone)]
pub struct BroadcastAllResult<'ctx> {
/// The statically known `ndims` of the broadcast result.
pub ndims: u64,
/// The broadcasting shape.
pub shape: Ptr<'ctx, IntModel<SizeT>>,
/// Broadcasted views on the inputs.
///
/// All of them will have `shape` [`BroadcastAllResult::shape`] and
/// `ndims` [`BroadcastAllResult::ndims`]. The length of the vector
/// is the same as the input.
pub ndarrays: Vec<NDArrayObject<'ctx>>,
}
pub fn broadcast_shapes<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
in_entries: &[(Ptr<'ctx, IntModel<SizeT>>, u64)],
broadcast_ndims: u64,
broadcast_shape: Ptr<'ctx, IntModel<SizeT>>,
) {
let sizet_model = IntModel(SizeT);
let shape_model = StructModel(ShapeEntry);
// Prepare input shape entries
let num_shape_entries =
sizet_model.constant(generator, ctx.ctx, u64::try_from(in_entries.len()).unwrap());
let shape_entries =
shape_model.array_alloca(generator, ctx, num_shape_entries.value, "shape_entries");
for (i, (in_shape, in_ndims)) in in_entries.iter().enumerate() {
let i = sizet_model.constant(generator, ctx.ctx, i as u64).value;
let pshape_entry = shape_entries.offset(generator, ctx, i, "shape_entry");
let in_ndims = sizet_model.constant(generator, ctx.ctx, *in_ndims);
pshape_entry.set(ctx, |f| f.ndims, in_ndims);
pshape_entry.set(ctx, |f| f.shape, *in_shape);
}
let broadcast_ndims = sizet_model.constant(generator, ctx.ctx, broadcast_ndims);
call_nac3_ndarray_broadcast_shapes(
generator,
ctx,
num_shape_entries,
shape_entries,
broadcast_ndims,
broadcast_shape,
);
}
impl<'ctx> NDArrayObject<'ctx> {
// TODO: DOCUMENT: Behaves like `np.broadcast()`, except returns results differently.
pub fn broadcast<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarrays: &[Self],
) -> BroadcastAllResult<'ctx> {
assert!(!ndarrays.is_empty());
let sizet_model = IntModel(SizeT);
// Infer the broadcast output ndims.
let broadcast_ndims_int = ndarrays.iter().map(|ndarray| ndarray.ndims).max().unwrap();
let broadcast_ndims = sizet_model.constant(generator, ctx.ctx, broadcast_ndims_int);
let broadcast_shape =
sizet_model.array_alloca(generator, ctx, broadcast_ndims.value, "broadcast_shape");
let shape_entries = ndarrays
.iter()
.map(|ndarray| {
(ndarray.instance.get(generator, ctx, |f| f.shape, "shape"), ndarray.ndims)
})
.collect_vec();
broadcast_shapes(generator, ctx, &shape_entries, broadcast_ndims_int, broadcast_shape);
// Broadcast all the inputs to shape `dst_shape`.
let broadcast_ndarrays: Vec<_> = ndarrays
.iter()
.map(|ndarray| {
ndarray.broadcast_to(generator, ctx, broadcast_ndims_int, broadcast_shape)
})
.collect_vec();
BroadcastAllResult {
ndims: broadcast_ndims_int,
shape: broadcast_shape,
ndarrays: broadcast_ndarrays,
}
}
}

View File

@ -0,0 +1,238 @@
use inkwell::{values::BasicValueEnum, IntPredicate};
use super::NDArrayObject;
use crate::{
codegen::{
irrt::call_nac3_ndarray_util_assert_shape_no_negative, model::*, object::AnyObject,
CodeGenContext, CodeGenerator,
},
typecheck::typedef::Type,
};
/// Get the zero value in `np.zeros()` of a `dtype`.
fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
) -> BasicValueEnum<'ctx> {
if [ctx.primitives.int32, ctx.primitives.uint32]
.iter()
.any(|ty| ctx.unifier.unioned(dtype, *ty))
{
ctx.ctx.i32_type().const_zero().into()
} else if [ctx.primitives.int64, ctx.primitives.uint64]
.iter()
.any(|ty| ctx.unifier.unioned(dtype, *ty))
{
ctx.ctx.i64_type().const_zero().into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.float) {
ctx.ctx.f64_type().const_zero().into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.bool) {
ctx.ctx.bool_type().const_zero().into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.str) {
ctx.gen_string(generator, "").value.into()
} else {
panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype));
}
}
/// Get the one value in `np.ones()` of a `dtype`.
fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
) -> BasicValueEnum<'ctx> {
if [ctx.primitives.int32, ctx.primitives.uint32]
.iter()
.any(|ty| ctx.unifier.unioned(dtype, *ty))
{
let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int32);
ctx.ctx.i32_type().const_int(1, is_signed).into()
} else if [ctx.primitives.int64, ctx.primitives.uint64]
.iter()
.any(|ty| ctx.unifier.unioned(dtype, *ty))
{
let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int64);
ctx.ctx.i64_type().const_int(1, is_signed).into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.float) {
ctx.ctx.f64_type().const_float(1.0).into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.bool) {
ctx.ctx.bool_type().const_int(1, false).into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.str) {
ctx.gen_string(generator, "1").value.into()
} else {
panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype));
}
}
impl<'ctx> NDArrayObject<'ctx> {
/// Create an ndarray like `np.empty`.
pub fn from_np_empty<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
ndims: u64,
shape: Ptr<'ctx, IntModel<SizeT>>,
) -> Self {
// Validate `shape`
// TODO: Should the caller be responsible for this instead?
let ndims_llvm = IntModel(SizeT).constant(generator, ctx.ctx, ndims);
call_nac3_ndarray_util_assert_shape_no_negative(generator, ctx, ndims_llvm, shape);
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims, "full_ndarray");
ndarray.copy_shape_from_array(generator, ctx, shape);
ndarray.create_data(generator, ctx);
ndarray
}
/// Create an ndarray like `np.full`.
pub fn from_np_full<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
ndims: u64,
shape: Ptr<'ctx, IntModel<SizeT>>,
fill_value: AnyObject<'ctx>,
) -> Self {
// Sanity check on `fill_value`'s dtype.
assert!(ctx.unifier.unioned(dtype, fill_value.ty));
let ndarray = NDArrayObject::from_np_empty(generator, ctx, dtype, ndims, shape);
ndarray.fill(generator, ctx, fill_value);
ndarray
}
/// Create an ndarray like `np.zero`.
pub fn from_np_zero<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
ndims: u64,
shape: Ptr<'ctx, IntModel<SizeT>>,
) -> Self {
let fill_value = ndarray_zero_value(generator, ctx, dtype);
let fill_value = AnyObject { value: fill_value, ty: dtype };
NDArrayObject::from_np_full(generator, ctx, dtype, ndims, shape, fill_value)
}
/// Create an ndarray like `np.ones`.
pub fn from_np_ones<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
ndims: u64,
shape: Ptr<'ctx, IntModel<SizeT>>,
) -> Self {
let fill_value = ndarray_one_value(generator, ctx, dtype);
let fill_value = AnyObject { value: fill_value, ty: dtype };
NDArrayObject::from_np_full(generator, ctx, dtype, ndims, shape, fill_value)
}
/// Create an ndarray like `np.arange`.
/// The returned ndarray's `dtype` is always `float`
pub fn from_np_arange<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
length: Int<'ctx, SizeT>,
) -> Self {
let ndarray = NDArrayObject::alloca(
generator,
ctx,
ctx.primitives.float,
1, // ndims = 1
"arange_ndarray",
);
// `ndarray.shape[0] = length`
ndarray
.instance
.get(generator, ctx, |f| f.shape, "shape")
.offset_const(generator, ctx, 0, "dim")
.store(ctx, length);
// Create data and set elements
ndarray.create_data(generator, ctx);
ndarray
.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
// Get the index of the current element, convert that index to float, and write it.
// This is how we get [0.0, 1.0, 2.0, ...].
let index = nditer.get_index(generator, ctx);
let pelement = nditer.get_pointer(generator, ctx);
let val = ctx
.builder
.build_unsigned_int_to_float(index.value, ctx.ctx.f64_type(), "val")
.unwrap();
ctx.builder.build_store(pelement, val).unwrap();
Ok(())
})
.unwrap();
ndarray
}
/// Create an ndarray like `np.eye`.
pub fn from_np_eye<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
nrows: Int<'ctx, SizeT>,
ncols: Int<'ctx, SizeT>,
offset: Int<'ctx, SizeT>,
) -> Self {
let ndzero = ndarray_zero_value(generator, ctx, dtype);
let ndone = ndarray_one_value(generator, ctx, dtype);
let ndarray = NDArrayObject::alloca_dynamic_shape(
generator,
ctx,
dtype,
&[nrows, ncols],
"eye_ndarray",
);
// Create data and make the matrix like look np.eye()
ndarray.create_data(generator, ctx);
ndarray
.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
// NOTE: rows and cols can never be zero here, since this ndarray's `np.size` would be zero
// and this loop would not execute.
// Load up `row_i` and `col_i` from indices.
let row_i = nditer
.get_indices()
.offset_const(generator, ctx, 0, "")
.load(generator, ctx, "row_i");
let col_i = nditer
.get_indices()
.offset_const(generator, ctx, 1, "")
.load(generator, ctx, "col_i");
// Write to element
let be_one =
row_i.add(ctx, offset, "").compare(ctx, IntPredicate::EQ, col_i, "write_one");
let value = ctx.builder.build_select(be_one.value, ndone, ndzero, "value").unwrap();
let p = nditer.get_pointer(generator, ctx);
ctx.builder.build_store(p, value).unwrap();
Ok(())
})
.unwrap();
ndarray
}
/// Create an ndarray like `np.identity`.
pub fn from_np_identity<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
size: Int<'ctx, SizeT>,
) -> Self {
// Convenient implementation
let offset = IntModel(SizeT).const_0(generator, ctx.ctx);
NDArrayObject::from_np_eye(generator, ctx, dtype, size, size, offset)
}
}

View File

@ -0,0 +1,128 @@
use inkwell::{FloatPredicate, IntPredicate};
use crate::codegen::{
model::*,
object::{AnyObject, MinOrMax},
stmt::gen_if_callback,
CodeGenContext, CodeGenerator,
};
use super::NDArrayObject;
impl<'ctx> NDArrayObject<'ctx> {
/// Helper function to implement NAC3's builtin `np_min()`, `np_max()`, `np_argmin()`, and `np_argmax()`.
///
/// Generate LLVM IR to find the extremum and index of the **first** extremum value.
///
/// Care has also been taken to make the error messages match that of NumPy.
fn min_max_argmin_argmax_helper<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
kind: MinOrMax,
on_empty_err_msg: &str,
) -> (AnyObject<'ctx>, Int<'ctx, SizeT>) {
let sizet_model = IntModel(SizeT);
let dtype_llvm = ctx.get_llvm_type(generator, self.dtype);
// If the ndarray is empty, throw an error.
let is_empty = self.is_empty(generator, ctx);
ctx.make_assert(
generator,
is_empty.value,
"0:ValueError",
on_empty_err_msg,
[None, None, None],
ctx.current_loc,
);
// Setup and initialize the extremum to be the first element in the ndarray
let pextremum_index = sizet_model.alloca(generator, ctx, "extremum_index");
let pextremum = ctx.builder.build_alloca(dtype_llvm, "extremum").unwrap();
let zero = sizet_model.const_0(generator, ctx.ctx);
pextremum_index.store(ctx, zero);
let first_scalar = self.get_nth_scalar(generator, ctx, zero);
ctx.builder.build_store(pextremum, first_scalar.value).unwrap();
self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
let old_extremum = ctx.builder.build_load(pextremum, "current_extremum").unwrap();
let old_extremum = AnyObject { ty: self.dtype, value: old_extremum };
let scalar = nditer.get_scalar(generator, ctx);
let new_extremum = AnyObject::call_min_or_max(ctx, kind, old_extremum, scalar);
gen_if_callback(
generator,
ctx,
|generator, ctx| {
// Is new_extremum is more extreme than old_extremum?
let cmp = AnyObject::compare_int_or_float_by_predicate(
generator,
ctx,
new_extremum,
old_extremum,
IntPredicate::NE,
FloatPredicate::ONE,
"",
);
Ok(cmp.value)
},
|generator, ctx| {
// Yes, update the extremum index
let index = nditer.get_index(generator, ctx);
pextremum_index.store(ctx, index);
Ok(())
},
|_generator, _ctx| {
// No, do nothing
Ok(())
},
)
})
.unwrap();
// Finally return the extremum and extremum index.
let extremum_index = pextremum_index.load(generator, ctx, "extremum_index");
let extremum = ctx.builder.build_load(pextremum, "extremum_value").unwrap();
let extremum = AnyObject { ty: self.dtype, value: extremum };
(extremum, extremum_index)
}
/// Invoke NAC3's builtin `np_min()` or `np_max()`.
pub fn min_or_max<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
kind: MinOrMax,
) -> AnyObject<'ctx> {
let on_empty_err_msg = format!(
"zero-size array to reduction operation {} which has no identity",
match kind {
MinOrMax::Min => "minimum",
MinOrMax::Max => "maximum",
}
);
self.min_max_argmin_argmax_helper(generator, ctx, kind, &on_empty_err_msg).0
}
/// Invoke NAC3's builtin `np_argmin()` or `np_argmax()`.
pub fn argmin_or_argmax<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
kind: MinOrMax,
) -> Int<'ctx, SizeT> {
let on_empty_err_msg = format!(
"attempt to get {} of an empty sequence",
match kind {
MinOrMax::Min => "argmin",
MinOrMax::Max => "argmax",
}
);
self.min_max_argmin_argmax_helper(generator, ctx, kind, &on_empty_err_msg).1
}
}

View File

@ -0,0 +1,321 @@
use crate::codegen::{irrt::call_nac3_ndarray_index, model::*, CodeGenContext, CodeGenerator};
use super::NDArrayObject;
pub type NDIndexType = Byte;
/// Fields of [`NDIndex`]
#[derive(Debug, Clone, Copy)]
pub struct NDIndexFields<'ctx, F: FieldTraversal<'ctx>> {
pub type_: F::Out<IntModel<NDIndexType>>, // Defined to be uint8_t in IRRT
pub data: F::Out<PtrModel<IntModel<Byte>>>,
}
/// An IRRT representation fo an ndarray subscript index.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct NDIndex;
impl<'ctx> StructKind<'ctx> for NDIndex {
type Fields<F: FieldTraversal<'ctx>> = NDIndexFields<'ctx, F>;
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
Self::Fields { type_: traversal.add_auto("type"), data: traversal.add_auto("data") }
}
}
/// Fields of [`UserSlice`]
#[derive(Debug, Clone)]
pub struct UserSliceFields<'ctx, F: FieldTraversal<'ctx>> {
pub start_defined: F::Out<IntModel<Bool>>,
pub start: F::Out<IntModel<Int32>>,
pub stop_defined: F::Out<IntModel<Bool>>,
pub stop: F::Out<IntModel<Int32>>,
pub step_defined: F::Out<IntModel<Bool>>,
pub step: F::Out<IntModel<Int32>>,
}
/// An IRRT representation of a user slice.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct UserSlice;
impl<'ctx> StructKind<'ctx> for UserSlice {
type Fields<F: FieldTraversal<'ctx>> = UserSliceFields<'ctx, F>;
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
Self::Fields {
start_defined: traversal.add_auto("start_defined"),
start: traversal.add_auto("start"),
stop_defined: traversal.add_auto("stop_defined"),
stop: traversal.add_auto("stop"),
step_defined: traversal.add_auto("step_defined"),
step: traversal.add_auto("step"),
}
}
}
/// A convenience structure to prepare a [`UserSlice`].
#[derive(Debug, Clone)]
pub struct RustUserSlice<'ctx> {
pub start: Option<Int<'ctx, Int32>>,
pub stop: Option<Int<'ctx, Int32>>,
pub step: Option<Int<'ctx, Int32>>,
}
impl<'ctx> RustUserSlice<'ctx> {
/// Write the contents to an LLVM [`UserSlice`].
pub fn write_to_user_slice<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
dst_slice_ptr: Ptr<'ctx, StructModel<UserSlice>>,
) {
let bool_model = IntModel(Bool);
let false_ = bool_model.constant(generator, ctx.ctx, 0);
let true_ = bool_model.constant(generator, ctx.ctx, 1);
// TODO: Code duplication. Probably okay...?
match self.start {
Some(start) => {
dst_slice_ptr.gep(ctx, |f| f.start_defined).store(ctx, true_);
dst_slice_ptr.gep(ctx, |f| f.start).store(ctx, start);
}
None => dst_slice_ptr.gep(ctx, |f| f.start_defined).store(ctx, false_),
}
match self.stop {
Some(stop) => {
dst_slice_ptr.gep(ctx, |f| f.stop_defined).store(ctx, true_);
dst_slice_ptr.gep(ctx, |f| f.stop).store(ctx, stop);
}
None => dst_slice_ptr.gep(ctx, |f| f.stop_defined).store(ctx, false_),
}
match self.step {
Some(step) => {
dst_slice_ptr.gep(ctx, |f| f.step_defined).store(ctx, true_);
dst_slice_ptr.gep(ctx, |f| f.step).store(ctx, step);
}
None => dst_slice_ptr.gep(ctx, |f| f.step_defined).store(ctx, false_),
}
}
}
// A convenience enum variant to store the content and type of an NDIndex in high level.
#[derive(Debug, Clone)]
pub enum RustNDIndex<'ctx> {
SingleElement(Int<'ctx, Int32>), // TODO: To be SizeT
Slice(RustUserSlice<'ctx>),
NewAxis,
Ellipsis,
}
impl<'ctx> RustNDIndex<'ctx> {
/// Get the value to set `NDIndex::type` for this variant.
fn get_type_id(&self) -> u64 {
// Defined in IRRT, must be in sync
match self {
RustNDIndex::SingleElement(_) => 0,
RustNDIndex::Slice(_) => 1,
RustNDIndex::NewAxis => 2,
RustNDIndex::Ellipsis => 3,
}
}
/// Write the contents to an LLVM [`NDIndex`].
fn write_to_ndindex<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
dst_ndindex_ptr: Ptr<'ctx, StructModel<NDIndex>>,
) {
let ndindex_type_model = IntModel(NDIndexType::default());
let i32_model = IntModel(Int32);
let user_slice_model = StructModel(UserSlice);
// Set `dst_ndindex_ptr->type`
dst_ndindex_ptr
.gep(ctx, |f| f.type_)
.store(ctx, ndindex_type_model.constant(generator, ctx.ctx, self.get_type_id()));
// Set `dst_ndindex_ptr->data`
match self {
RustNDIndex::SingleElement(in_index) => {
let index_ptr = i32_model.alloca(generator, ctx, "index");
index_ptr.store(ctx, *in_index);
dst_ndindex_ptr
.gep(ctx, |f| f.data)
.store(ctx, index_ptr.pointer_cast(generator, ctx, IntModel(Byte), ""));
}
RustNDIndex::Slice(in_rust_slice) => {
let user_slice_ptr = user_slice_model.alloca(generator, ctx, "user_slice");
in_rust_slice.write_to_user_slice(generator, ctx, user_slice_ptr);
dst_ndindex_ptr
.gep(ctx, |f| f.data)
.store(ctx, user_slice_ptr.pointer_cast(generator, ctx, IntModel(Byte), ""));
}
RustNDIndex::NewAxis | RustNDIndex::Ellipsis => {}
}
}
/// Allocate an array of `NDIndex`es on the stack and return its stack pointer.
pub fn alloca_ndindexes<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
in_ndindexes: &[RustNDIndex<'ctx>],
) -> (Int<'ctx, SizeT>, Ptr<'ctx, StructModel<NDIndex>>) {
let sizet_model = IntModel(SizeT);
let ndindex_model = StructModel(NDIndex);
let num_ndindexes = sizet_model.constant(generator, ctx.ctx, in_ndindexes.len() as u64);
let ndindexes =
ndindex_model.array_alloca(generator, ctx, num_ndindexes.value, "ndindexes");
for (i, in_ndindex) in in_ndindexes.iter().enumerate() {
let i = sizet_model.constant(generator, ctx.ctx, i as u64);
let pndindex = ndindexes.offset(generator, ctx, i.value, "");
in_ndindex.write_to_ndindex(generator, ctx, pndindex);
}
(num_ndindexes, ndindexes)
}
}
impl<'ctx> NDArrayObject<'ctx> {
/// Get the ndims [`Type`] after indexing with a given slice.
#[must_use]
pub fn deduce_ndims_after_indexing_with(&self, indexes: &[RustNDIndex<'ctx>]) -> u64 {
let mut ndims = self.ndims;
for index in indexes {
match index {
RustNDIndex::SingleElement(_) => {
ndims -= 1; // Single elements decrements ndims
}
RustNDIndex::NewAxis => {
ndims += 1; // `np.newaxis` / `none` adds a new axis
}
RustNDIndex::Ellipsis | RustNDIndex::Slice(_) => {}
}
}
ndims
}
/// Index into the ndarray, and return a newly-allocated view on this ndarray.
///
/// This function behaves like NumPy's ndarray indexing, but if the indexes index
/// into a single element, an unsized ndarray is returned.
#[must_use]
pub fn index<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
indexes: &[RustNDIndex<'ctx>],
name: &str,
) -> Self {
let dst_ndims = self.deduce_ndims_after_indexing_with(indexes);
let dst_ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, dst_ndims, name);
let (num_indexes, indexes) = RustNDIndex::alloca_ndindexes(generator, ctx, indexes);
call_nac3_ndarray_index(
generator,
ctx,
num_indexes,
indexes,
self.instance,
dst_ndarray.instance,
);
dst_ndarray
}
}
pub mod util {
use itertools::Itertools;
use nac3parser::ast::{Constant, Expr, ExprKind};
use crate::{
codegen::{expr::gen_slice, model::*, CodeGenContext, CodeGenerator},
typecheck::typedef::{Type, TypeEnum},
};
use super::{RustNDIndex, RustUserSlice};
/// Generate LLVM code to transform an ndarray subscript expression to
/// its list of [`RustNDIndex`]
///
/// i.e.,
/// ```python
/// my_ndarray[::3, 1, :2:]
/// ^^^^^^^^^^^ Then these into a three `RustNDIndex`es
/// ```
pub fn gen_ndarray_subscript_ndindexes<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
subscript: &Expr<Option<Type>>,
) -> Result<Vec<RustNDIndex<'ctx>>, String> {
// TODO: Support https://numpy.org/doc/stable/user/basics.indexing.html#dimensional-indexing-tools
let i32_model = IntModel(Int32);
// Annoying notes about `slice`
// - `my_array[5]`
// - slice is a `Constant`
// - `my_array[:5]`
// - slice is a `Slice`
// - `my_array[:]`
// - slice is a `Slice`, but lower upper step would all be `Option::None`
// - `my_array[:, :]`
// - slice is now a `Tuple` of two `Slice`-s
//
// In summary:
// - when there is a comma "," within [], `slice` will be a `Tuple` of the entries.
// - when there is not comma "," within [] (i.e., just a single entry), `slice` will be that entry itself.
//
// So we first "flatten" out the slice expression
let index_exprs = match &subscript.node {
ExprKind::Tuple { elts, .. } => elts.iter().collect_vec(),
_ => vec![subscript],
};
// Process all index expressions
let mut rust_ndindexes: Vec<RustNDIndex> = Vec::with_capacity(index_exprs.len()); // Not using iterators here because `?` is used here.
for index_expr in index_exprs {
// NOTE: Currently nac3core's slices do not have an object representation,
// so the code/implementation looks awkward - we have to do pattern matching on the expression
let ndindex = if let ExprKind::Slice { lower, upper, step } = &index_expr.node {
// Handle slices
// Helper function here to deduce code duplication
let (lower, upper, step) = gen_slice(generator, ctx, lower, upper, step)?;
RustNDIndex::Slice(RustUserSlice { start: lower, stop: upper, step })
} else if let ExprKind::Constant { value: Constant::Ellipsis, .. } = &index_expr.node {
// Handle '...'
RustNDIndex::Ellipsis
} else {
match &*ctx.unifier.get_ty(index_expr.custom.unwrap()) {
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.option.obj_id(&ctx.unifier).unwrap() =>
{
// Handle `np.newaxis` / `None`
RustNDIndex::NewAxis
}
_ => {
// Treat and handle everything else as a single element index.
let index =
generator.gen_expr(ctx, index_expr)?.unwrap().to_basic_value_enum(
ctx,
generator,
ctx.primitives.int32, // Must be int32, this checks for illegal values
)?;
let index = i32_model.check_value(generator, ctx.ctx, index).unwrap();
RustNDIndex::SingleElement(index)
}
}
};
rust_ndindexes.push(ndindex);
}
Ok(rust_ndindexes)
}
}

View File

@ -0,0 +1,200 @@
use itertools::Itertools;
use crate::{
codegen::{
object::ndarray::{AnyObject, NDArrayObject},
stmt::gen_for_callback,
CodeGenContext, CodeGenerator,
},
typecheck::typedef::Type,
};
use super::{nditer::NDIterHandle, scalar::ScalarOrNDArray, NDArrayOut};
impl<'ctx> NDArrayObject<'ctx> {
/// TODO: Document me. Has complex behavior.
/// and explain why `ret_dtype` has to be specified beforehand.
pub fn broadcasting_starmap<'a, G, MappingFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
ndarrays: &[Self],
out: NDArrayOut<'ctx>,
mapping: MappingFn,
) -> Result<Self, String>
where
G: CodeGenerator + ?Sized,
MappingFn: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
&[AnyObject<'ctx>],
) -> Result<AnyObject<'ctx>, String>,
{
// Broadcast inputs
let broadcast_result = NDArrayObject::broadcast(generator, ctx, ndarrays);
let out_ndarray = match out {
NDArrayOut::NewNDArray { dtype } => {
// Create a new ndarray based on the broadcast shape.
let result_ndarray = NDArrayObject::alloca(
generator,
ctx,
dtype,
broadcast_result.ndims,
"mapped_ndarray",
);
result_ndarray.copy_shape_from_array(generator, ctx, broadcast_result.shape);
result_ndarray.create_data(generator, ctx);
result_ndarray
}
NDArrayOut::WriteToNDArray { ndarray: result_ndarray } => {
// Use an existing ndarray.
// Check that its shape is compatible with the broadcast shape.
result_ndarray.check_can_be_written_by_out(
generator,
ctx,
broadcast_result.ndims,
broadcast_result.shape,
);
result_ndarray
}
};
// Map element-wise and store results into `mapped_ndarray`.
let nditer = NDIterHandle::new(generator, ctx, out_ndarray);
gen_for_callback(
generator,
ctx,
Some("broadcast_starmap"),
|generator, ctx| {
// Create NDIters for all broadcasted input ndarrays.
let other_nditers = broadcast_result
.ndarrays
.iter()
.map(|ndarray| NDIterHandle::new(generator, ctx, *ndarray))
.collect_vec();
Ok((nditer, other_nditers))
},
|generator, ctx, (out_nditer, _in_nditers)| {
// We can simply use `out_nditer`'s `has_next()`.
// `in_nditers`' `has_next()`s should return the same value.
Ok(out_nditer.has_next(generator, ctx).value)
},
|generator, ctx, _hooks, (out_nditer, in_nditers)| {
// Get all the scalars from the broadcasted input ndarrays, pass them to `mapping`,
// and write to `out_ndarray`.
let in_scalars =
in_nditers.iter().map(|nditer| nditer.get_scalar(generator, ctx)).collect_vec();
let result = mapping(generator, ctx, &in_scalars)?;
// Sanity check on result's ty
assert!(ctx.unifier.unioned(result.ty, out_ndarray.dtype));
let p = out_nditer.get_pointer(generator, ctx);
ctx.builder.build_store(p, result.value).unwrap();
Ok(())
},
|generator, ctx, (out_nditer, in_nditers)| {
// Advance all iterators
out_nditer.next(generator, ctx);
in_nditers.iter().for_each(|nditer| nditer.next(generator, ctx));
Ok(())
},
)?;
Ok(out_ndarray)
}
pub fn map<'a, G, Mapping>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
out: NDArrayOut<'ctx>,
mapping: Mapping,
) -> Result<Self, String>
where
G: CodeGenerator + ?Sized,
Mapping: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
AnyObject<'ctx>,
) -> Result<AnyObject<'ctx>, String>,
{
NDArrayObject::broadcasting_starmap(
generator,
ctx,
&[*self],
out,
|generator, ctx, scalars| mapping(generator, ctx, scalars[0]),
)
}
}
impl<'ctx> ScalarOrNDArray<'ctx> {
/// TODO: Document me. Has complex behavior.
/// and explain why `ret_dtype` has to be specified beforehand.
pub fn broadcasting_starmap<'a, G, MappingFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
inputs: &[Self],
ret_dtype: Type,
mapping: MappingFn,
) -> Result<Self, String>
where
G: CodeGenerator + ?Sized,
MappingFn: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
&[AnyObject<'ctx>],
) -> Result<AnyObject<'ctx>, String>,
{
// Check if all inputs are AnyObjects
let all_scalars: Option<Vec<_>> = inputs.iter().map(AnyObject::try_from).try_collect().ok();
if let Some(scalars) = all_scalars {
let scalar = mapping(generator, ctx, &scalars)?;
// Sanity check on scalar's type
assert!(ctx.unifier.unioned(scalar.ty, ret_dtype));
Ok(ScalarOrNDArray::Scalar(scalar))
} else {
// Promote all input to ndarrays and map through them.
let inputs = inputs.iter().map(|input| input.as_ndarray(generator, ctx)).collect_vec();
let ndarray = NDArrayObject::broadcasting_starmap(
generator,
ctx,
&inputs,
NDArrayOut::NewNDArray { dtype: ret_dtype },
mapping,
)?;
Ok(ScalarOrNDArray::NDArray(ndarray))
}
}
pub fn map<'a, G, Mapping>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
ret_dtype: Type,
mapping: Mapping,
) -> Result<Self, String>
where
G: CodeGenerator + ?Sized,
Mapping: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
AnyObject<'ctx>,
) -> Result<AnyObject<'ctx>, String>,
{
ScalarOrNDArray::broadcasting_starmap(
generator,
ctx,
&[*self],
ret_dtype,
|generator, ctx, scalars| mapping(generator, ctx, scalars[0]),
)
}
}

View File

@ -0,0 +1,831 @@
pub mod array;
pub mod broadcast;
pub mod factory;
pub mod functions;
pub mod indexing;
pub mod mapping;
pub mod nalgebra;
pub mod nditer;
pub mod product;
pub mod scalar;
pub mod shape_util;
use crate::{
codegen::{
irrt::{
call_nac3_ndarray_copy_data, call_nac3_ndarray_get_nth_pelement,
call_nac3_ndarray_is_c_contiguous, call_nac3_ndarray_len, call_nac3_ndarray_nbytes,
call_nac3_ndarray_resolve_and_check_new_shape, call_nac3_ndarray_set_strides_by_shape,
call_nac3_ndarray_size, call_nac3_ndarray_transpose,
call_nac3_ndarray_util_assert_output_shape_same,
},
model::*,
stmt::{gen_for_callback, BreakContinueHooks},
structure::{NDArray, SimpleNDArray},
CodeGenContext, CodeGenerator,
},
toplevel::{
helper::{create_ndims, extract_ndims},
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
},
typecheck::typedef::Type,
};
use indexing::RustNDIndex;
use inkwell::{
context::Context,
types::BasicType,
values::{BasicValue, PointerValue},
AddressSpace, IntPredicate,
};
use nditer::NDIterHandle;
use scalar::ScalarOrNDArray;
use util::call_memcpy_model;
use super::{tuple::TupleObject, AnyObject};
/// A NAC3 Python ndarray object.
#[derive(Debug, Clone, Copy)]
pub struct NDArrayObject<'ctx> {
pub dtype: Type,
pub ndims: u64,
pub instance: Ptr<'ctx, StructModel<NDArray>>,
}
impl<'ctx> NDArrayObject<'ctx> {
/// Create an [`NDArrayObject`] from an LLVM value and its typechecker [`Type`].
pub fn from_object<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
object: AnyObject<'ctx>,
) -> Self {
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, object.ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
Self::from_value_and_unpacked_types(generator, ctx, object.value, dtype, ndims)
}
/// Like [`NDArrayObject::from_object`] but you directly supply the ndarray's
/// `dtype` and `ndims`.
pub fn from_value_and_unpacked_types<V: BasicValue<'ctx>, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
value: V,
dtype: Type,
ndims: u64,
) -> Self {
let pndarray_model = PtrModel(StructModel(NDArray));
let value = pndarray_model.check_value(generator, ctx.ctx, value).unwrap();
NDArrayObject { dtype, ndims, instance: value }
}
/// Forget that this is an ndarray and convert to an [`AnyObject`].
pub fn to_any_object(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> AnyObject<'ctx> {
let ty = self.get_ndarray_type(ctx);
AnyObject { value: self.instance.value.as_basic_value_enum(), ty }
}
/// Create a [`SimpleNDArray`] from the contents of this ndarray.
///
/// This function may or may not be expensive depending on if this ndarray has contiguous data.
///
/// If this ndarray is not C-contiguous, this function will allocate memory on the stack for the `data` field of
/// the returned [`SimpleNDArray`] and copy contents of this ndarray to there.
///
/// If this ndarray is C-contiguous, contents of this ndarray will not be copied. The created [`SimpleNDArray`]
/// will have the same `data` field as this ndarray.
///
/// The `item_model` sets the [`Model`] of the returned [`SimpleNDArray`]'s `Item` model, and should match the
/// `ctx.get_llvm_type()` of this ndarray's `dtype`. Otherwise this function panics.
pub fn make_simple_ndarray<G: CodeGenerator + ?Sized, Item: Model<'ctx>>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
item_model: Item,
name: &str,
) -> Ptr<'ctx, StructModel<SimpleNDArray<Item>>> {
// Sanity check on `self.dtype` and `item_model`.
let dtype_llvm = ctx.get_llvm_type(generator, self.dtype);
item_model.check_type(generator, ctx.ctx, dtype_llvm).unwrap();
let simple_ndarray_model = StructModel(SimpleNDArray { item: item_model });
let current_bb = ctx.builder.get_insert_block().unwrap();
let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "then_bb");
let else_bb = ctx.ctx.insert_basic_block_after(then_bb, "else_bb");
let end_bb = ctx.ctx.insert_basic_block_after(else_bb, "end_bb");
// Allocate and setup the resulting [`SimpleNDArray`].
let result = simple_ndarray_model.alloca(generator, ctx, name);
// Set ndims and shape.
let ndims = self.get_ndims(generator, ctx.ctx);
result.set(ctx, |f| f.ndims, ndims);
let shape = self.instance.get(generator, ctx, |f| f.shape, "shape");
result.set(ctx, |f| f.shape, shape);
// Set data, but we do things differently if this ndarray is contiguous.
let is_contiguous = self.is_c_contiguous(generator, ctx);
ctx.builder.build_conditional_branch(is_contiguous.value, then_bb, else_bb).unwrap();
// Inserting into then_bb; This ndarray is contiguous.
let data = self.instance.get(generator, ctx, |f| f.data, "");
let data = data.pointer_cast(generator, ctx, item_model, "");
result.set(ctx, |f| f.data, data);
ctx.builder.build_unconditional_branch(end_bb).unwrap();
// Inserting into else_bb; This ndarray is not contiguous. Do a full-copy on `data`.
// TODO: Reimplement this? This method does give us the contiguous `data`, but
// this creates a few extra bytes of useless information because an entire NDArray
// is allocated. Though this is super convenient.
let data = self.make_copy(generator, ctx, "").instance.get(generator, ctx, |f| f.data, "");
let data = data.pointer_cast(generator, ctx, item_model, "");
result.set(ctx, |f| f.data, data);
ctx.builder.build_unconditional_branch(end_bb).unwrap();
// Reposition to end_bb for continuation
ctx.builder.position_at_end(end_bb);
result
}
/// Create an [`NDArrayObject`] from a [`SimpleNDArray`].
///
/// This operation is super cheap. The newly created [`NDArray`] will not copy contents
/// from `simple_ndarray`, but only having its `data` and `shape` pointing to `simple_array`.
pub fn from_simple_ndarray<G: CodeGenerator + ?Sized, Item: Model<'ctx>>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
simple_ndarray: Ptr<'ctx, StructModel<SimpleNDArray<Item>>>,
dtype: Type,
ndims: u64,
) -> Self {
// Sanity check on `dtype` and `simple_array`'s `Item` model.
let dtype_llvm = ctx.get_llvm_type(generator, dtype);
simple_ndarray.model.0 .0.item.check_type(generator, ctx.ctx, dtype_llvm).unwrap();
let byte_model = IntModel(Byte);
// TODO: Check if `ndims` is consistent with that in `simple_array`?
// Allocate the resulting ndarray.
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims, "from_simple_ndarray");
// Set data, shape by simply copying addresses.
let data = simple_ndarray
.get(generator, ctx, |f| f.data, "")
.pointer_cast(generator, ctx, byte_model, "data");
ndarray.instance.set(ctx, |f| f.data, data);
let shape = simple_ndarray.get(generator, ctx, |f| f.shape, "shape");
ndarray.instance.set(ctx, |f| f.shape, shape);
// Set strides. We know `data` is contiguous.
ndarray.update_strides_by_shape(generator, ctx);
ndarray
}
/// Get the typechecker ndarray type of this [`NDArrayObject`].
pub fn get_ndarray_type(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Type {
let ndims = create_ndims(&mut ctx.unifier, self.ndims);
make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(self.dtype), Some(ndims))
}
/// Get the `np.size()` of this ndarray.
pub fn size<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, SizeT> {
call_nac3_ndarray_size(generator, ctx, self.instance)
}
/// Get the `ndarray.nbytes` of this ndarray.
pub fn nbytes<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, SizeT> {
call_nac3_ndarray_nbytes(generator, ctx, self.instance)
}
/// Get the `len()` of this ndarray.
pub fn len<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, SizeT> {
call_nac3_ndarray_len(generator, ctx, self.instance)
}
/// Check if this ndarray is C-contiguous.
///
/// See NumPy's `flags["C_CONTIGUOUS"]`: <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags>
pub fn is_c_contiguous<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, Bool> {
call_nac3_ndarray_is_c_contiguous(generator, ctx, self.instance)
}
/// Get the pointer to the n-th (0-based) element.
///
/// The returned pointer has the element type of the LLVM type of this ndarray's `dtype`.
///
/// There is no out-of-bounds check.
pub fn get_nth_pointer<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
nth: Int<'ctx, SizeT>,
name: &str,
) -> PointerValue<'ctx> {
let elem_ty = ctx.get_llvm_type(generator, self.dtype);
let p = call_nac3_ndarray_get_nth_pelement(generator, ctx, self.instance, nth);
ctx.builder
.build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), name)
.unwrap()
}
/// Get the n-th (0-based) scalar.
///
/// There is no out-of-bounds check.
pub fn get_nth_scalar<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
nth: Int<'ctx, SizeT>,
) -> AnyObject<'ctx> {
let p = self.get_nth_pointer(generator, ctx, nth, "value");
let value = ctx.builder.build_load(p, "value").unwrap();
AnyObject { ty: self.dtype, value }
}
/// Set the n-th (0-based) scalar.
///
/// There is no out-of-bounds check.
pub fn set_nth_scalar<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
nth: Int<'ctx, SizeT>,
scalar: AnyObject<'ctx>,
) {
// Sanity check on scalar's `dtype`
assert!(ctx.unifier.unioned(scalar.ty, self.dtype));
let pscalar = self.get_nth_pointer(generator, ctx, nth, "pscalar");
ctx.builder.build_store(pscalar, scalar.value).unwrap();
}
/// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`.
///
/// Please refer to the IRRT implementation to see its purpose.
pub fn update_strides_by_shape<G: CodeGenerator + ?Sized>(
self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) {
call_nac3_ndarray_set_strides_by_shape(generator, ctx, self.instance);
}
/// Copy data from another ndarray.
///
/// This ndarray and `src` is that their `np.size()` should be the same. Their shapes
/// do not matter. The copying order is determined by how their flattened views look.
///
/// Panics if the `dtype`s of ndarrays are different.
pub fn copy_data_from<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src: NDArrayObject<'ctx>,
) {
assert!(ctx.unifier.unioned(self.dtype, src.dtype), "self and src dtype should match");
call_nac3_ndarray_copy_data(generator, ctx, src.instance, self.instance);
}
/// Allocate an ndarray on the stack given its `ndims` and `dtype`.
///
/// `shape` and `strides` will be automatically allocated on the stack.
//e
/// The returned ndarray's content will be:
/// - `data`: set to `nullptr`.
/// - `itemsize`: set to the `sizeof()` of `dtype`.
/// - `ndims`: set to the value of `ndims`.
/// - `shape`: allocated with an array of length `ndims` with uninitialized values.
/// - `strides`: allocated with an array of length `ndims` with uninitialized values.
pub fn alloca<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
ndims: u64,
name: &str,
) -> Self {
let sizet_model = IntModel(SizeT);
let ndarray_model = StructModel(NDArray);
let ndarray_data_model = PtrModel(IntModel(Byte));
let pndarray = ndarray_model.alloca(generator, ctx, name);
let data = ndarray_data_model.nullptr(generator, ctx.ctx);
pndarray.set(ctx, |f| f.data, data);
let itemsize = ctx.get_llvm_type(generator, dtype).size_of().unwrap();
let itemsize =
sizet_model.s_extend_or_bit_cast(generator, ctx, itemsize, "alloca_itemsize");
pndarray.set(ctx, |f| f.itemsize, itemsize);
let ndims_val = sizet_model.constant(generator, ctx.ctx, ndims);
pndarray.set(ctx, |f| f.ndims, ndims_val);
let shape = sizet_model.array_alloca(generator, ctx, ndims_val.value, "alloca_shape");
pndarray.set(ctx, |f| f.shape, shape);
let strides = sizet_model.array_alloca(generator, ctx, ndims_val.value, "alloca_strides");
pndarray.set(ctx, |f| f.strides, strides);
NDArrayObject { dtype, ndims, instance: pndarray }
}
/// Convenience function.
/// Like [`NDArrayObject::alloca_uninitialized`] but directly takes the typechecker type of the ndarray.
pub fn alloca_ndarray_type<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray_ty: Type,
name: &str,
) -> Self {
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
Self::alloca(generator, ctx, dtype, ndims, name)
}
/// Convenience function. Allocate an [`NDArrayObject`] with a statically known shape.
///
/// The returned [`NDArrayObject`]'s `data` and `strides` are uninitialized.
pub fn alloca_constant_shape<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
shape: &[u64],
name: &str,
) -> Self {
let sizet_model = IntModel(SizeT);
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, shape.len() as u64, name);
// Write shape
let dst_shape = ndarray.instance.get(generator, ctx, |f| f.shape, "shape");
for (i, dim) in shape.iter().enumerate() {
let dim = sizet_model.constant(generator, ctx.ctx, *dim);
dst_shape.offset_const(generator, ctx, i as u64, "").store(ctx, dim);
}
ndarray
}
/// Convenience function. Allocate an [`NDArrayObject`] with a dynamically known shape.
///
/// The returned [`NDArrayObject`]'s `data` and `strides` are uninitialized.
pub fn alloca_dynamic_shape<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
shape: &[Int<'ctx, SizeT>],
name: &str,
) -> Self {
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, shape.len() as u64, name);
// Write shape
let dst_shape = ndarray.instance.get(generator, ctx, |f| f.shape, "shape");
for (i, dim) in shape.iter().enumerate() {
dst_shape.offset_const(generator, ctx, i as u64, "").store(ctx, *dim);
}
ndarray
}
/// Clone/Copy this ndarray - Allocate a new ndarray with the same shape as this ndarray and copy the contents over.
///
/// The new ndarray will own its data and will be C-contiguous.
#[must_use]
pub fn make_copy<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: &str,
) -> Self {
let clone = NDArrayObject::alloca(generator, ctx, self.dtype, self.ndims, name);
let shape = self.instance.gep(ctx, |f| f.shape).load(generator, ctx, "shape");
clone.copy_shape_from_array(generator, ctx, shape);
clone.create_data(generator, ctx);
clone.copy_data_from(generator, ctx, *self);
clone
}
/// Get this ndarray's `ndims` as an LLVM constant.
pub fn get_ndims<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
) -> Int<'ctx, SizeT> {
let sizet_model = IntModel(SizeT);
sizet_model.constant(generator, ctx, self.ndims)
}
/// Get if this ndarray's `np.size` is `0` - containing no content.
pub fn is_empty<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, Bool> {
let sizet_model = IntModel(SizeT);
let size = self.size(generator, ctx);
size.compare(ctx, IntPredicate::EQ, sizet_model.const_0(generator, ctx.ctx), "is_empty")
}
/// Return true if this ndarray is unsized - `ndims == 0` and only contains a scalar.
///
/// This is a staticially known property of ndarrays. This is why it is returning
/// a Rust boolean instead of a [`BasicValue`].
#[must_use]
pub fn is_unsized(&self) -> bool {
self.ndims == 0
}
/// If this ndarray is unsized, return its sole value as a [`AnyObject`]. Otherwise, do nothing.
pub fn split_unsized<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> ScalarOrNDArray<'ctx> {
if self.is_unsized() {
// NOTE: `np.size(self) == 0` here is never possible.
let sizet_model = IntModel(SizeT);
let zero = sizet_model.const_0(generator, ctx.ctx);
ScalarOrNDArray::Scalar(self.get_nth_scalar(generator, ctx, zero))
} else {
ScalarOrNDArray::NDArray(*self)
}
}
/// Initialize an ndarray's `data` by allocating a buffer on the stack.
/// The allocated data buffer is considered to be *owned* by the ndarray.
///
/// `strides` of the ndarray will also be updated with `set_strides_by_shape`.
///
/// `shape` and `itemsize` of the ndarray ***must*** be initialized first.
pub fn create_data<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) {
let byte_model = IntModel(Byte);
let nbytes = self.nbytes(generator, ctx);
let data = byte_model.array_alloca(generator, ctx, nbytes.value, "data");
self.instance.set(ctx, |f| f.data, data);
self.update_strides_by_shape(generator, ctx);
}
/// Copy shape dimensions from an array.
pub fn copy_shape_from_array<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src_shape: Ptr<'ctx, IntModel<SizeT>>,
) {
let dst_shape = self.instance.get(generator, ctx, |f| f.shape, "dst_shape");
let num_items = self.get_ndims(generator, ctx.ctx).value;
call_memcpy_model(generator, ctx, dst_shape, src_shape, num_items);
}
/// Copy shape dimensions from an ndarray.
/// Panics if `ndims` mismatches.
pub fn copy_shape_from_ndarray<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src_ndarray: NDArrayObject<'ctx>,
) {
assert_eq!(self.ndims, src_ndarray.ndims);
let src_shape = src_ndarray.instance.get(generator, ctx, |f| f.shape, "src_shape");
self.copy_shape_from_array(generator, ctx, src_shape);
}
/// Copy strides dimensions from an array.
pub fn copy_strides_from_array<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src_strides: Ptr<'ctx, IntModel<SizeT>>,
) {
let dst_strides = self.instance.get(generator, ctx, |f| f.strides, "dst_strides");
let num_items = self.get_ndims(generator, ctx.ctx).value;
call_memcpy_model(generator, ctx, dst_strides, src_strides, num_items);
}
/// Copy strides dimensions from an ndarray.
/// Panics if `ndims` mismatches.
pub fn copy_strides_from_ndarray<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src_ndarray: NDArrayObject<'ctx>,
) {
assert_eq!(self.ndims, src_ndarray.ndims);
let src_strides = src_ndarray.instance.get(generator, ctx, |f| f.strides, "src_strides");
self.copy_strides_from_array(generator, ctx, src_strides);
}
/// Iterate through every element in the ndarray.
///
/// `body` also access to [`BreakContinueHooks`] to short-circuit.
pub fn foreach<'a, G, F>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
body: F,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
F: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
BreakContinueHooks<'ctx>,
NDIterHandle<'ctx>,
) -> Result<(), String>,
{
gen_for_callback(
generator,
ctx,
Some("ndarray_foreach"),
|generator, ctx| Ok(NDIterHandle::new(generator, ctx, *self)),
|generator, ctx, nditer| Ok(nditer.has_next(generator, ctx).value),
|generator, ctx, hooks, nditer| body(generator, ctx, hooks, nditer),
|generator, ctx, nditer| {
nditer.next(generator, ctx);
Ok(())
},
)
}
/// Make sure the ndarray is at least `ndmin`-dimensional.
///
/// If this ndarray's `ndims` is less than `ndmin`, a view is created on this with 1s prepended to the shape.
/// If this ndarray's `ndims` is not less than `ndmin`, this function does nothing and return this ndarray.
#[must_use]
pub fn atleast_nd<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndmin: u64,
) -> Self {
if self.ndims < ndmin {
let mut indices = vec![];
for _ in self.ndims..ndmin {
indices.push(RustNDIndex::NewAxis);
}
indices.push(RustNDIndex::Ellipsis);
self.index(generator, ctx, &indices, "atleast_nd_ndarray")
} else {
*self
}
}
/// Fill the ndarray with a scalar.
///
/// `fill_value` must have the same LLVM type as the `dtype` of this ndarray.
pub fn fill<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
scalar: AnyObject<'ctx>,
) {
// Sanity check on scalar's type.
assert!(ctx.unifier.unioned(self.dtype, scalar.ty));
self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
let p = nditer.get_pointer(generator, ctx);
ctx.builder.build_store(p, scalar.value).unwrap();
Ok(())
})
.unwrap();
}
/// Create a reshaped view on this ndarray like `np.reshape()`.
///
/// If there is a `-1` in `new_shape`, it will be resolved; `new_shape` would **NOT** be modified as a result.
///
/// If reshape without copying is impossible, this function will allocate a new ndarray and copy contents.
///
/// * `new_ndims` - The number of dimensions of `new_shape` as a [`Type`].
/// * `new_shape` - The target shape to do `np.reshape()`.
#[must_use]
pub fn reshape_or_copy<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
new_ndims: u64,
new_shape: Ptr<'ctx, IntModel<SizeT>>,
) -> Self {
// TODO: The current criterion for whether to do a full copy or not is by checking `is_c_contiguous`,
// but this is not optimal - there are cases when the ndarray is not contiguous but could be reshaped
// without copying data. Look into how numpy does it.
let current_bb = ctx.builder.get_insert_block().unwrap();
let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "then_bb");
let else_bb = ctx.ctx.insert_basic_block_after(then_bb, "else_bb");
let end_bb = ctx.ctx.insert_basic_block_after(else_bb, "end_bb");
let dst_ndarray =
NDArrayObject::alloca(generator, ctx, self.dtype, new_ndims, "reshaped_ndarray");
dst_ndarray.copy_shape_from_array(generator, ctx, new_shape);
// Reolsve negative indices
let size = self.size(generator, ctx);
let dst_ndims = dst_ndarray.get_ndims(generator, ctx.ctx);
let dst_shape =
dst_ndarray.instance.get(generator, ctx, |f| f.shape, "reshaped_ndarray_shape");
call_nac3_ndarray_resolve_and_check_new_shape(generator, ctx, size, dst_ndims, dst_shape);
let is_c_contiguous = self.is_c_contiguous(generator, ctx);
ctx.builder.build_conditional_branch(is_c_contiguous.value, then_bb, else_bb).unwrap();
// Inserting into then_bb: reshape is possible without copying
ctx.builder.position_at_end(then_bb);
dst_ndarray.update_strides_by_shape(generator, ctx);
dst_ndarray.instance.set(
ctx,
|f| f.data,
self.instance.get(generator, ctx, |f| f.data, "data"),
);
ctx.builder.build_unconditional_branch(end_bb).unwrap();
// Inserting into else_bb: reshape is impossible without copying
ctx.builder.position_at_end(else_bb);
dst_ndarray.create_data(generator, ctx);
dst_ndarray.copy_data_from(generator, ctx, *self);
ctx.builder.build_unconditional_branch(end_bb).unwrap();
// Reposition for continuation
ctx.builder.position_at_end(end_bb);
dst_ndarray
}
/// Create a flattened view of this ndarray, like `np.ravel()`.
///
/// Uses [`NDArrayObject::reshape_or_copy`] under-the-hood so ndarray may or may not be copied.
#[must_use]
pub fn ravel_or_copy<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Self {
// Define models
let sizet_model = IntModel(SizeT);
let num0 = sizet_model.const_0(generator, ctx.ctx);
let num1 = sizet_model.const_1(generator, ctx.ctx);
let num_neg1 = sizet_model.const_all_1s(generator, ctx.ctx);
// Create `[-1]` and pass to `reshape_or_copy`.
let new_shape = sizet_model.array_alloca(generator, ctx, num1.value, "new_shape");
new_shape.offset(generator, ctx, num0.value, "").store(ctx, num_neg1);
self.reshape_or_copy(generator, ctx, 1, new_shape)
}
/// Create a transposed view on this ndarray like `np.transpose(<ndarray>, <axes> = None)`.
/// * `axes` - If specified, should be an array of the permutation (negative indices are **allowed**).
#[must_use]
pub fn transpose<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
axes: Option<Ptr<'ctx, IntModel<SizeT>>>,
) -> Self {
// Define models
let sizet_model = IntModel(SizeT);
let transposed_ndarray =
NDArrayObject::alloca(generator, ctx, self.dtype, self.ndims, "transposed_ndarray");
let num_axes = self.get_ndims(generator, ctx.ctx);
// `axes = nullptr` if `axes` is unspecified.
let axes = axes.unwrap_or_else(|| PtrModel(sizet_model).nullptr(generator, ctx.ctx));
call_nac3_ndarray_transpose(
generator,
ctx,
self.instance,
transposed_ndarray.instance,
num_axes,
axes,
);
transposed_ndarray
}
/// Check if this `NDArray` can be used as an `out` ndarray for an operation.
///
/// Raise an exception if the shapes do not match.
pub fn check_can_be_written_by_out<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
out_ndims: u64,
out_shape: Ptr<'ctx, IntModel<SizeT>>,
) {
let sizet_model = IntModel(SizeT);
let ndarray_ndims = self.get_ndims(generator, ctx.ctx);
let ndarray_shape = self.instance.get(generator, ctx, |f| f.shape, "shape");
let output_ndims = sizet_model.constant(generator, ctx.ctx, out_ndims);
let output_shape = out_shape;
call_nac3_ndarray_util_assert_output_shape_same(
generator,
ctx,
ndarray_ndims,
ndarray_shape,
output_ndims,
output_shape,
);
}
/// Create the shape tuple of this ndarray like `np.shape(<ndarray>)`.
///
/// The returned integers in the tuple are in int32.
pub fn make_shape_tuple<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> TupleObject<'ctx> {
// TODO: Don't return a tuple of int32s.
let mut objects = Vec::with_capacity(self.ndims as usize);
for i in 0..self.ndims {
let dim = self
.instance
.get(generator, ctx, |f| f.shape, "")
.offset_const(generator, ctx, i, "")
.load(generator, ctx, "dim");
let dim = dim.truncate(generator, ctx, Int32, "dim"); // TODO: keep using SizeT
objects.push(AnyObject {
ty: ctx.primitives.int32,
value: dim.value.as_basic_value_enum(),
});
}
TupleObject::create(generator, ctx, objects, "shape")
}
/// Create the strides tuple of this ndarray like `np.strides(<ndarray>)`.
///
/// The returned integers in the tuple are in int32.
pub fn make_strides_tuple<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> TupleObject<'ctx> {
// TODO: Don't return a tuple of int32s.
let mut objects = Vec::with_capacity(self.ndims as usize);
for i in 0..self.ndims {
let dim = self
.instance
.get(generator, ctx, |f| f.strides, "")
.offset_const(generator, ctx, i, "")
.load(generator, ctx, "dim");
let dim = dim.truncate(generator, ctx, Int32, "dim"); // TODO: keep using SizeT
objects.push(AnyObject {
ty: ctx.primitives.int32,
value: dim.value.as_basic_value_enum(),
});
}
TupleObject::create(generator, ctx, objects, "strides")
}
}
/// TODO: Document me
#[derive(Debug, Clone, Copy)]
pub enum NDArrayOut<'ctx> {
NewNDArray { dtype: Type },
WriteToNDArray { ndarray: NDArrayObject<'ctx> },
}

View File

@ -0,0 +1,53 @@
use inkwell::values::{BasicValue, BasicValueEnum};
use crate::codegen::{model::*, structure::SimpleNDArray, CodeGenContext, CodeGenerator};
use super::NDArrayObject;
pub fn perform_nalgebra_call<'ctx, 'a, const NUM_INPUTS: usize, const NUM_OUTPUTS: usize, G, F>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
inputs: [NDArrayObject<'ctx>; NUM_INPUTS],
output_ndims: [u64; NUM_OUTPUTS],
invoke_function: F,
) -> [NDArrayObject<'ctx>; NUM_OUTPUTS]
where
G: CodeGenerator + ?Sized,
F: FnOnce(
&mut CodeGenContext<'ctx, 'a>,
[BasicValueEnum<'ctx>; NUM_INPUTS],
[BasicValueEnum<'ctx>; NUM_OUTPUTS],
),
{
// TODO: Allow stacked inputs. See NumPy docs.
let f64_model = FloatModel(Float64);
let simple_ndarray_model = StructModel(SimpleNDArray { item: f64_model });
// Prepare inputs & outputs and invoke
let inputs = inputs.map(|input| {
// Sanity check. Typechecker ensures this.
assert!(ctx.unifier.unioned(input.dtype, ctx.primitives.float));
input
.make_simple_ndarray(generator, ctx, FloatModel(Float64), "nalgebra_input")
.value
.as_basic_value_enum()
});
let outputs = [simple_ndarray_model.alloca(generator, ctx, "nalgebra_output"); NUM_OUTPUTS];
invoke_function(ctx, inputs, outputs.map(|output| output.value.as_basic_value_enum()));
// Turn the outputs into strided NDArrays
let mut output_i = 0;
outputs.map(|output| {
let out = NDArrayObject::from_simple_ndarray(
generator,
ctx,
output,
ctx.primitives.float,
output_ndims[output_i],
);
output_i += 1;
out
})
}

View File

@ -0,0 +1,88 @@
use inkwell::{types::BasicType, values::PointerValue, AddressSpace};
use crate::codegen::{
irrt::{call_nac3_nditer_has_next, call_nac3_nditer_initialize, call_nac3_nditer_next},
model::*,
object::AnyObject,
structure::NDIter,
CodeGenContext, CodeGenerator,
};
use super::NDArrayObject;
#[derive(Debug, Clone)]
pub struct NDIterHandle<'ctx> {
ndarray: NDArrayObject<'ctx>,
instance: Ptr<'ctx, StructModel<NDIter>>,
indices: Ptr<'ctx, IntModel<SizeT>>,
}
impl<'ctx> NDIterHandle<'ctx> {
pub fn new<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayObject<'ctx>,
) -> Self {
let nditer = StructModel(NDIter).alloca(generator, ctx, "nditer");
let ndims = ndarray.get_ndims(generator, ctx.ctx);
let indices = IntModel(SizeT).array_alloca(generator, ctx, ndims.value, "indices");
call_nac3_nditer_initialize(generator, ctx, nditer, ndarray.instance, indices);
NDIterHandle { ndarray, instance: nditer, indices }
}
#[must_use]
pub fn has_next<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, Bool> {
call_nac3_nditer_has_next(generator, ctx, self.instance)
}
pub fn next<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) {
call_nac3_nditer_next(generator, ctx, self.instance);
}
#[must_use]
pub fn get_pointer<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> PointerValue<'ctx> {
let elem_ty = ctx.get_llvm_type(generator, self.ndarray.dtype);
let p = self.instance.get(generator, ctx, |f| f.element, "element");
ctx.builder
.build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), "element")
.unwrap()
}
#[must_use]
pub fn get_scalar<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> AnyObject<'ctx> {
let p = self.get_pointer(generator, ctx);
let value = ctx.builder.build_load(p, "value").unwrap();
AnyObject { ty: self.ndarray.dtype, value }
}
#[must_use]
pub fn get_index<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, SizeT> {
self.instance.get(generator, ctx, |f| f.nth, "index")
}
#[must_use]
pub fn get_indices(&self) -> Ptr<'ctx, IntModel<SizeT>> {
self.indices
}
}

View File

@ -0,0 +1,159 @@
use std::cmp::max;
use crate::codegen::{
irrt::{
call_nac3_ndarray_float64_matmul_at_least_2d, call_nac3_ndarray_matmul_calculate_shapes,
},
model::*,
object::ndarray::indexing::RustNDIndex,
CodeGenContext, CodeGenerator,
};
use super::{NDArrayObject, NDArrayOut};
impl<'ctx> NDArrayObject<'ctx> {
/// TODO: Document me
fn matmul_helper<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
a: Self,
b: Self,
) -> Self {
assert!(a.ndims >= 2);
assert!(b.ndims >= 2);
assert!(ctx.unifier.unioned(ctx.primitives.float, a.dtype));
assert!(ctx.unifier.unioned(ctx.primitives.float, b.dtype));
let sizet_model = IntModel(SizeT);
let final_ndims_int = max(a.ndims, b.ndims);
let a_ndims = a.get_ndims(generator, ctx.ctx);
let a_shape = a.instance.get(generator, ctx, |f| f.shape, "a_shape");
let b_ndims = b.get_ndims(generator, ctx.ctx);
let b_shape = b.instance.get(generator, ctx, |f| f.shape, "b_shape");
let final_ndims = sizet_model.constant(generator, ctx.ctx, final_ndims_int);
let new_a_shape =
sizet_model.array_alloca(generator, ctx, final_ndims.value, "new_a_shape");
let new_b_shape =
sizet_model.array_alloca(generator, ctx, final_ndims.value, "new_b_shape");
let dst_shape = sizet_model.array_alloca(generator, ctx, final_ndims.value, "dst_shape");
call_nac3_ndarray_matmul_calculate_shapes(
generator,
ctx,
a_ndims,
a_shape,
b_ndims,
b_shape,
final_ndims,
new_a_shape,
new_b_shape,
dst_shape,
);
let new_a = a.broadcast_to(generator, ctx, final_ndims_int, new_a_shape);
let new_b = b.broadcast_to(generator, ctx, final_ndims_int, new_b_shape);
let dst = NDArrayObject::alloca(
generator,
ctx,
ctx.primitives.float,
final_ndims_int,
"matmul_result",
);
dst.copy_shape_from_array(generator, ctx, dst_shape);
dst.create_data(generator, ctx);
call_nac3_ndarray_float64_matmul_at_least_2d(
generator,
ctx,
new_a.instance,
new_b.instance,
dst.instance,
);
dst
}
/// Perform `np.matmul` according to the rules in
/// <https://numpy.org/doc/stable/reference/generated/numpy.matmul.html>.
///
/// This function always return an [`NDArrayObject`]. You may want to use [`NDArrayObject::split_unsized`].
pub fn matmul<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
a: Self,
b: Self,
out: NDArrayOut<'ctx>,
) -> Self {
// Sanity check, but type inference should prevent this.
assert!(a.ndims > 0 && b.ndims > 0, "np.matmul disallows scalar input");
/*
If both arguments are 2-D they are multiplied like conventional matrices.
If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the last two indexes and broadcast accordingly.
If the first argument is 1-D, it is promoted to a matrix by prepending a 1 to its dimensions. After matrix multiplication the prepended 1 is removed.
If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its dimensions. After matrix multiplication the appended 1 is removed.
*/
let new_a = if a.ndims == 1 {
// Prepend 1 to its dimensions
a.index(generator, ctx, &[RustNDIndex::NewAxis, RustNDIndex::Ellipsis], "new_a")
} else {
a
};
let new_b = if b.ndims == 1 {
// Append 1 to its dimensions
b.index(generator, ctx, &[RustNDIndex::Ellipsis, RustNDIndex::NewAxis], "new_a")
} else {
b
};
// NOTE: `result` will always be a newly allocated ndarray.
// Current implementation cannot do in-place matrix muliplication.
let mut result = NDArrayObject::matmul_helper(generator, ctx, new_a, new_b);
let i32_model = IntModel(Int32); // TODO: Upgrade to SizeT
let zero = i32_model.const_0(generator, ctx.ctx);
if a.ndims == 1 {
// Remove the prepended 1
result = result.index(
generator,
ctx,
&[RustNDIndex::SingleElement(zero), RustNDIndex::Ellipsis],
"result_no_prepend_1",
);
}
if b.ndims == 1 {
// Remove the appended 1
result = result.index(
generator,
ctx,
&[RustNDIndex::Ellipsis, RustNDIndex::SingleElement(zero)],
"result_no_append_1",
);
}
match out {
NDArrayOut::NewNDArray { dtype } => {
// We don't support auto-casting right now, nor anything other than float64.
// Force the output dtype to be float64.
assert!(ctx.unifier.unioned(ctx.primitives.float, dtype));
result
}
NDArrayOut::WriteToNDArray { ndarray: out_ndarray } => {
// TODO: It is possible to check the shapes before computing the matmul to save resources.
let result_shape = result.instance.get(generator, ctx, |f| f.shape, "result_shape");
out_ndarray.check_can_be_written_by_out(generator, ctx, result.ndims, result_shape);
// TODO: We can just set `out_ndarray.data` to `result.data`. Should we?
out_ndarray.copy_data_from(generator, ctx, result);
out_ndarray
}
}
}
}

View File

@ -0,0 +1,131 @@
use inkwell::values::{BasicValue, BasicValueEnum};
use crate::{
codegen::{model::*, object::AnyObject, CodeGenContext, CodeGenerator},
typecheck::typedef::{Type, TypeEnum},
};
use super::NDArrayObject;
impl<'ctx> AnyObject<'ctx> {
/// Promote this scalar to an unsized ndarray (like doing `np.asarray`).
///
/// The scalar value is allocated onto the stack, and the ndarray's `data` will point to that
/// allocated value.
pub fn as_ndarray<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> NDArrayObject<'ctx> {
let pbyte_model = PtrModel(IntModel(Byte));
// We have to put the value on the stack to get a data pointer.
let data = ctx.builder.build_alloca(self.value.get_type(), "as_ndarray_scalar").unwrap();
ctx.builder.build_store(data, self.value).unwrap();
let data = pbyte_model.pointer_cast(generator, ctx, data, "data");
let ndarray = NDArrayObject::alloca(generator, ctx, self.ty, 0, "scalar_ndarray");
ndarray.instance.set(ctx, |f| f.data, data);
ndarray
}
}
/// A convenience enum for implementing scalar/ndarray agnostic utilities.
#[derive(Debug, Clone, Copy)]
pub enum ScalarOrNDArray<'ctx> {
Scalar(AnyObject<'ctx>),
NDArray(NDArrayObject<'ctx>),
}
impl<'ctx> ScalarOrNDArray<'ctx> {
/// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`].
#[must_use]
pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> {
match self {
ScalarOrNDArray::Scalar(scalar) => scalar.value,
ScalarOrNDArray::NDArray(ndarray) => ndarray.instance.value.as_basic_value_enum(),
}
}
#[must_use]
pub fn into_scalar(&self) -> AnyObject<'ctx> {
match self {
ScalarOrNDArray::NDArray(_ndarray) => panic!("Got NDArray"),
ScalarOrNDArray::Scalar(scalar) => *scalar,
}
}
#[must_use]
pub fn into_ndarray(&self) -> NDArrayObject<'ctx> {
match self {
ScalarOrNDArray::NDArray(ndarray) => *ndarray,
ScalarOrNDArray::Scalar(_scalar) => panic!("Got Scalar"),
}
}
/// Convert this [`ScalarOrNDArray`] to an ndarray - behaves like `np.asarray`.
/// - If this is an ndarray, the ndarray is returned.
/// - If this is a scalar, an unsized ndarray view is created on it.
pub fn as_ndarray<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> NDArrayObject<'ctx> {
match self {
ScalarOrNDArray::NDArray(ndarray) => *ndarray,
ScalarOrNDArray::Scalar(scalar) => scalar.as_ndarray(generator, ctx),
}
}
#[must_use]
pub fn dtype(&self) -> Type {
match self {
ScalarOrNDArray::Scalar(scalar) => scalar.ty,
ScalarOrNDArray::NDArray(ndarray) => ndarray.dtype,
}
}
}
impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for AnyObject<'ctx> {
type Error = ();
fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result<Self, Self::Error> {
match value {
ScalarOrNDArray::Scalar(scalar) => Ok(*scalar),
ScalarOrNDArray::NDArray(_ndarray) => Err(()),
}
}
}
impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for NDArrayObject<'ctx> {
type Error = ();
fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result<Self, Self::Error> {
match value {
ScalarOrNDArray::Scalar(_scalar) => Err(()),
ScalarOrNDArray::NDArray(ndarray) => Ok(*ndarray),
}
}
}
/// Split an [`AnyObject`] into a [`ScalarOrNDArray`] depending on its [`Type`].
pub fn split_scalar_or_ndarray<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
object: AnyObject<'ctx>,
) -> ScalarOrNDArray<'ctx> {
// TODO: Automatically convert a list into an ndarray?
match &*ctx.unifier.get_ty(object.ty) {
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
{
let ndarray = NDArrayObject::from_object(generator, ctx, object);
ScalarOrNDArray::NDArray(ndarray)
}
_ => {
let scalar = AnyObject { ty: object.ty, value: object.value };
ScalarOrNDArray::Scalar(scalar)
}
}
}

View File

@ -0,0 +1,111 @@
use util::gen_for_model_auto;
use crate::{
codegen::{
model::*,
object::{list::ListObject, tuple::TupleObject, AnyObject},
CodeGenContext, CodeGenerator,
},
typecheck::typedef::TypeEnum,
};
/// Parse a NumPy-like "int sequence" input and return the int sequence as an array and its length.
///
/// * `sequence` - The `sequence` parameter.
/// * `sequence_ty` - The typechecker type of `sequence`
///
/// The `sequence` argument type may only be one of the following:
/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))`
/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
///
/// All `int32` values will be sign-extended to `SizeT`.
pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
input_sequence: AnyObject<'ctx>,
) -> (Int<'ctx, SizeT>, Ptr<'ctx, IntModel<SizeT>>) {
let sizet_model = IntModel(SizeT);
let zero = sizet_model.const_0(generator, ctx.ctx);
let one = sizet_model.const_1(generator, ctx.ctx);
// The result `list` to return.
match &*ctx.unifier.get_ty(input_sequence.ty) {
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{
// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
// Check `input_sequence`
let input_sequence = ListObject::from_object(generator, ctx, input_sequence);
let len = input_sequence.instance.gep(ctx, |f| f.len).load(generator, ctx, "len");
let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence");
// Load all the `int32`s from the input_sequence, cast them to `SizeT`, and store them into `result`
gen_for_model_auto(generator, ctx, zero, len, one, |generator, ctx, _hooks, i| {
// Load the i-th int32 in the input sequence
let int = input_sequence
.instance
.get(generator, ctx, |f| f.items, "")
.offset(generator, ctx, i.value, "")
.load(generator, ctx, "")
.value
.into_int_value();
// Cast to SizeT
let int = sizet_model.s_extend_or_bit_cast(generator, ctx, int, "int");
// Store
result.offset(generator, ctx, i.value, "int").store(ctx, int);
Ok(())
})
.unwrap();
(len, result)
}
TypeEnum::TTuple { .. } => {
// 2. A tuple of ints; e.g., `np.empty((600, 800, 3))`
let input_sequence = TupleObject::from_object(ctx, input_sequence);
let len_int = input_sequence.len_static();
let len = sizet_model.constant(generator, ctx.ctx, len_int as u64);
let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence");
for i in 0..len_int {
// Get the i-th element off of the tuple and load it into `result`.
let int = input_sequence.get(ctx, i, "dim").value.into_int_value();
let int = sizet_model.s_extend_or_bit_cast(generator, ctx, int, "int");
let offset = sizet_model.constant(generator, ctx.ctx, i as u64);
result.offset(generator, ctx, offset.value, "int").store(ctx, int);
}
(len, result)
}
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() =>
{
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
let input_int = input_sequence.value.into_int_value();
let len = sizet_model.const_1(generator, ctx.ctx);
let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence");
let int = sizet_model.s_extend_or_bit_cast(generator, ctx, input_int, "int");
// Storing into result[0]
result.store(ctx, int);
(len, result)
}
_ => panic!(
"encountered unknown sequence type: {}",
ctx.unifier.stringify(input_sequence.ty)
),
}
}

View File

@ -0,0 +1,41 @@
use crate::codegen::{
irrt::calculate_len_for_slice_range, model::*, structure::RangeModel, CodeGenContext,
CodeGenerator,
};
use super::AnyObject;
/// A `range` in NAC3
pub struct RangeObject<'ctx> {
pub instance: Ptr<'ctx, RangeModel>,
}
impl<'ctx> RangeObject<'ctx> {
pub fn from_object<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
object: AnyObject<'ctx>,
) -> Self {
assert!(ctx.unifier.unioned(ctx.primitives.range, object.ty)); // Sanity check on type.
let model = PtrModel(RangeModel::default());
let instance = model.check_value(generator, ctx.ctx, object.value).unwrap();
RangeObject { instance }
}
/// Get the `len()` of this range.
pub fn len<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, Int32> {
let start = self.instance.gep_start(generator, ctx, "").load(generator, ctx, "start");
let stop = self.instance.gep_stop(generator, ctx, "").load(generator, ctx, "stop");
let step = self.instance.gep_step(generator, ctx, "").load(generator, ctx, "step");
// TODO: Refactor this
let len =
calculate_len_for_slice_range(generator, ctx, start.value, stop.value, step.value);
IntModel(Int32).check_value(generator, ctx.ctx, len).unwrap()
}
}

View File

@ -0,0 +1,113 @@
use core::panic;
use inkwell::values::StructValue;
use itertools::Itertools;
use crate::{
codegen::{model::*, CodeGenContext, CodeGenerator},
typecheck::typedef::{Type, TypeEnum},
};
use super::AnyObject;
/// A NAC3 tuple object.
///
/// NOTE: This struct has no copy trait.
#[derive(Debug, Clone)]
pub struct TupleObject<'ctx> {
/// The type of the tuple.
pub tys: Vec<Type>,
/// The underlying LLVM value of this tuple.
pub value: StructValue<'ctx>,
}
impl<'ctx> TupleObject<'ctx> {
// NOTE: There is no Model abstraction for Tuples with arbitrary lengths.
// Everything has to be done raw with Inkwell.
pub fn from_object(ctx: &mut CodeGenContext<'ctx, '_>, object: AnyObject<'ctx>) -> Self {
// TODO: Keep `is_vararg_ctx` from TTuple?
// Sanity check on object type.
let TypeEnum::TTuple { ty: tys, .. } = &*ctx.unifier.get_ty(object.ty) else {
panic!(
"Expected type to be a TypeEnum::TTuple, got {}",
ctx.unifier.stringify(object.ty)
);
};
let value = object.value.into_struct_value();
let value_num_fields = value.get_type().count_fields() as usize;
assert!(
value_num_fields == tys.len(),
"Tuple type has {} item(s), but the LLVM struct value has {} field(s)",
tys.len(),
value_num_fields
);
TupleObject { tys: tys.clone(), value }
}
/// Convenience function. Create a [`TupleObject`] from an iterator of objects.
pub fn create<I, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
objects: I,
name: &str,
) -> Self
where
I: IntoIterator<Item = AnyObject<'ctx>>,
{
let (values, tys): (Vec<_>, Vec<_>) =
objects.into_iter().map(|object| (object.value, object.ty)).unzip();
let llvm_tys = tys.iter().map(|ty| ctx.get_llvm_type(generator, *ty)).collect_vec();
let llvm_tuple_ty = ctx.ctx.struct_type(&llvm_tys, false);
let pllvm_tuple = ctx.builder.build_alloca(llvm_tuple_ty, "tuple").unwrap();
for (i, val) in values.into_iter().enumerate() {
let pval = ctx.builder.build_struct_gep(pllvm_tuple, i as u32, "value").unwrap();
ctx.builder.build_store(pval, val).unwrap();
}
let value = ctx.builder.build_load(pllvm_tuple, name).unwrap().into_struct_value();
TupleObject { tys, value }
}
/// Get the `len()` of this tuple statically.
///
/// We statically know the lengths of tuples in NAC3 when compiling.
#[must_use]
pub fn len_static(&self) -> usize {
self.tys.len()
}
/// Get the `len()` of this tuple.
#[must_use]
pub fn len<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, SizeT> {
IntModel(SizeT).constant(generator, ctx.ctx, self.len_static() as u64)
}
/// Check if this tuple is an empty/unit tuple.
#[must_use]
pub fn is_empty(&self) -> bool {
self.len_static() == 0
}
/// Get the `i`-th (0-based) object in this tuple.
pub fn get(&self, ctx: &mut CodeGenContext<'ctx, '_>, i: usize, name: &str) -> AnyObject<'ctx> {
assert!(
i < self.len_static(),
"Tuple object with length {} have index {i}",
self.len_static()
);
let value = ctx.builder.build_extract_value(self.value, i as u32, name).unwrap();
let ty = self.tys[i];
AnyObject { ty, value }
}
}

View File

@ -1,8 +1,14 @@
use super::model::*;
use super::object::ndarray::indexing::util::gen_ndarray_subscript_ndindexes;
use super::object::ndarray::scalar::split_scalar_or_ndarray;
use super::object::ndarray::NDArrayObject;
use super::object::AnyObject;
use super::{ use super::{
super::symbol_resolver::ValueEnum, super::symbol_resolver::ValueEnum,
expr::destructure_range, expr::destructure_range,
irrt::{handle_slice_indices, list_slice_assignment}, irrt::{handle_slice_indices, list_slice_assignment},
CodeGenContext, CodeGenerator, structure::{CSlice, Exception},
CodeGenContext, CodeGenerator, Int32, IntModel, Ptr, StructModel,
}; };
use crate::{ use crate::{
codegen::{ codegen::{
@ -10,10 +16,10 @@ use crate::{
expr::gen_binop_expr, expr::gen_binop_expr,
gen_in_range_check, gen_in_range_check,
}, },
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef}, toplevel::{DefinitionId, TopLevelDef},
typecheck::{ typecheck::{
magic_methods::Binop, magic_methods::Binop,
typedef::{FunSignature, Type, TypeEnum}, typedef::{iter_type_vars, FunSignature, Type, TypeEnum},
}, },
}; };
use inkwell::{ use inkwell::{
@ -23,10 +29,10 @@ use inkwell::{
values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue}, values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue},
IntPredicate, IntPredicate,
}; };
use itertools::{izip, Itertools};
use nac3parser::ast::{ use nac3parser::ast::{
Constant, ExcepthandlerKind, Expr, ExprKind, Location, Stmt, StmtKind, StrRef, Constant, ExcepthandlerKind, Expr, ExprKind, Location, Stmt, StmtKind, StrRef,
}; };
use std::convert::TryFrom;
/// See [`CodeGenerator::gen_var_alloc`]. /// See [`CodeGenerator::gen_var_alloc`].
pub fn gen_var<'ctx>( pub fn gen_var<'ctx>(
@ -97,8 +103,6 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
pattern: &Expr<Option<Type>>, pattern: &Expr<Option<Type>>,
name: Option<&str>, name: Option<&str>,
) -> Result<Option<PointerValue<'ctx>>, String> { ) -> Result<Option<PointerValue<'ctx>>, String> {
let llvm_usize = generator.get_size_type(ctx.ctx);
// very similar to gen_expr, but we don't do an extra load at the end // very similar to gen_expr, but we don't do an extra load at the end
// and we flatten nested tuples // and we flatten nested tuples
Ok(Some(match &pattern.node { Ok(Some(match &pattern.node {
@ -137,65 +141,6 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
} }
.unwrap() .unwrap()
} }
ExprKind::Subscript { value, slice, .. } => {
match ctx.unifier.get_ty_immutable(value.custom.unwrap()).as_ref() {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::List.id() => {
let v = generator
.gen_expr(ctx, value)?
.unwrap()
.to_basic_value_enum(ctx, generator, value.custom.unwrap())?
.into_pointer_value();
let v = ListValue::from_ptr_val(v, llvm_usize, None);
let len = v.load_size(ctx, Some("len"));
let raw_index = generator
.gen_expr(ctx, slice)?
.unwrap()
.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?
.into_int_value();
let raw_index = ctx
.builder
.build_int_s_extend(raw_index, generator.get_size_type(ctx.ctx), "sext")
.unwrap();
// handle negative index
let is_negative = ctx
.builder
.build_int_compare(
IntPredicate::SLT,
raw_index,
generator.get_size_type(ctx.ctx).const_zero(),
"is_neg",
)
.unwrap();
let adjusted = ctx.builder.build_int_add(raw_index, len, "adjusted").unwrap();
let index = ctx
.builder
.build_select(is_negative, adjusted, raw_index, "index")
.map(BasicValueEnum::into_int_value)
.unwrap();
// unsigned less than is enough, because negative index after adjustment is
// bigger than the length (for unsigned cmp)
let bound_check = ctx
.builder
.build_int_compare(IntPredicate::ULT, index, len, "inbound")
.unwrap();
ctx.make_assert(
generator,
bound_check,
"0:IndexError",
"index {0} out of bounds 0:{1}",
[Some(raw_index), Some(len), None],
slice.location,
);
v.data().ptr_offset(ctx, generator, &index, name)
}
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
todo!()
}
_ => unreachable!(),
}
}
_ => unreachable!(), _ => unreachable!(),
})) }))
} }
@ -206,70 +151,20 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
target: &Expr<Option<Type>>, target: &Expr<Option<Type>>,
value: ValueEnum<'ctx>, value: ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String> { ) -> Result<(), String> {
let llvm_usize = generator.get_size_type(ctx.ctx); // See https://docs.python.org/3/reference/simple_stmts.html#assignment-statements.
match &target.node { match &target.node {
ExprKind::Tuple { elts, .. } => { ExprKind::Subscript { value: target, slice: key, .. } => {
let BasicValueEnum::StructValue(v) = // Handle "slicing" or "subscription"
value.to_basic_value_enum(ctx, generator, target.custom.unwrap())? generator.gen_setitem(ctx, target, key, value, value_ty)?;
else {
unreachable!()
};
for (i, elt) in elts.iter().enumerate() {
let v = ctx
.builder
.build_extract_value(v, u32::try_from(i).unwrap(), "struct_elem")
.unwrap();
generator.gen_assign(ctx, elt, v.into())?;
}
} }
ExprKind::Subscript { value: ls, slice, .. } ExprKind::Tuple { elts, .. } | ExprKind::List { elts, .. } => {
if matches!(&slice.node, ExprKind::Slice { .. }) => // Fold on `"[" [target_list] "]"` and `"(" [target_list] ")"`
{ generator.gen_assign_target_list(ctx, elts, value, value_ty)?;
let ExprKind::Slice { lower, upper, step } = &slice.node else { unreachable!() };
let ls = generator
.gen_expr(ctx, ls)?
.unwrap()
.to_basic_value_enum(ctx, generator, ls.custom.unwrap())?
.into_pointer_value();
let ls = ListValue::from_ptr_val(ls, llvm_usize, None);
let Some((start, end, step)) =
handle_slice_indices(lower, upper, step, ctx, generator, ls.load_size(ctx, None))?
else {
return Ok(());
};
let value = value
.to_basic_value_enum(ctx, generator, target.custom.unwrap())?
.into_pointer_value();
let value = ListValue::from_ptr_val(value, llvm_usize, None);
let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) {
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
*params.iter().next().unwrap().1
}
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
unpack_ndarray_var_tys(&mut ctx.unifier, target.custom.unwrap()).0
}
_ => unreachable!(),
};
let ty = ctx.get_llvm_type(generator, ty);
let Some(src_ind) = handle_slice_indices(
&None,
&None,
&None,
ctx,
generator,
value.load_size(ctx, None),
)?
else {
return Ok(());
};
list_slice_assignment(generator, ctx, ty, ls, (start, end, step), value, src_ind);
} }
_ => { _ => {
// Handle attribute and direct variable assignments.
let name = if let ExprKind::Name { id, .. } = &target.node { let name = if let ExprKind::Name { id, .. } = &target.node {
format!("{id}.addr") format!("{id}.addr")
} else { } else {
@ -293,6 +188,272 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
Ok(()) Ok(())
} }
/// See [`CodeGenerator::gen_assign_target_list`].
pub fn gen_assign_target_list<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
targets: &Vec<Expr<Option<Type>>>,
value: ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String> {
// Deconstruct the tuple `value`
let BasicValueEnum::StructValue(tuple) = value.to_basic_value_enum(ctx, generator, value_ty)?
else {
unreachable!()
};
// NOTE: Currently, RHS's type is forced to be a Tuple by the type inferencer.
let TypeEnum::TTuple { ty: tuple_tys, .. } = &*ctx.unifier.get_ty(value_ty) else {
unreachable!();
};
assert_eq!(tuple.get_type().count_fields() as usize, tuple_tys.len());
let tuple = (0..tuple.get_type().count_fields())
.map(|i| ctx.builder.build_extract_value(tuple, i, "item").unwrap())
.collect_vec();
// Find the starred target if it exists.
let mut starred_target_index: Option<usize> = None; // Index of the "starred" target. If it exists, there may only be one.
for (i, target) in targets.iter().enumerate() {
if matches!(target.node, ExprKind::Starred { .. }) {
assert!(starred_target_index.is_none()); // The typechecker ensures this
starred_target_index = Some(i);
}
}
if let Some(starred_target_index) = starred_target_index {
assert!(tuple_tys.len() >= targets.len() - 1); // The typechecker ensures this
let a = starred_target_index; // Number of RHS values before the starred target
let b = tuple_tys.len() - (targets.len() - 1 - starred_target_index); // Number of RHS values after the starred target
// Thus `tuple[a..b]` is assigned to the starred target.
// Handle assignment before the starred target
for (target, val, val_ty) in
izip!(&targets[..starred_target_index], &tuple[..a], &tuple_tys[..a])
{
generator.gen_assign(ctx, target, ValueEnum::Dynamic(*val), *val_ty)?;
}
// Handle assignment to the starred target
if let ExprKind::Starred { value: target, .. } = &targets[starred_target_index].node {
let vals = &tuple[a..b];
let val_tys = &tuple_tys[a..b];
// Create a sub-tuple from `value` for the starred target.
let sub_tuple_ty = ctx
.ctx
.struct_type(&vals.iter().map(BasicValueEnum::get_type).collect_vec(), false);
let psub_tuple_val =
ctx.builder.build_alloca(sub_tuple_ty, "starred_target_value_ptr").unwrap();
for (i, val) in vals.iter().enumerate() {
let pitem = ctx
.builder
.build_struct_gep(psub_tuple_val, i as u32, "starred_target_value_item")
.unwrap();
ctx.builder.build_store(pitem, *val).unwrap();
}
let sub_tuple_val =
ctx.builder.build_load(psub_tuple_val, "starred_target_value").unwrap();
// Create the typechecker type of the sub-tuple
let sub_tuple_ty =
ctx.unifier.add_ty(TypeEnum::TTuple { ty: val_tys.to_vec(), is_vararg_ctx: false });
// Now assign with that sub-tuple to the starred target.
generator.gen_assign(ctx, target, ValueEnum::Dynamic(sub_tuple_val), sub_tuple_ty)?;
} else {
unreachable!() // The typechecker ensures this
}
// Handle assignment after the starred target
for (target, val, val_ty) in
izip!(&targets[starred_target_index + 1..], &tuple[b..], &tuple_tys[b..])
{
generator.gen_assign(ctx, target, ValueEnum::Dynamic(*val), *val_ty)?;
}
} else {
assert_eq!(tuple_tys.len(), targets.len()); // The typechecker ensures this
for (target, val, val_ty) in izip!(targets, tuple, tuple_tys) {
generator.gen_assign(ctx, target, ValueEnum::Dynamic(val), *val_ty)?;
}
}
Ok(())
}
/// See [`CodeGenerator::gen_setitem`].
pub fn gen_setitem<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
target: &Expr<Option<Type>>,
key: &Expr<Option<Type>>,
value: ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String> {
let target_ty = target.custom.unwrap();
let key_ty = key.custom.unwrap();
match &*ctx.unifier.get_ty(target_ty) {
TypeEnum::TObj { obj_id, params: list_params, .. }
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{
// Handle list item assignment
let llvm_usize = generator.get_size_type(ctx.ctx);
let target_item_ty = iter_type_vars(list_params).next().unwrap().ty;
let target = generator
.gen_expr(ctx, target)?
.unwrap()
.to_basic_value_enum(ctx, generator, target_ty)?
.into_pointer_value();
let target = ListValue::from_ptr_val(target, llvm_usize, None);
if let ExprKind::Slice { .. } = &key.node {
// Handle assigning to a slice
let ExprKind::Slice { lower, upper, step } = &key.node else { unreachable!() };
let Some((start, end, step)) = handle_slice_indices(
lower,
upper,
step,
ctx,
generator,
target.load_size(ctx, None),
)?
else {
return Ok(());
};
let value =
value.to_basic_value_enum(ctx, generator, value_ty)?.into_pointer_value();
let value = ListValue::from_ptr_val(value, llvm_usize, None);
let target_item_ty = ctx.get_llvm_type(generator, target_item_ty);
let Some(src_ind) = handle_slice_indices(
&None,
&None,
&None,
ctx,
generator,
value.load_size(ctx, None),
)?
else {
return Ok(());
};
list_slice_assignment(
generator,
ctx,
target_item_ty,
target,
(start, end, step),
value,
src_ind,
);
} else {
// Handle assigning to an index
let len = target.load_size(ctx, Some("len"));
let index = generator
.gen_expr(ctx, key)?
.unwrap()
.to_basic_value_enum(ctx, generator, key_ty)?
.into_int_value();
let index = ctx
.builder
.build_int_s_extend(index, generator.get_size_type(ctx.ctx), "sext")
.unwrap();
// handle negative index
let is_negative = ctx
.builder
.build_int_compare(
IntPredicate::SLT,
index,
generator.get_size_type(ctx.ctx).const_zero(),
"is_neg",
)
.unwrap();
let adjusted = ctx.builder.build_int_add(index, len, "adjusted").unwrap();
let index = ctx
.builder
.build_select(is_negative, adjusted, index, "index")
.map(BasicValueEnum::into_int_value)
.unwrap();
// unsigned less than is enough, because negative index after adjustment is
// bigger than the length (for unsigned cmp)
let bound_check = ctx
.builder
.build_int_compare(IntPredicate::ULT, index, len, "inbound")
.unwrap();
ctx.make_assert(
generator,
bound_check,
"0:IndexError",
"index {0} out of bounds 0:{1}",
[Some(index), Some(len), None],
key.location,
);
// Write value to index on list
let item_ptr =
target.data().ptr_offset(ctx, generator, &index, Some("list_item_ptr"));
let value = value.to_basic_value_enum(ctx, generator, value_ty)?;
ctx.builder.build_store(item_ptr, value).unwrap();
}
}
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
{
// Handle NDArray item assignment
// Process target
let target = generator
.gen_expr(ctx, target)?
.unwrap()
.to_basic_value_enum(ctx, generator, target_ty)?;
let target = AnyObject { value: target, ty: target_ty };
// Process key
let key = gen_ndarray_subscript_ndindexes(generator, ctx, key)?;
// Process value
let value = value.to_basic_value_enum(ctx, generator, value_ty)?;
let value = AnyObject { value, ty: value_ty };
/*
Reference code:
```python
target = target[key]
value = np.asarray(value)
shape = np.broadcast_shape((target, value))
target = np.broadcast_to(target, shape)
value = np.broadcast_to(value, shape)
...and finally copy 1-1 from value to target.
```
*/
let target = NDArrayObject::from_object(generator, ctx, target);
let target = target.index(generator, ctx, &key, "assign_target_ndarray");
let value = split_scalar_or_ndarray(generator, ctx, value).as_ndarray(generator, ctx);
let broadcast_result = NDArrayObject::broadcast(generator, ctx, &[target, value]);
let target = broadcast_result.ndarrays[0];
let value = broadcast_result.ndarrays[1];
target.copy_data_from(generator, ctx, value);
}
_ => {
panic!("encountered unknown target type: {}", ctx.unifier.stringify(target_ty));
}
}
Ok(())
}
/// See [`CodeGenerator::gen_for`]. /// See [`CodeGenerator::gen_for`].
pub fn gen_for<G: CodeGenerator>( pub fn gen_for<G: CodeGenerator>(
generator: &mut G, generator: &mut G,
@ -315,9 +476,6 @@ pub fn gen_for<G: CodeGenerator>(
let orelse_bb = let orelse_bb =
if orelse.is_empty() { cont_bb } else { ctx.ctx.append_basic_block(current, "for.orelse") }; if orelse.is_empty() { cont_bb } else { ctx.ctx.append_basic_block(current, "for.orelse") };
// Whether the iterable is a range() expression
let is_iterable_range_expr = ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range);
// The BB containing the increment expression // The BB containing the increment expression
let incr_bb = ctx.ctx.append_basic_block(current, "for.incr"); let incr_bb = ctx.ctx.append_basic_block(current, "for.incr");
// The BB containing the loop condition check // The BB containing the loop condition check
@ -326,113 +484,132 @@ pub fn gen_for<G: CodeGenerator>(
// store loop bb information and restore it later // store loop bb information and restore it later
let loop_bb = ctx.loop_target.replace((incr_bb, cont_bb)); let loop_bb = ctx.loop_target.replace((incr_bb, cont_bb));
let iter_ty = iter.custom.unwrap();
let iter_val = if let Some(v) = generator.gen_expr(ctx, iter)? { let iter_val = if let Some(v) = generator.gen_expr(ctx, iter)? {
v.to_basic_value_enum(ctx, generator, iter.custom.unwrap())? v.to_basic_value_enum(ctx, generator, iter_ty)?
} else { } else {
return Ok(()); return Ok(());
}; };
if is_iterable_range_expr {
let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range"));
// Internal variable for loop; Cannot be assigned
let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?;
// Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed
let Some(target_i) = generator.gen_store_target(ctx, target, Some("for.target.addr"))?
else {
unreachable!()
};
let (start, stop, step) = destructure_range(ctx, iter_val);
ctx.builder.build_store(i, start).unwrap();
// Check "If step is zero, ValueError is raised."
let rangenez =
ctx.builder.build_int_compare(IntPredicate::NE, step, int32.const_zero(), "").unwrap();
ctx.make_assert(
generator,
rangenez,
"ValueError",
"range() arg 3 must not be zero",
[None, None, None],
ctx.current_loc,
);
ctx.builder.build_unconditional_branch(cond_bb).unwrap();
match &*ctx.unifier.get_ty(iter_ty) {
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() =>
{ {
ctx.builder.position_at_end(cond_bb); let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range"));
ctx.builder // Internal variable for loop; Cannot be assigned
.build_conditional_branch( let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?;
gen_in_range_check( // Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed
ctx, let Some(target_i) =
ctx.builder.build_load(i, "").map(BasicValueEnum::into_int_value).unwrap(), generator.gen_store_target(ctx, target, Some("for.target.addr"))?
stop, else {
step, unreachable!()
), };
body_bb, let (start, stop, step) = destructure_range(ctx, iter_val);
orelse_bb,
ctx.builder.build_store(i, start).unwrap();
// Check "If step is zero, ValueError is raised."
let rangenez = ctx
.builder
.build_int_compare(IntPredicate::NE, step, int32.const_zero(), "")
.unwrap();
ctx.make_assert(
generator,
rangenez,
"ValueError",
"range() arg 3 must not be zero",
[None, None, None],
ctx.current_loc,
);
ctx.builder.build_unconditional_branch(cond_bb).unwrap();
{
ctx.builder.position_at_end(cond_bb);
ctx.builder
.build_conditional_branch(
gen_in_range_check(
ctx,
ctx.builder
.build_load(i, "")
.map(BasicValueEnum::into_int_value)
.unwrap(),
stop,
step,
),
body_bb,
orelse_bb,
)
.unwrap();
}
ctx.builder.position_at_end(incr_bb);
let next_i = ctx
.builder
.build_int_add(
ctx.builder.build_load(i, "").map(BasicValueEnum::into_int_value).unwrap(),
step,
"inc",
) )
.unwrap(); .unwrap();
ctx.builder.build_store(i, next_i).unwrap();
ctx.builder.build_unconditional_branch(cond_bb).unwrap();
ctx.builder.position_at_end(body_bb);
ctx.builder
.build_store(
target_i,
ctx.builder.build_load(i, "").map(BasicValueEnum::into_int_value).unwrap(),
)
.unwrap();
generator.gen_block(ctx, body.iter())?;
} }
TypeEnum::TObj { obj_id, params: list_params, .. }
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{
let index_addr = generator.gen_var_alloc(ctx, size_t.into(), Some("for.index.addr"))?;
ctx.builder.build_store(index_addr, size_t.const_zero()).unwrap();
let len = ctx
.build_gep_and_load(
iter_val.into_pointer_value(),
&[zero, int32.const_int(1, false)],
Some("len"),
)
.into_int_value();
ctx.builder.build_unconditional_branch(cond_bb).unwrap();
ctx.builder.position_at_end(incr_bb); ctx.builder.position_at_end(cond_bb);
let next_i = ctx let index = ctx
.builder .builder
.build_int_add( .build_load(index_addr, "for.index")
ctx.builder.build_load(i, "").map(BasicValueEnum::into_int_value).unwrap(), .map(BasicValueEnum::into_int_value)
step, .unwrap();
"inc", let cmp = ctx.builder.build_int_compare(IntPredicate::SLT, index, len, "cond").unwrap();
) ctx.builder.build_conditional_branch(cmp, body_bb, orelse_bb).unwrap();
.unwrap();
ctx.builder.build_store(i, next_i).unwrap();
ctx.builder.build_unconditional_branch(cond_bb).unwrap();
ctx.builder.position_at_end(body_bb); ctx.builder.position_at_end(incr_bb);
ctx.builder let index =
.build_store( ctx.builder.build_load(index_addr, "").map(BasicValueEnum::into_int_value).unwrap();
target_i, let inc = ctx.builder.build_int_add(index, size_t.const_int(1, true), "inc").unwrap();
ctx.builder.build_load(i, "").map(BasicValueEnum::into_int_value).unwrap(), ctx.builder.build_store(index_addr, inc).unwrap();
) ctx.builder.build_unconditional_branch(cond_bb).unwrap();
.unwrap();
generator.gen_block(ctx, body.iter())?;
} else {
let index_addr = generator.gen_var_alloc(ctx, size_t.into(), Some("for.index.addr"))?;
ctx.builder.build_store(index_addr, size_t.const_zero()).unwrap();
let len = ctx
.build_gep_and_load(
iter_val.into_pointer_value(),
&[zero, int32.const_int(1, false)],
Some("len"),
)
.into_int_value();
ctx.builder.build_unconditional_branch(cond_bb).unwrap();
ctx.builder.position_at_end(cond_bb); ctx.builder.position_at_end(body_bb);
let index = ctx let arr_ptr = ctx
.builder .build_gep_and_load(iter_val.into_pointer_value(), &[zero, zero], Some("arr.addr"))
.build_load(index_addr, "for.index") .into_pointer_value();
.map(BasicValueEnum::into_int_value) let index = ctx
.unwrap(); .builder
let cmp = ctx.builder.build_int_compare(IntPredicate::SLT, index, len, "cond").unwrap(); .build_load(index_addr, "for.index")
ctx.builder.build_conditional_branch(cmp, body_bb, orelse_bb).unwrap(); .map(BasicValueEnum::into_int_value)
.unwrap();
ctx.builder.position_at_end(incr_bb); let val = ctx.build_gep_and_load(arr_ptr, &[index], Some("val"));
let index = let val_ty = iter_type_vars(list_params).next().unwrap().ty;
ctx.builder.build_load(index_addr, "").map(BasicValueEnum::into_int_value).unwrap(); generator.gen_assign(ctx, target, val.into(), val_ty)?;
let inc = ctx.builder.build_int_add(index, size_t.const_int(1, true), "inc").unwrap(); generator.gen_block(ctx, body.iter())?;
ctx.builder.build_store(index_addr, inc).unwrap(); }
ctx.builder.build_unconditional_branch(cond_bb).unwrap(); _ => {
panic!("unsupported for loop iterator type: {}", ctx.unifier.stringify(iter_ty));
ctx.builder.position_at_end(body_bb); }
let arr_ptr = ctx
.build_gep_and_load(iter_val.into_pointer_value(), &[zero, zero], Some("arr.addr"))
.into_pointer_value();
let index = ctx
.builder
.build_load(index_addr, "for.index")
.map(BasicValueEnum::into_int_value)
.unwrap();
let val = ctx.build_gep_and_load(arr_ptr, &[index], Some("val"));
generator.gen_assign(ctx, target, val.into())?;
generator.gen_block(ctx, body.iter())?;
} }
for (k, (_, _, counter)) in &var_assignment { for (k, (_, _, counter)) in &var_assignment {
@ -494,6 +671,7 @@ pub struct BreakContinueHooks<'ctx> {
pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>( pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
label: Option<&str>,
init: InitFn, init: InitFn,
cond: CondFn, cond: CondFn,
body: BodyFn, body: BodyFn,
@ -504,18 +682,24 @@ where
I: Clone, I: Clone,
InitFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<I, String>, InitFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<I, String>,
CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<IntValue<'ctx>, String>, CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<IntValue<'ctx>, String>,
BodyFn: BodyFn: FnOnce(
FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, BreakContinueHooks, I) -> Result<(), String>, &mut G,
&mut CodeGenContext<'ctx, 'a>,
BreakContinueHooks<'ctx>,
I,
) -> Result<(), String>,
UpdateFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>, UpdateFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
{ {
let label = label.unwrap_or("for");
let current_bb = ctx.builder.get_insert_block().unwrap(); let current_bb = ctx.builder.get_insert_block().unwrap();
let init_bb = ctx.ctx.insert_basic_block_after(current_bb, "for.init"); let init_bb = ctx.ctx.insert_basic_block_after(current_bb, &format!("{label}.init"));
// The BB containing the loop condition check // The BB containing the loop condition check
let cond_bb = ctx.ctx.insert_basic_block_after(init_bb, "for.cond"); let cond_bb = ctx.ctx.insert_basic_block_after(init_bb, &format!("{label}.cond"));
let body_bb = ctx.ctx.insert_basic_block_after(cond_bb, "for.body"); let body_bb = ctx.ctx.insert_basic_block_after(cond_bb, &format!("{label}.body"));
// The BB containing the increment expression // The BB containing the increment expression
let update_bb = ctx.ctx.insert_basic_block_after(body_bb, "for.update"); let update_bb = ctx.ctx.insert_basic_block_after(body_bb, &format!("{label}.update"));
let cont_bb = ctx.ctx.insert_basic_block_after(update_bb, "for.end"); let cont_bb = ctx.ctx.insert_basic_block_after(update_bb, &format!("{label}.end"));
// store loop bb information and restore it later // store loop bb information and restore it later
let loop_bb = ctx.loop_target.replace((update_bb, cont_bb)); let loop_bb = ctx.loop_target.replace((update_bb, cont_bb));
@ -572,6 +756,7 @@ where
pub fn gen_for_callback_incrementing<'ctx, 'a, G, BodyFn>( pub fn gen_for_callback_incrementing<'ctx, 'a, G, BodyFn>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
label: Option<&str>,
init_val: IntValue<'ctx>, init_val: IntValue<'ctx>,
max_val: (IntValue<'ctx>, bool), max_val: (IntValue<'ctx>, bool),
body: BodyFn, body: BodyFn,
@ -582,7 +767,7 @@ where
BodyFn: FnOnce( BodyFn: FnOnce(
&mut G, &mut G,
&mut CodeGenContext<'ctx, 'a>, &mut CodeGenContext<'ctx, 'a>,
BreakContinueHooks, BreakContinueHooks<'ctx>,
IntValue<'ctx>, IntValue<'ctx>,
) -> Result<(), String>, ) -> Result<(), String>,
{ {
@ -591,6 +776,7 @@ where
gen_for_callback( gen_for_callback(
generator, generator,
ctx, ctx,
label,
|generator, ctx| { |generator, ctx| {
let i_addr = generator.gen_var_alloc(ctx, init_val_t.into(), None)?; let i_addr = generator.gen_var_alloc(ctx, init_val_t.into(), None)?;
ctx.builder.build_store(i_addr, init_val).unwrap(); ctx.builder.build_store(i_addr, init_val).unwrap();
@ -642,9 +828,11 @@ where
/// - `step_fn`: A lambda of IR statements that retrieves the `step` value of the `range`-like /// - `step_fn`: A lambda of IR statements that retrieves the `step` value of the `range`-like
/// iterable. This value will be extended to the size of `start`. /// iterable. This value will be extended to the size of `start`.
/// - `body_fn`: A lambda of IR statements within the loop body. /// - `body_fn`: A lambda of IR statements within the loop body.
#[allow(clippy::too_many_arguments)]
pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>( pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
label: Option<&str>,
is_unsigned: bool, is_unsigned: bool,
start_fn: StartFn, start_fn: StartFn,
(stop_fn, stop_inclusive): (StopFn, bool), (stop_fn, stop_inclusive): (StopFn, bool),
@ -656,13 +844,19 @@ where
StartFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>, StartFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
StopFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>, StopFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
StepFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>, StepFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<(), String>, BodyFn: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
BreakContinueHooks,
IntValue<'ctx>,
) -> Result<(), String>,
{ {
let init_val_t = start_fn(generator, ctx).map(IntValue::get_type).unwrap(); let init_val_t = start_fn(generator, ctx).map(IntValue::get_type).unwrap();
gen_for_callback( gen_for_callback(
generator, generator,
ctx, ctx,
label,
|generator, ctx| { |generator, ctx| {
let i_addr = generator.gen_var_alloc(ctx, init_val_t.into(), None)?; let i_addr = generator.gen_var_alloc(ctx, init_val_t.into(), None)?;
@ -720,10 +914,10 @@ where
Ok(cond) Ok(cond)
}, },
|generator, ctx, _, (i_addr, _)| { |generator, ctx, hooks, (i_addr, _)| {
let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
body_fn(generator, ctx, i) body_fn(generator, ctx, hooks, i)
}, },
|generator, ctx, (i_addr, _)| { |generator, ctx, (i_addr, _)| {
let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
@ -1113,47 +1307,36 @@ pub fn exn_constructor<'ctx>(
pub fn gen_raise<'ctx, G: CodeGenerator + ?Sized>( pub fn gen_raise<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
exception: Option<&BasicValueEnum<'ctx>>, exception: Option<Ptr<'ctx, StructModel<Exception>>>,
loc: Location, loc: Location,
) { ) {
if let Some(exception) = exception { if let Some(pexn) = exception {
unsafe { let i32_model = IntModel(Int32);
let int32 = ctx.ctx.i32_type(); let cslice_model = StructModel(CSlice);
let zero = int32.const_zero();
let exception = exception.into_pointer_value();
let file_ptr = ctx
.builder
.build_in_bounds_gep(exception, &[zero, int32.const_int(1, false)], "file_ptr")
.unwrap();
let filename = ctx.gen_string(generator, loc.file.0);
ctx.builder.build_store(file_ptr, filename).unwrap();
let row_ptr = ctx
.builder
.build_in_bounds_gep(exception, &[zero, int32.const_int(2, false)], "row_ptr")
.unwrap();
ctx.builder.build_store(row_ptr, int32.const_int(loc.row as u64, false)).unwrap();
let col_ptr = ctx
.builder
.build_in_bounds_gep(exception, &[zero, int32.const_int(3, false)], "col_ptr")
.unwrap();
ctx.builder.build_store(col_ptr, int32.const_int(loc.column as u64, false)).unwrap();
let current_fun = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); // Get and store filename
let fun_name = ctx.gen_string(generator, current_fun.get_name().to_str().unwrap()); let filename = loc.file.0;
let name_ptr = ctx let filename = ctx.gen_string(generator, &String::from(filename)).value;
.builder let filename = cslice_model.check_value(generator, ctx.ctx, filename).unwrap();
.build_in_bounds_gep(exception, &[zero, int32.const_int(4, false)], "name_ptr") pexn.set(ctx, |f| f.filename, filename);
.unwrap();
ctx.builder.build_store(name_ptr, fun_name).unwrap(); let row = i32_model.constant(generator, ctx.ctx, loc.row as u64);
} pexn.set(ctx, |f| f.line, row);
let column = i32_model.constant(generator, ctx.ctx, loc.column as u64);
pexn.set(ctx, |f| f.column, column);
let current_fn = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
let fn_name = ctx.gen_string(generator, current_fn.get_name().to_str().unwrap());
pexn.set(ctx, |f| f.function, fn_name);
let raise = get_builtins(generator, ctx, "__nac3_raise"); let raise = get_builtins(generator, ctx, "__nac3_raise");
let exception = *exception; ctx.build_call_or_invoke(raise, &[pexn.value.into()], "raise");
ctx.build_call_or_invoke(raise, &[exception], "raise");
} else { } else {
let resume = get_builtins(generator, ctx, "__nac3_resume"); let resume = get_builtins(generator, ctx, "__nac3_resume");
ctx.build_call_or_invoke(resume, &[], "resume"); ctx.build_call_or_invoke(resume, &[], "resume");
} }
ctx.builder.build_unreachable().unwrap(); ctx.builder.build_unreachable().unwrap();
} }
@ -1575,14 +1758,14 @@ pub fn gen_stmt<G: CodeGenerator>(
} }
StmtKind::AnnAssign { target, value, .. } => { StmtKind::AnnAssign { target, value, .. } => {
if let Some(value) = value { if let Some(value) = value {
let Some(value) = generator.gen_expr(ctx, value)? else { return Ok(()) }; let Some(value_enum) = generator.gen_expr(ctx, value)? else { return Ok(()) };
generator.gen_assign(ctx, target, value)?; generator.gen_assign(ctx, target, value_enum, value.custom.unwrap())?;
} }
} }
StmtKind::Assign { targets, value, .. } => { StmtKind::Assign { targets, value, .. } => {
let Some(value) = generator.gen_expr(ctx, value)? else { return Ok(()) }; let Some(value_enum) = generator.gen_expr(ctx, value)? else { return Ok(()) };
for target in targets { for target in targets {
generator.gen_assign(ctx, target, value.clone())?; generator.gen_assign(ctx, target, value_enum.clone(), value.custom.unwrap())?;
} }
} }
StmtKind::Continue { .. } => { StmtKind::Continue { .. } => {
@ -1596,15 +1779,16 @@ pub fn gen_stmt<G: CodeGenerator>(
StmtKind::For { .. } => generator.gen_for(ctx, stmt)?, StmtKind::For { .. } => generator.gen_for(ctx, stmt)?,
StmtKind::With { .. } => generator.gen_with(ctx, stmt)?, StmtKind::With { .. } => generator.gen_with(ctx, stmt)?,
StmtKind::AugAssign { target, op, value, .. } => { StmtKind::AugAssign { target, op, value, .. } => {
let value = gen_binop_expr( let value_enum = gen_binop_expr(
generator, generator,
ctx, ctx,
target, target,
Binop::aug_assign(*op), Binop::aug_assign(*op),
value, value,
stmt.location, stmt.location,
)?; )?
generator.gen_assign(ctx, target, value.unwrap())?; .unwrap();
generator.gen_assign(ctx, target, value_enum, value.custom.unwrap())?;
} }
StmtKind::Try { .. } => gen_try(generator, ctx, stmt)?, StmtKind::Try { .. } => gen_try(generator, ctx, stmt)?,
StmtKind::Raise { exc, .. } => { StmtKind::Raise { exc, .. } => {
@ -1614,30 +1798,41 @@ pub fn gen_stmt<G: CodeGenerator>(
} else { } else {
return Ok(()); return Ok(());
}; };
gen_raise(generator, ctx, Some(&exc), stmt.location);
let pexn_model = PtrModel(StructModel(Exception));
let exn = pexn_model.check_value(generator, ctx.ctx, exc).unwrap();
gen_raise(generator, ctx, Some(exn), stmt.location);
} else { } else {
gen_raise(generator, ctx, None, stmt.location); gen_raise(generator, ctx, None, stmt.location);
} }
} }
StmtKind::Assert { test, msg, .. } => { StmtKind::Assert { test, msg, .. } => {
let test = if let Some(v) = generator.gen_expr(ctx, test)? { let byte_model = IntModel(Byte);
v.to_basic_value_enum(ctx, generator, test.custom.unwrap())? let cslice_model = StructModel(CSlice);
} else {
let Some(test) = generator.gen_expr(ctx, test)? else {
return Ok(()); return Ok(());
}; };
let test = test.to_basic_value_enum(ctx, generator, ctx.primitives.bool)?;
let test = byte_model.check_value(generator, ctx.ctx, test).unwrap(); // Python `bool` is represented as `i8` in nac3core
// Check `msg`
let err_msg = match msg { let err_msg = match msg {
Some(msg) => { Some(msg) => {
if let Some(v) = generator.gen_expr(ctx, msg)? { let Some(msg) = generator.gen_expr(ctx, msg)? else {
v.to_basic_value_enum(ctx, generator, msg.custom.unwrap())?
} else {
return Ok(()); return Ok(());
} };
let msg = msg.to_basic_value_enum(ctx, generator, ctx.primitives.str)?;
cslice_model.check_value(generator, ctx.ctx, msg).unwrap()
} }
None => ctx.gen_string(generator, ""), None => ctx.gen_string(generator, ""),
}; };
ctx.make_assert_impl( ctx.make_assert_impl(
generator, generator,
test.into_int_value(), test.value,
"0:AssertionError", "0:AssertionError",
err_msg, err_msg,
[None, None, None], [None, None, None],

View File

@ -0,0 +1,256 @@
use inkwell::context::Context;
use crate::codegen::model::*;
use super::{CodeGenContext, CodeGenerator};
/// Fields of [`CSlice`]
pub struct CSliceFields<'ctx, F: FieldTraversal<'ctx>> {
/// Pointer to data.
pub base: F::Out<PtrModel<IntModel<Byte>>>,
/// Number of bytes of data.
pub len: F::Out<IntModel<SizeT>>,
}
/// See <https://crates.io/crates/cslice>.
///
/// Additionally, see <https://github.com/m-labs/artiq/blob/b0d2705c385f64b6e6711c1726cd9178f40b598e/artiq/firmware/libeh/eh_artiq.rs>)
/// for ARTIQ-specific notes.
#[derive(Debug, Clone, Copy, Default)]
pub struct CSlice;
impl<'ctx> StructKind<'ctx> for CSlice {
type Fields<F: FieldTraversal<'ctx>> = CSliceFields<'ctx, F>;
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
Self::Fields { base: traversal.add_auto("base"), len: traversal.add_auto("len") }
}
}
impl StructModel<CSlice> {
/// Create a [`CSlice`].
///
/// `base` and `len` must be LLVM global constants.
pub fn create_const<'ctx, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
base: Ptr<'ctx, IntModel<Byte>>,
len: Int<'ctx, SizeT>,
) -> Struct<'ctx, CSlice> {
let value = self
.0
.get_struct_type(generator, ctx)
.const_named_struct(&[base.value.into(), len.value.into()]);
self.believe_value(value)
}
}
/// The LLVM int type of an Exception ID.
pub type ExceptionId = Int32;
/// Fields of [`Exception<'ctx>`]
///
/// The definition came from `pub struct Exception<'a>` in
/// <https://github.com/m-labs/artiq/blob/master/artiq/firmware/libeh/eh_artiq.rs>.
pub struct ExceptionFields<'ctx, F: FieldTraversal<'ctx>> {
/// nac3core's ID of the exception
pub id: F::Out<IntModel<ExceptionId>>,
/// The name of the file this `Exception` was raised in.
pub filename: F::Out<StructModel<CSlice>>,
/// The line number in the file this `Exception` was raised in.
pub line: F::Out<IntModel<Int32>>,
/// The column number in the file this `Exception` was raised in.
pub column: F::Out<IntModel<Int32>>,
/// The name of the Python function this `Exception` was raised in.
pub function: F::Out<StructModel<CSlice>>,
/// The message of this Exception.
///
/// The message can optionally contain integer parameters `{0}`, `{1}`, and `{2}` in its string,
/// where they will be substituted by `params[0]`, `params[1]`, and `params[2]` respectively (as `int64_t`s).
/// Here is an example:
///
/// ```ignore
/// "Index {0} is out of bounds! List only has {1} element(s)."
/// ```
///
/// In this case, `params[0]` and `params[1]` must be specified, and `params[2]` is ***unused***.
/// Having only 3 parameters is a constraint in ARTIQ.
pub msg: F::Out<StructModel<CSlice>>,
pub params: [F::Out<IntModel<Int64>>; 3],
}
/// nac3core & ARTIQ's Exception
#[derive(Debug, Clone, Copy, Default)]
pub struct Exception;
impl<'ctx> StructKind<'ctx> for Exception {
type Fields<F: FieldTraversal<'ctx>> = ExceptionFields<'ctx, F>;
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
Self::Fields {
id: traversal.add_auto("id"),
filename: traversal.add_auto("filename"),
line: traversal.add_auto("line"),
column: traversal.add_auto("column"),
function: traversal.add_auto("function"),
msg: traversal.add_auto("msg"),
params: [
traversal.add_auto("params[0]"),
traversal.add_auto("params[1]"),
traversal.add_auto("params[2]"),
],
}
}
}
/// Fields of [`List`]
pub struct ListFields<'ctx, F: FieldTraversal<'ctx>, Item: Model<'ctx>> {
/// Array pointer to content
pub items: F::Out<PtrModel<Item>>,
/// Number of items in the array
pub len: F::Out<IntModel<SizeT>>,
}
/// A list in NAC3.
#[derive(Debug, Clone, Copy, Default)]
pub struct List<Item> {
/// Model of the list items
pub item: Item,
}
impl<'ctx, Item: Model<'ctx>> StructKind<'ctx> for List<Item> {
type Fields<F: FieldTraversal<'ctx>> = ListFields<'ctx, F, Item>;
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
Self::Fields {
items: traversal.add("data", PtrModel(self.item)),
len: traversal.add_auto("len"),
}
}
}
/// Fields of [`NDArray`]
pub struct NDArrayFields<'ctx, F: FieldTraversal<'ctx>> {
pub data: F::Out<PtrModel<IntModel<Byte>>>,
pub itemsize: F::Out<IntModel<SizeT>>,
pub ndims: F::Out<IntModel<SizeT>>,
pub shape: F::Out<PtrModel<IntModel<SizeT>>>,
pub strides: F::Out<PtrModel<IntModel<SizeT>>>,
}
/// A strided ndarray in NAC3.
///
/// See IRRT implementation for details about its fields.
#[derive(Debug, Clone, Copy, Default)]
pub struct NDArray;
impl<'ctx> StructKind<'ctx> for NDArray {
type Fields<F: FieldTraversal<'ctx>> = NDArrayFields<'ctx, F>;
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
Self::Fields {
data: traversal.add_auto("data"),
itemsize: traversal.add_auto("itemsize"),
ndims: traversal.add_auto("ndims"),
shape: traversal.add_auto("shape"),
strides: traversal.add_auto("strides"),
}
}
}
/// Fields of [`SimpleNDArray`]
#[derive(Debug, Clone, Copy)]
pub struct SimpleNDArrayFields<'ctx, F: FieldTraversal<'ctx>, Item: Model<'ctx>> {
pub ndims: F::Out<IntModel<SizeT>>,
pub shape: F::Out<PtrModel<IntModel<SizeT>>>,
pub data: F::Out<PtrModel<Item>>,
}
/// An ndarray without strides and non-opaque `data` field in NAC3.
#[derive(Debug, Clone, Copy)]
pub struct SimpleNDArray<M> {
/// [`Model`] of the items.
pub item: M,
}
impl<'ctx, Item: Model<'ctx>> StructKind<'ctx> for SimpleNDArray<Item> {
type Fields<F: FieldTraversal<'ctx>> = SimpleNDArrayFields<'ctx, F, Item>;
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
Self::Fields {
ndims: traversal.add_auto("ndims"),
shape: traversal.add_auto("shape"),
data: traversal.add("data", PtrModel(self.item)),
}
}
}
/// Fields of [`NDIter`]
pub struct NDIterFields<'ctx, F: FieldTraversal<'ctx>> {
pub ndims: F::Out<IntModel<SizeT>>,
pub shape: F::Out<PtrModel<IntModel<SizeT>>>,
pub strides: F::Out<PtrModel<IntModel<SizeT>>>,
pub indices: F::Out<PtrModel<IntModel<SizeT>>>,
pub nth: F::Out<IntModel<SizeT>>,
pub element: F::Out<PtrModel<IntModel<Byte>>>,
pub size: F::Out<IntModel<SizeT>>,
}
/// An IRRT helper structure used when iterating through an ndarray.
#[derive(Debug, Clone, Copy, Default)]
pub struct NDIter;
impl<'ctx> StructKind<'ctx> for NDIter {
type Fields<F: FieldTraversal<'ctx>> = NDIterFields<'ctx, F>;
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
Self::Fields {
ndims: traversal.add_auto("ndims"),
shape: traversal.add_auto("shape"),
strides: traversal.add_auto("strides"),
indices: traversal.add_auto("indices"),
nth: traversal.add_auto("nth"),
element: traversal.add_auto("element"),
size: traversal.add_auto("size"),
}
}
}
/// A NAC3 `range`. It is an array of 3 int32s.
// TODO: Use `pub type RangeModel<N> = NArrayModel<3, IntModel<N>>` in the future when
// `range` type is type dependent.
pub type RangeModel = NArrayModel<3, IntModel<Int32>>;
impl<'ctx> Ptr<'ctx, RangeModel> {
pub fn gep_start<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: &str,
) -> Ptr<'ctx, IntModel<Int32>> {
self.at_const(generator, ctx, 0, name)
}
pub fn gep_stop<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: &str,
) -> Ptr<'ctx, IntModel<Int32>> {
self.at_const(generator, ctx, 1, name)
}
pub fn gep_step<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: &str,
) -> Ptr<'ctx, IntModel<Int32>> {
self.at_const(generator, ctx, 2, name)
}
}

View File

@ -94,7 +94,7 @@ fn test_primitives() {
"}; "};
let statements = parse_program(source, FileName::default()).unwrap(); let statements = parse_program(source, FileName::default()).unwrap();
let composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 32).0; let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 32).0;
let mut unifier = composer.unifier.clone(); let mut unifier = composer.unifier.clone();
let primitives = composer.primitives_ty; let primitives = composer.primitives_ty;
let top_level = Arc::new(composer.make_top_level_context()); let top_level = Arc::new(composer.make_top_level_context());
@ -109,8 +109,18 @@ fn test_primitives() {
let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()]; let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()];
let signature = FunSignature { let signature = FunSignature {
args: vec![ args: vec![
FuncArg { name: "a".into(), ty: primitives.int32, default_value: None }, FuncArg {
FuncArg { name: "b".into(), ty: primitives.int32, default_value: None }, name: "a".into(),
ty: primitives.int32,
default_value: None,
is_vararg: false,
},
FuncArg {
name: "b".into(),
ty: primitives.int32,
default_value: None,
is_vararg: false,
},
], ],
ret: primitives.int32, ret: primitives.int32,
vars: VarMap::new(), vars: VarMap::new(),
@ -189,6 +199,8 @@ fn test_primitives() {
let expected = indoc! {" let expected = indoc! {"
; ModuleID = 'test' ; ModuleID = 'test'
source_filename = \"test\" source_filename = \"test\"
target datalayout = \"e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128\"
target triple = \"x86_64-unknown-linux-gnu\"
; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn ; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn
define i32 @testing(i32 %0, i32 %1) local_unnamed_addr #0 !dbg !4 { define i32 @testing(i32 %0, i32 %1) local_unnamed_addr #0 !dbg !4 {
@ -246,14 +258,19 @@ fn test_simple_call() {
"}; "};
let statements_2 = parse_program(source_2, FileName::default()).unwrap(); let statements_2 = parse_program(source_2, FileName::default()).unwrap();
let composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 32).0; let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 32).0;
let mut unifier = composer.unifier.clone(); let mut unifier = composer.unifier.clone();
let primitives = composer.primitives_ty; let primitives = composer.primitives_ty;
let top_level = Arc::new(composer.make_top_level_context()); let top_level = Arc::new(composer.make_top_level_context());
unifier.top_level = Some(top_level.clone()); unifier.top_level = Some(top_level.clone());
let signature = FunSignature { let signature = FunSignature {
args: vec![FuncArg { name: "a".into(), ty: primitives.int32, default_value: None }], args: vec![FuncArg {
name: "a".into(),
ty: primitives.int32,
default_value: None,
is_vararg: false,
}],
ret: primitives.int32, ret: primitives.int32,
vars: VarMap::new(), vars: VarMap::new(),
}; };
@ -368,6 +385,8 @@ fn test_simple_call() {
let expected = indoc! {" let expected = indoc! {"
; ModuleID = 'test' ; ModuleID = 'test'
source_filename = \"test\" source_filename = \"test\"
target datalayout = \"e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128\"
target triple = \"x86_64-unknown-linux-gnu\"
; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn ; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn
define i32 @testing(i32 %0) local_unnamed_addr #0 !dbg !5 { define i32 @testing(i32 %0) local_unnamed_addr #0 !dbg !5 {

View File

@ -78,14 +78,14 @@ impl SymbolValue {
} }
Constant::Tuple(t) => { Constant::Tuple(t) => {
let expected_ty = unifier.get_ty(expected_ty); let expected_ty = unifier.get_ty(expected_ty);
let TypeEnum::TTuple { ty } = expected_ty.as_ref() else { let TypeEnum::TTuple { ty, is_vararg_ctx } = expected_ty.as_ref() else {
return Err(format!( return Err(format!(
"Expected {:?}, but got Tuple", "Expected {:?}, but got Tuple",
expected_ty.get_type_name() expected_ty.get_type_name()
)); ));
}; };
assert_eq!(ty.len(), t.len()); assert!(*is_vararg_ctx || ty.len() == t.len());
let elems = t let elems = t
.iter() .iter()
@ -155,7 +155,7 @@ impl SymbolValue {
SymbolValue::Bool(_) => primitives.bool, SymbolValue::Bool(_) => primitives.bool,
SymbolValue::Tuple(vs) => { SymbolValue::Tuple(vs) => {
let vs_tys = vs.iter().map(|v| v.get_type(primitives, unifier)).collect::<Vec<_>>(); let vs_tys = vs.iter().map(|v| v.get_type(primitives, unifier)).collect::<Vec<_>>();
unifier.add_ty(TypeEnum::TTuple { ty: vs_tys }) unifier.add_ty(TypeEnum::TTuple { ty: vs_tys, is_vararg_ctx: false })
} }
SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option, SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option,
} }
@ -482,7 +482,7 @@ pub fn parse_type_annotation<T>(
parse_type_annotation(resolver, top_level_defs, unifier, primitives, elt) parse_type_annotation(resolver, top_level_defs, unifier, primitives, elt)
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
Ok(unifier.add_ty(TypeEnum::TTuple { ty })) Ok(unifier.add_ty(TypeEnum::TTuple { ty, is_vararg_ctx: false }))
} else { } else {
Err(HashSet::from(["Expected multiple elements for tuple".into()])) Err(HashSet::from(["Expected multiple elements for tuple".into()]))
} }

File diff suppressed because it is too large Load Diff

View File

@ -44,12 +44,27 @@ pub struct TopLevelComposer {
pub size_t: u32, pub size_t: u32,
} }
/// The specification for a builtin function, consisting of the function name, the function
/// signature, and a [code generation callback][`GenCall`].
pub type BuiltinFuncSpec = (StrRef, FunSignature, Arc<GenCall>);
/// A function that creates a [`BuiltinFuncSpec`] using the provided [`PrimitiveStore`] and
/// [`Unifier`].
pub type BuiltinFuncCreator = dyn Fn(&PrimitiveStore, &mut Unifier) -> BuiltinFuncSpec;
impl TopLevelComposer { impl TopLevelComposer {
/// return a composer and things to make a "primitive" symbol resolver, so that the symbol /// return a composer and things to make a "primitive" symbol resolver, so that the symbol
/// resolver can later figure out primitive type definitions when passed a primitive type name /// resolver can later figure out primitive tye definitions when passed a primitive type name
///
/// `lateinit_builtins` are specifically for the ARTIQ module. Since the [`Unifier`] instance
/// used to create builtin functions do not persist until method compilation, any types
/// created (e.g. [`TypeEnum::TVar`]) also do not persist. Those functions should be instead put
/// in `lateinit_builtins`, where they will be instantiated with the [`Unifier`] instance used
/// for method compilation.
#[must_use] #[must_use]
pub fn new( pub fn new(
builtins: Vec<(StrRef, FunSignature, Arc<GenCall>)>, builtins: Vec<BuiltinFuncSpec>,
lateinit_builtins: Vec<Box<BuiltinFuncCreator>>,
core_config: ComposerConfig, core_config: ComposerConfig,
size_t: u32, size_t: u32,
) -> (Self, HashMap<StrRef, DefinitionId>, HashMap<StrRef, Type>) { ) -> (Self, HashMap<StrRef, DefinitionId>, HashMap<StrRef, Type>) {
@ -119,7 +134,13 @@ impl TopLevelComposer {
} }
} }
for (name, sig, codegen_callback) in builtins { // Materialize lateinit_builtins, now that the unifier is ready
let lateinit_builtins = lateinit_builtins
.into_iter()
.map(|builtin| builtin(&primitives_ty, &mut unifier))
.collect_vec();
for (name, sig, codegen_callback) in builtins.into_iter().chain(lateinit_builtins) {
let fun_sig = unifier.add_ty(TypeEnum::TFunc(sig)); let fun_sig = unifier.add_ty(TypeEnum::TFunc(sig));
builtin_ty.insert(name, fun_sig); builtin_ty.insert(name, fun_sig);
builtin_id.insert(name, DefinitionId(definition_ast_list.len())); builtin_id.insert(name, DefinitionId(definition_ast_list.len()));
@ -766,6 +787,7 @@ impl TopLevelComposer {
let target_ty = get_type_from_type_annotation_kinds( let target_ty = get_type_from_type_annotation_kinds(
&temp_def_list, &temp_def_list,
unifier, unifier,
primitives,
&def, &def,
&mut subst_list, &mut subst_list,
)?; )?;
@ -859,7 +881,73 @@ impl TopLevelComposer {
let resolver = &**resolver; let resolver = &**resolver;
let mut function_var_map = VarMap::new(); let mut function_var_map = VarMap::new();
let arg_types = {
let vararg = args
.vararg
.as_ref()
.map(|vararg| -> Result<_, HashSet<String>> {
let vararg = vararg.as_ref();
let annotation = vararg
.node
.annotation
.as_ref()
.ok_or_else(|| {
HashSet::from([format!(
"function parameter `{}` needs type annotation at {}",
vararg.node.arg, vararg.location
)])
})?
.as_ref();
let type_annotation = parse_ast_to_type_annotation_kinds(
resolver,
temp_def_list.as_slice(),
unifier,
primitives_store,
annotation,
// NOTE: since only class need this, for function
// it should be fine to be empty map
HashMap::new(),
)?;
let type_vars_within =
get_type_var_contained_in_type_annotation(&type_annotation)
.into_iter()
.map(|x| -> Result<TypeVar, HashSet<String>> {
let TypeAnnotation::TypeVar(ty) = x else {
unreachable!("must be type var annotation kind")
};
let id = Self::get_var_id(ty, unifier)?;
Ok(TypeVar { id, ty })
})
.collect::<Result<Vec<_>, _>>()?;
for var in type_vars_within {
if let Some(prev_ty) = function_var_map.insert(var.id, var.ty) {
// if already have the type inserted, make sure they are the same thing
assert_eq!(prev_ty, var.ty);
}
}
let ty = get_type_from_type_annotation_kinds(
temp_def_list.as_ref(),
unifier,
primitives_store,
&type_annotation,
&mut None,
)?;
Ok(FuncArg {
name: vararg.node.arg,
ty,
default_value: Some(SymbolValue::Tuple(Vec::default())),
is_vararg: true,
})
})
.transpose()?;
let mut arg_types = {
// make sure no duplicate parameter // make sure no duplicate parameter
let mut defined_parameter_name: HashSet<_> = HashSet::new(); let mut defined_parameter_name: HashSet<_> = HashSet::new();
for x in &args.args { for x in &args.args {
@ -936,6 +1024,7 @@ impl TopLevelComposer {
let ty = get_type_from_type_annotation_kinds( let ty = get_type_from_type_annotation_kinds(
temp_def_list.as_ref(), temp_def_list.as_ref(),
unifier, unifier,
primitives_store,
&type_annotation, &type_annotation,
&mut None, &mut None,
)?; )?;
@ -959,11 +1048,18 @@ impl TopLevelComposer {
v v
}), }),
}, },
is_vararg: false,
}) })
}) })
.collect::<Result<Vec<_>, _>>()? .collect::<Result<Vec<_>, _>>()?
}; };
if let Some(vararg) = vararg {
arg_types.push(vararg);
};
let arg_types = arg_types;
let return_ty = { let return_ty = {
if let Some(returns) = returns { if let Some(returns) = returns {
let return_ty_annotation = { let return_ty_annotation = {
@ -1002,6 +1098,7 @@ impl TopLevelComposer {
get_type_from_type_annotation_kinds( get_type_from_type_annotation_kinds(
&temp_def_list, &temp_def_list,
unifier, unifier,
primitives_store,
&return_ty_annotation, &return_ty_annotation,
&mut None, &mut None,
)? )?
@ -1214,6 +1311,7 @@ impl TopLevelComposer {
}) })
} }
}, },
is_vararg: false,
}; };
// push the dummy type and the type annotation // push the dummy type and the type annotation
// into the list for later unification // into the list for later unification
@ -1622,6 +1720,7 @@ impl TopLevelComposer {
let self_type = get_type_from_type_annotation_kinds( let self_type = get_type_from_type_annotation_kinds(
&def_list, &def_list,
unifier, unifier,
primitives_ty,
&make_self_type_annotation(type_vars, *object_id), &make_self_type_annotation(type_vars, *object_id),
&mut None, &mut None,
)?; )?;
@ -1638,21 +1737,25 @@ impl TopLevelComposer {
name: "msg".into(), name: "msg".into(),
ty: string, ty: string,
default_value: Some(SymbolValue::Str(String::new())), default_value: Some(SymbolValue::Str(String::new())),
is_vararg: false,
}, },
FuncArg { FuncArg {
name: "param0".into(), name: "param0".into(),
ty: int64, ty: int64,
default_value: Some(SymbolValue::I64(0)), default_value: Some(SymbolValue::I64(0)),
is_vararg: false,
}, },
FuncArg { FuncArg {
name: "param1".into(), name: "param1".into(),
ty: int64, ty: int64,
default_value: Some(SymbolValue::I64(0)), default_value: Some(SymbolValue::I64(0)),
is_vararg: false,
}, },
FuncArg { FuncArg {
name: "param2".into(), name: "param2".into(),
ty: int64, ty: int64,
default_value: Some(SymbolValue::I64(0)), default_value: Some(SymbolValue::I64(0)),
is_vararg: false,
}, },
], ],
ret: self_type, ret: self_type,
@ -1803,7 +1906,11 @@ impl TopLevelComposer {
let ty_ann = make_self_type_annotation(type_vars, *class_id); let ty_ann = make_self_type_annotation(type_vars, *class_id);
let self_ty = get_type_from_type_annotation_kinds( let self_ty = get_type_from_type_annotation_kinds(
&def_list, unifier, &ty_ann, &mut None, &def_list,
unifier,
primitives_ty,
&ty_ann,
&mut None,
)?; )?;
vars.extend(type_vars.iter().map(|ty| { vars.extend(type_vars.iter().map(|ty| {
let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*ty) else { let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*ty) else {
@ -1858,6 +1965,7 @@ impl TopLevelComposer {
name: a.name, name: a.name,
ty: unifier.subst(a.ty, &subst).unwrap_or(a.ty), ty: unifier.subst(a.ty, &subst).unwrap_or(a.ty),
default_value: a.default_value.clone(), default_value: a.default_value.clone(),
is_vararg: false,
}) })
.collect_vec() .collect_vec()
}; };

View File

@ -27,17 +27,22 @@ pub enum PrimDef {
List, List,
NDArray, NDArray,
// Member Functions // Option methods
OptionIsSome, FunOptionIsSome,
OptionIsNone, FunOptionIsNone,
OptionUnwrap, FunOptionUnwrap,
NDArrayCopy,
NDArrayFill, // Option-related functions
FunInt32, FunSome,
FunInt64,
FunUInt32, // NDArray methods
FunUInt64, FunNDArrayCopy,
FunFloat, FunNDArrayFill,
// Range methods
FunRangeInit,
// NumPy factory functions
FunNpNDArray, FunNpNDArray,
FunNpEmpty, FunNpEmpty,
FunNpZeros, FunNpZeros,
@ -46,26 +51,28 @@ pub enum PrimDef {
FunNpArray, FunNpArray,
FunNpEye, FunNpEye,
FunNpIdentity, FunNpIdentity,
FunRound, FunNpArange,
FunRound64,
// NumPy view functions
FunNpBroadcastTo,
FunNpReshape,
FunNpTranspose,
// NumPy NDArray property getters
FunNpSize,
FunNpShape,
FunNpStrides,
// Miscellaneous NumPy & SciPy functions
FunNpRound, FunNpRound,
FunRangeInit,
FunStr,
FunBool,
FunFloor,
FunFloor64,
FunNpFloor, FunNpFloor,
FunCeil,
FunCeil64,
FunNpCeil, FunNpCeil,
FunLen,
FunMin,
FunNpMin, FunNpMin,
FunNpMinimum, FunNpMinimum,
FunMax, FunNpArgmin,
FunNpMax, FunNpMax,
FunNpMaximum, FunNpMaximum,
FunAbs, FunNpArgmax,
FunNpIsNan, FunNpIsNan,
FunNpIsInf, FunNpIsInf,
FunNpSin, FunNpSin,
@ -104,14 +111,43 @@ pub enum PrimDef {
FunNpHypot, FunNpHypot,
FunNpNextAfter, FunNpNextAfter,
// Top-Level Functions // Linalg functions
FunSome, FunNpDot,
FunNpLinalgCholesky,
FunNpLinalgQr,
FunNpLinalgSvd,
FunNpLinalgInv,
FunNpLinalgPinv,
FunNpLinalgMatrixPower,
FunNpLinalgDet,
FunSpLinalgLu,
FunSpLinalgSchur,
FunSpLinalgHessenberg,
// Miscellaneous Python & NAC3 functions
FunInt32,
FunInt64,
FunUInt32,
FunUInt64,
FunFloat,
FunRound,
FunRound64,
FunStr,
FunBool,
FunFloor,
FunFloor64,
FunCeil,
FunCeil64,
FunLen,
FunMin,
FunMax,
FunAbs,
} }
/// Associated details of a [`PrimDef`] /// Associated details of a [`PrimDef`]
pub enum PrimDefDetails { pub enum PrimDefDetails {
PrimFunction { name: &'static str, simple_name: &'static str }, PrimFunction { name: &'static str, simple_name: &'static str },
PrimClass { name: &'static str }, PrimClass { name: &'static str, get_ty_fn: fn(&PrimitiveStore) -> Type },
} }
impl PrimDef { impl PrimDef {
@ -153,15 +189,17 @@ impl PrimDef {
#[must_use] #[must_use]
pub fn name(&self) -> &'static str { pub fn name(&self) -> &'static str {
match self.details() { match self.details() {
PrimDefDetails::PrimFunction { name, .. } | PrimDefDetails::PrimClass { name } => name, PrimDefDetails::PrimFunction { name, .. } | PrimDefDetails::PrimClass { name, .. } => {
name
}
} }
} }
/// Get the associated details of this [`PrimDef`] /// Get the associated details of this [`PrimDef`]
#[must_use] #[must_use]
pub fn details(self) -> PrimDefDetails { pub fn details(self) -> PrimDefDetails {
fn class(name: &'static str) -> PrimDefDetails { fn class(name: &'static str, get_ty_fn: fn(&PrimitiveStore) -> Type) -> PrimDefDetails {
PrimDefDetails::PrimClass { name } PrimDefDetails::PrimClass { name, get_ty_fn }
} }
fn fun(name: &'static str, simple_name: Option<&'static str>) -> PrimDefDetails { fn fun(name: &'static str, simple_name: Option<&'static str>) -> PrimDefDetails {
@ -169,29 +207,37 @@ impl PrimDef {
} }
match self { match self {
PrimDef::Int32 => class("int32"), // Classes
PrimDef::Int64 => class("int64"), PrimDef::Int32 => class("int32", |primitives| primitives.int32),
PrimDef::Float => class("float"), PrimDef::Int64 => class("int64", |primitives| primitives.int64),
PrimDef::Bool => class("bool"), PrimDef::Float => class("float", |primitives| primitives.float),
PrimDef::None => class("none"), PrimDef::Bool => class("bool", |primitives| primitives.bool),
PrimDef::Range => class("range"), PrimDef::None => class("none", |primitives| primitives.none),
PrimDef::Str => class("str"), PrimDef::Range => class("range", |primitives| primitives.range),
PrimDef::Exception => class("Exception"), PrimDef::Str => class("str", |primitives| primitives.str),
PrimDef::UInt32 => class("uint32"), PrimDef::Exception => class("Exception", |primitives| primitives.exception),
PrimDef::UInt64 => class("uint64"), PrimDef::UInt32 => class("uint32", |primitives| primitives.uint32),
PrimDef::Option => class("Option"), PrimDef::UInt64 => class("uint64", |primitives| primitives.uint64),
PrimDef::OptionIsSome => fun("Option.is_some", Some("is_some")), PrimDef::Option => class("Option", |primitives| primitives.option),
PrimDef::OptionIsNone => fun("Option.is_none", Some("is_none")), PrimDef::List => class("list", |primitives| primitives.list),
PrimDef::OptionUnwrap => fun("Option.unwrap", Some("unwrap")), PrimDef::NDArray => class("ndarray", |primitives| primitives.ndarray),
PrimDef::List => class("list"),
PrimDef::NDArray => class("ndarray"), // Option methods
PrimDef::NDArrayCopy => fun("ndarray.copy", Some("copy")), PrimDef::FunOptionIsSome => fun("Option.is_some", Some("is_some")),
PrimDef::NDArrayFill => fun("ndarray.fill", Some("fill")), PrimDef::FunOptionIsNone => fun("Option.is_none", Some("is_none")),
PrimDef::FunInt32 => fun("int32", None), PrimDef::FunOptionUnwrap => fun("Option.unwrap", Some("unwrap")),
PrimDef::FunInt64 => fun("int64", None),
PrimDef::FunUInt32 => fun("uint32", None), // Option-related functions
PrimDef::FunUInt64 => fun("uint64", None), PrimDef::FunSome => fun("Some", None),
PrimDef::FunFloat => fun("float", None),
// NDArray methods
PrimDef::FunNDArrayCopy => fun("ndarray.copy", Some("copy")),
PrimDef::FunNDArrayFill => fun("ndarray.fill", Some("fill")),
// Range methods
PrimDef::FunRangeInit => fun("range.__init__", Some("__init__")),
// NumPy factory functions
PrimDef::FunNpNDArray => fun("np_ndarray", None), PrimDef::FunNpNDArray => fun("np_ndarray", None),
PrimDef::FunNpEmpty => fun("np_empty", None), PrimDef::FunNpEmpty => fun("np_empty", None),
PrimDef::FunNpZeros => fun("np_zeros", None), PrimDef::FunNpZeros => fun("np_zeros", None),
@ -200,26 +246,28 @@ impl PrimDef {
PrimDef::FunNpArray => fun("np_array", None), PrimDef::FunNpArray => fun("np_array", None),
PrimDef::FunNpEye => fun("np_eye", None), PrimDef::FunNpEye => fun("np_eye", None),
PrimDef::FunNpIdentity => fun("np_identity", None), PrimDef::FunNpIdentity => fun("np_identity", None),
PrimDef::FunRound => fun("round", None), PrimDef::FunNpArange => fun("np_arange", None),
PrimDef::FunRound64 => fun("round64", None),
// NumPy view functions
PrimDef::FunNpBroadcastTo => fun("np_broadcast_to", None),
PrimDef::FunNpReshape => fun("np_reshape", None),
PrimDef::FunNpTranspose => fun("np_transpose", None),
// NumPy NDArray property getters
PrimDef::FunNpSize => fun("np_size", None),
PrimDef::FunNpShape => fun("np_shape", None),
PrimDef::FunNpStrides => fun("np_strides", None),
// Miscellaneous NumPy & SciPy functions
PrimDef::FunNpRound => fun("np_round", None), PrimDef::FunNpRound => fun("np_round", None),
PrimDef::FunRangeInit => fun("range.__init__", Some("__init__")),
PrimDef::FunStr => fun("str", None),
PrimDef::FunBool => fun("bool", None),
PrimDef::FunFloor => fun("floor", None),
PrimDef::FunFloor64 => fun("floor64", None),
PrimDef::FunNpFloor => fun("np_floor", None), PrimDef::FunNpFloor => fun("np_floor", None),
PrimDef::FunCeil => fun("ceil", None),
PrimDef::FunCeil64 => fun("ceil64", None),
PrimDef::FunNpCeil => fun("np_ceil", None), PrimDef::FunNpCeil => fun("np_ceil", None),
PrimDef::FunLen => fun("len", None),
PrimDef::FunMin => fun("min", None),
PrimDef::FunNpMin => fun("np_min", None), PrimDef::FunNpMin => fun("np_min", None),
PrimDef::FunNpMinimum => fun("np_minimum", None), PrimDef::FunNpMinimum => fun("np_minimum", None),
PrimDef::FunMax => fun("max", None), PrimDef::FunNpArgmin => fun("np_argmin", None),
PrimDef::FunNpMax => fun("np_max", None), PrimDef::FunNpMax => fun("np_max", None),
PrimDef::FunNpMaximum => fun("np_maximum", None), PrimDef::FunNpMaximum => fun("np_maximum", None),
PrimDef::FunAbs => fun("abs", None), PrimDef::FunNpArgmax => fun("np_argmax", None),
PrimDef::FunNpIsNan => fun("np_isnan", None), PrimDef::FunNpIsNan => fun("np_isnan", None),
PrimDef::FunNpIsInf => fun("np_isinf", None), PrimDef::FunNpIsInf => fun("np_isinf", None),
PrimDef::FunNpSin => fun("np_sin", None), PrimDef::FunNpSin => fun("np_sin", None),
@ -257,7 +305,38 @@ impl PrimDef {
PrimDef::FunNpLdExp => fun("np_ldexp", None), PrimDef::FunNpLdExp => fun("np_ldexp", None),
PrimDef::FunNpHypot => fun("np_hypot", None), PrimDef::FunNpHypot => fun("np_hypot", None),
PrimDef::FunNpNextAfter => fun("np_nextafter", None), PrimDef::FunNpNextAfter => fun("np_nextafter", None),
PrimDef::FunSome => fun("Some", None),
// Linalg functions
PrimDef::FunNpDot => fun("np_dot", None),
PrimDef::FunNpLinalgCholesky => fun("np_linalg_cholesky", None),
PrimDef::FunNpLinalgQr => fun("np_linalg_qr", None),
PrimDef::FunNpLinalgSvd => fun("np_linalg_svd", None),
PrimDef::FunNpLinalgInv => fun("np_linalg_inv", None),
PrimDef::FunNpLinalgPinv => fun("np_linalg_pinv", None),
PrimDef::FunNpLinalgMatrixPower => fun("np_linalg_matrix_power", None),
PrimDef::FunNpLinalgDet => fun("np_linalg_det", None),
PrimDef::FunSpLinalgLu => fun("sp_linalg_lu", None),
PrimDef::FunSpLinalgSchur => fun("sp_linalg_schur", None),
PrimDef::FunSpLinalgHessenberg => fun("sp_linalg_hessenberg", None),
// Miscellaneous Python & NAC3 functions
PrimDef::FunInt32 => fun("int32", None),
PrimDef::FunInt64 => fun("int64", None),
PrimDef::FunUInt32 => fun("uint32", None),
PrimDef::FunUInt64 => fun("uint64", None),
PrimDef::FunFloat => fun("float", None),
PrimDef::FunRound => fun("round", None),
PrimDef::FunRound64 => fun("round64", None),
PrimDef::FunStr => fun("str", None),
PrimDef::FunBool => fun("bool", None),
PrimDef::FunFloor => fun("floor", None),
PrimDef::FunFloor64 => fun("floor64", None),
PrimDef::FunCeil => fun("ceil", None),
PrimDef::FunCeil64 => fun("ceil64", None),
PrimDef::FunLen => fun("len", None),
PrimDef::FunMin => fun("min", None),
PrimDef::FunMax => fun("max", None),
PrimDef::FunAbs => fun("abs", None),
} }
} }
} }
@ -408,9 +487,9 @@ impl TopLevelComposer {
let option = unifier.add_ty(TypeEnum::TObj { let option = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::Option.id(), obj_id: PrimDef::Option.id(),
fields: vec![ fields: vec![
(PrimDef::OptionIsSome.simple_name().into(), (is_some_type_fun_ty, true)), (PrimDef::FunOptionIsSome.simple_name().into(), (is_some_type_fun_ty, true)),
(PrimDef::OptionIsNone.simple_name().into(), (is_some_type_fun_ty, true)), (PrimDef::FunOptionIsNone.simple_name().into(), (is_some_type_fun_ty, true)),
(PrimDef::OptionUnwrap.simple_name().into(), (unwrap_fun_ty, true)), (PrimDef::FunOptionUnwrap.simple_name().into(), (unwrap_fun_ty, true)),
] ]
.into_iter() .into_iter()
.collect::<HashMap<_, _>>(), .collect::<HashMap<_, _>>(),
@ -444,6 +523,7 @@ impl TopLevelComposer {
name: "value".into(), name: "value".into(),
ty: ndarray_dtype_tvar.ty, ty: ndarray_dtype_tvar.ty,
default_value: None, default_value: None,
is_vararg: false,
}], }],
ret: none, ret: none,
vars: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]), vars: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
@ -451,8 +531,8 @@ impl TopLevelComposer {
let ndarray = unifier.add_ty(TypeEnum::TObj { let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::NDArray.id(), obj_id: PrimDef::NDArray.id(),
fields: Mapping::from([ fields: Mapping::from([
(PrimDef::NDArrayCopy.simple_name().into(), (ndarray_copy_fun_ty, true)), (PrimDef::FunNDArrayCopy.simple_name().into(), (ndarray_copy_fun_ty, true)),
(PrimDef::NDArrayFill.simple_name().into(), (ndarray_fill_fun_ty, true)), (PrimDef::FunNDArrayFill.simple_name().into(), (ndarray_fill_fun_ty, true)),
]), ]),
params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]), params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
}); });
@ -938,3 +1018,23 @@ pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 {
_ => 0, _ => 0,
} }
} }
/// Extract an ndarray's `ndims` [type][`Type`] in `u64`. Panic if not possible.
/// The `ndims` must only contain 1 value.
#[must_use]
pub fn extract_ndims(unifier: &Unifier, ndims_ty: Type) -> u64 {
let ndims_ty_enum = unifier.get_ty_immutable(ndims_ty);
let TypeEnum::TLiteral { values, .. } = &*ndims_ty_enum else {
panic!("ndims_ty should be a TLiteral");
};
assert_eq!(values.len(), 1, "ndims_ty TLiteral should only contain 1 value");
let ndims = values[0].clone();
u64::try_from(ndims).unwrap()
}
/// Return an ndarray's `ndims` as a typechecker [`Type`] from its `u64` value.
pub fn create_ndims(unifier: &mut Unifier, ndims: u64) -> Type {
unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None)
}

View File

@ -31,6 +31,7 @@ pub mod builtins;
pub mod composer; pub mod composer;
pub mod helper; pub mod helper;
pub mod numpy; pub mod numpy;
pub mod option;
pub mod type_annotation; pub mod type_annotation;
use composer::*; use composer::*;
use type_annotation::*; use type_annotation::*;

View File

@ -0,0 +1,46 @@
use itertools::Itertools;
use crate::{
toplevel::helper::PrimDef,
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{Type, TypeEnum, Unifier, VarMap},
},
};
// TODO: This entire module is duplicated code (numpy.rs also has these kinds of things)
/// Creates a `option` [`Type`] with the given type arguments.
///
/// * `dtype` - The element type of the `option`, or [`None`] if the type variable is not
/// specialized.
/// * `ndims` - The number of dimensions of the `option`, or [`None`] if the type variable is not
/// specialized.
pub fn make_option_ty(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
dtype: Option<Type>,
) -> Type {
subst_option_tvars(unifier, primitives.option, dtype)
}
/// Substitutes type variables in `option`.
///
/// * `dtype` - The element type of the `option`, or [`None`] if the type variable is not
/// specialized.
pub fn subst_option_tvars(unifier: &mut Unifier, option: Type, dtype: Option<Type>) -> Type {
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(option) else {
panic!("Expected `option` to be TObj, but got {}", unifier.stringify(option))
};
debug_assert_eq!(*obj_id, PrimDef::Option.id());
let tvar_ids = params.iter().map(|(obj_id, _)| *obj_id).collect_vec();
debug_assert_eq!(tvar_ids.len(), 1);
let mut tvar_subst = VarMap::new();
if let Some(dtype) = dtype {
tvar_subst.insert(tvar_ids[0], dtype);
}
unifier.subst(option, &tvar_subst).unwrap_or(option)
}

View File

@ -5,7 +5,7 @@ expression: res_vec
[ [
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(245)]\n}\n", "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(241)]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",

View File

@ -7,7 +7,7 @@ expression: res_vec
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar234]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar234\"]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B[typevar230]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar230\"]\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",

View File

@ -5,8 +5,8 @@ expression: res_vec
[ [
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(247)]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(243)]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(252)]\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(248)]\n}\n",
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",

View File

@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec expression: res_vec
--- ---
[ [
"Class {\nname: \"A\",\nancestors: [\"A[typevar233, typevar234]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar233\", \"typevar234\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[typevar229, typevar230]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar229\", \"typevar230\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",

View File

@ -6,12 +6,12 @@ expression: res_vec
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(253)]\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(249)]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(261)]\n}\n", "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(257)]\n}\n",
] ]

View File

@ -117,7 +117,8 @@ impl SymbolResolver for Resolver {
"register" "register"
)] )]
fn test_simple_register(source: Vec<&str>) { fn test_simple_register(source: Vec<&str>) {
let mut composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 64).0; let mut composer =
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
for s in source { for s in source {
let ast = parse_program(s, FileName::default()).unwrap(); let ast = parse_program(s, FileName::default()).unwrap();
@ -137,7 +138,8 @@ fn test_simple_register(source: Vec<&str>) {
"register" "register"
)] )]
fn test_simple_register_without_constructor(source: &str) { fn test_simple_register_without_constructor(source: &str) {
let mut composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 64).0; let mut composer =
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
let ast = parse_program(source, FileName::default()).unwrap(); let ast = parse_program(source, FileName::default()).unwrap();
let ast = ast[0].clone(); let ast = ast[0].clone();
composer.register_top_level(ast, None, "", true).unwrap(); composer.register_top_level(ast, None, "", true).unwrap();
@ -171,7 +173,8 @@ fn test_simple_register_without_constructor(source: &str) {
"function compose" "function compose"
)] )]
fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) { fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
let mut composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 64).0; let mut composer =
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
let internal_resolver = Arc::new(ResolverInternal { let internal_resolver = Arc::new(ResolverInternal {
id_to_def: Mutex::default(), id_to_def: Mutex::default(),
@ -519,7 +522,8 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
)] )]
fn test_analyze(source: &[&str], res: &[&str]) { fn test_analyze(source: &[&str], res: &[&str]) {
let print = false; let print = false;
let mut composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 64).0; let mut composer =
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
let internal_resolver = make_internal_resolver_with_tvar( let internal_resolver = make_internal_resolver_with_tvar(
vec![ vec![
@ -696,7 +700,8 @@ fn test_analyze(source: &[&str], res: &[&str]) {
)] )]
fn test_inference(source: Vec<&str>, res: &[&str]) { fn test_inference(source: Vec<&str>, res: &[&str]) {
let print = true; let print = true;
let mut composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 64).0; let mut composer =
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
let internal_resolver = make_internal_resolver_with_tvar( let internal_resolver = make_internal_resolver_with_tvar(
vec![ vec![

View File

@ -1,8 +1,9 @@
use super::*; use super::*;
use crate::symbol_resolver::SymbolValue; use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::PrimDef; use crate::toplevel::helper::{PrimDef, PrimDefDetails};
use crate::typecheck::typedef::VarMap; use crate::typecheck::typedef::VarMap;
use nac3parser::ast::Constant; use nac3parser::ast::Constant;
use strum::IntoEnumIterator;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum TypeAnnotation { pub enum TypeAnnotation {
@ -357,6 +358,7 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
pub fn get_type_from_type_annotation_kinds( pub fn get_type_from_type_annotation_kinds(
top_level_defs: &[Arc<RwLock<TopLevelDef>>], top_level_defs: &[Arc<RwLock<TopLevelDef>>],
unifier: &mut Unifier, unifier: &mut Unifier,
primitives: &PrimitiveStore,
ann: &TypeAnnotation, ann: &TypeAnnotation,
subst_list: &mut Option<Vec<Type>>, subst_list: &mut Option<Vec<Type>>,
) -> Result<Type, HashSet<String>> { ) -> Result<Type, HashSet<String>> {
@ -379,100 +381,141 @@ pub fn get_type_from_type_annotation_kinds(
let param_ty = params let param_ty = params
.iter() .iter()
.map(|x| { .map(|x| {
get_type_from_type_annotation_kinds(top_level_defs, unifier, x, subst_list) get_type_from_type_annotation_kinds(
top_level_defs,
unifier,
primitives,
x,
subst_list,
)
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
let subst = { let ty = if let Some(prim_def) = PrimDef::iter().find(|prim| prim.id() == *obj_id) {
// check for compatible range // Primitive TopLevelDefs do not contain all fields that are present in their Type
// TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check // counterparts, so directly perform subst on the Type instead.
let mut result = VarMap::new();
for (tvar, p) in type_vars.iter().zip(param_ty) {
match unifier.get_ty(*tvar).as_ref() {
TypeEnum::TVar {
id,
range,
fields: None,
name,
loc,
is_const_generic: false,
} => {
let ok: bool = {
// create a temp type var and unify to check compatibility
p == *tvar || {
let temp = unifier.get_fresh_var_with_range(
range.as_slice(),
*name,
*loc,
);
unifier.unify(temp.ty, p).is_ok()
}
};
if ok {
result.insert(*id, p);
} else {
return Err(HashSet::from([format!(
"cannot apply type {} to type variable with id {:?}",
unifier.internal_stringify(
p,
&mut |id| format!("class{id}"),
&mut |id| format!("typevar{id}"),
&mut None
),
*id
)]));
}
}
TypeEnum::TVar { id, range, name, loc, is_const_generic: true, .. } => { let PrimDefDetails::PrimClass { get_ty_fn, .. } = prim_def.details() else {
let ty = range[0]; unreachable!()
let ok: bool = { };
// create a temp type var and unify to check compatibility
p == *tvar || {
let temp = unifier.get_fresh_const_generic_var(ty, *name, *loc);
unifier.unify(temp.ty, p).is_ok()
}
};
if ok {
result.insert(*id, p);
} else {
return Err(HashSet::from([format!(
"cannot apply type {} to type variable {}",
unifier.stringify(p),
name.unwrap_or_else(|| format!("typevar{id}").into()),
)]));
}
}
_ => unreachable!("must be generic type var"), let base_ty = get_ty_fn(primitives);
let params =
if let TypeEnum::TObj { params, .. } = &*unifier.get_ty_immutable(base_ty) {
params.clone()
} else {
unreachable!()
};
unifier
.subst(
get_ty_fn(primitives),
&params
.iter()
.zip(param_ty)
.map(|(obj_tv, param)| (*obj_tv.0, param))
.collect(),
)
.unwrap_or(base_ty)
} else {
let subst = {
// check for compatible range
// TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check
let mut result = VarMap::new();
for (tvar, p) in type_vars.iter().zip(param_ty) {
match unifier.get_ty(*tvar).as_ref() {
TypeEnum::TVar {
id,
range,
fields: None,
name,
loc,
is_const_generic: false,
} => {
let ok: bool = {
// create a temp type var and unify to check compatibility
p == *tvar || {
let temp = unifier.get_fresh_var_with_range(
range.as_slice(),
*name,
*loc,
);
unifier.unify(temp.ty, p).is_ok()
}
};
if ok {
result.insert(*id, p);
} else {
return Err(HashSet::from([format!(
"cannot apply type {} to type variable with id {:?}",
unifier.internal_stringify(
p,
&mut |id| format!("class{id}"),
&mut |id| format!("typevar{id}"),
&mut None
),
*id
)]));
}
}
TypeEnum::TVar {
id, range, name, loc, is_const_generic: true, ..
} => {
let ty = range[0];
let ok: bool = {
// create a temp type var and unify to check compatibility
p == *tvar || {
let temp =
unifier.get_fresh_const_generic_var(ty, *name, *loc);
unifier.unify(temp.ty, p).is_ok()
}
};
if ok {
result.insert(*id, p);
} else {
return Err(HashSet::from([format!(
"cannot apply type {} to type variable {}",
unifier.stringify(p),
name.unwrap_or_else(|| format!("typevar{id}").into()),
)]));
}
}
_ => unreachable!("must be generic type var"),
}
}
result
};
// Class Attributes keep a copy with Class Definition and are not added to objects
let mut tobj_fields = methods
.iter()
.map(|(name, ty, _)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
// methods are immutable
(*name, (subst_ty, false))
})
.collect::<HashMap<_, _>>();
tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*name, (subst_ty, *mutability))
}));
let need_subst = !subst.is_empty();
let ty = unifier.add_ty(TypeEnum::TObj {
obj_id: *obj_id,
fields: tobj_fields,
params: subst,
});
if need_subst {
if let Some(wl) = subst_list.as_mut() {
wl.push(ty);
} }
} }
result
ty
}; };
// Class Attributes keep a copy with Class Definition and are not added to objects
let mut tobj_fields = methods
.iter()
.map(|(name, ty, _)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
// methods are immutable
(*name, (subst_ty, false))
})
.collect::<HashMap<_, _>>();
tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*name, (subst_ty, *mutability))
}));
let need_subst = !subst.is_empty();
let ty = unifier.add_ty(TypeEnum::TObj {
obj_id: *obj_id,
fields: tobj_fields,
params: subst,
});
if need_subst {
if let Some(wl) = subst_list.as_mut() {
wl.push(ty);
}
}
Ok(ty) Ok(ty)
} }
TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty), TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty),
@ -490,6 +533,7 @@ pub fn get_type_from_type_annotation_kinds(
let ty = get_type_from_type_annotation_kinds( let ty = get_type_from_type_annotation_kinds(
top_level_defs, top_level_defs,
unifier, unifier,
primitives,
ty.as_ref(), ty.as_ref(),
subst_list, subst_list,
)?; )?;
@ -499,10 +543,16 @@ pub fn get_type_from_type_annotation_kinds(
let tys = tys let tys = tys
.iter() .iter()
.map(|x| { .map(|x| {
get_type_from_type_annotation_kinds(top_level_defs, unifier, x, subst_list) get_type_from_type_annotation_kinds(
top_level_defs,
unifier,
primitives,
x,
subst_list,
)
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
Ok(unifier.add_ty(TypeEnum::TTuple { ty: tys })) Ok(unifier.add_ty(TypeEnum::TTuple { ty: tys, is_vararg_ctx: false }))
} }
} }
} }

View File

@ -34,13 +34,18 @@ impl<'a> Inferencer<'a> {
self.should_have_value(pattern)?; self.should_have_value(pattern)?;
Ok(()) Ok(())
} }
ExprKind::Tuple { elts, .. } => { ExprKind::List { elts, .. } | ExprKind::Tuple { elts, .. } => {
for elt in elts { for elt in elts {
self.check_pattern(elt, defined_identifiers)?; self.check_pattern(elt, defined_identifiers)?;
self.should_have_value(elt)?; self.should_have_value(elt)?;
} }
Ok(()) Ok(())
} }
ExprKind::Starred { value, .. } => {
self.check_pattern(value, defined_identifiers)?;
self.should_have_value(value)?;
Ok(())
}
ExprKind::Subscript { value, slice, .. } => { ExprKind::Subscript { value, slice, .. } => {
self.check_expr(value, defined_identifiers)?; self.check_expr(value, defined_identifiers)?;
self.should_have_value(value)?; self.should_have_value(value)?;
@ -75,7 +80,7 @@ impl<'a> Inferencer<'a> {
return Err(HashSet::from([format!( return Err(HashSet::from([format!(
"expected concrete type at {} but got {}", "expected concrete type at {} but got {}",
expr.location, expr.location,
self.unifier.get_ty(*ty).get_type_name() self.unifier.stringify(*ty)
)])); )]));
} }
} }
@ -218,7 +223,7 @@ impl<'a> Inferencer<'a> {
] ]
.iter() .iter()
.any(|allowed_ty| self.unifier.unioned(ret_ty, *allowed_ty)), .any(|allowed_ty| self.unifier.unioned(ret_ty, *allowed_ty)),
TypeEnum::TTuple { ty } => ty.iter().all(|t| self.check_return_value_ty(*t)), TypeEnum::TTuple { ty, .. } => ty.iter().all(|t| self.check_return_value_ty(*t)),
_ => false, _ => false,
} }
} }

View File

@ -1,5 +1,5 @@
use crate::symbol_resolver::SymbolValue; use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::PrimDef; use crate::toplevel::helper::{arraylike_flatten_element_type, arraylike_get_ndims, PrimDef};
use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys}; use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys};
use crate::typecheck::{ use crate::typecheck::{
type_inferencer::*, type_inferencer::*,
@ -197,6 +197,7 @@ pub fn impl_binop(
ty: other_ty, ty: other_ty,
default_value: None, default_value: None,
name: "other".into(), name: "other".into(),
is_vararg: false,
}], }],
})), })),
false, false,
@ -261,6 +262,7 @@ pub fn impl_cmpop(
ty: other_ty, ty: other_ty,
default_value: None, default_value: None,
name: "other".into(), name: "other".into(),
is_vararg: false,
}], }],
})), })),
false, false,
@ -518,36 +520,41 @@ pub fn typeof_binop(
} }
Operator::MatMult => { Operator::MatMult => {
let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs); let lhs_dtype = arraylike_flatten_element_type(unifier, lhs);
let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) { let rhs_dtype = arraylike_flatten_element_type(unifier, rhs);
TypeEnum::TLiteral { values, .. } => {
assert_eq!(values.len(), 1); let lhs_ndims = arraylike_get_ndims(unifier, lhs);
u64::try_from(values[0].clone()).unwrap() let rhs_ndims = arraylike_get_ndims(unifier, rhs);
if !(unifier.unioned(lhs_dtype, primitives.float)
&& unifier.unioned(rhs_dtype, primitives.float))
{
return Err(format!(
"ndarray.__matmul__ only supports float64 operations, but LHS has type {} and RHS has type {}",
unifier.stringify(lhs),
unifier.stringify(rhs)
));
}
let result_ndims = match (lhs_ndims, rhs_ndims) {
(0, _) | (_, 0) => {
return Err(
"ndarray.__matmul__ does not allow unsized ndarray input".to_string()
)
} }
_ => unreachable!(), (1, 1) => 0,
}; (1, _) => rhs_ndims - 1,
let (_, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs); (_, 1) => lhs_ndims - 1,
let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) { (m, n) => max(m, n),
TypeEnum::TLiteral { values, .. } => {
assert_eq!(values.len(), 1);
u64::try_from(values[0].clone()).unwrap()
}
_ => unreachable!(),
}; };
match (lhs_ndims, rhs_ndims) { if result_ndims == 0 {
(2, 2) => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?, // If the result is unsized, NumPy returns a scalar.
(lhs, rhs) if lhs == 0 || rhs == 0 => { primitives.float
return Err(format!( } else {
"Input operand {} does not have enough dimensions (has {lhs}, requires {rhs})", let result_ndims_ty =
u8::from(rhs == 0) unifier.get_fresh_literal(vec![SymbolValue::U64(result_ndims)], None);
)) make_ndarray_ty(unifier, primitives, Some(primitives.float), Some(result_ndims_ty))
}
(lhs, rhs) => {
return Err(format!(
"ndarray.__matmul__ on {lhs}D and {rhs}D operands not supported"
))
}
} }
} }
@ -746,7 +753,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None); impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None);
impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_matmul(unifier, store, ndarray_t, &[ndarray_t], Some(ndarray_t)); impl_matmul(unifier, store, ndarray_t, &[ndarray_unsized_t], Some(ndarray_t));
impl_sign(unifier, store, ndarray_t, Some(ndarray_t)); impl_sign(unifier, store, ndarray_t, Some(ndarray_t));
impl_invert(unifier, store, ndarray_t, Some(ndarray_t)); impl_invert(unifier, store, ndarray_t, Some(ndarray_t));
impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);

View File

@ -183,9 +183,10 @@ impl<'a> Display for DisplayTypeError<'a> {
} }
result result
} }
(TypeEnum::TTuple { ty: ty1 }, TypeEnum::TTuple { ty: ty2 }) (
if ty1.len() != ty2.len() => TypeEnum::TTuple { ty: ty1, is_vararg_ctx: is_vararg1 },
{ TypeEnum::TTuple { ty: ty2, is_vararg_ctx: is_vararg2 },
) if !is_vararg1 && !is_vararg2 && ty1.len() != ty2.len() => {
let t1 = self.unifier.stringify_with_notes(*t1, &mut notes); let t1 = self.unifier.stringify_with_notes(*t1, &mut notes);
let t2 = self.unifier.stringify_with_notes(*t2, &mut notes); let t2 = self.unifier.stringify_with_notes(*t2, &mut notes);
write!(f, "Tuple length mismatch: got {t1} and {t2}") write!(f, "Tuple length mismatch: got {t1} and {t2}")

File diff suppressed because it is too large Load Diff

View File

@ -83,7 +83,12 @@ impl TestEnvironment {
}); });
with_fields(&mut unifier, int32, |unifier, fields| { with_fields(&mut unifier, int32, |unifier, fields| {
let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }], args: vec![FuncArg {
name: "other".into(),
ty: int32,
default_value: None,
is_vararg: false,
}],
ret: int32, ret: int32,
vars: VarMap::new(), vars: VarMap::new(),
})); }));
@ -224,7 +229,12 @@ impl TestEnvironment {
}); });
with_fields(&mut unifier, int32, |unifier, fields| { with_fields(&mut unifier, int32, |unifier, fields| {
let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }], args: vec![FuncArg {
name: "other".into(),
ty: int32,
default_value: None,
is_vararg: false,
}],
ret: int32, ret: int32,
vars: VarMap::new(), vars: VarMap::new(),
})); }));

View File

@ -1,15 +1,14 @@
use indexmap::IndexMap; use indexmap::IndexMap;
use itertools::Itertools; use itertools::{repeat_n, Itertools};
use nac3parser::ast::{Cmpop, Location, StrRef, Unaryop};
use std::cell::RefCell; use std::cell::RefCell;
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::{self, Display}; use std::fmt::{self, Display};
use std::iter::zip; use std::iter::{repeat, zip};
use std::rc::Rc; use std::rc::Rc;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use std::{borrow::Cow, collections::HashSet}; use std::{borrow::Cow, collections::HashSet};
use nac3parser::ast::{Cmpop, Location, StrRef, Unaryop};
use super::magic_methods::Binop; use super::magic_methods::Binop;
use super::type_error::{TypeError, TypeErrorKind}; use super::type_error::{TypeError, TypeErrorKind};
use super::unification_table::{UnificationKey, UnificationTable}; use super::unification_table::{UnificationKey, UnificationTable};
@ -115,6 +114,7 @@ pub struct FuncArg {
pub name: StrRef, pub name: StrRef,
pub ty: Type, pub ty: Type,
pub default_value: Option<SymbolValue>, pub default_value: Option<SymbolValue>,
pub is_vararg: bool,
} }
impl FuncArg { impl FuncArg {
@ -233,6 +233,12 @@ pub enum TypeEnum {
TTuple { TTuple {
/// The types of elements present in this tuple. /// The types of elements present in this tuple.
ty: Vec<Type>, ty: Vec<Type>,
/// Whether this tuple is used in a vararg context.
///
/// If `true`, `ty` must only contain one type, and the tuple is assumed to contain any
/// number of `ty`-typed values.
is_vararg_ctx: bool,
}, },
/// An object type. /// An object type.
@ -336,6 +342,14 @@ impl Unifier {
self.unification_table.unioned(a, b) self.unification_table.unioned(a, b)
} }
/// Determine if a type unions with a type in `tys`.
pub fn unioned_any<I>(&mut self, a: Type, tys: I) -> bool
where
I: IntoIterator<Item = Type>,
{
tys.into_iter().any(|ty| self.unioned(a, ty))
}
pub fn from_shared_unifier(unifier: &SharedUnifier) -> Unifier { pub fn from_shared_unifier(unifier: &SharedUnifier) -> Unifier {
let lock = unifier.lock().unwrap(); let lock = unifier.lock().unwrap();
Unifier { Unifier {
@ -527,7 +541,7 @@ impl Unifier {
TypeEnum::TVirtual { ty } => self.get_instantiations(*ty).map(|ty| { TypeEnum::TVirtual { ty } => self.get_instantiations(*ty).map(|ty| {
ty.iter().map(|&ty| self.add_ty(TypeEnum::TVirtual { ty })).collect_vec() ty.iter().map(|&ty| self.add_ty(TypeEnum::TVirtual { ty })).collect_vec()
}), }),
TypeEnum::TTuple { ty } => { TypeEnum::TTuple { ty, is_vararg_ctx } => {
let tuples = ty let tuples = ty
.iter() .iter()
.map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty])) .map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty]))
@ -537,7 +551,12 @@ impl Unifier {
None None
} else { } else {
Some( Some(
tuples.into_iter().map(|ty| self.add_ty(TypeEnum::TTuple { ty })).collect(), tuples
.into_iter()
.map(|ty| {
self.add_ty(TypeEnum::TTuple { ty, is_vararg_ctx: *is_vararg_ctx })
})
.collect(),
) )
} }
} }
@ -581,7 +600,7 @@ impl Unifier {
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)), TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
TCall { .. } => false, TCall { .. } => false,
TVirtual { ty } => self.is_concrete(*ty, allowed_typevars), TVirtual { ty } => self.is_concrete(*ty, allowed_typevars),
TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)), TTuple { ty, .. } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)),
TObj { params: vars, .. } => { TObj { params: vars, .. } => {
vars.values().all(|ty| self.is_concrete(*ty, allowed_typevars)) vars.values().all(|ty| self.is_concrete(*ty, allowed_typevars))
} }
@ -649,6 +668,7 @@ impl Unifier {
// Get details about the function signature/parameters. // Get details about the function signature/parameters.
let num_params = signature.args.len(); let num_params = signature.args.len();
let is_vararg = signature.args.iter().any(|arg| arg.is_vararg);
// Force the type vars in `b` and `signature' to be up-to-date. // Force the type vars in `b` and `signature' to be up-to-date.
let b = self.instantiate_fun(b, signature); let b = self.instantiate_fun(b, signature);
@ -737,7 +757,7 @@ impl Unifier {
}; };
// Check for "too many arguments" // Check for "too many arguments"
if num_params < posargs.len() { if !is_vararg && num_params < posargs.len() {
let expected_min_count = let expected_min_count =
signature.args.iter().filter(|param| param.is_required()).count(); signature.args.iter().filter(|param| param.is_required()).count();
let expected_max_count = num_params; let expected_max_count = num_params;
@ -770,6 +790,19 @@ impl Unifier {
type_check_arg(param.name, param.ty, arg_ty)?; type_check_arg(param.name, param.ty, arg_ty)?;
} }
if is_vararg {
debug_assert!(!signature.args.is_empty());
let vararg_args = posargs.iter().skip(signature.args.len());
let vararg_param = signature.args.last().unwrap();
for (&arg_ty, param) in zip(vararg_args, repeat(vararg_param)) {
// `param_info` for this argument would've already been marked as supplied
// during non-vararg posarg typecheck
type_check_arg(param.name, param.ty, arg_ty)?;
}
}
// Now consume all keyword arguments and typecheck them. // Now consume all keyword arguments and typecheck them.
for (&param_name, &arg_ty) in kwargs { for (&param_name, &arg_ty) in kwargs {
// We will also use this opportunity to check if this keyword argument is "legal". // We will also use this opportunity to check if this keyword argument is "legal".
@ -959,7 +992,10 @@ impl Unifier {
self.unify_impl(x, b, false)?; self.unify_impl(x, b, false)?;
self.set_a_to_b(a, x); self.set_a_to_b(a, x);
} }
(TVar { fields: Some(fields), range, is_const_generic: false, .. }, TTuple { ty }) => { (
TVar { fields: Some(fields), range, is_const_generic: false, .. },
TTuple { ty, .. },
) => {
let len = i32::try_from(ty.len()).unwrap(); let len = i32::try_from(ty.len()).unwrap();
for (k, v) in fields { for (k, v) in fields {
match *k { match *k {
@ -1056,15 +1092,47 @@ impl Unifier {
self.set_a_to_b(a, b); self.set_a_to_b(a, b);
} }
(TTuple { ty: ty1 }, TTuple { ty: ty2 }) => { (
if ty1.len() != ty2.len() { TTuple { ty: ty1, is_vararg_ctx: is_vararg1 },
return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None)); TTuple { ty: ty2, is_vararg_ctx: is_vararg2 },
} ) => {
for (x, y) in ty1.iter().zip(ty2.iter()) { // Rules for Tuples:
if self.unify_impl(*x, *y, false).is_err() { // - ty1: is_vararg && ty2: is_vararg -> ty1[0] == ty2[0]
return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None)); // - ty1: is_vararg && ty2: !is_vararg -> type error (not enough info to infer the correct number of arguments)
// - ty1: !is_vararg && ty2: is_vararg -> ty1[..] == ty2[0]
// - ty1: !is_vararg && ty2: !is_vararg -> ty1.len() == ty2.len() && ty1[i] == ty2[i]
debug_assert!(!is_vararg1 || ty1.len() == 1);
debug_assert!(!is_vararg2 || ty2.len() == 1);
match (*is_vararg1, *is_vararg2) {
(true, true) => {
if self.unify_impl(ty1[0], ty2[0], false).is_err() {
return Self::incompatible_types(a, b);
}
}
(true, false) => return Self::incompatible_types(a, b),
(false, true) => {
for y in ty2 {
if self.unify_impl(ty1[0], *y, false).is_err() {
return Self::incompatible_types(a, b);
}
}
}
(false, false) => {
if ty1.len() != ty2.len() {
return Self::incompatible_types(a, b);
}
for (x, y) in ty1.iter().zip(ty2.iter()) {
if self.unify_impl(*x, *y, false).is_err() {
return Self::incompatible_types(a, b);
}
}
} }
} }
self.set_a_to_b(a, b); self.set_a_to_b(a, b);
} }
(TVar { fields: Some(map), range, .. }, TObj { obj_id, fields, params }) => { (TVar { fields: Some(map), range, .. }, TObj { obj_id, fields, params }) => {
@ -1307,10 +1375,22 @@ impl Unifier {
TypeEnum::TLiteral { values, .. } => { TypeEnum::TLiteral { values, .. } => {
format!("const({})", values.iter().map(|v| format!("{v:?}")).join(", ")) format!("const({})", values.iter().map(|v| format!("{v:?}")).join(", "))
} }
TypeEnum::TTuple { ty } => { TypeEnum::TTuple { ty, is_vararg_ctx } => {
let mut fields = if *is_vararg_ctx {
ty.iter().map(|v| self.internal_stringify(*v, obj_to_name, var_to_name, notes)); debug_assert_eq!(ty.len(), 1);
format!("tuple[{}]", fields.join(", ")) let field = self.internal_stringify(
*ty.iter().next().unwrap(),
obj_to_name,
var_to_name,
notes,
);
format!("tuple[*{field}]")
} else {
let mut fields = ty
.iter()
.map(|v| self.internal_stringify(*v, obj_to_name, var_to_name, notes));
format!("tuple[{}]", fields.join(", "))
}
} }
TypeEnum::TVirtual { ty } => { TypeEnum::TVirtual { ty } => {
format!( format!(
@ -1335,17 +1415,21 @@ impl Unifier {
.args .args
.iter() .iter()
.map(|arg| { .map(|arg| {
let vararg_prefix = if arg.is_vararg { "*" } else { "" };
if let Some(dv) = &arg.default_value { if let Some(dv) = &arg.default_value {
format!( format!(
"{}:{}={}", "{}:{}{}={}",
arg.name, arg.name,
vararg_prefix,
self.internal_stringify(arg.ty, obj_to_name, var_to_name, notes), self.internal_stringify(arg.ty, obj_to_name, var_to_name, notes),
dv dv
) )
} else { } else {
format!( format!(
"{}:{}", "{}:{}{}",
arg.name, arg.name,
vararg_prefix,
self.internal_stringify(arg.ty, obj_to_name, var_to_name, notes) self.internal_stringify(arg.ty, obj_to_name, var_to_name, notes)
) )
} }
@ -1431,7 +1515,7 @@ impl Unifier {
match &*ty { match &*ty {
TypeEnum::TRigidVar { .. } | TypeEnum::TLiteral { .. } => None, TypeEnum::TRigidVar { .. } | TypeEnum::TLiteral { .. } => None,
TypeEnum::TVar { id, .. } => mapping.get(id).copied(), TypeEnum::TVar { id, .. } => mapping.get(id).copied(),
TypeEnum::TTuple { ty } => { TypeEnum::TTuple { ty, is_vararg_ctx } => {
let mut new_ty = Cow::from(ty); let mut new_ty = Cow::from(ty);
for (i, t) in ty.iter().enumerate() { for (i, t) in ty.iter().enumerate() {
if let Some(t1) = self.subst_impl(*t, mapping, cache) { if let Some(t1) = self.subst_impl(*t, mapping, cache) {
@ -1439,7 +1523,10 @@ impl Unifier {
} }
} }
if matches!(new_ty, Cow::Owned(_)) { if matches!(new_ty, Cow::Owned(_)) {
Some(self.add_ty(TypeEnum::TTuple { ty: new_ty.into_owned() })) Some(self.add_ty(TypeEnum::TTuple {
ty: new_ty.into_owned(),
is_vararg_ctx: *is_vararg_ctx,
}))
} else { } else {
None None
} }
@ -1599,16 +1686,37 @@ impl Unifier {
} }
} }
(TVar { range, .. }, _) => self.check_var_compatibility(b, range).or(Err(())), (TVar { range, .. }, _) => self.check_var_compatibility(b, range).or(Err(())),
(TTuple { ty: ty1 }, TTuple { ty: ty2 }) if ty1.len() == ty2.len() => { (
let ty: Vec<_> = zip(ty1.iter(), ty2.iter()) TTuple { ty: ty1, is_vararg_ctx: is_vararg1 },
.map(|(a, b)| self.get_intersection(*a, *b)) TTuple { ty: ty2, is_vararg_ctx: is_vararg2 },
.try_collect()?; ) => {
if ty.iter().any(Option::is_some) { if *is_vararg1 && *is_vararg2 {
Ok(Some(self.add_ty(TTuple { let isect_ty = self.get_intersection(ty1[0], ty2[0])?;
ty: zip(ty, ty1.iter()).map(|(a, b)| a.unwrap_or(*b)).collect(), Ok(isect_ty.map(|ty| self.add_ty(TTuple { ty: vec![ty], is_vararg_ctx: true })))
})))
} else { } else {
Ok(None) let zip_iter: Box<dyn Iterator<Item = (&Type, &Type)>> =
match (*is_vararg1, *is_vararg2) {
(true, _) => Box::new(repeat_n(&ty1[0], ty2.len()).zip(ty2.iter())),
(_, false) => Box::new(ty1.iter().zip(repeat_n(&ty2[0], ty1.len()))),
_ => {
if ty1.len() != ty2.len() {
return Err(());
}
Box::new(ty1.iter().zip(ty2.iter()))
}
};
let ty: Vec<_> =
zip_iter.map(|(a, b)| self.get_intersection(*a, *b)).try_collect()?;
Ok(if ty.iter().any(Option::is_some) {
Some(self.add_ty(TTuple {
ty: zip(ty, ty1.iter()).map(|(a, b)| a.unwrap_or(*b)).collect(),
is_vararg_ctx: false,
}))
} else {
None
})
} }
} }
// TODO(Derppening): #444 // TODO(Derppening): #444

View File

@ -28,7 +28,10 @@ impl Unifier {
TypeEnum::TVar { fields: Some(map1), .. }, TypeEnum::TVar { fields: Some(map1), .. },
TypeEnum::TVar { fields: Some(map2), .. }, TypeEnum::TVar { fields: Some(map2), .. },
) => self.map_eq2(map1, map2), ) => self.map_eq2(map1, map2),
(TypeEnum::TTuple { ty: ty1 }, TypeEnum::TTuple { ty: ty2 }) => { (
TypeEnum::TTuple { ty: ty1, is_vararg_ctx: false },
TypeEnum::TTuple { ty: ty2, is_vararg_ctx: false },
) => {
ty1.len() == ty2.len() ty1.len() == ty2.len()
&& ty1.iter().zip(ty2.iter()).all(|(t1, t2)| self.eq(*t1, *t2)) && ty1.iter().zip(ty2.iter()).all(|(t1, t2)| self.eq(*t1, *t2))
} }
@ -178,7 +181,7 @@ impl TestEnvironment {
ty.push(result.0); ty.push(result.0);
s = result.1; s = result.1;
} }
(self.unifier.add_ty(TypeEnum::TTuple { ty }), &s[1..]) (self.unifier.add_ty(TypeEnum::TTuple { ty, is_vararg_ctx: false }), &s[1..])
} }
"Record" => { "Record" => {
let mut s = &typ[end..]; let mut s = &typ[end..];
@ -608,7 +611,7 @@ fn test_instantiation() {
let v1 = env.unifier.get_fresh_var_with_range(&[list_v, int], None, None).ty; let v1 = env.unifier.get_fresh_var_with_range(&[list_v, int], None, None).ty;
let v2 = env.unifier.get_fresh_var_with_range(&[list_int, float], None, None).ty; let v2 = env.unifier.get_fresh_var_with_range(&[list_int, float], None, None).ty;
let t = env.unifier.get_dummy_var().ty; let t = env.unifier.get_dummy_var().ty;
let tuple = env.unifier.add_ty(TypeEnum::TTuple { ty: vec![v, v1, v2] }); let tuple = env.unifier.add_ty(TypeEnum::TTuple { ty: vec![v, v1, v2], is_vararg_ctx: false });
let v3 = env.unifier.get_fresh_var_with_range(&[tuple, t], None, None).ty; let v3 = env.unifier.get_fresh_var_with_range(&[tuple, t], None, None).ty;
// t = TypeVar('t') // t = TypeVar('t')
// v = TypeVar('v', int, bool) // v = TypeVar('v', int, bool)

View File

@ -3,23 +3,66 @@
set -e set -e
if [ -z "$1" ]; then if [ -z "$1" ]; then
echo "Requires at least one argument" echo "No argument supplied"
exit 1 exit 1
fi fi
declare -a nac3args declare -a nac3args
while [ $# -gt 1 ]; do
case "$1" in
--help)
echo "Usage: check_demo.sh [--debug] [-i686] -- [NAC3ARGS...] demo"
exit
;;
--debug)
debug=1
;;
-i686)
i686=1
;;
--)
shift
break
;;
*)
echo "Unrecognized argument \"$1\""
exit 1
;;
esac
shift
done
while [ $# -gt 1 ]; do while [ $# -gt 1 ]; do
nac3args+=("$1") nac3args+=("$1")
shift shift
done done
demo="$1" demo="$1"
echo -n "Checking $demo... "
./interpret_demo.py "$demo" > interpreted.log
./run_demo.sh --out run.log "${nac3args[@]}" "$demo"
./run_demo.sh --lli --out run_lli.log "${nac3args[@]}" "$demo"
diff -Nau interpreted.log run.log
diff -Nau interpreted.log run_lli.log
echo "ok"
rm -f interpreted.log run.log run_lli.log echo "### Checking $demo..."
echo ">>>>>> Running $demo with the Python interpreter"
./interpret_demo.py "$demo" > interpreted.log
if [ -n "$i686" ]; then
echo "...... Trying NAC3's 32-bit code generator output"
if [ -n "$debug" ]; then
./run_demo.sh --debug -i686 --out run_32.log -- "${nac3args[@]}" "$demo"
else
./run_demo.sh -i686 --out run_32.log -- "${nac3args[@]}" "$demo"
fi
diff -Nau interpreted.log run_32.log
fi
echo "...... Trying NAC3's 64-bit code generator output"
if [ -n "$debug" ]; then
./run_demo.sh --debug --out run_64.log -- "${nac3args[@]}" "$demo"
else
./run_demo.sh --out run_64.log -- "${nac3args[@]}" "$demo"
fi
diff -Nau interpreted.log run_64.log
echo "...... OK"
rm -f interpreted.log \
run_32.log run_64.log

View File

@ -2,6 +2,11 @@
set -e set -e
if [ "$1" == "--help" ]; then
echo "Usage: check_demos.sh [CHECKARGS...] [--] [NAC3ARGS...]"
exit
fi
count=0 count=0
for demo in src/*.py; do for demo in src/*.py; do
./check_demo.sh "$@" "$demo" ./check_demo.sh "$@" "$demo"

Some files were not shown because too many files have changed in this diff Show More