diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 0000000..188308d --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,2 @@ +[target.x86_64-unknown-linux-gnu] +rustflags = ["-C", "link-arg=-fuse-ld=lld"] diff --git a/Cargo.lock b/Cargo.lock index ce59893..9435fed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "ahash" @@ -9,7 +9,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", - "getrandom", + "getrandom 0.2.15", "once_cell", "version_check", "zerocopy", @@ -65,11 +65,12 @@ dependencies = [ [[package]] name = "anstyle-wincon" -version = "3.0.6" +version = "3.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2109dbce0e72be3ec00bed26e6a7479ca384ad226efdd66db8fa2e3a38c83125" +checksum = "ca3534e77181a9cc07539ad51f2141fe32f6c3ffd4df76db8ad92346b003ae4e" dependencies = [ "anstyle", + "once_cell", "windows-sys 0.59.0", ] @@ -105,9 +106,9 @@ checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" [[package]] name = "bitflags" -version = "2.6.0" +version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" +checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" [[package]] name = "block-buffer" @@ -126,9 +127,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "cc" -version = "1.2.4" +version = "1.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9157bbaa6b165880c27a4293a474c91cdcf265cc68cc829bf10be0964a391caf" +checksum = "13208fcbb66eaeffe09b99fffbe1af420f00a7b35aa99ad683dfc1aa76145229" dependencies = [ "shlex", ] @@ -141,9 +142,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "clap" -version = "4.5.23" +version = "4.5.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3135e7ec2ef7b10c6ed8950f0f792ed96ee093fa088608f1c76e569722700c84" +checksum = "769b0145982b4b48713e01ec42d61614425f27b7058bda7180a3a41f30104796" dependencies = [ "clap_builder", "clap_derive", @@ -151,9 +152,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.23" +version = "4.5.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30582fc632330df2bd26877bde0c1f4470d57c582bbc070376afcd04d8cb4838" +checksum = "1b26884eb4b57140e4d2d93652abfa49498b938b3c9179f9fc487b0acc3edad7" dependencies = [ "anstream", "anstyle", @@ -163,14 +164,14 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.18" +version = "4.5.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab" +checksum = "54b755194d6389280185988721fffba69495eed5ee9feeee9a599b53db80318c" dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -187,21 +188,21 @@ checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" [[package]] name = "console" -version = "0.15.8" +version = "0.15.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb" +checksum = "ea3c6ecd8059b57859df5c69830340ed3c41d30e3da0c1cbed90a96ac853041b" dependencies = [ "encode_unicode", - "lazy_static", "libc", - "windows-sys 0.52.0", + "once_cell", + "windows-sys 0.59.0", ] [[package]] name = "cpufeatures" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16b80225097f2e5ae4e7179dd2266824648f3e2f49d9134d584b76389d31c4c3" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" dependencies = [ "libc", ] @@ -221,18 +222,18 @@ dependencies = [ [[package]] name = "crossbeam-channel" -version = "0.5.13" +version = "0.5.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33480d6946193aa8033910124896ca395333cae7e2d1113d1fef6c3272217df2" +checksum = "06ba6d68e24814cb8de6bb986db8222d3a027d15872cabc0d18817bc3c0e4471" dependencies = [ "crossbeam-utils", ] [[package]] name = "crossbeam-deque" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" dependencies = [ "crossbeam-epoch", "crossbeam-utils", @@ -249,18 +250,18 @@ dependencies = [ [[package]] name = "crossbeam-queue" -version = "0.3.11" +version = "0.3.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df0346b5d5e76ac2fe4e327c5fd1118d6be7c51dfb18f9b7922923f287471e35" +checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" dependencies = [ "crossbeam-utils", ] [[package]] name = "crossbeam-utils" -version = "0.8.20" +version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "crypto-common" @@ -305,9 +306,9 @@ dependencies = [ [[package]] name = "encode_unicode" -version = "0.3.6" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" [[package]] name = "equivalent" @@ -333,9 +334,15 @@ checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" [[package]] name = "fixedbitset" -version = "0.4.2" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" +checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" + +[[package]] +name = "foldhash" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f" [[package]] name = "fxhash" @@ -373,14 +380,26 @@ checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" dependencies = [ "cfg-if", "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", +] + +[[package]] +name = "getrandom" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.13.3+wasi-0.2.2", + "windows-targets", ] [[package]] name = "glob" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" +checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" [[package]] name = "hashbrown" @@ -388,20 +407,14 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" -[[package]] -name = "hashbrown" -version = "0.14.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" -dependencies = [ - "ahash", -] - [[package]] name = "hashbrown" version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +dependencies = [ + "foldhash", +] [[package]] name = "heck" @@ -417,11 +430,11 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" [[package]] name = "home" -version = "0.5.9" +version = "0.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" +checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -436,9 +449,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.7.0" +version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f" +checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652" dependencies = [ "equivalent", "hashbrown 0.15.2", @@ -472,7 +485,7 @@ checksum = "9dd28cfd4cfba665d47d31c08a6ba637eed16770abca2eccbbc3ca831fef1e44" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -497,9 +510,9 @@ checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" [[package]] name = "itertools" -version = "0.13.0" +version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" dependencies = [ "either", ] @@ -521,9 +534,9 @@ dependencies = [ [[package]] name = "lalrpop" -version = "0.22.0" +version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06093b57658c723a21da679530e061a8c25340fa5a6f98e313b542268c7e2a1f" +checksum = "7047a26de42016abf8f181b46b398aef0b77ad46711df41847f6ed869a2a1d5b" dependencies = [ "ascii-canvas", "bit-set", @@ -543,9 +556,9 @@ dependencies = [ [[package]] name = "lalrpop-util" -version = "0.22.0" +version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "feee752d43abd0f4807a921958ab4131f692a44d4d599733d4419c5d586176ce" +checksum = "e8d05b3fe34b8bd562c338db725dfa9beb9451a48f65f129ccb9538b48d2c93b" dependencies = [ "regex-automata", "rustversion", @@ -559,9 +572,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.168" +version = "0.2.169" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aaeb2981e0606ca11d79718f8bb01164f1d6ed75080182d3abf017e6d244b6d" +checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" [[package]] name = "libloading" @@ -581,9 +594,9 @@ checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" [[package]] name = "linux-raw-sys" -version = "0.4.14" +version = "0.4.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" +checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" [[package]] name = "llvm-sys" @@ -610,9 +623,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.22" +version = "0.4.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" [[package]] name = "memchr" @@ -655,7 +668,7 @@ name = "nac3core" version = "0.1.0" dependencies = [ "crossbeam", - "indexmap 2.7.0", + "indexmap 2.7.1", "indoc", "inkwell", "insta", @@ -663,7 +676,6 @@ dependencies = [ "nac3core_derive", "nac3parser", "parking_lot", - "rayon", "regex", "strum", "strum_macros", @@ -678,7 +690,7 @@ dependencies = [ "proc-macro-error", "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", "trybuild", ] @@ -751,55 +763,55 @@ dependencies = [ [[package]] name = "petgraph" -version = "0.6.5" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" +checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" dependencies = [ "fixedbitset", - "indexmap 2.7.0", + "indexmap 2.7.1", ] [[package]] name = "phf" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc" +checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078" dependencies = [ "phf_macros", - "phf_shared 0.11.2", + "phf_shared 0.11.3", ] [[package]] name = "phf_codegen" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8d39688d359e6b34654d328e262234662d16cc0f60ec8dcbe5e718709342a5a" +checksum = "aef8048c789fa5e851558d709946d6d79a8ff88c0440c587967f8e94bfb1216a" dependencies = [ "phf_generator", - "phf_shared 0.11.2", + "phf_shared 0.11.3", ] [[package]] name = "phf_generator" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0" +checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d" dependencies = [ - "phf_shared 0.11.2", + "phf_shared 0.11.3", "rand", ] [[package]] name = "phf_macros" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3444646e286606587e49f3bcf1679b8cef1dc2c5ecc29ddacaffc305180d464b" +checksum = "f84ac04429c13a7ff43785d75ad27569f2951ce0ffd30a3321230db2fc727216" dependencies = [ "phf_generator", - "phf_shared 0.11.2", + "phf_shared 0.11.3", "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -808,16 +820,16 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6796ad771acdc0123d2a88dc428b5e38ef24456743ddb1744ed628f9815c096" dependencies = [ - "siphasher", + "siphasher 0.3.11", ] [[package]] name = "phf_shared" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90fcb95eef784c2ac79119d1dd819e162b5da872ce6f3c3abe1e8ca1c082f72b" +checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5" dependencies = [ - "siphasher", + "siphasher 1.0.1", ] [[package]] @@ -873,9 +885,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.92" +version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0" +checksum = "60946a68e5f9d28b0dc1c21bb8a97ee7d018a8b322fa57838ba31cc878e22d99" dependencies = [ "unicode-ident", ] @@ -927,7 +939,7 @@ dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -940,14 +952,14 @@ dependencies = [ "proc-macro2", "pyo3-build-config", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] name = "quote" -version = "1.0.37" +version = "1.0.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" +checksum = "0e4dccaaaf89514f546c693ddc140f729f958c247918a13380cccc6078391acc" dependencies = [ "proc-macro2", ] @@ -979,27 +991,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom", -] - -[[package]] -name = "rayon" -version = "1.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" -dependencies = [ - "either", - "rayon-core", -] - -[[package]] -name = "rayon-core" -version = "1.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" -dependencies = [ - "crossbeam-deque", - "crossbeam-utils", + "getrandom 0.2.15", ] [[package]] @@ -1049,9 +1041,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.42" +version = "0.38.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f93dc38ecbab2eb790ff964bb77fa94faf256fd3e73285fd7ba0903b76bedb85" +checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" dependencies = [ "bitflags", "errno", @@ -1062,15 +1054,15 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.18" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e819f2bc632f285be6d7cd36e25940d45b2391dd6d9b939e79de557f7014248" +checksum = "f7c45b9784283f1b2e7fb61b42047c2fd678ef0960d4f6f1eba131594cc369d4" [[package]] name = "ryu" -version = "1.0.18" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" [[package]] name = "same-file" @@ -1089,35 +1081,35 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "semver" -version = "1.0.24" +version = "1.0.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3cb6eb87a131f756572d7fb904f6e7b68633f09cca868c5df1c4b8d1a694bbba" +checksum = "f79dfe2d285b0488816f30e700a7438c5a73d816b5b7d3ac72fbc48b0d185e03" [[package]] name = "serde" -version = "1.0.216" +version = "1.0.217" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b9781016e935a97e8beecf0c933758c97a5520d32930e460142b4cd80c6338e" +checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.216" +version = "1.0.217" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46f859dbbf73865c6627ed570e78961cd3ac92407a2d117204c49232485da55e" +checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] name = "serde_json" -version = "1.0.133" +version = "1.0.137" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377" +checksum = "930cfb6e6abf99298aaad7d29abbef7a9999a9a8806a40088f55f0dcec03146b" dependencies = [ "itoa", "memchr", @@ -1164,9 +1156,9 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] name = "similar" -version = "2.6.0" +version = "2.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1de1d4f81173b03af4c0cbed3c898f6bff5b870e4a7f5d6f4057d62a7a4b686e" +checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" [[package]] name = "siphasher" @@ -1174,6 +1166,12 @@ version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" +[[package]] +name = "siphasher" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" + [[package]] name = "smallvec" version = "1.13.2" @@ -1182,12 +1180,11 @@ checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "string-interner" -version = "0.17.0" +version = "0.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c6a0d765f5807e98a091107bae0a56ea3799f66a5de47b2c84c94a39c09974e" +checksum = "1a3275464d7a9f2d4cac57c89c2ef96a8524dba2864c8d6f82e3980baf136f9b" dependencies = [ - "cfg-if", - "hashbrown 0.14.5", + "hashbrown 0.15.2", "serde", ] @@ -1226,7 +1223,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -1242,9 +1239,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.90" +version = "2.0.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "919d3b74a5dd0ccd15aeb8f93e7006bd9e14c295087c9896a110f490752bcf31" +checksum = "d5d0adab1ae378d7f53bdebc67a39f1f151407ef230f0ce2883572f5d8985c80" dependencies = [ "proc-macro2", "quote", @@ -1265,12 +1262,13 @@ checksum = "42a4d50cdb458045afc8131fd91b64904da29548bcb63c7236e0844936c13078" [[package]] name = "tempfile" -version = "3.14.0" +version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28cce251fcbc87fac86a866eeb0d6c2d536fc16d06f184bb61aeae11aa4cee0c" +checksum = "38c246215d7d24f48ae091a2902398798e05d978b24315d6efbc00ede9a8bb91" dependencies = [ "cfg-if", "fastrand", + "getrandom 0.3.1", "once_cell", "rustix", "windows-sys 0.59.0", @@ -1278,9 +1276,9 @@ dependencies = [ [[package]] name = "term" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4df4175de05129f31b80458c6df371a15e7fc3fd367272e6bf938e5c351c7ea0" +checksum = "a3bb6001afcea98122260987f8b7b5da969ecad46dbf0b5453702f776b491a41" dependencies = [ "home", "windows-sys 0.52.0", @@ -1325,7 +1323,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -1355,7 +1353,7 @@ version = "0.22.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" dependencies = [ - "indexmap 2.7.0", + "indexmap 2.7.1", "serde", "serde_spanned", "toml_datetime", @@ -1364,9 +1362,9 @@ dependencies = [ [[package]] name = "trybuild" -version = "1.0.101" +version = "1.0.103" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8dcd332a5496c026f1e14b7f3d2b7bd98e509660c04239c58b0ba38a12daded4" +checksum = "b812699e0c4f813b872b373a4471717d9eb550da14b311058a4d9cf4173cbca6" dependencies = [ "dissimilar", "glob", @@ -1438,9 +1436,9 @@ dependencies = [ [[package]] name = "unicode-ident" -version = "1.0.14" +version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" +checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034" [[package]] name = "unicode-width" @@ -1510,6 +1508,15 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasi" +version = "0.13.3+wasi-0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2" +dependencies = [ + "wit-bindgen-rt", +] + [[package]] name = "winapi-util" version = "0.1.9" @@ -1603,13 +1610,22 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.6.20" +version = "0.6.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b" +checksum = "ad699df48212c6cc6eb4435f35500ac6fd3b9913324f938aea302022ce19d310" dependencies = [ "memchr", ] +[[package]] +name = "wit-bindgen-rt" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c" +dependencies = [ + "bitflags", +] + [[package]] name = "yaml-rust" version = "0.4.5" @@ -1637,5 +1653,5 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] diff --git a/flake.lock b/flake.lock index 67bb80e..7e36e53 100644 --- a/flake.lock +++ b/flake.lock @@ -2,11 +2,11 @@ "nodes": { "nixpkgs": { "locked": { - "lastModified": 1733940404, - "narHash": "sha256-Pj39hSoUA86ZePPF/UXiYHHM7hMIkios8TYG29kQT4g=", + "lastModified": 1738680400, + "narHash": "sha256-ooLh+XW8jfa+91F1nhf9OF7qhuA/y1ChLx6lXDNeY5U=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "5d67ea6b4b63378b9c13be21e2ec9d1afc921713", + "rev": "799ba5bffed04ced7067a91798353d360788b30d", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index a20e1f1..51551c7 100644 --- a/flake.nix +++ b/flake.nix @@ -41,7 +41,7 @@ lockFile = ./Cargo.lock; }; passthru.cargoLock = cargoLock; - nativeBuildInputs = [ pkgs.python3 (pkgs.wrapClangMulti pkgs.llvmPackages_14.clang) llvm-tools-irrt pkgs.llvmPackages_14.llvm.out llvm-nac3 ]; + nativeBuildInputs = [ pkgs.python3 (pkgs.wrapClangMulti pkgs.llvmPackages_14.clang) llvm-tools-irrt pkgs.llvmPackages_14.llvm.out pkgs.llvmPackages_14.bintools llvm-nac3 ]; buildInputs = [ pkgs.python3 llvm-nac3 ]; checkInputs = [ (pkgs.python3.withPackages(ps: [ ps.numpy ps.scipy ])) ]; checkPhase = @@ -85,7 +85,7 @@ name = "nac3artiq-instrumented"; src = self; inherit (nac3artiq) cargoLock; - nativeBuildInputs = [ pkgs.python3 packages.x86_64-linux.llvm-tools-irrt llvm-nac3-instrumented ]; + nativeBuildInputs = [ pkgs.python3 packages.x86_64-linux.llvm-tools-irrt pkgs.llvmPackages_14.bintools llvm-nac3-instrumented ]; buildInputs = [ pkgs.python3 llvm-nac3-instrumented ]; cargoBuildFlags = [ "--package" "nac3artiq" "--features" "init-llvm-profile" ]; doCheck = false; @@ -120,6 +120,7 @@ buildInputs = [ (python3-mimalloc.withPackages(ps: [ ps.numpy ps.scipy ps.jsonschema ps.lmdb ps.platformdirs nac3artiq-instrumented ])) pkgs.llvmPackages_14.llvm.out + pkgs.llvmPackages_14.bintools ]; phases = [ "buildPhase" "installPhase" ]; buildPhase = @@ -147,7 +148,7 @@ name = "nac3artiq-pgo"; src = self; inherit (nac3artiq) cargoLock; - nativeBuildInputs = [ pkgs.python3 packages.x86_64-linux.llvm-tools-irrt llvm-nac3-pgo ]; + nativeBuildInputs = [ pkgs.python3 packages.x86_64-linux.llvm-tools-irrt pkgs.llvmPackages_14.bintools llvm-nac3-pgo ]; buildInputs = [ pkgs.python3 llvm-nac3-pgo ]; cargoBuildFlags = [ "--package" "nac3artiq" ]; cargoTestFlags = [ "--package" "nac3ast" "--package" "nac3parser" "--package" "nac3core" "--package" "nac3artiq" ]; @@ -168,7 +169,7 @@ buildInputs = with pkgs; [ # build dependencies packages.x86_64-linux.llvm-nac3 - (pkgs.wrapClangMulti llvmPackages_14.clang) llvmPackages_14.llvm.out # for running nac3standalone demos + (pkgs.wrapClangMulti llvmPackages_14.clang) llvmPackages_14.llvm.out llvmPackages_14.bintools # for running nac3standalone demos packages.x86_64-linux.llvm-tools-irrt cargo rustc diff --git a/nac3artiq/Cargo.toml b/nac3artiq/Cargo.toml index 2da812e..fa80465 100644 --- a/nac3artiq/Cargo.toml +++ b/nac3artiq/Cargo.toml @@ -9,10 +9,10 @@ name = "nac3artiq" crate-type = ["cdylib"] [dependencies] -itertools = "0.13" +itertools = "0.14" pyo3 = { version = "0.21", features = ["extension-module", "gil-refs"] } parking_lot = "0.12" -tempfile = "3.13" +tempfile = "3.16" nac3core = { path = "../nac3core" } nac3ld = { path = "../nac3ld" } diff --git a/nac3artiq/demo/embedding_map.py b/nac3artiq/demo/embedding_map.py deleted file mode 100644 index a43af69..0000000 --- a/nac3artiq/demo/embedding_map.py +++ /dev/null @@ -1,39 +0,0 @@ -class EmbeddingMap: - def __init__(self): - self.object_inverse_map = {} - self.object_map = {} - self.string_map = {} - self.string_reverse_map = {} - self.function_map = {} - self.attributes_writeback = [] - - def store_function(self, key, fun): - self.function_map[key] = fun - return key - - def store_object(self, obj): - obj_id = id(obj) - if obj_id in self.object_inverse_map: - return self.object_inverse_map[obj_id] - key = len(self.object_map) + 1 - self.object_map[key] = obj - self.object_inverse_map[obj_id] = key - return key - - def store_str(self, s): - if s in self.string_reverse_map: - return self.string_reverse_map[s] - key = len(self.string_map) - self.string_map[key] = s - self.string_reverse_map[s] = key - return key - - def retrieve_function(self, key): - return self.function_map[key] - - def retrieve_object(self, key): - return self.object_map[key] - - def retrieve_str(self, key): - return self.string_map[key] - diff --git a/nac3artiq/demo/min_artiq.py b/nac3artiq/demo/min_artiq.py index 62d32cc..fef018b 100644 --- a/nac3artiq/demo/min_artiq.py +++ b/nac3artiq/demo/min_artiq.py @@ -6,7 +6,6 @@ from typing import Generic, TypeVar from math import floor, ceil import nac3artiq -from embedding_map import EmbeddingMap __all__ = [ @@ -193,6 +192,46 @@ def print_int64(x: int64): raise NotImplementedError("syscall not simulated") +class EmbeddingMap: + def __init__(self): + self.object_inverse_map = {} + self.object_map = {} + self.string_map = {} + self.string_reverse_map = {} + self.function_map = {} + self.attributes_writeback = [] + + def store_function(self, key, fun): + self.function_map[key] = fun + return key + + def store_object(self, obj): + obj_id = id(obj) + if obj_id in self.object_inverse_map: + return self.object_inverse_map[obj_id] + key = len(self.object_map) + 1 + self.object_map[key] = obj + self.object_inverse_map[obj_id] = key + return key + + def store_str(self, s): + if s in self.string_reverse_map: + return self.string_reverse_map[s] + key = len(self.string_map) + self.string_map[key] = s + self.string_reverse_map[s] = key + return key + + def retrieve_function(self, key): + return self.function_map[key] + + def retrieve_object(self, key): + return self.object_map[key] + + def retrieve_str(self, key): + return self.string_map[key] + + @nac3 class Core: ref_period: KernelInvariant[float] diff --git a/nac3artiq/demo/module.py b/nac3artiq/demo/module.py new file mode 100644 index 0000000..58f9245 --- /dev/null +++ b/nac3artiq/demo/module.py @@ -0,0 +1,26 @@ +from min_artiq import * +from numpy import int32 + +# Global Variable Definition +X: Kernel[int32] = 1 + +# TopLevelFunction Defintion +@kernel +def display_X(): + print_int32(X) + +# TopLevel Class Definition +@nac3 +class A: + @kernel + def __init__(self): + self.set_x(1) + + @kernel + def set_x(self, new_val: int32): + global X + X = new_val + + @kernel + def get_X(self) -> int32: + return X diff --git a/nac3artiq/demo/module_support.py b/nac3artiq/demo/module_support.py new file mode 100644 index 0000000..78ef656 --- /dev/null +++ b/nac3artiq/demo/module_support.py @@ -0,0 +1,26 @@ +from min_artiq import * +import module as module_definition + +@nac3 +class TestModuleSupport: + core: KernelInvariant[Core] + + def __init__(self): + self.core = Core() + + @kernel + def run(self): + # Accessing classes + obj = module_definition.A() + obj.get_X() + obj.set_x(2) + + # Calling functions + module_definition.display_X() + + # Updating global variables + module_definition.X = 9 + module_definition.display_X() + +if __name__ == "__main__": + TestModuleSupport().run() \ No newline at end of file diff --git a/nac3artiq/demo/numpy_primitives_decay.py b/nac3artiq/demo/numpy_primitives_decay.py new file mode 100644 index 0000000..957d363 --- /dev/null +++ b/nac3artiq/demo/numpy_primitives_decay.py @@ -0,0 +1,29 @@ +from min_artiq import * +import numpy +from numpy import int32 + + +@nac3 +class NumpyBoolDecay: + core: KernelInvariant[Core] + np_true: KernelInvariant[bool] + np_false: KernelInvariant[bool] + np_int: KernelInvariant[int32] + np_float: KernelInvariant[float] + np_str: KernelInvariant[str] + + def __init__(self): + self.core = Core() + self.np_true = numpy.True_ + self.np_false = numpy.False_ + self.np_int = numpy.int32(0) + self.np_float = numpy.float64(0.0) + self.np_str = numpy.str_("") + + @kernel + def run(self): + pass + + +if __name__ == "__main__": + NumpyBoolDecay().run() diff --git a/nac3artiq/demo/string_attribute_issue337.py b/nac3artiq/demo/string_attribute_issue337.py deleted file mode 100644 index 9749462..0000000 --- a/nac3artiq/demo/string_attribute_issue337.py +++ /dev/null @@ -1,24 +0,0 @@ -from min_artiq import * -from numpy import int32 - - -@nac3 -class Demo: - core: KernelInvariant[Core] - attr1: KernelInvariant[str] - attr2: KernelInvariant[int32] - - - def __init__(self): - self.core = Core() - self.attr2 = 32 - self.attr1 = "SAMPLE" - - @kernel - def run(self): - print_int32(self.attr2) - self.attr1 - - -if __name__ == "__main__": - Demo().run() diff --git a/nac3artiq/demo/support_class_attr_issue102.py b/nac3artiq/demo/support_class_attr_issue102.py deleted file mode 100644 index 1b93144..0000000 --- a/nac3artiq/demo/support_class_attr_issue102.py +++ /dev/null @@ -1,40 +0,0 @@ -from min_artiq import * -from numpy import int32 - - -@nac3 -class Demo: - attr1: KernelInvariant[int32] = 2 - attr2: int32 = 4 - attr3: Kernel[int32] - - @kernel - def __init__(self): - self.attr3 = 8 - - -@nac3 -class NAC3Devices: - core: KernelInvariant[Core] - attr4: KernelInvariant[int32] = 16 - - def __init__(self): - self.core = Core() - - @kernel - def run(self): - Demo.attr1 # Supported - # Demo.attr2 # Field not accessible on Kernel - # Demo.attr3 # Only attributes can be accessed in this way - # Demo.attr1 = 2 # Attributes are immutable - - self.attr4 # Attributes can be accessed within class - - obj = Demo() - obj.attr1 # Attributes can be accessed by class objects - - NAC3Devices.attr4 # Attributes accessible for classes without __init__ - - -if __name__ == "__main__": - NAC3Devices().run() diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 9f0211e..d05158b 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -19,9 +19,9 @@ use nac3core::{ llvm_intrinsics::{call_int_smax, call_memcpy, call_stackrestore, call_stacksave}, stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with}, type_aligned_alloca, - types::ndarray::NDArrayType, + types::{ndarray::NDArrayType, RangeType}, values::{ - ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, ProxyValue, RangeValue, + ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, ProxyValue, UntypedArrayLikeAccessor, }, CodeGenContext, CodeGenerator, @@ -29,6 +29,7 @@ use nac3core::{ inkwell::{ context::Context, module::Linkage, + targets::TargetMachine, types::{BasicType, IntType}, values::{BasicValueEnum, IntValue, PointerValue, StructValue}, AddressSpace, IntPredicate, OptimizationLevel, @@ -87,13 +88,13 @@ pub struct ArtiqCodeGenerator<'a> { impl<'a> ArtiqCodeGenerator<'a> { pub fn new( name: String, - size_t: u32, + size_t: IntType<'_>, timeline: &'a (dyn TimeFns + Sync), ) -> ArtiqCodeGenerator<'a> { - assert!(size_t == 32 || size_t == 64); + assert!(matches!(size_t.get_bit_width(), 32 | 64)); ArtiqCodeGenerator { name, - size_t, + size_t: size_t.get_bit_width(), name_counter: 0, start: None, end: None, @@ -102,6 +103,17 @@ impl<'a> ArtiqCodeGenerator<'a> { } } + #[must_use] + pub fn with_target_machine( + name: String, + ctx: &Context, + target_machine: &TargetMachine, + timeline: &'a (dyn TimeFns + Sync), + ) -> ArtiqCodeGenerator<'a> { + let llvm_usize = ctx.ptr_sized_int_type(&target_machine.get_target_data(), None); + Self::new(name, llvm_usize, timeline) + } + /// If the generator is currently in a direct-`parallel` block context, emits IR that resets the /// position of the timeline to the initial timeline position before entering the `parallel` /// block. @@ -162,7 +174,7 @@ impl<'a> ArtiqCodeGenerator<'a> { } } -impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { +impl CodeGenerator for ArtiqCodeGenerator<'_> { fn get_name(&self) -> &str { &self.name } @@ -464,13 +476,13 @@ fn format_rpc_arg<'ctx>( // libproto_artiq: NDArray = [data[..], dim_sz[..]] let llvm_i1 = ctx.ctx.bool_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); let ndims = extract_ndims(&ctx.unifier, ndims); let dtype = ctx.get_llvm_type(generator, elem_ty); - let ndarray = NDArrayType::new(generator, ctx.ctx, dtype, Some(ndims)) - .map_value(arg.into_pointer_value(), None); + let ndarray = NDArrayType::new(ctx, dtype, ndims) + .map_pointer_value(arg.into_pointer_value(), None); let ndims = llvm_usize.const_int(ndims, false); @@ -549,7 +561,7 @@ fn format_rpc_ret<'ctx>( let llvm_i32 = ctx.ctx.i32_type(); let llvm_i8_8 = ctx.ctx.struct_type(&[llvm_i8.array_type(8).into()], false); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| { @@ -602,7 +614,7 @@ fn format_rpc_ret<'ctx>( let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty); let dtype_llvm = ctx.get_llvm_type(generator, dtype); let ndims = extract_ndims(&ctx.unifier, ndims); - let ndarray = NDArrayType::new(generator, ctx.ctx, dtype_llvm, Some(ndims)) + let ndarray = NDArrayType::new(ctx, dtype_llvm, ndims) .construct_uninitialized(generator, ctx, None); // NOTE: Current content of `ndarray`: @@ -690,7 +702,7 @@ fn format_rpc_ret<'ctx>( // debug_assert(nelems * sizeof(T) >= ndarray_nbytes) if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { - let num_elements = ndarray.size(generator, ctx); + let num_elements = ndarray.size(ctx); let expected_ndarray_nbytes = ctx.builder.build_int_mul(num_elements, itemsize, "").unwrap(); @@ -754,7 +766,7 @@ fn format_rpc_ret<'ctx>( ctx.builder.build_unconditional_branch(head_bb).unwrap(); ctx.builder.position_at_end(tail_bb); - ndarray.as_base_value().into() + ndarray.as_abi_value(ctx).into() } _ => { @@ -802,7 +814,7 @@ pub fn rpc_codegen_callback_fn<'ctx>( ) -> Result>, String> { let int8 = ctx.ctx.i8_type(); let int32 = ctx.ctx.i32_type(); - let size_type = generator.get_size_type(ctx.ctx); + let size_type = ctx.get_size_type(); let ptr_type = int8.ptr_type(AddressSpace::default()); let tag_ptr_type = ctx .ctx @@ -1054,6 +1066,34 @@ pub fn attributes_writeback<'ctx>( )); } } + TypeEnum::TModule { attributes, .. } => { + let mut fields = Vec::new(); + let obj = inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap(); + + for (name, (field_ty, is_method)) in attributes { + if *is_method { + continue; + } + if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() { + fields.push(name.to_string()); + let (index, _) = ctx.get_attr_index(ty, *name); + values.push(( + *field_ty, + ctx.build_gep_and_load( + obj.into_pointer_value(), + &[zero, int32.const_int(index as u64, false)], + None, + ), + )); + } + } + if !fields.is_empty() { + let pydict = PyDict::new(py); + pydict.set_item("obj", val)?; + pydict.set_item("fields", fields)?; + host_attributes.append(pydict)?; + } + } _ => {} } } @@ -1169,7 +1209,7 @@ fn polymorphic_print<'ctx>( let llvm_i32 = ctx.ctx.i32_type(); let llvm_i64 = ctx.ctx.i64_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let suffix = suffix.unwrap_or_default(); @@ -1357,7 +1397,7 @@ fn polymorphic_print<'ctx>( let (dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); let ndarray = NDArrayType::from_unifier_type(generator, ctx, ty) - .map_value(value.into_pointer_value(), None); + .map_pointer_value(value.into_pointer_value(), None); let num_0 = llvm_usize.const_zero(); @@ -1405,7 +1445,7 @@ fn polymorphic_print<'ctx>( fmt.push_str("range("); flush(ctx, generator, &mut fmt, &mut args); - let val = RangeValue::from_pointer_value(value.into_pointer_value(), None); + let val = RangeType::new(ctx).map_pointer_value(value.into_pointer_value(), None); let (start, stop, step) = destructure_range(ctx, val); @@ -1519,7 +1559,7 @@ pub fn call_rtio_log_impl<'ctx>( /// Generates a call to `core_log`. pub fn gen_core_log<'ctx>( ctx: &mut CodeGenContext<'ctx, '_>, - obj: &Option<(Type, ValueEnum<'ctx>)>, + obj: Option<&(Type, ValueEnum<'ctx>)>, fun: (&FunSignature, DefinitionId), args: &[(Option, ValueEnum<'ctx>)], generator: &mut dyn CodeGenerator, @@ -1536,7 +1576,7 @@ pub fn gen_core_log<'ctx>( /// Generates a call to `rtio_log`. pub fn gen_rtio_log<'ctx>( ctx: &mut CodeGenContext<'ctx, '_>, - obj: &Option<(Type, ValueEnum<'ctx>)>, + obj: Option<&(Type, ValueEnum<'ctx>)>, fun: (&FunSignature, DefinitionId), args: &[(Option, ValueEnum<'ctx>)], generator: &mut dyn CodeGenerator, diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index ca2f2f1..ba6c4fa 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -43,7 +43,7 @@ use nac3core::{ OptimizationLevel, }, nac3parser::{ - ast::{Constant, ExprKind, Located, Stmt, StmtKind, StrRef}, + ast::{self, Constant, ExprKind, Located, Stmt, StmtKind, StrRef}, parser::parse_program, }, symbol_resolver::SymbolResolver, @@ -78,14 +78,62 @@ enum Isa { } impl Isa { - /// Returns the number of bits in `size_t` for the [`Isa`]. - fn get_size_type(self) -> u32 { - if self == Isa::Host { - 64u32 - } else { - 32u32 + /// Returns the [`TargetTriple`] used for compiling to this ISA. + pub fn get_llvm_target_triple(self) -> TargetTriple { + match self { + Isa::Host => TargetMachine::get_default_triple(), + Isa::RiscV32G | Isa::RiscV32IMA => TargetTriple::create("riscv32-unknown-linux"), + Isa::CortexA9 => TargetTriple::create("armv7-unknown-linux-gnueabihf"), } } + + /// Returns the [`String`] representing the target CPU used for compiling to this ISA. + pub fn get_llvm_target_cpu(self) -> String { + match self { + Isa::Host => TargetMachine::get_host_cpu_name().to_string(), + Isa::RiscV32G | Isa::RiscV32IMA => "generic-rv32".to_string(), + Isa::CortexA9 => "cortex-a9".to_string(), + } + } + + /// Returns the [`String`] representing the target features used for compiling to this ISA. + pub fn get_llvm_target_features(self) -> String { + match self { + Isa::Host => TargetMachine::get_host_cpu_features().to_string(), + Isa::RiscV32G => "+a,+m,+f,+d".to_string(), + Isa::RiscV32IMA => "+a,+m".to_string(), + Isa::CortexA9 => "+dsp,+fp16,+neon,+vfp3,+long-calls".to_string(), + } + } + + /// Returns an instance of [`CodeGenTargetMachineOptions`] representing the target machine + /// options used for compiling to this ISA. + pub fn get_llvm_target_options(self) -> CodeGenTargetMachineOptions { + CodeGenTargetMachineOptions { + triple: self.get_llvm_target_triple().as_str().to_string_lossy().into_owned(), + cpu: self.get_llvm_target_cpu(), + features: self.get_llvm_target_features(), + reloc_mode: RelocMode::PIC, + ..CodeGenTargetMachineOptions::from_host() + } + } + + /// Returns an instance of [`TargetMachine`] used in compiling and linking of a program of this + /// ISA. + pub fn create_llvm_target_machine(self, opt_level: OptimizationLevel) -> TargetMachine { + self.get_llvm_target_options() + .create_target_machine(opt_level) + .expect("couldn't create target machine") + } + + /// Returns the number of bits in `size_t` for this ISA. + fn get_size_type(self, ctx: &Context) -> u32 { + ctx.ptr_sized_int_type( + &self.create_llvm_target_machine(OptimizationLevel::Default).get_target_data(), + None, + ) + .get_bit_width() + } } #[derive(Clone)] @@ -111,6 +159,7 @@ pub struct PrimitivePythonId { generic_alias: (u64, u64), virtual_id: u64, option: u64, + module: u64, } type TopLevelComponent = (Stmt, String, PyObject); @@ -228,6 +277,10 @@ impl Nac3 { } }) } + // Allow global variable declaration with `Kernel` type annotation + StmtKind::AnnAssign { ref annotation, .. } => { + matches!(&annotation.node, ExprKind::Subscript { value, .. } if matches!(&value.node, ExprKind::Name {id, ..} if id == &"Kernel".into())) + } _ => false, }; @@ -330,7 +383,7 @@ impl Nac3 { vars: into_var_map([arg_ty]), }, Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| { - gen_core_log(ctx, &obj, fun, &args, generator)?; + gen_core_log(ctx, obj.as_ref(), fun, &args, generator)?; Ok(None) }))), @@ -360,7 +413,7 @@ impl Nac3 { vars: into_var_map([arg_ty]), }, Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| { - gen_rtio_log(ctx, &obj, fun, &args, generator)?; + gen_rtio_log(ctx, obj.as_ref(), fun, &args, generator)?; Ok(None) }))), @@ -378,7 +431,7 @@ impl Nac3 { py: Python, link_fn: &dyn Fn(&Module) -> PyResult, ) -> PyResult { - let size_t = self.isa.get_size_type(); + let size_t = self.isa.get_size_type(&Context::create()); let (mut composer, mut builtins_def, mut builtins_ty) = TopLevelComposer::new( self.builtins.clone(), Self::get_lateinit_builtins(), @@ -421,12 +474,14 @@ impl Nac3 { ]; add_exceptions(&mut composer, &mut builtins_def, &mut builtins_ty, &exception_names); + // Stores a mapping from module id to attributes let mut module_to_resolver_cache: HashMap = HashMap::new(); let mut rpc_ids = vec![]; for (stmt, path, module) in &self.top_levels { let py_module: &PyAny = module.extract(py)?; let module_id: u64 = id_fn.call1((py_module,))?.extract()?; + let module_name: String = py_module.getattr("__name__")?.extract()?; let helper = helper.clone(); let class_obj; if let StmtKind::ClassDef { name, .. } = &stmt.node { @@ -441,7 +496,7 @@ impl Nac3 { } else { class_obj = None; } - let (name_to_pyid, resolver) = + let (name_to_pyid, resolver, _, _) = module_to_resolver_cache.get(&module_id).cloned().unwrap_or_else(|| { let mut name_to_pyid: HashMap = HashMap::new(); let members: &PyDict = @@ -470,9 +525,17 @@ impl Nac3 { }))) as Arc; let name_to_pyid = Rc::new(name_to_pyid); - module_to_resolver_cache - .insert(module_id, (name_to_pyid.clone(), resolver.clone())); - (name_to_pyid, resolver) + let module_location = ast::Location::new(1, 1, stmt.location.file); + module_to_resolver_cache.insert( + module_id, + ( + name_to_pyid.clone(), + resolver.clone(), + module_name.clone(), + Some(module_location), + ), + ); + (name_to_pyid, resolver, module_name, Some(module_location)) }); let (name, def_id, ty) = composer @@ -546,6 +609,24 @@ impl Nac3 { } } + // Adding top level module definitions + for (module_id, (module_name_to_pyid, module_resolver, module_name, module_location)) in + module_to_resolver_cache + { + let def_id = composer + .register_top_level_module( + &module_name, + &module_name_to_pyid, + module_resolver, + module_location, + ) + .map_err(|e| { + CompileError::new_err(format!("compilation failed\n----------\n{e}")) + })?; + + self.pyid_to_def.write().insert(module_id, def_id); + } + let id_fun = PyModule::import(py, "builtins")?.getattr("id")?; let mut name_to_pyid: HashMap = HashMap::new(); let module = PyModule::new(py, "tmp")?; @@ -665,6 +746,9 @@ impl Nac3 { "Unsupported @rpc annotation on global variable", ))) } + TopLevelDef::Module { .. } => { + unreachable!("Type module cannot be decorated with @rpc") + } } } } @@ -703,14 +787,18 @@ impl Nac3 { let buffer = buffer.as_slice().into(); membuffer.lock().push(buffer); }))); - let size_t = context - .ptr_sized_int_type(&self.get_llvm_target_machine().get_target_data(), None) - .get_bit_width(); let num_threads = if is_multithreaded() { 4 } else { 1 }; let thread_names: Vec = (0..num_threads).map(|_| "main".to_string()).collect(); let threads: Vec<_> = thread_names .iter() - .map(|s| Box::new(ArtiqCodeGenerator::new(s.to_string(), size_t, self.time_fns))) + .map(|s| { + Box::new(ArtiqCodeGenerator::with_target_machine( + s.to_string(), + &context, + &self.get_llvm_target_machine(), + self.time_fns, + )) + }) .collect(); let membuffer = membuffers.clone(); @@ -719,8 +807,13 @@ impl Nac3 { let (registry, handles) = WorkerRegistry::create_workers(threads, top_level.clone(), &self.llvm_options, &f); - let mut generator = ArtiqCodeGenerator::new("main".to_string(), size_t, self.time_fns); let context = Context::create(); + let mut generator = ArtiqCodeGenerator::with_target_machine( + "main".to_string(), + &context, + &self.get_llvm_target_machine(), + self.time_fns, + ); let module = context.create_module("main"); let target_machine = self.llvm_options.create_target_machine().unwrap(); module.set_data_layout(&target_machine.get_target_data().get_data_layout()); @@ -839,52 +932,10 @@ impl Nac3 { link_fn(&main) } - /// Returns the [`TargetTriple`] used for compiling to [isa]. - fn get_llvm_target_triple(isa: Isa) -> TargetTriple { - match isa { - Isa::Host => TargetMachine::get_default_triple(), - Isa::RiscV32G | Isa::RiscV32IMA => TargetTriple::create("riscv32-unknown-linux"), - Isa::CortexA9 => TargetTriple::create("armv7-unknown-linux-gnueabihf"), - } - } - - /// Returns the [`String`] representing the target CPU used for compiling to [isa]. - fn get_llvm_target_cpu(isa: Isa) -> String { - match isa { - Isa::Host => TargetMachine::get_host_cpu_name().to_string(), - Isa::RiscV32G | Isa::RiscV32IMA => "generic-rv32".to_string(), - Isa::CortexA9 => "cortex-a9".to_string(), - } - } - - /// Returns the [`String`] representing the target features used for compiling to [isa]. - fn get_llvm_target_features(isa: Isa) -> String { - match isa { - Isa::Host => TargetMachine::get_host_cpu_features().to_string(), - Isa::RiscV32G => "+a,+m,+f,+d".to_string(), - Isa::RiscV32IMA => "+a,+m".to_string(), - Isa::CortexA9 => "+dsp,+fp16,+neon,+vfp3,+long-calls".to_string(), - } - } - - /// Returns an instance of [`CodeGenTargetMachineOptions`] representing the target machine - /// options used for compiling to [isa]. - fn get_llvm_target_options(isa: Isa) -> CodeGenTargetMachineOptions { - CodeGenTargetMachineOptions { - triple: Nac3::get_llvm_target_triple(isa).as_str().to_string_lossy().into_owned(), - cpu: Nac3::get_llvm_target_cpu(isa), - features: Nac3::get_llvm_target_features(isa), - reloc_mode: RelocMode::PIC, - ..CodeGenTargetMachineOptions::from_host() - } - } - /// Returns an instance of [`TargetMachine`] used in compiling and linking of a program to the - /// target [isa]. + /// target [ISA][isa]. fn get_llvm_target_machine(&self) -> TargetMachine { - Nac3::get_llvm_target_options(self.isa) - .create_target_machine(self.llvm_options.opt_level) - .expect("couldn't create target machine") + self.isa.create_llvm_target_machine(self.llvm_options.opt_level) } } @@ -992,7 +1043,8 @@ impl Nac3 { Isa::RiscV32IMA => &timeline::NOW_PINNING_TIME_FNS, Isa::CortexA9 | Isa::Host => &timeline::EXTERN_TIME_FNS, }; - let (primitive, _) = TopLevelComposer::make_primitives(isa.get_size_type()); + let (primitive, _) = + TopLevelComposer::make_primitives(isa.get_size_type(&Context::create())); let builtins = vec![ ( "now_mu".into(), @@ -1080,6 +1132,7 @@ impl Nac3 { tuple: get_attr_id(builtins_mod, "tuple"), exception: get_attr_id(builtins_mod, "Exception"), option: get_id(artiq_builtins.get_item("Option").ok().flatten().unwrap()), + module: get_attr_id(types_mod, "ModuleType"), }; let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap(); @@ -1141,7 +1194,7 @@ impl Nac3 { deferred_eval_store: DeferredEvaluationStore::new(), llvm_options: CodeGenLLVMOptions { opt_level: OptimizationLevel::Default, - target: Nac3::get_llvm_target_options(isa), + target: isa.get_llvm_target_options(), }, }) } diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 6507dc2..f14a8ee 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -16,14 +16,14 @@ use pyo3::{ use super::PrimitivePythonId; use nac3core::{ codegen::{ - types::{ndarray::NDArrayType, ProxyType}, + types::{ndarray::NDArrayType, structure::StructProxyType, ProxyType}, values::ndarray::make_contiguous_strides, CodeGenContext, CodeGenerator, }, inkwell::{ module::Linkage, types::{BasicType, BasicTypeEnum}, - values::BasicValueEnum, + values::{BasicValue, BasicValueEnum}, AddressSpace, }, nac3parser::ast::{self, StrRef}, @@ -674,6 +674,48 @@ impl InnerResolver { }) }); + // check if obj is module + if self.helper.id_fn.call1(py, (ty.clone(),))?.extract::(py)? + == self.primitive_ids.module + && self.pyid_to_def.read().contains_key(&py_obj_id) + { + let def_id = self.pyid_to_def.read()[&py_obj_id]; + let def = defs[def_id.0].read(); + let TopLevelDef::Module { name: module_name, module_id, attributes, methods, .. } = + &*def + else { + unreachable!("must be a module here"); + }; + // Construct the module return type + let mut module_attributes = HashMap::new(); + for (name, _) in attributes { + let attribute_obj = obj.getattr(name.to_string().as_str())?; + let attribute_ty = + self.get_obj_type(py, attribute_obj, unifier, defs, primitives)?; + if let Ok(attribute_ty) = attribute_ty { + module_attributes.insert(*name, (attribute_ty, false)); + } else { + return Ok(Err(format!("Unable to resolve {module_name}.{name}"))); + } + } + + for name in methods.keys() { + let method_obj = obj.getattr(name.to_string().as_str())?; + let method_ty = self.get_obj_type(py, method_obj, unifier, defs, primitives)?; + if let Ok(method_ty) = method_ty { + module_attributes.insert(*name, (method_ty, true)); + } else { + return Ok(Err(format!("Unable to resolve {module_name}.{name}"))); + } + } + + let module_ty = + TypeEnum::TModule { module_id: *module_id, attributes: module_attributes }; + + let ty = unifier.add_ty(module_ty); + return Ok(Ok(ty)); + } + if let Some(ty) = constructor_ty { self.pyid_to_type.write().insert(py_obj_id, ty); return Ok(Ok(ty)); @@ -931,10 +973,13 @@ impl InnerResolver { |_| Ok(Ok(extracted_ty)), ) } else if unifier.unioned(extracted_ty, primitives.bool) { - obj.extract::().map_or_else( - |_| Ok(Err(format!("{obj} is not in the range of bool"))), - |_| Ok(Ok(extracted_ty)), - ) + if obj.extract::().is_ok() + || obj.call_method("__bool__", (), None)?.extract::().is_ok() + { + Ok(Ok(extracted_ty)) + } else { + Ok(Err(format!("{obj} is not in the range of bool"))) + } } else if unifier.unioned(extracted_ty, primitives.float) { obj.extract::().map_or_else( |_| Ok(Err(format!("{obj} is not in the range of float64"))), @@ -974,10 +1019,14 @@ impl InnerResolver { let val: u64 = obj.extract().unwrap(); self.id_to_primitive.write().insert(id, PrimitiveValue::U64(val)); Ok(Some(ctx.ctx.i64_type().const_int(val, false).into())) - } else if ty_id == self.primitive_ids.bool || ty_id == self.primitive_ids.np_bool_ { + } else if ty_id == self.primitive_ids.bool { let val: bool = obj.extract().unwrap(); self.id_to_primitive.write().insert(id, PrimitiveValue::Bool(val)); Ok(Some(ctx.ctx.i8_type().const_int(u64::from(val), false).into())) + } else if ty_id == self.primitive_ids.np_bool_ { + let val: bool = obj.call_method("__bool__", (), None)?.extract().unwrap(); + self.id_to_primitive.write().insert(id, PrimitiveValue::Bool(val)); + Ok(Some(ctx.ctx.i8_type().const_int(u64::from(val), false).into())) } else if ty_id == self.primitive_ids.string || ty_id == self.primitive_ids.np_str_ { let val: String = obj.extract().unwrap(); self.id_to_primitive.write().insert(id, PrimitiveValue::Str(val.clone())); @@ -1000,7 +1049,7 @@ impl InnerResolver { } _ => unreachable!("must be list"), }; - let size_t = generator.get_size_type(ctx.ctx); + let size_t = ctx.get_size_type(); let ty = if len == 0 && matches!(&*ctx.unifier.get_ty_immutable(elem_ty), TypeEnum::TVar { .. }) { @@ -1089,7 +1138,7 @@ impl InnerResolver { let llvm_i8 = ctx.ctx.i8_type(); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let llvm_ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty); let dtype = llvm_ndarray.element_type(); @@ -1097,7 +1146,7 @@ impl InnerResolver { if self.global_value_ids.read().contains_key(&id) { let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { ctx.module.add_global( - llvm_ndarray.as_base_type().get_element_type().into_struct_type(), + llvm_ndarray.as_abi_type().get_element_type().into_struct_type(), Some(AddressSpace::default()), &id_str, ) @@ -1107,7 +1156,7 @@ impl InnerResolver { self.global_value_ids.write().insert(id, obj.into()); } - let ndims = llvm_ndarray.ndims().unwrap(); + let ndims = llvm_ndarray.ndims(); // Obtain the shape of the ndarray let shape_tuple: &PyTuple = obj.getattr("shape")?.downcast()?; @@ -1124,7 +1173,10 @@ impl InnerResolver { super::CompileError::new_err(format!("Error getting element {i}: {e}")) })? .unwrap(); - let value = value.into_int_value(); + let value = ctx + .builder + .build_int_z_extend(value.into_int_value(), llvm_usize, "") + .unwrap(); Ok(value) }) .collect::, PyErr>>()?; @@ -1203,8 +1255,16 @@ impl InnerResolver { data_global.set_initializer(&data); // Get the constant itemsize. - let itemsize = dtype.size_of().unwrap(); - let itemsize = itemsize.get_zero_extended_constant().unwrap(); + // + // NOTE: dtype.size_of() may return a non-constant, where `TargetData::get_store_size` + // will always return a constant size. + let itemsize = ctx + .registry + .llvm_options + .create_target_machine() + .map(|tm| tm.get_target_data().get_store_size(&dtype)) + .unwrap(); + assert_ne!(itemsize, 0); // Create the strides needed for ndarray.strides let strides = make_contiguous_strides(itemsize, ndims, &shape_u64s); @@ -1214,7 +1274,7 @@ impl InnerResolver { // create a global for ndarray.strides and initialize it let strides_global = ctx.module.add_global( - llvm_i8.array_type(ndims as u32), + llvm_usize.array_type(ndims as u32), Some(AddressSpace::default()), &format!("${id_str}.strides"), ); @@ -1230,24 +1290,41 @@ impl InnerResolver { let ndarray_ndims = llvm_usize.const_int(ndims, false); + // calling as_pointer_value on shape and strides returns [i64 x ndims]* + // convert into i64* to conform with expected layout of ndarray + let ndarray_shape = shape_global.as_pointer_value(); + let ndarray_shape = unsafe { + ctx.builder + .build_in_bounds_gep( + ndarray_shape, + &[llvm_usize.const_zero(), llvm_usize.const_zero()], + "", + ) + .unwrap() + }; let ndarray_strides = strides_global.as_pointer_value(); + let ndarray_strides = unsafe { + ctx.builder + .build_in_bounds_gep( + ndarray_strides, + &[llvm_usize.const_zero(), llvm_usize.const_zero()], + "", + ) + .unwrap() + }; - let ndarray = llvm_ndarray - .as_base_type() - .get_element_type() - .into_struct_type() - .const_named_struct(&[ - ndarray_itemsize.into(), - ndarray_ndims.into(), - ndarray_shape.into(), - ndarray_strides.into(), - ndarray_data.into(), - ]); + let ndarray = llvm_ndarray.get_struct_type().const_named_struct(&[ + ndarray_itemsize.into(), + ndarray_ndims.into(), + ndarray_shape.into(), + ndarray_strides.into(), + ndarray_data.into(), + ]); let ndarray_global = ctx.module.add_global( - llvm_ndarray.as_base_type().get_element_type().into_struct_type(), + llvm_ndarray.as_abi_type().get_element_type().into_struct_type(), Some(AddressSpace::default()), &id_str, ); @@ -1334,6 +1411,77 @@ impl InnerResolver { None => Ok(None), } } + } else if ty_id == self.primitive_ids.module { + let id_str = id.to_string(); + + if let Some(global) = ctx.module.get_global(&id_str) { + return Ok(Some(global.as_pointer_value().into())); + } + + let top_level_defs = ctx.top_level.definitions.read(); + let ty = self + .get_obj_type(py, obj, &mut ctx.unifier, &top_level_defs, &ctx.primitives)? + .unwrap(); + let ty = ctx + .get_llvm_type(generator, ty) + .into_pointer_type() + .get_element_type() + .into_struct_type(); + + { + if self.global_value_ids.read().contains_key(&id) { + let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { + ctx.module.add_global(ty, Some(AddressSpace::default()), &id_str) + }); + return Ok(Some(global.as_pointer_value().into())); + } + self.global_value_ids.write().insert(id, obj.into()); + } + + let fields = { + let definition = + top_level_defs.get(self.pyid_to_def.read().get(&id).unwrap().0).unwrap().read(); + let TopLevelDef::Module { attributes, .. } = &*definition else { unreachable!() }; + attributes + .iter() + .filter_map(|f| { + let definition = top_level_defs.get(f.1 .0).unwrap().read(); + if let TopLevelDef::Variable { ty, .. } = &*definition { + Some((f.0, *ty)) + } else { + None + } + }) + .collect_vec() + }; + + let values: Result>, _> = fields + .iter() + .map(|(name, ty)| { + self.get_obj_value( + py, + obj.getattr(name.to_string().as_str())?, + ctx, + generator, + *ty, + ) + .map_err(|e| { + super::CompileError::new_err(format!("Error getting field {name}: {e}")) + }) + }) + .collect(); + let values = values?; + + if let Some(values) = values { + let val = ty.const_named_struct(&values); + let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { + ctx.module.add_global(ty, Some(AddressSpace::default()), &id_str) + }); + global.set_initializer(&val); + Ok(Some(global.as_pointer_value().into())) + } else { + Ok(None) + } } else { let id_str = id.to_string(); @@ -1413,9 +1561,12 @@ impl InnerResolver { } else if ty_id == self.primitive_ids.uint64 { let val: u64 = obj.extract()?; Ok(SymbolValue::U64(val)) - } else if ty_id == self.primitive_ids.bool || ty_id == self.primitive_ids.np_bool_ { + } else if ty_id == self.primitive_ids.bool { let val: bool = obj.extract()?; Ok(SymbolValue::Bool(val)) + } else if ty_id == self.primitive_ids.np_bool_ { + let val: bool = obj.call_method("__bool__", (), None)?.extract()?; + Ok(SymbolValue::Bool(val)) } else if ty_id == self.primitive_ids.string || ty_id == self.primitive_ids.np_str_ { let val: String = obj.extract()?; Ok(SymbolValue::Str(val)) @@ -1513,9 +1664,50 @@ impl SymbolResolver for Resolver { fn get_symbol_value<'ctx>( &self, id: StrRef, - _: &mut CodeGenContext<'ctx, '_>, - _: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut dyn CodeGenerator, ) -> Option> { + if let Some(def_id) = self.0.id_to_def.read().get(&id) { + let top_levels = ctx.top_level.definitions.read(); + if matches!(&*top_levels[def_id.0].read(), TopLevelDef::Variable { .. }) { + let module_val = &self.0.module; + let ret = Python::with_gil(|py| -> PyResult> { + let module_val = module_val.as_ref(py); + + let ty = self.0.get_obj_type( + py, + module_val, + &mut ctx.unifier, + &top_levels, + &ctx.primitives, + )?; + if let Err(ty) = ty { + return Ok(Err(ty)); + } + let ty = ty.unwrap(); + let obj = self.0.get_obj_value(py, module_val, ctx, generator, ty)?.unwrap(); + let (idx, _) = ctx.get_attr_index(ty, id); + let ret = unsafe { + ctx.builder.build_gep( + obj.into_pointer_value(), + &[ + ctx.ctx.i32_type().const_zero(), + ctx.ctx.i32_type().const_int(idx as u64, false), + ], + id.to_string().as_str(), + ) + } + .unwrap(); + Ok(Ok(ret.as_basic_value_enum())) + }) + .unwrap(); + if ret.is_err() { + return None; + } + return Some(ret.unwrap().into()); + } + } + let sym_value = { let id_to_val = self.0.id_to_pyval.read(); id_to_val.get(&id).cloned() diff --git a/nac3ast/Cargo.toml b/nac3ast/Cargo.toml index dc2bd55..947be09 100644 --- a/nac3ast/Cargo.toml +++ b/nac3ast/Cargo.toml @@ -11,5 +11,5 @@ fold = [] [dependencies] parking_lot = "0.12" -string-interner = "0.17" +string-interner = "0.18" fxhash = "0.2" diff --git a/nac3core/Cargo.toml b/nac3core/Cargo.toml index 6521a33..7badcee 100644 --- a/nac3core/Cargo.toml +++ b/nac3core/Cargo.toml @@ -10,11 +10,10 @@ derive = ["dep:nac3core_derive"] no-escape-analysis = [] [dependencies] -itertools = "0.13" +itertools = "0.14" crossbeam = "0.8" -indexmap = "2.6" +indexmap = "2.7" parking_lot = "0.12" -rayon = "1.10" nac3core_derive = { path = "nac3core_derive", optional = true } nac3parser = { path = "../nac3parser" } strum = "0.26" @@ -31,4 +30,4 @@ indoc = "2.0" insta = "=1.11.0" [build-dependencies] -regex = "1.10" +regex = "1.11" diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 8447fc5..87dcb42 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -1,10 +1,15 @@ #include "irrt/exception.hpp" #include "irrt/list.hpp" #include "irrt/math.hpp" -#include "irrt/ndarray.hpp" #include "irrt/range.hpp" #include "irrt/slice.hpp" +#include "irrt/string.hpp" #include "irrt/ndarray/basic.hpp" #include "irrt/ndarray/def.hpp" #include "irrt/ndarray/iter.hpp" #include "irrt/ndarray/indexing.hpp" +#include "irrt/ndarray/array.hpp" +#include "irrt/ndarray/reshape.hpp" +#include "irrt/ndarray/broadcast.hpp" +#include "irrt/ndarray/transpose.hpp" +#include "irrt/ndarray/matmul.hpp" \ No newline at end of file diff --git a/nac3core/irrt/irrt/int_types.hpp b/nac3core/irrt/irrt/int_types.hpp index ed8a48b..17ccf60 100644 --- a/nac3core/irrt/irrt/int_types.hpp +++ b/nac3core/irrt/irrt/int_types.hpp @@ -21,7 +21,5 @@ using uint64_t = unsigned _ExtInt(64); #endif -// NDArray indices are always `uint32_t`. -using NDIndexInt = uint32_t; // The type of an index or a value describing the length of a range/slice is always `int32_t`. using SliceIndex = int32_t; diff --git a/nac3core/irrt/irrt/list.hpp b/nac3core/irrt/irrt/list.hpp index 2854394..1edfe49 100644 --- a/nac3core/irrt/irrt/list.hpp +++ b/nac3core/irrt/irrt/list.hpp @@ -2,6 +2,21 @@ #include "irrt/int_types.hpp" #include "irrt/math_util.hpp" +#include "irrt/slice.hpp" + +namespace { +/** + * @brief A list in NAC3. + * + * The `items` field is opaque. You must rely on external contexts to + * know how to interpret it. + */ +template +struct List { + uint8_t* items; + SizeT len; +}; +} // namespace extern "C" { // Handle list assignment and dropping part of the list when diff --git a/nac3core/irrt/irrt/math.hpp b/nac3core/irrt/irrt/math.hpp index 1872f56..9dc1377 100644 --- a/nac3core/irrt/irrt/math.hpp +++ b/nac3core/irrt/irrt/math.hpp @@ -1,5 +1,7 @@ #pragma once +#include "irrt/int_types.hpp" + namespace { // adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c // need to make sure `exp >= 0` before calling this function diff --git a/nac3core/irrt/irrt/ndarray.hpp b/nac3core/irrt/irrt/ndarray.hpp deleted file mode 100644 index 9d305aa..0000000 --- a/nac3core/irrt/irrt/ndarray.hpp +++ /dev/null @@ -1,151 +0,0 @@ -#pragma once - -#include "irrt/int_types.hpp" - -// TODO: To be deleted since NDArray with strides is done. - -namespace { -template -SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) { - __builtin_assume(end_idx <= list_len); - - SizeT num_elems = 1; - for (SizeT i = begin_idx; i < end_idx; ++i) { - SizeT val = list_data[i]; - __builtin_assume(val > 0); - num_elems *= val; - } - return num_elems; -} - -template -void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT num_dims, NDIndexInt* idxs) { - SizeT stride = 1; - for (SizeT dim = 0; dim < num_dims; dim++) { - SizeT i = num_dims - dim - 1; - __builtin_assume(dims[i] > 0); - idxs[i] = (index / stride) % dims[i]; - stride *= dims[i]; - } -} - -template -SizeT __nac3_ndarray_flatten_index_impl(const SizeT* dims, - SizeT num_dims, - const NDIndexInt* indices, - SizeT num_indices) { - SizeT idx = 0; - SizeT stride = 1; - for (SizeT i = 0; i < num_dims; ++i) { - SizeT ri = num_dims - i - 1; - if (ri < num_indices) { - idx += stride * indices[ri]; - } - - __builtin_assume(dims[i] > 0); - stride *= dims[ri]; - } - return idx; -} - -template -void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims, - SizeT lhs_ndims, - const SizeT* rhs_dims, - SizeT rhs_ndims, - SizeT* out_dims) { - SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims; - - for (SizeT i = 0; i < max_ndims; ++i) { - const SizeT* lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr; - const SizeT* rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr; - SizeT* out_dim = &out_dims[max_ndims - i - 1]; - - if (lhs_dim_sz == nullptr) { - *out_dim = *rhs_dim_sz; - } else if (rhs_dim_sz == nullptr) { - *out_dim = *lhs_dim_sz; - } else if (*lhs_dim_sz == 1) { - *out_dim = *rhs_dim_sz; - } else if (*rhs_dim_sz == 1) { - *out_dim = *lhs_dim_sz; - } else if (*lhs_dim_sz == *rhs_dim_sz) { - *out_dim = *lhs_dim_sz; - } else { - __builtin_unreachable(); - } - } -} - -template -void __nac3_ndarray_calc_broadcast_idx_impl(const SizeT* src_dims, - SizeT src_ndims, - const NDIndexInt* in_idx, - NDIndexInt* out_idx) { - for (SizeT i = 0; i < src_ndims; ++i) { - SizeT src_i = src_ndims - i - 1; - out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i]; - } -} -} // namespace - -extern "C" { -uint32_t __nac3_ndarray_calc_size(const uint32_t* list_data, uint32_t list_len, uint32_t begin_idx, uint32_t end_idx) { - return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx); -} - -uint64_t -__nac3_ndarray_calc_size64(const uint64_t* list_data, uint64_t list_len, uint64_t begin_idx, uint64_t end_idx) { - return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx); -} - -void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims, uint32_t num_dims, NDIndexInt* idxs) { - __nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs); -} - -void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndexInt* idxs) { - __nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs); -} - -uint32_t -__nac3_ndarray_flatten_index(const uint32_t* dims, uint32_t num_dims, const NDIndexInt* indices, uint32_t num_indices) { - return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices); -} - -uint64_t __nac3_ndarray_flatten_index64(const uint64_t* dims, - uint64_t num_dims, - const NDIndexInt* indices, - uint64_t num_indices) { - return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices); -} - -void __nac3_ndarray_calc_broadcast(const uint32_t* lhs_dims, - uint32_t lhs_ndims, - const uint32_t* rhs_dims, - uint32_t rhs_ndims, - uint32_t* out_dims) { - return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims); -} - -void __nac3_ndarray_calc_broadcast64(const uint64_t* lhs_dims, - uint64_t lhs_ndims, - const uint64_t* rhs_dims, - uint64_t rhs_ndims, - uint64_t* out_dims) { - return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims); -} - -void __nac3_ndarray_calc_broadcast_idx(const uint32_t* src_dims, - uint32_t src_ndims, - const NDIndexInt* in_idx, - NDIndexInt* out_idx) { - __nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx); -} - -void __nac3_ndarray_calc_broadcast_idx64(const uint64_t* src_dims, - uint64_t src_ndims, - const NDIndexInt* in_idx, - NDIndexInt* out_idx) { - __nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx); -} -} // namespace \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/array.hpp b/nac3core/irrt/irrt/ndarray/array.hpp new file mode 100644 index 0000000..126669e --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/array.hpp @@ -0,0 +1,132 @@ +#pragma once + +#include "irrt/debug.hpp" +#include "irrt/exception.hpp" +#include "irrt/int_types.hpp" +#include "irrt/list.hpp" +#include "irrt/ndarray/basic.hpp" +#include "irrt/ndarray/def.hpp" + +namespace { +namespace ndarray::array { +/** + * @brief In the context of `np.array()`, deduce the ndarray's shape produced by `` and raise + * an exception if there is anything wrong with `` (e.g., inconsistent dimensions `np.array([[1.0, 2.0], + * [3.0]])`) + * + * If this function finds no issues with ``, the deduced shape is written to `shape`. The caller has the + * responsibility to allocate `[SizeT; ndims]` for `shape`. The caller must also initialize `shape` with `-1`s because + * of implementation details. + */ +template +void set_and_validate_list_shape_helper(SizeT axis, List* list, SizeT ndims, SizeT* shape) { + if (shape[axis] == -1) { + // Dimension is unspecified. Set it. + shape[axis] = list->len; + } else { + // Dimension is specified. Check. + if (shape[axis] != list->len) { + // Mismatch, throw an error. + // NOTE: NumPy's error message is more complex and needs more PARAMS to display. + raise_exception(SizeT, EXN_VALUE_ERROR, + "The requested array has an inhomogenous shape " + "after {0} dimension(s).", + axis, shape[axis], list->len); + } + } + + if (axis + 1 == ndims) { + // `list` has type `list[ItemType]` + // Do nothing + } else { + // `list` has type `list[list[...]]` + List** lists = (List**)(list->items); + for (SizeT i = 0; i < list->len; i++) { + set_and_validate_list_shape_helper(axis + 1, lists[i], ndims, shape); + } + } +} + +/** + * @brief See `set_and_validate_list_shape_helper`. + */ +template +void set_and_validate_list_shape(List* list, SizeT ndims, SizeT* shape) { + for (SizeT axis = 0; axis < ndims; axis++) { + shape[axis] = -1; // Sentinel to say this dimension is unspecified. + } + set_and_validate_list_shape_helper(0, list, ndims, shape); +} + +/** + * @brief In the context of `np.array()`, copied the contents stored in `list` to `ndarray`. + * + * `list` is assumed to be "legal". (i.e., no inconsistent dimensions) + * + * # Notes on `ndarray` + * The caller is responsible for allocating space for `ndarray`. + * Here is what this function expects from `ndarray` when called: + * - `ndarray->data` has to be allocated, contiguous, and may contain uninitialized values. + * - `ndarray->itemsize` has to be initialized. + * - `ndarray->ndims` has to be initialized. + * - `ndarray->shape` has to be initialized. + * - `ndarray->strides` is ignored, but note that `ndarray->data` is contiguous. + * When this function call ends: + * - `ndarray->data` is written with contents from ``. + */ +template +void write_list_to_array_helper(SizeT axis, SizeT* index, List* list, NDArray* ndarray) { + debug_assert_eq(SizeT, list->len, ndarray->shape[axis]); + if (IRRT_DEBUG_ASSERT_BOOL) { + if (!ndarray::basic::is_c_contiguous(ndarray)) { + raise_debug_assert(SizeT, "ndarray is not C-contiguous", ndarray->strides[0], ndarray->strides[1], + NO_PARAM); + } + } + + if (axis + 1 == ndarray->ndims) { + // `list` has type `list[scalar]` + // `ndarray` is contiguous, so we can do this, and this is fast. + uint8_t* dst = static_cast(ndarray->data) + (ndarray->itemsize * (*index)); + __builtin_memcpy(dst, list->items, ndarray->itemsize * list->len); + *index += list->len; + } else { + // `list` has type `list[list[...]]` + List** lists = (List**)(list->items); + + for (SizeT i = 0; i < list->len; i++) { + write_list_to_array_helper(axis + 1, index, lists[i], ndarray); + } + } +} + +/** + * @brief See `write_list_to_array_helper`. + */ +template +void write_list_to_array(List* list, NDArray* ndarray) { + SizeT index = 0; + write_list_to_array_helper((SizeT)0, &index, list, ndarray); +} +} // namespace ndarray::array +} // namespace + +extern "C" { +using namespace ndarray::array; + +void __nac3_ndarray_array_set_and_validate_list_shape(List* list, int32_t ndims, int32_t* shape) { + set_and_validate_list_shape(list, ndims, shape); +} + +void __nac3_ndarray_array_set_and_validate_list_shape64(List* list, int64_t ndims, int64_t* shape) { + set_and_validate_list_shape(list, ndims, shape); +} + +void __nac3_ndarray_array_write_list_to_array(List* list, NDArray* ndarray) { + write_list_to_array(list, ndarray); +} + +void __nac3_ndarray_array_write_list_to_array64(List* list, NDArray* ndarray) { + write_list_to_array(list, ndarray); +} +} \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/basic.hpp b/nac3core/irrt/irrt/ndarray/basic.hpp index 05ee30f..62c92ae 100644 --- a/nac3core/irrt/irrt/ndarray/basic.hpp +++ b/nac3core/irrt/irrt/ndarray/basic.hpp @@ -6,8 +6,7 @@ #include "irrt/ndarray/def.hpp" namespace { -namespace ndarray { -namespace basic { +namespace ndarray::basic { /** * @brief Assert that `shape` does not contain negative dimensions. * @@ -247,8 +246,7 @@ void copy_data(const NDArray* src_ndarray, NDArray* dst_ndarray) { ndarray::basic::set_pelement_value(dst_ndarray, dst_element, src_element); } } -} // namespace basic -} // namespace ndarray +} // namespace ndarray::basic } // namespace extern "C" { diff --git a/nac3core/irrt/irrt/ndarray/broadcast.hpp b/nac3core/irrt/irrt/ndarray/broadcast.hpp new file mode 100644 index 0000000..6e54b1c --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/broadcast.hpp @@ -0,0 +1,165 @@ +#pragma once + +#include "irrt/int_types.hpp" +#include "irrt/ndarray/def.hpp" +#include "irrt/slice.hpp" + +namespace { +template +struct ShapeEntry { + SizeT ndims; + SizeT* shape; +}; +} // namespace + +namespace { +namespace ndarray::broadcast { +/** + * @brief Return true if `src_shape` can broadcast to `dst_shape`. + * + * See https://numpy.org/doc/stable/user/basics.broadcasting.html + */ +template +bool can_broadcast_shape_to(SizeT target_ndims, const SizeT* target_shape, SizeT src_ndims, const SizeT* src_shape) { + if (src_ndims > target_ndims) { + return false; + } + + for (SizeT i = 0; i < src_ndims; i++) { + SizeT target_dim = target_shape[target_ndims - i - 1]; + SizeT src_dim = src_shape[src_ndims - i - 1]; + if (!(src_dim == 1 || target_dim == src_dim)) { + return false; + } + } + return true; +} + +/** + * @brief Performs `np.broadcast_shapes()` + * + * @param num_shapes Number of entries in `shapes` + * @param shapes The list of shape to do `np.broadcast_shapes` on. + * @param dst_ndims The length of `dst_shape`. + * `dst_ndims` must be `max([shape.ndims for shape in shapes])`, but the caller has to calculate it/provide it. + * for this function since they should already know in order to allocate `dst_shape` in the first place. + * @param dst_shape The resulting shape. Must be pre-allocated by the caller. This function calculate the result + * of `np.broadcast_shapes` and write it here. + */ +template +void broadcast_shapes(SizeT num_shapes, const ShapeEntry* shapes, SizeT dst_ndims, SizeT* dst_shape) { + for (SizeT dst_axis = 0; dst_axis < dst_ndims; dst_axis++) { + dst_shape[dst_axis] = 1; + } + +#ifdef IRRT_DEBUG_ASSERT + SizeT max_ndims_found = 0; +#endif + + for (SizeT i = 0; i < num_shapes; i++) { + ShapeEntry entry = shapes[i]; + + // Check pre-condition: `dst_ndims` must be `max([shape.ndims for shape in shapes])` + debug_assert(SizeT, entry.ndims <= dst_ndims); + +#ifdef IRRT_DEBUG_ASSERT + max_ndims_found = max(max_ndims_found, entry.ndims); +#endif + + for (SizeT j = 0; j < entry.ndims; j++) { + SizeT entry_axis = entry.ndims - j - 1; + SizeT dst_axis = dst_ndims - j - 1; + + SizeT entry_dim = entry.shape[entry_axis]; + SizeT dst_dim = dst_shape[dst_axis]; + + if (dst_dim == 1) { + dst_shape[dst_axis] = entry_dim; + } else if (entry_dim == 1 || entry_dim == dst_dim) { + // Do nothing + } else { + raise_exception(SizeT, EXN_VALUE_ERROR, + "shape mismatch: objects cannot be broadcast " + "to a single shape.", + NO_PARAM, NO_PARAM, NO_PARAM); + } + } + } + +#ifdef IRRT_DEBUG_ASSERT + // Check pre-condition: `dst_ndims` must be `max([shape.ndims for shape in shapes])` + debug_assert_eq(SizeT, max_ndims_found, dst_ndims); +#endif +} + +/** + * @brief Perform `np.broadcast_to(, )` and appropriate assertions. + * + * This function attempts to broadcast `src_ndarray` to a new shape defined by `dst_ndarray.shape`, + * and return the result by modifying `dst_ndarray`. + * + * # Notes on `dst_ndarray` + * The caller is responsible for allocating space for the resulting ndarray. + * Here is what this function expects from `dst_ndarray` when called: + * - `dst_ndarray->data` does not have to be initialized. + * - `dst_ndarray->itemsize` does not have to be initialized. + * - `dst_ndarray->ndims` must be initialized, determining the length of `dst_ndarray->shape` + * - `dst_ndarray->shape` must be allocated, and must contain the desired target broadcast shape. + * - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values. + * When this function call ends: + * - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`) + * - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize` + * - `dst_ndarray->ndims` is unchanged. + * - `dst_ndarray->shape` is unchanged. + * - `dst_ndarray->strides` is updated accordingly by how ndarray broadcast_to works. + */ +template +void broadcast_to(const NDArray* src_ndarray, NDArray* dst_ndarray) { + if (!ndarray::broadcast::can_broadcast_shape_to(dst_ndarray->ndims, dst_ndarray->shape, src_ndarray->ndims, + src_ndarray->shape)) { + raise_exception(SizeT, EXN_VALUE_ERROR, "operands could not be broadcast together", NO_PARAM, NO_PARAM, + NO_PARAM); + } + + dst_ndarray->data = src_ndarray->data; + dst_ndarray->itemsize = src_ndarray->itemsize; + + for (SizeT i = 0; i < dst_ndarray->ndims; i++) { + SizeT src_axis = src_ndarray->ndims - i - 1; + SizeT dst_axis = dst_ndarray->ndims - i - 1; + if (src_axis < 0 || (src_ndarray->shape[src_axis] == 1 && dst_ndarray->shape[dst_axis] != 1)) { + // Freeze the steps in-place + dst_ndarray->strides[dst_axis] = 0; + } else { + dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis]; + } + } +} +} // namespace ndarray::broadcast +} // namespace + +extern "C" { +using namespace ndarray::broadcast; + +void __nac3_ndarray_broadcast_to(NDArray* src_ndarray, NDArray* dst_ndarray) { + broadcast_to(src_ndarray, dst_ndarray); +} + +void __nac3_ndarray_broadcast_to64(NDArray* src_ndarray, NDArray* dst_ndarray) { + broadcast_to(src_ndarray, dst_ndarray); +} + +void __nac3_ndarray_broadcast_shapes(int32_t num_shapes, + const ShapeEntry* shapes, + int32_t dst_ndims, + int32_t* dst_shape) { + broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape); +} + +void __nac3_ndarray_broadcast_shapes64(int64_t num_shapes, + const ShapeEntry* shapes, + int64_t dst_ndims, + int64_t* dst_shape) { + broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape); +} +} \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/indexing.hpp b/nac3core/irrt/irrt/ndarray/indexing.hpp index 9e9e7b6..76e7847 100644 --- a/nac3core/irrt/irrt/ndarray/indexing.hpp +++ b/nac3core/irrt/irrt/ndarray/indexing.hpp @@ -65,8 +65,7 @@ struct NDIndex { } // namespace namespace { -namespace ndarray { -namespace indexing { +namespace ndarray::indexing { /** * @brief Perform ndarray "basic indexing" (https://numpy.org/doc/stable/user/basics.indexing.html#basic-indexing) * @@ -162,7 +161,8 @@ void index(SizeT num_indices, const NDIndex* indices, const NDArray* src_ Range range = slice->indices_checked(src_ndarray->shape[src_axis]); - dst_ndarray->data = static_cast(dst_ndarray->data) + (SizeT)range.start * src_ndarray->strides[src_axis]; + dst_ndarray->data = + static_cast(dst_ndarray->data) + (SizeT)range.start * src_ndarray->strides[src_axis]; dst_ndarray->strides[dst_axis] = ((SizeT)range.step) * src_ndarray->strides[src_axis]; dst_ndarray->shape[dst_axis] = (SizeT)range.len(); @@ -197,8 +197,7 @@ void index(SizeT num_indices, const NDIndex* indices, const NDArray* src_ debug_assert_eq(SizeT, src_ndarray->ndims, src_axis); debug_assert_eq(SizeT, dst_ndarray->ndims, dst_axis); } -} // namespace indexing -} // namespace ndarray +} // namespace ndarray::indexing } // namespace extern "C" { diff --git a/nac3core/irrt/irrt/ndarray/matmul.hpp b/nac3core/irrt/irrt/ndarray/matmul.hpp new file mode 100644 index 0000000..b0fd4d8 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/matmul.hpp @@ -0,0 +1,98 @@ +#pragma once + +#include "irrt/debug.hpp" +#include "irrt/exception.hpp" +#include "irrt/int_types.hpp" +#include "irrt/ndarray/basic.hpp" +#include "irrt/ndarray/broadcast.hpp" +#include "irrt/ndarray/iter.hpp" + +// NOTE: Everything would be much easier and elegant if einsum is implemented. + +namespace { +namespace ndarray::matmul { + +/** + * @brief Perform the broadcast in `np.einsum("...ij,...jk->...ik", a, b)`. + * + * Example: + * Suppose `a_shape == [1, 97, 4, 2]` + * and `b_shape == [99, 98, 1, 2, 5]`, + * + * ...then `new_a_shape == [99, 98, 97, 4, 2]`, + * `new_b_shape == [99, 98, 97, 2, 5]`, + * and `dst_shape == [99, 98, 97, 4, 5]`. + * ^^^^^^^^^^ ^^^^ + * (broadcasted) (4x2 @ 2x5 => 4x5) + * + * @param a_ndims Length of `a_shape`. + * @param a_shape Shape of `a`. + * @param b_ndims Length of `b_shape`. + * @param b_shape Shape of `b`. + * @param final_ndims Should be equal to `max(a_ndims, b_ndims)`. This is the length of `new_a_shape`, + * `new_b_shape`, and `dst_shape` - the number of dimensions after broadcasting. + */ +template +void calculate_shapes(SizeT a_ndims, + SizeT* a_shape, + SizeT b_ndims, + SizeT* b_shape, + SizeT final_ndims, + SizeT* new_a_shape, + SizeT* new_b_shape, + SizeT* dst_shape) { + debug_assert(SizeT, a_ndims >= 2); + debug_assert(SizeT, b_ndims >= 2); + debug_assert_eq(SizeT, max(a_ndims, b_ndims), final_ndims); + + // Check that a and b are compatible for matmul + if (a_shape[a_ndims - 1] != b_shape[b_ndims - 2]) { + // This is a custom error message. Different from NumPy. + raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot multiply LHS (shape ?x{0}) with RHS (shape {1}x?})", + a_shape[a_ndims - 1], b_shape[b_ndims - 2], NO_PARAM); + } + + const SizeT num_entries = 2; + ShapeEntry entries[num_entries] = {{.ndims = a_ndims - 2, .shape = a_shape}, + {.ndims = b_ndims - 2, .shape = b_shape}}; + + // TODO: Optimize this + ndarray::broadcast::broadcast_shapes(num_entries, entries, final_ndims - 2, new_a_shape); + ndarray::broadcast::broadcast_shapes(num_entries, entries, final_ndims - 2, new_b_shape); + ndarray::broadcast::broadcast_shapes(num_entries, entries, final_ndims - 2, dst_shape); + + new_a_shape[final_ndims - 2] = a_shape[a_ndims - 2]; + new_a_shape[final_ndims - 1] = a_shape[a_ndims - 1]; + new_b_shape[final_ndims - 2] = b_shape[b_ndims - 2]; + new_b_shape[final_ndims - 1] = b_shape[b_ndims - 1]; + dst_shape[final_ndims - 2] = a_shape[a_ndims - 2]; + dst_shape[final_ndims - 1] = b_shape[b_ndims - 1]; +} +} // namespace ndarray::matmul +} // namespace + +extern "C" { +using namespace ndarray::matmul; + +void __nac3_ndarray_matmul_calculate_shapes(int32_t a_ndims, + int32_t* a_shape, + int32_t b_ndims, + int32_t* b_shape, + int32_t final_ndims, + int32_t* new_a_shape, + int32_t* new_b_shape, + int32_t* dst_shape) { + calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape); +} + +void __nac3_ndarray_matmul_calculate_shapes64(int64_t a_ndims, + int64_t* a_shape, + int64_t b_ndims, + int64_t* b_shape, + int64_t final_ndims, + int64_t* new_a_shape, + int64_t* new_b_shape, + int64_t* dst_shape) { + calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape); +} +} \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/reshape.hpp b/nac3core/irrt/irrt/ndarray/reshape.hpp new file mode 100644 index 0000000..b2ad2a5 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/reshape.hpp @@ -0,0 +1,97 @@ +#pragma once + +#include "irrt/exception.hpp" +#include "irrt/int_types.hpp" +#include "irrt/ndarray/def.hpp" + +namespace { +namespace ndarray::reshape { +/** + * @brief Perform assertions on and resolve unknown dimensions in `new_shape` in `np.reshape(, new_shape)` + * + * If `new_shape` indeed contains unknown dimensions (specified with `-1`, just like numpy), `new_shape` will be + * modified to contain the resolved dimension. + * + * To perform assertions on and resolve unknown dimensions in `new_shape`, we don't need the actual + * `` object itself, but only the `.size` of the ``. + * + * @param size The `.size` of `` + * @param new_ndims Number of elements in `new_shape` + * @param new_shape Target shape to reshape to + */ +template +void resolve_and_check_new_shape(SizeT size, SizeT new_ndims, SizeT* new_shape) { + // Is there a -1 in `new_shape`? + bool neg1_exists = false; + // Location of -1, only initialized if `neg1_exists` is true + SizeT neg1_axis_i; + // The computed ndarray size of `new_shape` + SizeT new_size = 1; + + for (SizeT axis_i = 0; axis_i < new_ndims; axis_i++) { + SizeT dim = new_shape[axis_i]; + if (dim < 0) { + if (dim == -1) { + if (neg1_exists) { + // Multiple `-1` found. Throw an error. + raise_exception(SizeT, EXN_VALUE_ERROR, "can only specify one unknown dimension", NO_PARAM, + NO_PARAM, NO_PARAM); + } else { + neg1_exists = true; + neg1_axis_i = axis_i; + } + } else { + // TODO: What? In `np.reshape` any negative dimensions is + // treated like its `-1`. + // + // Try running `np.zeros((3, 4)).reshape((-999, 2))` + // + // It is not documented by numpy. + // Throw an error for now... + + raise_exception(SizeT, EXN_VALUE_ERROR, "Found non -1 negative dimension {0} on axis {1}", dim, axis_i, + NO_PARAM); + } + } else { + new_size *= dim; + } + } + + bool can_reshape; + if (neg1_exists) { + // Let `x` be the unknown dimension + // Solve `x * = ` + if (new_size == 0 && size == 0) { + // `x` has infinitely many solutions + can_reshape = false; + } else if (new_size == 0 && size != 0) { + // `x` has no solutions + can_reshape = false; + } else if (size % new_size != 0) { + // `x` has no integer solutions + can_reshape = false; + } else { + can_reshape = true; + new_shape[neg1_axis_i] = size / new_size; // Resolve dimension + } + } else { + can_reshape = (new_size == size); + } + + if (!can_reshape) { + raise_exception(SizeT, EXN_VALUE_ERROR, "cannot reshape array of size {0} into given shape", size, NO_PARAM, + NO_PARAM); + } +} +} // namespace ndarray::reshape +} // namespace + +extern "C" { +void __nac3_ndarray_reshape_resolve_and_check_new_shape(int32_t size, int32_t new_ndims, int32_t* new_shape) { + ndarray::reshape::resolve_and_check_new_shape(size, new_ndims, new_shape); +} + +void __nac3_ndarray_reshape_resolve_and_check_new_shape64(int64_t size, int64_t new_ndims, int64_t* new_shape) { + ndarray::reshape::resolve_and_check_new_shape(size, new_ndims, new_shape); +} +} diff --git a/nac3core/irrt/irrt/ndarray/transpose.hpp b/nac3core/irrt/irrt/ndarray/transpose.hpp new file mode 100644 index 0000000..662ceb1 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/transpose.hpp @@ -0,0 +1,143 @@ +#pragma once + +#include "irrt/debug.hpp" +#include "irrt/exception.hpp" +#include "irrt/int_types.hpp" +#include "irrt/ndarray/def.hpp" +#include "irrt/slice.hpp" + +/* + * Notes on `np.transpose(, )` + * + * TODO: `axes`, if specified, can actually contain negative indices, + * but it is not documented in numpy. + * + * Supporting it for now. + */ + +namespace { +namespace ndarray::transpose { +/** + * @brief Do assertions on `` in `np.transpose(, )`. + * + * Note that `np.transpose`'s `` argument is optional. If the argument + * is specified but the user, use this function to do assertions on it. + * + * @param ndims The number of dimensions of `` + * @param num_axes Number of elements in `` as specified by the user. + * This should be equal to `ndims`. If not, a "ValueError: axes don't match array" is thrown. + * @param axes The user specified ``. + */ +template +void assert_transpose_axes(SizeT ndims, SizeT num_axes, const SizeT* axes) { + if (ndims != num_axes) { + raise_exception(SizeT, EXN_VALUE_ERROR, "axes don't match array", NO_PARAM, NO_PARAM, NO_PARAM); + } + + // TODO: Optimize this + bool* axe_specified = (bool*)__builtin_alloca(sizeof(bool) * ndims); + for (SizeT i = 0; i < ndims; i++) + axe_specified[i] = false; + + for (SizeT i = 0; i < ndims; i++) { + SizeT axis = slice::resolve_index_in_length(ndims, axes[i]); + if (axis == -1) { + // TODO: numpy actually throws a `numpy.exceptions.AxisError` + raise_exception(SizeT, EXN_VALUE_ERROR, "axis {0} is out of bounds for array of dimension {1}", axis, ndims, + NO_PARAM); + } + + if (axe_specified[axis]) { + raise_exception(SizeT, EXN_VALUE_ERROR, "repeated axis in transpose", NO_PARAM, NO_PARAM, NO_PARAM); + } + + axe_specified[axis] = true; + } +} + +/** + * @brief Create a transpose view of `src_ndarray` and perform proper assertions. + * + * This function is very similar to doing `dst_ndarray = np.transpose(src_ndarray, )`. + * If `` is supposed to be `None`, caller can pass in a `nullptr` to ``. + * + * The transpose view created is returned by modifying `dst_ndarray`. + * + * The caller is responsible for setting up `dst_ndarray` before calling this function. + * Here is what this function expects from `dst_ndarray` when called: + * - `dst_ndarray->data` does not have to be initialized. + * - `dst_ndarray->itemsize` does not have to be initialized. + * - `dst_ndarray->ndims` must be initialized, must be equal to `src_ndarray->ndims`. + * - `dst_ndarray->shape` must be allocated, through it can contain uninitialized values. + * - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values. + * When this function call ends: + * - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`) + * - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize` + * - `dst_ndarray->ndims` is unchanged + * - `dst_ndarray->shape` is updated according to how `np.transpose` works + * - `dst_ndarray->strides` is updated according to how `np.transpose` works + * + * @param src_ndarray The NDArray to build a transpose view on + * @param dst_ndarray The resulting NDArray after transpose. Further details in the comments above, + * @param num_axes Number of elements in axes. Unused if `axes` is nullptr. + * @param axes Axes permutation. Set it to `nullptr` if `` is `None`. + */ +template +void transpose(const NDArray* src_ndarray, NDArray* dst_ndarray, SizeT num_axes, const SizeT* axes) { + debug_assert_eq(SizeT, src_ndarray->ndims, dst_ndarray->ndims); + const auto ndims = src_ndarray->ndims; + + if (axes != nullptr) + assert_transpose_axes(ndims, num_axes, axes); + + dst_ndarray->data = src_ndarray->data; + dst_ndarray->itemsize = src_ndarray->itemsize; + + // Check out https://ajcr.net/stride-guide-part-2/ to see how `np.transpose` works behind the scenes. + if (axes == nullptr) { + // `np.transpose(, axes=None)` + + /* + * Minor note: `np.transpose(, axes=None)` is equivalent to + * `np.transpose(, axes=[N-1, N-2, ..., 0])` - basically it + * is reversing the order of strides and shape. + * + * This is a fast implementation to handle this special (but very common) case. + */ + + for (SizeT axis = 0; axis < ndims; axis++) { + dst_ndarray->shape[axis] = src_ndarray->shape[ndims - axis - 1]; + dst_ndarray->strides[axis] = src_ndarray->strides[ndims - axis - 1]; + } + } else { + // `np.transpose(, )` + + // Permute strides and shape according to `axes`, while resolving negative indices in `axes` + for (SizeT axis = 0; axis < ndims; axis++) { + // `i` cannot be OUT_OF_BOUNDS because of assertions + SizeT i = slice::resolve_index_in_length(ndims, axes[axis]); + + dst_ndarray->shape[axis] = src_ndarray->shape[i]; + dst_ndarray->strides[axis] = src_ndarray->strides[i]; + } + } +} +} // namespace ndarray::transpose +} // namespace + +extern "C" { +using namespace ndarray::transpose; +void __nac3_ndarray_transpose(const NDArray* src_ndarray, + NDArray* dst_ndarray, + int32_t num_axes, + const int32_t* axes) { + transpose(src_ndarray, dst_ndarray, num_axes, axes); +} + +void __nac3_ndarray_transpose64(const NDArray* src_ndarray, + NDArray* dst_ndarray, + int64_t num_axes, + const int64_t* axes) { + transpose(src_ndarray, dst_ndarray, num_axes, axes); +} +} \ No newline at end of file diff --git a/nac3core/irrt/irrt/string.hpp b/nac3core/irrt/irrt/string.hpp new file mode 100644 index 0000000..229b750 --- /dev/null +++ b/nac3core/irrt/irrt/string.hpp @@ -0,0 +1,23 @@ +#pragma once + +#include "irrt/int_types.hpp" + +namespace { +template +bool __nac3_str_eq_impl(const char* str1, SizeT len1, const char* str2, SizeT len2) { + if (len1 != len2) { + return 0; + } + return __builtin_memcmp(str1, str2, static_cast(len1)) == 0; +} +} // namespace + +extern "C" { +bool nac3_str_eq(const char* str1, uint32_t len1, const char* str2, uint32_t len2) { + return __nac3_str_eq_impl(str1, len1, str2, len2); +} + +bool nac3_str_eq64(const char* str1, uint64_t len1, const char* str2, uint64_t len2) { + return __nac3_str_eq_impl(str1, len1, str2, len2); +} +} \ No newline at end of file diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index a41b9f5..dfb9082 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -1,6 +1,6 @@ use inkwell::{ types::BasicTypeEnum, - values::{BasicValueEnum, IntValue, PointerValue}, + values::{BasicValueEnum, IntValue}, FloatPredicate, IntPredicate, OptimizationLevel, }; use itertools::Itertools; @@ -11,19 +11,16 @@ use super::{ irrt::calculate_len_for_slice_range, llvm_intrinsics, macros::codegen_unreachable, - numpy, - numpy::ndarray_elementwise_unaryop_impl, - stmt::gen_for_callback_incrementing, - types::ndarray::NDArrayType, + types::{ndarray::NDArrayType, ListType, RangeType, TupleType}, values::{ - ndarray::NDArrayValue, ArrayLikeValue, ProxyValue, RangeValue, TypedArrayLikeAccessor, - UntypedArrayLikeAccessor, + ndarray::{NDArrayOut, NDArrayValue, ScalarOrNDArray}, + ProxyValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor, }, CodeGenContext, CodeGenerator, }; use crate::{ toplevel::{ - helper::{extract_ndims, PrimDef}, + helper::{arraylike_flatten_element_type, extract_ndims, PrimDef}, numpy::unpack_ndarray_var_tys, }, typecheck::typedef::{Type, TypeEnum}, @@ -50,47 +47,38 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>( let range_ty = ctx.primitives.range; Ok(if ctx.unifier.unioned(arg_ty, range_ty) { - let arg = RangeValue::from_pointer_value(arg.into_pointer_value(), Some("range")); + let arg = RangeType::new(ctx).map_pointer_value(arg.into_pointer_value(), Some("range")); let (start, end, step) = destructure_range(ctx, arg); calculate_len_for_slice_range(generator, ctx, start, end, step) } else { match &*ctx.unifier.get_ty_immutable(arg_ty) { - TypeEnum::TTuple { ty, .. } => llvm_i32.const_int(ty.len() as u64, false), - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::List.id() => { - let zero = llvm_i32.const_zero(); - let len = ctx - .build_gep_and_load( - arg.into_pointer_value(), - &[zero, llvm_i32.const_int(1, false)], - None, - ) - .into_int_value(); - ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap() + TypeEnum::TTuple { .. } => { + let tuple = TupleType::from_unifier_type(generator, ctx, arg_ty) + .map_struct_value(arg.into_struct_value(), None); + llvm_i32.const_int(tuple.get_type().num_elements().into(), false) } - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - let llvm_usize = generator.get_size_type(ctx.ctx); - let arg = NDArrayType::from_unifier_type(generator, ctx, arg_ty) - .map_value(arg.into_pointer_value(), None); - let ndims = arg.shape().size(ctx, generator); - ctx.make_assert( - generator, - ctx.builder - .build_int_compare(IntPredicate::NE, ndims, llvm_usize.const_zero(), "") - .unwrap(), - "0:TypeError", - "len() of unsized object", - [None, None, None], - ctx.current_loc, - ); - - let len = unsafe { - arg.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - - ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap() + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => + { + let ndarray = NDArrayType::from_unifier_type(generator, ctx, arg_ty) + .map_pointer_value(arg.into_pointer_value(), None); + ctx.builder + .build_int_truncate_or_bit_cast(ndarray.len(ctx), llvm_i32, "len") + .unwrap() } - _ => codegen_unreachable!(ctx), + + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + let list = ListType::from_unifier_type(generator, ctx, arg_ty) + .map_pointer_value(arg.into_pointer_value(), None); + ctx.builder + .build_int_truncate_or_bit_cast(list.load_size(ctx, None), llvm_i32, "len") + .unwrap() + } + + _ => unsupported_type(ctx, "len", &[arg_ty]), } }) } @@ -138,18 +126,19 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = + NDArrayType::from_unifier_type(generator, ctx, n_ty).map_pointer_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.int32, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| call_int32(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.ctx.i32_type().into() }, + |generator, ctx, scalar| call_int32(generator, ctx, (elem_ty, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, "int32", &[n_ty]), @@ -198,18 +187,19 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = + NDArrayType::from_unifier_type(generator, ctx, n_ty).map_pointer_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.int64, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| call_int64(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.ctx.i64_type().into() }, + |generator, ctx, scalar| call_int64(generator, ctx, (elem_ty, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, "int64", &[n_ty]), @@ -274,18 +264,19 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = + NDArrayType::from_unifier_type(generator, ctx, n_ty).map_pointer_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.uint32, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| call_uint32(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.ctx.i32_type().into() }, + |generator, ctx, scalar| call_uint32(generator, ctx, (elem_ty, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, "uint32", &[n_ty]), @@ -339,18 +330,19 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = + NDArrayType::from_unifier_type(generator, ctx, n_ty).map_pointer_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.uint64, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| call_uint64(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.ctx.i64_type().into() }, + |generator, ctx, scalar| call_uint64(generator, ctx, (elem_ty, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, "uint64", &[n_ty]), @@ -364,7 +356,6 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( (n_ty, n): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { let llvm_f64 = ctx.ctx.f64_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); Ok(match n { BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32 | 64) => { @@ -403,20 +394,20 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let ndims = extract_ndims(&ctx.unifier, ndims); - let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let ndarray = + NDArrayType::from_unifier_type(generator, ctx, n_ty).map_pointer_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.float, - None, - NDArrayValue::from_pointer_value(n, llvm_elem_ty, Some(ndims), llvm_usize, None), - |generator, ctx, val| call_float(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.ctx.f64_type().into() }, + |generator, ctx, scalar| call_float(generator, ctx, (elem_ty, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, "float", &[n_ty]), @@ -449,18 +440,21 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = + NDArrayType::from_unifier_type(generator, ctx, n_ty).map_pointer_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ret_elem_ty, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| call_round(generator, ctx, (elem_ty, val), ret_elem_ty), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: llvm_ret_elem_ty.into() }, + |generator, ctx, scalar| { + call_round(generator, ctx, (elem_ty, scalar), ret_elem_ty) + }, + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -486,18 +480,19 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = + NDArrayType::from_unifier_type(generator, ctx, n_ty).map_pointer_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.float, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| call_numpy_round(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.ctx.f64_type().into() }, + |generator, ctx, scalar| call_numpy_round(generator, ctx, (elem_ty, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -548,22 +543,22 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = + NDArrayType::from_unifier_type(generator, ctx, n_ty).map_pointer_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.bool, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| { - let elem = call_bool(generator, ctx, (elem_ty, val))?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.ctx.i8_type().into() }, + |generator, ctx, scalar| { + let elem = call_bool(generator, ctx, (elem_ty, scalar))?; + Ok(generator.bool_to_i8(ctx, elem.into_int_value()).into()) + }, + ) + .unwrap(); - Ok(generator.bool_to_i8(ctx, elem.into_int_value()).into()) - }, - )?; - - ndarray.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -600,18 +595,21 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = + NDArrayType::from_unifier_type(generator, ctx, n_ty).map_pointer_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ret_elem_ty, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: llvm_ret_elem_ty }, + |generator, ctx, scalar| { + call_floor(generator, ctx, (elem_ty, scalar), ret_elem_ty) + }, + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -648,18 +646,21 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = + NDArrayType::from_unifier_type(generator, ctx, n_ty).map_pointer_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ret_elem_ty, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| call_ceil(generator, ctx, (elem_ty, val), ret_elem_ty), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: llvm_ret_elem_ty }, + |generator, ctx, scalar| { + call_ceil(generator, ctx, (elem_ty, scalar), ret_elem_ty) + }, + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -750,42 +751,33 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>( ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) }) => { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + let x1 = + ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)).to_ndarray(generator, ctx); + let x2 = + ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)).to_ndarray(generator, ctx); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + let x1_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x1_ty); + let x2_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x2_ty); - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + debug_assert!(ctx.unifier.unioned(x1_dtype, x2_dtype)); + let llvm_common_dtype = x1.get_type().element_type(); - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; + let result = + NDArrayType::new_broadcast(ctx, llvm_common_dtype, &[x1.get_type(), x2.get_type()]) + .broadcast_starmap( + generator, + ctx, + &[x1, x2], + NDArrayOut::NewNDArray { dtype: llvm_common_dtype }, + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; + Ok(call_min(ctx, (x1_dtype, x1_scalar), (x2_dtype, x2_scalar))) + }, + ) + .unwrap(); - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_minimum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), @@ -849,7 +841,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( debug_assert!(["np_argmin", "np_argmax", "np_max", "np_min"].iter().any(|f| *f == fn_name)); let llvm_int64 = ctx.ctx.i64_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); Ok(match a { BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => { @@ -870,24 +862,27 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( _ => codegen_unreachable!(ctx), } } + BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, a_ty); - let n = llvm_ndarray_ty.map_value(n, None); - let n_sz = - irrt::ndarray::call_ndarray_calc_size(generator, ctx, &n.shape(), (None, None)); + let ndarray = + NDArrayType::from_unifier_type(generator, ctx, a_ty).map_pointer_value(n, None); + let llvm_dtype = ndarray.get_type().element_type(); + + let zero = llvm_usize.const_zero(); + if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { - let n_sz_eqz = ctx + let size_nez = ctx .builder - .build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "") + .build_int_compare(IntPredicate::NE, ndarray.size(ctx), zero, "") .unwrap(); ctx.make_assert( generator, - n_sz_eqz, + size_nez, "0:ValueError", format!("zero-size array to reduction operation {fn_name}").as_str(), [None, None, None], @@ -895,45 +890,43 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( ); } - let accumulator_addr = - generator.gen_var_alloc(ctx, llvm_ndarray_ty.element_type(), None)?; - let res_idx = generator.gen_var_alloc(ctx, llvm_int64.into(), None)?; + let extremum = generator.gen_var_alloc(ctx, llvm_dtype, None)?; + let extremum_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; - unsafe { - let identity = - n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None); - ctx.builder.build_store(accumulator_addr, identity).unwrap(); - ctx.builder.build_store(res_idx, llvm_int64.const_zero()).unwrap(); - } + let first_value = unsafe { ndarray.data().get_unchecked(ctx, generator, &zero, None) }; + ctx.builder.build_store(extremum, first_value).unwrap(); + ctx.builder.build_store(extremum_idx, zero).unwrap(); - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_int64.const_int(1, false), - (n_sz, false), - |generator, ctx, _, idx| { - let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) }; - let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); - let cur_idx = ctx.builder.build_load(res_idx, "").unwrap(); + // The first element is iterated, but this doesn't matter. + ndarray + .foreach(generator, ctx, |_, ctx, _, nditer| { + let old_extremum = ctx.builder.build_load(extremum, "").unwrap(); + let old_extremum_idx = ctx + .builder + .build_load(extremum_idx, "") + .map(BasicValueEnum::into_int_value) + .unwrap(); - let result = match fn_name { + let curr_value = nditer.get_scalar(ctx); + let curr_idx = nditer.get_index(ctx); + + let new_extremum = match fn_name { "np_argmin" | "np_min" => { - call_min(ctx, (elem_ty, accumulator), (elem_ty, elem)) + call_min(ctx, (elem_ty, old_extremum), (elem_ty, curr_value)) } "np_argmax" | "np_max" => { - call_max(ctx, (elem_ty, accumulator), (elem_ty, elem)) + call_max(ctx, (elem_ty, old_extremum), (elem_ty, curr_value)) } _ => codegen_unreachable!(ctx), }; - let updated_idx = match (accumulator, result) { + let new_extremum_idx = match (old_extremum, new_extremum) { (BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => ctx .builder .build_select( ctx.builder.build_int_compare(IntPredicate::NE, m, n, "").unwrap(), - idx.into(), - cur_idx, + curr_idx, + old_extremum_idx, "", ) .unwrap(), @@ -943,24 +936,35 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( ctx.builder .build_float_compare(FloatPredicate::ONE, m, n, "") .unwrap(), - idx.into(), - cur_idx, + curr_idx, + old_extremum_idx, "", ) .unwrap(), _ => unsupported_type(ctx, fn_name, &[elem_ty, elem_ty]), }; - ctx.builder.build_store(res_idx, updated_idx).unwrap(); - ctx.builder.build_store(accumulator_addr, result).unwrap(); + + ctx.builder.build_store(extremum, new_extremum).unwrap(); + ctx.builder.build_store(extremum_idx, new_extremum_idx).unwrap(); Ok(()) - }, - llvm_int64.const_int(1, false), - )?; + }) + .unwrap(); match fn_name { - "np_argmin" | "np_argmax" => ctx.builder.build_load(res_idx, "").unwrap(), - "np_max" | "np_min" => ctx.builder.build_load(accumulator_addr, "").unwrap(), + "np_argmin" | "np_argmax" => ctx + .builder + .build_int_s_extend_or_bit_cast( + ctx.builder + .build_load(extremum_idx, "") + .map(BasicValueEnum::into_int_value) + .unwrap(), + ctx.ctx.i64_type(), + "", + ) + .unwrap() + .into(), + "np_max" | "np_min" => ctx.builder.build_load(extremum, "").unwrap(), _ => codegen_unreachable!(ctx), } } @@ -1007,42 +1011,33 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>( ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) }) => { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + let x1 = + ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)).to_ndarray(generator, ctx); + let x2 = + ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)).to_ndarray(generator, ctx); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + let x1_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x1_ty); + let x2_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x2_ty); - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + debug_assert!(ctx.unifier.unioned(x1_dtype, x2_dtype)); + let llvm_common_dtype = x1.get_type().element_type(); - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; + let result = + NDArrayType::new_broadcast(ctx, llvm_common_dtype, &[x1.get_type(), x2.get_type()]) + .broadcast_starmap( + generator, + ctx, + &[x1, x2], + NDArrayOut::NewNDArray { dtype: llvm_common_dtype }, + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; + Ok(call_max(ctx, (x1_dtype, x1_scalar), (x2_dtype, x2_scalar))) + }, + ) + .unwrap(); - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_maximum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), @@ -1075,39 +1070,20 @@ where ) -> Option>, RetElemFn: Fn(&mut CodeGenContext<'ctx, '_>, Type) -> Type, { - let result = match arg_val { - BasicValueEnum::PointerValue(x) - if arg_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (arg_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, arg_ty); - let ret_elem_ty = get_ret_elem_type(ctx, arg_elem_ty); + let arg = ScalarOrNDArray::from_value(generator, ctx, (arg_ty, arg_val)); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ret_elem_ty, - None, - llvm_ndarray_ty.map_value(x, None), - |generator, ctx, elem_val| { - helper_call_numpy_unary_elementwise( - generator, - ctx, - (arg_elem_ty, elem_val), - fn_name, - get_ret_elem_type, - on_scalar, - ) - }, - )?; - ndarray.as_base_value().into() - } + let dtype = arraylike_flatten_element_type(&mut ctx.unifier, arg_ty); - _ => on_scalar(generator, ctx, arg_ty, arg_val) - .unwrap_or_else(|| unsupported_type(ctx, fn_name, &[arg_ty])), - }; + let ret_ty = get_ret_elem_type(ctx, dtype); + let llvm_ret_ty = ctx.get_llvm_type(generator, ret_ty); + let result = arg.map(generator, ctx, llvm_ret_ty, |generator, ctx, scalar| { + let Some(result) = on_scalar(generator, ctx, dtype, scalar) else { + unsupported_type(ctx, fn_name, &[arg_ty]) + }; + Ok(result) + })?; - Ok(result) + Ok(result.to_basic_value_enum()) } pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>( @@ -1432,59 +1408,29 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_arctan2"; - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)); + let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)); - extern_fns::call_atan2(ctx, x1, x2, None).into() - } + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.ctx.f64_type().into(), + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(extern_fns::call_atan2(ctx, x1, x2, None).into()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_arctan2(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_copysign` builtin function. @@ -1496,59 +1442,29 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_copysign"; - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)); + let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)); - llvm_intrinsics::call_float_copysign(ctx, x1, x2, None).into() - } + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.ctx.f64_type().into(), + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(llvm_intrinsics::call_float_copysign(ctx, x1, x2, None).into()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_copysign(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_fmax` builtin function. @@ -1560,59 +1476,29 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_fmax"; - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)); + let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)); - llvm_intrinsics::call_float_maxnum(ctx, x1, x2, None).into() - } + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.ctx.f64_type().into(), + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(llvm_intrinsics::call_float_maxnum(ctx, x1, x2, None).into()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_fmax(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_fmin` builtin function. @@ -1624,59 +1510,29 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_fmin"; - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)); + let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)); - llvm_intrinsics::call_float_minnum(ctx, x1, x2, None).into() - } + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.ctx.f64_type().into(), + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(llvm_intrinsics::call_float_minnum(ctx, x1, x2, None).into()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_fmin(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_ldexp` builtin function. @@ -1688,48 +1544,31 @@ pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_ldexp"; - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::IntValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.int32)); + let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)); + let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)); - extern_fns::call_ldexp(ctx, x1, x2, None).into() - } + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.ctx.f64_type().into(), + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1_scalar), BasicValueEnum::IntValue(x2_scalar)) => { + debug_assert_eq!(x1.get_dtype(), ctx.ctx.f64_type().into()); + debug_assert_eq!(x2.get_dtype(), ctx.ctx.i32_type().into()); + Ok(extern_fns::call_ldexp(ctx, x1_scalar, x2_scalar, None).into()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - let dtype = - if is_ndarray1 { unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 } else { x1_ty }; - - let x1_scalar_ty = dtype; - let x2_scalar_ty = - if is_ndarray2 { unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_ldexp(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_hypot` builtin function. @@ -1741,59 +1580,29 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_hypot"; - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)); + let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)); - extern_fns::call_hypot(ctx, x1, x2, None).into() - } + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.ctx.f64_type().into(), + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(extern_fns::call_hypot(ctx, x1, x2, None).into()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_hypot(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_nextafter` builtin function. @@ -1805,87 +1614,29 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_nextafter"; - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)); + let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)); - extern_fns::call_nextafter(ctx, x1, x2, None).into() - } + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.ctx.f64_type().into(), + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(extern_fns::call_nextafter(ctx, x1, x2, None).into()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_nextafter(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) -} - -/// Allocates a struct with the fields specified by `out_matrices` and returns a pointer to it -fn build_output_struct<'ctx>( - ctx: &mut CodeGenContext<'ctx, '_>, - out_matrices: &[BasicValueEnum<'ctx>], -) -> PointerValue<'ctx> { - let field_ty = out_matrices.iter().map(BasicValueEnum::get_type).collect_vec(); - let out_ty = ctx.ctx.struct_type(&field_ty, false); - let out_ptr = ctx.builder.build_alloca(out_ty, "").unwrap(); - - for (i, v) in out_matrices.iter().enumerate() { - unsafe { - let ptr = ctx - .builder - .build_in_bounds_gep( - out_ptr, - &[ - ctx.ctx.i32_type().const_zero(), - ctx.ctx.i32_type().const_int(i as u64, false), - ], - "", - ) - .unwrap(); - ctx.builder.build_store(ptr, *v).unwrap(); - } - } - out_ptr + Ok(result.to_basic_value_enum()) } /// Invokes the `np_linalg_cholesky` linalg function @@ -1898,13 +1649,13 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_pointer_value(x1, None); if !x1.get_type().element_type().is_float_type() { unsupported_type(ctx, FN_NAME, &[x1_ty]); } - let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)) + let out = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2) .construct_uninitialized(generator, ctx, None); out.copy_shape_from_ndarray(generator, ctx, x1); unsafe { out.create_data(generator, ctx) }; @@ -1913,11 +1664,11 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( let out_c = out.make_contiguous_ndarray(generator, ctx); extern_fns::call_np_linalg_cholesky( ctx, - x1_c.as_base_value().into(), - out_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + out_c.as_abi_value(ctx).into(), None, ); - Ok(out.as_base_value().into()) + Ok(out.as_abi_value(ctx).into()) } /// Invokes the `np_linalg_qr` linalg function @@ -1928,11 +1679,11 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_linalg_qr"; - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_pointer_value(x1, None); if !x1.get_type().element_type().is_float_type() { unsupported_type(ctx, FN_NAME, &[x1_ty]); @@ -1946,7 +1697,7 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( }; let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None); - let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)); + let out_ndarray_ty = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2); let q = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[d0, dk], None); unsafe { q.create_data(generator, ctx) }; @@ -1959,16 +1710,20 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_np_linalg_qr( ctx, - x1_c.as_base_value().into(), - q_c.as_base_value().into(), - r_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + q_c.as_abi_value(ctx).into(), + r_c.as_abi_value(ctx).into(), None, ); - let q = q.as_base_value().into(); - let r = r.as_base_value().into(); - let out_ptr = build_output_struct(ctx, &[q, r]); - Ok(ctx.builder.build_load(out_ptr, "QR_Factorization_result").map(Into::into).unwrap()) + let q = q.as_abi_value(ctx); + let r = r.as_abi_value(ctx); + let tuple = TupleType::new(ctx, &[q.get_type(), r.get_type()]).construct_from_objects( + ctx, + [q.into(), r.into()], + None, + ); + Ok(tuple.as_abi_value(ctx).into()) } /// Invokes the `np_linalg_svd` linalg function @@ -1979,11 +1734,11 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_linalg_svd"; - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_pointer_value(x1, None); if !x1.get_type().element_type().is_float_type() { unsupported_type(ctx, FN_NAME, &[x1_ty]); @@ -1997,8 +1752,8 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( }; let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None); - let out_ndarray1_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(1)); - let out_ndarray2_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)); + let out_ndarray1_ty = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 1); + let out_ndarray2_ty = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2); let u = out_ndarray2_ty.construct_dyn_shape(generator, ctx, &[d0, d0], None); unsafe { u.create_data(generator, ctx) }; @@ -2016,19 +1771,19 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_np_linalg_svd( ctx, - x1_c.as_base_value().into(), - u_c.as_base_value().into(), - s_c.as_base_value().into(), - vh_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + u_c.as_abi_value(ctx).into(), + s_c.as_abi_value(ctx).into(), + vh_c.as_abi_value(ctx).into(), None, ); - let u = u.as_base_value().into(); - let s = s.as_base_value().into(); - let vh = vh.as_base_value().into(); - let out_ptr = build_output_struct(ctx, &[u, s, vh]); - - Ok(ctx.builder.build_load(out_ptr, "SVD_Factorization_result").map(Into::into).unwrap()) + let u = u.as_abi_value(ctx); + let s = s.as_abi_value(ctx); + let vh = vh.as_abi_value(ctx); + let tuple = TupleType::new(ctx, &[u.get_type(), s.get_type(), vh.get_type()]) + .construct_from_objects(ctx, [u.into(), s.into(), vh.into()], None); + Ok(tuple.as_abi_value(ctx).into()) } /// Invokes the `np_linalg_inv` linalg function @@ -2041,13 +1796,13 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>( let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_pointer_value(x1, None); if !x1.get_type().element_type().is_float_type() { unsupported_type(ctx, FN_NAME, &[x1_ty]); } - let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)) + let out = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2) .construct_uninitialized(generator, ctx, None); out.copy_shape_from_ndarray(generator, ctx, x1); unsafe { out.create_data(generator, ctx) }; @@ -2056,12 +1811,12 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>( let out_c = out.make_contiguous_ndarray(generator, ctx); extern_fns::call_np_linalg_inv( ctx, - x1_c.as_base_value().into(), - out_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + out_c.as_abi_value(ctx).into(), None, ); - Ok(out.as_base_value().into()) + Ok(out.as_abi_value(ctx).into()) } /// Invokes the `np_linalg_pinv` linalg function @@ -2072,11 +1827,11 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_linalg_pinv"; - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_pointer_value(x1, None); if !x1.get_type().element_type().is_float_type() { unsupported_type(ctx, FN_NAME, &[x1_ty]); @@ -2089,20 +1844,24 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>( x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) }; - let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)) - .construct_dyn_shape(generator, ctx, &[d0, d1], None); + let out = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2).construct_dyn_shape( + generator, + ctx, + &[d0, d1], + None, + ); unsafe { out.create_data(generator, ctx) }; let x1_c = x1.make_contiguous_ndarray(generator, ctx); let out_c = out.make_contiguous_ndarray(generator, ctx); extern_fns::call_np_linalg_pinv( ctx, - x1_c.as_base_value().into(), - out_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + out_c.as_abi_value(ctx).into(), None, ); - Ok(out.as_base_value().into()) + Ok(out.as_abi_value(ctx).into()) } /// Invokes the `sp_linalg_lu` linalg function @@ -2113,11 +1872,11 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "sp_linalg_lu"; - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_pointer_value(x1, None); if !x1.get_type().element_type().is_float_type() { unsupported_type(ctx, FN_NAME, &[x1_ty]); @@ -2131,7 +1890,7 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( }; let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None); - let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)); + let out_ndarray_ty = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2); let l = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[d0, dk], None); unsafe { l.create_data(generator, ctx) }; @@ -2144,16 +1903,20 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( let u_c = u.make_contiguous_ndarray(generator, ctx); extern_fns::call_sp_linalg_lu( ctx, - x1_c.as_base_value().into(), - l_c.as_base_value().into(), - u_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + l_c.as_abi_value(ctx).into(), + u_c.as_abi_value(ctx).into(), None, ); - let l = l.as_base_value().into(); - let u = u.as_base_value().into(); - let out_ptr = build_output_struct(ctx, &[l, u]); - Ok(ctx.builder.build_load(out_ptr, "LU_Factorization_result").map(Into::into).unwrap()) + let l = l.as_abi_value(ctx); + let u = u.as_abi_value(ctx); + let tuple = TupleType::new(ctx, &[l.get_type(), u.get_type()]).construct_from_objects( + ctx, + [l.into(), u.into()], + None, + ); + Ok(tuple.as_abi_value(ctx).into()) } /// Invokes the `np_linalg_matrix_power` linalg function @@ -2165,7 +1928,7 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_linalg_matrix_power"; - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) @@ -2174,7 +1937,7 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>( let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let ndims = extract_ndims(&ctx.unifier, ndims); let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty); - let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None); + let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, ndims, llvm_usize, None); if !x1.get_type().element_type().is_float_type() { unsupported_type(ctx, FN_NAME, &[x1_ty]); @@ -2186,11 +1949,11 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) }; - let x2 = NDArrayType::new_unsized(generator, ctx.ctx, ctx.ctx.f64_type().into()) + let x2 = NDArrayType::new_unsized(ctx, ctx.ctx.f64_type().into()) .construct_unsized(generator, ctx, &x2, None); // x2.shape == [] let x2 = x2.atleast_nd(generator, ctx, 1); // x2.shape == [1] - let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)) + let out = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2) .construct_uninitialized(generator, ctx, None); out.copy_shape_from_ndarray(generator, ctx, x1); unsafe { out.create_data(generator, ctx) }; @@ -2201,13 +1964,13 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_np_linalg_matrix_power( ctx, - x1_c.as_base_value().into(), - x2_c.as_base_value().into(), - out_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + x2_c.as_abi_value(ctx).into(), + out_c.as_abi_value(ctx).into(), None, ); - Ok(out.as_base_value().into()) + Ok(out.as_abi_value(ctx).into()) } /// Invokes the `np_linalg_det` linalg function @@ -2218,27 +1981,31 @@ pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_linalg_matrix_power"; - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_pointer_value(x1, None); if !x1.get_type().element_type().is_float_type() { unsupported_type(ctx, FN_NAME, &[x1_ty]); } // The output is a float64, but we are using an ndarray (shape == [1]) for uniformity in function call. - let det = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(1)) - .construct_const_shape(generator, ctx, &[1], None); + let det = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 1).construct_const_shape( + generator, + ctx, + &[1], + None, + ); unsafe { det.create_data(generator, ctx) }; let x1_c = x1.make_contiguous_ndarray(generator, ctx); let out_c = det.make_contiguous_ndarray(generator, ctx); extern_fns::call_np_linalg_det( ctx, - x1_c.as_base_value().into(), - out_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + out_c.as_abi_value(ctx).into(), None, ); @@ -2257,14 +2024,14 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); - assert_eq!(x1.get_type().ndims(), Some(2)); + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_pointer_value(x1, None); + assert_eq!(x1.get_type().ndims(), 2); if !x1.get_type().element_type().is_float_type() { unsupported_type(ctx, FN_NAME, &[x1_ty]); } - let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)); + let out_ndarray_ty = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2); let t = out_ndarray_ty.construct_uninitialized(generator, ctx, None); t.copy_shape_from_ndarray(generator, ctx, x1); @@ -2279,16 +2046,20 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( let z_c = z.make_contiguous_ndarray(generator, ctx); extern_fns::call_sp_linalg_schur( ctx, - x1_c.as_base_value().into(), - t_c.as_base_value().into(), - z_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + t_c.as_abi_value(ctx).into(), + z_c.as_abi_value(ctx).into(), None, ); - let t = t.as_base_value().into(); - let z = z.as_base_value().into(); - let out_ptr = build_output_struct(ctx, &[t, z]); - Ok(ctx.builder.build_load(out_ptr, "Schur_Factorization_result").map(Into::into).unwrap()) + let t = t.as_abi_value(ctx); + let z = z.as_abi_value(ctx); + let tuple = TupleType::new(ctx, &[t.get_type(), z.get_type()]).construct_from_objects( + ctx, + [t.into(), z.into()], + None, + ); + Ok(tuple.as_abi_value(ctx).into()) } /// Invokes the `sp_linalg_hessenberg` linalg function @@ -2301,14 +2072,14 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); - assert_eq!(x1.get_type().ndims(), Some(2)); + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_pointer_value(x1, None); + assert_eq!(x1.get_type().ndims(), 2); if !x1.get_type().element_type().is_float_type() { unsupported_type(ctx, FN_NAME, &[x1_ty]); } - let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)); + let out_ndarray_ty = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2); let h = out_ndarray_ty.construct_uninitialized(generator, ctx, None); h.copy_shape_from_ndarray(generator, ctx, x1); @@ -2323,14 +2094,18 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( let q_c = q.make_contiguous_ndarray(generator, ctx); extern_fns::call_sp_linalg_hessenberg( ctx, - x1_c.as_base_value().into(), - h_c.as_base_value().into(), - q_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + h_c.as_abi_value(ctx).into(), + q_c.as_abi_value(ctx).into(), None, ); - let h = h.as_base_value().into(); - let q = q.as_base_value().into(); - let out_ptr = build_output_struct(ctx, &[h, q]); - Ok(ctx.builder.build_load(out_ptr, "Hessenberg_decomposition_result").map(Into::into).unwrap()) + let h = h.as_abi_value(ctx); + let q = q.as_abi_value(ctx); + let tuple = TupleType::new(ctx, &[h.get_type(), q.get_type()]).construct_from_objects( + ctx, + [h.into(), q.into()], + None, + ); + Ok(tuple.as_abi_value(ctx).into()) } diff --git a/nac3core/src/codegen/concrete_type.rs b/nac3core/src/codegen/concrete_type.rs index d5c1fc3..503a4ae 100644 --- a/nac3core/src/codegen/concrete_type.rs +++ b/nac3core/src/codegen/concrete_type.rs @@ -56,6 +56,10 @@ pub enum ConcreteTypeEnum { fields: HashMap, params: IndexMap, }, + TModule { + module_id: DefinitionId, + methods: HashMap, + }, TVirtual { ty: ConcreteType, }, @@ -205,6 +209,19 @@ impl ConcreteTypeStore { }) .collect(), }, + TypeEnum::TModule { module_id, attributes } => ConcreteTypeEnum::TModule { + module_id: *module_id, + methods: attributes + .iter() + .filter_map(|(name, ty)| match &*unifier.get_ty(ty.0) { + TypeEnum::TFunc(..) | TypeEnum::TObj { .. } => None, + _ => Some(( + *name, + (self.from_unifier_type(unifier, primitives, ty.0, cache), ty.1), + )), + }) + .collect(), + }, TypeEnum::TVirtual { ty } => ConcreteTypeEnum::TVirtual { ty: self.from_unifier_type(unifier, primitives, *ty, cache), }, @@ -284,6 +301,15 @@ impl ConcreteTypeStore { TypeVar { id, ty } })), }, + ConcreteTypeEnum::TModule { module_id, methods } => TypeEnum::TModule { + module_id: *module_id, + attributes: methods + .iter() + .map(|(name, cty)| { + (*name, (self.to_unifier_type(unifier, primitives, cty.0, cache), cty.1)) + }) + .collect::>(), + }, ConcreteTypeEnum::TFunc { args, ret, vars } => TypeEnum::TFunc(FunSignature { args: args .iter() diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 0118ca4..20d296e 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -11,7 +11,7 @@ use inkwell::{ values::{BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue, StructValue}, AddressSpace, IntPredicate, OptimizationLevel, }; -use itertools::{chain, izip, Either, Itertools}; +use itertools::{izip, Either, Itertools}; use nac3parser::ast::{ self, Boolop, Cmpop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef, @@ -24,26 +24,26 @@ use super::{ irrt::*, llvm_intrinsics::{ call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax, - call_int_umin, call_memcpy_generic, + call_memcpy_generic, }, macros::codegen_unreachable, - need_sret, numpy, + need_sret, stmt::{ gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise, gen_var, }, - types::{ndarray::NDArrayType, ListType}, + types::{ndarray::NDArrayType, ListType, RangeType}, values::{ - ndarray::{NDArrayValue, RustNDIndex}, + ndarray::{NDArrayOut, RustNDIndex, ScalarOrNDArray}, ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, RangeValue, - TypedArrayLikeAccessor, UntypedArrayLikeAccessor, + UntypedArrayLikeAccessor, }, CodeGenContext, CodeGenTask, CodeGenerator, }; use crate::{ symbol_resolver::{SymbolValue, ValueEnum}, toplevel::{ - helper::{extract_ndims, PrimDef}, + helper::{arraylike_flatten_element_type, PrimDef}, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef, }, @@ -61,8 +61,13 @@ pub fn get_subst_key( ) -> String { let mut vars = obj .map(|ty| { - let TypeEnum::TObj { params, .. } = &*unifier.get_ty(ty) else { unreachable!() }; - params.clone() + if let TypeEnum::TObj { params, .. } = &*unifier.get_ty(ty) { + params.clone() + } else if let TypeEnum::TModule { .. } = &*unifier.get_ty(ty) { + indexmap::IndexMap::new() + } else { + unreachable!() + } }) .unwrap_or_default(); vars.extend(fun_vars); @@ -79,7 +84,7 @@ pub fn get_subst_key( .join(", ") } -impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { +impl<'ctx> CodeGenContext<'ctx, '_> { /// Builds a sequence of `getelementptr` and `load` instructions which stores the value of a /// struct field into an LLVM value. pub fn build_gep_and_load( @@ -120,6 +125,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { pub fn get_attr_index(&mut self, ty: Type, attr: StrRef) -> (usize, Option) { let obj_id = match &*self.unifier.get_ty(ty) { TypeEnum::TObj { obj_id, .. } => *obj_id, + TypeEnum::TModule { module_id, .. } => *module_id, // we cannot have other types, virtual type should be handled by function calls _ => codegen_unreachable!(self), }; @@ -131,6 +137,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { let attribute_index = attributes.iter().find_position(|x| x.0 == attr).unwrap(); (attribute_index.0, Some(attribute_index.1 .2.clone())) } + } else if let TopLevelDef::Module { attributes, .. } = &*def.read() { + (attributes.iter().find_position(|x| x.0 == attr).unwrap().0, None) } else { codegen_unreachable!(self) }; @@ -165,7 +173,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { .build_global_string_ptr(v, "const") .map(|v| v.as_pointer_value().into()) .unwrap(); - let size = generator.get_size_type(self.ctx).const_int(v.len() as u64, false); + let size = self.get_size_type().const_int(v.len() as u64, false); let ty = self.get_llvm_type(generator, self.primitives.str).into_struct_type(); ty.const_named_struct(&[str_ptr, size.into()]).into() } @@ -318,7 +326,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { .build_global_string_ptr(v, "const") .map(|v| v.as_pointer_value().into()) .unwrap(); - let size = generator.get_size_type(self.ctx).const_int(v.len() as u64, false); + let size = self.get_size_type().const_int(v.len() as u64, false); let ty = self.get_llvm_type(generator, self.primitives.str); let val = ty.into_struct_type().const_named_struct(&[str_ptr, size.into()]).into(); @@ -820,7 +828,7 @@ pub fn gen_call<'ctx, G: CodeGenerator>( fun: (&FunSignature, DefinitionId), params: Vec<(Option, ValueEnum<'ctx>)>, ) -> Result>, String> { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let definition = ctx.top_level.definitions.read().get(fun.1 .0).cloned().unwrap(); let id; @@ -979,7 +987,7 @@ pub fn gen_call<'ctx, G: CodeGenerator>( TopLevelDef::Class { .. } => { return Ok(Some(generator.gen_constructor(ctx, fun.0, &def, params)?)) } - TopLevelDef::Variable { .. } => unreachable!(), + TopLevelDef::Variable { .. } | TopLevelDef::Module { .. } => unreachable!(), } } .or_else(|_: String| { @@ -1020,7 +1028,7 @@ pub fn gen_call<'ctx, G: CodeGenerator>( } let is_vararg = args.iter().any(|arg| arg.is_vararg); if is_vararg { - params.push(generator.get_size_type(ctx.ctx).into()); + params.push(ctx.get_size_type().into()); } let fun_ty = match ret_type { Some(ret_type) if !has_sret => ret_type.fn_type(¶ms, is_vararg), @@ -1095,33 +1103,6 @@ pub fn destructure_range<'ctx>( (start, end, step) } -/// Allocates a List structure with the given [type][ty] and [length]. The name of the resulting -/// LLVM value is `{name}.addr`, or `list.addr` if [name] is not specified. -/// -/// Setting `ty` to [`None`] implies that the list is empty **and** does not have a known element -/// type, and will therefore set the `list.data` type as `size_t*`. It is undefined behavior to -/// generate a sized list with an unknown element type. -pub fn allocate_list<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - ty: Option>, - length: IntValue<'ctx>, - name: Option<&'ctx str>, -) -> ListValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_elem_ty = ty.unwrap_or(llvm_usize.into()); - - // List structure; type { ty*, size_t } - let arr_ty = ListType::new(generator, ctx.ctx, llvm_elem_ty); - let list = arr_ty.alloca(generator, ctx, name); - - let length = ctx.builder.build_int_z_extend(length, llvm_usize, "").unwrap(); - list.store_size(ctx, generator, length); - list.create_data(ctx, llvm_elem_ty, None); - - list -} - /// Generates LLVM IR for a [list comprehension expression][expr]. pub fn gen_comprehension<'ctx, G: CodeGenerator>( generator: &mut G, @@ -1155,7 +1136,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( return Ok(None); }; let int32 = ctx.ctx.i32_type(); - let size_t = generator.get_size_type(ctx.ctx); + let size_t = ctx.get_size_type(); let zero_size_t = size_t.const_zero(); let zero_32 = int32.const_zero(); @@ -1170,7 +1151,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() => { let iter_val = - RangeValue::from_pointer_value(iter_val.into_pointer_value(), Some("range")); + RangeType::new(ctx).map_pointer_value(iter_val.into_pointer_value(), Some("range")); let (start, stop, step) = destructure_range(ctx, iter_val); let diff = ctx.builder.build_int_sub(stop, start, "diff").unwrap(); // add 1 to the length as the value is rounded to zero @@ -1194,12 +1175,11 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( "listcomp.alloc_size", ) .unwrap(); - list = allocate_list( + list = ListType::new(ctx, &elem_ty).construct( generator, ctx, - Some(elem_ty), list_alloc_size.into_int_value(), - Some("listcomp.addr"), + Some("listcomp"), ); let i = generator.gen_store_target(ctx, target, Some("i.addr"))?.unwrap(); @@ -1246,7 +1226,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( Some("length"), ) .into_int_value(); - list = allocate_list(generator, ctx, Some(elem_ty), length, Some("listcomp")); + list = ListType::new(ctx, &elem_ty).construct(generator, ctx, length, Some("listcomp")); let counter = generator.gen_var_alloc(ctx, size_t.into(), Some("counter.addr"))?; // counter = -1 @@ -1281,15 +1261,13 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( } // Emits the content of `cont_bb` - let emit_cont_bb = - |ctx: &CodeGenContext<'ctx, '_>, generator: &dyn CodeGenerator, list: ListValue<'ctx>| { - ctx.builder.position_at_end(cont_bb); - list.store_size( - ctx, - generator, - ctx.builder.build_load(index, "index").map(BasicValueEnum::into_int_value).unwrap(), - ); - }; + let emit_cont_bb = |ctx: &CodeGenContext<'ctx, '_>, list: ListValue<'ctx>| { + ctx.builder.position_at_end(cont_bb); + list.store_size( + ctx, + ctx.builder.build_load(index, "index").map(BasicValueEnum::into_int_value).unwrap(), + ); + }; for cond in ifs { let result = if let Some(v) = generator.gen_expr(ctx, cond)? { @@ -1297,7 +1275,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( } else { // Bail if the predicate is an ellipsis - Emit cont_bb contents in case the // no element matches the predicate - emit_cont_bb(ctx, generator, list); + emit_cont_bb(ctx, list); return Ok(None); }; @@ -1310,7 +1288,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( let Some(elem) = generator.gen_expr(ctx, elt)? else { // Similarly, bail if the generator expression is an ellipsis, but keep cont_bb contents - emit_cont_bb(ctx, generator, list); + emit_cont_bb(ctx, list); return Ok(None); }; @@ -1327,9 +1305,9 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( .unwrap(); ctx.builder.build_unconditional_branch(test_bb).unwrap(); - emit_cont_bb(ctx, generator, list); + emit_cont_bb(ctx, list); - Ok(Some(list.as_base_value().into())) + Ok(Some(list.as_abi_value(ctx).into())) } /// Generates LLVM IR for a binary operator expression using the [`Type`] and @@ -1373,7 +1351,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( } else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) || ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); if op.variant == BinopVariant::AugAssign { todo!("Augmented assignment operators not implemented for lists") @@ -1411,7 +1389,8 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( .build_int_add(lhs.load_size(ctx, None), rhs.load_size(ctx, None), "") .unwrap(); - let new_list = allocate_list(generator, ctx, Some(llvm_elem_ty), size, None); + let new_list = + ListType::new(ctx, &llvm_elem_ty).construct(generator, ctx, size, None); let lhs_size = ctx .builder @@ -1458,7 +1437,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( ctx.ctx.bool_type().const_zero(), ); - Ok(Some(new_list.as_base_value().into())) + Ok(Some(new_list.as_abi_value(ctx).into())) } Operator::Mult => { @@ -1498,10 +1477,9 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( let elem_llvm_ty = ctx.get_llvm_type(generator, elem_ty); let sizeof_elem = elem_llvm_ty.size_of().unwrap(); - let new_list = allocate_list( + let new_list = ListType::new(ctx, &elem_llvm_ty).construct( generator, ctx, - Some(elem_llvm_ty), ctx.builder.build_int_mul(list_val.load_size(ctx, None), int_val, "").unwrap(), None, ); @@ -1546,7 +1524,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( llvm_usize.const_int(1, false), )?; - Ok(Some(new_list.as_base_value().into())) + Ok(Some(new_list.as_abi_value(ctx).into())) } _ => todo!("Operator not supported"), @@ -1554,98 +1532,76 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( } else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) || ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { - let is_ndarray1 = ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + let left = ScalarOrNDArray::from_value(generator, ctx, (ty1, left_val)); + let right = ScalarOrNDArray::from_value(generator, ctx, (ty2, right_val)); - if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty2); + let ty1_dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty1); + let ty2_dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty2); - assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + // Inhomogeneous binary operations are not supported. + assert!(ctx.unifier.unioned(ty1_dtype, ty2_dtype)); - let left_val = NDArrayType::from_unifier_type(generator, ctx, ty1) - .map_value(left_val.into_pointer_value(), None); - let right_val = NDArrayType::from_unifier_type(generator, ctx, ty2) - .map_value(right_val.into_pointer_value(), None); + let common_dtype = ty1_dtype; + let llvm_common_dtype = left.get_dtype(); - let res = if op.base == Operator::MatMult { - // MatMult is the only binop which is not an elementwise op - numpy::ndarray_matmul_2d( - generator, - ctx, - ndarray_dtype1, - match op.variant { - BinopVariant::Normal => None, - BinopVariant::AugAssign => Some(left_val), - }, - left_val, - right_val, - )? - } else { - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - ndarray_dtype1, - match op.variant { - BinopVariant::Normal => None, - BinopVariant::AugAssign => Some(left_val), - }, - (ty1, left_val.as_base_value().into(), false), - (ty2, right_val.as_base_value().into(), false), - |generator, ctx, (lhs, rhs)| { - gen_binop_expr_with_values( - generator, - ctx, - (&Some(ndarray_dtype1), lhs), - op, - (&Some(ndarray_dtype2), rhs), - ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum( - ctx, - generator, - ndarray_dtype1, - ) - }, - )? - }; + let out = match op.variant { + BinopVariant::Normal => NDArrayOut::NewNDArray { dtype: llvm_common_dtype }, + BinopVariant::AugAssign => { + // Augmented assignment - `left` has to be an ndarray. If it were a scalar then NAC3 + // simply doesn't support it. + if let ScalarOrNDArray::NDArray(out_ndarray) = left { + NDArrayOut::WriteToNDArray { ndarray: out_ndarray } + } else { + panic!("left must be an ndarray") + } + } + }; - Ok(Some(res.as_base_value().into())) + if op.base == Operator::MatMult { + let left = left.to_ndarray(generator, ctx); + let right = right.to_ndarray(generator, ctx); + let result = left + .matmul(generator, ctx, ty1, (ty2, right), (common_dtype, out)) + .split_unsized(generator, ctx); + Ok(Some(result.to_basic_value_enum().into())) } else { - let (ndarray_dtype, _) = - unpack_ndarray_var_tys(&mut ctx.unifier, if is_ndarray1 { ty1 } else { ty2 }); - let ndarray_val = - NDArrayType::from_unifier_type(generator, ctx, if is_ndarray1 { ty1 } else { ty2 }) - .map_value( - if is_ndarray1 { left_val } else { right_val }.into_pointer_value(), - None, - ); - let res = numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - ndarray_dtype, - match op.variant { - BinopVariant::Normal => None, - BinopVariant::AugAssign => Some(ndarray_val), - }, - (ty1, left_val, !is_ndarray1), - (ty2, right_val, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - gen_binop_expr_with_values( - generator, - ctx, - (&Some(ndarray_dtype), lhs), - op, - (&Some(ndarray_dtype), rhs), - ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum(ctx, generator, ndarray_dtype) - }, - )?; + // For other operations, they are all elementwise operations. - Ok(Some(res.as_base_value().into())) + // There are only three cases: + // - LHS is a scalar, RHS is an ndarray. + // - LHS is an ndarray, RHS is a scalar. + // - LHS is an ndarray, RHS is an ndarray. + // + // For all cases, the scalar operand is promoted to an ndarray, + // the two are then broadcasted, and starmapped through. + + let left = left.to_ndarray(generator, ctx); + let right = right.to_ndarray(generator, ctx); + + let result = NDArrayType::new_broadcast( + ctx, + llvm_common_dtype, + &[left.get_type(), right.get_type()], + ) + .broadcast_starmap(generator, ctx, &[left, right], out, |generator, ctx, scalars| { + let left_value = scalars[0]; + let right_value = scalars[1]; + + let result = gen_binop_expr_with_values( + generator, + ctx, + (&Some(ty1_dtype), left_value), + op, + (&Some(ty2_dtype), right_value), + ctx.current_loc, + )? + .unwrap() + .to_basic_value_enum(ctx, generator, common_dtype)?; + + Ok(result) + }) + .unwrap(); + Ok(Some(result.as_abi_value(ctx).into())) } } else { let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap()); @@ -1808,10 +1764,10 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( _ => val.into(), } } else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, ty); let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); - let val = llvm_ndarray_ty.map_value(val.into_pointer_value(), None); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, ty) + .map_pointer_value(val.into_pointer_value(), None); // ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before // passing it to the elementwise codegen function @@ -1829,20 +1785,18 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( op }; - let res = numpy::ndarray_elementwise_unaryop_impl( + let mapped_ndarray = ndarray.map( generator, ctx, - ndarray_dtype, - None, - val, - |generator, ctx, val| { - gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), val))? + NDArrayOut::NewNDArray { dtype: ndarray.get_type().element_type() }, + |generator, ctx, scalar| { + gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), scalar))? + .map(|val| val.to_basic_value_enum(ctx, generator, ndarray_dtype)) .unwrap() - .to_basic_value_enum(ctx, generator, ndarray_dtype) }, )?; - res.as_base_value().into() + mapped_ndarray.as_abi_value(ctx).into() } else { unimplemented!() })) @@ -1885,87 +1839,55 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) || right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { - let (Some(left_ty), lhs) = left else { codegen_unreachable!(ctx) }; - let (Some(right_ty), rhs) = comparators[0] else { codegen_unreachable!(ctx) }; + let (Some(left_ty), left) = left else { codegen_unreachable!(ctx) }; + let (Some(right_ty), right) = comparators[0] else { codegen_unreachable!(ctx) }; let op = ops[0]; - let is_ndarray1 = - left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + let left_ty_dtype = arraylike_flatten_element_type(&mut ctx.unifier, left_ty); + let right_ty_dtype = arraylike_flatten_element_type(&mut ctx.unifier, right_ty); - return if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, right_ty); + let left = ScalarOrNDArray::from_value(generator, ctx, (left_ty, left)) + .to_ndarray(generator, ctx); + let right = ScalarOrNDArray::from_value(generator, ctx, (right_ty, right)) + .to_ndarray(generator, ctx); - assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + let result_ndarray = NDArrayType::new_broadcast( + ctx, + ctx.ctx.i8_type().into(), + &[left.get_type(), right.get_type()], + ) + .broadcast_starmap( + generator, + ctx, + &[left, right], + NDArrayOut::NewNDArray { dtype: ctx.ctx.i8_type().into() }, + |generator, ctx, scalars| { + let left_scalar = scalars[0]; + let right_scalar = scalars[1]; - let left_val = NDArrayType::from_unifier_type(generator, ctx, left_ty) - .map_value(lhs.into_pointer_value(), None); - let res = numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - ctx.primitives.bool, - None, - (left_ty, left_val.as_base_value().into(), false), - (right_ty, rhs, false), - |generator, ctx, (lhs, rhs)| { - let val = gen_cmpop_expr_with_values( - generator, - ctx, - (Some(ndarray_dtype1), lhs), - &[op], - &[(Some(ndarray_dtype2), rhs)], - )? - .unwrap() - .to_basic_value_enum( - ctx, - generator, - ctx.primitives.bool, - )?; + let val = gen_cmpop_expr_with_values( + generator, + ctx, + (Some(left_ty_dtype), left_scalar), + &[op], + &[(Some(right_ty_dtype), right_scalar)], + )? + .unwrap() + .to_basic_value_enum( + ctx, + generator, + ctx.primitives.bool, + )?; - Ok(generator.bool_to_i8(ctx, val.into_int_value()).into()) - }, - )?; + Ok(generator.bool_to_i8(ctx, val.into_int_value()).into()) + }, + )?; - Ok(Some(res.as_base_value().into())) - } else { - let (ndarray_dtype, _) = unpack_ndarray_var_tys( - &mut ctx.unifier, - if is_ndarray1 { left_ty } else { right_ty }, - ); - let res = numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - ctx.primitives.bool, - None, - (left_ty, lhs, !is_ndarray1), - (right_ty, rhs, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - let val = gen_cmpop_expr_with_values( - generator, - ctx, - (Some(ndarray_dtype), lhs), - &[op], - &[(Some(ndarray_dtype), rhs)], - )? - .unwrap() - .to_basic_value_enum( - ctx, - generator, - ctx.primitives.bool, - )?; - - Ok(generator.bool_to_i8(ctx, val.into_int_value()).into()) - }, - )?; - - Ok(Some(res.as_base_value().into())) - }; + return Ok(Some(result_ndarray.as_abi_value(ctx).into())); } } - let cmp_val = izip!(chain(once(&left), comparators.iter()), comparators.iter(), ops.iter(),) + let cmp_val = izip!(once(&left).chain(comparators.iter()), comparators.iter(), ops.iter(),) .fold(Ok(None), |prev: Result, String>, (lhs, rhs, op)| { let (left_ty, lhs) = lhs; let (right_ty, rhs) = rhs; @@ -2045,117 +1967,49 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( } else if left_ty == ctx.primitives.str { assert!(ctx.unifier.unioned(left_ty, right_ty)); - let llvm_i1 = ctx.ctx.bool_type(); - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - let lhs = lhs.into_struct_value(); let rhs = rhs.into_struct_value(); + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = ctx.get_size_type(); + let plhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap(); ctx.builder.build_store(plhs, lhs).unwrap(); let prhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap(); ctx.builder.build_store(prhs, rhs).unwrap(); + let lhs_ptr = ctx.build_in_bounds_gep_and_load( + plhs, + &[llvm_usize.const_zero(), llvm_i32.const_zero()], + None, + ).into_pointer_value(); let lhs_len = ctx.build_in_bounds_gep_and_load( plhs, - &[llvm_i32.const_zero(), llvm_i32.const_int(1, false)], + &[llvm_usize.const_zero(), llvm_i32.const_int(1, false)], None, ).into_int_value(); + + let rhs_ptr = ctx.build_in_bounds_gep_and_load( + prhs, + &[llvm_usize.const_zero(), llvm_i32.const_zero()], + None, + ).into_pointer_value(); let rhs_len = ctx.build_in_bounds_gep_and_load( prhs, - &[llvm_i32.const_zero(), llvm_i32.const_int(1, false)], + &[llvm_usize.const_zero(), llvm_i32.const_int(1, false)], None, ).into_int_value(); - - let len = call_int_umin(ctx, lhs_len, rhs_len, None); - - let current_bb = ctx.builder.get_insert_block().unwrap(); - let post_foreach_cmp = ctx.ctx.insert_basic_block_after(current_bb, "foreach.cmp.end"); - - ctx.builder.position_at_end(post_foreach_cmp); - let cmp_phi = ctx.builder.build_phi(llvm_i1, "").unwrap(); - ctx.builder.position_at_end(current_bb); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (len, false), - |generator, ctx, _, i| { - let lhs_char = { - let plhs_data = ctx.build_in_bounds_gep_and_load( - plhs, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - None, - ).into_pointer_value(); - - ctx.build_in_bounds_gep_and_load( - plhs_data, - &[i], - None - ).into_int_value() - }; - let rhs_char = { - let prhs_data = ctx.build_in_bounds_gep_and_load( - prhs, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - None, - ).into_pointer_value(); - - ctx.build_in_bounds_gep_and_load( - prhs_data, - &[i], - None - ).into_int_value() - }; - - gen_if_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx.builder.build_int_compare(IntPredicate::NE, lhs_char, rhs_char, "").unwrap()) - }, - |_, ctx| { - let bb = ctx.builder.get_insert_block().unwrap(); - cmp_phi.add_incoming(&[(&llvm_i1.const_zero(), bb)]); - ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap(); - - Ok(()) - }, - |_, _| Ok(()), - )?; - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - let bb = ctx.builder.get_insert_block().unwrap(); - let is_len_eq = ctx.builder.build_int_compare( - IntPredicate::EQ, - lhs_len, - rhs_len, - "", - ).unwrap(); - cmp_phi.add_incoming(&[(&is_len_eq, bb)]); - ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap(); - - ctx.builder.position_at_end(post_foreach_cmp); - let cmp_phi = cmp_phi.as_basic_value().into_int_value(); - - // Invert the final value if __ne__ + let result = call_string_eq(ctx, lhs_ptr, lhs_len, rhs_ptr, rhs_len); if *op == Cmpop::NotEq { - ctx.builder.build_not(cmp_phi, "").unwrap() + ctx.builder.build_not(result, "").unwrap() } else { - cmp_phi + result } } else if [left_ty, right_ty] .iter() .any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id())) { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let gen_list_cmpop = |generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>| @@ -2276,9 +2130,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( ctx.ctx.bool_type().const_zero(), ) .unwrap(); - ctx.builder - .build_unconditional_branch(hooks.exit_bb) - .unwrap(); + hooks.build_break_branch(&ctx.builder); Ok(()) }, @@ -2512,319 +2364,6 @@ pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>( ) } -/// Generates code for a subscript expression on an `ndarray`. -/// -/// * `ty` - The `Type` of the `NDArray` elements. -/// * `ndims` - The `Type` of the `NDArray` number-of-dimensions `Literal`. -/// * `v` - The `NDArray` value. -/// * `slice` - The slice expression used to subscript into the `ndarray`. -fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - ty: Type, - ndims_ty: Type, - v: NDArrayValue<'ctx>, - slice: &Expr>, -) -> Result>, String> { - let llvm_i1 = ctx.ctx.bool_type(); - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims_ty) else { - codegen_unreachable!(ctx) - }; - - let ndims = values - .iter() - .map(|ndim| u64::try_from(ndim.clone()).map_err(|()| ndim.clone())) - .collect::, _>>() - .map_err(|val| { - format!( - "Expected non-negative literal for ndarray.ndims, got {}", - i128::try_from(val).unwrap() - ) - })?; - - assert!(!ndims.is_empty()); - - // The number of dimensions subscripted by the index expression. - // Slicing a ndarray will yield the same number of dimensions, whereas indexing into a - // dimension will remove a dimension. - let subscripted_dims = match &slice.node { - ExprKind::Tuple { elts, .. } => elts.iter().fold(0, |acc, value_subexpr| { - if let ExprKind::Slice { .. } = &value_subexpr.node { - acc - } else { - acc + 1 - } - }), - - ExprKind::Slice { .. } => 0, - _ => 1, - }; - - let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum(); - let sizeof_elem = llvm_ndarray_data_t.size_of().unwrap(); - - // Check that len is non-zero - let len = v.load_ndims(ctx); - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::SGT, len, llvm_usize.const_zero(), "").unwrap(), - "0:IndexError", - "too many indices for array: array is {0}-dimensional but 1 were indexed", - [Some(len), None, None], - slice.location, - ); - - // Normalizes a possibly-negative index to its corresponding positive index - let normalize_index = |generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - index: IntValue<'ctx>, - dim: u64| { - gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare(IntPredicate::SGE, index, index.get_type().const_zero(), "") - .unwrap()) - }, - |_, _| Ok(Some(index)), - |generator, ctx| { - let llvm_i32 = ctx.ctx.i32_type(); - - let len = unsafe { - v.shape().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(dim, true), - None, - ) - }; - - let index = ctx - .builder - .build_int_add( - len, - ctx.builder.build_int_s_extend(index, llvm_usize, "").unwrap(), - "", - ) - .unwrap(); - - Ok(Some(ctx.builder.build_int_truncate(index, llvm_i32, "").unwrap())) - }, - ) - .map(|v| v.map(BasicValueEnum::into_int_value)) - }; - - // Converts a slice expression into a slice-range tuple - let expr_to_slice = |generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - node: &ExprKind>, - dim: u64| { - match node { - ExprKind::Constant { value: Constant::Int(v), .. } => { - let Some(index) = - normalize_index(generator, ctx, llvm_i32.const_int(*v as u64, true), dim)? - else { - return Ok(None); - }; - - Ok(Some((index, index, llvm_i32.const_int(1, true)))) - } - - ExprKind::Slice { lower, upper, step } => { - let dim_sz = unsafe { - v.shape().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(dim, false), - None, - ) - }; - - handle_slice_indices(lower, upper, step, ctx, generator, dim_sz) - } - - _ => { - let Some(index) = generator.gen_expr(ctx, slice)? else { return Ok(None) }; - let index = index - .to_basic_value_enum(ctx, generator, slice.custom.unwrap())? - .into_int_value(); - let Some(index) = normalize_index(generator, ctx, index, dim)? else { - return Ok(None); - }; - - Ok(Some((index, index, llvm_i32.const_int(1, true)))) - } - } - }; - - let make_indices_arr = |generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>| - -> Result<_, String> { - Ok(if let ExprKind::Tuple { elts, .. } = &slice.node { - let llvm_int_ty = ctx.get_llvm_type(generator, elts[0].custom.unwrap()); - let index_addr = generator.gen_array_var_alloc( - ctx, - llvm_int_ty, - llvm_usize.const_int(elts.len() as u64, false), - None, - )?; - - for (i, elt) in elts.iter().enumerate() { - let Some(index) = generator.gen_expr(ctx, elt)? else { - return Ok(None); - }; - - let index = index - .to_basic_value_enum(ctx, generator, elt.custom.unwrap())? - .into_int_value(); - let Some(index) = normalize_index(generator, ctx, index, 0)? else { - return Ok(None); - }; - - let store_ptr = unsafe { - index_addr.ptr_offset_unchecked( - ctx, - generator, - &llvm_usize.const_int(i as u64, false), - None, - ) - }; - ctx.builder.build_store(store_ptr, index).unwrap(); - } - - Some(index_addr) - } else if let Some(index) = generator.gen_expr(ctx, slice)? { - let llvm_int_ty = ctx.get_llvm_type(generator, slice.custom.unwrap()); - let index_addr = generator.gen_array_var_alloc( - ctx, - llvm_int_ty, - llvm_usize.const_int(1u64, false), - None, - )?; - - let index = - index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value(); - let Some(index) = normalize_index(generator, ctx, index, 0)? else { return Ok(None) }; - - let store_ptr = unsafe { - index_addr.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - ctx.builder.build_store(store_ptr, index).unwrap(); - - Some(index_addr) - } else { - None - }) - }; - - Ok(Some(if ndims.len() == 1 && ndims[0] - subscripted_dims == 0 { - let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) }; - - v.data().get(ctx, generator, &index_addr, None).into() - } else { - match &slice.node { - ExprKind::Tuple { elts, .. } => { - let slices = elts - .iter() - .enumerate() - .map(|(dim, elt)| expr_to_slice(generator, ctx, &elt.node, dim as u64)) - .take_while_inclusive(|slice| slice.as_ref().is_ok_and(Option::is_some)) - .collect::, _>>()?; - if slices.len() < elts.len() { - return Ok(None); - } - - let slices = slices.into_iter().map(Option::unwrap).collect_vec(); - - numpy::ndarray_sliced_copy(generator, ctx, ty, v, &slices)?.as_base_value().into() - } - - ExprKind::Slice { .. } => { - let Some(slice) = expr_to_slice(generator, ctx, &slice.node, 0)? else { - return Ok(None); - }; - - numpy::ndarray_sliced_copy(generator, ctx, ty, v, &[slice])?.as_base_value().into() - } - - _ => { - // Accessing an element from a multi-dimensional `ndarray` - let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) }; - - let num_dims = extract_ndims(&ctx.unifier, ndims_ty) - 1; - - // Create a new array, remove the top dimension from the dimension-size-list, and copy the - // elements over - let ndarray = - NDArrayType::new(generator, ctx.ctx, llvm_ndarray_data_t, Some(num_dims)) - .construct_uninitialized(generator, ctx, None); - - let ndarray_num_dims = ctx - .builder - .build_int_z_extend_or_bit_cast( - ndarray.load_ndims(ctx), - llvm_usize.size_of().get_type(), - "", - ) - .unwrap(); - let v_dims_src_ptr = unsafe { - v.shape().ptr_offset_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - call_memcpy_generic( - ctx, - ndarray.shape().base_ptr(ctx, generator), - v_dims_src_ptr, - ctx.builder - .build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "") - .map(Into::into) - .unwrap(), - llvm_i1.const_zero(), - ); - - let ndarray_num_elems = ndarray::call_ndarray_calc_size( - generator, - ctx, - &ndarray.shape().as_slice_value(ctx, generator), - (None, None), - ); - let ndarray_num_elems = ctx - .builder - .build_int_z_extend_or_bit_cast(ndarray_num_elems, sizeof_elem.get_type(), "") - .unwrap(); - unsafe { ndarray.create_data(generator, ctx) }; - - let v_data_src_ptr = v.data().ptr_offset(ctx, generator, &index_addr, None); - call_memcpy_generic( - ctx, - ndarray.data().base_ptr(ctx, generator), - v_data_src_ptr, - ctx.builder - .build_int_mul( - ndarray_num_elems, - llvm_ndarray_data_t.size_of().unwrap(), - "", - ) - .map(Into::into) - .unwrap(), - llvm_i1.const_zero(), - ); - - ndarray.as_base_value().into() - } - } - })) -} - /// See [`CodeGenerator::gen_expr`]. pub fn gen_expr<'ctx, G: CodeGenerator>( generator: &mut G, @@ -2833,7 +2372,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( ) -> Result>, String> { ctx.current_loc = expr.location; let int32 = ctx.ctx.i32_type(); - let usize = generator.get_size_type(ctx.ctx); + let usize = ctx.get_size_type(); let zero = int32.const_int(0, false); let loc = ctx.debug_info.0.create_debug_location( @@ -2938,8 +2477,12 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } else { Some(elements[0].get_type()) }; - let length = generator.get_size_type(ctx.ctx).const_int(elements.len() as u64, false); - let arr_str_ptr = allocate_list(generator, ctx, ty, length, Some("list")); + let length = ctx.get_size_type().const_int(elements.len() as u64, false); + let arr_str_ptr = if let Some(ty) = ty { + ListType::new(ctx, &ty).construct(generator, ctx, length, Some("list")) + } else { + ListType::new_untyped(ctx).construct_empty(generator, ctx, Some("list")) + }; let arr_ptr = arr_str_ptr.data(); for (i, v) in elements.iter().enumerate() { let elem_ptr = arr_ptr.ptr_offset( @@ -2950,7 +2493,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( ); ctx.builder.build_store(elem_ptr, *v).unwrap(); } - arr_str_ptr.as_base_value().into() + arr_str_ptr.as_abi_value(ctx).into() } ExprKind::Tuple { elts, .. } => { let elements_val = elts @@ -3270,6 +2813,10 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( &*ctx.unifier.get_ty(value.custom.unwrap()) { *obj_id + } else if let TypeEnum::TModule { module_id, .. } = + &*ctx.unifier.get_ty(value.custom.unwrap()) + { + *module_id } else { codegen_unreachable!(ctx) }; @@ -3280,11 +2827,13 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } else { let defs = ctx.top_level.definitions.read(); let obj_def = defs.get(id.0).unwrap().read(); - let TopLevelDef::Class { methods, .. } = &*obj_def else { + if let TopLevelDef::Class { methods, .. } = &*obj_def { + methods.iter().find(|method| method.0 == *attr).unwrap().2 + } else if let TopLevelDef::Module { methods, .. } = &*obj_def { + *methods.iter().find(|method| method.0 == attr).unwrap().1 + } else { codegen_unreachable!(ctx) - }; - - methods.iter().find(|method| method.0 == *attr).unwrap().2 + } }; // directly generate code for option.unwrap // since it needs to return static value to optimize for kernel invariant @@ -3418,7 +2967,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( step, ); let res_array_ret = - allocate_list(generator, ctx, Some(ty), length, Some("ret")); + ListType::new(ctx, &ty).construct(generator, ctx, length, Some("ret")); let Some(res_ind) = handle_slice_indices( &None, &None, @@ -3439,7 +2988,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( v, (start, end, step), ); - res_array_ret.as_base_value().into() + res_array_ret.as_abi_value(ctx).into() } else { let len = v.load_size(ctx, Some("len")); let raw_index = if let Some(v) = generator.gen_expr(ctx, slice)? { @@ -3450,7 +2999,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( }; let raw_index = ctx .builder - .build_int_s_extend(raw_index, generator.get_size_type(ctx.ctx), "sext") + .build_int_s_extend(raw_index, ctx.get_size_type(), "sext") .unwrap(); // handle negative index let is_negative = ctx @@ -3458,7 +3007,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( .build_int_compare( IntPredicate::SLT, raw_index, - generator.get_size_type(ctx.ctx).const_zero(), + ctx.get_size_type().const_zero(), "is_neg", ) .unwrap(); @@ -3494,14 +3043,14 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( let ndarray_ty = value.custom.unwrap(); let ndarray = ndarray.to_basic_value_enum(ctx, generator, ndarray_ty)?; let ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty) - .map_value(ndarray.into_pointer_value(), None); + .map_pointer_value(ndarray.into_pointer_value(), None); let indices = RustNDIndex::from_subscript_expr(generator, ctx, slice)?; let result = ndarray .index(generator, ctx, &indices) .split_unsized(generator, ctx) .to_basic_value_enum(); - return Ok(Some(ValueEnum::Dynamic(result))); + return Ok(Some(result.into())); } TypeEnum::TTuple { .. } => { let index: u32 = diff --git a/nac3core/src/codegen/generator.rs b/nac3core/src/codegen/generator.rs index f277ec9..620ede0 100644 --- a/nac3core/src/codegen/generator.rs +++ b/nac3core/src/codegen/generator.rs @@ -1,5 +1,6 @@ use inkwell::{ context::Context, + targets::TargetMachine, types::{BasicTypeEnum, IntType}, values::{BasicValueEnum, IntValue, PointerValue}, }; @@ -17,6 +18,10 @@ pub trait CodeGenerator { /// Return the module name for the code generator. fn get_name(&self) -> &str; + /// Return an instance of [`IntType`] corresponding to the type of `size_t` for this instance. + /// + /// Prefer using [`CodeGenContext::get_size_type`] if [`CodeGenContext`] is available, as it is + /// equivalent to this function in a more concise syntax. fn get_size_type<'ctx>(&self, ctx: &'ctx Context) -> IntType<'ctx>; /// Generate function call and returns the function return value. @@ -269,19 +274,27 @@ pub struct DefaultCodeGenerator { impl DefaultCodeGenerator { #[must_use] - pub fn new(name: String, size_t: u32) -> DefaultCodeGenerator { - assert!(matches!(size_t, 32 | 64)); - DefaultCodeGenerator { name, size_t } + pub fn new(name: String, size_t: IntType<'_>) -> DefaultCodeGenerator { + assert!(matches!(size_t.get_bit_width(), 32 | 64)); + DefaultCodeGenerator { name, size_t: size_t.get_bit_width() } + } + + #[must_use] + pub fn with_target_machine( + name: String, + ctx: &Context, + target_machine: &TargetMachine, + ) -> DefaultCodeGenerator { + let llvm_usize = ctx.ptr_sized_int_type(&target_machine.get_target_data(), None); + Self::new(name, llvm_usize) } } impl CodeGenerator for DefaultCodeGenerator { - /// Returns the name for this [`CodeGenerator`]. fn get_name(&self) -> &str { &self.name } - /// Returns an LLVM integer type representing `size_t`. fn get_size_type<'ctx>(&self, ctx: &'ctx Context) -> IntType<'ctx> { // it should be unsigned, but we don't really need unsigned and this could save us from // having to do a bit cast... diff --git a/nac3core/src/codegen/irrt/list.rs b/nac3core/src/codegen/irrt/list.rs index a7fec59..c01e2cb 100644 --- a/nac3core/src/codegen/irrt/list.rs +++ b/nac3core/src/codegen/irrt/list.rs @@ -24,42 +24,52 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>( src_arr: ListValue<'ctx>, src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), ) { - let size_ty = generator.get_size_type(ctx.ctx); - let int8_ptr = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); - let int32 = ctx.ctx.i32_type(); - let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", int8_ptr); + let llvm_usize = ctx.get_size_type(); + let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); + let llvm_i32 = ctx.ctx.i32_type(); + + assert_eq!(dest_idx.0.get_type(), llvm_i32); + assert_eq!(dest_idx.1.get_type(), llvm_i32); + assert_eq!(dest_idx.2.get_type(), llvm_i32); + assert_eq!(src_idx.0.get_type(), llvm_i32); + assert_eq!(src_idx.1.get_type(), llvm_i32); + assert_eq!(src_idx.2.get_type(), llvm_i32); + + let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", llvm_pi8); let slice_assign_fun = { let ty_vec = vec![ - int32.into(), // dest start idx - int32.into(), // dest end idx - int32.into(), // dest step + llvm_i32.into(), // dest start idx + llvm_i32.into(), // dest end idx + llvm_i32.into(), // dest step elem_ptr_type.into(), // dest arr ptr - int32.into(), // dest arr len - int32.into(), // src start idx - int32.into(), // src end idx - int32.into(), // src step + llvm_i32.into(), // dest arr len + llvm_i32.into(), // src start idx + llvm_i32.into(), // src end idx + llvm_i32.into(), // src step elem_ptr_type.into(), // src arr ptr - int32.into(), // src arr len - int32.into(), // size + llvm_i32.into(), // src arr len + llvm_i32.into(), // size ]; ctx.module.get_function(fun_symbol).unwrap_or_else(|| { - let fn_t = int32.fn_type(ty_vec.as_slice(), false); + let fn_t = llvm_i32.fn_type(ty_vec.as_slice(), false); ctx.module.add_function(fun_symbol, fn_t, None) }) }; - let zero = int32.const_zero(); - let one = int32.const_int(1, false); + let zero = llvm_i32.const_zero(); + let one = llvm_i32.const_int(1, false); let dest_arr_ptr = dest_arr.data().base_ptr(ctx, generator); let dest_arr_ptr = ctx.builder.build_pointer_cast(dest_arr_ptr, elem_ptr_type, "dest_arr_ptr_cast").unwrap(); let dest_len = dest_arr.load_size(ctx, Some("dest.len")); - let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32").unwrap(); + let dest_len = + ctx.builder.build_int_truncate_or_bit_cast(dest_len, llvm_i32, "srclen32").unwrap(); let src_arr_ptr = src_arr.data().base_ptr(ctx, generator); let src_arr_ptr = ctx.builder.build_pointer_cast(src_arr_ptr, elem_ptr_type, "src_arr_ptr_cast").unwrap(); let src_len = src_arr.load_size(ctx, Some("src.len")); - let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32").unwrap(); + let src_len = + ctx.builder.build_int_truncate_or_bit_cast(src_len, llvm_i32, "srclen32").unwrap(); // index in bound and positive should be done // assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and @@ -136,7 +146,7 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>( BasicTypeEnum::StructType(t) => t.size_of().unwrap(), _ => codegen_unreachable!(ctx), }; - ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size").unwrap() + ctx.builder.build_int_truncate_or_bit_cast(s, llvm_i32, "size").unwrap() } .into(), ]; @@ -147,6 +157,7 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>( .map(Either::unwrap_left) .unwrap() }; + // update length let need_update = ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update").unwrap(); @@ -155,8 +166,9 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>( let cont_bb = ctx.ctx.append_basic_block(current, "cont"); ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb).unwrap(); ctx.builder.position_at_end(update_bb); - let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len").unwrap(); - dest_arr.store_size(ctx, generator, new_len); + let new_len = + ctx.builder.build_int_z_extend_or_bit_cast(new_len, llvm_usize, "new_len").unwrap(); + dest_arr.store_size(ctx, new_len); ctx.builder.build_unconditional_branch(cont_bb).unwrap(); ctx.builder.position_at_end(cont_bb); } diff --git a/nac3core/src/codegen/irrt/math.rs b/nac3core/src/codegen/irrt/math.rs index 4bc9591..33445b2 100644 --- a/nac3core/src/codegen/irrt/math.rs +++ b/nac3core/src/codegen/irrt/math.rs @@ -62,8 +62,13 @@ pub fn call_isinf<'ctx, G: CodeGenerator + ?Sized>( ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>, ) -> IntValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_f64 = ctx.ctx.f64_type(); + + assert_eq!(v.get_type(), llvm_f64); + let intrinsic_fn = ctx.module.get_function("__nac3_isinf").unwrap_or_else(|| { - let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false); + let fn_type = llvm_i32.fn_type(&[llvm_f64.into()], false); ctx.module.add_function("__nac3_isinf", fn_type, None) }); @@ -84,8 +89,13 @@ pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>( ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>, ) -> IntValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_f64 = ctx.ctx.f64_type(); + + assert_eq!(v.get_type(), llvm_f64); + let intrinsic_fn = ctx.module.get_function("__nac3_isnan").unwrap_or_else(|| { - let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false); + let fn_type = llvm_i32.fn_type(&[llvm_f64.into()], false); ctx.module.add_function("__nac3_isnan", fn_type, None) }); @@ -104,6 +114,8 @@ pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>( pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> { let llvm_f64 = ctx.ctx.f64_type(); + assert_eq!(v.get_type(), llvm_f64); + let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| { let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); ctx.module.add_function("__nac3_gamma", fn_type, None) @@ -121,6 +133,8 @@ pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> { let llvm_f64 = ctx.ctx.f64_type(); + assert_eq!(v.get_type(), llvm_f64); + let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| { let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); ctx.module.add_function("__nac3_gammaln", fn_type, None) @@ -138,6 +152,8 @@ pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) - pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> { let llvm_f64 = ctx.ctx.f64_type(); + assert_eq!(v.get_type(), llvm_f64); + let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| { let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); ctx.module.add_function("__nac3_j0", fn_type, None) diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 824921c..8739178 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -15,12 +15,14 @@ pub use list::*; pub use math::*; pub use range::*; pub use slice::*; +pub use string::*; mod list; mod math; pub mod ndarray; mod range; mod slice; +mod string; #[must_use] pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> { @@ -66,13 +68,9 @@ pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) /// - When [`TypeContext::size_type`] is 32-bits, the function name is `fn_name}`. /// - When [`TypeContext::size_type`] is 64-bits, the function name is `{fn_name}64`. #[must_use] -pub fn get_usize_dependent_function_name( - generator: &G, - ctx: &CodeGenContext<'_, '_>, - name: &str, -) -> String { +pub fn get_usize_dependent_function_name(ctx: &CodeGenContext<'_, '_>, name: &str) -> String { let mut name = name.to_owned(); - match generator.get_size_type(ctx.ctx).get_bit_width() { + match ctx.get_size_type().get_bit_width() { 32 => {} 64 => name.push_str("64"), bit_width => { @@ -130,10 +128,11 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>( generator: &mut G, length: IntValue<'ctx>, ) -> Result, IntValue<'ctx>, IntValue<'ctx>)>, String> { - let int32 = ctx.ctx.i32_type(); - let zero = int32.const_zero(); - let one = int32.const_int(1, false); - let length = ctx.builder.build_int_truncate_or_bit_cast(length, int32, "leni32").unwrap(); + let llvm_i32 = ctx.ctx.i32_type(); + + let zero = llvm_i32.const_zero(); + let one = llvm_i32.const_int(1, false); + let length = ctx.builder.build_int_truncate_or_bit_cast(length, llvm_i32, "leni32").unwrap(); Ok(Some(match (start, end, step) { (s, e, None) => ( if let Some(s) = s.as_ref() { @@ -142,7 +141,7 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>( None => return Ok(None), } } else { - int32.const_zero() + llvm_i32.const_zero() }, { let e = if let Some(s) = e.as_ref() { diff --git a/nac3core/src/codegen/irrt/ndarray/array.rs b/nac3core/src/codegen/irrt/ndarray/array.rs new file mode 100644 index 0000000..63a2ab0 --- /dev/null +++ b/nac3core/src/codegen/irrt/ndarray/array.rs @@ -0,0 +1,72 @@ +use inkwell::{types::BasicTypeEnum, values::IntValue}; + +use crate::codegen::{ + expr::infer_and_call_function, + irrt::get_usize_dependent_function_name, + values::{ndarray::NDArrayValue, ListValue, ProxyValue, TypedArrayLikeAccessor}, + CodeGenContext, CodeGenerator, +}; + +/// Generates a call to `__nac3_ndarray_array_set_and_validate_list_shape`. +/// +/// Deduces the target shape of the `ndarray` from the provided `list`, raising an exception if +/// there is any issue with the resultant `shape`. +/// +/// `shape` must be pre-allocated by the caller of this function to `[usize; ndims]`, and must be +/// initialized to all `-1`s. +pub fn call_nac3_ndarray_array_set_and_validate_list_shape<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + list: ListValue<'ctx>, + ndims: IntValue<'ctx>, + shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, +) { + let llvm_usize = ctx.get_size_type(); + assert_eq!(list.get_type().element_type().unwrap(), ctx.ctx.i8_type().into()); + assert_eq!(ndims.get_type(), llvm_usize); + assert_eq!( + BasicTypeEnum::try_from(shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + + let name = + get_usize_dependent_function_name(ctx, "__nac3_ndarray_array_set_and_validate_list_shape"); + + infer_and_call_function( + ctx, + &name, + None, + &[list.as_abi_value(ctx).into(), ndims.into(), shape.base_ptr(ctx, generator).into()], + None, + None, + ); +} + +/// Generates a call to `__nac3_ndarray_array_write_list_to_array`. +/// +/// Copies the contents stored in `list` into `ndarray`. +/// +/// The `ndarray` must fulfill the following preconditions: +/// +/// - `ndarray.itemsize`: Must be initialized. +/// - `ndarray.ndims`: Must be initialized. +/// - `ndarray.shape`: Must be initialized. +/// - `ndarray.data`: Must be allocated and contiguous. +pub fn call_nac3_ndarray_array_write_list_to_array<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + list: ListValue<'ctx>, + ndarray: NDArrayValue<'ctx>, +) { + assert_eq!(list.get_type().element_type().unwrap(), ctx.ctx.i8_type().into()); + + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_array_write_list_to_array"); + + infer_and_call_function( + ctx, + &name, + None, + &[list.as_abi_value(ctx).into(), ndarray.as_abi_value(ctx).into()], + None, + None, + ); +} diff --git a/nac3core/src/codegen/irrt/ndarray/basic.rs b/nac3core/src/codegen/irrt/ndarray/basic.rs index 0daea1c..5f291c8 100644 --- a/nac3core/src/codegen/irrt/ndarray/basic.rs +++ b/nac3core/src/codegen/irrt/ndarray/basic.rs @@ -1,4 +1,5 @@ use inkwell::{ + types::BasicTypeEnum, values::{BasicValueEnum, IntValue, PointerValue}, AddressSpace, }; @@ -7,82 +8,100 @@ use crate::codegen::{ expr::{create_and_call_function, infer_and_call_function}, irrt::get_usize_dependent_function_name, types::ProxyType, - values::{ndarray::NDArrayValue, ProxyValue}, + values::{ndarray::NDArrayValue, ProxyValue, TypedArrayLikeAccessor}, CodeGenContext, CodeGenerator, }; +/// Generates a call to `__nac3_ndarray_util_assert_shape_no_negative`. +/// +/// Assets that `shape` does not contain negative dimensions. pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, - ndims: IntValue<'ctx>, - shape: PointerValue<'ctx>, + shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ) { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - let name = get_usize_dependent_function_name( - generator, - ctx, - "__nac3_ndarray_util_assert_shape_no_negative", + assert_eq!( + BasicTypeEnum::try_from(shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() ); - create_and_call_function( - ctx, - &name, - Some(llvm_usize.into()), - &[(llvm_usize.into(), ndims.into()), (llvm_pusize.into(), shape.into())], - None, - None, - ); -} - -pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ndarray_ndims: IntValue<'ctx>, - ndarray_shape: PointerValue<'ctx>, - output_ndims: IntValue<'ctx>, - output_shape: IntValue<'ctx>, -) { - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - let name = get_usize_dependent_function_name( - generator, - ctx, - "__nac3_ndarray_util_assert_output_shape_same", - ); + let name = + get_usize_dependent_function_name(ctx, "__nac3_ndarray_util_assert_shape_no_negative"); create_and_call_function( ctx, &name, Some(llvm_usize.into()), &[ - (llvm_usize.into(), ndarray_ndims.into()), - (llvm_pusize.into(), ndarray_shape.into()), - (llvm_usize.into(), output_ndims.into()), - (llvm_pusize.into(), output_shape.into()), + (llvm_usize.into(), shape.size(ctx, generator).into()), + (llvm_pusize.into(), shape.base_ptr(ctx, generator).into()), ], None, None, ); } -pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>( +/// Generates a call to `__nac3_ndarray_util_assert_shape_output_shape_same`. +/// +/// Asserts that `ndarray_shape` and `output_shape` are the same in the context of writing output to +/// an `ndarray`. +pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, - ndarray: NDArrayValue<'ctx>, -) -> IntValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_ndarray = ndarray.get_type().as_base_type(); + ndarray_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + output_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, +) { + let llvm_usize = ctx.get_size_type(); + let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_size"); + assert_eq!( + BasicTypeEnum::try_from(ndarray_shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + assert_eq!( + BasicTypeEnum::try_from(output_shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + + let name = + get_usize_dependent_function_name(ctx, "__nac3_ndarray_util_assert_output_shape_same"); create_and_call_function( ctx, &name, Some(llvm_usize.into()), - &[(llvm_ndarray.into(), ndarray.as_base_value().into())], + &[ + (llvm_usize.into(), ndarray_shape.size(ctx, generator).into()), + (llvm_pusize.into(), ndarray_shape.base_ptr(ctx, generator).into()), + (llvm_usize.into(), output_shape.size(ctx, generator).into()), + (llvm_pusize.into(), output_shape.base_ptr(ctx, generator).into()), + ], + None, + None, + ); +} + +/// Generates a call to `__nac3_ndarray_size`. +/// +/// Returns a [`usize`][CodeGenerator::get_size_type] value of the number of elements of an +/// `ndarray`, corresponding to the value of `ndarray.size`. +pub fn call_nac3_ndarray_size<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + ndarray: NDArrayValue<'ctx>, +) -> IntValue<'ctx> { + let llvm_usize = ctx.get_size_type(); + let llvm_ndarray = ndarray.get_type(); + + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_size"); + + create_and_call_function( + ctx, + &name, + Some(llvm_usize.into()), + &[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())], Some("size"), None, ) @@ -90,21 +109,24 @@ pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } -pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, +/// Generates a call to `__nac3_ndarray_nbytes`. +/// +/// Returns a [`usize`][CodeGenerator::get_size_type] value of the number of bytes consumed by the +/// data of the `ndarray`, corresponding to the value of `ndarray.nbytes`. +pub fn call_nac3_ndarray_nbytes<'ctx>( ctx: &CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, ) -> IntValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_ndarray = ndarray.get_type().as_base_type(); + let llvm_usize = ctx.get_size_type(); + let llvm_ndarray = ndarray.get_type(); - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_nbytes"); + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_nbytes"); create_and_call_function( ctx, &name, Some(llvm_usize.into()), - &[(llvm_ndarray.into(), ndarray.as_base_value().into())], + &[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())], Some("nbytes"), None, ) @@ -112,21 +134,24 @@ pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } -pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, +/// Generates a call to `__nac3_ndarray_len`. +/// +/// Returns a [`usize`][CodeGenerator::get_size_type] value of the size of the topmost dimension of +/// the `ndarray`, corresponding to the value of `ndarray.__len__`. +pub fn call_nac3_ndarray_len<'ctx>( ctx: &CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, ) -> IntValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_ndarray = ndarray.get_type().as_base_type(); + let llvm_usize = ctx.get_size_type(); + let llvm_ndarray = ndarray.get_type(); - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_len"); + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_len"); create_and_call_function( ctx, &name, Some(llvm_usize.into()), - &[(llvm_ndarray.into(), ndarray.as_base_value().into())], + &[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())], Some("len"), None, ) @@ -134,21 +159,23 @@ pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } -pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, +/// Generates a call to `__nac3_ndarray_is_c_contiguous`. +/// +/// Returns an `i1` value indicating whether the `ndarray` is C-contiguous. +pub fn call_nac3_ndarray_is_c_contiguous<'ctx>( ctx: &CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, ) -> IntValue<'ctx> { let llvm_i1 = ctx.ctx.bool_type(); - let llvm_ndarray = ndarray.get_type().as_base_type(); + let llvm_ndarray = ndarray.get_type(); - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_is_c_contiguous"); + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_is_c_contiguous"); create_and_call_function( ctx, &name, Some(llvm_i1.into()), - &[(llvm_ndarray.into(), ndarray.as_base_value().into())], + &[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())], Some("is_c_contiguous"), None, ) @@ -156,53 +183,30 @@ pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } -pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, +/// Generates a call to `__nac3_ndarray_get_nth_pelement`. +/// +/// Returns a [`PointerValue`] to the `index`-th flattened element of the `ndarray`. +pub fn call_nac3_ndarray_get_nth_pelement<'ctx>( ctx: &CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, index: IntValue<'ctx>, ) -> PointerValue<'ctx> { let llvm_i8 = ctx.ctx.i8_type(); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_ndarray = ndarray.get_type().as_base_type(); + let llvm_usize = ctx.get_size_type(); + let llvm_ndarray = ndarray.get_type(); - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_nth_pelement"); + assert_eq!(index.get_type(), llvm_usize); - create_and_call_function( - ctx, - &name, - Some(llvm_pi8.into()), - &[(llvm_ndarray.into(), ndarray.as_base_value().into()), (llvm_usize.into(), index.into())], - Some("pelement"), - None, - ) - .map(BasicValueEnum::into_pointer_value) - .unwrap() -} - -pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ndarray: NDArrayValue<'ctx>, - indices: PointerValue<'ctx>, -) -> PointerValue<'ctx> { - let llvm_i8 = ctx.ctx.i8_type(); - let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - let llvm_ndarray = ndarray.get_type().as_base_type(); - - let name = - get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_pelement_by_indices"); + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_get_nth_pelement"); create_and_call_function( ctx, &name, Some(llvm_pi8.into()), &[ - (llvm_ndarray.into(), ndarray.as_base_value().into()), - (llvm_pusize.into(), indices.into()), + (llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into()), + (llvm_usize.into(), index.into()), ], Some("pelement"), None, @@ -211,39 +215,83 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized .unwrap() } -pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>( +/// Generates a call to `__nac3_ndarray_get_pelement_by_indices`. +/// +/// `indices` must have the same number of elements as the number of dimensions in `ndarray`. +/// +/// Returns a [`PointerValue`] to the element indexed by `indices`. +pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, -) { - let llvm_ndarray = ndarray.get_type().as_base_type(); + indices: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, +) -> PointerValue<'ctx> { + let llvm_i8 = ctx.ctx.i8_type(); + let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); + let llvm_usize = ctx.get_size_type(); + let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + let llvm_ndarray = ndarray.get_type(); - let name = - get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_set_strides_by_shape"); + assert_eq!( + BasicTypeEnum::try_from(indices.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_get_pelement_by_indices"); + + create_and_call_function( + ctx, + &name, + Some(llvm_pi8.into()), + &[ + (llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into()), + (llvm_pusize.into(), indices.base_ptr(ctx, generator).into()), + ], + Some("pelement"), + None, + ) + .map(BasicValueEnum::into_pointer_value) + .unwrap() +} + +/// Generates a call to `__nac3_ndarray_set_strides_by_shape`. +/// +/// Sets `ndarray.strides` assuming that `ndarray.shape` is C-contiguous. +pub fn call_nac3_ndarray_set_strides_by_shape<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + ndarray: NDArrayValue<'ctx>, +) { + let llvm_ndarray = ndarray.get_type(); + + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_set_strides_by_shape"); create_and_call_function( ctx, &name, None, - &[(llvm_ndarray.into(), ndarray.as_base_value().into())], + &[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())], None, None, ); } -pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, +/// Generates a call to `__nac3_ndarray_copy_data`. +/// +/// Copies all elements from `src_ndarray` to `dst_ndarray` using their flattened views. The number +/// of elements in `src_ndarray` must be greater than or equal to the number of elements in +/// `dst_ndarray`. +pub fn call_nac3_ndarray_copy_data<'ctx>( ctx: &CodeGenContext<'ctx, '_>, src_ndarray: NDArrayValue<'ctx>, dst_ndarray: NDArrayValue<'ctx>, ) { - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_copy_data"); + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_copy_data"); infer_and_call_function( ctx, &name, None, - &[src_ndarray.as_base_value().into(), dst_ndarray.as_base_value().into()], + &[src_ndarray.as_abi_value(ctx).into(), dst_ndarray.as_abi_value(ctx).into()], None, None, ); diff --git a/nac3core/src/codegen/irrt/ndarray/broadcast.rs b/nac3core/src/codegen/irrt/ndarray/broadcast.rs new file mode 100644 index 0000000..59b0e4c --- /dev/null +++ b/nac3core/src/codegen/irrt/ndarray/broadcast.rs @@ -0,0 +1,80 @@ +use inkwell::values::IntValue; + +use crate::codegen::{ + expr::infer_and_call_function, + irrt::get_usize_dependent_function_name, + types::{ndarray::ShapeEntryType, ProxyType}, + values::{ + ndarray::NDArrayValue, ArrayLikeValue, ArraySliceValue, ProxyValue, TypedArrayLikeAccessor, + TypedArrayLikeMutator, + }, + CodeGenContext, CodeGenerator, +}; + +/// Generates a call to `__nac3_ndarray_broadcast_to`. +/// +/// Attempts to broadcast `src_ndarray` to the new shape defined by `dst_ndarray`. +/// +/// `dst_ndarray` must meet the following preconditions: +/// +/// - `dst_ndarray.ndims` must be initialized and matching the length of `dst_ndarray.shape`. +/// - `dst_ndarray.shape` must be initialized and contains the target broadcast shape. +/// - `dst_ndarray.strides` must be allocated and may contain uninitialized values. +pub fn call_nac3_ndarray_broadcast_to<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + src_ndarray: NDArrayValue<'ctx>, + dst_ndarray: NDArrayValue<'ctx>, +) { + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_broadcast_to"); + infer_and_call_function( + ctx, + &name, + None, + &[src_ndarray.as_abi_value(ctx).into(), dst_ndarray.as_abi_value(ctx).into()], + None, + None, + ); +} + +/// Generates a call to `__nac3_ndarray_broadcast_shapes`. +/// +/// Attempts to calculate the resultant shape from broadcasting all shapes in `shape_entries`, +/// writing the result to `dst_shape`. +pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G, Shape>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + num_shape_entries: IntValue<'ctx>, + shape_entries: ArraySliceValue<'ctx>, + dst_ndims: IntValue<'ctx>, + dst_shape: &Shape, +) where + G: CodeGenerator + ?Sized, + Shape: TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> + + TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>>, +{ + let llvm_usize = ctx.get_size_type(); + + assert_eq!(num_shape_entries.get_type(), llvm_usize); + assert!(ShapeEntryType::is_representable( + shape_entries.base_ptr(ctx, generator).get_type(), + llvm_usize, + ) + .is_ok()); + assert_eq!(dst_ndims.get_type(), llvm_usize); + assert_eq!(dst_shape.element_type(ctx, generator), llvm_usize.into()); + + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_broadcast_shapes"); + infer_and_call_function( + ctx, + &name, + None, + &[ + num_shape_entries.into(), + shape_entries.base_ptr(ctx, generator).into(), + dst_ndims.into(), + dst_shape.base_ptr(ctx, generator).into(), + ], + None, + None, + ); +} diff --git a/nac3core/src/codegen/irrt/ndarray/indexing.rs b/nac3core/src/codegen/irrt/ndarray/indexing.rs index 0821b2c..0d5d920 100644 --- a/nac3core/src/codegen/irrt/ndarray/indexing.rs +++ b/nac3core/src/codegen/irrt/ndarray/indexing.rs @@ -5,6 +5,11 @@ use crate::codegen::{ CodeGenContext, CodeGenerator, }; +/// Generates a call to `__nac3_ndarray_index`. +/// +/// Performs [basic indexing](https://numpy.org/doc/stable/user/basics.indexing.html#basic-indexing) +/// on `src_ndarray` using `indices`, writing the result to `dst_ndarray`, corresponding to the +/// operation `dst_ndarray = src_ndarray[indices]`. pub fn call_nac3_ndarray_index<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, @@ -12,7 +17,7 @@ pub fn call_nac3_ndarray_index<'ctx, G: CodeGenerator + ?Sized>( src_ndarray: NDArrayValue<'ctx>, dst_ndarray: NDArrayValue<'ctx>, ) { - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_index"); + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_index"); infer_and_call_function( ctx, &name, @@ -20,8 +25,8 @@ pub fn call_nac3_ndarray_index<'ctx, G: CodeGenerator + ?Sized>( &[ indices.size(ctx, generator).into(), indices.base_ptr(ctx, generator).into(), - src_ndarray.as_base_value().into(), - dst_ndarray.as_base_value().into(), + src_ndarray.as_abi_value(ctx).into(), + dst_ndarray.as_abi_value(ctx).into(), ], None, None, diff --git a/nac3core/src/codegen/irrt/ndarray/iter.rs b/nac3core/src/codegen/irrt/ndarray/iter.rs index 966d660..e4424df 100644 --- a/nac3core/src/codegen/irrt/ndarray/iter.rs +++ b/nac3core/src/codegen/irrt/ndarray/iter.rs @@ -1,4 +1,5 @@ use inkwell::{ + types::BasicTypeEnum, values::{BasicValueEnum, IntValue}, AddressSpace, }; @@ -9,30 +10,38 @@ use crate::codegen::{ types::ProxyType, values::{ ndarray::{NDArrayValue, NDIterValue}, - ArrayLikeValue, ArraySliceValue, ProxyValue, + ProxyValue, TypedArrayLikeAccessor, }, CodeGenContext, CodeGenerator, }; +/// Generates a call to `__nac3_nditer_initialize`. +/// +/// Initializes the `iter` object. pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, iter: NDIterValue<'ctx>, ndarray: NDArrayValue<'ctx>, - indices: ArraySliceValue<'ctx>, + indices: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ) { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_nditer_initialize"); + assert_eq!( + BasicTypeEnum::try_from(indices.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + + let name = get_usize_dependent_function_name(ctx, "__nac3_nditer_initialize"); create_and_call_function( ctx, &name, None, &[ - (iter.get_type().as_base_type().into(), iter.as_base_value().into()), - (ndarray.get_type().as_base_type().into(), ndarray.as_base_value().into()), + (iter.get_type().as_abi_type().into(), iter.as_abi_value(ctx).into()), + (ndarray.get_type().as_abi_type().into(), ndarray.as_abi_value(ctx).into()), (llvm_pusize.into(), indices.base_ptr(ctx, generator).into()), ], None, @@ -40,18 +49,21 @@ pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>( ); } -pub fn call_nac3_nditer_has_element<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, +/// Generates a call to `__nac3_nditer_initialize_has_element`. +/// +/// Returns an `i1` value indicating whether there are elements left to traverse for the `iter` +/// object. +pub fn call_nac3_nditer_has_element<'ctx>( ctx: &CodeGenContext<'ctx, '_>, iter: NDIterValue<'ctx>, ) -> IntValue<'ctx> { - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_nditer_has_element"); + let name = get_usize_dependent_function_name(ctx, "__nac3_nditer_has_element"); infer_and_call_function( ctx, &name, Some(ctx.ctx.bool_type().into()), - &[iter.as_base_value().into()], + &[iter.as_abi_value(ctx).into()], None, None, ) @@ -59,12 +71,11 @@ pub fn call_nac3_nditer_has_element<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } -pub fn call_nac3_nditer_next<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - iter: NDIterValue<'ctx>, -) { - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_nditer_next"); +/// Generates a call to `__nac3_nditer_next`. +/// +/// Moves `iter` to point to the next element. +pub fn call_nac3_nditer_next<'ctx>(ctx: &CodeGenContext<'ctx, '_>, iter: NDIterValue<'ctx>) { + let name = get_usize_dependent_function_name(ctx, "__nac3_nditer_next"); - infer_and_call_function(ctx, &name, None, &[iter.as_base_value().into()], None, None); + infer_and_call_function(ctx, &name, None, &[iter.as_abi_value(ctx).into()], None, None); } diff --git a/nac3core/src/codegen/irrt/ndarray/matmul.rs b/nac3core/src/codegen/irrt/ndarray/matmul.rs new file mode 100644 index 0000000..0df774f --- /dev/null +++ b/nac3core/src/codegen/irrt/ndarray/matmul.rs @@ -0,0 +1,65 @@ +use inkwell::{types::BasicTypeEnum, values::IntValue}; + +use crate::codegen::{ + expr::infer_and_call_function, irrt::get_usize_dependent_function_name, + values::TypedArrayLikeAccessor, CodeGenContext, CodeGenerator, +}; + +/// Generates a call to `__nac3_ndarray_matmul_calculate_shapes`. +/// +/// Calculates the broadcasted shapes for `a`, `b`, and the `ndarray` holding the final values of +/// `a @ b`. +#[allow(clippy::too_many_arguments)] +pub fn call_nac3_ndarray_matmul_calculate_shapes<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + a_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + b_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + final_ndims: IntValue<'ctx>, + new_a_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + new_b_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + dst_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, +) { + let llvm_usize = ctx.get_size_type(); + + assert_eq!( + BasicTypeEnum::try_from(a_shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + assert_eq!( + BasicTypeEnum::try_from(b_shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + assert_eq!( + BasicTypeEnum::try_from(new_a_shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + assert_eq!( + BasicTypeEnum::try_from(new_b_shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + assert_eq!( + BasicTypeEnum::try_from(dst_shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_matmul_calculate_shapes"); + + infer_and_call_function( + ctx, + &name, + None, + &[ + a_shape.size(ctx, generator).into(), + a_shape.base_ptr(ctx, generator).into(), + b_shape.size(ctx, generator).into(), + b_shape.base_ptr(ctx, generator).into(), + final_ndims.into(), + new_a_shape.base_ptr(ctx, generator).into(), + new_b_shape.base_ptr(ctx, generator).into(), + dst_shape.base_ptr(ctx, generator).into(), + ], + None, + None, + ); +} diff --git a/nac3core/src/codegen/irrt/ndarray/mod.rs b/nac3core/src/codegen/irrt/ndarray/mod.rs index a05e0ce..b153068 100644 --- a/nac3core/src/codegen/irrt/ndarray/mod.rs +++ b/nac3core/src/codegen/irrt/ndarray/mod.rs @@ -1,391 +1,17 @@ -use inkwell::{ - types::IntType, - values::{BasicValueEnum, CallSiteValue, IntValue}, - AddressSpace, IntPredicate, -}; -use itertools::Either; - -use crate::codegen::{ - llvm_intrinsics, - macros::codegen_unreachable, - stmt::gen_for_callback_incrementing, - values::{ - ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, - TypedArrayLikeAccessor, TypedArrayLikeAdapter, UntypedArrayLikeAccessor, - }, - CodeGenContext, CodeGenerator, -}; +pub use array::*; pub use basic::*; +pub use broadcast::*; pub use indexing::*; pub use iter::*; +pub use matmul::*; +pub use reshape::*; +pub use transpose::*; +mod array; mod basic; +mod broadcast; mod indexing; mod iter; - -/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the -/// calculated total size. -/// -/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension. -/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for, -/// or [`None`] if starting from the first dimension and ending at the last dimension -/// respectively. -pub fn call_ndarray_calc_size<'ctx, G, Dims>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - dims: &Dims, - (begin, end): (Option>, Option>), -) -> IntValue<'ctx> -where - G: CodeGenerator + ?Sized, - Dims: ArrayLikeIndexer<'ctx>, -{ - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_size", - 64 => "__nac3_ndarray_calc_size64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; - let ndarray_calc_size_fn_t = llvm_usize.fn_type( - &[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()], - false, - ); - let ndarray_calc_size_fn = - ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| { - ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None) - }); - - let begin = begin.unwrap_or_else(|| llvm_usize.const_zero()); - let end = end.unwrap_or_else(|| dims.size(ctx, generator)); - ctx.builder - .build_call( - ndarray_calc_size_fn, - &[ - dims.base_ptr(ctx, generator).into(), - dims.size(ctx, generator).into(), - begin.into(), - end.into(), - ], - "", - ) - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap() -} - -/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`] -/// containing `i32` indices of the flattened index. -/// -/// * `index` - The index to compute the multidimensional index for. -/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an -/// `NDArray`. -pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - index: IntValue<'ctx>, - ndarray: NDArrayValue<'ctx>, -) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { - let llvm_void = ctx.ctx.void_type(); - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_nd_indices", - 64 => "__nac3_ndarray_calc_nd_indices64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; - let ndarray_calc_nd_indices_fn = - ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| { - let fn_type = llvm_void.fn_type( - &[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()], - false, - ); - - ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None) - }); - - let ndarray_num_dims = ndarray.load_ndims(ctx); - let ndarray_dims = ndarray.shape(); - - let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap(); - - ctx.builder - .build_call( - ndarray_calc_nd_indices_fn, - &[ - index.into(), - ndarray_dims.base_ptr(ctx, generator).into(), - ndarray_num_dims.into(), - indices.into(), - ], - "", - ) - .unwrap(); - - TypedArrayLikeAdapter::from( - ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None), - Box::new(|_, v| v.into_int_value()), - Box::new(|_, v| v.into()), - ) -} - -fn call_ndarray_flatten_index_impl<'ctx, G, Indices>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ndarray: NDArrayValue<'ctx>, - indices: &Indices, -) -> IntValue<'ctx> -where - G: CodeGenerator + ?Sized, - Indices: ArrayLikeIndexer<'ctx>, -{ - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - debug_assert_eq!( - IntType::try_from(indices.element_type(ctx, generator)) - .map(IntType::get_bit_width) - .unwrap_or_default(), - llvm_i32.get_bit_width(), - "Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`" - ); - debug_assert_eq!( - indices.size(ctx, generator).get_type().get_bit_width(), - llvm_usize.get_bit_width(), - "Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`" - ); - - let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_flatten_index", - 64 => "__nac3_ndarray_flatten_index64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; - let ndarray_flatten_index_fn = - ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| { - let fn_type = llvm_usize.fn_type( - &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()], - false, - ); - - ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None) - }); - - let ndarray_num_dims = ndarray.load_ndims(ctx); - let ndarray_dims = ndarray.shape(); - - let index = ctx - .builder - .build_call( - ndarray_flatten_index_fn, - &[ - ndarray_dims.base_ptr(ctx, generator).into(), - ndarray_num_dims.into(), - indices.base_ptr(ctx, generator).into(), - indices.size(ctx, generator).into(), - ], - "", - ) - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap(); - - index -} - -/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the -/// multidimensional index. -/// -/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an -/// `NDArray`. -/// * `indices` - The multidimensional index to compute the flattened index for. -pub fn call_ndarray_flatten_index<'ctx, G, Index>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ndarray: NDArrayValue<'ctx>, - indices: &Index, -) -> IntValue<'ctx> -where - G: CodeGenerator + ?Sized, - Index: ArrayLikeIndexer<'ctx>, -{ - call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices) -} - -/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of -/// dimension and size of each dimension of the resultant `ndarray`. -pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - lhs: NDArrayValue<'ctx>, - rhs: NDArrayValue<'ctx>, -) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_broadcast", - 64 => "__nac3_ndarray_calc_broadcast64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; - let ndarray_calc_broadcast_fn = - ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { - let fn_type = llvm_usize.fn_type( - &[ - llvm_pusize.into(), - llvm_usize.into(), - llvm_pusize.into(), - llvm_usize.into(), - llvm_pusize.into(), - ], - false, - ); - - ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) - }); - - let lhs_ndims = lhs.load_ndims(ctx); - let rhs_ndims = rhs.load_ndims(ctx); - let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (min_ndims, false), - |generator, ctx, _, idx| { - let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap(); - let (lhs_dim_sz, rhs_dim_sz) = unsafe { - ( - lhs.shape().get_typed_unchecked(ctx, generator, &idx, None), - rhs.shape().get_typed_unchecked(ctx, generator, &idx, None), - ) - }; - - let llvm_usize_const_one = llvm_usize.const_int(1, false); - let lhs_eqz = ctx - .builder - .build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "") - .unwrap(); - let rhs_eqz = ctx - .builder - .build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "") - .unwrap(); - let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap(); - - let lhs_eq_rhs = ctx - .builder - .build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "") - .unwrap(); - - let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap(); - - ctx.make_assert( - generator, - is_compatible, - "0:ValueError", - "operands could not be broadcast together", - [None, None, None], - ctx.current_loc, - ); - - Ok(()) - }, - llvm_usize.const_int(1, false), - ) - .unwrap(); - - let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None); - let lhs_dims = lhs.shape().base_ptr(ctx, generator); - let lhs_ndims = lhs.load_ndims(ctx); - let rhs_dims = rhs.shape().base_ptr(ctx, generator); - let rhs_ndims = rhs.load_ndims(ctx); - let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap(); - let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None); - - ctx.builder - .build_call( - ndarray_calc_broadcast_fn, - &[ - lhs_dims.into(), - lhs_ndims.into(), - rhs_dims.into(), - rhs_ndims.into(), - out_dims.base_ptr(ctx, generator).into(), - ], - "", - ) - .unwrap(); - - TypedArrayLikeAdapter::from( - out_dims, - Box::new(|_, v| v.into_int_value()), - Box::new(|_, v| v.into()), - ) -} - -/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`] -/// containing the indices used for accessing `array` corresponding to the index of the broadcasted -/// array `broadcast_idx`. -pub fn call_ndarray_calc_broadcast_index< - 'ctx, - G: CodeGenerator + ?Sized, - BroadcastIdx: UntypedArrayLikeAccessor<'ctx>, ->( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - array: NDArrayValue<'ctx>, - broadcast_idx: &BroadcastIdx, -) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_broadcast_idx", - 64 => "__nac3_ndarray_calc_broadcast_idx64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; - let ndarray_calc_broadcast_fn = - ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { - let fn_type = llvm_usize.fn_type( - &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()], - false, - ); - - ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) - }); - - let broadcast_size = broadcast_idx.size(ctx, generator); - let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap(); - - let array_dims = array.shape().base_ptr(ctx, generator); - let array_ndims = array.load_ndims(ctx); - let broadcast_idx_ptr = unsafe { - broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - - ctx.builder - .build_call( - ndarray_calc_broadcast_fn, - &[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()], - "", - ) - .unwrap(); - - TypedArrayLikeAdapter::from( - ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None), - Box::new(|_, v| v.into_int_value()), - Box::new(|_, v| v.into()), - ) -} +mod matmul; +mod reshape; +mod transpose; diff --git a/nac3core/src/codegen/irrt/ndarray/reshape.rs b/nac3core/src/codegen/irrt/ndarray/reshape.rs new file mode 100644 index 0000000..66cbf13 --- /dev/null +++ b/nac3core/src/codegen/irrt/ndarray/reshape.rs @@ -0,0 +1,39 @@ +use inkwell::values::IntValue; + +use crate::codegen::{ + expr::infer_and_call_function, + irrt::get_usize_dependent_function_name, + values::{ArrayLikeValue, ArraySliceValue}, + CodeGenContext, CodeGenerator, +}; + +/// Generates a call to `__nac3_ndarray_reshape_resolve_and_check_new_shape`. +/// +/// Resolves unknown dimensions in `new_shape` for `numpy.reshape(, new_shape)`, raising an +/// assertion if multiple dimensions are unknown (`-1`). +pub fn call_nac3_ndarray_reshape_resolve_and_check_new_shape<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + size: IntValue<'ctx>, + new_ndims: IntValue<'ctx>, + new_shape: ArraySliceValue<'ctx>, +) { + let llvm_usize = ctx.get_size_type(); + + assert_eq!(size.get_type(), llvm_usize); + assert_eq!(new_ndims.get_type(), llvm_usize); + assert_eq!(new_shape.element_type(ctx, generator), llvm_usize.into()); + + let name = get_usize_dependent_function_name( + ctx, + "__nac3_ndarray_reshape_resolve_and_check_new_shape", + ); + infer_and_call_function( + ctx, + &name, + None, + &[size.into(), new_ndims.into(), new_shape.base_ptr(ctx, generator).into()], + None, + None, + ); +} diff --git a/nac3core/src/codegen/irrt/ndarray/transpose.rs b/nac3core/src/codegen/irrt/ndarray/transpose.rs new file mode 100644 index 0000000..331611f --- /dev/null +++ b/nac3core/src/codegen/irrt/ndarray/transpose.rs @@ -0,0 +1,48 @@ +use inkwell::{values::IntValue, AddressSpace}; + +use crate::codegen::{ + expr::infer_and_call_function, + irrt::get_usize_dependent_function_name, + values::{ndarray::NDArrayValue, ProxyValue, TypedArrayLikeAccessor}, + CodeGenContext, CodeGenerator, +}; + +/// Generates a call to `__nac3_ndarray_transpose`. +/// +/// Creates a transpose view of `src_ndarray` and writes the result to `dst_ndarray`. +/// +/// `dst_ndarray` must fulfill the following preconditions: +/// +/// - `dst_ndarray.ndims` must be initialized and must be equal to `src_ndarray.ndims`. +/// - `dst_ndarray.shape` must be allocated and may contain uninitialized values. +/// - `dst_ndarray.strides` must be allocated and may contain uninitialized values. +pub fn call_nac3_ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + src_ndarray: NDArrayValue<'ctx>, + dst_ndarray: NDArrayValue<'ctx>, + axes: Option<&impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>>, +) { + let llvm_usize = ctx.get_size_type(); + + assert!(axes.is_none_or(|axes| axes.size(ctx, generator).get_type() == llvm_usize)); + assert!(axes.is_none_or(|axes| axes.element_type(ctx, generator) == llvm_usize.into())); + + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_transpose"); + infer_and_call_function( + ctx, + &name, + None, + &[ + src_ndarray.as_abi_value(ctx).into(), + dst_ndarray.as_abi_value(ctx).into(), + axes.map_or(llvm_usize.const_zero(), |axes| axes.size(ctx, generator)).into(), + axes.map_or(llvm_usize.ptr_type(AddressSpace::default()).const_null(), |axes| { + axes.base_ptr(ctx, generator) + }) + .into(), + ], + None, + None, + ); +} diff --git a/nac3core/src/codegen/irrt/range.rs b/nac3core/src/codegen/irrt/range.rs index 47c63c4..3b6bc31 100644 --- a/nac3core/src/codegen/irrt/range.rs +++ b/nac3core/src/codegen/irrt/range.rs @@ -6,6 +6,13 @@ use itertools::Either; use crate::codegen::{CodeGenContext, CodeGenerator}; +/// Invokes the `__nac3_range_slice_len` in IRRT. +/// +/// - `start`: The `i32` start value for the slice. +/// - `end`: The `i32` end value for the slice. +/// - `step`: The `i32` step value for the slice. +/// +/// Returns an `i32` value of the length of the slice. pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, @@ -14,9 +21,15 @@ pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>( step: IntValue<'ctx>, ) -> IntValue<'ctx> { const SYMBOL: &str = "__nac3_range_slice_len"; + + let llvm_i32 = ctx.ctx.i32_type(); + + assert_eq!(start.get_type(), llvm_i32); + assert_eq!(end.get_type(), llvm_i32); + assert_eq!(step.get_type(), llvm_i32); + let len_func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| { - let i32_t = ctx.ctx.i32_type(); - let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into(), i32_t.into()], false); + let fn_t = llvm_i32.fn_type(&[llvm_i32.into(), llvm_i32.into(), llvm_i32.into()], false); ctx.module.add_function(SYMBOL, fn_t, None) }); @@ -33,6 +46,7 @@ pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>( [None, None, None], ctx.current_loc, ); + ctx.builder .build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len") .map(CallSiteValue::try_as_basic_value) diff --git a/nac3core/src/codegen/irrt/string.rs b/nac3core/src/codegen/irrt/string.rs new file mode 100644 index 0000000..e2fd8c0 --- /dev/null +++ b/nac3core/src/codegen/irrt/string.rs @@ -0,0 +1,45 @@ +use inkwell::values::{BasicValueEnum, CallSiteValue, IntValue, PointerValue}; +use itertools::Either; + +use super::get_usize_dependent_function_name; +use crate::codegen::CodeGenContext; + +/// Generates a call to string equality comparison. Returns an `i1` representing whether the strings are equal. +pub fn call_string_eq<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + str1_ptr: PointerValue<'ctx>, + str1_len: IntValue<'ctx>, + str2_ptr: PointerValue<'ctx>, + str2_len: IntValue<'ctx>, +) -> IntValue<'ctx> { + let llvm_i1 = ctx.ctx.bool_type(); + + let func_name = get_usize_dependent_function_name(ctx, "nac3_str_eq"); + + let func = ctx.module.get_function(&func_name).unwrap_or_else(|| { + ctx.module.add_function( + &func_name, + llvm_i1.fn_type( + &[ + str1_ptr.get_type().into(), + str1_len.get_type().into(), + str2_ptr.get_type().into(), + str2_len.get_type().into(), + ], + false, + ), + None, + ) + }); + + ctx.builder + .build_call( + func, + &[str1_ptr.into(), str1_len.into(), str2_ptr.into(), str2_len.into()], + "str_eq_call", + ) + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_int_value)) + .map(Either::unwrap_left) + .unwrap() +} diff --git a/nac3core/src/codegen/llvm_intrinsics.rs b/nac3core/src/codegen/llvm_intrinsics.rs index 895339d..9c360b6 100644 --- a/nac3core/src/codegen/llvm_intrinsics.rs +++ b/nac3core/src/codegen/llvm_intrinsics.rs @@ -1,7 +1,6 @@ use inkwell::{ - context::Context, intrinsics::Intrinsic, - types::{AnyTypeEnum::IntType, FloatType}, + types::AnyTypeEnum::IntType, values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue}, AddressSpace, }; @@ -9,34 +8,6 @@ use itertools::Either; use super::CodeGenContext; -/// Returns the string representation for the floating-point type `ft` when used in intrinsic -/// functions. -fn get_float_intrinsic_repr(ctx: &Context, ft: FloatType) -> &'static str { - // Standard LLVM floating-point types - if ft == ctx.f16_type() { - return "f16"; - } - if ft == ctx.f32_type() { - return "f32"; - } - if ft == ctx.f64_type() { - return "f64"; - } - if ft == ctx.f128_type() { - return "f128"; - } - - // Non-standard floating-point types - if ft == ctx.x86_f80_type() { - return "f80"; - } - if ft == ctx.ppc_f128_type() { - return "ppcf128"; - } - - unreachable!() -} - /// Invokes the [`llvm.va_start`](https://llvm.org/docs/LangRef.html#llvm-va-start-intrinsic) /// intrinsic. pub fn call_va_start<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue<'ctx>) { @@ -54,7 +25,7 @@ pub fn call_va_start<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue ctx.builder.build_call(intrinsic_fn, &[arglist.into()], "").unwrap(); } -/// Invokes the [`llvm.va_start`](https://llvm.org/docs/LangRef.html#llvm-va-start-intrinsic) +/// Invokes the [`llvm.va_end`](https://llvm.org/docs/LangRef.html#llvm-va-end-intrinsic) /// intrinsic. pub fn call_va_end<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue<'ctx>) { const FN_NAME: &str = "llvm.va_end"; diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 1e0fb26..a188d1c 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -1,4 +1,5 @@ use std::{ + cell::OnceCell, collections::{HashMap, HashSet}, sync::{ atomic::{AtomicBool, Ordering}, @@ -19,7 +20,7 @@ use inkwell::{ module::Module, passes::PassBuilderOptions, targets::{CodeModel, RelocMode, Target, TargetMachine, TargetTriple}, - types::{AnyType, BasicType, BasicTypeEnum}, + types::{AnyType, BasicType, BasicTypeEnum, IntType}, values::{BasicValueEnum, FunctionValue, IntValue, PhiValue, PointerValue}, AddressSpace, IntPredicate, OptimizationLevel, }; @@ -42,7 +43,7 @@ use crate::{ }; use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore}; pub use generator::{CodeGenerator, DefaultCodeGenerator}; -use types::{ndarray::NDArrayType, ListType, ProxyType, RangeType}; +use types::{ndarray::NDArrayType, ListType, ProxyType, RangeType, TupleType}; pub mod builtin_fns; pub mod concrete_type; @@ -226,14 +227,33 @@ pub struct CodeGenContext<'ctx, 'a> { /// The current source location. pub current_loc: Location, + + /// The cached type of `size_t`. + llvm_usize: OnceCell>, } -impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { +impl<'ctx> CodeGenContext<'ctx, '_> { /// Whether the [current basic block][Builder::get_insert_block] referenced by `builder` /// contains a [terminator statement][BasicBlock::get_terminator]. pub fn is_terminated(&self) -> bool { self.builder.get_insert_block().and_then(BasicBlock::get_terminator).is_some() } + + /// Returns a [`IntType`] representing `size_t` for the compilation target as specified by + /// [`self.registry`][WorkerRegistry]. + pub fn get_size_type(&self) -> IntType<'ctx> { + *self.llvm_usize.get_or_init(|| { + self.ctx.ptr_sized_int_type( + &self + .registry + .llvm_options + .create_target_machine() + .map(|tm| tm.get_target_data()) + .unwrap(), + None, + ) + }) + } } type Fp = Box; @@ -481,6 +501,38 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( type_cache.get(&unifier.get_representative(ty)).copied().unwrap_or_else(|| { let ty_enum = unifier.get_ty(ty); let result = match &*ty_enum { + TModule {module_id, attributes} => { + let top_level_defs = top_level.definitions.read(); + let definition = top_level_defs.get(module_id.0).unwrap(); + let TopLevelDef::Module { name, attributes: attribute_fields, .. } = &*definition.read() else { + unreachable!() + }; + let ty: BasicTypeEnum<'_> = if let Some(t) = module.get_struct_type(&name.to_string()) { + t.ptr_type(AddressSpace::default()).into() + } else { + let struct_type = ctx.opaque_struct_type(&name.to_string()); + type_cache.insert( + unifier.get_representative(ty), + struct_type.ptr_type(AddressSpace::default()).into(), + ); + let module_fields: Vec> = attribute_fields.iter() + .map(|f| { + get_llvm_type( + ctx, + module, + generator, + unifier, + top_level, + type_cache, + attributes[&f.0].0, + ) + }) + .collect_vec(); + struct_type.set_body(&module_fields, false); + struct_type.ptr_type(AddressSpace::default()).into() + }; + return ty; + }, TObj { obj_id, fields, .. } => { // check to avoid treating non-class primitives as classes if PrimDef::contains_id(*obj_id) { @@ -510,7 +562,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( *params.iter().next().unwrap().1, ); - ListType::new(generator, ctx, element_type).as_base_type().into() + ListType::new_with_generator(generator, ctx, element_type).as_abi_type().into() } TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { @@ -520,7 +572,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( ctx, module, generator, unifier, top_level, type_cache, dtype, ); - NDArrayType::new(generator, ctx, element_type, Some(ndims)).as_base_type().into() + NDArrayType::new_with_generator(generator, ctx, element_type, ndims).as_abi_type().into() } _ => unreachable!( @@ -574,7 +626,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty) }) .collect_vec(); - ctx.struct_type(&fields, false).into() + TupleType::new_with_generator(generator, ctx, &fields).as_abi_type().into() } TVirtual { .. } => unimplemented!(), _ => unreachable!("{}", ty_enum.get_type_name()), @@ -748,7 +800,7 @@ pub fn gen_func_impl< Some(t) => t.as_basic_type_enum(), } }), - (primitives.range, RangeType::new(context).as_base_type().into()), + (primitives.range, RangeType::new_with_generator(generator, context).as_abi_type().into()), (primitives.exception, { let name = "Exception"; if let Some(t) = module.get_struct_type(name) { @@ -987,8 +1039,20 @@ pub fn gen_func_impl< need_sret: has_sret, current_loc: Location::default(), debug_info: (dibuilder, compile_unit, func_scope.as_debug_info_scope()), + llvm_usize: OnceCell::default(), }; + let target_llvm_usize = context.ptr_sized_int_type( + ®istry.llvm_options.create_target_machine().map(|tm| tm.get_target_data()).unwrap(), + None, + ); + let generator_llvm_usize = generator.get_size_type(context); + assert_eq!( + generator_llvm_usize, + target_llvm_usize, + "CodeGenerator (size_t = {generator_llvm_usize}) is not compatible with CodeGen Target (size_t = {target_llvm_usize})", + ); + let loc = code_gen_context.debug_info.0.create_debug_location( context, row as u32, @@ -1180,7 +1244,7 @@ pub fn type_aligned_alloca<'ctx, G: CodeGenerator + ?Sized>( let llvm_i8 = ctx.ctx.i8_type(); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let align_ty = align_ty.into(); let size = ctx.builder.build_int_truncate_or_bit_cast(size, llvm_usize, "").unwrap(); diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index cd113aa..dfb1b4d 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -1,1714 +1,27 @@ use inkwell::{ - types::{AnyTypeEnum, BasicType, BasicTypeEnum, PointerType}, - values::{BasicValue, BasicValueEnum, IntValue, PointerValue}, - AddressSpace, IntPredicate, OptimizationLevel, + values::{BasicValue, BasicValueEnum, PointerValue}, + IntPredicate, }; -use itertools::Itertools; -use nac3parser::ast::{Operator, StrRef}; +use nac3parser::ast::StrRef; use super::{ - expr::gen_binop_expr_with_values, - irrt::{ - calculate_len_for_slice_range, - ndarray::{ - call_ndarray_calc_broadcast, call_ndarray_calc_broadcast_index, - call_ndarray_calc_nd_indices, call_ndarray_calc_size, - }, - }, - llvm_intrinsics::{self, call_memcpy_generic}, macros::codegen_unreachable, - stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, - types::{ndarray::NDArrayType, ListType, ProxyType}, - values::{ - ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, - TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, - UntypedArrayLikeAccessor, UntypedArrayLikeMutator, - }, + stmt::gen_for_callback, + types::ndarray::{NDArrayType, NDIterType}, + values::{ndarray::shape::parse_numpy_int_sequence, ProxyValue}, CodeGenContext, CodeGenerator, }; use crate::{ symbol_resolver::ValueEnum, - toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId}, - typecheck::{ - magic_methods::Binop, - typedef::{FunSignature, Type, TypeEnum}, + toplevel::{ + helper::{arraylike_flatten_element_type, extract_ndims}, + numpy::unpack_ndarray_var_tys, + DefinitionId, }, + typecheck::typedef::{FunSignature, Type}, }; -/// Creates an `NDArray` instance from a dynamic shape. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The shape of the `NDArray`. -/// * `shape_len_fn` - A function that retrieves the number of dimensions from `shape`. -/// * `shape_data_fn` - A function that retrieves the size of a dimension from `shape`. -fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - elem_ty: Type, - shape: &V, - shape_len_fn: LenFn, - shape_data_fn: DataFn, -) -> Result, String> -where - G: CodeGenerator + ?Sized, - LenFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, &V) -> Result, String>, - DataFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - &V, - IntValue<'ctx>, - ) -> Result, String>, -{ - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); - - // Assert that all dimensions are non-negative - let shape_len = shape_len_fn(generator, ctx, shape)?; - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (shape_len, false), - |generator, ctx, _, i| { - let shape_dim = shape_data_fn(generator, ctx, shape, i)?; - debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width()); - - let shape_dim_gez = ctx - .builder - .build_int_compare( - IntPredicate::SGE, - shape_dim, - shape_dim.get_type().const_zero(), - "", - ) - .unwrap(); - - ctx.make_assert( - generator, - shape_dim_gez, - "0:ValueError", - "negative dimensions not supported", - [None, None, None], - ctx.current_loc, - ); - - // TODO: Disallow shape > u32_MAX - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - let num_dims = shape_len_fn(generator, ctx, shape)?; - - let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty, None) - .construct_dyn_ndims(generator, ctx, num_dims, None); - - // Copy the dimension sizes from shape to ndarray.dims - let shape_len = shape_len_fn(generator, ctx, shape)?; - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (shape_len, false), - |generator, ctx, _, i| { - let shape_dim = shape_data_fn(generator, ctx, shape, i)?; - debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width()); - let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap(); - - let ndarray_pdim = - unsafe { ndarray.shape().ptr_offset_unchecked(ctx, generator, &i, None) }; - - ctx.builder.build_store(ndarray_pdim, shape_dim).unwrap(); - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - unsafe { ndarray.create_data(generator, ctx) }; - - Ok(ndarray) -} - -/// Creates an `NDArray` instance from a constant shape. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The shape of the `NDArray`, represented am array of [`IntValue`]s. -pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - shape: &[IntValue<'ctx>], -) -> Result, String> { - let llvm_usize = generator.get_size_type(ctx.ctx); - - for &shape_dim in shape { - let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap(); - let shape_dim_gez = ctx - .builder - .build_int_compare(IntPredicate::SGE, shape_dim, llvm_usize.const_zero(), "") - .unwrap(); - - ctx.make_assert( - generator, - shape_dim_gez, - "0:ValueError", - "negative dimensions not supported", - [None, None, None], - ctx.current_loc, - ); - - // TODO: Disallow shape > u32_MAX - } - - let llvm_dtype = ctx.get_llvm_type(generator, elem_ty); - - let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_dtype, Some(shape.len() as u64)) - .construct_dyn_shape(generator, ctx, shape, None); - unsafe { ndarray.create_data(generator, ctx) }; - - Ok(ndarray) -} - -fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, -) -> BasicValueEnum<'ctx> { - if [ctx.primitives.int32, ctx.primitives.uint32] - .iter() - .any(|ty| ctx.unifier.unioned(elem_ty, *ty)) - { - ctx.ctx.i32_type().const_zero().into() - } else if [ctx.primitives.int64, ctx.primitives.uint64] - .iter() - .any(|ty| ctx.unifier.unioned(elem_ty, *ty)) - { - ctx.ctx.i64_type().const_zero().into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) { - ctx.ctx.f64_type().const_zero().into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) { - ctx.ctx.bool_type().const_zero().into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) { - ctx.gen_string(generator, "").into() - } else { - codegen_unreachable!(ctx) - } -} - -fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, -) -> BasicValueEnum<'ctx> { - if [ctx.primitives.int32, ctx.primitives.uint32] - .iter() - .any(|ty| ctx.unifier.unioned(elem_ty, *ty)) - { - let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int32); - ctx.ctx.i32_type().const_int(1, is_signed).into() - } else if [ctx.primitives.int64, ctx.primitives.uint64] - .iter() - .any(|ty| ctx.unifier.unioned(elem_ty, *ty)) - { - let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int64); - ctx.ctx.i64_type().const_int(1, is_signed).into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) { - ctx.ctx.f64_type().const_float(1.0).into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) { - ctx.ctx.bool_type().const_int(1, false).into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) { - ctx.gen_string(generator, "1").into() - } else { - codegen_unreachable!(ctx) - } -} - -/// LLVM-typed implementation for generating the implementation for constructing an `NDArray`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The `shape` parameter used to construct the `NDArray`. -/// -/// ### Notes on `shape` -/// -/// Just like numpy, the `shape` argument can be: -/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])` -/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))` -/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` -/// -/// See also [`typecheck::type_inferencer::fold_numpy_function_call_shape_argument`] to -/// learn how `shape` gets from being a Python user expression to here. -fn call_ndarray_empty_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - shape: BasicValueEnum<'ctx>, -) -> Result, String> { - let llvm_usize = generator.get_size_type(ctx.ctx); - - match shape { - BasicValueEnum::PointerValue(shape_list_ptr) - if ListValue::is_representable(shape_list_ptr, llvm_usize).is_ok() => - { - // 1. A list of ints; e.g., `np.empty([600, 800, 3])` - - let shape_list = ListValue::from_pointer_value(shape_list_ptr, llvm_usize, None); - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &shape_list, - |_, ctx, shape_list| Ok(shape_list.load_size(ctx, None)), - |generator, ctx, shape_list, idx| { - Ok(shape_list.data().get(ctx, generator, &idx, None).into_int_value()) - }, - ) - } - BasicValueEnum::StructValue(shape_tuple) => { - // 2. A tuple of ints; e.g., `np.empty((600, 800, 3))` - // Read [`codegen::expr::gen_expr`] to see how `nac3core` translates a Python tuple into LLVM. - - // Get the length/size of the tuple, which also happens to be the value of `ndims`. - let ndims = shape_tuple.get_type().count_fields(); - - let shape = (0..ndims) - .map(|dim_i| { - ctx.builder - .build_extract_value(shape_tuple, dim_i, format!("dim{dim_i}").as_str()) - .map(BasicValueEnum::into_int_value) - .map(|v| { - ctx.builder.build_int_z_extend_or_bit_cast(v, llvm_usize, "").unwrap() - }) - .unwrap() - }) - .collect_vec(); - - create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice()) - } - BasicValueEnum::IntValue(shape_int) => { - // 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` - let shape_int = - ctx.builder.build_int_z_extend_or_bit_cast(shape_int, llvm_usize, "").unwrap(); - - create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int]) - } - _ => codegen_unreachable!(ctx), - } -} - -/// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as -/// its input. -fn ndarray_fill_flattened<'ctx, 'a, G, ValueFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - ndarray: NDArrayValue<'ctx>, - value_fn: ValueFn, -) -> Result<(), String> -where - G: CodeGenerator + ?Sized, - ValueFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - IntValue<'ctx>, - ) -> Result, String>, -{ - let llvm_usize = generator.get_size_type(ctx.ctx); - - let ndarray_num_elems = call_ndarray_calc_size( - generator, - ctx, - &ndarray.shape().as_slice_value(ctx, generator), - (None, None), - ); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (ndarray_num_elems, false), - |generator, ctx, _, i| { - let elem = unsafe { ndarray.data().ptr_offset_unchecked(ctx, generator, &i, None) }; - - let value = value_fn(generator, ctx, i)?; - ctx.builder.build_store(elem, value).unwrap(); - - Ok(()) - }, - llvm_usize.const_int(1, false), - ) -} - -/// Generates LLVM IR for populating the entire `NDArray` using a lambda with the dimension-indices -/// as its input. -fn ndarray_fill_indexed<'ctx, 'a, G, ValueFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - ndarray: NDArrayValue<'ctx>, - value_fn: ValueFn, -) -> Result<(), String> -where - G: CodeGenerator + ?Sized, - ValueFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - &TypedArrayLikeAdapter<'ctx, IntValue<'ctx>>, - ) -> Result, String>, -{ - ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, idx| { - let indices = call_ndarray_calc_nd_indices(generator, ctx, idx, ndarray); - - value_fn(generator, ctx, &indices) - }) -} - -fn ndarray_fill_mapping<'ctx, 'a, G, MapFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - src: NDArrayValue<'ctx>, - dest: NDArrayValue<'ctx>, - map_fn: MapFn, -) -> Result<(), String> -where - G: CodeGenerator + ?Sized, - MapFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - BasicValueEnum<'ctx>, - ) -> Result, String>, -{ - ndarray_fill_flattened(generator, ctx, dest, |generator, ctx, i| { - let elem = unsafe { src.data().get_unchecked(ctx, generator, &i, None) }; - - map_fn(generator, ctx, elem) - }) -} - -/// Generates the LLVM IR for checking whether the source `ndarray` can be broadcast to the shape of -/// the target `ndarray`. -fn ndarray_assert_is_broadcastable<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - target: NDArrayValue<'ctx>, - source: NDArrayValue<'ctx>, -) { - let array_ndims = source.load_ndims(ctx); - let broadcast_size = target.load_ndims(ctx); - - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::ULE, array_ndims, broadcast_size, "").unwrap(), - "0:ValueError", - "operands cannot be broadcast together", - [None, None, None], - ctx.current_loc, - ); -} - -/// Generates the LLVM IR for populating the entire `NDArray` from two `ndarray` or scalar value -/// with broadcast-compatible shapes. -fn ndarray_broadcast_fill<'ctx, 'a, G, ValueFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - res: NDArrayValue<'ctx>, - (lhs_ty, lhs_val, lhs_scalar): (Type, BasicValueEnum<'ctx>, bool), - (rhs_ty, rhs_val, rhs_scalar): (Type, BasicValueEnum<'ctx>, bool), - value_fn: ValueFn, -) -> Result, String> -where - G: CodeGenerator + ?Sized, - ValueFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>), - ) -> Result, String>, -{ - assert!( - !(lhs_scalar && rhs_scalar), - "One of the operands must be a ndarray instance: `{}`, `{}`", - lhs_val.get_type(), - rhs_val.get_type() - ); - - // Assert that all ndarray operands are broadcastable to the target size - if !lhs_scalar { - let lhs_val = NDArrayType::from_unifier_type(generator, ctx, lhs_ty) - .map_value(lhs_val.into_pointer_value(), None); - ndarray_assert_is_broadcastable(generator, ctx, res, lhs_val); - } - - if !rhs_scalar { - let rhs_val = NDArrayType::from_unifier_type(generator, ctx, rhs_ty) - .map_value(rhs_val.into_pointer_value(), None); - ndarray_assert_is_broadcastable(generator, ctx, res, rhs_val); - } - - ndarray_fill_indexed(generator, ctx, res, |generator, ctx, idx| { - let lhs_elem = if lhs_scalar { - lhs_val - } else { - let lhs = NDArrayType::from_unifier_type(generator, ctx, lhs_ty) - .map_value(lhs_val.into_pointer_value(), None); - let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, idx); - - unsafe { lhs.data().get_unchecked(ctx, generator, &lhs_idx, None) } - }; - - let rhs_elem = if rhs_scalar { - rhs_val - } else { - let rhs = NDArrayType::from_unifier_type(generator, ctx, rhs_ty) - .map_value(rhs_val.into_pointer_value(), None); - let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, idx); - - unsafe { rhs.data().get_unchecked(ctx, generator, &rhs_idx, None) } - }; - - value_fn(generator, ctx, (lhs_elem, rhs_elem)) - })?; - - Ok(res) -} - -/// LLVM-typed implementation for generating the implementation for `ndarray.zeros`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The `shape` parameter used to construct the `NDArray`. -fn call_ndarray_zeros_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - shape: BasicValueEnum<'ctx>, -) -> Result, String> { - let supported_types = [ - ctx.primitives.int32, - ctx.primitives.int64, - ctx.primitives.uint32, - ctx.primitives.uint64, - ctx.primitives.float, - ctx.primitives.bool, - ctx.primitives.str, - ]; - assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty))); - - let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; - ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| { - let value = ndarray_zero_value(generator, ctx, elem_ty); - - Ok(value) - })?; - - Ok(ndarray) -} - -/// LLVM-typed implementation for generating the implementation for `ndarray.ones`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The `shape` parameter used to construct the `NDArray`. -fn call_ndarray_ones_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - shape: BasicValueEnum<'ctx>, -) -> Result, String> { - let supported_types = [ - ctx.primitives.int32, - ctx.primitives.int64, - ctx.primitives.uint32, - ctx.primitives.uint64, - ctx.primitives.float, - ctx.primitives.bool, - ctx.primitives.str, - ]; - assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty))); - - let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; - ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| { - let value = ndarray_one_value(generator, ctx, elem_ty); - - Ok(value) - })?; - - Ok(ndarray) -} - -/// LLVM-typed implementation for generating the implementation for `ndarray.full`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The `shape` parameter used to construct the `NDArray`. -fn call_ndarray_full_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - shape: BasicValueEnum<'ctx>, - fill_value: BasicValueEnum<'ctx>, -) -> Result, String> { - let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; - ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| { - let value = if fill_value.is_pointer_value() { - let llvm_i1 = ctx.ctx.bool_type(); - - let copy = generator.gen_var_alloc(ctx, fill_value.get_type(), None)?; - - call_memcpy_generic( - ctx, - copy, - fill_value.into_pointer_value(), - fill_value.get_type().size_of().map(Into::into).unwrap(), - llvm_i1.const_zero(), - ); - - copy.into() - } else if fill_value.is_int_value() || fill_value.is_float_value() { - fill_value - } else { - codegen_unreachable!(ctx) - }; - - Ok(value) - })?; - - Ok(ndarray) -} - -/// Returns the number of dimensions for a multidimensional list as an [`IntValue`]. -fn llvm_ndlist_get_ndims<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ty: PointerType<'ctx>, -) -> IntValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); - - let list_ty = ListType::from_type(ty, llvm_usize); - let list_elem_ty = list_ty.element_type(); - - let ndims = llvm_usize.const_int(1, false); - match list_elem_ty { - AnyTypeEnum::PointerType(ptr_ty) - if ListType::is_representable(ptr_ty, llvm_usize).is_ok() => - { - ndims.const_add(llvm_ndlist_get_ndims(generator, ctx, ptr_ty)) - } - - AnyTypeEnum::PointerType(ptr_ty) - if NDArrayType::is_representable(ptr_ty, llvm_usize).is_ok() => - { - todo!("Getting ndims for list[ndarray] not supported") - } - - _ => ndims, - } -} - -/// Returns the number of dimensions for an array-like object as an [`IntValue`]. -fn llvm_arraylike_get_ndims<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - (ty, value): (Type, BasicValueEnum<'ctx>), -) -> IntValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); - - match value { - BasicValueEnum::PointerValue(v) - if NDArrayValue::is_representable(v, llvm_usize).is_ok() => - { - NDArrayType::from_unifier_type(generator, ctx, ty).map_value(v, None).load_ndims(ctx) - } - - BasicValueEnum::PointerValue(v) if ListValue::is_representable(v, llvm_usize).is_ok() => { - llvm_ndlist_get_ndims(generator, ctx, v.get_type()) - } - - _ => llvm_usize.const_zero(), - } -} - -/// Flattens and copies the values from a multidimensional list into an [`NDArrayValue`]. -fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - (dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>), - src_lst: ListValue<'ctx>, - dim: u64, -) -> Result<(), String> { - let llvm_i1 = ctx.ctx.bool_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let list_elem_ty = src_lst.get_type().element_type(); - - match list_elem_ty { - AnyTypeEnum::PointerType(ptr_ty) - if ListType::is_representable(ptr_ty, llvm_usize).is_ok() => - { - // The stride of elements in this dimension, i.e. the number of elements between arr[i] - // and arr[i + 1] in this dimension - let stride = call_ndarray_calc_size( - generator, - ctx, - &dst_arr.shape(), - (Some(llvm_usize.const_int(dim + 1, false)), None), - ); - - gen_for_range_callback( - generator, - ctx, - None, - true, - |_, _| Ok(llvm_usize.const_zero()), - (|_, ctx| Ok(src_lst.load_size(ctx, None)), false), - |_, _| Ok(llvm_usize.const_int(1, false)), - |generator, ctx, _, i| { - let offset = ctx.builder.build_int_mul(stride, i, "").unwrap(); - let offset = ctx - .builder - .build_int_mul( - offset, - ctx.builder - .build_int_truncate_or_bit_cast( - dst_arr.get_type().element_type().size_of().unwrap(), - offset.get_type(), - "", - ) - .unwrap(), - "", - ) - .unwrap(); - - let dst_ptr = - unsafe { ctx.builder.build_gep(dst_slice_ptr, &[offset], "").unwrap() }; - - let nested_lst_elem = ListValue::from_pointer_value( - unsafe { src_lst.data().get_unchecked(ctx, generator, &i, None) } - .into_pointer_value(), - llvm_usize, - None, - ); - - ndarray_from_ndlist_impl( - generator, - ctx, - (dst_arr, dst_ptr), - nested_lst_elem, - dim + 1, - )?; - - Ok(()) - }, - )?; - } - - AnyTypeEnum::PointerType(ptr_ty) - if NDArrayType::is_representable(ptr_ty, llvm_usize).is_ok() => - { - todo!("Not implemented for list[ndarray]") - } - - _ => { - let lst_len = src_lst.load_size(ctx, None); - let sizeof_elem = dst_arr.get_type().element_type().size_of().unwrap(); - let sizeof_elem = - ctx.builder.build_int_z_extend_or_bit_cast(sizeof_elem, llvm_usize, "").unwrap(); - - let cpy_len = ctx - .builder - .build_int_mul( - ctx.builder.build_int_z_extend_or_bit_cast(lst_len, llvm_usize, "").unwrap(), - sizeof_elem, - "", - ) - .unwrap(); - - call_memcpy_generic( - ctx, - dst_slice_ptr, - src_lst.data().base_ptr(ctx, generator), - cpy_len, - llvm_i1.const_zero(), - ); - } - } - - Ok(()) -} - -/// LLVM-typed implementation for `ndarray.array`. -fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - object: BasicValueEnum<'ctx>, - copy: IntValue<'ctx>, - ndmin: IntValue<'ctx>, -) -> Result, String> { - let llvm_i1 = ctx.ctx.bool_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let ndmin = ctx.builder.build_int_z_extend_or_bit_cast(ndmin, llvm_usize, "").unwrap(); - - // TODO(Derppening): Add assertions for sizes of different dimensions - - // object is not a pointer - 0-dim NDArray - if !object.is_pointer_value() { - let ndarray = create_ndarray_const_shape(generator, ctx, elem_ty, &[])?; - - unsafe { - ndarray.data().set_unchecked(ctx, generator, &llvm_usize.const_zero(), object); - } - - return Ok(ndarray); - } - - let object = object.into_pointer_value(); - - // object is an NDArray instance - copy object unless copy=0 && ndmin < object.ndims - if NDArrayValue::is_representable(object, llvm_usize).is_ok() { - let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); - let object = NDArrayValue::from_pointer_value(object, llvm_elem_ty, None, llvm_usize, None); - - let ndarray = gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - let copy_nez = ctx - .builder - .build_int_compare(IntPredicate::NE, copy, llvm_i1.const_zero(), "") - .unwrap(); - let ndmin_gt_ndims = ctx - .builder - .build_int_compare(IntPredicate::UGT, ndmin, object.load_ndims(ctx), "") - .unwrap(); - - Ok(ctx.builder.build_and(copy_nez, ndmin_gt_ndims, "").unwrap()) - }, - |generator, ctx| { - let ndarray = create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &object, - |_, ctx, object| { - let ndims = object.load_ndims(ctx); - let ndmin_gt_ndims = ctx - .builder - .build_int_compare(IntPredicate::UGT, ndmin, object.load_ndims(ctx), "") - .unwrap(); - - Ok(ctx - .builder - .build_select(ndmin_gt_ndims, ndmin, ndims, "") - .map(BasicValueEnum::into_int_value) - .unwrap()) - }, - |generator, ctx, object, idx| { - let ndims = object.load_ndims(ctx); - let ndmin = llvm_intrinsics::call_int_umax(ctx, ndims, ndmin, None); - // The number of dimensions to prepend 1's to - let offset = ctx.builder.build_int_sub(ndmin, ndims, "").unwrap(); - - Ok(gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare(IntPredicate::UGE, idx, offset, "") - .unwrap()) - }, - |_, _| Ok(Some(llvm_usize.const_int(1, false))), - |_, ctx| Ok(Some(ctx.builder.build_int_sub(idx, offset, "").unwrap())), - )? - .map(BasicValueEnum::into_int_value) - .unwrap()) - }, - )?; - - ndarray_sliced_copyto_impl( - generator, - ctx, - (ndarray, ndarray.data().base_ptr(ctx, generator)), - (object, object.data().base_ptr(ctx, generator)), - 0, - &[], - )?; - - Ok(Some(ndarray.as_base_value())) - }, - |_, _| Ok(Some(object.as_base_value())), - )?; - - return Ok(NDArrayValue::from_pointer_value( - ndarray.map(BasicValueEnum::into_pointer_value).unwrap(), - llvm_elem_ty, - None, - llvm_usize, - None, - )); - } - - // Remaining case: TList - assert!(ListValue::is_representable(object, llvm_usize).is_ok()); - let object = ListValue::from_pointer_value(object, llvm_usize, None); - - // The number of dimensions to prepend 1's to - let ndims = llvm_ndlist_get_ndims(generator, ctx, object.as_base_value().get_type()); - let ndmin = llvm_intrinsics::call_int_umax(ctx, ndims, ndmin, None); - let offset = ctx.builder.build_int_sub(ndmin, ndims, "").unwrap(); - - let ndarray = create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &object, - |generator, ctx, object| { - let ndims = llvm_ndlist_get_ndims(generator, ctx, object.as_base_value().get_type()); - let ndmin_gt_ndims = - ctx.builder.build_int_compare(IntPredicate::UGT, ndmin, ndims, "").unwrap(); - - Ok(ctx - .builder - .build_select(ndmin_gt_ndims, ndmin, ndims, "") - .map(BasicValueEnum::into_int_value) - .unwrap()) - }, - |generator, ctx, object, idx| { - Ok(gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx.builder.build_int_compare(IntPredicate::ULT, idx, offset, "").unwrap()) - }, - |_, _| Ok(Some(llvm_usize.const_int(1, false))), - |generator, ctx| { - let make_llvm_list = |elem_ty: BasicTypeEnum<'ctx>| { - ctx.ctx.struct_type( - &[elem_ty.ptr_type(AddressSpace::default()).into(), llvm_usize.into()], - false, - ) - }; - - let llvm_i8 = ctx.ctx.i8_type(); - let llvm_list_i8 = make_llvm_list(llvm_i8.into()); - let llvm_plist_i8 = llvm_list_i8.ptr_type(AddressSpace::default()); - - // Cast list to { i8*, usize } since we only care about the size - let lst = generator - .gen_var_alloc( - ctx, - ListType::new(generator, ctx.ctx, llvm_i8.into()).as_base_type().into(), - None, - ) - .unwrap(); - ctx.builder - .build_store( - lst, - ctx.builder - .build_bit_cast(object.as_base_value(), llvm_plist_i8, "") - .unwrap(), - ) - .unwrap(); - - let stop = ctx.builder.build_int_sub(idx, offset, "").unwrap(); - gen_for_range_callback( - generator, - ctx, - None, - true, - |_, _| Ok(llvm_usize.const_zero()), - (|_, _| Ok(stop), false), - |_, _| Ok(llvm_usize.const_int(1, false)), - |generator, ctx, _, _| { - let plist_plist_i8 = make_llvm_list(llvm_plist_i8.into()) - .ptr_type(AddressSpace::default()); - - let this_dim = ctx - .builder - .build_load(lst, "") - .map(BasicValueEnum::into_pointer_value) - .map(|v| ctx.builder.build_bit_cast(v, plist_plist_i8, "").unwrap()) - .map(BasicValueEnum::into_pointer_value) - .unwrap(); - let this_dim = - ListValue::from_pointer_value(this_dim, llvm_usize, None); - - // TODO: Assert this_dim.sz != 0 - - let next_dim = unsafe { - this_dim.data().get_unchecked( - ctx, - generator, - &llvm_usize.const_zero(), - None, - ) - } - .into_pointer_value(); - ctx.builder - .build_store( - lst, - ctx.builder - .build_bit_cast(next_dim, llvm_plist_i8, "") - .unwrap(), - ) - .unwrap(); - - Ok(()) - }, - )?; - - let lst = ListValue::from_pointer_value( - ctx.builder - .build_load(lst, "") - .map(BasicValueEnum::into_pointer_value) - .unwrap(), - llvm_usize, - None, - ); - - Ok(Some(lst.load_size(ctx, None))) - }, - )? - .map(BasicValueEnum::into_int_value) - .unwrap()) - }, - )?; - - ndarray_from_ndlist_impl( - generator, - ctx, - (ndarray, ndarray.data().base_ptr(ctx, generator)), - object, - 0, - )?; - - Ok(ndarray) -} - -/// LLVM-typed implementation for generating the implementation for `ndarray.eye`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - nrows: IntValue<'ctx>, - ncols: IntValue<'ctx>, - offset: IntValue<'ctx>, -) -> Result, String> { - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let nrows = ctx.builder.build_int_z_extend_or_bit_cast(nrows, llvm_usize, "").unwrap(); - let ncols = ctx.builder.build_int_z_extend_or_bit_cast(ncols, llvm_usize, "").unwrap(); - - let ndarray = create_ndarray_const_shape(generator, ctx, elem_ty, &[nrows, ncols])?; - - ndarray_fill_indexed(generator, ctx, ndarray, |generator, ctx, indices| { - let (row, col) = unsafe { - ( - indices.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None), - indices.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None), - ) - }; - - let col_with_offset = ctx - .builder - .build_int_add( - col, - ctx.builder.build_int_s_extend_or_bit_cast(offset, llvm_i32, "").unwrap(), - "", - ) - .unwrap(); - let is_on_diag = - ctx.builder.build_int_compare(IntPredicate::EQ, row, col_with_offset, "").unwrap(); - - let zero = ndarray_zero_value(generator, ctx, elem_ty); - let one = ndarray_one_value(generator, ctx, elem_ty); - - let value = ctx.builder.build_select(is_on_diag, one, zero, "").unwrap(); - - Ok(value) - })?; - - Ok(ndarray) -} - -/// Copies a slice of an [`NDArrayValue`] to another. -/// -/// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `shape` -/// fields should be populated before calling this function. -/// - `dst_slice_ptr`: The [`PointerValue`] to the first element of the currently processing -/// dimensional slice in the destination array. -/// - `src_arr`: The [`NDArrayValue`] instance of the source array. -/// - `src_slice_ptr`: The [`PointerValue`] to the first element of the currently processing -/// dimensional slice in the source array. -/// - `dim`: The index of the currently processing dimension. -/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to -/// this dimension. The `start`/`stop` values of each slice must be non-negative indices. -fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - (dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>), - (src_arr, src_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>), - dim: u64, - slices: &[(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)], -) -> Result<(), String> { - let llvm_i1 = ctx.ctx.bool_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - assert_eq!(dst_arr.get_type().element_type(), src_arr.get_type().element_type()); - - let sizeof_elem = dst_arr.get_type().element_type().size_of().unwrap(); - - // If there are no (remaining) slice expressions, memcpy the entire dimension - if slices.is_empty() { - let stride = call_ndarray_calc_size( - generator, - ctx, - &src_arr.shape(), - (Some(llvm_usize.const_int(dim, false)), None), - ); - let stride = - ctx.builder.build_int_z_extend_or_bit_cast(stride, sizeof_elem.get_type(), "").unwrap(); - - let cpy_len = ctx.builder.build_int_mul(stride, sizeof_elem, "").unwrap(); - - call_memcpy_generic(ctx, dst_slice_ptr, src_slice_ptr, cpy_len, llvm_i1.const_zero()); - - return Ok(()); - } - - // The stride of elements in this dimension, i.e. the number of elements between arr[i] and - // arr[i + 1] in this dimension - let src_stride = call_ndarray_calc_size( - generator, - ctx, - &src_arr.shape(), - (Some(llvm_usize.const_int(dim + 1, false)), None), - ); - let dst_stride = call_ndarray_calc_size( - generator, - ctx, - &dst_arr.shape(), - (Some(llvm_usize.const_int(dim + 1, false)), None), - ); - - let (start, stop, step) = slices[0]; - let start = ctx.builder.build_int_s_extend_or_bit_cast(start, llvm_usize, "").unwrap(); - let stop = ctx.builder.build_int_s_extend_or_bit_cast(stop, llvm_usize, "").unwrap(); - let step = ctx.builder.build_int_s_extend_or_bit_cast(step, llvm_usize, "").unwrap(); - - let dst_i_addr = generator.gen_var_alloc(ctx, start.get_type().into(), None).unwrap(); - ctx.builder.build_store(dst_i_addr, start.get_type().const_zero()).unwrap(); - - gen_for_range_callback( - generator, - ctx, - None, - false, - |_, _| Ok(start), - (|_, _| Ok(stop), true), - |_, _| Ok(step), - |generator, ctx, _, src_i| { - // Calculate the offset of the active slice - let src_data_offset = ctx.builder.build_int_mul(src_stride, src_i, "").unwrap(); - let src_data_offset = ctx - .builder - .build_int_mul( - src_data_offset, - ctx.builder - .build_int_z_extend_or_bit_cast(sizeof_elem, src_data_offset.get_type(), "") - .unwrap(), - "", - ) - .unwrap(); - let dst_i = - ctx.builder.build_load(dst_i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); - let dst_data_offset = ctx.builder.build_int_mul(dst_stride, dst_i, "").unwrap(); - let dst_data_offset = ctx - .builder - .build_int_mul( - dst_data_offset, - ctx.builder - .build_int_z_extend_or_bit_cast(sizeof_elem, dst_data_offset.get_type(), "") - .unwrap(), - "", - ) - .unwrap(); - - let (src_ptr, dst_ptr) = unsafe { - ( - ctx.builder.build_gep(src_slice_ptr, &[src_data_offset], "").unwrap(), - ctx.builder.build_gep(dst_slice_ptr, &[dst_data_offset], "").unwrap(), - ) - }; - - ndarray_sliced_copyto_impl( - generator, - ctx, - (dst_arr, dst_ptr), - (src_arr, src_ptr), - dim + 1, - &slices[1..], - )?; - - let dst_i = - ctx.builder.build_load(dst_i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); - let dst_i_add1 = - ctx.builder.build_int_add(dst_i, llvm_usize.const_int(1, false), "").unwrap(); - ctx.builder.build_store(dst_i_addr, dst_i_add1).unwrap(); - - Ok(()) - }, - )?; - - Ok(()) -} - -/// Copies a [`NDArrayValue`] using slices. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to -/// this dimension. The `start`/`stop` values of each slice must be positive indices. -pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - this: NDArrayValue<'ctx>, - slices: &[(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)], -) -> Result, String> { - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); - - let ndarray = - if slices.is_empty() { - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &this, - |_, ctx, shape| Ok(shape.load_ndims(ctx)), - |generator, ctx, shape, idx| unsafe { - Ok(shape.shape().get_typed_unchecked(ctx, generator, &idx, None)) - }, - )? - } else { - let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty, None) - .construct_dyn_ndims(generator, ctx, this.load_ndims(ctx), None); - - // Populate the first slices.len() dimensions by computing the size of each dim slice - for (i, (start, stop, step)) in slices.iter().enumerate() { - // HACK: workaround calculate_len_for_slice_range requiring exclusive stop - let stop = ctx - .builder - .build_select( - ctx.builder - .build_int_compare( - IntPredicate::SLT, - *step, - llvm_i32.const_zero(), - "is_neg", - ) - .unwrap(), - ctx.builder - .build_int_sub(*stop, llvm_i32.const_int(1, true), "e_min_one") - .unwrap(), - ctx.builder - .build_int_add(*stop, llvm_i32.const_int(1, true), "e_add_one") - .unwrap(), - "final_e", - ) - .map(BasicValueEnum::into_int_value) - .unwrap(); - - let slice_len = calculate_len_for_slice_range(generator, ctx, *start, stop, *step); - let slice_len = - ctx.builder.build_int_z_extend_or_bit_cast(slice_len, llvm_usize, "").unwrap(); - - unsafe { - ndarray.shape().set_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(i as u64, false), - slice_len, - ); - } - } - - // Populate the rest by directly copying the dim size from the source array - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_int(slices.len() as u64, false), - (this.load_ndims(ctx), false), - |generator, ctx, _, idx| { - unsafe { - let shape = this.shape().get_typed_unchecked(ctx, generator, &idx, None); - ndarray.shape().set_typed_unchecked(ctx, generator, &idx, shape); - } - - Ok(()) - }, - llvm_usize.const_int(1, false), - ) - .unwrap(); - - unsafe { ndarray.create_data(generator, ctx) }; - - ndarray - }; - - ndarray_sliced_copyto_impl( - generator, - ctx, - (ndarray, ndarray.data().base_ptr(ctx, generator)), - (this, this.data().base_ptr(ctx, generator)), - 0, - slices, - )?; - - Ok(ndarray) -} - -/// LLVM-typed implementation for generating the implementation for `ndarray.copy`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - this: NDArrayValue<'ctx>, -) -> Result, String> { - ndarray_sliced_copy(generator, ctx, elem_ty, this, &[]) -} - -pub fn ndarray_elementwise_unaryop_impl<'ctx, 'a, G, MapFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - elem_ty: Type, - res: Option>, - operand: NDArrayValue<'ctx>, - map_fn: MapFn, -) -> Result, String> -where - G: CodeGenerator + ?Sized, - MapFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - BasicValueEnum<'ctx>, - ) -> Result, String>, -{ - let res = res.unwrap_or_else(|| { - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &operand, - |_, ctx, v| Ok(v.load_ndims(ctx)), - |generator, ctx, v, idx| unsafe { - Ok(v.shape().get_typed_unchecked(ctx, generator, &idx, None)) - }, - ) - .unwrap() - }); - - ndarray_fill_mapping(generator, ctx, operand, res, |generator, ctx, elem| { - map_fn(generator, ctx, elem) - })?; - - Ok(res) -} - -/// LLVM-typed implementation for computing elementwise binary operations on two input operands. -/// -/// If the operand is a `ndarray`, the broadcast index corresponding to each element in the output -/// is computed, the element accessed and used as an operand of the `value_fn` arguments tuple. -/// Otherwise, the operand is treated as a scalar value, and is used as an operand of the -/// `value_fn` arguments tuple for all output elements. -/// -/// The second element of the tuple indicates whether to treat the operand value as a `ndarray` -/// (which would be accessed by its broadcast index) or as a scalar value (which would be -/// broadcast to all elements). -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be -/// written to a new `ndarray`. -/// * `value_fn` - Function mapping the two input elements into the result. -/// -/// # Panic -/// -/// This function will panic if neither input operands (`lhs` or `rhs`) is a `ndarray`. -pub fn ndarray_elementwise_binop_impl<'ctx, 'a, G, ValueFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - elem_ty: Type, - res: Option>, - lhs: (Type, BasicValueEnum<'ctx>, bool), - rhs: (Type, BasicValueEnum<'ctx>, bool), - value_fn: ValueFn, -) -> Result, String> -where - G: CodeGenerator + ?Sized, - ValueFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>), - ) -> Result, String>, -{ - let (lhs_ty, lhs_val, lhs_scalar) = lhs; - let (rhs_ty, rhs_val, rhs_scalar) = rhs; - - assert!( - !(lhs_scalar && rhs_scalar), - "One of the operands must be a ndarray instance: `{}`, `{}`", - lhs_val.get_type(), - rhs_val.get_type() - ); - - let ndarray = res.unwrap_or_else(|| { - if lhs_scalar && rhs_scalar { - let lhs_val = NDArrayType::from_unifier_type(generator, ctx, lhs_ty) - .map_value(lhs_val.into_pointer_value(), None); - let rhs_val = NDArrayType::from_unifier_type(generator, ctx, rhs_ty) - .map_value(rhs_val.into_pointer_value(), None); - - let ndarray_dims = call_ndarray_calc_broadcast(generator, ctx, lhs_val, rhs_val); - - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &ndarray_dims, - |generator, ctx, v| Ok(v.size(ctx, generator)), - |generator, ctx, v, idx| unsafe { - Ok(v.get_typed_unchecked(ctx, generator, &idx, None)) - }, - ) - .unwrap() - } else { - let ndarray = NDArrayType::from_unifier_type( - generator, - ctx, - if lhs_scalar { rhs_ty } else { lhs_ty }, - ) - .map_value(if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(), None); - - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &ndarray, - |_, ctx, v| Ok(v.load_ndims(ctx)), - |generator, ctx, v, idx| unsafe { - Ok(v.shape().get_typed_unchecked(ctx, generator, &idx, None)) - }, - ) - .unwrap() - } - }); - - ndarray_broadcast_fill(generator, ctx, ndarray, lhs, rhs, |generator, ctx, elems| { - value_fn(generator, ctx, elems) - })?; - - Ok(ndarray) -} - -/// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be -/// written to a new `ndarray`. -pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - res: Option>, - lhs: NDArrayValue<'ctx>, - rhs: NDArrayValue<'ctx>, -) -> Result, String> { - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - if cfg!(debug_assertions) { - let lhs_ndims = lhs.load_ndims(ctx); - let rhs_ndims = rhs.load_ndims(ctx); - - // lhs.ndims == 2 - ctx.make_assert( - generator, - ctx.builder - .build_int_compare(IntPredicate::EQ, lhs_ndims, llvm_usize.const_int(2, false), "") - .unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - - // rhs.ndims == 2 - ctx.make_assert( - generator, - ctx.builder - .build_int_compare(IntPredicate::EQ, rhs_ndims, llvm_usize.const_int(2, false), "") - .unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - - if let Some(res) = res { - let res_ndims = res.load_ndims(ctx); - let res_dim0 = unsafe { - res.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - let res_dim1 = unsafe { - res.shape().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - let lhs_dim0 = unsafe { - lhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - let rhs_dim1 = unsafe { - rhs.shape().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - - // res.ndims == 2 - ctx.make_assert( - generator, - ctx.builder - .build_int_compare( - IntPredicate::EQ, - res_ndims, - llvm_usize.const_int(2, false), - "", - ) - .unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - - // res.dims[0] == lhs.dims[0] - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim0, res_dim0, "").unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - - // res.dims[1] == rhs.dims[0] - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::EQ, rhs_dim1, res_dim1, "").unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - } - } - - if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { - let lhs_dim1 = unsafe { - lhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - }; - let rhs_dim0 = unsafe { - rhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - - // lhs.dims[1] == rhs.dims[0] - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim1, rhs_dim0, "").unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - } - - let lhs = if res.is_some_and(|res| res.as_base_value() == lhs.as_base_value()) { - ndarray_copy_impl(generator, ctx, elem_ty, lhs)? - } else { - lhs - }; - - let ndarray = res.unwrap_or_else(|| { - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &(lhs, rhs), - |_, _, _| Ok(llvm_usize.const_int(2, false)), - |generator, ctx, (lhs, rhs), idx| { - gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare(IntPredicate::EQ, idx, llvm_usize.const_zero(), "") - .unwrap()) - }, - |generator, ctx| { - Ok(Some(unsafe { - lhs.shape().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_zero(), - None, - ) - })) - }, - |generator, ctx| { - Ok(Some(unsafe { - rhs.shape().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - })) - }, - ) - .map(|v| v.map(BasicValueEnum::into_int_value).unwrap()) - }, - ) - .unwrap() - }); - - let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); - - ndarray_fill_indexed(generator, ctx, ndarray, |generator, ctx, idx| { - llvm_intrinsics::call_expect( - ctx, - idx.size(ctx, generator).get_type().const_int(2, false), - idx.size(ctx, generator), - None, - ); - - let common_dim = { - let lhs_idx1 = unsafe { - lhs.shape().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - let rhs_idx0 = unsafe { - rhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - - let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None); - - ctx.builder.build_int_truncate(idx, llvm_i32, "").unwrap() - }; - - let idx0 = unsafe { - let idx0 = idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None); - - ctx.builder.build_int_truncate(idx0, llvm_i32, "").unwrap() - }; - let idx1 = unsafe { - let idx1 = - idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None); - - ctx.builder.build_int_truncate(idx1, llvm_i32, "").unwrap() - }; - - let result_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; - let result_identity = ndarray_zero_value(generator, ctx, elem_ty); - ctx.builder.build_store(result_addr, result_identity).unwrap(); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_i32.const_zero(), - (common_dim, false), - |generator, ctx, _, i| { - let i = ctx.builder.build_int_truncate(i, llvm_i32, "").unwrap(); - - let ab_idx = generator.gen_array_var_alloc( - ctx, - llvm_i32.into(), - llvm_usize.const_int(2, false), - None, - )?; - - let a = unsafe { - ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), idx0.into()); - ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_int(1, false), i.into()); - - lhs.data().get_unchecked(ctx, generator, &ab_idx, None) - }; - let b = unsafe { - ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), i.into()); - ab_idx.set_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - idx1.into(), - ); - - rhs.data().get_unchecked(ctx, generator, &ab_idx, None) - }; - - let a_mul_b = gen_binop_expr_with_values( - generator, - ctx, - (&Some(elem_ty), a), - Binop::normal(Operator::Mult), - (&Some(elem_ty), b), - ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum(ctx, generator, elem_ty)?; - - let result = ctx.builder.build_load(result_addr, "").unwrap(); - let result = gen_binop_expr_with_values( - generator, - ctx, - (&Some(elem_ty), result), - Binop::normal(Operator::Add), - (&Some(elem_ty), a_mul_b), - ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum(ctx, generator, elem_ty)?; - ctx.builder.build_store(result_addr, result).unwrap(); - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - let result = ctx.builder.build_load(result_addr, "").unwrap(); - Ok(result) - })?; - - Ok(ndarray) -} - /// Generates LLVM IR for `ndarray.empty`. pub fn gen_ndarray_empty<'ctx>( context: &mut CodeGenContext<'ctx, '_>, @@ -1723,8 +36,15 @@ pub fn gen_ndarray_empty<'ctx>( let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; - call_ndarray_empty_impl(generator, context, context.primitives.float, shape_arg) - .map(NDArrayValue::into) + let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let llvm_dtype = context.get_llvm_type(generator, dtype); + let ndims = extract_ndims(&context.unifier, ndims); + + let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); + + let ndarray = NDArrayType::new(context, llvm_dtype, ndims) + .construct_numpy_empty(generator, context, &shape, None); + Ok(ndarray.as_abi_value(context)) } /// Generates LLVM IR for `ndarray.zeros`. @@ -1741,8 +61,15 @@ pub fn gen_ndarray_zeros<'ctx>( let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; - call_ndarray_zeros_impl(generator, context, context.primitives.float, shape_arg) - .map(NDArrayValue::into) + let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let llvm_dtype = context.get_llvm_type(generator, dtype); + let ndims = extract_ndims(&context.unifier, ndims); + + let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); + + let ndarray = NDArrayType::new(context, llvm_dtype, ndims) + .construct_numpy_zeros(generator, context, dtype, &shape, None); + Ok(ndarray.as_abi_value(context)) } /// Generates LLVM IR for `ndarray.ones`. @@ -1759,8 +86,15 @@ pub fn gen_ndarray_ones<'ctx>( let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; - call_ndarray_ones_impl(generator, context, context.primitives.float, shape_arg) - .map(NDArrayValue::into) + let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let llvm_dtype = context.get_llvm_type(generator, dtype); + let ndims = extract_ndims(&context.unifier, ndims); + + let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); + + let ndarray = NDArrayType::new(context, llvm_dtype, ndims) + .construct_numpy_ones(generator, context, dtype, &shape, None); + Ok(ndarray.as_abi_value(context)) } /// Generates LLVM IR for `ndarray.full`. @@ -1780,8 +114,20 @@ pub fn gen_ndarray_full<'ctx>( let fill_value_arg = args[1].1.clone().to_basic_value_enum(context, generator, fill_value_ty)?; - call_ndarray_full_impl(generator, context, fill_value_ty, shape_arg, fill_value_arg) - .map(NDArrayValue::into) + let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let llvm_dtype = context.get_llvm_type(generator, dtype); + let ndims = extract_ndims(&context.unifier, ndims); + + let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); + + let ndarray = NDArrayType::new(context, llvm_dtype, ndims).construct_numpy_full( + generator, + context, + &shape, + fill_value_arg, + None, + ); + Ok(ndarray.as_abi_value(context)) } pub fn gen_ndarray_array<'ctx>( @@ -1795,26 +141,6 @@ pub fn gen_ndarray_array<'ctx>( assert!(matches!(args.len(), 1..=3)); let obj_ty = fun.0.args[0].ty; - let obj_elem_ty = match &*context.unifier.get_ty(obj_ty) { - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - unpack_ndarray_var_tys(&mut context.unifier, obj_ty).0 - } - - TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => { - let mut ty = *params.iter().next().unwrap().1; - while let TypeEnum::TObj { obj_id, params, .. } = &*context.unifier.get_ty_immutable(ty) - { - if *obj_id != PrimDef::List.id() { - break; - } - - ty = *params.iter().next().unwrap().1; - } - ty - } - - _ => obj_ty, - }; let obj_arg = args[0].1.clone().to_basic_value_enum(context, generator, obj_ty)?; let copy_arg = if let Some(arg) = @@ -1830,28 +156,17 @@ pub fn gen_ndarray_array<'ctx>( ) }; - let ndmin_arg = if let Some(arg) = - args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[2].name)) - { - let ndmin_ty = fun.0.args[2].ty; - arg.1.clone().to_basic_value_enum(context, generator, ndmin_ty)? - } else { - context.gen_symbol_val( - generator, - fun.0.args[2].default_value.as_ref().unwrap(), - fun.0.args[2].ty, - ) - }; + // The ndmin argument is ignored. We can simply force the ndarray's number of dimensions to be + // the `ndims` of the function return type. + let (_, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let ndims = extract_ndims(&context.unifier, ndims); - call_ndarray_array_impl( - generator, - context, - obj_elem_ty, - obj_arg, - copy_arg.into_int_value(), - ndmin_arg.into_int_value(), - ) - .map(NDArrayValue::into) + let copy = generator.bool_to_i1(context, copy_arg.into_int_value()); + let ndarray = NDArrayType::from_unifier_type(generator, context, fun.0.ret) + .construct_numpy_array(generator, context, (obj_ty, obj_arg), copy, None) + .atleast_nd(generator, context, ndims); + + Ok(ndarray.as_abi_value(context)) } /// Generates LLVM IR for `ndarray.eye`. @@ -1890,15 +205,27 @@ pub fn gen_ndarray_eye<'ctx>( )) }?; - call_ndarray_eye_impl( - generator, - context, - context.primitives.float, - nrows_arg.into_int_value(), - ncols_arg.into_int_value(), - offset_arg.into_int_value(), - ) - .map(NDArrayValue::into) + let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + + let llvm_usize = context.get_size_type(); + let llvm_dtype = context.get_llvm_type(generator, dtype); + + let nrows = context + .builder + .build_int_s_extend_or_bit_cast(nrows_arg.into_int_value(), llvm_usize, "") + .unwrap(); + let ncols = context + .builder + .build_int_s_extend_or_bit_cast(ncols_arg.into_int_value(), llvm_usize, "") + .unwrap(); + let offset = context + .builder + .build_int_s_extend_or_bit_cast(offset_arg.into_int_value(), llvm_usize, "") + .unwrap(); + + let ndarray = NDArrayType::new(context, llvm_dtype, 2) + .construct_numpy_eye(generator, context, dtype, nrows, ncols, offset, None); + Ok(ndarray.as_abi_value(context)) } /// Generates LLVM IR for `ndarray.identity`. @@ -1912,20 +239,21 @@ pub fn gen_ndarray_identity<'ctx>( assert!(obj.is_none()); assert_eq!(args.len(), 1); - let llvm_usize = generator.get_size_type(context.ctx); - let n_ty = fun.0.args[0].ty; let n_arg = args[0].1.clone().to_basic_value_enum(context, generator, n_ty)?; - call_ndarray_eye_impl( - generator, - context, - context.primitives.float, - n_arg.into_int_value(), - n_arg.into_int_value(), - llvm_usize.const_zero(), - ) - .map(NDArrayValue::into) + let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + + let llvm_usize = context.get_size_type(); + let llvm_dtype = context.get_llvm_type(generator, dtype); + + let n = context + .builder + .build_int_s_extend_or_bit_cast(n_arg.into_int_value(), llvm_usize, "") + .unwrap(); + let ndarray = NDArrayType::new(context, llvm_dtype, 2) + .construct_numpy_identity(generator, context, dtype, n, None); + Ok(ndarray.as_abi_value(context)) } /// Generates LLVM IR for `ndarray.copy`. @@ -1940,19 +268,13 @@ pub fn gen_ndarray_copy<'ctx>( assert!(args.is_empty()); let this_ty = obj.as_ref().unwrap().0; - let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty); let this_arg = obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?; - let llvm_this_ty = NDArrayType::from_unifier_type(generator, context, this_ty); - - ndarray_copy_impl( - generator, - context, - this_elem_ty, - llvm_this_ty.map_value(this_arg.into_pointer_value(), None), - ) - .map(NDArrayValue::into) + let this = NDArrayType::from_unifier_type(generator, context, this_ty) + .map_pointer_value(this_arg.into_pointer_value(), None); + let ndarray = this.make_copy(generator, context); + Ok(ndarray.as_abi_value(context)) } /// Generates LLVM IR for `ndarray.fill`. @@ -1967,446 +289,17 @@ pub fn gen_ndarray_fill<'ctx>( assert_eq!(args.len(), 1); let this_ty = obj.as_ref().unwrap().0; - let this_arg = obj - .as_ref() - .unwrap() - .1 - .clone() - .to_basic_value_enum(context, generator, this_ty)? - .into_pointer_value(); + let this_arg = + obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?; let value_ty = fun.0.args[0].ty; let value_arg = args[0].1.clone().to_basic_value_enum(context, generator, value_ty)?; - let llvm_this_ty = NDArrayType::from_unifier_type(generator, context, this_ty); - - ndarray_fill_flattened( - generator, - context, - llvm_this_ty.map_value(this_arg, None), - |generator, ctx, _| { - let value = if value_arg.is_pointer_value() { - let llvm_i1 = ctx.ctx.bool_type(); - - let copy = generator.gen_var_alloc(ctx, value_arg.get_type(), None)?; - - call_memcpy_generic( - ctx, - copy, - value_arg.into_pointer_value(), - value_arg.get_type().size_of().map(Into::into).unwrap(), - llvm_i1.const_zero(), - ); - - copy.into() - } else if value_arg.is_int_value() || value_arg.is_float_value() { - value_arg - } else { - codegen_unreachable!(ctx) - }; - - Ok(value) - }, - )?; - + let this = NDArrayType::from_unifier_type(generator, context, this_ty) + .map_pointer_value(this_arg.into_pointer_value(), None); + this.fill(generator, context, value_arg); Ok(()) } -/// Generates LLVM IR for `ndarray.transpose`. -pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - (x1_ty, x1): (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - const FN_NAME: &str = "ndarray_transpose"; - - let llvm_usize = generator.get_size_type(ctx.ctx); - - if let BasicValueEnum::PointerValue(n1) = x1 { - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty); - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1 = llvm_ndarray_ty.map_value(n1, None); - let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); - - // Dimensions are reversed in the transposed array - let out = create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &n1, - |_, ctx, n| Ok(n.load_ndims(ctx)), - |generator, ctx, n, idx| { - let new_idx = ctx.builder.build_int_sub(n.load_ndims(ctx), idx, "").unwrap(); - let new_idx = ctx - .builder - .build_int_sub(new_idx, new_idx.get_type().const_int(1, false), "") - .unwrap(); - unsafe { Ok(n.shape().get_typed_unchecked(ctx, generator, &new_idx, None)) } - }, - ) - .unwrap(); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (n_sz, false), - |generator, ctx, _, idx| { - let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) }; - - let new_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; - let rem_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; - ctx.builder.build_store(new_idx, llvm_usize.const_zero()).unwrap(); - ctx.builder.build_store(rem_idx, idx).unwrap(); - - // Incrementally calculate the new index in the transposed array - // For each index, we first decompose it into the n-dims and use those to reconstruct the new index - // The formula used for indexing is: - // idx = dim_n * ( ... (dim2 * (dim0 * dim1) + dim1) + dim2 ... ) + dim_n - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (n1.load_ndims(ctx), false), - |generator, ctx, _, ndim| { - let ndim_rev = - ctx.builder.build_int_sub(n1.load_ndims(ctx), ndim, "").unwrap(); - let ndim_rev = ctx - .builder - .build_int_sub(ndim_rev, llvm_usize.const_int(1, false), "") - .unwrap(); - let dim = unsafe { - n1.shape().get_typed_unchecked(ctx, generator, &ndim_rev, None) - }; - - let rem_idx_val = - ctx.builder.build_load(rem_idx, "").unwrap().into_int_value(); - let new_idx_val = - ctx.builder.build_load(new_idx, "").unwrap().into_int_value(); - - let add_component = - ctx.builder.build_int_unsigned_rem(rem_idx_val, dim, "").unwrap(); - let rem_idx_val = - ctx.builder.build_int_unsigned_div(rem_idx_val, dim, "").unwrap(); - - let new_idx_val = ctx.builder.build_int_mul(new_idx_val, dim, "").unwrap(); - let new_idx_val = - ctx.builder.build_int_add(new_idx_val, add_component, "").unwrap(); - - ctx.builder.build_store(rem_idx, rem_idx_val).unwrap(); - ctx.builder.build_store(new_idx, new_idx_val).unwrap(); - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - let new_idx_val = ctx.builder.build_load(new_idx, "").unwrap().into_int_value(); - unsafe { out.data().set_unchecked(ctx, generator, &new_idx_val, elem) }; - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - Ok(out.as_base_value().into()) - } else { - codegen_unreachable!( - ctx, - "{FN_NAME}() not supported for '{}'", - format!("'{}'", ctx.unifier.stringify(x1_ty)) - ) - } -} - -/// LLVM-typed implementation for generating the implementation for `ndarray.reshape`. -/// -/// * `x1` - `NDArray` to reshape. -/// * `shape` - The `shape` parameter used to construct the new `NDArray`. -/// Just like numpy, the `shape` argument can be: -/// 1. A list of `int32`; e.g., `np.reshape(arr, [600, -1, 3])` -/// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))` -/// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)` -/// -/// Note that unlike other generating functions, one of the dimensions in the shape can be negative. -pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - shape: (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - const FN_NAME: &str = "ndarray_reshape"; - let (x1_ty, x1) = x1; - let (_, shape) = shape; - - let llvm_usize = generator.get_size_type(ctx.ctx); - - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty); - let n1 = llvm_ndarray_ty.map_value(n1, None); - let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); - - let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; - let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; - ctx.builder.build_store(acc, llvm_usize.const_int(1, false)).unwrap(); - ctx.builder.build_store(num_neg, llvm_usize.const_zero()).unwrap(); - - let out = match shape { - BasicValueEnum::PointerValue(shape_list_ptr) - if ListValue::is_representable(shape_list_ptr, llvm_usize).is_ok() => - { - // 1. A list of ints; e.g., `np.reshape(arr, [int64(600), int64(800, -1])` - - let shape_list = ListValue::from_pointer_value(shape_list_ptr, llvm_usize, None); - // Check for -1 in dimensions - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (shape_list.load_size(ctx, None), false), - |generator, ctx, _, idx| { - let ele = - shape_list.data().get(ctx, generator, &idx, None).into_int_value(); - let ele = ctx.builder.build_int_s_extend(ele, llvm_usize, "").unwrap(); - - gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare( - IntPredicate::SLT, - ele, - llvm_usize.const_zero(), - "", - ) - .unwrap()) - }, - |_, ctx| -> Result, String> { - let num_neg_value = - ctx.builder.build_load(num_neg, "").unwrap().into_int_value(); - let num_neg_value = ctx - .builder - .build_int_add( - num_neg_value, - llvm_usize.const_int(1, false), - "", - ) - .unwrap(); - ctx.builder.build_store(num_neg, num_neg_value).unwrap(); - Ok(None) - }, - |_, ctx| { - let acc_value = - ctx.builder.build_load(acc, "").unwrap().into_int_value(); - let acc_value = - ctx.builder.build_int_mul(acc_value, ele, "").unwrap(); - ctx.builder.build_store(acc, acc_value).unwrap(); - Ok(None) - }, - )?; - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value(); - let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap(); - // Generate the output shape by filling -1 with `rem` - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &shape_list, - |_, ctx, _| Ok(shape_list.load_size(ctx, None)), - |generator, ctx, shape_list, idx| { - let dim = - shape_list.data().get(ctx, generator, &idx, None).into_int_value(); - let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap(); - - Ok(gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare( - IntPredicate::SLT, - dim, - llvm_usize.const_zero(), - "", - ) - .unwrap()) - }, - |_, _| Ok(Some(rem)), - |_, _| Ok(Some(dim)), - )? - .unwrap() - .into_int_value()) - }, - ) - } - BasicValueEnum::StructValue(shape_tuple) => { - // 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))` - - let ndims = shape_tuple.get_type().count_fields(); - // Check for -1 in dims - for dim_i in 0..ndims { - let dim = ctx - .builder - .build_extract_value(shape_tuple, dim_i, "") - .unwrap() - .into_int_value(); - let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap(); - - gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare( - IntPredicate::SLT, - dim, - llvm_usize.const_zero(), - "", - ) - .unwrap()) - }, - |_, ctx| -> Result, String> { - let num_negs = - ctx.builder.build_load(num_neg, "").unwrap().into_int_value(); - let num_negs = ctx - .builder - .build_int_add(num_negs, llvm_usize.const_int(1, false), "") - .unwrap(); - ctx.builder.build_store(num_neg, num_negs).unwrap(); - Ok(None) - }, - |_, ctx| { - let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value(); - let acc_val = ctx.builder.build_int_mul(acc_val, dim, "").unwrap(); - ctx.builder.build_store(acc, acc_val).unwrap(); - Ok(None) - }, - )?; - } - - let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value(); - let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap(); - let mut shape = Vec::with_capacity(ndims as usize); - - // Reconstruct shape filling negatives with rem - for dim_i in 0..ndims { - let dim = ctx - .builder - .build_extract_value(shape_tuple, dim_i, "") - .unwrap() - .into_int_value(); - let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap(); - - let dim = gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare( - IntPredicate::SLT, - dim, - llvm_usize.const_zero(), - "", - ) - .unwrap()) - }, - |_, _| Ok(Some(rem)), - |_, _| Ok(Some(dim)), - )? - .unwrap() - .into_int_value(); - shape.push(dim); - } - create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice()) - } - BasicValueEnum::IntValue(shape_int) => { - // 3. A scalar `int32`; e.g., `np.reshape(arr, 3)` - let shape_int = gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare( - IntPredicate::SLT, - shape_int, - llvm_usize.const_zero(), - "", - ) - .unwrap()) - }, - |_, _| Ok(Some(n_sz)), - |_, ctx| { - Ok(Some(ctx.builder.build_int_s_extend(shape_int, llvm_usize, "").unwrap())) - }, - )? - .unwrap() - .into_int_value(); - create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int]) - } - _ => codegen_unreachable!(ctx), - } - .unwrap(); - - // Only allow one dimension to be negative - let num_negs = ctx.builder.build_load(num_neg, "").unwrap().into_int_value(); - ctx.make_assert( - generator, - ctx.builder - .build_int_compare(IntPredicate::ULT, num_negs, llvm_usize.const_int(2, false), "") - .unwrap(), - "0:ValueError", - "can only specify one unknown dimension", - [None, None, None], - ctx.current_loc, - ); - - // The new shape must be compatible with the old shape - let out_sz = call_ndarray_calc_size(generator, ctx, &out.shape(), (None, None)); - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(), - "0:ValueError", - "cannot reshape array of size {0} into provided shape of size {1}", - [Some(n_sz), Some(out_sz), None], - ctx.current_loc, - ); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (n_sz, false), - |generator, ctx, _, idx| { - let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) }; - unsafe { out.data().set_unchecked(ctx, generator, &idx, elem) }; - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - Ok(out.as_base_value().into()) - } else { - codegen_unreachable!( - ctx, - "{FN_NAME}() not supported for '{}'", - format!("'{}'", ctx.unifier.stringify(x1_ty)) - ) - } -} - /// Generates LLVM IR for `ndarray.dot`. /// Calculate inner product of two vectors or literals /// For matrix multiplication use `np_matmul` @@ -2416,89 +309,103 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>( pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "ndarray_dot"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; - - let llvm_usize = generator.get_size_type(ctx.ctx); match (x1, x2) { (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => { - let n1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(n1, None); - let n2 = NDArrayType::from_unifier_type(generator, ctx, x2_ty).map_value(n2, None); + let a = + NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_pointer_value(n1, None); + let b = + NDArrayType::from_unifier_type(generator, ctx, x2_ty).map_pointer_value(n2, None); - let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); - let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); + // TODO: General `np.dot()` https://numpy.org/doc/stable/reference/generated/numpy.dot.html. + assert_eq!(a.get_type().ndims(), 1); + assert_eq!(b.get_type().ndims(), 1); + let common_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x1_ty); + // Check shapes. + let a_size = a.size(ctx); + let b_size = b.size(ctx); + let same_shape = + ctx.builder.build_int_compare(IntPredicate::EQ, a_size, b_size, "").unwrap(); ctx.make_assert( generator, - ctx.builder.build_int_compare(IntPredicate::EQ, n1_sz, n2_sz, "").unwrap(), + same_shape, "0:ValueError", - "shapes ({0}), ({1}) not aligned", - [Some(n1_sz), Some(n2_sz), None], + "shapes ({0},) and ({1},) not aligned: {0} (dim 0) != {1} (dim 1)", + [Some(a_size), Some(b_size), None], ctx.current_loc, ); - let identity = - unsafe { n1.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; - let acc = ctx.builder.build_alloca(identity.get_type(), "").unwrap(); - ctx.builder.build_store(acc, identity.get_type().const_zero()).unwrap(); + let dtype_llvm = ctx.get_llvm_type(generator, common_dtype); - gen_for_callback_incrementing( + let result = ctx.builder.build_alloca(dtype_llvm, "np_dot_result").unwrap(); + ctx.builder.build_store(result, dtype_llvm.const_zero()).unwrap(); + + // Do dot product. + gen_for_callback( generator, ctx, - None, - llvm_usize.const_zero(), - (n1_sz, false), - |generator, ctx, _, idx| { - let elem1 = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) }; - let elem2 = unsafe { n2.data().get_unchecked(ctx, generator, &idx, None) }; + Some("np_dot"), + |generator, ctx| { + let a_iter = NDIterType::new(ctx).construct(generator, ctx, a); + let b_iter = NDIterType::new(ctx).construct(generator, ctx, b); + Ok((a_iter, b_iter)) + }, + |_, ctx, (a_iter, _b_iter)| { + // Only a_iter drives the condition, b_iter should have the same status. + Ok(a_iter.has_element(ctx)) + }, + |_, ctx, _hooks, (a_iter, b_iter)| { + let a_scalar = a_iter.get_scalar(ctx); + let b_scalar = b_iter.get_scalar(ctx); - let product = match elem1 { - BasicValueEnum::IntValue(e1) => ctx - .builder - .build_int_mul(e1, elem2.into_int_value(), "") - .unwrap() - .as_basic_value_enum(), - BasicValueEnum::FloatValue(e1) => ctx - .builder - .build_float_mul(e1, elem2.into_float_value(), "") - .unwrap() - .as_basic_value_enum(), - _ => codegen_unreachable!(ctx, "product: {}", elem1.get_type()), - }; - let acc_val = ctx.builder.build_load(acc, "").unwrap(); - let acc_val = match acc_val { - BasicValueEnum::IntValue(e1) => ctx - .builder - .build_int_add(e1, product.into_int_value(), "") - .unwrap() - .as_basic_value_enum(), - BasicValueEnum::FloatValue(e1) => ctx - .builder - .build_float_add(e1, product.into_float_value(), "") - .unwrap() - .as_basic_value_enum(), - _ => codegen_unreachable!(ctx, "acc_val: {}", acc_val.get_type()), - }; - ctx.builder.build_store(acc, acc_val).unwrap(); + let old_result = ctx.builder.build_load(result, "").unwrap(); + let new_result: BasicValueEnum<'ctx> = match old_result { + BasicValueEnum::IntValue(old_result) => { + let a_scalar = a_scalar.into_int_value(); + let b_scalar = b_scalar.into_int_value(); + let x = ctx.builder.build_int_mul(a_scalar, b_scalar, "").unwrap(); + ctx.builder.build_int_add(old_result, x, "").unwrap().into() + } + BasicValueEnum::FloatValue(old_result) => { + let a_scalar = a_scalar.into_float_value(); + let b_scalar = b_scalar.into_float_value(); + let x = ctx.builder.build_float_mul(a_scalar, b_scalar, "").unwrap(); + ctx.builder.build_float_add(old_result, x, "").unwrap().into() + } + + _ => { + panic!("Unrecognized dtype: {}", ctx.unifier.stringify(common_dtype)); + } + }; + + ctx.builder.build_store(result, new_result).unwrap(); Ok(()) }, - llvm_usize.const_int(1, false), - )?; - let acc_val = ctx.builder.build_load(acc, "").unwrap(); - Ok(acc_val) + |_, ctx, (a_iter, b_iter)| { + a_iter.next(ctx); + b_iter.next(ctx); + Ok(()) + }, + ) + .unwrap(); + + Ok(ctx.builder.build_load(result, "").unwrap()) } + (BasicValueEnum::IntValue(e1), BasicValueEnum::IntValue(e2)) => { Ok(ctx.builder.build_int_mul(e1, e2, "").unwrap().as_basic_value_enum()) } + (BasicValueEnum::FloatValue(e1), BasicValueEnum::FloatValue(e2)) => { Ok(ctx.builder.build_float_mul(e1, e2, "").unwrap().as_basic_value_enum()) } + _ => codegen_unreachable!( ctx, "{FN_NAME}() not supported for '{}'", diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 3595528..0c1b931 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -1,6 +1,7 @@ use inkwell::{ attributes::{Attribute, AttributeLoc}, basic_block::BasicBlock, + builder::Builder, types::{BasicType, BasicTypeEnum}, values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue}, IntPredicate, @@ -16,7 +17,11 @@ use super::{ gen_in_range_check, irrt::{handle_slice_indices, list_slice_assignment}, macros::codegen_unreachable, - values::{ArrayLikeIndexer, ArraySliceValue, ListValue, RangeValue}, + types::{ndarray::NDArrayType, RangeType}, + values::{ + ndarray::{RustNDIndex, ScalarOrNDArray}, + ArrayLikeIndexer, ArraySliceValue, ListValue, ProxyValue, + }, CodeGenContext, CodeGenerator, }; use crate::{ @@ -302,7 +307,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>( if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => { // Handle list item assignment - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let target_item_ty = iter_type_vars(list_params).next().unwrap().ty; let target = generator @@ -363,10 +368,8 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>( .unwrap() .to_basic_value_enum(ctx, generator, key_ty)? .into_int_value(); - let index = ctx - .builder - .build_int_s_extend(index, generator.get_size_type(ctx.ctx), "sext") - .unwrap(); + let index = + ctx.builder.build_int_s_extend(index, ctx.get_size_type(), "sext").unwrap(); // handle negative index let is_negative = ctx @@ -374,7 +377,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>( .build_int_compare( IntPredicate::SLT, index, - generator.get_size_type(ctx.ctx).const_zero(), + ctx.get_size_type().const_zero(), "is_neg", ) .unwrap(); @@ -411,7 +414,51 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>( if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => { // Handle NDArray item assignment - todo!("ndarray subscript assignment is not yet implemented"); + // Process target + let target = generator + .gen_expr(ctx, target)? + .unwrap() + .to_basic_value_enum(ctx, generator, target_ty)?; + + // Process key + let key = RustNDIndex::from_subscript_expr(generator, ctx, key)?; + + // Process value + let value = value.to_basic_value_enum(ctx, generator, value_ty)?; + + // Reference code: + // ```python + // target = target[key] + // value = np.asarray(value) + // + // shape = np.broadcast_shape((target, value)) + // + // target = np.broadcast_to(target, shape) + // value = np.broadcast_to(value, shape) + // + // # ...and finally copy 1-1 from value to target. + // ``` + + let target = NDArrayType::from_unifier_type(generator, ctx, target_ty) + .map_pointer_value(target.into_pointer_value(), None); + let target = target.index(generator, ctx, &key); + + let value = ScalarOrNDArray::from_value(generator, ctx, (value_ty, value)) + .to_ndarray(generator, ctx); + + let broadcast_ndims = + [target.get_type().ndims(), value.get_type().ndims()].into_iter().max().unwrap(); + let broadcast_result = NDArrayType::new( + ctx, + value.get_type().element_type(), + broadcast_ndims, + ) + .broadcast(generator, ctx, &[target, value]); + + let target = broadcast_result.ndarrays[0]; + let value = broadcast_result.ndarrays[1]; + + target.copy_data_from(ctx, value); } _ => { panic!("encountered unknown target type: {}", ctx.unifier.stringify(target_ty)); @@ -435,7 +482,7 @@ pub fn gen_for( let var_assignment = ctx.var_assignment.clone(); let int32 = ctx.ctx.i32_type(); - let size_t = generator.get_size_type(ctx.ctx); + let size_t = ctx.get_size_type(); let zero = int32.const_zero(); let current = ctx.builder.get_insert_block().and_then(BasicBlock::get_parent).unwrap(); let body_bb = ctx.ctx.append_basic_block(current, "for.body"); @@ -464,7 +511,7 @@ pub fn gen_for( if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() => { let iter_val = - RangeValue::from_pointer_value(iter_val.into_pointer_value(), Some("range")); + RangeType::new(ctx).map_pointer_value(iter_val.into_pointer_value(), Some("range")); // Internal variable for loop; Cannot be assigned let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?; // Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed @@ -616,11 +663,25 @@ pub fn gen_for( #[derive(PartialEq, Eq, Debug, Clone, Copy, Hash)] pub struct BreakContinueHooks<'ctx> { /// The [exit block][`BasicBlock`] to branch to when `break`-ing out of a loop. - pub exit_bb: BasicBlock<'ctx>, + exit_bb: BasicBlock<'ctx>, /// The [latch basic block][`BasicBlock`] to branch to for `continue`-ing to the next iteration /// of the loop. - pub latch_bb: BasicBlock<'ctx>, + latch_bb: BasicBlock<'ctx>, +} + +impl<'ctx> BreakContinueHooks<'ctx> { + /// Creates a [`br` instruction][Builder::build_unconditional_branch] to the exit + /// [`BasicBlock`], as if by calling `break`. + pub fn build_break_branch(&self, builder: &Builder<'ctx>) { + builder.build_unconditional_branch(self.exit_bb).unwrap(); + } + + /// Creates a [`br` instruction][Builder::build_unconditional_branch] to the latch + /// [`BasicBlock`], as if by calling `continue`. + pub fn build_continue_branch(&self, builder: &Builder<'ctx>) { + builder.build_unconditional_branch(self.latch_bb).unwrap(); + } } /// Generates a C-style `for` construct using lambdas, similar to the following C code: diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 81c5836..15c4654 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -36,7 +36,6 @@ use crate::{ struct Resolver { id_to_type: HashMap, id_to_def: RwLock>, - class_names: HashMap, } impl Resolver { @@ -98,19 +97,18 @@ fn test_primitives() { "}; let statements = parse_program(source, FileName::default()).unwrap(); - let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 32).0; + let context = inkwell::context::Context::create(); + let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0; let mut unifier = composer.unifier.clone(); let primitives = composer.primitives_ty; let top_level = Arc::new(composer.make_top_level_context()); unifier.top_level = Some(top_level.clone()); - let resolver = Arc::new(Resolver { - id_to_type: HashMap::new(), - id_to_def: RwLock::new(HashMap::new()), - class_names: HashMap::default(), - }) as Arc; + let resolver = + Arc::new(Resolver { id_to_type: HashMap::new(), id_to_def: RwLock::new(HashMap::new()) }) + as Arc; - let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()]; + let threads = vec![DefaultCodeGenerator::new("test".into(), context.i64_type()).into()]; let signature = FunSignature { args: vec![ FuncArg { @@ -263,7 +261,8 @@ fn test_simple_call() { "}; let statements_2 = parse_program(source_2, FileName::default()).unwrap(); - let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 32).0; + let context = inkwell::context::Context::create(); + let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0; let mut unifier = composer.unifier.clone(); let primitives = composer.primitives_ty; let top_level = Arc::new(composer.make_top_level_context()); @@ -298,11 +297,7 @@ fn test_simple_call() { loc: None, }))); - let resolver = Resolver { - id_to_type: HashMap::new(), - id_to_def: RwLock::new(HashMap::new()), - class_names: HashMap::default(), - }; + let resolver = Resolver { id_to_type: HashMap::new(), id_to_def: RwLock::new(HashMap::new()) }; resolver.add_id_def("foo".into(), DefinitionId(foo_id)); let resolver = Arc::new(resolver) as Arc; @@ -314,7 +309,7 @@ fn test_simple_call() { unreachable!() } - let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()]; + let threads = vec![DefaultCodeGenerator::new("test".into(), context.i64_type()).into()]; let mut function_data = FunctionData { resolver: resolver.clone(), bound_variables: Vec::new(), @@ -446,31 +441,34 @@ fn test_simple_call() { #[test] fn test_classes_list_type_new() { let ctx = inkwell::context::Context::create(); - let generator = DefaultCodeGenerator::new(String::new(), 64); + let generator = DefaultCodeGenerator::new(String::new(), ctx.i64_type()); let llvm_i32 = ctx.i32_type(); let llvm_usize = generator.get_size_type(&ctx); - let llvm_list = ListType::new(&generator, &ctx, llvm_i32.into()); - assert!(ListType::is_representable(llvm_list.as_base_type(), llvm_usize).is_ok()); + let llvm_list = ListType::new_with_generator(&generator, &ctx, llvm_i32.into()); + assert!(ListType::is_representable(llvm_list.as_abi_type(), llvm_usize).is_ok()); } #[test] fn test_classes_range_type_new() { let ctx = inkwell::context::Context::create(); + let generator = DefaultCodeGenerator::new(String::new(), ctx.i64_type()); - let llvm_range = RangeType::new(&ctx); - assert!(RangeType::is_representable(llvm_range.as_base_type()).is_ok()); + let llvm_usize = generator.get_size_type(&ctx); + + let llvm_range = RangeType::new_with_generator(&generator, &ctx); + assert!(RangeType::is_representable(llvm_range.as_abi_type(), llvm_usize).is_ok()); } #[test] fn test_classes_ndarray_type_new() { let ctx = inkwell::context::Context::create(); - let generator = DefaultCodeGenerator::new(String::new(), 64); + let generator = DefaultCodeGenerator::new(String::new(), ctx.i64_type()); let llvm_i32 = ctx.i32_type(); let llvm_usize = generator.get_size_type(&ctx); - let llvm_ndarray = NDArrayType::new(&generator, &ctx, llvm_i32.into(), None); - assert!(NDArrayType::is_representable(llvm_ndarray.as_base_type(), llvm_usize).is_ok()); + let llvm_ndarray = NDArrayType::new_with_generator(&generator, &ctx, llvm_i32.into(), 2); + assert!(NDArrayType::is_representable(llvm_ndarray.as_abi_type(), llvm_usize).is_ok()); } diff --git a/nac3core/src/codegen/types/list.rs b/nac3core/src/codegen/types/list.rs index 3d04134..b4110da 100644 --- a/nac3core/src/codegen/types/list.rs +++ b/nac3core/src/codegen/types/list.rs @@ -1,127 +1,311 @@ use inkwell::{ context::Context, - types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::IntValue, - AddressSpace, + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType, StructType}, + values::{IntValue, PointerValue, StructValue}, + AddressSpace, IntPredicate, OptimizationLevel, }; +use itertools::Itertools; + +use nac3core_derive::StructFields; use super::ProxyType; -use crate::codegen::{ - values::{ArraySliceValue, ListValue, ProxyValue}, - CodeGenContext, CodeGenerator, +use crate::{ + codegen::{ + types::structure::{ + check_struct_type_matches_fields, FieldIndexCounter, StructField, StructFields, + StructProxyType, + }, + values::ListValue, + CodeGenContext, CodeGenerator, + }, + typecheck::typedef::{iter_type_vars, Type, TypeEnum}, }; /// Proxy type for a `list` type in LLVM. #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub struct ListType<'ctx> { ty: PointerType<'ctx>, + item: Option>, llvm_usize: IntType<'ctx>, } +#[derive(PartialEq, Eq, Clone, Copy, StructFields)] +pub struct ListStructFields<'ctx> { + /// Array pointer to content. + #[value_type(i8_type().ptr_type(AddressSpace::default()))] + pub items: StructField<'ctx, PointerValue<'ctx>>, + + /// Number of items in the array. + #[value_type(usize)] + pub len: StructField<'ctx, IntValue<'ctx>>, +} + +impl<'ctx> ListStructFields<'ctx> { + #[must_use] + pub fn new_typed(item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + let mut counter = FieldIndexCounter::default(); + + ListStructFields { + items: StructField::create( + &mut counter, + "items", + item.ptr_type(AddressSpace::default()), + ), + len: StructField::create(&mut counter, "len", llvm_usize), + } + } +} + impl<'ctx> ListType<'ctx> { - /// Checks whether `llvm_ty` represents a `list` type, returning [Err] if it does not. - pub fn is_representable( - llvm_ty: PointerType<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - let llvm_list_ty = llvm_ty.get_element_type(); - let AnyTypeEnum::StructType(llvm_list_ty) = llvm_list_ty else { - return Err(format!("Expected struct type for `list` type, got {llvm_list_ty}")); - }; - if llvm_list_ty.count_fields() != 2 { - return Err(format!( - "Expected 2 fields in `list`, got {}", - llvm_list_ty.count_fields() - )); - } - - let list_size_ty = llvm_list_ty.get_field_type_at_index(0).unwrap(); - let Ok(_) = PointerType::try_from(list_size_ty) else { - return Err(format!("Expected pointer type for `list.0`, got {list_size_ty}")); - }; - - let list_data_ty = llvm_list_ty.get_field_type_at_index(1).unwrap(); - let Ok(list_data_ty) = IntType::try_from(list_data_ty) else { - return Err(format!("Expected int type for `list.1`, got {list_data_ty}")); - }; - if list_data_ty.get_bit_width() != llvm_usize.get_bit_width() { - return Err(format!( - "Expected {}-bit int type for `list.1`, got {}-bit int", - llvm_usize.get_bit_width(), - list_data_ty.get_bit_width() - )); - } - - Ok(()) + /// Returns an instance of [`StructFields`] containing all field accessors for this type. + #[must_use] + fn fields(item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>) -> ListStructFields<'ctx> { + ListStructFields::new_typed(item, llvm_usize) } /// Creates an LLVM type corresponding to the expected structure of a `List`. #[must_use] fn llvm_type( ctx: &'ctx Context, - element_type: BasicTypeEnum<'ctx>, + element_type: Option>, llvm_usize: IntType<'ctx>, ) -> PointerType<'ctx> { - // struct List { data: T*, size: size_t } - let field_tys = [element_type.ptr_type(AddressSpace::default()).into(), llvm_usize.into()]; + let element_type = element_type.map_or(llvm_usize.into(), |ty| ty.as_basic_type_enum()); + + let field_tys = + Self::fields(element_type, llvm_usize).into_iter().map(|field| field.1).collect_vec(); ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) } + fn new_impl( + ctx: &'ctx Context, + element_type: Option>, + llvm_usize: IntType<'ctx>, + ) -> Self { + let llvm_list = Self::llvm_type(ctx, element_type, llvm_usize); + + Self { ty: llvm_list, item: element_type, llvm_usize } + } + /// Creates an instance of [`ListType`]. #[must_use] - pub fn new( + pub fn new(ctx: &CodeGenContext<'ctx, '_>, element_type: &impl BasicType<'ctx>) -> Self { + Self::new_impl(ctx.ctx, Some(element_type.as_basic_type_enum()), ctx.get_size_type()) + } + + /// Creates an instance of [`ListType`]. + #[must_use] + pub fn new_with_generator( generator: &G, ctx: &'ctx Context, element_type: BasicTypeEnum<'ctx>, ) -> Self { - let llvm_usize = generator.get_size_type(ctx); - let llvm_list = Self::llvm_type(ctx, element_type, llvm_usize); + Self::new_impl(ctx, Some(element_type.as_basic_type_enum()), generator.get_size_type(ctx)) + } - ListType::from_type(llvm_list, llvm_usize) + /// Creates an instance of [`ListType`] with an unknown element type. + #[must_use] + pub fn new_untyped(ctx: &CodeGenContext<'ctx, '_>) -> Self { + Self::new_impl(ctx.ctx, None, ctx.get_size_type()) + } + + /// Creates an instance of [`ListType`] with an unknown element type. + #[must_use] + pub fn new_untyped_with_generator( + generator: &G, + ctx: &'ctx Context, + ) -> Self { + Self::new_impl(ctx, None, generator.get_size_type(ctx)) + } + + /// Creates an [`ListType`] from a [unifier type][Type]. + #[must_use] + pub fn from_unifier_type( + generator: &G, + ctx: &mut CodeGenContext<'ctx, '_>, + ty: Type, + ) -> Self { + // Check unifier type and extract `item_type` + let elem_type = match &*ctx.unifier.get_ty_immutable(ty) { + TypeEnum::TObj { obj_id, params, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + iter_type_vars(params).next().unwrap().ty + } + + _ => panic!("Expected `list` type, but got {}", ctx.unifier.stringify(ty)), + }; + + let llvm_usize = ctx.get_size_type(); + let llvm_elem_type = if let TypeEnum::TVar { .. } = &*ctx.unifier.get_ty_immutable(ty) { + None + } else { + Some(ctx.get_llvm_type(generator, elem_type)) + }; + + Self::new_impl(ctx.ctx, llvm_elem_type, llvm_usize) + } + + /// Creates an [`ListType`] from a [`StructType`]. + #[must_use] + pub fn from_struct_type(ty: StructType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + Self::from_pointer_type(ty.ptr_type(AddressSpace::default()), llvm_usize) } /// Creates an [`ListType`] from a [`PointerType`]. #[must_use] - pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { - debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); + pub fn from_pointer_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); - ListType { ty: ptr_ty, llvm_usize } + let ctx = ptr_ty.get_context(); + + // We are just searching for the index off a field - Slot an arbitrary element type in. + let item_field_idx = + Self::fields(ctx.i8_type().into(), llvm_usize).index_of_field(|f| f.items); + let item = unsafe { + ptr_ty + .get_element_type() + .into_struct_type() + .get_field_type_at_index_unchecked(item_field_idx) + .into_pointer_type() + .get_element_type() + }; + let item = BasicTypeEnum::try_from(item).unwrap_or_else(|()| { + panic!( + "Expected BasicTypeEnum for list element type, got {}", + ptr_ty.get_element_type().print_to_string() + ) + }); + + ListType { ty: ptr_ty, item: Some(item), llvm_usize } } /// Returns the type of the `size` field of this `list` type. #[must_use] pub fn size_type(&self) -> IntType<'ctx> { - self.as_base_type() - .get_element_type() - .into_struct_type() - .get_field_type_at_index(1) - .map(BasicTypeEnum::into_int_type) - .unwrap() + self.llvm_usize } /// Returns the element type of this `list` type. #[must_use] - pub fn element_type(&self) -> AnyTypeEnum<'ctx> { - self.as_base_type() - .get_element_type() - .into_struct_type() - .get_field_type_at_index(0) - .map(BasicTypeEnum::into_pointer_type) - .map(PointerType::get_element_type) - .unwrap() + pub fn element_type(&self) -> Option> { + self.item } /// Allocates an instance of [`ListValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca`]. #[must_use] - pub fn alloca( + pub fn alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca(ctx, name), + self.llvm_usize, + name, + ) + } + + /// Allocates an instance of [`ListValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca_var`]. + #[must_use] + pub fn alloca_var( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value( - self.raw_alloca(generator, ctx, name), + self.raw_alloca_var(generator, ctx, name), + self.llvm_usize, + name, + ) + } + + /// Allocates a [`ListValue`] on the stack using `item` of this [`ListType`] instance. + /// + /// The returned list will contain: + /// + /// - `data`: Allocated with `len` number of elements. + /// - `len`: Initialized to the value of `len` passed to this function. + #[must_use] + pub fn construct( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + len: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + let len = ctx.builder.build_int_z_extend(len, self.llvm_usize, "").unwrap(); + + // Generate a runtime assertion if allocating a non-empty list with unknown element type + if ctx.registry.llvm_options.opt_level == OptimizationLevel::None && self.item.is_none() { + let len_eqz = ctx + .builder + .build_int_compare(IntPredicate::EQ, len, self.llvm_usize.const_zero(), "") + .unwrap(); + + ctx.make_assert( + generator, + len_eqz, + "0:AssertionError", + "Cannot allocate a non-empty list with unknown element type", + [None, None, None], + ctx.current_loc, + ); + } + + let plist = self.alloca_var(generator, ctx, name); + plist.store_size(ctx, len); + + let item = self.item.unwrap_or(self.llvm_usize.into()); + plist.create_data(ctx, item, None); + + plist + } + + /// Convenience function for creating a list with zero elements. + /// + /// This function is preferred over [`ListType::construct`] if the length is known to always be + /// 0, as this function avoids injecting an IR assertion for checking if a non-empty untyped + /// list is being allocated. + /// + /// The returned list will contain: + /// + /// - `data`: Initialized to `(T*) 0`. + /// - `len`: Initialized to `0`. + #[must_use] + pub fn construct_empty( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + let plist = self.alloca_var(generator, ctx, name); + + plist.store_size(ctx, self.llvm_usize.const_zero()); + plist.create_data(ctx, self.item.unwrap_or(self.llvm_usize.into()), None); + + plist + } + + /// Converts an existing value into a [`ListValue`]. + #[must_use] + pub fn map_struct_value( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: StructValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_struct_value( + generator, + ctx, + value, self.llvm_usize, name, ) @@ -129,9 +313,9 @@ impl<'ctx> ListType<'ctx> { /// Converts an existing value into a [`ListValue`]. #[must_use] - pub fn map_value( + pub fn map_pointer_value( &self, - value: <>::Value as ProxyValue<'ctx>>::Base, + value: PointerValue<'ctx>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value(value, self.llvm_usize, name) @@ -139,64 +323,64 @@ impl<'ctx> ListType<'ctx> { } impl<'ctx> ProxyType<'ctx> for ListType<'ctx> { + type ABI = PointerType<'ctx>; type Base = PointerType<'ctx>; type Value = ListValue<'ctx>; - fn is_type( - generator: &G, - ctx: &'ctx Context, + fn is_representable( llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, ) -> Result<(), String> { if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { - >::is_representable(generator, ctx, ty) + Self::has_same_repr(ty, llvm_usize) } else { Err(format!("Expected pointer type, got {llvm_ty:?}")) } } - fn is_representable( - generator: &G, - ctx: &'ctx Context, - llvm_ty: Self::Base, - ) -> Result<(), String> { - Self::is_representable(llvm_ty, generator.get_size_type(ctx)) + fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String> { + let ctx = ty.get_context(); + + let llvm_ty = ty.get_element_type(); + let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else { + return Err(format!("Expected struct type for `list` type, got {llvm_ty}")); + }; + + let fields = ListStructFields::new(ctx, llvm_usize); + + check_struct_type_matches_fields( + fields, + llvm_ty, + "list", + &[(fields.items.name(), &|ty| { + if ty.is_pointer_type() { + Ok(()) + } else { + Err(format!("Expected T* for `list.items`, got {ty}")) + } + })], + ) } - fn raw_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> >::Base { - generator - .gen_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - name, - ) - .unwrap() - } - - fn array_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - size: IntValue<'ctx>, - name: Option<&'ctx str>, - ) -> ArraySliceValue<'ctx> { - generator - .gen_array_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - size, - name, - ) - .unwrap() + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.as_abi_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { self.ty } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } +} + +impl<'ctx> StructProxyType<'ctx> for ListType<'ctx> { + type StructFields = ListStructFields<'ctx>; + + fn get_fields(&self) -> Self::StructFields { + Self::fields(self.item.unwrap_or(self.llvm_usize.into()), self.llvm_usize) + } } impl<'ctx> From> for PointerType<'ctx> { diff --git a/nac3core/src/codegen/types/mod.rs b/nac3core/src/codegen/types/mod.rs index 022f897..abeab5b 100644 --- a/nac3core/src/codegen/types/mod.rs +++ b/nac3core/src/codegen/types/mod.rs @@ -16,7 +16,10 @@ //! the returned object. This is similar to a `new` expression in C++ but the object is allocated //! on the stack. -use inkwell::{context::Context, types::BasicType, values::IntValue}; +use inkwell::{ + types::{BasicType, IntType}, + values::{IntValue, PointerValue}, +}; use super::{ values::{ArraySliceValue, ProxyValue}, @@ -24,53 +27,103 @@ use super::{ }; pub use list::*; pub use range::*; +pub use tuple::*; mod list; pub mod ndarray; mod range; pub mod structure; +mod tuple; pub mod utils; /// A LLVM type that is used to represent a corresponding type in NAC3. pub trait ProxyType<'ctx>: Into { - /// The LLVM type of which values of this type possess. This is usually a - /// [LLVM pointer type][PointerType] for any non-primitive types. + /// The ABI type of which values of this type possess. + type ABI: BasicType<'ctx>; + + /// The LLVM type of which values of this type possess. type Base: BasicType<'ctx>; /// The type of values represented by this type. type Value: ProxyValue<'ctx, Type = Self>; - fn is_type( - generator: &G, - ctx: &'ctx Context, - llvm_ty: impl BasicType<'ctx>, - ) -> Result<(), String>; - /// Checks whether `llvm_ty` can be represented by this [`ProxyType`]. - fn is_representable( - generator: &G, - ctx: &'ctx Context, - llvm_ty: Self::Base, + fn is_representable( + llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, ) -> Result<(), String>; - /// Creates a new value of this type, returning the LLVM instance of this value. - fn raw_alloca( + /// Checks whether the type represented by `ty` expresses the same type represented by this + /// [`ProxyType`]. + fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String>; + + /// Returns the type that should be used in `alloca` IR statements. + fn alloca_type(&self) -> impl BasicType<'ctx>; + + /// Creates a new value of this type by invoking `alloca` at the current builder location, + /// returning a [`PointerValue`] instance representing the allocated value. + fn raw_alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> PointerValue<'ctx> { + ctx.builder + .build_alloca(self.alloca_type().as_basic_type_enum(), name.unwrap_or_default()) + .unwrap() + } + + /// Creates a new value of this type by invoking `alloca` at the beginning of the function, + /// returning a [`PointerValue`] instance representing the allocated value. + fn raw_alloca_var( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, - ) -> >::Base; + ) -> PointerValue<'ctx> { + generator.gen_var_alloc(ctx, self.alloca_type().as_basic_type_enum(), name).unwrap() + } - /// Creates a new array value of this type, returning an [`ArraySliceValue`] encapsulating the - /// resulting array. - fn array_alloca( + /// Creates a new array value of this type by invoking `alloca` at the current builder location, + /// returning an [`ArraySliceValue`] encapsulating the resulting array. + fn array_alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + size: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> ArraySliceValue<'ctx> { + ArraySliceValue::from_ptr_val( + ctx.builder + .build_array_alloca( + self.alloca_type().as_basic_type_enum(), + size, + name.unwrap_or_default(), + ) + .unwrap(), + size, + name, + ) + } + + /// Creates a new array value of this type by invoking `alloca` at the beginning of the + /// function, returning an [`ArraySliceValue`] encapsulating the resulting array. + fn array_alloca_var( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, size: IntValue<'ctx>, name: Option<&'ctx str>, - ) -> ArraySliceValue<'ctx>; + ) -> ArraySliceValue<'ctx> { + generator + .gen_array_var_alloc(ctx, self.alloca_type().as_basic_type_enum(), size, name) + .unwrap() + } /// Returns the [base type][Self::Base] of this proxy. fn as_base_type(&self) -> Self::Base; + + /// Returns this proxy as its ABI type, i.e. the expected type representation if a value of this + /// [`ProxyType`] is being passed into or returned from a function. + /// + /// See [`CodeGenContext::get_llvm_abi_type`]. + fn as_abi_type(&self) -> Self::ABI; } diff --git a/nac3core/src/codegen/types/ndarray/array.rs b/nac3core/src/codegen/types/ndarray/array.rs new file mode 100644 index 0000000..633a0b4 --- /dev/null +++ b/nac3core/src/codegen/types/ndarray/array.rs @@ -0,0 +1,240 @@ +use inkwell::{ + types::BasicTypeEnum, + values::{BasicValueEnum, IntValue}, + AddressSpace, +}; + +use crate::{ + codegen::{ + irrt, + stmt::gen_if_else_expr_callback, + types::{ndarray::NDArrayType, ListType, ProxyType}, + values::{ + ndarray::NDArrayValue, ArrayLikeValue, ArraySliceValue, ListValue, ProxyValue, + TypedArrayLikeAdapter, TypedArrayLikeMutator, + }, + CodeGenContext, CodeGenerator, + }, + toplevel::helper::{arraylike_flatten_element_type, arraylike_get_ndims}, + typecheck::typedef::{Type, TypeEnum}, +}; + +/// Get the expected `dtype` and `ndims` of the ndarray returned by `np_array()`. +fn get_list_object_dtype_and_ndims<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &mut CodeGenContext<'ctx, '_>, + list_ty: Type, +) -> (BasicTypeEnum<'ctx>, u64) { + let dtype = arraylike_flatten_element_type(&mut ctx.unifier, list_ty); + let ndims = arraylike_get_ndims(&mut ctx.unifier, list_ty); + + (ctx.get_llvm_type(generator, dtype), ndims) +} + +impl<'ctx> NDArrayType<'ctx> { + /// Implementation of `np_array(, copy=True)` + fn construct_numpy_array_from_list_copy_true_impl( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + (list_ty, list): (Type, ListValue<'ctx>), + name: Option<&'ctx str>, + ) -> >::Value { + let (dtype, ndims_int) = get_list_object_dtype_and_ndims(generator, ctx, list_ty); + assert!(self.ndims >= ndims_int); + assert_eq!(dtype, self.dtype); + + let list_value = list.as_i8_list(ctx); + + // Validate `list` has a consistent shape. + // Raise an exception if `list` is something abnormal like `[[1, 2], [3]]`. + // If `list` has a consistent shape, deduce the shape and write it to `shape`. + let ndims = self.llvm_usize.const_int(ndims_int, false); + let shape = ctx.builder.build_array_alloca(self.llvm_usize, ndims, "").unwrap(); + let shape = ArraySliceValue::from_ptr_val(shape, ndims, None); + let shape = TypedArrayLikeAdapter::from( + shape, + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + irrt::ndarray::call_nac3_ndarray_array_set_and_validate_list_shape( + generator, ctx, list_value, ndims, &shape, + ); + + let ndarray = + Self::new(ctx, dtype, ndims_int).construct_uninitialized(generator, ctx, name); + ndarray.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator)); + unsafe { ndarray.create_data(generator, ctx) }; + + // Copy all contents from the list. + irrt::ndarray::call_nac3_ndarray_array_write_list_to_array(ctx, list_value, ndarray); + + ndarray + } + + /// Implementation of `np_array(, copy=None)` + fn construct_numpy_array_from_list_copy_none_impl( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + (list_ty, list): (Type, ListValue<'ctx>), + name: Option<&'ctx str>, + ) -> >::Value { + // np_array without copying is only possible `list` is not nested. + // + // If `list` is `list[T]`, we can create an ndarray with `data` set + // to the array pointer of `list`. + // + // If `list` is `list[list[T]]` or worse, copy. + + let (dtype, ndims) = get_list_object_dtype_and_ndims(generator, ctx, list_ty); + if ndims == 1 { + // `list` is not nested + assert_eq!(ndims, 1); + assert!(self.ndims >= ndims); + assert_eq!(dtype, self.dtype); + + let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); + + let ndarray = Self::new(ctx, dtype, 1).construct_uninitialized(generator, ctx, name); + + // Set data + let data = ctx + .builder + .build_pointer_cast(list.data().base_ptr(ctx, generator), llvm_pi8, "") + .unwrap(); + ndarray.store_data(ctx, data); + + // ndarray->shape[0] = list->len; + let shape = ndarray.shape(); + let list_len = list.load_size(ctx, None); + unsafe { + shape.set_typed_unchecked(ctx, generator, &self.llvm_usize.const_zero(), list_len); + } + + // Set strides, the `data` is contiguous + ndarray.set_strides_contiguous(ctx); + + ndarray + } else { + // `list` is nested, copy + self.construct_numpy_array_from_list_copy_true_impl( + generator, + ctx, + (list_ty, list), + name, + ) + } + } + + /// Implementation of `np_array(, copy=copy)` + fn construct_numpy_array_list_impl( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + (list_ty, list): (Type, ListValue<'ctx>), + copy: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + assert_eq!(copy.get_type(), ctx.ctx.bool_type()); + + let (dtype, ndims) = get_list_object_dtype_and_ndims(generator, ctx, list_ty); + + let ndarray = gen_if_else_expr_callback( + generator, + ctx, + |_generator, _ctx| Ok(copy), + |generator, ctx| { + let ndarray = self.construct_numpy_array_from_list_copy_true_impl( + generator, + ctx, + (list_ty, list), + name, + ); + Ok(Some(ndarray.as_abi_value(ctx))) + }, + |generator, ctx| { + let ndarray = self.construct_numpy_array_from_list_copy_none_impl( + generator, + ctx, + (list_ty, list), + name, + ); + Ok(Some(ndarray.as_abi_value(ctx))) + }, + ) + .unwrap() + .map(BasicValueEnum::into_pointer_value) + .unwrap(); + + NDArrayType::new(ctx, dtype, ndims).map_pointer_value(ndarray, None) + } + + /// Implementation of `np_array(, copy=copy)`. + pub fn construct_numpy_array_ndarray_impl( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: NDArrayValue<'ctx>, + copy: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + assert_eq!(ndarray.get_type().dtype, self.dtype); + assert!(self.ndims >= ndarray.get_type().ndims); + assert_eq!(copy.get_type(), ctx.ctx.bool_type()); + + let ndarray_val = gen_if_else_expr_callback( + generator, + ctx, + |_generator, _ctx| Ok(copy), + |generator, ctx| { + let ndarray = ndarray.make_copy(generator, ctx); // Force copy + Ok(Some(ndarray.as_abi_value(ctx))) + }, + |_generator, ctx| { + // No need to copy. Return `ndarray` itself. + Ok(Some(ndarray.as_abi_value(ctx))) + }, + ) + .unwrap() + .map(BasicValueEnum::into_pointer_value) + .unwrap(); + + ndarray.get_type().map_pointer_value(ndarray_val, name) + } + + /// Create a new ndarray like + /// [`np.array()`](https://numpy.org/doc/stable/reference/generated/numpy.array.html). + /// + /// Note that the returned [`NDArrayValue`] may have fewer dimensions than is specified by this + /// instance. Use [`NDArrayValue::atleast_nd`] on the returned value if an `ndarray` instance + /// with the exact number of dimensions is needed. + pub fn construct_numpy_array( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + (object_ty, object): (Type, BasicValueEnum<'ctx>), + copy: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + match &*ctx.unifier.get_ty_immutable(object_ty) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + let list = ListType::from_unifier_type(generator, ctx, object_ty) + .map_pointer_value(object.into_pointer_value(), None); + self.construct_numpy_array_list_impl(generator, ctx, (object_ty, list), copy, name) + } + + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => + { + let ndarray = NDArrayType::from_unifier_type(generator, ctx, object_ty) + .map_pointer_value(object.into_pointer_value(), None); + self.construct_numpy_array_ndarray_impl(generator, ctx, ndarray, copy, name) + } + + _ => panic!("Unrecognized object type: {}", ctx.unifier.stringify(object_ty)), // Typechecker ensures this + } + } +} diff --git a/nac3core/src/codegen/types/ndarray/broadcast.rs b/nac3core/src/codegen/types/ndarray/broadcast.rs new file mode 100644 index 0000000..fa532b4 --- /dev/null +++ b/nac3core/src/codegen/types/ndarray/broadcast.rs @@ -0,0 +1,205 @@ +use inkwell::{ + context::{AsContextRef, Context}, + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType, StructType}, + values::{IntValue, PointerValue, StructValue}, + AddressSpace, +}; +use itertools::Itertools; + +use nac3core_derive::StructFields; + +use crate::codegen::{ + types::{ + structure::{check_struct_type_matches_fields, StructField, StructFields, StructProxyType}, + ProxyType, + }, + values::ndarray::ShapeEntryValue, + CodeGenContext, CodeGenerator, +}; + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct ShapeEntryType<'ctx> { + ty: PointerType<'ctx>, + llvm_usize: IntType<'ctx>, +} + +#[derive(PartialEq, Eq, Clone, Copy, StructFields)] +pub struct ShapeEntryStructFields<'ctx> { + #[value_type(usize)] + pub ndims: StructField<'ctx, IntValue<'ctx>>, + #[value_type(usize.ptr_type(AddressSpace::default()))] + pub shape: StructField<'ctx, PointerValue<'ctx>>, +} + +impl<'ctx> ShapeEntryType<'ctx> { + /// Returns an instance of [`StructFields`] containing all field accessors for this type. + #[must_use] + fn fields( + ctx: impl AsContextRef<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> ShapeEntryStructFields<'ctx> { + ShapeEntryStructFields::new(ctx, llvm_usize) + } + + /// Creates an LLVM type corresponding to the expected structure of a `ShapeEntry`. + #[must_use] + fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> { + let field_tys = + Self::fields(ctx, llvm_usize).into_iter().map(|field| field.1).collect_vec(); + + ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) + } + + fn new_impl(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self { + let llvm_ty = Self::llvm_type(ctx, llvm_usize); + + Self { ty: llvm_ty, llvm_usize } + } + + /// Creates an instance of [`ShapeEntryType`]. + #[must_use] + pub fn new(ctx: &CodeGenContext<'ctx, '_>) -> Self { + Self::new_impl(ctx.ctx, ctx.get_size_type()) + } + + /// Creates an instance of [`ShapeEntryType`]. + #[must_use] + pub fn new_with_generator( + generator: &G, + ctx: &'ctx Context, + ) -> Self { + Self::new_impl(ctx, generator.get_size_type(ctx)) + } + + /// Creates a [`ShapeEntryType`] from a [`StructType`] representing an `ShapeEntry`. + #[must_use] + pub fn from_struct_type(ty: StructType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + Self::from_pointer_type(ty.ptr_type(AddressSpace::default()), llvm_usize) + } + + /// Creates a [`ShapeEntryType`] from a [`PointerType`] representing an `ShapeEntry`. + #[must_use] + pub fn from_pointer_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); + + Self { ty: ptr_ty, llvm_usize } + } + + /// Allocates an instance of [`ShapeEntryValue`] as if by calling `alloca` on the base type. + #[must_use] + pub fn alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca(ctx, name), + self.llvm_usize, + name, + ) + } + + /// Allocates an instance of [`ShapeEntryValue`] as if by calling `alloca` on the base type. + #[must_use] + pub fn alloca_var( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca_var(generator, ctx, name), + self.llvm_usize, + name, + ) + } + + /// Converts an existing value into a [`ShapeEntryValue`]. + #[must_use] + pub fn map_struct_value( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: StructValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_struct_value( + generator, + ctx, + value, + self.llvm_usize, + name, + ) + } + + /// Converts an existing value into a [`ShapeEntryValue`]. + #[must_use] + pub fn map_pointer_value( + &self, + value: PointerValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value(value, self.llvm_usize, name) + } +} + +impl<'ctx> ProxyType<'ctx> for ShapeEntryType<'ctx> { + type ABI = PointerType<'ctx>; + type Base = PointerType<'ctx>; + type Value = ShapeEntryValue<'ctx>; + + fn is_representable( + llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { + Self::has_same_repr(ty, llvm_usize) + } else { + Err(format!("Expected pointer type, got {llvm_ty:?}")) + } + } + + fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String> { + let ctx = ty.get_context(); + + let llvm_ndarray_ty = ty.get_element_type(); + let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { + return Err(format!( + "Expected struct type for `ShapeEntry` type, got {llvm_ndarray_ty}" + )); + }; + + check_struct_type_matches_fields( + Self::fields(ctx, llvm_usize), + llvm_ndarray_ty, + "NDArray", + &[], + ) + } + + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.as_abi_type().get_element_type().into_struct_type() + } + + fn as_base_type(&self) -> Self::Base { + self.ty + } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } +} + +impl<'ctx> StructProxyType<'ctx> for ShapeEntryType<'ctx> { + type StructFields = ShapeEntryStructFields<'ctx>; + + fn get_fields(&self) -> Self::StructFields { + Self::fields(self.ty.get_context(), self.llvm_usize) + } +} + +impl<'ctx> From> for PointerType<'ctx> { + fn from(value: ShapeEntryType<'ctx>) -> Self { + value.as_base_type() + } +} diff --git a/nac3core/src/codegen/types/ndarray/contiguous.rs b/nac3core/src/codegen/types/ndarray/contiguous.rs index 4401cb6..1857536 100644 --- a/nac3core/src/codegen/types/ndarray/contiguous.rs +++ b/nac3core/src/codegen/types/ndarray/contiguous.rs @@ -1,7 +1,7 @@ use inkwell::{ context::Context, - types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::{IntValue, PointerValue}, + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType, StructType}, + values::{IntValue, PointerValue, StructValue}, AddressSpace, }; use itertools::Itertools; @@ -13,10 +13,11 @@ use crate::{ types::{ structure::{ check_struct_type_matches_fields, FieldIndexCounter, StructField, StructFields, + StructProxyType, }, ProxyType, }, - values::{ndarray::ContiguousNDArrayValue, ArraySliceValue, ProxyValue}, + values::ndarray::ContiguousNDArrayValue, CodeGenContext, CodeGenerator, }, toplevel::numpy::unpack_ndarray_var_tys, @@ -31,7 +32,7 @@ pub struct ContiguousNDArrayType<'ctx> { } #[derive(PartialEq, Eq, Clone, Copy, StructFields)] -pub struct ContiguousNDArrayFields<'ctx> { +pub struct ContiguousNDArrayStructFields<'ctx> { #[value_type(usize)] pub ndims: StructField<'ctx, IntValue<'ctx>>, #[value_type(usize.ptr_type(AddressSpace::default()))] @@ -40,12 +41,12 @@ pub struct ContiguousNDArrayFields<'ctx> { pub data: StructField<'ctx, PointerValue<'ctx>>, } -impl<'ctx> ContiguousNDArrayFields<'ctx> { +impl<'ctx> ContiguousNDArrayStructFields<'ctx> { #[must_use] pub fn new_typed(item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>) -> Self { let mut counter = FieldIndexCounter::default(); - ContiguousNDArrayFields { + ContiguousNDArrayStructFields { ndims: StructField::create(&mut counter, "ndims", llvm_usize), shape: StructField::create( &mut counter, @@ -58,50 +59,13 @@ impl<'ctx> ContiguousNDArrayFields<'ctx> { } impl<'ctx> ContiguousNDArrayType<'ctx> { - /// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not. - pub fn is_representable( - llvm_ty: PointerType<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - let ctx = llvm_ty.get_context(); - - let llvm_ty = llvm_ty.get_element_type(); - let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else { - return Err(format!( - "Expected struct type for `ContiguousNDArray` type, got {llvm_ty}" - )); - }; - - let fields = ContiguousNDArrayFields::new(ctx, llvm_usize); - - check_struct_type_matches_fields( - fields, - llvm_ty, - "ContiguousNDArray", - &[(fields.data.name(), &|ty| { - if ty.is_pointer_type() { - Ok(()) - } else { - Err(format!("Expected T* for `ContiguousNDArray.data`, got {ty}")) - } - })], - ) - } - /// Returns an instance of [`StructFields`] containing all field accessors for this type. #[must_use] fn fields( item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>, - ) -> ContiguousNDArrayFields<'ctx> { - ContiguousNDArrayFields::new_typed(item, llvm_usize) - } - - /// See [`NDArrayType::fields`]. - // TODO: Move this into e.g. StructProxyType - #[must_use] - pub fn get_fields(&self) -> ContiguousNDArrayFields<'ctx> { - Self::fields(self.item, self.llvm_usize) + ) -> ContiguousNDArrayStructFields<'ctx> { + ContiguousNDArrayStructFields::new_typed(item, llvm_usize) } /// Creates an LLVM type corresponding to the expected structure of an `NDArray`. @@ -117,17 +81,26 @@ impl<'ctx> ContiguousNDArrayType<'ctx> { ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) } + fn new_impl(ctx: &'ctx Context, item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + let llvm_cndarray = Self::llvm_type(ctx, item, llvm_usize); + + Self { ty: llvm_cndarray, item, llvm_usize } + } + /// Creates an instance of [`ContiguousNDArrayType`]. #[must_use] - pub fn new( + pub fn new(ctx: &CodeGenContext<'ctx, '_>, item: &impl BasicType<'ctx>) -> Self { + Self::new_impl(ctx.ctx, item.as_basic_type_enum(), ctx.get_size_type()) + } + + /// Creates an instance of [`ContiguousNDArrayType`]. + #[must_use] + pub fn new_with_generator( generator: &G, ctx: &'ctx Context, item: BasicTypeEnum<'ctx>, ) -> Self { - let llvm_usize = generator.get_size_type(ctx); - let llvm_cndarray = Self::llvm_type(ctx, item, llvm_usize); - - Self { ty: llvm_cndarray, item, llvm_usize } + Self::new_impl(ctx, item, generator.get_size_type(ctx)) } /// Creates an [`ContiguousNDArrayType`] from a [unifier type][Type]. @@ -140,33 +113,63 @@ impl<'ctx> ContiguousNDArrayType<'ctx> { let (dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); let llvm_dtype = ctx.get_llvm_type(generator, dtype); - let llvm_usize = generator.get_size_type(ctx.ctx); - Self { ty: Self::llvm_type(ctx.ctx, llvm_dtype, llvm_usize), item: llvm_dtype, llvm_usize } + Self::new_impl(ctx.ctx, llvm_dtype, ctx.get_size_type()) + } + + /// Creates an [`ContiguousNDArrayType`] from a [`StructType`] representing an `NDArray`. + #[must_use] + pub fn from_struct_type( + ty: StructType<'ctx>, + item: BasicTypeEnum<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Self { + Self::from_pointer_type(ty.ptr_type(AddressSpace::default()), item, llvm_usize) } /// Creates an [`ContiguousNDArrayType`] from a [`PointerType`] representing an `NDArray`. #[must_use] - pub fn from_type( + pub fn from_pointer_type( ptr_ty: PointerType<'ctx>, item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>, ) -> Self { - debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); + debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); Self { ty: ptr_ty, item, llvm_usize } } - /// Allocates an instance of [`ContiguousNDArrayValue`] as if by calling `alloca` on the base type. + /// Allocates an instance of [`ContiguousNDArrayValue`] as if by calling `alloca` on the base + /// type. + /// + /// See [`ProxyType::raw_alloca`]. #[must_use] - pub fn alloca( + pub fn alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca(ctx, name), + self.item, + self.llvm_usize, + name, + ) + } + + /// Allocates an instance of [`ContiguousNDArrayValue`] as if by calling `alloca` on the base + /// type. + /// + /// See [`ProxyType::raw_alloca_var`]. + #[must_use] + pub fn alloca_var( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value( - self.raw_alloca(generator, ctx, name), + self.raw_alloca_var(generator, ctx, name), self.item, self.llvm_usize, name, @@ -175,9 +178,28 @@ impl<'ctx> ContiguousNDArrayType<'ctx> { /// Converts an existing value into a [`ContiguousNDArrayValue`]. #[must_use] - pub fn map_value( + pub fn map_struct_value( &self, - value: <>::Value as ProxyValue<'ctx>>::Base, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: StructValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_struct_value( + generator, + ctx, + value, + self.item, + self.llvm_usize, + name, + ) + } + + /// Converts an existing value into a [`ContiguousNDArrayValue`]. + #[must_use] + pub fn map_pointer_value( + &self, + value: PointerValue<'ctx>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value( @@ -190,64 +212,66 @@ impl<'ctx> ContiguousNDArrayType<'ctx> { } impl<'ctx> ProxyType<'ctx> for ContiguousNDArrayType<'ctx> { + type ABI = PointerType<'ctx>; type Base = PointerType<'ctx>; type Value = ContiguousNDArrayValue<'ctx>; - fn is_type( - generator: &G, - ctx: &'ctx Context, + fn is_representable( llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, ) -> Result<(), String> { if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { - >::is_representable(generator, ctx, ty) + Self::has_same_repr(ty, llvm_usize) } else { Err(format!("Expected pointer type, got {llvm_ty:?}")) } } - fn is_representable( - generator: &G, - ctx: &'ctx Context, - llvm_ty: Self::Base, - ) -> Result<(), String> { - Self::is_representable(llvm_ty, generator.get_size_type(ctx)) + fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String> { + let ctx = ty.get_context(); + + let llvm_ty = ty.get_element_type(); + let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else { + return Err(format!( + "Expected struct type for `ContiguousNDArray` type, got {llvm_ty}" + )); + }; + + let fields = ContiguousNDArrayStructFields::new(ctx, llvm_usize); + + check_struct_type_matches_fields( + fields, + llvm_ty, + "ContiguousNDArray", + &[(fields.data.name(), &|ty| { + if ty.is_pointer_type() { + Ok(()) + } else { + Err(format!("Expected T* for `ContiguousNDArray.data`, got {ty}")) + } + })], + ) } - fn raw_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> >::Base { - generator - .gen_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - name, - ) - .unwrap() - } - - fn array_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - size: IntValue<'ctx>, - name: Option<&'ctx str>, - ) -> ArraySliceValue<'ctx> { - generator - .gen_array_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - size, - name, - ) - .unwrap() + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.as_abi_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { self.ty } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } +} + +impl<'ctx> StructProxyType<'ctx> for ContiguousNDArrayType<'ctx> { + type StructFields = ContiguousNDArrayStructFields<'ctx>; + + fn get_fields(&self) -> Self::StructFields { + Self::fields(self.item, self.llvm_usize) + } } impl<'ctx> From> for PointerType<'ctx> { diff --git a/nac3core/src/codegen/types/ndarray/factory.rs b/nac3core/src/codegen/types/ndarray/factory.rs new file mode 100644 index 0000000..2d0dca7 --- /dev/null +++ b/nac3core/src/codegen/types/ndarray/factory.rs @@ -0,0 +1,236 @@ +use inkwell::{ + values::{BasicValueEnum, IntValue}, + IntPredicate, +}; + +use super::NDArrayType; +use crate::{ + codegen::{ + irrt, types::ProxyType, values::TypedArrayLikeAccessor, CodeGenContext, CodeGenerator, + }, + typecheck::typedef::Type, +}; + +/// Get the zero value in `np.zeros()` of a `dtype`. +fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, +) -> BasicValueEnum<'ctx> { + if [ctx.primitives.int32, ctx.primitives.uint32] + .iter() + .any(|ty| ctx.unifier.unioned(dtype, *ty)) + { + ctx.ctx.i32_type().const_zero().into() + } else if [ctx.primitives.int64, ctx.primitives.uint64] + .iter() + .any(|ty| ctx.unifier.unioned(dtype, *ty)) + { + ctx.ctx.i64_type().const_zero().into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.float) { + ctx.ctx.f64_type().const_zero().into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.bool) { + ctx.ctx.bool_type().const_zero().into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.str) { + ctx.gen_string(generator, "").into() + } else { + panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype)); + } +} + +/// Get the one value in `np.ones()` of a `dtype`. +fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, +) -> BasicValueEnum<'ctx> { + if [ctx.primitives.int32, ctx.primitives.uint32] + .iter() + .any(|ty| ctx.unifier.unioned(dtype, *ty)) + { + let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int32); + ctx.ctx.i32_type().const_int(1, is_signed).into() + } else if [ctx.primitives.int64, ctx.primitives.uint64] + .iter() + .any(|ty| ctx.unifier.unioned(dtype, *ty)) + { + let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int64); + ctx.ctx.i64_type().const_int(1, is_signed).into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.float) { + ctx.ctx.f64_type().const_float(1.0).into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.bool) { + ctx.ctx.bool_type().const_int(1, false).into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.str) { + ctx.gen_string(generator, "1").into() + } else { + panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype)); + } +} + +impl<'ctx> NDArrayType<'ctx> { + /// Create an ndarray like + /// [`np.empty`](https://numpy.org/doc/stable/reference/generated/numpy.empty.html). + pub fn construct_numpy_empty( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + name: Option<&'ctx str>, + ) -> >::Value { + let ndarray = self.construct_uninitialized(generator, ctx, name); + + // Validate `shape` + irrt::ndarray::call_nac3_ndarray_util_assert_shape_no_negative(generator, ctx, shape); + + ndarray.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator)); + unsafe { ndarray.create_data(generator, ctx) }; + + ndarray + } + + /// Create an ndarray like + /// [`np.full`](https://numpy.org/doc/stable/reference/generated/numpy.full.html). + pub fn construct_numpy_full( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + fill_value: BasicValueEnum<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + let ndarray = self.construct_numpy_empty(generator, ctx, shape, name); + ndarray.fill(generator, ctx, fill_value); + ndarray + } + + /// Create an ndarray like + /// [`np.zero`](https://numpy.org/doc/stable/reference/generated/numpy.zeros.html). + pub fn construct_numpy_zeros( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + name: Option<&'ctx str>, + ) -> >::Value { + assert_eq!( + ctx.get_llvm_type(generator, dtype), + self.dtype, + "Expected LLVM dtype={} but got {}", + self.dtype.print_to_string(), + ctx.get_llvm_type(generator, dtype).print_to_string(), + ); + + let fill_value = ndarray_zero_value(generator, ctx, dtype); + self.construct_numpy_full(generator, ctx, shape, fill_value, name) + } + + /// Create an ndarray like + /// [`np.ones`](https://numpy.org/doc/stable/reference/generated/numpy.ones.html). + pub fn construct_numpy_ones( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + name: Option<&'ctx str>, + ) -> >::Value { + assert_eq!( + ctx.get_llvm_type(generator, dtype), + self.dtype, + "Expected LLVM dtype={} but got {}", + self.dtype.print_to_string(), + ctx.get_llvm_type(generator, dtype).print_to_string(), + ); + + let fill_value = ndarray_one_value(generator, ctx, dtype); + self.construct_numpy_full(generator, ctx, shape, fill_value, name) + } + + /// Create an ndarray like + /// [`np.eye`](https://numpy.org/doc/stable/reference/generated/numpy.eye.html). + #[allow(clippy::too_many_arguments)] + pub fn construct_numpy_eye( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + nrows: IntValue<'ctx>, + ncols: IntValue<'ctx>, + offset: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + assert_eq!( + ctx.get_llvm_type(generator, dtype), + self.dtype, + "Expected LLVM dtype={} but got {}", + self.dtype.print_to_string(), + ctx.get_llvm_type(generator, dtype).print_to_string(), + ); + assert_eq!(nrows.get_type(), self.llvm_usize); + assert_eq!(ncols.get_type(), self.llvm_usize); + assert_eq!(offset.get_type(), self.llvm_usize); + + let ndzero = ndarray_zero_value(generator, ctx, dtype); + let ndone = ndarray_one_value(generator, ctx, dtype); + + let ndarray = self.construct_dyn_shape(generator, ctx, &[nrows, ncols], name); + + // Create data and make the matrix like look np.eye() + unsafe { + ndarray.create_data(generator, ctx); + } + ndarray + .foreach(generator, ctx, |generator, ctx, _, nditer| { + // NOTE: rows and cols can never be zero here, since this ndarray's `np.size` would be zero + // and this loop would not execute. + + let indices = nditer.get_indices(); + + let row_i = unsafe { + indices.get_typed_unchecked(ctx, generator, &self.llvm_usize.const_zero(), None) + }; + let col_i = unsafe { + indices.get_typed_unchecked( + ctx, + generator, + &self.llvm_usize.const_int(1, false), + None, + ) + }; + + let be_one = ctx + .builder + .build_int_compare( + IntPredicate::EQ, + ctx.builder.build_int_add(row_i, offset, "").unwrap(), + col_i, + "", + ) + .unwrap(); + let value = ctx.builder.build_select(be_one, ndone, ndzero, "value").unwrap(); + + let p = nditer.get_pointer(ctx); + ctx.builder.build_store(p, value).unwrap(); + + Ok(()) + }) + .unwrap(); + + ndarray + } + + /// Create an ndarray like + /// [`np.identity`](https://numpy.org/doc/stable/reference/generated/numpy.identity.html). + pub fn construct_numpy_identity( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + size: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + let offset = self.llvm_usize.const_zero(); + self.construct_numpy_eye(generator, ctx, dtype, size, size, offset, name) + } +} diff --git a/nac3core/src/codegen/types/ndarray/indexing.rs b/nac3core/src/codegen/types/ndarray/indexing.rs index 959d4f5..d00e0fb 100644 --- a/nac3core/src/codegen/types/ndarray/indexing.rs +++ b/nac3core/src/codegen/types/ndarray/indexing.rs @@ -1,7 +1,7 @@ use inkwell::{ context::{AsContextRef, Context}, - types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::{IntValue, PointerValue}, + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType, StructType}, + values::{IntValue, PointerValue, StructValue}, AddressSpace, }; use itertools::Itertools; @@ -10,12 +10,12 @@ use nac3core_derive::StructFields; use crate::codegen::{ types::{ - structure::{check_struct_type_matches_fields, StructField, StructFields}, + structure::{check_struct_type_matches_fields, StructField, StructFields, StructProxyType}, ProxyType, }, values::{ ndarray::{NDIndexValue, RustNDIndex}, - ArrayLikeIndexer, ArraySliceValue, ProxyValue, + ArrayLikeIndexer, ArraySliceValue, }, CodeGenContext, CodeGenerator, }; @@ -35,25 +35,6 @@ pub struct NDIndexStructFields<'ctx> { } impl<'ctx> NDIndexType<'ctx> { - /// Checks whether `llvm_ty` represents a `ndindex` type, returning [Err] if it does not. - pub fn is_representable( - llvm_ty: PointerType<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - let ctx = llvm_ty.get_context(); - - let llvm_ty = llvm_ty.get_element_type(); - let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else { - return Err(format!( - "Expected struct type for `ContiguousNDArray` type, got {llvm_ty}" - )); - }; - - let fields = NDIndexStructFields::new(ctx, llvm_usize); - - check_struct_type_matches_fields(fields, llvm_ty, "NDIndex", &[]) - } - #[must_use] fn fields( ctx: impl AsContextRef<'ctx>, @@ -62,11 +43,6 @@ impl<'ctx> NDIndexType<'ctx> { NDIndexStructFields::new(ctx, llvm_usize) } - #[must_use] - pub fn get_fields(&self) -> NDIndexStructFields<'ctx> { - Self::fields(self.ty.get_context(), self.llvm_usize) - } - #[must_use] fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> { let field_tys = @@ -75,30 +51,64 @@ impl<'ctx> NDIndexType<'ctx> { ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) } - #[must_use] - pub fn new(generator: &G, ctx: &'ctx Context) -> Self { - let llvm_usize = generator.get_size_type(ctx); + fn new_impl(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self { let llvm_ndindex = Self::llvm_type(ctx, llvm_usize); Self { ty: llvm_ndindex, llvm_usize } } #[must_use] - pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { - debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); + pub fn new(ctx: &CodeGenContext<'ctx, '_>) -> Self { + Self::new_impl(ctx.ctx, ctx.get_size_type()) + } + + #[must_use] + pub fn new_with_generator( + generator: &G, + ctx: &'ctx Context, + ) -> Self { + Self::new_impl(ctx, generator.get_size_type(ctx)) + } + + #[must_use] + pub fn from_struct_type(ty: StructType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + Self::from_pointer_type(ty.ptr_type(AddressSpace::default()), llvm_usize) + } + + #[must_use] + pub fn from_pointer_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); Self { ty: ptr_ty, llvm_usize } } + /// Allocates an instance of [`NDIndexValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca`]. #[must_use] - pub fn alloca( + pub fn alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca(ctx, name), + self.llvm_usize, + name, + ) + } + /// Allocates an instance of [`NDIndexValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca_var`]. + #[must_use] + pub fn alloca_var( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value( - self.raw_alloca(generator, ctx, name), + self.raw_alloca_var(generator, ctx, name), self.llvm_usize, name, ) @@ -114,7 +124,7 @@ impl<'ctx> NDIndexType<'ctx> { ) -> ArraySliceValue<'ctx> { // Allocate the LLVM ndindices. let num_ndindices = self.llvm_usize.const_int(in_ndindices.len() as u64, false); - let ndindices = self.array_alloca(generator, ctx, num_ndindices, None); + let ndindices = self.array_alloca_var(generator, ctx, num_ndindices, None); // Initialize all of them. for (i, in_ndindex) in in_ndindices.iter().enumerate() { @@ -138,9 +148,26 @@ impl<'ctx> NDIndexType<'ctx> { } #[must_use] - pub fn map_value( + pub fn map_struct_value( &self, - value: <>::Value as ProxyValue<'ctx>>::Base, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: StructValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_struct_value( + generator, + ctx, + value, + self.llvm_usize, + name, + ) + } + + #[must_use] + pub fn map_pointer_value( + &self, + value: PointerValue<'ctx>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value(value, self.llvm_usize, name) @@ -148,64 +175,55 @@ impl<'ctx> NDIndexType<'ctx> { } impl<'ctx> ProxyType<'ctx> for NDIndexType<'ctx> { + type ABI = PointerType<'ctx>; type Base = PointerType<'ctx>; type Value = NDIndexValue<'ctx>; - fn is_type( - generator: &G, - ctx: &'ctx Context, + fn is_representable( llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, ) -> Result<(), String> { if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { - >::is_representable(generator, ctx, ty) + Self::has_same_repr(ty, llvm_usize) } else { Err(format!("Expected pointer type, got {llvm_ty:?}")) } } - fn is_representable( - generator: &G, - ctx: &'ctx Context, - llvm_ty: Self::Base, - ) -> Result<(), String> { - Self::is_representable(llvm_ty, generator.get_size_type(ctx)) + fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String> { + let ctx = ty.get_context(); + + let llvm_ty = ty.get_element_type(); + let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else { + return Err(format!( + "Expected struct type for `ContiguousNDArray` type, got {llvm_ty}" + )); + }; + + let fields = NDIndexStructFields::new(ctx, llvm_usize); + + check_struct_type_matches_fields(fields, llvm_ty, "NDIndex", &[]) } - fn raw_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> >::Base { - generator - .gen_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - name, - ) - .unwrap() - } - - fn array_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - size: IntValue<'ctx>, - name: Option<&'ctx str>, - ) -> ArraySliceValue<'ctx> { - generator - .gen_array_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - size, - name, - ) - .unwrap() + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.as_abi_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { self.ty } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } +} + +impl<'ctx> StructProxyType<'ctx> for NDIndexType<'ctx> { + type StructFields = NDIndexStructFields<'ctx>; + + fn get_fields(&self) -> Self::StructFields { + Self::fields(self.ty.get_context(), self.llvm_usize) + } } impl<'ctx> From> for PointerType<'ctx> { diff --git a/nac3core/src/codegen/types/ndarray/map.rs b/nac3core/src/codegen/types/ndarray/map.rs new file mode 100644 index 0000000..ae24458 --- /dev/null +++ b/nac3core/src/codegen/types/ndarray/map.rs @@ -0,0 +1,183 @@ +use inkwell::{types::BasicTypeEnum, values::BasicValueEnum}; +use itertools::Itertools; + +use crate::codegen::{ + stmt::gen_for_callback, + types::{ + ndarray::{NDArrayType, NDIterType}, + ProxyType, + }, + values::{ + ndarray::{NDArrayOut, NDArrayValue, ScalarOrNDArray}, + ArrayLikeValue, ProxyValue, + }, + CodeGenContext, CodeGenerator, +}; + +impl<'ctx> NDArrayType<'ctx> { + /// Generate LLVM IR to broadcast `ndarray`s together, and starmap through them with `mapping` + /// elementwise. + /// + /// `mapping` is an LLVM IR generator. The input of `mapping` is the list of elements when + /// iterating through the input `ndarrays` after broadcasting. The output of `mapping` is the + /// result of the elementwise operation. + /// + /// `out` specifies whether the result should be a new ndarray or to be written an existing + /// ndarray. + pub fn broadcast_starmap<'a, G, MappingFn>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + ndarrays: &[NDArrayValue<'ctx>], + out: NDArrayOut<'ctx>, + mapping: MappingFn, + ) -> Result<>::Value, String> + where + G: CodeGenerator + ?Sized, + MappingFn: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + &[BasicValueEnum<'ctx>], + ) -> Result, String>, + { + // Broadcast inputs + let broadcast_result = self.broadcast(generator, ctx, ndarrays); + + let out_ndarray = match out { + NDArrayOut::NewNDArray { dtype } => { + // Create a new ndarray based on the broadcast shape. + let result_ndarray = NDArrayType::new(ctx, dtype, broadcast_result.ndims) + .construct_uninitialized(generator, ctx, None); + result_ndarray.copy_shape_from_array( + generator, + ctx, + broadcast_result.shape.base_ptr(ctx, generator), + ); + unsafe { + result_ndarray.create_data(generator, ctx); + } + result_ndarray + } + + NDArrayOut::WriteToNDArray { ndarray: result_ndarray } => { + // Use an existing ndarray. + + // Check that its shape is compatible with the broadcast shape. + result_ndarray.assert_can_be_written_by_out(generator, ctx, broadcast_result.shape); + result_ndarray + } + }; + + // Map element-wise and store results into `mapped_ndarray`. + let nditer = NDIterType::new(ctx).construct(generator, ctx, out_ndarray); + gen_for_callback( + generator, + ctx, + Some("broadcast_starmap"), + |generator, ctx| { + // Create NDIters for all broadcasted input ndarrays. + let other_nditers = broadcast_result + .ndarrays + .iter() + .map(|ndarray| NDIterType::new(ctx).construct(generator, ctx, *ndarray)) + .collect_vec(); + Ok((nditer, other_nditers)) + }, + |_, ctx, (out_nditer, _in_nditers)| { + // We can simply use `out_nditer`'s `has_element()`. + // `in_nditers`' `has_element()`s should return the same value. + Ok(out_nditer.has_element(ctx)) + }, + |generator, ctx, _hooks, (out_nditer, in_nditers)| { + // Get all the scalars from the broadcasted input ndarrays, pass them to `mapping`, + // and write to `out_ndarray`. + let in_scalars = + in_nditers.iter().map(|nditer| nditer.get_scalar(ctx)).collect_vec(); + + let result = mapping(generator, ctx, &in_scalars)?; + + let p = out_nditer.get_pointer(ctx); + ctx.builder.build_store(p, result).unwrap(); + + Ok(()) + }, + |_, ctx, (out_nditer, in_nditers)| { + // Advance all iterators + out_nditer.next(ctx); + in_nditers.iter().for_each(|nditer| nditer.next(ctx)); + Ok(()) + }, + )?; + + Ok(out_ndarray) + } +} + +impl<'ctx> ScalarOrNDArray<'ctx> { + /// Starmap through a list of inputs using `mapping`, where an input could be an ndarray, a + /// scalar. + /// + /// This function is very helpful when implementing NumPy functions that takes on either scalars + /// or ndarrays or a mix of them as their inputs and produces either an ndarray with broadcast, + /// or a scalar if all its inputs are all scalars. + /// + /// For example ,this function can be used to implement `np.add`, which has the following + /// behaviors: + /// + /// - `np.add(3, 4) = 7` # (scalar, scalar) -> scalar + /// - `np.add(3, np.array([4, 5, 6]))` # (scalar, ndarray) -> ndarray; the first `scalar` is + /// converted into an ndarray and broadcasted. + /// - `np.add(np.array([[1], [2], [3]]), np.array([[4, 5, 6]]))` # (ndarray, ndarray) -> + /// ndarray; there is broadcasting. + /// + /// ## Details: + /// + /// If `inputs` are all [`ScalarOrNDArray::Scalar`], the output will be a + /// [`ScalarOrNDArray::Scalar`] with type `ret_dtype`. + /// + /// Otherwise (if there are any [`ScalarOrNDArray::NDArray`] in `inputs`), all inputs will be + /// 'as-ndarray'-ed into ndarrays, then all inputs (now all ndarrays) will be passed to + /// [`NDArrayValue::broadcasting_starmap`] and **create** a new ndarray with dtype `ret_dtype`. + pub fn broadcasting_starmap<'a, G, MappingFn>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + inputs: &[ScalarOrNDArray<'ctx>], + ret_dtype: BasicTypeEnum<'ctx>, + mapping: MappingFn, + ) -> Result, String> + where + G: CodeGenerator + ?Sized, + MappingFn: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + &[BasicValueEnum<'ctx>], + ) -> Result, String>, + { + // Check if all inputs are Scalars + let all_scalars: Option> = + inputs.iter().map(BasicValueEnum::<'ctx>::try_from).try_collect().ok(); + + if let Some(scalars) = all_scalars { + let scalars = scalars.iter().copied().collect_vec(); + let value = mapping(generator, ctx, &scalars)?; + + Ok(ScalarOrNDArray::Scalar(value)) + } else { + // Promote all input to ndarrays and map through them. + let inputs = inputs.iter().map(|input| input.to_ndarray(generator, ctx)).collect_vec(); + let ndarray = NDArrayType::new_broadcast( + ctx, + ret_dtype, + &inputs.iter().map(NDArrayValue::get_type).collect_vec(), + ) + .broadcast_starmap( + generator, + ctx, + &inputs, + NDArrayOut::NewNDArray { dtype: ret_dtype }, + mapping, + )?; + Ok(ScalarOrNDArray::NDArray(ndarray)) + } + } +} diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index 4127ffa..28ea527 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -1,7 +1,7 @@ use inkwell::{ context::{AsContextRef, Context}, - types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::{BasicValue, IntValue, PointerValue}, + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType, StructType}, + values::{BasicValue, IntValue, PointerValue, StructValue}, AddressSpace, }; use itertools::Itertools; @@ -9,23 +9,28 @@ use itertools::Itertools; use nac3core_derive::StructFields; use super::{ - structure::{check_struct_type_matches_fields, StructField, StructFields}, + structure::{check_struct_type_matches_fields, StructField, StructFields, StructProxyType}, ProxyType, }; use crate::{ codegen::{ - values::{ndarray::NDArrayValue, ArraySliceValue, ProxyValue, TypedArrayLikeMutator}, + values::{ndarray::NDArrayValue, TypedArrayLikeMutator}, {CodeGenContext, CodeGenerator}, }, toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys}, typecheck::typedef::Type, }; +pub use broadcast::*; pub use contiguous::*; pub use indexing::*; pub use nditer::*; +mod array; +mod broadcast; mod contiguous; +pub mod factory; mod indexing; +mod map; mod nditer; /// Proxy type for a `ndarray` type in LLVM. @@ -33,7 +38,7 @@ mod nditer; pub struct NDArrayType<'ctx> { ty: PointerType<'ctx>, dtype: BasicTypeEnum<'ctx>, - ndims: Option, + ndims: u64, llvm_usize: IntType<'ctx>, } @@ -57,26 +62,6 @@ pub struct NDArrayStructFields<'ctx> { } impl<'ctx> NDArrayType<'ctx> { - /// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not. - pub fn is_representable( - llvm_ty: PointerType<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - let ctx = llvm_ty.get_context(); - - let llvm_ndarray_ty = llvm_ty.get_element_type(); - let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { - return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}")); - }; - - check_struct_type_matches_fields( - Self::fields(ctx, llvm_usize), - llvm_ndarray_ty, - "NDArray", - &[], - ) - } - /// Returns an instance of [`StructFields`] containing all field accessors for this type. #[must_use] fn fields( @@ -86,13 +71,6 @@ impl<'ctx> NDArrayType<'ctx> { NDArrayStructFields::new(ctx, llvm_usize) } - /// See [`NDArrayType::fields`]. - // TODO: Move this into e.g. StructProxyType - #[must_use] - pub fn get_fields(&self, ctx: impl AsContextRef<'ctx>) -> NDArrayStructFields<'ctx> { - Self::fields(ctx, self.llvm_usize) - } - /// Creates an LLVM type corresponding to the expected structure of an `NDArray`. #[must_use] fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> { @@ -102,31 +80,85 @@ impl<'ctx> NDArrayType<'ctx> { ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) } - /// Creates an instance of [`NDArrayType`]. - #[must_use] - pub fn new( - generator: &G, + fn new_impl( ctx: &'ctx Context, dtype: BasicTypeEnum<'ctx>, - ndims: Option, + ndims: u64, + llvm_usize: IntType<'ctx>, ) -> Self { - let llvm_usize = generator.get_size_type(ctx); let llvm_ndarray = Self::llvm_type(ctx, llvm_usize); NDArrayType { ty: llvm_ndarray, dtype, ndims, llvm_usize } } + /// Creates an instance of [`NDArrayType`]. + #[must_use] + pub fn new(ctx: &CodeGenContext<'ctx, '_>, dtype: BasicTypeEnum<'ctx>, ndims: u64) -> Self { + Self::new_impl(ctx.ctx, dtype, ndims, ctx.get_size_type()) + } + + /// Creates an instance of [`NDArrayType`]. + #[must_use] + pub fn new_with_generator( + generator: &G, + ctx: &'ctx Context, + dtype: BasicTypeEnum<'ctx>, + ndims: u64, + ) -> Self { + Self::new_impl(ctx, dtype, ndims, generator.get_size_type(ctx)) + } + + /// Creates an instance of [`NDArrayType`] as a result of a broadcast operation over one or more + /// `ndarray` operands. + #[must_use] + pub fn new_broadcast( + ctx: &CodeGenContext<'ctx, '_>, + dtype: BasicTypeEnum<'ctx>, + inputs: &[NDArrayType<'ctx>], + ) -> Self { + assert!(!inputs.is_empty()); + + Self::new_impl( + ctx.ctx, + dtype, + inputs.iter().map(NDArrayType::ndims).max().unwrap(), + ctx.get_size_type(), + ) + } + + /// Creates an instance of [`NDArrayType`] as a result of a broadcast operation over one or more + /// `ndarray` operands. + #[must_use] + pub fn new_broadcast_with_generator( + generator: &G, + ctx: &'ctx Context, + dtype: BasicTypeEnum<'ctx>, + inputs: &[NDArrayType<'ctx>], + ) -> Self { + assert!(!inputs.is_empty()); + + Self::new_impl( + ctx, + dtype, + inputs.iter().map(NDArrayType::ndims).max().unwrap(), + generator.get_size_type(ctx), + ) + } + /// Creates an instance of [`NDArrayType`] with `ndims` of 0. #[must_use] - pub fn new_unsized( + pub fn new_unsized(ctx: &CodeGenContext<'ctx, '_>, dtype: BasicTypeEnum<'ctx>) -> Self { + Self::new_impl(ctx.ctx, dtype, 0, ctx.get_size_type()) + } + + /// Creates an instance of [`NDArrayType`] with `ndims` of 0. + #[must_use] + pub fn new_unsized_with_generator( generator: &G, ctx: &'ctx Context, dtype: BasicTypeEnum<'ctx>, ) -> Self { - let llvm_usize = generator.get_size_type(ctx); - let llvm_ndarray = Self::llvm_type(ctx, llvm_usize); - - NDArrayType { ty: llvm_ndarray, dtype, ndims: Some(0), llvm_usize } + Self::new_impl(ctx, dtype, 0, generator.get_size_type(ctx)) } /// Creates an [`NDArrayType`] from a [unifier type][Type]. @@ -139,26 +171,31 @@ impl<'ctx> NDArrayType<'ctx> { let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); let llvm_dtype = ctx.get_llvm_type(generator, dtype); - let llvm_usize = generator.get_size_type(ctx.ctx); let ndims = extract_ndims(&ctx.unifier, ndims); - NDArrayType { - ty: Self::llvm_type(ctx.ctx, llvm_usize), - dtype: llvm_dtype, - ndims: Some(ndims), - llvm_usize, - } + Self::new_impl(ctx.ctx, llvm_dtype, ndims, ctx.get_size_type()) + } + + /// Creates an [`NDArrayType`] from a [`StructType`] representing an `NDArray`. + #[must_use] + pub fn from_struct_type( + ty: StructType<'ctx>, + dtype: BasicTypeEnum<'ctx>, + ndims: u64, + llvm_usize: IntType<'ctx>, + ) -> Self { + Self::from_pointer_type(ty.ptr_type(AddressSpace::default()), dtype, ndims, llvm_usize) } /// Creates an [`NDArrayType`] from a [`PointerType`] representing an `NDArray`. #[must_use] - pub fn from_type( + pub fn from_pointer_type( ptr_ty: PointerType<'ctx>, dtype: BasicTypeEnum<'ctx>, - ndims: Option, + ndims: u64, llvm_usize: IntType<'ctx>, ) -> Self { - debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); + debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); NDArrayType { ty: ptr_ty, dtype, ndims, llvm_usize } } @@ -177,20 +214,40 @@ impl<'ctx> NDArrayType<'ctx> { /// Returns the number of dimensions of this `ndarray` type. #[must_use] - pub fn ndims(&self) -> Option { + pub fn ndims(&self) -> u64 { self.ndims } /// Allocates an instance of [`NDArrayValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca`]. #[must_use] - pub fn alloca( + pub fn alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca(ctx, name), + self.dtype, + self.ndims, + self.llvm_usize, + name, + ) + } + + /// Allocates an instance of [`NDArrayValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca_var`]. + #[must_use] + pub fn alloca_var( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value( - self.raw_alloca(generator, ctx, name), + self.raw_alloca_var(generator, ctx, name), self.dtype, self.ndims, self.llvm_usize, @@ -214,15 +271,15 @@ impl<'ctx> NDArrayType<'ctx> { ndims: IntValue<'ctx>, name: Option<&'ctx str>, ) -> >::Value { - let ndarray = self.alloca(generator, ctx, name); + let ndarray = self.alloca_var(generator, ctx, name); let itemsize = ctx .builder .build_int_truncate_or_bit_cast(self.dtype.size_of().unwrap(), self.llvm_usize, "") .unwrap(); - ndarray.store_itemsize(ctx, generator, itemsize); + ndarray.store_itemsize(ctx, itemsize); - ndarray.store_ndims(ctx, generator, ndims); + ndarray.store_ndims(ctx, ndims); ndarray.create_shape(ctx, self.llvm_usize, ndims); ndarray.create_strides(ctx, self.llvm_usize, ndims); @@ -247,35 +304,7 @@ impl<'ctx> NDArrayType<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, ) -> >::Value { - assert!(self.ndims.is_some(), "NDArrayType::construct can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))"); - - let Some(ndims) = self.ndims.map(|ndims| self.llvm_usize.const_int(ndims, false)) else { - unreachable!() - }; - - self.construct_impl(generator, ctx, ndims, name) - } - - /// Allocate an [`NDArrayValue`] on the stack given its `ndims` and `dtype`. - /// - /// `shape` and `strides` will be automatically allocated onto the stack. - /// - /// The returned ndarray's content will be: - /// - `data`: uninitialized. - /// - `itemsize`: set to the size of `dtype`. - /// - `ndims`: set to the value of `ndims`. - /// - `shape`: allocated with an array of length `ndims` with uninitialized values. - /// - `strides`: allocated with an array of length `ndims` with uninitialized values. - #[deprecated = "Prefer construct_uninitialized or construct_*_shape."] - #[must_use] - pub fn construct_dyn_ndims( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - ndims: IntValue<'ctx>, - name: Option<&'ctx str>, - ) -> >::Value { - assert!(self.ndims.is_none(), "NDArrayType::construct_dyn_ndims can only be called on an instance with compile-time unknown ndims (self.ndims = None)"); + let ndims = self.llvm_usize.const_int(self.ndims, false); self.construct_impl(generator, ctx, ndims, name) } @@ -291,12 +320,12 @@ impl<'ctx> NDArrayType<'ctx> { shape: &[u64], name: Option<&'ctx str>, ) -> >::Value { - assert!(self.ndims.is_none_or(|ndims| shape.len() as u64 == ndims)); + assert_eq!(shape.len() as u64, self.ndims); - let ndarray = Self::new(generator, ctx.ctx, self.dtype, Some(shape.len() as u64)) + let ndarray = Self::new(ctx, self.dtype, shape.len() as u64) .construct_uninitialized(generator, ctx, name); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); // Write shape let ndarray_shape = ndarray.shape(); @@ -326,12 +355,12 @@ impl<'ctx> NDArrayType<'ctx> { shape: &[IntValue<'ctx>], name: Option<&'ctx str>, ) -> >::Value { - assert!(self.ndims.is_none_or(|ndims| shape.len() as u64 == ndims)); + assert_eq!(shape.len() as u64, self.ndims); - let ndarray = Self::new(generator, ctx.ctx, self.dtype, Some(shape.len() as u64)) + let ndarray = Self::new(ctx, self.dtype, shape.len() as u64) .construct_uninitialized(generator, ctx, name); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); // Write shape let ndarray_shape = ndarray.shape(); @@ -368,7 +397,7 @@ impl<'ctx> NDArrayType<'ctx> { let value = value.as_basic_value_enum(); assert_eq!(value.get_type(), self.dtype); - assert!(self.ndims.is_none_or(|ndims| ndims == 0)); + assert_eq!(self.ndims, 0); // We have to put the value on the stack to get a data pointer. let data = ctx.builder.build_alloca(value.get_type(), "construct_unsized").unwrap(); @@ -378,17 +407,37 @@ impl<'ctx> NDArrayType<'ctx> { .build_pointer_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "") .unwrap(); - let ndarray = Self::new_unsized(generator, ctx.ctx, value.get_type()) - .construct_uninitialized(generator, ctx, name); + let ndarray = + Self::new_unsized(ctx, value.get_type()).construct_uninitialized(generator, ctx, name); ctx.builder.build_store(ndarray.ptr_to_data(ctx), data).unwrap(); ndarray } /// Converts an existing value into a [`NDArrayValue`]. #[must_use] - pub fn map_value( + pub fn map_struct_value( &self, - value: <>::Value as ProxyValue<'ctx>>::Base, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: StructValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_struct_value( + generator, + ctx, + value, + self.dtype, + self.ndims, + self.llvm_usize, + name, + ) + } + + /// Converts an existing value into a [`NDArrayValue`]. + #[must_use] + pub fn map_pointer_value( + &self, + value: PointerValue<'ctx>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value( @@ -402,64 +451,56 @@ impl<'ctx> NDArrayType<'ctx> { } impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> { + type ABI = PointerType<'ctx>; type Base = PointerType<'ctx>; type Value = NDArrayValue<'ctx>; - fn is_type( - generator: &G, - ctx: &'ctx Context, + fn is_representable( llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, ) -> Result<(), String> { if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { - >::is_representable(generator, ctx, ty) + Self::has_same_repr(ty, llvm_usize) } else { Err(format!("Expected pointer type, got {llvm_ty:?}")) } } - fn is_representable( - generator: &G, - ctx: &'ctx Context, - llvm_ty: Self::Base, - ) -> Result<(), String> { - Self::is_representable(llvm_ty, generator.get_size_type(ctx)) + fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String> { + let ctx = ty.get_context(); + + let llvm_ndarray_ty = ty.get_element_type(); + let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { + return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}")); + }; + + check_struct_type_matches_fields( + Self::fields(ctx, llvm_usize), + llvm_ndarray_ty, + "NDArray", + &[], + ) } - fn raw_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> >::Base { - generator - .gen_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - name, - ) - .unwrap() - } - - fn array_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - size: IntValue<'ctx>, - name: Option<&'ctx str>, - ) -> ArraySliceValue<'ctx> { - generator - .gen_array_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - size, - name, - ) - .unwrap() + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.as_abi_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { self.ty } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } +} + +impl<'ctx> StructProxyType<'ctx> for NDArrayType<'ctx> { + type StructFields = NDArrayStructFields<'ctx>; + + fn get_fields(&self) -> Self::StructFields { + Self::fields(self.ty.get_context(), self.llvm_usize) + } } impl<'ctx> From> for PointerType<'ctx> { diff --git a/nac3core/src/codegen/types/ndarray/nditer.rs b/nac3core/src/codegen/types/ndarray/nditer.rs index c9b6b7d..aec1a6f 100644 --- a/nac3core/src/codegen/types/ndarray/nditer.rs +++ b/nac3core/src/codegen/types/ndarray/nditer.rs @@ -1,7 +1,7 @@ use inkwell::{ context::{AsContextRef, Context}, - types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::{IntValue, PointerValue}, + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType, StructType}, + values::{IntValue, PointerValue, StructValue}, AddressSpace, }; use itertools::Itertools; @@ -11,10 +11,12 @@ use nac3core_derive::StructFields; use super::ProxyType; use crate::codegen::{ irrt, - types::structure::{check_struct_type_matches_fields, StructField, StructFields}, + types::structure::{ + check_struct_type_matches_fields, StructField, StructFields, StructProxyType, + }, values::{ ndarray::{NDArrayValue, NDIterValue}, - ArraySliceValue, ProxyValue, + ArrayLikeValue, ArraySliceValue, ProxyValue, TypedArrayLikeAdapter, }, CodeGenContext, CodeGenerator, }; @@ -44,39 +46,12 @@ pub struct NDIterStructFields<'ctx> { } impl<'ctx> NDIterType<'ctx> { - /// Checks whether `llvm_ty` represents a `nditer` type, returning [Err] if it does not. - pub fn is_representable( - llvm_ty: PointerType<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - let ctx = llvm_ty.get_context(); - - let llvm_ty = llvm_ty.get_element_type(); - let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ty else { - return Err(format!("Expected struct type for `NDIter` type, got {llvm_ty}")); - }; - - check_struct_type_matches_fields( - Self::fields(ctx, llvm_usize), - llvm_ndarray_ty, - "NDIter", - &[], - ) - } - /// Returns an instance of [`StructFields`] containing all field accessors for this type. #[must_use] fn fields(ctx: impl AsContextRef<'ctx>, llvm_usize: IntType<'ctx>) -> NDIterStructFields<'ctx> { NDIterStructFields::new(ctx, llvm_usize) } - /// See [`NDIterType::fields`]. - // TODO: Move this into e.g. StructProxyType - #[must_use] - pub fn get_fields(&self, ctx: impl AsContextRef<'ctx>) -> NDIterStructFields<'ctx> { - Self::fields(ctx, self.llvm_usize) - } - /// Creates an LLVM type corresponding to the expected structure of an `NDIter`. #[must_use] fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> { @@ -86,19 +61,37 @@ impl<'ctx> NDIterType<'ctx> { ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) } - /// Creates an instance of [`NDIter`]. - #[must_use] - pub fn new(generator: &G, ctx: &'ctx Context) -> Self { - let llvm_usize = generator.get_size_type(ctx); + fn new_impl(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self { let llvm_nditer = Self::llvm_type(ctx, llvm_usize); Self { ty: llvm_nditer, llvm_usize } } + /// Creates an instance of [`NDIter`]. + #[must_use] + pub fn new(ctx: &CodeGenContext<'ctx, '_>) -> Self { + Self::new_impl(ctx.ctx, ctx.get_size_type()) + } + + /// Creates an instance of [`NDIter`]. + #[must_use] + pub fn new_with_generator( + generator: &G, + ctx: &'ctx Context, + ) -> Self { + Self::new_impl(ctx, generator.get_size_type(ctx)) + } + + /// Creates an [`NDIterType`] from a [`StructType`] representing an `NDIter`. + #[must_use] + pub fn from_struct_type(ty: StructType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + Self::from_pointer_type(ty.ptr_type(AddressSpace::default()), llvm_usize) + } + /// Creates an [`NDIterType`] from a [`PointerType`] representing an `NDIter`. #[must_use] - pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { - debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); + pub fn from_pointer_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); Self { ty: ptr_ty, llvm_usize } } @@ -109,8 +102,31 @@ impl<'ctx> NDIterType<'ctx> { self.llvm_usize } + /// Allocates an instance of [`NDIterValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca`]. #[must_use] - pub fn alloca( + pub fn alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + parent: NDArrayValue<'ctx>, + indices: ArraySliceValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca(ctx, name), + parent, + indices, + self.llvm_usize, + name, + ) + } + + /// Allocates an instance of [`NDIterValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca_var`]. + #[must_use] + pub fn alloca_var( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, @@ -119,7 +135,7 @@ impl<'ctx> NDIterType<'ctx> { name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value( - self.raw_alloca(generator, ctx, name), + self.raw_alloca_var(generator, ctx, name), parent, indices, self.llvm_usize, @@ -135,30 +151,48 @@ impl<'ctx> NDIterType<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, ) -> >::Value { - let nditer = self.raw_alloca(generator, ctx, None); - let ndims = ndarray.load_ndims(ctx); + let nditer = self.raw_alloca_var(generator, ctx, None); + let ndims = self.llvm_usize.const_int(ndarray.get_type().ndims(), false); // The caller has the responsibility to allocate 'indices' for `NDIter`. let indices = generator.gen_array_var_alloc(ctx, self.llvm_usize.into(), ndims, None).unwrap(); + let indices = + TypedArrayLikeAdapter::from(indices, |_, _, v| v.into_int_value(), |_, _, v| v.into()); - let nditer = >::Value::from_pointer_value( - nditer, - ndarray, - indices, - self.llvm_usize, - None, - ); + let nditer = + self.map_pointer_value(nditer, ndarray, indices.as_slice_value(ctx, generator), None); - irrt::ndarray::call_nac3_nditer_initialize(generator, ctx, nditer, ndarray, indices); + irrt::ndarray::call_nac3_nditer_initialize(generator, ctx, nditer, ndarray, &indices); nditer } #[must_use] - pub fn map_value( + pub fn map_struct_value( &self, - value: <>::Value as ProxyValue<'ctx>>::Base, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: StructValue<'ctx>, + parent: NDArrayValue<'ctx>, + indices: ArraySliceValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_struct_value( + generator, + ctx, + value, + parent, + indices, + self.llvm_usize, + name, + ) + } + + #[must_use] + pub fn map_pointer_value( + &self, + value: PointerValue<'ctx>, parent: NDArrayValue<'ctx>, indices: ArraySliceValue<'ctx>, name: Option<&'ctx str>, @@ -174,64 +208,56 @@ impl<'ctx> NDIterType<'ctx> { } impl<'ctx> ProxyType<'ctx> for NDIterType<'ctx> { + type ABI = PointerType<'ctx>; type Base = PointerType<'ctx>; type Value = NDIterValue<'ctx>; - fn is_type( - generator: &G, - ctx: &'ctx Context, + fn is_representable( llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, ) -> Result<(), String> { if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { - >::is_representable(generator, ctx, ty) + Self::has_same_repr(ty, llvm_usize) } else { Err(format!("Expected pointer type, got {llvm_ty:?}")) } } - fn is_representable( - generator: &G, - ctx: &'ctx Context, - llvm_ty: Self::Base, - ) -> Result<(), String> { - Self::is_representable(llvm_ty, generator.get_size_type(ctx)) + fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String> { + let ctx = ty.get_context(); + + let llvm_ty = ty.get_element_type(); + let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ty else { + return Err(format!("Expected struct type for `NDIter` type, got {llvm_ty}")); + }; + + check_struct_type_matches_fields( + Self::fields(ctx, llvm_usize), + llvm_ndarray_ty, + "NDIter", + &[], + ) } - fn raw_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> >::Base { - generator - .gen_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - name, - ) - .unwrap() - } - - fn array_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - size: IntValue<'ctx>, - name: Option<&'ctx str>, - ) -> ArraySliceValue<'ctx> { - generator - .gen_array_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - size, - name, - ) - .unwrap() + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.as_abi_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { self.ty } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } +} + +impl<'ctx> StructProxyType<'ctx> for NDIterType<'ctx> { + type StructFields = NDIterStructFields<'ctx>; + + fn get_fields(&self) -> Self::StructFields { + Self::fields(self.ty.get_context(), self.llvm_usize) + } } impl<'ctx> From> for PointerType<'ctx> { diff --git a/nac3core/src/codegen/types/range.rs b/nac3core/src/codegen/types/range.rs index e704455..e8f6f4d 100644 --- a/nac3core/src/codegen/types/range.rs +++ b/nac3core/src/codegen/types/range.rs @@ -1,26 +1,167 @@ use inkwell::{ context::Context, - types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::IntValue, + types::{AnyTypeEnum, ArrayType, BasicType, BasicTypeEnum, IntType, PointerType}, + values::{ArrayValue, PointerValue}, AddressSpace, }; use super::ProxyType; -use crate::codegen::{ - values::{ArraySliceValue, ProxyValue, RangeValue}, - {CodeGenContext, CodeGenerator}, +use crate::{ + codegen::{ + values::RangeValue, + {CodeGenContext, CodeGenerator}, + }, + typecheck::typedef::{Type, TypeEnum}, }; /// Proxy type for a `range` type in LLVM. #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub struct RangeType<'ctx> { ty: PointerType<'ctx>, + llvm_usize: IntType<'ctx>, } impl<'ctx> RangeType<'ctx> { - /// Checks whether `llvm_ty` represents a `range` type, returning [Err] if it does not. - pub fn is_representable(llvm_ty: PointerType<'ctx>) -> Result<(), String> { - let llvm_range_ty = llvm_ty.get_element_type(); + /// Creates an LLVM type corresponding to the expected structure of a `Range`. + #[must_use] + fn llvm_type(ctx: &'ctx Context) -> PointerType<'ctx> { + // typedef int32_t Range[3]; + let llvm_i32 = ctx.i32_type(); + llvm_i32.array_type(3).ptr_type(AddressSpace::default()) + } + + fn new_impl(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self { + let llvm_range = Self::llvm_type(ctx); + + RangeType { ty: llvm_range, llvm_usize } + } + + /// Creates an instance of [`RangeType`]. + #[must_use] + pub fn new(ctx: &CodeGenContext<'ctx, '_>) -> Self { + Self::new_impl(ctx.ctx, ctx.get_size_type()) + } + + /// Creates an instance of [`RangeType`]. + #[must_use] + pub fn new_with_generator( + generator: &G, + ctx: &'ctx Context, + ) -> Self { + Self::new_impl(ctx, generator.get_size_type(ctx)) + } + + /// Creates an [`RangeType`] from a [unifier type][Type]. + #[must_use] + pub fn from_unifier_type(ctx: &mut CodeGenContext<'ctx, '_>, ty: Type) -> Self { + // Check unifier type + assert!( + matches!(&*ctx.unifier.get_ty_immutable(ty), TypeEnum::TObj { obj_id, .. } if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap()) + ); + + Self::new(ctx) + } + + /// Creates an [`RangeType`] from a [`ArrayType`]. + #[must_use] + pub fn from_array_type(arr_ty: ArrayType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + Self::from_pointer_type(arr_ty.ptr_type(AddressSpace::default()), llvm_usize) + } + + /// Creates an [`RangeType`] from a [`PointerType`]. + #[must_use] + pub fn from_pointer_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); + + RangeType { ty: ptr_ty, llvm_usize } + } + + /// Returns the type of all fields of this `range` type. + #[must_use] + pub fn value_type(&self) -> IntType<'ctx> { + self.as_abi_type().get_element_type().into_array_type().get_element_type().into_int_type() + } + + /// Allocates an instance of [`RangeValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca`]. + #[must_use] + pub fn alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca(ctx, name), + self.llvm_usize, + name, + ) + } + + /// Allocates an instance of [`RangeValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca_var`]. + #[must_use] + pub fn alloca_var( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca_var(generator, ctx, name), + self.llvm_usize, + name, + ) + } + + /// Converts an existing value into a [`RangeValue`]. + #[must_use] + pub fn map_array_value( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: ArrayValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_array_value( + generator, + ctx, + value, + self.llvm_usize, + name, + ) + } + + /// Converts an existing value into a [`RangeValue`]. + #[must_use] + pub fn map_pointer_value( + &self, + value: PointerValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value(value, self.llvm_usize, name) + } +} + +impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> { + type ABI = PointerType<'ctx>; + type Base = PointerType<'ctx>; + type Value = RangeValue<'ctx>; + + fn is_representable( + llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { + Self::has_same_repr(ty, llvm_usize) + } else { + Err(format!("Expected pointer type, got {llvm_ty:?}")) + } + } + + fn has_same_repr(ty: Self::Base, _: IntType<'ctx>) -> Result<(), String> { + let llvm_range_ty = ty.get_element_type(); let AnyTypeEnum::ArrayType(llvm_range_ty) = llvm_range_ty else { return Err(format!("Expected array type for `range` type, got {llvm_range_ty}")); }; @@ -47,120 +188,17 @@ impl<'ctx> RangeType<'ctx> { Ok(()) } - /// Creates an LLVM type corresponding to the expected structure of a `Range`. - #[must_use] - fn llvm_type(ctx: &'ctx Context) -> PointerType<'ctx> { - // typedef int32_t Range[3]; - let llvm_i32 = ctx.i32_type(); - llvm_i32.array_type(3).ptr_type(AddressSpace::default()) - } - - /// Creates an instance of [`RangeType`]. - #[must_use] - pub fn new(ctx: &'ctx Context) -> Self { - let llvm_range = Self::llvm_type(ctx); - - RangeType::from_type(llvm_range) - } - - /// Creates an [`RangeType`] from a [`PointerType`]. - #[must_use] - pub fn from_type(ptr_ty: PointerType<'ctx>) -> Self { - debug_assert!(Self::is_representable(ptr_ty).is_ok()); - - RangeType { ty: ptr_ty } - } - - /// Returns the type of all fields of this `range` type. - #[must_use] - pub fn value_type(&self) -> IntType<'ctx> { - self.as_base_type().get_element_type().into_array_type().get_element_type().into_int_type() - } - - /// Allocates an instance of [`RangeValue`] as if by calling `alloca` on the base type. - #[must_use] - pub fn alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> >::Value { - >::Value::from_pointer_value( - self.raw_alloca(generator, ctx, name), - name, - ) - } - - /// Converts an existing value into a [`RangeValue`]. - #[must_use] - pub fn map_value( - &self, - value: <>::Value as ProxyValue<'ctx>>::Base, - name: Option<&'ctx str>, - ) -> >::Value { - >::Value::from_pointer_value(value, name) - } -} - -impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> { - type Base = PointerType<'ctx>; - type Value = RangeValue<'ctx>; - - fn is_type( - generator: &G, - ctx: &'ctx Context, - llvm_ty: impl BasicType<'ctx>, - ) -> Result<(), String> { - if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { - >::is_representable(generator, ctx, ty) - } else { - Err(format!("Expected pointer type, got {llvm_ty:?}")) - } - } - - fn is_representable( - _: &G, - _: &'ctx Context, - llvm_ty: Self::Base, - ) -> Result<(), String> { - Self::is_representable(llvm_ty) - } - - fn raw_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> >::Base { - generator - .gen_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - name, - ) - .unwrap() - } - - fn array_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - size: IntValue<'ctx>, - name: Option<&'ctx str>, - ) -> ArraySliceValue<'ctx> { - generator - .gen_array_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - size, - name, - ) - .unwrap() + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.as_abi_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { self.ty } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } } impl<'ctx> From> for PointerType<'ctx> { diff --git a/nac3core/src/codegen/types/structure.rs b/nac3core/src/codegen/types/structure.rs index 4d6dcaf..d2622b0 100644 --- a/nac3core/src/codegen/types/structure.rs +++ b/nac3core/src/codegen/types/structure.rs @@ -2,12 +2,55 @@ use std::marker::PhantomData; use inkwell::{ context::AsContextRef, - types::{BasicTypeEnum, IntType, StructType}, - values::{BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue}, + types::{BasicTypeEnum, IntType, PointerType, StructType}, + values::{AggregateValueEnum, BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue}, + AddressSpace, }; +use itertools::Itertools; +use super::ProxyType; use crate::codegen::CodeGenContext; +/// A LLVM type that is used to represent a corresponding structure-like type in NAC3. +pub trait StructProxyType<'ctx>: ProxyType<'ctx, Base = PointerType<'ctx>> { + /// The concrete type of [`StructFields`]. + type StructFields: StructFields<'ctx>; + + /// Whether this [`StructProxyType`] has the same LLVM type representation as + /// [`llvm_ty`][StructType]. + fn has_same_struct_repr( + llvm_ty: StructType<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + Self::has_same_pointer_repr(llvm_ty.ptr_type(AddressSpace::default()), llvm_usize) + } + + /// Whether this [`StructProxyType`] has the same LLVM type representation as + /// [`llvm_ty`][PointerType]. + fn has_same_pointer_repr( + llvm_ty: PointerType<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + Self::has_same_repr(llvm_ty, llvm_usize) + } + + /// Returns the fields present in this [`StructProxyType`]. + #[must_use] + fn get_fields(&self) -> Self::StructFields; + + /// Returns the [`StructType`]. + #[must_use] + fn get_struct_type(&self) -> StructType<'ctx> { + self.as_base_type().get_element_type().into_struct_type() + } + + /// Returns the [`PointerType`] representing this type. + #[must_use] + fn get_pointer_type(&self) -> PointerType<'ctx> { + self.as_base_type() + } +} + /// Trait indicating that the structure is a field-wise representation of an LLVM structure. /// /// # Usage @@ -55,6 +98,20 @@ pub trait StructFields<'ctx>: Eq + Copy { { self.into_vec().into_iter() } + + /// Returns the field index of a field in this structure. + fn index_of_field(&self, name: impl FnOnce(&Self) -> StructField<'ctx, V>) -> u32 + where + V: BasicValue<'ctx> + TryFrom, Error = ()>, + { + let field_name = name(self).name; + self.index_of_field_name(field_name).unwrap() + } + + /// Returns the field index of a field with the given name in this structure. + fn index_of_field_name(&self, field_name: &str) -> Option { + self.iter().find_position(|(name, _)| *name == field_name).map(|(idx, _)| idx as u32) + } } /// A single field of an LLVM structure. @@ -146,17 +203,38 @@ where /// Gets the value of this field for a given `obj`. #[must_use] - pub fn get_from_value(&self, obj: StructValue<'ctx>) -> Value { - obj.get_field_at_index(self.index).and_then(|value| Value::try_from(value).ok()).unwrap() + pub fn extract_value(&self, ctx: &CodeGenContext<'ctx, '_>, obj: StructValue<'ctx>) -> Value { + Value::try_from( + ctx.builder + .build_extract_value( + obj, + self.index, + &format!("{}.{}", obj.get_name().to_str().unwrap(), self.name), + ) + .unwrap(), + ) + .unwrap() } /// Sets the value of this field for a given `obj`. - pub fn set_for_value(&self, obj: StructValue<'ctx>, value: Value) { - obj.set_field_at_index(self.index, value); + #[must_use] + pub fn insert_value( + &self, + ctx: &CodeGenContext<'ctx, '_>, + obj: StructValue<'ctx>, + value: Value, + ) -> StructValue<'ctx> { + let obj_name = obj.get_name().to_str().unwrap(); + let new_obj_name = if obj_name.chars().all(char::is_numeric) { "" } else { obj_name }; + + ctx.builder + .build_insert_value(obj, value, self.index, new_obj_name) + .map(AggregateValueEnum::into_struct_value) + .unwrap() } - /// Gets the value of this field for a pointer-to-structure. - pub fn get( + /// Loads the value of this field for a pointer-to-structure. + pub fn load( &self, ctx: &CodeGenContext<'ctx, '_>, pobj: PointerValue<'ctx>, @@ -172,8 +250,8 @@ where .unwrap() } - /// Sets the value of this field for a pointer-to-structure. - pub fn set( + /// Stores the value of this field for a pointer-to-structure. + pub fn store( &self, ctx: &CodeGenContext<'ctx, '_>, pobj: PointerValue<'ctx>, diff --git a/nac3core/src/codegen/types/tuple.rs b/nac3core/src/codegen/types/tuple.rs new file mode 100644 index 0000000..ea66feb --- /dev/null +++ b/nac3core/src/codegen/types/tuple.rs @@ -0,0 +1,213 @@ +use inkwell::{ + context::Context, + types::{BasicType, BasicTypeEnum, IntType, PointerType, StructType}, + values::{BasicValueEnum, PointerValue, StructValue}, +}; +use itertools::Itertools; + +use super::ProxyType; +use crate::{ + codegen::{values::TupleValue, CodeGenContext, CodeGenerator}, + typecheck::typedef::{Type, TypeEnum}, +}; + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct TupleType<'ctx> { + ty: StructType<'ctx>, + llvm_usize: IntType<'ctx>, +} + +impl<'ctx> TupleType<'ctx> { + /// Creates an LLVM type corresponding to the expected structure of a tuple. + #[must_use] + fn llvm_type(ctx: &'ctx Context, tys: &[BasicTypeEnum<'ctx>]) -> StructType<'ctx> { + ctx.struct_type(tys, false) + } + + fn new_impl( + ctx: &'ctx Context, + tys: &[BasicTypeEnum<'ctx>], + llvm_usize: IntType<'ctx>, + ) -> Self { + let llvm_tuple = Self::llvm_type(ctx, tys); + + Self { ty: llvm_tuple, llvm_usize } + } + + /// Creates an instance of [`TupleType`]. + #[must_use] + pub fn new(ctx: &CodeGenContext<'ctx, '_>, tys: &[impl BasicType<'ctx>]) -> Self { + Self::new_impl( + ctx.ctx, + &tys.iter().map(BasicType::as_basic_type_enum).collect_vec(), + ctx.get_size_type(), + ) + } + + /// Creates an instance of [`TupleType`]. + #[must_use] + pub fn new_with_generator( + generator: &G, + ctx: &'ctx Context, + tys: &[BasicTypeEnum<'ctx>], + ) -> Self { + Self::new_impl(ctx, tys, generator.get_size_type(ctx)) + } + + /// Creates an [`TupleType`] from a [unifier type][Type]. + #[must_use] + pub fn from_unifier_type( + generator: &G, + ctx: &mut CodeGenContext<'ctx, '_>, + ty: Type, + ) -> Self { + let llvm_usize = ctx.get_size_type(); + + // Sanity check on object type. + let TypeEnum::TTuple { ty: tys, .. } = &*ctx.unifier.get_ty_immutable(ty) else { + panic!("Expected type to be a TypeEnum::TTuple, got {}", ctx.unifier.stringify(ty)); + }; + + let llvm_tys = tys.iter().map(|ty| ctx.get_llvm_type(generator, *ty)).collect_vec(); + Self { ty: Self::llvm_type(ctx.ctx, &llvm_tys), llvm_usize } + } + + /// Creates an [`TupleType`] from a [`StructType`]. + #[must_use] + pub fn from_struct_type(struct_ty: StructType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + debug_assert!(Self::has_same_repr(struct_ty, llvm_usize).is_ok()); + + TupleType { ty: struct_ty, llvm_usize } + } + + /// Creates an [`TupleType`] from a [`PointerType`]. + #[must_use] + pub fn from_pointer_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + Self::from_struct_type(ptr_ty.get_element_type().into_struct_type(), llvm_usize) + } + + /// Returns the number of elements present in this [`TupleType`]. + #[must_use] + pub fn num_elements(&self) -> u32 { + self.ty.count_fields() + } + + /// Returns the type of the tuple element at the given `index`, or [`None`] if `index` is out of + /// range. + #[must_use] + pub fn type_at_index(&self, index: u32) -> Option> { + if index < self.num_elements() { + Some(unsafe { self.type_at_index_unchecked(index) }) + } else { + None + } + } + + /// Returns the type of the tuple element at the given `index`. + /// + /// # Safety + /// + /// The caller must ensure that the index is valid. + #[must_use] + pub unsafe fn type_at_index_unchecked(&self, index: u32) -> BasicTypeEnum<'ctx> { + self.ty.get_field_type_at_index_unchecked(index) + } + + /// Constructs a [`TupleValue`] from this type by zero-initializing the tuple value. + #[must_use] + pub fn construct( + &self, + ctx: &CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + self.map_struct_value( + Self::llvm_type(ctx.ctx, &self.ty.get_field_types()).const_zero(), + name, + ) + } + + /// Constructs a [`TupleValue`] from `objects`. The resulting tuple preserves the order of + /// objects. + #[must_use] + pub fn construct_from_objects>>( + &self, + ctx: &CodeGenContext<'ctx, '_>, + objects: I, + name: Option<&'ctx str>, + ) -> >::Value { + let values = objects.into_iter().collect_vec(); + + assert_eq!(values.len(), self.num_elements() as usize); + assert!(values + .iter() + .enumerate() + .all(|(i, v)| { v.get_type() == unsafe { self.type_at_index_unchecked(i as u32) } })); + + let mut value = self.construct(ctx, name); + for (i, val) in values.into_iter().enumerate() { + value.store_element(ctx, i as u32, val); + } + + value + } + + /// Converts an existing value into a [`ListValue`]. + #[must_use] + pub fn map_struct_value( + &self, + value: StructValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_struct_value(value, self.llvm_usize, name) + } + + /// Converts an existing value into a [`TupleValue`]. + #[must_use] + pub fn map_pointer_value( + &self, + ctx: &CodeGenContext<'ctx, '_>, + value: PointerValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value(ctx, value, self.llvm_usize, name) + } +} + +impl<'ctx> ProxyType<'ctx> for TupleType<'ctx> { + type ABI = StructType<'ctx>; + type Base = StructType<'ctx>; + type Value = TupleValue<'ctx>; + + fn is_representable( + llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + if let BasicTypeEnum::StructType(ty) = llvm_ty.as_basic_type_enum() { + Self::has_same_repr(ty, llvm_usize) + } else { + Err(format!("Expected struct type, got {llvm_ty:?}")) + } + } + + fn has_same_repr(_: Self::Base, _: IntType<'ctx>) -> Result<(), String> { + Ok(()) + } + + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.as_base_type() + } + + fn as_base_type(&self) -> Self::Base { + self.ty + } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } +} + +impl<'ctx> From> for StructType<'ctx> { + fn from(value: TupleType<'ctx>) -> Self { + value.as_base_type() + } +} diff --git a/nac3core/src/codegen/types/utils/slice.rs b/nac3core/src/codegen/types/utils/slice.rs index aba0efb..e43ac74 100644 --- a/nac3core/src/codegen/types/utils/slice.rs +++ b/nac3core/src/codegen/types/utils/slice.rs @@ -1,7 +1,7 @@ use inkwell::{ context::{AsContextRef, Context, ContextRef}, - types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::IntValue, + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType, StructType}, + values::{IntValue, PointerValue, StructValue}, AddressSpace, }; use itertools::Itertools; @@ -12,10 +12,11 @@ use crate::codegen::{ types::{ structure::{ check_struct_type_matches_fields, FieldIndexCounter, StructField, StructFields, + StructProxyType, }, ProxyType, }, - values::{utils::SliceValue, ArraySliceValue, ProxyValue}, + values::utils::SliceValue, CodeGenContext, CodeGenerator, }; @@ -27,7 +28,7 @@ pub struct SliceType<'ctx> { } #[derive(PartialEq, Eq, Clone, Copy, StructFields)] -pub struct SliceFields<'ctx> { +pub struct SliceStructFields<'ctx> { #[value_type(bool_type())] pub start_defined: StructField<'ctx, IntValue<'ctx>>, #[value_type(usize)] @@ -42,14 +43,14 @@ pub struct SliceFields<'ctx> { pub step: StructField<'ctx, IntValue<'ctx>>, } -impl<'ctx> SliceFields<'ctx> { - /// Creates a new instance of [`SliceFields`] with a custom integer type for its range values. +impl<'ctx> SliceStructFields<'ctx> { + /// Creates a new instance of [`SliceStructFields`] with a custom integer type for its range values. #[must_use] pub fn new_sized(ctx: &impl AsContextRef<'ctx>, int_ty: IntType<'ctx>) -> Self { let ctx = unsafe { ContextRef::new(ctx.as_ctx_ref()) }; let mut counter = FieldIndexCounter::default(); - SliceFields { + SliceStructFields { start_defined: StructField::create(&mut counter, "start_defined", ctx.bool_type()), start: StructField::create(&mut counter, "start", int_ty), stop_defined: StructField::create(&mut counter, "stop_defined", ctx.bool_type()), @@ -61,16 +62,173 @@ impl<'ctx> SliceFields<'ctx> { } impl<'ctx> SliceType<'ctx> { - /// Checks whether `llvm_ty` represents a `slice` type, returning [Err] if it does not. - pub fn is_representable( - llvm_ty: PointerType<'ctx>, + /// Creates an LLVM type corresponding to the expected structure of a `Slice`. + #[must_use] + fn llvm_type(ctx: &'ctx Context, int_ty: IntType<'ctx>) -> PointerType<'ctx> { + let field_tys = SliceStructFields::new_sized(&int_ty.get_context(), int_ty) + .into_iter() + .map(|field| field.1) + .collect_vec(); + + ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) + } + + fn new_impl(ctx: &'ctx Context, int_ty: IntType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + let llvm_ty = Self::llvm_type(ctx, int_ty); + + Self { ty: llvm_ty, int_ty, llvm_usize } + } + + /// Creates an instance of [`SliceType`] with `int_ty` as its backing integer type. + #[must_use] + pub fn new(ctx: &CodeGenContext<'ctx, '_>, int_ty: IntType<'ctx>) -> Self { + Self::new_impl(ctx.ctx, int_ty, ctx.get_size_type()) + } + + /// Creates an instance of [`SliceType`] with `int_ty` as its backing integer type. + #[must_use] + pub fn new_with_generator( + generator: &G, + ctx: &'ctx Context, + int_ty: IntType<'ctx>, + ) -> Self { + Self::new_impl(ctx, int_ty, generator.get_size_type(ctx)) + } + + /// Creates an instance of [`SliceType`] with `usize` as its backing integer type. + #[must_use] + pub fn new_usize(ctx: &CodeGenContext<'ctx, '_>) -> Self { + Self::new_impl(ctx.ctx, ctx.get_size_type(), ctx.get_size_type()) + } + + /// Creates an instance of [`SliceType`] with `usize` as its backing integer type. + #[must_use] + pub fn new_usize_with_generator( + generator: &G, + ctx: &'ctx Context, + ) -> Self { + Self::new_impl(ctx, generator.get_size_type(ctx), generator.get_size_type(ctx)) + } + + /// Creates an [`SliceType`] from a [`StructType`] representing a `slice`. + #[must_use] + pub fn from_struct_type( + ty: StructType<'ctx>, + int_ty: IntType<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Self { + Self::from_pointer_type(ty.ptr_type(AddressSpace::default()), int_ty, llvm_usize) + } + + /// Creates an [`SliceType`] from a [`PointerType`] representing a `slice`. + #[must_use] + pub fn from_pointer_type( + ptr_ty: PointerType<'ctx>, + int_ty: IntType<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Self { + debug_assert!(Self::has_same_repr(ptr_ty, int_ty).is_ok()); + + Self { ty: ptr_ty, int_ty, llvm_usize } + } + + #[must_use] + pub fn element_type(&self) -> IntType<'ctx> { + self.int_ty + } + + /// Allocates an instance of [`SliceValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca`]. + #[must_use] + pub fn alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca(ctx, name), + self.int_ty, + self.llvm_usize, + name, + ) + } + + /// Allocates an instance of [`SliceValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca_var`]. + #[must_use] + pub fn alloca_var( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca_var(generator, ctx, name), + self.int_ty, + self.llvm_usize, + name, + ) + } + + /// Converts an existing value into a [`SliceValue`]. + #[must_use] + pub fn map_struct_value( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: StructValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_struct_value( + generator, + ctx, + value, + self.int_ty, + self.llvm_usize, + name, + ) + } + + /// Converts an existing value into a [`ContiguousNDArrayValue`]. + #[must_use] + pub fn map_pointer_value( + &self, + value: PointerValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + value, + self.int_ty, + self.llvm_usize, + name, + ) + } +} + +impl<'ctx> ProxyType<'ctx> for SliceType<'ctx> { + type ABI = PointerType<'ctx>; + type Base = PointerType<'ctx>; + type Value = SliceValue<'ctx>; + + fn is_representable( + llvm_ty: impl BasicType<'ctx>, llvm_usize: IntType<'ctx>, ) -> Result<(), String> { - let ctx = llvm_ty.get_context(); + if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { + Self::has_same_repr(ty, llvm_usize) + } else { + Err(format!("Expected pointer type, got {llvm_ty:?}")) + } + } - let fields = SliceFields::new(ctx, llvm_usize); + fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String> { + let ctx = ty.get_context(); - let llvm_ty = llvm_ty.get_element_type(); + let fields = SliceStructFields::new(ctx, llvm_usize); + + let llvm_ty = ty.get_element_type(); let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else { return Err(format!("Expected struct type for `Slice` type, got {llvm_ty}")); }; @@ -105,146 +263,25 @@ impl<'ctx> SliceType<'ctx> { ) } - // TODO: Move this into e.g. StructProxyType - #[must_use] - pub fn get_fields(&self) -> SliceFields<'ctx> { - SliceFields::new_sized(&self.int_ty.get_context(), self.int_ty) - } - - /// Creates an LLVM type corresponding to the expected structure of a `Slice`. - #[must_use] - fn llvm_type(ctx: &'ctx Context, int_ty: IntType<'ctx>) -> PointerType<'ctx> { - let field_tys = SliceFields::new_sized(&int_ty.get_context(), int_ty) - .into_iter() - .map(|field| field.1) - .collect_vec(); - - ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) - } - - /// Creates an instance of [`SliceType`] with `int_ty` as its backing integer type. - #[must_use] - pub fn new(ctx: &'ctx Context, int_ty: IntType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { - let llvm_ty = Self::llvm_type(ctx, int_ty); - - Self { ty: llvm_ty, int_ty, llvm_usize } - } - - /// Creates an instance of [`SliceType`] with `usize` as its backing integer type. - #[must_use] - pub fn new_usize(generator: &G, ctx: &'ctx Context) -> Self { - let llvm_usize = generator.get_size_type(ctx); - Self::new(ctx, llvm_usize, llvm_usize) - } - - /// Creates an [`SliceType`] from a [`PointerType`] representing a `slice`. - #[must_use] - pub fn from_type( - ptr_ty: PointerType<'ctx>, - int_ty: IntType<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Self { - debug_assert!(Self::is_representable(ptr_ty, int_ty).is_ok()); - - Self { ty: ptr_ty, int_ty, llvm_usize } - } - - #[must_use] - pub fn element_type(&self) -> IntType<'ctx> { - self.int_ty - } - - /// Allocates an instance of [`ContiguousNDArrayValue`] as if by calling `alloca` on the base type. - #[must_use] - pub fn alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> >::Value { - >::Value::from_pointer_value( - self.raw_alloca(generator, ctx, name), - self.int_ty, - self.llvm_usize, - name, - ) - } - - /// Converts an existing value into a [`ContiguousNDArrayValue`]. - #[must_use] - pub fn map_value( - &self, - value: <>::Value as ProxyValue<'ctx>>::Base, - name: Option<&'ctx str>, - ) -> >::Value { - >::Value::from_pointer_value( - value, - self.int_ty, - self.llvm_usize, - name, - ) - } -} - -impl<'ctx> ProxyType<'ctx> for SliceType<'ctx> { - type Base = PointerType<'ctx>; - type Value = SliceValue<'ctx>; - - fn is_type( - generator: &G, - ctx: &'ctx Context, - llvm_ty: impl BasicType<'ctx>, - ) -> Result<(), String> { - if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { - >::is_representable(generator, ctx, ty) - } else { - Err(format!("Expected pointer type, got {llvm_ty:?}")) - } - } - - fn is_representable( - generator: &G, - ctx: &'ctx Context, - llvm_ty: Self::Base, - ) -> Result<(), String> { - Self::is_representable(llvm_ty, generator.get_size_type(ctx)) - } - - fn raw_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> >::Base { - generator - .gen_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - name, - ) - .unwrap() - } - - fn array_alloca( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - size: IntValue<'ctx>, - name: Option<&'ctx str>, - ) -> ArraySliceValue<'ctx> { - generator - .gen_array_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - size, - name, - ) - .unwrap() + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.as_abi_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { self.ty } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } +} + +impl<'ctx> StructProxyType<'ctx> for SliceType<'ctx> { + type StructFields = SliceStructFields<'ctx>; + + fn get_fields(&self) -> Self::StructFields { + SliceStructFields::new_sized(&self.ty.get_context(), self.int_ty) + } } impl<'ctx> From> for PointerType<'ctx> { diff --git a/nac3core/src/codegen/values/array.rs b/nac3core/src/codegen/values/array.rs index 78975f0..9f6652b 100644 --- a/nac3core/src/codegen/values/array.rs +++ b/nac3core/src/codegen/values/array.rs @@ -51,8 +51,8 @@ pub trait ArrayLikeIndexer<'ctx, Index = IntValue<'ctx>>: ArrayLikeValue<'ctx> { /// This function should be called with a valid index. unsafe fn ptr_offset_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, idx: &Index, name: Option<&str>, ) -> PointerValue<'ctx>; @@ -76,8 +76,8 @@ pub trait UntypedArrayLikeAccessor<'ctx, Index = IntValue<'ctx>>: /// This function should be called with a valid index. unsafe fn get_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, idx: &Index, name: Option<&str>, ) -> BasicValueEnum<'ctx> { @@ -107,8 +107,8 @@ pub trait UntypedArrayLikeMutator<'ctx, Index = IntValue<'ctx>>: /// This function should be called with a valid index. unsafe fn set_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, idx: &Index, value: BasicValueEnum<'ctx>, ) { @@ -130,32 +130,33 @@ pub trait UntypedArrayLikeMutator<'ctx, Index = IntValue<'ctx>>: } /// An array-like value that can have its array elements accessed as an arbitrary type `T`. -pub trait TypedArrayLikeAccessor<'ctx, T, Index = IntValue<'ctx>>: +pub trait TypedArrayLikeAccessor<'ctx, G: CodeGenerator + ?Sized, T, Index = IntValue<'ctx>>: UntypedArrayLikeAccessor<'ctx, Index> { /// Casts an element from [`BasicValueEnum`] into `T`. fn downcast_to_type( &self, - ctx: &mut CodeGenContext<'ctx, '_>, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, value: BasicValueEnum<'ctx>, ) -> T; /// # Safety /// /// This function should be called with a valid index. - unsafe fn get_typed_unchecked( + unsafe fn get_typed_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, idx: &Index, name: Option<&str>, ) -> T { let value = unsafe { self.get_unchecked(ctx, generator, idx, name) }; - self.downcast_to_type(ctx, value) + self.downcast_to_type(ctx, generator, value) } /// Returns the data at the `idx`-th index. - fn get_typed( + fn get_typed( &self, ctx: &mut CodeGenContext<'ctx, '_>, generator: &mut G, @@ -163,62 +164,63 @@ pub trait TypedArrayLikeAccessor<'ctx, T, Index = IntValue<'ctx>>: name: Option<&str>, ) -> T { let value = self.get(ctx, generator, idx, name); - self.downcast_to_type(ctx, value) + self.downcast_to_type(ctx, generator, value) } } /// An array-like value that can have its array elements mutated as an arbitrary type `T`. -pub trait TypedArrayLikeMutator<'ctx, T, Index = IntValue<'ctx>>: +pub trait TypedArrayLikeMutator<'ctx, G: CodeGenerator + ?Sized, T, Index = IntValue<'ctx>>: UntypedArrayLikeMutator<'ctx, Index> { /// Casts an element from T into [`BasicValueEnum`]. fn upcast_from_type( &self, - ctx: &mut CodeGenContext<'ctx, '_>, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, value: T, ) -> BasicValueEnum<'ctx>; /// # Safety /// /// This function should be called with a valid index. - unsafe fn set_typed_unchecked( + unsafe fn set_typed_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, idx: &Index, value: T, ) { - let value = self.upcast_from_type(ctx, value); + let value = self.upcast_from_type(ctx, generator, value); unsafe { self.set_unchecked(ctx, generator, idx, value) } } /// Sets the data at the `idx`-th index. - fn set_typed( + fn set_typed( &self, ctx: &mut CodeGenContext<'ctx, '_>, generator: &mut G, idx: &Index, value: T, ) { - let value = self.upcast_from_type(ctx, value); + let value = self.upcast_from_type(ctx, generator, value); self.set(ctx, generator, idx, value); } } -/// Type alias for a function that casts a [`BasicValueEnum`] into a `T`. -type ValueDowncastFn<'ctx, T> = - Box, BasicValueEnum<'ctx>) -> T + 'ctx>; -/// Type alias for a function that casts a `T` into a [`BasicValueEnum`]. -type ValueUpcastFn<'ctx, T> = Box, T) -> BasicValueEnum<'ctx>>; - /// An adapter for constraining untyped array values as typed values. -pub struct TypedArrayLikeAdapter<'ctx, T, Adapted: ArrayLikeValue<'ctx> = ArraySliceValue<'ctx>> { +#[derive(Copy, Clone)] +pub struct TypedArrayLikeAdapter< + 'ctx, + G: CodeGenerator + ?Sized, + T, + Adapted: ArrayLikeValue<'ctx> = ArraySliceValue<'ctx>, +> { adapted: Adapted, - downcast_fn: ValueDowncastFn<'ctx, T>, - upcast_fn: ValueUpcastFn<'ctx, T>, + downcast_fn: fn(&CodeGenContext<'ctx, '_>, &G, BasicValueEnum<'ctx>) -> T, + upcast_fn: fn(&CodeGenContext<'ctx, '_>, &G, T) -> BasicValueEnum<'ctx>, } -impl<'ctx, T, Adapted> TypedArrayLikeAdapter<'ctx, T, Adapted> +impl<'ctx, G: CodeGenerator + ?Sized, T, Adapted> TypedArrayLikeAdapter<'ctx, G, T, Adapted> where Adapted: ArrayLikeValue<'ctx>, { @@ -229,61 +231,70 @@ where /// * `upcast_fn` - The function converting a T into a [`BasicValueEnum`]. pub fn from( adapted: Adapted, - downcast_fn: ValueDowncastFn<'ctx, T>, - upcast_fn: ValueUpcastFn<'ctx, T>, + downcast_fn: fn(&CodeGenContext<'ctx, '_>, &G, BasicValueEnum<'ctx>) -> T, + upcast_fn: fn(&CodeGenContext<'ctx, '_>, &G, T) -> BasicValueEnum<'ctx>, ) -> Self { TypedArrayLikeAdapter { adapted, downcast_fn, upcast_fn } } } -impl<'ctx, T, Adapted> ArrayLikeValue<'ctx> for TypedArrayLikeAdapter<'ctx, T, Adapted> +impl<'ctx, G: CodeGenerator + ?Sized, T, Adapted> ArrayLikeValue<'ctx> + for TypedArrayLikeAdapter<'ctx, G, T, Adapted> where Adapted: ArrayLikeValue<'ctx>, { - fn element_type( + fn element_type( &self, ctx: &CodeGenContext<'ctx, '_>, - generator: &G, + generator: &CG, ) -> AnyTypeEnum<'ctx> { self.adapted.element_type(ctx, generator) } - fn base_ptr( + fn base_ptr( &self, ctx: &CodeGenContext<'ctx, '_>, - generator: &G, + generator: &CG, ) -> PointerValue<'ctx> { self.adapted.base_ptr(ctx, generator) } - fn size( + fn size( &self, ctx: &CodeGenContext<'ctx, '_>, - generator: &G, + generator: &CG, ) -> IntValue<'ctx> { self.adapted.size(ctx, generator) } + + fn as_slice_value( + &self, + ctx: &CodeGenContext<'ctx, '_>, + generator: &CG, + ) -> ArraySliceValue<'ctx> { + self.adapted.as_slice_value(ctx, generator) + } } -impl<'ctx, T, Index, Adapted> ArrayLikeIndexer<'ctx, Index> - for TypedArrayLikeAdapter<'ctx, T, Adapted> +impl<'ctx, G: CodeGenerator + ?Sized, T, Index, Adapted> ArrayLikeIndexer<'ctx, Index> + for TypedArrayLikeAdapter<'ctx, G, T, Adapted> where Adapted: ArrayLikeIndexer<'ctx, Index>, { - unsafe fn ptr_offset_unchecked( + unsafe fn ptr_offset_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &CG, idx: &Index, name: Option<&str>, ) -> PointerValue<'ctx> { unsafe { self.adapted.ptr_offset_unchecked(ctx, generator, idx, name) } } - fn ptr_offset( + fn ptr_offset( &self, ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + generator: &mut CG, idx: &Index, name: Option<&str>, ) -> PointerValue<'ctx> { @@ -291,44 +302,46 @@ where } } -impl<'ctx, T, Index, Adapted> UntypedArrayLikeAccessor<'ctx, Index> - for TypedArrayLikeAdapter<'ctx, T, Adapted> +impl<'ctx, G: CodeGenerator + ?Sized, T, Index, Adapted> UntypedArrayLikeAccessor<'ctx, Index> + for TypedArrayLikeAdapter<'ctx, G, T, Adapted> where Adapted: UntypedArrayLikeAccessor<'ctx, Index>, { } -impl<'ctx, T, Index, Adapted> UntypedArrayLikeMutator<'ctx, Index> - for TypedArrayLikeAdapter<'ctx, T, Adapted> +impl<'ctx, G: CodeGenerator + ?Sized, T, Index, Adapted> UntypedArrayLikeMutator<'ctx, Index> + for TypedArrayLikeAdapter<'ctx, G, T, Adapted> where Adapted: UntypedArrayLikeMutator<'ctx, Index>, { } -impl<'ctx, T, Index, Adapted> TypedArrayLikeAccessor<'ctx, T, Index> - for TypedArrayLikeAdapter<'ctx, T, Adapted> +impl<'ctx, G: CodeGenerator + ?Sized, T, Index, Adapted> TypedArrayLikeAccessor<'ctx, G, T, Index> + for TypedArrayLikeAdapter<'ctx, G, T, Adapted> where Adapted: UntypedArrayLikeAccessor<'ctx, Index>, { fn downcast_to_type( &self, - ctx: &mut CodeGenContext<'ctx, '_>, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, value: BasicValueEnum<'ctx>, ) -> T { - (self.downcast_fn)(ctx, value) + (self.downcast_fn)(ctx, generator, value) } } -impl<'ctx, T, Index, Adapted> TypedArrayLikeMutator<'ctx, T, Index> - for TypedArrayLikeAdapter<'ctx, T, Adapted> +impl<'ctx, G: CodeGenerator + ?Sized, T, Index, Adapted> TypedArrayLikeMutator<'ctx, G, T, Index> + for TypedArrayLikeAdapter<'ctx, G, T, Adapted> where Adapted: UntypedArrayLikeMutator<'ctx, Index>, { fn upcast_from_type( &self, - ctx: &mut CodeGenContext<'ctx, '_>, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, value: T, ) -> BasicValueEnum<'ctx> { - (self.upcast_fn)(ctx, value) + (self.upcast_fn)(ctx, generator, value) } } @@ -384,12 +397,12 @@ impl<'ctx> ArrayLikeValue<'ctx> for ArraySliceValue<'ctx> { impl<'ctx> ArrayLikeIndexer<'ctx> for ArraySliceValue<'ctx> { unsafe fn ptr_offset_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, idx: &IntValue<'ctx>, name: Option<&str>, ) -> PointerValue<'ctx> { - let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default(); + let var_name = name.or(self.2).map(|v| format!("{v}.addr")).unwrap_or_default(); unsafe { ctx.builder @@ -405,7 +418,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ArraySliceValue<'ctx> { idx: &IntValue<'ctx>, name: Option<&str>, ) -> PointerValue<'ctx> { - debug_assert_eq!(idx.get_type(), generator.get_size_type(ctx.ctx)); + debug_assert_eq!(idx.get_type(), ctx.get_size_type()); let size = self.size(ctx, generator); let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap(); diff --git a/nac3core/src/codegen/values/list.rs b/nac3core/src/codegen/values/list.rs index 7b1975f..453065b 100644 --- a/nac3core/src/codegen/values/list.rs +++ b/nac3core/src/codegen/values/list.rs @@ -1,14 +1,18 @@ use inkwell::{ types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType}, - values::{BasicValueEnum, IntValue, PointerValue}, + values::{BasicValueEnum, IntValue, PointerValue, StructValue}, AddressSpace, IntPredicate, }; use super::{ - ArrayLikeIndexer, ArrayLikeValue, ProxyValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, + structure::StructProxyValue, ArrayLikeIndexer, ArrayLikeValue, ProxyValue, + UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }; use crate::codegen::{ - types::ListType, + types::{ + structure::{StructField, StructProxyType}, + ListType, ProxyType, + }, {CodeGenContext, CodeGenerator}, }; @@ -21,13 +25,24 @@ pub struct ListValue<'ctx> { } impl<'ctx> ListValue<'ctx> { - /// Checks whether `value` is an instance of `list`, returning [Err] if `value` is not an - /// instance. - pub fn is_representable( - value: PointerValue<'ctx>, + /// Creates an [`ListValue`] from a [`PointerValue`]. + #[must_use] + pub fn from_struct_value( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + val: StructValue<'ctx>, llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - ListType::is_representable(value.get_type(), llvm_usize) + name: Option<&'ctx str>, + ) -> Self { + let pval = generator + .gen_var_alloc( + ctx, + val.get_type().into(), + name.map(|name| format!("{name}.addr")).as_deref(), + ) + .unwrap(); + ctx.builder.build_store(pval, val).unwrap(); + Self::from_pointer_value(pval, llvm_usize, name) } /// Creates an [`ListValue`] from a [`PointerValue`]. @@ -37,53 +52,25 @@ impl<'ctx> ListValue<'ctx> { llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, ) -> Self { - debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok()); + debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); ListValue { value: ptr, llvm_usize, name } } - /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` - /// on the field. - fn pptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let llvm_i32 = ctx.ctx.i32_type(); - let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - var_name.as_str(), - ) - .unwrap() - } - } - - /// Returns the pointer to the field storing the size of this `list`. - fn ptr_to_size(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let llvm_i32 = ctx.ctx.i32_type(); - let var_name = self.name.map(|v| format!("{v}.size.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - var_name.as_str(), - ) - .unwrap() - } + fn items_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { + self.get_type().get_fields().items } /// Stores the array of data elements `data` into this instance. fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) { - ctx.builder.build_store(self.pptr_to_data(ctx), data).unwrap(); + self.items_field().store(ctx, self.value, data, self.name); } /// Convenience method for creating a new array storing data elements with the given element /// type `elem_ty` and `size`. /// - /// If `size` is [None], the size stored in the field of this instance is used instead. + /// If `size` is [None], the size stored in the field of this instance is used instead. If + /// `size` is resolved to `0` at runtime, `(T*) 0` will be assigned to `data`. pub fn create_data( &self, ctx: &mut CodeGenContext<'ctx, '_>, @@ -114,47 +101,60 @@ impl<'ctx> ListValue<'ctx> { ListDataProxy(self) } - /// Stores the `size` of this `list` into this instance. - pub fn store_size( - &self, - ctx: &CodeGenContext<'ctx, '_>, - generator: &G, - size: IntValue<'ctx>, - ) { - debug_assert_eq!(size.get_type(), generator.get_size_type(ctx.ctx)); + fn len_field(&self) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields().len + } - let psize = self.ptr_to_size(ctx); - ctx.builder.build_store(psize, size).unwrap(); + /// Stores the `size` of this `list` into this instance. + pub fn store_size(&self, ctx: &CodeGenContext<'ctx, '_>, size: IntValue<'ctx>) { + debug_assert_eq!(size.get_type(), ctx.get_size_type()); + + self.len_field().store(ctx, self.value, size, self.name); } /// Returns the size of this `list` as a value. - pub fn load_size(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> { - let psize = self.ptr_to_size(ctx); - let var_name = name - .map(ToString::to_string) - .or_else(|| self.name.map(|v| format!("{v}.size"))) - .unwrap_or_default(); + pub fn load_size( + &self, + ctx: &CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> IntValue<'ctx> { + self.len_field().load(ctx, self.value, name) + } - ctx.builder - .build_load(psize, var_name.as_str()) - .map(BasicValueEnum::into_int_value) - .unwrap() + /// Returns an instance of [`ListValue`] with the `items` pointer cast to `i8*`. + #[must_use] + pub fn as_i8_list(&self, ctx: &CodeGenContext<'ctx, '_>) -> ListValue<'ctx> { + let llvm_i8 = ctx.ctx.i8_type(); + let llvm_list_i8 = ::Type::new(ctx, &llvm_i8); + + Self::from_pointer_value( + ctx.builder.build_pointer_cast(self.value, llvm_list_i8.as_abi_type(), "").unwrap(), + self.llvm_usize, + self.name, + ) } } impl<'ctx> ProxyValue<'ctx> for ListValue<'ctx> { + type ABI = PointerValue<'ctx>; type Base = PointerValue<'ctx>; type Type = ListType<'ctx>; fn get_type(&self) -> Self::Type { - ListType::from_type(self.as_base_value().get_type(), self.llvm_usize) + ListType::from_pointer_type(self.as_base_value().get_type(), self.llvm_usize) } fn as_base_value(&self) -> Self::Base { self.value } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } } +impl<'ctx> StructProxyValue<'ctx> for ListValue<'ctx> {} + impl<'ctx> From> for PointerValue<'ctx> { fn from(value: ListValue<'ctx>) -> Self { value.as_base_value() @@ -179,12 +179,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for ListDataProxy<'ctx, '_> { ctx: &CodeGenContext<'ctx, '_>, _: &G, ) -> PointerValue<'ctx> { - let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default(); - - ctx.builder - .build_load(self.0.pptr_to_data(ctx), var_name.as_str()) - .map(BasicValueEnum::into_pointer_value) - .unwrap() + self.0.items_field().load(ctx, self.0.value, self.0.name) } fn size( @@ -199,8 +194,8 @@ impl<'ctx> ArrayLikeValue<'ctx> for ListDataProxy<'ctx, '_> { impl<'ctx> ArrayLikeIndexer<'ctx> for ListDataProxy<'ctx, '_> { unsafe fn ptr_offset_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, idx: &IntValue<'ctx>, name: Option<&str>, ) -> PointerValue<'ctx> { @@ -220,7 +215,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ListDataProxy<'ctx, '_> { idx: &IntValue<'ctx>, name: Option<&str>, ) -> PointerValue<'ctx> { - debug_assert_eq!(idx.get_type(), generator.get_size_type(ctx.ctx)); + debug_assert_eq!(idx.get_type(), ctx.get_size_type()); let size = self.size(ctx, generator); let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap(); diff --git a/nac3core/src/codegen/values/mod.rs b/nac3core/src/codegen/values/mod.rs index 032f041..90f327e 100644 --- a/nac3core/src/codegen/values/mod.rs +++ b/nac3core/src/codegen/values/mod.rs @@ -1,42 +1,33 @@ -use inkwell::{context::Context, values::BasicValue}; +use inkwell::{types::IntType, values::BasicValue}; -use super::types::ProxyType; -use crate::codegen::CodeGenerator; +use super::{types::ProxyType, CodeGenContext}; pub use array::*; pub use list::*; pub use range::*; +pub use tuple::*; mod array; mod list; pub mod ndarray; mod range; +pub mod structure; +mod tuple; pub mod utils; /// A LLVM type that is used to represent a non-primitive value in NAC3. pub trait ProxyValue<'ctx>: Into { - /// The type of LLVM values represented by this instance. This is usually the - /// [LLVM pointer type][PointerValue]. + /// The ABI type of LLVM values represented by this instance. + type ABI: BasicValue<'ctx>; + + /// The type of LLVM values represented by this instance. type Base: BasicValue<'ctx>; /// The type of this value. type Type: ProxyType<'ctx, Value = Self>; /// Checks whether `value` can be represented by this [`ProxyValue`]. - fn is_instance( - generator: &G, - ctx: &'ctx Context, - value: impl BasicValue<'ctx>, - ) -> Result<(), String> { - Self::Type::is_type(generator, ctx, value.as_basic_value_enum().get_type()) - } - - /// Checks whether `value` can be represented by this [`ProxyValue`]. - fn is_representable( - generator: &G, - ctx: &'ctx Context, - value: Self::Base, - ) -> Result<(), String> { - Self::is_instance(generator, ctx, value.as_basic_value_enum()) + fn is_instance(value: impl BasicValue<'ctx>, llvm_usize: IntType<'ctx>) -> Result<(), String> { + Self::Type::is_representable(value.as_basic_value_enum().get_type(), llvm_usize) } /// Returns the [type][ProxyType] of this value. @@ -44,4 +35,10 @@ pub trait ProxyValue<'ctx>: Into { /// Returns the [base value][Self::Base] of this proxy. fn as_base_value(&self) -> Self::Base; + + /// Returns this proxy as its ABI value, i.e. the expected value representation if a value + /// represented by this [`ProxyValue`] is being passed into or returned from a function. + /// + /// See [`CodeGenContext::get_llvm_abi_type`]. + fn as_abi_value(&self, ctx: &CodeGenContext<'ctx, '_>) -> Self::ABI; } diff --git a/nac3core/src/codegen/values/ndarray/broadcast.rs b/nac3core/src/codegen/values/ndarray/broadcast.rs new file mode 100644 index 0000000..4935a36 --- /dev/null +++ b/nac3core/src/codegen/values/ndarray/broadcast.rs @@ -0,0 +1,262 @@ +use inkwell::{ + types::IntType, + values::{IntValue, PointerValue, StructValue}, +}; +use itertools::Itertools; + +use crate::codegen::{ + irrt, + types::{ + ndarray::{NDArrayType, ShapeEntryType}, + structure::{StructField, StructProxyType}, + ProxyType, + }, + values::{ + ndarray::NDArrayValue, structure::StructProxyValue, ArrayLikeIndexer, ArrayLikeValue, + ArraySliceValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, + TypedArrayLikeMutator, + }, + CodeGenContext, CodeGenerator, +}; + +#[derive(Copy, Clone)] +pub struct ShapeEntryValue<'ctx> { + value: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, +} + +impl<'ctx> ShapeEntryValue<'ctx> { + /// Creates an [`ShapeEntryValue`] from a [`StructValue`]. + #[must_use] + pub fn from_struct_value( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + val: StructValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + let pval = generator + .gen_var_alloc( + ctx, + val.get_type().into(), + name.map(|name| format!("{name}.addr")).as_deref(), + ) + .unwrap(); + ctx.builder.build_store(pval, val).unwrap(); + Self::from_pointer_value(pval, llvm_usize, name) + } + + /// Creates an [`ShapeEntryValue`] from a [`PointerValue`]. + #[must_use] + pub fn from_pointer_value( + ptr: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); + + Self { value: ptr, llvm_usize, name } + } + + fn ndims_field(&self) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields().ndims + } + + /// Stores the number of dimensions into this value. + pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) { + self.ndims_field().store(ctx, self.value, value, self.name); + } + + fn shape_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { + self.get_type().get_fields().shape + } + + /// Stores the shape into this value. + pub fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) { + self.shape_field().store(ctx, self.value, value, self.name); + } +} + +impl<'ctx> ProxyValue<'ctx> for ShapeEntryValue<'ctx> { + type ABI = PointerValue<'ctx>; + type Base = PointerValue<'ctx>; + type Type = ShapeEntryType<'ctx>; + + fn get_type(&self) -> Self::Type { + Self::Type::from_pointer_type(self.value.get_type(), self.llvm_usize) + } + + fn as_base_value(&self) -> Self::Base { + self.value + } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } +} + +impl<'ctx> StructProxyValue<'ctx> for ShapeEntryValue<'ctx> {} + +impl<'ctx> From> for PointerValue<'ctx> { + fn from(value: ShapeEntryValue<'ctx>) -> Self { + value.as_base_value() + } +} + +impl<'ctx> NDArrayValue<'ctx> { + /// Create a broadcast view on this ndarray with a target shape. + /// + /// The input shape will be checked to make sure that it contains no negative values. + /// + /// * `target_ndims` - The ndims type after broadcasting to the given shape. + /// The caller has to figure this out for this function. + /// * `target_shape` - An array pointer pointing to the target shape. + #[must_use] + pub fn broadcast_to( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + target_ndims: u64, + target_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + ) -> Self { + assert!(self.ndims <= target_ndims); + assert_eq!(target_shape.element_type(ctx, generator), self.llvm_usize.into()); + + let broadcast_ndarray = NDArrayType::new(ctx, self.dtype, target_ndims) + .construct_uninitialized(generator, ctx, None); + broadcast_ndarray.copy_shape_from_array( + generator, + ctx, + target_shape.base_ptr(ctx, generator), + ); + + irrt::ndarray::call_nac3_ndarray_broadcast_to(ctx, *self, broadcast_ndarray); + broadcast_ndarray + } +} + +/// A result produced by [`broadcast_all_ndarrays`] +#[derive(Clone)] +pub struct BroadcastAllResult<'ctx, G: CodeGenerator + ?Sized> { + /// The statically known `ndims` of the broadcast result. + pub ndims: u64, + + /// The broadcasting shape. + pub shape: TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>>, + + /// Broadcasted views on the inputs. + /// + /// All of them will have `shape` [`BroadcastAllResult::shape`] and + /// `ndims` [`BroadcastAllResult::ndims`]. The length of the vector + /// is the same as the input. + pub ndarrays: Vec>, +} + +/// Helper function to call [`irrt::ndarray::call_nac3_ndarray_broadcast_shapes`]. +fn broadcast_shapes<'ctx, G, Shape>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + in_shape_entries: &[(ArraySliceValue<'ctx>, u64)], // (shape, shape's length/ndims) + broadcast_ndims: u64, + broadcast_shape: &Shape, +) where + G: CodeGenerator + ?Sized, + Shape: TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> + + TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>>, +{ + let llvm_usize = ctx.get_size_type(); + let llvm_shape_ty = ShapeEntryType::new(ctx); + + assert!(in_shape_entries + .iter() + .all(|entry| entry.0.element_type(ctx, generator) == llvm_usize.into())); + assert_eq!(broadcast_shape.element_type(ctx, generator), llvm_usize.into()); + + // Prepare input shape entries to be passed to `call_nac3_ndarray_broadcast_shapes`. + let num_shape_entries = + llvm_usize.const_int(u64::try_from(in_shape_entries.len()).unwrap(), false); + let shape_entries = llvm_shape_ty.array_alloca(ctx, num_shape_entries, None); + for (i, (in_shape, in_ndims)) in in_shape_entries.iter().enumerate() { + let pshape_entry = unsafe { + shape_entries.ptr_offset_unchecked( + ctx, + generator, + &llvm_usize.const_int(i as u64, false), + None, + ) + }; + let shape_entry = llvm_shape_ty.map_pointer_value(pshape_entry, None); + + let in_ndims = llvm_usize.const_int(*in_ndims, false); + shape_entry.store_ndims(ctx, in_ndims); + + shape_entry.store_shape(ctx, in_shape.base_ptr(ctx, generator)); + } + + let broadcast_ndims = llvm_usize.const_int(broadcast_ndims, false); + irrt::ndarray::call_nac3_ndarray_broadcast_shapes( + generator, + ctx, + num_shape_entries, + shape_entries, + broadcast_ndims, + broadcast_shape, + ); +} + +impl<'ctx> NDArrayType<'ctx> { + /// Broadcast all ndarrays according to + /// [`np.broadcast()`](https://numpy.org/doc/stable/reference/generated/numpy.broadcast.html) + /// and return a [`BroadcastAllResult`] containing all the information of the result of the + /// broadcast operation. + pub fn broadcast( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarrays: &[NDArrayValue<'ctx>], + ) -> BroadcastAllResult<'ctx, G> { + assert!(!ndarrays.is_empty()); + + let llvm_usize = ctx.get_size_type(); + + // Infer the broadcast output ndims. + let broadcast_ndims_int = + ndarrays.iter().map(|ndarray| ndarray.get_type().ndims()).max().unwrap(); + assert!(self.ndims() >= broadcast_ndims_int); + + let broadcast_ndims = llvm_usize.const_int(broadcast_ndims_int, false); + let broadcast_shape = ArraySliceValue::from_ptr_val( + ctx.builder.build_array_alloca(llvm_usize, broadcast_ndims, "").unwrap(), + broadcast_ndims, + None, + ); + let broadcast_shape = TypedArrayLikeAdapter::from( + broadcast_shape, + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + + let shape_entries = ndarrays + .iter() + .map(|ndarray| { + (ndarray.shape().as_slice_value(ctx, generator), ndarray.get_type().ndims()) + }) + .collect_vec(); + broadcast_shapes(generator, ctx, &shape_entries, broadcast_ndims_int, &broadcast_shape); + + // Broadcast all the inputs to shape `dst_shape`. + let broadcast_ndarrays = ndarrays + .iter() + .map(|ndarray| { + ndarray.broadcast_to(generator, ctx, broadcast_ndims_int, &broadcast_shape) + }) + .collect_vec(); + + BroadcastAllResult { + ndims: broadcast_ndims_int, + shape: broadcast_shape, + ndarrays: broadcast_ndarrays, + } + } +} diff --git a/nac3core/src/codegen/values/ndarray/contiguous.rs b/nac3core/src/codegen/values/ndarray/contiguous.rs index 87e2f1d..9dca06a 100644 --- a/nac3core/src/codegen/values/ndarray/contiguous.rs +++ b/nac3core/src/codegen/values/ndarray/contiguous.rs @@ -1,16 +1,17 @@ use inkwell::{ types::{BasicType, BasicTypeEnum, IntType}, - values::{IntValue, PointerValue}, + values::{IntValue, PointerValue, StructValue}, AddressSpace, }; -use super::{ArrayLikeValue, NDArrayValue, ProxyValue}; +use super::NDArrayValue; use crate::codegen::{ stmt::gen_if_callback, types::{ ndarray::{ContiguousNDArrayType, NDArrayType}, - structure::StructField, + structure::{StructField, StructProxyType}, }, + values::{structure::StructProxyValue, ArrayLikeValue, ProxyValue}, CodeGenContext, CodeGenerator, }; @@ -23,13 +24,25 @@ pub struct ContiguousNDArrayValue<'ctx> { } impl<'ctx> ContiguousNDArrayValue<'ctx> { - /// Checks whether `value` is an instance of `ContiguousNDArray`, returning [Err] if `value` is - /// not an instance. - pub fn is_representable( - value: PointerValue<'ctx>, + /// Creates an [`ContiguousNDArrayValue`] from a [`StructValue`]. + #[must_use] + pub fn from_struct_value( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + val: StructValue<'ctx>, + dtype: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - >::Type::is_representable(value.get_type(), llvm_usize) + name: Option<&'ctx str>, + ) -> Self { + let pval = generator + .gen_var_alloc( + ctx, + val.get_type().into(), + name.map(|name| format!("{name}.addr")).as_deref(), + ) + .unwrap(); + ctx.builder.build_store(pval, val).unwrap(); + Self::from_pointer_value(pval, dtype, llvm_usize, name) } /// Creates an [`ContiguousNDArrayValue`] from a [`PointerValue`]. @@ -40,7 +53,7 @@ impl<'ctx> ContiguousNDArrayValue<'ctx> { llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, ) -> Self { - debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok()); + debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); Self { value: ptr, item: dtype, llvm_usize, name } } @@ -50,7 +63,7 @@ impl<'ctx> ContiguousNDArrayValue<'ctx> { } pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) { - self.ndims_field().set(ctx, self.as_base_value(), value, self.name); + self.ndims_field().store(ctx, self.as_abi_value(ctx), value, self.name); } fn shape_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { @@ -58,11 +71,11 @@ impl<'ctx> ContiguousNDArrayValue<'ctx> { } pub fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) { - self.shape_field().set(ctx, self.as_base_value(), value, self.name); + self.shape_field().store(ctx, self.as_abi_value(ctx), value, self.name); } pub fn load_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - self.shape_field().get(ctx, self.value, self.name) + self.shape_field().load(ctx, self.value, self.name) } fn data_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { @@ -70,20 +83,21 @@ impl<'ctx> ContiguousNDArrayValue<'ctx> { } pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) { - self.data_field().set(ctx, self.as_base_value(), value, self.name); + self.data_field().store(ctx, self.as_abi_value(ctx), value, self.name); } pub fn load_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - self.data_field().get(ctx, self.value, self.name) + self.data_field().load(ctx, self.value, self.name) } } impl<'ctx> ProxyValue<'ctx> for ContiguousNDArrayValue<'ctx> { + type ABI = PointerValue<'ctx>; type Base = PointerValue<'ctx>; type Type = ContiguousNDArrayType<'ctx>; fn get_type(&self) -> Self::Type { - >::Type::from_type( + >::Type::from_pointer_type( self.as_base_value().get_type(), self.item, self.llvm_usize, @@ -93,8 +107,14 @@ impl<'ctx> ProxyValue<'ctx> for ContiguousNDArrayValue<'ctx> { fn as_base_value(&self) -> Self::Base { self.value } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } } +impl<'ctx> StructProxyValue<'ctx> for ContiguousNDArrayValue<'ctx> {} + impl<'ctx> From> for PointerValue<'ctx> { fn from(value: ContiguousNDArrayValue<'ctx>) -> Self { value.as_base_value() @@ -117,13 +137,11 @@ impl<'ctx> NDArrayValue<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, ) -> ContiguousNDArrayValue<'ctx> { - let result = ContiguousNDArrayType::new(generator, ctx.ctx, self.dtype) - .alloca(generator, ctx, self.name); + let result = + ContiguousNDArrayType::new(ctx, &self.dtype).alloca_var(generator, ctx, self.name); // Set ndims and shape. - let ndims = self - .ndims - .map_or_else(|| self.load_ndims(ctx), |ndims| self.llvm_usize.const_int(ndims, false)); + let ndims = self.llvm_usize.const_int(self.ndims, false); result.store_ndims(ctx, ndims); let shape = self.shape(); @@ -132,10 +150,10 @@ impl<'ctx> NDArrayValue<'ctx> { gen_if_callback( generator, ctx, - |generator, ctx| Ok(self.is_c_contiguous(generator, ctx)), + |_, ctx| Ok(self.is_c_contiguous(ctx)), |_, ctx| { // This ndarray is contiguous. - let data = self.data_field(ctx).get(ctx, self.as_base_value(), self.name); + let data = self.data_field().load(ctx, self.as_abi_value(ctx), self.name); let data = ctx .builder .build_pointer_cast(data, result.item.ptr_type(AddressSpace::default()), "") @@ -180,13 +198,16 @@ impl<'ctx> NDArrayValue<'ctx> { // TODO: Debug assert `ndims == carray.ndims` to catch bugs. // Allocate the resulting ndarray. - let ndarray = NDArrayType::new(generator, ctx.ctx, carray.item, Some(ndims)) - .construct_uninitialized(generator, ctx, carray.name); + let ndarray = NDArrayType::new(ctx, carray.item, ndims).construct_uninitialized( + generator, + ctx, + carray.name, + ); // Copy shape and update strides let shape = carray.load_shape(ctx); ndarray.copy_shape_from_array(generator, ctx, shape); - ndarray.set_strides_contiguous(generator, ctx); + ndarray.set_strides_contiguous(ctx); // Share data let data = carray.load_data(ctx); diff --git a/nac3core/src/codegen/values/ndarray/fold.rs b/nac3core/src/codegen/values/ndarray/fold.rs new file mode 100644 index 0000000..7c8aebd --- /dev/null +++ b/nac3core/src/codegen/values/ndarray/fold.rs @@ -0,0 +1,101 @@ +use inkwell::values::{BasicValue, BasicValueEnum}; + +use super::{NDArrayValue, NDIterValue, ScalarOrNDArray}; +use crate::codegen::{ + stmt::{gen_for_callback, BreakContinueHooks}, + types::ndarray::NDIterType, + CodeGenContext, CodeGenerator, +}; + +impl<'ctx> NDArrayValue<'ctx> { + /// Folds the elements of this ndarray into an accumulator value by applying `f`, returning the + /// final value. + /// + /// `f` has access to [`BreakContinueHooks`] to short-circuit the `fold` operation, an instance + /// of `V` representing the current accumulated value, and an [`NDIterValue`] to get the + /// properties of the current iterated element. + pub fn fold<'a, G, V, F>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + init: V, + f: F, + ) -> Result + where + G: CodeGenerator + ?Sized, + V: BasicValue<'ctx> + TryFrom>, + >>::Error: std::fmt::Debug, + F: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + BreakContinueHooks<'ctx>, + V, + NDIterValue<'ctx>, + ) -> Result, + { + let acc_ptr = + generator.gen_var_alloc(ctx, init.as_basic_value_enum().get_type(), None).unwrap(); + ctx.builder.build_store(acc_ptr, init).unwrap(); + + gen_for_callback( + generator, + ctx, + Some("ndarray_fold"), + |generator, ctx| Ok(NDIterType::new(ctx).construct(generator, ctx, *self)), + |_, ctx, nditer| Ok(nditer.has_element(ctx)), + |generator, ctx, hooks, nditer| { + let acc = V::try_from(ctx.builder.build_load(acc_ptr, "").unwrap()).unwrap(); + let acc = f(generator, ctx, hooks, acc, nditer)?; + ctx.builder.build_store(acc_ptr, acc).unwrap(); + Ok(()) + }, + |_, ctx, nditer| { + nditer.next(ctx); + Ok(()) + }, + )?; + + let acc = ctx.builder.build_load(acc_ptr, "").unwrap(); + Ok(V::try_from(acc).unwrap()) + } +} + +impl<'ctx> ScalarOrNDArray<'ctx> { + /// See [`NDArrayValue::fold`]. + /// + /// The primary differences between this function and `NDArrayValue::fold` are: + /// + /// - The 3rd parameter of `f` is an `Option` of hooks, since `break`/`continue` hooks are not + /// available if this instance represents a scalar value. + /// - The 5th parameter of `f` is a [`BasicValueEnum`], since no [iterator][`NDIterValue`] will + /// be created if this instance represents a scalar value. + pub fn fold<'a, G, V, F>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + init: V, + f: F, + ) -> Result + where + G: CodeGenerator + ?Sized, + V: BasicValue<'ctx> + TryFrom>, + >>::Error: std::fmt::Debug, + F: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + Option<&BreakContinueHooks<'ctx>>, + V, + BasicValueEnum<'ctx>, + ) -> Result, + { + match self { + ScalarOrNDArray::Scalar(v) => f(generator, ctx, None, init, *v), + ScalarOrNDArray::NDArray(v) => { + v.fold(generator, ctx, init, |generator, ctx, hooks, acc, nditer| { + let elem = nditer.get_scalar(ctx); + f(generator, ctx, Some(&hooks), acc, elem) + }) + } + } + } +} diff --git a/nac3core/src/codegen/values/ndarray/indexing.rs b/nac3core/src/codegen/values/ndarray/indexing.rs index 69c0080..6ed0ed0 100644 --- a/nac3core/src/codegen/values/ndarray/indexing.rs +++ b/nac3core/src/codegen/values/ndarray/indexing.rs @@ -1,6 +1,6 @@ use inkwell::{ types::IntType, - values::{IntValue, PointerValue}, + values::{IntValue, PointerValue, StructValue}, AddressSpace, }; use itertools::Itertools; @@ -12,10 +12,12 @@ use crate::{ irrt, types::{ ndarray::{NDArrayType, NDIndexType}, - structure::StructField, + structure::{StructField, StructProxyType}, utils::SliceType, }, - values::{ndarray::NDArrayValue, utils::RustSlice, ProxyValue}, + values::{ + ndarray::NDArrayValue, structure::StructProxyValue, utils::RustSlice, ProxyValue, + }, CodeGenContext, CodeGenerator, }, typecheck::typedef::Type, @@ -30,13 +32,24 @@ pub struct NDIndexValue<'ctx> { } impl<'ctx> NDIndexValue<'ctx> { - /// Checks whether `value` is an instance of `ndindex`, returning [Err] if `value` is not an - /// instance. - pub fn is_representable( - value: PointerValue<'ctx>, + /// Creates an [`NDIndexValue`] from a [`StructValue`]. + #[must_use] + pub fn from_struct_value( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + val: StructValue<'ctx>, llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - >::Type::is_representable(value.get_type(), llvm_usize) + name: Option<&'ctx str>, + ) -> Self { + let pval = generator + .gen_var_alloc( + ctx, + val.get_type().into(), + name.map(|name| format!("{name}.addr")).as_deref(), + ) + .unwrap(); + ctx.builder.build_store(pval, val).unwrap(); + Self::from_pointer_value(pval, llvm_usize, name) } /// Creates an [`NDIndexValue`] from a [`PointerValue`]. @@ -46,7 +59,7 @@ impl<'ctx> NDIndexValue<'ctx> { llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, ) -> Self { - debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok()); + debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); Self { value: ptr, llvm_usize, name } } @@ -56,11 +69,11 @@ impl<'ctx> NDIndexValue<'ctx> { } pub fn load_type(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.type_field().get(ctx, self.value, self.name) + self.type_field().load(ctx, self.value, self.name) } pub fn store_type(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) { - self.type_field().set(ctx, self.value, value, self.name); + self.type_field().store(ctx, self.value, value, self.name); } fn data_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { @@ -68,27 +81,34 @@ impl<'ctx> NDIndexValue<'ctx> { } pub fn load_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - self.data_field().get(ctx, self.value, self.name) + self.data_field().load(ctx, self.value, self.name) } pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) { - self.data_field().set(ctx, self.value, value, self.name); + self.data_field().store(ctx, self.value, value, self.name); } } impl<'ctx> ProxyValue<'ctx> for NDIndexValue<'ctx> { + type ABI = PointerValue<'ctx>; type Base = PointerValue<'ctx>; type Type = NDIndexType<'ctx>; fn get_type(&self) -> Self::Type { - Self::Type::from_type(self.value.get_type(), self.llvm_usize) + Self::Type::from_pointer_type(self.value.get_type(), self.llvm_usize) } fn as_base_value(&self) -> Self::Base { self.value } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } } +impl<'ctx> StructProxyValue<'ctx> for NDIndexValue<'ctx> {} + impl<'ctx> From> for PointerValue<'ctx> { fn from(value: NDIndexValue<'ctx>) -> Self { value.as_base_value() @@ -98,8 +118,8 @@ impl<'ctx> From> for PointerValue<'ctx> { impl<'ctx> NDArrayValue<'ctx> { /// Get the expected `ndims` after indexing with `indices`. #[must_use] - fn deduce_ndims_after_indexing_with(&self, indices: &[RustNDIndex<'ctx>]) -> Option { - let mut ndims = self.ndims?; + fn deduce_ndims_after_indexing_with(&self, indices: &[RustNDIndex<'ctx>]) -> u64 { + let mut ndims = self.ndims; for index in indices { match index { @@ -113,7 +133,7 @@ impl<'ctx> NDArrayValue<'ctx> { } } - Some(ndims) + ndims } /// Index into the ndarray, and return a newly-allocated view on this ndarray. @@ -127,14 +147,11 @@ impl<'ctx> NDArrayValue<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, indices: &[RustNDIndex<'ctx>], ) -> Self { - assert!(self.ndims.is_some(), "NDArrayValue::index is only supported for instances with compile-time known ndims (self.ndims = Some(...))"); - let dst_ndims = self.deduce_ndims_after_indexing_with(indices); - let dst_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, dst_ndims) + let dst_ndarray = NDArrayType::new(ctx, self.dtype, dst_ndims) .construct_uninitialized(generator, ctx, None); - let indices = - NDIndexType::new(generator, ctx.ctx).construct_ndindices(generator, ctx, indices); + let indices = NDIndexType::new(ctx).construct_ndindices(generator, ctx, indices); irrt::ndarray::call_nac3_ndarray_index(generator, ctx, indices, *self, dst_ndarray); dst_ndarray @@ -247,8 +264,7 @@ impl<'ctx> RustNDIndex<'ctx> { } RustNDIndex::Slice(in_rust_slice) => { let user_slice_ptr = - SliceType::new(ctx.ctx, ctx.ctx.i32_type(), generator.get_size_type(ctx.ctx)) - .alloca(generator, ctx, None); + SliceType::new(ctx, ctx.ctx.i32_type()).alloca_var(generator, ctx, None); in_rust_slice.write_to_slice(ctx, user_slice_ptr); dst_ndindex.store_data( diff --git a/nac3core/src/codegen/values/ndarray/map.rs b/nac3core/src/codegen/values/ndarray/map.rs new file mode 100644 index 0000000..72d1bf9 --- /dev/null +++ b/nac3core/src/codegen/values/ndarray/map.rs @@ -0,0 +1,69 @@ +use inkwell::{types::BasicTypeEnum, values::BasicValueEnum}; + +use crate::codegen::{ + values::{ + ndarray::{NDArrayOut, NDArrayValue, ScalarOrNDArray}, + ProxyValue, + }, + CodeGenContext, CodeGenerator, +}; + +impl<'ctx> NDArrayValue<'ctx> { + /// Map through this ndarray with an elementwise function. + pub fn map<'a, G, Mapping>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + out: NDArrayOut<'ctx>, + mapping: Mapping, + ) -> Result + where + G: CodeGenerator + ?Sized, + Mapping: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + BasicValueEnum<'ctx>, + ) -> Result, String>, + { + self.get_type().broadcast_starmap( + generator, + ctx, + &[*self], + out, + |generator, ctx, scalars| mapping(generator, ctx, scalars[0]), + ) + } +} + +impl<'ctx> ScalarOrNDArray<'ctx> { + /// Map through this [`ScalarOrNDArray`] with an elementwise function. + /// + /// If this is a scalar, `mapping` will directly act on the scalar. This function will return a + /// [`ScalarOrNDArray::Scalar`] of that result. + /// + /// If this is an ndarray, `mapping` will be applied to the elements of the ndarray. A new + /// ndarray of the results will be created and returned as a [`ScalarOrNDArray::NDArray`]. + pub fn map<'a, G, Mapping>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + ret_dtype: BasicTypeEnum<'ctx>, + mapping: Mapping, + ) -> Result, String> + where + G: CodeGenerator + ?Sized, + Mapping: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + BasicValueEnum<'ctx>, + ) -> Result, String>, + { + ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[*self], + ret_dtype, + |generator, ctx, scalars| mapping(generator, ctx, scalars[0]), + ) + } +} diff --git a/nac3core/src/codegen/values/ndarray/matmul.rs b/nac3core/src/codegen/values/ndarray/matmul.rs new file mode 100644 index 0000000..f12d36c --- /dev/null +++ b/nac3core/src/codegen/values/ndarray/matmul.rs @@ -0,0 +1,323 @@ +use std::cmp::max; + +use nac3parser::ast::Operator; + +use super::{NDArrayOut, NDArrayValue, RustNDIndex}; +use crate::{ + codegen::{ + expr::gen_binop_expr_with_values, + irrt, + stmt::gen_for_callback_incrementing, + types::ndarray::NDArrayType, + values::{ + ArrayLikeValue, ArraySliceValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, + UntypedArrayLikeAccessor, UntypedArrayLikeMutator, + }, + CodeGenContext, CodeGenerator, + }, + toplevel::helper::arraylike_flatten_element_type, + typecheck::{magic_methods::Binop, typedef::Type}, +}; + +/// Perform `np.einsum("...ij,...jk->...ik", in_a, in_b)`. +/// +/// `dst_dtype` defines the dtype of the returned ndarray. +fn matmul_at_least_2d<'ctx, G: CodeGenerator>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dst_dtype: Type, + (in_a_ty, in_a): (Type, NDArrayValue<'ctx>), + (in_b_ty, in_b): (Type, NDArrayValue<'ctx>), +) -> NDArrayValue<'ctx> { + assert!(in_a.ndims >= 2, "in_a (which is {}) must be >= 2", in_a.ndims); + assert!(in_b.ndims >= 2, "in_b (which is {}) must be >= 2", in_b.ndims); + + let lhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_a_ty); + let rhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_b_ty); + + let llvm_usize = ctx.get_size_type(); + let llvm_dst_dtype = ctx.get_llvm_type(generator, dst_dtype); + + // Deduce ndims of the result of matmul. + let ndims_int = max(in_a.ndims, in_b.ndims); + let ndims = llvm_usize.const_int(ndims_int, false); + + // Broadcasts `in_a.shape[:-2]` and `in_b.shape[:-2]` together and allocate the + // destination ndarray to store the result of matmul. + let (lhs, rhs, dst) = { + let in_lhs_ndims = llvm_usize.const_int(in_a.ndims, false); + let in_lhs_shape = TypedArrayLikeAdapter::from( + ArraySliceValue::from_ptr_val( + in_a.shape().base_ptr(ctx, generator), + in_lhs_ndims, + None, + ), + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + let in_rhs_ndims = llvm_usize.const_int(in_b.ndims, false); + let in_rhs_shape = TypedArrayLikeAdapter::from( + ArraySliceValue::from_ptr_val( + in_b.shape().base_ptr(ctx, generator), + in_rhs_ndims, + None, + ), + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + let lhs_shape = TypedArrayLikeAdapter::from( + ArraySliceValue::from_ptr_val( + ctx.builder.build_array_alloca(llvm_usize, ndims, "").unwrap(), + ndims, + None, + ), + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + let rhs_shape = TypedArrayLikeAdapter::from( + ArraySliceValue::from_ptr_val( + ctx.builder.build_array_alloca(llvm_usize, ndims, "").unwrap(), + ndims, + None, + ), + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + let dst_shape = TypedArrayLikeAdapter::from( + ArraySliceValue::from_ptr_val( + ctx.builder.build_array_alloca(llvm_usize, ndims, "").unwrap(), + ndims, + None, + ), + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + + // Matmul dimension compatibility is checked here. + irrt::ndarray::call_nac3_ndarray_matmul_calculate_shapes( + generator, + ctx, + &in_lhs_shape, + &in_rhs_shape, + ndims, + &lhs_shape, + &rhs_shape, + &dst_shape, + ); + + let lhs = in_a.broadcast_to(generator, ctx, ndims_int, &lhs_shape); + let rhs = in_b.broadcast_to(generator, ctx, ndims_int, &rhs_shape); + + let dst = NDArrayType::new(ctx, llvm_dst_dtype, ndims_int) + .construct_uninitialized(generator, ctx, None); + dst.copy_shape_from_array(generator, ctx, dst_shape.base_ptr(ctx, generator)); + unsafe { + dst.create_data(generator, ctx); + } + + (lhs, rhs, dst) + }; + + let len = unsafe { + lhs.shape().get_typed_unchecked( + ctx, + generator, + &llvm_usize.const_int(ndims_int - 1, false), + None, + ) + }; + + let at_row = i64::try_from(ndims_int - 2).unwrap(); + let at_col = i64::try_from(ndims_int - 1).unwrap(); + + let dst_dtype_llvm = ctx.get_llvm_type(generator, dst_dtype); + let dst_zero = dst_dtype_llvm.const_zero(); + + dst.foreach(generator, ctx, |generator, ctx, _, hdl| { + let pdst_ij = hdl.get_pointer(ctx); + + ctx.builder.build_store(pdst_ij, dst_zero).unwrap(); + + let indices = hdl.get_indices::(); + let i = unsafe { + indices.get_unchecked(ctx, generator, &llvm_usize.const_int(at_row as u64, true), None) + }; + let j = unsafe { + indices.get_unchecked(ctx, generator, &llvm_usize.const_int(at_col as u64, true), None) + }; + + let num_0 = llvm_usize.const_int(0, false); + let num_1 = llvm_usize.const_int(1, false); + + gen_for_callback_incrementing( + generator, + ctx, + None, + num_0, + (len, false), + |generator, ctx, _, k| { + // `indices` is modified to index into `a` and `b`, and restored. + unsafe { + indices.set_unchecked( + ctx, + generator, + &llvm_usize.const_int(at_row as u64, true), + i, + ); + indices.set_unchecked( + ctx, + generator, + &llvm_usize.const_int(at_col as u64, true), + k.into(), + ); + } + let a_ik = unsafe { lhs.data().get_unchecked(ctx, generator, &indices, None) }; + + unsafe { + indices.set_unchecked( + ctx, + generator, + &llvm_usize.const_int(at_row as u64, true), + k.into(), + ); + indices.set_unchecked( + ctx, + generator, + &llvm_usize.const_int(at_col as u64, true), + j, + ); + } + let b_kj = unsafe { rhs.data().get_unchecked(ctx, generator, &indices, None) }; + + // Restore `indices`. + unsafe { + indices.set_unchecked( + ctx, + generator, + &llvm_usize.const_int(at_row as u64, true), + i, + ); + indices.set_unchecked( + ctx, + generator, + &llvm_usize.const_int(at_col as u64, true), + j, + ); + } + + // x = a_[...]ik * b_[...]kj + let x = gen_binop_expr_with_values( + generator, + ctx, + (&Some(lhs_dtype), a_ik), + Binop::normal(Operator::Mult), + (&Some(rhs_dtype), b_kj), + ctx.current_loc, + )? + .unwrap() + .to_basic_value_enum(ctx, generator, dst_dtype)?; + + // dst_[...]ij += x + let dst_ij = ctx.builder.build_load(pdst_ij, "").unwrap(); + let dst_ij = gen_binop_expr_with_values( + generator, + ctx, + (&Some(dst_dtype), dst_ij), + Binop::normal(Operator::Add), + (&Some(dst_dtype), x), + ctx.current_loc, + )? + .unwrap() + .to_basic_value_enum(ctx, generator, dst_dtype)?; + ctx.builder.build_store(pdst_ij, dst_ij).unwrap(); + + Ok(()) + }, + num_1, + ) + }) + .unwrap(); + + dst +} + +impl<'ctx> NDArrayValue<'ctx> { + /// Perform [`np.matmul`](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html). + /// + /// This function always return an [`NDArrayValue`]. You may want to use + /// [`NDArrayValue::split_unsized`] to handle when the output could be a scalar. + /// + /// `dst_dtype` defines the dtype of the returned ndarray. + #[must_use] + pub fn matmul( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + self_ty: Type, + (other_ty, other): (Type, Self), + (out_dtype, out): (Type, NDArrayOut<'ctx>), + ) -> Self { + // Sanity check, but type inference should prevent this. + assert!(self.ndims > 0 && other.ndims > 0, "np.matmul disallows scalar input"); + + // If both arguments are 2-D they are multiplied like conventional matrices. + // + // If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the + // last two indices and broadcast accordingly. + // + // If the first argument is 1-D, it is promoted to a matrix by prepending a 1 to its + // dimensions. After matrix multiplication the prepended 1 is removed. + // + // If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its + // dimensions. After matrix multiplication the appended 1 is removed. + + let new_a = if self.ndims == 1 { + // Prepend 1 to its dimensions + self.index(generator, ctx, &[RustNDIndex::NewAxis, RustNDIndex::Ellipsis]) + } else { + *self + }; + + let new_b = if other.ndims == 1 { + // Append 1 to its dimensions + other.index(generator, ctx, &[RustNDIndex::Ellipsis, RustNDIndex::NewAxis]) + } else { + other + }; + + // NOTE: `result` will always be a newly allocated ndarray. + // Current implementation cannot do in-place matrix muliplication. + let mut result = + matmul_at_least_2d(generator, ctx, out_dtype, (self_ty, new_a), (other_ty, new_b)); + + // Postprocessing on the result to remove prepended/appended axes. + let mut postindices = vec![]; + let zero = ctx.ctx.i32_type().const_zero(); + + if self.ndims == 1 { + // Remove the prepended 1 + postindices.push(RustNDIndex::SingleElement(zero)); + } + + if other.ndims == 1 { + // Remove the appended 1 + postindices.push(RustNDIndex::Ellipsis); + postindices.push(RustNDIndex::SingleElement(zero)); + } + + if !postindices.is_empty() { + result = result.index(generator, ctx, &postindices); + } + + match out { + NDArrayOut::NewNDArray { .. } => result, + NDArrayOut::WriteToNDArray { ndarray: out_ndarray } => { + let result_shape = result.shape(); + out_ndarray.assert_can_be_written_by_out(generator, ctx, result_shape); + + out_ndarray.copy_data_from(ctx, result); + out_ndarray + } + } + } +} diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 12fd863..dcb6947 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -1,29 +1,45 @@ +use std::iter::repeat_n; + use inkwell::{ types::{AnyType, AnyTypeEnum, BasicType, BasicTypeEnum, IntType}, - values::{BasicValueEnum, IntValue, PointerValue}, + values::{BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue}, AddressSpace, IntPredicate, }; +use itertools::Itertools; use super::{ - ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeMutator, - UntypedArrayLikeAccessor, UntypedArrayLikeMutator, + structure::StructProxyValue, ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TupleValue, + TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor, + UntypedArrayLikeMutator, }; -use crate::codegen::{ - irrt, - llvm_intrinsics::{call_int_umin, call_memcpy_generic_array}, - stmt::gen_for_callback_incrementing, - type_aligned_alloca, - types::{ndarray::NDArrayType, structure::StructField}, - CodeGenContext, CodeGenerator, +use crate::{ + codegen::{ + irrt, + llvm_intrinsics::{call_int_umin, call_memcpy_generic_array}, + stmt::gen_for_callback_incrementing, + type_aligned_alloca, + types::{ + ndarray::NDArrayType, + structure::{StructField, StructProxyType}, + TupleType, + }, + CodeGenContext, CodeGenerator, + }, + typecheck::typedef::{Type, TypeEnum}, }; +pub use broadcast::*; pub use contiguous::*; pub use indexing::*; pub use nditer::*; -pub use view::*; +mod broadcast; mod contiguous; +mod fold; mod indexing; +mod map; +mod matmul; mod nditer; +pub mod shape; mod view; /// Proxy type for accessing an `NDArray` value in LLVM. @@ -31,19 +47,32 @@ mod view; pub struct NDArrayValue<'ctx> { value: PointerValue<'ctx>, dtype: BasicTypeEnum<'ctx>, - ndims: Option, + ndims: u64, llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, } impl<'ctx> NDArrayValue<'ctx> { - /// Checks whether `value` is an instance of `NDArray`, returning [Err] if `value` is not an - /// instance. - pub fn is_representable( - value: PointerValue<'ctx>, + /// Creates an [`NDArrayValue`] from a [`StructValue`]. + #[must_use] + pub fn from_struct_value( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + val: StructValue<'ctx>, + dtype: BasicTypeEnum<'ctx>, + ndims: u64, llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - NDArrayType::is_representable(value.get_type(), llvm_usize) + name: Option<&'ctx str>, + ) -> Self { + let pval = generator + .gen_var_alloc( + ctx, + val.get_type().into(), + name.map(|name| format!("{name}.addr")).as_deref(), + ) + .unwrap(); + ctx.builder.build_store(pval, val).unwrap(); + Self::from_pointer_value(pval, dtype, ndims, llvm_usize, name) } /// Creates an [`NDArrayValue`] from a [`PointerValue`]. @@ -51,77 +80,54 @@ impl<'ctx> NDArrayValue<'ctx> { pub fn from_pointer_value( ptr: PointerValue<'ctx>, dtype: BasicTypeEnum<'ctx>, - ndims: Option, + ndims: u64, llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, ) -> Self { - debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok()); + debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); NDArrayValue { value: ptr, dtype, ndims, llvm_usize, name } } - fn ndims_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, IntValue<'ctx>> { - self.get_type().get_fields(ctx.ctx).ndims - } - - /// Returns the pointer to the field storing the number of dimensions of this `NDArray`. - fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - self.ndims_field(ctx).ptr_by_gep(ctx, self.value, self.name) + fn ndims_field(&self) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields().ndims } /// Stores the number of dimensions `ndims` into this instance. - pub fn store_ndims( - &self, - ctx: &CodeGenContext<'ctx, '_>, - generator: &G, - ndims: IntValue<'ctx>, - ) { - debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx)); + pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, ndims: IntValue<'ctx>) { + debug_assert_eq!(ndims.get_type(), ctx.get_size_type()); - let pndims = self.ptr_to_ndims(ctx); - ctx.builder.build_store(pndims, ndims).unwrap(); + self.ndims_field().store(ctx, self.value, ndims, self.name); } /// Returns the number of dimensions of this `NDArray` as a value. pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - let pndims = self.ptr_to_ndims(ctx); - ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap() + self.ndims_field().load(ctx, self.value, self.name) } - fn itemsize_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, IntValue<'ctx>> { - self.get_type().get_fields(ctx.ctx).itemsize + fn itemsize_field(&self) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields().itemsize } /// Stores the size of each element `itemsize` into this instance. - pub fn store_itemsize( - &self, - ctx: &CodeGenContext<'ctx, '_>, - generator: &G, - itemsize: IntValue<'ctx>, - ) { - debug_assert_eq!(itemsize.get_type(), generator.get_size_type(ctx.ctx)); + pub fn store_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>, itemsize: IntValue<'ctx>) { + debug_assert_eq!(itemsize.get_type(), ctx.get_size_type()); - self.itemsize_field(ctx).set(ctx, self.value, itemsize, self.name); + self.itemsize_field().store(ctx, self.value, itemsize, self.name); } /// Returns the size of each element of this `NDArray` as a value. pub fn load_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.itemsize_field(ctx).get(ctx, self.value, self.name) + self.itemsize_field().load(ctx, self.value, self.name) } - fn shape_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> { - self.get_type().get_fields(ctx.ctx).shape - } - - /// Returns the double-indirection pointer to the `shape` array, as if by calling - /// `getelementptr` on the field. - fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - self.shape_field(ctx).ptr_by_gep(ctx, self.value, self.name) + fn shape_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { + self.get_type().get_fields().shape } /// Stores the array of dimension sizes `dims` into this instance. fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) { - self.shape_field(ctx).set(ctx, self.as_base_value(), dims, self.name); + self.shape_field().store(ctx, self.value, dims, self.name); } /// Convenience method for creating a new array storing dimension sizes with the given `size`. @@ -140,22 +146,13 @@ impl<'ctx> NDArrayValue<'ctx> { NDArrayShapeProxy(self) } - fn strides_field( - &self, - ctx: &CodeGenContext<'ctx, '_>, - ) -> StructField<'ctx, PointerValue<'ctx>> { - self.get_type().get_fields(ctx.ctx).strides - } - - /// Returns the double-indirection pointer to the `strides` array, as if by calling - /// `getelementptr` on the field. - fn ptr_to_strides(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - self.strides_field(ctx).ptr_by_gep(ctx, self.value, self.name) + fn strides_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { + self.get_type().get_fields().strides } /// Stores the array of stride sizes `strides` into this instance. fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, strides: PointerValue<'ctx>) { - self.strides_field(ctx).set(ctx, self.as_base_value(), strides, self.name); + self.strides_field().store(ctx, self.value, strides, self.name); } /// Convenience method for creating a new array storing the stride with the given `size`. @@ -174,23 +171,23 @@ impl<'ctx> NDArrayValue<'ctx> { NDArrayStridesProxy(self) } - fn data_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> { - self.get_type().get_fields(ctx.ctx).data + fn data_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { + self.get_type().get_fields().data } /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` /// on the field. pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - self.data_field(ctx).ptr_by_gep(ctx, self.value, self.name) + self.data_field().ptr_by_gep(ctx, self.value, self.name) } /// Stores the array of data elements `data` into this instance. - fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) { + pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) { let data = ctx .builder .build_bit_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "") .unwrap(); - self.data_field(ctx).set(ctx, self.as_base_value(), data.into_pointer_value(), self.name); + self.data_field().store(ctx, self.value, data.into_pointer_value(), self.name); } /// Convenience method for creating a new array storing data elements with the given element @@ -206,12 +203,12 @@ impl<'ctx> NDArrayValue<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, ) { - let nbytes = self.nbytes(generator, ctx); + let nbytes = self.nbytes(ctx); let data = type_aligned_alloca(generator, ctx, self.dtype, nbytes, None); self.store_data(ctx, data); - self.set_strides_contiguous(generator, ctx); + self.set_strides_contiguous(ctx); } /// Returns a proxy object to the field storing the data of this `NDArray`. @@ -246,26 +243,7 @@ impl<'ctx> NDArrayValue<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, src_ndarray: NDArrayValue<'ctx>, ) { - if self.ndims.is_some() && src_ndarray.ndims.is_some() { - assert_eq!(self.ndims, src_ndarray.ndims); - } else { - let self_ndims = self.load_ndims(ctx); - let src_ndims = src_ndarray.load_ndims(ctx); - - ctx.make_assert( - generator, - ctx.builder.build_int_compare( - IntPredicate::EQ, - self_ndims, - src_ndims, - "" - ).unwrap(), - "0:AssertionError", - "NDArrayValue::copy_shape_from_ndarray: Expected self.ndims ({0}) == src_ndarray.ndims ({1})", - [Some(self_ndims), Some(src_ndims), None], - ctx.current_loc - ); - } + assert_eq!(self.ndims, src_ndarray.ndims); let src_shape = src_ndarray.shape().base_ptr(ctx, generator); self.copy_shape_from_array(generator, ctx, src_shape); @@ -297,96 +275,57 @@ impl<'ctx> NDArrayValue<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, src_ndarray: NDArrayValue<'ctx>, ) { - if self.ndims.is_some() && src_ndarray.ndims.is_some() { - assert_eq!(self.ndims, src_ndarray.ndims); - } else { - let self_ndims = self.load_ndims(ctx); - let src_ndims = src_ndarray.load_ndims(ctx); - - ctx.make_assert( - generator, - ctx.builder.build_int_compare( - IntPredicate::EQ, - self_ndims, - src_ndims, - "" - ).unwrap(), - "0:AssertionError", - "NDArrayValue::copy_shape_from_ndarray: Expected self.ndims ({0}) == src_ndarray.ndims ({1})", - [Some(self_ndims), Some(src_ndims), None], - ctx.current_loc - ); - } + assert_eq!(self.ndims, src_ndarray.ndims); let src_strides = src_ndarray.strides().base_ptr(ctx, generator); self.copy_strides_from_array(generator, ctx, src_strides); } /// Get the `np.size()` of this ndarray. - pub fn size( - &self, - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ) -> IntValue<'ctx> { - irrt::ndarray::call_nac3_ndarray_size(generator, ctx, *self) + pub fn size(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + irrt::ndarray::call_nac3_ndarray_size(ctx, *self) } /// Get the `ndarray.nbytes` of this ndarray. - pub fn nbytes( - &self, - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ) -> IntValue<'ctx> { - irrt::ndarray::call_nac3_ndarray_nbytes(generator, ctx, *self) + pub fn nbytes(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + irrt::ndarray::call_nac3_ndarray_nbytes(ctx, *self) } /// Get the `len()` of this ndarray. - pub fn len( - &self, - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ) -> IntValue<'ctx> { - irrt::ndarray::call_nac3_ndarray_len(generator, ctx, *self) + pub fn len(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + irrt::ndarray::call_nac3_ndarray_len(ctx, *self) } /// Check if this ndarray is C-contiguous. /// /// See NumPy's `flags["C_CONTIGUOUS"]`: - pub fn is_c_contiguous( - &self, - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ) -> IntValue<'ctx> { - irrt::ndarray::call_nac3_ndarray_is_c_contiguous(generator, ctx, *self) + pub fn is_c_contiguous(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + irrt::ndarray::call_nac3_ndarray_is_c_contiguous(ctx, *self) } /// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`. /// /// Update the ndarray's strides to make the ndarray contiguous. - pub fn set_strides_contiguous( - &self, - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ) { - irrt::ndarray::call_nac3_ndarray_set_strides_by_shape(generator, ctx, *self); + pub fn set_strides_contiguous(&self, ctx: &CodeGenContext<'ctx, '_>) { + irrt::ndarray::call_nac3_ndarray_set_strides_by_shape(ctx, *self); } + /// Clone/Copy this ndarray - Allocate a new ndarray with the same shape as this ndarray and + /// copy the contents over. + /// + /// The new ndarray will own its data and will be C-contiguous. #[must_use] pub fn make_copy( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, ) -> Self { - let clone = if self.ndims.is_some() { - self.get_type().construct_uninitialized(generator, ctx, None) - } else { - self.get_type().construct_dyn_ndims(generator, ctx, self.load_ndims(ctx), None) - }; + let clone = self.get_type().construct_uninitialized(generator, ctx, None); let shape = self.shape(); clone.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator)); unsafe { clone.create_data(generator, ctx) }; - clone.copy_data_from(generator, ctx, *self); + clone.copy_data_from(ctx, *self); clone } @@ -396,50 +335,155 @@ impl<'ctx> NDArrayValue<'ctx> { /// do not matter. The copying order is determined by how their flattened views look. /// /// Panics if the `dtype`s of ndarrays are different. - pub fn copy_data_from( - &self, - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - src: NDArrayValue<'ctx>, - ) { + pub fn copy_data_from(&self, ctx: &CodeGenContext<'ctx, '_>, src: NDArrayValue<'ctx>) { assert_eq!(self.dtype, src.dtype, "self and src dtype should match"); - irrt::ndarray::call_nac3_ndarray_copy_data(generator, ctx, src, *self); + irrt::ndarray::call_nac3_ndarray_copy_data(ctx, src, *self); + } + + /// Fill the ndarray with a scalar. + /// + /// `fill_value` must have the same LLVM type as the `dtype` of this ndarray. + pub fn fill( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: BasicValueEnum<'ctx>, + ) { + // TODO: It is possible to optimize this by exploiting contiguous strides with memset. + // Probably best to implement in IRRT. + self.foreach(generator, ctx, |_, ctx, _, nditer| { + let p = nditer.get_pointer(ctx); + ctx.builder.build_store(p, value).unwrap(); + Ok(()) + }) + .unwrap(); + } + + /// Create the shape tuple of this ndarray like + /// [`np.shape()`](https://numpy.org/doc/stable/reference/generated/numpy.shape.html). + /// + /// All elements in the tuple are `i32`. + pub fn make_shape_tuple( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> TupleValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + + let objects = (0..self.ndims) + .map(|i| { + let dim = unsafe { + self.shape().get_typed_unchecked( + ctx, + generator, + &self.llvm_usize.const_int(i, false), + None, + ) + }; + ctx.builder.build_int_truncate_or_bit_cast(dim, llvm_i32, "").unwrap() + }) + .map(|obj| obj.as_basic_value_enum()) + .collect_vec(); + + TupleType::new(ctx, &repeat_n(llvm_i32, self.ndims as usize).collect_vec()) + .construct_from_objects(ctx, objects, None) + } + + /// Create the strides tuple of this ndarray like + /// [`.strides`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.strides.html). + /// + /// All elements in the tuple are `i32`. + pub fn make_strides_tuple( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> TupleValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + + let objects = (0..self.ndims) + .map(|i| { + let dim = unsafe { + self.strides().get_typed_unchecked( + ctx, + generator, + &self.llvm_usize.const_int(i, false), + None, + ) + }; + ctx.builder.build_int_truncate_or_bit_cast(dim, llvm_i32, "").unwrap() + }) + .map(|obj| obj.as_basic_value_enum()) + .collect_vec(); + + TupleType::new(ctx, &repeat_n(llvm_i32, self.ndims as usize).collect_vec()) + .construct_from_objects(ctx, objects, None) } /// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar. #[must_use] - pub fn is_unsized(&self) -> Option { - self.ndims.map(|ndims| ndims == 0) + pub fn is_unsized(&self) -> bool { + self.ndims == 0 } - /// If this ndarray is unsized, return its sole value as an [`AnyObject`]. + /// Returns the element present in this `ndarray` if this is unsized. + pub fn get_unsized_element( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Option> { + if self.is_unsized() { + // NOTE: `np.size(self) == 0` here is never possible. + let zero = ctx.get_size_type().const_zero(); + let value = unsafe { self.data().get_unchecked(ctx, generator, &zero, None) }; + + Some(value) + } else { + None + } + } + + /// If this ndarray is unsized, return its sole value as an [`BasicValueEnum`]. /// Otherwise, do nothing and return the ndarray itself. - // TODO: Rename to get_unsized_element pub fn split_unsized( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, ) -> ScalarOrNDArray<'ctx> { - let Some(is_unsized) = self.is_unsized() else { todo!() }; - - if is_unsized { - // NOTE: `np.size(self) == 0` here is never possible. - let zero = generator.get_size_type(ctx.ctx).const_zero(); - let value = unsafe { self.data().get_unchecked(ctx, generator, &zero, None) }; - - ScalarOrNDArray::Scalar(value) + if let Some(unsized_elem) = self.get_unsized_element(generator, ctx) { + ScalarOrNDArray::Scalar(unsized_elem) } else { ScalarOrNDArray::NDArray(*self) } } + + /// Check if this `NDArray` can be used as an `out` ndarray for an operation. + /// + /// Raise an exception if the shapes do not match. + pub fn assert_can_be_written_by_out( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + out_shape: impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + ) { + let ndarray_shape = self.shape(); + let output_shape = out_shape; + + irrt::ndarray::call_nac3_ndarray_util_assert_output_shape_same( + generator, + ctx, + &ndarray_shape, + &output_shape, + ); + } } impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> { + type ABI = PointerValue<'ctx>; type Base = PointerValue<'ctx>; type Type = NDArrayType<'ctx>; fn get_type(&self) -> Self::Type { - NDArrayType::from_type( + NDArrayType::from_pointer_type( self.as_base_value().get_type(), self.dtype, self.ndims, @@ -450,8 +494,14 @@ impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> { fn as_base_value(&self) -> Self::Base { self.value } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } } +impl<'ctx> StructProxyValue<'ctx> for NDArrayValue<'ctx> {} + impl<'ctx> From> for PointerValue<'ctx> { fn from(value: NDArrayValue<'ctx>) -> Self { value.as_base_value() @@ -476,7 +526,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> { ctx: &CodeGenContext<'ctx, '_>, _: &G, ) -> PointerValue<'ctx> { - self.0.shape_field(ctx).get(ctx, self.0.as_base_value(), self.0.name) + self.0.shape_field().load(ctx, self.0.value, self.0.name) } fn size( @@ -491,8 +541,8 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> { impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> { unsafe fn ptr_offset_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, idx: &IntValue<'ctx>, name: Option<&str>, ) -> PointerValue<'ctx> { @@ -530,20 +580,26 @@ impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_ impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {} impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {} -impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> { +impl<'ctx, G: CodeGenerator + ?Sized> TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> + for NDArrayShapeProxy<'ctx, '_> +{ fn downcast_to_type( &self, - _: &mut CodeGenContext<'ctx, '_>, + _: &CodeGenContext<'ctx, '_>, + _: &G, value: BasicValueEnum<'ctx>, ) -> IntValue<'ctx> { value.into_int_value() } } -impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> { +impl<'ctx, G: CodeGenerator + ?Sized> TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>> + for NDArrayShapeProxy<'ctx, '_> +{ fn upcast_from_type( &self, - _: &mut CodeGenContext<'ctx, '_>, + _: &CodeGenContext<'ctx, '_>, + _: &G, value: IntValue<'ctx>, ) -> BasicValueEnum<'ctx> { value.into() @@ -568,7 +624,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayStridesProxy<'ctx, '_> { ctx: &CodeGenContext<'ctx, '_>, _: &G, ) -> PointerValue<'ctx> { - self.0.strides_field(ctx).get(ctx, self.0.as_base_value(), self.0.name) + self.0.strides_field().load(ctx, self.0.value, self.0.name) } fn size( @@ -583,8 +639,8 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayStridesProxy<'ctx, '_> { impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> { unsafe fn ptr_offset_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, idx: &IntValue<'ctx>, name: Option<&str>, ) -> PointerValue<'ctx> { @@ -622,20 +678,26 @@ impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {} impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {} -impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> { +impl<'ctx, G: CodeGenerator + ?Sized> TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> + for NDArrayStridesProxy<'ctx, '_> +{ fn downcast_to_type( &self, - _: &mut CodeGenContext<'ctx, '_>, + _: &CodeGenContext<'ctx, '_>, + _: &G, value: BasicValueEnum<'ctx>, ) -> IntValue<'ctx> { value.into_int_value() } } -impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> { +impl<'ctx, G: CodeGenerator + ?Sized> TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>> + for NDArrayStridesProxy<'ctx, '_> +{ fn upcast_from_type( &self, - _: &mut CodeGenContext<'ctx, '_>, + _: &CodeGenContext<'ctx, '_>, + _: &G, value: IntValue<'ctx>, ) -> BasicValueEnum<'ctx> { value.into() @@ -660,49 +722,27 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> { ctx: &CodeGenContext<'ctx, '_>, _: &G, ) -> PointerValue<'ctx> { - self.0.data_field(ctx).get(ctx, self.0.as_base_value(), self.0.name) + self.0.data_field().load(ctx, self.0.value, self.0.name) } fn size( &self, ctx: &CodeGenContext<'ctx, '_>, - generator: &G, + _: &G, ) -> IntValue<'ctx> { - irrt::ndarray::call_ndarray_calc_size( - generator, - ctx, - &self.as_slice_value(ctx, generator), - (None, None), - ) + irrt::ndarray::call_nac3_ndarray_len(ctx, *self.0) } } impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> { unsafe fn ptr_offset_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, idx: &IntValue<'ctx>, name: Option<&str>, ) -> PointerValue<'ctx> { - let sizeof_elem = ctx - .builder - .build_int_truncate_or_bit_cast( - self.element_type(ctx, generator).size_of().unwrap(), - idx.get_type(), - "", - ) - .unwrap(); - let idx = ctx.builder.build_int_mul(*idx, sizeof_elem, "").unwrap(); - let ptr = unsafe { - ctx.builder - .build_in_bounds_gep( - self.base_ptr(ctx, generator), - &[idx], - name.unwrap_or_default(), - ) - .unwrap() - }; + let ptr = irrt::ndarray::call_nac3_ndarray_get_nth_pelement(ctx, *self.0, *idx); // Current implementation is transparent - The returned pointer type is // already cast into the expected type, allowing for immediately @@ -713,7 +753,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> { BasicTypeEnum::try_from(self.element_type(ctx, generator)) .unwrap() .ptr_type(AddressSpace::default()), - "", + name.unwrap_or_default(), ) .unwrap() } @@ -761,55 +801,33 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> { unsafe fn ptr_offset_unchecked( &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, indices: &Index, name: Option<&str>, ) -> PointerValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); + assert_eq!(indices.element_type(ctx, generator), ctx.get_size_type().into()); - let indices_elem_ty = indices - .ptr_offset(ctx, generator, &llvm_usize.const_zero(), None) - .get_type() - .get_element_type(); - let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else { - panic!("Expected list[int32] but got {indices_elem_ty}") - }; - assert_eq!( - indices_elem_ty.get_bit_width(), - 32, - "Expected list[int32] but got list[int{}]", - indices_elem_ty.get_bit_width() + let indices = TypedArrayLikeAdapter::from( + indices.as_slice_value(ctx, generator), + |_, _, v| v.into_int_value(), + |_, _, v| v.into(), ); - let index = irrt::ndarray::call_ndarray_flatten_index(generator, ctx, *self.0, indices); - let sizeof_elem = ctx - .builder - .build_int_truncate_or_bit_cast( - self.element_type(ctx, generator).size_of().unwrap(), - index.get_type(), - "", - ) - .unwrap(); - let index = ctx.builder.build_int_mul(index, sizeof_elem, "").unwrap(); + let ptr = irrt::ndarray::call_nac3_ndarray_get_pelement_by_indices( + generator, ctx, *self.0, &indices, + ); - let ptr = unsafe { - ctx.builder - .build_in_bounds_gep( - self.base_ptr(ctx, generator), - &[index], - name.unwrap_or_default(), - ) - .unwrap() - }; - // TODO: Current implementation is transparent + // Current implementation is transparent - The returned pointer type is + // already cast into the expected type, allowing for immediately + // load/store. ctx.builder .build_pointer_cast( ptr, BasicTypeEnum::try_from(self.element_type(ctx, generator)) .unwrap() .ptr_type(AddressSpace::default()), - "", + name.unwrap_or_default(), ) .unwrap() } @@ -821,7 +839,7 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> indices: &Index, name: Option<&str>, ) -> PointerValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let indices_size = indices.size(ctx, generator); let nidx_leq_ndims = ctx @@ -904,10 +922,9 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeMutator<'ctx, /// This function is used generating strides for globally defined contiguous ndarrays. #[must_use] pub fn make_contiguous_strides(itemsize: u64, ndims: u64, shape: &[u64]) -> Vec { - let mut strides = Vec::with_capacity(ndims as usize); + let mut strides = vec![0u64; ndims as usize]; let mut stride_product = 1u64; - for i in 0..ndims { - let axis = ndims - i - 1; + for axis in (0..ndims).rev() { strides[axis as usize] = stride_product * itemsize; stride_product *= shape[axis as usize]; } @@ -921,13 +938,109 @@ pub enum ScalarOrNDArray<'ctx> { NDArray(NDArrayValue<'ctx>), } +impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for BasicValueEnum<'ctx> { + type Error = (); + + fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result { + match value { + ScalarOrNDArray::Scalar(scalar) => Ok(*scalar), + ScalarOrNDArray::NDArray(_) => Err(()), + } + } +} + +impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for NDArrayValue<'ctx> { + type Error = (); + + fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result { + match value { + ScalarOrNDArray::Scalar(_) => Err(()), + ScalarOrNDArray::NDArray(ndarray) => Ok(*ndarray), + } + } +} + impl<'ctx> ScalarOrNDArray<'ctx> { + /// Split on `object` either into a scalar or an ndarray. + /// + /// If `object` is an ndarray, [`ScalarOrNDArray::NDArray`]. + /// + /// For everything else, it is wrapped with [`ScalarOrNDArray::Scalar`]. + pub fn from_value( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + (object_ty, object): (Type, BasicValueEnum<'ctx>), + ) -> ScalarOrNDArray<'ctx> { + match &*ctx.unifier.get_ty(object_ty) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => + { + let ndarray = NDArrayType::from_unifier_type(generator, ctx, object_ty) + .map_pointer_value(object.into_pointer_value(), None); + ScalarOrNDArray::NDArray(ndarray) + } + + _ => ScalarOrNDArray::Scalar(object), + } + } + /// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`]. #[must_use] pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> { match self { ScalarOrNDArray::Scalar(scalar) => scalar, - ScalarOrNDArray::NDArray(ndarray) => ndarray.as_base_value().into(), + ScalarOrNDArray::NDArray(ndarray) => ndarray.value.into(), + } + } + + /// Convert this [`ScalarOrNDArray`] to an ndarray - behaves like `np.asarray`. + /// + /// - If this is an ndarray, the ndarray is returned. + /// - If this is a scalar, this function returns new ndarray created with + /// [`NDArrayType::construct_unsized`]. + pub fn to_ndarray( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> NDArrayValue<'ctx> { + match self { + ScalarOrNDArray::NDArray(ndarray) => *ndarray, + ScalarOrNDArray::Scalar(scalar) => NDArrayType::new_unsized(ctx, scalar.get_type()) + .construct_unsized(generator, ctx, scalar, None), + } + } + + /// Get the dtype of the ndarray created if this were called with + /// [`ScalarOrNDArray::to_ndarray`]. + #[must_use] + pub fn get_dtype(&self) -> BasicTypeEnum<'ctx> { + match self { + ScalarOrNDArray::NDArray(ndarray) => ndarray.dtype, + ScalarOrNDArray::Scalar(scalar) => scalar.get_type(), + } + } +} + +/// An helper enum specifying how a function should produce its output. +/// +/// Many functions in NumPy has an optional `out` parameter (e.g., `matmul`). If `out` is specified +/// with an ndarray, the result of a function will be written to `out`. If `out` is not specified, a +/// function will create a new ndarray and store the result in it. +#[derive(Clone, Copy)] +pub enum NDArrayOut<'ctx> { + /// Tell a function should create a new ndarray with the expected element type `dtype`. + NewNDArray { dtype: BasicTypeEnum<'ctx> }, + /// Tell a function to write the result to `ndarray`. + WriteToNDArray { ndarray: NDArrayValue<'ctx> }, +} + +impl<'ctx> NDArrayOut<'ctx> { + /// Get the dtype of this output. + #[must_use] + pub fn get_dtype(&self) -> BasicTypeEnum<'ctx> { + match self { + NDArrayOut::NewNDArray { dtype } => *dtype, + NDArrayOut::WriteToNDArray { ndarray } => ndarray.dtype, } } } diff --git a/nac3core/src/codegen/values/ndarray/nditer.rs b/nac3core/src/codegen/values/ndarray/nditer.rs index 45a82b3..c1bf751 100644 --- a/nac3core/src/codegen/values/ndarray/nditer.rs +++ b/nac3core/src/codegen/values/ndarray/nditer.rs @@ -1,15 +1,18 @@ use inkwell::{ types::{BasicType, IntType}, - values::{BasicValueEnum, IntValue, PointerValue}, + values::{BasicValueEnum, IntValue, PointerValue, StructValue}, AddressSpace, }; -use super::{NDArrayValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeMutator}; +use super::NDArrayValue; use crate::codegen::{ irrt, stmt::{gen_for_callback, BreakContinueHooks}, - types::{ndarray::NDIterType, structure::StructField}, - values::{ArraySliceValue, TypedArrayLikeAdapter}, + types::{ + ndarray::NDIterType, + structure::{StructField, StructProxyType}, + }, + values::{structure::StructProxyValue, ArraySliceValue, ProxyValue, TypedArrayLikeAdapter}, CodeGenContext, CodeGenerator, }; @@ -23,13 +26,26 @@ pub struct NDIterValue<'ctx> { } impl<'ctx> NDIterValue<'ctx> { - /// Checks whether `value` is an instance of `NDArray`, returning [Err] if `value` is not an - /// instance. - pub fn is_representable( - value: PointerValue<'ctx>, + /// Creates an [`NDArrayValue`] from a [`StructValue`]. + #[must_use] + pub fn from_struct_value( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + val: StructValue<'ctx>, + parent: NDArrayValue<'ctx>, + indices: ArraySliceValue<'ctx>, llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - ::Type::is_representable(value.get_type(), llvm_usize) + name: Option<&'ctx str>, + ) -> Self { + let pval = generator + .gen_var_alloc( + ctx, + val.get_type().into(), + name.map(|name| format!("{name}.addr")).as_deref(), + ) + .unwrap(); + ctx.builder.build_store(pval, val).unwrap(); + Self::from_pointer_value(pval, parent, indices, llvm_usize, name) } /// Creates an [`NDArrayValue`] from a [`PointerValue`]. @@ -41,7 +57,7 @@ impl<'ctx> NDIterValue<'ctx> { llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, ) -> Self { - debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok()); + debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); Self { value: ptr, parent, indices, llvm_usize, name } } @@ -53,24 +69,20 @@ impl<'ctx> NDIterValue<'ctx> { /// If `ndarray` is unsized, this returns true only for the first iteration. /// If `ndarray` is 0-sized, this always returns false. #[must_use] - pub fn has_element( - &self, - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ) -> IntValue<'ctx> { - irrt::ndarray::call_nac3_nditer_has_element(generator, ctx, *self) + pub fn has_element(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + irrt::ndarray::call_nac3_nditer_has_element(ctx, *self) } /// Go to the next element. If `has_element()` is false, then this has undefined behavior. /// /// If `ndarray` is unsized, this can only be called once. /// If `ndarray` is 0-sized, this can never be called. - pub fn next(&self, generator: &G, ctx: &CodeGenContext<'ctx, '_>) { - irrt::ndarray::call_nac3_nditer_next(generator, ctx, *self); + pub fn next(&self, ctx: &CodeGenContext<'ctx, '_>) { + irrt::ndarray::call_nac3_nditer_next(ctx, *self); } - fn element(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> { - self.get_type().get_fields(ctx.ctx).element + fn element_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { + self.get_type().get_fields().element } /// Get pointer to the current element. @@ -78,7 +90,7 @@ impl<'ctx> NDIterValue<'ctx> { pub fn get_pointer(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { let elem_ty = self.parent.dtype; - let p = self.element(ctx).get(ctx, self.as_base_value(), None); + let p = self.element_field().load(ctx, self.as_abi_value(ctx), self.name); ctx.builder .build_pointer_cast(p, elem_ty.ptr_type(AddressSpace::default()), "element") .unwrap() @@ -91,47 +103,49 @@ impl<'ctx> NDIterValue<'ctx> { ctx.builder.build_load(p, "value").unwrap() } - fn nth(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, IntValue<'ctx>> { - self.get_type().get_fields(ctx.ctx).nth + fn nth_field(&self) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields().nth } /// Get the index of the current element if this ndarray were a flat ndarray. #[must_use] pub fn get_index(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.nth(ctx).get(ctx, self.as_base_value(), None) + self.nth_field().load(ctx, self.as_abi_value(ctx), self.name) } /// Get the indices of the current element. #[must_use] - pub fn get_indices( - &'ctx self, - ) -> impl TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> + TypedArrayLikeMutator<'ctx, IntValue<'ctx>> - { + pub fn get_indices( + &self, + ) -> TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>> { TypedArrayLikeAdapter::from( self.indices, - Box::new(|ctx, val| { - ctx.builder - .build_int_z_extend_or_bit_cast(val.into_int_value(), self.llvm_usize, "") - .unwrap() - }), - Box::new(|_, val| val.into()), + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), ) } } impl<'ctx> ProxyValue<'ctx> for NDIterValue<'ctx> { + type ABI = PointerValue<'ctx>; type Base = PointerValue<'ctx>; type Type = NDIterType<'ctx>; fn get_type(&self) -> Self::Type { - NDIterType::from_type(self.as_base_value().get_type(), self.llvm_usize) + NDIterType::from_pointer_type(self.as_base_value().get_type(), self.llvm_usize) } fn as_base_value(&self) -> Self::Base { self.value } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } } +impl<'ctx> StructProxyValue<'ctx> for NDIterValue<'ctx> {} + impl<'ctx> From> for PointerValue<'ctx> { fn from(value: NDIterValue<'ctx>) -> Self { value.as_base_value() @@ -162,13 +176,11 @@ impl<'ctx> NDArrayValue<'ctx> { generator, ctx, Some("ndarray_foreach"), - |generator, ctx| { - Ok(NDIterType::new(generator, ctx.ctx).construct(generator, ctx, *self)) - }, - |generator, ctx, nditer| Ok(nditer.has_element(generator, ctx)), + |generator, ctx| Ok(NDIterType::new(ctx).construct(generator, ctx, *self)), + |_, ctx, nditer| Ok(nditer.has_element(ctx)), |generator, ctx, hooks, nditer| body(generator, ctx, hooks, nditer), - |generator, ctx, nditer| { - nditer.next(generator, ctx); + |_, ctx, nditer| { + nditer.next(ctx); Ok(()) }, ) diff --git a/nac3core/src/codegen/values/ndarray/shape.rs b/nac3core/src/codegen/values/ndarray/shape.rs new file mode 100644 index 0000000..b3331b6 --- /dev/null +++ b/nac3core/src/codegen/values/ndarray/shape.rs @@ -0,0 +1,152 @@ +use inkwell::values::{BasicValueEnum, IntValue}; + +use crate::{ + codegen::{ + stmt::gen_for_callback_incrementing, + types::{ListType, TupleType}, + values::{ + ArraySliceValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, + TypedArrayLikeMutator, UntypedArrayLikeAccessor, + }, + CodeGenContext, CodeGenerator, + }, + typecheck::typedef::{Type, TypeEnum}, +}; + +/// Parse a NumPy-like "int sequence" input and return the int sequence as an array and its length. +/// +/// * `sequence` - The `sequence` parameter. +/// * `sequence_ty` - The typechecker type of `sequence` +/// +/// The `sequence` argument type may only be one of the following: +/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])` +/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))` +/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to +/// `np.empty([3])` +/// +/// All `int32` values will be sign-extended to `SizeT`. +pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + (input_seq_ty, input_seq): (Type, BasicValueEnum<'ctx>), +) -> impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> { + let llvm_usize = ctx.get_size_type(); + let zero = llvm_usize.const_zero(); + let one = llvm_usize.const_int(1, false); + + // The result `list` to return. + match &*ctx.unifier.get_ty_immutable(input_seq_ty) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + // 1. A list of `int32`; e.g., `np.empty([600, 800, 3])` + + let input_seq = ListType::from_unifier_type(generator, ctx, input_seq_ty) + .map_pointer_value(input_seq.into_pointer_value(), None); + + let len = input_seq.load_size(ctx, None); + // TODO: Find a way to remove this mid-BB allocation + let result = ctx.builder.build_array_alloca(llvm_usize, len, "").unwrap(); + let result = TypedArrayLikeAdapter::from( + ArraySliceValue::from_ptr_val(result, len, None), + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + + // Load all the `int32`s from the input_sequence, cast them to `SizeT`, and store them into `result` + gen_for_callback_incrementing( + generator, + ctx, + None, + zero, + (len, false), + |generator, ctx, _, i| { + // Load the i-th int32 in the input sequence + let int = unsafe { + input_seq.data().get_unchecked(ctx, generator, &i, None).into_int_value() + }; + + // Cast to SizeT + let int = + ctx.builder.build_int_s_extend_or_bit_cast(int, llvm_usize, "").unwrap(); + + // Store + unsafe { result.set_typed_unchecked(ctx, generator, &i, int) }; + + Ok(()) + }, + one, + ) + .unwrap(); + + result + } + + TypeEnum::TTuple { .. } => { + // 2. A tuple of ints; e.g., `np.empty((600, 800, 3))` + + let input_seq = TupleType::from_unifier_type(generator, ctx, input_seq_ty) + .map_struct_value(input_seq.into_struct_value(), None); + + let len = input_seq.get_type().num_elements(); + + let result = generator + .gen_array_var_alloc( + ctx, + llvm_usize.into(), + llvm_usize.const_int(u64::from(len), false), + None, + ) + .unwrap(); + let result = TypedArrayLikeAdapter::from( + result, + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + + for i in 0..input_seq.get_type().num_elements() { + // Get the i-th element off of the tuple and load it into `result`. + let int = input_seq.load_element(ctx, i).into_int_value(); + let int = ctx.builder.build_int_s_extend_or_bit_cast(int, llvm_usize, "").unwrap(); + + unsafe { + result.set_typed_unchecked( + ctx, + generator, + &llvm_usize.const_int(u64::from(i), false), + int, + ); + } + } + + result + } + + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() => + { + // 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` + + let input_int = input_seq.into_int_value(); + + let len = one; + let result = generator.gen_array_var_alloc(ctx, llvm_usize.into(), len, None).unwrap(); + let result = TypedArrayLikeAdapter::from( + result, + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + let int = + ctx.builder.build_int_s_extend_or_bit_cast(input_int, llvm_usize, "").unwrap(); + + // Storing into result[0] + unsafe { + result.set_typed_unchecked(ctx, generator, &zero, int); + } + + result + } + + _ => panic!("encountered unknown sequence type: {}", ctx.unifier.stringify(input_seq_ty)), + } +} diff --git a/nac3core/src/codegen/values/ndarray/view.rs b/nac3core/src/codegen/values/ndarray/view.rs index 70a9d65..9ab3d30 100644 --- a/nac3core/src/codegen/values/ndarray/view.rs +++ b/nac3core/src/codegen/values/ndarray/view.rs @@ -1,9 +1,16 @@ use std::iter::{once, repeat_n}; +use inkwell::values::{IntValue, PointerValue}; use itertools::Itertools; use crate::codegen::{ - values::ndarray::{NDArrayValue, RustNDIndex}, + irrt, + stmt::gen_if_callback, + types::ndarray::NDArrayType, + values::{ + ndarray::{NDArrayValue, RustNDIndex}, + ArrayLikeValue, ArraySliceValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, + }, CodeGenContext, CodeGenerator, }; @@ -19,9 +26,7 @@ impl<'ctx> NDArrayValue<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, ndmin: u64, ) -> Self { - assert!(self.ndims.is_some(), "NDArrayValue::atleast_nd is only supported for instances with compile-time known ndims (self.ndims = Some(...))"); - - let ndims = self.ndims.unwrap(); + let ndims = self.ndims; if ndims < ndmin { // Extend the dimensions with np.newaxis. @@ -33,4 +38,117 @@ impl<'ctx> NDArrayValue<'ctx> { *self } } + + /// Create a reshaped view on this ndarray like + /// [`np.reshape()`](https://numpy.org/doc/stable/reference/generated/numpy.reshape.html). + /// + /// If there is a `-1` in `new_shape`, it will be resolved; `new_shape` would **NOT** be + /// modified as a result. + /// + /// If reshape without copying is impossible, this function will allocate a new ndarray and copy + /// contents. + /// + /// * `new_ndims` - The number of dimensions of `new_shape` as a [`Type`]. + /// * `new_shape` - The target shape to do `np.reshape()`. + #[must_use] + pub fn reshape_or_copy( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + new_ndims: u64, + new_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + ) -> Self { + assert_eq!(new_shape.element_type(ctx, generator), self.llvm_usize.into()); + + // TODO: The current criterion for whether to do a full copy or not is by checking + // `is_c_contiguous`, but this is not optimal - there are cases when the ndarray is + // not contiguous but could be reshaped without copying data. Look into how numpy does + // it. + + let dst_ndarray = NDArrayType::new(ctx, self.dtype, new_ndims) + .construct_uninitialized(generator, ctx, None); + dst_ndarray.copy_shape_from_array(generator, ctx, new_shape.base_ptr(ctx, generator)); + + // Resolve negative indices + let size = self.size(ctx); + let dst_ndims = self.llvm_usize.const_int(dst_ndarray.get_type().ndims(), false); + let dst_shape = dst_ndarray.shape(); + irrt::ndarray::call_nac3_ndarray_reshape_resolve_and_check_new_shape( + generator, + ctx, + size, + dst_ndims, + dst_shape.as_slice_value(ctx, generator), + ); + + gen_if_callback( + generator, + ctx, + |_, ctx| Ok(self.is_c_contiguous(ctx)), + |generator, ctx| { + // Reshape is possible without copying + dst_ndarray.set_strides_contiguous(ctx); + dst_ndarray.store_data(ctx, self.data().base_ptr(ctx, generator)); + + Ok(()) + }, + |generator, ctx| { + // Reshape is impossible without copying + unsafe { + dst_ndarray.create_data(generator, ctx); + } + dst_ndarray.copy_data_from(ctx, *self); + + Ok(()) + }, + ) + .unwrap(); + + dst_ndarray + } + + /// Create a transposed view on this ndarray like + /// [`np.transpose(, = None)`](https://numpy.org/doc/stable/reference/generated/numpy.transpose.html). + /// + /// * `axes` - If specified, should be an array of the permutation (negative indices are + /// **allowed**). + #[must_use] + pub fn transpose( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + axes: Option>, + ) -> Self { + assert!( + axes.is_none_or(|axes| axes.get_type().get_element_type() == self.llvm_usize.into()) + ); + + // Define models + let transposed_ndarray = self.get_type().construct_uninitialized(generator, ctx, None); + + let axes = if let Some(axes) = axes { + let num_axes = self.llvm_usize.const_int(self.ndims, false); + + // `axes = nullptr` if `axes` is unspecified. + let axes = ArraySliceValue::from_ptr_val(axes, num_axes, None); + + Some(TypedArrayLikeAdapter::from( + axes, + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + )) + } else { + None + }; + + irrt::ndarray::call_nac3_ndarray_transpose( + generator, + ctx, + *self, + transposed_ndarray, + axes.as_ref(), + ); + + transposed_ndarray + } } diff --git a/nac3core/src/codegen/values/range.rs b/nac3core/src/codegen/values/range.rs index 7e9976a..67e623a 100644 --- a/nac3core/src/codegen/values/range.rs +++ b/nac3core/src/codegen/values/range.rs @@ -1,27 +1,50 @@ -use inkwell::values::{BasicValueEnum, IntValue, PointerValue}; +use inkwell::{ + types::IntType, + values::{ArrayValue, BasicValueEnum, IntValue, PointerValue}, +}; use super::ProxyValue; -use crate::codegen::{types::RangeType, CodeGenContext}; +use crate::codegen::{types::RangeType, CodeGenContext, CodeGenerator}; /// Proxy type for accessing a `range` value in LLVM. #[derive(Copy, Clone)] pub struct RangeValue<'ctx> { value: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, } impl<'ctx> RangeValue<'ctx> { - /// Checks whether `value` is an instance of `range`, returning [Err] if `value` is not an instance. - pub fn is_representable(value: PointerValue<'ctx>) -> Result<(), String> { - RangeType::is_representable(value.get_type()) + /// Creates an [`RangeValue`] from a [`PointerValue`]. + #[must_use] + pub fn from_array_value( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + val: ArrayValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + let pval = generator + .gen_var_alloc( + ctx, + val.get_type().into(), + name.map(|name| format!("{name}.addr")).as_deref(), + ) + .unwrap(); + ctx.builder.build_store(pval, val).unwrap(); + Self::from_pointer_value(pval, llvm_usize, name) } /// Creates an [`RangeValue`] from a [`PointerValue`]. #[must_use] - pub fn from_pointer_value(ptr: PointerValue<'ctx>, name: Option<&'ctx str>) -> Self { - debug_assert!(Self::is_representable(ptr).is_ok()); + pub fn from_pointer_value( + ptr: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); - RangeValue { value: ptr, name } + RangeValue { value: ptr, llvm_usize, name } } fn ptr_to_start(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { @@ -31,7 +54,7 @@ impl<'ctx> RangeValue<'ctx> { unsafe { ctx.builder .build_in_bounds_gep( - self.as_base_value(), + self.as_abi_value(ctx), &[llvm_i32.const_zero(), llvm_i32.const_int(0, false)], var_name.as_str(), ) @@ -46,7 +69,7 @@ impl<'ctx> RangeValue<'ctx> { unsafe { ctx.builder .build_in_bounds_gep( - self.as_base_value(), + self.as_abi_value(ctx), &[llvm_i32.const_zero(), llvm_i32.const_int(1, false)], var_name.as_str(), ) @@ -61,7 +84,7 @@ impl<'ctx> RangeValue<'ctx> { unsafe { ctx.builder .build_in_bounds_gep( - self.as_base_value(), + self.as_abi_value(ctx), &[llvm_i32.const_zero(), llvm_i32.const_int(2, false)], var_name.as_str(), ) @@ -134,16 +157,21 @@ impl<'ctx> RangeValue<'ctx> { } impl<'ctx> ProxyValue<'ctx> for RangeValue<'ctx> { + type ABI = PointerValue<'ctx>; type Base = PointerValue<'ctx>; type Type = RangeType<'ctx>; fn get_type(&self) -> Self::Type { - RangeType::from_type(self.value.get_type()) + RangeType::from_pointer_type(self.value.get_type(), self.llvm_usize) } fn as_base_value(&self) -> Self::Base { self.value } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } } impl<'ctx> From> for PointerValue<'ctx> { diff --git a/nac3core/src/codegen/values/structure.rs b/nac3core/src/codegen/values/structure.rs new file mode 100644 index 0000000..dfe4543 --- /dev/null +++ b/nac3core/src/codegen/values/structure.rs @@ -0,0 +1,24 @@ +use inkwell::values::{BasicValueEnum, PointerValue, StructValue}; + +use super::ProxyValue; +use crate::codegen::{types::structure::StructProxyType, CodeGenContext}; + +/// An LLVM value that is used to represent a corresponding structure-like value in NAC3. +pub trait StructProxyValue<'ctx>: + ProxyValue<'ctx, Base = PointerValue<'ctx>, Type: StructProxyType<'ctx, Value = Self>> +{ + /// Returns this value as a [`StructValue`]. + #[must_use] + fn get_struct_value(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructValue<'ctx> { + ctx.builder + .build_load(self.get_pointer_value(ctx), "") + .map(BasicValueEnum::into_struct_value) + .unwrap() + } + + /// Returns this value as a [`PointerValue`]. + #[must_use] + fn get_pointer_value(&self, _: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + self.as_base_value() + } +} diff --git a/nac3core/src/codegen/values/tuple.rs b/nac3core/src/codegen/values/tuple.rs new file mode 100644 index 0000000..320e219 --- /dev/null +++ b/nac3core/src/codegen/values/tuple.rs @@ -0,0 +1,99 @@ +use inkwell::{ + types::IntType, + values::{BasicValue, BasicValueEnum, PointerValue, StructValue}, +}; + +use super::ProxyValue; +use crate::codegen::{types::TupleType, CodeGenContext}; + +#[derive(Copy, Clone)] +pub struct TupleValue<'ctx> { + value: StructValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, +} + +impl<'ctx> TupleValue<'ctx> { + /// Creates an [`TupleValue`] from a [`StructValue`]. + #[must_use] + pub fn from_struct_value( + value: StructValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + debug_assert!(Self::is_instance(value, llvm_usize).is_ok()); + + Self { value, llvm_usize, name } + } + + /// Creates an [`TupleValue`] from a [`PointerValue`]. + #[must_use] + pub fn from_pointer_value( + ctx: &CodeGenContext<'ctx, '_>, + ptr: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + Self::from_struct_value( + ctx.builder + .build_load(ptr, name.unwrap_or_default()) + .map(BasicValueEnum::into_struct_value) + .unwrap(), + llvm_usize, + name, + ) + } + + /// Stores a value into the tuple element at the given `index`. + pub fn store_element( + &mut self, + ctx: &CodeGenContext<'ctx, '_>, + index: u32, + element: impl BasicValue<'ctx>, + ) { + assert_eq!(element.as_basic_value_enum().get_type(), unsafe { + self.get_type().type_at_index_unchecked(index) + }); + + let new_value = ctx + .builder + .build_insert_value(self.value, element, index, self.name.unwrap_or_default()) + .unwrap(); + self.value = new_value.into_struct_value(); + } + + /// Loads a value from the tuple element at the given `index`. + pub fn load_element(&self, ctx: &CodeGenContext<'ctx, '_>, index: u32) -> BasicValueEnum<'ctx> { + ctx.builder + .build_extract_value( + self.value, + index, + &format!("{}[{{i}}]", self.name.unwrap_or("tuple")), + ) + .unwrap() + } +} + +impl<'ctx> ProxyValue<'ctx> for TupleValue<'ctx> { + type ABI = StructValue<'ctx>; + type Base = StructValue<'ctx>; + type Type = TupleType<'ctx>; + + fn get_type(&self) -> Self::Type { + TupleType::from_struct_type(self.as_base_value().get_type(), self.llvm_usize) + } + + fn as_base_value(&self) -> Self::Base { + self.value + } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } +} + +impl<'ctx> From> for StructValue<'ctx> { + fn from(value: TupleValue<'ctx>) -> Self { + value.as_base_value() + } +} diff --git a/nac3core/src/codegen/values/utils/slice.rs b/nac3core/src/codegen/values/utils/slice.rs index dffe6ce..549e556 100644 --- a/nac3core/src/codegen/values/utils/slice.rs +++ b/nac3core/src/codegen/values/utils/slice.rs @@ -1,14 +1,17 @@ use inkwell::{ types::IntType, - values::{IntValue, PointerValue}, + values::{IntValue, PointerValue, StructValue}, }; use nac3parser::ast::Expr; use crate::{ codegen::{ - types::{structure::StructField, utils::SliceType}, - values::ProxyValue, + types::{ + structure::{StructField, StructProxyType}, + utils::SliceType, + }, + values::{structure::StructProxyValue, ProxyValue}, CodeGenContext, CodeGenerator, }, typecheck::typedef::Type, @@ -24,13 +27,25 @@ pub struct SliceValue<'ctx> { } impl<'ctx> SliceValue<'ctx> { - /// Checks whether `value` is an instance of `ContiguousNDArray`, returning [Err] if `value` is - /// not an instance. - pub fn is_representable( - value: PointerValue<'ctx>, + /// Creates an [`SliceValue`] from a [`StructValue`]. + #[must_use] + pub fn from_struct_value( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + val: StructValue<'ctx>, + int_ty: IntType<'ctx>, llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - >::Type::is_representable(value.get_type(), llvm_usize) + name: Option<&'ctx str>, + ) -> Self { + let pval = generator + .gen_var_alloc( + ctx, + val.get_type().into(), + name.map(|name| format!("{name}.addr")).as_deref(), + ) + .unwrap(); + ctx.builder.build_store(pval, val).unwrap(); + Self::from_pointer_value(pval, int_ty, llvm_usize, name) } /// Creates an [`SliceValue`] from a [`PointerValue`]. @@ -41,7 +56,7 @@ impl<'ctx> SliceValue<'ctx> { llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, ) -> Self { - debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok()); + debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); Self { value: ptr, int_ty, llvm_usize, name } } @@ -51,7 +66,7 @@ impl<'ctx> SliceValue<'ctx> { } pub fn load_start_defined(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.start_defined_field().get(ctx, self.value, self.name) + self.start_defined_field().load(ctx, self.value, self.name) } fn start_field(&self) -> StructField<'ctx, IntValue<'ctx>> { @@ -59,22 +74,22 @@ impl<'ctx> SliceValue<'ctx> { } pub fn load_start(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.start_field().get(ctx, self.value, self.name) + self.start_field().load(ctx, self.value, self.name) } pub fn store_start(&self, ctx: &CodeGenContext<'ctx, '_>, value: Option>) { match value { Some(start) => { - self.start_defined_field().set( + self.start_defined_field().store( ctx, self.value, ctx.ctx.bool_type().const_all_ones(), self.name, ); - self.start_field().set(ctx, self.value, start, self.name); + self.start_field().store(ctx, self.value, start, self.name); } - None => self.start_defined_field().set( + None => self.start_defined_field().store( ctx, self.value, ctx.ctx.bool_type().const_zero(), @@ -88,7 +103,7 @@ impl<'ctx> SliceValue<'ctx> { } pub fn load_stop_defined(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.stop_defined_field().get(ctx, self.value, self.name) + self.stop_defined_field().load(ctx, self.value, self.name) } fn stop_field(&self) -> StructField<'ctx, IntValue<'ctx>> { @@ -96,22 +111,22 @@ impl<'ctx> SliceValue<'ctx> { } pub fn load_stop(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.stop_field().get(ctx, self.value, self.name) + self.stop_field().load(ctx, self.value, self.name) } pub fn store_stop(&self, ctx: &CodeGenContext<'ctx, '_>, value: Option>) { match value { Some(stop) => { - self.stop_defined_field().set( + self.stop_defined_field().store( ctx, self.value, ctx.ctx.bool_type().const_all_ones(), self.name, ); - self.stop_field().set(ctx, self.value, stop, self.name); + self.stop_field().store(ctx, self.value, stop, self.name); } - None => self.stop_defined_field().set( + None => self.stop_defined_field().store( ctx, self.value, ctx.ctx.bool_type().const_zero(), @@ -125,7 +140,7 @@ impl<'ctx> SliceValue<'ctx> { } pub fn load_step_defined(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.step_defined_field().get(ctx, self.value, self.name) + self.step_defined_field().load(ctx, self.value, self.name) } fn step_field(&self) -> StructField<'ctx, IntValue<'ctx>> { @@ -133,22 +148,22 @@ impl<'ctx> SliceValue<'ctx> { } pub fn load_step(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.step_field().get(ctx, self.value, self.name) + self.step_field().load(ctx, self.value, self.name) } pub fn store_step(&self, ctx: &CodeGenContext<'ctx, '_>, value: Option>) { match value { Some(step) => { - self.step_defined_field().set( + self.step_defined_field().store( ctx, self.value, ctx.ctx.bool_type().const_all_ones(), self.name, ); - self.step_field().set(ctx, self.value, step, self.name); + self.step_field().store(ctx, self.value, step, self.name); } - None => self.step_defined_field().set( + None => self.step_defined_field().store( ctx, self.value, ctx.ctx.bool_type().const_zero(), @@ -159,18 +174,25 @@ impl<'ctx> SliceValue<'ctx> { } impl<'ctx> ProxyValue<'ctx> for SliceValue<'ctx> { + type ABI = PointerValue<'ctx>; type Base = PointerValue<'ctx>; type Type = SliceType<'ctx>; fn get_type(&self) -> Self::Type { - Self::Type::from_type(self.value.get_type(), self.int_ty, self.llvm_usize) + Self::Type::from_pointer_type(self.value.get_type(), self.int_ty, self.llvm_usize) } fn as_base_value(&self) -> Self::Base { self.value } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } } +impl<'ctx> StructProxyValue<'ctx> for SliceValue<'ctx> {} + impl<'ctx> From> for PointerValue<'ctx> { fn from(value: SliceValue<'ctx>) -> Self { value.as_base_value() diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index bab823c..4829093 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -6,7 +6,7 @@ use std::{ }; use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue}; -use itertools::{chain, izip, Itertools}; +use itertools::{izip, Itertools}; use parking_lot::RwLock; use nac3parser::ast::{Constant, Expr, Location, StrRef}; @@ -452,11 +452,11 @@ pub fn parse_type_annotation( type_vars.len() )])); } - let fields = chain( - fields.iter().map(|(k, v, m)| (*k, (*v, *m))), - methods.iter().map(|(k, v, _)| (*k, (*v, false))), - ) - .collect(); + let fields = fields + .iter() + .map(|(k, v, m)| (*k, (*v, *m))) + .chain(methods.iter().map(|(k, v, _)| (*k, (*v, false)))) + .collect(); Ok(unifier.add_ty(TypeEnum::TObj { obj_id, fields, params: VarMap::default() })) } else { Err(HashSet::from([format!("Cannot use function name as type at {loc}")])) @@ -598,10 +598,12 @@ impl dyn SymbolResolver + Send + Sync { unifier.internal_stringify( ty, &mut |id| { - let TopLevelDef::Class { name, .. } = &*top_level_defs[id].read() else { - unreachable!("expected class definition") + let top_level_def = &*top_level_defs[id].read(); + let (TopLevelDef::Class { name, .. } | TopLevelDef::Module { name, .. }) = + top_level_def + else { + unreachable!("expected class/module definition") }; - name.to_string() }, &mut |id| format!("typevar{id}"), diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index e382222..eff614e 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1,26 +1,27 @@ use std::iter::once; use indexmap::IndexMap; -use inkwell::{ - attributes::{Attribute, AttributeLoc}, - types::{BasicMetadataTypeEnum, BasicType}, - values::{BasicMetadataValueEnum, BasicValue, CallSiteValue}, - IntPredicate, -}; -use itertools::Either; +use inkwell::{values::BasicValue, IntPredicate}; use strum::IntoEnumIterator; use super::{ - helper::{debug_assert_prim_is_allowed, make_exception_fields, PrimDef, PrimDefDetails}, - numpy::make_ndarray_ty, + helper::{ + arraylike_flatten_element_type, debug_assert_prim_is_allowed, extract_ndims, + make_exception_fields, PrimDef, PrimDefDetails, + }, + numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, *, }; use crate::{ codegen::{ builtin_fns, numpy::*, - stmt::exn_constructor, - values::{ProxyValue, RangeValue}, + stmt::{exn_constructor, gen_if_callback}, + types::{ndarray::NDArrayType, RangeType}, + values::{ + ndarray::{shape::parse_numpy_int_sequence, ScalarOrNDArray}, + ProxyValue, + }, }, symbol_resolver::SymbolValue, typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap}, @@ -148,144 +149,6 @@ fn create_fn_by_codegen( } } -/// Creates a NumPy [`TopLevelDef`] function using an LLVM intrinsic. -/// -/// * `name`: The name of the implemented NumPy function. -/// * `ret_ty`: The return type of this function. -/// * `param_ty`: The parameters accepted by this function, represented by a tuple of the -/// [parameter type][Type] and the parameter symbol name. -/// * `intrinsic_fn`: The fully-qualified name of the LLVM intrinsic function. -fn create_fn_by_intrinsic( - unifier: &mut Unifier, - var_map: &VarMap, - name: &'static str, - ret_ty: Type, - params: &[(Type, &'static str)], - intrinsic_fn: &'static str, -) -> TopLevelDef { - let param_tys = params.iter().map(|p| p.0).collect_vec(); - - create_fn_by_codegen( - unifier, - var_map, - name, - ret_ty, - params, - Box::new(move |ctx, _, fun, args, generator| { - let args_ty = fun.0.args.iter().map(|a| a.ty).collect_vec(); - - assert!(param_tys - .iter() - .zip(&args_ty) - .all(|(expected, actual)| ctx.unifier.unioned(*expected, *actual))); - - let args_val = args_ty - .iter() - .zip_eq(args.iter()) - .map(|(ty, arg)| arg.1.clone().to_basic_value_enum(ctx, generator, *ty).unwrap()) - .map_into::() - .collect_vec(); - - let intrinsic_fn = ctx.module.get_function(intrinsic_fn).unwrap_or_else(|| { - let ret_llvm_ty = ctx.get_llvm_abi_type(generator, ret_ty); - let param_llvm_ty = param_tys - .iter() - .map(|p| ctx.get_llvm_abi_type(generator, *p)) - .map_into::() - .collect_vec(); - let fn_type = ret_llvm_ty.fn_type(param_llvm_ty.as_slice(), false); - - ctx.module.add_function(intrinsic_fn, fn_type, None) - }); - - let val = ctx - .builder - .build_call(intrinsic_fn, args_val.as_slice(), name) - .map(CallSiteValue::try_as_basic_value) - .map(Either::unwrap_left) - .unwrap(); - Ok(val.into()) - }), - ) -} - -/// Creates a unary NumPy [`TopLevelDef`] function using an extern function (e.g. from `libc` or -/// `libm`). -/// -/// * `name`: The name of the implemented NumPy function. -/// * `ret_ty`: The return type of this function. -/// * `param_ty`: The parameters accepted by this function, represented by a tuple of the -/// [parameter type][Type] and the parameter symbol name. -/// * `extern_fn`: The fully-qualified name of the extern function used as the implementation. -/// * `attrs`: The list of attributes to apply to this function declaration. Note that `nounwind` is -/// already implied by the C ABI. -fn create_fn_by_extern( - unifier: &mut Unifier, - var_map: &VarMap, - name: &'static str, - ret_ty: Type, - params: &[(Type, &'static str)], - extern_fn: &'static str, - attrs: &'static [&str], -) -> TopLevelDef { - let param_tys = params.iter().map(|p| p.0).collect_vec(); - - create_fn_by_codegen( - unifier, - var_map, - name, - ret_ty, - params, - Box::new(move |ctx, _, fun, args, generator| { - let args_ty = fun.0.args.iter().map(|a| a.ty).collect_vec(); - - assert!(param_tys - .iter() - .zip(&args_ty) - .all(|(expected, actual)| ctx.unifier.unioned(*expected, *actual))); - - let args_val = args_ty - .iter() - .zip_eq(args.iter()) - .map(|(ty, arg)| arg.1.clone().to_basic_value_enum(ctx, generator, *ty).unwrap()) - .map_into::() - .collect_vec(); - - let intrinsic_fn = ctx.module.get_function(extern_fn).unwrap_or_else(|| { - let ret_llvm_ty = ctx.get_llvm_abi_type(generator, ret_ty); - let param_llvm_ty = param_tys - .iter() - .map(|p| ctx.get_llvm_abi_type(generator, *p)) - .map_into::() - .collect_vec(); - let fn_type = ret_llvm_ty.fn_type(param_llvm_ty.as_slice(), false); - let func = ctx.module.add_function(extern_fn, fn_type, None); - func.add_attribute( - AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0), - ); - - for attr in attrs { - func.add_attribute( - AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), - ); - } - - func - }); - - let val = ctx - .builder - .build_call(intrinsic_fn, &args_val, name) - .map(CallSiteValue::try_as_basic_value) - .map(Either::unwrap_left) - .unwrap(); - Ok(val.into()) - }), - ) -} - pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> BuiltinInfo { BuiltinBuilder::new(unifier, primitives) .build_all_builtins() @@ -336,7 +199,6 @@ struct BuiltinBuilder<'a> { ndarray_float: Type, ndarray_float_2d: Type, - ndarray_num_ty: Type, float_or_ndarray_ty: TypeVar, float_or_ndarray_var_map: VarMap, @@ -450,7 +312,6 @@ impl<'a> BuiltinBuilder<'a> { ndarray_float, ndarray_float_2d, - ndarray_num_ty, float_or_ndarray_ty, float_or_ndarray_var_map, @@ -512,6 +373,14 @@ impl<'a> BuiltinBuilder<'a> { | PrimDef::FunNpEye | PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim), + PrimDef::FunNpSize | PrimDef::FunNpShape | PrimDef::FunNpStrides => { + self.build_ndarray_property_getter_function(prim) + } + + PrimDef::FunNpBroadcastTo | PrimDef::FunNpTranspose | PrimDef::FunNpReshape => { + self.build_ndarray_view_function(prim) + } + PrimDef::FunStr => self.build_str_function(), PrimDef::FunFloor | PrimDef::FunFloor64 | PrimDef::FunCeil | PrimDef::FunCeil64 => { @@ -540,6 +409,8 @@ impl<'a> BuiltinBuilder<'a> { PrimDef::FunNpIsNan | PrimDef::FunNpIsInf => self.build_np_float_to_bool_function(prim), + PrimDef::FunNpAny | PrimDef::FunNpAll => self.build_np_any_all_function(prim), + PrimDef::FunNpSin | PrimDef::FunNpCos | PrimDef::FunNpTan @@ -577,10 +448,6 @@ impl<'a> BuiltinBuilder<'a> { | PrimDef::FunNpHypot | PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim), - PrimDef::FunNpTranspose | PrimDef::FunNpReshape => { - self.build_np_sp_ndarray_function(prim) - } - PrimDef::FunNpDot | PrimDef::FunNpLinalgCholesky | PrimDef::FunNpLinalgQr @@ -710,7 +577,7 @@ impl<'a> BuiltinBuilder<'a> { let (zelf_ty, zelf) = obj.unwrap(); let zelf = zelf.to_basic_value_enum(ctx, generator, zelf_ty)?.into_pointer_value(); - let zelf = RangeValue::from_pointer_value(zelf, Some("range")); + let zelf = RangeType::new(ctx).map_pointer_value(zelf, Some("range")); let mut start = None; let mut stop = None; @@ -797,7 +664,7 @@ impl<'a> BuiltinBuilder<'a> { zelf.store_end(ctx, stop); zelf.store_step(ctx, step); - Ok(Some(zelf.as_base_value().into())) + Ok(Some(zelf.as_abi_value(ctx).into())) }, )))), loc: None, @@ -1386,6 +1253,172 @@ impl<'a> BuiltinBuilder<'a> { } } + fn build_ndarray_property_getter_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed( + prim, + &[PrimDef::FunNpSize, PrimDef::FunNpShape, PrimDef::FunNpStrides], + ); + + let in_ndarray_ty = self.unifier.get_fresh_var_with_range( + &[self.primitives.ndarray], + Some("T".into()), + None, + ); + + match prim { + PrimDef::FunNpSize => create_fn_by_codegen( + self.unifier, + &into_var_map([in_ndarray_ty]), + prim.name(), + self.primitives.int32, + &[(in_ndarray_ty.ty, "a")], + Box::new(|ctx, obj, fun, args, generator| { + assert!(obj.is_none()); + assert_eq!(args.len(), 1); + + let ndarray_ty = fun.0.args[0].ty; + let ndarray = + args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?; + let ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty) + .map_pointer_value(ndarray.into_pointer_value(), None); + + let size = ctx + .builder + .build_int_truncate_or_bit_cast(ndarray.size(ctx), ctx.ctx.i32_type(), "") + .unwrap(); + Ok(Some(size.into())) + }), + ), + + PrimDef::FunNpShape | PrimDef::FunNpStrides => { + // The function signatures of `np_shape` an `np_size` are the same. + // Mixed together for convenience. + + // The return type is a tuple of variable length depending on the ndims of the input ndarray. + let ret_ty = self.unifier.get_dummy_var().ty; // Handled by special folding + + create_fn_by_codegen( + self.unifier, + &into_var_map([in_ndarray_ty]), + prim.name(), + ret_ty, + &[(in_ndarray_ty.ty, "a")], + Box::new(move |ctx, obj, fun, args, generator| { + assert!(obj.is_none()); + assert_eq!(args.len(), 1); + + let ndarray_ty = fun.0.args[0].ty; + let ndarray = + args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?; + + let ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty) + .map_pointer_value(ndarray.into_pointer_value(), None); + + let result_tuple = match prim { + PrimDef::FunNpShape => ndarray.make_shape_tuple(generator, ctx), + PrimDef::FunNpStrides => ndarray.make_strides_tuple(generator, ctx), + _ => unreachable!(), + }; + + Ok(Some(result_tuple.as_abi_value(ctx).into())) + }), + ) + } + _ => unreachable!(), + } + } + + /// Build np/sp functions that take as input `NDArray` only + fn build_ndarray_view_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed( + prim, + &[PrimDef::FunNpBroadcastTo, PrimDef::FunNpTranspose, PrimDef::FunNpReshape], + ); + + let in_ndarray_ty = self.unifier.get_fresh_var_with_range( + &[self.primitives.ndarray], + Some("T".into()), + None, + ); + + match prim { + PrimDef::FunNpTranspose => create_fn_by_codegen( + self.unifier, + &into_var_map([in_ndarray_ty]), + prim.name(), + in_ndarray_ty.ty, + &[(in_ndarray_ty.ty, "x")], + Box::new(move |ctx, _, fun, args, generator| { + let arg_ty = fun.0.args[0].ty; + let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; + + let ndarray = NDArrayType::from_unifier_type(generator, ctx, arg_ty) + .map_pointer_value(arg_val.into_pointer_value(), None); + + let ndarray = ndarray.transpose(generator, ctx, None); // TODO: Add axes argument + Ok(Some(ndarray.as_abi_value(ctx).into())) + }), + ), + + // NOTE: on `ndarray_factory_fn_shape_arg_tvar` and + // the `param_ty` for `create_fn_by_codegen`. + // + // Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking + // to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`], + // and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`. + PrimDef::FunNpBroadcastTo | PrimDef::FunNpReshape => { + // These two functions have the same function signature. + // Mixed together for convenience. + + let ret_ty = self.unifier.get_dummy_var().ty; // Handled by special holding + + create_fn_by_codegen( + self.unifier, + &VarMap::new(), + prim.name(), + ret_ty, + &[ + (in_ndarray_ty.ty, "x"), + (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape"), // Handled by special folding + ], + Box::new(move |ctx, _, fun, args, generator| { + let ndarray_ty = fun.0.args[0].ty; + let ndarray_val = + args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?; + + let shape_ty = fun.0.args[1].ty; + let shape_val = + args[1].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?; + + let ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty) + .map_pointer_value(ndarray_val.into_pointer_value(), None); + + let shape = parse_numpy_int_sequence(generator, ctx, (shape_ty, shape_val)); + + // The ndims after reshaping is gotten from the return type of the call. + let (_, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret); + let ndims = extract_ndims(&ctx.unifier, ndims); + + let new_ndarray = match prim { + PrimDef::FunNpBroadcastTo => { + ndarray.broadcast_to(generator, ctx, ndims, &shape) + } + + PrimDef::FunNpReshape => { + ndarray.reshape_or_copy(generator, ctx, ndims, &shape) + } + + _ => unreachable!(), + }; + Ok(Some(new_ndarray.as_abi_value(ctx).as_basic_value_enum())) + }), + ) + } + + _ => unreachable!(), + } + } + /// Build the `str()` function. fn build_str_function(&mut self) -> TopLevelDef { let prim = PrimDef::FunStr; @@ -1693,6 +1726,64 @@ impl<'a> BuiltinBuilder<'a> { ) } + fn build_np_any_all_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpAny, PrimDef::FunNpAll]); + + let param_ty = &[(self.num_or_ndarray_ty.ty, "a")]; + let ret_ty = self.primitives.bool; + let var_map = &self.num_or_ndarray_var_map; + let codegen_callback: Box = + Box::new(move |ctx, _, fun, args, generator| { + let llvm_i1 = ctx.ctx.bool_type(); + let llvm_i1_k0 = llvm_i1.const_zero(); + let llvm_i1_k1 = llvm_i1.const_all_ones(); + + let a_ty = fun.0.args[0].ty; + let a_val = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?; + let a = ScalarOrNDArray::from_value(generator, ctx, (a_ty, a_val)); + let a_elem_ty = arraylike_flatten_element_type(&mut ctx.unifier, a_ty); + + let (init, sc_val) = match prim { + PrimDef::FunNpAny => (llvm_i1_k0, llvm_i1_k1), + PrimDef::FunNpAll => (llvm_i1_k1, llvm_i1_k0), + _ => unreachable!(), + }; + + let acc = a.fold(generator, ctx, init, |generator, ctx, hooks, acc, elem| { + gen_if_callback( + generator, + ctx, + |_, ctx| { + Ok(ctx + .builder + .build_int_compare(IntPredicate::EQ, acc, sc_val, "") + .unwrap()) + }, + |_, ctx| { + if let Some(hooks) = hooks { + hooks.build_break_branch(&ctx.builder); + } + Ok(()) + }, + |_, _| Ok(()), + )?; + + let is_truthy = + builtin_fns::call_bool(generator, ctx, (a_elem_ty, elem))?.into_int_value(); + + Ok(match prim { + PrimDef::FunNpAny => ctx.builder.build_or(acc, is_truthy, "").unwrap(), + PrimDef::FunNpAll => ctx.builder.build_and(acc, is_truthy, "").unwrap(), + _ => unreachable!(), + }) + })?; + + Ok(Some(acc.as_basic_value_enum())) + }); + + create_fn_by_codegen(self.unifier, var_map, prim.name(), ret_ty, param_ty, codegen_callback) + } + /// Build 1-ary numpy/scipy functions that take in a float or an ndarray and return a value of the same type as the input. fn build_np_sp_float_or_ndarray_1ary_function(&mut self, prim: PrimDef) -> TopLevelDef { debug_assert_prim_is_allowed( @@ -1873,57 +1964,6 @@ impl<'a> BuiltinBuilder<'a> { } } - /// Build np/sp functions that take as input `NDArray` only - fn build_np_sp_ndarray_function(&mut self, prim: PrimDef) -> TopLevelDef { - debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]); - - match prim { - PrimDef::FunNpTranspose => { - let ndarray_ty = self.unifier.get_fresh_var_with_range( - &[self.ndarray_num_ty], - Some("T".into()), - None, - ); - create_fn_by_codegen( - self.unifier, - &into_var_map([ndarray_ty]), - prim.name(), - ndarray_ty.ty, - &[(ndarray_ty.ty, "x")], - Box::new(move |ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg_val = - args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?)) - }), - ) - } - - // NOTE: on `ndarray_factory_fn_shape_arg_tvar` and - // the `param_ty` for `create_fn_by_codegen`. - // - // Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking - // to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`], - // and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`. - PrimDef::FunNpReshape => create_fn_by_codegen( - self.unifier, - &VarMap::new(), - prim.name(), - self.ndarray_num_ty, - &[(self.ndarray_num_ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")], - Box::new(move |ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_ty = fun.0.args[1].ty; - let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) - }), - ), - - _ => unreachable!(), - } - } - /// Build `np_linalg` and `sp_linalg` functions /// /// The input to these functions must be floating point `NDArray` @@ -1955,10 +1995,12 @@ impl<'a> BuiltinBuilder<'a> { Box::new(move |ctx, _, fun, args, generator| { let x1_ty = fun.0.args[0].ty; let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - Ok(Some(ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) + let result = ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?; + Ok(Some(result)) }), ), diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 6040ced..a6a0ce7 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -101,7 +101,9 @@ impl TopLevelComposer { let builtin_name_list = definition_ast_list .iter() .map(|def_ast| match *def_ast.0.read() { - TopLevelDef::Class { name, .. } => name.to_string(), + TopLevelDef::Class { name, .. } | TopLevelDef::Module { name, .. } => { + name.to_string() + } TopLevelDef::Function { simple_name, .. } | TopLevelDef::Variable { simple_name, .. } => simple_name.to_string(), }) @@ -201,6 +203,43 @@ impl TopLevelComposer { self.definition_ast_list.iter().map(|(def, ..)| def.clone()).collect_vec() } + /// register top level modules + pub fn register_top_level_module( + &mut self, + module_name: &str, + name_to_pyid: &Rc>, + resolver: Arc, + location: Option, + ) -> Result { + let mut methods: HashMap = HashMap::new(); + let mut attributes: Vec<(StrRef, DefinitionId)> = Vec::new(); + + for (name, _) in name_to_pyid.iter() { + if let Ok(def_id) = resolver.get_identifier_def(*name) { + // Avoid repeated attribute instances resulting from multiple imports of same module + if self.defined_names.contains(&format!("{module_name}.{name}")) { + match &*self.definition_ast_list[def_id.0].0.read() { + TopLevelDef::Class { .. } | TopLevelDef::Function { .. } => { + methods.insert(*name, def_id); + } + _ => attributes.push((*name, def_id)), + } + } + }; + } + let module_def = TopLevelDef::Module { + name: module_name.to_string().into(), + module_id: DefinitionId(self.definition_ast_list.len()), + methods, + attributes, + resolver: Some(resolver), + loc: location, + }; + + self.definition_ast_list.push((Arc::new(RwLock::new(module_def)), None)); + Ok(DefinitionId(self.definition_ast_list.len() - 1)) + } + /// register, just remember the names of top level classes/function /// and check duplicate class/method/function definition pub fn register_top_level( @@ -469,10 +508,10 @@ impl TopLevelComposer { self.analyze_top_level_class_definition()?; self.analyze_top_level_class_fields_methods()?; self.analyze_top_level_function()?; + self.analyze_top_level_variables()?; if inference { self.analyze_function_instance()?; } - self.analyze_top_level_variables()?; Ok(()) } @@ -1052,7 +1091,7 @@ impl TopLevelComposer { } let mut result = Vec::new(); let no_defaults = args.args.len() - args.defaults.len() - 1; - for (idx, x) in itertools::enumerate(args.args.iter().skip(1)) { + for (idx, x) in args.args.iter().skip(1).enumerate() { let type_ann = { let Some(annotation_expr) = x.node.annotation.as_ref() else {return Err(HashSet::from([format!("type annotation needed for `{}` (at {})", x.node.arg, x.location)]));}; parse_ast_to_type_annotation_kinds( @@ -1410,7 +1449,7 @@ impl TopLevelComposer { Ok(()) } - /// step 4, analyze and call type inferencer to fill the `instance_to_stmt` of + /// step 5, analyze and call type inferencer to fill the `instance_to_stmt` of /// [`TopLevelDef::Function`] fn analyze_function_instance(&mut self) -> Result<(), HashSet> { // first get the class constructor type correct for the following type check in function body @@ -1941,7 +1980,7 @@ impl TopLevelComposer { Ok(()) } - /// Step 5. Analyze and populate the types of global variables. + /// Step 4. Analyze and populate the types of global variables. fn analyze_top_level_variables(&mut self) -> Result<(), HashSet> { let def_list = &self.definition_ast_list; let temp_def_list = self.extract_def_list(); @@ -1959,6 +1998,19 @@ impl TopLevelComposer { let resolver = &**resolver.as_ref().unwrap(); if let Some(ty_decl) = ty_decl { + let ty_decl = match &ty_decl.node { + ExprKind::Subscript { value, slice, .. } + if matches!( + &value.node, + ast::ExprKind::Name { id, .. } if self.core_config.kernel_ann.map_or(false, |c| id == &c.into()) + ) => + { + slice + } + _ if self.core_config.kernel_ann.is_none() => ty_decl, + _ => unreachable!("Global variables should be annotated with Kernel[]"), // ignore fields annotated otherwise + }; + let ty_annotation = parse_ast_to_type_annotation_kinds( resolver, &temp_def_list, diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 71c1859..4ca5464 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -54,6 +54,16 @@ pub enum PrimDef { FunNpEye, FunNpIdentity, + // NumPy ndarray property getters + FunNpSize, + FunNpShape, + FunNpStrides, + + // NumPy ndarray view functions + FunNpBroadcastTo, + FunNpTranspose, + FunNpReshape, + // Miscellaneous NumPy & SciPy functions FunNpRound, FunNpFloor, @@ -101,8 +111,8 @@ pub enum PrimDef { FunNpLdExp, FunNpHypot, FunNpNextAfter, - FunNpTranspose, - FunNpReshape, + FunNpAny, + FunNpAll, // Linalg functions FunNpDot, @@ -240,6 +250,16 @@ impl PrimDef { PrimDef::FunNpEye => fun("np_eye", None), PrimDef::FunNpIdentity => fun("np_identity", None), + // NumPy NDArray property getters, + PrimDef::FunNpSize => fun("np_size", None), + PrimDef::FunNpShape => fun("np_shape", None), + PrimDef::FunNpStrides => fun("np_strides", None), + + // NumPy NDArray view functions + PrimDef::FunNpBroadcastTo => fun("np_broadcast_to", None), + PrimDef::FunNpTranspose => fun("np_transpose", None), + PrimDef::FunNpReshape => fun("np_reshape", None), + // Miscellaneous NumPy & SciPy functions PrimDef::FunNpRound => fun("np_round", None), PrimDef::FunNpFloor => fun("np_floor", None), @@ -287,8 +307,8 @@ impl PrimDef { PrimDef::FunNpLdExp => fun("np_ldexp", None), PrimDef::FunNpHypot => fun("np_hypot", None), PrimDef::FunNpNextAfter => fun("np_nextafter", None), - PrimDef::FunNpTranspose => fun("np_transpose", None), - PrimDef::FunNpReshape => fun("np_reshape", None), + PrimDef::FunNpAny => fun("np_any", None), + PrimDef::FunNpAll => fun("np_all", None), // Linalg functions PrimDef::FunNpDot => fun("np_dot", None), @@ -359,21 +379,37 @@ pub fn make_exception_fields(int32: Type, int64: Type, str: Type) -> Vec<(StrRef impl TopLevelDef { pub fn to_string(&self, unifier: &mut Unifier) -> String { match self { - TopLevelDef::Class { name, ancestors, fields, methods, type_vars, .. } => { + TopLevelDef::Module { name, attributes, methods, .. } => { + format!( + "Module {{\nname: {:?},\nattributes: {:?}\nmethods: {:?}\n}}", + name, + attributes.iter().map(|(n, _)| n.to_string()).collect_vec(), + methods.iter().map(|(n, _)| n.to_string()).collect_vec() + ) + } + TopLevelDef::Class { + name, ancestors, fields, methods, attributes, type_vars, .. + } => { let fields_str = fields .iter() .map(|(n, ty, _)| (n.to_string(), unifier.stringify(*ty))) .collect_vec(); + let attributes_str = attributes + .iter() + .map(|(n, ty, _)| (n.to_string(), unifier.stringify(*ty))) + .collect_vec(); + let methods_str = methods .iter() .map(|(n, ty, id)| (n.to_string(), unifier.stringify(*ty), *id)) .collect_vec(); format!( - "Class {{\nname: {:?},\nancestors: {:?},\nfields: {:?},\nmethods: {:?},\ntype_vars: {:?}\n}}", + "Class {{\nname: {:?},\nancestors: {:?},\nfields: {:?},\nattributes: {:?},\nmethods: {:?},\ntype_vars: {:?}\n}}", name, ancestors.iter().map(|ancestor| ancestor.stringify(unifier)).collect_vec(), fields_str.iter().map(|(a, _)| a).collect_vec(), + attributes_str.iter().map(|(a, _)| a).collect_vec(), methods_str.iter().map(|(a, b, _)| (a, b)).collect_vec(), type_vars.iter().map(|id| unifier.stringify(*id)).collect_vec(), ) diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index cba2f5e..3ffd568 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -92,6 +92,20 @@ pub struct FunInstance { #[derive(Debug, Clone)] pub enum TopLevelDef { + Module { + /// Name of the module + name: StrRef, + /// Module ID used for [`TypeEnum`] + module_id: DefinitionId, + /// `DefinitionId` of `TopLevelDef::{Class, Function}` within the module + methods: HashMap, + /// `DefinitionId` of `TopLevelDef::{Variable}` within the module + attributes: Vec<(StrRef, DefinitionId)>, + /// Symbol resolver of the module defined the class. + resolver: Option>, + /// Definition location. + loc: Option, + }, Class { /// Name for error messages and symbols. name: StrRef, diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index 1b0c9b8..8c827ee 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -3,10 +3,10 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", + "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(246)]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(261)]\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index 2621337..b8a80a5 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -3,13 +3,13 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[T]\"],\nfields: [\"a\", \"b\", \"c\"],\nmethods: [(\"__init__\", \"fn[[t:T], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"T\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[T]\"],\nfields: [\"a\", \"b\", \"c\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[t:T], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"T\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar230]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar230\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar245]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar245\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", - "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", + "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index d076930..05f4488 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -4,10 +4,10 @@ expression: res_vec --- [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", - "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(243)]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(248)]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(258)]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(263)]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index 5ebdf86..7d3922e 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,10 +3,10 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar229, typevar230]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar229\", \"typevar230\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar244, typevar245]\"],\nfields: [\"a\", \"b\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar244\", \"typevar245\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:B], B]\",\nvar_id: []\n}\n", "Function {\nname: \"B.bar\",\nsig: \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index 9a9c4dd..b55e998 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -3,15 +3,15 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(249)]\n}\n", - "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(264)]\n}\n", + "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(257)]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(272)]\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_pass_in_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_pass_in_class.snap index 5178f1b..2f37789 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_pass_in_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_pass_in_class.snap @@ -1,9 +1,7 @@ --- source: nac3core/src/toplevel/test.rs -assertion_line: 549 expression: res_vec - --- [ - "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [],\nmethods: [],\ntype_vars: []\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [],\nattributes: [],\nmethods: [],\ntype_vars: []\n}\n", ] diff --git a/nac3core/src/toplevel/test.rs b/nac3core/src/toplevel/test.rs index 37a3ded..6a83632 100644 --- a/nac3core/src/toplevel/test.rs +++ b/nac3core/src/toplevel/test.rs @@ -15,14 +15,13 @@ use crate::{ symbol_resolver::{SymbolResolver, ValueEnum}, typecheck::{ type_inferencer::PrimitiveStore, - typedef::{into_var_map, Type, Unifier}, + typedef::{Type, Unifier}, }, }; struct ResolverInternal { id_to_type: Mutex>, id_to_def: Mutex>, - class_names: Mutex>, } impl ResolverInternal { @@ -179,11 +178,8 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) { let mut composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0; - let internal_resolver = Arc::new(ResolverInternal { - id_to_def: Mutex::default(), - id_to_type: Mutex::default(), - class_names: Mutex::default(), - }); + let internal_resolver = + Arc::new(ResolverInternal { id_to_def: Mutex::default(), id_to_type: Mutex::default() }); let resolver = Arc::new(Resolver(internal_resolver.clone())) as Arc; @@ -784,13 +780,6 @@ fn make_internal_resolver_with_tvar( unifier: &mut Unifier, print: bool, ) -> Arc { - let list_elem_tvar = unifier.get_fresh_var(Some("list_elem".into()), None); - let list = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimDef::List.id(), - fields: HashMap::new(), - params: into_var_map([list_elem_tvar]), - }); - let res: Arc = ResolverInternal { id_to_def: Mutex::new(HashMap::from([("list".into(), PrimDef::List.id())])), id_to_type: tvars @@ -806,7 +795,6 @@ fn make_internal_resolver_with_tvar( }) .collect::>() .into(), - class_names: Mutex::new(HashMap::from([("list".into(), list)])), } .into(); if print { @@ -819,7 +807,7 @@ struct TypeToStringFolder<'a> { unifier: &'a mut Unifier, } -impl<'a> Fold> for TypeToStringFolder<'a> { +impl Fold> for TypeToStringFolder<'_> { type TargetU = String; type Error = String; fn map_user(&mut self, user: Option) -> Result { diff --git a/nac3core/src/typecheck/function_check.rs b/nac3core/src/typecheck/function_check.rs index ed801a1..2e655d1 100644 --- a/nac3core/src/typecheck/function_check.rs +++ b/nac3core/src/typecheck/function_check.rs @@ -15,7 +15,7 @@ use super::{ }; use crate::toplevel::helper::PrimDef; -impl<'a> Inferencer<'a> { +impl Inferencer<'_> { fn should_have_value(&mut self, expr: &Expr>) -> Result<(), HashSet> { if matches!(expr.custom, Some(ty) if self.unifier.unioned(ty, self.primitives.none)) { Err(HashSet::from([format!("Error at {}: cannot have value none", expr.location)])) @@ -94,7 +94,7 @@ impl<'a> Inferencer<'a> { // there are some cases where the custom field is None if let Some(ty) = &expr.custom { if !matches!(&expr.node, ExprKind::Constant { value: Constant::Ellipsis, .. }) - && !ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::List.id()) + && ty.obj_id(self.unifier).is_none_or(|id| id != PrimDef::List.id()) && !self.unifier.is_concrete(*ty, &self.function_data.bound_variables) { return Err(HashSet::from([format!( diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index 60972f0..40bbdea 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -7,12 +7,12 @@ use nac3parser::ast::{Cmpop, Operator, StrRef, Unaryop}; use super::{ type_inferencer::*, - typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, + typedef::{into_var_map, FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, }; use crate::{ symbol_resolver::SymbolValue, toplevel::{ - helper::PrimDef, + helper::{extract_ndims, PrimDef}, numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, }, }; @@ -175,19 +175,8 @@ pub fn impl_binop( ops: &[Operator], ) { with_fields(unifier, ty, |unifier, fields| { - let (other_ty, other_var_id) = if other_ty.len() == 1 { - (other_ty[0], None) - } else { - let tvar = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None); - (tvar.ty, Some(tvar.id)) - }; - - let function_vars = if let Some(var_id) = other_var_id { - vec![(var_id, other_ty)].into_iter().collect::() - } else { - VarMap::new() - }; - + let other_tvar = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None); + let function_vars = into_var_map([other_tvar]); let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty); for (base_op, variant) in iproduct!(ops, [BinopVariant::Normal, BinopVariant::AugAssign]) { @@ -198,7 +187,7 @@ pub fn impl_binop( ret: ret_ty, vars: function_vars.clone(), args: vec![FuncArg { - ty: other_ty, + ty: other_tvar.ty, default_value: None, name: "other".into(), is_vararg: false, @@ -541,36 +530,43 @@ pub fn typeof_binop( } } - let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs); - let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) { - TypeEnum::TLiteral { values, .. } => { - assert_eq!(values.len(), 1); - u64::try_from(values[0].clone()).unwrap() + let (lhs_dtype, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs); + let lhs_ndims = extract_ndims(unifier, lhs_ndims); + + let (rhs_dtype, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs); + let rhs_ndims = extract_ndims(unifier, rhs_ndims); + + if !(unifier.unioned(lhs_dtype, primitives.float) + && unifier.unioned(rhs_dtype, primitives.float)) + { + return Err(format!( + "ndarray.__matmul__ only supports float64 operations, but LHS has type {} and RHS has type {}", + unifier.stringify(lhs), + unifier.stringify(rhs) + )); + } + + // Deduce the ndims of the resulting ndarray. + // If this is 0 (an unsized ndarray), matmul returns a scalar just like NumPy. + let result_ndims = match (lhs_ndims, rhs_ndims) { + (0, _) | (_, 0) => { + return Err( + "ndarray.__matmul__ does not allow unsized ndarray input".to_string() + ) } - _ => unreachable!(), - }; - let (_, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs); - let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) { - TypeEnum::TLiteral { values, .. } => { - assert_eq!(values.len(), 1); - u64::try_from(values[0].clone()).unwrap() - } - _ => unreachable!(), + (1, 1) => 0, + (1, _) => rhs_ndims - 1, + (_, 1) => lhs_ndims - 1, + (m, n) => max(m, n), }; - match (lhs_ndims, rhs_ndims) { - (2, 2) => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?, - (lhs, rhs) if lhs == 0 || rhs == 0 => { - return Err(format!( - "Input operand {} does not have enough dimensions (has {lhs}, requires {rhs})", - u8::from(rhs == 0) - )) - } - (lhs, rhs) => { - return Err(format!( - "ndarray.__matmul__ on {lhs}D and {rhs}D operands not supported" - )) - } + if result_ndims == 0 { + // If the result is unsized, NumPy returns a scalar. + primitives.float + } else { + let result_ndims_ty = + unifier.get_fresh_literal(vec![SymbolValue::U64(result_ndims)], None); + make_ndarray_ty(unifier, primitives, Some(primitives.float), Some(result_ndims_ty)) } } @@ -773,7 +769,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None); impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); - impl_matmul(unifier, store, ndarray_t, &[ndarray_t], Some(ndarray_t)); + impl_matmul(unifier, store, ndarray_t, &[ndarray_unsized_t], None); impl_sign(unifier, store, ndarray_t, Some(ndarray_t)); impl_invert(unifier, store, ndarray_t, Some(ndarray_t)); impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); diff --git a/nac3core/src/typecheck/type_error.rs b/nac3core/src/typecheck/type_error.rs index 3cf5bdc..b144f8a 100644 --- a/nac3core/src/typecheck/type_error.rs +++ b/nac3core/src/typecheck/type_error.rs @@ -94,7 +94,7 @@ fn loc_to_str(loc: Option) -> String { } } -impl<'a> Display for DisplayTypeError<'a> { +impl Display for DisplayTypeError<'_> { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { use TypeErrorKind::*; let mut notes = Some(HashMap::new()); diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 742fa19..7ce659f 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -2008,72 +2008,90 @@ impl Inferencer<'_> { ctx: ExprContext, ) -> InferenceResult { let ty = value.custom.unwrap(); - if let TypeEnum::TObj { obj_id, fields, .. } = &*self.unifier.get_ty(ty) { - // just a fast path - match (fields.get(&attr), ctx == ExprContext::Store) { - (Some((ty, true)), _) | (Some((ty, false)), false) => Ok(*ty), - (Some((ty, false)), true) => report_type_error( - TypeErrorKind::MutationError(RecordKey::Str(attr), *ty), - Some(value.location), - self.unifier, - ), - (None, mutable) => { - // Check whether it is a class attribute - let defs = self.top_level.definitions.read(); - let result = { - if let TopLevelDef::Class { attributes, .. } = &*defs[obj_id.0].read() { - attributes.iter().find_map(|f| { - if f.0 == attr { - return Some(f.1); - } - None - }) - } else { - None - } - }; - match result { - Some(res) if !mutable => Ok(res), - Some(_) => report_error( - &format!("Class Attribute `{attr}` is immutable"), - value.location, - ), - None => report_type_error( - TypeErrorKind::NoSuchField(RecordKey::Str(attr), ty), - Some(value.location), - self.unifier, - ), - } - } - } - } else if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(ty) { - // Access Class Attributes of classes with __init__ function using Class names e.g. Foo.ATTR1 - let result = { - self.top_level.definitions.read().iter().find_map(|def| { - if let Some(rear_guard) = def.try_read() { - if let TopLevelDef::Class { name, attributes, .. } = &*rear_guard { - if name.to_string() == self.unifier.stringify(sign.ret) { - return attributes.iter().find_map(|f| { + match &*self.unifier.get_ty(ty) { + TypeEnum::TObj { obj_id, fields, .. } => { + // just a fast path + match (fields.get(&attr), ctx == ExprContext::Store) { + (Some((ty, true)), _) | (Some((ty, false)), false) => Ok(*ty), + (Some((ty, false)), true) => report_type_error( + TypeErrorKind::MutationError(RecordKey::Str(attr), *ty), + Some(value.location), + self.unifier, + ), + (None, mutable) => { + // Check whether it is a class attribute + let defs = self.top_level.definitions.read(); + let result = { + if let TopLevelDef::Class { attributes, .. } = &*defs[obj_id.0].read() { + attributes.iter().find_map(|f| { if f.0 == attr { - return Some(f.clone().1); + return Some(f.1); } None - }); + }) + } else { + None } + }; + match result { + Some(res) if !mutable => Ok(res), + Some(_) => report_error( + &format!("Class Attribute `{attr}` is immutable"), + value.location, + ), + None => report_type_error( + TypeErrorKind::NoSuchField(RecordKey::Str(attr), ty), + Some(value.location), + self.unifier, + ), } } - None - }) - }; - match result { - Some(f) if ctx != ExprContext::Store => Ok(f), - Some(_) => { - report_error(&format!("Class Attribute `{attr}` is immutable"), value.location) } - None => self.infer_general_attribute(value, attr, ctx), } - } else { - self.infer_general_attribute(value, attr, ctx) + TypeEnum::TFunc(sign) => { + // Access Class Attributes of classes with __init__ function using Class names e.g. Foo.ATTR1 + let result = { + self.top_level.definitions.read().iter().find_map(|def| { + if let Some(rear_guard) = def.try_read() { + if let TopLevelDef::Class { name, attributes, .. } = &*rear_guard { + if name.to_string() == self.unifier.stringify(sign.ret) { + return attributes.iter().find_map(|f| { + if f.0 == attr { + return Some(f.clone().1); + } + None + }); + } + } + } + None + }) + }; + match result { + Some(f) if ctx != ExprContext::Store => Ok(f), + Some(_) => report_error( + &format!("Class Attribute `{attr}` is immutable"), + value.location, + ), + None => self.infer_general_attribute(value, attr, ctx), + } + } + TypeEnum::TModule { attributes, .. } => { + match (attributes.get(&attr), ctx == ExprContext::Load) { + (Some((ty, _)), true) | (Some((ty, false)), false) => Ok(*ty), + (Some((ty, true)), false) => report_type_error( + TypeErrorKind::MutationError(RecordKey::Str(attr), *ty), + Some(value.location), + self.unifier, + ), + (None, _) => report_type_error( + TypeErrorKind::NoSuchField(RecordKey::Str(attr), ty), + Some(value.location), + self.unifier, + ), + } + } + _ => self.infer_general_attribute(value, attr, ctx), } } @@ -2734,7 +2752,7 @@ impl Inferencer<'_> { .read() .iter() .map(|def| match *def.read() { - TopLevelDef::Class { name, .. } => (name, false), + TopLevelDef::Class { name, .. } | TopLevelDef::Module { name, .. } => (name, false), TopLevelDef::Function { simple_name, .. } => (simple_name, false), TopLevelDef::Variable { simple_name, .. } => (simple_name, true), }) diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index e56cb28..a658353 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -18,7 +18,6 @@ use crate::{ struct Resolver { id_to_type: HashMap, id_to_def: HashMap, - class_names: HashMap, } impl SymbolResolver for Resolver { @@ -198,7 +197,6 @@ impl TestEnvironment { let resolver = Arc::new(Resolver { id_to_type: identifier_mapping.clone(), id_to_def: HashMap::default(), - class_names: HashMap::default(), }) as Arc; TestEnvironment { @@ -454,7 +452,6 @@ impl TestEnvironment { vars: IndexMap::default(), })), ); - let class_names: HashMap<_, _> = [("Bar".into(), bar), ("Bar2".into(), bar2)].into(); let id_to_name = [ "int32".into(), @@ -492,7 +489,6 @@ impl TestEnvironment { ("Bar2".into(), DefinitionId(defs + 3)), ] .into(), - class_names, }) as Arc; TestEnvironment { diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 49cea04..f2f9ed6 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -3,13 +3,13 @@ use std::{ cell::RefCell, collections::{HashMap, HashSet}, fmt::{self, Display}, - iter::{repeat, zip}, + iter::{repeat, repeat_n, zip}, rc::Rc, sync::{Arc, Mutex}, }; use indexmap::IndexMap; -use itertools::{repeat_n, Itertools}; +use itertools::Itertools; use nac3parser::ast::{Cmpop, Location, StrRef, Unaryop}; @@ -270,6 +270,19 @@ pub enum TypeEnum { /// A function type. TFunc(FunSignature), + + /// Module Type + TModule { + /// The [`DefinitionId`] of this object type. + module_id: DefinitionId, + + /// The attributes present in this object type. + /// + /// The key of the [Mapping] is the identifier of the field, while the value is a tuple + /// containing the [Type] of the field, and a `bool` indicating whether the field is a + /// variable (as opposed to a function). + attributes: Mapping, + }, } impl TypeEnum { @@ -284,6 +297,7 @@ impl TypeEnum { TypeEnum::TVirtual { .. } => "TVirtual", TypeEnum::TCall { .. } => "TCall", TypeEnum::TFunc { .. } => "TFunc", + TypeEnum::TModule { .. } => "TModule", } } } @@ -593,7 +607,8 @@ impl Unifier { | TLiteral { .. } // functions are instantiated for each call sites, so the function type can contain // type variables. - | TFunc { .. } => true, + | TFunc { .. } + | TModule { .. } => true, TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)), TCall { .. } => false, @@ -1315,10 +1330,12 @@ impl Unifier { || format!("{id}"), |top_level| { let top_level_def = &top_level.definitions.read()[id]; - let TopLevelDef::Class { name, .. } = &*top_level_def.read() else { - unreachable!("expected class definition") + let top_level_def = top_level_def.read(); + let (TopLevelDef::Class { name, .. } | TopLevelDef::Module { name, .. }) = + &*top_level_def + else { + unreachable!("expected module/class definition") }; - name.to_string() }, ) @@ -1446,6 +1463,10 @@ impl Unifier { let ret = self.internal_stringify(signature.ret, obj_to_name, var_to_name, notes); format!("fn[[{params}], {ret}]") } + TypeEnum::TModule { module_id, .. } => { + let name = obj_to_name(module_id.0); + name.to_string() + } } } @@ -1521,7 +1542,9 @@ impl Unifier { // variables, i.e. things like TRecord, TCall should not occur, and we // should be safe to not implement the substitution for those variants. match &*ty { - TypeEnum::TRigidVar { .. } | TypeEnum::TLiteral { .. } => None, + TypeEnum::TRigidVar { .. } | TypeEnum::TLiteral { .. } | TypeEnum::TModule { .. } => { + None + } TypeEnum::TVar { id, .. } => mapping.get(id).copied(), TypeEnum::TTuple { ty, is_vararg_ctx } => { let mut new_ty = Cow::from(ty); diff --git a/nac3ld/src/dwarf.rs b/nac3ld/src/dwarf.rs index e85a4e4..4bcccd3 100644 --- a/nac3ld/src/dwarf.rs +++ b/nac3ld/src/dwarf.rs @@ -30,7 +30,7 @@ pub struct DwarfReader<'a> { pub virt_addr: u32, } -impl<'a> DwarfReader<'a> { +impl DwarfReader<'_> { pub fn new(slice: &[u8], virt_addr: u32) -> DwarfReader { DwarfReader { slice, virt_addr } } @@ -113,7 +113,7 @@ pub struct DwarfWriter<'a> { pub offset: usize, } -impl<'a> DwarfWriter<'a> { +impl DwarfWriter<'_> { pub fn new(slice: &mut [u8]) -> DwarfWriter { DwarfWriter { slice, offset: 0 } } @@ -375,7 +375,7 @@ pub struct FDE_Records<'a> { available: usize, } -impl<'a> Iterator for FDE_Records<'a> { +impl Iterator for FDE_Records<'_> { type Item = (u32, u32); fn next(&mut self) -> Option { @@ -423,7 +423,7 @@ pub struct EH_Frame_Hdr<'a> { fdes: Vec<(u32, u32)>, } -impl<'a> EH_Frame_Hdr<'a> { +impl EH_Frame_Hdr<'_> { /// Create a [EH_Frame_Hdr] object, and write out the fixed fields of `.eh_frame_hdr` to memory. /// /// Load address is not known at this point. diff --git a/nac3ld/src/lib.rs b/nac3ld/src/lib.rs index 73e065d..8ef8653 100644 --- a/nac3ld/src/lib.rs +++ b/nac3ld/src/lib.rs @@ -159,7 +159,7 @@ struct SymbolTableReader<'a> { strtab: &'a [u8], } -impl<'a> SymbolTableReader<'a> { +impl SymbolTableReader<'_> { pub fn find_index_by_name(&self, sym_name: &[u8]) -> Option { self.symtab.iter().position(|sym| { if let Ok(dynsym_name) = name_starting_at_slice(self.strtab, sym.st_name as usize) { diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index 4f19db9..180d24f 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -67,7 +67,7 @@ def _bool(x): def _float(x): if isinstance(x, np.ndarray): - return np.float_(x) + return np.float64(x) else: return float(x) @@ -111,6 +111,9 @@ def patch(module): def output_strln(x): print(x, end='') + def output_int32_list(x): + print([int(e) for e in x]) + def dbg_stack_address(_): return 0 @@ -126,11 +129,12 @@ def patch(module): return output_float elif name == "output_str": return output_strln + elif name == "output_int32_list": + return output_int32_list elif name in { "output_bool", "output_int32", "output_int64", - "output_int32_list", "output_uint32", "output_uint64", "output_strln", @@ -179,6 +183,16 @@ def patch(module): module.np_identity = np.identity module.np_array = np.array + # NumPy NDArray view functions + module.np_broadcast_to = np.broadcast_to + module.np_transpose = np.transpose + module.np_reshape = np.reshape + + # NumPy NDArray property getters + module.np_size = np.size + module.np_shape = np.shape + module.np_strides = lambda ndarray: ndarray.strides + # NumPy Math functions module.np_isnan = np.isnan module.np_isinf = np.isinf @@ -218,8 +232,8 @@ def patch(module): module.np_ldexp = np.ldexp module.np_hypot = np.hypot module.np_nextafter = np.nextafter - module.np_transpose = np.transpose - module.np_reshape = np.reshape + module.np_any = np.any + module.np_all = np.all # SciPy Math functions module.sp_spec_erf = special.erf diff --git a/nac3standalone/demo/run_demo.sh b/nac3standalone/demo/run_demo.sh index bec2eb6..78e32dd 100755 --- a/nac3standalone/demo/run_demo.sh +++ b/nac3standalone/demo/run_demo.sh @@ -58,7 +58,7 @@ rm -f ./*.o ./*.bc demo if [ -z "$i686" ]; then $nac3standalone "${nac3args[@]}" clang -c -std=gnu11 -Wall -Wextra -O3 -o demo.o demo.c - clang -o demo module.o demo.o $DEMO_LINALG_STUB -lm -Wl,--no-warn-search-mismatch + clang -o demo module.o demo.o $DEMO_LINALG_STUB -fuse-ld=lld -lm else $nac3standalone --triple i686-unknown-linux-gnu --target-features +sse2 "${nac3args[@]}" clang -m32 -c -std=gnu11 -Wall -Wextra -O3 -msse2 -o demo.o demo.c diff --git a/nac3standalone/demo/src/class_attributes.py b/nac3standalone/demo/src/class_attributes.py new file mode 100644 index 0000000..b58958f --- /dev/null +++ b/nac3standalone/demo/src/class_attributes.py @@ -0,0 +1,35 @@ +@extern +def output_int32(x: int32): + ... + +@extern +def output_strln(x: str): + ... + + +class A: + a: int32 = 1 + b: int32 + c: str = "test" + d: str + + def __init__(self): + self.b = 2 + self.d = "test" + + output_int32(self.a) # Attributes can be accessed within class + + +def run() -> int32: + output_int32(A.a) # Attributes can be directly accessed with class name + # A.b # Only attributes can be accessed in this way + # A.a = 2 # Attributes are immutable + + obj = A() + output_int32(obj.a) # Attributes can be accessed by class objects + + output_strln(obj.c) + output_strln(obj.d) + + return 0 + diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index d42f3b9..d077b82 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -68,6 +68,19 @@ def output_ndarray_float_2(n: ndarray[float, Literal[2]]): for c in range(len(n[r])): output_float64(n[r][c]) +def output_ndarray_float_3(n: ndarray[float, Literal[3]]): + for d in range(len(n)): + for r in range(len(n[d])): + for c in range(len(n[d][r])): + output_float64(n[d][r][c]) + +def output_ndarray_float_4(n: ndarray[float, Literal[4]]): + for x in range(len(n)): + for y in range(len(n[x])): + for z in range(len(n[x][y])): + for w in range(len(n[x][y][z])): + output_float64(n[x][y][z][w]) + def consume_ndarray_1(n: ndarray[float, Literal[1]]): pass @@ -197,6 +210,104 @@ def test_ndarray_nd_idx(): output_float64(x[1, 0]) output_float64(x[1, 1]) +def test_ndarray_transpose(): + x: ndarray[float, 2] = np_array([[1., 2., 3.], [4., 5., 6.]]) + y = np_transpose(x) + z = np_transpose(y) + + output_int32(np_shape(x)[0]) + output_int32(np_shape(x)[1]) + output_ndarray_float_2(x) + + output_int32(np_shape(y)[0]) + output_int32(np_shape(y)[1]) + output_ndarray_float_2(y) + + output_int32(np_shape(z)[0]) + output_int32(np_shape(z)[1]) + output_ndarray_float_2(z) + +def test_ndarray_reshape(): + w: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) + x = np_reshape(w, (1, 2, 1, -1)) + y = np_reshape(x, [2, -1]) + z = np_reshape(y, 10) + + output_int32(np_shape(w)[0]) + output_ndarray_float_1(w) + + output_int32(np_shape(x)[0]) + output_int32(np_shape(x)[1]) + output_int32(np_shape(x)[2]) + output_int32(np_shape(x)[3]) + output_ndarray_float_4(x) + + output_int32(np_shape(y)[0]) + output_int32(np_shape(y)[1]) + output_ndarray_float_2(y) + + output_int32(np_shape(z)[0]) + output_ndarray_float_1(z) + + x1: ndarray[int32, 1] = np_array([1, 2, 3, 4]) + x2: ndarray[int32, 2] = np_reshape(x1, (2, 2)) + + output_int32(np_shape(x1)[0]) + output_ndarray_int32_1(x1) + + output_int32(np_shape(x2)[0]) + output_int32(np_shape(x2)[1]) + output_ndarray_int32_2(x2) + +def test_ndarray_broadcast_to(): + xs = np_array([1.0, 2.0, 3.0]) + ys = np_broadcast_to(xs, (1, 3)) + zs = np_broadcast_to(ys, (2, 4, 3)) + + output_int32(np_shape(xs)[0]) + output_ndarray_float_1(xs) + + output_int32(np_shape(ys)[0]) + output_int32(np_shape(ys)[1]) + output_ndarray_float_2(ys) + + output_int32(np_shape(zs)[0]) + output_int32(np_shape(zs)[1]) + output_int32(np_shape(zs)[2]) + output_ndarray_float_3(zs) + +def test_ndarray_subscript_assignment(): + xs = np_array([[11.0, 22.0, 33.0, 44.0], [55.0, 66.0, 77.0, 88.0]]) + + xs[0, 0] = 99.0 + output_ndarray_float_2(xs) + + xs[0] = 100.0 + output_ndarray_float_2(xs) + + xs[:, ::2] = 101.0 + output_ndarray_float_2(xs) + + xs[1:, 0] = 102.0 + output_ndarray_float_2(xs) + + xs[0] = np_array([-1.0, -2.0, -3.0, -4.0]) + output_ndarray_float_2(xs) + + xs[:] = np_array([-5.0, -6.0, -7.0, -8.0]) + output_ndarray_float_2(xs) + + # Test assignment with memory sharing + ys1 = np_reshape(xs, (2, 4)) + ys2 = np_transpose(ys1) + ys3 = ys2[::-1, 0] + ys3[0] = -999.0 + + output_ndarray_float_2(xs) + output_ndarray_float_2(ys1) + output_ndarray_float_2(ys2) + output_ndarray_float_1(ys3) + def test_ndarray_add(): x = np_identity(2) y = x + np_ones([2, 2]) @@ -1440,26 +1551,58 @@ def test_ndarray_nextafter_broadcast_rhs_scalar(): output_ndarray_float_2(nextafter_x_zeros) output_ndarray_float_2(nextafter_x_ones) -def test_ndarray_transpose(): - x: ndarray[float, 2] = np_array([[1., 2., 3.], [4., 5., 6.]]) - y = np_transpose(x) - z = np_transpose(y) - output_ndarray_float_2(x) - output_ndarray_float_2(y) +def test_ndarray_any(): + s0 = 0 + output_bool(np_any(s0)) + s1 = 1 + output_bool(np_any(s1)) -def test_ndarray_reshape(): - w: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) - x = np_reshape(w, (1, 2, 1, -1)) - y = np_reshape(x, [2, -1]) - z = np_reshape(y, 10) + x1 = np_identity(5) + y1 = np_any(x1) + output_ndarray_float_2(x1) + output_bool(y1) - x1: ndarray[int32, 1] = np_array([1, 2, 3, 4]) - x2: ndarray[int32, 2] = np_reshape(x1, (2, 2)) + x2 = np_identity(1) + y2 = np_any(x2) + output_ndarray_float_2(x2) + output_bool(y2) - output_ndarray_float_1(w) - output_ndarray_float_2(y) - output_ndarray_float_1(z) + x3 = np_array([[1.0, 2.0], [3.0, 4.0]]) + y3 = np_any(x3) + output_ndarray_float_2(x3) + output_bool(y3) + + x4 = np_zeros([3, 5]) + y4 = np_any(x4) + output_ndarray_float_2(x4) + output_bool(y4) + +def test_ndarray_all(): + s0 = 0 + output_bool(np_all(s0)) + s1 = 1 + output_bool(np_all(s1)) + + x1 = np_identity(5) + y1 = np_all(x1) + output_ndarray_float_2(x1) + output_bool(y1) + + x2 = np_identity(1) + y2 = np_all(x2) + output_ndarray_float_2(x2) + output_bool(y2) + + x3 = np_array([[1.0, 2.0], [3.0, 4.0]]) + y3 = np_all(x3) + output_ndarray_float_2(x3) + output_bool(y3) + + x4 = np_zeros([3, 5]) + y4 = np_all(x4) + output_ndarray_float_2(x4) + output_bool(y4) def test_ndarray_dot(): x1: ndarray[float, 1] = np_array([5.0, 1.0, 4.0, 2.0]) @@ -1592,6 +1735,11 @@ def run() -> int32: test_ndarray_slices() test_ndarray_nd_idx() + test_ndarray_transpose() + test_ndarray_reshape() + test_ndarray_broadcast_to() + test_ndarray_subscript_assignment() + test_ndarray_add() test_ndarray_add_broadcast() test_ndarray_add_broadcast_lhs_scalar() @@ -1755,8 +1903,9 @@ def run() -> int32: test_ndarray_nextafter_broadcast() test_ndarray_nextafter_broadcast_lhs_scalar() test_ndarray_nextafter_broadcast_rhs_scalar() - test_ndarray_transpose() - test_ndarray_reshape() + + test_ndarray_any() + test_ndarray_all() test_ndarray_dot() test_ndarray_cholesky() diff --git a/nac3standalone/demo/src/numeric_primitives.py b/nac3standalone/demo/src/numeric_primitives.py index 77a641f..2c08c16 100644 --- a/nac3standalone/demo/src/numeric_primitives.py +++ b/nac3standalone/demo/src/numeric_primitives.py @@ -29,10 +29,10 @@ def u32_max() -> uint32: return ~uint32(0) def i32_min() -> int32: - return int32(1 << 31) + return int32(-(1 << 31)) def i32_max() -> int32: - return int32(~(1 << 31)) + return int32((1 << 31)-1) def u64_min() -> uint64: return uint64(0) @@ -63,8 +63,9 @@ def test_conv_from_i32(): i32_max() ]: output_int64(int64(x)) - output_uint32(uint32(x)) - output_uint64(uint64(x)) + if x >= 0: + output_uint32(uint32(x)) + output_uint64(uint64(x)) output_float64(float(x)) def test_conv_from_u32(): @@ -108,7 +109,6 @@ def test_conv_from_u64(): def test_f64toi32(): for x in [ - float(i32_min()) - 1.0, float(i32_min()), float(i32_min()) + 1.0, -1.5, @@ -117,7 +117,6 @@ def test_f64toi32(): 1.5, float(i32_max()) - 1.0, float(i32_max()), - float(i32_max()) + 1.0 ]: output_int32(int32(x)) @@ -138,24 +137,17 @@ def test_f64toi64(): def test_f64tou32(): for x in [ - -1.5, - float(u32_min()) - 1.0, - -0.5, float(u32_min()), 0.5, float(u32_min()) + 1.0, 1.5, float(u32_max()) - 1.0, float(u32_max()), - float(u32_max()) + 1.0 ]: output_uint32(uint32(x)) def test_f64tou64(): for x in [ - -1.5, - float(u64_min()) - 1.0, - -0.5, float(u64_min()), 0.5, float(u64_min()) + 1.0, @@ -181,4 +173,4 @@ def run() -> int32: test_f64tou32() test_f64tou64() - return 0 \ No newline at end of file + return 0 diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index 2fce5d1..d54e08e 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -456,7 +456,13 @@ fn main() { membuffer.lock().push(buffer); }))); let threads = (0..threads) - .map(|i| Box::new(DefaultCodeGenerator::new(format!("module{i}"), size_t))) + .map(|i| { + Box::new(DefaultCodeGenerator::with_target_machine( + format!("module{i}"), + &context, + &target_machine, + )) + }) .collect(); let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f); registry.add_task(task); diff --git a/nix/windows/msys2_packages.nix b/nix/windows/msys2_packages.nix index 0ac1aa8..0859244 100644 --- a/nix/windows/msys2_packages.nix +++ b/nix/windows/msys2_packages.nix @@ -1,15 +1,15 @@ { pkgs } : [ (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libunwind-19.1.4-1-any.pkg.tar.zst"; - sha256 = "0frb5k16bbxdf8g379d16vl3qrh7n9pydn83gpfxpvwf3qlvnzyl"; - name = "mingw-w64-clang-x86_64-libunwind-19.1.4-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libunwind-19.1.6-1-any.pkg.tar.zst"; + sha256 = "1gv6hbqvfgjzirpljql1shlchldmf5ww3rfsspg90pq1frnwavjl"; + name = "mingw-w64-clang-x86_64-libunwind-19.1.6-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libc++-19.1.4-1-any.pkg.tar.zst"; - sha256 = "0wh5km0v8j50pqz9bxb4f0w7r8zhsvssrjvc94np53iq8wjagk86"; - name = "mingw-w64-clang-x86_64-libc++-19.1.4-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libc++-19.1.6-1-any.pkg.tar.zst"; + sha256 = "1wbkvrx14ahc04cgkydvlxwmsl8jfnqwhy9sy4kn4wkdzmlcp1ax"; + name = "mingw-w64-clang-x86_64-libc++-19.1.6-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { @@ -19,15 +19,15 @@ }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libiconv-1.17-4-any.pkg.tar.zst"; - sha256 = "1g2bkhgf60dywccxw911ydyigf3m25yqfh81m5099swr7mjsmzyf"; - name = "mingw-w64-clang-x86_64-libiconv-1.17-4-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libiconv-1.18-1-any.pkg.tar.zst"; + sha256 = "0vn5xgx9jjg66f8r9ylm9220qdbjdkffykfl6nwj14zv9y7xh4nj"; + name = "mingw-w64-clang-x86_64-libiconv-1.18-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-gettext-runtime-0.22.5-2-any.pkg.tar.zst"; - sha256 = "0ll6ci6d3mc7g04q0xixjc209bh8r874dqbczgns69jsad3wg6mi"; - name = "mingw-w64-clang-x86_64-gettext-runtime-0.22.5-2-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-gettext-runtime-0.23.1-1-any.pkg.tar.zst"; + sha256 = "0wbp5pmrr0rk4mx7d1frvqlk4a061zw31zscs57srmvl0wv3pi2a"; + name = "mingw-w64-clang-x86_64-gettext-runtime-0.23.1-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { @@ -55,69 +55,69 @@ }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-llvm-libs-19.1.4-1-any.pkg.tar.zst"; - sha256 = "1clrbm8dk893byj8s15pgcgqqijm2zkd10zgyakamd8m354kj9q4"; - name = "mingw-w64-clang-x86_64-llvm-libs-19.1.4-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-llvm-libs-19.1.6-1-any.pkg.tar.zst"; + sha256 = "0fpsnfyf0bg39a4ygzga06sr4wv4jp1jnc8lk6sr3z0nim0nlhjn"; + name = "mingw-w64-clang-x86_64-llvm-libs-19.1.6-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-llvm-19.1.4-1-any.pkg.tar.zst"; - sha256 = "1iz2c9475h8p20ydpp0znbhyb62rlrk7wr7xl7cmwbam7wkwr8rn"; - name = "mingw-w64-clang-x86_64-llvm-19.1.4-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-llvm-19.1.6-1-any.pkg.tar.zst"; + sha256 = "0whqs9nvfmgxj3c83px6dipcdw9zi858kgd8130201fy1mbnafp1"; + name = "mingw-w64-clang-x86_64-llvm-19.1.6-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-clang-libs-19.1.4-1-any.pkg.tar.zst"; - sha256 = "1hidciwlakxrp4kyb0j2v6g4lv76nn834g6b88w1j94fk3qc765d"; - name = "mingw-w64-clang-x86_64-clang-libs-19.1.4-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-clang-libs-19.1.6-1-any.pkg.tar.zst"; + sha256 = "0rmzri7h043i73jy3c2jcrg3hy40dr5s9n96kmxgaghfhvlpilps"; + name = "mingw-w64-clang-x86_64-clang-libs-19.1.6-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-compiler-rt-19.1.4-1-any.pkg.tar.zst"; - sha256 = "1m1yhjkgzlbk10sv966qk4yji009ga0lr25gpgj2w7mcd2wixcr3"; - name = "mingw-w64-clang-x86_64-compiler-rt-19.1.4-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-compiler-rt-19.1.6-1-any.pkg.tar.zst"; + sha256 = "04cqlh35asvlh06nmhwnx9h0yrqk8zxd9lpzxmm1xh64kvm9maxn"; + name = "mingw-w64-clang-x86_64-compiler-rt-19.1.6-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-headers-git-12.0.0.r423.g8bcd5fc1a-1-any.pkg.tar.zst"; - sha256 = "08gxc7h2achckknn6fz3p6yi7gxxvbaday8fpm4j56c4sa04n0df"; - name = "mingw-w64-clang-x86_64-headers-git-12.0.0.r423.g8bcd5fc1a-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-headers-git-12.0.0.r473.gce0d0bfb7-1-any.pkg.tar.zst"; + sha256 = "05zsqgq8zwdcfacyqdxdjcf80447bgnrz71xv5cds0y135yziy7l"; + name = "mingw-w64-clang-x86_64-headers-git-12.0.0.r473.gce0d0bfb7-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-crt-git-12.0.0.r423.g8bcd5fc1a-1-any.pkg.tar.zst"; - sha256 = "0fxd1pb197ki0gzw6z8gmd6wgpd9d28js6cp5d31d55kw7d1vz13"; - name = "mingw-w64-clang-x86_64-crt-git-12.0.0.r423.g8bcd5fc1a-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-crt-git-12.0.0.r473.gce0d0bfb7-1-any.pkg.tar.zst"; + sha256 = "12fkxpk7rwy36snvvc7sdivx81pd4ckzh5ilyh7gl6ly4qayppp6"; + name = "mingw-w64-clang-x86_64-crt-git-12.0.0.r473.gce0d0bfb7-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-lld-19.1.4-1-any.pkg.tar.zst"; - sha256 = "1a8pjyhrzpc2z3784xxwix4i7yrz03ygnsk1wv9k0yq8m8wi9nbw"; - name = "mingw-w64-clang-x86_64-lld-19.1.4-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-lld-19.1.6-1-any.pkg.tar.zst"; + sha256 = "102bbv5acq1fvrfn8bp1x3503cb8hvcxmlpr86qsba4vm11l0wrw"; + name = "mingw-w64-clang-x86_64-lld-19.1.6-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libwinpthread-git-12.0.0.r423.g8bcd5fc1a-1-any.pkg.tar.zst"; - sha256 = "140m312jx1sywqjkvfij69d268m4jpdmilq5bb8khkf0ayb16036"; - name = "mingw-w64-clang-x86_64-libwinpthread-git-12.0.0.r423.g8bcd5fc1a-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libwinpthread-git-12.0.0.r473.gce0d0bfb7-1-any.pkg.tar.zst"; + sha256 = "1sris0qczxk5px9xy85976hbmqrpg49ns7yyzd9p455ckf740cid"; + name = "mingw-w64-clang-x86_64-libwinpthread-git-12.0.0.r473.gce0d0bfb7-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-winpthreads-git-12.0.0.r423.g8bcd5fc1a-1-any.pkg.tar.zst"; - sha256 = "017j4h511wg37bacym73f8g6s0jcfgzbzabzxpc6anr3gy4kkpbg"; - name = "mingw-w64-clang-x86_64-winpthreads-git-12.0.0.r423.g8bcd5fc1a-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-winpthreads-git-12.0.0.r473.gce0d0bfb7-1-any.pkg.tar.zst"; + sha256 = "1r0m5xpsxdl00a2daj4p0wgl6037700pvw6p6zl91h1dr092r6pa"; + name = "mingw-w64-clang-x86_64-winpthreads-git-12.0.0.r473.gce0d0bfb7-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-clang-19.1.4-1-any.pkg.tar.zst"; - sha256 = "11f4i4ai2bzvq6f06vxk1ymv7056c9707vdw489f1i2bdrf0c0ii"; - name = "mingw-w64-clang-x86_64-clang-19.1.4-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-clang-19.1.6-1-any.pkg.tar.zst"; + sha256 = "0j4a642fpnvqs79chhinc8r5q53q1wllmc1bzb01a4y7w9rqg4hw"; + name = "mingw-w64-clang-x86_64-clang-19.1.6-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-rust-1.83.0-3-any.pkg.tar.zst"; - sha256 = "0nxs571vb4f1i5vp91134p5blns9ml2r25nx6kdlg0zhd5x85kvm"; - name = "mingw-w64-clang-x86_64-rust-1.83.0-3-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-rust-1.84.0-1-any.pkg.tar.zst"; + sha256 = "0nrz9788grl50nkbhxswry143rrwpdnc6pk6f0k30kcp19qq6y2d"; + name = "mingw-w64-clang-x86_64-rust-1.84.0-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { @@ -127,9 +127,9 @@ }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-c-ares-1.34.3-1-any.pkg.tar.zst"; - sha256 = "1mpn397qsdz3l2fav6ymwjlj96ialn9m8sldii3ymbcyhranl3xx"; - name = "mingw-w64-clang-x86_64-c-ares-1.34.3-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-c-ares-1.34.4-1-any.pkg.tar.zst"; + sha256 = "1dppwwx3wrn0lzrlk2q7bpsainbidrpw1ndp1aasyv42xhxl1sn1"; + name = "mingw-w64-clang-x86_64-c-ares-1.34.4-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { @@ -139,9 +139,9 @@ }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libunistring-1.2-1-any.pkg.tar.zst"; - sha256 = "13nz49li39z1zgfx1q9jg4vrmyrmqb6qdq0nqshidaqc6zr16k3g"; - name = "mingw-w64-clang-x86_64-libunistring-1.2-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libunistring-1.3-1-any.pkg.tar.zst"; + sha256 = "1zg58qbfybyqzcj0dalb13l48f9jsras318h02rka65r7wi0pdcg"; + name = "mingw-w64-clang-x86_64-libunistring-1.3-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { @@ -169,9 +169,9 @@ }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-ca-certificates-20240203-1-any.pkg.tar.zst"; - sha256 = "1q5nxhsk04gidz66ai5wgd4dr04lfyakkfja9p0r5hrgg4ppqqjg"; - name = "mingw-w64-clang-x86_64-ca-certificates-20240203-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-ca-certificates-20241223-1-any.pkg.tar.zst"; + sha256 = "0c36lg63imzw8i6j1ard42v5wgzpc83phzk8lvifvm0djndq2bbj"; + name = "mingw-w64-clang-x86_64-ca-certificates-20241223-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { @@ -193,9 +193,9 @@ }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-nghttp3-1.6.0-1-any.pkg.tar.zst"; - sha256 = "1p7q47fin12vzyf126v1azbbpgpa0y6ighfh6mbfdb6zcyq74kbd"; - name = "mingw-w64-clang-x86_64-nghttp3-1.6.0-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-nghttp3-1.7.0-1-any.pkg.tar.zst"; + sha256 = "0kd2f7yh90815kyldxvdy8c6jyxyw0wv4f7k3shwp98w874m0mxd"; + name = "mingw-w64-clang-x86_64-nghttp3-1.7.0-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { @@ -271,15 +271,15 @@ }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-rhash-1.4.4-3-any.pkg.tar.zst"; - sha256 = "1ysbxirpfr0yf7pvyps75lnwc897w2a2kcid3nb4j6ilw6n64jmc"; - name = "mingw-w64-clang-x86_64-rhash-1.4.4-3-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-rhash-1.4.5-1-any.pkg.tar.zst"; + sha256 = "0gdn1351knjwgsqgyaa3l55qs135k7dn6mlf04vzjxlc1895wx5z"; + name = "mingw-w64-clang-x86_64-rhash-1.4.5-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-cmake-3.31.2-3-any.pkg.tar.zst"; - sha256 = "139f91r392c68hsajm0c81690pmzkywb0p4x8ms8ms53ncxnz6gz"; - name = "mingw-w64-clang-x86_64-cmake-3.31.2-3-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-cmake-3.31.4-1-any.pkg.tar.zst"; + sha256 = "1xjjwgkqf2j97pcx0yd6j0lgmzgbgqjjf0s7j29mc03g89fhdhw0"; + name = "mingw-w64-clang-x86_64-cmake-3.31.4-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { @@ -289,9 +289,9 @@ }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-ncurses-6.5.20240831-1-any.pkg.tar.zst"; - sha256 = "1hlfj9g4s767s502sawwbcv4a0xd3ym3ip4jswmhq48wh5050iyb"; - name = "mingw-w64-clang-x86_64-ncurses-6.5.20240831-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-ncurses-6.5.20241228-3-any.pkg.tar.zst"; + sha256 = "0f98pzrwsxil90n55hz2ym2x2rzrrjrmnj8i2203n189qbxbg2c9"; + name = "mingw-w64-clang-x86_64-ncurses-6.5.20241228-3-any.pkg.tar.zst"; }) (pkgs.fetchurl { @@ -331,32 +331,32 @@ }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-3.12.7-3-any.pkg.tar.zst"; - sha256 = "1v15j2pzy9wj4n1rjngdi2hf8h0l9z4lri3xb86yvdv1xl2msj6h"; - name = "mingw-w64-clang-x86_64-python-3.12.7-3-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-3.12.8-2-any.pkg.tar.zst"; + sha256 = "0lksgrmylvpr7yyjcc1szm30pnag7ixrj7vhdql1ryi4k9309v8s"; + name = "mingw-w64-clang-x86_64-python-3.12.8-2-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-llvm-openmp-19.1.4-2-any.pkg.tar.zst"; - sha256 = "1pn1fbj74rx837s9z8gqs4b0cr7kqi5m1m2mi9ibjpw64m1aqwxv"; - name = "mingw-w64-clang-x86_64-llvm-openmp-19.1.4-2-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-llvm-openmp-19.1.6-1-any.pkg.tar.zst"; + sha256 = "0d3mm26hnw716n0ppzqhydxcgm4im081hiiy6l4zp267ad3kfg93"; + name = "mingw-w64-clang-x86_64-llvm-openmp-19.1.6-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-openblas-0.3.28-2-any.pkg.tar.zst"; - sha256 = "18p1zhf7h3k3phf3bl483jg3k7y9zq375z6ww75g62158ic9lfyc"; - name = "mingw-w64-clang-x86_64-openblas-0.3.28-2-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-openblas-0.3.29-1-any.pkg.tar.zst"; + sha256 = "006f2s12jmk35rppkp20rlm7k4kknsnh5h4krqs2ry2rd6qqkk9h"; + name = "mingw-w64-clang-x86_64-openblas-0.3.29-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-numpy-2.1.1-2-any.pkg.tar.zst"; - sha256 = "1kiy7ail04ias47xbbhl9vpsz02g0g3f29ncgx5gcks9vgqldp6m"; - name = "mingw-w64-clang-x86_64-python-numpy-2.1.1-2-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-numpy-2.2.1-1-any.pkg.tar.zst"; + sha256 = "0sgkhax9cwmkkrfrir45l91h6pgg339gaw6147gsayf8h8ag4brg"; + name = "mingw-w64-clang-x86_64-python-numpy-2.2.1-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-setuptools-75.6.0-1-any.pkg.tar.zst"; - sha256 = "03l04kjmy5p9whaw0h619gdg7yw1gxbz8phifq4pzh3c1wlw7yfd"; - name = "mingw-w64-clang-x86_64-python-setuptools-75.6.0-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-setuptools-75.8.0-1-any.pkg.tar.zst"; + sha256 = "12ivpaj967y4bi8396q3fpii4fy5aakidxpv16rkyg1b831k0h93"; + name = "mingw-w64-clang-x86_64-python-setuptools-75.8.0-1-any.pkg.tar.zst"; }) ]