Compare commits

..

38 Commits

Author SHA1 Message Date
lyken 916a2b4993
core/ndstrides: implement np_reshape() 2024-08-30 14:41:54 +08:00
lyken c7c3cc21a8
core: categorize np_{transpose,reshape} as 'view functions' 2024-08-30 14:41:54 +08:00
lyken d2072d9248
core/ndstrides: implement np_size() 2024-08-30 14:41:00 +08:00
lyken be19165ead
core/ndstrides: implement np_shape() and np_strides()
These functions are not important, but they are handy for debugging.

`np.strides()` is not an actual NumPy function, but `ndarray.strides` is used.
2024-08-30 14:41:00 +08:00
lyken ee58cf3fc3
core/ndstrides: implement ndarray.fill() and .copy() 2024-08-30 14:41:00 +08:00
lyken 8fe8ccf200
core/ndstrides: implement np_identity() and np_eye() 2024-08-30 14:41:00 +08:00
lyken d222236492
core/ndstrides: implement np_array()
It also checks for inconsistent dimensions if the input is a list.
e.g., rejecting `[[1.0, 2.0], [3.0]]`.

However, currently only `np_array(<input>, copy=False)` and `np_array(<input>, copy=True)` are supported. In NumPy, copy could be false, true, or None. Right now, NAC3's `np_array(<input>, copy=False)` behaves like NumPy's `np.array(<input>, copy=None)`.
2024-08-30 14:40:15 +08:00
lyken 13715dbda9
core/irrt: add List
Needed for implementing np_array()
2024-08-30 14:20:19 +08:00
lyken 7910de10a1
core/ndstrides: add NDArrayObject::atleast_nd 2024-08-30 14:18:34 +08:00
lyken 6edc3f895b
core/ndstrides: add NDArrayObject::make_copy 2024-08-30 14:18:17 +08:00
lyken a40cdde8d2
core/ndstrides: implement ndarray indexing
The functionality for `...` and `np.newaxis` is there in IRRT, but there
is no implementation of them for @kernel Python expressions because of
#486.
2024-08-30 14:12:54 +08:00
lyken 853fa39537
core/irrt: rename NDIndex to NDIndexInt
Unfortunately the name `NDIndex` is used in later commits. Renaming this
typedef to `NDIndexInt` to avoid amending. `NDIndexInt` will be removed
anyway when ndarray strides is completed.
2024-08-30 14:00:19 +08:00
lyken b6a1880226
core/irrt: add Slice and Range
Needed for implementing general ndarray indexing.

Currently IRRT slice and range have nothing to do with NAC3's slice
and range. The IRRT slice and range are currently there to implement
ndarray specific features. However, in the future their definitions may
be used to replace that of NAC3's. (NAC3's range is a [i32 x 3], IRRT's
range is a proper struct. NAC3 does not have a slice struct).
2024-08-30 13:57:10 +08:00
lyken d1c75c7444
core/ndstrides: implement len(ndarray) & refactor len() 2024-08-30 13:45:25 +08:00
lyken 58c5bc56b9
core/ndstrides: implement np_{zeros,ones,full,empty} 2024-08-30 13:44:12 +08:00
lyken ddc0e44c61
core/model: add util::gen_for_model 2024-08-30 13:42:39 +08:00
lyken 549536f72c
core/object: add ListObject and TupleObject
Needed for implementing other ndarray utils.
2024-08-30 13:41:31 +08:00
lyken 40c42b571a
core/ndstrides: implement ndarray iterator NDIter
A necessary utility to iterate through all elements in a possibly strided ndarray.
2024-08-30 13:39:10 +08:00
lyken 92e7103ec7
core/ndstrides: introduce NDArray
NDArray with strides.
2024-08-30 13:24:45 +08:00
lyken 9bc5e96dba
core/irrt: fix exception.hpp C++ castings 2024-08-30 13:15:07 +08:00
lyken 78639b1030
core/toplevel/helper: add {extract,create}_ndims 2024-08-30 13:05:16 +08:00
lyken 9723c17e24
core/object: introduce object
A small abstraction to simplify implementations.
2024-08-30 13:04:54 +08:00
lyken d1c7a8ee50
StructKind::{traverse -> iter}_fields 2024-08-30 12:51:17 +08:00
lyken e0524c19eb
Newline "Otherwise, it will be caught..." 2024-08-30 12:51:17 +08:00
lyken 32822f9052
gep_index must be u32 2024-08-30 12:51:17 +08:00
lyken 6283036815
FieldTraversal::{Out -> Output} 2024-08-30 12:51:17 +08:00
lyken f167f5f215
Ptr::copy_from to use SizeT 2024-08-30 12:51:17 +08:00
lyken baf8ee2b3d
Ptr::offset_const offset i64, can be negative 2024-08-30 12:51:17 +08:00
lyken d68760447f
Int::const_int to have sign_extend 2024-08-30 12:51:17 +08:00
lyken fdd194ee2a
FnCall::{begin -> builder} 2024-08-30 12:51:17 +08:00
lyken 5fca81c68e
CallFunction -> FnCall 2024-08-30 12:51:17 +08:00
lyken 0562e9a385
Instance add newline 2024-08-30 12:51:17 +08:00
lyken 36af473816
unsafe Model::believe_value 2024-08-30 12:51:17 +08:00
lyken 7c7e1b3ab8
Model::{sizeof -> size_of} 2024-08-30 12:51:17 +08:00
lyken dbcfc9538a
ArrayLen::{get_length -> length} 2024-08-30 12:51:17 +08:00
lyken 5c4ba09e2f
LenKind -> ArrayLen 2024-08-30 12:51:17 +08:00
lyken eb34b99ee9
core/model: renaming and add notes on upgrading Ptr to LLVM 15 2024-08-30 12:51:17 +08:00
lyken d397b9ceaa
core/model: introduce models 2024-08-30 12:51:17 +08:00
94 changed files with 6448 additions and 2720 deletions

View File

@ -8,17 +8,17 @@ repos:
hooks:
- id: nac3-cargo-fmt
name: nac3 cargo format
entry: nix
entry: cargo
language: system
types: [file, rust]
pass_filenames: false
description: Runs cargo fmt on the codebase.
args: [develop, -c, cargo, fmt, --all]
args: [fmt]
- id: nac3-cargo-clippy
name: nac3 cargo clippy
entry: nix
entry: cargo
language: system
types: [file, rust]
pass_filenames: false
description: Runs cargo clippy on the codebase.
args: [develop, -c, cargo, clippy, --tests]
args: [clippy, --tests]

338
Cargo.lock generated
View File

@ -75,33 +75,33 @@ dependencies = [
[[package]]
name = "ascii-canvas"
version = "4.0.0"
version = "3.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef1e3e699d84ab1b0911a1010c5c106aa34ae89aeac103be5ce0c3859db1e891"
checksum = "8824ecca2e851cec16968d54a01dd372ef8f95b244fb84b84e70128be347c3c6"
dependencies = [
"term",
]
[[package]]
name = "autocfg"
version = "1.4.0"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
[[package]]
name = "bit-set"
version = "0.8.0"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3"
checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1"
dependencies = [
"bit-vec",
]
[[package]]
name = "bit-vec"
version = "0.8.0"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7"
checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb"
[[package]]
name = "bitflags"
@ -109,15 +109,6 @@ version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de"
[[package]]
name = "block-buffer"
version = "0.10.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71"
dependencies = [
"generic-array",
]
[[package]]
name = "byteorder"
version = "1.5.0"
@ -126,9 +117,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]]
name = "cc"
version = "1.1.30"
version = "1.1.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b16803a61b81d9eabb7eae2588776c4c1e584b738ede45fdbb4c972cec1e9945"
checksum = "57b6a275aa2903740dc87da01c62040406b8812552e97129a63ea8850a17c6e6"
dependencies = [
"shlex",
]
@ -141,9 +132,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "clap"
version = "4.5.20"
version = "4.5.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b97f376d85a664d5837dbae44bf546e6477a679ff6610010f17276f686d867e8"
checksum = "ed6719fffa43d0d87e5fd8caeab59be1554fb028cd30edc88fc4369b17971019"
dependencies = [
"clap_builder",
"clap_derive",
@ -151,9 +142,9 @@ dependencies = [
[[package]]
name = "clap_builder"
version = "4.5.20"
version = "4.5.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19bc80abd44e4bed93ca373a0704ccbd1b710dc5749406201bb018272808dc54"
checksum = "216aec2b177652e3846684cbfe25c9964d18ec45234f0f5da5157b207ed1aab6"
dependencies = [
"anstream",
"anstyle",
@ -163,14 +154,14 @@ dependencies = [
[[package]]
name = "clap_derive"
version = "4.5.18"
version = "4.5.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab"
checksum = "501d359d5f3dcaf6ecdeee48833ae73ec6e42723a1e52419c79abf9507eec0a0"
dependencies = [
"heck 0.5.0",
"proc-macro2",
"quote",
"syn 2.0.79",
"syn 2.0.76",
]
[[package]]
@ -197,15 +188,6 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "cpufeatures"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "608697df725056feaccfa42cffdaeeec3fccc4ffc38358ecd19b243e716a78e0"
dependencies = [
"libc",
]
[[package]]
name = "crossbeam"
version = "0.8.4"
@ -263,23 +245,30 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80"
[[package]]
name = "crypto-common"
version = "0.1.6"
name = "crunchy"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3"
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
[[package]]
name = "dirs-next"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b98cf8ebf19c3d1b223e151f99a4f9f0690dca41414773390fc824184ac833e1"
dependencies = [
"generic-array",
"typenum",
"cfg-if",
"dirs-sys-next",
]
[[package]]
name = "digest"
version = "0.10.7"
name = "dirs-sys-next"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
checksum = "4ebda144c4fe02d1f7ea1a7d9641b6fc6b580adcfa024ae48797ecdeb6825b4d"
dependencies = [
"block-buffer",
"crypto-common",
"libc",
"redox_users",
"winapi",
]
[[package]]
@ -340,16 +329,6 @@ dependencies = [
"byteorder",
]
[[package]]
name = "generic-array"
version = "0.14.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a"
dependencies = [
"typenum",
"version_check",
]
[[package]]
name = "getopts"
version = "0.2.21"
@ -385,12 +364,6 @@ dependencies = [
"ahash",
]
[[package]]
name = "hashbrown"
version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e087f84d4f86bf4b218b927129862374b72199ae7d8657835f1e89000eea4fb"
[[package]]
name = "heck"
version = "0.4.1"
@ -403,15 +376,6 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "home"
version = "0.5.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5"
dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "indexmap"
version = "1.9.3"
@ -424,12 +388,12 @@ dependencies = [
[[package]]
name = "indexmap"
version = "2.6.0"
version = "2.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da"
checksum = "93ead53efc7ea8ed3cfb0c79fc8023fbb782a5432b52830b6518941cebe6505c"
dependencies = [
"equivalent",
"hashbrown 0.15.0",
"hashbrown 0.14.5",
]
[[package]]
@ -440,9 +404,9 @@ checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5"
[[package]]
name = "inkwell"
version = "0.5.0"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40fb405537710d51f6bdbc8471365ddd4cd6d3a3c3ad6e0c8291691031ba94b2"
checksum = "b597a7b2cdf279aeef6d7149071e35e4bc87c2cf05a5b7f2d731300bffe587ea"
dependencies = [
"either",
"inkwell_internals",
@ -454,13 +418,13 @@ dependencies = [
[[package]]
name = "inkwell_internals"
version = "0.10.0"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9dd28cfd4cfba665d47d31c08a6ba637eed16770abca2eccbbc3ca831fef1e44"
checksum = "4fa4d8d74483041a882adaa9a29f633253a66dde85055f0495c121620ac484b2"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.79",
"syn 2.0.76",
]
[[package]]
@ -483,6 +447,15 @@ version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
[[package]]
name = "itertools"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57"
dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.13.0"
@ -498,45 +471,35 @@ version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b"
[[package]]
name = "keccak"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ecc2af9a1119c51f12a14607e783cb977bde58bc069ff0c3da1095e635d70654"
dependencies = [
"cpufeatures",
]
[[package]]
name = "lalrpop"
version = "0.22.0"
version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06093b57658c723a21da679530e061a8c25340fa5a6f98e313b542268c7e2a1f"
checksum = "55cb077ad656299f160924eb2912aa147d7339ea7d69e1b5517326fdcec3c1ca"
dependencies = [
"ascii-canvas",
"bit-set",
"ena",
"itertools",
"itertools 0.11.0",
"lalrpop-util",
"petgraph",
"pico-args",
"regex",
"regex-syntax",
"sha3",
"string_cache",
"term",
"tiny-keccak",
"unicode-xid",
"walkdir",
]
[[package]]
name = "lalrpop-util"
version = "0.22.0"
version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "feee752d43abd0f4807a921958ab4131f692a44d4d599733d4419c5d586176ce"
checksum = "507460a910eb7b32ee961886ff48539633b788a36b65692b95f225b844c82553"
dependencies = [
"regex-automata",
"rustversion",
]
[[package]]
@ -547,9 +510,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]]
name = "libc"
version = "0.2.161"
version = "0.2.158"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e9489c2807c139ffd9c1794f4af0ebe86a828db53ecdc7fea2111d0fed085d1"
checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439"
[[package]]
name = "libloading"
@ -561,6 +524,16 @@ dependencies = [
"windows-targets",
]
[[package]]
name = "libredox"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d"
dependencies = [
"bitflags",
"libc",
]
[[package]]
name = "linked-hash-map"
version = "0.5.6"
@ -621,9 +594,11 @@ dependencies = [
name = "nac3artiq"
version = "0.1.0"
dependencies = [
"itertools",
"inkwell",
"itertools 0.13.0",
"nac3core",
"nac3ld",
"nac3parser",
"parking_lot",
"pyo3",
"tempfile",
@ -644,11 +619,11 @@ name = "nac3core"
version = "0.1.0"
dependencies = [
"crossbeam",
"indexmap 2.6.0",
"indexmap 2.4.0",
"indoc",
"inkwell",
"insta",
"itertools",
"itertools 0.13.0",
"nac3parser",
"parking_lot",
"rayon",
@ -686,7 +661,9 @@ name = "nac3standalone"
version = "0.1.0"
dependencies = [
"clap",
"inkwell",
"nac3core",
"nac3parser",
"parking_lot",
]
@ -698,9 +675,9 @@ checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086"
[[package]]
name = "once_cell"
version = "1.20.2"
version = "1.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775"
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
[[package]]
name = "parking_lot"
@ -732,7 +709,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db"
dependencies = [
"fixedbitset",
"indexmap 2.6.0",
"indexmap 2.4.0",
]
[[package]]
@ -775,7 +752,7 @@ dependencies = [
"phf_shared 0.11.2",
"proc-macro2",
"quote",
"syn 2.0.79",
"syn 2.0.76",
]
[[package]]
@ -804,9 +781,9 @@ checksum = "5be167a7af36ee22fe3115051bc51f6e6c7054c9348e28deb4f49bd6f705a315"
[[package]]
name = "portable-atomic"
version = "1.9.0"
version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2"
checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265"
[[package]]
name = "ppv-lite86"
@ -825,9 +802,9 @@ checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c"
[[package]]
name = "proc-macro2"
version = "1.0.88"
version = "1.0.86"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c3a7fc5db1e57d5a779a352c8cdb57b29aa4c40cc69c3a68a7fedc815fbf2f9"
checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77"
dependencies = [
"unicode-ident",
]
@ -879,7 +856,7 @@ dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
"syn 2.0.79",
"syn 2.0.76",
]
[[package]]
@ -892,7 +869,7 @@ dependencies = [
"proc-macro2",
"pyo3-build-config",
"quote",
"syn 2.0.79",
"syn 2.0.76",
]
[[package]]
@ -956,18 +933,29 @@ dependencies = [
[[package]]
name = "redox_syscall"
version = "0.5.7"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f"
checksum = "2a908a6e00f1fdd0dfd9c0eb08ce85126f6d8bbda50017e74bc4a4b7d4a926a4"
dependencies = [
"bitflags",
]
[[package]]
name = "regex"
version = "1.11.0"
name = "redox_users"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38200e5ee88914975b69f657f0801b6f6dccafd44fd9326302a4aaeecfacb1d8"
checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43"
dependencies = [
"getrandom",
"libredox",
"thiserror",
]
[[package]]
name = "regex"
version = "1.10.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619"
dependencies = [
"aho-corasick",
"memchr",
@ -977,9 +965,9 @@ dependencies = [
[[package]]
name = "regex-automata"
version = "0.4.8"
version = "0.4.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3"
checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df"
dependencies = [
"aho-corasick",
"memchr",
@ -988,9 +976,9 @@ dependencies = [
[[package]]
name = "regex-syntax"
version = "0.8.5"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b"
[[package]]
name = "runkernel"
@ -1001,9 +989,9 @@ dependencies = [
[[package]]
name = "rustix"
version = "0.38.37"
version = "0.38.35"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8acb788b847c24f28525660c4d7758620a7210875711f79e7f663cc152726811"
checksum = "a85d50532239da68e9addb745ba38ff4612a242c1c7ceea689c4bc7c2f43c36f"
dependencies = [
"bitflags",
"errno",
@ -1014,9 +1002,9 @@ dependencies = [
[[package]]
name = "rustversion"
version = "1.0.18"
version = "1.0.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e819f2bc632f285be6d7cd36e25940d45b2391dd6d9b939e79de557f7014248"
checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6"
[[package]]
name = "ryu"
@ -1047,29 +1035,29 @@ checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b"
[[package]]
name = "serde"
version = "1.0.210"
version = "1.0.209"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a"
checksum = "99fce0ffe7310761ca6bf9faf5115afbc19688edd00171d81b1bb1b116c63e09"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.210"
version = "1.0.209"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f"
checksum = "a5831b979fd7b5439637af1752d535ff49f4860c0f341d1baeb6faf0f4242170"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.79",
"syn 2.0.76",
]
[[package]]
name = "serde_json"
version = "1.0.129"
version = "1.0.127"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6dbcf9b78a125ee667ae19388837dd12294b858d101fdd393cb9d5501ef09eb2"
checksum = "8043c06d9f82bd7271361ed64f415fe5e12a77fdb52e573e7f06a516dea329ad"
dependencies = [
"itoa",
"memchr",
@ -1089,16 +1077,6 @@ dependencies = [
"yaml-rust",
]
[[package]]
name = "sha3"
version = "0.10.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75872d278a8f37ef87fa0ddbda7802605cb18344497949862c0d4dcb291eba60"
dependencies = [
"digest",
"keccak",
]
[[package]]
name = "shlex"
version = "1.3.0"
@ -1169,7 +1147,7 @@ dependencies = [
"proc-macro2",
"quote",
"rustversion",
"syn 2.0.79",
"syn 2.0.76",
]
[[package]]
@ -1185,9 +1163,9 @@ dependencies = [
[[package]]
name = "syn"
version = "2.0.79"
version = "2.0.76"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590"
checksum = "578e081a14e0cefc3279b0472138c513f37b41a08d5a3cca9b6e4e8ceb6cd525"
dependencies = [
"proc-macro2",
"quote",
@ -1202,9 +1180,9 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
[[package]]
name = "tempfile"
version = "3.13.0"
version = "3.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0f2c9fc62d0beef6951ccffd757e241266a2c833136efbe35af6cd2567dca5b"
checksum = "04cbcdd0c794ebb0d4cf35e88edd2f7d2c4c3e9a5a6dab322839b321c6a87a64"
dependencies = [
"cfg-if",
"fastrand",
@ -1215,12 +1193,13 @@ dependencies = [
[[package]]
name = "term"
version = "1.0.0"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4df4175de05129f31b80458c6df371a15e7fc3fd367272e6bf938e5c351c7ea0"
checksum = "c59df8ac95d96ff9bede18eb7300b0fda5e5d8d90960e76f8e14ae765eedbf1f"
dependencies = [
"home",
"windows-sys 0.52.0",
"dirs-next",
"rustversion",
"winapi",
]
[[package]]
@ -1238,29 +1217,32 @@ dependencies = [
[[package]]
name = "thiserror"
version = "1.0.64"
version = "1.0.63"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84"
checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.64"
version = "1.0.63"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3"
checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.79",
"syn 2.0.76",
]
[[package]]
name = "typenum"
version = "1.17.0"
name = "tiny-keccak"
version = "2.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237"
dependencies = [
"crunchy",
]
[[package]]
name = "unic-char-property"
@ -1316,27 +1298,27 @@ dependencies = [
[[package]]
name = "unicode-ident"
version = "1.0.13"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe"
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
[[package]]
name = "unicode-width"
version = "0.1.14"
version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af"
checksum = "0336d538f7abc86d282a4189614dfaa90810dfc2c6f6427eaf88e16311dd225d"
[[package]]
name = "unicode-xid"
version = "0.2.6"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
checksum = "229730647fbc343e3a80e463c1db7f78f3855d3f3739bee0dda773c9a037c90a"
[[package]]
name = "unicode_names2"
version = "1.3.0"
version = "1.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1673eca9782c84de5f81b82e4109dcfb3611c8ba0d52930ec4a9478f547b2dd"
checksum = "addeebf294df7922a1164f729fb27ebbbcea99cc32b3bf08afab62757f707677"
dependencies = [
"phf",
"unicode_names2_generator",
@ -1344,9 +1326,9 @@ dependencies = [
[[package]]
name = "unicode_names2_generator"
version = "1.3.0"
version = "1.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b91e5b84611016120197efd7dc93ef76774f4e084cd73c9fb3ea4a86c570c56e"
checksum = "f444b8bba042fe3c1251ffaca35c603f2dc2ccc08d595c65a8c4f76f3e8426c0"
dependencies = [
"getopts",
"log",
@ -1388,6 +1370,22 @@ version = "0.11.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
[[package]]
name = "winapi"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
dependencies = [
"winapi-i686-pc-windows-gnu",
"winapi-x86_64-pc-windows-gnu",
]
[[package]]
name = "winapi-i686-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-util"
version = "0.1.9"
@ -1397,6 +1395,12 @@ dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
name = "windows-sys"
version = "0.52.0"
@ -1506,5 +1510,5 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.79",
"syn 2.0.76",
]

View File

@ -2,11 +2,11 @@
"nodes": {
"nixpkgs": {
"locked": {
"lastModified": 1727348695,
"narHash": "sha256-J+PeFKSDV+pHL7ukkfpVzCOO7mBSrrpJ3svwBFABbhI=",
"lastModified": 1723637854,
"narHash": "sha256-med8+5DSWa2UnOqtdICndjDAEjxr5D7zaIiK4pn0Q7c=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "1925c603f17fc89f4c8f6bf6f631a802ad85d784",
"rev": "c3aa7b8938b17aebd2deecf7be0636000d62a2b9",
"type": "github"
},
"original": {

View File

@ -12,10 +12,16 @@ crate-type = ["cdylib"]
itertools = "0.13"
pyo3 = { version = "0.21", features = ["extension-module", "gil-refs"] }
parking_lot = "0.12"
tempfile = "3.13"
tempfile = "3.10"
nac3parser = { path = "../nac3parser" }
nac3core = { path = "../nac3core" }
nac3ld = { path = "../nac3ld" }
[dependencies.inkwell]
version = "0.4"
default-features = false
features = ["llvm14-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]
[features]
init-llvm-profile = []
no-escape-analysis = ["nac3core/no-escape-analysis"]

View File

@ -112,15 +112,10 @@ def extern(function):
register_function(function)
return function
def rpc(arg=None, flags={}):
"""Decorates a function or method to be executed on the host interpreter."""
if arg is None:
def inner_decorator(function):
return rpc(function, flags)
return inner_decorator
register_function(arg)
return arg
def rpc(function):
"""Decorates a function declaration defined by the core device runtime."""
register_function(function)
return function
def kernel(function_or_method):
"""Decorates a function or method to be executed on the core device."""

View File

@ -1,17 +1,3 @@
use std::{
collections::{hash_map::DefaultHasher, HashMap},
hash::{Hash, Hasher},
iter::once,
mem,
sync::Arc,
};
use itertools::Itertools;
use pyo3::{
types::{PyDict, PyList},
PyObject, PyResult, Python,
};
use nac3core::{
codegen::{
classes::{
@ -24,20 +10,37 @@ use nac3core::{
stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with},
CodeGenContext, CodeGenerator,
},
inkwell::{
context::Context,
module::Linkage,
types::{BasicType, IntType},
values::{BasicValueEnum, IntValue, PointerValue, StructValue},
AddressSpace, IntPredicate, OptimizationLevel,
},
nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef},
symbol_resolver::ValueEnum,
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, GenCall},
typecheck::typedef::{iter_type_vars, FunSignature, FuncArg, Type, TypeEnum, VarMap},
};
use super::{symbol_resolver::InnerResolver, timeline::TimeFns};
use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef};
use inkwell::{
context::Context,
module::Linkage,
types::{BasicType, IntType},
values::{BasicValueEnum, PointerValue, StructValue},
AddressSpace, IntPredicate, OptimizationLevel,
};
use pyo3::{
types::{PyDict, PyList},
PyObject, PyResult, Python,
};
use crate::{symbol_resolver::InnerResolver, timeline::TimeFns};
use inkwell::values::IntValue;
use itertools::Itertools;
use std::{
collections::{hash_map::DefaultHasher, HashMap},
hash::{Hash, Hasher},
iter::once,
mem,
sync::Arc,
};
/// The parallelism mode within a block.
#[derive(Copy, Clone, Eq, PartialEq)]
@ -458,8 +461,8 @@ fn format_rpc_arg<'ctx>(
let llvm_usize = generator.get_size_type(ctx.ctx);
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let llvm_arg_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty);
let llvm_arg_ty =
NDArrayType::new(generator, ctx.ctx, ctx.get_llvm_type(generator, elem_ty));
let llvm_arg = NDArrayValue::from_ptr_val(arg.into_pointer_value(), llvm_usize, None);
let llvm_usize_sizeof = ctx
@ -469,7 +472,7 @@ fn format_rpc_arg<'ctx>(
let llvm_pdata_sizeof = ctx
.builder
.build_int_truncate_or_bit_cast(
llvm_elem_ty.ptr_type(AddressSpace::default()).size_of(),
llvm_arg_ty.element_type().ptr_type(AddressSpace::default()).size_of(),
llvm_usize,
"",
)
@ -512,7 +515,7 @@ fn format_rpc_arg<'ctx>(
ctx.builder.build_store(arg_slot, arg).unwrap();
ctx.builder
.build_bit_cast(arg_slot, llvm_pi8, "rpc.arg")
.build_bitcast(arg_slot, llvm_pi8, "rpc.arg")
.map(BasicValueEnum::into_pointer_value)
.unwrap()
}
@ -627,7 +630,7 @@ fn format_rpc_ret<'ctx>(
let llvm_pdata_sizeof = ctx
.builder
.build_int_truncate_or_bit_cast(
llvm_elem_ty.ptr_type(AddressSpace::default()).size_of(),
llvm_ret_ty.element_type().size_of().unwrap(),
llvm_usize,
"",
)
@ -659,7 +662,7 @@ fn format_rpc_ret<'ctx>(
.unwrap();
let buffer = ctx
.builder
.build_bit_cast(buffer, llvm_pi8, "")
.build_bitcast(buffer, llvm_pi8, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, None);
@ -782,7 +785,7 @@ fn format_rpc_ret<'ctx>(
_ => {
let slot = ctx.builder.build_alloca(llvm_ret_ty, "rpc.ret.slot").unwrap();
let slotgen = ctx.builder.build_bit_cast(slot, llvm_pi8, "rpc.ret.ptr").unwrap();
let slotgen = ctx.builder.build_bitcast(slot, llvm_pi8, "rpc.ret.ptr").unwrap();
ctx.builder.build_unconditional_branch(head_bb).unwrap();
ctx.builder.position_at_end(head_bb);
@ -803,7 +806,7 @@ fn format_rpc_ret<'ctx>(
let alloc_ptr =
ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap();
let alloc_ptr =
ctx.builder.build_bit_cast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap();
ctx.builder.build_bitcast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap();
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
ctx.builder.build_unconditional_branch(head_bb).unwrap();
@ -821,7 +824,6 @@ fn rpc_codegen_callback_fn<'ctx>(
fun: (&FunSignature, DefinitionId),
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
generator: &mut dyn CodeGenerator,
is_async: bool,
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
let int8 = ctx.ctx.i8_type();
let int32 = ctx.ctx.i32_type();
@ -930,64 +932,35 @@ fn rpc_codegen_callback_fn<'ctx>(
}
// call
if is_async {
let rpc_send_async = ctx.module.get_function("rpc_send_async").unwrap_or_else(|| {
ctx.module.add_function(
"rpc_send_async",
ctx.ctx.void_type().fn_type(
&[
int32.into(),
tag_ptr_type.ptr_type(AddressSpace::default()).into(),
ptr_type.ptr_type(AddressSpace::default()).into(),
],
false,
),
None,
)
});
ctx.builder
.build_call(
rpc_send_async,
&[service_id.into(), tag_ptr.into(), args_ptr.into()],
"rpc.send",
)
.unwrap();
} else {
let rpc_send = ctx.module.get_function("rpc_send").unwrap_or_else(|| {
ctx.module.add_function(
"rpc_send",
ctx.ctx.void_type().fn_type(
&[
int32.into(),
tag_ptr_type.ptr_type(AddressSpace::default()).into(),
ptr_type.ptr_type(AddressSpace::default()).into(),
],
false,
),
None,
)
});
ctx.builder
.build_call(rpc_send, &[service_id.into(), tag_ptr.into(), args_ptr.into()], "rpc.send")
.unwrap();
}
let rpc_send = ctx.module.get_function("rpc_send").unwrap_or_else(|| {
ctx.module.add_function(
"rpc_send",
ctx.ctx.void_type().fn_type(
&[
int32.into(),
tag_ptr_type.ptr_type(AddressSpace::default()).into(),
ptr_type.ptr_type(AddressSpace::default()).into(),
],
false,
),
None,
)
});
ctx.builder
.build_call(rpc_send, &[service_id.into(), tag_ptr.into(), args_ptr.into()], "rpc.send")
.unwrap();
// reclaim stack space used by arguments
call_stackrestore(ctx, stackptr);
if is_async {
// async RPCs do not return any values
Ok(None)
} else {
let result = format_rpc_ret(generator, ctx, fun.0.ret);
let result = format_rpc_ret(generator, ctx, fun.0.ret);
if !result.is_some_and(|res| res.get_type().is_pointer_type()) {
// An RPC returning an NDArray would not touch here.
call_stackrestore(ctx, stackptr);
}
Ok(result)
if !result.is_some_and(|res| res.get_type().is_pointer_type()) {
// An RPC returning an NDArray would not touch here.
call_stackrestore(ctx, stackptr);
}
Ok(result)
}
pub fn attributes_writeback(
@ -1082,7 +1055,7 @@ pub fn attributes_writeback(
let args: Vec<_> =
values.into_iter().map(|(_, val)| (None, ValueEnum::Dynamic(val))).collect();
if let Err(e) =
rpc_codegen_callback_fn(ctx, None, (&fun, PrimDef::Int32.id()), args, generator, false)
rpc_codegen_callback_fn(ctx, None, (&fun, PrimDef::Int32.id()), args, generator)
{
return Ok(Err(e));
}
@ -1092,9 +1065,9 @@ pub fn attributes_writeback(
Ok(())
}
pub fn rpc_codegen_callback(is_async: bool) -> Arc<GenCall> {
Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| {
rpc_codegen_callback_fn(ctx, obj, fun, args, generator, is_async)
pub fn rpc_codegen_callback() -> Arc<GenCall> {
Arc::new(GenCall::new(Box::new(|ctx, obj, fun, args, generator| {
rpc_codegen_callback_fn(ctx, obj, fun, args, generator)
})))
}

View File

@ -16,65 +16,64 @@
clippy::wildcard_imports
)]
use std::{
collections::{HashMap, HashSet},
fs,
io::Write,
process::Command,
rc::Rc,
sync::Arc,
};
use std::collections::{HashMap, HashSet};
use std::fs;
use std::io::Write;
use std::process::Command;
use std::rc::Rc;
use std::sync::Arc;
use itertools::Itertools;
use parking_lot::{Mutex, RwLock};
use pyo3::{
create_exception, exceptions,
prelude::*,
types::{PyBytes, PyDict, PySet},
use inkwell::{
context::Context,
memory_buffer::MemoryBuffer,
module::{Linkage, Module},
passes::PassBuilderOptions,
support::is_multithreaded,
targets::*,
OptimizationLevel,
};
use tempfile::{self, TempDir};
use itertools::Itertools;
use nac3core::codegen::{gen_func_impl, CodeGenLLVMOptions, CodeGenTargetMachineOptions};
use nac3core::toplevel::builtins::get_exn_constructor;
use nac3core::typecheck::typedef::{into_var_map, TypeEnum, Unifier, VarMap};
use nac3parser::{
ast::{ExprKind, Stmt, StmtKind, StrRef},
parser::parse_program,
};
use pyo3::create_exception;
use pyo3::prelude::*;
use pyo3::{exceptions, types::PyBytes, types::PyDict, types::PySet};
use parking_lot::{Mutex, RwLock};
use nac3core::{
codegen::{
concrete_type::ConcreteTypeStore, gen_func_impl, irrt::load_irrt, CodeGenLLVMOptions,
CodeGenTargetMachineOptions, CodeGenTask, WithCall, WorkerRegistry,
},
inkwell::{
context::Context,
memory_buffer::MemoryBuffer,
module::{Linkage, Module},
passes::PassBuilderOptions,
support::is_multithreaded,
targets::*,
OptimizationLevel,
},
nac3parser::{
ast::{Constant, ExprKind, Located, Stmt, StmtKind, StrRef},
parser::parse_program,
},
codegen::irrt::load_irrt,
codegen::{concrete_type::ConcreteTypeStore, CodeGenTask, WithCall, WorkerRegistry},
symbol_resolver::SymbolResolver,
toplevel::{
builtins::get_exn_constructor,
composer::{BuiltinFuncCreator, BuiltinFuncSpec, ComposerConfig, TopLevelComposer},
DefinitionId, GenCall, TopLevelDef,
},
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{into_var_map, FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
},
typecheck::typedef::{FunSignature, FuncArg},
typecheck::{type_inferencer::PrimitiveStore, typedef::Type},
};
use nac3ld::Linker;
use codegen::{
attributes_writeback, gen_core_log, gen_rtio_log, rpc_codegen_callback, ArtiqCodeGenerator,
use crate::{
codegen::{
attributes_writeback, gen_core_log, gen_rtio_log, rpc_codegen_callback, ArtiqCodeGenerator,
},
symbol_resolver::{DeferredEvaluationStore, InnerResolver, PythonHelper, Resolver},
};
use symbol_resolver::{DeferredEvaluationStore, InnerResolver, PythonHelper, Resolver};
use timeline::TimeFns;
use tempfile::{self, TempDir};
mod codegen;
mod symbol_resolver;
mod timeline;
use timeline::TimeFns;
#[derive(PartialEq, Clone, Copy)]
enum Isa {
Host,
@ -195,8 +194,10 @@ impl Nac3 {
body.retain(|stmt| {
if let StmtKind::FunctionDef { ref decorator_list, .. } = stmt.node {
decorator_list.iter().any(|decorator| {
if let Some(id) = decorator_id_string(decorator) {
id == "kernel" || id == "portable" || id == "rpc"
if let ExprKind::Name { id, .. } = decorator.node {
id.to_string() == "kernel"
|| id.to_string() == "portable"
|| id.to_string() == "rpc"
} else {
false
}
@ -209,8 +210,9 @@ impl Nac3 {
}
StmtKind::FunctionDef { ref decorator_list, .. } => {
decorator_list.iter().any(|decorator| {
if let Some(id) = decorator_id_string(decorator) {
id == "extern" || id == "kernel" || id == "portable" || id == "rpc"
if let ExprKind::Name { id, .. } = decorator.node {
let id = id.to_string();
id == "extern" || id == "portable" || id == "kernel" || id == "rpc"
} else {
false
}
@ -476,25 +478,9 @@ impl Nac3 {
match &stmt.node {
StmtKind::FunctionDef { decorator_list, .. } => {
if decorator_list
.iter()
.any(|decorator| decorator_id_string(decorator) == Some("rpc".to_string()))
{
store_fun
.call1(
py,
(
def_id.0.into_py(py),
module.getattr(py, name.to_string().as_str()).unwrap(),
),
)
.unwrap();
let is_async = decorator_list.iter().any(|decorator| {
decorator_get_flags(decorator)
.iter()
.any(|constant| *constant == Constant::Str("async".into()))
});
rpc_ids.push((None, def_id, is_async));
if decorator_list.iter().any(|decorator| matches!(decorator.node, ExprKind::Name { id, .. } if id == "rpc".into())) {
store_fun.call1(py, (def_id.0.into_py(py), module.getattr(py, name.to_string().as_str()).unwrap())).unwrap();
rpc_ids.push((None, def_id));
}
}
StmtKind::ClassDef { name, body, .. } => {
@ -502,26 +488,19 @@ impl Nac3 {
let class_obj = module.getattr(py, class_name.as_str()).unwrap();
for stmt in body {
if let StmtKind::FunctionDef { name, decorator_list, .. } = &stmt.node {
if decorator_list.iter().any(|decorator| {
decorator_id_string(decorator) == Some("rpc".to_string())
}) {
let is_async = decorator_list.iter().any(|decorator| {
decorator_get_flags(decorator)
.iter()
.any(|constant| *constant == Constant::Str("async".into()))
});
if decorator_list.iter().any(|decorator| matches!(decorator.node, ExprKind::Name { id, .. } if id == "rpc".into())) {
if name == &"__init__".into() {
return Err(CompileError::new_err(format!(
"compilation failed\n----------\nThe constructor of class {} should not be decorated with rpc decorator (at {})",
class_name, stmt.location
)));
}
rpc_ids.push((Some((class_obj.clone(), *name)), def_id, is_async));
rpc_ids.push((Some((class_obj.clone(), *name)), def_id));
}
}
}
}
_ => (),
_ => ()
}
let id = *name_to_pyid.get(&name).unwrap();
@ -577,7 +556,7 @@ impl Nac3 {
.unwrap();
// Process IRRT
let context = Context::create();
let context = inkwell::context::Context::create();
let irrt = load_irrt(&context, resolver.as_ref());
let fun_signature =
@ -617,12 +596,13 @@ impl Nac3 {
let top_level = Arc::new(composer.make_top_level_context());
{
let rpc_codegen = rpc_codegen_callback();
let defs = top_level.definitions.read();
for (class_data, id, is_async) in &rpc_ids {
for (class_data, id) in &rpc_ids {
let mut def = defs[id.0].write();
match &mut *def {
TopLevelDef::Function { codegen_callback, .. } => {
*codegen_callback = Some(rpc_codegen_callback(*is_async));
*codegen_callback = Some(rpc_codegen.clone());
}
TopLevelDef::Class { methods, .. } => {
let (class_def, method_name) = class_data.as_ref().unwrap();
@ -633,7 +613,7 @@ impl Nac3 {
if let TopLevelDef::Function { codegen_callback, .. } =
&mut *defs[id.0].write()
{
*codegen_callback = Some(rpc_codegen_callback(*is_async));
*codegen_callback = Some(rpc_codegen.clone());
store_fun
.call1(
py,
@ -648,11 +628,6 @@ impl Nac3 {
}
}
}
TopLevelDef::Variable { .. } => {
return Err(CompileError::new_err(String::from(
"Unsupported @rpc annotation on global variable",
)))
}
}
}
}
@ -712,7 +687,7 @@ impl Nac3 {
let buffer = buffer.as_slice().into();
membuffer.lock().push(buffer);
})));
let size_t = context
let size_t = Context::create()
.ptr_sized_int_type(&self.get_llvm_target_machine().get_target_data(), None)
.get_bit_width();
let num_threads = if is_multithreaded() { 4 } else { 1 };
@ -731,7 +706,7 @@ impl Nac3 {
let mut generator =
ArtiqCodeGenerator::new("attributes_writeback".to_string(), size_t, self.time_fns);
let context = Context::create();
let context = inkwell::context::Context::create();
let module = context.create_module("attributes_writeback");
let target_machine = self.llvm_options.create_target_machine().unwrap();
module.set_data_layout(&target_machine.get_target_data().get_data_layout());
@ -869,41 +844,6 @@ impl Nac3 {
}
}
/// Retrieves the Name.id from a decorator, supports decorators with arguments.
fn decorator_id_string(decorator: &Located<ExprKind>) -> Option<String> {
if let ExprKind::Name { id, .. } = decorator.node {
// Bare decorator
return Some(id.to_string());
} else if let ExprKind::Call { func, .. } = &decorator.node {
// Decorators that are calls (e.g. "@rpc()") have Call for the node,
// need to extract the id from within.
if let ExprKind::Name { id, .. } = func.node {
return Some(id.to_string());
}
}
None
}
/// Retrieves flags from a decorator, if any.
fn decorator_get_flags(decorator: &Located<ExprKind>) -> Vec<Constant> {
let mut flags = vec![];
if let ExprKind::Call { keywords, .. } = &decorator.node {
for keyword in keywords {
if keyword.node.arg != Some("flags".into()) {
continue;
}
if let ExprKind::Set { elts } = &keyword.node.value.node {
for elt in elts {
if let ExprKind::Constant { value, .. } = &elt.node {
flags.push(value.clone());
}
}
}
}
}
flags
}
fn link_with_lld(elf_filename: String, obj_filename: String) -> PyResult<()> {
let linker_args = vec![
"-shared".to_string(),

View File

@ -1,30 +1,16 @@
use std::{
collections::{HashMap, HashSet},
sync::{
atomic::{AtomicBool, Ordering::Relaxed},
Arc,
},
use crate::PrimitivePythonId;
use inkwell::{
module::Linkage,
types::{BasicType, BasicTypeEnum},
values::BasicValueEnum,
AddressSpace,
};
use itertools::Itertools;
use parking_lot::RwLock;
use pyo3::{
types::{PyDict, PyTuple},
PyAny, PyObject, PyResult, Python,
};
use nac3core::{
codegen::{
classes::{NDArrayType, ProxyType},
CodeGenContext, CodeGenerator,
},
inkwell::{
module::Linkage,
types::{BasicType, BasicTypeEnum},
values::BasicValueEnum,
AddressSpace,
},
nac3parser::ast::{self, StrRef},
symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum},
toplevel::{
helper::PrimDef,
@ -36,8 +22,19 @@ use nac3core::{
typedef::{into_var_map, iter_type_vars, Type, TypeEnum, TypeVar, Unifier, VarMap},
},
};
use super::PrimitivePythonId;
use nac3parser::ast::{self, StrRef};
use parking_lot::RwLock;
use pyo3::{
types::{PyDict, PyTuple},
PyAny, PyObject, PyResult, Python,
};
use std::{
collections::{HashMap, HashSet},
sync::{
atomic::{AtomicBool, Ordering::Relaxed},
Arc,
},
};
pub enum PrimitiveValue {
I32(i32),
@ -1470,7 +1467,6 @@ impl SymbolResolver for Resolver {
&self,
id: StrRef,
_: &mut CodeGenContext<'ctx, '_>,
_: &mut dyn CodeGenerator,
) -> Option<ValueEnum<'ctx>> {
let sym_value = {
let id_to_val = self.0.id_to_pyval.read();

View File

@ -1,12 +1,9 @@
use itertools::Either;
use nac3core::{
codegen::CodeGenContext,
inkwell::{
values::{BasicValueEnum, CallSiteValue},
AddressSpace, AtomicOrdering,
},
use inkwell::{
values::{BasicValueEnum, CallSiteValue},
AddressSpace, AtomicOrdering,
};
use itertools::Either;
use nac3core::codegen::CodeGenContext;
/// Functions for manipulating the timeline.
pub trait TimeFns {
@ -34,7 +31,7 @@ impl TimeFns for NowPinningTimeFns64 {
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx
.builder
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
@ -83,7 +80,7 @@ impl TimeFns for NowPinningTimeFns64 {
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx
.builder
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
@ -112,7 +109,7 @@ impl TimeFns for NowPinningTimeFns64 {
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx
.builder
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
@ -210,7 +207,7 @@ impl TimeFns for NowPinningTimeFns {
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx
.builder
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
@ -261,7 +258,7 @@ impl TimeFns for NowPinningTimeFns {
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap();
let now_hiptr = ctx
.builder
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();

View File

@ -10,17 +10,17 @@ no-escape-analysis = []
[dependencies]
itertools = "0.13"
crossbeam = "0.8"
indexmap = "2.6"
indexmap = "2.2"
parking_lot = "0.12"
rayon = "1.10"
rayon = "1.8"
nac3parser = { path = "../nac3parser" }
strum = "0.26"
strum_macros = "0.26"
[dependencies.inkwell]
version = "0.5"
version = "0.4"
default-features = false
features = ["llvm14-0-prefer-dynamic", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]
features = ["llvm14-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]
[dev-dependencies]
test-case = "1.2.0"

View File

@ -1,3 +1,4 @@
use regex::Regex;
use std::{
env,
fs::File,
@ -6,8 +7,6 @@ use std::{
process::{Command, Stdio},
};
use regex::Regex;
fn main() {
let out_dir = env::var("OUT_DIR").unwrap();
let out_dir = Path::new(&out_dir);

View File

@ -3,4 +3,11 @@
#include "irrt/list.hpp"
#include "irrt/math.hpp"
#include "irrt/ndarray.hpp"
#include "irrt/range.hpp"
#include "irrt/slice.hpp"
#include "irrt/ndarray/basic.hpp"
#include "irrt/ndarray/def.hpp"
#include "irrt/ndarray/iter.hpp"
#include "irrt/ndarray/indexing.hpp"
#include "irrt/ndarray/array.hpp"
#include "irrt/ndarray/reshape.hpp"

View File

@ -55,11 +55,14 @@ void _raise_exception_helper(ExceptionId id,
int64_t param2) {
Exception<SizeT> e = {
.id = id,
.filename = {.base = reinterpret_cast<const uint8_t*>(filename), .len = __builtin_strlen(filename)},
.filename = {.base = reinterpret_cast<uint8_t*>(const_cast<char*>(filename)),
.len = static_cast<int32_t>(__builtin_strlen(filename))},
.line = line,
.column = 0,
.function = {.base = reinterpret_cast<const uint8_t*>(function), .len = __builtin_strlen(function)},
.msg = {.base = reinterpret_cast<const uint8_t*>(msg), .len = __builtin_strlen(msg)},
.function = {.base = reinterpret_cast<uint8_t*>(const_cast<char*>(function)),
.len = static_cast<int32_t>(__builtin_strlen(function))},
.msg = {.base = reinterpret_cast<uint8_t*>(const_cast<char*>(msg)),
.len = static_cast<int32_t>(__builtin_strlen(msg))},
};
e.params[0] = param0;
e.params[1] = param1;

View File

@ -1,22 +1,13 @@
#pragma once
#if __STDC_VERSION__ >= 202000
using int8_t = _BitInt(8);
using uint8_t = unsigned _BitInt(8);
using int32_t = _BitInt(32);
using uint32_t = unsigned _BitInt(32);
using int64_t = _BitInt(64);
using uint64_t = unsigned _BitInt(64);
#else
using int8_t = _ExtInt(8);
using uint8_t = unsigned _ExtInt(8);
using int32_t = _ExtInt(32);
using uint32_t = unsigned _ExtInt(32);
using int64_t = _ExtInt(64);
using uint64_t = unsigned _ExtInt(64);
#endif
// NDArray indices are always `uint32_t`.
using NDIndex = uint32_t;
using NDIndexInt = uint32_t;
// The type of an index or a value describing the length of a range/slice is always `int32_t`.
using SliceIndex = int32_t;
using SliceIndex = int32_t;

View File

@ -2,6 +2,21 @@
#include "irrt/int_types.hpp"
#include "irrt/math_util.hpp"
#include "irrt/slice.hpp"
namespace {
/**
* @brief A list in NAC3.
*
* The `items` field is opaque. You must rely on external contexts to
* know how to interpret it.
*/
template<typename SizeT>
struct List {
uint8_t* items;
SizeT len;
};
} // namespace
extern "C" {
// Handle list assignment and dropping part of the list when

View File

@ -2,6 +2,8 @@
#include "irrt/int_types.hpp"
// TODO: To be deleted since NDArray with strides is done.
namespace {
template<typename SizeT>
SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) {
@ -17,7 +19,7 @@ SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, Size
}
template<typename SizeT>
void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT num_dims, NDIndex* idxs) {
void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT num_dims, NDIndexInt* idxs) {
SizeT stride = 1;
for (SizeT dim = 0; dim < num_dims; dim++) {
SizeT i = num_dims - dim - 1;
@ -28,7 +30,10 @@ void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT n
}
template<typename SizeT>
SizeT __nac3_ndarray_flatten_index_impl(const SizeT* dims, SizeT num_dims, const NDIndex* indices, SizeT num_indices) {
SizeT __nac3_ndarray_flatten_index_impl(const SizeT* dims,
SizeT num_dims,
const NDIndexInt* indices,
SizeT num_indices) {
SizeT idx = 0;
SizeT stride = 1;
for (SizeT i = 0; i < num_dims; ++i) {
@ -75,8 +80,8 @@ void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims,
template<typename SizeT>
void __nac3_ndarray_calc_broadcast_idx_impl(const SizeT* src_dims,
SizeT src_ndims,
const NDIndex* in_idx,
NDIndex* out_idx) {
const NDIndexInt* in_idx,
NDIndexInt* out_idx) {
for (SizeT i = 0; i < src_ndims; ++i) {
SizeT src_i = src_ndims - i - 1;
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
@ -94,21 +99,23 @@ __nac3_ndarray_calc_size64(const uint64_t* list_data, uint64_t list_len, uint64_
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
}
void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims, uint32_t num_dims, NDIndex* idxs) {
void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims, uint32_t num_dims, NDIndexInt* idxs) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
}
void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndex* idxs) {
void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndexInt* idxs) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
}
uint32_t
__nac3_ndarray_flatten_index(const uint32_t* dims, uint32_t num_dims, const NDIndex* indices, uint32_t num_indices) {
__nac3_ndarray_flatten_index(const uint32_t* dims, uint32_t num_dims, const NDIndexInt* indices, uint32_t num_indices) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
}
uint64_t
__nac3_ndarray_flatten_index64(const uint64_t* dims, uint64_t num_dims, const NDIndex* indices, uint64_t num_indices) {
uint64_t __nac3_ndarray_flatten_index64(const uint64_t* dims,
uint64_t num_dims,
const NDIndexInt* indices,
uint64_t num_indices) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
}
@ -130,15 +137,15 @@ void __nac3_ndarray_calc_broadcast64(const uint64_t* lhs_dims,
void __nac3_ndarray_calc_broadcast_idx(const uint32_t* src_dims,
uint32_t src_ndims,
const NDIndex* in_idx,
NDIndex* out_idx) {
const NDIndexInt* in_idx,
NDIndexInt* out_idx) {
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
}
void __nac3_ndarray_calc_broadcast_idx64(const uint64_t* src_dims,
uint64_t src_ndims,
const NDIndex* in_idx,
NDIndex* out_idx) {
const NDIndexInt* in_idx,
NDIndexInt* out_idx) {
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
}
}

View File

@ -0,0 +1,134 @@
#pragma once
#include "irrt/debug.hpp"
#include "irrt/exception.hpp"
#include "irrt/int_types.hpp"
#include "irrt/list.hpp"
#include "irrt/ndarray/basic.hpp"
#include "irrt/ndarray/def.hpp"
namespace {
namespace ndarray {
namespace array {
/**
* @brief In the context of `np.array(<list>)`, deduce the ndarray's shape produced by `<list>` and raise
* an exception if there is anything wrong with `<shape>` (e.g., inconsistent dimensions `np.array([[1.0, 2.0],
* [3.0]])`)
*
* If this function finds no issues with `<list>`, the deduced shape is written to `shape`. The caller has the
* responsibility to allocate `[SizeT; ndims]` for `shape`. The caller must also initialize `shape` with `-1`s because
* of implementation details.
*/
template<typename SizeT>
void set_and_validate_list_shape_helper(SizeT axis, List<SizeT>* list, SizeT ndims, SizeT* shape) {
if (shape[axis] == -1) {
// Dimension is unspecified. Set it.
shape[axis] = list->len;
} else {
// Dimension is specified. Check.
if (shape[axis] != list->len) {
// Mismatch, throw an error.
// NOTE: NumPy's error message is more complex and needs more PARAMS to display.
raise_exception(SizeT, EXN_VALUE_ERROR,
"The requested array has an inhomogenous shape "
"after {0} dimension(s).",
axis, shape[axis], list->len);
}
}
if (axis + 1 == ndims) {
// `list` has type `list[ItemType]`
// Do nothing
} else {
// `list` has type `list[list[...]]`
List<SizeT>** lists = (List<SizeT>**)(list->items);
for (SizeT i = 0; i < list->len; i++) {
set_and_validate_list_shape_helper<SizeT>(axis + 1, lists[i], ndims, shape);
}
}
}
/**
* @brief See `set_and_validate_list_shape_helper`.
*/
template<typename SizeT>
void set_and_validate_list_shape(List<SizeT>* list, SizeT ndims, SizeT* shape) {
for (SizeT axis = 0; axis < ndims; axis++) {
shape[axis] = -1; // Sentinel to say this dimension is unspecified.
}
set_and_validate_list_shape_helper<SizeT>(0, list, ndims, shape);
}
/**
* @brief In the context of `np.array(<list>)`, copied the contents stored in `list` to `ndarray`.
*
* `list` is assumed to be "legal". (i.e., no inconsistent dimensions)
*
* # Notes on `ndarray`
* The caller is responsible for allocating space for `ndarray`.
* Here is what this function expects from `ndarray` when called:
* - `ndarray->data` has to be allocated, contiguous, and may contain uninitialized values.
* - `ndarray->itemsize` has to be initialized.
* - `ndarray->ndims` has to be initialized.
* - `ndarray->shape` has to be initialized.
* - `ndarray->strides` is ignored, but note that `ndarray->data` is contiguous.
* When this function call ends:
* - `ndarray->data` is written with contents from `<list>`.
*/
template<typename SizeT>
void write_list_to_array_helper(SizeT axis, SizeT* index, List<SizeT>* list, NDArray<SizeT>* ndarray) {
debug_assert_eq(SizeT, list->len, ndarray->shape[axis]);
if (IRRT_DEBUG_ASSERT_BOOL) {
if (!ndarray::basic::is_c_contiguous(ndarray)) {
raise_debug_assert(SizeT, "ndarray is not C-contiguous", ndarray->strides[0], ndarray->strides[1],
NO_PARAM);
}
}
if (axis + 1 == ndarray->ndims) {
// `list` has type `list[scalar]`
// `ndarray` is contiguous, so we can do this, and this is fast.
uint8_t* dst = ndarray->data + (ndarray->itemsize * (*index));
__builtin_memcpy(dst, list->items, ndarray->itemsize * list->len);
*index += list->len;
} else {
// `list` has type `list[list[...]]`
List<SizeT>** lists = (List<SizeT>**)(list->items);
for (SizeT i = 0; i < list->len; i++) {
write_list_to_array_helper<SizeT>(axis + 1, index, lists[i], ndarray);
}
}
}
/**
* @brief See `write_list_to_array_helper`.
*/
template<typename SizeT>
void write_list_to_array(List<SizeT>* list, NDArray<SizeT>* ndarray) {
SizeT index = 0;
write_list_to_array_helper<SizeT>((SizeT)0, &index, list, ndarray);
}
} // namespace array
} // namespace ndarray
} // namespace
extern "C" {
using namespace ndarray::array;
void __nac3_ndarray_array_set_and_validate_list_shape(List<int32_t>* list, int32_t ndims, int32_t* shape) {
set_and_validate_list_shape(list, ndims, shape);
}
void __nac3_ndarray_array_set_and_validate_list_shape64(List<int64_t>* list, int64_t ndims, int64_t* shape) {
set_and_validate_list_shape(list, ndims, shape);
}
void __nac3_ndarray_array_write_list_to_array(List<int32_t>* list, NDArray<int32_t>* ndarray) {
write_list_to_array(list, ndarray);
}
void __nac3_ndarray_array_write_list_to_array64(List<int64_t>* list, NDArray<int64_t>* ndarray) {
write_list_to_array(list, ndarray);
}
}

View File

@ -0,0 +1,341 @@
#pragma once
#include "irrt/debug.hpp"
#include "irrt/exception.hpp"
#include "irrt/int_types.hpp"
#include "irrt/ndarray/def.hpp"
namespace {
namespace ndarray {
namespace basic {
/**
* @brief Assert that `shape` does not contain negative dimensions.
*
* @param ndims Number of dimensions in `shape`
* @param shape The shape to check on
*/
template<typename SizeT>
void assert_shape_no_negative(SizeT ndims, const SizeT* shape) {
for (SizeT axis = 0; axis < ndims; axis++) {
if (shape[axis] < 0) {
raise_exception(SizeT, EXN_VALUE_ERROR,
"negative dimensions are not allowed; axis {0} "
"has dimension {1}",
axis, shape[axis], NO_PARAM);
}
}
}
/**
* @brief Assert that two shapes are the same in the context of writing output to an ndarray.
*/
template<typename SizeT>
void assert_output_shape_same(SizeT ndarray_ndims,
const SizeT* ndarray_shape,
SizeT output_ndims,
const SizeT* output_shape) {
if (ndarray_ndims != output_ndims) {
// There is no corresponding NumPy error message like this.
raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot write output of ndims {0} to an ndarray with ndims {1}",
output_ndims, ndarray_ndims, NO_PARAM);
}
for (SizeT axis = 0; axis < ndarray_ndims; axis++) {
if (ndarray_shape[axis] != output_shape[axis]) {
// There is no corresponding NumPy error message like this.
raise_exception(SizeT, EXN_VALUE_ERROR,
"Mismatched dimensions on axis {0}, output has "
"dimension {1}, but destination ndarray has dimension {2}.",
axis, output_shape[axis], ndarray_shape[axis]);
}
}
}
/**
* @brief Return the number of elements of an ndarray given its shape.
*
* @param ndims Number of dimensions in `shape`
* @param shape The shape of the ndarray
*/
template<typename SizeT>
SizeT calc_size_from_shape(SizeT ndims, const SizeT* shape) {
SizeT size = 1;
for (SizeT axis = 0; axis < ndims; axis++)
size *= shape[axis];
return size;
}
/**
* @brief Compute the array indices of the `nth` (0-based) element of an ndarray given only its shape.
*
* @param ndims Number of elements in `shape` and `indices`
* @param shape The shape of the ndarray
* @param indices The returned indices indexing the ndarray with shape `shape`.
* @param nth The index of the element of interest.
*/
template<typename SizeT>
void set_indices_by_nth(SizeT ndims, const SizeT* shape, SizeT* indices, SizeT nth) {
for (SizeT i = 0; i < ndims; i++) {
SizeT axis = ndims - i - 1;
SizeT dim = shape[axis];
indices[axis] = nth % dim;
nth /= dim;
}
}
/**
* @brief Return the number of elements of an `ndarray`
*
* This function corresponds to `<an_ndarray>.size`
*/
template<typename SizeT>
SizeT size(const NDArray<SizeT>* ndarray) {
return calc_size_from_shape(ndarray->ndims, ndarray->shape);
}
/**
* @brief Return of the number of its content of an `ndarray`.
*
* This function corresponds to `<an_ndarray>.nbytes`.
*/
template<typename SizeT>
SizeT nbytes(const NDArray<SizeT>* ndarray) {
return size(ndarray) * ndarray->itemsize;
}
/**
* @brief Get the `len()` of an ndarray, and asserts that `ndarray` is a sized object.
*
* This function corresponds to `<an_ndarray>.__len__`.
*
* @param dst_length The length.
*/
template<typename SizeT>
SizeT len(const NDArray<SizeT>* ndarray) {
// numpy prohibits `__len__` on unsized objects
if (ndarray->ndims == 0) {
raise_exception(SizeT, EXN_TYPE_ERROR, "len() of unsized object", NO_PARAM, NO_PARAM, NO_PARAM);
} else {
return ndarray->shape[0];
}
}
/**
* @brief Return a boolean indicating if `ndarray` is (C-)contiguous.
*
* You may want to see ndarray's rules for C-contiguity:
* https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45
*/
template<typename SizeT>
bool is_c_contiguous(const NDArray<SizeT>* ndarray) {
// References:
// - tinynumpy's implementation:
// https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L102
// - ndarray's flags["C_CONTIGUOUS"]:
// https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags
// - ndarray's rules for C-contiguity:
// https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45
// From
// https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45:
//
// The traditional rule is that for an array to be flagged as C contiguous,
// the following must hold:
//
// strides[-1] == itemsize
// strides[i] == shape[i+1] * strides[i + 1]
// [...]
// According to these rules, a 0- or 1-dimensional array is either both
// C- and F-contiguous, or neither; and an array with 2+ dimensions
// can be C- or F- contiguous, or neither, but not both. Though there
// there are exceptions for arrays with zero or one item, in the first
// case the check is relaxed up to and including the first dimension
// with shape[i] == 0. In the second case `strides == itemsize` will
// can be true for all dimensions and both flags are set.
if (ndarray->ndims == 0) {
return true;
}
if (ndarray->strides[ndarray->ndims - 1] != ndarray->itemsize) {
return false;
}
for (SizeT i = 1; i < ndarray->ndims; i++) {
SizeT axis_i = ndarray->ndims - i - 1;
if (ndarray->strides[axis_i] != ndarray->shape[axis_i + 1] * ndarray->strides[axis_i + 1]) {
return false;
}
}
return true;
}
/**
* @brief Return the pointer to the element indexed by `indices` along the ndarray's axes.
*
* This function does no bound check.
*/
template<typename SizeT>
uint8_t* get_pelement_by_indices(const NDArray<SizeT>* ndarray, const SizeT* indices) {
uint8_t* element = ndarray->data;
for (SizeT dim_i = 0; dim_i < ndarray->ndims; dim_i++)
element += indices[dim_i] * ndarray->strides[dim_i];
return element;
}
/**
* @brief Return the pointer to the nth (0-based) element of `ndarray` in flattened view.
*
* This function does no bound check.
*/
template<typename SizeT>
uint8_t* get_nth_pelement(const NDArray<SizeT>* ndarray, SizeT nth) {
uint8_t* element = ndarray->data;
for (SizeT i = 0; i < ndarray->ndims; i++) {
SizeT axis = ndarray->ndims - i - 1;
SizeT dim = ndarray->shape[axis];
element += ndarray->strides[axis] * (nth % dim);
nth /= dim;
}
return element;
}
/**
* @brief Update the strides of an ndarray given an ndarray `shape` to be contiguous.
*
* You might want to read https://ajcr.net/stride-guide-part-1/.
*/
template<typename SizeT>
void set_strides_by_shape(NDArray<SizeT>* ndarray) {
SizeT stride_product = 1;
for (SizeT i = 0; i < ndarray->ndims; i++) {
SizeT axis = ndarray->ndims - i - 1;
ndarray->strides[axis] = stride_product * ndarray->itemsize;
stride_product *= ndarray->shape[axis];
}
}
/**
* @brief Set an element in `ndarray`.
*
* @param pelement Pointer to the element in `ndarray` to be set.
* @param pvalue Pointer to the value `pelement` will be set to.
*/
template<typename SizeT>
void set_pelement_value(NDArray<SizeT>* ndarray, uint8_t* pelement, const uint8_t* pvalue) {
__builtin_memcpy(pelement, pvalue, ndarray->itemsize);
}
/**
* @brief Copy data from one ndarray to another of the exact same size and itemsize.
*
* Both ndarrays will be viewed in their flatten views when copying the elements.
*/
template<typename SizeT>
void copy_data(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
// TODO: Make this faster with memcpy when we see a contiguous segment.
// TODO: Handle overlapping.
debug_assert_eq(SizeT, src_ndarray->itemsize, dst_ndarray->itemsize);
for (SizeT i = 0; i < size(src_ndarray); i++) {
auto src_element = ndarray::basic::get_nth_pelement(src_ndarray, i);
auto dst_element = ndarray::basic::get_nth_pelement(dst_ndarray, i);
ndarray::basic::set_pelement_value(dst_ndarray, dst_element, src_element);
}
}
} // namespace basic
} // namespace ndarray
} // namespace
extern "C" {
using namespace ndarray::basic;
void __nac3_ndarray_util_assert_shape_no_negative(int32_t ndims, int32_t* shape) {
assert_shape_no_negative(ndims, shape);
}
void __nac3_ndarray_util_assert_shape_no_negative64(int64_t ndims, int64_t* shape) {
assert_shape_no_negative(ndims, shape);
}
void __nac3_ndarray_util_assert_output_shape_same(int32_t ndarray_ndims,
const int32_t* ndarray_shape,
int32_t output_ndims,
const int32_t* output_shape) {
assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims, output_shape);
}
void __nac3_ndarray_util_assert_output_shape_same64(int64_t ndarray_ndims,
const int64_t* ndarray_shape,
int64_t output_ndims,
const int64_t* output_shape) {
assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims, output_shape);
}
uint32_t __nac3_ndarray_size(NDArray<int32_t>* ndarray) {
return size(ndarray);
}
uint64_t __nac3_ndarray_size64(NDArray<int64_t>* ndarray) {
return size(ndarray);
}
uint32_t __nac3_ndarray_nbytes(NDArray<int32_t>* ndarray) {
return nbytes(ndarray);
}
uint64_t __nac3_ndarray_nbytes64(NDArray<int64_t>* ndarray) {
return nbytes(ndarray);
}
int32_t __nac3_ndarray_len(NDArray<int32_t>* ndarray) {
return len(ndarray);
}
int64_t __nac3_ndarray_len64(NDArray<int64_t>* ndarray) {
return len(ndarray);
}
bool __nac3_ndarray_is_c_contiguous(NDArray<int32_t>* ndarray) {
return is_c_contiguous(ndarray);
}
bool __nac3_ndarray_is_c_contiguous64(NDArray<int64_t>* ndarray) {
return is_c_contiguous(ndarray);
}
uint8_t* __nac3_ndarray_get_nth_pelement(const NDArray<int32_t>* ndarray, int32_t nth) {
return get_nth_pelement(ndarray, nth);
}
uint8_t* __nac3_ndarray_get_nth_pelement64(const NDArray<int64_t>* ndarray, int64_t nth) {
return get_nth_pelement(ndarray, nth);
}
uint8_t* __nac3_ndarray_get_pelement_by_indices(const NDArray<int32_t>* ndarray, int32_t* indices) {
return get_pelement_by_indices(ndarray, indices);
}
uint8_t* __nac3_ndarray_get_pelement_by_indices64(const NDArray<int64_t>* ndarray, int64_t* indices) {
return get_pelement_by_indices(ndarray, indices);
}
void __nac3_ndarray_set_strides_by_shape(NDArray<int32_t>* ndarray) {
set_strides_by_shape(ndarray);
}
void __nac3_ndarray_set_strides_by_shape64(NDArray<int64_t>* ndarray) {
set_strides_by_shape(ndarray);
}
void __nac3_ndarray_copy_data(NDArray<int32_t>* src_ndarray, NDArray<int32_t>* dst_ndarray) {
copy_data(src_ndarray, dst_ndarray);
}
void __nac3_ndarray_copy_data64(NDArray<int64_t>* src_ndarray, NDArray<int64_t>* dst_ndarray) {
copy_data(src_ndarray, dst_ndarray);
}
}

View File

@ -0,0 +1,45 @@
#pragma once
#include "irrt/int_types.hpp"
namespace {
/**
* @brief The NDArray object
*
* Official numpy implementation:
* https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst
*/
template<typename SizeT>
struct NDArray {
/**
* @brief The underlying data this `ndarray` is pointing to.
*/
uint8_t* data;
/**
* @brief The number of bytes of a single element in `data`.
*/
SizeT itemsize;
/**
* @brief The number of dimensions of this shape.
*/
SizeT ndims;
/**
* @brief The NDArray shape, with length equal to `ndims`.
*
* Note that it may contain 0.
*/
SizeT* shape;
/**
* @brief Array strides, with length equal to `ndims`
*
* The stride values are in units of bytes, not number of elements.
*
* Note that `strides` can have negative values or contain 0.
*/
SizeT* strides;
};
} // namespace

View File

@ -0,0 +1,220 @@
#pragma once
#include "irrt/exception.hpp"
#include "irrt/int_types.hpp"
#include "irrt/ndarray/basic.hpp"
#include "irrt/ndarray/def.hpp"
#include "irrt/range.hpp"
#include "irrt/slice.hpp"
namespace {
typedef uint8_t NDIndexType;
/**
* @brief A single element index
*
* `data` points to a `int32_t`.
*/
const NDIndexType ND_INDEX_TYPE_SINGLE_ELEMENT = 0;
/**
* @brief A slice index
*
* `data` points to a `Slice<int32_t>`.
*/
const NDIndexType ND_INDEX_TYPE_SLICE = 1;
/**
* @brief `np.newaxis` / `None`
*
* `data` is unused.
*/
const NDIndexType ND_INDEX_TYPE_NEWAXIS = 2;
/**
* @brief `Ellipsis` / `...`
*
* `data` is unused.
*/
const NDIndexType ND_INDEX_TYPE_ELLIPSIS = 3;
/**
* @brief An index used in ndarray indexing
*
* That is:
* ```
* my_ndarray[::-1, 3, ..., np.newaxis]
* ^^^^ ^ ^^^ ^^^^^^^^^^ each of these is represented by an NDIndex.
* ```
*/
struct NDIndex {
/**
* @brief Enum tag to specify the type of index.
*
* Please see the comment of each enum constant.
*/
NDIndexType type;
/**
* @brief The accompanying data associated with `type`.
*
* Please see the comment of each enum constant.
*/
uint8_t* data;
};
} // namespace
namespace {
namespace ndarray {
namespace indexing {
/**
* @brief Perform ndarray "basic indexing" (https://numpy.org/doc/stable/user/basics.indexing.html#basic-indexing)
*
* This function is very similar to performing `dst_ndarray = src_ndarray[indices]` in Python.
*
* This function also does proper assertions on `indices` to check for out of bounds access and more.
*
* # Notes on `dst_ndarray`
* The caller is responsible for allocating space for the resulting ndarray.
* Here is what this function expects from `dst_ndarray` when called:
* - `dst_ndarray->data` does not have to be initialized.
* - `dst_ndarray->itemsize` does not have to be initialized.
* - `dst_ndarray->ndims` must be initialized, and it must be equal to the expected `ndims` of the `dst_ndarray` after
* indexing `src_ndarray` with `indices`.
* - `dst_ndarray->shape` must be allocated, through it can contain uninitialized values.
* - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values.
* When this function call ends:
* - `dst_ndarray->data` is set to `src_ndarray->data`.
* - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`.
* - `dst_ndarray->ndims` is unchanged.
* - `dst_ndarray->shape` is updated according to how `src_ndarray` is indexed.
* - `dst_ndarray->strides` is updated accordingly by how ndarray indexing works.
*
* @param indices indices to index `src_ndarray`, ordered in the same way you would write them in Python.
* @param src_ndarray The NDArray to be indexed.
* @param dst_ndarray The resulting NDArray after indexing. Further details in the comments above,
*/
template<typename SizeT>
void index(SizeT num_indices, const NDIndex* indices, const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
// Validate `indices`.
// Expected value of `dst_ndarray->ndims`.
SizeT expected_dst_ndims = src_ndarray->ndims;
// To check for "too many indices for array: array is ?-dimensional, but ? were indexed"
SizeT num_indexed = 0;
// There may be ellipsis `...` in `indices`. There can only be 0 or 1 ellipsis.
SizeT num_ellipsis = 0;
for (SizeT i = 0; i < num_indices; i++) {
if (indices[i].type == ND_INDEX_TYPE_SINGLE_ELEMENT) {
expected_dst_ndims--;
num_indexed++;
} else if (indices[i].type == ND_INDEX_TYPE_SLICE) {
num_indexed++;
} else if (indices[i].type == ND_INDEX_TYPE_NEWAXIS) {
expected_dst_ndims++;
} else if (indices[i].type == ND_INDEX_TYPE_ELLIPSIS) {
num_ellipsis++;
if (num_ellipsis > 1) {
raise_exception(SizeT, EXN_INDEX_ERROR, "an index can only have a single ellipsis ('...')", NO_PARAM,
NO_PARAM, NO_PARAM);
}
} else {
__builtin_unreachable();
}
}
debug_assert_eq(SizeT, expected_dst_ndims, dst_ndarray->ndims);
if (src_ndarray->ndims - num_indexed < 0) {
raise_exception(SizeT, EXN_INDEX_ERROR,
"too many indices for array: array is {0}-dimensional, "
"but {1} were indexed",
src_ndarray->ndims, num_indices, NO_PARAM);
}
dst_ndarray->data = src_ndarray->data;
dst_ndarray->itemsize = src_ndarray->itemsize;
// Reference code:
// https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L652
SizeT src_axis = 0;
SizeT dst_axis = 0;
for (int32_t i = 0; i < num_indices; i++) {
const NDIndex* index = &indices[i];
if (index->type == ND_INDEX_TYPE_SINGLE_ELEMENT) {
SizeT input = (SizeT) * ((int32_t*)index->data);
SizeT k = slice::resolve_index_in_length(src_ndarray->shape[src_axis], input);
if (k == -1) {
raise_exception(SizeT, EXN_INDEX_ERROR,
"index {0} is out of bounds for axis {1} "
"with size {2}",
input, src_axis, src_ndarray->shape[src_axis]);
}
dst_ndarray->data += k * src_ndarray->strides[src_axis];
src_axis++;
} else if (index->type == ND_INDEX_TYPE_SLICE) {
Slice<int32_t>* slice = (Slice<int32_t>*)index->data;
Range<int32_t> range = slice->indices_checked<SizeT>(src_ndarray->shape[src_axis]);
dst_ndarray->data += (SizeT)range.start * src_ndarray->strides[src_axis];
dst_ndarray->strides[dst_axis] = ((SizeT)range.step) * src_ndarray->strides[src_axis];
dst_ndarray->shape[dst_axis] = (SizeT)range.len<SizeT>();
dst_axis++;
src_axis++;
} else if (index->type == ND_INDEX_TYPE_NEWAXIS) {
dst_ndarray->strides[dst_axis] = 0;
dst_ndarray->shape[dst_axis] = 1;
dst_axis++;
} else if (index->type == ND_INDEX_TYPE_ELLIPSIS) {
// The number of ':' entries this '...' implies.
SizeT ellipsis_size = src_ndarray->ndims - num_indexed;
for (SizeT j = 0; j < ellipsis_size; j++) {
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis];
dst_axis++;
src_axis++;
}
} else {
__builtin_unreachable();
}
}
for (; dst_axis < dst_ndarray->ndims; dst_axis++, src_axis++) {
dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis];
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
}
debug_assert_eq(SizeT, src_ndarray->ndims, src_axis);
debug_assert_eq(SizeT, dst_ndarray->ndims, dst_axis);
}
} // namespace indexing
} // namespace ndarray
} // namespace
extern "C" {
using namespace ndarray::indexing;
void __nac3_ndarray_index(int32_t num_indices,
NDIndex* indices,
NDArray<int32_t>* src_ndarray,
NDArray<int32_t>* dst_ndarray) {
index(num_indices, indices, src_ndarray, dst_ndarray);
}
void __nac3_ndarray_index64(int64_t num_indices,
NDIndex* indices,
NDArray<int64_t>* src_ndarray,
NDArray<int64_t>* dst_ndarray) {
index(num_indices, indices, src_ndarray, dst_ndarray);
}
}

View File

@ -0,0 +1,146 @@
#pragma once
#include "irrt/int_types.hpp"
#include "irrt/ndarray/def.hpp"
namespace {
/**
* @brief Helper struct to enumerate through an ndarray *efficiently*.
*
* Example usage (in pseudo-code):
* ```
* // Suppose my_ndarray has been initialized, with shape [2, 3] and dtype `double`
* NDIter nditer;
* nditer.initialize(my_ndarray);
* while (nditer.has_element()) {
* // This body is run 6 (= my_ndarray.size) times.
*
* // [0, 0] -> [0, 1] -> [0, 2] -> [1, 0] -> [1, 1] -> [1, 2] -> end
* print(nditer.indices);
*
* // 0 -> 1 -> 2 -> 3 -> 4 -> 5
* print(nditer.nth);
*
* // <1st element> -> <2nd element> -> ... -> <6th element> -> end
* print(*((double *) nditer.element))
*
* nditer.next(); // Go to next element.
* }
* ```
*
* Interesting cases:
* - If `my_ndarray.ndims` == 0, there is one iteration.
* - If `my_ndarray.shape` contains zeroes, there are no iterations.
*/
template<typename SizeT>
struct NDIter {
// Information about the ndarray being iterated over.
SizeT ndims;
SizeT* shape;
SizeT* strides;
/**
* @brief The current indices.
*
* Must be allocated by the caller.
*/
SizeT* indices;
/**
* @brief The nth (0-based) index of the current indices.
*
* Initially this is 0.
*/
SizeT nth;
/**
* @brief Pointer to the current element.
*
* Initially this points to first element of the ndarray.
*/
uint8_t* element;
/**
* @brief Cache for the product of shape.
*
* Could be 0 if `shape` has 0s in it.
*/
SizeT size;
void initialize(SizeT ndims, SizeT* shape, SizeT* strides, uint8_t* element, SizeT* indices) {
this->ndims = ndims;
this->shape = shape;
this->strides = strides;
this->indices = indices;
this->element = element;
// Compute size
this->size = 1;
for (SizeT i = 0; i < ndims; i++) {
this->size *= shape[i];
}
// `indices` starts on all 0s.
for (SizeT axis = 0; axis < ndims; axis++)
indices[axis] = 0;
nth = 0;
}
void initialize_by_ndarray(NDArray<SizeT>* ndarray, SizeT* indices) {
// NOTE: ndarray->data is pointing to the first element, and `NDIter`'s `element` should also point to the first
// element as well.
this->initialize(ndarray->ndims, ndarray->shape, ndarray->strides, ndarray->data, indices);
}
// Is the current iteration valid?
// If true, then `element`, `indices` and `nth` contain details about the current element.
bool has_element() { return nth < size; }
// Go to the next element.
void next() {
for (SizeT i = 0; i < ndims; i++) {
SizeT axis = ndims - i - 1;
indices[axis]++;
if (indices[axis] >= shape[axis]) {
indices[axis] = 0;
// TODO: There is something called backstrides to speedup iteration.
// See https://ajcr.net/stride-guide-part-1/, and
// https://docs.scipy.org/doc/numpy-1.13.0/reference/c-api.types-and-structures.html#c.PyArrayIterObject.PyArrayIterObject.backstrides.
element -= strides[axis] * (shape[axis] - 1);
} else {
element += strides[axis];
break;
}
}
nth++;
}
};
} // namespace
extern "C" {
void __nac3_nditer_initialize(NDIter<int32_t>* iter, NDArray<int32_t>* ndarray, int32_t* indices) {
iter->initialize_by_ndarray(ndarray, indices);
}
void __nac3_nditer_initialize64(NDIter<int64_t>* iter, NDArray<int64_t>* ndarray, int64_t* indices) {
iter->initialize_by_ndarray(ndarray, indices);
}
bool __nac3_nditer_has_element(NDIter<int32_t>* iter) {
return iter->has_element();
}
bool __nac3_nditer_has_element64(NDIter<int64_t>* iter) {
return iter->has_element();
}
void __nac3_nditer_next(NDIter<int32_t>* iter) {
iter->next();
}
void __nac3_nditer_next64(NDIter<int64_t>* iter) {
iter->next();
}
}

View File

@ -0,0 +1,99 @@
#pragma once
#include "irrt/exception.hpp"
#include "irrt/int_types.hpp"
#include "irrt/ndarray/def.hpp"
namespace {
namespace ndarray {
namespace reshape {
/**
* @brief Perform assertions on and resolve unknown dimensions in `new_shape` in `np.reshape(<ndarray>, new_shape)`
*
* If `new_shape` indeed contains unknown dimensions (specified with `-1`, just like numpy), `new_shape` will be
* modified to contain the resolved dimension.
*
* To perform assertions on and resolve unknown dimensions in `new_shape`, we don't need the actual
* `<ndarray>` object itself, but only the `.size` of the `<ndarray>`.
*
* @param size The `.size` of `<ndarray>`
* @param new_ndims Number of elements in `new_shape`
* @param new_shape Target shape to reshape to
*/
template<typename SizeT>
void resolve_and_check_new_shape(SizeT size, SizeT new_ndims, SizeT* new_shape) {
// Is there a -1 in `new_shape`?
bool neg1_exists = false;
// Location of -1, only initialized if `neg1_exists` is true
SizeT neg1_axis_i;
// The computed ndarray size of `new_shape`
SizeT new_size = 1;
for (SizeT axis_i = 0; axis_i < new_ndims; axis_i++) {
SizeT dim = new_shape[axis_i];
if (dim < 0) {
if (dim == -1) {
if (neg1_exists) {
// Multiple `-1` found. Throw an error.
raise_exception(SizeT, EXN_VALUE_ERROR, "can only specify one unknown dimension", NO_PARAM,
NO_PARAM, NO_PARAM);
} else {
neg1_exists = true;
neg1_axis_i = axis_i;
}
} else {
// TODO: What? In `np.reshape` any negative dimensions is
// treated like its `-1`.
//
// Try running `np.zeros((3, 4)).reshape((-999, 2))`
//
// It is not documented by numpy.
// Throw an error for now...
raise_exception(SizeT, EXN_VALUE_ERROR, "Found non -1 negative dimension {0} on axis {1}", dim, axis_i,
NO_PARAM);
}
} else {
new_size *= dim;
}
}
bool can_reshape;
if (neg1_exists) {
// Let `x` be the unknown dimension
// Solve `x * <new_size> = <size>`
if (new_size == 0 && size == 0) {
// `x` has infinitely many solutions
can_reshape = false;
} else if (new_size == 0 && size != 0) {
// `x` has no solutions
can_reshape = false;
} else if (size % new_size != 0) {
// `x` has no integer solutions
can_reshape = false;
} else {
can_reshape = true;
new_shape[neg1_axis_i] = size / new_size; // Resolve dimension
}
} else {
can_reshape = (new_size == size);
}
if (!can_reshape) {
raise_exception(SizeT, EXN_VALUE_ERROR, "cannot reshape array of size {0} into given shape", size, NO_PARAM,
NO_PARAM);
}
}
} // namespace reshape
} // namespace ndarray
} // namespace
extern "C" {
void __nac3_ndarray_reshape_resolve_and_check_new_shape(int32_t size, int32_t new_ndims, int32_t* new_shape) {
ndarray::reshape::resolve_and_check_new_shape(size, new_ndims, new_shape);
}
void __nac3_ndarray_reshape_resolve_and_check_new_shape64(int64_t size, int64_t new_ndims, int64_t* new_shape) {
ndarray::reshape::resolve_and_check_new_shape(size, new_ndims, new_shape);
}
}

View File

@ -0,0 +1,47 @@
#pragma once
#include "irrt/debug.hpp"
#include "irrt/int_types.hpp"
namespace {
namespace range {
template<typename T>
T len(T start, T stop, T step) {
// Reference:
// https://github.com/python/cpython/blob/9dbd12375561a393eaec4b21ee4ac568a407cdb0/Objects/rangeobject.c#L933
if (step > 0 && start < stop)
return 1 + (stop - 1 - start) / step;
else if (step < 0 && start > stop)
return 1 + (start - 1 - stop) / (-step);
else
return 0;
}
} // namespace range
/**
* @brief A Python range.
*/
template<typename T>
struct Range {
T start;
T stop;
T step;
/**
* @brief Calculate the `len()` of this range.
*/
template<typename SizeT>
T len() {
debug_assert(SizeT, step != 0);
return range::len(start, stop, step);
}
};
} // namespace
extern "C" {
using namespace range;
SliceIndex __nac3_range_slice_len(const SliceIndex start, const SliceIndex end, const SliceIndex step) {
return len(start, end, step);
}
}

View File

@ -1,6 +1,145 @@
#pragma once
#include "irrt/debug.hpp"
#include "irrt/exception.hpp"
#include "irrt/int_types.hpp"
#include "irrt/math_util.hpp"
#include "irrt/range.hpp"
namespace {
namespace slice {
/**
* @brief Resolve a possibly negative index in a list of a known length.
*
* Returns -1 if the resolved index is out of the list's bounds.
*/
template<typename T>
T resolve_index_in_length(T length, T index) {
T resolved = index < 0 ? length + index : index;
if (0 <= resolved && resolved < length) {
return resolved;
} else {
return -1;
}
}
/**
* @brief Resolve a slice as a range.
*
* This is equivalent to `range(*slice(start, stop, step).indices(length))` in Python.
*/
template<typename T>
void indices(bool start_defined,
T start,
bool stop_defined,
T stop,
bool step_defined,
T step,
T length,
T* range_start,
T* range_stop,
T* range_step) {
// Reference: https://github.com/python/cpython/blob/main/Objects/sliceobject.c#L388
*range_step = step_defined ? step : 1;
bool step_is_negative = *range_step < 0;
T lower, upper;
if (step_is_negative) {
lower = -1;
upper = length - 1;
} else {
lower = 0;
upper = length;
}
if (start_defined) {
*range_start = start < 0 ? max(lower, start + length) : min(upper, start);
} else {
*range_start = step_is_negative ? upper : lower;
}
if (stop_defined) {
*range_stop = stop < 0 ? max(lower, stop + length) : min(upper, stop);
} else {
*range_stop = step_is_negative ? lower : upper;
}
}
} // namespace slice
/**
* @brief A Python-like slice with **unresolved** indices.
*/
template<typename T>
struct Slice {
bool start_defined;
T start;
bool stop_defined;
T stop;
bool step_defined;
T step;
Slice() { this->reset(); }
void reset() {
this->start_defined = false;
this->stop_defined = false;
this->step_defined = false;
}
void set_start(T start) {
this->start_defined = true;
this->start = start;
}
void set_stop(T stop) {
this->stop_defined = true;
this->stop = stop;
}
void set_step(T step) {
this->step_defined = true;
this->step = step;
}
/**
* @brief Resolve this slice as a range.
*
* In Python, this would be `range(*slice(start, stop, step).indices(length))`.
*/
template<typename SizeT>
Range<T> indices(T length) {
// Reference:
// https://github.com/python/cpython/blob/main/Objects/sliceobject.c#L388
debug_assert(SizeT, length >= 0);
Range<T> result;
slice::indices(start_defined, start, stop_defined, stop, step_defined, step, length, &result.start,
&result.stop, &result.step);
return result;
}
/**
* @brief Like `.indices()` but with assertions.
*/
template<typename SizeT>
Range<T> indices_checked(T length) {
// TODO: Switch to `SizeT length`
if (length < 0) {
raise_exception(SizeT, EXN_VALUE_ERROR, "length should not be negative, got {0}", length, NO_PARAM,
NO_PARAM);
}
if (this->step_defined && this->step == 0) {
raise_exception(SizeT, EXN_VALUE_ERROR, "slice step cannot be zero", NO_PARAM, NO_PARAM, NO_PARAM);
}
return this->indices<SizeT>(length);
}
};
} // namespace
extern "C" {
SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) {
@ -14,15 +153,4 @@ SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) {
}
return i;
}
SliceIndex __nac3_range_slice_len(const SliceIndex start, const SliceIndex end, const SliceIndex step) {
SliceIndex diff = end - start;
if (diff > 0 && step > 0) {
return ((diff - 1) / step) + 1;
} else if (diff < 0 && step < 0) {
return ((diff + 1) / step) + 1;
} else {
return 0;
}
}
}

View File

@ -1,29 +1,26 @@
use inkwell::{
types::BasicTypeEnum,
values::{BasicValue, BasicValueEnum, IntValue, PointerValue},
FloatPredicate, IntPredicate, OptimizationLevel,
};
use inkwell::types::BasicTypeEnum;
use inkwell::values::{BasicValue, BasicValueEnum, IntValue, PointerValue};
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
use itertools::Itertools;
use super::{
classes::{
ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor,
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
},
expr::destructure_range,
extern_fns, irrt,
irrt::calculate_len_for_slice_range,
llvm_intrinsics,
macros::codegen_unreachable,
numpy,
numpy::ndarray_elementwise_unaryop_impl,
stmt::gen_for_callback_incrementing,
CodeGenContext, CodeGenerator,
};
use crate::{
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys},
typecheck::typedef::{Type, TypeEnum},
use crate::codegen::classes::{
NDArrayValue, ProxyValue, RangeValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
};
use crate::codegen::expr::destructure_range;
use crate::codegen::irrt::calculate_len_for_slice_range;
use crate::codegen::macros::codegen_unreachable;
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
use crate::codegen::stmt::gen_for_callback_incrementing;
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
use crate::toplevel::helper::PrimDef;
use crate::toplevel::numpy::unpack_ndarray_var_tys;
use crate::typecheck::typedef::{Type, TypeEnum};
use super::model::*;
use super::object::any::AnyObject;
use super::object::list::ListObject;
use super::object::ndarray::NDArrayObject;
use super::object::tuple::TupleObject;
/// Shorthand for [`unreachable!()`] when a type of argument is not supported.
///
@ -42,58 +39,33 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>(
ctx: &mut CodeGenContext<'ctx, '_>,
n: (Type, BasicValueEnum<'ctx>),
) -> Result<IntValue<'ctx>, String> {
let llvm_i32 = ctx.ctx.i32_type();
let range_ty = ctx.primitives.range;
let (arg_ty, arg) = n;
Ok(if ctx.unifier.unioned(arg_ty, range_ty) {
Ok(if ctx.unifier.unioned(arg_ty, ctx.primitives.range) {
let arg = RangeValue::from_ptr_val(arg.into_pointer_value(), Some("range"));
let (start, end, step) = destructure_range(ctx, arg);
calculate_len_for_slice_range(generator, ctx, start, end, step)
} else {
match &*ctx.unifier.get_ty_immutable(arg_ty) {
TypeEnum::TTuple { ty, .. } => llvm_i32.const_int(ty.len() as u64, false),
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::List.id() => {
let zero = llvm_i32.const_zero();
let len = ctx
.build_gep_and_load(
arg.into_pointer_value(),
&[zero, llvm_i32.const_int(1, false)],
None,
)
.into_int_value();
ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap()
let arg = AnyObject { ty: arg_ty, value: arg };
let len: Instance<'ctx, Int<Int32>> = match &*ctx.unifier.get_ty(arg_ty) {
TypeEnum::TTuple { .. } => {
let tuple = TupleObject::from_object(ctx, arg);
tuple.len(generator, ctx).truncate_or_bit_cast(generator, ctx, Int32)
}
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let llvm_usize = generator.get_size_type(ctx.ctx);
let arg = NDArrayValue::from_ptr_val(arg.into_pointer_value(), llvm_usize, None);
let ndims = arg.dim_sizes().size(ctx, generator);
ctx.make_assert(
generator,
ctx.builder
.build_int_compare(IntPredicate::NE, ndims, llvm_usize.const_zero(), "")
.unwrap(),
"0:TypeError",
"len() of unsized object",
[None, None, None],
ctx.current_loc,
);
let len = unsafe {
arg.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_zero(),
None,
)
};
ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap()
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
{
let ndarray = NDArrayObject::from_object(generator, ctx, arg);
ndarray.len(generator, ctx).truncate_or_bit_cast(generator, ctx, Int32)
}
_ => codegen_unreachable!(ctx),
}
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{
let list = ListObject::from_object(generator, ctx, arg);
list.len(generator, ctx).truncate_or_bit_cast(generator, ctx, Int32)
}
_ => unsupported_type(ctx, "len", &[arg_ty]),
};
len.value
})
}

View File

@ -1,16 +1,17 @@
use inkwell::{
context::Context,
types::{AnyTypeEnum, ArrayType, BasicType, BasicTypeEnum, IntType, PointerType, StructType},
values::{ArrayValue, BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue},
AddressSpace, IntPredicate,
};
use super::{
use crate::codegen::{
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
llvm_intrinsics::call_int_umin,
stmt::gen_for_callback_incrementing,
CodeGenContext, CodeGenerator,
};
use inkwell::context::Context;
use inkwell::types::{ArrayType, BasicType, StructType};
use inkwell::values::{ArrayValue, BasicValue, StructValue};
use inkwell::{
types::{AnyTypeEnum, BasicTypeEnum, IntType, PointerType},
values::{BasicValueEnum, IntValue, PointerValue},
AddressSpace, IntPredicate,
};
/// A LLVM type that is used to represent a non-primitive type in NAC3.
pub trait ProxyType<'ctx>: Into<Self::Base> {
@ -1249,13 +1250,11 @@ impl<'ctx> NDArrayType<'ctx> {
/// Returns the element type of this `ndarray` type.
#[must_use]
pub fn element_type(&self) -> AnyTypeEnum<'ctx> {
pub fn element_type(&self) -> BasicTypeEnum<'ctx> {
self.as_base_type()
.get_element_type()
.into_struct_type()
.get_field_type_at_index(2)
.map(BasicTypeEnum::into_pointer_type)
.map(PointerType::get_element_type)
.unwrap()
}
}

View File

@ -1,9 +1,3 @@
use std::collections::HashMap;
use indexmap::IndexMap;
use nac3parser::ast::StrRef;
use crate::{
symbol_resolver::SymbolValue,
toplevel::DefinitionId,
@ -15,6 +9,10 @@ use crate::{
},
};
use indexmap::IndexMap;
use nac3parser::ast::StrRef;
use std::collections::HashMap;
pub struct ConcreteTypeStore {
store: Vec<ConcreteTypeEnum>,
}

View File

@ -1,10 +1,31 @@
use std::{
cmp::min,
collections::HashMap,
convert::TryInto,
iter::{once, repeat, repeat_with, zip},
use crate::{
codegen::{
classes::{
ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, NDArrayValue, ProxyType,
ProxyValue, RangeValue, UntypedArrayLikeAccessor,
},
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
gen_in_range_check, get_llvm_abi_type, get_llvm_type, get_va_count_arg_name,
irrt::*,
llvm_intrinsics::{
call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax,
call_int_umin, call_memcpy_generic,
},
macros::codegen_unreachable,
need_sret, numpy,
stmt::{
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
gen_var,
},
CodeGenContext, CodeGenTask, CodeGenerator,
},
symbol_resolver::{SymbolValue, ValueEnum},
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef},
typecheck::{
magic_methods::{Binop, BinopVariant, HasOpInfo},
typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap},
},
};
use inkwell::{
attributes::{Attribute, AttributeLoc},
types::{AnyType, BasicType, BasicTypeEnum},
@ -12,43 +33,17 @@ use inkwell::{
AddressSpace, IntPredicate, OptimizationLevel,
};
use itertools::{chain, izip, Either, Itertools};
use nac3parser::ast::{
self, Boolop, Cmpop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef,
Unaryop,
};
use std::cmp::min;
use std::iter::{repeat, repeat_with};
use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
use super::{
classes::{
ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, NDArrayValue, ProxyType, ProxyValue,
RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor,
},
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
gen_in_range_check, get_llvm_abi_type, get_llvm_type, get_va_count_arg_name,
irrt::*,
llvm_intrinsics::{
call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax,
call_int_umin, call_memcpy_generic,
},
macros::codegen_unreachable,
need_sret, numpy,
stmt::{
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
gen_var,
},
CodeGenContext, CodeGenTask, CodeGenerator,
};
use crate::{
symbol_resolver::{SymbolValue, ValueEnum},
toplevel::{
helper::PrimDef,
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
DefinitionId, TopLevelDef,
},
typecheck::{
magic_methods::{Binop, BinopVariant, HasOpInfo},
typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap},
},
use super::object::{
any::AnyObject,
ndarray::{indexing::util::gen_ndarray_subscript_ndindices, NDArrayObject},
};
pub fn get_subst_key(
@ -557,7 +552,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
&& val_ty.get_element_type().is_struct_type()
} =>
{
self.builder.build_bit_cast(*val, arg_ty, "call_arg_cast").unwrap()
self.builder.build_bitcast(*val, arg_ty, "call_arg_cast").unwrap()
}
_ => *val,
})
@ -977,7 +972,6 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
TopLevelDef::Class { .. } => {
return Ok(Some(generator.gen_constructor(ctx, fun.0, &def, params)?))
}
TopLevelDef::Variable { .. } => unreachable!(),
}
}
.or_else(|_: String| {
@ -2505,338 +2499,6 @@ pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>(
)
}
/// Generates code for a subscript expression on an `ndarray`.
///
/// * `ty` - The `Type` of the `NDArray` elements.
/// * `ndims` - The `Type` of the `NDArray` number-of-dimensions `Literal`.
/// * `v` - The `NDArray` value.
/// * `slice` - The slice expression used to subscript into the `ndarray`.
fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ty: Type,
ndims: Type,
v: NDArrayValue<'ctx>,
slice: &Expr<Option<Type>>,
) -> Result<Option<ValueEnum<'ctx>>, String> {
let llvm_i1 = ctx.ctx.bool_type();
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) else {
codegen_unreachable!(ctx)
};
let ndims = values
.iter()
.map(|ndim| u64::try_from(ndim.clone()).map_err(|()| ndim.clone()))
.collect::<Result<Vec<_>, _>>()
.map_err(|val| {
format!(
"Expected non-negative literal for ndarray.ndims, got {}",
i128::try_from(val).unwrap()
)
})?;
assert!(!ndims.is_empty());
// The number of dimensions subscripted by the index expression.
// Slicing a ndarray will yield the same number of dimensions, whereas indexing into a
// dimension will remove a dimension.
let subscripted_dims = match &slice.node {
ExprKind::Tuple { elts, .. } => elts.iter().fold(0, |acc, value_subexpr| {
if let ExprKind::Slice { .. } = &value_subexpr.node {
acc
} else {
acc + 1
}
}),
ExprKind::Slice { .. } => 0,
_ => 1,
};
let ndarray_ndims_ty = ctx.unifier.get_fresh_literal(
ndims.iter().map(|v| SymbolValue::U64(v - subscripted_dims)).collect(),
None,
);
let ndarray_ty =
make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(ty), Some(ndarray_ndims_ty));
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum();
let sizeof_elem = llvm_ndarray_data_t.size_of().unwrap();
// Check that len is non-zero
let len = v.load_ndims(ctx);
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::SGT, len, llvm_usize.const_zero(), "").unwrap(),
"0:IndexError",
"too many indices for array: array is {0}-dimensional but 1 were indexed",
[Some(len), None, None],
slice.location,
);
// Normalizes a possibly-negative index to its corresponding positive index
let normalize_index = |generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
index: IntValue<'ctx>,
dim: u64| {
gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(IntPredicate::SGE, index, index.get_type().const_zero(), "")
.unwrap())
},
|_, _| Ok(Some(index)),
|generator, ctx| {
let llvm_i32 = ctx.ctx.i32_type();
let len = unsafe {
v.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(dim, true),
None,
)
};
let index = ctx
.builder
.build_int_add(
len,
ctx.builder.build_int_s_extend(index, llvm_usize, "").unwrap(),
"",
)
.unwrap();
Ok(Some(ctx.builder.build_int_truncate(index, llvm_i32, "").unwrap()))
},
)
.map(|v| v.map(BasicValueEnum::into_int_value))
};
// Converts a slice expression into a slice-range tuple
let expr_to_slice = |generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
node: &ExprKind<Option<Type>>,
dim: u64| {
match node {
ExprKind::Constant { value: Constant::Int(v), .. } => {
let Some(index) =
normalize_index(generator, ctx, llvm_i32.const_int(*v as u64, true), dim)?
else {
return Ok(None);
};
Ok(Some((index, index, llvm_i32.const_int(1, true))))
}
ExprKind::Slice { lower, upper, step } => {
let dim_sz = unsafe {
v.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(dim, false),
None,
)
};
handle_slice_indices(lower, upper, step, ctx, generator, dim_sz)
}
_ => {
let Some(index) = generator.gen_expr(ctx, slice)? else { return Ok(None) };
let index = index
.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?
.into_int_value();
let Some(index) = normalize_index(generator, ctx, index, dim)? else {
return Ok(None);
};
Ok(Some((index, index, llvm_i32.const_int(1, true))))
}
}
};
let make_indices_arr = |generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>|
-> Result<_, String> {
Ok(if let ExprKind::Tuple { elts, .. } = &slice.node {
let llvm_int_ty = ctx.get_llvm_type(generator, elts[0].custom.unwrap());
let index_addr = generator.gen_array_var_alloc(
ctx,
llvm_int_ty,
llvm_usize.const_int(elts.len() as u64, false),
None,
)?;
for (i, elt) in elts.iter().enumerate() {
let Some(index) = generator.gen_expr(ctx, elt)? else {
return Ok(None);
};
let index = index
.to_basic_value_enum(ctx, generator, elt.custom.unwrap())?
.into_int_value();
let Some(index) = normalize_index(generator, ctx, index, 0)? else {
return Ok(None);
};
let store_ptr = unsafe {
index_addr.ptr_offset_unchecked(
ctx,
generator,
&llvm_usize.const_int(i as u64, false),
None,
)
};
ctx.builder.build_store(store_ptr, index).unwrap();
}
Some(index_addr)
} else if let Some(index) = generator.gen_expr(ctx, slice)? {
let llvm_int_ty = ctx.get_llvm_type(generator, slice.custom.unwrap());
let index_addr = generator.gen_array_var_alloc(
ctx,
llvm_int_ty,
llvm_usize.const_int(1u64, false),
None,
)?;
let index =
index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value();
let Some(index) = normalize_index(generator, ctx, index, 0)? else { return Ok(None) };
let store_ptr = unsafe {
index_addr.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
ctx.builder.build_store(store_ptr, index).unwrap();
Some(index_addr)
} else {
None
})
};
Ok(Some(if ndims.len() == 1 && ndims[0] - subscripted_dims == 0 {
let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) };
v.data().get(ctx, generator, &index_addr, None).into()
} else {
match &slice.node {
ExprKind::Tuple { elts, .. } => {
let slices = elts
.iter()
.enumerate()
.map(|(dim, elt)| expr_to_slice(generator, ctx, &elt.node, dim as u64))
.take_while_inclusive(|slice| slice.as_ref().is_ok_and(Option::is_some))
.collect::<Result<Vec<_>, _>>()?;
if slices.len() < elts.len() {
return Ok(None);
}
let slices = slices.into_iter().map(Option::unwrap).collect_vec();
numpy::ndarray_sliced_copy(generator, ctx, ty, v, &slices)?.as_base_value().into()
}
ExprKind::Slice { .. } => {
let Some(slice) = expr_to_slice(generator, ctx, &slice.node, 0)? else {
return Ok(None);
};
numpy::ndarray_sliced_copy(generator, ctx, ty, v, &[slice])?.as_base_value().into()
}
_ => {
// Accessing an element from a multi-dimensional `ndarray`
let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) };
// Create a new array, remove the top dimension from the dimension-size-list, and copy the
// elements over
let subscripted_ndarray =
generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?;
let ndarray = NDArrayValue::from_ptr_val(subscripted_ndarray, llvm_usize, None);
let num_dims = v.load_ndims(ctx);
ndarray.store_ndims(
ctx,
generator,
ctx.builder
.build_int_sub(num_dims, llvm_usize.const_int(1, false), "")
.unwrap(),
);
let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
let ndarray_num_dims = ctx
.builder
.build_int_z_extend_or_bit_cast(
ndarray.load_ndims(ctx),
llvm_usize.size_of().get_type(),
"",
)
.unwrap();
let v_dims_src_ptr = unsafe {
v.dim_sizes().ptr_offset_unchecked(
ctx,
generator,
&llvm_usize.const_int(1, false),
None,
)
};
call_memcpy_generic(
ctx,
ndarray.dim_sizes().base_ptr(ctx, generator),
v_dims_src_ptr,
ctx.builder
.build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "")
.map(Into::into)
.unwrap(),
llvm_i1.const_zero(),
);
let ndarray_num_elems = call_ndarray_calc_size(
generator,
ctx,
&ndarray.dim_sizes().as_slice_value(ctx, generator),
(None, None),
);
let ndarray_num_elems = ctx
.builder
.build_int_z_extend_or_bit_cast(ndarray_num_elems, sizeof_elem.get_type(), "")
.unwrap();
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
let v_data_src_ptr = v.data().ptr_offset(ctx, generator, &index_addr, None);
call_memcpy_generic(
ctx,
ndarray.data().base_ptr(ctx, generator),
v_data_src_ptr,
ctx.builder
.build_int_mul(
ndarray_num_elems,
llvm_ndarray_data_t.size_of().unwrap(),
"",
)
.map(Into::into)
.unwrap(),
llvm_i1.const_zero(),
);
ndarray.as_base_value().into()
}
}
}))
}
/// See [`CodeGenerator::gen_expr`].
pub fn gen_expr<'ctx, G: CodeGenerator>(
generator: &mut G,
@ -2886,31 +2548,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
Some((_, Some(static_value), _)) => ValueEnum::Static(static_value.clone()),
None => {
let resolver = ctx.resolver.clone();
let value = resolver.get_symbol_value(*id, ctx, generator).unwrap();
let globals = ctx
.top_level
.definitions
.read()
.iter()
.filter_map(|def| {
if let TopLevelDef::Variable { simple_name, ty, .. } = &*def.read() {
Some((*simple_name, *ty))
} else {
None
}
})
.collect_vec();
if let Some((_, ty)) = globals.iter().find(|(name, _)| name == id) {
let ptr = value
.to_basic_value_enum(ctx, generator, *ty)
.map(BasicValueEnum::into_pointer_value)?;
ctx.builder.build_load(ptr, id.to_string().as_str()).map(Into::into).unwrap()
} else {
value
}
resolver.get_symbol_value(*id, ctx).unwrap()
}
},
ExprKind::List { elts, .. } => {
@ -3493,18 +3131,26 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
v.data().get(ctx, generator, &index, None).into()
}
}
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::NDArray.id() => {
let (ty, ndims) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap();
let v = if let Some(v) = generator.gen_expr(ctx, value)? {
v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?
.into_pointer_value()
} else {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let Some(ndarray) = generator.gen_expr(ctx, value)? else {
return Ok(None);
};
let v = NDArrayValue::from_ptr_val(v, usize, None);
return gen_ndarray_subscript_expr(generator, ctx, *ty, *ndims, v, slice);
let ndarray_ty = value.custom.unwrap();
let ndarray = ndarray.to_basic_value_enum(ctx, generator, ndarray_ty)?;
let ndarray = NDArrayObject::from_object(
generator,
ctx,
AnyObject { ty: ndarray_ty, value: ndarray },
);
let indices = gen_ndarray_subscript_ndindices(generator, ctx, slice)?;
let result = ndarray
.index(generator, ctx, &indices)
.split_unsized(generator, ctx)
.to_basic_value_enum();
return Ok(Some(ValueEnum::Dynamic(result)));
}
TypeEnum::TTuple { .. } => {
let index: u32 =

View File

@ -1,10 +1,8 @@
use inkwell::{
attributes::{Attribute, AttributeLoc},
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue},
};
use inkwell::attributes::{Attribute, AttributeLoc};
use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue};
use itertools::Either;
use super::CodeGenContext;
use crate::codegen::CodeGenContext;
/// Macro to generate extern function
/// Both function return type and function parameter type are `FloatValue`

View File

@ -1,18 +1,16 @@
use crate::{
codegen::{bool_to_i1, bool_to_i8, classes::ArraySliceValue, expr::*, stmt::*, CodeGenContext},
symbol_resolver::ValueEnum,
toplevel::{DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, Type},
};
use inkwell::{
context::Context,
types::{BasicTypeEnum, IntType},
values::{BasicValueEnum, IntValue, PointerValue},
};
use nac3parser::ast::{Expr, Stmt, StrRef};
use super::{bool_to_i1, bool_to_i8, classes::ArraySliceValue, expr::*, stmt::*, CodeGenContext};
use crate::{
symbol_resolver::ValueEnum,
toplevel::{DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, Type},
};
pub trait CodeGenerator {
/// Return the module name for the code generator.
fn get_name(&self) -> &str;

View File

@ -1,3 +1,21 @@
use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Type};
use super::{
classes::{
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue,
TypedArrayLikeAccessor, TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
},
llvm_intrinsics,
macros::codegen_unreachable,
model::*,
object::{
list::List,
ndarray::{indexing::NDIndex, nditer::NDIter, NDArray},
},
stmt::gen_for_callback_incrementing,
CodeGenContext, CodeGenerator,
};
use function::FnCall;
use inkwell::{
attributes::{Attribute, AttributeLoc},
context::Context,
@ -8,21 +26,8 @@ use inkwell::{
AddressSpace, IntPredicate,
};
use itertools::Either;
use nac3parser::ast::Expr;
use super::{
classes::{
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue,
TypedArrayLikeAccessor, TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
},
llvm_intrinsics,
macros::codegen_unreachable,
stmt::gen_for_callback_incrementing,
CodeGenContext, CodeGenerator,
};
use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Type};
#[must_use]
pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> {
let bitcode_buf = MemoryBuffer::create_from_memory_range(
@ -950,3 +955,224 @@ pub fn call_ndarray_calc_broadcast_index<
Box::new(|_, v| v.into()),
)
}
// When [`TypeContext::size_type`] is 32-bits, the function name is "{fn_name}".
// When [`TypeContext::size_type`] is 64-bits, the function name is "{fn_name}64".
#[must_use]
pub fn get_sizet_dependent_function_name<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &CodeGenContext<'_, '_>,
name: &str,
) -> String {
let mut name = name.to_owned();
match generator.get_size_type(ctx.ctx).get_bit_width() {
32 => {}
64 => name.push_str("64"),
bit_width => {
panic!("Unsupported int type bit width {bit_width}, must be either 32-bits or 64-bits")
}
}
name
}
pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndims: Instance<'ctx, Int<SizeT>>,
shape: Instance<'ctx, Ptr<Int<SizeT>>>,
) {
let name = get_sizet_dependent_function_name(
generator,
ctx,
"__nac3_ndarray_util_assert_shape_no_negative",
);
FnCall::builder(generator, ctx, &name).arg(ndims).arg(shape).returning_void();
}
pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray_ndims: Instance<'ctx, Int<SizeT>>,
ndarray_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
output_ndims: Instance<'ctx, Int<SizeT>>,
output_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
) {
let name = get_sizet_dependent_function_name(
generator,
ctx,
"__nac3_ndarray_util_assert_output_shape_same",
);
FnCall::builder(generator, ctx, &name)
.arg(ndarray_ndims)
.arg(ndarray_shape)
.arg(output_ndims)
.arg(output_shape)
.returning_void();
}
pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
) -> Instance<'ctx, Int<SizeT>> {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_size");
FnCall::builder(generator, ctx, &name).arg(ndarray).returning_auto("size")
}
pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
) -> Instance<'ctx, Int<SizeT>> {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_nbytes");
FnCall::builder(generator, ctx, &name).arg(ndarray).returning_auto("nbytes")
}
pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
) -> Instance<'ctx, Int<SizeT>> {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_len");
FnCall::builder(generator, ctx, &name).arg(ndarray).returning_auto("len")
}
pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
) -> Instance<'ctx, Int<Bool>> {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_is_c_contiguous");
FnCall::builder(generator, ctx, &name).arg(ndarray).returning_auto("is_c_contiguous")
}
pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
index: Instance<'ctx, Int<SizeT>>,
) -> Instance<'ctx, Ptr<Int<Byte>>> {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_get_nth_pelement");
FnCall::builder(generator, ctx, &name).arg(ndarray).arg(index).returning_auto("pelement")
}
pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
indices: Instance<'ctx, Ptr<Int<SizeT>>>,
) -> Instance<'ctx, Ptr<Int<Byte>>> {
let name =
get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_get_pelement_by_indices");
FnCall::builder(generator, ctx, &name).arg(ndarray).arg(indices).returning_auto("pelement")
}
pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
) {
let name =
get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_set_strides_by_shape");
FnCall::builder(generator, ctx, &name).arg(ndarray).returning_void();
}
pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
dst_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
) {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_copy_data");
FnCall::builder(generator, ctx, &name).arg(src_ndarray).arg(dst_ndarray).returning_void();
}
pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
iter: Instance<'ctx, Ptr<Struct<NDIter>>>,
ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
indices: Instance<'ctx, Ptr<Int<SizeT>>>,
) {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_initialize");
FnCall::builder(generator, ctx, &name).arg(iter).arg(ndarray).arg(indices).returning_void();
}
pub fn call_nac3_nditer_has_element<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
iter: Instance<'ctx, Ptr<Struct<NDIter>>>,
) -> Instance<'ctx, Int<Bool>> {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_has_element");
FnCall::builder(generator, ctx, &name).arg(iter).returning_auto("has_element")
}
pub fn call_nac3_nditer_next<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
iter: Instance<'ctx, Ptr<Struct<NDIter>>>,
) {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_next");
FnCall::builder(generator, ctx, &name).arg(iter).returning_void();
}
pub fn call_nac3_ndarray_index<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
num_indices: Instance<'ctx, Int<SizeT>>,
indices: Instance<'ctx, Ptr<Struct<NDIndex>>>,
src_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
dst_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
) {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_index");
FnCall::builder(generator, ctx, &name)
.arg(num_indices)
.arg(indices)
.arg(src_ndarray)
.arg(dst_ndarray)
.returning_void();
}
pub fn call_nac3_ndarray_array_set_and_validate_list_shape<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
list: Instance<'ctx, Ptr<Struct<List<Int<Byte>>>>>,
ndims: Instance<'ctx, Int<SizeT>>,
shape: Instance<'ctx, Ptr<Int<SizeT>>>,
) {
let name = get_sizet_dependent_function_name(
generator,
ctx,
"__nac3_ndarray_array_set_and_validate_list_shape",
);
FnCall::builder(generator, ctx, &name).arg(list).arg(ndims).arg(shape).returning_void();
}
pub fn call_nac3_ndarray_array_write_list_to_array<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
list: Instance<'ctx, Ptr<Struct<List<Int<Byte>>>>>,
ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
) {
let name = get_sizet_dependent_function_name(
generator,
ctx,
"__nac3_ndarray_array_write_list_to_array",
);
FnCall::builder(generator, ctx, &name).arg(list).arg(ndarray).returning_void();
}
pub fn call_nac3_ndarray_reshape_resolve_and_check_new_shape<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
size: Instance<'ctx, Int<SizeT>>,
new_ndims: Instance<'ctx, Int<SizeT>>,
new_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
) {
let name = get_sizet_dependent_function_name(
generator,
ctx,
"__nac3_ndarray_reshape_resolve_and_check_new_shape",
);
FnCall::builder(generator, ctx, &name).arg(size).arg(new_ndims).arg(new_shape).returning_void();
}

View File

@ -1,14 +1,12 @@
use inkwell::{
context::Context,
intrinsics::Intrinsic,
types::{AnyTypeEnum::IntType, FloatType},
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue},
AddressSpace,
};
use crate::codegen::CodeGenContext;
use inkwell::context::Context;
use inkwell::intrinsics::Intrinsic;
use inkwell::types::AnyTypeEnum::IntType;
use inkwell::types::FloatType;
use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue};
use inkwell::AddressSpace;
use itertools::Either;
use super::CodeGenContext;
/// Returns the string representation for the floating-point type `ft` when used in intrinsic
/// functions.
fn get_float_intrinsic_repr(ctx: &Context, ft: FloatType) -> &'static str {
@ -185,7 +183,7 @@ pub fn call_memcpy_generic<'ctx>(
dest
} else {
ctx.builder
.build_bit_cast(dest, llvm_p0i8, "")
.build_bitcast(dest, llvm_p0i8, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap()
};
@ -193,7 +191,7 @@ pub fn call_memcpy_generic<'ctx>(
src
} else {
ctx.builder
.build_bit_cast(src, llvm_p0i8, "")
.build_bitcast(src, llvm_p0i8, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap()
};

View File

@ -1,12 +1,12 @@
use std::{
collections::{HashMap, HashSet},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
use crate::{
codegen::classes::{ListType, ProxyType, RangeType},
symbol_resolver::{StaticValue, SymbolResolver},
toplevel::{helper::PrimDef, TopLevelContext, TopLevelDef},
typecheck::{
type_inferencer::{CodeLocation, PrimitiveStore},
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
},
thread,
};
use crossbeam::channel::{unbounded, Receiver, Sender};
use inkwell::{
attributes::{Attribute, AttributeLoc},
@ -24,21 +24,16 @@ use inkwell::{
AddressSpace, IntPredicate, OptimizationLevel,
};
use itertools::Itertools;
use parking_lot::{Condvar, Mutex};
use model::*;
use nac3parser::ast::{Location, Stmt, StrRef};
use crate::{
symbol_resolver::{StaticValue, SymbolResolver},
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef},
typecheck::{
type_inferencer::{CodeLocation, PrimitiveStore},
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
},
use object::ndarray::NDArray;
use parking_lot::{Condvar, Mutex};
use std::collections::{HashMap, HashSet};
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use classes::{ListType, NDArrayType, ProxyType, RangeType};
use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore};
pub use generator::{CodeGenerator, DefaultCodeGenerator};
use std::thread;
pub mod builtin_fns;
pub mod classes;
@ -48,12 +43,17 @@ pub mod extern_fns;
mod generator;
pub mod irrt;
pub mod llvm_intrinsics;
pub mod model;
pub mod numpy;
pub mod object;
pub mod stmt;
#[cfg(test)]
mod test;
use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore};
pub use generator::{CodeGenerator, DefaultCodeGenerator};
mod macros {
/// Codegen-variant of [`std::unreachable`] which accepts an instance of [`CodeGenContext`] as
/// its first argument to provide Python source information to indicate the codegen location
@ -509,12 +509,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
}
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (dtype, _) = unpack_ndarray_var_tys(unifier, ty);
let element_type = get_llvm_type(
ctx, module, generator, unifier, top_level, type_cache, dtype,
);
NDArrayType::new(generator, ctx, element_type).as_base_type().into()
Ptr(Struct(NDArray)).llvm_type(generator, ctx).as_basic_type_enum()
}
_ => unreachable!(
@ -852,9 +847,10 @@ pub fn gen_func_impl<
builder.position_at_end(init_bb);
let body_bb = context.append_basic_block(fn_val, "body");
// Store non-vararg argument values into local variables
let mut var_assignment = HashMap::new();
let offset = u32::from(has_sret);
// Store non-vararg argument values into local variables
for (n, arg) in args.iter().enumerate().filter(|(_, arg)| !arg.is_vararg) {
let param = fn_val.get_nth_param((n as u32) + offset).unwrap();
let local_type = get_llvm_type(

View File

@ -0,0 +1,42 @@
use inkwell::{
context::Context,
types::{BasicType, BasicTypeEnum},
values::BasicValueEnum,
};
use crate::codegen::CodeGenerator;
use super::*;
/// A [`Model`] of any [`BasicTypeEnum`].
///
/// Use this when it is infeasible to use model abstractions.
#[derive(Debug, Clone, Copy)]
pub struct Any<'ctx>(pub BasicTypeEnum<'ctx>);
impl<'ctx> Model<'ctx> for Any<'ctx> {
type Value = BasicValueEnum<'ctx>;
type Type = BasicTypeEnum<'ctx>;
fn llvm_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
_ctx: &'ctx Context,
) -> Self::Type {
self.0
}
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
&self,
_generator: &mut G,
_ctx: &'ctx Context,
ty: T,
) -> Result<(), ModelError> {
let ty = ty.as_basic_type_enum();
if ty == self.0 {
Ok(())
} else {
Err(ModelError(format!("Expecting {}, but got {}", self.0, ty)))
}
}
}

View File

@ -0,0 +1,147 @@
use std::fmt;
use inkwell::{
context::Context,
types::{ArrayType, BasicType, BasicTypeEnum},
values::{ArrayValue, IntValue},
};
use crate::codegen::{CodeGenContext, CodeGenerator};
use super::*;
/// Trait for Rust structs identifying length values for [`Array`].
pub trait ArrayLen: fmt::Debug + Clone + Copy {
fn length(&self) -> u32;
}
/// A statically known length.
#[derive(Debug, Clone, Copy, Default)]
pub struct Len<const N: u32>;
/// A dynamically known length.
#[derive(Debug, Clone, Copy)]
pub struct AnyLen(pub u32);
impl<const N: u32> ArrayLen for Len<N> {
fn length(&self) -> u32 {
N
}
}
impl ArrayLen for AnyLen {
fn length(&self) -> u32 {
self.0
}
}
/// A Model for an [`ArrayType`].
///
/// `Len` should be of a [`LenKind`] and `Item` should be a of [`Model`].
#[derive(Debug, Clone, Copy, Default)]
pub struct Array<Len, Item> {
/// Length of this array.
pub len: Len,
/// [`Model`] of the array items.
pub item: Item,
}
impl<'ctx, Len: ArrayLen, Item: Model<'ctx>> Model<'ctx> for Array<Len, Item> {
type Value = ArrayValue<'ctx>;
type Type = ArrayType<'ctx>;
fn llvm_type<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &'ctx Context,
) -> Self::Type {
self.item.llvm_type(generator, ctx).array_type(self.len.length())
}
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
ty: T,
) -> Result<(), ModelError> {
let ty = ty.as_basic_type_enum();
let BasicTypeEnum::ArrayType(ty) = ty else {
return Err(ModelError(format!("Expecting ArrayType, but got {ty:?}")));
};
if ty.len() != self.len.length() {
return Err(ModelError(format!(
"Expecting ArrayType with size {}, but got an ArrayType with size {}",
ty.len(),
self.len.length()
)));
}
self.item
.check_type(generator, ctx, ty.get_element_type())
.map_err(|err| err.under_context("an ArrayType"))?;
Ok(())
}
}
impl<'ctx, Len: ArrayLen, Item: Model<'ctx>> Instance<'ctx, Ptr<Array<Len, Item>>> {
/// Get the pointer to the `i`-th (0-based) array element.
pub fn gep(
&self,
ctx: &CodeGenContext<'ctx, '_>,
i: IntValue<'ctx>,
) -> Instance<'ctx, Ptr<Item>> {
let zero = ctx.ctx.i32_type().const_zero();
let ptr = unsafe { ctx.builder.build_in_bounds_gep(self.value, &[zero, i], "").unwrap() };
unsafe { Ptr(self.model.0.item).believe_value(ptr) }
}
/// Like `gep` but `i` is a constant.
pub fn gep_const(&self, ctx: &CodeGenContext<'ctx, '_>, i: u64) -> Instance<'ctx, Ptr<Item>> {
assert!(
i < u64::from(self.model.0.len.length()),
"Index {i} is out of bounds. Array length = {}",
self.model.0.len.length()
);
let i = ctx.ctx.i32_type().const_int(i, false);
self.gep(ctx, i)
}
/// Convenience function equivalent to `.gep(...).load(...)`.
pub fn get<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
i: IntValue<'ctx>,
) -> Instance<'ctx, Item> {
self.gep(ctx, i).load(generator, ctx)
}
/// Like `get` but `i` is a constant.
pub fn get_const<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
i: u64,
) -> Instance<'ctx, Item> {
self.gep_const(ctx, i).load(generator, ctx)
}
/// Convenience function equivalent to `.gep(...).store(...)`.
pub fn set(
&self,
ctx: &CodeGenContext<'ctx, '_>,
i: IntValue<'ctx>,
value: Instance<'ctx, Item>,
) {
self.gep(ctx, i).store(ctx, value);
}
/// Like `set` but `i` is a constant.
pub fn set_const(&self, ctx: &CodeGenContext<'ctx, '_>, i: u64, value: Instance<'ctx, Item>) {
self.gep_const(ctx, i).store(ctx, value);
}
}

View File

@ -0,0 +1,207 @@
use std::fmt;
use inkwell::{context::Context, types::*, values::*};
use itertools::Itertools;
use super::*;
use crate::codegen::{CodeGenContext, CodeGenerator};
/// A error type for reporting any [`Model`]-related error (e.g., a [`BasicType`] mismatch).
#[derive(Debug, Clone)]
pub struct ModelError(pub String);
impl ModelError {
/// Append a context message to the error.
pub(super) fn under_context(mut self, context: &str) -> Self {
self.0.push_str(" ... in ");
self.0.push_str(context);
self
}
}
/// Trait for Rust structs identifying [`BasicType`]s in the context of a known [`CodeGenerator`] and [`CodeGenContext`].
///
/// For instance,
/// - [`Int<Int32>`] identifies an [`IntType`] with 32-bits.
/// - [`Int<SizeT>`] identifies an [`IntType`] with bit-width [`CodeGenerator::get_size_type`].
/// - [`Ptr<Int<SizeT>>`] identifies a [`PointerType`] that points to an [`IntType`] with bit-width [`CodeGenerator::get_size_type`].
/// - [`Int<AnyInt>`] identifies an [`IntType`] with bit-width of whatever is set in the [`AnyInt`] object.
/// - [`Any`] identifies a [`BasicType`] set in the [`Any`] object itself.
///
/// You can get the [`BasicType`] out of a model with [`Model::get_type`].
///
/// Furthermore, [`Instance<'ctx, M>`] is a simple structure that carries a [`BasicValue`] with [`BasicType`] identified by model `M`.
///
/// The main purpose of this abstraction is to have a more Rust type-safe way to use Inkwell and give type-hints for programmers.
///
/// ### Notes on `Default` trait
///
/// For some models like [`Int<Int32>`] or [`Int<SizeT>`], they have a [`Default`] trait since just by looking at their types, it is possible
/// to tell the [`BasicType`]s they are identifying.
///
/// This can be used to create strongly-typed interfaces accepting only values of a specific [`BasicType`] without having to worry about
/// writing debug assertions to check, for example, if the programmer has passed in an [`IntValue`] with the wrong bit-width.
/// ```ignore
/// fn give_me_i32_and_get_a_size_t_back<'ctx>(i32: Instance<'ctx, Int<Int32>>) -> Instance<'ctx, Int<SizeT>> {
/// // code...
/// }
/// ```
///
/// ### Notes on converting between Inkwell and model/ge.
///
/// Suppose you have an [`IntValue`], and you want to pass it into a function that takes a [`Instance<'ctx, Int<Int32>>`]. You can do use
/// [`Model::check_value`] or [`Model::believe_value`].
/// ```ignore
/// let my_value: IntValue<'ctx>;
///
/// let my_value = Int(Int32).check_value(my_value).unwrap(); // Panics if `my_value` is not 32-bit with a descriptive error message.
///
/// // or, if you are absolutely certain that `my_value` is 32-bit and doing extra checks is a waste of time:
/// let my_value = Int(Int32).believe_value(my_value);
/// ```
pub trait Model<'ctx>: fmt::Debug + Clone + Copy {
/// The [`BasicType`] *variant* this model is identifying.
type Type: BasicType<'ctx>;
/// The [`BasicValue`] type of the [`BasicType`] of this model.
type Value: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>>;
/// Return the [`BasicType`] of this model.
#[must_use]
fn llvm_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context)
-> Self::Type;
/// Get the number of bytes of the [`BasicType`] of this model.
fn size_of<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
) -> IntValue<'ctx> {
self.llvm_type(generator, ctx).size_of().unwrap()
}
/// Check if a [`BasicType`] matches the [`BasicType`] of this model.
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
ty: T,
) -> Result<(), ModelError>;
/// Create an instance from a value.
///
/// # Safety
///
/// Caller must make sure the type of `value` and the type of this `model` are equivalent.
#[must_use]
unsafe fn believe_value(&self, value: Self::Value) -> Instance<'ctx, Self> {
Instance { model: *self, value }
}
/// Check if a [`BasicValue`]'s type is equivalent to the type of this model.
/// Wrap the [`BasicValue`] into an [`Instance`] if it is.
fn check_value<V: BasicValue<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
value: V,
) -> Result<Instance<'ctx, Self>, ModelError> {
let value = value.as_basic_value_enum();
self.check_type(generator, ctx, value.get_type())
.map_err(|err| err.under_context(format!("the value {value:?}").as_str()))?;
let Ok(value) = Self::Value::try_from(value) else {
unreachable!("check_type() has bad implementation")
};
unsafe { Ok(self.believe_value(value)) }
}
// Allocate a value on the stack and return its pointer.
fn alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
) -> Instance<'ctx, Ptr<Self>> {
let p = ctx.builder.build_alloca(self.llvm_type(generator, ctx.ctx), "").unwrap();
unsafe { Ptr(*self).believe_value(p) }
}
// Allocate an array on the stack and return its pointer.
fn array_alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
len: IntValue<'ctx>,
) -> Instance<'ctx, Ptr<Self>> {
let p =
ctx.builder.build_array_alloca(self.llvm_type(generator, ctx.ctx), len, "").unwrap();
unsafe { Ptr(*self).believe_value(p) }
}
fn var_alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&str>,
) -> Result<Instance<'ctx, Ptr<Self>>, String> {
let ty = self.llvm_type(generator, ctx.ctx).as_basic_type_enum();
let p = generator.gen_var_alloc(ctx, ty, name)?;
unsafe { Ok(Ptr(*self).believe_value(p)) }
}
fn array_var_alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
len: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> Result<Instance<'ctx, Ptr<Self>>, String> {
// TODO: Remove ArraySliceValue
let ty = self.llvm_type(generator, ctx.ctx).as_basic_type_enum();
let p = generator.gen_array_var_alloc(ctx, ty, len, name)?;
unsafe { Ok(Ptr(*self).believe_value(PointerValue::from(p))) }
}
/// Allocate a constant array.
fn const_array<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
values: &[Instance<'ctx, Self>],
) -> Instance<'ctx, Array<AnyLen, Self>> {
macro_rules! make {
($t:expr, $into_value:expr) => {
$t.const_array(
&values
.iter()
.map(|x| $into_value(x.value.as_basic_value_enum()))
.collect_vec(),
)
};
}
let value = match self.llvm_type(generator, ctx).as_basic_type_enum() {
BasicTypeEnum::ArrayType(t) => make!(t, BasicValueEnum::into_array_value),
BasicTypeEnum::IntType(t) => make!(t, BasicValueEnum::into_int_value),
BasicTypeEnum::FloatType(t) => make!(t, BasicValueEnum::into_float_value),
BasicTypeEnum::PointerType(t) => make!(t, BasicValueEnum::into_pointer_value),
BasicTypeEnum::StructType(t) => make!(t, BasicValueEnum::into_struct_value),
BasicTypeEnum::VectorType(t) => make!(t, BasicValueEnum::into_vector_value),
};
Array { len: AnyLen(values.len() as u32), item: *self }
.check_value(generator, ctx, value)
.unwrap()
}
}
#[derive(Debug, Clone, Copy)]
pub struct Instance<'ctx, M: Model<'ctx>> {
/// The model of this instance.
pub model: M,
/// The value of this instance.
///
/// It is guaranteed the [`BasicType`] of `value` is consistent with that of `model`.
pub value: M::Value,
}

View File

@ -0,0 +1,94 @@
use std::fmt;
use inkwell::{
context::Context,
types::{BasicType, FloatType},
values::FloatValue,
};
use crate::codegen::CodeGenerator;
use super::*;
pub trait FloatKind<'ctx>: fmt::Debug + Clone + Copy {
fn get_float_type<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &'ctx Context,
) -> FloatType<'ctx>;
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Float32;
#[derive(Debug, Clone, Copy, Default)]
pub struct Float64;
impl<'ctx> FloatKind<'ctx> for Float32 {
fn get_float_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
ctx: &'ctx Context,
) -> FloatType<'ctx> {
ctx.f32_type()
}
}
impl<'ctx> FloatKind<'ctx> for Float64 {
fn get_float_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
ctx: &'ctx Context,
) -> FloatType<'ctx> {
ctx.f64_type()
}
}
#[derive(Debug, Clone, Copy)]
pub struct AnyFloat<'ctx>(FloatType<'ctx>);
impl<'ctx> FloatKind<'ctx> for AnyFloat<'ctx> {
fn get_float_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
_ctx: &'ctx Context,
) -> FloatType<'ctx> {
self.0
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Float<N>(pub N);
impl<'ctx, N: FloatKind<'ctx>> Model<'ctx> for Float<N> {
type Value = FloatValue<'ctx>;
type Type = FloatType<'ctx>;
fn llvm_type<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &'ctx Context,
) -> Self::Type {
self.0.get_float_type(generator, ctx)
}
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
ty: T,
) -> Result<(), ModelError> {
let ty = ty.as_basic_type_enum();
let Ok(ty) = FloatType::try_from(ty) else {
return Err(ModelError(format!("Expecting FloatType, but got {ty:?}")));
};
let exp_ty = self.0.get_float_type(generator, ctx);
// TODO: Inkwell does not have get_bit_width for FloatType?
if ty != exp_ty {
return Err(ModelError(format!("Expecting {exp_ty:?}, but got {ty:?}")));
}
Ok(())
}
}

View File

@ -0,0 +1,122 @@
use inkwell::{
attributes::{Attribute, AttributeLoc},
types::{BasicMetadataTypeEnum, BasicType, FunctionType},
values::{AnyValue, BasicMetadataValueEnum, BasicValue, BasicValueEnum, CallSiteValue},
};
use itertools::Itertools;
use crate::codegen::{CodeGenContext, CodeGenerator};
use super::*;
#[derive(Debug, Clone, Copy)]
struct Arg<'ctx> {
ty: BasicMetadataTypeEnum<'ctx>,
val: BasicMetadataValueEnum<'ctx>,
}
/// A convenience structure to construct & call an LLVM function.
///
/// ### Usage
///
/// The syntax is like this:
/// ```ignore
/// let result = CallFunction::begin("my_function_name")
/// .attrs(...)
/// .arg(arg1)
/// .arg(arg2)
/// .arg(arg3)
/// .returning("my_function_result", Int32);
/// ```
///
/// The function `my_function_name` is called when `.returning()` (or its variants) is called, returning
/// the result as an `Instance<'ctx, Int<Int32>>`.
///
/// If `my_function_name` has not been declared in `ctx.module`, once `.returning()` is called, a function
/// declaration of `my_function_name` is added to `ctx.module`, where the [`FunctionType`] is deduced from
/// the argument types and returning type.
pub struct FnCall<'ctx, 'a, 'b, 'c, 'd, G: CodeGenerator + ?Sized> {
generator: &'d mut G,
ctx: &'b CodeGenContext<'ctx, 'a>,
/// Function name
name: &'c str,
/// Call arguments
args: Vec<Arg<'ctx>>,
/// LLVM function Attributes
attrs: Vec<&'static str>,
}
impl<'ctx, 'a, 'b, 'c, 'd, G: CodeGenerator + ?Sized> FnCall<'ctx, 'a, 'b, 'c, 'd, G> {
pub fn builder(generator: &'d mut G, ctx: &'b CodeGenContext<'ctx, 'a>, name: &'c str) -> Self {
FnCall { generator, ctx, name, args: Vec::new(), attrs: Vec::new() }
}
/// Push a list of LLVM function attributes to the function declaration.
#[must_use]
pub fn attrs(mut self, attrs: Vec<&'static str>) -> Self {
self.attrs = attrs;
self
}
/// Push a call argument to the function call.
#[allow(clippy::needless_pass_by_value)]
#[must_use]
pub fn arg<M: Model<'ctx>>(mut self, arg: Instance<'ctx, M>) -> Self {
let arg = Arg {
ty: arg.model.llvm_type(self.generator, self.ctx.ctx).as_basic_type_enum().into(),
val: arg.value.as_basic_value_enum().into(),
};
self.args.push(arg);
self
}
/// Call the function and expect the function to return a value of type of `return_model`.
#[must_use]
pub fn returning<M: Model<'ctx>>(self, name: &str, return_model: M) -> Instance<'ctx, M> {
let ret_ty = return_model.llvm_type(self.generator, self.ctx.ctx);
let ret = self.call(|tys| ret_ty.fn_type(tys, false), name);
let ret = BasicValueEnum::try_from(ret.as_any_value_enum()).unwrap(); // Must work
let ret = return_model.check_value(self.generator, self.ctx.ctx, ret).unwrap(); // Must work
ret
}
/// Like [`CallFunction::returning_`] but `return_model` is automatically inferred.
#[must_use]
pub fn returning_auto<M: Model<'ctx> + Default>(self, name: &str) -> Instance<'ctx, M> {
self.returning(name, M::default())
}
/// Call the function and expect the function to return a void-type.
pub fn returning_void(self) {
let ret_ty = self.ctx.ctx.void_type();
let _ = self.call(|tys| ret_ty.fn_type(tys, false), "");
}
fn call<F>(&self, make_fn_type: F, return_value_name: &str) -> CallSiteValue<'ctx>
where
F: FnOnce(&[BasicMetadataTypeEnum<'ctx>]) -> FunctionType<'ctx>,
{
// Get the LLVM function.
let func = self.ctx.module.get_function(self.name).unwrap_or_else(|| {
// Declare the function if it doesn't exist.
let tys = self.args.iter().map(|arg| arg.ty).collect_vec();
let func_type = make_fn_type(&tys);
let func = self.ctx.module.add_function(self.name, func_type, None);
for attr in &self.attrs {
func.add_attribute(
AttributeLoc::Function,
self.ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
let vals = self.args.iter().map(|arg| arg.val).collect_vec();
self.ctx.builder.build_call(func, &vals, return_value_name).unwrap()
}
}

View File

@ -0,0 +1,422 @@
use std::{cmp::Ordering, fmt};
use inkwell::{
context::Context,
types::{BasicType, IntType},
values::IntValue,
IntPredicate,
};
use crate::codegen::{CodeGenContext, CodeGenerator};
use super::*;
pub trait IntKind<'ctx>: fmt::Debug + Clone + Copy {
fn get_int_type<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &'ctx Context,
) -> IntType<'ctx>;
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Bool;
#[derive(Debug, Clone, Copy, Default)]
pub struct Byte;
#[derive(Debug, Clone, Copy, Default)]
pub struct Int32;
#[derive(Debug, Clone, Copy, Default)]
pub struct Int64;
#[derive(Debug, Clone, Copy, Default)]
pub struct SizeT;
impl<'ctx> IntKind<'ctx> for Bool {
fn get_int_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
ctx: &'ctx Context,
) -> IntType<'ctx> {
ctx.bool_type()
}
}
impl<'ctx> IntKind<'ctx> for Byte {
fn get_int_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
ctx: &'ctx Context,
) -> IntType<'ctx> {
ctx.i8_type()
}
}
impl<'ctx> IntKind<'ctx> for Int32 {
fn get_int_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
ctx: &'ctx Context,
) -> IntType<'ctx> {
ctx.i32_type()
}
}
impl<'ctx> IntKind<'ctx> for Int64 {
fn get_int_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
ctx: &'ctx Context,
) -> IntType<'ctx> {
ctx.i64_type()
}
}
impl<'ctx> IntKind<'ctx> for SizeT {
fn get_int_type<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &'ctx Context,
) -> IntType<'ctx> {
generator.get_size_type(ctx)
}
}
#[derive(Debug, Clone, Copy)]
pub struct AnyInt<'ctx>(pub IntType<'ctx>);
impl<'ctx> IntKind<'ctx> for AnyInt<'ctx> {
fn get_int_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
_ctx: &'ctx Context,
) -> IntType<'ctx> {
self.0
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Int<N>(pub N);
impl<'ctx, N: IntKind<'ctx>> Model<'ctx> for Int<N> {
type Value = IntValue<'ctx>;
type Type = IntType<'ctx>;
fn llvm_type<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &'ctx Context,
) -> Self::Type {
self.0.get_int_type(generator, ctx)
}
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
ty: T,
) -> Result<(), ModelError> {
let ty = ty.as_basic_type_enum();
let Ok(ty) = IntType::try_from(ty) else {
return Err(ModelError(format!("Expecting IntType, but got {ty:?}")));
};
let exp_ty = self.0.get_int_type(generator, ctx);
if ty.get_bit_width() != exp_ty.get_bit_width() {
return Err(ModelError(format!(
"Expecting IntType to have {} bit(s), but got {} bit(s)",
exp_ty.get_bit_width(),
ty.get_bit_width()
)));
}
Ok(())
}
}
impl<'ctx, N: IntKind<'ctx>> Int<N> {
pub fn const_int<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
value: u64,
sign_extend: bool,
) -> Instance<'ctx, Self> {
let value = self.llvm_type(generator, ctx).const_int(value, sign_extend);
unsafe { self.believe_value(value) }
}
pub fn const_0<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
) -> Instance<'ctx, Self> {
let value = self.llvm_type(generator, ctx).const_zero();
unsafe { self.believe_value(value) }
}
pub fn const_1<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
) -> Instance<'ctx, Self> {
self.const_int(generator, ctx, 1, false)
}
pub fn const_all_ones<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
) -> Instance<'ctx, Self> {
let value = self.llvm_type(generator, ctx).const_all_ones();
unsafe { self.believe_value(value) }
}
pub fn s_extend_or_bit_cast<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
value: IntValue<'ctx>,
) -> Instance<'ctx, Self> {
assert!(
value.get_type().get_bit_width()
<= self.0.get_int_type(generator, ctx.ctx).get_bit_width()
);
let value = ctx
.builder
.build_int_s_extend_or_bit_cast(value, self.llvm_type(generator, ctx.ctx), "")
.unwrap();
unsafe { self.believe_value(value) }
}
pub fn s_extend<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
value: IntValue<'ctx>,
) -> Instance<'ctx, Self> {
assert!(
value.get_type().get_bit_width()
< self.0.get_int_type(generator, ctx.ctx).get_bit_width()
);
let value =
ctx.builder.build_int_s_extend(value, self.llvm_type(generator, ctx.ctx), "").unwrap();
unsafe { self.believe_value(value) }
}
pub fn z_extend_or_bit_cast<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
value: IntValue<'ctx>,
) -> Instance<'ctx, Self> {
assert!(
value.get_type().get_bit_width()
<= self.0.get_int_type(generator, ctx.ctx).get_bit_width()
);
let value = ctx
.builder
.build_int_z_extend_or_bit_cast(value, self.llvm_type(generator, ctx.ctx), "")
.unwrap();
unsafe { self.believe_value(value) }
}
pub fn z_extend<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
value: IntValue<'ctx>,
) -> Instance<'ctx, Self> {
assert!(
value.get_type().get_bit_width()
< self.0.get_int_type(generator, ctx.ctx).get_bit_width()
);
let value =
ctx.builder.build_int_z_extend(value, self.llvm_type(generator, ctx.ctx), "").unwrap();
unsafe { self.believe_value(value) }
}
pub fn truncate_or_bit_cast<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
value: IntValue<'ctx>,
) -> Instance<'ctx, Self> {
assert!(
value.get_type().get_bit_width()
>= self.0.get_int_type(generator, ctx.ctx).get_bit_width()
);
let value = ctx
.builder
.build_int_truncate_or_bit_cast(value, self.llvm_type(generator, ctx.ctx), "")
.unwrap();
unsafe { self.believe_value(value) }
}
pub fn truncate<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
value: IntValue<'ctx>,
) -> Instance<'ctx, Self> {
assert!(
value.get_type().get_bit_width()
> self.0.get_int_type(generator, ctx.ctx).get_bit_width()
);
let value =
ctx.builder.build_int_truncate(value, self.llvm_type(generator, ctx.ctx), "").unwrap();
unsafe { self.believe_value(value) }
}
/// `sext` or `trunc` an int to this model's int type. Does nothing if equal bit-widths.
pub fn s_extend_or_truncate<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
value: IntValue<'ctx>,
) -> Instance<'ctx, Self> {
let their_width = value.get_type().get_bit_width();
let our_width = self.0.get_int_type(generator, ctx.ctx).get_bit_width();
match their_width.cmp(&our_width) {
Ordering::Less => self.s_extend(generator, ctx, value),
Ordering::Equal => unsafe { self.believe_value(value) },
Ordering::Greater => self.truncate(generator, ctx, value),
}
}
/// `zext` or `trunc` an int to this model's int type. Does nothing if equal bit-widths.
pub fn z_extend_or_truncate<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
value: IntValue<'ctx>,
) -> Instance<'ctx, Self> {
let their_width = value.get_type().get_bit_width();
let our_width = self.0.get_int_type(generator, ctx.ctx).get_bit_width();
match their_width.cmp(&our_width) {
Ordering::Less => self.z_extend(generator, ctx, value),
Ordering::Equal => unsafe { self.believe_value(value) },
Ordering::Greater => self.truncate(generator, ctx, value),
}
}
}
impl Int<Bool> {
#[must_use]
pub fn const_false<'ctx, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
) -> Instance<'ctx, Self> {
self.const_int(generator, ctx, 0, false)
}
#[must_use]
pub fn const_true<'ctx, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
) -> Instance<'ctx, Self> {
self.const_int(generator, ctx, 1, false)
}
}
impl<'ctx, N: IntKind<'ctx>> Instance<'ctx, Int<N>> {
pub fn s_extend_or_bit_cast<NewN: IntKind<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
to_int_kind: NewN,
) -> Instance<'ctx, Int<NewN>> {
Int(to_int_kind).s_extend_or_bit_cast(generator, ctx, self.value)
}
pub fn s_extend<NewN: IntKind<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
to_int_kind: NewN,
) -> Instance<'ctx, Int<NewN>> {
Int(to_int_kind).s_extend(generator, ctx, self.value)
}
pub fn z_extend_or_bit_cast<NewN: IntKind<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
to_int_kind: NewN,
) -> Instance<'ctx, Int<NewN>> {
Int(to_int_kind).z_extend_or_bit_cast(generator, ctx, self.value)
}
pub fn z_extend<NewN: IntKind<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
to_int_kind: NewN,
) -> Instance<'ctx, Int<NewN>> {
Int(to_int_kind).z_extend(generator, ctx, self.value)
}
pub fn truncate_or_bit_cast<NewN: IntKind<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
to_int_kind: NewN,
) -> Instance<'ctx, Int<NewN>> {
Int(to_int_kind).truncate_or_bit_cast(generator, ctx, self.value)
}
pub fn truncate<NewN: IntKind<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
to_int_kind: NewN,
) -> Instance<'ctx, Int<NewN>> {
Int(to_int_kind).truncate(generator, ctx, self.value)
}
pub fn s_extend_or_truncate<NewN: IntKind<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
to_int_kind: NewN,
) -> Instance<'ctx, Int<NewN>> {
Int(to_int_kind).s_extend_or_truncate(generator, ctx, self.value)
}
pub fn z_extend_or_truncate<NewN: IntKind<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
to_int_kind: NewN,
) -> Instance<'ctx, Int<NewN>> {
Int(to_int_kind).z_extend_or_truncate(generator, ctx, self.value)
}
#[must_use]
pub fn add(&self, ctx: &CodeGenContext<'ctx, '_>, other: Self) -> Self {
let value = ctx.builder.build_int_add(self.value, other.value, "").unwrap();
unsafe { self.model.believe_value(value) }
}
#[must_use]
pub fn sub(&self, ctx: &CodeGenContext<'ctx, '_>, other: Self) -> Self {
let value = ctx.builder.build_int_sub(self.value, other.value, "").unwrap();
unsafe { self.model.believe_value(value) }
}
#[must_use]
pub fn mul(&self, ctx: &CodeGenContext<'ctx, '_>, other: Self) -> Self {
let value = ctx.builder.build_int_mul(self.value, other.value, "").unwrap();
unsafe { self.model.believe_value(value) }
}
pub fn compare(
&self,
ctx: &CodeGenContext<'ctx, '_>,
op: IntPredicate,
other: Self,
) -> Instance<'ctx, Int<Bool>> {
let value = ctx.builder.build_int_compare(op, self.value, other.value, "").unwrap();
unsafe { Int(Bool).believe_value(value) }
}
}

View File

@ -0,0 +1,17 @@
mod any;
mod array;
mod core;
mod float;
pub mod function;
mod int;
mod ptr;
mod structure;
pub mod util;
pub use any::*;
pub use array::*;
pub use core::*;
pub use float::*;
pub use int::*;
pub use ptr::*;
pub use structure::*;

View File

@ -0,0 +1,223 @@
use inkwell::{
context::Context,
types::{BasicType, BasicTypeEnum, PointerType},
values::{IntValue, PointerValue},
AddressSpace,
};
use crate::codegen::{llvm_intrinsics::call_memcpy_generic, CodeGenContext, CodeGenerator};
use super::*;
/// A model for [`PointerType`].
///
/// `Item` is the element type this pointer is pointing to, and should be of a [`Model`].
///
// TODO: LLVM 15: `Item` is a Rust type-hint for the LLVM type of value the `.store()/.load()` family
// of functions return. If a truly opaque pointer is needed, tell the programmer to use `OpaquePtr`.
#[derive(Debug, Clone, Copy, Default)]
pub struct Ptr<Item>(pub Item);
/// An opaque pointer. Like [`Ptr`] but without any Rust type-hints about its element type.
///
/// `.load()/.store()` is not available for [`Instance`]s of opaque pointers.
pub type OpaquePtr = Ptr<()>;
// TODO: LLVM 15: `Item: Model<'ctx>` don't even need to be a model anymore. It will only be
// a type-hint for the `.load()/.store()` functions for the `pointee_ty`.
//
// See https://thedan64.github.io/inkwell/inkwell/builder/struct.Builder.html#method.build_load.
impl<'ctx, Item: Model<'ctx>> Model<'ctx> for Ptr<Item> {
type Value = PointerValue<'ctx>;
type Type = PointerType<'ctx>;
fn llvm_type<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &'ctx Context,
) -> Self::Type {
// TODO: LLVM 15: ctx.ptr_type(AddressSpace::default())
self.0.llvm_type(generator, ctx).ptr_type(AddressSpace::default())
}
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
ty: T,
) -> Result<(), ModelError> {
let ty = ty.as_basic_type_enum();
let Ok(ty) = PointerType::try_from(ty) else {
return Err(ModelError(format!("Expecting PointerType, but got {ty:?}")));
};
let elem_ty = ty.get_element_type();
let Ok(elem_ty) = BasicTypeEnum::try_from(elem_ty) else {
return Err(ModelError(format!(
"Expecting pointer element type to be a BasicTypeEnum, but got {elem_ty:?}"
)));
};
// TODO: inkwell `get_element_type()` will be deprecated.
// Remove the check for `get_element_type()` when the time comes.
self.0
.check_type(generator, ctx, elem_ty)
.map_err(|err| err.under_context("a PointerType"))?;
Ok(())
}
}
impl<'ctx, Item: Model<'ctx>> Ptr<Item> {
/// Return a ***constant*** nullptr.
pub fn nullptr<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
) -> Instance<'ctx, Ptr<Item>> {
let ptr = self.llvm_type(generator, ctx).const_null();
unsafe { self.believe_value(ptr) }
}
/// Cast a pointer into this model with [`inkwell::builder::Builder::build_pointer_cast`]
pub fn pointer_cast<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
ptr: PointerValue<'ctx>,
) -> Instance<'ctx, Ptr<Item>> {
// TODO: LLVM 15: Write in an impl where `Item` does not have to be `Model<'ctx>`.
// TODO: LLVM 15: This function will only have to be:
// ```
// return self.believe_value(ptr);
// ```
let t = self.llvm_type(generator, ctx.ctx);
let ptr = ctx.builder.build_pointer_cast(ptr, t, "").unwrap();
unsafe { self.believe_value(ptr) }
}
}
impl<'ctx, Item: Model<'ctx>> Instance<'ctx, Ptr<Item>> {
/// Offset the pointer by [`inkwell::builder::Builder::build_in_bounds_gep`].
#[must_use]
pub fn offset(
&self,
ctx: &CodeGenContext<'ctx, '_>,
offset: IntValue<'ctx>,
) -> Instance<'ctx, Ptr<Item>> {
let p = unsafe { ctx.builder.build_in_bounds_gep(self.value, &[offset], "").unwrap() };
unsafe { self.model.believe_value(p) }
}
/// Offset the pointer by [`inkwell::builder::Builder::build_in_bounds_gep`] by a constant offset.
#[must_use]
pub fn offset_const(
&self,
ctx: &CodeGenContext<'ctx, '_>,
offset: i64,
) -> Instance<'ctx, Ptr<Item>> {
let offset = ctx.ctx.i32_type().const_int(offset as u64, true);
self.offset(ctx, offset)
}
pub fn set_index(
&self,
ctx: &CodeGenContext<'ctx, '_>,
index: IntValue<'ctx>,
value: Instance<'ctx, Item>,
) {
self.offset(ctx, index).store(ctx, value);
}
pub fn set_index_const(
&self,
ctx: &CodeGenContext<'ctx, '_>,
index: i64,
value: Instance<'ctx, Item>,
) {
self.offset_const(ctx, index).store(ctx, value);
}
pub fn get_index<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
index: IntValue<'ctx>,
) -> Instance<'ctx, Item> {
self.offset(ctx, index).load(generator, ctx)
}
pub fn get_index_const<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
index: i64,
) -> Instance<'ctx, Item> {
self.offset_const(ctx, index).load(generator, ctx)
}
/// Load the value with [`inkwell::builder::Builder::build_load`].
pub fn load<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
) -> Instance<'ctx, Item> {
let value = ctx.builder.build_load(self.value, "").unwrap();
self.model.0.check_value(generator, ctx.ctx, value).unwrap() // If unwrap() panics, there is a logic error.
}
/// Store a value with [`inkwell::builder::Builder::build_store`].
pub fn store(&self, ctx: &CodeGenContext<'ctx, '_>, value: Instance<'ctx, Item>) {
ctx.builder.build_store(self.value, value.value).unwrap();
}
/// Return a casted pointer of element type `NewElement` with [`inkwell::builder::Builder::build_pointer_cast`].
pub fn pointer_cast<NewItem: Model<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
new_item: NewItem,
) -> Instance<'ctx, Ptr<NewItem>> {
// TODO: LLVM 15: Write in an impl where `Item` does not have to be `Model<'ctx>`.
Ptr(new_item).pointer_cast(generator, ctx, self.value)
}
/// Cast this pointer to `uint8_t*`
pub fn cast_to_pi8<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
) -> Instance<'ctx, Ptr<Int<Byte>>> {
Ptr(Int(Byte)).pointer_cast(generator, ctx, self.value)
}
/// Check if the pointer is null with [`inkwell::builder::Builder::build_is_null`].
pub fn is_null(&self, ctx: &CodeGenContext<'ctx, '_>) -> Instance<'ctx, Int<Bool>> {
let value = ctx.builder.build_is_null(self.value, "").unwrap();
unsafe { Int(Bool).believe_value(value) }
}
/// Check if the pointer is not null with [`inkwell::builder::Builder::build_is_not_null`].
pub fn is_not_null(&self, ctx: &CodeGenContext<'ctx, '_>) -> Instance<'ctx, Int<Bool>> {
let value = ctx.builder.build_is_not_null(self.value, "").unwrap();
unsafe { Int(Bool).believe_value(value) }
}
/// `memcpy` from another pointer.
pub fn copy_from<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
source: Self,
num_items: IntValue<'ctx>,
) {
// Force extend `num_items` and `itemsize` to `i64` so their types would match.
let itemsize = self.model.size_of(generator, ctx.ctx);
let itemsize = Int(SizeT).z_extend_or_truncate(generator, ctx, itemsize);
let num_items = Int(SizeT).z_extend_or_truncate(generator, ctx, num_items);
let totalsize = itemsize.mul(ctx, num_items);
let is_volatile = ctx.ctx.bool_type().const_zero(); // is_volatile = false
call_memcpy_generic(ctx, self.value, source.value, totalsize.value, is_volatile);
}
}

View File

@ -0,0 +1,364 @@
use std::fmt;
use inkwell::{
context::Context,
types::{BasicType, BasicTypeEnum, StructType},
values::{BasicValueEnum, StructValue},
};
use crate::codegen::{CodeGenContext, CodeGenerator};
use super::*;
/// A traveral that traverses a Rust `struct` that is used to declare an LLVM's struct's field types.
pub trait FieldTraversal<'ctx> {
/// Output type of [`FieldTraversal::add`].
type Output<M>;
/// Traverse through the type of a declared field and do something with it.
///
/// * `name` - The cosmetic name of the LLVM field. Used for debugging.
/// * `model` - The [`Model`] representing the LLVM type of this field.
fn add<M: Model<'ctx>>(&mut self, name: &'static str, model: M) -> Self::Output<M>;
/// Like [`FieldTraversal::add`] but [`Model`] is automatically inferred from its [`Default`] trait.
fn add_auto<M: Model<'ctx> + Default>(&mut self, name: &'static str) -> Self::Output<M> {
self.add(name, M::default())
}
}
/// Descriptor of an LLVM struct field.
#[derive(Debug, Clone, Copy)]
pub struct GepField<M> {
/// The GEP index of this field. This is the index to use with `build_gep`.
pub gep_index: u32,
/// The cosmetic name of this field.
pub name: &'static str,
/// The [`Model`] of this field's type.
pub model: M,
}
/// A traversal to calculate the GEP index of fields.
pub struct GepFieldTraversal {
/// The current GEP index.
gep_index_counter: u32,
}
impl<'ctx> FieldTraversal<'ctx> for GepFieldTraversal {
type Output<M> = GepField<M>;
fn add<M: Model<'ctx>>(&mut self, name: &'static str, model: M) -> Self::Output<M> {
let gep_index = self.gep_index_counter;
self.gep_index_counter += 1;
Self::Output { gep_index, name, model }
}
}
/// A traversal to collect the field types of a struct.
///
/// This is used to collect field types and construct the LLVM struct type with [`Context::struct_type`].
struct TypeFieldTraversal<'ctx, 'a, G: CodeGenerator + ?Sized> {
generator: &'a G,
ctx: &'ctx Context,
/// The collected field types so far in exact order.
field_types: Vec<BasicTypeEnum<'ctx>>,
}
impl<'ctx, 'a, G: CodeGenerator + ?Sized> FieldTraversal<'ctx> for TypeFieldTraversal<'ctx, 'a, G> {
type Output<M> = (); // Checking types return nothing.
fn add<M: Model<'ctx>>(&mut self, _name: &'static str, model: M) -> Self::Output<M> {
let t = model.llvm_type(self.generator, self.ctx).as_basic_type_enum();
self.field_types.push(t);
}
}
/// A traversal to check the types of fields.
struct CheckTypeFieldTraversal<'ctx, 'a, G: CodeGenerator + ?Sized> {
generator: &'a mut G,
ctx: &'ctx Context,
/// The current GEP index, so we can tell the index of the field we are checking
/// and report the GEP index.
gep_index_counter: u32,
/// The [`StructType`] to check.
scrutinee: StructType<'ctx>,
/// The list of collected errors so far.
errors: Vec<ModelError>,
}
impl<'ctx, 'a, G: CodeGenerator + ?Sized> FieldTraversal<'ctx>
for CheckTypeFieldTraversal<'ctx, 'a, G>
{
type Output<M> = (); // Checking types return nothing.
fn add<M: Model<'ctx>>(&mut self, name: &'static str, model: M) -> Self::Output<M> {
let gep_index = self.gep_index_counter;
self.gep_index_counter += 1;
if let Some(t) = self.scrutinee.get_field_type_at_index(gep_index) {
if let Err(err) = model.check_type(self.generator, self.ctx, t) {
self.errors
.push(err.under_context(format!("field #{gep_index} '{name}'").as_str()));
}
}
// Otherwise, it will be caught by Struct's `check_type`.
}
}
/// A trait for Rust structs identifying LLVM structures.
///
/// ### Example
///
/// Suppose you want to define this structure:
/// ```c
/// template <typename T>
/// struct ContiguousNDArray {
/// size_t ndims;
/// size_t* shape;
/// T* data;
/// }
/// ```
///
/// This is how it should be done:
/// ```ignore
/// pub struct ContiguousNDArrayFields<'ctx, F: FieldTraversal<'ctx>, Item: Model<'ctx>> {
/// pub ndims: F::Out<Int<SizeT>>,
/// pub shape: F::Out<Ptr<Int<SizeT>>>,
/// pub data: F::Out<Ptr<Item>>,
/// }
///
/// /// An ndarray without strides and non-opaque `data` field in NAC3.
/// #[derive(Debug, Clone, Copy)]
/// pub struct ContiguousNDArray<M> {
/// /// [`Model`] of the items.
/// pub item: M,
/// }
///
/// impl<'ctx, Item: Model<'ctx>> StructKind<'ctx> for ContiguousNDArray<Item> {
/// type Fields<F: FieldTraversal<'ctx>> = ContiguousNDArrayFields<'ctx, F, Item>;
///
/// fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
/// // The order of `traversal.add*` is important
/// Self::Fields {
/// ndims: traversal.add_auto("ndims"),
/// shape: traversal.add_auto("shape"),
/// data: traversal.add("data", Ptr(self.item)),
/// }
/// }
/// }
/// ```
///
/// The [`FieldTraversal`] here is a mechanism to allow the fields of `ContiguousNDArrayFields` to be
/// traversed to do useful work such as:
///
/// - To create the [`StructType`] of `ContiguousNDArray` by collecting [`BasicType`]s of the fields.
/// - To enable the `.gep(ctx, |f| f.ndims).store(ctx, ...)` syntax.
///
/// Suppose now that you have defined `ContiguousNDArray` and you want to allocate a `ContiguousNDArray`
/// with dtype `float64` in LLVM, this is how you do it:
/// ```ignore
/// type F64NDArray = Struct<ContiguousNDArray<Float<Float64>>>; // Type alias for leaner documentation
/// let model: F64NDArray = Struct(ContigousNDArray { item: Float(Float64) });
/// let ndarray: Instance<'ctx, Ptr<F64NDArray>> = model.alloca(generator, ctx);
/// ```
///
/// ...and here is how you may manipulate/access `ndarray`:
///
/// (NOTE: some arguments have been omitted)
///
/// ```ignore
/// // Get `&ndarray->data`
/// ndarray.gep(|f| f.data); // type: Instance<'ctx, Ptr<Float<Float64>>>
///
/// // Get `ndarray->ndims`
/// ndarray.get(|f| f.ndims); // type: Instance<'ctx, Int<SizeT>>
///
/// // Get `&ndarray->ndims`
/// ndarray.gep(|f| f.ndims); // type: Instance<'ctx, Ptr<Int<SizeT>>>
///
/// // Get `ndarray->shape[0]`
/// ndarray.get(|f| f.shape).get_index_const(0); // Instance<'ctx, Int<SizeT>>
///
/// // Get `&ndarray->shape[2]`
/// ndarray.get(|f| f.shape).offset_const(2); // Instance<'ctx, Ptr<Int<SizeT>>>
///
/// // Do `ndarray->ndims = 3;`
/// let num_3 = Int(SizeT).const_int(3);
/// ndarray.set(|f| f.ndims, num_3);
/// ```
pub trait StructKind<'ctx>: fmt::Debug + Clone + Copy {
/// The associated fields of this struct.
type Fields<F: FieldTraversal<'ctx>>;
/// Traverse through all fields of this [`StructKind`].
///
/// Only used internally in this module for implementing other components.
fn iter_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F>;
/// Get a convenience structure to get a struct field's GEP index through its corresponding Rust field.
///
/// Only used internally in this module for implementing other components.
fn fields(&self) -> Self::Fields<GepFieldTraversal> {
self.iter_fields(&mut GepFieldTraversal { gep_index_counter: 0 })
}
/// Get the LLVM [`StructType`] of this [`StructKind`].
fn get_struct_type<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &'ctx Context,
) -> StructType<'ctx> {
let mut traversal = TypeFieldTraversal { generator, ctx, field_types: Vec::new() };
self.iter_fields(&mut traversal);
ctx.struct_type(&traversal.field_types, false)
}
}
/// A model for LLVM struct.
///
/// `S` should be of a [`StructKind`].
#[derive(Debug, Clone, Copy, Default)]
pub struct Struct<S>(pub S);
impl<'ctx, S: StructKind<'ctx>> Struct<S> {
/// Create a constant struct value from its fields.
///
/// This function also validates `fields` and panic when there is something wrong.
pub fn const_struct<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
fields: &[BasicValueEnum<'ctx>],
) -> Instance<'ctx, Self> {
// NOTE: There *could* have been a functor `F<M> = Instance<'ctx, M>` for `S::Fields<F>`
// to create a more user-friendly interface, but Rust's type system is not sophisticated enough
// and if you try doing that Rust would force you put lifetimes everywhere.
let val = ctx.const_struct(fields, false);
self.check_value(generator, ctx, val).unwrap()
}
}
impl<'ctx, S: StructKind<'ctx>> Model<'ctx> for Struct<S> {
type Value = StructValue<'ctx>;
type Type = StructType<'ctx>;
fn llvm_type<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &'ctx Context,
) -> Self::Type {
self.0.get_struct_type(generator, ctx)
}
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
ty: T,
) -> Result<(), ModelError> {
let ty = ty.as_basic_type_enum();
let Ok(ty) = StructType::try_from(ty) else {
return Err(ModelError(format!("Expecting StructType, but got {ty:?}")));
};
// Check each field individually.
let mut traversal = CheckTypeFieldTraversal {
generator,
ctx,
gep_index_counter: 0,
errors: Vec::new(),
scrutinee: ty,
};
self.0.iter_fields(&mut traversal);
// Check the number of fields.
let exp_num_fields = traversal.gep_index_counter;
let got_num_fields = u32::try_from(ty.get_field_types().len()).unwrap();
if exp_num_fields != got_num_fields {
return Err(ModelError(format!(
"Expecting StructType with {exp_num_fields} field(s), but got {got_num_fields}"
)));
}
if !traversal.errors.is_empty() {
// Currently, only the first error is reported.
return Err(traversal.errors[0].clone());
}
Ok(())
}
}
impl<'ctx, S: StructKind<'ctx>> Instance<'ctx, Struct<S>> {
/// Get a field with [`StructValue::get_field_at_index`].
pub fn get_field<G: CodeGenerator + ?Sized, M, GetField>(
&self,
generator: &mut G,
ctx: &'ctx Context,
get_field: GetField,
) -> Instance<'ctx, M>
where
M: Model<'ctx>,
GetField: FnOnce(S::Fields<GepFieldTraversal>) -> GepField<M>,
{
let field = get_field(self.model.0.fields());
let val = self.value.get_field_at_index(field.gep_index).unwrap();
field.model.check_value(generator, ctx, val).unwrap()
}
}
impl<'ctx, S: StructKind<'ctx>> Instance<'ctx, Ptr<Struct<S>>> {
/// Get a pointer to a field with [`Builder::build_in_bounds_gep`].
pub fn gep<M, GetField>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
get_field: GetField,
) -> Instance<'ctx, Ptr<M>>
where
M: Model<'ctx>,
GetField: FnOnce(S::Fields<GepFieldTraversal>) -> GepField<M>,
{
let field = get_field(self.model.0 .0.fields());
let llvm_i32 = ctx.ctx.i32_type();
let ptr = unsafe {
ctx.builder
.build_in_bounds_gep(
self.value,
&[llvm_i32.const_zero(), llvm_i32.const_int(u64::from(field.gep_index), false)],
field.name,
)
.unwrap()
};
unsafe { Ptr(field.model).believe_value(ptr) }
}
/// Convenience function equivalent to `.gep(...).load(...)`.
pub fn get<M, GetField, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
get_field: GetField,
) -> Instance<'ctx, M>
where
M: Model<'ctx>,
GetField: FnOnce(S::Fields<GepFieldTraversal>) -> GepField<M>,
{
self.gep(ctx, get_field).load(generator, ctx)
}
/// Convenience function equivalent to `.gep(...).store(...)`.
pub fn set<M, GetField>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
get_field: GetField,
value: Instance<'ctx, M>,
) where
M: Model<'ctx>,
GetField: FnOnce(S::Fields<GepFieldTraversal>) -> GepField<M>,
{
self.gep(ctx, get_field).store(ctx, value);
}
}

View File

@ -0,0 +1,42 @@
use crate::codegen::{
stmt::{gen_for_callback_incrementing, BreakContinueHooks},
CodeGenContext, CodeGenerator,
};
use super::*;
/// Like [`gen_for_callback_incrementing`] with [`Model`] abstractions.
///
/// `stop` is not included.
pub fn gen_for_model<'ctx, 'a, G, F, N>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
start: Instance<'ctx, Int<N>>,
stop: Instance<'ctx, Int<N>>,
step: Instance<'ctx, Int<N>>,
body: F,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
F: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
BreakContinueHooks<'ctx>,
Instance<'ctx, Int<N>>,
) -> Result<(), String>,
N: IntKind<'ctx> + Default,
{
let int_model = Int(N::default());
gen_for_callback_incrementing(
generator,
ctx,
None,
start.value,
(stop.value, false),
|g, ctx, hooks, i| {
let i = unsafe { int_model.believe_value(i) };
body(g, ctx, hooks, i)
},
step.value,
)
}

View File

@ -1,39 +1,47 @@
use inkwell::{
types::{AnyTypeEnum, BasicType, BasicTypeEnum, PointerType},
values::{BasicValue, BasicValueEnum, IntValue, PointerValue},
AddressSpace, IntPredicate, OptimizationLevel,
};
use nac3parser::ast::{Operator, StrRef};
use super::{
classes::{
ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, NDArrayType, NDArrayValue,
ProxyType, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter,
TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
},
expr::gen_binop_expr_with_values,
irrt::{
calculate_len_for_slice_range, call_ndarray_calc_broadcast,
call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices, call_ndarray_calc_size,
},
llvm_intrinsics::{self, call_memcpy_generic},
macros::codegen_unreachable,
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
CodeGenContext, CodeGenerator,
};
use crate::{
codegen::{
classes::{
ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, NDArrayType, NDArrayValue,
ProxyType, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter,
TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
},
expr::gen_binop_expr_with_values,
irrt::{
calculate_len_for_slice_range, call_ndarray_calc_broadcast,
call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices,
call_ndarray_calc_size,
},
llvm_intrinsics::{self, call_memcpy_generic},
macros::codegen_unreachable,
model::*,
object::{
any::AnyObject,
ndarray::{shape_util::parse_numpy_int_sequence, NDArrayObject},
},
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
CodeGenContext, CodeGenerator,
},
symbol_resolver::ValueEnum,
toplevel::{
helper::PrimDef,
helper::extract_ndims,
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
DefinitionId,
},
typecheck::{
magic_methods::Binop,
typedef::{FunSignature, Type, TypeEnum},
typedef::{FunSignature, Type},
},
};
use inkwell::{
types::BasicType,
values::{BasicValueEnum, IntValue, PointerValue},
AddressSpace, IntPredicate, OptimizationLevel,
};
use inkwell::{
types::{AnyTypeEnum, BasicTypeEnum, PointerType},
values::BasicValue,
};
use nac3parser::ast::{Operator, StrRef};
/// Creates an uninitialized `NDArray` instance.
fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
@ -939,7 +947,7 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
.build_store(
lst,
ctx.builder
.build_bit_cast(object.as_base_value(), llvm_plist_i8, "")
.build_bitcast(object.as_base_value(), llvm_plist_i8, "")
.unwrap(),
)
.unwrap();
@ -961,7 +969,7 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
.builder
.build_load(lst, "")
.map(BasicValueEnum::into_pointer_value)
.map(|v| ctx.builder.build_bit_cast(v, plist_plist_i8, "").unwrap())
.map(|v| ctx.builder.build_bitcast(v, plist_plist_i8, "").unwrap())
.map(BasicValueEnum::into_pointer_value)
.unwrap();
let this_dim = ListValue::from_ptr_val(this_dim, llvm_usize, None);
@ -980,9 +988,7 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
ctx.builder
.build_store(
lst,
ctx.builder
.build_bit_cast(next_dim, llvm_plist_i8, "")
.unwrap(),
ctx.builder.build_bitcast(next_dim, llvm_plist_i8, "").unwrap(),
)
.unwrap();
@ -1742,8 +1748,13 @@ pub fn gen_ndarray_empty<'ctx>(
let shape_ty = fun.0.args[0].ty;
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
call_ndarray_empty_impl(generator, context, context.primitives.float, shape_arg)
.map(NDArrayValue::into)
let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
let ndims = extract_ndims(&context.unifier, ndims);
let shape = AnyObject { value: shape_arg, ty: shape_ty };
let (_, shape) = parse_numpy_int_sequence(generator, context, shape);
let ndarray = NDArrayObject::make_np_empty(generator, context, dtype, ndims, shape);
Ok(ndarray.instance.value)
}
/// Generates LLVM IR for `ndarray.zeros`.
@ -1760,8 +1771,13 @@ pub fn gen_ndarray_zeros<'ctx>(
let shape_ty = fun.0.args[0].ty;
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
call_ndarray_zeros_impl(generator, context, context.primitives.float, shape_arg)
.map(NDArrayValue::into)
let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
let ndims = extract_ndims(&context.unifier, ndims);
let shape = AnyObject { value: shape_arg, ty: shape_ty };
let (_, shape) = parse_numpy_int_sequence(generator, context, shape);
let ndarray = NDArrayObject::make_np_zeros(generator, context, dtype, ndims, shape);
Ok(ndarray.instance.value)
}
/// Generates LLVM IR for `ndarray.ones`.
@ -1778,8 +1794,13 @@ pub fn gen_ndarray_ones<'ctx>(
let shape_ty = fun.0.args[0].ty;
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
call_ndarray_ones_impl(generator, context, context.primitives.float, shape_arg)
.map(NDArrayValue::into)
let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
let ndims = extract_ndims(&context.unifier, ndims);
let shape = AnyObject { value: shape_arg, ty: shape_ty };
let (_, shape) = parse_numpy_int_sequence(generator, context, shape);
let ndarray = NDArrayObject::make_np_ones(generator, context, dtype, ndims, shape);
Ok(ndarray.instance.value)
}
/// Generates LLVM IR for `ndarray.full`.
@ -1799,8 +1820,14 @@ pub fn gen_ndarray_full<'ctx>(
let fill_value_arg =
args[1].1.clone().to_basic_value_enum(context, generator, fill_value_ty)?;
call_ndarray_full_impl(generator, context, fill_value_ty, shape_arg, fill_value_arg)
.map(NDArrayValue::into)
let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
let ndims = extract_ndims(&context.unifier, ndims);
let shape = AnyObject { value: shape_arg, ty: shape_ty };
let (_, shape) = parse_numpy_int_sequence(generator, context, shape);
let ndarray =
NDArrayObject::make_np_full(generator, context, dtype, ndims, shape, fill_value_arg);
Ok(ndarray.instance.value)
}
pub fn gen_ndarray_array<'ctx>(
@ -1814,26 +1841,6 @@ pub fn gen_ndarray_array<'ctx>(
assert!(matches!(args.len(), 1..=3));
let obj_ty = fun.0.args[0].ty;
let obj_elem_ty = match &*context.unifier.get_ty(obj_ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
unpack_ndarray_var_tys(&mut context.unifier, obj_ty).0
}
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
let mut ty = *params.iter().next().unwrap().1;
while let TypeEnum::TObj { obj_id, params, .. } = &*context.unifier.get_ty_immutable(ty)
{
if *obj_id != PrimDef::List.id() {
break;
}
ty = *params.iter().next().unwrap().1;
}
ty
}
_ => obj_ty,
};
let obj_arg = args[0].1.clone().to_basic_value_enum(context, generator, obj_ty)?;
let copy_arg = if let Some(arg) =
@ -1849,28 +1856,18 @@ pub fn gen_ndarray_array<'ctx>(
)
};
let ndmin_arg = if let Some(arg) =
args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[2].name))
{
let ndmin_ty = fun.0.args[2].ty;
arg.1.clone().to_basic_value_enum(context, generator, ndmin_ty)?
} else {
context.gen_symbol_val(
generator,
fun.0.args[2].default_value.as_ref().unwrap(),
fun.0.args[2].ty,
)
};
// The ndmin argument is ignored. We can simply force the ndarray's number of dimensions to be
// the `ndims` of the function return type.
let (_, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
let ndims = extract_ndims(&context.unifier, ndims);
call_ndarray_array_impl(
generator,
context,
obj_elem_ty,
obj_arg,
copy_arg.into_int_value(),
ndmin_arg.into_int_value(),
)
.map(NDArrayValue::into)
let object = AnyObject { value: obj_arg, ty: obj_ty };
// NAC3 booleans are i8.
let copy = Int(Bool).truncate(generator, context, copy_arg.into_int_value());
let ndarray = NDArrayObject::make_np_array(generator, context, object, copy)
.atleast_nd(generator, context, ndims);
Ok(ndarray.instance.value)
}
/// Generates LLVM IR for `ndarray.eye`.
@ -1909,15 +1906,23 @@ pub fn gen_ndarray_eye<'ctx>(
))
}?;
call_ndarray_eye_impl(
generator,
context,
context.primitives.float,
nrows_arg.into_int_value(),
ncols_arg.into_int_value(),
offset_arg.into_int_value(),
)
.map(NDArrayValue::into)
let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
let nrows = Int(Int32)
.check_value(generator, context.ctx, nrows_arg)
.unwrap()
.s_extend_or_bit_cast(generator, context, SizeT);
let ncols = Int(Int32)
.check_value(generator, context.ctx, ncols_arg)
.unwrap()
.s_extend_or_bit_cast(generator, context, SizeT);
let offset = Int(Int32)
.check_value(generator, context.ctx, offset_arg)
.unwrap()
.s_extend_or_bit_cast(generator, context, SizeT);
let ndarray = NDArrayObject::make_np_eye(generator, context, dtype, nrows, ncols, offset);
Ok(ndarray.instance.value)
}
/// Generates LLVM IR for `ndarray.identity`.
@ -1931,20 +1936,15 @@ pub fn gen_ndarray_identity<'ctx>(
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let llvm_usize = generator.get_size_type(context.ctx);
let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
let n_ty = fun.0.args[0].ty;
let n_arg = args[0].1.clone().to_basic_value_enum(context, generator, n_ty)?;
call_ndarray_eye_impl(
generator,
context,
context.primitives.float,
n_arg.into_int_value(),
n_arg.into_int_value(),
llvm_usize.const_zero(),
)
.map(NDArrayValue::into)
let n = Int(Int32).check_value(generator, context.ctx, n_arg).unwrap();
let n = n.s_extend_or_bit_cast(generator, context, SizeT);
let ndarray = NDArrayObject::make_np_identity(generator, context, dtype, n);
Ok(ndarray.instance.value)
}
/// Generates LLVM IR for `ndarray.copy`.
@ -1958,20 +1958,14 @@ pub fn gen_ndarray_copy<'ctx>(
assert!(obj.is_some());
assert!(args.is_empty());
let llvm_usize = generator.get_size_type(context.ctx);
let this_ty = obj.as_ref().unwrap().0;
let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty);
let this_arg =
obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?;
ndarray_copy_impl(
generator,
context,
this_elem_ty,
NDArrayValue::from_ptr_val(this_arg.into_pointer_value(), llvm_usize, None),
)
.map(NDArrayValue::into)
let this = AnyObject { value: this_arg, ty: this_ty };
let this = NDArrayObject::from_object(generator, context, this);
let ndarray = this.make_copy(generator, context);
Ok(ndarray.instance.value)
}
/// Generates LLVM IR for `ndarray.fill`.
@ -1985,48 +1979,15 @@ pub fn gen_ndarray_fill<'ctx>(
assert!(obj.is_some());
assert_eq!(args.len(), 1);
let llvm_usize = generator.get_size_type(context.ctx);
let this_ty = obj.as_ref().unwrap().0;
let this_arg = obj
.as_ref()
.unwrap()
.1
.clone()
.to_basic_value_enum(context, generator, this_ty)?
.into_pointer_value();
let this_arg =
obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?;
let value_ty = fun.0.args[0].ty;
let value_arg = args[0].1.clone().to_basic_value_enum(context, generator, value_ty)?;
ndarray_fill_flattened(
generator,
context,
NDArrayValue::from_ptr_val(this_arg, llvm_usize, None),
|generator, ctx, _| {
let value = if value_arg.is_pointer_value() {
let llvm_i1 = ctx.ctx.bool_type();
let copy = generator.gen_var_alloc(ctx, value_arg.get_type(), None)?;
call_memcpy_generic(
ctx,
copy,
value_arg.into_pointer_value(),
value_arg.get_type().size_of().map(Into::into).unwrap(),
llvm_i1.const_zero(),
);
copy.into()
} else if value_arg.is_int_value() || value_arg.is_float_value() {
value_arg
} else {
codegen_unreachable!(ctx)
};
Ok(value)
},
)?;
let this = AnyObject { value: this_arg, ty: this_ty };
let this = NDArrayObject::from_object(generator, context, this);
this.fill(generator, context, value_arg);
Ok(())
}
@ -2137,293 +2098,6 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
}
}
/// LLVM-typed implementation for generating the implementation for `ndarray.reshape`.
///
/// * `x1` - `NDArray` to reshape.
/// * `shape` - The `shape` parameter used to construct the new `NDArray`.
/// Just like numpy, the `shape` argument can be:
/// 1. A list of `int32`; e.g., `np.reshape(arr, [600, -1, 3])`
/// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))`
/// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)`
///
/// Note that unlike other generating functions, one of the dimensions in the shape can be negative.
pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
shape: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "ndarray_reshape";
let (x1_ty, x1) = x1;
let (_, shape) = shape;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
ctx.builder.build_store(acc, llvm_usize.const_int(1, false)).unwrap();
ctx.builder.build_store(num_neg, llvm_usize.const_zero()).unwrap();
let out = match shape {
BasicValueEnum::PointerValue(shape_list_ptr)
if ListValue::is_instance(shape_list_ptr, llvm_usize).is_ok() =>
{
// 1. A list of ints; e.g., `np.reshape(arr, [int64(600), int64(800, -1])`
let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None);
// Check for -1 in dimensions
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(shape_list.load_size(ctx, None), false),
|generator, ctx, _, idx| {
let ele =
shape_list.data().get(ctx, generator, &idx, None).into_int_value();
let ele = ctx.builder.build_int_s_extend(ele, llvm_usize, "").unwrap();
gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
ele,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, ctx| -> Result<Option<IntValue>, String> {
let num_neg_value =
ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
let num_neg_value = ctx
.builder
.build_int_add(
num_neg_value,
llvm_usize.const_int(1, false),
"",
)
.unwrap();
ctx.builder.build_store(num_neg, num_neg_value).unwrap();
Ok(None)
},
|_, ctx| {
let acc_value =
ctx.builder.build_load(acc, "").unwrap().into_int_value();
let acc_value =
ctx.builder.build_int_mul(acc_value, ele, "").unwrap();
ctx.builder.build_store(acc, acc_value).unwrap();
Ok(None)
},
)?;
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
// Generate the output shape by filling -1 with `rem`
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&shape_list,
|_, ctx, _| Ok(shape_list.load_size(ctx, None)),
|generator, ctx, shape_list, idx| {
let dim =
shape_list.data().get(ctx, generator, &idx, None).into_int_value();
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
Ok(gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
dim,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, _| Ok(Some(rem)),
|_, _| Ok(Some(dim)),
)?
.unwrap()
.into_int_value())
},
)
}
BasicValueEnum::StructValue(shape_tuple) => {
// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))`
let ndims = shape_tuple.get_type().count_fields();
// Check for -1 in dims
for dim_i in 0..ndims {
let dim = ctx
.builder
.build_extract_value(shape_tuple, dim_i, "")
.unwrap()
.into_int_value();
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
dim,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, ctx| -> Result<Option<IntValue>, String> {
let num_negs =
ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
let num_negs = ctx
.builder
.build_int_add(num_negs, llvm_usize.const_int(1, false), "")
.unwrap();
ctx.builder.build_store(num_neg, num_negs).unwrap();
Ok(None)
},
|_, ctx| {
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
let acc_val = ctx.builder.build_int_mul(acc_val, dim, "").unwrap();
ctx.builder.build_store(acc, acc_val).unwrap();
Ok(None)
},
)?;
}
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
let mut shape = Vec::with_capacity(ndims as usize);
// Reconstruct shape filling negatives with rem
for dim_i in 0..ndims {
let dim = ctx
.builder
.build_extract_value(shape_tuple, dim_i, "")
.unwrap()
.into_int_value();
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
let dim = gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
dim,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, _| Ok(Some(rem)),
|_, _| Ok(Some(dim)),
)?
.unwrap()
.into_int_value();
shape.push(dim);
}
create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice())
}
BasicValueEnum::IntValue(shape_int) => {
// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)`
let shape_int = gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
shape_int,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, _| Ok(Some(n_sz)),
|_, ctx| {
Ok(Some(ctx.builder.build_int_s_extend(shape_int, llvm_usize, "").unwrap()))
},
)?
.unwrap()
.into_int_value();
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
}
_ => codegen_unreachable!(ctx),
}
.unwrap();
// Only allow one dimension to be negative
let num_negs = ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
ctx.make_assert(
generator,
ctx.builder
.build_int_compare(IntPredicate::ULT, num_negs, llvm_usize.const_int(2, false), "")
.unwrap(),
"0:ValueError",
"can only specify one unknown dimension",
[None, None, None],
ctx.current_loc,
);
// The new shape must be compatible with the old shape
let out_sz = call_ndarray_calc_size(generator, ctx, &out.dim_sizes(), (None, None));
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(),
"0:ValueError",
"cannot reshape array of size {0} into provided shape of size {1}",
[Some(n_sz), Some(out_sz), None],
ctx.current_loc,
);
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(n_sz, false),
|generator, ctx, _, idx| {
let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
unsafe { out.data().set_unchecked(ctx, generator, &idx, elem) };
Ok(())
},
llvm_usize.const_int(1, false),
)?;
Ok(out.as_base_value().into())
} else {
codegen_unreachable!(
ctx,
"{FN_NAME}() not supported for '{}'",
format!("'{}'", ctx.unifier.stringify(x1_ty))
)
}
}
/// Generates LLVM IR for `ndarray.dot`.
/// Calculate inner product of two vectors or literals
/// For matrix multiplication use `np_matmul`

View File

@ -0,0 +1,12 @@
use inkwell::values::BasicValueEnum;
use crate::typecheck::typedef::Type;
/// A NAC3 LLVM Python object of any type.
#[derive(Debug, Clone, Copy)]
pub struct AnyObject<'ctx> {
/// Typechecker type of the object.
pub ty: Type,
/// LLVM value of the object.
pub value: BasicValueEnum<'ctx>,
}

View File

@ -0,0 +1,87 @@
use crate::{
codegen::{model::*, CodeGenContext, CodeGenerator},
typecheck::typedef::{iter_type_vars, Type, TypeEnum},
};
use super::any::AnyObject;
/// Fields of [`List`]
pub struct ListFields<'ctx, F: FieldTraversal<'ctx>, Item: Model<'ctx>> {
/// Array pointer to content
pub items: F::Output<Ptr<Item>>,
/// Number of items in the array
pub len: F::Output<Int<SizeT>>,
}
/// A list in NAC3.
#[derive(Debug, Clone, Copy, Default)]
pub struct List<Item> {
/// Model of the list items
pub item: Item,
}
impl<'ctx, Item: Model<'ctx>> StructKind<'ctx> for List<Item> {
type Fields<F: FieldTraversal<'ctx>> = ListFields<'ctx, F, Item>;
fn iter_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
Self::Fields {
items: traversal.add("items", Ptr(self.item)),
len: traversal.add_auto("len"),
}
}
}
impl<'ctx, Item: Model<'ctx>> Instance<'ctx, Ptr<Struct<List<Item>>>> {
/// Cast the items pointer to `uint8_t*`.
pub fn with_pi8_items<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Instance<'ctx, Ptr<Struct<List<Int<Byte>>>>> {
self.pointer_cast(generator, ctx, Struct(List { item: Int(Byte) }))
}
}
/// A NAC3 Python List object.
#[derive(Debug, Clone, Copy)]
pub struct ListObject<'ctx> {
/// Typechecker type of the list items
pub item_type: Type,
pub instance: Instance<'ctx, Ptr<Struct<List<Any<'ctx>>>>>,
}
impl<'ctx> ListObject<'ctx> {
/// Create a [`ListObject`] from an LLVM value and its typechecker [`Type`].
pub fn from_object<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
object: AnyObject<'ctx>,
) -> Self {
// Check typechecker type and extract `item_type`
let item_type = match &*ctx.unifier.get_ty(object.ty) {
TypeEnum::TObj { obj_id, params, .. }
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{
iter_type_vars(params).next().unwrap().ty // Extract `item_type`
}
_ => {
panic!("Expecting type to be a list, but got {}", ctx.unifier.stringify(object.ty))
}
};
let plist = Ptr(Struct(List { item: Any(ctx.get_llvm_type(generator, item_type)) }));
// Create object
let value = plist.check_value(generator, ctx.ctx, object.value).unwrap();
ListObject { item_type, instance: value }
}
/// Get the `len()` of this list.
pub fn len<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Instance<'ctx, Int<SizeT>> {
self.instance.get(generator, ctx, |f| f.len)
}
}

View File

@ -0,0 +1,5 @@
pub mod any;
pub mod list;
pub mod ndarray;
pub mod tuple;
pub mod utils;

View File

@ -0,0 +1,184 @@
use super::NDArrayObject;
use crate::{
codegen::{
irrt::{
call_nac3_ndarray_array_set_and_validate_list_shape,
call_nac3_ndarray_array_write_list_to_array,
},
model::*,
object::{any::AnyObject, list::ListObject},
stmt::gen_if_else_expr_callback,
CodeGenContext, CodeGenerator,
},
toplevel::helper::{arraylike_flatten_element_type, arraylike_get_ndims},
typecheck::typedef::{Type, TypeEnum},
};
/// Get the expected `dtype` and `ndims` of the ndarray returned by `np_array(list)`.
fn get_list_object_dtype_and_ndims<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
list: ListObject<'ctx>,
) -> (Type, u64) {
let dtype = arraylike_flatten_element_type(&mut ctx.unifier, list.item_type);
let ndims = arraylike_get_ndims(&mut ctx.unifier, list.item_type);
let ndims = ndims + 1; // To count `list` itself.
(dtype, ndims)
}
impl<'ctx> NDArrayObject<'ctx> {
/// Implementation of `np_array(<list>, copy=True)`
fn make_np_array_list_copy_true_impl<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
list: ListObject<'ctx>,
) -> Self {
let (dtype, ndims_int) = get_list_object_dtype_and_ndims(ctx, list);
let list_value = list.instance.with_pi8_items(generator, ctx);
// Validate `list` has a consistent shape.
// Raise an exception if `list` is something abnormal like `[[1, 2], [3]]`.
// If `list` has a consistent shape, deduce the shape and write it to `shape`.
let ndims = Int(SizeT).const_int(generator, ctx.ctx, ndims_int, false);
let shape = Int(SizeT).array_alloca(generator, ctx, ndims.value);
call_nac3_ndarray_array_set_and_validate_list_shape(
generator, ctx, list_value, ndims, shape,
);
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims_int);
ndarray.copy_shape_from_array(generator, ctx, shape);
ndarray.create_data(generator, ctx);
// Copy all contents from the list.
call_nac3_ndarray_array_write_list_to_array(generator, ctx, list_value, ndarray.instance);
ndarray
}
/// Implementation of `np_array(<list>, copy=None)`
fn make_np_array_list_copy_none_impl<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
list: ListObject<'ctx>,
) -> Self {
// np_array without copying is only possible `list` is not nested.
//
// If `list` is `list[T]`, we can create an ndarray with `data` set
// to the array pointer of `list`.
//
// If `list` is `list[list[T]]` or worse, copy.
let (dtype, ndims) = get_list_object_dtype_and_ndims(ctx, list);
if ndims == 1 {
// `list` is not nested
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, 1);
// Set data
let data = list.instance.get(generator, ctx, |f| f.items).cast_to_pi8(generator, ctx);
ndarray.instance.set(ctx, |f| f.data, data);
// ndarray->shape[0] = list->len;
let shape = ndarray.instance.get(generator, ctx, |f| f.shape);
let list_len = list.instance.get(generator, ctx, |f| f.len);
shape.set_index_const(ctx, 0, list_len);
// Set strides, the `data` is contiguous
ndarray.set_strides_contiguous(generator, ctx);
ndarray
} else {
// `list` is nested, copy
NDArrayObject::make_np_array_list_copy_true_impl(generator, ctx, list)
}
}
/// Implementation of `np_array(<list>, copy=copy)`
fn make_np_array_list_impl<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
list: ListObject<'ctx>,
copy: Instance<'ctx, Int<Bool>>,
) -> Self {
let (dtype, ndims) = get_list_object_dtype_and_ndims(ctx, list);
let ndarray = gen_if_else_expr_callback(
generator,
ctx,
|_generator, _ctx| Ok(copy.value),
|generator, ctx| {
let ndarray =
NDArrayObject::make_np_array_list_copy_true_impl(generator, ctx, list);
Ok(Some(ndarray.instance.value))
},
|generator, ctx| {
let ndarray =
NDArrayObject::make_np_array_list_copy_none_impl(generator, ctx, list);
Ok(Some(ndarray.instance.value))
},
)
.unwrap()
.unwrap();
NDArrayObject::from_value_and_unpacked_types(generator, ctx, ndarray, dtype, ndims)
}
/// Implementation of `np_array(<ndarray>, copy=copy)`.
pub fn make_np_array_ndarray_impl<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayObject<'ctx>,
copy: Instance<'ctx, Int<Bool>>,
) -> Self {
let ndarray_val = gen_if_else_expr_callback(
generator,
ctx,
|_generator, _ctx| Ok(copy.value),
|generator, ctx| {
let ndarray = ndarray.make_copy(generator, ctx); // Force copy
Ok(Some(ndarray.instance.value))
},
|_generator, _ctx| {
// No need to copy. Return `ndarray` itself.
Ok(Some(ndarray.instance.value))
},
)
.unwrap()
.unwrap();
NDArrayObject::from_value_and_unpacked_types(
generator,
ctx,
ndarray_val,
ndarray.dtype,
ndarray.ndims,
)
}
/// Create a new ndarray like `np.array()`.
///
/// NOTE: The `ndmin` argument is not here. You may want to
/// do [`NDArrayObject::atleast_nd`] to achieve that.
pub fn make_np_array<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
object: AnyObject<'ctx>,
copy: Instance<'ctx, Int<Bool>>,
) -> Self {
match &*ctx.unifier.get_ty(object.ty) {
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{
let list = ListObject::from_object(generator, ctx, object);
NDArrayObject::make_np_array_list_impl(generator, ctx, list, copy)
}
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
{
let ndarray = NDArrayObject::from_object(generator, ctx, object);
NDArrayObject::make_np_array_ndarray_impl(generator, ctx, ndarray, copy)
}
_ => panic!("Unrecognized object type: {}", ctx.unifier.stringify(object.ty)), // Typechecker ensures this
}
}
}

View File

@ -0,0 +1,176 @@
use inkwell::{values::BasicValueEnum, IntPredicate};
use crate::{
codegen::{
irrt::call_nac3_ndarray_util_assert_shape_no_negative, model::*, CodeGenContext,
CodeGenerator,
},
typecheck::typedef::Type,
};
use super::NDArrayObject;
/// Get the zero value in `np.zeros()` of a `dtype`.
fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
) -> BasicValueEnum<'ctx> {
if [ctx.primitives.int32, ctx.primitives.uint32]
.iter()
.any(|ty| ctx.unifier.unioned(dtype, *ty))
{
ctx.ctx.i32_type().const_zero().into()
} else if [ctx.primitives.int64, ctx.primitives.uint64]
.iter()
.any(|ty| ctx.unifier.unioned(dtype, *ty))
{
ctx.ctx.i64_type().const_zero().into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.float) {
ctx.ctx.f64_type().const_zero().into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.bool) {
ctx.ctx.bool_type().const_zero().into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.str) {
ctx.gen_string(generator, "").into()
} else {
panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype));
}
}
/// Get the one value in `np.ones()` of a `dtype`.
fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
) -> BasicValueEnum<'ctx> {
if [ctx.primitives.int32, ctx.primitives.uint32]
.iter()
.any(|ty| ctx.unifier.unioned(dtype, *ty))
{
let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int32);
ctx.ctx.i32_type().const_int(1, is_signed).into()
} else if [ctx.primitives.int64, ctx.primitives.uint64]
.iter()
.any(|ty| ctx.unifier.unioned(dtype, *ty))
{
let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int64);
ctx.ctx.i64_type().const_int(1, is_signed).into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.float) {
ctx.ctx.f64_type().const_float(1.0).into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.bool) {
ctx.ctx.bool_type().const_int(1, false).into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.str) {
ctx.gen_string(generator, "1").into()
} else {
panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype));
}
}
impl<'ctx> NDArrayObject<'ctx> {
/// Create an ndarray like `np.empty`.
pub fn make_np_empty<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
ndims: u64,
shape: Instance<'ctx, Ptr<Int<SizeT>>>,
) -> Self {
// Validate `shape`
let ndims_llvm = Int(SizeT).const_int(generator, ctx.ctx, ndims, false);
call_nac3_ndarray_util_assert_shape_no_negative(generator, ctx, ndims_llvm, shape);
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims);
ndarray.copy_shape_from_array(generator, ctx, shape);
ndarray.create_data(generator, ctx);
ndarray
}
/// Create an ndarray like `np.full`.
pub fn make_np_full<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
ndims: u64,
shape: Instance<'ctx, Ptr<Int<SizeT>>>,
fill_value: BasicValueEnum<'ctx>,
) -> Self {
let ndarray = NDArrayObject::make_np_empty(generator, ctx, dtype, ndims, shape);
ndarray.fill(generator, ctx, fill_value);
ndarray
}
/// Create an ndarray like `np.zero`.
pub fn make_np_zeros<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
ndims: u64,
shape: Instance<'ctx, Ptr<Int<SizeT>>>,
) -> Self {
let fill_value = ndarray_zero_value(generator, ctx, dtype);
NDArrayObject::make_np_full(generator, ctx, dtype, ndims, shape, fill_value)
}
/// Create an ndarray like `np.ones`.
pub fn make_np_ones<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
ndims: u64,
shape: Instance<'ctx, Ptr<Int<SizeT>>>,
) -> Self {
let fill_value = ndarray_one_value(generator, ctx, dtype);
NDArrayObject::make_np_full(generator, ctx, dtype, ndims, shape, fill_value)
}
/// Create an ndarray like `np.eye`.
pub fn make_np_eye<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
nrows: Instance<'ctx, Int<SizeT>>,
ncols: Instance<'ctx, Int<SizeT>>,
offset: Instance<'ctx, Int<SizeT>>,
) -> Self {
let ndzero = ndarray_zero_value(generator, ctx, dtype);
let ndone = ndarray_one_value(generator, ctx, dtype);
let ndarray = NDArrayObject::alloca_dynamic_shape(generator, ctx, dtype, &[nrows, ncols]);
// Create data and make the matrix like look np.eye()
ndarray.create_data(generator, ctx);
ndarray
.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
// NOTE: rows and cols can never be zero here, since this ndarray's `np.size` would be zero
// and this loop would not execute.
// Load up `row_i` and `col_i` from indices.
let row_i = nditer.get_indices().get_index_const(generator, ctx, 0);
let col_i = nditer.get_indices().get_index_const(generator, ctx, 1);
let be_one = row_i.add(ctx, offset).compare(ctx, IntPredicate::EQ, col_i);
let value = ctx.builder.build_select(be_one.value, ndone, ndzero, "value").unwrap();
let p = nditer.get_pointer(generator, ctx);
ctx.builder.build_store(p, value).unwrap();
Ok(())
})
.unwrap();
ndarray
}
/// Create an ndarray like `np.identity`.
pub fn make_np_identity<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
size: Instance<'ctx, Int<SizeT>>,
) -> Self {
// Convenient implementation
let offset = Int(SizeT).const_0(generator, ctx.ctx);
NDArrayObject::make_np_eye(generator, ctx, dtype, size, size, offset)
}
}

View File

@ -0,0 +1,227 @@
use crate::codegen::{
irrt::call_nac3_ndarray_index,
model::*,
object::utils::slice::{RustSlice, Slice},
CodeGenContext, CodeGenerator,
};
use super::NDArrayObject;
pub type NDIndexType = Byte;
/// Fields of [`NDIndex`]
#[derive(Debug, Clone, Copy)]
pub struct NDIndexFields<'ctx, F: FieldTraversal<'ctx>> {
pub type_: F::Output<Int<NDIndexType>>,
pub data: F::Output<Ptr<Int<Byte>>>,
}
/// An IRRT representation of an ndarray subscript index.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct NDIndex;
impl<'ctx> StructKind<'ctx> for NDIndex {
type Fields<F: FieldTraversal<'ctx>> = NDIndexFields<'ctx, F>;
fn iter_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
Self::Fields { type_: traversal.add_auto("type"), data: traversal.add_auto("data") }
}
}
// A convenience enum representing a [`NDIndex`].
#[derive(Debug, Clone)]
pub enum RustNDIndex<'ctx> {
SingleElement(Instance<'ctx, Int<Int32>>),
Slice(RustSlice<'ctx, Int32>),
NewAxis,
Ellipsis,
}
impl<'ctx> RustNDIndex<'ctx> {
/// Get the value to set `NDIndex::type` for this variant.
fn get_type_id(&self) -> u64 {
// Defined in IRRT, must be in sync
match self {
RustNDIndex::SingleElement(_) => 0,
RustNDIndex::Slice(_) => 1,
RustNDIndex::NewAxis => 2,
RustNDIndex::Ellipsis => 3,
}
}
/// Serialize this [`RustNDIndex`] by writing it into an LLVM [`NDIndex`].
fn write_to_ndindex<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
dst_ndindex_ptr: Instance<'ctx, Ptr<Struct<NDIndex>>>,
) {
// Set `dst_ndindex_ptr->type`
dst_ndindex_ptr.gep(ctx, |f| f.type_).store(
ctx,
Int(NDIndexType::default()).const_int(generator, ctx.ctx, self.get_type_id(), false),
);
// Set `dst_ndindex_ptr->data`
match self {
RustNDIndex::SingleElement(in_index) => {
let index_ptr = Int(Int32).alloca(generator, ctx);
index_ptr.store(ctx, *in_index);
dst_ndindex_ptr
.gep(ctx, |f| f.data)
.store(ctx, index_ptr.pointer_cast(generator, ctx, Int(Byte)));
}
RustNDIndex::Slice(in_rust_slice) => {
let user_slice_ptr = Struct(Slice(Int32)).alloca(generator, ctx);
in_rust_slice.write_to_slice(generator, ctx, user_slice_ptr);
dst_ndindex_ptr
.gep(ctx, |f| f.data)
.store(ctx, user_slice_ptr.pointer_cast(generator, ctx, Int(Byte)));
}
RustNDIndex::NewAxis | RustNDIndex::Ellipsis => {}
}
}
/// Serialize a list of `RustNDIndex` as a newly allocated LLVM array of `NDIndex`.
pub fn make_ndindices<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
in_ndindices: &[RustNDIndex<'ctx>],
) -> (Instance<'ctx, Int<SizeT>>, Instance<'ctx, Ptr<Struct<NDIndex>>>) {
let ndindex_model = Struct(NDIndex);
// Allocate the LLVM ndindices.
let num_ndindices =
Int(SizeT).const_int(generator, ctx.ctx, in_ndindices.len() as u64, false);
let ndindices = ndindex_model.array_alloca(generator, ctx, num_ndindices.value);
// Initialize all of them.
for (i, in_ndindex) in in_ndindices.iter().enumerate() {
let pndindex = ndindices.offset_const(ctx, i64::try_from(i).unwrap());
in_ndindex.write_to_ndindex(generator, ctx, pndindex);
}
(num_ndindices, ndindices)
}
}
impl<'ctx> NDArrayObject<'ctx> {
/// Get the expected `ndims` after indexing with `indices`.
#[must_use]
fn deduce_ndims_after_indexing_with(&self, indices: &[RustNDIndex<'ctx>]) -> u64 {
let mut ndims = self.ndims;
for index in indices {
match index {
RustNDIndex::SingleElement(_) => {
ndims -= 1; // Single elements decrements ndims
}
RustNDIndex::NewAxis => {
ndims += 1; // `np.newaxis` / `none` adds a new axis
}
RustNDIndex::Ellipsis | RustNDIndex::Slice(_) => {}
}
}
ndims
}
/// Index into the ndarray, and return a newly-allocated view on this ndarray.
///
/// This function behaves like NumPy's ndarray indexing, but if the indices index
/// into a single element, an unsized ndarray is returned.
#[must_use]
pub fn index<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
indices: &[RustNDIndex<'ctx>],
) -> Self {
let dst_ndims = self.deduce_ndims_after_indexing_with(indices);
let dst_ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, dst_ndims);
let (num_indices, indices) = RustNDIndex::make_ndindices(generator, ctx, indices);
call_nac3_ndarray_index(
generator,
ctx,
num_indices,
indices,
self.instance,
dst_ndarray.instance,
);
dst_ndarray
}
}
pub mod util {
use itertools::Itertools;
use nac3parser::ast::{Expr, ExprKind};
use crate::{
codegen::{model::*, object::utils::slice::util::gen_slice, CodeGenContext, CodeGenerator},
typecheck::typedef::Type,
};
use super::RustNDIndex;
/// Generate LLVM code to transform an ndarray subscript expression to
/// its list of [`RustNDIndex`]
///
/// i.e.,
/// ```python
/// my_ndarray[::3, 1, :2:]
/// ^^^^^^^^^^^ Then these into a three `RustNDIndex`es
/// ```
pub fn gen_ndarray_subscript_ndindices<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
subscript: &Expr<Option<Type>>,
) -> Result<Vec<RustNDIndex<'ctx>>, String> {
// TODO: Support https://numpy.org/doc/stable/user/basics.indexing.html#dimensional-indexing-tools
// Annoying notes about `slice`
// - `my_array[5]`
// - slice is a `Constant`
// - `my_array[:5]`
// - slice is a `Slice`
// - `my_array[:]`
// - slice is a `Slice`, but lower upper step would all be `Option::None`
// - `my_array[:, :]`
// - slice is now a `Tuple` of two `Slice`-s
//
// In summary:
// - when there is a comma "," within [], `slice` will be a `Tuple` of the entries.
// - when there is not comma "," within [] (i.e., just a single entry), `slice` will be that entry itself.
//
// So we first "flatten" out the slice expression
let index_exprs = match &subscript.node {
ExprKind::Tuple { elts, .. } => elts.iter().collect_vec(),
_ => vec![subscript],
};
// Process all index expressions
let mut rust_ndindices: Vec<RustNDIndex> = Vec::with_capacity(index_exprs.len()); // Not using iterators here because `?` is used here.
for index_expr in index_exprs {
// NOTE: Currently nac3core's slices do not have an object representation,
// so the code/implementation looks awkward - we have to do pattern matching on the expression
let ndindex = if let ExprKind::Slice { lower, upper, step } = &index_expr.node {
// Handle slices
let slice = gen_slice(generator, ctx, lower, upper, step)?;
RustNDIndex::Slice(slice)
} else {
// Treat and handle everything else as a single element index.
let index = generator.gen_expr(ctx, index_expr)?.unwrap().to_basic_value_enum(
ctx,
generator,
ctx.primitives.int32, // Must be int32, this checks for illegal values
)?;
let index = Int(Int32).check_value(generator, ctx.ctx, index).unwrap();
RustNDIndex::SingleElement(index)
};
rust_ndindices.push(ndindex);
}
Ok(rust_ndindices)
}
}

View File

@ -0,0 +1,504 @@
pub mod array;
pub mod factory;
pub mod indexing;
pub mod nditer;
pub mod shape_util;
pub mod view;
use inkwell::{
context::Context,
types::BasicType,
values::{BasicValue, BasicValueEnum, PointerValue},
AddressSpace,
};
use crate::{
codegen::{
irrt::{
call_nac3_ndarray_copy_data, call_nac3_ndarray_get_nth_pelement,
call_nac3_ndarray_get_pelement_by_indices, call_nac3_ndarray_is_c_contiguous,
call_nac3_ndarray_len, call_nac3_ndarray_nbytes,
call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_size,
},
model::*,
CodeGenContext, CodeGenerator,
},
toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys},
typecheck::typedef::Type,
};
use super::{any::AnyObject, tuple::TupleObject};
/// Fields of [`NDArray`]
pub struct NDArrayFields<'ctx, F: FieldTraversal<'ctx>> {
pub data: F::Output<Ptr<Int<Byte>>>,
pub itemsize: F::Output<Int<SizeT>>,
pub ndims: F::Output<Int<SizeT>>,
pub shape: F::Output<Ptr<Int<SizeT>>>,
pub strides: F::Output<Ptr<Int<SizeT>>>,
}
/// A strided ndarray in NAC3.
///
/// See IRRT implementation for details about its fields.
#[derive(Debug, Clone, Copy, Default)]
pub struct NDArray;
impl<'ctx> StructKind<'ctx> for NDArray {
type Fields<F: FieldTraversal<'ctx>> = NDArrayFields<'ctx, F>;
fn iter_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
Self::Fields {
data: traversal.add_auto("data"),
itemsize: traversal.add_auto("itemsize"),
ndims: traversal.add_auto("ndims"),
shape: traversal.add_auto("shape"),
strides: traversal.add_auto("strides"),
}
}
}
/// A NAC3 Python ndarray object.
#[derive(Debug, Clone, Copy)]
pub struct NDArrayObject<'ctx> {
pub dtype: Type,
pub ndims: u64,
pub instance: Instance<'ctx, Ptr<Struct<NDArray>>>,
}
impl<'ctx> NDArrayObject<'ctx> {
/// Attempt to convert an [`AnyObject`] into an [`NDArrayObject`].
pub fn from_object<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
object: AnyObject<'ctx>,
) -> NDArrayObject<'ctx> {
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, object.ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
Self::from_value_and_unpacked_types(generator, ctx, object.value, dtype, ndims)
}
/// Like [`NDArrayObject::from_object`] but you directly supply the ndarray's
/// `dtype` and `ndims`.
pub fn from_value_and_unpacked_types<V: BasicValue<'ctx>, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
value: V,
dtype: Type,
ndims: u64,
) -> Self {
let value = Ptr(Struct(NDArray)).check_value(generator, ctx.ctx, value).unwrap();
NDArrayObject { dtype, ndims, instance: value }
}
/// Get this ndarray's `ndims` as an LLVM constant.
pub fn ndims_llvm<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
) -> Instance<'ctx, Int<SizeT>> {
Int(SizeT).const_int(generator, ctx, self.ndims, false)
}
/// Allocate an ndarray on the stack given its `ndims` and `dtype`.
///
/// `shape` and `strides` will be automatically allocated onto the stack.
///
/// The returned ndarray's content will be:
/// - `data`: uninitialized.
/// - `itemsize`: set to the `sizeof()` of `dtype`.
/// - `ndims`: set to the value of `ndims`.
/// - `shape`: allocated with an array of length `ndims` with uninitialized values.
/// - `strides`: allocated with an array of length `ndims` with uninitialized values.
pub fn alloca<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
ndims: u64,
) -> Self {
let ndarray = Struct(NDArray).alloca(generator, ctx);
let itemsize = ctx.get_llvm_type(generator, dtype).size_of().unwrap();
let itemsize = Int(SizeT).z_extend_or_truncate(generator, ctx, itemsize);
ndarray.set(ctx, |f| f.itemsize, itemsize);
let ndims_val = Int(SizeT).const_int(generator, ctx.ctx, ndims, false);
ndarray.set(ctx, |f| f.ndims, ndims_val);
let shape = Int(SizeT).array_alloca(generator, ctx, ndims_val.value);
ndarray.set(ctx, |f| f.shape, shape);
let strides = Int(SizeT).array_alloca(generator, ctx, ndims_val.value);
ndarray.set(ctx, |f| f.strides, strides);
NDArrayObject { dtype, ndims, instance: ndarray }
}
/// Convenience function. Allocate an [`NDArrayObject`] with a statically known shape.
///
/// The returned [`NDArrayObject`]'s `data` and `strides` are uninitialized.
pub fn alloca_constant_shape<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
shape: &[u64],
) -> Self {
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, shape.len() as u64);
// Write shape
let dst_shape = ndarray.instance.get(generator, ctx, |f| f.shape);
for (i, dim) in shape.iter().enumerate() {
let dim = Int(SizeT).const_int(generator, ctx.ctx, *dim, false);
dst_shape.offset_const(ctx, i64::try_from(i).unwrap()).store(ctx, dim);
}
ndarray
}
/// Convenience function. Allocate an [`NDArrayObject`] with a dynamically known shape.
///
/// The returned [`NDArrayObject`]'s `data` and `strides` are uninitialized.
pub fn alloca_dynamic_shape<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
shape: &[Instance<'ctx, Int<SizeT>>],
) -> Self {
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, shape.len() as u64);
// Write shape
let dst_shape = ndarray.instance.get(generator, ctx, |f| f.shape);
for (i, dim) in shape.iter().enumerate() {
dst_shape.offset_const(ctx, i64::try_from(i).unwrap()).store(ctx, *dim);
}
ndarray
}
/// Initialize an ndarray's `data` by allocating a buffer on the stack.
/// The allocated data buffer is considered to be *owned* by the ndarray.
///
/// `strides` of the ndarray will also be updated with `set_strides_by_shape`.
///
/// `shape` and `itemsize` of the ndarray ***must*** be initialized first.
pub fn create_data<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) {
let nbytes = self.nbytes(generator, ctx);
let data = Int(Byte).array_alloca(generator, ctx, nbytes.value);
self.instance.set(ctx, |f| f.data, data);
self.set_strides_contiguous(generator, ctx);
}
/// Copy shape dimensions from an array.
pub fn copy_shape_from_array<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
shape: Instance<'ctx, Ptr<Int<SizeT>>>,
) {
let num_items = self.ndims_llvm(generator, ctx.ctx).value;
self.instance.get(generator, ctx, |f| f.shape).copy_from(generator, ctx, shape, num_items);
}
/// Copy shape dimensions from an ndarray.
/// Panics if `ndims` mismatches.
pub fn copy_shape_from_ndarray<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src_ndarray: NDArrayObject<'ctx>,
) {
assert_eq!(self.ndims, src_ndarray.ndims);
let src_shape = src_ndarray.instance.get(generator, ctx, |f| f.shape);
self.copy_shape_from_array(generator, ctx, src_shape);
}
/// Copy strides dimensions from an array.
pub fn copy_strides_from_array<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
strides: Instance<'ctx, Ptr<Int<SizeT>>>,
) {
let num_items = self.ndims_llvm(generator, ctx.ctx).value;
self.instance
.get(generator, ctx, |f| f.strides)
.copy_from(generator, ctx, strides, num_items);
}
/// Copy strides dimensions from an ndarray.
/// Panics if `ndims` mismatches.
pub fn copy_strides_from_ndarray<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src_ndarray: NDArrayObject<'ctx>,
) {
assert_eq!(self.ndims, src_ndarray.ndims);
let src_strides = src_ndarray.instance.get(generator, ctx, |f| f.strides);
self.copy_strides_from_array(generator, ctx, src_strides);
}
/// Get the `np.size()` of this ndarray.
pub fn size<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Instance<'ctx, Int<SizeT>> {
call_nac3_ndarray_size(generator, ctx, self.instance)
}
/// Get the `ndarray.nbytes` of this ndarray.
pub fn nbytes<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Instance<'ctx, Int<SizeT>> {
call_nac3_ndarray_nbytes(generator, ctx, self.instance)
}
/// Get the `len()` of this ndarray.
pub fn len<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Instance<'ctx, Int<SizeT>> {
call_nac3_ndarray_len(generator, ctx, self.instance)
}
/// Check if this ndarray is C-contiguous.
///
/// See NumPy's `flags["C_CONTIGUOUS"]`: <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags>
pub fn is_c_contiguous<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Instance<'ctx, Int<Bool>> {
call_nac3_ndarray_is_c_contiguous(generator, ctx, self.instance)
}
/// Get the pointer to the n-th (0-based) element.
///
/// The returned pointer has the element type of the LLVM type of this ndarray's `dtype`.
pub fn get_nth_pelement<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
nth: Instance<'ctx, Int<SizeT>>,
) -> PointerValue<'ctx> {
let elem_ty = ctx.get_llvm_type(generator, self.dtype);
let p = call_nac3_ndarray_get_nth_pelement(generator, ctx, self.instance, nth);
ctx.builder
.build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), "")
.unwrap()
}
/// Get the n-th (0-based) scalar.
pub fn get_nth_scalar<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
nth: Instance<'ctx, Int<SizeT>>,
) -> AnyObject<'ctx> {
let ptr = self.get_nth_pelement(generator, ctx, nth);
let value = ctx.builder.build_load(ptr, "").unwrap();
AnyObject { ty: self.dtype, value }
}
/// Get the pointer to the element indexed by `indices`.
///
/// The returned pointer has the element type of the LLVM type of this ndarray's `dtype`.
pub fn get_pelement_by_indices<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
indices: Instance<'ctx, Ptr<Int<SizeT>>>,
) -> PointerValue<'ctx> {
let elem_ty = ctx.get_llvm_type(generator, self.dtype);
let p = call_nac3_ndarray_get_pelement_by_indices(generator, ctx, self.instance, indices);
ctx.builder
.build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), "")
.unwrap()
}
/// Get the scalar indexed by `indices`.
pub fn get_scalar_by_indices<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
indices: Instance<'ctx, Ptr<Int<SizeT>>>,
) -> AnyObject<'ctx> {
let ptr = self.get_pelement_by_indices(generator, ctx, indices);
let value = ctx.builder.build_load(ptr, "").unwrap();
AnyObject { ty: self.dtype, value }
}
/// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`.
///
/// Update the ndarray's strides to make the ndarray contiguous.
pub fn set_strides_contiguous<G: CodeGenerator + ?Sized>(
self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) {
call_nac3_ndarray_set_strides_by_shape(generator, ctx, self.instance);
}
/// Clone/Copy this ndarray - Allocate a new ndarray with the same shape as this ndarray and copy the contents over.
///
/// The new ndarray will own its data and will be C-contiguous.
#[must_use]
pub fn make_copy<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Self {
let clone = NDArrayObject::alloca(generator, ctx, self.dtype, self.ndims);
let shape = self.instance.gep(ctx, |f| f.shape).load(generator, ctx);
clone.copy_shape_from_array(generator, ctx, shape);
clone.create_data(generator, ctx);
clone.copy_data_from(generator, ctx, *self);
clone
}
/// Copy data from another ndarray.
///
/// This ndarray and `src` is that their `np.size()` should be the same. Their shapes
/// do not matter. The copying order is determined by how their flattened views look.
///
/// Panics if the `dtype`s of ndarrays are different.
pub fn copy_data_from<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src: NDArrayObject<'ctx>,
) {
assert!(ctx.unifier.unioned(self.dtype, src.dtype), "self and src dtype should match");
call_nac3_ndarray_copy_data(generator, ctx, src.instance, self.instance);
}
/// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar.
#[must_use]
pub fn is_unsized(&self) -> bool {
self.ndims == 0
}
/// If this ndarray is unsized, return its sole value as an [`AnyObject`].
/// Otherwise, do nothing and return the ndarray itself.
pub fn split_unsized<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> ScalarOrNDArray<'ctx> {
if self.is_unsized() {
// NOTE: `np.size(self) == 0` here is never possible.
let zero = Int(SizeT).const_0(generator, ctx.ctx);
let value = self.get_nth_scalar(generator, ctx, zero).value;
ScalarOrNDArray::Scalar(AnyObject { ty: self.dtype, value })
} else {
ScalarOrNDArray::NDArray(*self)
}
}
/// Fill the ndarray with a scalar.
///
/// `fill_value` must have the same LLVM type as the `dtype` of this ndarray.
pub fn fill<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
value: BasicValueEnum<'ctx>,
) {
// TODO: It is possible to optimize this by exploiting contiguous strides with memset.
// Probably best to implement in IRRT.
self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
let p = nditer.get_pointer(generator, ctx);
ctx.builder.build_store(p, value).unwrap();
Ok(())
})
.unwrap();
}
/// Create the shape tuple of this ndarray like `np.shape(<ndarray>)`.
///
/// The returned integers in the tuple are in int32.
pub fn make_shape_tuple<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> TupleObject<'ctx> {
// TODO: Return a tuple of SizeT
let mut objects = Vec::with_capacity(self.ndims as usize);
for i in 0..self.ndims {
let dim = self
.instance
.get(generator, ctx, |f| f.shape)
.get_index_const(generator, ctx, i64::try_from(i).unwrap())
.truncate_or_bit_cast(generator, ctx, Int32);
objects.push(AnyObject {
ty: ctx.primitives.int32,
value: dim.value.as_basic_value_enum(),
});
}
TupleObject::from_objects(generator, ctx, objects)
}
/// Create the strides tuple of this ndarray like `<ndarray>.strides`.
///
/// The returned integers in the tuple are in int32.
pub fn make_strides_tuple<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> TupleObject<'ctx> {
// TODO: Return a tuple of SizeT.
let mut objects = Vec::with_capacity(self.ndims as usize);
for i in 0..self.ndims {
let dim = self
.instance
.get(generator, ctx, |f| f.strides)
.get_index_const(generator, ctx, i64::try_from(i).unwrap())
.truncate_or_bit_cast(generator, ctx, Int32);
objects.push(AnyObject {
ty: ctx.primitives.int32,
value: dim.value.as_basic_value_enum(),
});
}
TupleObject::from_objects(generator, ctx, objects)
}
}
/// A convenience enum for implementing functions that acts on scalars or ndarrays or both.
#[derive(Debug, Clone, Copy)]
pub enum ScalarOrNDArray<'ctx> {
Scalar(AnyObject<'ctx>),
NDArray(NDArrayObject<'ctx>),
}
impl<'ctx> ScalarOrNDArray<'ctx> {
/// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`].
#[must_use]
pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> {
match self {
ScalarOrNDArray::Scalar(scalar) => scalar.value,
ScalarOrNDArray::NDArray(ndarray) => ndarray.instance.value.as_basic_value_enum(),
}
}
}

View File

@ -0,0 +1,179 @@
use inkwell::{types::BasicType, values::PointerValue, AddressSpace};
use crate::codegen::{
irrt::{call_nac3_nditer_has_element, call_nac3_nditer_initialize, call_nac3_nditer_next},
model::*,
object::any::AnyObject,
stmt::{gen_for_callback, BreakContinueHooks},
CodeGenContext, CodeGenerator,
};
use super::NDArrayObject;
/// Fields of [`NDIter`]
pub struct NDIterFields<'ctx, F: FieldTraversal<'ctx>> {
pub ndims: F::Output<Int<SizeT>>,
pub shape: F::Output<Ptr<Int<SizeT>>>,
pub strides: F::Output<Ptr<Int<SizeT>>>,
pub indices: F::Output<Ptr<Int<SizeT>>>,
pub nth: F::Output<Int<SizeT>>,
pub element: F::Output<Ptr<Int<Byte>>>,
pub size: F::Output<Int<SizeT>>,
}
/// An IRRT helper structure used to iterate through an ndarray.
#[derive(Debug, Clone, Copy, Default)]
pub struct NDIter;
impl<'ctx> StructKind<'ctx> for NDIter {
type Fields<F: FieldTraversal<'ctx>> = NDIterFields<'ctx, F>;
fn iter_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
Self::Fields {
ndims: traversal.add_auto("ndims"),
shape: traversal.add_auto("shape"),
strides: traversal.add_auto("strides"),
indices: traversal.add_auto("indices"),
nth: traversal.add_auto("nth"),
element: traversal.add_auto("element"),
size: traversal.add_auto("size"),
}
}
}
/// A helper structure with a convenient interface to interact with [`NDIter`].
#[derive(Debug, Clone)]
pub struct NDIterHandle<'ctx> {
instance: Instance<'ctx, Ptr<Struct<NDIter>>>,
/// The ndarray this [`NDIter`] to iterating over.
ndarray: NDArrayObject<'ctx>,
/// The current indices of [`NDIter`].
indices: Instance<'ctx, Ptr<Int<SizeT>>>,
}
impl<'ctx> NDIterHandle<'ctx> {
/// Allocate an [`NDIter`] that iterates through an ndarray.
pub fn new<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayObject<'ctx>,
) -> Self {
let nditer = Struct(NDIter).alloca(generator, ctx);
let ndims = ndarray.ndims_llvm(generator, ctx.ctx);
// The caller has the responsibility to allocate 'indices' for `NDIter`.
let indices = Int(SizeT).array_alloca(generator, ctx, ndims.value);
call_nac3_nditer_initialize(generator, ctx, nditer, ndarray.instance, indices);
NDIterHandle { ndarray, instance: nditer, indices }
}
/// Is the current iteration valid?
///
/// If true, then `element`, `indices` and `nth` contain details about the current element.
///
/// If `ndarray` is unsized, this returns true only for the first iteration.
/// If `ndarray` is 0-sized, this always returns false.
#[must_use]
pub fn has_element<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Instance<'ctx, Int<Bool>> {
call_nac3_nditer_has_element(generator, ctx, self.instance)
}
/// Go to the next element. If `has_element()` is false, then this has undefined behavior.
///
/// If `ndarray` is unsized, this can only be called once.
/// If `ndarray` is 0-sized, this can never be called.
pub fn next<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) {
call_nac3_nditer_next(generator, ctx, self.instance);
}
/// Get pointer to the current element.
#[must_use]
pub fn get_pointer<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> PointerValue<'ctx> {
let elem_ty = ctx.get_llvm_type(generator, self.ndarray.dtype);
let p = self.instance.get(generator, ctx, |f| f.element);
ctx.builder
.build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), "element")
.unwrap()
}
/// Get the value of the current element.
#[must_use]
pub fn get_scalar<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> AnyObject<'ctx> {
let p = self.get_pointer(generator, ctx);
let value = ctx.builder.build_load(p, "value").unwrap();
AnyObject { ty: self.ndarray.dtype, value }
}
/// Get the index of the current element if this ndarray were a flat ndarray.
#[must_use]
pub fn get_index<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Instance<'ctx, Int<SizeT>> {
self.instance.get(generator, ctx, |f| f.nth)
}
/// Get the indices of the current element.
#[must_use]
pub fn get_indices(&self) -> Instance<'ctx, Ptr<Int<SizeT>>> {
self.indices
}
}
impl<'ctx> NDArrayObject<'ctx> {
/// Iterate through every element in the ndarray.
///
/// `body` has access to [`BreakContinueHooks`] to short-circuit and [`NDIterHandle`] to
/// get properties of the current iteration (e.g., the current element, indices, etc.)
pub fn foreach<'a, G, F>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
body: F,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
F: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
BreakContinueHooks<'ctx>,
NDIterHandle<'ctx>,
) -> Result<(), String>,
{
gen_for_callback(
generator,
ctx,
Some("ndarray_foreach"),
|generator, ctx| Ok(NDIterHandle::new(generator, ctx, *self)),
|generator, ctx, nditer| Ok(nditer.has_element(generator, ctx).value),
|generator, ctx, hooks, nditer| body(generator, ctx, hooks, nditer),
|generator, ctx, nditer| {
nditer.next(generator, ctx);
Ok(())
},
)
}
}

View File

@ -0,0 +1,105 @@
use util::gen_for_model;
use crate::{
codegen::{
model::*,
object::{any::AnyObject, list::ListObject, tuple::TupleObject},
CodeGenContext, CodeGenerator,
},
typecheck::typedef::TypeEnum,
};
/// Parse a NumPy-like "int sequence" input and return the int sequence as an array and its length.
///
/// * `sequence` - The `sequence` parameter.
/// * `sequence_ty` - The typechecker type of `sequence`
///
/// The `sequence` argument type may only be one of the following:
/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))`
/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
///
/// All `int32` values will be sign-extended to `SizeT`.
pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
input_sequence: AnyObject<'ctx>,
) -> (Instance<'ctx, Int<SizeT>>, Instance<'ctx, Ptr<Int<SizeT>>>) {
let zero = Int(SizeT).const_0(generator, ctx.ctx);
let one = Int(SizeT).const_1(generator, ctx.ctx);
// The result `list` to return.
match &*ctx.unifier.get_ty(input_sequence.ty) {
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{
// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
// Check `input_sequence`
let input_sequence = ListObject::from_object(generator, ctx, input_sequence);
let len = input_sequence.instance.get(generator, ctx, |f| f.len);
let result = Int(SizeT).array_alloca(generator, ctx, len.value);
// Load all the `int32`s from the input_sequence, cast them to `SizeT`, and store them into `result`
gen_for_model(generator, ctx, zero, len, one, |generator, ctx, _hooks, i| {
// Load the i-th int32 in the input sequence
let int = input_sequence
.instance
.get(generator, ctx, |f| f.items)
.get_index(generator, ctx, i.value)
.value
.into_int_value();
// Cast to SizeT
let int = Int(SizeT).s_extend_or_bit_cast(generator, ctx, int);
// Store
result.set_index(ctx, i.value, int);
Ok(())
})
.unwrap();
(len, result)
}
TypeEnum::TTuple { .. } => {
// 2. A tuple of ints; e.g., `np.empty((600, 800, 3))`
let input_sequence = TupleObject::from_object(ctx, input_sequence);
let len = input_sequence.len(generator, ctx);
let result = Int(SizeT).array_alloca(generator, ctx, len.value);
for i in 0..input_sequence.num_elements() {
// Get the i-th element off of the tuple and load it into `result`.
let int = input_sequence.index(ctx, i).value.into_int_value();
let int = Int(SizeT).s_extend_or_bit_cast(generator, ctx, int);
result.set_index_const(ctx, i64::try_from(i).unwrap(), int);
}
(len, result)
}
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() =>
{
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
let input_int = input_sequence.value.into_int_value();
let len = Int(SizeT).const_1(generator, ctx.ctx);
let result = Int(SizeT).array_alloca(generator, ctx, len.value);
let int = Int(SizeT).s_extend_or_bit_cast(generator, ctx, input_int);
// Storing into result[0]
result.store(ctx, int);
(len, result)
}
_ => panic!(
"encountered unknown sequence type: {}",
ctx.unifier.stringify(input_sequence.ty)
),
}
}

View File

@ -0,0 +1,89 @@
use crate::codegen::{
irrt::call_nac3_ndarray_reshape_resolve_and_check_new_shape, model::*, CodeGenContext,
CodeGenerator,
};
use super::{indexing::RustNDIndex, NDArrayObject};
impl<'ctx> NDArrayObject<'ctx> {
/// Make sure the ndarray is at least `ndmin`-dimensional.
///
/// If this ndarray's `ndims` is less than `ndmin`, a view is created on this with 1s prepended to the shape.
/// If this ndarray's `ndims` is not less than `ndmin`, this function does nothing and return this ndarray.
#[must_use]
pub fn atleast_nd<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndmin: u64,
) -> Self {
if self.ndims < ndmin {
// Extend the dimensions with np.newaxis.
let mut indices = vec![];
for _ in self.ndims..ndmin {
indices.push(RustNDIndex::NewAxis);
}
indices.push(RustNDIndex::Ellipsis);
self.index(generator, ctx, &indices)
} else {
*self
}
}
/// Create a reshaped view on this ndarray like `np.reshape()`.
///
/// If there is a `-1` in `new_shape`, it will be resolved; `new_shape` would **NOT** be modified as a result.
///
/// If reshape without copying is impossible, this function will allocate a new ndarray and copy contents.
///
/// * `new_ndims` - The number of dimensions of `new_shape` as a [`Type`].
/// * `new_shape` - The target shape to do `np.reshape()`.
#[must_use]
pub fn reshape_or_copy<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
new_ndims: u64,
new_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
) -> Self {
// TODO: The current criterion for whether to do a full copy or not is by checking `is_c_contiguous`,
// but this is not optimal - there are cases when the ndarray is not contiguous but could be reshaped
// without copying data. Look into how numpy does it.
let current_bb = ctx.builder.get_insert_block().unwrap();
let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "then_bb");
let else_bb = ctx.ctx.insert_basic_block_after(then_bb, "else_bb");
let end_bb = ctx.ctx.insert_basic_block_after(else_bb, "end_bb");
let dst_ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, new_ndims);
dst_ndarray.copy_shape_from_array(generator, ctx, new_shape);
// Reolsve negative indices
let size = self.size(generator, ctx);
let dst_ndims = dst_ndarray.ndims_llvm(generator, ctx.ctx);
let dst_shape = dst_ndarray.instance.get(generator, ctx, |f| f.shape);
call_nac3_ndarray_reshape_resolve_and_check_new_shape(
generator, ctx, size, dst_ndims, dst_shape,
);
let is_c_contiguous = self.is_c_contiguous(generator, ctx);
ctx.builder.build_conditional_branch(is_c_contiguous.value, then_bb, else_bb).unwrap();
// Inserting into then_bb: reshape is possible without copying
ctx.builder.position_at_end(then_bb);
dst_ndarray.set_strides_contiguous(generator, ctx);
dst_ndarray.instance.set(ctx, |f| f.data, self.instance.get(generator, ctx, |f| f.data));
ctx.builder.build_unconditional_branch(end_bb).unwrap();
// Inserting into else_bb: reshape is impossible without copying
ctx.builder.position_at_end(else_bb);
dst_ndarray.create_data(generator, ctx);
dst_ndarray.copy_data_from(generator, ctx, *self);
ctx.builder.build_unconditional_branch(end_bb).unwrap();
// Reposition for continuation
ctx.builder.position_at_end(end_bb);
dst_ndarray
}
}

View File

@ -0,0 +1,99 @@
use inkwell::values::StructValue;
use itertools::Itertools;
use crate::{
codegen::{model::*, CodeGenContext, CodeGenerator},
typecheck::typedef::{Type, TypeEnum},
};
use super::any::AnyObject;
/// A NAC3 tuple object.
///
/// NOTE: This struct has no copy trait.
#[derive(Debug, Clone)]
pub struct TupleObject<'ctx> {
/// The type of the tuple.
pub tys: Vec<Type>,
/// The underlying LLVM struct value of this tuple.
pub value: StructValue<'ctx>,
}
impl<'ctx> TupleObject<'ctx> {
pub fn from_object(ctx: &mut CodeGenContext<'ctx, '_>, object: AnyObject<'ctx>) -> Self {
// TODO: Keep `is_vararg_ctx` from TTuple?
// Sanity check on object type.
let TypeEnum::TTuple { ty: tys, .. } = &*ctx.unifier.get_ty(object.ty) else {
panic!(
"Expected type to be a TypeEnum::TTuple, got {}",
ctx.unifier.stringify(object.ty)
);
};
// Check number of fields
let value = object.value.into_struct_value();
let value_num_fields = value.get_type().count_fields() as usize;
assert!(
value_num_fields == tys.len(),
"Tuple type has {} item(s), but the LLVM struct value has {} field(s)",
tys.len(),
value_num_fields
);
TupleObject { tys: tys.clone(), value }
}
/// Convenience function. Create a [`TupleObject`] from an iterator of objects.
pub fn from_objects<I, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
objects: I,
) -> Self
where
I: IntoIterator<Item = AnyObject<'ctx>>,
{
let (values, tys): (Vec<_>, Vec<_>) =
objects.into_iter().map(|object| (object.value, object.ty)).unzip();
let llvm_tys = tys.iter().map(|ty| ctx.get_llvm_type(generator, *ty)).collect_vec();
let llvm_tuple_ty = ctx.ctx.struct_type(&llvm_tys, false);
let pllvm_tuple = ctx.builder.build_alloca(llvm_tuple_ty, "tuple").unwrap();
for (i, val) in values.into_iter().enumerate() {
let pval = ctx.builder.build_struct_gep(pllvm_tuple, i as u32, "value").unwrap();
ctx.builder.build_store(pval, val).unwrap();
}
let value = ctx.builder.build_load(pllvm_tuple, "").unwrap().into_struct_value();
TupleObject { tys, value }
}
#[must_use]
pub fn num_elements(&self) -> usize {
self.tys.len()
}
/// Get the `len()` of this tuple.
#[must_use]
pub fn len<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Instance<'ctx, Int<SizeT>> {
Int(SizeT).const_int(generator, ctx.ctx, self.num_elements() as u64, false)
}
/// Get the `i`-th (0-based) object in this tuple.
pub fn index(&self, ctx: &mut CodeGenContext<'ctx, '_>, i: usize) -> AnyObject<'ctx> {
assert!(
i < self.num_elements(),
"Tuple object with length {} have index {i}",
self.num_elements()
);
let value = ctx.builder.build_extract_value(self.value, i as u32, "tuple[{i}]").unwrap();
let ty = self.tys[i];
AnyObject { ty, value }
}
}

View File

@ -0,0 +1 @@
pub mod slice;

View File

@ -0,0 +1,125 @@
use crate::codegen::{model::*, CodeGenContext, CodeGenerator};
/// Fields of [`Slice`]
#[derive(Debug, Clone)]
pub struct SliceFields<'ctx, F: FieldTraversal<'ctx>, N: IntKind<'ctx>> {
pub start_defined: F::Output<Int<Bool>>,
pub start: F::Output<Int<N>>,
pub stop_defined: F::Output<Int<Bool>>,
pub stop: F::Output<Int<N>>,
pub step_defined: F::Output<Int<Bool>>,
pub step: F::Output<Int<N>>,
}
/// An IRRT representation of an (unresolved) slice.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct Slice<N>(pub N);
impl<'ctx, N: IntKind<'ctx>> StructKind<'ctx> for Slice<N> {
type Fields<F: FieldTraversal<'ctx>> = SliceFields<'ctx, F, N>;
fn iter_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
Self::Fields {
start_defined: traversal.add_auto("start_defined"),
start: traversal.add("start", Int(self.0)),
stop_defined: traversal.add_auto("stop_defined"),
stop: traversal.add("stop", Int(self.0)),
step_defined: traversal.add_auto("step_defined"),
step: traversal.add("step", Int(self.0)),
}
}
}
/// A Rust structure that has [`Slice`] utilities and looks like a [`Slice`] but
/// `start`, `stop` and `step` are held by LLVM registers only and possibly
/// [`Option::None`] if unspecified.
#[derive(Debug, Clone)]
pub struct RustSlice<'ctx, N: IntKind<'ctx>> {
// It is possible that `start`, `stop`, and `step` are all `None`.
// We need to know the `int_kind` even when that is the case.
pub int_kind: N,
pub start: Option<Instance<'ctx, Int<N>>>,
pub stop: Option<Instance<'ctx, Int<N>>>,
pub step: Option<Instance<'ctx, Int<N>>>,
}
impl<'ctx, N: IntKind<'ctx>> RustSlice<'ctx, N> {
/// Write the contents to an LLVM [`Slice`].
pub fn write_to_slice<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
dst_slice_ptr: Instance<'ctx, Ptr<Struct<Slice<N>>>>,
) {
let false_ = Int(Bool).const_false(generator, ctx.ctx);
let true_ = Int(Bool).const_true(generator, ctx.ctx);
match self.start {
Some(start) => {
dst_slice_ptr.gep(ctx, |f| f.start_defined).store(ctx, true_);
dst_slice_ptr.gep(ctx, |f| f.start).store(ctx, start);
}
None => dst_slice_ptr.gep(ctx, |f| f.start_defined).store(ctx, false_),
}
match self.stop {
Some(stop) => {
dst_slice_ptr.gep(ctx, |f| f.stop_defined).store(ctx, true_);
dst_slice_ptr.gep(ctx, |f| f.stop).store(ctx, stop);
}
None => dst_slice_ptr.gep(ctx, |f| f.stop_defined).store(ctx, false_),
}
match self.step {
Some(step) => {
dst_slice_ptr.gep(ctx, |f| f.step_defined).store(ctx, true_);
dst_slice_ptr.gep(ctx, |f| f.step).store(ctx, step);
}
None => dst_slice_ptr.gep(ctx, |f| f.step_defined).store(ctx, false_),
}
}
}
pub mod util {
use nac3parser::ast::Expr;
use crate::{
codegen::{model::*, CodeGenContext, CodeGenerator},
typecheck::typedef::Type,
};
use super::RustSlice;
/// Generate LLVM IR for an [`ExprKind::Slice`] and convert it into a [`RustSlice`].
#[allow(clippy::type_complexity)]
pub fn gen_slice<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
lower: &Option<Box<Expr<Option<Type>>>>,
upper: &Option<Box<Expr<Option<Type>>>>,
step: &Option<Box<Expr<Option<Type>>>>,
) -> Result<RustSlice<'ctx, Int32>, String> {
let mut help = |value_expr: &Option<Box<Expr<Option<Type>>>>| -> Result<_, String> {
Ok(match value_expr {
None => None,
Some(value_expr) => {
let value_expr = generator
.gen_expr(ctx, value_expr)?
.unwrap()
.to_basic_value_enum(ctx, generator, ctx.primitives.int32)?;
let value_expr =
Int(Int32).check_value(generator, ctx.ctx, value_expr).unwrap();
Some(value_expr)
}
})
};
let start = help(lower)?;
let stop = help(upper)?;
let step = help(step)?;
Ok(RustSlice { int_kind: Int32, start, stop, step })
}
}

View File

@ -1,16 +1,3 @@
use inkwell::{
attributes::{Attribute, AttributeLoc},
basic_block::BasicBlock,
types::{BasicType, BasicTypeEnum},
values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue},
IntPredicate,
};
use itertools::{izip, Itertools};
use nac3parser::ast::{
Constant, ExcepthandlerKind, Expr, ExprKind, Location, Stmt, StmtKind, StrRef,
};
use super::{
classes::{ArrayLikeIndexer, ArraySliceValue, ListValue, RangeValue},
expr::{destructure_range, gen_binop_expr},
@ -27,6 +14,17 @@ use crate::{
typedef::{iter_type_vars, FunSignature, Type, TypeEnum},
},
};
use inkwell::{
attributes::{Attribute, AttributeLoc},
basic_block::BasicBlock,
types::{BasicType, BasicTypeEnum},
values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue},
IntPredicate,
};
use itertools::{izip, Itertools};
use nac3parser::ast::{
Constant, ExcepthandlerKind, Expr, ExprKind, Location, Stmt, StmtKind, StrRef,
};
/// See [`CodeGenerator::gen_var_alloc`].
pub fn gen_var<'ctx>(
@ -1828,37 +1826,6 @@ pub fn gen_stmt<G: CodeGenerator>(
stmt.location,
);
}
StmtKind::Global { names, .. } => {
let registered_globals = ctx
.top_level
.definitions
.read()
.iter()
.filter_map(|def| {
if let TopLevelDef::Variable { simple_name, ty, .. } = &*def.read() {
Some((*simple_name, *ty))
} else {
None
}
})
.collect_vec();
for id in names {
let Some((_, ty)) = registered_globals.iter().find(|(name, _)| name == id) else {
return Err(format!("{id} is not a global at {}", stmt.location));
};
let resolver = ctx.resolver.clone();
let ptr = resolver
.get_symbol_value(*id, ctx, generator)
.map(|val| val.to_basic_value_enum(ctx, generator, *ty))
.transpose()?
.map(BasicValueEnum::into_pointer_value)
.unwrap();
ctx.var_assignment.insert(*id, (ptr, None, 0));
}
}
_ => unimplemented!(),
};
Ok(())

View File

@ -1,37 +1,34 @@
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use indexmap::IndexMap;
use indoc::indoc;
use inkwell::{
targets::{InitializationConfig, Target},
OptimizationLevel,
};
use nac3parser::{
ast::{fold::Fold, FileName, StrRef},
parser::parse_program,
};
use parking_lot::RwLock;
use super::{
classes::{ListType, NDArrayType, ProxyType, RangeType},
concrete_type::ConcreteTypeStore,
CodeGenContext, CodeGenLLVMOptions, CodeGenTargetMachineOptions, CodeGenTask, CodeGenerator,
DefaultCodeGenerator, WithCall, WorkerRegistry,
};
use crate::{
codegen::{
classes::{ListType, NDArrayType, ProxyType, RangeType},
concrete_type::ConcreteTypeStore,
CodeGenContext, CodeGenLLVMOptions, CodeGenTargetMachineOptions, CodeGenTask,
CodeGenerator, DefaultCodeGenerator, WithCall, WorkerRegistry,
},
symbol_resolver::{SymbolResolver, ValueEnum},
toplevel::{
composer::{ComposerConfig, TopLevelComposer},
DefinitionId, FunInstance, TopLevelContext, TopLevelDef,
},
typecheck::{
type_inferencer::{FunctionData, IdentifierInfo, Inferencer, PrimitiveStore},
type_inferencer::{FunctionData, Inferencer, PrimitiveStore},
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
},
};
use indexmap::IndexMap;
use indoc::indoc;
use inkwell::{
targets::{InitializationConfig, Target},
OptimizationLevel,
};
use nac3parser::ast::FileName;
use nac3parser::{
ast::{fold::Fold, StrRef},
parser::parse_program,
};
use parking_lot::RwLock;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
struct Resolver {
id_to_type: HashMap<StrRef, Type>,
@ -67,7 +64,6 @@ impl SymbolResolver for Resolver {
&self,
_: StrRef,
_: &mut CodeGenContext<'ctx, '_>,
_: &mut dyn CodeGenerator,
) -> Option<ValueEnum<'ctx>> {
unimplemented!()
}
@ -142,8 +138,7 @@ fn test_primitives() {
};
let mut virtual_checks = Vec::new();
let mut calls = HashMap::new();
let mut identifiers: HashMap<_, _> =
["a".into(), "b".into()].map(|id| (id, IdentifierInfo::default())).into();
let mut identifiers: HashSet<_> = ["a".into(), "b".into()].into();
let mut inferencer = Inferencer {
top_level: &top_level,
function_data: &mut function_data,
@ -322,8 +317,7 @@ fn test_simple_call() {
};
let mut virtual_checks = Vec::new();
let mut calls = HashMap::new();
let mut identifiers: HashMap<_, _> =
["a".into(), "foo".into()].map(|id| (id, IdentifierInfo::default())).into();
let mut identifiers: HashSet<_> = ["a".into(), "foo".into()].into();
let mut inferencer = Inferencer {
top_level: &top_level,
function_data: &mut function_data,

View File

@ -19,10 +19,6 @@
clippy::wildcard_imports
)]
// users of nac3core need to use the same version of these dependencies, so expose them as nac3core::*
pub use inkwell;
pub use nac3parser;
pub mod codegen;
pub mod symbol_resolver;
pub mod toplevel;

View File

@ -1,15 +1,7 @@
use std::{
collections::{HashMap, HashSet},
fmt::{Debug, Display},
rc::Rc,
sync::Arc,
};
use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue};
use itertools::{chain, izip, Itertools};
use parking_lot::RwLock;
use nac3parser::ast::{Constant, Expr, Location, StrRef};
use std::fmt::Debug;
use std::rc::Rc;
use std::sync::Arc;
use std::{collections::HashMap, collections::HashSet, fmt::Display};
use crate::{
codegen::{CodeGenContext, CodeGenerator},
@ -19,6 +11,10 @@ use crate::{
typedef::{Type, TypeEnum, Unifier, VarMap},
},
};
use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue};
use itertools::{chain, izip, Itertools};
use nac3parser::ast::{Constant, Expr, Location, StrRef};
use parking_lot::RwLock;
#[derive(Clone, PartialEq, Debug)]
pub enum SymbolValue {
@ -369,7 +365,6 @@ pub trait SymbolResolver {
&self,
str: StrRef,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
) -> Option<ValueEnum<'ctx>>;
fn get_default_param_value(&self, expr: &Expr) -> Option<SymbolValue>;

View File

@ -1,5 +1,6 @@
use std::iter::once;
use helper::{debug_assert_prim_is_allowed, extract_ndims, make_exception_fields, PrimDefDetails};
use indexmap::IndexMap;
use inkwell::{
attributes::{Attribute, AttributeLoc},
@ -8,24 +9,28 @@ use inkwell::{
IntPredicate,
};
use itertools::Either;
use numpy::unpack_ndarray_var_tys;
use strum::IntoEnumIterator;
use super::{
helper::{debug_assert_prim_is_allowed, make_exception_fields, PrimDef, PrimDefDetails},
numpy::make_ndarray_ty,
*,
};
use crate::{
codegen::{
builtin_fns,
classes::{ProxyValue, RangeValue},
model::*,
numpy::*,
object::{
any::AnyObject,
ndarray::{shape_util::parse_numpy_int_sequence, NDArrayObject},
},
stmt::exn_constructor,
},
symbol_resolver::SymbolValue,
toplevel::{helper::PrimDef, numpy::make_ndarray_ty},
typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap},
};
use super::*;
type BuiltinInfo = Vec<(Arc<RwLock<TopLevelDef>>, Option<Stmt>)>;
pub fn get_exn_constructor(
@ -512,6 +517,14 @@ impl<'a> BuiltinBuilder<'a> {
| PrimDef::FunNpEye
| PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim),
PrimDef::FunNpSize | PrimDef::FunNpShape | PrimDef::FunNpStrides => {
self.build_ndarray_property_getter_function(prim)
}
PrimDef::FunNpTranspose | PrimDef::FunNpReshape => {
self.build_ndarray_view_function(prim)
}
PrimDef::FunStr => self.build_str_function(),
PrimDef::FunFloor | PrimDef::FunFloor64 | PrimDef::FunCeil | PrimDef::FunCeil64 => {
@ -577,10 +590,6 @@ impl<'a> BuiltinBuilder<'a> {
| PrimDef::FunNpHypot
| PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim),
PrimDef::FunNpTranspose | PrimDef::FunNpReshape => {
self.build_np_sp_ndarray_function(prim)
}
PrimDef::FunNpDot
| PrimDef::FunNpLinalgCholesky
| PrimDef::FunNpLinalgQr
@ -1386,6 +1395,149 @@ impl<'a> BuiltinBuilder<'a> {
}
}
fn build_ndarray_property_getter_function(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(
prim,
&[PrimDef::FunNpSize, PrimDef::FunNpShape, PrimDef::FunNpStrides],
);
let in_ndarray_ty = self.unifier.get_fresh_var_with_range(
&[self.primitives.ndarray],
Some("T".into()),
None,
);
match prim {
PrimDef::FunNpSize => create_fn_by_codegen(
self.unifier,
&into_var_map([in_ndarray_ty]),
prim.name(),
self.primitives.int32,
&[(in_ndarray_ty.ty, "a")],
Box::new(|ctx, obj, fun, args, generator| {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let ndarray_ty = fun.0.args[0].ty;
let ndarray =
args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
let size =
ndarray.size(generator, ctx).truncate_or_bit_cast(generator, ctx, Int32);
Ok(Some(size.value.as_basic_value_enum()))
}),
),
PrimDef::FunNpShape | PrimDef::FunNpStrides => {
// The function signatures of `np_shape` an `np_size` are the same.
// Mixed together for convenience.
// The return type is a tuple of variable length depending on the ndims of the input ndarray.
let ret_ty = self.unifier.get_dummy_var().ty; // Handled by special folding
create_fn_by_codegen(
self.unifier,
&into_var_map([in_ndarray_ty]),
prim.name(),
ret_ty,
&[(in_ndarray_ty.ty, "a")],
Box::new(move |ctx, obj, fun, args, generator| {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let ndarray_ty = fun.0.args[0].ty;
let ndarray =
args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
let result_tuple = match prim {
PrimDef::FunNpShape => ndarray.make_shape_tuple(generator, ctx),
PrimDef::FunNpStrides => ndarray.make_strides_tuple(generator, ctx),
_ => unreachable!(),
};
Ok(Some(result_tuple.value.as_basic_value_enum()))
}),
)
}
_ => unreachable!(),
}
}
/// Build np/sp functions that take as input `NDArray` only
fn build_ndarray_view_function(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]);
let in_ndarray_ty = self.unifier.get_fresh_var_with_range(
&[self.primitives.ndarray],
Some("T".into()),
None,
);
match prim {
PrimDef::FunNpTranspose => create_fn_by_codegen(
self.unifier,
&into_var_map([in_ndarray_ty]),
prim.name(),
in_ndarray_ty.ty,
&[(in_ndarray_ty.ty, "x")],
Box::new(move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty;
let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?))
}),
),
// NOTE: on `ndarray_factory_fn_shape_arg_tvar` and
// the `param_ty` for `create_fn_by_codegen`.
//
// Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking
// to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`],
// and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`.
PrimDef::FunNpReshape => {
let ret_ty = self.unifier.get_dummy_var().ty; // Handled by special holding
create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
ret_ty,
&[
(in_ndarray_ty.ty, "x"),
(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape"), // Handled by special folding
],
Box::new(move |ctx, _, fun, args, generator| {
let ndarray_ty = fun.0.args[0].ty;
let ndarray_val =
args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
let shape_ty = fun.0.args[1].ty;
let shape_val =
args[1].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
let ndarray = AnyObject { value: ndarray_val, ty: ndarray_ty };
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
let shape = AnyObject { value: shape_val, ty: shape_ty };
let (_, shape) = parse_numpy_int_sequence(generator, ctx, shape);
// The ndims after reshaping is gotten from the return type of the call.
let (_, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
let ndims = extract_ndims(&ctx.unifier, ndims);
let new_ndarray = ndarray.reshape_or_copy(generator, ctx, ndims, shape);
Ok(Some(new_ndarray.instance.value.as_basic_value_enum()))
}),
)
}
_ => unreachable!(),
}
}
/// Build the `str()` function.
fn build_str_function(&mut self) -> TopLevelDef {
let prim = PrimDef::FunStr;
@ -1873,57 +2025,6 @@ impl<'a> BuiltinBuilder<'a> {
}
}
/// Build np/sp functions that take as input `NDArray` only
fn build_np_sp_ndarray_function(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]);
match prim {
PrimDef::FunNpTranspose => {
let ndarray_ty = self.unifier.get_fresh_var_with_range(
&[self.ndarray_num_ty],
Some("T".into()),
None,
);
create_fn_by_codegen(
self.unifier,
&into_var_map([ndarray_ty]),
prim.name(),
ndarray_ty.ty,
&[(ndarray_ty.ty, "x")],
Box::new(move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty;
let arg_val =
args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?))
}),
)
}
// NOTE: on `ndarray_factory_fn_shape_arg_tvar` and
// the `param_ty` for `create_fn_by_codegen`.
//
// Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking
// to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`],
// and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`.
PrimDef::FunNpReshape => create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
self.ndarray_num_ty,
&[(self.ndarray_num_ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_ty = fun.0.args[1].ty;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
}),
),
_ => unreachable!(),
}
}
/// Build `np_linalg` and `sp_linalg` functions
///
/// The input to these functions must be floating point `NDArray`

View File

@ -1,17 +1,17 @@
use nac3parser::ast::fold::Fold;
use std::rc::Rc;
use nac3parser::ast::{fold::Fold, ExprKind, Ident};
use super::*;
use crate::{
codegen::{expr::get_subst_key, stmt::exn_constructor},
symbol_resolver::SymbolValue,
typecheck::{
type_inferencer::{FunctionData, IdentifierInfo, Inferencer},
type_inferencer::{FunctionData, Inferencer},
typedef::{TypeVar, VarMap},
},
};
use super::*;
pub struct ComposerConfig {
pub kernel_ann: Option<&'static str>,
pub kernel_invariant_ann: &'static str,
@ -101,8 +101,7 @@ impl TopLevelComposer {
.iter()
.map(|def_ast| match *def_ast.0.read() {
TopLevelDef::Class { name, .. } => name.to_string(),
TopLevelDef::Function { simple_name, .. }
| TopLevelDef::Variable { simple_name, .. } => simple_name.to_string(),
TopLevelDef::Function { simple_name, .. } => simple_name.to_string(),
})
.collect_vec();
@ -382,87 +381,13 @@ impl TopLevelComposer {
))
}
ast::StmtKind::Assign { .. } => {
// Assignment statements can assign to (and therefore create) more than one
// variable, but this function only allows returning one set of symbol information.
// We want to avoid changing this to return a `Vec` of symbol info, as this would
// require `iter().next().unwrap()` on every variable created from a non-Assign
// statement.
//
// Make callers use `register_top_level_var` instead, as it provides more
// fine-grained control over which symbols to register, while also simplifying the
// usage of this function.
panic!("Registration of top-level Assign statements must use TopLevelComposer::register_top_level_var (at {})", ast.location);
}
ast::StmtKind::AnnAssign { target, annotation, .. } => {
let ExprKind::Name { id: name, .. } = target.node else {
return Err(format!(
"global variable declaration must be an identifier (at {})",
target.location
));
};
self.register_top_level_var(
name,
Some(annotation.as_ref().clone()),
resolver,
mod_path,
target.location,
)
}
_ => Err(format!(
"registrations of constructs other than top level classes/functions/variables are not supported (at {})",
"registrations of constructs other than top level classes/functions are not supported (at {})",
ast.location
)),
}
}
/// Registers a top-level variable with the given `name` into the composer.
///
/// `annotation` - The type annotation of the top-level variable, or [`None`] if no type
/// annotation is provided.
/// `location` - The location of the top-level variable.
pub fn register_top_level_var(
&mut self,
name: Ident,
annotation: Option<Expr>,
resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>,
mod_path: &str,
location: Location,
) -> Result<(StrRef, DefinitionId, Option<Type>), String> {
if self.keyword_list.contains(&name) {
return Err(format!("cannot use keyword `{name}` as a class name (at {location})"));
}
let global_var_name =
if mod_path.is_empty() { name.to_string() } else { format!("{mod_path}.{name}") };
if !self.defined_names.insert(global_var_name.clone()) {
return Err(format!(
"global variable `{global_var_name}` defined twice (at {location})"
));
}
let ty_to_be_unified = self.unifier.get_dummy_var().ty;
self.definition_ast_list.push((
RwLock::new(Self::make_top_level_variable_def(
global_var_name,
name,
// dummy here, unify with correct type later,
ty_to_be_unified,
annotation,
resolver,
Some(location),
))
.into(),
None,
));
Ok((name, DefinitionId(self.definition_ast_list.len() - 1), Some(ty_to_be_unified)))
}
pub fn start_analysis(&mut self, inference: bool) -> Result<(), HashSet<String>> {
self.analyze_top_level_class_type_var()?;
self.analyze_top_level_class_bases()?;
@ -471,7 +396,6 @@ impl TopLevelComposer {
if inference {
self.analyze_function_instance()?;
}
self.analyze_top_level_variables()?;
Ok(())
}
@ -509,7 +433,7 @@ impl TopLevelComposer {
// things like `class A(Generic[T, V, ImportedModule.T])` is not supported
// i.e. only simple names are allowed in the subscript
// should update the TopLevelDef::Class.typevars and the TypeEnum::TObj.params
ExprKind::Subscript { value, slice, .. }
ast::ExprKind::Subscript { value, slice, .. }
if {
matches!(
&value.node,
@ -525,9 +449,9 @@ impl TopLevelComposer {
}
is_generic = true;
let type_var_list: Vec<&Expr<()>>;
let type_var_list: Vec<&ast::Expr<()>>;
// if `class A(Generic[T, V, G])`
if let ExprKind::Tuple { elts, .. } = &slice.node {
if let ast::ExprKind::Tuple { elts, .. } = &slice.node {
type_var_list = elts.iter().collect_vec();
// `class A(Generic[T])`
} else {
@ -576,7 +500,6 @@ impl TopLevelComposer {
}
Ok(())
};
let mut errors = HashSet::new();
for (class_def, class_ast) in def_list.iter().skip(self.builtin_num) {
if class_ast.is_none() {
@ -930,6 +853,7 @@ impl TopLevelComposer {
let unifier = self.unifier.borrow_mut();
let primitives_store = &self.primitives_ty;
let mut errors = HashSet::new();
let mut analyze = |function_def: &Arc<RwLock<TopLevelDef>>, function_ast: &Option<Stmt>| {
let mut function_def = function_def.write();
let function_def = &mut *function_def;
@ -1038,18 +962,18 @@ impl TopLevelComposer {
}
}
let arg_with_default: Vec<(&ast::Located<ast::ArgData<()>>, Option<&Expr>)> = args
.args
.iter()
.rev()
.zip(
args.defaults
.iter()
.rev()
.map(|x| -> Option<&Expr> { Some(x) })
.chain(std::iter::repeat(None)),
)
.collect_vec();
let arg_with_default: Vec<(&ast::Located<ast::ArgData<()>>, Option<&ast::Expr>)> =
args.args
.iter()
.rev()
.zip(
args.defaults
.iter()
.rev()
.map(|x| -> Option<&ast::Expr> { Some(x) })
.chain(std::iter::repeat(None)),
)
.collect_vec();
arg_with_default
.iter()
@ -1204,8 +1128,6 @@ impl TopLevelComposer {
})?;
Ok(())
};
let mut errors = HashSet::new();
for (function_def, function_ast) in def_list.iter().skip(self.builtin_num) {
if function_ast.is_none() {
continue;
@ -1307,7 +1229,7 @@ impl TopLevelComposer {
let arg_with_default: Vec<(
&ast::Located<ast::ArgData<()>>,
Option<&Expr>,
Option<&ast::Expr>,
)> = args
.args
.iter()
@ -1316,7 +1238,7 @@ impl TopLevelComposer {
args.defaults
.iter()
.rev()
.map(|x| -> Option<&Expr> { Some(x) })
.map(|x| -> Option<&ast::Expr> { Some(x) })
.chain(std::iter::repeat(None)),
)
.collect_vec();
@ -1473,7 +1395,7 @@ impl TopLevelComposer {
.map_err(|e| HashSet::from([e.to_display(unifier).to_string()]))?;
}
ast::StmtKind::AnnAssign { target, annotation, value, .. } => {
if let ExprKind::Name { id: attr, .. } = &target.node {
if let ast::ExprKind::Name { id: attr, .. } = &target.node {
if defined_fields.insert(attr.to_string()) {
let dummy_field_type = unifier.get_dummy_var().ty;
@ -1481,7 +1403,7 @@ impl TopLevelComposer {
None => {
// handle Kernel[T], KernelInvariant[T]
let (annotation, mutable) = match &annotation.node {
ExprKind::Subscript { value, slice, .. }
ast::ExprKind::Subscript { value, slice, .. }
if matches!(
&value.node,
ast::ExprKind::Name { id, .. } if id == &core_config.kernel_invariant_ann.into()
@ -1489,7 +1411,7 @@ impl TopLevelComposer {
{
(slice, false)
}
ExprKind::Subscript { value, slice, .. }
ast::ExprKind::Subscript { value, slice, .. }
if matches!(
&value.node,
ast::ExprKind::Name { id, .. } if core_config.kernel_ann.map_or(false, |c| id == &c.into())
@ -1507,13 +1429,13 @@ impl TopLevelComposer {
Some(boxed_expr) => {
// Class attributes are set as immutable regardless
let (annotation, _) = match &annotation.node {
ExprKind::Subscript { slice, .. } => (slice, false),
ast::ExprKind::Subscript { slice, .. } => (slice, false),
_ if core_config.kernel_ann.is_none() => (annotation, false),
_ => continue,
};
match &**boxed_expr {
ast::Located {location: _, custom: (), node: ExprKind::Constant { value: v, kind: _ }} => {
ast::Located {location: _, custom: (), node: ast::ExprKind::Constant { value: v, kind: _ }} => {
// Restricting the types allowed to be defined as class attributes
match v {
ast::Constant::Bool(_) | ast::Constant::Str(_) | ast::Constant::Int(_) | ast::Constant::Float(_) => {}
@ -1780,6 +1702,7 @@ impl TopLevelComposer {
}
}
let mut errors = HashSet::new();
let mut analyze = |i, def: &Arc<RwLock<TopLevelDef>>, ast: &Option<Stmt>| {
let class_def = def.read();
if let TopLevelDef::Class {
@ -1922,8 +1845,6 @@ impl TopLevelComposer {
}
Ok(())
};
let mut errors = HashSet::new();
for (i, (def, ast)) in definition_ast_list.iter().enumerate().skip(self.builtin_num) {
if ast.is_none() {
continue;
@ -1961,296 +1882,272 @@ impl TopLevelComposer {
if ast.is_none() {
return Ok(());
}
let (name, simple_name, signature, resolver) = {
let function_def = def.read();
let TopLevelDef::Function { name, simple_name, signature, resolver, .. } =
&*function_def
let mut function_def = def.write();
if let TopLevelDef::Function {
instance_to_stmt,
instance_to_symbol,
name,
simple_name,
signature,
resolver,
..
} = &mut *function_def
{
let signature_ty_enum = unifier.get_ty(*signature);
let TypeEnum::TFunc(FunSignature { args, ret, vars }) = signature_ty_enum.as_ref()
else {
return Ok(());
unreachable!("must be typeenum::tfunc")
};
(name.clone(), *simple_name, *signature, resolver.clone())
};
let mut vars = vars.clone();
// None if is not class method
let uninst_self_type = {
if let Some(class_id) = method_class.get(&DefinitionId(id)) {
let class_def = definition_ast_list.get(class_id.0).unwrap();
let class_def = class_def.0.read();
let TopLevelDef::Class { type_vars, .. } = &*class_def else {
unreachable!("must be class def")
};
let signature_ty_enum = unifier.get_ty(signature);
let TypeEnum::TFunc(FunSignature { args, ret, vars, .. }) = signature_ty_enum.as_ref()
else {
unreachable!("must be typeenum::tfunc")
};
let ty_ann = make_self_type_annotation(type_vars, *class_id);
let self_ty = get_type_from_type_annotation_kinds(
&def_list,
unifier,
primitives_ty,
&ty_ann,
&mut None,
)?;
vars.extend(type_vars.iter().map(|ty| {
let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*ty) else {
unreachable!()
};
let mut vars = vars.clone();
// None if is not class method
let uninst_self_type = {
if let Some(class_id) = method_class.get(&DefinitionId(id)) {
let class_def = definition_ast_list.get(class_id.0).unwrap();
let class_def = class_def.0.read();
let TopLevelDef::Class { type_vars, .. } = &*class_def else {
unreachable!("must be class def")
(*id, *ty)
}));
Some((self_ty, type_vars.clone()))
} else {
None
}
};
// carefully handle those with bounds, without bounds and no typevars
// if class methods, `vars` also contains all class typevars here
let (type_var_subst_comb, no_range_vars) = {
let mut no_ranges: Vec<Type> = Vec::new();
let var_combs = vars
.values()
.map(|ty| {
unifier.get_instantiations(*ty).unwrap_or_else(|| {
let TypeEnum::TVar { name, loc, is_const_generic: false, .. } =
&*unifier.get_ty(*ty)
else {
unreachable!()
};
let rigid = unifier.get_fresh_rigid_var(*name, *loc).ty;
no_ranges.push(rigid);
vec![rigid]
})
})
.multi_cartesian_product()
.collect_vec();
let mut result: Vec<VarMap> = Vec::default();
for comb in var_combs {
result.push(vars.keys().copied().zip(comb).collect());
}
// NOTE: if is empty, means no type var, append a empty subst, ok to do this?
if result.is_empty() {
result.push(VarMap::new());
}
(result, no_ranges)
};
for subst in type_var_subst_comb {
// for each instance
let inst_ret = unifier.subst(*ret, &subst).unwrap_or(*ret);
let inst_args = {
args.iter()
.map(|a| FuncArg {
name: a.name,
ty: unifier.subst(a.ty, &subst).unwrap_or(a.ty),
default_value: a.default_value.clone(),
is_vararg: false,
})
.collect_vec()
};
let self_type = {
uninst_self_type.clone().map(|(self_type, type_vars)| {
let subst_for_self = {
let class_ty_var_ids = type_vars
.iter()
.map(|x| {
if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x) {
*id
} else {
unreachable!("must be type var here");
}
})
.collect::<HashSet<_>>();
subst
.iter()
.filter_map(|(ty_var_id, ty_var_target)| {
if class_ty_var_ids.contains(ty_var_id) {
Some((*ty_var_id, *ty_var_target))
} else {
None
}
})
.collect::<VarMap>()
};
unifier.subst(self_type, &subst_for_self).unwrap_or(self_type)
})
};
let mut identifiers = {
let mut result: HashSet<_> = HashSet::new();
if self_type.is_some() {
result.insert("self".into());
}
result.extend(inst_args.iter().map(|x| x.name));
result
};
let mut calls: HashMap<CodeLocation, CallId> = HashMap::new();
let mut inferencer = Inferencer {
top_level: ctx.as_ref(),
defined_identifiers: identifiers.clone(),
function_data: &mut FunctionData {
resolver: resolver.as_ref().unwrap().clone(),
return_type: if unifier.unioned(inst_ret, primitives_ty.none) {
None
} else {
Some(inst_ret)
},
// NOTE: allowed type vars
bound_variables: no_range_vars.clone(),
},
unifier,
variable_mapping: {
let mut result: HashMap<StrRef, Type> = HashMap::new();
if let Some(self_ty) = self_type {
result.insert("self".into(), self_ty);
}
result.extend(inst_args.iter().map(|x| (x.name, x.ty)));
result
},
primitives: primitives_ty,
virtual_checks: &mut Vec::new(),
calls: &mut calls,
in_handler: false,
};
let ty_ann = make_self_type_annotation(type_vars, *class_id);
let self_ty = get_type_from_type_annotation_kinds(
&def_list,
unifier,
primitives_ty,
&ty_ann,
&mut None,
)?;
vars.extend(type_vars.iter().map(|ty| {
let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*ty) else {
unreachable!()
};
(*id, *ty)
}));
Some((self_ty, type_vars.clone()))
} else {
None
}
};
// carefully handle those with bounds, without bounds and no typevars
// if class methods, `vars` also contains all class typevars here
let (type_var_subst_comb, no_range_vars) = {
let mut no_ranges: Vec<Type> = Vec::new();
let var_combs = vars
.values()
.map(|ty| {
unifier.get_instantiations(*ty).unwrap_or_else(|| {
let TypeEnum::TVar { name, loc, is_const_generic: false, .. } =
&*unifier.get_ty(*ty)
else {
unreachable!()
};
let rigid = unifier.get_fresh_rigid_var(*name, *loc).ty;
no_ranges.push(rigid);
vec![rigid]
})
})
.multi_cartesian_product()
.collect_vec();
let mut result: Vec<VarMap> = Vec::default();
for comb in var_combs {
result.push(vars.keys().copied().zip(comb).collect());
}
// NOTE: if is empty, means no type var, append a empty subst, ok to do this?
if result.is_empty() {
result.push(VarMap::new());
}
(result, no_ranges)
};
for subst in type_var_subst_comb {
// for each instance
let inst_ret = unifier.subst(*ret, &subst).unwrap_or(*ret);
let inst_args = {
args.iter()
.map(|a| FuncArg {
name: a.name,
ty: unifier.subst(a.ty, &subst).unwrap_or(a.ty),
default_value: a.default_value.clone(),
is_vararg: false,
})
.collect_vec()
};
let self_type = {
uninst_self_type.clone().map(|(self_type, type_vars)| {
let subst_for_self = {
let class_ty_var_ids = type_vars
.iter()
.map(|x| {
if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x) {
*id
} else {
unreachable!("must be type var here");
}
})
.collect::<HashSet<_>>();
subst
.iter()
.filter_map(|(ty_var_id, ty_var_target)| {
if class_ty_var_ids.contains(ty_var_id) {
Some((*ty_var_id, *ty_var_target))
} else {
None
}
})
.collect::<VarMap>()
};
unifier.subst(self_type, &subst_for_self).unwrap_or(self_type)
})
};
let mut identifiers = {
let mut result = HashMap::new();
if self_type.is_some() {
result.insert("self".into(), IdentifierInfo::default());
}
result.extend(inst_args.iter().map(|x| (x.name, IdentifierInfo::default())));
result
};
let mut calls: HashMap<CodeLocation, CallId> = HashMap::new();
let mut inferencer = Inferencer {
top_level: ctx.as_ref(),
defined_identifiers: identifiers.clone(),
function_data: &mut FunctionData {
resolver: resolver.as_ref().unwrap().clone(),
return_type: if unifier.unioned(inst_ret, primitives_ty.none) {
None
} else {
Some(inst_ret)
},
// NOTE: allowed type vars
bound_variables: no_range_vars.clone(),
},
unifier,
variable_mapping: {
let mut result: HashMap<StrRef, Type> = HashMap::new();
if let Some(self_ty) = self_type {
result.insert("self".into(), self_ty);
}
result.extend(inst_args.iter().map(|x| (x.name, x.ty)));
result
},
primitives: primitives_ty,
virtual_checks: &mut Vec::new(),
calls: &mut calls,
in_handler: false,
};
let ast::StmtKind::FunctionDef { body, decorator_list, .. } =
ast.clone().unwrap().node
else {
unreachable!("must be function def ast")
};
if !decorator_list.is_empty() {
if matches!(&decorator_list[0].node, ExprKind::Name { id, .. } if id == &"extern".into())
let ast::StmtKind::FunctionDef { body, decorator_list, .. } =
ast.clone().unwrap().node
else {
unreachable!("must be function def ast")
};
if !decorator_list.is_empty()
&& matches!(&decorator_list[0].node,
ast::ExprKind::Name{ id, .. } if id == &"extern".into())
{
instance_to_symbol.insert(String::new(), simple_name.to_string());
continue;
}
if !decorator_list.is_empty()
&& matches!(&decorator_list[0].node,
ast::ExprKind::Name{ id, .. } if id == &"rpc".into())
{
let TopLevelDef::Function { instance_to_symbol, .. } = &mut *def.write()
else {
unreachable!()
};
instance_to_symbol.insert(String::new(), simple_name.to_string());
continue;
}
if matches!(&decorator_list[0].node, ExprKind::Name { id, .. } if id == &"rpc".into())
{
let TopLevelDef::Function { instance_to_symbol, .. } = &mut *def.write()
else {
unreachable!()
};
instance_to_symbol.insert(String::new(), simple_name.to_string());
continue;
}
if let ExprKind::Call { func, .. } = &decorator_list[0].node {
if matches!(&func.node, ExprKind::Name { id, .. } if id == &"rpc".into()) {
let TopLevelDef::Function { instance_to_symbol, .. } =
&mut *def.write()
else {
unreachable!()
};
instance_to_symbol.insert(String::new(), simple_name.to_string());
continue;
}
}
}
let fun_body =
body.into_iter()
let fun_body = body
.into_iter()
.map(|b| inferencer.fold_stmt(b))
.collect::<Result<Vec<_>, _>>()?;
let returned = inferencer.check_block(fun_body.as_slice(), &mut identifiers)?;
{
// check virtuals
let defs = ctx.definitions.read();
for (subtype, base, loc) in &*inferencer.virtual_checks {
let base_id = {
let base = inferencer.unifier.get_ty(*base);
if let TypeEnum::TObj { obj_id, .. } = &*base {
*obj_id
} else {
return Err(HashSet::from([format!(
"Base type should be a class (at {loc})"
)]));
}
};
let subtype_id = {
let ty = inferencer.unifier.get_ty(*subtype);
if let TypeEnum::TObj { obj_id, .. } = &*ty {
*obj_id
} else {
let returned = inferencer.check_block(fun_body.as_slice(), &mut identifiers)?;
{
// check virtuals
let defs = ctx.definitions.read();
for (subtype, base, loc) in &*inferencer.virtual_checks {
let base_id = {
let base = inferencer.unifier.get_ty(*base);
if let TypeEnum::TObj { obj_id, .. } = &*base {
*obj_id
} else {
return Err(HashSet::from([format!(
"Base type should be a class (at {loc})"
)]));
}
};
let subtype_id = {
let ty = inferencer.unifier.get_ty(*subtype);
if let TypeEnum::TObj { obj_id, .. } = &*ty {
*obj_id
} else {
let base_repr = inferencer.unifier.stringify(*base);
let subtype_repr = inferencer.unifier.stringify(*subtype);
return Err(HashSet::from([format!(
"Expected a subtype of {base_repr}, but got {subtype_repr} (at {loc})"),
]));
}
};
let subtype_entry = defs[subtype_id.0].read();
let TopLevelDef::Class { ancestors, .. } = &*subtype_entry else {
unreachable!()
};
let m = ancestors.iter()
.find(|kind| matches!(kind, TypeAnnotation::CustomClass { id, .. } if *id == base_id));
if m.is_none() {
let base_repr = inferencer.unifier.stringify(*base);
let subtype_repr = inferencer.unifier.stringify(*subtype);
return Err(HashSet::from([format!(
"Expected a subtype of {base_repr}, but got {subtype_repr} (at {loc})"),
]));
}
};
let subtype_entry = defs[subtype_id.0].read();
let TopLevelDef::Class { ancestors, .. } = &*subtype_entry else {
unreachable!()
};
let m = ancestors.iter()
.find(|kind| matches!(kind, TypeAnnotation::CustomClass { id, .. } if *id == base_id));
if m.is_none() {
let base_repr = inferencer.unifier.stringify(*base);
let subtype_repr = inferencer.unifier.stringify(*subtype);
return Err(HashSet::from([format!(
"Expected a subtype of {base_repr}, but got {subtype_repr} (at {loc})"),
]));
}
}
}
if !unifier.unioned(inst_ret, primitives_ty.none) && !returned {
let def_ast_list = &definition_ast_list;
let ret_str = unifier.internal_stringify(
inst_ret,
&mut |id| {
let TopLevelDef::Class { name, .. } = &*def_ast_list[id].0.read()
else {
unreachable!("must be class id here")
};
if !unifier.unioned(inst_ret, primitives_ty.none) && !returned {
let def_ast_list = &definition_ast_list;
let ret_str = unifier.internal_stringify(
inst_ret,
&mut |id| {
let TopLevelDef::Class { name, .. } = &*def_ast_list[id].0.read()
else {
unreachable!("must be class id here")
};
name.to_string()
name.to_string()
},
&mut |id| format!("typevar{id}"),
&mut None,
);
return Err(HashSet::from([format!(
"expected return type of `{}` in function `{}` (at {})",
ret_str,
name,
ast.as_ref().unwrap().location
)]));
}
instance_to_stmt.insert(
get_subst_key(
unifier,
self_type,
&subst,
Some(&vars.keys().copied().collect()),
),
FunInstance {
body: Arc::new(fun_body),
unifier_id: 0,
calls: Arc::new(calls),
subst,
},
&mut |id| format!("typevar{id}"),
&mut None,
);
return Err(HashSet::from([format!(
"expected return type of `{}` in function `{}` (at {})",
ret_str,
name,
ast.as_ref().unwrap().location
)]));
}
let TopLevelDef::Function { instance_to_stmt, .. } = &mut *def.write() else {
unreachable!()
};
instance_to_stmt.insert(
get_subst_key(
unifier,
self_type,
&subst,
Some(&vars.keys().copied().collect()),
),
FunInstance {
body: Arc::new(fun_body),
unifier_id: 0,
calls: Arc::new(calls),
subst,
},
);
}
Ok(())
};
for (id, (def, ast)) in self.definition_ast_list.iter().enumerate().skip(self.builtin_num) {
if ast.is_none() {
continue;
@ -2264,59 +2161,4 @@ impl TopLevelComposer {
}
Ok(())
}
/// Step 6. Analyze and populate the types of global variables.
fn analyze_top_level_variables(&mut self) -> Result<(), HashSet<String>> {
let def_list = &self.definition_ast_list;
let temp_def_list = self.extract_def_list();
let unifier = &mut self.unifier;
let primitives_store = &self.primitives_ty;
let mut analyze = |variable_def: &Arc<RwLock<TopLevelDef>>| -> Result<_, HashSet<String>> {
let TopLevelDef::Variable { ty: dummy_ty, ty_decl, resolver, loc, .. } =
&*variable_def.read()
else {
// not top level variable def, skip
return Ok(());
};
let resolver = &**resolver.as_ref().unwrap();
if let Some(ty_decl) = ty_decl {
let ty_annotation = parse_ast_to_type_annotation_kinds(
resolver,
&temp_def_list,
unifier,
primitives_store,
ty_decl,
HashMap::new(),
)?;
let ty_from_ty_annotation = get_type_from_type_annotation_kinds(
&temp_def_list,
unifier,
primitives_store,
&ty_annotation,
&mut None,
)?;
unifier.unify(*dummy_ty, ty_from_ty_annotation).map_err(|e| {
HashSet::from([e.at(Some(loc.unwrap())).to_display(unifier).to_string()])
})?;
}
Ok(())
};
let mut errors = HashSet::new();
for (variable_def, _) in def_list.iter().skip(self.builtin_num) {
if let Err(e) = analyze(variable_def) {
errors.extend(e);
}
}
if !errors.is_empty() {
return Err(errors);
}
Ok(())
}
}

View File

@ -1,15 +1,14 @@
use std::convert::TryInto;
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::numpy::unpack_ndarray_var_tys;
use crate::typecheck::typedef::{into_var_map, iter_type_vars, Mapping, TypeVarId, VarMap};
use ast::ExprKind;
use nac3parser::ast::{Constant, Location};
use strum::IntoEnumIterator;
use strum_macros::EnumIter;
use nac3parser::ast::{Constant, ExprKind, Location};
use super::{numpy::unpack_ndarray_var_tys, *};
use crate::{
symbol_resolver::SymbolValue,
typecheck::typedef::{into_var_map, iter_type_vars, Mapping, TypeVarId, VarMap},
};
use super::*;
/// All primitive types and functions in nac3core.
#[derive(Clone, Copy, Debug, EnumIter, PartialEq, Eq)]
@ -54,6 +53,15 @@ pub enum PrimDef {
FunNpEye,
FunNpIdentity,
// NumPy ndarray property getters
FunNpSize,
FunNpShape,
FunNpStrides,
// NumPy ndarray view functions
FunNpTranspose,
FunNpReshape,
// Miscellaneous NumPy & SciPy functions
FunNpRound,
FunNpFloor,
@ -101,8 +109,6 @@ pub enum PrimDef {
FunNpLdExp,
FunNpHypot,
FunNpNextAfter,
FunNpTranspose,
FunNpReshape,
// Linalg functions
FunNpDot,
@ -240,6 +246,15 @@ impl PrimDef {
PrimDef::FunNpEye => fun("np_eye", None),
PrimDef::FunNpIdentity => fun("np_identity", None),
// NumPy NDArray property getters,
PrimDef::FunNpSize => fun("np_size", None),
PrimDef::FunNpShape => fun("np_shape", None),
PrimDef::FunNpStrides => fun("np_strides", None),
// NumPy NDArray view functions
PrimDef::FunNpTranspose => fun("np_transpose", None),
PrimDef::FunNpReshape => fun("np_reshape", None),
// Miscellaneous NumPy & SciPy functions
PrimDef::FunNpRound => fun("np_round", None),
PrimDef::FunNpFloor => fun("np_floor", None),
@ -287,8 +302,6 @@ impl PrimDef {
PrimDef::FunNpLdExp => fun("np_ldexp", None),
PrimDef::FunNpHypot => fun("np_hypot", None),
PrimDef::FunNpNextAfter => fun("np_nextafter", None),
PrimDef::FunNpTranspose => fun("np_transpose", None),
PrimDef::FunNpReshape => fun("np_reshape", None),
// Linalg functions
PrimDef::FunNpDot => fun("np_dot", None),
@ -389,9 +402,6 @@ impl TopLevelDef {
r
}
),
TopLevelDef::Variable { name, ty, .. } => {
format!("Variable {{ name: {name:?}, ty: {:?} }}", unifier.stringify(*ty),)
}
}
}
}
@ -593,18 +603,6 @@ impl TopLevelComposer {
}
}
#[must_use]
pub fn make_top_level_variable_def(
name: String,
simple_name: StrRef,
ty: Type,
ty_decl: Option<Expr>,
resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>,
loc: Option<Location>,
) -> TopLevelDef {
TopLevelDef::Variable { name, simple_name, ty, ty_decl, resolver, loc }
}
#[must_use]
pub fn make_class_method_name(mut class_name: String, method_name: &str) -> String {
class_name.push('.');
@ -1134,3 +1132,23 @@ pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 {
_ => 0,
}
}
/// Extract an ndarray's `ndims` [type][`Type`] in `u64`. Panic if not possible.
/// The `ndims` must only contain 1 value.
#[must_use]
pub fn extract_ndims(unifier: &Unifier, ndims_ty: Type) -> u64 {
let ndims_ty_enum = unifier.get_ty_immutable(ndims_ty);
let TypeEnum::TLiteral { values, .. } = &*ndims_ty_enum else {
panic!("ndims_ty should be a TLiteral");
};
assert_eq!(values.len(), 1, "ndims_ty TLiteral should only contain 1 value");
let ndims = values[0].clone();
u64::try_from(ndims).unwrap()
}
/// Return an ndarray's `ndims` as a typechecker [`Type`] from its `u64` value.
pub fn create_ndims(unifier: &mut Unifier, ndims: u64) -> Type {
unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None)
}

View File

@ -6,36 +6,36 @@ use std::{
sync::Arc,
};
use inkwell::values::BasicValueEnum;
use itertools::Itertools;
use parking_lot::RwLock;
use nac3parser::ast::{self, Expr, Location, Stmt, StrRef};
use super::codegen::CodeGenContext;
use super::typecheck::type_inferencer::PrimitiveStore;
use super::typecheck::typedef::{
FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, Unifier, VarMap,
};
use crate::{
codegen::{CodeGenContext, CodeGenerator},
codegen::CodeGenerator,
symbol_resolver::{SymbolResolver, ValueEnum},
typecheck::{
type_inferencer::{CodeLocation, PrimitiveStore},
typedef::{
CallId, FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, TypeVarId, Unifier,
VarMap,
},
type_inferencer::CodeLocation,
typedef::{CallId, TypeVarId},
},
};
use composer::*;
use type_annotation::*;
use inkwell::values::BasicValueEnum;
use itertools::Itertools;
use nac3parser::ast::{self, Location, Stmt, StrRef};
use parking_lot::RwLock;
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash, Debug)]
pub struct DefinitionId(pub usize);
pub mod builtins;
pub mod composer;
pub mod helper;
pub mod numpy;
pub mod type_annotation;
use composer::*;
use type_annotation::*;
#[cfg(test)]
mod test;
pub mod type_annotation;
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash, Debug)]
pub struct DefinitionId(pub usize);
type GenCallCallback = dyn for<'ctx, 'a> Fn(
&mut CodeGenContext<'ctx, 'a>,
@ -148,25 +148,6 @@ pub enum TopLevelDef {
/// Definition location.
loc: Option<Location>,
},
Variable {
/// Qualified name of the global variable, should be unique globally.
name: String,
/// Simple name, the same as in method/function definition.
simple_name: StrRef,
/// Type of the global variable.
ty: Type,
/// The declared type of the global variable, or [`None`] if no type annotation is provided.
ty_decl: Option<Expr>,
/// Symbol resolver of the module defined the class.
resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>,
/// Definition location.
loc: Option<Location>,
},
}
pub struct TopLevelContext {

View File

@ -1,10 +1,11 @@
use itertools::Itertools;
use super::helper::PrimDef;
use crate::typecheck::{
type_inferencer::PrimitiveStore,
typedef::{Type, TypeEnum, TypeVarId, Unifier, VarMap},
use crate::{
toplevel::helper::PrimDef,
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{Type, TypeEnum, TypeVarId, Unifier, VarMap},
},
};
use itertools::Itertools;
/// Creates a `ndarray` [`Type`] with the given type arguments.
///

View File

@ -1,23 +1,21 @@
use std::{collections::HashMap, sync::Arc};
use indoc::indoc;
use parking_lot::Mutex;
use test_case::test_case;
use nac3parser::{
ast::{fold::Fold, FileName},
parser::parse_program,
};
use super::{helper::PrimDef, DefinitionId, *};
use super::*;
use crate::toplevel::helper::PrimDef;
use crate::typecheck::typedef::into_var_map;
use crate::{
codegen::CodeGenContext,
symbol_resolver::{SymbolResolver, ValueEnum},
toplevel::DefinitionId,
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{into_var_map, Type, Unifier},
typedef::{Type, Unifier},
},
};
use indoc::indoc;
use nac3parser::ast::FileName;
use nac3parser::{ast::fold::Fold, parser::parse_program};
use parking_lot::Mutex;
use std::{collections::HashMap, sync::Arc};
use test_case::test_case;
struct ResolverInternal {
id_to_type: Mutex<HashMap<StrRef, Type>>,
@ -64,7 +62,6 @@ impl SymbolResolver for Resolver {
&self,
_: StrRef,
_: &mut CodeGenContext<'ctx, '_>,
_: &mut dyn CodeGenerator,
) -> Option<ValueEnum<'ctx>> {
unimplemented!()
}

View File

@ -1,12 +1,9 @@
use strum::IntoEnumIterator;
use super::*;
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::{PrimDef, PrimDefDetails};
use crate::typecheck::typedef::VarMap;
use nac3parser::ast::Constant;
use super::{
helper::{PrimDef, PrimDefDetails},
*,
};
use crate::{symbol_resolver::SymbolValue, typecheck::typedef::VarMap};
use strum::IntoEnumIterator;
#[derive(Clone, Debug)]
pub enum TypeAnnotation {

View File

@ -1,19 +1,13 @@
use std::{
collections::{HashMap, HashSet},
iter::once,
};
use crate::toplevel::helper::PrimDef;
use super::type_inferencer::Inferencer;
use super::typedef::{Type, TypeEnum};
use nac3parser::ast::{
self, Constant, Expr, ExprKind,
Operator::{LShift, RShift},
Stmt, StmtKind, StrRef,
};
use super::{
type_inferencer::{DeclarationSource, IdentifierInfo, Inferencer},
typedef::{Type, TypeEnum},
};
use crate::toplevel::helper::PrimDef;
use std::{collections::HashSet, iter::once};
impl<'a> Inferencer<'a> {
fn should_have_value(&mut self, expr: &Expr<Option<Type>>) -> Result<(), HashSet<String>> {
@ -27,29 +21,15 @@ impl<'a> Inferencer<'a> {
fn check_pattern(
&mut self,
pattern: &Expr<Option<Type>>,
defined_identifiers: &mut HashMap<StrRef, IdentifierInfo>,
defined_identifiers: &mut HashSet<StrRef>,
) -> Result<(), HashSet<String>> {
match &pattern.node {
ExprKind::Name { id, .. } if id == &"none".into() => {
Err(HashSet::from([format!("cannot assign to a `none` (at {})", pattern.location)]))
}
ExprKind::Name { id, .. } => {
// If `id` refers to a declared symbol, reject this assignment if it is used in the
// context of an (implicit) global variable
if let Some(id_info) = defined_identifiers.get(id) {
if matches!(
id_info.source,
DeclarationSource::Global { is_explicit: Some(false) }
) {
return Err(HashSet::from([format!(
"cannot access local variable '{id}' before it is declared (at {})",
pattern.location
)]));
}
}
if !defined_identifiers.contains_key(id) {
defined_identifiers.insert(*id, IdentifierInfo::default());
if !defined_identifiers.contains(id) {
defined_identifiers.insert(*id);
}
self.should_have_value(pattern)?;
Ok(())
@ -89,7 +69,7 @@ impl<'a> Inferencer<'a> {
fn check_expr(
&mut self,
expr: &Expr<Option<Type>>,
defined_identifiers: &mut HashMap<StrRef, IdentifierInfo>,
defined_identifiers: &mut HashSet<StrRef>,
) -> Result<(), HashSet<String>> {
// there are some cases where the custom field is None
if let Some(ty) = &expr.custom {
@ -110,7 +90,7 @@ impl<'a> Inferencer<'a> {
return Ok(());
}
self.should_have_value(expr)?;
if !defined_identifiers.contains_key(id) {
if !defined_identifiers.contains(id) {
match self.function_data.resolver.get_symbol_type(
self.unifier,
&self.top_level.definitions.read(),
@ -118,22 +98,7 @@ impl<'a> Inferencer<'a> {
*id,
) {
Ok(_) => {
let is_global = self.is_id_global(*id);
defined_identifiers.insert(
*id,
IdentifierInfo {
source: match is_global {
Some(true) => {
DeclarationSource::Global { is_explicit: Some(false) }
}
Some(false) => {
DeclarationSource::Global { is_explicit: None }
}
None => DeclarationSource::Local,
},
},
);
self.defined_identifiers.insert(*id);
}
Err(e) => {
return Err(HashSet::from([format!(
@ -206,7 +171,9 @@ impl<'a> Inferencer<'a> {
let mut defined_identifiers = defined_identifiers.clone();
for arg in &args.args {
// TODO: should we check the types here?
defined_identifiers.entry(arg.node.arg).or_default();
if !defined_identifiers.contains(&arg.node.arg) {
defined_identifiers.insert(arg.node.arg);
}
}
self.check_expr(body, &mut defined_identifiers)?;
}
@ -269,7 +236,7 @@ impl<'a> Inferencer<'a> {
fn check_stmt(
&mut self,
stmt: &Stmt<Option<Type>>,
defined_identifiers: &mut HashMap<StrRef, IdentifierInfo>,
defined_identifiers: &mut HashSet<StrRef>,
) -> Result<bool, HashSet<String>> {
match &stmt.node {
StmtKind::For { target, iter, body, orelse, .. } => {
@ -295,11 +262,9 @@ impl<'a> Inferencer<'a> {
let body_returned = self.check_block(body, &mut body_identifiers)?;
let orelse_returned = self.check_block(orelse, &mut orelse_identifiers)?;
for ident in body_identifiers.keys() {
if !defined_identifiers.contains_key(ident)
&& orelse_identifiers.contains_key(ident)
{
defined_identifiers.insert(*ident, IdentifierInfo::default());
for ident in &body_identifiers {
if !defined_identifiers.contains(ident) && orelse_identifiers.contains(ident) {
defined_identifiers.insert(*ident);
}
}
Ok(body_returned && orelse_returned)
@ -330,7 +295,7 @@ impl<'a> Inferencer<'a> {
let mut defined_identifiers = defined_identifiers.clone();
let ast::ExcepthandlerKind::ExceptHandler { name, body, .. } = &handler.node;
if let Some(name) = name {
defined_identifiers.insert(*name, IdentifierInfo::default());
defined_identifiers.insert(*name);
}
self.check_block(body, &mut defined_identifiers)?;
}
@ -394,44 +359,6 @@ impl<'a> Inferencer<'a> {
}
Ok(true)
}
StmtKind::Global { names, .. } => {
for id in names {
if let Some(id_info) = defined_identifiers.get(id) {
if id_info.source == DeclarationSource::Local {
return Err(HashSet::from([format!(
"name '{id}' is referenced prior to global declaration at {}",
stmt.location,
)]));
}
continue;
}
match self.function_data.resolver.get_symbol_type(
self.unifier,
&self.top_level.definitions.read(),
self.primitives,
*id,
) {
Ok(_) => {
defined_identifiers.insert(
*id,
IdentifierInfo {
source: DeclarationSource::Global { is_explicit: Some(true) },
},
);
}
Err(e) => {
return Err(HashSet::from([format!(
"type error at identifier `{}` ({}) at {}",
id, e, stmt.location
)]))
}
}
}
Ok(false)
}
// break, raise, etc.
_ => Ok(false),
}
@ -440,7 +367,7 @@ impl<'a> Inferencer<'a> {
pub fn check_block(
&mut self,
block: &[Stmt<Option<Type>>],
defined_identifiers: &mut HashMap<StrRef, IdentifierInfo>,
defined_identifiers: &mut HashSet<StrRef>,
) -> Result<bool, HashSet<String>> {
let mut ret = false;
for stmt in block {

View File

@ -1,21 +1,17 @@
use std::{cmp::max, collections::HashMap, rc::Rc};
use itertools::{iproduct, Itertools};
use strum::IntoEnumIterator;
use nac3parser::ast::{Cmpop, Operator, StrRef, Unaryop};
use super::{
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::PrimDef;
use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys};
use crate::typecheck::{
type_inferencer::*,
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
};
use crate::{
symbol_resolver::SymbolValue,
toplevel::{
helper::PrimDef,
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
},
};
use itertools::{iproduct, Itertools};
use nac3parser::ast::StrRef;
use nac3parser::ast::{Cmpop, Operator, Unaryop};
use std::cmp::max;
use std::collections::HashMap;
use std::rc::Rc;
use strum::IntoEnumIterator;
/// The variant of a binary operator.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]

View File

@ -1,13 +1,14 @@
use std::{collections::HashMap, fmt::Display};
use std::collections::HashMap;
use std::fmt::Display;
use itertools::Itertools;
use nac3parser::ast::{Cmpop, Location, StrRef};
use crate::typecheck::{magic_methods::HasOpInfo, typedef::TypeEnum};
use super::{
magic_methods::{Binop, HasOpInfo},
typedef::{RecordKey, Type, TypeEnum, Unifier},
magic_methods::Binop,
typedef::{RecordKey, Type, Unifier},
};
use itertools::Itertools;
use nac3parser::ast::{Cmpop, Location, StrRef};
#[derive(Debug, Clone)]
pub enum TypeErrorKind {

View File

@ -1,36 +1,32 @@
use std::{
cell::RefCell,
cmp::max,
collections::{HashMap, HashSet},
convert::{From, TryInto},
iter::once,
sync::Arc,
};
use itertools::{izip, Itertools};
use nac3parser::ast::{
self,
fold::{self, Fold},
Arguments, Comprehension, ExprContext, ExprKind, Ident, Located, Location, StrRef,
};
use std::cmp::max;
use std::collections::{HashMap, HashSet};
use std::convert::{From, TryInto};
use std::iter::{self, once};
use std::{cell::RefCell, sync::Arc};
use super::{
magic_methods::*,
type_error::{TypeError, TypeErrorKind},
typedef::{
into_var_map, iter_type_vars, Call, CallId, FunSignature, FuncArg, Mapping, OperatorInfo,
into_var_map, iter_type_vars, Call, CallId, FunSignature, FuncArg, OperatorInfo,
RecordField, RecordKey, Type, TypeEnum, TypeVar, Unifier, VarMap,
},
};
use crate::toplevel::type_annotation::TypeAnnotation;
use crate::{
symbol_resolver::{SymbolResolver, SymbolValue},
toplevel::{
helper::{arraylike_flatten_element_type, arraylike_get_ndims, PrimDef},
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
type_annotation::TypeAnnotation,
TopLevelContext, TopLevelDef,
},
typecheck::typedef::Mapping,
};
use itertools::{izip, Itertools};
use nac3parser::ast::{
self,
fold::{self, Fold},
Arguments, Comprehension, ExprContext, ExprKind, Located, Location, StrRef,
};
#[cfg(test)]
@ -88,40 +84,6 @@ impl PrimitiveStore {
}
}
/// The location where an identifier declaration refers to.
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum DeclarationSource {
/// Local scope.
Local,
/// Global scope.
Global {
/// Whether the identifier is declared by the use of `global` statement. This field is
/// [`None`] if the identifier does not refer to a variable.
is_explicit: Option<bool>,
},
}
/// Information regarding a defined identifier.
#[derive(Clone, Copy, Debug)]
pub struct IdentifierInfo {
/// Whether this identifier refers to a global variable.
pub source: DeclarationSource,
}
impl Default for IdentifierInfo {
fn default() -> Self {
IdentifierInfo { source: DeclarationSource::Local }
}
}
impl IdentifierInfo {
#[must_use]
pub fn new() -> IdentifierInfo {
IdentifierInfo::default()
}
}
pub struct FunctionData {
pub resolver: Arc<dyn SymbolResolver + Send + Sync>,
pub return_type: Option<Type>,
@ -130,7 +92,7 @@ pub struct FunctionData {
pub struct Inferencer<'a> {
pub top_level: &'a TopLevelContext,
pub defined_identifiers: HashMap<StrRef, IdentifierInfo>,
pub defined_identifiers: HashSet<StrRef>,
pub function_data: &'a mut FunctionData,
pub unifier: &'a mut Unifier,
pub primitives: &'a PrimitiveStore,
@ -262,7 +224,9 @@ impl<'a> Fold<()> for Inferencer<'a> {
handler.location,
));
if let Some(name) = name {
self.defined_identifiers.entry(name).or_default();
if !self.defined_identifiers.contains(&name) {
self.defined_identifiers.insert(name);
}
if let Some(old_typ) = self.variable_mapping.insert(name, typ) {
let loc = handler.location;
self.unifier.unify(old_typ, typ).map_err(|e| {
@ -414,7 +378,6 @@ impl<'a> Fold<()> for Inferencer<'a> {
| ast::StmtKind::Continue { .. }
| ast::StmtKind::Expr { .. }
| ast::StmtKind::For { .. }
| ast::StmtKind::Global { .. }
| ast::StmtKind::Pass { .. }
| ast::StmtKind::Try { .. } => {}
ast::StmtKind::If { test, .. } | ast::StmtKind::While { test, .. } => {
@ -586,7 +549,7 @@ impl<'a> Fold<()> for Inferencer<'a> {
unreachable!("must be tobj")
}
} else {
if !self.defined_identifiers.contains_key(id) {
if !self.defined_identifiers.contains(id) {
match self.function_data.resolver.get_symbol_type(
self.unifier,
&self.top_level.definitions.read(),
@ -594,22 +557,7 @@ impl<'a> Fold<()> for Inferencer<'a> {
*id,
) {
Ok(_) => {
let is_global = self.is_id_global(*id);
self.defined_identifiers.insert(
*id,
IdentifierInfo {
source: match is_global {
Some(true) => DeclarationSource::Global {
is_explicit: Some(false),
},
Some(false) => {
DeclarationSource::Global { is_explicit: None }
}
None => DeclarationSource::Local,
},
},
);
self.defined_identifiers.insert(*id);
}
Err(e) => {
return report_error(
@ -674,8 +622,8 @@ impl<'a> Inferencer<'a> {
fn infer_pattern<T>(&mut self, pattern: &ast::Expr<T>) -> Result<(), InferenceError> {
match &pattern.node {
ExprKind::Name { id, .. } => {
if !self.defined_identifiers.contains_key(id) {
self.defined_identifiers.insert(*id, IdentifierInfo::default());
if !self.defined_identifiers.contains(id) {
self.defined_identifiers.insert(*id);
}
Ok(())
}
@ -784,8 +732,8 @@ impl<'a> Inferencer<'a> {
let mut defined_identifiers = self.defined_identifiers.clone();
for arg in &args.args {
let name = &arg.node.arg;
if !defined_identifiers.contains_key(name) {
defined_identifiers.insert(*name, IdentifierInfo::default());
if !defined_identifiers.contains(name) {
defined_identifiers.insert(*name);
}
}
let fn_args: Vec<_> = args
@ -1235,6 +1183,45 @@ impl<'a> Inferencer<'a> {
}));
}
if ["np_shape".into(), "np_strides".into()].contains(id) && args.len() == 1 {
let ndarray = self.fold_expr(args.remove(0))?;
let ndims = arraylike_get_ndims(self.unifier, ndarray.custom.unwrap());
// Make a tuple of size `ndims` full of int32 (TODO: Make it usize)
let ret_ty = TypeEnum::TTuple {
ty: iter::repeat(self.primitives.int32).take(ndims as usize).collect_vec(),
is_vararg_ctx: false,
};
let ret_ty = self.unifier.add_ty(ret_ty);
let func_ty = TypeEnum::TFunc(FunSignature {
args: vec![FuncArg {
name: "a".into(),
default_value: None,
ty: ndarray.custom.unwrap(),
is_vararg: false,
}],
ret: ret_ty,
vars: VarMap::new(),
});
let func_ty = self.unifier.add_ty(func_ty);
return Ok(Some(Located {
location,
custom: Some(ret_ty),
node: ExprKind::Call {
func: Box::new(Located {
custom: Some(func_ty),
location: func.location,
node: ExprKind::Name { id: *id, ctx: *ctx },
}),
args: vec![ndarray],
keywords: vec![],
},
}));
}
if id == &"np_dot".into() {
let arg0 = self.fold_expr(args.remove(0))?;
let arg1 = self.fold_expr(args.remove(0))?;
@ -1602,29 +1589,36 @@ impl<'a> Inferencer<'a> {
}
// 2-argument ndarray n-dimensional creation functions
if id == &"np_full".into() && args.len() == 2 {
// Parse arguments
let shape_expr = args.remove(0);
let (ndims, shape) =
self.fold_numpy_function_call_shape_argument(*id, 0, shape_expr)?; // Special handling for `shape`
let ExprKind::List { elts, .. } = &args[0].node else {
return report_error(
format!(
"Expected List literal for first argument of {id}, got {}",
args[0].node.name()
)
.as_str(),
args[0].location,
);
};
let fill_value = self.fold_expr(args.remove(0))?;
let ndims = elts.len() as u64;
// Build the return type
let dtype = fill_value.custom.unwrap();
let arg0 = self.fold_expr(args.remove(0))?;
let arg1 = self.fold_expr(args.remove(0))?;
let ty = arg1.custom.unwrap();
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
let ret = make_ndarray_ty(self.unifier, self.primitives, Some(dtype), Some(ndims));
let ret = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims));
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg {
name: "shape".into(),
ty: shape.custom.unwrap(),
ty: arg0.custom.unwrap(),
default_value: None,
is_vararg: false,
},
FuncArg {
name: "fill_value".into(),
ty: fill_value.custom.unwrap(),
ty: arg1.custom.unwrap(),
default_value: None,
is_vararg: false,
},
@ -1642,7 +1636,7 @@ impl<'a> Inferencer<'a> {
location: func.location,
node: ExprKind::Name { id: *id, ctx: *ctx },
}),
args: vec![shape, fill_value],
args: vec![arg0, arg1],
keywords: vec![],
},
}));
@ -2685,22 +2679,4 @@ impl<'a> Inferencer<'a> {
self.constrain(body.custom.unwrap(), orelse.custom.unwrap(), &body.location)?;
Ok(body.custom.unwrap())
}
/// Determines whether the given `id` refers to a global symbol.
///
/// Returns `Some(true)` if `id` refers to a global variable, `Some(false)` if `id` refers to a
/// class/function, and `None` if `id` refers to a local symbol.
pub(super) fn is_id_global(&self, id: Ident) -> Option<bool> {
self.top_level
.definitions
.read()
.iter()
.map(|def| match *def.read() {
TopLevelDef::Class { name, .. } => (name, false),
TopLevelDef::Function { simple_name, .. } => (simple_name, false),
TopLevelDef::Variable { simple_name, .. } => (simple_name, true),
})
.find(|(global, _)| global == &id)
.map(|(_, has_explicit_prop)| has_explicit_prop)
}
}

View File

@ -1,19 +1,17 @@
use std::iter::zip;
use indexmap::IndexMap;
use indoc::indoc;
use parking_lot::RwLock;
use test_case::test_case;
use nac3parser::{ast::FileName, parser::parse_program};
use super::super::{magic_methods::with_fields, typedef::*};
use super::*;
use crate::{
codegen::{CodeGenContext, CodeGenerator},
codegen::CodeGenContext,
symbol_resolver::ValueEnum,
toplevel::{helper::PrimDef, DefinitionId, TopLevelDef},
typecheck::{magic_methods::with_fields, typedef::*},
};
use indexmap::IndexMap;
use indoc::indoc;
use nac3parser::ast::FileName;
use nac3parser::parser::parse_program;
use parking_lot::RwLock;
use std::iter::zip;
use test_case::test_case;
struct Resolver {
id_to_type: HashMap<StrRef, Type>,
@ -43,7 +41,6 @@ impl SymbolResolver for Resolver {
&self,
_: StrRef,
_: &mut CodeGenContext<'ctx, '_>,
_: &mut dyn CodeGenerator,
) -> Option<ValueEnum<'ctx>> {
unimplemented!()
}
@ -520,7 +517,7 @@ impl TestEnvironment {
primitives: &mut self.primitives,
virtual_checks: &mut self.virtual_checks,
calls: &mut self.calls,
defined_identifiers: HashMap::default(),
defined_identifiers: HashSet::default(),
in_handler: false,
}
}
@ -596,9 +593,8 @@ fn test_basic(source: &str, mapping: &HashMap<&str, &str>, virtuals: &[(&str, &s
println!("source:\n{source}");
let mut env = TestEnvironment::new();
let id_to_name = std::mem::take(&mut env.id_to_name);
let mut defined_identifiers: HashMap<_, _> =
env.identifier_mapping.keys().copied().map(|id| (id, IdentifierInfo::default())).collect();
defined_identifiers.insert("virtual".into(), IdentifierInfo::default());
let mut defined_identifiers: HashSet<_> = env.identifier_mapping.keys().copied().collect();
defined_identifiers.insert("virtual".into());
let mut inferencer = env.get_inferencer();
inferencer.defined_identifiers.clone_from(&defined_identifiers);
let statements = parse_program(source, FileName::default()).unwrap();
@ -743,9 +739,8 @@ fn test_primitive_magic_methods(source: &str, mapping: &HashMap<&str, &str>) {
println!("source:\n{source}");
let mut env = TestEnvironment::basic_test_env();
let id_to_name = std::mem::take(&mut env.id_to_name);
let mut defined_identifiers: HashMap<_, _> =
env.identifier_mapping.keys().copied().map(|id| (id, IdentifierInfo::default())).collect();
defined_identifiers.insert("virtual".into(), IdentifierInfo::default());
let mut defined_identifiers: HashSet<_> = env.identifier_mapping.keys().copied().collect();
defined_identifiers.insert("virtual".into());
let mut inferencer = env.get_inferencer();
inferencer.defined_identifiers.clone_from(&defined_identifiers);
let statements = parse_program(source, FileName::default()).unwrap();

View File

@ -1,28 +1,21 @@
use std::{
borrow::Cow,
cell::RefCell,
collections::{HashMap, HashSet},
fmt::{self, Display},
iter::{repeat, zip},
rc::Rc,
sync::{Arc, Mutex},
};
use super::magic_methods::{Binop, HasOpInfo};
use super::type_error::{TypeError, TypeErrorKind};
use super::unification_table::{UnificationKey, UnificationTable};
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::PrimDef;
use crate::toplevel::{DefinitionId, TopLevelContext, TopLevelDef};
use crate::typecheck::magic_methods::OpInfo;
use crate::typecheck::type_inferencer::PrimitiveStore;
use indexmap::IndexMap;
use itertools::{repeat_n, Itertools};
use nac3parser::ast::{Cmpop, Location, StrRef, Unaryop};
use super::{
magic_methods::{Binop, HasOpInfo, OpInfo},
type_error::{TypeError, TypeErrorKind},
type_inferencer::PrimitiveStore,
unification_table::{UnificationKey, UnificationTable},
};
use crate::{
symbol_resolver::SymbolValue,
toplevel::{helper::PrimDef, DefinitionId, TopLevelContext, TopLevelDef},
};
use std::cell::RefCell;
use std::collections::HashMap;
use std::fmt::{self, Display};
use std::iter::{repeat, zip};
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use std::{borrow::Cow, collections::HashSet};
#[cfg(test)]
mod test;
@ -677,8 +670,8 @@ impl Unifier {
let num_args = posargs.len() + kwargs.len();
// Now we check the arguments against the parameters,
// and depending on what `call_info` is, we might change how `unify_call()` behaves
// to improve user error messages when type checking fails.
// and depending on what `call_info` is, we might change how the behavior `unify_call()`
// in hopes to improve user error messages when type checking fails.
match operator_info {
Some(OperatorInfo::IsBinaryOp { self_type, operator }) => {
// The call is written in the form of (say) `a + b`.

View File

@ -1,12 +1,10 @@
use std::collections::HashMap;
use super::super::magic_methods::with_fields;
use super::*;
use indoc::indoc;
use itertools::Itertools;
use std::collections::HashMap;
use test_case::test_case;
use super::*;
use crate::typecheck::magic_methods::with_fields;
impl Unifier {
/// Check whether two types are equal.
fn eq(&mut self, a: Type, b: Type) -> bool {

View File

@ -21,12 +21,13 @@
clippy::wildcard_imports
)]
use std::{collections::HashMap, mem, ptr, slice, str};
use byteorder::{ByteOrder, LittleEndian};
use dwarf::*;
use elf::*;
use std::collections::HashMap;
use std::{mem, ptr, slice, str};
extern crate byteorder;
use byteorder::{ByteOrder, LittleEndian};
mod dwarf;
mod elf;

View File

@ -8,15 +8,15 @@ license = "MIT"
edition = "2021"
[build-dependencies]
lalrpop = "0.22"
lalrpop = "0.20"
[dependencies]
nac3ast = { path = "../nac3ast" }
lalrpop-util = "0.22"
lalrpop-util = "0.20"
log = "0.4"
unic-emoji-char = "0.9"
unic-ucd-ident = "0.9"
unicode_names2 = "1.3"
unicode_names2 = "1.2"
phf = { version = "0.11", features = ["macros"] }
ahash = "0.8"

View File

@ -1,10 +1,8 @@
use crate::{
ast::{Ident, Location},
error::*,
token::Tok,
};
use crate::ast::Ident;
use crate::ast::Location;
use crate::error::*;
use crate::token::Tok;
use lalrpop_util::ParseError;
use nac3ast::*;
pub fn make_config_comment(

View File

@ -1,11 +1,12 @@
//! Define internal parse error types
//! The goal is to provide a matching and a safe error API, maksing errors from LALR
use std::error::Error;
use std::fmt;
use lalrpop_util::ParseError as LalrpopError;
use crate::{ast::Location, token::Tok};
use crate::ast::Location;
use crate::token::Tok;
use std::error::Error;
use std::fmt;
/// Represents an error during lexical scanning.
#[derive(Debug, PartialEq)]

View File

@ -1,11 +1,12 @@
use std::{iter, mem, str};
use std::iter;
use std::mem;
use std::str;
use crate::ast::{Constant, ConversionFlag, Expr, ExprKind, Location};
use crate::error::{FStringError, FStringErrorType, ParseError};
use crate::parser::parse_expression;
use self::FStringErrorType::*;
use crate::{
ast::{Constant, ConversionFlag, Expr, ExprKind, Location},
error::{FStringError, FStringErrorType, ParseError},
parser::parse_expression,
};
struct FStringParser<'a> {
chars: iter::Peekable<str::Chars<'a>>,

View File

@ -1,11 +1,8 @@
use ahash::RandomState;
use std::collections::HashSet;
use ahash::RandomState;
use crate::{
ast,
error::{LexicalError, LexicalErrorType},
};
use crate::ast;
use crate::error::{LexicalError, LexicalErrorType};
pub struct ArgumentList {
pub args: Vec<ast::Expr>,

View File

@ -1,16 +1,16 @@
//! This module takes care of lexing python source text.
//!
//! This means source code is translated into separate tokens.
use std::{char, cmp::Ordering, num::IntErrorKind, str::FromStr};
use unic_emoji_char::is_emoji_presentation;
use unic_ucd_ident::{is_xid_continue, is_xid_start};
pub use super::token::Tok;
use crate::{
ast::{FileName, Location},
error::{LexicalError, LexicalErrorType},
};
use crate::ast::{FileName, Location};
use crate::error::{LexicalError, LexicalErrorType};
use std::char;
use std::cmp::Ordering;
use std::num::IntErrorKind;
use std::str::FromStr;
use unic_emoji_char::is_emoji_presentation;
use unic_ucd_ident::{is_xid_continue, is_xid_start};
#[derive(Clone, Copy, PartialEq, Debug, Default)]
struct IndentationLevel {

View File

@ -5,16 +5,14 @@
//! parse a whole program, a single statement, or a single
//! expression.
use nac3ast::Location;
use std::iter;
use nac3ast::Location;
use crate::ast::{self, FileName};
use crate::error::ParseError;
use crate::lexer;
pub use crate::mode::Mode;
use crate::{
ast::{self, FileName},
error::ParseError,
lexer, python,
};
use crate::python;
/*
* Parse python code.

View File

@ -1,8 +1,7 @@
//! Different token definitions.
//! Loosely based on token.h from CPython source:
use std::fmt::{self, Write};
use crate::ast;
use std::fmt::{self, Write};
/// Python source code can be tokenized in a sequence of these tokens.
#[derive(Clone, Debug, PartialEq)]

View File

@ -9,8 +9,14 @@ no-escape-analysis = ["nac3core/no-escape-analysis"]
[dependencies]
parking_lot = "0.12"
nac3parser = { path = "../nac3parser" }
nac3core = { path = "../nac3core" }
[dependencies.clap]
version = "4.5"
features = ["derive"]
[dependencies.inkwell]
version = "0.4"
default-features = false
features = ["llvm14-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]

View File

@ -179,6 +179,15 @@ def patch(module):
module.np_identity = np.identity
module.np_array = np.array
# NumPy NDArray view functions
module.np_transpose = np.transpose
module.np_reshape = np.reshape
# NumPy NDArray property getters
module.np_size = np.size
module.np_shape = np.shape
module.np_strides = lambda ndarray: ndarray.strides
# NumPy Math functions
module.np_isnan = np.isnan
module.np_isinf = np.isinf
@ -218,8 +227,6 @@ def patch(module):
module.np_ldexp = np.ldexp
module.np_hypot = np.hypot
module.np_nextafter = np.nextafter
module.np_transpose = np.transpose
module.np_reshape = np.reshape
# SciPy Math functions
module.sp_spec_erf = special.erf

View File

@ -1,31 +0,0 @@
@extern
def output_int32(x: int32):
...
@extern
def output_int64(x: int64):
...
X: int32 = 0
Y = int64(1)
def f():
global X, Y
X = 1
Y = int64(2)
def run() -> int32:
global X, Y
output_int32(X)
output_int64(Y)
f()
output_int32(X)
output_int64(Y)
X = 0
Y = int64(0)
output_int32(X)
output_int64(Y)
return 0

View File

@ -114,22 +114,12 @@ def test_ndarray_ones():
n: ndarray[float, 1] = np_ones([1])
output_ndarray_float_1(n)
dim = (1,)
n_tup: ndarray[float, 1] = np_ones(dim)
output_ndarray_float_1(n_tup)
def test_ndarray_full():
n_float: ndarray[float, 1] = np_full([1], 2.0)
output_ndarray_float_1(n_float)
n_i32: ndarray[int32, 1] = np_full([1], 2)
output_ndarray_int32_1(n_i32)
dim = (1,)
n_float_tup: ndarray[float, 1] = np_full(dim, 2.0)
output_ndarray_float_1(n_float_tup)
n_i32_tup: ndarray[int32, 1] = np_full(dim, 2)
output_ndarray_int32_1(n_i32_tup)
def test_ndarray_eye():
n: ndarray[float, 2] = np_eye(2)
output_ndarray_float_2(n)

View File

@ -1,14 +1,5 @@
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use parking_lot::{Mutex, RwLock};
use nac3core::{
codegen::{CodeGenContext, CodeGenerator},
inkwell::{module::Linkage, values::BasicValue},
nac3parser::ast::{self, StrRef},
codegen::CodeGenContext,
symbol_resolver::{SymbolResolver, SymbolValue, ValueEnum},
toplevel::{DefinitionId, TopLevelDef},
typecheck::{
@ -16,6 +7,10 @@ use nac3core::{
typedef::{Type, Unifier},
},
};
use nac3parser::ast::{self, StrRef};
use parking_lot::{Mutex, RwLock};
use std::collections::HashSet;
use std::{collections::HashMap, sync::Arc};
pub struct ResolverInternal {
pub id_to_type: Mutex<HashMap<StrRef, Type>>,
@ -50,51 +45,20 @@ impl SymbolResolver for Resolver {
fn get_symbol_type(
&self,
unifier: &mut Unifier,
_: &mut Unifier,
_: &[Arc<RwLock<TopLevelDef>>],
primitives: &PrimitiveStore,
_: &PrimitiveStore,
str: StrRef,
) -> Result<Type, String> {
self.0
.id_to_type
.lock()
.get(&str)
.copied()
.or_else(|| {
self.0
.module_globals
.lock()
.get(&str)
.cloned()
.map(|v| v.get_type(primitives, unifier))
})
.ok_or(format!("cannot get type of {str}"))
self.0.id_to_type.lock().get(&str).copied().ok_or(format!("cannot get type of {str}"))
}
fn get_symbol_value<'ctx>(
&self,
str: StrRef,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
_: StrRef,
_: &mut CodeGenContext<'ctx, '_>,
) -> Option<ValueEnum<'ctx>> {
self.0.module_globals.lock().get(&str).cloned().map(|v| {
ctx.module
.get_global(&str.to_string())
.unwrap_or_else(|| {
let ty = v.get_type(&ctx.primitives, &mut ctx.unifier);
let init_val = ctx.gen_symbol_val(generator, &v, ty);
let llvm_ty = init_val.get_type();
let global = ctx.module.add_global(llvm_ty, None, &str.to_string());
global.set_linkage(Linkage::LinkOnceAny);
global.set_initializer(&init_val);
global
})
.as_basic_value_enum()
.into()
})
unimplemented!()
}
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> {

View File

@ -8,30 +8,17 @@
#![warn(clippy::pedantic)]
#![allow(clippy::too_many_lines, clippy::wildcard_imports)]
use std::{
collections::{HashMap, HashSet},
fs,
num::NonZeroUsize,
path::Path,
sync::Arc,
};
use clap::Parser;
use parking_lot::{Mutex, RwLock};
use inkwell::context::Context;
use inkwell::{
memory_buffer::MemoryBuffer, passes::PassBuilderOptions, support::is_multithreaded, targets::*,
OptimizationLevel,
};
use nac3core::{
codegen::{
concrete_type::ConcreteTypeStore, irrt::load_irrt, CodeGenLLVMOptions,
CodeGenTargetMachineOptions, CodeGenTask, DefaultCodeGenerator, WithCall, WorkerRegistry,
},
inkwell::{
memory_buffer::MemoryBuffer, module::Linkage, passes::PassBuilderOptions,
support::is_multithreaded, targets::*, OptimizationLevel,
},
nac3parser::{
ast::{Constant, Expr, ExprKind, StmtKind, StrRef},
parser,
},
symbol_resolver::SymbolResolver,
toplevel::{
composer::{ComposerConfig, TopLevelComposer},
@ -44,10 +31,17 @@ use nac3core::{
typedef::{FunSignature, Type, Unifier, VarMap},
},
};
use basic_symbol_resolver::*;
use nac3parser::{
ast::{Constant, Expr, ExprKind, StmtKind, StrRef},
parser,
};
use parking_lot::{Mutex, RwLock};
use std::collections::HashSet;
use std::num::NonZeroUsize;
use std::{collections::HashMap, fs, path::Path, sync::Arc};
mod basic_symbol_resolver;
use basic_symbol_resolver::*;
/// Command-line argument parser definition.
#[derive(Parser)]
@ -175,49 +169,46 @@ fn handle_typevar_definition(
fn handle_assignment_pattern(
targets: &[Expr],
value: &Expr,
resolver: Arc<dyn SymbolResolver + Send + Sync>,
resolver: &(dyn SymbolResolver + Send + Sync),
internal_resolver: &ResolverInternal,
composer: &mut TopLevelComposer,
def_list: &[Arc<RwLock<TopLevelDef>>],
unifier: &mut Unifier,
primitives: &PrimitiveStore,
) -> Result<(), String> {
if targets.len() == 1 {
let target = &targets[0];
match &target.node {
match &targets[0].node {
ExprKind::Name { id, .. } => {
let def_list = composer.extract_def_list();
let unifier = &mut composer.unifier;
let primitives = &composer.primitives_ty;
if let Ok(var) =
handle_typevar_definition(value, &*resolver, &def_list, unifier, primitives)
handle_typevar_definition(value, resolver, def_list, unifier, primitives)
{
internal_resolver.add_id_type(*id, var);
Ok(())
} else if let Ok(val) = parse_parameter_default_value(value, &*resolver) {
} else if let Ok(val) = parse_parameter_default_value(value, resolver) {
internal_resolver.add_module_global(*id, val);
let (name, def_id, _) = composer
.register_top_level_var(
*id,
None,
Some(resolver.clone()),
"__main__",
target.location,
)
.unwrap();
internal_resolver.add_id_def(name, def_id);
Ok(())
} else {
Err(format!("fails to evaluate this expression `{:?}` as a constant or generic parameter at {}",
target.node,
target.location,
targets[0].node,
targets[0].location,
))
}
}
ExprKind::List { elts, .. } | ExprKind::Tuple { elts, .. } => {
handle_assignment_pattern(elts, value, resolver, internal_resolver, composer)?;
handle_assignment_pattern(
elts,
value,
resolver,
internal_resolver,
def_list,
unifier,
primitives,
)?;
Ok(())
}
_ => Err(format!("assignment to {target:?} is not supported at {}", target.location)),
_ => Err(format!(
"assignment to {:?} is not supported at {}",
targets[0], targets[0].location
)),
}
} else {
match &value.node {
@ -227,9 +218,11 @@ fn handle_assignment_pattern(
handle_assignment_pattern(
std::slice::from_ref(tar),
val,
resolver.clone(),
resolver,
internal_resolver,
composer,
def_list,
unifier,
primitives,
)?;
}
Ok(())
@ -247,39 +240,6 @@ fn handle_assignment_pattern(
}
}
fn handle_global_var(
target: &Expr,
value: Option<&Expr>,
resolver: &Arc<dyn SymbolResolver + Send + Sync>,
internal_resolver: &ResolverInternal,
composer: &mut TopLevelComposer,
) -> Result<(), String> {
let ExprKind::Name { id, .. } = target.node else {
return Err(format!(
"global variable declaration must be an identifier (at {})",
target.location,
));
};
let Some(value) = value else {
return Err(format!("global variable `{id}` must be initialized in its definition"));
};
if let Ok(val) = parse_parameter_default_value(value, &**resolver) {
internal_resolver.add_module_global(id, val);
let (name, def_id, _) = composer
.register_top_level_var(id, None, Some(resolver.clone()), "__main__", target.location)
.unwrap();
internal_resolver.add_id_def(name, def_id);
Ok(())
} else {
Err(format!(
"failed to evaluate this expression `{:?}` as a constant at {}",
target.node, target.location,
))
}
}
fn main() {
let cli = CommandLineArgs::parse();
let CommandLineArgs { file_name, threads, opt_level, emit_llvm, triple, mcpu, target_features } =
@ -320,19 +280,22 @@ fn main() {
reloc_mode: RelocMode::PIC,
..host_target_machine
};
let target_machine = target_machine_options
.create_target_machine(opt_level)
.expect("couldn't create target machine");
let context = nac3core::inkwell::context::Context::create();
let size_t =
context.ptr_sized_int_type(&target_machine.get_target_data(), None).get_bit_width();
let size_t = Context::create()
.ptr_sized_int_type(
&target_machine_options
.create_target_machine(opt_level)
.map(|tm| tm.get_target_data())
.unwrap(),
None,
)
.get_bit_width();
let program = match fs::read_to_string(file_name.clone()) {
Ok(program) => program,
Err(err) => {
panic!("Cannot open input file: {err}");
println!("Cannot open input file: {err}");
return;
}
};
@ -350,6 +313,8 @@ fn main() {
let resolver =
Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
let context = inkwell::context::Context::create();
// Process IRRT
let irrt = load_irrt(&context, resolver.as_ref());
if emit_llvm {
@ -362,29 +327,22 @@ fn main() {
for stmt in parser_result {
match &stmt.node {
StmtKind::Assign { targets, value, .. } => {
let def_list = composer.extract_def_list();
let unifier = &mut composer.unifier;
let primitives = &composer.primitives_ty;
if let Err(err) = handle_assignment_pattern(
targets,
value,
resolver.clone(),
resolver.as_ref(),
internal_resolver.as_ref(),
&mut composer,
&def_list,
unifier,
primitives,
) {
panic!("{err}");
eprintln!("{err}");
return;
}
}
StmtKind::AnnAssign { target, value, .. } => {
if let Err(err) = handle_global_var(
target,
value.as_ref().map(Box::as_ref),
&resolver,
internal_resolver.as_ref(),
&mut composer,
) {
panic!("{err}");
}
}
// allow (and ignore) "from __future__ import annotations"
StmtKind::ImportFrom { module, names, .. }
if module == &Some("__future__".into())
@ -495,12 +453,17 @@ fn main() {
let mut function_iter = main.get_first_function();
while let Some(func) = function_iter {
if func.count_basic_blocks() > 0 && func.get_name().to_str().unwrap() != "run" {
func.set_linkage(Linkage::Private);
func.set_linkage(inkwell::module::Linkage::Private);
}
function_iter = func.get_next_function();
}
// Optimize `main`
let target_machine = llvm_options
.target
.create_target_machine(llvm_options.opt_level)
.expect("couldn't create target machine");
let pass_options = PassBuilderOptions::create();
pass_options.set_merge_functions(true);
let passes = format!("default<O{}>", opt_level as u32);

View File

@ -21,6 +21,6 @@ build() {
}
package() {
mkdir -p $pkgdir/clang64/lib/python3.12/site-packages
cp ${srcdir}/nac3artiq.pyd $pkgdir/clang64/lib/python3.12/site-packages
mkdir -p $pkgdir/clang64/lib/python3.11/site-packages
cp ${srcdir}/nac3artiq.pyd $pkgdir/clang64/lib/python3.11/site-packages
}

View File

@ -21,10 +21,10 @@ let
text =
''
implementation=CPython
version=3.12
version=3.11
shared=true
abi3=false
lib_name=python3.12
lib_name=python3.11
lib_dir=${msys2-env}/clang64/lib
pointer_width=64
build_flags=WITH_THREAD

View File

@ -6,11 +6,11 @@ cd $(dirname $0)
MSYS2DIR=`pwd`/msys2
mkdir -p $MSYS2DIR/var/lib/pacman $MSYS2DIR/msys/etc
curl -L https://repo.msys2.org/msys/x86_64/pacman-mirrors-20240523-1-any.pkg.tar.zst | tar xvf - -C $MSYS2DIR --zstd
curl -L https://mirror.msys2.org/msys/x86_64/pacman-mirrors-20220205-1-any.pkg.tar.zst | tar xvf - -C $MSYS2DIR --zstd
curl -L https://raw.githubusercontent.com/msys2/MSYS2-packages/master/pacman/pacman.conf | sed "s|SigLevel = Required|SigLevel = Never|g" | sed "s|/etc/pacman.d|$MSYS2DIR/etc/pacman.d|g" > $MSYS2DIR/etc/pacman.conf
fakeroot pacman --root $MSYS2DIR --config $MSYS2DIR/etc/pacman.conf -Syy
pacman --root $MSYS2DIR --config $MSYS2DIR/etc/pacman.conf --cachedir $MSYS2DIR/msys/cache -Sp mingw-w64-clang-x86_64-rust mingw-w64-clang-x86_64-cmake mingw-w64-clang-x86_64-ninja mingw-w64-clang-x86_64-python-numpy mingw-w64-clang-x86_64-python-setuptools > $MSYS2DIR/packages.txt
pacman --root $MSYS2DIR --config $MSYS2DIR/etc/pacman.conf --cachedir $MSYS2DIR/msys/cache -Sp mingw-w64-clang-x86_64-rust mingw-w64-clang-x86_64-cmake mingw-w64-clang-x86_64-ninja mingw-w64-clang-x86_64-python3.11 mingw-w64-clang-x86_64-python-numpy mingw-w64-clang-x86_64-python-setuptools > $MSYS2DIR/packages.txt
echo "{ pkgs } : [" > msys2_packages.nix
while read package; do

View File

@ -1,15 +1,15 @@
{ pkgs } : [
(pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libunwind-18.1.8-2-any.pkg.tar.zst";
sha256 = "0f9m76dx40iy794nfks0360gvjhdg6yngb2lyhwp4xd76rn5081m";
name = "mingw-w64-clang-x86_64-libunwind-18.1.8-2-any.pkg.tar.zst";
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libunwind-18.1.8-1-any.pkg.tar.zst";
sha256 = "1v8zkfcbf1ga2ndpd1j0dwv5s1rassxs2b5pjhcsmqwjcvczba1m";
name = "mingw-w64-clang-x86_64-libunwind-18.1.8-1-any.pkg.tar.zst";
})
(pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libc++-18.1.8-2-any.pkg.tar.zst";
sha256 = "17savj9wys9my2ji7vyba7wwqkvzdjwnkb3k4858wxrjbzbfa6lk";
name = "mingw-w64-clang-x86_64-libc++-18.1.8-2-any.pkg.tar.zst";
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libc++-18.1.8-1-any.pkg.tar.zst";
sha256 = "0mfd8wrmgx12j5gf354j7pk1l3lg9ykxvq75xdk3jipsr6hbn846";
name = "mingw-w64-clang-x86_64-libc++-18.1.8-1-any.pkg.tar.zst";
})
(pkgs.fetchurl {
@ -43,9 +43,9 @@
})
(pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libxml2-2.12.9-1-any.pkg.tar.zst";
sha256 = "0cjz2vj9yz6k5xj601cp0yk631rrr0z94ciamwqrvclb0yhakf25";
name = "mingw-w64-clang-x86_64-libxml2-2.12.9-1-any.pkg.tar.zst";
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libxml2-2.12.8-1-any.pkg.tar.zst";
sha256 = "1imipb0dz4w6x4n9arn22imyzzcwdlf2cqxvn7irqq7w9by6fy0b";
name = "mingw-w64-clang-x86_64-libxml2-2.12.8-1-any.pkg.tar.zst";
})
(pkgs.fetchurl {
@ -79,15 +79,15 @@
})
(pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-headers-git-12.0.0.r250.gc6bf4bdf6-1-any.pkg.tar.zst";
sha256 = "0163jzjlvq7inpafy3h48pkwag3ysk6x56xm84yfcz5q52fnfzq5";
name = "mingw-w64-clang-x86_64-headers-git-12.0.0.r250.gc6bf4bdf6-1-any.pkg.tar.zst";
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-headers-git-12.0.0.r81.g90abf784a-1-any.pkg.tar.zst";
sha256 = "1h3cdcajz29iq7vja908kkijz1vb9xn0f7w1lw1ima0q0zhinv4q";
name = "mingw-w64-clang-x86_64-headers-git-12.0.0.r81.g90abf784a-1-any.pkg.tar.zst";
})
(pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-crt-git-12.0.0.r250.gc6bf4bdf6-1-any.pkg.tar.zst";
sha256 = "00cn1mi29mfys7qy4hvgnjd0smqvnkdn3ibnrr6a3wy1h2vaykgq";
name = "mingw-w64-clang-x86_64-crt-git-12.0.0.r250.gc6bf4bdf6-1-any.pkg.tar.zst";
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-crt-git-12.0.0.r81.g90abf784a-1-any.pkg.tar.zst";
sha256 = "15kamyi3b0j6f5zxin4i2jgzjc7lzvwl4z5cz3dx0i8hg91aq0n7";
name = "mingw-w64-clang-x86_64-crt-git-12.0.0.r81.g90abf784a-1-any.pkg.tar.zst";
})
(pkgs.fetchurl {
@ -97,15 +97,15 @@
})
(pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libwinpthread-git-12.0.0.r250.gc6bf4bdf6-1-any.pkg.tar.zst";
sha256 = "1zkzqqd31xpkv817wja3qssjjx891bsdxw07037hv2sk0qr4ffn9";
name = "mingw-w64-clang-x86_64-libwinpthread-git-12.0.0.r250.gc6bf4bdf6-1-any.pkg.tar.zst";
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libwinpthread-git-12.0.0.r81.g90abf784a-1-any.pkg.tar.zst";
sha256 = "0qdvgs1rmjjhn9klf9kpw7l0ydz36rr5fasn4q9gpby2lgl11bkb";
name = "mingw-w64-clang-x86_64-libwinpthread-git-12.0.0.r81.g90abf784a-1-any.pkg.tar.zst";
})
(pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-winpthreads-git-12.0.0.r250.gc6bf4bdf6-1-any.pkg.tar.zst";
sha256 = "02ynia88ad3l03r08nyldmnajwqkyxcjd191lyamkbj4d6zck323";
name = "mingw-w64-clang-x86_64-winpthreads-git-12.0.0.r250.gc6bf4bdf6-1-any.pkg.tar.zst";
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-winpthreads-git-12.0.0.r81.g90abf784a-1-any.pkg.tar.zst";
sha256 = "0rh2mn078cifcmr4as4k57jxjln5lbnsmpx47h9d0s5d2i8sf2rc";
name = "mingw-w64-clang-x86_64-winpthreads-git-12.0.0.r81.g90abf784a-1-any.pkg.tar.zst";
})
(pkgs.fetchurl {
@ -115,15 +115,15 @@
})
(pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-c-ares-1.33.1-1-any.pkg.tar.zst";
sha256 = "14r6jjsvfbapbkv2zqp2yglva4vz4srzkgk7f186ri3kcafjspgq";
name = "mingw-w64-clang-x86_64-c-ares-1.33.1-1-any.pkg.tar.zst";
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-c-ares-1.29.0-1-any.pkg.tar.zst";
sha256 = "01xg1h1a8kda0kq2921w25ybvm1ms7lfdzday0hv93f3myq7briq";
name = "mingw-w64-clang-x86_64-c-ares-1.29.0-1-any.pkg.tar.zst";
})
(pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-brotli-1.1.0-2-any.pkg.tar.zst";
sha256 = "1q01lz9lcyrjmkhv9rddgjazmk7warlcmwhc4qkq9y6h0yfsb71n";
name = "mingw-w64-clang-x86_64-brotli-1.1.0-2-any.pkg.tar.zst";
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-brotli-1.1.0-1-any.pkg.tar.zst";
sha256 = "113mha41q53cx0hw13cq1xdf7zbsd58sh8cl1cd7xzg1q69n60w2";
name = "mingw-w64-clang-x86_64-brotli-1.1.0-1-any.pkg.tar.zst";
})
(pkgs.fetchurl {
@ -163,9 +163,9 @@
})
(pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-openssl-3.3.2-1-any.pkg.tar.zst";
sha256 = "1djgpcz447yvhdy1yq5wh8l5d0821izxklx9afyszbw0pbr7f24y";
name = "mingw-w64-clang-x86_64-openssl-3.3.2-1-any.pkg.tar.zst";
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-openssl-3.3.1-1-any.pkg.tar.zst";
sha256 = "0ywhwm4kw3qjzv0872qwabnsq2rzbmqjb9m69q3fykjl0m9gigsa";
name = "mingw-w64-clang-x86_64-openssl-3.3.1-1-any.pkg.tar.zst";
})
(pkgs.fetchurl {
@ -175,39 +175,39 @@
})
(pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-nghttp2-1.63.0-1-any.pkg.tar.zst";
sha256 = "0lfqrlmapsc7ilxjmhr7hxi578vclqlhpqimbvzq0c70c0iwk864";
name = "mingw-w64-clang-x86_64-nghttp2-1.63.0-1-any.pkg.tar.zst";
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-nghttp2-1.61.0-2-any.pkg.tar.zst";
sha256 = "07bkk98126gy4k6lb9rrqqnzjfz9j2rsr5dzr2djmzdkw0h4dr95";
name = "mingw-w64-clang-x86_64-nghttp2-1.61.0-2-any.pkg.tar.zst";
})
(pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-nghttp3-1.5.0-1-any.pkg.tar.zst";
sha256 = "1ljl9kdasf91bxkqcmbbjchp5g00ahv8jn2zab38899z6j3x43nz";
name = "mingw-w64-clang-x86_64-nghttp3-1.5.0-1-any.pkg.tar.zst";
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-nghttp3-1.4.0-1-any.pkg.tar.zst";
sha256 = "007w2252nzn274j4wjc1vf56xyzzh5vg3blj1hil7mlmffgvc923";
name = "mingw-w64-clang-x86_64-nghttp3-1.4.0-1-any.pkg.tar.zst";
})
(pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-curl-8.9.1-2-any.pkg.tar.zst";
sha256 = "1zr6kgqp9i4qqrfckh3kfmz4x1cwv4xis9sfqsx7xji88priax64";
name = "mingw-w64-clang-x86_64-curl-8.9.1-2-any.pkg.tar.zst";
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-curl-8.8.0-10-any.pkg.tar.zst";
sha256 = "024z5b1achkf448gxqy1i3gcw371x54kfl6igv08b5wb3rrw35a4";
name = "mingw-w64-clang-x86_64-curl-8.8.0-10-any.pkg.tar.zst";
})
(pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-rust-1.80.1-1-any.pkg.tar.zst";
sha256 = "1dm4vlrfi9m6xl09zpn0yjr7qcjjr4x738z1rjfwysfnc0awq4x8";
name = "mingw-w64-clang-x86_64-rust-1.80.1-1-any.pkg.tar.zst";
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-rust-1.79.0-1-any.pkg.tar.zst";
sha256 = "0i7s88hj8m4920xifkj7i4b4sq8cqq7p5cypp3jqx3dc44pwm19a";
name = "mingw-w64-clang-x86_64-rust-1.79.0-1-any.pkg.tar.zst";
})
(pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-cppdap-1.65-1-any.pkg.tar.zst";
sha256 = "0phhwkcqp30dsyj5vr6w99sgm1jfm5rzg0w5x5mv9md4x7lm9lmh";
name = "mingw-w64-clang-x86_64-cppdap-1.65-1-any.pkg.tar.zst";
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-pkgconf-1~2.2.0-1-any.pkg.tar.zst";
sha256 = "1y44ijg3y8p80f1yn9972nshrnyrd06a9sh984ajhxg8bi8s5xyl";
name = "mingw-w64-clang-x86_64-pkgconf-12.2.0-1-any.pkg.tar.zst";
})
(pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-expat-2.6.3-1-any.pkg.tar.zst";
sha256 = "19xfl1q78q1k8j0lr5aspcf668pmfg01fgib73zq7ff7y5y5fcyi";
name = "mingw-w64-clang-x86_64-expat-2.6.3-1-any.pkg.tar.zst";
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-expat-2.6.2-1-any.pkg.tar.zst";
sha256 = "0kj1vzjh3qh7d2g47avlgk7a6j4nc62111hy1m63jwq0alc01k38";
name = "mingw-w64-clang-x86_64-expat-2.6.2-1-any.pkg.tar.zst";
})
(pkgs.fetchurl {
@ -229,9 +229,9 @@
})
(pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-lz4-1.10.0-1-any.pkg.tar.zst";
sha256 = "0kznnw9z9zqxkmn8qbypm2rpsfaapbgls1ks3zzpfnfjz9cpw8py";
name = "mingw-w64-clang-x86_64-lz4-1.10.0-1-any.pkg.tar.zst";
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-lz4-1.9.4-1-any.pkg.tar.zst";
sha256 = "0nn7cy25j53q5ckkx4n4f77w00xdwwf5wjswm374shvvs58nlln0";
name = "mingw-w64-clang-x86_64-lz4-1.9.4-1-any.pkg.tar.zst";
})
(pkgs.fetchurl {
@ -264,12 +264,6 @@
name = "mingw-w64-clang-x86_64-ninja-1.12.1-1-any.pkg.tar.zst";
})
(pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-pkgconf-1~2.3.0-1-any.pkg.tar.zst";
sha256 = "15i7x6akkgs7aa7aa804k93p2iipnvygsy7z8hsafskka3h150fa";
name = "mingw-w64-clang-x86_64-pkgconf-12.3.0-1-any.pkg.tar.zst";
})
(pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-rhash-1.4.4-3-any.pkg.tar.zst";
sha256 = "1ysbxirpfr0yf7pvyps75lnwc897w2a2kcid3nb4j6ilw6n64jmc";
@ -277,9 +271,9 @@
})
(pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-cmake-3.30.3-1-any.pkg.tar.zst";
sha256 = "0fjwf6xxzli6rcsbzr1razldmm538ibkyf5kw132lpaz5wma9bj8";
name = "mingw-w64-clang-x86_64-cmake-3.30.3-1-any.pkg.tar.zst";
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-cmake-3.30.0-1-any.pkg.tar.zst";
sha256 = "07b7132hwhiqrf0l2lgw3g4zw9i2lln3kqc9kg2qijvkapbkmwqb";
name = "mingw-w64-clang-x86_64-cmake-3.30.0-1-any.pkg.tar.zst";
})
(pkgs.fetchurl {
@ -301,9 +295,9 @@
})
(pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-readline-8.2.013-1-any.pkg.tar.zst";
sha256 = "0pv1ypqfgm4mimzr0amq9anr1ysqmzrwv6gfk7rrlzhihadknsvr";
name = "mingw-w64-clang-x86_64-readline-8.2.013-1-any.pkg.tar.zst";
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-readline-8.2.010-1-any.pkg.tar.zst";
sha256 = "1s47pd5iz8y3hspsxn4pnp0v3m05ccia40v5nfvx0rmwgvcaz82v";
name = "mingw-w64-clang-x86_64-readline-8.2.010-1-any.pkg.tar.zst";
})
(pkgs.fetchurl {
@ -313,9 +307,9 @@
})
(pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-sqlite3-3.46.1-1-any.pkg.tar.zst";
sha256 = "1axplxyjnaz411qzjjqwbj55fbrh4akq3plm2p1sx64jp844xpyq";
name = "mingw-w64-clang-x86_64-sqlite3-3.46.1-1-any.pkg.tar.zst";
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-sqlite3-3.46.0-1-any.pkg.tar.zst";
sha256 = "0q676i2z5nr4c71jnd4z5qz9xa1xryl0cpi84w74yvd0p4qiz7y2";
name = "mingw-w64-clang-x86_64-sqlite3-3.46.0-1-any.pkg.tar.zst";
})
(pkgs.fetchurl {
@ -330,12 +324,6 @@
name = "mingw-w64-clang-x86_64-tzdata-2024a-1-any.pkg.tar.zst";
})
(pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python3.12-3.12.1-2-any.pkg.tar.zst";
sha256 = "0wmd39wl9z237w093a7c6hl5pclca9yvwxn0kiw6i2njk3sjv51a";
name = "mingw-w64-clang-x86_64-python3.12-3.12.1-2-any.pkg.tar.zst";
})
(pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-3.11.9-1-any.pkg.tar.zst";
sha256 = "0ah1idjqxg7jc07a1gz9z766rjjd0f0c6ri4hpcsimsrbj1zjd3c";
@ -349,20 +337,20 @@
})
(pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-openblas-0.3.28-1-any.pkg.tar.zst";
sha256 = "1pskcqc1lg9p8m8rk7bw3mz7mn7vw5fpl7zxa23bhjn02p5b79qq";
name = "mingw-w64-clang-x86_64-openblas-0.3.28-1-any.pkg.tar.zst";
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-openblas-0.3.27-1-any.pkg.tar.zst";
sha256 = "06ygz1wa488wqvmxbn74b0fyan4wf3lb6kbwfampgikd1gijww2k";
name = "mingw-w64-clang-x86_64-openblas-0.3.27-1-any.pkg.tar.zst";
})
(pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-numpy-2.0.1-1-any.pkg.tar.zst";
sha256 = "0ks6q8v58h4wmr2pzsjl2xm4f63g0psvfm0jwlz24mqfxp8gqfcc";
name = "mingw-w64-clang-x86_64-python-numpy-2.0.1-1-any.pkg.tar.zst";
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-numpy-1.26.4-1-any.pkg.tar.zst";
sha256 = "00h0ap954cjwlsc3p01fjwy7s3nlzs90v0kmnrzxm0rljmvn4jkf";
name = "mingw-w64-clang-x86_64-python-numpy-1.26.4-1-any.pkg.tar.zst";
})
(pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-setuptools-74.0.0-1-any.pkg.tar.zst";
sha256 = "0xc95z5jzzjf5lw35bs4yn5rlwkrkmrh78yi5rranqpz5nn0wsa0";
name = "mingw-w64-clang-x86_64-python-setuptools-74.0.0-1-any.pkg.tar.zst";
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-setuptools-70.2.0-1-any.pkg.tar.zst";
sha256 = "1q4r9bg2hn3jmshvq81xm5zvy9wn35yf0z2ayksrkwph1zzdkvkm";
name = "mingw-w64-clang-x86_64-python-setuptools-70.2.0-1-any.pkg.tar.zst";
})
]