forked from M-Labs/nac3
Compare commits
83 Commits
ndstrides-
...
master
Author | SHA1 | Date |
---|---|---|
David Mak | 01edd5af67 | |
occheung | 015714eee1 | |
occheung | 71dec251e3 | |
occheung | fce61f7b8c | |
abdul124 | babc081dbd | |
abdul124 | 5337dbe23b | |
abdul124 | f862c01412 | |
David Mak | 0c9705f5f1 | |
David Mak | 5f940f86d9 | |
Sebastien Bourdeauducq | 5651e00688 | |
Sebastien Bourdeauducq | f6745b987f | |
mwojcik | e0dedc6580 | |
David Mak | 28f574282c | |
David Mak | 144f0922db | |
David Mak | c58ce9c3a9 | |
David Mak | f7e296da53 | |
David Mak | b58c99369e | |
David Mak | 1a535db558 | |
David Mak | 1ba2e287a6 | |
lyken | f95f979ad3 | |
lyken | 48e2148c0f | |
David Mak | 88e57f7120 | |
David Mak | d7633c42bc | |
David Mak | a4f53b6e6b | |
David Mak | 9d9ead211e | |
David Mak | 26a1b85206 | |
David Mak | 2822074b2d | |
David Mak | fe67ed076c | |
David Mak | 94e2414df0 | |
Sebastien Bourdeauducq | 2cee760404 | |
Sebastien Bourdeauducq | 230982dc84 | |
occheung | 2bd3f63991 | |
occheung | b53266e9e6 | |
occheung | 86eb22bbf3 | |
occheung | beaa38047d | |
occheung | 705dc4ff1c | |
occheung | 979209a526 | |
David Mak | c3927d0ef6 | |
David Mak | 202a902cd0 | |
David Mak | b6e2644391 | |
David Mak | 45cd01556b | |
David Mak | b6cd2a6993 | |
David Mak | a98f33e6d1 | |
David Mak | 5839badadd | |
David Mak | 56c845aac4 | |
David Mak | 65a12d9ab3 | |
David Mak | 9c6685fa8f | |
David Mak | 2bb788e4bb | |
David Mak | 42a2f243b5 | |
David Mak | 3ce2eddcdc | |
David Mak | 51bf126a32 | |
David Mak | 1a197c67f6 | |
David Mak | 581b2f7bb2 | |
David Mak | 746329ec5d | |
David Mak | e60e8e837f | |
David Mak | 9fdbe9695d | |
David Mak | 8065e73598 | |
David Mak | 192290889b | |
David Mak | 1407553a2f | |
David Mak | c7697606e1 | |
David Mak | 88d0ccbf69 | |
David Mak | a43b59539c | |
David Mak | fe06b2806f | |
David Mak | 7f6c9a25ac | |
Sébastien Bourdeauducq | 6c8382219f | |
Sebastien Bourdeauducq | 9274a7b96b | |
Sébastien Bourdeauducq | d1c0fe2900 | |
mwojcik | f2c047ba57 | |
David Mak | 5e2e77a500 | |
David Mak | f3cc4702b9 | |
David Mak | 3e92c491f5 | |
lyken | 7f629f1579 | |
lyken | 5640a793e2 | |
David Mak | abbaa506ad | |
David Mak | f3dc02d646 | |
David Mak | ea217eaea1 | |
Sébastien Bourdeauducq | 5a34551905 | |
Sebastien Bourdeauducq | 6098b1b853 | |
Sebastien Bourdeauducq | 668ccb1c95 | |
Sebastien Bourdeauducq | a3c624d69d | |
Sébastien Bourdeauducq | bd06155f34 | |
David Mak | 9c33c4209c | |
Sebastien Bourdeauducq | 122983f11c |
|
@ -1,24 +1,24 @@
|
||||||
# See https://pre-commit.com for more information
|
# See https://pre-commit.com for more information
|
||||||
# See https://pre-commit.com/hooks.html for more hooks
|
# See https://pre-commit.com/hooks.html for more hooks
|
||||||
|
|
||||||
default_stages: [commit]
|
default_stages: [pre-commit]
|
||||||
|
|
||||||
repos:
|
repos:
|
||||||
- repo: local
|
- repo: local
|
||||||
hooks:
|
hooks:
|
||||||
- id: nac3-cargo-fmt
|
- id: nac3-cargo-fmt
|
||||||
name: nac3 cargo format
|
name: nac3 cargo format
|
||||||
entry: cargo
|
entry: nix
|
||||||
language: system
|
language: system
|
||||||
types: [file, rust]
|
types: [file, rust]
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
description: Runs cargo fmt on the codebase.
|
description: Runs cargo fmt on the codebase.
|
||||||
args: [fmt]
|
args: [develop, -c, cargo, fmt, --all]
|
||||||
- id: nac3-cargo-clippy
|
- id: nac3-cargo-clippy
|
||||||
name: nac3 cargo clippy
|
name: nac3 cargo clippy
|
||||||
entry: cargo
|
entry: nix
|
||||||
language: system
|
language: system
|
||||||
types: [file, rust]
|
types: [file, rust]
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
description: Runs cargo clippy on the codebase.
|
description: Runs cargo clippy on the codebase.
|
||||||
args: [clippy, --tests]
|
args: [develop, -c, cargo, clippy, --tests]
|
||||||
|
|
|
@ -26,9 +26,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "anstream"
|
name = "anstream"
|
||||||
version = "0.6.15"
|
version = "0.6.18"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "64e15c1ab1f89faffbf04a634d5e1962e9074f2741eef6d97f3c4e322426d526"
|
checksum = "8acc5369981196006228e28809f761875c0327210a891e941f4c683b3a99529b"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anstyle",
|
"anstyle",
|
||||||
"anstyle-parse",
|
"anstyle-parse",
|
||||||
|
@ -41,67 +41,67 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "anstyle"
|
name = "anstyle"
|
||||||
version = "1.0.8"
|
version = "1.0.10"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1"
|
checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "anstyle-parse"
|
name = "anstyle-parse"
|
||||||
version = "0.2.5"
|
version = "0.2.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "eb47de1e80c2b463c735db5b217a0ddc39d612e7ac9e2e96a5aed1f57616c1cb"
|
checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"utf8parse",
|
"utf8parse",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "anstyle-query"
|
name = "anstyle-query"
|
||||||
version = "1.1.1"
|
version = "1.1.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6d36fc52c7f6c869915e99412912f22093507da8d9e942ceaf66fe4b7c14422a"
|
checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"windows-sys 0.52.0",
|
"windows-sys 0.59.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "anstyle-wincon"
|
name = "anstyle-wincon"
|
||||||
version = "3.0.4"
|
version = "3.0.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "5bf74e1b6e971609db8ca7a9ce79fd5768ab6ae46441c572e46cf596f59e57f8"
|
checksum = "2109dbce0e72be3ec00bed26e6a7479ca384ad226efdd66db8fa2e3a38c83125"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anstyle",
|
"anstyle",
|
||||||
"windows-sys 0.52.0",
|
"windows-sys 0.59.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ascii-canvas"
|
name = "ascii-canvas"
|
||||||
version = "3.0.0"
|
version = "4.0.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8824ecca2e851cec16968d54a01dd372ef8f95b244fb84b84e70128be347c3c6"
|
checksum = "ef1e3e699d84ab1b0911a1010c5c106aa34ae89aeac103be5ce0c3859db1e891"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"term",
|
"term",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "autocfg"
|
name = "autocfg"
|
||||||
version = "1.3.0"
|
version = "1.4.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
|
checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bit-set"
|
name = "bit-set"
|
||||||
version = "0.5.3"
|
version = "0.8.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1"
|
checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bit-vec",
|
"bit-vec",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bit-vec"
|
name = "bit-vec"
|
||||||
version = "0.6.3"
|
version = "0.8.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb"
|
checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bitflags"
|
name = "bitflags"
|
||||||
|
@ -109,6 +109,15 @@ version = "2.6.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de"
|
checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "block-buffer"
|
||||||
|
version = "0.10.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71"
|
||||||
|
dependencies = [
|
||||||
|
"generic-array",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "byteorder"
|
name = "byteorder"
|
||||||
version = "1.5.0"
|
version = "1.5.0"
|
||||||
|
@ -117,9 +126,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cc"
|
name = "cc"
|
||||||
version = "1.1.15"
|
version = "1.2.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "57b6a275aa2903740dc87da01c62040406b8812552e97129a63ea8850a17c6e6"
|
checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"shlex",
|
"shlex",
|
||||||
]
|
]
|
||||||
|
@ -132,9 +141,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "clap"
|
name = "clap"
|
||||||
version = "4.5.16"
|
version = "4.5.21"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "ed6719fffa43d0d87e5fd8caeab59be1554fb028cd30edc88fc4369b17971019"
|
checksum = "fb3b4b9e5a7c7514dfa52869339ee98b3156b0bfb4e8a77c4ff4babb64b1604f"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"clap_builder",
|
"clap_builder",
|
||||||
"clap_derive",
|
"clap_derive",
|
||||||
|
@ -142,9 +151,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "clap_builder"
|
name = "clap_builder"
|
||||||
version = "4.5.15"
|
version = "4.5.21"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "216aec2b177652e3846684cbfe25c9964d18ec45234f0f5da5157b207ed1aab6"
|
checksum = "b17a95aa67cc7b5ebd32aa5370189aa0d79069ef1c64ce893bd30fb24bff20ec"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anstream",
|
"anstream",
|
||||||
"anstyle",
|
"anstyle",
|
||||||
|
@ -154,27 +163,27 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "clap_derive"
|
name = "clap_derive"
|
||||||
version = "4.5.13"
|
version = "4.5.18"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "501d359d5f3dcaf6ecdeee48833ae73ec6e42723a1e52419c79abf9507eec0a0"
|
checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"heck 0.5.0",
|
"heck 0.5.0",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.76",
|
"syn 2.0.87",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "clap_lex"
|
name = "clap_lex"
|
||||||
version = "0.7.2"
|
version = "0.7.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97"
|
checksum = "afb84c814227b90d6895e01398aee0d8033c00e7466aca416fb6a8e0eb19d8a7"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "colorchoice"
|
name = "colorchoice"
|
||||||
version = "1.0.2"
|
version = "1.0.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0"
|
checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "console"
|
name = "console"
|
||||||
|
@ -188,6 +197,15 @@ dependencies = [
|
||||||
"windows-sys 0.52.0",
|
"windows-sys 0.52.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cpufeatures"
|
||||||
|
version = "0.2.15"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0ca741a962e1b0bff6d724a1a0958b686406e853bb14061f218562e1896f95e6"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "crossbeam"
|
name = "crossbeam"
|
||||||
version = "0.8.4"
|
version = "0.8.4"
|
||||||
|
@ -245,32 +263,31 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80"
|
checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "crunchy"
|
name = "crypto-common"
|
||||||
version = "0.2.2"
|
version = "0.1.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
|
checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "dirs-next"
|
|
||||||
version = "2.0.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "b98cf8ebf19c3d1b223e151f99a4f9f0690dca41414773390fc824184ac833e1"
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"generic-array",
|
||||||
"dirs-sys-next",
|
"typenum",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "dirs-sys-next"
|
name = "digest"
|
||||||
version = "0.1.2"
|
version = "0.10.7"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4ebda144c4fe02d1f7ea1a7d9641b6fc6b580adcfa024ae48797ecdeb6825b4d"
|
checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"libc",
|
"block-buffer",
|
||||||
"redox_users",
|
"crypto-common",
|
||||||
"winapi",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "dissimilar"
|
||||||
|
version = "1.0.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "59f8e79d1fbf76bdfbde321e902714bf6c49df88a7dda6fc682fc2979226962d"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "either"
|
name = "either"
|
||||||
version = "1.13.0"
|
version = "1.13.0"
|
||||||
|
@ -310,9 +327,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fastrand"
|
name = "fastrand"
|
||||||
version = "2.1.1"
|
version = "2.2.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6"
|
checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fixedbitset"
|
name = "fixedbitset"
|
||||||
|
@ -329,6 +346,16 @@ dependencies = [
|
||||||
"byteorder",
|
"byteorder",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "generic-array"
|
||||||
|
version = "0.14.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a"
|
||||||
|
dependencies = [
|
||||||
|
"typenum",
|
||||||
|
"version_check",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "getopts"
|
name = "getopts"
|
||||||
version = "0.2.21"
|
version = "0.2.21"
|
||||||
|
@ -349,6 +376,12 @@ dependencies = [
|
||||||
"wasi",
|
"wasi",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "glob"
|
||||||
|
version = "0.3.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "hashbrown"
|
name = "hashbrown"
|
||||||
version = "0.12.3"
|
version = "0.12.3"
|
||||||
|
@ -364,6 +397,12 @@ dependencies = [
|
||||||
"ahash",
|
"ahash",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hashbrown"
|
||||||
|
version = "0.15.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3a9bfc1af68b1726ea47d3d5109de126281def866b33970e10fbab11b5dafab3"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "heck"
|
name = "heck"
|
||||||
version = "0.4.1"
|
version = "0.4.1"
|
||||||
|
@ -376,6 +415,15 @@ version = "0.5.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
|
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "home"
|
||||||
|
version = "0.5.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5"
|
||||||
|
dependencies = [
|
||||||
|
"windows-sys 0.52.0",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "indexmap"
|
name = "indexmap"
|
||||||
version = "1.9.3"
|
version = "1.9.3"
|
||||||
|
@ -388,12 +436,12 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "indexmap"
|
name = "indexmap"
|
||||||
version = "2.4.0"
|
version = "2.6.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "93ead53efc7ea8ed3cfb0c79fc8023fbb782a5432b52830b6518941cebe6505c"
|
checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"equivalent",
|
"equivalent",
|
||||||
"hashbrown 0.14.5",
|
"hashbrown 0.15.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -404,9 +452,9 @@ checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "inkwell"
|
name = "inkwell"
|
||||||
version = "0.4.0"
|
version = "0.5.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b597a7b2cdf279aeef6d7149071e35e4bc87c2cf05a5b7f2d731300bffe587ea"
|
checksum = "40fb405537710d51f6bdbc8471365ddd4cd6d3a3c3ad6e0c8291691031ba94b2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"either",
|
"either",
|
||||||
"inkwell_internals",
|
"inkwell_internals",
|
||||||
|
@ -418,13 +466,13 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "inkwell_internals"
|
name = "inkwell_internals"
|
||||||
version = "0.9.0"
|
version = "0.10.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4fa4d8d74483041a882adaa9a29f633253a66dde85055f0495c121620ac484b2"
|
checksum = "9dd28cfd4cfba665d47d31c08a6ba637eed16770abca2eccbbc3ca831fef1e44"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.76",
|
"syn 2.0.87",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -447,15 +495,6 @@ version = "1.70.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
|
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "itertools"
|
|
||||||
version = "0.11.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57"
|
|
||||||
dependencies = [
|
|
||||||
"either",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "itertools"
|
name = "itertools"
|
||||||
version = "0.13.0"
|
version = "0.13.0"
|
||||||
|
@ -472,34 +511,44 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b"
|
checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lalrpop"
|
name = "keccak"
|
||||||
version = "0.20.2"
|
version = "0.1.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "55cb077ad656299f160924eb2912aa147d7339ea7d69e1b5517326fdcec3c1ca"
|
checksum = "ecc2af9a1119c51f12a14607e783cb977bde58bc069ff0c3da1095e635d70654"
|
||||||
|
dependencies = [
|
||||||
|
"cpufeatures",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "lalrpop"
|
||||||
|
version = "0.22.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "06093b57658c723a21da679530e061a8c25340fa5a6f98e313b542268c7e2a1f"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"ascii-canvas",
|
"ascii-canvas",
|
||||||
"bit-set",
|
"bit-set",
|
||||||
"ena",
|
"ena",
|
||||||
"itertools 0.11.0",
|
"itertools",
|
||||||
"lalrpop-util",
|
"lalrpop-util",
|
||||||
"petgraph",
|
"petgraph",
|
||||||
"pico-args",
|
"pico-args",
|
||||||
"regex",
|
"regex",
|
||||||
"regex-syntax",
|
"regex-syntax",
|
||||||
|
"sha3",
|
||||||
"string_cache",
|
"string_cache",
|
||||||
"term",
|
"term",
|
||||||
"tiny-keccak",
|
|
||||||
"unicode-xid",
|
"unicode-xid",
|
||||||
"walkdir",
|
"walkdir",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lalrpop-util"
|
name = "lalrpop-util"
|
||||||
version = "0.20.2"
|
version = "0.22.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "507460a910eb7b32ee961886ff48539633b788a36b65692b95f225b844c82553"
|
checksum = "feee752d43abd0f4807a921958ab4131f692a44d4d599733d4419c5d586176ce"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"regex-automata",
|
"regex-automata",
|
||||||
|
"rustversion",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -510,9 +559,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "libc"
|
name = "libc"
|
||||||
version = "0.2.158"
|
version = "0.2.164"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439"
|
checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "libloading"
|
name = "libloading"
|
||||||
|
@ -524,16 +573,6 @@ dependencies = [
|
||||||
"windows-targets",
|
"windows-targets",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "libredox"
|
|
||||||
version = "0.1.3"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d"
|
|
||||||
dependencies = [
|
|
||||||
"bitflags",
|
|
||||||
"libc",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "linked-hash-map"
|
name = "linked-hash-map"
|
||||||
version = "0.5.6"
|
version = "0.5.6"
|
||||||
|
@ -594,11 +633,9 @@ dependencies = [
|
||||||
name = "nac3artiq"
|
name = "nac3artiq"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"inkwell",
|
"itertools",
|
||||||
"itertools 0.13.0",
|
|
||||||
"nac3core",
|
"nac3core",
|
||||||
"nac3ld",
|
"nac3ld",
|
||||||
"nac3parser",
|
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
"pyo3",
|
"pyo3",
|
||||||
"tempfile",
|
"tempfile",
|
||||||
|
@ -609,7 +646,6 @@ name = "nac3ast"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"fxhash",
|
"fxhash",
|
||||||
"lazy_static",
|
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
"string-interner",
|
"string-interner",
|
||||||
]
|
]
|
||||||
|
@ -619,11 +655,12 @@ name = "nac3core"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"crossbeam",
|
"crossbeam",
|
||||||
"indexmap 2.4.0",
|
"indexmap 2.6.0",
|
||||||
"indoc",
|
"indoc",
|
||||||
"inkwell",
|
"inkwell",
|
||||||
"insta",
|
"insta",
|
||||||
"itertools 0.13.0",
|
"itertools",
|
||||||
|
"nac3core_derive",
|
||||||
"nac3parser",
|
"nac3parser",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
"rayon",
|
"rayon",
|
||||||
|
@ -633,6 +670,18 @@ dependencies = [
|
||||||
"test-case",
|
"test-case",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "nac3core_derive"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"nac3core",
|
||||||
|
"proc-macro-error",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 2.0.87",
|
||||||
|
"trybuild",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "nac3ld"
|
name = "nac3ld"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
|
@ -661,9 +710,7 @@ name = "nac3standalone"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"clap",
|
"clap",
|
||||||
"inkwell",
|
|
||||||
"nac3core",
|
"nac3core",
|
||||||
"nac3parser",
|
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -675,9 +722,9 @@ checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "once_cell"
|
name = "once_cell"
|
||||||
version = "1.19.0"
|
version = "1.20.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
|
checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "parking_lot"
|
name = "parking_lot"
|
||||||
|
@ -709,7 +756,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db"
|
checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"fixedbitset",
|
"fixedbitset",
|
||||||
"indexmap 2.4.0",
|
"indexmap 2.6.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -752,7 +799,7 @@ dependencies = [
|
||||||
"phf_shared 0.11.2",
|
"phf_shared 0.11.2",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.76",
|
"syn 2.0.87",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -781,9 +828,9 @@ checksum = "5be167a7af36ee22fe3115051bc51f6e6c7054c9348e28deb4f49bd6f705a315"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "portable-atomic"
|
name = "portable-atomic"
|
||||||
version = "1.7.0"
|
version = "1.9.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265"
|
checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ppv-lite86"
|
name = "ppv-lite86"
|
||||||
|
@ -801,10 +848,34 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c"
|
checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "proc-macro2"
|
name = "proc-macro-error"
|
||||||
version = "1.0.86"
|
version = "1.0.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77"
|
checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro-error-attr",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 1.0.109",
|
||||||
|
"version_check",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "proc-macro-error-attr"
|
||||||
|
version = "1.0.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"version_check",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "proc-macro2"
|
||||||
|
version = "1.0.89"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"unicode-ident",
|
"unicode-ident",
|
||||||
]
|
]
|
||||||
|
@ -856,7 +927,7 @@ dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"pyo3-macros-backend",
|
"pyo3-macros-backend",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.76",
|
"syn 2.0.87",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -869,7 +940,7 @@ dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"pyo3-build-config",
|
"pyo3-build-config",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.76",
|
"syn 2.0.87",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -933,29 +1004,18 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "redox_syscall"
|
name = "redox_syscall"
|
||||||
version = "0.5.3"
|
version = "0.5.7"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "2a908a6e00f1fdd0dfd9c0eb08ce85126f6d8bbda50017e74bc4a4b7d4a926a4"
|
checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bitflags",
|
"bitflags",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "redox_users"
|
|
||||||
version = "0.4.6"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43"
|
|
||||||
dependencies = [
|
|
||||||
"getrandom",
|
|
||||||
"libredox",
|
|
||||||
"thiserror",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "regex"
|
name = "regex"
|
||||||
version = "1.10.6"
|
version = "1.11.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619"
|
checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aho-corasick",
|
"aho-corasick",
|
||||||
"memchr",
|
"memchr",
|
||||||
|
@ -965,9 +1025,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "regex-automata"
|
name = "regex-automata"
|
||||||
version = "0.4.7"
|
version = "0.4.9"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df"
|
checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aho-corasick",
|
"aho-corasick",
|
||||||
"memchr",
|
"memchr",
|
||||||
|
@ -976,9 +1036,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "regex-syntax"
|
name = "regex-syntax"
|
||||||
version = "0.8.4"
|
version = "0.8.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b"
|
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "runkernel"
|
name = "runkernel"
|
||||||
|
@ -989,9 +1049,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rustix"
|
name = "rustix"
|
||||||
version = "0.38.35"
|
version = "0.38.41"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a85d50532239da68e9addb745ba38ff4612a242c1c7ceea689c4bc7c2f43c36f"
|
checksum = "d7f649912bc1495e167a6edee79151c84b1bad49748cb4f1f1167f459f6224f6"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bitflags",
|
"bitflags",
|
||||||
"errno",
|
"errno",
|
||||||
|
@ -1002,9 +1062,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rustversion"
|
name = "rustversion"
|
||||||
version = "1.0.17"
|
version = "1.0.18"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6"
|
checksum = "0e819f2bc632f285be6d7cd36e25940d45b2391dd6d9b939e79de557f7014248"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ryu"
|
name = "ryu"
|
||||||
|
@ -1035,29 +1095,29 @@ checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde"
|
name = "serde"
|
||||||
version = "1.0.209"
|
version = "1.0.215"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "99fce0ffe7310761ca6bf9faf5115afbc19688edd00171d81b1bb1b116c63e09"
|
checksum = "6513c1ad0b11a9376da888e3e0baa0077f1aed55c17f50e7b2397136129fb88f"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"serde_derive",
|
"serde_derive",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_derive"
|
name = "serde_derive"
|
||||||
version = "1.0.209"
|
version = "1.0.215"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a5831b979fd7b5439637af1752d535ff49f4860c0f341d1baeb6faf0f4242170"
|
checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.76",
|
"syn 2.0.87",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_json"
|
name = "serde_json"
|
||||||
version = "1.0.127"
|
version = "1.0.133"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8043c06d9f82bd7271361ed64f415fe5e12a77fdb52e573e7f06a516dea329ad"
|
checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"itoa",
|
"itoa",
|
||||||
"memchr",
|
"memchr",
|
||||||
|
@ -1065,6 +1125,15 @@ dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde_spanned"
|
||||||
|
version = "0.6.8"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "87607cb1398ed59d48732e575a4c28a7a8ebf2454b964fe3f224f2afc07909e1"
|
||||||
|
dependencies = [
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_yaml"
|
name = "serde_yaml"
|
||||||
version = "0.8.26"
|
version = "0.8.26"
|
||||||
|
@ -1077,6 +1146,16 @@ dependencies = [
|
||||||
"yaml-rust",
|
"yaml-rust",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "sha3"
|
||||||
|
version = "0.10.8"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "75872d278a8f37ef87fa0ddbda7802605cb18344497949862c0d4dcb291eba60"
|
||||||
|
dependencies = [
|
||||||
|
"digest",
|
||||||
|
"keccak",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "shlex"
|
name = "shlex"
|
||||||
version = "1.3.0"
|
version = "1.3.0"
|
||||||
|
@ -1147,7 +1226,7 @@ dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"rustversion",
|
"rustversion",
|
||||||
"syn 2.0.76",
|
"syn 2.0.87",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -1163,9 +1242,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "syn"
|
name = "syn"
|
||||||
version = "2.0.76"
|
version = "2.0.87"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "578e081a14e0cefc3279b0472138c513f37b41a08d5a3cca9b6e4e8ceb6cd525"
|
checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
|
@ -1179,10 +1258,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
|
checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tempfile"
|
name = "target-triple"
|
||||||
version = "3.12.0"
|
version = "0.1.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "04cbcdd0c794ebb0d4cf35e88edd2f7d2c4c3e9a5a6dab322839b321c6a87a64"
|
checksum = "42a4d50cdb458045afc8131fd91b64904da29548bcb63c7236e0844936c13078"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tempfile"
|
||||||
|
version = "3.14.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "28cce251fcbc87fac86a866eeb0d6c2d536fc16d06f184bb61aeae11aa4cee0c"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"fastrand",
|
"fastrand",
|
||||||
|
@ -1193,13 +1278,21 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "term"
|
name = "term"
|
||||||
version = "0.7.0"
|
version = "1.0.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c59df8ac95d96ff9bede18eb7300b0fda5e5d8d90960e76f8e14ae765eedbf1f"
|
checksum = "4df4175de05129f31b80458c6df371a15e7fc3fd367272e6bf938e5c351c7ea0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"dirs-next",
|
"home",
|
||||||
"rustversion",
|
"windows-sys 0.52.0",
|
||||||
"winapi",
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "termcolor"
|
||||||
|
version = "1.4.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755"
|
||||||
|
dependencies = [
|
||||||
|
"winapi-util",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -1217,33 +1310,80 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "thiserror"
|
name = "thiserror"
|
||||||
version = "1.0.63"
|
version = "1.0.69"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724"
|
checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"thiserror-impl",
|
"thiserror-impl",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "thiserror-impl"
|
name = "thiserror-impl"
|
||||||
version = "1.0.63"
|
version = "1.0.69"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261"
|
checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.76",
|
"syn 2.0.87",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tiny-keccak"
|
name = "toml"
|
||||||
version = "2.0.2"
|
version = "0.8.19"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237"
|
checksum = "a1ed1f98e3fdc28d6d910e6737ae6ab1a93bf1985935a1193e68f93eeb68d24e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"crunchy",
|
"serde",
|
||||||
|
"serde_spanned",
|
||||||
|
"toml_datetime",
|
||||||
|
"toml_edit",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "toml_datetime"
|
||||||
|
version = "0.6.8"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41"
|
||||||
|
dependencies = [
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "toml_edit"
|
||||||
|
version = "0.22.22"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5"
|
||||||
|
dependencies = [
|
||||||
|
"indexmap 2.6.0",
|
||||||
|
"serde",
|
||||||
|
"serde_spanned",
|
||||||
|
"toml_datetime",
|
||||||
|
"winnow",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "trybuild"
|
||||||
|
version = "1.0.101"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8dcd332a5496c026f1e14b7f3d2b7bd98e509660c04239c58b0ba38a12daded4"
|
||||||
|
dependencies = [
|
||||||
|
"dissimilar",
|
||||||
|
"glob",
|
||||||
|
"serde",
|
||||||
|
"serde_derive",
|
||||||
|
"serde_json",
|
||||||
|
"target-triple",
|
||||||
|
"termcolor",
|
||||||
|
"toml",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "typenum"
|
||||||
|
version = "1.17.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unic-char-property"
|
name = "unic-char-property"
|
||||||
version = "0.9.0"
|
version = "0.9.0"
|
||||||
|
@ -1298,27 +1438,27 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unicode-ident"
|
name = "unicode-ident"
|
||||||
version = "1.0.12"
|
version = "1.0.13"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
|
checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unicode-width"
|
name = "unicode-width"
|
||||||
version = "0.1.13"
|
version = "0.1.14"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "0336d538f7abc86d282a4189614dfaa90810dfc2c6f6427eaf88e16311dd225d"
|
checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unicode-xid"
|
name = "unicode-xid"
|
||||||
version = "0.2.5"
|
version = "0.2.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "229730647fbc343e3a80e463c1db7f78f3855d3f3739bee0dda773c9a037c90a"
|
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unicode_names2"
|
name = "unicode_names2"
|
||||||
version = "1.2.2"
|
version = "1.3.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "addeebf294df7922a1164f729fb27ebbbcea99cc32b3bf08afab62757f707677"
|
checksum = "d1673eca9782c84de5f81b82e4109dcfb3611c8ba0d52930ec4a9478f547b2dd"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"phf",
|
"phf",
|
||||||
"unicode_names2_generator",
|
"unicode_names2_generator",
|
||||||
|
@ -1326,9 +1466,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unicode_names2_generator"
|
name = "unicode_names2_generator"
|
||||||
version = "1.2.2"
|
version = "1.3.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f444b8bba042fe3c1251ffaca35c603f2dc2ccc08d595c65a8c4f76f3e8426c0"
|
checksum = "b91e5b84611016120197efd7dc93ef76774f4e084cd73c9fb3ea4a86c570c56e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"getopts",
|
"getopts",
|
||||||
"log",
|
"log",
|
||||||
|
@ -1370,22 +1510,6 @@ version = "0.11.0+wasi-snapshot-preview1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
|
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "winapi"
|
|
||||||
version = "0.3.9"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
|
|
||||||
dependencies = [
|
|
||||||
"winapi-i686-pc-windows-gnu",
|
|
||||||
"winapi-x86_64-pc-windows-gnu",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "winapi-i686-pc-windows-gnu"
|
|
||||||
version = "0.4.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "winapi-util"
|
name = "winapi-util"
|
||||||
version = "0.1.9"
|
version = "0.1.9"
|
||||||
|
@ -1395,12 +1519,6 @@ dependencies = [
|
||||||
"windows-sys 0.59.0",
|
"windows-sys 0.59.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "winapi-x86_64-pc-windows-gnu"
|
|
||||||
version = "0.4.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows-sys"
|
name = "windows-sys"
|
||||||
version = "0.52.0"
|
version = "0.52.0"
|
||||||
|
@ -1483,6 +1601,15 @@ version = "0.52.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
|
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "winnow"
|
||||||
|
version = "0.6.20"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b"
|
||||||
|
dependencies = [
|
||||||
|
"memchr",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "yaml-rust"
|
name = "yaml-rust"
|
||||||
version = "0.4.5"
|
version = "0.4.5"
|
||||||
|
@ -1510,5 +1637,5 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.76",
|
"syn 2.0.87",
|
||||||
]
|
]
|
||||||
|
|
|
@ -4,6 +4,7 @@ members = [
|
||||||
"nac3ast",
|
"nac3ast",
|
||||||
"nac3parser",
|
"nac3parser",
|
||||||
"nac3core",
|
"nac3core",
|
||||||
|
"nac3core/nac3core_derive",
|
||||||
"nac3standalone",
|
"nac3standalone",
|
||||||
"nac3artiq",
|
"nac3artiq",
|
||||||
"runkernel",
|
"runkernel",
|
||||||
|
|
|
@ -2,11 +2,11 @@
|
||||||
"nodes": {
|
"nodes": {
|
||||||
"nixpkgs": {
|
"nixpkgs": {
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1723637854,
|
"lastModified": 1731319897,
|
||||||
"narHash": "sha256-med8+5DSWa2UnOqtdICndjDAEjxr5D7zaIiK4pn0Q7c=",
|
"narHash": "sha256-PbABj4tnbWFMfBp6OcUK5iGy1QY+/Z96ZcLpooIbuEI=",
|
||||||
"owner": "NixOS",
|
"owner": "NixOS",
|
||||||
"repo": "nixpkgs",
|
"repo": "nixpkgs",
|
||||||
"rev": "c3aa7b8938b17aebd2deecf7be0636000d62a2b9",
|
"rev": "dc460ec76cbff0e66e269457d7b728432263166c",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
|
|
10
flake.nix
10
flake.nix
|
@ -107,18 +107,18 @@
|
||||||
(pkgs.fetchFromGitHub {
|
(pkgs.fetchFromGitHub {
|
||||||
owner = "m-labs";
|
owner = "m-labs";
|
||||||
repo = "sipyco";
|
repo = "sipyco";
|
||||||
rev = "939f84f9b5eef7efbf7423c735d1834783b6140e";
|
rev = "094a6cd63ffa980ef63698920170e50dc9ba77fd";
|
||||||
sha256 = "sha256-15Nun4EY35j+6SPZkjzZtyH/ncxLS60KuGJjFh5kSTc=";
|
sha256 = "sha256-PPnAyDedUQ7Og/Cby9x5OT9wMkNGTP8GS53V6N/dk4w=";
|
||||||
})
|
})
|
||||||
(pkgs.fetchFromGitHub {
|
(pkgs.fetchFromGitHub {
|
||||||
owner = "m-labs";
|
owner = "m-labs";
|
||||||
repo = "artiq";
|
repo = "artiq";
|
||||||
rev = "923ca3377d42c815f979983134ec549dc39d3ca0";
|
rev = "28c9de3e251daa89a8c9fd79d5ab64a3ec03bac6";
|
||||||
sha256 = "sha256-oJoEeNEeNFSUyh6jXG8Tzp6qHVikeHS0CzfE+mODPgw=";
|
sha256 = "sha256-vAvpbHc5B+1wtG8zqN7j9dQE1ON+i22v+uqA+tw6Gak=";
|
||||||
})
|
})
|
||||||
];
|
];
|
||||||
buildInputs = [
|
buildInputs = [
|
||||||
(python3-mimalloc.withPackages(ps: [ ps.numpy ps.scipy ps.jsonschema ps.lmdb nac3artiq-instrumented ]))
|
(python3-mimalloc.withPackages(ps: [ ps.numpy ps.scipy ps.jsonschema ps.lmdb ps.platformdirs nac3artiq-instrumented ]))
|
||||||
pkgs.llvmPackages_14.llvm.out
|
pkgs.llvmPackages_14.llvm.out
|
||||||
];
|
];
|
||||||
phases = [ "buildPhase" "installPhase" ];
|
phases = [ "buildPhase" "installPhase" ];
|
||||||
|
|
|
@ -12,16 +12,10 @@ crate-type = ["cdylib"]
|
||||||
itertools = "0.13"
|
itertools = "0.13"
|
||||||
pyo3 = { version = "0.21", features = ["extension-module", "gil-refs"] }
|
pyo3 = { version = "0.21", features = ["extension-module", "gil-refs"] }
|
||||||
parking_lot = "0.12"
|
parking_lot = "0.12"
|
||||||
tempfile = "3.10"
|
tempfile = "3.13"
|
||||||
nac3parser = { path = "../nac3parser" }
|
|
||||||
nac3core = { path = "../nac3core" }
|
nac3core = { path = "../nac3core" }
|
||||||
nac3ld = { path = "../nac3ld" }
|
nac3ld = { path = "../nac3ld" }
|
||||||
|
|
||||||
[dependencies.inkwell]
|
|
||||||
version = "0.4"
|
|
||||||
default-features = false
|
|
||||||
features = ["llvm14-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]
|
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
init-llvm-profile = []
|
init-llvm-profile = []
|
||||||
no-escape-analysis = ["nac3core/no-escape-analysis"]
|
no-escape-analysis = ["nac3core/no-escape-analysis"]
|
||||||
|
|
|
@ -112,10 +112,15 @@ def extern(function):
|
||||||
register_function(function)
|
register_function(function)
|
||||||
return function
|
return function
|
||||||
|
|
||||||
def rpc(function):
|
|
||||||
"""Decorates a function declaration defined by the core device runtime."""
|
def rpc(arg=None, flags={}):
|
||||||
register_function(function)
|
"""Decorates a function or method to be executed on the host interpreter."""
|
||||||
return function
|
if arg is None:
|
||||||
|
def inner_decorator(function):
|
||||||
|
return rpc(function, flags)
|
||||||
|
return inner_decorator
|
||||||
|
register_function(arg)
|
||||||
|
return arg
|
||||||
|
|
||||||
def kernel(function_or_method):
|
def kernel(function_or_method):
|
||||||
"""Decorates a function or method to be executed on the core device."""
|
"""Decorates a function or method to be executed on the core device."""
|
||||||
|
@ -201,7 +206,7 @@ class Core:
|
||||||
embedding = EmbeddingMap()
|
embedding = EmbeddingMap()
|
||||||
|
|
||||||
if allow_registration:
|
if allow_registration:
|
||||||
compiler.analyze(registered_functions, registered_classes)
|
compiler.analyze(registered_functions, registered_classes, set())
|
||||||
allow_registration = False
|
allow_registration = False
|
||||||
|
|
||||||
if hasattr(method, "__self__"):
|
if hasattr(method, "__self__"):
|
||||||
|
|
|
@ -1,41 +1,3 @@
|
||||||
use nac3core::{
|
|
||||||
codegen::{
|
|
||||||
classes::{ListValue, RangeValue, UntypedArrayLikeAccessor},
|
|
||||||
expr::{destructure_range, gen_call},
|
|
||||||
llvm_intrinsics::{call_int_smax, call_stackrestore, call_stacksave},
|
|
||||||
model::*,
|
|
||||||
object::{any::AnyObject, ndarray::NDArrayObject},
|
|
||||||
stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with},
|
|
||||||
CodeGenContext, CodeGenerator,
|
|
||||||
},
|
|
||||||
symbol_resolver::ValueEnum,
|
|
||||||
toplevel::{
|
|
||||||
helper::{extract_ndims, PrimDef},
|
|
||||||
numpy::unpack_ndarray_var_tys,
|
|
||||||
DefinitionId, GenCall,
|
|
||||||
},
|
|
||||||
typecheck::typedef::{iter_type_vars, FunSignature, FuncArg, Type, TypeEnum, VarMap},
|
|
||||||
};
|
|
||||||
|
|
||||||
use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef};
|
|
||||||
|
|
||||||
use inkwell::{
|
|
||||||
context::Context,
|
|
||||||
module::Linkage,
|
|
||||||
types::IntType,
|
|
||||||
values::{BasicValue, BasicValueEnum, PointerValue, StructValue},
|
|
||||||
AddressSpace, IntPredicate, OptimizationLevel,
|
|
||||||
};
|
|
||||||
|
|
||||||
use pyo3::{
|
|
||||||
types::{PyDict, PyList},
|
|
||||||
PyObject, PyResult, Python,
|
|
||||||
};
|
|
||||||
|
|
||||||
use crate::{symbol_resolver::InnerResolver, timeline::TimeFns};
|
|
||||||
|
|
||||||
use inkwell::values::IntValue;
|
|
||||||
use itertools::Itertools;
|
|
||||||
use std::{
|
use std::{
|
||||||
collections::{hash_map::DefaultHasher, HashMap},
|
collections::{hash_map::DefaultHasher, HashMap},
|
||||||
hash::{Hash, Hasher},
|
hash::{Hash, Hasher},
|
||||||
|
@ -44,6 +6,40 @@ use std::{
|
||||||
sync::Arc,
|
sync::Arc,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use itertools::Itertools;
|
||||||
|
use pyo3::{
|
||||||
|
types::{PyDict, PyList},
|
||||||
|
PyObject, PyResult, Python,
|
||||||
|
};
|
||||||
|
|
||||||
|
use nac3core::{
|
||||||
|
codegen::{
|
||||||
|
expr::{destructure_range, gen_call},
|
||||||
|
irrt::call_ndarray_calc_size,
|
||||||
|
llvm_intrinsics::{call_int_smax, call_memcpy_generic, call_stackrestore, call_stacksave},
|
||||||
|
stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with},
|
||||||
|
types::{NDArrayType, ProxyType},
|
||||||
|
values::{
|
||||||
|
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, ProxyValue,
|
||||||
|
RangeValue, UntypedArrayLikeAccessor,
|
||||||
|
},
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
},
|
||||||
|
inkwell::{
|
||||||
|
context::Context,
|
||||||
|
module::Linkage,
|
||||||
|
types::{BasicType, IntType},
|
||||||
|
values::{BasicValueEnum, IntValue, PointerValue, StructValue},
|
||||||
|
AddressSpace, IntPredicate, OptimizationLevel,
|
||||||
|
},
|
||||||
|
nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef},
|
||||||
|
symbol_resolver::ValueEnum,
|
||||||
|
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, GenCall},
|
||||||
|
typecheck::typedef::{iter_type_vars, FunSignature, FuncArg, Type, TypeEnum, VarMap},
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{symbol_resolver::InnerResolver, timeline::TimeFns};
|
||||||
|
|
||||||
/// The parallelism mode within a block.
|
/// The parallelism mode within a block.
|
||||||
#[derive(Copy, Clone, Eq, PartialEq)]
|
#[derive(Copy, Clone, Eq, PartialEq)]
|
||||||
enum ParallelMode {
|
enum ParallelMode {
|
||||||
|
@ -459,42 +455,55 @@ fn format_rpc_arg<'ctx>(
|
||||||
// NAC3: NDArray = { usize, usize*, T* }
|
// NAC3: NDArray = { usize, usize*, T* }
|
||||||
// libproto_artiq: NDArray = [data[..], dim_sz[..]]
|
// libproto_artiq: NDArray = [data[..], dim_sz[..]]
|
||||||
|
|
||||||
let ndarray = AnyObject { ty: arg_ty, value: arg };
|
let llvm_i1 = ctx.ctx.bool_type();
|
||||||
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
let dtype = ctx.get_llvm_type(generator, ndarray.dtype);
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty);
|
||||||
let ndims = ndarray.ndims_llvm(generator, ctx.ctx);
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
let llvm_arg_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty);
|
||||||
|
let llvm_arg = llvm_arg_ty.map_value(arg.into_pointer_value(), None);
|
||||||
|
|
||||||
// `ndarray.data` is possibly not contiguous, and we need it to be contiguous for
|
let llvm_usize_sizeof = ctx
|
||||||
// the reader.
|
.builder
|
||||||
// Turning it into a ContiguousNDArray to get a `data` that is contiguous.
|
.build_int_truncate_or_bit_cast(llvm_arg_ty.size_type().size_of(), llvm_usize, "")
|
||||||
let carray = ndarray.make_contiguous_ndarray(generator, ctx, Any(dtype));
|
.unwrap();
|
||||||
|
let llvm_pdata_sizeof = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_truncate_or_bit_cast(
|
||||||
|
llvm_elem_ty.ptr_type(AddressSpace::default()).size_of(),
|
||||||
|
llvm_usize,
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let sizeof_sizet = Int(SizeT).size_of(generator, ctx.ctx);
|
let dims_buf_sz =
|
||||||
let sizeof_sizet = Int(SizeT).truncate_or_bit_cast(generator, ctx, sizeof_sizet);
|
ctx.builder.build_int_mul(llvm_arg.load_ndims(ctx), llvm_usize_sizeof, "").unwrap();
|
||||||
|
|
||||||
let sizeof_pdata = Ptr(Any(dtype)).size_of(generator, ctx.ctx);
|
let buffer_size =
|
||||||
let sizeof_pdata = Int(SizeT).truncate_or_bit_cast(generator, ctx, sizeof_pdata);
|
ctx.builder.build_int_add(dims_buf_sz, llvm_pdata_sizeof, "").unwrap();
|
||||||
|
|
||||||
let sizeof_buf_shape = sizeof_sizet.mul(ctx, ndims);
|
let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.arg").unwrap();
|
||||||
let sizeof_buf = sizeof_buf_shape.add(ctx, sizeof_pdata);
|
let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.arg"));
|
||||||
|
|
||||||
// buf = { data: void*, shape: [size_t; ndims]; }
|
call_memcpy_generic(
|
||||||
let buf = Int(Byte).array_alloca(generator, ctx, sizeof_buf.value);
|
ctx,
|
||||||
let buf_data = buf;
|
buffer.base_ptr(ctx, generator),
|
||||||
let buf_shape = buf_data.offset(ctx, sizeof_pdata.value);
|
llvm_arg.ptr_to_data(ctx),
|
||||||
|
llvm_pdata_sizeof,
|
||||||
|
llvm_i1.const_zero(),
|
||||||
|
);
|
||||||
|
|
||||||
// Write to `buf->data`
|
let pbuffer_dims_begin =
|
||||||
let carray_data = carray.get(generator, ctx, |f| f.data); // has type Ptr<Any>
|
unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) };
|
||||||
let carray_data = carray_data.pointer_cast(generator, ctx, Int(Byte));
|
call_memcpy_generic(
|
||||||
buf_data.copy_from(generator, ctx, carray_data, sizeof_pdata.value);
|
ctx,
|
||||||
|
pbuffer_dims_begin,
|
||||||
|
llvm_arg.shape().base_ptr(ctx, generator),
|
||||||
|
dims_buf_sz,
|
||||||
|
llvm_i1.const_zero(),
|
||||||
|
);
|
||||||
|
|
||||||
// Write to `buf->shape`
|
buffer.base_ptr(ctx, generator)
|
||||||
let carray_shape = ndarray.instance.get(generator, ctx, |f| f.shape);
|
|
||||||
let carray_shape_i8 = carray_shape.pointer_cast(generator, ctx, Int(Byte));
|
|
||||||
buf_shape.copy_from(generator, ctx, carray_shape_i8, sizeof_buf_shape.value);
|
|
||||||
|
|
||||||
buf.value
|
|
||||||
}
|
}
|
||||||
|
|
||||||
_ => {
|
_ => {
|
||||||
|
@ -504,7 +513,7 @@ fn format_rpc_arg<'ctx>(
|
||||||
ctx.builder.build_store(arg_slot, arg).unwrap();
|
ctx.builder.build_store(arg_slot, arg).unwrap();
|
||||||
|
|
||||||
ctx.builder
|
ctx.builder
|
||||||
.build_bitcast(arg_slot, llvm_pi8, "rpc.arg")
|
.build_bit_cast(arg_slot, llvm_pi8, "rpc.arg")
|
||||||
.map(BasicValueEnum::into_pointer_value)
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
@ -555,10 +564,8 @@ fn format_rpc_ret<'ctx>(
|
||||||
|
|
||||||
let result = match &*ctx.unifier.get_ty_immutable(ret_ty) {
|
let result = match &*ctx.unifier.get_ty_immutable(ret_ty) {
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||||
// FIXME: It is possible to rewrite everything more neatly with `Model<'ctx>`, but this is not too important.
|
let llvm_i1 = ctx.ctx.bool_type();
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
let num_0 = Int(SizeT).const_0(generator, ctx.ctx);
|
|
||||||
let num_8 = Int(SizeT).const_int(generator, ctx.ctx, 8, false);
|
|
||||||
|
|
||||||
// Round `val` up to its modulo `power_of_two`
|
// Round `val` up to its modulo `power_of_two`
|
||||||
let round_up = |ctx: &mut CodeGenContext<'ctx, '_>,
|
let round_up = |ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
@ -584,36 +591,60 @@ fn format_rpc_ret<'ctx>(
|
||||||
.unwrap()
|
.unwrap()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Setup types
|
||||||
|
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty);
|
||||||
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
let llvm_ret_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty);
|
||||||
|
|
||||||
// Allocate the resulting ndarray
|
// Allocate the resulting ndarray
|
||||||
// A condition after format_rpc_ret ensures this will not be popped this off.
|
// A condition after format_rpc_ret ensures this will not be popped this off.
|
||||||
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty);
|
let ndarray = llvm_ret_ty.new_value(generator, ctx, Some("rpc.result"));
|
||||||
let ndims = extract_ndims(&ctx.unifier, ndims);
|
|
||||||
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims);
|
|
||||||
|
|
||||||
// NOTE: Current content of `ndarray`:
|
// Setup ndims
|
||||||
// - * `data` - **NOT YET** allocated.
|
let ndims =
|
||||||
// - * `itemsize` - initialized to be size_of(dtype).
|
if let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) {
|
||||||
// - * `ndims` - initialized.
|
assert_eq!(values.len(), 1);
|
||||||
// - * `shape` - allocated; has uninitialized values.
|
|
||||||
// - * `strides` - allocated; has uninitialized values.
|
|
||||||
|
|
||||||
let itemsize = ndarray.instance.get(generator, ctx, |f| f.itemsize); // Same as doing a `ctx.get_llvm_type` on `dtype` and get its `size_of()`.
|
u64::try_from(values[0].clone()).unwrap()
|
||||||
let dtype_llvm = ctx.get_llvm_type(generator, dtype);
|
} else {
|
||||||
|
unreachable!();
|
||||||
|
};
|
||||||
|
// Set `ndarray.ndims`
|
||||||
|
ndarray.store_ndims(ctx, generator, llvm_usize.const_int(ndims, false));
|
||||||
|
// Allocate `ndarray.shape` [size_t; ndims]
|
||||||
|
ndarray.create_shape(ctx, llvm_usize, ndarray.load_ndims(ctx));
|
||||||
|
|
||||||
|
/*
|
||||||
|
ndarray now:
|
||||||
|
- .ndims: initialized
|
||||||
|
- .shape: allocated but uninitialized .shape
|
||||||
|
- .data: uninitialized
|
||||||
|
*/
|
||||||
|
|
||||||
|
let llvm_usize_sizeof = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_truncate_or_bit_cast(llvm_usize.size_of(), llvm_usize, "")
|
||||||
|
.unwrap();
|
||||||
|
let llvm_pdata_sizeof = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_truncate_or_bit_cast(
|
||||||
|
llvm_elem_ty.ptr_type(AddressSpace::default()).size_of(),
|
||||||
|
llvm_usize,
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let llvm_elem_sizeof = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_truncate_or_bit_cast(llvm_elem_ty.size_of().unwrap(), llvm_usize, "")
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
// Allocates a buffer for the initial RPC'ed object, which is guaranteed to be
|
// Allocates a buffer for the initial RPC'ed object, which is guaranteed to be
|
||||||
// (4 + 4 * ndims) bytes with 8-byte alignment
|
// (4 + 4 * ndims) bytes with 8-byte alignment
|
||||||
let sizeof_size_t = Int(SizeT).size_of(generator, ctx.ctx);
|
let sizeof_dims =
|
||||||
let sizeof_size_t = Int(SizeT).z_extend_or_truncate(generator, ctx, sizeof_size_t); // sizeof(size_t)
|
ctx.builder.build_int_mul(ndarray.load_ndims(ctx), llvm_usize_sizeof, "").unwrap();
|
||||||
|
let unaligned_buffer_size =
|
||||||
let sizeof_ptr = Ptr(Int(Byte)).size_of(generator, ctx.ctx);
|
ctx.builder.build_int_add(sizeof_dims, llvm_pdata_sizeof, "").unwrap();
|
||||||
let sizeof_ptr = Int(SizeT).z_extend_or_truncate(generator, ctx, sizeof_ptr); // sizeof(uint8_t*)
|
let buffer_size = round_up(ctx, unaligned_buffer_size, llvm_usize.const_int(8, false));
|
||||||
|
|
||||||
let sizeof_shape = ndarray.ndims_llvm(generator, ctx.ctx).mul(ctx, sizeof_size_t); // sizeof([size_t; ndims]); same as the # of bytes of `ndarray.shape`.
|
|
||||||
|
|
||||||
// Size of the buffer for the initial `rpc_recv()`.
|
|
||||||
let unaligned_buffer_size = sizeof_ptr.add(ctx, sizeof_shape); // sizeof(uint8_t*) + sizeof([size_t; ndims]).
|
|
||||||
let buffer_size = round_up(ctx, unaligned_buffer_size.value, num_8.value);
|
|
||||||
let buffer_size = unsafe { Int(SizeT).believe_value(buffer_size) };
|
|
||||||
|
|
||||||
let stackptr = call_stacksave(ctx, None);
|
let stackptr = call_stacksave(ctx, None);
|
||||||
// Just to be absolutely sure, alloca in [i8 x 8] slices to force 8-byte alignment
|
// Just to be absolutely sure, alloca in [i8 x 8] slices to force 8-byte alignment
|
||||||
|
@ -621,16 +652,18 @@ fn format_rpc_ret<'ctx>(
|
||||||
.builder
|
.builder
|
||||||
.build_array_alloca(
|
.build_array_alloca(
|
||||||
llvm_i8_8,
|
llvm_i8_8,
|
||||||
ctx.builder.build_int_unsigned_div(buffer_size.value, num_8.value, "").unwrap(),
|
ctx.builder
|
||||||
|
.build_int_unsigned_div(buffer_size, llvm_usize.const_int(8, false), "")
|
||||||
|
.unwrap(),
|
||||||
"rpc.buffer",
|
"rpc.buffer",
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let buffer = ctx
|
let buffer = ctx
|
||||||
.builder
|
.builder
|
||||||
.build_bitcast(buffer, llvm_pi8, "")
|
.build_bit_cast(buffer, llvm_pi8, "")
|
||||||
.map(BasicValueEnum::into_pointer_value)
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let buffer = unsafe { Ptr(Int(Byte)).believe_value(buffer) };
|
let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, None);
|
||||||
|
|
||||||
// The first call to `rpc_recv` reads the top-level ndarray object: [pdata, shape]
|
// The first call to `rpc_recv` reads the top-level ndarray object: [pdata, shape]
|
||||||
//
|
//
|
||||||
|
@ -638,20 +671,24 @@ fn format_rpc_ret<'ctx>(
|
||||||
let ndarray_nbytes = ctx
|
let ndarray_nbytes = ctx
|
||||||
.build_call_or_invoke(
|
.build_call_or_invoke(
|
||||||
rpc_recv,
|
rpc_recv,
|
||||||
&[buffer.value.into()], // Reads [usize; ndims]
|
&[buffer.base_ptr(ctx, generator).into()], // Reads [usize; ndims]. NOTE: We are allocated [size_t; ndims].
|
||||||
"rpc.size.next",
|
"rpc.size.next",
|
||||||
)
|
)
|
||||||
.map(BasicValueEnum::into_int_value)
|
.map(BasicValueEnum::into_int_value)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let ndarray_nbytes = unsafe { Int(SizeT).believe_value(ndarray_nbytes) };
|
|
||||||
|
|
||||||
// debug_assert(ndarray_nbytes > 0)
|
// debug_assert(ndarray_nbytes > 0)
|
||||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||||
let cmp = ndarray_nbytes.compare(ctx, IntPredicate::UGT, num_0);
|
|
||||||
|
|
||||||
ctx.make_assert(
|
ctx.make_assert(
|
||||||
generator,
|
generator,
|
||||||
cmp.value,
|
ctx.builder
|
||||||
|
.build_int_compare(
|
||||||
|
IntPredicate::UGT,
|
||||||
|
ndarray_nbytes,
|
||||||
|
ndarray_nbytes.get_type().const_zero(),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap(),
|
||||||
"0:AssertionError",
|
"0:AssertionError",
|
||||||
"Unexpected RPC termination for ndarray - Expected data buffer next",
|
"Unexpected RPC termination for ndarray - Expected data buffer next",
|
||||||
[None, None, None],
|
[None, None, None],
|
||||||
|
@ -660,39 +697,49 @@ fn format_rpc_ret<'ctx>(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy shape from the buffer to `ndarray.shape`.
|
// Copy shape from the buffer to `ndarray.shape`.
|
||||||
// We need to skip the first `sizeof(uint8_t*)` bytes to skip the `pdata` in `[pdata, shape]`.
|
let pbuffer_dims =
|
||||||
let pbuffer_shape = buffer.offset(ctx, sizeof_ptr.value);
|
unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) };
|
||||||
let pbuffer_shape = pbuffer_shape.pointer_cast(generator, ctx, Int(SizeT));
|
|
||||||
|
|
||||||
// Copy shape from buffer to `ndarray.shape`
|
|
||||||
ndarray.copy_shape_from_array(generator, ctx, pbuffer_shape);
|
|
||||||
|
|
||||||
|
call_memcpy_generic(
|
||||||
|
ctx,
|
||||||
|
ndarray.shape().base_ptr(ctx, generator),
|
||||||
|
pbuffer_dims,
|
||||||
|
sizeof_dims,
|
||||||
|
llvm_i1.const_zero(),
|
||||||
|
);
|
||||||
// Restore stack from before allocation of buffer
|
// Restore stack from before allocation of buffer
|
||||||
call_stackrestore(ctx, stackptr);
|
call_stackrestore(ctx, stackptr);
|
||||||
|
|
||||||
// Allocate `ndarray.data`.
|
// Allocate `ndarray.data`.
|
||||||
// `ndarray.shape` must be initialized beforehand in this implementation
|
// `ndarray.shape` must be initialized beforehand in this implementation
|
||||||
// (for ndarray.create_data() to know how many elements to allocate)
|
// (for ndarray.create_data() to know how many elements to allocate)
|
||||||
ndarray.create_data(generator, ctx); // NOTE: the strides of `ndarray` has also been set to contiguous in `::create_data()`.
|
let num_elements =
|
||||||
|
call_ndarray_calc_size(generator, ctx, &ndarray.shape(), (None, None));
|
||||||
|
|
||||||
// debug_assert(nelems * sizeof(T) >= ndarray_nbytes)
|
// debug_assert(nelems * sizeof(T) >= ndarray_nbytes)
|
||||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||||
let num_elements = ndarray.size(generator, ctx);
|
let sizeof_data =
|
||||||
|
ctx.builder.build_int_mul(num_elements, llvm_elem_sizeof, "").unwrap();
|
||||||
let expected_ndarray_nbytes = num_elements.mul(ctx, itemsize);
|
|
||||||
let cmp = expected_ndarray_nbytes.compare(ctx, IntPredicate::UGE, ndarray_nbytes);
|
|
||||||
|
|
||||||
ctx.make_assert(
|
ctx.make_assert(
|
||||||
generator,
|
generator,
|
||||||
cmp.value,
|
ctx.builder.build_int_compare(IntPredicate::UGE,
|
||||||
|
sizeof_data,
|
||||||
|
ndarray_nbytes,
|
||||||
|
"",
|
||||||
|
).unwrap(),
|
||||||
"0:AssertionError",
|
"0:AssertionError",
|
||||||
"Unexpected allocation size request for ndarray data - Expected up to {0} bytes, got {1} bytes",
|
"Unexpected allocation size request for ndarray data - Expected up to {0} bytes, got {1} bytes",
|
||||||
[Some(expected_ndarray_nbytes.value), Some(ndarray_nbytes.value), None],
|
[Some(sizeof_data), Some(ndarray_nbytes), None],
|
||||||
ctx.current_loc,
|
ctx.current_loc,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let ndarray_data = ndarray.instance.get(generator, ctx, |f| f.data);
|
ndarray.create_data(ctx, llvm_elem_ty, num_elements);
|
||||||
|
|
||||||
|
let ndarray_data = ndarray.data().base_ptr(ctx, generator);
|
||||||
|
let ndarray_data_i8 =
|
||||||
|
ctx.builder.build_pointer_cast(ndarray_data, llvm_pi8, "").unwrap();
|
||||||
|
|
||||||
// NOTE: Currently on `prehead_bb`
|
// NOTE: Currently on `prehead_bb`
|
||||||
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
||||||
|
@ -701,7 +748,7 @@ fn format_rpc_ret<'ctx>(
|
||||||
ctx.builder.position_at_end(head_bb);
|
ctx.builder.position_at_end(head_bb);
|
||||||
|
|
||||||
let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap();
|
let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap();
|
||||||
phi.add_incoming(&[(&ndarray_data.value, prehead_bb)]);
|
phi.add_incoming(&[(&ndarray_data_i8, prehead_bb)]);
|
||||||
|
|
||||||
let alloc_size = ctx
|
let alloc_size = ctx
|
||||||
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
|
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
|
||||||
|
@ -716,12 +763,12 @@ fn format_rpc_ret<'ctx>(
|
||||||
|
|
||||||
ctx.builder.position_at_end(alloc_bb);
|
ctx.builder.position_at_end(alloc_bb);
|
||||||
// Align the allocation to sizeof(T)
|
// Align the allocation to sizeof(T)
|
||||||
let alloc_size = round_up(ctx, alloc_size, itemsize.value);
|
let alloc_size = round_up(ctx, alloc_size, llvm_elem_sizeof);
|
||||||
let alloc_ptr = ctx
|
let alloc_ptr = ctx
|
||||||
.builder
|
.builder
|
||||||
.build_array_alloca(
|
.build_array_alloca(
|
||||||
dtype_llvm,
|
llvm_elem_ty,
|
||||||
ctx.builder.build_int_unsigned_div(alloc_size, itemsize.value, "").unwrap(),
|
ctx.builder.build_int_unsigned_div(alloc_size, llvm_elem_sizeof, "").unwrap(),
|
||||||
"rpc.alloc",
|
"rpc.alloc",
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
@ -731,12 +778,12 @@ fn format_rpc_ret<'ctx>(
|
||||||
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
||||||
|
|
||||||
ctx.builder.position_at_end(tail_bb);
|
ctx.builder.position_at_end(tail_bb);
|
||||||
ndarray.instance.value.as_basic_value_enum()
|
ndarray.as_base_value().into()
|
||||||
}
|
}
|
||||||
|
|
||||||
_ => {
|
_ => {
|
||||||
let slot = ctx.builder.build_alloca(llvm_ret_ty, "rpc.ret.slot").unwrap();
|
let slot = ctx.builder.build_alloca(llvm_ret_ty, "rpc.ret.slot").unwrap();
|
||||||
let slotgen = ctx.builder.build_bitcast(slot, llvm_pi8, "rpc.ret.ptr").unwrap();
|
let slotgen = ctx.builder.build_bit_cast(slot, llvm_pi8, "rpc.ret.ptr").unwrap();
|
||||||
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
||||||
ctx.builder.position_at_end(head_bb);
|
ctx.builder.position_at_end(head_bb);
|
||||||
|
|
||||||
|
@ -757,7 +804,7 @@ fn format_rpc_ret<'ctx>(
|
||||||
let alloc_ptr =
|
let alloc_ptr =
|
||||||
ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap();
|
ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap();
|
||||||
let alloc_ptr =
|
let alloc_ptr =
|
||||||
ctx.builder.build_bitcast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap();
|
ctx.builder.build_bit_cast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap();
|
||||||
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
|
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
|
||||||
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
||||||
|
|
||||||
|
@ -775,6 +822,7 @@ fn rpc_codegen_callback_fn<'ctx>(
|
||||||
fun: (&FunSignature, DefinitionId),
|
fun: (&FunSignature, DefinitionId),
|
||||||
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
||||||
generator: &mut dyn CodeGenerator,
|
generator: &mut dyn CodeGenerator,
|
||||||
|
is_async: bool,
|
||||||
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
|
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
|
||||||
let int8 = ctx.ctx.i8_type();
|
let int8 = ctx.ctx.i8_type();
|
||||||
let int32 = ctx.ctx.i32_type();
|
let int32 = ctx.ctx.i32_type();
|
||||||
|
@ -883,6 +931,29 @@ fn rpc_codegen_callback_fn<'ctx>(
|
||||||
}
|
}
|
||||||
|
|
||||||
// call
|
// call
|
||||||
|
if is_async {
|
||||||
|
let rpc_send_async = ctx.module.get_function("rpc_send_async").unwrap_or_else(|| {
|
||||||
|
ctx.module.add_function(
|
||||||
|
"rpc_send_async",
|
||||||
|
ctx.ctx.void_type().fn_type(
|
||||||
|
&[
|
||||||
|
int32.into(),
|
||||||
|
tag_ptr_type.ptr_type(AddressSpace::default()).into(),
|
||||||
|
ptr_type.ptr_type(AddressSpace::default()).into(),
|
||||||
|
],
|
||||||
|
false,
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
ctx.builder
|
||||||
|
.build_call(
|
||||||
|
rpc_send_async,
|
||||||
|
&[service_id.into(), tag_ptr.into(), args_ptr.into()],
|
||||||
|
"rpc.send",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
} else {
|
||||||
let rpc_send = ctx.module.get_function("rpc_send").unwrap_or_else(|| {
|
let rpc_send = ctx.module.get_function("rpc_send").unwrap_or_else(|| {
|
||||||
ctx.module.add_function(
|
ctx.module.add_function(
|
||||||
"rpc_send",
|
"rpc_send",
|
||||||
|
@ -900,10 +971,15 @@ fn rpc_codegen_callback_fn<'ctx>(
|
||||||
ctx.builder
|
ctx.builder
|
||||||
.build_call(rpc_send, &[service_id.into(), tag_ptr.into(), args_ptr.into()], "rpc.send")
|
.build_call(rpc_send, &[service_id.into(), tag_ptr.into(), args_ptr.into()], "rpc.send")
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
// reclaim stack space used by arguments
|
// reclaim stack space used by arguments
|
||||||
call_stackrestore(ctx, stackptr);
|
call_stackrestore(ctx, stackptr);
|
||||||
|
|
||||||
|
if is_async {
|
||||||
|
// async RPCs do not return any values
|
||||||
|
Ok(None)
|
||||||
|
} else {
|
||||||
let result = format_rpc_ret(generator, ctx, fun.0.ret);
|
let result = format_rpc_ret(generator, ctx, fun.0.ret);
|
||||||
|
|
||||||
if !result.is_some_and(|res| res.get_type().is_pointer_type()) {
|
if !result.is_some_and(|res| res.get_type().is_pointer_type()) {
|
||||||
|
@ -913,12 +989,14 @@ fn rpc_codegen_callback_fn<'ctx>(
|
||||||
|
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn attributes_writeback(
|
pub fn attributes_writeback<'ctx>(
|
||||||
ctx: &mut CodeGenContext<'_, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
generator: &mut dyn CodeGenerator,
|
generator: &mut dyn CodeGenerator,
|
||||||
inner_resolver: &InnerResolver,
|
inner_resolver: &InnerResolver,
|
||||||
host_attributes: &PyObject,
|
host_attributes: &PyObject,
|
||||||
|
return_obj: Option<(Type, ValueEnum<'ctx>)>,
|
||||||
) -> Result<(), String> {
|
) -> Result<(), String> {
|
||||||
Python::with_gil(|py| -> PyResult<Result<(), String>> {
|
Python::with_gil(|py| -> PyResult<Result<(), String>> {
|
||||||
let host_attributes: &PyList = host_attributes.downcast(py)?;
|
let host_attributes: &PyList = host_attributes.downcast(py)?;
|
||||||
|
@ -928,6 +1006,11 @@ pub fn attributes_writeback(
|
||||||
let zero = int32.const_zero();
|
let zero = int32.const_zero();
|
||||||
let mut values = Vec::new();
|
let mut values = Vec::new();
|
||||||
let mut scratch_buffer = Vec::new();
|
let mut scratch_buffer = Vec::new();
|
||||||
|
|
||||||
|
if let Some((ty, obj)) = return_obj {
|
||||||
|
values.push((ty, obj.to_basic_value_enum(ctx, generator, ty).unwrap()));
|
||||||
|
}
|
||||||
|
|
||||||
for val in (*globals).values() {
|
for val in (*globals).values() {
|
||||||
let val = val.as_ref(py);
|
let val = val.as_ref(py);
|
||||||
let ty = inner_resolver.get_obj_type(
|
let ty = inner_resolver.get_obj_type(
|
||||||
|
@ -1006,7 +1089,7 @@ pub fn attributes_writeback(
|
||||||
let args: Vec<_> =
|
let args: Vec<_> =
|
||||||
values.into_iter().map(|(_, val)| (None, ValueEnum::Dynamic(val))).collect();
|
values.into_iter().map(|(_, val)| (None, ValueEnum::Dynamic(val))).collect();
|
||||||
if let Err(e) =
|
if let Err(e) =
|
||||||
rpc_codegen_callback_fn(ctx, None, (&fun, PrimDef::Int32.id()), args, generator)
|
rpc_codegen_callback_fn(ctx, None, (&fun, PrimDef::Int32.id()), args, generator, true)
|
||||||
{
|
{
|
||||||
return Ok(Err(e));
|
return Ok(Err(e));
|
||||||
}
|
}
|
||||||
|
@ -1016,9 +1099,9 @@ pub fn attributes_writeback(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn rpc_codegen_callback() -> Arc<GenCall> {
|
pub fn rpc_codegen_callback(is_async: bool) -> Arc<GenCall> {
|
||||||
Arc::new(GenCall::new(Box::new(|ctx, obj, fun, args, generator| {
|
Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| {
|
||||||
rpc_codegen_callback_fn(ctx, obj, fun, args, generator)
|
rpc_codegen_callback_fn(ctx, obj, fun, args, generator, is_async)
|
||||||
})))
|
})))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1232,7 +1315,8 @@ fn polymorphic_print<'ctx>(
|
||||||
fmt.push('[');
|
fmt.push('[');
|
||||||
flush(ctx, generator, &mut fmt, &mut args);
|
flush(ctx, generator, &mut fmt, &mut args);
|
||||||
|
|
||||||
let val = ListValue::from_ptr_val(value.into_pointer_value(), llvm_usize, None);
|
let val =
|
||||||
|
ListValue::from_pointer_value(value.into_pointer_value(), llvm_usize, None);
|
||||||
let len = val.load_size(ctx, None);
|
let len = val.load_size(ctx, None);
|
||||||
let last =
|
let last =
|
||||||
ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap();
|
ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap();
|
||||||
|
@ -1283,46 +1367,62 @@ fn polymorphic_print<'ctx>(
|
||||||
}
|
}
|
||||||
|
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||||
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
||||||
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
fmt.push_str("array([");
|
fmt.push_str("array([");
|
||||||
flush(ctx, generator, &mut fmt, &mut args);
|
flush(ctx, generator, &mut fmt, &mut args);
|
||||||
|
|
||||||
let ndarray = AnyObject { ty, value };
|
let val = NDArrayValue::from_pointer_value(
|
||||||
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
|
value.into_pointer_value(),
|
||||||
|
llvm_elem_ty,
|
||||||
|
llvm_usize,
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
let len = call_ndarray_calc_size(generator, ctx, &val.shape(), (None, None));
|
||||||
|
let last =
|
||||||
|
ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap();
|
||||||
|
|
||||||
let num_0 = Int(SizeT).const_0(generator, ctx.ctx);
|
gen_for_callback_incrementing(
|
||||||
|
|
||||||
// Print `ndarray` as a flat list delimited by interspersed with ", \0"
|
|
||||||
ndarray.foreach(generator, ctx, |generator, ctx, _, hdl| {
|
|
||||||
let i = hdl.get_index(generator, ctx);
|
|
||||||
let scalar = hdl.get_scalar(generator, ctx);
|
|
||||||
|
|
||||||
// if (i != 0) { puts(", "); }
|
|
||||||
gen_if_callback(
|
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
|_, ctx| {
|
None,
|
||||||
let not_first = i.compare(ctx, IntPredicate::NE, num_0);
|
llvm_usize.const_zero(),
|
||||||
Ok(not_first.value)
|
(len, false),
|
||||||
},
|
|generator, ctx, _, i| {
|
||||||
|generator, ctx| {
|
let elem = unsafe { val.data().get_unchecked(ctx, generator, &i, None) };
|
||||||
printf(ctx, generator, ", \0".into(), Vec::default());
|
|
||||||
Ok(())
|
|
||||||
},
|
|
||||||
|_, _| Ok(()),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
// Print element
|
|
||||||
polymorphic_print(
|
polymorphic_print(
|
||||||
ctx,
|
ctx,
|
||||||
generator,
|
generator,
|
||||||
&[(scalar.ty, scalar.value.into())],
|
&[(elem_ty, elem.into())],
|
||||||
"",
|
"",
|
||||||
None,
|
None,
|
||||||
true,
|
true,
|
||||||
as_rtio,
|
as_rtio,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
gen_if_callback(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
|_, ctx| {
|
||||||
|
Ok(ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(IntPredicate::ULT, i, last, "")
|
||||||
|
.unwrap())
|
||||||
|
},
|
||||||
|
|generator, ctx| {
|
||||||
|
printf(ctx, generator, ", \0".into(), Vec::default());
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
})?;
|
},
|
||||||
|
|_, _| Ok(()),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
llvm_usize.const_int(1, false),
|
||||||
|
)?;
|
||||||
|
|
||||||
fmt.push_str(")]");
|
fmt.push_str(")]");
|
||||||
flush(ctx, generator, &mut fmt, &mut args);
|
flush(ctx, generator, &mut fmt, &mut args);
|
||||||
|
@ -1332,7 +1432,7 @@ fn polymorphic_print<'ctx>(
|
||||||
fmt.push_str("range(");
|
fmt.push_str("range(");
|
||||||
flush(ctx, generator, &mut fmt, &mut args);
|
flush(ctx, generator, &mut fmt, &mut args);
|
||||||
|
|
||||||
let val = RangeValue::from_ptr_val(value.into_pointer_value(), None);
|
let val = RangeValue::from_pointer_value(value.into_pointer_value(), None);
|
||||||
|
|
||||||
let (start, stop, step) = destructure_range(ctx, val);
|
let (start, stop, step) = destructure_range(ctx, val);
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,4 @@
|
||||||
#![deny(
|
#![deny(future_incompatible, let_underscore, nonstandard_style, clippy::all)]
|
||||||
future_incompatible,
|
|
||||||
let_underscore,
|
|
||||||
nonstandard_style,
|
|
||||||
rust_2024_compatibility,
|
|
||||||
clippy::all
|
|
||||||
)]
|
|
||||||
#![warn(clippy::pedantic)]
|
#![warn(clippy::pedantic)]
|
||||||
#![allow(
|
#![allow(
|
||||||
unsafe_op_in_unsafe_fn,
|
unsafe_op_in_unsafe_fn,
|
||||||
|
@ -16,64 +10,65 @@
|
||||||
clippy::wildcard_imports
|
clippy::wildcard_imports
|
||||||
)]
|
)]
|
||||||
|
|
||||||
use std::collections::{HashMap, HashSet};
|
use std::{
|
||||||
use std::fs;
|
collections::{HashMap, HashSet},
|
||||||
use std::io::Write;
|
fs,
|
||||||
use std::process::Command;
|
io::Write,
|
||||||
use std::rc::Rc;
|
process::Command,
|
||||||
use std::sync::Arc;
|
rc::Rc,
|
||||||
|
sync::Arc,
|
||||||
|
};
|
||||||
|
|
||||||
use inkwell::{
|
use itertools::Itertools;
|
||||||
|
use parking_lot::{Mutex, RwLock};
|
||||||
|
use pyo3::{
|
||||||
|
create_exception, exceptions,
|
||||||
|
prelude::*,
|
||||||
|
types::{PyBytes, PyDict, PyNone, PySet},
|
||||||
|
};
|
||||||
|
use tempfile::{self, TempDir};
|
||||||
|
|
||||||
|
use nac3core::{
|
||||||
|
codegen::{
|
||||||
|
concrete_type::ConcreteTypeStore, gen_func_impl, irrt::load_irrt, CodeGenLLVMOptions,
|
||||||
|
CodeGenTargetMachineOptions, CodeGenTask, CodeGenerator, WithCall, WorkerRegistry,
|
||||||
|
},
|
||||||
|
inkwell::{
|
||||||
context::Context,
|
context::Context,
|
||||||
memory_buffer::MemoryBuffer,
|
memory_buffer::MemoryBuffer,
|
||||||
module::{Linkage, Module},
|
module::{FlagBehavior, Linkage, Module},
|
||||||
passes::PassBuilderOptions,
|
passes::PassBuilderOptions,
|
||||||
support::is_multithreaded,
|
support::is_multithreaded,
|
||||||
targets::*,
|
targets::*,
|
||||||
OptimizationLevel,
|
OptimizationLevel,
|
||||||
};
|
},
|
||||||
use itertools::Itertools;
|
nac3parser::{
|
||||||
use nac3core::codegen::{gen_func_impl, CodeGenLLVMOptions, CodeGenTargetMachineOptions};
|
ast::{Constant, ExprKind, Located, Stmt, StmtKind, StrRef},
|
||||||
use nac3core::toplevel::builtins::get_exn_constructor;
|
|
||||||
use nac3core::typecheck::typedef::{into_var_map, TypeEnum, Unifier, VarMap};
|
|
||||||
use nac3parser::{
|
|
||||||
ast::{ExprKind, Stmt, StmtKind, StrRef},
|
|
||||||
parser::parse_program,
|
parser::parse_program,
|
||||||
};
|
},
|
||||||
use pyo3::create_exception;
|
|
||||||
use pyo3::prelude::*;
|
|
||||||
use pyo3::{exceptions, types::PyBytes, types::PyDict, types::PySet};
|
|
||||||
|
|
||||||
use parking_lot::{Mutex, RwLock};
|
|
||||||
|
|
||||||
use nac3core::{
|
|
||||||
codegen::irrt::load_irrt,
|
|
||||||
codegen::{concrete_type::ConcreteTypeStore, CodeGenTask, WithCall, WorkerRegistry},
|
|
||||||
symbol_resolver::SymbolResolver,
|
symbol_resolver::SymbolResolver,
|
||||||
toplevel::{
|
toplevel::{
|
||||||
|
builtins::get_exn_constructor,
|
||||||
composer::{BuiltinFuncCreator, BuiltinFuncSpec, ComposerConfig, TopLevelComposer},
|
composer::{BuiltinFuncCreator, BuiltinFuncSpec, ComposerConfig, TopLevelComposer},
|
||||||
DefinitionId, GenCall, TopLevelDef,
|
DefinitionId, GenCall, TopLevelDef,
|
||||||
},
|
},
|
||||||
typecheck::typedef::{FunSignature, FuncArg},
|
typecheck::{
|
||||||
typecheck::{type_inferencer::PrimitiveStore, typedef::Type},
|
type_inferencer::PrimitiveStore,
|
||||||
|
typedef::{into_var_map, FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
use nac3ld::Linker;
|
use nac3ld::Linker;
|
||||||
|
|
||||||
use crate::{
|
use codegen::{
|
||||||
codegen::{
|
|
||||||
attributes_writeback, gen_core_log, gen_rtio_log, rpc_codegen_callback, ArtiqCodeGenerator,
|
attributes_writeback, gen_core_log, gen_rtio_log, rpc_codegen_callback, ArtiqCodeGenerator,
|
||||||
},
|
|
||||||
symbol_resolver::{DeferredEvaluationStore, InnerResolver, PythonHelper, Resolver},
|
|
||||||
};
|
};
|
||||||
use tempfile::{self, TempDir};
|
use symbol_resolver::{DeferredEvaluationStore, InnerResolver, PythonHelper, Resolver};
|
||||||
|
use timeline::TimeFns;
|
||||||
|
|
||||||
mod codegen;
|
mod codegen;
|
||||||
mod symbol_resolver;
|
mod symbol_resolver;
|
||||||
mod timeline;
|
mod timeline;
|
||||||
|
|
||||||
use timeline::TimeFns;
|
|
||||||
|
|
||||||
#[derive(PartialEq, Clone, Copy)]
|
#[derive(PartialEq, Clone, Copy)]
|
||||||
enum Isa {
|
enum Isa {
|
||||||
Host,
|
Host,
|
||||||
|
@ -147,14 +142,32 @@ impl Nac3 {
|
||||||
module: &PyObject,
|
module: &PyObject,
|
||||||
registered_class_ids: &HashSet<u64>,
|
registered_class_ids: &HashSet<u64>,
|
||||||
) -> PyResult<()> {
|
) -> PyResult<()> {
|
||||||
let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> {
|
let (module_name, source_file, source) =
|
||||||
|
Python::with_gil(|py| -> PyResult<(String, String, String)> {
|
||||||
let module: &PyAny = module.extract(py)?;
|
let module: &PyAny = module.extract(py)?;
|
||||||
Ok((module.getattr("__name__")?.extract()?, module.getattr("__file__")?.extract()?))
|
let source_file = module.getattr("__file__");
|
||||||
|
let (source_file, source) = if let Ok(source_file) = source_file {
|
||||||
|
let source_file = source_file.extract()?;
|
||||||
|
(
|
||||||
|
source_file,
|
||||||
|
fs::read_to_string(source_file).map_err(|e| {
|
||||||
|
exceptions::PyIOError::new_err(format!(
|
||||||
|
"failed to read input file: {e}"
|
||||||
|
))
|
||||||
|
})?,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
// kernels submitted by content have no file
|
||||||
|
// but still can provide source by StringLoader
|
||||||
|
let get_src_fn = module
|
||||||
|
.getattr("__loader__")?
|
||||||
|
.extract::<PyObject>()?
|
||||||
|
.getattr(py, "get_source")?;
|
||||||
|
("<expcontent>", get_src_fn.call1(py, (PyNone::get(py),))?.extract(py)?)
|
||||||
|
};
|
||||||
|
Ok((module.getattr("__name__")?.extract()?, source_file.to_string(), source))
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let source = fs::read_to_string(&source_file).map_err(|e| {
|
|
||||||
exceptions::PyIOError::new_err(format!("failed to read input file: {e}"))
|
|
||||||
})?;
|
|
||||||
let parser_result = parse_program(&source, source_file.into())
|
let parser_result = parse_program(&source, source_file.into())
|
||||||
.map_err(|e| exceptions::PySyntaxError::new_err(format!("parse error: {e}")))?;
|
.map_err(|e| exceptions::PySyntaxError::new_err(format!("parse error: {e}")))?;
|
||||||
|
|
||||||
|
@ -194,10 +207,8 @@ impl Nac3 {
|
||||||
body.retain(|stmt| {
|
body.retain(|stmt| {
|
||||||
if let StmtKind::FunctionDef { ref decorator_list, .. } = stmt.node {
|
if let StmtKind::FunctionDef { ref decorator_list, .. } = stmt.node {
|
||||||
decorator_list.iter().any(|decorator| {
|
decorator_list.iter().any(|decorator| {
|
||||||
if let ExprKind::Name { id, .. } = decorator.node {
|
if let Some(id) = decorator_id_string(decorator) {
|
||||||
id.to_string() == "kernel"
|
id == "kernel" || id == "portable" || id == "rpc"
|
||||||
|| id.to_string() == "portable"
|
|
||||||
|| id.to_string() == "rpc"
|
|
||||||
} else {
|
} else {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
@ -210,9 +221,8 @@ impl Nac3 {
|
||||||
}
|
}
|
||||||
StmtKind::FunctionDef { ref decorator_list, .. } => {
|
StmtKind::FunctionDef { ref decorator_list, .. } => {
|
||||||
decorator_list.iter().any(|decorator| {
|
decorator_list.iter().any(|decorator| {
|
||||||
if let ExprKind::Name { id, .. } = decorator.node {
|
if let Some(id) = decorator_id_string(decorator) {
|
||||||
let id = id.to_string();
|
id == "extern" || id == "kernel" || id == "portable" || id == "rpc"
|
||||||
id == "extern" || id == "portable" || id == "kernel" || id == "rpc"
|
|
||||||
} else {
|
} else {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
@ -478,9 +488,25 @@ impl Nac3 {
|
||||||
|
|
||||||
match &stmt.node {
|
match &stmt.node {
|
||||||
StmtKind::FunctionDef { decorator_list, .. } => {
|
StmtKind::FunctionDef { decorator_list, .. } => {
|
||||||
if decorator_list.iter().any(|decorator| matches!(decorator.node, ExprKind::Name { id, .. } if id == "rpc".into())) {
|
if decorator_list
|
||||||
store_fun.call1(py, (def_id.0.into_py(py), module.getattr(py, name.to_string().as_str()).unwrap())).unwrap();
|
.iter()
|
||||||
rpc_ids.push((None, def_id));
|
.any(|decorator| decorator_id_string(decorator) == Some("rpc".to_string()))
|
||||||
|
{
|
||||||
|
store_fun
|
||||||
|
.call1(
|
||||||
|
py,
|
||||||
|
(
|
||||||
|
def_id.0.into_py(py),
|
||||||
|
module.getattr(py, name.to_string().as_str()).unwrap(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let is_async = decorator_list.iter().any(|decorator| {
|
||||||
|
decorator_get_flags(decorator)
|
||||||
|
.iter()
|
||||||
|
.any(|constant| *constant == Constant::Str("async".into()))
|
||||||
|
});
|
||||||
|
rpc_ids.push((None, def_id, is_async));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
StmtKind::ClassDef { name, body, .. } => {
|
StmtKind::ClassDef { name, body, .. } => {
|
||||||
|
@ -488,19 +514,26 @@ impl Nac3 {
|
||||||
let class_obj = module.getattr(py, class_name.as_str()).unwrap();
|
let class_obj = module.getattr(py, class_name.as_str()).unwrap();
|
||||||
for stmt in body {
|
for stmt in body {
|
||||||
if let StmtKind::FunctionDef { name, decorator_list, .. } = &stmt.node {
|
if let StmtKind::FunctionDef { name, decorator_list, .. } = &stmt.node {
|
||||||
if decorator_list.iter().any(|decorator| matches!(decorator.node, ExprKind::Name { id, .. } if id == "rpc".into())) {
|
if decorator_list.iter().any(|decorator| {
|
||||||
|
decorator_id_string(decorator) == Some("rpc".to_string())
|
||||||
|
}) {
|
||||||
|
let is_async = decorator_list.iter().any(|decorator| {
|
||||||
|
decorator_get_flags(decorator)
|
||||||
|
.iter()
|
||||||
|
.any(|constant| *constant == Constant::Str("async".into()))
|
||||||
|
});
|
||||||
if name == &"__init__".into() {
|
if name == &"__init__".into() {
|
||||||
return Err(CompileError::new_err(format!(
|
return Err(CompileError::new_err(format!(
|
||||||
"compilation failed\n----------\nThe constructor of class {} should not be decorated with rpc decorator (at {})",
|
"compilation failed\n----------\nThe constructor of class {} should not be decorated with rpc decorator (at {})",
|
||||||
class_name, stmt.location
|
class_name, stmt.location
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
rpc_ids.push((Some((class_obj.clone(), *name)), def_id));
|
rpc_ids.push((Some((class_obj.clone(), *name)), def_id, is_async));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => ()
|
_ => (),
|
||||||
}
|
}
|
||||||
|
|
||||||
let id = *name_to_pyid.get(&name).unwrap();
|
let id = *name_to_pyid.get(&name).unwrap();
|
||||||
|
@ -556,7 +589,7 @@ impl Nac3 {
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
// Process IRRT
|
// Process IRRT
|
||||||
let context = inkwell::context::Context::create();
|
let context = Context::create();
|
||||||
let irrt = load_irrt(&context, resolver.as_ref());
|
let irrt = load_irrt(&context, resolver.as_ref());
|
||||||
|
|
||||||
let fun_signature =
|
let fun_signature =
|
||||||
|
@ -596,13 +629,12 @@ impl Nac3 {
|
||||||
let top_level = Arc::new(composer.make_top_level_context());
|
let top_level = Arc::new(composer.make_top_level_context());
|
||||||
|
|
||||||
{
|
{
|
||||||
let rpc_codegen = rpc_codegen_callback();
|
|
||||||
let defs = top_level.definitions.read();
|
let defs = top_level.definitions.read();
|
||||||
for (class_data, id) in &rpc_ids {
|
for (class_data, id, is_async) in &rpc_ids {
|
||||||
let mut def = defs[id.0].write();
|
let mut def = defs[id.0].write();
|
||||||
match &mut *def {
|
match &mut *def {
|
||||||
TopLevelDef::Function { codegen_callback, .. } => {
|
TopLevelDef::Function { codegen_callback, .. } => {
|
||||||
*codegen_callback = Some(rpc_codegen.clone());
|
*codegen_callback = Some(rpc_codegen_callback(*is_async));
|
||||||
}
|
}
|
||||||
TopLevelDef::Class { methods, .. } => {
|
TopLevelDef::Class { methods, .. } => {
|
||||||
let (class_def, method_name) = class_data.as_ref().unwrap();
|
let (class_def, method_name) = class_data.as_ref().unwrap();
|
||||||
|
@ -613,7 +645,7 @@ impl Nac3 {
|
||||||
if let TopLevelDef::Function { codegen_callback, .. } =
|
if let TopLevelDef::Function { codegen_callback, .. } =
|
||||||
&mut *defs[id.0].write()
|
&mut *defs[id.0].write()
|
||||||
{
|
{
|
||||||
*codegen_callback = Some(rpc_codegen.clone());
|
*codegen_callback = Some(rpc_codegen_callback(*is_async));
|
||||||
store_fun
|
store_fun
|
||||||
.call1(
|
.call1(
|
||||||
py,
|
py,
|
||||||
|
@ -628,6 +660,11 @@ impl Nac3 {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
TopLevelDef::Variable { .. } => {
|
||||||
|
return Err(CompileError::new_err(String::from(
|
||||||
|
"Unsupported @rpc annotation on global variable",
|
||||||
|
)))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -648,33 +685,12 @@ impl Nac3 {
|
||||||
let task = CodeGenTask {
|
let task = CodeGenTask {
|
||||||
subst: Vec::default(),
|
subst: Vec::default(),
|
||||||
symbol_name: "__modinit__".to_string(),
|
symbol_name: "__modinit__".to_string(),
|
||||||
body: instance.body,
|
|
||||||
signature,
|
|
||||||
resolver: resolver.clone(),
|
|
||||||
store,
|
|
||||||
unifier_index: instance.unifier_id,
|
|
||||||
calls: instance.calls,
|
|
||||||
id: 0,
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut store = ConcreteTypeStore::new();
|
|
||||||
let mut cache = HashMap::new();
|
|
||||||
let signature = store.from_signature(
|
|
||||||
&mut composer.unifier,
|
|
||||||
&self.primitive,
|
|
||||||
&fun_signature,
|
|
||||||
&mut cache,
|
|
||||||
);
|
|
||||||
let signature = store.add_cty(signature);
|
|
||||||
let attributes_writeback_task = CodeGenTask {
|
|
||||||
subst: Vec::default(),
|
|
||||||
symbol_name: "attributes_writeback".to_string(),
|
|
||||||
body: Arc::new(Vec::default()),
|
body: Arc::new(Vec::default()),
|
||||||
signature,
|
signature,
|
||||||
resolver,
|
resolver,
|
||||||
store,
|
store,
|
||||||
unifier_index: instance.unifier_id,
|
unifier_index: instance.unifier_id,
|
||||||
calls: Arc::new(HashMap::default()),
|
calls: instance.calls,
|
||||||
id: 0,
|
id: 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -687,7 +703,7 @@ impl Nac3 {
|
||||||
let buffer = buffer.as_slice().into();
|
let buffer = buffer.as_slice().into();
|
||||||
membuffer.lock().push(buffer);
|
membuffer.lock().push(buffer);
|
||||||
})));
|
})));
|
||||||
let size_t = Context::create()
|
let size_t = context
|
||||||
.ptr_sized_int_type(&self.get_llvm_target_machine().get_target_data(), None)
|
.ptr_sized_int_type(&self.get_llvm_target_machine().get_target_data(), None)
|
||||||
.get_bit_width();
|
.get_bit_width();
|
||||||
let num_threads = if is_multithreaded() { 4 } else { 1 };
|
let num_threads = if is_multithreaded() { 4 } else { 1 };
|
||||||
|
@ -698,19 +714,27 @@ impl Nac3 {
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let membuffer = membuffers.clone();
|
let membuffer = membuffers.clone();
|
||||||
|
let mut has_return = false;
|
||||||
py.allow_threads(|| {
|
py.allow_threads(|| {
|
||||||
let (registry, handles) =
|
let (registry, handles) =
|
||||||
WorkerRegistry::create_workers(threads, top_level.clone(), &self.llvm_options, &f);
|
WorkerRegistry::create_workers(threads, top_level.clone(), &self.llvm_options, &f);
|
||||||
registry.add_task(task);
|
|
||||||
registry.wait_tasks_complete(handles);
|
|
||||||
|
|
||||||
let mut generator =
|
let mut generator = ArtiqCodeGenerator::new("main".to_string(), size_t, self.time_fns);
|
||||||
ArtiqCodeGenerator::new("attributes_writeback".to_string(), size_t, self.time_fns);
|
let context = Context::create();
|
||||||
let context = inkwell::context::Context::create();
|
let module = context.create_module("main");
|
||||||
let module = context.create_module("attributes_writeback");
|
|
||||||
let target_machine = self.llvm_options.create_target_machine().unwrap();
|
let target_machine = self.llvm_options.create_target_machine().unwrap();
|
||||||
module.set_data_layout(&target_machine.get_target_data().get_data_layout());
|
module.set_data_layout(&target_machine.get_target_data().get_data_layout());
|
||||||
module.set_triple(&target_machine.get_triple());
|
module.set_triple(&target_machine.get_triple());
|
||||||
|
module.add_basic_value_flag(
|
||||||
|
"Debug Info Version",
|
||||||
|
FlagBehavior::Warning,
|
||||||
|
context.i32_type().const_int(3, false),
|
||||||
|
);
|
||||||
|
module.add_basic_value_flag(
|
||||||
|
"Dwarf Version",
|
||||||
|
FlagBehavior::Warning,
|
||||||
|
context.i32_type().const_int(4, false),
|
||||||
|
);
|
||||||
let builder = context.create_builder();
|
let builder = context.create_builder();
|
||||||
let (_, module, _) = gen_func_impl(
|
let (_, module, _) = gen_func_impl(
|
||||||
&context,
|
&context,
|
||||||
|
@ -718,9 +742,27 @@ impl Nac3 {
|
||||||
®istry,
|
®istry,
|
||||||
builder,
|
builder,
|
||||||
module,
|
module,
|
||||||
attributes_writeback_task,
|
task,
|
||||||
|generator, ctx| {
|
|generator, ctx| {
|
||||||
attributes_writeback(ctx, generator, inner_resolver.as_ref(), &host_attributes)
|
assert_eq!(instance.body.len(), 1, "toplevel module should have 1 statement");
|
||||||
|
let StmtKind::Expr { value: ref expr, .. } = instance.body[0].node else {
|
||||||
|
unreachable!("toplevel statement must be an expression")
|
||||||
|
};
|
||||||
|
let ExprKind::Call { .. } = expr.node else {
|
||||||
|
unreachable!("toplevel expression must be a function call")
|
||||||
|
};
|
||||||
|
|
||||||
|
let return_obj =
|
||||||
|
generator.gen_expr(ctx, expr)?.map(|value| (expr.custom.unwrap(), value));
|
||||||
|
has_return = return_obj.is_some();
|
||||||
|
registry.wait_tasks_complete(handles);
|
||||||
|
attributes_writeback(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
inner_resolver.as_ref(),
|
||||||
|
&host_attributes,
|
||||||
|
return_obj,
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
@ -729,35 +771,23 @@ impl Nac3 {
|
||||||
membuffer.lock().push(buffer);
|
membuffer.lock().push(buffer);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
embedding_map.setattr("expects_return", has_return).unwrap();
|
||||||
|
|
||||||
// Link all modules into `main`.
|
// Link all modules into `main`.
|
||||||
let buffers = membuffers.lock();
|
let buffers = membuffers.lock();
|
||||||
let main = context
|
let main = context
|
||||||
.create_module_from_ir(MemoryBuffer::create_from_memory_range(&buffers[0], "main"))
|
.create_module_from_ir(MemoryBuffer::create_from_memory_range(
|
||||||
|
buffers.last().unwrap(),
|
||||||
|
"main",
|
||||||
|
))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
for buffer in buffers.iter().skip(1) {
|
for buffer in buffers.iter().rev().skip(1) {
|
||||||
let other = context
|
let other = context
|
||||||
.create_module_from_ir(MemoryBuffer::create_from_memory_range(buffer, "main"))
|
.create_module_from_ir(MemoryBuffer::create_from_memory_range(buffer, "main"))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
main.link_in_module(other).map_err(|err| CompileError::new_err(err.to_string()))?;
|
main.link_in_module(other).map_err(|err| CompileError::new_err(err.to_string()))?;
|
||||||
}
|
}
|
||||||
let builder = context.create_builder();
|
|
||||||
let modinit_return = main
|
|
||||||
.get_function("__modinit__")
|
|
||||||
.unwrap()
|
|
||||||
.get_last_basic_block()
|
|
||||||
.unwrap()
|
|
||||||
.get_terminator()
|
|
||||||
.unwrap();
|
|
||||||
builder.position_before(&modinit_return);
|
|
||||||
builder
|
|
||||||
.build_call(
|
|
||||||
main.get_function("attributes_writeback").unwrap(),
|
|
||||||
&[],
|
|
||||||
"attributes_writeback",
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
main.link_in_module(irrt).map_err(|err| CompileError::new_err(err.to_string()))?;
|
main.link_in_module(irrt).map_err(|err| CompileError::new_err(err.to_string()))?;
|
||||||
|
|
||||||
let mut function_iter = main.get_first_function();
|
let mut function_iter = main.get_first_function();
|
||||||
|
@ -844,6 +874,41 @@ impl Nac3 {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Retrieves the Name.id from a decorator, supports decorators with arguments.
|
||||||
|
fn decorator_id_string(decorator: &Located<ExprKind>) -> Option<String> {
|
||||||
|
if let ExprKind::Name { id, .. } = decorator.node {
|
||||||
|
// Bare decorator
|
||||||
|
return Some(id.to_string());
|
||||||
|
} else if let ExprKind::Call { func, .. } = &decorator.node {
|
||||||
|
// Decorators that are calls (e.g. "@rpc()") have Call for the node,
|
||||||
|
// need to extract the id from within.
|
||||||
|
if let ExprKind::Name { id, .. } = func.node {
|
||||||
|
return Some(id.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Retrieves flags from a decorator, if any.
|
||||||
|
fn decorator_get_flags(decorator: &Located<ExprKind>) -> Vec<Constant> {
|
||||||
|
let mut flags = vec![];
|
||||||
|
if let ExprKind::Call { keywords, .. } = &decorator.node {
|
||||||
|
for keyword in keywords {
|
||||||
|
if keyword.node.arg != Some("flags".into()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if let ExprKind::Set { elts } = &keyword.node.value.node {
|
||||||
|
for elt in elts {
|
||||||
|
if let ExprKind::Constant { value, .. } = &elt.node {
|
||||||
|
flags.push(value.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
flags
|
||||||
|
}
|
||||||
|
|
||||||
fn link_with_lld(elf_filename: String, obj_filename: String) -> PyResult<()> {
|
fn link_with_lld(elf_filename: String, obj_filename: String) -> PyResult<()> {
|
||||||
let linker_args = vec![
|
let linker_args = vec![
|
||||||
"-shared".to_string(),
|
"-shared".to_string(),
|
||||||
|
@ -1025,7 +1090,12 @@ impl Nac3 {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn analyze(&mut self, functions: &PySet, classes: &PySet) -> PyResult<()> {
|
fn analyze(
|
||||||
|
&mut self,
|
||||||
|
functions: &PySet,
|
||||||
|
classes: &PySet,
|
||||||
|
content_modules: &PySet,
|
||||||
|
) -> PyResult<()> {
|
||||||
let (modules, class_ids) =
|
let (modules, class_ids) =
|
||||||
Python::with_gil(|py| -> PyResult<(HashMap<u64, PyObject>, HashSet<u64>)> {
|
Python::with_gil(|py| -> PyResult<(HashMap<u64, PyObject>, HashSet<u64>)> {
|
||||||
let mut modules: HashMap<u64, PyObject> = HashMap::new();
|
let mut modules: HashMap<u64, PyObject> = HashMap::new();
|
||||||
|
@ -1035,14 +1105,22 @@ impl Nac3 {
|
||||||
let getmodule_fn = PyModule::import(py, "inspect")?.getattr("getmodule")?;
|
let getmodule_fn = PyModule::import(py, "inspect")?.getattr("getmodule")?;
|
||||||
|
|
||||||
for function in functions {
|
for function in functions {
|
||||||
let module = getmodule_fn.call1((function,))?.extract()?;
|
let module: PyObject = getmodule_fn.call1((function,))?.extract()?;
|
||||||
|
if !module.is_none(py) {
|
||||||
modules.insert(id_fn.call1((&module,))?.extract()?, module);
|
modules.insert(id_fn.call1((&module,))?.extract()?, module);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
for class in classes {
|
for class in classes {
|
||||||
let module = getmodule_fn.call1((class,))?.extract()?;
|
let module: PyObject = getmodule_fn.call1((class,))?.extract()?;
|
||||||
|
if !module.is_none(py) {
|
||||||
modules.insert(id_fn.call1((&module,))?.extract()?, module);
|
modules.insert(id_fn.call1((&module,))?.extract()?, module);
|
||||||
|
}
|
||||||
class_ids.insert(id_fn.call1((class,))?.extract()?);
|
class_ids.insert(id_fn.call1((class,))?.extract()?);
|
||||||
}
|
}
|
||||||
|
for module in content_modules {
|
||||||
|
let module: PyObject = module.extract()?;
|
||||||
|
modules.insert(id_fn.call1((&module,))?.extract()?, module);
|
||||||
|
}
|
||||||
Ok((modules, class_ids))
|
Ok((modules, class_ids))
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
|
|
@ -1,17 +1,30 @@
|
||||||
use crate::PrimitivePythonId;
|
use std::{
|
||||||
use inkwell::{
|
collections::{HashMap, HashSet},
|
||||||
module::Linkage,
|
sync::{
|
||||||
types::BasicType,
|
atomic::{AtomicBool, Ordering::Relaxed},
|
||||||
values::{BasicValue, BasicValueEnum},
|
Arc,
|
||||||
AddressSpace,
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
|
use parking_lot::RwLock;
|
||||||
|
use pyo3::{
|
||||||
|
types::{PyDict, PyTuple},
|
||||||
|
PyAny, PyObject, PyResult, Python,
|
||||||
|
};
|
||||||
|
|
||||||
use nac3core::{
|
use nac3core::{
|
||||||
codegen::{
|
codegen::{
|
||||||
model::*,
|
types::{NDArrayType, ProxyType},
|
||||||
object::ndarray::{make_contiguous_strides, NDArray},
|
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
},
|
},
|
||||||
|
inkwell::{
|
||||||
|
module::Linkage,
|
||||||
|
types::{BasicType, BasicTypeEnum},
|
||||||
|
values::BasicValueEnum,
|
||||||
|
AddressSpace,
|
||||||
|
},
|
||||||
|
nac3parser::ast::{self, StrRef},
|
||||||
symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum},
|
symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum},
|
||||||
toplevel::{
|
toplevel::{
|
||||||
helper::PrimDef,
|
helper::PrimDef,
|
||||||
|
@ -23,19 +36,8 @@ use nac3core::{
|
||||||
typedef::{into_var_map, iter_type_vars, Type, TypeEnum, TypeVar, Unifier, VarMap},
|
typedef::{into_var_map, iter_type_vars, Type, TypeEnum, TypeVar, Unifier, VarMap},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
use nac3parser::ast::{self, StrRef};
|
|
||||||
use parking_lot::RwLock;
|
use super::PrimitivePythonId;
|
||||||
use pyo3::{
|
|
||||||
types::{PyDict, PyTuple},
|
|
||||||
PyAny, PyErr, PyObject, PyResult, Python,
|
|
||||||
};
|
|
||||||
use std::{
|
|
||||||
collections::{HashMap, HashSet},
|
|
||||||
sync::{
|
|
||||||
atomic::{AtomicBool, Ordering::Relaxed},
|
|
||||||
Arc,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
pub enum PrimitiveValue {
|
pub enum PrimitiveValue {
|
||||||
I32(i32),
|
I32(i32),
|
||||||
|
@ -1086,12 +1088,15 @@ impl InnerResolver {
|
||||||
let (ndarray_dtype, ndarray_ndims) =
|
let (ndarray_dtype, ndarray_ndims) =
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty);
|
unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty);
|
||||||
|
|
||||||
let dtype = Any(ctx.get_llvm_type(generator, ndarray_dtype));
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let ndarray_dtype_llvm_ty = ctx.get_llvm_type(generator, ndarray_dtype);
|
||||||
|
let ndarray_llvm_ty = NDArrayType::new(generator, ctx.ctx, ndarray_dtype_llvm_ty);
|
||||||
|
|
||||||
{
|
{
|
||||||
if self.global_value_ids.read().contains_key(&id) {
|
if self.global_value_ids.read().contains_key(&id) {
|
||||||
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
|
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
|
||||||
ctx.module.add_global(
|
ctx.module.add_global(
|
||||||
Struct(NDArray).llvm_type(generator, ctx.ctx),
|
ndarray_llvm_ty.as_base_type().get_element_type().into_struct_type(),
|
||||||
Some(AddressSpace::default()),
|
Some(AddressSpace::default()),
|
||||||
&id_str,
|
&id_str,
|
||||||
)
|
)
|
||||||
|
@ -1111,138 +1116,104 @@ impl InnerResolver {
|
||||||
} else {
|
} else {
|
||||||
todo!("Unpacking literal of more than one element unimplemented")
|
todo!("Unpacking literal of more than one element unimplemented")
|
||||||
};
|
};
|
||||||
let Ok(ndims) = u64::try_from(ndarray_ndims) else {
|
let Ok(ndarray_ndims) = u64::try_from(ndarray_ndims) else {
|
||||||
unreachable!("Expected u64 value for ndarray_ndims")
|
unreachable!("Expected u64 value for ndarray_ndims")
|
||||||
};
|
};
|
||||||
|
|
||||||
// Obtain the shape of the ndarray
|
// Obtain the shape of the ndarray
|
||||||
let shape_tuple: &PyTuple = obj.getattr("shape")?.downcast()?;
|
let shape_tuple: &PyTuple = obj.getattr("shape")?.downcast()?;
|
||||||
assert_eq!(shape_tuple.len(), ndims as usize);
|
assert_eq!(shape_tuple.len(), ndarray_ndims as usize);
|
||||||
|
let shape_values: Result<Option<Vec<_>>, _> = shape_tuple
|
||||||
// The Rust type inferencer cannot figure this out
|
|
||||||
let shape_values: Result<Vec<Instance<'ctx, Int<SizeT>>>, PyErr> = shape_tuple
|
|
||||||
.iter()
|
.iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.map(|(i, elem)| {
|
.map(|(i, elem)| {
|
||||||
let value = self
|
self.get_obj_value(py, elem, ctx, generator, ctx.primitives.usize()).map_err(
|
||||||
.get_obj_value(py, elem, ctx, generator, ctx.primitives.usize())
|
|e| super::CompileError::new_err(format!("Error getting element {i}: {e}")),
|
||||||
.map_err(|e| {
|
)
|
||||||
super::CompileError::new_err(format!("Error getting element {i}: {e}"))
|
|
||||||
})?
|
|
||||||
.unwrap();
|
|
||||||
let value = Int(SizeT).check_value(generator, ctx.ctx, value).unwrap();
|
|
||||||
Ok(value)
|
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
let shape_values = shape_values?;
|
let shape_values = shape_values?.unwrap();
|
||||||
|
let shape_values = llvm_usize.const_array(
|
||||||
// Also use this opportunity to get the constant values of `shape_values` for calculating strides.
|
&shape_values.into_iter().map(BasicValueEnum::into_int_value).collect_vec(),
|
||||||
let shape_u64s = shape_values
|
);
|
||||||
.iter()
|
|
||||||
.map(|dim| {
|
|
||||||
assert!(dim.value.is_const());
|
|
||||||
dim.value.get_zero_extended_constant().unwrap()
|
|
||||||
})
|
|
||||||
.collect_vec();
|
|
||||||
let shape_values = Int(SizeT).const_array(generator, ctx.ctx, &shape_values);
|
|
||||||
|
|
||||||
// create a global for ndarray.shape and initialize it using the shape
|
// create a global for ndarray.shape and initialize it using the shape
|
||||||
let shape_global = ctx.module.add_global(
|
let shape_global = ctx.module.add_global(
|
||||||
Array { len: AnyLen(ndims as u32), item: Int(SizeT) }.llvm_type(generator, ctx.ctx),
|
llvm_usize.array_type(ndarray_ndims as u32),
|
||||||
Some(AddressSpace::default()),
|
Some(AddressSpace::default()),
|
||||||
&(id_str.clone() + ".shape"),
|
&(id_str.clone() + ".shape"),
|
||||||
);
|
);
|
||||||
shape_global.set_initializer(&shape_values.value);
|
shape_global.set_initializer(&shape_values);
|
||||||
|
|
||||||
// Obtain the (flattened) elements of the ndarray
|
// Obtain the (flattened) elements of the ndarray
|
||||||
let sz: usize = obj.getattr("size")?.extract()?;
|
let sz: usize = obj.getattr("size")?.extract()?;
|
||||||
let data_values: Vec<Instance<'ctx, Any>> = (0..sz)
|
let data: Result<Option<Vec<_>>, _> = (0..sz)
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
obj.getattr("flat")?.get_item(i).and_then(|elem| {
|
obj.getattr("flat")?.get_item(i).and_then(|elem| {
|
||||||
let value = self
|
self.get_obj_value(py, elem, ctx, generator, ndarray_dtype).map_err(|e| {
|
||||||
.get_obj_value(py, elem, ctx, generator, ndarray_dtype)
|
super::CompileError::new_err(format!("Error getting element {i}: {e}"))
|
||||||
.map_err(|e| {
|
})
|
||||||
super::CompileError::new_err(format!(
|
})
|
||||||
"Error getting element {i}: {e}"
|
})
|
||||||
))
|
.collect();
|
||||||
})?
|
let data = data?.unwrap().into_iter();
|
||||||
.unwrap();
|
let data = match ndarray_dtype_llvm_ty {
|
||||||
|
BasicTypeEnum::ArrayType(ty) => {
|
||||||
|
ty.const_array(&data.map(BasicValueEnum::into_array_value).collect_vec())
|
||||||
|
}
|
||||||
|
|
||||||
let value = dtype.check_value(generator, ctx.ctx, value).unwrap();
|
BasicTypeEnum::FloatType(ty) => {
|
||||||
Ok(value)
|
ty.const_array(&data.map(BasicValueEnum::into_float_value).collect_vec())
|
||||||
})
|
}
|
||||||
})
|
|
||||||
.try_collect()?;
|
BasicTypeEnum::IntType(ty) => {
|
||||||
let data = dtype.const_array(generator, ctx.ctx, &data_values);
|
ty.const_array(&data.map(BasicValueEnum::into_int_value).collect_vec())
|
||||||
|
}
|
||||||
|
|
||||||
|
BasicTypeEnum::PointerType(ty) => {
|
||||||
|
ty.const_array(&data.map(BasicValueEnum::into_pointer_value).collect_vec())
|
||||||
|
}
|
||||||
|
|
||||||
|
BasicTypeEnum::StructType(ty) => {
|
||||||
|
ty.const_array(&data.map(BasicValueEnum::into_struct_value).collect_vec())
|
||||||
|
}
|
||||||
|
|
||||||
|
BasicTypeEnum::VectorType(_) => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
// create a global for ndarray.data and initialize it using the elements
|
// create a global for ndarray.data and initialize it using the elements
|
||||||
//
|
|
||||||
// NOTE: NDArray's `data` is `u8*`. Here, `data_global` is an array of `dtype`.
|
|
||||||
// We will have to cast it to an `u8*` later.
|
|
||||||
let data_global = ctx.module.add_global(
|
let data_global = ctx.module.add_global(
|
||||||
Array { len: AnyLen(sz as u32), item: dtype }.llvm_type(generator, ctx.ctx),
|
ndarray_dtype_llvm_ty.array_type(sz as u32),
|
||||||
Some(AddressSpace::default()),
|
Some(AddressSpace::default()),
|
||||||
&(id_str.clone() + ".data"),
|
&(id_str.clone() + ".data"),
|
||||||
);
|
);
|
||||||
data_global.set_initializer(&data.value);
|
data_global.set_initializer(&data);
|
||||||
|
|
||||||
// Get the constant itemsize.
|
|
||||||
let itemsize = dtype.llvm_type(generator, ctx.ctx).size_of().unwrap();
|
|
||||||
let itemsize = itemsize.get_zero_extended_constant().unwrap();
|
|
||||||
|
|
||||||
// Create the strides needed for ndarray.strides
|
|
||||||
let strides = make_contiguous_strides(itemsize, ndims, &shape_u64s);
|
|
||||||
let strides = strides
|
|
||||||
.into_iter()
|
|
||||||
.map(|stride| Int(SizeT).const_int(generator, ctx.ctx, stride, false))
|
|
||||||
.collect_vec();
|
|
||||||
let strides = Int(SizeT).const_array(generator, ctx.ctx, &strides);
|
|
||||||
|
|
||||||
// create a global for ndarray.strides and initialize it
|
|
||||||
let strides_global = ctx.module.add_global(
|
|
||||||
Array { len: AnyLen(ndims as u32), item: Int(Byte) }.llvm_type(generator, ctx.ctx),
|
|
||||||
Some(AddressSpace::default()),
|
|
||||||
&(id_str.clone() + ".strides"),
|
|
||||||
);
|
|
||||||
strides_global.set_initializer(&strides.value);
|
|
||||||
|
|
||||||
// create a global for the ndarray object and initialize it
|
// create a global for the ndarray object and initialize it
|
||||||
// We are also doing [`Model::check_value`] instead of [`Model::believe_value`] to catch bugs.
|
let value = ndarray_llvm_ty
|
||||||
|
.as_base_type()
|
||||||
|
.get_element_type()
|
||||||
|
.into_struct_type()
|
||||||
|
.const_named_struct(&[
|
||||||
|
llvm_usize.const_int(ndarray_ndims, false).into(),
|
||||||
|
shape_global
|
||||||
|
.as_pointer_value()
|
||||||
|
.const_cast(llvm_usize.ptr_type(AddressSpace::default()))
|
||||||
|
.into(),
|
||||||
|
data_global
|
||||||
|
.as_pointer_value()
|
||||||
|
.const_cast(ndarray_dtype_llvm_ty.ptr_type(AddressSpace::default()))
|
||||||
|
.into(),
|
||||||
|
]);
|
||||||
|
|
||||||
// NOTE: data_global is an array of dtype, we want a `u8*`.
|
let ndarray = ctx.module.add_global(
|
||||||
let ndarray_data = Ptr(dtype).check_value(generator, ctx.ctx, data_global).unwrap();
|
ndarray_llvm_ty.as_base_type().get_element_type().into_struct_type(),
|
||||||
let ndarray_data = Ptr(Int(Byte)).pointer_cast(generator, ctx, ndarray_data.value);
|
|
||||||
|
|
||||||
let ndarray_itemsize = Int(SizeT).const_int(generator, ctx.ctx, itemsize, false);
|
|
||||||
|
|
||||||
let ndarray_ndims = Int(SizeT).const_int(generator, ctx.ctx, ndims, false);
|
|
||||||
|
|
||||||
let ndarray_shape =
|
|
||||||
Ptr(Int(SizeT)).check_value(generator, ctx.ctx, shape_global).unwrap();
|
|
||||||
|
|
||||||
let ndarray_strides =
|
|
||||||
Ptr(Int(SizeT)).check_value(generator, ctx.ctx, strides_global).unwrap();
|
|
||||||
|
|
||||||
let ndarray = Struct(NDArray).const_struct(
|
|
||||||
generator,
|
|
||||||
ctx.ctx,
|
|
||||||
&[
|
|
||||||
ndarray_data.value.as_basic_value_enum(),
|
|
||||||
ndarray_itemsize.value.as_basic_value_enum(),
|
|
||||||
ndarray_ndims.value.as_basic_value_enum(),
|
|
||||||
ndarray_shape.value.as_basic_value_enum(),
|
|
||||||
ndarray_strides.value.as_basic_value_enum(),
|
|
||||||
],
|
|
||||||
);
|
|
||||||
|
|
||||||
let ndarray_global = ctx.module.add_global(
|
|
||||||
Struct(NDArray).llvm_type(generator, ctx.ctx),
|
|
||||||
Some(AddressSpace::default()),
|
Some(AddressSpace::default()),
|
||||||
&id_str,
|
&id_str,
|
||||||
);
|
);
|
||||||
ndarray_global.set_initializer(&ndarray.value);
|
ndarray.set_initializer(&value);
|
||||||
|
|
||||||
Ok(Some(ndarray_global.as_pointer_value().into()))
|
Ok(Some(ndarray.as_pointer_value().into()))
|
||||||
} else if ty_id == self.primitive_ids.tuple {
|
} else if ty_id == self.primitive_ids.tuple {
|
||||||
let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty);
|
let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty);
|
||||||
let TypeEnum::TTuple { ty, is_vararg_ctx: false } = expected_ty_enum.as_ref() else {
|
let TypeEnum::TTuple { ty, is_vararg_ctx: false } = expected_ty_enum.as_ref() else {
|
||||||
|
@ -1503,6 +1474,7 @@ impl SymbolResolver for Resolver {
|
||||||
&self,
|
&self,
|
||||||
id: StrRef,
|
id: StrRef,
|
||||||
_: &mut CodeGenContext<'ctx, '_>,
|
_: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
_: &mut dyn CodeGenerator,
|
||||||
) -> Option<ValueEnum<'ctx>> {
|
) -> Option<ValueEnum<'ctx>> {
|
||||||
let sym_value = {
|
let sym_value = {
|
||||||
let id_to_val = self.0.id_to_pyval.read();
|
let id_to_val = self.0.id_to_pyval.read();
|
||||||
|
|
|
@ -1,9 +1,12 @@
|
||||||
use inkwell::{
|
use itertools::Either;
|
||||||
|
|
||||||
|
use nac3core::{
|
||||||
|
codegen::CodeGenContext,
|
||||||
|
inkwell::{
|
||||||
values::{BasicValueEnum, CallSiteValue},
|
values::{BasicValueEnum, CallSiteValue},
|
||||||
AddressSpace, AtomicOrdering,
|
AddressSpace, AtomicOrdering,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
use itertools::Either;
|
|
||||||
use nac3core::codegen::CodeGenContext;
|
|
||||||
|
|
||||||
/// Functions for manipulating the timeline.
|
/// Functions for manipulating the timeline.
|
||||||
pub trait TimeFns {
|
pub trait TimeFns {
|
||||||
|
@ -31,7 +34,7 @@ impl TimeFns for NowPinningTimeFns64 {
|
||||||
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
|
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
|
||||||
let now_hiptr = ctx
|
let now_hiptr = ctx
|
||||||
.builder
|
.builder
|
||||||
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
|
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
|
||||||
.map(BasicValueEnum::into_pointer_value)
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
@ -80,7 +83,7 @@ impl TimeFns for NowPinningTimeFns64 {
|
||||||
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
|
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
|
||||||
let now_hiptr = ctx
|
let now_hiptr = ctx
|
||||||
.builder
|
.builder
|
||||||
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
|
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
|
||||||
.map(BasicValueEnum::into_pointer_value)
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
@ -109,7 +112,7 @@ impl TimeFns for NowPinningTimeFns64 {
|
||||||
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
|
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
|
||||||
let now_hiptr = ctx
|
let now_hiptr = ctx
|
||||||
.builder
|
.builder
|
||||||
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
|
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
|
||||||
.map(BasicValueEnum::into_pointer_value)
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
@ -207,7 +210,7 @@ impl TimeFns for NowPinningTimeFns {
|
||||||
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
|
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
|
||||||
let now_hiptr = ctx
|
let now_hiptr = ctx
|
||||||
.builder
|
.builder
|
||||||
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
|
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
|
||||||
.map(BasicValueEnum::into_pointer_value)
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
@ -258,7 +261,7 @@ impl TimeFns for NowPinningTimeFns {
|
||||||
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap();
|
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap();
|
||||||
let now_hiptr = ctx
|
let now_hiptr = ctx
|
||||||
.builder
|
.builder
|
||||||
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
|
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
|
||||||
.map(BasicValueEnum::into_pointer_value)
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,6 @@ constant-optimization = ["fold"]
|
||||||
fold = []
|
fold = []
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
lazy_static = "1.5"
|
|
||||||
parking_lot = "0.12"
|
parking_lot = "0.12"
|
||||||
string-interner = "0.17"
|
string-interner = "0.17"
|
||||||
fxhash = "0.2"
|
fxhash = "0.2"
|
||||||
|
|
|
@ -5,14 +5,12 @@ pub use crate::location::Location;
|
||||||
|
|
||||||
use fxhash::FxBuildHasher;
|
use fxhash::FxBuildHasher;
|
||||||
use parking_lot::{Mutex, MutexGuard};
|
use parking_lot::{Mutex, MutexGuard};
|
||||||
use std::{cell::RefCell, collections::HashMap, fmt};
|
use std::{cell::RefCell, collections::HashMap, fmt, sync::LazyLock};
|
||||||
use string_interner::{symbol::SymbolU32, DefaultBackend, StringInterner};
|
use string_interner::{symbol::SymbolU32, DefaultBackend, StringInterner};
|
||||||
|
|
||||||
pub type Interner = StringInterner<DefaultBackend, FxBuildHasher>;
|
pub type Interner = StringInterner<DefaultBackend, FxBuildHasher>;
|
||||||
lazy_static! {
|
static INTERNER: LazyLock<Mutex<Interner>> =
|
||||||
static ref INTERNER: Mutex<Interner> =
|
LazyLock::new(|| Mutex::new(StringInterner::with_hasher(FxBuildHasher::default())));
|
||||||
Mutex::new(StringInterner::with_hasher(FxBuildHasher::default()));
|
|
||||||
}
|
|
||||||
|
|
||||||
thread_local! {
|
thread_local! {
|
||||||
static LOCAL_INTERNER: RefCell<HashMap<String, StrRef>> = RefCell::default();
|
static LOCAL_INTERNER: RefCell<HashMap<String, StrRef>> = RefCell::default();
|
||||||
|
|
|
@ -1,10 +1,4 @@
|
||||||
#![deny(
|
#![deny(future_incompatible, let_underscore, nonstandard_style, clippy::all)]
|
||||||
future_incompatible,
|
|
||||||
let_underscore,
|
|
||||||
nonstandard_style,
|
|
||||||
rust_2024_compatibility,
|
|
||||||
clippy::all
|
|
||||||
)]
|
|
||||||
#![warn(clippy::pedantic)]
|
#![warn(clippy::pedantic)]
|
||||||
#![allow(
|
#![allow(
|
||||||
clippy::missing_errors_doc,
|
clippy::missing_errors_doc,
|
||||||
|
@ -14,9 +8,6 @@
|
||||||
clippy::wildcard_imports
|
clippy::wildcard_imports
|
||||||
)]
|
)]
|
||||||
|
|
||||||
#[macro_use]
|
|
||||||
extern crate lazy_static;
|
|
||||||
|
|
||||||
mod ast_gen;
|
mod ast_gen;
|
||||||
mod constant;
|
mod constant;
|
||||||
#[cfg(feature = "fold")]
|
#[cfg(feature = "fold")]
|
||||||
|
|
|
@ -5,22 +5,25 @@ authors = ["M-Labs"]
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
|
default = ["derive"]
|
||||||
|
derive = ["dep:nac3core_derive"]
|
||||||
no-escape-analysis = []
|
no-escape-analysis = []
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
itertools = "0.13"
|
itertools = "0.13"
|
||||||
crossbeam = "0.8"
|
crossbeam = "0.8"
|
||||||
indexmap = "2.2"
|
indexmap = "2.6"
|
||||||
parking_lot = "0.12"
|
parking_lot = "0.12"
|
||||||
rayon = "1.8"
|
rayon = "1.10"
|
||||||
|
nac3core_derive = { path = "nac3core_derive", optional = true }
|
||||||
nac3parser = { path = "../nac3parser" }
|
nac3parser = { path = "../nac3parser" }
|
||||||
strum = "0.26"
|
strum = "0.26"
|
||||||
strum_macros = "0.26"
|
strum_macros = "0.26"
|
||||||
|
|
||||||
[dependencies.inkwell]
|
[dependencies.inkwell]
|
||||||
version = "0.4"
|
version = "0.5"
|
||||||
default-features = false
|
default-features = false
|
||||||
features = ["llvm14-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]
|
features = ["llvm14-0-prefer-dynamic", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
test-case = "1.2.0"
|
test-case = "1.2.0"
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
use regex::Regex;
|
|
||||||
use std::{
|
use std::{
|
||||||
env,
|
env,
|
||||||
fs::File,
|
fs::File,
|
||||||
|
@ -7,6 +6,8 @@ use std::{
|
||||||
process::{Command, Stdio},
|
process::{Command, Stdio},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use regex::Regex;
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
let out_dir = env::var("OUT_DIR").unwrap();
|
let out_dir = env::var("OUT_DIR").unwrap();
|
||||||
let out_dir = Path::new(&out_dir);
|
let out_dir = Path::new(&out_dir);
|
||||||
|
@ -55,9 +56,8 @@ fn main() {
|
||||||
let output = Command::new("clang-irrt")
|
let output = Command::new("clang-irrt")
|
||||||
.args(flags)
|
.args(flags)
|
||||||
.output()
|
.output()
|
||||||
.map(|o| {
|
.inspect(|o| {
|
||||||
assert!(o.status.success(), "{}", std::str::from_utf8(&o.stderr).unwrap());
|
assert!(o.status.success(), "{}", std::str::from_utf8(&o.stderr).unwrap());
|
||||||
o
|
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
|
@ -1,15 +1,5 @@
|
||||||
#include "irrt/exception.hpp"
|
#include "irrt/exception.hpp"
|
||||||
#include "irrt/int_types.hpp"
|
|
||||||
#include "irrt/list.hpp"
|
#include "irrt/list.hpp"
|
||||||
#include "irrt/math.hpp"
|
#include "irrt/math.hpp"
|
||||||
#include "irrt/range.hpp"
|
#include "irrt/ndarray.hpp"
|
||||||
#include "irrt/slice.hpp"
|
#include "irrt/slice.hpp"
|
||||||
#include "irrt/ndarray/basic.hpp"
|
|
||||||
#include "irrt/ndarray/def.hpp"
|
|
||||||
#include "irrt/ndarray/iter.hpp"
|
|
||||||
#include "irrt/ndarray/indexing.hpp"
|
|
||||||
#include "irrt/ndarray/array.hpp"
|
|
||||||
#include "irrt/ndarray/reshape.hpp"
|
|
||||||
#include "irrt/ndarray/broadcast.hpp"
|
|
||||||
#include "irrt/ndarray/transpose.hpp"
|
|
||||||
#include "irrt/ndarray/matmul.hpp"
|
|
|
@ -4,6 +4,6 @@
|
||||||
|
|
||||||
template<typename SizeT>
|
template<typename SizeT>
|
||||||
struct CSlice {
|
struct CSlice {
|
||||||
uint8_t* base;
|
void* base;
|
||||||
SizeT len;
|
SizeT len;
|
||||||
};
|
};
|
|
@ -6,7 +6,7 @@
|
||||||
/**
|
/**
|
||||||
* @brief The int type of ARTIQ exception IDs.
|
* @brief The int type of ARTIQ exception IDs.
|
||||||
*/
|
*/
|
||||||
typedef int32_t ExceptionId;
|
using ExceptionId = int32_t;
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Set of exceptions C++ IRRT can use.
|
* Set of exceptions C++ IRRT can use.
|
||||||
|
@ -55,14 +55,14 @@ void _raise_exception_helper(ExceptionId id,
|
||||||
int64_t param2) {
|
int64_t param2) {
|
||||||
Exception<SizeT> e = {
|
Exception<SizeT> e = {
|
||||||
.id = id,
|
.id = id,
|
||||||
.filename = {.base = reinterpret_cast<uint8_t*>(const_cast<char*>(filename)),
|
.filename = {.base = reinterpret_cast<void*>(const_cast<char*>(filename)),
|
||||||
.len = static_cast<int32_t>(__builtin_strlen(filename))},
|
.len = static_cast<SizeT>(__builtin_strlen(filename))},
|
||||||
.line = line,
|
.line = line,
|
||||||
.column = 0,
|
.column = 0,
|
||||||
.function = {.base = reinterpret_cast<uint8_t*>(const_cast<char*>(function)),
|
.function = {.base = reinterpret_cast<void*>(const_cast<char*>(function)),
|
||||||
.len = static_cast<int32_t>(__builtin_strlen(function))},
|
.len = static_cast<SizeT>(__builtin_strlen(function))},
|
||||||
.msg = {.base = reinterpret_cast<uint8_t*>(const_cast<char*>(msg)),
|
.msg = {.base = reinterpret_cast<void*>(const_cast<char*>(msg)),
|
||||||
.len = static_cast<int32_t>(__builtin_strlen(msg))},
|
.len = static_cast<SizeT>(__builtin_strlen(msg))},
|
||||||
};
|
};
|
||||||
e.params[0] = param0;
|
e.params[0] = param0;
|
||||||
e.params[1] = param1;
|
e.params[1] = param1;
|
||||||
|
@ -70,6 +70,7 @@ void _raise_exception_helper(ExceptionId id,
|
||||||
__nac3_raise(reinterpret_cast<void*>(&e));
|
__nac3_raise(reinterpret_cast<void*>(&e));
|
||||||
__builtin_unreachable();
|
__builtin_unreachable();
|
||||||
}
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Raise an exception with location details (location in the IRRT source files).
|
* @brief Raise an exception with location details (location in the IRRT source files).
|
||||||
|
@ -82,4 +83,3 @@ void _raise_exception_helper(ExceptionId id,
|
||||||
*/
|
*/
|
||||||
#define raise_exception(SizeT, id, msg, param0, param1, param2) \
|
#define raise_exception(SizeT, id, msg, param0, param1, param2) \
|
||||||
_raise_exception_helper<SizeT>(id, __FILE__, __LINE__, __FUNCTION__, msg, param0, param1, param2)
|
_raise_exception_helper<SizeT>(id, __FILE__, __LINE__, __FUNCTION__, msg, param0, param1, param2)
|
||||||
} // namespace
|
|
|
@ -1,13 +1,27 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#if __STDC_VERSION__ >= 202000
|
||||||
using int8_t = _BitInt(8);
|
using int8_t = _BitInt(8);
|
||||||
using uint8_t = unsigned _BitInt(8);
|
using uint8_t = unsigned _BitInt(8);
|
||||||
using int32_t = _BitInt(32);
|
using int32_t = _BitInt(32);
|
||||||
using uint32_t = unsigned _BitInt(32);
|
using uint32_t = unsigned _BitInt(32);
|
||||||
using int64_t = _BitInt(64);
|
using int64_t = _BitInt(64);
|
||||||
using uint64_t = unsigned _BitInt(64);
|
using uint64_t = unsigned _BitInt(64);
|
||||||
|
#else
|
||||||
|
|
||||||
|
#pragma clang diagnostic push
|
||||||
|
#pragma clang diagnostic ignored "-Wdeprecated-type"
|
||||||
|
using int8_t = _ExtInt(8);
|
||||||
|
using uint8_t = unsigned _ExtInt(8);
|
||||||
|
using int32_t = _ExtInt(32);
|
||||||
|
using uint32_t = unsigned _ExtInt(32);
|
||||||
|
using int64_t = _ExtInt(64);
|
||||||
|
using uint64_t = unsigned _ExtInt(64);
|
||||||
|
#pragma clang diagnostic pop
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
// NDArray indices are always `uint32_t`.
|
// NDArray indices are always `uint32_t`.
|
||||||
using NDIndexInt = uint32_t;
|
using NDIndex = uint32_t;
|
||||||
// The type of an index or a value describing the length of a range/slice is always `int32_t`.
|
// The type of an index or a value describing the length of a range/slice is always `int32_t`.
|
||||||
using SliceIndex = int32_t;
|
using SliceIndex = int32_t;
|
|
@ -2,21 +2,6 @@
|
||||||
|
|
||||||
#include "irrt/int_types.hpp"
|
#include "irrt/int_types.hpp"
|
||||||
#include "irrt/math_util.hpp"
|
#include "irrt/math_util.hpp"
|
||||||
#include "irrt/slice.hpp"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
/**
|
|
||||||
* @brief A list in NAC3.
|
|
||||||
*
|
|
||||||
* The `items` field is opaque. You must rely on external contexts to
|
|
||||||
* know how to interpret it.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
struct List {
|
|
||||||
uint8_t* items;
|
|
||||||
SizeT len;
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
// Handle list assignment and dropping part of the list when
|
// Handle list assignment and dropping part of the list when
|
||||||
|
@ -28,12 +13,12 @@ extern "C" {
|
||||||
SliceIndex __nac3_list_slice_assign_var_size(SliceIndex dest_start,
|
SliceIndex __nac3_list_slice_assign_var_size(SliceIndex dest_start,
|
||||||
SliceIndex dest_end,
|
SliceIndex dest_end,
|
||||||
SliceIndex dest_step,
|
SliceIndex dest_step,
|
||||||
uint8_t* dest_arr,
|
void* dest_arr,
|
||||||
SliceIndex dest_arr_len,
|
SliceIndex dest_arr_len,
|
||||||
SliceIndex src_start,
|
SliceIndex src_start,
|
||||||
SliceIndex src_end,
|
SliceIndex src_end,
|
||||||
SliceIndex src_step,
|
SliceIndex src_step,
|
||||||
uint8_t* src_arr,
|
void* src_arr,
|
||||||
SliceIndex src_arr_len,
|
SliceIndex src_arr_len,
|
||||||
const SliceIndex size) {
|
const SliceIndex size) {
|
||||||
/* if dest_arr_len == 0, do nothing since we do not support extending list */
|
/* if dest_arr_len == 0, do nothing since we do not support extending list */
|
||||||
|
@ -44,11 +29,13 @@ SliceIndex __nac3_list_slice_assign_var_size(SliceIndex dest_start,
|
||||||
const SliceIndex src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0;
|
const SliceIndex src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0;
|
||||||
const SliceIndex dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
|
const SliceIndex dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
|
||||||
if (src_len > 0) {
|
if (src_len > 0) {
|
||||||
__builtin_memmove(dest_arr + dest_start * size, src_arr + src_start * size, src_len * size);
|
__builtin_memmove(static_cast<uint8_t*>(dest_arr) + dest_start * size,
|
||||||
|
static_cast<uint8_t*>(src_arr) + src_start * size, src_len * size);
|
||||||
}
|
}
|
||||||
if (dest_len > 0) {
|
if (dest_len > 0) {
|
||||||
/* dropping */
|
/* dropping */
|
||||||
__builtin_memmove(dest_arr + (dest_start + src_len) * size, dest_arr + (dest_end + 1) * size,
|
__builtin_memmove(static_cast<uint8_t*>(dest_arr) + (dest_start + src_len) * size,
|
||||||
|
static_cast<uint8_t*>(dest_arr) + (dest_end + 1) * size,
|
||||||
(dest_arr_len - dest_end - 1) * size);
|
(dest_arr_len - dest_end - 1) * size);
|
||||||
}
|
}
|
||||||
/* shrink size */
|
/* shrink size */
|
||||||
|
@ -59,7 +46,7 @@ SliceIndex __nac3_list_slice_assign_var_size(SliceIndex dest_start,
|
||||||
&& !(max(dest_start, dest_end) < min(src_start, src_end)
|
&& !(max(dest_start, dest_end) < min(src_start, src_end)
|
||||||
|| max(src_start, src_end) < min(dest_start, dest_end));
|
|| max(src_start, src_end) < min(dest_start, dest_end));
|
||||||
if (need_alloca) {
|
if (need_alloca) {
|
||||||
uint8_t* tmp = reinterpret_cast<uint8_t*>(__builtin_alloca(src_arr_len * size));
|
void* tmp = __builtin_alloca(src_arr_len * size);
|
||||||
__builtin_memcpy(tmp, src_arr, src_arr_len * size);
|
__builtin_memcpy(tmp, src_arr, src_arr_len * size);
|
||||||
src_arr = tmp;
|
src_arr = tmp;
|
||||||
}
|
}
|
||||||
|
@ -68,20 +55,24 @@ SliceIndex __nac3_list_slice_assign_var_size(SliceIndex dest_start,
|
||||||
for (; (src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end); src_ind += src_step, dest_ind += dest_step) {
|
for (; (src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end); src_ind += src_step, dest_ind += dest_step) {
|
||||||
/* for constant optimization */
|
/* for constant optimization */
|
||||||
if (size == 1) {
|
if (size == 1) {
|
||||||
__builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1);
|
__builtin_memcpy(static_cast<uint8_t*>(dest_arr) + dest_ind, static_cast<uint8_t*>(src_arr) + src_ind, 1);
|
||||||
} else if (size == 4) {
|
} else if (size == 4) {
|
||||||
__builtin_memcpy(dest_arr + dest_ind * 4, src_arr + src_ind * 4, 4);
|
__builtin_memcpy(static_cast<uint8_t*>(dest_arr) + dest_ind * 4,
|
||||||
|
static_cast<uint8_t*>(src_arr) + src_ind * 4, 4);
|
||||||
} else if (size == 8) {
|
} else if (size == 8) {
|
||||||
__builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8);
|
__builtin_memcpy(static_cast<uint8_t*>(dest_arr) + dest_ind * 8,
|
||||||
|
static_cast<uint8_t*>(src_arr) + src_ind * 8, 8);
|
||||||
} else {
|
} else {
|
||||||
/* memcpy for var size, cannot overlap after previous alloca */
|
/* memcpy for var size, cannot overlap after previous alloca */
|
||||||
__builtin_memcpy(dest_arr + dest_ind * size, src_arr + src_ind * size, size);
|
__builtin_memcpy(static_cast<uint8_t*>(dest_arr) + dest_ind * size,
|
||||||
|
static_cast<uint8_t*>(src_arr) + src_ind * size, size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/* only dest_step == 1 can we shrink the dest list. */
|
/* only dest_step == 1 can we shrink the dest list. */
|
||||||
/* size should be ensured prior to calling this function */
|
/* size should be ensured prior to calling this function */
|
||||||
if (dest_step == 1 && dest_end >= dest_start) {
|
if (dest_step == 1 && dest_end >= dest_start) {
|
||||||
__builtin_memmove(dest_arr + dest_ind * size, dest_arr + (dest_end + 1) * size,
|
__builtin_memmove(static_cast<uint8_t*>(dest_arr) + dest_ind * size,
|
||||||
|
static_cast<uint8_t*>(dest_arr) + (dest_end + 1) * size,
|
||||||
(dest_arr_len - dest_end - 1) * size);
|
(dest_arr_len - dest_end - 1) * size);
|
||||||
return dest_arr_len - (dest_end - dest_ind) - 1;
|
return dest_arr_len - (dest_end - dest_ind) - 1;
|
||||||
}
|
}
|
||||||
|
|
|
@ -90,4 +90,4 @@ double __nac3_j0(double x) {
|
||||||
|
|
||||||
return j0(x);
|
return j0(x);
|
||||||
}
|
}
|
||||||
}
|
} // namespace
|
|
@ -0,0 +1,144 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "irrt/int_types.hpp"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template<typename SizeT>
|
||||||
|
SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) {
|
||||||
|
__builtin_assume(end_idx <= list_len);
|
||||||
|
|
||||||
|
SizeT num_elems = 1;
|
||||||
|
for (SizeT i = begin_idx; i < end_idx; ++i) {
|
||||||
|
SizeT val = list_data[i];
|
||||||
|
__builtin_assume(val > 0);
|
||||||
|
num_elems *= val;
|
||||||
|
}
|
||||||
|
return num_elems;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename SizeT>
|
||||||
|
void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT num_dims, NDIndex* idxs) {
|
||||||
|
SizeT stride = 1;
|
||||||
|
for (SizeT dim = 0; dim < num_dims; dim++) {
|
||||||
|
SizeT i = num_dims - dim - 1;
|
||||||
|
__builtin_assume(dims[i] > 0);
|
||||||
|
idxs[i] = (index / stride) % dims[i];
|
||||||
|
stride *= dims[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename SizeT>
|
||||||
|
SizeT __nac3_ndarray_flatten_index_impl(const SizeT* dims, SizeT num_dims, const NDIndex* indices, SizeT num_indices) {
|
||||||
|
SizeT idx = 0;
|
||||||
|
SizeT stride = 1;
|
||||||
|
for (SizeT i = 0; i < num_dims; ++i) {
|
||||||
|
SizeT ri = num_dims - i - 1;
|
||||||
|
if (ri < num_indices) {
|
||||||
|
idx += stride * indices[ri];
|
||||||
|
}
|
||||||
|
|
||||||
|
__builtin_assume(dims[i] > 0);
|
||||||
|
stride *= dims[ri];
|
||||||
|
}
|
||||||
|
return idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename SizeT>
|
||||||
|
void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims,
|
||||||
|
SizeT lhs_ndims,
|
||||||
|
const SizeT* rhs_dims,
|
||||||
|
SizeT rhs_ndims,
|
||||||
|
SizeT* out_dims) {
|
||||||
|
SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
|
||||||
|
|
||||||
|
for (SizeT i = 0; i < max_ndims; ++i) {
|
||||||
|
const SizeT* lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr;
|
||||||
|
const SizeT* rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr;
|
||||||
|
SizeT* out_dim = &out_dims[max_ndims - i - 1];
|
||||||
|
|
||||||
|
if (lhs_dim_sz == nullptr) {
|
||||||
|
*out_dim = *rhs_dim_sz;
|
||||||
|
} else if (rhs_dim_sz == nullptr) {
|
||||||
|
*out_dim = *lhs_dim_sz;
|
||||||
|
} else if (*lhs_dim_sz == 1) {
|
||||||
|
*out_dim = *rhs_dim_sz;
|
||||||
|
} else if (*rhs_dim_sz == 1) {
|
||||||
|
*out_dim = *lhs_dim_sz;
|
||||||
|
} else if (*lhs_dim_sz == *rhs_dim_sz) {
|
||||||
|
*out_dim = *lhs_dim_sz;
|
||||||
|
} else {
|
||||||
|
__builtin_unreachable();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename SizeT>
|
||||||
|
void __nac3_ndarray_calc_broadcast_idx_impl(const SizeT* src_dims,
|
||||||
|
SizeT src_ndims,
|
||||||
|
const NDIndex* in_idx,
|
||||||
|
NDIndex* out_idx) {
|
||||||
|
for (SizeT i = 0; i < src_ndims; ++i) {
|
||||||
|
SizeT src_i = src_ndims - i - 1;
|
||||||
|
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
uint32_t __nac3_ndarray_calc_size(const uint32_t* list_data, uint32_t list_len, uint32_t begin_idx, uint32_t end_idx) {
|
||||||
|
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t
|
||||||
|
__nac3_ndarray_calc_size64(const uint64_t* list_data, uint64_t list_len, uint64_t begin_idx, uint64_t end_idx) {
|
||||||
|
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims, uint32_t num_dims, NDIndex* idxs) {
|
||||||
|
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndex* idxs) {
|
||||||
|
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t
|
||||||
|
__nac3_ndarray_flatten_index(const uint32_t* dims, uint32_t num_dims, const NDIndex* indices, uint32_t num_indices) {
|
||||||
|
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t
|
||||||
|
__nac3_ndarray_flatten_index64(const uint64_t* dims, uint64_t num_dims, const NDIndex* indices, uint64_t num_indices) {
|
||||||
|
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_calc_broadcast(const uint32_t* lhs_dims,
|
||||||
|
uint32_t lhs_ndims,
|
||||||
|
const uint32_t* rhs_dims,
|
||||||
|
uint32_t rhs_ndims,
|
||||||
|
uint32_t* out_dims) {
|
||||||
|
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_calc_broadcast64(const uint64_t* lhs_dims,
|
||||||
|
uint64_t lhs_ndims,
|
||||||
|
const uint64_t* rhs_dims,
|
||||||
|
uint64_t rhs_ndims,
|
||||||
|
uint64_t* out_dims) {
|
||||||
|
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_calc_broadcast_idx(const uint32_t* src_dims,
|
||||||
|
uint32_t src_ndims,
|
||||||
|
const NDIndex* in_idx,
|
||||||
|
NDIndex* out_idx) {
|
||||||
|
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_calc_broadcast_idx64(const uint64_t* src_dims,
|
||||||
|
uint64_t src_ndims,
|
||||||
|
const NDIndex* in_idx,
|
||||||
|
NDIndex* out_idx) {
|
||||||
|
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
|
||||||
|
}
|
||||||
|
} // namespace
|
|
@ -1,134 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "irrt/debug.hpp"
|
|
||||||
#include "irrt/exception.hpp"
|
|
||||||
#include "irrt/int_types.hpp"
|
|
||||||
#include "irrt/list.hpp"
|
|
||||||
#include "irrt/ndarray/basic.hpp"
|
|
||||||
#include "irrt/ndarray/def.hpp"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
namespace ndarray {
|
|
||||||
namespace array {
|
|
||||||
/**
|
|
||||||
* @brief In the context of `np.array(<list>)`, deduce the ndarray's shape produced by `<list>` and raise
|
|
||||||
* an exception if there is anything wrong with `<shape>` (e.g., inconsistent dimensions `np.array([[1.0, 2.0],
|
|
||||||
* [3.0]])`)
|
|
||||||
*
|
|
||||||
* If this function finds no issues with `<list>`, the deduced shape is written to `shape`. The caller has the
|
|
||||||
* responsibility to allocate `[SizeT; ndims]` for `shape`. The caller must also initialize `shape` with `-1`s because
|
|
||||||
* of implementation details.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void set_and_validate_list_shape_helper(SizeT axis, List<SizeT>* list, SizeT ndims, SizeT* shape) {
|
|
||||||
if (shape[axis] == -1) {
|
|
||||||
// Dimension is unspecified. Set it.
|
|
||||||
shape[axis] = list->len;
|
|
||||||
} else {
|
|
||||||
// Dimension is specified. Check.
|
|
||||||
if (shape[axis] != list->len) {
|
|
||||||
// Mismatch, throw an error.
|
|
||||||
// NOTE: NumPy's error message is more complex and needs more PARAMS to display.
|
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR,
|
|
||||||
"The requested array has an inhomogenous shape "
|
|
||||||
"after {0} dimension(s).",
|
|
||||||
axis, shape[axis], list->len);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (axis + 1 == ndims) {
|
|
||||||
// `list` has type `list[ItemType]`
|
|
||||||
// Do nothing
|
|
||||||
} else {
|
|
||||||
// `list` has type `list[list[...]]`
|
|
||||||
List<SizeT>** lists = (List<SizeT>**)(list->items);
|
|
||||||
for (SizeT i = 0; i < list->len; i++) {
|
|
||||||
set_and_validate_list_shape_helper<SizeT>(axis + 1, lists[i], ndims, shape);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief See `set_and_validate_list_shape_helper`.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void set_and_validate_list_shape(List<SizeT>* list, SizeT ndims, SizeT* shape) {
|
|
||||||
for (SizeT axis = 0; axis < ndims; axis++) {
|
|
||||||
shape[axis] = -1; // Sentinel to say this dimension is unspecified.
|
|
||||||
}
|
|
||||||
set_and_validate_list_shape_helper<SizeT>(0, list, ndims, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief In the context of `np.array(<list>)`, copied the contents stored in `list` to `ndarray`.
|
|
||||||
*
|
|
||||||
* `list` is assumed to be "legal". (i.e., no inconsistent dimensions)
|
|
||||||
*
|
|
||||||
* # Notes on `ndarray`
|
|
||||||
* The caller is responsible for allocating space for `ndarray`.
|
|
||||||
* Here is what this function expects from `ndarray` when called:
|
|
||||||
* - `ndarray->data` has to be allocated, contiguous, and may contain uninitialized values.
|
|
||||||
* - `ndarray->itemsize` has to be initialized.
|
|
||||||
* - `ndarray->ndims` has to be initialized.
|
|
||||||
* - `ndarray->shape` has to be initialized.
|
|
||||||
* - `ndarray->strides` is ignored, but note that `ndarray->data` is contiguous.
|
|
||||||
* When this function call ends:
|
|
||||||
* - `ndarray->data` is written with contents from `<list>`.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void write_list_to_array_helper(SizeT axis, SizeT* index, List<SizeT>* list, NDArray<SizeT>* ndarray) {
|
|
||||||
debug_assert_eq(SizeT, list->len, ndarray->shape[axis]);
|
|
||||||
if (IRRT_DEBUG_ASSERT_BOOL) {
|
|
||||||
if (!ndarray::basic::is_c_contiguous(ndarray)) {
|
|
||||||
raise_debug_assert(SizeT, "ndarray is not C-contiguous", ndarray->strides[0], ndarray->strides[1],
|
|
||||||
NO_PARAM);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (axis + 1 == ndarray->ndims) {
|
|
||||||
// `list` has type `list[scalar]`
|
|
||||||
// `ndarray` is contiguous, so we can do this, and this is fast.
|
|
||||||
uint8_t* dst = ndarray->data + (ndarray->itemsize * (*index));
|
|
||||||
__builtin_memcpy(dst, list->items, ndarray->itemsize * list->len);
|
|
||||||
*index += list->len;
|
|
||||||
} else {
|
|
||||||
// `list` has type `list[list[...]]`
|
|
||||||
List<SizeT>** lists = (List<SizeT>**)(list->items);
|
|
||||||
|
|
||||||
for (SizeT i = 0; i < list->len; i++) {
|
|
||||||
write_list_to_array_helper<SizeT>(axis + 1, index, lists[i], ndarray);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief See `write_list_to_array_helper`.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void write_list_to_array(List<SizeT>* list, NDArray<SizeT>* ndarray) {
|
|
||||||
SizeT index = 0;
|
|
||||||
write_list_to_array_helper<SizeT>((SizeT)0, &index, list, ndarray);
|
|
||||||
}
|
|
||||||
} // namespace array
|
|
||||||
} // namespace ndarray
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
using namespace ndarray::array;
|
|
||||||
|
|
||||||
void __nac3_ndarray_array_set_and_validate_list_shape(List<int32_t>* list, int32_t ndims, int32_t* shape) {
|
|
||||||
set_and_validate_list_shape(list, ndims, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_array_set_and_validate_list_shape64(List<int64_t>* list, int64_t ndims, int64_t* shape) {
|
|
||||||
set_and_validate_list_shape(list, ndims, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_array_write_list_to_array(List<int32_t>* list, NDArray<int32_t>* ndarray) {
|
|
||||||
write_list_to_array(list, ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_array_write_list_to_array64(List<int64_t>* list, NDArray<int64_t>* ndarray) {
|
|
||||||
write_list_to_array(list, ndarray);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,341 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "irrt/debug.hpp"
|
|
||||||
#include "irrt/exception.hpp"
|
|
||||||
#include "irrt/int_types.hpp"
|
|
||||||
#include "irrt/ndarray/def.hpp"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
namespace ndarray {
|
|
||||||
namespace basic {
|
|
||||||
/**
|
|
||||||
* @brief Assert that `shape` does not contain negative dimensions.
|
|
||||||
*
|
|
||||||
* @param ndims Number of dimensions in `shape`
|
|
||||||
* @param shape The shape to check on
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void assert_shape_no_negative(SizeT ndims, const SizeT* shape) {
|
|
||||||
for (SizeT axis = 0; axis < ndims; axis++) {
|
|
||||||
if (shape[axis] < 0) {
|
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR,
|
|
||||||
"negative dimensions are not allowed; axis {0} "
|
|
||||||
"has dimension {1}",
|
|
||||||
axis, shape[axis], NO_PARAM);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Assert that two shapes are the same in the context of writing output to an ndarray.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void assert_output_shape_same(SizeT ndarray_ndims,
|
|
||||||
const SizeT* ndarray_shape,
|
|
||||||
SizeT output_ndims,
|
|
||||||
const SizeT* output_shape) {
|
|
||||||
if (ndarray_ndims != output_ndims) {
|
|
||||||
// There is no corresponding NumPy error message like this.
|
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot write output of ndims {0} to an ndarray with ndims {1}",
|
|
||||||
output_ndims, ndarray_ndims, NO_PARAM);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (SizeT axis = 0; axis < ndarray_ndims; axis++) {
|
|
||||||
if (ndarray_shape[axis] != output_shape[axis]) {
|
|
||||||
// There is no corresponding NumPy error message like this.
|
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR,
|
|
||||||
"Mismatched dimensions on axis {0}, output has "
|
|
||||||
"dimension {1}, but destination ndarray has dimension {2}.",
|
|
||||||
axis, output_shape[axis], ndarray_shape[axis]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Return the number of elements of an ndarray given its shape.
|
|
||||||
*
|
|
||||||
* @param ndims Number of dimensions in `shape`
|
|
||||||
* @param shape The shape of the ndarray
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
SizeT calc_size_from_shape(SizeT ndims, const SizeT* shape) {
|
|
||||||
SizeT size = 1;
|
|
||||||
for (SizeT axis = 0; axis < ndims; axis++)
|
|
||||||
size *= shape[axis];
|
|
||||||
return size;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Compute the array indices of the `nth` (0-based) element of an ndarray given only its shape.
|
|
||||||
*
|
|
||||||
* @param ndims Number of elements in `shape` and `indices`
|
|
||||||
* @param shape The shape of the ndarray
|
|
||||||
* @param indices The returned indices indexing the ndarray with shape `shape`.
|
|
||||||
* @param nth The index of the element of interest.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void set_indices_by_nth(SizeT ndims, const SizeT* shape, SizeT* indices, SizeT nth) {
|
|
||||||
for (SizeT i = 0; i < ndims; i++) {
|
|
||||||
SizeT axis = ndims - i - 1;
|
|
||||||
SizeT dim = shape[axis];
|
|
||||||
|
|
||||||
indices[axis] = nth % dim;
|
|
||||||
nth /= dim;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Return the number of elements of an `ndarray`
|
|
||||||
*
|
|
||||||
* This function corresponds to `<an_ndarray>.size`
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
SizeT size(const NDArray<SizeT>* ndarray) {
|
|
||||||
return calc_size_from_shape(ndarray->ndims, ndarray->shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Return of the number of its content of an `ndarray`.
|
|
||||||
*
|
|
||||||
* This function corresponds to `<an_ndarray>.nbytes`.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
SizeT nbytes(const NDArray<SizeT>* ndarray) {
|
|
||||||
return size(ndarray) * ndarray->itemsize;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Get the `len()` of an ndarray, and asserts that `ndarray` is a sized object.
|
|
||||||
*
|
|
||||||
* This function corresponds to `<an_ndarray>.__len__`.
|
|
||||||
*
|
|
||||||
* @param dst_length The length.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
SizeT len(const NDArray<SizeT>* ndarray) {
|
|
||||||
// numpy prohibits `__len__` on unsized objects
|
|
||||||
if (ndarray->ndims == 0) {
|
|
||||||
raise_exception(SizeT, EXN_TYPE_ERROR, "len() of unsized object", NO_PARAM, NO_PARAM, NO_PARAM);
|
|
||||||
} else {
|
|
||||||
return ndarray->shape[0];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Return a boolean indicating if `ndarray` is (C-)contiguous.
|
|
||||||
*
|
|
||||||
* You may want to see ndarray's rules for C-contiguity:
|
|
||||||
* https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
bool is_c_contiguous(const NDArray<SizeT>* ndarray) {
|
|
||||||
// References:
|
|
||||||
// - tinynumpy's implementation:
|
|
||||||
// https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L102
|
|
||||||
// - ndarray's flags["C_CONTIGUOUS"]:
|
|
||||||
// https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags
|
|
||||||
// - ndarray's rules for C-contiguity:
|
|
||||||
// https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45
|
|
||||||
|
|
||||||
// From
|
|
||||||
// https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45:
|
|
||||||
//
|
|
||||||
// The traditional rule is that for an array to be flagged as C contiguous,
|
|
||||||
// the following must hold:
|
|
||||||
//
|
|
||||||
// strides[-1] == itemsize
|
|
||||||
// strides[i] == shape[i+1] * strides[i + 1]
|
|
||||||
// [...]
|
|
||||||
// According to these rules, a 0- or 1-dimensional array is either both
|
|
||||||
// C- and F-contiguous, or neither; and an array with 2+ dimensions
|
|
||||||
// can be C- or F- contiguous, or neither, but not both. Though there
|
|
||||||
// there are exceptions for arrays with zero or one item, in the first
|
|
||||||
// case the check is relaxed up to and including the first dimension
|
|
||||||
// with shape[i] == 0. In the second case `strides == itemsize` will
|
|
||||||
// can be true for all dimensions and both flags are set.
|
|
||||||
|
|
||||||
if (ndarray->ndims == 0) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ndarray->strides[ndarray->ndims - 1] != ndarray->itemsize) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (SizeT i = 1; i < ndarray->ndims; i++) {
|
|
||||||
SizeT axis_i = ndarray->ndims - i - 1;
|
|
||||||
if (ndarray->strides[axis_i] != ndarray->shape[axis_i + 1] * ndarray->strides[axis_i + 1]) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Return the pointer to the element indexed by `indices` along the ndarray's axes.
|
|
||||||
*
|
|
||||||
* This function does no bound check.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
uint8_t* get_pelement_by_indices(const NDArray<SizeT>* ndarray, const SizeT* indices) {
|
|
||||||
uint8_t* element = ndarray->data;
|
|
||||||
for (SizeT dim_i = 0; dim_i < ndarray->ndims; dim_i++)
|
|
||||||
element += indices[dim_i] * ndarray->strides[dim_i];
|
|
||||||
return element;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Return the pointer to the nth (0-based) element of `ndarray` in flattened view.
|
|
||||||
*
|
|
||||||
* This function does no bound check.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
uint8_t* get_nth_pelement(const NDArray<SizeT>* ndarray, SizeT nth) {
|
|
||||||
uint8_t* element = ndarray->data;
|
|
||||||
for (SizeT i = 0; i < ndarray->ndims; i++) {
|
|
||||||
SizeT axis = ndarray->ndims - i - 1;
|
|
||||||
SizeT dim = ndarray->shape[axis];
|
|
||||||
element += ndarray->strides[axis] * (nth % dim);
|
|
||||||
nth /= dim;
|
|
||||||
}
|
|
||||||
return element;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Update the strides of an ndarray given an ndarray `shape` to be contiguous.
|
|
||||||
*
|
|
||||||
* You might want to read https://ajcr.net/stride-guide-part-1/.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void set_strides_by_shape(NDArray<SizeT>* ndarray) {
|
|
||||||
SizeT stride_product = 1;
|
|
||||||
for (SizeT i = 0; i < ndarray->ndims; i++) {
|
|
||||||
SizeT axis = ndarray->ndims - i - 1;
|
|
||||||
ndarray->strides[axis] = stride_product * ndarray->itemsize;
|
|
||||||
stride_product *= ndarray->shape[axis];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Set an element in `ndarray`.
|
|
||||||
*
|
|
||||||
* @param pelement Pointer to the element in `ndarray` to be set.
|
|
||||||
* @param pvalue Pointer to the value `pelement` will be set to.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void set_pelement_value(NDArray<SizeT>* ndarray, uint8_t* pelement, const uint8_t* pvalue) {
|
|
||||||
__builtin_memcpy(pelement, pvalue, ndarray->itemsize);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Copy data from one ndarray to another of the exact same size and itemsize.
|
|
||||||
*
|
|
||||||
* Both ndarrays will be viewed in their flatten views when copying the elements.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void copy_data(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
|
||||||
// TODO: Make this faster with memcpy when we see a contiguous segment.
|
|
||||||
// TODO: Handle overlapping.
|
|
||||||
|
|
||||||
debug_assert_eq(SizeT, src_ndarray->itemsize, dst_ndarray->itemsize);
|
|
||||||
|
|
||||||
for (SizeT i = 0; i < size(src_ndarray); i++) {
|
|
||||||
auto src_element = ndarray::basic::get_nth_pelement(src_ndarray, i);
|
|
||||||
auto dst_element = ndarray::basic::get_nth_pelement(dst_ndarray, i);
|
|
||||||
ndarray::basic::set_pelement_value(dst_ndarray, dst_element, src_element);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace basic
|
|
||||||
} // namespace ndarray
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
using namespace ndarray::basic;
|
|
||||||
|
|
||||||
void __nac3_ndarray_util_assert_shape_no_negative(int32_t ndims, int32_t* shape) {
|
|
||||||
assert_shape_no_negative(ndims, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_util_assert_shape_no_negative64(int64_t ndims, int64_t* shape) {
|
|
||||||
assert_shape_no_negative(ndims, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_util_assert_output_shape_same(int32_t ndarray_ndims,
|
|
||||||
const int32_t* ndarray_shape,
|
|
||||||
int32_t output_ndims,
|
|
||||||
const int32_t* output_shape) {
|
|
||||||
assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims, output_shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_util_assert_output_shape_same64(int64_t ndarray_ndims,
|
|
||||||
const int64_t* ndarray_shape,
|
|
||||||
int64_t output_ndims,
|
|
||||||
const int64_t* output_shape) {
|
|
||||||
assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims, output_shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t __nac3_ndarray_size(NDArray<int32_t>* ndarray) {
|
|
||||||
return size(ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t __nac3_ndarray_size64(NDArray<int64_t>* ndarray) {
|
|
||||||
return size(ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t __nac3_ndarray_nbytes(NDArray<int32_t>* ndarray) {
|
|
||||||
return nbytes(ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t __nac3_ndarray_nbytes64(NDArray<int64_t>* ndarray) {
|
|
||||||
return nbytes(ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t __nac3_ndarray_len(NDArray<int32_t>* ndarray) {
|
|
||||||
return len(ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t __nac3_ndarray_len64(NDArray<int64_t>* ndarray) {
|
|
||||||
return len(ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool __nac3_ndarray_is_c_contiguous(NDArray<int32_t>* ndarray) {
|
|
||||||
return is_c_contiguous(ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool __nac3_ndarray_is_c_contiguous64(NDArray<int64_t>* ndarray) {
|
|
||||||
return is_c_contiguous(ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint8_t* __nac3_ndarray_get_nth_pelement(const NDArray<int32_t>* ndarray, int32_t nth) {
|
|
||||||
return get_nth_pelement(ndarray, nth);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint8_t* __nac3_ndarray_get_nth_pelement64(const NDArray<int64_t>* ndarray, int64_t nth) {
|
|
||||||
return get_nth_pelement(ndarray, nth);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint8_t* __nac3_ndarray_get_pelement_by_indices(const NDArray<int32_t>* ndarray, int32_t* indices) {
|
|
||||||
return get_pelement_by_indices(ndarray, indices);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint8_t* __nac3_ndarray_get_pelement_by_indices64(const NDArray<int64_t>* ndarray, int64_t* indices) {
|
|
||||||
return get_pelement_by_indices(ndarray, indices);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_set_strides_by_shape(NDArray<int32_t>* ndarray) {
|
|
||||||
set_strides_by_shape(ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_set_strides_by_shape64(NDArray<int64_t>* ndarray) {
|
|
||||||
set_strides_by_shape(ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_copy_data(NDArray<int32_t>* src_ndarray, NDArray<int32_t>* dst_ndarray) {
|
|
||||||
copy_data(src_ndarray, dst_ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_copy_data64(NDArray<int64_t>* src_ndarray, NDArray<int64_t>* dst_ndarray) {
|
|
||||||
copy_data(src_ndarray, dst_ndarray);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,165 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "irrt/int_types.hpp"
|
|
||||||
#include "irrt/ndarray/def.hpp"
|
|
||||||
#include "irrt/slice.hpp"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
template<typename SizeT>
|
|
||||||
struct ShapeEntry {
|
|
||||||
SizeT ndims;
|
|
||||||
SizeT* shape;
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
namespace ndarray {
|
|
||||||
namespace broadcast {
|
|
||||||
/**
|
|
||||||
* @brief Return true if `src_shape` can broadcast to `dst_shape`.
|
|
||||||
*
|
|
||||||
* See https://numpy.org/doc/stable/user/basics.broadcasting.html
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
bool can_broadcast_shape_to(SizeT target_ndims, const SizeT* target_shape, SizeT src_ndims, const SizeT* src_shape) {
|
|
||||||
if (src_ndims > target_ndims) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (SizeT i = 0; i < src_ndims; i++) {
|
|
||||||
SizeT target_dim = target_shape[target_ndims - i - 1];
|
|
||||||
SizeT src_dim = src_shape[src_ndims - i - 1];
|
|
||||||
if (!(src_dim == 1 || target_dim == src_dim)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Performs `np.broadcast_shapes(<shapes>)`
|
|
||||||
*
|
|
||||||
* @param num_shapes Number of entries in `shapes`
|
|
||||||
* @param shapes The list of shape to do `np.broadcast_shapes` on.
|
|
||||||
* @param dst_ndims The length of `dst_shape`.
|
|
||||||
* `dst_ndims` must be `max([shape.ndims for shape in shapes])`, but the caller has to calculate it/provide it.
|
|
||||||
* for this function since they should already know in order to allocate `dst_shape` in the first place.
|
|
||||||
* @param dst_shape The resulting shape. Must be pre-allocated by the caller. This function calculate the result
|
|
||||||
* of `np.broadcast_shapes` and write it here.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void broadcast_shapes(SizeT num_shapes, const ShapeEntry<SizeT>* shapes, SizeT dst_ndims, SizeT* dst_shape) {
|
|
||||||
for (SizeT dst_axis = 0; dst_axis < dst_ndims; dst_axis++) {
|
|
||||||
dst_shape[dst_axis] = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef IRRT_DEBUG_ASSERT
|
|
||||||
SizeT max_ndims_found = 0;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
for (SizeT i = 0; i < num_shapes; i++) {
|
|
||||||
ShapeEntry<SizeT> entry = shapes[i];
|
|
||||||
|
|
||||||
// Check pre-condition: `dst_ndims` must be `max([shape.ndims for shape in shapes])`
|
|
||||||
debug_assert(SizeT, entry.ndims <= dst_ndims);
|
|
||||||
|
|
||||||
#ifdef IRRT_DEBUG_ASSERT
|
|
||||||
max_ndims_found = max(max_ndims_found, entry.ndims);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
for (SizeT j = 0; j < entry.ndims; j++) {
|
|
||||||
SizeT entry_axis = entry.ndims - j - 1;
|
|
||||||
SizeT dst_axis = dst_ndims - j - 1;
|
|
||||||
|
|
||||||
SizeT entry_dim = entry.shape[entry_axis];
|
|
||||||
SizeT dst_dim = dst_shape[dst_axis];
|
|
||||||
|
|
||||||
if (dst_dim == 1) {
|
|
||||||
dst_shape[dst_axis] = entry_dim;
|
|
||||||
} else if (entry_dim == 1 || entry_dim == dst_dim) {
|
|
||||||
// Do nothing
|
|
||||||
} else {
|
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR,
|
|
||||||
"shape mismatch: objects cannot be broadcast "
|
|
||||||
"to a single shape.",
|
|
||||||
NO_PARAM, NO_PARAM, NO_PARAM);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check pre-condition: `dst_ndims` must be `max([shape.ndims for shape in shapes])`
|
|
||||||
debug_assert_eq(SizeT, max_ndims_found, dst_ndims);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Perform `np.broadcast_to(<ndarray>, <target_shape>)` and appropriate assertions.
|
|
||||||
*
|
|
||||||
* This function attempts to broadcast `src_ndarray` to a new shape defined by `dst_ndarray.shape`,
|
|
||||||
* and return the result by modifying `dst_ndarray`.
|
|
||||||
*
|
|
||||||
* # Notes on `dst_ndarray`
|
|
||||||
* The caller is responsible for allocating space for the resulting ndarray.
|
|
||||||
* Here is what this function expects from `dst_ndarray` when called:
|
|
||||||
* - `dst_ndarray->data` does not have to be initialized.
|
|
||||||
* - `dst_ndarray->itemsize` does not have to be initialized.
|
|
||||||
* - `dst_ndarray->ndims` must be initialized, determining the length of `dst_ndarray->shape`
|
|
||||||
* - `dst_ndarray->shape` must be allocated, and must contain the desired target broadcast shape.
|
|
||||||
* - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values.
|
|
||||||
* When this function call ends:
|
|
||||||
* - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`)
|
|
||||||
* - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`
|
|
||||||
* - `dst_ndarray->ndims` is unchanged.
|
|
||||||
* - `dst_ndarray->shape` is unchanged.
|
|
||||||
* - `dst_ndarray->strides` is updated accordingly by how ndarray broadcast_to works.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void broadcast_to(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
|
||||||
if (!ndarray::broadcast::can_broadcast_shape_to(dst_ndarray->ndims, dst_ndarray->shape, src_ndarray->ndims,
|
|
||||||
src_ndarray->shape)) {
|
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR, "operands could not be broadcast together", NO_PARAM, NO_PARAM,
|
|
||||||
NO_PARAM);
|
|
||||||
}
|
|
||||||
|
|
||||||
dst_ndarray->data = src_ndarray->data;
|
|
||||||
dst_ndarray->itemsize = src_ndarray->itemsize;
|
|
||||||
|
|
||||||
for (SizeT i = 0; i < dst_ndarray->ndims; i++) {
|
|
||||||
SizeT src_axis = src_ndarray->ndims - i - 1;
|
|
||||||
SizeT dst_axis = dst_ndarray->ndims - i - 1;
|
|
||||||
if (src_axis < 0 || (src_ndarray->shape[src_axis] == 1 && dst_ndarray->shape[dst_axis] != 1)) {
|
|
||||||
// Freeze the steps in-place
|
|
||||||
dst_ndarray->strides[dst_axis] = 0;
|
|
||||||
} else {
|
|
||||||
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace broadcast
|
|
||||||
} // namespace ndarray
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
using namespace ndarray::broadcast;
|
|
||||||
|
|
||||||
void __nac3_ndarray_broadcast_to(NDArray<int32_t>* src_ndarray, NDArray<int32_t>* dst_ndarray) {
|
|
||||||
broadcast_to(src_ndarray, dst_ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_broadcast_to64(NDArray<int64_t>* src_ndarray, NDArray<int64_t>* dst_ndarray) {
|
|
||||||
broadcast_to(src_ndarray, dst_ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_broadcast_shapes(int32_t num_shapes,
|
|
||||||
const ShapeEntry<int32_t>* shapes,
|
|
||||||
int32_t dst_ndims,
|
|
||||||
int32_t* dst_shape) {
|
|
||||||
broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_broadcast_shapes64(int64_t num_shapes,
|
|
||||||
const ShapeEntry<int64_t>* shapes,
|
|
||||||
int64_t dst_ndims,
|
|
||||||
int64_t* dst_shape) {
|
|
||||||
broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,45 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "irrt/int_types.hpp"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
/**
|
|
||||||
* @brief The NDArray object
|
|
||||||
*
|
|
||||||
* Official numpy implementation:
|
|
||||||
* https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
struct NDArray {
|
|
||||||
/**
|
|
||||||
* @brief The underlying data this `ndarray` is pointing to.
|
|
||||||
*/
|
|
||||||
uint8_t* data;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief The number of bytes of a single element in `data`.
|
|
||||||
*/
|
|
||||||
SizeT itemsize;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief The number of dimensions of this shape.
|
|
||||||
*/
|
|
||||||
SizeT ndims;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief The NDArray shape, with length equal to `ndims`.
|
|
||||||
*
|
|
||||||
* Note that it may contain 0.
|
|
||||||
*/
|
|
||||||
SizeT* shape;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Array strides, with length equal to `ndims`
|
|
||||||
*
|
|
||||||
* The stride values are in units of bytes, not number of elements.
|
|
||||||
*
|
|
||||||
* Note that `strides` can have negative values or contain 0.
|
|
||||||
*/
|
|
||||||
SizeT* strides;
|
|
||||||
};
|
|
||||||
} // namespace
|
|
|
@ -1,220 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "irrt/exception.hpp"
|
|
||||||
#include "irrt/int_types.hpp"
|
|
||||||
#include "irrt/ndarray/basic.hpp"
|
|
||||||
#include "irrt/ndarray/def.hpp"
|
|
||||||
#include "irrt/range.hpp"
|
|
||||||
#include "irrt/slice.hpp"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
typedef uint8_t NDIndexType;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief A single element index
|
|
||||||
*
|
|
||||||
* `data` points to a `int32_t`.
|
|
||||||
*/
|
|
||||||
const NDIndexType ND_INDEX_TYPE_SINGLE_ELEMENT = 0;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief A slice index
|
|
||||||
*
|
|
||||||
* `data` points to a `Slice<int32_t>`.
|
|
||||||
*/
|
|
||||||
const NDIndexType ND_INDEX_TYPE_SLICE = 1;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief `np.newaxis` / `None`
|
|
||||||
*
|
|
||||||
* `data` is unused.
|
|
||||||
*/
|
|
||||||
const NDIndexType ND_INDEX_TYPE_NEWAXIS = 2;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief `Ellipsis` / `...`
|
|
||||||
*
|
|
||||||
* `data` is unused.
|
|
||||||
*/
|
|
||||||
const NDIndexType ND_INDEX_TYPE_ELLIPSIS = 3;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief An index used in ndarray indexing
|
|
||||||
*
|
|
||||||
* That is:
|
|
||||||
* ```
|
|
||||||
* my_ndarray[::-1, 3, ..., np.newaxis]
|
|
||||||
* ^^^^ ^ ^^^ ^^^^^^^^^^ each of these is represented by an NDIndex.
|
|
||||||
* ```
|
|
||||||
*/
|
|
||||||
struct NDIndex {
|
|
||||||
/**
|
|
||||||
* @brief Enum tag to specify the type of index.
|
|
||||||
*
|
|
||||||
* Please see the comment of each enum constant.
|
|
||||||
*/
|
|
||||||
NDIndexType type;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief The accompanying data associated with `type`.
|
|
||||||
*
|
|
||||||
* Please see the comment of each enum constant.
|
|
||||||
*/
|
|
||||||
uint8_t* data;
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
namespace ndarray {
|
|
||||||
namespace indexing {
|
|
||||||
/**
|
|
||||||
* @brief Perform ndarray "basic indexing" (https://numpy.org/doc/stable/user/basics.indexing.html#basic-indexing)
|
|
||||||
*
|
|
||||||
* This function is very similar to performing `dst_ndarray = src_ndarray[indices]` in Python.
|
|
||||||
*
|
|
||||||
* This function also does proper assertions on `indices` to check for out of bounds access and more.
|
|
||||||
*
|
|
||||||
* # Notes on `dst_ndarray`
|
|
||||||
* The caller is responsible for allocating space for the resulting ndarray.
|
|
||||||
* Here is what this function expects from `dst_ndarray` when called:
|
|
||||||
* - `dst_ndarray->data` does not have to be initialized.
|
|
||||||
* - `dst_ndarray->itemsize` does not have to be initialized.
|
|
||||||
* - `dst_ndarray->ndims` must be initialized, and it must be equal to the expected `ndims` of the `dst_ndarray` after
|
|
||||||
* indexing `src_ndarray` with `indices`.
|
|
||||||
* - `dst_ndarray->shape` must be allocated, through it can contain uninitialized values.
|
|
||||||
* - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values.
|
|
||||||
* When this function call ends:
|
|
||||||
* - `dst_ndarray->data` is set to `src_ndarray->data`.
|
|
||||||
* - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`.
|
|
||||||
* - `dst_ndarray->ndims` is unchanged.
|
|
||||||
* - `dst_ndarray->shape` is updated according to how `src_ndarray` is indexed.
|
|
||||||
* - `dst_ndarray->strides` is updated accordingly by how ndarray indexing works.
|
|
||||||
*
|
|
||||||
* @param indices indices to index `src_ndarray`, ordered in the same way you would write them in Python.
|
|
||||||
* @param src_ndarray The NDArray to be indexed.
|
|
||||||
* @param dst_ndarray The resulting NDArray after indexing. Further details in the comments above,
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void index(SizeT num_indices, const NDIndex* indices, const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
|
||||||
// Validate `indices`.
|
|
||||||
|
|
||||||
// Expected value of `dst_ndarray->ndims`.
|
|
||||||
SizeT expected_dst_ndims = src_ndarray->ndims;
|
|
||||||
// To check for "too many indices for array: array is ?-dimensional, but ? were indexed"
|
|
||||||
SizeT num_indexed = 0;
|
|
||||||
// There may be ellipsis `...` in `indices`. There can only be 0 or 1 ellipsis.
|
|
||||||
SizeT num_ellipsis = 0;
|
|
||||||
|
|
||||||
for (SizeT i = 0; i < num_indices; i++) {
|
|
||||||
if (indices[i].type == ND_INDEX_TYPE_SINGLE_ELEMENT) {
|
|
||||||
expected_dst_ndims--;
|
|
||||||
num_indexed++;
|
|
||||||
} else if (indices[i].type == ND_INDEX_TYPE_SLICE) {
|
|
||||||
num_indexed++;
|
|
||||||
} else if (indices[i].type == ND_INDEX_TYPE_NEWAXIS) {
|
|
||||||
expected_dst_ndims++;
|
|
||||||
} else if (indices[i].type == ND_INDEX_TYPE_ELLIPSIS) {
|
|
||||||
num_ellipsis++;
|
|
||||||
if (num_ellipsis > 1) {
|
|
||||||
raise_exception(SizeT, EXN_INDEX_ERROR, "an index can only have a single ellipsis ('...')", NO_PARAM,
|
|
||||||
NO_PARAM, NO_PARAM);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
__builtin_unreachable();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
debug_assert_eq(SizeT, expected_dst_ndims, dst_ndarray->ndims);
|
|
||||||
|
|
||||||
if (src_ndarray->ndims - num_indexed < 0) {
|
|
||||||
raise_exception(SizeT, EXN_INDEX_ERROR,
|
|
||||||
"too many indices for array: array is {0}-dimensional, "
|
|
||||||
"but {1} were indexed",
|
|
||||||
src_ndarray->ndims, num_indices, NO_PARAM);
|
|
||||||
}
|
|
||||||
|
|
||||||
dst_ndarray->data = src_ndarray->data;
|
|
||||||
dst_ndarray->itemsize = src_ndarray->itemsize;
|
|
||||||
|
|
||||||
// Reference code:
|
|
||||||
// https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L652
|
|
||||||
SizeT src_axis = 0;
|
|
||||||
SizeT dst_axis = 0;
|
|
||||||
|
|
||||||
for (int32_t i = 0; i < num_indices; i++) {
|
|
||||||
const NDIndex* index = &indices[i];
|
|
||||||
if (index->type == ND_INDEX_TYPE_SINGLE_ELEMENT) {
|
|
||||||
SizeT input = (SizeT) * ((int32_t*)index->data);
|
|
||||||
|
|
||||||
SizeT k = slice::resolve_index_in_length(src_ndarray->shape[src_axis], input);
|
|
||||||
if (k == -1) {
|
|
||||||
raise_exception(SizeT, EXN_INDEX_ERROR,
|
|
||||||
"index {0} is out of bounds for axis {1} "
|
|
||||||
"with size {2}",
|
|
||||||
input, src_axis, src_ndarray->shape[src_axis]);
|
|
||||||
}
|
|
||||||
|
|
||||||
dst_ndarray->data += k * src_ndarray->strides[src_axis];
|
|
||||||
|
|
||||||
src_axis++;
|
|
||||||
} else if (index->type == ND_INDEX_TYPE_SLICE) {
|
|
||||||
Slice<int32_t>* slice = (Slice<int32_t>*)index->data;
|
|
||||||
|
|
||||||
Range<int32_t> range = slice->indices_checked<SizeT>(src_ndarray->shape[src_axis]);
|
|
||||||
|
|
||||||
dst_ndarray->data += (SizeT)range.start * src_ndarray->strides[src_axis];
|
|
||||||
dst_ndarray->strides[dst_axis] = ((SizeT)range.step) * src_ndarray->strides[src_axis];
|
|
||||||
dst_ndarray->shape[dst_axis] = (SizeT)range.len<SizeT>();
|
|
||||||
|
|
||||||
dst_axis++;
|
|
||||||
src_axis++;
|
|
||||||
} else if (index->type == ND_INDEX_TYPE_NEWAXIS) {
|
|
||||||
dst_ndarray->strides[dst_axis] = 0;
|
|
||||||
dst_ndarray->shape[dst_axis] = 1;
|
|
||||||
|
|
||||||
dst_axis++;
|
|
||||||
} else if (index->type == ND_INDEX_TYPE_ELLIPSIS) {
|
|
||||||
// The number of ':' entries this '...' implies.
|
|
||||||
SizeT ellipsis_size = src_ndarray->ndims - num_indexed;
|
|
||||||
|
|
||||||
for (SizeT j = 0; j < ellipsis_size; j++) {
|
|
||||||
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
|
|
||||||
dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis];
|
|
||||||
|
|
||||||
dst_axis++;
|
|
||||||
src_axis++;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
__builtin_unreachable();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (; dst_axis < dst_ndarray->ndims; dst_axis++, src_axis++) {
|
|
||||||
dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis];
|
|
||||||
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
|
|
||||||
}
|
|
||||||
|
|
||||||
debug_assert_eq(SizeT, src_ndarray->ndims, src_axis);
|
|
||||||
debug_assert_eq(SizeT, dst_ndarray->ndims, dst_axis);
|
|
||||||
}
|
|
||||||
} // namespace indexing
|
|
||||||
} // namespace ndarray
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
using namespace ndarray::indexing;
|
|
||||||
|
|
||||||
void __nac3_ndarray_index(int32_t num_indices,
|
|
||||||
NDIndex* indices,
|
|
||||||
NDArray<int32_t>* src_ndarray,
|
|
||||||
NDArray<int32_t>* dst_ndarray) {
|
|
||||||
index(num_indices, indices, src_ndarray, dst_ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_index64(int64_t num_indices,
|
|
||||||
NDIndex* indices,
|
|
||||||
NDArray<int64_t>* src_ndarray,
|
|
||||||
NDArray<int64_t>* dst_ndarray) {
|
|
||||||
index(num_indices, indices, src_ndarray, dst_ndarray);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,146 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "irrt/int_types.hpp"
|
|
||||||
#include "irrt/ndarray/def.hpp"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
/**
|
|
||||||
* @brief Helper struct to enumerate through an ndarray *efficiently*.
|
|
||||||
*
|
|
||||||
* Example usage (in pseudo-code):
|
|
||||||
* ```
|
|
||||||
* // Suppose my_ndarray has been initialized, with shape [2, 3] and dtype `double`
|
|
||||||
* NDIter nditer;
|
|
||||||
* nditer.initialize(my_ndarray);
|
|
||||||
* while (nditer.has_element()) {
|
|
||||||
* // This body is run 6 (= my_ndarray.size) times.
|
|
||||||
*
|
|
||||||
* // [0, 0] -> [0, 1] -> [0, 2] -> [1, 0] -> [1, 1] -> [1, 2] -> end
|
|
||||||
* print(nditer.indices);
|
|
||||||
*
|
|
||||||
* // 0 -> 1 -> 2 -> 3 -> 4 -> 5
|
|
||||||
* print(nditer.nth);
|
|
||||||
*
|
|
||||||
* // <1st element> -> <2nd element> -> ... -> <6th element> -> end
|
|
||||||
* print(*((double *) nditer.element))
|
|
||||||
*
|
|
||||||
* nditer.next(); // Go to next element.
|
|
||||||
* }
|
|
||||||
* ```
|
|
||||||
*
|
|
||||||
* Interesting cases:
|
|
||||||
* - If `my_ndarray.ndims` == 0, there is one iteration.
|
|
||||||
* - If `my_ndarray.shape` contains zeroes, there are no iterations.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
struct NDIter {
|
|
||||||
// Information about the ndarray being iterated over.
|
|
||||||
SizeT ndims;
|
|
||||||
SizeT* shape;
|
|
||||||
SizeT* strides;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief The current indices.
|
|
||||||
*
|
|
||||||
* Must be allocated by the caller.
|
|
||||||
*/
|
|
||||||
SizeT* indices;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief The nth (0-based) index of the current indices.
|
|
||||||
*
|
|
||||||
* Initially this is 0.
|
|
||||||
*/
|
|
||||||
SizeT nth;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Pointer to the current element.
|
|
||||||
*
|
|
||||||
* Initially this points to first element of the ndarray.
|
|
||||||
*/
|
|
||||||
uint8_t* element;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Cache for the product of shape.
|
|
||||||
*
|
|
||||||
* Could be 0 if `shape` has 0s in it.
|
|
||||||
*/
|
|
||||||
SizeT size;
|
|
||||||
|
|
||||||
void initialize(SizeT ndims, SizeT* shape, SizeT* strides, uint8_t* element, SizeT* indices) {
|
|
||||||
this->ndims = ndims;
|
|
||||||
this->shape = shape;
|
|
||||||
this->strides = strides;
|
|
||||||
|
|
||||||
this->indices = indices;
|
|
||||||
this->element = element;
|
|
||||||
|
|
||||||
// Compute size
|
|
||||||
this->size = 1;
|
|
||||||
for (SizeT i = 0; i < ndims; i++) {
|
|
||||||
this->size *= shape[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
// `indices` starts on all 0s.
|
|
||||||
for (SizeT axis = 0; axis < ndims; axis++)
|
|
||||||
indices[axis] = 0;
|
|
||||||
nth = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
void initialize_by_ndarray(NDArray<SizeT>* ndarray, SizeT* indices) {
|
|
||||||
// NOTE: ndarray->data is pointing to the first element, and `NDIter`'s `element` should also point to the first
|
|
||||||
// element as well.
|
|
||||||
this->initialize(ndarray->ndims, ndarray->shape, ndarray->strides, ndarray->data, indices);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Is the current iteration valid?
|
|
||||||
// If true, then `element`, `indices` and `nth` contain details about the current element.
|
|
||||||
bool has_element() { return nth < size; }
|
|
||||||
|
|
||||||
// Go to the next element.
|
|
||||||
void next() {
|
|
||||||
for (SizeT i = 0; i < ndims; i++) {
|
|
||||||
SizeT axis = ndims - i - 1;
|
|
||||||
indices[axis]++;
|
|
||||||
if (indices[axis] >= shape[axis]) {
|
|
||||||
indices[axis] = 0;
|
|
||||||
|
|
||||||
// TODO: There is something called backstrides to speedup iteration.
|
|
||||||
// See https://ajcr.net/stride-guide-part-1/, and
|
|
||||||
// https://docs.scipy.org/doc/numpy-1.13.0/reference/c-api.types-and-structures.html#c.PyArrayIterObject.PyArrayIterObject.backstrides.
|
|
||||||
element -= strides[axis] * (shape[axis] - 1);
|
|
||||||
} else {
|
|
||||||
element += strides[axis];
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
nth++;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
void __nac3_nditer_initialize(NDIter<int32_t>* iter, NDArray<int32_t>* ndarray, int32_t* indices) {
|
|
||||||
iter->initialize_by_ndarray(ndarray, indices);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_nditer_initialize64(NDIter<int64_t>* iter, NDArray<int64_t>* ndarray, int64_t* indices) {
|
|
||||||
iter->initialize_by_ndarray(ndarray, indices);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool __nac3_nditer_has_element(NDIter<int32_t>* iter) {
|
|
||||||
return iter->has_element();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool __nac3_nditer_has_element64(NDIter<int64_t>* iter) {
|
|
||||||
return iter->has_element();
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_nditer_next(NDIter<int32_t>* iter) {
|
|
||||||
iter->next();
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_nditer_next64(NDIter<int64_t>* iter) {
|
|
||||||
iter->next();
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,100 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "irrt/debug.hpp"
|
|
||||||
#include "irrt/exception.hpp"
|
|
||||||
#include "irrt/int_types.hpp"
|
|
||||||
#include "irrt/ndarray/basic.hpp"
|
|
||||||
#include "irrt/ndarray/broadcast.hpp"
|
|
||||||
#include "irrt/ndarray/iter.hpp"
|
|
||||||
|
|
||||||
// NOTE: Everything would be much easier and elegant if einsum is implemented.
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
namespace ndarray {
|
|
||||||
namespace matmul {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Perform the broadcast in `np.einsum("...ij,...jk->...ik", a, b)`.
|
|
||||||
*
|
|
||||||
* Example:
|
|
||||||
* Suppose `a_shape == [1, 97, 4, 2]`
|
|
||||||
* and `b_shape == [99, 98, 1, 2, 5]`,
|
|
||||||
*
|
|
||||||
* ...then `new_a_shape == [99, 98, 97, 4, 2]`,
|
|
||||||
* `new_b_shape == [99, 98, 97, 2, 5]`,
|
|
||||||
* and `dst_shape == [99, 98, 97, 4, 5]`.
|
|
||||||
* ^^^^^^^^^^ ^^^^
|
|
||||||
* (broadcasted) (4x2 @ 2x5 => 4x5)
|
|
||||||
*
|
|
||||||
* @param a_ndims Length of `a_shape`.
|
|
||||||
* @param a_shape Shape of `a`.
|
|
||||||
* @param b_ndims Length of `b_shape`.
|
|
||||||
* @param b_shape Shape of `b`.
|
|
||||||
* @param final_ndims Should be equal to `max(a_ndims, b_ndims)`. This is the length of `new_a_shape`,
|
|
||||||
* `new_b_shape`, and `dst_shape` - the number of dimensions after broadcasting.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void calculate_shapes(SizeT a_ndims,
|
|
||||||
SizeT* a_shape,
|
|
||||||
SizeT b_ndims,
|
|
||||||
SizeT* b_shape,
|
|
||||||
SizeT final_ndims,
|
|
||||||
SizeT* new_a_shape,
|
|
||||||
SizeT* new_b_shape,
|
|
||||||
SizeT* dst_shape) {
|
|
||||||
debug_assert(SizeT, a_ndims >= 2);
|
|
||||||
debug_assert(SizeT, b_ndims >= 2);
|
|
||||||
debug_assert_eq(SizeT, max(a_ndims, b_ndims), final_ndims);
|
|
||||||
|
|
||||||
// Check that a and b are compatible for matmul
|
|
||||||
if (a_shape[a_ndims - 1] != b_shape[b_ndims - 2]) {
|
|
||||||
// This is a custom error message. Different from NumPy.
|
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot multiply LHS (shape ?x{0}) with RHS (shape {1}x?})",
|
|
||||||
a_shape[a_ndims - 1], b_shape[b_ndims - 2], NO_PARAM);
|
|
||||||
}
|
|
||||||
|
|
||||||
const SizeT num_entries = 2;
|
|
||||||
ShapeEntry<SizeT> entries[num_entries] = {{.ndims = a_ndims - 2, .shape = a_shape},
|
|
||||||
{.ndims = b_ndims - 2, .shape = b_shape}};
|
|
||||||
|
|
||||||
// TODO: Optimize this
|
|
||||||
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries, final_ndims - 2, new_a_shape);
|
|
||||||
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries, final_ndims - 2, new_b_shape);
|
|
||||||
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries, final_ndims - 2, dst_shape);
|
|
||||||
|
|
||||||
new_a_shape[final_ndims - 2] = a_shape[a_ndims - 2];
|
|
||||||
new_a_shape[final_ndims - 1] = a_shape[a_ndims - 1];
|
|
||||||
new_b_shape[final_ndims - 2] = b_shape[b_ndims - 2];
|
|
||||||
new_b_shape[final_ndims - 1] = b_shape[b_ndims - 1];
|
|
||||||
dst_shape[final_ndims - 2] = a_shape[a_ndims - 2];
|
|
||||||
dst_shape[final_ndims - 1] = b_shape[b_ndims - 1];
|
|
||||||
}
|
|
||||||
} // namespace matmul
|
|
||||||
} // namespace ndarray
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
using namespace ndarray::matmul;
|
|
||||||
|
|
||||||
void __nac3_ndarray_matmul_calculate_shapes(int32_t a_ndims,
|
|
||||||
int32_t* a_shape,
|
|
||||||
int32_t b_ndims,
|
|
||||||
int32_t* b_shape,
|
|
||||||
int32_t final_ndims,
|
|
||||||
int32_t* new_a_shape,
|
|
||||||
int32_t* new_b_shape,
|
|
||||||
int32_t* dst_shape) {
|
|
||||||
calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_matmul_calculate_shapes64(int64_t a_ndims,
|
|
||||||
int64_t* a_shape,
|
|
||||||
int64_t b_ndims,
|
|
||||||
int64_t* b_shape,
|
|
||||||
int64_t final_ndims,
|
|
||||||
int64_t* new_a_shape,
|
|
||||||
int64_t* new_b_shape,
|
|
||||||
int64_t* dst_shape) {
|
|
||||||
calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,99 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "irrt/exception.hpp"
|
|
||||||
#include "irrt/int_types.hpp"
|
|
||||||
#include "irrt/ndarray/def.hpp"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
namespace ndarray {
|
|
||||||
namespace reshape {
|
|
||||||
/**
|
|
||||||
* @brief Perform assertions on and resolve unknown dimensions in `new_shape` in `np.reshape(<ndarray>, new_shape)`
|
|
||||||
*
|
|
||||||
* If `new_shape` indeed contains unknown dimensions (specified with `-1`, just like numpy), `new_shape` will be
|
|
||||||
* modified to contain the resolved dimension.
|
|
||||||
*
|
|
||||||
* To perform assertions on and resolve unknown dimensions in `new_shape`, we don't need the actual
|
|
||||||
* `<ndarray>` object itself, but only the `.size` of the `<ndarray>`.
|
|
||||||
*
|
|
||||||
* @param size The `.size` of `<ndarray>`
|
|
||||||
* @param new_ndims Number of elements in `new_shape`
|
|
||||||
* @param new_shape Target shape to reshape to
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void resolve_and_check_new_shape(SizeT size, SizeT new_ndims, SizeT* new_shape) {
|
|
||||||
// Is there a -1 in `new_shape`?
|
|
||||||
bool neg1_exists = false;
|
|
||||||
// Location of -1, only initialized if `neg1_exists` is true
|
|
||||||
SizeT neg1_axis_i;
|
|
||||||
// The computed ndarray size of `new_shape`
|
|
||||||
SizeT new_size = 1;
|
|
||||||
|
|
||||||
for (SizeT axis_i = 0; axis_i < new_ndims; axis_i++) {
|
|
||||||
SizeT dim = new_shape[axis_i];
|
|
||||||
if (dim < 0) {
|
|
||||||
if (dim == -1) {
|
|
||||||
if (neg1_exists) {
|
|
||||||
// Multiple `-1` found. Throw an error.
|
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR, "can only specify one unknown dimension", NO_PARAM,
|
|
||||||
NO_PARAM, NO_PARAM);
|
|
||||||
} else {
|
|
||||||
neg1_exists = true;
|
|
||||||
neg1_axis_i = axis_i;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// TODO: What? In `np.reshape` any negative dimensions is
|
|
||||||
// treated like its `-1`.
|
|
||||||
//
|
|
||||||
// Try running `np.zeros((3, 4)).reshape((-999, 2))`
|
|
||||||
//
|
|
||||||
// It is not documented by numpy.
|
|
||||||
// Throw an error for now...
|
|
||||||
|
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR, "Found non -1 negative dimension {0} on axis {1}", dim, axis_i,
|
|
||||||
NO_PARAM);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
new_size *= dim;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool can_reshape;
|
|
||||||
if (neg1_exists) {
|
|
||||||
// Let `x` be the unknown dimension
|
|
||||||
// Solve `x * <new_size> = <size>`
|
|
||||||
if (new_size == 0 && size == 0) {
|
|
||||||
// `x` has infinitely many solutions
|
|
||||||
can_reshape = false;
|
|
||||||
} else if (new_size == 0 && size != 0) {
|
|
||||||
// `x` has no solutions
|
|
||||||
can_reshape = false;
|
|
||||||
} else if (size % new_size != 0) {
|
|
||||||
// `x` has no integer solutions
|
|
||||||
can_reshape = false;
|
|
||||||
} else {
|
|
||||||
can_reshape = true;
|
|
||||||
new_shape[neg1_axis_i] = size / new_size; // Resolve dimension
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
can_reshape = (new_size == size);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!can_reshape) {
|
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR, "cannot reshape array of size {0} into given shape", size, NO_PARAM,
|
|
||||||
NO_PARAM);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace reshape
|
|
||||||
} // namespace ndarray
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
void __nac3_ndarray_reshape_resolve_and_check_new_shape(int32_t size, int32_t new_ndims, int32_t* new_shape) {
|
|
||||||
ndarray::reshape::resolve_and_check_new_shape(size, new_ndims, new_shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_reshape_resolve_and_check_new_shape64(int64_t size, int64_t new_ndims, int64_t* new_shape) {
|
|
||||||
ndarray::reshape::resolve_and_check_new_shape(size, new_ndims, new_shape);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,145 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "irrt/debug.hpp"
|
|
||||||
#include "irrt/exception.hpp"
|
|
||||||
#include "irrt/int_types.hpp"
|
|
||||||
#include "irrt/ndarray/def.hpp"
|
|
||||||
#include "irrt/slice.hpp"
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Notes on `np.transpose(<array>, <axes>)`
|
|
||||||
*
|
|
||||||
* TODO: `axes`, if specified, can actually contain negative indices,
|
|
||||||
* but it is not documented in numpy.
|
|
||||||
*
|
|
||||||
* Supporting it for now.
|
|
||||||
*/
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
namespace ndarray {
|
|
||||||
namespace transpose {
|
|
||||||
/**
|
|
||||||
* @brief Do assertions on `<axes>` in `np.transpose(<array>, <axes>)`.
|
|
||||||
*
|
|
||||||
* Note that `np.transpose`'s `<axe>` argument is optional. If the argument
|
|
||||||
* is specified but the user, use this function to do assertions on it.
|
|
||||||
*
|
|
||||||
* @param ndims The number of dimensions of `<array>`
|
|
||||||
* @param num_axes Number of elements in `<axes>` as specified by the user.
|
|
||||||
* This should be equal to `ndims`. If not, a "ValueError: axes don't match array" is thrown.
|
|
||||||
* @param axes The user specified `<axes>`.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void assert_transpose_axes(SizeT ndims, SizeT num_axes, const SizeT* axes) {
|
|
||||||
if (ndims != num_axes) {
|
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR, "axes don't match array", NO_PARAM, NO_PARAM, NO_PARAM);
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Optimize this
|
|
||||||
bool* axe_specified = (bool*)__builtin_alloca(sizeof(bool) * ndims);
|
|
||||||
for (SizeT i = 0; i < ndims; i++)
|
|
||||||
axe_specified[i] = false;
|
|
||||||
|
|
||||||
for (SizeT i = 0; i < ndims; i++) {
|
|
||||||
SizeT axis = slice::resolve_index_in_length(ndims, axes[i]);
|
|
||||||
if (axis == -1) {
|
|
||||||
// TODO: numpy actually throws a `numpy.exceptions.AxisError`
|
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR, "axis {0} is out of bounds for array of dimension {1}", axis, ndims,
|
|
||||||
NO_PARAM);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (axe_specified[axis]) {
|
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR, "repeated axis in transpose", NO_PARAM, NO_PARAM, NO_PARAM);
|
|
||||||
}
|
|
||||||
|
|
||||||
axe_specified[axis] = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Create a transpose view of `src_ndarray` and perform proper assertions.
|
|
||||||
*
|
|
||||||
* This function is very similar to doing `dst_ndarray = np.transpose(src_ndarray, <axes>)`.
|
|
||||||
* If `<axes>` is supposed to be `None`, caller can pass in a `nullptr` to `<axes>`.
|
|
||||||
*
|
|
||||||
* The transpose view created is returned by modifying `dst_ndarray`.
|
|
||||||
*
|
|
||||||
* The caller is responsible for setting up `dst_ndarray` before calling this function.
|
|
||||||
* Here is what this function expects from `dst_ndarray` when called:
|
|
||||||
* - `dst_ndarray->data` does not have to be initialized.
|
|
||||||
* - `dst_ndarray->itemsize` does not have to be initialized.
|
|
||||||
* - `dst_ndarray->ndims` must be initialized, must be equal to `src_ndarray->ndims`.
|
|
||||||
* - `dst_ndarray->shape` must be allocated, through it can contain uninitialized values.
|
|
||||||
* - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values.
|
|
||||||
* When this function call ends:
|
|
||||||
* - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`)
|
|
||||||
* - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`
|
|
||||||
* - `dst_ndarray->ndims` is unchanged
|
|
||||||
* - `dst_ndarray->shape` is updated according to how `np.transpose` works
|
|
||||||
* - `dst_ndarray->strides` is updated according to how `np.transpose` works
|
|
||||||
*
|
|
||||||
* @param src_ndarray The NDArray to build a transpose view on
|
|
||||||
* @param dst_ndarray The resulting NDArray after transpose. Further details in the comments above,
|
|
||||||
* @param num_axes Number of elements in axes. Unused if `axes` is nullptr.
|
|
||||||
* @param axes Axes permutation. Set it to `nullptr` if `<axes>` is `None`.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void transpose(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray, SizeT num_axes, const SizeT* axes) {
|
|
||||||
debug_assert_eq(SizeT, src_ndarray->ndims, dst_ndarray->ndims);
|
|
||||||
const auto ndims = src_ndarray->ndims;
|
|
||||||
|
|
||||||
if (axes != nullptr)
|
|
||||||
assert_transpose_axes(ndims, num_axes, axes);
|
|
||||||
|
|
||||||
dst_ndarray->data = src_ndarray->data;
|
|
||||||
dst_ndarray->itemsize = src_ndarray->itemsize;
|
|
||||||
|
|
||||||
// Check out https://ajcr.net/stride-guide-part-2/ to see how `np.transpose` works behind the scenes.
|
|
||||||
if (axes == nullptr) {
|
|
||||||
// `np.transpose(<array>, axes=None)`
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Minor note: `np.transpose(<array>, axes=None)` is equivalent to
|
|
||||||
* `np.transpose(<array>, axes=[N-1, N-2, ..., 0])` - basically it
|
|
||||||
* is reversing the order of strides and shape.
|
|
||||||
*
|
|
||||||
* This is a fast implementation to handle this special (but very common) case.
|
|
||||||
*/
|
|
||||||
|
|
||||||
for (SizeT axis = 0; axis < ndims; axis++) {
|
|
||||||
dst_ndarray->shape[axis] = src_ndarray->shape[ndims - axis - 1];
|
|
||||||
dst_ndarray->strides[axis] = src_ndarray->strides[ndims - axis - 1];
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// `np.transpose(<array>, <axes>)`
|
|
||||||
|
|
||||||
// Permute strides and shape according to `axes`, while resolving negative indices in `axes`
|
|
||||||
for (SizeT axis = 0; axis < ndims; axis++) {
|
|
||||||
// `i` cannot be OUT_OF_BOUNDS because of assertions
|
|
||||||
SizeT i = slice::resolve_index_in_length(ndims, axes[axis]);
|
|
||||||
|
|
||||||
dst_ndarray->shape[axis] = src_ndarray->shape[i];
|
|
||||||
dst_ndarray->strides[axis] = src_ndarray->strides[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace transpose
|
|
||||||
} // namespace ndarray
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
using namespace ndarray::transpose;
|
|
||||||
void __nac3_ndarray_transpose(const NDArray<int32_t>* src_ndarray,
|
|
||||||
NDArray<int32_t>* dst_ndarray,
|
|
||||||
int32_t num_axes,
|
|
||||||
const int32_t* axes) {
|
|
||||||
transpose(src_ndarray, dst_ndarray, num_axes, axes);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_transpose64(const NDArray<int64_t>* src_ndarray,
|
|
||||||
NDArray<int64_t>* dst_ndarray,
|
|
||||||
int64_t num_axes,
|
|
||||||
const int64_t* axes) {
|
|
||||||
transpose(src_ndarray, dst_ndarray, num_axes, axes);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,47 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "irrt/debug.hpp"
|
|
||||||
#include "irrt/int_types.hpp"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
namespace range {
|
|
||||||
template<typename T>
|
|
||||||
T len(T start, T stop, T step) {
|
|
||||||
// Reference:
|
|
||||||
// https://github.com/python/cpython/blob/9dbd12375561a393eaec4b21ee4ac568a407cdb0/Objects/rangeobject.c#L933
|
|
||||||
if (step > 0 && start < stop)
|
|
||||||
return 1 + (stop - 1 - start) / step;
|
|
||||||
else if (step < 0 && start > stop)
|
|
||||||
return 1 + (start - 1 - stop) / (-step);
|
|
||||||
else
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
} // namespace range
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief A Python range.
|
|
||||||
*/
|
|
||||||
template<typename T>
|
|
||||||
struct Range {
|
|
||||||
T start;
|
|
||||||
T stop;
|
|
||||||
T step;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Calculate the `len()` of this range.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
T len() {
|
|
||||||
debug_assert(SizeT, step != 0);
|
|
||||||
return range::len(start, stop, step);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
using namespace range;
|
|
||||||
|
|
||||||
SliceIndex __nac3_range_slice_len(const SliceIndex start, const SliceIndex end, const SliceIndex step) {
|
|
||||||
return len(start, end, step);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,145 +1,6 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "irrt/debug.hpp"
|
|
||||||
#include "irrt/exception.hpp"
|
|
||||||
#include "irrt/int_types.hpp"
|
#include "irrt/int_types.hpp"
|
||||||
#include "irrt/math_util.hpp"
|
|
||||||
#include "irrt/range.hpp"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
namespace slice {
|
|
||||||
/**
|
|
||||||
* @brief Resolve a possibly negative index in a list of a known length.
|
|
||||||
*
|
|
||||||
* Returns -1 if the resolved index is out of the list's bounds.
|
|
||||||
*/
|
|
||||||
template<typename T>
|
|
||||||
T resolve_index_in_length(T length, T index) {
|
|
||||||
T resolved = index < 0 ? length + index : index;
|
|
||||||
if (0 <= resolved && resolved < length) {
|
|
||||||
return resolved;
|
|
||||||
} else {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Resolve a slice as a range.
|
|
||||||
*
|
|
||||||
* This is equivalent to `range(*slice(start, stop, step).indices(length))` in Python.
|
|
||||||
*/
|
|
||||||
template<typename T>
|
|
||||||
void indices(bool start_defined,
|
|
||||||
T start,
|
|
||||||
bool stop_defined,
|
|
||||||
T stop,
|
|
||||||
bool step_defined,
|
|
||||||
T step,
|
|
||||||
T length,
|
|
||||||
T* range_start,
|
|
||||||
T* range_stop,
|
|
||||||
T* range_step) {
|
|
||||||
// Reference: https://github.com/python/cpython/blob/main/Objects/sliceobject.c#L388
|
|
||||||
*range_step = step_defined ? step : 1;
|
|
||||||
bool step_is_negative = *range_step < 0;
|
|
||||||
|
|
||||||
T lower, upper;
|
|
||||||
if (step_is_negative) {
|
|
||||||
lower = -1;
|
|
||||||
upper = length - 1;
|
|
||||||
} else {
|
|
||||||
lower = 0;
|
|
||||||
upper = length;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (start_defined) {
|
|
||||||
*range_start = start < 0 ? max(lower, start + length) : min(upper, start);
|
|
||||||
} else {
|
|
||||||
*range_start = step_is_negative ? upper : lower;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (stop_defined) {
|
|
||||||
*range_stop = stop < 0 ? max(lower, stop + length) : min(upper, stop);
|
|
||||||
} else {
|
|
||||||
*range_stop = step_is_negative ? lower : upper;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace slice
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief A Python-like slice with **unresolved** indices.
|
|
||||||
*/
|
|
||||||
template<typename T>
|
|
||||||
struct Slice {
|
|
||||||
bool start_defined;
|
|
||||||
T start;
|
|
||||||
|
|
||||||
bool stop_defined;
|
|
||||||
T stop;
|
|
||||||
|
|
||||||
bool step_defined;
|
|
||||||
T step;
|
|
||||||
|
|
||||||
Slice() { this->reset(); }
|
|
||||||
|
|
||||||
void reset() {
|
|
||||||
this->start_defined = false;
|
|
||||||
this->stop_defined = false;
|
|
||||||
this->step_defined = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_start(T start) {
|
|
||||||
this->start_defined = true;
|
|
||||||
this->start = start;
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_stop(T stop) {
|
|
||||||
this->stop_defined = true;
|
|
||||||
this->stop = stop;
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_step(T step) {
|
|
||||||
this->step_defined = true;
|
|
||||||
this->step = step;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Resolve this slice as a range.
|
|
||||||
*
|
|
||||||
* In Python, this would be `range(*slice(start, stop, step).indices(length))`.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
Range<T> indices(T length) {
|
|
||||||
// Reference:
|
|
||||||
// https://github.com/python/cpython/blob/main/Objects/sliceobject.c#L388
|
|
||||||
debug_assert(SizeT, length >= 0);
|
|
||||||
|
|
||||||
Range<T> result;
|
|
||||||
slice::indices(start_defined, start, stop_defined, stop, step_defined, step, length, &result.start,
|
|
||||||
&result.stop, &result.step);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Like `.indices()` but with assertions.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
Range<T> indices_checked(T length) {
|
|
||||||
// TODO: Switch to `SizeT length`
|
|
||||||
|
|
||||||
if (length < 0) {
|
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR, "length should not be negative, got {0}", length, NO_PARAM,
|
|
||||||
NO_PARAM);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (this->step_defined && this->step == 0) {
|
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR, "slice step cannot be zero", NO_PARAM, NO_PARAM, NO_PARAM);
|
|
||||||
}
|
|
||||||
|
|
||||||
return this->indices<SizeT>(length);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) {
|
SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) {
|
||||||
|
@ -153,4 +14,15 @@ SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) {
|
||||||
}
|
}
|
||||||
return i;
|
return i;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SliceIndex __nac3_range_slice_len(const SliceIndex start, const SliceIndex end, const SliceIndex step) {
|
||||||
|
SliceIndex diff = end - start;
|
||||||
|
if (diff > 0 && step > 0) {
|
||||||
|
return ((diff - 1) / step) + 1;
|
||||||
|
} else if (diff < 0 && step < 0) {
|
||||||
|
return ((diff + 1) / step) + 1;
|
||||||
|
} else {
|
||||||
|
return 0;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
} // namespace
|
|
@ -0,0 +1,21 @@
|
||||||
|
[package]
|
||||||
|
name = "nac3core_derive"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
proc-macro = true
|
||||||
|
|
||||||
|
[[test]]
|
||||||
|
name = "structfields_tests"
|
||||||
|
path = "tests/structfields_test.rs"
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
nac3core = { path = ".." }
|
||||||
|
trybuild = { version = "1.0", features = ["diff"] }
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
proc-macro2 = "1.0"
|
||||||
|
proc-macro-error = "1.0"
|
||||||
|
syn = "2.0"
|
||||||
|
quote = "1.0"
|
|
@ -0,0 +1,320 @@
|
||||||
|
use proc_macro::TokenStream;
|
||||||
|
use proc_macro_error::{abort, proc_macro_error};
|
||||||
|
use quote::quote;
|
||||||
|
use syn::{
|
||||||
|
parse_macro_input, spanned::Spanned, Data, DataStruct, Expr, ExprField, ExprMethodCall,
|
||||||
|
ExprPath, GenericArgument, Ident, LitStr, Path, PathArguments, Type, TypePath,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Extracts all generic arguments of a [`Type`] into a [`Vec`].
|
||||||
|
///
|
||||||
|
/// Returns [`Some`] of a possibly-empty [`Vec`] if the path of `ty` matches with
|
||||||
|
/// `expected_ty_name`, otherwise returns [`None`].
|
||||||
|
fn extract_generic_args(expected_ty_name: &'static str, ty: &Type) -> Option<Vec<GenericArgument>> {
|
||||||
|
let Type::Path(TypePath { qself: None, path, .. }) = ty else {
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
|
||||||
|
let segments = &path.segments;
|
||||||
|
if segments.len() != 1 {
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
|
||||||
|
let segment = segments.iter().next().unwrap();
|
||||||
|
if segment.ident != expected_ty_name {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let PathArguments::AngleBracketed(path_args) = &segment.arguments else {
|
||||||
|
return Some(Vec::new());
|
||||||
|
};
|
||||||
|
let args = &path_args.args;
|
||||||
|
|
||||||
|
Some(args.iter().cloned().collect::<Vec<_>>())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Maps a `path` matching one of the `target_idents` into the `replacement` [`Ident`].
|
||||||
|
fn map_path_to_ident(path: &Path, target_idents: &[&str], replacement: &str) -> Option<Ident> {
|
||||||
|
path.require_ident()
|
||||||
|
.ok()
|
||||||
|
.filter(|ident| target_idents.iter().any(|target| ident == target))
|
||||||
|
.map(|ident| Ident::new(replacement, ident.span()))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extracts the left-hand side of a dot-expression.
|
||||||
|
fn extract_dot_operand(expr: &Expr) -> Option<&Expr> {
|
||||||
|
match expr {
|
||||||
|
Expr::MethodCall(ExprMethodCall { receiver: operand, .. })
|
||||||
|
| Expr::Field(ExprField { base: operand, .. }) => Some(operand),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Replaces the top-level receiver of a dot-expression with an [`Ident`], returning `Some(&mut expr)` if the
|
||||||
|
/// replacement is performed.
|
||||||
|
///
|
||||||
|
/// The top-level receiver is the left-most receiver expression, e.g. the top-level receiver of `a.b.c.foo()` is `a`.
|
||||||
|
fn replace_top_level_receiver(expr: &mut Expr, ident: Ident) -> Option<&mut Expr> {
|
||||||
|
if let Expr::MethodCall(ExprMethodCall { receiver: operand, .. })
|
||||||
|
| Expr::Field(ExprField { base: operand, .. }) = expr
|
||||||
|
{
|
||||||
|
return if extract_dot_operand(operand).is_some() {
|
||||||
|
if replace_top_level_receiver(operand, ident).is_some() {
|
||||||
|
Some(expr)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
*operand = Box::new(Expr::Path(ExprPath {
|
||||||
|
attrs: Vec::default(),
|
||||||
|
qself: None,
|
||||||
|
path: ident.into(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
Some(expr)
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Iterates all operands to the left-hand side of the `.` of an [expression][`Expr`], i.e. the container operand of all
|
||||||
|
/// [`Expr::Field`] and the receiver operand of all [`Expr::MethodCall`].
|
||||||
|
///
|
||||||
|
/// The iterator will return the operand expressions in reverse order of appearance. For example, `a.b.c.func()` will
|
||||||
|
/// return `vec![c, b, a]`.
|
||||||
|
fn iter_dot_operands(expr: &Expr) -> impl Iterator<Item = &Expr> {
|
||||||
|
let mut o = extract_dot_operand(expr);
|
||||||
|
|
||||||
|
std::iter::from_fn(move || {
|
||||||
|
let this = o;
|
||||||
|
o = o.as_ref().and_then(|o| extract_dot_operand(o));
|
||||||
|
|
||||||
|
this
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Normalizes a value expression for use when creating an instance of this structure, returning a
|
||||||
|
/// [`proc_macro2::TokenStream`] of tokens representing the normalized expression.
|
||||||
|
fn normalize_value_expr(expr: &Expr) -> proc_macro2::TokenStream {
|
||||||
|
match &expr {
|
||||||
|
Expr::Path(ExprPath { qself: None, path, .. }) => {
|
||||||
|
if let Some(ident) = map_path_to_ident(path, &["usize", "size_t"], "llvm_usize") {
|
||||||
|
quote! { #ident }
|
||||||
|
} else {
|
||||||
|
abort!(
|
||||||
|
path,
|
||||||
|
format!(
|
||||||
|
"Expected one of `size_t`, `usize`, or an implicit call expression in #[value_type(...)], found {}",
|
||||||
|
quote!(#expr).to_string(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Expr::Call(_) => {
|
||||||
|
quote! { ctx.#expr }
|
||||||
|
}
|
||||||
|
|
||||||
|
Expr::MethodCall(_) => {
|
||||||
|
let base_receiver = iter_dot_operands(expr).last();
|
||||||
|
|
||||||
|
match base_receiver {
|
||||||
|
// `usize.{...}`, `size_t.{...}` -> Rewrite the identifiers to `llvm_usize`
|
||||||
|
Some(Expr::Path(ExprPath { qself: None, path, .. }))
|
||||||
|
if map_path_to_ident(path, &["usize", "size_t"], "llvm_usize").is_some() =>
|
||||||
|
{
|
||||||
|
let ident =
|
||||||
|
map_path_to_ident(path, &["usize", "size_t"], "llvm_usize").unwrap();
|
||||||
|
|
||||||
|
let mut expr = expr.clone();
|
||||||
|
let expr = replace_top_level_receiver(&mut expr, ident).unwrap();
|
||||||
|
|
||||||
|
quote!(#expr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// `ctx.{...}`, `context.{...}` -> Rewrite the identifiers to `ctx`
|
||||||
|
Some(Expr::Path(ExprPath { qself: None, path, .. }))
|
||||||
|
if map_path_to_ident(path, &["ctx", "context"], "ctx").is_some() =>
|
||||||
|
{
|
||||||
|
let ident = map_path_to_ident(path, &["ctx", "context"], "ctx").unwrap();
|
||||||
|
|
||||||
|
let mut expr = expr.clone();
|
||||||
|
let expr = replace_top_level_receiver(&mut expr, ident).unwrap();
|
||||||
|
|
||||||
|
quote!(#expr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// No reserved identifier prefix -> Prepend `ctx.` to the entire expression
|
||||||
|
_ => quote! { ctx.#expr },
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_ => {
|
||||||
|
abort!(
|
||||||
|
expr,
|
||||||
|
format!(
|
||||||
|
"Expected one of `size_t`, `usize`, or an implicit call expression in #[value_type(...)], found {}",
|
||||||
|
quote!(#expr).to_string(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Derives an implementation of `codegen::types::structure::StructFields`.
|
||||||
|
///
|
||||||
|
/// The benefit of using `#[derive(StructFields)]` is that all index- or order-dependent logic required by
|
||||||
|
/// `impl StructFields` is automatically generated by this implementation, including the field index as required by
|
||||||
|
/// `StructField::new` and the fields as returned by `StructFields::to_vec`.
|
||||||
|
///
|
||||||
|
/// # Prerequisites
|
||||||
|
///
|
||||||
|
/// In order to derive from [`StructFields`], you must implement (or derive) [`Eq`] and [`Copy`] as required by
|
||||||
|
/// `StructFields`.
|
||||||
|
///
|
||||||
|
/// Moreover, `#[derive(StructFields)]` can only be used for `struct`s with named fields, and may only contain fields
|
||||||
|
/// with either `StructField` or [`PhantomData`] types.
|
||||||
|
///
|
||||||
|
/// # Attributes for [`StructFields`]
|
||||||
|
///
|
||||||
|
/// Each `StructField` field must be declared with the `#[value_type(...)]` attribute. The argument of `value_type`
|
||||||
|
/// accepts one of the following:
|
||||||
|
///
|
||||||
|
/// - An expression returning an instance of `inkwell::types::BasicType` (with or without the receiver `ctx`/`context`).
|
||||||
|
/// For example, `context.i8_type()`, `ctx.i8_type()`, and `i8_type()` all refer to `i8`.
|
||||||
|
/// - The reserved identifiers `usize` and `size_t` referring to an `inkwell::types::IntType` of the platform-dependent
|
||||||
|
/// integer size. `usize` and `size_t` can also be used as the receiver to other method calls, e.g.
|
||||||
|
/// `usize.array_type(3)`.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// The following is an example of an LLVM slice implemented using `#[derive(StructFields)]`.
|
||||||
|
///
|
||||||
|
/// ```rust,ignore
|
||||||
|
/// use nac3core::{
|
||||||
|
/// codegen::types::structure::StructField,
|
||||||
|
/// inkwell::{
|
||||||
|
/// values::{IntValue, PointerValue},
|
||||||
|
/// AddressSpace,
|
||||||
|
/// },
|
||||||
|
/// };
|
||||||
|
/// use nac3core_derive::StructFields;
|
||||||
|
///
|
||||||
|
/// // All classes that implement StructFields must also implement Eq and Copy
|
||||||
|
/// #[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
||||||
|
/// pub struct SliceValue<'ctx> {
|
||||||
|
/// // Declares ptr have a value type of i8*
|
||||||
|
/// //
|
||||||
|
/// // Can also be written as `ctx.i8_type().ptr_type(...)` or `context.i8_type().ptr_type(...)`
|
||||||
|
/// #[value_type(i8_type().ptr_type(AddressSpace::default()))]
|
||||||
|
/// ptr: StructField<'ctx, PointerValue<'ctx>>,
|
||||||
|
///
|
||||||
|
/// // Declares len have a value type of usize, depending on the target compilation platform
|
||||||
|
/// #[value_type(usize)]
|
||||||
|
/// len: StructField<'ctx, IntValue<'ctx>>,
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
#[proc_macro_derive(StructFields, attributes(value_type))]
|
||||||
|
#[proc_macro_error]
|
||||||
|
pub fn derive(input: TokenStream) -> TokenStream {
|
||||||
|
let input = parse_macro_input!(input as syn::DeriveInput);
|
||||||
|
let ident = &input.ident;
|
||||||
|
|
||||||
|
let Data::Struct(DataStruct { fields, .. }) = &input.data else {
|
||||||
|
abort!(input, "Only structs with named fields are supported");
|
||||||
|
};
|
||||||
|
if let Err(err_span) =
|
||||||
|
fields
|
||||||
|
.iter()
|
||||||
|
.try_for_each(|field| if field.ident.is_some() { Ok(()) } else { Err(field.span()) })
|
||||||
|
{
|
||||||
|
abort!(err_span, "Only structs with named fields are supported");
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check if struct<'ctx>
|
||||||
|
if input.generics.params.len() != 1 {
|
||||||
|
abort!(input.generics, "Expected exactly 1 generic parameter")
|
||||||
|
}
|
||||||
|
|
||||||
|
let phantom_info = fields
|
||||||
|
.iter()
|
||||||
|
.filter(|field| extract_generic_args("PhantomData", &field.ty).is_some())
|
||||||
|
.map(|field| field.ident.as_ref().unwrap())
|
||||||
|
.cloned()
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let field_info = fields
|
||||||
|
.iter()
|
||||||
|
.filter(|field| extract_generic_args("PhantomData", &field.ty).is_none())
|
||||||
|
.map(|field| {
|
||||||
|
let ident = field.ident.as_ref().unwrap();
|
||||||
|
let ty = &field.ty;
|
||||||
|
|
||||||
|
let Some(_) = extract_generic_args("StructField", ty) else {
|
||||||
|
abort!(field, "Only StructField and PhantomData are allowed")
|
||||||
|
};
|
||||||
|
|
||||||
|
let attrs = &field.attrs;
|
||||||
|
let Some(value_type_attr) =
|
||||||
|
attrs.iter().find(|attr| attr.path().is_ident("value_type"))
|
||||||
|
else {
|
||||||
|
abort!(field, "Expected #[value_type(...)] attribute for field");
|
||||||
|
};
|
||||||
|
|
||||||
|
let Ok(value_type_expr) = value_type_attr.parse_args::<Expr>() else {
|
||||||
|
abort!(value_type_attr, "Expected expression in #[value_type(...)]");
|
||||||
|
};
|
||||||
|
|
||||||
|
let value_expr_toks = normalize_value_expr(&value_type_expr);
|
||||||
|
|
||||||
|
(ident.clone(), value_expr_toks)
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
// `<*>::new` impl of `StructField` and `PhantomData` for `StructFields::new`
|
||||||
|
let phantoms_create = phantom_info
|
||||||
|
.iter()
|
||||||
|
.map(|id| quote! { #id: ::std::marker::PhantomData })
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let fields_create = field_info
|
||||||
|
.iter()
|
||||||
|
.map(|(id, ty)| {
|
||||||
|
let id_lit = LitStr::new(&id.to_string(), id.span());
|
||||||
|
quote! {
|
||||||
|
#id: ::nac3core::codegen::types::structure::StructField::create(
|
||||||
|
&mut counter,
|
||||||
|
#id_lit,
|
||||||
|
#ty,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
// `.into()` impl of `StructField` for `StructFields::to_vec`
|
||||||
|
let fields_into =
|
||||||
|
field_info.iter().map(|(id, _)| quote! { self.#id.into() }).collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let impl_block = quote! {
|
||||||
|
impl<'ctx> ::nac3core::codegen::types::structure::StructFields<'ctx> for #ident<'ctx> {
|
||||||
|
fn new(ctx: impl ::nac3core::inkwell::context::AsContextRef<'ctx>, llvm_usize: ::nac3core::inkwell::types::IntType<'ctx>) -> Self {
|
||||||
|
let ctx = unsafe { ::nac3core::inkwell::context::ContextRef::new(ctx.as_ctx_ref()) };
|
||||||
|
|
||||||
|
let mut counter = ::nac3core::codegen::types::structure::FieldIndexCounter::default();
|
||||||
|
|
||||||
|
#ident {
|
||||||
|
#(#fields_create),*
|
||||||
|
#(#phantoms_create),*
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_vec(&self) -> ::std::vec::Vec<(&'static str, ::nac3core::inkwell::types::BasicTypeEnum<'ctx>)> {
|
||||||
|
vec![
|
||||||
|
#(#fields_into),*
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
impl_block.into()
|
||||||
|
}
|
|
@ -0,0 +1,9 @@
|
||||||
|
use nac3core_derive::StructFields;
|
||||||
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
||||||
|
pub struct EmptyValue<'ctx> {
|
||||||
|
_phantom: PhantomData<&'ctx ()>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {}
|
|
@ -0,0 +1,20 @@
|
||||||
|
use nac3core::{
|
||||||
|
codegen::types::structure::StructField,
|
||||||
|
inkwell::{
|
||||||
|
values::{IntValue, PointerValue},
|
||||||
|
AddressSpace,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
use nac3core_derive::StructFields;
|
||||||
|
|
||||||
|
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
||||||
|
pub struct NDArrayValue<'ctx> {
|
||||||
|
#[value_type(usize)]
|
||||||
|
ndims: StructField<'ctx, IntValue<'ctx>>,
|
||||||
|
#[value_type(usize.ptr_type(AddressSpace::default()))]
|
||||||
|
shape: StructField<'ctx, PointerValue<'ctx>>,
|
||||||
|
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
|
||||||
|
data: StructField<'ctx, PointerValue<'ctx>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {}
|
|
@ -0,0 +1,18 @@
|
||||||
|
use nac3core::{
|
||||||
|
codegen::types::structure::StructField,
|
||||||
|
inkwell::{
|
||||||
|
values::{IntValue, PointerValue},
|
||||||
|
AddressSpace,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
use nac3core_derive::StructFields;
|
||||||
|
|
||||||
|
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
||||||
|
pub struct SliceValue<'ctx> {
|
||||||
|
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
|
||||||
|
ptr: StructField<'ctx, PointerValue<'ctx>>,
|
||||||
|
#[value_type(usize)]
|
||||||
|
len: StructField<'ctx, IntValue<'ctx>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {}
|
|
@ -0,0 +1,18 @@
|
||||||
|
use nac3core::{
|
||||||
|
codegen::types::structure::StructField,
|
||||||
|
inkwell::{
|
||||||
|
values::{IntValue, PointerValue},
|
||||||
|
AddressSpace,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
use nac3core_derive::StructFields;
|
||||||
|
|
||||||
|
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
||||||
|
pub struct SliceValue<'ctx> {
|
||||||
|
#[value_type(context.i8_type().ptr_type(AddressSpace::default()))]
|
||||||
|
ptr: StructField<'ctx, PointerValue<'ctx>>,
|
||||||
|
#[value_type(usize)]
|
||||||
|
len: StructField<'ctx, IntValue<'ctx>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {}
|
|
@ -0,0 +1,18 @@
|
||||||
|
use nac3core::{
|
||||||
|
codegen::types::structure::StructField,
|
||||||
|
inkwell::{
|
||||||
|
values::{IntValue, PointerValue},
|
||||||
|
AddressSpace,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
use nac3core_derive::StructFields;
|
||||||
|
|
||||||
|
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
||||||
|
pub struct SliceValue<'ctx> {
|
||||||
|
#[value_type(ctx.i8_type().ptr_type(AddressSpace::default()))]
|
||||||
|
ptr: StructField<'ctx, PointerValue<'ctx>>,
|
||||||
|
#[value_type(usize)]
|
||||||
|
len: StructField<'ctx, IntValue<'ctx>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {}
|
|
@ -0,0 +1,18 @@
|
||||||
|
use nac3core::{
|
||||||
|
codegen::types::structure::StructField,
|
||||||
|
inkwell::{
|
||||||
|
values::{IntValue, PointerValue},
|
||||||
|
AddressSpace,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
use nac3core_derive::StructFields;
|
||||||
|
|
||||||
|
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
||||||
|
pub struct SliceValue<'ctx> {
|
||||||
|
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
|
||||||
|
ptr: StructField<'ctx, PointerValue<'ctx>>,
|
||||||
|
#[value_type(size_t)]
|
||||||
|
len: StructField<'ctx, IntValue<'ctx>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {}
|
|
@ -0,0 +1,10 @@
|
||||||
|
#[test]
|
||||||
|
fn test_parse_empty() {
|
||||||
|
let t = trybuild::TestCases::new();
|
||||||
|
t.pass("tests/structfields_empty.rs");
|
||||||
|
t.pass("tests/structfields_slice.rs");
|
||||||
|
t.pass("tests/structfields_slice_ctx.rs");
|
||||||
|
t.pass("tests/structfields_slice_context.rs");
|
||||||
|
t.pass("tests/structfields_slice_sizet.rs");
|
||||||
|
t.pass("tests/structfields_ndarray.rs");
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -1,3 +1,9 @@
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use indexmap::IndexMap;
|
||||||
|
|
||||||
|
use nac3parser::ast::StrRef;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
symbol_resolver::SymbolValue,
|
symbol_resolver::SymbolValue,
|
||||||
toplevel::DefinitionId,
|
toplevel::DefinitionId,
|
||||||
|
@ -9,10 +15,6 @@ use crate::{
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
use indexmap::IndexMap;
|
|
||||||
use nac3parser::ast::StrRef;
|
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
pub struct ConcreteTypeStore {
|
pub struct ConcreteTypeStore {
|
||||||
store: Vec<ConcreteTypeEnum>,
|
store: Vec<ConcreteTypeEnum>,
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,9 +1,24 @@
|
||||||
use crate::{
|
use std::{
|
||||||
codegen::{
|
cmp::min,
|
||||||
classes::{
|
collections::HashMap,
|
||||||
ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, ProxyType, ProxyValue,
|
convert::TryInto,
|
||||||
RangeValue, UntypedArrayLikeAccessor,
|
iter::{once, repeat, repeat_with, zip},
|
||||||
},
|
};
|
||||||
|
|
||||||
|
use inkwell::{
|
||||||
|
attributes::{Attribute, AttributeLoc},
|
||||||
|
types::{AnyType, BasicType, BasicTypeEnum},
|
||||||
|
values::{BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue, StructValue},
|
||||||
|
AddressSpace, IntPredicate, OptimizationLevel,
|
||||||
|
};
|
||||||
|
use itertools::{chain, izip, Either, Itertools};
|
||||||
|
|
||||||
|
use nac3parser::ast::{
|
||||||
|
self, Boolop, Cmpop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef,
|
||||||
|
Unaryop,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{
|
||||||
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
|
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
|
||||||
gen_in_range_check, get_llvm_abi_type, get_llvm_type, get_va_count_arg_name,
|
gen_in_range_check, get_llvm_abi_type, get_llvm_type, get_va_count_arg_name,
|
||||||
irrt::*,
|
irrt::*,
|
||||||
|
@ -12,43 +27,30 @@ use crate::{
|
||||||
call_int_umin, call_memcpy_generic,
|
call_int_umin, call_memcpy_generic,
|
||||||
},
|
},
|
||||||
macros::codegen_unreachable,
|
macros::codegen_unreachable,
|
||||||
need_sret,
|
need_sret, numpy,
|
||||||
object::ndarray::{NDArrayOut, ScalarOrNDArray},
|
|
||||||
stmt::{
|
stmt::{
|
||||||
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
|
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
|
||||||
gen_var,
|
gen_var,
|
||||||
},
|
},
|
||||||
CodeGenContext, CodeGenTask, CodeGenerator,
|
types::{ListType, ProxyType},
|
||||||
|
values::{
|
||||||
|
ArrayLikeIndexer, ArrayLikeValue, ListValue, NDArrayValue, ProxyValue, RangeValue,
|
||||||
|
TypedArrayLikeAccessor, UntypedArrayLikeAccessor,
|
||||||
},
|
},
|
||||||
|
CodeGenContext, CodeGenTask, CodeGenerator,
|
||||||
|
};
|
||||||
|
use crate::{
|
||||||
symbol_resolver::{SymbolValue, ValueEnum},
|
symbol_resolver::{SymbolValue, ValueEnum},
|
||||||
toplevel::{helper::PrimDef, DefinitionId, TopLevelDef},
|
toplevel::{
|
||||||
|
helper::PrimDef,
|
||||||
|
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
||||||
|
DefinitionId, TopLevelDef,
|
||||||
|
},
|
||||||
typecheck::{
|
typecheck::{
|
||||||
magic_methods::{Binop, BinopVariant, HasOpInfo},
|
magic_methods::{Binop, BinopVariant, HasOpInfo},
|
||||||
typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap},
|
typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
use inkwell::{
|
|
||||||
attributes::{Attribute, AttributeLoc},
|
|
||||||
types::{AnyType, BasicType, BasicTypeEnum},
|
|
||||||
values::{
|
|
||||||
BasicValue, BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue,
|
|
||||||
StructValue,
|
|
||||||
},
|
|
||||||
AddressSpace, IntPredicate, OptimizationLevel,
|
|
||||||
};
|
|
||||||
use itertools::{chain, izip, Either, Itertools};
|
|
||||||
use nac3parser::ast::{
|
|
||||||
self, Boolop, Cmpop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef,
|
|
||||||
Unaryop,
|
|
||||||
};
|
|
||||||
use std::cmp::min;
|
|
||||||
use std::iter::{repeat, repeat_with};
|
|
||||||
use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
|
|
||||||
|
|
||||||
use super::object::{
|
|
||||||
any::AnyObject,
|
|
||||||
ndarray::{indexing::util::gen_ndarray_subscript_ndindices, NDArrayObject},
|
|
||||||
};
|
|
||||||
|
|
||||||
pub fn get_subst_key(
|
pub fn get_subst_key(
|
||||||
unifier: &mut Unifier,
|
unifier: &mut Unifier,
|
||||||
|
@ -556,7 +558,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||||
&& val_ty.get_element_type().is_struct_type()
|
&& val_ty.get_element_type().is_struct_type()
|
||||||
} =>
|
} =>
|
||||||
{
|
{
|
||||||
self.builder.build_bitcast(*val, arg_ty, "call_arg_cast").unwrap()
|
self.builder.build_bit_cast(*val, arg_ty, "call_arg_cast").unwrap()
|
||||||
}
|
}
|
||||||
_ => *val,
|
_ => *val,
|
||||||
})
|
})
|
||||||
|
@ -976,6 +978,7 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
|
||||||
TopLevelDef::Class { .. } => {
|
TopLevelDef::Class { .. } => {
|
||||||
return Ok(Some(generator.gen_constructor(ctx, fun.0, &def, params)?))
|
return Ok(Some(generator.gen_constructor(ctx, fun.0, &def, params)?))
|
||||||
}
|
}
|
||||||
|
TopLevelDef::Variable { .. } => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
.or_else(|_: String| {
|
.or_else(|_: String| {
|
||||||
|
@ -1165,7 +1168,8 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
|
||||||
TypeEnum::TObj { obj_id, .. }
|
TypeEnum::TObj { obj_id, .. }
|
||||||
if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() =>
|
if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() =>
|
||||||
{
|
{
|
||||||
let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range"));
|
let iter_val =
|
||||||
|
RangeValue::from_pointer_value(iter_val.into_pointer_value(), Some("range"));
|
||||||
let (start, stop, step) = destructure_range(ctx, iter_val);
|
let (start, stop, step) = destructure_range(ctx, iter_val);
|
||||||
let diff = ctx.builder.build_int_sub(stop, start, "diff").unwrap();
|
let diff = ctx.builder.build_int_sub(stop, start, "diff").unwrap();
|
||||||
// add 1 to the length as the value is rounded to zero
|
// add 1 to the length as the value is rounded to zero
|
||||||
|
@ -1396,8 +1400,10 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty1);
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty1);
|
||||||
let sizeof_elem = llvm_elem_ty.size_of().unwrap();
|
let sizeof_elem = llvm_elem_ty.size_of().unwrap();
|
||||||
|
|
||||||
let lhs = ListValue::from_ptr_val(left_val.into_pointer_value(), llvm_usize, None);
|
let lhs =
|
||||||
let rhs = ListValue::from_ptr_val(right_val.into_pointer_value(), llvm_usize, None);
|
ListValue::from_pointer_value(left_val.into_pointer_value(), llvm_usize, None);
|
||||||
|
let rhs =
|
||||||
|
ListValue::from_pointer_value(right_val.into_pointer_value(), llvm_usize, None);
|
||||||
|
|
||||||
let size = ctx
|
let size = ctx
|
||||||
.builder
|
.builder
|
||||||
|
@ -1480,7 +1486,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
codegen_unreachable!(ctx)
|
codegen_unreachable!(ctx)
|
||||||
};
|
};
|
||||||
let list_val =
|
let list_val =
|
||||||
ListValue::from_ptr_val(list_val.into_pointer_value(), llvm_usize, None);
|
ListValue::from_pointer_value(list_val.into_pointer_value(), llvm_usize, None);
|
||||||
let int_val = ctx
|
let int_val = ctx
|
||||||
.builder
|
.builder
|
||||||
.build_int_s_extend(int_val.into_int_value(), llvm_usize, "")
|
.build_int_s_extend(int_val.into_int_value(), llvm_usize, "")
|
||||||
|
@ -1547,75 +1553,112 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
} else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
} else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||||
|| ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
|| ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||||
{
|
{
|
||||||
let left =
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
ScalarOrNDArray::split_object(generator, ctx, AnyObject { ty: ty1, value: left_val });
|
|
||||||
let right =
|
|
||||||
ScalarOrNDArray::split_object(generator, ctx, AnyObject { ty: ty2, value: right_val });
|
|
||||||
|
|
||||||
// Inhomogeneous binary operations are not supported.
|
let is_ndarray1 = ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||||
assert!(ctx.unifier.unioned(left.get_dtype(), right.get_dtype()));
|
let is_ndarray2 = ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||||
|
|
||||||
let common_dtype = left.get_dtype();
|
if is_ndarray1 && is_ndarray2 {
|
||||||
|
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1);
|
||||||
|
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty2);
|
||||||
|
|
||||||
let out = match op.variant {
|
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||||
BinopVariant::Normal => NDArrayOut::NewNDArray { dtype: common_dtype },
|
|
||||||
BinopVariant::AugAssign => {
|
|
||||||
// If this is an augmented assignment.
|
|
||||||
// `left` has to be an ndarray. If it were a scalar then NAC3 simply doesn't support it.
|
|
||||||
if let ScalarOrNDArray::NDArray(out_ndarray) = left {
|
|
||||||
NDArrayOut::WriteToNDArray { ndarray: out_ndarray }
|
|
||||||
} else {
|
|
||||||
panic!("left must be an ndarray")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
if op.base == Operator::MatMult {
|
let llvm_ndarray_dtype1 = ctx.get_llvm_type(generator, ndarray_dtype1);
|
||||||
// Handle matrix multiplication.
|
let llvm_ndarray_dtype2 = ctx.get_llvm_type(generator, ndarray_dtype2);
|
||||||
let left = left.to_ndarray(generator, ctx);
|
|
||||||
let right = right.to_ndarray(generator, ctx);
|
|
||||||
let result = NDArrayObject::matmul(generator, ctx, left, right, out)
|
|
||||||
.split_unsized(generator, ctx);
|
|
||||||
Ok(Some(ValueEnum::Dynamic(result.to_basic_value_enum())))
|
|
||||||
} else {
|
|
||||||
// For other operations, they are all elementwise operations.
|
|
||||||
|
|
||||||
// There are only three cases:
|
let left_val = NDArrayValue::from_pointer_value(
|
||||||
// - LHS is a scalar, RHS is an ndarray.
|
left_val.into_pointer_value(),
|
||||||
// - LHS is an ndarray, RHS is a scalar.
|
llvm_ndarray_dtype1,
|
||||||
// - LHS is an ndarray, RHS is an ndarray.
|
llvm_usize,
|
||||||
//
|
None,
|
||||||
// For all cases, the scalar operand is promoted to an ndarray,
|
);
|
||||||
// the two are then broadcasted, and starmapped through.
|
let right_val = NDArrayValue::from_pointer_value(
|
||||||
|
right_val.into_pointer_value(),
|
||||||
|
llvm_ndarray_dtype2,
|
||||||
|
llvm_usize,
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
let left = left.to_ndarray(generator, ctx);
|
let res = if op.base == Operator::MatMult {
|
||||||
let right = right.to_ndarray(generator, ctx);
|
// MatMult is the only binop which is not an elementwise op
|
||||||
|
numpy::ndarray_matmul_2d(
|
||||||
let result = NDArrayObject::broadcast_starmap(
|
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
&[left, right],
|
ndarray_dtype1,
|
||||||
out,
|
match op.variant {
|
||||||
|generator, ctx, scalars| {
|
BinopVariant::Normal => None,
|
||||||
let left_value = scalars[0];
|
BinopVariant::AugAssign => Some(left_val),
|
||||||
let right_value = scalars[1];
|
},
|
||||||
|
left_val,
|
||||||
let result = gen_binop_expr_with_values(
|
right_val,
|
||||||
|
)?
|
||||||
|
} else {
|
||||||
|
numpy::ndarray_elementwise_binop_impl(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
(&Some(left.dtype), left_value),
|
ndarray_dtype1,
|
||||||
|
match op.variant {
|
||||||
|
BinopVariant::Normal => None,
|
||||||
|
BinopVariant::AugAssign => Some(left_val),
|
||||||
|
},
|
||||||
|
(ty1, left_val.as_base_value().into(), false),
|
||||||
|
(ty2, right_val.as_base_value().into(), false),
|
||||||
|
|generator, ctx, (lhs, rhs)| {
|
||||||
|
gen_binop_expr_with_values(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
(&Some(ndarray_dtype1), lhs),
|
||||||
op,
|
op,
|
||||||
(&Some(right.dtype), right_value),
|
(&Some(ndarray_dtype2), rhs),
|
||||||
ctx.current_loc,
|
ctx.current_loc,
|
||||||
)?
|
)?
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.to_basic_value_enum(ctx, generator, common_dtype)?;
|
.to_basic_value_enum(
|
||||||
|
ctx,
|
||||||
Ok(result)
|
generator,
|
||||||
},
|
ndarray_dtype1,
|
||||||
)
|
)
|
||||||
.unwrap();
|
},
|
||||||
Ok(Some(ValueEnum::Dynamic(result.instance.value.as_basic_value_enum())))
|
)?
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Some(res.as_base_value().into()))
|
||||||
|
} else {
|
||||||
|
let (ndarray_dtype, _) =
|
||||||
|
unpack_ndarray_var_tys(&mut ctx.unifier, if is_ndarray1 { ty1 } else { ty2 });
|
||||||
|
let llvm_ndarray_dtype = ctx.get_llvm_type(generator, ndarray_dtype);
|
||||||
|
let ndarray_val = NDArrayValue::from_pointer_value(
|
||||||
|
if is_ndarray1 { left_val } else { right_val }.into_pointer_value(),
|
||||||
|
llvm_ndarray_dtype,
|
||||||
|
llvm_usize,
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
let res = numpy::ndarray_elementwise_binop_impl(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
ndarray_dtype,
|
||||||
|
match op.variant {
|
||||||
|
BinopVariant::Normal => None,
|
||||||
|
BinopVariant::AugAssign => Some(ndarray_val),
|
||||||
|
},
|
||||||
|
(ty1, left_val, !is_ndarray1),
|
||||||
|
(ty2, right_val, !is_ndarray2),
|
||||||
|
|generator, ctx, (lhs, rhs)| {
|
||||||
|
gen_binop_expr_with_values(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
(&Some(ndarray_dtype), lhs),
|
||||||
|
op,
|
||||||
|
(&Some(ndarray_dtype), rhs),
|
||||||
|
ctx.current_loc,
|
||||||
|
)?
|
||||||
|
.unwrap()
|
||||||
|
.to_basic_value_enum(ctx, generator, ndarray_dtype)
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(Some(res.as_base_value().into()))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap());
|
let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap());
|
||||||
|
@ -1751,7 +1794,12 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
ast::Unaryop::Invert => ctx.builder.build_not(val, "not").map(Into::into).unwrap(),
|
ast::Unaryop::Invert => ctx.builder.build_not(val, "not").map(Into::into).unwrap(),
|
||||||
ast::Unaryop::Not => ctx
|
ast::Unaryop::Not => ctx
|
||||||
.builder
|
.builder
|
||||||
.build_xor(val, val.get_type().const_all_ones(), "not")
|
.build_int_compare(
|
||||||
|
inkwell::IntPredicate::EQ,
|
||||||
|
val,
|
||||||
|
val.get_type().const_zero(),
|
||||||
|
"not",
|
||||||
|
)
|
||||||
.map(Into::into)
|
.map(Into::into)
|
||||||
.unwrap(),
|
.unwrap(),
|
||||||
ast::Unaryop::UAdd => val.into(),
|
ast::Unaryop::UAdd => val.into(),
|
||||||
|
@ -1773,12 +1821,20 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
_ => val.into(),
|
_ => val.into(),
|
||||||
}
|
}
|
||||||
} else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
|
} else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
|
||||||
let ndarray = AnyObject { value: val, ty };
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
|
let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
||||||
|
let llvm_ndarray_dtype = ctx.get_llvm_type(generator, ndarray_dtype);
|
||||||
|
|
||||||
|
let val = NDArrayValue::from_pointer_value(
|
||||||
|
val.into_pointer_value(),
|
||||||
|
llvm_ndarray_dtype,
|
||||||
|
llvm_usize,
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
// ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before
|
// ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before
|
||||||
// passing it to the elementwise codegen function
|
// passing it to the elementwise codegen function
|
||||||
let op = if ndarray.dtype.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::Bool.id()) {
|
let op = if ndarray_dtype.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::Bool.id()) {
|
||||||
if op == ast::Unaryop::Invert {
|
if op == ast::Unaryop::Invert {
|
||||||
ast::Unaryop::Not
|
ast::Unaryop::Not
|
||||||
} else {
|
} else {
|
||||||
|
@ -1792,18 +1848,20 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
op
|
op
|
||||||
};
|
};
|
||||||
|
|
||||||
let mapped_ndarray = ndarray.map(
|
let res = numpy::ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
NDArrayOut::NewNDArray { dtype: ndarray.dtype },
|
ndarray_dtype,
|
||||||
|generator, ctx, scalar| {
|
None,
|
||||||
gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray.dtype), scalar))?
|
val,
|
||||||
|
|generator, ctx, val| {
|
||||||
|
gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), val))?
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.to_basic_value_enum(ctx, generator, ndarray.dtype)
|
.to_basic_value_enum(ctx, generator, ndarray_dtype)
|
||||||
},
|
},
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
ValueEnum::Dynamic(mapped_ndarray.instance.value.as_basic_value_enum())
|
res.as_base_value().into()
|
||||||
} else {
|
} else {
|
||||||
unimplemented!()
|
unimplemented!()
|
||||||
}))
|
}))
|
||||||
|
@ -1846,33 +1904,45 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||||
|| right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
|| right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||||
{
|
{
|
||||||
let (Some(left_ty), left) = left else { codegen_unreachable!(ctx) };
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
let (Some(right_ty), right) = comparators[0] else { codegen_unreachable!(ctx) };
|
|
||||||
|
let (Some(left_ty), lhs) = left else { codegen_unreachable!(ctx) };
|
||||||
|
let (Some(right_ty), rhs) = comparators[0] else { codegen_unreachable!(ctx) };
|
||||||
let op = ops[0];
|
let op = ops[0];
|
||||||
|
|
||||||
let left = AnyObject { value: left, ty: left_ty };
|
let is_ndarray1 =
|
||||||
let left =
|
left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||||
ScalarOrNDArray::split_object(generator, ctx, left).to_ndarray(generator, ctx);
|
let is_ndarray2 =
|
||||||
|
right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||||
|
|
||||||
let right = AnyObject { value: right, ty: right_ty };
|
return if is_ndarray1 && is_ndarray2 {
|
||||||
let right =
|
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty);
|
||||||
ScalarOrNDArray::split_object(generator, ctx, right).to_ndarray(generator, ctx);
|
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, right_ty);
|
||||||
|
|
||||||
let result_ndarray = NDArrayObject::broadcast_starmap(
|
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||||
|
|
||||||
|
let llvm_ndarray_dtype1 = ctx.get_llvm_type(generator, ndarray_dtype1);
|
||||||
|
|
||||||
|
let left_val = NDArrayValue::from_pointer_value(
|
||||||
|
lhs.into_pointer_value(),
|
||||||
|
llvm_ndarray_dtype1,
|
||||||
|
llvm_usize,
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
let res = numpy::ndarray_elementwise_binop_impl(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
&[left, right],
|
ctx.primitives.bool,
|
||||||
NDArrayOut::NewNDArray { dtype: ctx.primitives.bool },
|
None,
|
||||||
|generator, ctx, scalars| {
|
(left_ty, left_val.as_base_value().into(), false),
|
||||||
let left_scalar = scalars[0];
|
(right_ty, rhs, false),
|
||||||
let right_scalar = scalars[1];
|
|generator, ctx, (lhs, rhs)| {
|
||||||
|
|
||||||
let val = gen_cmpop_expr_with_values(
|
let val = gen_cmpop_expr_with_values(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
(Some(left.dtype), left_scalar),
|
(Some(ndarray_dtype1), lhs),
|
||||||
&[op],
|
&[op],
|
||||||
&[(Some(right.dtype), right_scalar)],
|
&[(Some(ndarray_dtype2), rhs)],
|
||||||
)?
|
)?
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.to_basic_value_enum(
|
.to_basic_value_enum(
|
||||||
|
@ -1885,7 +1955,40 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
},
|
},
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
return Ok(Some(result_ndarray.instance.value.into()));
|
Ok(Some(res.as_base_value().into()))
|
||||||
|
} else {
|
||||||
|
let (ndarray_dtype, _) = unpack_ndarray_var_tys(
|
||||||
|
&mut ctx.unifier,
|
||||||
|
if is_ndarray1 { left_ty } else { right_ty },
|
||||||
|
);
|
||||||
|
let res = numpy::ndarray_elementwise_binop_impl(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
ctx.primitives.bool,
|
||||||
|
None,
|
||||||
|
(left_ty, lhs, !is_ndarray1),
|
||||||
|
(right_ty, rhs, !is_ndarray2),
|
||||||
|
|generator, ctx, (lhs, rhs)| {
|
||||||
|
let val = gen_cmpop_expr_with_values(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
(Some(ndarray_dtype), lhs),
|
||||||
|
&[op],
|
||||||
|
&[(Some(ndarray_dtype), rhs)],
|
||||||
|
)?
|
||||||
|
.unwrap()
|
||||||
|
.to_basic_value_enum(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
ctx.primitives.bool,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(generator.bool_to_i8(ctx, val.into_int_value()).into())
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(Some(res.as_base_value().into()))
|
||||||
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2127,9 +2230,9 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
}
|
}
|
||||||
|
|
||||||
let left_val =
|
let left_val =
|
||||||
ListValue::from_ptr_val(lhs.into_pointer_value(), llvm_usize, None);
|
ListValue::from_pointer_value(lhs.into_pointer_value(), llvm_usize, None);
|
||||||
let right_val =
|
let right_val =
|
||||||
ListValue::from_ptr_val(rhs.into_pointer_value(), llvm_usize, None);
|
ListValue::from_pointer_value(rhs.into_pointer_value(), llvm_usize, None);
|
||||||
|
|
||||||
Ok(gen_if_else_expr_callback(
|
Ok(gen_if_else_expr_callback(
|
||||||
generator,
|
generator,
|
||||||
|
@ -2436,6 +2539,343 @@ pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generates code for a subscript expression on an `ndarray`.
|
||||||
|
///
|
||||||
|
/// * `ty` - The `Type` of the `NDArray` elements.
|
||||||
|
/// * `ndims` - The `Type` of the `NDArray` number-of-dimensions `Literal`.
|
||||||
|
/// * `v` - The `NDArray` value.
|
||||||
|
/// * `slice` - The slice expression used to subscript into the `ndarray`.
|
||||||
|
fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ty: Type,
|
||||||
|
ndims: Type,
|
||||||
|
v: NDArrayValue<'ctx>,
|
||||||
|
slice: &Expr<Option<Type>>,
|
||||||
|
) -> Result<Option<ValueEnum<'ctx>>, String> {
|
||||||
|
let llvm_i1 = ctx.ctx.bool_type();
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) else {
|
||||||
|
codegen_unreachable!(ctx)
|
||||||
|
};
|
||||||
|
|
||||||
|
let ndims = values
|
||||||
|
.iter()
|
||||||
|
.map(|ndim| u64::try_from(ndim.clone()).map_err(|()| ndim.clone()))
|
||||||
|
.collect::<Result<Vec<_>, _>>()
|
||||||
|
.map_err(|val| {
|
||||||
|
format!(
|
||||||
|
"Expected non-negative literal for ndarray.ndims, got {}",
|
||||||
|
i128::try_from(val).unwrap()
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
assert!(!ndims.is_empty());
|
||||||
|
|
||||||
|
// The number of dimensions subscripted by the index expression.
|
||||||
|
// Slicing a ndarray will yield the same number of dimensions, whereas indexing into a
|
||||||
|
// dimension will remove a dimension.
|
||||||
|
let subscripted_dims = match &slice.node {
|
||||||
|
ExprKind::Tuple { elts, .. } => elts.iter().fold(0, |acc, value_subexpr| {
|
||||||
|
if let ExprKind::Slice { .. } = &value_subexpr.node {
|
||||||
|
acc
|
||||||
|
} else {
|
||||||
|
acc + 1
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
|
||||||
|
ExprKind::Slice { .. } => 0,
|
||||||
|
_ => 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
let ndarray_ndims_ty = ctx.unifier.get_fresh_literal(
|
||||||
|
ndims.iter().map(|v| SymbolValue::U64(v - subscripted_dims)).collect(),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
let ndarray_ty =
|
||||||
|
make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(ty), Some(ndarray_ndims_ty));
|
||||||
|
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
|
||||||
|
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
|
||||||
|
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum();
|
||||||
|
let sizeof_elem = llvm_ndarray_data_t.size_of().unwrap();
|
||||||
|
|
||||||
|
// Check that len is non-zero
|
||||||
|
let len = v.load_ndims(ctx);
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
ctx.builder.build_int_compare(IntPredicate::SGT, len, llvm_usize.const_zero(), "").unwrap(),
|
||||||
|
"0:IndexError",
|
||||||
|
"too many indices for array: array is {0}-dimensional but 1 were indexed",
|
||||||
|
[Some(len), None, None],
|
||||||
|
slice.location,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Normalizes a possibly-negative index to its corresponding positive index
|
||||||
|
let normalize_index = |generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
index: IntValue<'ctx>,
|
||||||
|
dim: u64| {
|
||||||
|
gen_if_else_expr_callback(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
|_, ctx| {
|
||||||
|
Ok(ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(IntPredicate::SGE, index, index.get_type().const_zero(), "")
|
||||||
|
.unwrap())
|
||||||
|
},
|
||||||
|
|_, _| Ok(Some(index)),
|
||||||
|
|generator, ctx| {
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
|
||||||
|
let len = unsafe {
|
||||||
|
v.shape().get_typed_unchecked(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
&llvm_usize.const_int(dim, true),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
let index = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_add(
|
||||||
|
len,
|
||||||
|
ctx.builder.build_int_s_extend(index, llvm_usize, "").unwrap(),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
Ok(Some(ctx.builder.build_int_truncate(index, llvm_i32, "").unwrap()))
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.map(|v| v.map(BasicValueEnum::into_int_value))
|
||||||
|
};
|
||||||
|
|
||||||
|
// Converts a slice expression into a slice-range tuple
|
||||||
|
let expr_to_slice = |generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
node: &ExprKind<Option<Type>>,
|
||||||
|
dim: u64| {
|
||||||
|
match node {
|
||||||
|
ExprKind::Constant { value: Constant::Int(v), .. } => {
|
||||||
|
let Some(index) =
|
||||||
|
normalize_index(generator, ctx, llvm_i32.const_int(*v as u64, true), dim)?
|
||||||
|
else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Some((index, index, llvm_i32.const_int(1, true))))
|
||||||
|
}
|
||||||
|
|
||||||
|
ExprKind::Slice { lower, upper, step } => {
|
||||||
|
let dim_sz = unsafe {
|
||||||
|
v.shape().get_typed_unchecked(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
&llvm_usize.const_int(dim, false),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
handle_slice_indices(lower, upper, step, ctx, generator, dim_sz)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ => {
|
||||||
|
let Some(index) = generator.gen_expr(ctx, slice)? else { return Ok(None) };
|
||||||
|
let index = index
|
||||||
|
.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?
|
||||||
|
.into_int_value();
|
||||||
|
let Some(index) = normalize_index(generator, ctx, index, dim)? else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Some((index, index, llvm_i32.const_int(1, true))))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let make_indices_arr = |generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>|
|
||||||
|
-> Result<_, String> {
|
||||||
|
Ok(if let ExprKind::Tuple { elts, .. } = &slice.node {
|
||||||
|
let llvm_int_ty = ctx.get_llvm_type(generator, elts[0].custom.unwrap());
|
||||||
|
let index_addr = generator.gen_array_var_alloc(
|
||||||
|
ctx,
|
||||||
|
llvm_int_ty,
|
||||||
|
llvm_usize.const_int(elts.len() as u64, false),
|
||||||
|
None,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
for (i, elt) in elts.iter().enumerate() {
|
||||||
|
let Some(index) = generator.gen_expr(ctx, elt)? else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
|
||||||
|
let index = index
|
||||||
|
.to_basic_value_enum(ctx, generator, elt.custom.unwrap())?
|
||||||
|
.into_int_value();
|
||||||
|
let Some(index) = normalize_index(generator, ctx, index, 0)? else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
|
||||||
|
let store_ptr = unsafe {
|
||||||
|
index_addr.ptr_offset_unchecked(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
&llvm_usize.const_int(i as u64, false),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
ctx.builder.build_store(store_ptr, index).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(index_addr)
|
||||||
|
} else if let Some(index) = generator.gen_expr(ctx, slice)? {
|
||||||
|
let llvm_int_ty = ctx.get_llvm_type(generator, slice.custom.unwrap());
|
||||||
|
let index_addr = generator.gen_array_var_alloc(
|
||||||
|
ctx,
|
||||||
|
llvm_int_ty,
|
||||||
|
llvm_usize.const_int(1u64, false),
|
||||||
|
None,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let index =
|
||||||
|
index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value();
|
||||||
|
let Some(index) = normalize_index(generator, ctx, index, 0)? else { return Ok(None) };
|
||||||
|
|
||||||
|
let store_ptr = unsafe {
|
||||||
|
index_addr.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
|
};
|
||||||
|
ctx.builder.build_store(store_ptr, index).unwrap();
|
||||||
|
|
||||||
|
Some(index_addr)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
})
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Some(if ndims.len() == 1 && ndims[0] - subscripted_dims == 0 {
|
||||||
|
let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) };
|
||||||
|
|
||||||
|
v.data().get(ctx, generator, &index_addr, None).into()
|
||||||
|
} else {
|
||||||
|
match &slice.node {
|
||||||
|
ExprKind::Tuple { elts, .. } => {
|
||||||
|
let slices = elts
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(dim, elt)| expr_to_slice(generator, ctx, &elt.node, dim as u64))
|
||||||
|
.take_while_inclusive(|slice| slice.as_ref().is_ok_and(Option::is_some))
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
if slices.len() < elts.len() {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
let slices = slices.into_iter().map(Option::unwrap).collect_vec();
|
||||||
|
|
||||||
|
numpy::ndarray_sliced_copy(generator, ctx, ty, v, &slices)?.as_base_value().into()
|
||||||
|
}
|
||||||
|
|
||||||
|
ExprKind::Slice { .. } => {
|
||||||
|
let Some(slice) = expr_to_slice(generator, ctx, &slice.node, 0)? else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
|
||||||
|
numpy::ndarray_sliced_copy(generator, ctx, ty, v, &[slice])?.as_base_value().into()
|
||||||
|
}
|
||||||
|
|
||||||
|
_ => {
|
||||||
|
// Accessing an element from a multi-dimensional `ndarray`
|
||||||
|
|
||||||
|
let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) };
|
||||||
|
|
||||||
|
// Create a new array, remove the top dimension from the dimension-size-list, and copy the
|
||||||
|
// elements over
|
||||||
|
let subscripted_ndarray =
|
||||||
|
generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?;
|
||||||
|
let ndarray = NDArrayValue::from_pointer_value(
|
||||||
|
subscripted_ndarray,
|
||||||
|
llvm_ndarray_data_t,
|
||||||
|
llvm_usize,
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
|
let num_dims = v.load_ndims(ctx);
|
||||||
|
ndarray.store_ndims(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
ctx.builder
|
||||||
|
.build_int_sub(num_dims, llvm_usize.const_int(1, false), "")
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||||
|
ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims);
|
||||||
|
|
||||||
|
let ndarray_num_dims = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_z_extend_or_bit_cast(
|
||||||
|
ndarray.load_ndims(ctx),
|
||||||
|
llvm_usize.size_of().get_type(),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let v_dims_src_ptr = unsafe {
|
||||||
|
v.shape().ptr_offset_unchecked(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
&llvm_usize.const_int(1, false),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
call_memcpy_generic(
|
||||||
|
ctx,
|
||||||
|
ndarray.shape().base_ptr(ctx, generator),
|
||||||
|
v_dims_src_ptr,
|
||||||
|
ctx.builder
|
||||||
|
.build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "")
|
||||||
|
.map(Into::into)
|
||||||
|
.unwrap(),
|
||||||
|
llvm_i1.const_zero(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let ndarray_num_elems = call_ndarray_calc_size(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
&ndarray.shape().as_slice_value(ctx, generator),
|
||||||
|
(None, None),
|
||||||
|
);
|
||||||
|
let ndarray_num_elems = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_z_extend_or_bit_cast(ndarray_num_elems, sizeof_elem.get_type(), "")
|
||||||
|
.unwrap();
|
||||||
|
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
||||||
|
|
||||||
|
let v_data_src_ptr = v.data().ptr_offset(ctx, generator, &index_addr, None);
|
||||||
|
call_memcpy_generic(
|
||||||
|
ctx,
|
||||||
|
ndarray.data().base_ptr(ctx, generator),
|
||||||
|
v_data_src_ptr,
|
||||||
|
ctx.builder
|
||||||
|
.build_int_mul(
|
||||||
|
ndarray_num_elems,
|
||||||
|
llvm_ndarray_data_t.size_of().unwrap(),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.map(Into::into)
|
||||||
|
.unwrap(),
|
||||||
|
llvm_i1.const_zero(),
|
||||||
|
);
|
||||||
|
|
||||||
|
ndarray.as_base_value().into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
/// See [`CodeGenerator::gen_expr`].
|
/// See [`CodeGenerator::gen_expr`].
|
||||||
pub fn gen_expr<'ctx, G: CodeGenerator>(
|
pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
|
@ -2485,7 +2925,31 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||||
Some((_, Some(static_value), _)) => ValueEnum::Static(static_value.clone()),
|
Some((_, Some(static_value), _)) => ValueEnum::Static(static_value.clone()),
|
||||||
None => {
|
None => {
|
||||||
let resolver = ctx.resolver.clone();
|
let resolver = ctx.resolver.clone();
|
||||||
resolver.get_symbol_value(*id, ctx).unwrap()
|
let value = resolver.get_symbol_value(*id, ctx, generator).unwrap();
|
||||||
|
|
||||||
|
let globals = ctx
|
||||||
|
.top_level
|
||||||
|
.definitions
|
||||||
|
.read()
|
||||||
|
.iter()
|
||||||
|
.filter_map(|def| {
|
||||||
|
if let TopLevelDef::Variable { simple_name, ty, .. } = &*def.read() {
|
||||||
|
Some((*simple_name, *ty))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect_vec();
|
||||||
|
|
||||||
|
if let Some((_, ty)) = globals.iter().find(|(name, _)| name == id) {
|
||||||
|
let ptr = value
|
||||||
|
.to_basic_value_enum(ctx, generator, *ty)
|
||||||
|
.map(BasicValueEnum::into_pointer_value)?;
|
||||||
|
|
||||||
|
ctx.builder.build_load(ptr, id.to_string().as_str()).map(Into::into).unwrap()
|
||||||
|
} else {
|
||||||
|
value
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
ExprKind::List { elts, .. } => {
|
ExprKind::List { elts, .. } => {
|
||||||
|
@ -2678,48 +3142,53 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||||
};
|
};
|
||||||
let left = generator.bool_to_i1(ctx, left);
|
let left = generator.bool_to_i1(ctx, left);
|
||||||
let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
|
let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
|
||||||
let a_bb = ctx.ctx.append_basic_block(current, "a");
|
let a_begin_bb = ctx.ctx.append_basic_block(current, "a_begin");
|
||||||
let b_bb = ctx.ctx.append_basic_block(current, "b");
|
let a_end_bb = ctx.ctx.append_basic_block(current, "a_end");
|
||||||
|
let b_begin_bb = ctx.ctx.append_basic_block(current, "b_begin");
|
||||||
|
let b_end_bb = ctx.ctx.append_basic_block(current, "b_end");
|
||||||
let cont_bb = ctx.ctx.append_basic_block(current, "cont");
|
let cont_bb = ctx.ctx.append_basic_block(current, "cont");
|
||||||
ctx.builder.build_conditional_branch(left, a_bb, b_bb).unwrap();
|
ctx.builder.build_conditional_branch(left, a_begin_bb, b_begin_bb).unwrap();
|
||||||
|
|
||||||
|
ctx.builder.position_at_end(a_end_bb);
|
||||||
|
ctx.builder.build_unconditional_branch(cont_bb).unwrap();
|
||||||
|
ctx.builder.position_at_end(b_end_bb);
|
||||||
|
ctx.builder.build_unconditional_branch(cont_bb).unwrap();
|
||||||
let (a, b) = match op {
|
let (a, b) = match op {
|
||||||
Boolop::Or => {
|
Boolop::Or => {
|
||||||
ctx.builder.position_at_end(a_bb);
|
ctx.builder.position_at_end(a_begin_bb);
|
||||||
let a = ctx.ctx.i8_type().const_int(1, false);
|
let a = ctx.ctx.i8_type().const_int(1, false);
|
||||||
ctx.builder.build_unconditional_branch(cont_bb).unwrap();
|
ctx.builder.build_unconditional_branch(a_end_bb).unwrap();
|
||||||
|
|
||||||
ctx.builder.position_at_end(b_bb);
|
ctx.builder.position_at_end(b_begin_bb);
|
||||||
let b = if let Some(v) = generator.gen_expr(ctx, &values[1])? {
|
let b = if let Some(v) = generator.gen_expr(ctx, &values[1])? {
|
||||||
let b = v
|
let b = v
|
||||||
.to_basic_value_enum(ctx, generator, values[1].custom.unwrap())?
|
.to_basic_value_enum(ctx, generator, values[1].custom.unwrap())?
|
||||||
.into_int_value();
|
.into_int_value();
|
||||||
let b = generator.bool_to_i8(ctx, b);
|
let b = generator.bool_to_i8(ctx, b);
|
||||||
ctx.builder.build_unconditional_branch(cont_bb).unwrap();
|
|
||||||
|
|
||||||
Some(b)
|
Some(b)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
ctx.builder.build_unconditional_branch(b_end_bb).unwrap();
|
||||||
|
|
||||||
(Some(a), b)
|
(Some(a), b)
|
||||||
}
|
}
|
||||||
Boolop::And => {
|
Boolop::And => {
|
||||||
ctx.builder.position_at_end(a_bb);
|
ctx.builder.position_at_end(a_begin_bb);
|
||||||
let a = if let Some(v) = generator.gen_expr(ctx, &values[1])? {
|
let a = if let Some(v) = generator.gen_expr(ctx, &values[1])? {
|
||||||
let a = v
|
let a = v
|
||||||
.to_basic_value_enum(ctx, generator, values[1].custom.unwrap())?
|
.to_basic_value_enum(ctx, generator, values[1].custom.unwrap())?
|
||||||
.into_int_value();
|
.into_int_value();
|
||||||
let a = generator.bool_to_i8(ctx, a);
|
let a = generator.bool_to_i8(ctx, a);
|
||||||
ctx.builder.build_unconditional_branch(cont_bb).unwrap();
|
|
||||||
|
|
||||||
Some(a)
|
Some(a)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
ctx.builder.build_unconditional_branch(a_end_bb).unwrap();
|
||||||
|
|
||||||
ctx.builder.position_at_end(b_bb);
|
ctx.builder.position_at_end(b_begin_bb);
|
||||||
let b = ctx.ctx.i8_type().const_zero();
|
let b = ctx.ctx.i8_type().const_zero();
|
||||||
ctx.builder.build_unconditional_branch(cont_bb).unwrap();
|
ctx.builder.build_unconditional_branch(b_end_bb).unwrap();
|
||||||
|
|
||||||
(a, Some(b))
|
(a, Some(b))
|
||||||
}
|
}
|
||||||
|
@ -2729,7 +3198,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||||
match (a, b) {
|
match (a, b) {
|
||||||
(Some(a), Some(b)) => {
|
(Some(a), Some(b)) => {
|
||||||
let phi = ctx.builder.build_phi(ctx.ctx.i8_type(), "").unwrap();
|
let phi = ctx.builder.build_phi(ctx.ctx.i8_type(), "").unwrap();
|
||||||
phi.add_incoming(&[(&a, a_bb), (&b, b_bb)]);
|
phi.add_incoming(&[(&a, a_end_bb), (&b, b_end_bb)]);
|
||||||
phi.as_basic_value().into()
|
phi.as_basic_value().into()
|
||||||
}
|
}
|
||||||
(Some(a), None) => a.into(),
|
(Some(a), None) => a.into(),
|
||||||
|
@ -2967,7 +3436,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||||
} else {
|
} else {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
};
|
};
|
||||||
let v = ListValue::from_ptr_val(v, usize, Some("arr"));
|
let v = ListValue::from_pointer_value(v, usize, Some("arr"));
|
||||||
let ty = ctx.get_llvm_type(generator, *ty);
|
let ty = ctx.get_llvm_type(generator, *ty);
|
||||||
if let ExprKind::Slice { lower, upper, step } = &slice.node {
|
if let ExprKind::Slice { lower, upper, step } = &slice.node {
|
||||||
let one = int32.const_int(1, false);
|
let one = int32.const_int(1, false);
|
||||||
|
@ -3068,26 +3537,19 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||||
v.data().get(ctx, generator, &index, None).into()
|
v.data().get(ctx, generator, &index, None).into()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||||
let Some(ndarray) = generator.gen_expr(ctx, value)? else {
|
let (ty, ndims) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap();
|
||||||
|
let llvm_ty = ctx.get_llvm_type(generator, *ty);
|
||||||
|
|
||||||
|
let v = if let Some(v) = generator.gen_expr(ctx, value)? {
|
||||||
|
v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?
|
||||||
|
.into_pointer_value()
|
||||||
|
} else {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
};
|
};
|
||||||
|
let v = NDArrayValue::from_pointer_value(v, llvm_ty, usize, None);
|
||||||
|
|
||||||
let ndarray_ty = value.custom.unwrap();
|
return gen_ndarray_subscript_expr(generator, ctx, *ty, *ndims, v, slice);
|
||||||
let ndarray = ndarray.to_basic_value_enum(ctx, generator, ndarray_ty)?;
|
|
||||||
|
|
||||||
let ndarray = NDArrayObject::from_object(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
AnyObject { ty: ndarray_ty, value: ndarray },
|
|
||||||
);
|
|
||||||
|
|
||||||
let indices = gen_ndarray_subscript_ndindices(generator, ctx, slice)?;
|
|
||||||
let result = ndarray
|
|
||||||
.index(generator, ctx, &indices)
|
|
||||||
.split_unsized(generator, ctx)
|
|
||||||
.to_basic_value_enum();
|
|
||||||
return Ok(Some(ValueEnum::Dynamic(result)));
|
|
||||||
}
|
}
|
||||||
TypeEnum::TTuple { .. } => {
|
TypeEnum::TTuple { .. } => {
|
||||||
let index: u32 =
|
let index: u32 =
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
use inkwell::attributes::{Attribute, AttributeLoc};
|
use inkwell::{
|
||||||
use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue};
|
attributes::{Attribute, AttributeLoc},
|
||||||
|
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue},
|
||||||
|
};
|
||||||
use itertools::Either;
|
use itertools::Either;
|
||||||
|
|
||||||
use crate::codegen::CodeGenContext;
|
use super::CodeGenContext;
|
||||||
|
|
||||||
/// Macro to generate extern function
|
/// Macro to generate extern function
|
||||||
/// Both function return type and function parameter type are `FloatValue`
|
/// Both function return type and function parameter type are `FloatValue`
|
||||||
|
|
|
@ -1,16 +1,18 @@
|
||||||
use crate::{
|
|
||||||
codegen::{bool_to_i1, bool_to_i8, classes::ArraySliceValue, expr::*, stmt::*, CodeGenContext},
|
|
||||||
symbol_resolver::ValueEnum,
|
|
||||||
toplevel::{DefinitionId, TopLevelDef},
|
|
||||||
typecheck::typedef::{FunSignature, Type},
|
|
||||||
};
|
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
context::Context,
|
context::Context,
|
||||||
types::{BasicTypeEnum, IntType},
|
types::{BasicTypeEnum, IntType},
|
||||||
values::{BasicValueEnum, IntValue, PointerValue},
|
values::{BasicValueEnum, IntValue, PointerValue},
|
||||||
};
|
};
|
||||||
|
|
||||||
use nac3parser::ast::{Expr, Stmt, StrRef};
|
use nac3parser::ast::{Expr, Stmt, StrRef};
|
||||||
|
|
||||||
|
use super::{bool_to_i1, bool_to_i8, expr::*, stmt::*, values::ArraySliceValue, CodeGenContext};
|
||||||
|
use crate::{
|
||||||
|
symbol_resolver::ValueEnum,
|
||||||
|
toplevel::{DefinitionId, TopLevelDef},
|
||||||
|
typecheck::typedef::{FunSignature, Type},
|
||||||
|
};
|
||||||
|
|
||||||
pub trait CodeGenerator {
|
pub trait CodeGenerator {
|
||||||
/// Return the module name for the code generator.
|
/// Return the module name for the code generator.
|
||||||
fn get_name(&self) -> &str;
|
fn get_name(&self) -> &str;
|
||||||
|
|
|
@ -0,0 +1,162 @@
|
||||||
|
use inkwell::{
|
||||||
|
types::BasicTypeEnum,
|
||||||
|
values::{BasicValueEnum, CallSiteValue, IntValue},
|
||||||
|
AddressSpace, IntPredicate,
|
||||||
|
};
|
||||||
|
use itertools::Either;
|
||||||
|
|
||||||
|
use super::calculate_len_for_slice_range;
|
||||||
|
use crate::codegen::{
|
||||||
|
macros::codegen_unreachable,
|
||||||
|
values::{ArrayLikeValue, ListValue},
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// This function handles 'end' **inclusively**.
|
||||||
|
/// Order of tuples `assign_idx` and `value_idx` is ('start', 'end', 'step').
|
||||||
|
/// Negative index should be handled before entering this function
|
||||||
|
pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ty: BasicTypeEnum<'ctx>,
|
||||||
|
dest_arr: ListValue<'ctx>,
|
||||||
|
dest_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
|
||||||
|
src_arr: ListValue<'ctx>,
|
||||||
|
src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
|
||||||
|
) {
|
||||||
|
let size_ty = generator.get_size_type(ctx.ctx);
|
||||||
|
let int8_ptr = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
|
||||||
|
let int32 = ctx.ctx.i32_type();
|
||||||
|
let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", int8_ptr);
|
||||||
|
let slice_assign_fun = {
|
||||||
|
let ty_vec = vec![
|
||||||
|
int32.into(), // dest start idx
|
||||||
|
int32.into(), // dest end idx
|
||||||
|
int32.into(), // dest step
|
||||||
|
elem_ptr_type.into(), // dest arr ptr
|
||||||
|
int32.into(), // dest arr len
|
||||||
|
int32.into(), // src start idx
|
||||||
|
int32.into(), // src end idx
|
||||||
|
int32.into(), // src step
|
||||||
|
elem_ptr_type.into(), // src arr ptr
|
||||||
|
int32.into(), // src arr len
|
||||||
|
int32.into(), // size
|
||||||
|
];
|
||||||
|
ctx.module.get_function(fun_symbol).unwrap_or_else(|| {
|
||||||
|
let fn_t = int32.fn_type(ty_vec.as_slice(), false);
|
||||||
|
ctx.module.add_function(fun_symbol, fn_t, None)
|
||||||
|
})
|
||||||
|
};
|
||||||
|
|
||||||
|
let zero = int32.const_zero();
|
||||||
|
let one = int32.const_int(1, false);
|
||||||
|
let dest_arr_ptr = dest_arr.data().base_ptr(ctx, generator);
|
||||||
|
let dest_arr_ptr =
|
||||||
|
ctx.builder.build_pointer_cast(dest_arr_ptr, elem_ptr_type, "dest_arr_ptr_cast").unwrap();
|
||||||
|
let dest_len = dest_arr.load_size(ctx, Some("dest.len"));
|
||||||
|
let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32").unwrap();
|
||||||
|
let src_arr_ptr = src_arr.data().base_ptr(ctx, generator);
|
||||||
|
let src_arr_ptr =
|
||||||
|
ctx.builder.build_pointer_cast(src_arr_ptr, elem_ptr_type, "src_arr_ptr_cast").unwrap();
|
||||||
|
let src_len = src_arr.load_size(ctx, Some("src.len"));
|
||||||
|
let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32").unwrap();
|
||||||
|
|
||||||
|
// index in bound and positive should be done
|
||||||
|
// assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and
|
||||||
|
// throw exception if not satisfied
|
||||||
|
let src_end = ctx
|
||||||
|
.builder
|
||||||
|
.build_select(
|
||||||
|
ctx.builder.build_int_compare(IntPredicate::SLT, src_idx.2, zero, "is_neg").unwrap(),
|
||||||
|
ctx.builder.build_int_sub(src_idx.1, one, "e_min_one").unwrap(),
|
||||||
|
ctx.builder.build_int_add(src_idx.1, one, "e_add_one").unwrap(),
|
||||||
|
"final_e",
|
||||||
|
)
|
||||||
|
.map(BasicValueEnum::into_int_value)
|
||||||
|
.unwrap();
|
||||||
|
let dest_end = ctx
|
||||||
|
.builder
|
||||||
|
.build_select(
|
||||||
|
ctx.builder.build_int_compare(IntPredicate::SLT, dest_idx.2, zero, "is_neg").unwrap(),
|
||||||
|
ctx.builder.build_int_sub(dest_idx.1, one, "e_min_one").unwrap(),
|
||||||
|
ctx.builder.build_int_add(dest_idx.1, one, "e_add_one").unwrap(),
|
||||||
|
"final_e",
|
||||||
|
)
|
||||||
|
.map(BasicValueEnum::into_int_value)
|
||||||
|
.unwrap();
|
||||||
|
let src_slice_len =
|
||||||
|
calculate_len_for_slice_range(generator, ctx, src_idx.0, src_end, src_idx.2);
|
||||||
|
let dest_slice_len =
|
||||||
|
calculate_len_for_slice_range(generator, ctx, dest_idx.0, dest_end, dest_idx.2);
|
||||||
|
let src_eq_dest = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(IntPredicate::EQ, src_slice_len, dest_slice_len, "slice_src_eq_dest")
|
||||||
|
.unwrap();
|
||||||
|
let src_slt_dest = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(IntPredicate::SLT, src_slice_len, dest_slice_len, "slice_src_slt_dest")
|
||||||
|
.unwrap();
|
||||||
|
let dest_step_eq_one = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(
|
||||||
|
IntPredicate::EQ,
|
||||||
|
dest_idx.2,
|
||||||
|
dest_idx.2.get_type().const_int(1, false),
|
||||||
|
"slice_dest_step_eq_one",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1").unwrap();
|
||||||
|
let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond").unwrap();
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
cond,
|
||||||
|
"0:ValueError",
|
||||||
|
"attempt to assign sequence of size {0} to slice of size {1} with step size {2}",
|
||||||
|
[Some(src_slice_len), Some(dest_slice_len), Some(dest_idx.2)],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
|
||||||
|
let new_len = {
|
||||||
|
let args = vec![
|
||||||
|
dest_idx.0.into(), // dest start idx
|
||||||
|
dest_idx.1.into(), // dest end idx
|
||||||
|
dest_idx.2.into(), // dest step
|
||||||
|
dest_arr_ptr.into(), // dest arr ptr
|
||||||
|
dest_len.into(), // dest arr len
|
||||||
|
src_idx.0.into(), // src start idx
|
||||||
|
src_idx.1.into(), // src end idx
|
||||||
|
src_idx.2.into(), // src step
|
||||||
|
src_arr_ptr.into(), // src arr ptr
|
||||||
|
src_len.into(), // src arr len
|
||||||
|
{
|
||||||
|
let s = match ty {
|
||||||
|
BasicTypeEnum::FloatType(t) => t.size_of(),
|
||||||
|
BasicTypeEnum::IntType(t) => t.size_of(),
|
||||||
|
BasicTypeEnum::PointerType(t) => t.size_of(),
|
||||||
|
BasicTypeEnum::StructType(t) => t.size_of().unwrap(),
|
||||||
|
_ => codegen_unreachable!(ctx),
|
||||||
|
};
|
||||||
|
ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size").unwrap()
|
||||||
|
}
|
||||||
|
.into(),
|
||||||
|
];
|
||||||
|
ctx.builder
|
||||||
|
.build_call(slice_assign_fun, args.as_slice(), "slice_assign")
|
||||||
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
|
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||||
|
.map(Either::unwrap_left)
|
||||||
|
.unwrap()
|
||||||
|
};
|
||||||
|
// update length
|
||||||
|
let need_update =
|
||||||
|
ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update").unwrap();
|
||||||
|
let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
|
||||||
|
let update_bb = ctx.ctx.append_basic_block(current, "update");
|
||||||
|
let cont_bb = ctx.ctx.append_basic_block(current, "cont");
|
||||||
|
ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb).unwrap();
|
||||||
|
ctx.builder.position_at_end(update_bb);
|
||||||
|
let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len").unwrap();
|
||||||
|
dest_arr.store_size(ctx, generator, new_len);
|
||||||
|
ctx.builder.build_unconditional_branch(cont_bb).unwrap();
|
||||||
|
ctx.builder.position_at_end(cont_bb);
|
||||||
|
}
|
|
@ -0,0 +1,152 @@
|
||||||
|
use inkwell::{
|
||||||
|
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue},
|
||||||
|
IntPredicate,
|
||||||
|
};
|
||||||
|
use itertools::Either;
|
||||||
|
|
||||||
|
use crate::codegen::{
|
||||||
|
macros::codegen_unreachable,
|
||||||
|
{CodeGenContext, CodeGenerator},
|
||||||
|
};
|
||||||
|
|
||||||
|
// repeated squaring method adapted from GNU Scientific Library:
|
||||||
|
// https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
|
||||||
|
pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
base: IntValue<'ctx>,
|
||||||
|
exp: IntValue<'ctx>,
|
||||||
|
signed: bool,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
let symbol = match (base.get_type().get_bit_width(), exp.get_type().get_bit_width(), signed) {
|
||||||
|
(32, 32, true) => "__nac3_int_exp_int32_t",
|
||||||
|
(64, 64, true) => "__nac3_int_exp_int64_t",
|
||||||
|
(32, 32, false) => "__nac3_int_exp_uint32_t",
|
||||||
|
(64, 64, false) => "__nac3_int_exp_uint64_t",
|
||||||
|
_ => codegen_unreachable!(ctx),
|
||||||
|
};
|
||||||
|
let base_type = base.get_type();
|
||||||
|
let pow_fun = ctx.module.get_function(symbol).unwrap_or_else(|| {
|
||||||
|
let fn_type = base_type.fn_type(&[base_type.into(), base_type.into()], false);
|
||||||
|
ctx.module.add_function(symbol, fn_type, None)
|
||||||
|
});
|
||||||
|
// throw exception when exp < 0
|
||||||
|
let ge_zero = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(
|
||||||
|
IntPredicate::SGE,
|
||||||
|
exp,
|
||||||
|
exp.get_type().const_zero(),
|
||||||
|
"assert_int_pow_ge_0",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
ge_zero,
|
||||||
|
"0:ValueError",
|
||||||
|
"integer power must be positive or zero",
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
ctx.builder
|
||||||
|
.build_call(pow_fun, &[base.into(), exp.into()], "call_int_pow")
|
||||||
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
|
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||||
|
.map(Either::unwrap_left)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a call to `isinf` in IR. Returns an `i1` representing the result.
|
||||||
|
pub fn call_isinf<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
v: FloatValue<'ctx>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
let intrinsic_fn = ctx.module.get_function("__nac3_isinf").unwrap_or_else(|| {
|
||||||
|
let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
|
||||||
|
ctx.module.add_function("__nac3_isinf", fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
let ret = ctx
|
||||||
|
.builder
|
||||||
|
.build_call(intrinsic_fn, &[v.into()], "isinf")
|
||||||
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
|
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||||
|
.map(Either::unwrap_left)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
generator.bool_to_i1(ctx, ret)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a call to `isnan` in IR. Returns an `i1` representing the result.
|
||||||
|
pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
v: FloatValue<'ctx>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
let intrinsic_fn = ctx.module.get_function("__nac3_isnan").unwrap_or_else(|| {
|
||||||
|
let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
|
||||||
|
ctx.module.add_function("__nac3_isnan", fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
let ret = ctx
|
||||||
|
.builder
|
||||||
|
.build_call(intrinsic_fn, &[v.into()], "isnan")
|
||||||
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
|
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||||
|
.map(Either::unwrap_left)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
generator.bool_to_i1(ctx, ret)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a call to `gamma` in IR. Returns an `f64` representing the result.
|
||||||
|
pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
|
||||||
|
let llvm_f64 = ctx.ctx.f64_type();
|
||||||
|
|
||||||
|
let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| {
|
||||||
|
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||||
|
ctx.module.add_function("__nac3_gamma", fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
ctx.builder
|
||||||
|
.build_call(intrinsic_fn, &[v.into()], "gamma")
|
||||||
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
|
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||||
|
.map(Either::unwrap_left)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a call to `gammaln` in IR. Returns an `f64` representing the result.
|
||||||
|
pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
|
||||||
|
let llvm_f64 = ctx.ctx.f64_type();
|
||||||
|
|
||||||
|
let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| {
|
||||||
|
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||||
|
ctx.module.add_function("__nac3_gammaln", fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
ctx.builder
|
||||||
|
.build_call(intrinsic_fn, &[v.into()], "gammaln")
|
||||||
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
|
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||||
|
.map(Either::unwrap_left)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a call to `j0` in IR. Returns an `f64` representing the result.
|
||||||
|
pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
|
||||||
|
let llvm_f64 = ctx.ctx.f64_type();
|
||||||
|
|
||||||
|
let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| {
|
||||||
|
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||||
|
ctx.module.add_function("__nac3_j0", fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
ctx.builder
|
||||||
|
.build_call(intrinsic_fn, &[v.into()], "j0")
|
||||||
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
|
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||||
|
.map(Either::unwrap_left)
|
||||||
|
.unwrap()
|
||||||
|
}
|
|
@ -1,27 +1,26 @@
|
||||||
use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Type};
|
|
||||||
|
|
||||||
use super::{
|
|
||||||
classes::{ArrayLikeValue, ListValue},
|
|
||||||
macros::codegen_unreachable,
|
|
||||||
model::{function::FnCall, *},
|
|
||||||
object::{
|
|
||||||
list::List,
|
|
||||||
ndarray::{broadcast::ShapeEntry, indexing::NDIndex, nditer::NDIter, NDArray},
|
|
||||||
},
|
|
||||||
CodeGenContext, CodeGenerator,
|
|
||||||
};
|
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
attributes::{Attribute, AttributeLoc},
|
attributes::{Attribute, AttributeLoc},
|
||||||
context::Context,
|
context::Context,
|
||||||
memory_buffer::MemoryBuffer,
|
memory_buffer::MemoryBuffer,
|
||||||
module::Module,
|
module::Module,
|
||||||
types::BasicTypeEnum,
|
values::{BasicValue, BasicValueEnum, IntValue},
|
||||||
values::{BasicValue, BasicValueEnum, CallSiteValue, FloatValue, IntValue},
|
IntPredicate,
|
||||||
AddressSpace, IntPredicate,
|
|
||||||
};
|
};
|
||||||
use itertools::Either;
|
|
||||||
use nac3parser::ast::Expr;
|
use nac3parser::ast::Expr;
|
||||||
|
|
||||||
|
use super::{CodeGenContext, CodeGenerator};
|
||||||
|
use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Type};
|
||||||
|
pub use list::*;
|
||||||
|
pub use math::*;
|
||||||
|
pub use ndarray::*;
|
||||||
|
pub use slice::*;
|
||||||
|
|
||||||
|
mod list;
|
||||||
|
mod math;
|
||||||
|
mod ndarray;
|
||||||
|
mod slice;
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> {
|
pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> {
|
||||||
let bitcode_buf = MemoryBuffer::create_from_memory_range(
|
let bitcode_buf = MemoryBuffer::create_from_memory_range(
|
||||||
|
@ -61,88 +60,6 @@ pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver)
|
||||||
irrt_mod
|
irrt_mod
|
||||||
}
|
}
|
||||||
|
|
||||||
// repeated squaring method adapted from GNU Scientific Library:
|
|
||||||
// https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
|
|
||||||
pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
base: IntValue<'ctx>,
|
|
||||||
exp: IntValue<'ctx>,
|
|
||||||
signed: bool,
|
|
||||||
) -> IntValue<'ctx> {
|
|
||||||
let symbol = match (base.get_type().get_bit_width(), exp.get_type().get_bit_width(), signed) {
|
|
||||||
(32, 32, true) => "__nac3_int_exp_int32_t",
|
|
||||||
(64, 64, true) => "__nac3_int_exp_int64_t",
|
|
||||||
(32, 32, false) => "__nac3_int_exp_uint32_t",
|
|
||||||
(64, 64, false) => "__nac3_int_exp_uint64_t",
|
|
||||||
_ => codegen_unreachable!(ctx),
|
|
||||||
};
|
|
||||||
let base_type = base.get_type();
|
|
||||||
let pow_fun = ctx.module.get_function(symbol).unwrap_or_else(|| {
|
|
||||||
let fn_type = base_type.fn_type(&[base_type.into(), base_type.into()], false);
|
|
||||||
ctx.module.add_function(symbol, fn_type, None)
|
|
||||||
});
|
|
||||||
// throw exception when exp < 0
|
|
||||||
let ge_zero = ctx
|
|
||||||
.builder
|
|
||||||
.build_int_compare(
|
|
||||||
IntPredicate::SGE,
|
|
||||||
exp,
|
|
||||||
exp.get_type().const_zero(),
|
|
||||||
"assert_int_pow_ge_0",
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
ge_zero,
|
|
||||||
"0:ValueError",
|
|
||||||
"integer power must be positive or zero",
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
ctx.builder
|
|
||||||
.build_call(pow_fun, &[base.into(), exp.into()], "call_int_pow")
|
|
||||||
.map(CallSiteValue::try_as_basic_value)
|
|
||||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
|
||||||
.map(Either::unwrap_left)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
start: IntValue<'ctx>,
|
|
||||||
end: IntValue<'ctx>,
|
|
||||||
step: IntValue<'ctx>,
|
|
||||||
) -> IntValue<'ctx> {
|
|
||||||
const SYMBOL: &str = "__nac3_range_slice_len";
|
|
||||||
let len_func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
|
|
||||||
let i32_t = ctx.ctx.i32_type();
|
|
||||||
let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into(), i32_t.into()], false);
|
|
||||||
ctx.module.add_function(SYMBOL, fn_t, None)
|
|
||||||
});
|
|
||||||
|
|
||||||
// assert step != 0, throw exception if not
|
|
||||||
let not_zero = ctx
|
|
||||||
.builder
|
|
||||||
.build_int_compare(IntPredicate::NE, step, step.get_type().const_zero(), "range_step_ne")
|
|
||||||
.unwrap();
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
not_zero,
|
|
||||||
"0:ValueError",
|
|
||||||
"step must not be zero",
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
ctx.builder
|
|
||||||
.build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len")
|
|
||||||
.map(CallSiteValue::try_as_basic_value)
|
|
||||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
|
||||||
.map(Either::unwrap_left)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// NOTE: the output value of the end index of this function should be compared ***inclusively***,
|
/// NOTE: the output value of the end index of this function should be compared ***inclusively***,
|
||||||
/// because python allows `a[2::-1]`, whose semantic is `[a[2], a[1], a[0]]`, which is equivalent to
|
/// because python allows `a[2::-1]`, whose semantic is `[a[2], a[1], a[0]]`, which is equivalent to
|
||||||
/// NO numeric slice in python.
|
/// NO numeric slice in python.
|
||||||
|
@ -308,569 +225,3 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// this function allows index out of range, since python
|
|
||||||
/// allows index out of range in slice (`a = [1,2,3]; a[1:10] == [2,3]`).
|
|
||||||
pub fn handle_slice_index_bound<'ctx, G: CodeGenerator>(
|
|
||||||
i: &Expr<Option<Type>>,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
generator: &mut G,
|
|
||||||
length: IntValue<'ctx>,
|
|
||||||
) -> Result<Option<IntValue<'ctx>>, String> {
|
|
||||||
const SYMBOL: &str = "__nac3_slice_index_bound";
|
|
||||||
let func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
|
|
||||||
let i32_t = ctx.ctx.i32_type();
|
|
||||||
let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into()], false);
|
|
||||||
ctx.module.add_function(SYMBOL, fn_t, None)
|
|
||||||
});
|
|
||||||
|
|
||||||
let i = if let Some(v) = generator.gen_expr(ctx, i)? {
|
|
||||||
v.to_basic_value_enum(ctx, generator, i.custom.unwrap())?
|
|
||||||
} else {
|
|
||||||
return Ok(None);
|
|
||||||
};
|
|
||||||
Ok(Some(
|
|
||||||
ctx.builder
|
|
||||||
.build_call(func, &[i.into(), length.into()], "bounded_ind")
|
|
||||||
.map(CallSiteValue::try_as_basic_value)
|
|
||||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
|
||||||
.map(Either::unwrap_left)
|
|
||||||
.unwrap(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// This function handles 'end' **inclusively**.
|
|
||||||
/// Order of tuples `assign_idx` and `value_idx` is ('start', 'end', 'step').
|
|
||||||
/// Negative index should be handled before entering this function
|
|
||||||
pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
ty: BasicTypeEnum<'ctx>,
|
|
||||||
dest_arr: ListValue<'ctx>,
|
|
||||||
dest_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
|
|
||||||
src_arr: ListValue<'ctx>,
|
|
||||||
src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
|
|
||||||
) {
|
|
||||||
let size_ty = generator.get_size_type(ctx.ctx);
|
|
||||||
let int8_ptr = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
|
|
||||||
let int32 = ctx.ctx.i32_type();
|
|
||||||
let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", int8_ptr);
|
|
||||||
let slice_assign_fun = {
|
|
||||||
let ty_vec = vec![
|
|
||||||
int32.into(), // dest start idx
|
|
||||||
int32.into(), // dest end idx
|
|
||||||
int32.into(), // dest step
|
|
||||||
elem_ptr_type.into(), // dest arr ptr
|
|
||||||
int32.into(), // dest arr len
|
|
||||||
int32.into(), // src start idx
|
|
||||||
int32.into(), // src end idx
|
|
||||||
int32.into(), // src step
|
|
||||||
elem_ptr_type.into(), // src arr ptr
|
|
||||||
int32.into(), // src arr len
|
|
||||||
int32.into(), // size
|
|
||||||
];
|
|
||||||
ctx.module.get_function(fun_symbol).unwrap_or_else(|| {
|
|
||||||
let fn_t = int32.fn_type(ty_vec.as_slice(), false);
|
|
||||||
ctx.module.add_function(fun_symbol, fn_t, None)
|
|
||||||
})
|
|
||||||
};
|
|
||||||
|
|
||||||
let zero = int32.const_zero();
|
|
||||||
let one = int32.const_int(1, false);
|
|
||||||
let dest_arr_ptr = dest_arr.data().base_ptr(ctx, generator);
|
|
||||||
let dest_arr_ptr =
|
|
||||||
ctx.builder.build_pointer_cast(dest_arr_ptr, elem_ptr_type, "dest_arr_ptr_cast").unwrap();
|
|
||||||
let dest_len = dest_arr.load_size(ctx, Some("dest.len"));
|
|
||||||
let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32").unwrap();
|
|
||||||
let src_arr_ptr = src_arr.data().base_ptr(ctx, generator);
|
|
||||||
let src_arr_ptr =
|
|
||||||
ctx.builder.build_pointer_cast(src_arr_ptr, elem_ptr_type, "src_arr_ptr_cast").unwrap();
|
|
||||||
let src_len = src_arr.load_size(ctx, Some("src.len"));
|
|
||||||
let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32").unwrap();
|
|
||||||
|
|
||||||
// index in bound and positive should be done
|
|
||||||
// assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and
|
|
||||||
// throw exception if not satisfied
|
|
||||||
let src_end = ctx
|
|
||||||
.builder
|
|
||||||
.build_select(
|
|
||||||
ctx.builder.build_int_compare(IntPredicate::SLT, src_idx.2, zero, "is_neg").unwrap(),
|
|
||||||
ctx.builder.build_int_sub(src_idx.1, one, "e_min_one").unwrap(),
|
|
||||||
ctx.builder.build_int_add(src_idx.1, one, "e_add_one").unwrap(),
|
|
||||||
"final_e",
|
|
||||||
)
|
|
||||||
.map(BasicValueEnum::into_int_value)
|
|
||||||
.unwrap();
|
|
||||||
let dest_end = ctx
|
|
||||||
.builder
|
|
||||||
.build_select(
|
|
||||||
ctx.builder.build_int_compare(IntPredicate::SLT, dest_idx.2, zero, "is_neg").unwrap(),
|
|
||||||
ctx.builder.build_int_sub(dest_idx.1, one, "e_min_one").unwrap(),
|
|
||||||
ctx.builder.build_int_add(dest_idx.1, one, "e_add_one").unwrap(),
|
|
||||||
"final_e",
|
|
||||||
)
|
|
||||||
.map(BasicValueEnum::into_int_value)
|
|
||||||
.unwrap();
|
|
||||||
let src_slice_len =
|
|
||||||
calculate_len_for_slice_range(generator, ctx, src_idx.0, src_end, src_idx.2);
|
|
||||||
let dest_slice_len =
|
|
||||||
calculate_len_for_slice_range(generator, ctx, dest_idx.0, dest_end, dest_idx.2);
|
|
||||||
let src_eq_dest = ctx
|
|
||||||
.builder
|
|
||||||
.build_int_compare(IntPredicate::EQ, src_slice_len, dest_slice_len, "slice_src_eq_dest")
|
|
||||||
.unwrap();
|
|
||||||
let src_slt_dest = ctx
|
|
||||||
.builder
|
|
||||||
.build_int_compare(IntPredicate::SLT, src_slice_len, dest_slice_len, "slice_src_slt_dest")
|
|
||||||
.unwrap();
|
|
||||||
let dest_step_eq_one = ctx
|
|
||||||
.builder
|
|
||||||
.build_int_compare(
|
|
||||||
IntPredicate::EQ,
|
|
||||||
dest_idx.2,
|
|
||||||
dest_idx.2.get_type().const_int(1, false),
|
|
||||||
"slice_dest_step_eq_one",
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1").unwrap();
|
|
||||||
let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond").unwrap();
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
cond,
|
|
||||||
"0:ValueError",
|
|
||||||
"attempt to assign sequence of size {0} to slice of size {1} with step size {2}",
|
|
||||||
[Some(src_slice_len), Some(dest_slice_len), Some(dest_idx.2)],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
|
|
||||||
let new_len = {
|
|
||||||
let args = vec![
|
|
||||||
dest_idx.0.into(), // dest start idx
|
|
||||||
dest_idx.1.into(), // dest end idx
|
|
||||||
dest_idx.2.into(), // dest step
|
|
||||||
dest_arr_ptr.into(), // dest arr ptr
|
|
||||||
dest_len.into(), // dest arr len
|
|
||||||
src_idx.0.into(), // src start idx
|
|
||||||
src_idx.1.into(), // src end idx
|
|
||||||
src_idx.2.into(), // src step
|
|
||||||
src_arr_ptr.into(), // src arr ptr
|
|
||||||
src_len.into(), // src arr len
|
|
||||||
{
|
|
||||||
let s = match ty {
|
|
||||||
BasicTypeEnum::FloatType(t) => t.size_of(),
|
|
||||||
BasicTypeEnum::IntType(t) => t.size_of(),
|
|
||||||
BasicTypeEnum::PointerType(t) => t.size_of(),
|
|
||||||
BasicTypeEnum::StructType(t) => t.size_of().unwrap(),
|
|
||||||
_ => codegen_unreachable!(ctx),
|
|
||||||
};
|
|
||||||
ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size").unwrap()
|
|
||||||
}
|
|
||||||
.into(),
|
|
||||||
];
|
|
||||||
ctx.builder
|
|
||||||
.build_call(slice_assign_fun, args.as_slice(), "slice_assign")
|
|
||||||
.map(CallSiteValue::try_as_basic_value)
|
|
||||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
|
||||||
.map(Either::unwrap_left)
|
|
||||||
.unwrap()
|
|
||||||
};
|
|
||||||
// update length
|
|
||||||
let need_update =
|
|
||||||
ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update").unwrap();
|
|
||||||
let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
|
|
||||||
let update_bb = ctx.ctx.append_basic_block(current, "update");
|
|
||||||
let cont_bb = ctx.ctx.append_basic_block(current, "cont");
|
|
||||||
ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb).unwrap();
|
|
||||||
ctx.builder.position_at_end(update_bb);
|
|
||||||
let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len").unwrap();
|
|
||||||
dest_arr.store_size(ctx, generator, new_len);
|
|
||||||
ctx.builder.build_unconditional_branch(cont_bb).unwrap();
|
|
||||||
ctx.builder.position_at_end(cont_bb);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates a call to `isinf` in IR. Returns an `i1` representing the result.
|
|
||||||
pub fn call_isinf<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
v: FloatValue<'ctx>,
|
|
||||||
) -> IntValue<'ctx> {
|
|
||||||
let intrinsic_fn = ctx.module.get_function("__nac3_isinf").unwrap_or_else(|| {
|
|
||||||
let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
|
|
||||||
ctx.module.add_function("__nac3_isinf", fn_type, None)
|
|
||||||
});
|
|
||||||
|
|
||||||
let ret = ctx
|
|
||||||
.builder
|
|
||||||
.build_call(intrinsic_fn, &[v.into()], "isinf")
|
|
||||||
.map(CallSiteValue::try_as_basic_value)
|
|
||||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
|
||||||
.map(Either::unwrap_left)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
generator.bool_to_i1(ctx, ret)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates a call to `isnan` in IR. Returns an `i1` representing the result.
|
|
||||||
pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
v: FloatValue<'ctx>,
|
|
||||||
) -> IntValue<'ctx> {
|
|
||||||
let intrinsic_fn = ctx.module.get_function("__nac3_isnan").unwrap_or_else(|| {
|
|
||||||
let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
|
|
||||||
ctx.module.add_function("__nac3_isnan", fn_type, None)
|
|
||||||
});
|
|
||||||
|
|
||||||
let ret = ctx
|
|
||||||
.builder
|
|
||||||
.build_call(intrinsic_fn, &[v.into()], "isnan")
|
|
||||||
.map(CallSiteValue::try_as_basic_value)
|
|
||||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
|
||||||
.map(Either::unwrap_left)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
generator.bool_to_i1(ctx, ret)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates a call to `gamma` in IR. Returns an `f64` representing the result.
|
|
||||||
pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
|
|
||||||
let llvm_f64 = ctx.ctx.f64_type();
|
|
||||||
|
|
||||||
let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| {
|
|
||||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
|
||||||
ctx.module.add_function("__nac3_gamma", fn_type, None)
|
|
||||||
});
|
|
||||||
|
|
||||||
ctx.builder
|
|
||||||
.build_call(intrinsic_fn, &[v.into()], "gamma")
|
|
||||||
.map(CallSiteValue::try_as_basic_value)
|
|
||||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
|
||||||
.map(Either::unwrap_left)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates a call to `gammaln` in IR. Returns an `f64` representing the result.
|
|
||||||
pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
|
|
||||||
let llvm_f64 = ctx.ctx.f64_type();
|
|
||||||
|
|
||||||
let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| {
|
|
||||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
|
||||||
ctx.module.add_function("__nac3_gammaln", fn_type, None)
|
|
||||||
});
|
|
||||||
|
|
||||||
ctx.builder
|
|
||||||
.build_call(intrinsic_fn, &[v.into()], "gammaln")
|
|
||||||
.map(CallSiteValue::try_as_basic_value)
|
|
||||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
|
||||||
.map(Either::unwrap_left)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates a call to `j0` in IR. Returns an `f64` representing the result.
|
|
||||||
pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
|
|
||||||
let llvm_f64 = ctx.ctx.f64_type();
|
|
||||||
|
|
||||||
let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| {
|
|
||||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
|
||||||
ctx.module.add_function("__nac3_j0", fn_type, None)
|
|
||||||
});
|
|
||||||
|
|
||||||
ctx.builder
|
|
||||||
.build_call(intrinsic_fn, &[v.into()], "j0")
|
|
||||||
.map(CallSiteValue::try_as_basic_value)
|
|
||||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
|
||||||
.map(Either::unwrap_left)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
// When [`TypeContext::size_type`] is 32-bits, the function name is "{fn_name}".
|
|
||||||
// When [`TypeContext::size_type`] is 64-bits, the function name is "{fn_name}64".
|
|
||||||
#[must_use]
|
|
||||||
pub fn get_sizet_dependent_function_name<G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'_, '_>,
|
|
||||||
name: &str,
|
|
||||||
) -> String {
|
|
||||||
let mut name = name.to_owned();
|
|
||||||
match generator.get_size_type(ctx.ctx).get_bit_width() {
|
|
||||||
32 => {}
|
|
||||||
64 => name.push_str("64"),
|
|
||||||
bit_width => {
|
|
||||||
panic!("Unsupported int type bit width {bit_width}, must be either 32-bits or 64-bits")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
name
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
ndims: Instance<'ctx, Int<SizeT>>,
|
|
||||||
shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
|
||||||
) {
|
|
||||||
let name = get_sizet_dependent_function_name(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
"__nac3_ndarray_util_assert_shape_no_negative",
|
|
||||||
);
|
|
||||||
FnCall::builder(generator, ctx, &name).arg(ndims).arg(shape).returning_void();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
ndarray_ndims: Instance<'ctx, Int<SizeT>>,
|
|
||||||
ndarray_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
|
||||||
output_ndims: Instance<'ctx, Int<SizeT>>,
|
|
||||||
output_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
|
||||||
) {
|
|
||||||
let name = get_sizet_dependent_function_name(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
"__nac3_ndarray_util_assert_output_shape_same",
|
|
||||||
);
|
|
||||||
FnCall::builder(generator, ctx, &name)
|
|
||||||
.arg(ndarray_ndims)
|
|
||||||
.arg(ndarray_shape)
|
|
||||||
.arg(output_ndims)
|
|
||||||
.arg(output_shape)
|
|
||||||
.returning_void();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
|
|
||||||
) -> Instance<'ctx, Int<SizeT>> {
|
|
||||||
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_size");
|
|
||||||
FnCall::builder(generator, ctx, &name).arg(ndarray).returning_auto("size")
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
|
|
||||||
) -> Instance<'ctx, Int<SizeT>> {
|
|
||||||
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_nbytes");
|
|
||||||
FnCall::builder(generator, ctx, &name).arg(ndarray).returning_auto("nbytes")
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
|
|
||||||
) -> Instance<'ctx, Int<SizeT>> {
|
|
||||||
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_len");
|
|
||||||
FnCall::builder(generator, ctx, &name).arg(ndarray).returning_auto("len")
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
|
|
||||||
) -> Instance<'ctx, Int<Bool>> {
|
|
||||||
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_is_c_contiguous");
|
|
||||||
FnCall::builder(generator, ctx, &name).arg(ndarray).returning_auto("is_c_contiguous")
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
|
|
||||||
index: Instance<'ctx, Int<SizeT>>,
|
|
||||||
) -> Instance<'ctx, Ptr<Int<Byte>>> {
|
|
||||||
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_get_nth_pelement");
|
|
||||||
FnCall::builder(generator, ctx, &name).arg(ndarray).arg(index).returning_auto("pelement")
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
|
|
||||||
indices: Instance<'ctx, Ptr<Int<SizeT>>>,
|
|
||||||
) -> Instance<'ctx, Ptr<Int<Byte>>> {
|
|
||||||
let name =
|
|
||||||
get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_get_pelement_by_indices");
|
|
||||||
FnCall::builder(generator, ctx, &name).arg(ndarray).arg(indices).returning_auto("pelement")
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
|
|
||||||
) {
|
|
||||||
let name =
|
|
||||||
get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_set_strides_by_shape");
|
|
||||||
FnCall::builder(generator, ctx, &name).arg(ndarray).returning_void();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
src_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
|
|
||||||
dst_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
|
|
||||||
) {
|
|
||||||
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_copy_data");
|
|
||||||
FnCall::builder(generator, ctx, &name).arg(src_ndarray).arg(dst_ndarray).returning_void();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
iter: Instance<'ctx, Ptr<Struct<NDIter>>>,
|
|
||||||
ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
|
|
||||||
indices: Instance<'ctx, Ptr<Int<SizeT>>>,
|
|
||||||
) {
|
|
||||||
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_initialize");
|
|
||||||
FnCall::builder(generator, ctx, &name).arg(iter).arg(ndarray).arg(indices).returning_void();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_nac3_nditer_has_element<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
iter: Instance<'ctx, Ptr<Struct<NDIter>>>,
|
|
||||||
) -> Instance<'ctx, Int<Bool>> {
|
|
||||||
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_has_element");
|
|
||||||
FnCall::builder(generator, ctx, &name).arg(iter).returning_auto("has_element")
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_nac3_nditer_next<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
iter: Instance<'ctx, Ptr<Struct<NDIter>>>,
|
|
||||||
) {
|
|
||||||
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_next");
|
|
||||||
FnCall::builder(generator, ctx, &name).arg(iter).returning_void();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_nac3_ndarray_index<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
num_indices: Instance<'ctx, Int<SizeT>>,
|
|
||||||
indices: Instance<'ctx, Ptr<Struct<NDIndex>>>,
|
|
||||||
src_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
|
|
||||||
dst_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
|
|
||||||
) {
|
|
||||||
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_index");
|
|
||||||
FnCall::builder(generator, ctx, &name)
|
|
||||||
.arg(num_indices)
|
|
||||||
.arg(indices)
|
|
||||||
.arg(src_ndarray)
|
|
||||||
.arg(dst_ndarray)
|
|
||||||
.returning_void();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_nac3_ndarray_array_set_and_validate_list_shape<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
list: Instance<'ctx, Ptr<Struct<List<Int<Byte>>>>>,
|
|
||||||
ndims: Instance<'ctx, Int<SizeT>>,
|
|
||||||
shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
|
||||||
) {
|
|
||||||
let name = get_sizet_dependent_function_name(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
"__nac3_ndarray_array_set_and_validate_list_shape",
|
|
||||||
);
|
|
||||||
FnCall::builder(generator, ctx, &name).arg(list).arg(ndims).arg(shape).returning_void();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_nac3_ndarray_array_write_list_to_array<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
list: Instance<'ctx, Ptr<Struct<List<Int<Byte>>>>>,
|
|
||||||
ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
|
|
||||||
) {
|
|
||||||
let name = get_sizet_dependent_function_name(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
"__nac3_ndarray_array_write_list_to_array",
|
|
||||||
);
|
|
||||||
FnCall::builder(generator, ctx, &name).arg(list).arg(ndarray).returning_void();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_nac3_ndarray_reshape_resolve_and_check_new_shape<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
size: Instance<'ctx, Int<SizeT>>,
|
|
||||||
new_ndims: Instance<'ctx, Int<SizeT>>,
|
|
||||||
new_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
|
||||||
) {
|
|
||||||
let name = get_sizet_dependent_function_name(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
"__nac3_ndarray_reshape_resolve_and_check_new_shape",
|
|
||||||
);
|
|
||||||
FnCall::builder(generator, ctx, &name).arg(size).arg(new_ndims).arg(new_shape).returning_void();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_nac3_ndarray_broadcast_to<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
src_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
|
|
||||||
dst_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
|
|
||||||
) {
|
|
||||||
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_to");
|
|
||||||
FnCall::builder(generator, ctx, &name).arg(src_ndarray).arg(dst_ndarray).returning_void();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
num_shape_entries: Instance<'ctx, Int<SizeT>>,
|
|
||||||
shape_entries: Instance<'ctx, Ptr<Struct<ShapeEntry>>>,
|
|
||||||
dst_ndims: Instance<'ctx, Int<SizeT>>,
|
|
||||||
dst_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
|
||||||
) {
|
|
||||||
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_shapes");
|
|
||||||
FnCall::builder(generator, ctx, &name)
|
|
||||||
.arg(num_shape_entries)
|
|
||||||
.arg(shape_entries)
|
|
||||||
.arg(dst_ndims)
|
|
||||||
.arg(dst_shape)
|
|
||||||
.returning_void();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_nac3_ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
src_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
|
|
||||||
dst_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
|
|
||||||
num_axes: Instance<'ctx, Int<SizeT>>,
|
|
||||||
axes: Instance<'ctx, Ptr<Int<SizeT>>>,
|
|
||||||
) {
|
|
||||||
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_transpose");
|
|
||||||
FnCall::builder(generator, ctx, &name)
|
|
||||||
.arg(src_ndarray)
|
|
||||||
.arg(dst_ndarray)
|
|
||||||
.arg(num_axes)
|
|
||||||
.arg(axes)
|
|
||||||
.returning_void();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
pub fn call_nac3_ndarray_matmul_calculate_shapes<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
a_ndims: Instance<'ctx, Int<SizeT>>,
|
|
||||||
a_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
|
||||||
b_ndims: Instance<'ctx, Int<SizeT>>,
|
|
||||||
b_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
|
||||||
final_ndims: Instance<'ctx, Int<SizeT>>,
|
|
||||||
new_a_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
|
||||||
new_b_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
|
||||||
dst_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
|
||||||
) {
|
|
||||||
let name =
|
|
||||||
get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_matmul_calculate_shapes");
|
|
||||||
FnCall::builder(generator, ctx, &name)
|
|
||||||
.arg(a_ndims)
|
|
||||||
.arg(a_shape)
|
|
||||||
.arg(b_ndims)
|
|
||||||
.arg(b_shape)
|
|
||||||
.arg(final_ndims)
|
|
||||||
.arg(new_a_shape)
|
|
||||||
.arg(new_b_shape)
|
|
||||||
.arg(dst_shape)
|
|
||||||
.returning_void();
|
|
||||||
}
|
|
||||||
|
|
|
@ -0,0 +1,384 @@
|
||||||
|
use inkwell::{
|
||||||
|
types::IntType,
|
||||||
|
values::{BasicValueEnum, CallSiteValue, IntValue},
|
||||||
|
AddressSpace, IntPredicate,
|
||||||
|
};
|
||||||
|
use itertools::Either;
|
||||||
|
|
||||||
|
use crate::codegen::{
|
||||||
|
llvm_intrinsics,
|
||||||
|
macros::codegen_unreachable,
|
||||||
|
stmt::gen_for_callback_incrementing,
|
||||||
|
values::{
|
||||||
|
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, NDArrayValue, TypedArrayLikeAccessor,
|
||||||
|
TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
|
||||||
|
},
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
|
||||||
|
/// calculated total size.
|
||||||
|
///
|
||||||
|
/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension.
|
||||||
|
/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for,
|
||||||
|
/// or [`None`] if starting from the first dimension and ending at the last dimension
|
||||||
|
/// respectively.
|
||||||
|
pub fn call_ndarray_calc_size<'ctx, G, Dims>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
dims: &Dims,
|
||||||
|
(begin, end): (Option<IntValue<'ctx>>, Option<IntValue<'ctx>>),
|
||||||
|
) -> IntValue<'ctx>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
Dims: ArrayLikeIndexer<'ctx>,
|
||||||
|
{
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||||
|
|
||||||
|
let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() {
|
||||||
|
32 => "__nac3_ndarray_calc_size",
|
||||||
|
64 => "__nac3_ndarray_calc_size64",
|
||||||
|
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
|
||||||
|
};
|
||||||
|
let ndarray_calc_size_fn_t = llvm_usize.fn_type(
|
||||||
|
&[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()],
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
let ndarray_calc_size_fn =
|
||||||
|
ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| {
|
||||||
|
ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
let begin = begin.unwrap_or_else(|| llvm_usize.const_zero());
|
||||||
|
let end = end.unwrap_or_else(|| dims.size(ctx, generator));
|
||||||
|
ctx.builder
|
||||||
|
.build_call(
|
||||||
|
ndarray_calc_size_fn,
|
||||||
|
&[
|
||||||
|
dims.base_ptr(ctx, generator).into(),
|
||||||
|
dims.size(ctx, generator).into(),
|
||||||
|
begin.into(),
|
||||||
|
end.into(),
|
||||||
|
],
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
|
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||||
|
.map(Either::unwrap_left)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`]
|
||||||
|
/// containing `i32` indices of the flattened index.
|
||||||
|
///
|
||||||
|
/// * `index` - The index to compute the multidimensional index for.
|
||||||
|
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
|
||||||
|
/// `NDArray`.
|
||||||
|
pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
index: IntValue<'ctx>,
|
||||||
|
ndarray: NDArrayValue<'ctx>,
|
||||||
|
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
|
||||||
|
let llvm_void = ctx.ctx.void_type();
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
|
||||||
|
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||||
|
|
||||||
|
let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() {
|
||||||
|
32 => "__nac3_ndarray_calc_nd_indices",
|
||||||
|
64 => "__nac3_ndarray_calc_nd_indices64",
|
||||||
|
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
|
||||||
|
};
|
||||||
|
let ndarray_calc_nd_indices_fn =
|
||||||
|
ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| {
|
||||||
|
let fn_type = llvm_void.fn_type(
|
||||||
|
&[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()],
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
|
||||||
|
ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||||
|
let ndarray_dims = ndarray.shape();
|
||||||
|
|
||||||
|
let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap();
|
||||||
|
|
||||||
|
ctx.builder
|
||||||
|
.build_call(
|
||||||
|
ndarray_calc_nd_indices_fn,
|
||||||
|
&[
|
||||||
|
index.into(),
|
||||||
|
ndarray_dims.base_ptr(ctx, generator).into(),
|
||||||
|
ndarray_num_dims.into(),
|
||||||
|
indices.into(),
|
||||||
|
],
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
TypedArrayLikeAdapter::from(
|
||||||
|
ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None),
|
||||||
|
Box::new(|_, v| v.into_int_value()),
|
||||||
|
Box::new(|_, v| v.into()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call_ndarray_flatten_index_impl<'ctx, G, Indices>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray: NDArrayValue<'ctx>,
|
||||||
|
indices: &Indices,
|
||||||
|
) -> IntValue<'ctx>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
Indices: ArrayLikeIndexer<'ctx>,
|
||||||
|
{
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
|
||||||
|
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||||
|
|
||||||
|
debug_assert_eq!(
|
||||||
|
IntType::try_from(indices.element_type(ctx, generator))
|
||||||
|
.map(IntType::get_bit_width)
|
||||||
|
.unwrap_or_default(),
|
||||||
|
llvm_i32.get_bit_width(),
|
||||||
|
"Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`"
|
||||||
|
);
|
||||||
|
debug_assert_eq!(
|
||||||
|
indices.size(ctx, generator).get_type().get_bit_width(),
|
||||||
|
llvm_usize.get_bit_width(),
|
||||||
|
"Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`"
|
||||||
|
);
|
||||||
|
|
||||||
|
let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
|
||||||
|
32 => "__nac3_ndarray_flatten_index",
|
||||||
|
64 => "__nac3_ndarray_flatten_index64",
|
||||||
|
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
|
||||||
|
};
|
||||||
|
let ndarray_flatten_index_fn =
|
||||||
|
ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| {
|
||||||
|
let fn_type = llvm_usize.fn_type(
|
||||||
|
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()],
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
|
||||||
|
ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||||
|
let ndarray_dims = ndarray.shape();
|
||||||
|
|
||||||
|
let index = ctx
|
||||||
|
.builder
|
||||||
|
.build_call(
|
||||||
|
ndarray_flatten_index_fn,
|
||||||
|
&[
|
||||||
|
ndarray_dims.base_ptr(ctx, generator).into(),
|
||||||
|
ndarray_num_dims.into(),
|
||||||
|
indices.base_ptr(ctx, generator).into(),
|
||||||
|
indices.size(ctx, generator).into(),
|
||||||
|
],
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
|
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||||
|
.map(Either::unwrap_left)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
index
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
|
||||||
|
/// multidimensional index.
|
||||||
|
///
|
||||||
|
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
|
||||||
|
/// `NDArray`.
|
||||||
|
/// * `indices` - The multidimensional index to compute the flattened index for.
|
||||||
|
pub fn call_ndarray_flatten_index<'ctx, G, Index>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray: NDArrayValue<'ctx>,
|
||||||
|
indices: &Index,
|
||||||
|
) -> IntValue<'ctx>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
Index: ArrayLikeIndexer<'ctx>,
|
||||||
|
{
|
||||||
|
call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of
|
||||||
|
/// dimension and size of each dimension of the resultant `ndarray`.
|
||||||
|
pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
lhs: NDArrayValue<'ctx>,
|
||||||
|
rhs: NDArrayValue<'ctx>,
|
||||||
|
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||||
|
|
||||||
|
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
|
||||||
|
32 => "__nac3_ndarray_calc_broadcast",
|
||||||
|
64 => "__nac3_ndarray_calc_broadcast64",
|
||||||
|
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
|
||||||
|
};
|
||||||
|
let ndarray_calc_broadcast_fn =
|
||||||
|
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
|
||||||
|
let fn_type = llvm_usize.fn_type(
|
||||||
|
&[
|
||||||
|
llvm_pusize.into(),
|
||||||
|
llvm_usize.into(),
|
||||||
|
llvm_pusize.into(),
|
||||||
|
llvm_usize.into(),
|
||||||
|
llvm_pusize.into(),
|
||||||
|
],
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
|
||||||
|
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
let lhs_ndims = lhs.load_ndims(ctx);
|
||||||
|
let rhs_ndims = rhs.load_ndims(ctx);
|
||||||
|
let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None);
|
||||||
|
|
||||||
|
gen_for_callback_incrementing(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
None,
|
||||||
|
llvm_usize.const_zero(),
|
||||||
|
(min_ndims, false),
|
||||||
|
|generator, ctx, _, idx| {
|
||||||
|
let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap();
|
||||||
|
let (lhs_dim_sz, rhs_dim_sz) = unsafe {
|
||||||
|
(
|
||||||
|
lhs.shape().get_typed_unchecked(ctx, generator, &idx, None),
|
||||||
|
rhs.shape().get_typed_unchecked(ctx, generator, &idx, None),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
let llvm_usize_const_one = llvm_usize.const_int(1, false);
|
||||||
|
let lhs_eqz = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "")
|
||||||
|
.unwrap();
|
||||||
|
let rhs_eqz = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "")
|
||||||
|
.unwrap();
|
||||||
|
let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap();
|
||||||
|
|
||||||
|
let lhs_eq_rhs = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "")
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap();
|
||||||
|
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
is_compatible,
|
||||||
|
"0:ValueError",
|
||||||
|
"operands could not be broadcast together",
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
llvm_usize.const_int(1, false),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None);
|
||||||
|
let lhs_dims = lhs.shape().base_ptr(ctx, generator);
|
||||||
|
let lhs_ndims = lhs.load_ndims(ctx);
|
||||||
|
let rhs_dims = rhs.shape().base_ptr(ctx, generator);
|
||||||
|
let rhs_ndims = rhs.load_ndims(ctx);
|
||||||
|
let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap();
|
||||||
|
let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None);
|
||||||
|
|
||||||
|
ctx.builder
|
||||||
|
.build_call(
|
||||||
|
ndarray_calc_broadcast_fn,
|
||||||
|
&[
|
||||||
|
lhs_dims.into(),
|
||||||
|
lhs_ndims.into(),
|
||||||
|
rhs_dims.into(),
|
||||||
|
rhs_ndims.into(),
|
||||||
|
out_dims.base_ptr(ctx, generator).into(),
|
||||||
|
],
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
TypedArrayLikeAdapter::from(
|
||||||
|
out_dims,
|
||||||
|
Box::new(|_, v| v.into_int_value()),
|
||||||
|
Box::new(|_, v| v.into()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`]
|
||||||
|
/// containing the indices used for accessing `array` corresponding to the index of the broadcasted
|
||||||
|
/// array `broadcast_idx`.
|
||||||
|
pub fn call_ndarray_calc_broadcast_index<
|
||||||
|
'ctx,
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
BroadcastIdx: UntypedArrayLikeAccessor<'ctx>,
|
||||||
|
>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
array: NDArrayValue<'ctx>,
|
||||||
|
broadcast_idx: &BroadcastIdx,
|
||||||
|
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
|
||||||
|
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||||
|
|
||||||
|
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
|
||||||
|
32 => "__nac3_ndarray_calc_broadcast_idx",
|
||||||
|
64 => "__nac3_ndarray_calc_broadcast_idx64",
|
||||||
|
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
|
||||||
|
};
|
||||||
|
let ndarray_calc_broadcast_fn =
|
||||||
|
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
|
||||||
|
let fn_type = llvm_usize.fn_type(
|
||||||
|
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()],
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
|
||||||
|
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
let broadcast_size = broadcast_idx.size(ctx, generator);
|
||||||
|
let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();
|
||||||
|
|
||||||
|
let array_dims = array.shape().base_ptr(ctx, generator);
|
||||||
|
let array_ndims = array.load_ndims(ctx);
|
||||||
|
let broadcast_idx_ptr = unsafe {
|
||||||
|
broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
|
};
|
||||||
|
|
||||||
|
ctx.builder
|
||||||
|
.build_call(
|
||||||
|
ndarray_calc_broadcast_fn,
|
||||||
|
&[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()],
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
TypedArrayLikeAdapter::from(
|
||||||
|
ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None),
|
||||||
|
Box::new(|_, v| v.into_int_value()),
|
||||||
|
Box::new(|_, v| v.into()),
|
||||||
|
)
|
||||||
|
}
|
|
@ -0,0 +1,76 @@
|
||||||
|
use inkwell::{
|
||||||
|
values::{BasicValueEnum, CallSiteValue, IntValue},
|
||||||
|
IntPredicate,
|
||||||
|
};
|
||||||
|
use itertools::Either;
|
||||||
|
use nac3parser::ast::Expr;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
codegen::{CodeGenContext, CodeGenerator},
|
||||||
|
typecheck::typedef::Type,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// this function allows index out of range, since python
|
||||||
|
/// allows index out of range in slice (`a = [1,2,3]; a[1:10] == [2,3]`).
|
||||||
|
pub fn handle_slice_index_bound<'ctx, G: CodeGenerator>(
|
||||||
|
i: &Expr<Option<Type>>,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut G,
|
||||||
|
length: IntValue<'ctx>,
|
||||||
|
) -> Result<Option<IntValue<'ctx>>, String> {
|
||||||
|
const SYMBOL: &str = "__nac3_slice_index_bound";
|
||||||
|
let func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
|
||||||
|
let i32_t = ctx.ctx.i32_type();
|
||||||
|
let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into()], false);
|
||||||
|
ctx.module.add_function(SYMBOL, fn_t, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
let i = if let Some(v) = generator.gen_expr(ctx, i)? {
|
||||||
|
v.to_basic_value_enum(ctx, generator, i.custom.unwrap())?
|
||||||
|
} else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
Ok(Some(
|
||||||
|
ctx.builder
|
||||||
|
.build_call(func, &[i.into(), length.into()], "bounded_ind")
|
||||||
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
|
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||||
|
.map(Either::unwrap_left)
|
||||||
|
.unwrap(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
start: IntValue<'ctx>,
|
||||||
|
end: IntValue<'ctx>,
|
||||||
|
step: IntValue<'ctx>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
const SYMBOL: &str = "__nac3_range_slice_len";
|
||||||
|
let len_func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
|
||||||
|
let i32_t = ctx.ctx.i32_type();
|
||||||
|
let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into(), i32_t.into()], false);
|
||||||
|
ctx.module.add_function(SYMBOL, fn_t, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
// assert step != 0, throw exception if not
|
||||||
|
let not_zero = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(IntPredicate::NE, step, step.get_type().const_zero(), "range_step_ne")
|
||||||
|
.unwrap();
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
not_zero,
|
||||||
|
"0:ValueError",
|
||||||
|
"step must not be zero",
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
ctx.builder
|
||||||
|
.build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len")
|
||||||
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
|
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||||
|
.map(Either::unwrap_left)
|
||||||
|
.unwrap()
|
||||||
|
}
|
|
@ -1,12 +1,14 @@
|
||||||
use crate::codegen::CodeGenContext;
|
use inkwell::{
|
||||||
use inkwell::context::Context;
|
context::Context,
|
||||||
use inkwell::intrinsics::Intrinsic;
|
intrinsics::Intrinsic,
|
||||||
use inkwell::types::AnyTypeEnum::IntType;
|
types::{AnyTypeEnum::IntType, FloatType},
|
||||||
use inkwell::types::FloatType;
|
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue},
|
||||||
use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue};
|
AddressSpace,
|
||||||
use inkwell::AddressSpace;
|
};
|
||||||
use itertools::Either;
|
use itertools::Either;
|
||||||
|
|
||||||
|
use super::CodeGenContext;
|
||||||
|
|
||||||
/// Returns the string representation for the floating-point type `ft` when used in intrinsic
|
/// Returns the string representation for the floating-point type `ft` when used in intrinsic
|
||||||
/// functions.
|
/// functions.
|
||||||
fn get_float_intrinsic_repr(ctx: &Context, ft: FloatType) -> &'static str {
|
fn get_float_intrinsic_repr(ctx: &Context, ft: FloatType) -> &'static str {
|
||||||
|
@ -183,7 +185,7 @@ pub fn call_memcpy_generic<'ctx>(
|
||||||
dest
|
dest
|
||||||
} else {
|
} else {
|
||||||
ctx.builder
|
ctx.builder
|
||||||
.build_bitcast(dest, llvm_p0i8, "")
|
.build_bit_cast(dest, llvm_p0i8, "")
|
||||||
.map(BasicValueEnum::into_pointer_value)
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
};
|
};
|
||||||
|
@ -191,7 +193,7 @@ pub fn call_memcpy_generic<'ctx>(
|
||||||
src
|
src
|
||||||
} else {
|
} else {
|
||||||
ctx.builder
|
ctx.builder
|
||||||
.build_bitcast(src, llvm_p0i8, "")
|
.build_bit_cast(src, llvm_p0i8, "")
|
||||||
.map(BasicValueEnum::into_pointer_value)
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
};
|
};
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
use crate::{
|
use std::{
|
||||||
codegen::classes::{ListType, ProxyType, RangeType},
|
collections::{HashMap, HashSet},
|
||||||
symbol_resolver::{StaticValue, SymbolResolver},
|
sync::{
|
||||||
toplevel::{helper::PrimDef, TopLevelContext, TopLevelDef},
|
atomic::{AtomicBool, Ordering},
|
||||||
typecheck::{
|
Arc,
|
||||||
type_inferencer::{CodeLocation, PrimitiveStore},
|
|
||||||
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
|
|
||||||
},
|
},
|
||||||
|
thread,
|
||||||
};
|
};
|
||||||
|
|
||||||
use crossbeam::channel::{unbounded, Receiver, Sender};
|
use crossbeam::channel::{unbounded, Receiver, Sender};
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
attributes::{Attribute, AttributeLoc},
|
attributes::{Attribute, AttributeLoc},
|
||||||
|
@ -24,36 +24,37 @@ use inkwell::{
|
||||||
AddressSpace, IntPredicate, OptimizationLevel,
|
AddressSpace, IntPredicate, OptimizationLevel,
|
||||||
};
|
};
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use model::*;
|
|
||||||
use nac3parser::ast::{Location, Stmt, StrRef};
|
|
||||||
use object::ndarray::NDArray;
|
|
||||||
use parking_lot::{Condvar, Mutex};
|
use parking_lot::{Condvar, Mutex};
|
||||||
use std::collections::{HashMap, HashSet};
|
|
||||||
use std::sync::{
|
use nac3parser::ast::{Location, Stmt, StrRef};
|
||||||
atomic::{AtomicBool, Ordering},
|
|
||||||
Arc,
|
use crate::{
|
||||||
|
symbol_resolver::{StaticValue, SymbolResolver},
|
||||||
|
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef},
|
||||||
|
typecheck::{
|
||||||
|
type_inferencer::{CodeLocation, PrimitiveStore},
|
||||||
|
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
|
||||||
|
},
|
||||||
};
|
};
|
||||||
use std::thread;
|
use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore};
|
||||||
|
pub use generator::{CodeGenerator, DefaultCodeGenerator};
|
||||||
|
use types::{ListType, NDArrayType, ProxyType, RangeType};
|
||||||
|
|
||||||
pub mod builtin_fns;
|
pub mod builtin_fns;
|
||||||
pub mod classes;
|
|
||||||
pub mod concrete_type;
|
pub mod concrete_type;
|
||||||
pub mod expr;
|
pub mod expr;
|
||||||
pub mod extern_fns;
|
pub mod extern_fns;
|
||||||
mod generator;
|
mod generator;
|
||||||
pub mod irrt;
|
pub mod irrt;
|
||||||
pub mod llvm_intrinsics;
|
pub mod llvm_intrinsics;
|
||||||
pub mod model;
|
|
||||||
pub mod numpy;
|
pub mod numpy;
|
||||||
pub mod object;
|
|
||||||
pub mod stmt;
|
pub mod stmt;
|
||||||
|
pub mod types;
|
||||||
|
pub mod values;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test;
|
mod test;
|
||||||
|
|
||||||
use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore};
|
|
||||||
pub use generator::{CodeGenerator, DefaultCodeGenerator};
|
|
||||||
|
|
||||||
mod macros {
|
mod macros {
|
||||||
/// Codegen-variant of [`std::unreachable`] which accepts an instance of [`CodeGenContext`] as
|
/// Codegen-variant of [`std::unreachable`] which accepts an instance of [`CodeGenContext`] as
|
||||||
/// its first argument to provide Python source information to indicate the codegen location
|
/// its first argument to provide Python source information to indicate the codegen location
|
||||||
|
@ -509,7 +510,12 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
}
|
}
|
||||||
|
|
||||||
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||||
Ptr(Struct(NDArray)).llvm_type(generator, ctx).as_basic_type_enum()
|
let (dtype, _) = unpack_ndarray_var_tys(unifier, ty);
|
||||||
|
let element_type = get_llvm_type(
|
||||||
|
ctx, module, generator, unifier, top_level, type_cache, dtype,
|
||||||
|
);
|
||||||
|
|
||||||
|
NDArrayType::new(generator, ctx, element_type).as_base_type().into()
|
||||||
}
|
}
|
||||||
|
|
||||||
_ => unreachable!(
|
_ => unreachable!(
|
||||||
|
@ -847,10 +853,9 @@ pub fn gen_func_impl<
|
||||||
builder.position_at_end(init_bb);
|
builder.position_at_end(init_bb);
|
||||||
let body_bb = context.append_basic_block(fn_val, "body");
|
let body_bb = context.append_basic_block(fn_val, "body");
|
||||||
|
|
||||||
|
// Store non-vararg argument values into local variables
|
||||||
let mut var_assignment = HashMap::new();
|
let mut var_assignment = HashMap::new();
|
||||||
let offset = u32::from(has_sret);
|
let offset = u32::from(has_sret);
|
||||||
|
|
||||||
// Store non-vararg argument values into local variables
|
|
||||||
for (n, arg) in args.iter().enumerate().filter(|(_, arg)| !arg.is_vararg) {
|
for (n, arg) in args.iter().enumerate().filter(|(_, arg)| !arg.is_vararg) {
|
||||||
let param = fn_val.get_nth_param((n as u32) + offset).unwrap();
|
let param = fn_val.get_nth_param((n as u32) + offset).unwrap();
|
||||||
let local_type = get_llvm_type(
|
let local_type = get_llvm_type(
|
||||||
|
|
|
@ -1,42 +0,0 @@
|
||||||
use inkwell::{
|
|
||||||
context::Context,
|
|
||||||
types::{BasicType, BasicTypeEnum},
|
|
||||||
values::BasicValueEnum,
|
|
||||||
};
|
|
||||||
|
|
||||||
use crate::codegen::CodeGenerator;
|
|
||||||
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
/// A [`Model`] of any [`BasicTypeEnum`].
|
|
||||||
///
|
|
||||||
/// Use this when it is infeasible to use model abstractions.
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
|
||||||
pub struct Any<'ctx>(pub BasicTypeEnum<'ctx>);
|
|
||||||
|
|
||||||
impl<'ctx> Model<'ctx> for Any<'ctx> {
|
|
||||||
type Value = BasicValueEnum<'ctx>;
|
|
||||||
type Type = BasicTypeEnum<'ctx>;
|
|
||||||
|
|
||||||
fn llvm_type<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
_generator: &G,
|
|
||||||
_ctx: &'ctx Context,
|
|
||||||
) -> Self::Type {
|
|
||||||
self.0
|
|
||||||
}
|
|
||||||
|
|
||||||
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
_generator: &mut G,
|
|
||||||
_ctx: &'ctx Context,
|
|
||||||
ty: T,
|
|
||||||
) -> Result<(), ModelError> {
|
|
||||||
let ty = ty.as_basic_type_enum();
|
|
||||||
if ty == self.0 {
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
Err(ModelError(format!("Expecting {}, but got {}", self.0, ty)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,147 +0,0 @@
|
||||||
use std::fmt;
|
|
||||||
|
|
||||||
use inkwell::{
|
|
||||||
context::Context,
|
|
||||||
types::{ArrayType, BasicType, BasicTypeEnum},
|
|
||||||
values::{ArrayValue, IntValue},
|
|
||||||
};
|
|
||||||
|
|
||||||
use crate::codegen::{CodeGenContext, CodeGenerator};
|
|
||||||
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
/// Trait for Rust structs identifying length values for [`Array`].
|
|
||||||
pub trait ArrayLen: fmt::Debug + Clone + Copy {
|
|
||||||
fn length(&self) -> u32;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A statically known length.
|
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
|
||||||
pub struct Len<const N: u32>;
|
|
||||||
|
|
||||||
/// A dynamically known length.
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
|
||||||
pub struct AnyLen(pub u32);
|
|
||||||
|
|
||||||
impl<const N: u32> ArrayLen for Len<N> {
|
|
||||||
fn length(&self) -> u32 {
|
|
||||||
N
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ArrayLen for AnyLen {
|
|
||||||
fn length(&self) -> u32 {
|
|
||||||
self.0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A Model for an [`ArrayType`].
|
|
||||||
///
|
|
||||||
/// `Len` should be of a [`LenKind`] and `Item` should be a of [`Model`].
|
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
|
||||||
pub struct Array<Len, Item> {
|
|
||||||
/// Length of this array.
|
|
||||||
pub len: Len,
|
|
||||||
/// [`Model`] of the array items.
|
|
||||||
pub item: Item,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx, Len: ArrayLen, Item: Model<'ctx>> Model<'ctx> for Array<Len, Item> {
|
|
||||||
type Value = ArrayValue<'ctx>;
|
|
||||||
type Type = ArrayType<'ctx>;
|
|
||||||
|
|
||||||
fn llvm_type<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> Self::Type {
|
|
||||||
self.item.llvm_type(generator, ctx).array_type(self.len.length())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
ty: T,
|
|
||||||
) -> Result<(), ModelError> {
|
|
||||||
let ty = ty.as_basic_type_enum();
|
|
||||||
let BasicTypeEnum::ArrayType(ty) = ty else {
|
|
||||||
return Err(ModelError(format!("Expecting ArrayType, but got {ty:?}")));
|
|
||||||
};
|
|
||||||
|
|
||||||
if ty.len() != self.len.length() {
|
|
||||||
return Err(ModelError(format!(
|
|
||||||
"Expecting ArrayType with size {}, but got an ArrayType with size {}",
|
|
||||||
ty.len(),
|
|
||||||
self.len.length()
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
|
|
||||||
self.item
|
|
||||||
.check_type(generator, ctx, ty.get_element_type())
|
|
||||||
.map_err(|err| err.under_context("an ArrayType"))?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx, Len: ArrayLen, Item: Model<'ctx>> Instance<'ctx, Ptr<Array<Len, Item>>> {
|
|
||||||
/// Get the pointer to the `i`-th (0-based) array element.
|
|
||||||
pub fn gep(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
i: IntValue<'ctx>,
|
|
||||||
) -> Instance<'ctx, Ptr<Item>> {
|
|
||||||
let zero = ctx.ctx.i32_type().const_zero();
|
|
||||||
let ptr = unsafe { ctx.builder.build_in_bounds_gep(self.value, &[zero, i], "").unwrap() };
|
|
||||||
|
|
||||||
unsafe { Ptr(self.model.0.item).believe_value(ptr) }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Like `gep` but `i` is a constant.
|
|
||||||
pub fn gep_const(&self, ctx: &CodeGenContext<'ctx, '_>, i: u64) -> Instance<'ctx, Ptr<Item>> {
|
|
||||||
assert!(
|
|
||||||
i < u64::from(self.model.0.len.length()),
|
|
||||||
"Index {i} is out of bounds. Array length = {}",
|
|
||||||
self.model.0.len.length()
|
|
||||||
);
|
|
||||||
|
|
||||||
let i = ctx.ctx.i32_type().const_int(i, false);
|
|
||||||
self.gep(ctx, i)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convenience function equivalent to `.gep(...).load(...)`.
|
|
||||||
pub fn get<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
i: IntValue<'ctx>,
|
|
||||||
) -> Instance<'ctx, Item> {
|
|
||||||
self.gep(ctx, i).load(generator, ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Like `get` but `i` is a constant.
|
|
||||||
pub fn get_const<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
i: u64,
|
|
||||||
) -> Instance<'ctx, Item> {
|
|
||||||
self.gep_const(ctx, i).load(generator, ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convenience function equivalent to `.gep(...).store(...)`.
|
|
||||||
pub fn set(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
i: IntValue<'ctx>,
|
|
||||||
value: Instance<'ctx, Item>,
|
|
||||||
) {
|
|
||||||
self.gep(ctx, i).store(ctx, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Like `set` but `i` is a constant.
|
|
||||||
pub fn set_const(&self, ctx: &CodeGenContext<'ctx, '_>, i: u64, value: Instance<'ctx, Item>) {
|
|
||||||
self.gep_const(ctx, i).store(ctx, value);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,207 +0,0 @@
|
||||||
use std::fmt;
|
|
||||||
|
|
||||||
use inkwell::{context::Context, types::*, values::*};
|
|
||||||
use itertools::Itertools;
|
|
||||||
|
|
||||||
use super::*;
|
|
||||||
use crate::codegen::{CodeGenContext, CodeGenerator};
|
|
||||||
|
|
||||||
/// A error type for reporting any [`Model`]-related error (e.g., a [`BasicType`] mismatch).
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct ModelError(pub String);
|
|
||||||
|
|
||||||
impl ModelError {
|
|
||||||
/// Append a context message to the error.
|
|
||||||
pub(super) fn under_context(mut self, context: &str) -> Self {
|
|
||||||
self.0.push_str(" ... in ");
|
|
||||||
self.0.push_str(context);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Trait for Rust structs identifying [`BasicType`]s in the context of a known [`CodeGenerator`] and [`CodeGenContext`].
|
|
||||||
///
|
|
||||||
/// For instance,
|
|
||||||
/// - [`Int<Int32>`] identifies an [`IntType`] with 32-bits.
|
|
||||||
/// - [`Int<SizeT>`] identifies an [`IntType`] with bit-width [`CodeGenerator::get_size_type`].
|
|
||||||
/// - [`Ptr<Int<SizeT>>`] identifies a [`PointerType`] that points to an [`IntType`] with bit-width [`CodeGenerator::get_size_type`].
|
|
||||||
/// - [`Int<AnyInt>`] identifies an [`IntType`] with bit-width of whatever is set in the [`AnyInt`] object.
|
|
||||||
/// - [`Any`] identifies a [`BasicType`] set in the [`Any`] object itself.
|
|
||||||
///
|
|
||||||
/// You can get the [`BasicType`] out of a model with [`Model::get_type`].
|
|
||||||
///
|
|
||||||
/// Furthermore, [`Instance<'ctx, M>`] is a simple structure that carries a [`BasicValue`] with [`BasicType`] identified by model `M`.
|
|
||||||
///
|
|
||||||
/// The main purpose of this abstraction is to have a more Rust type-safe way to use Inkwell and give type-hints for programmers.
|
|
||||||
///
|
|
||||||
/// ### Notes on `Default` trait
|
|
||||||
///
|
|
||||||
/// For some models like [`Int<Int32>`] or [`Int<SizeT>`], they have a [`Default`] trait since just by looking at their types, it is possible
|
|
||||||
/// to tell the [`BasicType`]s they are identifying.
|
|
||||||
///
|
|
||||||
/// This can be used to create strongly-typed interfaces accepting only values of a specific [`BasicType`] without having to worry about
|
|
||||||
/// writing debug assertions to check, for example, if the programmer has passed in an [`IntValue`] with the wrong bit-width.
|
|
||||||
/// ```ignore
|
|
||||||
/// fn give_me_i32_and_get_a_size_t_back<'ctx>(i32: Instance<'ctx, Int<Int32>>) -> Instance<'ctx, Int<SizeT>> {
|
|
||||||
/// // code...
|
|
||||||
/// }
|
|
||||||
/// ```
|
|
||||||
///
|
|
||||||
/// ### Notes on converting between Inkwell and model/ge.
|
|
||||||
///
|
|
||||||
/// Suppose you have an [`IntValue`], and you want to pass it into a function that takes a [`Instance<'ctx, Int<Int32>>`]. You can do use
|
|
||||||
/// [`Model::check_value`] or [`Model::believe_value`].
|
|
||||||
/// ```ignore
|
|
||||||
/// let my_value: IntValue<'ctx>;
|
|
||||||
///
|
|
||||||
/// let my_value = Int(Int32).check_value(my_value).unwrap(); // Panics if `my_value` is not 32-bit with a descriptive error message.
|
|
||||||
///
|
|
||||||
/// // or, if you are absolutely certain that `my_value` is 32-bit and doing extra checks is a waste of time:
|
|
||||||
/// let my_value = Int(Int32).believe_value(my_value);
|
|
||||||
/// ```
|
|
||||||
pub trait Model<'ctx>: fmt::Debug + Clone + Copy {
|
|
||||||
/// The [`BasicType`] *variant* this model is identifying.
|
|
||||||
type Type: BasicType<'ctx>;
|
|
||||||
|
|
||||||
/// The [`BasicValue`] type of the [`BasicType`] of this model.
|
|
||||||
type Value: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>>;
|
|
||||||
|
|
||||||
/// Return the [`BasicType`] of this model.
|
|
||||||
#[must_use]
|
|
||||||
fn llvm_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context)
|
|
||||||
-> Self::Type;
|
|
||||||
|
|
||||||
/// Get the number of bytes of the [`BasicType`] of this model.
|
|
||||||
fn size_of<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> IntValue<'ctx> {
|
|
||||||
self.llvm_type(generator, ctx).size_of().unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if a [`BasicType`] matches the [`BasicType`] of this model.
|
|
||||||
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
ty: T,
|
|
||||||
) -> Result<(), ModelError>;
|
|
||||||
|
|
||||||
/// Create an instance from a value.
|
|
||||||
///
|
|
||||||
/// # Safety
|
|
||||||
///
|
|
||||||
/// Caller must make sure the type of `value` and the type of this `model` are equivalent.
|
|
||||||
#[must_use]
|
|
||||||
unsafe fn believe_value(&self, value: Self::Value) -> Instance<'ctx, Self> {
|
|
||||||
Instance { model: *self, value }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if a [`BasicValue`]'s type is equivalent to the type of this model.
|
|
||||||
/// Wrap the [`BasicValue`] into an [`Instance`] if it is.
|
|
||||||
fn check_value<V: BasicValue<'ctx>, G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
value: V,
|
|
||||||
) -> Result<Instance<'ctx, Self>, ModelError> {
|
|
||||||
let value = value.as_basic_value_enum();
|
|
||||||
self.check_type(generator, ctx, value.get_type())
|
|
||||||
.map_err(|err| err.under_context(format!("the value {value:?}").as_str()))?;
|
|
||||||
|
|
||||||
let Ok(value) = Self::Value::try_from(value) else {
|
|
||||||
unreachable!("check_type() has bad implementation")
|
|
||||||
};
|
|
||||||
unsafe { Ok(self.believe_value(value)) }
|
|
||||||
}
|
|
||||||
|
|
||||||
// Allocate a value on the stack and return its pointer.
|
|
||||||
fn alloca<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
) -> Instance<'ctx, Ptr<Self>> {
|
|
||||||
let p = ctx.builder.build_alloca(self.llvm_type(generator, ctx.ctx), "").unwrap();
|
|
||||||
unsafe { Ptr(*self).believe_value(p) }
|
|
||||||
}
|
|
||||||
|
|
||||||
// Allocate an array on the stack and return its pointer.
|
|
||||||
fn array_alloca<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
len: IntValue<'ctx>,
|
|
||||||
) -> Instance<'ctx, Ptr<Self>> {
|
|
||||||
let p =
|
|
||||||
ctx.builder.build_array_alloca(self.llvm_type(generator, ctx.ctx), len, "").unwrap();
|
|
||||||
unsafe { Ptr(*self).believe_value(p) }
|
|
||||||
}
|
|
||||||
|
|
||||||
fn var_alloca<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
name: Option<&str>,
|
|
||||||
) -> Result<Instance<'ctx, Ptr<Self>>, String> {
|
|
||||||
let ty = self.llvm_type(generator, ctx.ctx).as_basic_type_enum();
|
|
||||||
let p = generator.gen_var_alloc(ctx, ty, name)?;
|
|
||||||
unsafe { Ok(Ptr(*self).believe_value(p)) }
|
|
||||||
}
|
|
||||||
|
|
||||||
fn array_var_alloca<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
len: IntValue<'ctx>,
|
|
||||||
name: Option<&'ctx str>,
|
|
||||||
) -> Result<Instance<'ctx, Ptr<Self>>, String> {
|
|
||||||
// TODO: Remove ArraySliceValue
|
|
||||||
let ty = self.llvm_type(generator, ctx.ctx).as_basic_type_enum();
|
|
||||||
let p = generator.gen_array_var_alloc(ctx, ty, len, name)?;
|
|
||||||
unsafe { Ok(Ptr(*self).believe_value(PointerValue::from(p))) }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Allocate a constant array.
|
|
||||||
fn const_array<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
values: &[Instance<'ctx, Self>],
|
|
||||||
) -> Instance<'ctx, Array<AnyLen, Self>> {
|
|
||||||
macro_rules! make {
|
|
||||||
($t:expr, $into_value:expr) => {
|
|
||||||
$t.const_array(
|
|
||||||
&values
|
|
||||||
.iter()
|
|
||||||
.map(|x| $into_value(x.value.as_basic_value_enum()))
|
|
||||||
.collect_vec(),
|
|
||||||
)
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
let value = match self.llvm_type(generator, ctx).as_basic_type_enum() {
|
|
||||||
BasicTypeEnum::ArrayType(t) => make!(t, BasicValueEnum::into_array_value),
|
|
||||||
BasicTypeEnum::IntType(t) => make!(t, BasicValueEnum::into_int_value),
|
|
||||||
BasicTypeEnum::FloatType(t) => make!(t, BasicValueEnum::into_float_value),
|
|
||||||
BasicTypeEnum::PointerType(t) => make!(t, BasicValueEnum::into_pointer_value),
|
|
||||||
BasicTypeEnum::StructType(t) => make!(t, BasicValueEnum::into_struct_value),
|
|
||||||
BasicTypeEnum::VectorType(t) => make!(t, BasicValueEnum::into_vector_value),
|
|
||||||
};
|
|
||||||
|
|
||||||
Array { len: AnyLen(values.len() as u32), item: *self }
|
|
||||||
.check_value(generator, ctx, value)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
|
||||||
pub struct Instance<'ctx, M: Model<'ctx>> {
|
|
||||||
/// The model of this instance.
|
|
||||||
pub model: M,
|
|
||||||
|
|
||||||
/// The value of this instance.
|
|
||||||
///
|
|
||||||
/// It is guaranteed the [`BasicType`] of `value` is consistent with that of `model`.
|
|
||||||
pub value: M::Value,
|
|
||||||
}
|
|
|
@ -1,94 +0,0 @@
|
||||||
use std::fmt;
|
|
||||||
|
|
||||||
use inkwell::{
|
|
||||||
context::Context,
|
|
||||||
types::{BasicType, FloatType},
|
|
||||||
values::FloatValue,
|
|
||||||
};
|
|
||||||
|
|
||||||
use crate::codegen::CodeGenerator;
|
|
||||||
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
pub trait FloatKind<'ctx>: fmt::Debug + Clone + Copy {
|
|
||||||
fn get_float_type<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> FloatType<'ctx>;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
|
||||||
pub struct Float32;
|
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
|
||||||
pub struct Float64;
|
|
||||||
|
|
||||||
impl<'ctx> FloatKind<'ctx> for Float32 {
|
|
||||||
fn get_float_type<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
_generator: &G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> FloatType<'ctx> {
|
|
||||||
ctx.f32_type()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> FloatKind<'ctx> for Float64 {
|
|
||||||
fn get_float_type<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
_generator: &G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> FloatType<'ctx> {
|
|
||||||
ctx.f64_type()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
|
||||||
pub struct AnyFloat<'ctx>(FloatType<'ctx>);
|
|
||||||
|
|
||||||
impl<'ctx> FloatKind<'ctx> for AnyFloat<'ctx> {
|
|
||||||
fn get_float_type<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
_generator: &G,
|
|
||||||
_ctx: &'ctx Context,
|
|
||||||
) -> FloatType<'ctx> {
|
|
||||||
self.0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
|
||||||
pub struct Float<N>(pub N);
|
|
||||||
|
|
||||||
impl<'ctx, N: FloatKind<'ctx>> Model<'ctx> for Float<N> {
|
|
||||||
type Value = FloatValue<'ctx>;
|
|
||||||
type Type = FloatType<'ctx>;
|
|
||||||
|
|
||||||
fn llvm_type<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> Self::Type {
|
|
||||||
self.0.get_float_type(generator, ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
ty: T,
|
|
||||||
) -> Result<(), ModelError> {
|
|
||||||
let ty = ty.as_basic_type_enum();
|
|
||||||
let Ok(ty) = FloatType::try_from(ty) else {
|
|
||||||
return Err(ModelError(format!("Expecting FloatType, but got {ty:?}")));
|
|
||||||
};
|
|
||||||
|
|
||||||
let exp_ty = self.0.get_float_type(generator, ctx);
|
|
||||||
|
|
||||||
// TODO: Inkwell does not have get_bit_width for FloatType?
|
|
||||||
if ty != exp_ty {
|
|
||||||
return Err(ModelError(format!("Expecting {exp_ty:?}, but got {ty:?}")));
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,122 +0,0 @@
|
||||||
use inkwell::{
|
|
||||||
attributes::{Attribute, AttributeLoc},
|
|
||||||
types::{BasicMetadataTypeEnum, BasicType, FunctionType},
|
|
||||||
values::{AnyValue, BasicMetadataValueEnum, BasicValue, BasicValueEnum, CallSiteValue},
|
|
||||||
};
|
|
||||||
use itertools::Itertools;
|
|
||||||
|
|
||||||
use crate::codegen::{CodeGenContext, CodeGenerator};
|
|
||||||
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
|
||||||
struct Arg<'ctx> {
|
|
||||||
ty: BasicMetadataTypeEnum<'ctx>,
|
|
||||||
val: BasicMetadataValueEnum<'ctx>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A convenience structure to construct & call an LLVM function.
|
|
||||||
///
|
|
||||||
/// ### Usage
|
|
||||||
///
|
|
||||||
/// The syntax is like this:
|
|
||||||
/// ```ignore
|
|
||||||
/// let result = CallFunction::begin("my_function_name")
|
|
||||||
/// .attrs(...)
|
|
||||||
/// .arg(arg1)
|
|
||||||
/// .arg(arg2)
|
|
||||||
/// .arg(arg3)
|
|
||||||
/// .returning("my_function_result", Int32);
|
|
||||||
/// ```
|
|
||||||
///
|
|
||||||
/// The function `my_function_name` is called when `.returning()` (or its variants) is called, returning
|
|
||||||
/// the result as an `Instance<'ctx, Int<Int32>>`.
|
|
||||||
///
|
|
||||||
/// If `my_function_name` has not been declared in `ctx.module`, once `.returning()` is called, a function
|
|
||||||
/// declaration of `my_function_name` is added to `ctx.module`, where the [`FunctionType`] is deduced from
|
|
||||||
/// the argument types and returning type.
|
|
||||||
pub struct FnCall<'ctx, 'a, 'b, 'c, 'd, G: CodeGenerator + ?Sized> {
|
|
||||||
generator: &'d mut G,
|
|
||||||
ctx: &'b CodeGenContext<'ctx, 'a>,
|
|
||||||
/// Function name
|
|
||||||
name: &'c str,
|
|
||||||
/// Call arguments
|
|
||||||
args: Vec<Arg<'ctx>>,
|
|
||||||
/// LLVM function Attributes
|
|
||||||
attrs: Vec<&'static str>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx, 'a, 'b, 'c, 'd, G: CodeGenerator + ?Sized> FnCall<'ctx, 'a, 'b, 'c, 'd, G> {
|
|
||||||
pub fn builder(generator: &'d mut G, ctx: &'b CodeGenContext<'ctx, 'a>, name: &'c str) -> Self {
|
|
||||||
FnCall { generator, ctx, name, args: Vec::new(), attrs: Vec::new() }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Push a list of LLVM function attributes to the function declaration.
|
|
||||||
#[must_use]
|
|
||||||
pub fn attrs(mut self, attrs: Vec<&'static str>) -> Self {
|
|
||||||
self.attrs = attrs;
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Push a call argument to the function call.
|
|
||||||
#[allow(clippy::needless_pass_by_value)]
|
|
||||||
#[must_use]
|
|
||||||
pub fn arg<M: Model<'ctx>>(mut self, arg: Instance<'ctx, M>) -> Self {
|
|
||||||
let arg = Arg {
|
|
||||||
ty: arg.model.llvm_type(self.generator, self.ctx.ctx).as_basic_type_enum().into(),
|
|
||||||
val: arg.value.as_basic_value_enum().into(),
|
|
||||||
};
|
|
||||||
self.args.push(arg);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Call the function and expect the function to return a value of type of `return_model`.
|
|
||||||
#[must_use]
|
|
||||||
pub fn returning<M: Model<'ctx>>(self, name: &str, return_model: M) -> Instance<'ctx, M> {
|
|
||||||
let ret_ty = return_model.llvm_type(self.generator, self.ctx.ctx);
|
|
||||||
|
|
||||||
let ret = self.call(|tys| ret_ty.fn_type(tys, false), name);
|
|
||||||
let ret = BasicValueEnum::try_from(ret.as_any_value_enum()).unwrap(); // Must work
|
|
||||||
let ret = return_model.check_value(self.generator, self.ctx.ctx, ret).unwrap(); // Must work
|
|
||||||
ret
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Like [`CallFunction::returning_`] but `return_model` is automatically inferred.
|
|
||||||
#[must_use]
|
|
||||||
pub fn returning_auto<M: Model<'ctx> + Default>(self, name: &str) -> Instance<'ctx, M> {
|
|
||||||
self.returning(name, M::default())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Call the function and expect the function to return a void-type.
|
|
||||||
pub fn returning_void(self) {
|
|
||||||
let ret_ty = self.ctx.ctx.void_type();
|
|
||||||
|
|
||||||
let _ = self.call(|tys| ret_ty.fn_type(tys, false), "");
|
|
||||||
}
|
|
||||||
|
|
||||||
fn call<F>(&self, make_fn_type: F, return_value_name: &str) -> CallSiteValue<'ctx>
|
|
||||||
where
|
|
||||||
F: FnOnce(&[BasicMetadataTypeEnum<'ctx>]) -> FunctionType<'ctx>,
|
|
||||||
{
|
|
||||||
// Get the LLVM function.
|
|
||||||
let func = self.ctx.module.get_function(self.name).unwrap_or_else(|| {
|
|
||||||
// Declare the function if it doesn't exist.
|
|
||||||
let tys = self.args.iter().map(|arg| arg.ty).collect_vec();
|
|
||||||
|
|
||||||
let func_type = make_fn_type(&tys);
|
|
||||||
let func = self.ctx.module.add_function(self.name, func_type, None);
|
|
||||||
|
|
||||||
for attr in &self.attrs {
|
|
||||||
func.add_attribute(
|
|
||||||
AttributeLoc::Function,
|
|
||||||
self.ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
func
|
|
||||||
});
|
|
||||||
|
|
||||||
let vals = self.args.iter().map(|arg| arg.val).collect_vec();
|
|
||||||
self.ctx.builder.build_call(func, &vals, return_value_name).unwrap()
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,422 +0,0 @@
|
||||||
use std::{cmp::Ordering, fmt};
|
|
||||||
|
|
||||||
use inkwell::{
|
|
||||||
context::Context,
|
|
||||||
types::{BasicType, IntType},
|
|
||||||
values::IntValue,
|
|
||||||
IntPredicate,
|
|
||||||
};
|
|
||||||
|
|
||||||
use crate::codegen::{CodeGenContext, CodeGenerator};
|
|
||||||
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
pub trait IntKind<'ctx>: fmt::Debug + Clone + Copy {
|
|
||||||
fn get_int_type<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> IntType<'ctx>;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
|
||||||
pub struct Bool;
|
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
|
||||||
pub struct Byte;
|
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
|
||||||
pub struct Int32;
|
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
|
||||||
pub struct Int64;
|
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
|
||||||
pub struct SizeT;
|
|
||||||
|
|
||||||
impl<'ctx> IntKind<'ctx> for Bool {
|
|
||||||
fn get_int_type<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
_generator: &G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> IntType<'ctx> {
|
|
||||||
ctx.bool_type()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> IntKind<'ctx> for Byte {
|
|
||||||
fn get_int_type<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
_generator: &G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> IntType<'ctx> {
|
|
||||||
ctx.i8_type()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> IntKind<'ctx> for Int32 {
|
|
||||||
fn get_int_type<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
_generator: &G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> IntType<'ctx> {
|
|
||||||
ctx.i32_type()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> IntKind<'ctx> for Int64 {
|
|
||||||
fn get_int_type<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
_generator: &G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> IntType<'ctx> {
|
|
||||||
ctx.i64_type()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> IntKind<'ctx> for SizeT {
|
|
||||||
fn get_int_type<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> IntType<'ctx> {
|
|
||||||
generator.get_size_type(ctx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
|
||||||
pub struct AnyInt<'ctx>(pub IntType<'ctx>);
|
|
||||||
|
|
||||||
impl<'ctx> IntKind<'ctx> for AnyInt<'ctx> {
|
|
||||||
fn get_int_type<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
_generator: &G,
|
|
||||||
_ctx: &'ctx Context,
|
|
||||||
) -> IntType<'ctx> {
|
|
||||||
self.0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
|
||||||
pub struct Int<N>(pub N);
|
|
||||||
|
|
||||||
impl<'ctx, N: IntKind<'ctx>> Model<'ctx> for Int<N> {
|
|
||||||
type Value = IntValue<'ctx>;
|
|
||||||
type Type = IntType<'ctx>;
|
|
||||||
|
|
||||||
fn llvm_type<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> Self::Type {
|
|
||||||
self.0.get_int_type(generator, ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
ty: T,
|
|
||||||
) -> Result<(), ModelError> {
|
|
||||||
let ty = ty.as_basic_type_enum();
|
|
||||||
let Ok(ty) = IntType::try_from(ty) else {
|
|
||||||
return Err(ModelError(format!("Expecting IntType, but got {ty:?}")));
|
|
||||||
};
|
|
||||||
|
|
||||||
let exp_ty = self.0.get_int_type(generator, ctx);
|
|
||||||
if ty.get_bit_width() != exp_ty.get_bit_width() {
|
|
||||||
return Err(ModelError(format!(
|
|
||||||
"Expecting IntType to have {} bit(s), but got {} bit(s)",
|
|
||||||
exp_ty.get_bit_width(),
|
|
||||||
ty.get_bit_width()
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx, N: IntKind<'ctx>> Int<N> {
|
|
||||||
pub fn const_int<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
value: u64,
|
|
||||||
sign_extend: bool,
|
|
||||||
) -> Instance<'ctx, Self> {
|
|
||||||
let value = self.llvm_type(generator, ctx).const_int(value, sign_extend);
|
|
||||||
unsafe { self.believe_value(value) }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn const_0<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> Instance<'ctx, Self> {
|
|
||||||
let value = self.llvm_type(generator, ctx).const_zero();
|
|
||||||
unsafe { self.believe_value(value) }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn const_1<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> Instance<'ctx, Self> {
|
|
||||||
self.const_int(generator, ctx, 1, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn const_all_ones<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> Instance<'ctx, Self> {
|
|
||||||
let value = self.llvm_type(generator, ctx).const_all_ones();
|
|
||||||
unsafe { self.believe_value(value) }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn s_extend_or_bit_cast<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
value: IntValue<'ctx>,
|
|
||||||
) -> Instance<'ctx, Self> {
|
|
||||||
assert!(
|
|
||||||
value.get_type().get_bit_width()
|
|
||||||
<= self.0.get_int_type(generator, ctx.ctx).get_bit_width()
|
|
||||||
);
|
|
||||||
let value = ctx
|
|
||||||
.builder
|
|
||||||
.build_int_s_extend_or_bit_cast(value, self.llvm_type(generator, ctx.ctx), "")
|
|
||||||
.unwrap();
|
|
||||||
unsafe { self.believe_value(value) }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn s_extend<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
value: IntValue<'ctx>,
|
|
||||||
) -> Instance<'ctx, Self> {
|
|
||||||
assert!(
|
|
||||||
value.get_type().get_bit_width()
|
|
||||||
< self.0.get_int_type(generator, ctx.ctx).get_bit_width()
|
|
||||||
);
|
|
||||||
let value =
|
|
||||||
ctx.builder.build_int_s_extend(value, self.llvm_type(generator, ctx.ctx), "").unwrap();
|
|
||||||
unsafe { self.believe_value(value) }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn z_extend_or_bit_cast<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
value: IntValue<'ctx>,
|
|
||||||
) -> Instance<'ctx, Self> {
|
|
||||||
assert!(
|
|
||||||
value.get_type().get_bit_width()
|
|
||||||
<= self.0.get_int_type(generator, ctx.ctx).get_bit_width()
|
|
||||||
);
|
|
||||||
let value = ctx
|
|
||||||
.builder
|
|
||||||
.build_int_z_extend_or_bit_cast(value, self.llvm_type(generator, ctx.ctx), "")
|
|
||||||
.unwrap();
|
|
||||||
unsafe { self.believe_value(value) }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn z_extend<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
value: IntValue<'ctx>,
|
|
||||||
) -> Instance<'ctx, Self> {
|
|
||||||
assert!(
|
|
||||||
value.get_type().get_bit_width()
|
|
||||||
< self.0.get_int_type(generator, ctx.ctx).get_bit_width()
|
|
||||||
);
|
|
||||||
let value =
|
|
||||||
ctx.builder.build_int_z_extend(value, self.llvm_type(generator, ctx.ctx), "").unwrap();
|
|
||||||
unsafe { self.believe_value(value) }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn truncate_or_bit_cast<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
value: IntValue<'ctx>,
|
|
||||||
) -> Instance<'ctx, Self> {
|
|
||||||
assert!(
|
|
||||||
value.get_type().get_bit_width()
|
|
||||||
>= self.0.get_int_type(generator, ctx.ctx).get_bit_width()
|
|
||||||
);
|
|
||||||
let value = ctx
|
|
||||||
.builder
|
|
||||||
.build_int_truncate_or_bit_cast(value, self.llvm_type(generator, ctx.ctx), "")
|
|
||||||
.unwrap();
|
|
||||||
unsafe { self.believe_value(value) }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn truncate<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
value: IntValue<'ctx>,
|
|
||||||
) -> Instance<'ctx, Self> {
|
|
||||||
assert!(
|
|
||||||
value.get_type().get_bit_width()
|
|
||||||
> self.0.get_int_type(generator, ctx.ctx).get_bit_width()
|
|
||||||
);
|
|
||||||
let value =
|
|
||||||
ctx.builder.build_int_truncate(value, self.llvm_type(generator, ctx.ctx), "").unwrap();
|
|
||||||
unsafe { self.believe_value(value) }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// `sext` or `trunc` an int to this model's int type. Does nothing if equal bit-widths.
|
|
||||||
pub fn s_extend_or_truncate<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
value: IntValue<'ctx>,
|
|
||||||
) -> Instance<'ctx, Self> {
|
|
||||||
let their_width = value.get_type().get_bit_width();
|
|
||||||
let our_width = self.0.get_int_type(generator, ctx.ctx).get_bit_width();
|
|
||||||
match their_width.cmp(&our_width) {
|
|
||||||
Ordering::Less => self.s_extend(generator, ctx, value),
|
|
||||||
Ordering::Equal => unsafe { self.believe_value(value) },
|
|
||||||
Ordering::Greater => self.truncate(generator, ctx, value),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// `zext` or `trunc` an int to this model's int type. Does nothing if equal bit-widths.
|
|
||||||
pub fn z_extend_or_truncate<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
value: IntValue<'ctx>,
|
|
||||||
) -> Instance<'ctx, Self> {
|
|
||||||
let their_width = value.get_type().get_bit_width();
|
|
||||||
let our_width = self.0.get_int_type(generator, ctx.ctx).get_bit_width();
|
|
||||||
match their_width.cmp(&our_width) {
|
|
||||||
Ordering::Less => self.z_extend(generator, ctx, value),
|
|
||||||
Ordering::Equal => unsafe { self.believe_value(value) },
|
|
||||||
Ordering::Greater => self.truncate(generator, ctx, value),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Int<Bool> {
|
|
||||||
#[must_use]
|
|
||||||
pub fn const_false<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> Instance<'ctx, Self> {
|
|
||||||
self.const_int(generator, ctx, 0, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn const_true<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> Instance<'ctx, Self> {
|
|
||||||
self.const_int(generator, ctx, 1, false)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx, N: IntKind<'ctx>> Instance<'ctx, Int<N>> {
|
|
||||||
pub fn s_extend_or_bit_cast<NewN: IntKind<'ctx>, G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
to_int_kind: NewN,
|
|
||||||
) -> Instance<'ctx, Int<NewN>> {
|
|
||||||
Int(to_int_kind).s_extend_or_bit_cast(generator, ctx, self.value)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn s_extend<NewN: IntKind<'ctx>, G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
to_int_kind: NewN,
|
|
||||||
) -> Instance<'ctx, Int<NewN>> {
|
|
||||||
Int(to_int_kind).s_extend(generator, ctx, self.value)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn z_extend_or_bit_cast<NewN: IntKind<'ctx>, G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
to_int_kind: NewN,
|
|
||||||
) -> Instance<'ctx, Int<NewN>> {
|
|
||||||
Int(to_int_kind).z_extend_or_bit_cast(generator, ctx, self.value)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn z_extend<NewN: IntKind<'ctx>, G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
to_int_kind: NewN,
|
|
||||||
) -> Instance<'ctx, Int<NewN>> {
|
|
||||||
Int(to_int_kind).z_extend(generator, ctx, self.value)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn truncate_or_bit_cast<NewN: IntKind<'ctx>, G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
to_int_kind: NewN,
|
|
||||||
) -> Instance<'ctx, Int<NewN>> {
|
|
||||||
Int(to_int_kind).truncate_or_bit_cast(generator, ctx, self.value)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn truncate<NewN: IntKind<'ctx>, G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
to_int_kind: NewN,
|
|
||||||
) -> Instance<'ctx, Int<NewN>> {
|
|
||||||
Int(to_int_kind).truncate(generator, ctx, self.value)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn s_extend_or_truncate<NewN: IntKind<'ctx>, G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
to_int_kind: NewN,
|
|
||||||
) -> Instance<'ctx, Int<NewN>> {
|
|
||||||
Int(to_int_kind).s_extend_or_truncate(generator, ctx, self.value)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn z_extend_or_truncate<NewN: IntKind<'ctx>, G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
to_int_kind: NewN,
|
|
||||||
) -> Instance<'ctx, Int<NewN>> {
|
|
||||||
Int(to_int_kind).z_extend_or_truncate(generator, ctx, self.value)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn add(&self, ctx: &CodeGenContext<'ctx, '_>, other: Self) -> Self {
|
|
||||||
let value = ctx.builder.build_int_add(self.value, other.value, "").unwrap();
|
|
||||||
unsafe { self.model.believe_value(value) }
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn sub(&self, ctx: &CodeGenContext<'ctx, '_>, other: Self) -> Self {
|
|
||||||
let value = ctx.builder.build_int_sub(self.value, other.value, "").unwrap();
|
|
||||||
unsafe { self.model.believe_value(value) }
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn mul(&self, ctx: &CodeGenContext<'ctx, '_>, other: Self) -> Self {
|
|
||||||
let value = ctx.builder.build_int_mul(self.value, other.value, "").unwrap();
|
|
||||||
unsafe { self.model.believe_value(value) }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn compare(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
op: IntPredicate,
|
|
||||||
other: Self,
|
|
||||||
) -> Instance<'ctx, Int<Bool>> {
|
|
||||||
let value = ctx.builder.build_int_compare(op, self.value, other.value, "").unwrap();
|
|
||||||
unsafe { Int(Bool).believe_value(value) }
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,17 +0,0 @@
|
||||||
mod any;
|
|
||||||
mod array;
|
|
||||||
mod core;
|
|
||||||
mod float;
|
|
||||||
pub mod function;
|
|
||||||
mod int;
|
|
||||||
mod ptr;
|
|
||||||
mod structure;
|
|
||||||
pub mod util;
|
|
||||||
|
|
||||||
pub use any::*;
|
|
||||||
pub use array::*;
|
|
||||||
pub use core::*;
|
|
||||||
pub use float::*;
|
|
||||||
pub use int::*;
|
|
||||||
pub use ptr::*;
|
|
||||||
pub use structure::*;
|
|
|
@ -1,223 +0,0 @@
|
||||||
use inkwell::{
|
|
||||||
context::Context,
|
|
||||||
types::{BasicType, BasicTypeEnum, PointerType},
|
|
||||||
values::{IntValue, PointerValue},
|
|
||||||
AddressSpace,
|
|
||||||
};
|
|
||||||
|
|
||||||
use crate::codegen::{llvm_intrinsics::call_memcpy_generic, CodeGenContext, CodeGenerator};
|
|
||||||
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
/// A model for [`PointerType`].
|
|
||||||
///
|
|
||||||
/// `Item` is the element type this pointer is pointing to, and should be of a [`Model`].
|
|
||||||
///
|
|
||||||
// TODO: LLVM 15: `Item` is a Rust type-hint for the LLVM type of value the `.store()/.load()` family
|
|
||||||
// of functions return. If a truly opaque pointer is needed, tell the programmer to use `OpaquePtr`.
|
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
|
||||||
pub struct Ptr<Item>(pub Item);
|
|
||||||
|
|
||||||
/// An opaque pointer. Like [`Ptr`] but without any Rust type-hints about its element type.
|
|
||||||
///
|
|
||||||
/// `.load()/.store()` is not available for [`Instance`]s of opaque pointers.
|
|
||||||
pub type OpaquePtr = Ptr<()>;
|
|
||||||
|
|
||||||
// TODO: LLVM 15: `Item: Model<'ctx>` don't even need to be a model anymore. It will only be
|
|
||||||
// a type-hint for the `.load()/.store()` functions for the `pointee_ty`.
|
|
||||||
//
|
|
||||||
// See https://thedan64.github.io/inkwell/inkwell/builder/struct.Builder.html#method.build_load.
|
|
||||||
impl<'ctx, Item: Model<'ctx>> Model<'ctx> for Ptr<Item> {
|
|
||||||
type Value = PointerValue<'ctx>;
|
|
||||||
type Type = PointerType<'ctx>;
|
|
||||||
|
|
||||||
fn llvm_type<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> Self::Type {
|
|
||||||
// TODO: LLVM 15: ctx.ptr_type(AddressSpace::default())
|
|
||||||
self.0.llvm_type(generator, ctx).ptr_type(AddressSpace::default())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
ty: T,
|
|
||||||
) -> Result<(), ModelError> {
|
|
||||||
let ty = ty.as_basic_type_enum();
|
|
||||||
let Ok(ty) = PointerType::try_from(ty) else {
|
|
||||||
return Err(ModelError(format!("Expecting PointerType, but got {ty:?}")));
|
|
||||||
};
|
|
||||||
|
|
||||||
let elem_ty = ty.get_element_type();
|
|
||||||
let Ok(elem_ty) = BasicTypeEnum::try_from(elem_ty) else {
|
|
||||||
return Err(ModelError(format!(
|
|
||||||
"Expecting pointer element type to be a BasicTypeEnum, but got {elem_ty:?}"
|
|
||||||
)));
|
|
||||||
};
|
|
||||||
|
|
||||||
// TODO: inkwell `get_element_type()` will be deprecated.
|
|
||||||
// Remove the check for `get_element_type()` when the time comes.
|
|
||||||
self.0
|
|
||||||
.check_type(generator, ctx, elem_ty)
|
|
||||||
.map_err(|err| err.under_context("a PointerType"))?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx, Item: Model<'ctx>> Ptr<Item> {
|
|
||||||
/// Return a ***constant*** nullptr.
|
|
||||||
pub fn nullptr<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> Instance<'ctx, Ptr<Item>> {
|
|
||||||
let ptr = self.llvm_type(generator, ctx).const_null();
|
|
||||||
unsafe { self.believe_value(ptr) }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Cast a pointer into this model with [`inkwell::builder::Builder::build_pointer_cast`]
|
|
||||||
pub fn pointer_cast<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
ptr: PointerValue<'ctx>,
|
|
||||||
) -> Instance<'ctx, Ptr<Item>> {
|
|
||||||
// TODO: LLVM 15: Write in an impl where `Item` does not have to be `Model<'ctx>`.
|
|
||||||
// TODO: LLVM 15: This function will only have to be:
|
|
||||||
// ```
|
|
||||||
// return self.believe_value(ptr);
|
|
||||||
// ```
|
|
||||||
let t = self.llvm_type(generator, ctx.ctx);
|
|
||||||
let ptr = ctx.builder.build_pointer_cast(ptr, t, "").unwrap();
|
|
||||||
unsafe { self.believe_value(ptr) }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx, Item: Model<'ctx>> Instance<'ctx, Ptr<Item>> {
|
|
||||||
/// Offset the pointer by [`inkwell::builder::Builder::build_in_bounds_gep`].
|
|
||||||
#[must_use]
|
|
||||||
pub fn offset(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
offset: IntValue<'ctx>,
|
|
||||||
) -> Instance<'ctx, Ptr<Item>> {
|
|
||||||
let p = unsafe { ctx.builder.build_in_bounds_gep(self.value, &[offset], "").unwrap() };
|
|
||||||
unsafe { self.model.believe_value(p) }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Offset the pointer by [`inkwell::builder::Builder::build_in_bounds_gep`] by a constant offset.
|
|
||||||
#[must_use]
|
|
||||||
pub fn offset_const(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
offset: i64,
|
|
||||||
) -> Instance<'ctx, Ptr<Item>> {
|
|
||||||
let offset = ctx.ctx.i32_type().const_int(offset as u64, true);
|
|
||||||
self.offset(ctx, offset)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn set_index(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
index: IntValue<'ctx>,
|
|
||||||
value: Instance<'ctx, Item>,
|
|
||||||
) {
|
|
||||||
self.offset(ctx, index).store(ctx, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn set_index_const(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
index: i64,
|
|
||||||
value: Instance<'ctx, Item>,
|
|
||||||
) {
|
|
||||||
self.offset_const(ctx, index).store(ctx, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_index<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
index: IntValue<'ctx>,
|
|
||||||
) -> Instance<'ctx, Item> {
|
|
||||||
self.offset(ctx, index).load(generator, ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_index_const<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
index: i64,
|
|
||||||
) -> Instance<'ctx, Item> {
|
|
||||||
self.offset_const(ctx, index).load(generator, ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Load the value with [`inkwell::builder::Builder::build_load`].
|
|
||||||
pub fn load<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
) -> Instance<'ctx, Item> {
|
|
||||||
let value = ctx.builder.build_load(self.value, "").unwrap();
|
|
||||||
self.model.0.check_value(generator, ctx.ctx, value).unwrap() // If unwrap() panics, there is a logic error.
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Store a value with [`inkwell::builder::Builder::build_store`].
|
|
||||||
pub fn store(&self, ctx: &CodeGenContext<'ctx, '_>, value: Instance<'ctx, Item>) {
|
|
||||||
ctx.builder.build_store(self.value, value.value).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Return a casted pointer of element type `NewElement` with [`inkwell::builder::Builder::build_pointer_cast`].
|
|
||||||
pub fn pointer_cast<NewItem: Model<'ctx>, G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
new_item: NewItem,
|
|
||||||
) -> Instance<'ctx, Ptr<NewItem>> {
|
|
||||||
// TODO: LLVM 15: Write in an impl where `Item` does not have to be `Model<'ctx>`.
|
|
||||||
Ptr(new_item).pointer_cast(generator, ctx, self.value)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Cast this pointer to `uint8_t*`
|
|
||||||
pub fn cast_to_pi8<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
) -> Instance<'ctx, Ptr<Int<Byte>>> {
|
|
||||||
Ptr(Int(Byte)).pointer_cast(generator, ctx, self.value)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if the pointer is null with [`inkwell::builder::Builder::build_is_null`].
|
|
||||||
pub fn is_null(&self, ctx: &CodeGenContext<'ctx, '_>) -> Instance<'ctx, Int<Bool>> {
|
|
||||||
let value = ctx.builder.build_is_null(self.value, "").unwrap();
|
|
||||||
unsafe { Int(Bool).believe_value(value) }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if the pointer is not null with [`inkwell::builder::Builder::build_is_not_null`].
|
|
||||||
pub fn is_not_null(&self, ctx: &CodeGenContext<'ctx, '_>) -> Instance<'ctx, Int<Bool>> {
|
|
||||||
let value = ctx.builder.build_is_not_null(self.value, "").unwrap();
|
|
||||||
unsafe { Int(Bool).believe_value(value) }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// `memcpy` from another pointer.
|
|
||||||
pub fn copy_from<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
source: Self,
|
|
||||||
num_items: IntValue<'ctx>,
|
|
||||||
) {
|
|
||||||
// Force extend `num_items` and `itemsize` to `i64` so their types would match.
|
|
||||||
let itemsize = self.model.size_of(generator, ctx.ctx);
|
|
||||||
let itemsize = Int(SizeT).z_extend_or_truncate(generator, ctx, itemsize);
|
|
||||||
let num_items = Int(SizeT).z_extend_or_truncate(generator, ctx, num_items);
|
|
||||||
let totalsize = itemsize.mul(ctx, num_items);
|
|
||||||
|
|
||||||
let is_volatile = ctx.ctx.bool_type().const_zero(); // is_volatile = false
|
|
||||||
call_memcpy_generic(ctx, self.value, source.value, totalsize.value, is_volatile);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,364 +0,0 @@
|
||||||
use std::fmt;
|
|
||||||
|
|
||||||
use inkwell::{
|
|
||||||
context::Context,
|
|
||||||
types::{BasicType, BasicTypeEnum, StructType},
|
|
||||||
values::{BasicValueEnum, StructValue},
|
|
||||||
};
|
|
||||||
|
|
||||||
use crate::codegen::{CodeGenContext, CodeGenerator};
|
|
||||||
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
/// A traveral that traverses a Rust `struct` that is used to declare an LLVM's struct's field types.
|
|
||||||
pub trait FieldTraversal<'ctx> {
|
|
||||||
/// Output type of [`FieldTraversal::add`].
|
|
||||||
type Output<M>;
|
|
||||||
|
|
||||||
/// Traverse through the type of a declared field and do something with it.
|
|
||||||
///
|
|
||||||
/// * `name` - The cosmetic name of the LLVM field. Used for debugging.
|
|
||||||
/// * `model` - The [`Model`] representing the LLVM type of this field.
|
|
||||||
fn add<M: Model<'ctx>>(&mut self, name: &'static str, model: M) -> Self::Output<M>;
|
|
||||||
|
|
||||||
/// Like [`FieldTraversal::add`] but [`Model`] is automatically inferred from its [`Default`] trait.
|
|
||||||
fn add_auto<M: Model<'ctx> + Default>(&mut self, name: &'static str) -> Self::Output<M> {
|
|
||||||
self.add(name, M::default())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Descriptor of an LLVM struct field.
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
|
||||||
pub struct GepField<M> {
|
|
||||||
/// The GEP index of this field. This is the index to use with `build_gep`.
|
|
||||||
pub gep_index: u32,
|
|
||||||
/// The cosmetic name of this field.
|
|
||||||
pub name: &'static str,
|
|
||||||
/// The [`Model`] of this field's type.
|
|
||||||
pub model: M,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A traversal to calculate the GEP index of fields.
|
|
||||||
pub struct GepFieldTraversal {
|
|
||||||
/// The current GEP index.
|
|
||||||
gep_index_counter: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> FieldTraversal<'ctx> for GepFieldTraversal {
|
|
||||||
type Output<M> = GepField<M>;
|
|
||||||
|
|
||||||
fn add<M: Model<'ctx>>(&mut self, name: &'static str, model: M) -> Self::Output<M> {
|
|
||||||
let gep_index = self.gep_index_counter;
|
|
||||||
self.gep_index_counter += 1;
|
|
||||||
Self::Output { gep_index, name, model }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A traversal to collect the field types of a struct.
|
|
||||||
///
|
|
||||||
/// This is used to collect field types and construct the LLVM struct type with [`Context::struct_type`].
|
|
||||||
struct TypeFieldTraversal<'ctx, 'a, G: CodeGenerator + ?Sized> {
|
|
||||||
generator: &'a G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
/// The collected field types so far in exact order.
|
|
||||||
field_types: Vec<BasicTypeEnum<'ctx>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx, 'a, G: CodeGenerator + ?Sized> FieldTraversal<'ctx> for TypeFieldTraversal<'ctx, 'a, G> {
|
|
||||||
type Output<M> = (); // Checking types return nothing.
|
|
||||||
|
|
||||||
fn add<M: Model<'ctx>>(&mut self, _name: &'static str, model: M) -> Self::Output<M> {
|
|
||||||
let t = model.llvm_type(self.generator, self.ctx).as_basic_type_enum();
|
|
||||||
self.field_types.push(t);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A traversal to check the types of fields.
|
|
||||||
struct CheckTypeFieldTraversal<'ctx, 'a, G: CodeGenerator + ?Sized> {
|
|
||||||
generator: &'a mut G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
/// The current GEP index, so we can tell the index of the field we are checking
|
|
||||||
/// and report the GEP index.
|
|
||||||
gep_index_counter: u32,
|
|
||||||
/// The [`StructType`] to check.
|
|
||||||
scrutinee: StructType<'ctx>,
|
|
||||||
/// The list of collected errors so far.
|
|
||||||
errors: Vec<ModelError>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx, 'a, G: CodeGenerator + ?Sized> FieldTraversal<'ctx>
|
|
||||||
for CheckTypeFieldTraversal<'ctx, 'a, G>
|
|
||||||
{
|
|
||||||
type Output<M> = (); // Checking types return nothing.
|
|
||||||
|
|
||||||
fn add<M: Model<'ctx>>(&mut self, name: &'static str, model: M) -> Self::Output<M> {
|
|
||||||
let gep_index = self.gep_index_counter;
|
|
||||||
self.gep_index_counter += 1;
|
|
||||||
|
|
||||||
if let Some(t) = self.scrutinee.get_field_type_at_index(gep_index) {
|
|
||||||
if let Err(err) = model.check_type(self.generator, self.ctx, t) {
|
|
||||||
self.errors
|
|
||||||
.push(err.under_context(format!("field #{gep_index} '{name}'").as_str()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Otherwise, it will be caught by Struct's `check_type`.
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A trait for Rust structs identifying LLVM structures.
|
|
||||||
///
|
|
||||||
/// ### Example
|
|
||||||
///
|
|
||||||
/// Suppose you want to define this structure:
|
|
||||||
/// ```c
|
|
||||||
/// template <typename T>
|
|
||||||
/// struct ContiguousNDArray {
|
|
||||||
/// size_t ndims;
|
|
||||||
/// size_t* shape;
|
|
||||||
/// T* data;
|
|
||||||
/// }
|
|
||||||
/// ```
|
|
||||||
///
|
|
||||||
/// This is how it should be done:
|
|
||||||
/// ```ignore
|
|
||||||
/// pub struct ContiguousNDArrayFields<'ctx, F: FieldTraversal<'ctx>, Item: Model<'ctx>> {
|
|
||||||
/// pub ndims: F::Out<Int<SizeT>>,
|
|
||||||
/// pub shape: F::Out<Ptr<Int<SizeT>>>,
|
|
||||||
/// pub data: F::Out<Ptr<Item>>,
|
|
||||||
/// }
|
|
||||||
///
|
|
||||||
/// /// An ndarray without strides and non-opaque `data` field in NAC3.
|
|
||||||
/// #[derive(Debug, Clone, Copy)]
|
|
||||||
/// pub struct ContiguousNDArray<M> {
|
|
||||||
/// /// [`Model`] of the items.
|
|
||||||
/// pub item: M,
|
|
||||||
/// }
|
|
||||||
///
|
|
||||||
/// impl<'ctx, Item: Model<'ctx>> StructKind<'ctx> for ContiguousNDArray<Item> {
|
|
||||||
/// type Fields<F: FieldTraversal<'ctx>> = ContiguousNDArrayFields<'ctx, F, Item>;
|
|
||||||
///
|
|
||||||
/// fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
|
|
||||||
/// // The order of `traversal.add*` is important
|
|
||||||
/// Self::Fields {
|
|
||||||
/// ndims: traversal.add_auto("ndims"),
|
|
||||||
/// shape: traversal.add_auto("shape"),
|
|
||||||
/// data: traversal.add("data", Ptr(self.item)),
|
|
||||||
/// }
|
|
||||||
/// }
|
|
||||||
/// }
|
|
||||||
/// ```
|
|
||||||
///
|
|
||||||
/// The [`FieldTraversal`] here is a mechanism to allow the fields of `ContiguousNDArrayFields` to be
|
|
||||||
/// traversed to do useful work such as:
|
|
||||||
///
|
|
||||||
/// - To create the [`StructType`] of `ContiguousNDArray` by collecting [`BasicType`]s of the fields.
|
|
||||||
/// - To enable the `.gep(ctx, |f| f.ndims).store(ctx, ...)` syntax.
|
|
||||||
///
|
|
||||||
/// Suppose now that you have defined `ContiguousNDArray` and you want to allocate a `ContiguousNDArray`
|
|
||||||
/// with dtype `float64` in LLVM, this is how you do it:
|
|
||||||
/// ```ignore
|
|
||||||
/// type F64NDArray = Struct<ContiguousNDArray<Float<Float64>>>; // Type alias for leaner documentation
|
|
||||||
/// let model: F64NDArray = Struct(ContigousNDArray { item: Float(Float64) });
|
|
||||||
/// let ndarray: Instance<'ctx, Ptr<F64NDArray>> = model.alloca(generator, ctx);
|
|
||||||
/// ```
|
|
||||||
///
|
|
||||||
/// ...and here is how you may manipulate/access `ndarray`:
|
|
||||||
///
|
|
||||||
/// (NOTE: some arguments have been omitted)
|
|
||||||
///
|
|
||||||
/// ```ignore
|
|
||||||
/// // Get `&ndarray->data`
|
|
||||||
/// ndarray.gep(|f| f.data); // type: Instance<'ctx, Ptr<Float<Float64>>>
|
|
||||||
///
|
|
||||||
/// // Get `ndarray->ndims`
|
|
||||||
/// ndarray.get(|f| f.ndims); // type: Instance<'ctx, Int<SizeT>>
|
|
||||||
///
|
|
||||||
/// // Get `&ndarray->ndims`
|
|
||||||
/// ndarray.gep(|f| f.ndims); // type: Instance<'ctx, Ptr<Int<SizeT>>>
|
|
||||||
///
|
|
||||||
/// // Get `ndarray->shape[0]`
|
|
||||||
/// ndarray.get(|f| f.shape).get_index_const(0); // Instance<'ctx, Int<SizeT>>
|
|
||||||
///
|
|
||||||
/// // Get `&ndarray->shape[2]`
|
|
||||||
/// ndarray.get(|f| f.shape).offset_const(2); // Instance<'ctx, Ptr<Int<SizeT>>>
|
|
||||||
///
|
|
||||||
/// // Do `ndarray->ndims = 3;`
|
|
||||||
/// let num_3 = Int(SizeT).const_int(3);
|
|
||||||
/// ndarray.set(|f| f.ndims, num_3);
|
|
||||||
/// ```
|
|
||||||
pub trait StructKind<'ctx>: fmt::Debug + Clone + Copy {
|
|
||||||
/// The associated fields of this struct.
|
|
||||||
type Fields<F: FieldTraversal<'ctx>>;
|
|
||||||
|
|
||||||
/// Traverse through all fields of this [`StructKind`].
|
|
||||||
///
|
|
||||||
/// Only used internally in this module for implementing other components.
|
|
||||||
fn iter_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F>;
|
|
||||||
|
|
||||||
/// Get a convenience structure to get a struct field's GEP index through its corresponding Rust field.
|
|
||||||
///
|
|
||||||
/// Only used internally in this module for implementing other components.
|
|
||||||
fn fields(&self) -> Self::Fields<GepFieldTraversal> {
|
|
||||||
self.iter_fields(&mut GepFieldTraversal { gep_index_counter: 0 })
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the LLVM [`StructType`] of this [`StructKind`].
|
|
||||||
fn get_struct_type<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> StructType<'ctx> {
|
|
||||||
let mut traversal = TypeFieldTraversal { generator, ctx, field_types: Vec::new() };
|
|
||||||
self.iter_fields(&mut traversal);
|
|
||||||
|
|
||||||
ctx.struct_type(&traversal.field_types, false)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A model for LLVM struct.
|
|
||||||
///
|
|
||||||
/// `S` should be of a [`StructKind`].
|
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
|
||||||
pub struct Struct<S>(pub S);
|
|
||||||
|
|
||||||
impl<'ctx, S: StructKind<'ctx>> Struct<S> {
|
|
||||||
/// Create a constant struct value from its fields.
|
|
||||||
///
|
|
||||||
/// This function also validates `fields` and panic when there is something wrong.
|
|
||||||
pub fn const_struct<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
fields: &[BasicValueEnum<'ctx>],
|
|
||||||
) -> Instance<'ctx, Self> {
|
|
||||||
// NOTE: There *could* have been a functor `F<M> = Instance<'ctx, M>` for `S::Fields<F>`
|
|
||||||
// to create a more user-friendly interface, but Rust's type system is not sophisticated enough
|
|
||||||
// and if you try doing that Rust would force you put lifetimes everywhere.
|
|
||||||
let val = ctx.const_struct(fields, false);
|
|
||||||
self.check_value(generator, ctx, val).unwrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx, S: StructKind<'ctx>> Model<'ctx> for Struct<S> {
|
|
||||||
type Value = StructValue<'ctx>;
|
|
||||||
type Type = StructType<'ctx>;
|
|
||||||
|
|
||||||
fn llvm_type<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> Self::Type {
|
|
||||||
self.0.get_struct_type(generator, ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
ty: T,
|
|
||||||
) -> Result<(), ModelError> {
|
|
||||||
let ty = ty.as_basic_type_enum();
|
|
||||||
let Ok(ty) = StructType::try_from(ty) else {
|
|
||||||
return Err(ModelError(format!("Expecting StructType, but got {ty:?}")));
|
|
||||||
};
|
|
||||||
|
|
||||||
// Check each field individually.
|
|
||||||
let mut traversal = CheckTypeFieldTraversal {
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
gep_index_counter: 0,
|
|
||||||
errors: Vec::new(),
|
|
||||||
scrutinee: ty,
|
|
||||||
};
|
|
||||||
self.0.iter_fields(&mut traversal);
|
|
||||||
|
|
||||||
// Check the number of fields.
|
|
||||||
let exp_num_fields = traversal.gep_index_counter;
|
|
||||||
let got_num_fields = u32::try_from(ty.get_field_types().len()).unwrap();
|
|
||||||
if exp_num_fields != got_num_fields {
|
|
||||||
return Err(ModelError(format!(
|
|
||||||
"Expecting StructType with {exp_num_fields} field(s), but got {got_num_fields}"
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
|
|
||||||
if !traversal.errors.is_empty() {
|
|
||||||
// Currently, only the first error is reported.
|
|
||||||
return Err(traversal.errors[0].clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx, S: StructKind<'ctx>> Instance<'ctx, Struct<S>> {
|
|
||||||
/// Get a field with [`StructValue::get_field_at_index`].
|
|
||||||
pub fn get_field<G: CodeGenerator + ?Sized, M, GetField>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
get_field: GetField,
|
|
||||||
) -> Instance<'ctx, M>
|
|
||||||
where
|
|
||||||
M: Model<'ctx>,
|
|
||||||
GetField: FnOnce(S::Fields<GepFieldTraversal>) -> GepField<M>,
|
|
||||||
{
|
|
||||||
let field = get_field(self.model.0.fields());
|
|
||||||
let val = self.value.get_field_at_index(field.gep_index).unwrap();
|
|
||||||
field.model.check_value(generator, ctx, val).unwrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx, S: StructKind<'ctx>> Instance<'ctx, Ptr<Struct<S>>> {
|
|
||||||
/// Get a pointer to a field with [`Builder::build_in_bounds_gep`].
|
|
||||||
pub fn gep<M, GetField>(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
get_field: GetField,
|
|
||||||
) -> Instance<'ctx, Ptr<M>>
|
|
||||||
where
|
|
||||||
M: Model<'ctx>,
|
|
||||||
GetField: FnOnce(S::Fields<GepFieldTraversal>) -> GepField<M>,
|
|
||||||
{
|
|
||||||
let field = get_field(self.model.0 .0.fields());
|
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
|
||||||
|
|
||||||
let ptr = unsafe {
|
|
||||||
ctx.builder
|
|
||||||
.build_in_bounds_gep(
|
|
||||||
self.value,
|
|
||||||
&[llvm_i32.const_zero(), llvm_i32.const_int(u64::from(field.gep_index), false)],
|
|
||||||
field.name,
|
|
||||||
)
|
|
||||||
.unwrap()
|
|
||||||
};
|
|
||||||
|
|
||||||
unsafe { Ptr(field.model).believe_value(ptr) }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convenience function equivalent to `.gep(...).load(...)`.
|
|
||||||
pub fn get<M, GetField, G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
get_field: GetField,
|
|
||||||
) -> Instance<'ctx, M>
|
|
||||||
where
|
|
||||||
M: Model<'ctx>,
|
|
||||||
GetField: FnOnce(S::Fields<GepFieldTraversal>) -> GepField<M>,
|
|
||||||
{
|
|
||||||
self.gep(ctx, get_field).load(generator, ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convenience function equivalent to `.gep(...).store(...)`.
|
|
||||||
pub fn set<M, GetField>(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
get_field: GetField,
|
|
||||||
value: Instance<'ctx, M>,
|
|
||||||
) where
|
|
||||||
M: Model<'ctx>,
|
|
||||||
GetField: FnOnce(S::Fields<GepFieldTraversal>) -> GepField<M>,
|
|
||||||
{
|
|
||||||
self.gep(ctx, get_field).store(ctx, value);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,42 +0,0 @@
|
||||||
use crate::codegen::{
|
|
||||||
stmt::{gen_for_callback_incrementing, BreakContinueHooks},
|
|
||||||
CodeGenContext, CodeGenerator,
|
|
||||||
};
|
|
||||||
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
/// Like [`gen_for_callback_incrementing`] with [`Model`] abstractions.
|
|
||||||
///
|
|
||||||
/// `stop` is not included.
|
|
||||||
pub fn gen_for_model<'ctx, 'a, G, F, N>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
|
||||||
start: Instance<'ctx, Int<N>>,
|
|
||||||
stop: Instance<'ctx, Int<N>>,
|
|
||||||
step: Instance<'ctx, Int<N>>,
|
|
||||||
body: F,
|
|
||||||
) -> Result<(), String>
|
|
||||||
where
|
|
||||||
G: CodeGenerator + ?Sized,
|
|
||||||
F: FnOnce(
|
|
||||||
&mut G,
|
|
||||||
&mut CodeGenContext<'ctx, 'a>,
|
|
||||||
BreakContinueHooks<'ctx>,
|
|
||||||
Instance<'ctx, Int<N>>,
|
|
||||||
) -> Result<(), String>,
|
|
||||||
N: IntKind<'ctx> + Default,
|
|
||||||
{
|
|
||||||
let int_model = Int(N::default());
|
|
||||||
gen_for_callback_incrementing(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
None,
|
|
||||||
start.value,
|
|
||||||
(stop.value, false),
|
|
||||||
|g, ctx, hooks, i| {
|
|
||||||
let i = unsafe { int_model.believe_value(i) };
|
|
||||||
body(g, ctx, hooks, i)
|
|
||||||
},
|
|
||||||
step.value,
|
|
||||||
)
|
|
||||||
}
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,12 +0,0 @@
|
||||||
use inkwell::values::BasicValueEnum;
|
|
||||||
|
|
||||||
use crate::typecheck::typedef::Type;
|
|
||||||
|
|
||||||
/// A NAC3 LLVM Python object of any type.
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
|
||||||
pub struct AnyObject<'ctx> {
|
|
||||||
/// Typechecker type of the object.
|
|
||||||
pub ty: Type,
|
|
||||||
/// LLVM value of the object.
|
|
||||||
pub value: BasicValueEnum<'ctx>,
|
|
||||||
}
|
|
|
@ -1,87 +0,0 @@
|
||||||
use crate::{
|
|
||||||
codegen::{model::*, CodeGenContext, CodeGenerator},
|
|
||||||
typecheck::typedef::{iter_type_vars, Type, TypeEnum},
|
|
||||||
};
|
|
||||||
|
|
||||||
use super::any::AnyObject;
|
|
||||||
|
|
||||||
/// Fields of [`List`]
|
|
||||||
pub struct ListFields<'ctx, F: FieldTraversal<'ctx>, Item: Model<'ctx>> {
|
|
||||||
/// Array pointer to content
|
|
||||||
pub items: F::Output<Ptr<Item>>,
|
|
||||||
/// Number of items in the array
|
|
||||||
pub len: F::Output<Int<SizeT>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A list in NAC3.
|
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
|
||||||
pub struct List<Item> {
|
|
||||||
/// Model of the list items
|
|
||||||
pub item: Item,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx, Item: Model<'ctx>> StructKind<'ctx> for List<Item> {
|
|
||||||
type Fields<F: FieldTraversal<'ctx>> = ListFields<'ctx, F, Item>;
|
|
||||||
|
|
||||||
fn iter_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
|
|
||||||
Self::Fields {
|
|
||||||
items: traversal.add("items", Ptr(self.item)),
|
|
||||||
len: traversal.add_auto("len"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx, Item: Model<'ctx>> Instance<'ctx, Ptr<Struct<List<Item>>>> {
|
|
||||||
/// Cast the items pointer to `uint8_t*`.
|
|
||||||
pub fn with_pi8_items<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
) -> Instance<'ctx, Ptr<Struct<List<Int<Byte>>>>> {
|
|
||||||
self.pointer_cast(generator, ctx, Struct(List { item: Int(Byte) }))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A NAC3 Python List object.
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
|
||||||
pub struct ListObject<'ctx> {
|
|
||||||
/// Typechecker type of the list items
|
|
||||||
pub item_type: Type,
|
|
||||||
pub instance: Instance<'ctx, Ptr<Struct<List<Any<'ctx>>>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> ListObject<'ctx> {
|
|
||||||
/// Create a [`ListObject`] from an LLVM value and its typechecker [`Type`].
|
|
||||||
pub fn from_object<G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
object: AnyObject<'ctx>,
|
|
||||||
) -> Self {
|
|
||||||
// Check typechecker type and extract `item_type`
|
|
||||||
let item_type = match &*ctx.unifier.get_ty(object.ty) {
|
|
||||||
TypeEnum::TObj { obj_id, params, .. }
|
|
||||||
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
|
||||||
{
|
|
||||||
iter_type_vars(params).next().unwrap().ty // Extract `item_type`
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
panic!("Expecting type to be a list, but got {}", ctx.unifier.stringify(object.ty))
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let plist = Ptr(Struct(List { item: Any(ctx.get_llvm_type(generator, item_type)) }));
|
|
||||||
|
|
||||||
// Create object
|
|
||||||
let value = plist.check_value(generator, ctx.ctx, object.value).unwrap();
|
|
||||||
ListObject { item_type, instance: value }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the `len()` of this list.
|
|
||||||
pub fn len<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
) -> Instance<'ctx, Int<SizeT>> {
|
|
||||||
self.instance.get(generator, ctx, |f| f.len)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,5 +0,0 @@
|
||||||
pub mod any;
|
|
||||||
pub mod list;
|
|
||||||
pub mod ndarray;
|
|
||||||
pub mod tuple;
|
|
||||||
pub mod utils;
|
|
|
@ -1,184 +0,0 @@
|
||||||
use super::NDArrayObject;
|
|
||||||
use crate::{
|
|
||||||
codegen::{
|
|
||||||
irrt::{
|
|
||||||
call_nac3_ndarray_array_set_and_validate_list_shape,
|
|
||||||
call_nac3_ndarray_array_write_list_to_array,
|
|
||||||
},
|
|
||||||
model::*,
|
|
||||||
object::{any::AnyObject, list::ListObject},
|
|
||||||
stmt::gen_if_else_expr_callback,
|
|
||||||
CodeGenContext, CodeGenerator,
|
|
||||||
},
|
|
||||||
toplevel::helper::{arraylike_flatten_element_type, arraylike_get_ndims},
|
|
||||||
typecheck::typedef::{Type, TypeEnum},
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Get the expected `dtype` and `ndims` of the ndarray returned by `np_array(list)`.
|
|
||||||
fn get_list_object_dtype_and_ndims<'ctx>(
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
list: ListObject<'ctx>,
|
|
||||||
) -> (Type, u64) {
|
|
||||||
let dtype = arraylike_flatten_element_type(&mut ctx.unifier, list.item_type);
|
|
||||||
|
|
||||||
let ndims = arraylike_get_ndims(&mut ctx.unifier, list.item_type);
|
|
||||||
let ndims = ndims + 1; // To count `list` itself.
|
|
||||||
|
|
||||||
(dtype, ndims)
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> NDArrayObject<'ctx> {
|
|
||||||
/// Implementation of `np_array(<list>, copy=True)`
|
|
||||||
fn make_np_array_list_copy_true_impl<G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
list: ListObject<'ctx>,
|
|
||||||
) -> Self {
|
|
||||||
let (dtype, ndims_int) = get_list_object_dtype_and_ndims(ctx, list);
|
|
||||||
let list_value = list.instance.with_pi8_items(generator, ctx);
|
|
||||||
|
|
||||||
// Validate `list` has a consistent shape.
|
|
||||||
// Raise an exception if `list` is something abnormal like `[[1, 2], [3]]`.
|
|
||||||
// If `list` has a consistent shape, deduce the shape and write it to `shape`.
|
|
||||||
let ndims = Int(SizeT).const_int(generator, ctx.ctx, ndims_int, false);
|
|
||||||
let shape = Int(SizeT).array_alloca(generator, ctx, ndims.value);
|
|
||||||
call_nac3_ndarray_array_set_and_validate_list_shape(
|
|
||||||
generator, ctx, list_value, ndims, shape,
|
|
||||||
);
|
|
||||||
|
|
||||||
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims_int);
|
|
||||||
ndarray.copy_shape_from_array(generator, ctx, shape);
|
|
||||||
ndarray.create_data(generator, ctx);
|
|
||||||
|
|
||||||
// Copy all contents from the list.
|
|
||||||
call_nac3_ndarray_array_write_list_to_array(generator, ctx, list_value, ndarray.instance);
|
|
||||||
|
|
||||||
ndarray
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Implementation of `np_array(<list>, copy=None)`
|
|
||||||
fn make_np_array_list_copy_none_impl<G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
list: ListObject<'ctx>,
|
|
||||||
) -> Self {
|
|
||||||
// np_array without copying is only possible `list` is not nested.
|
|
||||||
//
|
|
||||||
// If `list` is `list[T]`, we can create an ndarray with `data` set
|
|
||||||
// to the array pointer of `list`.
|
|
||||||
//
|
|
||||||
// If `list` is `list[list[T]]` or worse, copy.
|
|
||||||
|
|
||||||
let (dtype, ndims) = get_list_object_dtype_and_ndims(ctx, list);
|
|
||||||
if ndims == 1 {
|
|
||||||
// `list` is not nested
|
|
||||||
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, 1);
|
|
||||||
|
|
||||||
// Set data
|
|
||||||
let data = list.instance.get(generator, ctx, |f| f.items).cast_to_pi8(generator, ctx);
|
|
||||||
ndarray.instance.set(ctx, |f| f.data, data);
|
|
||||||
|
|
||||||
// ndarray->shape[0] = list->len;
|
|
||||||
let shape = ndarray.instance.get(generator, ctx, |f| f.shape);
|
|
||||||
let list_len = list.instance.get(generator, ctx, |f| f.len);
|
|
||||||
shape.set_index_const(ctx, 0, list_len);
|
|
||||||
|
|
||||||
// Set strides, the `data` is contiguous
|
|
||||||
ndarray.set_strides_contiguous(generator, ctx);
|
|
||||||
|
|
||||||
ndarray
|
|
||||||
} else {
|
|
||||||
// `list` is nested, copy
|
|
||||||
NDArrayObject::make_np_array_list_copy_true_impl(generator, ctx, list)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Implementation of `np_array(<list>, copy=copy)`
|
|
||||||
fn make_np_array_list_impl<G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
list: ListObject<'ctx>,
|
|
||||||
copy: Instance<'ctx, Int<Bool>>,
|
|
||||||
) -> Self {
|
|
||||||
let (dtype, ndims) = get_list_object_dtype_and_ndims(ctx, list);
|
|
||||||
|
|
||||||
let ndarray = gen_if_else_expr_callback(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
|_generator, _ctx| Ok(copy.value),
|
|
||||||
|generator, ctx| {
|
|
||||||
let ndarray =
|
|
||||||
NDArrayObject::make_np_array_list_copy_true_impl(generator, ctx, list);
|
|
||||||
Ok(Some(ndarray.instance.value))
|
|
||||||
},
|
|
||||||
|generator, ctx| {
|
|
||||||
let ndarray =
|
|
||||||
NDArrayObject::make_np_array_list_copy_none_impl(generator, ctx, list);
|
|
||||||
Ok(Some(ndarray.instance.value))
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.unwrap()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
NDArrayObject::from_value_and_unpacked_types(generator, ctx, ndarray, dtype, ndims)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Implementation of `np_array(<ndarray>, copy=copy)`.
|
|
||||||
pub fn make_np_array_ndarray_impl<G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
ndarray: NDArrayObject<'ctx>,
|
|
||||||
copy: Instance<'ctx, Int<Bool>>,
|
|
||||||
) -> Self {
|
|
||||||
let ndarray_val = gen_if_else_expr_callback(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
|_generator, _ctx| Ok(copy.value),
|
|
||||||
|generator, ctx| {
|
|
||||||
let ndarray = ndarray.make_copy(generator, ctx); // Force copy
|
|
||||||
Ok(Some(ndarray.instance.value))
|
|
||||||
},
|
|
||||||
|_generator, _ctx| {
|
|
||||||
// No need to copy. Return `ndarray` itself.
|
|
||||||
Ok(Some(ndarray.instance.value))
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.unwrap()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
NDArrayObject::from_value_and_unpacked_types(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
ndarray_val,
|
|
||||||
ndarray.dtype,
|
|
||||||
ndarray.ndims,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a new ndarray like `np.array()`.
|
|
||||||
///
|
|
||||||
/// NOTE: The `ndmin` argument is not here. You may want to
|
|
||||||
/// do [`NDArrayObject::atleast_nd`] to achieve that.
|
|
||||||
pub fn make_np_array<G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
object: AnyObject<'ctx>,
|
|
||||||
copy: Instance<'ctx, Int<Bool>>,
|
|
||||||
) -> Self {
|
|
||||||
match &*ctx.unifier.get_ty(object.ty) {
|
|
||||||
TypeEnum::TObj { obj_id, .. }
|
|
||||||
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
|
||||||
{
|
|
||||||
let list = ListObject::from_object(generator, ctx, object);
|
|
||||||
NDArrayObject::make_np_array_list_impl(generator, ctx, list, copy)
|
|
||||||
}
|
|
||||||
TypeEnum::TObj { obj_id, .. }
|
|
||||||
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
|
|
||||||
{
|
|
||||||
let ndarray = NDArrayObject::from_object(generator, ctx, object);
|
|
||||||
NDArrayObject::make_np_array_ndarray_impl(generator, ctx, ndarray, copy)
|
|
||||||
}
|
|
||||||
_ => panic!("Unrecognized object type: {}", ctx.unifier.stringify(object.ty)), // Typechecker ensures this
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,139 +0,0 @@
|
||||||
use itertools::Itertools;
|
|
||||||
|
|
||||||
use crate::codegen::{
|
|
||||||
irrt::{call_nac3_ndarray_broadcast_shapes, call_nac3_ndarray_broadcast_to},
|
|
||||||
model::*,
|
|
||||||
CodeGenContext, CodeGenerator,
|
|
||||||
};
|
|
||||||
|
|
||||||
use super::NDArrayObject;
|
|
||||||
|
|
||||||
/// Fields of [`ShapeEntry`]
|
|
||||||
pub struct ShapeEntryFields<'ctx, F: FieldTraversal<'ctx>> {
|
|
||||||
pub ndims: F::Output<Int<SizeT>>,
|
|
||||||
pub shape: F::Output<Ptr<Int<SizeT>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// An IRRT structure used in broadcasting.
|
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
|
||||||
pub struct ShapeEntry;
|
|
||||||
|
|
||||||
impl<'ctx> StructKind<'ctx> for ShapeEntry {
|
|
||||||
type Fields<F: FieldTraversal<'ctx>> = ShapeEntryFields<'ctx, F>;
|
|
||||||
|
|
||||||
fn iter_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
|
|
||||||
Self::Fields { ndims: traversal.add_auto("ndims"), shape: traversal.add_auto("shape") }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> NDArrayObject<'ctx> {
|
|
||||||
/// Create a broadcast view on this ndarray with a target shape.
|
|
||||||
///
|
|
||||||
/// The input shape will be checked to make sure that it contains no negative values.
|
|
||||||
///
|
|
||||||
/// * `target_ndims` - The ndims type after broadcasting to the given shape.
|
|
||||||
/// The caller has to figure this out for this function.
|
|
||||||
/// * `target_shape` - An array pointer pointing to the target shape.
|
|
||||||
#[must_use]
|
|
||||||
pub fn broadcast_to<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
target_ndims: u64,
|
|
||||||
target_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
|
||||||
) -> Self {
|
|
||||||
let broadcast_ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, target_ndims);
|
|
||||||
broadcast_ndarray.copy_shape_from_array(generator, ctx, target_shape);
|
|
||||||
|
|
||||||
call_nac3_ndarray_broadcast_to(generator, ctx, self.instance, broadcast_ndarray.instance);
|
|
||||||
broadcast_ndarray
|
|
||||||
}
|
|
||||||
}
|
|
||||||
/// A result produced by [`broadcast_all_ndarrays`]
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct BroadcastAllResult<'ctx> {
|
|
||||||
/// The statically known `ndims` of the broadcast result.
|
|
||||||
pub ndims: u64,
|
|
||||||
/// The broadcasting shape.
|
|
||||||
pub shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
|
||||||
/// Broadcasted views on the inputs.
|
|
||||||
///
|
|
||||||
/// All of them will have `shape` [`BroadcastAllResult::shape`] and
|
|
||||||
/// `ndims` [`BroadcastAllResult::ndims`]. The length of the vector
|
|
||||||
/// is the same as the input.
|
|
||||||
pub ndarrays: Vec<NDArrayObject<'ctx>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Helper function to call `call_nac3_ndarray_broadcast_shapes`
|
|
||||||
fn broadcast_shapes<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
in_shape_entries: &[(Instance<'ctx, Ptr<Int<SizeT>>>, u64)], // (shape, shape's length/ndims)
|
|
||||||
broadcast_ndims: u64,
|
|
||||||
broadcast_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
|
||||||
) {
|
|
||||||
// Prepare input shape entries to be passed to `call_nac3_ndarray_broadcast_shapes`.
|
|
||||||
let num_shape_entries = Int(SizeT).const_int(
|
|
||||||
generator,
|
|
||||||
ctx.ctx,
|
|
||||||
u64::try_from(in_shape_entries.len()).unwrap(),
|
|
||||||
false,
|
|
||||||
);
|
|
||||||
let shape_entries = Struct(ShapeEntry).array_alloca(generator, ctx, num_shape_entries.value);
|
|
||||||
for (i, (in_shape, in_ndims)) in in_shape_entries.iter().enumerate() {
|
|
||||||
let pshape_entry = shape_entries.offset_const(ctx, i64::try_from(i).unwrap());
|
|
||||||
|
|
||||||
let in_ndims = Int(SizeT).const_int(generator, ctx.ctx, *in_ndims, false);
|
|
||||||
pshape_entry.set(ctx, |f| f.ndims, in_ndims);
|
|
||||||
|
|
||||||
pshape_entry.set(ctx, |f| f.shape, *in_shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
let broadcast_ndims = Int(SizeT).const_int(generator, ctx.ctx, broadcast_ndims, false);
|
|
||||||
call_nac3_ndarray_broadcast_shapes(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
num_shape_entries,
|
|
||||||
shape_entries,
|
|
||||||
broadcast_ndims,
|
|
||||||
broadcast_shape,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> NDArrayObject<'ctx> {
|
|
||||||
/// Broadcast all ndarrays according to `np.broadcast()` and return a [`BroadcastAllResult`]
|
|
||||||
/// containing all the information of the result of the broadcast operation.
|
|
||||||
pub fn broadcast<G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
ndarrays: &[Self],
|
|
||||||
) -> BroadcastAllResult<'ctx> {
|
|
||||||
assert!(!ndarrays.is_empty());
|
|
||||||
|
|
||||||
// Infer the broadcast output ndims.
|
|
||||||
let broadcast_ndims_int = ndarrays.iter().map(|ndarray| ndarray.ndims).max().unwrap();
|
|
||||||
|
|
||||||
let broadcast_ndims = Int(SizeT).const_int(generator, ctx.ctx, broadcast_ndims_int, false);
|
|
||||||
let broadcast_shape = Int(SizeT).array_alloca(generator, ctx, broadcast_ndims.value);
|
|
||||||
|
|
||||||
let shape_entries = ndarrays
|
|
||||||
.iter()
|
|
||||||
.map(|ndarray| (ndarray.instance.get(generator, ctx, |f| f.shape), ndarray.ndims))
|
|
||||||
.collect_vec();
|
|
||||||
broadcast_shapes(generator, ctx, &shape_entries, broadcast_ndims_int, broadcast_shape);
|
|
||||||
|
|
||||||
// Broadcast all the inputs to shape `dst_shape`.
|
|
||||||
let broadcast_ndarrays: Vec<_> = ndarrays
|
|
||||||
.iter()
|
|
||||||
.map(|ndarray| {
|
|
||||||
ndarray.broadcast_to(generator, ctx, broadcast_ndims_int, broadcast_shape)
|
|
||||||
})
|
|
||||||
.collect_vec();
|
|
||||||
|
|
||||||
BroadcastAllResult {
|
|
||||||
ndims: broadcast_ndims_int,
|
|
||||||
shape: broadcast_shape,
|
|
||||||
ndarrays: broadcast_ndarrays,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,134 +0,0 @@
|
||||||
use crate::{
|
|
||||||
codegen::{model::*, CodeGenContext, CodeGenerator},
|
|
||||||
typecheck::typedef::Type,
|
|
||||||
};
|
|
||||||
|
|
||||||
use super::NDArrayObject;
|
|
||||||
|
|
||||||
/// Fields of [`ContiguousNDArray`]
|
|
||||||
pub struct ContiguousNDArrayFields<'ctx, F: FieldTraversal<'ctx>, Item: Model<'ctx>> {
|
|
||||||
pub ndims: F::Output<Int<SizeT>>,
|
|
||||||
pub shape: F::Output<Ptr<Int<SizeT>>>,
|
|
||||||
pub data: F::Output<Ptr<Item>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// An ndarray without strides and non-opaque `data` field in NAC3.
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
|
||||||
pub struct ContiguousNDArray<M> {
|
|
||||||
/// [`Model`] of the items.
|
|
||||||
pub item: M,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx, Item: Model<'ctx>> StructKind<'ctx> for ContiguousNDArray<Item> {
|
|
||||||
type Fields<F: FieldTraversal<'ctx>> = ContiguousNDArrayFields<'ctx, F, Item>;
|
|
||||||
|
|
||||||
fn iter_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
|
|
||||||
Self::Fields {
|
|
||||||
ndims: traversal.add_auto("ndims"),
|
|
||||||
shape: traversal.add_auto("shape"),
|
|
||||||
data: traversal.add("data", Ptr(self.item)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> NDArrayObject<'ctx> {
|
|
||||||
/// Create a [`ContiguousNDArray`] from the contents of this ndarray.
|
|
||||||
///
|
|
||||||
/// This function may or may not be expensive depending on if this ndarray has contiguous data.
|
|
||||||
///
|
|
||||||
/// If this ndarray is not C-contiguous, this function will allocate memory on the stack for the `data` field of
|
|
||||||
/// the returned [`ContiguousNDArray`] and copy contents of this ndarray to there.
|
|
||||||
///
|
|
||||||
/// If this ndarray is C-contiguous, contents of this ndarray will not be copied. The created [`ContiguousNDArray`]
|
|
||||||
/// will share memory with this ndarray.
|
|
||||||
///
|
|
||||||
/// The `item_model` sets the [`Model`] of the returned [`ContiguousNDArray`]'s `Item` model for type-safety, and
|
|
||||||
/// should match the `ctx.get_llvm_type()` of this ndarray's `dtype`. Otherwise this function panics. Use model [`Any`]
|
|
||||||
/// if you don't care/cannot know the [`Model`] in advance.
|
|
||||||
pub fn make_contiguous_ndarray<G: CodeGenerator + ?Sized, Item: Model<'ctx>>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
item_model: Item,
|
|
||||||
) -> Instance<'ctx, Ptr<Struct<ContiguousNDArray<Item>>>> {
|
|
||||||
// Sanity check on `self.dtype` and `item_model`.
|
|
||||||
let dtype_llvm = ctx.get_llvm_type(generator, self.dtype);
|
|
||||||
item_model.check_type(generator, ctx.ctx, dtype_llvm).unwrap();
|
|
||||||
|
|
||||||
let cdarray_model = Struct(ContiguousNDArray { item: item_model });
|
|
||||||
|
|
||||||
let current_bb = ctx.builder.get_insert_block().unwrap();
|
|
||||||
let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "then_bb");
|
|
||||||
let else_bb = ctx.ctx.insert_basic_block_after(then_bb, "else_bb");
|
|
||||||
let end_bb = ctx.ctx.insert_basic_block_after(else_bb, "end_bb");
|
|
||||||
|
|
||||||
// Allocate and setup the resulting [`ContiguousNDArray`].
|
|
||||||
let result = cdarray_model.alloca(generator, ctx);
|
|
||||||
|
|
||||||
// Set ndims and shape.
|
|
||||||
let ndims = self.ndims_llvm(generator, ctx.ctx);
|
|
||||||
result.set(ctx, |f| f.ndims, ndims);
|
|
||||||
|
|
||||||
let shape = self.instance.get(generator, ctx, |f| f.shape);
|
|
||||||
result.set(ctx, |f| f.shape, shape);
|
|
||||||
|
|
||||||
let is_contiguous = self.is_c_contiguous(generator, ctx);
|
|
||||||
ctx.builder.build_conditional_branch(is_contiguous.value, then_bb, else_bb).unwrap();
|
|
||||||
|
|
||||||
// Inserting into then_bb; This ndarray is contiguous.
|
|
||||||
ctx.builder.position_at_end(then_bb);
|
|
||||||
let data = self.instance.get(generator, ctx, |f| f.data);
|
|
||||||
let data = data.pointer_cast(generator, ctx, item_model);
|
|
||||||
result.set(ctx, |f| f.data, data);
|
|
||||||
ctx.builder.build_unconditional_branch(end_bb).unwrap();
|
|
||||||
|
|
||||||
// Inserting into else_bb; This ndarray is not contiguous. Do a full-copy on `data`.
|
|
||||||
// `make_copy` produces an ndarray with contiguous `data`.
|
|
||||||
ctx.builder.position_at_end(else_bb);
|
|
||||||
let copied_ndarray = self.make_copy(generator, ctx);
|
|
||||||
let data = copied_ndarray.instance.get(generator, ctx, |f| f.data);
|
|
||||||
let data = data.pointer_cast(generator, ctx, item_model);
|
|
||||||
result.set(ctx, |f| f.data, data);
|
|
||||||
ctx.builder.build_unconditional_branch(end_bb).unwrap();
|
|
||||||
|
|
||||||
// Reposition to end_bb for continuation
|
|
||||||
ctx.builder.position_at_end(end_bb);
|
|
||||||
|
|
||||||
result
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create an [`NDArrayObject`] from a [`ContiguousNDArray`].
|
|
||||||
///
|
|
||||||
/// The operation is super cheap. The newly created [`NDArrayObject`] will share the
|
|
||||||
/// same memory as the [`ContiguousNDArray`].
|
|
||||||
///
|
|
||||||
/// `ndims` has to be provided as [`NDArrayObject`] requires a statically known `ndims` value, despite
|
|
||||||
/// the fact that the information should be contained within the [`ContiguousNDArray`].
|
|
||||||
pub fn from_contiguous_ndarray<G: CodeGenerator + ?Sized, Item: Model<'ctx>>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
carray: Instance<'ctx, Ptr<Struct<ContiguousNDArray<Item>>>>,
|
|
||||||
dtype: Type,
|
|
||||||
ndims: u64,
|
|
||||||
) -> Self {
|
|
||||||
// Sanity check on `dtype` and `contiguous_array`'s `Item` model.
|
|
||||||
let dtype_llvm = ctx.get_llvm_type(generator, dtype);
|
|
||||||
carray.model.0 .0.item.check_type(generator, ctx.ctx, dtype_llvm).unwrap();
|
|
||||||
|
|
||||||
// TODO: Debug assert `ndims == carray.ndims` to catch bugs.
|
|
||||||
|
|
||||||
// Allocate the resulting ndarray.
|
|
||||||
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims);
|
|
||||||
|
|
||||||
// Copy shape and update strides
|
|
||||||
let shape = carray.get(generator, ctx, |f| f.shape);
|
|
||||||
ndarray.copy_shape_from_array(generator, ctx, shape);
|
|
||||||
ndarray.set_strides_contiguous(generator, ctx);
|
|
||||||
|
|
||||||
// Share data
|
|
||||||
let data = carray.get(generator, ctx, |f| f.data).pointer_cast(generator, ctx, Int(Byte));
|
|
||||||
ndarray.instance.set(ctx, |f| f.data, data);
|
|
||||||
|
|
||||||
ndarray
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,176 +0,0 @@
|
||||||
use inkwell::{values::BasicValueEnum, IntPredicate};
|
|
||||||
|
|
||||||
use crate::{
|
|
||||||
codegen::{
|
|
||||||
irrt::call_nac3_ndarray_util_assert_shape_no_negative, model::*, CodeGenContext,
|
|
||||||
CodeGenerator,
|
|
||||||
},
|
|
||||||
typecheck::typedef::Type,
|
|
||||||
};
|
|
||||||
|
|
||||||
use super::NDArrayObject;
|
|
||||||
|
|
||||||
/// Get the zero value in `np.zeros()` of a `dtype`.
|
|
||||||
fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
dtype: Type,
|
|
||||||
) -> BasicValueEnum<'ctx> {
|
|
||||||
if [ctx.primitives.int32, ctx.primitives.uint32]
|
|
||||||
.iter()
|
|
||||||
.any(|ty| ctx.unifier.unioned(dtype, *ty))
|
|
||||||
{
|
|
||||||
ctx.ctx.i32_type().const_zero().into()
|
|
||||||
} else if [ctx.primitives.int64, ctx.primitives.uint64]
|
|
||||||
.iter()
|
|
||||||
.any(|ty| ctx.unifier.unioned(dtype, *ty))
|
|
||||||
{
|
|
||||||
ctx.ctx.i64_type().const_zero().into()
|
|
||||||
} else if ctx.unifier.unioned(dtype, ctx.primitives.float) {
|
|
||||||
ctx.ctx.f64_type().const_zero().into()
|
|
||||||
} else if ctx.unifier.unioned(dtype, ctx.primitives.bool) {
|
|
||||||
ctx.ctx.bool_type().const_zero().into()
|
|
||||||
} else if ctx.unifier.unioned(dtype, ctx.primitives.str) {
|
|
||||||
ctx.gen_string(generator, "").into()
|
|
||||||
} else {
|
|
||||||
panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the one value in `np.ones()` of a `dtype`.
|
|
||||||
fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
dtype: Type,
|
|
||||||
) -> BasicValueEnum<'ctx> {
|
|
||||||
if [ctx.primitives.int32, ctx.primitives.uint32]
|
|
||||||
.iter()
|
|
||||||
.any(|ty| ctx.unifier.unioned(dtype, *ty))
|
|
||||||
{
|
|
||||||
let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int32);
|
|
||||||
ctx.ctx.i32_type().const_int(1, is_signed).into()
|
|
||||||
} else if [ctx.primitives.int64, ctx.primitives.uint64]
|
|
||||||
.iter()
|
|
||||||
.any(|ty| ctx.unifier.unioned(dtype, *ty))
|
|
||||||
{
|
|
||||||
let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int64);
|
|
||||||
ctx.ctx.i64_type().const_int(1, is_signed).into()
|
|
||||||
} else if ctx.unifier.unioned(dtype, ctx.primitives.float) {
|
|
||||||
ctx.ctx.f64_type().const_float(1.0).into()
|
|
||||||
} else if ctx.unifier.unioned(dtype, ctx.primitives.bool) {
|
|
||||||
ctx.ctx.bool_type().const_int(1, false).into()
|
|
||||||
} else if ctx.unifier.unioned(dtype, ctx.primitives.str) {
|
|
||||||
ctx.gen_string(generator, "1").into()
|
|
||||||
} else {
|
|
||||||
panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> NDArrayObject<'ctx> {
|
|
||||||
/// Create an ndarray like `np.empty`.
|
|
||||||
pub fn make_np_empty<G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
dtype: Type,
|
|
||||||
ndims: u64,
|
|
||||||
shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
|
||||||
) -> Self {
|
|
||||||
// Validate `shape`
|
|
||||||
let ndims_llvm = Int(SizeT).const_int(generator, ctx.ctx, ndims, false);
|
|
||||||
call_nac3_ndarray_util_assert_shape_no_negative(generator, ctx, ndims_llvm, shape);
|
|
||||||
|
|
||||||
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims);
|
|
||||||
ndarray.copy_shape_from_array(generator, ctx, shape);
|
|
||||||
ndarray.create_data(generator, ctx);
|
|
||||||
|
|
||||||
ndarray
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create an ndarray like `np.full`.
|
|
||||||
pub fn make_np_full<G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
dtype: Type,
|
|
||||||
ndims: u64,
|
|
||||||
shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
|
||||||
fill_value: BasicValueEnum<'ctx>,
|
|
||||||
) -> Self {
|
|
||||||
let ndarray = NDArrayObject::make_np_empty(generator, ctx, dtype, ndims, shape);
|
|
||||||
ndarray.fill(generator, ctx, fill_value);
|
|
||||||
ndarray
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create an ndarray like `np.zero`.
|
|
||||||
pub fn make_np_zeros<G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
dtype: Type,
|
|
||||||
ndims: u64,
|
|
||||||
shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
|
||||||
) -> Self {
|
|
||||||
let fill_value = ndarray_zero_value(generator, ctx, dtype);
|
|
||||||
NDArrayObject::make_np_full(generator, ctx, dtype, ndims, shape, fill_value)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create an ndarray like `np.ones`.
|
|
||||||
pub fn make_np_ones<G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
dtype: Type,
|
|
||||||
ndims: u64,
|
|
||||||
shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
|
||||||
) -> Self {
|
|
||||||
let fill_value = ndarray_one_value(generator, ctx, dtype);
|
|
||||||
NDArrayObject::make_np_full(generator, ctx, dtype, ndims, shape, fill_value)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create an ndarray like `np.eye`.
|
|
||||||
pub fn make_np_eye<G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
dtype: Type,
|
|
||||||
nrows: Instance<'ctx, Int<SizeT>>,
|
|
||||||
ncols: Instance<'ctx, Int<SizeT>>,
|
|
||||||
offset: Instance<'ctx, Int<SizeT>>,
|
|
||||||
) -> Self {
|
|
||||||
let ndzero = ndarray_zero_value(generator, ctx, dtype);
|
|
||||||
let ndone = ndarray_one_value(generator, ctx, dtype);
|
|
||||||
|
|
||||||
let ndarray = NDArrayObject::alloca_dynamic_shape(generator, ctx, dtype, &[nrows, ncols]);
|
|
||||||
|
|
||||||
// Create data and make the matrix like look np.eye()
|
|
||||||
ndarray.create_data(generator, ctx);
|
|
||||||
ndarray
|
|
||||||
.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
|
|
||||||
// NOTE: rows and cols can never be zero here, since this ndarray's `np.size` would be zero
|
|
||||||
// and this loop would not execute.
|
|
||||||
|
|
||||||
// Load up `row_i` and `col_i` from indices.
|
|
||||||
let row_i = nditer.get_indices().get_index_const(generator, ctx, 0);
|
|
||||||
let col_i = nditer.get_indices().get_index_const(generator, ctx, 1);
|
|
||||||
|
|
||||||
let be_one = row_i.add(ctx, offset).compare(ctx, IntPredicate::EQ, col_i);
|
|
||||||
let value = ctx.builder.build_select(be_one.value, ndone, ndzero, "value").unwrap();
|
|
||||||
|
|
||||||
let p = nditer.get_pointer(generator, ctx);
|
|
||||||
ctx.builder.build_store(p, value).unwrap();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
})
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
ndarray
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create an ndarray like `np.identity`.
|
|
||||||
pub fn make_np_identity<G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
dtype: Type,
|
|
||||||
size: Instance<'ctx, Int<SizeT>>,
|
|
||||||
) -> Self {
|
|
||||||
// Convenient implementation
|
|
||||||
let offset = Int(SizeT).const_0(generator, ctx.ctx);
|
|
||||||
NDArrayObject::make_np_eye(generator, ctx, dtype, size, size, offset)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,227 +0,0 @@
|
||||||
use crate::codegen::{
|
|
||||||
irrt::call_nac3_ndarray_index,
|
|
||||||
model::*,
|
|
||||||
object::utils::slice::{RustSlice, Slice},
|
|
||||||
CodeGenContext, CodeGenerator,
|
|
||||||
};
|
|
||||||
|
|
||||||
use super::NDArrayObject;
|
|
||||||
|
|
||||||
pub type NDIndexType = Byte;
|
|
||||||
|
|
||||||
/// Fields of [`NDIndex`]
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
|
||||||
pub struct NDIndexFields<'ctx, F: FieldTraversal<'ctx>> {
|
|
||||||
pub type_: F::Output<Int<NDIndexType>>,
|
|
||||||
pub data: F::Output<Ptr<Int<Byte>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// An IRRT representation of an ndarray subscript index.
|
|
||||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
|
|
||||||
pub struct NDIndex;
|
|
||||||
|
|
||||||
impl<'ctx> StructKind<'ctx> for NDIndex {
|
|
||||||
type Fields<F: FieldTraversal<'ctx>> = NDIndexFields<'ctx, F>;
|
|
||||||
|
|
||||||
fn iter_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
|
|
||||||
Self::Fields { type_: traversal.add_auto("type"), data: traversal.add_auto("data") }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// A convenience enum representing a [`NDIndex`].
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub enum RustNDIndex<'ctx> {
|
|
||||||
SingleElement(Instance<'ctx, Int<Int32>>),
|
|
||||||
Slice(RustSlice<'ctx, Int32>),
|
|
||||||
NewAxis,
|
|
||||||
Ellipsis,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> RustNDIndex<'ctx> {
|
|
||||||
/// Get the value to set `NDIndex::type` for this variant.
|
|
||||||
fn get_type_id(&self) -> u64 {
|
|
||||||
// Defined in IRRT, must be in sync
|
|
||||||
match self {
|
|
||||||
RustNDIndex::SingleElement(_) => 0,
|
|
||||||
RustNDIndex::Slice(_) => 1,
|
|
||||||
RustNDIndex::NewAxis => 2,
|
|
||||||
RustNDIndex::Ellipsis => 3,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Serialize this [`RustNDIndex`] by writing it into an LLVM [`NDIndex`].
|
|
||||||
fn write_to_ndindex<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
dst_ndindex_ptr: Instance<'ctx, Ptr<Struct<NDIndex>>>,
|
|
||||||
) {
|
|
||||||
// Set `dst_ndindex_ptr->type`
|
|
||||||
dst_ndindex_ptr.gep(ctx, |f| f.type_).store(
|
|
||||||
ctx,
|
|
||||||
Int(NDIndexType::default()).const_int(generator, ctx.ctx, self.get_type_id(), false),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Set `dst_ndindex_ptr->data`
|
|
||||||
match self {
|
|
||||||
RustNDIndex::SingleElement(in_index) => {
|
|
||||||
let index_ptr = Int(Int32).alloca(generator, ctx);
|
|
||||||
index_ptr.store(ctx, *in_index);
|
|
||||||
|
|
||||||
dst_ndindex_ptr
|
|
||||||
.gep(ctx, |f| f.data)
|
|
||||||
.store(ctx, index_ptr.pointer_cast(generator, ctx, Int(Byte)));
|
|
||||||
}
|
|
||||||
RustNDIndex::Slice(in_rust_slice) => {
|
|
||||||
let user_slice_ptr = Struct(Slice(Int32)).alloca(generator, ctx);
|
|
||||||
in_rust_slice.write_to_slice(generator, ctx, user_slice_ptr);
|
|
||||||
|
|
||||||
dst_ndindex_ptr
|
|
||||||
.gep(ctx, |f| f.data)
|
|
||||||
.store(ctx, user_slice_ptr.pointer_cast(generator, ctx, Int(Byte)));
|
|
||||||
}
|
|
||||||
RustNDIndex::NewAxis | RustNDIndex::Ellipsis => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Serialize a list of `RustNDIndex` as a newly allocated LLVM array of `NDIndex`.
|
|
||||||
pub fn make_ndindices<G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
in_ndindices: &[RustNDIndex<'ctx>],
|
|
||||||
) -> (Instance<'ctx, Int<SizeT>>, Instance<'ctx, Ptr<Struct<NDIndex>>>) {
|
|
||||||
let ndindex_model = Struct(NDIndex);
|
|
||||||
|
|
||||||
// Allocate the LLVM ndindices.
|
|
||||||
let num_ndindices =
|
|
||||||
Int(SizeT).const_int(generator, ctx.ctx, in_ndindices.len() as u64, false);
|
|
||||||
let ndindices = ndindex_model.array_alloca(generator, ctx, num_ndindices.value);
|
|
||||||
|
|
||||||
// Initialize all of them.
|
|
||||||
for (i, in_ndindex) in in_ndindices.iter().enumerate() {
|
|
||||||
let pndindex = ndindices.offset_const(ctx, i64::try_from(i).unwrap());
|
|
||||||
in_ndindex.write_to_ndindex(generator, ctx, pndindex);
|
|
||||||
}
|
|
||||||
|
|
||||||
(num_ndindices, ndindices)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> NDArrayObject<'ctx> {
|
|
||||||
/// Get the expected `ndims` after indexing with `indices`.
|
|
||||||
#[must_use]
|
|
||||||
fn deduce_ndims_after_indexing_with(&self, indices: &[RustNDIndex<'ctx>]) -> u64 {
|
|
||||||
let mut ndims = self.ndims;
|
|
||||||
for index in indices {
|
|
||||||
match index {
|
|
||||||
RustNDIndex::SingleElement(_) => {
|
|
||||||
ndims -= 1; // Single elements decrements ndims
|
|
||||||
}
|
|
||||||
RustNDIndex::NewAxis => {
|
|
||||||
ndims += 1; // `np.newaxis` / `none` adds a new axis
|
|
||||||
}
|
|
||||||
RustNDIndex::Ellipsis | RustNDIndex::Slice(_) => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ndims
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Index into the ndarray, and return a newly-allocated view on this ndarray.
|
|
||||||
///
|
|
||||||
/// This function behaves like NumPy's ndarray indexing, but if the indices index
|
|
||||||
/// into a single element, an unsized ndarray is returned.
|
|
||||||
#[must_use]
|
|
||||||
pub fn index<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
indices: &[RustNDIndex<'ctx>],
|
|
||||||
) -> Self {
|
|
||||||
let dst_ndims = self.deduce_ndims_after_indexing_with(indices);
|
|
||||||
let dst_ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, dst_ndims);
|
|
||||||
|
|
||||||
let (num_indices, indices) = RustNDIndex::make_ndindices(generator, ctx, indices);
|
|
||||||
call_nac3_ndarray_index(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
num_indices,
|
|
||||||
indices,
|
|
||||||
self.instance,
|
|
||||||
dst_ndarray.instance,
|
|
||||||
);
|
|
||||||
|
|
||||||
dst_ndarray
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub mod util {
|
|
||||||
use itertools::Itertools;
|
|
||||||
use nac3parser::ast::{Expr, ExprKind};
|
|
||||||
|
|
||||||
use crate::{
|
|
||||||
codegen::{model::*, object::utils::slice::util::gen_slice, CodeGenContext, CodeGenerator},
|
|
||||||
typecheck::typedef::Type,
|
|
||||||
};
|
|
||||||
|
|
||||||
use super::RustNDIndex;
|
|
||||||
|
|
||||||
/// Generate LLVM code to transform an ndarray subscript expression to
|
|
||||||
/// its list of [`RustNDIndex`]
|
|
||||||
///
|
|
||||||
/// i.e.,
|
|
||||||
/// ```python
|
|
||||||
/// my_ndarray[::3, 1, :2:]
|
|
||||||
/// ^^^^^^^^^^^ Then these into a three `RustNDIndex`es
|
|
||||||
/// ```
|
|
||||||
pub fn gen_ndarray_subscript_ndindices<'ctx, G: CodeGenerator>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
subscript: &Expr<Option<Type>>,
|
|
||||||
) -> Result<Vec<RustNDIndex<'ctx>>, String> {
|
|
||||||
// TODO: Support https://numpy.org/doc/stable/user/basics.indexing.html#dimensional-indexing-tools
|
|
||||||
|
|
||||||
// Annoying notes about `slice`
|
|
||||||
// - `my_array[5]`
|
|
||||||
// - slice is a `Constant`
|
|
||||||
// - `my_array[:5]`
|
|
||||||
// - slice is a `Slice`
|
|
||||||
// - `my_array[:]`
|
|
||||||
// - slice is a `Slice`, but lower upper step would all be `Option::None`
|
|
||||||
// - `my_array[:, :]`
|
|
||||||
// - slice is now a `Tuple` of two `Slice`-s
|
|
||||||
//
|
|
||||||
// In summary:
|
|
||||||
// - when there is a comma "," within [], `slice` will be a `Tuple` of the entries.
|
|
||||||
// - when there is not comma "," within [] (i.e., just a single entry), `slice` will be that entry itself.
|
|
||||||
//
|
|
||||||
// So we first "flatten" out the slice expression
|
|
||||||
let index_exprs = match &subscript.node {
|
|
||||||
ExprKind::Tuple { elts, .. } => elts.iter().collect_vec(),
|
|
||||||
_ => vec![subscript],
|
|
||||||
};
|
|
||||||
|
|
||||||
// Process all index expressions
|
|
||||||
let mut rust_ndindices: Vec<RustNDIndex> = Vec::with_capacity(index_exprs.len()); // Not using iterators here because `?` is used here.
|
|
||||||
for index_expr in index_exprs {
|
|
||||||
// NOTE: Currently nac3core's slices do not have an object representation,
|
|
||||||
// so the code/implementation looks awkward - we have to do pattern matching on the expression
|
|
||||||
let ndindex = if let ExprKind::Slice { lower, upper, step } = &index_expr.node {
|
|
||||||
// Handle slices
|
|
||||||
let slice = gen_slice(generator, ctx, lower, upper, step)?;
|
|
||||||
RustNDIndex::Slice(slice)
|
|
||||||
} else {
|
|
||||||
// Treat and handle everything else as a single element index.
|
|
||||||
let index = generator.gen_expr(ctx, index_expr)?.unwrap().to_basic_value_enum(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
ctx.primitives.int32, // Must be int32, this checks for illegal values
|
|
||||||
)?;
|
|
||||||
let index = Int(Int32).check_value(generator, ctx.ctx, index).unwrap();
|
|
||||||
|
|
||||||
RustNDIndex::SingleElement(index)
|
|
||||||
};
|
|
||||||
rust_ndindices.push(ndindex);
|
|
||||||
}
|
|
||||||
Ok(rust_ndindices)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,219 +0,0 @@
|
||||||
use inkwell::values::BasicValueEnum;
|
|
||||||
use itertools::Itertools;
|
|
||||||
|
|
||||||
use crate::{
|
|
||||||
codegen::{
|
|
||||||
object::ndarray::{AnyObject, NDArrayObject},
|
|
||||||
stmt::gen_for_callback,
|
|
||||||
CodeGenContext, CodeGenerator,
|
|
||||||
},
|
|
||||||
typecheck::typedef::Type,
|
|
||||||
};
|
|
||||||
|
|
||||||
use super::{nditer::NDIterHandle, NDArrayOut, ScalarOrNDArray};
|
|
||||||
|
|
||||||
impl<'ctx> NDArrayObject<'ctx> {
|
|
||||||
/// Generate LLVM IR to broadcast `ndarray`s together, and starmap through them with `mapping` elementwise.
|
|
||||||
///
|
|
||||||
/// `mapping` is an LLVM IR generator. The input of `mapping` is the list of elements when iterating through
|
|
||||||
/// the input `ndarrays` after broadcasting. The output of `mapping` is the result of the elementwise operation.
|
|
||||||
///
|
|
||||||
/// `out` specifies whether the result should be a new ndarray or to be written an existing ndarray.
|
|
||||||
pub fn broadcast_starmap<'a, G, MappingFn>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
|
||||||
ndarrays: &[Self],
|
|
||||||
out: NDArrayOut<'ctx>,
|
|
||||||
mapping: MappingFn,
|
|
||||||
) -> Result<Self, String>
|
|
||||||
where
|
|
||||||
G: CodeGenerator + ?Sized,
|
|
||||||
MappingFn: FnOnce(
|
|
||||||
&mut G,
|
|
||||||
&mut CodeGenContext<'ctx, 'a>,
|
|
||||||
&[BasicValueEnum<'ctx>],
|
|
||||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
|
||||||
{
|
|
||||||
// Broadcast inputs
|
|
||||||
let broadcast_result = NDArrayObject::broadcast(generator, ctx, ndarrays);
|
|
||||||
|
|
||||||
let out_ndarray = match out {
|
|
||||||
NDArrayOut::NewNDArray { dtype } => {
|
|
||||||
// Create a new ndarray based on the broadcast shape.
|
|
||||||
let result_ndarray =
|
|
||||||
NDArrayObject::alloca(generator, ctx, dtype, broadcast_result.ndims);
|
|
||||||
result_ndarray.copy_shape_from_array(generator, ctx, broadcast_result.shape);
|
|
||||||
result_ndarray.create_data(generator, ctx);
|
|
||||||
result_ndarray
|
|
||||||
}
|
|
||||||
NDArrayOut::WriteToNDArray { ndarray: result_ndarray } => {
|
|
||||||
// Use an existing ndarray.
|
|
||||||
|
|
||||||
// Check that its shape is compatible with the broadcast shape.
|
|
||||||
result_ndarray.assert_can_be_written_by_out(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
broadcast_result.ndims,
|
|
||||||
broadcast_result.shape,
|
|
||||||
);
|
|
||||||
result_ndarray
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Map element-wise and store results into `mapped_ndarray`.
|
|
||||||
let nditer = NDIterHandle::new(generator, ctx, out_ndarray);
|
|
||||||
gen_for_callback(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
Some("broadcast_starmap"),
|
|
||||||
|generator, ctx| {
|
|
||||||
// Create NDIters for all broadcasted input ndarrays.
|
|
||||||
let other_nditers = broadcast_result
|
|
||||||
.ndarrays
|
|
||||||
.iter()
|
|
||||||
.map(|ndarray| NDIterHandle::new(generator, ctx, *ndarray))
|
|
||||||
.collect_vec();
|
|
||||||
Ok((nditer, other_nditers))
|
|
||||||
},
|
|
||||||
|generator, ctx, (out_nditer, _in_nditers)| {
|
|
||||||
// We can simply use `out_nditer`'s `has_element()`.
|
|
||||||
// `in_nditers`' `has_element()`s should return the same value.
|
|
||||||
Ok(out_nditer.has_element(generator, ctx).value)
|
|
||||||
},
|
|
||||||
|generator, ctx, _hooks, (out_nditer, in_nditers)| {
|
|
||||||
// Get all the scalars from the broadcasted input ndarrays, pass them to `mapping`,
|
|
||||||
// and write to `out_ndarray`.
|
|
||||||
let in_scalars = in_nditers
|
|
||||||
.iter()
|
|
||||||
.map(|nditer| nditer.get_scalar(generator, ctx).value)
|
|
||||||
.collect_vec();
|
|
||||||
|
|
||||||
let result = mapping(generator, ctx, &in_scalars)?;
|
|
||||||
|
|
||||||
let p = out_nditer.get_pointer(generator, ctx);
|
|
||||||
ctx.builder.build_store(p, result).unwrap();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
},
|
|
||||||
|generator, ctx, (out_nditer, in_nditers)| {
|
|
||||||
// Advance all iterators
|
|
||||||
out_nditer.next(generator, ctx);
|
|
||||||
in_nditers.iter().for_each(|nditer| nditer.next(generator, ctx));
|
|
||||||
Ok(())
|
|
||||||
},
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(out_ndarray)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Map through this ndarray with an elementwise function.
|
|
||||||
pub fn map<'a, G, Mapping>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
|
||||||
out: NDArrayOut<'ctx>,
|
|
||||||
mapping: Mapping,
|
|
||||||
) -> Result<Self, String>
|
|
||||||
where
|
|
||||||
G: CodeGenerator + ?Sized,
|
|
||||||
Mapping: FnOnce(
|
|
||||||
&mut G,
|
|
||||||
&mut CodeGenContext<'ctx, 'a>,
|
|
||||||
BasicValueEnum<'ctx>,
|
|
||||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
|
||||||
{
|
|
||||||
NDArrayObject::broadcast_starmap(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
&[*self],
|
|
||||||
out,
|
|
||||||
|generator, ctx, scalars| mapping(generator, ctx, scalars[0]),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> ScalarOrNDArray<'ctx> {
|
|
||||||
/// Starmap through a list of inputs using `mapping`, where an input could be an ndarray, a scalar.
|
|
||||||
///
|
|
||||||
/// This function is very helpful when implementing NumPy functions that takes on either scalars or ndarrays or a mix of them
|
|
||||||
/// as their inputs and produces either an ndarray with broadcast, or a scalar if all its inputs are all scalars.
|
|
||||||
///
|
|
||||||
/// For example ,this function can be used to implement `np.add`, which has the following behaviors:
|
|
||||||
/// - `np.add(3, 4) = 7` # (scalar, scalar) -> scalar
|
|
||||||
/// - `np.add(3, np.array([4, 5, 6]))` # (scalar, ndarray) -> ndarray; the first `scalar` is converted into an ndarray and broadcasted.
|
|
||||||
/// - `np.add(np.array([[1], [2], [3]]), np.array([[4, 5, 6]]))` # (ndarray, ndarray) -> ndarray; there is broadcasting.
|
|
||||||
///
|
|
||||||
/// ## Details:
|
|
||||||
///
|
|
||||||
/// If `inputs` are all [`ScalarOrNDArray::Scalar`], the output will be a [`ScalarOrNDArray::Scalar`] with type `ret_dtype`.
|
|
||||||
///
|
|
||||||
/// Otherwise (if there are any [`ScalarOrNDArray::NDArray`] in `inputs`), all inputs will be 'as-ndarray'-ed into ndarrays,
|
|
||||||
/// then all inputs (now all ndarrays) will be passed to [`NDArrayObject::broadcasting_starmap`] and **create** a new ndarray
|
|
||||||
/// with dtype `ret_dtype`.
|
|
||||||
pub fn broadcasting_starmap<'a, G, MappingFn>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
|
||||||
inputs: &[ScalarOrNDArray<'ctx>],
|
|
||||||
ret_dtype: Type,
|
|
||||||
mapping: MappingFn,
|
|
||||||
) -> Result<ScalarOrNDArray<'ctx>, String>
|
|
||||||
where
|
|
||||||
G: CodeGenerator + ?Sized,
|
|
||||||
MappingFn: FnOnce(
|
|
||||||
&mut G,
|
|
||||||
&mut CodeGenContext<'ctx, 'a>,
|
|
||||||
&[BasicValueEnum<'ctx>],
|
|
||||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
|
||||||
{
|
|
||||||
// Check if all inputs are Scalars
|
|
||||||
let all_scalars: Option<Vec<_>> = inputs.iter().map(AnyObject::try_from).try_collect().ok();
|
|
||||||
|
|
||||||
if let Some(scalars) = all_scalars {
|
|
||||||
let scalars = scalars.iter().map(|scalar| scalar.value).collect_vec();
|
|
||||||
let value = mapping(generator, ctx, &scalars)?;
|
|
||||||
|
|
||||||
Ok(ScalarOrNDArray::Scalar(AnyObject { ty: ret_dtype, value }))
|
|
||||||
} else {
|
|
||||||
// Promote all input to ndarrays and map through them.
|
|
||||||
let inputs = inputs.iter().map(|input| input.to_ndarray(generator, ctx)).collect_vec();
|
|
||||||
let ndarray = NDArrayObject::broadcast_starmap(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
&inputs,
|
|
||||||
NDArrayOut::NewNDArray { dtype: ret_dtype },
|
|
||||||
mapping,
|
|
||||||
)?;
|
|
||||||
Ok(ScalarOrNDArray::NDArray(ndarray))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Map through this [`ScalarOrNDArray`] with an elementwise function.
|
|
||||||
///
|
|
||||||
/// If this is a scalar, `mapping` will directly act on the scalar. This function will return a [`ScalarOrNDArray::Scalar`] of that result.
|
|
||||||
///
|
|
||||||
/// If this is an ndarray, `mapping` will be applied to the elements of the ndarray. A new ndarray of the results will be created and
|
|
||||||
/// returned as a [`ScalarOrNDArray::NDArray`].
|
|
||||||
pub fn map<'a, G, Mapping>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
|
||||||
ret_dtype: Type,
|
|
||||||
mapping: Mapping,
|
|
||||||
) -> Result<ScalarOrNDArray<'ctx>, String>
|
|
||||||
where
|
|
||||||
G: CodeGenerator + ?Sized,
|
|
||||||
Mapping: FnOnce(
|
|
||||||
&mut G,
|
|
||||||
&mut CodeGenContext<'ctx, 'a>,
|
|
||||||
BasicValueEnum<'ctx>,
|
|
||||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
|
||||||
{
|
|
||||||
ScalarOrNDArray::broadcasting_starmap(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
&[*self],
|
|
||||||
ret_dtype,
|
|
||||||
|generator, ctx, scalars| mapping(generator, ctx, scalars[0]),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,218 +0,0 @@
|
||||||
use std::cmp::max;
|
|
||||||
|
|
||||||
use nac3parser::ast::Operator;
|
|
||||||
use util::gen_for_model;
|
|
||||||
|
|
||||||
use crate::{
|
|
||||||
codegen::{
|
|
||||||
expr::gen_binop_expr_with_values, irrt::call_nac3_ndarray_matmul_calculate_shapes,
|
|
||||||
model::*, object::ndarray::indexing::RustNDIndex, CodeGenContext, CodeGenerator,
|
|
||||||
},
|
|
||||||
typecheck::{magic_methods::Binop, typedef::Type},
|
|
||||||
};
|
|
||||||
|
|
||||||
use super::{NDArrayObject, NDArrayOut};
|
|
||||||
|
|
||||||
/// Perform `np.einsum("...ij,...jk->...ik", in_a, in_b)`.
|
|
||||||
///
|
|
||||||
/// `dst_dtype` defines the dtype of the returned ndarray.
|
|
||||||
fn matmul_at_least_2d<'ctx, G: CodeGenerator>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
dst_dtype: Type,
|
|
||||||
in_a: NDArrayObject<'ctx>,
|
|
||||||
in_b: NDArrayObject<'ctx>,
|
|
||||||
) -> NDArrayObject<'ctx> {
|
|
||||||
assert!(in_a.ndims >= 2);
|
|
||||||
assert!(in_b.ndims >= 2);
|
|
||||||
|
|
||||||
// Deduce ndims of the result of matmul.
|
|
||||||
let ndims_int = max(in_a.ndims, in_b.ndims);
|
|
||||||
let ndims = Int(SizeT).const_int(generator, ctx.ctx, ndims_int, false);
|
|
||||||
|
|
||||||
// Broadcasts `in_a.shape[:-2]` and `in_b.shape[:-2]` together and allocate the
|
|
||||||
// destination ndarray to store the result of matmul.
|
|
||||||
let (lhs, rhs, dst) = {
|
|
||||||
let in_lhs_ndims = in_a.ndims_llvm(generator, ctx.ctx);
|
|
||||||
let in_lhs_shape = in_a.instance.get(generator, ctx, |f| f.shape);
|
|
||||||
let in_rhs_ndims = in_b.ndims_llvm(generator, ctx.ctx);
|
|
||||||
let in_rhs_shape = in_b.instance.get(generator, ctx, |f| f.shape);
|
|
||||||
let lhs_shape = Int(SizeT).array_alloca(generator, ctx, ndims.value);
|
|
||||||
let rhs_shape = Int(SizeT).array_alloca(generator, ctx, ndims.value);
|
|
||||||
let dst_shape = Int(SizeT).array_alloca(generator, ctx, ndims.value);
|
|
||||||
|
|
||||||
// Matmul dimension compatibility is checked here.
|
|
||||||
call_nac3_ndarray_matmul_calculate_shapes(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
in_lhs_ndims,
|
|
||||||
in_lhs_shape,
|
|
||||||
in_rhs_ndims,
|
|
||||||
in_rhs_shape,
|
|
||||||
ndims,
|
|
||||||
lhs_shape,
|
|
||||||
rhs_shape,
|
|
||||||
dst_shape,
|
|
||||||
);
|
|
||||||
|
|
||||||
let lhs = in_a.broadcast_to(generator, ctx, ndims_int, lhs_shape);
|
|
||||||
let rhs = in_b.broadcast_to(generator, ctx, ndims_int, rhs_shape);
|
|
||||||
|
|
||||||
let dst = NDArrayObject::alloca(generator, ctx, dst_dtype, ndims_int);
|
|
||||||
dst.copy_shape_from_array(generator, ctx, dst_shape);
|
|
||||||
dst.create_data(generator, ctx);
|
|
||||||
|
|
||||||
(lhs, rhs, dst)
|
|
||||||
};
|
|
||||||
|
|
||||||
let len = lhs.instance.get(generator, ctx, |f| f.shape).get_index_const(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
i64::try_from(ndims_int - 1).unwrap(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let at_row = i64::try_from(ndims_int - 2).unwrap();
|
|
||||||
let at_col = i64::try_from(ndims_int - 1).unwrap();
|
|
||||||
|
|
||||||
let dst_dtype_llvm = ctx.get_llvm_type(generator, dst_dtype);
|
|
||||||
let dst_zero = dst_dtype_llvm.const_zero();
|
|
||||||
|
|
||||||
dst.foreach(generator, ctx, |generator, ctx, _, hdl| {
|
|
||||||
let pdst_ij = hdl.get_pointer(generator, ctx);
|
|
||||||
|
|
||||||
ctx.builder.build_store(pdst_ij, dst_zero).unwrap();
|
|
||||||
|
|
||||||
let indices = hdl.get_indices();
|
|
||||||
let i = indices.get_index_const(generator, ctx, at_row);
|
|
||||||
let j = indices.get_index_const(generator, ctx, at_col);
|
|
||||||
|
|
||||||
let num_0 = Int(SizeT).const_int(generator, ctx.ctx, 0, false);
|
|
||||||
let num_1 = Int(SizeT).const_int(generator, ctx.ctx, 1, false);
|
|
||||||
|
|
||||||
gen_for_model(generator, ctx, num_0, len, num_1, |generator, ctx, _, k| {
|
|
||||||
// `indices` is modified to index into `a` and `b`, and restored.
|
|
||||||
indices.set_index_const(ctx, at_row, i);
|
|
||||||
indices.set_index_const(ctx, at_col, k);
|
|
||||||
let a_ik = lhs.get_scalar_by_indices(generator, ctx, indices);
|
|
||||||
|
|
||||||
indices.set_index_const(ctx, at_row, k);
|
|
||||||
indices.set_index_const(ctx, at_col, j);
|
|
||||||
let b_kj = rhs.get_scalar_by_indices(generator, ctx, indices);
|
|
||||||
|
|
||||||
// Restore `indices`.
|
|
||||||
indices.set_index_const(ctx, at_row, i);
|
|
||||||
indices.set_index_const(ctx, at_col, j);
|
|
||||||
|
|
||||||
// x = a_[...]ik * b_[...]kj
|
|
||||||
let x = gen_binop_expr_with_values(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
(&Some(lhs.dtype), a_ik.value),
|
|
||||||
Binop::normal(Operator::Mult),
|
|
||||||
(&Some(rhs.dtype), b_kj.value),
|
|
||||||
ctx.current_loc,
|
|
||||||
)?
|
|
||||||
.unwrap()
|
|
||||||
.to_basic_value_enum(ctx, generator, dst_dtype)?;
|
|
||||||
|
|
||||||
// dst_[...]ij += x
|
|
||||||
let dst_ij = ctx.builder.build_load(pdst_ij, "").unwrap();
|
|
||||||
let dst_ij = gen_binop_expr_with_values(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
(&Some(dst_dtype), dst_ij),
|
|
||||||
Binop::normal(Operator::Add),
|
|
||||||
(&Some(dst_dtype), x),
|
|
||||||
ctx.current_loc,
|
|
||||||
)?
|
|
||||||
.unwrap()
|
|
||||||
.to_basic_value_enum(ctx, generator, dst_dtype)?;
|
|
||||||
ctx.builder.build_store(pdst_ij, dst_ij).unwrap();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
dst
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> NDArrayObject<'ctx> {
|
|
||||||
/// Perform `np.matmul` according to the rules in
|
|
||||||
/// <https://numpy.org/doc/stable/reference/generated/numpy.matmul.html>.
|
|
||||||
///
|
|
||||||
/// This function always return an [`NDArrayObject`]. You may want to use [`NDArrayObject::split_unsized`]
|
|
||||||
/// to handle when the output could be a scalar.
|
|
||||||
///
|
|
||||||
/// `dst_dtype` defines the dtype of the returned ndarray.
|
|
||||||
pub fn matmul<G: CodeGenerator>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
a: Self,
|
|
||||||
b: Self,
|
|
||||||
out: NDArrayOut<'ctx>,
|
|
||||||
) -> Self {
|
|
||||||
// Sanity check, but type inference should prevent this.
|
|
||||||
assert!(a.ndims > 0 && b.ndims > 0, "np.matmul disallows scalar input");
|
|
||||||
|
|
||||||
/*
|
|
||||||
If both arguments are 2-D they are multiplied like conventional matrices.
|
|
||||||
If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the last two indices and broadcast accordingly.
|
|
||||||
If the first argument is 1-D, it is promoted to a matrix by prepending a 1 to its dimensions. After matrix multiplication the prepended 1 is removed.
|
|
||||||
If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its dimensions. After matrix multiplication the appended 1 is removed.
|
|
||||||
*/
|
|
||||||
|
|
||||||
let new_a = if a.ndims == 1 {
|
|
||||||
// Prepend 1 to its dimensions
|
|
||||||
a.index(generator, ctx, &[RustNDIndex::NewAxis, RustNDIndex::Ellipsis])
|
|
||||||
} else {
|
|
||||||
a
|
|
||||||
};
|
|
||||||
|
|
||||||
let new_b = if b.ndims == 1 {
|
|
||||||
// Append 1 to its dimensions
|
|
||||||
b.index(generator, ctx, &[RustNDIndex::Ellipsis, RustNDIndex::NewAxis])
|
|
||||||
} else {
|
|
||||||
b
|
|
||||||
};
|
|
||||||
|
|
||||||
// NOTE: `result` will always be a newly allocated ndarray.
|
|
||||||
// Current implementation cannot do in-place matrix muliplication.
|
|
||||||
let mut result = matmul_at_least_2d(generator, ctx, out.get_dtype(), new_a, new_b);
|
|
||||||
|
|
||||||
// Postprocessing on the result to remove prepended/appended axes.
|
|
||||||
let mut postindices = vec![];
|
|
||||||
let zero = Int(Int32).const_0(generator, ctx.ctx);
|
|
||||||
|
|
||||||
if a.ndims == 1 {
|
|
||||||
// Remove the prepended 1
|
|
||||||
postindices.push(RustNDIndex::SingleElement(zero));
|
|
||||||
}
|
|
||||||
|
|
||||||
if b.ndims == 1 {
|
|
||||||
// Remove the appended 1
|
|
||||||
postindices.push(RustNDIndex::Ellipsis);
|
|
||||||
postindices.push(RustNDIndex::SingleElement(zero));
|
|
||||||
}
|
|
||||||
|
|
||||||
if !postindices.is_empty() {
|
|
||||||
result = result.index(generator, ctx, &postindices);
|
|
||||||
}
|
|
||||||
|
|
||||||
match out {
|
|
||||||
NDArrayOut::NewNDArray { .. } => result,
|
|
||||||
NDArrayOut::WriteToNDArray { ndarray: out_ndarray } => {
|
|
||||||
let result_shape = result.instance.get(generator, ctx, |f| f.shape);
|
|
||||||
out_ndarray.assert_can_be_written_by_out(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
result.ndims,
|
|
||||||
result_shape,
|
|
||||||
);
|
|
||||||
|
|
||||||
out_ndarray.copy_data_from(generator, ctx, result);
|
|
||||||
out_ndarray
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,670 +0,0 @@
|
||||||
pub mod array;
|
|
||||||
pub mod broadcast;
|
|
||||||
pub mod contiguous;
|
|
||||||
pub mod factory;
|
|
||||||
pub mod indexing;
|
|
||||||
pub mod map;
|
|
||||||
pub mod matmul;
|
|
||||||
pub mod nditer;
|
|
||||||
pub mod shape_util;
|
|
||||||
pub mod view;
|
|
||||||
|
|
||||||
use inkwell::{
|
|
||||||
context::Context,
|
|
||||||
types::BasicType,
|
|
||||||
values::{BasicValue, BasicValueEnum, PointerValue},
|
|
||||||
AddressSpace,
|
|
||||||
};
|
|
||||||
|
|
||||||
use crate::{
|
|
||||||
codegen::{
|
|
||||||
irrt::{
|
|
||||||
call_nac3_ndarray_copy_data, call_nac3_ndarray_get_nth_pelement,
|
|
||||||
call_nac3_ndarray_get_pelement_by_indices, call_nac3_ndarray_is_c_contiguous,
|
|
||||||
call_nac3_ndarray_len, call_nac3_ndarray_nbytes,
|
|
||||||
call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_size,
|
|
||||||
call_nac3_ndarray_util_assert_output_shape_same,
|
|
||||||
},
|
|
||||||
model::*,
|
|
||||||
CodeGenContext, CodeGenerator,
|
|
||||||
},
|
|
||||||
toplevel::{
|
|
||||||
helper::{create_ndims, extract_ndims},
|
|
||||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
|
||||||
},
|
|
||||||
typecheck::typedef::{Type, TypeEnum},
|
|
||||||
};
|
|
||||||
|
|
||||||
use super::{any::AnyObject, tuple::TupleObject};
|
|
||||||
|
|
||||||
/// Fields of [`NDArray`]
|
|
||||||
pub struct NDArrayFields<'ctx, F: FieldTraversal<'ctx>> {
|
|
||||||
pub data: F::Output<Ptr<Int<Byte>>>,
|
|
||||||
pub itemsize: F::Output<Int<SizeT>>,
|
|
||||||
pub ndims: F::Output<Int<SizeT>>,
|
|
||||||
pub shape: F::Output<Ptr<Int<SizeT>>>,
|
|
||||||
pub strides: F::Output<Ptr<Int<SizeT>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A strided ndarray in NAC3.
|
|
||||||
///
|
|
||||||
/// See IRRT implementation for details about its fields.
|
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
|
||||||
pub struct NDArray;
|
|
||||||
|
|
||||||
impl<'ctx> StructKind<'ctx> for NDArray {
|
|
||||||
type Fields<F: FieldTraversal<'ctx>> = NDArrayFields<'ctx, F>;
|
|
||||||
|
|
||||||
fn iter_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
|
|
||||||
Self::Fields {
|
|
||||||
data: traversal.add_auto("data"),
|
|
||||||
itemsize: traversal.add_auto("itemsize"),
|
|
||||||
ndims: traversal.add_auto("ndims"),
|
|
||||||
shape: traversal.add_auto("shape"),
|
|
||||||
strides: traversal.add_auto("strides"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A NAC3 Python ndarray object.
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
|
||||||
pub struct NDArrayObject<'ctx> {
|
|
||||||
pub dtype: Type,
|
|
||||||
pub ndims: u64,
|
|
||||||
pub instance: Instance<'ctx, Ptr<Struct<NDArray>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> NDArrayObject<'ctx> {
|
|
||||||
/// Attempt to convert an [`AnyObject`] into an [`NDArrayObject`].
|
|
||||||
pub fn from_object<G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
object: AnyObject<'ctx>,
|
|
||||||
) -> NDArrayObject<'ctx> {
|
|
||||||
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, object.ty);
|
|
||||||
let ndims = extract_ndims(&ctx.unifier, ndims);
|
|
||||||
Self::from_value_and_unpacked_types(generator, ctx, object.value, dtype, ndims)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Like [`NDArrayObject::from_object`] but you directly supply the ndarray's
|
|
||||||
/// `dtype` and `ndims`.
|
|
||||||
pub fn from_value_and_unpacked_types<V: BasicValue<'ctx>, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
value: V,
|
|
||||||
dtype: Type,
|
|
||||||
ndims: u64,
|
|
||||||
) -> Self {
|
|
||||||
let value = Ptr(Struct(NDArray)).check_value(generator, ctx.ctx, value).unwrap();
|
|
||||||
NDArrayObject { dtype, ndims, instance: value }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get this ndarray's `ndims` as an LLVM constant.
|
|
||||||
pub fn ndims_llvm<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> Instance<'ctx, Int<SizeT>> {
|
|
||||||
Int(SizeT).const_int(generator, ctx, self.ndims, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the typechecker ndarray type of this [`NDArrayObject`].
|
|
||||||
pub fn get_type(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Type {
|
|
||||||
let ndims = create_ndims(&mut ctx.unifier, self.ndims);
|
|
||||||
make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(self.dtype), Some(ndims))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Forget that this is an ndarray and convert into an [`AnyObject`].
|
|
||||||
pub fn to_any(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> AnyObject<'ctx> {
|
|
||||||
let ty = self.get_type(ctx);
|
|
||||||
AnyObject { value: self.instance.value.as_basic_value_enum(), ty }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Allocate an ndarray on the stack given its `ndims` and `dtype`.
|
|
||||||
///
|
|
||||||
/// `shape` and `strides` will be automatically allocated onto the stack.
|
|
||||||
///
|
|
||||||
/// The returned ndarray's content will be:
|
|
||||||
/// - `data`: uninitialized.
|
|
||||||
/// - `itemsize`: set to the `sizeof()` of `dtype`.
|
|
||||||
/// - `ndims`: set to the value of `ndims`.
|
|
||||||
/// - `shape`: allocated with an array of length `ndims` with uninitialized values.
|
|
||||||
/// - `strides`: allocated with an array of length `ndims` with uninitialized values.
|
|
||||||
pub fn alloca<G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
dtype: Type,
|
|
||||||
ndims: u64,
|
|
||||||
) -> Self {
|
|
||||||
let ndarray = Struct(NDArray).alloca(generator, ctx);
|
|
||||||
|
|
||||||
let itemsize = ctx.get_llvm_type(generator, dtype).size_of().unwrap();
|
|
||||||
let itemsize = Int(SizeT).z_extend_or_truncate(generator, ctx, itemsize);
|
|
||||||
ndarray.set(ctx, |f| f.itemsize, itemsize);
|
|
||||||
|
|
||||||
let ndims_val = Int(SizeT).const_int(generator, ctx.ctx, ndims, false);
|
|
||||||
ndarray.set(ctx, |f| f.ndims, ndims_val);
|
|
||||||
|
|
||||||
let shape = Int(SizeT).array_alloca(generator, ctx, ndims_val.value);
|
|
||||||
ndarray.set(ctx, |f| f.shape, shape);
|
|
||||||
|
|
||||||
let strides = Int(SizeT).array_alloca(generator, ctx, ndims_val.value);
|
|
||||||
ndarray.set(ctx, |f| f.strides, strides);
|
|
||||||
|
|
||||||
NDArrayObject { dtype, ndims, instance: ndarray }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convenience function. Allocate an [`NDArrayObject`] with a statically known shape.
|
|
||||||
///
|
|
||||||
/// The returned [`NDArrayObject`]'s `data` and `strides` are uninitialized.
|
|
||||||
pub fn alloca_constant_shape<G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
dtype: Type,
|
|
||||||
shape: &[u64],
|
|
||||||
) -> Self {
|
|
||||||
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, shape.len() as u64);
|
|
||||||
|
|
||||||
// Write shape
|
|
||||||
let dst_shape = ndarray.instance.get(generator, ctx, |f| f.shape);
|
|
||||||
for (i, dim) in shape.iter().enumerate() {
|
|
||||||
let dim = Int(SizeT).const_int(generator, ctx.ctx, *dim, false);
|
|
||||||
dst_shape.offset_const(ctx, i64::try_from(i).unwrap()).store(ctx, dim);
|
|
||||||
}
|
|
||||||
|
|
||||||
ndarray
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convenience function. Allocate an [`NDArrayObject`] with a dynamically known shape.
|
|
||||||
///
|
|
||||||
/// The returned [`NDArrayObject`]'s `data` and `strides` are uninitialized.
|
|
||||||
pub fn alloca_dynamic_shape<G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
dtype: Type,
|
|
||||||
shape: &[Instance<'ctx, Int<SizeT>>],
|
|
||||||
) -> Self {
|
|
||||||
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, shape.len() as u64);
|
|
||||||
|
|
||||||
// Write shape
|
|
||||||
let dst_shape = ndarray.instance.get(generator, ctx, |f| f.shape);
|
|
||||||
for (i, dim) in shape.iter().enumerate() {
|
|
||||||
dst_shape.offset_const(ctx, i64::try_from(i).unwrap()).store(ctx, *dim);
|
|
||||||
}
|
|
||||||
|
|
||||||
ndarray
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Initialize an ndarray's `data` by allocating a buffer on the stack.
|
|
||||||
/// The allocated data buffer is considered to be *owned* by the ndarray.
|
|
||||||
///
|
|
||||||
/// `strides` of the ndarray will also be updated with `set_strides_by_shape`.
|
|
||||||
///
|
|
||||||
/// `shape` and `itemsize` of the ndarray ***must*** be initialized first.
|
|
||||||
pub fn create_data<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
) {
|
|
||||||
let nbytes = self.nbytes(generator, ctx);
|
|
||||||
|
|
||||||
let data = Int(Byte).array_alloca(generator, ctx, nbytes.value);
|
|
||||||
self.instance.set(ctx, |f| f.data, data);
|
|
||||||
|
|
||||||
self.set_strides_contiguous(generator, ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Copy shape dimensions from an array.
|
|
||||||
pub fn copy_shape_from_array<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
|
||||||
) {
|
|
||||||
let num_items = self.ndims_llvm(generator, ctx.ctx).value;
|
|
||||||
self.instance.get(generator, ctx, |f| f.shape).copy_from(generator, ctx, shape, num_items);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Copy shape dimensions from an ndarray.
|
|
||||||
/// Panics if `ndims` mismatches.
|
|
||||||
pub fn copy_shape_from_ndarray<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
src_ndarray: NDArrayObject<'ctx>,
|
|
||||||
) {
|
|
||||||
assert_eq!(self.ndims, src_ndarray.ndims);
|
|
||||||
let src_shape = src_ndarray.instance.get(generator, ctx, |f| f.shape);
|
|
||||||
self.copy_shape_from_array(generator, ctx, src_shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Copy strides dimensions from an array.
|
|
||||||
pub fn copy_strides_from_array<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
strides: Instance<'ctx, Ptr<Int<SizeT>>>,
|
|
||||||
) {
|
|
||||||
let num_items = self.ndims_llvm(generator, ctx.ctx).value;
|
|
||||||
self.instance
|
|
||||||
.get(generator, ctx, |f| f.strides)
|
|
||||||
.copy_from(generator, ctx, strides, num_items);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Copy strides dimensions from an ndarray.
|
|
||||||
/// Panics if `ndims` mismatches.
|
|
||||||
pub fn copy_strides_from_ndarray<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
src_ndarray: NDArrayObject<'ctx>,
|
|
||||||
) {
|
|
||||||
assert_eq!(self.ndims, src_ndarray.ndims);
|
|
||||||
let src_strides = src_ndarray.instance.get(generator, ctx, |f| f.strides);
|
|
||||||
self.copy_strides_from_array(generator, ctx, src_strides);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the `np.size()` of this ndarray.
|
|
||||||
pub fn size<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
) -> Instance<'ctx, Int<SizeT>> {
|
|
||||||
call_nac3_ndarray_size(generator, ctx, self.instance)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the `ndarray.nbytes` of this ndarray.
|
|
||||||
pub fn nbytes<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
) -> Instance<'ctx, Int<SizeT>> {
|
|
||||||
call_nac3_ndarray_nbytes(generator, ctx, self.instance)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the `len()` of this ndarray.
|
|
||||||
pub fn len<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
) -> Instance<'ctx, Int<SizeT>> {
|
|
||||||
call_nac3_ndarray_len(generator, ctx, self.instance)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if this ndarray is C-contiguous.
|
|
||||||
///
|
|
||||||
/// See NumPy's `flags["C_CONTIGUOUS"]`: <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags>
|
|
||||||
pub fn is_c_contiguous<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
) -> Instance<'ctx, Int<Bool>> {
|
|
||||||
call_nac3_ndarray_is_c_contiguous(generator, ctx, self.instance)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the pointer to the n-th (0-based) element.
|
|
||||||
///
|
|
||||||
/// The returned pointer has the element type of the LLVM type of this ndarray's `dtype`.
|
|
||||||
pub fn get_nth_pelement<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
nth: Instance<'ctx, Int<SizeT>>,
|
|
||||||
) -> PointerValue<'ctx> {
|
|
||||||
let elem_ty = ctx.get_llvm_type(generator, self.dtype);
|
|
||||||
|
|
||||||
let p = call_nac3_ndarray_get_nth_pelement(generator, ctx, self.instance, nth);
|
|
||||||
ctx.builder
|
|
||||||
.build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), "")
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the n-th (0-based) scalar.
|
|
||||||
pub fn get_nth_scalar<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
nth: Instance<'ctx, Int<SizeT>>,
|
|
||||||
) -> AnyObject<'ctx> {
|
|
||||||
let ptr = self.get_nth_pelement(generator, ctx, nth);
|
|
||||||
let value = ctx.builder.build_load(ptr, "").unwrap();
|
|
||||||
AnyObject { ty: self.dtype, value }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the pointer to the element indexed by `indices`.
|
|
||||||
///
|
|
||||||
/// The returned pointer has the element type of the LLVM type of this ndarray's `dtype`.
|
|
||||||
pub fn get_pelement_by_indices<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
indices: Instance<'ctx, Ptr<Int<SizeT>>>,
|
|
||||||
) -> PointerValue<'ctx> {
|
|
||||||
let elem_ty = ctx.get_llvm_type(generator, self.dtype);
|
|
||||||
|
|
||||||
let p = call_nac3_ndarray_get_pelement_by_indices(generator, ctx, self.instance, indices);
|
|
||||||
ctx.builder
|
|
||||||
.build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), "")
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the scalar indexed by `indices`.
|
|
||||||
pub fn get_scalar_by_indices<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
indices: Instance<'ctx, Ptr<Int<SizeT>>>,
|
|
||||||
) -> AnyObject<'ctx> {
|
|
||||||
let ptr = self.get_pelement_by_indices(generator, ctx, indices);
|
|
||||||
let value = ctx.builder.build_load(ptr, "").unwrap();
|
|
||||||
AnyObject { ty: self.dtype, value }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`.
|
|
||||||
///
|
|
||||||
/// Update the ndarray's strides to make the ndarray contiguous.
|
|
||||||
pub fn set_strides_contiguous<G: CodeGenerator + ?Sized>(
|
|
||||||
self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
) {
|
|
||||||
call_nac3_ndarray_set_strides_by_shape(generator, ctx, self.instance);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Clone/Copy this ndarray - Allocate a new ndarray with the same shape as this ndarray and copy the contents over.
|
|
||||||
///
|
|
||||||
/// The new ndarray will own its data and will be C-contiguous.
|
|
||||||
#[must_use]
|
|
||||||
pub fn make_copy<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
) -> Self {
|
|
||||||
let clone = NDArrayObject::alloca(generator, ctx, self.dtype, self.ndims);
|
|
||||||
|
|
||||||
let shape = self.instance.gep(ctx, |f| f.shape).load(generator, ctx);
|
|
||||||
clone.copy_shape_from_array(generator, ctx, shape);
|
|
||||||
clone.create_data(generator, ctx);
|
|
||||||
clone.copy_data_from(generator, ctx, *self);
|
|
||||||
clone
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Copy data from another ndarray.
|
|
||||||
///
|
|
||||||
/// This ndarray and `src` is that their `np.size()` should be the same. Their shapes
|
|
||||||
/// do not matter. The copying order is determined by how their flattened views look.
|
|
||||||
///
|
|
||||||
/// Panics if the `dtype`s of ndarrays are different.
|
|
||||||
pub fn copy_data_from<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
src: NDArrayObject<'ctx>,
|
|
||||||
) {
|
|
||||||
assert!(ctx.unifier.unioned(self.dtype, src.dtype), "self and src dtype should match");
|
|
||||||
call_nac3_ndarray_copy_data(generator, ctx, src.instance, self.instance);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar.
|
|
||||||
#[must_use]
|
|
||||||
pub fn is_unsized(&self) -> bool {
|
|
||||||
self.ndims == 0
|
|
||||||
}
|
|
||||||
|
|
||||||
/// If this ndarray is unsized, return its sole value as an [`AnyObject`].
|
|
||||||
/// Otherwise, do nothing and return the ndarray itself.
|
|
||||||
pub fn split_unsized<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
) -> ScalarOrNDArray<'ctx> {
|
|
||||||
if self.is_unsized() {
|
|
||||||
// NOTE: `np.size(self) == 0` here is never possible.
|
|
||||||
let zero = Int(SizeT).const_0(generator, ctx.ctx);
|
|
||||||
let value = self.get_nth_scalar(generator, ctx, zero).value;
|
|
||||||
|
|
||||||
ScalarOrNDArray::Scalar(AnyObject { ty: self.dtype, value })
|
|
||||||
} else {
|
|
||||||
ScalarOrNDArray::NDArray(*self)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Fill the ndarray with a scalar.
|
|
||||||
///
|
|
||||||
/// `fill_value` must have the same LLVM type as the `dtype` of this ndarray.
|
|
||||||
pub fn fill<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
value: BasicValueEnum<'ctx>,
|
|
||||||
) {
|
|
||||||
// TODO: It is possible to optimize this by exploiting contiguous strides with memset.
|
|
||||||
// Probably best to implement in IRRT.
|
|
||||||
self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
|
|
||||||
let p = nditer.get_pointer(generator, ctx);
|
|
||||||
ctx.builder.build_store(p, value).unwrap();
|
|
||||||
Ok(())
|
|
||||||
})
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create the shape tuple of this ndarray like `np.shape(<ndarray>)`.
|
|
||||||
///
|
|
||||||
/// The returned integers in the tuple are in int32.
|
|
||||||
pub fn make_shape_tuple<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
) -> TupleObject<'ctx> {
|
|
||||||
// TODO: Return a tuple of SizeT
|
|
||||||
|
|
||||||
let mut objects = Vec::with_capacity(self.ndims as usize);
|
|
||||||
|
|
||||||
for i in 0..self.ndims {
|
|
||||||
let dim = self
|
|
||||||
.instance
|
|
||||||
.get(generator, ctx, |f| f.shape)
|
|
||||||
.get_index_const(generator, ctx, i64::try_from(i).unwrap())
|
|
||||||
.truncate_or_bit_cast(generator, ctx, Int32);
|
|
||||||
|
|
||||||
objects.push(AnyObject {
|
|
||||||
ty: ctx.primitives.int32,
|
|
||||||
value: dim.value.as_basic_value_enum(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
TupleObject::from_objects(generator, ctx, objects)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create the strides tuple of this ndarray like `<ndarray>.strides`.
|
|
||||||
///
|
|
||||||
/// The returned integers in the tuple are in int32.
|
|
||||||
pub fn make_strides_tuple<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
) -> TupleObject<'ctx> {
|
|
||||||
// TODO: Return a tuple of SizeT.
|
|
||||||
|
|
||||||
let mut objects = Vec::with_capacity(self.ndims as usize);
|
|
||||||
|
|
||||||
for i in 0..self.ndims {
|
|
||||||
let dim = self
|
|
||||||
.instance
|
|
||||||
.get(generator, ctx, |f| f.strides)
|
|
||||||
.get_index_const(generator, ctx, i64::try_from(i).unwrap())
|
|
||||||
.truncate_or_bit_cast(generator, ctx, Int32);
|
|
||||||
|
|
||||||
objects.push(AnyObject {
|
|
||||||
ty: ctx.primitives.int32,
|
|
||||||
value: dim.value.as_basic_value_enum(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
TupleObject::from_objects(generator, ctx, objects)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create an unsized ndarray to contain `object`.
|
|
||||||
pub fn make_unsized<G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
object: AnyObject<'ctx>,
|
|
||||||
) -> NDArrayObject<'ctx> {
|
|
||||||
// We have to put the value on the stack to get a data pointer.
|
|
||||||
let data = ctx.builder.build_alloca(object.value.get_type(), "make_unsized").unwrap();
|
|
||||||
ctx.builder.build_store(data, object.value).unwrap();
|
|
||||||
let data = Ptr(Int(Byte)).pointer_cast(generator, ctx, data);
|
|
||||||
|
|
||||||
let ndarray = NDArrayObject::alloca(generator, ctx, object.ty, 0);
|
|
||||||
ndarray.instance.set(ctx, |f| f.data, data);
|
|
||||||
ndarray
|
|
||||||
}
|
|
||||||
/// Check if this `NDArray` can be used as an `out` ndarray for an operation.
|
|
||||||
///
|
|
||||||
/// Raise an exception if the shapes do not match.
|
|
||||||
pub fn assert_can_be_written_by_out<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
out_ndims: u64,
|
|
||||||
out_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
|
||||||
) {
|
|
||||||
let ndarray_ndims = self.ndims_llvm(generator, ctx.ctx);
|
|
||||||
let ndarray_shape = self.instance.get(generator, ctx, |f| f.shape);
|
|
||||||
|
|
||||||
let output_ndims = Int(SizeT).const_int(generator, ctx.ctx, out_ndims, false);
|
|
||||||
let output_shape = out_shape;
|
|
||||||
|
|
||||||
call_nac3_ndarray_util_assert_output_shape_same(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
ndarray_ndims,
|
|
||||||
ndarray_shape,
|
|
||||||
output_ndims,
|
|
||||||
output_shape,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A convenience enum for implementing functions that acts on scalars or ndarrays or both.
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
|
||||||
pub enum ScalarOrNDArray<'ctx> {
|
|
||||||
Scalar(AnyObject<'ctx>),
|
|
||||||
NDArray(NDArrayObject<'ctx>),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for AnyObject<'ctx> {
|
|
||||||
type Error = ();
|
|
||||||
|
|
||||||
fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result<Self, Self::Error> {
|
|
||||||
match value {
|
|
||||||
ScalarOrNDArray::Scalar(scalar) => Ok(*scalar),
|
|
||||||
ScalarOrNDArray::NDArray(_ndarray) => Err(()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for NDArrayObject<'ctx> {
|
|
||||||
type Error = ();
|
|
||||||
|
|
||||||
fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result<Self, Self::Error> {
|
|
||||||
match value {
|
|
||||||
ScalarOrNDArray::Scalar(_scalar) => Err(()),
|
|
||||||
ScalarOrNDArray::NDArray(ndarray) => Ok(*ndarray),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> ScalarOrNDArray<'ctx> {
|
|
||||||
/// Split on `object` either into a scalar or an ndarray.
|
|
||||||
///
|
|
||||||
/// If `object` is an ndarray, [`ScalarOrNDArray::NDArray`].
|
|
||||||
///
|
|
||||||
/// For everything else, it is wrapped with [`ScalarOrNDArray::Scalar`].
|
|
||||||
pub fn split_object<G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
object: AnyObject<'ctx>,
|
|
||||||
) -> ScalarOrNDArray<'ctx> {
|
|
||||||
match &*ctx.unifier.get_ty(object.ty) {
|
|
||||||
TypeEnum::TObj { obj_id, .. }
|
|
||||||
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
|
|
||||||
{
|
|
||||||
let ndarray = NDArrayObject::from_object(generator, ctx, object);
|
|
||||||
ScalarOrNDArray::NDArray(ndarray)
|
|
||||||
}
|
|
||||||
_ => ScalarOrNDArray::Scalar(object),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`].
|
|
||||||
#[must_use]
|
|
||||||
pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> {
|
|
||||||
match self {
|
|
||||||
ScalarOrNDArray::Scalar(scalar) => scalar.value,
|
|
||||||
ScalarOrNDArray::NDArray(ndarray) => ndarray.instance.value.as_basic_value_enum(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convert this [`ScalarOrNDArray`] to an ndarray - behaves like `np.asarray`.
|
|
||||||
/// - If this is an ndarray, the ndarray is returned.
|
|
||||||
/// - If this is a scalar, this function returns new ndarray created with [`NDArrayObject::make_unsized`].
|
|
||||||
pub fn to_ndarray<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
) -> NDArrayObject<'ctx> {
|
|
||||||
match self {
|
|
||||||
ScalarOrNDArray::NDArray(ndarray) => *ndarray,
|
|
||||||
ScalarOrNDArray::Scalar(scalar) => NDArrayObject::make_unsized(generator, ctx, *scalar),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the dtype of the ndarray created if this were called with [`ScalarOrNDArray::to_ndarray`].
|
|
||||||
#[must_use]
|
|
||||||
pub fn get_dtype(&self) -> Type {
|
|
||||||
match self {
|
|
||||||
ScalarOrNDArray::NDArray(ndarray) => ndarray.dtype,
|
|
||||||
ScalarOrNDArray::Scalar(scalar) => scalar.ty,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// An helper enum specifying how a function should produce its output.
|
|
||||||
///
|
|
||||||
/// Many functions in NumPy has an optional `out` parameter (e.g., `matmul`). If `out` is specified
|
|
||||||
/// with an ndarray, the result of a function will be written to `out`. If `out` is not specified, a function will
|
|
||||||
/// create a new ndarray and store the result in it.
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
|
||||||
pub enum NDArrayOut<'ctx> {
|
|
||||||
/// Tell a function should create a new ndarray with the expected element type `dtype`.
|
|
||||||
NewNDArray { dtype: Type },
|
|
||||||
/// Tell a function to write the result to `ndarray`.
|
|
||||||
WriteToNDArray { ndarray: NDArrayObject<'ctx> },
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> NDArrayOut<'ctx> {
|
|
||||||
/// Get the dtype of this output.
|
|
||||||
#[must_use]
|
|
||||||
pub fn get_dtype(&self) -> Type {
|
|
||||||
match self {
|
|
||||||
NDArrayOut::NewNDArray { dtype } => *dtype,
|
|
||||||
NDArrayOut::WriteToNDArray { ndarray } => ndarray.dtype,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A version of [`call_nac3_ndarray_set_strides_by_shape`] in Rust.
|
|
||||||
///
|
|
||||||
/// This function is used generating strides for globally defined contiguous ndarrays.
|
|
||||||
#[must_use]
|
|
||||||
pub fn make_contiguous_strides(itemsize: u64, ndims: u64, shape: &[u64]) -> Vec<u64> {
|
|
||||||
let mut strides = Vec::with_capacity(ndims as usize);
|
|
||||||
let mut stride_product = 1u64;
|
|
||||||
for i in 0..ndims {
|
|
||||||
let axis = ndims - i - 1;
|
|
||||||
strides[axis as usize] = stride_product * itemsize;
|
|
||||||
stride_product *= shape[axis as usize];
|
|
||||||
}
|
|
||||||
strides
|
|
||||||
}
|
|
|
@ -1,179 +0,0 @@
|
||||||
use inkwell::{types::BasicType, values::PointerValue, AddressSpace};
|
|
||||||
|
|
||||||
use crate::codegen::{
|
|
||||||
irrt::{call_nac3_nditer_has_element, call_nac3_nditer_initialize, call_nac3_nditer_next},
|
|
||||||
model::*,
|
|
||||||
object::any::AnyObject,
|
|
||||||
stmt::{gen_for_callback, BreakContinueHooks},
|
|
||||||
CodeGenContext, CodeGenerator,
|
|
||||||
};
|
|
||||||
|
|
||||||
use super::NDArrayObject;
|
|
||||||
|
|
||||||
/// Fields of [`NDIter`]
|
|
||||||
pub struct NDIterFields<'ctx, F: FieldTraversal<'ctx>> {
|
|
||||||
pub ndims: F::Output<Int<SizeT>>,
|
|
||||||
pub shape: F::Output<Ptr<Int<SizeT>>>,
|
|
||||||
pub strides: F::Output<Ptr<Int<SizeT>>>,
|
|
||||||
|
|
||||||
pub indices: F::Output<Ptr<Int<SizeT>>>,
|
|
||||||
pub nth: F::Output<Int<SizeT>>,
|
|
||||||
pub element: F::Output<Ptr<Int<Byte>>>,
|
|
||||||
|
|
||||||
pub size: F::Output<Int<SizeT>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// An IRRT helper structure used to iterate through an ndarray.
|
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
|
||||||
pub struct NDIter;
|
|
||||||
|
|
||||||
impl<'ctx> StructKind<'ctx> for NDIter {
|
|
||||||
type Fields<F: FieldTraversal<'ctx>> = NDIterFields<'ctx, F>;
|
|
||||||
|
|
||||||
fn iter_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
|
|
||||||
Self::Fields {
|
|
||||||
ndims: traversal.add_auto("ndims"),
|
|
||||||
shape: traversal.add_auto("shape"),
|
|
||||||
strides: traversal.add_auto("strides"),
|
|
||||||
|
|
||||||
indices: traversal.add_auto("indices"),
|
|
||||||
nth: traversal.add_auto("nth"),
|
|
||||||
element: traversal.add_auto("element"),
|
|
||||||
|
|
||||||
size: traversal.add_auto("size"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A helper structure with a convenient interface to interact with [`NDIter`].
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct NDIterHandle<'ctx> {
|
|
||||||
instance: Instance<'ctx, Ptr<Struct<NDIter>>>,
|
|
||||||
/// The ndarray this [`NDIter`] to iterating over.
|
|
||||||
ndarray: NDArrayObject<'ctx>,
|
|
||||||
/// The current indices of [`NDIter`].
|
|
||||||
indices: Instance<'ctx, Ptr<Int<SizeT>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> NDIterHandle<'ctx> {
|
|
||||||
/// Allocate an [`NDIter`] that iterates through an ndarray.
|
|
||||||
pub fn new<G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
ndarray: NDArrayObject<'ctx>,
|
|
||||||
) -> Self {
|
|
||||||
let nditer = Struct(NDIter).alloca(generator, ctx);
|
|
||||||
let ndims = ndarray.ndims_llvm(generator, ctx.ctx);
|
|
||||||
|
|
||||||
// The caller has the responsibility to allocate 'indices' for `NDIter`.
|
|
||||||
let indices = Int(SizeT).array_alloca(generator, ctx, ndims.value);
|
|
||||||
call_nac3_nditer_initialize(generator, ctx, nditer, ndarray.instance, indices);
|
|
||||||
|
|
||||||
NDIterHandle { ndarray, instance: nditer, indices }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Is the current iteration valid?
|
|
||||||
///
|
|
||||||
/// If true, then `element`, `indices` and `nth` contain details about the current element.
|
|
||||||
///
|
|
||||||
/// If `ndarray` is unsized, this returns true only for the first iteration.
|
|
||||||
/// If `ndarray` is 0-sized, this always returns false.
|
|
||||||
#[must_use]
|
|
||||||
pub fn has_element<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
) -> Instance<'ctx, Int<Bool>> {
|
|
||||||
call_nac3_nditer_has_element(generator, ctx, self.instance)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Go to the next element. If `has_element()` is false, then this has undefined behavior.
|
|
||||||
///
|
|
||||||
/// If `ndarray` is unsized, this can only be called once.
|
|
||||||
/// If `ndarray` is 0-sized, this can never be called.
|
|
||||||
pub fn next<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
) {
|
|
||||||
call_nac3_nditer_next(generator, ctx, self.instance);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get pointer to the current element.
|
|
||||||
#[must_use]
|
|
||||||
pub fn get_pointer<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
) -> PointerValue<'ctx> {
|
|
||||||
let elem_ty = ctx.get_llvm_type(generator, self.ndarray.dtype);
|
|
||||||
|
|
||||||
let p = self.instance.get(generator, ctx, |f| f.element);
|
|
||||||
ctx.builder
|
|
||||||
.build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), "element")
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the value of the current element.
|
|
||||||
#[must_use]
|
|
||||||
pub fn get_scalar<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
) -> AnyObject<'ctx> {
|
|
||||||
let p = self.get_pointer(generator, ctx);
|
|
||||||
let value = ctx.builder.build_load(p, "value").unwrap();
|
|
||||||
AnyObject { ty: self.ndarray.dtype, value }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the index of the current element if this ndarray were a flat ndarray.
|
|
||||||
#[must_use]
|
|
||||||
pub fn get_index<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
) -> Instance<'ctx, Int<SizeT>> {
|
|
||||||
self.instance.get(generator, ctx, |f| f.nth)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the indices of the current element.
|
|
||||||
#[must_use]
|
|
||||||
pub fn get_indices(&self) -> Instance<'ctx, Ptr<Int<SizeT>>> {
|
|
||||||
self.indices
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> NDArrayObject<'ctx> {
|
|
||||||
/// Iterate through every element in the ndarray.
|
|
||||||
///
|
|
||||||
/// `body` has access to [`BreakContinueHooks`] to short-circuit and [`NDIterHandle`] to
|
|
||||||
/// get properties of the current iteration (e.g., the current element, indices, etc.)
|
|
||||||
pub fn foreach<'a, G, F>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
|
||||||
body: F,
|
|
||||||
) -> Result<(), String>
|
|
||||||
where
|
|
||||||
G: CodeGenerator + ?Sized,
|
|
||||||
F: FnOnce(
|
|
||||||
&mut G,
|
|
||||||
&mut CodeGenContext<'ctx, 'a>,
|
|
||||||
BreakContinueHooks<'ctx>,
|
|
||||||
NDIterHandle<'ctx>,
|
|
||||||
) -> Result<(), String>,
|
|
||||||
{
|
|
||||||
gen_for_callback(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
Some("ndarray_foreach"),
|
|
||||||
|generator, ctx| Ok(NDIterHandle::new(generator, ctx, *self)),
|
|
||||||
|generator, ctx, nditer| Ok(nditer.has_element(generator, ctx).value),
|
|
||||||
|generator, ctx, hooks, nditer| body(generator, ctx, hooks, nditer),
|
|
||||||
|generator, ctx, nditer| {
|
|
||||||
nditer.next(generator, ctx);
|
|
||||||
Ok(())
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,105 +0,0 @@
|
||||||
use util::gen_for_model;
|
|
||||||
|
|
||||||
use crate::{
|
|
||||||
codegen::{
|
|
||||||
model::*,
|
|
||||||
object::{any::AnyObject, list::ListObject, tuple::TupleObject},
|
|
||||||
CodeGenContext, CodeGenerator,
|
|
||||||
},
|
|
||||||
typecheck::typedef::TypeEnum,
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Parse a NumPy-like "int sequence" input and return the int sequence as an array and its length.
|
|
||||||
///
|
|
||||||
/// * `sequence` - The `sequence` parameter.
|
|
||||||
/// * `sequence_ty` - The typechecker type of `sequence`
|
|
||||||
///
|
|
||||||
/// The `sequence` argument type may only be one of the following:
|
|
||||||
/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
|
|
||||||
/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))`
|
|
||||||
/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
|
|
||||||
///
|
|
||||||
/// All `int32` values will be sign-extended to `SizeT`.
|
|
||||||
pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
input_sequence: AnyObject<'ctx>,
|
|
||||||
) -> (Instance<'ctx, Int<SizeT>>, Instance<'ctx, Ptr<Int<SizeT>>>) {
|
|
||||||
let zero = Int(SizeT).const_0(generator, ctx.ctx);
|
|
||||||
let one = Int(SizeT).const_1(generator, ctx.ctx);
|
|
||||||
|
|
||||||
// The result `list` to return.
|
|
||||||
match &*ctx.unifier.get_ty(input_sequence.ty) {
|
|
||||||
TypeEnum::TObj { obj_id, .. }
|
|
||||||
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
|
||||||
{
|
|
||||||
// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
|
|
||||||
|
|
||||||
// Check `input_sequence`
|
|
||||||
let input_sequence = ListObject::from_object(generator, ctx, input_sequence);
|
|
||||||
|
|
||||||
let len = input_sequence.instance.get(generator, ctx, |f| f.len);
|
|
||||||
let result = Int(SizeT).array_alloca(generator, ctx, len.value);
|
|
||||||
|
|
||||||
// Load all the `int32`s from the input_sequence, cast them to `SizeT`, and store them into `result`
|
|
||||||
gen_for_model(generator, ctx, zero, len, one, |generator, ctx, _hooks, i| {
|
|
||||||
// Load the i-th int32 in the input sequence
|
|
||||||
let int = input_sequence
|
|
||||||
.instance
|
|
||||||
.get(generator, ctx, |f| f.items)
|
|
||||||
.get_index(generator, ctx, i.value)
|
|
||||||
.value
|
|
||||||
.into_int_value();
|
|
||||||
|
|
||||||
// Cast to SizeT
|
|
||||||
let int = Int(SizeT).s_extend_or_bit_cast(generator, ctx, int);
|
|
||||||
|
|
||||||
// Store
|
|
||||||
result.set_index(ctx, i.value, int);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
})
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
(len, result)
|
|
||||||
}
|
|
||||||
TypeEnum::TTuple { .. } => {
|
|
||||||
// 2. A tuple of ints; e.g., `np.empty((600, 800, 3))`
|
|
||||||
|
|
||||||
let input_sequence = TupleObject::from_object(ctx, input_sequence);
|
|
||||||
|
|
||||||
let len = input_sequence.len(generator, ctx);
|
|
||||||
|
|
||||||
let result = Int(SizeT).array_alloca(generator, ctx, len.value);
|
|
||||||
|
|
||||||
for i in 0..input_sequence.num_elements() {
|
|
||||||
// Get the i-th element off of the tuple and load it into `result`.
|
|
||||||
let int = input_sequence.index(ctx, i).value.into_int_value();
|
|
||||||
let int = Int(SizeT).s_extend_or_bit_cast(generator, ctx, int);
|
|
||||||
|
|
||||||
result.set_index_const(ctx, i64::try_from(i).unwrap(), int);
|
|
||||||
}
|
|
||||||
|
|
||||||
(len, result)
|
|
||||||
}
|
|
||||||
TypeEnum::TObj { obj_id, .. }
|
|
||||||
if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() =>
|
|
||||||
{
|
|
||||||
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
|
|
||||||
let input_int = input_sequence.value.into_int_value();
|
|
||||||
|
|
||||||
let len = Int(SizeT).const_1(generator, ctx.ctx);
|
|
||||||
let result = Int(SizeT).array_alloca(generator, ctx, len.value);
|
|
||||||
let int = Int(SizeT).s_extend_or_bit_cast(generator, ctx, input_int);
|
|
||||||
|
|
||||||
// Storing into result[0]
|
|
||||||
result.store(ctx, int);
|
|
||||||
|
|
||||||
(len, result)
|
|
||||||
}
|
|
||||||
_ => panic!(
|
|
||||||
"encountered unknown sequence type: {}",
|
|
||||||
ctx.unifier.stringify(input_sequence.ty)
|
|
||||||
),
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,119 +0,0 @@
|
||||||
use crate::codegen::{
|
|
||||||
irrt::{call_nac3_ndarray_reshape_resolve_and_check_new_shape, call_nac3_ndarray_transpose},
|
|
||||||
model::*,
|
|
||||||
CodeGenContext, CodeGenerator,
|
|
||||||
};
|
|
||||||
|
|
||||||
use super::{indexing::RustNDIndex, NDArrayObject};
|
|
||||||
|
|
||||||
impl<'ctx> NDArrayObject<'ctx> {
|
|
||||||
/// Make sure the ndarray is at least `ndmin`-dimensional.
|
|
||||||
///
|
|
||||||
/// If this ndarray's `ndims` is less than `ndmin`, a view is created on this with 1s prepended to the shape.
|
|
||||||
/// If this ndarray's `ndims` is not less than `ndmin`, this function does nothing and return this ndarray.
|
|
||||||
#[must_use]
|
|
||||||
pub fn atleast_nd<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
ndmin: u64,
|
|
||||||
) -> Self {
|
|
||||||
if self.ndims < ndmin {
|
|
||||||
// Extend the dimensions with np.newaxis.
|
|
||||||
let mut indices = vec![];
|
|
||||||
for _ in self.ndims..ndmin {
|
|
||||||
indices.push(RustNDIndex::NewAxis);
|
|
||||||
}
|
|
||||||
indices.push(RustNDIndex::Ellipsis);
|
|
||||||
self.index(generator, ctx, &indices)
|
|
||||||
} else {
|
|
||||||
*self
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a reshaped view on this ndarray like `np.reshape()`.
|
|
||||||
///
|
|
||||||
/// If there is a `-1` in `new_shape`, it will be resolved; `new_shape` would **NOT** be modified as a result.
|
|
||||||
///
|
|
||||||
/// If reshape without copying is impossible, this function will allocate a new ndarray and copy contents.
|
|
||||||
///
|
|
||||||
/// * `new_ndims` - The number of dimensions of `new_shape` as a [`Type`].
|
|
||||||
/// * `new_shape` - The target shape to do `np.reshape()`.
|
|
||||||
#[must_use]
|
|
||||||
pub fn reshape_or_copy<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
new_ndims: u64,
|
|
||||||
new_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
|
||||||
) -> Self {
|
|
||||||
// TODO: The current criterion for whether to do a full copy or not is by checking `is_c_contiguous`,
|
|
||||||
// but this is not optimal - there are cases when the ndarray is not contiguous but could be reshaped
|
|
||||||
// without copying data. Look into how numpy does it.
|
|
||||||
|
|
||||||
let current_bb = ctx.builder.get_insert_block().unwrap();
|
|
||||||
let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "then_bb");
|
|
||||||
let else_bb = ctx.ctx.insert_basic_block_after(then_bb, "else_bb");
|
|
||||||
let end_bb = ctx.ctx.insert_basic_block_after(else_bb, "end_bb");
|
|
||||||
|
|
||||||
let dst_ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, new_ndims);
|
|
||||||
dst_ndarray.copy_shape_from_array(generator, ctx, new_shape);
|
|
||||||
|
|
||||||
// Reolsve negative indices
|
|
||||||
let size = self.size(generator, ctx);
|
|
||||||
let dst_ndims = dst_ndarray.ndims_llvm(generator, ctx.ctx);
|
|
||||||
let dst_shape = dst_ndarray.instance.get(generator, ctx, |f| f.shape);
|
|
||||||
call_nac3_ndarray_reshape_resolve_and_check_new_shape(
|
|
||||||
generator, ctx, size, dst_ndims, dst_shape,
|
|
||||||
);
|
|
||||||
|
|
||||||
let is_c_contiguous = self.is_c_contiguous(generator, ctx);
|
|
||||||
ctx.builder.build_conditional_branch(is_c_contiguous.value, then_bb, else_bb).unwrap();
|
|
||||||
|
|
||||||
// Inserting into then_bb: reshape is possible without copying
|
|
||||||
ctx.builder.position_at_end(then_bb);
|
|
||||||
dst_ndarray.set_strides_contiguous(generator, ctx);
|
|
||||||
dst_ndarray.instance.set(ctx, |f| f.data, self.instance.get(generator, ctx, |f| f.data));
|
|
||||||
ctx.builder.build_unconditional_branch(end_bb).unwrap();
|
|
||||||
|
|
||||||
// Inserting into else_bb: reshape is impossible without copying
|
|
||||||
ctx.builder.position_at_end(else_bb);
|
|
||||||
dst_ndarray.create_data(generator, ctx);
|
|
||||||
dst_ndarray.copy_data_from(generator, ctx, *self);
|
|
||||||
ctx.builder.build_unconditional_branch(end_bb).unwrap();
|
|
||||||
|
|
||||||
// Reposition for continuation
|
|
||||||
ctx.builder.position_at_end(end_bb);
|
|
||||||
|
|
||||||
dst_ndarray
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a transposed view on this ndarray like `np.transpose(<ndarray>, <axes> = None)`.
|
|
||||||
/// * `axes` - If specified, should be an array of the permutation (negative indices are **allowed**).
|
|
||||||
#[must_use]
|
|
||||||
pub fn transpose<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
axes: Option<Instance<'ctx, Ptr<Int<SizeT>>>>,
|
|
||||||
) -> Self {
|
|
||||||
// Define models
|
|
||||||
let transposed_ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, self.ndims);
|
|
||||||
|
|
||||||
let num_axes = self.ndims_llvm(generator, ctx.ctx);
|
|
||||||
|
|
||||||
// `axes = nullptr` if `axes` is unspecified.
|
|
||||||
let axes = axes.unwrap_or_else(|| Ptr(Int(SizeT)).nullptr(generator, ctx.ctx));
|
|
||||||
|
|
||||||
call_nac3_ndarray_transpose(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
self.instance,
|
|
||||||
transposed_ndarray.instance,
|
|
||||||
num_axes,
|
|
||||||
axes,
|
|
||||||
);
|
|
||||||
|
|
||||||
transposed_ndarray
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,99 +0,0 @@
|
||||||
use inkwell::values::StructValue;
|
|
||||||
use itertools::Itertools;
|
|
||||||
|
|
||||||
use crate::{
|
|
||||||
codegen::{model::*, CodeGenContext, CodeGenerator},
|
|
||||||
typecheck::typedef::{Type, TypeEnum},
|
|
||||||
};
|
|
||||||
|
|
||||||
use super::any::AnyObject;
|
|
||||||
|
|
||||||
/// A NAC3 tuple object.
|
|
||||||
///
|
|
||||||
/// NOTE: This struct has no copy trait.
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct TupleObject<'ctx> {
|
|
||||||
/// The type of the tuple.
|
|
||||||
pub tys: Vec<Type>,
|
|
||||||
/// The underlying LLVM struct value of this tuple.
|
|
||||||
pub value: StructValue<'ctx>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> TupleObject<'ctx> {
|
|
||||||
pub fn from_object(ctx: &mut CodeGenContext<'ctx, '_>, object: AnyObject<'ctx>) -> Self {
|
|
||||||
// TODO: Keep `is_vararg_ctx` from TTuple?
|
|
||||||
|
|
||||||
// Sanity check on object type.
|
|
||||||
let TypeEnum::TTuple { ty: tys, .. } = &*ctx.unifier.get_ty(object.ty) else {
|
|
||||||
panic!(
|
|
||||||
"Expected type to be a TypeEnum::TTuple, got {}",
|
|
||||||
ctx.unifier.stringify(object.ty)
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
// Check number of fields
|
|
||||||
let value = object.value.into_struct_value();
|
|
||||||
let value_num_fields = value.get_type().count_fields() as usize;
|
|
||||||
assert!(
|
|
||||||
value_num_fields == tys.len(),
|
|
||||||
"Tuple type has {} item(s), but the LLVM struct value has {} field(s)",
|
|
||||||
tys.len(),
|
|
||||||
value_num_fields
|
|
||||||
);
|
|
||||||
|
|
||||||
TupleObject { tys: tys.clone(), value }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convenience function. Create a [`TupleObject`] from an iterator of objects.
|
|
||||||
pub fn from_objects<I, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
objects: I,
|
|
||||||
) -> Self
|
|
||||||
where
|
|
||||||
I: IntoIterator<Item = AnyObject<'ctx>>,
|
|
||||||
{
|
|
||||||
let (values, tys): (Vec<_>, Vec<_>) =
|
|
||||||
objects.into_iter().map(|object| (object.value, object.ty)).unzip();
|
|
||||||
|
|
||||||
let llvm_tys = tys.iter().map(|ty| ctx.get_llvm_type(generator, *ty)).collect_vec();
|
|
||||||
let llvm_tuple_ty = ctx.ctx.struct_type(&llvm_tys, false);
|
|
||||||
|
|
||||||
let pllvm_tuple = ctx.builder.build_alloca(llvm_tuple_ty, "tuple").unwrap();
|
|
||||||
for (i, val) in values.into_iter().enumerate() {
|
|
||||||
let pval = ctx.builder.build_struct_gep(pllvm_tuple, i as u32, "value").unwrap();
|
|
||||||
ctx.builder.build_store(pval, val).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
let value = ctx.builder.build_load(pllvm_tuple, "").unwrap().into_struct_value();
|
|
||||||
TupleObject { tys, value }
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn num_elements(&self) -> usize {
|
|
||||||
self.tys.len()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the `len()` of this tuple.
|
|
||||||
#[must_use]
|
|
||||||
pub fn len<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
) -> Instance<'ctx, Int<SizeT>> {
|
|
||||||
Int(SizeT).const_int(generator, ctx.ctx, self.num_elements() as u64, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the `i`-th (0-based) object in this tuple.
|
|
||||||
pub fn index(&self, ctx: &mut CodeGenContext<'ctx, '_>, i: usize) -> AnyObject<'ctx> {
|
|
||||||
assert!(
|
|
||||||
i < self.num_elements(),
|
|
||||||
"Tuple object with length {} have index {i}",
|
|
||||||
self.num_elements()
|
|
||||||
);
|
|
||||||
|
|
||||||
let value = ctx.builder.build_extract_value(self.value, i as u32, "tuple[{i}]").unwrap();
|
|
||||||
let ty = self.tys[i];
|
|
||||||
AnyObject { ty, value }
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1 +0,0 @@
|
||||||
pub mod slice;
|
|
|
@ -1,125 +0,0 @@
|
||||||
use crate::codegen::{model::*, CodeGenContext, CodeGenerator};
|
|
||||||
|
|
||||||
/// Fields of [`Slice`]
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct SliceFields<'ctx, F: FieldTraversal<'ctx>, N: IntKind<'ctx>> {
|
|
||||||
pub start_defined: F::Output<Int<Bool>>,
|
|
||||||
pub start: F::Output<Int<N>>,
|
|
||||||
pub stop_defined: F::Output<Int<Bool>>,
|
|
||||||
pub stop: F::Output<Int<N>>,
|
|
||||||
pub step_defined: F::Output<Int<Bool>>,
|
|
||||||
pub step: F::Output<Int<N>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// An IRRT representation of an (unresolved) slice.
|
|
||||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
|
|
||||||
pub struct Slice<N>(pub N);
|
|
||||||
|
|
||||||
impl<'ctx, N: IntKind<'ctx>> StructKind<'ctx> for Slice<N> {
|
|
||||||
type Fields<F: FieldTraversal<'ctx>> = SliceFields<'ctx, F, N>;
|
|
||||||
|
|
||||||
fn iter_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
|
|
||||||
Self::Fields {
|
|
||||||
start_defined: traversal.add_auto("start_defined"),
|
|
||||||
start: traversal.add("start", Int(self.0)),
|
|
||||||
stop_defined: traversal.add_auto("stop_defined"),
|
|
||||||
stop: traversal.add("stop", Int(self.0)),
|
|
||||||
step_defined: traversal.add_auto("step_defined"),
|
|
||||||
step: traversal.add("step", Int(self.0)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A Rust structure that has [`Slice`] utilities and looks like a [`Slice`] but
|
|
||||||
/// `start`, `stop` and `step` are held by LLVM registers only and possibly
|
|
||||||
/// [`Option::None`] if unspecified.
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct RustSlice<'ctx, N: IntKind<'ctx>> {
|
|
||||||
// It is possible that `start`, `stop`, and `step` are all `None`.
|
|
||||||
// We need to know the `int_kind` even when that is the case.
|
|
||||||
pub int_kind: N,
|
|
||||||
pub start: Option<Instance<'ctx, Int<N>>>,
|
|
||||||
pub stop: Option<Instance<'ctx, Int<N>>>,
|
|
||||||
pub step: Option<Instance<'ctx, Int<N>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx, N: IntKind<'ctx>> RustSlice<'ctx, N> {
|
|
||||||
/// Write the contents to an LLVM [`Slice`].
|
|
||||||
pub fn write_to_slice<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
dst_slice_ptr: Instance<'ctx, Ptr<Struct<Slice<N>>>>,
|
|
||||||
) {
|
|
||||||
let false_ = Int(Bool).const_false(generator, ctx.ctx);
|
|
||||||
let true_ = Int(Bool).const_true(generator, ctx.ctx);
|
|
||||||
|
|
||||||
match self.start {
|
|
||||||
Some(start) => {
|
|
||||||
dst_slice_ptr.gep(ctx, |f| f.start_defined).store(ctx, true_);
|
|
||||||
dst_slice_ptr.gep(ctx, |f| f.start).store(ctx, start);
|
|
||||||
}
|
|
||||||
None => dst_slice_ptr.gep(ctx, |f| f.start_defined).store(ctx, false_),
|
|
||||||
}
|
|
||||||
|
|
||||||
match self.stop {
|
|
||||||
Some(stop) => {
|
|
||||||
dst_slice_ptr.gep(ctx, |f| f.stop_defined).store(ctx, true_);
|
|
||||||
dst_slice_ptr.gep(ctx, |f| f.stop).store(ctx, stop);
|
|
||||||
}
|
|
||||||
None => dst_slice_ptr.gep(ctx, |f| f.stop_defined).store(ctx, false_),
|
|
||||||
}
|
|
||||||
|
|
||||||
match self.step {
|
|
||||||
Some(step) => {
|
|
||||||
dst_slice_ptr.gep(ctx, |f| f.step_defined).store(ctx, true_);
|
|
||||||
dst_slice_ptr.gep(ctx, |f| f.step).store(ctx, step);
|
|
||||||
}
|
|
||||||
None => dst_slice_ptr.gep(ctx, |f| f.step_defined).store(ctx, false_),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub mod util {
|
|
||||||
use nac3parser::ast::Expr;
|
|
||||||
|
|
||||||
use crate::{
|
|
||||||
codegen::{model::*, CodeGenContext, CodeGenerator},
|
|
||||||
typecheck::typedef::Type,
|
|
||||||
};
|
|
||||||
|
|
||||||
use super::RustSlice;
|
|
||||||
|
|
||||||
/// Generate LLVM IR for an [`ExprKind::Slice`] and convert it into a [`RustSlice`].
|
|
||||||
#[allow(clippy::type_complexity)]
|
|
||||||
pub fn gen_slice<'ctx, G: CodeGenerator>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
lower: &Option<Box<Expr<Option<Type>>>>,
|
|
||||||
upper: &Option<Box<Expr<Option<Type>>>>,
|
|
||||||
step: &Option<Box<Expr<Option<Type>>>>,
|
|
||||||
) -> Result<RustSlice<'ctx, Int32>, String> {
|
|
||||||
let mut help = |value_expr: &Option<Box<Expr<Option<Type>>>>| -> Result<_, String> {
|
|
||||||
Ok(match value_expr {
|
|
||||||
None => None,
|
|
||||||
Some(value_expr) => {
|
|
||||||
let value_expr = generator
|
|
||||||
.gen_expr(ctx, value_expr)?
|
|
||||||
.unwrap()
|
|
||||||
.to_basic_value_enum(ctx, generator, ctx.primitives.int32)?;
|
|
||||||
|
|
||||||
let value_expr =
|
|
||||||
Int(Int32).check_value(generator, ctx.ctx, value_expr).unwrap();
|
|
||||||
|
|
||||||
Some(value_expr)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
};
|
|
||||||
|
|
||||||
let start = help(lower)?;
|
|
||||||
let stop = help(upper)?;
|
|
||||||
let step = help(step)?;
|
|
||||||
|
|
||||||
Ok(RustSlice { int_kind: Int32, start, stop, step })
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,15 +1,22 @@
|
||||||
|
use inkwell::{
|
||||||
|
attributes::{Attribute, AttributeLoc},
|
||||||
|
basic_block::BasicBlock,
|
||||||
|
types::{BasicType, BasicTypeEnum},
|
||||||
|
values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue},
|
||||||
|
IntPredicate,
|
||||||
|
};
|
||||||
|
use itertools::{izip, Itertools};
|
||||||
|
|
||||||
|
use nac3parser::ast::{
|
||||||
|
Constant, ExcepthandlerKind, Expr, ExprKind, Location, Stmt, StmtKind, StrRef,
|
||||||
|
};
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
classes::{ArrayLikeIndexer, ArraySliceValue, ListValue, RangeValue},
|
|
||||||
expr::{destructure_range, gen_binop_expr},
|
expr::{destructure_range, gen_binop_expr},
|
||||||
gen_in_range_check,
|
gen_in_range_check,
|
||||||
irrt::{handle_slice_indices, list_slice_assignment},
|
irrt::{handle_slice_indices, list_slice_assignment},
|
||||||
macros::codegen_unreachable,
|
macros::codegen_unreachable,
|
||||||
object::{
|
values::{ArrayLikeIndexer, ArraySliceValue, ListValue, RangeValue},
|
||||||
any::AnyObject,
|
|
||||||
ndarray::{
|
|
||||||
indexing::util::gen_ndarray_subscript_ndindices, NDArrayObject, ScalarOrNDArray,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -20,17 +27,6 @@ use crate::{
|
||||||
typedef::{iter_type_vars, FunSignature, Type, TypeEnum},
|
typedef::{iter_type_vars, FunSignature, Type, TypeEnum},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
use inkwell::{
|
|
||||||
attributes::{Attribute, AttributeLoc},
|
|
||||||
basic_block::BasicBlock,
|
|
||||||
types::{BasicType, BasicTypeEnum},
|
|
||||||
values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue},
|
|
||||||
IntPredicate,
|
|
||||||
};
|
|
||||||
use itertools::{izip, Itertools};
|
|
||||||
use nac3parser::ast::{
|
|
||||||
Constant, ExcepthandlerKind, Expr, ExprKind, Location, Stmt, StmtKind, StrRef,
|
|
||||||
};
|
|
||||||
|
|
||||||
/// See [`CodeGenerator::gen_var_alloc`].
|
/// See [`CodeGenerator::gen_var_alloc`].
|
||||||
pub fn gen_var<'ctx>(
|
pub fn gen_var<'ctx>(
|
||||||
|
@ -314,7 +310,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.to_basic_value_enum(ctx, generator, target_ty)?
|
.to_basic_value_enum(ctx, generator, target_ty)?
|
||||||
.into_pointer_value();
|
.into_pointer_value();
|
||||||
let target = ListValue::from_ptr_val(target, llvm_usize, None);
|
let target = ListValue::from_pointer_value(target, llvm_usize, None);
|
||||||
|
|
||||||
if let ExprKind::Slice { .. } = &key.node {
|
if let ExprKind::Slice { .. } = &key.node {
|
||||||
// Handle assigning to a slice
|
// Handle assigning to a slice
|
||||||
|
@ -335,7 +331,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
|
||||||
|
|
||||||
let value =
|
let value =
|
||||||
value.to_basic_value_enum(ctx, generator, value_ty)?.into_pointer_value();
|
value.to_basic_value_enum(ctx, generator, value_ty)?.into_pointer_value();
|
||||||
let value = ListValue::from_ptr_val(value, llvm_usize, None);
|
let value = ListValue::from_pointer_value(value, llvm_usize, None);
|
||||||
|
|
||||||
let target_item_ty = ctx.get_llvm_type(generator, target_item_ty);
|
let target_item_ty = ctx.get_llvm_type(generator, target_item_ty);
|
||||||
let Some(src_ind) = handle_slice_indices(
|
let Some(src_ind) = handle_slice_indices(
|
||||||
|
@ -415,47 +411,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
|
||||||
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
|
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
|
||||||
{
|
{
|
||||||
// Handle NDArray item assignment
|
// Handle NDArray item assignment
|
||||||
// Process target
|
todo!("ndarray subscript assignment is not yet implemented");
|
||||||
let target = generator
|
|
||||||
.gen_expr(ctx, target)?
|
|
||||||
.unwrap()
|
|
||||||
.to_basic_value_enum(ctx, generator, target_ty)?;
|
|
||||||
let target = AnyObject { value: target, ty: target_ty };
|
|
||||||
|
|
||||||
// Process key
|
|
||||||
let key = gen_ndarray_subscript_ndindices(generator, ctx, key)?;
|
|
||||||
|
|
||||||
// Process value
|
|
||||||
let value = value.to_basic_value_enum(ctx, generator, value_ty)?;
|
|
||||||
let value = AnyObject { value, ty: value_ty };
|
|
||||||
|
|
||||||
/*
|
|
||||||
Reference code:
|
|
||||||
```python
|
|
||||||
target = target[key]
|
|
||||||
value = np.asarray(value)
|
|
||||||
|
|
||||||
shape = np.broadcast_shape((target, value))
|
|
||||||
|
|
||||||
target = np.broadcast_to(target, shape)
|
|
||||||
value = np.broadcast_to(value, shape)
|
|
||||||
|
|
||||||
...and finally copy 1-1 from value to target.
|
|
||||||
```
|
|
||||||
*/
|
|
||||||
|
|
||||||
let target = NDArrayObject::from_object(generator, ctx, target);
|
|
||||||
let target = target.index(generator, ctx, &key);
|
|
||||||
|
|
||||||
let value =
|
|
||||||
ScalarOrNDArray::split_object(generator, ctx, value).to_ndarray(generator, ctx);
|
|
||||||
|
|
||||||
let broadcast_result = NDArrayObject::broadcast(generator, ctx, &[target, value]);
|
|
||||||
|
|
||||||
let target = broadcast_result.ndarrays[0];
|
|
||||||
let value = broadcast_result.ndarrays[1];
|
|
||||||
|
|
||||||
target.copy_data_from(generator, ctx, value);
|
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
panic!("encountered unknown target type: {}", ctx.unifier.stringify(target_ty));
|
panic!("encountered unknown target type: {}", ctx.unifier.stringify(target_ty));
|
||||||
|
@ -507,7 +463,8 @@ pub fn gen_for<G: CodeGenerator>(
|
||||||
TypeEnum::TObj { obj_id, .. }
|
TypeEnum::TObj { obj_id, .. }
|
||||||
if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() =>
|
if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() =>
|
||||||
{
|
{
|
||||||
let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range"));
|
let iter_val =
|
||||||
|
RangeValue::from_pointer_value(iter_val.into_pointer_value(), Some("range"));
|
||||||
// Internal variable for loop; Cannot be assigned
|
// Internal variable for loop; Cannot be assigned
|
||||||
let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?;
|
let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?;
|
||||||
// Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed
|
// Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed
|
||||||
|
@ -1872,6 +1829,37 @@ pub fn gen_stmt<G: CodeGenerator>(
|
||||||
stmt.location,
|
stmt.location,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
StmtKind::Global { names, .. } => {
|
||||||
|
let registered_globals = ctx
|
||||||
|
.top_level
|
||||||
|
.definitions
|
||||||
|
.read()
|
||||||
|
.iter()
|
||||||
|
.filter_map(|def| {
|
||||||
|
if let TopLevelDef::Variable { simple_name, ty, .. } = &*def.read() {
|
||||||
|
Some((*simple_name, *ty))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect_vec();
|
||||||
|
|
||||||
|
for id in names {
|
||||||
|
let Some((_, ty)) = registered_globals.iter().find(|(name, _)| name == id) else {
|
||||||
|
return Err(format!("{id} is not a global at {}", stmt.location));
|
||||||
|
};
|
||||||
|
|
||||||
|
let resolver = ctx.resolver.clone();
|
||||||
|
let ptr = resolver
|
||||||
|
.get_symbol_value(*id, ctx, generator)
|
||||||
|
.map(|val| val.to_basic_value_enum(ctx, generator, *ty))
|
||||||
|
.transpose()?
|
||||||
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
ctx.var_assignment.insert(*id, (ptr, None, 0));
|
||||||
|
}
|
||||||
|
}
|
||||||
_ => unimplemented!(),
|
_ => unimplemented!(),
|
||||||
};
|
};
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
|
@ -1,34 +1,37 @@
|
||||||
use crate::{
|
use std::{
|
||||||
codegen::{
|
collections::{HashMap, HashSet},
|
||||||
classes::{ListType, ProxyType, RangeType},
|
sync::Arc,
|
||||||
concrete_type::ConcreteTypeStore,
|
|
||||||
CodeGenContext, CodeGenLLVMOptions, CodeGenTargetMachineOptions, CodeGenTask,
|
|
||||||
CodeGenerator, DefaultCodeGenerator, WithCall, WorkerRegistry,
|
|
||||||
},
|
|
||||||
symbol_resolver::{SymbolResolver, ValueEnum},
|
|
||||||
toplevel::{
|
|
||||||
composer::{ComposerConfig, TopLevelComposer},
|
|
||||||
DefinitionId, FunInstance, TopLevelContext, TopLevelDef,
|
|
||||||
},
|
|
||||||
typecheck::{
|
|
||||||
type_inferencer::{FunctionData, Inferencer, PrimitiveStore},
|
|
||||||
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use indexmap::IndexMap;
|
use indexmap::IndexMap;
|
||||||
use indoc::indoc;
|
use indoc::indoc;
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
targets::{InitializationConfig, Target},
|
targets::{InitializationConfig, Target},
|
||||||
OptimizationLevel,
|
OptimizationLevel,
|
||||||
};
|
};
|
||||||
use nac3parser::ast::FileName;
|
|
||||||
use nac3parser::{
|
use nac3parser::{
|
||||||
ast::{fold::Fold, StrRef},
|
ast::{fold::Fold, FileName, StrRef},
|
||||||
parser::parse_program,
|
parser::parse_program,
|
||||||
};
|
};
|
||||||
use parking_lot::RwLock;
|
use parking_lot::RwLock;
|
||||||
use std::collections::{HashMap, HashSet};
|
|
||||||
use std::sync::Arc;
|
use super::{
|
||||||
|
concrete_type::ConcreteTypeStore,
|
||||||
|
types::{ListType, NDArrayType, ProxyType, RangeType},
|
||||||
|
CodeGenContext, CodeGenLLVMOptions, CodeGenTargetMachineOptions, CodeGenTask, CodeGenerator,
|
||||||
|
DefaultCodeGenerator, WithCall, WorkerRegistry,
|
||||||
|
};
|
||||||
|
use crate::{
|
||||||
|
symbol_resolver::{SymbolResolver, ValueEnum},
|
||||||
|
toplevel::{
|
||||||
|
composer::{ComposerConfig, TopLevelComposer},
|
||||||
|
DefinitionId, FunInstance, TopLevelContext, TopLevelDef,
|
||||||
|
},
|
||||||
|
typecheck::{
|
||||||
|
type_inferencer::{FunctionData, IdentifierInfo, Inferencer, PrimitiveStore},
|
||||||
|
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
struct Resolver {
|
struct Resolver {
|
||||||
id_to_type: HashMap<StrRef, Type>,
|
id_to_type: HashMap<StrRef, Type>,
|
||||||
|
@ -64,6 +67,7 @@ impl SymbolResolver for Resolver {
|
||||||
&self,
|
&self,
|
||||||
_: StrRef,
|
_: StrRef,
|
||||||
_: &mut CodeGenContext<'ctx, '_>,
|
_: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
_: &mut dyn CodeGenerator,
|
||||||
) -> Option<ValueEnum<'ctx>> {
|
) -> Option<ValueEnum<'ctx>> {
|
||||||
unimplemented!()
|
unimplemented!()
|
||||||
}
|
}
|
||||||
|
@ -138,7 +142,8 @@ fn test_primitives() {
|
||||||
};
|
};
|
||||||
let mut virtual_checks = Vec::new();
|
let mut virtual_checks = Vec::new();
|
||||||
let mut calls = HashMap::new();
|
let mut calls = HashMap::new();
|
||||||
let mut identifiers: HashSet<_> = ["a".into(), "b".into()].into();
|
let mut identifiers: HashMap<_, _> =
|
||||||
|
["a".into(), "b".into()].map(|id| (id, IdentifierInfo::default())).into();
|
||||||
let mut inferencer = Inferencer {
|
let mut inferencer = Inferencer {
|
||||||
top_level: &top_level,
|
top_level: &top_level,
|
||||||
function_data: &mut function_data,
|
function_data: &mut function_data,
|
||||||
|
@ -317,7 +322,8 @@ fn test_simple_call() {
|
||||||
};
|
};
|
||||||
let mut virtual_checks = Vec::new();
|
let mut virtual_checks = Vec::new();
|
||||||
let mut calls = HashMap::new();
|
let mut calls = HashMap::new();
|
||||||
let mut identifiers: HashSet<_> = ["a".into(), "foo".into()].into();
|
let mut identifiers: HashMap<_, _> =
|
||||||
|
["a".into(), "foo".into()].map(|id| (id, IdentifierInfo::default())).into();
|
||||||
let mut inferencer = Inferencer {
|
let mut inferencer = Inferencer {
|
||||||
top_level: &top_level,
|
top_level: &top_level,
|
||||||
function_data: &mut function_data,
|
function_data: &mut function_data,
|
||||||
|
@ -446,7 +452,7 @@ fn test_classes_list_type_new() {
|
||||||
let llvm_usize = generator.get_size_type(&ctx);
|
let llvm_usize = generator.get_size_type(&ctx);
|
||||||
|
|
||||||
let llvm_list = ListType::new(&generator, &ctx, llvm_i32.into());
|
let llvm_list = ListType::new(&generator, &ctx, llvm_i32.into());
|
||||||
assert!(ListType::is_type(llvm_list.as_base_type(), llvm_usize).is_ok());
|
assert!(ListType::is_representable(llvm_list.as_base_type(), llvm_usize).is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -454,5 +460,17 @@ fn test_classes_range_type_new() {
|
||||||
let ctx = inkwell::context::Context::create();
|
let ctx = inkwell::context::Context::create();
|
||||||
|
|
||||||
let llvm_range = RangeType::new(&ctx);
|
let llvm_range = RangeType::new(&ctx);
|
||||||
assert!(RangeType::is_type(llvm_range.as_base_type()).is_ok());
|
assert!(RangeType::is_representable(llvm_range.as_base_type()).is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_classes_ndarray_type_new() {
|
||||||
|
let ctx = inkwell::context::Context::create();
|
||||||
|
let generator = DefaultCodeGenerator::new(String::new(), 64);
|
||||||
|
|
||||||
|
let llvm_i32 = ctx.i32_type();
|
||||||
|
let llvm_usize = generator.get_size_type(&ctx);
|
||||||
|
|
||||||
|
let llvm_ndarray = NDArrayType::new(&generator, &ctx, llvm_i32.into());
|
||||||
|
assert!(NDArrayType::is_representable(llvm_ndarray.as_base_type(), llvm_usize).is_ok());
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,192 @@
|
||||||
|
use inkwell::{
|
||||||
|
context::Context,
|
||||||
|
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
|
||||||
|
values::IntValue,
|
||||||
|
AddressSpace,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::ProxyType;
|
||||||
|
use crate::codegen::{
|
||||||
|
values::{ArraySliceValue, ListValue, ProxyValue},
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Proxy type for a `list` type in LLVM.
|
||||||
|
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||||
|
pub struct ListType<'ctx> {
|
||||||
|
ty: PointerType<'ctx>,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> ListType<'ctx> {
|
||||||
|
/// Checks whether `llvm_ty` represents a `list` type, returning [Err] if it does not.
|
||||||
|
pub fn is_representable(
|
||||||
|
llvm_ty: PointerType<'ctx>,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
let llvm_list_ty = llvm_ty.get_element_type();
|
||||||
|
let AnyTypeEnum::StructType(llvm_list_ty) = llvm_list_ty else {
|
||||||
|
return Err(format!("Expected struct type for `list` type, got {llvm_list_ty}"));
|
||||||
|
};
|
||||||
|
if llvm_list_ty.count_fields() != 2 {
|
||||||
|
return Err(format!(
|
||||||
|
"Expected 2 fields in `list`, got {}",
|
||||||
|
llvm_list_ty.count_fields()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let list_size_ty = llvm_list_ty.get_field_type_at_index(0).unwrap();
|
||||||
|
let Ok(_) = PointerType::try_from(list_size_ty) else {
|
||||||
|
return Err(format!("Expected pointer type for `list.0`, got {list_size_ty}"));
|
||||||
|
};
|
||||||
|
|
||||||
|
let list_data_ty = llvm_list_ty.get_field_type_at_index(1).unwrap();
|
||||||
|
let Ok(list_data_ty) = IntType::try_from(list_data_ty) else {
|
||||||
|
return Err(format!("Expected int type for `list.1`, got {list_data_ty}"));
|
||||||
|
};
|
||||||
|
if list_data_ty.get_bit_width() != llvm_usize.get_bit_width() {
|
||||||
|
return Err(format!(
|
||||||
|
"Expected {}-bit int type for `list.1`, got {}-bit int",
|
||||||
|
llvm_usize.get_bit_width(),
|
||||||
|
list_data_ty.get_bit_width()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates an LLVM type corresponding to the expected structure of a `List`.
|
||||||
|
#[must_use]
|
||||||
|
fn llvm_type(
|
||||||
|
ctx: &'ctx Context,
|
||||||
|
element_type: BasicTypeEnum<'ctx>,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
) -> PointerType<'ctx> {
|
||||||
|
// struct List { data: T*, size: size_t }
|
||||||
|
let field_tys = [element_type.ptr_type(AddressSpace::default()).into(), llvm_usize.into()];
|
||||||
|
|
||||||
|
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates an instance of [`ListType`].
|
||||||
|
#[must_use]
|
||||||
|
pub fn new<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &'ctx Context,
|
||||||
|
element_type: BasicTypeEnum<'ctx>,
|
||||||
|
) -> Self {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx);
|
||||||
|
let llvm_list = Self::llvm_type(ctx, element_type, llvm_usize);
|
||||||
|
|
||||||
|
ListType::from_type(llvm_list, llvm_usize)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates an [`ListType`] from a [`PointerType`].
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
|
||||||
|
debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok());
|
||||||
|
|
||||||
|
ListType { ty: ptr_ty, llvm_usize }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the type of the `size` field of this `list` type.
|
||||||
|
#[must_use]
|
||||||
|
pub fn size_type(&self) -> IntType<'ctx> {
|
||||||
|
self.as_base_type()
|
||||||
|
.get_element_type()
|
||||||
|
.into_struct_type()
|
||||||
|
.get_field_type_at_index(1)
|
||||||
|
.map(BasicTypeEnum::into_int_type)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the element type of this `list` type.
|
||||||
|
#[must_use]
|
||||||
|
pub fn element_type(&self) -> AnyTypeEnum<'ctx> {
|
||||||
|
self.as_base_type()
|
||||||
|
.get_element_type()
|
||||||
|
.into_struct_type()
|
||||||
|
.get_field_type_at_index(0)
|
||||||
|
.map(BasicTypeEnum::into_pointer_type)
|
||||||
|
.map(PointerType::get_element_type)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> ProxyType<'ctx> for ListType<'ctx> {
|
||||||
|
type Base = PointerType<'ctx>;
|
||||||
|
type Value = ListValue<'ctx>;
|
||||||
|
|
||||||
|
fn is_type<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &'ctx Context,
|
||||||
|
llvm_ty: impl BasicType<'ctx>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
|
||||||
|
<Self as ProxyType<'ctx>>::is_representable(generator, ctx, ty)
|
||||||
|
} else {
|
||||||
|
Err(format!("Expected pointer type, got {llvm_ty:?}"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_representable<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &'ctx Context,
|
||||||
|
llvm_ty: Self::Base,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
Self::is_representable(llvm_ty, generator.get_size_type(ctx))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new_value<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
) -> Self::Value {
|
||||||
|
self.map_value(
|
||||||
|
generator
|
||||||
|
.gen_var_alloc(
|
||||||
|
ctx,
|
||||||
|
self.as_base_type().get_element_type().into_struct_type().into(),
|
||||||
|
name,
|
||||||
|
)
|
||||||
|
.unwrap(),
|
||||||
|
name,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new_array_value<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
size: IntValue<'ctx>,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
) -> ArraySliceValue<'ctx> {
|
||||||
|
generator
|
||||||
|
.gen_array_var_alloc(
|
||||||
|
ctx,
|
||||||
|
self.as_base_type().get_element_type().into_struct_type().into(),
|
||||||
|
size,
|
||||||
|
name,
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn map_value(
|
||||||
|
&self,
|
||||||
|
value: <Self::Value as ProxyValue<'ctx>>::Base,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
) -> Self::Value {
|
||||||
|
Self::Value::from_pointer_value(value, self.llvm_usize, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn as_base_type(&self) -> Self::Base {
|
||||||
|
self.ty
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> From<ListType<'ctx>> for PointerType<'ctx> {
|
||||||
|
fn from(value: ListType<'ctx>) -> Self {
|
||||||
|
value.as_base_type()
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,64 @@
|
||||||
|
use inkwell::{context::Context, types::BasicType, values::IntValue};
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
values::{ArraySliceValue, ProxyValue},
|
||||||
|
{CodeGenContext, CodeGenerator},
|
||||||
|
};
|
||||||
|
pub use list::*;
|
||||||
|
pub use ndarray::*;
|
||||||
|
pub use range::*;
|
||||||
|
|
||||||
|
mod list;
|
||||||
|
mod ndarray;
|
||||||
|
mod range;
|
||||||
|
pub mod structure;
|
||||||
|
|
||||||
|
/// A LLVM type that is used to represent a corresponding type in NAC3.
|
||||||
|
pub trait ProxyType<'ctx>: Into<Self::Base> {
|
||||||
|
/// The LLVM type of which values of this type possess. This is usually a
|
||||||
|
/// [LLVM pointer type][PointerType] for any non-primitive types.
|
||||||
|
type Base: BasicType<'ctx>;
|
||||||
|
|
||||||
|
/// The type of values represented by this type.
|
||||||
|
type Value: ProxyValue<'ctx, Type = Self>;
|
||||||
|
|
||||||
|
fn is_type<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &'ctx Context,
|
||||||
|
llvm_ty: impl BasicType<'ctx>,
|
||||||
|
) -> Result<(), String>;
|
||||||
|
|
||||||
|
/// Checks whether `llvm_ty` can be represented by this [`ProxyType`].
|
||||||
|
fn is_representable<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &'ctx Context,
|
||||||
|
llvm_ty: Self::Base,
|
||||||
|
) -> Result<(), String>;
|
||||||
|
|
||||||
|
/// Creates a new value of this type.
|
||||||
|
fn new_value<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
) -> Self::Value;
|
||||||
|
|
||||||
|
/// Creates a new array value of this type.
|
||||||
|
fn new_array_value<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
size: IntValue<'ctx>,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
) -> ArraySliceValue<'ctx>;
|
||||||
|
|
||||||
|
/// Converts an existing value into a [`ProxyValue`] of this type.
|
||||||
|
fn map_value(
|
||||||
|
&self,
|
||||||
|
value: <Self::Value as ProxyValue<'ctx>>::Base,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
) -> Self::Value;
|
||||||
|
|
||||||
|
/// Returns the [base type][Self::Base] of this proxy.
|
||||||
|
fn as_base_type(&self) -> Self::Base;
|
||||||
|
}
|
|
@ -0,0 +1,258 @@
|
||||||
|
use inkwell::{
|
||||||
|
context::Context,
|
||||||
|
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
|
||||||
|
values::{IntValue, PointerValue},
|
||||||
|
AddressSpace,
|
||||||
|
};
|
||||||
|
use itertools::Itertools;
|
||||||
|
|
||||||
|
use nac3core_derive::StructFields;
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
structure::{StructField, StructFields},
|
||||||
|
ProxyType,
|
||||||
|
};
|
||||||
|
use crate::codegen::{
|
||||||
|
values::{ArraySliceValue, NDArrayValue, ProxyValue},
|
||||||
|
{CodeGenContext, CodeGenerator},
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Proxy type for a `ndarray` type in LLVM.
|
||||||
|
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||||
|
pub struct NDArrayType<'ctx> {
|
||||||
|
ty: PointerType<'ctx>,
|
||||||
|
dtype: BasicTypeEnum<'ctx>,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
||||||
|
pub struct NDArrayStructFields<'ctx> {
|
||||||
|
#[value_type(usize)]
|
||||||
|
pub ndims: StructField<'ctx, IntValue<'ctx>>,
|
||||||
|
#[value_type(usize.ptr_type(AddressSpace::default()))]
|
||||||
|
pub shape: StructField<'ctx, PointerValue<'ctx>>,
|
||||||
|
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
|
||||||
|
pub data: StructField<'ctx, PointerValue<'ctx>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> NDArrayType<'ctx> {
|
||||||
|
/// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not.
|
||||||
|
pub fn is_representable(
|
||||||
|
llvm_ty: PointerType<'ctx>,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
let llvm_ndarray_ty = llvm_ty.get_element_type();
|
||||||
|
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
|
||||||
|
return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}"));
|
||||||
|
};
|
||||||
|
if llvm_ndarray_ty.count_fields() != 3 {
|
||||||
|
return Err(format!(
|
||||||
|
"Expected 3 fields in `NDArray`, got {}",
|
||||||
|
llvm_ndarray_ty.count_fields()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let ndarray_ndims_ty = llvm_ndarray_ty.get_field_type_at_index(0).unwrap();
|
||||||
|
let Ok(ndarray_ndims_ty) = IntType::try_from(ndarray_ndims_ty) else {
|
||||||
|
return Err(format!("Expected int type for `ndarray.0`, got {ndarray_ndims_ty}"));
|
||||||
|
};
|
||||||
|
if ndarray_ndims_ty.get_bit_width() != llvm_usize.get_bit_width() {
|
||||||
|
return Err(format!(
|
||||||
|
"Expected {}-bit int type for `ndarray.0`, got {}-bit int",
|
||||||
|
llvm_usize.get_bit_width(),
|
||||||
|
ndarray_ndims_ty.get_bit_width()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let ndarray_dims_ty = llvm_ndarray_ty.get_field_type_at_index(1).unwrap();
|
||||||
|
let Ok(ndarray_pdims) = PointerType::try_from(ndarray_dims_ty) else {
|
||||||
|
return Err(format!("Expected pointer type for `ndarray.1`, got {ndarray_dims_ty}"));
|
||||||
|
};
|
||||||
|
let ndarray_dims = ndarray_pdims.get_element_type();
|
||||||
|
let Ok(ndarray_dims) = IntType::try_from(ndarray_dims) else {
|
||||||
|
return Err(format!(
|
||||||
|
"Expected pointer-to-int type for `ndarray.1`, got pointer-to-{ndarray_dims}"
|
||||||
|
));
|
||||||
|
};
|
||||||
|
if ndarray_dims.get_bit_width() != llvm_usize.get_bit_width() {
|
||||||
|
return Err(format!(
|
||||||
|
"Expected pointer-to-{}-bit int type for `ndarray.1`, got pointer-to-{}-bit int",
|
||||||
|
llvm_usize.get_bit_width(),
|
||||||
|
ndarray_dims.get_bit_width()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let ndarray_data_ty = llvm_ndarray_ty.get_field_type_at_index(2).unwrap();
|
||||||
|
let Ok(ndarray_pdata) = PointerType::try_from(ndarray_data_ty) else {
|
||||||
|
return Err(format!("Expected pointer type for `ndarray.2`, got {ndarray_data_ty}"));
|
||||||
|
};
|
||||||
|
let ndarray_data = ndarray_pdata.get_element_type();
|
||||||
|
let Ok(ndarray_data) = IntType::try_from(ndarray_data) else {
|
||||||
|
return Err(format!(
|
||||||
|
"Expected pointer-to-int type for `ndarray.2`, got pointer-to-{ndarray_data}"
|
||||||
|
));
|
||||||
|
};
|
||||||
|
if ndarray_data.get_bit_width() != 8 {
|
||||||
|
return Err(format!(
|
||||||
|
"Expected pointer-to-8-bit int type for `ndarray.1`, got pointer-to-{}-bit int",
|
||||||
|
ndarray_data.get_bit_width()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Move this into e.g. StructProxyType
|
||||||
|
#[must_use]
|
||||||
|
fn fields(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> NDArrayStructFields<'ctx> {
|
||||||
|
NDArrayStructFields::new(ctx, llvm_usize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Move this into e.g. StructProxyType
|
||||||
|
#[must_use]
|
||||||
|
pub fn get_fields(
|
||||||
|
&self,
|
||||||
|
ctx: &'ctx Context,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
) -> NDArrayStructFields<'ctx> {
|
||||||
|
Self::fields(ctx, llvm_usize)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
|
||||||
|
#[must_use]
|
||||||
|
fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> {
|
||||||
|
// struct NDArray { num_dims: size_t, dims: size_t*, data: i8* }
|
||||||
|
//
|
||||||
|
// * data : Pointer to an array containing the array data
|
||||||
|
// * itemsize: The size of each NDArray elements in bytes
|
||||||
|
// * ndims : Number of dimensions in the array
|
||||||
|
// * shape : Pointer to an array containing the shape of the NDArray
|
||||||
|
// * strides : Pointer to an array indicating the number of bytes between each element at a dimension
|
||||||
|
let field_tys =
|
||||||
|
Self::fields(ctx, llvm_usize).into_iter().map(|field| field.1).collect_vec();
|
||||||
|
|
||||||
|
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates an instance of [`NDArrayType`].
|
||||||
|
#[must_use]
|
||||||
|
pub fn new<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &'ctx Context,
|
||||||
|
dtype: BasicTypeEnum<'ctx>,
|
||||||
|
) -> Self {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx);
|
||||||
|
let llvm_ndarray = Self::llvm_type(ctx, llvm_usize);
|
||||||
|
|
||||||
|
NDArrayType { ty: llvm_ndarray, dtype, llvm_usize }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates an [`NDArrayType`] from a [`PointerType`] representing an `NDArray`.
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_type(
|
||||||
|
ptr_ty: PointerType<'ctx>,
|
||||||
|
dtype: BasicTypeEnum<'ctx>,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
) -> Self {
|
||||||
|
debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok());
|
||||||
|
|
||||||
|
NDArrayType { ty: ptr_ty, dtype, llvm_usize }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the type of the `size` field of this `ndarray` type.
|
||||||
|
#[must_use]
|
||||||
|
pub fn size_type(&self) -> IntType<'ctx> {
|
||||||
|
self.as_base_type()
|
||||||
|
.get_element_type()
|
||||||
|
.into_struct_type()
|
||||||
|
.get_field_type_at_index(0)
|
||||||
|
.map(BasicTypeEnum::into_int_type)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the element type of this `ndarray` type.
|
||||||
|
#[must_use]
|
||||||
|
pub fn element_type(&self) -> BasicTypeEnum<'ctx> {
|
||||||
|
self.dtype
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> {
|
||||||
|
type Base = PointerType<'ctx>;
|
||||||
|
type Value = NDArrayValue<'ctx>;
|
||||||
|
|
||||||
|
fn is_type<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &'ctx Context,
|
||||||
|
llvm_ty: impl BasicType<'ctx>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
|
||||||
|
<Self as ProxyType<'ctx>>::is_representable(generator, ctx, ty)
|
||||||
|
} else {
|
||||||
|
Err(format!("Expected pointer type, got {llvm_ty:?}"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_representable<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &'ctx Context,
|
||||||
|
llvm_ty: Self::Base,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
Self::is_representable(llvm_ty, generator.get_size_type(ctx))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new_value<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
) -> Self::Value {
|
||||||
|
self.map_value(
|
||||||
|
generator
|
||||||
|
.gen_var_alloc(
|
||||||
|
ctx,
|
||||||
|
self.as_base_type().get_element_type().into_struct_type().into(),
|
||||||
|
name,
|
||||||
|
)
|
||||||
|
.unwrap(),
|
||||||
|
name,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new_array_value<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
size: IntValue<'ctx>,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
) -> ArraySliceValue<'ctx> {
|
||||||
|
generator
|
||||||
|
.gen_array_var_alloc(
|
||||||
|
ctx,
|
||||||
|
self.as_base_type().get_element_type().into_struct_type().into(),
|
||||||
|
size,
|
||||||
|
name,
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn map_value(
|
||||||
|
&self,
|
||||||
|
value: <Self::Value as ProxyValue<'ctx>>::Base,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
) -> Self::Value {
|
||||||
|
debug_assert_eq!(value.get_type(), self.as_base_type());
|
||||||
|
|
||||||
|
NDArrayValue::from_pointer_value(value, self.dtype, self.llvm_usize, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn as_base_type(&self) -> Self::Base {
|
||||||
|
self.ty
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> From<NDArrayType<'ctx>> for PointerType<'ctx> {
|
||||||
|
fn from(value: NDArrayType<'ctx>) -> Self {
|
||||||
|
value.as_base_type()
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,159 @@
|
||||||
|
use inkwell::{
|
||||||
|
context::Context,
|
||||||
|
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
|
||||||
|
values::IntValue,
|
||||||
|
AddressSpace,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::ProxyType;
|
||||||
|
use crate::codegen::{
|
||||||
|
values::{ArraySliceValue, ProxyValue, RangeValue},
|
||||||
|
{CodeGenContext, CodeGenerator},
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Proxy type for a `range` type in LLVM.
|
||||||
|
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||||
|
pub struct RangeType<'ctx> {
|
||||||
|
ty: PointerType<'ctx>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> RangeType<'ctx> {
|
||||||
|
/// Checks whether `llvm_ty` represents a `range` type, returning [Err] if it does not.
|
||||||
|
pub fn is_representable(llvm_ty: PointerType<'ctx>) -> Result<(), String> {
|
||||||
|
let llvm_range_ty = llvm_ty.get_element_type();
|
||||||
|
let AnyTypeEnum::ArrayType(llvm_range_ty) = llvm_range_ty else {
|
||||||
|
return Err(format!("Expected array type for `range` type, got {llvm_range_ty}"));
|
||||||
|
};
|
||||||
|
if llvm_range_ty.len() != 3 {
|
||||||
|
return Err(format!(
|
||||||
|
"Expected 3 elements for `range` type, got {}",
|
||||||
|
llvm_range_ty.len()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let llvm_range_elem_ty = llvm_range_ty.get_element_type();
|
||||||
|
let Ok(llvm_range_elem_ty) = IntType::try_from(llvm_range_elem_ty) else {
|
||||||
|
return Err(format!(
|
||||||
|
"Expected int type for `range` element type, got {llvm_range_elem_ty}"
|
||||||
|
));
|
||||||
|
};
|
||||||
|
if llvm_range_elem_ty.get_bit_width() != 32 {
|
||||||
|
return Err(format!(
|
||||||
|
"Expected 32-bit int type for `range` element type, got {}",
|
||||||
|
llvm_range_elem_ty.get_bit_width()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates an LLVM type corresponding to the expected structure of a `Range`.
|
||||||
|
#[must_use]
|
||||||
|
fn llvm_type(ctx: &'ctx Context) -> PointerType<'ctx> {
|
||||||
|
// typedef int32_t Range[3];
|
||||||
|
let llvm_i32 = ctx.i32_type();
|
||||||
|
llvm_i32.array_type(3).ptr_type(AddressSpace::default())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates an instance of [`RangeType`].
|
||||||
|
#[must_use]
|
||||||
|
pub fn new(ctx: &'ctx Context) -> Self {
|
||||||
|
let llvm_range = Self::llvm_type(ctx);
|
||||||
|
|
||||||
|
RangeType::from_type(llvm_range)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates an [`RangeType`] from a [`PointerType`].
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_type(ptr_ty: PointerType<'ctx>) -> Self {
|
||||||
|
debug_assert!(Self::is_representable(ptr_ty).is_ok());
|
||||||
|
|
||||||
|
RangeType { ty: ptr_ty }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the type of all fields of this `range` type.
|
||||||
|
#[must_use]
|
||||||
|
pub fn value_type(&self) -> IntType<'ctx> {
|
||||||
|
self.as_base_type().get_element_type().into_array_type().get_element_type().into_int_type()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> {
|
||||||
|
type Base = PointerType<'ctx>;
|
||||||
|
type Value = RangeValue<'ctx>;
|
||||||
|
|
||||||
|
fn is_type<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &'ctx Context,
|
||||||
|
llvm_ty: impl BasicType<'ctx>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
|
||||||
|
<Self as ProxyType<'ctx>>::is_representable(generator, ctx, ty)
|
||||||
|
} else {
|
||||||
|
Err(format!("Expected pointer type, got {llvm_ty:?}"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_representable<G: CodeGenerator + ?Sized>(
|
||||||
|
_: &G,
|
||||||
|
_: &'ctx Context,
|
||||||
|
llvm_ty: Self::Base,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
Self::is_representable(llvm_ty)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new_value<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
) -> Self::Value {
|
||||||
|
self.map_value(
|
||||||
|
generator
|
||||||
|
.gen_var_alloc(
|
||||||
|
ctx,
|
||||||
|
self.as_base_type().get_element_type().into_struct_type().into(),
|
||||||
|
name,
|
||||||
|
)
|
||||||
|
.unwrap(),
|
||||||
|
name,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new_array_value<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
size: IntValue<'ctx>,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
) -> ArraySliceValue<'ctx> {
|
||||||
|
generator
|
||||||
|
.gen_array_var_alloc(
|
||||||
|
ctx,
|
||||||
|
self.as_base_type().get_element_type().into_struct_type().into(),
|
||||||
|
size,
|
||||||
|
name,
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn map_value(
|
||||||
|
&self,
|
||||||
|
value: <Self::Value as ProxyValue<'ctx>>::Base,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
) -> Self::Value {
|
||||||
|
debug_assert_eq!(value.get_type(), self.as_base_type());
|
||||||
|
|
||||||
|
RangeValue::from_pointer_value(value, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn as_base_type(&self) -> Self::Base {
|
||||||
|
self.ty
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> From<RangeType<'ctx>> for PointerType<'ctx> {
|
||||||
|
fn from(value: RangeType<'ctx>) -> Self {
|
||||||
|
value.as_base_type()
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,203 @@
|
||||||
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
use inkwell::{
|
||||||
|
context::AsContextRef,
|
||||||
|
types::{BasicTypeEnum, IntType},
|
||||||
|
values::{BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue},
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::codegen::CodeGenContext;
|
||||||
|
|
||||||
|
/// Trait indicating that the structure is a field-wise representation of an LLVM structure.
|
||||||
|
///
|
||||||
|
/// # Usage
|
||||||
|
///
|
||||||
|
/// For example, for a simple C-slice LLVM structure:
|
||||||
|
///
|
||||||
|
/// ```ignore
|
||||||
|
/// struct CSliceFields<'ctx> {
|
||||||
|
/// ptr: StructField<'ctx, PointerValue<'ctx>>,
|
||||||
|
/// len: StructField<'ctx, IntValue<'ctx>>
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
pub trait StructFields<'ctx>: Eq + Copy {
|
||||||
|
/// Creates an instance of [`StructFields`] using the given `ctx` and `size_t` types.
|
||||||
|
fn new(ctx: impl AsContextRef<'ctx>, llvm_usize: IntType<'ctx>) -> Self;
|
||||||
|
|
||||||
|
/// Returns a [`Vec`] that contains the fields of the structure in the order as they appear in
|
||||||
|
/// the type definition.
|
||||||
|
#[must_use]
|
||||||
|
fn to_vec(&self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)>;
|
||||||
|
|
||||||
|
/// Returns a [`Iterator`] that contains the fields of the structure in the order as they appear
|
||||||
|
/// in the type definition.
|
||||||
|
#[must_use]
|
||||||
|
fn iter(&self) -> impl Iterator<Item = (&'static str, BasicTypeEnum<'ctx>)> {
|
||||||
|
self.to_vec().into_iter()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a [`Vec`] that contains the fields of the structure in the order as they appear in
|
||||||
|
/// the type definition.
|
||||||
|
#[must_use]
|
||||||
|
fn into_vec(self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)>
|
||||||
|
where
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
|
self.to_vec()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a [`Iterator`] that contains the fields of the structure in the order as they appear
|
||||||
|
/// in the type definition.
|
||||||
|
#[must_use]
|
||||||
|
fn into_iter(self) -> impl Iterator<Item = (&'static str, BasicTypeEnum<'ctx>)>
|
||||||
|
where
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
|
self.into_vec().into_iter()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A single field of an LLVM structure.
|
||||||
|
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||||
|
pub struct StructField<'ctx, Value>
|
||||||
|
where
|
||||||
|
Value: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>, Error = ()>,
|
||||||
|
{
|
||||||
|
/// The index of this field within the structure.
|
||||||
|
index: u32,
|
||||||
|
|
||||||
|
/// The name of this field.
|
||||||
|
name: &'static str,
|
||||||
|
|
||||||
|
/// The type of this field.
|
||||||
|
ty: BasicTypeEnum<'ctx>,
|
||||||
|
|
||||||
|
/// Instance of [`PhantomData`] containing [`Value`], used to implement automatic downcasts.
|
||||||
|
_value_ty: PhantomData<Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, Value> StructField<'ctx, Value>
|
||||||
|
where
|
||||||
|
Value: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>, Error = ()>,
|
||||||
|
{
|
||||||
|
/// Creates an instance of [`StructField`].
|
||||||
|
///
|
||||||
|
/// * `idx_counter` - The instance of [`FieldIndexCounter`] used to track the current field
|
||||||
|
/// index.
|
||||||
|
/// * `name` - Name of the field.
|
||||||
|
/// * `ty` - The type of this field.
|
||||||
|
pub fn create(
|
||||||
|
idx_counter: &mut FieldIndexCounter,
|
||||||
|
name: &'static str,
|
||||||
|
ty: impl Into<BasicTypeEnum<'ctx>>,
|
||||||
|
) -> Self {
|
||||||
|
StructField { index: idx_counter.increment(), name, ty: ty.into(), _value_ty: PhantomData }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates an instance of [`StructField`] with a given index.
|
||||||
|
///
|
||||||
|
/// * `index` - The index of this field within its enclosing structure.
|
||||||
|
/// * `name` - Name of the field.
|
||||||
|
/// * `ty` - The type of this field.
|
||||||
|
pub fn create_at(index: u32, name: &'static str, ty: impl Into<BasicTypeEnum<'ctx>>) -> Self {
|
||||||
|
StructField { index, name, ty: ty.into(), _value_ty: PhantomData }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a pointer to this field in an arbitrary structure by performing a `getelementptr i32
|
||||||
|
/// {idx...}, i32 {self.index}`.
|
||||||
|
pub fn ptr_by_array_gep(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
pobj: PointerValue<'ctx>,
|
||||||
|
idx: &[IntValue<'ctx>],
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
unsafe {
|
||||||
|
ctx.builder.build_in_bounds_gep(
|
||||||
|
pobj,
|
||||||
|
&[idx, &[ctx.ctx.i32_type().const_int(u64::from(self.index), false)]].concat(),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a pointer to this field in an arbitrary structure by performing the equivalent of
|
||||||
|
/// `getelementptr i32 0, i32 {self.index}`.
|
||||||
|
pub fn ptr_by_gep(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
pobj: PointerValue<'ctx>,
|
||||||
|
obj_name: Option<&'ctx str>,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
ctx.builder
|
||||||
|
.build_struct_gep(
|
||||||
|
pobj,
|
||||||
|
self.index,
|
||||||
|
&obj_name.map(|name| format!("{name}.{}.addr", self.name)).unwrap_or_default(),
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Gets the value of this field for a given `obj`.
|
||||||
|
#[must_use]
|
||||||
|
pub fn get_from_value(&self, obj: StructValue<'ctx>) -> Value {
|
||||||
|
obj.get_field_at_index(self.index).and_then(|value| Value::try_from(value).ok()).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sets the value of this field for a given `obj`.
|
||||||
|
pub fn set_from_value(&self, obj: StructValue<'ctx>, value: Value) {
|
||||||
|
obj.set_field_at_index(self.index, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Gets the value of this field for a pointer-to-structure.
|
||||||
|
pub fn get(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
pobj: PointerValue<'ctx>,
|
||||||
|
obj_name: Option<&'ctx str>,
|
||||||
|
) -> Value {
|
||||||
|
ctx.builder
|
||||||
|
.build_load(
|
||||||
|
self.ptr_by_gep(ctx, pobj, obj_name),
|
||||||
|
&obj_name.map(|name| format!("{name}.{}", self.name)).unwrap_or_default(),
|
||||||
|
)
|
||||||
|
.map_err(|_| ())
|
||||||
|
.and_then(|value| Value::try_from(value))
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sets the value of this field for a pointer-to-structure.
|
||||||
|
pub fn set(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
pobj: PointerValue<'ctx>,
|
||||||
|
value: Value,
|
||||||
|
obj_name: Option<&'ctx str>,
|
||||||
|
) {
|
||||||
|
ctx.builder.build_store(self.ptr_by_gep(ctx, pobj, obj_name), value).unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, Value> From<StructField<'ctx, Value>> for (&'static str, BasicTypeEnum<'ctx>)
|
||||||
|
where
|
||||||
|
Value: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>, Error = ()>,
|
||||||
|
{
|
||||||
|
fn from(value: StructField<'ctx, Value>) -> Self {
|
||||||
|
(value.name, value.ty)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A counter that tracks the next index of a field using a monotonically increasing counter.
|
||||||
|
#[derive(Default, Debug, PartialEq, Eq, Clone, Copy)]
|
||||||
|
pub struct FieldIndexCounter(u32);
|
||||||
|
|
||||||
|
impl FieldIndexCounter {
|
||||||
|
/// Increments the number stored by this counter, returning the previous value.
|
||||||
|
///
|
||||||
|
/// Functionally equivalent to `i++` in C-based languages.
|
||||||
|
pub fn increment(&mut self) -> u32 {
|
||||||
|
let v = self.0;
|
||||||
|
self.0 += 1;
|
||||||
|
v
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,426 @@
|
||||||
|
use inkwell::{
|
||||||
|
types::AnyTypeEnum,
|
||||||
|
values::{BasicValueEnum, IntValue, PointerValue},
|
||||||
|
IntPredicate,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::codegen::{CodeGenContext, CodeGenerator};
|
||||||
|
|
||||||
|
/// An LLVM value that is array-like, i.e. it contains a contiguous, sequenced collection of
|
||||||
|
/// elements.
|
||||||
|
pub trait ArrayLikeValue<'ctx> {
|
||||||
|
/// Returns the element type of this array-like value.
|
||||||
|
fn element_type<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &G,
|
||||||
|
) -> AnyTypeEnum<'ctx>;
|
||||||
|
|
||||||
|
/// Returns the base pointer to the array.
|
||||||
|
fn base_ptr<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &G,
|
||||||
|
) -> PointerValue<'ctx>;
|
||||||
|
|
||||||
|
/// Returns the size of this array-like value.
|
||||||
|
fn size<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &G,
|
||||||
|
) -> IntValue<'ctx>;
|
||||||
|
|
||||||
|
/// Returns a [`ArraySliceValue`] representing this value.
|
||||||
|
fn as_slice_value<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &G,
|
||||||
|
) -> ArraySliceValue<'ctx> {
|
||||||
|
ArraySliceValue::from_ptr_val(
|
||||||
|
self.base_ptr(ctx, generator),
|
||||||
|
self.size(ctx, generator),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An array-like value that can be indexed by memory offset.
|
||||||
|
pub trait ArrayLikeIndexer<'ctx, Index = IntValue<'ctx>>: ArrayLikeValue<'ctx> {
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// This function should be called with a valid index.
|
||||||
|
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut G,
|
||||||
|
idx: &Index,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> PointerValue<'ctx>;
|
||||||
|
|
||||||
|
/// Returns the pointer to the data at the `idx`-th index.
|
||||||
|
fn ptr_offset<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut G,
|
||||||
|
idx: &Index,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> PointerValue<'ctx>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An array-like value that can have its array elements accessed as a [`BasicValueEnum`].
|
||||||
|
pub trait UntypedArrayLikeAccessor<'ctx, Index = IntValue<'ctx>>:
|
||||||
|
ArrayLikeIndexer<'ctx, Index>
|
||||||
|
{
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// This function should be called with a valid index.
|
||||||
|
unsafe fn get_unchecked<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut G,
|
||||||
|
idx: &Index,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> BasicValueEnum<'ctx> {
|
||||||
|
let ptr = unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) };
|
||||||
|
ctx.builder.build_load(ptr, name.unwrap_or_default()).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the data at the `idx`-th index.
|
||||||
|
fn get<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut G,
|
||||||
|
idx: &Index,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> BasicValueEnum<'ctx> {
|
||||||
|
let ptr = self.ptr_offset(ctx, generator, idx, name);
|
||||||
|
ctx.builder.build_load(ptr, name.unwrap_or_default()).unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An array-like value that can have its array elements mutated as a [`BasicValueEnum`].
|
||||||
|
pub trait UntypedArrayLikeMutator<'ctx, Index = IntValue<'ctx>>:
|
||||||
|
ArrayLikeIndexer<'ctx, Index>
|
||||||
|
{
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// This function should be called with a valid index.
|
||||||
|
unsafe fn set_unchecked<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut G,
|
||||||
|
idx: &Index,
|
||||||
|
value: BasicValueEnum<'ctx>,
|
||||||
|
) {
|
||||||
|
let ptr = unsafe { self.ptr_offset_unchecked(ctx, generator, idx, None) };
|
||||||
|
ctx.builder.build_store(ptr, value).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sets the data at the `idx`-th index.
|
||||||
|
fn set<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut G,
|
||||||
|
idx: &Index,
|
||||||
|
value: BasicValueEnum<'ctx>,
|
||||||
|
) {
|
||||||
|
let ptr = self.ptr_offset(ctx, generator, idx, None);
|
||||||
|
ctx.builder.build_store(ptr, value).unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An array-like value that can have its array elements accessed as an arbitrary type `T`.
|
||||||
|
pub trait TypedArrayLikeAccessor<'ctx, T, Index = IntValue<'ctx>>:
|
||||||
|
UntypedArrayLikeAccessor<'ctx, Index>
|
||||||
|
{
|
||||||
|
/// Casts an element from [`BasicValueEnum`] into `T`.
|
||||||
|
fn downcast_to_type(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
value: BasicValueEnum<'ctx>,
|
||||||
|
) -> T;
|
||||||
|
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// This function should be called with a valid index.
|
||||||
|
unsafe fn get_typed_unchecked<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut G,
|
||||||
|
idx: &Index,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> T {
|
||||||
|
let value = unsafe { self.get_unchecked(ctx, generator, idx, name) };
|
||||||
|
self.downcast_to_type(ctx, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the data at the `idx`-th index.
|
||||||
|
fn get_typed<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut G,
|
||||||
|
idx: &Index,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> T {
|
||||||
|
let value = self.get(ctx, generator, idx, name);
|
||||||
|
self.downcast_to_type(ctx, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An array-like value that can have its array elements mutated as an arbitrary type `T`.
|
||||||
|
pub trait TypedArrayLikeMutator<'ctx, T, Index = IntValue<'ctx>>:
|
||||||
|
UntypedArrayLikeMutator<'ctx, Index>
|
||||||
|
{
|
||||||
|
/// Casts an element from T into [`BasicValueEnum`].
|
||||||
|
fn upcast_from_type(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
value: T,
|
||||||
|
) -> BasicValueEnum<'ctx>;
|
||||||
|
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// This function should be called with a valid index.
|
||||||
|
unsafe fn set_typed_unchecked<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut G,
|
||||||
|
idx: &Index,
|
||||||
|
value: T,
|
||||||
|
) {
|
||||||
|
let value = self.upcast_from_type(ctx, value);
|
||||||
|
unsafe { self.set_unchecked(ctx, generator, idx, value) }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sets the data at the `idx`-th index.
|
||||||
|
fn set_typed<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut G,
|
||||||
|
idx: &Index,
|
||||||
|
value: T,
|
||||||
|
) {
|
||||||
|
let value = self.upcast_from_type(ctx, value);
|
||||||
|
self.set(ctx, generator, idx, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Type alias for a function that casts a [`BasicValueEnum`] into a `T`.
|
||||||
|
type ValueDowncastFn<'ctx, T> =
|
||||||
|
Box<dyn Fn(&mut CodeGenContext<'ctx, '_>, BasicValueEnum<'ctx>) -> T>;
|
||||||
|
/// Type alias for a function that casts a `T` into a [`BasicValueEnum`].
|
||||||
|
type ValueUpcastFn<'ctx, T> = Box<dyn Fn(&mut CodeGenContext<'ctx, '_>, T) -> BasicValueEnum<'ctx>>;
|
||||||
|
|
||||||
|
/// An adapter for constraining untyped array values as typed values.
|
||||||
|
pub struct TypedArrayLikeAdapter<'ctx, T, Adapted: ArrayLikeValue<'ctx> = ArraySliceValue<'ctx>> {
|
||||||
|
adapted: Adapted,
|
||||||
|
downcast_fn: ValueDowncastFn<'ctx, T>,
|
||||||
|
upcast_fn: ValueUpcastFn<'ctx, T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, T, Adapted> TypedArrayLikeAdapter<'ctx, T, Adapted>
|
||||||
|
where
|
||||||
|
Adapted: ArrayLikeValue<'ctx>,
|
||||||
|
{
|
||||||
|
/// Creates a [`TypedArrayLikeAdapter`].
|
||||||
|
///
|
||||||
|
/// * `adapted` - The value to be adapted.
|
||||||
|
/// * `downcast_fn` - The function converting a [`BasicValueEnum`] into a `T`.
|
||||||
|
/// * `upcast_fn` - The function converting a T into a [`BasicValueEnum`].
|
||||||
|
pub fn from(
|
||||||
|
adapted: Adapted,
|
||||||
|
downcast_fn: ValueDowncastFn<'ctx, T>,
|
||||||
|
upcast_fn: ValueUpcastFn<'ctx, T>,
|
||||||
|
) -> Self {
|
||||||
|
TypedArrayLikeAdapter { adapted, downcast_fn, upcast_fn }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, T, Adapted> ArrayLikeValue<'ctx> for TypedArrayLikeAdapter<'ctx, T, Adapted>
|
||||||
|
where
|
||||||
|
Adapted: ArrayLikeValue<'ctx>,
|
||||||
|
{
|
||||||
|
fn element_type<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &G,
|
||||||
|
) -> AnyTypeEnum<'ctx> {
|
||||||
|
self.adapted.element_type(ctx, generator)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn base_ptr<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &G,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
self.adapted.base_ptr(ctx, generator)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn size<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &G,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
self.adapted.size(ctx, generator)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, T, Index, Adapted> ArrayLikeIndexer<'ctx, Index>
|
||||||
|
for TypedArrayLikeAdapter<'ctx, T, Adapted>
|
||||||
|
where
|
||||||
|
Adapted: ArrayLikeIndexer<'ctx, Index>,
|
||||||
|
{
|
||||||
|
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut G,
|
||||||
|
idx: &Index,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
unsafe { self.adapted.ptr_offset_unchecked(ctx, generator, idx, name) }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ptr_offset<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut G,
|
||||||
|
idx: &Index,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
self.adapted.ptr_offset(ctx, generator, idx, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, T, Index, Adapted> UntypedArrayLikeAccessor<'ctx, Index>
|
||||||
|
for TypedArrayLikeAdapter<'ctx, T, Adapted>
|
||||||
|
where
|
||||||
|
Adapted: UntypedArrayLikeAccessor<'ctx, Index>,
|
||||||
|
{
|
||||||
|
}
|
||||||
|
impl<'ctx, T, Index, Adapted> UntypedArrayLikeMutator<'ctx, Index>
|
||||||
|
for TypedArrayLikeAdapter<'ctx, T, Adapted>
|
||||||
|
where
|
||||||
|
Adapted: UntypedArrayLikeMutator<'ctx, Index>,
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, T, Index, Adapted> TypedArrayLikeAccessor<'ctx, T, Index>
|
||||||
|
for TypedArrayLikeAdapter<'ctx, T, Adapted>
|
||||||
|
where
|
||||||
|
Adapted: UntypedArrayLikeAccessor<'ctx, Index>,
|
||||||
|
{
|
||||||
|
fn downcast_to_type(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
value: BasicValueEnum<'ctx>,
|
||||||
|
) -> T {
|
||||||
|
(self.downcast_fn)(ctx, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, T, Index, Adapted> TypedArrayLikeMutator<'ctx, T, Index>
|
||||||
|
for TypedArrayLikeAdapter<'ctx, T, Adapted>
|
||||||
|
where
|
||||||
|
Adapted: UntypedArrayLikeMutator<'ctx, Index>,
|
||||||
|
{
|
||||||
|
fn upcast_from_type(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
value: T,
|
||||||
|
) -> BasicValueEnum<'ctx> {
|
||||||
|
(self.upcast_fn)(ctx, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An LLVM value representing an array slice, consisting of a pointer to the data and the size of
|
||||||
|
/// the slice.
|
||||||
|
#[derive(Copy, Clone)]
|
||||||
|
pub struct ArraySliceValue<'ctx>(PointerValue<'ctx>, IntValue<'ctx>, Option<&'ctx str>);
|
||||||
|
|
||||||
|
impl<'ctx> ArraySliceValue<'ctx> {
|
||||||
|
/// Creates an [`ArraySliceValue`] from a [`PointerValue`] and its size.
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_ptr_val(
|
||||||
|
ptr: PointerValue<'ctx>,
|
||||||
|
size: IntValue<'ctx>,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
) -> Self {
|
||||||
|
ArraySliceValue(ptr, size, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> From<ArraySliceValue<'ctx>> for PointerValue<'ctx> {
|
||||||
|
fn from(value: ArraySliceValue<'ctx>) -> Self {
|
||||||
|
value.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> ArrayLikeValue<'ctx> for ArraySliceValue<'ctx> {
|
||||||
|
fn element_type<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
_: &CodeGenContext<'ctx, '_>,
|
||||||
|
_: &G,
|
||||||
|
) -> AnyTypeEnum<'ctx> {
|
||||||
|
self.0.get_type().get_element_type()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn base_ptr<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
_: &CodeGenContext<'ctx, '_>,
|
||||||
|
_: &G,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
self.0
|
||||||
|
}
|
||||||
|
|
||||||
|
fn size<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
_: &CodeGenContext<'ctx, '_>,
|
||||||
|
_: &G,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
self.1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> ArrayLikeIndexer<'ctx> for ArraySliceValue<'ctx> {
|
||||||
|
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut G,
|
||||||
|
idx: &IntValue<'ctx>,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
ctx.builder
|
||||||
|
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ptr_offset<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut G,
|
||||||
|
idx: &IntValue<'ctx>,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
debug_assert_eq!(idx.get_type(), generator.get_size_type(ctx.ctx));
|
||||||
|
|
||||||
|
let size = self.size(ctx, generator);
|
||||||
|
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
in_range,
|
||||||
|
"0:IndexError",
|
||||||
|
"list index out of range",
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
|
||||||
|
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> UntypedArrayLikeAccessor<'ctx> for ArraySliceValue<'ctx> {}
|
||||||
|
impl<'ctx> UntypedArrayLikeMutator<'ctx> for ArraySliceValue<'ctx> {}
|
|
@ -0,0 +1,241 @@
|
||||||
|
use inkwell::{
|
||||||
|
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType},
|
||||||
|
values::{BasicValueEnum, IntValue, PointerValue},
|
||||||
|
AddressSpace, IntPredicate,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
ArrayLikeIndexer, ArrayLikeValue, ProxyValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
||||||
|
};
|
||||||
|
use crate::codegen::{
|
||||||
|
types::ListType,
|
||||||
|
{CodeGenContext, CodeGenerator},
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Proxy type for accessing a `list` value in LLVM.
|
||||||
|
#[derive(Copy, Clone)]
|
||||||
|
pub struct ListValue<'ctx> {
|
||||||
|
value: PointerValue<'ctx>,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> ListValue<'ctx> {
|
||||||
|
/// Checks whether `value` is an instance of `list`, returning [Err] if `value` is not an
|
||||||
|
/// instance.
|
||||||
|
pub fn is_representable(
|
||||||
|
value: PointerValue<'ctx>,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
ListType::is_representable(value.get_type(), llvm_usize)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates an [`ListValue`] from a [`PointerValue`].
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_pointer_value(
|
||||||
|
ptr: PointerValue<'ctx>,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
) -> Self {
|
||||||
|
debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok());
|
||||||
|
|
||||||
|
ListValue { value: ptr, llvm_usize, name }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
|
||||||
|
/// on the field.
|
||||||
|
fn pptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default();
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
ctx.builder
|
||||||
|
.build_in_bounds_gep(
|
||||||
|
self.as_base_value(),
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||||
|
var_name.as_str(),
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the pointer to the field storing the size of this `list`.
|
||||||
|
fn ptr_to_size(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
let var_name = self.name.map(|v| format!("{v}.size.addr")).unwrap_or_default();
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
ctx.builder
|
||||||
|
.build_in_bounds_gep(
|
||||||
|
self.as_base_value(),
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
|
||||||
|
var_name.as_str(),
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stores the array of data elements `data` into this instance.
|
||||||
|
fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) {
|
||||||
|
ctx.builder.build_store(self.pptr_to_data(ctx), data).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convenience method for creating a new array storing data elements with the given element
|
||||||
|
/// type `elem_ty` and `size`.
|
||||||
|
///
|
||||||
|
/// If `size` is [None], the size stored in the field of this instance is used instead.
|
||||||
|
pub fn create_data(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
elem_ty: BasicTypeEnum<'ctx>,
|
||||||
|
size: Option<IntValue<'ctx>>,
|
||||||
|
) {
|
||||||
|
let size = size.unwrap_or_else(|| self.load_size(ctx, None));
|
||||||
|
|
||||||
|
let data = ctx
|
||||||
|
.builder
|
||||||
|
.build_select(
|
||||||
|
ctx.builder
|
||||||
|
.build_int_compare(IntPredicate::NE, size, self.llvm_usize.const_zero(), "")
|
||||||
|
.unwrap(),
|
||||||
|
ctx.builder.build_array_alloca(elem_ty, size, "").unwrap(),
|
||||||
|
elem_ty.ptr_type(AddressSpace::default()).const_zero(),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
|
.unwrap();
|
||||||
|
self.store_data(ctx, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
|
||||||
|
/// on the field.
|
||||||
|
#[must_use]
|
||||||
|
pub fn data(&self) -> ListDataProxy<'ctx, '_> {
|
||||||
|
ListDataProxy(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stores the `size` of this `list` into this instance.
|
||||||
|
pub fn store_size<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &G,
|
||||||
|
size: IntValue<'ctx>,
|
||||||
|
) {
|
||||||
|
debug_assert_eq!(size.get_type(), generator.get_size_type(ctx.ctx));
|
||||||
|
|
||||||
|
let psize = self.ptr_to_size(ctx);
|
||||||
|
ctx.builder.build_store(psize, size).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the size of this `list` as a value.
|
||||||
|
pub fn load_size(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> {
|
||||||
|
let psize = self.ptr_to_size(ctx);
|
||||||
|
let var_name = name
|
||||||
|
.map(ToString::to_string)
|
||||||
|
.or_else(|| self.name.map(|v| format!("{v}.size")))
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
ctx.builder
|
||||||
|
.build_load(psize, var_name.as_str())
|
||||||
|
.map(BasicValueEnum::into_int_value)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> ProxyValue<'ctx> for ListValue<'ctx> {
|
||||||
|
type Base = PointerValue<'ctx>;
|
||||||
|
type Type = ListType<'ctx>;
|
||||||
|
|
||||||
|
fn get_type(&self) -> Self::Type {
|
||||||
|
ListType::from_type(self.as_base_value().get_type(), self.llvm_usize)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn as_base_value(&self) -> Self::Base {
|
||||||
|
self.value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> From<ListValue<'ctx>> for PointerValue<'ctx> {
|
||||||
|
fn from(value: ListValue<'ctx>) -> Self {
|
||||||
|
value.as_base_value()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Proxy type for accessing the `data` array of an `list` instance in LLVM.
|
||||||
|
#[derive(Copy, Clone)]
|
||||||
|
pub struct ListDataProxy<'ctx, 'a>(&'a ListValue<'ctx>);
|
||||||
|
|
||||||
|
impl<'ctx> ArrayLikeValue<'ctx> for ListDataProxy<'ctx, '_> {
|
||||||
|
fn element_type<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
_: &CodeGenContext<'ctx, '_>,
|
||||||
|
_: &G,
|
||||||
|
) -> AnyTypeEnum<'ctx> {
|
||||||
|
self.0.value.get_type().get_element_type()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn base_ptr<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
_: &G,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
|
||||||
|
|
||||||
|
ctx.builder
|
||||||
|
.build_load(self.0.pptr_to_data(ctx), var_name.as_str())
|
||||||
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn size<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
_: &G,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
self.0.load_size(ctx, None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> ArrayLikeIndexer<'ctx> for ListDataProxy<'ctx, '_> {
|
||||||
|
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut G,
|
||||||
|
idx: &IntValue<'ctx>,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
ctx.builder
|
||||||
|
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ptr_offset<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut G,
|
||||||
|
idx: &IntValue<'ctx>,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
debug_assert_eq!(idx.get_type(), generator.get_size_type(ctx.ctx));
|
||||||
|
|
||||||
|
let size = self.size(ctx, generator);
|
||||||
|
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
in_range,
|
||||||
|
"0:IndexError",
|
||||||
|
"list index out of range",
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
|
||||||
|
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> UntypedArrayLikeAccessor<'ctx> for ListDataProxy<'ctx, '_> {}
|
||||||
|
impl<'ctx> UntypedArrayLikeMutator<'ctx> for ListDataProxy<'ctx, '_> {}
|
|
@ -0,0 +1,47 @@
|
||||||
|
use inkwell::{context::Context, values::BasicValue};
|
||||||
|
|
||||||
|
use super::types::ProxyType;
|
||||||
|
use crate::codegen::CodeGenerator;
|
||||||
|
pub use array::*;
|
||||||
|
pub use list::*;
|
||||||
|
pub use ndarray::*;
|
||||||
|
pub use range::*;
|
||||||
|
|
||||||
|
mod array;
|
||||||
|
mod list;
|
||||||
|
mod ndarray;
|
||||||
|
mod range;
|
||||||
|
|
||||||
|
/// A LLVM type that is used to represent a non-primitive value in NAC3.
|
||||||
|
pub trait ProxyValue<'ctx>: Into<Self::Base> {
|
||||||
|
/// The type of LLVM values represented by this instance. This is usually the
|
||||||
|
/// [LLVM pointer type][PointerValue].
|
||||||
|
type Base: BasicValue<'ctx>;
|
||||||
|
|
||||||
|
/// The type of this value.
|
||||||
|
type Type: ProxyType<'ctx, Value = Self>;
|
||||||
|
|
||||||
|
/// Checks whether `value` can be represented by this [`ProxyValue`].
|
||||||
|
fn is_instance<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &'ctx Context,
|
||||||
|
value: impl BasicValue<'ctx>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
Self::Type::is_type(generator, ctx, value.as_basic_value_enum().get_type())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Checks whether `value` can be represented by this [`ProxyValue`].
|
||||||
|
fn is_representable<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &'ctx Context,
|
||||||
|
value: Self::Base,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
Self::is_instance(generator, ctx, value.as_basic_value_enum())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the [type][ProxyType] of this value.
|
||||||
|
fn get_type(&self) -> Self::Type;
|
||||||
|
|
||||||
|
/// Returns the [base value][Self::Base] of this proxy.
|
||||||
|
fn as_base_value(&self) -> Self::Base;
|
||||||
|
}
|
|
@ -0,0 +1,523 @@
|
||||||
|
use inkwell::{
|
||||||
|
types::{AnyType, AnyTypeEnum, BasicType, BasicTypeEnum, IntType},
|
||||||
|
values::{BasicValueEnum, IntValue, PointerValue},
|
||||||
|
AddressSpace, IntPredicate,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeMutator,
|
||||||
|
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
||||||
|
};
|
||||||
|
use crate::codegen::{
|
||||||
|
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
|
||||||
|
llvm_intrinsics::call_int_umin,
|
||||||
|
stmt::gen_for_callback_incrementing,
|
||||||
|
types::NDArrayType,
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Proxy type for accessing an `NDArray` value in LLVM.
|
||||||
|
#[derive(Copy, Clone)]
|
||||||
|
pub struct NDArrayValue<'ctx> {
|
||||||
|
value: PointerValue<'ctx>,
|
||||||
|
dtype: BasicTypeEnum<'ctx>,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> NDArrayValue<'ctx> {
|
||||||
|
/// Checks whether `value` is an instance of `NDArray`, returning [Err] if `value` is not an
|
||||||
|
/// instance.
|
||||||
|
pub fn is_representable(
|
||||||
|
value: PointerValue<'ctx>,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
NDArrayType::is_representable(value.get_type(), llvm_usize)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates an [`NDArrayValue`] from a [`PointerValue`].
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_pointer_value(
|
||||||
|
ptr: PointerValue<'ctx>,
|
||||||
|
dtype: BasicTypeEnum<'ctx>,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
) -> Self {
|
||||||
|
debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok());
|
||||||
|
|
||||||
|
NDArrayValue { value: ptr, dtype, llvm_usize, name }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
|
||||||
|
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
|
self.get_type()
|
||||||
|
.get_fields(ctx.ctx, self.llvm_usize)
|
||||||
|
.ndims
|
||||||
|
.ptr_by_gep(ctx, self.value, self.name)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stores the number of dimensions `ndims` into this instance.
|
||||||
|
pub fn store_ndims<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &G,
|
||||||
|
ndims: IntValue<'ctx>,
|
||||||
|
) {
|
||||||
|
debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx));
|
||||||
|
|
||||||
|
let pndims = self.ptr_to_ndims(ctx);
|
||||||
|
ctx.builder.build_store(pndims, ndims).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the number of dimensions of this `NDArray` as a value.
|
||||||
|
pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
|
||||||
|
let pndims = self.ptr_to_ndims(ctx);
|
||||||
|
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the double-indirection pointer to the `shape` array, as if by calling
|
||||||
|
/// `getelementptr` on the field.
|
||||||
|
fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
|
self.get_type()
|
||||||
|
.get_fields(ctx.ctx, self.llvm_usize)
|
||||||
|
.shape
|
||||||
|
.ptr_by_gep(ctx, self.value, self.name)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stores the array of dimension sizes `dims` into this instance.
|
||||||
|
fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
|
||||||
|
ctx.builder.build_store(self.ptr_to_shape(ctx), dims).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convenience method for creating a new array storing dimension sizes with the given `size`.
|
||||||
|
pub fn create_shape(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
size: IntValue<'ctx>,
|
||||||
|
) {
|
||||||
|
self.store_shape(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a proxy object to the field storing the size of each dimension of this `NDArray`.
|
||||||
|
#[must_use]
|
||||||
|
pub fn shape(&self) -> NDArrayShapeProxy<'ctx, '_> {
|
||||||
|
NDArrayShapeProxy(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
|
||||||
|
/// on the field.
|
||||||
|
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
|
self.get_type()
|
||||||
|
.get_fields(ctx.ctx, self.llvm_usize)
|
||||||
|
.data
|
||||||
|
.ptr_by_gep(ctx, self.value, self.name)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stores the array of data elements `data` into this instance.
|
||||||
|
fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) {
|
||||||
|
let data = ctx
|
||||||
|
.builder
|
||||||
|
.build_bit_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "")
|
||||||
|
.unwrap();
|
||||||
|
ctx.builder.build_store(self.ptr_to_data(ctx), data).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convenience method for creating a new array storing data elements with the given element
|
||||||
|
/// type `elem_ty` and `size`.
|
||||||
|
pub fn create_data(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
elem_ty: BasicTypeEnum<'ctx>,
|
||||||
|
size: IntValue<'ctx>,
|
||||||
|
) {
|
||||||
|
let itemsize =
|
||||||
|
ctx.builder.build_int_cast(elem_ty.size_of().unwrap(), size.get_type(), "").unwrap();
|
||||||
|
let nbytes = ctx.builder.build_int_mul(size, itemsize, "").unwrap();
|
||||||
|
|
||||||
|
// TODO: What about alignment?
|
||||||
|
self.store_data(
|
||||||
|
ctx,
|
||||||
|
ctx.builder.build_array_alloca(ctx.ctx.i8_type(), nbytes, "").unwrap(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a proxy object to the field storing the data of this `NDArray`.
|
||||||
|
#[must_use]
|
||||||
|
pub fn data(&self) -> NDArrayDataProxy<'ctx, '_> {
|
||||||
|
NDArrayDataProxy(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> {
|
||||||
|
type Base = PointerValue<'ctx>;
|
||||||
|
type Type = NDArrayType<'ctx>;
|
||||||
|
|
||||||
|
fn get_type(&self) -> Self::Type {
|
||||||
|
NDArrayType::from_type(self.as_base_value().get_type(), self.dtype, self.llvm_usize)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn as_base_value(&self) -> Self::Base {
|
||||||
|
self.value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> From<NDArrayValue<'ctx>> for PointerValue<'ctx> {
|
||||||
|
fn from(value: NDArrayValue<'ctx>) -> Self {
|
||||||
|
value.as_base_value()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM.
|
||||||
|
#[derive(Copy, Clone)]
|
||||||
|
pub struct NDArrayShapeProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
|
||||||
|
|
||||||
|
impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> {
|
||||||
|
fn element_type<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &G,
|
||||||
|
) -> AnyTypeEnum<'ctx> {
|
||||||
|
self.0.shape().base_ptr(ctx, generator).get_type().get_element_type()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn base_ptr<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
_: &G,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
|
||||||
|
|
||||||
|
ctx.builder
|
||||||
|
.build_load(self.0.ptr_to_shape(ctx), var_name.as_str())
|
||||||
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn size<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
_: &G,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
self.0.load_ndims(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
|
||||||
|
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut G,
|
||||||
|
idx: &IntValue<'ctx>,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
ctx.builder
|
||||||
|
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ptr_offset<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut G,
|
||||||
|
idx: &IntValue<'ctx>,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
let size = self.size(ctx, generator);
|
||||||
|
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
in_range,
|
||||||
|
"0:IndexError",
|
||||||
|
"index {0} is out of bounds for axis 0 with size {1}",
|
||||||
|
[Some(*idx), Some(self.0.load_ndims(ctx)), None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
|
||||||
|
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {}
|
||||||
|
impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {}
|
||||||
|
|
||||||
|
impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
|
||||||
|
fn downcast_to_type(
|
||||||
|
&self,
|
||||||
|
_: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
value: BasicValueEnum<'ctx>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
value.into_int_value()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
|
||||||
|
fn upcast_from_type(
|
||||||
|
&self,
|
||||||
|
_: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
value: IntValue<'ctx>,
|
||||||
|
) -> BasicValueEnum<'ctx> {
|
||||||
|
value.into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Proxy type for accessing the `data` array of an `NDArray` instance in LLVM.
|
||||||
|
#[derive(Copy, Clone)]
|
||||||
|
pub struct NDArrayDataProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
|
||||||
|
|
||||||
|
impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> {
|
||||||
|
fn element_type<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
_: &CodeGenContext<'ctx, '_>,
|
||||||
|
_: &G,
|
||||||
|
) -> AnyTypeEnum<'ctx> {
|
||||||
|
self.0.dtype.as_any_type_enum()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn base_ptr<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
_: &G,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
|
||||||
|
|
||||||
|
ctx.builder
|
||||||
|
.build_load(self.0.ptr_to_data(ctx), var_name.as_str())
|
||||||
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn size<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &G,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
call_ndarray_calc_size(generator, ctx, &self.as_slice_value(ctx, generator), (None, None))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> {
|
||||||
|
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut G,
|
||||||
|
idx: &IntValue<'ctx>,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
let sizeof_elem = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_truncate_or_bit_cast(
|
||||||
|
self.element_type(ctx, generator).size_of().unwrap(),
|
||||||
|
idx.get_type(),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let idx = ctx.builder.build_int_mul(*idx, sizeof_elem, "").unwrap();
|
||||||
|
let ptr = unsafe {
|
||||||
|
ctx.builder
|
||||||
|
.build_in_bounds_gep(
|
||||||
|
self.base_ptr(ctx, generator),
|
||||||
|
&[idx],
|
||||||
|
name.unwrap_or_default(),
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
};
|
||||||
|
|
||||||
|
// Current implementation is transparent - The returned pointer type is
|
||||||
|
// already cast into the expected type, allowing for immediately
|
||||||
|
// load/store.
|
||||||
|
ctx.builder
|
||||||
|
.build_pointer_cast(
|
||||||
|
ptr,
|
||||||
|
BasicTypeEnum::try_from(self.element_type(ctx, generator))
|
||||||
|
.unwrap()
|
||||||
|
.ptr_type(AddressSpace::default()),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ptr_offset<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut G,
|
||||||
|
idx: &IntValue<'ctx>,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
let data_sz = self.size(ctx, generator);
|
||||||
|
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, data_sz, "").unwrap();
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
in_range,
|
||||||
|
"0:IndexError",
|
||||||
|
"index {0} is out of bounds with size {1}",
|
||||||
|
[Some(*idx), Some(self.0.load_ndims(ctx)), None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
|
||||||
|
let ptr = unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) };
|
||||||
|
|
||||||
|
// Current implementation is transparent - The returned pointer type is
|
||||||
|
// already cast into the expected type, allowing for immediately
|
||||||
|
// load/store.
|
||||||
|
ctx.builder
|
||||||
|
.build_pointer_cast(
|
||||||
|
ptr,
|
||||||
|
BasicTypeEnum::try_from(self.element_type(ctx, generator))
|
||||||
|
.unwrap()
|
||||||
|
.ptr_type(AddressSpace::default()),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDataProxy<'ctx, '_> {}
|
||||||
|
impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayDataProxy<'ctx, '_> {}
|
||||||
|
|
||||||
|
impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
|
||||||
|
for NDArrayDataProxy<'ctx, '_>
|
||||||
|
{
|
||||||
|
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut G,
|
||||||
|
indices: &Index,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
let indices_elem_ty = indices
|
||||||
|
.ptr_offset(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
|
.get_type()
|
||||||
|
.get_element_type();
|
||||||
|
let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else {
|
||||||
|
panic!("Expected list[int32] but got {indices_elem_ty}")
|
||||||
|
};
|
||||||
|
assert_eq!(
|
||||||
|
indices_elem_ty.get_bit_width(),
|
||||||
|
32,
|
||||||
|
"Expected list[int32] but got list[int{}]",
|
||||||
|
indices_elem_ty.get_bit_width()
|
||||||
|
);
|
||||||
|
|
||||||
|
let index = call_ndarray_flatten_index(generator, ctx, *self.0, indices);
|
||||||
|
let sizeof_elem = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_truncate_or_bit_cast(
|
||||||
|
self.element_type(ctx, generator).size_of().unwrap(),
|
||||||
|
index.get_type(),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let index = ctx.builder.build_int_mul(index, sizeof_elem, "").unwrap();
|
||||||
|
|
||||||
|
let ptr = unsafe {
|
||||||
|
ctx.builder
|
||||||
|
.build_in_bounds_gep(
|
||||||
|
self.base_ptr(ctx, generator),
|
||||||
|
&[index],
|
||||||
|
name.unwrap_or_default(),
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
};
|
||||||
|
// TODO: Current implementation is transparent
|
||||||
|
ctx.builder
|
||||||
|
.build_pointer_cast(
|
||||||
|
ptr,
|
||||||
|
BasicTypeEnum::try_from(self.element_type(ctx, generator))
|
||||||
|
.unwrap()
|
||||||
|
.ptr_type(AddressSpace::default()),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ptr_offset<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut G,
|
||||||
|
indices: &Index,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
let indices_size = indices.size(ctx, generator);
|
||||||
|
let nidx_leq_ndims = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(IntPredicate::SLE, indices_size, self.0.load_ndims(ctx), "")
|
||||||
|
.unwrap();
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
nidx_leq_ndims,
|
||||||
|
"0:IndexError",
|
||||||
|
"invalid index to scalar variable",
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
|
||||||
|
let indices_len = indices.size(ctx, generator);
|
||||||
|
let ndarray_len = self.0.load_ndims(ctx);
|
||||||
|
let len = call_int_umin(ctx, indices_len, ndarray_len, None);
|
||||||
|
gen_for_callback_incrementing(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
None,
|
||||||
|
llvm_usize.const_zero(),
|
||||||
|
(len, false),
|
||||||
|
|generator, ctx, _, i| {
|
||||||
|
let (dim_idx, dim_sz) = unsafe {
|
||||||
|
(
|
||||||
|
indices.get_unchecked(ctx, generator, &i, None).into_int_value(),
|
||||||
|
self.0.shape().get_typed_unchecked(ctx, generator, &i, None),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
let dim_idx = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_z_extend_or_bit_cast(dim_idx, dim_sz.get_type(), "")
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let dim_lt =
|
||||||
|
ctx.builder.build_int_compare(IntPredicate::SLT, dim_idx, dim_sz, "").unwrap();
|
||||||
|
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
dim_lt,
|
||||||
|
"0:IndexError",
|
||||||
|
"index {0} is out of bounds for axis 0 with size {1}",
|
||||||
|
[Some(dim_idx), Some(dim_sz), None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
llvm_usize.const_int(1, false),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let ptr = unsafe { self.ptr_offset_unchecked(ctx, generator, indices, name) };
|
||||||
|
// TODO: Current implementation is transparent
|
||||||
|
ctx.builder
|
||||||
|
.build_pointer_cast(
|
||||||
|
ptr,
|
||||||
|
BasicTypeEnum::try_from(self.element_type(ctx, generator))
|
||||||
|
.unwrap()
|
||||||
|
.ptr_type(AddressSpace::default()),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeAccessor<'ctx, Index>
|
||||||
|
for NDArrayDataProxy<'ctx, '_>
|
||||||
|
{
|
||||||
|
}
|
||||||
|
impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeMutator<'ctx, Index>
|
||||||
|
for NDArrayDataProxy<'ctx, '_>
|
||||||
|
{
|
||||||
|
}
|
|
@ -0,0 +1,153 @@
|
||||||
|
use inkwell::values::{BasicValueEnum, IntValue, PointerValue};
|
||||||
|
|
||||||
|
use super::ProxyValue;
|
||||||
|
use crate::codegen::{types::RangeType, CodeGenContext};
|
||||||
|
|
||||||
|
/// Proxy type for accessing a `range` value in LLVM.
|
||||||
|
#[derive(Copy, Clone)]
|
||||||
|
pub struct RangeValue<'ctx> {
|
||||||
|
value: PointerValue<'ctx>,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> RangeValue<'ctx> {
|
||||||
|
/// Checks whether `value` is an instance of `range`, returning [Err] if `value` is not an instance.
|
||||||
|
pub fn is_representable(value: PointerValue<'ctx>) -> Result<(), String> {
|
||||||
|
RangeType::is_representable(value.get_type())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates an [`RangeValue`] from a [`PointerValue`].
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_pointer_value(ptr: PointerValue<'ctx>, name: Option<&'ctx str>) -> Self {
|
||||||
|
debug_assert!(Self::is_representable(ptr).is_ok());
|
||||||
|
|
||||||
|
RangeValue { value: ptr, name }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ptr_to_start(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
let var_name = self.name.map(|v| format!("{v}.start.addr")).unwrap_or_default();
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
ctx.builder
|
||||||
|
.build_in_bounds_gep(
|
||||||
|
self.as_base_value(),
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_int(0, false)],
|
||||||
|
var_name.as_str(),
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ptr_to_end(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
let var_name = self.name.map(|v| format!("{v}.end.addr")).unwrap_or_default();
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
ctx.builder
|
||||||
|
.build_in_bounds_gep(
|
||||||
|
self.as_base_value(),
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_int(1, false)],
|
||||||
|
var_name.as_str(),
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ptr_to_step(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
let var_name = self.name.map(|v| format!("{v}.step.addr")).unwrap_or_default();
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
ctx.builder
|
||||||
|
.build_in_bounds_gep(
|
||||||
|
self.as_base_value(),
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_int(2, false)],
|
||||||
|
var_name.as_str(),
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stores the `start` value into this instance.
|
||||||
|
pub fn store_start(&self, ctx: &CodeGenContext<'ctx, '_>, start: IntValue<'ctx>) {
|
||||||
|
debug_assert_eq!(start.get_type().get_bit_width(), 32);
|
||||||
|
|
||||||
|
let pstart = self.ptr_to_start(ctx);
|
||||||
|
ctx.builder.build_store(pstart, start).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the `start` value of this `range`.
|
||||||
|
pub fn load_start(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> {
|
||||||
|
let pstart = self.ptr_to_start(ctx);
|
||||||
|
let var_name = name
|
||||||
|
.map(ToString::to_string)
|
||||||
|
.or_else(|| self.name.map(|v| format!("{v}.start")))
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
ctx.builder
|
||||||
|
.build_load(pstart, var_name.as_str())
|
||||||
|
.map(BasicValueEnum::into_int_value)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stores the `end` value into this instance.
|
||||||
|
pub fn store_end(&self, ctx: &CodeGenContext<'ctx, '_>, end: IntValue<'ctx>) {
|
||||||
|
debug_assert_eq!(end.get_type().get_bit_width(), 32);
|
||||||
|
|
||||||
|
let pend = self.ptr_to_end(ctx);
|
||||||
|
ctx.builder.build_store(pend, end).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the `end` value of this `range`.
|
||||||
|
pub fn load_end(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> {
|
||||||
|
let pend = self.ptr_to_end(ctx);
|
||||||
|
let var_name = name
|
||||||
|
.map(ToString::to_string)
|
||||||
|
.or_else(|| self.name.map(|v| format!("{v}.end")))
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
ctx.builder.build_load(pend, var_name.as_str()).map(BasicValueEnum::into_int_value).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stores the `step` value into this instance.
|
||||||
|
pub fn store_step(&self, ctx: &CodeGenContext<'ctx, '_>, step: IntValue<'ctx>) {
|
||||||
|
debug_assert_eq!(step.get_type().get_bit_width(), 32);
|
||||||
|
|
||||||
|
let pstep = self.ptr_to_step(ctx);
|
||||||
|
ctx.builder.build_store(pstep, step).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the `step` value of this `range`.
|
||||||
|
pub fn load_step(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> {
|
||||||
|
let pstep = self.ptr_to_step(ctx);
|
||||||
|
let var_name = name
|
||||||
|
.map(ToString::to_string)
|
||||||
|
.or_else(|| self.name.map(|v| format!("{v}.step")))
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
ctx.builder
|
||||||
|
.build_load(pstep, var_name.as_str())
|
||||||
|
.map(BasicValueEnum::into_int_value)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> ProxyValue<'ctx> for RangeValue<'ctx> {
|
||||||
|
type Base = PointerValue<'ctx>;
|
||||||
|
type Type = RangeType<'ctx>;
|
||||||
|
|
||||||
|
fn get_type(&self) -> Self::Type {
|
||||||
|
RangeType::from_type(self.value.get_type())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn as_base_value(&self) -> Self::Base {
|
||||||
|
self.value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> From<RangeValue<'ctx>> for PointerValue<'ctx> {
|
||||||
|
fn from(value: RangeValue<'ctx>) -> Self {
|
||||||
|
value.as_base_value()
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,10 +1,4 @@
|
||||||
#![deny(
|
#![deny(future_incompatible, let_underscore, nonstandard_style, clippy::all)]
|
||||||
future_incompatible,
|
|
||||||
let_underscore,
|
|
||||||
nonstandard_style,
|
|
||||||
rust_2024_compatibility,
|
|
||||||
clippy::all
|
|
||||||
)]
|
|
||||||
#![warn(clippy::pedantic)]
|
#![warn(clippy::pedantic)]
|
||||||
#![allow(
|
#![allow(
|
||||||
dead_code,
|
dead_code,
|
||||||
|
@ -19,7 +13,13 @@
|
||||||
clippy::wildcard_imports
|
clippy::wildcard_imports
|
||||||
)]
|
)]
|
||||||
|
|
||||||
|
// users of nac3core need to use the same version of these dependencies, so expose them as nac3core::*
|
||||||
|
pub use inkwell;
|
||||||
|
pub use nac3parser;
|
||||||
|
|
||||||
pub mod codegen;
|
pub mod codegen;
|
||||||
pub mod symbol_resolver;
|
pub mod symbol_resolver;
|
||||||
pub mod toplevel;
|
pub mod toplevel;
|
||||||
pub mod typecheck;
|
pub mod typecheck;
|
||||||
|
|
||||||
|
extern crate self as nac3core;
|
||||||
|
|
|
@ -1,7 +1,15 @@
|
||||||
use std::fmt::Debug;
|
use std::{
|
||||||
use std::rc::Rc;
|
collections::{HashMap, HashSet},
|
||||||
use std::sync::Arc;
|
fmt::{Debug, Display},
|
||||||
use std::{collections::HashMap, collections::HashSet, fmt::Display};
|
rc::Rc,
|
||||||
|
sync::Arc,
|
||||||
|
};
|
||||||
|
|
||||||
|
use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue};
|
||||||
|
use itertools::{chain, izip, Itertools};
|
||||||
|
use parking_lot::RwLock;
|
||||||
|
|
||||||
|
use nac3parser::ast::{Constant, Expr, Location, StrRef};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
codegen::{CodeGenContext, CodeGenerator},
|
codegen::{CodeGenContext, CodeGenerator},
|
||||||
|
@ -11,10 +19,6 @@ use crate::{
|
||||||
typedef::{Type, TypeEnum, Unifier, VarMap},
|
typedef::{Type, TypeEnum, Unifier, VarMap},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue};
|
|
||||||
use itertools::{chain, izip, Itertools};
|
|
||||||
use nac3parser::ast::{Constant, Expr, Location, StrRef};
|
|
||||||
use parking_lot::RwLock;
|
|
||||||
|
|
||||||
#[derive(Clone, PartialEq, Debug)]
|
#[derive(Clone, PartialEq, Debug)]
|
||||||
pub enum SymbolValue {
|
pub enum SymbolValue {
|
||||||
|
@ -365,6 +369,7 @@ pub trait SymbolResolver {
|
||||||
&self,
|
&self,
|
||||||
str: StrRef,
|
str: StrRef,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
) -> Option<ValueEnum<'ctx>>;
|
) -> Option<ValueEnum<'ctx>>;
|
||||||
|
|
||||||
fn get_default_param_value(&self, expr: &Expr) -> Option<SymbolValue>;
|
fn get_default_param_value(&self, expr: &Expr) -> Option<SymbolValue>;
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
use std::iter::once;
|
use std::iter::once;
|
||||||
|
|
||||||
use helper::{debug_assert_prim_is_allowed, extract_ndims, make_exception_fields, PrimDefDetails};
|
|
||||||
use indexmap::IndexMap;
|
use indexmap::IndexMap;
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
attributes::{Attribute, AttributeLoc},
|
attributes::{Attribute, AttributeLoc},
|
||||||
|
@ -9,28 +8,24 @@ use inkwell::{
|
||||||
IntPredicate,
|
IntPredicate,
|
||||||
};
|
};
|
||||||
use itertools::Either;
|
use itertools::Either;
|
||||||
use numpy::unpack_ndarray_var_tys;
|
|
||||||
use strum::IntoEnumIterator;
|
use strum::IntoEnumIterator;
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
helper::{debug_assert_prim_is_allowed, make_exception_fields, PrimDef, PrimDefDetails},
|
||||||
|
numpy::make_ndarray_ty,
|
||||||
|
*,
|
||||||
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
codegen::{
|
codegen::{
|
||||||
builtin_fns,
|
builtin_fns,
|
||||||
classes::{ProxyValue, RangeValue},
|
|
||||||
model::*,
|
|
||||||
numpy::*,
|
numpy::*,
|
||||||
object::{
|
|
||||||
any::AnyObject,
|
|
||||||
ndarray::{shape_util::parse_numpy_int_sequence, NDArrayObject},
|
|
||||||
},
|
|
||||||
stmt::exn_constructor,
|
stmt::exn_constructor,
|
||||||
|
values::{ProxyValue, RangeValue},
|
||||||
},
|
},
|
||||||
symbol_resolver::SymbolValue,
|
symbol_resolver::SymbolValue,
|
||||||
toplevel::{helper::PrimDef, numpy::make_ndarray_ty},
|
|
||||||
typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap},
|
typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap},
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
type BuiltinInfo = Vec<(Arc<RwLock<TopLevelDef>>, Option<Stmt>)>;
|
type BuiltinInfo = Vec<(Arc<RwLock<TopLevelDef>>, Option<Stmt>)>;
|
||||||
|
|
||||||
pub fn get_exn_constructor(
|
pub fn get_exn_constructor(
|
||||||
|
@ -517,14 +512,6 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
| PrimDef::FunNpEye
|
| PrimDef::FunNpEye
|
||||||
| PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim),
|
| PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim),
|
||||||
|
|
||||||
PrimDef::FunNpSize | PrimDef::FunNpShape | PrimDef::FunNpStrides => {
|
|
||||||
self.build_ndarray_property_getter_function(prim)
|
|
||||||
}
|
|
||||||
|
|
||||||
PrimDef::FunNpBroadcastTo | PrimDef::FunNpTranspose | PrimDef::FunNpReshape => {
|
|
||||||
self.build_ndarray_view_function(prim)
|
|
||||||
}
|
|
||||||
|
|
||||||
PrimDef::FunStr => self.build_str_function(),
|
PrimDef::FunStr => self.build_str_function(),
|
||||||
|
|
||||||
PrimDef::FunFloor | PrimDef::FunFloor64 | PrimDef::FunCeil | PrimDef::FunCeil64 => {
|
PrimDef::FunFloor | PrimDef::FunFloor64 | PrimDef::FunCeil | PrimDef::FunCeil64 => {
|
||||||
|
@ -590,6 +577,10 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
| PrimDef::FunNpHypot
|
| PrimDef::FunNpHypot
|
||||||
| PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim),
|
| PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim),
|
||||||
|
|
||||||
|
PrimDef::FunNpTranspose | PrimDef::FunNpReshape => {
|
||||||
|
self.build_np_sp_ndarray_function(prim)
|
||||||
|
}
|
||||||
|
|
||||||
PrimDef::FunNpDot
|
PrimDef::FunNpDot
|
||||||
| PrimDef::FunNpLinalgCholesky
|
| PrimDef::FunNpLinalgCholesky
|
||||||
| PrimDef::FunNpLinalgQr
|
| PrimDef::FunNpLinalgQr
|
||||||
|
@ -719,7 +710,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
let (zelf_ty, zelf) = obj.unwrap();
|
let (zelf_ty, zelf) = obj.unwrap();
|
||||||
let zelf =
|
let zelf =
|
||||||
zelf.to_basic_value_enum(ctx, generator, zelf_ty)?.into_pointer_value();
|
zelf.to_basic_value_enum(ctx, generator, zelf_ty)?.into_pointer_value();
|
||||||
let zelf = RangeValue::from_ptr_val(zelf, Some("range"));
|
let zelf = RangeValue::from_pointer_value(zelf, Some("range"));
|
||||||
|
|
||||||
let mut start = None;
|
let mut start = None;
|
||||||
let mut stop = None;
|
let mut stop = None;
|
||||||
|
@ -1395,171 +1386,6 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_ndarray_property_getter_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
|
||||||
debug_assert_prim_is_allowed(
|
|
||||||
prim,
|
|
||||||
&[PrimDef::FunNpSize, PrimDef::FunNpShape, PrimDef::FunNpStrides],
|
|
||||||
);
|
|
||||||
|
|
||||||
let in_ndarray_ty = self.unifier.get_fresh_var_with_range(
|
|
||||||
&[self.primitives.ndarray],
|
|
||||||
Some("T".into()),
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
|
|
||||||
match prim {
|
|
||||||
PrimDef::FunNpSize => create_fn_by_codegen(
|
|
||||||
self.unifier,
|
|
||||||
&into_var_map([in_ndarray_ty]),
|
|
||||||
prim.name(),
|
|
||||||
self.primitives.int32,
|
|
||||||
&[(in_ndarray_ty.ty, "a")],
|
|
||||||
Box::new(|ctx, obj, fun, args, generator| {
|
|
||||||
assert!(obj.is_none());
|
|
||||||
assert_eq!(args.len(), 1);
|
|
||||||
|
|
||||||
let ndarray_ty = fun.0.args[0].ty;
|
|
||||||
let ndarray =
|
|
||||||
args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
|
|
||||||
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
|
|
||||||
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
|
|
||||||
|
|
||||||
let size =
|
|
||||||
ndarray.size(generator, ctx).truncate_or_bit_cast(generator, ctx, Int32);
|
|
||||||
Ok(Some(size.value.as_basic_value_enum()))
|
|
||||||
}),
|
|
||||||
),
|
|
||||||
PrimDef::FunNpShape | PrimDef::FunNpStrides => {
|
|
||||||
// The function signatures of `np_shape` an `np_size` are the same.
|
|
||||||
// Mixed together for convenience.
|
|
||||||
|
|
||||||
// The return type is a tuple of variable length depending on the ndims of the input ndarray.
|
|
||||||
let ret_ty = self.unifier.get_dummy_var().ty; // Handled by special folding
|
|
||||||
|
|
||||||
create_fn_by_codegen(
|
|
||||||
self.unifier,
|
|
||||||
&into_var_map([in_ndarray_ty]),
|
|
||||||
prim.name(),
|
|
||||||
ret_ty,
|
|
||||||
&[(in_ndarray_ty.ty, "a")],
|
|
||||||
Box::new(move |ctx, obj, fun, args, generator| {
|
|
||||||
assert!(obj.is_none());
|
|
||||||
assert_eq!(args.len(), 1);
|
|
||||||
|
|
||||||
let ndarray_ty = fun.0.args[0].ty;
|
|
||||||
let ndarray =
|
|
||||||
args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
|
|
||||||
|
|
||||||
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
|
|
||||||
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
|
|
||||||
|
|
||||||
let result_tuple = match prim {
|
|
||||||
PrimDef::FunNpShape => ndarray.make_shape_tuple(generator, ctx),
|
|
||||||
PrimDef::FunNpStrides => ndarray.make_strides_tuple(generator, ctx),
|
|
||||||
_ => unreachable!(),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(Some(result_tuple.value.as_basic_value_enum()))
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
_ => unreachable!(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Build np/sp functions that take as input `NDArray` only
|
|
||||||
fn build_ndarray_view_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
|
||||||
debug_assert_prim_is_allowed(
|
|
||||||
prim,
|
|
||||||
&[PrimDef::FunNpBroadcastTo, PrimDef::FunNpTranspose, PrimDef::FunNpReshape],
|
|
||||||
);
|
|
||||||
|
|
||||||
let in_ndarray_ty = self.unifier.get_fresh_var_with_range(
|
|
||||||
&[self.primitives.ndarray],
|
|
||||||
Some("T".into()),
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
|
|
||||||
match prim {
|
|
||||||
PrimDef::FunNpTranspose => {
|
|
||||||
create_fn_by_codegen(
|
|
||||||
self.unifier,
|
|
||||||
&into_var_map([in_ndarray_ty]),
|
|
||||||
prim.name(),
|
|
||||||
in_ndarray_ty.ty,
|
|
||||||
&[(in_ndarray_ty.ty, "x")],
|
|
||||||
Box::new(move |ctx, _, fun, args, generator| {
|
|
||||||
let arg_ty = fun.0.args[0].ty;
|
|
||||||
let arg_val =
|
|
||||||
args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
|
||||||
|
|
||||||
let arg = AnyObject { ty: arg_ty, value: arg_val };
|
|
||||||
let ndarray = NDArrayObject::from_object(generator, ctx, arg);
|
|
||||||
|
|
||||||
let ndarray = ndarray.transpose(generator, ctx, None); // TODO: Add axes argument
|
|
||||||
Ok(Some(ndarray.instance.value.as_basic_value_enum()))
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NOTE: on `ndarray_factory_fn_shape_arg_tvar` and
|
|
||||||
// the `param_ty` for `create_fn_by_codegen`.
|
|
||||||
//
|
|
||||||
// Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking
|
|
||||||
// to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`],
|
|
||||||
// and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`.
|
|
||||||
PrimDef::FunNpBroadcastTo | PrimDef::FunNpReshape => {
|
|
||||||
// These two functions have the same function signature.
|
|
||||||
// Mixed together for convenience.
|
|
||||||
|
|
||||||
let ret_ty = self.unifier.get_dummy_var().ty; // Handled by special holding
|
|
||||||
|
|
||||||
create_fn_by_codegen(
|
|
||||||
self.unifier,
|
|
||||||
&VarMap::new(),
|
|
||||||
prim.name(),
|
|
||||||
ret_ty,
|
|
||||||
&[
|
|
||||||
(in_ndarray_ty.ty, "x"),
|
|
||||||
(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape"), // Handled by special folding
|
|
||||||
],
|
|
||||||
Box::new(move |ctx, _, fun, args, generator| {
|
|
||||||
let ndarray_ty = fun.0.args[0].ty;
|
|
||||||
let ndarray_val =
|
|
||||||
args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
|
|
||||||
|
|
||||||
let shape_ty = fun.0.args[1].ty;
|
|
||||||
let shape_val =
|
|
||||||
args[1].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
|
|
||||||
|
|
||||||
let ndarray = AnyObject { value: ndarray_val, ty: ndarray_ty };
|
|
||||||
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
|
|
||||||
|
|
||||||
let shape = AnyObject { value: shape_val, ty: shape_ty };
|
|
||||||
let (_, shape) = parse_numpy_int_sequence(generator, ctx, shape);
|
|
||||||
|
|
||||||
// The ndims after reshaping is gotten from the return type of the call.
|
|
||||||
let (_, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
|
|
||||||
let ndims = extract_ndims(&ctx.unifier, ndims);
|
|
||||||
|
|
||||||
let new_ndarray = match prim {
|
|
||||||
PrimDef::FunNpBroadcastTo => {
|
|
||||||
ndarray.broadcast_to(generator, ctx, ndims, shape)
|
|
||||||
}
|
|
||||||
PrimDef::FunNpReshape => {
|
|
||||||
ndarray.reshape_or_copy(generator, ctx, ndims, shape)
|
|
||||||
}
|
|
||||||
_ => unreachable!(),
|
|
||||||
};
|
|
||||||
Ok(Some(new_ndarray.instance.value.as_basic_value_enum()))
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
_ => unreachable!(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Build the `str()` function.
|
/// Build the `str()` function.
|
||||||
fn build_str_function(&mut self) -> TopLevelDef {
|
fn build_str_function(&mut self) -> TopLevelDef {
|
||||||
let prim = PrimDef::FunStr;
|
let prim = PrimDef::FunStr;
|
||||||
|
@ -2047,6 +1873,57 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Build np/sp functions that take as input `NDArray` only
|
||||||
|
fn build_np_sp_ndarray_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||||
|
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]);
|
||||||
|
|
||||||
|
match prim {
|
||||||
|
PrimDef::FunNpTranspose => {
|
||||||
|
let ndarray_ty = self.unifier.get_fresh_var_with_range(
|
||||||
|
&[self.ndarray_num_ty],
|
||||||
|
Some("T".into()),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
create_fn_by_codegen(
|
||||||
|
self.unifier,
|
||||||
|
&into_var_map([ndarray_ty]),
|
||||||
|
prim.name(),
|
||||||
|
ndarray_ty.ty,
|
||||||
|
&[(ndarray_ty.ty, "x")],
|
||||||
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
|
let arg_ty = fun.0.args[0].ty;
|
||||||
|
let arg_val =
|
||||||
|
args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||||
|
Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?))
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: on `ndarray_factory_fn_shape_arg_tvar` and
|
||||||
|
// the `param_ty` for `create_fn_by_codegen`.
|
||||||
|
//
|
||||||
|
// Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking
|
||||||
|
// to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`],
|
||||||
|
// and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`.
|
||||||
|
PrimDef::FunNpReshape => create_fn_by_codegen(
|
||||||
|
self.unifier,
|
||||||
|
&VarMap::new(),
|
||||||
|
prim.name(),
|
||||||
|
self.ndarray_num_ty,
|
||||||
|
&[(self.ndarray_num_ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
|
||||||
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
|
let x1_ty = fun.0.args[0].ty;
|
||||||
|
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||||
|
let x2_ty = fun.0.args[1].ty;
|
||||||
|
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
||||||
|
Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Build `np_linalg` and `sp_linalg` functions
|
/// Build `np_linalg` and `sp_linalg` functions
|
||||||
///
|
///
|
||||||
/// The input to these functions must be floating point `NDArray`
|
/// The input to these functions must be floating point `NDArray`
|
||||||
|
@ -2078,12 +1955,10 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
Box::new(move |ctx, _, fun, args, generator| {
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
let x1_ty = fun.0.args[0].ty;
|
let x1_ty = fun.0.args[0].ty;
|
||||||
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||||
|
|
||||||
let x2_ty = fun.0.args[1].ty;
|
let x2_ty = fun.0.args[1].ty;
|
||||||
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
||||||
|
|
||||||
let result = ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?;
|
Ok(Some(ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
|
||||||
Ok(Some(result))
|
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue