forked from M-Labs/nac3
1
0
Fork 0

Compare commits

...

69 Commits

Author SHA1 Message Date
lyken 2ab7b299b8 core/ndstrides: refactor numpy indexing 2024-07-31 09:53:15 +08:00
lyken 86b0d31290 core/ndstrides: pub ScalarOrNDArray::to_basic_value_enum 2024-07-31 09:53:15 +08:00
lyken 6369db94ab core/codegen: gen_assign to take in value_ty 2024-07-31 09:53:15 +08:00
lyken 3d8240259c core/typecheck: Inferencer allow heterogenerous assignemnt 2024-07-31 09:53:15 +08:00
lyken e4f6adb1ec core/ndstrides: add numpy broadcasting utils 2024-07-31 09:53:15 +08:00
lyken eb295cf7e4 core/ndstrides: implement numpy broadcasting IRRT 2024-07-31 09:53:15 +08:00
lyken 7501a086d0 core/irrt: print_value add bool 2024-07-31 09:53:15 +08:00
lyken fb54d5d112 core/ndstrides: add TODO in np_reshape 2024-07-31 09:53:15 +08:00
lyken 3dc4b17310 core/ndstrides: introduce NDArrayObject & refactor reshape 2024-07-31 09:53:15 +08:00
lyken 7436513b64 core/model: add util.rs & gen_model_memcpy 2024-07-31 09:53:15 +08:00
lyken 7e056b9747 core/ndstrides: fix alloca_ndarray comment 2024-07-31 09:53:15 +08:00
lyken ac7cc15d90 core/ndstrides: remove unnecessary Result<_, String> 2024-07-31 09:53:15 +08:00
lyken 28e6f23034 core/ndstrides: rewrite and fix np_reshape() bug
Data content should be copied and strides should be updated after
negative indices are resolved.
2024-07-31 09:53:15 +08:00
lyken dfb8bf9748 core/ndstrides: fix and rewrite is_c_contiguous 2024-07-31 09:53:15 +08:00
lyken d5880b119a core/ndstrides: move functions to numpy_new/util.rs 2024-07-31 09:53:15 +08:00
lyken 2747869a45 core/ndstrides: implement general ndarray reshaping 2024-07-31 09:53:15 +08:00
lyken bd5cb14d0d core/ndstrides: implement general ndarray basic indexing 2024-07-31 09:53:15 +08:00
lyken 4b14609342 core/ndstrides: implement IRRT slice
Needed by ndarray indexing
2024-07-31 09:53:15 +08:00
lyken 2211c4d852 core/ndstrides: implement gen_foreach_ndarray_elements & np_{empty,ndarray,zeros,ones,full} 2024-07-31 09:53:15 +08:00
lyken 5b9ac9b09c core/ndstrides: implement ndarray len() 2024-07-31 09:53:15 +08:00
lyken 02e3ddfce6 core: make get_llvm_type return new NDArray with strides
NOTE: All old numpy functions are now impossible to run, until NDArray
with strides is fully implemented.
2024-07-31 09:53:15 +08:00
lyken 8ae9a4294b core/ndstrides: add basic ndarray IRRT functions 2024-07-31 09:53:15 +08:00
lyken e5fe86cc93 core/ndstrides: add ArrayWriter & make_shape_writer 2024-07-31 09:53:15 +08:00
lyken fd3d02bff0 core/ndstrides: add NDArray with strides definition 2024-07-31 09:53:15 +08:00
lyken 7502b14d55 core/irrt: add ErrorContext 2024-07-31 09:53:15 +08:00
lyken 5b7588df75 core/model: add and use CSlice and Exception 2024-07-31 09:53:15 +08:00
lyken 0477e2acfa core/irrt: comment arrays_match() 2024-07-31 09:53:15 +08:00
lyken bf0dcf325e core/irrt: add cstr_utils 2024-07-31 09:53:15 +08:00
lyken c772fdb83a core/model: introduce codegen/model 2024-07-31 09:53:15 +08:00
lyken c1369ea5bd core/irrt: introduce irrt testing
`cargo test -F test` would compile `nac3core/irrt/irrt_test.cpp`
targetted to the host machine (it gets to use `std`) and run the
test executable.
2024-07-31 09:52:43 +08:00
lyken ef28138291 core/irrt: split irrt.cpp into headers
To scale IRRT implementations
2024-07-31 09:52:43 +08:00
lyken 984843a46a core/irrt: build.rs capture IR defined constants 2024-07-31 09:52:43 +08:00
lyken c5626e4947 core/irrt: build.rs capture IR defined types 2024-07-31 09:52:43 +08:00
lyken e4ba5e6411 core/irrt: reformat 2024-07-31 09:52:43 +08:00
lyken 31d0fdd818 core: add .clang-format 2024-07-31 09:52:43 +08:00
lyken 3f0e7e28b8 core/irrt: comment build.rs & move irrt to its own dir
To prepare for future IRRT implementations, and to also make cargo
only have to watch a single directory.
2024-07-31 09:52:43 +08:00
David Mak 318a675ea6 standalone: Rename -m32 to -i386 2024-07-29 14:58:58 +08:00
David Mak 32e52ce198 standalone: Revert using uint32_t as slice length
Turns out list and str have always been size_t.
2024-07-29 14:58:29 +08:00
Sebastien Bourdeauducq 665ca8e32d cargo: update dependencies 2024-07-27 22:24:56 +08:00
Sebastien Bourdeauducq 12c12b1d80 flake: update nixpkgs 2024-07-27 22:22:20 +08:00
lyken 72972fa909 core/toplevel: add more numpy categories 2024-07-27 21:57:47 +08:00
lyken 142cd48594 core/toplevel: reorder PrimDef::details 2024-07-27 21:57:47 +08:00
lyken 8adfe781c5 core/toplevel: fix PrimDef method names 2024-07-27 21:57:47 +08:00
lyken 339b74161b core/toplevel: reorganize PrimDef 2024-07-27 21:57:47 +08:00
David Mak 8c5ba37d09 standalone: Add 32-bit execution tests to check_demo.sh 2024-07-26 13:35:40 +08:00
David Mak 05a8948ff2 core: Minor cleanup to use ListValue APIs 2024-07-26 13:35:40 +08:00
David Mak 6d171ec284 core: Add label name and hooks to gen_for functions 2024-07-26 13:35:40 +08:00
David Mak 0ba68f6657 core: Set target triple and datalayout for each module
Fixes an issue with inconsistent pointer sizes causing crashes.
2024-07-26 13:35:40 +08:00
David Mak 693b2a8863 core: Add support for 32-bit size_t on 64-bit targets 2024-07-26 13:35:40 +08:00
David Mak 5faeede0e5 Determine size_t using LLVM target machine 2024-07-26 13:35:38 +08:00
David Mak 266707df9d standalone: Add support for running 32-bit binaries 2024-07-26 13:32:38 +08:00
David Mak 3d3c258756 standalone: Remove support for --lli 2024-07-26 13:32:38 +08:00
David Mak ed1182cb24 standalone: Update format specifiers for exceptions
Use platform-agnostic identifiers instead.
2024-07-26 13:32:37 +08:00
David Mak fd025c1137 standalone: Use uint32_t for cslice length
Matching the expected type of string and list slices.
2024-07-26 13:32:21 +08:00
David Mak f139db9af9 meta: Update dependencies 2024-07-26 10:33:02 +08:00
lyken 44487b76ae standalone: interpret_demo.py remove duplicated section 2024-07-22 17:23:35 +08:00
lyken 1332f113e8 standalone: fix interpret_demo.py comments 2024-07-22 17:06:14 +08:00
Sébastien Bourdeauducq 7632d6f72a cargo: update dependencies 2024-07-21 11:00:25 +08:00
David Mak 4948395ca2 core/toplevel/type_annotation: Add handling for mismatching class def
Primitive types only contain fields in its Type and not its TopLevelDef.
This causes primitive object types to lack some fields.
2024-07-19 14:42:14 +08:00
David Mak 3db3061d99 artiq/symbol_resolver: Handle type of zero-length lists 2024-07-19 14:42:14 +08:00
David Mak 51c2175c80 core/codegen/stmt: Convert assertion values to i1 2024-07-19 14:42:14 +08:00
lyken 1a31a50b8a
standalone: fix __nac3_raise def in demo.c 2024-07-17 21:22:08 +08:00
lyken 6c10e3d056 core: cargo clippy 2024-07-12 21:18:53 +08:00
lyken 2dbc1ec659 cargo fmt 2024-07-12 21:16:38 +08:00
Sebastien Bourdeauducq c80378063a add np_argmin/argmax to interpret_demo environment 2024-07-12 13:27:52 +02:00
abdul124 513d30152b core: support raise exception short form 2024-07-12 18:58:34 +08:00
abdul124 45e9360c4d standalone: Add np_argmax and np_argmin tests 2024-07-12 18:19:56 +08:00
abdul124 2e01b77fc8 core: refactor np_max/np_min functions 2024-07-12 18:18:54 +08:00
abdul124 cea7cade51 core: add np_argmax/np_argmin functions 2024-07-12 18:18:28 +08:00
86 changed files with 6338 additions and 1208 deletions

3
.clang-format Normal file
View File

@ -0,0 +1,3 @@
BasedOnStyle: Google
IndentWidth: 4
ReflowComments: false

98
Cargo.lock generated
View File

@ -26,9 +26,9 @@ dependencies = [
[[package]] [[package]]
name = "anstream" name = "anstream"
version = "0.6.14" version = "0.6.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "418c75fa768af9c03be99d17643f93f79bbba589895012a80e3452a19ddda15b" checksum = "64e15c1ab1f89faffbf04a634d5e1962e9074f2741eef6d97f3c4e322426d526"
dependencies = [ dependencies = [
"anstyle", "anstyle",
"anstyle-parse", "anstyle-parse",
@ -41,33 +41,33 @@ dependencies = [
[[package]] [[package]]
name = "anstyle" name = "anstyle"
version = "1.0.7" version = "1.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "038dfcf04a5feb68e9c60b21c9625a54c2c0616e79b72b0fd87075a056ae1d1b" checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1"
[[package]] [[package]]
name = "anstyle-parse" name = "anstyle-parse"
version = "0.2.4" version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c03a11a9034d92058ceb6ee011ce58af4a9bf61491aa7e1e59ecd24bd40d22d4" checksum = "eb47de1e80c2b463c735db5b217a0ddc39d612e7ac9e2e96a5aed1f57616c1cb"
dependencies = [ dependencies = [
"utf8parse", "utf8parse",
] ]
[[package]] [[package]]
name = "anstyle-query" name = "anstyle-query"
version = "1.1.0" version = "1.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad186efb764318d35165f1758e7dcef3b10628e26d41a44bc5550652e6804391" checksum = "6d36fc52c7f6c869915e99412912f22093507da8d9e942ceaf66fe4b7c14422a"
dependencies = [ dependencies = [
"windows-sys", "windows-sys",
] ]
[[package]] [[package]]
name = "anstyle-wincon" name = "anstyle-wincon"
version = "3.0.3" version = "3.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61a38449feb7068f52bb06c12759005cf459ee52bb4adc1d5a7c4322d716fb19" checksum = "5bf74e1b6e971609db8ca7a9ce79fd5768ab6ae46441c572e46cf596f59e57f8"
dependencies = [ dependencies = [
"anstyle", "anstyle",
"windows-sys", "windows-sys",
@ -117,9 +117,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.1.0" version = "1.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eaff6f8ce506b9773fa786672d63fc7a191ffea1be33f72bbd4aeacefca9ffc8" checksum = "2aba8f4e9906c7ce3c73463f62a7f0c65183ada1a2d47e397cc8810827f9694f"
[[package]] [[package]]
name = "cfg-if" name = "cfg-if"
@ -129,9 +129,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]] [[package]]
name = "clap" name = "clap"
version = "4.5.9" version = "4.5.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "64acc1846d54c1fe936a78dc189c34e28d3f5afc348403f28ecf53660b9b8462" checksum = "35723e6a11662c2afb578bcf0b88bf6ea8e21282a953428f240574fcc3a2b5b3"
dependencies = [ dependencies = [
"clap_builder", "clap_builder",
"clap_derive", "clap_derive",
@ -139,9 +139,9 @@ dependencies = [
[[package]] [[package]]
name = "clap_builder" name = "clap_builder"
version = "4.5.9" version = "4.5.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fb8393d67ba2e7bfaf28a23458e4e2b543cc73a99595511eb207fdb8aede942" checksum = "49eb96cbfa7cfa35017b7cd548c75b14c3118c98b423041d70562665e07fb0fa"
dependencies = [ dependencies = [
"anstream", "anstream",
"anstyle", "anstyle",
@ -151,27 +151,27 @@ dependencies = [
[[package]] [[package]]
name = "clap_derive" name = "clap_derive"
version = "4.5.8" version = "4.5.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2bac35c6dafb060fd4d275d9a4ffae97917c13a6327903a8be2153cd964f7085" checksum = "5d029b67f89d30bbb547c89fd5161293c0aec155fc691d7924b64550662db93e"
dependencies = [ dependencies = [
"heck 0.5.0", "heck 0.5.0",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.70", "syn 2.0.72",
] ]
[[package]] [[package]]
name = "clap_lex" name = "clap_lex"
version = "0.7.1" version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4b82cf0babdbd58558212896d1a4272303a57bdb245c2bf1147185fb45640e70" checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97"
[[package]] [[package]]
name = "colorchoice" name = "colorchoice"
version = "1.0.1" version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422" checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0"
[[package]] [[package]]
name = "console" name = "console"
@ -421,7 +421,7 @@ checksum = "4fa4d8d74483041a882adaa9a29f633253a66dde85055f0495c121620ac484b2"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.70", "syn 2.0.72",
] ]
[[package]] [[package]]
@ -440,9 +440,9 @@ dependencies = [
[[package]] [[package]]
name = "is_terminal_polyfill" name = "is_terminal_polyfill"
version = "1.70.0" version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f8478577c03552c21db0e2724ffb8986a5ce7af88107e6be5d2ee6e158c12800" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
[[package]] [[package]]
name = "itertools" name = "itertools"
@ -513,9 +513,9 @@ checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c"
[[package]] [[package]]
name = "libloading" name = "libloading"
version = "0.8.4" version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e310b3a6b5907f99202fcdb4960ff45b93735d7c7d96b760fcff8db2dc0e103d" checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"windows-targets", "windows-targets",
@ -749,7 +749,7 @@ dependencies = [
"phf_shared 0.11.2", "phf_shared 0.11.2",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.70", "syn 2.0.72",
] ]
[[package]] [[package]]
@ -778,9 +778,9 @@ checksum = "5be167a7af36ee22fe3115051bc51f6e6c7054c9348e28deb4f49bd6f705a315"
[[package]] [[package]]
name = "portable-atomic" name = "portable-atomic"
version = "1.6.0" version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265"
[[package]] [[package]]
name = "ppv-lite86" name = "ppv-lite86"
@ -850,7 +850,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"pyo3-macros-backend", "pyo3-macros-backend",
"quote", "quote",
"syn 2.0.70", "syn 2.0.72",
] ]
[[package]] [[package]]
@ -863,7 +863,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"pyo3-build-config", "pyo3-build-config",
"quote", "quote",
"syn 2.0.70", "syn 2.0.72",
] ]
[[package]] [[package]]
@ -927,9 +927,9 @@ dependencies = [
[[package]] [[package]]
name = "redox_syscall" name = "redox_syscall"
version = "0.5.2" version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c82cf8cff14456045f55ec4241383baeff27af886adb72ffb2162f99911de0fd" checksum = "2a908a6e00f1fdd0dfd9c0eb08ce85126f6d8bbda50017e74bc4a4b7d4a926a4"
dependencies = [ dependencies = [
"bitflags", "bitflags",
] ]
@ -1044,7 +1044,7 @@ checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.70", "syn 2.0.72",
] ]
[[package]] [[package]]
@ -1072,9 +1072,9 @@ dependencies = [
[[package]] [[package]]
name = "similar" name = "similar"
version = "2.5.0" version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa42c91313f1d05da9b26f267f931cf178d4aba455b4c4622dd7355eb80c6640" checksum = "1de1d4f81173b03af4c0cbed3c898f6bff5b870e4a7f5d6f4057d62a7a4b686e"
[[package]] [[package]]
name = "siphasher" name = "siphasher"
@ -1134,7 +1134,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"rustversion", "rustversion",
"syn 2.0.70", "syn 2.0.72",
] ]
[[package]] [[package]]
@ -1150,9 +1150,9 @@ dependencies = [
[[package]] [[package]]
name = "syn" name = "syn"
version = "2.0.70" version = "2.0.72"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f0209b68b3613b093e0ec905354eccaedcfe83b8cb37cbdeae64026c3064c16" checksum = "dc4b9b9bf2add8093d3f2c0204471e951b2285580335de42f9d2534f3ae7a8af"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -1203,22 +1203,22 @@ dependencies = [
[[package]] [[package]]
name = "thiserror" name = "thiserror"
version = "1.0.61" version = "1.0.63"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709" checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724"
dependencies = [ dependencies = [
"thiserror-impl", "thiserror-impl",
] ]
[[package]] [[package]]
name = "thiserror-impl" name = "thiserror-impl"
version = "1.0.61" version = "1.0.63"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.70", "syn 2.0.72",
] ]
[[package]] [[package]]
@ -1336,9 +1336,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]] [[package]]
name = "version_check" name = "version_check"
version = "0.9.4" version = "0.9.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
[[package]] [[package]]
name = "walkdir" name = "walkdir"
@ -1486,5 +1486,5 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.70", "syn 2.0.72",
] ]

View File

@ -2,11 +2,11 @@
"nodes": { "nodes": {
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1720418205, "lastModified": 1721924956,
"narHash": "sha256-cPJoFPXU44GlhWg4pUk9oUPqurPlCFZ11ZQPk21GTPU=", "narHash": "sha256-Sb1jlyRO+N8jBXEX9Pg9Z1Qb8Bw9QyOgLDNMEpmjZ2M=",
"owner": "NixOS", "owner": "NixOS",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "655a58a72a6601292512670343087c2d75d859c1", "rev": "5ad6a14c6bf098e98800b091668718c336effc95",
"type": "github" "type": "github"
}, },
"original": { "original": {

View File

@ -13,6 +13,7 @@
'' ''
mkdir -p $out/bin mkdir -p $out/bin
ln -s ${pkgs.llvmPackages_14.clang-unwrapped}/bin/clang $out/bin/clang-irrt ln -s ${pkgs.llvmPackages_14.clang-unwrapped}/bin/clang $out/bin/clang-irrt
ln -s ${pkgs.llvmPackages_14.clang}/bin/clang $out/bin/clang-irrt-test
ln -s ${pkgs.llvmPackages_14.llvm.out}/bin/llvm-as $out/bin/llvm-as-irrt ln -s ${pkgs.llvmPackages_14.llvm.out}/bin/llvm-as $out/bin/llvm-as-irrt
''; '';
nac3artiq = pkgs.python3Packages.toPythonModule ( nac3artiq = pkgs.python3Packages.toPythonModule (
@ -23,8 +24,9 @@
cargoLock = { cargoLock = {
lockFile = ./Cargo.lock; lockFile = ./Cargo.lock;
}; };
cargoTestFlags = [ "--features" "test" ];
passthru.cargoLock = cargoLock; passthru.cargoLock = cargoLock;
nativeBuildInputs = [ pkgs.python3 pkgs.llvmPackages_14.clang llvm-tools-irrt pkgs.llvmPackages_14.llvm.out llvm-nac3 ]; nativeBuildInputs = [ pkgs.python3 (pkgs.wrapClangMulti pkgs.llvmPackages_14.clang) llvm-tools-irrt pkgs.llvmPackages_14.llvm.out llvm-nac3 ];
buildInputs = [ pkgs.python3 llvm-nac3 ]; buildInputs = [ pkgs.python3 llvm-nac3 ];
checkInputs = [ (pkgs.python3.withPackages(ps: [ ps.numpy ps.scipy ])) ]; checkInputs = [ (pkgs.python3.withPackages(ps: [ ps.numpy ps.scipy ])) ];
checkPhase = checkPhase =
@ -149,7 +151,7 @@
buildInputs = with pkgs; [ buildInputs = with pkgs; [
# build dependencies # build dependencies
packages.x86_64-linux.llvm-nac3 packages.x86_64-linux.llvm-nac3
llvmPackages_14.clang llvmPackages_14.llvm.out # for running nac3standalone demos (pkgs.wrapClangMulti llvmPackages_14.clang) llvmPackages_14.llvm.out # for running nac3standalone demos
packages.x86_64-linux.llvm-tools-irrt packages.x86_64-linux.llvm-tools-irrt
cargo cargo
rustc rustc

View File

@ -0,0 +1,24 @@
from min_artiq import *
from numpy import int32
@nac3
class EmptyList:
core: KernelInvariant[Core]
def __init__(self):
self.core = Core()
@rpc
def get_empty(self) -> list[int32]:
return []
@kernel
def run(self):
a: list[int32] = self.get_empty()
if a != []:
raise ValueError
if __name__ == "__main__":
EmptyList().run()

View File

@ -24,6 +24,7 @@ use std::rc::Rc;
use std::sync::Arc; use std::sync::Arc;
use inkwell::{ use inkwell::{
context::Context,
memory_buffer::MemoryBuffer, memory_buffer::MemoryBuffer,
module::{Linkage, Module}, module::{Linkage, Module},
passes::PassBuilderOptions, passes::PassBuilderOptions,
@ -625,7 +626,9 @@ impl Nac3 {
let buffer = buffer.as_slice().into(); let buffer = buffer.as_slice().into();
membuffer.lock().push(buffer); membuffer.lock().push(buffer);
}))); })));
let size_t = if self.isa == Isa::Host { 64 } else { 32 }; let size_t = Context::create()
.ptr_sized_int_type(&self.get_llvm_target_machine().get_target_data(), None)
.get_bit_width();
let num_threads = if is_multithreaded() { 4 } else { 1 }; let num_threads = if is_multithreaded() { 4 } else { 1 };
let thread_names: Vec<String> = (0..num_threads).map(|_| "main".to_string()).collect(); let thread_names: Vec<String> = (0..num_threads).map(|_| "main".to_string()).collect();
let threads: Vec<_> = thread_names let threads: Vec<_> = thread_names
@ -644,6 +647,9 @@ impl Nac3 {
ArtiqCodeGenerator::new("attributes_writeback".to_string(), size_t, self.time_fns); ArtiqCodeGenerator::new("attributes_writeback".to_string(), size_t, self.time_fns);
let context = inkwell::context::Context::create(); let context = inkwell::context::Context::create();
let module = context.create_module("attributes_writeback"); let module = context.create_module("attributes_writeback");
let target_machine = self.llvm_options.create_target_machine().unwrap();
module.set_data_layout(&target_machine.get_target_data().get_data_layout());
module.set_triple(&target_machine.get_triple());
let builder = context.create_builder(); let builder = context.create_builder();
let (_, module, _) = gen_func_impl( let (_, module, _) = gen_func_impl(
&context, &context,

View File

@ -991,8 +991,15 @@ impl InnerResolver {
} }
_ => unreachable!("must be list"), _ => unreachable!("must be list"),
}; };
let ty = ctx.get_llvm_type(generator, elem_ty);
let size_t = generator.get_size_type(ctx.ctx); let size_t = generator.get_size_type(ctx.ctx);
let ty = if len == 0
&& matches!(&*ctx.unifier.get_ty_immutable(elem_ty), TypeEnum::TVar { .. })
{
// The default type for zero-length lists of unknown element type is size_t
size_t.into()
} else {
ctx.get_llvm_type(generator, elem_ty)
};
let arr_ty = ctx let arr_ty = ctx
.ctx .ctx
.struct_type(&[ty.ptr_type(AddressSpace::default()).into(), size_t.into()], false); .struct_type(&[ty.ptr_type(AddressSpace::default()).into(), size_t.into()], false);

View File

@ -1,3 +1,6 @@
[features]
test = []
[package] [package]
name = "nac3core" name = "nac3core"
version = "0.1.0" version = "0.1.0"

View File

@ -3,20 +3,34 @@ use std::{
env, env,
fs::File, fs::File,
io::Write, io::Write,
path::Path, path::{Path, PathBuf},
process::{Command, Stdio}, process::{Command, Stdio},
}; };
fn main() { const CMD_IRRT_CLANG: &str = "clang-irrt";
const FILE: &str = "src/codegen/irrt/irrt.cpp"; const CMD_IRRT_CLANG_TEST: &str = "clang-irrt-test";
const CMD_IRRT_LLVM_AS: &str = "llvm-as-irrt";
fn get_out_dir() -> PathBuf {
PathBuf::from(env::var("OUT_DIR").unwrap())
}
fn get_irrt_dir() -> &'static Path {
Path::new("irrt")
}
/// Compile `irrt.cpp` for use in `src/codegen`
fn compile_irrt_cpp() {
let out_dir = get_out_dir();
let irrt_dir = get_irrt_dir();
/* /*
* HACK: Sadly, clang doesn't let us emit generic LLVM bitcode. * HACK: Sadly, clang doesn't let us emit generic LLVM bitcode.
* Compiling for WASM32 and filtering the output with regex is the closest we can get. * Compiling for WASM32 and filtering the output with regex is the closest we can get.
*/ */
let irrt_cpp_path = irrt_dir.join("irrt.cpp");
let flags: &[&str] = &[ let flags: &[&str] = &[
"--target=wasm32", "--target=wasm32",
FILE,
"-x", "-x",
"c++", "c++",
"-fno-discard-value-names", "-fno-discard-value-names",
@ -33,13 +47,16 @@ fn main() {
"-Wextra", "-Wextra",
"-o", "-o",
"-", "-",
"-I",
irrt_dir.to_str().unwrap(),
irrt_cpp_path.to_str().unwrap(),
]; ];
println!("cargo:rerun-if-changed={FILE}"); // Tell Cargo to rerun if any file under `irrt_dir` (recursive) changes
let out_dir = env::var("OUT_DIR").unwrap(); println!("cargo:rerun-if-changed={}", irrt_dir.to_str().unwrap());
let out_path = Path::new(&out_dir);
let output = Command::new("clang-irrt") // Compile IRRT and capture the LLVM IR output
let output = Command::new(CMD_IRRT_CLANG)
.args(flags) .args(flags)
.output() .output()
.map(|o| { .map(|o| {
@ -52,7 +69,17 @@ fn main() {
let output = std::str::from_utf8(&output.stdout).unwrap().replace("\r\n", "\n"); let output = std::str::from_utf8(&output.stdout).unwrap().replace("\r\n", "\n");
let mut filtered_output = String::with_capacity(output.len()); let mut filtered_output = String::with_capacity(output.len());
let regex_filter = Regex::new(r"(?ms:^define.*?\}$)|(?m:^declare.*?$)").unwrap(); // Filter out irrelevant IR
//
// Regex:
// - `(?ms:^define.*?\}$)` captures LLVM `define` blocks
// - `(?m:^declare.*?$)` captures LLVM `declare` lines
// - `(?m:^%.+?=\s*type\s*\{.+?\}$)` captures LLVM `type` declarations
// - `(?m:^@.+?=.+$)` captures global constants
let regex_filter = Regex::new(
r"(?ms:^define.*?\}$)|(?m:^declare.*?$)|(?m:^%.+?=\s*type\s*\{.+?\}$)|(?m:^@.+?=.+$)",
)
.unwrap();
for f in regex_filter.captures_iter(&output) { for f in regex_filter.captures_iter(&output) {
assert_eq!(f.len(), 1); assert_eq!(f.len(), 1);
filtered_output.push_str(&f[0]); filtered_output.push_str(&f[0]);
@ -63,20 +90,71 @@ fn main() {
.unwrap() .unwrap()
.replace_all(&filtered_output, ""); .replace_all(&filtered_output, "");
println!("cargo:rerun-if-env-changed=DEBUG_DUMP_IRRT"); // For debugging
if env::var("DEBUG_DUMP_IRRT").is_ok() { // Doing `DEBUG_DUMP_IRRT=1 cargo build -p nac3core` dumps the LLVM IR generated
let mut file = File::create(out_path.join("irrt.ll")).unwrap(); const DEBUG_DUMP_IRRT: &str = "DEBUG_DUMP_IRRT";
println!("cargo:rerun-if-env-changed={DEBUG_DUMP_IRRT}");
if env::var(DEBUG_DUMP_IRRT).is_ok() {
let mut file = File::create(out_dir.join("irrt.ll")).unwrap();
file.write_all(output.as_bytes()).unwrap(); file.write_all(output.as_bytes()).unwrap();
let mut file = File::create(out_path.join("irrt-filtered.ll")).unwrap();
let mut file = File::create(out_dir.join("irrt-filtered.ll")).unwrap();
file.write_all(filtered_output.as_bytes()).unwrap(); file.write_all(filtered_output.as_bytes()).unwrap();
} }
let mut llvm_as = Command::new("llvm-as-irrt") // Assemble the emitted and filtered IR to .bc
// That .bc will be integrated into nac3core's codegen
let mut llvm_as = Command::new(CMD_IRRT_LLVM_AS)
.stdin(Stdio::piped()) .stdin(Stdio::piped())
.arg("-o") .arg("-o")
.arg(out_path.join("irrt.bc")) .arg(out_dir.join("irrt.bc"))
.spawn() .spawn()
.unwrap(); .unwrap();
llvm_as.stdin.as_mut().unwrap().write_all(filtered_output.as_bytes()).unwrap(); llvm_as.stdin.as_mut().unwrap().write_all(filtered_output.as_bytes()).unwrap();
assert!(llvm_as.wait().unwrap().success()); assert!(llvm_as.wait().unwrap().success());
} }
/// Compile `irrt_test.cpp` for testing
fn compile_irrt_test_cpp() {
let out_dir = get_out_dir();
let irrt_dir = get_irrt_dir();
let exe_path = out_dir.join("irrt_test.out"); // Output path of the compiled test executable
let irrt_test_cpp_path = irrt_dir.join("irrt_test.cpp");
let flags: &[&str] = &[
irrt_test_cpp_path.to_str().unwrap(),
"-x",
"c++",
"-I",
irrt_dir.to_str().unwrap(),
"-g",
"-fno-discard-value-names",
"-O0",
"-Wall",
"-Wextra",
"-Werror=return-type",
"-lm", // for `tgamma()`, `lgamma()`
"-o",
exe_path.to_str().unwrap(),
];
Command::new(CMD_IRRT_CLANG_TEST)
.args(flags)
.output()
.map(|o| {
assert!(o.status.success(), "{}", std::str::from_utf8(&o.stderr).unwrap());
o
})
.unwrap();
println!("cargo:rerun-if-changed={}", irrt_dir.to_str().unwrap());
}
fn main() {
compile_irrt_cpp();
// https://github.com/rust-lang/cargo/issues/2549
// `cargo test -F test` to also build `irrt_test.cpp
if cfg!(feature = "test") {
compile_irrt_test_cpp();
}
}

10
nac3core/irrt/irrt.cpp Normal file
View File

@ -0,0 +1,10 @@
#define IRRT_DEFINE_TYPEDEF_INTS
#include <irrt_everything.hpp>
/*
* All IRRT implementations.
*
* We don't have pre-compiled objects, so we are writing all implementations in
* headers and concatenate them with `#include` into one massive source file that
* contains all the IRRT stuff.
*/

View File

@ -0,0 +1,39 @@
#pragma once
#include <irrt/int_defs.hpp>
/*
This file defines all ARTIQ-specific structures
*/
/**
* @brief ARTIQ's `cslice` object
*
* See https://docs.rs/cslice/0.3.0/src/cslice/lib.rs.html#33-37
*/
template <typename SizeT>
struct CSlice {
const char *base;
SizeT len;
};
/**
* @brief Int type of ARTIQ's `Exception` IDs.
*/
typedef uint32_t ExceptionId;
/**
* @brief ARTIQ's `Exception` object
*
* See https://github.com/m-labs/artiq/blob/b0d2705c385f64b6e6711c1726cd9178f40b598e/artiq/firmware/libeh/eh_artiq.rs#L1C1-L17C1
*/
template <typename SizeT>
struct Exception {
ExceptionId id;
CSlice<SizeT> file;
uint32_t line;
uint32_t column;
CSlice<SizeT> function;
CSlice<SizeT> message;
uint32_t param;
};

View File

@ -1,27 +1,15 @@
using int8_t = _BitInt(8); #pragma once
using uint8_t = unsigned _BitInt(8);
using int32_t = _BitInt(32); #include <irrt/int_defs.hpp>
using uint32_t = unsigned _BitInt(32); #include <irrt/slice.hpp>
using int64_t = _BitInt(64); #include <irrt/utils.hpp>
using uint64_t = unsigned _BitInt(64);
// NDArray indices are always `uint32_t`. // NDArray indices are always `uint32_t`.
using NDIndex = uint32_t; using NDIndexInt = uint32_t;
// The type of an index or a value describing the length of a range/slice is always `int32_t`.
using SliceIndex = int32_t;
namespace { namespace {
template <typename T> // adapted from GNU Scientific Library:
const T& max(const T& a, const T& b) { // https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
return a > b ? a : b;
}
template <typename T>
const T& min(const T& a, const T& b) {
return a > b ? b : a;
}
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
// need to make sure `exp >= 0` before calling this function // need to make sure `exp >= 0` before calling this function
template <typename T> template <typename T>
T __nac3_int_exp_impl(T base, T exp) { T __nac3_int_exp_impl(T base, T exp) {
@ -38,12 +26,8 @@ T __nac3_int_exp_impl(T base, T exp) {
} }
template <typename SizeT> template <typename SizeT>
SizeT __nac3_ndarray_calc_size_impl( SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len,
const SizeT* list_data, SizeT begin_idx, SizeT end_idx) {
SizeT list_len,
SizeT begin_idx,
SizeT end_idx
) {
__builtin_assume(end_idx <= list_len); __builtin_assume(end_idx <= list_len);
SizeT num_elems = 1; SizeT num_elems = 1;
@ -56,12 +40,8 @@ SizeT __nac3_ndarray_calc_size_impl(
} }
template <typename SizeT> template <typename SizeT>
void __nac3_ndarray_calc_nd_indices_impl( void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims,
SizeT index, SizeT num_dims, NDIndexInt* idxs) {
const SizeT* dims,
SizeT num_dims,
NDIndex* idxs
) {
SizeT stride = 1; SizeT stride = 1;
for (SizeT dim = 0; dim < num_dims; dim++) { for (SizeT dim = 0; dim < num_dims; dim++) {
SizeT i = num_dims - dim - 1; SizeT i = num_dims - dim - 1;
@ -72,12 +52,9 @@ void __nac3_ndarray_calc_nd_indices_impl(
} }
template <typename SizeT> template <typename SizeT>
SizeT __nac3_ndarray_flatten_index_impl( SizeT __nac3_ndarray_flatten_index_impl(const SizeT* dims, SizeT num_dims,
const SizeT* dims, const NDIndexInt* indices,
SizeT num_dims, SizeT num_indices) {
const NDIndex* indices,
SizeT num_indices
) {
SizeT idx = 0; SizeT idx = 0;
SizeT stride = 1; SizeT stride = 1;
for (SizeT i = 0; i < num_dims; ++i) { for (SizeT i = 0; i < num_dims; ++i) {
@ -93,18 +70,17 @@ SizeT __nac3_ndarray_flatten_index_impl(
} }
template <typename SizeT> template <typename SizeT>
void __nac3_ndarray_calc_broadcast_impl( void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims, SizeT lhs_ndims,
const SizeT* lhs_dims, const SizeT* rhs_dims, SizeT rhs_ndims,
SizeT lhs_ndims, SizeT* out_dims) {
const SizeT* rhs_dims,
SizeT rhs_ndims,
SizeT* out_dims
) {
SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims; SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
for (SizeT i = 0; i < max_ndims; ++i) { for (SizeT i = 0; i < max_ndims; ++i) {
const SizeT* lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr; const SizeT* lhs_dim_sz =
const SizeT* rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr; i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr;
const SizeT* rhs_dim_sz =
i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr;
SizeT* out_dim = &out_dims[max_ndims - i - 1]; SizeT* out_dim = &out_dims[max_ndims - i - 1];
if (lhs_dim_sz == nullptr) { if (lhs_dim_sz == nullptr) {
@ -124,12 +100,10 @@ void __nac3_ndarray_calc_broadcast_impl(
} }
template <typename SizeT> template <typename SizeT>
void __nac3_ndarray_calc_broadcast_idx_impl( void __nac3_ndarray_calc_broadcast_idx_impl(const SizeT* src_dims,
const SizeT* src_dims, SizeT src_ndims,
SizeT src_ndims, const NDIndexInt* in_idx,
const NDIndex* in_idx, NDIndexInt* out_idx) {
NDIndex* out_idx
) {
for (SizeT i = 0; i < src_ndims; ++i) { for (SizeT i = 0; i < src_ndims; ++i) {
SizeT src_i = src_ndims - i - 1; SizeT src_i = src_ndims - i - 1;
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i]; out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
@ -138,15 +112,15 @@ void __nac3_ndarray_calc_broadcast_idx_impl(
} // namespace } // namespace
extern "C" { extern "C" {
#define DEF_nac3_int_exp_(T) \ #define DEF_nac3_int_exp_(T) \
T __nac3_int_exp_##T(T base, T exp) {\ T __nac3_int_exp_##T(T base, T exp) { \
return __nac3_int_exp_impl(base, exp);\ return __nac3_int_exp_impl(base, exp); \
} }
DEF_nac3_int_exp_(int32_t) DEF_nac3_int_exp_(int32_t);
DEF_nac3_int_exp_(int64_t) DEF_nac3_int_exp_(int64_t);
DEF_nac3_int_exp_(uint32_t) DEF_nac3_int_exp_(uint32_t);
DEF_nac3_int_exp_(uint64_t) DEF_nac3_int_exp_(uint64_t);
SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) { SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) {
if (i < 0) { if (i < 0) {
@ -160,11 +134,8 @@ SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) {
return i; return i;
} }
SliceIndex __nac3_range_slice_len( SliceIndex __nac3_range_slice_len(const SliceIndex start, const SliceIndex end,
const SliceIndex start, const SliceIndex step) {
const SliceIndex end,
const SliceIndex step
) {
SliceIndex diff = end - start; SliceIndex diff = end - start;
if (diff > 0 && step > 0) { if (diff > 0 && step > 0) {
return ((diff - 1) / step) + 1; return ((diff - 1) / step) + 1;
@ -180,62 +151,52 @@ SliceIndex __nac3_range_slice_len(
// - All the index must *not* be out-of-bound or negative, // - All the index must *not* be out-of-bound or negative,
// - The end index is *inclusive*, // - The end index is *inclusive*,
// - The length of src and dest slice size should already // - The length of src and dest slice size should already
// be checked: if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest) // be checked: if dest.step == 1 then len(src) <= len(dest) else
// len(src) == len(dest)
SliceIndex __nac3_list_slice_assign_var_size( SliceIndex __nac3_list_slice_assign_var_size(
SliceIndex dest_start, SliceIndex dest_start, SliceIndex dest_end, SliceIndex dest_step,
SliceIndex dest_end, uint8_t* dest_arr, SliceIndex dest_arr_len, SliceIndex src_start,
SliceIndex dest_step, SliceIndex src_end, SliceIndex src_step, uint8_t* src_arr,
uint8_t* dest_arr, SliceIndex src_arr_len, const SliceIndex size) {
SliceIndex dest_arr_len, /* if dest_arr_len == 0, do nothing since we do not support
SliceIndex src_start, * extending list
SliceIndex src_end, */
SliceIndex src_step,
uint8_t* src_arr,
SliceIndex src_arr_len,
const SliceIndex size
) {
/* if dest_arr_len == 0, do nothing since we do not support extending list */
if (dest_arr_len == 0) return dest_arr_len; if (dest_arr_len == 0) return dest_arr_len;
/* if both step is 1, memmove directly, handle the dropping of the list, and shrink size */ /* if both step is 1, memmove directly, handle the dropping of
* the list, and shrink size */
if (src_step == dest_step && dest_step == 1) { if (src_step == dest_step && dest_step == 1) {
const SliceIndex src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0; const SliceIndex src_len =
const SliceIndex dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0; (src_end >= src_start) ? (src_end - src_start + 1) : 0;
const SliceIndex dest_len =
(dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
if (src_len > 0) { if (src_len > 0) {
__builtin_memmove( __builtin_memmove(dest_arr + dest_start * size,
dest_arr + dest_start * size, src_arr + src_start * size, src_len * size);
src_arr + src_start * size,
src_len * size
);
} }
if (dest_len > 0) { if (dest_len > 0) {
/* dropping */ /* dropping */
__builtin_memmove( __builtin_memmove(dest_arr + (dest_start + src_len) * size,
dest_arr + (dest_start + src_len) * size, dest_arr + (dest_end + 1) * size,
dest_arr + (dest_end + 1) * size, (dest_arr_len - dest_end - 1) * size);
(dest_arr_len - dest_end - 1) * size
);
} }
/* shrink size */ /* shrink size */
return dest_arr_len - (dest_len - src_len); return dest_arr_len - (dest_len - src_len);
} }
/* if two range overlaps, need alloca */ /* if two range overlaps, need alloca */
uint8_t need_alloca = uint8_t need_alloca =
(dest_arr == src_arr) (dest_arr == src_arr) &&
&& !( !(max(dest_start, dest_end) < min(src_start, src_end) ||
max(dest_start, dest_end) < min(src_start, src_end) max(src_start, src_end) < min(dest_start, dest_end));
|| max(src_start, src_end) < min(dest_start, dest_end)
);
if (need_alloca) { if (need_alloca) {
uint8_t* tmp = reinterpret_cast<uint8_t *>(__builtin_alloca(src_arr_len * size)); uint8_t* tmp =
reinterpret_cast<uint8_t*>(__builtin_alloca(src_arr_len * size));
__builtin_memcpy(tmp, src_arr, src_arr_len * size); __builtin_memcpy(tmp, src_arr, src_arr_len * size);
src_arr = tmp; src_arr = tmp;
} }
SliceIndex src_ind = src_start; SliceIndex src_ind = src_start;
SliceIndex dest_ind = dest_start; SliceIndex dest_ind = dest_start;
for (; for (; (src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end);
(src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end); src_ind += src_step, dest_ind += dest_step) {
src_ind += src_step, dest_ind += dest_step
) {
/* for constant optimization */ /* for constant optimization */
if (size == 1) { if (size == 1) {
__builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1); __builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1);
@ -244,30 +205,26 @@ SliceIndex __nac3_list_slice_assign_var_size(
} else if (size == 8) { } else if (size == 8) {
__builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8); __builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8);
} else { } else {
/* memcpy for var size, cannot overlap after previous alloca */ /* memcpy for var size, cannot overlap after previous
__builtin_memcpy(dest_arr + dest_ind * size, src_arr + src_ind * size, size); * alloca */
__builtin_memcpy(dest_arr + dest_ind * size,
src_arr + src_ind * size, size);
} }
} }
/* only dest_step == 1 can we shrink the dest list. */ /* only dest_step == 1 can we shrink the dest list. */
/* size should be ensured prior to calling this function */ /* size should be ensured prior to calling this function */
if (dest_step == 1 && dest_end >= dest_start) { if (dest_step == 1 && dest_end >= dest_start) {
__builtin_memmove( __builtin_memmove(
dest_arr + dest_ind * size, dest_arr + dest_ind * size, dest_arr + (dest_end + 1) * size,
dest_arr + (dest_end + 1) * size, (dest_arr_len - dest_end - 1) * size + size + size + size);
(dest_arr_len - dest_end - 1) * size
);
return dest_arr_len - (dest_end - dest_ind) - 1; return dest_arr_len - (dest_end - dest_ind) - 1;
} }
return dest_arr_len; return dest_arr_len;
} }
int32_t __nac3_isinf(double x) { int32_t __nac3_isinf(double x) { return __builtin_isinf(x); }
return __builtin_isinf(x);
}
int32_t __nac3_isnan(double x) { int32_t __nac3_isnan(double x) { return __builtin_isnan(x); }
return __builtin_isnan(x);
}
double tgamma(double arg); double tgamma(double arg);
@ -320,95 +277,71 @@ double __nac3_j0(double x) {
return j0(x); return j0(x);
} }
uint32_t __nac3_ndarray_calc_size( uint32_t __nac3_ndarray_calc_size(const uint32_t* list_data, uint32_t list_len,
const uint32_t* list_data, uint32_t begin_idx, uint32_t end_idx) {
uint32_t list_len, return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx,
uint32_t begin_idx, end_idx);
uint32_t end_idx
) {
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
} }
uint64_t __nac3_ndarray_calc_size64( uint64_t __nac3_ndarray_calc_size64(const uint64_t* list_data,
const uint64_t* list_data, uint64_t list_len, uint64_t begin_idx,
uint64_t list_len, uint64_t end_idx) {
uint64_t begin_idx, return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx,
uint64_t end_idx end_idx);
) {
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
} }
void __nac3_ndarray_calc_nd_indices( void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims,
uint32_t index, uint32_t num_dims, NDIndexInt* idxs) {
const uint32_t* dims,
uint32_t num_dims,
NDIndex* idxs
) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs); __nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
} }
void __nac3_ndarray_calc_nd_indices64( void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims,
uint64_t index, uint64_t num_dims, NDIndexInt* idxs) {
const uint64_t* dims,
uint64_t num_dims,
NDIndex* idxs
) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs); __nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
} }
uint32_t __nac3_ndarray_flatten_index( uint32_t __nac3_ndarray_flatten_index(const uint32_t* dims, uint32_t num_dims,
const uint32_t* dims, const NDIndexInt* indices,
uint32_t num_dims, uint32_t num_indices) {
const NDIndex* indices, return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices,
uint32_t num_indices num_indices);
) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
} }
uint64_t __nac3_ndarray_flatten_index64( uint64_t __nac3_ndarray_flatten_index64(const uint64_t* dims, uint64_t num_dims,
const uint64_t* dims, const NDIndexInt* indices,
uint64_t num_dims, uint64_t num_indices) {
const NDIndex* indices, return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices,
uint64_t num_indices num_indices);
) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
} }
void __nac3_ndarray_calc_broadcast( void __nac3_ndarray_calc_broadcast(const uint32_t* lhs_dims, uint32_t lhs_ndims,
const uint32_t* lhs_dims, const uint32_t* rhs_dims, uint32_t rhs_ndims,
uint32_t lhs_ndims, uint32_t* out_dims) {
const uint32_t* rhs_dims, return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims,
uint32_t rhs_ndims, rhs_ndims, out_dims);
uint32_t* out_dims
) {
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
} }
void __nac3_ndarray_calc_broadcast64( void __nac3_ndarray_calc_broadcast64(const uint64_t* lhs_dims,
const uint64_t* lhs_dims, uint64_t lhs_ndims,
uint64_t lhs_ndims, const uint64_t* rhs_dims,
const uint64_t* rhs_dims, uint64_t rhs_ndims, uint64_t* out_dims) {
uint64_t rhs_ndims, return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims,
uint64_t* out_dims rhs_ndims, out_dims);
) {
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
} }
void __nac3_ndarray_calc_broadcast_idx( void __nac3_ndarray_calc_broadcast_idx(const uint32_t* src_dims,
const uint32_t* src_dims, uint32_t src_ndims,
uint32_t src_ndims, const NDIndexInt* in_idx,
const NDIndex* in_idx, NDIndexInt* out_idx) {
NDIndex* out_idx __nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx,
) { out_idx);
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
} }
void __nac3_ndarray_calc_broadcast_idx64( void __nac3_ndarray_calc_broadcast_idx64(const uint64_t* src_dims,
const uint64_t* src_dims, uint64_t src_ndims,
uint64_t src_ndims, const NDIndexInt* in_idx,
const NDIndex* in_idx, NDIndexInt* out_idx) {
NDIndex* out_idx __nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx,
) { out_idx);
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
} }
} // extern "C" } // extern "C"

View File

@ -0,0 +1,92 @@
#pragma once
#include <irrt/artiq_defs.hpp>
#include <irrt/int_defs.hpp>
#include <irrt/utils.hpp>
namespace {
/**
* @brief A (limited) set of known Exception IDs usable in IRRT
*/
struct ErrorContextExceptions {
ExceptionId index_error;
ExceptionId value_error;
ExceptionId assertion_error;
ExceptionId runtime_error;
ExceptionId type_error;
};
/**
* @brief The IRRT error context object
*
* This object contains all the details needed to propagate Python-like Exceptions in
* IRRT - within IRRT itself or propagate out of extern calls from nac3core.
*/
struct ErrorContext {
const ErrorContextExceptions *exceptions;
// Exception thrown by IRRT
ExceptionId exception_id;
// Points to empty c-string if there is no thrown Exception
const char *msg;
uint64_t param1;
uint64_t param2;
uint64_t param3;
void initialize(const ErrorContextExceptions *exceptions) {
this->exceptions = exceptions;
clear_error();
}
void clear_error() {
// NOTE: Point the msg to an empty str.
// Don't set it to nullptr - to implement `has_exception`
this->msg = "";
}
void set_exception(ExceptionId exception_id, const char *msg,
uint64_t param1 = 0, uint64_t param2 = 0,
uint64_t param3 = 0) {
this->exception_id = exception_id;
this->msg = msg;
this->param1 = param1;
this->param2 = param2;
this->param3 = param3;
}
bool has_exception() { return !cstr_utils::is_empty(msg); }
template <typename SizeT>
void get_exception_str(CSlice<SizeT> *dst_str) {
dst_str->base = msg;
dst_str->len = (SizeT)cstr_utils::length(msg);
}
};
} // namespace
extern "C" {
void __nac3_error_context_initialize(ErrorContext *errctx,
ErrorContextExceptions *exceptions) {
errctx->initialize(exceptions);
}
bool __nac3_error_context_has_exception(ErrorContext *errctx) {
return errctx->has_exception();
}
void __nac3_error_context_get_exception_str(ErrorContext *errctx,
CSlice<int32_t> *dst_str) {
errctx->get_exception_str<int32_t>(dst_str);
}
void __nac3_error_context_get_exception_str64(ErrorContext *errctx,
CSlice<int64_t> *dst_str) {
errctx->get_exception_str<int64_t>(dst_str);
}
// Used for testing
void __nac3_error_dummy_raise(ErrorContext *errctx) {
errctx->set_exception(errctx->exceptions->runtime_error,
"Error thrown from __nac3_error_dummy_raise");
}
}

View File

@ -0,0 +1,12 @@
#pragma once
// This is made toggleable since `irrt_test.cpp` itself would include
// headers that define these typedefs
#ifdef IRRT_DEFINE_TYPEDEF_INTS
using int8_t = _BitInt(8);
using uint8_t = unsigned _BitInt(8);
using int32_t = _BitInt(32);
using uint32_t = unsigned _BitInt(32);
using int64_t = _BitInt(64);
using uint64_t = unsigned _BitInt(64);
#endif

View File

@ -0,0 +1,315 @@
#pragma once
#include <irrt/error_context.hpp>
#include <irrt/int_defs.hpp>
#include <irrt/ndarray/def.hpp>
namespace {
namespace ndarray {
namespace basic {
namespace util {
/**
* @brief Asserts that `shape` does not contain negative dimensions.
*
* @param ndims Number of dimensions in `shape`
* @param shape The shape to check on
*/
template <typename SizeT>
void assert_shape_no_negative(ErrorContext* errctx, SizeT ndims,
const SizeT* shape) {
for (SizeT axis = 0; axis < ndims; axis++) {
if (shape[axis] < 0) {
errctx->set_exception(
errctx->exceptions->value_error,
"negative dimensions are not allowed; axis {0} "
"has dimension {1}",
axis, shape[axis]);
return;
}
}
}
/**
* @brief Returns the number of elements of an ndarray given its shape.
*
* @param ndims Number of dimensions in `shape`
* @param shape The shape of the ndarray
*/
template <typename SizeT>
SizeT calc_size_from_shape(SizeT ndims, const SizeT* shape) {
SizeT size = 1;
for (SizeT axis = 0; axis < ndims; axis++) size *= shape[axis];
return size;
}
/**
* @brief Compute the array indices of the `nth` (0-based) element of an ndarray given only its shape.
*
* @param ndims Number of elements in `shape` and `indices`
* @param shape The shape of the ndarray
* @param indices The returned indices indexing the ndarray with shape `shape`.
* @param nth The index of the element of interest.
*/
template <typename SizeT>
void set_indices_by_nth(SizeT ndims, const SizeT* shape, SizeT* indices,
SizeT nth) {
for (int32_t i = 0; i < ndims; i++) {
int32_t axis = ndims - i - 1;
int32_t dim = shape[axis];
indices[axis] = nth % dim;
nth /= dim;
}
}
} // namespace util
/**
* @brief Return the number of elements of an `ndarray`
*
* This function corresponds to `<an_ndarray>.size`
*/
template <typename SizeT>
SizeT size(const NDArray<SizeT>* ndarray) {
return util::calc_size_from_shape(ndarray->ndims, ndarray->shape);
}
/**
* @brief Return of the number of its content of an `ndarray`.
*
* This function corresponds to `<an_ndarray>.nbytes`.
*/
template <typename SizeT>
SizeT nbytes(const NDArray<SizeT>* ndarray) {
return size(ndarray) * ndarray->itemsize;
}
/**
* @brief Update the strides of an ndarray given an ndarray `shape`
* and assuming that the ndarray is fully c-contagious.
*
* You might want to read https://ajcr.net/stride-guide-part-1/.
*/
template <typename SizeT>
void set_strides_by_shape(NDArray<SizeT>* ndarray) {
SizeT stride_product = 1;
for (SizeT i = 0; i < ndarray->ndims; i++) {
int axis = ndarray->ndims - i - 1;
ndarray->strides[axis] = stride_product * ndarray->itemsize;
stride_product *= ndarray->shape[axis];
}
}
/**
* @brief Return the pointer to the element indexed by `indices`.
*/
template <typename SizeT>
uint8_t* get_pelement_by_indices(const NDArray<SizeT>* ndarray,
const SizeT* indices) {
uint8_t* element = ndarray->data;
for (SizeT dim_i = 0; dim_i < ndarray->ndims; dim_i++)
element += indices[dim_i] * ndarray->strides[dim_i];
return element;
}
/**
* @brief Return the pointer to the nth (0-based) element in a flattened view of `ndarray`.
*/
template <typename SizeT>
uint8_t* get_nth_pelement(const NDArray<SizeT>* ndarray, SizeT nth) {
SizeT* indices = (SizeT*)__builtin_alloca(sizeof(SizeT) * ndarray->ndims);
util::set_indices_by_nth(ndarray->ndims, ndarray->shape, indices, nth);
return get_pelement_by_indices(ndarray, indices);
}
/**
* @brief Like `get_nth_pelement` but asserts that `nth` is in bounds.
*/
template <typename SizeT>
uint8_t* checked_get_nth_pelement(ErrorContext* errctx,
const NDArray<SizeT>* ndarray, SizeT nth) {
SizeT arr_size = ndarray->size();
if (!(0 <= nth && nth < arr_size)) {
errctx->set_exception(
errctx->exceptions->index_error,
"index {0} is out of bounds, valid range is {1} <= index < {2}",
nth, 0, arr_size);
return 0;
}
return get_nth_pelement(ndarray, nth);
}
/**
* @brief Set an element in `ndarray`.
*
* @param pelement Pointer to the element in `ndarray` to be set.
* @param pvalue Pointer to the value `pelement` will be set to.
*/
template <typename SizeT>
void set_pelement_value(NDArray<SizeT>* ndarray, uint8_t* pelement,
const uint8_t* pvalue) {
__builtin_memcpy(pelement, pvalue, ndarray->itemsize);
}
/**
* @brief Get the `len()` of an ndarray, and asserts that `ndarray` is a sized object.
*
* This function corresponds to `<an_ndarray>.__len__`.
*
* @param dst_length The returned result
*/
template <typename SizeT>
void len(ErrorContext* errctx, const NDArray<SizeT>* ndarray,
SliceIndex* dst_length) {
// numpy prohibits `__len__` on unsized objects
if (ndarray->ndims == 0) {
errctx->set_exception(errctx->exceptions->type_error,
"len() of unsized object");
return;
}
*dst_length = (SliceIndex)ndarray->shape[0];
}
/**
* @brief Copy data from one ndarray to another of the exact same size and itemsize.
*
* Both ndarrays will be viewed in their flatten views when copying the elements.
*/
template <typename SizeT>
void copy_data(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
__builtin_assume(src_ndarray->itemsize == dst_ndarray->itemsize);
for (SizeT i = 0; i < size(src_ndarray); i++) {
auto src_element = ndarray::basic::get_nth_pelement(src_ndarray, i);
auto dst_element = ndarray::basic::get_nth_pelement(dst_ndarray, i);
ndarray::basic::set_pelement_value(dst_ndarray, dst_element,
src_element);
}
}
/**
* @brief Return a boolean indicating if `ndarray` is (C-)contiguous.
*
* You may want to see: ndarray's rules for C-contiguity: https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45
*/
template <typename SizeT>
bool is_c_contiguous(const NDArray<SizeT>* ndarray) {
// Other references:
// - tinynumpy's implementation: https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L102
// - ndarray's flags["C_CONTIGUOUS"]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags
// - ndarray's rules for C-contiguity: https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45
// From https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45:
//
// The traditional rule is that for an array to be flagged as C contiguous,
// the following must hold:
//
// strides[-1] == itemsize
// strides[i] == shape[i+1] * strides[i + 1]
// [...]
// According to these rules, a 0- or 1-dimensional array is either both
// C- and F-contiguous, or neither; and an array with 2+ dimensions
// can be C- or F- contiguous, or neither, but not both. Though there
// there are exceptions for arrays with zero or one item, in the first
// case the check is relaxed up to and including the first dimension
// with shape[i] == 0. In the second case `strides == itemsize` will
// can be true for all dimensions and both flags are set.
if (ndarray->ndims == 0) {
return true;
}
if (ndarray->strides[ndarray->ndims - 1] != ndarray->itemsize) {
return false;
}
for (SizeT i = 0; i < ndarray->ndims - 1; i++) {
if (ndarray->strides[i] !=
ndarray->shape[i + 1] + ndarray->strides[i + 1]) {
return false;
}
}
return true;
}
} // namespace basic
} // namespace ndarray
} // namespace
extern "C" {
using namespace ndarray::basic;
uint32_t __nac3_ndarray_size(NDArray<int32_t>* ndarray) {
return size(ndarray);
}
uint64_t __nac3_ndarray_size64(NDArray<int64_t>* ndarray) {
return size(ndarray);
}
uint32_t __nac3_ndarray_nbytes(NDArray<int32_t>* ndarray) {
return nbytes(ndarray);
}
uint64_t __nac3_ndarray_nbytes64(NDArray<int64_t>* ndarray) {
return nbytes(ndarray);
}
void __nac3_ndarray_len(ErrorContext* errctx, NDArray<int32_t>* ndarray,
SliceIndex* dst_len) {
return len(errctx, ndarray, dst_len);
}
void __nac3_ndarray_len64(ErrorContext* errctx, NDArray<int64_t>* ndarray,
SliceIndex* dst_len) {
return len(errctx, ndarray, dst_len);
}
void __nac3_ndarray_util_assert_shape_no_negative(ErrorContext* errctx,
int32_t ndims,
int32_t* shape) {
util::assert_shape_no_negative(errctx, ndims, shape);
}
void __nac3_ndarray_util_assert_shape_no_negative64(ErrorContext* errctx,
int64_t ndims,
int64_t* shape) {
util::assert_shape_no_negative(errctx, ndims, shape);
}
void __nac3_ndarray_set_strides_by_shape(NDArray<int32_t>* ndarray) {
set_strides_by_shape(ndarray);
}
void __nac3_ndarray_set_strides_by_shape64(NDArray<int64_t>* ndarray) {
set_strides_by_shape(ndarray);
}
bool __nac3_ndarray_is_c_contiguous(NDArray<int32_t>* ndarray) {
return is_c_contiguous(ndarray);
}
bool __nac3_ndarray_is_c_contiguous64(NDArray<int64_t>* ndarray) {
return is_c_contiguous(ndarray);
}
void __nac3_ndarray_copy_data(NDArray<int32_t>* src_ndarray,
NDArray<int32_t>* dst_ndarray) {
copy_data(src_ndarray, dst_ndarray);
}
void __nac3_ndarray_copy_data64(NDArray<int64_t>* src_ndarray,
NDArray<int64_t>* dst_ndarray) {
copy_data(src_ndarray, dst_ndarray);
}
uint8_t* __nac3_ndarray_get_nth_pelement(NDArray<int32_t>* ndarray,
int32_t index) {
return get_nth_pelement(ndarray, index);
}
uint8_t* __nac3_ndarray_get_nth_pelement64(NDArray<int64_t>* ndarray,
int64_t index) {
return get_nth_pelement(ndarray, index);
}
}

View File

@ -0,0 +1,221 @@
#pragma once
#include <irrt/error_context.hpp>
#include <irrt/int_defs.hpp>
#include <irrt/ndarray/def.hpp>
#include <irrt/slice.hpp>
namespace {
template <typename SizeT>
struct ShapeEntry {
SizeT ndims;
SizeT* shape;
};
} // namespace
namespace {
namespace ndarray {
namespace broadcast {
namespace util {
/**
* @brief Return true if `src_shape` can broadcast to `dst_shape`.
*/
template <typename SizeT>
bool can_broadcast_shape_to(SizeT target_ndims, const SizeT* target_shape,
SizeT src_ndims, const SizeT* src_shape) {
/*
* // See https://numpy.org/doc/stable/user/basics.broadcasting.html
* This function handles this example:
* ```
* Image (3d array): 256 x 256 x 3
* Scale (1d array): 3
* Result (3d array): 256 x 256 x 3
* ```
* Other interesting examples to consider:
* - `can_broadcast_shape_to([3], [1, 1, 1, 1, 3]) == true`
* - `can_broadcast_shape_to([3], [3, 1]) == false`
* - `can_broadcast_shape_to([256, 256, 3], [256, 1, 3]) == true`
* In cases when the shapes contain zero(es):
* - `can_broadcast_shape_to([0], [1]) == true`
* - `can_broadcast_shape_to([0], [2]) == false`
* - `can_broadcast_shape_to([0, 4, 0, 0], [1]) == true`
* - `can_broadcast_shape_to([0, 4, 0, 0], [1, 1, 1, 1]) == true`
* - `can_broadcast_shape_to([0, 4, 0, 0], [1, 4, 1, 1]) == true`
* - `can_broadcast_shape_to([4, 3], [0, 3]) == false`
* - `can_broadcast_shape_to([4, 3], [0, 0]) == false`
*/
// This is essentially doing the following in Python:
// `for target_dim, src_dim in itertools.zip_longest(target_shape[::-1], src_shape[::-1], fillvalue=1)`
for (SizeT i = 0; i < max(target_ndims, src_ndims); i++) {
SizeT target_dim_i = target_ndims - i - 1;
SizeT src_dim_i = src_ndims - i - 1;
bool target_dim_exists = target_dim_i >= 0;
bool src_dim_exists = src_dim_i >= 0;
SizeT target_dim = target_dim_exists ? target_shape[target_dim_i] : 1;
SizeT src_dim = src_dim_exists ? src_shape[src_dim_i] : 1;
bool ok = src_dim == 1 || target_dim == src_dim;
if (!ok) return false;
}
return true;
}
/**
* @brief Performs `np.broadcast_shapes`
*/
template <typename SizeT>
void broadcast_shapes(ErrorContext* errctx, SizeT num_shapes,
const ShapeEntry<SizeT>* shapes, SizeT dst_ndims,
SizeT* dst_shape) {
// `dst_ndims` must be `max([shape.ndims for shape in shapes])`, but the caller has to calculate it/provide it
// for this function since it should already know in order to allocate `dst_shape` in the first place.
// `dst_shape` must be pre-allocated.
// `dst_shape` does not have to be initialized
// TODO: Implementation is not obvious
// This is essentially a `mconcat` where the neutral element is `[1, 1, 1, 1, ...]`, and the operation is commutative.
// Set `dst_shape` to all `1`s.
for (SizeT dst_axis = 0; dst_axis < dst_ndims; dst_axis++) {
dst_shape[dst_axis] = 0;
}
for (SizeT i = 0; i < num_shapes; i++) {
ShapeEntry<SizeT> entry = shapes[i];
SizeT entry_axis = entry.ndims - i;
SizeT dst_axis = dst_ndims - i;
SizeT entry_dim = entry.shape[entry_axis];
SizeT dst_dim = dst_shape[dst_axis];
if (dst_dim == 1) {
dst_shape[dst_axis] = entry_dim;
} else if (entry_dim == 1) {
// Do nothing
} else if (entry_dim == dst_dim) {
// Do nothing
} else {
errctx->set_exception(errctx->exceptions->value_error,
"shape mismatch: objects cannot be broadcast "
"to a single shape.");
return;
}
}
}
} // namespace util
/**
* @brief Perform `np.broadcast_to(<ndarray>, <target_shape>)` and appropriate assertions.
*
* Cautious note on https://github.com/numpy/numpy/issues/21744..
*
* This function attempts to broadcast `src_ndarray` to a new shape defined by `dst_ndarray.shape`,
* and return the result by modifying `dst_ndarray`.
*
* # Notes on `dst_ndarray`
* The caller is responsible for allocating space for the resulting ndarray.
* Here is what this function expects from `dst_ndarray` when called:
* - `dst_ndarray->data` does not have to be initialized.
* - `dst_ndarray->itemsize` does not have to be initialized.
* - `dst_ndarray->ndims` must be initialized, determining the length of `dst_ndarray->shape`
* - `dst_ndarray->shape` must be allocated, and must contain the desired target broadcast shape.
* - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values.
* When this function call ends:
* - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`)
* - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`
* - `dst_ndarray->ndims` is unchanged.
* - `dst_ndarray->shape` is unchanged.
* - `dst_ndarray->strides` is updated accordingly by how ndarray broadcast_to works.
*/
template <typename SizeT>
void broadcast_to(ErrorContext* errctx, const NDArray<SizeT>* src_ndarray,
NDArray<SizeT>* dst_ndarray) {
/*
* Cautions:
* ```
* xs = np.zeros((4,))
* ys = np.zero((4, 1))
* ys[:] = xs # ok
*
* xs = np.zeros((1, 4))
* ys = np.zero((4,))
* ys[:] = xs # allowed
* # However `np.broadcast_to(xs, (4,))` would fails, as per numpy's broadcasting rule.
* # and apparently numpy will "deprecate" this? SEE https://github.com/numpy/numpy/issues/21744
* # This implementation will NOT support this assignment.
* ```
*/
if (!ndarray::broadcast::util::can_broadcast_shape_to(
dst_ndarray->ndims, dst_ndarray->shape, src_ndarray->ndims,
src_ndarray->shape)) {
errctx->set_exception(errctx->exceptions->value_error,
"operands could not be broadcast together");
return;
}
dst_ndarray->data = src_ndarray->data;
dst_ndarray->itemsize = src_ndarray->itemsize;
// TODO: Implementation is not obvious
SizeT stride_product = 1;
for (SizeT i = 0; i < max(src_ndarray->ndims, dst_ndarray->ndims); i++) {
SizeT src_ndarray_dim_i = src_ndarray->ndims - i - 1;
SizeT dst_dim_i = dst_ndarray->ndims - i - 1;
bool src_ndarray_dim_exists = src_ndarray_dim_i >= 0;
bool dst_dim_exists = dst_dim_i >= 0;
bool c1 = src_ndarray_dim_exists &&
src_ndarray->shape[src_ndarray_dim_i] == 1;
bool c2 = dst_dim_exists && dst_ndarray->shape[dst_dim_i] != 1;
if (!src_ndarray_dim_exists || (c1 && c2)) {
dst_ndarray->strides[dst_dim_i] = 0; // Freeze it in-place
} else {
dst_ndarray->strides[dst_dim_i] =
stride_product * src_ndarray->itemsize;
stride_product *= src_ndarray->shape[src_ndarray_dim_i];
}
}
}
} // namespace broadcast
} // namespace ndarray
} // namespace
extern "C" {
using namespace ndarray::broadcast;
void __nac3_ndarray_broadcast_to(ErrorContext* errctx,
NDArray<int32_t>* src_ndarray,
NDArray<int32_t>* dst_ndarray) {
broadcast_to(errctx, src_ndarray, dst_ndarray);
}
void __nac3_ndarray_broadcast_to64(ErrorContext* errctx,
NDArray<int64_t>* src_ndarray,
NDArray<int64_t>* dst_ndarray) {
broadcast_to(errctx, src_ndarray, dst_ndarray);
}
void __nac3_ndarray_broadcast_shapes(ErrorContext* errctx, int32_t num_shapes,
const ShapeEntry<int32_t>* shapes,
int32_t dst_ndims, int32_t* dst_shape) {
ndarray::broadcast::util::broadcast_shapes(errctx, num_shapes, shapes,
dst_ndims, dst_shape);
}
void __nac3_ndarray_broadcast_shapes64(ErrorContext* errctx, int64_t num_shapes,
const ShapeEntry<int64_t>* shapes,
int64_t dst_ndims, int64_t* dst_shape) {
ndarray::broadcast::util::broadcast_shapes(errctx, num_shapes, shapes,
dst_ndims, dst_shape);
}
}

View File

@ -0,0 +1,44 @@
#pragma once
namespace {
/**
* @brief The NDArray object
*
* The official numpy implementations: https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst
*/
template <typename SizeT>
struct NDArray {
/**
* @brief The underlying data this `ndarray` is pointing to.
*
* Must be set to `nullptr` to indicate that this NDArray's `data` is uninitialized.
*/
uint8_t* data;
/**
* @brief The number of bytes of a single element in `data`.
*/
SizeT itemsize;
/**
* @brief The number of dimensions of this shape.
*/
SizeT ndims;
/**
* @brief The NDArray shape, with length equal to `ndims`.
*
* Note that it may contain 0.
*/
SizeT* shape;
/**
* @brief Array strides, with length equal to `ndims`
*
* The stride values are in units of bytes, not number of elements.
*
* Note that `strides` can have negative values.
*/
SizeT* strides;
};
} // namespace

View File

@ -0,0 +1,200 @@
#pragma once
#include <irrt/error_context.hpp>
#include <irrt/ndarray/basic.hpp>
#include <irrt/ndarray/def.hpp>
#include <irrt/slice.hpp>
namespace {
typedef uint8_t NDIndexType;
/**
* @brief A single element index
*
* See https://numpy.org/doc/stable/user/basics.indexing.html#single-element-indexing
*
* `data` points to a `SliceIndex`.
*/
const NDIndexType ND_INDEX_TYPE_SINGLE_ELEMENT = 0;
/**
* @brief A slice index
*
* See https://numpy.org/doc/stable/user/basics.indexing.html#slicing-and-striding
*
* `data` points to a `UserRange`.
*/
const NDIndexType ND_INDEX_TYPE_SLICE = 1;
/**
* @brief An index used in ndarray indexing
*/
struct NDIndex {
/**
* @brief Enum tag to specify the type of index.
*
* Please see comments of each enum constant.
*/
NDIndexType type;
/**
* @brief The accompanying data associated with `type`.
*
* Please see comments of each enum constant.
*/
uint8_t* data;
};
} // namespace
namespace {
namespace ndarray {
namespace indexing {
namespace util {
/**
* @brief Return the expected rank of the resulting ndarray
* created by indexing an ndarray of rank `ndims` using `indexes`.
*/
template <typename SizeT>
void deduce_ndims_after_indexing(ErrorContext* errctx, SizeT* final_ndims,
SizeT ndims, SizeT num_indexes,
const NDIndex* indexes) {
if (num_indexes > ndims) {
errctx->set_exception(errctx->exceptions->index_error,
"too many indices for array: array is "
"{0}-dimensional, but {1} were indexed",
ndims, num_indexes);
return;
}
*final_ndims = ndims;
for (SizeT i = 0; i < num_indexes; i++) {
if (indexes[i].type == ND_INDEX_TYPE_SINGLE_ELEMENT) {
// An index demotes the rank by 1
(*final_ndims)--;
}
}
}
} // namespace util
/**
* @brief Perform ndarray "basic indexing" (https://numpy.org/doc/stable/user/basics.indexing.html#basic-indexing)
*
* This is function very similar to performing `dst_ndarray = src_ndarray[indexes]` in Python (where the variables
* can all be found in the parameter of this function).
*
* In other words, this function takes in an ndarray (`src_ndarray`), index it with `indexes`, and return the
* indexed array (by writing the result to `dst_ndarray`).
*
* This function also does proper assertions on `indexes`.
*
* # Notes on `dst_ndarray`
* The caller is responsible for allocating space for the resulting ndarray.
* Here is what this function expects from `dst_ndarray` when called:
* - `dst_ndarray->data` does not have to be initialized.
* - `dst_ndarray->itemsize` does not have to be initialized.
* - `dst_ndarray->ndims` must be initialized, and it must be equal to the expected `ndims` of the `dst_ndarray` after
* indexing `src_ndarray` with `indexes`.
* - `dst_ndarray->shape` must be allocated, through it can contain uninitialized values.
* - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values.
* When this function call ends:
* - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`)
* - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`
* - `dst_ndarray->ndims` is unchanged.
* - `dst_ndarray->shape` is updated according to how `src_ndarray` is indexed.
* - `dst_ndarray->strides` is updated accordingly by how ndarray indexing works.
*
* @param indexes Indexes to index `src_ndarray`, ordered in the same way you would write them in Python.
* @param src_ndarray The NDArray to be indexed.
* @param dst_ndarray The resulting NDArray after indexing. Further details in the comments above,
*/
template <typename SizeT>
void index(ErrorContext* errctx, SizeT num_indexes, const NDIndex* indexes,
const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
// Reference code: https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L652
dst_ndarray->data = src_ndarray->data;
dst_ndarray->itemsize = src_ndarray->itemsize;
SizeT src_axis = 0;
SizeT dst_axis = 0;
for (SliceIndex i = 0; i < num_indexes; i++) {
const NDIndex* index = &indexes[i];
if (index->type == ND_INDEX_TYPE_SINGLE_ELEMENT) {
SliceIndex input = *((SliceIndex*)index->data);
SliceIndex k = slice::resolve_index_in_length(
src_ndarray->shape[src_axis], input);
if (k == slice::OUT_OF_BOUNDS) {
errctx->set_exception(errctx->exceptions->index_error,
"index {0} is out of bounds for axis {1} "
"with size {2}",
input, src_axis,
src_ndarray->shape[src_axis]);
return;
}
dst_ndarray->data += k * src_ndarray->strides[src_axis];
src_axis++;
} else if (index->type == ND_INDEX_TYPE_SLICE) {
UserSlice* input = (UserSlice*)index->data;
Slice slice;
input->indices_checked(errctx, src_ndarray->shape[src_axis],
&slice);
if (errctx->has_exception()) {
return;
}
dst_ndarray->data +=
(SizeT)slice.start * src_ndarray->strides[src_axis];
dst_ndarray->strides[dst_axis] =
((SizeT)slice.step) * src_ndarray->strides[src_axis];
dst_ndarray->shape[dst_axis] = (SizeT)slice.len();
dst_axis++;
src_axis++;
} else {
__builtin_unreachable();
}
}
for (; dst_axis < dst_ndarray->ndims; dst_axis++, src_axis++) {
dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis];
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
}
}
} // namespace indexing
} // namespace ndarray
} // namespace
extern "C" {
using namespace ndarray::indexing;
void __nac3_ndarray_indexing_deduce_ndims_after_indexing(
ErrorContext* errctx, int32_t* result, int32_t ndims, int32_t num_indexes,
const NDIndex* indexes) {
ndarray::indexing::util::deduce_ndims_after_indexing(errctx, result, ndims,
num_indexes, indexes);
}
void __nac3_ndarray_indexing_deduce_ndims_after_indexing64(
ErrorContext* errctx, int64_t* result, int64_t ndims, int64_t num_indexes,
const NDIndex* indexes) {
ndarray::indexing::util::deduce_ndims_after_indexing(errctx, result, ndims,
num_indexes, indexes);
}
void __nac3_ndarray_index(ErrorContext* errctx, int32_t num_indexes,
NDIndex* indexes, NDArray<int32_t>* src_ndarray,
NDArray<int32_t>* dst_ndarray) {
index(errctx, num_indexes, indexes, src_ndarray, dst_ndarray);
}
void __nac3_ndarray_index64(ErrorContext* errctx, int64_t num_indexes,
NDIndex* indexes, NDArray<int64_t>* src_ndarray,
NDArray<int64_t>* dst_ndarray) {
index(errctx, num_indexes, indexes, src_ndarray, dst_ndarray);
}
}

View File

@ -0,0 +1,117 @@
#pragma once
#include <irrt/error_context.hpp>
#include <irrt/int_defs.hpp>
#include <irrt/ndarray/def.hpp>
namespace {
namespace ndarray {
namespace reshape {
namespace util {
/**
* @brief Perform assertions on and resolve unknown dimensions in `new_shape` in `np.reshape(<ndarray>, new_shape)`
*
* If `new_shape` indeed contains unknown dimensions (specified with `-1`, just like numpy), `new_shape` will be
* modified to contain the resolved dimension.
*
* To perform assertions on and resolve unknown dimensions in `new_shape`, we don't need the actual
* `<ndarray>` object itself, but only the `.size` of the `<ndarray>`.
*
* @param size The `.size` of `<ndarray>`
* @param new_ndims Number of elements in `new_shape`
* @param new_shape Target shape to reshape to
*/
template <typename SizeT>
void resolve_and_check_new_shape(ErrorContext* errctx, SizeT size,
SizeT new_ndims, SizeT* new_shape) {
// Is there a -1 in `new_shape`?
bool neg1_exists = false;
// Location of -1, only initialized if `neg1_exists` is true
SizeT neg1_axis_i;
// The computed ndarray size of `new_shape`
SizeT new_size = 1;
for (SizeT axis_i = 0; axis_i < new_ndims; axis_i++) {
SizeT dim = new_shape[axis_i];
if (dim < 0) {
if (dim == -1) {
if (neg1_exists) {
// Multiple `-1` found. Throw an error.
errctx->set_exception(
errctx->exceptions->value_error,
"can only specify one unknown dimension");
return;
} else {
neg1_exists = true;
neg1_axis_i = axis_i;
}
} else {
// TODO: What? In `np.reshape` any negative dimensions is
// treated like its `-1`.
//
// Try running `np.zeros((3, 4)).reshape((-999, 2))`
//
// It is not documented by numpy.
// Throw an error for now...
errctx->set_exception(
errctx->exceptions->value_error,
"Found negative dimension {0} on axis {1}", dim, axis_i);
return;
}
} else {
new_size *= dim;
}
}
bool can_reshape;
if (neg1_exists) {
// Let `x` be the unknown dimension
// solve `x * <new_size> = <size>`
if (new_size == 0 && size == 0) {
// `x` has infinitely many solutions
can_reshape = false;
} else if (new_size == 0 && size != 0) {
// `x` has no solutions
can_reshape = false;
} else if (size % new_size != 0) {
// `x` has no integer solutions
can_reshape = false;
} else {
can_reshape = true;
new_shape[neg1_axis_i] = size / new_size; // Resolve dimension
}
} else {
can_reshape = (new_size == size);
}
if (!can_reshape) {
errctx->set_exception(
errctx->exceptions->value_error,
"cannot reshape array of size {0} into given shape", size);
return;
}
}
} // namespace util
} // namespace reshape
} // namespace ndarray
} // namespace
extern "C" {
void __nac3_ndarray_resolve_and_check_new_shape(ErrorContext* errctx,
int32_t size, int32_t new_ndims,
int32_t* new_shape) {
ndarray::reshape::util::resolve_and_check_new_shape(errctx, size, new_ndims,
new_shape);
}
void __nac3_ndarray_resolve_and_check_new_shape64(ErrorContext* errctx,
int64_t size,
int64_t new_ndims,
int64_t* new_shape) {
ndarray::reshape::util::resolve_and_check_new_shape(errctx, size, new_ndims,
new_shape);
}
}

View File

@ -0,0 +1,166 @@
#pragma once
#include <irrt/error_context.hpp>
#include <irrt/int_defs.hpp>
#include <irrt/slice.hpp>
#include <irrt/utils.hpp>
// The type of an index or a value describing the length of a
// range/slice is always `int32_t`.
using SliceIndex = int32_t;
namespace {
/**
* @brief A Python-like slice with resolved indices.
*
* "Resolved indices" means that `start` and `stop` must be positive and are
* bound to a known length.
*/
struct Slice {
SliceIndex start;
SliceIndex stop;
SliceIndex step;
/**
* @brief Calculate and return the length / the number of the slice.
*
* If this were a Python range, this function would be `len(range(start, stop, step))`.
*/
SliceIndex len() {
SliceIndex diff = stop - start;
if (diff > 0 && step > 0) {
return ((diff - 1) / step) + 1;
} else if (diff < 0 && step < 0) {
return ((diff + 1) / step) + 1;
} else {
return 0;
}
}
};
namespace slice {
/**
* @brief Resolve a slice index under a given length like Python indexing.
*
* In Python, if you have a `list` of length 100, `list[-1]` resolves to
* `list[99]`, so `resolve_index_in_length_clamped(100, -1)` returns `99`.
*
* If `length` is 0, 0 is returned for any value of `index`.
*
* If `index` is out of bounds, clamps the returned value between `0` and
* `length - 1` (inclusive).
*
*/
SliceIndex resolve_index_in_length_clamped(SliceIndex length,
SliceIndex index) {
if (index < 0) {
return max<SliceIndex>(length + index, 0);
} else {
return min<SliceIndex>(length, index);
}
}
const SliceIndex OUT_OF_BOUNDS = -1;
/**
* @brief Like `resolve_index_in_length_clamped`, but returns `OUT_OF_BOUNDS`
* if `index` is out of bounds.
*/
SliceIndex resolve_index_in_length(SliceIndex length, SliceIndex index) {
SliceIndex resolved = index < 0 ? length + index : index;
if (0 <= resolved && resolved < length) {
return resolved;
} else {
return OUT_OF_BOUNDS;
}
}
} // namespace slice
/**
* @brief A Python-like slice with **unresolved** indices.
*/
struct UserSlice {
bool start_defined;
SliceIndex start;
bool stop_defined;
SliceIndex stop;
bool step_defined;
SliceIndex step;
UserSlice() { this->reset(); }
void reset() {
this->start_defined = false;
this->stop_defined = false;
this->step_defined = false;
}
void set_start(SliceIndex start) {
this->start_defined = true;
this->start = start;
}
void set_stop(SliceIndex stop) {
this->stop_defined = true;
this->stop = stop;
}
void set_step(SliceIndex step) {
this->step_defined = true;
this->step = step;
}
/**
* @brief Resolve this slice.
*
* In Python, this would be `slice(start, stop, step).indices(length)`.
*
* @return A `Slice` with the resolved indices.
*/
Slice indices(SliceIndex length) {
Slice result;
result.step = step_defined ? step : 1;
bool step_is_negative = result.step < 0;
if (start_defined) {
result.start =
slice::resolve_index_in_length_clamped(length, start);
} else {
result.start = step_is_negative ? length - 1 : 0;
}
if (stop_defined) {
result.stop = slice::resolve_index_in_length_clamped(length, stop);
} else {
result.stop = step_is_negative ? -1 : length;
}
return result;
}
/**
* @brief Like `.indices()` but with assertions.
*/
void indices_checked(ErrorContext* errctx, SliceIndex length,
Slice* result) {
if (length < 0) {
errctx->set_exception(errctx->exceptions->value_error,
"length should not be negative, got {0}",
length);
return;
}
if (this->step_defined && this->step == 0) {
errctx->set_exception(errctx->exceptions->value_error,
"slice step cannot be zero");
return;
}
*result = this->indices(length);
}
};
} // namespace

View File

@ -0,0 +1,104 @@
#pragma once
namespace {
template <typename T>
const T& max(const T& a, const T& b) {
return a > b ? a : b;
}
template <typename T>
const T& min(const T& a, const T& b) {
return a > b ? b : a;
}
/**
* @brief Compare contents of two arrays with the same length.
*/
template <typename T>
bool arrays_match(int len, T* as, T* bs) {
for (int i = 0; i < len; i++) {
if (as[i] != bs[i]) return false;
}
return true;
}
namespace cstr_utils {
/**
* @brief Return true if `str` is empty.
*/
bool is_empty(const char* str) { return str[0] == '\0'; }
/**
* @brief Implementation of `strcmp()`
*/
int8_t compare(const char* a, const char* b) {
uint32_t i = 0;
while (true) {
if (a[i] < b[i]) {
return -1;
} else if (a[i] > b[i]) {
return 1;
} else {
if (a[i] == '\0') {
return 0;
} else {
i++;
}
}
}
}
/**
* @brief Return true two strings have the same content.
*/
int8_t equal(const char* a, const char* b) { return compare(a, b) == 0; }
/**
* @brief Implementation of `strlen()`.
*/
uint32_t length(const char* str) {
uint32_t length = 0;
while (*str != '\0') {
length++;
str++;
}
return length;
}
/**
* @brief Copy a null-terminated string to a buffer with limited size and guaranteed null-termination.
*
* `dst_max_size` must be greater than 0, otherwise this function has undefined behavior.
*
* This function attempts to copy everything from `src` from `dst`, and *always* null-terminates `dst`.
*
* If the size of `dst` is too small, the final byte (`dst[dst_max_size - 1]`) of `dst` will be set to
* the null terminator.
*
* @param src String to copy from.
* @param dst Buffer to copy string to.
* @param dst_max_size
* Number of bytes of this buffer, including the space needed for the null terminator.
* Must be greater than 0.
* @return If `dst` is too small to contain everything in `src`.
*/
bool copy(const char* src, char* dst, uint32_t dst_max_size) {
for (uint32_t i = 0; i < dst_max_size; i++) {
bool is_last = i + 1 == dst_max_size;
if (is_last && src[i] != '\0') {
dst[i] = '\0';
return false;
}
if (src[i] == '\0') {
dst[i] = '\0';
return true;
}
dst[i] = src[i];
}
__builtin_unreachable();
}
} // namespace cstr_utils
} // namespace

View File

@ -0,0 +1,13 @@
#pragma once
#include <irrt/artiq_defs.hpp>
#include <irrt/core.hpp>
#include <irrt/error_context.hpp>
#include <irrt/int_defs.hpp>
#include <irrt/ndarray/basic.hpp>
#include <irrt/ndarray/broadcast.hpp>
#include <irrt/ndarray/def.hpp>
#include <irrt/ndarray/indexing.hpp>
#include <irrt/ndarray/reshape.hpp>
#include <irrt/slice.hpp>
#include <irrt/utils.hpp>

View File

@ -0,0 +1,20 @@
// This file will be compiled like a real C++ program,
// and we do have the luxury to use the standard libraries.
// That is if the nix flakes do not have issues... especially on msys2...
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <test/test_core.hpp>
#include <test/test_ndarray_basic.hpp>
#include <test/test_ndarray_broadcast.hpp>
#include <test/test_ndarray_indexing.hpp>
#include <test/test_slice.hpp>
int main() {
test::core::run();
test::slice::run();
test::ndarray_basic::run();
test::ndarray_indexing::run();
test::ndarray_broadcast::run();
return 0;
}

View File

@ -0,0 +1,11 @@
#pragma once
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <irrt_everything.hpp>
#include <test/util.hpp>
/*
Include this header for every test_*.cpp
*/

View File

@ -0,0 +1,16 @@
#pragma once
#include <test/includes.hpp>
namespace test {
namespace core {
void test_int_exp() {
BEGIN_TEST();
assert_values_match(125, __nac3_int_exp_impl<int32_t>(5, 3));
assert_values_match(3125, __nac3_int_exp_impl<int32_t>(5, 5));
}
void run() { test_int_exp(); }
} // namespace core
} // namespace test

View File

@ -0,0 +1,30 @@
#pragma once
#include <test/includes.hpp>
namespace test {
namespace ndarray_basic {
void test_calc_size_from_shape_normal() {
// Test shapes with normal values
BEGIN_TEST();
int32_t shape[4] = {2, 3, 5, 7};
assert_values_match(
210, ndarray::basic::util::calc_size_from_shape<int32_t>(4, shape));
}
void test_calc_size_from_shape_has_zero() {
// Test shapes with 0 in them
BEGIN_TEST();
int32_t shape[4] = {2, 0, 5, 7};
assert_values_match(
0, ndarray::basic::util::calc_size_from_shape<int32_t>(4, shape));
}
void run() {
test_calc_size_from_shape_normal();
test_calc_size_from_shape_has_zero();
}
} // namespace ndarray_basic
} // namespace test

View File

@ -0,0 +1,129 @@
#pragma once
#include <test/includes.hpp>
namespace test {
namespace ndarray_broadcast {
void test_can_broadcast_shape() {
BEGIN_TEST();
assert_values_match(true,
ndarray::broadcast::util::can_broadcast_shape_to(
1, (int32_t[]){3}, 5, (int32_t[]){1, 1, 1, 1, 3}));
assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to(
1, (int32_t[]){3}, 2, (int32_t[]){3, 1}));
assert_values_match(true, ndarray::broadcast::util::can_broadcast_shape_to(
1, (int32_t[]){3}, 1, (int32_t[]){3}));
assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to(
1, (int32_t[]){1}, 1, (int32_t[]){3}));
assert_values_match(true, ndarray::broadcast::util::can_broadcast_shape_to(
1, (int32_t[]){1}, 1, (int32_t[]){1}));
assert_values_match(
true, ndarray::broadcast::util::can_broadcast_shape_to(
3, (int32_t[]){256, 256, 3}, 3, (int32_t[]){256, 1, 3}));
assert_values_match(true,
ndarray::broadcast::util::can_broadcast_shape_to(
3, (int32_t[]){256, 256, 3}, 1, (int32_t[]){3}));
assert_values_match(false,
ndarray::broadcast::util::can_broadcast_shape_to(
3, (int32_t[]){256, 256, 3}, 1, (int32_t[]){2}));
assert_values_match(true,
ndarray::broadcast::util::can_broadcast_shape_to(
3, (int32_t[]){256, 256, 3}, 1, (int32_t[]){1}));
// In cases when the shapes contain zero(es)
assert_values_match(true, ndarray::broadcast::util::can_broadcast_shape_to(
1, (int32_t[]){0}, 1, (int32_t[]){1}));
assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to(
1, (int32_t[]){0}, 1, (int32_t[]){2}));
assert_values_match(true,
ndarray::broadcast::util::can_broadcast_shape_to(
4, (int32_t[]){0, 4, 0, 0}, 1, (int32_t[]){1}));
assert_values_match(
true, ndarray::broadcast::util::can_broadcast_shape_to(
4, (int32_t[]){0, 4, 0, 0}, 4, (int32_t[]){1, 1, 1, 1}));
assert_values_match(
true, ndarray::broadcast::util::can_broadcast_shape_to(
4, (int32_t[]){0, 4, 0, 0}, 4, (int32_t[]){1, 4, 1, 1}));
assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to(
2, (int32_t[]){4, 3}, 2, (int32_t[]){0, 3}));
assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to(
2, (int32_t[]){4, 3}, 2, (int32_t[]){0, 0}));
}
void test_ndarray_broadcast() {
/*
# array = np.array([[19.9, 29.9, 39.9, 49.9]], dtype=np.float64)
# >>> [[19.9 29.9 39.9 49.9]]
#
# array = np.broadcast_to(array, (2, 3, 4))
# >>> [[[19.9 29.9 39.9 49.9]
# >>> [19.9 29.9 39.9 49.9]
# >>> [19.9 29.9 39.9 49.9]]
# >>> [[19.9 29.9 39.9 49.9]
# >>> [19.9 29.9 39.9 49.9]
# >>> [19.9 29.9 39.9 49.9]]]
#
# assery array.strides == (0, 0, 8)
*/
BEGIN_TEST();
double in_data[4] = {19.9, 29.9, 39.9, 49.9};
const int32_t in_ndims = 2;
int32_t in_shape[in_ndims] = {1, 4};
int32_t in_strides[in_ndims] = {};
NDArray<int32_t> ndarray = {.data = (uint8_t*)in_data,
.itemsize = sizeof(double),
.ndims = in_ndims,
.shape = in_shape,
.strides = in_strides};
ndarray::basic::set_strides_by_shape(&ndarray);
const int32_t dst_ndims = 3;
int32_t dst_shape[dst_ndims] = {2, 3, 4};
int32_t dst_strides[dst_ndims] = {};
NDArray<int32_t> dst_ndarray = {
.ndims = dst_ndims, .shape = dst_shape, .strides = dst_strides};
ErrorContext errctx = create_testing_errctx();
ndarray::broadcast::broadcast_to(&errctx, &ndarray, &dst_ndarray);
assert_errctx_no_exception(&errctx);
assert_arrays_match(dst_ndims, ((int32_t[]){0, 0, 8}), dst_ndarray.strides);
assert_values_match(19.9,
*((double*)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, ((int32_t[]){0, 0, 0}))));
assert_values_match(29.9,
*((double*)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, ((int32_t[]){0, 0, 1}))));
assert_values_match(39.9,
*((double*)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, ((int32_t[]){0, 0, 2}))));
assert_values_match(49.9,
*((double*)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, ((int32_t[]){0, 0, 3}))));
assert_values_match(19.9,
*((double*)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, ((int32_t[]){0, 1, 0}))));
assert_values_match(29.9,
*((double*)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, ((int32_t[]){0, 1, 1}))));
assert_values_match(39.9,
*((double*)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, ((int32_t[]){0, 1, 2}))));
assert_values_match(49.9,
*((double*)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, ((int32_t[]){0, 1, 3}))));
assert_values_match(49.9,
*((double*)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, ((int32_t[]){1, 2, 3}))));
}
void run() {
test_can_broadcast_shape();
test_ndarray_broadcast();
}
} // namespace ndarray_broadcast
} // namespace test

View File

@ -0,0 +1,220 @@
#pragma once
#include <test/includes.hpp>
namespace test {
namespace ndarray_indexing {
void test_normal_1() {
/*
Reference Python code:
```python
ndarray = np.arange(12, dtype=np.float64).reshape((3, 4));
# array([[ 0., 1., 2., 3.],
# [ 4., 5., 6., 7.],
# [ 8., 9., 10., 11.]])
dst_ndarray = ndarray[-2:, 1::2]
# array([[ 5., 7.],
# [ 9., 11.]])
assert dst_ndarray.shape == (2, 2)
assert dst_ndarray.strides == (32, 16)
assert dst_ndarray[0, 0] == 5.0
assert dst_ndarray[0, 1] == 7.0
assert dst_ndarray[1, 0] == 9.0
assert dst_ndarray[1, 1] == 11.0
```
*/
BEGIN_TEST();
// Prepare src_ndarray
double src_data[12] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0,
6.0, 7.0, 8.0, 9.0, 10.0, 11.0};
int32_t src_itemsize = sizeof(double);
const int32_t src_ndims = 2;
int32_t src_shape[src_ndims] = {3, 4};
int32_t src_strides[src_ndims] = {};
NDArray<int32_t> src_ndarray = {.data = (uint8_t *)src_data,
.itemsize = src_itemsize,
.ndims = src_ndims,
.shape = src_shape,
.strides = src_strides};
ndarray::basic::set_strides_by_shape(&src_ndarray);
// Prepare dst_ndarray
const int32_t dst_ndims = 2;
int32_t dst_shape[dst_ndims] = {999, 999}; // Empty values
int32_t dst_strides[dst_ndims] = {999, 999}; // Empty values
NDArray<int32_t> dst_ndarray = {.data = nullptr,
.ndims = dst_ndims,
.shape = dst_shape,
.strides = dst_strides};
// Create the subscripts in `ndarray[-2::, 1::2]`
UserSlice subscript_1;
subscript_1.set_start(-2);
UserSlice subscript_2;
subscript_2.set_start(1);
subscript_2.set_step(2);
const int32_t num_indexes = 2;
NDIndex indexes[num_indexes] = {
{.type = ND_INDEX_TYPE_SLICE, .data = (uint8_t *)&subscript_1},
{.type = ND_INDEX_TYPE_SLICE, .data = (uint8_t *)&subscript_2}};
ErrorContext errctx = create_testing_errctx();
ndarray::indexing::index(&errctx, num_indexes, indexes, &src_ndarray,
&dst_ndarray);
assert_errctx_no_exception(&errctx);
int32_t expected_shape[dst_ndims] = {2, 2};
int32_t expected_strides[dst_ndims] = {32, 16};
assert_arrays_match(dst_ndims, expected_shape, dst_ndarray.shape);
assert_arrays_match(dst_ndims, expected_strides, dst_ndarray.strides);
// dst_ndarray[0, 0]
assert_values_match(5.0,
*((double *)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, (int32_t[dst_ndims]){0, 0})));
// dst_ndarray[0, 1]
assert_values_match(7.0,
*((double *)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, (int32_t[dst_ndims]){0, 1})));
// dst_ndarray[1, 0]
assert_values_match(9.0,
*((double *)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, (int32_t[dst_ndims]){1, 0})));
// dst_ndarray[1, 1]
assert_values_match(11.0,
*((double *)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, (int32_t[dst_ndims]){1, 1})));
}
void test_normal_2() {
/*
```python
ndarray = np.arange(12, dtype=np.float64).reshape((3, 4))
# array([[ 0., 1., 2., 3.],
# [ 4., 5., 6., 7.],
# [ 8., 9., 10., 11.]])
dst_ndarray = ndarray[2, ::-2]
# array([11., 9.])
assert dst_ndarray.shape == (2,)
assert dst_ndarray.strides == (-16,)
assert dst_ndarray[0] == 11.0
assert dst_ndarray[1] == 9.0
```
*/
BEGIN_TEST();
// Prepare src_ndarray
double src_data[12] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0,
6.0, 7.0, 8.0, 9.0, 10.0, 11.0};
int32_t src_itemsize = sizeof(double);
const int32_t src_ndims = 2;
int32_t src_shape[src_ndims] = {3, 4};
int32_t src_strides[src_ndims] = {};
NDArray<int32_t> src_ndarray = {.data = (uint8_t *)src_data,
.itemsize = src_itemsize,
.ndims = src_ndims,
.shape = src_shape,
.strides = src_strides};
ndarray::basic::set_strides_by_shape(&src_ndarray);
// Prepare dst_ndarray
const int32_t dst_ndims = 1;
int32_t dst_shape[dst_ndims] = {999}; // Empty values
int32_t dst_strides[dst_ndims] = {999}; // Empty values
NDArray<int32_t> dst_ndarray = {.data = nullptr,
.ndims = dst_ndims,
.shape = dst_shape,
.strides = dst_strides};
// Create the subscripts in `ndarray[2, ::-2]`
int32_t subscript_1 = 2;
UserSlice subscript_2;
subscript_2.set_step(-2);
const int32_t num_indexes = 2;
NDIndex indexes[num_indexes] = {
{.type = ND_INDEX_TYPE_SINGLE_ELEMENT, .data = (uint8_t *)&subscript_1},
{.type = ND_INDEX_TYPE_SLICE, .data = (uint8_t *)&subscript_2}};
ErrorContext errctx = create_testing_errctx();
ndarray::indexing::index(&errctx, num_indexes, indexes, &src_ndarray,
&dst_ndarray);
assert_errctx_no_exception(&errctx);
int32_t expected_shape[dst_ndims] = {2};
int32_t expected_strides[dst_ndims] = {-16};
assert_arrays_match(dst_ndims, expected_shape, dst_ndarray.shape);
assert_arrays_match(dst_ndims, expected_strides, dst_ndarray.strides);
assert_values_match(11.0,
*((double *)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, (int32_t[dst_ndims]){0})));
assert_values_match(9.0,
*((double *)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, (int32_t[dst_ndims]){1})));
}
void test_index_subscript_out_of_bounds() {
/*
# Consider `my_array`
print(my_array.shape)
# (4, 5, 6)
my_array[2, 100] # error, index subscript at axis 1 is out of bounds
*/
BEGIN_TEST();
// Prepare src_ndarray
const int32_t src_ndims = 2;
int32_t src_shape[src_ndims] = {3, 4};
int32_t src_strides[src_ndims] = {};
NDArray<int32_t> src_ndarray = {
.data = (uint8_t *)nullptr, // placeholder, we wouldn't access it
.itemsize = sizeof(double), // placeholder
.ndims = src_ndims,
.shape = src_shape,
.strides = src_strides};
ndarray::basic::set_strides_by_shape(&src_ndarray);
// Create the subscripts in `my_array[2, 100]`
int32_t subscript_1 = 2;
int32_t subscript_2 = 100;
const int32_t num_indexes = 2;
NDIndex indexes[num_indexes] = {
{.type = ND_INDEX_TYPE_SINGLE_ELEMENT, .data = (uint8_t *)&subscript_1},
{.type = ND_INDEX_TYPE_SINGLE_ELEMENT,
.data = (uint8_t *)&subscript_2}};
// Prepare dst_ndarray
const int32_t dst_ndims = 0;
int32_t dst_shape[dst_ndims] = {};
int32_t dst_strides[dst_ndims] = {};
NDArray<int32_t> dst_ndarray = {.data = nullptr, // placehloder
.ndims = dst_ndims,
.shape = dst_shape,
.strides = dst_strides};
ErrorContext errctx = create_testing_errctx();
ndarray::indexing::index(&errctx, num_indexes, indexes, &src_ndarray,
&dst_ndarray);
assert_errctx_has_exception(&errctx, errctx.exceptions->index_error);
}
void run() {
test_normal_1();
test_normal_2();
test_index_subscript_out_of_bounds();
}
} // namespace ndarray_indexing
} // namespace test

View File

@ -0,0 +1,92 @@
#pragma once
#include <irrt_everything.hpp>
#include <test/includes.hpp>
namespace test {
namespace slice {
void test_slice_normal() {
// Normal situation
BEGIN_TEST();
UserSlice user_slice;
user_slice.set_stop(5);
Slice slice = user_slice.indices(100);
printf("%d, %d, %d\n", slice.start, slice.stop, slice.step);
assert_values_match(0, slice.start);
assert_values_match(5, slice.stop);
assert_values_match(1, slice.step);
}
void test_slice_start_too_large() {
// Start is too large and should be clamped to length
BEGIN_TEST();
UserSlice user_slice;
user_slice.set_start(400);
Slice slice = user_slice.indices(100);
assert_values_match(100, slice.start);
assert_values_match(100, slice.stop);
assert_values_match(1, slice.step);
}
void test_slice_negative_start_stop() {
// Negative start/stop should be resolved
BEGIN_TEST();
UserSlice user_slice;
user_slice.set_start(-10);
user_slice.set_stop(-5);
Slice slice = user_slice.indices(100);
assert_values_match(90, slice.start);
assert_values_match(95, slice.stop);
assert_values_match(1, slice.step);
}
void test_slice_only_negative_step() {
// Things like `[::-5]` should be handled correctly
BEGIN_TEST();
UserSlice user_slice;
user_slice.set_step(-5);
Slice slice = user_slice.indices(100);
assert_values_match(99, slice.start);
assert_values_match(-1, slice.stop);
assert_values_match(-5, slice.step);
}
void test_slice_step_zero() {
// Step = 0 is a value error
BEGIN_TEST();
ErrorContext errctx = create_testing_errctx();
UserSlice user_slice;
user_slice.set_start(2);
user_slice.set_stop(12);
user_slice.set_step(0);
Slice slice;
user_slice.indices_checked(&errctx, 100, &slice);
assert_errctx_has_exception(&errctx, errctx.exceptions->value_error);
}
void run() {
test_slice_normal();
test_slice_start_too_large();
test_slice_negative_start_stop();
test_slice_only_negative_step();
test_slice_step_zero();
}
} // namespace slice
} // namespace test

188
nac3core/irrt/test/util.hpp Normal file
View File

@ -0,0 +1,188 @@
#pragma once
#include <cstdio>
#include <cstdlib>
template <class T>
void print_value(const T& value);
template <>
void print_value(const bool& value) {
printf("%s", value ? "true" : "false");
}
template <>
void print_value(const int8_t& value) {
printf("%d", value);
}
template <>
void print_value(const int32_t& value) {
printf("%d", value);
}
template <>
void print_value(const uint8_t& value) {
printf("%u", value);
}
template <>
void print_value(const uint32_t& value) {
printf("%u", value);
}
template <>
void print_value(const float& value) {
printf("%f", value);
}
template <>
void print_value(const double& value) {
printf("%f", value);
}
void __begin_test(const char* function_name, const char* file, int line) {
printf("######### Running %s @ %s:%d\n", function_name, file, line);
}
#define BEGIN_TEST() __begin_test(__FUNCTION__, __FILE__, __LINE__)
void test_fail() {
printf("[!] Test failed. Exiting with status code 1.\n");
exit(1);
}
template <typename T>
void debug_print_array(int len, const T* as) {
printf("[");
for (int i = 0; i < len; i++) {
if (i != 0) printf(", ");
print_value(as[i]);
}
printf("]");
}
void print_assertion_passed(const char* file, int line) {
printf("[*] Assertion passed on %s:%d\n", file, line);
}
void print_assertion_failed(const char* file, int line) {
printf("[!] Assertion failed on %s:%d\n", file, line);
}
void __assert_true(const char* file, int line, bool cond) {
if (cond) {
print_assertion_passed(file, line);
} else {
print_assertion_failed(file, line);
test_fail();
}
}
#define assert_true(cond) __assert_true(__FILE__, __LINE__, cond)
template <typename T>
void __assert_arrays_match(const char* file, int line, int len,
const T* expected, const T* got) {
if (arrays_match(len, expected, got)) {
print_assertion_passed(file, line);
} else {
print_assertion_failed(file, line);
printf("Expect = ");
debug_print_array(len, expected);
printf("\n");
printf(" Got = ");
debug_print_array(len, got);
printf("\n");
test_fail();
}
}
#define assert_arrays_match(len, expected, got) \
__assert_arrays_match(__FILE__, __LINE__, len, expected, got)
template <typename T>
void __assert_values_match(const char* file, int line, T expected, T got) {
if (expected == got) {
print_assertion_passed(file, line);
} else {
print_assertion_failed(file, line);
printf("Expect = ");
print_value(expected);
printf("\n");
printf(" Got = ");
print_value(got);
printf("\n");
test_fail();
}
}
#define assert_values_match(expected, got) \
__assert_values_match(__FILE__, __LINE__, expected, got)
// A fake set of ExceptionIds for testing only
const ErrorContextExceptions TEST_ERROR_CONTEXT_EXCEPTIONS = {
.index_error = 0,
.value_error = 1,
.assertion_error = 2,
.runtime_error = 3,
.type_error = 4,
};
ErrorContext create_testing_errctx() {
// Everything is global so it is fine to directly return a struct
// ErrorContext
ErrorContext errctx;
errctx.initialize(&TEST_ERROR_CONTEXT_EXCEPTIONS);
return errctx;
}
void print_errctx_content(ErrorContext* errctx) {
if (errctx->has_exception()) {
printf(
"(Exception ID %d): %s ... where param1 = %ld, param2 = %ld, "
"param3 = "
"%ld\n",
errctx->exception_id, errctx->msg, errctx->param1, errctx->param2,
errctx->param3);
} else {
printf("<no exception>\n");
}
}
void __assert_errctx_no_exception(const char* file, int line,
ErrorContext* errctx) {
if (errctx->has_exception()) {
print_assertion_failed(file, line);
printf("Expecting no exception but caught the following:\n\n");
print_errctx_content(errctx);
test_fail();
}
}
#define assert_errctx_no_exception(errctx) \
__assert_errctx_no_exception(__FILE__, __LINE__, errctx)
void __assert_errctx_has_exception(const char* file, int line,
ErrorContext* errctx,
ExceptionId expected_exception_id) {
if (errctx->has_exception()) {
if (errctx->exception_id != expected_exception_id) {
print_assertion_failed(file, line);
printf(
"Expecting exception id %d but got exception id %d. Error "
"caught:\n\n",
expected_exception_id, errctx->exception_id);
print_errctx_content(errctx);
test_fail();
}
} else {
print_assertion_failed(file, line);
printf("Expecting an exception, but there is none.");
test_fail();
}
}
#define assert_errctx_has_exception(errctx, expected_exception_id) \
__assert_errctx_has_exception(__FILE__, __LINE__, errctx, \
expected_exception_id)

View File

@ -661,90 +661,6 @@ pub fn call_min<'ctx>(
} }
} }
/// Invokes the `np_min` builtin function.
pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
a: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_min";
let llvm_usize = generator.get_size_type(ctx.ctx);
let (a_ty, a) = a;
Ok(match a {
BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => {
debug_assert!([
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.uint32,
ctx.primitives.int64,
ctx.primitives.uint64,
ctx.primitives.float,
]
.iter()
.any(|ty| ctx.unifier.unioned(a_ty, *ty)));
a
}
BasicValueEnum::PointerValue(n)
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None));
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let n_sz_eqz = ctx
.builder
.build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "")
.unwrap();
ctx.make_assert(
generator,
n_sz_eqz,
"0:ValueError",
"zero-size array to reduction operation minimum which has no identity",
[None, None, None],
ctx.current_loc,
);
}
let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
unsafe {
let identity =
n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
ctx.builder.build_store(accumulator_addr, identity).unwrap();
}
gen_for_callback_incrementing(
generator,
ctx,
llvm_usize.const_int(1, false),
(n_sz, false),
|generator, ctx, _, idx| {
let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) };
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
let result = call_min(ctx, (elem_ty, accumulator), (elem_ty, elem));
ctx.builder.build_store(accumulator_addr, result).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
accumulator
}
_ => unsupported_type(ctx, FN_NAME, &[a_ty]),
})
}
/// Invokes the `np_minimum` builtin function. /// Invokes the `np_minimum` builtin function.
pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>( pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
@ -877,18 +793,20 @@ pub fn call_max<'ctx>(
} }
} }
/// Invokes the `np_max` builtin function. /// Invokes the `np_max`, `np_min`, `np_argmax`, `np_argmin` functions
pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>( /// * `fn_name`: Can be one of `"np_argmin"`, `"np_argmax"`, `"np_max"`, `"np_min"`
pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
a: (Type, BasicValueEnum<'ctx>), a: (Type, BasicValueEnum<'ctx>),
fn_name: &str,
) -> Result<BasicValueEnum<'ctx>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_max"; debug_assert!(["np_argmin", "np_argmax", "np_max", "np_min"].iter().any(|f| *f == fn_name));
let llvm_int64 = ctx.ctx.i64_type();
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let (a_ty, a) = a; let (a_ty, a) = a;
Ok(match a { Ok(match a {
BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => { BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => {
debug_assert!([ debug_assert!([
@ -902,9 +820,12 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
.iter() .iter()
.any(|ty| ctx.unifier.unioned(a_ty, *ty))); .any(|ty| ctx.unifier.unioned(a_ty, *ty)));
a match fn_name {
"np_argmin" | "np_argmax" => llvm_int64.const_zero().into(),
"np_max" | "np_min" => a,
_ => unreachable!(),
}
} }
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
@ -923,41 +844,82 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
generator, generator,
n_sz_eqz, n_sz_eqz,
"0:ValueError", "0:ValueError",
"zero-size array to reduction operation minimum which has no identity", format!("zero-size array to reduction operation {fn_name}").as_str(),
[None, None, None], [None, None, None],
ctx.current_loc, ctx.current_loc,
); );
} }
let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
let res_idx = generator.gen_var_alloc(ctx, llvm_int64.into(), None)?;
unsafe { unsafe {
let identity = let identity =
n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None); n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
ctx.builder.build_store(accumulator_addr, identity).unwrap(); ctx.builder.build_store(accumulator_addr, identity).unwrap();
ctx.builder.build_store(res_idx, llvm_int64.const_zero()).unwrap();
} }
gen_for_callback_incrementing( gen_for_callback_incrementing(
generator, generator,
ctx, ctx,
llvm_usize.const_int(1, false), None,
llvm_int64.const_int(1, false),
(n_sz, false), (n_sz, false),
|generator, ctx, _, idx| { |generator, ctx, _, idx| {
let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) }; let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) };
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
let result = call_max(ctx, (elem_ty, accumulator), (elem_ty, elem)); let cur_idx = ctx.builder.build_load(res_idx, "").unwrap();
let result = match fn_name {
"np_argmin" | "np_min" => {
call_min(ctx, (elem_ty, accumulator), (elem_ty, elem))
}
"np_argmax" | "np_max" => {
call_max(ctx, (elem_ty, accumulator), (elem_ty, elem))
}
_ => unreachable!(),
};
let updated_idx = match (accumulator, result) {
(BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => ctx
.builder
.build_select(
ctx.builder.build_int_compare(IntPredicate::NE, m, n, "").unwrap(),
idx.into(),
cur_idx,
"",
)
.unwrap(),
(BasicValueEnum::FloatValue(m), BasicValueEnum::FloatValue(n)) => ctx
.builder
.build_select(
ctx.builder
.build_float_compare(FloatPredicate::ONE, m, n, "")
.unwrap(),
idx.into(),
cur_idx,
"",
)
.unwrap(),
_ => unsupported_type(ctx, fn_name, &[elem_ty, elem_ty]),
};
ctx.builder.build_store(res_idx, updated_idx).unwrap();
ctx.builder.build_store(accumulator_addr, result).unwrap(); ctx.builder.build_store(accumulator_addr, result).unwrap();
Ok(()) Ok(())
}, },
llvm_usize.const_int(1, false), llvm_int64.const_int(1, false),
)?; )?;
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); match fn_name {
accumulator "np_argmin" | "np_argmax" => ctx.builder.build_load(res_idx, "").unwrap(),
"np_max" | "np_min" => ctx.builder.build_load(accumulator_addr, "").unwrap(),
_ => unreachable!(),
}
} }
_ => unsupported_type(ctx, FN_NAME, &[a_ty]), _ => unsupported_type(ctx, fn_name, &[a_ty]),
}) })
} }

View File

@ -1717,6 +1717,7 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
gen_for_callback_incrementing( gen_for_callback_incrementing(
generator, generator,
ctx, ctx,
None,
llvm_usize.const_zero(), llvm_usize.const_zero(),
(len, false), (len, false),
|generator, ctx, _, i| { |generator, ctx, _, i| {

View File

@ -1,10 +1,15 @@
use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip}; use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
use super::{
irrt::slice::{RustUserSlice, SliceIndex},
numpy_new::object::{NDArrayObject, ScalarOrNDArray},
structure::ndarray::NpArray,
};
use crate::{ use crate::{
codegen::{ codegen::{
classes::{ classes::{
ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, NDArrayValue, ProxyType, ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, NDArrayValue, ProxyType,
ProxyValue, RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor, ProxyValue, RangeValue, UntypedArrayLikeAccessor,
}, },
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore}, concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
gen_in_range_check, get_llvm_abi_type, get_llvm_type, gen_in_range_check, get_llvm_abi_type, get_llvm_type,
@ -21,11 +26,7 @@ use crate::{
CodeGenContext, CodeGenTask, CodeGenerator, CodeGenContext, CodeGenTask, CodeGenerator,
}, },
symbol_resolver::{SymbolValue, ValueEnum}, symbol_resolver::{SymbolValue, ValueEnum},
toplevel::{ toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef},
helper::PrimDef,
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
DefinitionId, TopLevelDef,
},
typecheck::{ typecheck::{
magic_methods::{Binop, BinopVariant, HasOpInfo}, magic_methods::{Binop, BinopVariant, HasOpInfo},
typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap}, typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap},
@ -39,8 +40,18 @@ use inkwell::{
}; };
use itertools::{chain, izip, Either, Itertools}; use itertools::{chain, izip, Either, Itertools};
use nac3parser::ast::{ use nac3parser::ast::{
self, Boolop, Cmpop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef, self, Boolop, Cmpop, Comprehension, Constant, Expr, ExprKind, Located, Location, Operator,
Unaryop, StrRef, Unaryop,
};
use ndarray::indexing::RustNDIndex;
use super::{
model::*,
structure::{
cslice::CSlice,
exception::{Exception, ExceptionId},
},
}; };
pub fn get_subst_key( pub fn get_subst_key(
@ -281,24 +292,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
None None
} }
} }
Constant::Str(v) => { Constant::Str(s) => Some(self.gen_string(generator, s).value.into()),
assert!(self.unifier.unioned(ty, self.primitives.str));
if let Some(v) = self.const_strings.get(v) {
Some(*v)
} else {
let str_ptr = self
.builder
.build_global_string_ptr(v, "const")
.map(|v| v.as_pointer_value().into())
.unwrap();
let size = generator.get_size_type(self.ctx).const_int(v.len() as u64, false);
let ty = self.get_llvm_type(generator, self.primitives.str);
let val =
ty.into_struct_type().const_named_struct(&[str_ptr, size.into()]).into();
self.const_strings.insert(v.to_string(), val);
Some(val)
}
}
Constant::Ellipsis => { Constant::Ellipsis => {
let msg = self.gen_string(generator, "NotImplementedError"); let msg = self.gen_string(generator, "NotImplementedError");
@ -560,96 +554,127 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
} }
/// Helper function for generating a LLVM variable storing a [String]. /// Helper function for generating a LLVM variable storing a [String].
pub fn gen_string<G, S>(&mut self, generator: &mut G, s: S) -> BasicValueEnum<'ctx> pub fn gen_string<G>(&mut self, generator: &mut G, string: &str) -> Struct<'ctx, CSlice>
where where
G: CodeGenerator + ?Sized, G: CodeGenerator + ?Sized,
S: Into<String>,
{ {
self.gen_const(generator, &Constant::Str(s.into()), self.primitives.str).unwrap() self.const_strings.get(string).copied().unwrap_or_else(|| {
let type_context = generator.type_context(self.ctx);
let sizet_model = IntModel(SizeT);
let pbyte_model = PtrModel(IntModel(Byte));
let cslice_model = StructModel(CSlice);
let base = self.builder.build_global_string_ptr(string, "constant_string").unwrap();
let base = pbyte_model.believe_value(base.as_pointer_value());
let len = sizet_model.constant(type_context, self.ctx, string.len() as u64);
let cslice = cslice_model.create_const(type_context, self, base, len);
self.const_strings.insert(string.to_owned(), cslice);
cslice
})
}
pub fn raise_exn_impl<G: CodeGenerator + ?Sized>(
&mut self,
generator: &mut G,
exn_id: Int<'ctx, ExceptionId>,
msg: Struct<'ctx, CSlice>,
params: [Option<Int<'ctx, Int64>>; 3],
loc: Location,
) {
let exn_model = StructModel(Exception);
let exn = self.exception_val.unwrap_or_else(|| {
let exn = exn_model.var_alloca(generator, self, Some("exn")).unwrap();
*self.exception_val.insert(exn)
});
exn.gep(self, |f| f.id).store(self, exn_id);
exn.gep(self, |f| f.msg).store(self, msg);
for (i, param) in params.iter().enumerate() {
if let Some(param) = param {
exn.gep(self, |f| f.params[i]).store(self, *param);
}
}
gen_raise(generator, self, Some(exn), loc);
} }
pub fn raise_exn<G: CodeGenerator + ?Sized>( pub fn raise_exn<G: CodeGenerator + ?Sized>(
&mut self, &mut self,
generator: &mut G, generator: &mut G,
name: &str, name: &str,
msg: BasicValueEnum<'ctx>, msg: Struct<'ctx, CSlice>,
params: [Option<IntValue<'ctx>>; 3], params: [Option<Int<'ctx, Int64>>; 3],
loc: Location, loc: Location,
) { ) {
let zelf = if let Some(exception_val) = self.exception_val { let tyctx = generator.type_context(self.ctx);
exception_val let exn_id_model = IntModel(ExceptionId::default());
} else {
let ty = self.get_llvm_type(generator, self.primitives.exception).into_pointer_type(); let exn_id = self.resolver.get_string_id(name);
let zelf_ty: BasicTypeEnum = ty.get_element_type().into_struct_type().into(); let exn_id = exn_id_model.constant(tyctx, self.ctx, exn_id as u64);
let zelf = generator.gen_var_alloc(self, zelf_ty, Some("exn")).unwrap();
*self.exception_val.insert(zelf) self.raise_exn_impl(generator, exn_id, msg, params, loc);
};
let int32 = self.ctx.i32_type();
let zero = int32.const_zero();
unsafe {
let id_ptr = self.builder.build_in_bounds_gep(zelf, &[zero, zero], "exn.id").unwrap();
let id = self.resolver.get_string_id(name);
self.builder.build_store(id_ptr, int32.const_int(id as u64, false)).unwrap();
let ptr = self
.builder
.build_in_bounds_gep(zelf, &[zero, int32.const_int(5, false)], "exn.msg")
.unwrap();
self.builder.build_store(ptr, msg).unwrap();
let i64_zero = self.ctx.i64_type().const_zero();
for (i, attr_ind) in [6, 7, 8].iter().enumerate() {
let ptr = self
.builder
.build_in_bounds_gep(
zelf,
&[zero, int32.const_int(*attr_ind, false)],
"exn.param",
)
.unwrap();
let val = params[i].map_or(i64_zero, |v| {
self.builder.build_int_s_extend(v, self.ctx.i64_type(), "sext").unwrap()
});
self.builder.build_store(ptr, val).unwrap();
}
}
gen_raise(generator, self, Some(&zelf.into()), loc);
} }
pub fn make_assert<G: CodeGenerator + ?Sized>( pub fn make_assert<G: CodeGenerator + ?Sized>(
&mut self, &mut self,
generator: &mut G, generator: &mut G,
cond: IntValue<'ctx>, cond: IntValue<'ctx>, // IntType can have arbitrary bit width
err_name: &str, err_name: &str,
err_msg: &str, err_msg: &str,
params: [Option<IntValue<'ctx>>; 3], params: [Option<IntValue<'ctx>>; 3],
loc: Location, loc: Location,
) { ) {
let type_context = generator.type_context(self.ctx);
let param_model = IntModel(Int64);
let err_msg = self.gen_string(generator, err_msg); let err_msg = self.gen_string(generator, err_msg);
let ctx = self.ctx;
let params =
params.map(|p| p.map(|p| param_model.check_value(type_context, ctx, p).unwrap()));
self.make_assert_impl(generator, cond, err_name, err_msg, params, loc); self.make_assert_impl(generator, cond, err_name, err_msg, params, loc);
} }
pub fn make_assert_impl<G: CodeGenerator + ?Sized>( pub fn make_assert_impl<G: CodeGenerator + ?Sized>(
&mut self, &mut self,
generator: &mut G, generator: &mut G,
cond: IntValue<'ctx>, cond: IntValue<'ctx>, // IntType can have arbitrary bit width
err_name: &str, err_name: &str,
err_msg: BasicValueEnum<'ctx>, err_msg: Struct<'ctx, CSlice>,
params: [Option<IntValue<'ctx>>; 3], params: [Option<Int<'ctx, Int64>>; 3],
loc: Location, loc: Location,
) { ) {
let i1 = self.ctx.bool_type(); let type_context = generator.type_context(self.ctx);
let i1_true = i1.const_all_ones(); let bool_model = IntModel(Bool);
// we assume that the condition is most probably true, so the normal path is the most
// probable path // We assume that the condition is most probably true, so the normal path is the most
// even if this assumption is violated, it does not matter as exception unwinding is // probable path even if this assumption is violated, it does not matter as exception unwinding is.
// slow anyway... let cond = call_expect(
let cond = call_expect(self, cond, i1_true, Some("expect")); self,
generator.bool_to_i1(self, cond),
bool_model.const_true(type_context, self.ctx).value,
Some("expect"),
);
let current_bb = self.builder.get_insert_block().unwrap(); let current_bb = self.builder.get_insert_block().unwrap();
let current_fun = current_bb.get_parent().unwrap(); let current_fun = current_bb.get_parent().unwrap();
let then_block = self.ctx.insert_basic_block_after(current_bb, "succ"); let then_block = self.ctx.insert_basic_block_after(current_bb, "succ");
let exn_block = self.ctx.append_basic_block(current_fun, "fail"); let exn_block = self.ctx.append_basic_block(current_fun, "fail");
self.builder.build_conditional_branch(cond, then_block, exn_block).unwrap(); self.builder.build_conditional_branch(cond, then_block, exn_block).unwrap();
// Inserting into `exn_block`
self.builder.position_at_end(exn_block); self.builder.position_at_end(exn_block);
self.raise_exn(generator, err_name, err_msg, params, loc); self.raise_exn(generator, err_name, err_msg, params, loc);
// Continuation
self.builder.position_at_end(then_block); self.builder.position_at_end(then_block);
} }
} }
@ -951,9 +976,9 @@ pub fn destructure_range<'ctx>(
/// Allocates a List structure with the given [type][ty] and [length]. The name of the resulting /// Allocates a List structure with the given [type][ty] and [length]. The name of the resulting
/// LLVM value is `{name}.addr`, or `list.addr` if [name] is not specified. /// LLVM value is `{name}.addr`, or `list.addr` if [name] is not specified.
/// ///
/// Setting `ty` to [`None`] implies that the list does not have a known element type, which is only /// Setting `ty` to [`None`] implies that the list is empty **and** does not have a known element
/// valid for empty lists. It is undefined behavior to generate a sized list with an unknown element /// type, and will therefore set the `list.data` type as `size_t*`. It is undefined behavior to
/// type. /// generate a sized list with an unknown element type.
pub fn allocate_list<'ctx, G: CodeGenerator + ?Sized>( pub fn allocate_list<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -1016,7 +1041,6 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
let elem_ty = ctx.get_llvm_type(generator, elt.custom.unwrap()); let elem_ty = ctx.get_llvm_type(generator, elt.custom.unwrap());
let is_range = ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range); let is_range = ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range);
let list; let list;
let list_content;
if is_range { if is_range {
let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range")); let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range"));
@ -1047,7 +1071,6 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
list_alloc_size.into_int_value(), list_alloc_size.into_int_value(),
Some("listcomp.addr"), Some("listcomp.addr"),
); );
list_content = list.data().base_ptr(ctx, generator);
let i = generator.gen_store_target(ctx, target, Some("i.addr"))?.unwrap(); let i = generator.gen_store_target(ctx, target, Some("i.addr"))?.unwrap();
ctx.builder ctx.builder
@ -1083,10 +1106,10 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
) )
.into_int_value(); .into_int_value();
list = allocate_list(generator, ctx, Some(elem_ty), length, Some("listcomp")); list = allocate_list(generator, ctx, Some(elem_ty), length, Some("listcomp"));
list_content = list.data().base_ptr(ctx, generator);
let counter = generator.gen_var_alloc(ctx, size_t.into(), Some("counter.addr"))?; let counter = generator.gen_var_alloc(ctx, size_t.into(), Some("counter.addr"))?;
// counter = -1 // counter = -1
ctx.builder.build_store(counter, size_t.const_int(u64::MAX, true)).unwrap(); ctx.builder.build_store(counter, size_t.const_all_ones()).unwrap();
ctx.builder.build_unconditional_branch(test_bb).unwrap(); ctx.builder.build_unconditional_branch(test_bb).unwrap();
ctx.builder.position_at_end(test_bb); ctx.builder.position_at_end(test_bb);
@ -1105,7 +1128,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
) )
.into_pointer_value(); .into_pointer_value();
let val = ctx.build_gep_and_load(arr_ptr, &[tmp], Some("val")); let val = ctx.build_gep_and_load(arr_ptr, &[tmp], Some("val"));
generator.gen_assign(ctx, target, val.into())?; generator.gen_assign(ctx, target, val.into(), ctx.primitives.int32)?;
} }
// Emits the content of `cont_bb` // Emits the content of `cont_bb`
@ -1143,7 +1166,8 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
return Ok(None); return Ok(None);
}; };
let i = ctx.builder.build_load(index, "i").map(BasicValueEnum::into_int_value).unwrap(); let i = ctx.builder.build_load(index, "i").map(BasicValueEnum::into_int_value).unwrap();
let elem_ptr = unsafe { ctx.builder.build_gep(list_content, &[i], "elem_ptr") }.unwrap(); let elem_ptr =
unsafe { list.data().ptr_offset_unchecked(ctx, generator, &i, Some("elem_ptr")) };
let val = elem.to_basic_value_enum(ctx, generator, elt.custom.unwrap())?; let val = elem.to_basic_value_enum(ctx, generator, elt.custom.unwrap())?;
ctx.builder.build_store(elem_ptr, val).unwrap(); ctx.builder.build_store(elem_ptr, val).unwrap();
ctx.builder ctx.builder
@ -1226,6 +1250,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
debug_assert!(ctx.unifier.unioned(elem_ty1, elem_ty2)); debug_assert!(ctx.unifier.unioned(elem_ty1, elem_ty2));
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty1); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty1);
let sizeof_elem = llvm_elem_ty.size_of().unwrap();
let lhs = ListValue::from_ptr_val(left_val.into_pointer_value(), llvm_usize, None); let lhs = ListValue::from_ptr_val(left_val.into_pointer_value(), llvm_usize, None);
let rhs = ListValue::from_ptr_val(right_val.into_pointer_value(), llvm_usize, None); let rhs = ListValue::from_ptr_val(right_val.into_pointer_value(), llvm_usize, None);
@ -1237,14 +1262,25 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
let new_list = allocate_list(generator, ctx, Some(llvm_elem_ty), size, None); let new_list = allocate_list(generator, ctx, Some(llvm_elem_ty), size, None);
let lhs_len = ctx let lhs_size = ctx
.builder .builder
.build_int_mul(lhs.load_size(ctx, None), llvm_elem_ty.size_of().unwrap(), "") .build_int_z_extend_or_bit_cast(
lhs.load_size(ctx, None),
sizeof_elem.get_type(),
"",
)
.unwrap(); .unwrap();
let rhs_len = ctx let lhs_len = ctx.builder.build_int_mul(lhs_size, sizeof_elem, "").unwrap();
let rhs_size = ctx
.builder .builder
.build_int_mul(rhs.load_size(ctx, None), llvm_elem_ty.size_of().unwrap(), "") .build_int_z_extend_or_bit_cast(
rhs.load_size(ctx, None),
sizeof_elem.get_type(),
"",
)
.unwrap(); .unwrap();
let rhs_len = ctx.builder.build_int_mul(rhs_size, sizeof_elem, "").unwrap();
let list_ptr = new_list.data().base_ptr(ctx, generator); let list_ptr = new_list.data().base_ptr(ctx, generator);
call_memcpy_generic( call_memcpy_generic(
@ -1309,6 +1345,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
let int_val = call_int_smax(ctx, int_val, llvm_usize.const_zero(), None); let int_val = call_int_smax(ctx, int_val, llvm_usize.const_zero(), None);
let elem_llvm_ty = ctx.get_llvm_type(generator, elem_ty); let elem_llvm_ty = ctx.get_llvm_type(generator, elem_ty);
let sizeof_elem = elem_llvm_ty.size_of().unwrap();
let new_list = allocate_list( let new_list = allocate_list(
generator, generator,
@ -1321,6 +1358,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
gen_for_callback_incrementing( gen_for_callback_incrementing(
generator, generator,
ctx, ctx,
None,
llvm_usize.const_zero(), llvm_usize.const_zero(),
(int_val, false), (int_val, false),
|generator, ctx, _, i| { |generator, ctx, _, i| {
@ -1332,15 +1370,18 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
new_list.data().ptr_offset_unchecked(ctx, generator, &offset, None) new_list.data().ptr_offset_unchecked(ctx, generator, &offset, None)
}; };
let memcpy_sz = ctx let list_size = ctx
.builder .builder
.build_int_mul( .build_int_z_extend_or_bit_cast(
list_val.load_size(ctx, None), list_val.load_size(ctx, None),
elem_llvm_ty.size_of().unwrap(), sizeof_elem.get_type(),
"", "",
) )
.unwrap(); .unwrap();
let memcpy_sz =
ctx.builder.build_int_mul(list_size, sizeof_elem, "").unwrap();
call_memcpy_generic( call_memcpy_generic(
ctx, ctx,
ptr, ptr,
@ -1928,6 +1969,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
gen_for_callback_incrementing( gen_for_callback_incrementing(
generator, generator,
ctx, ctx,
None,
llvm_usize.const_zero(), llvm_usize.const_zero(),
(left_val.load_size(ctx, None), false), (left_val.load_size(ctx, None), false),
|generator, ctx, hooks, i| { |generator, ctx, hooks, i| {
@ -2088,324 +2130,98 @@ pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>(
) )
} }
/// Generates code for a subscript expression on an `ndarray`. pub fn gen_ndarray_subscript_ndindexes<'ctx, G: CodeGenerator>(
///
/// * `ty` - The `Type` of the `NDArray` elements.
/// * `ndims` - The `Type` of the `NDArray` number-of-dimensions `Literal`.
/// * `v` - The `NDArray` value.
/// * `slice` - The slice expression used to subscript into the `ndarray`.
fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
ty: Type, subscript: &Expr<Option<Type>>,
ndims: Type, ) -> Result<Vec<RustNDIndex<'ctx>>, String> {
v: NDArrayValue<'ctx>, // TODO: Support https://numpy.org/doc/stable/user/basics.indexing.html#dimensional-indexing-tools
slice: &Expr<Option<Type>>, let tyctx = generator.type_context(ctx.ctx);
) -> Result<Option<ValueEnum<'ctx>>, String> { let slice_index_model = IntModel(SliceIndex::default());
let llvm_i1 = ctx.ctx.bool_type();
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) else { // Annoying notes about `slice`
unreachable!() // - `my_array[5]`
// - slice is a `Constant`
// - `my_array[:5]`
// - slice is a `Slice`
// - `my_array[:]`
// - slice is a `Slice`, but lower upper step would all be `Option::None`
// - `my_array[:, :]`
// - slice is now a `Tuple` of two `Slice`-s
//
// In summary:
// - when there is a comma "," within [], `slice` will be a `Tuple` of the entries.
// - when there is not comma "," within [] (i.e., just a single entry), `slice` will be that entry itself.
//
// So we first "flatten" out the slice expression
let index_exprs = match &subscript.node {
ExprKind::Tuple { elts, .. } => elts.iter().collect_vec(),
_ => vec![subscript],
}; };
let ndims = values // Process all index expressions
.iter() let mut rust_ndindexes: Vec<RustNDIndex> = Vec::with_capacity(index_exprs.len()); // Not using iterators here because `?` is used here.
.map(|ndim| u64::try_from(ndim.clone()).map_err(|()| ndim.clone())) for index_expr in index_exprs {
.collect::<Result<Vec<_>, _>>() // NOTE: Currently nac3core's slices do not have an object representation,
.map_err(|val| { // so the code/implementation looks awkward - we have to do pattern matching on the expression
format!( let ndindex = if let ExprKind::Slice { lower: start, upper: stop, step } = &index_expr.node
"Expected non-negative literal for ndarray.ndims, got {}", {
i128::try_from(val).unwrap() // Helper function here to deduce code duplication
) type ValueExpr = Option<Box<Located<ExprKind<Option<Type>>, Option<Type>>>>;
})?; let mut help = |value_expr: &ValueExpr| -> Result<_, String> {
Ok(match value_expr {
None => None,
Some(value_expr) => {
let value_expr = generator
.gen_expr(ctx, value_expr)?
.unwrap()
.to_basic_value_enum(ctx, generator, ctx.primitives.int32)?;
assert!(!ndims.is_empty()); let value_expr =
slice_index_model.check_value(tyctx, ctx.ctx, value_expr).unwrap();
// The number of dimensions subscripted by the index expression. Some(value_expr)
// Slicing a ndarray will yield the same number of dimensions, whereas indexing into a }
// dimension will remove a dimension. })
let subscripted_dims = match &slice.node {
ExprKind::Tuple { elts, .. } => elts.iter().fold(0, |acc, value_subexpr| {
if let ExprKind::Slice { .. } = &value_subexpr.node {
acc
} else {
acc + 1
}
}),
ExprKind::Slice { .. } => 0,
_ => 1,
};
let ndarray_ndims_ty = ctx.unifier.get_fresh_literal(
ndims.iter().map(|v| SymbolValue::U64(v - subscripted_dims)).collect(),
None,
);
let ndarray_ty =
make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(ty), Some(ndarray_ndims_ty));
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum();
// Check that len is non-zero
let len = v.load_ndims(ctx);
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::SGT, len, llvm_usize.const_zero(), "").unwrap(),
"0:IndexError",
"too many indices for array: array is {0}-dimensional but 1 were indexed",
[Some(len), None, None],
slice.location,
);
// Normalizes a possibly-negative index to its corresponding positive index
let normalize_index = |generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
index: IntValue<'ctx>,
dim: u64| {
gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(IntPredicate::SGE, index, index.get_type().const_zero(), "")
.unwrap())
},
|_, _| Ok(Some(index)),
|generator, ctx| {
let llvm_i32 = ctx.ctx.i32_type();
let len = unsafe {
v.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(dim, true),
None,
)
};
let index = ctx
.builder
.build_int_add(
len,
ctx.builder.build_int_s_extend(index, llvm_usize, "").unwrap(),
"",
)
.unwrap();
Ok(Some(ctx.builder.build_int_truncate(index, llvm_i32, "").unwrap()))
},
)
.map(|v| v.map(BasicValueEnum::into_int_value))
};
// Converts a slice expression into a slice-range tuple
let expr_to_slice = |generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
node: &ExprKind<Option<Type>>,
dim: u64| {
match node {
ExprKind::Constant { value: Constant::Int(v), .. } => {
let Some(index) =
normalize_index(generator, ctx, llvm_i32.const_int(*v as u64, true), dim)?
else {
return Ok(None);
};
Ok(Some((index, index, llvm_i32.const_int(1, true))))
}
ExprKind::Slice { lower, upper, step } => {
let dim_sz = unsafe {
v.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(dim, false),
None,
)
};
handle_slice_indices(lower, upper, step, ctx, generator, dim_sz)
}
_ => {
let Some(index) = generator.gen_expr(ctx, slice)? else { return Ok(None) };
let index = index
.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?
.into_int_value();
let Some(index) = normalize_index(generator, ctx, index, dim)? else {
return Ok(None);
};
Ok(Some((index, index, llvm_i32.const_int(1, true))))
}
}
};
let make_indices_arr = |generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>|
-> Result<_, String> {
Ok(if let ExprKind::Tuple { elts, .. } = &slice.node {
let llvm_int_ty = ctx.get_llvm_type(generator, elts[0].custom.unwrap());
let index_addr = generator.gen_array_var_alloc(
ctx,
llvm_int_ty,
llvm_usize.const_int(elts.len() as u64, false),
None,
)?;
for (i, elt) in elts.iter().enumerate() {
let Some(index) = generator.gen_expr(ctx, elt)? else {
return Ok(None);
};
let index = index
.to_basic_value_enum(ctx, generator, elt.custom.unwrap())?
.into_int_value();
let Some(index) = normalize_index(generator, ctx, index, 0)? else {
return Ok(None);
};
let store_ptr = unsafe {
index_addr.ptr_offset_unchecked(
ctx,
generator,
&llvm_usize.const_int(i as u64, false),
None,
)
};
ctx.builder.build_store(store_ptr, index).unwrap();
}
Some(index_addr)
} else if let Some(index) = generator.gen_expr(ctx, slice)? {
let llvm_int_ty = ctx.get_llvm_type(generator, slice.custom.unwrap());
let index_addr = generator.gen_array_var_alloc(
ctx,
llvm_int_ty,
llvm_usize.const_int(1u64, false),
None,
)?;
let index =
index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value();
let Some(index) = normalize_index(generator, ctx, index, 0)? else { return Ok(None) };
let store_ptr = unsafe {
index_addr.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
}; };
ctx.builder.build_store(store_ptr, index).unwrap();
Some(index_addr) let start = help(start)?;
let stop = help(stop)?;
let step = help(step)?;
RustNDIndex::Slice(RustUserSlice { start, stop, step })
} else { } else {
None // Anything else that is not a slice (might be illegal values),
}) // For nac3core, this should be e.g., an int32 constant, an int32 variable, otherwise its an error
};
Ok(Some(if ndims.len() == 1 && ndims[0] - subscripted_dims == 0 { let index = generator.gen_expr(ctx, index_expr)?.unwrap().to_basic_value_enum(
let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) }; ctx,
generator,
ctx.primitives.int32,
)?;
let index = slice_index_model.check_value(tyctx, ctx.ctx, index).unwrap();
v.data().get(ctx, generator, &index_addr, None).into() RustNDIndex::SingleElement(index)
} else { };
match &slice.node { rust_ndindexes.push(ndindex);
ExprKind::Tuple { elts, .. } => { }
let slices = elts Ok(rust_ndindexes)
.iter() }
.enumerate()
.map(|(dim, elt)| expr_to_slice(generator, ctx, &elt.node, dim as u64))
.take_while_inclusive(|slice| slice.as_ref().is_ok_and(Option::is_some))
.collect::<Result<Vec<_>, _>>()?;
if slices.len() < elts.len() {
return Ok(None);
}
let slices = slices.into_iter().map(Option::unwrap).collect_vec(); /// Generates code for a subscript expression on an `ndarray`.
///
numpy::ndarray_sliced_copy(generator, ctx, ty, v, &slices)?.as_base_value().into() /// * `elem_ty` - The `Type` of the `NDArray` elements.
} /// * `ndims` - The `Type` of the `NDArray` number-of-dimensions `Literal`.
/// * `src_ndarray` - The `NDArray` value.
ExprKind::Slice { .. } => { /// * `subscript` - The subscript expression used to index into the `ndarray`.
let Some(slice) = expr_to_slice(generator, ctx, &slice.node, 0)? else { pub fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
return Ok(None); generator: &mut G,
}; ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayObject<'ctx>,
numpy::ndarray_sliced_copy(generator, ctx, ty, v, &[slice])?.as_base_value().into() subscript: &Expr<Option<Type>>,
} ) -> Result<ScalarOrNDArray<'ctx>, String> {
let indexes = gen_ndarray_subscript_ndindexes(generator, ctx, subscript)?;
_ => { Ok(ndarray.index(generator, ctx, &indexes, "subndarray"))
// Accessing an element from a multi-dimensional `ndarray`
let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) };
// Create a new array, remove the top dimension from the dimension-size-list, and copy the
// elements over
let subscripted_ndarray =
generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?;
let ndarray = NDArrayValue::from_ptr_val(subscripted_ndarray, llvm_usize, None);
let num_dims = v.load_ndims(ctx);
ndarray.store_ndims(
ctx,
generator,
ctx.builder
.build_int_sub(num_dims, llvm_usize.const_int(1, false), "")
.unwrap(),
);
let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
let ndarray_num_dims = ndarray.load_ndims(ctx);
let v_dims_src_ptr = unsafe {
v.dim_sizes().ptr_offset_unchecked(
ctx,
generator,
&llvm_usize.const_int(1, false),
None,
)
};
call_memcpy_generic(
ctx,
ndarray.dim_sizes().base_ptr(ctx, generator),
v_dims_src_ptr,
ctx.builder
.build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "")
.map(Into::into)
.unwrap(),
llvm_i1.const_zero(),
);
let ndarray_num_elems = call_ndarray_calc_size(
generator,
ctx,
&ndarray.dim_sizes().as_slice_value(ctx, generator),
(None, None),
);
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
let v_data_src_ptr = v.data().ptr_offset(ctx, generator, &index_addr, None);
call_memcpy_generic(
ctx,
ndarray.data().base_ptr(ctx, generator),
v_data_src_ptr,
ctx.builder
.build_int_mul(
ndarray_num_elems,
llvm_ndarray_data_t.size_of().unwrap(),
"",
)
.map(Into::into)
.unwrap(),
llvm_i1.const_zero(),
);
ndarray.as_base_value().into()
}
}
}))
} }
/// See [`CodeGenerator::gen_expr`]. /// See [`CodeGenerator::gen_expr`].
@ -2457,7 +2273,29 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
Some((_, Some(static_value), _)) => ValueEnum::Static(static_value.clone()), Some((_, Some(static_value), _)) => ValueEnum::Static(static_value.clone()),
None => { None => {
let resolver = ctx.resolver.clone(); let resolver = ctx.resolver.clone();
resolver.get_symbol_value(*id, ctx).unwrap() if let Some(res) = resolver.get_symbol_value(*id, ctx) {
res
} else {
// Allow "raise Exception" short form
let def_id = resolver.get_identifier_def(*id).map_err(|e| {
format!("{} (at {})", e.iter().next().unwrap(), expr.location)
})?;
let def = ctx.top_level.definitions.read();
if let TopLevelDef::Class { constructor, .. } = *def[def_id.0].read() {
let TypeEnum::TFunc(signature) =
ctx.unifier.get_ty(constructor.unwrap()).as_ref().clone()
else {
return Err(format!(
"Failed to resolve symbol {} (at {})",
id, expr.location
));
};
return Ok(generator
.gen_call(ctx, None, (&signature, def_id), Vec::default())?
.map(Into::into));
}
return Err(format!("Failed to resolve symbol {} (at {})", id, expr.location));
}
} }
}, },
ExprKind::List { elts, .. } => { ExprKind::List { elts, .. } => {
@ -3026,17 +2864,22 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
} }
} }
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::NDArray.id() => {
let (ty, ndims) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap(); let tyctx = generator.type_context(ctx.ctx);
let pndarray_model = PtrModel(StructModel(NpArray));
let v = if let Some(v) = generator.gen_expr(ctx, value)? { let (&dtype, &ndims) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap();
v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?
.into_pointer_value() let Some(ndarray) = generator.gen_expr(ctx, value)? else {
} else {
return Ok(None); return Ok(None);
}; };
let v = NDArrayValue::from_ptr_val(v, usize, None);
return gen_ndarray_subscript_expr(generator, ctx, *ty, *ndims, v, slice); let ndarray =
ndarray.to_basic_value_enum(ctx, generator, value.custom.unwrap())?;
let ndarray = pndarray_model.check_value(tyctx, ctx.ctx, ndarray).unwrap();
let ndarray = NDArrayObject { dtype, ndims, instance: ndarray };
let result = gen_ndarray_subscript_expr(generator, ctx, ndarray, slice)?;
return Ok(Some(ValueEnum::Dynamic(result.to_basic_value_enum())));
} }
TypeEnum::TTuple { .. } => { TypeEnum::TTuple { .. } => {
let index: u32 = let index: u32 =

View File

@ -123,11 +123,12 @@ pub trait CodeGenerator {
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
target: &Expr<Option<Type>>, target: &Expr<Option<Type>>,
value: ValueEnum<'ctx>, value: ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String> ) -> Result<(), String>
where where
Self: Sized, Self: Sized,
{ {
gen_assign(self, ctx, target, value) gen_assign(self, ctx, target, value, value_ty)
} }
/// Generate code for a while expression. /// Generate code for a while expression.

View File

@ -0,0 +1,198 @@
use super::util::{function::CallFunction, get_sizet_dependent_function_name};
use crate::codegen::{
model::*,
structure::{cslice::CSlice, exception::ExceptionId},
CodeGenContext, CodeGenerator,
};
#[allow(clippy::struct_field_names)]
pub struct ErrorContextExceptionsFields<F: FieldVisitor> {
pub index_error: F::Field<IntModel<ExceptionId>>,
pub value_error: F::Field<IntModel<ExceptionId>>,
pub assertion_error: F::Field<IntModel<ExceptionId>>,
pub runtime_error: F::Field<IntModel<ExceptionId>>,
pub type_error: F::Field<IntModel<ExceptionId>>,
}
/// Corresponds to IRRT's `struct ErrorContextExceptions`
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct ErrorContextExceptions;
impl StructKind for ErrorContextExceptions {
type Fields<F: FieldVisitor> = ErrorContextExceptionsFields<F>;
fn visit_fields<F: FieldVisitor>(&self, visitor: &mut F) -> Self::Fields<F> {
Self::Fields {
index_error: visitor.add("index_error"),
value_error: visitor.add("value_error"),
assertion_error: visitor.add("assertion_error"),
runtime_error: visitor.add("runtime_error"),
type_error: visitor.add("type_error"),
}
}
}
pub struct ErrorContextFields<F: FieldVisitor> {
pub exceptions: F::Field<PtrModel<StructModel<ErrorContextExceptions>>>,
pub exception_id: F::Field<IntModel<ExceptionId>>,
pub msg: F::Field<PtrModel<IntModel<Byte>>>,
pub param1: F::Field<IntModel<Int64>>,
pub param2: F::Field<IntModel<Int64>>,
pub param3: F::Field<IntModel<Int64>>,
}
/// Corresponds to IRRT's `struct ErrorContext`
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct ErrorContext;
impl StructKind for ErrorContext {
type Fields<F: FieldVisitor> = ErrorContextFields<F>;
fn visit_fields<F: FieldVisitor>(&self, visitor: &mut F) -> Self::Fields<F> {
Self::Fields {
exceptions: visitor.add("exceptions"),
exception_id: visitor.add("exception_id"),
msg: visitor.add("msg"),
param1: visitor.add("param1"),
param2: visitor.add("param2"),
param3: visitor.add("param3"),
}
}
}
/// Build an [`ErrorContextExceptions`] loaded with resolved [`ExceptionID`]s according to the [`SymbolResolver`].
fn build_error_context_exceptions<'ctx>(
tyctx: TypeContext<'ctx>,
ctx: &CodeGenContext<'ctx, '_>,
) -> Ptr<'ctx, StructModel<ErrorContextExceptions>> {
let exceptions =
StructModel(ErrorContextExceptions).alloca(tyctx, ctx, "error_context_exceptions");
let i32_model = IntModel(Int32);
let get_string_id = |string_id| {
i32_model.constant(tyctx, ctx.ctx, ctx.resolver.get_string_id(string_id) as u64)
};
exceptions.gep(ctx, |f| f.index_error).store(ctx, get_string_id("0:IndexError"));
exceptions.gep(ctx, |f| f.value_error).store(ctx, get_string_id("0:ValueError"));
exceptions.gep(ctx, |f| f.assertion_error).store(ctx, get_string_id("0:AssertionError"));
exceptions.gep(ctx, |f| f.runtime_error).store(ctx, get_string_id("0:RuntimeError"));
exceptions.gep(ctx, |f| f.type_error).store(ctx, get_string_id("0:TypeError"));
exceptions
}
pub fn call_nac3_error_context_initialize<'ctx>(
tyctx: TypeContext<'ctx>,
ctx: &CodeGenContext<'ctx, '_>,
perrctx: Ptr<'ctx, StructModel<ErrorContext>>,
pexceptions: Ptr<'ctx, StructModel<ErrorContextExceptions>>,
) {
CallFunction::begin(tyctx, ctx, "__nac3_error_context_initialize")
.arg("errctx", perrctx)
.arg("exceptions", pexceptions)
.returning_void();
}
pub fn call_nac3_error_context_has_exception<'ctx>(
tyctx: TypeContext<'ctx>,
ctx: &CodeGenContext<'ctx, '_>,
perrctx: Ptr<'ctx, StructModel<ErrorContext>>,
) -> Int<'ctx, Bool> {
CallFunction::begin(tyctx, ctx, "__nac3_error_context_has_exception")
.arg("errctx", perrctx)
.returning("has_exception")
}
pub fn call_nac3_error_context_get_exception_str<'ctx>(
tyctx: TypeContext<'ctx>,
ctx: &CodeGenContext<'ctx, '_>,
perrctx: Ptr<'ctx, StructModel<ErrorContext>>,
dst_str: Ptr<'ctx, StructModel<CSlice>>,
) {
CallFunction::begin(
tyctx,
ctx,
&get_sizet_dependent_function_name(tyctx, "__nac3_error_context_get_exception_str"),
)
.arg("errctx", perrctx)
.arg("dst_str", dst_str)
.returning_void();
}
/// Setup a [`ErrorContext`] that could be passed to IRRT functions taking in a `ErrorContext* errctx`
/// for error reporting purposes.
///
/// Also see: [`check_error_context`]
pub fn setup_error_context<'ctx>(
tyctx: TypeContext<'ctx>,
ctx: &CodeGenContext<'ctx, '_>,
) -> Ptr<'ctx, StructModel<ErrorContext>> {
let errctx_model = StructModel(ErrorContext);
let exceptions = build_error_context_exceptions(tyctx, ctx);
let errctx_ptr = errctx_model.alloca(tyctx, ctx, "errctx");
call_nac3_error_context_initialize(tyctx, ctx, errctx_ptr, exceptions);
errctx_ptr
}
/// Check a [`ErrorContext`] to see if it contains error. **If there is an error,
/// a Pythonic exception will be raised in the firmware**.
pub fn check_error_context<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
perrctx: Ptr<'ctx, StructModel<ErrorContext>>,
) {
let tyctx = generator.type_context(ctx.ctx);
let cslice_model = StructModel(CSlice);
let current_bb = ctx.builder.get_insert_block().unwrap();
let irrt_has_exception_bb = ctx.ctx.insert_basic_block_after(current_bb, "irrt_has_exception");
let end_bb = ctx.ctx.insert_basic_block_after(irrt_has_exception_bb, "end");
// Inserting into `current_bb`
let has_exception = call_nac3_error_context_has_exception(tyctx, ctx, perrctx);
ctx.builder
.build_conditional_branch(has_exception.value, irrt_has_exception_bb, end_bb)
.unwrap();
// Inserting into `irrt_has_exception_bb`
ctx.builder.position_at_end(irrt_has_exception_bb);
// Load all the values for `ctx.make_assert_impl_by_id`
let pexception_str = cslice_model.alloca(tyctx, ctx, "exception_str");
call_nac3_error_context_get_exception_str(tyctx, ctx, perrctx, pexception_str);
let exception_id = perrctx.gep(ctx, |f| f.exception_id).load(tyctx, ctx, "exception_id");
let msg = pexception_str.load(tyctx, ctx, "msg");
let param1 = perrctx.gep(ctx, |f| f.param1).load(tyctx, ctx, "param1");
let param2 = perrctx.gep(ctx, |f| f.param2).load(tyctx, ctx, "param2");
let param3 = perrctx.gep(ctx, |f| f.param3).load(tyctx, ctx, "param3");
ctx.raise_exn_impl(
generator,
exception_id,
msg,
[Some(param1), Some(param2), Some(param3)],
ctx.current_loc,
);
// Position to `end_bb` for continuation
ctx.builder.position_at_end(end_bb);
}
pub fn call_nac3_dummy_raise<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext,
) {
let tyctx = generator.type_context(ctx.ctx);
let errctx = setup_error_context(tyctx, ctx);
CallFunction::begin(tyctx, ctx, "__nac3_error_dummy_raise")
.arg("errctx", errctx)
.returning_void();
check_error_context(generator, ctx, errctx);
}

View File

@ -1,5 +1,12 @@
use crate::typecheck::typedef::Type; use crate::typecheck::typedef::Type;
pub mod error_context;
pub mod ndarray;
pub mod slice;
mod test;
mod util;
use super::model::*;
use super::{ use super::{
classes::{ classes::{
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue,
@ -414,14 +421,29 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
.unwrap(); .unwrap();
let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1").unwrap(); let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1").unwrap();
let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond").unwrap(); let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond").unwrap();
ctx.make_assert(
generator, // TODO: Temporary fix. Rewrite `list_slice_assignment` later
cond, // Exception params should have been i64
"0:ValueError", {
"attempt to assign sequence of size {0} to slice of size {1} with step size {2}", let type_context = generator.type_context(ctx.ctx);
[Some(src_slice_len), Some(dest_slice_len), Some(dest_idx.2)], let param_model = IntModel(Int64);
ctx.current_loc,
); let src_slice_len =
param_model.s_extend_or_bit_cast(type_context, ctx, src_slice_len, "src_slice_len");
let dest_slice_len =
param_model.s_extend_or_bit_cast(type_context, ctx, dest_slice_len, "dest_slice_len");
let dest_idx_2 =
param_model.s_extend_or_bit_cast(type_context, ctx, dest_idx.2, "dest_idx_2");
ctx.make_assert(
generator,
cond,
"0:ValueError",
"attempt to assign sequence of size {0} to slice of size {1} with step size {2}",
[Some(src_slice_len.value), Some(dest_slice_len.value), Some(dest_idx_2.value)],
ctx.current_loc,
);
}
let new_len = { let new_len = {
let args = vec![ let args = vec![
@ -798,6 +820,7 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
gen_for_callback_incrementing( gen_for_callback_incrementing(
generator, generator,
ctx, ctx,
None,
llvm_usize.const_zero(), llvm_usize.const_zero(),
(min_ndims, false), (min_ndims, false),
|generator, ctx, _, idx| { |generator, ctx, _, idx| {

View File

@ -0,0 +1,153 @@
use crate::codegen::irrt::error_context::{check_error_context, setup_error_context};
use crate::codegen::irrt::slice::SliceIndex;
use crate::codegen::irrt::util::function::CallFunction;
use crate::codegen::irrt::util::get_sizet_dependent_function_name;
use crate::codegen::model::*;
use crate::codegen::structure::ndarray::NpArray;
use crate::codegen::{CodeGenContext, CodeGenerator};
pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray_ptr: Ptr<'ctx, StructModel<NpArray>>,
) -> Int<'ctx, SizeT> {
let tyctx = generator.type_context(ctx.ctx);
CallFunction::begin(
tyctx,
ctx,
&get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_size"),
)
.arg("ndarray", ndarray_ptr)
.returning("size")
}
pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray_ptr: Ptr<'ctx, StructModel<NpArray>>,
) -> Int<'ctx, SizeT> {
let tyctx = generator.type_context(ctx.ctx);
CallFunction::begin(
tyctx,
ctx,
&get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_nbytes"),
)
.arg("ndarray", ndarray_ptr)
.returning("nbytes")
}
pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray_ptr: Ptr<'ctx, StructModel<NpArray>>,
) -> Int<'ctx, SliceIndex> {
let tyctx = generator.type_context(ctx.ctx);
let slice_index_model = IntModel(SliceIndex::default());
let dst_len = slice_index_model.alloca(tyctx, ctx, "dst_len");
let errctx = setup_error_context(tyctx, ctx);
CallFunction::begin(
tyctx,
ctx,
&get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_len"),
)
.arg("errctx", errctx)
.arg("ndarray", ndarray_ptr)
.arg("dst_len", dst_len)
.returning_void();
check_error_context(generator, ctx, errctx);
dst_len.load(tyctx, ctx, "len")
}
pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndims: Int<'ctx, SizeT>,
shape: Ptr<'ctx, IntModel<SizeT>>,
) {
let tyctx = generator.type_context(ctx.ctx);
let errctx = setup_error_context(tyctx, ctx);
CallFunction::begin(
tyctx,
ctx,
&get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_util_assert_shape_no_negative"),
)
.arg("errctx", errctx)
.arg("ndims", ndims)
.arg("shape", shape)
.returning_void();
check_error_context(generator, ctx, errctx);
}
pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray_ptr: Ptr<'ctx, StructModel<NpArray>>,
) {
let tyctx = generator.type_context(ctx.ctx);
CallFunction::begin(
tyctx,
ctx,
&get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_set_strides_by_shape"),
)
.arg("ndarray", ndarray_ptr)
.returning_void();
}
pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray_ptr: Ptr<'ctx, StructModel<NpArray>>,
) -> Int<'ctx, Bool> {
let tyctx = generator.type_context(ctx.ctx);
CallFunction::begin(
tyctx,
ctx,
&get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_is_c_contiguous"),
)
.arg("ndarray", ndarray_ptr)
.returning("is_c_contiguous")
}
pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src_ndarray: Ptr<'ctx, StructModel<NpArray>>,
dst_ndarray: Ptr<'ctx, StructModel<NpArray>>,
) {
let tyctx = generator.type_context(ctx.ctx);
CallFunction::begin(
tyctx,
ctx,
&get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_copy_data"),
)
.arg("src_ndarray", src_ndarray)
.arg("dst_ndarray", dst_ndarray)
.returning_void();
}
pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
pndarray: Ptr<'ctx, StructModel<NpArray>>,
index: Int<'ctx, SizeT>,
) -> Ptr<'ctx, IntModel<Byte>> {
let tyctx = generator.type_context(ctx.ctx);
CallFunction::begin(
tyctx,
ctx,
&get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_get_nth_pelement"),
)
.arg("ndarray", pndarray)
.arg("index", index)
.returning("pelement")
}

View File

@ -0,0 +1,74 @@
use crate::codegen::{
irrt::{
error_context::{check_error_context, setup_error_context},
util::{function::CallFunction, get_sizet_dependent_function_name},
},
model::*,
structure::ndarray::NpArray,
CodeGenContext, CodeGenerator,
};
pub fn call_nac3_ndarray_broadcast_to<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src_ndarray: Ptr<'ctx, StructModel<NpArray>>,
dst_ndarray: Ptr<'ctx, StructModel<NpArray>>,
) {
let tyctx = generator.type_context(ctx.ctx);
let perrctx = setup_error_context(tyctx, ctx);
CallFunction::begin(
tyctx,
ctx,
&get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_broadcast_to"),
)
.arg("errctx", perrctx)
.arg("src_ndarray", src_ndarray)
.arg("dst_ndarray", dst_ndarray)
.returning_void();
check_error_context(generator, ctx, perrctx);
}
/// Fields of [`ShapeEntry`]
pub struct ShapeEntryFields<F: FieldVisitor> {
pub ndims: F::Field<IntModel<SizeT>>,
pub shape: F::Field<PtrModel<IntModel<SizeT>>>,
}
#[derive(Debug, Clone, Copy, Default)]
pub struct ShapeEntry;
impl StructKind for ShapeEntry {
type Fields<F: FieldVisitor> = ShapeEntryFields<F>;
fn visit_fields<F: FieldVisitor>(&self, visitor: &mut F) -> Self::Fields<F> {
Self::Fields { ndims: visitor.add("ndims"), shape: visitor.add("shape") }
}
}
pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
num_shape_entries: Int<'ctx, SizeT>,
shape_entries: Ptr<'ctx, StructModel<ShapeEntry>>,
dst_ndims: Int<'ctx, SizeT>,
dst_shape: Ptr<'ctx, IntModel<SizeT>>,
) {
let tyctx = generator.type_context(ctx.ctx);
let perrctx = setup_error_context(tyctx, ctx);
CallFunction::begin(
tyctx,
ctx,
&get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_broadcast_shapes"),
)
.arg("errctx", perrctx)
.arg("num_shapes", num_shape_entries)
.arg("shapes", shape_entries)
.arg("dst_ndims", dst_ndims)
.arg("dst_shape", dst_shape)
.returning_void();
check_error_context(generator, ctx, perrctx);
}

View File

@ -0,0 +1,170 @@
use crate::codegen::{
irrt::{
error_context::{check_error_context, setup_error_context},
slice::{RustUserSlice, SliceIndex, UserSlice},
util::{function::CallFunction, get_sizet_dependent_function_name},
},
model::*,
structure::ndarray::NpArray,
CodeGenContext, CodeGenerator,
};
pub type NDIndexType = Byte;
#[derive(Debug, Clone, Copy)]
pub struct NDIndexFields<F: FieldVisitor> {
pub type_: F::Field<IntModel<NDIndexType>>, // Defined to be uint8_t in IRRT
pub data: F::Field<PtrModel<IntModel<Byte>>>,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct NDIndex;
impl StructKind for NDIndex {
type Fields<F: FieldVisitor> = NDIndexFields<F>;
fn visit_fields<F: FieldVisitor>(&self, visitor: &mut F) -> Self::Fields<F> {
Self::Fields { type_: visitor.add("type"), data: visitor.add("data") }
}
}
// An enum variant to store the content
// and type of an NDIndex in high level.
#[derive(Debug, Clone)]
pub enum RustNDIndex<'ctx> {
SingleElement(Int<'ctx, SliceIndex>),
Slice(RustUserSlice<'ctx>),
}
impl<'ctx> RustNDIndex<'ctx> {
fn get_type_id(&self) -> u64 {
// Defined in IRRT, must be in sync
match self {
RustNDIndex::SingleElement(_) => 0,
RustNDIndex::Slice(_) => 1,
}
}
fn write_to_ndindex(
&self,
tyctx: TypeContext<'ctx>,
ctx: &CodeGenContext<'ctx, '_>,
dst_ndindex_ptr: Ptr<'ctx, StructModel<NDIndex>>,
) {
let ndindex_type_model = IntModel(NDIndexType::default());
let slice_index_model = IntModel(SliceIndex::default());
let user_slice_model = StructModel(UserSlice);
// Set `dst_ndindex_ptr->type`
dst_ndindex_ptr
.gep(ctx, |f| f.type_)
.store(ctx, ndindex_type_model.constant(tyctx, ctx.ctx, self.get_type_id()));
// Set `dst_ndindex_ptr->data`
let data = match self {
RustNDIndex::SingleElement(in_index) => {
let index_ptr = slice_index_model.alloca(tyctx, ctx, "index");
index_ptr.store(ctx, *in_index);
index_ptr.transmute(tyctx, ctx, IntModel(Byte), "")
}
RustNDIndex::Slice(in_rust_slice) => {
let user_slice_ptr = user_slice_model.alloca(tyctx, ctx, "user_slice");
in_rust_slice.write_to_user_slice(tyctx, ctx, user_slice_ptr);
user_slice_ptr.transmute(tyctx, ctx, IntModel(Byte), "")
}
};
dst_ndindex_ptr.gep(ctx, |f| f.data).store(ctx, data);
}
/// Allocate an array of `NDIndex`es on the stack and return its stack pointer.
pub fn alloca_ndindexes(
tyctx: TypeContext<'ctx>,
ctx: &CodeGenContext<'ctx, '_>,
in_ndindexes: &[RustNDIndex<'ctx>],
) -> (Int<'ctx, SizeT>, Ptr<'ctx, StructModel<NDIndex>>) {
let sizet_model = IntModel(SizeT);
let ndindex_model = StructModel(NDIndex);
let num_ndindexes = sizet_model.constant(tyctx, ctx.ctx, in_ndindexes.len() as u64);
let ndindexes = ndindex_model.array_alloca(tyctx, ctx, num_ndindexes.value, "ndindexes");
for (i, in_ndindex) in in_ndindexes.iter().enumerate() {
let i = sizet_model.constant(tyctx, ctx.ctx, i as u64);
let pndindex = ndindexes.offset(tyctx, ctx, i.value, "");
in_ndindex.write_to_ndindex(tyctx, ctx, pndindex);
}
(num_ndindexes, ndindexes)
}
#[must_use]
pub fn deduce_ndims_after_indexing(indices: &[RustNDIndex], original_ndims: u64) -> u64 {
let mut final_ndims = original_ndims;
for index in indices {
match index {
RustNDIndex::SingleElement(_) => {
final_ndims -= 1;
}
RustNDIndex::Slice(_) => {}
}
}
final_ndims
}
}
pub fn call_nac3_ndarray_indexing_deduce_ndims_after_indexing<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndims: Int<'ctx, SizeT>,
num_ndindexes: Int<'ctx, SizeT>,
ndindexs: Ptr<'ctx, StructModel<NDIndex>>,
) -> Int<'ctx, SizeT> {
let tyctx = generator.type_context(ctx.ctx);
let sizet_model = IntModel(SizeT);
let pfinal_ndims = sizet_model.alloca(tyctx, ctx, "pfinal_ndims");
let errctx_ptr = setup_error_context(tyctx, ctx);
CallFunction::begin(
tyctx,
ctx,
&get_sizet_dependent_function_name(
tyctx,
"__nac3_ndarray_indexing_deduce_ndims_after_indexing",
),
)
.arg("errctx", errctx_ptr)
.arg("result", pfinal_ndims)
.arg("ndims", ndims)
.arg("num_ndindexs", num_ndindexes)
.arg("ndindexs", ndindexs)
.returning_void();
check_error_context(generator, ctx, errctx_ptr);
pfinal_ndims.load(tyctx, ctx, "final_ndims")
}
pub fn call_nac3_ndarray_index<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
num_indexes: Int<'ctx, SizeT>,
indexes: Ptr<'ctx, StructModel<NDIndex>>,
src_ndarray: Ptr<'ctx, StructModel<NpArray>>,
dst_ndarray: Ptr<'ctx, StructModel<NpArray>>,
) {
let tyctx = generator.type_context(ctx.ctx);
let perrctx = setup_error_context(tyctx, ctx);
CallFunction::begin(
tyctx,
ctx,
&get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_index"),
)
.arg("errctx", perrctx)
.arg("num_indexes", num_indexes)
.arg("indexes", indexes)
.arg("src_ndarray", src_ndarray)
.arg("dst_ndarray", dst_ndarray)
.returning_void();
check_error_context(generator, ctx, perrctx);
}

View File

@ -0,0 +1,4 @@
pub mod basic;
pub mod broadcast;
pub mod indexing;
pub mod reshape;

View File

@ -0,0 +1,31 @@
use crate::codegen::{
irrt::{
error_context::{check_error_context, setup_error_context},
util::{function::CallFunction, get_sizet_dependent_function_name},
},
model::*,
CodeGenContext, CodeGenerator,
};
pub fn call_nac3_ndarray_resolve_and_check_new_shape<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
size: Int<'ctx, SizeT>,
new_ndims: Int<'ctx, SizeT>,
new_shape: Ptr<'ctx, IntModel<SizeT>>,
) {
let tyctx = generator.type_context(ctx.ctx);
let perrctx = setup_error_context(tyctx, ctx);
CallFunction::begin(
tyctx,
ctx,
&get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_resolve_and_check_new_shape"),
)
.arg("errctx", perrctx)
.arg("size", size)
.arg("new_ndims", new_ndims)
.arg("new_shape", new_shape)
.returning_void();
check_error_context(generator, ctx, perrctx);
}

View File

@ -0,0 +1,81 @@
use crate::codegen::{model::*, CodeGenContext};
// nac3core's slicing index/length values are always int32_t
pub type SliceIndex = Int32;
#[derive(Debug, Clone)]
pub struct UserSliceFields<F: FieldVisitor> {
pub start_defined: F::Field<IntModel<Bool>>,
pub start: F::Field<IntModel<SliceIndex>>,
pub stop_defined: F::Field<IntModel<Bool>>,
pub stop: F::Field<IntModel<SliceIndex>>,
pub step_defined: F::Field<IntModel<Bool>>,
pub step: F::Field<IntModel<SliceIndex>>,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct UserSlice;
impl StructKind for UserSlice {
type Fields<F: FieldVisitor> = UserSliceFields<F>;
fn visit_fields<F: FieldVisitor>(&self, visitor: &mut F) -> Self::Fields<F> {
Self::Fields {
start_defined: visitor.add("start_defined"),
start: visitor.add("start"),
stop_defined: visitor.add("stop_defined"),
stop: visitor.add("stop"),
step_defined: visitor.add("step_defined"),
step: visitor.add("step"),
}
}
}
#[derive(Debug, Clone)]
pub struct RustUserSlice<'ctx> {
pub start: Option<Int<'ctx, SliceIndex>>,
pub stop: Option<Int<'ctx, SliceIndex>>,
pub step: Option<Int<'ctx, SliceIndex>>,
}
impl<'ctx> RustUserSlice<'ctx> {
// Set the values of an LLVM UserSlice
// in the format of Python's `slice()`
pub fn write_to_user_slice(
&self,
tyctx: TypeContext<'ctx>,
ctx: &CodeGenContext<'ctx, '_>,
dst_slice_ptr: Ptr<'ctx, StructModel<UserSlice>>,
) {
let bool_model = IntModel(Bool);
let false_ = bool_model.constant(tyctx, ctx.ctx, 0);
let true_ = bool_model.constant(tyctx, ctx.ctx, 1);
// TODO: Code duplication. Probably okay...?
match self.start {
Some(start) => {
dst_slice_ptr.gep(ctx, |f| f.start_defined).store(ctx, true_);
dst_slice_ptr.gep(ctx, |f| f.start).store(ctx, start);
}
None => dst_slice_ptr.gep(ctx, |f| f.start_defined).store(ctx, false_),
}
match self.stop {
Some(stop) => {
dst_slice_ptr.gep(ctx, |f| f.stop_defined).store(ctx, true_);
dst_slice_ptr.gep(ctx, |f| f.stop).store(ctx, stop);
}
None => dst_slice_ptr.gep(ctx, |f| f.stop_defined).store(ctx, false_),
}
match self.step {
Some(step) => {
dst_slice_ptr.gep(ctx, |f| f.step_defined).store(ctx, true_);
dst_slice_ptr.gep(ctx, |f| f.step).store(ctx, step);
}
None => dst_slice_ptr.gep(ctx, |f| f.step_defined).store(ctx, false_),
}
}
}

View File

@ -0,0 +1,26 @@
#[cfg(test)]
mod tests {
use std::{path::Path, process::Command};
#[test]
fn run_irrt_test() {
assert!(
cfg!(feature = "test"),
"Please do `cargo test -F test` to compile `irrt_test.out` and run test"
);
let irrt_test_out_path = Path::new(concat!(env!("OUT_DIR"), "/irrt_test.out"));
let output = Command::new(irrt_test_out_path.to_str().unwrap()).output().unwrap();
if !output.status.success() {
eprintln!("irrt_test failed with status {}:", output.status);
eprintln!("====== stdout ======");
eprintln!("{}", String::from_utf8(output.stdout).unwrap());
eprintln!("====== stderr ======");
eprintln!("{}", String::from_utf8(output.stderr).unwrap());
eprintln!("====================");
panic!("irrt_test failed");
}
}
}

View File

@ -0,0 +1,103 @@
use crate::codegen::model::*;
// When [`TypeContext::size_type`] is 32-bits, the function name is "{fn_name}".
// When [`TypeContext::size_type`] is 64-bits, the function name is "{fn_name}64".
#[must_use]
pub fn get_sizet_dependent_function_name(tyctx: TypeContext<'_>, name: &str) -> String {
let mut name = name.to_owned();
match tyctx.size_type.get_bit_width() {
32 => {}
64 => name.push_str("64"),
bit_width => {
panic!("Unsupported int type bit width {bit_width}, must be either 32-bits or 64-bits")
}
}
name
}
pub mod function {
use crate::codegen::{model::*, CodeGenContext};
use inkwell::{
types::{BasicMetadataTypeEnum, BasicType, FunctionType},
values::{AnyValue, BasicMetadataValueEnum, BasicValue, BasicValueEnum, CallSiteValue},
};
use itertools::Itertools;
#[derive(Debug, Clone, Copy)]
struct Arg<'ctx> {
ty: BasicMetadataTypeEnum<'ctx>,
val: BasicMetadataValueEnum<'ctx>,
}
/// Helper structure to reduce IRRT Inkwell function call boilerplate
/// TODO: Optimize
pub struct CallFunction<'ctx, 'a, 'b, 'c> {
tyctx: TypeContext<'ctx>,
ctx: &'b CodeGenContext<'ctx, 'a>,
/// Function name
name: &'c str,
/// Call arguments
args: Vec<Arg<'ctx>>,
}
impl<'ctx, 'a, 'b, 'c> CallFunction<'ctx, 'a, 'b, 'c> {
pub fn begin(
tyctx: TypeContext<'ctx>,
ctx: &'b CodeGenContext<'ctx, 'a>,
name: &'c str,
) -> Self {
CallFunction { tyctx, ctx, name, args: Vec::new() }
}
/// Push a call argument to the function call.
///
/// The `_name` parameter is there for self-documentation purposes.
#[allow(clippy::needless_pass_by_value)]
pub fn arg<M: Model>(mut self, _name: &str, arg: Instance<'ctx, M>) -> Self {
let arg = Arg {
ty: arg.model.get_type(self.tyctx, self.ctx.ctx).as_basic_type_enum().into(),
val: arg.value.as_basic_value_enum().into(),
};
self.args.push(arg);
self
}
/// Like [`CallFunction::returning_`] but `return_model` is automatically inferred.
pub fn returning<M: Model>(self, name: &str) -> Instance<'ctx, M> {
self.returning_(name, M::default())
}
/// Call the function and expect the function to return a value of type of `return_model`.
pub fn returning_<M: Model>(self, name: &str, return_model: M) -> Instance<'ctx, M> {
let ret_ty = return_model.get_type(self.tyctx, self.ctx.ctx);
let ret = self.get_function(|tys| ret_ty.fn_type(tys, false), name);
let ret = BasicValueEnum::try_from(ret.as_any_value_enum()).unwrap(); // Must work
let ret = return_model.check_value(self.tyctx, self.ctx.ctx, ret).unwrap(); // Must work
ret
}
/// Call the function and expect the function to return a void-type.
pub fn returning_void(self) {
let ret_ty = self.ctx.ctx.void_type();
let _ = self.get_function(|tys| ret_ty.fn_type(tys, false), "");
}
fn get_function<F>(&self, make_fn_type: F, return_value_name: &str) -> CallSiteValue<'ctx>
where
F: FnOnce(&[BasicMetadataTypeEnum<'ctx>]) -> FunctionType<'ctx>,
{
// Get the LLVM function, declare the function if it doesn't exist - it will be defined by other
// components of NAC3.
let func = self.ctx.module.get_function(self.name).unwrap_or_else(|| {
let tys = self.args.iter().map(|arg| arg.ty).collect_vec();
let fn_type = make_fn_type(&tys);
self.ctx.module.add_function(self.name, fn_type, None)
});
let vals = self.args.iter().map(|arg| arg.val).collect_vec();
self.ctx.builder.build_call(func, &vals, return_value_name).unwrap()
}
}
}

View File

@ -1,7 +1,7 @@
use crate::{ use crate::{
codegen::classes::{ListType, NDArrayType, ProxyType, RangeType}, codegen::classes::{ListType, ProxyType, RangeType},
symbol_resolver::{StaticValue, SymbolResolver}, symbol_resolver::{StaticValue, SymbolResolver},
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef}, toplevel::{helper::PrimDef, TopLevelContext, TopLevelDef},
typecheck::{ typecheck::{
type_inferencer::{CodeLocation, PrimitiveStore}, type_inferencer::{CodeLocation, PrimitiveStore},
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier}, typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
@ -24,6 +24,7 @@ use inkwell::{
AddressSpace, IntPredicate, OptimizationLevel, AddressSpace, IntPredicate, OptimizationLevel,
}; };
use itertools::Itertools; use itertools::Itertools;
use model::*;
use nac3parser::ast::{Location, Stmt, StrRef}; use nac3parser::ast::{Location, Stmt, StrRef};
use parking_lot::{Condvar, Mutex}; use parking_lot::{Condvar, Mutex};
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
@ -32,6 +33,7 @@ use std::sync::{
Arc, Arc,
}; };
use std::thread; use std::thread;
use structure::{cslice::CSlice, exception::Exception, ndarray::NpArray};
pub mod builtin_fns; pub mod builtin_fns;
pub mod classes; pub mod classes;
@ -41,8 +43,12 @@ pub mod extern_fns;
mod generator; mod generator;
pub mod irrt; pub mod irrt;
pub mod llvm_intrinsics; pub mod llvm_intrinsics;
pub mod model;
pub mod numpy; pub mod numpy;
pub mod numpy_new;
pub mod stmt; pub mod stmt;
pub mod structure;
pub mod util;
#[cfg(test)] #[cfg(test)]
mod test; mod test;
@ -68,6 +74,16 @@ pub struct CodeGenLLVMOptions {
pub target: CodeGenTargetMachineOptions, pub target: CodeGenTargetMachineOptions,
} }
impl CodeGenLLVMOptions {
/// Creates a [`TargetMachine`] using the target options specified by this struct.
///
/// See [`Target::create_target_machine`].
#[must_use]
pub fn create_target_machine(&self) -> Option<TargetMachine> {
self.target.create_target_machine(self.opt_level)
}
}
/// Additional options for code generation for the target machine. /// Additional options for code generation for the target machine.
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug, Eq, PartialEq)]
pub struct CodeGenTargetMachineOptions { pub struct CodeGenTargetMachineOptions {
@ -158,11 +174,11 @@ pub struct CodeGenContext<'ctx, 'a> {
pub registry: &'a WorkerRegistry, pub registry: &'a WorkerRegistry,
/// Cache for constant strings. /// Cache for constant strings.
pub const_strings: HashMap<String, BasicValueEnum<'ctx>>, pub const_strings: HashMap<String, Struct<'ctx, CSlice>>,
/// [`BasicBlock`] containing all `alloca` statements for the current function. /// [`BasicBlock`] containing all `alloca` statements for the current function.
pub init_bb: BasicBlock<'ctx>, pub init_bb: BasicBlock<'ctx>,
pub exception_val: Option<PointerValue<'ctx>>, pub exception_val: Option<Ptr<'ctx, StructModel<Exception>>>,
/// The header and exit basic blocks of a loop in this context. See /// The header and exit basic blocks of a loop in this context. See
/// <https://llvm.org/docs/LoopTerminology.html> for explanation of these terminology. /// <https://llvm.org/docs/LoopTerminology.html> for explanation of these terminology.
@ -338,6 +354,10 @@ impl WorkerRegistry {
let mut builder = context.create_builder(); let mut builder = context.create_builder();
let mut module = context.create_module(generator.get_name()); let mut module = context.create_module(generator.get_name());
let target_machine = self.llvm_options.create_target_machine().unwrap();
module.set_data_layout(&target_machine.get_target_data().get_data_layout());
module.set_triple(&target_machine.get_triple());
module.add_basic_value_flag( module.add_basic_value_flag(
"Debug Info Version", "Debug Info Version",
inkwell::module::FlagBehavior::Warning, inkwell::module::FlagBehavior::Warning,
@ -361,6 +381,10 @@ impl WorkerRegistry {
errors.insert(e); errors.insert(e);
// create a new empty module just to continue codegen and collect errors // create a new empty module just to continue codegen and collect errors
module = context.create_module(&format!("{}_recover", generator.get_name())); module = context.create_module(&format!("{}_recover", generator.get_name()));
let target_machine = self.llvm_options.create_target_machine().unwrap();
module.set_data_layout(&target_machine.get_target_data().get_data_layout());
module.set_triple(&target_machine.get_triple());
} }
} }
*self.task_count.lock() -= 1; *self.task_count.lock() -= 1;
@ -471,12 +495,9 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
} }
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (dtype, _) = unpack_ndarray_var_tys(unifier, ty); let tyctx = generator.type_context(ctx);
let element_type = get_llvm_type( let pndarray_model = PtrModel(StructModel(NpArray));
ctx, module, generator, unifier, top_level, type_cache, dtype, pndarray_model.get_type(tyctx, ctx).into()
);
NDArrayType::new(generator, ctx, element_type).as_base_type().into()
} }
_ => unreachable!( _ => unreachable!(
@ -646,43 +667,20 @@ pub fn gen_func_impl<
..primitives ..primitives
}; };
let mut type_cache: HashMap<_, _> = [ let type_context = generator.type_context(context);
let cslice_model = StructModel(CSlice);
let pexn_model = PtrModel(StructModel(Exception));
let mut type_cache: HashMap<_, BasicTypeEnum<'ctx>> = [
(primitives.int32, context.i32_type().into()), (primitives.int32, context.i32_type().into()),
(primitives.int64, context.i64_type().into()), (primitives.int64, context.i64_type().into()),
(primitives.uint32, context.i32_type().into()), (primitives.uint32, context.i32_type().into()),
(primitives.uint64, context.i64_type().into()), (primitives.uint64, context.i64_type().into()),
(primitives.float, context.f64_type().into()), (primitives.float, context.f64_type().into()),
(primitives.bool, context.i8_type().into()), (primitives.bool, context.i8_type().into()),
(primitives.str, { (primitives.str, cslice_model.get_type(type_context, context).into()),
let name = "str";
match module.get_struct_type(name) {
None => {
let str_type = context.opaque_struct_type("str");
let fields = [
context.i8_type().ptr_type(AddressSpace::default()).into(),
generator.get_size_type(context).into(),
];
str_type.set_body(&fields, false);
str_type.into()
}
Some(t) => t.as_basic_type_enum(),
}
}),
(primitives.range, RangeType::new(context).as_base_type().into()), (primitives.range, RangeType::new(context).as_base_type().into()),
(primitives.exception, { (primitives.exception, pexn_model.get_type(type_context, context).into()),
let name = "Exception";
if let Some(t) = module.get_struct_type(name) {
t.ptr_type(AddressSpace::default()).as_basic_type_enum()
} else {
let exception = context.opaque_struct_type("Exception");
let int32 = context.i32_type().into();
let int64 = context.i64_type().into();
let str_ty = module.get_struct_type("str").unwrap().as_basic_type_enum();
let fields = [int32, str_ty, int32, int32, str_ty, str_ty, int64, int64, int64];
exception.set_body(&fields, false);
exception.ptr_type(AddressSpace::default()).as_basic_type_enum()
}
}),
] ]
.iter() .iter()
.copied() .copied()

View File

@ -0,0 +1,161 @@
use std::fmt;
use inkwell::{context::Context, types::*, values::*};
use super::*;
use crate::codegen::{CodeGenContext, CodeGenerator};
#[derive(Clone, Copy)]
pub struct TypeContext<'ctx> {
pub size_type: IntType<'ctx>,
}
pub trait HasTypeContext {
fn type_context<'ctx>(&self, ctx: &'ctx Context) -> TypeContext<'ctx>;
}
impl<T: CodeGenerator + ?Sized> HasTypeContext for T {
fn type_context<'ctx>(&self, ctx: &'ctx Context) -> TypeContext<'ctx> {
TypeContext { size_type: self.get_size_type(ctx) }
}
}
#[derive(Debug, Clone)]
pub struct ModelError(pub String);
impl ModelError {
pub(super) fn under_context(mut self, context: &str) -> Self {
self.0.push_str(" ... in ");
self.0.push_str(context);
self
}
}
/// A [`Model`] is a singleton object that uniquely identifies a [`BasicType`]
/// solely from a [`CodeGenerator`] and a [`Context`].
pub trait Model: CheckType + fmt::Debug + Clone + Copy + Default {
type Value<'ctx>: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>>;
type Type<'ctx>: BasicType<'ctx>;
/// Return the [`BasicType`] of this model.
fn get_type<'ctx>(&self, tyctx: TypeContext<'ctx>, ctx: &'ctx Context) -> Self::Type<'ctx>;
/// Check if a [`BasicType`] is the same type of this model.
fn check_type<'ctx, T: BasicType<'ctx>>(
&self,
tyctx: TypeContext<'ctx>,
ctx: &'ctx Context,
ty: T,
) -> Result<(), ModelError> {
let ty = ty.as_basic_type_enum();
self.check_type_impl(tyctx, ctx, ty.as_basic_type_enum())
}
/// Create an instance from a value with [`Instance::model`] being this model.
///
/// Caller must make sure the type of `value` and the type of this `model` are equivalent.
fn believe_value<'ctx>(&self, value: Self::Value<'ctx>) -> Instance<'ctx, Self> {
Instance { model: *self, value }
}
/// Check if a [`BasicValue`]'s type is equivalent to the type of this model.
/// Wrap it into an [`Instance`] if it is.
fn check_value<'ctx, V: BasicValue<'ctx>>(
&self,
tyctx: TypeContext<'ctx>,
ctx: &'ctx Context,
value: V,
) -> Result<Instance<'ctx, Self>, ModelError> {
let value = value.as_basic_value_enum();
self.check_type(tyctx, ctx, value.get_type())
.map_err(|err| err.under_context("the value {value:?}"))?;
let Ok(value) = Self::Value::try_from(value) else {
unreachable!("check_type() has bad implementation")
};
Ok(self.believe_value(value))
}
// Allocate a value on the stack and return its pointer.
fn alloca<'ctx>(
&self,
tyctx: TypeContext<'ctx>,
ctx: &CodeGenContext<'ctx, '_>,
name: &str,
) -> Ptr<'ctx, Self> {
let pmodel = PtrModel(*self);
let p = ctx.builder.build_alloca(self.get_type(tyctx, ctx.ctx), name).unwrap();
pmodel.believe_value(p)
}
// Allocate an array on the stack and return its pointer.
fn array_alloca<'ctx>(
&self,
tyctx: TypeContext<'ctx>,
ctx: &CodeGenContext<'ctx, '_>,
len: IntValue<'ctx>,
name: &str,
) -> Ptr<'ctx, Self> {
let pmodel = PtrModel(*self);
let p = ctx.builder.build_array_alloca(self.get_type(tyctx, ctx.ctx), len, name).unwrap();
pmodel.believe_value(p)
}
fn var_alloca<'ctx, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&str>,
) -> Result<Ptr<'ctx, Self>, String> {
let tyctx = generator.type_context(ctx.ctx);
let pmodel = PtrModel(*self);
let p = generator.gen_var_alloc(
ctx,
self.get_type(tyctx, ctx.ctx).as_basic_type_enum(),
name,
)?;
Ok(pmodel.believe_value(p))
}
fn array_var_alloca<'ctx, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
len: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> Result<Ptr<'ctx, Self>, String> {
let tyctx = generator.type_context(ctx.ctx);
// TODO: Remove ArraySliceValue
let pmodel = PtrModel(*self);
let p = generator.gen_array_var_alloc(
ctx,
self.get_type(tyctx, ctx.ctx).as_basic_type_enum(),
len,
name,
)?;
Ok(pmodel.believe_value(PointerValue::from(p)))
}
}
#[derive(Debug, Clone, Copy)]
pub struct Instance<'ctx, M: Model> {
/// The model of this instance.
pub model: M,
/// The value of this instance.
///
/// Caller must make sure the type of `value` and the type of this `model` are equivalent,
/// down to having the same [`IntType::get_bit_width`] in case of [`IntType`] for example.
pub value: M::Value<'ctx>,
}
// NOTE: Must be Rust object-safe - This must be typeable for a Rust trait object.
pub trait CheckType {
fn check_type_impl<'ctx>(
&self,
tyctx: TypeContext<'ctx>,
ctx: &'ctx Context,
ty: BasicTypeEnum<'ctx>,
) -> Result<(), ModelError>;
}

View File

@ -0,0 +1,228 @@
use std::fmt;
use inkwell::{
context::Context,
types::{BasicTypeEnum, IntType},
values::IntValue,
IntPredicate,
};
use crate::codegen::{CodeGenContext, CodeGenerator};
use super::*;
pub trait IntKind: fmt::Debug + Clone + Copy + Default {
fn get_int_type<'ctx>(&self, tyctx: TypeContext<'ctx>, ctx: &'ctx Context) -> IntType<'ctx>;
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Bool;
#[derive(Debug, Clone, Copy, Default)]
pub struct Byte;
#[derive(Debug, Clone, Copy, Default)]
pub struct Int32;
#[derive(Debug, Clone, Copy, Default)]
pub struct Int64;
#[derive(Debug, Clone, Copy, Default)]
pub struct SizeT;
impl IntKind for Bool {
fn get_int_type<'ctx>(&self, _tyctx: TypeContext<'ctx>, ctx: &'ctx Context) -> IntType<'ctx> {
ctx.bool_type()
}
}
impl IntKind for Byte {
fn get_int_type<'ctx>(&self, _tyctx: TypeContext<'ctx>, ctx: &'ctx Context) -> IntType<'ctx> {
ctx.i8_type()
}
}
impl IntKind for Int32 {
fn get_int_type<'ctx>(&self, _tyctx: TypeContext<'ctx>, ctx: &'ctx Context) -> IntType<'ctx> {
ctx.i32_type()
}
}
impl IntKind for Int64 {
fn get_int_type<'ctx>(&self, _tyctx: TypeContext<'ctx>, ctx: &'ctx Context) -> IntType<'ctx> {
ctx.i64_type()
}
}
impl IntKind for SizeT {
fn get_int_type<'ctx>(&self, tyctx: TypeContext<'ctx>, _ctx: &'ctx Context) -> IntType<'ctx> {
tyctx.size_type
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct IntModel<N: IntKind>(pub N);
pub type Int<'ctx, N> = Instance<'ctx, IntModel<N>>;
impl<N: IntKind> CheckType for IntModel<N> {
fn check_type_impl<'ctx>(
&self,
tyctx: TypeContext<'ctx>,
ctx: &'ctx Context,
ty: BasicTypeEnum<'ctx>,
) -> Result<(), ModelError> {
let Ok(ty) = IntType::try_from(ty) else {
return Err(ModelError(format!("Expecting IntType, but got {ty:?}")));
};
let exp_ty = self.0.get_int_type(tyctx, ctx);
if ty.get_bit_width() != exp_ty.get_bit_width() {
return Err(ModelError(format!(
"Expecting IntType to have {} bit(s), but got {} bit(s)",
exp_ty.get_bit_width(),
ty.get_bit_width()
)));
}
Ok(())
}
}
impl<N: IntKind> Model for IntModel<N> {
type Value<'ctx> = IntValue<'ctx>;
type Type<'ctx> = IntType<'ctx>;
#[must_use]
fn get_type<'ctx>(&self, tyctx: TypeContext<'ctx>, ctx: &'ctx Context) -> Self::Type<'ctx> {
self.0.get_int_type(tyctx, ctx)
}
}
impl<N: IntKind> IntModel<N> {
pub fn constant<'ctx>(
&self,
tyctx: TypeContext<'ctx>,
ctx: &'ctx Context,
value: u64,
) -> Int<'ctx, N> {
let value = self.get_type(tyctx, ctx).const_int(value, false);
self.believe_value(value)
}
pub fn const_0<'ctx>(&self, tyctx: TypeContext<'ctx>, ctx: &'ctx Context) -> Int<'ctx, N> {
self.constant(tyctx, ctx, 0)
}
pub fn const_1<'ctx>(&self, tyctx: TypeContext<'ctx>, ctx: &'ctx Context) -> Int<'ctx, N> {
self.constant(tyctx, ctx, 1)
}
pub fn s_extend_or_bit_cast<'ctx>(
&self,
tyctx: TypeContext<'ctx>,
ctx: &CodeGenContext<'ctx, '_>,
value: IntValue<'ctx>,
name: &str,
) -> Int<'ctx, N> {
let value = ctx
.builder
.build_int_s_extend_or_bit_cast(value, self.get_type(tyctx, ctx.ctx), name)
.unwrap();
self.believe_value(value)
}
pub fn truncate<'ctx>(
&self,
tyctx: TypeContext<'ctx>,
ctx: &CodeGenContext<'ctx, '_>,
value: IntValue<'ctx>,
name: &str,
) -> Int<'ctx, N> {
let value =
ctx.builder.build_int_truncate(value, self.get_type(tyctx, ctx.ctx), name).unwrap();
self.believe_value(value)
}
}
impl IntModel<Bool> {
#[must_use]
pub fn const_false<'ctx>(
&self,
tyctx: TypeContext<'ctx>,
ctx: &'ctx Context,
) -> Int<'ctx, Bool> {
self.constant(tyctx, ctx, 0)
}
#[must_use]
pub fn const_true<'ctx>(
&self,
tyctx: TypeContext<'ctx>,
ctx: &'ctx Context,
) -> Int<'ctx, Bool> {
self.constant(tyctx, ctx, 1)
}
}
impl<'ctx, N: IntKind> Int<'ctx, N> {
pub fn s_extend_or_bit_cast<NewN: IntKind, G: CodeGenerator + ?Sized>(
&self,
tyctx: TypeContext<'ctx>,
ctx: &CodeGenContext<'ctx, '_>,
to_int_kind: NewN,
name: &str,
) -> Int<'ctx, NewN> {
IntModel(to_int_kind).s_extend_or_bit_cast(tyctx, ctx, self.value, name)
}
pub fn truncate<NewN: IntKind, G: CodeGenerator + ?Sized>(
&self,
tyctx: TypeContext<'ctx>,
ctx: &CodeGenContext<'ctx, '_>,
to_int_kind: NewN,
name: &str,
) -> Int<'ctx, NewN> {
IntModel(to_int_kind).truncate(tyctx, ctx, self.value, name)
}
#[must_use]
pub fn add<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
other: Int<'ctx, N>,
name: &str,
) -> Int<'ctx, N> {
let value = ctx.builder.build_int_add(self.value, other.value, name).unwrap();
self.model.believe_value(value)
}
#[must_use]
pub fn sub<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
other: Int<'ctx, N>,
name: &str,
) -> Int<'ctx, N> {
let value = ctx.builder.build_int_sub(self.value, other.value, name).unwrap();
self.model.believe_value(value)
}
#[must_use]
pub fn mul<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
other: Int<'ctx, N>,
name: &str,
) -> Int<'ctx, N> {
let value = ctx.builder.build_int_mul(self.value, other.value, name).unwrap();
self.model.believe_value(value)
}
pub fn compare<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
op: IntPredicate,
other: Int<'ctx, N>,
name: &str,
) -> Int<'ctx, Bool> {
let bool_model = IntModel(Bool);
let value = ctx.builder.build_int_compare(op, self.value, other.value, name).unwrap();
bool_model.believe_value(value)
}
}

View File

@ -0,0 +1,12 @@
mod core;
mod int;
mod ptr;
mod slice;
mod structure;
pub mod util;
pub use core::*;
pub use int::*;
pub use ptr::*;
pub use slice::*;
pub use structure::*;

View File

@ -0,0 +1,142 @@
use inkwell::{
context::Context,
types::{BasicType, BasicTypeEnum, PointerType},
values::{IntValue, PointerValue},
AddressSpace,
};
use crate::codegen::CodeGenContext;
use super::*;
#[derive(Debug, Clone, Copy, Default)]
pub struct PtrModel<Element>(pub Element);
pub type Ptr<'ctx, Element> = Instance<'ctx, PtrModel<Element>>;
impl<Element: CheckType> CheckType for PtrModel<Element> {
fn check_type_impl<'ctx>(
&self,
tyctx: TypeContext<'ctx>,
ctx: &'ctx Context,
ty: BasicTypeEnum<'ctx>,
) -> Result<(), super::ModelError> {
let Ok(ty) = PointerType::try_from(ty) else {
return Err(ModelError(format!("Expecting PointerType, but got {ty:?}")));
};
let elem_ty = ty.get_element_type();
let Ok(elem_ty) = BasicTypeEnum::try_from(elem_ty) else {
return Err(ModelError(format!(
"Expecting pointer element type to be a BasicTypeEnum, but got {elem_ty:?}"
)));
};
// TODO: inkwell `get_element_type()` will be deprecated.
// Remove the check for `get_element_type()` when the time comes.
self.0
.check_type_impl(tyctx, ctx, elem_ty)
.map_err(|err| err.under_context("a PointerType"))?;
Ok(())
}
}
impl<Element: Model> Model for PtrModel<Element> {
type Value<'ctx> = PointerValue<'ctx>;
type Type<'ctx> = PointerType<'ctx>;
fn get_type<'ctx>(&self, tyctx: TypeContext<'ctx>, ctx: &'ctx Context) -> Self::Type<'ctx> {
self.0.get_type(tyctx, ctx).ptr_type(AddressSpace::default())
}
}
impl<Element: Model> PtrModel<Element> {
/// Return a ***constant*** nullptr.
pub fn nullptr<'ctx>(
&self,
tyctx: TypeContext<'ctx>,
ctx: &'ctx Context,
) -> Ptr<'ctx, Element> {
let ptr = self.get_type(tyctx, ctx).const_null();
self.believe_value(ptr)
}
pub fn transmute<'ctx>(
&self,
tyctx: TypeContext<'ctx>,
ctx: &CodeGenContext<'ctx, '_>,
ptr: PointerValue<'ctx>,
name: &str,
) -> Ptr<'ctx, Element> {
let ptr = ctx.builder.build_pointer_cast(ptr, self.get_type(tyctx, ctx.ctx), name).unwrap();
self.believe_value(ptr)
}
}
impl<'ctx, Element: Model> Ptr<'ctx, Element> {
/// Offset the pointer by [`inkwell::builder::Builder::build_in_bounds_gep`].
#[must_use]
pub fn offset(
&self,
tyctx: TypeContext<'ctx>,
ctx: &CodeGenContext<'ctx, '_>,
offset: IntValue<'ctx>,
name: &str,
) -> Ptr<'ctx, Element> {
let new_ptr =
unsafe { ctx.builder.build_in_bounds_gep(self.value, &[offset], name).unwrap() };
self.model.check_value(tyctx, ctx.ctx, new_ptr).unwrap()
}
// Load the `i`-th element (0-based) on the array with [`inkwell::builder::Builder::build_in_bounds_gep`].
pub fn ix(
&self,
tyctx: TypeContext<'ctx>,
ctx: &CodeGenContext<'ctx, '_>,
i: IntValue<'ctx>,
name: &str,
) -> Instance<'ctx, Element> {
self.offset(tyctx, ctx, i, name).load(tyctx, ctx, name)
}
/// Load the value with [`inkwell::builder::Builder::build_load`].
pub fn load(
&self,
tyctx: TypeContext<'ctx>,
ctx: &CodeGenContext<'ctx, '_>,
name: &str,
) -> Instance<'ctx, Element> {
let value = ctx.builder.build_load(self.value, name).unwrap();
self.model.0.check_value(tyctx, ctx.ctx, value).unwrap() // If unwrap() panics, there is a logic error.
}
/// Store a value with [`inkwell::builder::Builder::build_store`].
pub fn store(&self, ctx: &CodeGenContext<'ctx, '_>, value: Instance<'ctx, Element>) {
ctx.builder.build_store(self.value, value.value).unwrap();
}
/// Return a casted pointer of element type `NewElement` with [`inkwell::builder::Builder::build_pointer_cast`].
pub fn transmute<NewElement: Model>(
&self,
tyctx: TypeContext<'ctx>,
ctx: &CodeGenContext<'ctx, '_>,
new_model: NewElement,
name: &str,
) -> Ptr<'ctx, NewElement> {
PtrModel(new_model).transmute(tyctx, ctx, self.value, name)
}
/// Check if the pointer is null with [`inkwell::builder::Builder::build_is_null`].
pub fn is_null(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> Int<'ctx, Bool> {
let bool_model = IntModel(Bool);
let value = ctx.builder.build_is_null(self.value, name).unwrap();
bool_model.believe_value(value)
}
/// Check if the pointer is not null with [`inkwell::builder::Builder::build_is_not_null`].
pub fn is_not_null(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> Int<'ctx, Bool> {
let bool_model = IntModel(Bool);
let value = ctx.builder.build_is_not_null(self.value, name).unwrap();
bool_model.believe_value(value)
}
}

View File

@ -0,0 +1,72 @@
use crate::codegen::{CodeGenContext, CodeGenerator};
use super::*;
/// A slice - literally just a pointer and a length value.
///
/// NOTE: This is NOT a [`Model`].
pub struct ArraySlice<'ctx, Len: IntKind, Item: Model> {
pub base: Ptr<'ctx, Item>,
pub len: Int<'ctx, Len>,
}
impl<'ctx, Len: IntKind, Item: Model> ArraySlice<'ctx, Len, Item> {
/// Get the `idx`-nth element of this [`ArraySlice`], but doesn't do an assertion to see if `idx` is out of bounds or not.
///
/// Also see [`ArraySlice::ix`].
pub fn ix_unchecked(
&self,
tyctx: TypeContext<'ctx>,
ctx: &CodeGenContext<'ctx, '_>,
idx: Int<'ctx, Len>,
name: &str,
) -> Ptr<'ctx, Item> {
let element_ptr = unsafe {
ctx.builder.build_in_bounds_gep(self.base.value, &[idx.value], name).unwrap()
};
self.base.model.check_value(tyctx, ctx.ctx, element_ptr).unwrap()
}
/// Call [`ArraySlice::ix_unchecked`], but checks if `idx` is in bounds, otherwise a runtime `IndexError` will be thrown.
pub fn ix<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
idx: Int<'ctx, Len>,
name: &str,
) -> Ptr<'ctx, Item> {
let tyctx = generator.type_context(ctx.ctx);
let len_model = IntModel(Len::default());
// Assert `0 <= idx < length` and throw an Exception if `idx` is out of bounds
let lower_bounded = ctx
.builder
.build_int_compare(
inkwell::IntPredicate::SLE,
len_model.constant(tyctx, ctx.ctx, 0).value,
idx.value,
"lower_bounded",
)
.unwrap();
let upper_bounded = ctx
.builder
.build_int_compare(
inkwell::IntPredicate::SLT,
idx.value,
self.len.value,
"upper_bounded",
)
.unwrap();
let bounded = ctx.builder.build_and(lower_bounded, upper_bounded, "bounded").unwrap();
ctx.make_assert(
generator,
bounded,
"0:IndexError",
"nac3core LLVM codegen attempting to access out of bounds array index {0}. Must satisfy 0 <= index < {2}",
[ Some(idx.value), Some(self.len.value), None],
ctx.current_loc
);
self.ix_unchecked(tyctx, ctx, idx, name)
}
}

View File

@ -0,0 +1,174 @@
use std::fmt;
use inkwell::{
context::Context,
types::{BasicType, BasicTypeEnum, StructType},
values::StructValue,
};
use itertools::izip;
use crate::codegen::CodeGenContext;
use super::*;
#[derive(Debug, Clone, Copy)]
pub struct GepField<M: Model> {
pub gep_index: u64,
pub name: &'static str,
pub model: M,
}
pub trait FieldVisitor {
type Field<M: Model + 'static>;
fn add<M: Model + 'static>(&mut self, name: &'static str) -> Self::Field<M>;
}
pub struct GepFieldVisitor {
gep_index_counter: u64,
}
impl FieldVisitor for GepFieldVisitor {
type Field<M: Model + 'static> = GepField<M>;
fn add<M: Model + 'static>(&mut self, name: &'static str) -> Self::Field<M> {
let gep_index = self.gep_index_counter;
self.gep_index_counter += 1;
Self::Field { gep_index, name, model: M::default() }
}
}
struct TypeFieldVisitor<'ctx> {
tyctx: TypeContext<'ctx>,
ctx: &'ctx Context,
field_types: Vec<BasicTypeEnum<'ctx>>,
}
impl<'ctx> FieldVisitor for TypeFieldVisitor<'ctx> {
type Field<M: Model + 'static> = ();
fn add<M: Model + 'static>(&mut self, _name: &'static str) -> Self::Field<M> {
self.field_types.push(M::default().get_type(self.tyctx, self.ctx).as_basic_type_enum());
}
}
struct CheckTypeEntry {
check_type: Box<dyn CheckType + 'static>,
name: &'static str,
}
struct CheckTypeFieldVisitor<'ctx> {
tyctx: TypeContext<'ctx>,
ctx: &'ctx Context,
check_types: Vec<CheckTypeEntry>,
}
impl<'ctx> FieldVisitor for CheckTypeFieldVisitor<'ctx> {
type Field<M: Model + 'static> = ();
fn add<M: Model + 'static>(&mut self, name: &'static str) -> Self::Field<M> {
self.check_types.push(CheckTypeEntry { check_type: Box::<M>::default(), name });
}
}
pub trait StructKind: fmt::Debug + Clone + Copy + Default {
type Fields<F: FieldVisitor>;
fn visit_fields<F: FieldVisitor>(&self, visitor: &mut F) -> Self::Fields<F>;
fn fields(&self) -> Self::Fields<GepFieldVisitor> {
self.visit_fields(&mut GepFieldVisitor { gep_index_counter: 0 })
}
fn get_struct_type<'ctx>(
&self,
tyctx: TypeContext<'ctx>,
ctx: &'ctx Context,
) -> StructType<'ctx> {
let mut visitor = TypeFieldVisitor { tyctx, ctx, field_types: Vec::new() };
self.visit_fields(&mut visitor);
ctx.struct_type(&visitor.field_types, false)
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct StructModel<S: StructKind>(pub S);
pub type Struct<'ctx, S> = Instance<'ctx, StructModel<S>>;
impl<S: StructKind> CheckType for StructModel<S> {
fn check_type_impl<'ctx>(
&self,
tyctx: TypeContext<'ctx>,
ctx: &'ctx Context,
ty: BasicTypeEnum<'ctx>,
) -> Result<(), ModelError> {
let ty = ty.as_basic_type_enum();
let Ok(ty) = StructType::try_from(ty) else {
return Err(ModelError(format!("Expecting StructType, but got {ty:?}")));
};
let field_types = ty.get_field_types();
let check_types = {
let mut builder = CheckTypeFieldVisitor { tyctx, ctx, check_types: Vec::new() };
self.0.visit_fields(&mut builder);
builder.check_types
};
if check_types.len() != field_types.len() {
return Err(ModelError(format!(
"Expecting StructType to have {} field(s), but got {} field(s)",
check_types.len(),
field_types.len()
)));
}
for (field_i, (entry, field_type)) in izip!(check_types, field_types).enumerate() {
let field_at = field_i + 1;
entry.check_type.check_type_impl(tyctx, ctx, field_type).map_err(|err| {
err.under_context(format!("struct field #{field_at} '{}'", entry.name).as_str())
})?;
}
Ok(())
}
}
impl<S: StructKind> Model for StructModel<S> {
type Value<'ctx> = StructValue<'ctx>;
type Type<'ctx> = StructType<'ctx>;
fn get_type<'ctx>(&self, tyctx: TypeContext<'ctx>, ctx: &'ctx Context) -> Self::Type<'ctx> {
self.0.get_struct_type(tyctx, ctx)
}
}
impl<'ctx, S: StructKind> Ptr<'ctx, StructModel<S>> {
pub fn gep<M, GetField>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
get_field: GetField,
) -> Ptr<'ctx, M>
where
M: Model,
GetField: FnOnce(S::Fields<GepFieldVisitor>) -> GepField<M>,
{
let field = get_field(self.model.0 .0.fields());
let llvm_i32 = ctx.ctx.i32_type(); // must be i32, if its i64 then rust segfaults
let ptr = unsafe {
ctx.builder
.build_in_bounds_gep(
self.value,
&[llvm_i32.const_zero(), llvm_i32.const_int(field.gep_index, false)],
field.name,
)
.unwrap()
};
let ptr_model = PtrModel(field.model);
ptr_model.believe_value(ptr)
}
}

View File

@ -0,0 +1,23 @@
use inkwell::{types::BasicType, values::IntValue};
use crate::codegen::{llvm_intrinsics::call_memcpy_generic, CodeGenContext};
use super::*;
pub fn gen_model_memcpy<'ctx, M: Model>(
tyctx: TypeContext<'ctx>,
ctx: &CodeGenContext<'ctx, '_>,
dst: Ptr<'ctx, M>,
src: Ptr<'ctx, M>,
num_elements: IntValue<'ctx>,
volatile: bool,
) {
let bool_model = IntModel(Bool);
let itemsize = M::default().get_type(tyctx, ctx.ctx).size_of().unwrap();
let totalsize =
ctx.builder.build_int_mul(itemsize, num_elements, "model_memcpy_totalsize").unwrap();
let is_volatile = bool_model.constant(tyctx, ctx.ctx, u64::from(volatile));
call_memcpy_generic(ctx, dst.value, src.value, totalsize, is_volatile.value);
}

View File

@ -86,6 +86,7 @@ where
gen_for_callback_incrementing( gen_for_callback_incrementing(
generator, generator,
ctx, ctx,
None,
llvm_usize.const_zero(), llvm_usize.const_zero(),
(shape_len, false), (shape_len, false),
|generator, ctx, _, i| { |generator, ctx, _, i| {
@ -131,6 +132,7 @@ where
gen_for_callback_incrementing( gen_for_callback_incrementing(
generator, generator,
ctx, ctx,
None,
llvm_usize.const_zero(), llvm_usize.const_zero(),
(shape_len, false), (shape_len, false),
|generator, ctx, _, i| { |generator, ctx, _, i| {
@ -252,7 +254,7 @@ fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) { } else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
ctx.ctx.bool_type().const_zero().into() ctx.ctx.bool_type().const_zero().into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) { } else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
ctx.gen_string(generator, "") ctx.gen_string(generator, "").value.into()
} else { } else {
unreachable!() unreachable!()
} }
@ -280,7 +282,7 @@ fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) { } else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
ctx.ctx.bool_type().const_int(1, false).into() ctx.ctx.bool_type().const_int(1, false).into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) { } else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
ctx.gen_string(generator, "1") ctx.gen_string(generator, "1").value.into()
} else { } else {
unreachable!() unreachable!()
} }
@ -382,6 +384,7 @@ where
gen_for_callback_incrementing( gen_for_callback_incrementing(
generator, generator,
ctx, ctx,
None,
llvm_usize.const_zero(), llvm_usize.const_zero(),
(ndarray_num_elems, false), (ndarray_num_elems, false),
|generator, ctx, _, i| { |generator, ctx, _, i| {
@ -703,11 +706,12 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>(
gen_for_range_callback( gen_for_range_callback(
generator, generator,
ctx, ctx,
None,
true, true,
|_, _| Ok(llvm_usize.const_zero()), |_, _| Ok(llvm_usize.const_zero()),
(|_, ctx| Ok(src_lst.load_size(ctx, None)), false), (|_, ctx| Ok(src_lst.load_size(ctx, None)), false),
|_, _| Ok(llvm_usize.const_int(1, false)), |_, _| Ok(llvm_usize.const_int(1, false)),
|generator, ctx, i| { |generator, ctx, _, i| {
let offset = ctx.builder.build_int_mul(stride, i, "").unwrap(); let offset = ctx.builder.build_int_mul(stride, i, "").unwrap();
let dst_ptr = let dst_ptr =
@ -943,11 +947,12 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
gen_for_range_callback( gen_for_range_callback(
generator, generator,
ctx, ctx,
None,
true, true,
|_, _| Ok(llvm_usize.const_zero()), |_, _| Ok(llvm_usize.const_zero()),
(|_, _| Ok(stop), false), (|_, _| Ok(stop), false),
|_, _| Ok(llvm_usize.const_int(1, false)), |_, _| Ok(llvm_usize.const_int(1, false)),
|generator, ctx, _| { |generator, ctx, _, _| {
let plist_plist_i8 = make_llvm_list(llvm_plist_i8.into()) let plist_plist_i8 = make_llvm_list(llvm_plist_i8.into())
.ptr_type(AddressSpace::default()); .ptr_type(AddressSpace::default());
@ -1086,13 +1091,17 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
// If there are no (remaining) slice expressions, memcpy the entire dimension // If there are no (remaining) slice expressions, memcpy the entire dimension
if slices.is_empty() { if slices.is_empty() {
let sizeof_elem = ctx.get_llvm_type(generator, elem_ty).size_of().unwrap();
let stride = call_ndarray_calc_size( let stride = call_ndarray_calc_size(
generator, generator,
ctx, ctx,
&src_arr.dim_sizes(), &src_arr.dim_sizes(),
(Some(llvm_usize.const_int(dim, false)), None), (Some(llvm_usize.const_int(dim, false)), None),
); );
let sizeof_elem = ctx.get_llvm_type(generator, elem_ty).size_of().unwrap(); let stride =
ctx.builder.build_int_z_extend_or_bit_cast(stride, sizeof_elem.get_type(), "").unwrap();
let cpy_len = ctx.builder.build_int_mul(stride, sizeof_elem, "").unwrap(); let cpy_len = ctx.builder.build_int_mul(stride, sizeof_elem, "").unwrap();
call_memcpy_generic(ctx, dst_slice_ptr, src_slice_ptr, cpy_len, llvm_i1.const_zero()); call_memcpy_generic(ctx, dst_slice_ptr, src_slice_ptr, cpy_len, llvm_i1.const_zero());
@ -1126,11 +1135,12 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
gen_for_range_callback( gen_for_range_callback(
generator, generator,
ctx, ctx,
None,
false, false,
|_, _| Ok(start), |_, _| Ok(start),
(|_, _| Ok(stop), true), (|_, _| Ok(stop), true),
|_, _| Ok(step), |_, _| Ok(step),
|generator, ctx, src_i| { |generator, ctx, _, src_i| {
// Calculate the offset of the active slice // Calculate the offset of the active slice
let src_data_offset = ctx.builder.build_int_mul(src_stride, src_i, "").unwrap(); let src_data_offset = ctx.builder.build_int_mul(src_stride, src_i, "").unwrap();
let dst_i = let dst_i =
@ -1243,6 +1253,7 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
gen_for_callback_incrementing( gen_for_callback_incrementing(
generator, generator,
ctx, ctx,
None,
llvm_usize.const_int(slices.len() as u64, false), llvm_usize.const_int(slices.len() as u64, false),
(this.load_ndims(ctx), false), (this.load_ndims(ctx), false),
|generator, ctx, _, idx| { |generator, ctx, _, idx| {
@ -1647,6 +1658,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
gen_for_callback_incrementing( gen_for_callback_incrementing(
generator, generator,
ctx, ctx,
None,
llvm_i32.const_zero(), llvm_i32.const_zero(),
(common_dim, false), (common_dim, false),
|generator, ctx, _, i| { |generator, ctx, _, i| {

View File

@ -0,0 +1,113 @@
use itertools::Itertools;
use crate::{
codegen::{
irrt::ndarray::broadcast::{
call_nac3_ndarray_broadcast_shapes, call_nac3_ndarray_broadcast_to, ShapeEntry,
},
model::*,
numpy_new::util::{create_ndims, extract_ndims},
CodeGenContext, CodeGenerator,
},
typecheck::typedef::Type,
};
use super::object::NDArrayObject;
#[derive(Debug, Clone)]
pub struct BroadcastAllResult<'ctx> {
/// The statically known `ndims` of the broadcast result.
pub ndims: u64,
/// The broadcasting shape.
pub shape: Ptr<'ctx, IntModel<SizeT>>,
/// Broadcasted views on the inputs.
///
/// All of them will have `shape` [`BroadcastAllResult::shape`] and
/// `ndims` [`BroadcastAllResult::ndims`]. The length of the vector
/// is the same as the input.
pub ndarrays: Vec<NDArrayObject<'ctx>>,
}
// TODO: DOCUMENT: Behaves like `np.broadcast()`, except returns results differently.
pub fn broadcast_all_ndarrays<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarrays: Vec<NDArrayObject<'ctx>>,
) -> BroadcastAllResult<'ctx> {
assert!(!ndarrays.is_empty());
let tyctx = generator.type_context(ctx.ctx);
let sizet_model = IntModel(SizeT);
let shape_model = StructModel(ShapeEntry);
// We can deduce the final ndims statically and immediately.
// It should be `max([ ndarray.ndims for ndarray in ndarrays ])`.
let broadcast_ndims =
ndarrays.iter().map(|ndarray| extract_ndims(&ctx.unifier, ndarray.ndims)).max().unwrap();
let broadcast_ndims_ty = create_ndims(&mut ctx.unifier, broadcast_ndims);
// NOTE: Now prepare before calling `call_nac3_ndarray_broadcast_shapes`
// Prepare input shape entries
let num_shape_entries =
sizet_model.constant(tyctx, ctx.ctx, u64::try_from(ndarrays.len()).unwrap());
let shape_entries =
shape_model.array_alloca(tyctx, ctx, num_shape_entries.value, "shape_entries");
for (i, ndarray) in ndarrays.iter().enumerate() {
let i = sizet_model.constant(tyctx, ctx.ctx, i as u64).value;
let this_shape = ndarray.instance.gep(ctx, |f| f.shape).load(tyctx, ctx, "this_shape");
let this_ndims = ndarray.instance.gep(ctx, |f| f.ndims).load(tyctx, ctx, "this_ndims");
let shape_entry = shape_entries.offset(tyctx, ctx, i, "shape_entry");
shape_entry.gep(ctx, |f| f.shape).store(ctx, this_shape);
shape_entry.gep(ctx, |f| f.ndims).store(ctx, this_ndims);
}
// Prepare destination
let dst_ndims = sizet_model.constant(tyctx, ctx.ctx, broadcast_ndims);
let dst_shape = sizet_model.array_alloca(tyctx, ctx, dst_ndims.value, "dst_shape");
call_nac3_ndarray_broadcast_shapes(
generator,
ctx,
num_shape_entries,
shape_entries,
dst_ndims,
dst_shape,
);
// Now that we know about the broadcasting shape, broadcast all the inputs.
// Broadcast all the inputs to shape `dst_shape`
let broadcasted_ndarrays = ndarrays
.into_iter()
.map(|ndarray| ndarray.broadcast_to(generator, ctx, broadcast_ndims_ty, dst_shape))
.collect_vec();
BroadcastAllResult { ndims: broadcast_ndims, shape: dst_shape, ndarrays: broadcasted_ndarrays }
}
impl<'ctx> NDArrayObject<'ctx> {
/// Broadcast an ndarray to a target shape.
#[must_use]
pub fn broadcast_to<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
target_ndims_ty: Type,
target_shape: Ptr<'ctx, IntModel<SizeT>>,
) -> Self {
// Please see comment in IRRT on how the caller should prepare `dst_ndarray`
let dst_ndarray = NDArrayObject::alloca(
generator,
ctx,
target_ndims_ty,
self.dtype,
"broadcast_ndarray_to_dst",
);
dst_ndarray.copy_shape(generator, ctx, target_shape);
call_nac3_ndarray_broadcast_to(generator, ctx, self.instance, dst_ndarray.instance);
dst_ndarray
}
}

View File

@ -0,0 +1,217 @@
use inkwell::{
types::BasicType,
values::{BasicValue, BasicValueEnum, PointerValue},
AddressSpace,
};
use nac3parser::ast::StrRef;
use crate::{
codegen::{
model::*,
numpy_new::util::{alloca_ndarray, init_ndarray_data_by_alloca, init_ndarray_shape},
structure::ndarray::NpArray,
util::shape::make_shape_writer,
CodeGenContext, CodeGenerator,
},
symbol_resolver::ValueEnum,
toplevel::DefinitionId,
typecheck::typedef::{FunSignature, Type},
};
use super::util::gen_foreach_ndarray_elements;
/// Helper function to create an ndarray with uninitialized values
///
/// * `elem_ty` - The [`Type`] of the ndarray elements
/// * `shape` - The user input shape argument
/// * `shape_ty` - The [`Type`] of the shape argument
/// * `name` - LLVM IR name of the returned ndarray
fn create_empty_ndarray<'ctx, G>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
shape: BasicValueEnum<'ctx>,
shape_ty: Type,
name: &str,
) -> Result<Ptr<'ctx, StructModel<NpArray>>, String>
where
G: CodeGenerator + ?Sized,
{
let tyctx = generator.type_context(ctx.ctx);
let sizet_model = IntModel(SizeT);
let shape_writer = make_shape_writer(generator, ctx, shape, shape_ty);
let ndims = shape_writer.len;
let ndarray = alloca_ndarray(generator, ctx, ndims, name);
init_ndarray_shape(generator, ctx, ndarray, &shape_writer)?;
let itemsize = ctx.get_llvm_type(generator, elem_ty).size_of().unwrap();
let itemsize = sizet_model.check_value(tyctx, ctx.ctx, itemsize).unwrap();
ndarray.gep(ctx, |f| f.itemsize).store(ctx, itemsize);
// Needs `itemsize` and `shape` initialized
init_ndarray_data_by_alloca(generator, ctx, ndarray);
Ok(ndarray)
}
/// Helper function to create an ndarray full of a value.
///
/// * `elem_ty` - The [`Type`] of the ndarray elements and the fill value
/// * `shape` - The user input shape argument
/// * `shape_ty` - The [`Type`] of the shape argument
/// * `fill_value` - The user specified fill value
/// * `name` - LLVM IR name of the returned ndarray
fn create_full_ndarray<'ctx, G>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
shape: BasicValueEnum<'ctx>,
shape_ty: Type,
fill_value: BasicValueEnum<'ctx>,
name: &str,
) -> Result<Ptr<'ctx, StructModel<NpArray>>, String>
where
G: CodeGenerator + ?Sized,
{
let pndarray = create_empty_ndarray(generator, ctx, elem_ty, shape, shape_ty, name)?;
gen_foreach_ndarray_elements(
generator,
ctx,
pndarray,
|_generator, ctx, _hooks, _i, pelement| {
// Cannot use Model here, fill_value's type is not statically known.
let pfill_value_ty = fill_value.get_type().ptr_type(AddressSpace::default());
let pelement =
ctx.builder.build_pointer_cast(pelement.value, pfill_value_ty, "pelement").unwrap();
ctx.builder.build_store(pelement, fill_value).unwrap();
Ok(())
},
)?;
Ok(pndarray)
}
/// Generates LLVM IR for `np.empty`.
pub fn gen_ndarray_empty<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
// Parse arguments
let shape_ty = fun.0.args[0].ty;
let shape = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
// Implementation
let ndarray_ptr = create_empty_ndarray(
generator,
context,
context.primitives.float,
shape,
shape_ty,
"ndarray",
)?;
Ok(ndarray_ptr.value)
}
/// Generates LLVM IR for `np.zeros`.
pub fn gen_ndarray_zeros<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
// Parse arguments
let shape_ty = fun.0.args[0].ty;
let shape = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
// Implementation
// NOTE: Currently nac3's `np.zeros` is always `float64`.
let float64_ty = context.primitives.float;
let float64_llvm_type = context.get_llvm_type(generator, float64_ty).into_float_type();
let ndarray_ptr = create_full_ndarray(
generator,
context,
float64_ty, // `elem_ty` is always `float64`
shape,
shape_ty,
float64_llvm_type.const_zero().as_basic_value_enum(),
"ndarray",
)?;
Ok(ndarray_ptr.value)
}
/// Generates LLVM IR for `np.ones`.
pub fn gen_ndarray_ones<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
// Parse arguments
let shape_ty = fun.0.args[0].ty;
let shape = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
// Implementation
// NOTE: Currently nac3's `np.ones` is always `float64`.
let float64_ty = context.primitives.float;
let float64_llvm_type = context.get_llvm_type(generator, float64_ty).into_float_type();
let ndarray_ptr = create_full_ndarray(
generator,
context,
float64_ty, // `elem_ty` is always `float64`
shape,
shape_ty,
float64_llvm_type.const_float(1.0).as_basic_value_enum(),
"ndarray",
)?;
Ok(ndarray_ptr.value)
}
/// Generates LLVM IR for `np.full`.
pub fn gen_ndarray_full<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 2);
// Parse argument #1 shape
let shape_ty = fun.0.args[0].ty;
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
// Parse argument #2 fill_value
let fill_value_ty = fun.0.args[1].ty;
let fill_value_arg =
args[1].1.clone().to_basic_value_enum(context, generator, fill_value_ty)?;
// Implementation
let ndarray_ptr = create_full_ndarray(
generator,
context,
fill_value_ty,
shape_arg,
shape_ty,
fill_value_arg,
"ndarray",
)?;
Ok(ndarray_ptr.value)
}

View File

@ -0,0 +1,76 @@
use crate::{
codegen::{
irrt::ndarray::indexing::{call_nac3_ndarray_index, RustNDIndex},
model::*,
CodeGenContext, CodeGenerator,
},
typecheck::typedef::{Type, Unifier},
};
use super::{
object::{NDArrayObject, ScalarObject, ScalarOrNDArray},
util::{create_ndims, extract_ndims},
};
impl<'ctx> NDArrayObject<'ctx> {
pub fn deduce_ndims_after_indexing_with(
&self,
unifier: &mut Unifier,
indexes: &[RustNDIndex<'ctx>],
) -> Type {
let ndims = extract_ndims(unifier, self.ndims);
let new_ndims = RustNDIndex::deduce_ndims_after_indexing(indexes, ndims);
create_ndims(unifier, new_ndims)
}
#[must_use]
pub fn index_always_ndarray<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
indexes: &[RustNDIndex<'ctx>],
name: &str,
) -> Self {
let tyctx = generator.type_context(ctx.ctx);
let dst_ndims = self.deduce_ndims_after_indexing_with(&mut ctx.unifier, indexes);
let dst_ndarray = NDArrayObject::alloca(generator, ctx, dst_ndims, self.dtype, name);
let (num_indexes, indexes) = RustNDIndex::alloca_ndindexes(tyctx, ctx, indexes);
call_nac3_ndarray_index(
generator,
ctx,
num_indexes,
indexes,
self.instance,
dst_ndarray.instance,
);
dst_ndarray
}
pub fn index<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
indexes: &[RustNDIndex<'ctx>],
name: &str,
) -> ScalarOrNDArray<'ctx> {
let tyctx = generator.type_context(ctx.ctx);
let sizet_model = IntModel(SizeT);
let subndarray = self.index_always_ndarray(generator, ctx, indexes, name);
if subndarray.is_unsized(&ctx.unifier) {
// TODO: This actually never fails, don't use the `checked_` version.
let value = subndarray.checked_get_nth_element(
generator,
ctx,
sizet_model.const_0(tyctx, ctx.ctx),
name,
);
ScalarOrNDArray::Scalar(ScalarObject { dtype: self.dtype, value })
} else {
ScalarOrNDArray::NDArray(subndarray)
}
}
}

View File

@ -0,0 +1,6 @@
pub mod broadcast;
pub mod factory;
pub mod indexing;
pub mod object;
pub mod util;
pub mod view;

View File

@ -0,0 +1,69 @@
use inkwell::values::{BasicValue, BasicValueEnum};
use crate::{
codegen::{model::*, structure::ndarray::NpArray, CodeGenContext},
toplevel::numpy::unpack_ndarray_var_tys,
typecheck::typedef::{Type, TypeEnum},
};
/// An LLVM ndarray instance with its typechecker [`Type`]s.
#[derive(Debug, Clone, Copy)]
pub struct NDArrayObject<'ctx> {
pub dtype: Type,
pub ndims: Type,
pub instance: Ptr<'ctx, StructModel<NpArray>>,
}
/// An LLVM numpy scalar with its [`Type`].
#[derive(Debug, Clone, Copy)]
pub struct ScalarObject<'ctx> {
pub dtype: Type,
pub value: BasicValueEnum<'ctx>,
}
#[derive(Debug, Clone, Copy)]
pub enum ScalarOrNDArray<'ctx> {
Scalar(ScalarObject<'ctx>),
NDArray(NDArrayObject<'ctx>),
}
impl<'ctx> ScalarOrNDArray<'ctx> {
/// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`].
#[must_use]
pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> {
match self {
ScalarOrNDArray::Scalar(scalar) => scalar.value,
ScalarOrNDArray::NDArray(ndarray) => ndarray.instance.value.as_basic_value_enum(),
}
}
}
impl<'ctx> From<ScalarOrNDArray<'ctx>> for BasicValueEnum<'ctx> {
fn from(input: ScalarOrNDArray<'ctx>) -> BasicValueEnum<'ctx> {
input.to_basic_value_enum()
}
}
/// Split an [`BasicValueEnum<'ctx>`] into a [`ScalarOrNDArray`] depending
/// on its [`Type`].
pub fn split_scalar_or_ndarray<'ctx>(
tyctx: TypeContext<'ctx>,
ctx: &mut CodeGenContext<'ctx, '_>,
input: BasicValueEnum<'ctx>,
input_ty: Type,
) -> ScalarOrNDArray<'ctx> {
let pndarray_model = PtrModel(StructModel(NpArray));
let input_ty_enum = ctx.unifier.get_ty(input_ty);
match &*input_ty_enum {
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
{
let value = pndarray_model.check_value(tyctx, ctx.ctx, input).unwrap();
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, input_ty);
ScalarOrNDArray::NDArray(NDArrayObject { dtype, ndims, instance: value })
}
_ => ScalarOrNDArray::Scalar(ScalarObject { dtype: input_ty, value: input }),
}
}

View File

@ -0,0 +1,328 @@
use inkwell::{
types::BasicType,
values::{BasicValueEnum, PointerValue},
AddressSpace,
};
use util::gen_model_memcpy;
use crate::{
codegen::{
irrt::ndarray::basic::{
call_nac3_ndarray_copy_data, call_nac3_ndarray_get_nth_pelement,
call_nac3_ndarray_is_c_contiguous, call_nac3_ndarray_nbytes,
call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_size,
call_nac3_ndarray_util_assert_shape_no_negative,
},
model::*,
stmt::BreakContinueHooks,
structure::ndarray::NpArray,
util::{array_writer::ArrayWriter, control::gen_model_for},
CodeGenContext, CodeGenerator,
},
symbol_resolver::SymbolValue,
typecheck::typedef::{Type, TypeEnum, Unifier},
};
use super::object::{NDArrayObject, ScalarOrNDArray};
/// Extract an ndarray's `ndims` [type][`Type`] in `u64`. Panic if not possible.
#[must_use]
pub fn extract_ndims(unifier: &Unifier, ndims_ty: Type) -> u64 {
let ndims_ty_enum = unifier.get_ty_immutable(ndims_ty);
let TypeEnum::TLiteral { values, .. } = &*ndims_ty_enum else {
panic!("ndims_ty should be a TLiteral");
};
assert_eq!(values.len(), 1, "ndims_ty TLiteral should only contain 1 value");
let ndims = values[0].clone();
u64::try_from(ndims).unwrap()
}
/// Return an ndarray's `ndims` as a typechecker [`Type`] from its `u64` value.
pub fn create_ndims(unifier: &mut Unifier, ndims: u64) -> Type {
unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None)
}
/// Allocate an ndarray on the stack given its `ndims`.
///
/// `shape` and `strides` will be automatically allocated on the stack.
///
/// The returned ndarray's content will be:
/// - `data`: `nullptr`
/// - `itemsize`: **uninitialized** value
/// - `ndims`: initialized value, set to the input `ndims`
/// - `shape`: initialized pointer to an allocated stack with **uninitialized** values
/// - `strides`: initialized pointer to an allocated stack with **uninitialized** values
pub fn alloca_ndarray<'ctx, G>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndims: Int<'ctx, SizeT>,
name: &str,
) -> Ptr<'ctx, StructModel<NpArray>>
where
G: CodeGenerator + ?Sized,
{
let tyctx = generator.type_context(ctx.ctx);
let sizet_model = IntModel(SizeT);
let ndarray_model = StructModel(NpArray);
let ndarray_data_model = PtrModel(IntModel(Byte));
// Setup ndarray
let ndarray_ptr = ndarray_model.alloca(tyctx, ctx, name);
let shape = sizet_model.array_alloca(tyctx, ctx, ndims.value, "shape");
let strides = sizet_model.array_alloca(tyctx, ctx, ndims.value, "strides");
ndarray_ptr.gep(ctx, |f| f.data).store(ctx, ndarray_data_model.nullptr(tyctx, ctx.ctx));
ndarray_ptr.gep(ctx, |f| f.ndims).store(ctx, ndims);
ndarray_ptr.gep(ctx, |f| f.shape).store(ctx, shape);
ndarray_ptr.gep(ctx, |f| f.strides).store(ctx, strides);
ndarray_ptr
}
/// Initialize an ndarray's `shape` and asserts on.
/// `shape`'s values and prohibit illegal inputs like negative dimensions.
pub fn init_ndarray_shape<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
pndarray: Ptr<'ctx, StructModel<NpArray>>,
shape_writer: &ArrayWriter<'ctx, G, SizeT, IntModel<SizeT>>,
) -> Result<(), String> {
let tyctx = generator.type_context(ctx.ctx);
let shape = pndarray.gep(ctx, |f| f.shape).load(tyctx, ctx, "shape");
(shape_writer.write)(generator, ctx, shape)?;
call_nac3_ndarray_util_assert_shape_no_negative(generator, ctx, shape_writer.len, shape);
Ok(())
}
/// Initialize an ndarray's `data` by allocating a buffer on the stack.
/// The allocated data buffer is considered to be *owned* by the ndarray.
///
/// `strides` of the ndarray will also be updated with `set_strides_by_shape`.
///
/// `shape` and `itemsize` of the ndarray ***must*** be initialized first.
pub fn init_ndarray_data_by_alloca<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
pndarray: Ptr<'ctx, StructModel<NpArray>>,
) {
let tyctx = generator.type_context(ctx.ctx);
let ndarray_data_model = IntModel(Byte);
let nbytes = call_nac3_ndarray_nbytes(generator, ctx, pndarray);
let data = ndarray_data_model.array_alloca(tyctx, ctx, nbytes.value, "data");
pndarray.gep(ctx, |f| f.data).store(ctx, data);
call_nac3_ndarray_set_strides_by_shape(generator, ctx, pndarray);
}
/// Iterate through all elements in an ndarray.
///
/// `body` is given the index of an element and an opaque pointer (as an `uint8_t*`, you might want to cast it) to the element.
///
/// Short-circuiting is possible with the given [`BreakContinueHooks`].
pub fn gen_foreach_ndarray_elements<'ctx, G, F>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
pndarray: Ptr<'ctx, StructModel<NpArray>>,
body: F,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
F: Fn(
&mut G,
&mut CodeGenContext<'ctx, '_>,
BreakContinueHooks<'ctx>,
Int<'ctx, SizeT>,
Ptr<'ctx, IntModel<Byte>>,
) -> Result<(), String>,
{
// TODO: Make this more efficient - use a special NDArray iterator?
let tyctx = generator.type_context(ctx.ctx);
let sizet_model = IntModel(SizeT);
let size = call_nac3_ndarray_size(generator, ctx, pndarray);
gen_model_for(
generator,
ctx,
sizet_model.const_0(tyctx, ctx.ctx),
size,
sizet_model.const_1(tyctx, ctx.ctx),
|generator, ctx, hooks, index| {
let pelement = call_nac3_ndarray_get_nth_pelement(generator, ctx, pndarray, index);
body(generator, ctx, hooks, index, pelement)
},
)
}
impl<'ctx> ScalarOrNDArray<'ctx> {
/// Convert `input` to an ndarray - behaves like `np.asarray`.
pub fn as_ndarray<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> NDArrayObject<'ctx> {
match self {
ScalarOrNDArray::NDArray(ndarray) => *ndarray,
ScalarOrNDArray::Scalar(scalar) => {
let tyctx = generator.type_context(ctx.ctx);
let pbyte_model = PtrModel(IntModel(Byte));
// We have to put the value on the stack to get a data pointer.
let data =
ctx.builder.build_alloca(scalar.value.get_type(), "as_ndarray_scalar").unwrap();
ctx.builder.build_store(data, scalar.value).unwrap();
let data = pbyte_model.transmute(tyctx, ctx, data, "data");
let ndims_ty = create_ndims(&mut ctx.unifier, 0);
let ndarray = NDArrayObject::alloca(
generator,
ctx,
ndims_ty,
scalar.dtype,
"scalar_as_ndarray",
);
ndarray.instance.gep(ctx, |f| f.data).store(ctx, data);
// No need to initialize/setup strides or shapes - because `ndims` is 0.
// So we only have to set `data`, `itemsize`, and `ndims = 0`.
ndarray
}
}
}
}
impl<'ctx> NDArrayObject<'ctx> {
pub fn alloca<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndims: Type,
dtype: Type,
name: &str,
) -> Self {
let tyctx = generator.type_context(ctx.ctx);
let sizet_model = IntModel(SizeT);
let ndims_int = sizet_model.constant(tyctx, ctx.ctx, extract_ndims(&ctx.unifier, ndims));
let instance = alloca_ndarray(generator, ctx, ndims_int, name);
// Set itemsize
let dtype_ty = ctx.get_llvm_type(generator, dtype);
let itemsize = dtype_ty.size_of().unwrap();
let itemsize = sizet_model.s_extend_or_bit_cast(tyctx, ctx, itemsize, "itemsize");
instance.gep(ctx, |f| f.itemsize).store(ctx, itemsize);
NDArrayObject { dtype, ndims, instance }
}
pub fn copy_shape<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src_shape: Ptr<'ctx, IntModel<SizeT>>,
) {
let tyctx = generator.type_context(ctx.ctx);
let sizet_model = IntModel(SizeT);
let self_shape = self.instance.gep(ctx, |f| f.shape).load(tyctx, ctx, "self_shape");
let ndims_int =
sizet_model.constant(tyctx, ctx.ctx, extract_ndims(&ctx.unifier, self.ndims));
gen_model_memcpy(tyctx, ctx, self_shape, src_shape, ndims_int.value, false);
}
pub fn copy_shape_from<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src_ndarray: NDArrayObject<'ctx>,
) {
let tyctx = generator.type_context(ctx.ctx);
let src_shape = src_ndarray.instance.gep(ctx, |f| f.shape).load(tyctx, ctx, "src_shape");
self.copy_shape(generator, ctx, src_shape);
}
pub fn update_strides_by_shape<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) {
call_nac3_ndarray_set_strides_by_shape(generator, ctx, self.instance);
}
pub fn checked_get_nth_pelement<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
i: Int<'ctx, SizeT>,
name: &str,
) -> PointerValue<'ctx> {
let elem_ty = ctx.get_llvm_type(generator, self.dtype);
let p = call_nac3_ndarray_get_nth_pelement(generator, ctx, self.instance, i);
ctx.builder
.build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), name)
.unwrap()
}
pub fn checked_get_nth_element<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
i: Int<'ctx, SizeT>,
name: &str,
) -> BasicValueEnum<'ctx> {
let pelement = self.checked_get_nth_pelement(generator, ctx, i, "pelement");
ctx.builder.build_load(pelement, name).unwrap()
}
#[must_use]
pub fn is_unsized(&self, unifier: &Unifier) -> bool {
extract_ndims(unifier, self.ndims) == 0
}
pub fn size<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, SizeT> {
call_nac3_ndarray_size(generator, ctx, self.instance)
}
pub fn nbytes<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, SizeT> {
call_nac3_ndarray_nbytes(generator, ctx, self.instance)
}
pub fn is_c_contiguous<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, Bool> {
call_nac3_ndarray_is_c_contiguous(generator, ctx, self.instance)
}
pub fn alloca_owned_data<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) {
init_ndarray_data_by_alloca(generator, ctx, self.instance);
}
pub fn copy_data_from<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src: NDArrayObject<'ctx>,
) {
assert!(ctx.unifier.unioned(self.dtype, src.dtype), "self and src dtype should match");
call_nac3_ndarray_copy_data(generator, ctx, src.instance, self.instance);
}
}

View File

@ -0,0 +1,114 @@
use inkwell::values::PointerValue;
use nac3parser::ast::StrRef;
use crate::{
codegen::{
irrt::ndarray::reshape::call_nac3_ndarray_resolve_and_check_new_shape,
model::*,
numpy_new::{object::split_scalar_or_ndarray, util::extract_ndims},
util::shape::make_shape_writer,
CodeGenContext, CodeGenerator,
},
symbol_resolver::ValueEnum,
toplevel::{numpy::unpack_ndarray_var_tys, DefinitionId},
typecheck::typedef::{FunSignature, Type},
};
use super::object::NDArrayObject;
impl<'ctx> NDArrayObject<'ctx> {
#[must_use]
pub fn reshape_or_copy<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
new_ndims: Type,
new_shape: Ptr<'ctx, IntModel<SizeT>>,
) -> Self {
let tyctx = generator.type_context(ctx.ctx);
let current_bb = ctx.builder.get_insert_block().unwrap();
let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "then_bb");
let else_bb = ctx.ctx.insert_basic_block_after(then_bb, "else_bb");
let end_bb = ctx.ctx.insert_basic_block_after(else_bb, "end_bb");
let dst_ndarray =
NDArrayObject::alloca(generator, ctx, new_ndims, self.dtype, "reshaped_ndarray");
dst_ndarray.copy_shape(generator, ctx, new_shape);
dst_ndarray.update_strides_by_shape(generator, ctx);
let is_c_contiguous = self.is_c_contiguous(generator, ctx);
ctx.builder.build_conditional_branch(is_c_contiguous.value, then_bb, else_bb).unwrap();
// Inserting into then_bb: reshape is possible without copying
ctx.builder.position_at_end(then_bb);
dst_ndarray
.instance
.gep(ctx, |f| f.data)
.store(ctx, dst_ndarray.instance.gep(ctx, |f| f.data).load(tyctx, ctx, "data"));
ctx.builder.build_unconditional_branch(end_bb).unwrap();
// Inserting into else_bb: reshape is impossible without copying
ctx.builder.position_at_end(else_bb);
dst_ndarray.alloca_owned_data(generator, ctx);
dst_ndarray.copy_data_from(generator, ctx, *self);
ctx.builder.build_unconditional_branch(end_bb).unwrap();
// Reposition for continuation
ctx.builder.position_at_end(end_bb);
dst_ndarray
}
}
/// Generates LLVM IR for `np.reshape`.
pub fn gen_ndarray_reshape<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 2);
// Parse argument #1 input
let input_ty = fun.0.args[0].ty;
let input_arg = args[0].1.clone().to_basic_value_enum(ctx, generator, input_ty)?;
// Parse argument #2 shape
let shape_ty = fun.0.args[1].ty;
let shape_arg = args[1].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
// Define models
let tyctx = generator.type_context(ctx.ctx);
let sizet_model = IntModel(SizeT);
// Extract reshaped_ndims
let (_, reshaped_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
let reshaped_ndims_int = extract_ndims(&ctx.unifier, reshaped_ndims);
// Process `input`
let ndarray =
split_scalar_or_ndarray(tyctx, ctx, input_arg, input_ty).as_ndarray(generator, ctx);
// Process the shape input from user and resolve negative indices
let new_shape = make_shape_writer(generator, ctx, shape_arg, shape_ty).alloca_array_and_write(
generator,
ctx,
"new_shape",
)?;
let size = ndarray.size(generator, ctx);
call_nac3_ndarray_resolve_and_check_new_shape(
generator,
ctx,
size,
sizet_model.constant(tyctx, ctx.ctx, reshaped_ndims_int),
new_shape,
);
// Reshape
let reshaped_ndarray = ndarray.reshape_or_copy(generator, ctx, reshaped_ndims, new_shape);
Ok(reshaped_ndarray.instance.value)
}

View File

@ -1,8 +1,11 @@
use super::model::*;
use super::structure::cslice::CSlice;
use super::{ use super::{
super::symbol_resolver::ValueEnum, super::symbol_resolver::ValueEnum,
expr::destructure_range, expr::destructure_range,
irrt::{handle_slice_indices, list_slice_assignment}, irrt::{handle_slice_indices, list_slice_assignment},
CodeGenContext, CodeGenerator, structure::exception::Exception,
CodeGenContext, CodeGenerator, Int32, IntModel, Ptr, StructModel,
}; };
use crate::{ use crate::{
codegen::{ codegen::{
@ -206,6 +209,7 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
target: &Expr<Option<Type>>, target: &Expr<Option<Type>>,
value: ValueEnum<'ctx>, value: ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String> { ) -> Result<(), String> {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
@ -222,7 +226,7 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
.builder .builder
.build_extract_value(v, u32::try_from(i).unwrap(), "struct_elem") .build_extract_value(v, u32::try_from(i).unwrap(), "struct_elem")
.unwrap(); .unwrap();
generator.gen_assign(ctx, elt, v.into())?; generator.gen_assign(ctx, elt, v.into(), value_ty)?;
} }
} }
ExprKind::Subscript { value: ls, slice, .. } ExprKind::Subscript { value: ls, slice, .. }
@ -431,7 +435,7 @@ pub fn gen_for<G: CodeGenerator>(
.map(BasicValueEnum::into_int_value) .map(BasicValueEnum::into_int_value)
.unwrap(); .unwrap();
let val = ctx.build_gep_and_load(arr_ptr, &[index], Some("val")); let val = ctx.build_gep_and_load(arr_ptr, &[index], Some("val"));
generator.gen_assign(ctx, target, val.into())?; generator.gen_assign(ctx, target, val.into(), ctx.primitives.int32)?;
generator.gen_block(ctx, body.iter())?; generator.gen_block(ctx, body.iter())?;
} }
@ -494,6 +498,7 @@ pub struct BreakContinueHooks<'ctx> {
pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>( pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
label: Option<&str>,
init: InitFn, init: InitFn,
cond: CondFn, cond: CondFn,
body: BodyFn, body: BodyFn,
@ -504,18 +509,24 @@ where
I: Clone, I: Clone,
InitFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<I, String>, InitFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<I, String>,
CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<IntValue<'ctx>, String>, CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<IntValue<'ctx>, String>,
BodyFn: BodyFn: FnOnce(
FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, BreakContinueHooks, I) -> Result<(), String>, &mut G,
&mut CodeGenContext<'ctx, 'a>,
BreakContinueHooks<'ctx>,
I,
) -> Result<(), String>,
UpdateFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>, UpdateFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
{ {
let label = label.unwrap_or("for");
let current_bb = ctx.builder.get_insert_block().unwrap(); let current_bb = ctx.builder.get_insert_block().unwrap();
let init_bb = ctx.ctx.insert_basic_block_after(current_bb, "for.init"); let init_bb = ctx.ctx.insert_basic_block_after(current_bb, &format!("{label}.init"));
// The BB containing the loop condition check // The BB containing the loop condition check
let cond_bb = ctx.ctx.insert_basic_block_after(init_bb, "for.cond"); let cond_bb = ctx.ctx.insert_basic_block_after(init_bb, &format!("{label}.cond"));
let body_bb = ctx.ctx.insert_basic_block_after(cond_bb, "for.body"); let body_bb = ctx.ctx.insert_basic_block_after(cond_bb, &format!("{label}.body"));
// The BB containing the increment expression // The BB containing the increment expression
let update_bb = ctx.ctx.insert_basic_block_after(body_bb, "for.update"); let update_bb = ctx.ctx.insert_basic_block_after(body_bb, &format!("{label}.update"));
let cont_bb = ctx.ctx.insert_basic_block_after(update_bb, "for.end"); let cont_bb = ctx.ctx.insert_basic_block_after(update_bb, &format!("{label}.end"));
// store loop bb information and restore it later // store loop bb information and restore it later
let loop_bb = ctx.loop_target.replace((update_bb, cont_bb)); let loop_bb = ctx.loop_target.replace((update_bb, cont_bb));
@ -572,6 +583,7 @@ where
pub fn gen_for_callback_incrementing<'ctx, 'a, G, BodyFn>( pub fn gen_for_callback_incrementing<'ctx, 'a, G, BodyFn>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
label: Option<&str>,
init_val: IntValue<'ctx>, init_val: IntValue<'ctx>,
max_val: (IntValue<'ctx>, bool), max_val: (IntValue<'ctx>, bool),
body: BodyFn, body: BodyFn,
@ -582,7 +594,7 @@ where
BodyFn: FnOnce( BodyFn: FnOnce(
&mut G, &mut G,
&mut CodeGenContext<'ctx, 'a>, &mut CodeGenContext<'ctx, 'a>,
BreakContinueHooks, BreakContinueHooks<'ctx>,
IntValue<'ctx>, IntValue<'ctx>,
) -> Result<(), String>, ) -> Result<(), String>,
{ {
@ -591,6 +603,7 @@ where
gen_for_callback( gen_for_callback(
generator, generator,
ctx, ctx,
label,
|generator, ctx| { |generator, ctx| {
let i_addr = generator.gen_var_alloc(ctx, init_val_t.into(), None)?; let i_addr = generator.gen_var_alloc(ctx, init_val_t.into(), None)?;
ctx.builder.build_store(i_addr, init_val).unwrap(); ctx.builder.build_store(i_addr, init_val).unwrap();
@ -642,9 +655,11 @@ where
/// - `step_fn`: A lambda of IR statements that retrieves the `step` value of the `range`-like /// - `step_fn`: A lambda of IR statements that retrieves the `step` value of the `range`-like
/// iterable. This value will be extended to the size of `start`. /// iterable. This value will be extended to the size of `start`.
/// - `body_fn`: A lambda of IR statements within the loop body. /// - `body_fn`: A lambda of IR statements within the loop body.
#[allow(clippy::too_many_arguments)]
pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>( pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
label: Option<&str>,
is_unsigned: bool, is_unsigned: bool,
start_fn: StartFn, start_fn: StartFn,
(stop_fn, stop_inclusive): (StopFn, bool), (stop_fn, stop_inclusive): (StopFn, bool),
@ -656,13 +671,19 @@ where
StartFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>, StartFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
StopFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>, StopFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
StepFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>, StepFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<(), String>, BodyFn: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
BreakContinueHooks,
IntValue<'ctx>,
) -> Result<(), String>,
{ {
let init_val_t = start_fn(generator, ctx).map(IntValue::get_type).unwrap(); let init_val_t = start_fn(generator, ctx).map(IntValue::get_type).unwrap();
gen_for_callback( gen_for_callback(
generator, generator,
ctx, ctx,
label,
|generator, ctx| { |generator, ctx| {
let i_addr = generator.gen_var_alloc(ctx, init_val_t.into(), None)?; let i_addr = generator.gen_var_alloc(ctx, init_val_t.into(), None)?;
@ -720,10 +741,10 @@ where
Ok(cond) Ok(cond)
}, },
|generator, ctx, _, (i_addr, _)| { |generator, ctx, hooks, (i_addr, _)| {
let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
body_fn(generator, ctx, i) body_fn(generator, ctx, hooks, i)
}, },
|generator, ctx, (i_addr, _)| { |generator, ctx, (i_addr, _)| {
let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
@ -1113,47 +1134,37 @@ pub fn exn_constructor<'ctx>(
pub fn gen_raise<'ctx, G: CodeGenerator + ?Sized>( pub fn gen_raise<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
exception: Option<&BasicValueEnum<'ctx>>, exception: Option<Ptr<'ctx, StructModel<Exception>>>,
loc: Location, loc: Location,
) { ) {
if let Some(exception) = exception { if let Some(pexn) = exception {
unsafe { let type_context = generator.type_context(ctx.ctx);
let int32 = ctx.ctx.i32_type(); let i32_model = IntModel(Int32);
let zero = int32.const_zero(); let cslice_model = StructModel(CSlice);
let exception = exception.into_pointer_value();
let file_ptr = ctx
.builder
.build_in_bounds_gep(exception, &[zero, int32.const_int(1, false)], "file_ptr")
.unwrap();
let filename = ctx.gen_string(generator, loc.file.0);
ctx.builder.build_store(file_ptr, filename).unwrap();
let row_ptr = ctx
.builder
.build_in_bounds_gep(exception, &[zero, int32.const_int(2, false)], "row_ptr")
.unwrap();
ctx.builder.build_store(row_ptr, int32.const_int(loc.row as u64, false)).unwrap();
let col_ptr = ctx
.builder
.build_in_bounds_gep(exception, &[zero, int32.const_int(3, false)], "col_ptr")
.unwrap();
ctx.builder.build_store(col_ptr, int32.const_int(loc.column as u64, false)).unwrap();
let current_fun = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); // Get and store filename
let fun_name = ctx.gen_string(generator, current_fun.get_name().to_str().unwrap()); let filename = loc.file.0;
let name_ptr = ctx let filename = ctx.gen_string(generator, &String::from(filename)).value;
.builder let filename = cslice_model.check_value(type_context, ctx.ctx, filename).unwrap();
.build_in_bounds_gep(exception, &[zero, int32.const_int(4, false)], "name_ptr") pexn.gep(ctx, |f| f.filename).store(ctx, filename);
.unwrap();
ctx.builder.build_store(name_ptr, fun_name).unwrap(); let row = i32_model.constant(type_context, ctx.ctx, loc.row as u64);
} pexn.gep(ctx, |f| f.line).store(ctx, row);
let column = i32_model.constant(type_context, ctx.ctx, loc.column as u64);
pexn.gep(ctx, |f| f.column).store(ctx, column);
let current_fn = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
let fn_name = ctx.gen_string(generator, current_fn.get_name().to_str().unwrap());
pexn.gep(ctx, |f| f.function).store(ctx, fn_name);
let raise = get_builtins(generator, ctx, "__nac3_raise"); let raise = get_builtins(generator, ctx, "__nac3_raise");
let exception = *exception; ctx.build_call_or_invoke(raise, &[pexn.value.into()], "raise");
ctx.build_call_or_invoke(raise, &[exception], "raise");
} else { } else {
let resume = get_builtins(generator, ctx, "__nac3_resume"); let resume = get_builtins(generator, ctx, "__nac3_resume");
ctx.build_call_or_invoke(resume, &[], "resume"); ctx.build_call_or_invoke(resume, &[], "resume");
} }
ctx.builder.build_unreachable().unwrap(); ctx.builder.build_unreachable().unwrap();
} }
@ -1575,14 +1586,16 @@ pub fn gen_stmt<G: CodeGenerator>(
} }
StmtKind::AnnAssign { target, value, .. } => { StmtKind::AnnAssign { target, value, .. } => {
if let Some(value) = value { if let Some(value) = value {
let value_ty = value.custom.unwrap();
let Some(value) = generator.gen_expr(ctx, value)? else { return Ok(()) }; let Some(value) = generator.gen_expr(ctx, value)? else { return Ok(()) };
generator.gen_assign(ctx, target, value)?; generator.gen_assign(ctx, target, value, value_ty)?;
} }
} }
StmtKind::Assign { targets, value, .. } => { StmtKind::Assign { targets, value, .. } => {
let value_ty = value.custom.unwrap();
let Some(value) = generator.gen_expr(ctx, value)? else { return Ok(()) }; let Some(value) = generator.gen_expr(ctx, value)? else { return Ok(()) };
for target in targets { for target in targets {
generator.gen_assign(ctx, target, value.clone())?; generator.gen_assign(ctx, target, value.clone(), value_ty)?;
} }
} }
StmtKind::Continue { .. } => { StmtKind::Continue { .. } => {
@ -1596,6 +1609,7 @@ pub fn gen_stmt<G: CodeGenerator>(
StmtKind::For { .. } => generator.gen_for(ctx, stmt)?, StmtKind::For { .. } => generator.gen_for(ctx, stmt)?,
StmtKind::With { .. } => generator.gen_with(ctx, stmt)?, StmtKind::With { .. } => generator.gen_with(ctx, stmt)?,
StmtKind::AugAssign { target, op, value, .. } => { StmtKind::AugAssign { target, op, value, .. } => {
let value_ty = value.custom.unwrap();
let value = gen_binop_expr( let value = gen_binop_expr(
generator, generator,
ctx, ctx,
@ -1604,7 +1618,7 @@ pub fn gen_stmt<G: CodeGenerator>(
value, value,
stmt.location, stmt.location,
)?; )?;
generator.gen_assign(ctx, target, value.unwrap())?; generator.gen_assign(ctx, target, value.unwrap(), value_ty)?;
} }
StmtKind::Try { .. } => gen_try(generator, ctx, stmt)?, StmtKind::Try { .. } => gen_try(generator, ctx, stmt)?,
StmtKind::Raise { exc, .. } => { StmtKind::Raise { exc, .. } => {
@ -1614,30 +1628,43 @@ pub fn gen_stmt<G: CodeGenerator>(
} else { } else {
return Ok(()); return Ok(());
}; };
gen_raise(generator, ctx, Some(&exc), stmt.location);
let type_context = generator.type_context(ctx.ctx);
let pexn_model = PtrModel(StructModel(Exception));
let exn = pexn_model.check_value(type_context, ctx.ctx, exc).unwrap();
gen_raise(generator, ctx, Some(exn), stmt.location);
} else { } else {
gen_raise(generator, ctx, None, stmt.location); gen_raise(generator, ctx, None, stmt.location);
} }
} }
StmtKind::Assert { test, msg, .. } => { StmtKind::Assert { test, msg, .. } => {
let test = if let Some(v) = generator.gen_expr(ctx, test)? { let type_context = generator.type_context(ctx.ctx);
v.to_basic_value_enum(ctx, generator, test.custom.unwrap())? let byte_model = IntModel(Byte);
} else { let cslice_model = StructModel(CSlice);
let Some(test) = generator.gen_expr(ctx, test)? else {
return Ok(()); return Ok(());
}; };
let test = test.to_basic_value_enum(ctx, generator, ctx.primitives.bool)?;
let test = byte_model.check_value(type_context, ctx.ctx, test).unwrap(); // Python `bool` is represented as `i8` in nac3core
// Check `msg`
let err_msg = match msg { let err_msg = match msg {
Some(msg) => { Some(msg) => {
if let Some(v) = generator.gen_expr(ctx, msg)? { let Some(msg) = generator.gen_expr(ctx, msg)? else {
v.to_basic_value_enum(ctx, generator, msg.custom.unwrap())?
} else {
return Ok(()); return Ok(());
} };
let msg = msg.to_basic_value_enum(ctx, generator, ctx.primitives.str)?;
cslice_model.check_value(type_context, ctx.ctx, msg).unwrap()
} }
None => ctx.gen_string(generator, ""), None => ctx.gen_string(generator, ""),
}; };
ctx.make_assert_impl( ctx.make_assert_impl(
generator, generator,
test.into_int_value(), test.value,
"0:AssertionError", "0:AssertionError",
err_msg, err_msg,
[None, None, None], [None, None, None],

View File

@ -0,0 +1,43 @@
use crate::codegen::{model::*, CodeGenContext};
/// Fields of [`CSlice<'ctx>`].
pub struct CSliceFields<F: FieldVisitor> {
/// Pointer to the data.
pub base: F::Field<PtrModel<IntModel<Byte>>>,
/// Number of bytes of the data.
pub len: F::Field<IntModel<SizeT>>,
}
/// See <https://crates.io/crates/cslice>.
///
/// Additionally, see <https://github.com/m-labs/artiq/blob/b0d2705c385f64b6e6711c1726cd9178f40b598e/artiq/firmware/libeh/eh_artiq.rs>)
/// for ARTIQ-specific notes.
#[derive(Debug, Clone, Copy, Default)]
pub struct CSlice;
impl StructKind for CSlice {
type Fields<F: FieldVisitor> = CSliceFields<F>;
fn visit_fields<F: FieldVisitor>(&self, visitor: &mut F) -> Self::Fields<F> {
Self::Fields { base: visitor.add("base"), len: visitor.add("len") }
}
}
impl StructModel<CSlice> {
/// Create a [`CSlice`].
///
/// `base` and `len` must be LLVM global constants.
pub fn create_const<'ctx>(
&self,
type_context: TypeContext<'ctx>,
ctx: &CodeGenContext<'ctx, '_>,
base: Ptr<'ctx, IntModel<Byte>>,
len: Int<'ctx, SizeT>,
) -> Struct<'ctx, CSlice> {
let value = self
.0
.get_struct_type(type_context, ctx.ctx)
.const_named_struct(&[base.value.into(), len.value.into()]);
self.believe_value(value)
}
}

View File

@ -0,0 +1,57 @@
use crate::codegen::model::*;
use super::cslice::CSlice;
/// The LLVM int type of an Exception ID.
pub type ExceptionId = Int32;
/// Fields of [`Exception<'ctx>`]
///
/// The definition came from `pub struct Exception<'a>` in
/// <https://github.com/m-labs/artiq/blob/master/artiq/firmware/libeh/eh_artiq.rs>.
pub struct ExceptionFields<F: FieldVisitor> {
/// nac3core's ID of the exception
pub id: F::Field<IntModel<ExceptionId>>,
/// The name of the file this `Exception` was raised in.
pub filename: F::Field<StructModel<CSlice>>,
/// The line number in the file this `Exception` was raised in.
pub line: F::Field<IntModel<Int32>>,
/// The column number in the file this `Exception` was raised in.
pub column: F::Field<IntModel<Int32>>,
/// The name of the Python function this `Exception` was raised in.
pub function: F::Field<StructModel<CSlice>>,
/// The message of this Exception.
///
/// The message can optionally contain integer parameters `{0}`, `{1}`, and `{2}` in its string,
/// where they will be substituted by `params[0]`, `params[1]`, and `params[2]` respectively (as `int64_t`s).
/// Here is an example:
///
/// ```ignore
/// "Index {0} is out of bounds! List only has {1} element(s)."
/// ```
///
/// In this case, `params[0]` and `params[1]` must be specified, and `params[2]` is ***unused***.
/// Having only 3 parameters is a constraint in ARTIQ.
pub msg: F::Field<StructModel<CSlice>>,
pub params: [F::Field<IntModel<Int64>>; 3],
}
/// nac3core & ARTIQ's Exception
#[derive(Debug, Clone, Copy, Default)]
pub struct Exception;
impl StructKind for Exception {
type Fields<F: FieldVisitor> = ExceptionFields<F>;
fn visit_fields<F: FieldVisitor>(&self, visitor: &mut F) -> Self::Fields<F> {
Self::Fields {
id: visitor.add("id"),
filename: visitor.add("filename"),
line: visitor.add("line"),
column: visitor.add("column"),
function: visitor.add("function"),
msg: visitor.add("msg"),
params: [visitor.add("params[0]"), visitor.add("params[1]"), visitor.add("params[2]")],
}
}
}

View File

@ -0,0 +1,3 @@
pub mod cslice;
pub mod exception;
pub mod ndarray;

View File

@ -0,0 +1,27 @@
use crate::codegen::*;
pub struct NpArrayFields<F: FieldVisitor> {
pub data: F::Field<PtrModel<IntModel<Byte>>>,
pub itemsize: F::Field<IntModel<SizeT>>,
pub ndims: F::Field<IntModel<SizeT>>,
pub shape: F::Field<PtrModel<IntModel<SizeT>>>,
pub strides: F::Field<PtrModel<IntModel<SizeT>>>,
}
// TODO: Rename to `NDArray` when the old NDArray is removed.
#[derive(Debug, Clone, Copy, Default)]
pub struct NpArray;
impl StructKind for NpArray {
type Fields<F: FieldVisitor> = NpArrayFields<F>;
fn visit_fields<F: FieldVisitor>(&self, visitor: &mut F) -> Self::Fields<F> {
Self::Fields {
data: visitor.add("data"),
itemsize: visitor.add("itemsize"),
ndims: visitor.add("ndims"),
shape: visitor.add("shape"),
strides: visitor.add("strides"),
}
}
}

View File

@ -189,6 +189,8 @@ fn test_primitives() {
let expected = indoc! {" let expected = indoc! {"
; ModuleID = 'test' ; ModuleID = 'test'
source_filename = \"test\" source_filename = \"test\"
target datalayout = \"e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128\"
target triple = \"x86_64-unknown-linux-gnu\"
; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn ; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn
define i32 @testing(i32 %0, i32 %1) local_unnamed_addr #0 !dbg !4 { define i32 @testing(i32 %0, i32 %1) local_unnamed_addr #0 !dbg !4 {
@ -368,6 +370,8 @@ fn test_simple_call() {
let expected = indoc! {" let expected = indoc! {"
; ModuleID = 'test' ; ModuleID = 'test'
source_filename = \"test\" source_filename = \"test\"
target datalayout = \"e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128\"
target triple = \"x86_64-unknown-linux-gnu\"
; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn ; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn
define i32 @testing(i32 %0) local_unnamed_addr #0 !dbg !5 { define i32 @testing(i32 %0) local_unnamed_addr #0 !dbg !5 {

View File

@ -0,0 +1,34 @@
use crate::codegen::{model::*, CodeGenContext, CodeGenerator};
/// A closure containing details on how to write to/initialize an array.
#[allow(clippy::type_complexity)]
pub struct ArrayWriter<'ctx, G: CodeGenerator + ?Sized, Len: IntKind, Item: Model> {
/// Number of items to write
pub len: Int<'ctx, Len>,
/// Implementation to write to an array given its base pointer.
pub write: Box<
dyn Fn(
&mut G,
&mut CodeGenContext<'ctx, '_>,
Ptr<'ctx, Item>, // Base pointer
) -> Result<(), String>
+ 'ctx,
>,
}
impl<'ctx, G: CodeGenerator + ?Sized, Len: IntKind, Item: Model> ArrayWriter<'ctx, G, Len, Item> {
pub fn alloca_array_and_write(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: &str,
) -> Result<Ptr<'ctx, Item>, String> {
let tyctx = generator.type_context(ctx.ctx);
let item_model = Item::default();
let item_array = item_model.array_alloca(tyctx, ctx, self.len.value, name);
(self.write)(generator, ctx, item_array)?;
Ok(item_array)
}
}

View File

@ -0,0 +1,42 @@
use crate::codegen::{
model::*,
stmt::{gen_for_callback_incrementing, BreakContinueHooks},
CodeGenContext, CodeGenerator,
};
// TODO: Document
// TODO: Rename function
/// Only allows positive steps
pub fn gen_model_for<'ctx, 'a, G, F, I>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
start: Int<'ctx, I>,
stop: Int<'ctx, I>,
step: Int<'ctx, I>,
body: F,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
F: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
BreakContinueHooks<'ctx>,
Int<'ctx, I>,
) -> Result<(), String>,
I: IntKind,
{
let int_model = IntModel(I::default());
gen_for_callback_incrementing(
generator,
ctx,
None,
start.value,
(stop.value, false),
|g, ctx, hooks, i| {
let i = int_model.believe_value(i);
body(g, ctx, hooks, i)
},
step.value,
)
}

View File

@ -0,0 +1,3 @@
pub mod array_writer;
pub mod control;
pub mod shape;

View File

@ -0,0 +1,127 @@
use inkwell::values::BasicValueEnum;
use crate::{
codegen::{
classes::{ListValue, UntypedArrayLikeAccessor},
model::*,
CodeGenContext, CodeGenerator,
},
typecheck::typedef::{Type, TypeEnum},
};
use super::{array_writer::ArrayWriter, control::gen_model_for};
// TODO: Generalize to complex iterables under a common interface
/// Create an [`ArrayWriter`] from a NumPy-like `shape` argument input.
/// * `shape` - The `shape` parameter.
/// * `shape_ty` - The element type of the `NDArray`.
///
/// The `shape` argument type may only be one of the following:
/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))`
/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
///
/// The `int32` values will be sign-extended to `SizeT`
pub fn make_shape_writer<'ctx, G>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
shape: BasicValueEnum<'ctx>,
shape_ty: Type,
) -> ArrayWriter<'ctx, G, SizeT, IntModel<SizeT>>
where
G: CodeGenerator + ?Sized,
{
let tyctx = generator.type_context(ctx.ctx);
let sizet_model = IntModel(SizeT);
match &*ctx.unifier.get_ty(shape_ty) {
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{
// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
// TODO: Remove ListValue with Model
let shape = ListValue::from_ptr_val(shape.into_pointer_value(), tyctx.size_type, None);
let len =
sizet_model.check_value(tyctx, ctx.ctx, shape.load_size(ctx, Some("len"))).unwrap();
ArrayWriter {
len,
write: Box::new(move |generator, ctx, dst_array| {
gen_model_for(
generator,
ctx,
sizet_model.constant(tyctx, ctx.ctx, 0),
len,
sizet_model.constant(tyctx, ctx.ctx, 1),
|generator, ctx, _hooks, i| {
let dim =
shape.data().get(ctx, generator, &i.value, None).into_int_value();
let dim = sizet_model.s_extend_or_bit_cast(tyctx, ctx, dim, "");
dst_array.offset(tyctx, ctx, i.value, "pdim").store(ctx, dim);
Ok(())
},
)
}),
}
}
TypeEnum::TTuple { ty: tuple_types } => {
// 2. A tuple of ints; e.g., `np.empty((600, 800, 3))`
let ndims = tuple_types.len();
// A tuple has to be a StructValue
// Read [`codegen::expr::gen_expr`] to see how `nac3core` translates a Python tuple into LLVM.
let shape = shape.into_struct_value();
ArrayWriter {
len: sizet_model.constant(tyctx, ctx.ctx, ndims as u64),
write: Box::new(move |_generator, ctx, dst_array| {
for axis in 0..ndims {
let dim = ctx
.builder
.build_extract_value(shape, axis as u32, format!("dim{axis}").as_str())
.unwrap()
.into_int_value();
let dim = sizet_model.s_extend_or_bit_cast(tyctx, ctx, dim, "");
dst_array
.offset(
tyctx,
ctx,
sizet_model.constant(tyctx, ctx.ctx, axis as u64).value,
"pdim",
)
.store(ctx, dim);
}
Ok(())
}),
}
}
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() =>
{
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
// The value has to be an integer
let shape_int = shape.into_int_value();
ArrayWriter {
len: sizet_model.constant(tyctx, ctx.ctx, 1),
write: Box::new(move |_generator, ctx, dst_array| {
let dim = sizet_model.s_extend_or_bit_cast(tyctx, ctx, shape_int, "");
// Set shape[0] = shape_int
dst_array
.offset(tyctx, ctx, sizet_model.constant(tyctx, ctx.ctx, 0).value, "pdim")
.store(ctx, dim);
Ok(())
}),
}
}
_ => panic!("encountered shape type"),
}
}

View File

@ -9,16 +9,20 @@ use inkwell::{
IntPredicate, IntPredicate,
}; };
use itertools::Either; use itertools::Either;
use ndarray::basic::call_nac3_ndarray_len;
use strum::IntoEnumIterator; use strum::IntoEnumIterator;
use crate::{ use crate::{
codegen::{ codegen::{
builtin_fns, builtin_fns,
classes::{ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor}, classes::{ProxyValue, RangeValue},
expr::destructure_range, expr::destructure_range,
irrt::*, irrt::*,
model::*,
numpy::*, numpy::*,
numpy_new,
stmt::exn_constructor, stmt::exn_constructor,
structure::ndarray::NpArray,
}, },
symbol_resolver::SymbolValue, symbol_resolver::SymbolValue,
toplevel::{helper::PrimDef, numpy::make_ndarray_ty}, toplevel::{helper::PrimDef, numpy::make_ndarray_ty},
@ -346,8 +350,8 @@ impl<'a> BuiltinBuilder<'a> {
let (is_some_ty, unwrap_ty, option_tvar) = let (is_some_ty, unwrap_ty, option_tvar) =
if let TypeEnum::TObj { fields, params, .. } = unifier.get_ty(option).as_ref() { if let TypeEnum::TObj { fields, params, .. } = unifier.get_ty(option).as_ref() {
( (
*fields.get(&PrimDef::OptionIsSome.simple_name().into()).unwrap(), *fields.get(&PrimDef::FunOptionIsSome.simple_name().into()).unwrap(),
*fields.get(&PrimDef::OptionUnwrap.simple_name().into()).unwrap(), *fields.get(&PrimDef::FunOptionUnwrap.simple_name().into()).unwrap(),
iter_type_vars(params).next().unwrap(), iter_type_vars(params).next().unwrap(),
) )
} else { } else {
@ -362,9 +366,9 @@ impl<'a> BuiltinBuilder<'a> {
let ndarray_dtype_tvar = iter_type_vars(ndarray_params).next().unwrap(); let ndarray_dtype_tvar = iter_type_vars(ndarray_params).next().unwrap();
let ndarray_ndims_tvar = iter_type_vars(ndarray_params).nth(1).unwrap(); let ndarray_ndims_tvar = iter_type_vars(ndarray_params).nth(1).unwrap();
let ndarray_copy_ty = let ndarray_copy_ty =
*ndarray_fields.get(&PrimDef::NDArrayCopy.simple_name().into()).unwrap(); *ndarray_fields.get(&PrimDef::FunNDArrayCopy.simple_name().into()).unwrap();
let ndarray_fill_ty = let ndarray_fill_ty =
*ndarray_fields.get(&PrimDef::NDArrayFill.simple_name().into()).unwrap(); *ndarray_fields.get(&PrimDef::FunNDArrayFill.simple_name().into()).unwrap();
let num_ty = unifier.get_fresh_var_with_range( let num_ty = unifier.get_fresh_var_with_range(
&[int32, int64, float, boolean, uint32, uint64], &[int32, int64, float, boolean, uint32, uint64],
@ -464,14 +468,14 @@ impl<'a> BuiltinBuilder<'a> {
PrimDef::Exception => self.build_exception_class_related(prim), PrimDef::Exception => self.build_exception_class_related(prim),
PrimDef::Option PrimDef::Option
| PrimDef::OptionIsSome | PrimDef::FunOptionIsSome
| PrimDef::OptionIsNone | PrimDef::FunOptionIsNone
| PrimDef::OptionUnwrap | PrimDef::FunOptionUnwrap
| PrimDef::FunSome => self.build_option_class_related(prim), | PrimDef::FunSome => self.build_option_class_related(prim),
PrimDef::List => self.build_list_class_related(prim), PrimDef::List => self.build_list_class_related(prim),
PrimDef::NDArray | PrimDef::NDArrayCopy | PrimDef::NDArrayFill => { PrimDef::NDArray | PrimDef::FunNDArrayCopy | PrimDef::FunNDArrayFill => {
self.build_ndarray_class_related(prim) self.build_ndarray_class_related(prim)
} }
@ -492,6 +496,8 @@ impl<'a> BuiltinBuilder<'a> {
| PrimDef::FunNpEye | PrimDef::FunNpEye
| PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim), | PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim),
PrimDef::FunNpReshape => self.build_ndarray_view_functions(prim),
PrimDef::FunStr => self.build_str_function(), PrimDef::FunStr => self.build_str_function(),
PrimDef::FunFloor | PrimDef::FunFloor64 | PrimDef::FunCeil | PrimDef::FunCeil64 => { PrimDef::FunFloor | PrimDef::FunFloor64 | PrimDef::FunCeil | PrimDef::FunCeil64 => {
@ -510,7 +516,9 @@ impl<'a> BuiltinBuilder<'a> {
PrimDef::FunMin | PrimDef::FunMax => self.build_min_max_function(prim), PrimDef::FunMin | PrimDef::FunMax => self.build_min_max_function(prim),
PrimDef::FunNpMin | PrimDef::FunNpMax => self.build_np_min_max_function(prim), PrimDef::FunNpArgmin | PrimDef::FunNpArgmax | PrimDef::FunNpMin | PrimDef::FunNpMax => {
self.build_np_max_min_function(prim)
}
PrimDef::FunNpMinimum | PrimDef::FunNpMaximum => { PrimDef::FunNpMinimum | PrimDef::FunNpMaximum => {
self.build_np_minimum_maximum_function(prim) self.build_np_minimum_maximum_function(prim)
@ -562,7 +570,7 @@ impl<'a> BuiltinBuilder<'a> {
match (&tld, prim.details()) { match (&tld, prim.details()) {
( (
TopLevelDef::Class { name, object_id, .. }, TopLevelDef::Class { name, object_id, .. },
PrimDefDetails::PrimClass { name: exp_name }, PrimDefDetails::PrimClass { name: exp_name, .. },
) => { ) => {
let exp_object_id = prim.id(); let exp_object_id = prim.id();
assert_eq!(name, &exp_name.into()); assert_eq!(name, &exp_name.into());
@ -792,9 +800,9 @@ impl<'a> BuiltinBuilder<'a> {
prim, prim,
&[ &[
PrimDef::Option, PrimDef::Option,
PrimDef::OptionIsSome, PrimDef::FunOptionIsSome,
PrimDef::OptionIsNone, PrimDef::FunOptionIsNone,
PrimDef::OptionUnwrap, PrimDef::FunOptionUnwrap,
PrimDef::FunSome, PrimDef::FunSome,
], ],
); );
@ -807,9 +815,9 @@ impl<'a> BuiltinBuilder<'a> {
fields: Vec::default(), fields: Vec::default(),
attributes: Vec::default(), attributes: Vec::default(),
methods: vec![ methods: vec![
Self::create_method(PrimDef::OptionIsSome, self.is_some_ty.0), Self::create_method(PrimDef::FunOptionIsSome, self.is_some_ty.0),
Self::create_method(PrimDef::OptionIsNone, self.is_some_ty.0), Self::create_method(PrimDef::FunOptionIsNone, self.is_some_ty.0),
Self::create_method(PrimDef::OptionUnwrap, self.unwrap_ty.0), Self::create_method(PrimDef::FunOptionUnwrap, self.unwrap_ty.0),
], ],
ancestors: vec![TypeAnnotation::CustomClass { ancestors: vec![TypeAnnotation::CustomClass {
id: prim.id(), id: prim.id(),
@ -820,7 +828,7 @@ impl<'a> BuiltinBuilder<'a> {
loc: None, loc: None,
}, },
PrimDef::OptionUnwrap => TopLevelDef::Function { PrimDef::FunOptionUnwrap => TopLevelDef::Function {
name: prim.name().into(), name: prim.name().into(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.unwrap_ty.0, signature: self.unwrap_ty.0,
@ -834,7 +842,7 @@ impl<'a> BuiltinBuilder<'a> {
loc: None, loc: None,
}, },
PrimDef::OptionIsNone | PrimDef::OptionIsSome => TopLevelDef::Function { PrimDef::FunOptionIsNone | PrimDef::FunOptionIsSome => TopLevelDef::Function {
name: prim.name().to_string(), name: prim.name().to_string(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.is_some_ty.0, signature: self.is_some_ty.0,
@ -855,10 +863,10 @@ impl<'a> BuiltinBuilder<'a> {
}; };
let returned_int = match prim { let returned_int = match prim {
PrimDef::OptionIsNone => { PrimDef::FunOptionIsNone => {
ctx.builder.build_is_null(ptr, prim.simple_name()) ctx.builder.build_is_null(ptr, prim.simple_name())
} }
PrimDef::OptionIsSome => { PrimDef::FunOptionIsSome => {
ctx.builder.build_is_not_null(ptr, prim.simple_name()) ctx.builder.build_is_not_null(ptr, prim.simple_name())
} }
_ => unreachable!(), _ => unreachable!(),
@ -931,7 +939,7 @@ impl<'a> BuiltinBuilder<'a> {
fn build_ndarray_class_related(&self, prim: PrimDef) -> TopLevelDef { fn build_ndarray_class_related(&self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed( debug_assert_prim_is_allowed(
prim, prim,
&[PrimDef::NDArray, PrimDef::NDArrayCopy, PrimDef::NDArrayFill], &[PrimDef::NDArray, PrimDef::FunNDArrayCopy, PrimDef::FunNDArrayFill],
); );
match prim { match prim {
@ -942,8 +950,8 @@ impl<'a> BuiltinBuilder<'a> {
fields: Vec::default(), fields: Vec::default(),
attributes: Vec::default(), attributes: Vec::default(),
methods: vec![ methods: vec![
Self::create_method(PrimDef::NDArrayCopy, self.ndarray_copy_ty.0), Self::create_method(PrimDef::FunNDArrayCopy, self.ndarray_copy_ty.0),
Self::create_method(PrimDef::NDArrayFill, self.ndarray_fill_ty.0), Self::create_method(PrimDef::FunNDArrayFill, self.ndarray_fill_ty.0),
], ],
ancestors: Vec::default(), ancestors: Vec::default(),
constructor: None, constructor: None,
@ -951,7 +959,7 @@ impl<'a> BuiltinBuilder<'a> {
loc: None, loc: None,
}, },
PrimDef::NDArrayCopy => TopLevelDef::Function { PrimDef::FunNDArrayCopy => TopLevelDef::Function {
name: prim.name().into(), name: prim.name().into(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.ndarray_copy_ty.0, signature: self.ndarray_copy_ty.0,
@ -968,7 +976,7 @@ impl<'a> BuiltinBuilder<'a> {
loc: None, loc: None,
}, },
PrimDef::NDArrayFill => TopLevelDef::Function { PrimDef::FunNDArrayFill => TopLevelDef::Function {
name: prim.name().into(), name: prim.name().into(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.ndarray_fill_ty.0, signature: self.ndarray_fill_ty.0,
@ -1200,9 +1208,11 @@ impl<'a> BuiltinBuilder<'a> {
&[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")], &[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
Box::new(move |ctx, obj, fun, args, generator| { Box::new(move |ctx, obj, fun, args, generator| {
let func = match prim { let func = match prim {
PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => gen_ndarray_empty, PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => {
PrimDef::FunNpZeros => gen_ndarray_zeros, numpy_new::factory::gen_ndarray_empty
PrimDef::FunNpOnes => gen_ndarray_ones, }
PrimDef::FunNpZeros => numpy_new::factory::gen_ndarray_zeros,
PrimDef::FunNpOnes => numpy_new::factory::gen_ndarray_ones,
_ => unreachable!(), _ => unreachable!(),
}; };
func(ctx, &obj, fun, &args, generator).map(|val| Some(val.as_basic_value_enum())) func(ctx, &obj, fun, &args, generator).map(|val| Some(val.as_basic_value_enum()))
@ -1270,7 +1280,7 @@ impl<'a> BuiltinBuilder<'a> {
// type variable // type variable
&[(self.list_int32, "shape"), (tv.ty, "fill_value")], &[(self.list_int32, "shape"), (tv.ty, "fill_value")],
Box::new(move |ctx, obj, fun, args, generator| { Box::new(move |ctx, obj, fun, args, generator| {
gen_ndarray_full(ctx, &obj, fun, &args, generator) numpy_new::factory::gen_ndarray_full(ctx, &obj, fun, &args, generator)
.map(|val| Some(val.as_basic_value_enum())) .map(|val| Some(val.as_basic_value_enum()))
}), }),
) )
@ -1325,6 +1335,41 @@ impl<'a> BuiltinBuilder<'a> {
} }
} }
// Build functions related to NDArray views
fn build_ndarray_view_functions(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpReshape]);
match prim {
PrimDef::FunNpReshape => {
// TODO: Support scalar inputs, e.g., `np.reshape(99, (1, 1, 1, 1))`
let new_ndim_ty = self.unifier.get_fresh_var(Some("NewNDim".into()), None);
let returned_ndarray_ty = make_ndarray_ty(
self.unifier,
self.primitives,
Some(self.ndarray_dtype_tvar.ty),
Some(new_ndim_ty.ty),
);
create_fn_by_codegen(
self.unifier,
&into_var_map([self.ndarray_dtype_tvar, self.ndarray_ndims_tvar, new_ndim_ty]),
prim.name(),
returned_ndarray_ty,
&[
(self.primitives.ndarray, "array"),
(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape"),
],
Box::new(|ctx, obj, fun, args, generator| {
numpy_new::view::gen_ndarray_reshape(ctx, &obj, fun, &args, generator)
.map(|val| Some(val.as_basic_value_enum()))
}),
)
}
_ => unreachable!(),
}
}
/// Build the `str()` function. /// Build the `str()` function.
fn build_str_function(&mut self) -> TopLevelDef { fn build_str_function(&mut self) -> TopLevelDef {
let prim = PrimDef::FunStr; let prim = PrimDef::FunStr;
@ -1462,51 +1507,13 @@ impl<'a> BuiltinBuilder<'a> {
} }
} }
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let llvm_i32 = ctx.ctx.i32_type(); let tyctx = generator.type_context(ctx.ctx);
let llvm_usize = generator.get_size_type(ctx.ctx); let pndarray_model = PtrModel(StructModel(NpArray));
let arg = NDArrayValue::from_ptr_val( let ndarray =
arg.into_pointer_value(), pndarray_model.check_value(tyctx, ctx.ctx, arg).unwrap();
llvm_usize, let len = call_nac3_ndarray_len(generator, ctx, ndarray);
None, Some(len.value.as_basic_value_enum())
);
let ndims = arg.dim_sizes().size(ctx, generator);
ctx.make_assert(
generator,
ctx.builder
.build_int_compare(
IntPredicate::NE,
ndims,
llvm_usize.const_zero(),
"",
)
.unwrap(),
"0:TypeError",
&format!("{name}() of unsized object", name = prim.name()),
[None, None, None],
ctx.current_loc,
);
let len = unsafe {
arg.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_zero(),
None,
)
};
if len.get_type().get_bit_width() == 32 {
Some(len.into())
} else {
Some(
ctx.builder
.build_int_truncate(len, llvm_i32, "len")
.map(Into::into)
.unwrap(),
)
}
} }
_ => unreachable!(), _ => unreachable!(),
} }
@ -1555,39 +1562,45 @@ impl<'a> BuiltinBuilder<'a> {
} }
} }
/// Build the functions `np_min()` and `np_max()`. /// Build the functions `np_max()`, `np_min()`, `np_argmax()` and `np_argmin()`
fn build_np_min_max_function(&mut self, prim: PrimDef) -> TopLevelDef { /// Calls `call_numpy_max_min` with the function name
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpMin, PrimDef::FunNpMax]); fn build_np_max_min_function(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(
prim,
&[PrimDef::FunNpArgmin, PrimDef::FunNpArgmax, PrimDef::FunNpMin, PrimDef::FunNpMax],
);
let ret_ty = self.unifier.get_fresh_var(Some("R".into()), None); let (var_map, ret_ty) = match prim {
let var_map = self PrimDef::FunNpArgmax | PrimDef::FunNpArgmin => {
.num_or_ndarray_var_map (self.num_or_ndarray_var_map.clone(), self.primitives.int64)
.clone() }
.into_iter() PrimDef::FunNpMax | PrimDef::FunNpMin => {
.chain(once((ret_ty.id, ret_ty.ty))) let ret_ty = self.unifier.get_fresh_var(Some("R".into()), None);
.collect::<IndexMap<_, _>>(); let var_map = self
.num_or_ndarray_var_map
.clone()
.into_iter()
.chain(once((ret_ty.id, ret_ty.ty)))
.collect::<IndexMap<_, _>>();
(var_map, ret_ty.ty)
}
_ => unreachable!(),
};
create_fn_by_codegen( create_fn_by_codegen(
self.unifier, self.unifier,
&var_map, &var_map,
prim.name(), prim.name(),
ret_ty.ty, ret_ty,
&[(self.float_or_ndarray_ty.ty, "a")], &[(self.num_or_ndarray_ty.ty, "a")],
Box::new(move |ctx, _, fun, args, generator| { Box::new(move |ctx, _, fun, args, generator| {
let a_ty = fun.0.args[0].ty; let a_ty = fun.0.args[0].ty;
let a = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?; let a = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?;
let func = match prim { Ok(Some(builtin_fns::call_numpy_max_min(generator, ctx, (a_ty, a), prim.name())?))
PrimDef::FunNpMin => builtin_fns::call_numpy_min,
PrimDef::FunNpMax => builtin_fns::call_numpy_max,
_ => unreachable!(),
};
Ok(Some(func(generator, ctx, (a_ty, a))?))
}), }),
) )
} }
/// Build the functions `np_minimum()` and `np_maximum()`. /// Build the functions `np_minimum()` and `np_maximum()`.
fn build_np_minimum_maximum_function(&mut self, prim: PrimDef) -> TopLevelDef { fn build_np_minimum_maximum_function(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpMinimum, PrimDef::FunNpMaximum]); debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpMinimum, PrimDef::FunNpMaximum]);

View File

@ -766,6 +766,7 @@ impl TopLevelComposer {
let target_ty = get_type_from_type_annotation_kinds( let target_ty = get_type_from_type_annotation_kinds(
&temp_def_list, &temp_def_list,
unifier, unifier,
primitives,
&def, &def,
&mut subst_list, &mut subst_list,
)?; )?;
@ -936,6 +937,7 @@ impl TopLevelComposer {
let ty = get_type_from_type_annotation_kinds( let ty = get_type_from_type_annotation_kinds(
temp_def_list.as_ref(), temp_def_list.as_ref(),
unifier, unifier,
primitives_store,
&type_annotation, &type_annotation,
&mut None, &mut None,
)?; )?;
@ -1002,6 +1004,7 @@ impl TopLevelComposer {
get_type_from_type_annotation_kinds( get_type_from_type_annotation_kinds(
&temp_def_list, &temp_def_list,
unifier, unifier,
primitives_store,
&return_ty_annotation, &return_ty_annotation,
&mut None, &mut None,
)? )?
@ -1622,6 +1625,7 @@ impl TopLevelComposer {
let self_type = get_type_from_type_annotation_kinds( let self_type = get_type_from_type_annotation_kinds(
&def_list, &def_list,
unifier, unifier,
primitives_ty,
&make_self_type_annotation(type_vars, *object_id), &make_self_type_annotation(type_vars, *object_id),
&mut None, &mut None,
)?; )?;
@ -1803,7 +1807,11 @@ impl TopLevelComposer {
let ty_ann = make_self_type_annotation(type_vars, *class_id); let ty_ann = make_self_type_annotation(type_vars, *class_id);
let self_ty = get_type_from_type_annotation_kinds( let self_ty = get_type_from_type_annotation_kinds(
&def_list, unifier, &ty_ann, &mut None, &def_list,
unifier,
primitives_ty,
&ty_ann,
&mut None,
)?; )?;
vars.extend(type_vars.iter().map(|ty| { vars.extend(type_vars.iter().map(|ty| {
let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*ty) else { let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*ty) else {

View File

@ -27,17 +27,22 @@ pub enum PrimDef {
List, List,
NDArray, NDArray,
// Member Functions // Option methods
OptionIsSome, FunOptionIsSome,
OptionIsNone, FunOptionIsNone,
OptionUnwrap, FunOptionUnwrap,
NDArrayCopy,
NDArrayFill, // Option-related functions
FunInt32, FunSome,
FunInt64,
FunUInt32, // NDArray methods
FunUInt64, FunNDArrayCopy,
FunFloat, FunNDArrayFill,
// Range methods
FunRangeInit,
// NumPy factory functions
FunNpNDArray, FunNpNDArray,
FunNpEmpty, FunNpEmpty,
FunNpZeros, FunNpZeros,
@ -46,26 +51,20 @@ pub enum PrimDef {
FunNpArray, FunNpArray,
FunNpEye, FunNpEye,
FunNpIdentity, FunNpIdentity,
FunRound,
FunRound64, // NumPy view functions
FunNpReshape,
// Miscellaneous NumPy & SciPy functions
FunNpRound, FunNpRound,
FunRangeInit,
FunStr,
FunBool,
FunFloor,
FunFloor64,
FunNpFloor, FunNpFloor,
FunCeil,
FunCeil64,
FunNpCeil, FunNpCeil,
FunLen,
FunMin,
FunNpMin, FunNpMin,
FunNpMinimum, FunNpMinimum,
FunMax, FunNpArgmin,
FunNpMax, FunNpMax,
FunNpMaximum, FunNpMaximum,
FunAbs, FunNpArgmax,
FunNpIsNan, FunNpIsNan,
FunNpIsInf, FunNpIsInf,
FunNpSin, FunNpSin,
@ -104,14 +103,30 @@ pub enum PrimDef {
FunNpHypot, FunNpHypot,
FunNpNextAfter, FunNpNextAfter,
// Top-Level Functions // Miscellaneous Python & NAC3 functions
FunSome, FunInt32,
FunInt64,
FunUInt32,
FunUInt64,
FunFloat,
FunRound,
FunRound64,
FunStr,
FunBool,
FunFloor,
FunFloor64,
FunCeil,
FunCeil64,
FunLen,
FunMin,
FunMax,
FunAbs,
} }
/// Associated details of a [`PrimDef`] /// Associated details of a [`PrimDef`]
pub enum PrimDefDetails { pub enum PrimDefDetails {
PrimFunction { name: &'static str, simple_name: &'static str }, PrimFunction { name: &'static str, simple_name: &'static str },
PrimClass { name: &'static str }, PrimClass { name: &'static str, get_ty_fn: fn(&PrimitiveStore) -> Type },
} }
impl PrimDef { impl PrimDef {
@ -153,15 +168,17 @@ impl PrimDef {
#[must_use] #[must_use]
pub fn name(&self) -> &'static str { pub fn name(&self) -> &'static str {
match self.details() { match self.details() {
PrimDefDetails::PrimFunction { name, .. } | PrimDefDetails::PrimClass { name } => name, PrimDefDetails::PrimFunction { name, .. } | PrimDefDetails::PrimClass { name, .. } => {
name
}
} }
} }
/// Get the associated details of this [`PrimDef`] /// Get the associated details of this [`PrimDef`]
#[must_use] #[must_use]
pub fn details(self) -> PrimDefDetails { pub fn details(self) -> PrimDefDetails {
fn class(name: &'static str) -> PrimDefDetails { fn class(name: &'static str, get_ty_fn: fn(&PrimitiveStore) -> Type) -> PrimDefDetails {
PrimDefDetails::PrimClass { name } PrimDefDetails::PrimClass { name, get_ty_fn }
} }
fn fun(name: &'static str, simple_name: Option<&'static str>) -> PrimDefDetails { fn fun(name: &'static str, simple_name: Option<&'static str>) -> PrimDefDetails {
@ -169,29 +186,37 @@ impl PrimDef {
} }
match self { match self {
PrimDef::Int32 => class("int32"), // Classes
PrimDef::Int64 => class("int64"), PrimDef::Int32 => class("int32", |primitives| primitives.int32),
PrimDef::Float => class("float"), PrimDef::Int64 => class("int64", |primitives| primitives.int64),
PrimDef::Bool => class("bool"), PrimDef::Float => class("float", |primitives| primitives.float),
PrimDef::None => class("none"), PrimDef::Bool => class("bool", |primitives| primitives.bool),
PrimDef::Range => class("range"), PrimDef::None => class("none", |primitives| primitives.none),
PrimDef::Str => class("str"), PrimDef::Range => class("range", |primitives| primitives.range),
PrimDef::Exception => class("Exception"), PrimDef::Str => class("str", |primitives| primitives.str),
PrimDef::UInt32 => class("uint32"), PrimDef::Exception => class("Exception", |primitives| primitives.exception),
PrimDef::UInt64 => class("uint64"), PrimDef::UInt32 => class("uint32", |primitives| primitives.uint32),
PrimDef::Option => class("Option"), PrimDef::UInt64 => class("uint64", |primitives| primitives.uint64),
PrimDef::OptionIsSome => fun("Option.is_some", Some("is_some")), PrimDef::Option => class("Option", |primitives| primitives.option),
PrimDef::OptionIsNone => fun("Option.is_none", Some("is_none")), PrimDef::List => class("list", |primitives| primitives.list),
PrimDef::OptionUnwrap => fun("Option.unwrap", Some("unwrap")), PrimDef::NDArray => class("ndarray", |primitives| primitives.ndarray),
PrimDef::List => class("list"),
PrimDef::NDArray => class("ndarray"), // Option methods
PrimDef::NDArrayCopy => fun("ndarray.copy", Some("copy")), PrimDef::FunOptionIsSome => fun("Option.is_some", Some("is_some")),
PrimDef::NDArrayFill => fun("ndarray.fill", Some("fill")), PrimDef::FunOptionIsNone => fun("Option.is_none", Some("is_none")),
PrimDef::FunInt32 => fun("int32", None), PrimDef::FunOptionUnwrap => fun("Option.unwrap", Some("unwrap")),
PrimDef::FunInt64 => fun("int64", None),
PrimDef::FunUInt32 => fun("uint32", None), // Option-related functions
PrimDef::FunUInt64 => fun("uint64", None), PrimDef::FunSome => fun("Some", None),
PrimDef::FunFloat => fun("float", None),
// NDArray methods
PrimDef::FunNDArrayCopy => fun("ndarray.copy", Some("copy")),
PrimDef::FunNDArrayFill => fun("ndarray.fill", Some("fill")),
// Range methods
PrimDef::FunRangeInit => fun("range.__init__", Some("__init__")),
// NumPy factory functions
PrimDef::FunNpNDArray => fun("np_ndarray", None), PrimDef::FunNpNDArray => fun("np_ndarray", None),
PrimDef::FunNpEmpty => fun("np_empty", None), PrimDef::FunNpEmpty => fun("np_empty", None),
PrimDef::FunNpZeros => fun("np_zeros", None), PrimDef::FunNpZeros => fun("np_zeros", None),
@ -200,26 +225,20 @@ impl PrimDef {
PrimDef::FunNpArray => fun("np_array", None), PrimDef::FunNpArray => fun("np_array", None),
PrimDef::FunNpEye => fun("np_eye", None), PrimDef::FunNpEye => fun("np_eye", None),
PrimDef::FunNpIdentity => fun("np_identity", None), PrimDef::FunNpIdentity => fun("np_identity", None),
PrimDef::FunRound => fun("round", None),
PrimDef::FunRound64 => fun("round64", None), // NumPy view functions
PrimDef::FunNpReshape => fun("np_reshape", None),
// Miscellaneous NumPy & SciPy functions
PrimDef::FunNpRound => fun("np_round", None), PrimDef::FunNpRound => fun("np_round", None),
PrimDef::FunRangeInit => fun("range.__init__", Some("__init__")),
PrimDef::FunStr => fun("str", None),
PrimDef::FunBool => fun("bool", None),
PrimDef::FunFloor => fun("floor", None),
PrimDef::FunFloor64 => fun("floor64", None),
PrimDef::FunNpFloor => fun("np_floor", None), PrimDef::FunNpFloor => fun("np_floor", None),
PrimDef::FunCeil => fun("ceil", None),
PrimDef::FunCeil64 => fun("ceil64", None),
PrimDef::FunNpCeil => fun("np_ceil", None), PrimDef::FunNpCeil => fun("np_ceil", None),
PrimDef::FunLen => fun("len", None),
PrimDef::FunMin => fun("min", None),
PrimDef::FunNpMin => fun("np_min", None), PrimDef::FunNpMin => fun("np_min", None),
PrimDef::FunNpMinimum => fun("np_minimum", None), PrimDef::FunNpMinimum => fun("np_minimum", None),
PrimDef::FunMax => fun("max", None), PrimDef::FunNpArgmin => fun("np_argmin", None),
PrimDef::FunNpMax => fun("np_max", None), PrimDef::FunNpMax => fun("np_max", None),
PrimDef::FunNpMaximum => fun("np_maximum", None), PrimDef::FunNpMaximum => fun("np_maximum", None),
PrimDef::FunAbs => fun("abs", None), PrimDef::FunNpArgmax => fun("np_argmax", None),
PrimDef::FunNpIsNan => fun("np_isnan", None), PrimDef::FunNpIsNan => fun("np_isnan", None),
PrimDef::FunNpIsInf => fun("np_isinf", None), PrimDef::FunNpIsInf => fun("np_isinf", None),
PrimDef::FunNpSin => fun("np_sin", None), PrimDef::FunNpSin => fun("np_sin", None),
@ -257,7 +276,25 @@ impl PrimDef {
PrimDef::FunNpLdExp => fun("np_ldexp", None), PrimDef::FunNpLdExp => fun("np_ldexp", None),
PrimDef::FunNpHypot => fun("np_hypot", None), PrimDef::FunNpHypot => fun("np_hypot", None),
PrimDef::FunNpNextAfter => fun("np_nextafter", None), PrimDef::FunNpNextAfter => fun("np_nextafter", None),
PrimDef::FunSome => fun("Some", None),
// Miscellaneous Python & NAC3 functions
PrimDef::FunInt32 => fun("int32", None),
PrimDef::FunInt64 => fun("int64", None),
PrimDef::FunUInt32 => fun("uint32", None),
PrimDef::FunUInt64 => fun("uint64", None),
PrimDef::FunFloat => fun("float", None),
PrimDef::FunRound => fun("round", None),
PrimDef::FunRound64 => fun("round64", None),
PrimDef::FunStr => fun("str", None),
PrimDef::FunBool => fun("bool", None),
PrimDef::FunFloor => fun("floor", None),
PrimDef::FunFloor64 => fun("floor64", None),
PrimDef::FunCeil => fun("ceil", None),
PrimDef::FunCeil64 => fun("ceil64", None),
PrimDef::FunLen => fun("len", None),
PrimDef::FunMin => fun("min", None),
PrimDef::FunMax => fun("max", None),
PrimDef::FunAbs => fun("abs", None),
} }
} }
} }
@ -408,9 +445,9 @@ impl TopLevelComposer {
let option = unifier.add_ty(TypeEnum::TObj { let option = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::Option.id(), obj_id: PrimDef::Option.id(),
fields: vec![ fields: vec![
(PrimDef::OptionIsSome.simple_name().into(), (is_some_type_fun_ty, true)), (PrimDef::FunOptionIsSome.simple_name().into(), (is_some_type_fun_ty, true)),
(PrimDef::OptionIsNone.simple_name().into(), (is_some_type_fun_ty, true)), (PrimDef::FunOptionIsNone.simple_name().into(), (is_some_type_fun_ty, true)),
(PrimDef::OptionUnwrap.simple_name().into(), (unwrap_fun_ty, true)), (PrimDef::FunOptionUnwrap.simple_name().into(), (unwrap_fun_ty, true)),
] ]
.into_iter() .into_iter()
.collect::<HashMap<_, _>>(), .collect::<HashMap<_, _>>(),
@ -451,8 +488,8 @@ impl TopLevelComposer {
let ndarray = unifier.add_ty(TypeEnum::TObj { let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::NDArray.id(), obj_id: PrimDef::NDArray.id(),
fields: Mapping::from([ fields: Mapping::from([
(PrimDef::NDArrayCopy.simple_name().into(), (ndarray_copy_fun_ty, true)), (PrimDef::FunNDArrayCopy.simple_name().into(), (ndarray_copy_fun_ty, true)),
(PrimDef::NDArrayFill.simple_name().into(), (ndarray_fill_fun_ty, true)), (PrimDef::FunNDArrayFill.simple_name().into(), (ndarray_fill_fun_ty, true)),
]), ]),
params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]), params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
}); });

View File

@ -5,7 +5,7 @@ expression: res_vec
[ [
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(245)]\n}\n", "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(248)]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",

View File

@ -7,7 +7,7 @@ expression: res_vec
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar234]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar234\"]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B[typevar237]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar237\"]\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",

View File

@ -5,8 +5,8 @@ expression: res_vec
[ [
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(247)]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(250)]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(252)]\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(255)]\n}\n",
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",

View File

@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec expression: res_vec
--- ---
[ [
"Class {\nname: \"A\",\nancestors: [\"A[typevar233, typevar234]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar233\", \"typevar234\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[typevar236, typevar237]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar236\", \"typevar237\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",

View File

@ -6,12 +6,12 @@ expression: res_vec
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(253)]\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(256)]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(261)]\n}\n", "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(264)]\n}\n",
] ]

View File

@ -1,8 +1,9 @@
use super::*; use super::*;
use crate::symbol_resolver::SymbolValue; use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::PrimDef; use crate::toplevel::helper::{PrimDef, PrimDefDetails};
use crate::typecheck::typedef::VarMap; use crate::typecheck::typedef::VarMap;
use nac3parser::ast::Constant; use nac3parser::ast::Constant;
use strum::IntoEnumIterator;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum TypeAnnotation { pub enum TypeAnnotation {
@ -357,6 +358,7 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
pub fn get_type_from_type_annotation_kinds( pub fn get_type_from_type_annotation_kinds(
top_level_defs: &[Arc<RwLock<TopLevelDef>>], top_level_defs: &[Arc<RwLock<TopLevelDef>>],
unifier: &mut Unifier, unifier: &mut Unifier,
primitives: &PrimitiveStore,
ann: &TypeAnnotation, ann: &TypeAnnotation,
subst_list: &mut Option<Vec<Type>>, subst_list: &mut Option<Vec<Type>>,
) -> Result<Type, HashSet<String>> { ) -> Result<Type, HashSet<String>> {
@ -379,100 +381,141 @@ pub fn get_type_from_type_annotation_kinds(
let param_ty = params let param_ty = params
.iter() .iter()
.map(|x| { .map(|x| {
get_type_from_type_annotation_kinds(top_level_defs, unifier, x, subst_list) get_type_from_type_annotation_kinds(
top_level_defs,
unifier,
primitives,
x,
subst_list,
)
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
let subst = { let ty = if let Some(prim_def) = PrimDef::iter().find(|prim| prim.id() == *obj_id) {
// check for compatible range // Primitive TopLevelDefs do not contain all fields that are present in their Type
// TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check // counterparts, so directly perform subst on the Type instead.
let mut result = VarMap::new();
for (tvar, p) in type_vars.iter().zip(param_ty) {
match unifier.get_ty(*tvar).as_ref() {
TypeEnum::TVar {
id,
range,
fields: None,
name,
loc,
is_const_generic: false,
} => {
let ok: bool = {
// create a temp type var and unify to check compatibility
p == *tvar || {
let temp = unifier.get_fresh_var_with_range(
range.as_slice(),
*name,
*loc,
);
unifier.unify(temp.ty, p).is_ok()
}
};
if ok {
result.insert(*id, p);
} else {
return Err(HashSet::from([format!(
"cannot apply type {} to type variable with id {:?}",
unifier.internal_stringify(
p,
&mut |id| format!("class{id}"),
&mut |id| format!("typevar{id}"),
&mut None
),
*id
)]));
}
}
TypeEnum::TVar { id, range, name, loc, is_const_generic: true, .. } => { let PrimDefDetails::PrimClass { get_ty_fn, .. } = prim_def.details() else {
let ty = range[0]; unreachable!()
let ok: bool = { };
// create a temp type var and unify to check compatibility
p == *tvar || {
let temp = unifier.get_fresh_const_generic_var(ty, *name, *loc);
unifier.unify(temp.ty, p).is_ok()
}
};
if ok {
result.insert(*id, p);
} else {
return Err(HashSet::from([format!(
"cannot apply type {} to type variable {}",
unifier.stringify(p),
name.unwrap_or_else(|| format!("typevar{id}").into()),
)]));
}
}
_ => unreachable!("must be generic type var"), let base_ty = get_ty_fn(primitives);
let params =
if let TypeEnum::TObj { params, .. } = &*unifier.get_ty_immutable(base_ty) {
params.clone()
} else {
unreachable!()
};
unifier
.subst(
get_ty_fn(primitives),
&params
.iter()
.zip(param_ty)
.map(|(obj_tv, param)| (*obj_tv.0, param))
.collect(),
)
.unwrap_or(base_ty)
} else {
let subst = {
// check for compatible range
// TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check
let mut result = VarMap::new();
for (tvar, p) in type_vars.iter().zip(param_ty) {
match unifier.get_ty(*tvar).as_ref() {
TypeEnum::TVar {
id,
range,
fields: None,
name,
loc,
is_const_generic: false,
} => {
let ok: bool = {
// create a temp type var and unify to check compatibility
p == *tvar || {
let temp = unifier.get_fresh_var_with_range(
range.as_slice(),
*name,
*loc,
);
unifier.unify(temp.ty, p).is_ok()
}
};
if ok {
result.insert(*id, p);
} else {
return Err(HashSet::from([format!(
"cannot apply type {} to type variable with id {:?}",
unifier.internal_stringify(
p,
&mut |id| format!("class{id}"),
&mut |id| format!("typevar{id}"),
&mut None
),
*id
)]));
}
}
TypeEnum::TVar {
id, range, name, loc, is_const_generic: true, ..
} => {
let ty = range[0];
let ok: bool = {
// create a temp type var and unify to check compatibility
p == *tvar || {
let temp =
unifier.get_fresh_const_generic_var(ty, *name, *loc);
unifier.unify(temp.ty, p).is_ok()
}
};
if ok {
result.insert(*id, p);
} else {
return Err(HashSet::from([format!(
"cannot apply type {} to type variable {}",
unifier.stringify(p),
name.unwrap_or_else(|| format!("typevar{id}").into()),
)]));
}
}
_ => unreachable!("must be generic type var"),
}
}
result
};
// Class Attributes keep a copy with Class Definition and are not added to objects
let mut tobj_fields = methods
.iter()
.map(|(name, ty, _)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
// methods are immutable
(*name, (subst_ty, false))
})
.collect::<HashMap<_, _>>();
tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*name, (subst_ty, *mutability))
}));
let need_subst = !subst.is_empty();
let ty = unifier.add_ty(TypeEnum::TObj {
obj_id: *obj_id,
fields: tobj_fields,
params: subst,
});
if need_subst {
if let Some(wl) = subst_list.as_mut() {
wl.push(ty);
} }
} }
result
ty
}; };
// Class Attributes keep a copy with Class Definition and are not added to objects
let mut tobj_fields = methods
.iter()
.map(|(name, ty, _)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
// methods are immutable
(*name, (subst_ty, false))
})
.collect::<HashMap<_, _>>();
tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*name, (subst_ty, *mutability))
}));
let need_subst = !subst.is_empty();
let ty = unifier.add_ty(TypeEnum::TObj {
obj_id: *obj_id,
fields: tobj_fields,
params: subst,
});
if need_subst {
if let Some(wl) = subst_list.as_mut() {
wl.push(ty);
}
}
Ok(ty) Ok(ty)
} }
TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty), TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty),
@ -490,6 +533,7 @@ pub fn get_type_from_type_annotation_kinds(
let ty = get_type_from_type_annotation_kinds( let ty = get_type_from_type_annotation_kinds(
top_level_defs, top_level_defs,
unifier, unifier,
primitives,
ty.as_ref(), ty.as_ref(),
subst_list, subst_list,
)?; )?;
@ -499,7 +543,13 @@ pub fn get_type_from_type_annotation_kinds(
let tys = tys let tys = tys
.iter() .iter()
.map(|x| { .map(|x| {
get_type_from_type_annotation_kinds(top_level_defs, unifier, x, subst_list) get_type_from_type_annotation_kinds(
top_level_defs,
unifier,
primitives,
x,
subst_list,
)
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
Ok(unifier.add_ty(TypeEnum::TTuple { ty: tys })) Ok(unifier.add_ty(TypeEnum::TTuple { ty: tys }))

View File

@ -389,7 +389,7 @@ impl<'a> Fold<()> for Inferencer<'a> {
} }
ast::StmtKind::Assign { targets, value, .. } => { ast::StmtKind::Assign { targets, value, .. } => {
for target in targets { for target in targets {
self.unify(target.custom.unwrap(), value.custom.unwrap(), &target.location)?; self.fold_assign(target, value)?;
} }
} }
ast::StmtKind::Raise { exc, cause, .. } => { ast::StmtKind::Raise { exc, cause, .. } => {
@ -398,7 +398,10 @@ impl<'a> Fold<()> for Inferencer<'a> {
} }
if let Some(exc) = exc { if let Some(exc) = exc {
self.virtual_checks.push(( self.virtual_checks.push((
exc.custom.unwrap(), match &*self.unifier.get_ty(exc.custom.unwrap()) {
TypeEnum::TFunc(sign) => sign.ret,
_ => exc.custom.unwrap(),
},
self.primitives.exception, self.primitives.exception,
exc.location, exc.location,
)); ));
@ -1387,6 +1390,55 @@ impl<'a> Inferencer<'a> {
})); }));
} }
// Handle `np.reshape(<array>, <shape>)`
if ["np_reshape".into()].contains(id) && args.len() == 2 {
// Extract arguments
let array_expr = args.remove(0);
let shape_expr = args.remove(0);
// Fold `<array>`
let array = self.fold_expr(array_expr)?;
let array_ty = array.custom.unwrap();
let (array_dtype, _) = unpack_ndarray_var_tys(self.unifier, array_ty);
// Fold `<shape>`
let (target_ndims, target_shape) =
self.fold_numpy_function_call_shape_argument(*id, 0, shape_expr)?;
let target_shape_ty = target_shape.custom.unwrap();
// ... and deduce the return type of the call
let target_ndims_ty =
self.unifier.get_fresh_literal(vec![SymbolValue::U64(target_ndims)], None);
let ret = make_ndarray_ty(
self.unifier,
self.primitives,
Some(array_dtype),
Some(target_ndims_ty),
);
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg { name: "array".into(), ty: array_ty, default_value: None },
FuncArg { name: "shape".into(), ty: target_shape_ty, default_value: None },
],
ret,
vars: VarMap::new(),
}));
return Ok(Some(Located {
location,
custom: Some(ret),
node: ExprKind::Call {
func: Box::new(Located {
custom: Some(custom),
location: func.location,
node: ExprKind::Name { id: *id, ctx: *ctx },
}),
args: vec![array, target_shape],
keywords: vec![],
},
}));
}
// 2-argument ndarray n-dimensional creation functions // 2-argument ndarray n-dimensional creation functions
if id == &"np_full".into() && args.len() == 2 { if id == &"np_full".into() && args.len() == 2 {
let ExprKind::List { elts, .. } = &args[0].node else { let ExprKind::List { elts, .. } = &args[0].node else {
@ -2107,4 +2159,58 @@ impl<'a> Inferencer<'a> {
self.constrain(body.custom.unwrap(), orelse.custom.unwrap(), &body.location)?; self.constrain(body.custom.unwrap(), orelse.custom.unwrap(), &body.location)?;
Ok(body.custom.unwrap()) Ok(body.custom.unwrap())
} }
fn fold_assign(
&mut self,
target: &ast::Expr<Option<Type>>,
value: &ast::Expr<Option<Type>>,
) -> Result<(), HashSet<String>> {
let target_ty = target.custom.unwrap();
let value_ty = value.custom.unwrap();
match (&target.node, &*self.unifier.get_ty(target_ty)) {
(ExprKind::Subscript { .. }, TypeEnum::TObj { obj_id: target_obj_id, .. })
if *target_obj_id == self.primitives.ndarray.obj_id(self.unifier).unwrap() =>
{
// Pattern match expressions like `my_ndarray[slices] = value`.
// TODO: `(my_ndarray[slices1], my_ndarray[slices2]) = (value1, value2)` are not supported for now.
// Suppose `my_ndarray` has type `ndarray[target_dtype, ndims]`
// value's type could be one of the following:
// Case 1. `target_dtype`
// Case 2. `ndarray[target_dtype, ?]`
// Case 3. list, tuple, iterables (TODO: NOT IMPLEMENTED)
let (target_dtype, _) = unpack_ndarray_var_tys(self.unifier, target_ty);
// Typecheck `value_ty`
match &*self.unifier.get_ty(value_ty) {
TypeEnum::TObj { obj_id: value_obj_id, .. }
if *value_obj_id
== self.primitives.ndarray.obj_id(self.unifier).unwrap() =>
{
// Case 2
// - `dtype` of `target_ty` and `value_ty` must unify.
// - `ndims` of `value_ty` is ignored.
let (value_dtype, _) = unpack_ndarray_var_tys(self.unifier, value_ty);
self.unify(target_dtype, value_dtype, &target.location)?;
}
_ => {
// If `value_ty` is not an ndarray, simply typecheck as through it has to be Case 1.
self.unify(target_dtype, value_ty, &target.location)?;
}
}
}
_ => {
// To handle
// - variable assignments `target = value`
// - and attribute assignments `target.my_attr = value`
//
// For these cases in nac3core, types of LHS and RHS must unify
self.unify(target_ty, value_ty, &target.location)?;
}
}
Ok(())
}
} }

View File

@ -14,12 +14,21 @@ while [ $# -gt 1 ]; do
done done
demo="$1" demo="$1"
echo -n "Checking $demo... " echo "### Checking $demo..."
./interpret_demo.py "$demo" > interpreted.log
./run_demo.sh --out run.log "${nac3args[@]}" "$demo"
./run_demo.sh --lli --out run_lli.log "${nac3args[@]}" "$demo"
diff -Nau interpreted.log run.log
diff -Nau interpreted.log run_lli.log
echo "ok"
rm -f interpreted.log run.log run_lli.log # Get reference output
echo ">>>>>> Running $demo with the Python interpreter"
./interpret_demo.py "$demo" > interpreted.log
echo "...... Trying NAC3's 32-bit code generator output"
./run_demo.sh -i386 --out run_32.log "${nac3args[@]}" "$demo"
diff -Nau interpreted.log run_32.log
echo "...... Trying NAC3's 64-bit code generator output"
./run_demo.sh --out run_64.log "${nac3args[@]}" "$demo"
diff -Nau interpreted.log run_64.log
echo "...... OK"
rm -f interpreted.log \
run_32.log run_64.log

View File

@ -6,8 +6,6 @@
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#define usize size_t
double dbl_nan(void) { double dbl_nan(void) {
return NAN; return NAN;
} }
@ -64,14 +62,14 @@ void output_asciiart(int32_t x) {
struct cslice { struct cslice {
void *data; void *data;
usize len; size_t len;
}; };
void output_int32_list(struct cslice *slice) { void output_int32_list(struct cslice *slice) {
const int32_t *data = (int32_t *) slice->data; const int32_t *data = (int32_t *) slice->data;
putchar('['); putchar('[');
for (usize i = 0; i < slice->len; ++i) { for (size_t i = 0; i < slice->len; ++i) {
if (i == slice->len - 1) { if (i == slice->len - 1) {
printf("%d", data[i]); printf("%d", data[i]);
} else { } else {
@ -85,7 +83,7 @@ void output_int32_list(struct cslice *slice) {
void output_str(struct cslice *slice) { void output_str(struct cslice *slice) {
const char *data = (const char *) slice->data; const char *data = (const char *) slice->data;
for (usize i = 0; i < slice->len; ++i) { for (size_t i = 0; i < slice->len; ++i) {
putchar(data[i]); putchar(data[i]);
} }
} }
@ -107,8 +105,25 @@ uint32_t __nac3_personality(uint32_t state, uint32_t exception_object, uint32_t
__builtin_unreachable(); __builtin_unreachable();
} }
uint32_t __nac3_raise(uint32_t state, uint32_t exception_object, uint32_t context) { // See `struct Exception<'a>` in
printf("__nac3_raise(state: %u, exception_object: %u, context: %u)\n", state, exception_object, context); // https://github.com/m-labs/artiq/blob/master/artiq/firmware/libeh/eh_artiq.rs
struct Exception {
uint32_t id;
struct cslice file;
uint32_t line;
uint32_t column;
struct cslice function;
struct cslice message;
int64_t param[3];
};
uint32_t __nac3_raise(struct Exception* e) {
printf("__nac3_raise called. Exception details:\n");
printf(" ID: %"PRIu32"\n", e->id);
printf(" Location: %*s:%"PRIu32":%"PRIu32"\n" , (int) e->file.len, (const char*) e->file.data, e->line, e->column);
printf(" Function: %*s\n" , (int) e->function.len, (const char*) e->function.data);
printf(" Message: \"%*s\"\n" , (int) e->message.len, (const char*) e->message.data);
printf(" Params: {0}=%"PRId64", {1}=%"PRId64", {2}=%"PRId64"\n", e->param[0], e->param[1], e->param[2]);
exit(101); exit(101);
__builtin_unreachable(); __builtin_unreachable();
} }

View File

@ -167,7 +167,7 @@ def patch(module):
module.ceil64 = _ceil module.ceil64 = _ceil
module.np_ceil = np.ceil module.np_ceil = np.ceil
# NumPy ndarray functions # NumPy NDArray factory functions
module.ndarray = NDArray module.ndarray = NDArray
module.np_ndarray = np.ndarray module.np_ndarray = np.ndarray
module.np_empty = np.empty module.np_empty = np.empty
@ -178,13 +178,18 @@ def patch(module):
module.np_identity = np.identity module.np_identity = np.identity
module.np_array = np.array module.np_array = np.array
# NumPy view functions
module.np_reshape = np.reshape
# NumPy Math functions # NumPy Math functions
module.np_isnan = np.isnan module.np_isnan = np.isnan
module.np_isinf = np.isinf module.np_isinf = np.isinf
module.np_min = np.min module.np_min = np.min
module.np_minimum = np.minimum module.np_minimum = np.minimum
module.np_argmin = np.argmin
module.np_max = np.max module.np_max = np.max
module.np_maximum = np.maximum module.np_maximum = np.maximum
module.np_argmax = np.argmax
module.np_sin = np.sin module.np_sin = np.sin
module.np_cos = np.cos module.np_cos = np.cos
module.np_exp = np.exp module.np_exp = np.exp
@ -216,7 +221,7 @@ def patch(module):
module.np_hypot = np.hypot module.np_hypot = np.hypot
module.np_nextafter = np.nextafter module.np_nextafter = np.nextafter
# SciPy Math Functions # SciPy Math functions
module.sp_spec_erf = special.erf module.sp_spec_erf = special.erf
module.sp_spec_erfc = special.erfc module.sp_spec_erfc = special.erfc
module.sp_spec_gamma = special.gamma module.sp_spec_gamma = special.gamma
@ -224,15 +229,6 @@ def patch(module):
module.sp_spec_j0 = special.j0 module.sp_spec_j0 = special.j0
module.sp_spec_j1 = special.j1 module.sp_spec_j1 = special.j1
# NumPy NDArray Functions
module.np_ndarray = np.ndarray
module.np_empty = np.empty
module.np_zeros = np.zeros
module.np_ones = np.ones
module.np_full = np.full
module.np_eye = np.eye
module.np_identity = np.identity
def file_import(filename, prefix="file_import_"): def file_import(filename, prefix="file_import_"):
filename = pathlib.Path(filename) filename = pathlib.Path(filename)
modname = prefix + filename.stem modname = prefix + filename.stem

View File

@ -11,19 +11,19 @@ declare -a nac3args
while [ $# -ge 1 ]; do while [ $# -ge 1 ]; do
case "$1" in case "$1" in
--help) --help)
echo "Usage: run_demo.sh [--help] [--out OUTFILE] [--lli] [--debug] -- [NAC3ARGS...]" echo "Usage: run_demo.sh [--help] [--out OUTFILE] [--debug] [-i386] -- [NAC3ARGS...]"
exit exit
;; ;;
--out) --out)
shift shift
outfile="$1" outfile="$1"
;; ;;
--lli)
use_lli=1
;;
--debug) --debug)
debug=1 debug=1
;; ;;
-i386)
i386=1
;;
--) --)
shift shift
break break
@ -50,29 +50,23 @@ else
fi fi
rm -f ./*.o ./*.bc demo rm -f ./*.o ./*.bc demo
if [ -z "$use_lli" ]; then
if [ -z "$i386" ]; then
$nac3standalone "${nac3args[@]}" $nac3standalone "${nac3args[@]}"
clang -c -std=gnu11 -Wall -Wextra -O3 -o demo.o demo.c clang -c -std=gnu11 -Wall -Wextra -O3 -o demo.o demo.c
clang -lm -o demo module.o demo.o clang -lm -Wl,--no-warn-search-mismatch -o demo module.o demo.o
if [ -z "$outfile" ]; then
./demo
else
./demo > "$outfile"
fi
else else
$nac3standalone --emit-llvm "${nac3args[@]}" # Enable SSE2 to avoid rounding errors with X87's 80-bit fp precision computations
clang -c -std=gnu11 -Wall -Wextra -O3 -emit-llvm -o demo.bc demo.c $nac3standalone --triple i386-pc-linux-gnu --target-features +sse2 "${nac3args[@]}"
shopt -s nullglob clang -m32 -c -std=gnu11 -Wall -Wextra -O3 -msse2 -o demo.o demo.c
llvm-link -o nac3out.bc module*.bc main.bc clang -m32 -lm -Wl,--no-warn-search-mismatch -o demo module.o demo.o
shopt -u nullglob fi
if [ -z "$outfile" ]; then if [ -z "$outfile" ]; then
lli --extra-module demo.bc --extra-module irrt.bc nac3out.bc ./demo
else else
lli --extra-module demo.bc --extra-module irrt.bc nac3out.bc > "$outfile" ./demo > "$outfile"
fi
fi fi

View File

@ -867,6 +867,13 @@ def test_ndarray_minimum_broadcast_rhs_scalar():
output_ndarray_float_2(min_x_zeros) output_ndarray_float_2(min_x_zeros)
output_ndarray_float_2(min_x_ones) output_ndarray_float_2(min_x_ones)
def test_ndarray_argmin():
x = np_array([[1., 2.], [3., 4.]])
y = np_argmin(x)
output_ndarray_float_2(x)
output_int64(y)
def test_ndarray_max(): def test_ndarray_max():
x = np_identity(2) x = np_identity(2)
y = np_max(x) y = np_max(x)
@ -910,6 +917,13 @@ def test_ndarray_maximum_broadcast_rhs_scalar():
output_ndarray_float_2(max_x_zeros) output_ndarray_float_2(max_x_zeros)
output_ndarray_float_2(max_x_ones) output_ndarray_float_2(max_x_ones)
def test_ndarray_argmax():
x = np_array([[1., 2.], [3., 4.]])
y = np_argmax(x)
output_ndarray_float_2(x)
output_int64(y)
def test_ndarray_abs(): def test_ndarray_abs():
x = np_identity(2) x = np_identity(2)
y = abs(x) y = abs(x)
@ -1524,11 +1538,13 @@ def run() -> int32:
test_ndarray_minimum_broadcast() test_ndarray_minimum_broadcast()
test_ndarray_minimum_broadcast_lhs_scalar() test_ndarray_minimum_broadcast_lhs_scalar()
test_ndarray_minimum_broadcast_rhs_scalar() test_ndarray_minimum_broadcast_rhs_scalar()
test_ndarray_argmin()
test_ndarray_max() test_ndarray_max()
test_ndarray_maximum() test_ndarray_maximum()
test_ndarray_maximum_broadcast() test_ndarray_maximum_broadcast()
test_ndarray_maximum_broadcast_lhs_scalar() test_ndarray_maximum_broadcast_lhs_scalar()
test_ndarray_maximum_broadcast_rhs_scalar() test_ndarray_maximum_broadcast_rhs_scalar()
test_ndarray_argmax()
test_ndarray_abs() test_ndarray_abs()
test_ndarray_isnan() test_ndarray_isnan()
test_ndarray_isinf() test_ndarray_isinf()

View File

@ -9,15 +9,11 @@
#![allow(clippy::too_many_lines, clippy::wildcard_imports)] #![allow(clippy::too_many_lines, clippy::wildcard_imports)]
use clap::Parser; use clap::Parser;
use inkwell::context::Context;
use inkwell::{ use inkwell::{
memory_buffer::MemoryBuffer, passes::PassBuilderOptions, support::is_multithreaded, targets::*, memory_buffer::MemoryBuffer, passes::PassBuilderOptions, support::is_multithreaded, targets::*,
OptimizationLevel, OptimizationLevel,
}; };
use parking_lot::{Mutex, RwLock};
use std::collections::HashSet;
use std::num::NonZeroUsize;
use std::{collections::HashMap, fs, path::Path, sync::Arc};
use nac3core::{ use nac3core::{
codegen::{ codegen::{
concrete_type::ConcreteTypeStore, irrt::load_irrt, CodeGenLLVMOptions, concrete_type::ConcreteTypeStore, irrt::load_irrt, CodeGenLLVMOptions,
@ -39,6 +35,10 @@ use nac3parser::{
ast::{Constant, Expr, ExprKind, StmtKind, StrRef}, ast::{Constant, Expr, ExprKind, StmtKind, StrRef},
parser, parser,
}; };
use parking_lot::{Mutex, RwLock};
use std::collections::HashSet;
use std::num::NonZeroUsize;
use std::{collections::HashMap, fs, path::Path, sync::Arc};
mod basic_symbol_resolver; mod basic_symbol_resolver;
use basic_symbol_resolver::*; use basic_symbol_resolver::*;
@ -113,7 +113,9 @@ fn handle_typevar_definition(
x, x,
HashMap::new(), HashMap::new(),
)?; )?;
get_type_from_type_annotation_kinds(def_list, unifier, &ty, &mut None) get_type_from_type_annotation_kinds(
def_list, unifier, primitives, &ty, &mut None,
)
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
let loc = func.location; let loc = func.location;
@ -152,7 +154,7 @@ fn handle_typevar_definition(
HashMap::new(), HashMap::new(),
)?; )?;
let constraint = let constraint =
get_type_from_type_annotation_kinds(def_list, unifier, &ty, &mut None)?; get_type_from_type_annotation_kinds(def_list, unifier, primitives, &ty, &mut None)?;
let loc = func.location; let loc = func.location;
Ok(unifier.get_fresh_const_generic_var(constraint, Some(generic_name), Some(loc)).ty) Ok(unifier.get_fresh_const_generic_var(constraint, Some(generic_name), Some(loc)).ty)
@ -239,8 +241,6 @@ fn handle_assignment_pattern(
} }
fn main() { fn main() {
const SIZE_T: u32 = usize::BITS;
let cli = CommandLineArgs::parse(); let cli = CommandLineArgs::parse();
let CommandLineArgs { file_name, threads, opt_level, emit_llvm, triple, mcpu, target_features } = let CommandLineArgs { file_name, threads, opt_level, emit_llvm, triple, mcpu, target_features } =
cli; cli;
@ -273,6 +273,24 @@ fn main() {
_ => OptimizationLevel::Aggressive, _ => OptimizationLevel::Aggressive,
}; };
let target_machine_options = CodeGenTargetMachineOptions {
triple,
cpu: mcpu,
features: target_features,
reloc_mode: RelocMode::PIC,
..host_target_machine
};
let size_t = Context::create()
.ptr_sized_int_type(
&target_machine_options
.create_target_machine(opt_level)
.map(|tm| tm.get_target_data())
.unwrap(),
None,
)
.get_bit_width();
let program = match fs::read_to_string(file_name.clone()) { let program = match fs::read_to_string(file_name.clone()) {
Ok(program) => program, Ok(program) => program,
Err(err) => { Err(err) => {
@ -281,9 +299,9 @@ fn main() {
} }
}; };
let primitive: PrimitiveStore = TopLevelComposer::make_primitives(SIZE_T).0; let primitive: PrimitiveStore = TopLevelComposer::make_primitives(size_t).0;
let (mut composer, builtins_def, builtins_ty) = let (mut composer, builtins_def, builtins_ty) =
TopLevelComposer::new(vec![], ComposerConfig::default(), SIZE_T); TopLevelComposer::new(vec![], ComposerConfig::default(), size_t);
let internal_resolver: Arc<ResolverInternal> = ResolverInternal { let internal_resolver: Arc<ResolverInternal> = ResolverInternal {
id_to_type: builtins_ty.into(), id_to_type: builtins_ty.into(),
@ -371,16 +389,7 @@ fn main() {
instance_to_stmt[""].clone() instance_to_stmt[""].clone()
}; };
let llvm_options = CodeGenLLVMOptions { let llvm_options = CodeGenLLVMOptions { opt_level, target: target_machine_options };
opt_level,
target: CodeGenTargetMachineOptions {
triple,
cpu: mcpu,
features: target_features,
reloc_mode: RelocMode::PIC,
..host_target_machine
},
};
let task = CodeGenTask { let task = CodeGenTask {
subst: Vec::default(), subst: Vec::default(),
@ -403,7 +412,7 @@ fn main() {
membuffer.lock().push(buffer); membuffer.lock().push(buffer);
}))); })));
let threads = (0..threads) let threads = (0..threads)
.map(|i| Box::new(DefaultCodeGenerator::new(format!("module{i}"), SIZE_T))) .map(|i| Box::new(DefaultCodeGenerator::new(format!("module{i}"), size_t)))
.collect(); .collect();
let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f); let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f);
registry.add_task(task); registry.add_task(task);