forked from M-Labs/nac3
1
0
Fork 0

Compare commits

..

6 Commits

Author SHA1 Message Date
David Mak 3a64b0cf07 core/toplevel/type_annotation: Add handling for mismatching class def
Primitive types only contain fields in its Type and not its TopLevelDef.
This causes primitive object types to lack some fields.
2024-07-19 13:41:42 +08:00
David Mak 2cdb057d20 artiq/symbol_resolver: Handle type of zero-length lists 2024-07-19 13:41:42 +08:00
David Mak 8f95c707d7 artiq/symbol_resolver: Determine global array type by init-val type 2024-07-19 13:41:42 +08:00
David Mak 9b5fb69875 core/codegen/stmt: Convert assertion values to i1 2024-07-19 13:41:42 +08:00
David Mak 6c1d8ac001 core: Add compile-time feature to disable escape analysis 2024-07-19 13:41:41 +08:00
David Mak 6d6f8be07f meta: Update dependencies 2024-07-19 13:41:41 +08:00
50 changed files with 1025 additions and 4340 deletions

1
.gitignore vendored
View File

@ -1,4 +1,3 @@
__pycache__ __pycache__
/target /target
/nac3standalone/demo/linalg/target
nix/windows/msys2 nix/windows/msys2

147
Cargo.lock generated
View File

@ -26,9 +26,9 @@ dependencies = [
[[package]] [[package]]
name = "anstream" name = "anstream"
version = "0.6.15" version = "0.6.14"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "64e15c1ab1f89faffbf04a634d5e1962e9074f2741eef6d97f3c4e322426d526" checksum = "418c75fa768af9c03be99d17643f93f79bbba589895012a80e3452a19ddda15b"
dependencies = [ dependencies = [
"anstyle", "anstyle",
"anstyle-parse", "anstyle-parse",
@ -41,36 +41,36 @@ dependencies = [
[[package]] [[package]]
name = "anstyle" name = "anstyle"
version = "1.0.8" version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" checksum = "038dfcf04a5feb68e9c60b21c9625a54c2c0616e79b72b0fd87075a056ae1d1b"
[[package]] [[package]]
name = "anstyle-parse" name = "anstyle-parse"
version = "0.2.5" version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb47de1e80c2b463c735db5b217a0ddc39d612e7ac9e2e96a5aed1f57616c1cb" checksum = "c03a11a9034d92058ceb6ee011ce58af4a9bf61491aa7e1e59ecd24bd40d22d4"
dependencies = [ dependencies = [
"utf8parse", "utf8parse",
] ]
[[package]] [[package]]
name = "anstyle-query" name = "anstyle-query"
version = "1.1.1" version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d36fc52c7f6c869915e99412912f22093507da8d9e942ceaf66fe4b7c14422a" checksum = "ad186efb764318d35165f1758e7dcef3b10628e26d41a44bc5550652e6804391"
dependencies = [ dependencies = [
"windows-sys 0.52.0", "windows-sys",
] ]
[[package]] [[package]]
name = "anstyle-wincon" name = "anstyle-wincon"
version = "3.0.4" version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5bf74e1b6e971609db8ca7a9ce79fd5768ab6ae46441c572e46cf596f59e57f8" checksum = "61a38449feb7068f52bb06c12759005cf459ee52bb4adc1d5a7c4322d716fb19"
dependencies = [ dependencies = [
"anstyle", "anstyle",
"windows-sys 0.52.0", "windows-sys",
] ]
[[package]] [[package]]
@ -117,9 +117,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.1.7" version = "1.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26a5c3fd7bfa1ce3897a3a3501d362b2d87b7f2583ebcb4a949ec25911025cbc" checksum = "324c74f2155653c90b04f25b2a47a8a631360cb908f92a772695f430c7e31052"
[[package]] [[package]]
name = "cfg-if" name = "cfg-if"
@ -129,9 +129,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]] [[package]]
name = "clap" name = "clap"
version = "4.5.13" version = "4.5.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fbb260a053428790f3de475e304ff84cdbc4face759ea7a3e64c1edd938a7fc" checksum = "64acc1846d54c1fe936a78dc189c34e28d3f5afc348403f28ecf53660b9b8462"
dependencies = [ dependencies = [
"clap_builder", "clap_builder",
"clap_derive", "clap_derive",
@ -139,9 +139,9 @@ dependencies = [
[[package]] [[package]]
name = "clap_builder" name = "clap_builder"
version = "4.5.13" version = "4.5.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "64b17d7ea74e9f833c7dbf2cbe4fb12ff26783eda4782a8975b72f895c9b4d99" checksum = "6fb8393d67ba2e7bfaf28a23458e4e2b543cc73a99595511eb207fdb8aede942"
dependencies = [ dependencies = [
"anstream", "anstream",
"anstyle", "anstyle",
@ -151,27 +151,27 @@ dependencies = [
[[package]] [[package]]
name = "clap_derive" name = "clap_derive"
version = "4.5.13" version = "4.5.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "501d359d5f3dcaf6ecdeee48833ae73ec6e42723a1e52419c79abf9507eec0a0" checksum = "2bac35c6dafb060fd4d275d9a4ffae97917c13a6327903a8be2153cd964f7085"
dependencies = [ dependencies = [
"heck 0.5.0", "heck 0.5.0",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.72", "syn 2.0.71",
] ]
[[package]] [[package]]
name = "clap_lex" name = "clap_lex"
version = "0.7.2" version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" checksum = "4b82cf0babdbd58558212896d1a4272303a57bdb245c2bf1147185fb45640e70"
[[package]] [[package]]
name = "colorchoice" name = "colorchoice"
version = "1.0.2" version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422"
[[package]] [[package]]
name = "console" name = "console"
@ -182,7 +182,7 @@ dependencies = [
"encode_unicode", "encode_unicode",
"lazy_static", "lazy_static",
"libc", "libc",
"windows-sys 0.52.0", "windows-sys",
] ]
[[package]] [[package]]
@ -302,7 +302,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba"
dependencies = [ dependencies = [
"libc", "libc",
"windows-sys 0.52.0", "windows-sys",
] ]
[[package]] [[package]]
@ -385,9 +385,9 @@ dependencies = [
[[package]] [[package]]
name = "indexmap" name = "indexmap"
version = "2.3.0" version = "2.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "de3fc2e30ba82dd1b3911c8de1ffc143c74a914a14e99514d7637e3099df5ea0" checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26"
dependencies = [ dependencies = [
"equivalent", "equivalent",
"hashbrown 0.14.5", "hashbrown 0.14.5",
@ -421,7 +421,7 @@ checksum = "4fa4d8d74483041a882adaa9a29f633253a66dde85055f0495c121620ac484b2"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.72", "syn 2.0.71",
] ]
[[package]] [[package]]
@ -440,9 +440,9 @@ dependencies = [
[[package]] [[package]]
name = "is_terminal_polyfill" name = "is_terminal_polyfill"
version = "1.70.1" version = "1.70.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" checksum = "f8478577c03552c21db0e2724ffb8986a5ce7af88107e6be5d2ee6e158c12800"
[[package]] [[package]]
name = "itertools" name = "itertools"
@ -513,9 +513,9 @@ checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c"
[[package]] [[package]]
name = "libloading" name = "libloading"
version = "0.8.5" version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4" checksum = "e310b3a6b5907f99202fcdb4960ff45b93735d7c7d96b760fcff8db2dc0e103d"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"windows-targets", "windows-targets",
@ -616,7 +616,7 @@ name = "nac3core"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"crossbeam", "crossbeam",
"indexmap 2.3.0", "indexmap 2.2.6",
"indoc", "indoc",
"inkwell", "inkwell",
"insta", "insta",
@ -706,7 +706,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db"
dependencies = [ dependencies = [
"fixedbitset", "fixedbitset",
"indexmap 2.3.0", "indexmap 2.2.6",
] ]
[[package]] [[package]]
@ -749,7 +749,7 @@ dependencies = [
"phf_shared 0.11.2", "phf_shared 0.11.2",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.72", "syn 2.0.71",
] ]
[[package]] [[package]]
@ -778,18 +778,15 @@ checksum = "5be167a7af36ee22fe3115051bc51f6e6c7054c9348e28deb4f49bd6f705a315"
[[package]] [[package]]
name = "portable-atomic" name = "portable-atomic"
version = "1.7.0" version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265" checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0"
[[package]] [[package]]
name = "ppv-lite86" name = "ppv-lite86"
version = "0.2.20" version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
dependencies = [
"zerocopy",
]
[[package]] [[package]]
name = "precomputed-hash" name = "precomputed-hash"
@ -853,7 +850,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"pyo3-macros-backend", "pyo3-macros-backend",
"quote", "quote",
"syn 2.0.72", "syn 2.0.71",
] ]
[[package]] [[package]]
@ -866,7 +863,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"pyo3-build-config", "pyo3-build-config",
"quote", "quote",
"syn 2.0.72", "syn 2.0.71",
] ]
[[package]] [[package]]
@ -950,9 +947,9 @@ dependencies = [
[[package]] [[package]]
name = "regex" name = "regex"
version = "1.10.6" version = "1.10.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f"
dependencies = [ dependencies = [
"aho-corasick", "aho-corasick",
"memchr", "memchr",
@ -994,7 +991,7 @@ dependencies = [
"errno", "errno",
"libc", "libc",
"linux-raw-sys", "linux-raw-sys",
"windows-sys 0.52.0", "windows-sys",
] ]
[[package]] [[package]]
@ -1047,17 +1044,16 @@ checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.72", "syn 2.0.71",
] ]
[[package]] [[package]]
name = "serde_json" name = "serde_json"
version = "1.0.122" version = "1.0.120"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "784b6203951c57ff748476b126ccb5e8e2959a5c19e5c617ab1956be3dbc68da" checksum = "4e0d21c9a8cae1235ad58a00c11cb40d4b1e5c784f1ef2c537876ed6ffd8b7c5"
dependencies = [ dependencies = [
"itoa", "itoa",
"memchr",
"ryu", "ryu",
"serde", "serde",
] ]
@ -1076,9 +1072,9 @@ dependencies = [
[[package]] [[package]]
name = "similar" name = "similar"
version = "2.6.0" version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1de1d4f81173b03af4c0cbed3c898f6bff5b870e4a7f5d6f4057d62a7a4b686e" checksum = "fa42c91313f1d05da9b26f267f931cf178d4aba455b4c4622dd7355eb80c6640"
[[package]] [[package]]
name = "siphasher" name = "siphasher"
@ -1138,7 +1134,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"rustversion", "rustversion",
"syn 2.0.72", "syn 2.0.71",
] ]
[[package]] [[package]]
@ -1154,9 +1150,9 @@ dependencies = [
[[package]] [[package]]
name = "syn" name = "syn"
version = "2.0.72" version = "2.0.71"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc4b9b9bf2add8093d3f2c0204471e951b2285580335de42f9d2534f3ae7a8af" checksum = "b146dcf730474b4bcd16c311627b31ede9ab149045db4d6088b3becaea046462"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -1165,21 +1161,20 @@ dependencies = [
[[package]] [[package]]
name = "target-lexicon" name = "target-lexicon"
version = "0.12.16" version = "0.12.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" checksum = "4873307b7c257eddcb50c9bedf158eb669578359fb28428bef438fec8e6ba7c2"
[[package]] [[package]]
name = "tempfile" name = "tempfile"
version = "3.11.0" version = "3.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8fcd239983515c23a32fb82099f97d0b11b8c72f654ed659363a95c3dad7a53" checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"fastrand", "fastrand",
"once_cell",
"rustix", "rustix",
"windows-sys 0.52.0", "windows-sys",
] ]
[[package]] [[package]]
@ -1223,7 +1218,7 @@ checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.72", "syn 2.0.71",
] ]
[[package]] [[package]]
@ -1341,9 +1336,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]] [[package]]
name = "version_check" name = "version_check"
version = "0.9.5" version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
[[package]] [[package]]
name = "walkdir" name = "walkdir"
@ -1379,11 +1374,11 @@ checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]] [[package]]
name = "winapi-util" name = "winapi-util"
version = "0.1.9" version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" checksum = "4d4cc384e1e73b93bafa6fb4f1df8c41695c8a91cf9c4c64358067d15a7b6c6b"
dependencies = [ dependencies = [
"windows-sys 0.59.0", "windows-sys",
] ]
[[package]] [[package]]
@ -1401,15 +1396,6 @@ dependencies = [
"windows-targets", "windows-targets",
] ]
[[package]]
name = "windows-sys"
version = "0.59.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b"
dependencies = [
"windows-targets",
]
[[package]] [[package]]
name = "windows-targets" name = "windows-targets"
version = "0.52.6" version = "0.52.6"
@ -1489,7 +1475,6 @@ version = "0.7.35"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0"
dependencies = [ dependencies = [
"byteorder",
"zerocopy-derive", "zerocopy-derive",
] ]
@ -1501,5 +1486,5 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.72", "syn 2.0.71",
] ]

View File

@ -2,11 +2,11 @@
"nodes": { "nodes": {
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1721924956, "lastModified": 1720418205,
"narHash": "sha256-Sb1jlyRO+N8jBXEX9Pg9Z1Qb8Bw9QyOgLDNMEpmjZ2M=", "narHash": "sha256-cPJoFPXU44GlhWg4pUk9oUPqurPlCFZ11ZQPk21GTPU=",
"owner": "NixOS", "owner": "NixOS",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "5ad6a14c6bf098e98800b091668718c336effc95", "rev": "655a58a72a6601292512670343087c2d75d859c1",
"type": "github" "type": "github"
}, },
"original": { "original": {

View File

@ -6,7 +6,6 @@
outputs = { self, nixpkgs }: outputs = { self, nixpkgs }:
let let
pkgs = import nixpkgs { system = "x86_64-linux"; }; pkgs = import nixpkgs { system = "x86_64-linux"; };
pkgs32 = import nixpkgs { system = "i686-linux"; };
in rec { in rec {
packages.x86_64-linux = rec { packages.x86_64-linux = rec {
llvm-nac3 = pkgs.callPackage ./nix/llvm {}; llvm-nac3 = pkgs.callPackage ./nix/llvm {};
@ -16,22 +15,6 @@
ln -s ${pkgs.llvmPackages_14.clang-unwrapped}/bin/clang $out/bin/clang-irrt ln -s ${pkgs.llvmPackages_14.clang-unwrapped}/bin/clang $out/bin/clang-irrt
ln -s ${pkgs.llvmPackages_14.llvm.out}/bin/llvm-as $out/bin/llvm-as-irrt ln -s ${pkgs.llvmPackages_14.llvm.out}/bin/llvm-as $out/bin/llvm-as-irrt
''; '';
demo-linalg-stub = pkgs.rustPlatform.buildRustPackage {
name = "demo-linalg-stub";
src = ./nac3standalone/demo/linalg;
cargoLock = {
lockFile = ./nac3standalone/demo/linalg/Cargo.lock;
};
doCheck = false;
};
demo-linalg-stub32 = pkgs32.rustPlatform.buildRustPackage {
name = "demo-linalg-stub32";
src = ./nac3standalone/demo/linalg;
cargoLock = {
lockFile = ./nac3standalone/demo/linalg/Cargo.lock;
};
doCheck = false;
};
nac3artiq = pkgs.python3Packages.toPythonModule ( nac3artiq = pkgs.python3Packages.toPythonModule (
pkgs.rustPlatform.buildRustPackage rec { pkgs.rustPlatform.buildRustPackage rec {
name = "nac3artiq"; name = "nac3artiq";
@ -41,7 +24,7 @@
lockFile = ./Cargo.lock; lockFile = ./Cargo.lock;
}; };
passthru.cargoLock = cargoLock; passthru.cargoLock = cargoLock;
nativeBuildInputs = [ pkgs.python3 (pkgs.wrapClangMulti pkgs.llvmPackages_14.clang) llvm-tools-irrt pkgs.llvmPackages_14.llvm.out llvm-nac3 ]; nativeBuildInputs = [ pkgs.python3 pkgs.llvmPackages_14.clang llvm-tools-irrt pkgs.llvmPackages_14.llvm.out llvm-nac3 ];
buildInputs = [ pkgs.python3 llvm-nac3 ]; buildInputs = [ pkgs.python3 llvm-nac3 ];
checkInputs = [ (pkgs.python3.withPackages(ps: [ ps.numpy ps.scipy ])) ]; checkInputs = [ (pkgs.python3.withPackages(ps: [ ps.numpy ps.scipy ])) ];
checkPhase = checkPhase =
@ -49,9 +32,7 @@
echo "Checking nac3standalone demos..." echo "Checking nac3standalone demos..."
pushd nac3standalone/demo pushd nac3standalone/demo
patchShebangs . patchShebangs .
export DEMO_LINALG_STUB=${demo-linalg-stub}/lib/liblinalg.a ./check_demos.sh
export DEMO_LINALG_STUB32=${demo-linalg-stub32}/lib/liblinalg.a
./check_demos.sh -i686
popd popd
echo "Running Cargo tests..." echo "Running Cargo tests..."
cargoCheckHook cargoCheckHook
@ -168,7 +149,7 @@
buildInputs = with pkgs; [ buildInputs = with pkgs; [
# build dependencies # build dependencies
packages.x86_64-linux.llvm-nac3 packages.x86_64-linux.llvm-nac3
(pkgs.wrapClangMulti llvmPackages_14.clang) llvmPackages_14.llvm.out # for running nac3standalone demos llvmPackages_14.clang llvmPackages_14.llvm.out # for running nac3standalone demos
packages.x86_64-linux.llvm-tools-irrt packages.x86_64-linux.llvm-tools-irrt
cargo cargo
rustc rustc
@ -181,11 +162,6 @@
pre-commit pre-commit
rustfmt rustfmt
]; ];
shellHook =
''
export DEMO_LINALG_STUB=${packages.x86_64-linux.demo-linalg-stub}/lib/liblinalg.a
export DEMO_LINALG_STUB32=${packages.x86_64-linux.demo-linalg-stub32}/lib/liblinalg.a
'';
}; };
devShells.x86_64-linux.msys2 = pkgs.mkShell { devShells.x86_64-linux.msys2 = pkgs.mkShell {
name = "nac3-dev-shell-msys2"; name = "nac3-dev-shell-msys2";

View File

@ -24,3 +24,4 @@ features = ["llvm14-0", "target-x86", "target-arm", "target-riscv", "no-libffi-l
[features] [features]
init-llvm-profile = [] init-llvm-profile = []
no-escape-analysis = ["nac3core/no-escape-analysis"]

View File

@ -386,7 +386,7 @@ fn gen_rpc_tag(
} else { } else {
let ty_enum = ctx.unifier.get_ty(ty); let ty_enum = ctx.unifier.get_ty(ty);
match &*ty_enum { match &*ty_enum {
TTuple { ty, is_vararg_ctx: false } => { TTuple { ty } => {
buffer.push(b't'); buffer.push(b't');
buffer.push(ty.len() as u8); buffer.push(ty.len() as u8);
for ty in ty { for ty in ty {
@ -700,7 +700,6 @@ pub fn attributes_writeback(
name: i.to_string().into(), name: i.to_string().into(),
ty: *ty, ty: *ty,
default_value: None, default_value: None,
is_vararg: false,
}) })
.collect(), .collect(),
ret: ctx.primitives.none, ret: ctx.primitives.none,

View File

@ -24,7 +24,6 @@ use std::rc::Rc;
use std::sync::Arc; use std::sync::Arc;
use inkwell::{ use inkwell::{
context::Context,
memory_buffer::MemoryBuffer, memory_buffer::MemoryBuffer,
module::{Linkage, Module}, module::{Linkage, Module},
passes::PassBuilderOptions, passes::PassBuilderOptions,
@ -265,7 +264,7 @@ impl Nac3 {
arg_names.len(), arg_names.len(),
)); ));
} }
for (i, FuncArg { ty, default_value, name, .. }) in args.iter().enumerate() { for (i, FuncArg { ty, default_value, name }) in args.iter().enumerate() {
let in_name = match arg_names.get(i) { let in_name = match arg_names.get(i) {
Some(n) => n, Some(n) => n,
None if default_value.is_none() => { None if default_value.is_none() => {
@ -626,9 +625,7 @@ impl Nac3 {
let buffer = buffer.as_slice().into(); let buffer = buffer.as_slice().into();
membuffer.lock().push(buffer); membuffer.lock().push(buffer);
}))); })));
let size_t = Context::create() let size_t = if self.isa == Isa::Host { 64 } else { 32 };
.ptr_sized_int_type(&self.get_llvm_target_machine().get_target_data(), None)
.get_bit_width();
let num_threads = if is_multithreaded() { 4 } else { 1 }; let num_threads = if is_multithreaded() { 4 } else { 1 };
let thread_names: Vec<String> = (0..num_threads).map(|_| "main".to_string()).collect(); let thread_names: Vec<String> = (0..num_threads).map(|_| "main".to_string()).collect();
let threads: Vec<_> = thread_names let threads: Vec<_> = thread_names
@ -647,9 +644,6 @@ impl Nac3 {
ArtiqCodeGenerator::new("attributes_writeback".to_string(), size_t, self.time_fns); ArtiqCodeGenerator::new("attributes_writeback".to_string(), size_t, self.time_fns);
let context = inkwell::context::Context::create(); let context = inkwell::context::Context::create();
let module = context.create_module("attributes_writeback"); let module = context.create_module("attributes_writeback");
let target_machine = self.llvm_options.create_target_machine().unwrap();
module.set_data_layout(&target_machine.get_target_data().get_data_layout());
module.set_triple(&target_machine.get_triple());
let builder = context.create_builder(); let builder = context.create_builder();
let (_, module, _) = gen_func_impl( let (_, module, _) = gen_func_impl(
&context, &context,
@ -869,7 +863,6 @@ impl Nac3 {
name: "t".into(), name: "t".into(),
ty: primitive.int64, ty: primitive.int64,
default_value: None, default_value: None,
is_vararg: false,
}], }],
ret: primitive.none, ret: primitive.none,
vars: VarMap::new(), vars: VarMap::new(),
@ -889,7 +882,6 @@ impl Nac3 {
name: "dt".into(), name: "dt".into(),
ty: primitive.int64, ty: primitive.int64,
default_value: None, default_value: None,
is_vararg: false,
}], }],
ret: primitive.none, ret: primitive.none,
vars: VarMap::new(), vars: VarMap::new(),

View File

@ -351,7 +351,7 @@ impl InnerResolver {
Ok(Ok((ndarray, false))) Ok(Ok((ndarray, false)))
} else if ty_id == self.primitive_ids.tuple { } else if ty_id == self.primitive_ids.tuple {
// do not handle type var param and concrete check here // do not handle type var param and concrete check here
Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![], is_vararg_ctx: false }), false))) Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false)))
} else if ty_id == self.primitive_ids.option { } else if ty_id == self.primitive_ids.option {
Ok(Ok((primitives.option, false))) Ok(Ok((primitives.option, false)))
} else if ty_id == self.primitive_ids.none { } else if ty_id == self.primitive_ids.none {
@ -555,10 +555,7 @@ impl InnerResolver {
Err(err) => return Ok(Err(err)), Err(err) => return Ok(Err(err)),
_ => return Ok(Err("tuple type needs at least 1 type parameters".to_string())) _ => return Ok(Err("tuple type needs at least 1 type parameters".to_string()))
}; };
Ok(Ok(( Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: args }), true)))
unifier.add_ty(TypeEnum::TTuple { ty: args, is_vararg_ctx: false }),
true,
)))
} }
TypeEnum::TObj { params, obj_id, .. } => { TypeEnum::TObj { params, obj_id, .. } => {
let subst = { let subst = {
@ -800,9 +797,7 @@ impl InnerResolver {
.map(|elem| self.get_obj_type(py, elem, unifier, defs, primitives)) .map(|elem| self.get_obj_type(py, elem, unifier, defs, primitives))
.collect(); .collect();
let types = types?; let types = types?;
Ok(types.map(|types| { Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types })))
unifier.add_ty(TypeEnum::TTuple { ty: types, is_vararg_ctx: false })
}))
} }
// special handling for option type since its class member layout in python side // special handling for option type since its class member layout in python side
// is special and cannot be mapped directly to a nac3 type as below // is special and cannot be mapped directly to a nac3 type as below
@ -1085,6 +1080,8 @@ impl InnerResolver {
unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty); unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty);
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
// TODO: Special handling required for strings, since there are two representations:
// struct %str and [n x i8].
let ndarray_dtype_llvm_ty = ctx.get_llvm_type(generator, ndarray_dtype); let ndarray_dtype_llvm_ty = ctx.get_llvm_type(generator, ndarray_dtype);
let ndarray_llvm_ty = NDArrayType::new(generator, ctx.ctx, ndarray_dtype_llvm_ty); let ndarray_llvm_ty = NDArrayType::new(generator, ctx.ctx, ndarray_dtype_llvm_ty);
@ -1152,31 +1149,44 @@ impl InnerResolver {
}) })
}) })
.collect(); .collect();
let data = data?.unwrap().into_iter(); let data = data?.unwrap();
let data = match ndarray_dtype_llvm_ty {
BasicTypeEnum::ArrayType(ty) => {
ty.const_array(&data.map(BasicValueEnum::into_array_value).collect_vec())
}
BasicTypeEnum::FloatType(ty) => { let make_llvm_array =
ty.const_array(&data.map(BasicValueEnum::into_float_value).collect_vec()) |llvm_ty: BasicTypeEnum<'ctx>, elems: Vec<BasicValueEnum<'ctx>>| {
} debug_assert!(elems.iter().all(|elem| elem.get_type() == llvm_ty));
BasicTypeEnum::IntType(ty) => { match llvm_ty {
ty.const_array(&data.map(BasicValueEnum::into_int_value).collect_vec()) BasicTypeEnum::ArrayType(ty) => ty.const_array(
} &elems.into_iter().map(BasicValueEnum::into_array_value).collect_vec(),
),
BasicTypeEnum::PointerType(ty) => { BasicTypeEnum::FloatType(ty) => ty.const_array(
ty.const_array(&data.map(BasicValueEnum::into_pointer_value).collect_vec()) &elems.into_iter().map(BasicValueEnum::into_float_value).collect_vec(),
} ),
BasicTypeEnum::StructType(ty) => { BasicTypeEnum::IntType(ty) => ty.const_array(
ty.const_array(&data.map(BasicValueEnum::into_struct_value).collect_vec()) &elems.into_iter().map(BasicValueEnum::into_int_value).collect_vec(),
} ),
BasicTypeEnum::PointerType(ty) => ty.const_array(
&elems
.into_iter()
.map(BasicValueEnum::into_pointer_value)
.collect_vec(),
),
BasicTypeEnum::StructType(ty) => ty.const_array(
&elems.into_iter().map(BasicValueEnum::into_struct_value).collect_vec(),
),
BasicTypeEnum::VectorType(_) => unreachable!(), BasicTypeEnum::VectorType(_) => unreachable!(),
}
}; };
let ndarray_dtype_llvm_ty =
if data.is_empty() { ndarray_dtype_llvm_ty } else { data[0].get_type() };
let data = make_llvm_array(ndarray_dtype_llvm_ty, data);
// create a global for ndarray.data and initialize it using the elements // create a global for ndarray.data and initialize it using the elements
let data_global = ctx.module.add_global( let data_global = ctx.module.add_global(
ndarray_dtype_llvm_ty.array_type(sz as u32), ndarray_dtype_llvm_ty.array_type(sz as u32),
@ -1208,9 +1218,7 @@ impl InnerResolver {
Ok(Some(ndarray.as_pointer_value().into())) Ok(Some(ndarray.as_pointer_value().into()))
} else if ty_id == self.primitive_ids.tuple { } else if ty_id == self.primitive_ids.tuple {
let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty); let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty);
let TypeEnum::TTuple { ty, is_vararg_ctx: false } = expected_ty_enum.as_ref() else { let TypeEnum::TTuple { ty } = expected_ty_enum.as_ref() else { unreachable!() };
unreachable!()
};
let tup_tys = ty.iter(); let tup_tys = ty.iter();
let elements: &PyTuple = obj.downcast()?; let elements: &PyTuple = obj.downcast()?;

View File

@ -4,6 +4,9 @@ version = "0.1.0"
authors = ["M-Labs"] authors = ["M-Labs"]
edition = "2021" edition = "2021"
[features]
no-escape-analysis = []
[dependencies] [dependencies]
itertools = "0.13" itertools = "0.13"
crossbeam = "0.8" crossbeam = "0.8"

View File

@ -1,11 +1,9 @@
use inkwell::types::BasicTypeEnum; use inkwell::types::BasicTypeEnum;
use inkwell::values::{BasicValue, BasicValueEnum, PointerValue}; use inkwell::values::BasicValueEnum;
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel}; use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
use itertools::Itertools; use itertools::Itertools;
use crate::codegen::classes::{ use crate::codegen::classes::{NDArrayValue, ProxyValue, UntypedArrayLikeAccessor};
NDArrayValue, ProxyValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
};
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl; use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
use crate::codegen::stmt::gen_for_callback_incrementing; use crate::codegen::stmt::gen_for_callback_incrementing;
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator}; use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
@ -33,6 +31,7 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let (n_ty, n) = n; let (n_ty, n) = n;
Ok(match n { Ok(match n {
BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => { BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool)); debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool));
@ -603,7 +602,7 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>(
ret_elem_ty, ret_elem_ty,
None, None,
NDArrayValue::from_ptr_val(n, llvm_usize, None), NDArrayValue::from_ptr_val(n, llvm_usize, None),
|generator, ctx, val| call_ceil(generator, ctx, (elem_ty, val), ret_elem_ty), |generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty),
)?; )?;
ndarray.as_base_value().into() ndarray.as_base_value().into()
@ -864,7 +863,6 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
gen_for_callback_incrementing( gen_for_callback_incrementing(
generator, generator,
ctx, ctx,
None,
llvm_int64.const_int(1, false), llvm_int64.const_int(1, false),
(n_sz, false), (n_sz, false),
|generator, ctx, _, idx| { |generator, ctx, _, idx| {
@ -1837,501 +1835,3 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]),
}) })
} }
/// Allocates a struct with the fields specified by `out_matrices` and returns a pointer to it
fn build_output_struct<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
out_matrices: Vec<BasicValueEnum<'ctx>>,
) -> PointerValue<'ctx> {
let field_ty =
out_matrices.iter().map(BasicValueEnum::get_type).collect::<Vec<BasicTypeEnum>>();
let out_ty = ctx.ctx.struct_type(&field_ty, false);
let out_ptr = ctx.builder.build_alloca(out_ty, "").unwrap();
for (i, v) in out_matrices.into_iter().enumerate() {
unsafe {
let ptr = ctx
.builder
.build_in_bounds_gep(
out_ptr,
&[
ctx.ctx.i32_type().const_zero(),
ctx.ctx.i32_type().const_int(i as u64, false),
],
"",
)
.unwrap();
ctx.builder.build_store(ptr, v).unwrap();
}
}
out_ptr
}
/// Invokes the `np_linalg_cholesky` linalg function
pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_cholesky";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_cholesky(ctx, x1, out, None);
Ok(out)
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
/// Invokes the `np_linalg_qr` linalg function
pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_qr";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k])
.unwrap()
.as_base_value()
.as_basic_value_enum();
let out_r = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_qr(ctx, x1, out_q, out_r, None);
let out_ptr = build_output_struct(ctx, vec![out_q, out_r]);
Ok(ctx.builder.build_load(out_ptr, "QR_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
/// Invokes the `np_linalg_svd` linalg function
pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_svd";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
let out_s = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k])
.unwrap()
.as_base_value()
.as_basic_value_enum();
let out_vh = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_svd(ctx, x1, out_u, out_s, out_vh, None);
let out_ptr = build_output_struct(ctx, vec![out_u, out_s, out_vh]);
Ok(ctx.builder.build_load(out_ptr, "SVD_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
/// Invokes the `np_linalg_inv` linalg function
pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_inv";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_inv(ctx, x1, out, None);
Ok(out)
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
/// Invokes the `np_linalg_pinv` linalg function
pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_pinv";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_pinv(ctx, x1, out, None);
Ok(out)
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
/// Invokes the `sp_linalg_lu` linalg function
pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "sp_linalg_lu";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
let out_l = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k])
.unwrap()
.as_base_value()
.as_basic_value_enum();
let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_sp_linalg_lu(ctx, x1, out_l, out_u, None);
let out_ptr = build_output_struct(ctx, vec![out_l, out_u]);
Ok(ctx.builder.build_load(out_ptr, "LU_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
/// Invokes the `np_linalg_matrix_power` linalg function
pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_matrix_power";
let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
let x2 = call_float(generator, ctx, (x2_ty, x2)).unwrap();
let llvm_usize = generator.get_size_type(ctx.ctx);
if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::FloatValue(n2)) = (x1, x2) {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
// Changing second parameter to a `NDArray` for uniformity in function call
let n2_array = numpy::create_ndarray_const_shape(
generator,
ctx,
elem_ty,
&[llvm_usize.const_int(1, false)],
)
.unwrap();
unsafe {
n2_array.data().set_unchecked(
ctx,
generator,
&llvm_usize.const_zero(),
n2.as_basic_value_enum(),
);
};
let n2_array = n2_array.as_base_value().as_basic_value_enum();
let outdim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let outdim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[outdim0, outdim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_matrix_power(ctx, x1, n2_array, out, None);
Ok(out)
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
}
}
/// Invokes the `np_linalg_det` linalg function
pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_matrix_power";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(_) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
// Changing second parameter to a `NDArray` for uniformity in function call
let out = numpy::create_ndarray_const_shape(
generator,
ctx,
elem_ty,
&[llvm_usize.const_int(1, false)],
)
.unwrap();
extern_fns::call_np_linalg_det(ctx, x1, out.as_base_value().as_basic_value_enum(), None);
let res =
unsafe { out.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
Ok(res)
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
/// Invokes the `sp_linalg_schur` linalg function
pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "sp_linalg_schur";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let out_t = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
let out_z = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_sp_linalg_schur(ctx, x1, out_t, out_z, None);
let out_ptr = build_output_struct(ctx, vec![out_t, out_z]);
Ok(ctx.builder.build_load(out_ptr, "Schur_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
/// Invokes the `sp_linalg_hessenberg` linalg function
pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "sp_linalg_hessenberg";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let out_h = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_sp_linalg_hessenberg(ctx, x1, out_h, out_q, None);
let out_ptr = build_output_struct(ctx, vec![out_h, out_q]);
Ok(ctx
.builder
.build_load(out_ptr, "Hessenberg_decomposition_result")
.map(Into::into)
.unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}

View File

@ -1717,7 +1717,6 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
gen_for_callback_incrementing( gen_for_callback_incrementing(
generator, generator,
ctx, ctx,
None,
llvm_usize.const_zero(), llvm_usize.const_zero(),
(len, false), (len, false),
|generator, ctx, _, i| { |generator, ctx, _, i| {

View File

@ -25,7 +25,6 @@ pub struct ConcreteFuncArg {
pub name: StrRef, pub name: StrRef,
pub ty: ConcreteType, pub ty: ConcreteType,
pub default_value: Option<SymbolValue>, pub default_value: Option<SymbolValue>,
pub is_vararg: bool,
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -47,7 +46,6 @@ pub enum ConcreteTypeEnum {
TPrimitive(Primitive), TPrimitive(Primitive),
TTuple { TTuple {
ty: Vec<ConcreteType>, ty: Vec<ConcreteType>,
is_vararg_ctx: bool,
}, },
TObj { TObj {
obj_id: DefinitionId, obj_id: DefinitionId,
@ -104,16 +102,8 @@ impl ConcreteTypeStore {
.iter() .iter()
.map(|arg| ConcreteFuncArg { .map(|arg| ConcreteFuncArg {
name: arg.name, name: arg.name,
ty: if arg.is_vararg { ty: self.from_unifier_type(unifier, primitives, arg.ty, cache),
let tuple_ty = unifier
.add_ty(TypeEnum::TTuple { ty: vec![arg.ty], is_vararg_ctx: true });
self.from_unifier_type(unifier, primitives, tuple_ty, cache)
} else {
self.from_unifier_type(unifier, primitives, arg.ty, cache)
},
default_value: arg.default_value.clone(), default_value: arg.default_value.clone(),
is_vararg: arg.is_vararg,
}) })
.collect(), .collect(),
ret: self.from_unifier_type(unifier, primitives, signature.ret, cache), ret: self.from_unifier_type(unifier, primitives, signature.ret, cache),
@ -168,12 +158,11 @@ impl ConcreteTypeStore {
cache.insert(ty, None); cache.insert(ty, None);
let ty_enum = unifier.get_ty(ty); let ty_enum = unifier.get_ty(ty);
let result = match &*ty_enum { let result = match &*ty_enum {
TypeEnum::TTuple { ty, is_vararg_ctx } => ConcreteTypeEnum::TTuple { TypeEnum::TTuple { ty } => ConcreteTypeEnum::TTuple {
ty: ty ty: ty
.iter() .iter()
.map(|t| self.from_unifier_type(unifier, primitives, *t, cache)) .map(|t| self.from_unifier_type(unifier, primitives, *t, cache))
.collect(), .collect(),
is_vararg_ctx: *is_vararg_ctx,
}, },
TypeEnum::TObj { obj_id, fields, params } => ConcreteTypeEnum::TObj { TypeEnum::TObj { obj_id, fields, params } => ConcreteTypeEnum::TObj {
obj_id: *obj_id, obj_id: *obj_id,
@ -259,12 +248,11 @@ impl ConcreteTypeStore {
*cache.get_mut(&cty).unwrap() = Some(ty); *cache.get_mut(&cty).unwrap() = Some(ty);
return ty; return ty;
} }
ConcreteTypeEnum::TTuple { ty, is_vararg_ctx } => TypeEnum::TTuple { ConcreteTypeEnum::TTuple { ty } => TypeEnum::TTuple {
ty: ty ty: ty
.iter() .iter()
.map(|cty| self.to_unifier_type(unifier, primitives, *cty, cache)) .map(|cty| self.to_unifier_type(unifier, primitives, *cty, cache))
.collect(), .collect(),
is_vararg_ctx: *is_vararg_ctx,
}, },
ConcreteTypeEnum::TVirtual { ty } => { ConcreteTypeEnum::TVirtual { ty } => {
TypeEnum::TVirtual { ty: self.to_unifier_type(unifier, primitives, *ty, cache) } TypeEnum::TVirtual { ty: self.to_unifier_type(unifier, primitives, *ty, cache) }
@ -289,7 +277,6 @@ impl ConcreteTypeStore {
name: arg.name, name: arg.name,
ty: self.to_unifier_type(unifier, primitives, arg.ty, cache), ty: self.to_unifier_type(unifier, primitives, arg.ty, cache),
default_value: arg.default_value.clone(), default_value: arg.default_value.clone(),
is_vararg: false,
}) })
.collect(), .collect(),
ret: self.to_unifier_type(unifier, primitives, *ret, cache), ret: self.to_unifier_type(unifier, primitives, *ret, cache),

View File

@ -1,3 +1,5 @@
use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
use crate::{ use crate::{
codegen::{ codegen::{
classes::{ classes::{
@ -5,7 +7,7 @@ use crate::{
ProxyValue, RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor, ProxyValue, RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor,
}, },
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore}, concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
gen_in_range_check, get_llvm_abi_type, get_llvm_type, get_va_count_arg_name, gen_in_range_check, get_llvm_abi_type, get_llvm_type,
irrt::*, irrt::*,
llvm_intrinsics::{ llvm_intrinsics::{
call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax, call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax,
@ -40,8 +42,6 @@ use nac3parser::ast::{
self, Boolop, Cmpop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef, self, Boolop, Cmpop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef,
Unaryop, Unaryop,
}; };
use std::iter::{repeat, repeat_with};
use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
pub fn get_subst_key( pub fn get_subst_key(
unifier: &mut Unifier, unifier: &mut Unifier,
@ -201,7 +201,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
/// See [`get_llvm_type`]. /// See [`get_llvm_type`].
pub fn get_llvm_type<G: CodeGenerator + ?Sized>( pub fn get_llvm_type<G: CodeGenerator + ?Sized>(
&mut self, &mut self,
generator: &G, generator: &mut G,
ty: Type, ty: Type,
) -> BasicTypeEnum<'ctx> { ) -> BasicTypeEnum<'ctx> {
get_llvm_type( get_llvm_type(
@ -218,7 +218,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
/// See [`get_llvm_abi_type`]. /// See [`get_llvm_abi_type`].
pub fn get_llvm_abi_type<G: CodeGenerator + ?Sized>( pub fn get_llvm_abi_type<G: CodeGenerator + ?Sized>(
&mut self, &mut self,
generator: &G, generator: &mut G,
ty: Type, ty: Type,
) -> BasicTypeEnum<'ctx> { ) -> BasicTypeEnum<'ctx> {
get_llvm_abi_type( get_llvm_abi_type(
@ -267,16 +267,13 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
} }
Constant::Tuple(v) => { Constant::Tuple(v) => {
let ty = self.unifier.get_ty(ty); let ty = self.unifier.get_ty(ty);
let (types, is_vararg_ctx) = if let TypeEnum::TTuple { ty, is_vararg_ctx } = &*ty { let types =
(ty.clone(), *is_vararg_ctx) if let TypeEnum::TTuple { ty } = &*ty { ty.clone() } else { unreachable!() };
} else {
unreachable!()
};
let values = zip(types, v.iter()) let values = zip(types, v.iter())
.map_while(|(ty, v)| self.gen_const(generator, v, ty)) .map_while(|(ty, v)| self.gen_const(generator, v, ty))
.collect_vec(); .collect_vec();
if is_vararg_ctx || values.len() == v.len() { if values.len() == v.len() {
let types = values.iter().map(BasicValueEnum::get_type).collect_vec(); let types = values.iter().map(BasicValueEnum::get_type).collect_vec();
let ty = self.ctx.struct_type(&types, false); let ty = self.ctx.struct_type(&types, false);
Some(ty.const_named_struct(&values).into()) Some(ty.const_named_struct(&values).into())
@ -517,19 +514,16 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
} }
} }
} }
let params = if loc_params.is_empty() { params } else { &loc_params }; let params = if loc_params.is_empty() { params } else { &loc_params };
let params = fun let params = fun
.get_type() .get_type()
.get_param_types() .get_param_types()
.into_iter() .into_iter()
.map(Some)
.chain(repeat(None))
.zip(params.iter()) .zip(params.iter())
.map(|(ty, val)| match (ty, val.get_type()) { .map(|(ty, val)| match (ty, val.get_type()) {
(Some(BasicTypeEnum::PointerType(arg_ty)), BasicTypeEnum::PointerType(val_ty)) (BasicTypeEnum::PointerType(arg_ty), BasicTypeEnum::PointerType(val_ty))
if { if {
ty.unwrap() != val.get_type() ty != val.get_type()
&& arg_ty.get_element_type().is_struct_type() && arg_ty.get_element_type().is_struct_type()
&& val_ty.get_element_type().is_struct_type() && val_ty.get_element_type().is_struct_type()
} => } =>
@ -539,7 +533,6 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
_ => *val, _ => *val,
}) })
.collect_vec(); .collect_vec();
let result = if let Some(target) = self.unwind_target { let result = if let Some(target) = self.unwind_target {
let current = self.builder.get_insert_block().unwrap().get_parent().unwrap(); let current = self.builder.get_insert_block().unwrap().get_parent().unwrap();
let then_block = self.ctx.append_basic_block(current, &format!("after.{call_name}")); let then_block = self.ctx.append_basic_block(current, &format!("after.{call_name}"));
@ -559,7 +552,6 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
.map(Either::left) .map(Either::left)
.unwrap() .unwrap()
}; };
if let Some(slot) = return_slot { if let Some(slot) = return_slot {
Some(self.builder.build_load(slot, call_name).unwrap()) Some(self.builder.build_load(slot, call_name).unwrap())
} else { } else {
@ -734,41 +726,13 @@ pub fn gen_func_instance<'ctx>(
.collect(); .collect();
let mut signature = store.from_signature(&mut ctx.unifier, &ctx.primitives, sign, &mut cache); let mut signature = store.from_signature(&mut ctx.unifier, &ctx.primitives, sign, &mut cache);
let ConcreteTypeEnum::TFunc { args, .. } = &mut signature else { unreachable!() };
if let Some(obj) = &obj { if let Some(obj) = &obj {
let zelf = store.from_unifier_type(&mut ctx.unifier, &ctx.primitives, obj.0, &mut cache); let zelf = store.from_unifier_type(&mut ctx.unifier, &ctx.primitives, obj.0, &mut cache);
let ConcreteTypeEnum::TFunc { args, .. } = &mut signature else { unreachable!() };
args.insert( args.insert(0, ConcreteFuncArg { name: "self".into(), ty: zelf, default_value: None });
0,
ConcreteFuncArg {
name: "self".into(),
ty: zelf,
default_value: None,
is_vararg: false,
},
);
} }
if let Some(vararg_arg) = sign.args.iter().find(|arg| arg.is_vararg) {
let va_count_arg = get_va_count_arg_name(vararg_arg.name);
args.insert(
args.len() - 1,
ConcreteFuncArg {
name: va_count_arg,
ty: store.from_unifier_type(
&mut ctx.unifier,
&ctx.primitives,
ctx.primitives.usize(),
&mut cache,
),
default_value: None,
is_vararg: false,
},
);
}
let signature = store.add_cty(signature); let signature = store.add_cty(signature);
ctx.registry.add_task(CodeGenTask { ctx.registry.add_task(CodeGenTask {
@ -793,17 +757,11 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
fun: (&FunSignature, DefinitionId), fun: (&FunSignature, DefinitionId),
params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>, params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
) -> Result<Option<BasicValueEnum<'ctx>>, String> { ) -> Result<Option<BasicValueEnum<'ctx>>, String> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let definition = ctx.top_level.definitions.read().get(fun.1 .0).cloned().unwrap(); let definition = ctx.top_level.definitions.read().get(fun.1 .0).cloned().unwrap();
let id; let id;
let key; let key;
let param_vals; let param_vals;
let is_extern; let is_extern;
let vararg_arg;
// Ensure that the function object only contains up to 1 vararg parameter
debug_assert!(fun.0.args.iter().filter(|arg| arg.is_vararg).count() <= 1);
let symbol = { let symbol = {
// make sure this lock guard is dropped at the end of this scope... // make sure this lock guard is dropped at the end of this scope...
@ -819,72 +777,22 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
return callback.run(ctx, obj, fun, params, generator); return callback.run(ctx, obj, fun, params, generator);
} }
is_extern = instance_to_stmt.is_empty(); is_extern = instance_to_stmt.is_empty();
vararg_arg = fun.0.args.iter().find(|arg| arg.is_vararg);
let old_key = ctx.get_subst_key(obj.as_ref().map(|a| a.0), fun.0, None); let old_key = ctx.get_subst_key(obj.as_ref().map(|a| a.0), fun.0, None);
let mut keys = fun.0.args.clone(); let mut keys = fun.0.args.clone();
let mut mapping = HashMap::<_, Vec<ValueEnum>>::new(); let mut mapping = HashMap::new();
for (key, value) in params { for (key, value) in params {
// Find the matching argument mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value);
let matching_param = fun
.0
.args
.iter()
.find_or_last(|p| key.is_some_and(|k| k == p.name))
.unwrap();
if matching_param.is_vararg {
if key.is_none() && !keys.is_empty() {
keys.remove(0);
} }
// vararg is lowered into two arguments - va_count and `...`
// Handle va_count first, for each argument encountered we increment it by 1
let va_count = get_va_count_arg_name(matching_param.name);
if let Some(params) = mapping.get_mut(&va_count) {
debug_assert_eq!(params.len(), 1);
let param = params[0]
.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.usize())?
.into_int_value();
params[0] = param.const_add(llvm_usize.const_int(1, false)).into();
} else {
mapping.insert(va_count, vec![llvm_usize.const_int(1, false).into()]);
}
if let Some(param) = mapping.get_mut(&matching_param.name) {
param.push(value);
} else {
mapping.insert(key.unwrap_or(matching_param.name), vec![value]);
}
} else {
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), vec![value]);
}
}
// default value handling // default value handling
for k in keys { for k in keys {
if mapping.contains_key(&k.name) { if mapping.contains_key(&k.name) {
continue; continue;
} }
if k.is_vararg {
mapping.insert(
get_va_count_arg_name(k.name),
vec![llvm_usize.const_zero().into()],
);
mapping.insert(k.name, Vec::default());
} else {
mapping.insert( mapping.insert(
k.name, k.name,
vec![ctx ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into(),
.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty)
.into()],
); );
} }
}
// reorder the parameters // reorder the parameters
let mut real_params = fun let mut real_params = fun
.0 .0
@ -893,24 +801,13 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
.map(|arg| (mapping.remove(&arg.name).unwrap(), arg.ty)) .map(|arg| (mapping.remove(&arg.name).unwrap(), arg.ty))
.collect_vec(); .collect_vec();
if let Some(obj) = &obj { if let Some(obj) = &obj {
real_params.insert(0, (vec![obj.1.clone()], obj.0)); real_params.insert(0, (obj.1.clone(), obj.0));
} }
if let Some(vararg) = vararg_arg {
let vararg_arg_name = get_va_count_arg_name(vararg.name);
real_params.insert(
real_params.len() - 1,
(mapping[&vararg_arg_name].clone(), ctx.primitives.usize()),
);
}
let static_params = real_params let static_params = real_params
.iter() .iter()
.enumerate() .enumerate()
.filter_map(|(i, (v, _))| { .filter_map(|(i, (v, _))| {
if v.len() != 1 { if let ValueEnum::Static(s) = v {
None
} else if let ValueEnum::Static(s) = &v[0] {
Some((i, s.clone())) Some((i, s.clone()))
} else { } else {
None None
@ -940,13 +837,8 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
}; };
param_vals = real_params param_vals = real_params
.into_iter() .into_iter()
.map(|(ps, t)| { .map(|(p, t)| p.to_basic_value_enum(ctx, generator, t))
ps.into_iter().map(|p| p.to_basic_value_enum(ctx, generator, t)).collect() .collect::<Result<Vec<_>, String>>()?;
})
.collect::<Result<Vec<Vec<_>>, _>>()?
.into_iter()
.flatten()
.collect::<Vec<_>>();
instance_to_symbol.get(&key).cloned().ok_or_else(String::new) instance_to_symbol.get(&key).cloned().ok_or_else(String::new)
} }
TopLevelDef::Class { .. } => { TopLevelDef::Class { .. } => {
@ -960,10 +852,7 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
let fun_val = ctx.module.get_function(&symbol).unwrap_or_else(|| { let fun_val = ctx.module.get_function(&symbol).unwrap_or_else(|| {
let mut args = fun.0.args.clone(); let mut args = fun.0.args.clone();
if let Some(obj) = &obj { if let Some(obj) = &obj {
args.insert( args.insert(0, FuncArg { name: "self".into(), ty: obj.0, default_value: None });
0,
FuncArg { name: "self".into(), ty: obj.0, default_value: None, is_vararg: false },
);
} }
let ret_type = if ctx.unifier.unioned(fun.0.ret, ctx.primitives.none) { let ret_type = if ctx.unifier.unioned(fun.0.ret, ctx.primitives.none) {
None None
@ -975,7 +864,6 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
let mut params = args let mut params = args
.iter() .iter()
.enumerate() .enumerate()
.filter(|(_, arg)| !arg.is_vararg)
.map(|(i, arg)| { .map(|(i, arg)| {
match ctx.get_llvm_abi_type(generator, arg.ty) { match ctx.get_llvm_abi_type(generator, arg.ty) {
BasicTypeEnum::StructType(ty) if is_extern => { BasicTypeEnum::StructType(ty) if is_extern => {
@ -990,13 +878,9 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
if has_sret { if has_sret {
params.insert(0, ret_type.unwrap().ptr_type(AddressSpace::default()).into()); params.insert(0, ret_type.unwrap().ptr_type(AddressSpace::default()).into());
} }
let is_vararg = args.iter().any(|arg| arg.is_vararg);
if is_vararg {
params.push(generator.get_size_type(ctx.ctx).into());
}
let fun_ty = match ret_type { let fun_ty = match ret_type {
Some(ret_type) if !has_sret => ret_type.fn_type(&params, is_vararg), Some(ret_type) if !has_sret => ret_type.fn_type(&params, false),
_ => ctx.ctx.void_type().fn_type(&params, is_vararg), _ => ctx.ctx.void_type().fn_type(&params, false),
}; };
let fun_val = ctx.module.add_function(&symbol, fun_ty, None); let fun_val = ctx.module.add_function(&symbol, fun_ty, None);
let offset = if has_sret { let offset = if has_sret {
@ -1028,16 +912,13 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
}); });
// Convert boolean parameter values into i1 // Convert boolean parameter values into i1
let vararg_ty = vararg_arg.map(|vararg| ctx.get_llvm_abi_type(generator, vararg.ty));
let param_vals = fun_val let param_vals = fun_val
.get_params() .get_params()
.iter() .iter()
.map(BasicValueEnum::get_type)
.chain(repeat_with(|| vararg_ty.unwrap()))
.zip(param_vals) .zip(param_vals)
.map(|(p, v)| { .map(|(p, v)| {
if p.is_int_type() && v.is_int_value() { if p.is_int_value() && v.is_int_value() {
let expected_ty = p.into_int_type(); let expected_ty = p.into_int_value().get_type();
let param_val = v.into_int_value(); let param_val = v.into_int_value();
if expected_ty.get_bit_width() == 1 && param_val.get_type().get_bit_width() != 1 { if expected_ty.get_bit_width() == 1 && param_val.get_type().get_bit_width() != 1 {
@ -1114,10 +995,8 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
ctx.builder.position_at_end(init_bb); ctx.builder.position_at_end(init_bb);
let Comprehension { target, iter, ifs, .. } = &generators[0]; let Comprehension { target, iter, ifs, .. } = &generators[0];
let iter_ty = iter.custom.unwrap();
let iter_val = if let Some(v) = generator.gen_expr(ctx, iter)? { let iter_val = if let Some(v) = generator.gen_expr(ctx, iter)? {
v.to_basic_value_enum(ctx, generator, iter_ty)? v.to_basic_value_enum(ctx, generator, iter.custom.unwrap())?
} else { } else {
for bb in [test_bb, body_bb, cont_bb] { for bb in [test_bb, body_bb, cont_bb] {
ctx.builder.position_at_end(bb); ctx.builder.position_at_end(bb);
@ -1135,12 +1014,11 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
ctx.builder.build_store(index, zero_size_t).unwrap(); ctx.builder.build_store(index, zero_size_t).unwrap();
let elem_ty = ctx.get_llvm_type(generator, elt.custom.unwrap()); let elem_ty = ctx.get_llvm_type(generator, elt.custom.unwrap());
let is_range = ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range);
let list; let list;
let list_content;
match &*ctx.unifier.get_ty(iter_ty) { if is_range {
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() =>
{
let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range")); let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range"));
let (start, stop, step) = destructure_range(ctx, iter_val); let (start, stop, step) = destructure_range(ctx, iter_val);
let diff = ctx.builder.build_int_sub(stop, start, "diff").unwrap(); let diff = ctx.builder.build_int_sub(stop, start, "diff").unwrap();
@ -1148,8 +1026,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
// the length may be 1 more than the actual length if the division is exact, but the // the length may be 1 more than the actual length if the division is exact, but the
// length is a upper bound only anyway so it does not matter. // length is a upper bound only anyway so it does not matter.
let length = ctx.builder.build_int_signed_div(diff, step, "div").unwrap(); let length = ctx.builder.build_int_signed_div(diff, step, "div").unwrap();
let length = let length = ctx.builder.build_int_add(length, int32.const_int(1, false), "add1").unwrap();
ctx.builder.build_int_add(length, int32.const_int(1, false), "add1").unwrap();
// in case length is non-positive // in case length is non-positive
let is_valid = let is_valid =
ctx.builder.build_int_compare(IntPredicate::SGT, length, zero_32, "check").unwrap(); ctx.builder.build_int_compare(IntPredicate::SGT, length, zero_32, "check").unwrap();
@ -1158,9 +1035,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
.builder .builder
.build_select( .build_select(
is_valid, is_valid,
ctx.builder ctx.builder.build_int_z_extend_or_bit_cast(length, size_t, "z_ext_len").unwrap(),
.build_int_z_extend_or_bit_cast(length, size_t, "z_ext_len")
.unwrap(),
zero_size_t, zero_size_t,
"listcomp.alloc_size", "listcomp.alloc_size",
) )
@ -1172,6 +1047,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
list_alloc_size.into_int_value(), list_alloc_size.into_int_value(),
Some("listcomp.addr"), Some("listcomp.addr"),
); );
list_content = list.data().base_ptr(ctx, generator);
let i = generator.gen_store_target(ctx, target, Some("i.addr"))?.unwrap(); let i = generator.gen_store_target(ctx, target, Some("i.addr"))?.unwrap();
ctx.builder ctx.builder
@ -1179,11 +1055,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
.unwrap(); .unwrap();
ctx.builder ctx.builder
.build_conditional_branch( .build_conditional_branch(gen_in_range_check(ctx, start, stop, step), test_bb, cont_bb)
gen_in_range_check(ctx, start, stop, step),
test_bb,
cont_bb,
)
.unwrap(); .unwrap();
ctx.builder.position_at_end(test_bb); ctx.builder.position_at_end(test_bb);
@ -1198,18 +1070,11 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
.unwrap(); .unwrap();
ctx.builder.build_store(i, tmp).unwrap(); ctx.builder.build_store(i, tmp).unwrap();
ctx.builder ctx.builder
.build_conditional_branch( .build_conditional_branch(gen_in_range_check(ctx, tmp, stop, step), body_bb, cont_bb)
gen_in_range_check(ctx, tmp, stop, step),
body_bb,
cont_bb,
)
.unwrap(); .unwrap();
ctx.builder.position_at_end(body_bb); ctx.builder.position_at_end(body_bb);
} } else {
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{
let length = ctx let length = ctx
.build_gep_and_load( .build_gep_and_load(
iter_val.into_pointer_value(), iter_val.into_pointer_value(),
@ -1218,15 +1083,14 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
) )
.into_int_value(); .into_int_value();
list = allocate_list(generator, ctx, Some(elem_ty), length, Some("listcomp")); list = allocate_list(generator, ctx, Some(elem_ty), length, Some("listcomp"));
list_content = list.data().base_ptr(ctx, generator);
let counter = generator.gen_var_alloc(ctx, size_t.into(), Some("counter.addr"))?; let counter = generator.gen_var_alloc(ctx, size_t.into(), Some("counter.addr"))?;
// counter = -1 // counter = -1
ctx.builder.build_store(counter, size_t.const_all_ones()).unwrap(); ctx.builder.build_store(counter, size_t.const_int(u64::MAX, true)).unwrap();
ctx.builder.build_unconditional_branch(test_bb).unwrap(); ctx.builder.build_unconditional_branch(test_bb).unwrap();
ctx.builder.position_at_end(test_bb); ctx.builder.position_at_end(test_bb);
let tmp = let tmp = ctx.builder.build_load(counter, "i").map(BasicValueEnum::into_int_value).unwrap();
ctx.builder.build_load(counter, "i").map(BasicValueEnum::into_int_value).unwrap();
let tmp = ctx.builder.build_int_add(tmp, size_t.const_int(1, false), "inc").unwrap(); let tmp = ctx.builder.build_int_add(tmp, size_t.const_int(1, false), "inc").unwrap();
ctx.builder.build_store(counter, tmp).unwrap(); ctx.builder.build_store(counter, tmp).unwrap();
let cmp = ctx.builder.build_int_compare(IntPredicate::SLT, tmp, length, "cmp").unwrap(); let cmp = ctx.builder.build_int_compare(IntPredicate::SLT, tmp, length, "cmp").unwrap();
@ -1241,14 +1105,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
) )
.into_pointer_value(); .into_pointer_value();
let val = ctx.build_gep_and_load(arr_ptr, &[tmp], Some("val")); let val = ctx.build_gep_and_load(arr_ptr, &[tmp], Some("val"));
generator.gen_assign(ctx, target, val.into(), elt.custom.unwrap())?; generator.gen_assign(ctx, target, val.into())?;
}
_ => {
panic!(
"unsupported list comprehension iterator type: {}",
ctx.unifier.stringify(iter_ty)
);
}
} }
// Emits the content of `cont_bb` // Emits the content of `cont_bb`
@ -1286,8 +1143,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
return Ok(None); return Ok(None);
}; };
let i = ctx.builder.build_load(index, "i").map(BasicValueEnum::into_int_value).unwrap(); let i = ctx.builder.build_load(index, "i").map(BasicValueEnum::into_int_value).unwrap();
let elem_ptr = let elem_ptr = unsafe { ctx.builder.build_gep(list_content, &[i], "elem_ptr") }.unwrap();
unsafe { list.data().ptr_offset_unchecked(ctx, generator, &i, Some("elem_ptr")) };
let val = elem.to_basic_value_enum(ctx, generator, elt.custom.unwrap())?; let val = elem.to_basic_value_enum(ctx, generator, elt.custom.unwrap())?;
ctx.builder.build_store(elem_ptr, val).unwrap(); ctx.builder.build_store(elem_ptr, val).unwrap();
ctx.builder ctx.builder
@ -1370,7 +1226,6 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
debug_assert!(ctx.unifier.unioned(elem_ty1, elem_ty2)); debug_assert!(ctx.unifier.unioned(elem_ty1, elem_ty2));
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty1); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty1);
let sizeof_elem = llvm_elem_ty.size_of().unwrap();
let lhs = ListValue::from_ptr_val(left_val.into_pointer_value(), llvm_usize, None); let lhs = ListValue::from_ptr_val(left_val.into_pointer_value(), llvm_usize, None);
let rhs = ListValue::from_ptr_val(right_val.into_pointer_value(), llvm_usize, None); let rhs = ListValue::from_ptr_val(right_val.into_pointer_value(), llvm_usize, None);
@ -1382,25 +1237,14 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
let new_list = allocate_list(generator, ctx, Some(llvm_elem_ty), size, None); let new_list = allocate_list(generator, ctx, Some(llvm_elem_ty), size, None);
let lhs_size = ctx let lhs_len = ctx
.builder .builder
.build_int_z_extend_or_bit_cast( .build_int_mul(lhs.load_size(ctx, None), llvm_elem_ty.size_of().unwrap(), "")
lhs.load_size(ctx, None),
sizeof_elem.get_type(),
"",
)
.unwrap(); .unwrap();
let lhs_len = ctx.builder.build_int_mul(lhs_size, sizeof_elem, "").unwrap(); let rhs_len = ctx
let rhs_size = ctx
.builder .builder
.build_int_z_extend_or_bit_cast( .build_int_mul(rhs.load_size(ctx, None), llvm_elem_ty.size_of().unwrap(), "")
rhs.load_size(ctx, None),
sizeof_elem.get_type(),
"",
)
.unwrap(); .unwrap();
let rhs_len = ctx.builder.build_int_mul(rhs_size, sizeof_elem, "").unwrap();
let list_ptr = new_list.data().base_ptr(ctx, generator); let list_ptr = new_list.data().base_ptr(ctx, generator);
call_memcpy_generic( call_memcpy_generic(
@ -1465,7 +1309,6 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
let int_val = call_int_smax(ctx, int_val, llvm_usize.const_zero(), None); let int_val = call_int_smax(ctx, int_val, llvm_usize.const_zero(), None);
let elem_llvm_ty = ctx.get_llvm_type(generator, elem_ty); let elem_llvm_ty = ctx.get_llvm_type(generator, elem_ty);
let sizeof_elem = elem_llvm_ty.size_of().unwrap();
let new_list = allocate_list( let new_list = allocate_list(
generator, generator,
@ -1478,7 +1321,6 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
gen_for_callback_incrementing( gen_for_callback_incrementing(
generator, generator,
ctx, ctx,
None,
llvm_usize.const_zero(), llvm_usize.const_zero(),
(int_val, false), (int_val, false),
|generator, ctx, _, i| { |generator, ctx, _, i| {
@ -1490,18 +1332,15 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
new_list.data().ptr_offset_unchecked(ctx, generator, &offset, None) new_list.data().ptr_offset_unchecked(ctx, generator, &offset, None)
}; };
let list_size = ctx let memcpy_sz = ctx
.builder .builder
.build_int_z_extend_or_bit_cast( .build_int_mul(
list_val.load_size(ctx, None), list_val.load_size(ctx, None),
sizeof_elem.get_type(), elem_llvm_ty.size_of().unwrap(),
"", "",
) )
.unwrap(); .unwrap();
let memcpy_sz =
ctx.builder.build_int_mul(list_size, sizeof_elem, "").unwrap();
call_memcpy_generic( call_memcpy_generic(
ctx, ctx,
ptr, ptr,
@ -2089,7 +1928,6 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
gen_for_callback_incrementing( gen_for_callback_incrementing(
generator, generator,
ctx, ctx,
None,
llvm_usize.const_zero(), llvm_usize.const_zero(),
(left_val.load_size(ctx, None), false), (left_val.load_size(ctx, None), false),
|generator, ctx, hooks, i| { |generator, ctx, hooks, i| {
@ -2310,7 +2148,6 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type(); let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type(); let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum(); let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum();
let sizeof_elem = llvm_ndarray_data_t.size_of().unwrap();
// Check that len is non-zero // Check that len is non-zero
let len = v.load_ndims(ctx); let len = v.load_ndims(ctx);
@ -2521,14 +2358,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
let ndarray_num_dims = ndarray.load_ndims(ctx); let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims); ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
let ndarray_num_dims = ctx let ndarray_num_dims = ndarray.load_ndims(ctx);
.builder
.build_int_z_extend_or_bit_cast(
ndarray.load_ndims(ctx),
llvm_usize.size_of().get_type(),
"",
)
.unwrap();
let v_dims_src_ptr = unsafe { let v_dims_src_ptr = unsafe {
v.dim_sizes().ptr_offset_unchecked( v.dim_sizes().ptr_offset_unchecked(
ctx, ctx,
@ -2554,10 +2384,6 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
&ndarray.dim_sizes().as_slice_value(ctx, generator), &ndarray.dim_sizes().as_slice_value(ctx, generator),
(None, None), (None, None),
); );
let ndarray_num_elems = ctx
.builder
.build_int_z_extend_or_bit_cast(ndarray_num_elems, sizeof_elem.get_type(), "")
.unwrap();
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems); ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
let v_data_src_ptr = v.data().ptr_offset(ctx, generator, &index_addr, None); let v_data_src_ptr = v.data().ptr_offset(ctx, generator, &index_addr, None);

View File

@ -130,62 +130,3 @@ pub fn call_ldexp<'ctx>(
.map(Either::unwrap_left) .map(Either::unwrap_left)
.unwrap() .unwrap()
} }
/// Macro to generate `np_linalg` and `sp_linalg` functions
/// The function takes as input `NDArray` and returns ()
///
/// Arguments:
/// * `$fn_name:ident`: The identifier of the rust function to be generated
/// * `$extern_fn:literal`: Name of underlying extern function
/// * (2/3/4): Number of `NDArray` that function takes as input
///
/// Note:
/// The operands and resulting `NDArray` are both passed as input to the funcion
/// It is the responsibility of caller to ensure that output `NDArray` is properly allocated on stack
/// The function changes the content of the output `NDArray` in-place
macro_rules! generate_linalg_extern_fn {
($fn_name:ident, $extern_fn:literal, 2) => {
generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2);
};
($fn_name:ident, $extern_fn:literal, 3) => {
generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2, mat3);
};
($fn_name:ident, $extern_fn:literal, 4) => {
generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2, mat3, mat4);
};
($fn_name:ident, $extern_fn:literal $(,$input_matrix:ident)*) => {
#[doc = concat!("Invokes the linalg `", stringify!($extern_fn), " function." )]
pub fn $fn_name<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>
$(,$input_matrix: BasicValueEnum<'ctx>)*,
name: Option<&str>,
){
const FN_NAME: &str = $extern_fn;
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = ctx.ctx.void_type().fn_type(&[$($input_matrix.get_type().into()),*], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder.build_call(extern_fn, &[$($input_matrix.into(),)*], name.unwrap_or_default()).unwrap();
}
};
}
generate_linalg_extern_fn!(call_np_linalg_cholesky, "np_linalg_cholesky", 2);
generate_linalg_extern_fn!(call_np_linalg_qr, "np_linalg_qr", 3);
generate_linalg_extern_fn!(call_np_linalg_svd, "np_linalg_svd", 4);
generate_linalg_extern_fn!(call_np_linalg_inv, "np_linalg_inv", 2);
generate_linalg_extern_fn!(call_np_linalg_pinv, "np_linalg_pinv", 2);
generate_linalg_extern_fn!(call_np_linalg_matrix_power, "np_linalg_matrix_power", 3);
generate_linalg_extern_fn!(call_np_linalg_det, "np_linalg_det", 2);
generate_linalg_extern_fn!(call_sp_linalg_lu, "sp_linalg_lu", 3);
generate_linalg_extern_fn!(call_sp_linalg_schur, "sp_linalg_schur", 3);
generate_linalg_extern_fn!(call_sp_linalg_hessenberg, "sp_linalg_hessenberg", 3);

View File

@ -123,45 +123,11 @@ pub trait CodeGenerator {
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
target: &Expr<Option<Type>>, target: &Expr<Option<Type>>,
value: ValueEnum<'ctx>, value: ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String> ) -> Result<(), String>
where where
Self: Sized, Self: Sized,
{ {
gen_assign(self, ctx, target, value, value_ty) gen_assign(self, ctx, target, value)
}
/// Generate code for an assignment expression where LHS is a `"target_list"`.
///
/// See <https://docs.python.org/3/reference/simple_stmts.html#assignment-statements>.
fn gen_assign_target_list<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
targets: &Vec<Expr<Option<Type>>>,
value: ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String>
where
Self: Sized,
{
gen_assign_target_list(self, ctx, targets, value, value_ty)
}
/// Generate code for an item assignment.
///
/// i.e., `target[key] = value`
fn gen_setitem<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
target: &Expr<Option<Type>>,
key: &Expr<Option<Type>>,
value: ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String>
where
Self: Sized,
{
gen_setitem(self, ctx, target, key, value, value_ty)
} }
/// Generate code for a while expression. /// Generate code for a while expression.

View File

@ -798,7 +798,6 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
gen_for_callback_incrementing( gen_for_callback_incrementing(
generator, generator,
ctx, ctx,
None,
llvm_usize.const_zero(), llvm_usize.const_zero(),
(min_ndims, false), (min_ndims, false),
|generator, ctx, _, idx| { |generator, ctx, _, idx| {

View File

@ -35,40 +35,6 @@ fn get_float_intrinsic_repr(ctx: &Context, ft: FloatType) -> &'static str {
unreachable!() unreachable!()
} }
/// Invokes the [`llvm.va_start`](https://llvm.org/docs/LangRef.html#llvm-va-start-intrinsic)
/// intrinsic.
pub fn call_va_start<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue<'ctx>) {
const FN_NAME: &str = "llvm.va_start";
let intrinsic_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let llvm_void = ctx.ctx.void_type();
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let fn_type = llvm_void.fn_type(&[llvm_p0i8.into()], false);
ctx.module.add_function(FN_NAME, fn_type, None)
});
ctx.builder.build_call(intrinsic_fn, &[arglist.into()], "").unwrap();
}
/// Invokes the [`llvm.va_start`](https://llvm.org/docs/LangRef.html#llvm-va-start-intrinsic)
/// intrinsic.
pub fn call_va_end<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue<'ctx>) {
const FN_NAME: &str = "llvm.va_end";
let intrinsic_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let llvm_void = ctx.ctx.void_type();
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let fn_type = llvm_void.fn_type(&[llvm_p0i8.into()], false);
ctx.module.add_function(FN_NAME, fn_type, None)
});
ctx.builder.build_call(intrinsic_fn, &[arglist.into()], "").unwrap();
}
/// Invokes the [`llvm.stacksave`](https://llvm.org/docs/LangRef.html#llvm-stacksave-intrinsic) /// Invokes the [`llvm.stacksave`](https://llvm.org/docs/LangRef.html#llvm-stacksave-intrinsic)
/// intrinsic. /// intrinsic.
pub fn call_stacksave<'ctx>( pub fn call_stacksave<'ctx>(

View File

@ -68,16 +68,6 @@ pub struct CodeGenLLVMOptions {
pub target: CodeGenTargetMachineOptions, pub target: CodeGenTargetMachineOptions,
} }
impl CodeGenLLVMOptions {
/// Creates a [`TargetMachine`] using the target options specified by this struct.
///
/// See [`Target::create_target_machine`].
#[must_use]
pub fn create_target_machine(&self) -> Option<TargetMachine> {
self.target.create_target_machine(self.opt_level)
}
}
/// Additional options for code generation for the target machine. /// Additional options for code generation for the target machine.
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug, Eq, PartialEq)]
pub struct CodeGenTargetMachineOptions { pub struct CodeGenTargetMachineOptions {
@ -348,10 +338,6 @@ impl WorkerRegistry {
let mut builder = context.create_builder(); let mut builder = context.create_builder();
let mut module = context.create_module(generator.get_name()); let mut module = context.create_module(generator.get_name());
let target_machine = self.llvm_options.create_target_machine().unwrap();
module.set_data_layout(&target_machine.get_target_data().get_data_layout());
module.set_triple(&target_machine.get_triple());
module.add_basic_value_flag( module.add_basic_value_flag(
"Debug Info Version", "Debug Info Version",
inkwell::module::FlagBehavior::Warning, inkwell::module::FlagBehavior::Warning,
@ -375,10 +361,6 @@ impl WorkerRegistry {
errors.insert(e); errors.insert(e);
// create a new empty module just to continue codegen and collect errors // create a new empty module just to continue codegen and collect errors
module = context.create_module(&format!("{}_recover", generator.get_name())); module = context.create_module(&format!("{}_recover", generator.get_name()));
let target_machine = self.llvm_options.create_target_machine().unwrap();
module.set_data_layout(&target_machine.get_target_data().get_data_layout());
module.set_triple(&target_machine.get_triple());
} }
} }
*self.task_count.lock() -= 1; *self.task_count.lock() -= 1;
@ -444,7 +426,7 @@ pub struct CodeGenTask {
fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
ctx: &'ctx Context, ctx: &'ctx Context,
module: &Module<'ctx>, module: &Module<'ctx>,
generator: &G, generator: &mut G,
unifier: &mut Unifier, unifier: &mut Unifier,
top_level: &TopLevelContext, top_level: &TopLevelContext,
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>, type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
@ -538,10 +520,8 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
}; };
return ty; return ty;
} }
TTuple { ty, is_vararg_ctx } => { TTuple { ty } => {
// a struct with fields in the order present in the tuple // a struct with fields in the order present in the tuple
assert!(!is_vararg_ctx, "Tuples in vararg context must be instantiated with the correct number of arguments before calling get_llvm_type");
let fields = ty let fields = ty
.iter() .iter()
.map(|ty| { .map(|ty| {
@ -571,7 +551,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
fn get_llvm_abi_type<'ctx, G: CodeGenerator + ?Sized>( fn get_llvm_abi_type<'ctx, G: CodeGenerator + ?Sized>(
ctx: &'ctx Context, ctx: &'ctx Context,
module: &Module<'ctx>, module: &Module<'ctx>,
generator: &G, generator: &mut G,
unifier: &mut Unifier, unifier: &mut Unifier,
top_level: &TopLevelContext, top_level: &TopLevelContext,
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>, type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
@ -609,40 +589,6 @@ fn need_sret(ty: BasicTypeEnum) -> bool {
need_sret_impl(ty, true) need_sret_impl(ty, true)
} }
/// Returns the [`BasicTypeEnum`] representing a `va_list` struct for variadic arguments.
fn get_llvm_valist_type<'ctx>(ctx: &'ctx Context, triple: &TargetTriple) -> BasicTypeEnum<'ctx> {
let triple = TargetMachine::normalize_triple(triple);
let triple = triple.as_str().to_str().unwrap();
let arch = triple.split('-').next().unwrap();
let llvm_pi8 = ctx.i8_type().ptr_type(AddressSpace::default());
// Referenced from parseArch() in llvm/lib/Support/Triple.cpp
match arch {
"i386" | "i486" | "i586" | "i686" | "riscv32" => {
ctx.i8_type().ptr_type(AddressSpace::default()).into()
}
"amd64" | "x86_64" | "x86_64h" => {
let llvm_i32 = ctx.i32_type();
let va_list_tag = ctx.opaque_struct_type("struct.__va_list_tag");
va_list_tag.set_body(
&[llvm_i32.into(), llvm_i32.into(), llvm_pi8.into(), llvm_pi8.into()],
false,
);
va_list_tag.into()
}
"armv7" => {
let va_list = ctx.opaque_struct_type("struct.__va_list");
va_list.set_body(&[llvm_pi8.into()], false);
va_list.into()
}
triple => {
todo!("Unsupported platform for varargs: {triple}")
}
}
}
/// Implementation for generating LLVM IR for a function. /// Implementation for generating LLVM IR for a function.
pub fn gen_func_impl< pub fn gen_func_impl<
'ctx, 'ctx,
@ -754,7 +700,6 @@ pub fn gen_func_impl<
name: arg.name, name: arg.name,
ty: task.store.to_unifier_type(&mut unifier, &primitives, arg.ty, &mut cache), ty: task.store.to_unifier_type(&mut unifier, &primitives, arg.ty, &mut cache),
default_value: arg.default_value.clone(), default_value: arg.default_value.clone(),
is_vararg: arg.is_vararg,
}) })
.collect_vec(), .collect_vec(),
task.store.to_unifier_type(&mut unifier, &primitives, *ret, &mut cache), task.store.to_unifier_type(&mut unifier, &primitives, *ret, &mut cache),
@ -777,10 +722,7 @@ pub fn gen_func_impl<
let has_sret = ret_type.map_or(false, |ty| need_sret(ty)); let has_sret = ret_type.map_or(false, |ty| need_sret(ty));
let mut params = args let mut params = args
.iter() .iter()
.filter(|arg| !arg.is_vararg)
.map(|arg| { .map(|arg| {
debug_assert!(!arg.is_vararg);
get_llvm_abi_type( get_llvm_abi_type(
context, context,
&module, &module,
@ -799,12 +741,9 @@ pub fn gen_func_impl<
params.insert(0, ret_type.unwrap().ptr_type(AddressSpace::default()).into()); params.insert(0, ret_type.unwrap().ptr_type(AddressSpace::default()).into());
} }
debug_assert!(matches!(args.iter().filter(|arg| arg.is_vararg).count(), 0..=1));
let vararg_arg = args.iter().find(|arg| arg.is_vararg);
let fn_type = match ret_type { let fn_type = match ret_type {
Some(ret_type) if !has_sret => ret_type.fn_type(&params, vararg_arg.is_some()), Some(ret_type) if !has_sret => ret_type.fn_type(&params, false),
_ => context.void_type().fn_type(&params, vararg_arg.is_some()), _ => context.void_type().fn_type(&params, false),
}; };
let symbol = &task.symbol_name; let symbol = &task.symbol_name;
@ -834,9 +773,7 @@ pub fn gen_func_impl<
let mut var_assignment = HashMap::new(); let mut var_assignment = HashMap::new();
let offset = u32::from(has_sret); let offset = u32::from(has_sret);
for (n, arg) in args.iter().enumerate() {
// Store non-vararg argument values into local variables
for (n, arg) in args.iter().enumerate().filter(|(_, arg)| !arg.is_vararg) {
let param = fn_val.get_nth_param((n as u32) + offset).unwrap(); let param = fn_val.get_nth_param((n as u32) + offset).unwrap();
let local_type = get_llvm_type( let local_type = get_llvm_type(
context, context,
@ -869,8 +806,6 @@ pub fn gen_func_impl<
var_assignment.insert(arg.name, (alloca, None, 0)); var_assignment.insert(arg.name, (alloca, None, 0));
} }
// TODO: Save vararg parameters as list
let return_buffer = if has_sret { let return_buffer = if has_sret {
Some(fn_val.get_nth_param(0).unwrap().into_pointer_value()) Some(fn_val.get_nth_param(0).unwrap().into_pointer_value())
} else { } else {
@ -1093,9 +1028,3 @@ fn gen_in_range_check<'ctx>(
ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp").unwrap() ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp").unwrap()
} }
/// Returns the internal name for the `va_count` argument, used to indicate the number of arguments
/// passed to the variadic function.
fn get_va_count_arg_name(arg_name: StrRef) -> StrRef {
format!("__{}_va_count", &arg_name).into()
}

View File

@ -26,15 +26,12 @@ use crate::{
typedef::{FunSignature, Type, TypeEnum}, typedef::{FunSignature, Type, TypeEnum},
}, },
}; };
use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType};
use inkwell::{ use inkwell::{
types::BasicType, types::BasicType,
values::{BasicValueEnum, IntValue, PointerValue}, values::{BasicValueEnum, IntValue, PointerValue},
AddressSpace, IntPredicate, OptimizationLevel, AddressSpace, IntPredicate, OptimizationLevel,
}; };
use inkwell::{
types::{AnyTypeEnum, BasicTypeEnum, PointerType},
values::BasicValue,
};
use nac3parser::ast::{Operator, StrRef}; use nac3parser::ast::{Operator, StrRef};
/// Creates an uninitialized `NDArray` instance. /// Creates an uninitialized `NDArray` instance.
@ -89,7 +86,6 @@ where
gen_for_callback_incrementing( gen_for_callback_incrementing(
generator, generator,
ctx, ctx,
None,
llvm_usize.const_zero(), llvm_usize.const_zero(),
(shape_len, false), (shape_len, false),
|generator, ctx, _, i| { |generator, ctx, _, i| {
@ -135,7 +131,6 @@ where
gen_for_callback_incrementing( gen_for_callback_incrementing(
generator, generator,
ctx, ctx,
None,
llvm_usize.const_zero(), llvm_usize.const_zero(),
(shape_len, false), (shape_len, false),
|generator, ctx, _, i| { |generator, ctx, _, i| {
@ -162,7 +157,7 @@ where
/// ///
/// * `elem_ty` - The element type of the `NDArray`. /// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The shape of the `NDArray`, represented am array of [`IntValue`]s. /// * `shape` - The shape of the `NDArray`, represented am array of [`IntValue`]s.
pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>( fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type, elem_ty: Type,
@ -387,7 +382,6 @@ where
gen_for_callback_incrementing( gen_for_callback_incrementing(
generator, generator,
ctx, ctx,
None,
llvm_usize.const_zero(), llvm_usize.const_zero(),
(ndarray_num_elems, false), (ndarray_num_elems, false),
|generator, ctx, _, i| { |generator, ctx, _, i| {
@ -709,12 +703,11 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>(
gen_for_range_callback( gen_for_range_callback(
generator, generator,
ctx, ctx,
None,
true, true,
|_, _| Ok(llvm_usize.const_zero()), |_, _| Ok(llvm_usize.const_zero()),
(|_, ctx| Ok(src_lst.load_size(ctx, None)), false), (|_, ctx| Ok(src_lst.load_size(ctx, None)), false),
|_, _| Ok(llvm_usize.const_int(1, false)), |_, _| Ok(llvm_usize.const_int(1, false)),
|generator, ctx, _, i| { |generator, ctx, i| {
let offset = ctx.builder.build_int_mul(stride, i, "").unwrap(); let offset = ctx.builder.build_int_mul(stride, i, "").unwrap();
let dst_ptr = let dst_ptr =
@ -950,12 +943,11 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
gen_for_range_callback( gen_for_range_callback(
generator, generator,
ctx, ctx,
None,
true, true,
|_, _| Ok(llvm_usize.const_zero()), |_, _| Ok(llvm_usize.const_zero()),
(|_, _| Ok(stop), false), (|_, _| Ok(stop), false),
|_, _| Ok(llvm_usize.const_int(1, false)), |_, _| Ok(llvm_usize.const_int(1, false)),
|generator, ctx, _, _| { |generator, ctx, _| {
let plist_plist_i8 = make_llvm_list(llvm_plist_i8.into()) let plist_plist_i8 = make_llvm_list(llvm_plist_i8.into())
.ptr_type(AddressSpace::default()); .ptr_type(AddressSpace::default());
@ -1094,17 +1086,13 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
// If there are no (remaining) slice expressions, memcpy the entire dimension // If there are no (remaining) slice expressions, memcpy the entire dimension
if slices.is_empty() { if slices.is_empty() {
let sizeof_elem = ctx.get_llvm_type(generator, elem_ty).size_of().unwrap();
let stride = call_ndarray_calc_size( let stride = call_ndarray_calc_size(
generator, generator,
ctx, ctx,
&src_arr.dim_sizes(), &src_arr.dim_sizes(),
(Some(llvm_usize.const_int(dim, false)), None), (Some(llvm_usize.const_int(dim, false)), None),
); );
let stride = let sizeof_elem = ctx.get_llvm_type(generator, elem_ty).size_of().unwrap();
ctx.builder.build_int_z_extend_or_bit_cast(stride, sizeof_elem.get_type(), "").unwrap();
let cpy_len = ctx.builder.build_int_mul(stride, sizeof_elem, "").unwrap(); let cpy_len = ctx.builder.build_int_mul(stride, sizeof_elem, "").unwrap();
call_memcpy_generic(ctx, dst_slice_ptr, src_slice_ptr, cpy_len, llvm_i1.const_zero()); call_memcpy_generic(ctx, dst_slice_ptr, src_slice_ptr, cpy_len, llvm_i1.const_zero());
@ -1138,12 +1126,11 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
gen_for_range_callback( gen_for_range_callback(
generator, generator,
ctx, ctx,
None,
false, false,
|_, _| Ok(start), |_, _| Ok(start),
(|_, _| Ok(stop), true), (|_, _| Ok(stop), true),
|_, _| Ok(step), |_, _| Ok(step),
|generator, ctx, _, src_i| { |generator, ctx, src_i| {
// Calculate the offset of the active slice // Calculate the offset of the active slice
let src_data_offset = ctx.builder.build_int_mul(src_stride, src_i, "").unwrap(); let src_data_offset = ctx.builder.build_int_mul(src_stride, src_i, "").unwrap();
let dst_i = let dst_i =
@ -1256,7 +1243,6 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
gen_for_callback_incrementing( gen_for_callback_incrementing(
generator, generator,
ctx, ctx,
None,
llvm_usize.const_int(slices.len() as u64, false), llvm_usize.const_int(slices.len() as u64, false),
(this.load_ndims(ctx), false), (this.load_ndims(ctx), false),
|generator, ctx, _, idx| { |generator, ctx, _, idx| {
@ -1661,7 +1647,6 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
gen_for_callback_incrementing( gen_for_callback_incrementing(
generator, generator,
ctx, ctx,
None,
llvm_i32.const_zero(), llvm_i32.const_zero(),
(common_dim, false), (common_dim, false),
|generator, ctx, _, i| { |generator, ctx, _, i| {
@ -2029,493 +2014,3 @@ pub fn gen_ndarray_fill<'ctx>(
Ok(()) Ok(())
} }
/// Generates LLVM IR for `ndarray.transpose`.
pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "ndarray_transpose";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
// Dimensions are reversed in the transposed array
let out = create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&n1,
|_, ctx, n| Ok(n.load_ndims(ctx)),
|generator, ctx, n, idx| {
let new_idx = ctx.builder.build_int_sub(n.load_ndims(ctx), idx, "").unwrap();
let new_idx = ctx
.builder
.build_int_sub(new_idx, new_idx.get_type().const_int(1, false), "")
.unwrap();
unsafe { Ok(n.dim_sizes().get_typed_unchecked(ctx, generator, &new_idx, None)) }
},
)
.unwrap();
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(n_sz, false),
|generator, ctx, _, idx| {
let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
let new_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
let rem_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
ctx.builder.build_store(new_idx, llvm_usize.const_zero()).unwrap();
ctx.builder.build_store(rem_idx, idx).unwrap();
// Incrementally calculate the new index in the transposed array
// For each index, we first decompose it into the n-dims and use those to reconstruct the new index
// The formula used for indexing is:
// idx = dim_n * ( ... (dim2 * (dim0 * dim1) + dim1) + dim2 ... ) + dim_n
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(n1.load_ndims(ctx), false),
|generator, ctx, _, ndim| {
let ndim_rev =
ctx.builder.build_int_sub(n1.load_ndims(ctx), ndim, "").unwrap();
let ndim_rev = ctx
.builder
.build_int_sub(ndim_rev, llvm_usize.const_int(1, false), "")
.unwrap();
let dim = unsafe {
n1.dim_sizes().get_typed_unchecked(ctx, generator, &ndim_rev, None)
};
let rem_idx_val =
ctx.builder.build_load(rem_idx, "").unwrap().into_int_value();
let new_idx_val =
ctx.builder.build_load(new_idx, "").unwrap().into_int_value();
let add_component =
ctx.builder.build_int_unsigned_rem(rem_idx_val, dim, "").unwrap();
let rem_idx_val =
ctx.builder.build_int_unsigned_div(rem_idx_val, dim, "").unwrap();
let new_idx_val = ctx.builder.build_int_mul(new_idx_val, dim, "").unwrap();
let new_idx_val =
ctx.builder.build_int_add(new_idx_val, add_component, "").unwrap();
ctx.builder.build_store(rem_idx, rem_idx_val).unwrap();
ctx.builder.build_store(new_idx, new_idx_val).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let new_idx_val = ctx.builder.build_load(new_idx, "").unwrap().into_int_value();
unsafe { out.data().set_unchecked(ctx, generator, &new_idx_val, elem) };
Ok(())
},
llvm_usize.const_int(1, false),
)?;
Ok(out.as_base_value().into())
} else {
unreachable!(
"{FN_NAME}() not supported for '{}'",
format!("'{}'", ctx.unifier.stringify(x1_ty))
)
}
}
/// LLVM-typed implementation for generating the implementation for `ndarray.reshape`.
///
/// * `x1` - `NDArray` to reshape.
/// * `shape` - The `shape` parameter used to construct the new `NDArray`.
/// Just like numpy, the `shape` argument can be:
/// 1. A list of `int32`; e.g., `np.reshape(arr, [600, -1, 3])`
/// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))`
/// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)`
/// Note that unlike other generating functions, one of the dimesions in the shape can be negative
pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
shape: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "ndarray_reshape";
let (x1_ty, x1) = x1;
let (_, shape) = shape;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
ctx.builder.build_store(acc, llvm_usize.const_int(1, false)).unwrap();
ctx.builder.build_store(num_neg, llvm_usize.const_zero()).unwrap();
let out = match shape {
BasicValueEnum::PointerValue(shape_list_ptr)
if ListValue::is_instance(shape_list_ptr, llvm_usize).is_ok() =>
{
// 1. A list of ints; e.g., `np.reshape(arr, [int64(600), int64(800, -1])`
let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None);
// Check for -1 in dimensions
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(shape_list.load_size(ctx, None), false),
|generator, ctx, _, idx| {
let ele =
shape_list.data().get(ctx, generator, &idx, None).into_int_value();
let ele = ctx.builder.build_int_s_extend(ele, llvm_usize, "").unwrap();
gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
ele,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, ctx| -> Result<Option<IntValue>, String> {
let num_neg_value =
ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
let num_neg_value = ctx
.builder
.build_int_add(
num_neg_value,
llvm_usize.const_int(1, false),
"",
)
.unwrap();
ctx.builder.build_store(num_neg, num_neg_value).unwrap();
Ok(None)
},
|_, ctx| {
let acc_value =
ctx.builder.build_load(acc, "").unwrap().into_int_value();
let acc_value =
ctx.builder.build_int_mul(acc_value, ele, "").unwrap();
ctx.builder.build_store(acc, acc_value).unwrap();
Ok(None)
},
)?;
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
// Generate the output shape by filling -1 with `rem`
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&shape_list,
|_, ctx, _| Ok(shape_list.load_size(ctx, None)),
|generator, ctx, shape_list, idx| {
let dim =
shape_list.data().get(ctx, generator, &idx, None).into_int_value();
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
Ok(gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
dim,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, _| Ok(Some(rem)),
|_, _| Ok(Some(dim)),
)?
.unwrap()
.into_int_value())
},
)
}
BasicValueEnum::StructValue(shape_tuple) => {
// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))`
let ndims = shape_tuple.get_type().count_fields();
// Check for -1 in dims
for dim_i in 0..ndims {
let dim = ctx
.builder
.build_extract_value(shape_tuple, dim_i, "")
.unwrap()
.into_int_value();
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
dim,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, ctx| -> Result<Option<IntValue>, String> {
let num_negs =
ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
let num_negs = ctx
.builder
.build_int_add(num_negs, llvm_usize.const_int(1, false), "")
.unwrap();
ctx.builder.build_store(num_neg, num_negs).unwrap();
Ok(None)
},
|_, ctx| {
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
let acc_val = ctx.builder.build_int_mul(acc_val, dim, "").unwrap();
ctx.builder.build_store(acc, acc_val).unwrap();
Ok(None)
},
)?;
}
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
let mut shape = Vec::with_capacity(ndims as usize);
// Reconstruct shape filling negatives with rem
for dim_i in 0..ndims {
let dim = ctx
.builder
.build_extract_value(shape_tuple, dim_i, "")
.unwrap()
.into_int_value();
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
let dim = gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
dim,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, _| Ok(Some(rem)),
|_, _| Ok(Some(dim)),
)?
.unwrap()
.into_int_value();
shape.push(dim);
}
create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice())
}
BasicValueEnum::IntValue(shape_int) => {
// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)`
let shape_int = gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
shape_int,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, _| Ok(Some(n_sz)),
|_, ctx| {
Ok(Some(ctx.builder.build_int_s_extend(shape_int, llvm_usize, "").unwrap()))
},
)?
.unwrap()
.into_int_value();
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
}
_ => unreachable!(),
}
.unwrap();
// Only allow one dimension to be negative
let num_negs = ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
ctx.make_assert(
generator,
ctx.builder
.build_int_compare(IntPredicate::ULT, num_negs, llvm_usize.const_int(2, false), "")
.unwrap(),
"0:ValueError",
"can only specify one unknown dimension",
[None, None, None],
ctx.current_loc,
);
// The new shape must be compatible with the old shape
let out_sz = call_ndarray_calc_size(generator, ctx, &out.dim_sizes(), (None, None));
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(),
"0:ValueError",
"cannot reshape array of size {0} into provided shape of size {1}",
[Some(n_sz), Some(out_sz), None],
ctx.current_loc,
);
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(n_sz, false),
|generator, ctx, _, idx| {
let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
unsafe { out.data().set_unchecked(ctx, generator, &idx, elem) };
Ok(())
},
llvm_usize.const_int(1, false),
)?;
Ok(out.as_base_value().into())
} else {
unreachable!(
"{FN_NAME}() not supported for '{}'",
format!("'{}'", ctx.unifier.stringify(x1_ty))
)
}
}
/// Generates LLVM IR for `ndarray.dot`.
/// Calculate inner product of two vectors or literals
/// For matrix multiplication use `np_matmul`
///
/// The input `NDArray` are flattened and treated as 1D
/// The operation is equivalent to `np.dot(arr1.ravel(), arr2.ravel())`
pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "ndarray_dot";
let (x1_ty, x1) = x1;
let (_, x2) = x2;
let llvm_usize = generator.get_size_type(ctx.ctx);
match (x1, x2) {
(BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => {
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None);
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::EQ, n1_sz, n2_sz, "").unwrap(),
"0:ValueError",
"shapes ({0}), ({1}) not aligned",
[Some(n1_sz), Some(n2_sz), None],
ctx.current_loc,
);
let identity =
unsafe { n1.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
let acc = ctx.builder.build_alloca(identity.get_type(), "").unwrap();
ctx.builder.build_store(acc, identity.get_type().const_zero()).unwrap();
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(n1_sz, false),
|generator, ctx, _, idx| {
let elem1 = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
let elem2 = unsafe { n2.data().get_unchecked(ctx, generator, &idx, None) };
let product = match elem1 {
BasicValueEnum::IntValue(e1) => ctx
.builder
.build_int_mul(e1, elem2.into_int_value(), "")
.unwrap()
.as_basic_value_enum(),
BasicValueEnum::FloatValue(e1) => ctx
.builder
.build_float_mul(e1, elem2.into_float_value(), "")
.unwrap()
.as_basic_value_enum(),
_ => unreachable!(),
};
let acc_val = ctx.builder.build_load(acc, "").unwrap();
let acc_val = match acc_val {
BasicValueEnum::IntValue(e1) => ctx
.builder
.build_int_add(e1, product.into_int_value(), "")
.unwrap()
.as_basic_value_enum(),
BasicValueEnum::FloatValue(e1) => ctx
.builder
.build_float_add(e1, product.into_float_value(), "")
.unwrap()
.as_basic_value_enum(),
_ => unreachable!(),
};
ctx.builder.build_store(acc, acc_val).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let acc_val = ctx.builder.build_load(acc, "").unwrap();
Ok(acc_val)
}
(BasicValueEnum::IntValue(e1), BasicValueEnum::IntValue(e2)) => {
Ok(ctx.builder.build_int_mul(e1, e2, "").unwrap().as_basic_value_enum())
}
(BasicValueEnum::FloatValue(e1), BasicValueEnum::FloatValue(e2)) => {
Ok(ctx.builder.build_float_mul(e1, e2, "").unwrap().as_basic_value_enum())
}
_ => unreachable!(
"{FN_NAME}() not supported for '{}'",
format!("'{}'", ctx.unifier.stringify(x1_ty))
),
}
}

View File

@ -10,10 +10,10 @@ use crate::{
expr::gen_binop_expr, expr::gen_binop_expr,
gen_in_range_check, gen_in_range_check,
}, },
toplevel::{DefinitionId, TopLevelDef}, toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef},
typecheck::{ typecheck::{
magic_methods::Binop, magic_methods::Binop,
typedef::{iter_type_vars, FunSignature, Type, TypeEnum}, typedef::{FunSignature, Type, TypeEnum},
}, },
}; };
use inkwell::{ use inkwell::{
@ -23,10 +23,10 @@ use inkwell::{
values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue}, values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue},
IntPredicate, IntPredicate,
}; };
use itertools::{izip, Itertools};
use nac3parser::ast::{ use nac3parser::ast::{
Constant, ExcepthandlerKind, Expr, ExprKind, Location, Stmt, StmtKind, StrRef, Constant, ExcepthandlerKind, Expr, ExprKind, Location, Stmt, StmtKind, StrRef,
}; };
use std::convert::TryFrom;
/// See [`CodeGenerator::gen_var_alloc`]. /// See [`CodeGenerator::gen_var_alloc`].
pub fn gen_var<'ctx>( pub fn gen_var<'ctx>(
@ -97,6 +97,8 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
pattern: &Expr<Option<Type>>, pattern: &Expr<Option<Type>>,
name: Option<&str>, name: Option<&str>,
) -> Result<Option<PointerValue<'ctx>>, String> { ) -> Result<Option<PointerValue<'ctx>>, String> {
let llvm_usize = generator.get_size_type(ctx.ctx);
// very similar to gen_expr, but we don't do an extra load at the end // very similar to gen_expr, but we don't do an extra load at the end
// and we flatten nested tuples // and we flatten nested tuples
Ok(Some(match &pattern.node { Ok(Some(match &pattern.node {
@ -135,6 +137,65 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
} }
.unwrap() .unwrap()
} }
ExprKind::Subscript { value, slice, .. } => {
match ctx.unifier.get_ty_immutable(value.custom.unwrap()).as_ref() {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::List.id() => {
let v = generator
.gen_expr(ctx, value)?
.unwrap()
.to_basic_value_enum(ctx, generator, value.custom.unwrap())?
.into_pointer_value();
let v = ListValue::from_ptr_val(v, llvm_usize, None);
let len = v.load_size(ctx, Some("len"));
let raw_index = generator
.gen_expr(ctx, slice)?
.unwrap()
.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?
.into_int_value();
let raw_index = ctx
.builder
.build_int_s_extend(raw_index, generator.get_size_type(ctx.ctx), "sext")
.unwrap();
// handle negative index
let is_negative = ctx
.builder
.build_int_compare(
IntPredicate::SLT,
raw_index,
generator.get_size_type(ctx.ctx).const_zero(),
"is_neg",
)
.unwrap();
let adjusted = ctx.builder.build_int_add(raw_index, len, "adjusted").unwrap();
let index = ctx
.builder
.build_select(is_negative, adjusted, raw_index, "index")
.map(BasicValueEnum::into_int_value)
.unwrap();
// unsigned less than is enough, because negative index after adjustment is
// bigger than the length (for unsigned cmp)
let bound_check = ctx
.builder
.build_int_compare(IntPredicate::ULT, index, len, "inbound")
.unwrap();
ctx.make_assert(
generator,
bound_check,
"0:IndexError",
"index {0} out of bounds 0:{1}",
[Some(raw_index), Some(len), None],
slice.location,
);
v.data().ptr_offset(ctx, generator, &index, name)
}
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
todo!()
}
_ => unreachable!(),
}
}
_ => unreachable!(), _ => unreachable!(),
})) }))
} }
@ -145,20 +206,70 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
target: &Expr<Option<Type>>, target: &Expr<Option<Type>>,
value: ValueEnum<'ctx>, value: ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String> { ) -> Result<(), String> {
// See https://docs.python.org/3/reference/simple_stmts.html#assignment-statements. let llvm_usize = generator.get_size_type(ctx.ctx);
match &target.node { match &target.node {
ExprKind::Subscript { value: target, slice: key, .. } => { ExprKind::Tuple { elts, .. } => {
// Handle "slicing" or "subscription" let BasicValueEnum::StructValue(v) =
generator.gen_setitem(ctx, target, key, value, value_ty)?; value.to_basic_value_enum(ctx, generator, target.custom.unwrap())?
else {
unreachable!()
};
for (i, elt) in elts.iter().enumerate() {
let v = ctx
.builder
.build_extract_value(v, u32::try_from(i).unwrap(), "struct_elem")
.unwrap();
generator.gen_assign(ctx, elt, v.into())?;
} }
ExprKind::Tuple { elts, .. } | ExprKind::List { elts, .. } => { }
// Fold on `"[" [target_list] "]"` and `"(" [target_list] ")"` ExprKind::Subscript { value: ls, slice, .. }
generator.gen_assign_target_list(ctx, elts, value, value_ty)?; if matches!(&slice.node, ExprKind::Slice { .. }) =>
{
let ExprKind::Slice { lower, upper, step } = &slice.node else { unreachable!() };
let ls = generator
.gen_expr(ctx, ls)?
.unwrap()
.to_basic_value_enum(ctx, generator, ls.custom.unwrap())?
.into_pointer_value();
let ls = ListValue::from_ptr_val(ls, llvm_usize, None);
let Some((start, end, step)) =
handle_slice_indices(lower, upper, step, ctx, generator, ls.load_size(ctx, None))?
else {
return Ok(());
};
let value = value
.to_basic_value_enum(ctx, generator, target.custom.unwrap())?
.into_pointer_value();
let value = ListValue::from_ptr_val(value, llvm_usize, None);
let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) {
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
*params.iter().next().unwrap().1
}
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
unpack_ndarray_var_tys(&mut ctx.unifier, target.custom.unwrap()).0
}
_ => unreachable!(),
};
let ty = ctx.get_llvm_type(generator, ty);
let Some(src_ind) = handle_slice_indices(
&None,
&None,
&None,
ctx,
generator,
value.load_size(ctx, None),
)?
else {
return Ok(());
};
list_slice_assignment(generator, ctx, ty, ls, (start, end, step), value, src_ind);
} }
_ => { _ => {
// Handle attribute and direct variable assignments.
let name = if let ExprKind::Name { id, .. } = &target.node { let name = if let ExprKind::Name { id, .. } = &target.node {
format!("{id}.addr") format!("{id}.addr")
} else { } else {
@ -182,234 +293,6 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
Ok(()) Ok(())
} }
/// See [`CodeGenerator::gen_assign_target_list`].
pub fn gen_assign_target_list<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
targets: &Vec<Expr<Option<Type>>>,
value: ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String> {
// Deconstruct the tuple `value`
let BasicValueEnum::StructValue(tuple) = value.to_basic_value_enum(ctx, generator, value_ty)?
else {
unreachable!()
};
// NOTE: Currently, RHS's type is forced to be a Tuple by the type inferencer.
let TypeEnum::TTuple { ty: tuple_tys, .. } = &*ctx.unifier.get_ty(value_ty) else {
unreachable!();
};
assert_eq!(tuple.get_type().count_fields() as usize, tuple_tys.len());
let tuple = (0..tuple.get_type().count_fields())
.map(|i| ctx.builder.build_extract_value(tuple, i, "item").unwrap())
.collect_vec();
// Find the starred target if it exists.
let mut starred_target_index: Option<usize> = None; // Index of the "starred" target. If it exists, there may only be one.
for (i, target) in targets.iter().enumerate() {
if matches!(target.node, ExprKind::Starred { .. }) {
assert!(starred_target_index.is_none()); // The typechecker ensures this
starred_target_index = Some(i);
}
}
if let Some(starred_target_index) = starred_target_index {
assert!(tuple_tys.len() >= targets.len() - 1); // The typechecker ensures this
let a = starred_target_index; // Number of RHS values before the starred target
let b = tuple_tys.len() - (targets.len() - 1 - starred_target_index); // Number of RHS values after the starred target
// Thus `tuple[a..b]` is assigned to the starred target.
// Handle assignment before the starred target
for (target, val, val_ty) in
izip!(&targets[..starred_target_index], &tuple[..a], &tuple_tys[..a])
{
generator.gen_assign(ctx, target, ValueEnum::Dynamic(*val), *val_ty)?;
}
// Handle assignment to the starred target
if let ExprKind::Starred { value: target, .. } = &targets[starred_target_index].node {
let vals = &tuple[a..b];
let val_tys = &tuple_tys[a..b];
// Create a sub-tuple from `value` for the starred target.
let sub_tuple_ty = ctx
.ctx
.struct_type(&vals.iter().map(BasicValueEnum::get_type).collect_vec(), false);
let psub_tuple_val =
ctx.builder.build_alloca(sub_tuple_ty, "starred_target_value_ptr").unwrap();
for (i, val) in vals.iter().enumerate() {
let pitem = ctx
.builder
.build_struct_gep(psub_tuple_val, i as u32, "starred_target_value_item")
.unwrap();
ctx.builder.build_store(pitem, *val).unwrap();
}
let sub_tuple_val =
ctx.builder.build_load(psub_tuple_val, "starred_target_value").unwrap();
// Create the typechecker type of the sub-tuple
let sub_tuple_ty =
ctx.unifier.add_ty(TypeEnum::TTuple { ty: val_tys.to_vec(), is_vararg_ctx: false });
// Now assign with that sub-tuple to the starred target.
generator.gen_assign(ctx, target, ValueEnum::Dynamic(sub_tuple_val), sub_tuple_ty)?;
} else {
unreachable!() // The typechecker ensures this
}
// Handle assignment after the starred target
for (target, val, val_ty) in
izip!(&targets[starred_target_index + 1..], &tuple[b..], &tuple_tys[b..])
{
generator.gen_assign(ctx, target, ValueEnum::Dynamic(*val), *val_ty)?;
}
} else {
assert_eq!(tuple_tys.len(), targets.len()); // The typechecker ensures this
for (target, val, val_ty) in izip!(targets, tuple, tuple_tys) {
generator.gen_assign(ctx, target, ValueEnum::Dynamic(val), *val_ty)?;
}
}
Ok(())
}
/// See [`CodeGenerator::gen_setitem`].
pub fn gen_setitem<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
target: &Expr<Option<Type>>,
key: &Expr<Option<Type>>,
value: ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String> {
let target_ty = target.custom.unwrap();
let key_ty = key.custom.unwrap();
match &*ctx.unifier.get_ty(target_ty) {
TypeEnum::TObj { obj_id, params: list_params, .. }
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{
// Handle list item assignment
let llvm_usize = generator.get_size_type(ctx.ctx);
let target_item_ty = iter_type_vars(list_params).next().unwrap().ty;
let target = generator
.gen_expr(ctx, target)?
.unwrap()
.to_basic_value_enum(ctx, generator, target_ty)?
.into_pointer_value();
let target = ListValue::from_ptr_val(target, llvm_usize, None);
if let ExprKind::Slice { .. } = &key.node {
// Handle assigning to a slice
let ExprKind::Slice { lower, upper, step } = &key.node else { unreachable!() };
let Some((start, end, step)) = handle_slice_indices(
lower,
upper,
step,
ctx,
generator,
target.load_size(ctx, None),
)?
else {
return Ok(());
};
let value =
value.to_basic_value_enum(ctx, generator, value_ty)?.into_pointer_value();
let value = ListValue::from_ptr_val(value, llvm_usize, None);
let target_item_ty = ctx.get_llvm_type(generator, target_item_ty);
let Some(src_ind) = handle_slice_indices(
&None,
&None,
&None,
ctx,
generator,
value.load_size(ctx, None),
)?
else {
return Ok(());
};
list_slice_assignment(
generator,
ctx,
target_item_ty,
target,
(start, end, step),
value,
src_ind,
);
} else {
// Handle assigning to an index
let len = target.load_size(ctx, Some("len"));
let index = generator
.gen_expr(ctx, key)?
.unwrap()
.to_basic_value_enum(ctx, generator, key_ty)?
.into_int_value();
let index = ctx
.builder
.build_int_s_extend(index, generator.get_size_type(ctx.ctx), "sext")
.unwrap();
// handle negative index
let is_negative = ctx
.builder
.build_int_compare(
IntPredicate::SLT,
index,
generator.get_size_type(ctx.ctx).const_zero(),
"is_neg",
)
.unwrap();
let adjusted = ctx.builder.build_int_add(index, len, "adjusted").unwrap();
let index = ctx
.builder
.build_select(is_negative, adjusted, index, "index")
.map(BasicValueEnum::into_int_value)
.unwrap();
// unsigned less than is enough, because negative index after adjustment is
// bigger than the length (for unsigned cmp)
let bound_check = ctx
.builder
.build_int_compare(IntPredicate::ULT, index, len, "inbound")
.unwrap();
ctx.make_assert(
generator,
bound_check,
"0:IndexError",
"index {0} out of bounds 0:{1}",
[Some(index), Some(len), None],
key.location,
);
// Write value to index on list
let item_ptr =
target.data().ptr_offset(ctx, generator, &index, Some("list_item_ptr"));
let value = value.to_basic_value_enum(ctx, generator, value_ty)?;
ctx.builder.build_store(item_ptr, value).unwrap();
}
}
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
{
// Handle NDArray item assignment
todo!("ndarray subscript assignment is not yet implemented");
}
_ => {
panic!("encountered unknown target type: {}", ctx.unifier.stringify(target_ty));
}
}
Ok(())
}
/// See [`CodeGenerator::gen_for`]. /// See [`CodeGenerator::gen_for`].
pub fn gen_for<G: CodeGenerator>( pub fn gen_for<G: CodeGenerator>(
generator: &mut G, generator: &mut G,
@ -432,6 +315,9 @@ pub fn gen_for<G: CodeGenerator>(
let orelse_bb = let orelse_bb =
if orelse.is_empty() { cont_bb } else { ctx.ctx.append_basic_block(current, "for.orelse") }; if orelse.is_empty() { cont_bb } else { ctx.ctx.append_basic_block(current, "for.orelse") };
// Whether the iterable is a range() expression
let is_iterable_range_expr = ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range);
// The BB containing the increment expression // The BB containing the increment expression
let incr_bb = ctx.ctx.append_basic_block(current, "for.incr"); let incr_bb = ctx.ctx.append_basic_block(current, "for.incr");
// The BB containing the loop condition check // The BB containing the loop condition check
@ -440,23 +326,17 @@ pub fn gen_for<G: CodeGenerator>(
// store loop bb information and restore it later // store loop bb information and restore it later
let loop_bb = ctx.loop_target.replace((incr_bb, cont_bb)); let loop_bb = ctx.loop_target.replace((incr_bb, cont_bb));
let iter_ty = iter.custom.unwrap();
let iter_val = if let Some(v) = generator.gen_expr(ctx, iter)? { let iter_val = if let Some(v) = generator.gen_expr(ctx, iter)? {
v.to_basic_value_enum(ctx, generator, iter_ty)? v.to_basic_value_enum(ctx, generator, iter.custom.unwrap())?
} else { } else {
return Ok(()); return Ok(());
}; };
if is_iterable_range_expr {
match &*ctx.unifier.get_ty(iter_ty) {
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() =>
{
let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range")); let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range"));
// Internal variable for loop; Cannot be assigned // Internal variable for loop; Cannot be assigned
let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?; let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?;
// Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed // Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed
let Some(target_i) = let Some(target_i) = generator.gen_store_target(ctx, target, Some("for.target.addr"))?
generator.gen_store_target(ctx, target, Some("for.target.addr"))?
else { else {
unreachable!() unreachable!()
}; };
@ -465,10 +345,8 @@ pub fn gen_for<G: CodeGenerator>(
ctx.builder.build_store(i, start).unwrap(); ctx.builder.build_store(i, start).unwrap();
// Check "If step is zero, ValueError is raised." // Check "If step is zero, ValueError is raised."
let rangenez = ctx let rangenez =
.builder ctx.builder.build_int_compare(IntPredicate::NE, step, int32.const_zero(), "").unwrap();
.build_int_compare(IntPredicate::NE, step, int32.const_zero(), "")
.unwrap();
ctx.make_assert( ctx.make_assert(
generator, generator,
rangenez, rangenez,
@ -485,10 +363,7 @@ pub fn gen_for<G: CodeGenerator>(
.build_conditional_branch( .build_conditional_branch(
gen_in_range_check( gen_in_range_check(
ctx, ctx,
ctx.builder ctx.builder.build_load(i, "").map(BasicValueEnum::into_int_value).unwrap(),
.build_load(i, "")
.map(BasicValueEnum::into_int_value)
.unwrap(),
stop, stop,
step, step,
), ),
@ -518,10 +393,7 @@ pub fn gen_for<G: CodeGenerator>(
) )
.unwrap(); .unwrap();
generator.gen_block(ctx, body.iter())?; generator.gen_block(ctx, body.iter())?;
} } else {
TypeEnum::TObj { obj_id, params: list_params, .. }
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{
let index_addr = generator.gen_var_alloc(ctx, size_t.into(), Some("for.index.addr"))?; let index_addr = generator.gen_var_alloc(ctx, size_t.into(), Some("for.index.addr"))?;
ctx.builder.build_store(index_addr, size_t.const_zero()).unwrap(); ctx.builder.build_store(index_addr, size_t.const_zero()).unwrap();
let len = ctx let len = ctx
@ -559,14 +431,9 @@ pub fn gen_for<G: CodeGenerator>(
.map(BasicValueEnum::into_int_value) .map(BasicValueEnum::into_int_value)
.unwrap(); .unwrap();
let val = ctx.build_gep_and_load(arr_ptr, &[index], Some("val")); let val = ctx.build_gep_and_load(arr_ptr, &[index], Some("val"));
let val_ty = iter_type_vars(list_params).next().unwrap().ty; generator.gen_assign(ctx, target, val.into())?;
generator.gen_assign(ctx, target, val.into(), val_ty)?;
generator.gen_block(ctx, body.iter())?; generator.gen_block(ctx, body.iter())?;
} }
_ => {
panic!("unsupported for loop iterator type: {}", ctx.unifier.stringify(iter_ty));
}
}
for (k, (_, _, counter)) in &var_assignment { for (k, (_, _, counter)) in &var_assignment {
let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap(); let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap();
@ -627,7 +494,6 @@ pub struct BreakContinueHooks<'ctx> {
pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>( pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
label: Option<&str>,
init: InitFn, init: InitFn,
cond: CondFn, cond: CondFn,
body: BodyFn, body: BodyFn,
@ -642,16 +508,14 @@ where
FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, BreakContinueHooks, I) -> Result<(), String>, FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, BreakContinueHooks, I) -> Result<(), String>,
UpdateFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>, UpdateFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
{ {
let label = label.unwrap_or("for");
let current_bb = ctx.builder.get_insert_block().unwrap(); let current_bb = ctx.builder.get_insert_block().unwrap();
let init_bb = ctx.ctx.insert_basic_block_after(current_bb, &format!("{label}.init")); let init_bb = ctx.ctx.insert_basic_block_after(current_bb, "for.init");
// The BB containing the loop condition check // The BB containing the loop condition check
let cond_bb = ctx.ctx.insert_basic_block_after(init_bb, &format!("{label}.cond")); let cond_bb = ctx.ctx.insert_basic_block_after(init_bb, "for.cond");
let body_bb = ctx.ctx.insert_basic_block_after(cond_bb, &format!("{label}.body")); let body_bb = ctx.ctx.insert_basic_block_after(cond_bb, "for.body");
// The BB containing the increment expression // The BB containing the increment expression
let update_bb = ctx.ctx.insert_basic_block_after(body_bb, &format!("{label}.update")); let update_bb = ctx.ctx.insert_basic_block_after(body_bb, "for.update");
let cont_bb = ctx.ctx.insert_basic_block_after(update_bb, &format!("{label}.end")); let cont_bb = ctx.ctx.insert_basic_block_after(update_bb, "for.end");
// store loop bb information and restore it later // store loop bb information and restore it later
let loop_bb = ctx.loop_target.replace((update_bb, cont_bb)); let loop_bb = ctx.loop_target.replace((update_bb, cont_bb));
@ -708,7 +572,6 @@ where
pub fn gen_for_callback_incrementing<'ctx, 'a, G, BodyFn>( pub fn gen_for_callback_incrementing<'ctx, 'a, G, BodyFn>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
label: Option<&str>,
init_val: IntValue<'ctx>, init_val: IntValue<'ctx>,
max_val: (IntValue<'ctx>, bool), max_val: (IntValue<'ctx>, bool),
body: BodyFn, body: BodyFn,
@ -728,7 +591,6 @@ where
gen_for_callback( gen_for_callback(
generator, generator,
ctx, ctx,
label,
|generator, ctx| { |generator, ctx| {
let i_addr = generator.gen_var_alloc(ctx, init_val_t.into(), None)?; let i_addr = generator.gen_var_alloc(ctx, init_val_t.into(), None)?;
ctx.builder.build_store(i_addr, init_val).unwrap(); ctx.builder.build_store(i_addr, init_val).unwrap();
@ -780,11 +642,9 @@ where
/// - `step_fn`: A lambda of IR statements that retrieves the `step` value of the `range`-like /// - `step_fn`: A lambda of IR statements that retrieves the `step` value of the `range`-like
/// iterable. This value will be extended to the size of `start`. /// iterable. This value will be extended to the size of `start`.
/// - `body_fn`: A lambda of IR statements within the loop body. /// - `body_fn`: A lambda of IR statements within the loop body.
#[allow(clippy::too_many_arguments)]
pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>( pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
label: Option<&str>,
is_unsigned: bool, is_unsigned: bool,
start_fn: StartFn, start_fn: StartFn,
(stop_fn, stop_inclusive): (StopFn, bool), (stop_fn, stop_inclusive): (StopFn, bool),
@ -796,19 +656,13 @@ where
StartFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>, StartFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
StopFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>, StopFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
StepFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>, StepFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
BodyFn: FnOnce( BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<(), String>,
&mut G,
&mut CodeGenContext<'ctx, 'a>,
BreakContinueHooks,
IntValue<'ctx>,
) -> Result<(), String>,
{ {
let init_val_t = start_fn(generator, ctx).map(IntValue::get_type).unwrap(); let init_val_t = start_fn(generator, ctx).map(IntValue::get_type).unwrap();
gen_for_callback( gen_for_callback(
generator, generator,
ctx, ctx,
label,
|generator, ctx| { |generator, ctx| {
let i_addr = generator.gen_var_alloc(ctx, init_val_t.into(), None)?; let i_addr = generator.gen_var_alloc(ctx, init_val_t.into(), None)?;
@ -866,10 +720,10 @@ where
Ok(cond) Ok(cond)
}, },
|generator, ctx, hooks, (i_addr, _)| { |generator, ctx, _, (i_addr, _)| {
let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
body_fn(generator, ctx, hooks, i) body_fn(generator, ctx, i)
}, },
|generator, ctx, (i_addr, _)| { |generator, ctx, (i_addr, _)| {
let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
@ -1721,14 +1575,14 @@ pub fn gen_stmt<G: CodeGenerator>(
} }
StmtKind::AnnAssign { target, value, .. } => { StmtKind::AnnAssign { target, value, .. } => {
if let Some(value) = value { if let Some(value) = value {
let Some(value_enum) = generator.gen_expr(ctx, value)? else { return Ok(()) }; let Some(value) = generator.gen_expr(ctx, value)? else { return Ok(()) };
generator.gen_assign(ctx, target, value_enum, value.custom.unwrap())?; generator.gen_assign(ctx, target, value)?;
} }
} }
StmtKind::Assign { targets, value, .. } => { StmtKind::Assign { targets, value, .. } => {
let Some(value_enum) = generator.gen_expr(ctx, value)? else { return Ok(()) }; let Some(value) = generator.gen_expr(ctx, value)? else { return Ok(()) };
for target in targets { for target in targets {
generator.gen_assign(ctx, target, value_enum.clone(), value.custom.unwrap())?; generator.gen_assign(ctx, target, value.clone())?;
} }
} }
StmtKind::Continue { .. } => { StmtKind::Continue { .. } => {
@ -1742,16 +1596,15 @@ pub fn gen_stmt<G: CodeGenerator>(
StmtKind::For { .. } => generator.gen_for(ctx, stmt)?, StmtKind::For { .. } => generator.gen_for(ctx, stmt)?,
StmtKind::With { .. } => generator.gen_with(ctx, stmt)?, StmtKind::With { .. } => generator.gen_with(ctx, stmt)?,
StmtKind::AugAssign { target, op, value, .. } => { StmtKind::AugAssign { target, op, value, .. } => {
let value_enum = gen_binop_expr( let value = gen_binop_expr(
generator, generator,
ctx, ctx,
target, target,
Binop::aug_assign(*op), Binop::aug_assign(*op),
value, value,
stmt.location, stmt.location,
)? )?;
.unwrap(); generator.gen_assign(ctx, target, value.unwrap())?;
generator.gen_assign(ctx, target, value_enum, value.custom.unwrap())?;
} }
StmtKind::Try { .. } => gen_try(generator, ctx, stmt)?, StmtKind::Try { .. } => gen_try(generator, ctx, stmt)?,
StmtKind::Raise { exc, .. } => { StmtKind::Raise { exc, .. } => {

View File

@ -109,18 +109,8 @@ fn test_primitives() {
let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()]; let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()];
let signature = FunSignature { let signature = FunSignature {
args: vec![ args: vec![
FuncArg { FuncArg { name: "a".into(), ty: primitives.int32, default_value: None },
name: "a".into(), FuncArg { name: "b".into(), ty: primitives.int32, default_value: None },
ty: primitives.int32,
default_value: None,
is_vararg: false,
},
FuncArg {
name: "b".into(),
ty: primitives.int32,
default_value: None,
is_vararg: false,
},
], ],
ret: primitives.int32, ret: primitives.int32,
vars: VarMap::new(), vars: VarMap::new(),
@ -199,8 +189,6 @@ fn test_primitives() {
let expected = indoc! {" let expected = indoc! {"
; ModuleID = 'test' ; ModuleID = 'test'
source_filename = \"test\" source_filename = \"test\"
target datalayout = \"e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128\"
target triple = \"x86_64-unknown-linux-gnu\"
; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn ; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn
define i32 @testing(i32 %0, i32 %1) local_unnamed_addr #0 !dbg !4 { define i32 @testing(i32 %0, i32 %1) local_unnamed_addr #0 !dbg !4 {
@ -265,12 +253,7 @@ fn test_simple_call() {
unifier.top_level = Some(top_level.clone()); unifier.top_level = Some(top_level.clone());
let signature = FunSignature { let signature = FunSignature {
args: vec![FuncArg { args: vec![FuncArg { name: "a".into(), ty: primitives.int32, default_value: None }],
name: "a".into(),
ty: primitives.int32,
default_value: None,
is_vararg: false,
}],
ret: primitives.int32, ret: primitives.int32,
vars: VarMap::new(), vars: VarMap::new(),
}; };
@ -385,8 +368,6 @@ fn test_simple_call() {
let expected = indoc! {" let expected = indoc! {"
; ModuleID = 'test' ; ModuleID = 'test'
source_filename = \"test\" source_filename = \"test\"
target datalayout = \"e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128\"
target triple = \"x86_64-unknown-linux-gnu\"
; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn ; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn
define i32 @testing(i32 %0) local_unnamed_addr #0 !dbg !5 { define i32 @testing(i32 %0) local_unnamed_addr #0 !dbg !5 {

View File

@ -78,14 +78,14 @@ impl SymbolValue {
} }
Constant::Tuple(t) => { Constant::Tuple(t) => {
let expected_ty = unifier.get_ty(expected_ty); let expected_ty = unifier.get_ty(expected_ty);
let TypeEnum::TTuple { ty, is_vararg_ctx } = expected_ty.as_ref() else { let TypeEnum::TTuple { ty } = expected_ty.as_ref() else {
return Err(format!( return Err(format!(
"Expected {:?}, but got Tuple", "Expected {:?}, but got Tuple",
expected_ty.get_type_name() expected_ty.get_type_name()
)); ));
}; };
assert!(*is_vararg_ctx || ty.len() == t.len()); assert_eq!(ty.len(), t.len());
let elems = t let elems = t
.iter() .iter()
@ -155,7 +155,7 @@ impl SymbolValue {
SymbolValue::Bool(_) => primitives.bool, SymbolValue::Bool(_) => primitives.bool,
SymbolValue::Tuple(vs) => { SymbolValue::Tuple(vs) => {
let vs_tys = vs.iter().map(|v| v.get_type(primitives, unifier)).collect::<Vec<_>>(); let vs_tys = vs.iter().map(|v| v.get_type(primitives, unifier)).collect::<Vec<_>>();
unifier.add_ty(TypeEnum::TTuple { ty: vs_tys, is_vararg_ctx: false }) unifier.add_ty(TypeEnum::TTuple { ty: vs_tys })
} }
SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option, SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option,
} }
@ -482,7 +482,7 @@ pub fn parse_type_annotation<T>(
parse_type_annotation(resolver, top_level_defs, unifier, primitives, elt) parse_type_annotation(resolver, top_level_defs, unifier, primitives, elt)
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
Ok(unifier.add_ty(TypeEnum::TTuple { ty, is_vararg_ctx: false })) Ok(unifier.add_ty(TypeEnum::TTuple { ty }))
} else { } else {
Err(HashSet::from(["Expected multiple elements for tuple".into()])) Err(HashSet::from(["Expected multiple elements for tuple".into()]))
} }

View File

@ -45,26 +45,10 @@ pub fn get_exn_constructor(
name: "msg".into(), name: "msg".into(),
ty: string, ty: string,
default_value: Some(SymbolValue::Str(String::new())), default_value: Some(SymbolValue::Str(String::new())),
is_vararg: false,
},
FuncArg {
name: "param0".into(),
ty: int64,
default_value: Some(SymbolValue::I64(0)),
is_vararg: false,
},
FuncArg {
name: "param1".into(),
ty: int64,
default_value: Some(SymbolValue::I64(0)),
is_vararg: false,
},
FuncArg {
name: "param2".into(),
ty: int64,
default_value: Some(SymbolValue::I64(0)),
is_vararg: false,
}, },
FuncArg { name: "param0".into(), ty: int64, default_value: Some(SymbolValue::I64(0)) },
FuncArg { name: "param1".into(), ty: int64, default_value: Some(SymbolValue::I64(0)) },
FuncArg { name: "param2".into(), ty: int64, default_value: Some(SymbolValue::I64(0)) },
]; ];
let exn_type = unifier.add_ty(TypeEnum::TObj { let exn_type = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(class_id), obj_id: DefinitionId(class_id),
@ -130,12 +114,7 @@ fn create_fn_by_codegen(
signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { signature: unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: param_ty args: param_ty
.iter() .iter()
.map(|p| FuncArg { .map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None })
name: p.1.into(),
ty: p.0,
default_value: None,
is_vararg: false,
})
.collect(), .collect(),
ret: ret_ty, ret: ret_ty,
vars: var_map.clone(), vars: var_map.clone(),
@ -367,8 +346,8 @@ impl<'a> BuiltinBuilder<'a> {
let (is_some_ty, unwrap_ty, option_tvar) = let (is_some_ty, unwrap_ty, option_tvar) =
if let TypeEnum::TObj { fields, params, .. } = unifier.get_ty(option).as_ref() { if let TypeEnum::TObj { fields, params, .. } = unifier.get_ty(option).as_ref() {
( (
*fields.get(&PrimDef::FunOptionIsSome.simple_name().into()).unwrap(), *fields.get(&PrimDef::OptionIsSome.simple_name().into()).unwrap(),
*fields.get(&PrimDef::FunOptionUnwrap.simple_name().into()).unwrap(), *fields.get(&PrimDef::OptionUnwrap.simple_name().into()).unwrap(),
iter_type_vars(params).next().unwrap(), iter_type_vars(params).next().unwrap(),
) )
} else { } else {
@ -383,9 +362,9 @@ impl<'a> BuiltinBuilder<'a> {
let ndarray_dtype_tvar = iter_type_vars(ndarray_params).next().unwrap(); let ndarray_dtype_tvar = iter_type_vars(ndarray_params).next().unwrap();
let ndarray_ndims_tvar = iter_type_vars(ndarray_params).nth(1).unwrap(); let ndarray_ndims_tvar = iter_type_vars(ndarray_params).nth(1).unwrap();
let ndarray_copy_ty = let ndarray_copy_ty =
*ndarray_fields.get(&PrimDef::FunNDArrayCopy.simple_name().into()).unwrap(); *ndarray_fields.get(&PrimDef::NDArrayCopy.simple_name().into()).unwrap();
let ndarray_fill_ty = let ndarray_fill_ty =
*ndarray_fields.get(&PrimDef::FunNDArrayFill.simple_name().into()).unwrap(); *ndarray_fields.get(&PrimDef::NDArrayFill.simple_name().into()).unwrap();
let num_ty = unifier.get_fresh_var_with_range( let num_ty = unifier.get_fresh_var_with_range(
&[int32, int64, float, boolean, uint32, uint64], &[int32, int64, float, boolean, uint32, uint64],
@ -485,14 +464,14 @@ impl<'a> BuiltinBuilder<'a> {
PrimDef::Exception => self.build_exception_class_related(prim), PrimDef::Exception => self.build_exception_class_related(prim),
PrimDef::Option PrimDef::Option
| PrimDef::FunOptionIsSome | PrimDef::OptionIsSome
| PrimDef::FunOptionIsNone | PrimDef::OptionIsNone
| PrimDef::FunOptionUnwrap | PrimDef::OptionUnwrap
| PrimDef::FunSome => self.build_option_class_related(prim), | PrimDef::FunSome => self.build_option_class_related(prim),
PrimDef::List => self.build_list_class_related(prim), PrimDef::List => self.build_list_class_related(prim),
PrimDef::NDArray | PrimDef::FunNDArrayCopy | PrimDef::FunNDArrayFill => { PrimDef::NDArray | PrimDef::NDArrayCopy | PrimDef::NDArrayFill => {
self.build_ndarray_class_related(prim) self.build_ndarray_class_related(prim)
} }
@ -577,22 +556,6 @@ impl<'a> BuiltinBuilder<'a> {
| PrimDef::FunNpLdExp | PrimDef::FunNpLdExp
| PrimDef::FunNpHypot | PrimDef::FunNpHypot
| PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim), | PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim),
PrimDef::FunNpTranspose | PrimDef::FunNpReshape => {
self.build_np_sp_ndarray_function(prim)
}
PrimDef::FunNpDot
| PrimDef::FunNpLinalgCholesky
| PrimDef::FunNpLinalgQr
| PrimDef::FunNpLinalgSvd
| PrimDef::FunNpLinalgInv
| PrimDef::FunNpLinalgPinv
| PrimDef::FunNpLinalgMatrixPower
| PrimDef::FunNpLinalgDet
| PrimDef::FunSpLinalgLu
| PrimDef::FunSpLinalgSchur
| PrimDef::FunSpLinalgHessenberg => self.build_linalg_methods(prim),
}; };
if cfg!(debug_assertions) { if cfg!(debug_assertions) {
@ -650,24 +613,17 @@ impl<'a> BuiltinBuilder<'a> {
let make_ctor_signature = |unifier: &mut Unifier| { let make_ctor_signature = |unifier: &mut Unifier| {
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![ args: vec![
FuncArg { FuncArg { name: "start".into(), ty: int32, default_value: None },
name: "start".into(),
ty: int32,
default_value: None,
is_vararg: false,
},
FuncArg { FuncArg {
name: "stop".into(), name: "stop".into(),
ty: int32, ty: int32,
// placeholder // placeholder
default_value: Some(SymbolValue::I32(0)), default_value: Some(SymbolValue::I32(0)),
is_vararg: false,
}, },
FuncArg { FuncArg {
name: "step".into(), name: "step".into(),
ty: int32, ty: int32,
default_value: Some(SymbolValue::I32(1)), default_value: Some(SymbolValue::I32(1)),
is_vararg: false,
}, },
], ],
ret: range, ret: range,
@ -838,9 +794,9 @@ impl<'a> BuiltinBuilder<'a> {
prim, prim,
&[ &[
PrimDef::Option, PrimDef::Option,
PrimDef::FunOptionIsSome, PrimDef::OptionIsSome,
PrimDef::FunOptionIsNone, PrimDef::OptionIsNone,
PrimDef::FunOptionUnwrap, PrimDef::OptionUnwrap,
PrimDef::FunSome, PrimDef::FunSome,
], ],
); );
@ -853,9 +809,9 @@ impl<'a> BuiltinBuilder<'a> {
fields: Vec::default(), fields: Vec::default(),
attributes: Vec::default(), attributes: Vec::default(),
methods: vec![ methods: vec![
Self::create_method(PrimDef::FunOptionIsSome, self.is_some_ty.0), Self::create_method(PrimDef::OptionIsSome, self.is_some_ty.0),
Self::create_method(PrimDef::FunOptionIsNone, self.is_some_ty.0), Self::create_method(PrimDef::OptionIsNone, self.is_some_ty.0),
Self::create_method(PrimDef::FunOptionUnwrap, self.unwrap_ty.0), Self::create_method(PrimDef::OptionUnwrap, self.unwrap_ty.0),
], ],
ancestors: vec![TypeAnnotation::CustomClass { ancestors: vec![TypeAnnotation::CustomClass {
id: prim.id(), id: prim.id(),
@ -866,7 +822,7 @@ impl<'a> BuiltinBuilder<'a> {
loc: None, loc: None,
}, },
PrimDef::FunOptionUnwrap => TopLevelDef::Function { PrimDef::OptionUnwrap => TopLevelDef::Function {
name: prim.name().into(), name: prim.name().into(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.unwrap_ty.0, signature: self.unwrap_ty.0,
@ -880,7 +836,7 @@ impl<'a> BuiltinBuilder<'a> {
loc: None, loc: None,
}, },
PrimDef::FunOptionIsNone | PrimDef::FunOptionIsSome => TopLevelDef::Function { PrimDef::OptionIsNone | PrimDef::OptionIsSome => TopLevelDef::Function {
name: prim.name().to_string(), name: prim.name().to_string(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.is_some_ty.0, signature: self.is_some_ty.0,
@ -901,10 +857,10 @@ impl<'a> BuiltinBuilder<'a> {
}; };
let returned_int = match prim { let returned_int = match prim {
PrimDef::FunOptionIsNone => { PrimDef::OptionIsNone => {
ctx.builder.build_is_null(ptr, prim.simple_name()) ctx.builder.build_is_null(ptr, prim.simple_name())
} }
PrimDef::FunOptionIsSome => { PrimDef::OptionIsSome => {
ctx.builder.build_is_not_null(ptr, prim.simple_name()) ctx.builder.build_is_not_null(ptr, prim.simple_name())
} }
_ => unreachable!(), _ => unreachable!(),
@ -923,7 +879,6 @@ impl<'a> BuiltinBuilder<'a> {
name: "n".into(), name: "n".into(),
ty: self.option_tvar.ty, ty: self.option_tvar.ty,
default_value: None, default_value: None,
is_vararg: false,
}], }],
ret: self.primitives.option, ret: self.primitives.option,
vars: into_var_map([self.option_tvar]), vars: into_var_map([self.option_tvar]),
@ -978,7 +933,7 @@ impl<'a> BuiltinBuilder<'a> {
fn build_ndarray_class_related(&self, prim: PrimDef) -> TopLevelDef { fn build_ndarray_class_related(&self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed( debug_assert_prim_is_allowed(
prim, prim,
&[PrimDef::NDArray, PrimDef::FunNDArrayCopy, PrimDef::FunNDArrayFill], &[PrimDef::NDArray, PrimDef::NDArrayCopy, PrimDef::NDArrayFill],
); );
match prim { match prim {
@ -989,8 +944,8 @@ impl<'a> BuiltinBuilder<'a> {
fields: Vec::default(), fields: Vec::default(),
attributes: Vec::default(), attributes: Vec::default(),
methods: vec![ methods: vec![
Self::create_method(PrimDef::FunNDArrayCopy, self.ndarray_copy_ty.0), Self::create_method(PrimDef::NDArrayCopy, self.ndarray_copy_ty.0),
Self::create_method(PrimDef::FunNDArrayFill, self.ndarray_fill_ty.0), Self::create_method(PrimDef::NDArrayFill, self.ndarray_fill_ty.0),
], ],
ancestors: Vec::default(), ancestors: Vec::default(),
constructor: None, constructor: None,
@ -998,7 +953,7 @@ impl<'a> BuiltinBuilder<'a> {
loc: None, loc: None,
}, },
PrimDef::FunNDArrayCopy => TopLevelDef::Function { PrimDef::NDArrayCopy => TopLevelDef::Function {
name: prim.name().into(), name: prim.name().into(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.ndarray_copy_ty.0, signature: self.ndarray_copy_ty.0,
@ -1015,7 +970,7 @@ impl<'a> BuiltinBuilder<'a> {
loc: None, loc: None,
}, },
PrimDef::FunNDArrayFill => TopLevelDef::Function { PrimDef::NDArrayFill => TopLevelDef::Function {
name: prim.name().into(), name: prim.name().into(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.ndarray_fill_ty.0, signature: self.ndarray_fill_ty.0,
@ -1058,7 +1013,6 @@ impl<'a> BuiltinBuilder<'a> {
name: "n".into(), name: "n".into(),
ty: self.num_or_ndarray_ty.ty, ty: self.num_or_ndarray_ty.ty,
default_value: None, default_value: None,
is_vararg: false,
}], }],
ret: self.num_or_ndarray_ty.ty, ret: self.num_or_ndarray_ty.ty,
vars: self.num_or_ndarray_var_map.clone(), vars: self.num_or_ndarray_var_map.clone(),
@ -1278,23 +1232,16 @@ impl<'a> BuiltinBuilder<'a> {
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![ args: vec![
FuncArg { FuncArg { name: "object".into(), ty: tv.ty, default_value: None },
name: "object".into(),
ty: tv.ty,
default_value: None,
is_vararg: false,
},
FuncArg { FuncArg {
name: "copy".into(), name: "copy".into(),
ty: bool, ty: bool,
default_value: Some(SymbolValue::Bool(true)), default_value: Some(SymbolValue::Bool(true)),
is_vararg: false,
}, },
FuncArg { FuncArg {
name: "ndmin".into(), name: "ndmin".into(),
ty: int32, ty: int32,
default_value: Some(SymbolValue::U32(0)), default_value: Some(SymbolValue::U32(0)),
is_vararg: false,
}, },
], ],
ret: ndarray, ret: ndarray,
@ -1336,24 +1283,17 @@ impl<'a> BuiltinBuilder<'a> {
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![ args: vec![
FuncArg { FuncArg { name: "N".into(), ty: int32, default_value: None },
name: "N".into(),
ty: int32,
default_value: None,
is_vararg: false,
},
// TODO(Derppening): Default values current do not work? // TODO(Derppening): Default values current do not work?
FuncArg { FuncArg {
name: "M".into(), name: "M".into(),
ty: int32, ty: int32,
default_value: Some(SymbolValue::OptionNone), default_value: Some(SymbolValue::OptionNone),
is_vararg: false,
}, },
FuncArg { FuncArg {
name: "k".into(), name: "k".into(),
ty: int32, ty: int32,
default_value: Some(SymbolValue::I32(0)), default_value: Some(SymbolValue::I32(0)),
is_vararg: false,
}, },
], ],
ret: self.ndarray_float_2d, ret: self.ndarray_float_2d,
@ -1397,12 +1337,7 @@ impl<'a> BuiltinBuilder<'a> {
name: prim.name().into(), name: prim.name().into(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { args: vec![FuncArg { name: "s".into(), ty: str, default_value: None }],
name: "s".into(),
ty: str,
default_value: None,
is_vararg: false,
}],
ret: str, ret: str,
vars: VarMap::default(), vars: VarMap::default(),
})), })),
@ -1488,12 +1423,7 @@ impl<'a> BuiltinBuilder<'a> {
name: prim.name().into(), name: prim.name().into(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { args: vec![FuncArg { name: "ls".into(), ty: arg_ty.ty, default_value: None }],
name: "ls".into(),
ty: arg_ty.ty,
default_value: None,
is_vararg: false,
}],
ret: int32, ret: int32,
vars: into_var_map([tvar, arg_ty]), vars: into_var_map([tvar, arg_ty]),
})), })),
@ -1598,18 +1528,8 @@ impl<'a> BuiltinBuilder<'a> {
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![ args: vec![
FuncArg { FuncArg { name: "m".into(), ty: self.num_ty.ty, default_value: None },
name: "m".into(), FuncArg { name: "n".into(), ty: self.num_ty.ty, default_value: None },
ty: self.num_ty.ty,
default_value: None,
is_vararg: false,
},
FuncArg {
name: "n".into(),
ty: self.num_ty.ty,
default_value: None,
is_vararg: false,
},
], ],
ret: self.num_ty.ty, ret: self.num_ty.ty,
vars: self.num_var_map.clone(), vars: self.num_var_map.clone(),
@ -1691,12 +1611,7 @@ impl<'a> BuiltinBuilder<'a> {
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: param_ty args: param_ty
.iter() .iter()
.map(|p| FuncArg { .map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None })
name: p.1.into(),
ty: p.0,
default_value: None,
is_vararg: false,
})
.collect(), .collect(),
ret: ret_ty.ty, ret: ret_ty.ty,
vars: into_var_map([x1_ty, x2_ty, ret_ty]), vars: into_var_map([x1_ty, x2_ty, ret_ty]),
@ -1737,7 +1652,6 @@ impl<'a> BuiltinBuilder<'a> {
name: "n".into(), name: "n".into(),
ty: self.num_or_ndarray_ty.ty, ty: self.num_or_ndarray_ty.ty,
default_value: None, default_value: None,
is_vararg: false,
}], }],
ret: self.num_or_ndarray_ty.ty, ret: self.num_or_ndarray_ty.ty,
vars: self.num_or_ndarray_var_map.clone(), vars: self.num_or_ndarray_var_map.clone(),
@ -1926,12 +1840,7 @@ impl<'a> BuiltinBuilder<'a> {
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: param_ty args: param_ty
.iter() .iter()
.map(|p| FuncArg { .map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None })
name: p.1.into(),
ty: p.0,
default_value: None,
is_vararg: false,
})
.collect(), .collect(),
ret: ret_ty.ty, ret: ret_ty.ty,
vars: into_var_map([x1_ty, x2_ty, ret_ty]), vars: into_var_map([x1_ty, x2_ty, ret_ty]),
@ -1965,207 +1874,6 @@ impl<'a> BuiltinBuilder<'a> {
} }
} }
/// Build np/sp functions that take as input `NDArray` only
fn build_np_sp_ndarray_function(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]);
match prim {
PrimDef::FunNpTranspose => {
let ndarray_ty = self.unifier.get_fresh_var_with_range(
&[self.ndarray_num_ty],
Some("T".into()),
None,
);
create_fn_by_codegen(
self.unifier,
&into_var_map([ndarray_ty]),
prim.name(),
ndarray_ty.ty,
&[(ndarray_ty.ty, "x")],
Box::new(move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty;
let arg_val =
args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?))
}),
)
}
// NOTE: on `ndarray_factory_fn_shape_arg_tvar` and
// the `param_ty` for `create_fn_by_codegen`.
//
// Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking
// to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`],
// and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`.
PrimDef::FunNpReshape => create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
self.ndarray_num_ty,
&[(self.ndarray_num_ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_ty = fun.0.args[1].ty;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
}),
),
_ => unreachable!(),
}
}
/// Build `np_linalg` and `sp_linalg` functions
///
/// The input to these functions must be floating point `NDArray`
fn build_linalg_methods(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(
prim,
&[
PrimDef::FunNpDot,
PrimDef::FunNpLinalgCholesky,
PrimDef::FunNpLinalgQr,
PrimDef::FunNpLinalgSvd,
PrimDef::FunNpLinalgInv,
PrimDef::FunNpLinalgPinv,
PrimDef::FunNpLinalgMatrixPower,
PrimDef::FunNpLinalgDet,
PrimDef::FunSpLinalgLu,
PrimDef::FunSpLinalgSchur,
PrimDef::FunSpLinalgHessenberg,
],
);
match prim {
PrimDef::FunNpDot => create_fn_by_codegen(
self.unifier,
&self.num_or_ndarray_var_map,
prim.name(),
self.num_ty.ty,
&[(self.num_or_ndarray_ty.ty, "x1"), (self.num_or_ndarray_ty.ty, "x2")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_ty = fun.0.args[1].ty;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
}),
),
PrimDef::FunNpLinalgCholesky | PrimDef::FunNpLinalgInv | PrimDef::FunNpLinalgPinv => {
create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
self.ndarray_float_2d,
&[(self.ndarray_float_2d, "x1")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val =
args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let func = match prim {
PrimDef::FunNpLinalgCholesky => builtin_fns::call_np_linalg_cholesky,
PrimDef::FunNpLinalgInv => builtin_fns::call_np_linalg_inv,
PrimDef::FunNpLinalgPinv => builtin_fns::call_np_linalg_pinv,
_ => unreachable!(),
};
Ok(Some(func(generator, ctx, (x1_ty, x1_val))?))
}),
)
}
PrimDef::FunNpLinalgQr
| PrimDef::FunSpLinalgLu
| PrimDef::FunSpLinalgSchur
| PrimDef::FunSpLinalgHessenberg => {
let ret_ty = self.unifier.add_ty(TypeEnum::TTuple {
ty: vec![self.ndarray_float_2d, self.ndarray_float_2d],
is_vararg_ctx: false,
});
create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
ret_ty,
&[(self.ndarray_float_2d, "x1")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val =
args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let func = match prim {
PrimDef::FunNpLinalgQr => builtin_fns::call_np_linalg_qr,
PrimDef::FunSpLinalgLu => builtin_fns::call_sp_linalg_lu,
PrimDef::FunSpLinalgSchur => builtin_fns::call_sp_linalg_schur,
PrimDef::FunSpLinalgHessenberg => {
builtin_fns::call_sp_linalg_hessenberg
}
_ => unreachable!(),
};
Ok(Some(func(generator, ctx, (x1_ty, x1_val))?))
}),
)
}
PrimDef::FunNpLinalgSvd => {
let ret_ty = self.unifier.add_ty(TypeEnum::TTuple {
ty: vec![self.ndarray_float_2d, self.ndarray_float, self.ndarray_float_2d],
is_vararg_ctx: false,
});
create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
ret_ty,
&[(self.ndarray_float_2d, "x1")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val =
args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
Ok(Some(builtin_fns::call_np_linalg_svd(generator, ctx, (x1_ty, x1_val))?))
}),
)
}
PrimDef::FunNpLinalgMatrixPower => create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
self.ndarray_float_2d,
&[(self.ndarray_float_2d, "x1"), (self.primitives.int32, "power")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_ty = fun.0.args[1].ty;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(builtin_fns::call_np_linalg_matrix_power(
generator,
ctx,
(x1_ty, x1_val),
(x2_ty, x2_val),
)?))
}),
),
PrimDef::FunNpLinalgDet => create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
self.primitives.float,
&[(self.ndarray_float_2d, "x1")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
Ok(Some(builtin_fns::call_np_linalg_det(generator, ctx, (x1_ty, x1_val))?))
}),
),
_ => unreachable!(),
}
}
fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) { fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) {
(prim.simple_name().into(), method_ty, prim.id()) (prim.simple_name().into(), method_ty, prim.id())
} }

View File

@ -860,73 +860,7 @@ impl TopLevelComposer {
let resolver = &**resolver; let resolver = &**resolver;
let mut function_var_map = VarMap::new(); let mut function_var_map = VarMap::new();
let arg_types = {
let vararg = args
.vararg
.as_ref()
.map(|vararg| -> Result<_, HashSet<String>> {
let vararg = vararg.as_ref();
let annotation = vararg
.node
.annotation
.as_ref()
.ok_or_else(|| {
HashSet::from([format!(
"function parameter `{}` needs type annotation at {}",
vararg.node.arg, vararg.location
)])
})?
.as_ref();
let type_annotation = parse_ast_to_type_annotation_kinds(
resolver,
temp_def_list.as_slice(),
unifier,
primitives_store,
annotation,
// NOTE: since only class need this, for function
// it should be fine to be empty map
HashMap::new(),
)?;
let type_vars_within =
get_type_var_contained_in_type_annotation(&type_annotation)
.into_iter()
.map(|x| -> Result<TypeVar, HashSet<String>> {
let TypeAnnotation::TypeVar(ty) = x else {
unreachable!("must be type var annotation kind")
};
let id = Self::get_var_id(ty, unifier)?;
Ok(TypeVar { id, ty })
})
.collect::<Result<Vec<_>, _>>()?;
for var in type_vars_within {
if let Some(prev_ty) = function_var_map.insert(var.id, var.ty) {
// if already have the type inserted, make sure they are the same thing
assert_eq!(prev_ty, var.ty);
}
}
let ty = get_type_from_type_annotation_kinds(
temp_def_list.as_ref(),
unifier,
primitives_store,
&type_annotation,
&mut None,
)?;
Ok(FuncArg {
name: vararg.node.arg,
ty,
default_value: Some(SymbolValue::Tuple(Vec::default())),
is_vararg: true,
})
})
.transpose()?;
let mut arg_types = {
// make sure no duplicate parameter // make sure no duplicate parameter
let mut defined_parameter_name: HashSet<_> = HashSet::new(); let mut defined_parameter_name: HashSet<_> = HashSet::new();
for x in &args.args { for x in &args.args {
@ -1027,18 +961,11 @@ impl TopLevelComposer {
v v
}), }),
}, },
is_vararg: false,
}) })
}) })
.collect::<Result<Vec<_>, _>>()? .collect::<Result<Vec<_>, _>>()?
}; };
if let Some(vararg) = vararg {
arg_types.push(vararg);
};
let arg_types = arg_types;
let return_ty = { let return_ty = {
if let Some(returns) = returns { if let Some(returns) = returns {
let return_ty_annotation = { let return_ty_annotation = {
@ -1290,7 +1217,6 @@ impl TopLevelComposer {
}) })
} }
}, },
is_vararg: false,
}; };
// push the dummy type and the type annotation // push the dummy type and the type annotation
// into the list for later unification // into the list for later unification
@ -1716,25 +1642,21 @@ impl TopLevelComposer {
name: "msg".into(), name: "msg".into(),
ty: string, ty: string,
default_value: Some(SymbolValue::Str(String::new())), default_value: Some(SymbolValue::Str(String::new())),
is_vararg: false,
}, },
FuncArg { FuncArg {
name: "param0".into(), name: "param0".into(),
ty: int64, ty: int64,
default_value: Some(SymbolValue::I64(0)), default_value: Some(SymbolValue::I64(0)),
is_vararg: false,
}, },
FuncArg { FuncArg {
name: "param1".into(), name: "param1".into(),
ty: int64, ty: int64,
default_value: Some(SymbolValue::I64(0)), default_value: Some(SymbolValue::I64(0)),
is_vararg: false,
}, },
FuncArg { FuncArg {
name: "param2".into(), name: "param2".into(),
ty: int64, ty: int64,
default_value: Some(SymbolValue::I64(0)), default_value: Some(SymbolValue::I64(0)),
is_vararg: false,
}, },
], ],
ret: self_type, ret: self_type,
@ -1944,7 +1866,6 @@ impl TopLevelComposer {
name: a.name, name: a.name,
ty: unifier.subst(a.ty, &subst).unwrap_or(a.ty), ty: unifier.subst(a.ty, &subst).unwrap_or(a.ty),
default_value: a.default_value.clone(), default_value: a.default_value.clone(),
is_vararg: false,
}) })
.collect_vec() .collect_vec()
}; };

View File

@ -27,22 +27,17 @@ pub enum PrimDef {
List, List,
NDArray, NDArray,
// Option methods // Member Functions
FunOptionIsSome, OptionIsSome,
FunOptionIsNone, OptionIsNone,
FunOptionUnwrap, OptionUnwrap,
NDArrayCopy,
// Option-related functions NDArrayFill,
FunSome, FunInt32,
FunInt64,
// NDArray methods FunUInt32,
FunNDArrayCopy, FunUInt64,
FunNDArrayFill, FunFloat,
// Range methods
FunRangeInit,
// NumPy factory functions
FunNpNDArray, FunNpNDArray,
FunNpEmpty, FunNpEmpty,
FunNpZeros, FunNpZeros,
@ -51,17 +46,28 @@ pub enum PrimDef {
FunNpArray, FunNpArray,
FunNpEye, FunNpEye,
FunNpIdentity, FunNpIdentity,
FunRound,
// Miscellaneous NumPy & SciPy functions FunRound64,
FunNpRound, FunNpRound,
FunRangeInit,
FunStr,
FunBool,
FunFloor,
FunFloor64,
FunNpFloor, FunNpFloor,
FunCeil,
FunCeil64,
FunNpCeil, FunNpCeil,
FunLen,
FunMin,
FunNpMin, FunNpMin,
FunNpMinimum, FunNpMinimum,
FunNpArgmin, FunNpArgmin,
FunMax,
FunNpMax, FunNpMax,
FunNpMaximum, FunNpMaximum,
FunNpArgmax, FunNpArgmax,
FunAbs,
FunNpIsNan, FunNpIsNan,
FunNpIsInf, FunNpIsInf,
FunNpSin, FunNpSin,
@ -99,40 +105,9 @@ pub enum PrimDef {
FunNpLdExp, FunNpLdExp,
FunNpHypot, FunNpHypot,
FunNpNextAfter, FunNpNextAfter,
FunNpTranspose,
FunNpReshape,
// Linalg functions // Top-Level Functions
FunNpDot, FunSome,
FunNpLinalgCholesky,
FunNpLinalgQr,
FunNpLinalgSvd,
FunNpLinalgInv,
FunNpLinalgPinv,
FunNpLinalgMatrixPower,
FunNpLinalgDet,
FunSpLinalgLu,
FunSpLinalgSchur,
FunSpLinalgHessenberg,
// Miscellaneous Python & NAC3 functions
FunInt32,
FunInt64,
FunUInt32,
FunUInt64,
FunFloat,
FunRound,
FunRound64,
FunStr,
FunBool,
FunFloor,
FunFloor64,
FunCeil,
FunCeil64,
FunLen,
FunMin,
FunMax,
FunAbs,
} }
/// Associated details of a [`PrimDef`] /// Associated details of a [`PrimDef`]
@ -198,7 +173,6 @@ impl PrimDef {
} }
match self { match self {
// Classes
PrimDef::Int32 => class("int32", |primitives| primitives.int32), PrimDef::Int32 => class("int32", |primitives| primitives.int32),
PrimDef::Int64 => class("int64", |primitives| primitives.int64), PrimDef::Int64 => class("int64", |primitives| primitives.int64),
PrimDef::Float => class("float", |primitives| primitives.float), PrimDef::Float => class("float", |primitives| primitives.float),
@ -210,25 +184,18 @@ impl PrimDef {
PrimDef::UInt32 => class("uint32", |primitives| primitives.uint32), PrimDef::UInt32 => class("uint32", |primitives| primitives.uint32),
PrimDef::UInt64 => class("uint64", |primitives| primitives.uint64), PrimDef::UInt64 => class("uint64", |primitives| primitives.uint64),
PrimDef::Option => class("Option", |primitives| primitives.option), PrimDef::Option => class("Option", |primitives| primitives.option),
PrimDef::OptionIsSome => fun("Option.is_some", Some("is_some")),
PrimDef::OptionIsNone => fun("Option.is_none", Some("is_none")),
PrimDef::OptionUnwrap => fun("Option.unwrap", Some("unwrap")),
PrimDef::List => class("list", |primitives| primitives.list), PrimDef::List => class("list", |primitives| primitives.list),
PrimDef::NDArray => class("ndarray", |primitives| primitives.ndarray), PrimDef::NDArray => class("ndarray", |primitives| primitives.ndarray),
PrimDef::NDArrayCopy => fun("ndarray.copy", Some("copy")),
// Option methods PrimDef::NDArrayFill => fun("ndarray.fill", Some("fill")),
PrimDef::FunOptionIsSome => fun("Option.is_some", Some("is_some")), PrimDef::FunInt32 => fun("int32", None),
PrimDef::FunOptionIsNone => fun("Option.is_none", Some("is_none")), PrimDef::FunInt64 => fun("int64", None),
PrimDef::FunOptionUnwrap => fun("Option.unwrap", Some("unwrap")), PrimDef::FunUInt32 => fun("uint32", None),
PrimDef::FunUInt64 => fun("uint64", None),
// Option-related functions PrimDef::FunFloat => fun("float", None),
PrimDef::FunSome => fun("Some", None),
// NDArray methods
PrimDef::FunNDArrayCopy => fun("ndarray.copy", Some("copy")),
PrimDef::FunNDArrayFill => fun("ndarray.fill", Some("fill")),
// Range methods
PrimDef::FunRangeInit => fun("range.__init__", Some("__init__")),
// NumPy factory functions
PrimDef::FunNpNDArray => fun("np_ndarray", None), PrimDef::FunNpNDArray => fun("np_ndarray", None),
PrimDef::FunNpEmpty => fun("np_empty", None), PrimDef::FunNpEmpty => fun("np_empty", None),
PrimDef::FunNpZeros => fun("np_zeros", None), PrimDef::FunNpZeros => fun("np_zeros", None),
@ -237,17 +204,28 @@ impl PrimDef {
PrimDef::FunNpArray => fun("np_array", None), PrimDef::FunNpArray => fun("np_array", None),
PrimDef::FunNpEye => fun("np_eye", None), PrimDef::FunNpEye => fun("np_eye", None),
PrimDef::FunNpIdentity => fun("np_identity", None), PrimDef::FunNpIdentity => fun("np_identity", None),
PrimDef::FunRound => fun("round", None),
// Miscellaneous NumPy & SciPy functions PrimDef::FunRound64 => fun("round64", None),
PrimDef::FunNpRound => fun("np_round", None), PrimDef::FunNpRound => fun("np_round", None),
PrimDef::FunRangeInit => fun("range.__init__", Some("__init__")),
PrimDef::FunStr => fun("str", None),
PrimDef::FunBool => fun("bool", None),
PrimDef::FunFloor => fun("floor", None),
PrimDef::FunFloor64 => fun("floor64", None),
PrimDef::FunNpFloor => fun("np_floor", None), PrimDef::FunNpFloor => fun("np_floor", None),
PrimDef::FunCeil => fun("ceil", None),
PrimDef::FunCeil64 => fun("ceil64", None),
PrimDef::FunNpCeil => fun("np_ceil", None), PrimDef::FunNpCeil => fun("np_ceil", None),
PrimDef::FunLen => fun("len", None),
PrimDef::FunMin => fun("min", None),
PrimDef::FunNpMin => fun("np_min", None), PrimDef::FunNpMin => fun("np_min", None),
PrimDef::FunNpMinimum => fun("np_minimum", None), PrimDef::FunNpMinimum => fun("np_minimum", None),
PrimDef::FunNpArgmin => fun("np_argmin", None), PrimDef::FunNpArgmin => fun("np_argmin", None),
PrimDef::FunMax => fun("max", None),
PrimDef::FunNpMax => fun("np_max", None), PrimDef::FunNpMax => fun("np_max", None),
PrimDef::FunNpMaximum => fun("np_maximum", None), PrimDef::FunNpMaximum => fun("np_maximum", None),
PrimDef::FunNpArgmax => fun("np_argmax", None), PrimDef::FunNpArgmax => fun("np_argmax", None),
PrimDef::FunAbs => fun("abs", None),
PrimDef::FunNpIsNan => fun("np_isnan", None), PrimDef::FunNpIsNan => fun("np_isnan", None),
PrimDef::FunNpIsInf => fun("np_isinf", None), PrimDef::FunNpIsInf => fun("np_isinf", None),
PrimDef::FunNpSin => fun("np_sin", None), PrimDef::FunNpSin => fun("np_sin", None),
@ -285,40 +263,7 @@ impl PrimDef {
PrimDef::FunNpLdExp => fun("np_ldexp", None), PrimDef::FunNpLdExp => fun("np_ldexp", None),
PrimDef::FunNpHypot => fun("np_hypot", None), PrimDef::FunNpHypot => fun("np_hypot", None),
PrimDef::FunNpNextAfter => fun("np_nextafter", None), PrimDef::FunNpNextAfter => fun("np_nextafter", None),
PrimDef::FunNpTranspose => fun("np_transpose", None), PrimDef::FunSome => fun("Some", None),
PrimDef::FunNpReshape => fun("np_reshape", None),
// Linalg functions
PrimDef::FunNpDot => fun("np_dot", None),
PrimDef::FunNpLinalgCholesky => fun("np_linalg_cholesky", None),
PrimDef::FunNpLinalgQr => fun("np_linalg_qr", None),
PrimDef::FunNpLinalgSvd => fun("np_linalg_svd", None),
PrimDef::FunNpLinalgInv => fun("np_linalg_inv", None),
PrimDef::FunNpLinalgPinv => fun("np_linalg_pinv", None),
PrimDef::FunNpLinalgMatrixPower => fun("np_linalg_matrix_power", None),
PrimDef::FunNpLinalgDet => fun("np_linalg_det", None),
PrimDef::FunSpLinalgLu => fun("sp_linalg_lu", None),
PrimDef::FunSpLinalgSchur => fun("sp_linalg_schur", None),
PrimDef::FunSpLinalgHessenberg => fun("sp_linalg_hessenberg", None),
// Miscellaneous Python & NAC3 functions
PrimDef::FunInt32 => fun("int32", None),
PrimDef::FunInt64 => fun("int64", None),
PrimDef::FunUInt32 => fun("uint32", None),
PrimDef::FunUInt64 => fun("uint64", None),
PrimDef::FunFloat => fun("float", None),
PrimDef::FunRound => fun("round", None),
PrimDef::FunRound64 => fun("round64", None),
PrimDef::FunStr => fun("str", None),
PrimDef::FunBool => fun("bool", None),
PrimDef::FunFloor => fun("floor", None),
PrimDef::FunFloor64 => fun("floor64", None),
PrimDef::FunCeil => fun("ceil", None),
PrimDef::FunCeil64 => fun("ceil64", None),
PrimDef::FunLen => fun("len", None),
PrimDef::FunMin => fun("min", None),
PrimDef::FunMax => fun("max", None),
PrimDef::FunAbs => fun("abs", None),
} }
} }
} }
@ -469,9 +414,9 @@ impl TopLevelComposer {
let option = unifier.add_ty(TypeEnum::TObj { let option = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::Option.id(), obj_id: PrimDef::Option.id(),
fields: vec![ fields: vec![
(PrimDef::FunOptionIsSome.simple_name().into(), (is_some_type_fun_ty, true)), (PrimDef::OptionIsSome.simple_name().into(), (is_some_type_fun_ty, true)),
(PrimDef::FunOptionIsNone.simple_name().into(), (is_some_type_fun_ty, true)), (PrimDef::OptionIsNone.simple_name().into(), (is_some_type_fun_ty, true)),
(PrimDef::FunOptionUnwrap.simple_name().into(), (unwrap_fun_ty, true)), (PrimDef::OptionUnwrap.simple_name().into(), (unwrap_fun_ty, true)),
] ]
.into_iter() .into_iter()
.collect::<HashMap<_, _>>(), .collect::<HashMap<_, _>>(),
@ -505,7 +450,6 @@ impl TopLevelComposer {
name: "value".into(), name: "value".into(),
ty: ndarray_dtype_tvar.ty, ty: ndarray_dtype_tvar.ty,
default_value: None, default_value: None,
is_vararg: false,
}], }],
ret: none, ret: none,
vars: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]), vars: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
@ -513,8 +457,8 @@ impl TopLevelComposer {
let ndarray = unifier.add_ty(TypeEnum::TObj { let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::NDArray.id(), obj_id: PrimDef::NDArray.id(),
fields: Mapping::from([ fields: Mapping::from([
(PrimDef::FunNDArrayCopy.simple_name().into(), (ndarray_copy_fun_ty, true)), (PrimDef::NDArrayCopy.simple_name().into(), (ndarray_copy_fun_ty, true)),
(PrimDef::FunNDArrayFill.simple_name().into(), (ndarray_fill_fun_ty, true)), (PrimDef::NDArrayFill.simple_name().into(), (ndarray_fill_fun_ty, true)),
]), ]),
params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]), params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
}); });

View File

@ -5,7 +5,7 @@ expression: res_vec
[ [
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(246)]\n}\n", "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(245)]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",

View File

@ -7,7 +7,7 @@ expression: res_vec
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar235]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar235\"]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B[typevar234]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar234\"]\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",

View File

@ -5,8 +5,8 @@ expression: res_vec
[ [
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(248)]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(247)]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(253)]\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(252)]\n}\n",
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",

View File

@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec expression: res_vec
--- ---
[ [
"Class {\nname: \"A\",\nancestors: [\"A[typevar234, typevar235]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar234\", \"typevar235\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[typevar233, typevar234]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar233\", \"typevar234\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",

View File

@ -6,12 +6,12 @@ expression: res_vec
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(254)]\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(253)]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(262)]\n}\n", "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(261)]\n}\n",
] ]

View File

@ -552,7 +552,7 @@ pub fn get_type_from_type_annotation_kinds(
) )
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
Ok(unifier.add_ty(TypeEnum::TTuple { ty: tys, is_vararg_ctx: false })) Ok(unifier.add_ty(TypeEnum::TTuple { ty: tys }))
} }
} }
} }

View File

@ -34,18 +34,13 @@ impl<'a> Inferencer<'a> {
self.should_have_value(pattern)?; self.should_have_value(pattern)?;
Ok(()) Ok(())
} }
ExprKind::List { elts, .. } | ExprKind::Tuple { elts, .. } => { ExprKind::Tuple { elts, .. } => {
for elt in elts { for elt in elts {
self.check_pattern(elt, defined_identifiers)?; self.check_pattern(elt, defined_identifiers)?;
self.should_have_value(elt)?; self.should_have_value(elt)?;
} }
Ok(()) Ok(())
} }
ExprKind::Starred { value, .. } => {
self.check_pattern(value, defined_identifiers)?;
self.should_have_value(value)?;
Ok(())
}
ExprKind::Subscript { value, slice, .. } => { ExprKind::Subscript { value, slice, .. } => {
self.check_expr(value, defined_identifiers)?; self.check_expr(value, defined_identifiers)?;
self.should_have_value(value)?; self.should_have_value(value)?;
@ -212,6 +207,9 @@ impl<'a> Inferencer<'a> {
/// This is a workaround preventing the caller from using a variable `alloca`-ed in the body, which /// This is a workaround preventing the caller from using a variable `alloca`-ed in the body, which
/// is freed when the function returns. /// is freed when the function returns.
fn check_return_value_ty(&mut self, ret_ty: Type) -> bool { fn check_return_value_ty(&mut self, ret_ty: Type) -> bool {
if cfg!(feature = "no-escape-analysis") {
true
} else {
match &*self.unifier.get_ty_immutable(ret_ty) { match &*self.unifier.get_ty_immutable(ret_ty) {
TypeEnum::TObj { .. } => [ TypeEnum::TObj { .. } => [
self.primitives.int32, self.primitives.int32,
@ -223,10 +221,11 @@ impl<'a> Inferencer<'a> {
] ]
.iter() .iter()
.any(|allowed_ty| self.unifier.unioned(ret_ty, *allowed_ty)), .any(|allowed_ty| self.unifier.unioned(ret_ty, *allowed_ty)),
TypeEnum::TTuple { ty, .. } => ty.iter().all(|t| self.check_return_value_ty(*t)), TypeEnum::TTuple { ty } => ty.iter().all(|t| self.check_return_value_ty(*t)),
_ => false, _ => false,
} }
} }
}
// check statements for proper identifier def-use and return on all paths // check statements for proper identifier def-use and return on all paths
fn check_stmt( fn check_stmt(

View File

@ -197,7 +197,6 @@ pub fn impl_binop(
ty: other_ty, ty: other_ty,
default_value: None, default_value: None,
name: "other".into(), name: "other".into(),
is_vararg: false,
}], }],
})), })),
false, false,
@ -262,7 +261,6 @@ pub fn impl_cmpop(
ty: other_ty, ty: other_ty,
default_value: None, default_value: None,
name: "other".into(), name: "other".into(),
is_vararg: false,
}], }],
})), })),
false, false,

View File

@ -183,10 +183,9 @@ impl<'a> Display for DisplayTypeError<'a> {
} }
result result
} }
( (TypeEnum::TTuple { ty: ty1 }, TypeEnum::TTuple { ty: ty2 })
TypeEnum::TTuple { ty: ty1, is_vararg_ctx: is_vararg1 }, if ty1.len() != ty2.len() =>
TypeEnum::TTuple { ty: ty2, is_vararg_ctx: is_vararg2 }, {
) if !is_vararg1 && !is_vararg2 && ty1.len() != ty2.len() => {
let t1 = self.unifier.stringify_with_notes(*t1, &mut notes); let t1 = self.unifier.stringify_with_notes(*t1, &mut notes);
let t2 = self.unifier.stringify_with_notes(*t2, &mut notes); let t2 = self.unifier.stringify_with_notes(*t2, &mut notes);
write!(f, "Tuple length mismatch: got {t1} and {t2}") write!(f, "Tuple length mismatch: got {t1} and {t2}")

File diff suppressed because it is too large Load Diff

View File

@ -83,12 +83,7 @@ impl TestEnvironment {
}); });
with_fields(&mut unifier, int32, |unifier, fields| { with_fields(&mut unifier, int32, |unifier, fields| {
let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }],
name: "other".into(),
ty: int32,
default_value: None,
is_vararg: false,
}],
ret: int32, ret: int32,
vars: VarMap::new(), vars: VarMap::new(),
})); }));
@ -229,12 +224,7 @@ impl TestEnvironment {
}); });
with_fields(&mut unifier, int32, |unifier, fields| { with_fields(&mut unifier, int32, |unifier, fields| {
let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }],
name: "other".into(),
ty: int32,
default_value: None,
is_vararg: false,
}],
ret: int32, ret: int32,
vars: VarMap::new(), vars: VarMap::new(),
})); }));

View File

@ -1,14 +1,15 @@
use indexmap::IndexMap; use indexmap::IndexMap;
use itertools::{repeat_n, Itertools}; use itertools::Itertools;
use nac3parser::ast::{Cmpop, Location, StrRef, Unaryop};
use std::cell::RefCell; use std::cell::RefCell;
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::{self, Display}; use std::fmt::{self, Display};
use std::iter::{repeat, zip}; use std::iter::zip;
use std::rc::Rc; use std::rc::Rc;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use std::{borrow::Cow, collections::HashSet}; use std::{borrow::Cow, collections::HashSet};
use nac3parser::ast::{Cmpop, Location, StrRef, Unaryop};
use super::magic_methods::Binop; use super::magic_methods::Binop;
use super::type_error::{TypeError, TypeErrorKind}; use super::type_error::{TypeError, TypeErrorKind};
use super::unification_table::{UnificationKey, UnificationTable}; use super::unification_table::{UnificationKey, UnificationTable};
@ -114,7 +115,6 @@ pub struct FuncArg {
pub name: StrRef, pub name: StrRef,
pub ty: Type, pub ty: Type,
pub default_value: Option<SymbolValue>, pub default_value: Option<SymbolValue>,
pub is_vararg: bool,
} }
impl FuncArg { impl FuncArg {
@ -233,12 +233,6 @@ pub enum TypeEnum {
TTuple { TTuple {
/// The types of elements present in this tuple. /// The types of elements present in this tuple.
ty: Vec<Type>, ty: Vec<Type>,
/// Whether this tuple is used in a vararg context.
///
/// If `true`, `ty` must only contain one type, and the tuple is assumed to contain any
/// number of `ty`-typed values.
is_vararg_ctx: bool,
}, },
/// An object type. /// An object type.
@ -533,7 +527,7 @@ impl Unifier {
TypeEnum::TVirtual { ty } => self.get_instantiations(*ty).map(|ty| { TypeEnum::TVirtual { ty } => self.get_instantiations(*ty).map(|ty| {
ty.iter().map(|&ty| self.add_ty(TypeEnum::TVirtual { ty })).collect_vec() ty.iter().map(|&ty| self.add_ty(TypeEnum::TVirtual { ty })).collect_vec()
}), }),
TypeEnum::TTuple { ty, is_vararg_ctx } => { TypeEnum::TTuple { ty } => {
let tuples = ty let tuples = ty
.iter() .iter()
.map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty])) .map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty]))
@ -543,12 +537,7 @@ impl Unifier {
None None
} else { } else {
Some( Some(
tuples tuples.into_iter().map(|ty| self.add_ty(TypeEnum::TTuple { ty })).collect(),
.into_iter()
.map(|ty| {
self.add_ty(TypeEnum::TTuple { ty, is_vararg_ctx: *is_vararg_ctx })
})
.collect(),
) )
} }
} }
@ -592,7 +581,7 @@ impl Unifier {
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)), TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
TCall { .. } => false, TCall { .. } => false,
TVirtual { ty } => self.is_concrete(*ty, allowed_typevars), TVirtual { ty } => self.is_concrete(*ty, allowed_typevars),
TTuple { ty, .. } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)), TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)),
TObj { params: vars, .. } => { TObj { params: vars, .. } => {
vars.values().all(|ty| self.is_concrete(*ty, allowed_typevars)) vars.values().all(|ty| self.is_concrete(*ty, allowed_typevars))
} }
@ -660,7 +649,6 @@ impl Unifier {
// Get details about the function signature/parameters. // Get details about the function signature/parameters.
let num_params = signature.args.len(); let num_params = signature.args.len();
let is_vararg = signature.args.iter().any(|arg| arg.is_vararg);
// Force the type vars in `b` and `signature' to be up-to-date. // Force the type vars in `b` and `signature' to be up-to-date.
let b = self.instantiate_fun(b, signature); let b = self.instantiate_fun(b, signature);
@ -749,7 +737,7 @@ impl Unifier {
}; };
// Check for "too many arguments" // Check for "too many arguments"
if !is_vararg && num_params < posargs.len() { if num_params < posargs.len() {
let expected_min_count = let expected_min_count =
signature.args.iter().filter(|param| param.is_required()).count(); signature.args.iter().filter(|param| param.is_required()).count();
let expected_max_count = num_params; let expected_max_count = num_params;
@ -782,19 +770,6 @@ impl Unifier {
type_check_arg(param.name, param.ty, arg_ty)?; type_check_arg(param.name, param.ty, arg_ty)?;
} }
if is_vararg {
debug_assert!(!signature.args.is_empty());
let vararg_args = posargs.iter().skip(signature.args.len());
let vararg_param = signature.args.last().unwrap();
for (&arg_ty, param) in zip(vararg_args, repeat(vararg_param)) {
// `param_info` for this argument would've already been marked as supplied
// during non-vararg posarg typecheck
type_check_arg(param.name, param.ty, arg_ty)?;
}
}
// Now consume all keyword arguments and typecheck them. // Now consume all keyword arguments and typecheck them.
for (&param_name, &arg_ty) in kwargs { for (&param_name, &arg_ty) in kwargs {
// We will also use this opportunity to check if this keyword argument is "legal". // We will also use this opportunity to check if this keyword argument is "legal".
@ -984,10 +959,7 @@ impl Unifier {
self.unify_impl(x, b, false)?; self.unify_impl(x, b, false)?;
self.set_a_to_b(a, x); self.set_a_to_b(a, x);
} }
( (TVar { fields: Some(fields), range, is_const_generic: false, .. }, TTuple { ty }) => {
TVar { fields: Some(fields), range, is_const_generic: false, .. },
TTuple { ty, .. },
) => {
let len = i32::try_from(ty.len()).unwrap(); let len = i32::try_from(ty.len()).unwrap();
for (k, v) in fields { for (k, v) in fields {
match *k { match *k {
@ -1084,47 +1056,15 @@ impl Unifier {
self.set_a_to_b(a, b); self.set_a_to_b(a, b);
} }
( (TTuple { ty: ty1 }, TTuple { ty: ty2 }) => {
TTuple { ty: ty1, is_vararg_ctx: is_vararg1 },
TTuple { ty: ty2, is_vararg_ctx: is_vararg2 },
) => {
// Rules for Tuples:
// - ty1: is_vararg && ty2: is_vararg -> ty1[0] == ty2[0]
// - ty1: is_vararg && ty2: !is_vararg -> type error (not enough info to infer the correct number of arguments)
// - ty1: !is_vararg && ty2: is_vararg -> ty1[..] == ty2[0]
// - ty1: !is_vararg && ty2: !is_vararg -> ty1.len() == ty2.len() && ty1[i] == ty2[i]
debug_assert!(!is_vararg1 || ty1.len() == 1);
debug_assert!(!is_vararg2 || ty2.len() == 1);
match (*is_vararg1, *is_vararg2) {
(true, true) => {
if self.unify_impl(ty1[0], ty2[0], false).is_err() {
return Self::incompatible_types(a, b);
}
}
(true, false) => return Self::incompatible_types(a, b),
(false, true) => {
for y in ty2 {
if self.unify_impl(ty1[0], *y, false).is_err() {
return Self::incompatible_types(a, b);
}
}
}
(false, false) => {
if ty1.len() != ty2.len() { if ty1.len() != ty2.len() {
return Self::incompatible_types(a, b); return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None));
} }
for (x, y) in ty1.iter().zip(ty2.iter()) { for (x, y) in ty1.iter().zip(ty2.iter()) {
if self.unify_impl(*x, *y, false).is_err() { if self.unify_impl(*x, *y, false).is_err() {
return Self::incompatible_types(a, b); return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None));
} }
} }
}
}
self.set_a_to_b(a, b); self.set_a_to_b(a, b);
} }
(TVar { fields: Some(map), range, .. }, TObj { obj_id, fields, params }) => { (TVar { fields: Some(map), range, .. }, TObj { obj_id, fields, params }) => {
@ -1367,23 +1307,11 @@ impl Unifier {
TypeEnum::TLiteral { values, .. } => { TypeEnum::TLiteral { values, .. } => {
format!("const({})", values.iter().map(|v| format!("{v:?}")).join(", ")) format!("const({})", values.iter().map(|v| format!("{v:?}")).join(", "))
} }
TypeEnum::TTuple { ty, is_vararg_ctx } => { TypeEnum::TTuple { ty } => {
if *is_vararg_ctx { let mut fields =
debug_assert_eq!(ty.len(), 1); ty.iter().map(|v| self.internal_stringify(*v, obj_to_name, var_to_name, notes));
let field = self.internal_stringify(
*ty.iter().next().unwrap(),
obj_to_name,
var_to_name,
notes,
);
format!("tuple[*{field}]")
} else {
let mut fields = ty
.iter()
.map(|v| self.internal_stringify(*v, obj_to_name, var_to_name, notes));
format!("tuple[{}]", fields.join(", ")) format!("tuple[{}]", fields.join(", "))
} }
}
TypeEnum::TVirtual { ty } => { TypeEnum::TVirtual { ty } => {
format!( format!(
"virtual[{}]", "virtual[{}]",
@ -1407,21 +1335,17 @@ impl Unifier {
.args .args
.iter() .iter()
.map(|arg| { .map(|arg| {
let vararg_prefix = if arg.is_vararg { "*" } else { "" };
if let Some(dv) = &arg.default_value { if let Some(dv) = &arg.default_value {
format!( format!(
"{}:{}{}={}", "{}:{}={}",
arg.name, arg.name,
vararg_prefix,
self.internal_stringify(arg.ty, obj_to_name, var_to_name, notes), self.internal_stringify(arg.ty, obj_to_name, var_to_name, notes),
dv dv
) )
} else { } else {
format!( format!(
"{}:{}{}", "{}:{}",
arg.name, arg.name,
vararg_prefix,
self.internal_stringify(arg.ty, obj_to_name, var_to_name, notes) self.internal_stringify(arg.ty, obj_to_name, var_to_name, notes)
) )
} }
@ -1507,7 +1431,7 @@ impl Unifier {
match &*ty { match &*ty {
TypeEnum::TRigidVar { .. } | TypeEnum::TLiteral { .. } => None, TypeEnum::TRigidVar { .. } | TypeEnum::TLiteral { .. } => None,
TypeEnum::TVar { id, .. } => mapping.get(id).copied(), TypeEnum::TVar { id, .. } => mapping.get(id).copied(),
TypeEnum::TTuple { ty, is_vararg_ctx } => { TypeEnum::TTuple { ty } => {
let mut new_ty = Cow::from(ty); let mut new_ty = Cow::from(ty);
for (i, t) in ty.iter().enumerate() { for (i, t) in ty.iter().enumerate() {
if let Some(t1) = self.subst_impl(*t, mapping, cache) { if let Some(t1) = self.subst_impl(*t, mapping, cache) {
@ -1515,10 +1439,7 @@ impl Unifier {
} }
} }
if matches!(new_ty, Cow::Owned(_)) { if matches!(new_ty, Cow::Owned(_)) {
Some(self.add_ty(TypeEnum::TTuple { Some(self.add_ty(TypeEnum::TTuple { ty: new_ty.into_owned() }))
ty: new_ty.into_owned(),
is_vararg_ctx: *is_vararg_ctx,
}))
} else { } else {
None None
} }
@ -1678,37 +1599,16 @@ impl Unifier {
} }
} }
(TVar { range, .. }, _) => self.check_var_compatibility(b, range).or(Err(())), (TVar { range, .. }, _) => self.check_var_compatibility(b, range).or(Err(())),
( (TTuple { ty: ty1 }, TTuple { ty: ty2 }) if ty1.len() == ty2.len() => {
TTuple { ty: ty1, is_vararg_ctx: is_vararg1 }, let ty: Vec<_> = zip(ty1.iter(), ty2.iter())
TTuple { ty: ty2, is_vararg_ctx: is_vararg2 }, .map(|(a, b)| self.get_intersection(*a, *b))
) => { .try_collect()?;
if *is_vararg1 && *is_vararg2 { if ty.iter().any(Option::is_some) {
let isect_ty = self.get_intersection(ty1[0], ty2[0])?; Ok(Some(self.add_ty(TTuple {
Ok(isect_ty.map(|ty| self.add_ty(TTuple { ty: vec![ty], is_vararg_ctx: true })))
} else {
let zip_iter: Box<dyn Iterator<Item = (&Type, &Type)>> =
match (*is_vararg1, *is_vararg2) {
(true, _) => Box::new(repeat_n(&ty1[0], ty2.len()).zip(ty2.iter())),
(_, false) => Box::new(ty1.iter().zip(repeat_n(&ty2[0], ty1.len()))),
_ => {
if ty1.len() != ty2.len() {
return Err(());
}
Box::new(ty1.iter().zip(ty2.iter()))
}
};
let ty: Vec<_> =
zip_iter.map(|(a, b)| self.get_intersection(*a, *b)).try_collect()?;
Ok(if ty.iter().any(Option::is_some) {
Some(self.add_ty(TTuple {
ty: zip(ty, ty1.iter()).map(|(a, b)| a.unwrap_or(*b)).collect(), ty: zip(ty, ty1.iter()).map(|(a, b)| a.unwrap_or(*b)).collect(),
is_vararg_ctx: false, })))
}))
} else { } else {
None Ok(None)
})
} }
} }
// TODO(Derppening): #444 // TODO(Derppening): #444

View File

@ -28,10 +28,7 @@ impl Unifier {
TypeEnum::TVar { fields: Some(map1), .. }, TypeEnum::TVar { fields: Some(map1), .. },
TypeEnum::TVar { fields: Some(map2), .. }, TypeEnum::TVar { fields: Some(map2), .. },
) => self.map_eq2(map1, map2), ) => self.map_eq2(map1, map2),
( (TypeEnum::TTuple { ty: ty1 }, TypeEnum::TTuple { ty: ty2 }) => {
TypeEnum::TTuple { ty: ty1, is_vararg_ctx: false },
TypeEnum::TTuple { ty: ty2, is_vararg_ctx: false },
) => {
ty1.len() == ty2.len() ty1.len() == ty2.len()
&& ty1.iter().zip(ty2.iter()).all(|(t1, t2)| self.eq(*t1, *t2)) && ty1.iter().zip(ty2.iter()).all(|(t1, t2)| self.eq(*t1, *t2))
} }
@ -181,7 +178,7 @@ impl TestEnvironment {
ty.push(result.0); ty.push(result.0);
s = result.1; s = result.1;
} }
(self.unifier.add_ty(TypeEnum::TTuple { ty, is_vararg_ctx: false }), &s[1..]) (self.unifier.add_ty(TypeEnum::TTuple { ty }), &s[1..])
} }
"Record" => { "Record" => {
let mut s = &typ[end..]; let mut s = &typ[end..];
@ -611,7 +608,7 @@ fn test_instantiation() {
let v1 = env.unifier.get_fresh_var_with_range(&[list_v, int], None, None).ty; let v1 = env.unifier.get_fresh_var_with_range(&[list_v, int], None, None).ty;
let v2 = env.unifier.get_fresh_var_with_range(&[list_int, float], None, None).ty; let v2 = env.unifier.get_fresh_var_with_range(&[list_int, float], None, None).ty;
let t = env.unifier.get_dummy_var().ty; let t = env.unifier.get_dummy_var().ty;
let tuple = env.unifier.add_ty(TypeEnum::TTuple { ty: vec![v, v1, v2], is_vararg_ctx: false }); let tuple = env.unifier.add_ty(TypeEnum::TTuple { ty: vec![v, v1, v2] });
let v3 = env.unifier.get_fresh_var_with_range(&[tuple, t], None, None).ty; let v3 = env.unifier.get_fresh_var_with_range(&[tuple, t], None, None).ty;
// t = TypeVar('t') // t = TypeVar('t')
// v = TypeVar('v', int, bool) // v = TypeVar('v', int, bool)

View File

@ -4,6 +4,9 @@ version = "0.1.0"
authors = ["M-Labs"] authors = ["M-Labs"]
edition = "2021" edition = "2021"
[features]
no-escape-analysis = ["nac3core/no-escape-analysis"]
[dependencies] [dependencies]
parking_lot = "0.12" parking_lot = "0.12"
nac3parser = { path = "../nac3parser" } nac3parser = { path = "../nac3parser" }

View File

@ -3,55 +3,23 @@
set -e set -e
if [ -z "$1" ]; then if [ -z "$1" ]; then
echo "No argument supplied" echo "Requires at least one argument"
exit 1 exit 1
fi fi
declare -a nac3args declare -a nac3args
while [ $# -ge 2 ]; do
case "$1" in
--help)
echo "Usage: check_demo.sh [-i686] -- demo [NAC3ARGS...]"
exit
;;
-i686)
i686=1
;;
--)
shift
break
;;
*)
break
;;
esac
shift
done
demo="$1"
shift
while [ $# -gt 1 ]; do while [ $# -gt 1 ]; do
nac3args+=("$1") nac3args+=("$1")
shift shift
done done
demo="$1"
echo -n "Checking $demo... "
echo "### Checking $demo..."
echo ">>>>>> Running $demo with the Python interpreter"
./interpret_demo.py "$demo" > interpreted.log ./interpret_demo.py "$demo" > interpreted.log
./run_demo.sh --out run.log "${nac3args[@]}" "$demo"
./run_demo.sh --lli --out run_lli.log "${nac3args[@]}" "$demo"
diff -Nau interpreted.log run.log
diff -Nau interpreted.log run_lli.log
echo "ok"
if [ -n "$i686" ]; then rm -f interpreted.log run.log run_lli.log
echo "...... Trying NAC3's 32-bit code generator output"
./run_demo.sh -i686 --out run_32.log "${nac3args[@]}" "$demo"
diff -Nau interpreted.log run_32.log
fi
echo "...... Trying NAC3's 64-bit code generator output"
./run_demo.sh --out run_64.log "${nac3args[@]}" "$demo"
diff -Nau interpreted.log run_64.log
echo "...... OK"
rm -f interpreted.log \
run_32.log run_64.log

View File

@ -6,6 +6,8 @@
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#define usize size_t
double dbl_nan(void) { double dbl_nan(void) {
return NAN; return NAN;
} }
@ -62,14 +64,14 @@ void output_asciiart(int32_t x) {
struct cslice { struct cslice {
void *data; void *data;
size_t len; usize len;
}; };
void output_int32_list(struct cslice *slice) { void output_int32_list(struct cslice *slice) {
const int32_t *data = (int32_t *) slice->data; const int32_t *data = (int32_t *) slice->data;
putchar('['); putchar('[');
for (size_t i = 0; i < slice->len; ++i) { for (usize i = 0; i < slice->len; ++i) {
if (i == slice->len - 1) { if (i == slice->len - 1) {
printf("%d", data[i]); printf("%d", data[i]);
} else { } else {
@ -83,7 +85,7 @@ void output_int32_list(struct cslice *slice) {
void output_str(struct cslice *slice) { void output_str(struct cslice *slice) {
const char *data = (const char *) slice->data; const char *data = (const char *) slice->data;
for (size_t i = 0; i < slice->len; ++i) { for (usize i = 0; i < slice->len; ++i) {
putchar(data[i]); putchar(data[i]);
} }
} }
@ -119,11 +121,11 @@ struct Exception {
uint32_t __nac3_raise(struct Exception* e) { uint32_t __nac3_raise(struct Exception* e) {
printf("__nac3_raise called. Exception details:\n"); printf("__nac3_raise called. Exception details:\n");
printf(" ID: %"PRIu32"\n", e->id); printf(" ID: %lld\n", e->id);
printf(" Location: %*s:%"PRIu32":%"PRIu32"\n" , (int) e->file.len, (const char*) e->file.data, e->line, e->column); printf(" Location: %*s:%lld:%lld\n" , e->file.len, (const char*) e->file.data, e->line, e->column);
printf(" Function: %*s\n" , (int) e->function.len, (const char*) e->function.data); printf(" Function: %*s\n" , e->function.len, (const char*) e->function.data);
printf(" Message: \"%*s\"\n" , (int) e->message.len, (const char*) e->message.data); printf(" Message: \"%*s\"\n" , e->message.len, (const char*) e->message.data);
printf(" Params: {0}=%"PRId64", {1}=%"PRId64", {2}=%"PRId64"\n", e->param[0], e->param[1], e->param[2]); printf(" Params: {0}=%lld, {1}=%lld, {2}=%lld\n", e->param[0], e->param[1], e->param[2]);
exit(101); exit(101);
__builtin_unreachable(); __builtin_unreachable();
} }

View File

@ -6,7 +6,6 @@ import importlib.machinery
import math import math
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
import scipy as sp
import pathlib import pathlib
from numpy import int32, int64, uint32, uint64 from numpy import int32, int64, uint32, uint64
@ -168,7 +167,7 @@ def patch(module):
module.ceil64 = _ceil module.ceil64 = _ceil
module.np_ceil = np.ceil module.np_ceil = np.ceil
# NumPy NDArray factory functions # NumPy ndarray functions
module.ndarray = NDArray module.ndarray = NDArray
module.np_ndarray = np.ndarray module.np_ndarray = np.ndarray
module.np_empty = np.empty module.np_empty = np.empty
@ -218,10 +217,8 @@ def patch(module):
module.np_ldexp = np.ldexp module.np_ldexp = np.ldexp
module.np_hypot = np.hypot module.np_hypot = np.hypot
module.np_nextafter = np.nextafter module.np_nextafter = np.nextafter
module.np_transpose = np.transpose
module.np_reshape = np.reshape
# SciPy Math functions # SciPy Math Functions
module.sp_spec_erf = special.erf module.sp_spec_erf = special.erf
module.sp_spec_erfc = special.erfc module.sp_spec_erfc = special.erfc
module.sp_spec_gamma = special.gamma module.sp_spec_gamma = special.gamma
@ -229,19 +226,14 @@ def patch(module):
module.sp_spec_j0 = special.j0 module.sp_spec_j0 = special.j0
module.sp_spec_j1 = special.j1 module.sp_spec_j1 = special.j1
# Linalg functions # NumPy NDArray Functions
module.np_dot = np.dot module.np_ndarray = np.ndarray
module.np_linalg_cholesky = np.linalg.cholesky module.np_empty = np.empty
module.np_linalg_qr = np.linalg.qr module.np_zeros = np.zeros
module.np_linalg_svd = np.linalg.svd module.np_ones = np.ones
module.np_linalg_inv = np.linalg.inv module.np_full = np.full
module.np_linalg_pinv = np.linalg.pinv module.np_eye = np.eye
module.np_linalg_matrix_power = np.linalg.matrix_power module.np_identity = np.identity
module.np_linalg_det = np.linalg.det
module.sp_linalg_lu = lambda x: sp.linalg.lu(x, True)
module.sp_linalg_schur = sp.linalg.schur
module.sp_linalg_hessenberg = lambda x: sp.linalg.hessenberg(x, True)
def file_import(filename, prefix="file_import_"): def file_import(filename, prefix="file_import_"):
filename = pathlib.Path(filename) filename = pathlib.Path(filename)

View File

@ -1,114 +0,0 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
[[package]]
name = "approx"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6"
dependencies = [
"num-traits",
]
[[package]]
name = "autocfg"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
[[package]]
name = "cslice"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f8cb7306107e4b10e64994de6d3274bd08996a7c1322a27b86482392f96be0a"
[[package]]
name = "libm"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
[[package]]
name = "linalg"
version = "0.1.0"
dependencies = [
"cslice",
"nalgebra",
]
[[package]]
name = "nalgebra"
version = "0.32.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b5c17de023a86f59ed79891b2e5d5a94c705dbe904a5b5c9c952ea6221b03e4"
dependencies = [
"approx",
"num-complex",
"num-rational",
"num-traits",
"simba",
"typenum",
]
[[package]]
name = "num-complex"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
dependencies = [
"num-traits",
]
[[package]]
name = "num-integer"
version = "0.1.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
dependencies = [
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
dependencies = [
"num-integer",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
"libm",
]
[[package]]
name = "paste"
version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
[[package]]
name = "simba"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "061507c94fc6ab4ba1c9a0305018408e312e17c041eb63bef8aa726fa33aceae"
dependencies = [
"approx",
"num-complex",
"num-traits",
"paste",
]
[[package]]
name = "typenum"
version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"

View File

@ -1,13 +0,0 @@
[package]
name = "linalg"
version = "0.1.0"
edition = "2021"
[lib]
crate-type = ["staticlib"]
[dependencies]
nalgebra = {version = "0.32.6", default-features = false, features = ["libm", "alloc"]}
cslice = "0.3.0"
[workspace]

View File

@ -1,406 +0,0 @@
// Uses `nalgebra` crate to invoke `np_linalg` and `sp_linalg` functions
// When converting between `nalgebra::Matrix` and `NDArray` following considerations are necessary
//
// * Both `nalgebra::Matrix` and `NDArray` require their content to be stored in row-major order
// * `NDArray` data pointer can be directly read and converted to `nalgebra::Matrix` (row and column number must be known)
// * `nalgebra::Matrix::as_slice` returns the content of matrix in column-major order and initial data needs to be transposed before storing it in `NDArray` data pointer
use core::slice;
use nalgebra::DMatrix;
fn report_error(
error_name: &str,
fn_name: &str,
file_name: &str,
line_num: u32,
col_num: u32,
err_msg: &str,
) -> ! {
panic!(
"Exception {} from {} in {}:{}:{}, message: {}",
error_name, fn_name, file_name, line_num, col_num, err_msg
);
}
pub struct InputMatrix {
pub ndims: usize,
pub dims: *const usize,
pub data: *mut f64,
}
impl InputMatrix {
fn get_dims(&mut self) -> Vec<usize> {
let dims = unsafe { slice::from_raw_parts(self.dims, self.ndims) };
dims.to_vec()
}
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_cholesky(mat1: *mut InputMatrix, out: *mut InputMatrix) {
let mat1 = mat1.as_mut().unwrap();
let out = out.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "np_linalg_cholesky", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
if dim1[0] != dim1[1] {
let err_msg =
format!("last 2 dimensions of the array must be square: {0} != {1}", dim1[0], dim1[1]);
report_error("LinAlgError", "np_linalg_cholesky", file!(), line!(), column!(), &err_msg);
}
let outdim = out.get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let result = matrix1.cholesky();
match result {
Some(res) => {
out_slice.copy_from_slice(res.unpack().transpose().as_slice());
}
None => {
report_error(
"LinAlgError",
"np_linalg_cholesky",
file!(),
line!(),
column!(),
"Matrix is not positive definite",
);
}
};
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_qr(
mat1: *mut InputMatrix,
out_q: *mut InputMatrix,
out_r: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let out_q = out_q.as_mut().unwrap();
let out_r = out_r.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "np_linalg_cholesky", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let outq_dim = (*out_q).get_dims();
let outr_dim = (*out_r).get_dims();
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let out_q_slice = unsafe { slice::from_raw_parts_mut(out_q.data, outq_dim[0] * outq_dim[1]) };
let out_r_slice = unsafe { slice::from_raw_parts_mut(out_r.data, outr_dim[0] * outr_dim[1]) };
// Refer to https://github.com/dimforge/nalgebra/issues/735
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let res = matrix1.qr();
let (q, r) = res.unpack();
// Uses different algo need to match numpy
out_q_slice.copy_from_slice(q.transpose().as_slice());
out_r_slice.copy_from_slice(r.transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_svd(
mat1: *mut InputMatrix,
outu: *mut InputMatrix,
outs: *mut InputMatrix,
outvh: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let outu = outu.as_mut().unwrap();
let outs = outs.as_mut().unwrap();
let outvh = outvh.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "np_linalg_svd", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let outu_dim = (*outu).get_dims();
let outs_dim = (*outs).get_dims();
let outvh_dim = (*outvh).get_dims();
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let out_u_slice = unsafe { slice::from_raw_parts_mut(outu.data, outu_dim[0] * outu_dim[1]) };
let out_s_slice = unsafe { slice::from_raw_parts_mut(outs.data, outs_dim[0]) };
let out_vh_slice =
unsafe { slice::from_raw_parts_mut(outvh.data, outvh_dim[0] * outvh_dim[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let result = matrix.svd(true, true);
out_u_slice.copy_from_slice(result.u.unwrap().transpose().as_slice());
out_s_slice.copy_from_slice(result.singular_values.as_slice());
out_vh_slice.copy_from_slice(result.v_t.unwrap().transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_inv(mat1: *mut InputMatrix, out: *mut InputMatrix) {
let mat1 = mat1.as_mut().unwrap();
let out = out.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "np_linalg_inv", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
if dim1[0] != dim1[1] {
let err_msg =
format!("last 2 dimensions of the array must be square: {0} != {1}", dim1[0], dim1[1]);
report_error("LinAlgError", "np_linalg_inv", file!(), line!(), column!(), &err_msg);
}
let outdim = out.get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
if !matrix.is_invertible() {
report_error(
"LinAlgError",
"np_linalg_inv",
file!(),
line!(),
column!(),
"no inverse for Singular Matrix",
);
}
let inv = matrix.try_inverse().unwrap();
out_slice.copy_from_slice(inv.transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_pinv(mat1: *mut InputMatrix, out: *mut InputMatrix) {
let mat1 = mat1.as_mut().unwrap();
let out = out.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "np_linalg_pinv", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let outdim = out.get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let svd = matrix.svd(true, true);
let inv = svd.pseudo_inverse(1e-15);
match inv {
Ok(m) => {
out_slice.copy_from_slice(m.transpose().as_slice());
}
Err(err_msg) => {
report_error("LinAlgError", "np_linalg_pinv", file!(), line!(), column!(), err_msg);
}
}
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_matrix_power(
mat1: *mut InputMatrix,
mat2: *mut InputMatrix,
out: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let mat2 = mat2.as_mut().unwrap();
let out = out.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D", mat1.ndims);
report_error("ValueError", "np_linalg_matrix_power", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let power = unsafe { slice::from_raw_parts_mut(mat2.data, 1) };
let power = power[0];
let outdim = out.get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let abs_pow = power.abs();
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let mut result = matrix1.pow(abs_pow as u32);
if power < 0.0 {
if !result.is_invertible() {
report_error(
"LinAlgError",
"np_linalg_inv",
file!(),
line!(),
column!(),
"no inverse for Singular Matrix",
);
}
result = result.try_inverse().unwrap();
}
out_slice.copy_from_slice(result.transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_det(mat1: *mut InputMatrix, out: *mut InputMatrix) {
let mat1 = mat1.as_mut().unwrap();
let out = out.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "np_linalg_det", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, 1) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
if !matrix.is_square() {
let err_msg =
format!("last 2 dimensions of the array must be square: {0} != {1}", dim1[0], dim1[1]);
report_error("LinAlgError", "np_linalg_inv", file!(), line!(), column!(), &err_msg);
}
out_slice[0] = matrix.determinant();
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn sp_linalg_lu(
mat1: *mut InputMatrix,
out_l: *mut InputMatrix,
out_u: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let out_l = out_l.as_mut().unwrap();
let out_u = out_u.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "sp_linalg_lu", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let outl_dim = (*out_l).get_dims();
let outu_dim = (*out_u).get_dims();
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let out_l_slice = unsafe { slice::from_raw_parts_mut(out_l.data, outl_dim[0] * outl_dim[1]) };
let out_u_slice = unsafe { slice::from_raw_parts_mut(out_u.data, outu_dim[0] * outu_dim[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let (_, l, u) = matrix.lu().unpack();
out_l_slice.copy_from_slice(l.transpose().as_slice());
out_u_slice.copy_from_slice(u.transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn sp_linalg_schur(
mat1: *mut InputMatrix,
out_t: *mut InputMatrix,
out_z: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let out_t = out_t.as_mut().unwrap();
let out_z = out_z.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "sp_linalg_schur", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
if dim1[0] != dim1[1] {
let err_msg =
format!("last 2 dimensions of the array must be square: {0} != {1}", dim1[0], dim1[1]);
report_error("LinAlgError", "np_linalg_schur", file!(), line!(), column!(), &err_msg);
}
let out_t_dim = (*out_t).get_dims();
let out_z_dim = (*out_z).get_dims();
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let out_t_slice = unsafe { slice::from_raw_parts_mut(out_t.data, out_t_dim[0] * out_t_dim[1]) };
let out_z_slice = unsafe { slice::from_raw_parts_mut(out_z.data, out_z_dim[0] * out_z_dim[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let (z, t) = matrix.schur().unpack();
out_t_slice.copy_from_slice(t.transpose().as_slice());
out_z_slice.copy_from_slice(z.transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn sp_linalg_hessenberg(
mat1: *mut InputMatrix,
out_h: *mut InputMatrix,
out_q: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let out_h = out_h.as_mut().unwrap();
let out_q = out_q.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "sp_linalg_hessenberg", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
if dim1[0] != dim1[1] {
let err_msg =
format!("last 2 dimensions of the array must be square: {} != {}", dim1[0], dim1[1]);
report_error("LinAlgError", "sp_linalg_hessenberg", file!(), line!(), column!(), &err_msg);
}
let out_h_dim = (*out_h).get_dims();
let out_q_dim = (*out_q).get_dims();
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let out_h_slice = unsafe { slice::from_raw_parts_mut(out_h.data, out_h_dim[0] * out_h_dim[1]) };
let out_q_slice = unsafe { slice::from_raw_parts_mut(out_q.data, out_q_dim[0] * out_q_dim[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let (q, h) = matrix.hessenberg().unpack();
out_h_slice.copy_from_slice(h.transpose().as_slice());
out_q_slice.copy_from_slice(q.transpose().as_slice());
}

View File

@ -2,9 +2,6 @@
set -e set -e
: "${DEMO_LINALG_STUB:=linalg/target/release/liblinalg.a}"
: "${DEMO_LINALG_STUB32:=linalg/target/i686-unknown-linux-gnu/release/liblinalg.a}"
if [ -z "$1" ]; then if [ -z "$1" ]; then
echo "No argument supplied" echo "No argument supplied"
exit 1 exit 1
@ -14,19 +11,19 @@ declare -a nac3args
while [ $# -ge 1 ]; do while [ $# -ge 1 ]; do
case "$1" in case "$1" in
--help) --help)
echo "Usage: run_demo.sh [--help] [--out OUTFILE] [--debug] [-i686] -- [NAC3ARGS...]" echo "Usage: run_demo.sh [--help] [--out OUTFILE] [--lli] [--debug] -- [NAC3ARGS...]"
exit exit
;; ;;
--out) --out)
shift shift
outfile="$1" outfile="$1"
;; ;;
--lli)
use_lli=1
;;
--debug) --debug)
debug=1 debug=1
;; ;;
-i686)
i686=1
;;
--) --)
shift shift
break break
@ -53,19 +50,29 @@ else
fi fi
rm -f ./*.o ./*.bc demo rm -f ./*.o ./*.bc demo
if [ -z "$use_lli" ]; then
if [ -z "$i686" ]; then
$nac3standalone "${nac3args[@]}" $nac3standalone "${nac3args[@]}"
clang -c -std=gnu11 -Wall -Wextra -O3 -o demo.o demo.c clang -c -std=gnu11 -Wall -Wextra -O3 -o demo.o demo.c
clang -o demo module.o demo.o $DEMO_LINALG_STUB -lm -Wl,--no-warn-search-mismatch clang -lm -o demo module.o demo.o
else
$nac3standalone --triple i686-unknown-linux-gnu "${nac3args[@]}"
clang -m32 -c -std=gnu11 -Wall -Wextra -O3 -msse2 -o demo.o demo.c
clang -m32 -o demo module.o demo.o $DEMO_LINALG_STUB32 -lm -Wl,--no-warn-search-mismatch
fi
if [ -z "$outfile" ]; then if [ -z "$outfile" ]; then
./demo ./demo
else else
./demo > "$outfile" ./demo > "$outfile"
fi fi
else
$nac3standalone --emit-llvm "${nac3args[@]}"
clang -c -std=gnu11 -Wall -Wextra -O3 -emit-llvm -o demo.bc demo.c
shopt -s nullglob
llvm-link -o nac3out.bc module*.bc main.bc
shopt -u nullglob
if [ -z "$outfile" ]; then
lli --extra-module demo.bc --extra-module irrt.bc nac3out.bc
else
lli --extra-module demo.bc --extra-module irrt.bc nac3out.bc > "$outfile"
fi
fi

View File

@ -1,66 +0,0 @@
@extern
def output_int32(x: int32):
...
@extern
def output_bool(x: bool):
...
def example1():
x, *ys, z = (1, 2, 3, 4, 5)
output_int32(x)
output_int32(ys[0])
output_int32(ys[1])
output_int32(ys[2])
output_int32(z)
def example2():
x, y, *zs = (1, 2, 3, 4, 5)
output_int32(x)
output_int32(y)
output_int32(zs[0])
output_int32(zs[1])
output_int32(zs[2])
def example3():
*xs, y, z = (1, 2, 3, 4, 5)
output_int32(xs[0])
output_int32(xs[1])
output_int32(xs[2])
output_int32(y)
output_int32(z)
def example4():
# Example from: https://docs.python.org/3/reference/simple_stmts.html#assignment-statements
x = [0, 1]
i = 0
i, x[i] = 1, 2 # i is updated, then x[i] is updated
output_int32(i)
output_int32(x[0])
output_int32(x[1])
class A:
value: int32
def __init__(self):
self.value = 1000
def example5():
ws = [88, 7, 8]
a = A()
x, [y, *ys, a.value], ws[0], (ws[0],) = 1, (2, False, 4, 5), 99, (6,)
output_int32(x)
output_int32(y)
output_bool(ys[0])
output_int32(ys[1])
output_int32(a.value)
output_int32(ws[0])
output_int32(ws[1])
output_int32(ws[2])
def run() -> int32:
example1()
example2()
example3()
example4()
example5()
return 0

View File

@ -1429,142 +1429,6 @@ def test_ndarray_nextafter_broadcast_rhs_scalar():
output_ndarray_float_2(nextafter_x_zeros) output_ndarray_float_2(nextafter_x_zeros)
output_ndarray_float_2(nextafter_x_ones) output_ndarray_float_2(nextafter_x_ones)
def test_ndarray_transpose():
x: ndarray[float, 2] = np_array([[1., 2., 3.], [4., 5., 6.]])
y = np_transpose(x)
z = np_transpose(y)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_reshape():
w: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])
x = np_reshape(w, (1, 2, 1, -1))
y = np_reshape(x, [2, -1])
z = np_reshape(y, 10)
x1: ndarray[int32, 1] = np_array([1, 2, 3, 4])
x2: ndarray[int32, 2] = np_reshape(x1, (2, 2))
output_ndarray_float_1(w)
output_ndarray_float_2(y)
output_ndarray_float_1(z)
def test_ndarray_dot():
x1: ndarray[float, 1] = np_array([5.0, 1.0, 4.0, 2.0])
y1: ndarray[float, 1] = np_array([5.0, 1.0, 6.0, 6.0])
z1 = np_dot(x1, y1)
x2: ndarray[int32, 1] = np_array([5, 1, 4, 2])
y2: ndarray[int32, 1] = np_array([5, 1, 6, 6])
z2 = np_dot(x2, y2)
x3: ndarray[bool, 1] = np_array([True, True, True, True])
y3: ndarray[bool, 1] = np_array([True, True, True, True])
z3 = np_dot(x3, y3)
z4 = np_dot(2, 3)
z5 = np_dot(2., 3.)
z6 = np_dot(True, False)
output_float64(z1)
output_int32(z2)
output_bool(z3)
output_int32(z4)
output_float64(z5)
output_bool(z6)
def test_ndarray_cholesky():
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
y = np_linalg_cholesky(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_qr():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
y, z = np_linalg_qr(x)
output_ndarray_float_2(x)
# QR Factorization is not unique and gives different results in numpy and nalgebra
# Reverting the decomposition to compare the initial arrays
a = y @ z
output_ndarray_float_2(a)
def test_ndarray_linalg_inv():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
y = np_linalg_inv(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_pinv():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5]])
y = np_linalg_pinv(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_matrix_power():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
y = np_linalg_matrix_power(x, -9)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_det():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
y = np_linalg_det(x)
output_ndarray_float_2(x)
output_float64(y)
def test_ndarray_schur():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
t, z = sp_linalg_schur(x)
output_ndarray_float_2(x)
# Schur Factorization is not unique and gives different results in scipy and nalgebra
# Reverting the decomposition to compare the initial arrays
a = (z @ t) @ np_linalg_inv(z)
output_ndarray_float_2(a)
def test_ndarray_hessenberg():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 5.0, 8.5]])
h, q = sp_linalg_hessenberg(x)
output_ndarray_float_2(x)
# Hessenberg Factorization is not unique and gives different results in scipy and nalgebra
# Reverting the decomposition to compare the initial arrays
a = (q @ h) @ np_linalg_inv(q)
output_ndarray_float_2(a)
def test_ndarray_lu():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5]])
l, u = sp_linalg_lu(x)
output_ndarray_float_2(x)
output_ndarray_float_2(l)
output_ndarray_float_2(u)
def test_ndarray_svd():
w: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
x, y, z = np_linalg_svd(w)
output_ndarray_float_2(w)
# SVD Factorization is not unique and gives different results in numpy and nalgebra
# Reverting the decomposition to compare the initial arrays
a = x @ z
output_ndarray_float_2(a)
output_ndarray_float_1(y)
def run() -> int32: def run() -> int32:
test_ndarray_ctor() test_ndarray_ctor()
test_ndarray_empty() test_ndarray_empty()
@ -1743,18 +1607,5 @@ def run() -> int32:
test_ndarray_nextafter_broadcast() test_ndarray_nextafter_broadcast()
test_ndarray_nextafter_broadcast_lhs_scalar() test_ndarray_nextafter_broadcast_lhs_scalar()
test_ndarray_nextafter_broadcast_rhs_scalar() test_ndarray_nextafter_broadcast_rhs_scalar()
test_ndarray_transpose()
test_ndarray_reshape()
test_ndarray_dot()
test_ndarray_cholesky()
test_ndarray_qr()
test_ndarray_svd()
test_ndarray_linalg_inv()
test_ndarray_pinv()
test_ndarray_matrix_power()
test_ndarray_det()
test_ndarray_lu()
test_ndarray_schur()
test_ndarray_hessenberg()
return 0 return 0

View File

@ -1,11 +0,0 @@
def f(*args: int32):
pass
def run() -> int32:
f()
f(1)
f(1, 2)
f(1, 2, 3)
return 0

View File

@ -9,11 +9,15 @@
#![allow(clippy::too_many_lines, clippy::wildcard_imports)] #![allow(clippy::too_many_lines, clippy::wildcard_imports)]
use clap::Parser; use clap::Parser;
use inkwell::context::Context;
use inkwell::{ use inkwell::{
memory_buffer::MemoryBuffer, passes::PassBuilderOptions, support::is_multithreaded, targets::*, memory_buffer::MemoryBuffer, passes::PassBuilderOptions, support::is_multithreaded, targets::*,
OptimizationLevel, OptimizationLevel,
}; };
use parking_lot::{Mutex, RwLock};
use std::collections::HashSet;
use std::num::NonZeroUsize;
use std::{collections::HashMap, fs, path::Path, sync::Arc};
use nac3core::{ use nac3core::{
codegen::{ codegen::{
concrete_type::ConcreteTypeStore, irrt::load_irrt, CodeGenLLVMOptions, concrete_type::ConcreteTypeStore, irrt::load_irrt, CodeGenLLVMOptions,
@ -35,10 +39,6 @@ use nac3parser::{
ast::{Constant, Expr, ExprKind, StmtKind, StrRef}, ast::{Constant, Expr, ExprKind, StmtKind, StrRef},
parser, parser,
}; };
use parking_lot::{Mutex, RwLock};
use std::collections::HashSet;
use std::num::NonZeroUsize;
use std::{collections::HashMap, fs, path::Path, sync::Arc};
mod basic_symbol_resolver; mod basic_symbol_resolver;
use basic_symbol_resolver::*; use basic_symbol_resolver::*;
@ -241,6 +241,8 @@ fn handle_assignment_pattern(
} }
fn main() { fn main() {
const SIZE_T: u32 = usize::BITS;
let cli = CommandLineArgs::parse(); let cli = CommandLineArgs::parse();
let CommandLineArgs { file_name, threads, opt_level, emit_llvm, triple, mcpu, target_features } = let CommandLineArgs { file_name, threads, opt_level, emit_llvm, triple, mcpu, target_features } =
cli; cli;
@ -273,24 +275,6 @@ fn main() {
_ => OptimizationLevel::Aggressive, _ => OptimizationLevel::Aggressive,
}; };
let target_machine_options = CodeGenTargetMachineOptions {
triple,
cpu: mcpu,
features: target_features,
reloc_mode: RelocMode::PIC,
..host_target_machine
};
let size_t = Context::create()
.ptr_sized_int_type(
&target_machine_options
.create_target_machine(opt_level)
.map(|tm| tm.get_target_data())
.unwrap(),
None,
)
.get_bit_width();
let program = match fs::read_to_string(file_name.clone()) { let program = match fs::read_to_string(file_name.clone()) {
Ok(program) => program, Ok(program) => program,
Err(err) => { Err(err) => {
@ -299,9 +283,9 @@ fn main() {
} }
}; };
let primitive: PrimitiveStore = TopLevelComposer::make_primitives(size_t).0; let primitive: PrimitiveStore = TopLevelComposer::make_primitives(SIZE_T).0;
let (mut composer, builtins_def, builtins_ty) = let (mut composer, builtins_def, builtins_ty) =
TopLevelComposer::new(vec![], ComposerConfig::default(), size_t); TopLevelComposer::new(vec![], ComposerConfig::default(), SIZE_T);
let internal_resolver: Arc<ResolverInternal> = ResolverInternal { let internal_resolver: Arc<ResolverInternal> = ResolverInternal {
id_to_type: builtins_ty.into(), id_to_type: builtins_ty.into(),
@ -389,7 +373,16 @@ fn main() {
instance_to_stmt[""].clone() instance_to_stmt[""].clone()
}; };
let llvm_options = CodeGenLLVMOptions { opt_level, target: target_machine_options }; let llvm_options = CodeGenLLVMOptions {
opt_level,
target: CodeGenTargetMachineOptions {
triple,
cpu: mcpu,
features: target_features,
reloc_mode: RelocMode::PIC,
..host_target_machine
},
};
let task = CodeGenTask { let task = CodeGenTask {
subst: Vec::default(), subst: Vec::default(),
@ -412,7 +405,7 @@ fn main() {
membuffer.lock().push(buffer); membuffer.lock().push(buffer);
}))); })));
let threads = (0..threads) let threads = (0..threads)
.map(|i| Box::new(DefaultCodeGenerator::new(format!("module{i}"), size_t))) .map(|i| Box::new(DefaultCodeGenerator::new(format!("module{i}"), SIZE_T)))
.collect(); .collect();
let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f); let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f);
registry.add_task(task); registry.add_task(task);