diff --git a/Cargo.lock b/Cargo.lock index f16dee50..1d50ff1a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,31 +1,32 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. +version = 3 + +[[package]] +name = "ahash" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43bb833f0bf979d8475d38fbf09ed3b8a55e1885fe93ad3f93239fc6a4f17b98" +dependencies = [ + "getrandom 0.2.3", + "once_cell", + "version_check", +] + [[package]] name = "aho-corasick" -version = "0.7.15" +version = "0.7.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7404febffaa47dac81aa44dba71523c9d069b1bdc50a77db41195149e17f68e5" +checksum = "1e37cfd5e7657ada45f742d6e99ca5788580b5c529dc78faf11ece6dc702656f" dependencies = [ "memchr", ] -[[package]] -name = "arrayref" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4c527152e37cf757a3f78aae5a06fbeefdb07ccc535c980a3208ee3060dd544" - -[[package]] -name = "arrayvec" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23b62fc65de8e4e7f52534fb52b0f3ed04746ae267519eef2a83941e8085068b" - [[package]] name = "ascii-canvas" -version = "2.0.0" +version = "3.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff8eb72df928aafb99fe5d37b383f2fe25bd2a765e3e5f7c365916b6f2463a29" +checksum = "8824ecca2e851cec16968d54a01dd372ef8f95b244fb84b84e70128be347c3c6" dependencies = [ "term", ] @@ -47,12 +48,6 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a" -[[package]] -name = "base64" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd" - [[package]] name = "bit-set" version = "0.5.2" @@ -69,60 +64,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" [[package]] -name = "blake2b_simd" -version = "0.5.11" +name = "bitflags" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afa748e348ad3be8263be728124b24a24f268266f6f5d58af9d75f6a40b5c587" -dependencies = [ - "arrayref", - "arrayvec", - "constant_time_eq", -] - -[[package]] -name = "block-buffer" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0940dc441f31689269e10ac70eb1002a3a1d3ad1390e030043662eb7fe4688b" -dependencies = [ - "block-padding", - "byte-tools", - "byteorder", - "generic-array", -] - -[[package]] -name = "block-padding" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa79dedbb091f449f1f39e53edf88d5dbe95f895dae6135a8d7b881fb5af73f5" -dependencies = [ - "byte-tools", -] - -[[package]] -name = "byte-tools" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3b5ca7a04898ad4bcd41c90c5285445ff5b791899bb1b0abdd2a2aa791211d7" - -[[package]] -name = "byteorder" -version = "1.3.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08c48aae112d48ed9f069b33538ea9e3e90aa263cfa3d1c24309612b1f7472de" +checksum = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693" [[package]] name = "cc" -version = "1.0.66" +version = "1.0.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c0496836a84f8d0495758516b8621a622beb77c0fed418570e50764093ced48" - -[[package]] -name = "cfg-if" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822" +checksum = "4a72c244c1ff497a746a7e1fb3d14bd08420ecda70c8f25c7112f2781652d787" [[package]] name = "cfg-if" @@ -131,27 +82,84 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] -name = "constant_time_eq" -version = "0.1.5" +name = "crossbeam" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" +checksum = "4ae5588f6b3c3cb05239e90bd110f257254aecd01e4635400391aeae07497845" +dependencies = [ + "cfg-if", + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-epoch", + "crossbeam-queue", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-channel" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06ed27e177f16d65f0f0c22a213e17c696ace5dd64b14258b52f9417ccb52db4" +dependencies = [ + "cfg-if", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6455c0ca19f0d2fbf751b908d5c55c1f5cbc65e03c4225427254b46890bdde1e" +dependencies = [ + "cfg-if", + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec02e091aa634e2c3ada4a392989e7c3116673ef0ac5b72232439094d73b7fd" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "lazy_static", + "memoffset", + "scopeguard", +] + +[[package]] +name = "crossbeam-queue" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b10ddc024425c88c2ad148c1b0fd53f4c6d38db9697c9f1588381212fa657c9" +dependencies = [ + "cfg-if", + "crossbeam-utils", +] [[package]] name = "crossbeam-utils" -version = "0.8.1" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02d96d1e189ef58269ebe5b97953da3274d83a93af647c2ddd6f9dab28cedb8d" +checksum = "d82cfc11ce7f2c3faef78d8a684447b40d503d9681acebed6cb728d45940c4db" dependencies = [ - "autocfg", - "cfg-if 1.0.0", + "cfg-if", "lazy_static", ] [[package]] -name = "ctor" -version = "0.1.16" +name = "crunchy" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fbaabec2c953050352311293be5c6aba8e141ba19d6811862b232d6fd020484" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + +[[package]] +name = "ctor" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e98e2ad1a782e33928b96fc3948e7c355e5af34ba4de7670fe8bac2a3b2006d" dependencies = [ "quote", "syn", @@ -164,37 +172,26 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e25ea47919b1560c4e3b7fe0aaab9becf5b84a10325ddf7db0f0ba5e1026499" [[package]] -name = "digest" -version = "0.8.1" +name = "dirs-next" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3d0c8c8752312f9713efd397ff63acb9f85585afbf179282e720e7704954dd5" +checksum = "b98cf8ebf19c3d1b223e151f99a4f9f0690dca41414773390fc824184ac833e1" dependencies = [ - "generic-array", + "cfg-if", + "dirs-sys-next", ] [[package]] -name = "dirs" -version = "1.0.5" +name = "dirs-sys-next" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fd78930633bd1c6e35c4b42b1df7b0cbc6bc191146e512bb3bedf243fcc3901" +checksum = "4ebda144c4fe02d1f7ea1a7d9641b6fc6b580adcfa024ae48797ecdeb6825b4d" dependencies = [ "libc", "redox_users", "winapi", ] -[[package]] -name = "docopt" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f525a586d310c87df72ebcd98009e57f1cc030c8c268305287a476beb653969" -dependencies = [ - "lazy_static", - "regex", - "serde", - "strsim", -] - [[package]] name = "either" version = "1.6.1" @@ -210,12 +207,6 @@ dependencies = [ "log", ] -[[package]] -name = "fake-simd" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e88a8acf291dafb59c2d96e8f59828f3838bb1a70398823ade51a84de6a6deed" - [[package]] name = "fixedbitset" version = "0.2.0" @@ -223,23 +214,25 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37ab347416e802de484e4d03c7316c48f1ecb56574dfd4a46a80f173ce1de04d" [[package]] -name = "generic-array" -version = "0.12.3" +name = "getrandom" +version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c68f0274ae0e023facc3c97b2e00f076be70e254bc851d972503b328db79b2ec" +checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce" dependencies = [ - "typenum", + "cfg-if", + "libc", + "wasi 0.9.0+wasi-snapshot-preview1", ] [[package]] name = "getrandom" -version = "0.1.15" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc587bc0ec293155d5bfa6b9891ec18a1e330c234f896ea47fbada4cadbe47e6" +checksum = "7fcd999463524c52659517fe2cea98493cfe485d10565e7b0fb07dbba7ad2753" dependencies = [ - "cfg-if 0.1.10", + "cfg-if", "libc", - "wasi", + "wasi 0.10.2+wasi-snapshot-preview1", ] [[package]] @@ -261,18 +254,18 @@ checksum = "d7afe4a420e3fe79967a00898cc1f4db7c8a49a9333a29f8a4bd76a253d5cd04" [[package]] name = "hermit-abi" -version = "0.1.17" +version = "0.1.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aca5565f760fb5b220e499d72710ed156fdb74e631659e99377d9ebfbd13ae8" +checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" dependencies = [ "libc", ] [[package]] name = "indexmap" -version = "1.6.1" +version = "1.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fb1fa934250de4de8aef298d81c729a7d33d8c239daa3a7575e6b92bfc7313b" +checksum = "824845a0bf897a9042383849b02c1bc219c2383772efcd5c6f9766fa4b81aef3" dependencies = [ "autocfg", "hashbrown", @@ -288,6 +281,15 @@ dependencies = [ "proc-macro-hack", ] +[[package]] +name = "indoc" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5a75aeaaef0ce18b58056d306c27b07436fbb34b8816c53094b76dd81803136" +dependencies = [ + "unindent", +] + [[package]] name = "indoc-impl" version = "0.3.6" @@ -304,12 +306,13 @@ dependencies = [ [[package]] name = "inkwell" version = "0.1.0" -source = "git+https://github.com/TheDan64/inkwell#3eab4db479c2ca9d20b191f431a6d36835093108" +source = "git+https://github.com/TheDan64/inkwell?branch=master#aa4de1d78471a3d2f0fda1f56801177ddf80f3bf" dependencies = [ "either", "inkwell_internals", "libc", - "llvm-sys", + "llvm-sys 100.2.1", + "llvm-sys 110.0.1", "once_cell", "parking_lot", "regex", @@ -317,8 +320,8 @@ dependencies = [ [[package]] name = "inkwell_internals" -version = "0.2.0" -source = "git+https://github.com/TheDan64/inkwell#3eab4db479c2ca9d20b191f431a6d36835093108" +version = "0.3.0" +source = "git+https://github.com/TheDan64/inkwell?branch=master#aa4de1d78471a3d2f0fda1f56801177ddf80f3bf" dependencies = [ "proc-macro2", "quote", @@ -331,7 +334,7 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61124eeebbd69b8190558df225adf7e4caafce0d743919e5d6b19652314ec5ec" dependencies = [ - "cfg-if 1.0.0", + "cfg-if", ] [[package]] @@ -358,43 +361,44 @@ dependencies = [ [[package]] name = "itertools" -version = "0.9.0" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "284f18f85651fe11e8a991b2adb42cb078325c996ed026d994719efcfca1d54b" +checksum = "69ddb889f9d0d08a67338271fa9b62996bc788c7796a5c18cf057420aaed5eaf" dependencies = [ "either", ] [[package]] name = "lalrpop" -version = "0.19.1" +version = "0.19.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60fb56191fb8ed5311597e5750debe6779c9fdb487dbaa5ff302592897d7a2c8" +checksum = "b15174f1c529af5bf1283c3bc0058266b483a67156f79589fab2a25e23cf8988" dependencies = [ "ascii-canvas", "atty", "bit-set", "diff", - "docopt", "ena", "itertools", "lalrpop-util", "petgraph", + "pico-args", "regex", "regex-syntax", - "serde", - "serde_derive", - "sha2", "string_cache", "term", + "tiny-keccak", "unicode-xid", ] [[package]] name = "lalrpop-util" -version = "0.19.1" +version = "0.19.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6771161eff561647fad8bb7e745e002c304864fb8f436b52b30acda51fca4408" +checksum = "d3e58cce361efcc90ba8a0a5f982c741ff86b603495bb15a998412e957dcd278" +dependencies = [ + "regex", +] [[package]] name = "lazy_static" @@ -404,55 +408,83 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.81" +version = "0.2.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1482821306169ec4d07f6aca392a4681f66c75c9918aa49641a2595db64053cb" +checksum = "12b8adadd720df158f4d70dfe7ccc6adb0472d7c55ca83445f6a5ab3e36f8fb6" [[package]] name = "llvm-sys" -version = "100.2.0" +version = "100.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9109e19fbfac3458f2970189719fa19f1007c6fd4e08c44fdebf4be0ddbe261d" +checksum = "15d9c00ce56221b2150e2d4d51887ff139fce5a0e50346c744861d1e66d2f7c4" dependencies = [ "cc", "lazy_static", "libc", "regex", - "semver", + "semver 0.9.0", +] + +[[package]] +name = "llvm-sys" +version = "110.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21ede189444b8c78907e5d36da5dabcf153170fcff9c1dba48afc4b33c7e19f0" +dependencies = [ + "cc", + "lazy_static", + "libc", + "regex", + "semver 0.11.0", ] [[package]] name = "lock_api" -version = "0.4.2" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd96ffd135b2fd7b973ac026d28085defbe8983df057ced3eb4f2130b0831312" +checksum = "0382880606dff6d15c9476c416d18690b72742aa7b605bb6dd6ec9030fbf07eb" dependencies = [ "scopeguard", ] [[package]] name = "log" -version = "0.4.11" +version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fabed175da42fed1fa0746b0ea71f412aa9d35e76e95e59b192c64b9dc2bf8b" +checksum = "51b9bbe6c47d51fc3e1a9b945965946b4c44142ab8792c50835a980d362c2710" dependencies = [ - "cfg-if 0.1.10", + "cfg-if", ] [[package]] name = "memchr" -version = "2.3.4" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ee1c47aaa256ecabcaea351eae4a9b01ef39ed810004e298d2511ed284b1525" +checksum = "b16bd47d9e329435e309c58469fe0791c2d0d1ba96ec0954152a5ae2b04387dc" + +[[package]] +name = "memoffset" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59accc507f1338036a0477ef61afdae33cde60840f4dfe481319ce3ad116ddf9" +dependencies = [ + "autocfg", +] [[package]] name = "nac3core" version = "0.1.0" dependencies = [ + "crossbeam", + "indoc 1.0.3", "inkwell", - "num-bigint", + "itertools", + "num-bigint 0.3.2", "num-traits", + "parking_lot", + "rayon", "rustpython-parser", + "test-case", ] [[package]] @@ -482,9 +514,20 @@ checksum = "e4a24736216ec316047a1fc4252e27dabb04218aa4a3f37c6e7ddbf1f9782b54" [[package]] name = "num-bigint" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e9a41747ae4633fce5adffb4d2e81ffc5e89593cb19917f8fb2cc5ff76507bf" +checksum = "7d0a3d5e207573f948a9e5376662aa743a2ea13f7c50a554d7af443a73fbfeba" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-bigint" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e0d047c1062aa51e256408c560894e5251f08925980e53cf1aa5bd00eec6512" dependencies = [ "autocfg", "num-integer", @@ -511,16 +554,20 @@ dependencies = [ ] [[package]] -name = "once_cell" -version = "1.5.2" +name = "num_cpus" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13bd41f508810a131401606d54ac32a467c97172d74ba7662562ebba5ad07fa0" +checksum = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3" +dependencies = [ + "hermit-abi", + "libc", +] [[package]] -name = "opaque-debug" -version = "0.2.3" +name = "once_cell" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2839e79665f131bdb5782e51f2c6c9599c133c6098982a54c794358bf432529c" +checksum = "692fcb63b64b1758029e0a96ee63e049ce8c5948587f2f7208df04625e5f6b56" [[package]] name = "parking_lot" @@ -535,11 +582,11 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.8.1" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7c6d9b8427445284a09c55be860a15855ab580a417ccad9da88f5a06787ced0" +checksum = "fa7a782938e745763fe6907fc6ba86946d72f49fe7e21de074e08128a99fb018" dependencies = [ - "cfg-if 1.0.0", + "cfg-if", "instant", "libc", "redox_syscall", @@ -566,6 +613,15 @@ dependencies = [ "proc-macro-hack", ] +[[package]] +name = "pest" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10f4872ae94d7b90ae48754df22fd42ad52ce740b8f370b03da4835417403e53" +dependencies = [ + "ucd-trie", +] + [[package]] name = "petgraph" version = "0.5.1" @@ -576,6 +632,41 @@ dependencies = [ "indexmap", ] +[[package]] +name = "phf" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3dfb61232e34fcb633f43d12c58f83c1df82962dcdfa565a4e866ffc17dafe12" +dependencies = [ + "phf_macros", + "phf_shared", + "proc-macro-hack", +] + +[[package]] +name = "phf_generator" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17367f0cc86f2d25802b2c26ee58a7b23faeccf78a396094c13dced0d0182526" +dependencies = [ + "phf_shared", + "rand", +] + +[[package]] +name = "phf_macros" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f6fde18ff429ffc8fe78e2bf7f8b7a5a5a6e2a8b58bc5a9ac69198bbda9189c" +dependencies = [ + "phf_generator", + "phf_shared", + "proc-macro-hack", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "phf_shared" version = "0.8.0" @@ -585,6 +676,18 @@ dependencies = [ "siphasher", ] +[[package]] +name = "pico-args" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db8bcd96cb740d03149cbad5518db9fd87126a10ab519c011893b1754134c468" + +[[package]] +name = "ppv-lite86" +version = "0.2.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac74c624d6b2d21f425f752262f42188365d7b8ff1aff74c82e45136510a4857" + [[package]] name = "precomputed-hash" version = "0.1.1" @@ -599,9 +702,9 @@ checksum = "dbf0c48bc1d91375ae5c3cd81e3722dff1abcf81a30960240640d223f59fe0e5" [[package]] name = "proc-macro2" -version = "1.0.24" +version = "1.0.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e0704ee1a7e00d7bb417d0770ea303c1bccbabf0ef1667dae92b5967f5f8a71" +checksum = "f0d8caf72986c1a598726adc988bb5984792ef84f5ee5aa50209145ee8077038" dependencies = [ "unicode-xid", ] @@ -613,7 +716,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf6bbbe8f70d179260b3728e5d04eb012f4f0c7988e58c11433dd689cecaa72e" dependencies = [ "ctor", - "indoc", + "indoc 0.3.6", "inventory", "libc", "parking_lot", @@ -646,84 +749,157 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.7" +version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa563d17ecb180e500da1cfd2b028310ac758de548efdd203e18f283af693f37" +checksum = "c3d0b9745dc2debf507c8422de05d7226cc1f0644216dfdfead988f9b1ab32a7" dependencies = [ "proc-macro2", ] [[package]] -name = "redox_syscall" -version = "0.1.57" +name = "rand" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41cc0f7e4d5d4544e8861606a285bb08d3e70712ccc7d2b84d7c0ccfaf4b05ce" +checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03" +dependencies = [ + "getrandom 0.1.16", + "libc", + "rand_chacha", + "rand_core", + "rand_hc", + "rand_pcg", +] + +[[package]] +name = "rand_chacha" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" +dependencies = [ + "getrandom 0.1.16", +] + +[[package]] +name = "rand_hc" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c" +dependencies = [ + "rand_core", +] + +[[package]] +name = "rand_pcg" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16abd0c1b639e9eb4d7c50c0b8100b0d0f849be2349829c740fe8e6eb4816429" +dependencies = [ + "rand_core", +] + +[[package]] +name = "rayon" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c06aca804d41dbc8ba42dfd964f0d01334eceb64314b9ecf7c5fad5188a06d90" +dependencies = [ + "autocfg", + "crossbeam-deque", + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d78120e2c850279833f1dd3582f730c4ab53ed95aeaaaa862a2a5c71b1656d8e" +dependencies = [ + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-utils", + "lazy_static", + "num_cpus", +] + +[[package]] +name = "redox_syscall" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ab49abadf3f9e1c4bc499e8845e152ad87d2ad2d30371841171169e9d75feee" +dependencies = [ + "bitflags", +] [[package]] name = "redox_users" -version = "0.3.5" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de0737333e7a9502c789a36d7c7fa6092a49895d4faa31ca5df163857ded2e9d" +checksum = "528532f3d801c87aec9def2add9ca802fe569e44a544afe633765267840abe64" dependencies = [ - "getrandom", + "getrandom 0.2.3", "redox_syscall", - "rust-argon2", ] [[package]] name = "regex" -version = "1.4.2" +version = "1.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38cf2c13ed4745de91a5eb834e11c00bcc3709e773173b2ce4c56c9fbde04b9c" +checksum = "d07a8629359eb56f1e2fb1652bb04212c072a87ba68546a04065d525673ac461" dependencies = [ "aho-corasick", "memchr", "regex-syntax", - "thread_local", ] [[package]] name = "regex-syntax" -version = "0.6.21" +version = "0.6.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b181ba2dcf07aaccad5448e8ead58db5b742cf85dfe035e2227f137a539a189" - -[[package]] -name = "rust-argon2" -version = "0.8.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b18820d944b33caa75a71378964ac46f58517c92b6ae5f762636247c09e78fb" -dependencies = [ - "base64", - "blake2b_simd", - "constant_time_eq", - "crossbeam-utils", -] +checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b" [[package]] name = "rustpython-ast" version = "0.1.0" -source = "git+https://github.com/RustPython/RustPython#b01cca97f4f94af6c1e15f161283f1ac9e0617b1" +source = "git+https://github.com/RustPython/RustPython?branch=master#bee5794b6e2b777ee343c7277954b73d06b5cb7d" dependencies = [ - "num-bigint", + "num-bigint 0.4.0", ] [[package]] name = "rustpython-parser" version = "0.1.2" -source = "git+https://github.com/RustPython/RustPython#b01cca97f4f94af6c1e15f161283f1ac9e0617b1" +source = "git+https://github.com/RustPython/RustPython?branch=master#bee5794b6e2b777ee343c7277954b73d06b5cb7d" dependencies = [ + "ahash", "lalrpop", "lalrpop-util", "log", - "num-bigint", + "num-bigint 0.4.0", "num-traits", + "phf", "rustpython-ast", "unic-emoji-char", "unic-ucd-ident", "unicode_names2", ] +[[package]] +name = "rustversion" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61b3909d758bb75c79f23d4736fac9433868679d3ad2ea7a61e3c25cfda9a088" + [[package]] name = "scopeguard" version = "1.1.0" @@ -736,7 +912,16 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d7eb9ef2c18661902cc47e535f9bc51b78acd254da71d375c2f6720d9a40403" dependencies = [ - "semver-parser", + "semver-parser 0.7.0", +] + +[[package]] +name = "semver" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f301af10236f6df4160f7c3f04eec6dbc70ace82d23326abad5edee88801c6b6" +dependencies = [ + "semver-parser 0.10.2", ] [[package]] @@ -746,48 +931,25 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3" [[package]] -name = "serde" -version = "1.0.118" +name = "semver-parser" +version = "0.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06c64263859d87aa2eb554587e2d23183398d617427327cf2b3d0ed8c69e4800" +checksum = "00b0bef5b7f9e0df16536d3961cfb6e84331c065b4066afb39768d0e319411f7" dependencies = [ - "serde_derive", -] - -[[package]] -name = "serde_derive" -version = "1.0.118" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c84d3526699cd55261af4b941e4e725444df67aa4f9e6a3564f18030d12672df" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "sha2" -version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a256f46ea78a0c0d9ff00077504903ac881a1dafdc20da66545699e7776b3e69" -dependencies = [ - "block-buffer", - "digest", - "fake-simd", - "opaque-debug", + "pest", ] [[package]] name = "siphasher" -version = "0.3.3" +version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa8f3741c7372e75519bd9346068370c9cdaabcc1f9599cbcf2a2719352286b7" +checksum = "cbce6d4507c7e4a3962091436e56e95290cb71fa302d0d270e32130b75fbff27" [[package]] name = "smallvec" -version = "1.5.1" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae524f056d7d770e174287294f562e95044c68e88dec909a00d2094805db9d75" +checksum = "fe0f37c9e8f3c5a4a66ad655a93c74daac4ad00c441533bf5c6e7990bb42604e" [[package]] name = "string_cache" @@ -799,20 +961,13 @@ dependencies = [ "new_debug_unreachable", "phf_shared", "precomputed-hash", - "serde", ] -[[package]] -name = "strsim" -version = "0.9.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6446ced80d6c486436db5c078dde11a9f73d42b57fb273121e160b84f63d894c" - [[package]] name = "syn" -version = "1.0.54" +version = "1.0.73" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a2af957a63d6bd42255c359c93d9bfdb97076bd3b820897ce55ffbfbf107f44" +checksum = "f71489ff30030d2ae598524f61326b902466f72a0fb1a8564c001cc63425bcc7" dependencies = [ "proc-macro2", "quote", @@ -821,29 +976,42 @@ dependencies = [ [[package]] name = "term" -version = "0.5.2" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edd106a334b7657c10b7c540a0106114feadeb4dc314513e97df481d5d966f42" +checksum = "c59df8ac95d96ff9bede18eb7300b0fda5e5d8d90960e76f8e14ae765eedbf1f" dependencies = [ - "byteorder", - "dirs", + "dirs-next", + "rustversion", "winapi", ] [[package]] -name = "thread_local" -version = "1.0.1" +name = "test-case" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d40c6d1b69745a6ec6fb1ca717914848da4b44ae29d9b3080cbee91d72a69b14" +checksum = "3b114ece25254e97bf48dd4bfc2a12bad0647adacfe4cae1247a9ca6ad302cec" dependencies = [ - "lazy_static", + "cfg-if", + "proc-macro2", + "quote", + "syn", + "version_check", ] [[package]] -name = "typenum" -version = "1.12.0" +name = "tiny-keccak" +version = "2.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "373c8a200f9e67a0c95e62a4f52fbf80c23b4381c05a17845531982fa99e6b33" +checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237" +dependencies = [ + "crunchy", +] + +[[package]] +name = "ucd-trie" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56dee185309b50d1f11bfedef0fe6d036842e3fb77413abef29f8f8d1c5d4c1c" [[package]] name = "unic-char-property" @@ -899,9 +1067,9 @@ dependencies = [ [[package]] name = "unicode-xid" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7fe0bb3479651439c9112f72b6c505038574c9fbb575ed1bf3b797fa39dd564" +checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3" [[package]] name = "unicode_names2" @@ -915,12 +1083,24 @@ version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f14ee04d9415b52b3aeab06258a3f07093182b88ba0f9b8d203f211a7a7d41c7" +[[package]] +name = "version_check" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5fecdca9a5291cc2b8dcf7dc02453fee791a280f3743cb0905f8822ae463b3fe" + [[package]] name = "wasi" version = "0.9.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" +[[package]] +name = "wasi" +version = "0.10.2+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6" + [[package]] name = "winapi" version = "0.3.9" diff --git a/README.md b/README.md index 23c93519..00f7dcaf 100644 --- a/README.md +++ b/README.md @@ -15,20 +15,20 @@ caller to specify which methods should be compiled). After type checking, the compiler would analyse the set of functions/classes that are used and perform code generation. - -Symbol resolver: -- Str -> Nac3Type -- Str -> Value - value could be integer values, boolean values, bytes (for memcpy), function ID (full name + concrete type) ## Current Plan -1. Write out the syntax-directed type checking/inferencing rules. Fix the rule - for type variable instantiation. -2. Update the library dependencies and rewrite some of the type checking code. -3. Design the symbol resolver API. -4. Move tests from code to external files to cleanup the code. +Type checking: + +- [x] Basic interface for symbol resolver. +- [x] Track location information in context object (for diagnostics). +- [ ] Refactor old expression and statement type inference code. (anto) +- [ ] Error diagnostics utilities. (pca) +- [ ] Move tests to external files, write scripts for testing. (pca) +- [ ] Implement function type checking (instantiate bounded type parameters), + loop unrolling, type inference for lists with virtual objects. (pca) + diff --git a/nac3core/Cargo.toml b/nac3core/Cargo.toml index 62af7c29..bbd553e4 100644 --- a/nac3core/Cargo.toml +++ b/nac3core/Cargo.toml @@ -7,5 +7,13 @@ edition = "2018" [dependencies] num-bigint = "0.3" num-traits = "0.2" -inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm10-0"] } +inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm11-0"] } rustpython-parser = { git = "https://github.com/RustPython/RustPython", branch = "master" } +itertools = "0.10.1" +crossbeam = "0.8.1" +parking_lot = "0.11.1" +rayon = "1.5.1" + +[dev-dependencies] +test-case = "1.2.0" +indoc = "1.0" diff --git a/nac3core/rustfmt.toml b/nac3core/rustfmt.toml new file mode 100644 index 00000000..cfaa54ae --- /dev/null +++ b/nac3core/rustfmt.toml @@ -0,0 +1 @@ + use_small_heuristics = "Max" diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs new file mode 100644 index 00000000..81a2161b --- /dev/null +++ b/nac3core/src/codegen/expr.rs @@ -0,0 +1,527 @@ +use std::{collections::HashMap, convert::TryInto, iter::once}; + +use super::{get_llvm_type, CodeGenContext}; +use crate::{ + symbol_resolver::SymbolValue, + top_level::{DefinitionId, TopLevelDef}, + typecheck::typedef::{FunSignature, Type, TypeEnum}, +}; +use inkwell::{ + types::{BasicType, BasicTypeEnum}, + values::BasicValueEnum, + AddressSpace, +}; +use itertools::{chain, izip, zip, Itertools}; +use rustpython_parser::ast::{self, Boolop, Constant, Expr, ExprKind, Operator}; + +impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { + fn get_subst_key(&mut self, obj: Option, fun: &FunSignature) -> String { + let mut vars = obj + .map(|ty| { + if let TypeEnum::TObj { params, .. } = &*self.unifier.get_ty(ty) { + params.borrow().clone() + } else { + unreachable!() + } + }) + .unwrap_or_default(); + vars.extend(fun.vars.iter()); + let sorted = vars.keys().sorted(); + sorted + .map(|id| { + self.unifier.stringify(vars[id], &mut |id| id.to_string(), &mut |id| id.to_string()) + }) + .join(", ") + } + + pub fn get_attr_index(&mut self, ty: Type, attr: &str) -> usize { + let obj_id = match &*self.unifier.get_ty(ty) { + TypeEnum::TObj { obj_id, .. } => *obj_id, + // we cannot have other types, virtual type should be handled by function calls + _ => unreachable!(), + }; + let def = &self.top_level.definitions.read()[obj_id.0]; + let index = if let TopLevelDef::Class { fields, .. } = &*def.read() { + fields.iter().find_position(|x| x.0 == attr).unwrap().0 + } else { + unreachable!() + }; + index + } + + fn gen_symbol_val(&mut self, val: &SymbolValue) -> BasicValueEnum<'ctx> { + match val { + SymbolValue::I32(v) => self.ctx.i32_type().const_int(*v as u64, true).into(), + SymbolValue::I64(v) => self.ctx.i64_type().const_int(*v as u64, true).into(), + SymbolValue::Bool(v) => self.ctx.bool_type().const_int(*v as u64, true).into(), + SymbolValue::Double(v) => self.ctx.f64_type().const_float(*v).into(), + SymbolValue::Tuple(ls) => { + let vals = ls.iter().map(|v| self.gen_symbol_val(v)).collect_vec(); + let fields = vals.iter().map(|v| v.get_type()).collect_vec(); + let ty = self.ctx.struct_type(&fields, false); + let ptr = self.builder.build_alloca(ty, "tuple"); + let zero = self.ctx.i32_type().const_zero(); + unsafe { + for (i, val) in vals.into_iter().enumerate() { + let p = ptr.const_in_bounds_gep(&[ + zero, + self.ctx.i32_type().const_int(i as u64, false), + ]); + self.builder.build_store(p, val); + } + } + ptr.into() + } + } + } + + pub fn get_llvm_type(&mut self, ty: Type) -> BasicTypeEnum<'ctx> { + get_llvm_type(self.ctx, &mut self.unifier, self.top_level, &mut self.type_cache, ty) + } + + fn gen_call( + &mut self, + obj: Option<(Type, BasicValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + params: Vec<(Option, BasicValueEnum<'ctx>)>, + ret: Type, + ) -> Option> { + let key = self.get_subst_key(obj.map(|(a, _)| a), fun.0); + let defs = self.top_level.definitions.read(); + let definition = defs.get(fun.1 .0).unwrap(); + let val = if let TopLevelDef::Function { instance_to_symbol, .. } = &*definition.read() { + let symbol = instance_to_symbol.get(&key).unwrap_or_else(|| { + // TODO: codegen for function that are not yet generated + unimplemented!() + }); + let fun_val = self.module.get_function(symbol).unwrap_or_else(|| { + let params = fun.0.args.iter().map(|arg| self.get_llvm_type(arg.ty)).collect_vec(); + let fun_ty = if self.unifier.unioned(ret, self.primitives.none) { + self.ctx.void_type().fn_type(¶ms, false) + } else { + self.get_llvm_type(ret).fn_type(¶ms, false) + }; + self.module.add_function(symbol, fun_ty, None) + }); + let mut keys = fun.0.args.clone(); + let mut mapping = HashMap::new(); + for (key, value) in params.into_iter() { + mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value); + } + // default value handling + for k in keys.into_iter() { + mapping.insert(k.name, self.gen_symbol_val(&k.default_value.unwrap())); + } + // reorder the parameters + let params = + fun.0.args.iter().map(|arg| mapping.remove(&arg.name).unwrap()).collect_vec(); + self.builder.build_call(fun_val, ¶ms, "call").try_as_basic_value().left() + } else { + unreachable!() + }; + val + } + + fn gen_const(&mut self, value: &Constant, ty: Type) -> BasicValueEnum<'ctx> { + match value { + Constant::Bool(v) => { + assert!(self.unifier.unioned(ty, self.primitives.bool)); + let ty = self.ctx.bool_type(); + ty.const_int(if *v { 1 } else { 0 }, false).into() + } + Constant::Int(v) => { + let ty = if self.unifier.unioned(ty, self.primitives.int32) { + self.ctx.i32_type() + } else if self.unifier.unioned(ty, self.primitives.int64) { + self.ctx.i64_type() + } else { + unreachable!(); + }; + ty.const_int(v.try_into().unwrap(), false).into() + } + Constant::Float(v) => { + assert!(self.unifier.unioned(ty, self.primitives.float)); + let ty = self.ctx.f64_type(); + ty.const_float(*v).into() + } + Constant::Tuple(v) => { + let ty = self.unifier.get_ty(ty); + let types = + if let TypeEnum::TTuple { ty } = &*ty { ty.clone() } else { unreachable!() }; + let values = zip(types.into_iter(), v.iter()) + .map(|(ty, v)| self.gen_const(v, ty)) + .collect_vec(); + let types = values.iter().map(BasicValueEnum::get_type).collect_vec(); + let ty = self.ctx.struct_type(&types, false); + ty.const_named_struct(&values).into() + } + _ => unreachable!(), + } + } + + fn gen_int_ops( + &mut self, + op: &Operator, + lhs: BasicValueEnum<'ctx>, + rhs: BasicValueEnum<'ctx>, + ) -> BasicValueEnum<'ctx> { + let (lhs, rhs) = + if let (BasicValueEnum::IntValue(lhs), BasicValueEnum::IntValue(rhs)) = (lhs, rhs) { + (lhs, rhs) + } else { + unreachable!() + }; + match op { + Operator::Add => self.builder.build_int_add(lhs, rhs, "add").into(), + Operator::Sub => self.builder.build_int_sub(lhs, rhs, "sub").into(), + Operator::Mult => self.builder.build_int_mul(lhs, rhs, "mul").into(), + Operator::Div => { + let float = self.ctx.f64_type(); + let left = self.builder.build_signed_int_to_float(lhs, float, "i2f"); + let right = self.builder.build_signed_int_to_float(rhs, float, "i2f"); + self.builder.build_float_div(left, right, "fdiv").into() + } + Operator::Mod => self.builder.build_int_signed_rem(lhs, rhs, "mod").into(), + Operator::BitOr => self.builder.build_or(lhs, rhs, "or").into(), + Operator::BitXor => self.builder.build_xor(lhs, rhs, "xor").into(), + Operator::BitAnd => self.builder.build_and(lhs, rhs, "and").into(), + Operator::LShift => self.builder.build_left_shift(lhs, rhs, "lshift").into(), + Operator::RShift => self.builder.build_right_shift(lhs, rhs, true, "rshift").into(), + Operator::FloorDiv => self.builder.build_int_signed_div(lhs, rhs, "floordiv").into(), + // special implementation? + Operator::Pow => unimplemented!(), + Operator::MatMult => unreachable!(), + } + } + + fn gen_float_ops( + &mut self, + op: &Operator, + lhs: BasicValueEnum<'ctx>, + rhs: BasicValueEnum<'ctx>, + ) -> BasicValueEnum<'ctx> { + let (lhs, rhs) = if let (BasicValueEnum::FloatValue(lhs), BasicValueEnum::FloatValue(rhs)) = + (lhs, rhs) + { + (lhs, rhs) + } else { + unreachable!() + }; + match op { + Operator::Add => self.builder.build_float_add(lhs, rhs, "fadd").into(), + Operator::Sub => self.builder.build_float_sub(lhs, rhs, "fsub").into(), + Operator::Mult => self.builder.build_float_mul(lhs, rhs, "fmul").into(), + Operator::Div => self.builder.build_float_div(lhs, rhs, "fdiv").into(), + Operator::Mod => self.builder.build_float_rem(lhs, rhs, "fmod").into(), + Operator::FloorDiv => { + let div = self.builder.build_float_div(lhs, rhs, "fdiv"); + let floor_intrinsic = + self.module.get_function("llvm.floor.f64").unwrap_or_else(|| { + let float = self.ctx.f64_type(); + let fn_type = float.fn_type(&[float.into()], false); + self.module.add_function("llvm.floor.f64", fn_type, None) + }); + self.builder + .build_call(floor_intrinsic, &[div.into()], "floor") + .try_as_basic_value() + .left() + .unwrap() + } + // special implementation? + _ => unimplemented!(), + } + } + + pub fn gen_expr(&mut self, expr: &Expr>) -> BasicValueEnum<'ctx> { + let zero = self.ctx.i32_type().const_int(0, false); + match &expr.node { + ExprKind::Constant { value, .. } => { + let ty = expr.custom.unwrap(); + self.gen_const(value, ty) + } + ExprKind::Name { id, .. } => { + let ptr = self.var_assignment.get(id).unwrap(); + let primitives = &self.primitives; + // we should only dereference primitive types + if [primitives.int32, primitives.int64, primitives.float, primitives.bool] + .contains(&self.unifier.get_representative(expr.custom.unwrap())) + { + self.builder.build_load(*ptr, "load") + } else { + (*ptr).into() + } + } + ExprKind::List { elts, .. } => { + // this shall be optimized later for constant primitive lists... + // we should use memcpy for that instead of generating thousands of stores + let elements = elts.iter().map(|x| self.gen_expr(x)).collect_vec(); + let ty = if elements.is_empty() { + self.ctx.i32_type().into() + } else { + elements[0].get_type() + }; + let arr_ptr = self.builder.build_array_alloca( + ty, + self.ctx.i32_type().const_int(elements.len() as u64, false), + "tmparr", + ); + let arr_ty = self.ctx.struct_type( + &[self.ctx.i32_type().into(), ty.ptr_type(AddressSpace::Generic).into()], + false, + ); + let arr_str_ptr = self.builder.build_alloca(arr_ty, "tmparrstr"); + unsafe { + self.builder.build_store( + arr_str_ptr.const_in_bounds_gep(&[zero, zero]), + self.ctx.i32_type().const_int(elements.len() as u64, false), + ); + self.builder.build_store( + arr_str_ptr + .const_in_bounds_gep(&[zero, self.ctx.i32_type().const_int(1, false)]), + arr_ptr, + ); + let arr_offset = self.ctx.i32_type().const_int(1, false); + for (i, v) in elements.iter().enumerate() { + let ptr = self.builder.build_in_bounds_gep( + arr_ptr, + &[zero, arr_offset, self.ctx.i32_type().const_int(i as u64, false)], + "arr_element", + ); + self.builder.build_store(ptr, *v); + } + } + arr_str_ptr.into() + } + ExprKind::Tuple { elts, .. } => { + let element_val = elts.iter().map(|x| self.gen_expr(x)).collect_vec(); + let element_ty = element_val.iter().map(BasicValueEnum::get_type).collect_vec(); + let tuple_ty = self.ctx.struct_type(&element_ty, false); + let tuple_ptr = self.builder.build_alloca(tuple_ty, "tuple"); + for (i, v) in element_val.into_iter().enumerate() { + unsafe { + let ptr = tuple_ptr.const_in_bounds_gep(&[ + zero, + self.ctx.i32_type().const_int(i as u64, false), + ]); + self.builder.build_store(ptr, v); + } + } + tuple_ptr.into() + } + ExprKind::Attribute { value, attr, .. } => { + // note that we would handle class methods directly in calls + let index = self.get_attr_index(value.custom.unwrap(), attr); + let val = self.gen_expr(value); + let ptr = if let BasicValueEnum::PointerValue(v) = val { + v + } else { + unreachable!(); + }; + unsafe { + let ptr = ptr.const_in_bounds_gep(&[ + zero, + self.ctx.i32_type().const_int(index as u64, false), + ]); + self.builder.build_load(ptr, "field") + } + } + ExprKind::BoolOp { op, values } => { + // requires conditional branches for short-circuiting... + let left = if let BasicValueEnum::IntValue(left) = self.gen_expr(&values[0]) { + left + } else { + unreachable!() + }; + let current = self.builder.get_insert_block().unwrap().get_parent().unwrap(); + let a_bb = self.ctx.append_basic_block(current, "a"); + let b_bb = self.ctx.append_basic_block(current, "b"); + let cont_bb = self.ctx.append_basic_block(current, "cont"); + self.builder.build_conditional_branch(left, a_bb, b_bb); + let (a, b) = match op { + Boolop::Or => { + self.builder.position_at_end(a_bb); + let a = self.ctx.bool_type().const_int(1, false); + self.builder.build_unconditional_branch(cont_bb); + self.builder.position_at_end(b_bb); + let b = if let BasicValueEnum::IntValue(b) = self.gen_expr(&values[1]) { + b + } else { + unreachable!() + }; + self.builder.build_unconditional_branch(cont_bb); + (a, b) + } + Boolop::And => { + self.builder.position_at_end(a_bb); + let a = if let BasicValueEnum::IntValue(a) = self.gen_expr(&values[1]) { + a + } else { + unreachable!() + }; + self.builder.build_unconditional_branch(cont_bb); + self.builder.position_at_end(b_bb); + let b = self.ctx.bool_type().const_int(0, false); + self.builder.build_unconditional_branch(cont_bb); + (a, b) + } + }; + self.builder.position_at_end(cont_bb); + let phi = self.builder.build_phi(self.ctx.bool_type(), "phi"); + phi.add_incoming(&[(&a, a_bb), (&b, b_bb)]); + phi.as_basic_value() + } + ExprKind::BinOp { op, left, right } => { + let ty1 = self.unifier.get_representative(left.custom.unwrap()); + let ty2 = self.unifier.get_representative(right.custom.unwrap()); + let left = self.gen_expr(left); + let right = self.gen_expr(right); + + // we can directly compare the types, because we've got their representatives + // which would be unchanged until further unification, which we would never do + // when doing code generation for function instances + if ty1 == ty2 && [self.primitives.int32, self.primitives.int64].contains(&ty1) { + self.gen_int_ops(op, left, right) + } else if ty1 == ty2 && self.primitives.float == ty1 { + self.gen_float_ops(op, left, right) + } else { + unimplemented!() + } + } + ExprKind::UnaryOp { op, operand } => { + let ty = self.unifier.get_representative(operand.custom.unwrap()); + let val = self.gen_expr(operand); + if ty == self.primitives.bool { + let val = + if let BasicValueEnum::IntValue(val) = val { val } else { unreachable!() }; + match op { + ast::Unaryop::Invert | ast::Unaryop::Not => { + self.builder.build_not(val, "not").into() + } + _ => val.into(), + } + } else if [self.primitives.int32, self.primitives.int64].contains(&ty) { + let val = + if let BasicValueEnum::IntValue(val) = val { val } else { unreachable!() }; + match op { + ast::Unaryop::USub => self.builder.build_int_neg(val, "neg").into(), + ast::Unaryop::Invert => self.builder.build_not(val, "not").into(), + ast::Unaryop::Not => self + .builder + .build_int_compare( + inkwell::IntPredicate::EQ, + val, + val.get_type().const_zero(), + "not", + ) + .into(), + _ => val.into(), + } + } else if ty == self.primitives.float { + let val = if let BasicValueEnum::FloatValue(val) = val { + val + } else { + unreachable!() + }; + match op { + ast::Unaryop::USub => self.builder.build_float_neg(val, "neg").into(), + ast::Unaryop::Not => self + .builder + .build_float_compare( + inkwell::FloatPredicate::OEQ, + val, + val.get_type().const_zero(), + "not", + ) + .into(), + _ => val.into(), + } + } else { + unimplemented!() + } + } + ExprKind::Compare { left, ops, comparators } => { + izip!( + chain(once(left.as_ref()), comparators.iter()), + comparators.iter(), + ops.iter(), + ) + .fold(None, |prev, (lhs, rhs, op)| { + let ty = self.unifier.get_representative(lhs.custom.unwrap()); + let current = + if [self.primitives.int32, self.primitives.int64, self.primitives.bool] + .contains(&ty) + { + let (lhs, rhs) = if let ( + BasicValueEnum::IntValue(lhs), + BasicValueEnum::IntValue(rhs), + ) = (self.gen_expr(lhs), self.gen_expr(rhs)) + { + (lhs, rhs) + } else { + unreachable!() + }; + let op = match op { + ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::IntPredicate::EQ, + ast::Cmpop::NotEq => inkwell::IntPredicate::NE, + ast::Cmpop::Lt => inkwell::IntPredicate::SLT, + ast::Cmpop::LtE => inkwell::IntPredicate::SLE, + ast::Cmpop::Gt => inkwell::IntPredicate::SGT, + ast::Cmpop::GtE => inkwell::IntPredicate::SGE, + _ => unreachable!(), + }; + self.builder.build_int_compare(op, lhs, rhs, "cmp") + } else if ty == self.primitives.float { + let (lhs, rhs) = if let ( + BasicValueEnum::FloatValue(lhs), + BasicValueEnum::FloatValue(rhs), + ) = (self.gen_expr(lhs), self.gen_expr(rhs)) + { + (lhs, rhs) + } else { + unreachable!() + }; + let op = match op { + ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::FloatPredicate::OEQ, + ast::Cmpop::NotEq => inkwell::FloatPredicate::ONE, + ast::Cmpop::Lt => inkwell::FloatPredicate::OLT, + ast::Cmpop::LtE => inkwell::FloatPredicate::OLE, + ast::Cmpop::Gt => inkwell::FloatPredicate::OGT, + ast::Cmpop::GtE => inkwell::FloatPredicate::OGE, + _ => unreachable!(), + }; + self.builder.build_float_compare(op, lhs, rhs, "cmp") + } else { + unimplemented!() + }; + prev.map(|v| self.builder.build_and(v, current, "cmp")).or(Some(current)) + }) + .unwrap() + .into() // as there should be at least 1 element, it should never be none + } + ExprKind::IfExp { test, body, orelse } => { + let test = if let BasicValueEnum::IntValue(test) = self.gen_expr(test) { + test + } else { + unreachable!() + }; + + let current = self.builder.get_insert_block().unwrap().get_parent().unwrap(); + let then_bb = self.ctx.append_basic_block(current, "then"); + let else_bb = self.ctx.append_basic_block(current, "else"); + let cont_bb = self.ctx.append_basic_block(current, "cont"); + self.builder.build_conditional_branch(test, then_bb, else_bb); + self.builder.position_at_end(then_bb); + let a = self.gen_expr(body); + self.builder.build_unconditional_branch(cont_bb); + self.builder.position_at_end(else_bb); + let b = self.gen_expr(orelse); + self.builder.build_unconditional_branch(cont_bb); + self.builder.position_at_end(cont_bb); + let phi = self.builder.build_phi(a.get_type(), "ifexpr"); + phi.add_incoming(&[(&a, then_bb), (&b, else_bb)]); + phi.as_basic_value() + } + _ => unimplemented!(), + } + } +} diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs new file mode 100644 index 00000000..4df3ee99 --- /dev/null +++ b/nac3core/src/codegen/mod.rs @@ -0,0 +1,343 @@ +use crate::{ + symbol_resolver::SymbolResolver, + top_level::{TopLevelContext, TopLevelDef}, + typecheck::{ + type_inferencer::PrimitiveStore, + typedef::{FunSignature, Type, TypeEnum, Unifier}, + }, +}; +use crossbeam::channel::{unbounded, Receiver, Sender}; +use inkwell::{ + basic_block::BasicBlock, + builder::Builder, + context::Context, + module::Module, + types::{BasicType, BasicTypeEnum}, + values::PointerValue, + AddressSpace, +}; +use itertools::Itertools; +use parking_lot::{Condvar, Mutex}; +use rustpython_parser::ast::Stmt; +use std::collections::HashMap; +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; +use std::thread; + +mod expr; +mod stmt; + +#[cfg(test)] +mod test; + +pub struct CodeGenContext<'ctx, 'a> { + pub ctx: &'ctx Context, + pub builder: Builder<'ctx>, + pub module: Module<'ctx>, + pub top_level: &'a TopLevelContext, + pub unifier: Unifier, + pub resolver: Arc, + pub var_assignment: HashMap>, + pub type_cache: HashMap>, + pub primitives: PrimitiveStore, + // stores the alloca for variables + pub init_bb: BasicBlock<'ctx>, + // where continue and break should go to respectively + // the first one is the test_bb, and the second one is bb after the loop + pub loop_bb: Option<(BasicBlock<'ctx>, BasicBlock<'ctx>)>, +} + +type Fp = Box; + +pub struct WithCall { + fp: Fp, +} + +impl WithCall { + pub fn new(fp: Fp) -> WithCall { + WithCall { fp } + } + + pub fn run<'ctx>(&self, m: &Module<'ctx>) { + (self.fp)(m) + } +} + +pub struct WorkerRegistry { + sender: Arc>>, + receiver: Arc>>, + panicked: AtomicBool, + task_count: Mutex, + thread_count: usize, + wait_condvar: Condvar, +} + +impl WorkerRegistry { + pub fn create_workers( + names: &[&str], + top_level_ctx: Arc, + f: Arc, + ) -> (Arc, Vec>) { + let (sender, receiver) = unbounded(); + let task_count = Mutex::new(0); + let wait_condvar = Condvar::new(); + + let registry = Arc::new(WorkerRegistry { + sender: Arc::new(sender), + receiver: Arc::new(receiver), + thread_count: names.len(), + panicked: AtomicBool::new(false), + task_count, + wait_condvar, + }); + + let mut handles = Vec::new(); + for name in names.iter() { + let top_level_ctx = top_level_ctx.clone(); + let registry = registry.clone(); + let registry2 = registry.clone(); + let name = name.to_string(); + let f = f.clone(); + let handle = thread::spawn(move || { + registry.worker_thread(name, top_level_ctx, f); + }); + let handle = thread::spawn(move || { + if let Err(e) = handle.join() { + if let Some(e) = e.downcast_ref::<&'static str>() { + eprintln!("Got an error: {}", e); + } else { + eprintln!("Got an unknown error: {:?}", e); + } + registry2.panicked.store(true, Ordering::SeqCst); + registry2.wait_condvar.notify_all(); + } + }); + handles.push(handle); + } + (registry, handles) + } + + pub fn wait_tasks_complete(&self, handles: Vec>) { + { + let mut count = self.task_count.lock(); + while *count != 0 { + if self.panicked.load(Ordering::SeqCst) { + break; + } + self.wait_condvar.wait(&mut count); + } + } + for _ in 0..self.thread_count { + self.sender.send(None).unwrap(); + } + { + let mut count = self.task_count.lock(); + while *count != self.thread_count { + if self.panicked.load(Ordering::SeqCst) { + break; + } + self.wait_condvar.wait(&mut count); + } + } + for handle in handles { + handle.join().unwrap(); + } + if self.panicked.load(Ordering::SeqCst) { + panic!("tasks panicked"); + } + } + + pub fn add_task(&self, task: CodeGenTask) { + *self.task_count.lock() += 1; + self.sender.send(Some(task)).unwrap(); + } + + fn worker_thread( + &self, + module_name: String, + top_level_ctx: Arc, + f: Arc, + ) { + let context = Context::create(); + let mut builder = context.create_builder(); + let mut module = context.create_module(&module_name); + + while let Some(task) = self.receiver.recv().unwrap() { + let result = gen_func(&context, builder, module, task, top_level_ctx.clone()); + builder = result.0; + module = result.1; + *self.task_count.lock() -= 1; + self.wait_condvar.notify_all(); + } + + // do whatever... + let mut lock = self.task_count.lock(); + module.verify().unwrap(); + f.run(&module); + *lock += 1; + self.wait_condvar.notify_all(); + } +} + +pub struct CodeGenTask { + pub subst: Vec<(Type, Type)>, + pub symbol_name: String, + pub signature: FunSignature, + pub body: Vec>>, + pub unifier_index: usize, + pub resolver: Arc, +} + +fn get_llvm_type<'ctx>( + ctx: &'ctx Context, + unifier: &mut Unifier, + top_level: &TopLevelContext, + type_cache: &mut HashMap>, + ty: Type, +) -> BasicTypeEnum<'ctx> { + use TypeEnum::*; + // we assume the type cache should already contain primitive types, + // and they should be passed by value instead of passing as pointer. + type_cache.get(&unifier.get_representative(ty)).cloned().unwrap_or_else(|| { + match &*unifier.get_ty(ty) { + TObj { obj_id, fields, .. } => { + // a struct with fields in the order of declaration + let defs = top_level.definitions.read(); + let definition = defs.get(obj_id.0).unwrap(); + let ty = if let TopLevelDef::Class { fields: fields_list, .. } = &*definition.read() + { + let fields = fields.borrow(); + let fields = fields_list + .iter() + .map(|f| get_llvm_type(ctx, unifier, top_level, type_cache, fields[&f.0])) + .collect_vec(); + ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into() + } else { + unreachable!() + }; + ty + } + TTuple { ty } => { + // a struct with fields in the order present in the tuple + let fields = ty + .iter() + .map(|ty| get_llvm_type(ctx, unifier, top_level, type_cache, *ty)) + .collect_vec(); + ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into() + } + TList { ty } => { + // a struct with an integer and a pointer to an array + let element_type = get_llvm_type(ctx, unifier, top_level, type_cache, *ty); + let fields = + [ctx.i32_type().into(), element_type.ptr_type(AddressSpace::Generic).into()]; + ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into() + } + TVirtual { .. } => unimplemented!(), + _ => unreachable!(), + } + }) +} + +pub fn gen_func<'ctx>( + context: &'ctx Context, + builder: Builder<'ctx>, + module: Module<'ctx>, + task: CodeGenTask, + top_level_ctx: Arc, +) -> (Builder<'ctx>, Module<'ctx>) { + // unwrap_or(0) is for unit tests without using rayon + let (mut unifier, primitives) = { + let unifiers = top_level_ctx.unifiers.read(); + let (unifier, primitives) = &unifiers[task.unifier_index]; + (Unifier::from_shared_unifier(unifier), *primitives) + }; + + for (a, b) in task.subst.iter() { + // this should be unification between variables and concrete types + // and should not cause any problem... + unifier.unify(*a, *b).unwrap(); + } + + // rebuild primitive store with unique representatives + let primitives = PrimitiveStore { + int32: unifier.get_representative(primitives.int32), + int64: unifier.get_representative(primitives.int64), + float: unifier.get_representative(primitives.float), + bool: unifier.get_representative(primitives.bool), + none: unifier.get_representative(primitives.none), + }; + + let mut type_cache: HashMap<_, _> = [ + (unifier.get_representative(primitives.int32), context.i32_type().into()), + (unifier.get_representative(primitives.int64), context.i64_type().into()), + (unifier.get_representative(primitives.float), context.f64_type().into()), + (unifier.get_representative(primitives.bool), context.bool_type().into()), + ] + .iter() + .cloned() + .collect(); + + let params = task + .signature + .args + .iter() + .map(|arg| { + get_llvm_type(&context, &mut unifier, top_level_ctx.as_ref(), &mut type_cache, arg.ty) + }) + .collect_vec(); + + let fn_type = if unifier.unioned(task.signature.ret, primitives.none) { + context.void_type().fn_type(¶ms, false) + } else { + get_llvm_type( + &context, + &mut unifier, + top_level_ctx.as_ref(), + &mut type_cache, + task.signature.ret, + ) + .fn_type(¶ms, false) + }; + + let fn_val = module.add_function(&task.symbol_name, fn_type, None); + let init_bb = context.append_basic_block(fn_val, "init"); + builder.position_at_end(init_bb); + let body_bb = context.append_basic_block(fn_val, "body"); + + let mut var_assignment = HashMap::new(); + for (n, arg) in task.signature.args.iter().enumerate() { + let param = fn_val.get_nth_param(n as u32).unwrap(); + let alloca = builder.build_alloca( + get_llvm_type(&context, &mut unifier, top_level_ctx.as_ref(), &mut type_cache, arg.ty), + &arg.name, + ); + builder.build_store(alloca, param); + var_assignment.insert(arg.name.clone(), alloca); + } + builder.build_unconditional_branch(body_bb); + builder.position_at_end(body_bb); + + let mut code_gen_context = CodeGenContext { + ctx: &context, + resolver: task.resolver, + top_level: top_level_ctx.as_ref(), + loop_bb: None, + var_assignment, + type_cache, + primitives, + init_bb, + builder, + module, + unifier, + }; + + for stmt in task.body.iter() { + code_gen_context.gen_stmt(stmt); + } + + let CodeGenContext { builder, module, .. } = code_gen_context; + + (builder, module) +} diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs new file mode 100644 index 00000000..9f468609 --- /dev/null +++ b/nac3core/src/codegen/stmt.rs @@ -0,0 +1,138 @@ +use super::CodeGenContext; +use crate::typecheck::typedef::Type; +use inkwell::values::{BasicValue, BasicValueEnum, PointerValue}; +use rustpython_parser::ast::{Expr, ExprKind, Stmt, StmtKind}; + +impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { + fn gen_var(&mut self, ty: Type) -> PointerValue<'ctx> { + // put the alloca in init block + let current = self.builder.get_insert_block().unwrap(); + // position before the last branching instruction... + self.builder.position_before(&self.init_bb.get_last_instruction().unwrap()); + let ty = self.get_llvm_type(ty); + let ptr = self.builder.build_alloca(ty, "tmp"); + self.builder.position_at_end(current); + ptr + } + + fn parse_pattern(&mut self, pattern: &Expr>) -> PointerValue<'ctx> { + // very similar to gen_expr, but we don't do an extra load at the end + // and we flatten nested tuples + match &pattern.node { + ExprKind::Name { id, .. } => { + self.var_assignment.get(id).cloned().unwrap_or_else(|| { + let ptr = self.gen_var(pattern.custom.unwrap()); + self.var_assignment.insert(id.clone(), ptr); + ptr + }) + } + ExprKind::Attribute { value, attr, .. } => { + let index = self.get_attr_index(value.custom.unwrap(), attr); + let val = self.gen_expr(value); + let ptr = if let BasicValueEnum::PointerValue(v) = val { + v + } else { + unreachable!(); + }; + unsafe { + ptr.const_in_bounds_gep(&[ + self.ctx.i32_type().const_zero(), + self.ctx.i32_type().const_int(index as u64, false), + ]) + } + } + ExprKind::Subscript { .. } => unimplemented!(), + _ => unreachable!(), + } + } + + fn gen_assignment(&mut self, target: &Expr>, value: BasicValueEnum<'ctx>) { + if let ExprKind::Tuple { elts, .. } = &target.node { + if let BasicValueEnum::PointerValue(ptr) = value { + for (i, elt) in elts.iter().enumerate() { + unsafe { + let t = ptr.const_in_bounds_gep(&[ + self.ctx.i32_type().const_zero(), + self.ctx.i32_type().const_int(i as u64, false), + ]); + let v = self.builder.build_load(t, "tmpload"); + self.gen_assignment(elt, v); + } + } + } else { + unreachable!() + } + } else { + let ptr = self.parse_pattern(target); + self.builder.build_store(ptr, value); + } + } + + pub fn gen_stmt(&mut self, stmt: &Stmt>) { + match &stmt.node { + StmtKind::Expr { value } => { + self.gen_expr(&value); + } + StmtKind::Return { value } => { + let value = value.as_ref().map(|v| self.gen_expr(&v)); + let value = value.as_ref().map(|v| v as &dyn BasicValue); + self.builder.build_return(value); + } + StmtKind::AnnAssign { target, value, .. } => { + if let Some(value) = value { + let value = self.gen_expr(&value); + self.gen_assignment(target, value); + } + } + StmtKind::Assign { targets, value, .. } => { + let value = self.gen_expr(&value); + for target in targets.iter() { + self.gen_assignment(target, value); + } + } + StmtKind::Continue => { + self.builder.build_unconditional_branch(self.loop_bb.unwrap().0); + } + StmtKind::Break => { + self.builder.build_unconditional_branch(self.loop_bb.unwrap().1); + } + StmtKind::While { test, body, orelse } => { + let current = self.builder.get_insert_block().unwrap().get_parent().unwrap(); + let test_bb = self.ctx.append_basic_block(current, "test"); + let body_bb = self.ctx.append_basic_block(current, "body"); + let cont_bb = self.ctx.append_basic_block(current, "cont"); + // if there is no orelse, we just go to cont_bb + let orelse_bb = if orelse.is_empty() { + cont_bb + } else { + self.ctx.append_basic_block(current, "orelse") + }; + // store loop bb information and restore it later + let loop_bb = self.loop_bb.replace((test_bb, cont_bb)); + self.builder.build_unconditional_branch(test_bb); + self.builder.position_at_end(test_bb); + let test = self.gen_expr(test); + if let BasicValueEnum::IntValue(test) = test { + self.builder.build_conditional_branch(test, body_bb, orelse_bb); + } else { + unreachable!() + }; + self.builder.position_at_end(body_bb); + for stmt in body.iter() { + self.gen_stmt(stmt); + } + self.builder.build_unconditional_branch(test_bb); + if !orelse.is_empty() { + self.builder.position_at_end(orelse_bb); + for stmt in orelse.iter() { + self.gen_stmt(stmt); + } + self.builder.build_unconditional_branch(cont_bb); + } + self.builder.position_at_end(cont_bb); + self.loop_bb = loop_bb; + } + _ => unimplemented!(), + } + } +} diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs new file mode 100644 index 00000000..72ea9a75 --- /dev/null +++ b/nac3core/src/codegen/test.rs @@ -0,0 +1,247 @@ +use super::{CodeGenTask, WorkerRegistry}; +use crate::{ + codegen::WithCall, + location::Location, + symbol_resolver::{SymbolResolver, SymbolValue}, + top_level::{DefinitionId, TopLevelContext}, + typecheck::{ + magic_methods::set_primitives_magic_methods, + type_inferencer::{CodeLocation, FunctionData, Inferencer, PrimitiveStore}, + typedef::{CallId, FunSignature, FuncArg, Type, TypeEnum, Unifier}, + }, +}; +use indoc::indoc; +use parking_lot::RwLock; +use rustpython_parser::{ast::fold::Fold, parser::parse_program}; +use std::collections::HashMap; +use std::sync::Arc; + +#[derive(Clone)] +struct Resolver { + id_to_type: HashMap, + id_to_def: HashMap, + class_names: HashMap, +} + +impl SymbolResolver for Resolver { + fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option { + self.id_to_type.get(str).cloned() + } + + fn get_symbol_value(&self, _: &str) -> Option { + unimplemented!() + } + + fn get_symbol_location(&self, _: &str) -> Option { + unimplemented!() + } + + fn get_identifier_def(&self, id: &str) -> Option { + self.id_to_def.get(id).cloned() + } +} + +struct TestEnvironment { + pub unifier: Unifier, + pub function_data: FunctionData, + pub primitives: PrimitiveStore, + pub id_to_name: HashMap, + pub identifier_mapping: HashMap, + pub virtual_checks: Vec<(Type, Type)>, + pub calls: HashMap, + pub top_level: TopLevelContext, +} + +impl TestEnvironment { + pub fn basic_test_env() -> TestEnvironment { + let mut unifier = Unifier::new(); + + let int32 = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(0), + fields: HashMap::new().into(), + params: HashMap::new().into(), + }); + let int64 = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(1), + fields: HashMap::new().into(), + params: HashMap::new().into(), + }); + let float = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(2), + fields: HashMap::new().into(), + params: HashMap::new().into(), + }); + let bool = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(3), + fields: HashMap::new().into(), + params: HashMap::new().into(), + }); + let none = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(4), + fields: HashMap::new().into(), + params: HashMap::new().into(), + }); + let primitives = PrimitiveStore { int32, int64, float, bool, none }; + set_primitives_magic_methods(&primitives, &mut unifier); + + let id_to_name = [ + (0, "int32".to_string()), + (1, "int64".to_string()), + (2, "float".to_string()), + (3, "bool".to_string()), + (4, "none".to_string()), + ] + .iter() + .cloned() + .collect(); + + let mut identifier_mapping = HashMap::new(); + identifier_mapping.insert("None".into(), none); + + let resolver = Arc::new(Resolver { + id_to_type: identifier_mapping.clone(), + id_to_def: Default::default(), + class_names: Default::default(), + }) as Arc; + + TestEnvironment { + unifier, + top_level: TopLevelContext { + definitions: Default::default(), + unifiers: Default::default(), + // conetexts: Default::default(), + }, + function_data: FunctionData { + resolver, + bound_variables: Vec::new(), + return_type: Some(primitives.int32), + }, + primitives, + id_to_name, + identifier_mapping, + virtual_checks: Vec::new(), + calls: HashMap::new(), + } + } + + fn get_inferencer(&mut self) -> Inferencer { + Inferencer { + top_level: &self.top_level, + function_data: &mut self.function_data, + unifier: &mut self.unifier, + variable_mapping: Default::default(), + primitives: &mut self.primitives, + virtual_checks: &mut self.virtual_checks, + calls: &mut self.calls, + } + } +} + +#[test] +fn test_primitives() { + let mut env = TestEnvironment::basic_test_env(); + let threads = ["test"]; + let signature = FunSignature { + args: vec![ + FuncArg { name: "a".to_string(), ty: env.primitives.int32, default_value: None }, + FuncArg { name: "b".to_string(), ty: env.primitives.int32, default_value: None }, + ], + ret: env.primitives.int32, + vars: HashMap::new(), + }; + + let mut inferencer = env.get_inferencer(); + inferencer.variable_mapping.insert("a".into(), inferencer.primitives.int32); + inferencer.variable_mapping.insert("b".into(), inferencer.primitives.int32); + let source = indoc! { " + c = a + b + d = a if c == 1 else 0 + return d + "}; + let statements = parse_program(source).unwrap(); + + let statements = statements + .into_iter() + .map(|v| inferencer.fold_stmt(v)) + .collect::, _>>() + .unwrap(); + let mut identifiers = vec!["a".to_string(), "b".to_string()]; + inferencer.check_block(&statements, &mut identifiers).unwrap(); + + let top_level = Arc::new(TopLevelContext { + definitions: Default::default(), + unifiers: Arc::new(RwLock::new(vec![(env.unifier.get_shared_unifier(), env.primitives)])), + // conetexts: Default::default(), + }); + let task = CodeGenTask { + subst: Default::default(), + symbol_name: "testing".to_string(), + body: statements, + unifier_index: 0, + resolver: env.function_data.resolver.clone(), + signature, + }; + + let f = Arc::new(WithCall::new(Box::new(|module| { + // the following IR is equivalent to + // ``` + // ; ModuleID = 'test.ll' + // source_filename = "test" + // + // ; Function Attrs: norecurse nounwind readnone + // define i32 @testing(i32 %0, i32 %1) local_unnamed_addr #0 { + // init: + // %add = add i32 %1, %0 + // %cmp = icmp eq i32 %add, 1 + // %ifexpr = select i1 %cmp, i32 %0, i32 0 + // ret i32 %ifexpr + // } + // + // attributes #0 = { norecurse nounwind readnone } + // ``` + // after O2 optimization + + let expected = indoc! {" + ; ModuleID = 'test' + source_filename = \"test\" + + define i32 @testing(i32 %0, i32 %1) { + init: + %a = alloca i32, align 4 + store i32 %0, i32* %a, align 4 + %b = alloca i32, align 4 + store i32 %1, i32* %b, align 4 + %tmp = alloca i32, align 4 + %tmp4 = alloca i32, align 4 + br label %body + + body: ; preds = %init + %load = load i32, i32* %a, align 4 + %load1 = load i32, i32* %b, align 4 + %add = add i32 %load, %load1 + store i32 %add, i32* %tmp, align 4 + %load2 = load i32, i32* %tmp, align 4 + %cmp = icmp eq i32 %load2, 1 + br i1 %cmp, label %then, label %else + + then: ; preds = %body + %load3 = load i32, i32* %a, align 4 + br label %cont + + else: ; preds = %body + br label %cont + + cont: ; preds = %else, %then + %ifexpr = phi i32 [ %load3, %then ], [ 0, %else ] + store i32 %ifexpr, i32* %tmp4, align 4 + %load5 = load i32, i32* %tmp4, align 4 + ret i32 %load5 + } + "} + .trim(); + assert_eq!(expected, module.print_to_string().to_str().unwrap().trim()); + }))); + let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f); + registry.add_task(task); + registry.wait_tasks_complete(handles); +} diff --git a/nac3core/src/context/inference_context.rs b/nac3core/src/context/inference_context.rs deleted file mode 100644 index e1fbff2f..00000000 --- a/nac3core/src/context/inference_context.rs +++ /dev/null @@ -1,212 +0,0 @@ -use super::TopLevelContext; -use crate::typedef::*; -use std::boxed::Box; -use std::collections::HashMap; - -struct ContextStack<'a> { - /// stack level, starts from 0 - level: u32, - /// stack of variable definitions containing (id, def, level) where `def` is the original - /// definition in `level-1`. - var_defs: Vec<(usize, VarDef<'a>, u32)>, - /// stack of symbol definitions containing (name, level) where `level` is the smallest level - /// where the name is assigned a value - sym_def: Vec<(&'a str, u32)>, -} - -pub struct InferenceContext<'a> { - /// top level context - top_level: TopLevelContext<'a>, - - /// list of primitive instances - primitives: Vec, - /// list of variable instances - variables: Vec, - /// identifier to (type, readable) mapping. - /// an identifier might be defined earlier but has no value (for some code path), thus not - /// readable. - sym_table: HashMap<&'a str, (Type, bool)>, - /// resolution function reference, that may resolve unbounded identifiers to some type - resolution_fn: Box Result>, - /// stack - stack: ContextStack<'a>, -} - -// non-trivial implementations here -impl<'a> InferenceContext<'a> { - /// return a new `InferenceContext` from `TopLevelContext` and resolution function. - pub fn new( - top_level: TopLevelContext, - resolution_fn: Box Result>, - ) -> InferenceContext { - let primitives = (0..top_level.primitive_defs.len()) - .map(|v| TypeEnum::PrimitiveType(PrimitiveId(v)).into()) - .collect(); - let variables = (0..top_level.var_defs.len()) - .map(|v| TypeEnum::TypeVariable(VariableId(v)).into()) - .collect(); - InferenceContext { - top_level, - primitives, - variables, - sym_table: HashMap::new(), - resolution_fn, - stack: ContextStack { - level: 0, - var_defs: Vec::new(), - sym_def: Vec::new(), - }, - } - } - - /// execute the function with new scope. - /// variable assignment would be limited within the scope (not readable outside), and type - /// variable type guard would be limited within the scope. - /// returns the list of variables assigned within the scope, and the result of the function - pub fn with_scope(&mut self, f: F) -> (Vec<&'a str>, R) - where - F: FnOnce(&mut Self) -> R, - { - self.stack.level += 1; - let result = f(self); - self.stack.level -= 1; - while !self.stack.var_defs.is_empty() { - let (_, _, level) = self.stack.var_defs.last().unwrap(); - if *level > self.stack.level { - let (id, def, _) = self.stack.var_defs.pop().unwrap(); - self.top_level.var_defs[id] = def; - } else { - break; - } - } - let mut poped_names = Vec::new(); - while !self.stack.sym_def.is_empty() { - let (_, level) = self.stack.sym_def.last().unwrap(); - if *level > self.stack.level { - let (name, _) = self.stack.sym_def.pop().unwrap(); - self.sym_table.remove(name).unwrap(); - poped_names.push(name); - } else { - break; - } - } - (poped_names, result) - } - - /// assign a type to an identifier. - /// may return error if the identifier was defined but with different type - pub fn assign(&mut self, name: &'a str, ty: Type) -> Result { - if let Some((t, x)) = self.sym_table.get_mut(name) { - if t == &ty { - if !*x { - self.stack.sym_def.push((name, self.stack.level)); - } - *x = true; - Ok(ty) - } else { - Err("different types".into()) - } - } else { - self.stack.sym_def.push((name, self.stack.level)); - self.sym_table.insert(name, (ty.clone(), true)); - Ok(ty) - } - } - - /// check if an identifier is already defined - pub fn defined(&self, name: &str) -> bool { - self.sym_table.get(name).is_some() - } - - /// get the type of an identifier - /// may return error if the identifier is not defined, and cannot be resolved with the - /// resolution function. - pub fn resolve(&mut self, name: &str) -> Result { - if let Some((t, x)) = self.sym_table.get(name) { - if *x { - Ok(t.clone()) - } else { - Err("may not have value".into()) - } - } else { - self.resolution_fn.as_mut()(name) - } - } - - /// restrict the bound of a type variable by replacing its definition. - /// used for implementing type guard - pub fn restrict(&mut self, id: VariableId, mut def: VarDef<'a>) { - std::mem::swap(self.top_level.var_defs.get_mut(id.0).unwrap(), &mut def); - self.stack.var_defs.push((id.0, def, self.stack.level)); - } -} - -// trivial getters: -impl<'a> InferenceContext<'a> { - pub fn get_primitive(&self, id: PrimitiveId) -> Type { - self.primitives.get(id.0).unwrap().clone() - } - pub fn get_variable(&self, id: VariableId) -> Type { - self.variables.get(id.0).unwrap().clone() - } - - pub fn get_fn_def(&self, name: &str) -> Option<&FnDef> { - self.top_level.fn_table.get(name) - } - pub fn get_primitive_def(&self, id: PrimitiveId) -> &TypeDef { - self.top_level.primitive_defs.get(id.0).unwrap() - } - pub fn get_class_def(&self, id: ClassId) -> &ClassDef { - self.top_level.class_defs.get(id.0).unwrap() - } - pub fn get_parametric_def(&self, id: ParamId) -> &ParametricDef { - self.top_level.parametric_defs.get(id.0).unwrap() - } - pub fn get_variable_def(&self, id: VariableId) -> &VarDef { - self.top_level.var_defs.get(id.0).unwrap() - } - pub fn get_type(&self, name: &str) -> Option { - self.top_level.get_type(name) - } -} - -impl TypeEnum { - pub fn subst(&self, map: &HashMap) -> TypeEnum { - match self { - TypeEnum::TypeVariable(id) => map.get(id).map(|v| v.as_ref()).unwrap_or(self).clone(), - TypeEnum::ParametricType(id, params) => TypeEnum::ParametricType( - *id, - params - .iter() - .map(|v| v.as_ref().subst(map).into()) - .collect(), - ), - _ => self.clone(), - } - } - - pub fn get_subst(&self, ctx: &InferenceContext) -> HashMap { - match self { - TypeEnum::ParametricType(id, params) => { - let vars = &ctx.get_parametric_def(*id).params; - vars.iter() - .zip(params) - .map(|(v, p)| (*v, p.as_ref().clone().into())) - .collect() - } - // if this proves to be slow, we can use option type - _ => HashMap::new(), - } - } - - pub fn get_base<'b: 'a, 'a>(&'a self, ctx: &'b InferenceContext) -> Option<&'b TypeDef> { - match self { - TypeEnum::PrimitiveType(id) => Some(ctx.get_primitive_def(*id)), - TypeEnum::ClassType(id) | TypeEnum::VirtualClassType(id) => { - Some(&ctx.get_class_def(*id).base) - } - TypeEnum::ParametricType(id, _) => Some(&ctx.get_parametric_def(*id).base), - _ => None, - } - } -} diff --git a/nac3core/src/context/mod.rs b/nac3core/src/context/mod.rs deleted file mode 100644 index f59140d9..00000000 --- a/nac3core/src/context/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -mod inference_context; -mod top_level_context; -pub use inference_context::InferenceContext; -pub use top_level_context::TopLevelContext; diff --git a/nac3core/src/context/top_level_context.rs b/nac3core/src/context/top_level_context.rs deleted file mode 100644 index 004b271e..00000000 --- a/nac3core/src/context/top_level_context.rs +++ /dev/null @@ -1,136 +0,0 @@ -use crate::typedef::*; -use std::collections::HashMap; -use std::rc::Rc; - -/// Structure for storing top-level type definitions. -/// Used for collecting type signature from source code. -/// Can be converted to `InferenceContext` for type inference in functions. -pub struct TopLevelContext<'a> { - /// List of primitive definitions. - pub(super) primitive_defs: Vec>, - /// List of class definitions. - pub(super) class_defs: Vec>, - /// List of parametric type definitions. - pub(super) parametric_defs: Vec>, - /// List of type variable definitions. - pub(super) var_defs: Vec>, - /// Function name to signature mapping. - pub(super) fn_table: HashMap<&'a str, FnDef>, - /// Type name to type mapping. - pub(super) sym_table: HashMap<&'a str, Type>, - - primitives: Vec, - variables: Vec, -} - -impl<'a> TopLevelContext<'a> { - pub fn new(primitive_defs: Vec>) -> TopLevelContext { - let mut sym_table = HashMap::new(); - let mut primitives = Vec::new(); - for (i, t) in primitive_defs.iter().enumerate() { - primitives.push(TypeEnum::PrimitiveType(PrimitiveId(i)).into()); - sym_table.insert(t.name, TypeEnum::PrimitiveType(PrimitiveId(i)).into()); - } - TopLevelContext { - primitive_defs, - class_defs: Vec::new(), - parametric_defs: Vec::new(), - var_defs: Vec::new(), - fn_table: HashMap::new(), - sym_table, - primitives, - variables: Vec::new(), - } - } - - pub fn add_class(&mut self, def: ClassDef<'a>) -> ClassId { - self.sym_table.insert( - def.base.name, - TypeEnum::ClassType(ClassId(self.class_defs.len())).into(), - ); - self.class_defs.push(def); - ClassId(self.class_defs.len() - 1) - } - - pub fn add_parametric(&mut self, def: ParametricDef<'a>) -> ParamId { - let params = def - .params - .iter() - .map(|&v| Rc::new(TypeEnum::TypeVariable(v))) - .collect(); - self.sym_table.insert( - def.base.name, - TypeEnum::ParametricType(ParamId(self.parametric_defs.len()), params).into(), - ); - self.parametric_defs.push(def); - ParamId(self.parametric_defs.len() - 1) - } - - pub fn add_variable(&mut self, def: VarDef<'a>) -> VariableId { - self.sym_table.insert( - def.name, - TypeEnum::TypeVariable(VariableId(self.var_defs.len())).into(), - ); - self.add_variable_private(def) - } - - pub fn add_variable_private(&mut self, def: VarDef<'a>) -> VariableId { - self.var_defs.push(def); - self.variables - .push(TypeEnum::TypeVariable(VariableId(self.var_defs.len() - 1)).into()); - VariableId(self.var_defs.len() - 1) - } - - pub fn add_fn(&mut self, name: &'a str, def: FnDef) { - self.fn_table.insert(name, def); - } - - pub fn get_fn_def(&self, name: &str) -> Option<&FnDef> { - self.fn_table.get(name) - } - - pub fn get_primitive_def_mut(&mut self, id: PrimitiveId) -> &mut TypeDef<'a> { - self.primitive_defs.get_mut(id.0).unwrap() - } - - pub fn get_primitive_def(&self, id: PrimitiveId) -> &TypeDef { - self.primitive_defs.get(id.0).unwrap() - } - - pub fn get_class_def_mut(&mut self, id: ClassId) -> &mut ClassDef<'a> { - self.class_defs.get_mut(id.0).unwrap() - } - - pub fn get_class_def(&self, id: ClassId) -> &ClassDef { - self.class_defs.get(id.0).unwrap() - } - - pub fn get_parametric_def_mut(&mut self, id: ParamId) -> &mut ParametricDef<'a> { - self.parametric_defs.get_mut(id.0).unwrap() - } - - pub fn get_parametric_def(&self, id: ParamId) -> &ParametricDef { - self.parametric_defs.get(id.0).unwrap() - } - - pub fn get_variable_def_mut(&mut self, id: VariableId) -> &mut VarDef<'a> { - self.var_defs.get_mut(id.0).unwrap() - } - - pub fn get_variable_def(&self, id: VariableId) -> &VarDef { - self.var_defs.get(id.0).unwrap() - } - - pub fn get_primitive(&self, id: PrimitiveId) -> Type { - self.primitives.get(id.0).unwrap().clone() - } - - pub fn get_variable(&self, id: VariableId) -> Type { - self.variables.get(id.0).unwrap().clone() - } - - pub fn get_type(&self, name: &str) -> Option { - // TODO: handle parametric types - self.sym_table.get(name).cloned() - } -} diff --git a/nac3core/src/expression_inference.rs b/nac3core/src/expression_inference.rs deleted file mode 100644 index fafe04e5..00000000 --- a/nac3core/src/expression_inference.rs +++ /dev/null @@ -1,922 +0,0 @@ -use crate::context::InferenceContext; -use crate::inference_core::resolve_call; -use crate::magic_methods::*; -use crate::primitives::*; -use crate::typedef::{Type, TypeEnum::*}; -use rustpython_parser::ast::{ - Comparison, Comprehension, ComprehensionKind, Expression, ExpressionType, Operator, - UnaryOperator, -}; -use std::convert::TryInto; - -type ParserResult = Result, String>; - -pub fn infer_expr<'b: 'a, 'a>( - ctx: &mut InferenceContext<'a>, - expr: &'b Expression, -) -> ParserResult { - match &expr.node { - ExpressionType::Number { value } => infer_constant(ctx, value), - ExpressionType::Identifier { name } => infer_identifier(ctx, name), - ExpressionType::List { elements } => infer_list(ctx, elements), - ExpressionType::Tuple { elements } => infer_tuple(ctx, elements), - ExpressionType::Attribute { value, name } => infer_attribute(ctx, value, name), - ExpressionType::BoolOp { values, .. } => infer_bool_ops(ctx, values), - ExpressionType::Binop { a, b, op } => infer_bin_ops(ctx, op, a, b), - ExpressionType::Unop { op, a } => infer_unary_ops(ctx, op, a), - ExpressionType::Compare { vals, ops } => infer_compare(ctx, vals, ops), - ExpressionType::Call { - args, - function, - keywords, - } => { - if !keywords.is_empty() { - Err("keyword is not supported".into()) - } else { - infer_call(ctx, &args, &function) - } - } - ExpressionType::Subscript { a, b } => infer_subscript(ctx, a, b), - ExpressionType::IfExpression { test, body, orelse } => { - infer_if_expr(ctx, &test, &body, orelse) - } - ExpressionType::Comprehension { kind, generators } => match kind.as_ref() { - ComprehensionKind::List { element } => { - if generators.len() == 1 { - infer_list_comprehension(ctx, element, &generators[0]) - } else { - Err("only 1 generator statement is supported".into()) - } - } - _ => Err("only list comprehension is supported".into()), - }, - ExpressionType::True | ExpressionType::False => Ok(Some(ctx.get_primitive(BOOL_TYPE))), - _ => Err("not supported".into()), - } -} - -fn infer_constant( - ctx: &mut InferenceContext, - value: &rustpython_parser::ast::Number, -) -> ParserResult { - use rustpython_parser::ast::Number; - match value { - Number::Integer { value } => { - let int32: Result = value.try_into(); - if int32.is_ok() { - Ok(Some(ctx.get_primitive(INT32_TYPE))) - } else { - Err("integer out of range".into()) - } - } - Number::Float { .. } => Ok(Some(ctx.get_primitive(FLOAT_TYPE))), - _ => Err("not supported".into()), - } -} - -fn infer_identifier(ctx: &mut InferenceContext, name: &str) -> ParserResult { - Ok(Some(ctx.resolve(name)?)) -} - -fn infer_list<'b: 'a, 'a>( - ctx: &mut InferenceContext<'a>, - elements: &'b [Expression], -) -> ParserResult { - if elements.is_empty() { - return Ok(Some(ParametricType(LIST_TYPE, vec![BotType.into()]).into())); - } - - let mut types = elements.iter().map(|v| infer_expr(ctx, v)); - - let head = types.next().unwrap()?; - if head.is_none() { - return Err("list elements must have some type".into()); - } - for v in types { - // TODO: try virtual type... - if v? != head { - return Err("inhomogeneous list is not allowed".into()); - } - } - Ok(Some(ParametricType(LIST_TYPE, vec![head.unwrap()]).into())) -} - -fn infer_tuple<'b: 'a, 'a>( - ctx: &mut InferenceContext<'a>, - elements: &'b [Expression], -) -> ParserResult { - let types: Result>, String> = - elements.iter().map(|v| infer_expr(ctx, v)).collect(); - if let Some(t) = types? { - Ok(Some(ParametricType(TUPLE_TYPE, t).into())) - } else { - Err("tuple elements must have some type".into()) - } -} - -fn infer_attribute<'a>( - ctx: &mut InferenceContext<'a>, - value: &'a Expression, - name: &str, -) -> ParserResult { - let value = infer_expr(ctx, value)?.ok_or_else(|| "no value".to_string())?; - if let TypeVariable(_) = value.as_ref() { - return Err("no fields for type variable".into()); - } - - value - .get_base(ctx) - .and_then(|b| b.fields.get(name).cloned()) - .map_or_else(|| Err("no such field".to_string()), |v| Ok(Some(v))) -} - -fn infer_bool_ops<'a>(ctx: &mut InferenceContext<'a>, values: &'a [Expression]) -> ParserResult { - assert_eq!(values.len(), 2); - let left = infer_expr(ctx, &values[0])?.ok_or_else(|| "no value".to_string())?; - let right = infer_expr(ctx, &values[1])?.ok_or_else(|| "no value".to_string())?; - - let b = ctx.get_primitive(BOOL_TYPE); - if left == b && right == b { - Ok(Some(b)) - } else { - Err("bool operands must be bool".into()) - } -} - -fn infer_bin_ops<'b: 'a, 'a>( - ctx: &mut InferenceContext<'a>, - op: &Operator, - left: &'b Expression, - right: &'b Expression, -) -> ParserResult { - let left = infer_expr(ctx, left)?.ok_or_else(|| "no value".to_string())?; - let right = infer_expr(ctx, right)?.ok_or_else(|| "no value".to_string())?; - let fun = binop_name(op); - resolve_call(ctx, Some(left), fun, &[right]) -} - -fn infer_unary_ops<'b: 'a, 'a>( - ctx: &mut InferenceContext<'a>, - op: &UnaryOperator, - obj: &'b Expression, -) -> ParserResult { - let ty = infer_expr(ctx, obj)?.ok_or_else(|| "no value".to_string())?; - if let UnaryOperator::Not = op { - if ty == ctx.get_primitive(BOOL_TYPE) { - Ok(Some(ty)) - } else { - Err("logical not must be applied to bool".into()) - } - } else { - resolve_call(ctx, Some(ty), unaryop_name(op), &[]) - } -} - -fn infer_compare<'b: 'a, 'a>( - ctx: &mut InferenceContext<'a>, - vals: &'b [Expression], - ops: &'b [Comparison], -) -> ParserResult { - let types: Result>, _> = vals.iter().map(|v| infer_expr(ctx, v)).collect(); - let types = types?; - if types.is_none() { - return Err("comparison operands must have type".into()); - } - let types = types.unwrap(); - let boolean = ctx.get_primitive(BOOL_TYPE); - let left = &types[..types.len() - 1]; - let right = &types[1..]; - - for ((a, b), op) in left.iter().zip(right.iter()).zip(ops.iter()) { - let fun = comparison_name(op).ok_or_else(|| "unsupported comparison".to_string())?; - let ty = resolve_call(ctx, Some(a.clone()), fun, &[b.clone()])?; - if ty.is_none() || ty.unwrap() != boolean { - return Err("comparison result must be boolean".into()); - } - } - Ok(Some(boolean)) -} - -fn infer_call<'b: 'a, 'a>( - ctx: &mut InferenceContext<'a>, - args: &'b [Expression], - function: &'b Expression, -) -> ParserResult { - // TODO: special handling for int64 constant - let types: Result>, _> = args.iter().map(|v| infer_expr(ctx, v)).collect(); - let types = types?; - if types.is_none() { - return Err("function params must have type".into()); - } - - let (obj, fun) = match &function.node { - ExpressionType::Identifier { name } => (None, name), - ExpressionType::Attribute { value, name } => ( - Some(infer_expr(ctx, &value)?.ok_or_else(|| "no value".to_string())?), - name, - ), - _ => return Err("not supported".into()), - }; - resolve_call(ctx, obj, fun.as_str(), &types.unwrap()) -} - -fn infer_subscript<'b: 'a, 'a>( - ctx: &mut InferenceContext<'a>, - a: &'b Expression, - b: &'b Expression, -) -> ParserResult { - let a = infer_expr(ctx, a)?.ok_or_else(|| "no value".to_string())?; - let t = if let ParametricType(LIST_TYPE, ls) = a.as_ref() { - ls[0].clone() - } else { - return Err("subscript is not supported for types other than list".into()); - }; - - match &b.node { - ExpressionType::Slice { elements } => { - let int32 = ctx.get_primitive(INT32_TYPE); - let types: Result>, _> = elements - .iter() - .map(|v| { - if let ExpressionType::None = v.node { - Ok(Some(int32.clone())) - } else { - infer_expr(ctx, v) - } - }) - .collect(); - let types = types?.ok_or_else(|| "slice must have type".to_string())?; - if types.iter().all(|v| v == &int32) { - Ok(Some(a)) - } else { - Err("slice must be int32 type".into()) - } - } - _ => { - let b = infer_expr(ctx, b)?.ok_or_else(|| "no value".to_string())?; - if b == ctx.get_primitive(INT32_TYPE) { - Ok(Some(t)) - } else { - Err("index must be either slice or int32".into()) - } - } - } -} - -fn infer_if_expr<'b: 'a, 'a>( - ctx: &mut InferenceContext<'a>, - test: &'b Expression, - body: &'b Expression, - orelse: &'b Expression, -) -> ParserResult { - let test = infer_expr(ctx, test)?.ok_or_else(|| "no value".to_string())?; - if test != ctx.get_primitive(BOOL_TYPE) { - return Err("test should be bool".into()); - } - - let body = infer_expr(ctx, body)?; - let orelse = infer_expr(ctx, orelse)?; - if body.as_ref() == orelse.as_ref() { - Ok(body) - } else { - Err("divergent type".into()) - } -} - -fn infer_simple_binding<'a: 'b, 'b>( - ctx: &mut InferenceContext<'b>, - name: &'a Expression, - ty: Type, -) -> Result<(), String> { - match &name.node { - ExpressionType::Identifier { name } => { - if name == "_" { - Ok(()) - } else if ctx.defined(name.as_str()) { - Err("duplicated naming".into()) - } else { - ctx.assign(name.as_str(), ty)?; - Ok(()) - } - } - ExpressionType::Tuple { elements } => { - if let ParametricType(TUPLE_TYPE, ls) = ty.as_ref() { - if elements.len() == ls.len() { - for (a, b) in elements.iter().zip(ls.iter()) { - infer_simple_binding(ctx, a, b.clone())?; - } - Ok(()) - } else { - Err("different length".into()) - } - } else { - Err("not supported".into()) - } - } - _ => Err("not supported".into()), - } -} - -fn infer_list_comprehension<'b: 'a, 'a>( - ctx: &mut InferenceContext<'a>, - element: &'b Expression, - comprehension: &'b Comprehension, -) -> ParserResult { - if comprehension.is_async { - return Err("async is not supported".into()); - } - - let iter = infer_expr(ctx, &comprehension.iter)?.ok_or_else(|| "no value".to_string())?; - if let ParametricType(LIST_TYPE, ls) = iter.as_ref() { - ctx.with_scope(|ctx| { - infer_simple_binding(ctx, &comprehension.target, ls[0].clone())?; - - let boolean = ctx.get_primitive(BOOL_TYPE); - for test in comprehension.ifs.iter() { - let result = - infer_expr(ctx, test)?.ok_or_else(|| "no value in test".to_string())?; - if result != boolean { - return Err("test must be bool".into()); - } - } - let result = infer_expr(ctx, element)?.ok_or_else(|| "no value")?; - Ok(Some(ParametricType(LIST_TYPE, vec![result]).into())) - }) - .1 - } else { - Err("iteration is supported for list only".into()) - } -} - -#[cfg(test)] -mod test { - use super::*; - use crate::context::*; - use crate::typedef::*; - use rustpython_parser::parser::parse_expression; - use std::collections::HashMap; - use std::rc::Rc; - - fn get_inference_context(ctx: TopLevelContext) -> InferenceContext { - InferenceContext::new(ctx, Box::new(|_| Err("unbounded identifier".into()))) - } - - #[test] - fn test_constants() { - let ctx = basic_ctx(); - let mut ctx = get_inference_context(ctx); - - let ast = parse_expression("123").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE)); - - let ast = parse_expression("2147483647").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE)); - - let ast = parse_expression("2147483648").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result, Err("integer out of range".into())); - // - // let ast = parse_expression("2147483648").unwrap(); - // let result = infer_expr(&mut ctx, &ast); - // assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT64_TYPE)); - - // let ast = parse_expression("9223372036854775807").unwrap(); - // let result = infer_expr(&mut ctx, &ast); - // assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT64_TYPE)); - - // let ast = parse_expression("9223372036854775808").unwrap(); - // let result = infer_expr(&mut ctx, &ast); - // assert_eq!(result, Err("integer out of range".into())); - - let ast = parse_expression("123.456").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(FLOAT_TYPE)); - - let ast = parse_expression("True").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE)); - - let ast = parse_expression("False").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE)); - } - - #[test] - fn test_identifier() { - let ctx = basic_ctx(); - let mut ctx = get_inference_context(ctx); - ctx.assign("abc", ctx.get_primitive(INT32_TYPE)).unwrap(); - - let ast = parse_expression("abc").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE)); - - let ast = parse_expression("ab").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result, Err("unbounded identifier".into())); - } - - #[test] - fn test_list() { - let mut ctx = basic_ctx(); - ctx.add_fn( - "foo", - FnDef { - args: vec![], - result: None, - }, - ); - let mut ctx = get_inference_context(ctx); - ctx.assign("abc", ctx.get_primitive(INT32_TYPE)).unwrap(); - // def is reserved... - ctx.assign("efg", ctx.get_primitive(INT32_TYPE)).unwrap(); - ctx.assign("xyz", ctx.get_primitive(FLOAT_TYPE)).unwrap(); - - let ast = parse_expression("[]").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!( - result.unwrap().unwrap(), - ParametricType(LIST_TYPE, vec![BotType.into()]).into() - ); - - let ast = parse_expression("[abc]").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!( - result.unwrap().unwrap(), - ParametricType(LIST_TYPE, vec![ctx.get_primitive(INT32_TYPE)]).into() - ); - - let ast = parse_expression("[abc, efg]").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!( - result.unwrap().unwrap(), - ParametricType(LIST_TYPE, vec![ctx.get_primitive(INT32_TYPE)]).into() - ); - - let ast = parse_expression("[abc, efg, xyz]").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result, Err("inhomogeneous list is not allowed".into())); - - let ast = parse_expression("[foo()]").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result, Err("list elements must have some type".into())); - } - - #[test] - fn test_tuple() { - let mut ctx = basic_ctx(); - ctx.add_fn( - "foo", - FnDef { - args: vec![], - result: None, - }, - ); - let mut ctx = get_inference_context(ctx); - ctx.assign("abc", ctx.get_primitive(INT32_TYPE)).unwrap(); - ctx.assign("efg", ctx.get_primitive(FLOAT_TYPE)).unwrap(); - - let ast = parse_expression("(abc, efg)").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!( - result.unwrap().unwrap(), - ParametricType( - TUPLE_TYPE, - vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(FLOAT_TYPE)] - ) - .into() - ); - - let ast = parse_expression("(abc, efg, foo())").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result, Err("tuple elements must have some type".into())); - } - - #[test] - fn test_attribute() { - let mut ctx = basic_ctx(); - ctx.add_fn( - "none", - FnDef { - args: vec![], - result: None, - }, - ); - let int32 = ctx.get_primitive(INT32_TYPE); - let float = ctx.get_primitive(FLOAT_TYPE); - - let foo = ctx.add_class(ClassDef { - base: TypeDef { - name: "Foo", - fields: HashMap::new(), - methods: HashMap::new(), - }, - parents: vec![], - }); - let foo_def = ctx.get_class_def_mut(foo); - foo_def.base.fields.insert("a", int32.clone()); - foo_def.base.fields.insert("b", ClassType(foo).into()); - foo_def.base.fields.insert("c", int32.clone()); - - let bar = ctx.add_class(ClassDef { - base: TypeDef { - name: "Bar", - fields: HashMap::new(), - methods: HashMap::new(), - }, - parents: vec![], - }); - let bar_def = ctx.get_class_def_mut(bar); - bar_def.base.fields.insert("a", int32); - bar_def.base.fields.insert("b", ClassType(bar).into()); - bar_def.base.fields.insert("c", float); - - let v0 = ctx.add_variable(VarDef { - name: "v0", - bound: vec![], - }); - - let v1 = ctx.add_variable(VarDef { - name: "v1", - bound: vec![ClassType(foo).into(), ClassType(bar).into()], - }); - - let mut ctx = get_inference_context(ctx); - ctx.assign("foo", Rc::new(ClassType(foo))).unwrap(); - ctx.assign("bar", Rc::new(ClassType(bar))).unwrap(); - ctx.assign("foobar", Rc::new(VirtualClassType(foo))) - .unwrap(); - ctx.assign("v0", ctx.get_variable(v0)).unwrap(); - ctx.assign("v1", ctx.get_variable(v1)).unwrap(); - ctx.assign("bot", Rc::new(BotType)).unwrap(); - - let ast = parse_expression("foo.a").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE)); - - let ast = parse_expression("foo.d").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result, Err("no such field".into())); - - let ast = parse_expression("foobar.a").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE)); - - let ast = parse_expression("v0.a").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result, Err("no fields for type variable".into())); - - let ast = parse_expression("v1.a").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result, Err("no fields for type variable".into())); - - let ast = parse_expression("none().a").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result, Err("no value".into())); - - let ast = parse_expression("bot.a").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result, Err("no such field".into())); - } - - #[test] - fn test_bool_ops() { - let mut ctx = basic_ctx(); - ctx.add_fn( - "none", - FnDef { - args: vec![], - result: None, - }, - ); - let mut ctx = get_inference_context(ctx); - - let ast = parse_expression("True and False").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE)); - - let ast = parse_expression("True and none()").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result, Err("no value".into())); - - let ast = parse_expression("True and 123").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result, Err("bool operands must be bool".into())); - } - - #[test] - fn test_bin_ops() { - let mut ctx = basic_ctx(); - let v0 = ctx.add_variable(VarDef { - name: "v0", - bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(INT64_TYPE)], - }); - let mut ctx = get_inference_context(ctx); - ctx.assign("a", TypeVariable(v0).into()).unwrap(); - - let ast = parse_expression("1 + 2 + 3").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE)); - - let ast = parse_expression("a + a + a").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result, Err("not supported".into())); - } - - #[test] - fn test_unary_ops() { - let mut ctx = basic_ctx(); - let v0 = ctx.add_variable(VarDef { - name: "v0", - bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(INT64_TYPE)], - }); - let mut ctx = get_inference_context(ctx); - ctx.assign("a", TypeVariable(v0).into()).unwrap(); - - let ast = parse_expression("-(123)").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE)); - - let ast = parse_expression("-a").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result, Err("not supported".into())); - - let ast = parse_expression("not True").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE)); - - let ast = parse_expression("not (1)").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result, Err("logical not must be applied to bool".into())); - } - - #[test] - fn test_compare() { - let mut ctx = basic_ctx(); - let v0 = ctx.add_variable(VarDef { - name: "v0", - bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(INT64_TYPE)], - }); - let mut ctx = get_inference_context(ctx); - ctx.assign("a", TypeVariable(v0).into()).unwrap(); - - let ast = parse_expression("a == a == a").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result, Err("not supported".into())); - - let ast = parse_expression("a == a == 1").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result, Err("not supported".into())); - - let ast = parse_expression("True > False").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result, Err("no such function".into())); - - let ast = parse_expression("True in False").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result, Err("unsupported comparison".into())); - } - - #[test] - fn test_call() { - let mut ctx = basic_ctx(); - ctx.add_fn( - "none", - FnDef { - args: vec![], - result: None, - }, - ); - - let foo = ctx.add_class(ClassDef { - base: TypeDef { - name: "Foo", - fields: HashMap::new(), - methods: HashMap::new(), - }, - parents: vec![], - }); - let foo_def = ctx.get_class_def_mut(foo); - foo_def.base.methods.insert( - "a", - FnDef { - args: vec![], - result: Some(Rc::new(ClassType(foo))), - }, - ); - - let bar = ctx.add_class(ClassDef { - base: TypeDef { - name: "Bar", - fields: HashMap::new(), - methods: HashMap::new(), - }, - parents: vec![], - }); - let bar_def = ctx.get_class_def_mut(bar); - bar_def.base.methods.insert( - "a", - FnDef { - args: vec![], - result: Some(Rc::new(ClassType(bar))), - }, - ); - - let v0 = ctx.add_variable(VarDef { - name: "v0", - bound: vec![], - }); - let v1 = ctx.add_variable(VarDef { - name: "v1", - bound: vec![ClassType(foo).into(), ClassType(bar).into()], - }); - let v2 = ctx.add_variable(VarDef { - name: "v2", - bound: vec![ - ClassType(foo).into(), - ClassType(bar).into(), - ctx.get_primitive(INT32_TYPE), - ], - }); - let mut ctx = get_inference_context(ctx); - ctx.assign("foo", Rc::new(ClassType(foo))).unwrap(); - ctx.assign("bar", Rc::new(ClassType(bar))).unwrap(); - ctx.assign("foobar", Rc::new(VirtualClassType(foo))) - .unwrap(); - ctx.assign("v0", ctx.get_variable(v0)).unwrap(); - ctx.assign("v1", ctx.get_variable(v1)).unwrap(); - ctx.assign("v2", ctx.get_variable(v2)).unwrap(); - ctx.assign("bot", Rc::new(BotType)).unwrap(); - - let ast = parse_expression("foo.a()").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result.unwrap().unwrap(), ClassType(foo).into()); - - let ast = parse_expression("v1.a()").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result, Err("not supported".into())); - - let ast = parse_expression("foobar.a()").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result.unwrap().unwrap(), ClassType(foo).into()); - - let ast = parse_expression("none().a()").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result, Err("no value".into())); - - let ast = parse_expression("bot.a()").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result, Err("not supported".into())); - - let ast = parse_expression("[][0].a()").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result, Err("not supported".into())); - } - - #[test] - fn infer_subscript() { - let mut ctx = basic_ctx(); - ctx.add_fn( - "none", - FnDef { - args: vec![], - result: None, - }, - ); - let mut ctx = get_inference_context(ctx); - - let ast = parse_expression("[1, 2, 3][0]").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE)); - - let ast = parse_expression("[[1]][0][0]").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE)); - - let ast = parse_expression("[1, 2, 3][1:2]").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!( - result.unwrap().unwrap(), - ParametricType(LIST_TYPE, vec![ctx.get_primitive(INT32_TYPE)]).into() - ); - - let ast = parse_expression("[1, 2, 3][1:2:2]").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!( - result.unwrap().unwrap(), - ParametricType(LIST_TYPE, vec![ctx.get_primitive(INT32_TYPE)]).into() - ); - - let ast = parse_expression("[1, 2, 3][1:1.2]").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result, Err("slice must be int32 type".into())); - - let ast = parse_expression("[1, 2, 3][1:none()]").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result, Err("slice must have type".into())); - - let ast = parse_expression("[1, 2, 3][1.2]").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result, Err("index must be either slice or int32".into())); - - let ast = parse_expression("[1, 2, 3][none()]").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result, Err("no value".into())); - - let ast = parse_expression("none()[1.2]").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result, Err("no value".into())); - - let ast = parse_expression("123[1]").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!( - result, - Err("subscript is not supported for types other than list".into()) - ); - } - - #[test] - fn test_if_expr() { - let mut ctx = basic_ctx(); - ctx.add_fn( - "none", - FnDef { - args: vec![], - result: None, - }, - ); - let mut ctx = get_inference_context(ctx); - - let ast = parse_expression("1 if True else 0").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE)); - - let ast = parse_expression("none() if True else none()").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result.unwrap(), None); - - let ast = parse_expression("none() if 1 else none()").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result, Err("test should be bool".into())); - - let ast = parse_expression("1 if True else none()").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result, Err("divergent type".into())); - } - - #[test] - fn test_list_comp() { - let mut ctx = basic_ctx(); - ctx.add_fn( - "none", - FnDef { - args: vec![], - result: None, - }, - ); - let int32 = ctx.get_primitive(INT32_TYPE); - let mut ctx = get_inference_context(ctx); - ctx.assign("z", int32.clone()).unwrap(); - - let ast = parse_expression("[x for x in [(1, 2), (2, 3), (3, 4)]][0]").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!( - result.unwrap().unwrap(), - ParametricType(TUPLE_TYPE, vec![int32.clone(), int32.clone()]).into() - ); - - let ast = parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)]][0]").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result.unwrap().unwrap(), int32); - - let ast = - parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)] if x > 0][0]").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result.unwrap().unwrap(), int32); - - let ast = parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)] if x][0]").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result, Err("test must be bool".into())); - - let ast = parse_expression("[y for x in []][0]").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result, Err("unbounded identifier".into())); - - let ast = parse_expression("[none() for x in []][0]").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result, Err("no value".into())); - - let ast = parse_expression("[z for z in []][0]").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!(result, Err("duplicated naming".into())); - - let ast = parse_expression("[x for x in [] for y in []]").unwrap(); - let result = infer_expr(&mut ctx, &ast); - assert_eq!( - result, - Err("only 1 generator statement is supported".into()) - ); - } -} diff --git a/nac3core/src/inference_core.rs b/nac3core/src/inference_core.rs deleted file mode 100644 index 3b6b7d06..00000000 --- a/nac3core/src/inference_core.rs +++ /dev/null @@ -1,525 +0,0 @@ -use crate::context::InferenceContext; -use crate::typedef::{TypeEnum::*, *}; -use std::collections::HashMap; - -fn find_subst( - ctx: &InferenceContext, - valuation: &Option<(VariableId, Type)>, - sub: &mut HashMap, - mut a: Type, - mut b: Type, -) -> Result<(), String> { - // TODO: fix error messages later - if let TypeVariable(id) = a.as_ref() { - if let Some((assumption_id, t)) = valuation { - if assumption_id == id { - a = t.clone(); - } - } - } - - let mut substituted = false; - if let TypeVariable(id) = b.as_ref() { - if let Some(c) = sub.get(&id) { - b = c.clone(); - substituted = true; - } - } - - match (a.as_ref(), b.as_ref()) { - (BotType, _) => Ok(()), - (TypeVariable(id_a), TypeVariable(id_b)) => { - if substituted { - return if id_a == id_b { - Ok(()) - } else { - Err("different variables".to_string()) - }; - } - let v_a = ctx.get_variable_def(*id_a); - let v_b = ctx.get_variable_def(*id_b); - if !v_b.bound.is_empty() { - if v_a.bound.is_empty() { - return Err("unbounded a".to_string()); - } else { - let diff: Vec<_> = v_a - .bound - .iter() - .filter(|x| !v_b.bound.contains(x)) - .collect(); - if !diff.is_empty() { - return Err("different domain".to_string()); - } - } - } - sub.insert(*id_b, a.clone()); - Ok(()) - } - (TypeVariable(id_a), _) => { - let v_a = ctx.get_variable_def(*id_a); - if v_a.bound.len() == 1 && v_a.bound[0].as_ref() == b.as_ref() { - Ok(()) - } else { - Err("different domain".to_string()) - } - } - (_, TypeVariable(id_b)) => { - let v_b = ctx.get_variable_def(*id_b); - if v_b.bound.is_empty() || v_b.bound.contains(&a) { - sub.insert(*id_b, a.clone()); - Ok(()) - } else { - Err("different domain".to_string()) - } - } - (_, VirtualClassType(id_b)) => { - let mut parents; - match a.as_ref() { - ClassType(id_a) => { - parents = [*id_a].to_vec(); - } - VirtualClassType(id_a) => { - parents = [*id_a].to_vec(); - } - _ => { - return Err("cannot substitute non-class type into virtual class".to_string()); - } - }; - while !parents.is_empty() { - if *id_b == parents[0] { - return Ok(()); - } - let c = ctx.get_class_def(parents.remove(0)); - parents.extend_from_slice(&c.parents); - } - Err("not subtype".to_string()) - } - (ParametricType(id_a, param_a), ParametricType(id_b, param_b)) => { - if id_a != id_b || param_a.len() != param_b.len() { - Err("different parametric types".to_string()) - } else { - for (x, y) in param_a.iter().zip(param_b.iter()) { - find_subst(ctx, valuation, sub, x.clone(), y.clone())?; - } - Ok(()) - } - } - (_, _) => { - if a == b { - Ok(()) - } else { - Err("not equal".to_string()) - } - } - } -} - -fn resolve_call_rec( - ctx: &InferenceContext, - valuation: &Option<(VariableId, Type)>, - obj: Option, - func: &str, - args: &[Type], -) -> Result, String> { - let mut subst = obj - .as_ref() - .map(|v| v.get_subst(ctx)) - .unwrap_or_else(HashMap::new); - - let fun = match &obj { - Some(obj) => { - let base = match obj.as_ref() { - PrimitiveType(id) => &ctx.get_primitive_def(*id), - ClassType(id) | VirtualClassType(id) => &ctx.get_class_def(*id).base, - ParametricType(id, _) => &ctx.get_parametric_def(*id).base, - _ => return Err("not supported".to_string()), - }; - base.methods.get(func) - } - None => ctx.get_fn_def(func), - } - .ok_or_else(|| "no such function".to_string())?; - - if args.len() != fun.args.len() { - return Err("incorrect parameter number".to_string()); - } - for (a, b) in args.iter().zip(fun.args.iter()) { - find_subst(ctx, valuation, &mut subst, a.clone(), b.clone())?; - } - let result = fun.result.as_ref().map(|v| v.subst(&subst)); - Ok(result.map(|result| { - if let SelfType = result { - obj.unwrap() - } else { - result.into() - } - })) -} - -pub fn resolve_call( - ctx: &InferenceContext, - obj: Option, - func: &str, - args: &[Type], -) -> Result, String> { - resolve_call_rec(ctx, &None, obj, func, args) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::context::TopLevelContext; - use crate::primitives::*; - use std::rc::Rc; - - fn get_inference_context(ctx: TopLevelContext) -> InferenceContext { - InferenceContext::new(ctx, Box::new(|_| Err("unbounded identifier".into()))) - } - - #[test] - fn test_simple_generic() { - let mut ctx = basic_ctx(); - let v1 = ctx.add_variable(VarDef { - name: "V1", - bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(FLOAT_TYPE)], - }); - let v1 = ctx.get_variable(v1); - let v2 = ctx.add_variable(VarDef { - name: "V2", - bound: vec![ - ctx.get_primitive(BOOL_TYPE), - ctx.get_primitive(INT32_TYPE), - ctx.get_primitive(FLOAT_TYPE), - ], - }); - let v2 = ctx.get_variable(v2); - let ctx = get_inference_context(ctx); - - assert_eq!( - resolve_call(&ctx, None, "int32", &[ctx.get_primitive(FLOAT_TYPE)]), - Ok(Some(ctx.get_primitive(INT32_TYPE))) - ); - - assert_eq!( - resolve_call(&ctx, None, "int32", &[ctx.get_primitive(INT32_TYPE)],), - Ok(Some(ctx.get_primitive(INT32_TYPE))) - ); - - assert_eq!( - resolve_call(&ctx, None, "float", &[ctx.get_primitive(INT32_TYPE)]), - Ok(Some(ctx.get_primitive(FLOAT_TYPE))) - ); - - assert_eq!( - resolve_call(&ctx, None, "float", &[ctx.get_primitive(BOOL_TYPE)]), - Err("different domain".to_string()) - ); - - assert_eq!( - resolve_call(&ctx, None, "float", &[]), - Err("incorrect parameter number".to_string()) - ); - - assert_eq!( - resolve_call(&ctx, None, "float", &[v1]), - Ok(Some(ctx.get_primitive(FLOAT_TYPE))) - ); - - assert_eq!( - resolve_call(&ctx, None, "float", &[v2]), - Err("different domain".to_string()) - ); - } - - #[test] - fn test_methods() { - let mut ctx = basic_ctx(); - - let v0 = ctx.add_variable(VarDef { - name: "V0", - bound: vec![], - }); - let v0 = ctx.get_variable(v0); - - let int32 = ctx.get_primitive(INT32_TYPE); - let int64 = ctx.get_primitive(INT64_TYPE); - let ctx = get_inference_context(ctx); - - // simple cases - assert_eq!( - resolve_call(&ctx, Some(int32.clone()), "__add__", &[int32.clone()]), - Ok(Some(int32.clone())) - ); - - assert_ne!( - resolve_call(&ctx, Some(int32.clone()), "__add__", &[int32.clone()]), - Ok(Some(int64.clone())) - ); - - assert_eq!( - resolve_call(&ctx, Some(int32), "__add__", &[int64]), - Err("not equal".to_string()) - ); - - // with type variables - assert_eq!( - resolve_call(&ctx, Some(v0.clone()), "__add__", &[v0.clone()]), - Err("not supported".into()) - ); - } - - #[test] - fn test_multi_generic() { - let mut ctx = basic_ctx(); - let v0 = ctx.add_variable(VarDef { - name: "V0", - bound: vec![], - }); - let v0 = ctx.get_variable(v0); - let v1 = ctx.add_variable(VarDef { - name: "V1", - bound: vec![], - }); - let v1 = ctx.get_variable(v1); - let v2 = ctx.add_variable(VarDef { - name: "V2", - bound: vec![], - }); - let v2 = ctx.get_variable(v2); - let v3 = ctx.add_variable(VarDef { - name: "V3", - bound: vec![], - }); - let v3 = ctx.get_variable(v3); - - ctx.add_fn( - "foo", - FnDef { - args: vec![v0.clone(), v0.clone(), v1.clone()], - result: Some(v0.clone()), - }, - ); - - ctx.add_fn( - "foo1", - FnDef { - args: vec![ParametricType(TUPLE_TYPE, vec![v0.clone(), v0.clone(), v1]).into()], - result: Some(v0), - }, - ); - let ctx = get_inference_context(ctx); - - assert_eq!( - resolve_call(&ctx, None, "foo", &[v2.clone(), v2.clone(), v2.clone()]), - Ok(Some(v2.clone())) - ); - assert_eq!( - resolve_call(&ctx, None, "foo", &[v2.clone(), v2.clone(), v3.clone()]), - Ok(Some(v2.clone())) - ); - assert_eq!( - resolve_call(&ctx, None, "foo", &[v2.clone(), v3.clone(), v3.clone()]), - Err("different variables".to_string()) - ); - - assert_eq!( - resolve_call( - &ctx, - None, - "foo1", - &[ParametricType(TUPLE_TYPE, vec![v2.clone(), v2.clone(), v2.clone()]).into()] - ), - Ok(Some(v2.clone())) - ); - assert_eq!( - resolve_call( - &ctx, - None, - "foo1", - &[ParametricType(TUPLE_TYPE, vec![v2.clone(), v2.clone(), v3.clone()]).into()] - ), - Ok(Some(v2.clone())) - ); - assert_eq!( - resolve_call( - &ctx, - None, - "foo1", - &[ParametricType(TUPLE_TYPE, vec![v2, v3.clone(), v3]).into()] - ), - Err("different variables".to_string()) - ); - } - - #[test] - fn test_class_generics() { - let mut ctx = basic_ctx(); - - let list = ctx.get_parametric_def_mut(LIST_TYPE); - let t = Rc::new(TypeVariable(list.params[0])); - list.base.methods.insert( - "head", - FnDef { - args: vec![], - result: Some(t.clone()), - }, - ); - list.base.methods.insert( - "append", - FnDef { - args: vec![t], - result: None, - }, - ); - - let v0 = ctx.add_variable(VarDef { - name: "V0", - bound: vec![], - }); - let v0 = ctx.get_variable(v0); - let v1 = ctx.add_variable(VarDef { - name: "V1", - bound: vec![], - }); - let v1 = ctx.get_variable(v1); - let ctx = get_inference_context(ctx); - - assert_eq!( - resolve_call( - &ctx, - Some(ParametricType(LIST_TYPE, vec![v0.clone()]).into()), - "head", - &[] - ), - Ok(Some(v0.clone())) - ); - assert_eq!( - resolve_call( - &ctx, - Some(ParametricType(LIST_TYPE, vec![v0.clone()]).into()), - "append", - &[v0.clone()] - ), - Ok(None) - ); - assert_eq!( - resolve_call( - &ctx, - Some(ParametricType(LIST_TYPE, vec![v0]).into()), - "append", - &[v1] - ), - Err("different variables".to_string()) - ); - } - - #[test] - fn test_virtual_class() { - let mut ctx = basic_ctx(); - - let foo = ctx.add_class(ClassDef { - base: TypeDef { - name: "Foo", - methods: HashMap::new(), - fields: HashMap::new(), - }, - parents: vec![], - }); - - let foo1 = ctx.add_class(ClassDef { - base: TypeDef { - name: "Foo1", - methods: HashMap::new(), - fields: HashMap::new(), - }, - parents: vec![foo], - }); - - let foo2 = ctx.add_class(ClassDef { - base: TypeDef { - name: "Foo2", - methods: HashMap::new(), - fields: HashMap::new(), - }, - parents: vec![foo1], - }); - - let bar = ctx.add_class(ClassDef { - base: TypeDef { - name: "bar", - methods: HashMap::new(), - fields: HashMap::new(), - }, - parents: vec![], - }); - - ctx.add_fn( - "foo", - FnDef { - args: vec![VirtualClassType(foo).into()], - result: None, - }, - ); - ctx.add_fn( - "foo1", - FnDef { - args: vec![VirtualClassType(foo1).into()], - result: None, - }, - ); - let ctx = get_inference_context(ctx); - - assert_eq!( - resolve_call(&ctx, None, "foo", &[ClassType(foo).into()]), - Ok(None) - ); - - assert_eq!( - resolve_call(&ctx, None, "foo", &[ClassType(foo1).into()]), - Ok(None) - ); - - assert_eq!( - resolve_call(&ctx, None, "foo", &[ClassType(foo2).into()]), - Ok(None) - ); - - assert_eq!( - resolve_call(&ctx, None, "foo", &[ClassType(bar).into()]), - Err("not subtype".to_string()) - ); - - assert_eq!( - resolve_call(&ctx, None, "foo1", &[ClassType(foo1).into()]), - Ok(None) - ); - - assert_eq!( - resolve_call(&ctx, None, "foo1", &[ClassType(foo2).into()]), - Ok(None) - ); - - assert_eq!( - resolve_call(&ctx, None, "foo1", &[ClassType(foo).into()]), - Err("not subtype".to_string()) - ); - - // virtual class substitution - assert_eq!( - resolve_call(&ctx, None, "foo", &[VirtualClassType(foo).into()]), - Ok(None) - ); - assert_eq!( - resolve_call(&ctx, None, "foo", &[VirtualClassType(foo1).into()]), - Ok(None) - ); - assert_eq!( - resolve_call(&ctx, None, "foo", &[VirtualClassType(foo2).into()]), - Ok(None) - ); - assert_eq!( - resolve_call(&ctx, None, "foo", &[VirtualClassType(bar).into()]), - Err("not subtype".to_string()) - ); - } -} diff --git a/nac3core/src/lib.rs b/nac3core/src/lib.rs index 5cc15c3c..6d7de6f5 100644 --- a/nac3core/src/lib.rs +++ b/nac3core/src/lib.rs @@ -1,593 +1,8 @@ #![warn(clippy::all)] -#![allow(clippy::clone_double_ref)] +#![allow(dead_code)] -extern crate num_bigint; -extern crate inkwell; -extern crate rustpython_parser; - -pub mod expression_inference; -pub mod inference_core; -mod magic_methods; -pub mod primitives; -pub mod typedef; -pub mod context; - -use std::error::Error; -use std::fmt; -use std::path::Path; -use std::collections::HashMap; - -use num_traits::cast::ToPrimitive; - -use rustpython_parser::ast; - -use inkwell::OptimizationLevel; -use inkwell::builder::Builder; -use inkwell::context::Context; -use inkwell::module::Module; -use inkwell::targets::*; -use inkwell::types; -use inkwell::types::BasicType; -use inkwell::values; -use inkwell::{IntPredicate, FloatPredicate}; -use inkwell::basic_block; -use inkwell::passes; - - -#[derive(Debug)] -enum CompileErrorKind { - Unsupported(&'static str), - MissingTypeAnnotation, - UnknownTypeAnnotation, - IncompatibleTypes, - UnboundIdentifier, - BreakOutsideLoop, - Internal(&'static str) -} - -impl fmt::Display for CompileErrorKind { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - CompileErrorKind::Unsupported(feature) - => write!(f, "The following Python feature is not supported by NAC3: {}", feature), - CompileErrorKind::MissingTypeAnnotation - => write!(f, "Missing type annotation"), - CompileErrorKind::UnknownTypeAnnotation - => write!(f, "Unknown type annotation"), - CompileErrorKind::IncompatibleTypes - => write!(f, "Incompatible types"), - CompileErrorKind::UnboundIdentifier - => write!(f, "Unbound identifier"), - CompileErrorKind::BreakOutsideLoop - => write!(f, "Break outside loop"), - CompileErrorKind::Internal(details) - => write!(f, "Internal compiler error: {}", details), - } - } -} - -#[derive(Debug)] -pub struct CompileError { - location: ast::Location, - kind: CompileErrorKind, -} - -impl fmt::Display for CompileError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}, at {}", self.kind, self.location) - } -} - -impl Error for CompileError {} - -type CompileResult = Result; - -pub struct CodeGen<'ctx> { - context: &'ctx Context, - module: Module<'ctx>, - pass_manager: passes::PassManager>, - builder: Builder<'ctx>, - current_source_location: ast::Location, - namespace: HashMap>, - break_bb: Option>, -} - -impl<'ctx> CodeGen<'ctx> { - pub fn new(context: &'ctx Context) -> CodeGen<'ctx> { - let module = context.create_module("kernel"); - - let pass_manager = passes::PassManager::create(&module); - pass_manager.add_instruction_combining_pass(); - pass_manager.add_reassociate_pass(); - pass_manager.add_gvn_pass(); - pass_manager.add_cfg_simplification_pass(); - pass_manager.add_basic_alias_analysis_pass(); - pass_manager.add_promote_memory_to_register_pass(); - pass_manager.add_instruction_combining_pass(); - pass_manager.add_reassociate_pass(); - pass_manager.initialize(); - - let i32_type = context.i32_type(); - let fn_type = i32_type.fn_type(&[i32_type.into()], false); - module.add_function("output", fn_type, None); - - CodeGen { - context, module, pass_manager, - builder: context.create_builder(), - current_source_location: ast::Location::default(), - namespace: HashMap::new(), - break_bb: None, - } - } - - fn set_source_location(&mut self, location: ast::Location) { - self.current_source_location = location; - } - - fn compile_error(&self, kind: CompileErrorKind) -> CompileError { - CompileError { - location: self.current_source_location, - kind - } - } - - fn get_basic_type(&self, name: &str) -> CompileResult> { - match name { - "bool" => Ok(self.context.bool_type().into()), - "int32" => Ok(self.context.i32_type().into()), - "int64" => Ok(self.context.i64_type().into()), - "float32" => Ok(self.context.f32_type().into()), - "float64" => Ok(self.context.f64_type().into()), - _ => Err(self.compile_error(CompileErrorKind::UnknownTypeAnnotation)) - } - } - - fn compile_function_def( - &mut self, - name: &str, - args: &ast::Parameters, - body: &ast::Suite, - decorator_list: &[ast::Expression], - returns: &Option, - is_async: bool, - ) -> CompileResult> { - if is_async { - return Err(self.compile_error(CompileErrorKind::Unsupported("async functions"))) - } - for decorator in decorator_list.iter() { - self.set_source_location(decorator.location); - if let ast::ExpressionType::Identifier { name } = &decorator.node { - if name != "kernel" && name != "portable" { - return Err(self.compile_error(CompileErrorKind::Unsupported("custom decorators"))) - } - } else { - return Err(self.compile_error(CompileErrorKind::Unsupported("decorator must be an identifier"))) - } - } - - let args_type = args.args.iter().map(|val| { - self.set_source_location(val.location); - if let Some(annotation) = &val.annotation { - if let ast::ExpressionType::Identifier { name } = &annotation.node { - Ok(self.get_basic_type(&name)?) - } else { - Err(self.compile_error(CompileErrorKind::Unsupported("type annotation must be an identifier"))) - } - } else { - Err(self.compile_error(CompileErrorKind::MissingTypeAnnotation)) - } - }).collect::>>()?; - let return_type = if let Some(returns) = returns { - self.set_source_location(returns.location); - if let ast::ExpressionType::Identifier { name } = &returns.node { - if name == "None" { None } else { Some(self.get_basic_type(name)?) } - } else { - return Err(self.compile_error(CompileErrorKind::Unsupported("type annotation must be an identifier"))) - } - } else { - None - }; - - let fn_type = match return_type { - Some(ty) => ty.fn_type(&args_type, false), - None => self.context.void_type().fn_type(&args_type, false) - }; - - let function = self.module.add_function(name, fn_type, None); - let basic_block = self.context.append_basic_block(function, "entry"); - self.builder.position_at_end(basic_block); - - for (n, arg) in args.args.iter().enumerate() { - let param = function.get_nth_param(n as u32).unwrap(); - let alloca = self.builder.build_alloca(param.get_type(), &arg.arg); - self.builder.build_store(alloca, param); - self.namespace.insert(arg.arg.clone(), alloca); - } - - self.compile_suite(body, return_type)?; - - Ok(function) - } - - fn compile_expression( - &mut self, - expression: &ast::Expression - ) -> CompileResult> { - self.set_source_location(expression.location); - - match &expression.node { - ast::ExpressionType::True => Ok(self.context.bool_type().const_int(1, false).into()), - ast::ExpressionType::False => Ok(self.context.bool_type().const_int(0, false).into()), - ast::ExpressionType::Number { value: ast::Number::Integer { value } } => { - let mut bits = value.bits(); - if value.sign() == num_bigint::Sign::Minus { - bits += 1; - } - match bits { - 0..=32 => Ok(self.context.i32_type().const_int(value.to_i32().unwrap() as _, true).into()), - 33..=64 => Ok(self.context.i64_type().const_int(value.to_i64().unwrap() as _, true).into()), - _ => Err(self.compile_error(CompileErrorKind::Unsupported("integers larger than 64 bits"))) - } - }, - ast::ExpressionType::Number { value: ast::Number::Float { value } } => { - Ok(self.context.f64_type().const_float(*value).into()) - }, - ast::ExpressionType::Identifier { name } => { - match self.namespace.get(name) { - Some(value) => Ok(self.builder.build_load(*value, name).into()), - None => Err(self.compile_error(CompileErrorKind::UnboundIdentifier)) - } - }, - ast::ExpressionType::Unop { op, a } => { - let a = self.compile_expression(&a)?; - match (op, a) { - (ast::UnaryOperator::Pos, values::BasicValueEnum::IntValue(a)) - => Ok(a.into()), - (ast::UnaryOperator::Pos, values::BasicValueEnum::FloatValue(a)) - => Ok(a.into()), - (ast::UnaryOperator::Neg, values::BasicValueEnum::IntValue(a)) - => Ok(self.builder.build_int_neg(a, "tmpneg").into()), - (ast::UnaryOperator::Neg, values::BasicValueEnum::FloatValue(a)) - => Ok(self.builder.build_float_neg(a, "tmpneg").into()), - (ast::UnaryOperator::Inv, values::BasicValueEnum::IntValue(a)) - => Ok(self.builder.build_not(a, "tmpnot").into()), - (ast::UnaryOperator::Not, values::BasicValueEnum::IntValue(a)) => { - // boolean "not" - if a.get_type().get_bit_width() != 1 { - Err(self.compile_error(CompileErrorKind::Unsupported("unimplemented unary operation"))) - } else { - Ok(self.builder.build_not(a, "tmpnot").into()) - } - }, - _ => Err(self.compile_error(CompileErrorKind::Unsupported("unimplemented unary operation"))), - } - }, - ast::ExpressionType::Binop { a, op, b } => { - let a = self.compile_expression(&a)?; - let b = self.compile_expression(&b)?; - if a.get_type() != b.get_type() { - return Err(self.compile_error(CompileErrorKind::IncompatibleTypes)); - } - use ast::Operator::*; - match (op, a, b) { - (Add, values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b)) - => Ok(self.builder.build_int_add(a, b, "tmpadd").into()), - (Sub, values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b)) - => Ok(self.builder.build_int_sub(a, b, "tmpsub").into()), - (Mult, values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b)) - => Ok(self.builder.build_int_mul(a, b, "tmpmul").into()), - - (Add, values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b)) - => Ok(self.builder.build_float_add(a, b, "tmpadd").into()), - (Sub, values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b)) - => Ok(self.builder.build_float_sub(a, b, "tmpsub").into()), - (Mult, values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b)) - => Ok(self.builder.build_float_mul(a, b, "tmpmul").into()), - - (Div, values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b)) - => Ok(self.builder.build_float_div(a, b, "tmpdiv").into()), - (FloorDiv, values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b)) - => Ok(self.builder.build_int_signed_div(a, b, "tmpdiv").into()), - _ => Err(self.compile_error(CompileErrorKind::Unsupported("unimplemented binary operation"))), - } - }, - ast::ExpressionType::Compare { vals, ops } => { - let mut vals = vals.iter(); - let mut ops = ops.iter(); - - let mut result = None; - let mut a = self.compile_expression(vals.next().unwrap())?; - loop { - if let Some(op) = ops.next() { - let b = self.compile_expression(vals.next().unwrap())?; - if a.get_type() != b.get_type() { - return Err(self.compile_error(CompileErrorKind::IncompatibleTypes)); - } - let this_result = match (a, b) { - (values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b)) => { - match op { - ast::Comparison::Equal - => self.builder.build_int_compare(IntPredicate::EQ, a, b, "tmpeq"), - ast::Comparison::NotEqual - => self.builder.build_int_compare(IntPredicate::NE, a, b, "tmpne"), - ast::Comparison::Less - => self.builder.build_int_compare(IntPredicate::SLT, a, b, "tmpslt"), - ast::Comparison::LessOrEqual - => self.builder.build_int_compare(IntPredicate::SLE, a, b, "tmpsle"), - ast::Comparison::Greater - => self.builder.build_int_compare(IntPredicate::SGT, a, b, "tmpsgt"), - ast::Comparison::GreaterOrEqual - => self.builder.build_int_compare(IntPredicate::SGE, a, b, "tmpsge"), - _ => return Err(self.compile_error(CompileErrorKind::Unsupported("special comparison"))), - } - }, - (values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b)) => { - match op { - ast::Comparison::Equal - => self.builder.build_float_compare(FloatPredicate::OEQ, a, b, "tmpoeq"), - ast::Comparison::NotEqual - => self.builder.build_float_compare(FloatPredicate::UNE, a, b, "tmpune"), - ast::Comparison::Less - => self.builder.build_float_compare(FloatPredicate::OLT, a, b, "tmpolt"), - ast::Comparison::LessOrEqual - => self.builder.build_float_compare(FloatPredicate::OLE, a, b, "tmpole"), - ast::Comparison::Greater - => self.builder.build_float_compare(FloatPredicate::OGT, a, b, "tmpogt"), - ast::Comparison::GreaterOrEqual - => self.builder.build_float_compare(FloatPredicate::OGE, a, b, "tmpoge"), - _ => return Err(self.compile_error(CompileErrorKind::Unsupported("special comparison"))), - } - }, - _ => return Err(self.compile_error(CompileErrorKind::Unsupported("comparison of non-numerical types"))), - }; - match result { - Some(last) => { - result = Some(self.builder.build_and(last, this_result, "tmpand")); - } - None => { - result = Some(this_result); - } - } - a = b; - } else { - return Ok(result.unwrap().into()) - } - } - }, - ast::ExpressionType::Call { function, args, keywords } => { - if !keywords.is_empty() { - return Err(self.compile_error(CompileErrorKind::Unsupported("keyword arguments"))) - } - let args = args.iter().map(|val| self.compile_expression(val)) - .collect::>>()?; - self.set_source_location(expression.location); - if let ast::ExpressionType::Identifier { name } = &function.node { - match (name.as_str(), args[0]) { - ("int32", values::BasicValueEnum::IntValue(a)) => { - let nbits = a.get_type().get_bit_width(); - if nbits < 32 { - Ok(self.builder.build_int_s_extend(a, self.context.i32_type(), "tmpsext").into()) - } else if nbits > 32 { - Ok(self.builder.build_int_truncate(a, self.context.i32_type(), "tmptrunc").into()) - } else { - Ok(a.into()) - } - }, - ("int64", values::BasicValueEnum::IntValue(a)) => { - let nbits = a.get_type().get_bit_width(); - if nbits < 64 { - Ok(self.builder.build_int_s_extend(a, self.context.i64_type(), "tmpsext").into()) - } else { - Ok(a.into()) - } - }, - ("int32", values::BasicValueEnum::FloatValue(a)) => { - Ok(self.builder.build_float_to_signed_int(a, self.context.i32_type(), "tmpfptosi").into()) - }, - ("int64", values::BasicValueEnum::FloatValue(a)) => { - Ok(self.builder.build_float_to_signed_int(a, self.context.i64_type(), "tmpfptosi").into()) - }, - ("float32", values::BasicValueEnum::IntValue(a)) => { - Ok(self.builder.build_signed_int_to_float(a, self.context.f32_type(), "tmpsitofp").into()) - }, - ("float64", values::BasicValueEnum::IntValue(a)) => { - Ok(self.builder.build_signed_int_to_float(a, self.context.f64_type(), "tmpsitofp").into()) - }, - ("float32", values::BasicValueEnum::FloatValue(a)) => { - if a.get_type() == self.context.f64_type() { - Ok(self.builder.build_float_trunc(a, self.context.f32_type(), "tmptrunc").into()) - } else { - Ok(a.into()) - } - }, - ("float64", values::BasicValueEnum::FloatValue(a)) => { - if a.get_type() == self.context.f32_type() { - Ok(self.builder.build_float_ext(a, self.context.f64_type(), "tmpext").into()) - } else { - Ok(a.into()) - } - }, - - ("output", values::BasicValueEnum::IntValue(a)) => { - let fn_value = self.module.get_function("output").unwrap(); - Ok(self.builder.build_call(fn_value, &[a.into()], "call") - .try_as_basic_value().left().unwrap()) - }, - _ => Err(self.compile_error(CompileErrorKind::Unsupported("unrecognized call"))) - } - } else { - return Err(self.compile_error(CompileErrorKind::Unsupported("function must be an identifier"))) - } - }, - _ => return Err(self.compile_error(CompileErrorKind::Unsupported("unimplemented expression"))), - } - } - - fn compile_statement( - &mut self, - statement: &ast::Statement, - return_type: Option - ) -> CompileResult<()> { - self.set_source_location(statement.location); - - use ast::StatementType::*; - match &statement.node { - Assign { targets, value } => { - let value = self.compile_expression(value)?; - for target in targets.iter() { - self.set_source_location(target.location); - if let ast::ExpressionType::Identifier { name } = &target.node { - let builder = &self.builder; - let target = self.namespace.entry(name.clone()).or_insert_with( - || builder.build_alloca(value.get_type(), name)); - if target.get_type() != value.get_type().ptr_type(inkwell::AddressSpace::Generic) { - return Err(self.compile_error(CompileErrorKind::IncompatibleTypes)); - } - builder.build_store(*target, value); - } else { - return Err(self.compile_error(CompileErrorKind::Unsupported("assignment target must be an identifier"))) - } - } - }, - Expression { expression } => { self.compile_expression(expression)?; }, - If { test, body, orelse } => { - let test = self.compile_expression(test)?; - if test.get_type() != self.context.bool_type().into() { - return Err(self.compile_error(CompileErrorKind::IncompatibleTypes)); - } - - let parent = self.builder.get_insert_block().unwrap().get_parent().unwrap(); - let then_bb = self.context.append_basic_block(parent, "then"); - let else_bb = self.context.append_basic_block(parent, "else"); - let cont_bb = self.context.append_basic_block(parent, "ifcont"); - self.builder.build_conditional_branch(test.into_int_value(), then_bb, else_bb); - - self.builder.position_at_end(then_bb); - self.compile_suite(body, return_type)?; - self.builder.build_unconditional_branch(cont_bb); - - self.builder.position_at_end(else_bb); - if let Some(orelse) = orelse { - self.compile_suite(orelse, return_type)?; - } - self.builder.build_unconditional_branch(cont_bb); - self.builder.position_at_end(cont_bb); - }, - While { test, body, orelse } => { - let parent = self.builder.get_insert_block().unwrap().get_parent().unwrap(); - let test_bb = self.context.append_basic_block(parent, "test"); - self.builder.build_unconditional_branch(test_bb); - self.builder.position_at_end(test_bb); - let test = self.compile_expression(test)?; - if test.get_type() != self.context.bool_type().into() { - return Err(self.compile_error(CompileErrorKind::IncompatibleTypes)); - } - - let then_bb = self.context.append_basic_block(parent, "then"); - let else_bb = self.context.append_basic_block(parent, "else"); - let cont_bb = self.context.append_basic_block(parent, "ifcont"); - self.builder.build_conditional_branch(test.into_int_value(), then_bb, else_bb); - - self.break_bb = Some(cont_bb); - - self.builder.position_at_end(then_bb); - self.compile_suite(body, return_type)?; - self.builder.build_unconditional_branch(test_bb); - - self.builder.position_at_end(else_bb); - if let Some(orelse) = orelse { - self.compile_suite(orelse, return_type)?; - } - self.builder.build_unconditional_branch(cont_bb); - self.builder.position_at_end(cont_bb); - - self.break_bb = None; - }, - Break => { - if let Some(bb) = self.break_bb { - self.builder.build_unconditional_branch(bb); - let parent = self.builder.get_insert_block().unwrap().get_parent().unwrap(); - let unreachable_bb = self.context.append_basic_block(parent, "unreachable"); - self.builder.position_at_end(unreachable_bb); - } else { - return Err(self.compile_error(CompileErrorKind::BreakOutsideLoop)); - } - } - Return { value: Some(value) } => { - if let Some(return_type) = return_type { - let value = self.compile_expression(value)?; - if value.get_type() != return_type { - return Err(self.compile_error(CompileErrorKind::IncompatibleTypes)); - } - self.builder.build_return(Some(&value)); - } else { - return Err(self.compile_error(CompileErrorKind::IncompatibleTypes)); - } - }, - Return { value: None } => { - if !return_type.is_none() { - return Err(self.compile_error(CompileErrorKind::IncompatibleTypes)); - } - self.builder.build_return(None); - }, - Pass => (), - _ => return Err(self.compile_error(CompileErrorKind::Unsupported("special statement"))), - } - Ok(()) - } - - fn compile_suite( - &mut self, - suite: &ast::Suite, - return_type: Option - ) -> CompileResult<()> { - for statement in suite.iter() { - self.compile_statement(statement, return_type)?; - } - Ok(()) - } - - pub fn compile_toplevel(&mut self, statement: &ast::Statement) -> CompileResult<()> { - self.set_source_location(statement.location); - if let ast::StatementType::FunctionDef { - is_async, - name, - args, - body, - decorator_list, - returns, - } = &statement.node { - let function = self.compile_function_def(name, args, body, decorator_list, returns, *is_async)?; - self.pass_manager.run_on(&function); - Ok(()) - } else { - Err(self.compile_error(CompileErrorKind::Internal("top-level is not a function definition"))) - } - } - - pub fn print_ir(&self) { - self.module.print_to_stderr(); - } - - pub fn output(&self, filename: &str) { - //let triple = TargetTriple::create("riscv32-none-linux-gnu"); - let triple = TargetMachine::get_default_triple(); - let target = Target::from_triple(&triple) - .expect("couldn't create target from target triple"); - - let target_machine = target - .create_target_machine( - &triple, - "", - "", - OptimizationLevel::Default, - RelocMode::Default, - CodeModel::Default, - ) - .expect("couldn't create target machine"); - - target_machine - .write_to_file(&self.module, FileType::Object, Path::new(filename)) - .expect("couldn't write module to file"); - } -} +mod codegen; +mod location; +mod symbol_resolver; +mod top_level; +mod typecheck; diff --git a/nac3core/src/location.rs b/nac3core/src/location.rs new file mode 100644 index 00000000..424336f2 --- /dev/null +++ b/nac3core/src/location.rs @@ -0,0 +1,31 @@ +use rustpython_parser::ast; +use std::vec::Vec; + +#[derive(Clone, Copy, PartialEq)] +pub struct FileID(u32); + +#[derive(Clone, Copy, PartialEq)] +pub enum Location { + CodeRange(FileID, ast::Location), + Builtin, +} + +pub struct FileRegistry { + files: Vec, +} + +impl FileRegistry { + pub fn new() -> FileRegistry { + FileRegistry { files: Vec::new() } + } + + pub fn add_file(&mut self, path: &str) -> FileID { + let index = self.files.len() as u32; + self.files.push(path.to_owned()); + FileID(index) + } + + pub fn query_file(&self, id: FileID) -> &str { + &self.files[id.0 as usize] + } +} diff --git a/nac3core/src/magic_methods.rs b/nac3core/src/magic_methods.rs deleted file mode 100644 index b0c248b4..00000000 --- a/nac3core/src/magic_methods.rs +++ /dev/null @@ -1,58 +0,0 @@ -use rustpython_parser::ast::{Comparison, Operator, UnaryOperator}; - -pub fn binop_name(op: &Operator) -> &'static str { - match op { - Operator::Add => "__add__", - Operator::Sub => "__sub__", - Operator::Div => "__truediv__", - Operator::Mod => "__mod__", - Operator::Mult => "__mul__", - Operator::Pow => "__pow__", - Operator::BitOr => "__or__", - Operator::BitXor => "__xor__", - Operator::BitAnd => "__and__", - Operator::LShift => "__lshift__", - Operator::RShift => "__rshift__", - Operator::FloorDiv => "__floordiv__", - Operator::MatMult => "__matmul__", - } -} - -pub fn binop_assign_name(op: &Operator) -> &'static str { - match op { - Operator::Add => "__iadd__", - Operator::Sub => "__isub__", - Operator::Div => "__itruediv__", - Operator::Mod => "__imod__", - Operator::Mult => "__imul__", - Operator::Pow => "__ipow__", - Operator::BitOr => "__ior__", - Operator::BitXor => "__ixor__", - Operator::BitAnd => "__iand__", - Operator::LShift => "__ilshift__", - Operator::RShift => "__irshift__", - Operator::FloorDiv => "__ifloordiv__", - Operator::MatMult => "__imatmul__", - } -} - -pub fn unaryop_name(op: &UnaryOperator) -> &'static str { - match op { - UnaryOperator::Pos => "__pos__", - UnaryOperator::Neg => "__neg__", - UnaryOperator::Not => "__not__", - UnaryOperator::Inv => "__inv__", - } -} - -pub fn comparison_name(op: &Comparison) -> Option<&'static str> { - match op { - Comparison::Less => Some("__lt__"), - Comparison::LessOrEqual => Some("__le__"), - Comparison::Greater => Some("__gt__"), - Comparison::GreaterOrEqual => Some("__ge__"), - Comparison::Equal => Some("__eq__"), - Comparison::NotEqual => Some("__ne__"), - _ => None, - } -} diff --git a/nac3core/src/primitives.rs b/nac3core/src/primitives.rs deleted file mode 100644 index e7777491..00000000 --- a/nac3core/src/primitives.rs +++ /dev/null @@ -1,184 +0,0 @@ -use super::typedef::{TypeEnum::*, *}; -use crate::context::*; -use std::collections::HashMap; - -pub const TUPLE_TYPE: ParamId = ParamId(0); -pub const LIST_TYPE: ParamId = ParamId(1); - -pub const BOOL_TYPE: PrimitiveId = PrimitiveId(0); -pub const INT32_TYPE: PrimitiveId = PrimitiveId(1); -pub const INT64_TYPE: PrimitiveId = PrimitiveId(2); -pub const FLOAT_TYPE: PrimitiveId = PrimitiveId(3); - -fn impl_math(def: &mut TypeDef, ty: &Type) { - let result = Some(ty.clone()); - let fun = FnDef { - args: vec![ty.clone()], - result: result.clone(), - }; - def.methods.insert("__add__", fun.clone()); - def.methods.insert("__sub__", fun.clone()); - def.methods.insert("__mul__", fun.clone()); - def.methods.insert( - "__neg__", - FnDef { - args: vec![], - result, - }, - ); - def.methods.insert( - "__truediv__", - FnDef { - args: vec![ty.clone()], - result: Some(PrimitiveType(FLOAT_TYPE).into()), - }, - ); - def.methods.insert("__floordiv__", fun.clone()); - def.methods.insert("__mod__", fun.clone()); - def.methods.insert("__pow__", fun); -} - -fn impl_bits(def: &mut TypeDef, ty: &Type) { - let result = Some(ty.clone()); - let fun = FnDef { - args: vec![PrimitiveType(INT32_TYPE).into()], - result, - }; - - def.methods.insert("__lshift__", fun.clone()); - def.methods.insert("__rshift__", fun); - def.methods.insert( - "__xor__", - FnDef { - args: vec![ty.clone()], - result: Some(ty.clone()), - }, - ); -} - -fn impl_eq(def: &mut TypeDef, ty: &Type) { - let fun = FnDef { - args: vec![ty.clone()], - result: Some(PrimitiveType(BOOL_TYPE).into()), - }; - - def.methods.insert("__eq__", fun.clone()); - def.methods.insert("__ne__", fun); -} - -fn impl_order(def: &mut TypeDef, ty: &Type) { - let fun = FnDef { - args: vec![ty.clone()], - result: Some(PrimitiveType(BOOL_TYPE).into()), - }; - - def.methods.insert("__lt__", fun.clone()); - def.methods.insert("__gt__", fun.clone()); - def.methods.insert("__le__", fun.clone()); - def.methods.insert("__ge__", fun); -} - -pub fn basic_ctx() -> TopLevelContext<'static> { - let primitives = [ - TypeDef { - name: "bool", - fields: HashMap::new(), - methods: HashMap::new(), - }, - TypeDef { - name: "int32", - fields: HashMap::new(), - methods: HashMap::new(), - }, - TypeDef { - name: "int64", - fields: HashMap::new(), - methods: HashMap::new(), - }, - TypeDef { - name: "float", - fields: HashMap::new(), - methods: HashMap::new(), - }, - ] - .to_vec(); - let mut ctx = TopLevelContext::new(primitives); - - let b = ctx.get_primitive(BOOL_TYPE); - let b_def = ctx.get_primitive_def_mut(BOOL_TYPE); - impl_eq(b_def, &b); - let int32 = ctx.get_primitive(INT32_TYPE); - let int32_def = ctx.get_primitive_def_mut(INT32_TYPE); - impl_math(int32_def, &int32); - impl_bits(int32_def, &int32); - impl_order(int32_def, &int32); - impl_eq(int32_def, &int32); - let int64 = ctx.get_primitive(INT64_TYPE); - let int64_def = ctx.get_primitive_def_mut(INT64_TYPE); - impl_math(int64_def, &int64); - impl_bits(int64_def, &int64); - impl_order(int64_def, &int64); - impl_eq(int64_def, &int64); - let float = ctx.get_primitive(FLOAT_TYPE); - let float_def = ctx.get_primitive_def_mut(FLOAT_TYPE); - impl_math(float_def, &float); - impl_order(float_def, &float); - impl_eq(float_def, &float); - - let t = ctx.add_variable_private(VarDef { - name: "T", - bound: vec![], - }); - - ctx.add_parametric(ParametricDef { - base: TypeDef { - name: "tuple", - fields: HashMap::new(), - methods: HashMap::new(), - }, - // we have nothing for tuple, so no param def - params: vec![], - }); - - ctx.add_parametric(ParametricDef { - base: TypeDef { - name: "list", - fields: HashMap::new(), - methods: HashMap::new(), - }, - params: vec![t], - }); - - let i = ctx.add_variable_private(VarDef { - name: "I", - bound: vec![ - PrimitiveType(INT32_TYPE).into(), - PrimitiveType(INT64_TYPE).into(), - PrimitiveType(FLOAT_TYPE).into(), - ], - }); - let args = vec![TypeVariable(i).into()]; - ctx.add_fn( - "int32", - FnDef { - args: args.clone(), - result: Some(PrimitiveType(INT32_TYPE).into()), - }, - ); - ctx.add_fn( - "int64", - FnDef { - args: args.clone(), - result: Some(PrimitiveType(INT64_TYPE).into()), - }, - ); - ctx.add_fn( - "float", - FnDef { - args, - result: Some(PrimitiveType(FLOAT_TYPE).into()), - }, - ); - - ctx -} diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs new file mode 100644 index 00000000..5f716ee4 --- /dev/null +++ b/nac3core/src/symbol_resolver.rs @@ -0,0 +1,174 @@ +use std::cell::RefCell; +use std::collections::HashMap; + +use crate::top_level::{DefinitionId, TopLevelContext, TopLevelDef}; +use crate::typecheck::{ + type_inferencer::PrimitiveStore, + typedef::{Type, Unifier}, +}; +use crate::{location::Location, typecheck::typedef::TypeEnum}; +use itertools::{chain, izip}; +use rustpython_parser::ast::Expr; + +#[derive(Clone, PartialEq)] +pub enum SymbolValue { + I32(i32), + I64(i64), + Double(f64), + Bool(bool), + Tuple(Vec), + // we should think about how to implement bytes later... + // Bytes(&'a [u8]), +} + +pub trait SymbolResolver { + // get type of type variable identifier or top-level function type + fn get_symbol_type( + &self, + unifier: &mut Unifier, + primitives: &PrimitiveStore, + str: &str, + ) -> Option; + // get the top-level definition of identifiers + fn get_identifier_def(&self, str: &str) -> Option; + fn get_symbol_value(&self, str: &str) -> Option; + fn get_symbol_location(&self, str: &str) -> Option; + // handle function call etc. +} + +// convert type annotation into type +pub fn parse_type_annotation( + resolver: &dyn SymbolResolver, + top_level: &TopLevelContext, + unifier: &mut Unifier, + primitives: &PrimitiveStore, + expr: &Expr, +) -> Result { + use rustpython_parser::ast::ExprKind::*; + match &expr.node { + Name { id, .. } => match id.as_str() { + "int32" => Ok(primitives.int32), + "int64" => Ok(primitives.int64), + "float" => Ok(primitives.float), + "bool" => Ok(primitives.bool), + "None" => Ok(primitives.none), + x => { + let obj_id = resolver.get_identifier_def(x); + if let Some(obj_id) = obj_id { + let defs = top_level.definitions.read(); + let def = defs[obj_id.0].read(); + if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { + if !type_vars.is_empty() { + return Err(format!( + "Unexpected number of type parameters: expected {} but got 0", + type_vars.len() + )); + } + let fields = RefCell::new( + chain( + fields.iter().map(|(k, v)| (k.clone(), *v)), + methods.iter().map(|(k, v, _)| (k.clone(), *v)), + ) + .collect(), + ); + Ok(unifier.add_ty(TypeEnum::TObj { + obj_id, + fields, + params: Default::default(), + })) + } else { + Err("Cannot use function name as type".into()) + } + } else { + // it could be a type variable + let ty = resolver + .get_symbol_type(unifier, primitives, x) + .ok_or_else(|| "Cannot use function name as type".to_owned())?; + if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) { + Ok(ty) + } else { + Err(format!("Unknown type annotation {}", x)) + } + } + } + }, + Subscript { value, slice, .. } => { + if let Name { id, .. } = &value.node { + if id == "virtual" { + let ty = + parse_type_annotation(resolver, top_level, unifier, primitives, slice)?; + Ok(unifier.add_ty(TypeEnum::TVirtual { ty })) + } else { + let types = if let Tuple { elts, .. } = &slice.node { + elts.iter() + .map(|v| { + parse_type_annotation(resolver, top_level, unifier, primitives, v) + }) + .collect::, _>>()? + } else { + vec![parse_type_annotation( + resolver, top_level, unifier, primitives, slice, + )?] + }; + + let obj_id = resolver + .get_identifier_def(id) + .ok_or_else(|| format!("Unknown type annotation {}", id))?; + let defs = top_level.definitions.read(); + let def = defs[obj_id.0].read(); + if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { + if types.len() != type_vars.len() { + return Err(format!( + "Unexpected number of type parameters: expected {} but got {}", + type_vars.len(), + types.len() + )); + } + let mut subst = HashMap::new(); + for (var, ty) in izip!(type_vars.iter(), types.iter()) { + let id = if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) { + *id + } else { + unreachable!() + }; + subst.insert(id, *ty); + } + let mut fields = fields + .iter() + .map(|(attr, ty)| { + let ty = unifier.subst(*ty, &subst).unwrap_or(*ty); + (attr.clone(), ty) + }) + .collect::>(); + fields.extend(methods.iter().map(|(attr, ty, _)| { + let ty = unifier.subst(*ty, &subst).unwrap_or(*ty); + (attr.clone(), ty) + })); + Ok(unifier.add_ty(TypeEnum::TObj { + obj_id, + fields: fields.into(), + params: subst.into(), + })) + } else { + Err("Cannot use function name as type".into()) + } + } + } else { + Err("unsupported type expression".into()) + } + } + _ => Err("unsupported type expression".into()), + } +} + +impl dyn SymbolResolver + Send + Sync { + pub fn parse_type_annotation( + &self, + top_level: &TopLevelContext, + unifier: &mut Unifier, + primitives: &PrimitiveStore, + expr: &Expr, + ) -> Result { + parse_type_annotation(self, top_level, unifier, primitives, expr) + } +} diff --git a/nac3core/src/top_level.rs b/nac3core/src/top_level.rs new file mode 100644 index 00000000..8560cfb3 --- /dev/null +++ b/nac3core/src/top_level.rs @@ -0,0 +1,778 @@ +use std::borrow::BorrowMut; +use std::ops::{Deref, DerefMut}; +use std::{collections::HashMap, collections::HashSet, sync::Arc}; + +use super::typecheck::type_inferencer::PrimitiveStore; +use super::typecheck::typedef::{SharedUnifier, Type, TypeEnum, Unifier}; +use crate::typecheck::typedef::{FunSignature, FuncArg}; +use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Mapping}; +use itertools::Itertools; +use parking_lot::{Mutex, RwLock}; +use rustpython_parser::ast::{self, Stmt}; + +#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] +pub struct DefinitionId(pub usize); + +pub enum TopLevelDef { + Class { + // object ID used for TypeEnum + object_id: DefinitionId, + // type variables bounded to the class. + type_vars: Vec, + // class fields + fields: Vec<(String, Type)>, + // class methods, pointing to the corresponding function definition. + methods: Vec<(String, Type, DefinitionId)>, + // ancestor classes, including itself. + ancestors: Vec, + // symbol resolver of the module defined the class, none if it is built-in type + resolver: Option>>, + }, + Function { + // prefix for symbol, should be unique globally, and not ending with numbers + name: String, + // function signature. + signature: Type, + /// Function instance to symbol mapping + /// Key: string representation of type variable values, sorted by variable ID in ascending + /// order, including type variables associated with the class. + /// Value: function symbol name. + instance_to_symbol: HashMap, + /// Function instances to annotated AST mapping + /// Key: string representation of type variable values, sorted by variable ID in ascending + /// order, including type variables associated with the class. Excluding rigid type + /// variables. + /// Value: AST annotated with types together with a unification table index. Could contain + /// rigid type variables that would be substituted when the function is instantiated. + instance_to_stmt: HashMap>, usize)>, + // symbol resolver of the module defined the class + resolver: Option>>, + }, + Initializer { + class_id: DefinitionId, + }, +} + +impl TopLevelDef { + fn get_function_type(&self) -> Result { + if let Self::Function { signature, .. } = self { + Ok(*signature) + } else { + Err("only expect function def here".into()) + } + } +} + +pub struct TopLevelContext { + pub definitions: Arc>>>>, + pub unifiers: Arc>>, +} + +pub struct TopLevelComposer { + // list of top level definitions, same as top level context + pub definition_ast_list: Arc>, Option>)>>>, + // start as a primitive unifier, will add more top_level defs inside + pub unifier: Unifier, + // primitive store + pub primitives: PrimitiveStore, + // mangled class method name to def_id + pub class_method_to_def_id: HashMap, + // record the def id of the classes whoses fields and methods are to be analyzed + pub to_be_analyzed_class: Vec, +} + +impl TopLevelComposer { + pub fn to_top_level_context(&self) -> TopLevelContext { + let def_list = + self.definition_ast_list.read().iter().map(|(x, _)| x.clone()).collect::>(); + TopLevelContext { + definitions: RwLock::new(def_list).into(), + // FIXME: all the big unifier or? + unifiers: Default::default(), + } + } + + fn name_mangling(mut class_name: String, method_name: &str) -> String { + class_name.push_str(method_name); + class_name + } + + pub fn make_primitives() -> (PrimitiveStore, Unifier) { + let mut unifier = Unifier::new(); + let int32 = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(0), + fields: HashMap::new().into(), + params: HashMap::new().into(), + }); + let int64 = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(1), + fields: HashMap::new().into(), + params: HashMap::new().into(), + }); + let float = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(2), + fields: HashMap::new().into(), + params: HashMap::new().into(), + }); + let bool = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(3), + fields: HashMap::new().into(), + params: HashMap::new().into(), + }); + let none = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(4), + fields: HashMap::new().into(), + params: HashMap::new().into(), + }); + let primitives = PrimitiveStore { int32, int64, float, bool, none }; + crate::typecheck::magic_methods::set_primitives_magic_methods(&primitives, &mut unifier); + (primitives, unifier) + } + + /// return a composer and things to make a "primitive" symbol resolver, so that the symbol + /// resolver can later figure out primitive type definitions when passed a primitive type name + pub fn new() -> (Vec<(String, DefinitionId, Type)>, Self) { + let primitives = Self::make_primitives(); + + let top_level_def_list = vec![ + Arc::new(RwLock::new(Self::make_top_level_class_def(0, None))), + Arc::new(RwLock::new(Self::make_top_level_class_def(1, None))), + Arc::new(RwLock::new(Self::make_top_level_class_def(2, None))), + Arc::new(RwLock::new(Self::make_top_level_class_def(3, None))), + Arc::new(RwLock::new(Self::make_top_level_class_def(4, None))), + ]; + + let ast_list: Vec>> = vec![None, None, None, None, None]; + + let composer = TopLevelComposer { + definition_ast_list: RwLock::new( + top_level_def_list.into_iter().zip(ast_list).collect_vec(), + ) + .into(), + primitives: primitives.0, + unifier: primitives.1, + class_method_to_def_id: Default::default(), + to_be_analyzed_class: Default::default(), + }; + ( + vec![ + ("int32".into(), DefinitionId(0), composer.primitives.int32), + ("int64".into(), DefinitionId(1), composer.primitives.int64), + ("float".into(), DefinitionId(2), composer.primitives.float), + ("bool".into(), DefinitionId(3), composer.primitives.bool), + ("none".into(), DefinitionId(4), composer.primitives.none), + ], + composer, + ) + } + + /// already include the definition_id of itself inside the ancestors vector + /// when first regitering, the type_vars, fields, methods, ancestors are invalid + pub fn make_top_level_class_def( + index: usize, + resolver: Option>>, + ) -> TopLevelDef { + TopLevelDef::Class { + object_id: DefinitionId(index), + type_vars: Default::default(), + fields: Default::default(), + methods: Default::default(), + ancestors: vec![DefinitionId(index)], + resolver, + } + } + + /// when first registering, the type is a invalid value + pub fn make_top_level_function_def( + name: String, + ty: Type, + resolver: Option>>, + ) -> TopLevelDef { + TopLevelDef::Function { + name, + signature: ty, + instance_to_symbol: Default::default(), + instance_to_stmt: Default::default(), + resolver, + } + } + + /// step 0, register, just remeber the names of top level classes/function + pub fn register_top_level( + &mut self, + ast: ast::Stmt<()>, + resolver: Option>>, + ) -> Result<(String, DefinitionId), String> { + let mut def_list = self.definition_ast_list.write(); + match &ast.node { + ast::StmtKind::ClassDef { name, body, .. } => { + let class_name = name.to_string(); + let class_def_id = def_list.len(); + + // add the class to the definition lists + // since later when registering class method, ast will still be used, + // here push None temporarly, later will move the ast inside + let mut class_def_ast = ( + Arc::new(RwLock::new(Self::make_top_level_class_def( + class_def_id, + resolver.clone(), + ))), + None, + ); + + // parse class def body and register class methods into the def list. + // module's symbol resolver would not know the name of the class methods, + // thus cannot return their definition_id + let mut class_method_name_def_ids: Vec<( + String, + Arc>, + DefinitionId, + )> = Vec::new(); + let mut class_method_index_offset = 0; + for b in body { + if let ast::StmtKind::FunctionDef { name: method_name, .. } = &b.node { + let method_name = Self::name_mangling(class_name.clone(), method_name); + let method_def_id = def_list.len() + { + class_method_index_offset += 1; + class_method_index_offset + }; + + // dummy method define here + // the ast of class method is in the class, push None in to the list here + class_method_name_def_ids.push(( + method_name.clone(), + RwLock::new(Self::make_top_level_function_def( + method_name.clone(), + self.primitives.none, + resolver.clone(), + )) + .into(), + DefinitionId(method_def_id), + )); + } + } + // move the ast to the entry of the class in the ast_list + class_def_ast.1 = Some(ast); + + // now class_def_ast and class_method_def_ast_ids are ok, put them into actual def list in correct order + def_list.push(class_def_ast); + for (name, def, id) in class_method_name_def_ids { + def_list.push((def, None)); + self.class_method_to_def_id.insert(name, id); + } + + // put the constructor into the def_list + def_list.push(( + RwLock::new(TopLevelDef::Initializer { class_id: DefinitionId(class_def_id) }) + .into(), + None, + )); + + // class, put its def_id into the to be analyzed set + self.to_be_analyzed_class.push(DefinitionId(class_def_id)); + + Ok((class_name, DefinitionId(class_def_id))) + } + + ast::StmtKind::FunctionDef { name, .. } => { + let fun_name = name.to_string(); + + // add to the definition list + def_list.push(( + RwLock::new(Self::make_top_level_function_def( + name.into(), + self.primitives.none, + resolver, + )) + .into(), + Some(ast), + )); + + // return + Ok((fun_name, DefinitionId(def_list.len() - 1))) + } + + _ => Err("only registrations of top level classes/functions are supprted".into()), + } + } + + /// step 1, analyze the type vars associated with top level class + fn analyze_top_level_class_type_var(&mut self) -> Result<(), String> { + let mut def_list = self.definition_ast_list.write(); + let converted_top_level = &self.to_top_level_context(); + let primitives = &self.primitives; + let unifier = &mut self.unifier; + + for (class_def, class_ast) in def_list.iter_mut() { + // only deal with class def here + let mut class_def = class_def.write(); + let (class_bases_ast, class_def_type_vars, class_resolver) = { + if let TopLevelDef::Class { type_vars, resolver, .. } = class_def.deref_mut() { + if let Some(ast::Located { + node: ast::StmtKind::ClassDef { bases, .. }, .. + }) = class_ast + { + (bases, type_vars, resolver) + } else { + unreachable!("must be both class") + } + } else { + continue; + } + }; + let class_resolver = class_resolver.as_ref().unwrap().lock(); + + let mut is_generic = false; + for b in class_bases_ast { + match &b.node { + // analyze typevars bounded to the class, + // only support things like `class A(Generic[T, V])`, + // things like `class A(Generic[T, V, ImportedModule.T])` is not supported + // i.e. only simple names are allowed in the subscript + // should update the TopLevelDef::Class.typevars and the TypeEnum::TObj.params + ast::ExprKind::Subscript { value, slice, .. } if matches!(&value.node, ast::ExprKind::Name { id, .. } if id == "Generic") => + { + if !is_generic { + is_generic = true; + } else { + return Err("Only single Generic[...] can be in bases".into()); + } + + // if `class A(Generic[T, V, G])` + if let ast::ExprKind::Tuple { elts, .. } = &slice.node { + // parse the type vars + let type_vars = elts + .iter() + .map(|e| { + class_resolver.parse_type_annotation( + converted_top_level, + unifier.borrow_mut(), + primitives, + e, + ) + }) + .collect::, _>>()?; + + // check if all are unique type vars + let mut occured_type_var_id: HashSet = HashSet::new(); + let all_unique_type_var = type_vars.iter().all(|x| { + let ty = unifier.get_ty(*x); + if let TypeEnum::TVar { id, .. } = ty.as_ref() { + occured_type_var_id.insert(*id) + } else { + false + } + }); + + if !all_unique_type_var { + return Err("expect unique type variables".into()); + } + + // add to TopLevelDef + class_def_type_vars.extend(type_vars); + + // `class A(Generic[T])` + } else { + let ty = class_resolver.parse_type_annotation( + converted_top_level, + unifier.borrow_mut(), + primitives, + &slice, + )?; + // check if it is type var + let is_type_var = + matches!(unifier.get_ty(ty).as_ref(), &TypeEnum::TVar { .. }); + if !is_type_var { + return Err("expect type variable here".into()); + } + + // add to TopLevelDef + class_def_type_vars.push(ty); + } + } + + // if others, do nothing in this function + _ => continue, + } + } + } + Ok(()) + } + + /// step 2, base classes. Need to separate step1 and step2 for this reason: + /// `class B(Generic[T, V]); + /// class A(B[int, bool])` + /// if the type var associated with class `B` has not been handled properly, + /// the parse of type annotation of `B[int, bool]` will fail + fn analyze_top_level_class_bases(&mut self) -> Result<(), String> { + let mut def_list = self.definition_ast_list.write(); + let converted_top_level = &self.to_top_level_context(); + let primitives = &self.primitives; + let unifier = &mut self.unifier; + + for (class_def, class_ast) in def_list.iter_mut() { + let mut class_def = class_def.write(); + let (class_bases, class_ancestors, class_resolver) = { + if let TopLevelDef::Class { ancestors, resolver, .. } = class_def.deref_mut() { + if let Some(ast::Located { + node: ast::StmtKind::ClassDef { bases, .. }, .. + }) = class_ast + { + (bases, ancestors, resolver) + } else { + unreachable!("must be both class") + } + } else { + continue; + } + }; + let class_resolver = class_resolver.as_ref().unwrap().lock(); + for b in class_bases { + // type vars have already been handled, so skip on `Generic[...]` + if let ast::ExprKind::Subscript { value, .. } = &b.node { + if let ast::ExprKind::Name { id, .. } = &value.node { + if id == "Generic" { + continue; + } + } + } + // get the def id of the base class + let base_ty = class_resolver.parse_type_annotation( + converted_top_level, + unifier.borrow_mut(), + primitives, + b, + )?; + let base_id = + if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty(base_ty).as_ref() { + *obj_id + } else { + return Err("expect concrete class/type to be base class".into()); + }; + + // write to the class ancestors, make sure the uniqueness + if !class_ancestors.contains(&base_id) { + class_ancestors.push(base_id); + } else { + return Err("cannot specify the same base class twice".into()); + } + } + } + Ok(()) + } + + /// step 3, class fields and methods + // FIXME: analyze base classes here + // FIXME: deal with self type + // NOTE: prevent cycles only roughly done + fn analyze_top_level_class_fields_methods(&mut self) -> Result<(), String> { + let mut def_ast_list = self.definition_ast_list.write(); + let converted_top_level = &self.to_top_level_context(); + let primitives = &self.primitives; + let to_be_analyzed_class = &mut self.to_be_analyzed_class; + let unifier = &mut self.unifier; + + // NOTE: roughly prevent infinite loop + let mut max_iter = to_be_analyzed_class.len() * 4; + 'class: loop { + if to_be_analyzed_class.is_empty() && { + max_iter -= 1; + max_iter > 0 + } { + break; + } + + let class_ind = to_be_analyzed_class.remove(0).0; + let (class_name, class_body_ast, class_bases_ast, class_resolver, class_ancestors) = { + let (class_def, class_ast) = &mut def_ast_list[class_ind]; + if let Some(ast::Located { + node: ast::StmtKind::ClassDef { name, body, bases, .. }, + .. + }) = class_ast.as_ref() + { + if let TopLevelDef::Class { resolver, ancestors, .. } = + class_def.write().deref() + { + (name, body, bases, resolver.as_ref().unwrap().clone(), ancestors.clone()) + } else { + unreachable!() + } + } else { + unreachable!("should be class def ast") + } + }; + + let all_base_class_analyzed = { + let not_yet_analyzed = + to_be_analyzed_class.clone().into_iter().collect::>(); + let base = class_ancestors.clone().into_iter().collect::>(); + let intersection = not_yet_analyzed.intersection(&base).collect_vec(); + intersection.is_empty() + }; + if !all_base_class_analyzed { + to_be_analyzed_class.push(DefinitionId(class_ind)); + continue 'class; + } + + // get the bases type, can directly do this since it + // already pass the check in the previous stages + let class_bases_ty = class_bases_ast + .iter() + .filter_map(|x| { + class_resolver + .as_ref() + .lock() + .parse_type_annotation( + converted_top_level, + unifier.borrow_mut(), + primitives, + x, + ) + .ok() + }) + .collect_vec(); + + // need these vectors to check re-defining methods, class fields + // and store the parsed result in case some method cannot be typed for now + let mut class_methods_parsing_result: Vec<(String, Type, DefinitionId)> = vec![]; + let mut class_fields_parsing_result: Vec<(String, Type)> = vec![]; + for b in class_body_ast { + if let ast::StmtKind::FunctionDef { + args: method_args_ast, + body: method_body_ast, + name: method_name, + returns: method_returns_ast, + .. + } = &b.node + { + let arg_name_tys: Vec<(String, Type)> = { + let mut result = vec![]; + for a in &method_args_ast.args { + if a.node.arg != "self" { + let annotation = a + .node + .annotation + .as_ref() + .ok_or_else(|| { + "type annotation for function parameter is needed" + .to_string() + })? + .as_ref(); + + let ty = class_resolver.as_ref().lock().parse_type_annotation( + converted_top_level, + unifier.borrow_mut(), + primitives, + annotation, + )?; + if !Self::check_ty_analyzed(ty, unifier, to_be_analyzed_class) { + to_be_analyzed_class.push(DefinitionId(class_ind)); + continue 'class; + } + result.push((a.node.arg.to_string(), ty)); + } else { + // TODO: handle self, how + unimplemented!() + } + } + result + }; + + let method_type_var = arg_name_tys + .iter() + .filter_map(|(_, ty)| { + let ty_enum = unifier.get_ty(*ty); + if let TypeEnum::TVar { id, .. } = ty_enum.as_ref() { + Some((*id, *ty)) + } else { + None + } + }) + .collect::>(); + + let ret_ty = { + if method_name != "__init__" { + let ty = method_returns_ast + .as_ref() + .map(|x| { + class_resolver.as_ref().lock().parse_type_annotation( + converted_top_level, + unifier.borrow_mut(), + primitives, + x.as_ref(), + ) + }) + .ok_or_else(|| "return type annotation error".to_string())??; + if !Self::check_ty_analyzed(ty, unifier, to_be_analyzed_class) { + to_be_analyzed_class.push(DefinitionId(class_ind)); + continue 'class; + } else { + ty + } + } else { + // TODO: __init__ function, self type, how + unimplemented!() + } + }; + + // handle fields + let class_field_name_tys: Option> = if method_name + == "__init__" + { + let mut result: Vec<(String, Type)> = vec![]; + for body in method_body_ast { + match &body.node { + ast::StmtKind::AnnAssign { target, annotation, .. } + if { + if let ast::ExprKind::Attribute { value, .. } = &target.node + { + matches!( + &value.node, + ast::ExprKind::Name { id, .. } if id == "self") + } else { + false + } + } => + { + let field_ty = + class_resolver.as_ref().lock().parse_type_annotation( + converted_top_level, + unifier.borrow_mut(), + primitives, + annotation.as_ref(), + )?; + if !Self::check_ty_analyzed( + field_ty, + unifier, + to_be_analyzed_class, + ) { + to_be_analyzed_class.push(DefinitionId(class_ind)); + continue 'class; + } else { + result.push(( + if let ast::ExprKind::Attribute { attr, .. } = + &target.node + { + attr.to_string() + } else { + unreachable!() + }, + field_ty, + )) + } + } + + // exclude those without type annotation + ast::StmtKind::Assign { targets, .. } + if { + if let ast::ExprKind::Attribute { value, .. } = + &targets[0].node + { + matches!( + &value.node, + ast::ExprKind::Name {id, ..} if id == "self") + } else { + false + } + } => + { + return Err("class fields type annotation needed".into()) + } + + // do nothing + _ => {} + } + } + Some(result) + } else { + None + }; + + // current method all type ok, put the current method into the list + if class_methods_parsing_result.iter().any(|(name, _, _)| name == method_name) { + return Err("duplicate method definition".into()); + } else { + class_methods_parsing_result.push(( + method_name.clone(), + unifier.add_ty(TypeEnum::TFunc( + FunSignature { + ret: ret_ty, + args: arg_name_tys + .into_iter() + .map(|(name, ty)| FuncArg { name, ty, default_value: None }) + .collect_vec(), + vars: method_type_var, + } + .into(), + )), + *self + .class_method_to_def_id + .get(&Self::name_mangling(class_name.clone(), method_name)) + .unwrap(), + )) + } + + // put the fiedlds inside + if let Some(class_field_name_tys) = class_field_name_tys { + assert!(class_fields_parsing_result.is_empty()); + class_fields_parsing_result.extend(class_field_name_tys); + } + } else { + // what should we do with `class A: a = 3`? + // do nothing, continue the for loop to iterate class ast + continue; + } + } + + // now it should be confirmed that every + // methods and fields of the class can be correctly typed, put the results + // into the actual class def method and fields field + let (class_def, _) = &def_ast_list[class_ind]; + let mut class_def = class_def.write(); + if let TopLevelDef::Class { fields, methods, .. } = class_def.deref_mut() { + for (ref n, ref t) in class_fields_parsing_result { + fields.push((n.clone(), *t)); + } + for (n, t, id) in &class_methods_parsing_result { + methods.push((n.clone(), *t, *id)); + } + } else { + unreachable!() + } + + // change the signature field of the class methods + for (_, ty, id) in &class_methods_parsing_result { + let (method_def, _) = &def_ast_list[id.0]; + let mut method_def = method_def.write(); + if let TopLevelDef::Function { signature, .. } = method_def.deref_mut() { + *signature = *ty; + } + } + } + Ok(()) + } + + fn analyze_top_level_function(&mut self) -> Result<(), String> { + unimplemented!() + } + + fn analyze_top_level_field_instantiation(&mut self) -> Result<(), String> { + unimplemented!() + } + + fn check_ty_analyzed(ty: Type, unifier: &mut Unifier, to_be_analyzed: &[DefinitionId]) -> bool { + let type_enum = unifier.get_ty(ty); + match type_enum.as_ref() { + TypeEnum::TObj { obj_id, .. } => !to_be_analyzed.contains(obj_id), + TypeEnum::TVirtual { ty } => { + if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty(*ty).as_ref() { + !to_be_analyzed.contains(obj_id) + } else { + unreachable!() + } + } + TypeEnum::TVar { .. } => true, + _ => unreachable!(), + } + } +} diff --git a/nac3core/src/typecheck/function_check.rs b/nac3core/src/typecheck/function_check.rs new file mode 100644 index 00000000..92509ccc --- /dev/null +++ b/nac3core/src/typecheck/function_check.rs @@ -0,0 +1,216 @@ +use super::type_inferencer::Inferencer; +use super::typedef::Type; +use rustpython_parser::ast::{self, Expr, ExprKind, Stmt, StmtKind}; +use std::iter::once; + +impl<'a> Inferencer<'a> { + fn check_pattern( + &mut self, + pattern: &Expr>, + defined_identifiers: &mut Vec, + ) -> Result<(), String> { + match &pattern.node { + ExprKind::Name { id, .. } => { + if !defined_identifiers.contains(id) { + defined_identifiers.push(id.clone()); + } + Ok(()) + } + ExprKind::Tuple { elts, .. } => { + for elt in elts.iter() { + self.check_pattern(elt, defined_identifiers)?; + } + Ok(()) + } + _ => self.check_expr(pattern, defined_identifiers), + } + } + + fn check_expr( + &mut self, + expr: &Expr>, + defined_identifiers: &[String], + ) -> Result<(), String> { + // there are some cases where the custom field is None + if let Some(ty) = &expr.custom { + if !self.unifier.is_concrete(*ty, &self.function_data.bound_variables) { + return Err(format!( + "expected concrete type at {} but got {}", + expr.location, + self.unifier.get_ty(*ty).get_type_name() + )); + } + } + match &expr.node { + ExprKind::Name { id, .. } => { + if !defined_identifiers.contains(id) { + return Err(format!( + "unknown identifier {} (use before def?) at {}", + id, expr.location + )); + } + } + ExprKind::List { elts, .. } + | ExprKind::Tuple { elts, .. } + | ExprKind::BoolOp { values: elts, .. } => { + for elt in elts.iter() { + self.check_expr(elt, defined_identifiers)?; + } + } + ExprKind::Attribute { value, .. } => { + self.check_expr(value, defined_identifiers)?; + } + ExprKind::BinOp { left, right, .. } => { + self.check_expr(left, defined_identifiers)?; + self.check_expr(right, defined_identifiers)?; + } + ExprKind::UnaryOp { operand, .. } => { + self.check_expr(operand, defined_identifiers)?; + } + ExprKind::Compare { left, comparators, .. } => { + for elt in once(left.as_ref()).chain(comparators.iter()) { + self.check_expr(elt, defined_identifiers)?; + } + } + ExprKind::Subscript { value, slice, .. } => { + self.check_expr(value, defined_identifiers)?; + self.check_expr(slice, defined_identifiers)?; + } + ExprKind::IfExp { test, body, orelse } => { + self.check_expr(test, defined_identifiers)?; + self.check_expr(body, defined_identifiers)?; + self.check_expr(orelse, defined_identifiers)?; + } + ExprKind::Slice { lower, upper, step } => { + for elt in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() { + self.check_expr(elt, defined_identifiers)?; + } + } + ExprKind::Lambda { args, body } => { + let mut defined_identifiers = defined_identifiers.to_vec(); + for arg in args.args.iter() { + if !defined_identifiers.contains(&arg.node.arg) { + defined_identifiers.push(arg.node.arg.clone()); + } + } + self.check_expr(body, &defined_identifiers)?; + } + ExprKind::ListComp { elt, generators, .. } => { + // in our type inference stage, we already make sure that there is only 1 generator + let ast::Comprehension { target, iter, ifs, .. } = &generators[0]; + self.check_expr(iter, defined_identifiers)?; + let mut defined_identifiers = defined_identifiers.to_vec(); + self.check_pattern(target, &mut defined_identifiers)?; + for term in once(elt.as_ref()).chain(ifs.iter()) { + self.check_expr(term, &defined_identifiers)?; + } + } + ExprKind::Call { func, args, keywords } => { + for expr in once(func.as_ref()) + .chain(args.iter()) + .chain(keywords.iter().map(|v| v.node.value.as_ref())) + { + self.check_expr(expr, defined_identifiers)?; + } + } + ExprKind::Constant { .. } => {} + _ => { + println!("{:?}", expr.node); + unimplemented!() + } + } + Ok(()) + } + + // check statements for proper identifier def-use and return on all paths + fn check_stmt( + &mut self, + stmt: &Stmt>, + defined_identifiers: &mut Vec, + ) -> Result { + match &stmt.node { + StmtKind::For { target, iter, body, orelse, .. } => { + self.check_expr(iter, defined_identifiers)?; + for stmt in orelse.iter() { + self.check_stmt(stmt, defined_identifiers)?; + } + let mut defined_identifiers = defined_identifiers.clone(); + self.check_pattern(target, &mut defined_identifiers)?; + for stmt in body.iter() { + self.check_stmt(stmt, &mut defined_identifiers)?; + } + Ok(false) + } + StmtKind::If { test, body, orelse } => { + self.check_expr(test, defined_identifiers)?; + let mut body_identifiers = defined_identifiers.clone(); + let mut orelse_identifiers = defined_identifiers.clone(); + let body_returned = self.check_block(body, &mut body_identifiers)?; + let orelse_returned = self.check_block(orelse, &mut orelse_identifiers)?; + + for ident in body_identifiers.iter() { + if !defined_identifiers.contains(ident) && orelse_identifiers.contains(ident) { + defined_identifiers.push(ident.clone()) + } + } + Ok(body_returned && orelse_returned) + } + StmtKind::While { test, body, orelse } => { + self.check_expr(test, defined_identifiers)?; + let mut defined_identifiers = defined_identifiers.clone(); + self.check_block(body, &mut defined_identifiers)?; + self.check_block(orelse, &mut defined_identifiers)?; + Ok(false) + } + StmtKind::Expr { value } => { + self.check_expr(value, defined_identifiers)?; + Ok(false) + } + StmtKind::Assign { targets, value, .. } => { + self.check_expr(value, defined_identifiers)?; + for target in targets { + self.check_pattern(target, defined_identifiers)?; + } + Ok(false) + } + StmtKind::AnnAssign { target, value, .. } => { + if let Some(value) = value { + self.check_expr(value, defined_identifiers)?; + self.check_pattern(target, defined_identifiers)?; + } + Ok(false) + } + StmtKind::Return { value } => { + if let Some(value) = value { + self.check_expr(value, defined_identifiers)?; + } + Ok(true) + } + StmtKind::Raise { exc, .. } => { + if let Some(value) = exc { + self.check_expr(value, defined_identifiers)?; + } + Ok(true) + } + // break, raise, etc. + _ => Ok(false), + } + } + + pub fn check_block( + &mut self, + block: &[Stmt>], + defined_identifiers: &mut Vec, + ) -> Result { + let mut ret = false; + for stmt in block { + if ret { + return Err(format!("dead code at {:?}", stmt.location)); + } + if self.check_stmt(stmt, defined_identifiers)? { + ret = true; + } + } + Ok(ret) + } +} diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs new file mode 100644 index 00000000..30e3c753 --- /dev/null +++ b/nac3core/src/typecheck/magic_methods.rs @@ -0,0 +1,322 @@ +use crate::typecheck::{ + type_inferencer::*, + typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier}, +}; +use rustpython_parser::ast; +use rustpython_parser::ast::{Cmpop, Operator, Unaryop}; +use std::borrow::Borrow; +use std::collections::HashMap; + +pub fn binop_name(op: &Operator) -> &'static str { + match op { + Operator::Add => "__add__", + Operator::Sub => "__sub__", + Operator::Div => "__truediv__", + Operator::Mod => "__mod__", + Operator::Mult => "__mul__", + Operator::Pow => "__pow__", + Operator::BitOr => "__or__", + Operator::BitXor => "__xor__", + Operator::BitAnd => "__and__", + Operator::LShift => "__lshift__", + Operator::RShift => "__rshift__", + Operator::FloorDiv => "__floordiv__", + Operator::MatMult => "__matmul__", + } +} + +pub fn binop_assign_name(op: &Operator) -> &'static str { + match op { + Operator::Add => "__iadd__", + Operator::Sub => "__isub__", + Operator::Div => "__itruediv__", + Operator::Mod => "__imod__", + Operator::Mult => "__imul__", + Operator::Pow => "__ipow__", + Operator::BitOr => "__ior__", + Operator::BitXor => "__ixor__", + Operator::BitAnd => "__iand__", + Operator::LShift => "__ilshift__", + Operator::RShift => "__irshift__", + Operator::FloorDiv => "__ifloordiv__", + Operator::MatMult => "__imatmul__", + } +} + +pub fn unaryop_name(op: &Unaryop) -> &'static str { + match op { + Unaryop::UAdd => "__pos__", + Unaryop::USub => "__neg__", + Unaryop::Not => "__not__", + Unaryop::Invert => "__inv__", + } +} + +pub fn comparison_name(op: &Cmpop) -> Option<&'static str> { + match op { + Cmpop::Lt => Some("__lt__"), + Cmpop::LtE => Some("__le__"), + Cmpop::Gt => Some("__gt__"), + Cmpop::GtE => Some("__ge__"), + Cmpop::Eq => Some("__eq__"), + Cmpop::NotEq => Some("__ne__"), + _ => None, + } +} + +pub fn impl_binop( + unifier: &mut Unifier, + store: &PrimitiveStore, + ty: Type, + other_ty: &[Type], + ret_ty: Type, + ops: &[ast::Operator], +) { + if let TypeEnum::TObj { fields, .. } = unifier.get_ty(ty).borrow() { + let (other_ty, other_var_id) = if other_ty.len() == 1 { + (other_ty[0], None) + } else { + let (ty, var_id) = unifier.get_fresh_var_with_range(other_ty); + (ty, Some(var_id)) + }; + let function_vars = if let Some(var_id) = other_var_id { + vec![(var_id, other_ty)].into_iter().collect::>() + } else { + HashMap::new() + }; + for op in ops { + fields.borrow_mut().insert(binop_name(op).into(), { + unifier.add_ty(TypeEnum::TFunc( + FunSignature { + ret: ret_ty, + vars: function_vars.clone(), + args: vec![FuncArg { + ty: other_ty, + default_value: None, + name: "other".into(), + }], + } + .into(), + )) + }); + + fields.borrow_mut().insert(binop_assign_name(op).into(), { + unifier.add_ty(TypeEnum::TFunc( + FunSignature { + ret: store.none, + vars: function_vars.clone(), + args: vec![FuncArg { + ty: other_ty, + default_value: None, + name: "other".into(), + }], + } + .into(), + )) + }); + } + } else { + unreachable!("") + } +} + +pub fn impl_unaryop( + unifier: &mut Unifier, + _store: &PrimitiveStore, + ty: Type, + ret_ty: Type, + ops: &[ast::Unaryop], +) { + if let TypeEnum::TObj { fields, .. } = unifier.get_ty(ty).borrow() { + for op in ops { + fields.borrow_mut().insert( + unaryop_name(op).into(), + unifier.add_ty(TypeEnum::TFunc( + FunSignature { ret: ret_ty, vars: HashMap::new(), args: vec![] }.into(), + )), + ); + } + } else { + unreachable!() + } +} + +pub fn impl_cmpop( + unifier: &mut Unifier, + store: &PrimitiveStore, + ty: Type, + other_ty: Type, + ops: &[ast::Cmpop], +) { + if let TypeEnum::TObj { fields, .. } = unifier.get_ty(ty).borrow() { + for op in ops { + fields.borrow_mut().insert( + comparison_name(op).unwrap().into(), + unifier.add_ty(TypeEnum::TFunc( + FunSignature { + ret: store.bool, + vars: HashMap::new(), + args: vec![FuncArg { + ty: other_ty, + default_value: None, + name: "other".into(), + }], + } + .into(), + )), + ); + } + } else { + unreachable!() + } +} + +/// Add, Sub, Mult, Pow +pub fn impl_basic_arithmetic( + unifier: &mut Unifier, + store: &PrimitiveStore, + ty: Type, + other_ty: &[Type], + ret_ty: Type, +) { + impl_binop( + unifier, + store, + ty, + other_ty, + ret_ty, + &[ast::Operator::Add, ast::Operator::Sub, ast::Operator::Mult], + ) +} + +pub fn impl_pow( + unifier: &mut Unifier, + store: &PrimitiveStore, + ty: Type, + other_ty: &[Type], + ret_ty: Type, +) { + impl_binop(unifier, store, ty, other_ty, ret_ty, &[ast::Operator::Pow]) +} + +/// BitOr, BitXor, BitAnd +pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { + impl_binop( + unifier, + store, + ty, + &[ty], + ty, + &[ast::Operator::BitAnd, ast::Operator::BitOr, ast::Operator::BitXor], + ) +} + +/// LShift, RShift +pub fn impl_bitwise_shift(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { + impl_binop(unifier, store, ty, &[ty], ty, &[ast::Operator::LShift, ast::Operator::RShift]) +} + +/// Div +pub fn impl_div(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type]) { + impl_binop(unifier, store, ty, other_ty, store.float, &[ast::Operator::Div]) +} + +/// FloorDiv +pub fn impl_floordiv( + unifier: &mut Unifier, + store: &PrimitiveStore, + ty: Type, + other_ty: &[Type], + ret_ty: Type, +) { + impl_binop(unifier, store, ty, other_ty, ret_ty, &[ast::Operator::FloorDiv]) +} + +/// Mod +pub fn impl_mod( + unifier: &mut Unifier, + store: &PrimitiveStore, + ty: Type, + other_ty: &[Type], + ret_ty: Type, +) { + impl_binop(unifier, store, ty, other_ty, ret_ty, &[ast::Operator::Mod]) +} + +/// UAdd, USub +pub fn impl_sign(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { + impl_unaryop(unifier, store, ty, ty, &[ast::Unaryop::UAdd, ast::Unaryop::USub]) +} + +/// Invert +pub fn impl_invert(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { + impl_unaryop(unifier, store, ty, ty, &[ast::Unaryop::Invert]) +} + +/// Not +pub fn impl_not(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { + impl_unaryop(unifier, store, ty, store.bool, &[ast::Unaryop::Not]) +} + +/// Lt, LtE, Gt, GtE +pub fn impl_comparison(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type) { + impl_cmpop( + unifier, + store, + ty, + other_ty, + &[ast::Cmpop::Lt, ast::Cmpop::Gt, ast::Cmpop::LtE, ast::Cmpop::GtE], + ) +} + +/// Eq, NotEq +pub fn impl_eq(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { + impl_cmpop(unifier, store, ty, ty, &[ast::Cmpop::Eq, ast::Cmpop::NotEq]) +} + +pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) { + let PrimitiveStore { int32: int32_t, int64: int64_t, float: float_t, bool: bool_t, .. } = + *store; + /* int32 ======== */ + impl_basic_arithmetic(unifier, store, int32_t, &[int32_t], int32_t); + impl_pow(unifier, store, int32_t, &[int32_t], int32_t); + impl_bitwise_arithmetic(unifier, store, int32_t); + impl_bitwise_shift(unifier, store, int32_t); + impl_div(unifier, store, int32_t, &[int32_t]); + impl_floordiv(unifier, store, int32_t, &[int32_t], int32_t); + impl_mod(unifier, store, int32_t, &[int32_t], int32_t); + impl_sign(unifier, store, int32_t); + impl_invert(unifier, store, int32_t); + impl_not(unifier, store, int32_t); + impl_comparison(unifier, store, int32_t, int32_t); + impl_eq(unifier, store, int32_t); + + /* int64 ======== */ + impl_basic_arithmetic(unifier, store, int64_t, &[int64_t], int64_t); + impl_pow(unifier, store, int64_t, &[int64_t], int64_t); + impl_bitwise_arithmetic(unifier, store, int64_t); + impl_bitwise_shift(unifier, store, int64_t); + impl_div(unifier, store, int64_t, &[int64_t]); + impl_floordiv(unifier, store, int64_t, &[int64_t], int64_t); + impl_mod(unifier, store, int64_t, &[int64_t], int64_t); + impl_sign(unifier, store, int64_t); + impl_invert(unifier, store, int64_t); + impl_not(unifier, store, int64_t); + impl_comparison(unifier, store, int64_t, int64_t); + impl_eq(unifier, store, int64_t); + + /* float ======== */ + impl_basic_arithmetic(unifier, store, float_t, &[float_t], float_t); + impl_pow(unifier, store, float_t, &[int32_t, float_t], float_t); + impl_div(unifier, store, float_t, &[float_t]); + impl_floordiv(unifier, store, float_t, &[float_t], float_t); + impl_mod(unifier, store, float_t, &[float_t], float_t); + impl_sign(unifier, store, float_t); + impl_not(unifier, store, float_t); + impl_comparison(unifier, store, float_t, float_t); + impl_eq(unifier, store, float_t); + + /* bool ======== */ + impl_not(unifier, store, bool_t); + impl_eq(unifier, store, bool_t); +} diff --git a/nac3core/src/typecheck/mod.rs b/nac3core/src/typecheck/mod.rs new file mode 100644 index 00000000..db7bcaec --- /dev/null +++ b/nac3core/src/typecheck/mod.rs @@ -0,0 +1,5 @@ +mod function_check; +pub mod magic_methods; +pub mod type_inferencer; +pub mod typedef; +mod unification_table; diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs new file mode 100644 index 00000000..72c47f60 --- /dev/null +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -0,0 +1,582 @@ +use std::collections::HashMap; +use std::convert::{From, TryInto}; +use std::iter::once; +use std::{cell::RefCell, sync::Arc}; + +use super::typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier}; +use super::{magic_methods::*, typedef::CallId}; +use crate::{symbol_resolver::SymbolResolver, top_level::TopLevelContext}; +use itertools::izip; +use rustpython_parser::ast::{ + self, + fold::{self, Fold}, + Arguments, Comprehension, ExprKind, Located, Location, +}; + +#[cfg(test)] +mod test; + +#[derive(PartialEq, Eq, Hash, Copy, Clone, Debug)] +pub struct CodeLocation { + row: usize, + col: usize, +} + +impl From for CodeLocation { + fn from(loc: Location) -> CodeLocation { + CodeLocation { row: loc.row(), col: loc.column() } + } +} + +#[derive(Clone, Copy)] +pub struct PrimitiveStore { + pub int32: Type, + pub int64: Type, + pub float: Type, + pub bool: Type, + pub none: Type, +} + +pub struct FunctionData { + pub resolver: Arc, + pub return_type: Option, + pub bound_variables: Vec, +} + +pub struct Inferencer<'a> { + pub top_level: &'a TopLevelContext, + pub function_data: &'a mut FunctionData, + pub unifier: &'a mut Unifier, + pub primitives: &'a PrimitiveStore, + pub virtual_checks: &'a mut Vec<(Type, Type)>, + pub variable_mapping: HashMap, + pub calls: &'a mut HashMap, +} + +struct NaiveFolder(); +impl fold::Fold<()> for NaiveFolder { + type TargetU = Option; + type Error = String; + fn map_user(&mut self, _: ()) -> Result { + Ok(None) + } +} + +impl<'a> fold::Fold<()> for Inferencer<'a> { + type TargetU = Option; + type Error = String; + + fn map_user(&mut self, _: ()) -> Result { + Ok(None) + } + + fn fold_stmt(&mut self, node: ast::Stmt<()>) -> Result, Self::Error> { + let stmt = match node.node { + // we don't want fold over type annotation + ast::StmtKind::AnnAssign { target, annotation, value, simple } => { + let target = Box::new(self.fold_expr(*target)?); + let value = if let Some(v) = value { + let ty = Box::new(self.fold_expr(*v)?); + self.unifier.unify(target.custom.unwrap(), ty.custom.unwrap())?; + Some(ty) + } else { + None + }; + let annotation_type = self.function_data.resolver.parse_type_annotation( + self.top_level, + self.unifier, + &self.primitives, + annotation.as_ref(), + )?; + self.unifier.unify(annotation_type, target.custom.unwrap())?; + let annotation = Box::new(NaiveFolder().fold_expr(*annotation)?); + Located { + location: node.location, + custom: None, + node: ast::StmtKind::AnnAssign { target, annotation, value, simple }, + } + } + _ => fold::fold_stmt(self, node)?, + }; + match &stmt.node { + ast::StmtKind::For { target, iter, .. } => { + let list = self.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() }); + self.unifier.unify(list, iter.custom.unwrap())?; + } + ast::StmtKind::If { test, .. } | ast::StmtKind::While { test, .. } => { + self.unifier.unify(test.custom.unwrap(), self.primitives.bool)?; + } + ast::StmtKind::Assign { targets, value, .. } => { + for target in targets.iter() { + self.unifier.unify(target.custom.unwrap(), value.custom.unwrap())?; + } + } + ast::StmtKind::AnnAssign { .. } | ast::StmtKind::Expr { .. } => {} + ast::StmtKind::Break | ast::StmtKind::Continue => {} + ast::StmtKind::Return { value } => match (value, self.function_data.return_type) { + (Some(v), Some(v1)) => { + self.unifier.unify(v.custom.unwrap(), v1)?; + } + (Some(_), None) => { + return Err("Unexpected return value".to_string()); + } + (None, Some(_)) => { + return Err("Expected return value".to_string()); + } + (None, None) => {} + }, + _ => return Err("Unsupported statement type".to_string()), + }; + Ok(stmt) + } + + fn fold_expr(&mut self, node: ast::Expr<()>) -> Result, Self::Error> { + let expr = match node.node { + ast::ExprKind::Call { func, args, keywords } => { + return self.fold_call(node.location, *func, args, keywords); + } + ast::ExprKind::Lambda { args, body } => { + return self.fold_lambda(node.location, *args, *body); + } + ast::ExprKind::ListComp { elt, generators } => { + return self.fold_listcomp(node.location, *elt, generators); + } + _ => fold::fold_expr(self, node)?, + }; + let custom = match &expr.node { + ast::ExprKind::Constant { value, .. } => Some(self.infer_constant(value)?), + ast::ExprKind::Name { id, .. } => Some(self.infer_identifier(id)?), + ast::ExprKind::List { elts, .. } => Some(self.infer_list(elts)?), + ast::ExprKind::Tuple { elts, .. } => Some(self.infer_tuple(elts)?), + ast::ExprKind::Attribute { value, attr, ctx: _ } => { + Some(self.infer_attribute(value, attr)?) + } + ast::ExprKind::BoolOp { values, .. } => Some(self.infer_bool_ops(values)?), + ast::ExprKind::BinOp { left, op, right } => Some(self.infer_bin_ops(left, op, right)?), + ast::ExprKind::UnaryOp { op, operand } => Some(self.infer_unary_ops(op, operand)?), + ast::ExprKind::Compare { left, ops, comparators } => { + Some(self.infer_compare(left, ops, comparators)?) + } + ast::ExprKind::Subscript { value, slice, .. } => { + Some(self.infer_subscript(value.as_ref(), slice.as_ref())?) + } + ast::ExprKind::IfExp { test, body, orelse } => { + Some(self.infer_if_expr(test, body.as_ref(), orelse.as_ref())?) + } + ast::ExprKind::ListComp { .. } + | ast::ExprKind::Lambda { .. } + | ast::ExprKind::Call { .. } => expr.custom, // already computed + ast::ExprKind::Slice { .. } => None, // we don't need it for slice + _ => return Err("not supported yet".into()), + }; + Ok(ast::Expr { custom, location: expr.location, node: expr.node }) + } +} + +type InferenceResult = Result; + +impl<'a> Inferencer<'a> { + /// Constrain a <: b + /// Currently implemented as unification + fn constrain(&mut self, a: Type, b: Type) -> Result<(), String> { + self.unifier.unify(a, b) + } + + fn build_method_call( + &mut self, + location: Location, + method: String, + obj: Type, + params: Vec, + ret: Type, + ) -> InferenceResult { + let call = self.unifier.add_call(Call { + posargs: params, + kwargs: HashMap::new(), + ret, + fun: RefCell::new(None), + }); + self.calls.insert(location.into(), call); + let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into())); + let fields = once((method, call)).collect(); + let record = self.unifier.add_record(fields); + self.constrain(obj, record)?; + Ok(ret) + } + + fn fold_lambda( + &mut self, + location: Location, + args: Arguments, + body: ast::Expr<()>, + ) -> Result>, String> { + if !args.posonlyargs.is_empty() + || args.vararg.is_some() + || !args.kwonlyargs.is_empty() + || args.kwarg.is_some() + || !args.defaults.is_empty() + { + // actually I'm not sure whether programs violating this is a valid python program. + return Err( + "We only support positional or keyword arguments without defaults for lambdas." + .to_string(), + ); + } + + let fn_args: Vec<_> = args + .args + .iter() + .map(|v| (v.node.arg.clone(), self.unifier.get_fresh_var().0)) + .collect(); + let mut variable_mapping = self.variable_mapping.clone(); + variable_mapping.extend(fn_args.iter().cloned()); + let ret = self.unifier.get_fresh_var().0; + let mut new_context = Inferencer { + function_data: self.function_data, + unifier: self.unifier, + primitives: self.primitives, + virtual_checks: self.virtual_checks, + calls: self.calls, + top_level: self.top_level, + variable_mapping, + }; + let fun = FunSignature { + args: fn_args + .iter() + .map(|(k, ty)| FuncArg { name: k.clone(), ty: *ty, default_value: None }) + .collect(), + ret, + vars: Default::default(), + }; + let body = new_context.fold_expr(body)?; + new_context.unifier.unify(fun.ret, body.custom.unwrap())?; + let mut args = new_context.fold_arguments(args)?; + for (arg, (name, ty)) in args.args.iter_mut().zip(fn_args.iter()) { + assert_eq!(&arg.node.arg, name); + arg.custom = Some(*ty); + } + Ok(Located { + location, + node: ExprKind::Lambda { args: args.into(), body: body.into() }, + custom: Some(self.unifier.add_ty(TypeEnum::TFunc(fun.into()))), + }) + } + + fn fold_listcomp( + &mut self, + location: Location, + elt: ast::Expr<()>, + mut generators: Vec, + ) -> Result>, String> { + if generators.len() != 1 { + return Err( + "Only 1 generator statement for list comprehension is supported.".to_string() + ); + } + let variable_mapping = self.variable_mapping.clone(); + let mut new_context = Inferencer { + function_data: self.function_data, + unifier: self.unifier, + virtual_checks: self.virtual_checks, + top_level: self.top_level, + variable_mapping, + primitives: self.primitives, + calls: self.calls, + }; + let elt = new_context.fold_expr(elt)?; + let generator = generators.pop().unwrap(); + if generator.is_async { + return Err("Async iterator not supported.".to_string()); + } + let target = new_context.fold_expr(*generator.target)?; + let iter = new_context.fold_expr(*generator.iter)?; + let ifs: Vec<_> = generator + .ifs + .into_iter() + .map(|v| new_context.fold_expr(v)) + .collect::>()?; + + // iter should be a list of targets... + // actually it should be an iterator of targets, but we don't have iter type for now + let list = new_context.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() }); + new_context.unifier.unify(iter.custom.unwrap(), list)?; + // if conditions should be bool + for v in ifs.iter() { + new_context.unifier.unify(v.custom.unwrap(), new_context.primitives.bool)?; + } + + Ok(Located { + location, + custom: Some(new_context.unifier.add_ty(TypeEnum::TList { ty: elt.custom.unwrap() })), + node: ExprKind::ListComp { + elt: Box::new(elt), + generators: vec![ast::Comprehension { + target: Box::new(target), + iter: Box::new(iter), + ifs, + is_async: false, + }], + }, + }) + } + + fn fold_call( + &mut self, + location: Location, + func: ast::Expr<()>, + mut args: Vec>, + keywords: Vec>, + ) -> Result>, String> { + let func = + if let Located { location: func_location, custom, node: ExprKind::Name { id, ctx } } = + func + { + // handle special functions that cannot be typed in the usual way... + if id == "virtual" { + if args.is_empty() || args.len() > 2 || !keywords.is_empty() { + return Err( + "`virtual` can only accept 1/2 positional arguments.".to_string() + ); + } + let arg0 = self.fold_expr(args.remove(0))?; + let ty = if let Some(arg) = args.pop() { + self.function_data.resolver.parse_type_annotation( + self.top_level, + self.unifier, + self.primitives, + &arg, + )? + } else { + self.unifier.get_fresh_var().0 + }; + self.virtual_checks.push((arg0.custom.unwrap(), ty)); + let custom = Some(self.unifier.add_ty(TypeEnum::TVirtual { ty })); + return Ok(Located { + location, + custom, + node: ExprKind::Call { + func: Box::new(Located { + custom: None, + location: func.location, + node: ExprKind::Name { id, ctx }, + }), + args: vec![arg0], + keywords: vec![], + }, + }); + } + // int64 is special because its argument can be a constant larger than int32 + if id == "int64" && args.len() == 1 { + if let ExprKind::Constant { value: ast::Constant::Int(val), kind } = + &args[0].node + { + let int64: Result = val.try_into(); + let custom; + if int64.is_ok() { + custom = Some(self.primitives.int64); + } else { + return Err("Integer out of bound".into()); + } + return Ok(Located { + location: args[0].location, + custom, + node: ExprKind::Constant { + value: ast::Constant::Int(val.clone()), + kind: kind.clone(), + }, + }); + } + } + Located { location: func_location, custom, node: ExprKind::Name { id, ctx } } + } else { + func + }; + let func = Box::new(self.fold_expr(func)?); + let args = args.into_iter().map(|v| self.fold_expr(v)).collect::, _>>()?; + let keywords = keywords + .into_iter() + .map(|v| fold::fold_keyword(self, v)) + .collect::, _>>()?; + let ret = self.unifier.get_fresh_var().0; + let call = self.unifier.add_call(Call { + posargs: args.iter().map(|v| v.custom.unwrap()).collect(), + kwargs: keywords + .iter() + .map(|v| (v.node.arg.as_ref().unwrap().clone(), v.custom.unwrap())) + .collect(), + fun: RefCell::new(None), + ret, + }); + self.calls.insert(location.into(), call); + let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into())); + self.unifier.unify(func.custom.unwrap(), call)?; + + Ok(Located { location, custom: Some(ret), node: ExprKind::Call { func, args, keywords } }) + } + + fn infer_identifier(&mut self, id: &str) -> InferenceResult { + if let Some(ty) = self.variable_mapping.get(id) { + Ok(*ty) + } else { + Ok(self + .function_data + .resolver + .get_symbol_type(self.unifier, self.primitives, id) + .unwrap_or_else(|| { + let ty = self.unifier.get_fresh_var().0; + self.variable_mapping.insert(id.to_string(), ty); + ty + })) + } + } + + fn infer_constant(&mut self, constant: &ast::Constant) -> InferenceResult { + match constant { + ast::Constant::Bool(_) => Ok(self.primitives.bool), + ast::Constant::Int(val) => { + let int32: Result = val.try_into(); + // int64 would be handled separately in functions + if int32.is_ok() { + Ok(self.primitives.int32) + } else { + Err("Integer out of bound".into()) + } + } + ast::Constant::Float(_) => Ok(self.primitives.float), + ast::Constant::Tuple(vals) => { + let ty: Result, _> = vals.iter().map(|x| self.infer_constant(x)).collect(); + Ok(self.unifier.add_ty(TypeEnum::TTuple { ty: ty? })) + } + _ => Err("not supported".into()), + } + } + + fn infer_list(&mut self, elts: &[ast::Expr>]) -> InferenceResult { + let (ty, _) = self.unifier.get_fresh_var(); + for t in elts.iter() { + self.unifier.unify(ty, t.custom.unwrap())?; + } + Ok(self.unifier.add_ty(TypeEnum::TList { ty })) + } + + fn infer_tuple(&mut self, elts: &[ast::Expr>]) -> InferenceResult { + let ty = elts.iter().map(|x| x.custom.unwrap()).collect(); + Ok(self.unifier.add_ty(TypeEnum::TTuple { ty })) + } + + fn infer_attribute(&mut self, value: &ast::Expr>, attr: &str) -> InferenceResult { + let (attr_ty, _) = self.unifier.get_fresh_var(); + let fields = once((attr.to_string(), attr_ty)).collect(); + let record = self.unifier.add_record(fields); + self.constrain(value.custom.unwrap(), record)?; + Ok(attr_ty) + } + + fn infer_bool_ops(&mut self, values: &[ast::Expr>]) -> InferenceResult { + let b = self.primitives.bool; + for v in values { + self.constrain(v.custom.unwrap(), b)?; + } + Ok(b) + } + + fn infer_bin_ops( + &mut self, + left: &ast::Expr>, + op: &ast::Operator, + right: &ast::Expr>, + ) -> InferenceResult { + let method = binop_name(op); + let ret = self.unifier.get_fresh_var().0; + self.build_method_call( + left.location, + method.to_string(), + left.custom.unwrap(), + vec![right.custom.unwrap()], + ret, + ) + } + + fn infer_unary_ops( + &mut self, + op: &ast::Unaryop, + operand: &ast::Expr>, + ) -> InferenceResult { + let method = unaryop_name(op); + let ret = self.unifier.get_fresh_var().0; + self.build_method_call( + operand.location, + method.to_string(), + operand.custom.unwrap(), + vec![], + ret, + ) + } + + fn infer_compare( + &mut self, + left: &ast::Expr>, + ops: &[ast::Cmpop], + comparators: &[ast::Expr>], + ) -> InferenceResult { + let boolean = self.primitives.bool; + for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) { + let method = + comparison_name(c).ok_or_else(|| "unsupported comparator".to_string())?.to_string(); + self.build_method_call( + a.location, + method, + a.custom.unwrap(), + vec![b.custom.unwrap()], + boolean, + )?; + } + Ok(boolean) + } + + fn infer_subscript( + &mut self, + value: &ast::Expr>, + slice: &ast::Expr>, + ) -> InferenceResult { + let ty = self.unifier.get_fresh_var().0; + match &slice.node { + ast::ExprKind::Slice { lower, upper, step } => { + for v in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() { + self.constrain(v.custom.unwrap(), self.primitives.int32)?; + } + let list = self.unifier.add_ty(TypeEnum::TList { ty }); + self.constrain(value.custom.unwrap(), list)?; + Ok(list) + } + ast::ExprKind::Constant { value: ast::Constant::Int(val), .. } => { + // the index is a constant, so value can be a sequence. + let ind: i32 = val.try_into().map_err(|_| "Index must be int32".to_string())?; + let map = once((ind, ty)).collect(); + let seq = self.unifier.add_sequence(map); + self.constrain(value.custom.unwrap(), seq)?; + Ok(ty) + } + _ => { + // the index is not a constant, so value can only be a list + self.constrain(slice.custom.unwrap(), self.primitives.int32)?; + let list = self.unifier.add_ty(TypeEnum::TList { ty }); + self.constrain(value.custom.unwrap(), list)?; + Ok(ty) + } + } + } + + fn infer_if_expr( + &mut self, + test: &ast::Expr>, + body: &ast::Expr>, + orelse: &ast::Expr>, + ) -> InferenceResult { + self.constrain(test.custom.unwrap(), self.primitives.bool)?; + let ty = self.unifier.get_fresh_var().0; + self.constrain(body.custom.unwrap(), ty)?; + self.constrain(orelse.custom.unwrap(), ty)?; + Ok(ty) + } +} diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs new file mode 100644 index 00000000..6952ef1e --- /dev/null +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -0,0 +1,546 @@ +use super::super::typedef::*; +use super::*; +use crate::symbol_resolver::*; +use crate::top_level::DefinitionId; +use crate::{location::Location, top_level::TopLevelDef}; +use indoc::indoc; +use itertools::zip; +use parking_lot::RwLock; +use rustpython_parser::parser::parse_program; +use test_case::test_case; + +struct Resolver { + id_to_type: HashMap, + id_to_def: HashMap, + class_names: HashMap, +} + +impl SymbolResolver for Resolver { + fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option { + self.id_to_type.get(str).cloned() + } + + fn get_symbol_value(&self, _: &str) -> Option { + unimplemented!() + } + + fn get_symbol_location(&self, _: &str) -> Option { + unimplemented!() + } + + fn get_identifier_def(&self, id: &str) -> Option { + self.id_to_def.get(id).cloned() + } +} + +struct TestEnvironment { + pub unifier: Unifier, + pub function_data: FunctionData, + pub primitives: PrimitiveStore, + pub id_to_name: HashMap, + pub identifier_mapping: HashMap, + pub virtual_checks: Vec<(Type, Type)>, + pub calls: HashMap, + pub top_level: TopLevelContext, +} + +impl TestEnvironment { + pub fn basic_test_env() -> TestEnvironment { + let mut unifier = Unifier::new(); + + let int32 = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(0), + fields: HashMap::new().into(), + params: HashMap::new().into(), + }); + let int64 = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(1), + fields: HashMap::new().into(), + params: HashMap::new().into(), + }); + let float = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(2), + fields: HashMap::new().into(), + params: HashMap::new().into(), + }); + let bool = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(3), + fields: HashMap::new().into(), + params: HashMap::new().into(), + }); + let none = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(4), + fields: HashMap::new().into(), + params: HashMap::new().into(), + }); + let primitives = PrimitiveStore { int32, int64, float, bool, none }; + set_primitives_magic_methods(&primitives, &mut unifier); + + let id_to_name = [ + (0, "int32".to_string()), + (1, "int64".to_string()), + (2, "float".to_string()), + (3, "bool".to_string()), + (4, "none".to_string()), + ] + .iter() + .cloned() + .collect(); + + let mut identifier_mapping = HashMap::new(); + identifier_mapping.insert("None".into(), none); + + let resolver = Arc::new(Resolver { + id_to_type: identifier_mapping.clone(), + id_to_def: Default::default(), + class_names: Default::default(), + }) as Arc; + + TestEnvironment { + top_level: TopLevelContext { + definitions: Default::default(), + unifiers: Default::default(), + }, + unifier, + function_data: FunctionData { + resolver, + bound_variables: Vec::new(), + return_type: None, + }, + primitives, + id_to_name, + identifier_mapping, + virtual_checks: Vec::new(), + calls: HashMap::new(), + } + } + + fn new() -> TestEnvironment { + let mut unifier = Unifier::new(); + let mut identifier_mapping = HashMap::new(); + let mut top_level_defs: Vec>> = Vec::new(); + let int32 = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(0), + fields: HashMap::new().into(), + params: HashMap::new().into(), + }); + let int64 = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(1), + fields: HashMap::new().into(), + params: HashMap::new().into(), + }); + let float = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(2), + fields: HashMap::new().into(), + params: HashMap::new().into(), + }); + let bool = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(3), + fields: HashMap::new().into(), + params: HashMap::new().into(), + }); + let none = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(4), + fields: HashMap::new().into(), + params: HashMap::new().into(), + }); + identifier_mapping.insert("None".into(), none); + for i in 0..5 { + top_level_defs.push( + RwLock::new(TopLevelDef::Class { + object_id: DefinitionId(i), + type_vars: Default::default(), + fields: Default::default(), + methods: Default::default(), + ancestors: Default::default(), + resolver: None, + }) + .into(), + ); + } + + let primitives = PrimitiveStore { int32, int64, float, bool, none }; + + let (v0, id) = unifier.get_fresh_var(); + + let foo_ty = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(5), + fields: [("a".into(), v0)].iter().cloned().collect::>().into(), + params: [(id, v0)].iter().cloned().collect::>().into(), + }); + top_level_defs.push( + RwLock::new(TopLevelDef::Class { + object_id: DefinitionId(5), + type_vars: vec![v0], + fields: [("a".into(), v0)].into(), + methods: Default::default(), + ancestors: Default::default(), + resolver: None, + }) + .into(), + ); + + identifier_mapping.insert( + "Foo".into(), + unifier.add_ty(TypeEnum::TFunc( + FunSignature { + args: vec![], + ret: foo_ty, + vars: [(id, v0)].iter().cloned().collect(), + } + .into(), + )), + ); + + let fun = unifier.add_ty(TypeEnum::TFunc( + FunSignature { args: vec![], ret: int32, vars: Default::default() }.into(), + )); + let bar = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(6), + fields: [("a".into(), int32), ("b".into(), fun)] + .iter() + .cloned() + .collect::>() + .into(), + params: Default::default(), + }); + top_level_defs.push( + RwLock::new(TopLevelDef::Class { + object_id: DefinitionId(6), + type_vars: Default::default(), + fields: [("a".into(), int32), ("b".into(), fun)].into(), + methods: Default::default(), + ancestors: Default::default(), + resolver: None, + }) + .into(), + ); + identifier_mapping.insert( + "Bar".into(), + unifier.add_ty(TypeEnum::TFunc( + FunSignature { args: vec![], ret: bar, vars: Default::default() }.into(), + )), + ); + + let bar2 = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(7), + fields: [("a".into(), bool), ("b".into(), fun)] + .iter() + .cloned() + .collect::>() + .into(), + params: Default::default(), + }); + top_level_defs.push( + RwLock::new(TopLevelDef::Class { + object_id: DefinitionId(7), + type_vars: Default::default(), + fields: [("a".into(), bool), ("b".into(), fun)].into(), + methods: Default::default(), + ancestors: Default::default(), + resolver: None, + }) + .into(), + ); + identifier_mapping.insert( + "Bar2".into(), + unifier.add_ty(TypeEnum::TFunc( + FunSignature { args: vec![], ret: bar2, vars: Default::default() }.into(), + )), + ); + let class_names = [("Bar".into(), bar), ("Bar2".into(), bar2)].iter().cloned().collect(); + + let id_to_name = [ + (0, "int32".to_string()), + (1, "int64".to_string()), + (2, "float".to_string()), + (3, "bool".to_string()), + (4, "none".to_string()), + (5, "Foo".to_string()), + (6, "Bar".to_string()), + (7, "Bar2".to_string()), + ] + .iter() + .cloned() + .collect(); + + let top_level = TopLevelContext { + definitions: Arc::new(RwLock::new(top_level_defs)), + unifiers: Default::default(), + }; + + let resolver = Arc::new(Resolver { + id_to_type: identifier_mapping.clone(), + id_to_def: [ + ("Foo".into(), DefinitionId(5)), + ("Bar".into(), DefinitionId(6)), + ("Bar2".into(), DefinitionId(7)), + ] + .iter() + .cloned() + .collect(), + class_names, + }) as Arc; + + TestEnvironment { + unifier, + top_level, + function_data: FunctionData { + resolver, + bound_variables: Vec::new(), + return_type: None, + }, + primitives, + id_to_name, + identifier_mapping, + virtual_checks: Vec::new(), + calls: HashMap::new(), + } + } + + fn get_inferencer(&mut self) -> Inferencer { + Inferencer { + top_level: &self.top_level, + function_data: &mut self.function_data, + unifier: &mut self.unifier, + variable_mapping: Default::default(), + primitives: &mut self.primitives, + virtual_checks: &mut self.virtual_checks, + calls: &mut self.calls, + } + } +} + +#[test_case(indoc! {" + a = 1234 + b = int64(2147483648) + c = 1.234 + d = True + "}, + [("a", "int32"), ("b", "int64"), ("c", "float"), ("d", "bool")].iter().cloned().collect(), + &[] + ; "primitives test")] +#[test_case(indoc! {" + a = lambda x, y: x + b = lambda x: a(x, x) + c = 1.234 + d = b(c) + "}, + [("a", "fn[[x=float, y=float], float]"), ("b", "fn[[x=float], float]"), ("c", "float"), ("d", "float")].iter().cloned().collect(), + &[] + ; "lambda test")] +#[test_case(indoc! {" + a = lambda x: x + b = lambda x: x + + foo1 = Foo() + foo2 = Foo() + c = a(foo1.a) + d = b(foo2.a) + + a(True) + b(123) + + "}, + [("a", "fn[[x=bool], bool]"), ("b", "fn[[x=int32], int32]"), ("c", "bool"), + ("d", "int32"), ("foo1", "Foo[bool]"), ("foo2", "Foo[int32]")].iter().cloned().collect(), + &[] + ; "obj test")] +#[test_case(indoc! {" + f = lambda x: True + a = [1, 2, 3] + b = [f(x) for x in a if f(x)] + "}, + [("a", "list[int32]"), ("b", "list[bool]"), ("f", "fn[[x=int32], bool]")].iter().cloned().collect(), + &[] + ; "listcomp test")] +#[test_case(indoc! {" + a = virtual(Bar(), Bar) + b = a.b() + a = virtual(Bar2()) + "}, + [("a", "virtual[Bar]"), ("b", "int32")].iter().cloned().collect(), + &[("Bar", "Bar"), ("Bar2", "Bar")] + ; "virtual test")] +#[test_case(indoc! {" + a = [virtual(Bar(), Bar), virtual(Bar2())] + b = [x.b() for x in a] + "}, + [("a", "list[virtual[Bar]]"), ("b", "list[int32]")].iter().cloned().collect(), + &[("Bar", "Bar"), ("Bar2", "Bar")] + ; "virtual list test")] +fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &str)]) { + println!("source:\n{}", source); + let mut env = TestEnvironment::new(); + let id_to_name = std::mem::take(&mut env.id_to_name); + let mut defined_identifiers: Vec<_> = env.identifier_mapping.keys().cloned().collect(); + defined_identifiers.push("virtual".to_string()); + let mut inferencer = env.get_inferencer(); + let statements = parse_program(source).unwrap(); + let statements = statements + .into_iter() + .map(|v| inferencer.fold_stmt(v)) + .collect::, _>>() + .unwrap(); + + inferencer.check_block(&statements, &mut defined_identifiers).unwrap(); + + for (k, v) in inferencer.variable_mapping.iter() { + let name = inferencer.unifier.stringify( + *v, + &mut |v| id_to_name.get(&v).unwrap().clone(), + &mut |v| format!("v{}", v), + ); + println!("{}: {}", k, name); + } + for (k, v) in mapping.iter() { + let ty = inferencer.variable_mapping.get(*k).unwrap(); + let name = inferencer.unifier.stringify( + *ty, + &mut |v| id_to_name.get(&v).unwrap().clone(), + &mut |v| format!("v{}", v), + ); + assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name)); + } + assert_eq!(inferencer.virtual_checks.len(), virtuals.len()); + for ((a, b), (x, y)) in zip(inferencer.virtual_checks.iter(), virtuals) { + let a = inferencer.unifier.stringify( + *a, + &mut |v| id_to_name.get(&v).unwrap().clone(), + &mut |v| format!("v{}", v), + ); + let b = inferencer.unifier.stringify( + *b, + &mut |v| id_to_name.get(&v).unwrap().clone(), + &mut |v| format!("v{}", v), + ); + + assert_eq!(&a, x); + assert_eq!(&b, y); + } +} + +#[test_case(indoc! {" + a = 2 + b = 2 + c = a + b + d = a - b + e = a * b + f = a / b + g = a // b + h = a % b + "}, + [("a", "int32"), + ("b", "int32"), + ("c", "int32"), + ("d", "int32"), + ("e", "int32"), + ("f", "float"), + ("g", "int32"), + ("h", "int32")].iter().cloned().collect() + ; "int32")] +#[test_case( + indoc! {" + a = 2.4 + b = 3.6 + c = a + b + d = a - b + e = a * b + f = a / b + g = a // b + h = a % b + i = a ** b + ii = 3 + j = a ** b + "}, + [("a", "float"), + ("b", "float"), + ("c", "float"), + ("d", "float"), + ("e", "float"), + ("f", "float"), + ("g", "float"), + ("h", "float"), + ("i", "float"), + ("ii", "int32"), + ("j", "float")].iter().cloned().collect() + ; "float" +)] +#[test_case( + indoc! {" + a = int64(12312312312) + b = int64(24242424424) + c = a + b + d = a - b + e = a * b + f = a / b + g = a // b + h = a % b + i = a == b + j = a > b + k = a < b + l = a != b + "}, + [("a", "int64"), + ("b", "int64"), + ("c", "int64"), + ("d", "int64"), + ("e", "int64"), + ("f", "float"), + ("g", "int64"), + ("h", "int64"), + ("i", "bool"), + ("j", "bool"), + ("k", "bool"), + ("l", "bool")].iter().cloned().collect() + ; "int64" +)] +#[test_case( + indoc! {" + a = True + b = False + c = a == b + d = not a + e = a != b + "}, + [("a", "bool"), + ("b", "bool"), + ("c", "bool"), + ("d", "bool"), + ("e", "bool")].iter().cloned().collect() + ; "boolean" +)] +fn test_primitive_magic_methods(source: &str, mapping: HashMap<&str, &str>) { + println!("source:\n{}", source); + let mut env = TestEnvironment::basic_test_env(); + let id_to_name = std::mem::take(&mut env.id_to_name); + let mut defined_identifiers: Vec<_> = env.identifier_mapping.keys().cloned().collect(); + defined_identifiers.push("virtual".to_string()); + let mut inferencer = env.get_inferencer(); + let statements = parse_program(source).unwrap(); + let statements = statements + .into_iter() + .map(|v| inferencer.fold_stmt(v)) + .collect::, _>>() + .unwrap(); + + inferencer.check_block(&statements, &mut defined_identifiers).unwrap(); + + for (k, v) in inferencer.variable_mapping.iter() { + let name = inferencer.unifier.stringify( + *v, + &mut |v| id_to_name.get(&v).unwrap().clone(), + &mut |v| format!("v{}", v), + ); + println!("{}: {}", k, name); + } + for (k, v) in mapping.iter() { + let ty = inferencer.variable_mapping.get(*k).unwrap(); + let name = inferencer.unifier.stringify( + *ty, + &mut |v| id_to_name.get(&v).unwrap().clone(), + &mut |v| format!("v{}", v), + ); + assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name)); + } +} diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs new file mode 100644 index 00000000..621ea65b --- /dev/null +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -0,0 +1,947 @@ +use itertools::{chain, zip, Itertools}; +use std::borrow::Cow; +use std::cell::RefCell; +use std::collections::HashMap; +use std::iter::once; +use std::rc::Rc; +use std::sync::{Arc, Mutex}; + +use super::unification_table::{UnificationKey, UnificationTable}; +use crate::symbol_resolver::SymbolValue; +use crate::top_level::DefinitionId; + +#[cfg(test)] +mod test; + +/// Handle for a type, implementated as a key in the unification table. +pub type Type = UnificationKey; + +#[derive(Clone, Copy, PartialEq, Eq)] +pub struct CallId(usize); + +pub type Mapping = HashMap; +type VarMap = Mapping; + +#[derive(Clone)] +pub struct Call { + pub posargs: Vec, + pub kwargs: HashMap, + pub ret: Type, + pub fun: RefCell>, +} + +#[derive(Clone)] +pub struct FuncArg { + pub name: String, + pub ty: Type, + pub default_value: Option, +} + +#[derive(Clone)] +pub struct FunSignature { + pub args: Vec, + pub ret: Type, + pub vars: VarMap, +} + +#[derive(Clone)] +pub enum TypeVarMeta { + Generic, + Sequence(RefCell>), + Record(RefCell>), +} + +#[derive(Clone)] +pub enum TypeEnum { + TRigidVar { + id: u32, + }, + TVar { + id: u32, + meta: TypeVarMeta, + // empty indicates no restriction + range: RefCell>, + }, + TTuple { + ty: Vec, + }, + TList { + ty: Type, + }, + TObj { + obj_id: DefinitionId, + fields: RefCell>, + params: RefCell, + }, + TVirtual { + ty: Type, + }, + TCall(RefCell>), + TFunc(RefCell), +} + +impl TypeEnum { + pub fn get_type_name(&self) -> &'static str { + match self { + TypeEnum::TRigidVar { .. } => "TRigidVar", + TypeEnum::TVar { .. } => "TVar", + TypeEnum::TTuple { .. } => "TTuple", + TypeEnum::TList { .. } => "TList", + TypeEnum::TObj { .. } => "TObj", + TypeEnum::TVirtual { .. } => "TVirtual", + TypeEnum::TCall { .. } => "TCall", + TypeEnum::TFunc { .. } => "TFunc", + } + } +} + +pub type SharedUnifier = Arc, u32, Vec)>>; + +pub struct Unifier { + unification_table: UnificationTable>, + calls: Vec>, + var_id: u32, +} + +impl Unifier { + /// Get an empty unifier + pub fn new() -> Unifier { + Unifier { unification_table: UnificationTable::new(), var_id: 0, calls: Vec::new() } + } + + /// Determine if the two types are the same + pub fn unioned(&mut self, a: Type, b: Type) -> bool { + self.unification_table.unioned(a, b) + } + + pub fn from_shared_unifier(unifier: &SharedUnifier) -> Unifier { + let lock = unifier.lock().unwrap(); + Unifier { + unification_table: UnificationTable::from_send(&lock.0), + var_id: lock.1, + calls: lock.2.iter().map(|v| Rc::new(v.clone())).collect_vec(), + } + } + + pub fn get_shared_unifier(&self) -> SharedUnifier { + Arc::new(Mutex::new(( + self.unification_table.get_send(), + self.var_id, + self.calls.iter().map(|v| v.as_ref().clone()).collect_vec(), + ))) + } + + /// Register a type to the unifier. + /// Returns a key in the unification_table. + pub fn add_ty(&mut self, a: TypeEnum) -> Type { + self.unification_table.new_key(Rc::new(a)) + } + + pub fn add_record(&mut self, fields: Mapping) -> Type { + let id = self.var_id + 1; + self.var_id += 1; + self.add_ty(TypeEnum::TVar { + id, + range: vec![].into(), + meta: TypeVarMeta::Record(fields.into()), + }) + } + + pub fn add_call(&mut self, call: Call) -> CallId { + let id = CallId(self.calls.len()); + self.calls.push(Rc::new(call)); + id + } + + pub fn get_representative(&mut self, ty: Type) -> Type { + self.unification_table.get_representative(ty) + } + + pub fn add_sequence(&mut self, sequence: Mapping) -> Type { + let id = self.var_id + 1; + self.var_id += 1; + self.add_ty(TypeEnum::TVar { + id, + range: vec![].into(), + meta: TypeVarMeta::Sequence(sequence.into()), + }) + } + + /// Get the TypeEnum of a type. + pub fn get_ty(&mut self, a: Type) -> Rc { + self.unification_table.probe_value(a).clone() + } + + pub fn get_fresh_rigid_var(&mut self) -> (Type, u32) { + let id = self.var_id + 1; + self.var_id += 1; + (self.add_ty(TypeEnum::TRigidVar { id }), id) + } + + pub fn get_fresh_var(&mut self) -> (Type, u32) { + self.get_fresh_var_with_range(&[]) + } + + /// Get a fresh type variable. + pub fn get_fresh_var_with_range(&mut self, range: &[Type]) -> (Type, u32) { + let id = self.var_id + 1; + self.var_id += 1; + let range = range.to_vec().into(); + (self.add_ty(TypeEnum::TVar { id, range, meta: TypeVarMeta::Generic }), id) + } + + /// Unification would not unify rigid variables with other types, but we want to do this for + /// function instantiations, so we make it explicit. + pub fn replace_rigid_var(&mut self, rigid: Type, b: Type) { + assert!(matches!(&*self.get_ty(rigid), TypeEnum::TRigidVar { .. })); + self.set_a_to_b(rigid, b); + } + + pub fn get_instantiations(&mut self, ty: Type) -> Option> { + match &*self.get_ty(ty) { + TypeEnum::TVar { range, .. } => { + let range = range.borrow(); + if range.is_empty() { + None + } else { + Some( + range + .iter() + .map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty])) + .flatten() + .collect_vec(), + ) + } + } + TypeEnum::TList { ty } => self + .get_instantiations(*ty) + .map(|ty| ty.iter().map(|&ty| self.add_ty(TypeEnum::TList { ty })).collect_vec()), + TypeEnum::TVirtual { ty } => self.get_instantiations(*ty).map(|ty| { + ty.iter().map(|&ty| self.add_ty(TypeEnum::TVirtual { ty })).collect_vec() + }), + TypeEnum::TTuple { ty } => { + let tuples = ty + .iter() + .map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty])) + .multi_cartesian_product() + .collect_vec(); + if tuples.len() == 1 { + None + } else { + Some( + tuples.into_iter().map(|ty| self.add_ty(TypeEnum::TTuple { ty })).collect(), + ) + } + } + TypeEnum::TObj { params, .. } => { + let params = params.borrow(); + let (keys, params): (Vec<&u32>, Vec<&Type>) = params.iter().unzip(); + let params = params + .into_iter() + .map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty])) + .multi_cartesian_product() + .collect_vec(); + if params.len() <= 1 { + None + } else { + Some( + params + .into_iter() + .map(|params| { + self.subst( + ty, + &zip(keys.iter().cloned().cloned(), params.iter().cloned()) + .collect(), + ) + .unwrap_or(ty) + }) + .collect(), + ) + } + } + _ => None, + } + } + + pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool { + use TypeEnum::*; + match &*self.get_ty(a) { + TRigidVar { .. } => true, + TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)), + TCall { .. } => false, + TList { ty } => self.is_concrete(*ty, allowed_typevars), + TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)), + TObj { params: vars, .. } => { + vars.borrow().values().all(|ty| self.is_concrete(*ty, allowed_typevars)) + } + // functions are instantiated for each call sites, so the function type can contain + // type variables. + TFunc { .. } => true, + TVirtual { ty } => self.is_concrete(*ty, allowed_typevars), + } + } + + pub fn unify(&mut self, a: Type, b: Type) -> Result<(), String> { + if self.unification_table.unioned(a, b) { + Ok(()) + } else { + self.unify_impl(a, b, false) + } + } + + fn unify_impl(&mut self, a: Type, b: Type, swapped: bool) -> Result<(), String> { + use TypeEnum::*; + use TypeVarMeta::*; + let (ty_a, ty_b) = { + ( + self.unification_table.probe_value(a).clone(), + self.unification_table.probe_value(b).clone(), + ) + }; + match (&*ty_a, &*ty_b) { + (TVar { meta: meta1, range: range1, .. }, TVar { meta: meta2, range: range2, .. }) => { + self.occur_check(a, b)?; + self.occur_check(b, a)?; + match (meta1, meta2) { + (Generic, _) => {} + (_, Generic) => { + return self.unify_impl(b, a, true); + } + (Record(fields1), Record(fields2)) => { + let mut fields2 = fields2.borrow_mut(); + for (key, value) in fields1.borrow().iter() { + if let Some(ty) = fields2.get(key) { + self.unify(*ty, *value)?; + } else { + fields2.insert(key.clone(), *value); + } + } + } + (Sequence(map1), Sequence(map2)) => { + let mut map2 = map2.borrow_mut(); + for (key, value) in map1.borrow().iter() { + if let Some(ty) = map2.get(key) { + self.unify(*ty, *value)?; + } else { + map2.insert(*key, *value); + } + } + } + _ => { + return Err("Incompatible".to_string()); + } + } + let range1 = range1.borrow(); + // new range is the intersection of them + // empty range indicates no constraint + if !range1.is_empty() { + let old_range2 = range2.take(); + let mut range2 = range2.borrow_mut(); + if old_range2.is_empty() { + range2.extend_from_slice(&range1); + } + for v1 in old_range2.iter() { + for v2 in range1.iter() { + if let Ok(result) = self.get_intersection(*v1, *v2) { + range2.push(result.unwrap_or(*v2)); + } + } + } + if range2.is_empty() { + return Err( + "cannot unify type variables with incompatible value range".to_string() + ); + } + } + self.set_a_to_b(a, b); + } + (TVar { meta: Generic, id, range, .. }, _) => { + self.occur_check(a, b)?; + // We check for the range of the type variable to see if unification is allowed. + // Note that although b may be compatible with a, we may have to constrain type + // variables in b to make sure that instantiations of b would always be compatible + // with a. + // The return value x of check_var_compatibility would be a new type that is + // guaranteed to be compatible with a under all possible instantiations. So we + // unify x with b to recursively apply the constrains, and then set a to x. + let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b); + self.unify(x, b)?; + self.set_a_to_b(a, x); + } + (TVar { meta: Sequence(map), id, range, .. }, TTuple { ty }) => { + self.occur_check(a, b)?; + let len = ty.len() as i32; + for (k, v) in map.borrow().iter() { + // handle negative index + let ind = if *k < 0 { len + *k } else { *k }; + if ind >= len || ind < 0 { + return Err(format!( + "Tuple index out of range. (Length: {}, Index: {})", + len, k + )); + } + self.unify(*v, ty[ind as usize])?; + } + let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b); + self.unify(x, b)?; + self.set_a_to_b(a, x); + } + (TVar { meta: Sequence(map), id, range, .. }, TList { ty }) => { + self.occur_check(a, b)?; + for v in map.borrow().values() { + self.unify(*v, *ty)?; + } + let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b); + self.unify(x, b)?; + self.set_a_to_b(a, x); + } + (TTuple { ty: ty1 }, TTuple { ty: ty2 }) => { + if ty1.len() != ty2.len() { + return Err(format!( + "Cannot unify tuples with length {} and {}", + ty1.len(), + ty2.len() + )); + } + for (x, y) in ty1.iter().zip(ty2.iter()) { + self.unify(*x, *y)?; + } + self.set_a_to_b(a, b); + } + (TList { ty: ty1 }, TList { ty: ty2 }) => { + self.unify(*ty1, *ty2)?; + self.set_a_to_b(a, b); + } + (TVar { meta: Record(map), id, range, .. }, TObj { fields, .. }) => { + self.occur_check(a, b)?; + for (k, v) in map.borrow().iter() { + let ty = fields + .borrow() + .get(k) + .copied() + .ok_or_else(|| format!("No such attribute {}", k))?; + self.unify(ty, *v)?; + } + let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b); + self.unify(x, b)?; + self.set_a_to_b(a, x); + } + (TVar { meta: Record(map), id, range, .. }, TVirtual { ty }) => { + self.occur_check(a, b)?; + let ty = self.get_ty(*ty); + if let TObj { fields, .. } = ty.as_ref() { + for (k, v) in map.borrow().iter() { + let ty = fields + .borrow() + .get(k) + .copied() + .ok_or_else(|| format!("No such attribute {}", k))?; + if !matches!(self.get_ty(ty).as_ref(), TFunc { .. }) { + return Err(format!("Cannot access field {} for virtual type", k)); + } + self.unify(*v, ty)?; + } + } else { + // require annotation... + return Err("Requires type annotation for virtual".to_string()); + } + let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b); + self.unify(x, b)?; + self.set_a_to_b(a, x); + } + ( + TObj { obj_id: id1, params: params1, .. }, + TObj { obj_id: id2, params: params2, .. }, + ) => { + if id1 != id2 { + return Err(format!("Cannot unify objects with ID {} and {}", id1.0, id2.0)); + } + for (x, y) in zip(params1.borrow().values(), params2.borrow().values()) { + self.unify(*x, *y)?; + } + self.set_a_to_b(a, b); + } + (TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => { + self.unify(*ty1, *ty2)?; + self.set_a_to_b(a, b); + } + (TCall(calls1), TCall(calls2)) => { + // we do not unify individual calls, instead we defer until the unification wtih a + // function definition. + calls2.borrow_mut().extend_from_slice(&calls1.borrow()); + } + (TCall(calls), TFunc(signature)) => { + self.occur_check(a, b)?; + let required: Vec = signature + .borrow() + .args + .iter() + .filter(|v| v.default_value.is_none()) + .map(|v| v.name.clone()) + .rev() + .collect(); + // we unify every calls to the function signature. + for c in calls.borrow().iter() { + let Call { posargs, kwargs, ret, fun } = &*self.calls[c.0].clone(); + let instantiated = self.instantiate_fun(b, &*signature.borrow()); + let r = self.get_ty(instantiated); + let r = r.as_ref(); + let signature; + if let TypeEnum::TFunc(s) = &*r { + signature = s; + } else { + unreachable!(); + } + // we check to make sure that all required arguments (those without default + // arguments) are provided, and do not provide the same argument twice. + let mut required = required.clone(); + let mut all_names: Vec<_> = signature + .borrow() + .args + .iter() + .map(|v| (v.name.clone(), v.ty)) + .rev() + .collect(); + for (i, t) in posargs.iter().enumerate() { + if signature.borrow().args.len() <= i { + return Err("Too many arguments.".to_string()); + } + if !required.is_empty() { + required.pop(); + } + self.unify(all_names.pop().unwrap().1, *t)?; + } + for (k, t) in kwargs.iter() { + if let Some(i) = required.iter().position(|v| v == k) { + required.remove(i); + } + let i = all_names + .iter() + .position(|v| &v.0 == k) + .ok_or_else(|| format!("Unknown keyword argument {}", k))?; + self.unify(all_names.remove(i).1, *t)?; + } + if !required.is_empty() { + return Err("Expected more arguments".to_string()); + } + self.unify(*ret, signature.borrow().ret)?; + *fun.borrow_mut() = Some(instantiated); + } + self.set_a_to_b(a, b); + } + (TFunc(sign1), TFunc(sign2)) => { + let (sign1, sign2) = (&*sign1.borrow(), &*sign2.borrow()); + if !sign1.vars.is_empty() || !sign2.vars.is_empty() { + return Err("Polymorphic function pointer is prohibited.".to_string()); + } + if sign1.args.len() != sign2.args.len() { + return Err("Functions differ in number of parameters.".to_string()); + } + for (x, y) in sign1.args.iter().zip(sign2.args.iter()) { + if x.name != y.name { + return Err("Functions differ in parameter names.".to_string()); + } + if x.default_value != y.default_value { + return Err("Functions differ in optional parameters.".to_string()); + } + self.unify(x.ty, y.ty)?; + } + self.unify(sign1.ret, sign2.ret)?; + self.set_a_to_b(a, b); + } + _ => { + if swapped { + return self.incompatible_types(&*ty_a, &*ty_b); + } else { + self.unify_impl(b, a, true)?; + } + } + } + Ok(()) + } + + /// Get string representation of the type + pub fn stringify(&mut self, ty: Type, obj_to_name: &mut F, var_to_name: &mut G) -> String + where + F: FnMut(usize) -> String, + G: FnMut(u32) -> String, + { + use TypeVarMeta::*; + let ty = self.unification_table.probe_value(ty).clone(); + match ty.as_ref() { + TypeEnum::TRigidVar { id } => var_to_name(*id), + TypeEnum::TVar { id, meta: Generic, .. } => var_to_name(*id), + TypeEnum::TVar { meta: Sequence(map), .. } => { + let fields = map + .borrow() + .iter() + .map(|(k, v)| format!("{}={}", k, self.stringify(*v, obj_to_name, var_to_name))) + .join(", "); + format!("seq[{}]", fields) + } + TypeEnum::TVar { meta: Record(fields), .. } => { + let fields = fields + .borrow() + .iter() + .map(|(k, v)| format!("{}={}", k, self.stringify(*v, obj_to_name, var_to_name))) + .join(", "); + format!("record[{}]", fields) + } + TypeEnum::TTuple { ty } => { + let mut fields = ty.iter().map(|v| self.stringify(*v, obj_to_name, var_to_name)); + format!("tuple[{}]", fields.join(", ")) + } + TypeEnum::TList { ty } => { + format!("list[{}]", self.stringify(*ty, obj_to_name, var_to_name)) + } + TypeEnum::TVirtual { ty } => { + format!("virtual[{}]", self.stringify(*ty, obj_to_name, var_to_name)) + } + TypeEnum::TObj { obj_id, params, .. } => { + let name = obj_to_name(obj_id.0); + let params = params.borrow(); + if !params.is_empty() { + let mut params = + params.values().map(|v| self.stringify(*v, obj_to_name, var_to_name)); + format!("{}[{}]", name, params.join(", ")) + } else { + name + } + } + TypeEnum::TCall { .. } => "call".to_owned(), + TypeEnum::TFunc(signature) => { + let params = signature + .borrow() + .args + .iter() + .map(|arg| { + format!("{}={}", arg.name, self.stringify(arg.ty, obj_to_name, var_to_name)) + }) + .join(", "); + let ret = self.stringify(signature.borrow().ret, obj_to_name, var_to_name); + format!("fn[[{}], {}]", params, ret) + } + } + } + + fn set_a_to_b(&mut self, a: Type, b: Type) { + // unify a and b together, and set the value to b's value. + let table = &mut self.unification_table; + let ty_b = table.probe_value(b).clone(); + table.unify(a, b); + table.set_value(a, ty_b) + } + + fn incompatible_types(&self, a: &TypeEnum, b: &TypeEnum) -> Result<(), String> { + Err(format!("Cannot unify {} with {}", a.get_type_name(), b.get_type_name())) + } + + /// Instantiate a function if it hasn't been instantiated. + /// Returns Some(T) where T is the instantiated type. + /// Returns None if the function is already instantiated. + fn instantiate_fun(&mut self, ty: Type, fun: &FunSignature) -> Type { + let mut instantiated = false; + let mut vars = Vec::new(); + for (k, v) in fun.vars.iter() { + if let TypeEnum::TVar { id, range, .. } = + self.unification_table.probe_value(*v).as_ref() + { + if k != id { + instantiated = true; + break; + } + // actually, if the first check succeeded, the function should be uninstatiated. + // The cloned values must be used and would not be wasted. + vars.push((*k, range.clone())); + } else { + instantiated = true; + break; + } + } + if instantiated { + ty + } else { + let mapping = vars + .into_iter() + .map(|(k, range)| (k, self.get_fresh_var_with_range(range.borrow().as_ref()).0)) + .collect(); + self.subst(ty, &mapping).unwrap_or(ty) + } + } + + /// Substitute type variables within a type into other types. + /// If this returns Some(T), T would be the substituted type. + /// If this returns None, the result type would be the original type + /// (no substitution has to be done). + pub fn subst(&mut self, a: Type, mapping: &VarMap) -> Option { + use TypeVarMeta::*; + let ty = self.unification_table.probe_value(a).clone(); + // this function would only be called when we instantiate functions. + // function type signature should ONLY contain concrete types and type + // variables, i.e. things like TRecord, TCall should not occur, and we + // should be safe to not implement the substitution for those variants. + match &*ty { + TypeEnum::TRigidVar { .. } => None, + TypeEnum::TVar { id, meta: Generic, .. } => mapping.get(&id).cloned(), + TypeEnum::TTuple { ty } => { + let mut new_ty = Cow::from(ty); + for (i, t) in ty.iter().enumerate() { + if let Some(t1) = self.subst(*t, mapping) { + new_ty.to_mut()[i] = t1; + } + } + if matches!(new_ty, Cow::Owned(_)) { + Some(self.add_ty(TypeEnum::TTuple { ty: new_ty.into_owned() })) + } else { + None + } + } + TypeEnum::TList { ty } => { + self.subst(*ty, mapping).map(|t| self.add_ty(TypeEnum::TList { ty: t })) + } + TypeEnum::TVirtual { ty } => { + self.subst(*ty, mapping).map(|t| self.add_ty(TypeEnum::TVirtual { ty: t })) + } + TypeEnum::TObj { obj_id, fields, params } => { + // Type variables in field types must be present in the type parameter. + // If the mapping does not contain any type variables in the + // parameter list, we don't need to substitute the fields. + // This is also used to prevent infinite substitution... + let params = params.borrow(); + let need_subst = params.values().any(|v| { + let ty = self.unification_table.probe_value(*v); + if let TypeEnum::TVar { id, .. } = ty.as_ref() { + mapping.contains_key(&id) + } else { + false + } + }); + if need_subst { + let obj_id = *obj_id; + let params = self.subst_map(¶ms, mapping).unwrap_or_else(|| params.clone()); + let fields = self + .subst_map(&fields.borrow(), mapping) + .unwrap_or_else(|| fields.borrow().clone()); + Some(self.add_ty(TypeEnum::TObj { + obj_id, + params: params.into(), + fields: fields.into(), + })) + } else { + None + } + } + TypeEnum::TFunc(sig) => { + let FunSignature { args, ret, vars: params } = &*sig.borrow(); + let new_params = self.subst_map(params, mapping); + let new_ret = self.subst(*ret, mapping); + let mut new_args = Cow::from(args); + for (i, t) in args.iter().enumerate() { + if let Some(t1) = self.subst(t.ty, mapping) { + let mut t = t.clone(); + t.ty = t1; + new_args.to_mut()[i] = t; + } + } + if new_params.is_some() || new_ret.is_some() || matches!(new_args, Cow::Owned(..)) { + let params = new_params.unwrap_or_else(|| params.clone()); + let ret = new_ret.unwrap_or_else(|| *ret); + let args = new_args.into_owned(); + Some( + self.add_ty(TypeEnum::TFunc( + FunSignature { args, ret, vars: params }.into(), + )), + ) + } else { + None + } + } + _ => unimplemented!(), + } + } + + fn subst_map(&mut self, map: &Mapping, mapping: &VarMap) -> Option> + where + K: std::hash::Hash + std::cmp::Eq + std::clone::Clone, + { + let mut map2 = None; + for (k, v) in map.iter() { + if let Some(v1) = self.subst(*v, mapping) { + if map2.is_none() { + map2 = Some(map.clone()); + } + *map2.as_mut().unwrap().get_mut(k).unwrap() = v1; + } + } + map2 + } + + fn occur_check(&mut self, a: Type, b: Type) -> Result<(), String> { + use TypeVarMeta::*; + if self.unification_table.unioned(a, b) { + return Err("Recursive type is prohibited.".to_owned()); + } + let ty = self.unification_table.probe_value(b).clone(); + + match ty.as_ref() { + TypeEnum::TRigidVar { .. } | TypeEnum::TVar { meta: Generic, .. } => {} + TypeEnum::TVar { meta: Sequence(map), .. } => { + for t in map.borrow().values() { + self.occur_check(a, *t)?; + } + } + TypeEnum::TVar { meta: Record(map), .. } => { + for t in map.borrow().values() { + self.occur_check(a, *t)?; + } + } + TypeEnum::TCall(calls) => { + let call_store = self.calls.clone(); + for t in calls + .borrow() + .iter() + .map(|call| { + let call = call_store[call.0].as_ref(); + chain!(call.posargs.iter(), call.kwargs.values(), once(&call.ret)) + }) + .flatten() + { + self.occur_check(a, *t)?; + } + } + TypeEnum::TTuple { ty } => { + for t in ty.iter() { + self.occur_check(a, *t)?; + } + } + TypeEnum::TList { ty } | TypeEnum::TVirtual { ty } => { + self.occur_check(a, *ty)?; + } + TypeEnum::TObj { params: map, .. } => { + for t in map.borrow().values() { + self.occur_check(a, *t)?; + } + } + TypeEnum::TFunc(sig) => { + let FunSignature { args, ret, vars: params } = &*sig.borrow(); + for t in chain!(args.iter().map(|v| &v.ty), params.values(), once(ret)) { + self.occur_check(a, *t)?; + } + } + } + Ok(()) + } + + fn get_intersection(&mut self, a: Type, b: Type) -> Result, ()> { + use TypeEnum::*; + let x = self.get_ty(a); + let y = self.get_ty(b); + match (x.as_ref(), y.as_ref()) { + (TVar { range: range1, .. }, TVar { meta, range: range2, .. }) => { + // we should restrict range2 + let range1 = range1.borrow(); + // new range is the intersection of them + // empty range indicates no constraint + if !range1.is_empty() { + let range2 = range2.borrow(); + let mut range = Vec::new(); + if range2.is_empty() { + range.extend_from_slice(&range1); + } + for v1 in range2.iter() { + for v2 in range1.iter() { + let result = self.get_intersection(*v1, *v2); + if let Ok(result) = result { + range.push(result.unwrap_or(*v2)); + } + } + } + if range.is_empty() { + Err(()) + } else { + let id = self.var_id + 1; + self.var_id += 1; + let ty = TVar { id, meta: meta.clone(), range: range.into() }; + Ok(Some(self.unification_table.new_key(ty.into()))) + } + } else { + Ok(Some(b)) + } + } + (_, TVar { range, .. }) => { + // range should be restricted to the left hand side + let range = range.borrow(); + if range.is_empty() { + Ok(Some(a)) + } else { + for v in range.iter() { + let result = self.get_intersection(a, *v); + if let Ok(result) = result { + return Ok(result.or(Some(a))); + } + } + Err(()) + } + } + (TVar { id, range, .. }, _) => { + self.check_var_compatibility(*id, b, &range.borrow()).or(Err(())) + } + (TTuple { ty: ty1 }, TTuple { ty: ty2 }) => { + if ty1.len() != ty2.len() { + return Err(()); + } + let mut need_new = false; + let mut ty = ty1.clone(); + for (a, b) in zip(ty1.iter(), ty2.iter()) { + let result = self.get_intersection(*a, *b)?; + ty.push(result.unwrap_or(*a)); + if result.is_some() { + need_new = true; + } + } + if need_new { + Ok(Some(self.add_ty(TTuple { ty }))) + } else { + Ok(None) + } + } + (TList { ty: ty1 }, TList { ty: ty2 }) => { + Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TList { ty }))) + } + (TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => { + Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TVirtual { ty }))) + } + (TObj { obj_id: id1, .. }, TObj { obj_id: id2, .. }) => { + if id1 == id2 { + Ok(None) + } else { + Err(()) + } + } + // don't deal with function shape for now + _ => Err(()), + } + } + + fn check_var_compatibility( + &mut self, + id: u32, + b: Type, + range: &[Type], + ) -> Result, String> { + if range.is_empty() { + return Ok(None); + } + for t in range.iter() { + let result = self.get_intersection(*t, b); + if let Ok(result) = result { + return Ok(result); + } + } + return Err(format!( + "Cannot unify type variable {} with {} due to incompatible value range", + id, + self.get_ty(b).get_type_name() + )); + } +} diff --git a/nac3core/src/typecheck/typedef/test.rs b/nac3core/src/typecheck/typedef/test.rs new file mode 100644 index 00000000..2972b3f9 --- /dev/null +++ b/nac3core/src/typecheck/typedef/test.rs @@ -0,0 +1,534 @@ +use super::*; +use indoc::indoc; +use itertools::Itertools; +use std::collections::HashMap; +use test_case::test_case; + +impl Unifier { + /// Check whether two types are equal. + fn eq(&mut self, a: Type, b: Type) -> bool { + use TypeVarMeta::*; + if a == b { + return true; + } + let (ty_a, ty_b) = { + let table = &mut self.unification_table; + if table.unioned(a, b) { + return true; + } + (table.probe_value(a).clone(), table.probe_value(b).clone()) + }; + + match (&*ty_a, &*ty_b) { + ( + TypeEnum::TVar { meta: Generic, id: id1, .. }, + TypeEnum::TVar { meta: Generic, id: id2, .. }, + ) => id1 == id2, + ( + TypeEnum::TVar { meta: Sequence(map1), .. }, + TypeEnum::TVar { meta: Sequence(map2), .. }, + ) => self.map_eq(&map1.borrow(), &map2.borrow()), + (TypeEnum::TTuple { ty: ty1 }, TypeEnum::TTuple { ty: ty2 }) => { + ty1.len() == ty2.len() + && ty1.iter().zip(ty2.iter()).all(|(t1, t2)| self.eq(*t1, *t2)) + } + (TypeEnum::TList { ty: ty1 }, TypeEnum::TList { ty: ty2 }) + | (TypeEnum::TVirtual { ty: ty1 }, TypeEnum::TVirtual { ty: ty2 }) => { + self.eq(*ty1, *ty2) + } + ( + TypeEnum::TVar { meta: Record(fields1), .. }, + TypeEnum::TVar { meta: Record(fields2), .. }, + ) => self.map_eq(&fields1.borrow(), &fields2.borrow()), + ( + TypeEnum::TObj { obj_id: id1, params: params1, .. }, + TypeEnum::TObj { obj_id: id2, params: params2, .. }, + ) => id1 == id2 && self.map_eq(¶ms1.borrow(), ¶ms2.borrow()), + // TCall and TFunc are not yet implemented + _ => false, + } + } + + fn map_eq(&mut self, map1: &Mapping, map2: &Mapping) -> bool + where + K: std::hash::Hash + std::cmp::Eq + std::clone::Clone, + { + if map1.len() != map2.len() { + return false; + } + for (k, v) in map1.iter() { + if !map2.get(k).map(|v1| self.eq(*v, *v1)).unwrap_or(false) { + return false; + } + } + true + } +} + +struct TestEnvironment { + pub unifier: Unifier, + type_mapping: HashMap, +} + +impl TestEnvironment { + fn new() -> TestEnvironment { + let mut unifier = Unifier::new(); + let mut type_mapping = HashMap::new(); + + type_mapping.insert( + "int".into(), + unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(0), + fields: HashMap::new().into(), + params: HashMap::new().into(), + }), + ); + type_mapping.insert( + "float".into(), + unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(1), + fields: HashMap::new().into(), + params: HashMap::new().into(), + }), + ); + type_mapping.insert( + "bool".into(), + unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(2), + fields: HashMap::new().into(), + params: HashMap::new().into(), + }), + ); + let (v0, id) = unifier.get_fresh_var(); + type_mapping.insert( + "Foo".into(), + unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(3), + fields: [("a".into(), v0)].iter().cloned().collect::>().into(), + params: [(id, v0)].iter().cloned().collect::>().into(), + }), + ); + + TestEnvironment { unifier, type_mapping } + } + + fn parse(&mut self, typ: &str, mapping: &Mapping) -> Type { + let result = self.internal_parse(typ, mapping); + assert!(result.1.is_empty()); + result.0 + } + + fn internal_parse<'a, 'b>( + &'a mut self, + typ: &'b str, + mapping: &Mapping, + ) -> (Type, &'b str) { + // for testing only, so we can just panic when the input is malformed + let end = typ.find(|c| ['[', ',', ']', '='].contains(&c)).unwrap_or_else(|| typ.len()); + match &typ[..end] { + "Tuple" => { + let mut s = &typ[end..]; + assert!(&s[0..1] == "["); + let mut ty = Vec::new(); + while &s[0..1] != "]" { + let result = self.internal_parse(&s[1..], mapping); + ty.push(result.0); + s = result.1; + } + (self.unifier.add_ty(TypeEnum::TTuple { ty }), &s[1..]) + } + "List" => { + assert!(&typ[end..end + 1] == "["); + let (ty, s) = self.internal_parse(&typ[end + 1..], mapping); + assert!(&s[0..1] == "]"); + (self.unifier.add_ty(TypeEnum::TList { ty }), &s[1..]) + } + "Record" => { + let mut s = &typ[end..]; + assert!(&s[0..1] == "["); + let mut fields = HashMap::new(); + while &s[0..1] != "]" { + let eq = s.find('=').unwrap(); + let key = s[1..eq].to_string(); + let result = self.internal_parse(&s[eq + 1..], mapping); + fields.insert(key, result.0); + s = result.1; + } + (self.unifier.add_record(fields), &s[1..]) + } + x => { + let mut s = &typ[end..]; + let ty = mapping.get(x).cloned().unwrap_or_else(|| { + // mapping should be type variables, type_mapping should be concrete types + // we should not resolve the type of type variables. + let mut ty = *self.type_mapping.get(x).unwrap(); + let te = self.unifier.get_ty(ty); + if let TypeEnum::TObj { params, .. } = &*te.as_ref() { + let params = params.borrow(); + if !params.is_empty() { + assert!(&s[0..1] == "["); + let mut p = Vec::new(); + while &s[0..1] != "]" { + let result = self.internal_parse(&s[1..], mapping); + p.push(result.0); + s = result.1; + } + s = &s[1..]; + ty = self + .unifier + .subst(ty, ¶ms.keys().cloned().zip(p.into_iter()).collect()) + .unwrap_or(ty); + } + } + ty + }); + (ty, s) + } + } + } +} + +#[test_case(2, + &[("v1", "v2"), ("v2", "float")], + &[("v1", "float"), ("v2", "float")] + ; "simple variable" +)] +#[test_case(2, + &[("v1", "List[v2]"), ("v1", "List[float]")], + &[("v1", "List[float]"), ("v2", "float")] + ; "list element" +)] +#[test_case(3, + &[ + ("v1", "Record[a=v3,b=v3]"), + ("v2", "Record[b=float,c=v3]"), + ("v1", "v2") + ], + &[ + ("v1", "Record[a=float,b=float,c=float]"), + ("v2", "Record[a=float,b=float,c=float]"), + ("v3", "float") + ] + ; "record merge" +)] +#[test_case(3, + &[ + ("v1", "Record[a=float]"), + ("v2", "Foo[v3]"), + ("v1", "v2") + ], + &[ + ("v1", "Foo[float]"), + ("v3", "float") + ] + ; "record obj merge" +)] +/// Test cases for valid unifications. +fn test_unify( + variable_count: u32, + unify_pairs: &[(&'static str, &'static str)], + verify_pairs: &[(&'static str, &'static str)], +) { + let unify_count = unify_pairs.len(); + // test all permutations... + for perm in unify_pairs.iter().permutations(unify_count) { + let mut env = TestEnvironment::new(); + let mut mapping = HashMap::new(); + for i in 1..=variable_count { + let v = env.unifier.get_fresh_var(); + mapping.insert(format!("v{}", i), v.0); + } + // unification may have side effect when we do type resolution, so freeze the types + // before doing unification. + let mut pairs = Vec::new(); + for (a, b) in perm.iter() { + let t1 = env.parse(a, &mapping); + let t2 = env.parse(b, &mapping); + pairs.push((t1, t2)); + } + for (t1, t2) in pairs { + env.unifier.unify(t1, t2).unwrap(); + } + for (a, b) in verify_pairs.iter() { + println!("{} = {}", a, b); + let t1 = env.parse(a, &mapping); + let t2 = env.parse(b, &mapping); + assert!(env.unifier.eq(t1, t2)); + } + } +} + +#[test_case(2, + &[ + ("v1", "Tuple[int]"), + ("v2", "List[int]"), + ], + (("v1", "v2"), "Cannot unify TList with TTuple") + ; "type mismatch" +)] +#[test_case(2, + &[ + ("v1", "Tuple[int]"), + ("v2", "Tuple[float]"), + ], + (("v1", "v2"), "Cannot unify objects with ID 0 and 1") + ; "tuple parameter mismatch" +)] +#[test_case(2, + &[ + ("v1", "Tuple[int,int]"), + ("v2", "Tuple[int]"), + ], + (("v1", "v2"), "Cannot unify tuples with length 2 and 1") + ; "tuple length mismatch" +)] +#[test_case(3, + &[ + ("v1", "Record[a=float,b=int]"), + ("v2", "Foo[v3]"), + ], + (("v1", "v2"), "No such attribute b") + ; "record obj merge" +)] +#[test_case(2, + &[ + ("v1", "List[v2]"), + ], + (("v1", "v2"), "Recursive type is prohibited.") + ; "recursive type for lists" +)] +/// Test cases for invalid unifications. +fn test_invalid_unification( + variable_count: u32, + unify_pairs: &[(&'static str, &'static str)], + errornous_pair: ((&'static str, &'static str), &'static str), +) { + let mut env = TestEnvironment::new(); + let mut mapping = HashMap::new(); + for i in 1..=variable_count { + let v = env.unifier.get_fresh_var(); + mapping.insert(format!("v{}", i), v.0); + } + // unification may have side effect when we do type resolution, so freeze the types + // before doing unification. + let mut pairs = Vec::new(); + for (a, b) in unify_pairs.iter() { + let t1 = env.parse(a, &mapping); + let t2 = env.parse(b, &mapping); + pairs.push((t1, t2)); + } + let (t1, t2) = + (env.parse(errornous_pair.0 .0, &mapping), env.parse(errornous_pair.0 .1, &mapping)); + for (a, b) in pairs { + env.unifier.unify(a, b).unwrap(); + } + assert_eq!(env.unifier.unify(t1, t2), Err(errornous_pair.1.to_string())); +} + +#[test] +fn test_virtual() { + let mut env = TestEnvironment::new(); + let int = env.parse("int", &HashMap::new()); + let fun = env.unifier.add_ty(TypeEnum::TFunc( + FunSignature { args: vec![], ret: int, vars: HashMap::new() }.into(), + )); + let bar = env.unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(5), + fields: [("f".to_string(), fun), ("a".to_string(), int)] + .iter() + .cloned() + .collect::>() + .into(), + params: HashMap::new().into(), + }); + let v0 = env.unifier.get_fresh_var().0; + let v1 = env.unifier.get_fresh_var().0; + + let a = env.unifier.add_ty(TypeEnum::TVirtual { ty: bar }); + let b = env.unifier.add_ty(TypeEnum::TVirtual { ty: v0 }); + let c = env.unifier.add_record([("f".to_string(), v1)].iter().cloned().collect()); + env.unifier.unify(a, b).unwrap(); + env.unifier.unify(b, c).unwrap(); + assert!(env.unifier.eq(v1, fun)); + + let d = env.unifier.add_record([("a".to_string(), v1)].iter().cloned().collect()); + assert_eq!(env.unifier.unify(b, d), Err("Cannot access field a for virtual type".to_string())); + + let d = env.unifier.add_record([("b".to_string(), v1)].iter().cloned().collect()); + assert_eq!(env.unifier.unify(b, d), Err("No such attribute b".to_string())); +} + +#[test] +fn test_typevar_range() { + let mut env = TestEnvironment::new(); + let int = env.parse("int", &HashMap::new()); + let boolean = env.parse("bool", &HashMap::new()); + let float = env.parse("float", &HashMap::new()); + let int_list = env.parse("List[int]", &HashMap::new()); + let float_list = env.parse("List[float]", &HashMap::new()); + + // unification between v and int + // where v in (int, bool) + let v = env.unifier.get_fresh_var_with_range(&[int, boolean]).0; + env.unifier.unify(int, v).unwrap(); + + // unification between v and List[int] + // where v in (int, bool) + let v = env.unifier.get_fresh_var_with_range(&[int, boolean]).0; + assert_eq!( + env.unifier.unify(int_list, v), + Err("Cannot unify type variable 3 with TList due to incompatible value range".to_string()) + ); + + // unification between v and float + // where v in (int, bool) + let v = env.unifier.get_fresh_var_with_range(&[int, boolean]).0; + assert_eq!( + env.unifier.unify(float, v), + Err("Cannot unify type variable 4 with TObj due to incompatible value range".to_string()) + ); + + let v1 = env.unifier.get_fresh_var_with_range(&[int, boolean]).0; + let v1_list = env.unifier.add_ty(TypeEnum::TList { ty: v1 }); + let v = env.unifier.get_fresh_var_with_range(&[int, v1_list]).0; + // unification between v and int + // where v in (int, List[v1]), v1 in (int, bool) + env.unifier.unify(int, v).unwrap(); + + let v = env.unifier.get_fresh_var_with_range(&[int, v1_list]).0; + // unification between v and List[int] + // where v in (int, List[v1]), v1 in (int, bool) + env.unifier.unify(int_list, v).unwrap(); + + let v = env.unifier.get_fresh_var_with_range(&[int, v1_list]).0; + // unification between v and List[float] + // where v in (int, List[v1]), v1 in (int, bool) + assert_eq!( + env.unifier.unify(float_list, v), + Err("Cannot unify type variable 8 with TList due to incompatible value range".to_string()) + ); + + let a = env.unifier.get_fresh_var_with_range(&[int, float]).0; + let b = env.unifier.get_fresh_var_with_range(&[boolean, float]).0; + env.unifier.unify(a, b).unwrap(); + env.unifier.unify(a, float).unwrap(); + + let a = env.unifier.get_fresh_var_with_range(&[int, float]).0; + let b = env.unifier.get_fresh_var_with_range(&[boolean, float]).0; + env.unifier.unify(a, b).unwrap(); + assert_eq!( + env.unifier.unify(a, int), + Err("Cannot unify type variable 12 with TObj due to incompatible value range".into()) + ); + + let a = env.unifier.get_fresh_var_with_range(&[int, float]).0; + let b = env.unifier.get_fresh_var_with_range(&[boolean, float]).0; + let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a }); + let a_list = env.unifier.get_fresh_var_with_range(&[a_list]).0; + let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b }); + let b_list = env.unifier.get_fresh_var_with_range(&[b_list]).0; + env.unifier.unify(a_list, b_list).unwrap(); + let float_list = env.unifier.add_ty(TypeEnum::TList { ty: float }); + env.unifier.unify(a_list, float_list).unwrap(); + // previous unifications should not affect a and b + env.unifier.unify(a, int).unwrap(); + + let a = env.unifier.get_fresh_var_with_range(&[int, float]).0; + let b = env.unifier.get_fresh_var_with_range(&[boolean, float]).0; + let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a }); + let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b }); + env.unifier.unify(a_list, b_list).unwrap(); + let int_list = env.unifier.add_ty(TypeEnum::TList { ty: int }); + assert_eq!( + env.unifier.unify(a_list, int_list), + Err("Cannot unify type variable 19 with TObj due to incompatible value range".into()) + ); + + let a = env.unifier.get_fresh_var_with_range(&[int, float]).0; + let b = env.unifier.get_fresh_var().0; + let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a }); + let a_list = env.unifier.get_fresh_var_with_range(&[a_list]).0; + let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b }); + env.unifier.unify(a_list, b_list).unwrap(); + assert_eq!( + env.unifier.unify(b, boolean), + Err("Cannot unify type variable 21 with TObj due to incompatible value range".into()) + ); +} + +#[test] +fn test_rigid_var() { + let mut env = TestEnvironment::new(); + let a = env.unifier.get_fresh_rigid_var().0; + let b = env.unifier.get_fresh_rigid_var().0; + let x = env.unifier.get_fresh_var().0; + let list_a = env.unifier.add_ty(TypeEnum::TList { ty: a }); + let list_x = env.unifier.add_ty(TypeEnum::TList { ty: x }); + let int = env.parse("int", &HashMap::new()); + let list_int = env.parse("List[int]", &HashMap::new()); + + assert_eq!(env.unifier.unify(a, b), Err("Cannot unify TRigidVar with TRigidVar".to_string())); + env.unifier.unify(list_a, list_x).unwrap(); + assert_eq!( + env.unifier.unify(list_x, list_int), + Err("Cannot unify TObj with TRigidVar".to_string()) + ); + + env.unifier.replace_rigid_var(a, int); + env.unifier.unify(list_x, list_int).unwrap(); +} + +#[test] +fn test_instantiation() { + let mut env = TestEnvironment::new(); + let int = env.parse("int", &HashMap::new()); + let boolean = env.parse("bool", &HashMap::new()); + let float = env.parse("float", &HashMap::new()); + let list_int = env.parse("List[int]", &HashMap::new()); + + let obj_map: HashMap<_, _> = + [(0usize, "int"), (1, "float"), (2, "bool")].iter().cloned().collect(); + + let v = env.unifier.get_fresh_var_with_range(&[int, boolean]).0; + let list_v = env.unifier.add_ty(TypeEnum::TList { ty: v }); + let v1 = env.unifier.get_fresh_var_with_range(&[list_v, int]).0; + let v2 = env.unifier.get_fresh_var_with_range(&[list_int, float]).0; + let t = env.unifier.get_fresh_rigid_var().0; + let tuple = env.unifier.add_ty(TypeEnum::TTuple { ty: vec![v, v1, v2] }); + let v3 = env.unifier.get_fresh_var_with_range(&[tuple, t]).0; + // t = TypeVar('t') + // v = TypeVar('v', int, bool) + // v1 = TypeVar('v1', 'list[v]', int) + // v2 = TypeVar('v2', 'list[int]', float) + // v3 = TypeVar('v3', tuple[v, v1, v2], t) + // what values can v3 take? + + let types = env.unifier.get_instantiations(v3).unwrap(); + let expected_types = indoc! {" + tuple[bool, int, float] + tuple[bool, int, list[int]] + tuple[bool, list[bool], float] + tuple[bool, list[bool], list[int]] + tuple[bool, list[int], float] + tuple[bool, list[int], list[int]] + tuple[int, int, float] + tuple[int, int, list[int]] + tuple[int, list[bool], float] + tuple[int, list[bool], list[int]] + tuple[int, list[int], float] + tuple[int, list[int], list[int]] + v5" + } + .split('\n') + .collect_vec(); + let types = types + .iter() + .map(|ty| { + env.unifier.stringify(*ty, &mut |i| obj_map.get(&i).unwrap().to_string(), &mut |i| { + format!("v{}", i) + }) + }) + .sorted() + .collect_vec(); + assert_eq!(expected_types, types); +} diff --git a/nac3core/src/typecheck/unification_table.rs b/nac3core/src/typecheck/unification_table.rs new file mode 100644 index 00000000..7475afce --- /dev/null +++ b/nac3core/src/typecheck/unification_table.rs @@ -0,0 +1,87 @@ +use std::rc::Rc; + +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +pub struct UnificationKey(usize); + +pub struct UnificationTable { + parents: Vec, + ranks: Vec, + values: Vec, +} + +impl UnificationTable { + pub fn new() -> UnificationTable { + UnificationTable { parents: Vec::new(), ranks: Vec::new(), values: Vec::new() } + } + + pub fn new_key(&mut self, v: V) -> UnificationKey { + let index = self.parents.len(); + self.parents.push(index); + self.ranks.push(0); + self.values.push(v); + UnificationKey(index) + } + + pub fn unify(&mut self, a: UnificationKey, b: UnificationKey) { + let mut a = self.find(a); + let mut b = self.find(b); + if a == b { + return; + } + if self.ranks[a] < self.ranks[b] { + std::mem::swap(&mut a, &mut b); + } + self.parents[b] = a; + if self.ranks[a] == self.ranks[b] { + self.ranks[a] += 1; + } + } + + pub fn probe_value(&mut self, a: UnificationKey) -> &V { + let index = self.find(a); + &self.values[index] + } + + pub fn set_value(&mut self, a: UnificationKey, v: V) { + let index = self.find(a); + self.values[index] = v; + } + + pub fn unioned(&mut self, a: UnificationKey, b: UnificationKey) -> bool { + self.find(a) == self.find(b) + } + + pub fn get_representative(&mut self, key: UnificationKey) -> UnificationKey { + UnificationKey(self.find(key)) + } + + fn find(&mut self, key: UnificationKey) -> usize { + let mut root = key.0; + let mut parent = self.parents[root]; + while root != parent { + // a = parent.parent + let a = self.parents[parent]; + // root.parent = parent.parent + self.parents[root] = a; + root = parent; + // parent = root.parent + parent = a; + } + parent + } +} + +impl UnificationTable> +where + V: Clone, +{ + pub fn get_send(&self) -> UnificationTable { + let values = self.values.iter().map(|v| v.as_ref().clone()).collect(); + UnificationTable { parents: self.parents.clone(), ranks: self.ranks.clone(), values } + } + + pub fn from_send(table: &UnificationTable) -> UnificationTable> { + let values = table.values.iter().cloned().map(Rc::new).collect(); + UnificationTable { parents: table.parents.clone(), ranks: table.ranks.clone(), values } + } +} diff --git a/nac3core/src/typedef.rs b/nac3core/src/typedef.rs deleted file mode 100644 index bec61fd1..00000000 --- a/nac3core/src/typedef.rs +++ /dev/null @@ -1,60 +0,0 @@ -use std::collections::HashMap; -use std::rc::Rc; - -#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)] -pub struct PrimitiveId(pub(crate) usize); - -#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)] -pub struct ClassId(pub(crate) usize); - -#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)] -pub struct ParamId(pub(crate) usize); - -#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)] -pub struct VariableId(pub(crate) usize); - -#[derive(PartialEq, Eq, Clone, Hash, Debug)] -pub enum TypeEnum { - BotType, - SelfType, - PrimitiveType(PrimitiveId), - ClassType(ClassId), - VirtualClassType(ClassId), - ParametricType(ParamId, Vec>), - TypeVariable(VariableId), -} - -pub type Type = Rc; - -#[derive(Clone)] -pub struct FnDef { - // we assume methods first argument to be SelfType, - // so the first argument is not contained here - pub args: Vec, - pub result: Option, -} - -#[derive(Clone)] -pub struct TypeDef<'a> { - pub name: &'a str, - pub fields: HashMap<&'a str, Type>, - pub methods: HashMap<&'a str, FnDef>, -} - -#[derive(Clone)] -pub struct ClassDef<'a> { - pub base: TypeDef<'a>, - pub parents: Vec, -} - -#[derive(Clone)] -pub struct ParametricDef<'a> { - pub base: TypeDef<'a>, - pub params: Vec, -} - -#[derive(Clone)] -pub struct VarDef<'a> { - pub name: &'a str, - pub bound: Vec, -} diff --git a/shell.nix b/shell.nix index 858e68b9..35be2055 100644 --- a/shell.nix +++ b/shell.nix @@ -4,6 +4,6 @@ in pkgs.stdenv.mkDerivation { name = "nac3-env"; buildInputs = with pkgs; [ - llvm_10 clang_10 cargo rustc libffi libxml2 clippy + llvm_11 clang_11 cargo rustc libffi libxml2 clippy ]; }