Compare commits

..

9 Commits

Author SHA1 Message Date
lyken e719d9396d asd 2024-07-11 00:34:06 +08:00
lyken 27f2e8b391 core: irrt general numpy broadcasting 2024-07-10 20:10:14 +08:00
lyken 5f4c406b37 core: irrt general numpy slicing 2024-07-10 20:10:14 +08:00
lyken 31ab9675ca core: more irrt 2024-07-10 20:10:14 +08:00
lyken 5fd5d65377 core: build.rs rewrite regex to capture `= type` 2024-07-10 20:10:14 +08:00
lyken 01042aecfb core: move irrt c++ sources to /nac3core/irrt 2024-07-10 20:10:14 +08:00
lyken 8754f252f6 core: IRRT -Werror=return-type 2024-07-10 20:10:14 +08:00
lyken 17207a4ebe core: add irrt_test 2024-07-10 20:10:14 +08:00
lyken e3a4675fc6 core: comment out numpy 2024-07-10 20:10:05 +08:00
105 changed files with 8667 additions and 12216 deletions

View File

@ -1,32 +0,0 @@
BasedOnStyle: LLVM
Language: Cpp
Standard: Cpp11
AccessModifierOffset: -1
AlignEscapedNewlines: Left
AlwaysBreakAfterReturnType: None
AlwaysBreakTemplateDeclarations: Yes
AllowAllParametersOfDeclarationOnNextLine: false
AllowShortFunctionsOnASingleLine: Inline
BinPackParameters: false
BreakBeforeBinaryOperators: NonAssignment
BreakBeforeTernaryOperators: true
BreakConstructorInitializers: AfterColon
BreakInheritanceList: AfterColon
ColumnLimit: 120
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ContinuationIndentWidth: 4
DerivePointerAlignment: false
IndentCaseLabels: true
IndentPPDirectives: None
IndentWidth: 4
MaxEmptyLinesToKeep: 1
PointerAlignment: Left
ReflowComments: true
SortIncludes: false
SortUsingDeclarations: true
SpaceAfterTemplateKeyword: false
SpacesBeforeTrailingComments: 2
TabWidth: 4
UseTab: Never

1
.gitignore vendored
View File

@ -1,4 +1,3 @@
__pycache__ __pycache__
/target /target
/nac3standalone/demo/linalg/target
nix/windows/msys2 nix/windows/msys2

View File

@ -8,17 +8,17 @@ repos:
hooks: hooks:
- id: nac3-cargo-fmt - id: nac3-cargo-fmt
name: nac3 cargo format name: nac3 cargo format
entry: nix entry: cargo
language: system language: system
types: [file, rust] types: [file, rust]
pass_filenames: false pass_filenames: false
description: Runs cargo fmt on the codebase. description: Runs cargo fmt on the codebase.
args: [develop, -c, cargo, fmt, --all] args: [fmt]
- id: nac3-cargo-clippy - id: nac3-cargo-clippy
name: nac3 cargo clippy name: nac3 cargo clippy
entry: nix entry: cargo
language: system language: system
types: [file, rust] types: [file, rust]
pass_filenames: false pass_filenames: false
description: Runs cargo clippy on the codebase. description: Runs cargo clippy on the codebase.
args: [develop, -c, cargo, clippy, --tests] args: [clippy, --tests]

443
Cargo.lock generated
View File

@ -26,9 +26,9 @@ dependencies = [
[[package]] [[package]]
name = "anstream" name = "anstream"
version = "0.6.17" version = "0.6.14"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "23a1e53f0f5d86382dafe1cf314783b2044280f406e7e1506368220ad11b1338" checksum = "418c75fa768af9c03be99d17643f93f79bbba589895012a80e3452a19ddda15b"
dependencies = [ dependencies = [
"anstyle", "anstyle",
"anstyle-parse", "anstyle-parse",
@ -41,67 +41,67 @@ dependencies = [
[[package]] [[package]]
name = "anstyle" name = "anstyle"
version = "1.0.9" version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8365de52b16c035ff4fcafe0092ba9390540e3e352870ac09933bebcaa2c8c56" checksum = "038dfcf04a5feb68e9c60b21c9625a54c2c0616e79b72b0fd87075a056ae1d1b"
[[package]] [[package]]
name = "anstyle-parse" name = "anstyle-parse"
version = "0.2.6" version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9" checksum = "c03a11a9034d92058ceb6ee011ce58af4a9bf61491aa7e1e59ecd24bd40d22d4"
dependencies = [ dependencies = [
"utf8parse", "utf8parse",
] ]
[[package]] [[package]]
name = "anstyle-query" name = "anstyle-query"
version = "1.1.2" version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c" checksum = "ad186efb764318d35165f1758e7dcef3b10628e26d41a44bc5550652e6804391"
dependencies = [ dependencies = [
"windows-sys 0.59.0", "windows-sys",
] ]
[[package]] [[package]]
name = "anstyle-wincon" name = "anstyle-wincon"
version = "3.0.6" version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2109dbce0e72be3ec00bed26e6a7479ca384ad226efdd66db8fa2e3a38c83125" checksum = "61a38449feb7068f52bb06c12759005cf459ee52bb4adc1d5a7c4322d716fb19"
dependencies = [ dependencies = [
"anstyle", "anstyle",
"windows-sys 0.59.0", "windows-sys",
] ]
[[package]] [[package]]
name = "ascii-canvas" name = "ascii-canvas"
version = "4.0.0" version = "3.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef1e3e699d84ab1b0911a1010c5c106aa34ae89aeac103be5ce0c3859db1e891" checksum = "8824ecca2e851cec16968d54a01dd372ef8f95b244fb84b84e70128be347c3c6"
dependencies = [ dependencies = [
"term", "term",
] ]
[[package]] [[package]]
name = "autocfg" name = "autocfg"
version = "1.4.0" version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
[[package]] [[package]]
name = "bit-set" name = "bit-set"
version = "0.8.0" version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3" checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1"
dependencies = [ dependencies = [
"bit-vec", "bit-vec",
] ]
[[package]] [[package]]
name = "bit-vec" name = "bit-vec"
version = "0.8.0" version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb"
[[package]] [[package]]
name = "bitflags" name = "bitflags"
@ -109,15 +109,6 @@ version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de"
[[package]]
name = "block-buffer"
version = "0.10.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71"
dependencies = [
"generic-array",
]
[[package]] [[package]]
name = "byteorder" name = "byteorder"
version = "1.5.0" version = "1.5.0"
@ -126,12 +117,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.1.31" version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2e7962b54006dcfcc61cb72735f4d89bb97061dd6a7ed882ec6b8ee53714c6f" checksum = "eaff6f8ce506b9773fa786672d63fc7a191ffea1be33f72bbd4aeacefca9ffc8"
dependencies = [
"shlex",
]
[[package]] [[package]]
name = "cfg-if" name = "cfg-if"
@ -141,9 +129,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]] [[package]]
name = "clap" name = "clap"
version = "4.5.20" version = "4.5.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b97f376d85a664d5837dbae44bf546e6477a679ff6610010f17276f686d867e8" checksum = "64acc1846d54c1fe936a78dc189c34e28d3f5afc348403f28ecf53660b9b8462"
dependencies = [ dependencies = [
"clap_builder", "clap_builder",
"clap_derive", "clap_derive",
@ -151,9 +139,9 @@ dependencies = [
[[package]] [[package]]
name = "clap_builder" name = "clap_builder"
version = "4.5.20" version = "4.5.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19bc80abd44e4bed93ca373a0704ccbd1b710dc5749406201bb018272808dc54" checksum = "6fb8393d67ba2e7bfaf28a23458e4e2b543cc73a99595511eb207fdb8aede942"
dependencies = [ dependencies = [
"anstream", "anstream",
"anstyle", "anstyle",
@ -163,27 +151,27 @@ dependencies = [
[[package]] [[package]]
name = "clap_derive" name = "clap_derive"
version = "4.5.18" version = "4.5.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab" checksum = "2bac35c6dafb060fd4d275d9a4ffae97917c13a6327903a8be2153cd964f7085"
dependencies = [ dependencies = [
"heck 0.5.0", "heck 0.5.0",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.85", "syn 2.0.70",
] ]
[[package]] [[package]]
name = "clap_lex" name = "clap_lex"
version = "0.7.2" version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" checksum = "4b82cf0babdbd58558212896d1a4272303a57bdb245c2bf1147185fb45640e70"
[[package]] [[package]]
name = "colorchoice" name = "colorchoice"
version = "1.0.3" version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422"
[[package]] [[package]]
name = "console" name = "console"
@ -194,16 +182,7 @@ dependencies = [
"encode_unicode", "encode_unicode",
"lazy_static", "lazy_static",
"libc", "libc",
"windows-sys 0.52.0", "windows-sys",
]
[[package]]
name = "cpufeatures"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "608697df725056feaccfa42cffdaeeec3fccc4ffc38358ecd19b243e716a78e0"
dependencies = [
"libc",
] ]
[[package]] [[package]]
@ -263,23 +242,30 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80"
[[package]] [[package]]
name = "crypto-common" name = "crunchy"
version = "0.1.6" version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
[[package]]
name = "dirs-next"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b98cf8ebf19c3d1b223e151f99a4f9f0690dca41414773390fc824184ac833e1"
dependencies = [ dependencies = [
"generic-array", "cfg-if",
"typenum", "dirs-sys-next",
] ]
[[package]] [[package]]
name = "digest" name = "dirs-sys-next"
version = "0.10.7" version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" checksum = "4ebda144c4fe02d1f7ea1a7d9641b6fc6b580adcfa024ae48797ecdeb6825b4d"
dependencies = [ dependencies = [
"block-buffer", "libc",
"crypto-common", "redox_users",
"winapi",
] ]
[[package]] [[package]]
@ -316,14 +302,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba"
dependencies = [ dependencies = [
"libc", "libc",
"windows-sys 0.52.0", "windows-sys",
] ]
[[package]] [[package]]
name = "fastrand" name = "fastrand"
version = "2.1.1" version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6" checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a"
[[package]] [[package]]
name = "fixedbitset" name = "fixedbitset"
@ -340,16 +326,6 @@ dependencies = [
"byteorder", "byteorder",
] ]
[[package]]
name = "generic-array"
version = "0.14.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a"
dependencies = [
"typenum",
"version_check",
]
[[package]] [[package]]
name = "getopts" name = "getopts"
version = "0.2.21" version = "0.2.21"
@ -385,12 +361,6 @@ dependencies = [
"ahash", "ahash",
] ]
[[package]]
name = "hashbrown"
version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e087f84d4f86bf4b218b927129862374b72199ae7d8657835f1e89000eea4fb"
[[package]] [[package]]
name = "heck" name = "heck"
version = "0.4.1" version = "0.4.1"
@ -403,15 +373,6 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "home"
version = "0.5.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5"
dependencies = [
"windows-sys 0.52.0",
]
[[package]] [[package]]
name = "indexmap" name = "indexmap"
version = "1.9.3" version = "1.9.3"
@ -424,12 +385,12 @@ dependencies = [
[[package]] [[package]]
name = "indexmap" name = "indexmap"
version = "2.6.0" version = "2.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26"
dependencies = [ dependencies = [
"equivalent", "equivalent",
"hashbrown 0.15.0", "hashbrown 0.14.5",
] ]
[[package]] [[package]]
@ -440,9 +401,9 @@ checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5"
[[package]] [[package]]
name = "inkwell" name = "inkwell"
version = "0.5.0" version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40fb405537710d51f6bdbc8471365ddd4cd6d3a3c3ad6e0c8291691031ba94b2" checksum = "b597a7b2cdf279aeef6d7149071e35e4bc87c2cf05a5b7f2d731300bffe587ea"
dependencies = [ dependencies = [
"either", "either",
"inkwell_internals", "inkwell_internals",
@ -454,13 +415,13 @@ dependencies = [
[[package]] [[package]]
name = "inkwell_internals" name = "inkwell_internals"
version = "0.10.0" version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9dd28cfd4cfba665d47d31c08a6ba637eed16770abca2eccbbc3ca831fef1e44" checksum = "4fa4d8d74483041a882adaa9a29f633253a66dde85055f0495c121620ac484b2"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.85", "syn 2.0.70",
] ]
[[package]] [[package]]
@ -479,9 +440,18 @@ dependencies = [
[[package]] [[package]]
name = "is_terminal_polyfill" name = "is_terminal_polyfill"
version = "1.70.1" version = "1.70.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" checksum = "f8478577c03552c21db0e2724ffb8986a5ce7af88107e6be5d2ee6e158c12800"
[[package]]
name = "itertools"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57"
dependencies = [
"either",
]
[[package]] [[package]]
name = "itertools" name = "itertools"
@ -498,45 +468,35 @@ version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b"
[[package]]
name = "keccak"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ecc2af9a1119c51f12a14607e783cb977bde58bc069ff0c3da1095e635d70654"
dependencies = [
"cpufeatures",
]
[[package]] [[package]]
name = "lalrpop" name = "lalrpop"
version = "0.22.0" version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06093b57658c723a21da679530e061a8c25340fa5a6f98e313b542268c7e2a1f" checksum = "55cb077ad656299f160924eb2912aa147d7339ea7d69e1b5517326fdcec3c1ca"
dependencies = [ dependencies = [
"ascii-canvas", "ascii-canvas",
"bit-set", "bit-set",
"ena", "ena",
"itertools", "itertools 0.11.0",
"lalrpop-util", "lalrpop-util",
"petgraph", "petgraph",
"pico-args", "pico-args",
"regex", "regex",
"regex-syntax", "regex-syntax",
"sha3",
"string_cache", "string_cache",
"term", "term",
"tiny-keccak",
"unicode-xid", "unicode-xid",
"walkdir", "walkdir",
] ]
[[package]] [[package]]
name = "lalrpop-util" name = "lalrpop-util"
version = "0.22.0" version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "feee752d43abd0f4807a921958ab4131f692a44d4d599733d4419c5d586176ce" checksum = "507460a910eb7b32ee961886ff48539633b788a36b65692b95f225b844c82553"
dependencies = [ dependencies = [
"regex-automata", "regex-automata",
"rustversion",
] ]
[[package]] [[package]]
@ -547,20 +507,30 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.161" version = "0.2.155"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e9489c2807c139ffd9c1794f4af0ebe86a828db53ecdc7fea2111d0fed085d1" checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c"
[[package]] [[package]]
name = "libloading" name = "libloading"
version = "0.8.5" version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4" checksum = "e310b3a6b5907f99202fcdb4960ff45b93735d7c7d96b760fcff8db2dc0e103d"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"windows-targets", "windows-targets",
] ]
[[package]]
name = "libredox"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d"
dependencies = [
"bitflags",
"libc",
]
[[package]] [[package]]
name = "linked-hash-map" name = "linked-hash-map"
version = "0.5.6" version = "0.5.6"
@ -621,9 +591,11 @@ dependencies = [
name = "nac3artiq" name = "nac3artiq"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"itertools", "inkwell",
"itertools 0.13.0",
"nac3core", "nac3core",
"nac3ld", "nac3ld",
"nac3parser",
"parking_lot", "parking_lot",
"pyo3", "pyo3",
"tempfile", "tempfile",
@ -634,6 +606,7 @@ name = "nac3ast"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"fxhash", "fxhash",
"lazy_static",
"parking_lot", "parking_lot",
"string-interner", "string-interner",
] ]
@ -643,11 +616,11 @@ name = "nac3core"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"crossbeam", "crossbeam",
"indexmap 2.6.0", "indexmap 2.2.6",
"indoc", "indoc",
"inkwell", "inkwell",
"insta", "insta",
"itertools", "itertools 0.13.0",
"nac3parser", "nac3parser",
"parking_lot", "parking_lot",
"rayon", "rayon",
@ -685,7 +658,9 @@ name = "nac3standalone"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"clap", "clap",
"inkwell",
"nac3core", "nac3core",
"nac3parser",
"parking_lot", "parking_lot",
] ]
@ -697,9 +672,9 @@ checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086"
[[package]] [[package]]
name = "once_cell" name = "once_cell"
version = "1.20.2" version = "1.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
[[package]] [[package]]
name = "parking_lot" name = "parking_lot"
@ -731,7 +706,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db"
dependencies = [ dependencies = [
"fixedbitset", "fixedbitset",
"indexmap 2.6.0", "indexmap 2.2.6",
] ]
[[package]] [[package]]
@ -774,7 +749,7 @@ dependencies = [
"phf_shared 0.11.2", "phf_shared 0.11.2",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.85", "syn 2.0.70",
] ]
[[package]] [[package]]
@ -803,18 +778,15 @@ checksum = "5be167a7af36ee22fe3115051bc51f6e6c7054c9348e28deb4f49bd6f705a315"
[[package]] [[package]]
name = "portable-atomic" name = "portable-atomic"
version = "1.9.0" version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2" checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0"
[[package]] [[package]]
name = "ppv-lite86" name = "ppv-lite86"
version = "0.2.20" version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
dependencies = [
"zerocopy",
]
[[package]] [[package]]
name = "precomputed-hash" name = "precomputed-hash"
@ -824,9 +796,9 @@ checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c"
[[package]] [[package]]
name = "proc-macro2" name = "proc-macro2"
version = "1.0.89" version = "1.0.86"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77"
dependencies = [ dependencies = [
"unicode-ident", "unicode-ident",
] ]
@ -878,7 +850,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"pyo3-macros-backend", "pyo3-macros-backend",
"quote", "quote",
"syn 2.0.85", "syn 2.0.70",
] ]
[[package]] [[package]]
@ -891,14 +863,14 @@ dependencies = [
"proc-macro2", "proc-macro2",
"pyo3-build-config", "pyo3-build-config",
"quote", "quote",
"syn 2.0.85", "syn 2.0.70",
] ]
[[package]] [[package]]
name = "quote" name = "quote"
version = "1.0.37" version = "1.0.36"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
] ]
@ -955,18 +927,29 @@ dependencies = [
[[package]] [[package]]
name = "redox_syscall" name = "redox_syscall"
version = "0.5.7" version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f" checksum = "c82cf8cff14456045f55ec4241383baeff27af886adb72ffb2162f99911de0fd"
dependencies = [ dependencies = [
"bitflags", "bitflags",
] ]
[[package]] [[package]]
name = "regex" name = "redox_users"
version = "1.11.1" version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" checksum = "bd283d9651eeda4b2a83a43c1c91b266c40fd76ecd39a50a8c630ae69dc72891"
dependencies = [
"getrandom",
"libredox",
"thiserror",
]
[[package]]
name = "regex"
version = "1.10.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f"
dependencies = [ dependencies = [
"aho-corasick", "aho-corasick",
"memchr", "memchr",
@ -976,9 +959,9 @@ dependencies = [
[[package]] [[package]]
name = "regex-automata" name = "regex-automata"
version = "0.4.8" version = "0.4.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3" checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df"
dependencies = [ dependencies = [
"aho-corasick", "aho-corasick",
"memchr", "memchr",
@ -987,9 +970,9 @@ dependencies = [
[[package]] [[package]]
name = "regex-syntax" name = "regex-syntax"
version = "0.8.5" version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b"
[[package]] [[package]]
name = "runkernel" name = "runkernel"
@ -1000,22 +983,22 @@ dependencies = [
[[package]] [[package]]
name = "rustix" name = "rustix"
version = "0.38.38" version = "0.38.34"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa260229e6538e52293eeb577aabd09945a09d6d9cc0fc550ed7529056c2e32a" checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f"
dependencies = [ dependencies = [
"bitflags", "bitflags",
"errno", "errno",
"libc", "libc",
"linux-raw-sys", "linux-raw-sys",
"windows-sys 0.52.0", "windows-sys",
] ]
[[package]] [[package]]
name = "rustversion" name = "rustversion"
version = "1.0.18" version = "1.0.17"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e819f2bc632f285be6d7cd36e25940d45b2391dd6d9b939e79de557f7014248" checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6"
[[package]] [[package]]
name = "ryu" name = "ryu"
@ -1046,32 +1029,31 @@ checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b"
[[package]] [[package]]
name = "serde" name = "serde"
version = "1.0.214" version = "1.0.204"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f55c3193aca71c12ad7890f1785d2b73e1b9f63a0bbc353c08ef26fe03fc56b5" checksum = "bc76f558e0cbb2a839d37354c575f1dc3fdc6546b5be373ba43d95f231bf7c12"
dependencies = [ dependencies = [
"serde_derive", "serde_derive",
] ]
[[package]] [[package]]
name = "serde_derive" name = "serde_derive"
version = "1.0.214" version = "1.0.204"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "de523f781f095e28fa605cdce0f8307e451cc0fd14e2eb4cd2e98a355b147766" checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.85", "syn 2.0.70",
] ]
[[package]] [[package]]
name = "serde_json" name = "serde_json"
version = "1.0.132" version = "1.0.120"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" checksum = "4e0d21c9a8cae1235ad58a00c11cb40d4b1e5c784f1ef2c537876ed6ffd8b7c5"
dependencies = [ dependencies = [
"itoa", "itoa",
"memchr",
"ryu", "ryu",
"serde", "serde",
] ]
@ -1088,27 +1070,11 @@ dependencies = [
"yaml-rust", "yaml-rust",
] ]
[[package]]
name = "sha3"
version = "0.10.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75872d278a8f37ef87fa0ddbda7802605cb18344497949862c0d4dcb291eba60"
dependencies = [
"digest",
"keccak",
]
[[package]]
name = "shlex"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
[[package]] [[package]]
name = "similar" name = "similar"
version = "2.6.0" version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1de1d4f81173b03af4c0cbed3c898f6bff5b870e4a7f5d6f4057d62a7a4b686e" checksum = "fa42c91313f1d05da9b26f267f931cf178d4aba455b4c4622dd7355eb80c6640"
[[package]] [[package]]
name = "siphasher" name = "siphasher"
@ -1168,7 +1134,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"rustversion", "rustversion",
"syn 2.0.85", "syn 2.0.70",
] ]
[[package]] [[package]]
@ -1184,9 +1150,9 @@ dependencies = [
[[package]] [[package]]
name = "syn" name = "syn"
version = "2.0.85" version = "2.0.70"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5023162dfcd14ef8f32034d8bcd4cc5ddc61ef7a247c024a33e24e1f24d21b56" checksum = "2f0209b68b3613b093e0ec905354eccaedcfe83b8cb37cbdeae64026c3064c16"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -1195,31 +1161,31 @@ dependencies = [
[[package]] [[package]]
name = "target-lexicon" name = "target-lexicon"
version = "0.12.16" version = "0.12.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" checksum = "4873307b7c257eddcb50c9bedf158eb669578359fb28428bef438fec8e6ba7c2"
[[package]] [[package]]
name = "tempfile" name = "tempfile"
version = "3.13.0" version = "3.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0f2c9fc62d0beef6951ccffd757e241266a2c833136efbe35af6cd2567dca5b" checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"fastrand", "fastrand",
"once_cell",
"rustix", "rustix",
"windows-sys 0.59.0", "windows-sys",
] ]
[[package]] [[package]]
name = "term" name = "term"
version = "1.0.0" version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4df4175de05129f31b80458c6df371a15e7fc3fd367272e6bf938e5c351c7ea0" checksum = "c59df8ac95d96ff9bede18eb7300b0fda5e5d8d90960e76f8e14ae765eedbf1f"
dependencies = [ dependencies = [
"home", "dirs-next",
"windows-sys 0.52.0", "rustversion",
"winapi",
] ]
[[package]] [[package]]
@ -1237,29 +1203,32 @@ dependencies = [
[[package]] [[package]]
name = "thiserror" name = "thiserror"
version = "1.0.65" version = "1.0.61"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d11abd9594d9b38965ef50805c5e469ca9cc6f197f883f717e0269a3057b3d5" checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709"
dependencies = [ dependencies = [
"thiserror-impl", "thiserror-impl",
] ]
[[package]] [[package]]
name = "thiserror-impl" name = "thiserror-impl"
version = "1.0.65" version = "1.0.61"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae71770322cbd277e69d762a16c444af02aa0575ac0d174f0b9562d3b37f8602" checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.85", "syn 2.0.70",
] ]
[[package]] [[package]]
name = "typenum" name = "tiny-keccak"
version = "1.17.0" version = "2.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237"
dependencies = [
"crunchy",
]
[[package]] [[package]]
name = "unic-char-property" name = "unic-char-property"
@ -1315,27 +1284,27 @@ dependencies = [
[[package]] [[package]]
name = "unicode-ident" name = "unicode-ident"
version = "1.0.13" version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
[[package]] [[package]]
name = "unicode-width" name = "unicode-width"
version = "0.1.14" version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" checksum = "0336d538f7abc86d282a4189614dfaa90810dfc2c6f6427eaf88e16311dd225d"
[[package]] [[package]]
name = "unicode-xid" name = "unicode-xid"
version = "0.2.6" version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c"
[[package]] [[package]]
name = "unicode_names2" name = "unicode_names2"
version = "1.3.0" version = "1.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1673eca9782c84de5f81b82e4109dcfb3611c8ba0d52930ec4a9478f547b2dd" checksum = "addeebf294df7922a1164f729fb27ebbbcea99cc32b3bf08afab62757f707677"
dependencies = [ dependencies = [
"phf", "phf",
"unicode_names2_generator", "unicode_names2_generator",
@ -1343,9 +1312,9 @@ dependencies = [
[[package]] [[package]]
name = "unicode_names2_generator" name = "unicode_names2_generator"
version = "1.3.0" version = "1.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b91e5b84611016120197efd7dc93ef76774f4e084cd73c9fb3ea4a86c570c56e" checksum = "f444b8bba042fe3c1251ffaca35c603f2dc2ccc08d595c65a8c4f76f3e8426c0"
dependencies = [ dependencies = [
"getopts", "getopts",
"log", "log",
@ -1367,9 +1336,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]] [[package]]
name = "version_check" name = "version_check"
version = "0.9.5" version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
[[package]] [[package]]
name = "walkdir" name = "walkdir"
@ -1388,14 +1357,36 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
[[package]] [[package]]
name = "winapi-util" name = "winapi"
version = "0.1.9" version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
dependencies = [ dependencies = [
"windows-sys 0.59.0", "winapi-i686-pc-windows-gnu",
"winapi-x86_64-pc-windows-gnu",
] ]
[[package]]
name = "winapi-i686-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-util"
version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4d4cc384e1e73b93bafa6fb4f1df8c41695c8a91cf9c4c64358067d15a7b6c6b"
dependencies = [
"windows-sys",
]
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]] [[package]]
name = "windows-sys" name = "windows-sys"
version = "0.52.0" version = "0.52.0"
@ -1405,15 +1396,6 @@ dependencies = [
"windows-targets", "windows-targets",
] ]
[[package]]
name = "windows-sys"
version = "0.59.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b"
dependencies = [
"windows-targets",
]
[[package]] [[package]]
name = "windows-targets" name = "windows-targets"
version = "0.52.6" version = "0.52.6"
@ -1493,7 +1475,6 @@ version = "0.7.35"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0"
dependencies = [ dependencies = [
"byteorder",
"zerocopy-derive", "zerocopy-derive",
] ]
@ -1505,5 +1486,5 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.85", "syn 2.0.70",
] ]

View File

@ -2,11 +2,11 @@
"nodes": { "nodes": {
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1727348695, "lastModified": 1720418205,
"narHash": "sha256-J+PeFKSDV+pHL7ukkfpVzCOO7mBSrrpJ3svwBFABbhI=", "narHash": "sha256-cPJoFPXU44GlhWg4pUk9oUPqurPlCFZ11ZQPk21GTPU=",
"owner": "NixOS", "owner": "NixOS",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "1925c603f17fc89f4c8f6bf6f631a802ad85d784", "rev": "655a58a72a6601292512670343087c2d75d859c1",
"type": "github" "type": "github"
}, },
"original": { "original": {

View File

@ -6,7 +6,6 @@
outputs = { self, nixpkgs }: outputs = { self, nixpkgs }:
let let
pkgs = import nixpkgs { system = "x86_64-linux"; }; pkgs = import nixpkgs { system = "x86_64-linux"; };
pkgs32 = import nixpkgs { system = "i686-linux"; };
in rec { in rec {
packages.x86_64-linux = rec { packages.x86_64-linux = rec {
llvm-nac3 = pkgs.callPackage ./nix/llvm {}; llvm-nac3 = pkgs.callPackage ./nix/llvm {};
@ -14,24 +13,9 @@
'' ''
mkdir -p $out/bin mkdir -p $out/bin
ln -s ${pkgs.llvmPackages_14.clang-unwrapped}/bin/clang $out/bin/clang-irrt ln -s ${pkgs.llvmPackages_14.clang-unwrapped}/bin/clang $out/bin/clang-irrt
ln -s ${pkgs.llvmPackages_14.clang}/bin/clang $out/bin/clang-irrt-test
ln -s ${pkgs.llvmPackages_14.llvm.out}/bin/llvm-as $out/bin/llvm-as-irrt ln -s ${pkgs.llvmPackages_14.llvm.out}/bin/llvm-as $out/bin/llvm-as-irrt
''; '';
demo-linalg-stub = pkgs.rustPlatform.buildRustPackage {
name = "demo-linalg-stub";
src = ./nac3standalone/demo/linalg;
cargoLock = {
lockFile = ./nac3standalone/demo/linalg/Cargo.lock;
};
doCheck = false;
};
demo-linalg-stub32 = pkgs32.rustPlatform.buildRustPackage {
name = "demo-linalg-stub32";
src = ./nac3standalone/demo/linalg;
cargoLock = {
lockFile = ./nac3standalone/demo/linalg/Cargo.lock;
};
doCheck = false;
};
nac3artiq = pkgs.python3Packages.toPythonModule ( nac3artiq = pkgs.python3Packages.toPythonModule (
pkgs.rustPlatform.buildRustPackage rec { pkgs.rustPlatform.buildRustPackage rec {
name = "nac3artiq"; name = "nac3artiq";
@ -40,8 +24,9 @@
cargoLock = { cargoLock = {
lockFile = ./Cargo.lock; lockFile = ./Cargo.lock;
}; };
cargoTestFlags = [ "--features" "test" ];
passthru.cargoLock = cargoLock; passthru.cargoLock = cargoLock;
nativeBuildInputs = [ pkgs.python3 (pkgs.wrapClangMulti pkgs.llvmPackages_14.clang) llvm-tools-irrt pkgs.llvmPackages_14.llvm.out llvm-nac3 ]; nativeBuildInputs = [ pkgs.python3 pkgs.llvmPackages_14.clang llvm-tools-irrt pkgs.llvmPackages_14.llvm.out llvm-nac3 ];
buildInputs = [ pkgs.python3 llvm-nac3 ]; buildInputs = [ pkgs.python3 llvm-nac3 ];
checkInputs = [ (pkgs.python3.withPackages(ps: [ ps.numpy ps.scipy ])) ]; checkInputs = [ (pkgs.python3.withPackages(ps: [ ps.numpy ps.scipy ])) ];
checkPhase = checkPhase =
@ -49,9 +34,7 @@
echo "Checking nac3standalone demos..." echo "Checking nac3standalone demos..."
pushd nac3standalone/demo pushd nac3standalone/demo
patchShebangs . patchShebangs .
export DEMO_LINALG_STUB=${demo-linalg-stub}/lib/liblinalg.a ./check_demos.sh
export DEMO_LINALG_STUB32=${demo-linalg-stub32}/lib/liblinalg.a
./check_demos.sh -i686
popd popd
echo "Running Cargo tests..." echo "Running Cargo tests..."
cargoCheckHook cargoCheckHook
@ -168,7 +151,7 @@
buildInputs = with pkgs; [ buildInputs = with pkgs; [
# build dependencies # build dependencies
packages.x86_64-linux.llvm-nac3 packages.x86_64-linux.llvm-nac3
(pkgs.wrapClangMulti llvmPackages_14.clang) llvmPackages_14.llvm.out # for running nac3standalone demos llvmPackages_14.clang llvmPackages_14.llvm.out # for running nac3standalone demos
packages.x86_64-linux.llvm-tools-irrt packages.x86_64-linux.llvm-tools-irrt
cargo cargo
rustc rustc
@ -180,12 +163,10 @@
clippy clippy
pre-commit pre-commit
rustfmt rustfmt
rust-analyzer
]; ];
shellHook = # https://nixos.wiki/wiki/Rust#Shell.nix_example
'' RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}";
export DEMO_LINALG_STUB=${packages.x86_64-linux.demo-linalg-stub}/lib/liblinalg.a
export DEMO_LINALG_STUB32=${packages.x86_64-linux.demo-linalg-stub32}/lib/liblinalg.a
'';
}; };
devShells.x86_64-linux.msys2 = pkgs.mkShell { devShells.x86_64-linux.msys2 = pkgs.mkShell {
name = "nac3-dev-shell-msys2"; name = "nac3-dev-shell-msys2";

View File

@ -12,10 +12,15 @@ crate-type = ["cdylib"]
itertools = "0.13" itertools = "0.13"
pyo3 = { version = "0.21", features = ["extension-module", "gil-refs"] } pyo3 = { version = "0.21", features = ["extension-module", "gil-refs"] }
parking_lot = "0.12" parking_lot = "0.12"
tempfile = "3.13" tempfile = "3.10"
nac3parser = { path = "../nac3parser" }
nac3core = { path = "../nac3core" } nac3core = { path = "../nac3core" }
nac3ld = { path = "../nac3ld" } nac3ld = { path = "../nac3ld" }
[dependencies.inkwell]
version = "0.4"
default-features = false
features = ["llvm14-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]
[features] [features]
init-llvm-profile = [] init-llvm-profile = []
no-escape-analysis = ["nac3core/no-escape-analysis"]

View File

@ -1,24 +0,0 @@
from min_artiq import *
from numpy import int32
@nac3
class EmptyList:
core: KernelInvariant[Core]
def __init__(self):
self.core = Core()
@rpc
def get_empty(self) -> list[int32]:
return []
@kernel
def run(self):
a: list[int32] = self.get_empty()
if a != []:
raise ValueError
if __name__ == "__main__":
EmptyList().run()

View File

@ -112,15 +112,10 @@ def extern(function):
register_function(function) register_function(function)
return function return function
def rpc(function):
def rpc(arg=None, flags={}): """Decorates a function declaration defined by the core device runtime."""
"""Decorates a function or method to be executed on the host interpreter.""" register_function(function)
if arg is None: return function
def inner_decorator(function):
return rpc(function, flags)
return inner_decorator
register_function(arg)
return arg
def kernel(function_or_method): def kernel(function_or_method):
"""Decorates a function or method to be executed on the core device.""" """Decorates a function or method to be executed on the core device."""

View File

@ -1,26 +0,0 @@
from min_artiq import *
from numpy import ndarray, zeros as np_zeros
@nac3
class StrFail:
core: KernelInvariant[Core]
def __init__(self):
self.core = Core()
@kernel
def hello(self, arg: str):
pass
@kernel
def consume_ndarray(self, arg: ndarray[str, 1]):
pass
def run(self):
self.hello("world")
self.consume_ndarray(np_zeros([10], dtype=str))
if __name__ == "__main__":
StrFail().run()

File diff suppressed because it is too large Load Diff

View File

@ -16,65 +16,63 @@
clippy::wildcard_imports clippy::wildcard_imports
)] )]
use std::{ use std::collections::{HashMap, HashSet};
collections::{HashMap, HashSet}, use std::fs;
fs, use std::io::Write;
io::Write, use std::process::Command;
process::Command, use std::rc::Rc;
rc::Rc, use std::sync::Arc;
sync::Arc,
};
use itertools::Itertools; use inkwell::{
use parking_lot::{Mutex, RwLock}; memory_buffer::MemoryBuffer,
use pyo3::{ module::{Linkage, Module},
create_exception, exceptions, passes::PassBuilderOptions,
prelude::*, support::is_multithreaded,
types::{PyBytes, PyDict, PySet}, targets::*,
OptimizationLevel,
}; };
use tempfile::{self, TempDir}; use itertools::Itertools;
use nac3core::codegen::{gen_func_impl, CodeGenLLVMOptions, CodeGenTargetMachineOptions};
use nac3core::toplevel::builtins::get_exn_constructor;
use nac3core::typecheck::typedef::{TypeEnum, Unifier, VarMap};
use nac3parser::{
ast::{ExprKind, Stmt, StmtKind, StrRef},
parser::parse_program,
};
use pyo3::create_exception;
use pyo3::prelude::*;
use pyo3::{exceptions, types::PyBytes, types::PyDict, types::PySet};
use parking_lot::{Mutex, RwLock};
use nac3core::{ use nac3core::{
codegen::{ codegen::irrt::load_irrt,
concrete_type::ConcreteTypeStore, gen_func_impl, irrt::load_irrt, CodeGenLLVMOptions, codegen::{concrete_type::ConcreteTypeStore, CodeGenTask, WithCall, WorkerRegistry},
CodeGenTargetMachineOptions, CodeGenTask, WithCall, WorkerRegistry,
},
inkwell::{
context::Context,
memory_buffer::MemoryBuffer,
module::{Linkage, Module},
passes::PassBuilderOptions,
support::is_multithreaded,
targets::*,
OptimizationLevel,
},
nac3parser::{
ast::{Constant, ExprKind, Located, Stmt, StmtKind, StrRef},
parser::parse_program,
},
symbol_resolver::SymbolResolver, symbol_resolver::SymbolResolver,
toplevel::{ toplevel::{
builtins::get_exn_constructor, composer::{ComposerConfig, TopLevelComposer},
composer::{BuiltinFuncCreator, BuiltinFuncSpec, ComposerConfig, TopLevelComposer},
DefinitionId, GenCall, TopLevelDef, DefinitionId, GenCall, TopLevelDef,
}, },
typecheck::{ typecheck::typedef::{FunSignature, FuncArg},
type_inferencer::PrimitiveStore, typecheck::{type_inferencer::PrimitiveStore, typedef::Type},
typedef::{into_var_map, FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
},
}; };
use nac3ld::Linker; use nac3ld::Linker;
use codegen::{ use tempfile::{self, TempDir};
attributes_writeback, gen_core_log, gen_rtio_log, rpc_codegen_callback, ArtiqCodeGenerator,
use crate::codegen::attributes_writeback;
use crate::{
codegen::{rpc_codegen_callback, ArtiqCodeGenerator},
symbol_resolver::{DeferredEvaluationStore, InnerResolver, PythonHelper, Resolver},
}; };
use symbol_resolver::{DeferredEvaluationStore, InnerResolver, PythonHelper, Resolver};
use timeline::TimeFns;
mod codegen; mod codegen;
mod symbol_resolver; mod symbol_resolver;
mod timeline; mod timeline;
use timeline::TimeFns;
#[derive(PartialEq, Clone, Copy)] #[derive(PartialEq, Clone, Copy)]
enum Isa { enum Isa {
Host, Host,
@ -128,7 +126,7 @@ struct Nac3 {
isa: Isa, isa: Isa,
time_fns: &'static (dyn TimeFns + Sync), time_fns: &'static (dyn TimeFns + Sync),
primitive: PrimitiveStore, primitive: PrimitiveStore,
builtins: Vec<BuiltinFuncSpec>, builtins: Vec<(StrRef, FunSignature, Arc<GenCall>)>,
pyid_to_def: Arc<RwLock<HashMap<u64, DefinitionId>>>, pyid_to_def: Arc<RwLock<HashMap<u64, DefinitionId>>>,
primitive_ids: PrimitivePythonId, primitive_ids: PrimitivePythonId,
working_directory: TempDir, working_directory: TempDir,
@ -195,8 +193,10 @@ impl Nac3 {
body.retain(|stmt| { body.retain(|stmt| {
if let StmtKind::FunctionDef { ref decorator_list, .. } = stmt.node { if let StmtKind::FunctionDef { ref decorator_list, .. } = stmt.node {
decorator_list.iter().any(|decorator| { decorator_list.iter().any(|decorator| {
if let Some(id) = decorator_id_string(decorator) { if let ExprKind::Name { id, .. } = decorator.node {
id == "kernel" || id == "portable" || id == "rpc" id.to_string() == "kernel"
|| id.to_string() == "portable"
|| id.to_string() == "rpc"
} else { } else {
false false
} }
@ -209,8 +209,9 @@ impl Nac3 {
} }
StmtKind::FunctionDef { ref decorator_list, .. } => { StmtKind::FunctionDef { ref decorator_list, .. } => {
decorator_list.iter().any(|decorator| { decorator_list.iter().any(|decorator| {
if let Some(id) = decorator_id_string(decorator) { if let ExprKind::Name { id, .. } = decorator.node {
id == "extern" || id == "kernel" || id == "portable" || id == "rpc" let id = id.to_string();
id == "extern" || id == "portable" || id == "kernel" || id == "rpc"
} else { } else {
false false
} }
@ -263,7 +264,7 @@ impl Nac3 {
arg_names.len(), arg_names.len(),
)); ));
} }
for (i, FuncArg { ty, default_value, name, .. }) in args.iter().enumerate() { for (i, FuncArg { ty, default_value, name }) in args.iter().enumerate() {
let in_name = match arg_names.get(i) { let in_name = match arg_names.get(i) {
Some(n) => n, Some(n) => n,
None if default_value.is_none() => { None if default_value.is_none() => {
@ -299,64 +300,6 @@ impl Nac3 {
None None
} }
/// Returns a [`Vec`] of builtins that needs to be initialized during method compilation time.
fn get_lateinit_builtins() -> Vec<Box<BuiltinFuncCreator>> {
vec![
Box::new(|primitives, unifier| {
let arg_ty = unifier.get_fresh_var(Some("T".into()), None);
(
"core_log".into(),
FunSignature {
args: vec![FuncArg {
name: "arg".into(),
ty: arg_ty.ty,
default_value: None,
is_vararg: false,
}],
ret: primitives.none,
vars: into_var_map([arg_ty]),
},
Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| {
gen_core_log(ctx, &obj, fun, &args, generator)?;
Ok(None)
}))),
)
}),
Box::new(|primitives, unifier| {
let arg_ty = unifier.get_fresh_var(Some("T".into()), None);
(
"rtio_log".into(),
FunSignature {
args: vec![
FuncArg {
name: "channel".into(),
ty: primitives.str,
default_value: None,
is_vararg: false,
},
FuncArg {
name: "arg".into(),
ty: arg_ty.ty,
default_value: None,
is_vararg: false,
},
],
ret: primitives.none,
vars: into_var_map([arg_ty]),
},
Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| {
gen_rtio_log(ctx, &obj, fun, &args, generator)?;
Ok(None)
}))),
)
}),
]
}
fn compile_method<T>( fn compile_method<T>(
&self, &self,
obj: &PyAny, obj: &PyAny,
@ -369,7 +312,6 @@ impl Nac3 {
let size_t = self.isa.get_size_type(); let size_t = self.isa.get_size_type();
let (mut composer, mut builtins_def, mut builtins_ty) = TopLevelComposer::new( let (mut composer, mut builtins_def, mut builtins_ty) = TopLevelComposer::new(
self.builtins.clone(), self.builtins.clone(),
Self::get_lateinit_builtins(),
ComposerConfig { kernel_ann: Some("Kernel"), kernel_invariant_ann: "KernelInvariant" }, ComposerConfig { kernel_ann: Some("Kernel"), kernel_invariant_ann: "KernelInvariant" },
size_t, size_t,
); );
@ -446,6 +388,7 @@ impl Nac3 {
pyid_to_type: pyid_to_type.clone(), pyid_to_type: pyid_to_type.clone(),
primitive_ids: self.primitive_ids.clone(), primitive_ids: self.primitive_ids.clone(),
global_value_ids: global_value_ids.clone(), global_value_ids: global_value_ids.clone(),
class_names: Mutex::default(),
name_to_pyid: name_to_pyid.clone(), name_to_pyid: name_to_pyid.clone(),
module: module.clone(), module: module.clone(),
id_to_pyval: RwLock::default(), id_to_pyval: RwLock::default(),
@ -476,25 +419,9 @@ impl Nac3 {
match &stmt.node { match &stmt.node {
StmtKind::FunctionDef { decorator_list, .. } => { StmtKind::FunctionDef { decorator_list, .. } => {
if decorator_list if decorator_list.iter().any(|decorator| matches!(decorator.node, ExprKind::Name { id, .. } if id == "rpc".into())) {
.iter() store_fun.call1(py, (def_id.0.into_py(py), module.getattr(py, name.to_string().as_str()).unwrap())).unwrap();
.any(|decorator| decorator_id_string(decorator) == Some("rpc".to_string())) rpc_ids.push((None, def_id));
{
store_fun
.call1(
py,
(
def_id.0.into_py(py),
module.getattr(py, name.to_string().as_str()).unwrap(),
),
)
.unwrap();
let is_async = decorator_list.iter().any(|decorator| {
decorator_get_flags(decorator)
.iter()
.any(|constant| *constant == Constant::Str("async".into()))
});
rpc_ids.push((None, def_id, is_async));
} }
} }
StmtKind::ClassDef { name, body, .. } => { StmtKind::ClassDef { name, body, .. } => {
@ -502,26 +429,19 @@ impl Nac3 {
let class_obj = module.getattr(py, class_name.as_str()).unwrap(); let class_obj = module.getattr(py, class_name.as_str()).unwrap();
for stmt in body { for stmt in body {
if let StmtKind::FunctionDef { name, decorator_list, .. } = &stmt.node { if let StmtKind::FunctionDef { name, decorator_list, .. } = &stmt.node {
if decorator_list.iter().any(|decorator| { if decorator_list.iter().any(|decorator| matches!(decorator.node, ExprKind::Name { id, .. } if id == "rpc".into())) {
decorator_id_string(decorator) == Some("rpc".to_string())
}) {
let is_async = decorator_list.iter().any(|decorator| {
decorator_get_flags(decorator)
.iter()
.any(|constant| *constant == Constant::Str("async".into()))
});
if name == &"__init__".into() { if name == &"__init__".into() {
return Err(CompileError::new_err(format!( return Err(CompileError::new_err(format!(
"compilation failed\n----------\nThe constructor of class {} should not be decorated with rpc decorator (at {})", "compilation failed\n----------\nThe constructor of class {} should not be decorated with rpc decorator (at {})",
class_name, stmt.location class_name, stmt.location
))); )));
} }
rpc_ids.push((Some((class_obj.clone(), *name)), def_id, is_async)); rpc_ids.push((Some((class_obj.clone(), *name)), def_id));
} }
} }
} }
} }
_ => (), _ => ()
} }
let id = *name_to_pyid.get(&name).unwrap(); let id = *name_to_pyid.get(&name).unwrap();
@ -560,6 +480,7 @@ impl Nac3 {
pyid_to_type: pyid_to_type.clone(), pyid_to_type: pyid_to_type.clone(),
primitive_ids: self.primitive_ids.clone(), primitive_ids: self.primitive_ids.clone(),
global_value_ids: global_value_ids.clone(), global_value_ids: global_value_ids.clone(),
class_names: Mutex::default(),
id_to_pyval: RwLock::default(), id_to_pyval: RwLock::default(),
id_to_primitive: RwLock::default(), id_to_primitive: RwLock::default(),
field_to_val: RwLock::default(), field_to_val: RwLock::default(),
@ -576,10 +497,6 @@ impl Nac3 {
.register_top_level(synthesized.pop().unwrap(), Some(resolver.clone()), "", false) .register_top_level(synthesized.pop().unwrap(), Some(resolver.clone()), "", false)
.unwrap(); .unwrap();
// Process IRRT
let context = Context::create();
let irrt = load_irrt(&context, resolver.as_ref());
let fun_signature = let fun_signature =
FunSignature { args: vec![], ret: self.primitive.none, vars: VarMap::new() }; FunSignature { args: vec![], ret: self.primitive.none, vars: VarMap::new() };
let mut store = ConcreteTypeStore::new(); let mut store = ConcreteTypeStore::new();
@ -617,12 +534,13 @@ impl Nac3 {
let top_level = Arc::new(composer.make_top_level_context()); let top_level = Arc::new(composer.make_top_level_context());
{ {
let rpc_codegen = rpc_codegen_callback();
let defs = top_level.definitions.read(); let defs = top_level.definitions.read();
for (class_data, id, is_async) in &rpc_ids { for (class_data, id) in &rpc_ids {
let mut def = defs[id.0].write(); let mut def = defs[id.0].write();
match &mut *def { match &mut *def {
TopLevelDef::Function { codegen_callback, .. } => { TopLevelDef::Function { codegen_callback, .. } => {
*codegen_callback = Some(rpc_codegen_callback(*is_async)); *codegen_callback = Some(rpc_codegen.clone());
} }
TopLevelDef::Class { methods, .. } => { TopLevelDef::Class { methods, .. } => {
let (class_def, method_name) = class_data.as_ref().unwrap(); let (class_def, method_name) = class_data.as_ref().unwrap();
@ -633,7 +551,7 @@ impl Nac3 {
if let TopLevelDef::Function { codegen_callback, .. } = if let TopLevelDef::Function { codegen_callback, .. } =
&mut *defs[id.0].write() &mut *defs[id.0].write()
{ {
*codegen_callback = Some(rpc_codegen_callback(*is_async)); *codegen_callback = Some(rpc_codegen.clone());
store_fun store_fun
.call1( .call1(
py, py,
@ -648,11 +566,6 @@ impl Nac3 {
} }
} }
} }
TopLevelDef::Variable { .. } => {
return Err(CompileError::new_err(String::from(
"Unsupported @rpc annotation on global variable",
)))
}
} }
} }
} }
@ -712,9 +625,7 @@ impl Nac3 {
let buffer = buffer.as_slice().into(); let buffer = buffer.as_slice().into();
membuffer.lock().push(buffer); membuffer.lock().push(buffer);
}))); })));
let size_t = context let size_t = if self.isa == Isa::Host { 64 } else { 32 };
.ptr_sized_int_type(&self.get_llvm_target_machine().get_target_data(), None)
.get_bit_width();
let num_threads = if is_multithreaded() { 4 } else { 1 }; let num_threads = if is_multithreaded() { 4 } else { 1 };
let thread_names: Vec<String> = (0..num_threads).map(|_| "main".to_string()).collect(); let thread_names: Vec<String> = (0..num_threads).map(|_| "main".to_string()).collect();
let threads: Vec<_> = thread_names let threads: Vec<_> = thread_names
@ -731,11 +642,8 @@ impl Nac3 {
let mut generator = let mut generator =
ArtiqCodeGenerator::new("attributes_writeback".to_string(), size_t, self.time_fns); ArtiqCodeGenerator::new("attributes_writeback".to_string(), size_t, self.time_fns);
let context = Context::create(); let context = inkwell::context::Context::create();
let module = context.create_module("attributes_writeback"); let module = context.create_module("attributes_writeback");
let target_machine = self.llvm_options.create_target_machine().unwrap();
module.set_data_layout(&target_machine.get_target_data().get_data_layout());
module.set_triple(&target_machine.get_triple());
let builder = context.create_builder(); let builder = context.create_builder();
let (_, module, _) = gen_func_impl( let (_, module, _) = gen_func_impl(
&context, &context,
@ -754,7 +662,7 @@ impl Nac3 {
membuffer.lock().push(buffer); membuffer.lock().push(buffer);
}); });
// Link all modules into `main`. let context = inkwell::context::Context::create();
let buffers = membuffers.lock(); let buffers = membuffers.lock();
let main = context let main = context
.create_module_from_ir(MemoryBuffer::create_from_memory_range(&buffers[0], "main")) .create_module_from_ir(MemoryBuffer::create_from_memory_range(&buffers[0], "main"))
@ -783,7 +691,8 @@ impl Nac3 {
) )
.unwrap(); .unwrap();
main.link_in_module(irrt).map_err(|err| CompileError::new_err(err.to_string()))?; main.link_in_module(load_irrt(&context))
.map_err(|err| CompileError::new_err(err.to_string()))?;
let mut function_iter = main.get_first_function(); let mut function_iter = main.get_first_function();
while let Some(func) = function_iter { while let Some(func) = function_iter {
@ -869,41 +778,6 @@ impl Nac3 {
} }
} }
/// Retrieves the Name.id from a decorator, supports decorators with arguments.
fn decorator_id_string(decorator: &Located<ExprKind>) -> Option<String> {
if let ExprKind::Name { id, .. } = decorator.node {
// Bare decorator
return Some(id.to_string());
} else if let ExprKind::Call { func, .. } = &decorator.node {
// Decorators that are calls (e.g. "@rpc()") have Call for the node,
// need to extract the id from within.
if let ExprKind::Name { id, .. } = func.node {
return Some(id.to_string());
}
}
None
}
/// Retrieves flags from a decorator, if any.
fn decorator_get_flags(decorator: &Located<ExprKind>) -> Vec<Constant> {
let mut flags = vec![];
if let ExprKind::Call { keywords, .. } = &decorator.node {
for keyword in keywords {
if keyword.node.arg != Some("flags".into()) {
continue;
}
if let ExprKind::Set { elts } = &keyword.node.value.node {
for elt in elts {
if let ExprKind::Constant { value, .. } = &elt.node {
flags.push(value.clone());
}
}
}
}
}
flags
}
fn link_with_lld(elf_filename: String, obj_filename: String) -> PyResult<()> { fn link_with_lld(elf_filename: String, obj_filename: String) -> PyResult<()> {
let linker_args = vec![ let linker_args = vec![
"-shared".to_string(), "-shared".to_string(),
@ -973,7 +847,7 @@ impl Nac3 {
Isa::RiscV32IMA => &timeline::NOW_PINNING_TIME_FNS, Isa::RiscV32IMA => &timeline::NOW_PINNING_TIME_FNS,
Isa::CortexA9 | Isa::Host => &timeline::EXTERN_TIME_FNS, Isa::CortexA9 | Isa::Host => &timeline::EXTERN_TIME_FNS,
}; };
let (primitive, _) = TopLevelComposer::make_primitives(isa.get_size_type()); let primitive: PrimitiveStore = TopLevelComposer::make_primitives(isa.get_size_type()).0;
let builtins = vec![ let builtins = vec![
( (
"now_mu".into(), "now_mu".into(),
@ -989,7 +863,6 @@ impl Nac3 {
name: "t".into(), name: "t".into(),
ty: primitive.int64, ty: primitive.int64,
default_value: None, default_value: None,
is_vararg: false,
}], }],
ret: primitive.none, ret: primitive.none,
vars: VarMap::new(), vars: VarMap::new(),
@ -1009,7 +882,6 @@ impl Nac3 {
name: "dt".into(), name: "dt".into(),
ty: primitive.int64, ty: primitive.int64,
default_value: None, default_value: None,
is_vararg: false,
}], }],
ret: primitive.none, ret: primitive.none,
vars: VarMap::new(), vars: VarMap::new(),

View File

@ -1,30 +1,14 @@
use std::{ use inkwell::{
collections::{HashMap, HashSet}, types::{BasicType, BasicTypeEnum},
sync::{ values::BasicValueEnum,
atomic::{AtomicBool, Ordering::Relaxed}, AddressSpace,
Arc,
},
}; };
use itertools::Itertools; use itertools::Itertools;
use parking_lot::RwLock;
use pyo3::{
types::{PyDict, PyTuple},
PyAny, PyObject, PyResult, Python,
};
use nac3core::{ use nac3core::{
codegen::{ codegen::{
classes::{NDArrayType, ProxyType}, classes::{NDArrayType, ProxyType},
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}, },
inkwell::{
module::Linkage,
types::{BasicType, BasicTypeEnum},
values::BasicValueEnum,
AddressSpace,
},
nac3parser::ast::{self, StrRef},
symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum}, symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum},
toplevel::{ toplevel::{
helper::PrimDef, helper::PrimDef,
@ -36,8 +20,21 @@ use nac3core::{
typedef::{into_var_map, iter_type_vars, Type, TypeEnum, TypeVar, Unifier, VarMap}, typedef::{into_var_map, iter_type_vars, Type, TypeEnum, TypeVar, Unifier, VarMap},
}, },
}; };
use nac3parser::ast::{self, StrRef};
use parking_lot::{Mutex, RwLock};
use pyo3::{
types::{PyDict, PyTuple},
PyAny, PyObject, PyResult, Python,
};
use std::{
collections::{HashMap, HashSet},
sync::{
atomic::{AtomicBool, Ordering::Relaxed},
Arc,
},
};
use super::PrimitivePythonId; use crate::PrimitivePythonId;
pub enum PrimitiveValue { pub enum PrimitiveValue {
I32(i32), I32(i32),
@ -82,6 +79,7 @@ pub struct InnerResolver {
pub id_to_primitive: RwLock<HashMap<u64, PrimitiveValue>>, pub id_to_primitive: RwLock<HashMap<u64, PrimitiveValue>>,
pub field_to_val: RwLock<HashMap<ResolverField, Option<PyFieldHandle>>>, pub field_to_val: RwLock<HashMap<ResolverField, Option<PyFieldHandle>>>,
pub global_value_ids: Arc<RwLock<HashMap<u64, PyObject>>>, pub global_value_ids: Arc<RwLock<HashMap<u64, PyObject>>>,
pub class_names: Mutex<HashMap<StrRef, Type>>,
pub pyid_to_def: Arc<RwLock<HashMap<u64, DefinitionId>>>, pub pyid_to_def: Arc<RwLock<HashMap<u64, DefinitionId>>>,
pub pyid_to_type: Arc<RwLock<HashMap<u64, Type>>>, pub pyid_to_type: Arc<RwLock<HashMap<u64, Type>>>,
pub primitive_ids: PrimitivePythonId, pub primitive_ids: PrimitivePythonId,
@ -135,8 +133,6 @@ impl StaticValue for PythonValue {
format!("{}_const", self.id).as_str(), format!("{}_const", self.id).as_str(),
); );
global.set_constant(true); global.set_constant(true);
// Set linkage of global to private to avoid name collisions
global.set_linkage(Linkage::Private);
global.set_initializer(&ctx.ctx.const_struct( global.set_initializer(&ctx.ctx.const_struct(
&[ctx.ctx.i32_type().const_int(u64::from(id), false).into()], &[ctx.ctx.i32_type().const_int(u64::from(id), false).into()],
false, false,
@ -167,7 +163,7 @@ impl StaticValue for PythonValue {
PrimitiveValue::Bool(val) => { PrimitiveValue::Bool(val) => {
ctx.ctx.i8_type().const_int(u64::from(*val), false).into() ctx.ctx.i8_type().const_int(u64::from(*val), false).into()
} }
PrimitiveValue::Str(val) => ctx.gen_string(generator, val).into(), PrimitiveValue::Str(val) => ctx.ctx.const_string(val.as_bytes(), true).into(),
}); });
} }
if let Some(global) = ctx.module.get_global(&self.id.to_string()) { if let Some(global) = ctx.module.get_global(&self.id.to_string()) {
@ -355,7 +351,7 @@ impl InnerResolver {
Ok(Ok((ndarray, false))) Ok(Ok((ndarray, false)))
} else if ty_id == self.primitive_ids.tuple { } else if ty_id == self.primitive_ids.tuple {
// do not handle type var param and concrete check here // do not handle type var param and concrete check here
Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![], is_vararg_ctx: false }), false))) Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false)))
} else if ty_id == self.primitive_ids.option { } else if ty_id == self.primitive_ids.option {
Ok(Ok((primitives.option, false))) Ok(Ok((primitives.option, false)))
} else if ty_id == self.primitive_ids.none { } else if ty_id == self.primitive_ids.none {
@ -559,10 +555,7 @@ impl InnerResolver {
Err(err) => return Ok(Err(err)), Err(err) => return Ok(Err(err)),
_ => return Ok(Err("tuple type needs at least 1 type parameters".to_string())) _ => return Ok(Err("tuple type needs at least 1 type parameters".to_string()))
}; };
Ok(Ok(( Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: args }), true)))
unifier.add_ty(TypeEnum::TTuple { ty: args, is_vararg_ctx: false }),
true,
)))
} }
TypeEnum::TObj { params, obj_id, .. } => { TypeEnum::TObj { params, obj_id, .. } => {
let subst = { let subst = {
@ -804,9 +797,7 @@ impl InnerResolver {
.map(|elem| self.get_obj_type(py, elem, unifier, defs, primitives)) .map(|elem| self.get_obj_type(py, elem, unifier, defs, primitives))
.collect(); .collect();
let types = types?; let types = types?;
Ok(types.map(|types| { Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types })))
unifier.add_ty(TypeEnum::TTuple { ty: types, is_vararg_ctx: false })
}))
} }
// special handling for option type since its class member layout in python side // special handling for option type since its class member layout in python side
// is special and cannot be mapped directly to a nac3 type as below // is special and cannot be mapped directly to a nac3 type as below
@ -981,7 +972,7 @@ impl InnerResolver {
} else if ty_id == self.primitive_ids.string || ty_id == self.primitive_ids.np_str_ { } else if ty_id == self.primitive_ids.string || ty_id == self.primitive_ids.np_str_ {
let val: String = obj.extract().unwrap(); let val: String = obj.extract().unwrap();
self.id_to_primitive.write().insert(id, PrimitiveValue::Str(val.clone())); self.id_to_primitive.write().insert(id, PrimitiveValue::Str(val.clone()));
Ok(Some(ctx.gen_string(generator, val).into())) Ok(Some(ctx.ctx.const_string(val.as_bytes(), true).into()))
} else if ty_id == self.primitive_ids.float || ty_id == self.primitive_ids.float64 { } else if ty_id == self.primitive_ids.float || ty_id == self.primitive_ids.float64 {
let val: f64 = obj.extract().unwrap(); let val: f64 = obj.extract().unwrap();
self.id_to_primitive.write().insert(id, PrimitiveValue::F64(val)); self.id_to_primitive.write().insert(id, PrimitiveValue::F64(val));
@ -1000,15 +991,8 @@ impl InnerResolver {
} }
_ => unreachable!("must be list"), _ => unreachable!("must be list"),
}; };
let ty = ctx.get_llvm_type(generator, elem_ty);
let size_t = generator.get_size_type(ctx.ctx); let size_t = generator.get_size_type(ctx.ctx);
let ty = if len == 0
&& matches!(&*ctx.unifier.get_ty_immutable(elem_ty), TypeEnum::TVar { .. })
{
// The default type for zero-length lists of unknown element type is size_t
size_t.into()
} else {
ctx.get_llvm_type(generator, elem_ty)
};
let arr_ty = ctx let arr_ty = ctx
.ctx .ctx
.struct_type(&[ty.ptr_type(AddressSpace::default()).into(), size_t.into()], false); .struct_type(&[ty.ptr_type(AddressSpace::default()).into(), size_t.into()], false);
@ -1212,9 +1196,7 @@ impl InnerResolver {
Ok(Some(ndarray.as_pointer_value().into())) Ok(Some(ndarray.as_pointer_value().into()))
} else if ty_id == self.primitive_ids.tuple { } else if ty_id == self.primitive_ids.tuple {
let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty); let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty);
let TypeEnum::TTuple { ty, is_vararg_ctx: false } = expected_ty_enum.as_ref() else { let TypeEnum::TTuple { ty } = expected_ty_enum.as_ref() else { unreachable!() };
unreachable!()
};
let tup_tys = ty.iter(); let tup_tys = ty.iter();
let elements: &PyTuple = obj.downcast()?; let elements: &PyTuple = obj.downcast()?;
@ -1470,7 +1452,6 @@ impl SymbolResolver for Resolver {
&self, &self,
id: StrRef, id: StrRef,
_: &mut CodeGenContext<'ctx, '_>, _: &mut CodeGenContext<'ctx, '_>,
_: &mut dyn CodeGenerator,
) -> Option<ValueEnum<'ctx>> { ) -> Option<ValueEnum<'ctx>> {
let sym_value = { let sym_value = {
let id_to_val = self.0.id_to_pyval.read(); let id_to_val = self.0.id_to_pyval.read();

View File

@ -1,12 +1,9 @@
use itertools::Either; use inkwell::{
values::{BasicValueEnum, CallSiteValue},
use nac3core::{ AddressSpace, AtomicOrdering,
codegen::CodeGenContext,
inkwell::{
values::{BasicValueEnum, CallSiteValue},
AddressSpace, AtomicOrdering,
},
}; };
use itertools::Either;
use nac3core::codegen::CodeGenContext;
/// Functions for manipulating the timeline. /// Functions for manipulating the timeline.
pub trait TimeFns { pub trait TimeFns {
@ -34,7 +31,7 @@ impl TimeFns for NowPinningTimeFns64 {
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx let now_hiptr = ctx
.builder .builder
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr") .build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value) .map(BasicValueEnum::into_pointer_value)
.unwrap(); .unwrap();
@ -83,7 +80,7 @@ impl TimeFns for NowPinningTimeFns64 {
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx let now_hiptr = ctx
.builder .builder
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr") .build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value) .map(BasicValueEnum::into_pointer_value)
.unwrap(); .unwrap();
@ -112,7 +109,7 @@ impl TimeFns for NowPinningTimeFns64 {
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx let now_hiptr = ctx
.builder .builder
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr") .build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value) .map(BasicValueEnum::into_pointer_value)
.unwrap(); .unwrap();
@ -210,7 +207,7 @@ impl TimeFns for NowPinningTimeFns {
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx let now_hiptr = ctx
.builder .builder
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr") .build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value) .map(BasicValueEnum::into_pointer_value)
.unwrap(); .unwrap();
@ -261,7 +258,7 @@ impl TimeFns for NowPinningTimeFns {
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap(); let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap();
let now_hiptr = ctx let now_hiptr = ctx
.builder .builder
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr") .build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value) .map(BasicValueEnum::into_pointer_value)
.unwrap(); .unwrap();

View File

@ -10,6 +10,7 @@ constant-optimization = ["fold"]
fold = [] fold = []
[dependencies] [dependencies]
lazy_static = "1.5"
parking_lot = "0.12" parking_lot = "0.12"
string-interner = "0.17" string-interner = "0.17"
fxhash = "0.2" fxhash = "0.2"

View File

@ -5,12 +5,14 @@ pub use crate::location::Location;
use fxhash::FxBuildHasher; use fxhash::FxBuildHasher;
use parking_lot::{Mutex, MutexGuard}; use parking_lot::{Mutex, MutexGuard};
use std::{cell::RefCell, collections::HashMap, fmt, sync::LazyLock}; use std::{cell::RefCell, collections::HashMap, fmt};
use string_interner::{symbol::SymbolU32, DefaultBackend, StringInterner}; use string_interner::{symbol::SymbolU32, DefaultBackend, StringInterner};
pub type Interner = StringInterner<DefaultBackend, FxBuildHasher>; pub type Interner = StringInterner<DefaultBackend, FxBuildHasher>;
static INTERNER: LazyLock<Mutex<Interner>> = lazy_static! {
LazyLock::new(|| Mutex::new(StringInterner::with_hasher(FxBuildHasher::default()))); static ref INTERNER: Mutex<Interner> =
Mutex::new(StringInterner::with_hasher(FxBuildHasher::default()));
}
thread_local! { thread_local! {
static LOCAL_INTERNER: RefCell<HashMap<String, StrRef>> = RefCell::default(); static LOCAL_INTERNER: RefCell<HashMap<String, StrRef>> = RefCell::default();

View File

@ -14,6 +14,9 @@
clippy::wildcard_imports clippy::wildcard_imports
)] )]
#[macro_use]
extern crate lazy_static;
mod ast_gen; mod ast_gen;
mod constant; mod constant;
#[cfg(feature = "fold")] #[cfg(feature = "fold")]

View File

@ -1,26 +1,26 @@
[features]
test = []
[package] [package]
name = "nac3core" name = "nac3core"
version = "0.1.0" version = "0.1.0"
authors = ["M-Labs"] authors = ["M-Labs"]
edition = "2021" edition = "2021"
[features]
no-escape-analysis = []
[dependencies] [dependencies]
itertools = "0.13" itertools = "0.13"
crossbeam = "0.8" crossbeam = "0.8"
indexmap = "2.6" indexmap = "2.2"
parking_lot = "0.12" parking_lot = "0.12"
rayon = "1.10" rayon = "1.8"
nac3parser = { path = "../nac3parser" } nac3parser = { path = "../nac3parser" }
strum = "0.26" strum = "0.26.2"
strum_macros = "0.26" strum_macros = "0.26.4"
[dependencies.inkwell] [dependencies.inkwell]
version = "0.5" version = "0.4"
default-features = false default-features = false
features = ["llvm14-0-prefer-dynamic", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"] features = ["llvm14-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]
[dev-dependencies] [dev-dependencies]
test-case = "1.2.0" test-case = "1.2.0"

View File

@ -1,3 +1,4 @@
use regex::Regex;
use std::{ use std::{
env, env,
fs::File, fs::File,
@ -6,53 +7,39 @@ use std::{
process::{Command, Stdio}, process::{Command, Stdio},
}; };
use regex::Regex; fn compile_irrt(irrt_dir: &Path, out_dir: &Path) {
fn main() {
let out_dir = env::var("OUT_DIR").unwrap();
let out_dir = Path::new(&out_dir);
let irrt_dir = Path::new("irrt");
let irrt_cpp_path = irrt_dir.join("irrt.cpp"); let irrt_cpp_path = irrt_dir.join("irrt.cpp");
/* /*
* HACK: Sadly, clang doesn't let us emit generic LLVM bitcode. * HACK: Sadly, clang doesn't let us emit generic LLVM bitcode.
* Compiling for WASM32 and filtering the output with regex is the closest we can get. * Compiling for WASM32 and filtering the output with regex is the closest we can get.
*/ */
let mut flags: Vec<&str> = vec![ let flags: &[&str] = &[
"--target=wasm32", "--target=wasm32",
irrt_cpp_path.to_str().unwrap(),
"-x", "-x",
"c++", "c++",
"-std=c++20",
"-fno-discard-value-names", "-fno-discard-value-names",
"-fno-exceptions", "-fno-exceptions",
"-fno-rtti", "-fno-rtti",
match env::var("PROFILE").as_deref() {
Ok("debug") => "-O0",
Ok("release") => "-O3",
flavor => panic!("Unknown or missing build flavor {flavor:?}"),
},
"-emit-llvm", "-emit-llvm",
"-S", "-S",
"-Wall", "-Wall",
"-Wextra", "-Wextra",
"-o", "-Werror=return-type",
"-",
"-I", "-I",
irrt_dir.to_str().unwrap(), irrt_dir.to_str().unwrap(),
irrt_cpp_path.to_str().unwrap(), "-o",
"-",
]; ];
match env::var("PROFILE").as_deref() { println!("cargo:rerun-if-changed={}", out_dir.to_str().unwrap());
Ok("debug") => {
flags.push("-O0");
flags.push("-DIRRT_DEBUG_ASSERT");
}
Ok("release") => {
flags.push("-O3");
}
flavor => panic!("Unknown or missing build flavor {flavor:?}"),
}
// Tell Cargo to rerun if any file under `irrt_dir` (recursive) changes
println!("cargo:rerun-if-changed={}", irrt_dir.to_str().unwrap());
// Compile IRRT and capture the LLVM IR output
let output = Command::new("clang-irrt") let output = Command::new("clang-irrt")
.args(flags) .args(flags)
.output() .output()
@ -66,17 +53,11 @@ fn main() {
let output = std::str::from_utf8(&output.stdout).unwrap().replace("\r\n", "\n"); let output = std::str::from_utf8(&output.stdout).unwrap().replace("\r\n", "\n");
let mut filtered_output = String::with_capacity(output.len()); let mut filtered_output = String::with_capacity(output.len());
// Filter out irrelevant IR // (?ms:^define.*?\}$) to capture `define` blocks
// // (?m:^declare.*?$) to capture `declare` blocks
// Regex: // (?m:^%.+?=\s*type\s*\{.+?\}$) to capture `type` declarations
// - `(?ms:^define.*?\}$)` captures LLVM `define` blocks let regex_filter =
// - `(?m:^declare.*?$)` captures LLVM `declare` lines Regex::new(r"(?ms:^define.*?\}$)|(?m:^declare.*?$)|(?m:^%.+?=\s*type\s*\{.+?\}$)").unwrap();
// - `(?m:^%.+?=\s*type\s*\{.+?\}$)` captures LLVM `type` declarations
// - `(?m:^@.+?=.+$)` captures global constants
let regex_filter = Regex::new(
r"(?ms:^define.*?\}$)|(?m:^declare.*?$)|(?m:^%.+?=\s*type\s*\{.+?\}$)|(?m:^@.+?=.+$)",
)
.unwrap();
for f in regex_filter.captures_iter(&output) { for f in regex_filter.captures_iter(&output) {
assert_eq!(f.len(), 1); assert_eq!(f.len(), 1);
filtered_output.push_str(&f[0]); filtered_output.push_str(&f[0]);
@ -87,14 +68,10 @@ fn main() {
.unwrap() .unwrap()
.replace_all(&filtered_output, ""); .replace_all(&filtered_output, "");
// For debugging println!("cargo:rerun-if-env-changed=DEBUG_DUMP_IRRT");
// Doing `DEBUG_DUMP_IRRT=1 cargo build -p nac3core` dumps the LLVM IR generated if env::var("DEBUG_DUMP_IRRT").is_ok() {
const DEBUG_DUMP_IRRT: &str = "DEBUG_DUMP_IRRT";
println!("cargo:rerun-if-env-changed={DEBUG_DUMP_IRRT}");
if env::var(DEBUG_DUMP_IRRT).is_ok() {
let mut file = File::create(out_dir.join("irrt.ll")).unwrap(); let mut file = File::create(out_dir.join("irrt.ll")).unwrap();
file.write_all(output.as_bytes()).unwrap(); file.write_all(output.as_bytes()).unwrap();
let mut file = File::create(out_dir.join("irrt-filtered.ll")).unwrap(); let mut file = File::create(out_dir.join("irrt-filtered.ll")).unwrap();
file.write_all(filtered_output.as_bytes()).unwrap(); file.write_all(filtered_output.as_bytes()).unwrap();
} }
@ -108,3 +85,50 @@ fn main() {
llvm_as.stdin.as_mut().unwrap().write_all(filtered_output.as_bytes()).unwrap(); llvm_as.stdin.as_mut().unwrap().write_all(filtered_output.as_bytes()).unwrap();
assert!(llvm_as.wait().unwrap().success()); assert!(llvm_as.wait().unwrap().success());
} }
fn compile_irrt_test(irrt_dir: &Path, out_dir: &Path) {
let irrt_test_cpp_path = irrt_dir.join("irrt_test.cpp");
let exe_path = out_dir.join("irrt_test.out");
let flags: &[&str] = &[
irrt_test_cpp_path.to_str().unwrap(),
"-x",
"c++",
"-I",
irrt_dir.to_str().unwrap(),
"-g",
"-fno-discard-value-names",
"-O0",
"-Wall",
"-Wextra",
"-Werror=return-type",
"-lm", // for `tgamma()`, `lgamma()`
"-o",
exe_path.to_str().unwrap(),
];
Command::new("clang-irrt-test")
.args(flags)
.output()
.map(|o| {
assert!(o.status.success(), "{}", std::str::from_utf8(&o.stderr).unwrap());
o
})
.unwrap();
println!("cargo:rerun-if-changed={}", out_dir.to_str().unwrap());
}
fn main() {
let out_dir = env::var("OUT_DIR").unwrap();
let out_dir = Path::new(&out_dir);
let irrt_dir = Path::new("./irrt");
compile_irrt(irrt_dir, out_dir);
// https://github.com/rust-lang/cargo/issues/2549
// `cargo test -F test` to also build `irrt_test.cpp
if cfg!(feature = "test") {
compile_irrt_test(irrt_dir, out_dir);
}
}

View File

@ -1,6 +1,5 @@
#include "irrt/exception.hpp" #include "irrt_everything.hpp"
#include "irrt/int_types.hpp"
#include "irrt/list.hpp" /*
#include "irrt/math.hpp" This file will be read by `clang-irrt` to conveniently produce LLVM IR for `nac3core/codegen`.
#include "irrt/ndarray.hpp" */
#include "irrt/slice.hpp"

437
nac3core/irrt/irrt.hpp Normal file
View File

@ -0,0 +1,437 @@
#ifndef IRRT_DONT_TYPEDEF_INTS
typedef _BitInt(8) int8_t;
typedef unsigned _BitInt(8) uint8_t;
typedef _BitInt(32) int32_t;
typedef unsigned _BitInt(32) uint32_t;
typedef _BitInt(64) int64_t;
typedef unsigned _BitInt(64) uint64_t;
#endif
// NDArray indices are always `uint32_t`.
typedef uint32_t NDIndex;
// The type of an index or a value describing the length of a range/slice is
// always `int32_t`.
typedef int32_t SliceIndex;
template <typename T>
static T max(T a, T b) {
return a > b ? a : b;
}
template <typename T>
static T min(T a, T b) {
return a > b ? b : a;
}
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
// need to make sure `exp >= 0` before calling this function
template <typename T>
static T __nac3_int_exp_impl(T base, T exp) {
T res = 1;
/* repeated squaring method */
do {
if (exp & 1) {
res *= base; /* for n odd */
}
exp >>= 1;
base *= base;
} while (exp);
return res;
}
template <typename SizeT>
static SizeT __nac3_ndarray_calc_size_impl(
const SizeT *list_data,
SizeT list_len,
SizeT begin_idx,
SizeT end_idx
) {
__builtin_assume(end_idx <= list_len);
SizeT num_elems = 1;
for (SizeT i = begin_idx; i < end_idx; ++i) {
SizeT val = list_data[i];
__builtin_assume(val > 0);
num_elems *= val;
}
return num_elems;
}
template <typename SizeT>
static void __nac3_ndarray_calc_nd_indices_impl(
SizeT index,
const SizeT *dims,
SizeT num_dims,
NDIndex *idxs
) {
SizeT stride = 1;
for (SizeT dim = 0; dim < num_dims; dim++) {
SizeT i = num_dims - dim - 1;
__builtin_assume(dims[i] > 0);
idxs[i] = (index / stride) % dims[i];
stride *= dims[i];
}
}
template <typename SizeT>
static SizeT __nac3_ndarray_flatten_index_impl(
const SizeT *dims,
SizeT num_dims,
const NDIndex *indices,
SizeT num_indices
) {
SizeT idx = 0;
SizeT stride = 1;
for (SizeT i = 0; i < num_dims; ++i) {
SizeT ri = num_dims - i - 1;
if (ri < num_indices) {
idx += stride * indices[ri];
}
__builtin_assume(dims[i] > 0);
stride *= dims[ri];
}
return idx;
}
template <typename SizeT>
static void __nac3_ndarray_calc_broadcast_impl(
const SizeT *lhs_dims,
SizeT lhs_ndims,
const SizeT *rhs_dims,
SizeT rhs_ndims,
SizeT *out_dims
) {
SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
for (SizeT i = 0; i < max_ndims; ++i) {
const SizeT *lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr;
const SizeT *rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr;
SizeT *out_dim = &out_dims[max_ndims - i - 1];
if (lhs_dim_sz == nullptr) {
*out_dim = *rhs_dim_sz;
} else if (rhs_dim_sz == nullptr) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == 1) {
*out_dim = *rhs_dim_sz;
} else if (*rhs_dim_sz == 1) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == *rhs_dim_sz) {
*out_dim = *lhs_dim_sz;
} else {
__builtin_unreachable();
}
}
}
template <typename SizeT>
static void __nac3_ndarray_calc_broadcast_idx_impl(
const SizeT *src_dims,
SizeT src_ndims,
const NDIndex *in_idx,
NDIndex *out_idx
) {
for (SizeT i = 0; i < src_ndims; ++i) {
SizeT src_i = src_ndims - i - 1;
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
}
}
template<typename SizeT>
static void __nac3_ndarray_strides_from_shape_impl(
SizeT ndims,
SizeT *shape,
SizeT *dst_strides
) {
SizeT stride_product = 1;
for (SizeT i = 0; i < ndims; i++) {
int dim_i = ndims - i - 1;
dst_strides[dim_i] = stride_product;
stride_product *= shape[dim_i];
}
}
extern "C" {
#define DEF_nac3_int_exp_(T) \
T __nac3_int_exp_##T(T base, T exp) {\
return __nac3_int_exp_impl(base, exp);\
}
DEF_nac3_int_exp_(int32_t)
DEF_nac3_int_exp_(int64_t)
DEF_nac3_int_exp_(uint32_t)
DEF_nac3_int_exp_(uint64_t)
SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) {
if (i < 0) {
i = len + i;
}
if (i < 0) {
return 0;
} else if (i > len) {
return len;
}
return i;
}
SliceIndex __nac3_range_slice_len(
const SliceIndex start,
const SliceIndex end,
const SliceIndex step
) {
SliceIndex diff = end - start;
if (diff > 0 && step > 0) {
return ((diff - 1) / step) + 1;
} else if (diff < 0 && step < 0) {
return ((diff + 1) / step) + 1;
} else {
return 0;
}
}
// Handle list assignment and dropping part of the list when
// both dest_step and src_step are +1.
// - All the index must *not* be out-of-bound or negative,
// - The end index is *inclusive*,
// - The length of src and dest slice size should already
// be checked: if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest)
SliceIndex __nac3_list_slice_assign_var_size(
SliceIndex dest_start,
SliceIndex dest_end,
SliceIndex dest_step,
uint8_t *dest_arr,
SliceIndex dest_arr_len,
SliceIndex src_start,
SliceIndex src_end,
SliceIndex src_step,
uint8_t *src_arr,
SliceIndex src_arr_len,
const SliceIndex size
) {
/* if dest_arr_len == 0, do nothing since we do not support extending list */
if (dest_arr_len == 0) return dest_arr_len;
/* if both step is 1, memmove directly, handle the dropping of the list, and shrink size */
if (src_step == dest_step && dest_step == 1) {
const SliceIndex src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0;
const SliceIndex dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
if (src_len > 0) {
__builtin_memmove(
dest_arr + dest_start * size,
src_arr + src_start * size,
src_len * size
);
}
if (dest_len > 0) {
/* dropping */
__builtin_memmove(
dest_arr + (dest_start + src_len) * size,
dest_arr + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size
);
}
/* shrink size */
return dest_arr_len - (dest_len - src_len);
}
/* if two range overlaps, need alloca */
uint8_t need_alloca =
(dest_arr == src_arr)
&& !(
max(dest_start, dest_end) < min(src_start, src_end)
|| max(src_start, src_end) < min(dest_start, dest_end)
);
if (need_alloca) {
uint8_t *tmp = reinterpret_cast<uint8_t *>(__builtin_alloca(src_arr_len * size));
__builtin_memcpy(tmp, src_arr, src_arr_len * size);
src_arr = tmp;
}
SliceIndex src_ind = src_start;
SliceIndex dest_ind = dest_start;
for (;
(src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end);
src_ind += src_step, dest_ind += dest_step
) {
/* for constant optimization */
if (size == 1) {
__builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1);
} else if (size == 4) {
__builtin_memcpy(dest_arr + dest_ind * 4, src_arr + src_ind * 4, 4);
} else if (size == 8) {
__builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8);
} else {
/* memcpy for var size, cannot overlap after previous alloca */
__builtin_memcpy(dest_arr + dest_ind * size, src_arr + src_ind * size, size);
}
}
/* only dest_step == 1 can we shrink the dest list. */
/* size should be ensured prior to calling this function */
if (dest_step == 1 && dest_end >= dest_start) {
__builtin_memmove(
dest_arr + dest_ind * size,
dest_arr + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size
);
return dest_arr_len - (dest_end - dest_ind) - 1;
}
return dest_arr_len;
}
int32_t __nac3_isinf(double x) {
return __builtin_isinf(x);
}
int32_t __nac3_isnan(double x) {
return __builtin_isnan(x);
}
double tgamma(double arg);
double __nac3_gamma(double z) {
// Handling for denormals
// | x | Python gamma(x) | C tgamma(x) |
// --- | ----------------- | --------------- | ----------- |
// (1) | nan | nan | nan |
// (2) | -inf | -inf | inf |
// (3) | inf | inf | inf |
// (4) | 0.0 | inf | inf |
// (5) | {-1.0, -2.0, ...} | inf | nan |
// (1)-(3)
if (__builtin_isinf(z) || __builtin_isnan(z)) {
return z;
}
double v = tgamma(z);
// (4)-(5)
return __builtin_isinf(v) || __builtin_isnan(v) ? __builtin_inf() : v;
}
double lgamma(double arg);
double __nac3_gammaln(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: gammaln(-inf) -> -inf
// - libm : lgamma(-inf) -> inf
if (__builtin_isinf(x)) {
return x;
}
return lgamma(x);
}
double j0(double x);
double __nac3_j0(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: j0(inf) -> nan
// - libm : j0(inf) -> 0.0
if (__builtin_isinf(x)) {
return __builtin_nan("");
}
return j0(x);
}
uint32_t __nac3_ndarray_calc_size(
const uint32_t *list_data,
uint32_t list_len,
uint32_t begin_idx,
uint32_t end_idx
) {
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
}
uint64_t __nac3_ndarray_calc_size64(
const uint64_t *list_data,
uint64_t list_len,
uint64_t begin_idx,
uint64_t end_idx
) {
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
}
void __nac3_ndarray_calc_nd_indices(
uint32_t index,
const uint32_t* dims,
uint32_t num_dims,
NDIndex* idxs
) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
}
void __nac3_ndarray_calc_nd_indices64(
uint64_t index,
const uint64_t* dims,
uint64_t num_dims,
NDIndex* idxs
) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
}
uint32_t __nac3_ndarray_flatten_index(
const uint32_t* dims,
uint32_t num_dims,
const NDIndex* indices,
uint32_t num_indices
) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
}
uint64_t __nac3_ndarray_flatten_index64(
const uint64_t* dims,
uint64_t num_dims,
const NDIndex* indices,
uint64_t num_indices
) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
}
void __nac3_ndarray_calc_broadcast(
const uint32_t *lhs_dims,
uint32_t lhs_ndims,
const uint32_t *rhs_dims,
uint32_t rhs_ndims,
uint32_t *out_dims
) {
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
}
void __nac3_ndarray_calc_broadcast64(
const uint64_t *lhs_dims,
uint64_t lhs_ndims,
const uint64_t *rhs_dims,
uint64_t rhs_ndims,
uint64_t *out_dims
) {
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
}
void __nac3_ndarray_calc_broadcast_idx(
const uint32_t *src_dims,
uint32_t src_ndims,
const NDIndex *in_idx,
NDIndex *out_idx
) {
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
}
void __nac3_ndarray_calc_broadcast_idx64(
const uint64_t *src_dims,
uint64_t src_ndims,
const NDIndex *in_idx,
NDIndex *out_idx
) {
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
}
void __nac3_ndarray_strides_from_shape(uint32_t ndims, uint32_t* shape, uint32_t* dst_strides) {
__nac3_ndarray_strides_from_shape_impl(ndims, shape, dst_strides);
}
void __nac3_ndarray_strides_from_shape64(uint64_t ndims, uint64_t* shape, uint64_t* dst_strides) {
__nac3_ndarray_strides_from_shape_impl(ndims, shape, dst_strides);
}
}

View File

@ -1,9 +0,0 @@
#pragma once
#include "irrt/int_types.hpp"
template<typename SizeT>
struct CSlice {
uint8_t* base;
SizeT len;
};

View File

@ -1,25 +0,0 @@
#pragma once
// Set in nac3core/build.rs
#ifdef IRRT_DEBUG_ASSERT
#define IRRT_DEBUG_ASSERT_BOOL true
#else
#define IRRT_DEBUG_ASSERT_BOOL false
#endif
#define raise_debug_assert(SizeT, msg, param1, param2, param3) \
raise_exception(SizeT, EXN_ASSERTION_ERROR, "IRRT debug assert failed: " msg, param1, param2, param3)
#define debug_assert_eq(SizeT, lhs, rhs) \
if constexpr (IRRT_DEBUG_ASSERT_BOOL) { \
if ((lhs) != (rhs)) { \
raise_debug_assert(SizeT, "LHS = {0}. RHS = {1}", lhs, rhs, NO_PARAM); \
} \
}
#define debug_assert(SizeT, expr) \
if constexpr (IRRT_DEBUG_ASSERT_BOOL) { \
if (!(expr)) { \
raise_debug_assert(SizeT, "Got false.", NO_PARAM, NO_PARAM, NO_PARAM); \
} \
}

View File

@ -1,82 +0,0 @@
#pragma once
#include "irrt/cslice.hpp"
#include "irrt/int_types.hpp"
/**
* @brief The int type of ARTIQ exception IDs.
*/
typedef int32_t ExceptionId;
/*
* Set of exceptions C++ IRRT can use.
* Must be synchronized with `setup_irrt_exceptions` in `nac3core/src/codegen/irrt/mod.rs`.
*/
extern "C" {
ExceptionId EXN_INDEX_ERROR;
ExceptionId EXN_VALUE_ERROR;
ExceptionId EXN_ASSERTION_ERROR;
ExceptionId EXN_TYPE_ERROR;
}
/**
* @brief Extern function to `__nac3_raise`
*
* The parameter `err` could be `Exception<int32_t>` or `Exception<int64_t>`. The caller
* must make sure to pass `Exception`s with the correct `SizeT` depending on the `size_t` of the runtime.
*/
extern "C" void __nac3_raise(void* err);
namespace {
/**
* @brief NAC3's Exception struct
*/
template<typename SizeT>
struct Exception {
ExceptionId id;
CSlice<SizeT> filename;
int32_t line;
int32_t column;
CSlice<SizeT> function;
CSlice<SizeT> msg;
int64_t params[3];
};
constexpr int64_t NO_PARAM = 0;
template<typename SizeT>
void _raise_exception_helper(ExceptionId id,
const char* filename,
int32_t line,
const char* function,
const char* msg,
int64_t param0,
int64_t param1,
int64_t param2) {
Exception<SizeT> e = {
.id = id,
.filename = {.base = reinterpret_cast<const uint8_t*>(filename), .len = __builtin_strlen(filename)},
.line = line,
.column = 0,
.function = {.base = reinterpret_cast<const uint8_t*>(function), .len = __builtin_strlen(function)},
.msg = {.base = reinterpret_cast<const uint8_t*>(msg), .len = __builtin_strlen(msg)},
};
e.params[0] = param0;
e.params[1] = param1;
e.params[2] = param2;
__nac3_raise(reinterpret_cast<void*>(&e));
__builtin_unreachable();
}
/**
* @brief Raise an exception with location details (location in the IRRT source files).
* @param SizeT The runtime `size_t` type.
* @param id The ID of the exception to raise.
* @param msg A global constant C-string of the error message.
*
* `param0` to `param2` are optional format arguments of `msg`. They should be set to
* `NO_PARAM` to indicate they are unused.
*/
#define raise_exception(SizeT, id, msg, param0, param1, param2) \
_raise_exception_helper<SizeT>(id, __FILE__, __LINE__, __FUNCTION__, msg, param0, param1, param2)
} // namespace

View File

@ -1,22 +0,0 @@
#pragma once
#if __STDC_VERSION__ >= 202000
using int8_t = _BitInt(8);
using uint8_t = unsigned _BitInt(8);
using int32_t = _BitInt(32);
using uint32_t = unsigned _BitInt(32);
using int64_t = _BitInt(64);
using uint64_t = unsigned _BitInt(64);
#else
using int8_t = _ExtInt(8);
using uint8_t = unsigned _ExtInt(8);
using int32_t = _ExtInt(32);
using uint32_t = unsigned _ExtInt(32);
using int64_t = _ExtInt(64);
using uint64_t = unsigned _ExtInt(64);
#endif
// NDArray indices are always `uint32_t`.
using NDIndex = uint32_t;
// The type of an index or a value describing the length of a range/slice is always `int32_t`.
using SliceIndex = int32_t;

View File

@ -1,75 +0,0 @@
#pragma once
#include "irrt/int_types.hpp"
#include "irrt/math_util.hpp"
extern "C" {
// Handle list assignment and dropping part of the list when
// both dest_step and src_step are +1.
// - All the index must *not* be out-of-bound or negative,
// - The end index is *inclusive*,
// - The length of src and dest slice size should already
// be checked: if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest)
SliceIndex __nac3_list_slice_assign_var_size(SliceIndex dest_start,
SliceIndex dest_end,
SliceIndex dest_step,
uint8_t* dest_arr,
SliceIndex dest_arr_len,
SliceIndex src_start,
SliceIndex src_end,
SliceIndex src_step,
uint8_t* src_arr,
SliceIndex src_arr_len,
const SliceIndex size) {
/* if dest_arr_len == 0, do nothing since we do not support extending list */
if (dest_arr_len == 0)
return dest_arr_len;
/* if both step is 1, memmove directly, handle the dropping of the list, and shrink size */
if (src_step == dest_step && dest_step == 1) {
const SliceIndex src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0;
const SliceIndex dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
if (src_len > 0) {
__builtin_memmove(dest_arr + dest_start * size, src_arr + src_start * size, src_len * size);
}
if (dest_len > 0) {
/* dropping */
__builtin_memmove(dest_arr + (dest_start + src_len) * size, dest_arr + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size);
}
/* shrink size */
return dest_arr_len - (dest_len - src_len);
}
/* if two range overlaps, need alloca */
uint8_t need_alloca = (dest_arr == src_arr)
&& !(max(dest_start, dest_end) < min(src_start, src_end)
|| max(src_start, src_end) < min(dest_start, dest_end));
if (need_alloca) {
uint8_t* tmp = reinterpret_cast<uint8_t*>(__builtin_alloca(src_arr_len * size));
__builtin_memcpy(tmp, src_arr, src_arr_len * size);
src_arr = tmp;
}
SliceIndex src_ind = src_start;
SliceIndex dest_ind = dest_start;
for (; (src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end); src_ind += src_step, dest_ind += dest_step) {
/* for constant optimization */
if (size == 1) {
__builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1);
} else if (size == 4) {
__builtin_memcpy(dest_arr + dest_ind * 4, src_arr + src_ind * 4, 4);
} else if (size == 8) {
__builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8);
} else {
/* memcpy for var size, cannot overlap after previous alloca */
__builtin_memcpy(dest_arr + dest_ind * size, src_arr + src_ind * size, size);
}
}
/* only dest_step == 1 can we shrink the dest list. */
/* size should be ensured prior to calling this function */
if (dest_step == 1 && dest_end >= dest_start) {
__builtin_memmove(dest_arr + dest_ind * size, dest_arr + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size);
return dest_arr_len - (dest_end - dest_ind) - 1;
}
return dest_arr_len;
}
} // extern "C"

View File

@ -1,93 +0,0 @@
#pragma once
namespace {
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
// need to make sure `exp >= 0` before calling this function
template<typename T>
T __nac3_int_exp_impl(T base, T exp) {
T res = 1;
/* repeated squaring method */
do {
if (exp & 1) {
res *= base; /* for n odd */
}
exp >>= 1;
base *= base;
} while (exp);
return res;
}
} // namespace
#define DEF_nac3_int_exp_(T) \
T __nac3_int_exp_##T(T base, T exp) { \
return __nac3_int_exp_impl(base, exp); \
}
extern "C" {
// Putting semicolons here to make clang-format not reformat this into
// a stair shape.
DEF_nac3_int_exp_(int32_t);
DEF_nac3_int_exp_(int64_t);
DEF_nac3_int_exp_(uint32_t);
DEF_nac3_int_exp_(uint64_t);
int32_t __nac3_isinf(double x) {
return __builtin_isinf(x);
}
int32_t __nac3_isnan(double x) {
return __builtin_isnan(x);
}
double tgamma(double arg);
double __nac3_gamma(double z) {
// Handling for denormals
// | x | Python gamma(x) | C tgamma(x) |
// --- | ----------------- | --------------- | ----------- |
// (1) | nan | nan | nan |
// (2) | -inf | -inf | inf |
// (3) | inf | inf | inf |
// (4) | 0.0 | inf | inf |
// (5) | {-1.0, -2.0, ...} | inf | nan |
// (1)-(3)
if (__builtin_isinf(z) || __builtin_isnan(z)) {
return z;
}
double v = tgamma(z);
// (4)-(5)
return __builtin_isinf(v) || __builtin_isnan(v) ? __builtin_inf() : v;
}
double lgamma(double arg);
double __nac3_gammaln(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: gammaln(-inf) -> -inf
// - libm : lgamma(-inf) -> inf
if (__builtin_isinf(x)) {
return x;
}
return lgamma(x);
}
double j0(double x);
double __nac3_j0(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: j0(inf) -> nan
// - libm : j0(inf) -> 0.0
if (__builtin_isinf(x)) {
return __builtin_nan("");
}
return j0(x);
}
}

View File

@ -1,13 +0,0 @@
#pragma once
namespace {
template<typename T>
const T& max(const T& a, const T& b) {
return a > b ? a : b;
}
template<typename T>
const T& min(const T& a, const T& b) {
return a > b ? b : a;
}
} // namespace

View File

@ -1,144 +0,0 @@
#pragma once
#include "irrt/int_types.hpp"
namespace {
template<typename SizeT>
SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) {
__builtin_assume(end_idx <= list_len);
SizeT num_elems = 1;
for (SizeT i = begin_idx; i < end_idx; ++i) {
SizeT val = list_data[i];
__builtin_assume(val > 0);
num_elems *= val;
}
return num_elems;
}
template<typename SizeT>
void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT num_dims, NDIndex* idxs) {
SizeT stride = 1;
for (SizeT dim = 0; dim < num_dims; dim++) {
SizeT i = num_dims - dim - 1;
__builtin_assume(dims[i] > 0);
idxs[i] = (index / stride) % dims[i];
stride *= dims[i];
}
}
template<typename SizeT>
SizeT __nac3_ndarray_flatten_index_impl(const SizeT* dims, SizeT num_dims, const NDIndex* indices, SizeT num_indices) {
SizeT idx = 0;
SizeT stride = 1;
for (SizeT i = 0; i < num_dims; ++i) {
SizeT ri = num_dims - i - 1;
if (ri < num_indices) {
idx += stride * indices[ri];
}
__builtin_assume(dims[i] > 0);
stride *= dims[ri];
}
return idx;
}
template<typename SizeT>
void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims,
SizeT lhs_ndims,
const SizeT* rhs_dims,
SizeT rhs_ndims,
SizeT* out_dims) {
SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
for (SizeT i = 0; i < max_ndims; ++i) {
const SizeT* lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr;
const SizeT* rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr;
SizeT* out_dim = &out_dims[max_ndims - i - 1];
if (lhs_dim_sz == nullptr) {
*out_dim = *rhs_dim_sz;
} else if (rhs_dim_sz == nullptr) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == 1) {
*out_dim = *rhs_dim_sz;
} else if (*rhs_dim_sz == 1) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == *rhs_dim_sz) {
*out_dim = *lhs_dim_sz;
} else {
__builtin_unreachable();
}
}
}
template<typename SizeT>
void __nac3_ndarray_calc_broadcast_idx_impl(const SizeT* src_dims,
SizeT src_ndims,
const NDIndex* in_idx,
NDIndex* out_idx) {
for (SizeT i = 0; i < src_ndims; ++i) {
SizeT src_i = src_ndims - i - 1;
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
}
}
} // namespace
extern "C" {
uint32_t __nac3_ndarray_calc_size(const uint32_t* list_data, uint32_t list_len, uint32_t begin_idx, uint32_t end_idx) {
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
}
uint64_t
__nac3_ndarray_calc_size64(const uint64_t* list_data, uint64_t list_len, uint64_t begin_idx, uint64_t end_idx) {
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
}
void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims, uint32_t num_dims, NDIndex* idxs) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
}
void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndex* idxs) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
}
uint32_t
__nac3_ndarray_flatten_index(const uint32_t* dims, uint32_t num_dims, const NDIndex* indices, uint32_t num_indices) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
}
uint64_t
__nac3_ndarray_flatten_index64(const uint64_t* dims, uint64_t num_dims, const NDIndex* indices, uint64_t num_indices) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
}
void __nac3_ndarray_calc_broadcast(const uint32_t* lhs_dims,
uint32_t lhs_ndims,
const uint32_t* rhs_dims,
uint32_t rhs_ndims,
uint32_t* out_dims) {
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
}
void __nac3_ndarray_calc_broadcast64(const uint64_t* lhs_dims,
uint64_t lhs_ndims,
const uint64_t* rhs_dims,
uint64_t rhs_ndims,
uint64_t* out_dims) {
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
}
void __nac3_ndarray_calc_broadcast_idx(const uint32_t* src_dims,
uint32_t src_ndims,
const NDIndex* in_idx,
NDIndex* out_idx) {
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
}
void __nac3_ndarray_calc_broadcast_idx64(const uint64_t* src_dims,
uint64_t src_ndims,
const NDIndex* in_idx,
NDIndex* out_idx) {
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
}
}

View File

@ -1,28 +0,0 @@
#pragma once
#include "irrt/int_types.hpp"
extern "C" {
SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) {
if (i < 0) {
i = len + i;
}
if (i < 0) {
return 0;
} else if (i > len) {
return len;
}
return i;
}
SliceIndex __nac3_range_slice_len(const SliceIndex start, const SliceIndex end, const SliceIndex step) {
SliceIndex diff = end - start;
if (diff > 0 && step > 0) {
return ((diff - 1) / step) + 1;
} else if (diff < 0 && step < 0) {
return ((diff + 1) / step) + 1;
} else {
return 0;
}
}
}

View File

@ -0,0 +1,216 @@
#pragma once
#include "irrt_utils.hpp"
#include "irrt_typedefs.hpp"
/*
This header contains IRRT implementations
that do not deserved to be categorized (e.g., into numpy, etc.)
Check out other *.hpp files before including them here!!
*/
// The type of an index or a value describing the length of a range/slice is
// always `int32_t`.
namespace {
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
// need to make sure `exp >= 0` before calling this function
template <typename T>
T __nac3_int_exp_impl(T base, T exp) {
T res = 1;
/* repeated squaring method */
do {
if (exp & 1) {
res *= base; /* for n odd */
}
exp >>= 1;
base *= base;
} while (exp);
return res;
}
}
extern "C" {
#define DEF_nac3_int_exp_(T) \
T __nac3_int_exp_##T(T base, T exp) {\
return __nac3_int_exp_impl(base, exp);\
}
DEF_nac3_int_exp_(int32_t)
DEF_nac3_int_exp_(int64_t)
DEF_nac3_int_exp_(uint32_t)
DEF_nac3_int_exp_(uint64_t)
SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) {
if (i < 0) {
i = len + i;
}
if (i < 0) {
return 0;
} else if (i > len) {
return len;
}
return i;
}
SliceIndex __nac3_range_slice_len(
const SliceIndex start,
const SliceIndex end,
const SliceIndex step
) {
SliceIndex diff = end - start;
if (diff > 0 && step > 0) {
return ((diff - 1) / step) + 1;
} else if (diff < 0 && step < 0) {
return ((diff + 1) / step) + 1;
} else {
return 0;
}
}
// Handle list assignment and dropping part of the list when
// both dest_step and src_step are +1.
// - All the index must *not* be out-of-bound or negative,
// - The end index is *inclusive*,
// - The length of src and dest slice size should already
// be checked: if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest)
SliceIndex __nac3_list_slice_assign_var_size(
SliceIndex dest_start,
SliceIndex dest_end,
SliceIndex dest_step,
uint8_t *dest_arr,
SliceIndex dest_arr_len,
SliceIndex src_start,
SliceIndex src_end,
SliceIndex src_step,
uint8_t *src_arr,
SliceIndex src_arr_len,
const SliceIndex size
) {
/* if dest_arr_len == 0, do nothing since we do not support extending list */
if (dest_arr_len == 0) return dest_arr_len;
/* if both step is 1, memmove directly, handle the dropping of the list, and shrink size */
if (src_step == dest_step && dest_step == 1) {
const SliceIndex src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0;
const SliceIndex dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
if (src_len > 0) {
__builtin_memmove(
dest_arr + dest_start * size,
src_arr + src_start * size,
src_len * size
);
}
if (dest_len > 0) {
/* dropping */
__builtin_memmove(
dest_arr + (dest_start + src_len) * size,
dest_arr + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size
);
}
/* shrink size */
return dest_arr_len - (dest_len - src_len);
}
/* if two range overlaps, need alloca */
uint8_t need_alloca =
(dest_arr == src_arr)
&& !(
max(dest_start, dest_end) < min(src_start, src_end)
|| max(src_start, src_end) < min(dest_start, dest_end)
);
if (need_alloca) {
uint8_t *tmp = reinterpret_cast<uint8_t *>(__builtin_alloca(src_arr_len * size));
__builtin_memcpy(tmp, src_arr, src_arr_len * size);
src_arr = tmp;
}
SliceIndex src_ind = src_start;
SliceIndex dest_ind = dest_start;
for (;
(src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end);
src_ind += src_step, dest_ind += dest_step
) {
/* for constant optimization */
if (size == 1) {
__builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1);
} else if (size == 4) {
__builtin_memcpy(dest_arr + dest_ind * 4, src_arr + src_ind * 4, 4);
} else if (size == 8) {
__builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8);
} else {
/* memcpy for var size, cannot overlap after previous alloca */
__builtin_memcpy(dest_arr + dest_ind * size, src_arr + src_ind * size, size);
}
}
/* only dest_step == 1 can we shrink the dest list. */
/* size should be ensured prior to calling this function */
if (dest_step == 1 && dest_end >= dest_start) {
__builtin_memmove(
dest_arr + dest_ind * size,
dest_arr + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size
);
return dest_arr_len - (dest_end - dest_ind) - 1;
}
return dest_arr_len;
}
int32_t __nac3_isinf(double x) {
return __builtin_isinf(x);
}
int32_t __nac3_isnan(double x) {
return __builtin_isnan(x);
}
double tgamma(double arg);
double __nac3_gamma(double z) {
// Handling for denormals
// | x | Python gamma(x) | C tgamma(x) |
// --- | ----------------- | --------------- | ----------- |
// (1) | nan | nan | nan |
// (2) | -inf | -inf | inf |
// (3) | inf | inf | inf |
// (4) | 0.0 | inf | inf |
// (5) | {-1.0, -2.0, ...} | inf | nan |
// (1)-(3)
if (__builtin_isinf(z) || __builtin_isnan(z)) {
return z;
}
double v = tgamma(z);
// (4)-(5)
return __builtin_isinf(v) || __builtin_isnan(v) ? __builtin_inf() : v;
}
double lgamma(double arg);
double __nac3_gammaln(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: gammaln(-inf) -> -inf
// - libm : lgamma(-inf) -> inf
if (__builtin_isinf(x)) {
return x;
}
return lgamma(x);
}
double j0(double x);
double __nac3_j0(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: j0(inf) -> nan
// - libm : j0(inf) -> 0.0
if (__builtin_isinf(x)) {
return __builtin_nan("");
}
return j0(x);
}
}

View File

@ -0,0 +1,14 @@
#pragma once
#include "irrt_utils.hpp"
#include "irrt_typedefs.hpp"
#include "irrt_basic.hpp"
#include "irrt_slice.hpp"
#include "irrt_numpy_ndarray.hpp"
/*
All IRRT implementations.
We don't have any pre-compiled objects, so we are writing all implementations in headers and
concatenate them with `#include` into one massive source file that contains all the IRRT stuff.
*/

View File

@ -0,0 +1,466 @@
#pragma once
#include "irrt_utils.hpp"
#include "irrt_typedefs.hpp"
#include "irrt_slice.hpp"
/*
NDArray-related implementations.
`*/
// NDArray indices are always `uint32_t`.
using NDIndex = uint32_t;
namespace {
namespace ndarray_util {
template <typename SizeT>
static void set_indices_by_nth(SizeT ndims, const SizeT* shape, SizeT* indices, SizeT nth) {
for (int32_t i = 0; i < ndims; i++) {
int32_t dim_i = ndims - i - 1;
int32_t dim = shape[dim_i];
indices[dim_i] = nth % dim;
nth /= dim;
}
}
// Compute the strides of an ndarray given an ndarray `shape`
// and assuming that the ndarray is *fully C-contagious*.
//
// You might want to read up on https://ajcr.net/stride-guide-part-1/.
template <typename SizeT>
static void set_strides_by_shape(SizeT itemsize, SizeT ndims, SizeT* dst_strides, const SizeT* shape) {
SizeT stride_product = 1;
for (SizeT i = 0; i < ndims; i++) {
int dim_i = ndims - i - 1;
dst_strides[dim_i] = stride_product * itemsize;
stride_product *= shape[dim_i];
}
}
// Compute the size/# of elements of an ndarray given its shape
template <typename SizeT>
static SizeT calc_size_from_shape(SizeT ndims, const SizeT* shape) {
SizeT size = 1;
for (SizeT dim_i = 0; dim_i < ndims; dim_i++) size *= shape[dim_i];
return size;
}
template <typename SizeT>
static bool can_broadcast_shape_to(
const SizeT target_ndims,
const SizeT *target_shape,
const SizeT src_ndims,
const SizeT *src_shape
) {
/*
// See https://numpy.org/doc/stable/user/basics.broadcasting.html
This function handles this example:
```
Image (3d array): 256 x 256 x 3
Scale (1d array): 3
Result (3d array): 256 x 256 x 3
```
Other interesting examples to consider:
- `can_broadcast_shape_to([3], [1, 1, 1, 1, 3]) == true`
- `can_broadcast_shape_to([3], [3, 1]) == false`
- `can_broadcast_shape_to([256, 256, 3], [256, 1, 3]) == true`
In cases when the shapes contain zero(es):
- `can_broadcast_shape_to([0], [1]) == true`
- `can_broadcast_shape_to([0], [2]) == false`
- `can_broadcast_shape_to([0, 4, 0, 0], [1]) == true`
- `can_broadcast_shape_to([0, 4, 0, 0], [1, 1, 1, 1]) == true`
- `can_broadcast_shape_to([0, 4, 0, 0], [1, 4, 1, 1]) == true`
- `can_broadcast_shape_to([4, 3], [0, 3]) == false`
- `can_broadcast_shape_to([4, 3], [0, 0]) == false`
*/
// This is essentially doing the following in Python:
// `for target_dim, src_dim in itertools.zip_longest(target_shape[::-1], src_shape[::-1], fillvalue=1)`
for (SizeT i = 0; i < max(target_ndims, src_ndims); i++) {
SizeT target_dim_i = target_ndims - i - 1;
SizeT src_dim_i = src_ndims - i - 1;
bool target_dim_exists = target_dim_i >= 0;
bool src_dim_exists = src_dim_i >= 0;
SizeT target_dim = target_dim_exists ? target_shape[target_dim_i] : 1;
SizeT src_dim = src_dim_exists ? src_shape[src_dim_i] : 1;
bool ok = src_dim == 1 || target_dim == src_dim;
if (!ok) return false;
}
return true;
}
}
typedef uint8_t NDSliceType;
extern "C" {
const NDSliceType INPUT_SLICE_TYPE_INDEX = 0;
const NDSliceType INPUT_SLICE_TYPE_SLICE = 1;
}
struct NDSlice {
// A poor-man's `std::variant<int, UserRange>`
NDSliceType type;
/*
if type == INPUT_SLICE_TYPE_INDEX => `slice` points to a single `SizeT`
if type == INPUT_SLICE_TYPE_SLICE => `slice` points to a single `UserRange`
*/
uint8_t *slice;
};
namespace ndarray_util {
template<typename SizeT>
SizeT deduce_ndims_after_slicing(SizeT ndims, SizeT num_slices, const NDSlice *slices) {
irrt_assert(num_slices <= ndims);
SizeT final_ndims = ndims;
for (SizeT i = 0; i < num_slices; i++) {
if (slices[i].type == INPUT_SLICE_TYPE_INDEX) {
final_ndims--; // An integer slice demotes the rank by 1
}
}
return final_ndims;
}
}
template <typename SizeT>
struct NDArrayIndicesIter {
SizeT ndims;
const SizeT *shape;
SizeT *indices;
void set_indices_zero() {
__builtin_memset(indices, 0, sizeof(SizeT) * ndims);
}
void next() {
for (SizeT i = 0; i < ndims; i++) {
SizeT dim_i = ndims - i - 1;
indices[dim_i]++;
if (indices[dim_i] < shape[dim_i]) {
break;
} else {
indices[dim_i] = 0;
}
}
}
};
// The NDArray object. `SizeT` is the *signed* size type of this ndarray.
//
// NOTE: The order of fields is IMPORTANT. DON'T TOUCH IT
//
// Some resources you might find helpful:
// - The official numpy implementations:
// - https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst
// - On strides (about reshaping, slicing, C-contagiousness, etc)
// - https://ajcr.net/stride-guide-part-1/.
// - https://ajcr.net/stride-guide-part-2/.
// - https://ajcr.net/stride-guide-part-3/.
template <typename SizeT>
struct NDArray {
// The underlying data this `ndarray` is pointing to.
//
// NOTE: Formally this should be of type `void *`, but clang
// translates `void *` to `i8 *` when run with `-S -emit-llvm`,
// so we will put `uint8_t *` here for clarity.
uint8_t *data;
// The number of bytes of a single element in `data`.
//
// The `SizeT` is treated as `unsigned`.
SizeT itemsize;
// The number of dimensions of this shape.
//
// The `SizeT` is treated as `unsigned`.
SizeT ndims;
// Array shape, with length equal to `ndims`.
//
// The `SizeT` is treated as `unsigned`.
//
// NOTE: `shape` can contain 0.
// (those appear when the user makes an out of bounds slice into an ndarray, e.g., `np.zeros((3, 3))[400:].shape == (0, 3)`)
SizeT *shape;
// Array strides (stride value is in number of bytes, NOT number of elements), with length equal to `ndims`.
//
// The `SizeT` is treated as `signed`.
//
// NOTE: `strides` can have negative numbers.
// (those appear when there is a slice with a negative step, e.g., `my_array[::-1]`)
SizeT *strides;
// Calculate the size/# of elements of an `ndarray`.
// This function corresponds to `np.size(<ndarray>)` or `ndarray.size`
SizeT size() {
return ndarray_util::calc_size_from_shape(ndims, shape);
}
// Calculate the number of bytes of its content of an `ndarray` *in its view*.
// This function corresponds to `ndarray.nbytes`
SizeT nbytes() {
return this->size() * itemsize;
}
void set_pelement_value(uint8_t* pelement, const uint8_t* pvalue) {
__builtin_memcpy(pelement, pvalue, itemsize);
}
uint8_t* get_pelement_by_indices(const SizeT *indices) {
uint8_t* element = data;
for (SizeT dim_i = 0; dim_i < ndims; dim_i++)
element += indices[dim_i] * strides[dim_i];
return element;
}
uint8_t* get_nth_pelement(SizeT nth) {
irrt_assert(0 <= nth);
irrt_assert(nth < this->size());
SizeT* indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * this->ndims);
ndarray_util::set_indices_by_nth(this->ndims, this->shape, indices, nth);
return get_pelement_by_indices(indices);
}
// Get pointer to the first element of this ndarray, assuming
// `this->size() > 0`, i.e., not "degenerate" due to zeroes in `this->shape`)
//
// This is particularly useful for when the ndarray is just containing a single scalar.
uint8_t* get_first_pelement() {
irrt_assert(this->size() > 0);
return this->data; // ...It is simply `this->data`
}
// Is the given `indices` valid/in-bounds?
bool in_bounds(const SizeT *indices) {
for (SizeT dim_i = 0; dim_i < ndims; dim_i++) {
bool dim_ok = indices[dim_i] < shape[dim_i];
if (!dim_ok) return false;
}
return true;
}
// Fill the ndarray with a value
void fill_generic(const uint8_t* pvalue) {
NDArrayIndicesIter<SizeT> iter;
iter.ndims = this->ndims;
iter.shape = this->shape;
iter.indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * ndims);
iter.set_indices_zero();
for (SizeT i = 0; i < this->size(); i++, iter.next()) {
uint8_t* pelement = get_pelement_by_indices(iter.indices);
set_pelement_value(pelement, pvalue);
}
}
// Set the strides of the ndarray with `ndarray_util::set_strides_by_shape`
void set_strides_by_shape() {
ndarray_util::set_strides_by_shape(itemsize, ndims, strides, shape);
}
// https://numpy.org/doc/stable/reference/generated/numpy.eye.html
void set_to_eye(SizeT k, const uint8_t* zero_pvalue, const uint8_t* one_pvalue) {
__builtin_assume(ndims == 2);
// TODO: Better implementation
fill_generic(zero_pvalue);
for (SizeT i = 0; i < min(shape[0], shape[1]); i++) {
SizeT row = i;
SizeT col = i + k;
SizeT indices[2] = { row, col };
if (!in_bounds(indices)) continue;
uint8_t* pelement = get_pelement_by_indices(indices);
set_pelement_value(pelement, one_pvalue);
}
}
// To support numpy complex slices (e.g., `my_array[:50:2,4,:2:-1]`)
//
// Things assumed by this function:
// - `dst_ndarray` is allocated by the caller
// - `dst_ndarray.ndims` has the correct value (according to `ndarray_util::deduce_ndims_after_slicing`).
// - ... and `dst_ndarray.shape` and `dst_ndarray.strides` have been allocated by the caller as well
//
// Other notes:
// - `dst_ndarray->data` does not have to be set, it will be derived.
// - `dst_ndarray->itemsize` does not have to be set, it will be set to `this->itemsize`
// - `dst_ndarray->shape` and `dst_ndarray.strides` can contain empty values
void slice(SizeT num_ndslices, NDSlice* ndslices, NDArray<SizeT>* dst_ndarray) {
// REFERENCE CODE (check out `_index_helper` in `__getitem__`):
// https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L652
irrt_assert(dst_ndarray->ndims == ndarray_util::deduce_ndims_after_slicing(this->ndims, num_ndslices, ndslices));
dst_ndarray->data = this->data;
SizeT this_axis = 0;
SizeT dst_axis = 0;
for (SizeT i = 0; i < num_ndslices; i++) {
NDSlice *ndslice = &ndslices[i];
if (ndslice->type == INPUT_SLICE_TYPE_INDEX) {
// Handle when the ndslice is just a single (possibly negative) integer
// e.g., `my_array[::2, -5, ::-1]`
// ^^------ like this
SizeT index_user = *((SizeT*) ndslice->slice);
SizeT index = resolve_index_in_length(this->shape[this_axis], index_user);
dst_ndarray->data += index * this->strides[this_axis]; // Add offset
// Next
this_axis++;
} else if (ndslice->type == INPUT_SLICE_TYPE_SLICE) {
// Handle when the ndslice is a slice (represented by UserSlice in IRRT)
// e.g., `my_array[::2, -5, ::-1]`
// ^^^------^^^^----- like these
UserSlice<SizeT>* user_slice = (UserSlice<SizeT>*) ndslice->slice;
Slice<SizeT> slice = user_slice->indices(this->shape[this_axis]); // To resolve negative indices and other funny stuff written by the user
// NOTE: There is no need to write special code to handle negative steps/strides.
// This simple implementation meticulously handles both positive and negative steps/strides.
// Check out the tinynumpy and IRRT's test cases if you are not convinced.
dst_ndarray->data += slice.start * this->strides[this_axis]; // Add offset (NOTE: no need to `* itemsize`, strides count in # of bytes)
dst_ndarray->strides[dst_axis] = slice.step * this->strides[this_axis]; // Determine stride
dst_ndarray->shape[dst_axis] = slice.len(); // Determine shape dimension
// Next
dst_axis++;
this_axis++;
} else {
__builtin_unreachable();
}
}
irrt_assert(dst_axis == dst_ndarray->ndims); // Sanity check on the implementation
}
// Similar to `np.broadcast_to(<ndarray>, <target_shape>)`
// Assumptions:
// - `this` has to be fully initialized.
// - `dst_ndarray->ndims` has to be set.
// - `dst_ndarray->shape` has to be set, this determines the shape `this` broadcasts to.
//
// Other notes:
// - `dst_ndarray->data` does not have to be set, it will be set to `this->data`.
// - `dst_ndarray->itemsize` does not have to be set, it will be set to `this->data`.
// - `dst_ndarray->strides` does not have to be set, it will be overwritten.
//
// Cautions:
// ```
// xs = np.zeros((4,))
// ys = np.zero((4, 1))
// ys[:] = xs # ok
//
// xs = np.zeros((1, 4))
// ys = np.zero((4,))
// ys[:] = xs # allowed
// # However `np.broadcast_to(xs, (4,))` would fails, as per numpy's broadcasting rule.
// # and apparently numpy will "deprecate" this? SEE https://github.com/numpy/numpy/issues/21744
// # This implementation will NOT support this assignment.
// ```
void broadcast_to(NDArray<SizeT>* dst_ndarray) {
dst_ndarray->data = this->data;
dst_ndarray->itemsize = this->itemsize;
irrt_assert(
ndarray_util::can_broadcast_shape_to(
dst_ndarray->ndims,
dst_ndarray->shape,
this->ndims,
this->shape
)
);
SizeT stride_product = 1;
for (SizeT i = 0; i < max(this->ndims, dst_ndarray->ndims); i++) {
SizeT this_dim_i = this->ndims - i - 1;
SizeT dst_dim_i = dst_ndarray->ndims - i - 1;
bool this_dim_exists = this_dim_i >= 0;
bool dst_dim_exists = dst_dim_i >= 0;
// TODO: Explain how this works
bool c1 = this_dim_exists && this->shape[this_dim_i] == 1;
bool c2 = dst_dim_exists && dst_ndarray->shape[dst_dim_i] != 1;
if (!this_dim_exists || (c1 && c2)) {
dst_ndarray->strides[dst_dim_i] = 0; // Freeze it in-place
} else {
dst_ndarray->strides[dst_dim_i] = stride_product * this->itemsize;
stride_product *= this->shape[this_dim_i]; // NOTE: this_dim_exist must be true here.
}
}
}
// Simulates `this_ndarray[:] = src_ndarray`, with automatic broadcasting.
// Caution on https://github.com/numpy/numpy/issues/21744
// Also see `NDArray::broadcast_to`
void assign_with(NDArray<SizeT>* src_ndarray) {
irrt_assert(
ndarray_util::can_broadcast_shape_to(
this->ndims,
this->shape,
src_ndarray->ndims,
src_ndarray->shape
)
);
// Broadcast the `src_ndarray` to make the reading process *much* easier
SizeT* broadcasted_src_ndarray_strides = __builtin_alloca(sizeof(SizeT) * this->ndims); // Remember to allocate strides beforehand
NDArray<SizeT> broadcasted_src_ndarray = {
.ndims = this->ndims,
.shape = this->shape,
.strides = broadcasted_src_ndarray_strides
};
src_ndarray->broadcast_to(&broadcasted_src_ndarray);
// Using iter instead of `get_nth_pelement` because it is slightly faster
SizeT* indices = __builtin_alloca(sizeof(SizeT) * this->ndims);
auto iter = NDArrayIndicesIter<SizeT> {
.ndims = this->ndims,
.shape = this->shape,
.indices = indices
};
const SizeT this_size = this->size();
for (SizeT i = 0; i < this_size; i++, iter.next()) {
uint8_t* src_pelement = broadcasted_src_ndarray_strides->get_pelement_by_indices(indices);
uint8_t* this_pelement = this->get_pelement_by_indices(indices);
this->set_pelement_value(src_pelement, src_pelement);
}
}
};
}
extern "C" {
uint32_t __nac3_ndarray_size(NDArray<int32_t>* ndarray) {
return ndarray->size();
}
uint64_t __nac3_ndarray_size64(NDArray<int64_t>* ndarray) {
return ndarray->size();
}
void __nac3_ndarray_fill_generic(NDArray<int32_t>* ndarray, uint8_t* pvalue) {
ndarray->fill_generic(pvalue);
}
void __nac3_ndarray_fill_generic64(NDArray<int64_t>* ndarray, uint8_t* pvalue) {
ndarray->fill_generic(pvalue);
}
// void __nac3_ndarray_slice(NDArray<int32_t>* ndarray, int32_t num_slices, NDSlice<int32_t> *slices, NDArray<int32_t> *dst_ndarray) {
// // ndarray->slice(num_slices, slices, dst_ndarray);
// }
}

View File

@ -0,0 +1,80 @@
#pragma once
#include "irrt_utils.hpp"
#include "irrt_typedefs.hpp"
namespace {
// A proper slice in IRRT, all negative indices have be resolved to absolute values.
// Even though nac3core's slices are always `int32_t`, we will template slice anyway
// since this struct is used as a general utility.
template <typename T>
struct Slice {
T start;
T stop;
T step;
// The length/The number of elements of the slice if it were a range,
// i.e., the value of `len(range(this->start, this->stop, this->end))`
T len() {
T diff = stop - start;
if (diff > 0 && step > 0) {
return ((diff - 1) / step) + 1;
} else if (diff < 0 && step < 0) {
return ((diff + 1) / step) + 1;
} else {
return 0;
}
}
};
template<typename T>
T resolve_index_in_length(T length, T index) {
irrt_assert(length >= 0);
if (index < 0) {
// Remember that index is negative, so do a plus here
return max(length + index, 0);
} else {
return min(length, index);
}
}
// NOTE: using a bitfield for the `*_defined` is better, at the
// cost of a more annoying implementation in nac3core inkwell
template <typename T>
struct UserSlice {
uint8_t start_defined;
T start;
uint8_t stop_defined;
T stop;
uint8_t step_defined;
T step;
// Like Python's `slice(start, stop, step).indices(length)`
Slice<T> indices(T length) {
// NOTE: This function implements Python's `slice.indices` *FAITHFULLY*.
// SEE: https://github.com/python/cpython/blob/f62161837e68c1c77961435f1b954412dd5c2b65/Objects/sliceobject.c#L546
irrt_assert(length >= 0);
irrt_assert(!step_defined || step != 0); // step_defined -> step != 0; step cannot be zero if specified by user
Slice<T> result;
result.step = step_defined ? step : 1;
bool step_is_negative = result.step < 0;
if (start_defined) {
result.start = resolve_index_in_length(length, start);
} else {
result.start = step_is_negative ? length - 1 : 0;
}
if (stop_defined) {
result.stop = resolve_index_in_length(length, stop);
} else {
result.stop = step_is_negative ? -1 : length;
}
return result;
}
};
}

648
nac3core/irrt/irrt_test.cpp Normal file
View File

@ -0,0 +1,648 @@
// This file will be compiled like a real C++ program,
// and we do have the luxury to use the standard libraries.
// That is if the nix flakes do not have issues... especially on msys2...
#include <cstdint>
#include <cstdio>
#include <cstdlib>
// Set `IRRT_DONT_TYPEDEF_INTS` because `cstdint` defines them
#define IRRT_DONT_TYPEDEF_INTS
#include "irrt_everything.hpp"
void test_fail() {
printf("[!] Test failed\n");
exit(1);
}
void __begin_test(const char* function_name, const char* file, int line) {
printf("######### Running %s @ %s:%d\n", function_name, file, line);
}
#define BEGIN_TEST() __begin_test(__FUNCTION__, __FILE__, __LINE__)
template <typename T>
void debug_print_array(const char* format, int len, T* as) {
printf("[");
for (int i = 0; i < len; i++) {
if (i != 0) printf(", ");
printf(format, as[i]);
}
printf("]");
}
template <typename T>
void assert_arrays_match(const char* label, const char* format, int len, T* expected, T* got) {
if (!arrays_match(len, expected, got)) {
printf(">>>>>>> %s\n", label);
printf(" Expecting = ");
debug_print_array(format, len, expected);
printf("\n");
printf(" Got = ");
debug_print_array(format, len, got);
printf("\n");
test_fail();
}
}
template <typename T>
void assert_values_match(const char* label, const char* format, T expected, T got) {
if (expected != got) {
printf(">>>>>>> %s\n", label);
printf(" Expecting = ");
printf(format, expected);
printf("\n");
printf(" Got = ");
printf(format, got);
printf("\n");
test_fail();
}
}
void print_repeated(const char *str, int count) {
for (int i = 0; i < count; i++) {
printf("%s", str);
}
}
template<typename SizeT, typename ElementT>
void __print_ndarray_aux(const char *format, bool first, bool last, SizeT* cursor, SizeT depth, NDArray<SizeT>* ndarray) {
// A really lazy recursive implementation
// Add left padding unless its the first entry (since there would be "[[[" before it)
if (!first) {
print_repeated(" ", depth);
}
const SizeT dim = ndarray->shape[depth];
if (depth + 1 == ndarray->ndims) {
// Recursed down to last dimension, print the values in a nice list
printf("[");
SizeT* indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * ndarray->ndims);
for (SizeT i = 0; i < dim; i++) {
ndarray_util::set_indices_by_nth(ndarray->ndims, ndarray->shape, indices, *cursor);
ElementT* pelement = (ElementT*) ndarray->get_pelement_by_indices(indices);
ElementT element = *pelement;
if (i != 0) printf(", "); // List delimiter
printf(format, element);
printf("(@");
debug_print_array("%d", ndarray->ndims, indices);
printf(")");
(*cursor)++;
}
printf("]");
} else {
printf("[");
for (SizeT i = 0; i < ndarray->shape[depth]; i++) {
__print_ndarray_aux<SizeT, ElementT>(
format,
i == 0, // first?
i + 1 == dim, // last?
cursor,
depth + 1,
ndarray
);
}
printf("]");
}
// Add newline unless its the last entry (since there will be "]]]" after it)
if (!last) {
print_repeated("\n", depth);
}
}
template<typename SizeT, typename ElementT>
void print_ndarray(const char *format, NDArray<SizeT>* ndarray) {
if (ndarray->ndims == 0) {
printf("<empty ndarray>");
} else {
SizeT cursor = 0;
__print_ndarray_aux<SizeT, ElementT>(format, true, true, &cursor, 0, ndarray);
}
printf("\n");
}
void test_calc_size_from_shape_normal() {
// Test shapes with normal values
BEGIN_TEST();
int32_t shape[4] = { 2, 3, 5, 7 };
assert_values_match("size", "%d", 210, ndarray_util::calc_size_from_shape<int32_t>(4, shape));
}
void test_calc_size_from_shape_has_zero() {
// Test shapes with 0 in them
BEGIN_TEST();
int32_t shape[4] = { 2, 0, 5, 7 };
assert_values_match("size", "%d", 0, ndarray_util::calc_size_from_shape<int32_t>(4, shape));
}
void test_set_strides_by_shape() {
// Test `set_strides_by_shape()`
BEGIN_TEST();
int32_t shape[4] = { 99, 3, 5, 7 };
int32_t strides[4] = { 0 };
ndarray_util::set_strides_by_shape((int32_t) sizeof(int32_t), 4, strides, shape);
int32_t expected_strides[4] = {
105 * sizeof(int32_t),
35 * sizeof(int32_t),
7 * sizeof(int32_t),
1 * sizeof(int32_t)
};
assert_arrays_match("strides", "%u", 4u, expected_strides, strides);
}
void test_ndarray_indices_iter_normal() {
// Test NDArrayIndicesIter normal behavior
BEGIN_TEST();
int32_t shape[3] = { 1, 2, 3 };
int32_t indices[3] = { 0, 0, 0 };
auto iter = NDArrayIndicesIter<int32_t> {
.ndims = 3,
.shape = shape,
.indices = indices
};
assert_arrays_match("indices #0", "%u", 3u, iter.indices, (int32_t[3]) { 0, 0, 0 });
iter.next();
assert_arrays_match("indices #1", "%u", 3u, iter.indices, (int32_t[3]) { 0, 0, 1 });
iter.next();
assert_arrays_match("indices #2", "%u", 3u, iter.indices, (int32_t[3]) { 0, 0, 2 });
iter.next();
assert_arrays_match("indices #3", "%u", 3u, iter.indices, (int32_t[3]) { 0, 1, 0 });
iter.next();
assert_arrays_match("indices #4", "%u", 3u, iter.indices, (int32_t[3]) { 0, 1, 1 });
iter.next();
assert_arrays_match("indices #5", "%u", 3u, iter.indices, (int32_t[3]) { 0, 1, 2 });
iter.next();
assert_arrays_match("indices #6", "%u", 3u, iter.indices, (int32_t[3]) { 0, 0, 0 }); // Loops back
iter.next();
assert_arrays_match("indices #7", "%u", 3u, iter.indices, (int32_t[3]) { 0, 0, 1 });
}
void test_ndarray_fill_generic() {
// Test ndarray fill_generic
BEGIN_TEST();
// Choose a type that's neither int32_t nor uint64_t (candidates of SizeT) to spice it up
// Also make all the octets non-zero, to see if `memcpy` in `fill_generic` is working perfectly.
uint16_t fill_value = 0xFACE;
uint16_t in_data[6] = { 100, 101, 102, 103, 104, 105 }; // Fill `data` with values that != `999`
int32_t in_itemsize = sizeof(uint16_t);
const int32_t in_ndims = 2;
int32_t in_shape[in_ndims] = { 2, 3 };
int32_t in_strides[in_ndims] = {};
NDArray<int32_t> ndarray = {
.data = (uint8_t*) in_data,
.itemsize = in_itemsize,
.ndims = in_ndims,
.shape = in_shape,
.strides = in_strides,
};
ndarray.set_strides_by_shape();
ndarray.fill_generic((uint8_t*) &fill_value); // `fill_generic` here
uint16_t expected_data[6] = { fill_value, fill_value, fill_value, fill_value, fill_value, fill_value };
assert_arrays_match("data", "0x%hX", 6, expected_data, in_data);
}
void test_ndarray_set_to_eye() {
// Test `set_to_eye` behavior (helper function to implement `np.eye()`)
BEGIN_TEST();
double in_data[9] = { 99.0, 99.0, 99.0, 99.0, 99.0, 99.0, 99.0, 99.0, 99.0 };
int32_t in_itemsize = sizeof(double);
const int32_t in_ndims = 2;
int32_t in_shape[in_ndims] = { 3, 3 };
int32_t in_strides[in_ndims] = {};
NDArray<int32_t> ndarray = {
.data = (uint8_t*) in_data,
.itemsize = in_itemsize,
.ndims = in_ndims,
.shape = in_shape,
.strides = in_strides,
};
ndarray.set_strides_by_shape();
double zero = 0.0;
double one = 1.0;
ndarray.set_to_eye(1, (uint8_t*) &zero, (uint8_t*) &one);
assert_values_match("in_data[0]", "%f", 0.0, in_data[0]);
assert_values_match("in_data[1]", "%f", 1.0, in_data[1]);
assert_values_match("in_data[2]", "%f", 0.0, in_data[2]);
assert_values_match("in_data[3]", "%f", 0.0, in_data[3]);
assert_values_match("in_data[4]", "%f", 0.0, in_data[4]);
assert_values_match("in_data[5]", "%f", 1.0, in_data[5]);
assert_values_match("in_data[6]", "%f", 0.0, in_data[6]);
assert_values_match("in_data[7]", "%f", 0.0, in_data[7]);
assert_values_match("in_data[8]", "%f", 0.0, in_data[8]);
}
void test_slice_1() {
// Test `slice(5, None, None).indices(100) == slice(5, 100, 1)`
BEGIN_TEST();
UserSlice<int> user_slice = {
.start_defined = 1,
.start = 5,
.stop_defined = 0,
.step_defined = 0,
};
auto slice = user_slice.indices(100);
assert_values_match("start", "%d", 5, slice.start);
assert_values_match("stop", "%d", 100, slice.stop);
assert_values_match("step", "%d", 1, slice.step);
}
void test_slice_2() {
// Test `slice(400, 999, None).indices(100) == slice(100, 100, 1)`
BEGIN_TEST();
UserSlice<int> user_slice = {
.start_defined = 1,
.start = 400,
.stop_defined = 0,
.step_defined = 0,
};
auto slice = user_slice.indices(100);
assert_values_match("start", "%d", 100, slice.start);
assert_values_match("stop", "%d", 100, slice.stop);
assert_values_match("step", "%d", 1, slice.step);
}
void test_slice_3() {
// Test `slice(-10, -5, None).indices(100) == slice(90, 95, 1)`
BEGIN_TEST();
UserSlice<int> user_slice = {
.start_defined = 1,
.start = -10,
.stop_defined = 1,
.stop = -5,
.step_defined = 0,
};
auto slice = user_slice.indices(100);
assert_values_match("start", "%d", 90, slice.start);
assert_values_match("stop", "%d", 95, slice.stop);
assert_values_match("step", "%d", 1, slice.step);
}
void test_slice_4() {
// Test `slice(None, None, -5).indices(100) == (99, -1, -5)`
BEGIN_TEST();
UserSlice<int> user_slice = {
.start_defined = 0,
.stop_defined = 0,
.step_defined = 1,
.step = -5
};
auto slice = user_slice.indices(100);
assert_values_match("start", "%d", 99, slice.start);
assert_values_match("stop", "%d", -1, slice.stop);
assert_values_match("step", "%d", -5, slice.step);
}
void test_ndslice_1() {
/*
Reference Python code:
```python
ndarray = np.arange(12, dtype=np.float64).reshape((3, 4));
# array([[ 0., 1., 2., 3.],
# [ 4., 5., 6., 7.],
# [ 8., 9., 10., 11.]])
dst_ndarray = ndarray[-2:, 1::2]
# array([[ 5., 7.],
# [ 9., 11.]])
assert dst_ndarray.shape == (2, 2)
assert dst_ndarray.strides == (32, 16)
assert dst_ndarray[0, 0] == 5.0
assert dst_ndarray[0, 1] == 7.0
assert dst_ndarray[1, 0] == 9.0
assert dst_ndarray[1, 1] == 11.0
```
*/
BEGIN_TEST();
double in_data[12] = { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0 };
int32_t in_itemsize = sizeof(double);
const int32_t in_ndims = 2;
int32_t in_shape[in_ndims] = { 3, 4 };
int32_t in_strides[in_ndims] = {};
NDArray<int32_t> ndarray = {
.data = (uint8_t*) in_data,
.itemsize = in_itemsize,
.ndims = in_ndims,
.shape = in_shape,
.strides = in_strides
};
ndarray.set_strides_by_shape();
// Destination ndarray
// As documented, ndims and shape & strides must be allocated and determined by the caller.
const int32_t dst_ndims = 2;
int32_t dst_shape[dst_ndims] = {999, 999}; // Empty values
int32_t dst_strides[dst_ndims] = {999, 999}; // Empty values
NDArray<int32_t> dst_ndarray = {
.data = nullptr,
.ndims = dst_ndims,
.shape = dst_shape,
.strides = dst_strides
};
// Create the slice in `ndarray[-2::, 1::2]`
UserSlice<int32_t> user_slice_1 = {
.start_defined = 1,
.start = -2,
.stop_defined = 0,
.step_defined = 0
};
UserSlice<int32_t> user_slice_2 = {
.start_defined = 1,
.start = 1,
.stop_defined = 0,
.step_defined = 1,
.step = 2
};
const int32_t num_ndslices = 2;
NDSlice ndslices[num_ndslices] = {
{ .type = INPUT_SLICE_TYPE_SLICE, .slice = (uint8_t*) &user_slice_1 },
{ .type = INPUT_SLICE_TYPE_SLICE, .slice = (uint8_t*) &user_slice_2 }
};
ndarray.slice(num_ndslices, ndslices, &dst_ndarray);
int32_t expected_shape[dst_ndims] = { 2, 2 };
int32_t expected_strides[dst_ndims] = { 32, 16 };
assert_arrays_match("shape", "%d", dst_ndims, expected_shape, dst_ndarray.shape);
assert_arrays_match("strides", "%d", dst_ndims, expected_strides, dst_ndarray.strides);
assert_values_match("dst_ndarray[0, 0]", "%f", 5.0, *((double *) dst_ndarray.get_pelement_by_indices((int32_t[dst_ndims]) { 0, 0 })));
assert_values_match("dst_ndarray[0, 1]", "%f", 7.0, *((double *) dst_ndarray.get_pelement_by_indices((int32_t[dst_ndims]) { 0, 1 })));
assert_values_match("dst_ndarray[1, 0]", "%f", 9.0, *((double *) dst_ndarray.get_pelement_by_indices((int32_t[dst_ndims]) { 1, 0 })));
assert_values_match("dst_ndarray[1, 1]", "%f", 11.0, *((double *) dst_ndarray.get_pelement_by_indices((int32_t[dst_ndims]) { 1, 1 })));
}
void test_ndslice_2() {
/*
```python
ndarray = np.arange(12, dtype=np.float64).reshape((3, 4))
# array([[ 0., 1., 2., 3.],
# [ 4., 5., 6., 7.],
# [ 8., 9., 10., 11.]])
dst_ndarray = ndarray[2, ::-2]
# array([11., 9.])
assert dst_ndarray.shape == (2,)
assert dst_ndarray.strides == (-16,)
assert dst_ndarray[0] == 11.0
assert dst_ndarray[1] == 9.0
dst_ndarray[1, 0] == 99 # If you write to `dst_ndarray`
assert ndarray[1, 3] == 99 # `ndarray` also updates!!
```
*/
BEGIN_TEST();
double in_data[12] = { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0 };
int32_t in_itemsize = sizeof(double);
const int32_t in_ndims = 2;
int32_t in_shape[in_ndims] = { 3, 4 };
int32_t in_strides[in_ndims] = {};
NDArray<int32_t> ndarray = {
.data = (uint8_t*) in_data,
.itemsize = in_itemsize,
.ndims = in_ndims,
.shape = in_shape,
.strides = in_strides
};
ndarray.set_strides_by_shape();
// Destination ndarray
// As documented, ndims and shape & strides must be allocated and determined by the caller.
const int32_t dst_ndims = 1;
int32_t dst_shape[dst_ndims] = {999}; // Empty values
int32_t dst_strides[dst_ndims] = {999}; // Empty values
NDArray<int32_t> dst_ndarray = {
.data = nullptr,
.ndims = dst_ndims,
.shape = dst_shape,
.strides = dst_strides
};
// Create the slice in `ndarray[2, ::-2]`
int32_t user_slice_1 = 2;
UserSlice<int32_t> user_slice_2 = {
.start_defined = 0,
.stop_defined = 0,
.step_defined = 1,
.step = -2
};
const int32_t num_ndslices = 2;
NDSlice ndslices[num_ndslices] = {
{ .type = INPUT_SLICE_TYPE_INDEX, .slice = (uint8_t*) &user_slice_1 },
{ .type = INPUT_SLICE_TYPE_SLICE, .slice = (uint8_t*) &user_slice_2 }
};
ndarray.slice(num_ndslices, ndslices, &dst_ndarray);
int32_t expected_shape[dst_ndims] = { 2 };
int32_t expected_strides[dst_ndims] = { -16 };
assert_arrays_match("shape", "%d", dst_ndims, expected_shape, dst_ndarray.shape);
assert_arrays_match("strides", "%d", dst_ndims, expected_strides, dst_ndarray.strides);
// [5.0, 3.0]
assert_values_match("dst_ndarray[0]", "%f", 11.0, *((double *) dst_ndarray.get_pelement_by_indices((int32_t[dst_ndims]) { 0 })));
assert_values_match("dst_ndarray[1]", "%f", 9.0, *((double *) dst_ndarray.get_pelement_by_indices((int32_t[dst_ndims]) { 1 })));
}
void test_can_broadcast_shape() {
BEGIN_TEST();
assert_values_match(
"can_broadcast_shape_to([3], [1, 1, 1, 1, 3]) == true",
"%d",
true,
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 3 }, 5, (int32_t[]) { 1, 1, 1, 1, 3 })
);
assert_values_match(
"can_broadcast_shape_to([3], [3, 1]) == false",
"%d",
false,
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 3 }, 2, (int32_t[]) { 3, 1 }));
assert_values_match(
"can_broadcast_shape_to([3], [3]) == true",
"%d",
true,
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 3 }, 1, (int32_t[]) { 3 }));
assert_values_match(
"can_broadcast_shape_to([1], [3]) == false",
"%d",
false,
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 1 }, 1, (int32_t[]) { 3 }));
assert_values_match(
"can_broadcast_shape_to([1], [1]) == true",
"%d",
true,
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 1 }, 1, (int32_t[]) { 1 }));
assert_values_match(
"can_broadcast_shape_to([256, 256, 3], [256, 1, 3]) == true",
"%d",
true,
ndarray_util::can_broadcast_shape_to(3, (int32_t[]) { 256, 256, 3 }, 3, (int32_t[]) { 256, 1, 3 })
);
assert_values_match(
"can_broadcast_shape_to([256, 256, 3], [3]) == true",
"%d",
true,
ndarray_util::can_broadcast_shape_to(3, (int32_t[]) { 256, 256, 3 }, 1, (int32_t[]) { 3 })
);
assert_values_match(
"can_broadcast_shape_to([256, 256, 3], [2]) == false",
"%d",
false,
ndarray_util::can_broadcast_shape_to(3, (int32_t[]) { 256, 256, 3 }, 1, (int32_t[]) { 2 })
);
assert_values_match(
"can_broadcast_shape_to([256, 256, 3], [1]) == true",
"%d",
true,
ndarray_util::can_broadcast_shape_to(3, (int32_t[]) { 256, 256, 3 }, 1, (int32_t[]) { 1 })
);
// In cases when the shapes contain zero(es)
assert_values_match(
"can_broadcast_shape_to([0], [1]) == true",
"%d",
true,
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 0 }, 1, (int32_t[]) { 1 })
);
assert_values_match(
"can_broadcast_shape_to([0], [2]) == false",
"%d",
false,
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 0 }, 1, (int32_t[]) { 2 })
);
assert_values_match(
"can_broadcast_shape_to([0, 4, 0, 0], [1]) == true",
"%d",
true,
ndarray_util::can_broadcast_shape_to(4, (int32_t[]) { 0, 4, 0, 0 }, 1, (int32_t[]) { 1 })
);
assert_values_match(
"can_broadcast_shape_to([0, 4, 0, 0], [1, 1, 1, 1]) == true",
"%d",
true,
ndarray_util::can_broadcast_shape_to(4, (int32_t[]) { 0, 4, 0, 0 }, 4, (int32_t[]) { 1, 1, 1, 1 })
);
assert_values_match(
"can_broadcast_shape_to([0, 4, 0, 0], [1, 4, 1, 1]) == true",
"%d",
true,
ndarray_util::can_broadcast_shape_to(4, (int32_t[]) { 0, 4, 0, 0 }, 4, (int32_t[]) { 1, 4, 1, 1 })
);
assert_values_match(
"can_broadcast_shape_to([4, 3], [0, 3]) == false",
"%d",
false,
ndarray_util::can_broadcast_shape_to(2, (int32_t[]) { 4, 3 }, 2, (int32_t[]) { 0, 3 })
);
assert_values_match(
"can_broadcast_shape_to([4, 3], [0, 0]) == false",
"%d",
false,
ndarray_util::can_broadcast_shape_to(2, (int32_t[]) { 4, 3 }, 2, (int32_t[]) { 0, 0 })
);
}
void test_ndarray_broadcast_1() {
/*
# array = np.array([[19.9, 29.9, 39.9, 49.9]], dtype=np.float64)
# >>> [[19.9 29.9 39.9 49.9]]
#
# array = np.broadcast_to(array, (2, 3, 4))
# >>> [[[19.9 29.9 39.9 49.9]
# >>> [19.9 29.9 39.9 49.9]
# >>> [19.9 29.9 39.9 49.9]]
# >>> [[19.9 29.9 39.9 49.9]
# >>> [19.9 29.9 39.9 49.9]
# >>> [19.9 29.9 39.9 49.9]]]
#
# assery array.strides == (0, 0, 8)
*/
BEGIN_TEST();
double in_data[4] = { 19.9, 29.9, 39.9, 49.9 };
const int32_t in_ndims = 2;
int32_t in_shape[in_ndims] = {1, 4};
int32_t in_strides[in_ndims] = {};
NDArray<int32_t> ndarray = {
.data = (uint8_t*) in_data,
.itemsize = sizeof(double),
.ndims = in_ndims,
.shape = in_shape,
.strides = in_strides
};
ndarray.set_strides_by_shape();
const int32_t dst_ndims = 3;
int32_t dst_shape[dst_ndims] = {2, 3, 4};
int32_t dst_strides[dst_ndims] = {};
NDArray<int32_t> dst_ndarray = {
.ndims = dst_ndims,
.shape = dst_shape,
.strides = dst_strides
};
ndarray.broadcast_to(&dst_ndarray);
assert_arrays_match("dst_ndarray->strides", "%d", dst_ndims, (int32_t[]) { 0, 0, 8 }, dst_ndarray.strides);
assert_values_match("dst_ndarray[0, 0, 0]", "%f", 19.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 0, 0})));
assert_values_match("dst_ndarray[0, 0, 1]", "%f", 29.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 0, 1})));
assert_values_match("dst_ndarray[0, 0, 2]", "%f", 39.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 0, 2})));
assert_values_match("dst_ndarray[0, 0, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 0, 3})));
assert_values_match("dst_ndarray[0, 1, 0]", "%f", 19.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 1, 0})));
assert_values_match("dst_ndarray[0, 1, 1]", "%f", 29.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 1, 1})));
assert_values_match("dst_ndarray[0, 1, 2]", "%f", 39.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 1, 2})));
assert_values_match("dst_ndarray[0, 1, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 1, 3})));
assert_values_match("dst_ndarray[1, 2, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {1, 2, 3})));
}
int main() {
test_calc_size_from_shape_normal();
test_calc_size_from_shape_has_zero();
test_set_strides_by_shape();
test_ndarray_indices_iter_normal();
test_ndarray_fill_generic();
test_ndarray_set_to_eye();
test_slice_1();
test_slice_2();
test_slice_3();
test_slice_4();
test_ndslice_1();
test_ndslice_2();
test_can_broadcast_shape();
test_ndarray_broadcast_1();
return 0;
}

View File

@ -0,0 +1,14 @@
#pragma once
// This is made toggleable since `irrt_test.cpp` itself would include
// headers that define the `int_t` family.
#ifndef IRRT_DONT_TYPEDEF_INTS
typedef _BitInt(8) int8_t;
typedef unsigned _BitInt(8) uint8_t;
typedef _BitInt(32) int32_t;
typedef unsigned _BitInt(32) uint32_t;
typedef _BitInt(64) int64_t;
typedef unsigned _BitInt(64) uint64_t;
#endif
typedef int32_t SliceIndex;

View File

@ -0,0 +1,37 @@
#pragma once
#include "irrt_typedefs.hpp"
namespace {
template <typename T>
T max(T a, T b) {
return a > b ? a : b;
}
template <typename T>
T min(T a, T b) {
return a > b ? b : a;
}
template <typename T>
bool arrays_match(int len, T *as, T *bs) {
for (int i = 0; i < len; i++) {
if (as[i] != bs[i]) return false;
}
return true;
}
void irrt_panic() {
// Crash the program for now.
// TODO: Don't crash the program
// ... or at least produce a good message when doing testing IRRT
uint8_t* death = nullptr;
*death = 0; // TODO: address 0 on hardware might be writable?
}
// TODO: Make this a macro and allow it to be toggled on/off (e.g., debug vs release)
void irrt_assert(bool condition) {
if (!condition) irrt_panic();
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,16 +1,19 @@
use inkwell::{ use crate::codegen::{
context::Context, // irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
types::{AnyTypeEnum, ArrayType, BasicType, BasicTypeEnum, IntType, PointerType, StructType},
values::{ArrayValue, BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue},
AddressSpace, IntPredicate,
};
use super::{
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
llvm_intrinsics::call_int_umin, llvm_intrinsics::call_int_umin,
stmt::gen_for_callback_incrementing, stmt::gen_for_callback_incrementing,
CodeGenContext, CodeGenerator, CodeGenContext,
CodeGenerator,
}; };
use inkwell::context::Context;
use inkwell::types::{ArrayType, BasicType, StructType};
use inkwell::values::{ArrayValue, BasicValue, StructValue};
use inkwell::{
types::{AnyTypeEnum, BasicTypeEnum, IntType, PointerType},
values::{BasicValueEnum, IntValue, PointerValue},
AddressSpace, IntPredicate,
};
use itertools::Itertools;
/// A LLVM type that is used to represent a non-primitive type in NAC3. /// A LLVM type that is used to represent a non-primitive type in NAC3.
pub trait ProxyType<'ctx>: Into<Self::Base> { pub trait ProxyType<'ctx>: Into<Self::Base> {
@ -1207,25 +1210,27 @@ impl<'ctx> NDArrayType<'ctx> {
ctx: &'ctx Context, ctx: &'ctx Context,
dtype: BasicTypeEnum<'ctx>, dtype: BasicTypeEnum<'ctx>,
) -> Self { ) -> Self {
let llvm_usize = generator.get_size_type(ctx); todo!()
// struct NDArray { num_dims: size_t, dims: size_t*, data: T* } // let llvm_usize = generator.get_size_type(ctx);
//
// * num_dims: Number of dimensions in the array
// * dims: Pointer to an array containing the size of each dimension
// * data: Pointer to an array containing the array data
let llvm_ndarray = ctx
.struct_type(
&[
llvm_usize.into(),
llvm_usize.ptr_type(AddressSpace::default()).into(),
dtype.ptr_type(AddressSpace::default()).into(),
],
false,
)
.ptr_type(AddressSpace::default());
NDArrayType::from_type(llvm_ndarray, llvm_usize) // // struct NDArray { num_dims: size_t, dims: size_t*, data: T* }
// //
// // * num_dims: Number of dimensions in the array
// // * dims: Pointer to an array containing the size of each dimension
// // * data: Pointer to an array containing the array data
// let llvm_ndarray = ctx
// .struct_type(
// &[
// llvm_usize.into(),
// llvm_usize.ptr_type(AddressSpace::default()).into(),
// dtype.ptr_type(AddressSpace::default()).into(),
// ],
// false,
// )
// .ptr_type(AddressSpace::default());
// NDArrayType::from_type(llvm_ndarray, llvm_usize)
} }
/// Creates an [`NDArrayType`] from a [`PointerType`]. /// Creates an [`NDArrayType`] from a [`PointerType`].
@ -1249,13 +1254,11 @@ impl<'ctx> NDArrayType<'ctx> {
/// Returns the element type of this `ndarray` type. /// Returns the element type of this `ndarray` type.
#[must_use] #[must_use]
pub fn element_type(&self) -> AnyTypeEnum<'ctx> { pub fn element_type(&self) -> BasicTypeEnum<'ctx> {
self.as_base_type() self.as_base_type()
.get_element_type() .get_element_type()
.into_struct_type() .into_struct_type()
.get_field_type_at_index(2) .get_field_type_at_index(2)
.map(BasicTypeEnum::into_pointer_type)
.map(PointerType::get_element_type)
.unwrap() .unwrap()
} }
} }
@ -1405,7 +1408,7 @@ impl<'ctx> NDArrayValue<'ctx> {
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
/// on the field. /// on the field.
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default(); let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default();
@ -1602,7 +1605,8 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> {
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
generator: &G, generator: &G,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
call_ndarray_calc_size(generator, ctx, &self.as_slice_value(ctx, generator), (None, None)) todo!()
// call_ndarray_calc_size(generator, ctx, &self.as_slice_value(ctx, generator), (None, None))
} }
} }
@ -1660,33 +1664,34 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
indices: &Index, indices: &Index,
name: Option<&str>, name: Option<&str>,
) -> PointerValue<'ctx> { ) -> PointerValue<'ctx> {
let llvm_usize = generator.get_size_type(ctx.ctx); todo!()
// let llvm_usize = generator.get_size_type(ctx.ctx);
let indices_elem_ty = indices // let indices_elem_ty = indices
.ptr_offset(ctx, generator, &llvm_usize.const_zero(), None) // .ptr_offset(ctx, generator, &llvm_usize.const_zero(), None)
.get_type() // .get_type()
.get_element_type(); // .get_element_type();
let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else { // let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else {
panic!("Expected list[int32] but got {indices_elem_ty}") // panic!("Expected list[int32] but got {indices_elem_ty}")
}; // };
assert_eq!( // assert_eq!(
indices_elem_ty.get_bit_width(), // indices_elem_ty.get_bit_width(),
32, // 32,
"Expected list[int32] but got list[int{}]", // "Expected list[int32] but got list[int{}]",
indices_elem_ty.get_bit_width() // indices_elem_ty.get_bit_width()
); // );
let index = call_ndarray_flatten_index(generator, ctx, *self.0, indices); // let index = call_ndarray_flatten_index(generator, ctx, *self.0, indices);
unsafe { // unsafe {
ctx.builder // ctx.builder
.build_in_bounds_gep( // .build_in_bounds_gep(
self.base_ptr(ctx, generator), // self.base_ptr(ctx, generator),
&[index], // &[index],
name.unwrap_or_default(), // name.unwrap_or_default(),
) // )
.unwrap() // .unwrap()
} // }
} }
fn ptr_offset<G: CodeGenerator + ?Sized>( fn ptr_offset<G: CodeGenerator + ?Sized>(
@ -1718,7 +1723,6 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
gen_for_callback_incrementing( gen_for_callback_incrementing(
generator, generator,
ctx, ctx,
None,
llvm_usize.const_zero(), llvm_usize.const_zero(),
(len, false), (len, false),
|generator, ctx, _, i| { |generator, ctx, _, i| {
@ -1763,3 +1767,341 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeMutator<'ctx,
for NDArrayDataProxy<'ctx, '_> for NDArrayDataProxy<'ctx, '_>
{ {
} }
#[derive(Debug, Clone, Copy)]
pub struct StructField<'ctx> {
/// The GEP index of this struct field.
pub gep_index: u32,
/// Name of this struct field.
///
/// Used for generating names.
pub name: &'static str,
/// The type of this struct field.
pub ty: BasicTypeEnum<'ctx>,
}
pub struct StructFields<'ctx> {
/// Name of the struct.
///
/// Used for generating names.
pub name: &'static str,
/// All the [`StructField`]s of this struct.
///
/// **NOTE:** The index position of a [`StructField`]
/// matches the element's [`StructField::index`].
pub fields: Vec<StructField<'ctx>>,
}
struct StructFieldsBuilder<'ctx> {
gep_index_counter: u32,
/// Name of the struct to be built.
name: &'static str,
fields: Vec<StructField<'ctx>>,
}
impl<'ctx> StructField<'ctx> {
/// TODO: DOCUMENT ME
pub fn gep(
&self,
ctx: &CodeGenContext<'ctx, '_>,
struct_ptr: PointerValue<'ctx>,
) -> PointerValue<'ctx> {
let index_type = ctx.ctx.i32_type(); // TODO: I think I'm not supposed to use i32 for GEP like that
unsafe {
ctx.builder
.build_in_bounds_gep(
struct_ptr,
&[index_type.const_zero(), index_type.const_int(self.gep_index as u64, false)],
self.name,
)
.unwrap()
}
}
/// TODO: DOCUMENT ME
pub fn load(
&self,
ctx: &CodeGenContext<'ctx, '_>,
struct_ptr: PointerValue<'ctx>,
) -> BasicValueEnum<'ctx> {
ctx.builder.build_load(self.gep(ctx, struct_ptr), self.name).unwrap()
}
/// TODO: DOCUMENT ME
pub fn store<V>(&self, ctx: &CodeGenContext<'ctx, '_>, struct_ptr: PointerValue<'ctx>, value: V)
where
V: BasicValue<'ctx>,
{
ctx.builder.build_store(self.gep(ctx, struct_ptr), value).unwrap();
}
}
type IsInstanceError = String;
type IsInstanceResult = Result<(), IsInstanceError>;
pub fn check_basic_types_match<'ctx, A, B>(expected: A, got: B) -> IsInstanceResult
where
A: BasicType<'ctx>,
B: BasicType<'ctx>,
{
let expected = expected.as_basic_type_enum();
let got = got.as_basic_type_enum();
// Put those logic into here,
// otherwise there is always a fallback reporting on any kind of mismatch
match (expected, got) {
(BasicTypeEnum::IntType(expected), BasicTypeEnum::IntType(got)) => {
if expected.get_bit_width() != got.get_bit_width() {
return Err(format!(
"Expected IntType ({expected}-bit(s)), got IntType ({got}-bit(s))"
));
}
}
(expected, got) => {
if expected != got {
return Err(format!("Expected {expected}, got {got}"));
}
}
}
Ok(())
}
impl<'ctx> StructFields<'ctx> {
pub fn num_fields(&self) -> u32 {
self.fields.len() as u32
}
pub fn get_struct_type(&self, ctx: &'ctx Context) -> StructType<'ctx> {
let llvm_fields = self.fields.iter().map(|field| field.ty).collect_vec();
ctx.struct_type(llvm_fields.as_slice(), false)
}
pub fn is_type(&self, scrutinee: StructType<'ctx>) -> IsInstanceResult {
// Check scrutinee's number of struct fields
if scrutinee.count_fields() != self.num_fields() {
return Err(format!(
"Expected {expected_count} field(s) in `{struct_name}` type, got {got_count}",
struct_name = self.name,
expected_count = self.num_fields(),
got_count = scrutinee.count_fields(),
));
}
// Check the scrutinee's field types
for field in self.fields.iter() {
let expected_field_ty = field.ty;
let got_field_ty = scrutinee.get_field_type_at_index(field.gep_index).unwrap();
if let Err(field_err) = check_basic_types_match(expected_field_ty, got_field_ty) {
return Err(format!(
"Field GEP index {gep_index} does not match the expected type of ({struct_name}::{field_name}): {field_err}",
gep_index = field.gep_index,
struct_name = self.name,
field_name = field.name,
));
}
}
// Done
Ok(())
}
}
impl<'ctx> StructFieldsBuilder<'ctx> {
fn start(name: &'static str) -> Self {
StructFieldsBuilder { gep_index_counter: 0, name, fields: Vec::new() }
}
fn add_field(&mut self, name: &'static str, ty: BasicTypeEnum<'ctx>) -> StructField<'ctx> {
let index = self.gep_index_counter;
self.gep_index_counter += 1;
let field = StructField { gep_index: index, name, ty };
self.fields.push(field); // Register into self.fields
field // Return to the caller to conveniently let them do whatever they want
}
fn end(self) -> StructFields<'ctx> {
StructFields { name: self.name, fields: self.fields }
}
}
#[derive(Debug, Clone, Copy)]
pub struct NpArrayType<'ctx> {
pub size_type: IntType<'ctx>,
pub elem_type: BasicTypeEnum<'ctx>,
}
pub struct NpArrayStructFields<'ctx> {
pub whole_struct: StructFields<'ctx>,
pub data: StructField<'ctx>,
pub itemsize: StructField<'ctx>,
pub ndims: StructField<'ctx>,
pub shape: StructField<'ctx>,
pub strides: StructField<'ctx>,
}
impl<'ctx> NpArrayType<'ctx> {
pub fn new_opaque_elem(
ctx: &CodeGenContext<'ctx, '_>,
size_type: IntType<'ctx>,
) -> NpArrayType<'ctx> {
NpArrayType { size_type, elem_type: ctx.ctx.i8_type().as_basic_type_enum() }
}
pub fn get_struct_type(&self, ctx: &'ctx Context) -> StructType<'ctx> {
self.fields().whole_struct.get_struct_type(ctx)
}
pub fn fields(&self) -> NpArrayStructFields<'ctx> {
let mut builder = StructFieldsBuilder::start("NpArray");
let addrspace = AddressSpace::default();
let byte_type = self.size_type.get_context().i8_type();
// Make sure the struct matches PERFECTLY with that defined in `nac3core/irrt`.
let data = builder.add_field("data", byte_type.ptr_type(addrspace).into());
let itemsize = builder.add_field("itemsize", self.size_type.into());
let ndims = builder.add_field("ndims", self.size_type.into());
let shape = builder.add_field("shape", self.size_type.ptr_type(addrspace).into());
let strides = builder.add_field("strides", self.size_type.ptr_type(addrspace).into());
NpArrayStructFields { whole_struct: builder.end(), data, itemsize, ndims, shape, strides }
}
/// Allocate an `ndarray` on stack, with the following notes:
///
/// - `ndarray.ndims` will be initialized to `in_ndims`.
/// - `ndarray.itemsize` will be initialized to the size of `self.elem_type.size_of()`.
/// - `ndarray.shape` and `ndarray.strides` will be allocated on the stack with number of elements being `in_ndims`,
/// all with empty/uninitialized values.
pub fn var_alloc<G>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
in_ndims: IntValue<'ctx>,
name: Option<&str>,
) -> NpArrayValue<'ctx>
where
G: CodeGenerator + ?Sized,
{
let ptr = generator
.gen_var_alloc(ctx, self.get_struct_type(ctx.ctx).as_basic_type_enum(), name)
.unwrap();
// Allocate `in_dims` number of `size_type` on the stack for `shape` and `strides`
let allocated_shape = generator
.gen_array_var_alloc(
ctx,
self.size_type.as_basic_type_enum(),
in_ndims,
Some("allocated_shape"),
)
.unwrap();
let allocated_strides = generator
.gen_array_var_alloc(
ctx,
self.size_type.as_basic_type_enum(),
in_ndims,
Some("allocated_strides"),
)
.unwrap();
let value = NpArrayValue { ty: *self, ptr };
value.store_ndims(ctx, in_ndims);
value.store_itemsize(ctx, self.elem_type.size_of().unwrap());
value.store_shape(ctx, allocated_shape.base_ptr(ctx, generator));
value.store_strides(ctx, allocated_strides.base_ptr(ctx, generator));
return value;
}
}
#[derive(Debug, Clone, Copy)]
pub struct NpArrayValue<'ctx> {
pub ty: NpArrayType<'ctx>,
pub ptr: PointerValue<'ctx>,
}
impl<'ctx> NpArrayValue<'ctx> {
pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
let field = self.ty.fields().ndims;
field.load(ctx, self.ptr).into_int_value()
}
pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) {
let field = self.ty.fields().ndims;
field.store(ctx, self.ptr, value);
}
pub fn load_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
let field = self.ty.fields().itemsize;
field.load(ctx, self.ptr).into_int_value()
}
pub fn store_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) {
let field = self.ty.fields().itemsize;
field.store(ctx, self.ptr, value);
}
pub fn load_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let field = self.ty.fields().shape;
field.load(ctx, self.ptr).into_pointer_value()
}
pub fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) {
let field = self.ty.fields().shape;
field.store(ctx, self.ptr, value);
}
pub fn load_strides(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let field = self.ty.fields().strides;
field.load(ctx, self.ptr).into_pointer_value()
}
pub fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) {
let field = self.ty.fields().strides;
field.store(ctx, self.ptr, value);
}
/// TODO: DOCUMENT ME -- NDIMS WOULD NEVER CHANGE!!!!!
pub fn shape_slice(
&self,
ctx: &CodeGenContext<'ctx, '_>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
// Get the pointer to `shape`
let field = self.ty.fields().shape;
let shape = field.load(ctx, self.ptr).into_pointer_value();
// Load `ndims`
let ndims = self.load_ndims(ctx);
TypedArrayLikeAdapter {
adapted: ArraySliceValue(shape, ndims, Some(field.name)),
downcast_fn: Box::new(|_ctx, x| x.into_int_value()),
upcast_fn: Box::new(|_ctx, x| x.as_basic_value_enum()),
}
}
/// TODO: DOCUMENT ME -- NDIMS WOULD NEVER CHANGE!!!!!
pub fn strides_slice(
&self,
ctx: &CodeGenContext<'ctx, '_>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
// Get the pointer to `strides`
let field = self.ty.fields().strides;
let strides = field.load(ctx, self.ptr).into_pointer_value();
// Load `ndims`
let ndims = self.load_ndims(ctx);
TypedArrayLikeAdapter {
adapted: ArraySliceValue(strides, ndims, Some(field.name)),
downcast_fn: Box::new(|_ctx, x| x.into_int_value()),
upcast_fn: Box::new(|_ctx, x| x.as_basic_value_enum()),
}
}
}

View File

@ -1,9 +1,3 @@
use std::collections::HashMap;
use indexmap::IndexMap;
use nac3parser::ast::StrRef;
use crate::{ use crate::{
symbol_resolver::SymbolValue, symbol_resolver::SymbolValue,
toplevel::DefinitionId, toplevel::DefinitionId,
@ -15,6 +9,10 @@ use crate::{
}, },
}; };
use indexmap::IndexMap;
use nac3parser::ast::StrRef;
use std::collections::HashMap;
pub struct ConcreteTypeStore { pub struct ConcreteTypeStore {
store: Vec<ConcreteTypeEnum>, store: Vec<ConcreteTypeEnum>,
} }
@ -27,7 +25,6 @@ pub struct ConcreteFuncArg {
pub name: StrRef, pub name: StrRef,
pub ty: ConcreteType, pub ty: ConcreteType,
pub default_value: Option<SymbolValue>, pub default_value: Option<SymbolValue>,
pub is_vararg: bool,
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -49,7 +46,6 @@ pub enum ConcreteTypeEnum {
TPrimitive(Primitive), TPrimitive(Primitive),
TTuple { TTuple {
ty: Vec<ConcreteType>, ty: Vec<ConcreteType>,
is_vararg_ctx: bool,
}, },
TObj { TObj {
obj_id: DefinitionId, obj_id: DefinitionId,
@ -106,16 +102,8 @@ impl ConcreteTypeStore {
.iter() .iter()
.map(|arg| ConcreteFuncArg { .map(|arg| ConcreteFuncArg {
name: arg.name, name: arg.name,
ty: if arg.is_vararg { ty: self.from_unifier_type(unifier, primitives, arg.ty, cache),
let tuple_ty = unifier
.add_ty(TypeEnum::TTuple { ty: vec![arg.ty], is_vararg_ctx: true });
self.from_unifier_type(unifier, primitives, tuple_ty, cache)
} else {
self.from_unifier_type(unifier, primitives, arg.ty, cache)
},
default_value: arg.default_value.clone(), default_value: arg.default_value.clone(),
is_vararg: arg.is_vararg,
}) })
.collect(), .collect(),
ret: self.from_unifier_type(unifier, primitives, signature.ret, cache), ret: self.from_unifier_type(unifier, primitives, signature.ret, cache),
@ -170,12 +158,11 @@ impl ConcreteTypeStore {
cache.insert(ty, None); cache.insert(ty, None);
let ty_enum = unifier.get_ty(ty); let ty_enum = unifier.get_ty(ty);
let result = match &*ty_enum { let result = match &*ty_enum {
TypeEnum::TTuple { ty, is_vararg_ctx } => ConcreteTypeEnum::TTuple { TypeEnum::TTuple { ty } => ConcreteTypeEnum::TTuple {
ty: ty ty: ty
.iter() .iter()
.map(|t| self.from_unifier_type(unifier, primitives, *t, cache)) .map(|t| self.from_unifier_type(unifier, primitives, *t, cache))
.collect(), .collect(),
is_vararg_ctx: *is_vararg_ctx,
}, },
TypeEnum::TObj { obj_id, fields, params } => ConcreteTypeEnum::TObj { TypeEnum::TObj { obj_id, fields, params } => ConcreteTypeEnum::TObj {
obj_id: *obj_id, obj_id: *obj_id,
@ -261,12 +248,11 @@ impl ConcreteTypeStore {
*cache.get_mut(&cty).unwrap() = Some(ty); *cache.get_mut(&cty).unwrap() = Some(ty);
return ty; return ty;
} }
ConcreteTypeEnum::TTuple { ty, is_vararg_ctx } => TypeEnum::TTuple { ConcreteTypeEnum::TTuple { ty } => TypeEnum::TTuple {
ty: ty ty: ty
.iter() .iter()
.map(|cty| self.to_unifier_type(unifier, primitives, *cty, cache)) .map(|cty| self.to_unifier_type(unifier, primitives, *cty, cache))
.collect(), .collect(),
is_vararg_ctx: *is_vararg_ctx,
}, },
ConcreteTypeEnum::TVirtual { ty } => { ConcreteTypeEnum::TVirtual { ty } => {
TypeEnum::TVirtual { ty: self.to_unifier_type(unifier, primitives, *ty, cache) } TypeEnum::TVirtual { ty: self.to_unifier_type(unifier, primitives, *ty, cache) }
@ -291,7 +277,6 @@ impl ConcreteTypeStore {
name: arg.name, name: arg.name,
ty: self.to_unifier_type(unifier, primitives, arg.ty, cache), ty: self.to_unifier_type(unifier, primitives, arg.ty, cache),
default_value: arg.default_value.clone(), default_value: arg.default_value.clone(),
is_vararg: false,
}) })
.collect(), .collect(),
ret: self.to_unifier_type(unifier, primitives, *ret, cache), ret: self.to_unifier_type(unifier, primitives, *ret, cache),

File diff suppressed because it is too large Load Diff

View File

@ -1,10 +1,8 @@
use inkwell::{ use inkwell::attributes::{Attribute, AttributeLoc};
attributes::{Attribute, AttributeLoc}, use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue};
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue},
};
use itertools::Either; use itertools::Either;
use super::CodeGenContext; use crate::codegen::CodeGenContext;
/// Macro to generate extern function /// Macro to generate extern function
/// Both function return type and function parameter type are `FloatValue` /// Both function return type and function parameter type are `FloatValue`
@ -15,11 +13,11 @@ use super::CodeGenContext;
/// * `$extern_fn:literal`: Name of underlying extern function /// * `$extern_fn:literal`: Name of underlying extern function
/// ///
/// Optional Arguments: /// Optional Arguments:
/// * `$(,$attributes:literal)*)`: Attributes linked with the extern function. /// * `$(,$attributes:literal)*)`: Attributes linked with the extern function
/// The default attributes are "mustprogress", "nofree", "nounwind", "willreturn", and "writeonly". /// The default attributes are "mustprogress", "nofree", "nounwind", "willreturn", and "writeonly"
/// These will be used unless other attributes are specified /// These will be used unless other attributes are specified
/// * `$(,$args:ident)*`: Operands of the extern function /// * `$(,$args:ident)*`: Operands of the extern function
/// The data type of these operands will be set to `FloatValue` /// The data type of these operands will be set to `FloatValue`
/// ///
macro_rules! generate_extern_fn { macro_rules! generate_extern_fn {
("unary", $fn_name:ident, $extern_fn:literal) => { ("unary", $fn_name:ident, $extern_fn:literal) => {
@ -132,62 +130,3 @@ pub fn call_ldexp<'ctx>(
.map(Either::unwrap_left) .map(Either::unwrap_left)
.unwrap() .unwrap()
} }
/// Macro to generate `np_linalg` and `sp_linalg` functions
/// The function takes as input `NDArray` and returns ()
///
/// Arguments:
/// * `$fn_name:ident`: The identifier of the rust function to be generated
/// * `$extern_fn:literal`: Name of underlying extern function
/// * (2/3/4): Number of `NDArray` that function takes as input
///
/// Note:
/// The operands and resulting `NDArray` are both passed as input to the funcion
/// It is the responsibility of caller to ensure that output `NDArray` is properly allocated on stack
/// The function changes the content of the output `NDArray` in-place
macro_rules! generate_linalg_extern_fn {
($fn_name:ident, $extern_fn:literal, 2) => {
generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2);
};
($fn_name:ident, $extern_fn:literal, 3) => {
generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2, mat3);
};
($fn_name:ident, $extern_fn:literal, 4) => {
generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2, mat3, mat4);
};
($fn_name:ident, $extern_fn:literal $(,$input_matrix:ident)*) => {
#[doc = concat!("Invokes the linalg `", stringify!($extern_fn), " function." )]
pub fn $fn_name<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>
$(,$input_matrix: BasicValueEnum<'ctx>)*,
name: Option<&str>,
){
const FN_NAME: &str = $extern_fn;
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = ctx.ctx.void_type().fn_type(&[$($input_matrix.get_type().into()),*], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder.build_call(extern_fn, &[$($input_matrix.into(),)*], name.unwrap_or_default()).unwrap();
}
};
}
generate_linalg_extern_fn!(call_np_linalg_cholesky, "np_linalg_cholesky", 2);
generate_linalg_extern_fn!(call_np_linalg_qr, "np_linalg_qr", 3);
generate_linalg_extern_fn!(call_np_linalg_svd, "np_linalg_svd", 4);
generate_linalg_extern_fn!(call_np_linalg_inv, "np_linalg_inv", 2);
generate_linalg_extern_fn!(call_np_linalg_pinv, "np_linalg_pinv", 2);
generate_linalg_extern_fn!(call_np_linalg_matrix_power, "np_linalg_matrix_power", 3);
generate_linalg_extern_fn!(call_np_linalg_det, "np_linalg_det", 2);
generate_linalg_extern_fn!(call_sp_linalg_lu, "sp_linalg_lu", 3);
generate_linalg_extern_fn!(call_sp_linalg_schur, "sp_linalg_schur", 3);
generate_linalg_extern_fn!(call_sp_linalg_hessenberg, "sp_linalg_hessenberg", 3);

View File

@ -1,18 +1,16 @@
use crate::{
codegen::{bool_to_i1, bool_to_i8, classes::ArraySliceValue, expr::*, stmt::*, CodeGenContext},
symbol_resolver::ValueEnum,
toplevel::{DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, Type},
};
use inkwell::{ use inkwell::{
context::Context, context::Context,
types::{BasicTypeEnum, IntType}, types::{BasicTypeEnum, IntType},
values::{BasicValueEnum, IntValue, PointerValue}, values::{BasicValueEnum, IntValue, PointerValue},
}; };
use nac3parser::ast::{Expr, Stmt, StrRef}; use nac3parser::ast::{Expr, Stmt, StrRef};
use super::{bool_to_i1, bool_to_i8, classes::ArraySliceValue, expr::*, stmt::*, CodeGenContext};
use crate::{
symbol_resolver::ValueEnum,
toplevel::{DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, Type},
};
pub trait CodeGenerator { pub trait CodeGenerator {
/// Return the module name for the code generator. /// Return the module name for the code generator.
fn get_name(&self) -> &str; fn get_name(&self) -> &str;
@ -59,7 +57,6 @@ pub trait CodeGenerator {
/// - fun: Function signature, definition ID and the substitution key. /// - fun: Function signature, definition ID and the substitution key.
/// - params: Function parameters. Note that this does not include the object even if the /// - params: Function parameters. Note that this does not include the object even if the
/// function is a class method. /// function is a class method.
///
/// Note that this function should check if the function is generated in another thread (due to /// Note that this function should check if the function is generated in another thread (due to
/// possible race condition), see the default implementation for an example. /// possible race condition), see the default implementation for an example.
fn gen_func_instance<'ctx>( fn gen_func_instance<'ctx>(
@ -126,45 +123,11 @@ pub trait CodeGenerator {
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
target: &Expr<Option<Type>>, target: &Expr<Option<Type>>,
value: ValueEnum<'ctx>, value: ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String> ) -> Result<(), String>
where where
Self: Sized, Self: Sized,
{ {
gen_assign(self, ctx, target, value, value_ty) gen_assign(self, ctx, target, value)
}
/// Generate code for an assignment expression where LHS is a `"target_list"`.
///
/// See <https://docs.python.org/3/reference/simple_stmts.html#assignment-statements>.
fn gen_assign_target_list<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
targets: &Vec<Expr<Option<Type>>>,
value: ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String>
where
Self: Sized,
{
gen_assign_target_list(self, ctx, targets, value, value_ty)
}
/// Generate code for an item assignment.
///
/// i.e., `target[key] = value`
fn gen_setitem<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
target: &Expr<Option<Type>>,
key: &Expr<Option<Type>>,
value: ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String>
where
Self: Sized,
{
gen_setitem(self, ctx, target, key, value, value_ty)
} }
/// Generate code for a while expression. /// Generate code for a while expression.

View File

@ -1,30 +1,30 @@
use crate::{typecheck::typedef::Type, util::SizeVariant};
mod test;
use super::{
classes::{
check_basic_types_match, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue,
NDArrayValue, NpArrayType, NpArrayValue, TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
},
llvm_intrinsics, CodeGenContext, CodeGenerator,
};
use crate::codegen::classes::TypedArrayLikeAccessor;
use crate::codegen::stmt::gen_for_callback_incrementing;
use inkwell::{ use inkwell::{
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
context::Context, context::Context,
memory_buffer::MemoryBuffer, memory_buffer::MemoryBuffer,
module::Module, module::Module,
types::{BasicTypeEnum, IntType}, types::{BasicType, BasicTypeEnum, FunctionType, IntType, PointerType},
values::{BasicValue, BasicValueEnum, CallSiteValue, FloatValue, IntValue}, values::{BasicValue, BasicValueEnum, CallSiteValue, FloatValue, FunctionValue, IntValue},
AddressSpace, IntPredicate, AddressSpace, IntPredicate,
}; };
use itertools::Either; use itertools::Either;
use nac3parser::ast::Expr; use nac3parser::ast::Expr;
use super::{
classes::{
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue,
TypedArrayLikeAccessor, TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
},
llvm_intrinsics,
macros::codegen_unreachable,
stmt::gen_for_callback_incrementing,
CodeGenContext, CodeGenerator,
};
use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Type};
#[must_use] #[must_use]
pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> { pub fn load_irrt(ctx: &Context) -> Module {
let bitcode_buf = MemoryBuffer::create_from_memory_range( let bitcode_buf = MemoryBuffer::create_from_memory_range(
include_bytes!(concat!(env!("OUT_DIR"), "/irrt.bc")), include_bytes!(concat!(env!("OUT_DIR"), "/irrt.bc")),
"irrt_bitcode_buffer", "irrt_bitcode_buffer",
@ -40,25 +40,6 @@ pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver)
let function = irrt_mod.get_function(symbol).unwrap(); let function = irrt_mod.get_function(symbol).unwrap();
function.add_attribute(AttributeLoc::Function, ctx.create_enum_attribute(inline_attr, 0)); function.add_attribute(AttributeLoc::Function, ctx.create_enum_attribute(inline_attr, 0));
} }
// Initialize all global `EXN_*` exception IDs in IRRT with the [`SymbolResolver`].
let exn_id_type = ctx.i32_type();
let errors = &[
("EXN_INDEX_ERROR", "0:IndexError"),
("EXN_VALUE_ERROR", "0:ValueError"),
("EXN_ASSERTION_ERROR", "0:AssertionError"),
("EXN_TYPE_ERROR", "0:TypeError"),
];
for (irrt_name, symbol_name) in errors {
let exn_id = symbol_resolver.get_string_id(symbol_name);
let exn_id = exn_id_type.const_int(exn_id as u64, false).as_basic_value_enum();
let global = irrt_mod.get_global(irrt_name).unwrap_or_else(|| {
panic!("Exception symbol name '{irrt_name}' should exist in the IRRT LLVM module")
});
global.set_initializer(&exn_id);
}
irrt_mod irrt_mod
} }
@ -76,7 +57,7 @@ pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>(
(64, 64, true) => "__nac3_int_exp_int64_t", (64, 64, true) => "__nac3_int_exp_int64_t",
(32, 32, false) => "__nac3_int_exp_uint32_t", (32, 32, false) => "__nac3_int_exp_uint32_t",
(64, 64, false) => "__nac3_int_exp_uint64_t", (64, 64, false) => "__nac3_int_exp_uint64_t",
_ => codegen_unreachable!(ctx), _ => unreachable!(),
}; };
let base_type = base.get_type(); let base_type = base.get_type();
let pow_fun = ctx.module.get_function(symbol).unwrap_or_else(|| { let pow_fun = ctx.module.get_function(symbol).unwrap_or_else(|| {
@ -462,7 +443,7 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
BasicTypeEnum::IntType(t) => t.size_of(), BasicTypeEnum::IntType(t) => t.size_of(),
BasicTypeEnum::PointerType(t) => t.size_of(), BasicTypeEnum::PointerType(t) => t.size_of(),
BasicTypeEnum::StructType(t) => t.size_of().unwrap(), BasicTypeEnum::StructType(t) => t.size_of().unwrap(),
_ => codegen_unreachable!(ctx), _ => unreachable!(),
}; };
ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size").unwrap() ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size").unwrap()
} }
@ -584,369 +565,475 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo
.unwrap() .unwrap()
} }
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the // /// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
/// calculated total size. // /// calculated total size.
/// // ///
/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension. // /// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension.
/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for, // /// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for,
/// or [`None`] if starting from the first dimension and ending at the last dimension // /// or [`None`] if starting from the first dimension and ending at the last dimension respectively.
/// respectively. // pub fn call_ndarray_calc_size<'ctx, G, Dims>(
pub fn call_ndarray_calc_size<'ctx, G, Dims>( // generator: &G,
generator: &G, // ctx: &CodeGenContext<'ctx, '_>,
// dims: &Dims,
// (begin, end): (Option<IntValue<'ctx>>, Option<IntValue<'ctx>>),
// ) -> IntValue<'ctx>
// where
// G: CodeGenerator + ?Sized,
// Dims: ArrayLikeIndexer<'ctx>,
// {
// let llvm_usize = generator.get_size_type(ctx.ctx);
// let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
//
// let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() {
// 32 => "__nac3_ndarray_calc_size",
// 64 => "__nac3_ndarray_calc_size64",
// bw => unreachable!("Unsupported size type bit width: {}", bw),
// };
// let ndarray_calc_size_fn_t = llvm_usize.fn_type(
// &[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()],
// false,
// );
// let ndarray_calc_size_fn =
// ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| {
// ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
// });
//
// let begin = begin.unwrap_or_else(|| llvm_usize.const_zero());
// let end = end.unwrap_or_else(|| dims.size(ctx, generator));
// ctx.builder
// .build_call(
// ndarray_calc_size_fn,
// &[
// dims.base_ptr(ctx, generator).into(),
// dims.size(ctx, generator).into(),
// begin.into(),
// end.into(),
// ],
// "",
// )
// .map(CallSiteValue::try_as_basic_value)
// .map(|v| v.map_left(BasicValueEnum::into_int_value))
// .map(Either::unwrap_left)
// .unwrap()
// }
//
// /// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`]
// /// containing `i32` indices of the flattened index.
// ///
// /// * `index` - The index to compute the multidimensional index for.
// /// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
// /// `NDArray`.
// pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
// generator: &G,
// ctx: &mut CodeGenContext<'ctx, '_>,
// index: IntValue<'ctx>,
// ndarray: NDArrayValue<'ctx>,
// ) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
// let llvm_void = ctx.ctx.void_type();
// let llvm_i32 = ctx.ctx.i32_type();
// let llvm_usize = generator.get_size_type(ctx.ctx);
// let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
// let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
//
// let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() {
// 32 => "__nac3_ndarray_calc_nd_indices",
// 64 => "__nac3_ndarray_calc_nd_indices64",
// bw => unreachable!("Unsupported size type bit width: {}", bw),
// };
// let ndarray_calc_nd_indices_fn =
// ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| {
// let fn_type = llvm_void.fn_type(
// &[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()],
// false,
// );
//
// ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None)
// });
//
// let ndarray_num_dims = ndarray.load_ndims(ctx);
// let ndarray_dims = ndarray.dim_sizes();
//
// let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap();
//
// ctx.builder
// .build_call(
// ndarray_calc_nd_indices_fn,
// &[
// index.into(),
// ndarray_dims.base_ptr(ctx, generator).into(),
// ndarray_num_dims.into(),
// indices.into(),
// ],
// "",
// )
// .unwrap();
//
// TypedArrayLikeAdapter::from(
// ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None),
// Box::new(|_, v| v.into_int_value()),
// Box::new(|_, v| v.into()),
// )
// }
//
// fn call_ndarray_flatten_index_impl<'ctx, G, Indices>(
// generator: &G,
// ctx: &CodeGenContext<'ctx, '_>,
// ndarray: NDArrayValue<'ctx>,
// indices: &Indices,
// ) -> IntValue<'ctx>
// where
// G: CodeGenerator + ?Sized,
// Indices: ArrayLikeIndexer<'ctx>,
// {
// let llvm_i32 = ctx.ctx.i32_type();
// let llvm_usize = generator.get_size_type(ctx.ctx);
//
// let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
// let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
//
// debug_assert_eq!(
// IntType::try_from(indices.element_type(ctx, generator))
// .map(IntType::get_bit_width)
// .unwrap_or_default(),
// llvm_i32.get_bit_width(),
// "Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`"
// );
// debug_assert_eq!(
// indices.size(ctx, generator).get_type().get_bit_width(),
// llvm_usize.get_bit_width(),
// "Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`"
// );
//
// let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
// 32 => "__nac3_ndarray_flatten_index",
// 64 => "__nac3_ndarray_flatten_index64",
// bw => unreachable!("Unsupported size type bit width: {}", bw),
// };
// let ndarray_flatten_index_fn =
// ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| {
// let fn_type = llvm_usize.fn_type(
// &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()],
// false,
// );
//
// ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None)
// });
//
// let ndarray_num_dims = ndarray.load_ndims(ctx);
// let ndarray_dims = ndarray.dim_sizes();
//
// let index = ctx
// .builder
// .build_call(
// ndarray_flatten_index_fn,
// &[
// ndarray_dims.base_ptr(ctx, generator).into(),
// ndarray_num_dims.into(),
// indices.base_ptr(ctx, generator).into(),
// indices.size(ctx, generator).into(),
// ],
// "",
// )
// .map(CallSiteValue::try_as_basic_value)
// .map(|v| v.map_left(BasicValueEnum::into_int_value))
// .map(Either::unwrap_left)
// .unwrap();
//
// index
// }
//
// /// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
// /// multidimensional index.
// ///
// /// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
// /// `NDArray`.
// /// * `indices` - The multidimensional index to compute the flattened index for.
// pub fn call_ndarray_flatten_index<'ctx, G, Index>(
// generator: &mut G,
// ctx: &mut CodeGenContext<'ctx, '_>,
// ndarray: NDArrayValue<'ctx>,
// indices: &Index,
// ) -> IntValue<'ctx>
// where
// G: CodeGenerator + ?Sized,
// Index: ArrayLikeIndexer<'ctx>,
// {
// call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices)
// }
//
// /// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of
// /// dimension and size of each dimension of the resultant `ndarray`.
// pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
// generator: &mut G,
// ctx: &mut CodeGenContext<'ctx, '_>,
// lhs: NDArrayValue<'ctx>,
// rhs: NDArrayValue<'ctx>,
// ) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
// let llvm_usize = generator.get_size_type(ctx.ctx);
// let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
//
// let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
// 32 => "__nac3_ndarray_calc_broadcast",
// 64 => "__nac3_ndarray_calc_broadcast64",
// bw => unreachable!("Unsupported size type bit width: {}", bw),
// };
// let ndarray_calc_broadcast_fn =
// ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
// let fn_type = llvm_usize.fn_type(
// &[
// llvm_pusize.into(),
// llvm_usize.into(),
// llvm_pusize.into(),
// llvm_usize.into(),
// llvm_pusize.into(),
// ],
// false,
// );
//
// ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
// });
//
// let lhs_ndims = lhs.load_ndims(ctx);
// let rhs_ndims = rhs.load_ndims(ctx);
// let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None);
//
// gen_for_callback_incrementing(
// generator,
// ctx,
// llvm_usize.const_zero(),
// (min_ndims, false),
// |generator, ctx, _, idx| {
// let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap();
// let (lhs_dim_sz, rhs_dim_sz) = unsafe {
// (
// lhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None),
// rhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None),
// )
// };
//
// let llvm_usize_const_one = llvm_usize.const_int(1, false);
// let lhs_eqz = ctx
// .builder
// .build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "")
// .unwrap();
// let rhs_eqz = ctx
// .builder
// .build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "")
// .unwrap();
// let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap();
//
// let lhs_eq_rhs = ctx
// .builder
// .build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "")
// .unwrap();
//
// let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap();
//
// ctx.make_assert(
// generator,
// is_compatible,
// "0:ValueError",
// "operands could not be broadcast together",
// [None, None, None],
// ctx.current_loc,
// );
//
// Ok(())
// },
// llvm_usize.const_int(1, false),
// )
// .unwrap();
//
// let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None);
// let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator);
// let lhs_ndims = lhs.load_ndims(ctx);
// let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator);
// let rhs_ndims = rhs.load_ndims(ctx);
// let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap();
// let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None);
//
// ctx.builder
// .build_call(
// ndarray_calc_broadcast_fn,
// &[
// lhs_dims.into(),
// lhs_ndims.into(),
// rhs_dims.into(),
// rhs_ndims.into(),
// out_dims.base_ptr(ctx, generator).into(),
// ],
// "",
// )
// .unwrap();
//
// TypedArrayLikeAdapter::from(
// out_dims,
// Box::new(|_, v| v.into_int_value()),
// Box::new(|_, v| v.into()),
// )
// }
//
// /// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`]
// /// containing the indices used for accessing `array` corresponding to the index of the broadcasted
// /// array `broadcast_idx`.
// pub fn call_ndarray_calc_broadcast_index<
// 'ctx,
// G: CodeGenerator + ?Sized,
// BroadcastIdx: UntypedArrayLikeAccessor<'ctx>,
// >(
// generator: &mut G,
// ctx: &mut CodeGenContext<'ctx, '_>,
// array: NDArrayValue<'ctx>,
// broadcast_idx: &BroadcastIdx,
// ) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
// let llvm_i32 = ctx.ctx.i32_type();
// let llvm_usize = generator.get_size_type(ctx.ctx);
// let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
// let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
//
// let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
// 32 => "__nac3_ndarray_calc_broadcast_idx",
// 64 => "__nac3_ndarray_calc_broadcast_idx64",
// bw => unreachable!("Unsupported size type bit width: {}", bw),
// };
// let ndarray_calc_broadcast_fn =
// ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
// let fn_type = llvm_usize.fn_type(
// &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()],
// false,
// );
//
// ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
// });
//
// let broadcast_size = broadcast_idx.size(ctx, generator);
// let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();
//
// let array_dims = array.dim_sizes().base_ptr(ctx, generator);
// let array_ndims = array.load_ndims(ctx);
// let broadcast_idx_ptr = unsafe {
// broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
// };
//
// ctx.builder
// .build_call(
// ndarray_calc_broadcast_fn,
// &[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()],
// "",
// )
// .unwrap();
//
// TypedArrayLikeAdapter::from(
// ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None),
// Box::new(|_, v| v.into_int_value()),
// Box::new(|_, v| v.into()),
// )
// }
fn get_size_variant<'ctx>(ty: IntType<'ctx>) -> SizeVariant {
match ty.get_bit_width() {
32 => SizeVariant::Bits32,
64 => SizeVariant::Bits64,
_ => unreachable!("Unsupported int type bit width {}", ty.get_bit_width()),
}
}
fn get_size_type_dependent_function<'ctx, BuildFuncTypeFn>(
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
dims: &Dims, size_type: IntType<'ctx>,
(begin, end): (Option<IntValue<'ctx>>, Option<IntValue<'ctx>>), base_name: &str,
) -> IntValue<'ctx> build_func_type: BuildFuncTypeFn,
) -> FunctionValue<'ctx>
where where
G: CodeGenerator + ?Sized, BuildFuncTypeFn: Fn() -> FunctionType<'ctx>,
Dims: ArrayLikeIndexer<'ctx>,
{ {
let llvm_usize = generator.get_size_type(ctx.ctx); let mut fn_name = base_name.to_owned();
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); match get_size_variant(size_type) {
SizeVariant::Bits32 => {
// The original fn_name is the correct function name
}
SizeVariant::Bits64 => {
// Append "64" at the end, this is the naming convention for 64-bit
fn_name.push_str("64");
}
}
let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() { // Get (or declare then get if does not exist) the corresponding function
32 => "__nac3_ndarray_calc_size", ctx.module.get_function(&fn_name).unwrap_or_else(|| {
64 => "__nac3_ndarray_calc_size64", let fn_type = build_func_type();
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), ctx.module.add_function(&fn_name, fn_type, None)
}; })
let ndarray_calc_size_fn_t = llvm_usize.fn_type( }
&[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()],
false, fn get_irrt_ndarray_ptr_type<'ctx>(
); ctx: &CodeGenContext<'ctx, '_>,
let ndarray_calc_size_fn = size_type: IntType<'ctx>,
ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| { ) -> PointerType<'ctx> {
ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None) let i8_type = ctx.ctx.i8_type();
});
let ndarray_ty = NpArrayType { size_type, elem_type: i8_type.as_basic_type_enum() };
let struct_ty = ndarray_ty.get_struct_type(ctx.ctx);
struct_ty.ptr_type(AddressSpace::default())
}
fn get_irrt_opaque_uint8_ptr_type<'ctx>(ctx: &CodeGenContext<'ctx, '_>) -> PointerType<'ctx> {
ctx.ctx.i8_type().ptr_type(AddressSpace::default())
}
pub fn call_nac3_ndarray_size<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NpArrayValue<'ctx>,
) -> IntValue<'ctx> {
let size_type = ndarray.ty.size_type;
let function = get_size_type_dependent_function(ctx, size_type, "__nac3_ndarray_size", || {
size_type.fn_type(&[get_irrt_ndarray_ptr_type(ctx, size_type).into()], false)
});
let begin = begin.unwrap_or_else(|| llvm_usize.const_zero());
let end = end.unwrap_or_else(|| dims.size(ctx, generator));
ctx.builder ctx.builder
.build_call( .build_call(function, &[ndarray.ptr.into()], "size")
ndarray_calc_size_fn,
&[
dims.base_ptr(ctx, generator).into(),
dims.size(ctx, generator).into(),
begin.into(),
end.into(),
],
"",
)
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap() .unwrap()
.try_as_basic_value()
.unwrap_left()
.into_int_value()
} }
/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`] pub fn call_nac3_ndarray_fill_generic<'ctx>(
/// containing `i32` indices of the flattened index.
///
/// * `index` - The index to compute the multidimensional index for.
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`.
pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &mut CodeGenContext<'ctx, '_>,
index: IntValue<'ctx>,
ndarray: NDArrayValue<'ctx>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_void = ctx.ctx.void_type();
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_nd_indices",
64 => "__nac3_ndarray_calc_nd_indices64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_calc_nd_indices_fn =
ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| {
let fn_type = llvm_void.fn_type(
&[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()],
false,
);
ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None)
});
let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.dim_sizes();
let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap();
ctx.builder
.build_call(
ndarray_calc_nd_indices_fn,
&[
index.into(),
ndarray_dims.base_ptr(ctx, generator).into(),
ndarray_num_dims.into(),
indices.into(),
],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None),
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}
fn call_ndarray_flatten_index_impl<'ctx, G, Indices>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>, ndarray: NpArrayValue<'ctx>,
indices: &Indices, fill_value: BasicValueEnum<'ctx>,
) -> IntValue<'ctx> ) {
where // Sanity check on type of `fill_value`
G: CodeGenerator + ?Sized, check_basic_types_match(ndarray.ty.elem_type, fill_value.get_type().as_basic_type_enum())
Indices: ArrayLikeIndexer<'ctx>,
{
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
debug_assert_eq!(
IntType::try_from(indices.element_type(ctx, generator))
.map(IntType::get_bit_width)
.unwrap_or_default(),
llvm_i32.get_bit_width(),
"Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`"
);
debug_assert_eq!(
indices.size(ctx, generator).get_type().get_bit_width(),
llvm_usize.get_bit_width(),
"Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`"
);
let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_flatten_index",
64 => "__nac3_ndarray_flatten_index64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_flatten_index_fn =
ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()],
false,
);
ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None)
});
let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.dim_sizes();
let index = ctx
.builder
.build_call(
ndarray_flatten_index_fn,
&[
ndarray_dims.base_ptr(ctx, generator).into(),
ndarray_num_dims.into(),
indices.base_ptr(ctx, generator).into(),
indices.size(ctx, generator).into(),
],
"",
)
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap(); .unwrap();
index let size_type = ndarray.ty.size_type;
} let function =
get_size_type_dependent_function(ctx, size_type, "__nac3_ndarray_fill_generic", || {
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the ctx.ctx.void_type().fn_type(
/// multidimensional index.
///
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`.
/// * `indices` - The multidimensional index to compute the flattened index for.
pub fn call_ndarray_flatten_index<'ctx, G, Index>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: &Index,
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Index: ArrayLikeIndexer<'ctx>,
{
call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices)
}
/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of
/// dimension and size of each dimension of the resultant `ndarray`.
pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
lhs: NDArrayValue<'ctx>,
rhs: NDArrayValue<'ctx>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast",
64 => "__nac3_ndarray_calc_broadcast64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_calc_broadcast_fn =
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[ &[
llvm_pusize.into(), get_irrt_ndarray_ptr_type(ctx, size_type).into(), // NDArray<SizeT>* ndarray
llvm_usize.into(), get_irrt_opaque_uint8_ptr_type(ctx).into(), // uint8_t* pvalue
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
], ],
false, false,
); )
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
}); });
let lhs_ndims = lhs.load_ndims(ctx); // Put `fill_value` onto the stack and get a pointer to it, and that pointer will be `pvalue`
let rhs_ndims = rhs.load_ndims(ctx); let pvalue = ctx.builder.build_alloca(ndarray.ty.elem_type, "fill_value").unwrap();
let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None); ctx.builder.build_store(pvalue, fill_value).unwrap();
gen_for_callback_incrementing( // Cast pvalue to `uint8_t*`
generator, let pvalue = ctx.builder.build_pointer_cast(pvalue, get_irrt_opaque_uint8_ptr_type(ctx), "").unwrap();
ctx,
None,
llvm_usize.const_zero(),
(min_ndims, false),
|generator, ctx, _, idx| {
let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap();
let (lhs_dim_sz, rhs_dim_sz) = unsafe {
(
lhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None),
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None),
)
};
let llvm_usize_const_one = llvm_usize.const_int(1, false);
let lhs_eqz = ctx
.builder
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "")
.unwrap();
let rhs_eqz = ctx
.builder
.build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "")
.unwrap();
let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap();
let lhs_eq_rhs = ctx
.builder
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "")
.unwrap();
let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap();
ctx.make_assert(
generator,
is_compatible,
"0:ValueError",
"operands could not be broadcast together",
[None, None, None],
ctx.current_loc,
);
Ok(())
},
llvm_usize.const_int(1, false),
)
.unwrap();
let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None);
let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator);
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator);
let rhs_ndims = rhs.load_ndims(ctx);
let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap();
let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None);
// Call the IRRT function
ctx.builder ctx.builder
.build_call( .build_call(
ndarray_calc_broadcast_fn, function,
&[ &[
lhs_dims.into(), ndarray.ptr.into(), // ndarray
lhs_ndims.into(), pvalue.into(), // pvalue
rhs_dims.into(),
rhs_ndims.into(),
out_dims.base_ptr(ctx, generator).into(),
], ],
"", "",
) )
.unwrap(); .unwrap();
TypedArrayLikeAdapter::from(
out_dims,
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}
/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`]
/// containing the indices used for accessing `array` corresponding to the index of the broadcasted
/// array `broadcast_idx`.
pub fn call_ndarray_calc_broadcast_index<
'ctx,
G: CodeGenerator + ?Sized,
BroadcastIdx: UntypedArrayLikeAccessor<'ctx>,
>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
array: NDArrayValue<'ctx>,
broadcast_idx: &BroadcastIdx,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast_idx",
64 => "__nac3_ndarray_calc_broadcast_idx64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_calc_broadcast_fn =
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()],
false,
);
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
});
let broadcast_size = broadcast_idx.size(ctx, generator);
let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();
let array_dims = array.dim_sizes().base_ptr(ctx, generator);
let array_ndims = array.load_ndims(ctx);
let broadcast_idx_ptr = unsafe {
broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
ctx.builder
.build_call(
ndarray_calc_broadcast_fn,
&[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None),
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
} }

View File

@ -0,0 +1,26 @@
#[cfg(test)]
mod tests {
use std::{path::Path, process::Command};
#[test]
fn run_irrt_test() {
assert!(
cfg!(feature = "test"),
"Please do `cargo test -F test` to compile `irrt_test.out` and run test"
);
let irrt_test_out_path = Path::new(concat!(env!("OUT_DIR"), "/irrt_test.out"));
let output = Command::new(irrt_test_out_path.to_str().unwrap()).output().unwrap();
if !output.status.success() {
eprintln!("irrt_test failed with status {}:", output.status);
eprintln!("====== stdout ======");
eprintln!("{}", String::from_utf8(output.stdout).unwrap());
eprintln!("====== stderr ======");
eprintln!("{}", String::from_utf8(output.stderr).unwrap());
eprintln!("====================");
panic!("irrt_test failed");
}
}
}

View File

@ -1,14 +1,12 @@
use inkwell::{ use crate::codegen::CodeGenContext;
context::Context, use inkwell::context::Context;
intrinsics::Intrinsic, use inkwell::intrinsics::Intrinsic;
types::{AnyTypeEnum::IntType, FloatType}, use inkwell::types::AnyTypeEnum::IntType;
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue}, use inkwell::types::FloatType;
AddressSpace, use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue};
}; use inkwell::AddressSpace;
use itertools::Either; use itertools::Either;
use super::CodeGenContext;
/// Returns the string representation for the floating-point type `ft` when used in intrinsic /// Returns the string representation for the floating-point type `ft` when used in intrinsic
/// functions. /// functions.
fn get_float_intrinsic_repr(ctx: &Context, ft: FloatType) -> &'static str { fn get_float_intrinsic_repr(ctx: &Context, ft: FloatType) -> &'static str {
@ -37,40 +35,6 @@ fn get_float_intrinsic_repr(ctx: &Context, ft: FloatType) -> &'static str {
unreachable!() unreachable!()
} }
/// Invokes the [`llvm.va_start`](https://llvm.org/docs/LangRef.html#llvm-va-start-intrinsic)
/// intrinsic.
pub fn call_va_start<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue<'ctx>) {
const FN_NAME: &str = "llvm.va_start";
let intrinsic_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let llvm_void = ctx.ctx.void_type();
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let fn_type = llvm_void.fn_type(&[llvm_p0i8.into()], false);
ctx.module.add_function(FN_NAME, fn_type, None)
});
ctx.builder.build_call(intrinsic_fn, &[arglist.into()], "").unwrap();
}
/// Invokes the [`llvm.va_start`](https://llvm.org/docs/LangRef.html#llvm-va-start-intrinsic)
/// intrinsic.
pub fn call_va_end<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue<'ctx>) {
const FN_NAME: &str = "llvm.va_end";
let intrinsic_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let llvm_void = ctx.ctx.void_type();
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let fn_type = llvm_void.fn_type(&[llvm_p0i8.into()], false);
ctx.module.add_function(FN_NAME, fn_type, None)
});
ctx.builder.build_call(intrinsic_fn, &[arglist.into()], "").unwrap();
}
/// Invokes the [`llvm.stacksave`](https://llvm.org/docs/LangRef.html#llvm-stacksave-intrinsic) /// Invokes the [`llvm.stacksave`](https://llvm.org/docs/LangRef.html#llvm-stacksave-intrinsic)
/// intrinsic. /// intrinsic.
pub fn call_stacksave<'ctx>( pub fn call_stacksave<'ctx>(
@ -185,7 +149,7 @@ pub fn call_memcpy_generic<'ctx>(
dest dest
} else { } else {
ctx.builder ctx.builder
.build_bit_cast(dest, llvm_p0i8, "") .build_bitcast(dest, llvm_p0i8, "")
.map(BasicValueEnum::into_pointer_value) .map(BasicValueEnum::into_pointer_value)
.unwrap() .unwrap()
}; };
@ -193,7 +157,7 @@ pub fn call_memcpy_generic<'ctx>(
src src
} else { } else {
ctx.builder ctx.builder
.build_bit_cast(src, llvm_p0i8, "") .build_bitcast(src, llvm_p0i8, "")
.map(BasicValueEnum::into_pointer_value) .map(BasicValueEnum::into_pointer_value)
.unwrap() .unwrap()
}; };
@ -207,9 +171,8 @@ pub fn call_memcpy_generic<'ctx>(
/// * `$ctx:ident`: Reference to the current Code Generation Context /// * `$ctx:ident`: Reference to the current Code Generation Context
/// * `$name:ident`: Optional name to be assigned to the llvm build call (Option<&str>) /// * `$name:ident`: Optional name to be assigned to the llvm build call (Option<&str>)
/// * `$llvm_name:literal`: Name of underlying llvm intrinsic function /// * `$llvm_name:literal`: Name of underlying llvm intrinsic function
/// * `$map_fn:ident`: Mapping function to be applied on `BasicValue` (`BasicValue` -> Function Return Type). /// * `$map_fn:ident`: Mapping function to be applied on `BasicValue` (`BasicValue` -> Function Return Type)
/// Use `BasicValueEnum::into_int_value` for Integer return type and /// Use `BasicValueEnum::into_int_value` for Integer return type and `BasicValueEnum::into_float_value` for Float return type
/// `BasicValueEnum::into_float_value` for Float return type
/// * `$llvm_ty:ident`: Type of first operand /// * `$llvm_ty:ident`: Type of first operand
/// * `,($val:ident)*`: Comma separated list of operands /// * `,($val:ident)*`: Comma separated list of operands
macro_rules! generate_llvm_intrinsic_fn_body { macro_rules! generate_llvm_intrinsic_fn_body {
@ -225,8 +188,8 @@ macro_rules! generate_llvm_intrinsic_fn_body {
/// Arguments: /// Arguments:
/// * `float/int`: Indicates the return and argument type of the function /// * `float/int`: Indicates the return and argument type of the function
/// * `$fn_name:ident`: The identifier of the rust function to be generated /// * `$fn_name:ident`: The identifier of the rust function to be generated
/// * `$llvm_name:literal`: Name of underlying llvm intrinsic function. /// * `$llvm_name:literal`: Name of underlying llvm intrinsic function
/// Omit "llvm." prefix from the function name i.e. use "ceil" instead of "llvm.ceil" /// Omit "llvm." prefix from the function name i.e. use "ceil" instead of "llvm.ceil"
/// * `$val:ident`: The operand for unary operations /// * `$val:ident`: The operand for unary operations
/// * `$val1:ident`, `$val2:ident`: The operands for binary operations /// * `$val1:ident`, `$val2:ident`: The operands for binary operations
macro_rules! generate_llvm_intrinsic_fn { macro_rules! generate_llvm_intrinsic_fn {

View File

@ -1,12 +1,13 @@
use std::{ use crate::{
collections::{HashMap, HashSet}, codegen::classes::{ListType, NDArrayType, ProxyType, RangeType},
sync::{ symbol_resolver::{StaticValue, SymbolResolver},
atomic::{AtomicBool, Ordering}, toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef},
Arc, typecheck::{
type_inferencer::{CodeLocation, PrimitiveStore},
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
}, },
thread,
}; };
use classes::NpArrayType;
use crossbeam::channel::{unbounded, Receiver, Sender}; use crossbeam::channel::{unbounded, Receiver, Sender};
use inkwell::{ use inkwell::{
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
@ -24,21 +25,14 @@ use inkwell::{
AddressSpace, IntPredicate, OptimizationLevel, AddressSpace, IntPredicate, OptimizationLevel,
}; };
use itertools::Itertools; use itertools::Itertools;
use parking_lot::{Condvar, Mutex};
use nac3parser::ast::{Location, Stmt, StrRef}; use nac3parser::ast::{Location, Stmt, StrRef};
use parking_lot::{Condvar, Mutex};
use crate::{ use std::collections::{HashMap, HashSet};
symbol_resolver::{StaticValue, SymbolResolver}, use std::sync::{
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef}, atomic::{AtomicBool, Ordering},
typecheck::{ Arc,
type_inferencer::{CodeLocation, PrimitiveStore},
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
},
}; };
use classes::{ListType, NDArrayType, ProxyType, RangeType}; use std::thread;
use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore};
pub use generator::{CodeGenerator, DefaultCodeGenerator};
pub mod builtin_fns; pub mod builtin_fns;
pub mod classes; pub mod classes;
@ -54,21 +48,8 @@ pub mod stmt;
#[cfg(test)] #[cfg(test)]
mod test; mod test;
mod macros { use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore};
/// Codegen-variant of [`std::unreachable`] which accepts an instance of [`CodeGenContext`] as pub use generator::{CodeGenerator, DefaultCodeGenerator};
/// its first argument to provide Python source information to indicate the codegen location
/// causing the assertion.
macro_rules! codegen_unreachable {
($ctx:expr $(,)?) => {
std::unreachable!("unreachable code while processing {}", &$ctx.current_loc)
};
($ctx:expr, $($arg:tt)*) => {
std::unreachable!("unreachable code while processing {}: {}", &$ctx.current_loc, std::format!("{}", std::format_args!($($arg)+)))
};
}
pub(crate) use codegen_unreachable;
}
#[derive(Default)] #[derive(Default)]
pub struct StaticValueStore { pub struct StaticValueStore {
@ -88,16 +69,6 @@ pub struct CodeGenLLVMOptions {
pub target: CodeGenTargetMachineOptions, pub target: CodeGenTargetMachineOptions,
} }
impl CodeGenLLVMOptions {
/// Creates a [`TargetMachine`] using the target options specified by this struct.
///
/// See [`Target::create_target_machine`].
#[must_use]
pub fn create_target_machine(&self) -> Option<TargetMachine> {
self.target.create_target_machine(self.opt_level)
}
}
/// Additional options for code generation for the target machine. /// Additional options for code generation for the target machine.
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug, Eq, PartialEq)]
pub struct CodeGenTargetMachineOptions { pub struct CodeGenTargetMachineOptions {
@ -368,10 +339,6 @@ impl WorkerRegistry {
let mut builder = context.create_builder(); let mut builder = context.create_builder();
let mut module = context.create_module(generator.get_name()); let mut module = context.create_module(generator.get_name());
let target_machine = self.llvm_options.create_target_machine().unwrap();
module.set_data_layout(&target_machine.get_target_data().get_data_layout());
module.set_triple(&target_machine.get_triple());
module.add_basic_value_flag( module.add_basic_value_flag(
"Debug Info Version", "Debug Info Version",
inkwell::module::FlagBehavior::Warning, inkwell::module::FlagBehavior::Warning,
@ -395,10 +362,6 @@ impl WorkerRegistry {
errors.insert(e); errors.insert(e);
// create a new empty module just to continue codegen and collect errors // create a new empty module just to continue codegen and collect errors
module = context.create_module(&format!("{}_recover", generator.get_name())); module = context.create_module(&format!("{}_recover", generator.get_name()));
let target_machine = self.llvm_options.create_target_machine().unwrap();
module.set_data_layout(&target_machine.get_target_data().get_data_layout());
module.set_triple(&target_machine.get_triple());
} }
} }
*self.task_count.lock() -= 1; *self.task_count.lock() -= 1;
@ -464,7 +427,7 @@ pub struct CodeGenTask {
fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
ctx: &'ctx Context, ctx: &'ctx Context,
module: &Module<'ctx>, module: &Module<'ctx>,
generator: &G, generator: &mut G,
unifier: &mut Unifier, unifier: &mut Unifier,
top_level: &TopLevelContext, top_level: &TopLevelContext,
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>, type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
@ -514,7 +477,11 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
ctx, module, generator, unifier, top_level, type_cache, dtype, ctx, module, generator, unifier, top_level, type_cache, dtype,
); );
NDArrayType::new(generator, ctx, element_type).as_base_type().into() let ndarray_ty = NpArrayType {
size_type: generator.get_size_type(ctx),
elem_type: element_type,
};
ndarray_ty.get_struct_type(ctx).as_basic_type_enum()
} }
_ => unreachable!( _ => unreachable!(
@ -558,10 +525,8 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
}; };
return ty; return ty;
} }
TTuple { ty, is_vararg_ctx } => { TTuple { ty } => {
// a struct with fields in the order present in the tuple // a struct with fields in the order present in the tuple
assert!(!is_vararg_ctx, "Tuples in vararg context must be instantiated with the correct number of arguments before calling get_llvm_type");
let fields = ty let fields = ty
.iter() .iter()
.map(|ty| { .map(|ty| {
@ -591,7 +556,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
fn get_llvm_abi_type<'ctx, G: CodeGenerator + ?Sized>( fn get_llvm_abi_type<'ctx, G: CodeGenerator + ?Sized>(
ctx: &'ctx Context, ctx: &'ctx Context,
module: &Module<'ctx>, module: &Module<'ctx>,
generator: &G, generator: &mut G,
unifier: &mut Unifier, unifier: &mut Unifier,
top_level: &TopLevelContext, top_level: &TopLevelContext,
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>, type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
@ -600,11 +565,11 @@ fn get_llvm_abi_type<'ctx, G: CodeGenerator + ?Sized>(
) -> BasicTypeEnum<'ctx> { ) -> BasicTypeEnum<'ctx> {
// If the type is used in the definition of a function, return `i1` instead of `i8` for ABI // If the type is used in the definition of a function, return `i1` instead of `i8` for ABI
// consistency. // consistency.
if unifier.unioned(ty, primitives.bool) { return if unifier.unioned(ty, primitives.bool) {
ctx.bool_type().into() ctx.bool_type().into()
} else { } else {
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, ty) get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, ty)
} };
} }
/// Whether `sret` is needed for a return value with type `ty`. /// Whether `sret` is needed for a return value with type `ty`.
@ -629,40 +594,6 @@ fn need_sret(ty: BasicTypeEnum) -> bool {
need_sret_impl(ty, true) need_sret_impl(ty, true)
} }
/// Returns the [`BasicTypeEnum`] representing a `va_list` struct for variadic arguments.
fn get_llvm_valist_type<'ctx>(ctx: &'ctx Context, triple: &TargetTriple) -> BasicTypeEnum<'ctx> {
let triple = TargetMachine::normalize_triple(triple);
let triple = triple.as_str().to_str().unwrap();
let arch = triple.split('-').next().unwrap();
let llvm_pi8 = ctx.i8_type().ptr_type(AddressSpace::default());
// Referenced from parseArch() in llvm/lib/Support/Triple.cpp
match arch {
"i386" | "i486" | "i586" | "i686" | "riscv32" => {
ctx.i8_type().ptr_type(AddressSpace::default()).into()
}
"amd64" | "x86_64" | "x86_64h" => {
let llvm_i32 = ctx.i32_type();
let va_list_tag = ctx.opaque_struct_type("struct.__va_list_tag");
va_list_tag.set_body(
&[llvm_i32.into(), llvm_i32.into(), llvm_pi8.into(), llvm_pi8.into()],
false,
);
va_list_tag.into()
}
"armv7" => {
let va_list = ctx.opaque_struct_type("struct.__va_list");
va_list.set_body(&[llvm_pi8.into()], false);
va_list.into()
}
triple => {
todo!("Unsupported platform for varargs: {triple}")
}
}
}
/// Implementation for generating LLVM IR for a function. /// Implementation for generating LLVM IR for a function.
pub fn gen_func_impl< pub fn gen_func_impl<
'ctx, 'ctx,
@ -774,7 +705,6 @@ pub fn gen_func_impl<
name: arg.name, name: arg.name,
ty: task.store.to_unifier_type(&mut unifier, &primitives, arg.ty, &mut cache), ty: task.store.to_unifier_type(&mut unifier, &primitives, arg.ty, &mut cache),
default_value: arg.default_value.clone(), default_value: arg.default_value.clone(),
is_vararg: arg.is_vararg,
}) })
.collect_vec(), .collect_vec(),
task.store.to_unifier_type(&mut unifier, &primitives, *ret, &mut cache), task.store.to_unifier_type(&mut unifier, &primitives, *ret, &mut cache),
@ -797,10 +727,7 @@ pub fn gen_func_impl<
let has_sret = ret_type.map_or(false, |ty| need_sret(ty)); let has_sret = ret_type.map_or(false, |ty| need_sret(ty));
let mut params = args let mut params = args
.iter() .iter()
.filter(|arg| !arg.is_vararg)
.map(|arg| { .map(|arg| {
debug_assert!(!arg.is_vararg);
get_llvm_abi_type( get_llvm_abi_type(
context, context,
&module, &module,
@ -819,12 +746,9 @@ pub fn gen_func_impl<
params.insert(0, ret_type.unwrap().ptr_type(AddressSpace::default()).into()); params.insert(0, ret_type.unwrap().ptr_type(AddressSpace::default()).into());
} }
debug_assert!(matches!(args.iter().filter(|arg| arg.is_vararg).count(), 0..=1));
let vararg_arg = args.iter().find(|arg| arg.is_vararg);
let fn_type = match ret_type { let fn_type = match ret_type {
Some(ret_type) if !has_sret => ret_type.fn_type(&params, vararg_arg.is_some()), Some(ret_type) if !has_sret => ret_type.fn_type(&params, false),
_ => context.void_type().fn_type(&params, vararg_arg.is_some()), _ => context.void_type().fn_type(&params, false),
}; };
let symbol = &task.symbol_name; let symbol = &task.symbol_name;
@ -852,10 +776,9 @@ pub fn gen_func_impl<
builder.position_at_end(init_bb); builder.position_at_end(init_bb);
let body_bb = context.append_basic_block(fn_val, "body"); let body_bb = context.append_basic_block(fn_val, "body");
// Store non-vararg argument values into local variables
let mut var_assignment = HashMap::new(); let mut var_assignment = HashMap::new();
let offset = u32::from(has_sret); let offset = u32::from(has_sret);
for (n, arg) in args.iter().enumerate().filter(|(_, arg)| !arg.is_vararg) { for (n, arg) in args.iter().enumerate() {
let param = fn_val.get_nth_param((n as u32) + offset).unwrap(); let param = fn_val.get_nth_param((n as u32) + offset).unwrap();
let local_type = get_llvm_type( let local_type = get_llvm_type(
context, context,
@ -888,8 +811,6 @@ pub fn gen_func_impl<
var_assignment.insert(arg.name, (alloca, None, 0)); var_assignment.insert(arg.name, (alloca, None, 0));
} }
// TODO: Save vararg parameters as list
let return_buffer = if has_sret { let return_buffer = if has_sret {
Some(fn_val.get_nth_param(0).unwrap().into_pointer_value()) Some(fn_val.get_nth_param(0).unwrap().into_pointer_value())
} else { } else {
@ -1112,9 +1033,3 @@ fn gen_in_range_check<'ctx>(
ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp").unwrap() ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp").unwrap()
} }
/// Returns the internal name for the `va_count` argument, used to indicate the number of arguments
/// passed to the variadic function.
fn get_va_count_arg_name(arg_name: StrRef) -> StrRef {
format!("__{}_va_count", &arg_name).into()
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,37 +1,34 @@
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use indexmap::IndexMap;
use indoc::indoc;
use inkwell::{
targets::{InitializationConfig, Target},
OptimizationLevel,
};
use nac3parser::{
ast::{fold::Fold, FileName, StrRef},
parser::parse_program,
};
use parking_lot::RwLock;
use super::{
classes::{ListType, NDArrayType, ProxyType, RangeType},
concrete_type::ConcreteTypeStore,
CodeGenContext, CodeGenLLVMOptions, CodeGenTargetMachineOptions, CodeGenTask, CodeGenerator,
DefaultCodeGenerator, WithCall, WorkerRegistry,
};
use crate::{ use crate::{
codegen::{
classes::{ListType, NDArrayType, ProxyType, RangeType},
concrete_type::ConcreteTypeStore,
CodeGenContext, CodeGenLLVMOptions, CodeGenTargetMachineOptions, CodeGenTask,
CodeGenerator, DefaultCodeGenerator, WithCall, WorkerRegistry,
},
symbol_resolver::{SymbolResolver, ValueEnum}, symbol_resolver::{SymbolResolver, ValueEnum},
toplevel::{ toplevel::{
composer::{ComposerConfig, TopLevelComposer}, composer::{ComposerConfig, TopLevelComposer},
DefinitionId, FunInstance, TopLevelContext, TopLevelDef, DefinitionId, FunInstance, TopLevelContext, TopLevelDef,
}, },
typecheck::{ typecheck::{
type_inferencer::{FunctionData, IdentifierInfo, Inferencer, PrimitiveStore}, type_inferencer::{FunctionData, Inferencer, PrimitiveStore},
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
}, },
}; };
use indexmap::IndexMap;
use indoc::indoc;
use inkwell::{
targets::{InitializationConfig, Target},
OptimizationLevel,
};
use nac3parser::ast::FileName;
use nac3parser::{
ast::{fold::Fold, StrRef},
parser::parse_program,
};
use parking_lot::RwLock;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
struct Resolver { struct Resolver {
id_to_type: HashMap<StrRef, Type>, id_to_type: HashMap<StrRef, Type>,
@ -67,7 +64,6 @@ impl SymbolResolver for Resolver {
&self, &self,
_: StrRef, _: StrRef,
_: &mut CodeGenContext<'ctx, '_>, _: &mut CodeGenContext<'ctx, '_>,
_: &mut dyn CodeGenerator,
) -> Option<ValueEnum<'ctx>> { ) -> Option<ValueEnum<'ctx>> {
unimplemented!() unimplemented!()
} }
@ -98,7 +94,7 @@ fn test_primitives() {
"}; "};
let statements = parse_program(source, FileName::default()).unwrap(); let statements = parse_program(source, FileName::default()).unwrap();
let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 32).0; let composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 32).0;
let mut unifier = composer.unifier.clone(); let mut unifier = composer.unifier.clone();
let primitives = composer.primitives_ty; let primitives = composer.primitives_ty;
let top_level = Arc::new(composer.make_top_level_context()); let top_level = Arc::new(composer.make_top_level_context());
@ -113,18 +109,8 @@ fn test_primitives() {
let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()]; let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()];
let signature = FunSignature { let signature = FunSignature {
args: vec![ args: vec![
FuncArg { FuncArg { name: "a".into(), ty: primitives.int32, default_value: None },
name: "a".into(), FuncArg { name: "b".into(), ty: primitives.int32, default_value: None },
ty: primitives.int32,
default_value: None,
is_vararg: false,
},
FuncArg {
name: "b".into(),
ty: primitives.int32,
default_value: None,
is_vararg: false,
},
], ],
ret: primitives.int32, ret: primitives.int32,
vars: VarMap::new(), vars: VarMap::new(),
@ -142,8 +128,7 @@ fn test_primitives() {
}; };
let mut virtual_checks = Vec::new(); let mut virtual_checks = Vec::new();
let mut calls = HashMap::new(); let mut calls = HashMap::new();
let mut identifiers: HashMap<_, _> = let mut identifiers: HashSet<_> = ["a".into(), "b".into()].into();
["a".into(), "b".into()].map(|id| (id, IdentifierInfo::default())).into();
let mut inferencer = Inferencer { let mut inferencer = Inferencer {
top_level: &top_level, top_level: &top_level,
function_data: &mut function_data, function_data: &mut function_data,
@ -204,8 +189,6 @@ fn test_primitives() {
let expected = indoc! {" let expected = indoc! {"
; ModuleID = 'test' ; ModuleID = 'test'
source_filename = \"test\" source_filename = \"test\"
target datalayout = \"e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128\"
target triple = \"x86_64-unknown-linux-gnu\"
; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn ; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn
define i32 @testing(i32 %0, i32 %1) local_unnamed_addr #0 !dbg !4 { define i32 @testing(i32 %0, i32 %1) local_unnamed_addr #0 !dbg !4 {
@ -263,19 +246,14 @@ fn test_simple_call() {
"}; "};
let statements_2 = parse_program(source_2, FileName::default()).unwrap(); let statements_2 = parse_program(source_2, FileName::default()).unwrap();
let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 32).0; let composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 32).0;
let mut unifier = composer.unifier.clone(); let mut unifier = composer.unifier.clone();
let primitives = composer.primitives_ty; let primitives = composer.primitives_ty;
let top_level = Arc::new(composer.make_top_level_context()); let top_level = Arc::new(composer.make_top_level_context());
unifier.top_level = Some(top_level.clone()); unifier.top_level = Some(top_level.clone());
let signature = FunSignature { let signature = FunSignature {
args: vec![FuncArg { args: vec![FuncArg { name: "a".into(), ty: primitives.int32, default_value: None }],
name: "a".into(),
ty: primitives.int32,
default_value: None,
is_vararg: false,
}],
ret: primitives.int32, ret: primitives.int32,
vars: VarMap::new(), vars: VarMap::new(),
}; };
@ -322,8 +300,7 @@ fn test_simple_call() {
}; };
let mut virtual_checks = Vec::new(); let mut virtual_checks = Vec::new();
let mut calls = HashMap::new(); let mut calls = HashMap::new();
let mut identifiers: HashMap<_, _> = let mut identifiers: HashSet<_> = ["a".into(), "foo".into()].into();
["a".into(), "foo".into()].map(|id| (id, IdentifierInfo::default())).into();
let mut inferencer = Inferencer { let mut inferencer = Inferencer {
top_level: &top_level, top_level: &top_level,
function_data: &mut function_data, function_data: &mut function_data,
@ -391,8 +368,6 @@ fn test_simple_call() {
let expected = indoc! {" let expected = indoc! {"
; ModuleID = 'test' ; ModuleID = 'test'
source_filename = \"test\" source_filename = \"test\"
target datalayout = \"e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128\"
target triple = \"x86_64-unknown-linux-gnu\"
; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn ; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn
define i32 @testing(i32 %0) local_unnamed_addr #0 !dbg !5 { define i32 @testing(i32 %0) local_unnamed_addr #0 !dbg !5 {

View File

@ -19,11 +19,8 @@
clippy::wildcard_imports clippy::wildcard_imports
)] )]
// users of nac3core need to use the same version of these dependencies, so expose them as nac3core::*
pub use inkwell;
pub use nac3parser;
pub mod codegen; pub mod codegen;
pub mod symbol_resolver; pub mod symbol_resolver;
pub mod toplevel; pub mod toplevel;
pub mod typecheck; pub mod typecheck;
pub mod util;

View File

@ -1,15 +1,7 @@
use std::{ use std::fmt::Debug;
collections::{HashMap, HashSet}, use std::rc::Rc;
fmt::{Debug, Display}, use std::sync::Arc;
rc::Rc, use std::{collections::HashMap, collections::HashSet, fmt::Display};
sync::Arc,
};
use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue};
use itertools::{chain, izip, Itertools};
use parking_lot::RwLock;
use nac3parser::ast::{Constant, Expr, Location, StrRef};
use crate::{ use crate::{
codegen::{CodeGenContext, CodeGenerator}, codegen::{CodeGenContext, CodeGenerator},
@ -19,6 +11,10 @@ use crate::{
typedef::{Type, TypeEnum, Unifier, VarMap}, typedef::{Type, TypeEnum, Unifier, VarMap},
}, },
}; };
use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue};
use itertools::{chain, izip, Itertools};
use nac3parser::ast::{Constant, Expr, Location, StrRef};
use parking_lot::RwLock;
#[derive(Clone, PartialEq, Debug)] #[derive(Clone, PartialEq, Debug)]
pub enum SymbolValue { pub enum SymbolValue {
@ -82,14 +78,14 @@ impl SymbolValue {
} }
Constant::Tuple(t) => { Constant::Tuple(t) => {
let expected_ty = unifier.get_ty(expected_ty); let expected_ty = unifier.get_ty(expected_ty);
let TypeEnum::TTuple { ty, is_vararg_ctx } = expected_ty.as_ref() else { let TypeEnum::TTuple { ty } = expected_ty.as_ref() else {
return Err(format!( return Err(format!(
"Expected {:?}, but got Tuple", "Expected {:?}, but got Tuple",
expected_ty.get_type_name() expected_ty.get_type_name()
)); ));
}; };
assert!(*is_vararg_ctx || ty.len() == t.len()); assert_eq!(ty.len(), t.len());
let elems = t let elems = t
.iter() .iter()
@ -159,7 +155,7 @@ impl SymbolValue {
SymbolValue::Bool(_) => primitives.bool, SymbolValue::Bool(_) => primitives.bool,
SymbolValue::Tuple(vs) => { SymbolValue::Tuple(vs) => {
let vs_tys = vs.iter().map(|v| v.get_type(primitives, unifier)).collect::<Vec<_>>(); let vs_tys = vs.iter().map(|v| v.get_type(primitives, unifier)).collect::<Vec<_>>();
unifier.add_ty(TypeEnum::TTuple { ty: vs_tys, is_vararg_ctx: false }) unifier.add_ty(TypeEnum::TTuple { ty: vs_tys })
} }
SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option, SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option,
} }
@ -369,7 +365,6 @@ pub trait SymbolResolver {
&self, &self,
str: StrRef, str: StrRef,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
) -> Option<ValueEnum<'ctx>>; ) -> Option<ValueEnum<'ctx>>;
fn get_default_param_value(&self, expr: &Expr) -> Option<SymbolValue>; fn get_default_param_value(&self, expr: &Expr) -> Option<SymbolValue>;
@ -487,7 +482,7 @@ pub fn parse_type_annotation<T>(
parse_type_annotation(resolver, top_level_defs, unifier, primitives, elt) parse_type_annotation(resolver, top_level_defs, unifier, primitives, elt)
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
Ok(unifier.add_ty(TypeEnum::TTuple { ty, is_vararg_ctx: false })) Ok(unifier.add_ty(TypeEnum::TTuple { ty }))
} else { } else {
Err(HashSet::from(["Expected multiple elements for tuple".into()])) Err(HashSet::from(["Expected multiple elements for tuple".into()]))
} }

View File

@ -1,5 +1,7 @@
use std::iter::once; use std::iter::once;
use crate::util::SizeVariant;
use helper::{debug_assert_prim_is_allowed, make_exception_fields, PrimDefDetails};
use indexmap::IndexMap; use indexmap::IndexMap;
use inkwell::{ use inkwell::{
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
@ -10,22 +12,22 @@ use inkwell::{
use itertools::Either; use itertools::Either;
use strum::IntoEnumIterator; use strum::IntoEnumIterator;
use super::{
helper::{debug_assert_prim_is_allowed, make_exception_fields, PrimDef, PrimDefDetails},
numpy::make_ndarray_ty,
*,
};
use crate::{ use crate::{
codegen::{ codegen::{
builtin_fns, builtin_fns,
classes::{ProxyValue, RangeValue}, classes::{ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor},
expr::destructure_range,
irrt::*,
numpy::*, numpy::*,
stmt::exn_constructor, stmt::exn_constructor,
}, },
symbol_resolver::SymbolValue, symbol_resolver::SymbolValue,
toplevel::{helper::PrimDef, numpy::make_ndarray_ty},
typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap}, typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap},
}; };
use super::*;
type BuiltinInfo = Vec<(Arc<RwLock<TopLevelDef>>, Option<Stmt>)>; type BuiltinInfo = Vec<(Arc<RwLock<TopLevelDef>>, Option<Stmt>)>;
pub fn get_exn_constructor( pub fn get_exn_constructor(
@ -44,26 +46,10 @@ pub fn get_exn_constructor(
name: "msg".into(), name: "msg".into(),
ty: string, ty: string,
default_value: Some(SymbolValue::Str(String::new())), default_value: Some(SymbolValue::Str(String::new())),
is_vararg: false,
},
FuncArg {
name: "param0".into(),
ty: int64,
default_value: Some(SymbolValue::I64(0)),
is_vararg: false,
},
FuncArg {
name: "param1".into(),
ty: int64,
default_value: Some(SymbolValue::I64(0)),
is_vararg: false,
},
FuncArg {
name: "param2".into(),
ty: int64,
default_value: Some(SymbolValue::I64(0)),
is_vararg: false,
}, },
FuncArg { name: "param0".into(), ty: int64, default_value: Some(SymbolValue::I64(0)) },
FuncArg { name: "param1".into(), ty: int64, default_value: Some(SymbolValue::I64(0)) },
FuncArg { name: "param2".into(), ty: int64, default_value: Some(SymbolValue::I64(0)) },
]; ];
let exn_type = unifier.add_ty(TypeEnum::TObj { let exn_type = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(class_id), obj_id: DefinitionId(class_id),
@ -113,7 +99,7 @@ pub fn get_exn_constructor(
/// * `name`: The name of the implemented NumPy function. /// * `name`: The name of the implemented NumPy function.
/// * `ret_ty`: The return type of this function. /// * `ret_ty`: The return type of this function.
/// * `param_ty`: The parameters accepted by this function, represented by a tuple of the /// * `param_ty`: The parameters accepted by this function, represented by a tuple of the
/// [parameter type][Type] and the parameter symbol name. /// [parameter type][Type] and the parameter symbol name.
/// * `codegen_callback`: A lambda generating LLVM IR for the implementation of this function. /// * `codegen_callback`: A lambda generating LLVM IR for the implementation of this function.
fn create_fn_by_codegen( fn create_fn_by_codegen(
unifier: &mut Unifier, unifier: &mut Unifier,
@ -129,12 +115,7 @@ fn create_fn_by_codegen(
signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { signature: unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: param_ty args: param_ty
.iter() .iter()
.map(|p| FuncArg { .map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None })
name: p.1.into(),
ty: p.0,
default_value: None,
is_vararg: false,
})
.collect(), .collect(),
ret: ret_ty, ret: ret_ty,
vars: var_map.clone(), vars: var_map.clone(),
@ -153,7 +134,7 @@ fn create_fn_by_codegen(
/// * `name`: The name of the implemented NumPy function. /// * `name`: The name of the implemented NumPy function.
/// * `ret_ty`: The return type of this function. /// * `ret_ty`: The return type of this function.
/// * `param_ty`: The parameters accepted by this function, represented by a tuple of the /// * `param_ty`: The parameters accepted by this function, represented by a tuple of the
/// [parameter type][Type] and the parameter symbol name. /// [parameter type][Type] and the parameter symbol name.
/// * `intrinsic_fn`: The fully-qualified name of the LLVM intrinsic function. /// * `intrinsic_fn`: The fully-qualified name of the LLVM intrinsic function.
fn create_fn_by_intrinsic( fn create_fn_by_intrinsic(
unifier: &mut Unifier, unifier: &mut Unifier,
@ -215,10 +196,10 @@ fn create_fn_by_intrinsic(
/// * `name`: The name of the implemented NumPy function. /// * `name`: The name of the implemented NumPy function.
/// * `ret_ty`: The return type of this function. /// * `ret_ty`: The return type of this function.
/// * `param_ty`: The parameters accepted by this function, represented by a tuple of the /// * `param_ty`: The parameters accepted by this function, represented by a tuple of the
/// [parameter type][Type] and the parameter symbol name. /// [parameter type][Type] and the parameter symbol name.
/// * `extern_fn`: The fully-qualified name of the extern function used as the implementation. /// * `extern_fn`: The fully-qualified name of the extern function used as the implementation.
/// * `attrs`: The list of attributes to apply to this function declaration. Note that `nounwind` is /// * `attrs`: The list of attributes to apply to this function declaration. Note that `nounwind` is
/// already implied by the C ABI. /// already implied by the C ABI.
fn create_fn_by_extern( fn create_fn_by_extern(
unifier: &mut Unifier, unifier: &mut Unifier,
var_map: &VarMap, var_map: &VarMap,
@ -298,19 +279,10 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built
.collect() .collect()
} }
/// A helper enum used by [`BuiltinBuilder`] fn size_variant_to_int_type(variant: SizeVariant, primitives: &PrimitiveStore) -> Type {
#[derive(Clone, Copy)] match variant {
enum SizeVariant { SizeVariant::Bits32 => primitives.int32,
Bits32, SizeVariant::Bits64 => primitives.int64,
Bits64,
}
impl SizeVariant {
fn of_int(self, primitives: &PrimitiveStore) -> Type {
match self {
SizeVariant::Bits32 => primitives.int32,
SizeVariant::Bits64 => primitives.int64,
}
} }
} }
@ -366,8 +338,8 @@ impl<'a> BuiltinBuilder<'a> {
let (is_some_ty, unwrap_ty, option_tvar) = let (is_some_ty, unwrap_ty, option_tvar) =
if let TypeEnum::TObj { fields, params, .. } = unifier.get_ty(option).as_ref() { if let TypeEnum::TObj { fields, params, .. } = unifier.get_ty(option).as_ref() {
( (
*fields.get(&PrimDef::FunOptionIsSome.simple_name().into()).unwrap(), *fields.get(&PrimDef::OptionIsSome.simple_name().into()).unwrap(),
*fields.get(&PrimDef::FunOptionUnwrap.simple_name().into()).unwrap(), *fields.get(&PrimDef::OptionUnwrap.simple_name().into()).unwrap(),
iter_type_vars(params).next().unwrap(), iter_type_vars(params).next().unwrap(),
) )
} else { } else {
@ -382,9 +354,9 @@ impl<'a> BuiltinBuilder<'a> {
let ndarray_dtype_tvar = iter_type_vars(ndarray_params).next().unwrap(); let ndarray_dtype_tvar = iter_type_vars(ndarray_params).next().unwrap();
let ndarray_ndims_tvar = iter_type_vars(ndarray_params).nth(1).unwrap(); let ndarray_ndims_tvar = iter_type_vars(ndarray_params).nth(1).unwrap();
let ndarray_copy_ty = let ndarray_copy_ty =
*ndarray_fields.get(&PrimDef::FunNDArrayCopy.simple_name().into()).unwrap(); *ndarray_fields.get(&PrimDef::NDArrayCopy.simple_name().into()).unwrap();
let ndarray_fill_ty = let ndarray_fill_ty =
*ndarray_fields.get(&PrimDef::FunNDArrayFill.simple_name().into()).unwrap(); *ndarray_fields.get(&PrimDef::NDArrayFill.simple_name().into()).unwrap();
let num_ty = unifier.get_fresh_var_with_range( let num_ty = unifier.get_fresh_var_with_range(
&[int32, int64, float, boolean, uint32, uint64], &[int32, int64, float, boolean, uint32, uint64],
@ -484,14 +456,14 @@ impl<'a> BuiltinBuilder<'a> {
PrimDef::Exception => self.build_exception_class_related(prim), PrimDef::Exception => self.build_exception_class_related(prim),
PrimDef::Option PrimDef::Option
| PrimDef::FunOptionIsSome | PrimDef::OptionIsSome
| PrimDef::FunOptionIsNone | PrimDef::OptionIsNone
| PrimDef::FunOptionUnwrap | PrimDef::OptionUnwrap
| PrimDef::FunSome => self.build_option_class_related(prim), | PrimDef::FunSome => self.build_option_class_related(prim),
PrimDef::List => self.build_list_class_related(prim), PrimDef::List => self.build_list_class_related(prim),
PrimDef::NDArray | PrimDef::FunNDArrayCopy | PrimDef::FunNDArrayFill => { PrimDef::NDArray | PrimDef::NDArrayCopy | PrimDef::NDArrayFill => {
self.build_ndarray_class_related(prim) self.build_ndarray_class_related(prim)
} }
@ -530,9 +502,7 @@ impl<'a> BuiltinBuilder<'a> {
PrimDef::FunMin | PrimDef::FunMax => self.build_min_max_function(prim), PrimDef::FunMin | PrimDef::FunMax => self.build_min_max_function(prim),
PrimDef::FunNpArgmin | PrimDef::FunNpArgmax | PrimDef::FunNpMin | PrimDef::FunNpMax => { PrimDef::FunNpMin | PrimDef::FunNpMax => self.build_np_min_max_function(prim),
self.build_np_max_min_function(prim)
}
PrimDef::FunNpMinimum | PrimDef::FunNpMaximum => { PrimDef::FunNpMinimum | PrimDef::FunNpMaximum => {
self.build_np_minimum_maximum_function(prim) self.build_np_minimum_maximum_function(prim)
@ -576,22 +546,6 @@ impl<'a> BuiltinBuilder<'a> {
| PrimDef::FunNpLdExp | PrimDef::FunNpLdExp
| PrimDef::FunNpHypot | PrimDef::FunNpHypot
| PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim), | PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim),
PrimDef::FunNpTranspose | PrimDef::FunNpReshape => {
self.build_np_sp_ndarray_function(prim)
}
PrimDef::FunNpDot
| PrimDef::FunNpLinalgCholesky
| PrimDef::FunNpLinalgQr
| PrimDef::FunNpLinalgSvd
| PrimDef::FunNpLinalgInv
| PrimDef::FunNpLinalgPinv
| PrimDef::FunNpLinalgMatrixPower
| PrimDef::FunNpLinalgDet
| PrimDef::FunSpLinalgLu
| PrimDef::FunSpLinalgSchur
| PrimDef::FunSpLinalgHessenberg => self.build_linalg_methods(prim),
}; };
if cfg!(debug_assertions) { if cfg!(debug_assertions) {
@ -600,7 +554,7 @@ impl<'a> BuiltinBuilder<'a> {
match (&tld, prim.details()) { match (&tld, prim.details()) {
( (
TopLevelDef::Class { name, object_id, .. }, TopLevelDef::Class { name, object_id, .. },
PrimDefDetails::PrimClass { name: exp_name, .. }, PrimDefDetails::PrimClass { name: exp_name },
) => { ) => {
let exp_object_id = prim.id(); let exp_object_id = prim.id();
assert_eq!(name, &exp_name.into()); assert_eq!(name, &exp_name.into());
@ -649,24 +603,17 @@ impl<'a> BuiltinBuilder<'a> {
let make_ctor_signature = |unifier: &mut Unifier| { let make_ctor_signature = |unifier: &mut Unifier| {
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![ args: vec![
FuncArg { FuncArg { name: "start".into(), ty: int32, default_value: None },
name: "start".into(),
ty: int32,
default_value: None,
is_vararg: false,
},
FuncArg { FuncArg {
name: "stop".into(), name: "stop".into(),
ty: int32, ty: int32,
// placeholder // placeholder
default_value: Some(SymbolValue::I32(0)), default_value: Some(SymbolValue::I32(0)),
is_vararg: false,
}, },
FuncArg { FuncArg {
name: "step".into(), name: "step".into(),
ty: int32, ty: int32,
default_value: Some(SymbolValue::I32(1)), default_value: Some(SymbolValue::I32(1)),
is_vararg: false,
}, },
], ],
ret: range, ret: range,
@ -837,9 +784,9 @@ impl<'a> BuiltinBuilder<'a> {
prim, prim,
&[ &[
PrimDef::Option, PrimDef::Option,
PrimDef::FunOptionIsSome, PrimDef::OptionIsSome,
PrimDef::FunOptionIsNone, PrimDef::OptionIsNone,
PrimDef::FunOptionUnwrap, PrimDef::OptionUnwrap,
PrimDef::FunSome, PrimDef::FunSome,
], ],
); );
@ -852,9 +799,9 @@ impl<'a> BuiltinBuilder<'a> {
fields: Vec::default(), fields: Vec::default(),
attributes: Vec::default(), attributes: Vec::default(),
methods: vec![ methods: vec![
Self::create_method(PrimDef::FunOptionIsSome, self.is_some_ty.0), Self::create_method(PrimDef::OptionIsSome, self.is_some_ty.0),
Self::create_method(PrimDef::FunOptionIsNone, self.is_some_ty.0), Self::create_method(PrimDef::OptionIsNone, self.is_some_ty.0),
Self::create_method(PrimDef::FunOptionUnwrap, self.unwrap_ty.0), Self::create_method(PrimDef::OptionUnwrap, self.unwrap_ty.0),
], ],
ancestors: vec![TypeAnnotation::CustomClass { ancestors: vec![TypeAnnotation::CustomClass {
id: prim.id(), id: prim.id(),
@ -865,7 +812,7 @@ impl<'a> BuiltinBuilder<'a> {
loc: None, loc: None,
}, },
PrimDef::FunOptionUnwrap => TopLevelDef::Function { PrimDef::OptionUnwrap => TopLevelDef::Function {
name: prim.name().into(), name: prim.name().into(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.unwrap_ty.0, signature: self.unwrap_ty.0,
@ -879,7 +826,7 @@ impl<'a> BuiltinBuilder<'a> {
loc: None, loc: None,
}, },
PrimDef::FunOptionIsNone | PrimDef::FunOptionIsSome => TopLevelDef::Function { PrimDef::OptionIsNone | PrimDef::OptionIsSome => TopLevelDef::Function {
name: prim.name().to_string(), name: prim.name().to_string(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.is_some_ty.0, signature: self.is_some_ty.0,
@ -900,10 +847,10 @@ impl<'a> BuiltinBuilder<'a> {
}; };
let returned_int = match prim { let returned_int = match prim {
PrimDef::FunOptionIsNone => { PrimDef::OptionIsNone => {
ctx.builder.build_is_null(ptr, prim.simple_name()) ctx.builder.build_is_null(ptr, prim.simple_name())
} }
PrimDef::FunOptionIsSome => { PrimDef::OptionIsSome => {
ctx.builder.build_is_not_null(ptr, prim.simple_name()) ctx.builder.build_is_not_null(ptr, prim.simple_name())
} }
_ => unreachable!(), _ => unreachable!(),
@ -922,7 +869,6 @@ impl<'a> BuiltinBuilder<'a> {
name: "n".into(), name: "n".into(),
ty: self.option_tvar.ty, ty: self.option_tvar.ty,
default_value: None, default_value: None,
is_vararg: false,
}], }],
ret: self.primitives.option, ret: self.primitives.option,
vars: into_var_map([self.option_tvar]), vars: into_var_map([self.option_tvar]),
@ -977,7 +923,7 @@ impl<'a> BuiltinBuilder<'a> {
fn build_ndarray_class_related(&self, prim: PrimDef) -> TopLevelDef { fn build_ndarray_class_related(&self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed( debug_assert_prim_is_allowed(
prim, prim,
&[PrimDef::NDArray, PrimDef::FunNDArrayCopy, PrimDef::FunNDArrayFill], &[PrimDef::NDArray, PrimDef::NDArrayCopy, PrimDef::NDArrayFill],
); );
match prim { match prim {
@ -988,8 +934,8 @@ impl<'a> BuiltinBuilder<'a> {
fields: Vec::default(), fields: Vec::default(),
attributes: Vec::default(), attributes: Vec::default(),
methods: vec![ methods: vec![
Self::create_method(PrimDef::FunNDArrayCopy, self.ndarray_copy_ty.0), Self::create_method(PrimDef::NDArrayCopy, self.ndarray_copy_ty.0),
Self::create_method(PrimDef::FunNDArrayFill, self.ndarray_fill_ty.0), Self::create_method(PrimDef::NDArrayFill, self.ndarray_fill_ty.0),
], ],
ancestors: Vec::default(), ancestors: Vec::default(),
constructor: None, constructor: None,
@ -997,7 +943,7 @@ impl<'a> BuiltinBuilder<'a> {
loc: None, loc: None,
}, },
PrimDef::FunNDArrayCopy => TopLevelDef::Function { PrimDef::NDArrayCopy => TopLevelDef::Function {
name: prim.name().into(), name: prim.name().into(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.ndarray_copy_ty.0, signature: self.ndarray_copy_ty.0,
@ -1007,14 +953,15 @@ impl<'a> BuiltinBuilder<'a> {
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new( codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, obj, fun, args, generator| { |ctx, obj, fun, args, generator| {
gen_ndarray_copy(ctx, &obj, fun, &args, generator) todo!()
.map(|val| Some(val.as_basic_value_enum())) // gen_ndarray_copy(ctx, &obj, fun, &args, generator)
// .map(|val| Some(val.as_basic_value_enum()))
}, },
)))), )))),
loc: None, loc: None,
}, },
PrimDef::FunNDArrayFill => TopLevelDef::Function { PrimDef::NDArrayFill => TopLevelDef::Function {
name: prim.name().into(), name: prim.name().into(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.ndarray_fill_ty.0, signature: self.ndarray_fill_ty.0,
@ -1024,8 +971,9 @@ impl<'a> BuiltinBuilder<'a> {
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new( codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, obj, fun, args, generator| { |ctx, obj, fun, args, generator| {
gen_ndarray_fill(ctx, &obj, fun, &args, generator)?; todo!()
Ok(None) // gen_ndarray_fill(ctx, &obj, fun, &args, generator)?;
// Ok(None)
}, },
)))), )))),
loc: None, loc: None,
@ -1057,7 +1005,6 @@ impl<'a> BuiltinBuilder<'a> {
name: "n".into(), name: "n".into(),
ty: self.num_or_ndarray_ty.ty, ty: self.num_or_ndarray_ty.ty,
default_value: None, default_value: None,
is_vararg: false,
}], }],
ret: self.num_or_ndarray_ty.ty, ret: self.num_or_ndarray_ty.ty,
vars: self.num_or_ndarray_var_map.clone(), vars: self.num_or_ndarray_var_map.clone(),
@ -1106,7 +1053,7 @@ impl<'a> BuiltinBuilder<'a> {
); );
// The size variant of the function determines the size of the returned int. // The size variant of the function determines the size of the returned int.
let int_sized = size_variant.of_int(self.primitives); let int_sized = size_variant_to_int_type(size_variant, self.primitives);
let ndarray_int_sized = let ndarray_int_sized =
make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.ty)); make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.ty));
@ -1131,7 +1078,7 @@ impl<'a> BuiltinBuilder<'a> {
let arg_ty = fun.0.args[0].ty; let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
let ret_elem_ty = size_variant.of_int(&ctx.primitives); let ret_elem_ty = size_variant_to_int_type(size_variant, &ctx.primitives);
Ok(Some(builtin_fns::call_round(generator, ctx, (arg_ty, arg), ret_elem_ty)?)) Ok(Some(builtin_fns::call_round(generator, ctx, (arg_ty, arg), ret_elem_ty)?))
}), }),
) )
@ -1172,7 +1119,7 @@ impl<'a> BuiltinBuilder<'a> {
make_ndarray_ty(self.unifier, self.primitives, Some(float), Some(common_ndim.ty)); make_ndarray_ty(self.unifier, self.primitives, Some(float), Some(common_ndim.ty));
// The size variant of the function determines the type of int returned // The size variant of the function determines the type of int returned
let int_sized = size_variant.of_int(self.primitives); let int_sized = size_variant_to_int_type(size_variant, self.primitives);
let ndarray_int_sized = let ndarray_int_sized =
make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.ty)); make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.ty));
@ -1195,7 +1142,7 @@ impl<'a> BuiltinBuilder<'a> {
let arg_ty = fun.0.args[0].ty; let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
let ret_elem_ty = size_variant.of_int(&ctx.primitives); let ret_elem_ty = size_variant_to_int_type(size_variant, &ctx.primitives);
let func = match kind { let func = match kind {
Kind::Ceil => builtin_fns::call_ceil, Kind::Ceil => builtin_fns::call_ceil,
Kind::Floor => builtin_fns::call_floor, Kind::Floor => builtin_fns::call_floor,
@ -1249,7 +1196,7 @@ impl<'a> BuiltinBuilder<'a> {
let func = match prim { let func = match prim {
PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => gen_ndarray_empty, PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => gen_ndarray_empty,
PrimDef::FunNpZeros => gen_ndarray_zeros, PrimDef::FunNpZeros => gen_ndarray_zeros,
PrimDef::FunNpOnes => gen_ndarray_ones, PrimDef::FunNpOnes => todo!(), // gen_ndarray_ones,
_ => unreachable!(), _ => unreachable!(),
}; };
func(ctx, &obj, fun, &args, generator).map(|val| Some(val.as_basic_value_enum())) func(ctx, &obj, fun, &args, generator).map(|val| Some(val.as_basic_value_enum()))
@ -1277,23 +1224,16 @@ impl<'a> BuiltinBuilder<'a> {
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![ args: vec![
FuncArg { FuncArg { name: "object".into(), ty: tv.ty, default_value: None },
name: "object".into(),
ty: tv.ty,
default_value: None,
is_vararg: false,
},
FuncArg { FuncArg {
name: "copy".into(), name: "copy".into(),
ty: bool, ty: bool,
default_value: Some(SymbolValue::Bool(true)), default_value: Some(SymbolValue::Bool(true)),
is_vararg: false,
}, },
FuncArg { FuncArg {
name: "ndmin".into(), name: "ndmin".into(),
ty: int32, ty: int32,
default_value: Some(SymbolValue::U32(0)), default_value: Some(SymbolValue::U32(0)),
is_vararg: false,
}, },
], ],
ret: ndarray, ret: ndarray,
@ -1305,8 +1245,9 @@ impl<'a> BuiltinBuilder<'a> {
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new( codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, obj, fun, args, generator| { |ctx, obj, fun, args, generator| {
gen_ndarray_array(ctx, &obj, fun, &args, generator) todo!()
.map(|val| Some(val.as_basic_value_enum())) // gen_ndarray_array(ctx, &obj, fun, &args, generator)
// .map(|val| Some(val.as_basic_value_enum()))
}, },
)))), )))),
loc: None, loc: None,
@ -1324,8 +1265,9 @@ impl<'a> BuiltinBuilder<'a> {
// type variable // type variable
&[(self.list_int32, "shape"), (tv.ty, "fill_value")], &[(self.list_int32, "shape"), (tv.ty, "fill_value")],
Box::new(move |ctx, obj, fun, args, generator| { Box::new(move |ctx, obj, fun, args, generator| {
gen_ndarray_full(ctx, &obj, fun, &args, generator) todo!()
.map(|val| Some(val.as_basic_value_enum())) // gen_ndarray_full(ctx, &obj, fun, &args, generator)
// .map(|val| Some(val.as_basic_value_enum()))
}), }),
) )
} }
@ -1335,24 +1277,17 @@ impl<'a> BuiltinBuilder<'a> {
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![ args: vec![
FuncArg { FuncArg { name: "N".into(), ty: int32, default_value: None },
name: "N".into(),
ty: int32,
default_value: None,
is_vararg: false,
},
// TODO(Derppening): Default values current do not work? // TODO(Derppening): Default values current do not work?
FuncArg { FuncArg {
name: "M".into(), name: "M".into(),
ty: int32, ty: int32,
default_value: Some(SymbolValue::OptionNone), default_value: Some(SymbolValue::OptionNone),
is_vararg: false,
}, },
FuncArg { FuncArg {
name: "k".into(), name: "k".into(),
ty: int32, ty: int32,
default_value: Some(SymbolValue::I32(0)), default_value: Some(SymbolValue::I32(0)),
is_vararg: false,
}, },
], ],
ret: self.ndarray_float_2d, ret: self.ndarray_float_2d,
@ -1364,8 +1299,9 @@ impl<'a> BuiltinBuilder<'a> {
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new( codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, obj, fun, args, generator| { |ctx, obj, fun, args, generator| {
gen_ndarray_eye(ctx, &obj, fun, &args, generator) todo!()
.map(|val| Some(val.as_basic_value_enum())) // gen_ndarray_eye(ctx, &obj, fun, &args, generator)
// .map(|val| Some(val.as_basic_value_enum()))
}, },
)))), )))),
loc: None, loc: None,
@ -1378,8 +1314,9 @@ impl<'a> BuiltinBuilder<'a> {
self.ndarray_float_2d, self.ndarray_float_2d,
&[(int32, "n")], &[(int32, "n")],
Box::new(|ctx, obj, fun, args, generator| { Box::new(|ctx, obj, fun, args, generator| {
gen_ndarray_identity(ctx, &obj, fun, &args, generator) todo!()
.map(|val| Some(val.as_basic_value_enum())) // gen_ndarray_identity(ctx, &obj, fun, &args, generator)
// .map(|val| Some(val.as_basic_value_enum()))
}), }),
), ),
_ => unreachable!(), _ => unreachable!(),
@ -1396,12 +1333,7 @@ impl<'a> BuiltinBuilder<'a> {
name: prim.name().into(), name: prim.name().into(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { args: vec![FuncArg { name: "s".into(), ty: str, default_value: None }],
name: "s".into(),
ty: str,
default_value: None,
is_vararg: false,
}],
ret: str, ret: str,
vars: VarMap::default(), vars: VarMap::default(),
})), })),
@ -1465,21 +1397,31 @@ impl<'a> BuiltinBuilder<'a> {
fn build_len_function(&mut self) -> TopLevelDef { fn build_len_function(&mut self) -> TopLevelDef {
let prim = PrimDef::FunLen; let prim = PrimDef::FunLen;
// Type handled in [`Inferencer::try_fold_special_call`] let PrimitiveStore { uint64, int32, .. } = *self.primitives;
let arg_tvar = self.unifier.get_dummy_var();
let tvar = self.unifier.get_fresh_var(Some("L".into()), None);
let list = self
.unifier
.subst(
self.primitives.list,
&into_var_map([TypeVar { id: self.list_tvar.id, ty: tvar.ty }]),
)
.unwrap();
let ndims = self.unifier.get_fresh_const_generic_var(uint64, Some("N".into()), None);
let ndarray = make_ndarray_ty(self.unifier, self.primitives, Some(tvar.ty), Some(ndims.ty));
let arg_ty = self.unifier.get_fresh_var_with_range(
&[list, ndarray, self.primitives.range],
Some("I".into()),
None,
);
TopLevelDef::Function { TopLevelDef::Function {
name: prim.name().into(), name: prim.name().into(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { args: vec![FuncArg { name: "ls".into(), ty: arg_ty.ty, default_value: None }],
name: "obj".into(), ret: int32,
ty: arg_tvar.ty, vars: into_var_map([tvar, arg_ty]),
default_value: None,
is_vararg: false,
}],
ret: self.primitives.int32,
vars: into_var_map([arg_tvar]),
})), })),
var_id: Vec::default(), var_id: Vec::default(),
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
@ -1487,10 +1429,86 @@ impl<'a> BuiltinBuilder<'a> {
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new( codegen_callback: Some(Arc::new(GenCall::new(Box::new(
move |ctx, _, fun, args, generator| { move |ctx, _, fun, args, generator| {
let range_ty = ctx.primitives.range;
let arg_ty = fun.0.args[0].ty; let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
Ok(if ctx.unifier.unioned(arg_ty, range_ty) {
let arg = RangeValue::from_ptr_val(arg.into_pointer_value(), Some("range"));
let (start, end, step) = destructure_range(ctx, arg);
Some(calculate_len_for_slice_range(generator, ctx, start, end, step).into())
} else {
match &*ctx.unifier.get_ty_immutable(arg_ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::List.id() => {
let int32 = ctx.ctx.i32_type();
let zero = int32.const_zero();
let len = ctx
.build_gep_and_load(
arg.into_pointer_value(),
&[zero, int32.const_int(1, false)],
None,
)
.into_int_value();
if len.get_type().get_bit_width() == 32 {
Some(len.into())
} else {
Some(
ctx.builder
.build_int_truncate(len, int32, "len2i32")
.map(Into::into)
.unwrap(),
)
}
}
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
builtin_fns::call_len(generator, ctx, (arg_ty, arg)).map(|ret| Some(ret.into())) let arg = NDArrayValue::from_ptr_val(
arg.into_pointer_value(),
llvm_usize,
None,
);
let ndims = arg.dim_sizes().size(ctx, generator);
ctx.make_assert(
generator,
ctx.builder
.build_int_compare(
IntPredicate::NE,
ndims,
llvm_usize.const_zero(),
"",
)
.unwrap(),
"0:TypeError",
&format!("{name}() of unsized object", name = prim.name()),
[None, None, None],
ctx.current_loc,
);
let len = unsafe {
arg.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_zero(),
None,
)
};
if len.get_type().get_bit_width() == 32 {
Some(len.into())
} else {
Some(
ctx.builder
.build_int_truncate(len, llvm_i32, "len")
.map(Into::into)
.unwrap(),
)
}
}
_ => unreachable!(),
}
})
}, },
)))), )))),
loc: None, loc: None,
@ -1506,18 +1524,8 @@ impl<'a> BuiltinBuilder<'a> {
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![ args: vec![
FuncArg { FuncArg { name: "m".into(), ty: self.num_ty.ty, default_value: None },
name: "m".into(), FuncArg { name: "n".into(), ty: self.num_ty.ty, default_value: None },
ty: self.num_ty.ty,
default_value: None,
is_vararg: false,
},
FuncArg {
name: "n".into(),
ty: self.num_ty.ty,
default_value: None,
is_vararg: false,
},
], ],
ret: self.num_ty.ty, ret: self.num_ty.ty,
vars: self.num_var_map.clone(), vars: self.num_var_map.clone(),
@ -1545,45 +1553,39 @@ impl<'a> BuiltinBuilder<'a> {
} }
} }
/// Build the functions `np_max()`, `np_min()`, `np_argmax()` and `np_argmin()` /// Build the functions `np_min()` and `np_max()`.
/// Calls `call_numpy_max_min` with the function name fn build_np_min_max_function(&mut self, prim: PrimDef) -> TopLevelDef {
fn build_np_max_min_function(&mut self, prim: PrimDef) -> TopLevelDef { debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpMin, PrimDef::FunNpMax]);
debug_assert_prim_is_allowed(
prim,
&[PrimDef::FunNpArgmin, PrimDef::FunNpArgmax, PrimDef::FunNpMin, PrimDef::FunNpMax],
);
let (var_map, ret_ty) = match prim { let ret_ty = self.unifier.get_fresh_var(Some("R".into()), None);
PrimDef::FunNpArgmax | PrimDef::FunNpArgmin => { let var_map = self
(self.num_or_ndarray_var_map.clone(), self.primitives.int64) .num_or_ndarray_var_map
} .clone()
PrimDef::FunNpMax | PrimDef::FunNpMin => { .into_iter()
let ret_ty = self.unifier.get_fresh_var(Some("R".into()), None); .chain(once((ret_ty.id, ret_ty.ty)))
let var_map = self .collect::<IndexMap<_, _>>();
.num_or_ndarray_var_map
.clone()
.into_iter()
.chain(once((ret_ty.id, ret_ty.ty)))
.collect::<IndexMap<_, _>>();
(var_map, ret_ty.ty)
}
_ => unreachable!(),
};
create_fn_by_codegen( create_fn_by_codegen(
self.unifier, self.unifier,
&var_map, &var_map,
prim.name(), prim.name(),
ret_ty, ret_ty.ty,
&[(self.num_or_ndarray_ty.ty, "a")], &[(self.float_or_ndarray_ty.ty, "a")],
Box::new(move |ctx, _, fun, args, generator| { Box::new(move |ctx, _, fun, args, generator| {
let a_ty = fun.0.args[0].ty; let a_ty = fun.0.args[0].ty;
let a = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?; let a = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?;
Ok(Some(builtin_fns::call_numpy_max_min(generator, ctx, (a_ty, a), prim.name())?)) let func = match prim {
PrimDef::FunNpMin => builtin_fns::call_numpy_min,
PrimDef::FunNpMax => builtin_fns::call_numpy_max,
_ => unreachable!(),
};
Ok(Some(func(generator, ctx, (a_ty, a))?))
}), }),
) )
} }
/// Build the functions `np_minimum()` and `np_maximum()`. /// Build the functions `np_minimum()` and `np_maximum()`.
fn build_np_minimum_maximum_function(&mut self, prim: PrimDef) -> TopLevelDef { fn build_np_minimum_maximum_function(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpMinimum, PrimDef::FunNpMaximum]); debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpMinimum, PrimDef::FunNpMaximum]);
@ -1599,12 +1601,7 @@ impl<'a> BuiltinBuilder<'a> {
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: param_ty args: param_ty
.iter() .iter()
.map(|p| FuncArg { .map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None })
name: p.1.into(),
ty: p.0,
default_value: None,
is_vararg: false,
})
.collect(), .collect(),
ret: ret_ty.ty, ret: ret_ty.ty,
vars: into_var_map([x1_ty, x2_ty, ret_ty]), vars: into_var_map([x1_ty, x2_ty, ret_ty]),
@ -1645,7 +1642,6 @@ impl<'a> BuiltinBuilder<'a> {
name: "n".into(), name: "n".into(),
ty: self.num_or_ndarray_ty.ty, ty: self.num_or_ndarray_ty.ty,
default_value: None, default_value: None,
is_vararg: false,
}], }],
ret: self.num_or_ndarray_ty.ty, ret: self.num_or_ndarray_ty.ty,
vars: self.num_or_ndarray_var_map.clone(), vars: self.num_or_ndarray_var_map.clone(),
@ -1834,12 +1830,7 @@ impl<'a> BuiltinBuilder<'a> {
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: param_ty args: param_ty
.iter() .iter()
.map(|p| FuncArg { .map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None })
name: p.1.into(),
ty: p.0,
default_value: None,
is_vararg: false,
})
.collect(), .collect(),
ret: ret_ty.ty, ret: ret_ty.ty,
vars: into_var_map([x1_ty, x2_ty, ret_ty]), vars: into_var_map([x1_ty, x2_ty, ret_ty]),
@ -1873,207 +1864,6 @@ impl<'a> BuiltinBuilder<'a> {
} }
} }
/// Build np/sp functions that take as input `NDArray` only
fn build_np_sp_ndarray_function(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]);
match prim {
PrimDef::FunNpTranspose => {
let ndarray_ty = self.unifier.get_fresh_var_with_range(
&[self.ndarray_num_ty],
Some("T".into()),
None,
);
create_fn_by_codegen(
self.unifier,
&into_var_map([ndarray_ty]),
prim.name(),
ndarray_ty.ty,
&[(ndarray_ty.ty, "x")],
Box::new(move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty;
let arg_val =
args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?))
}),
)
}
// NOTE: on `ndarray_factory_fn_shape_arg_tvar` and
// the `param_ty` for `create_fn_by_codegen`.
//
// Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking
// to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`],
// and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`.
PrimDef::FunNpReshape => create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
self.ndarray_num_ty,
&[(self.ndarray_num_ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_ty = fun.0.args[1].ty;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
}),
),
_ => unreachable!(),
}
}
/// Build `np_linalg` and `sp_linalg` functions
///
/// The input to these functions must be floating point `NDArray`
fn build_linalg_methods(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(
prim,
&[
PrimDef::FunNpDot,
PrimDef::FunNpLinalgCholesky,
PrimDef::FunNpLinalgQr,
PrimDef::FunNpLinalgSvd,
PrimDef::FunNpLinalgInv,
PrimDef::FunNpLinalgPinv,
PrimDef::FunNpLinalgMatrixPower,
PrimDef::FunNpLinalgDet,
PrimDef::FunSpLinalgLu,
PrimDef::FunSpLinalgSchur,
PrimDef::FunSpLinalgHessenberg,
],
);
match prim {
PrimDef::FunNpDot => create_fn_by_codegen(
self.unifier,
&self.num_or_ndarray_var_map,
prim.name(),
self.num_ty.ty,
&[(self.num_or_ndarray_ty.ty, "x1"), (self.num_or_ndarray_ty.ty, "x2")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_ty = fun.0.args[1].ty;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
}),
),
PrimDef::FunNpLinalgCholesky | PrimDef::FunNpLinalgInv | PrimDef::FunNpLinalgPinv => {
create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
self.ndarray_float_2d,
&[(self.ndarray_float_2d, "x1")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val =
args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let func = match prim {
PrimDef::FunNpLinalgCholesky => builtin_fns::call_np_linalg_cholesky,
PrimDef::FunNpLinalgInv => builtin_fns::call_np_linalg_inv,
PrimDef::FunNpLinalgPinv => builtin_fns::call_np_linalg_pinv,
_ => unreachable!(),
};
Ok(Some(func(generator, ctx, (x1_ty, x1_val))?))
}),
)
}
PrimDef::FunNpLinalgQr
| PrimDef::FunSpLinalgLu
| PrimDef::FunSpLinalgSchur
| PrimDef::FunSpLinalgHessenberg => {
let ret_ty = self.unifier.add_ty(TypeEnum::TTuple {
ty: vec![self.ndarray_float_2d, self.ndarray_float_2d],
is_vararg_ctx: false,
});
create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
ret_ty,
&[(self.ndarray_float_2d, "x1")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val =
args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let func = match prim {
PrimDef::FunNpLinalgQr => builtin_fns::call_np_linalg_qr,
PrimDef::FunSpLinalgLu => builtin_fns::call_sp_linalg_lu,
PrimDef::FunSpLinalgSchur => builtin_fns::call_sp_linalg_schur,
PrimDef::FunSpLinalgHessenberg => {
builtin_fns::call_sp_linalg_hessenberg
}
_ => unreachable!(),
};
Ok(Some(func(generator, ctx, (x1_ty, x1_val))?))
}),
)
}
PrimDef::FunNpLinalgSvd => {
let ret_ty = self.unifier.add_ty(TypeEnum::TTuple {
ty: vec![self.ndarray_float_2d, self.ndarray_float, self.ndarray_float_2d],
is_vararg_ctx: false,
});
create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
ret_ty,
&[(self.ndarray_float_2d, "x1")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val =
args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
Ok(Some(builtin_fns::call_np_linalg_svd(generator, ctx, (x1_ty, x1_val))?))
}),
)
}
PrimDef::FunNpLinalgMatrixPower => create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
self.ndarray_float_2d,
&[(self.ndarray_float_2d, "x1"), (self.primitives.int32, "power")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_ty = fun.0.args[1].ty;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(builtin_fns::call_np_linalg_matrix_power(
generator,
ctx,
(x1_ty, x1_val),
(x2_ty, x2_val),
)?))
}),
),
PrimDef::FunNpLinalgDet => create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
self.primitives.float,
&[(self.ndarray_float_2d, "x1")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
Ok(Some(builtin_fns::call_np_linalg_det(generator, ctx, (x1_ty, x1_val))?))
}),
),
_ => unreachable!(),
}
}
fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) { fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) {
(prim.simple_name().into(), method_ty, prim.id()) (prim.simple_name().into(), method_ty, prim.id())
} }

File diff suppressed because it is too large Load Diff

View File

@ -1,15 +1,13 @@
use std::convert::TryInto; use std::convert::TryInto;
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::numpy::unpack_ndarray_var_tys;
use crate::typecheck::typedef::{into_var_map, iter_type_vars, Mapping, TypeVarId, VarMap};
use nac3parser::ast::{Constant, Location};
use strum::IntoEnumIterator; use strum::IntoEnumIterator;
use strum_macros::EnumIter; use strum_macros::EnumIter;
use nac3parser::ast::{Constant, ExprKind, Location}; use super::*;
use super::{numpy::unpack_ndarray_var_tys, *};
use crate::{
symbol_resolver::SymbolValue,
typecheck::typedef::{into_var_map, iter_type_vars, Mapping, TypeVarId, VarMap},
};
/// All primitive types and functions in nac3core. /// All primitive types and functions in nac3core.
#[derive(Clone, Copy, Debug, EnumIter, PartialEq, Eq)] #[derive(Clone, Copy, Debug, EnumIter, PartialEq, Eq)]
@ -29,22 +27,17 @@ pub enum PrimDef {
List, List,
NDArray, NDArray,
// Option methods // Member Functions
FunOptionIsSome, OptionIsSome,
FunOptionIsNone, OptionIsNone,
FunOptionUnwrap, OptionUnwrap,
NDArrayCopy,
// Option-related functions NDArrayFill,
FunSome, FunInt32,
FunInt64,
// NDArray methods FunUInt32,
FunNDArrayCopy, FunUInt64,
FunNDArrayFill, FunFloat,
// Range methods
FunRangeInit,
// NumPy factory functions
FunNpNDArray, FunNpNDArray,
FunNpEmpty, FunNpEmpty,
FunNpZeros, FunNpZeros,
@ -53,17 +46,26 @@ pub enum PrimDef {
FunNpArray, FunNpArray,
FunNpEye, FunNpEye,
FunNpIdentity, FunNpIdentity,
FunRound,
// Miscellaneous NumPy & SciPy functions FunRound64,
FunNpRound, FunNpRound,
FunRangeInit,
FunStr,
FunBool,
FunFloor,
FunFloor64,
FunNpFloor, FunNpFloor,
FunCeil,
FunCeil64,
FunNpCeil, FunNpCeil,
FunLen,
FunMin,
FunNpMin, FunNpMin,
FunNpMinimum, FunNpMinimum,
FunNpArgmin, FunMax,
FunNpMax, FunNpMax,
FunNpMaximum, FunNpMaximum,
FunNpArgmax, FunAbs,
FunNpIsNan, FunNpIsNan,
FunNpIsInf, FunNpIsInf,
FunNpSin, FunNpSin,
@ -101,46 +103,15 @@ pub enum PrimDef {
FunNpLdExp, FunNpLdExp,
FunNpHypot, FunNpHypot,
FunNpNextAfter, FunNpNextAfter,
FunNpTranspose,
FunNpReshape,
// Linalg functions // Top-Level Functions
FunNpDot, FunSome,
FunNpLinalgCholesky,
FunNpLinalgQr,
FunNpLinalgSvd,
FunNpLinalgInv,
FunNpLinalgPinv,
FunNpLinalgMatrixPower,
FunNpLinalgDet,
FunSpLinalgLu,
FunSpLinalgSchur,
FunSpLinalgHessenberg,
// Miscellaneous Python & NAC3 functions
FunInt32,
FunInt64,
FunUInt32,
FunUInt64,
FunFloat,
FunRound,
FunRound64,
FunStr,
FunBool,
FunFloor,
FunFloor64,
FunCeil,
FunCeil64,
FunLen,
FunMin,
FunMax,
FunAbs,
} }
/// Associated details of a [`PrimDef`] /// Associated details of a [`PrimDef`]
pub enum PrimDefDetails { pub enum PrimDefDetails {
PrimFunction { name: &'static str, simple_name: &'static str }, PrimFunction { name: &'static str, simple_name: &'static str },
PrimClass { name: &'static str, get_ty_fn: fn(&PrimitiveStore) -> Type }, PrimClass { name: &'static str },
} }
impl PrimDef { impl PrimDef {
@ -182,17 +153,15 @@ impl PrimDef {
#[must_use] #[must_use]
pub fn name(&self) -> &'static str { pub fn name(&self) -> &'static str {
match self.details() { match self.details() {
PrimDefDetails::PrimFunction { name, .. } | PrimDefDetails::PrimClass { name, .. } => { PrimDefDetails::PrimFunction { name, .. } | PrimDefDetails::PrimClass { name } => name,
name
}
} }
} }
/// Get the associated details of this [`PrimDef`] /// Get the associated details of this [`PrimDef`]
#[must_use] #[must_use]
pub fn details(self) -> PrimDefDetails { pub fn details(self) -> PrimDefDetails {
fn class(name: &'static str, get_ty_fn: fn(&PrimitiveStore) -> Type) -> PrimDefDetails { fn class(name: &'static str) -> PrimDefDetails {
PrimDefDetails::PrimClass { name, get_ty_fn } PrimDefDetails::PrimClass { name }
} }
fn fun(name: &'static str, simple_name: Option<&'static str>) -> PrimDefDetails { fn fun(name: &'static str, simple_name: Option<&'static str>) -> PrimDefDetails {
@ -200,37 +169,29 @@ impl PrimDef {
} }
match self { match self {
// Classes PrimDef::Int32 => class("int32"),
PrimDef::Int32 => class("int32", |primitives| primitives.int32), PrimDef::Int64 => class("int64"),
PrimDef::Int64 => class("int64", |primitives| primitives.int64), PrimDef::Float => class("float"),
PrimDef::Float => class("float", |primitives| primitives.float), PrimDef::Bool => class("bool"),
PrimDef::Bool => class("bool", |primitives| primitives.bool), PrimDef::None => class("none"),
PrimDef::None => class("none", |primitives| primitives.none), PrimDef::Range => class("range"),
PrimDef::Range => class("range", |primitives| primitives.range), PrimDef::Str => class("str"),
PrimDef::Str => class("str", |primitives| primitives.str), PrimDef::Exception => class("Exception"),
PrimDef::Exception => class("Exception", |primitives| primitives.exception), PrimDef::UInt32 => class("uint32"),
PrimDef::UInt32 => class("uint32", |primitives| primitives.uint32), PrimDef::UInt64 => class("uint64"),
PrimDef::UInt64 => class("uint64", |primitives| primitives.uint64), PrimDef::Option => class("Option"),
PrimDef::Option => class("Option", |primitives| primitives.option), PrimDef::OptionIsSome => fun("Option.is_some", Some("is_some")),
PrimDef::List => class("list", |primitives| primitives.list), PrimDef::OptionIsNone => fun("Option.is_none", Some("is_none")),
PrimDef::NDArray => class("ndarray", |primitives| primitives.ndarray), PrimDef::OptionUnwrap => fun("Option.unwrap", Some("unwrap")),
PrimDef::List => class("list"),
// Option methods PrimDef::NDArray => class("ndarray"),
PrimDef::FunOptionIsSome => fun("Option.is_some", Some("is_some")), PrimDef::NDArrayCopy => fun("ndarray.copy", Some("copy")),
PrimDef::FunOptionIsNone => fun("Option.is_none", Some("is_none")), PrimDef::NDArrayFill => fun("ndarray.fill", Some("fill")),
PrimDef::FunOptionUnwrap => fun("Option.unwrap", Some("unwrap")), PrimDef::FunInt32 => fun("int32", None),
PrimDef::FunInt64 => fun("int64", None),
// Option-related functions PrimDef::FunUInt32 => fun("uint32", None),
PrimDef::FunSome => fun("Some", None), PrimDef::FunUInt64 => fun("uint64", None),
PrimDef::FunFloat => fun("float", None),
// NDArray methods
PrimDef::FunNDArrayCopy => fun("ndarray.copy", Some("copy")),
PrimDef::FunNDArrayFill => fun("ndarray.fill", Some("fill")),
// Range methods
PrimDef::FunRangeInit => fun("range.__init__", Some("__init__")),
// NumPy factory functions
PrimDef::FunNpNDArray => fun("np_ndarray", None), PrimDef::FunNpNDArray => fun("np_ndarray", None),
PrimDef::FunNpEmpty => fun("np_empty", None), PrimDef::FunNpEmpty => fun("np_empty", None),
PrimDef::FunNpZeros => fun("np_zeros", None), PrimDef::FunNpZeros => fun("np_zeros", None),
@ -239,17 +200,26 @@ impl PrimDef {
PrimDef::FunNpArray => fun("np_array", None), PrimDef::FunNpArray => fun("np_array", None),
PrimDef::FunNpEye => fun("np_eye", None), PrimDef::FunNpEye => fun("np_eye", None),
PrimDef::FunNpIdentity => fun("np_identity", None), PrimDef::FunNpIdentity => fun("np_identity", None),
PrimDef::FunRound => fun("round", None),
// Miscellaneous NumPy & SciPy functions PrimDef::FunRound64 => fun("round64", None),
PrimDef::FunNpRound => fun("np_round", None), PrimDef::FunNpRound => fun("np_round", None),
PrimDef::FunRangeInit => fun("range.__init__", Some("__init__")),
PrimDef::FunStr => fun("str", None),
PrimDef::FunBool => fun("bool", None),
PrimDef::FunFloor => fun("floor", None),
PrimDef::FunFloor64 => fun("floor64", None),
PrimDef::FunNpFloor => fun("np_floor", None), PrimDef::FunNpFloor => fun("np_floor", None),
PrimDef::FunCeil => fun("ceil", None),
PrimDef::FunCeil64 => fun("ceil64", None),
PrimDef::FunNpCeil => fun("np_ceil", None), PrimDef::FunNpCeil => fun("np_ceil", None),
PrimDef::FunLen => fun("len", None),
PrimDef::FunMin => fun("min", None),
PrimDef::FunNpMin => fun("np_min", None), PrimDef::FunNpMin => fun("np_min", None),
PrimDef::FunNpMinimum => fun("np_minimum", None), PrimDef::FunNpMinimum => fun("np_minimum", None),
PrimDef::FunNpArgmin => fun("np_argmin", None), PrimDef::FunMax => fun("max", None),
PrimDef::FunNpMax => fun("np_max", None), PrimDef::FunNpMax => fun("np_max", None),
PrimDef::FunNpMaximum => fun("np_maximum", None), PrimDef::FunNpMaximum => fun("np_maximum", None),
PrimDef::FunNpArgmax => fun("np_argmax", None), PrimDef::FunAbs => fun("abs", None),
PrimDef::FunNpIsNan => fun("np_isnan", None), PrimDef::FunNpIsNan => fun("np_isnan", None),
PrimDef::FunNpIsInf => fun("np_isinf", None), PrimDef::FunNpIsInf => fun("np_isinf", None),
PrimDef::FunNpSin => fun("np_sin", None), PrimDef::FunNpSin => fun("np_sin", None),
@ -287,40 +257,7 @@ impl PrimDef {
PrimDef::FunNpLdExp => fun("np_ldexp", None), PrimDef::FunNpLdExp => fun("np_ldexp", None),
PrimDef::FunNpHypot => fun("np_hypot", None), PrimDef::FunNpHypot => fun("np_hypot", None),
PrimDef::FunNpNextAfter => fun("np_nextafter", None), PrimDef::FunNpNextAfter => fun("np_nextafter", None),
PrimDef::FunNpTranspose => fun("np_transpose", None), PrimDef::FunSome => fun("Some", None),
PrimDef::FunNpReshape => fun("np_reshape", None),
// Linalg functions
PrimDef::FunNpDot => fun("np_dot", None),
PrimDef::FunNpLinalgCholesky => fun("np_linalg_cholesky", None),
PrimDef::FunNpLinalgQr => fun("np_linalg_qr", None),
PrimDef::FunNpLinalgSvd => fun("np_linalg_svd", None),
PrimDef::FunNpLinalgInv => fun("np_linalg_inv", None),
PrimDef::FunNpLinalgPinv => fun("np_linalg_pinv", None),
PrimDef::FunNpLinalgMatrixPower => fun("np_linalg_matrix_power", None),
PrimDef::FunNpLinalgDet => fun("np_linalg_det", None),
PrimDef::FunSpLinalgLu => fun("sp_linalg_lu", None),
PrimDef::FunSpLinalgSchur => fun("sp_linalg_schur", None),
PrimDef::FunSpLinalgHessenberg => fun("sp_linalg_hessenberg", None),
// Miscellaneous Python & NAC3 functions
PrimDef::FunInt32 => fun("int32", None),
PrimDef::FunInt64 => fun("int64", None),
PrimDef::FunUInt32 => fun("uint32", None),
PrimDef::FunUInt64 => fun("uint64", None),
PrimDef::FunFloat => fun("float", None),
PrimDef::FunRound => fun("round", None),
PrimDef::FunRound64 => fun("round64", None),
PrimDef::FunStr => fun("str", None),
PrimDef::FunBool => fun("bool", None),
PrimDef::FunFloor => fun("floor", None),
PrimDef::FunFloor64 => fun("floor64", None),
PrimDef::FunCeil => fun("ceil", None),
PrimDef::FunCeil64 => fun("ceil64", None),
PrimDef::FunLen => fun("len", None),
PrimDef::FunMin => fun("min", None),
PrimDef::FunMax => fun("max", None),
PrimDef::FunAbs => fun("abs", None),
} }
} }
} }
@ -389,9 +326,6 @@ impl TopLevelDef {
r r
} }
), ),
TopLevelDef::Variable { name, ty, .. } => {
format!("Variable {{ name: {name:?}, ty: {:?} }}", unifier.stringify(*ty),)
}
} }
} }
} }
@ -474,9 +408,9 @@ impl TopLevelComposer {
let option = unifier.add_ty(TypeEnum::TObj { let option = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::Option.id(), obj_id: PrimDef::Option.id(),
fields: vec![ fields: vec![
(PrimDef::FunOptionIsSome.simple_name().into(), (is_some_type_fun_ty, true)), (PrimDef::OptionIsSome.simple_name().into(), (is_some_type_fun_ty, true)),
(PrimDef::FunOptionIsNone.simple_name().into(), (is_some_type_fun_ty, true)), (PrimDef::OptionIsNone.simple_name().into(), (is_some_type_fun_ty, true)),
(PrimDef::FunOptionUnwrap.simple_name().into(), (unwrap_fun_ty, true)), (PrimDef::OptionUnwrap.simple_name().into(), (unwrap_fun_ty, true)),
] ]
.into_iter() .into_iter()
.collect::<HashMap<_, _>>(), .collect::<HashMap<_, _>>(),
@ -510,7 +444,6 @@ impl TopLevelComposer {
name: "value".into(), name: "value".into(),
ty: ndarray_dtype_tvar.ty, ty: ndarray_dtype_tvar.ty,
default_value: None, default_value: None,
is_vararg: false,
}], }],
ret: none, ret: none,
vars: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]), vars: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
@ -518,8 +451,8 @@ impl TopLevelComposer {
let ndarray = unifier.add_ty(TypeEnum::TObj { let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::NDArray.id(), obj_id: PrimDef::NDArray.id(),
fields: Mapping::from([ fields: Mapping::from([
(PrimDef::FunNDArrayCopy.simple_name().into(), (ndarray_copy_fun_ty, true)), (PrimDef::NDArrayCopy.simple_name().into(), (ndarray_copy_fun_ty, true)),
(PrimDef::FunNDArrayFill.simple_name().into(), (ndarray_fill_fun_ty, true)), (PrimDef::NDArrayFill.simple_name().into(), (ndarray_fill_fun_ty, true)),
]), ]),
params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]), params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
}); });
@ -593,18 +526,6 @@ impl TopLevelComposer {
} }
} }
#[must_use]
pub fn make_top_level_variable_def(
name: String,
simple_name: StrRef,
ty: Type,
ty_decl: Option<Expr>,
resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>,
loc: Option<Location>,
) -> TopLevelDef {
TopLevelDef::Variable { name, simple_name, ty, ty_decl, resolver, loc }
}
#[must_use] #[must_use]
pub fn make_class_method_name(mut class_name: String, method_name: &str) -> String { pub fn make_class_method_name(mut class_name: String, method_name: &str) -> String {
class_name.push('.'); class_name.push('.');
@ -750,16 +671,7 @@ impl TopLevelComposer {
) )
} }
/// This function returns the fields that have been initialized in the `__init__` function of a class pub fn get_all_assigned_field(stmts: &[Stmt<()>]) -> Result<HashSet<StrRef>, HashSet<String>> {
/// The function takes as input:
/// * `class_id`: The `object_id` of the class whose function is being evaluated (check `TopLevelDef::Class`)
/// * `definition_ast_list`: A list of ast definitions and statements defined in `TopLevelComposer`
/// * `stmts`: The body of function being parsed. Each statment is analyzed to check varaible initialization statements
pub fn get_all_assigned_field(
class_id: usize,
definition_ast_list: &Vec<DefAst>,
stmts: &[Stmt<()>],
) -> Result<HashSet<StrRef>, HashSet<String>> {
let mut result = HashSet::new(); let mut result = HashSet::new();
for s in stmts { for s in stmts {
match &s.node { match &s.node {
@ -795,138 +707,30 @@ impl TopLevelComposer {
// TODO: do not check for For and While? // TODO: do not check for For and While?
ast::StmtKind::For { body, orelse, .. } ast::StmtKind::For { body, orelse, .. }
| ast::StmtKind::While { body, orelse, .. } => { | ast::StmtKind::While { body, orelse, .. } => {
result.extend(Self::get_all_assigned_field( result.extend(Self::get_all_assigned_field(body.as_slice())?);
class_id, result.extend(Self::get_all_assigned_field(orelse.as_slice())?);
definition_ast_list,
body.as_slice(),
)?);
result.extend(Self::get_all_assigned_field(
class_id,
definition_ast_list,
orelse.as_slice(),
)?);
} }
ast::StmtKind::If { body, orelse, .. } => { ast::StmtKind::If { body, orelse, .. } => {
let inited_for_sure = Self::get_all_assigned_field( let inited_for_sure = Self::get_all_assigned_field(body.as_slice())?
class_id, .intersection(&Self::get_all_assigned_field(orelse.as_slice())?)
definition_ast_list, .copied()
body.as_slice(), .collect::<HashSet<_>>();
)?
.intersection(&Self::get_all_assigned_field(
class_id,
definition_ast_list,
orelse.as_slice(),
)?)
.copied()
.collect::<HashSet<_>>();
result.extend(inited_for_sure); result.extend(inited_for_sure);
} }
ast::StmtKind::Try { body, orelse, finalbody, .. } => { ast::StmtKind::Try { body, orelse, finalbody, .. } => {
let inited_for_sure = Self::get_all_assigned_field( let inited_for_sure = Self::get_all_assigned_field(body.as_slice())?
class_id, .intersection(&Self::get_all_assigned_field(orelse.as_slice())?)
definition_ast_list, .copied()
body.as_slice(), .collect::<HashSet<_>>();
)?
.intersection(&Self::get_all_assigned_field(
class_id,
definition_ast_list,
orelse.as_slice(),
)?)
.copied()
.collect::<HashSet<_>>();
result.extend(inited_for_sure); result.extend(inited_for_sure);
result.extend(Self::get_all_assigned_field( result.extend(Self::get_all_assigned_field(finalbody.as_slice())?);
class_id,
definition_ast_list,
finalbody.as_slice(),
)?);
} }
ast::StmtKind::With { body, .. } => { ast::StmtKind::With { body, .. } => {
result.extend(Self::get_all_assigned_field( result.extend(Self::get_all_assigned_field(body.as_slice())?);
class_id,
definition_ast_list,
body.as_slice(),
)?);
}
// Variables Initialized in function calls
ast::StmtKind::Expr { value, .. } => {
let ExprKind::Call { func, .. } = &value.node else {
continue;
};
let ExprKind::Attribute { value, attr, .. } = &func.node else {
continue;
};
let ExprKind::Name { id, .. } = &value.node else {
continue;
};
// Need to consider the two cases:
// Case 1) Call to class function i.e. id = `self`
// Case 2) Call to class ancestor function i.e. id = ancestor_name
// We leave checking whether function in case 2 belonged to class ancestor or not to type checker
//
// According to current handling of `self`, function definition are fixed and do not change regardless
// of which object is passed as `self` i.e. virtual polymorphism is not supported
// Therefore, we change class id for case 2 to reflect behavior of our compiler
let class_name = if *id == "self".into() {
let ast::StmtKind::ClassDef { name, .. } =
&definition_ast_list[class_id].1.as_ref().unwrap().node
else {
unreachable!()
};
name
} else {
id
};
let parent_method = definition_ast_list.iter().find_map(|def| {
let (
class_def,
Some(ast::Located {
node: ast::StmtKind::ClassDef { name, body, .. },
..
}),
) = &def
else {
return None;
};
let TopLevelDef::Class { object_id: class_id, .. } = &*class_def.read()
else {
unreachable!()
};
if name == class_name {
body.iter().find_map(|m| {
let ast::StmtKind::FunctionDef { name, body, .. } = &m.node else {
return None;
};
if *name == *attr {
return Some((body.clone(), class_id.0));
}
None
})
} else {
None
}
});
// If method body is none then method does not exist
if let Some((method_body, class_id)) = parent_method {
result.extend(Self::get_all_assigned_field(
class_id,
definition_ast_list,
method_body.as_slice(),
)?);
} else {
return Err(HashSet::from([format!(
"{}.{} not found in class {class_name} at {}",
*id, *attr, value.location
)]));
}
} }
ast::StmtKind::Pass { .. } ast::StmtKind::Pass { .. }
| ast::StmtKind::Assert { .. } | ast::StmtKind::Assert { .. }
| ast::StmtKind::AnnAssign { .. } => {} | ast::StmtKind::Expr { .. } => {}
_ => { _ => {
unimplemented!() unimplemented!()

View File

@ -6,36 +6,36 @@ use std::{
sync::Arc, sync::Arc,
}; };
use inkwell::values::BasicValueEnum; use super::codegen::CodeGenContext;
use itertools::Itertools; use super::typecheck::type_inferencer::PrimitiveStore;
use parking_lot::RwLock; use super::typecheck::typedef::{
FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, Unifier, VarMap,
use nac3parser::ast::{self, Expr, Location, Stmt, StrRef}; };
use crate::{ use crate::{
codegen::{CodeGenContext, CodeGenerator}, codegen::CodeGenerator,
symbol_resolver::{SymbolResolver, ValueEnum}, symbol_resolver::{SymbolResolver, ValueEnum},
typecheck::{ typecheck::{
type_inferencer::{CodeLocation, PrimitiveStore}, type_inferencer::CodeLocation,
typedef::{ typedef::{CallId, TypeVarId},
CallId, FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, TypeVarId, Unifier,
VarMap,
},
}, },
}; };
use composer::*; use inkwell::values::BasicValueEnum;
use type_annotation::*; use itertools::Itertools;
use nac3parser::ast::{self, Location, Stmt, StrRef};
use parking_lot::RwLock;
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash, Debug)]
pub struct DefinitionId(pub usize);
pub mod builtins; pub mod builtins;
pub mod composer; pub mod composer;
pub mod helper; pub mod helper;
pub mod numpy; pub mod numpy;
pub mod type_annotation;
use composer::*;
use type_annotation::*;
#[cfg(test)] #[cfg(test)]
mod test; mod test;
pub mod type_annotation;
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash, Debug)]
pub struct DefinitionId(pub usize);
type GenCallCallback = dyn for<'ctx, 'a> Fn( type GenCallCallback = dyn for<'ctx, 'a> Fn(
&mut CodeGenContext<'ctx, 'a>, &mut CodeGenContext<'ctx, 'a>,
@ -130,14 +130,14 @@ pub enum TopLevelDef {
/// Function instance to symbol mapping /// Function instance to symbol mapping
/// ///
/// * Key: String representation of type variable values, sorted by variable ID in ascending /// * Key: String representation of type variable values, sorted by variable ID in ascending
/// order, including type variables associated with the class. /// order, including type variables associated with the class.
/// * Value: Function symbol name. /// * Value: Function symbol name.
instance_to_symbol: HashMap<String, String>, instance_to_symbol: HashMap<String, String>,
/// Function instances to annotated AST mapping /// Function instances to annotated AST mapping
/// ///
/// * Key: String representation of type variable values, sorted by variable ID in ascending /// * Key: String representation of type variable values, sorted by variable ID in ascending
/// order, including type variables associated with the class. Excluding rigid type /// order, including type variables associated with the class. Excluding rigid type
/// variables. /// variables.
/// ///
/// Rigid type variables that would be substituted when the function is instantiated. /// Rigid type variables that would be substituted when the function is instantiated.
instance_to_stmt: HashMap<String, FunInstance>, instance_to_stmt: HashMap<String, FunInstance>,
@ -148,25 +148,6 @@ pub enum TopLevelDef {
/// Definition location. /// Definition location.
loc: Option<Location>, loc: Option<Location>,
}, },
Variable {
/// Qualified name of the global variable, should be unique globally.
name: String,
/// Simple name, the same as in method/function definition.
simple_name: StrRef,
/// Type of the global variable.
ty: Type,
/// The declared type of the global variable, or [`None`] if no type annotation is provided.
ty_decl: Option<Expr>,
/// Symbol resolver of the module defined the class.
resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>,
/// Definition location.
loc: Option<Location>,
},
} }
pub struct TopLevelContext { pub struct TopLevelContext {

View File

@ -1,17 +1,18 @@
use itertools::Itertools; use crate::{
toplevel::helper::PrimDef,
use super::helper::PrimDef; typecheck::{
use crate::typecheck::{ type_inferencer::PrimitiveStore,
type_inferencer::PrimitiveStore, typedef::{Type, TypeEnum, TypeVarId, Unifier, VarMap},
typedef::{Type, TypeEnum, TypeVarId, Unifier, VarMap}, },
}; };
use itertools::Itertools;
/// Creates a `ndarray` [`Type`] with the given type arguments. /// Creates a `ndarray` [`Type`] with the given type arguments.
/// ///
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not /// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
/// specialized. /// specialized.
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not /// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
/// specialized. /// specialized.
pub fn make_ndarray_ty( pub fn make_ndarray_ty(
unifier: &mut Unifier, unifier: &mut Unifier,
primitives: &PrimitiveStore, primitives: &PrimitiveStore,
@ -24,9 +25,9 @@ pub fn make_ndarray_ty(
/// Substitutes type variables in `ndarray`. /// Substitutes type variables in `ndarray`.
/// ///
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not /// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
/// specialized. /// specialized.
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not /// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
/// specialized. /// specialized.
pub fn subst_ndarray_tvars( pub fn subst_ndarray_tvars(
unifier: &mut Unifier, unifier: &mut Unifier,
ndarray: Type, ndarray: Type,

View File

@ -5,7 +5,7 @@ expression: res_vec
[ [
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(241)]\n}\n", "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(245)]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",

View File

@ -7,7 +7,7 @@ expression: res_vec
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar230]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar230\"]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B[typevar234]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar234\"]\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",

View File

@ -5,8 +5,8 @@ expression: res_vec
[ [
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(243)]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(247)]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(248)]\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(252)]\n}\n",
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",

View File

@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec expression: res_vec
--- ---
[ [
"Class {\nname: \"A\",\nancestors: [\"A[typevar229, typevar230]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar229\", \"typevar230\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[typevar233, typevar234]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar233\", \"typevar234\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",

View File

@ -6,12 +6,12 @@ expression: res_vec
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(249)]\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(253)]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(257)]\n}\n", "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(261)]\n}\n",
] ]

View File

@ -1,23 +1,21 @@
use std::{collections::HashMap, sync::Arc}; use super::*;
use crate::toplevel::helper::PrimDef;
use indoc::indoc; use crate::typecheck::typedef::into_var_map;
use parking_lot::Mutex;
use test_case::test_case;
use nac3parser::{
ast::{fold::Fold, FileName},
parser::parse_program,
};
use super::{helper::PrimDef, DefinitionId, *};
use crate::{ use crate::{
codegen::CodeGenContext, codegen::CodeGenContext,
symbol_resolver::{SymbolResolver, ValueEnum}, symbol_resolver::{SymbolResolver, ValueEnum},
toplevel::DefinitionId,
typecheck::{ typecheck::{
type_inferencer::PrimitiveStore, type_inferencer::PrimitiveStore,
typedef::{into_var_map, Type, Unifier}, typedef::{Type, Unifier},
}, },
}; };
use indoc::indoc;
use nac3parser::ast::FileName;
use nac3parser::{ast::fold::Fold, parser::parse_program};
use parking_lot::Mutex;
use std::{collections::HashMap, sync::Arc};
use test_case::test_case;
struct ResolverInternal { struct ResolverInternal {
id_to_type: Mutex<HashMap<StrRef, Type>>, id_to_type: Mutex<HashMap<StrRef, Type>>,
@ -64,7 +62,6 @@ impl SymbolResolver for Resolver {
&self, &self,
_: StrRef, _: StrRef,
_: &mut CodeGenContext<'ctx, '_>, _: &mut CodeGenContext<'ctx, '_>,
_: &mut dyn CodeGenerator,
) -> Option<ValueEnum<'ctx>> { ) -> Option<ValueEnum<'ctx>> {
unimplemented!() unimplemented!()
} }
@ -120,8 +117,7 @@ impl SymbolResolver for Resolver {
"register" "register"
)] )]
fn test_simple_register(source: Vec<&str>) { fn test_simple_register(source: Vec<&str>) {
let mut composer = let mut composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 64).0;
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
for s in source { for s in source {
let ast = parse_program(s, FileName::default()).unwrap(); let ast = parse_program(s, FileName::default()).unwrap();
@ -141,8 +137,7 @@ fn test_simple_register(source: Vec<&str>) {
"register" "register"
)] )]
fn test_simple_register_without_constructor(source: &str) { fn test_simple_register_without_constructor(source: &str) {
let mut composer = let mut composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 64).0;
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
let ast = parse_program(source, FileName::default()).unwrap(); let ast = parse_program(source, FileName::default()).unwrap();
let ast = ast[0].clone(); let ast = ast[0].clone();
composer.register_top_level(ast, None, "", true).unwrap(); composer.register_top_level(ast, None, "", true).unwrap();
@ -176,8 +171,7 @@ fn test_simple_register_without_constructor(source: &str) {
"function compose" "function compose"
)] )]
fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) { fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
let mut composer = let mut composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 64).0;
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
let internal_resolver = Arc::new(ResolverInternal { let internal_resolver = Arc::new(ResolverInternal {
id_to_def: Mutex::default(), id_to_def: Mutex::default(),
@ -525,8 +519,7 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
)] )]
fn test_analyze(source: &[&str], res: &[&str]) { fn test_analyze(source: &[&str], res: &[&str]) {
let print = false; let print = false;
let mut composer = let mut composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 64).0;
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
let internal_resolver = make_internal_resolver_with_tvar( let internal_resolver = make_internal_resolver_with_tvar(
vec![ vec![
@ -703,8 +696,7 @@ fn test_analyze(source: &[&str], res: &[&str]) {
)] )]
fn test_inference(source: Vec<&str>, res: &[&str]) { fn test_inference(source: Vec<&str>, res: &[&str]) {
let print = true; let print = true;
let mut composer = let mut composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 64).0;
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
let internal_resolver = make_internal_resolver_with_tvar( let internal_resolver = make_internal_resolver_with_tvar(
vec![ vec![

View File

@ -1,13 +1,9 @@
use strum::IntoEnumIterator; use super::*;
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::PrimDef;
use crate::typecheck::typedef::VarMap;
use nac3parser::ast::Constant; use nac3parser::ast::Constant;
use super::{
helper::{PrimDef, PrimDefDetails},
*,
};
use crate::{symbol_resolver::SymbolValue, typecheck::typedef::VarMap};
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum TypeAnnotation { pub enum TypeAnnotation {
Primitive(Type), Primitive(Type),
@ -67,9 +63,9 @@ impl TypeAnnotation {
/// Parses an AST expression `expr` into a [`TypeAnnotation`]. /// Parses an AST expression `expr` into a [`TypeAnnotation`].
/// ///
/// * `locked` - A [`HashMap`] containing the IDs of known definitions, mapped to a [`Vec`] of all /// * `locked` - A [`HashMap`] containing the IDs of known definitions, mapped to a [`Vec`] of all
/// generic variables associated with the definition. /// generic variables associated with the definition.
/// * `type_var` - The type variable associated with the type argument currently being parsed. Pass /// * `type_var` - The type variable associated with the type argument currently being parsed. Pass
/// [`None`] when this function is invoked externally. /// [`None`] when this function is invoked externally.
pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>( pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
resolver: &(dyn SymbolResolver + Send + Sync), resolver: &(dyn SymbolResolver + Send + Sync),
top_level_defs: &[Arc<RwLock<TopLevelDef>>], top_level_defs: &[Arc<RwLock<TopLevelDef>>],
@ -361,7 +357,6 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
pub fn get_type_from_type_annotation_kinds( pub fn get_type_from_type_annotation_kinds(
top_level_defs: &[Arc<RwLock<TopLevelDef>>], top_level_defs: &[Arc<RwLock<TopLevelDef>>],
unifier: &mut Unifier, unifier: &mut Unifier,
primitives: &PrimitiveStore,
ann: &TypeAnnotation, ann: &TypeAnnotation,
subst_list: &mut Option<Vec<Type>>, subst_list: &mut Option<Vec<Type>>,
) -> Result<Type, HashSet<String>> { ) -> Result<Type, HashSet<String>> {
@ -384,141 +379,100 @@ pub fn get_type_from_type_annotation_kinds(
let param_ty = params let param_ty = params
.iter() .iter()
.map(|x| { .map(|x| {
get_type_from_type_annotation_kinds( get_type_from_type_annotation_kinds(top_level_defs, unifier, x, subst_list)
top_level_defs,
unifier,
primitives,
x,
subst_list,
)
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
let ty = if let Some(prim_def) = PrimDef::iter().find(|prim| prim.id() == *obj_id) { let subst = {
// Primitive TopLevelDefs do not contain all fields that are present in their Type // check for compatible range
// counterparts, so directly perform subst on the Type instead. // TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check
let mut result = VarMap::new();
let PrimDefDetails::PrimClass { get_ty_fn, .. } = prim_def.details() else { for (tvar, p) in type_vars.iter().zip(param_ty) {
unreachable!() match unifier.get_ty(*tvar).as_ref() {
}; TypeEnum::TVar {
id,
let base_ty = get_ty_fn(primitives); range,
let params = fields: None,
if let TypeEnum::TObj { params, .. } = &*unifier.get_ty_immutable(base_ty) { name,
params.clone() loc,
} else { is_const_generic: false,
unreachable!() } => {
}; let ok: bool = {
// create a temp type var and unify to check compatibility
unifier p == *tvar || {
.subst( let temp = unifier.get_fresh_var_with_range(
get_ty_fn(primitives), range.as_slice(),
&params *name,
.iter() *loc,
.zip(param_ty) );
.map(|(obj_tv, param)| (*obj_tv.0, param)) unifier.unify(temp.ty, p).is_ok()
.collect(),
)
.unwrap_or(base_ty)
} else {
let subst = {
// check for compatible range
// TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check
let mut result = VarMap::new();
for (tvar, p) in type_vars.iter().zip(param_ty) {
match unifier.get_ty(*tvar).as_ref() {
TypeEnum::TVar {
id,
range,
fields: None,
name,
loc,
is_const_generic: false,
} => {
let ok: bool = {
// create a temp type var and unify to check compatibility
p == *tvar || {
let temp = unifier.get_fresh_var_with_range(
range.as_slice(),
*name,
*loc,
);
unifier.unify(temp.ty, p).is_ok()
}
};
if ok {
result.insert(*id, p);
} else {
return Err(HashSet::from([format!(
"cannot apply type {} to type variable with id {:?}",
unifier.internal_stringify(
p,
&mut |id| format!("class{id}"),
&mut |id| format!("typevar{id}"),
&mut None
),
*id
)]));
} }
};
if ok {
result.insert(*id, p);
} else {
return Err(HashSet::from([format!(
"cannot apply type {} to type variable with id {:?}",
unifier.internal_stringify(
p,
&mut |id| format!("class{id}"),
&mut |id| format!("typevar{id}"),
&mut None
),
*id
)]));
} }
TypeEnum::TVar {
id, range, name, loc, is_const_generic: true, ..
} => {
let ty = range[0];
let ok: bool = {
// create a temp type var and unify to check compatibility
p == *tvar || {
let temp =
unifier.get_fresh_const_generic_var(ty, *name, *loc);
unifier.unify(temp.ty, p).is_ok()
}
};
if ok {
result.insert(*id, p);
} else {
return Err(HashSet::from([format!(
"cannot apply type {} to type variable {}",
unifier.stringify(p),
name.unwrap_or_else(|| format!("typevar{id}").into()),
)]));
}
}
_ => unreachable!("must be generic type var"),
} }
}
result
};
// Class Attributes keep a copy with Class Definition and are not added to objects
let mut tobj_fields = methods
.iter()
.map(|(name, ty, _)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
// methods are immutable
(*name, (subst_ty, false))
})
.collect::<HashMap<_, _>>();
tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*name, (subst_ty, *mutability))
}));
let need_subst = !subst.is_empty();
let ty = unifier.add_ty(TypeEnum::TObj {
obj_id: *obj_id,
fields: tobj_fields,
params: subst,
});
if need_subst { TypeEnum::TVar { id, range, name, loc, is_const_generic: true, .. } => {
if let Some(wl) = subst_list.as_mut() { let ty = range[0];
wl.push(ty); let ok: bool = {
// create a temp type var and unify to check compatibility
p == *tvar || {
let temp = unifier.get_fresh_const_generic_var(ty, *name, *loc);
unifier.unify(temp.ty, p).is_ok()
}
};
if ok {
result.insert(*id, p);
} else {
return Err(HashSet::from([format!(
"cannot apply type {} to type variable {}",
unifier.stringify(p),
name.unwrap_or_else(|| format!("typevar{id}").into()),
)]));
}
}
_ => unreachable!("must be generic type var"),
} }
} }
result
ty
}; };
// Class Attributes keep a copy with Class Definition and are not added to objects
let mut tobj_fields = methods
.iter()
.map(|(name, ty, _)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
// methods are immutable
(*name, (subst_ty, false))
})
.collect::<HashMap<_, _>>();
tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*name, (subst_ty, *mutability))
}));
let need_subst = !subst.is_empty();
let ty = unifier.add_ty(TypeEnum::TObj {
obj_id: *obj_id,
fields: tobj_fields,
params: subst,
});
if need_subst {
if let Some(wl) = subst_list.as_mut() {
wl.push(ty);
}
}
Ok(ty) Ok(ty)
} }
TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty), TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty),
@ -536,7 +490,6 @@ pub fn get_type_from_type_annotation_kinds(
let ty = get_type_from_type_annotation_kinds( let ty = get_type_from_type_annotation_kinds(
top_level_defs, top_level_defs,
unifier, unifier,
primitives,
ty.as_ref(), ty.as_ref(),
subst_list, subst_list,
)?; )?;
@ -546,16 +499,10 @@ pub fn get_type_from_type_annotation_kinds(
let tys = tys let tys = tys
.iter() .iter()
.map(|x| { .map(|x| {
get_type_from_type_annotation_kinds( get_type_from_type_annotation_kinds(top_level_defs, unifier, x, subst_list)
top_level_defs,
unifier,
primitives,
x,
subst_list,
)
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
Ok(unifier.add_ty(TypeEnum::TTuple { ty: tys, is_vararg_ctx: false })) Ok(unifier.add_ty(TypeEnum::TTuple { ty: tys }))
} }
} }
} }

View File

@ -1,19 +1,13 @@
use std::{ use crate::toplevel::helper::PrimDef;
collections::{HashMap, HashSet},
iter::once,
};
use super::type_inferencer::Inferencer;
use super::typedef::{Type, TypeEnum};
use nac3parser::ast::{ use nac3parser::ast::{
self, Constant, Expr, ExprKind, self, Constant, Expr, ExprKind,
Operator::{LShift, RShift}, Operator::{LShift, RShift},
Stmt, StmtKind, StrRef, Stmt, StmtKind, StrRef,
}; };
use std::{collections::HashSet, iter::once};
use super::{
type_inferencer::{DeclarationSource, IdentifierInfo, Inferencer},
typedef::{Type, TypeEnum},
};
use crate::toplevel::helper::PrimDef;
impl<'a> Inferencer<'a> { impl<'a> Inferencer<'a> {
fn should_have_value(&mut self, expr: &Expr<Option<Type>>) -> Result<(), HashSet<String>> { fn should_have_value(&mut self, expr: &Expr<Option<Type>>) -> Result<(), HashSet<String>> {
@ -27,45 +21,26 @@ impl<'a> Inferencer<'a> {
fn check_pattern( fn check_pattern(
&mut self, &mut self,
pattern: &Expr<Option<Type>>, pattern: &Expr<Option<Type>>,
defined_identifiers: &mut HashMap<StrRef, IdentifierInfo>, defined_identifiers: &mut HashSet<StrRef>,
) -> Result<(), HashSet<String>> { ) -> Result<(), HashSet<String>> {
match &pattern.node { match &pattern.node {
ExprKind::Name { id, .. } if id == &"none".into() => { ExprKind::Name { id, .. } if id == &"none".into() => {
Err(HashSet::from([format!("cannot assign to a `none` (at {})", pattern.location)])) Err(HashSet::from([format!("cannot assign to a `none` (at {})", pattern.location)]))
} }
ExprKind::Name { id, .. } => { ExprKind::Name { id, .. } => {
// If `id` refers to a declared symbol, reject this assignment if it is used in the if !defined_identifiers.contains(id) {
// context of an (implicit) global variable defined_identifiers.insert(*id);
if let Some(id_info) = defined_identifiers.get(id) {
if matches!(
id_info.source,
DeclarationSource::Global { is_explicit: Some(false) }
) {
return Err(HashSet::from([format!(
"cannot access local variable '{id}' before it is declared (at {})",
pattern.location
)]));
}
}
if !defined_identifiers.contains_key(id) {
defined_identifiers.insert(*id, IdentifierInfo::default());
} }
self.should_have_value(pattern)?; self.should_have_value(pattern)?;
Ok(()) Ok(())
} }
ExprKind::List { elts, .. } | ExprKind::Tuple { elts, .. } => { ExprKind::Tuple { elts, .. } => {
for elt in elts { for elt in elts {
self.check_pattern(elt, defined_identifiers)?; self.check_pattern(elt, defined_identifiers)?;
self.should_have_value(elt)?; self.should_have_value(elt)?;
} }
Ok(()) Ok(())
} }
ExprKind::Starred { value, .. } => {
self.check_pattern(value, defined_identifiers)?;
self.should_have_value(value)?;
Ok(())
}
ExprKind::Subscript { value, slice, .. } => { ExprKind::Subscript { value, slice, .. } => {
self.check_expr(value, defined_identifiers)?; self.check_expr(value, defined_identifiers)?;
self.should_have_value(value)?; self.should_have_value(value)?;
@ -89,7 +64,7 @@ impl<'a> Inferencer<'a> {
fn check_expr( fn check_expr(
&mut self, &mut self,
expr: &Expr<Option<Type>>, expr: &Expr<Option<Type>>,
defined_identifiers: &mut HashMap<StrRef, IdentifierInfo>, defined_identifiers: &mut HashSet<StrRef>,
) -> Result<(), HashSet<String>> { ) -> Result<(), HashSet<String>> {
// there are some cases where the custom field is None // there are some cases where the custom field is None
if let Some(ty) = &expr.custom { if let Some(ty) = &expr.custom {
@ -110,7 +85,7 @@ impl<'a> Inferencer<'a> {
return Ok(()); return Ok(());
} }
self.should_have_value(expr)?; self.should_have_value(expr)?;
if !defined_identifiers.contains_key(id) { if !defined_identifiers.contains(id) {
match self.function_data.resolver.get_symbol_type( match self.function_data.resolver.get_symbol_type(
self.unifier, self.unifier,
&self.top_level.definitions.read(), &self.top_level.definitions.read(),
@ -118,22 +93,7 @@ impl<'a> Inferencer<'a> {
*id, *id,
) { ) {
Ok(_) => { Ok(_) => {
let is_global = self.is_id_global(*id); self.defined_identifiers.insert(*id);
defined_identifiers.insert(
*id,
IdentifierInfo {
source: match is_global {
Some(true) => {
DeclarationSource::Global { is_explicit: Some(false) }
}
Some(false) => {
DeclarationSource::Global { is_explicit: None }
}
None => DeclarationSource::Local,
},
},
);
} }
Err(e) => { Err(e) => {
return Err(HashSet::from([format!( return Err(HashSet::from([format!(
@ -206,7 +166,9 @@ impl<'a> Inferencer<'a> {
let mut defined_identifiers = defined_identifiers.clone(); let mut defined_identifiers = defined_identifiers.clone();
for arg in &args.args { for arg in &args.args {
// TODO: should we check the types here? // TODO: should we check the types here?
defined_identifiers.entry(arg.node.arg).or_default(); if !defined_identifiers.contains(&arg.node.arg) {
defined_identifiers.insert(arg.node.arg);
}
} }
self.check_expr(body, &mut defined_identifiers)?; self.check_expr(body, &mut defined_identifiers)?;
} }
@ -245,23 +207,19 @@ impl<'a> Inferencer<'a> {
/// This is a workaround preventing the caller from using a variable `alloca`-ed in the body, which /// This is a workaround preventing the caller from using a variable `alloca`-ed in the body, which
/// is freed when the function returns. /// is freed when the function returns.
fn check_return_value_ty(&mut self, ret_ty: Type) -> bool { fn check_return_value_ty(&mut self, ret_ty: Type) -> bool {
if cfg!(feature = "no-escape-analysis") { match &*self.unifier.get_ty_immutable(ret_ty) {
true TypeEnum::TObj { .. } => [
} else { self.primitives.int32,
match &*self.unifier.get_ty_immutable(ret_ty) { self.primitives.int64,
TypeEnum::TObj { .. } => [ self.primitives.uint32,
self.primitives.int32, self.primitives.uint64,
self.primitives.int64, self.primitives.float,
self.primitives.uint32, self.primitives.bool,
self.primitives.uint64, ]
self.primitives.float, .iter()
self.primitives.bool, .any(|allowed_ty| self.unifier.unioned(ret_ty, *allowed_ty)),
] TypeEnum::TTuple { ty } => ty.iter().all(|t| self.check_return_value_ty(*t)),
.iter() _ => false,
.any(|allowed_ty| self.unifier.unioned(ret_ty, *allowed_ty)),
TypeEnum::TTuple { ty, .. } => ty.iter().all(|t| self.check_return_value_ty(*t)),
_ => false,
}
} }
} }
@ -269,7 +227,7 @@ impl<'a> Inferencer<'a> {
fn check_stmt( fn check_stmt(
&mut self, &mut self,
stmt: &Stmt<Option<Type>>, stmt: &Stmt<Option<Type>>,
defined_identifiers: &mut HashMap<StrRef, IdentifierInfo>, defined_identifiers: &mut HashSet<StrRef>,
) -> Result<bool, HashSet<String>> { ) -> Result<bool, HashSet<String>> {
match &stmt.node { match &stmt.node {
StmtKind::For { target, iter, body, orelse, .. } => { StmtKind::For { target, iter, body, orelse, .. } => {
@ -295,11 +253,9 @@ impl<'a> Inferencer<'a> {
let body_returned = self.check_block(body, &mut body_identifiers)?; let body_returned = self.check_block(body, &mut body_identifiers)?;
let orelse_returned = self.check_block(orelse, &mut orelse_identifiers)?; let orelse_returned = self.check_block(orelse, &mut orelse_identifiers)?;
for ident in body_identifiers.keys() { for ident in &body_identifiers {
if !defined_identifiers.contains_key(ident) if !defined_identifiers.contains(ident) && orelse_identifiers.contains(ident) {
&& orelse_identifiers.contains_key(ident) defined_identifiers.insert(*ident);
{
defined_identifiers.insert(*ident, IdentifierInfo::default());
} }
} }
Ok(body_returned && orelse_returned) Ok(body_returned && orelse_returned)
@ -330,7 +286,7 @@ impl<'a> Inferencer<'a> {
let mut defined_identifiers = defined_identifiers.clone(); let mut defined_identifiers = defined_identifiers.clone();
let ast::ExcepthandlerKind::ExceptHandler { name, body, .. } = &handler.node; let ast::ExcepthandlerKind::ExceptHandler { name, body, .. } = &handler.node;
if let Some(name) = name { if let Some(name) = name {
defined_identifiers.insert(*name, IdentifierInfo::default()); defined_identifiers.insert(*name);
} }
self.check_block(body, &mut defined_identifiers)?; self.check_block(body, &mut defined_identifiers)?;
} }
@ -394,44 +350,6 @@ impl<'a> Inferencer<'a> {
} }
Ok(true) Ok(true)
} }
StmtKind::Global { names, .. } => {
for id in names {
if let Some(id_info) = defined_identifiers.get(id) {
if id_info.source == DeclarationSource::Local {
return Err(HashSet::from([format!(
"name '{id}' is referenced prior to global declaration at {}",
stmt.location,
)]));
}
continue;
}
match self.function_data.resolver.get_symbol_type(
self.unifier,
&self.top_level.definitions.read(),
self.primitives,
*id,
) {
Ok(_) => {
defined_identifiers.insert(
*id,
IdentifierInfo {
source: DeclarationSource::Global { is_explicit: Some(true) },
},
);
}
Err(e) => {
return Err(HashSet::from([format!(
"type error at identifier `{}` ({}) at {}",
id, e, stmt.location
)]))
}
}
}
Ok(false)
}
// break, raise, etc. // break, raise, etc.
_ => Ok(false), _ => Ok(false),
} }
@ -440,7 +358,7 @@ impl<'a> Inferencer<'a> {
pub fn check_block( pub fn check_block(
&mut self, &mut self,
block: &[Stmt<Option<Type>>], block: &[Stmt<Option<Type>>],
defined_identifiers: &mut HashMap<StrRef, IdentifierInfo>, defined_identifiers: &mut HashSet<StrRef>,
) -> Result<bool, HashSet<String>> { ) -> Result<bool, HashSet<String>> {
let mut ret = false; let mut ret = false;
for stmt in block { for stmt in block {

View File

@ -1,21 +1,17 @@
use std::{cmp::max, collections::HashMap, rc::Rc}; use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::PrimDef;
use itertools::{iproduct, Itertools}; use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys};
use strum::IntoEnumIterator; use crate::typecheck::{
use nac3parser::ast::{Cmpop, Operator, StrRef, Unaryop};
use super::{
type_inferencer::*, type_inferencer::*,
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
}; };
use crate::{ use itertools::{iproduct, Itertools};
symbol_resolver::SymbolValue, use nac3parser::ast::StrRef;
toplevel::{ use nac3parser::ast::{Cmpop, Operator, Unaryop};
helper::PrimDef, use std::cmp::max;
numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, use std::collections::HashMap;
}, use std::rc::Rc;
}; use strum::IntoEnumIterator;
/// The variant of a binary operator. /// The variant of a binary operator.
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
@ -201,7 +197,6 @@ pub fn impl_binop(
ty: other_ty, ty: other_ty,
default_value: None, default_value: None,
name: "other".into(), name: "other".into(),
is_vararg: false,
}], }],
})), })),
false, false,
@ -266,7 +261,6 @@ pub fn impl_cmpop(
ty: other_ty, ty: other_ty,
default_value: None, default_value: None,
name: "other".into(), name: "other".into(),
is_vararg: false,
}], }],
})), })),
false, false,
@ -524,23 +518,6 @@ pub fn typeof_binop(
} }
Operator::MatMult => { Operator::MatMult => {
// NOTE: NumPy matmul's LHS and RHS must both be ndarrays. Scalars are not allowed.
match (&*unifier.get_ty(lhs), &*unifier.get_ty(rhs)) {
(
TypeEnum::TObj { obj_id: lhs_obj_id, .. },
TypeEnum::TObj { obj_id: rhs_obj_id, .. },
) if *lhs_obj_id == primitives.ndarray.obj_id(unifier).unwrap()
&& *rhs_obj_id == primitives.ndarray.obj_id(unifier).unwrap() =>
{
// LHS and RHS have valid types
}
_ => {
let lhs_str = unifier.stringify(lhs);
let rhs_str = unifier.stringify(rhs);
return Err(format!("ndarray.__matmul__ only accepts ndarray operands, but left operand has type {lhs_str}, and right operand has type {rhs_str}"));
}
}
let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs); let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs);
let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) { let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) {
TypeEnum::TLiteral { values, .. } => { TypeEnum::TLiteral { values, .. } => {
@ -701,7 +678,6 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
bool: bool_t, bool: bool_t,
uint32: uint32_t, uint32: uint32_t,
uint64: uint64_t, uint64: uint64_t,
str: str_t,
list: list_t, list: list_t,
ndarray: ndarray_t, ndarray: ndarray_t,
.. ..
@ -747,9 +723,6 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
impl_sign(unifier, store, bool_t, Some(int32_t)); impl_sign(unifier, store, bool_t, Some(int32_t));
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None); impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None);
/* str ========= */
impl_cmpop(unifier, store, str_t, &[str_t], &[Cmpop::Eq, Cmpop::NotEq], Some(bool_t));
/* list ======== */ /* list ======== */
impl_binop(unifier, store, list_t, &[list_t], Some(list_t), &[Operator::Add]); impl_binop(unifier, store, list_t, &[list_t], Some(list_t), &[Operator::Add]);
impl_binop(unifier, store, list_t, &[int32_t, int64_t], Some(list_t), &[Operator::Mult]); impl_binop(unifier, store, list_t, &[int32_t, int64_t], Some(list_t), &[Operator::Mult]);

View File

@ -1,13 +1,14 @@
use std::{collections::HashMap, fmt::Display}; use std::collections::HashMap;
use std::fmt::Display;
use itertools::Itertools; use crate::typecheck::{magic_methods::HasOpInfo, typedef::TypeEnum};
use nac3parser::ast::{Cmpop, Location, StrRef};
use super::{ use super::{
magic_methods::{Binop, HasOpInfo}, magic_methods::Binop,
typedef::{RecordKey, Type, TypeEnum, Unifier}, typedef::{RecordKey, Type, Unifier},
}; };
use itertools::Itertools;
use nac3parser::ast::{Cmpop, Location, StrRef};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum TypeErrorKind { pub enum TypeErrorKind {
@ -182,10 +183,9 @@ impl<'a> Display for DisplayTypeError<'a> {
} }
result result
} }
( (TypeEnum::TTuple { ty: ty1 }, TypeEnum::TTuple { ty: ty2 })
TypeEnum::TTuple { ty: ty1, is_vararg_ctx: is_vararg1 }, if ty1.len() != ty2.len() =>
TypeEnum::TTuple { ty: ty2, is_vararg_ctx: is_vararg2 }, {
) if !is_vararg1 && !is_vararg2 && ty1.len() != ty2.len() => {
let t1 = self.unifier.stringify_with_notes(*t1, &mut notes); let t1 = self.unifier.stringify_with_notes(*t1, &mut notes);
let t2 = self.unifier.stringify_with_notes(*t2, &mut notes); let t2 = self.unifier.stringify_with_notes(*t2, &mut notes);
write!(f, "Tuple length mismatch: got {t1} and {t2}") write!(f, "Tuple length mismatch: got {t1} and {t2}")

File diff suppressed because it is too large Load Diff

View File

@ -1,19 +1,17 @@
use std::iter::zip; use super::super::{magic_methods::with_fields, typedef::*};
use indexmap::IndexMap;
use indoc::indoc;
use parking_lot::RwLock;
use test_case::test_case;
use nac3parser::{ast::FileName, parser::parse_program};
use super::*; use super::*;
use crate::{ use crate::{
codegen::{CodeGenContext, CodeGenerator}, codegen::CodeGenContext,
symbol_resolver::ValueEnum, symbol_resolver::ValueEnum,
toplevel::{helper::PrimDef, DefinitionId, TopLevelDef}, toplevel::{helper::PrimDef, DefinitionId, TopLevelDef},
typecheck::{magic_methods::with_fields, typedef::*},
}; };
use indexmap::IndexMap;
use indoc::indoc;
use nac3parser::ast::FileName;
use nac3parser::parser::parse_program;
use parking_lot::RwLock;
use std::iter::zip;
use test_case::test_case;
struct Resolver { struct Resolver {
id_to_type: HashMap<StrRef, Type>, id_to_type: HashMap<StrRef, Type>,
@ -43,7 +41,6 @@ impl SymbolResolver for Resolver {
&self, &self,
_: StrRef, _: StrRef,
_: &mut CodeGenContext<'ctx, '_>, _: &mut CodeGenContext<'ctx, '_>,
_: &mut dyn CodeGenerator,
) -> Option<ValueEnum<'ctx>> { ) -> Option<ValueEnum<'ctx>> {
unimplemented!() unimplemented!()
} }
@ -86,12 +83,7 @@ impl TestEnvironment {
}); });
with_fields(&mut unifier, int32, |unifier, fields| { with_fields(&mut unifier, int32, |unifier, fields| {
let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }],
name: "other".into(),
ty: int32,
default_value: None,
is_vararg: false,
}],
ret: int32, ret: int32,
vars: VarMap::new(), vars: VarMap::new(),
})); }));
@ -232,12 +224,7 @@ impl TestEnvironment {
}); });
with_fields(&mut unifier, int32, |unifier, fields| { with_fields(&mut unifier, int32, |unifier, fields| {
let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }],
name: "other".into(),
ty: int32,
default_value: None,
is_vararg: false,
}],
ret: int32, ret: int32,
vars: VarMap::new(), vars: VarMap::new(),
})); }));
@ -520,7 +507,7 @@ impl TestEnvironment {
primitives: &mut self.primitives, primitives: &mut self.primitives,
virtual_checks: &mut self.virtual_checks, virtual_checks: &mut self.virtual_checks,
calls: &mut self.calls, calls: &mut self.calls,
defined_identifiers: HashMap::default(), defined_identifiers: HashSet::default(),
in_handler: false, in_handler: false,
} }
} }
@ -596,9 +583,8 @@ fn test_basic(source: &str, mapping: &HashMap<&str, &str>, virtuals: &[(&str, &s
println!("source:\n{source}"); println!("source:\n{source}");
let mut env = TestEnvironment::new(); let mut env = TestEnvironment::new();
let id_to_name = std::mem::take(&mut env.id_to_name); let id_to_name = std::mem::take(&mut env.id_to_name);
let mut defined_identifiers: HashMap<_, _> = let mut defined_identifiers: HashSet<_> = env.identifier_mapping.keys().copied().collect();
env.identifier_mapping.keys().copied().map(|id| (id, IdentifierInfo::default())).collect(); defined_identifiers.insert("virtual".into());
defined_identifiers.insert("virtual".into(), IdentifierInfo::default());
let mut inferencer = env.get_inferencer(); let mut inferencer = env.get_inferencer();
inferencer.defined_identifiers.clone_from(&defined_identifiers); inferencer.defined_identifiers.clone_from(&defined_identifiers);
let statements = parse_program(source, FileName::default()).unwrap(); let statements = parse_program(source, FileName::default()).unwrap();
@ -743,9 +729,8 @@ fn test_primitive_magic_methods(source: &str, mapping: &HashMap<&str, &str>) {
println!("source:\n{source}"); println!("source:\n{source}");
let mut env = TestEnvironment::basic_test_env(); let mut env = TestEnvironment::basic_test_env();
let id_to_name = std::mem::take(&mut env.id_to_name); let id_to_name = std::mem::take(&mut env.id_to_name);
let mut defined_identifiers: HashMap<_, _> = let mut defined_identifiers: HashSet<_> = env.identifier_mapping.keys().copied().collect();
env.identifier_mapping.keys().copied().map(|id| (id, IdentifierInfo::default())).collect(); defined_identifiers.insert("virtual".into());
defined_identifiers.insert("virtual".into(), IdentifierInfo::default());
let mut inferencer = env.get_inferencer(); let mut inferencer = env.get_inferencer();
inferencer.defined_identifiers.clone_from(&defined_identifiers); inferencer.defined_identifiers.clone_from(&defined_identifiers);
let statements = parse_program(source, FileName::default()).unwrap(); let statements = parse_program(source, FileName::default()).unwrap();

View File

@ -1,28 +1,23 @@
use std::{
borrow::Cow,
cell::RefCell,
collections::{HashMap, HashSet},
fmt::{self, Display},
iter::{repeat, zip},
rc::Rc,
sync::{Arc, Mutex},
};
use indexmap::IndexMap; use indexmap::IndexMap;
use itertools::{repeat_n, Itertools}; use itertools::Itertools;
use std::cell::RefCell;
use std::collections::HashMap;
use std::fmt::{self, Display};
use std::iter::zip;
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use std::{borrow::Cow, collections::HashSet};
use nac3parser::ast::{Cmpop, Location, StrRef, Unaryop}; use nac3parser::ast::{Cmpop, Location, StrRef, Unaryop};
use super::{ use super::magic_methods::Binop;
magic_methods::{Binop, HasOpInfo, OpInfo}, use super::type_error::{TypeError, TypeErrorKind};
type_error::{TypeError, TypeErrorKind}, use super::unification_table::{UnificationKey, UnificationTable};
type_inferencer::PrimitiveStore, use crate::symbol_resolver::SymbolValue;
unification_table::{UnificationKey, UnificationTable}, use crate::toplevel::helper::PrimDef;
}; use crate::toplevel::{DefinitionId, TopLevelContext, TopLevelDef};
use crate::{ use crate::typecheck::magic_methods::OpInfo;
symbol_resolver::SymbolValue, use crate::typecheck::type_inferencer::PrimitiveStore;
toplevel::{helper::PrimDef, DefinitionId, TopLevelContext, TopLevelDef},
};
#[cfg(test)] #[cfg(test)]
mod test; mod test;
@ -120,7 +115,6 @@ pub struct FuncArg {
pub name: StrRef, pub name: StrRef,
pub ty: Type, pub ty: Type,
pub default_value: Option<SymbolValue>, pub default_value: Option<SymbolValue>,
pub is_vararg: bool,
} }
impl FuncArg { impl FuncArg {
@ -239,12 +233,6 @@ pub enum TypeEnum {
TTuple { TTuple {
/// The types of elements present in this tuple. /// The types of elements present in this tuple.
ty: Vec<Type>, ty: Vec<Type>,
/// Whether this tuple is used in a vararg context.
///
/// If `true`, `ty` must only contain one type, and the tuple is assumed to contain any
/// number of `ty`-typed values.
is_vararg_ctx: bool,
}, },
/// An object type. /// An object type.
@ -539,7 +527,7 @@ impl Unifier {
TypeEnum::TVirtual { ty } => self.get_instantiations(*ty).map(|ty| { TypeEnum::TVirtual { ty } => self.get_instantiations(*ty).map(|ty| {
ty.iter().map(|&ty| self.add_ty(TypeEnum::TVirtual { ty })).collect_vec() ty.iter().map(|&ty| self.add_ty(TypeEnum::TVirtual { ty })).collect_vec()
}), }),
TypeEnum::TTuple { ty, is_vararg_ctx } => { TypeEnum::TTuple { ty } => {
let tuples = ty let tuples = ty
.iter() .iter()
.map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty])) .map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty]))
@ -549,12 +537,7 @@ impl Unifier {
None None
} else { } else {
Some( Some(
tuples tuples.into_iter().map(|ty| self.add_ty(TypeEnum::TTuple { ty })).collect(),
.into_iter()
.map(|ty| {
self.add_ty(TypeEnum::TTuple { ty, is_vararg_ctx: *is_vararg_ctx })
})
.collect(),
) )
} }
} }
@ -598,7 +581,7 @@ impl Unifier {
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)), TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
TCall { .. } => false, TCall { .. } => false,
TVirtual { ty } => self.is_concrete(*ty, allowed_typevars), TVirtual { ty } => self.is_concrete(*ty, allowed_typevars),
TTuple { ty, .. } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)), TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)),
TObj { params: vars, .. } => { TObj { params: vars, .. } => {
vars.values().all(|ty| self.is_concrete(*ty, allowed_typevars)) vars.values().all(|ty| self.is_concrete(*ty, allowed_typevars))
} }
@ -666,7 +649,6 @@ impl Unifier {
// Get details about the function signature/parameters. // Get details about the function signature/parameters.
let num_params = signature.args.len(); let num_params = signature.args.len();
let is_vararg = signature.args.iter().any(|arg| arg.is_vararg);
// Force the type vars in `b` and `signature' to be up-to-date. // Force the type vars in `b` and `signature' to be up-to-date.
let b = self.instantiate_fun(b, signature); let b = self.instantiate_fun(b, signature);
@ -677,8 +659,8 @@ impl Unifier {
let num_args = posargs.len() + kwargs.len(); let num_args = posargs.len() + kwargs.len();
// Now we check the arguments against the parameters, // Now we check the arguments against the parameters,
// and depending on what `call_info` is, we might change how `unify_call()` behaves // and depending on what `call_info` is, we might change how the behavior `unify_call()`
// to improve user error messages when type checking fails. // in hopes to improve user error messages when type checking fails.
match operator_info { match operator_info {
Some(OperatorInfo::IsBinaryOp { self_type, operator }) => { Some(OperatorInfo::IsBinaryOp { self_type, operator }) => {
// The call is written in the form of (say) `a + b`. // The call is written in the form of (say) `a + b`.
@ -755,7 +737,7 @@ impl Unifier {
}; };
// Check for "too many arguments" // Check for "too many arguments"
if !is_vararg && num_params < posargs.len() { if num_params < posargs.len() {
let expected_min_count = let expected_min_count =
signature.args.iter().filter(|param| param.is_required()).count(); signature.args.iter().filter(|param| param.is_required()).count();
let expected_max_count = num_params; let expected_max_count = num_params;
@ -788,19 +770,6 @@ impl Unifier {
type_check_arg(param.name, param.ty, arg_ty)?; type_check_arg(param.name, param.ty, arg_ty)?;
} }
if is_vararg {
debug_assert!(!signature.args.is_empty());
let vararg_args = posargs.iter().skip(signature.args.len());
let vararg_param = signature.args.last().unwrap();
for (&arg_ty, param) in zip(vararg_args, repeat(vararg_param)) {
// `param_info` for this argument would've already been marked as supplied
// during non-vararg posarg typecheck
type_check_arg(param.name, param.ty, arg_ty)?;
}
}
// Now consume all keyword arguments and typecheck them. // Now consume all keyword arguments and typecheck them.
for (&param_name, &arg_ty) in kwargs { for (&param_name, &arg_ty) in kwargs {
// We will also use this opportunity to check if this keyword argument is "legal". // We will also use this opportunity to check if this keyword argument is "legal".
@ -990,10 +959,7 @@ impl Unifier {
self.unify_impl(x, b, false)?; self.unify_impl(x, b, false)?;
self.set_a_to_b(a, x); self.set_a_to_b(a, x);
} }
( (TVar { fields: Some(fields), range, is_const_generic: false, .. }, TTuple { ty }) => {
TVar { fields: Some(fields), range, is_const_generic: false, .. },
TTuple { ty, .. },
) => {
let len = i32::try_from(ty.len()).unwrap(); let len = i32::try_from(ty.len()).unwrap();
for (k, v) in fields { for (k, v) in fields {
match *k { match *k {
@ -1014,18 +980,8 @@ impl Unifier {
self.unify_impl(v.ty, ty[ind as usize], false) self.unify_impl(v.ty, ty[ind as usize], false)
.map_err(|e| e.at(v.loc))?; .map_err(|e| e.at(v.loc))?;
} }
RecordKey::Str(s) => { RecordKey::Str(_) => {
let tuple_fns = [ return Err(TypeError::new(TypeErrorKind::NoSuchField(*k, b), v.loc))
Cmpop::Eq.op_info().method_name,
Cmpop::NotEq.op_info().method_name,
];
if !tuple_fns.into_iter().any(|op| s.to_string() == op) {
return Err(TypeError::new(
TypeErrorKind::NoSuchField(*k, b),
v.loc,
));
}
} }
} }
} }
@ -1100,47 +1056,15 @@ impl Unifier {
self.set_a_to_b(a, b); self.set_a_to_b(a, b);
} }
( (TTuple { ty: ty1 }, TTuple { ty: ty2 }) => {
TTuple { ty: ty1, is_vararg_ctx: is_vararg1 }, if ty1.len() != ty2.len() {
TTuple { ty: ty2, is_vararg_ctx: is_vararg2 }, return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None));
) => { }
// Rules for Tuples: for (x, y) in ty1.iter().zip(ty2.iter()) {
// - ty1: is_vararg && ty2: is_vararg -> ty1[0] == ty2[0] if self.unify_impl(*x, *y, false).is_err() {
// - ty1: is_vararg && ty2: !is_vararg -> type error (not enough info to infer the correct number of arguments) return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None));
// - ty1: !is_vararg && ty2: is_vararg -> ty1[..] == ty2[0]
// - ty1: !is_vararg && ty2: !is_vararg -> ty1.len() == ty2.len() && ty1[i] == ty2[i]
debug_assert!(!is_vararg1 || ty1.len() == 1);
debug_assert!(!is_vararg2 || ty2.len() == 1);
match (*is_vararg1, *is_vararg2) {
(true, true) => {
if self.unify_impl(ty1[0], ty2[0], false).is_err() {
return Self::incompatible_types(a, b);
}
}
(true, false) => return Self::incompatible_types(a, b),
(false, true) => {
for y in ty2 {
if self.unify_impl(ty1[0], *y, false).is_err() {
return Self::incompatible_types(a, b);
}
}
}
(false, false) => {
if ty1.len() != ty2.len() {
return Self::incompatible_types(a, b);
}
for (x, y) in ty1.iter().zip(ty2.iter()) {
if self.unify_impl(*x, *y, false).is_err() {
return Self::incompatible_types(a, b);
}
}
} }
} }
self.set_a_to_b(a, b); self.set_a_to_b(a, b);
} }
(TVar { fields: Some(map), range, .. }, TObj { obj_id, fields, params }) => { (TVar { fields: Some(map), range, .. }, TObj { obj_id, fields, params }) => {
@ -1383,22 +1307,10 @@ impl Unifier {
TypeEnum::TLiteral { values, .. } => { TypeEnum::TLiteral { values, .. } => {
format!("const({})", values.iter().map(|v| format!("{v:?}")).join(", ")) format!("const({})", values.iter().map(|v| format!("{v:?}")).join(", "))
} }
TypeEnum::TTuple { ty, is_vararg_ctx } => { TypeEnum::TTuple { ty } => {
if *is_vararg_ctx { let mut fields =
debug_assert_eq!(ty.len(), 1); ty.iter().map(|v| self.internal_stringify(*v, obj_to_name, var_to_name, notes));
let field = self.internal_stringify( format!("tuple[{}]", fields.join(", "))
*ty.iter().next().unwrap(),
obj_to_name,
var_to_name,
notes,
);
format!("tuple[*{field}]")
} else {
let mut fields = ty
.iter()
.map(|v| self.internal_stringify(*v, obj_to_name, var_to_name, notes));
format!("tuple[{}]", fields.join(", "))
}
} }
TypeEnum::TVirtual { ty } => { TypeEnum::TVirtual { ty } => {
format!( format!(
@ -1423,21 +1335,17 @@ impl Unifier {
.args .args
.iter() .iter()
.map(|arg| { .map(|arg| {
let vararg_prefix = if arg.is_vararg { "*" } else { "" };
if let Some(dv) = &arg.default_value { if let Some(dv) = &arg.default_value {
format!( format!(
"{}:{}{}={}", "{}:{}={}",
arg.name, arg.name,
vararg_prefix,
self.internal_stringify(arg.ty, obj_to_name, var_to_name, notes), self.internal_stringify(arg.ty, obj_to_name, var_to_name, notes),
dv dv
) )
} else { } else {
format!( format!(
"{}:{}{}", "{}:{}",
arg.name, arg.name,
vararg_prefix,
self.internal_stringify(arg.ty, obj_to_name, var_to_name, notes) self.internal_stringify(arg.ty, obj_to_name, var_to_name, notes)
) )
} }
@ -1523,7 +1431,7 @@ impl Unifier {
match &*ty { match &*ty {
TypeEnum::TRigidVar { .. } | TypeEnum::TLiteral { .. } => None, TypeEnum::TRigidVar { .. } | TypeEnum::TLiteral { .. } => None,
TypeEnum::TVar { id, .. } => mapping.get(id).copied(), TypeEnum::TVar { id, .. } => mapping.get(id).copied(),
TypeEnum::TTuple { ty, is_vararg_ctx } => { TypeEnum::TTuple { ty } => {
let mut new_ty = Cow::from(ty); let mut new_ty = Cow::from(ty);
for (i, t) in ty.iter().enumerate() { for (i, t) in ty.iter().enumerate() {
if let Some(t1) = self.subst_impl(*t, mapping, cache) { if let Some(t1) = self.subst_impl(*t, mapping, cache) {
@ -1531,10 +1439,7 @@ impl Unifier {
} }
} }
if matches!(new_ty, Cow::Owned(_)) { if matches!(new_ty, Cow::Owned(_)) {
Some(self.add_ty(TypeEnum::TTuple { Some(self.add_ty(TypeEnum::TTuple { ty: new_ty.into_owned() }))
ty: new_ty.into_owned(),
is_vararg_ctx: *is_vararg_ctx,
}))
} else { } else {
None None
} }
@ -1694,37 +1599,16 @@ impl Unifier {
} }
} }
(TVar { range, .. }, _) => self.check_var_compatibility(b, range).or(Err(())), (TVar { range, .. }, _) => self.check_var_compatibility(b, range).or(Err(())),
( (TTuple { ty: ty1 }, TTuple { ty: ty2 }) if ty1.len() == ty2.len() => {
TTuple { ty: ty1, is_vararg_ctx: is_vararg1 }, let ty: Vec<_> = zip(ty1.iter(), ty2.iter())
TTuple { ty: ty2, is_vararg_ctx: is_vararg2 }, .map(|(a, b)| self.get_intersection(*a, *b))
) => { .try_collect()?;
if *is_vararg1 && *is_vararg2 { if ty.iter().any(Option::is_some) {
let isect_ty = self.get_intersection(ty1[0], ty2[0])?; Ok(Some(self.add_ty(TTuple {
Ok(isect_ty.map(|ty| self.add_ty(TTuple { ty: vec![ty], is_vararg_ctx: true }))) ty: zip(ty, ty1.iter()).map(|(a, b)| a.unwrap_or(*b)).collect(),
})))
} else { } else {
let zip_iter: Box<dyn Iterator<Item = (&Type, &Type)>> = Ok(None)
match (*is_vararg1, *is_vararg2) {
(true, _) => Box::new(repeat_n(&ty1[0], ty2.len()).zip(ty2.iter())),
(_, false) => Box::new(ty1.iter().zip(repeat_n(&ty2[0], ty1.len()))),
_ => {
if ty1.len() != ty2.len() {
return Err(());
}
Box::new(ty1.iter().zip(ty2.iter()))
}
};
let ty: Vec<_> =
zip_iter.map(|(a, b)| self.get_intersection(*a, *b)).try_collect()?;
Ok(if ty.iter().any(Option::is_some) {
Some(self.add_ty(TTuple {
ty: zip(ty, ty1.iter()).map(|(a, b)| a.unwrap_or(*b)).collect(),
is_vararg_ctx: false,
}))
} else {
None
})
} }
} }
// TODO(Derppening): #444 // TODO(Derppening): #444

View File

@ -1,12 +1,10 @@
use std::collections::HashMap; use super::super::magic_methods::with_fields;
use super::*;
use indoc::indoc; use indoc::indoc;
use itertools::Itertools; use itertools::Itertools;
use std::collections::HashMap;
use test_case::test_case; use test_case::test_case;
use super::*;
use crate::typecheck::magic_methods::with_fields;
impl Unifier { impl Unifier {
/// Check whether two types are equal. /// Check whether two types are equal.
fn eq(&mut self, a: Type, b: Type) -> bool { fn eq(&mut self, a: Type, b: Type) -> bool {
@ -30,10 +28,7 @@ impl Unifier {
TypeEnum::TVar { fields: Some(map1), .. }, TypeEnum::TVar { fields: Some(map1), .. },
TypeEnum::TVar { fields: Some(map2), .. }, TypeEnum::TVar { fields: Some(map2), .. },
) => self.map_eq2(map1, map2), ) => self.map_eq2(map1, map2),
( (TypeEnum::TTuple { ty: ty1 }, TypeEnum::TTuple { ty: ty2 }) => {
TypeEnum::TTuple { ty: ty1, is_vararg_ctx: false },
TypeEnum::TTuple { ty: ty2, is_vararg_ctx: false },
) => {
ty1.len() == ty2.len() ty1.len() == ty2.len()
&& ty1.iter().zip(ty2.iter()).all(|(t1, t2)| self.eq(*t1, *t2)) && ty1.iter().zip(ty2.iter()).all(|(t1, t2)| self.eq(*t1, *t2))
} }
@ -183,7 +178,7 @@ impl TestEnvironment {
ty.push(result.0); ty.push(result.0);
s = result.1; s = result.1;
} }
(self.unifier.add_ty(TypeEnum::TTuple { ty, is_vararg_ctx: false }), &s[1..]) (self.unifier.add_ty(TypeEnum::TTuple { ty }), &s[1..])
} }
"Record" => { "Record" => {
let mut s = &typ[end..]; let mut s = &typ[end..];
@ -613,7 +608,7 @@ fn test_instantiation() {
let v1 = env.unifier.get_fresh_var_with_range(&[list_v, int], None, None).ty; let v1 = env.unifier.get_fresh_var_with_range(&[list_v, int], None, None).ty;
let v2 = env.unifier.get_fresh_var_with_range(&[list_int, float], None, None).ty; let v2 = env.unifier.get_fresh_var_with_range(&[list_int, float], None, None).ty;
let t = env.unifier.get_dummy_var().ty; let t = env.unifier.get_dummy_var().ty;
let tuple = env.unifier.add_ty(TypeEnum::TTuple { ty: vec![v, v1, v2], is_vararg_ctx: false }); let tuple = env.unifier.add_ty(TypeEnum::TTuple { ty: vec![v, v1, v2] });
let v3 = env.unifier.get_fresh_var_with_range(&[tuple, t], None, None).ty; let v3 = env.unifier.get_fresh_var_with_range(&[tuple, t], None, None).ty;
// t = TypeVar('t') // t = TypeVar('t')
// v = TypeVar('v', int, bool) // v = TypeVar('v', int, bool)

5
nac3core/src/util.rs Normal file
View File

@ -0,0 +1,5 @@
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SizeVariant {
Bits32,
Bits64,
}

View File

@ -238,7 +238,7 @@ impl<'a> EH_Frame<'a> {
/// From the [specification](https://refspecs.linuxfoundation.org/LSB_5.0.0/LSB-Core-generic/LSB-Core-generic/ehframechpt.html): /// From the [specification](https://refspecs.linuxfoundation.org/LSB_5.0.0/LSB-Core-generic/LSB-Core-generic/ehframechpt.html):
/// ///
/// > Each CFI record contains a Common Information Entry (CIE) record followed by 1 or more Frame /// > Each CFI record contains a Common Information Entry (CIE) record followed by 1 or more Frame
/// > Description Entry (FDE) records. /// Description Entry (FDE) records.
pub struct CFI_Record<'a> { pub struct CFI_Record<'a> {
// It refers to the augmentation data that corresponds to 'R' in the augmentation string // It refers to the augmentation data that corresponds to 'R' in the augmentation string
fde_pointer_encoding: u8, fde_pointer_encoding: u8,

View File

@ -21,12 +21,13 @@
clippy::wildcard_imports clippy::wildcard_imports
)] )]
use std::{collections::HashMap, mem, ptr, slice, str};
use byteorder::{ByteOrder, LittleEndian};
use dwarf::*; use dwarf::*;
use elf::*; use elf::*;
use std::collections::HashMap;
use std::{mem, ptr, slice, str};
extern crate byteorder;
use byteorder::{ByteOrder, LittleEndian};
mod dwarf; mod dwarf;
mod elf; mod elf;

View File

@ -8,15 +8,15 @@ license = "MIT"
edition = "2021" edition = "2021"
[build-dependencies] [build-dependencies]
lalrpop = "0.22" lalrpop = "0.20"
[dependencies] [dependencies]
nac3ast = { path = "../nac3ast" } nac3ast = { path = "../nac3ast" }
lalrpop-util = "0.22" lalrpop-util = "0.20"
log = "0.4" log = "0.4"
unic-emoji-char = "0.9" unic-emoji-char = "0.9"
unic-ucd-ident = "0.9" unic-ucd-ident = "0.9"
unicode_names2 = "1.3" unicode_names2 = "1.2"
phf = { version = "0.11", features = ["macros"] } phf = { version = "0.11", features = ["macros"] }
ahash = "0.8" ahash = "0.8"

View File

@ -1,10 +1,8 @@
use crate::{ use crate::ast::Ident;
ast::{Ident, Location}, use crate::ast::Location;
error::*, use crate::error::*;
token::Tok, use crate::token::Tok;
};
use lalrpop_util::ParseError; use lalrpop_util::ParseError;
use nac3ast::*; use nac3ast::*;
pub fn make_config_comment( pub fn make_config_comment(

View File

@ -1,11 +1,12 @@
//! Define internal parse error types //! Define internal parse error types
//! The goal is to provide a matching and a safe error API, maksing errors from LALR //! The goal is to provide a matching and a safe error API, maksing errors from LALR
use std::error::Error;
use std::fmt;
use lalrpop_util::ParseError as LalrpopError; use lalrpop_util::ParseError as LalrpopError;
use crate::{ast::Location, token::Tok}; use crate::ast::Location;
use crate::token::Tok;
use std::error::Error;
use std::fmt;
/// Represents an error during lexical scanning. /// Represents an error during lexical scanning.
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]

View File

@ -1,11 +1,12 @@
use std::{iter, mem, str}; use std::iter;
use std::mem;
use std::str;
use crate::ast::{Constant, ConversionFlag, Expr, ExprKind, Location};
use crate::error::{FStringError, FStringErrorType, ParseError};
use crate::parser::parse_expression;
use self::FStringErrorType::*; use self::FStringErrorType::*;
use crate::{
ast::{Constant, ConversionFlag, Expr, ExprKind, Location},
error::{FStringError, FStringErrorType, ParseError},
parser::parse_expression,
};
struct FStringParser<'a> { struct FStringParser<'a> {
chars: iter::Peekable<str::Chars<'a>>, chars: iter::Peekable<str::Chars<'a>>,

View File

@ -1,11 +1,8 @@
use ahash::RandomState;
use std::collections::HashSet; use std::collections::HashSet;
use ahash::RandomState; use crate::ast;
use crate::error::{LexicalError, LexicalErrorType};
use crate::{
ast,
error::{LexicalError, LexicalErrorType},
};
pub struct ArgumentList { pub struct ArgumentList {
pub args: Vec<ast::Expr>, pub args: Vec<ast::Expr>,

View File

@ -1,16 +1,16 @@
//! This module takes care of lexing python source text. //! This module takes care of lexing python source text.
//! //!
//! This means source code is translated into separate tokens. //! This means source code is translated into separate tokens.
use std::{char, cmp::Ordering, num::IntErrorKind, str::FromStr};
use unic_emoji_char::is_emoji_presentation;
use unic_ucd_ident::{is_xid_continue, is_xid_start};
pub use super::token::Tok; pub use super::token::Tok;
use crate::{ use crate::ast::{FileName, Location};
ast::{FileName, Location}, use crate::error::{LexicalError, LexicalErrorType};
error::{LexicalError, LexicalErrorType}, use std::char;
}; use std::cmp::Ordering;
use std::num::IntErrorKind;
use std::str::FromStr;
use unic_emoji_char::is_emoji_presentation;
use unic_ucd_ident::{is_xid_continue, is_xid_start};
#[derive(Clone, Copy, PartialEq, Debug, Default)] #[derive(Clone, Copy, PartialEq, Debug, Default)]
struct IndentationLevel { struct IndentationLevel {

View File

@ -5,16 +5,14 @@
//! parse a whole program, a single statement, or a single //! parse a whole program, a single statement, or a single
//! expression. //! expression.
use nac3ast::Location;
use std::iter; use std::iter;
use nac3ast::Location; use crate::ast::{self, FileName};
use crate::error::ParseError;
use crate::lexer;
pub use crate::mode::Mode; pub use crate::mode::Mode;
use crate::{ use crate::python;
ast::{self, FileName},
error::ParseError,
lexer, python,
};
/* /*
* Parse python code. * Parse python code.

View File

@ -1,8 +1,7 @@
//! Different token definitions. //! Different token definitions.
//! Loosely based on token.h from CPython source: //! Loosely based on token.h from CPython source:
use std::fmt::{self, Write};
use crate::ast; use crate::ast;
use std::fmt::{self, Write};
/// Python source code can be tokenized in a sequence of these tokens. /// Python source code can be tokenized in a sequence of these tokens.
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]

View File

@ -4,13 +4,16 @@ version = "0.1.0"
authors = ["M-Labs"] authors = ["M-Labs"]
edition = "2021" edition = "2021"
[features]
no-escape-analysis = ["nac3core/no-escape-analysis"]
[dependencies] [dependencies]
parking_lot = "0.12" parking_lot = "0.12"
nac3parser = { path = "../nac3parser" }
nac3core = { path = "../nac3core" } nac3core = { path = "../nac3core" }
[dependencies.clap] [dependencies.clap]
version = "4.5" version = "4.5"
features = ["derive"] features = ["derive"]
[dependencies.inkwell]
version = "0.4"
default-features = false
features = ["llvm14-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]

View File

@ -3,66 +3,23 @@
set -e set -e
if [ -z "$1" ]; then if [ -z "$1" ]; then
echo "No argument supplied" echo "Requires at least one argument"
exit 1 exit 1
fi fi
declare -a nac3args declare -a nac3args
while [ $# -gt 1 ]; do
case "$1" in
--help)
echo "Usage: check_demo.sh [--debug] [-i686] -- [NAC3ARGS...] demo"
exit
;;
--debug)
debug=1
;;
-i686)
i686=1
;;
--)
shift
break
;;
*)
echo "Unrecognized argument \"$1\""
exit 1
;;
esac
shift
done
while [ $# -gt 1 ]; do while [ $# -gt 1 ]; do
nac3args+=("$1") nac3args+=("$1")
shift shift
done done
demo="$1" demo="$1"
echo -n "Checking $demo... "
echo "### Checking $demo..."
echo ">>>>>> Running $demo with the Python interpreter"
./interpret_demo.py "$demo" > interpreted.log ./interpret_demo.py "$demo" > interpreted.log
./run_demo.sh --out run.log "${nac3args[@]}" "$demo"
./run_demo.sh --lli --out run_lli.log "${nac3args[@]}" "$demo"
diff -Nau interpreted.log run.log
diff -Nau interpreted.log run_lli.log
echo "ok"
if [ -n "$i686" ]; then rm -f interpreted.log run.log run_lli.log
echo "...... Trying NAC3's 32-bit code generator output"
if [ -n "$debug" ]; then
./run_demo.sh --debug -i686 --out run_32.log -- "${nac3args[@]}" "$demo"
else
./run_demo.sh -i686 --out run_32.log -- "${nac3args[@]}" "$demo"
fi
diff -Nau interpreted.log run_32.log
fi
echo "...... Trying NAC3's 64-bit code generator output"
if [ -n "$debug" ]; then
./run_demo.sh --debug --out run_64.log -- "${nac3args[@]}" "$demo"
else
./run_demo.sh --out run_64.log -- "${nac3args[@]}" "$demo"
fi
diff -Nau interpreted.log run_64.log
echo "...... OK"
rm -f interpreted.log \
run_32.log run_64.log

View File

@ -2,11 +2,6 @@
set -e set -e
if [ "$1" == "--help" ]; then
echo "Usage: check_demos.sh [CHECKARGS...] [--] [NAC3ARGS...]"
exit
fi
count=0 count=0
for demo in src/*.py; do for demo in src/*.py; do
./check_demo.sh "$@" "$demo" ./check_demo.sh "$@" "$demo"

View File

@ -6,12 +6,14 @@
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#define usize size_t
double dbl_nan(void) { double dbl_nan(void) {
return NAN; return NAN;
} }
double dbl_inf(void) { double dbl_inf(void) {
return INFINITY; return INFINITY;
} }
void output_bool(bool x) { void output_bool(bool x) {
@ -19,19 +21,19 @@ void output_bool(bool x) {
} }
void output_int32(int32_t x) { void output_int32(int32_t x) {
printf("%" PRId32 "\n", x); printf("%"PRId32"\n", x);
} }
void output_int64(int64_t x) { void output_int64(int64_t x) {
printf("%" PRId64 "\n", x); printf("%"PRId64"\n", x);
} }
void output_uint32(uint32_t x) { void output_uint32(uint32_t x) {
printf("%" PRIu32 "\n", x); printf("%"PRIu32"\n", x);
} }
void output_uint64(uint64_t x) { void output_uint64(uint64_t x) {
printf("%" PRIu64 "\n", x); printf("%"PRIu64"\n", x);
} }
void output_float64(double x) { void output_float64(double x) {
@ -52,7 +54,7 @@ void output_range(int32_t range[3]) {
} }
void output_asciiart(int32_t x) { void output_asciiart(int32_t x) {
static const char* chars = " .,-:;i+hHM$*#@ "; static const char *chars = " .,-:;i+hHM$*#@ ";
if (x < 0) { if (x < 0) {
putchar('\n'); putchar('\n');
} else { } else {
@ -61,15 +63,15 @@ void output_asciiart(int32_t x) {
} }
struct cslice { struct cslice {
void* data; void *data;
size_t len; usize len;
}; };
void output_int32_list(struct cslice* slice) { void output_int32_list(struct cslice *slice) {
const int32_t* data = (int32_t*)slice->data; const int32_t *data = (int32_t *) slice->data;
putchar('['); putchar('[');
for (size_t i = 0; i < slice->len; ++i) { for (usize i = 0; i < slice->len; ++i) {
if (i == slice->len - 1) { if (i == slice->len - 1) {
printf("%d", data[i]); printf("%d", data[i]);
} else { } else {
@ -80,23 +82,23 @@ void output_int32_list(struct cslice* slice) {
putchar('\n'); putchar('\n');
} }
void output_str(struct cslice* slice) { void output_str(struct cslice *slice) {
const char* data = (const char*)slice->data; const char *data = (const char *) slice->data;
for (size_t i = 0; i < slice->len; ++i) { for (usize i = 0; i < slice->len; ++i) {
putchar(data[i]); putchar(data[i]);
} }
} }
void output_strln(struct cslice* slice) { void output_strln(struct cslice *slice) {
output_str(slice); output_str(slice);
putchar('\n'); putchar('\n');
} }
uint64_t dbg_stack_address(__attribute__((unused)) struct cslice* slice) { uint64_t dbg_stack_address(__attribute__((unused)) struct cslice *slice) {
int i; int i;
void* ptr = (void*)&i; void *ptr = (void *) &i;
return (uintptr_t)ptr; return (uintptr_t) ptr;
} }
uint32_t __nac3_personality(uint32_t state, uint32_t exception_object, uint32_t context) { uint32_t __nac3_personality(uint32_t state, uint32_t exception_object, uint32_t context) {
@ -105,26 +107,8 @@ uint32_t __nac3_personality(uint32_t state, uint32_t exception_object, uint32_t
__builtin_unreachable(); __builtin_unreachable();
} }
// See `struct Exception<'a>` in uint32_t __nac3_raise(uint32_t state, uint32_t exception_object, uint32_t context) {
// https://github.com/m-labs/artiq/blob/master/artiq/firmware/libeh/eh_artiq.rs printf("__nac3_raise(state: %u, exception_object: %u, context: %u)\n", state, exception_object, context);
struct Exception {
uint32_t id;
struct cslice file;
uint32_t line;
uint32_t column;
struct cslice function;
struct cslice message;
int64_t param[3];
};
uint32_t __nac3_raise(struct Exception* e) {
printf("__nac3_raise called. Exception details:\n");
printf(" ID: %" PRIu32 "\n", e->id);
printf(" Location: %*s:%" PRIu32 ":%" PRIu32 "\n", (int)e->file.len, (const char*)e->file.data, e->line,
e->column);
printf(" Function: %*s\n", (int)e->function.len, (const char*)e->function.data);
printf(" Message: \"%*s\"\n", (int)e->message.len, (const char*)e->message.data);
printf(" Params: {0}=%" PRId64 ", {1}=%" PRId64 ", {2}=%" PRId64 "\n", e->param[0], e->param[1], e->param[2]);
exit(101); exit(101);
__builtin_unreachable(); __builtin_unreachable();
} }

View File

@ -6,7 +6,6 @@ import importlib.machinery
import math import math
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
import scipy as sp
import pathlib import pathlib
from numpy import int32, int64, uint32, uint64 from numpy import int32, int64, uint32, uint64
@ -168,7 +167,7 @@ def patch(module):
module.ceil64 = _ceil module.ceil64 = _ceil
module.np_ceil = np.ceil module.np_ceil = np.ceil
# NumPy NDArray factory functions # NumPy ndarray functions
module.ndarray = NDArray module.ndarray = NDArray
module.np_ndarray = np.ndarray module.np_ndarray = np.ndarray
module.np_empty = np.empty module.np_empty = np.empty
@ -184,10 +183,8 @@ def patch(module):
module.np_isinf = np.isinf module.np_isinf = np.isinf
module.np_min = np.min module.np_min = np.min
module.np_minimum = np.minimum module.np_minimum = np.minimum
module.np_argmin = np.argmin
module.np_max = np.max module.np_max = np.max
module.np_maximum = np.maximum module.np_maximum = np.maximum
module.np_argmax = np.argmax
module.np_sin = np.sin module.np_sin = np.sin
module.np_cos = np.cos module.np_cos = np.cos
module.np_exp = np.exp module.np_exp = np.exp
@ -218,10 +215,8 @@ def patch(module):
module.np_ldexp = np.ldexp module.np_ldexp = np.ldexp
module.np_hypot = np.hypot module.np_hypot = np.hypot
module.np_nextafter = np.nextafter module.np_nextafter = np.nextafter
module.np_transpose = np.transpose
module.np_reshape = np.reshape
# SciPy Math functions # SciPy Math Functions
module.sp_spec_erf = special.erf module.sp_spec_erf = special.erf
module.sp_spec_erfc = special.erfc module.sp_spec_erfc = special.erfc
module.sp_spec_gamma = special.gamma module.sp_spec_gamma = special.gamma
@ -229,19 +224,14 @@ def patch(module):
module.sp_spec_j0 = special.j0 module.sp_spec_j0 = special.j0
module.sp_spec_j1 = special.j1 module.sp_spec_j1 = special.j1
# Linalg functions # NumPy NDArray Functions
module.np_dot = np.dot module.np_ndarray = np.ndarray
module.np_linalg_cholesky = np.linalg.cholesky module.np_empty = np.empty
module.np_linalg_qr = np.linalg.qr module.np_zeros = np.zeros
module.np_linalg_svd = np.linalg.svd module.np_ones = np.ones
module.np_linalg_inv = np.linalg.inv module.np_full = np.full
module.np_linalg_pinv = np.linalg.pinv module.np_eye = np.eye
module.np_linalg_matrix_power = np.linalg.matrix_power module.np_identity = np.identity
module.np_linalg_det = np.linalg.det
module.sp_linalg_lu = lambda x: sp.linalg.lu(x, True)
module.sp_linalg_schur = sp.linalg.schur
module.sp_linalg_hessenberg = lambda x: sp.linalg.hessenberg(x, True)
def file_import(filename, prefix="file_import_"): def file_import(filename, prefix="file_import_"):
filename = pathlib.Path(filename) filename = pathlib.Path(filename)

View File

@ -1,114 +0,0 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
[[package]]
name = "approx"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6"
dependencies = [
"num-traits",
]
[[package]]
name = "autocfg"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
[[package]]
name = "cslice"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f8cb7306107e4b10e64994de6d3274bd08996a7c1322a27b86482392f96be0a"
[[package]]
name = "libm"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
[[package]]
name = "linalg"
version = "0.1.0"
dependencies = [
"cslice",
"nalgebra",
]
[[package]]
name = "nalgebra"
version = "0.32.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b5c17de023a86f59ed79891b2e5d5a94c705dbe904a5b5c9c952ea6221b03e4"
dependencies = [
"approx",
"num-complex",
"num-rational",
"num-traits",
"simba",
"typenum",
]
[[package]]
name = "num-complex"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
dependencies = [
"num-traits",
]
[[package]]
name = "num-integer"
version = "0.1.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
dependencies = [
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
dependencies = [
"num-integer",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
"libm",
]
[[package]]
name = "paste"
version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
[[package]]
name = "simba"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "061507c94fc6ab4ba1c9a0305018408e312e17c041eb63bef8aa726fa33aceae"
dependencies = [
"approx",
"num-complex",
"num-traits",
"paste",
]
[[package]]
name = "typenum"
version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"

View File

@ -1,13 +0,0 @@
[package]
name = "linalg"
version = "0.1.0"
edition = "2021"
[lib]
crate-type = ["staticlib"]
[dependencies]
nalgebra = {version = "0.32.6", default-features = false, features = ["libm", "alloc"]}
cslice = "0.3.0"
[workspace]

View File

@ -1,406 +0,0 @@
// Uses `nalgebra` crate to invoke `np_linalg` and `sp_linalg` functions
// When converting between `nalgebra::Matrix` and `NDArray` following considerations are necessary
//
// * Both `nalgebra::Matrix` and `NDArray` require their content to be stored in row-major order
// * `NDArray` data pointer can be directly read and converted to `nalgebra::Matrix` (row and column number must be known)
// * `nalgebra::Matrix::as_slice` returns the content of matrix in column-major order and initial data needs to be transposed before storing it in `NDArray` data pointer
use core::slice;
use nalgebra::DMatrix;
fn report_error(
error_name: &str,
fn_name: &str,
file_name: &str,
line_num: u32,
col_num: u32,
err_msg: &str,
) -> ! {
panic!(
"Exception {} from {} in {}:{}:{}, message: {}",
error_name, fn_name, file_name, line_num, col_num, err_msg
);
}
pub struct InputMatrix {
pub ndims: usize,
pub dims: *const usize,
pub data: *mut f64,
}
impl InputMatrix {
fn get_dims(&mut self) -> Vec<usize> {
let dims = unsafe { slice::from_raw_parts(self.dims, self.ndims) };
dims.to_vec()
}
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_cholesky(mat1: *mut InputMatrix, out: *mut InputMatrix) {
let mat1 = mat1.as_mut().unwrap();
let out = out.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "np_linalg_cholesky", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
if dim1[0] != dim1[1] {
let err_msg =
format!("last 2 dimensions of the array must be square: {0} != {1}", dim1[0], dim1[1]);
report_error("LinAlgError", "np_linalg_cholesky", file!(), line!(), column!(), &err_msg);
}
let outdim = out.get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let result = matrix1.cholesky();
match result {
Some(res) => {
out_slice.copy_from_slice(res.unpack().transpose().as_slice());
}
None => {
report_error(
"LinAlgError",
"np_linalg_cholesky",
file!(),
line!(),
column!(),
"Matrix is not positive definite",
);
}
};
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_qr(
mat1: *mut InputMatrix,
out_q: *mut InputMatrix,
out_r: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let out_q = out_q.as_mut().unwrap();
let out_r = out_r.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "np_linalg_cholesky", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let outq_dim = (*out_q).get_dims();
let outr_dim = (*out_r).get_dims();
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let out_q_slice = unsafe { slice::from_raw_parts_mut(out_q.data, outq_dim[0] * outq_dim[1]) };
let out_r_slice = unsafe { slice::from_raw_parts_mut(out_r.data, outr_dim[0] * outr_dim[1]) };
// Refer to https://github.com/dimforge/nalgebra/issues/735
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let res = matrix1.qr();
let (q, r) = res.unpack();
// Uses different algo need to match numpy
out_q_slice.copy_from_slice(q.transpose().as_slice());
out_r_slice.copy_from_slice(r.transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_svd(
mat1: *mut InputMatrix,
outu: *mut InputMatrix,
outs: *mut InputMatrix,
outvh: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let outu = outu.as_mut().unwrap();
let outs = outs.as_mut().unwrap();
let outvh = outvh.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "np_linalg_svd", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let outu_dim = (*outu).get_dims();
let outs_dim = (*outs).get_dims();
let outvh_dim = (*outvh).get_dims();
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let out_u_slice = unsafe { slice::from_raw_parts_mut(outu.data, outu_dim[0] * outu_dim[1]) };
let out_s_slice = unsafe { slice::from_raw_parts_mut(outs.data, outs_dim[0]) };
let out_vh_slice =
unsafe { slice::from_raw_parts_mut(outvh.data, outvh_dim[0] * outvh_dim[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let result = matrix.svd(true, true);
out_u_slice.copy_from_slice(result.u.unwrap().transpose().as_slice());
out_s_slice.copy_from_slice(result.singular_values.as_slice());
out_vh_slice.copy_from_slice(result.v_t.unwrap().transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_inv(mat1: *mut InputMatrix, out: *mut InputMatrix) {
let mat1 = mat1.as_mut().unwrap();
let out = out.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "np_linalg_inv", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
if dim1[0] != dim1[1] {
let err_msg =
format!("last 2 dimensions of the array must be square: {0} != {1}", dim1[0], dim1[1]);
report_error("LinAlgError", "np_linalg_inv", file!(), line!(), column!(), &err_msg);
}
let outdim = out.get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
if !matrix.is_invertible() {
report_error(
"LinAlgError",
"np_linalg_inv",
file!(),
line!(),
column!(),
"no inverse for Singular Matrix",
);
}
let inv = matrix.try_inverse().unwrap();
out_slice.copy_from_slice(inv.transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_pinv(mat1: *mut InputMatrix, out: *mut InputMatrix) {
let mat1 = mat1.as_mut().unwrap();
let out = out.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "np_linalg_pinv", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let outdim = out.get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let svd = matrix.svd(true, true);
let inv = svd.pseudo_inverse(1e-15);
match inv {
Ok(m) => {
out_slice.copy_from_slice(m.transpose().as_slice());
}
Err(err_msg) => {
report_error("LinAlgError", "np_linalg_pinv", file!(), line!(), column!(), err_msg);
}
}
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_matrix_power(
mat1: *mut InputMatrix,
mat2: *mut InputMatrix,
out: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let mat2 = mat2.as_mut().unwrap();
let out = out.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D", mat1.ndims);
report_error("ValueError", "np_linalg_matrix_power", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let power = unsafe { slice::from_raw_parts_mut(mat2.data, 1) };
let power = power[0];
let outdim = out.get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let abs_pow = power.abs();
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let mut result = matrix1.pow(abs_pow as u32);
if power < 0.0 {
if !result.is_invertible() {
report_error(
"LinAlgError",
"np_linalg_inv",
file!(),
line!(),
column!(),
"no inverse for Singular Matrix",
);
}
result = result.try_inverse().unwrap();
}
out_slice.copy_from_slice(result.transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_det(mat1: *mut InputMatrix, out: *mut InputMatrix) {
let mat1 = mat1.as_mut().unwrap();
let out = out.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "np_linalg_det", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, 1) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
if !matrix.is_square() {
let err_msg =
format!("last 2 dimensions of the array must be square: {0} != {1}", dim1[0], dim1[1]);
report_error("LinAlgError", "np_linalg_inv", file!(), line!(), column!(), &err_msg);
}
out_slice[0] = matrix.determinant();
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn sp_linalg_lu(
mat1: *mut InputMatrix,
out_l: *mut InputMatrix,
out_u: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let out_l = out_l.as_mut().unwrap();
let out_u = out_u.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "sp_linalg_lu", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let outl_dim = (*out_l).get_dims();
let outu_dim = (*out_u).get_dims();
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let out_l_slice = unsafe { slice::from_raw_parts_mut(out_l.data, outl_dim[0] * outl_dim[1]) };
let out_u_slice = unsafe { slice::from_raw_parts_mut(out_u.data, outu_dim[0] * outu_dim[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let (_, l, u) = matrix.lu().unpack();
out_l_slice.copy_from_slice(l.transpose().as_slice());
out_u_slice.copy_from_slice(u.transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn sp_linalg_schur(
mat1: *mut InputMatrix,
out_t: *mut InputMatrix,
out_z: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let out_t = out_t.as_mut().unwrap();
let out_z = out_z.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "sp_linalg_schur", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
if dim1[0] != dim1[1] {
let err_msg =
format!("last 2 dimensions of the array must be square: {0} != {1}", dim1[0], dim1[1]);
report_error("LinAlgError", "np_linalg_schur", file!(), line!(), column!(), &err_msg);
}
let out_t_dim = (*out_t).get_dims();
let out_z_dim = (*out_z).get_dims();
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let out_t_slice = unsafe { slice::from_raw_parts_mut(out_t.data, out_t_dim[0] * out_t_dim[1]) };
let out_z_slice = unsafe { slice::from_raw_parts_mut(out_z.data, out_z_dim[0] * out_z_dim[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let (z, t) = matrix.schur().unpack();
out_t_slice.copy_from_slice(t.transpose().as_slice());
out_z_slice.copy_from_slice(z.transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn sp_linalg_hessenberg(
mat1: *mut InputMatrix,
out_h: *mut InputMatrix,
out_q: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let out_h = out_h.as_mut().unwrap();
let out_q = out_q.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "sp_linalg_hessenberg", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
if dim1[0] != dim1[1] {
let err_msg =
format!("last 2 dimensions of the array must be square: {} != {}", dim1[0], dim1[1]);
report_error("LinAlgError", "sp_linalg_hessenberg", file!(), line!(), column!(), &err_msg);
}
let out_h_dim = (*out_h).get_dims();
let out_q_dim = (*out_q).get_dims();
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let out_h_slice = unsafe { slice::from_raw_parts_mut(out_h.data, out_h_dim[0] * out_h_dim[1]) };
let out_q_slice = unsafe { slice::from_raw_parts_mut(out_q.data, out_q_dim[0] * out_q_dim[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let (q, h) = matrix.hessenberg().unpack();
out_h_slice.copy_from_slice(h.transpose().as_slice());
out_q_slice.copy_from_slice(q.transpose().as_slice());
}

View File

@ -2,9 +2,6 @@
set -e set -e
: "${DEMO_LINALG_STUB:=linalg/target/release/liblinalg.a}"
: "${DEMO_LINALG_STUB32:=linalg/target/i686-unknown-linux-gnu/release/liblinalg.a}"
if [ -z "$1" ]; then if [ -z "$1" ]; then
echo "No argument supplied" echo "No argument supplied"
exit 1 exit 1
@ -14,26 +11,25 @@ declare -a nac3args
while [ $# -ge 1 ]; do while [ $# -ge 1 ]; do
case "$1" in case "$1" in
--help) --help)
echo "Usage: run_demo.sh [--help] [--out OUTFILE] [--debug] [-i686] -- [NAC3ARGS...] demo" echo "Usage: run_demo.sh [--help] [--out OUTFILE] [--lli] [--debug] -- [NAC3ARGS...]"
exit exit
;; ;;
--out) --out)
shift shift
outfile="$1" outfile="$1"
;; ;;
--lli)
use_lli=1
;;
--debug) --debug)
debug=1 debug=1
;; ;;
-i686)
i686=1
;;
--) --)
shift shift
break break
;; ;;
*) *)
echo "Unrecognized argument \"$1\"" break
exit 1
;; ;;
esac esac
shift shift
@ -54,19 +50,29 @@ else
fi fi
rm -f ./*.o ./*.bc demo rm -f ./*.o ./*.bc demo
if [ -z "$use_lli" ]; then
if [ -z "$i686" ]; then
$nac3standalone "${nac3args[@]}" $nac3standalone "${nac3args[@]}"
clang -c -std=gnu11 -Wall -Wextra -O3 -o demo.o demo.c
clang -o demo module.o demo.o $DEMO_LINALG_STUB -lm -Wl,--no-warn-search-mismatch
else
$nac3standalone --triple i686-unknown-linux-gnu --target-features +sse2 "${nac3args[@]}"
clang -m32 -c -std=gnu11 -Wall -Wextra -O3 -msse2 -o demo.o demo.c
clang -m32 -o demo module.o demo.o $DEMO_LINALG_STUB32 -lm -Wl,--no-warn-search-mismatch
fi
if [ -z "$outfile" ]; then clang -c -std=gnu11 -Wall -Wextra -O3 -o demo.o demo.c
./demo clang -lm -o demo module.o demo.o
if [ -z "$outfile" ]; then
./demo
else
./demo > "$outfile"
fi
else else
./demo > "$outfile" $nac3standalone --emit-llvm "${nac3args[@]}"
clang -c -std=gnu11 -Wall -Wextra -O3 -emit-llvm -o demo.bc demo.c
shopt -s nullglob
llvm-link -o nac3out.bc module*.bc main.bc
shopt -u nullglob
if [ -z "$outfile" ]; then
lli --extra-module demo.bc --extra-module irrt.bc nac3out.bc
else
lli --extra-module demo.bc --extra-module irrt.bc nac3out.bc > "$outfile"
fi
fi fi

View File

@ -1,76 +0,0 @@
@extern
def output_int32(x: int32):
...
@extern
def output_bool(x: bool):
...
def example1():
x, *ys, z = (1, 2, 3, 4, 5)
output_int32(x)
output_int32(len(ys))
output_int32(ys[0])
output_int32(ys[1])
output_int32(ys[2])
output_int32(z)
def example2():
x, y, *zs = (1, 2, 3, 4, 5)
output_int32(x)
output_int32(y)
output_int32(len(zs))
output_int32(zs[0])
output_int32(zs[1])
output_int32(zs[2])
def example3():
*xs, y, z = (1, 2, 3, 4, 5)
output_int32(len(xs))
output_int32(xs[0])
output_int32(xs[1])
output_int32(xs[2])
output_int32(y)
output_int32(z)
def example4():
*xs, y, z = (4, 5)
output_int32(len(xs))
output_int32(y)
output_int32(z)
def example5():
# Example from: https://docs.python.org/3/reference/simple_stmts.html#assignment-statements
x = [0, 1]
i = 0
i, x[i] = 1, 2 # i is updated, then x[i] is updated
output_int32(i)
output_int32(x[0])
output_int32(x[1])
class A:
value: int32
def __init__(self):
self.value = 1000
def example6():
ws = [88, 7, 8]
a = A()
x, [y, *ys, a.value], ws[0], (ws[0],) = 1, (2, False, 4, 5), 99, (6,)
output_int32(x)
output_int32(y)
output_bool(ys[0])
output_int32(ys[1])
output_int32(a.value)
output_int32(ws[0])
output_int32(ws[1])
output_int32(ws[2])
def run() -> int32:
example1()
example2()
example3()
example4()
example5()
example6()
return 0

View File

@ -1,31 +0,0 @@
@extern
def output_int32(x: int32):
...
@extern
def output_int64(x: int64):
...
X: int32 = 0
Y = int64(1)
def f():
global X, Y
X = 1
Y = int64(2)
def run() -> int32:
global X, Y
output_int32(X)
output_int64(Y)
f()
output_int32(X)
output_int64(Y)
X = 0
Y = int64(0)
output_int32(X)
output_int64(Y)
return 0

View File

@ -10,58 +10,23 @@ class A:
def __init__(self, a: int32): def __init__(self, a: int32):
self.a = a self.a = a
def output_all_fields(self): def f1(self):
output_int32(self.a) self.f2()
def set_a(self, a: int32): def f2(self):
self.a = a output_int32(self.a)
class B(A): class B(A):
b: int32 b: int32
def __init__(self, b: int32): def __init__(self, b: int32):
A.__init__(self, b + 1) self.a = b + 1
self.set_b(b)
def output_parent_fields(self):
A.output_all_fields(self)
def output_all_fields(self):
A.output_all_fields(self)
output_int32(self.b)
def set_b(self, b: int32):
self.b = b self.b = b
class C(B):
c: int32
def __init__(self, c: int32):
B.__init__(self, c + 1)
self.c = c
def output_parent_fields(self):
B.output_all_fields(self)
def output_all_fields(self):
B.output_all_fields(self)
output_int32(self.c)
def set_c(self, c: int32):
self.c = c
def run() -> int32: def run() -> int32:
ccc = C(10) aaa = A(5)
ccc.output_all_fields() bbb = B(2)
ccc.set_a(1) aaa.f1()
ccc.set_b(2) bbb.f1()
ccc.set_c(3)
ccc.output_all_fields()
bbb = B(10)
bbb.set_a(9)
bbb.set_b(8)
bbb.output_all_fields()
ccc.output_all_fields()
return 0 return 0

View File

@ -0,0 +1,3 @@
def run() -> int32:
hello = np_zeros((3, 4))
return 0

View File

@ -114,22 +114,12 @@ def test_ndarray_ones():
n: ndarray[float, 1] = np_ones([1]) n: ndarray[float, 1] = np_ones([1])
output_ndarray_float_1(n) output_ndarray_float_1(n)
dim = (1,)
n_tup: ndarray[float, 1] = np_ones(dim)
output_ndarray_float_1(n_tup)
def test_ndarray_full(): def test_ndarray_full():
n_float: ndarray[float, 1] = np_full([1], 2.0) n_float: ndarray[float, 1] = np_full([1], 2.0)
output_ndarray_float_1(n_float) output_ndarray_float_1(n_float)
n_i32: ndarray[int32, 1] = np_full([1], 2) n_i32: ndarray[int32, 1] = np_full([1], 2)
output_ndarray_int32_1(n_i32) output_ndarray_int32_1(n_i32)
dim = (1,)
n_float_tup: ndarray[float, 1] = np_full(dim, 2.0)
output_ndarray_float_1(n_float_tup)
n_i32_tup: ndarray[int32, 1] = np_full(dim, 2)
output_ndarray_int32_1(n_i32_tup)
def test_ndarray_eye(): def test_ndarray_eye():
n: ndarray[float, 2] = np_eye(2) n: ndarray[float, 2] = np_eye(2)
output_ndarray_float_2(n) output_ndarray_float_2(n)
@ -877,13 +867,6 @@ def test_ndarray_minimum_broadcast_rhs_scalar():
output_ndarray_float_2(min_x_zeros) output_ndarray_float_2(min_x_zeros)
output_ndarray_float_2(min_x_ones) output_ndarray_float_2(min_x_ones)
def test_ndarray_argmin():
x = np_array([[1., 2.], [3., 4.]])
y = np_argmin(x)
output_ndarray_float_2(x)
output_int64(y)
def test_ndarray_max(): def test_ndarray_max():
x = np_identity(2) x = np_identity(2)
y = np_max(x) y = np_max(x)
@ -927,13 +910,6 @@ def test_ndarray_maximum_broadcast_rhs_scalar():
output_ndarray_float_2(max_x_zeros) output_ndarray_float_2(max_x_zeros)
output_ndarray_float_2(max_x_ones) output_ndarray_float_2(max_x_ones)
def test_ndarray_argmax():
x = np_array([[1., 2.], [3., 4.]])
y = np_argmax(x)
output_ndarray_float_2(x)
output_int64(y)
def test_ndarray_abs(): def test_ndarray_abs():
x = np_identity(2) x = np_identity(2)
y = abs(x) y = abs(x)
@ -1439,142 +1415,6 @@ def test_ndarray_nextafter_broadcast_rhs_scalar():
output_ndarray_float_2(nextafter_x_zeros) output_ndarray_float_2(nextafter_x_zeros)
output_ndarray_float_2(nextafter_x_ones) output_ndarray_float_2(nextafter_x_ones)
def test_ndarray_transpose():
x: ndarray[float, 2] = np_array([[1., 2., 3.], [4., 5., 6.]])
y = np_transpose(x)
z = np_transpose(y)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_reshape():
w: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])
x = np_reshape(w, (1, 2, 1, -1))
y = np_reshape(x, [2, -1])
z = np_reshape(y, 10)
x1: ndarray[int32, 1] = np_array([1, 2, 3, 4])
x2: ndarray[int32, 2] = np_reshape(x1, (2, 2))
output_ndarray_float_1(w)
output_ndarray_float_2(y)
output_ndarray_float_1(z)
def test_ndarray_dot():
x1: ndarray[float, 1] = np_array([5.0, 1.0, 4.0, 2.0])
y1: ndarray[float, 1] = np_array([5.0, 1.0, 6.0, 6.0])
z1 = np_dot(x1, y1)
x2: ndarray[int32, 1] = np_array([5, 1, 4, 2])
y2: ndarray[int32, 1] = np_array([5, 1, 6, 6])
z2 = np_dot(x2, y2)
x3: ndarray[bool, 1] = np_array([True, True, True, True])
y3: ndarray[bool, 1] = np_array([True, True, True, True])
z3 = np_dot(x3, y3)
z4 = np_dot(2, 3)
z5 = np_dot(2., 3.)
z6 = np_dot(True, False)
output_float64(z1)
output_int32(z2)
output_bool(z3)
output_int32(z4)
output_float64(z5)
output_bool(z6)
def test_ndarray_cholesky():
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
y = np_linalg_cholesky(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_qr():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
y, z = np_linalg_qr(x)
output_ndarray_float_2(x)
# QR Factorization is not unique and gives different results in numpy and nalgebra
# Reverting the decomposition to compare the initial arrays
a = y @ z
output_ndarray_float_2(a)
def test_ndarray_linalg_inv():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
y = np_linalg_inv(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_pinv():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5]])
y = np_linalg_pinv(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_matrix_power():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
y = np_linalg_matrix_power(x, -9)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_det():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
y = np_linalg_det(x)
output_ndarray_float_2(x)
output_float64(y)
def test_ndarray_schur():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
t, z = sp_linalg_schur(x)
output_ndarray_float_2(x)
# Schur Factorization is not unique and gives different results in scipy and nalgebra
# Reverting the decomposition to compare the initial arrays
a = (z @ t) @ np_linalg_inv(z)
output_ndarray_float_2(a)
def test_ndarray_hessenberg():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 5.0, 8.5]])
h, q = sp_linalg_hessenberg(x)
output_ndarray_float_2(x)
# Hessenberg Factorization is not unique and gives different results in scipy and nalgebra
# Reverting the decomposition to compare the initial arrays
a = (q @ h) @ np_linalg_inv(q)
output_ndarray_float_2(a)
def test_ndarray_lu():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5]])
l, u = sp_linalg_lu(x)
output_ndarray_float_2(x)
output_ndarray_float_2(l)
output_ndarray_float_2(u)
def test_ndarray_svd():
w: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
x, y, z = np_linalg_svd(w)
output_ndarray_float_2(w)
# SVD Factorization is not unique and gives different results in numpy and nalgebra
# Reverting the decomposition to compare the initial arrays
a = x @ z
output_ndarray_float_2(a)
output_ndarray_float_1(y)
def run() -> int32: def run() -> int32:
test_ndarray_ctor() test_ndarray_ctor()
test_ndarray_empty() test_ndarray_empty()
@ -1679,19 +1519,16 @@ def run() -> int32:
test_ndarray_round() test_ndarray_round()
test_ndarray_floor() test_ndarray_floor()
test_ndarray_ceil()
test_ndarray_min() test_ndarray_min()
test_ndarray_minimum() test_ndarray_minimum()
test_ndarray_minimum_broadcast() test_ndarray_minimum_broadcast()
test_ndarray_minimum_broadcast_lhs_scalar() test_ndarray_minimum_broadcast_lhs_scalar()
test_ndarray_minimum_broadcast_rhs_scalar() test_ndarray_minimum_broadcast_rhs_scalar()
test_ndarray_argmin()
test_ndarray_max() test_ndarray_max()
test_ndarray_maximum() test_ndarray_maximum()
test_ndarray_maximum_broadcast() test_ndarray_maximum_broadcast()
test_ndarray_maximum_broadcast_lhs_scalar() test_ndarray_maximum_broadcast_lhs_scalar()
test_ndarray_maximum_broadcast_rhs_scalar() test_ndarray_maximum_broadcast_rhs_scalar()
test_ndarray_argmax()
test_ndarray_abs() test_ndarray_abs()
test_ndarray_isnan() test_ndarray_isnan()
test_ndarray_isinf() test_ndarray_isinf()
@ -1754,18 +1591,5 @@ def run() -> int32:
test_ndarray_nextafter_broadcast() test_ndarray_nextafter_broadcast()
test_ndarray_nextafter_broadcast_lhs_scalar() test_ndarray_nextafter_broadcast_lhs_scalar()
test_ndarray_nextafter_broadcast_rhs_scalar() test_ndarray_nextafter_broadcast_rhs_scalar()
test_ndarray_transpose()
test_ndarray_reshape()
test_ndarray_dot()
test_ndarray_cholesky()
test_ndarray_qr()
test_ndarray_svd()
test_ndarray_linalg_inv()
test_ndarray_pinv()
test_ndarray_matrix_power()
test_ndarray_det()
test_ndarray_lu()
test_ndarray_schur()
test_ndarray_hessenberg()
return 0 return 0

View File

@ -1,30 +0,0 @@
@extern
def output_bool(x: bool):
...
def str_eq():
output_bool("" == "")
output_bool("a" == "")
output_bool("a" == "b")
output_bool("b" == "a")
output_bool("a" == "a")
output_bool("test string" == "test string")
output_bool("test string1" == "test string2")
def str_ne():
output_bool("" != "")
output_bool("a" != "")
output_bool("a" != "b")
output_bool("b" != "a")
output_bool("a" != "a")
output_bool("test string" != "test string")
output_bool("test string1" != "test string2")
def run() -> int32:
str_eq()
str_ne()
return 0

View File

@ -1,7 +1,3 @@
@extern
def output_bool(b: bool):
...
@extern @extern
def output_int32_list(x: list[int32]): def output_int32_list(x: list[int32]):
... ...
@ -17,41 +13,6 @@ class A:
self.a = a self.a = a
self.b = b self.b = b
def test_tuple_eq():
# 0-len
output_bool(() == ())
# 1-len
output_bool((1,) == ())
output_bool(() == (1,))
output_bool((1,) == (1,))
output_bool((1,) == (2,))
# # 2-len
output_bool((1, 2) == ())
output_bool(() == (1, 2))
output_bool((1,) == (1, 2))
output_bool((1, 2) == (1,))
output_bool((2, 2) == (1, 2))
output_bool((1, 2) == (2, 2))
def test_tuple_ne():
# 0-len
output_bool(() != ())
# 1-len
output_bool((1,) != ())
output_bool(() != (1,))
output_bool((1,) != (1,))
output_bool((1,) != (2,))
# 2-len
output_bool((1, 2) != ())
output_bool(() != (1, 2))
output_bool((1,) != (1, 2))
output_bool((1, 2) != (1,))
output_bool((2, 2) != (1, 2))
output_bool((1, 2) != (2, 2))
def run() -> int32: def run() -> int32:
data = [0, 1, 2, 3] data = [0, 1, 2, 3]
@ -65,14 +26,4 @@ def run() -> int32:
output_int32(tl[0][1]) output_int32(tl[0][1])
output_int32(tl[1]) output_int32(tl[1])
output_int32(len(()))
output_int32(len((1,)))
output_int32(len((1, 2)))
output_int32(len((1, 2, 3)))
output_int32(len((1, 2, 3, 4)))
output_int32(len((1, 2, 3, 4, 5)))
test_tuple_eq()
test_tuple_ne()
return 0 return 0

View File

@ -1,11 +0,0 @@
def f(*args: int32):
pass
def run() -> int32:
f()
f(1)
f(1, 2)
f(1, 2, 3)
return 0

View File

@ -1,14 +1,5 @@
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use parking_lot::{Mutex, RwLock};
use nac3core::{ use nac3core::{
codegen::{CodeGenContext, CodeGenerator}, codegen::CodeGenContext,
inkwell::{module::Linkage, values::BasicValue},
nac3parser::ast::{self, StrRef},
symbol_resolver::{SymbolResolver, SymbolValue, ValueEnum}, symbol_resolver::{SymbolResolver, SymbolValue, ValueEnum},
toplevel::{DefinitionId, TopLevelDef}, toplevel::{DefinitionId, TopLevelDef},
typecheck::{ typecheck::{
@ -16,10 +7,15 @@ use nac3core::{
typedef::{Type, Unifier}, typedef::{Type, Unifier},
}, },
}; };
use nac3parser::ast::{self, StrRef};
use parking_lot::{Mutex, RwLock};
use std::collections::HashSet;
use std::{collections::HashMap, sync::Arc};
pub struct ResolverInternal { pub struct ResolverInternal {
pub id_to_type: Mutex<HashMap<StrRef, Type>>, pub id_to_type: Mutex<HashMap<StrRef, Type>>,
pub id_to_def: Mutex<HashMap<StrRef, DefinitionId>>, pub id_to_def: Mutex<HashMap<StrRef, DefinitionId>>,
pub class_names: Mutex<HashMap<StrRef, Type>>,
pub module_globals: Mutex<HashMap<StrRef, SymbolValue>>, pub module_globals: Mutex<HashMap<StrRef, SymbolValue>>,
pub str_store: Mutex<HashMap<String, i32>>, pub str_store: Mutex<HashMap<String, i32>>,
} }
@ -50,51 +46,20 @@ impl SymbolResolver for Resolver {
fn get_symbol_type( fn get_symbol_type(
&self, &self,
unifier: &mut Unifier, _: &mut Unifier,
_: &[Arc<RwLock<TopLevelDef>>], _: &[Arc<RwLock<TopLevelDef>>],
primitives: &PrimitiveStore, _: &PrimitiveStore,
str: StrRef, str: StrRef,
) -> Result<Type, String> { ) -> Result<Type, String> {
self.0 self.0.id_to_type.lock().get(&str).copied().ok_or(format!("cannot get type of {str}"))
.id_to_type
.lock()
.get(&str)
.copied()
.or_else(|| {
self.0
.module_globals
.lock()
.get(&str)
.cloned()
.map(|v| v.get_type(primitives, unifier))
})
.ok_or(format!("cannot get type of {str}"))
} }
fn get_symbol_value<'ctx>( fn get_symbol_value<'ctx>(
&self, &self,
str: StrRef, _: StrRef,
ctx: &mut CodeGenContext<'ctx, '_>, _: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
) -> Option<ValueEnum<'ctx>> { ) -> Option<ValueEnum<'ctx>> {
self.0.module_globals.lock().get(&str).cloned().map(|v| { unimplemented!()
ctx.module
.get_global(&str.to_string())
.unwrap_or_else(|| {
let ty = v.get_type(&ctx.primitives, &mut ctx.unifier);
let init_val = ctx.gen_symbol_val(generator, &v, ty);
let llvm_ty = init_val.get_type();
let global = ctx.module.add_global(llvm_ty, None, &str.to_string());
global.set_linkage(Linkage::LinkOnceAny);
global.set_initializer(&init_val);
global
})
.as_basic_value_enum()
.into()
})
} }
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> { fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> {

Some files were not shown because too many files have changed in this diff Show More