forked from M-Labs/nac3
Compare commits
393 Commits
Author | SHA1 | Date |
---|---|---|
lyken | 676412fe6d | |
lyken | 8b9df7252f | |
lyken | 6979843431 | |
lyken | fed1361c6a | |
lyken | aa94e0c8a4 | |
lyken | f523e26227 | |
lyken | f026b48e2a | |
lyken | dc874f2994 | |
lyken | 95de0800b4 | |
lyken | 3d71c6a850 | |
David Mak | be55e2ac80 | |
David Mak | 79c8b759ad | |
David Mak | 4798c53a21 | |
David Mak | 23974feae7 | |
David Mak | 40a3bded36 | |
lyken | c4420e6ab9 | |
lyken | fd36f78005 | |
lyken | 8168692cc3 | |
David Mak | 53d44b9595 | |
David Mak | 6153f94b05 | |
David Mak | 4730b595f3 | |
David Mak | c2fdb12397 | |
David Mak | 82bf14785b | |
David Mak | 2d4329e23c | |
David Mak | 679656f9e1 | |
David Mak | 210d9e2334 | |
David Mak | 181ac3ec1a | |
David Mak | 3acdfb304d | |
David Mak | 6e24da9cc5 | |
David Mak | f0ab1b858a | |
lyken | 08129cc635 | |
David Mak | ad4832dcf4 | |
lyken | 520bbb246b | |
lyken | b857f1e403 | |
Sebastien Bourdeauducq | fa8af37e84 | |
David Mak | 23b2fee4e7 | |
David Mak | ed79d5bb9e | |
David Mak | c35ad06949 | |
David Mak | 135ef557f9 | |
David Mak | a176c3eb70 | |
David Mak | 2cf79510c2 | |
David Mak | b6ff75dcaf | |
David Mak | 588c15f80d | |
David Mak | 82cc693b11 | |
David Mak | 520e1adc56 | |
David Mak | 73e81259f3 | |
David Mak | 7627acea41 | |
David Mak | a777099ea8 | |
David Mak | 876e6ea7b8 | |
David Mak | 30c6cffbad | |
David Mak | 51671800b6 | |
David Mak | 7195476edb | |
David Mak | eecba0b71d | |
David Mak | 7b4253ccd8 | |
David Mak | f58c3a11f8 | |
David Mak | d0766a116f | |
David Mak | 64a3751fc2 | |
David Mak | 9566047241 | |
David Mak | 062e318dd5 | |
David Mak | c4dc36ae99 | |
David Mak | baac348ee6 | |
David Mak | 847615fc2f | |
David Mak | 5dfcc63978 | |
David Mak | 025b3cd02f | |
David Mak | e0f440040c | |
David Mak | f0715e2b6d | |
David Mak | e7fca67786 | |
David Mak | 52c731c312 | |
David Mak | 00d1b9be9b | |
David Mak | 8404d4c4dc | |
David Mak | e614dd4257 | |
David Mak | 937a8b9698 | |
David Mak | 876ad6c59c | |
David Mak | a920fe0501 | |
David Mak | 727a1886b3 | |
David Mak | 6af13a8261 | |
David Mak | 3540d0ab29 | |
David Mak | 3a6c53d760 | |
David Mak | 87bc34f7ec | |
David Mak | f50a5f0345 | |
David Mak | a77fd213e0 | |
David Mak | 8f1497df83 | |
David Mak | 5ca2dbeec8 | |
David Mak | 9a98cde595 | |
David Mak | 5ba8601b39 | |
David Mak | 26a01b14d5 | |
David Mak | d5f4817134 | |
David Mak | 789bfb5a26 | |
David Mak | 4bb0e60981 | |
Sebastien Bourdeauducq | 623fcf85af | |
David Mak | 13f06f3e29 | |
David Mak | f0da9c0283 | |
David Mak | 2c4bf3ce59 | |
David Mak | e980f19c93 | |
David Mak | cfbc37c1ed | |
David Mak | 50264e8750 | |
David Mak | 1b77e62901 | |
David Mak | fd44ee6887 | |
David Mak | c8866b1534 | |
David Mak | 84a888758a | |
David Mak | 9d550725b7 | |
David Mak | 2edc1de0b6 | |
David Mak | c3b122acfc | |
David Mak | a94927a11d | |
David Mak | ebf86cd134 | |
David Mak | cccd8f2d00 | |
David Mak | 3292aed099 | |
David Mak | 96b7f29679 | |
David Mak | 3d2abf73c8 | |
David Mak | f682e9bf7a | |
David Mak | b26cb2b360 | |
David Mak | 2317516cf6 | |
David Mak | 77de24ef74 | |
David Mak | 234a6bde2a | |
David Mak | c3db6297d9 | |
David Mak | 82fdb02d13 | |
David Mak | 4efdd17513 | |
David Mak | 49de81ef1e | |
David Mak | 8492503af2 | |
Sebastien Bourdeauducq | e1dbe2526a | |
Sebastien Bourdeauducq | f37de381ce | |
Sebastien Bourdeauducq | 4452c8986a | |
David Mak | 22e831cb76 | |
David Mak | cc538d221a | |
David Mak | 0d5c53e60c | |
David Mak | 976a9512c1 | |
David Mak | 1eacaf9afa | |
David Mak | 8c7e44098a | |
David Mak | 282a3e1911 | |
David Mak | 5cecb2bb74 | |
David Mak | 1963c30744 | |
David Mak | 27011f385b | |
David Mak | d6302b6ec8 | |
David Mak | fef4b2a5ce | |
David Mak | b3736c3e99 | |
Sebastien Bourdeauducq | e328e44c9a | |
Sebastien Bourdeauducq | 9e4e90f8a0 | |
David Mak | 8470915809 | |
David Mak | 148900302e | |
David Mak | 5ee08b585f | |
David Mak | f1581299fc | |
David Mak | af95ba5012 | |
David Mak | 9c9756be33 | |
David Mak | 2a922c7480 | |
David Mak | e3e2c36ef4 | |
David Mak | 4f9a0110c4 | |
David Mak | 12c0eed0a3 | |
David Mak | c679474f5c | |
Sébastien Bourdeauducq | ab3fa05996 | |
David Mak | 140f8f8a08 | |
David Mak | 27fcf8926e | |
David Mak | afa7d9b100 | |
David Mak | c395472094 | |
David Mak | 03870f222d | |
David Mak | e435b25756 | |
David Mak | bd792904f9 | |
David Mak | 1c3a823670 | |
David Mak | f01d833d48 | |
David Mak | 9d64e606f4 | |
David Mak | 6dccb343bb | |
Sebastien Bourdeauducq | d47534e2ad | |
David Mak | 8886964776 | |
David Mak | f09f3c27a5 | |
David Mak | 0bbc9ce6f5 | |
David Mak | 457d3b6cd7 | |
David Mak | 5f692debd8 | |
David Mak | c7735d935b | |
David Mak | b47ac1b89b | |
David Mak | a19f1065e3 | |
Sebastien Bourdeauducq | 5bf05c6a69 | |
David Mak | 32746c37be | |
David Mak | 1d6291b9ba | |
David Mak | 16655959f2 | |
David Mak | beee3e1f7e | |
David Mak | d4c109b6ef | |
David Mak | 5ffd06dd61 | |
David Mak | 95d0c3c93c | |
David Mak | bd3d67f3d6 | |
David Mak | ddfb532b80 | |
David Mak | 02933753ca | |
David Mak | a1f244834f | |
David Mak | d304afd333 | |
David Mak | ef04696b02 | |
David Mak | 4dc5dbb856 | |
David Mak | fd9f66b8d9 | |
David Mak | 5182453bd9 | |
Sebastien Bourdeauducq | 68556da5fd | |
David Mak | 983f080ea7 | |
David Mak | 031e660f18 | |
David Mak | b6dfcfcc38 | |
David Mak | c93ad152d7 | |
David Mak | 68b97347b1 | |
David Mak | 875d534de4 | |
Sebastien Bourdeauducq | adadf56e2b | |
Sebastien Bourdeauducq | 9f610745b7 | |
Sebastien Bourdeauducq | 98199768e3 | |
Sebastien Bourdeauducq | bfa9ceaae3 | |
Sebastien Bourdeauducq | 120f8da5c7 | |
Sebastien Bourdeauducq | cee62aa6c5 | |
Sebastien Bourdeauducq | fcda360ad6 | |
Sebastien Bourdeauducq | 87c20ada48 | |
Sebastien Bourdeauducq | 38e968cff6 | |
David Mak | 5c5620692f | |
David Mak | 0af1e37e99 | |
David Mak | 854e33ed48 | |
Sebastien Bourdeauducq | f020d61cbb | |
David Mak | 10538b5296 | |
David Mak | d322c91697 | |
David Mak | 3231eb0d78 | |
Sebastien Bourdeauducq | 1ca4de99b9 | |
Sebastien Bourdeauducq | bf4b1aae47 | |
David Mak | 08a5050f9a | |
David Mak | c2ab6b58ff | |
David Mak | 0a84f7ac31 | |
David Mak | fd787ca3f5 | |
David Mak | 4dbe07a0c0 | |
David Mak | 2e055e8ab1 | |
David Mak | 9d737743c1 | |
David Mak | c6b9aefe00 | |
David Mak | 8ad09748d0 | |
David Mak | 7a5a2db842 | |
David Mak | 447eb9c387 | |
David Mak | 92d6f0a5d3 | |
David Mak | 7e4dab15ae | |
David Mak | ff1fed112c | |
David Mak | 36a6a7b8cd | |
David Mak | 2b635a0b97 | |
David Mak | 60ad100fbb | |
David Mak | 316f0824d8 | |
David Mak | 7cf7634985 | |
David Mak | 068f0d9faf | |
David Mak | 95810d4229 | |
David Mak | 630897b779 | |
Sebastien Bourdeauducq | e546535df0 | |
David Mak | 352f70b885 | |
David Mak | e95586f61e | |
David Mak | bb27e3d400 | |
David Mak | bb5147521f | |
David Mak | 9518d3fe14 | |
David Mak | cbd333ab10 | |
David Mak | 65d6104d00 | |
David Mak | 8373a6cb0f | |
David Mak | f75ae78677 | |
Sebastien Bourdeauducq | ea2ab0ef7c | |
David Mak | e49b760e34 | |
David Mak | aa92778363 | |
David Mak | e1487ed335 | |
David Mak | 73500c9081 | |
David Mak | 9ca34c714e | |
David Mak | 7fc2a30c14 | |
David Mak | 950f431483 | |
David Mak | a50c690428 | |
David Mak | 48eb64403f | |
David Mak | 2c44b58bb8 | |
David Mak | 50230e61f3 | |
David Mak | 0205161e35 | |
David Mak | a2fce49b26 | |
David Mak | 60a503a791 | |
David Mak | 0c49b30a90 | |
David Mak | c7de22287e | |
David Mak | 1a54aaa1c0 | |
David Mak | c5629d4eb5 | |
David Mak | a79286113e | |
Sebastien Bourdeauducq | 901e921e00 | |
Sebastien Bourdeauducq | 45a323e969 | |
Sebastien Bourdeauducq | 11759a722f | |
David Mak | 480a4bc0ad | |
Sebastien Bourdeauducq | a1d3093196 | |
Sebastien Bourdeauducq | 85c5f2c044 | |
David Mak | f34c6053d6 | |
David Mak | e8a5f0dfef | |
David Mak | 7140901261 | |
David Mak | 2a775d822e | |
David Mak | 1659c3e724 | |
David Mak | f53cb804ec | |
David Mak | 279376a373 | |
David Mak | b6afd1bfda | |
David Mak | be3e8f50a2 | |
David Mak | 059d3da58b | |
David Mak | 9b28f23d8c | |
Sebastien Bourdeauducq | 119f4d63e9 | |
Sebastien Bourdeauducq | 458fa12788 | |
David Mak | 48c6498d1f | |
David Mak | 2a38d5160e | |
David Mak | b39831b388 | |
David Mak | cb39f61e79 | |
David Mak | 176f250bdb | |
David Mak | acdb1de6fe | |
David Mak | 31dcd2dde9 | |
David Mak | fc93fc2f0e | |
David Mak | dd42022633 | |
David Mak | 6dfc43c8b0 | |
David Mak | ab2360d7a0 | |
David Mak | ee1ee4ab3b | |
David Mak | 3e430b9b40 | |
David Mak | 9e57498958 | |
David Mak | 769fd01df8 | |
David Mak | 411837cacd | |
David Mak | f59d45805f | |
David Mak | 048fcb0a69 | |
David Mak | 676d07657a | |
David Mak | 2482a1ef9b | |
David Mak | eb63f2ad48 | |
Sebastien Bourdeauducq | ff27e22ee6 | |
Sebastien Bourdeauducq | d672ef094b | |
Sebastien Bourdeauducq | d25921230e | |
Sebastien Bourdeauducq | 66f07b5bf4 | |
David Mak | 008d50995c | |
David Mak | 474f9050ce | |
David Mak | 3993a5cf3f | |
David Mak | 39724de598 | |
David Mak | e4940247f3 | |
David Mak | 4481d48709 | |
David Mak | b4983526bd | |
David Mak | b4a9616648 | |
David Mak | e0de82993f | |
David Mak | 6805253515 | |
David Mak | 19915bac79 | |
David Mak | 17b4686260 | |
David Mak | 6de0884dc1 | |
David Mak | f1b0e05b3d | |
David Mak | ff23968544 | |
Sebastien Bourdeauducq | 049908044a | |
David Mak | d37287a33d | |
Sebastien Bourdeauducq | 283bd7c69a | |
Sebastien Bourdeauducq | 3d73f5c129 | |
Sebastien Bourdeauducq | d824c5d8b5 | |
Sebastien Bourdeauducq | b8d637f5c4 | |
Sebastien Bourdeauducq | 3af287d1c4 | |
Sebastien Bourdeauducq | 5b53be0311 | |
Sebastien Bourdeauducq | aead36f0fd | |
Sebastien Bourdeauducq | c269444c0b | |
Sebastien Bourdeauducq | 52cec3c12f | |
Sebastien Bourdeauducq | 2927f2a1d0 | |
Sebastien Bourdeauducq | c1c45373a6 | |
Sebastien Bourdeauducq | 946ea155b8 | |
Sebastien Bourdeauducq | 085c6ee738 | |
Sebastien Bourdeauducq | cfa67c418a | |
Sebastien Bourdeauducq | 813bfa92a7 | |
Sebastien Bourdeauducq | fff4b65169 | |
Sebastien Bourdeauducq | c891fffd75 | |
Sebastien Bourdeauducq | 12acd35e15 | |
Sebastien Bourdeauducq | f66ca02b2d | |
z78078 | b514f91441 | |
z78078 | 8f95b79257 | |
z78078 | ebd25af38b | |
z78078 | 96b3a3bf5c | |
ychenfo | a18d095245 | |
Sebastien Bourdeauducq | b242463548 | |
Sebastien Bourdeauducq | 8e6e4d6715 | |
Sebastien Bourdeauducq | 73c2aefe4b | |
Sebastien Bourdeauducq | 892597cda4 | |
Sebastien Bourdeauducq | 33321c5e9c | |
occheung | 50ed04b787 | |
occheung | 7cb9be0f81 | |
occheung | ac560ba985 | |
occheung | a96371145d | |
ychenfo | 8addf2b55e | |
ychenfo | 5d5e9a5e02 | |
Sebastien Bourdeauducq | 4c39dd240f | |
occheung | 48fc5ceb8e | |
ychenfo | c4ab2855e5 | |
ychenfo | ffac37dc48 | |
ychenfo | 76473152e8 | |
Sebastien Bourdeauducq | b04631e935 | |
ychenfo | 09820e5aed | |
Sebastien Bourdeauducq | 0ec2ed4d91 | |
ychenfo | 2cb725b7ac | |
Sebastien Bourdeauducq | b9259b1907 | |
ychenfo | 096f4b03c0 | |
ychenfo | a022005183 | |
ychenfo | 325ba0a408 | |
ychenfo | ae6434696c | |
Sebastien Bourdeauducq | 3f327113b2 | |
Sebastien Bourdeauducq | 27d509d70e | |
Sebastien Bourdeauducq | a321b13bec | |
ychenfo | 48cb485b89 | |
Sebastien Bourdeauducq | 837aaa95f1 | |
Sebastien Bourdeauducq | a19e9c0bec | |
Sebastien Bourdeauducq | 5dbe1d3d7d | |
Sebastien Bourdeauducq | e9bca3c822 | |
Sebastien Bourdeauducq | 42d1aad507 | |
Sebastien Bourdeauducq | 2777a6e05f | |
Sebastien Bourdeauducq | 05be5e93c4 | |
Sebastien Bourdeauducq | 85f21060e4 | |
Sebastien Bourdeauducq | a308d24caa | |
Sebastien Bourdeauducq | 1eac111d4c | |
ychenfo | 44199781dc | |
ychenfo | 711c3d3303 | |
sb10q | 0975264482 | |
Sebastien Bourdeauducq | 087aded3a3 | |
ychenfo | f14b32be67 | |
David Nadlinger | 879c66cccf |
|
@ -0,0 +1 @@
|
|||
doc-valid-idents = ["CPython", "NumPy", ".."]
|
|
@ -1,3 +1,3 @@
|
|||
__pycache__
|
||||
/target
|
||||
windows/msys2
|
||||
nix/windows/msys2
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
# See https://pre-commit.com for more information
|
||||
# See https://pre-commit.com/hooks.html for more hooks
|
||||
|
||||
default_stages: [commit]
|
||||
|
||||
repos:
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: nac3-cargo-fmt
|
||||
name: nac3 cargo format
|
||||
entry: cargo
|
||||
language: system
|
||||
types: [file, rust]
|
||||
pass_filenames: false
|
||||
description: Runs cargo fmt on the codebase.
|
||||
args: [fmt]
|
||||
- id: nac3-cargo-clippy
|
||||
name: nac3 cargo clippy
|
||||
entry: cargo
|
||||
language: system
|
||||
types: [file, rust]
|
||||
pass_filenames: false
|
||||
description: Runs cargo clippy on the codebase.
|
||||
args: [clippy]
|
File diff suppressed because it is too large
Load Diff
|
@ -1,5 +1,6 @@
|
|||
[workspace]
|
||||
members = [
|
||||
"nac3ld",
|
||||
"nac3ast",
|
||||
"nac3parser",
|
||||
"nac3core",
|
||||
|
@ -7,6 +8,7 @@ members = [
|
|||
"nac3artiq",
|
||||
"runkernel",
|
||||
]
|
||||
resolver = "2"
|
||||
|
||||
[profile.release]
|
||||
debug = true
|
||||
|
|
33
README.md
33
README.md
|
@ -1,5 +1,10 @@
|
|||
# NAC3
|
||||
<div align="center">
|
||||
|
||||
![icon](https://git.m-labs.hk/M-Labs/nac3/raw/branch/master/nac3.svg)
|
||||
|
||||
</div>
|
||||
|
||||
# NAC3
|
||||
NAC3 is a major, backward-incompatible rewrite of the compiler for the [ARTIQ](https://m-labs.hk/artiq) physics experiment control and data acquisition system. It features greatly improved compilation speeds, a much better type system, and more predictable and transparent operation.
|
||||
|
||||
NAC3 has a modular design and its applicability reaches beyond ARTIQ. The ``nac3core`` module does not contain anything specific to ARTIQ, and can be used in any project that requires compiling Python to machine code.
|
||||
|
@ -8,7 +13,7 @@ NAC3 has a modular design and its applicability reaches beyond ARTIQ. The ``nac3
|
|||
|
||||
## Packaging
|
||||
|
||||
NAC3 is packaged using the [Nix](https://nixos.org) Flakes system. Install Nix 2.4+ and enable flakes by adding ``experimental-features = nix-command flakes`` to ``nix.conf`` (e.g. ``~/.config/nix/nix.conf``).
|
||||
NAC3 is packaged using the [Nix](https://nixos.org) Flakes system. Install Nix 2.8+ and enable flakes by adding ``experimental-features = nix-command flakes`` to ``nix.conf`` (e.g. ``~/.config/nix/nix.conf``).
|
||||
|
||||
## Try NAC3
|
||||
|
||||
|
@ -18,27 +23,19 @@ After setting up Nix as above, use ``nix shell git+https://github.com/m-labs/art
|
|||
|
||||
### Windows
|
||||
|
||||
Install [MSYS2](https://www.msys2.org/), and open "MSYS2 MinGW x64". Edit ``/etc/pacman.conf`` to add:
|
||||
Install [MSYS2](https://www.msys2.org/), and open "MSYS2 CLANG64". Edit ``/etc/pacman.conf`` to add:
|
||||
```
|
||||
[artiq]
|
||||
SigLevel = Optional TrustAll
|
||||
Server = https://lab.m-labs.hk/msys2
|
||||
Server = https://msys2.m-labs.hk/artiq-nac3
|
||||
```
|
||||
|
||||
Then run the following commands:
|
||||
```
|
||||
pacman -Syu
|
||||
pacman -S mingw-w64-x86_64-artiq
|
||||
pacman -S mingw-w64-clang-x86_64-artiq
|
||||
```
|
||||
|
||||
Install ``lld-msys2`` manually:
|
||||
```
|
||||
wget https://nixbld.m-labs.hk/build/115527/download/1/ld.lld.exe
|
||||
mv ld.lld.exe C:/msys64/mingw64/bin
|
||||
```
|
||||
|
||||
Note: This build of NAC3 cannot be used with Anaconda Python nor the python.org binaries for Windows. Those Python versions are compiled with Visual Studio (MSVC) and their ABI is incompatible with the GNU ABI used in this build. We have no plans to support Visual Studio nor the MSVC ABI. If you need a MSVC build, please install the requisite bloated spyware from Microsoft and compile NAC3 yourself.
|
||||
|
||||
## For developers
|
||||
|
||||
This repository contains:
|
||||
|
@ -46,6 +43,7 @@ This repository contains:
|
|||
- ``nac3parser``: Python parser (based on RustPython).
|
||||
- ``nac3core``: Core compiler library, containing type-checking and code generation.
|
||||
- ``nac3standalone``: Standalone compiler tool (core language only).
|
||||
- ``nac3ld``: Minimalist RISC-V and ARM linker.
|
||||
- ``nac3artiq``: Integration with ARTIQ and implementation of ARTIQ-specific extensions to the core language.
|
||||
- ``runkernel``: Simple program that runs compiled ARTIQ kernels on the host and displays RTIO operations. Useful for testing without hardware.
|
||||
|
||||
|
@ -53,3 +51,12 @@ Use ``nix develop`` in this repository to enter a development shell.
|
|||
If you are using a different shell than bash you can use e.g. ``nix develop --command fish``.
|
||||
|
||||
Build NAC3 with ``cargo build --release``. See the demonstrations in ``nac3artiq`` and ``nac3standalone``.
|
||||
|
||||
### Pre-Commit Hooks
|
||||
|
||||
You are strongly recommended to use the provided pre-commit hooks to automatically reformat files and check for non-optimal Rust practices using Clippy. Run `pre-commit install` to install the hook and `pre-commit` will automatically run `cargo fmt` and `cargo clippy` for you.
|
||||
|
||||
Several things to note:
|
||||
|
||||
- If `cargo fmt` or `cargo clippy` returns an error, the pre-commit hook will fail. You should fix all errors before trying to commit again.
|
||||
- If `cargo fmt` reformats some files, the pre-commit hook will also fail. You should review the changes and, if satisfied, try to commit again.
|
||||
|
|
|
@ -2,16 +2,16 @@
|
|||
"nodes": {
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1649619156,
|
||||
"narHash": "sha256-p0q4zpuKMwrzGF+5ZU7Thnpac5TinhDI9jr2mBxhV4w=",
|
||||
"lastModified": 1717196966,
|
||||
"narHash": "sha256-yZKhxVIKd2lsbOqYd5iDoUIwsRZFqE87smE2Vzf6Ck0=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "e7d63bd0d50df412f5a1d8acfa3caae75522e347",
|
||||
"rev": "57610d2f8f0937f39dbd72251e9614b1561942d8",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-21.11",
|
||||
"ref": "nixos-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
|
|
82
flake.nix
82
flake.nix
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
description = "The third-generation ARTIQ compiler";
|
||||
|
||||
inputs.nixpkgs.url = github:NixOS/nixpkgs/nixos-21.11;
|
||||
inputs.nixpkgs.url = github:NixOS/nixpkgs/nixos-unstable;
|
||||
|
||||
outputs = { self, nixpkgs }:
|
||||
let
|
||||
|
@ -9,15 +9,24 @@
|
|||
in rec {
|
||||
packages.x86_64-linux = rec {
|
||||
llvm-nac3 = pkgs.callPackage ./nix/llvm {};
|
||||
llvm-tools-irrt = pkgs.runCommandNoCC "llvm-tools-irrt" {}
|
||||
''
|
||||
mkdir -p $out/bin
|
||||
ln -s ${pkgs.llvmPackages_14.clang-unwrapped}/bin/clang $out/bin/clang-irrt
|
||||
ln -s ${pkgs.llvmPackages_14.llvm.out}/bin/llvm-as $out/bin/llvm-as-irrt
|
||||
'';
|
||||
nac3artiq = pkgs.python3Packages.toPythonModule (
|
||||
pkgs.rustPlatform.buildRustPackage {
|
||||
pkgs.rustPlatform.buildRustPackage rec {
|
||||
name = "nac3artiq";
|
||||
outputs = [ "out" "runkernel" "standalone" ];
|
||||
src = self;
|
||||
cargoLock = { lockFile = ./Cargo.lock; };
|
||||
nativeBuildInputs = [ pkgs.python3 pkgs.llvmPackages_13.clang-unwrapped llvm-nac3 ];
|
||||
cargoLock = {
|
||||
lockFile = ./Cargo.lock;
|
||||
};
|
||||
passthru.cargoLock = cargoLock;
|
||||
nativeBuildInputs = [ pkgs.python3 pkgs.llvmPackages_14.clang llvm-tools-irrt pkgs.llvmPackages_14.llvm.out llvm-nac3 ];
|
||||
buildInputs = [ pkgs.python3 llvm-nac3 ];
|
||||
checkInputs = [ (pkgs.python3.withPackages(ps: [ ps.numpy ])) ];
|
||||
checkInputs = [ (pkgs.python3.withPackages(ps: [ ps.numpy ps.scipy ])) ];
|
||||
checkPhase =
|
||||
''
|
||||
echo "Checking nac3standalone demos..."
|
||||
|
@ -49,21 +58,21 @@
|
|||
|
||||
# LLVM PGO support
|
||||
llvm-nac3-instrumented = pkgs.callPackage ./nix/llvm {
|
||||
stdenv = pkgs.llvmPackages_13.stdenv;
|
||||
stdenv = pkgs.llvmPackages_14.stdenv;
|
||||
extraCmakeFlags = [ "-DLLVM_BUILD_INSTRUMENTED=IR" ];
|
||||
};
|
||||
nac3artiq-instrumented = pkgs.python3Packages.toPythonModule (
|
||||
pkgs.rustPlatform.buildRustPackage {
|
||||
name = "nac3artiq-instrumented";
|
||||
src = self;
|
||||
cargoLock = { lockFile = ./Cargo.lock; };
|
||||
nativeBuildInputs = [ pkgs.python3 pkgs.llvmPackages_13.clang-unwrapped llvm-nac3-instrumented ];
|
||||
inherit (nac3artiq) cargoLock;
|
||||
nativeBuildInputs = [ pkgs.python3 packages.x86_64-linux.llvm-tools-irrt llvm-nac3-instrumented ];
|
||||
buildInputs = [ pkgs.python3 llvm-nac3-instrumented ];
|
||||
cargoBuildFlags = [ "--package" "nac3artiq" "--features" "init-llvm-profile" ];
|
||||
doCheck = false;
|
||||
configurePhase =
|
||||
''
|
||||
export CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_RUSTFLAGS="-C link-arg=-L${pkgs.llvmPackages_13.compiler-rt}/lib/linux -C link-arg=-lclang_rt.profile-x86_64"
|
||||
export CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_RUSTFLAGS="-C link-arg=-L${pkgs.llvmPackages_14.compiler-rt}/lib/linux -C link-arg=-lclang_rt.profile-x86_64"
|
||||
'';
|
||||
installPhase =
|
||||
''
|
||||
|
@ -75,11 +84,35 @@
|
|||
);
|
||||
nac3artiq-profile = pkgs.stdenvNoCC.mkDerivation {
|
||||
name = "nac3artiq-profile";
|
||||
src = self;
|
||||
buildInputs = [ (python3-mimalloc.withPackages(ps: [ ps.numpy nac3artiq-instrumented ])) pkgs.lld_13 pkgs.llvmPackages_13.libllvm ];
|
||||
srcs = [
|
||||
(pkgs.fetchFromGitHub {
|
||||
owner = "m-labs";
|
||||
repo = "sipyco";
|
||||
rev = "939f84f9b5eef7efbf7423c735d1834783b6140e";
|
||||
sha256 = "sha256-15Nun4EY35j+6SPZkjzZtyH/ncxLS60KuGJjFh5kSTc=";
|
||||
})
|
||||
(pkgs.fetchFromGitHub {
|
||||
owner = "m-labs";
|
||||
repo = "artiq";
|
||||
rev = "923ca3377d42c815f979983134ec549dc39d3ca0";
|
||||
sha256 = "sha256-oJoEeNEeNFSUyh6jXG8Tzp6qHVikeHS0CzfE+mODPgw=";
|
||||
})
|
||||
];
|
||||
buildInputs = [
|
||||
(python3-mimalloc.withPackages(ps: [ ps.numpy ps.scipy ps.jsonschema ps.lmdb nac3artiq-instrumented ]))
|
||||
pkgs.llvmPackages_14.llvm.out
|
||||
];
|
||||
phases = [ "buildPhase" "installPhase" ];
|
||||
# TODO: get more representative code.
|
||||
buildPhase = "python $src/nac3artiq/demo/demo.py";
|
||||
buildPhase =
|
||||
''
|
||||
srcs=($srcs)
|
||||
sipyco=''${srcs[0]}
|
||||
artiq=''${srcs[1]}
|
||||
export PYTHONPATH=$sipyco:$artiq
|
||||
python -m artiq.frontend.artiq_ddb_template $artiq/artiq/examples/nac3devices/nac3devices.json > device_db.py
|
||||
cp $artiq/artiq/examples/nac3devices/nac3devices.py .
|
||||
python -m artiq.frontend.artiq_compile nac3devices.py
|
||||
'';
|
||||
installPhase =
|
||||
''
|
||||
mkdir $out
|
||||
|
@ -87,15 +120,15 @@
|
|||
'';
|
||||
};
|
||||
llvm-nac3-pgo = pkgs.callPackage ./nix/llvm {
|
||||
stdenv = pkgs.llvmPackages_13.stdenv;
|
||||
stdenv = pkgs.llvmPackages_14.stdenv;
|
||||
extraCmakeFlags = [ "-DLLVM_PROFDATA_FILE=${nac3artiq-profile}/llvm.profdata" ];
|
||||
};
|
||||
nac3artiq-pgo = pkgs.python3Packages.toPythonModule (
|
||||
pkgs.rustPlatform.buildRustPackage {
|
||||
name = "nac3artiq-pgo";
|
||||
src = self;
|
||||
cargoLock = { lockFile = ./Cargo.lock; };
|
||||
nativeBuildInputs = [ pkgs.python3 pkgs.llvmPackages_13.clang-unwrapped llvm-nac3-pgo ];
|
||||
inherit (nac3artiq) cargoLock;
|
||||
nativeBuildInputs = [ pkgs.python3 packages.x86_64-linux.llvm-tools-irrt llvm-nac3-pgo ];
|
||||
buildInputs = [ pkgs.python3 llvm-nac3-pgo ];
|
||||
cargoBuildFlags = [ "--package" "nac3artiq" ];
|
||||
cargoTestFlags = [ "--package" "nac3ast" "--package" "nac3parser" "--package" "nac3core" "--package" "nac3artiq" ];
|
||||
|
@ -111,20 +144,22 @@
|
|||
|
||||
packages.x86_64-w64-mingw32 = import ./nix/windows { inherit pkgs; };
|
||||
|
||||
devShell.x86_64-linux = pkgs.mkShell {
|
||||
devShells.x86_64-linux.default = pkgs.mkShell {
|
||||
name = "nac3-dev-shell";
|
||||
buildInputs = with pkgs; [
|
||||
# build dependencies
|
||||
packages.x86_64-linux.llvm-nac3
|
||||
llvmPackages_13.clang-unwrapped # IRRT
|
||||
llvmPackages_14.clang llvmPackages_14.llvm.out # for running nac3standalone demos
|
||||
packages.x86_64-linux.llvm-tools-irrt
|
||||
cargo
|
||||
rustc
|
||||
# runtime dependencies
|
||||
lld_13
|
||||
(packages.x86_64-linux.python3-mimalloc.withPackages(ps: [ ps.numpy ]))
|
||||
lld_14 # for running kernels on the host
|
||||
(packages.x86_64-linux.python3-mimalloc.withPackages(ps: [ ps.numpy ps.scipy ]))
|
||||
# development tools
|
||||
cargo-insta
|
||||
clippy
|
||||
pre-commit
|
||||
rustfmt
|
||||
];
|
||||
};
|
||||
|
@ -139,16 +174,15 @@
|
|||
};
|
||||
|
||||
hydraJobs = {
|
||||
inherit (packages.x86_64-linux) llvm-nac3 nac3artiq;
|
||||
inherit (packages.x86_64-linux) llvm-nac3 nac3artiq nac3artiq-pgo;
|
||||
llvm-nac3-msys2 = packages.x86_64-w64-mingw32.llvm-nac3;
|
||||
nac3artiq-msys2 = packages.x86_64-w64-mingw32.nac3artiq;
|
||||
nac3artiq-msys2-pkg = packages.x86_64-w64-mingw32.nac3artiq-pkg;
|
||||
lld-msys2 = packages.x86_64-w64-mingw32.lld;
|
||||
};
|
||||
};
|
||||
|
||||
nixConfig = {
|
||||
binaryCachePublicKeys = ["nixbld.m-labs.hk-1:5aSRVA5b320xbNvu30tqxVPXpld73bhtOeH6uAjRyHc="];
|
||||
binaryCaches = ["https://nixbld.m-labs.hk" "https://cache.nixos.org"];
|
||||
extra-trusted-public-keys = "nixbld.m-labs.hk-1:5aSRVA5b320xbNvu30tqxVPXpld73bhtOeH6uAjRyHc=";
|
||||
extra-substituters = "https://nixbld.m-labs.hk";
|
||||
};
|
||||
}
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
|
||||
<svg
|
||||
id="a"
|
||||
width="128"
|
||||
height="128"
|
||||
viewBox="0 0 95.99999 95.99999"
|
||||
version="1.1"
|
||||
sodipodi:docname="nac3.svg"
|
||||
inkscape:version="1.1.1 (3bf5ae0d25, 2021-09-20)"
|
||||
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
|
||||
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
xmlns:svg="http://www.w3.org/2000/svg">
|
||||
<defs
|
||||
id="defs11" />
|
||||
<sodipodi:namedview
|
||||
id="namedview9"
|
||||
pagecolor="#ffffff"
|
||||
bordercolor="#666666"
|
||||
borderopacity="1.0"
|
||||
inkscape:pageshadow="2"
|
||||
inkscape:pageopacity="0.0"
|
||||
inkscape:pagecheckerboard="0"
|
||||
inkscape:document-units="mm"
|
||||
showgrid="false"
|
||||
units="px"
|
||||
width="128px"
|
||||
inkscape:zoom="5.9448568"
|
||||
inkscape:cx="60.472441"
|
||||
inkscape:cy="60.556547"
|
||||
inkscape:window-width="2560"
|
||||
inkscape:window-height="1371"
|
||||
inkscape:window-x="0"
|
||||
inkscape:window-y="32"
|
||||
inkscape:window-maximized="1"
|
||||
inkscape:current-layer="a" />
|
||||
<rect
|
||||
x="40.072601"
|
||||
y="-26.776209"
|
||||
width="55.668747"
|
||||
height="55.668747"
|
||||
transform="matrix(0.71803815,0.69600374,-0.71803815,0.69600374,0,0)"
|
||||
style="fill:#be211e;stroke:#000000;stroke-width:4.37375px;stroke-linecap:round;stroke-linejoin:round"
|
||||
id="rect2" />
|
||||
<line
|
||||
x1="38.00692"
|
||||
y1="63.457153"
|
||||
x2="57.993061"
|
||||
y2="63.457153"
|
||||
style="fill:none;stroke:#000000;stroke-width:4.37269px;stroke-linecap:round;stroke-linejoin:round"
|
||||
id="line4" />
|
||||
<path
|
||||
d="m 48.007301,57.843329 c -1.943097,0 -3.877522,-0.41727 -5.686157,-1.246007 -3.218257,-1.474616 -5.650382,-4.075418 -6.849639,-7.323671 -2.065624,-5.588921 -1.192751,-10.226647 2.575258,-13.827 0.611554,-0.584909 1.518048,-0.773041 2.323689,-0.488206 0.80673,0.286405 1.369495,0.998486 1.447563,1.827234 0.237469,2.549302 2.439719,5.917376 4.28414,6.55273 0.396859,0.13506 0.820953,-0.05859 1.097084,-0.35222 0.339254,-0.360754 0.451065,-0.961893 -1.013597,-3.191372 -2.089851,-3.181137 -4.638728,-8.754903 -0.262407,-15.069853 0.494457,-0.713491 1.384673,-1.068907 2.256469,-0.909156 0.871795,0.161332 1.583757,0.806404 1.752251,1.651189 0.716448,3.591862 2.962357,6.151755 5.199306,8.023138 1.935503,1.61861 4.344688,3.867387 5.435687,7.096643 2.283183,6.758017 -1.202511,14.114988 -8.060822,16.494025 -1.467083,0.509226 -2.98513,0.762536 -4.498836,0.762536 z M 39.358865,40.002192 c -0.304711,0.696206 -0.541636,2.080524 -0.56865,2.237454 -0.330316,1.918771 0.168305,3.803963 0.846157,5.539951 0.856828,2.19436 2.437543,3.942467 4.583411,4.925713 2.143691,0.981675 4.554131,1.097816 6.789992,0.322666 4.571485,-1.586549 6.977584,-6.532238 5.363036,-11.02597 v -5.27e-4 C 55.455481,39.447968 54.023463,38.162043 52.221335,36.65432 50.876945,35.529534 49.409662,33.987726 48.417983,32.135555 48.01343,31.37996 47.79547,30.34303 47.76669,29.413263 c -0.187481,0.669514 -0.212441,2.325923 -0.150396,2.93691 0.179209,1.764456 1.333476,3.644546 2.340611,5.171243 1.311568,1.988179 2.72058,6.037272 0.459681,8.367985 -1.54192,1.58953 -4.038511,2.052034 -5.839973,1.38492 -2.398314,-0.888147 -3.942744,-2.690627 -4.941118,-4.768029 -0.121194,-0.25217 -0.532464,-1.174187 -0.276619,-2.5041 z"
|
||||
id="path6"
|
||||
style="stroke-width:1.09317" />
|
||||
</svg>
|
After Width: | Height: | Size: 3.3 KiB |
|
@ -2,23 +2,25 @@
|
|||
name = "nac3artiq"
|
||||
version = "0.1.0"
|
||||
authors = ["M-Labs"]
|
||||
edition = "2018"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
name = "nac3artiq"
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
pyo3 = { version = "0.14", features = ["extension-module"] }
|
||||
parking_lot = "0.11"
|
||||
tempfile = "3"
|
||||
itertools = "0.13"
|
||||
pyo3 = { version = "0.21", features = ["extension-module", "gil-refs"] }
|
||||
parking_lot = "0.12"
|
||||
tempfile = "3.10"
|
||||
nac3parser = { path = "../nac3parser" }
|
||||
nac3core = { path = "../nac3core" }
|
||||
nac3ld = { path = "../nac3ld" }
|
||||
|
||||
[dependencies.inkwell]
|
||||
version = "0.1.0-beta.4"
|
||||
version = "0.4"
|
||||
default-features = false
|
||||
features = ["llvm13-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]
|
||||
features = ["llvm14-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]
|
||||
|
||||
[features]
|
||||
init-llvm-profile = []
|
||||
|
|
|
@ -18,6 +18,13 @@ class EmbeddingMap:
|
|||
"SPIError",
|
||||
"0:ZeroDivisionError",
|
||||
"0:IndexError",
|
||||
"0:ValueError",
|
||||
"0:RuntimeError",
|
||||
"0:AssertionError",
|
||||
"0:KeyError",
|
||||
"0:NotImplementedError",
|
||||
"0:OverflowError",
|
||||
"0:IOError",
|
||||
"0:UnwrapNoneError"])
|
||||
|
||||
def preallocate_runtime_exception_names(self, names):
|
||||
|
|
|
@ -10,7 +10,7 @@ from embedding_map import EmbeddingMap
|
|||
|
||||
|
||||
__all__ = [
|
||||
"Kernel", "KernelInvariant", "virtual",
|
||||
"Kernel", "KernelInvariant", "virtual", "ConstGeneric",
|
||||
"Option", "Some", "none", "UnwrapNoneError",
|
||||
"round64", "floor64", "ceil64",
|
||||
"extern", "kernel", "portable", "nac3",
|
||||
|
@ -67,6 +67,12 @@ def Some(v: T) -> Option[T]:
|
|||
|
||||
none = Option(None)
|
||||
|
||||
class _ConstGenericMarker:
|
||||
pass
|
||||
|
||||
def ConstGeneric(name, constraint):
|
||||
return TypeVar(name, _ConstGenericMarker, constraint)
|
||||
|
||||
def round64(x):
|
||||
return round(x)
|
||||
|
||||
|
@ -80,7 +86,13 @@ def ceil64(x):
|
|||
import device_db
|
||||
core_arguments = device_db.device_db["core"]["arguments"]
|
||||
|
||||
compiler = nac3artiq.NAC3(core_arguments["target"])
|
||||
artiq_builtins = {
|
||||
"none": none,
|
||||
"virtual": virtual,
|
||||
"_ConstGenericMarker": _ConstGenericMarker,
|
||||
"Option": Option,
|
||||
}
|
||||
compiler = nac3artiq.NAC3(core_arguments["target"], artiq_builtins)
|
||||
allow_registration = True
|
||||
# Delay NAC3 analysis until all referenced variables are supposed to exist on the CPython side.
|
||||
registered_functions = set()
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
use nac3core::{
|
||||
codegen::{
|
||||
expr::gen_call,
|
||||
llvm_intrinsics::{call_int_smax, call_stackrestore, call_stacksave},
|
||||
stmt::{gen_block, gen_with},
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
symbol_resolver::ValueEnum,
|
||||
toplevel::{DefinitionId, GenCall},
|
||||
typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum}
|
||||
toplevel::{helper::PrimDef, DefinitionId, GenCall},
|
||||
typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, VarMap},
|
||||
};
|
||||
|
||||
use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef};
|
||||
|
@ -15,7 +16,10 @@ use inkwell::{
|
|||
context::Context, module::Linkage, types::IntType, values::BasicValueEnum, AddressSpace,
|
||||
};
|
||||
|
||||
use pyo3::{PyObject, PyResult, Python, types::{PyDict, PyList}};
|
||||
use pyo3::{
|
||||
types::{PyDict, PyList},
|
||||
PyObject, PyResult, Python,
|
||||
};
|
||||
|
||||
use crate::{symbol_resolver::InnerResolver, timeline::TimeFns};
|
||||
|
||||
|
@ -26,13 +30,45 @@ use std::{
|
|||
sync::Arc,
|
||||
};
|
||||
|
||||
/// The parallelism mode within a block.
|
||||
#[derive(Copy, Clone, Eq, PartialEq)]
|
||||
enum ParallelMode {
|
||||
/// No parallelism is currently registered for this context.
|
||||
None,
|
||||
|
||||
/// Legacy (or shallow) parallelism. Default before NAC3.
|
||||
///
|
||||
/// Each statement within the `with` block is treated as statements to be executed in parallel.
|
||||
Legacy,
|
||||
|
||||
/// Deep parallelism. Default since NAC3.
|
||||
///
|
||||
/// Each function call within the `with` block (except those within a nested `sequential` block)
|
||||
/// are treated to be executed in parallel.
|
||||
Deep,
|
||||
}
|
||||
|
||||
pub struct ArtiqCodeGenerator<'a> {
|
||||
name: String,
|
||||
|
||||
/// The size of a `size_t` variable in bits.
|
||||
size_t: u32,
|
||||
|
||||
/// Monotonic counter for naming `start`/`stop` variables used by `with parallel` blocks.
|
||||
name_counter: u32,
|
||||
|
||||
/// Variable for tracking the start of a `with parallel` block.
|
||||
start: Option<Expr<Option<Type>>>,
|
||||
|
||||
/// Variable for tracking the end of a `with parallel` block.
|
||||
end: Option<Expr<Option<Type>>>,
|
||||
timeline: &'a (dyn TimeFns + Sync),
|
||||
|
||||
/// The [ParallelMode] of the current parallel context.
|
||||
///
|
||||
/// The current parallel context refers to the nearest `with parallel` or `with legacy_parallel`
|
||||
/// statement, which is used to determine when and how the timeline should be updated.
|
||||
parallel_mode: ParallelMode,
|
||||
}
|
||||
|
||||
impl<'a> ArtiqCodeGenerator<'a> {
|
||||
|
@ -42,7 +78,74 @@ impl<'a> ArtiqCodeGenerator<'a> {
|
|||
timeline: &'a (dyn TimeFns + Sync),
|
||||
) -> ArtiqCodeGenerator<'a> {
|
||||
assert!(size_t == 32 || size_t == 64);
|
||||
ArtiqCodeGenerator { name, size_t, name_counter: 0, start: None, end: None, timeline }
|
||||
ArtiqCodeGenerator {
|
||||
name,
|
||||
size_t,
|
||||
name_counter: 0,
|
||||
start: None,
|
||||
end: None,
|
||||
timeline,
|
||||
parallel_mode: ParallelMode::None,
|
||||
}
|
||||
}
|
||||
|
||||
/// If the generator is currently in a direct-`parallel` block context, emits IR that resets the
|
||||
/// position of the timeline to the initial timeline position before entering the `parallel`
|
||||
/// block.
|
||||
///
|
||||
/// Direct-`parallel` block context refers to when the generator is generating statements whose
|
||||
/// closest parent `with` statement is a `with parallel` block.
|
||||
fn timeline_reset_start(&mut self, ctx: &mut CodeGenContext<'_, '_>) -> Result<(), String> {
|
||||
if let Some(start) = self.start.clone() {
|
||||
let start_val = self.gen_expr(ctx, &start)?.unwrap().to_basic_value_enum(
|
||||
ctx,
|
||||
self,
|
||||
start.custom.unwrap(),
|
||||
)?;
|
||||
self.timeline.emit_at_mu(ctx, start_val);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// If the generator is currently in a `parallel` block context, emits IR that updates the
|
||||
/// maximum end position of the `parallel` block as specified by the timeline `end` value.
|
||||
///
|
||||
/// In general the `end` parameter should be set to `self.end` for updating the maximum end
|
||||
/// position for the current `parallel` block. Other values can be passed in to update the
|
||||
/// maximum end position for other `parallel` blocks.
|
||||
///
|
||||
/// `parallel`-block context refers to when the generator is generating statements within a
|
||||
/// (possibly indirect) `parallel` block.
|
||||
///
|
||||
/// * `store_name` - The LLVM value name for the pointer to `end`. `.addr` will be appended to
|
||||
/// the end of the provided value name.
|
||||
fn timeline_update_end_max(
|
||||
&mut self,
|
||||
ctx: &mut CodeGenContext<'_, '_>,
|
||||
end: Option<Expr<Option<Type>>>,
|
||||
store_name: Option<&str>,
|
||||
) -> Result<(), String> {
|
||||
if let Some(end) = end {
|
||||
let old_end = self.gen_expr(ctx, &end)?.unwrap().to_basic_value_enum(
|
||||
ctx,
|
||||
self,
|
||||
end.custom.unwrap(),
|
||||
)?;
|
||||
let now = self.timeline.emit_now_mu(ctx);
|
||||
let max =
|
||||
call_int_smax(ctx, old_end.into_int_value(), now.into_int_value(), Some("smax"));
|
||||
let end_store = self
|
||||
.gen_store_target(
|
||||
ctx,
|
||||
&end,
|
||||
store_name.map(|name| format!("{name}.addr")).as_deref(),
|
||||
)?
|
||||
.unwrap();
|
||||
ctx.builder.build_store(end_store, max).unwrap();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -59,183 +162,203 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
|
|||
}
|
||||
}
|
||||
|
||||
fn gen_call<'ctx, 'a>(
|
||||
fn gen_block<'ctx, 'a, 'c, I: Iterator<Item = &'c Stmt<Option<Type>>>>(
|
||||
&mut self,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
stmts: I,
|
||||
) -> Result<(), String>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
// Legacy parallel emits timeline end-update/timeline-reset after each top-level statement
|
||||
// in the parallel block
|
||||
if self.parallel_mode == ParallelMode::Legacy {
|
||||
for stmt in stmts {
|
||||
self.gen_stmt(ctx, stmt)?;
|
||||
|
||||
if ctx.is_terminated() {
|
||||
break;
|
||||
}
|
||||
|
||||
self.timeline_update_end_max(ctx, self.end.clone(), Some("end"))?;
|
||||
self.timeline_reset_start(ctx)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
} else {
|
||||
gen_block(self, ctx, stmts)
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_call<'ctx>(
|
||||
&mut self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
obj: Option<(Type, ValueEnum<'ctx>)>,
|
||||
fun: (&FunSignature, DefinitionId),
|
||||
params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
||||
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
|
||||
let result = gen_call(self, ctx, obj, fun, params)?;
|
||||
if let Some(end) = self.end.clone() {
|
||||
let old_end = self.gen_expr(ctx, &end)?.unwrap().to_basic_value_enum(ctx, self, end.custom.unwrap())?;
|
||||
let now = self.timeline.emit_now_mu(ctx);
|
||||
let smax = ctx.module.get_function("llvm.smax.i64").unwrap_or_else(|| {
|
||||
let i64 = ctx.ctx.i64_type();
|
||||
ctx.module.add_function(
|
||||
"llvm.smax.i64",
|
||||
i64.fn_type(&[i64.into(), i64.into()], false),
|
||||
None,
|
||||
)
|
||||
});
|
||||
let max = ctx
|
||||
.builder
|
||||
.build_call(smax, &[old_end.into(), now.into()], "smax")
|
||||
.try_as_basic_value()
|
||||
.left()
|
||||
.unwrap();
|
||||
let end_store = self.gen_store_target(ctx, &end)?;
|
||||
ctx.builder.build_store(end_store, max);
|
||||
}
|
||||
if let Some(start) = self.start.clone() {
|
||||
let start_val = self.gen_expr(ctx, &start)?.unwrap().to_basic_value_enum(ctx, self, start.custom.unwrap())?;
|
||||
self.timeline.emit_at_mu(ctx, start_val);
|
||||
|
||||
// Deep parallel emits timeline end-update/timeline-reset after each function call
|
||||
if self.parallel_mode == ParallelMode::Deep {
|
||||
self.timeline_update_end_max(ctx, self.end.clone(), Some("end"))?;
|
||||
self.timeline_reset_start(ctx)?;
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn gen_with<'ctx, 'a>(
|
||||
fn gen_with(
|
||||
&mut self,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ctx: &mut CodeGenContext<'_, '_>,
|
||||
stmt: &Stmt<Option<Type>>,
|
||||
) -> Result<(), String> {
|
||||
if let StmtKind::With { items, body, .. } = &stmt.node {
|
||||
if items.len() == 1 && items[0].optional_vars.is_none() {
|
||||
let item = &items[0];
|
||||
// Behavior of parallel and sequential:
|
||||
// Each function call (indirectly, can be inside a sequential block) within a parallel
|
||||
// block will update the end variable to the maximum now_mu in the block.
|
||||
// Each function call directly inside a parallel block will reset the timeline after
|
||||
// execution. A parallel block within a sequential block (or not within any block) will
|
||||
// set the timeline to the max now_mu within the block (and the outer max now_mu will also
|
||||
// be updated).
|
||||
//
|
||||
// Implementation: We track the start and end separately.
|
||||
// - If there is a start variable, it indicates that we are directly inside a
|
||||
// parallel block and we have to reset the timeline after every function call.
|
||||
// - If there is a end variable, it indicates that we are (indirectly) inside a
|
||||
// parallel block, and we should update the max end value.
|
||||
if let ExprKind::Name { id, ctx: name_ctx } = &item.context_expr.node {
|
||||
if id == &"parallel".into() {
|
||||
let old_start = self.start.take();
|
||||
let old_end = self.end.take();
|
||||
let now = if let Some(old_start) = &old_start {
|
||||
self.gen_expr(ctx, old_start)?.unwrap().to_basic_value_enum(ctx, self, old_start.custom.unwrap())?
|
||||
} else {
|
||||
self.timeline.emit_now_mu(ctx)
|
||||
};
|
||||
// Emulate variable allocation, as we need to use the CodeGenContext
|
||||
// HashMap to store our variable due to lifetime limitation
|
||||
// Note: we should be able to store variables directly if generic
|
||||
// associative type is used by limiting the lifetime of CodeGenerator to
|
||||
// the LLVM Context.
|
||||
// The name is guaranteed to be unique as users cannot use this as variable
|
||||
// name.
|
||||
self.start = old_start.clone().map_or_else(
|
||||
|| {
|
||||
let start = format!("with-{}-start", self.name_counter).into();
|
||||
let start_expr = Located {
|
||||
// location does not matter at this point
|
||||
location: stmt.location,
|
||||
node: ExprKind::Name { id: start, ctx: name_ctx.clone() },
|
||||
custom: Some(ctx.primitives.int64),
|
||||
};
|
||||
let start = self.gen_store_target(ctx, &start_expr)?;
|
||||
ctx.builder.build_store(start, now);
|
||||
Ok(Some(start_expr)) as Result<_, String>
|
||||
},
|
||||
|v| Ok(Some(v)),
|
||||
)?;
|
||||
let end = format!("with-{}-end", self.name_counter).into();
|
||||
let end_expr = Located {
|
||||
// location does not matter at this point
|
||||
location: stmt.location,
|
||||
node: ExprKind::Name { id: end, ctx: name_ctx.clone() },
|
||||
custom: Some(ctx.primitives.int64),
|
||||
};
|
||||
let end = self.gen_store_target(ctx, &end_expr)?;
|
||||
ctx.builder.build_store(end, now);
|
||||
self.end = Some(end_expr);
|
||||
self.name_counter += 1;
|
||||
gen_block(self, ctx, body.iter())?;
|
||||
let current = ctx.builder.get_insert_block().unwrap();
|
||||
// if the current block is terminated, move before the terminator
|
||||
// we want to set the timeline before reaching the terminator
|
||||
// TODO: This may be unsound if there are multiple exit paths in the
|
||||
// block... e.g.
|
||||
// if ...:
|
||||
// return
|
||||
// Perhaps we can fix this by using actual with block?
|
||||
let reset_position = if let Some(terminator) = current.get_terminator() {
|
||||
ctx.builder.position_before(&terminator);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
};
|
||||
// set duration
|
||||
let end_expr = self.end.take().unwrap();
|
||||
let end_val = self
|
||||
.gen_expr(ctx, &end_expr)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, self, end_expr.custom.unwrap())?;
|
||||
let StmtKind::With { items, body, .. } = &stmt.node else { unreachable!() };
|
||||
|
||||
// inside a sequential block
|
||||
if old_start.is_none() {
|
||||
self.timeline.emit_at_mu(ctx, end_val);
|
||||
}
|
||||
// inside a parallel block, should update the outer max now_mu
|
||||
if let Some(old_end) = &old_end {
|
||||
let outer_end_val = self
|
||||
.gen_expr(ctx, old_end)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, self, old_end.custom.unwrap())?;
|
||||
let smax =
|
||||
ctx.module.get_function("llvm.smax.i64").unwrap_or_else(|| {
|
||||
let i64 = ctx.ctx.i64_type();
|
||||
ctx.module.add_function(
|
||||
"llvm.smax.i64",
|
||||
i64.fn_type(&[i64.into(), i64.into()], false),
|
||||
None,
|
||||
)
|
||||
});
|
||||
let max = ctx
|
||||
.builder
|
||||
.build_call(smax, &[end_val.into(), outer_end_val.into()], "smax")
|
||||
.try_as_basic_value()
|
||||
.left()
|
||||
if items.len() == 1 && items[0].optional_vars.is_none() {
|
||||
let item = &items[0];
|
||||
|
||||
// Behavior of parallel and sequential:
|
||||
// Each function call (indirectly, can be inside a sequential block) within a parallel
|
||||
// block will update the end variable to the maximum now_mu in the block.
|
||||
// Each function call directly inside a parallel block will reset the timeline after
|
||||
// execution. A parallel block within a sequential block (or not within any block) will
|
||||
// set the timeline to the max now_mu within the block (and the outer max now_mu will also
|
||||
// be updated).
|
||||
//
|
||||
// Implementation: We track the start and end separately.
|
||||
// - If there is a start variable, it indicates that we are directly inside a
|
||||
// parallel block and we have to reset the timeline after every function call.
|
||||
// - If there is a end variable, it indicates that we are (indirectly) inside a
|
||||
// parallel block, and we should update the max end value.
|
||||
if let ExprKind::Name { id, ctx: name_ctx } = &item.context_expr.node {
|
||||
if id == &"parallel".into() || id == &"legacy_parallel".into() {
|
||||
let old_start = self.start.take();
|
||||
let old_end = self.end.take();
|
||||
let old_parallel_mode = self.parallel_mode;
|
||||
|
||||
let now = if let Some(old_start) = &old_start {
|
||||
self.gen_expr(ctx, old_start)?.unwrap().to_basic_value_enum(
|
||||
ctx,
|
||||
self,
|
||||
old_start.custom.unwrap(),
|
||||
)?
|
||||
} else {
|
||||
self.timeline.emit_now_mu(ctx)
|
||||
};
|
||||
|
||||
// Emulate variable allocation, as we need to use the CodeGenContext
|
||||
// HashMap to store our variable due to lifetime limitation
|
||||
// Note: we should be able to store variables directly if generic
|
||||
// associative type is used by limiting the lifetime of CodeGenerator to
|
||||
// the LLVM Context.
|
||||
// The name is guaranteed to be unique as users cannot use this as variable
|
||||
// name.
|
||||
self.start = old_start.clone().map_or_else(
|
||||
|| {
|
||||
let start = format!("with-{}-start", self.name_counter).into();
|
||||
let start_expr = Located {
|
||||
// location does not matter at this point
|
||||
location: stmt.location,
|
||||
node: ExprKind::Name { id: start, ctx: *name_ctx },
|
||||
custom: Some(ctx.primitives.int64),
|
||||
};
|
||||
let start = self
|
||||
.gen_store_target(ctx, &start_expr, Some("start.addr"))?
|
||||
.unwrap();
|
||||
let outer_end = self.gen_store_target(ctx, old_end)?;
|
||||
ctx.builder.build_store(outer_end, max);
|
||||
}
|
||||
self.start = old_start;
|
||||
self.end = old_end;
|
||||
if reset_position {
|
||||
ctx.builder.position_at_end(current);
|
||||
}
|
||||
return Ok(());
|
||||
} else if id == &"sequential".into() {
|
||||
let start = self.start.take();
|
||||
for stmt in body.iter() {
|
||||
self.gen_stmt(ctx, stmt)?;
|
||||
if ctx.is_terminated() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
self.start = start;
|
||||
return Ok(());
|
||||
ctx.builder.build_store(start, now).unwrap();
|
||||
Ok(Some(start_expr)) as Result<_, String>
|
||||
},
|
||||
|v| Ok(Some(v)),
|
||||
)?;
|
||||
let end = format!("with-{}-end", self.name_counter).into();
|
||||
let end_expr = Located {
|
||||
// location does not matter at this point
|
||||
location: stmt.location,
|
||||
node: ExprKind::Name { id: end, ctx: *name_ctx },
|
||||
custom: Some(ctx.primitives.int64),
|
||||
};
|
||||
let end = self.gen_store_target(ctx, &end_expr, Some("end.addr"))?.unwrap();
|
||||
ctx.builder.build_store(end, now).unwrap();
|
||||
self.end = Some(end_expr);
|
||||
self.name_counter += 1;
|
||||
self.parallel_mode = match id.to_string().as_str() {
|
||||
"parallel" => ParallelMode::Deep,
|
||||
"legacy_parallel" => ParallelMode::Legacy,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
self.gen_block(ctx, body.iter())?;
|
||||
|
||||
let current = ctx.builder.get_insert_block().unwrap();
|
||||
|
||||
// if the current block is terminated, move before the terminator
|
||||
// we want to set the timeline before reaching the terminator
|
||||
// TODO: This may be unsound if there are multiple exit paths in the
|
||||
// block... e.g.
|
||||
// if ...:
|
||||
// return
|
||||
// Perhaps we can fix this by using actual with block?
|
||||
let reset_position = if let Some(terminator) = current.get_terminator() {
|
||||
ctx.builder.position_before(&terminator);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
// set duration
|
||||
let end_expr = self.end.take().unwrap();
|
||||
let end_val = self.gen_expr(ctx, &end_expr)?.unwrap().to_basic_value_enum(
|
||||
ctx,
|
||||
self,
|
||||
end_expr.custom.unwrap(),
|
||||
)?;
|
||||
|
||||
// inside a sequential block
|
||||
if old_start.is_none() {
|
||||
self.timeline.emit_at_mu(ctx, end_val);
|
||||
}
|
||||
|
||||
// inside a parallel block, should update the outer max now_mu
|
||||
self.timeline_update_end_max(ctx, old_end.clone(), Some("outer.end"))?;
|
||||
|
||||
self.parallel_mode = old_parallel_mode;
|
||||
self.end = old_end;
|
||||
self.start = old_start;
|
||||
|
||||
if reset_position {
|
||||
ctx.builder.position_at_end(current);
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
} else if id == &"sequential".into() {
|
||||
// For deep parallel, temporarily take away start to avoid function calls in
|
||||
// the block from resetting the timeline.
|
||||
// This does not affect legacy parallel, as the timeline will be reset after
|
||||
// this block finishes execution.
|
||||
let start = self.start.take();
|
||||
self.gen_block(ctx, body.iter())?;
|
||||
self.start = start;
|
||||
|
||||
// Reset the timeline when we are exiting the sequential block
|
||||
// Legacy parallel does not need this, since it will be reset after codegen
|
||||
// for this statement is completed
|
||||
if self.parallel_mode == ParallelMode::Deep {
|
||||
self.timeline_reset_start(ctx)?;
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
// not parallel/sequential
|
||||
gen_with(self, ctx, stmt)
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
// not parallel/sequential
|
||||
gen_with(self, ctx, stmt)
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_rpc_tag<'ctx, 'a>(
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
fn gen_rpc_tag(
|
||||
ctx: &mut CodeGenContext<'_, '_>,
|
||||
ty: Type,
|
||||
buffer: &mut Vec<u8>,
|
||||
) -> Result<(), String> {
|
||||
|
@ -280,26 +403,26 @@ fn gen_rpc_tag<'ctx, 'a>(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn rpc_codegen_callback_fn<'ctx, 'a>(
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
fn rpc_codegen_callback_fn<'ctx>(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
obj: Option<(Type, ValueEnum<'ctx>)>,
|
||||
fun: (&FunSignature, DefinitionId),
|
||||
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
|
||||
let ptr_type = ctx.ctx.i8_type().ptr_type(inkwell::AddressSpace::Generic);
|
||||
let ptr_type = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
|
||||
let size_type = generator.get_size_type(ctx.ctx);
|
||||
let int8 = ctx.ctx.i8_type();
|
||||
let int32 = ctx.ctx.i32_type();
|
||||
let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false);
|
||||
|
||||
let service_id = int32.const_int(fun.1.0 as u64, false);
|
||||
let service_id = int32.const_int(fun.1 .0 as u64, false);
|
||||
// -- setup rpc tags
|
||||
let mut tag = Vec::new();
|
||||
if obj.is_some() {
|
||||
tag.push(b'O');
|
||||
}
|
||||
for arg in fun.0.args.iter() {
|
||||
for arg in &fun.0.args {
|
||||
gen_rpc_tag(ctx, arg.ty, &mut tag)?;
|
||||
}
|
||||
tag.push(b':');
|
||||
|
@ -319,7 +442,7 @@ fn rpc_codegen_callback_fn<'ctx, 'a>(
|
|||
format!("tagptr{}", fun.1 .0).as_str(),
|
||||
);
|
||||
tag_arr_ptr.set_initializer(&int8.const_array(
|
||||
&tag.iter().map(|v| int8.const_int(*v as u64, false)).collect::<Vec<_>>(),
|
||||
&tag.iter().map(|v| int8.const_int(u64::from(*v), false)).collect::<Vec<_>>(),
|
||||
));
|
||||
tag_arr_ptr.set_linkage(Linkage::Private);
|
||||
let tag_ptr = ctx.module.add_global(tag_ptr_type, None, &hash);
|
||||
|
@ -335,38 +458,28 @@ fn rpc_codegen_callback_fn<'ctx, 'a>(
|
|||
})
|
||||
.as_pointer_value();
|
||||
|
||||
let arg_length = args.len() + if obj.is_some() { 1 } else { 0 };
|
||||
let arg_length = args.len() + usize::from(obj.is_some());
|
||||
|
||||
let stacksave = ctx.module.get_function("llvm.stacksave").unwrap_or_else(|| {
|
||||
ctx.module.add_function("llvm.stacksave", ptr_type.fn_type(&[], false), None)
|
||||
});
|
||||
let stackrestore = ctx.module.get_function("llvm.stackrestore").unwrap_or_else(|| {
|
||||
ctx.module.add_function(
|
||||
"llvm.stackrestore",
|
||||
ctx.ctx.void_type().fn_type(&[ptr_type.into()], false),
|
||||
None,
|
||||
let stackptr = call_stacksave(ctx, Some("rpc.stack"));
|
||||
let args_ptr = ctx
|
||||
.builder
|
||||
.build_array_alloca(
|
||||
ptr_type,
|
||||
ctx.ctx.i32_type().const_int(arg_length as u64, false),
|
||||
"argptr",
|
||||
)
|
||||
});
|
||||
|
||||
let stackptr = ctx.builder.build_call(stacksave, &[], "rpc.stack");
|
||||
let args_ptr = ctx.builder.build_array_alloca(
|
||||
ptr_type,
|
||||
ctx.ctx.i32_type().const_int(arg_length as u64, false),
|
||||
"argptr",
|
||||
);
|
||||
.unwrap();
|
||||
|
||||
// -- rpc args handling
|
||||
let mut keys = fun.0.args.clone();
|
||||
let mut mapping = HashMap::new();
|
||||
for (key, value) in args.into_iter() {
|
||||
for (key, value) in args {
|
||||
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value);
|
||||
}
|
||||
// default value handling
|
||||
for k in keys.into_iter() {
|
||||
mapping.insert(
|
||||
k.name,
|
||||
ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into()
|
||||
);
|
||||
for k in keys {
|
||||
mapping
|
||||
.insert(k.name, ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into());
|
||||
}
|
||||
// reorder the parameters
|
||||
let mut real_params = fun
|
||||
|
@ -385,17 +498,19 @@ fn rpc_codegen_callback_fn<'ctx, 'a>(
|
|||
}
|
||||
|
||||
for (i, arg) in real_params.iter().enumerate() {
|
||||
let arg_slot = ctx.builder.build_alloca(arg.get_type(), &format!("rpc.arg{}", i));
|
||||
ctx.builder.build_store(arg_slot, *arg);
|
||||
let arg_slot = ctx.builder.build_bitcast(arg_slot, ptr_type, "rpc.arg");
|
||||
let arg_slot =
|
||||
generator.gen_var_alloc(ctx, arg.get_type(), Some(&format!("rpc.arg{i}"))).unwrap();
|
||||
ctx.builder.build_store(arg_slot, *arg).unwrap();
|
||||
let arg_slot = ctx.builder.build_bitcast(arg_slot, ptr_type, "rpc.arg").unwrap();
|
||||
let arg_ptr = unsafe {
|
||||
ctx.builder.build_gep(
|
||||
args_ptr,
|
||||
&[int32.const_int(i as u64, false)],
|
||||
&format!("rpc.arg{}", i),
|
||||
&format!("rpc.arg{i}"),
|
||||
)
|
||||
};
|
||||
ctx.builder.build_store(arg_ptr, arg_slot);
|
||||
}
|
||||
.unwrap();
|
||||
ctx.builder.build_store(arg_ptr, arg_slot).unwrap();
|
||||
}
|
||||
|
||||
// call
|
||||
|
@ -405,26 +520,20 @@ fn rpc_codegen_callback_fn<'ctx, 'a>(
|
|||
ctx.ctx.void_type().fn_type(
|
||||
&[
|
||||
int32.into(),
|
||||
tag_ptr_type.ptr_type(AddressSpace::Generic).into(),
|
||||
ptr_type.ptr_type(AddressSpace::Generic).into(),
|
||||
tag_ptr_type.ptr_type(AddressSpace::default()).into(),
|
||||
ptr_type.ptr_type(AddressSpace::default()).into(),
|
||||
],
|
||||
false,
|
||||
),
|
||||
None,
|
||||
)
|
||||
});
|
||||
ctx.builder.build_call(
|
||||
rpc_send,
|
||||
&[service_id.into(), tag_ptr.into(), args_ptr.into()],
|
||||
"rpc.send",
|
||||
);
|
||||
ctx.builder
|
||||
.build_call(rpc_send, &[service_id.into(), tag_ptr.into(), args_ptr.into()], "rpc.send")
|
||||
.unwrap();
|
||||
|
||||
// reclaim stack space used by arguments
|
||||
ctx.builder.build_call(
|
||||
stackrestore,
|
||||
&[stackptr.try_as_basic_value().unwrap_left().into()],
|
||||
"rpc.stackrestore",
|
||||
);
|
||||
call_stackrestore(ctx, stackptr);
|
||||
|
||||
// -- receive value:
|
||||
// T result = {
|
||||
|
@ -450,86 +559,91 @@ fn rpc_codegen_callback_fn<'ctx, 'a>(
|
|||
let alloc_bb = ctx.ctx.append_basic_block(current_function, "rpc.continue");
|
||||
let tail_bb = ctx.ctx.append_basic_block(current_function, "rpc.tail");
|
||||
|
||||
let ret_ty = ctx.get_llvm_type(generator, fun.0.ret);
|
||||
let ret_ty = ctx.get_llvm_abi_type(generator, fun.0.ret);
|
||||
let need_load = !ret_ty.is_pointer_type();
|
||||
let slot = ctx.builder.build_alloca(ret_ty, "rpc.ret.slot");
|
||||
let slotgen = ctx.builder.build_bitcast(slot, ptr_type, "rpc.ret.ptr");
|
||||
ctx.builder.build_unconditional_branch(head_bb);
|
||||
let slot = ctx.builder.build_alloca(ret_ty, "rpc.ret.slot").unwrap();
|
||||
let slotgen = ctx.builder.build_bitcast(slot, ptr_type, "rpc.ret.ptr").unwrap();
|
||||
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
||||
ctx.builder.position_at_end(head_bb);
|
||||
|
||||
let phi = ctx.builder.build_phi(ptr_type, "rpc.ptr");
|
||||
let phi = ctx.builder.build_phi(ptr_type, "rpc.ptr").unwrap();
|
||||
phi.add_incoming(&[(&slotgen, prehead_bb)]);
|
||||
let alloc_size = ctx
|
||||
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
|
||||
.unwrap()
|
||||
.into_int_value();
|
||||
let is_done = ctx.builder.build_int_compare(
|
||||
inkwell::IntPredicate::EQ,
|
||||
int32.const_zero(),
|
||||
alloc_size,
|
||||
"rpc.done",
|
||||
);
|
||||
let is_done = ctx
|
||||
.builder
|
||||
.build_int_compare(inkwell::IntPredicate::EQ, int32.const_zero(), alloc_size, "rpc.done")
|
||||
.unwrap();
|
||||
|
||||
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb);
|
||||
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap();
|
||||
ctx.builder.position_at_end(alloc_bb);
|
||||
|
||||
let alloc_ptr = ctx.builder.build_array_alloca(ptr_type, alloc_size, "rpc.alloc");
|
||||
let alloc_ptr = ctx.builder.build_bitcast(alloc_ptr, ptr_type, "rpc.alloc.ptr");
|
||||
let alloc_ptr = ctx.builder.build_array_alloca(ptr_type, alloc_size, "rpc.alloc").unwrap();
|
||||
let alloc_ptr = ctx.builder.build_bitcast(alloc_ptr, ptr_type, "rpc.alloc.ptr").unwrap();
|
||||
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
|
||||
ctx.builder.build_unconditional_branch(head_bb);
|
||||
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
||||
|
||||
ctx.builder.position_at_end(tail_bb);
|
||||
|
||||
let result = ctx.builder.build_load(slot, "rpc.result");
|
||||
let result = ctx.builder.build_load(slot, "rpc.result").unwrap();
|
||||
if need_load {
|
||||
ctx.builder.build_call(
|
||||
stackrestore,
|
||||
&[stackptr.try_as_basic_value().unwrap_left().into()],
|
||||
"rpc.stackrestore",
|
||||
);
|
||||
call_stackrestore(ctx, stackptr);
|
||||
}
|
||||
Ok(Some(result))
|
||||
}
|
||||
|
||||
pub fn attributes_writeback<'ctx, 'a>(
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
pub fn attributes_writeback(
|
||||
ctx: &mut CodeGenContext<'_, '_>,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
inner_resolver: &InnerResolver,
|
||||
host_attributes: PyObject,
|
||||
host_attributes: &PyObject,
|
||||
) -> Result<(), String> {
|
||||
Python::with_gil(|py| -> PyResult<Result<(), String>> {
|
||||
let host_attributes = host_attributes.cast_as::<PyList>(py)?;
|
||||
let host_attributes: &PyList = host_attributes.downcast(py)?;
|
||||
let top_levels = ctx.top_level.definitions.read();
|
||||
let globals = inner_resolver.global_value_ids.read();
|
||||
let int32 = ctx.ctx.i32_type();
|
||||
let zero = int32.const_zero();
|
||||
let mut values = Vec::new();
|
||||
let mut scratch_buffer = Vec::new();
|
||||
for (_, val) in globals.iter() {
|
||||
for val in (*globals).values() {
|
||||
let val = val.as_ref(py);
|
||||
let ty = inner_resolver.get_obj_type(py, val, &mut ctx.unifier, &top_levels, &ctx.primitives)?;
|
||||
let ty = inner_resolver.get_obj_type(
|
||||
py,
|
||||
val,
|
||||
&mut ctx.unifier,
|
||||
&top_levels,
|
||||
&ctx.primitives,
|
||||
)?;
|
||||
if let Err(ty) = ty {
|
||||
return Ok(Err(ty))
|
||||
return Ok(Err(ty));
|
||||
}
|
||||
let ty = ty.unwrap();
|
||||
match &*ctx.unifier.get_ty(ty) {
|
||||
TypeEnum::TObj { fields, obj_id, .. }
|
||||
if *obj_id != ctx.primitives.option.get_obj_id(&ctx.unifier) =>
|
||||
if *obj_id != ctx.primitives.option.obj_id(&ctx.unifier).unwrap() =>
|
||||
{
|
||||
// we only care about primitive attributes
|
||||
// for non-primitive attributes, they should be in another global
|
||||
let mut attributes = Vec::new();
|
||||
let obj = inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap();
|
||||
for (name, (field_ty, is_mutable)) in fields.iter() {
|
||||
for (name, (field_ty, is_mutable)) in fields {
|
||||
if !is_mutable {
|
||||
continue
|
||||
continue;
|
||||
}
|
||||
if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() {
|
||||
attributes.push(name.to_string());
|
||||
let index = ctx.get_attr_index(ty, *name);
|
||||
values.push((*field_ty, ctx.build_gep_and_load(
|
||||
obj.into_pointer_value(),
|
||||
&[zero, int32.const_int(index as u64, false)])));
|
||||
values.push((
|
||||
*field_ty,
|
||||
ctx.build_gep_and_load(
|
||||
obj.into_pointer_value(),
|
||||
&[zero, int32.const_int(index as u64, false)],
|
||||
None,
|
||||
),
|
||||
));
|
||||
}
|
||||
}
|
||||
if !attributes.is_empty() {
|
||||
|
@ -538,33 +652,44 @@ pub fn attributes_writeback<'ctx, 'a>(
|
|||
pydict.set_item("fields", attributes)?;
|
||||
host_attributes.append(pydict)?;
|
||||
}
|
||||
},
|
||||
}
|
||||
TypeEnum::TList { ty: elem_ty } => {
|
||||
if gen_rpc_tag(ctx, *elem_ty, &mut scratch_buffer).is_ok() {
|
||||
let pydict = PyDict::new(py);
|
||||
pydict.set_item("obj", val)?;
|
||||
host_attributes.append(pydict)?;
|
||||
values.push((ty, inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap()));
|
||||
values.push((
|
||||
ty,
|
||||
inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap(),
|
||||
));
|
||||
}
|
||||
},
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
let fun = FunSignature {
|
||||
args: values.iter().enumerate().map(|(i, (ty, _))| FuncArg {
|
||||
name: i.to_string().into(),
|
||||
ty: *ty,
|
||||
default_value: None
|
||||
}).collect(),
|
||||
args: values
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, (ty, _))| FuncArg {
|
||||
name: i.to_string().into(),
|
||||
ty: *ty,
|
||||
default_value: None,
|
||||
})
|
||||
.collect(),
|
||||
ret: ctx.primitives.none,
|
||||
vars: Default::default()
|
||||
vars: VarMap::default(),
|
||||
};
|
||||
let args: Vec<_> = values.into_iter().map(|(_, val)| (None, ValueEnum::Dynamic(val))).collect();
|
||||
if let Err(e) = rpc_codegen_callback_fn(ctx, None, (&fun, DefinitionId(0)), args, generator) {
|
||||
let args: Vec<_> =
|
||||
values.into_iter().map(|(_, val)| (None, ValueEnum::Dynamic(val))).collect();
|
||||
if let Err(e) =
|
||||
rpc_codegen_callback_fn(ctx, None, (&fun, PrimDef::Int32.id()), args, generator)
|
||||
{
|
||||
return Ok(Err(e));
|
||||
}
|
||||
Ok(Ok(()))
|
||||
}).unwrap()?;
|
||||
})
|
||||
.unwrap()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -1,10 +1,20 @@
|
|||
use inkwell::{values::BasicValueEnum, AddressSpace, AtomicOrdering};
|
||||
use inkwell::{
|
||||
values::{BasicValueEnum, CallSiteValue},
|
||||
AddressSpace, AtomicOrdering,
|
||||
};
|
||||
use itertools::Either;
|
||||
use nac3core::codegen::CodeGenContext;
|
||||
|
||||
/// Functions for manipulating the timeline.
|
||||
pub trait TimeFns {
|
||||
fn emit_now_mu<'ctx, 'a>(&self, ctx: &mut CodeGenContext<'ctx, 'a>) -> BasicValueEnum<'ctx>;
|
||||
fn emit_at_mu<'ctx, 'a>(&self, ctx: &mut CodeGenContext<'ctx, 'a>, t: BasicValueEnum<'ctx>);
|
||||
fn emit_delay_mu<'ctx, 'a>(&self, ctx: &mut CodeGenContext<'ctx, 'a>, dt: BasicValueEnum<'ctx>);
|
||||
/// Emits LLVM IR for `now_mu`.
|
||||
fn emit_now_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx>;
|
||||
|
||||
/// Emits LLVM IR for `at_mu`.
|
||||
fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>);
|
||||
|
||||
/// Emits LLVM IR for `delay_mu`.
|
||||
fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>);
|
||||
}
|
||||
|
||||
pub struct NowPinningTimeFns64 {}
|
||||
|
@ -12,141 +22,143 @@ pub struct NowPinningTimeFns64 {}
|
|||
// For FPGA design reasons, on VexRiscv with 64-bit data bus, the "now" CSR is split into two 32-bit
|
||||
// values that are each padded to 64-bits.
|
||||
impl TimeFns for NowPinningTimeFns64 {
|
||||
fn emit_now_mu<'ctx, 'a>(&self, ctx: &mut CodeGenContext<'ctx, 'a>) -> BasicValueEnum<'ctx> {
|
||||
fn emit_now_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx> {
|
||||
let i64_type = ctx.ctx.i64_type();
|
||||
let i32_type = ctx.ctx.i32_type();
|
||||
let now = ctx
|
||||
.module
|
||||
.get_global("now")
|
||||
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
|
||||
let now_hiptr =
|
||||
ctx.builder.build_bitcast(now, i32_type.ptr_type(AddressSpace::Generic), "now_hiptr");
|
||||
if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr {
|
||||
let now_loptr = unsafe {
|
||||
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now_gep")
|
||||
};
|
||||
if let (BasicValueEnum::IntValue(now_hi), BasicValueEnum::IntValue(now_lo)) = (
|
||||
ctx.builder.build_load(now_hiptr, "now_hi"),
|
||||
ctx.builder.build_load(now_loptr, "now_lo"),
|
||||
) {
|
||||
let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "now_zext_hi");
|
||||
let shifted_hi = ctx.builder.build_left_shift(
|
||||
zext_hi,
|
||||
i64_type.const_int(32, false),
|
||||
"now_shifted_zext_hi",
|
||||
);
|
||||
let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "now_zext_lo");
|
||||
ctx.builder.build_or(shifted_hi, zext_lo, "now_or").into()
|
||||
} else {
|
||||
unreachable!();
|
||||
}
|
||||
} else {
|
||||
unreachable!();
|
||||
let now_hiptr = ctx
|
||||
.builder
|
||||
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
|
||||
.map(BasicValueEnum::into_pointer_value)
|
||||
.unwrap();
|
||||
|
||||
let now_loptr = unsafe {
|
||||
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")
|
||||
}
|
||||
.unwrap();
|
||||
|
||||
let now_hi = ctx
|
||||
.builder
|
||||
.build_load(now_hiptr, "now.hi")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
let now_lo = ctx
|
||||
.builder
|
||||
.build_load(now_loptr, "now.lo")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
|
||||
let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "").unwrap();
|
||||
let shifted_hi =
|
||||
ctx.builder.build_left_shift(zext_hi, i64_type.const_int(32, false), "").unwrap();
|
||||
let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "").unwrap();
|
||||
ctx.builder.build_or(shifted_hi, zext_lo, "now_mu").map(Into::into).unwrap()
|
||||
}
|
||||
|
||||
fn emit_at_mu<'ctx, 'a>(&self, ctx: &mut CodeGenContext<'ctx, 'a>, t: BasicValueEnum<'ctx>) {
|
||||
fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) {
|
||||
let i32_type = ctx.ctx.i32_type();
|
||||
let i64_type = ctx.ctx.i64_type();
|
||||
|
||||
let i64_32 = i64_type.const_int(32, false);
|
||||
if let BasicValueEnum::IntValue(time) = t {
|
||||
let time_hi = ctx.builder.build_int_truncate(
|
||||
ctx.builder.build_right_shift(time, i64_32, false, "now_lshr"),
|
||||
let time = t.into_int_value();
|
||||
|
||||
let time_hi = ctx
|
||||
.builder
|
||||
.build_int_truncate(
|
||||
ctx.builder.build_right_shift(time, i64_32, false, "time.hi").unwrap(),
|
||||
i32_type,
|
||||
"now_trunc",
|
||||
);
|
||||
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "now_trunc");
|
||||
let now = ctx
|
||||
.module
|
||||
.get_global("now")
|
||||
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
|
||||
let now_hiptr = ctx.builder.build_bitcast(
|
||||
now,
|
||||
i32_type.ptr_type(AddressSpace::Generic),
|
||||
"now_bitcast",
|
||||
);
|
||||
if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr {
|
||||
let now_loptr = unsafe {
|
||||
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now_gep")
|
||||
};
|
||||
ctx.builder
|
||||
.build_store(now_hiptr, time_hi)
|
||||
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
|
||||
.unwrap();
|
||||
ctx.builder
|
||||
.build_store(now_loptr, time_lo)
|
||||
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
|
||||
.unwrap();
|
||||
} else {
|
||||
unreachable!();
|
||||
}
|
||||
} else {
|
||||
unreachable!();
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap();
|
||||
let now = ctx
|
||||
.module
|
||||
.get_global("now")
|
||||
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
|
||||
let now_hiptr = ctx
|
||||
.builder
|
||||
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
|
||||
.map(BasicValueEnum::into_pointer_value)
|
||||
.unwrap();
|
||||
|
||||
let now_loptr = unsafe {
|
||||
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")
|
||||
}
|
||||
.unwrap();
|
||||
ctx.builder
|
||||
.build_store(now_hiptr, time_hi)
|
||||
.unwrap()
|
||||
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
|
||||
.unwrap();
|
||||
ctx.builder
|
||||
.build_store(now_loptr, time_lo)
|
||||
.unwrap()
|
||||
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
fn emit_delay_mu<'ctx, 'a>(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
dt: BasicValueEnum<'ctx>,
|
||||
) {
|
||||
fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>) {
|
||||
let i64_type = ctx.ctx.i64_type();
|
||||
let i32_type = ctx.ctx.i32_type();
|
||||
let now = ctx
|
||||
.module
|
||||
.get_global("now")
|
||||
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
|
||||
let now_hiptr =
|
||||
ctx.builder.build_bitcast(now, i32_type.ptr_type(AddressSpace::Generic), "now_hiptr");
|
||||
if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr {
|
||||
let now_loptr = unsafe {
|
||||
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now_loptr")
|
||||
};
|
||||
if let (
|
||||
BasicValueEnum::IntValue(now_hi),
|
||||
BasicValueEnum::IntValue(now_lo),
|
||||
BasicValueEnum::IntValue(dt),
|
||||
) = (
|
||||
ctx.builder.build_load(now_hiptr, "now_hi"),
|
||||
ctx.builder.build_load(now_loptr, "now_lo"),
|
||||
dt,
|
||||
) {
|
||||
let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "now_zext_hi");
|
||||
let shifted_hi = ctx.builder.build_left_shift(
|
||||
zext_hi,
|
||||
i64_type.const_int(32, false),
|
||||
"now_shifted_zext_hi",
|
||||
);
|
||||
let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "now_zext_lo");
|
||||
let now_val = ctx.builder.build_or(shifted_hi, zext_lo, "now_or");
|
||||
let now_hiptr = ctx
|
||||
.builder
|
||||
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
|
||||
.map(BasicValueEnum::into_pointer_value)
|
||||
.unwrap();
|
||||
|
||||
let time = ctx.builder.build_int_add(now_val, dt, "now_add");
|
||||
let time_hi = ctx.builder.build_int_truncate(
|
||||
ctx.builder.build_right_shift(
|
||||
time,
|
||||
i64_type.const_int(32, false),
|
||||
false,
|
||||
"now_lshr",
|
||||
),
|
||||
i32_type,
|
||||
"now_trunc",
|
||||
);
|
||||
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "now_trunc");
|
||||
let now_loptr = unsafe {
|
||||
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")
|
||||
}
|
||||
.unwrap();
|
||||
|
||||
let now_hi = ctx
|
||||
.builder
|
||||
.build_load(now_hiptr, "now.hi")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
let now_lo = ctx
|
||||
.builder
|
||||
.build_load(now_loptr, "now.lo")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
let dt = dt.into_int_value();
|
||||
|
||||
let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "").unwrap();
|
||||
let shifted_hi =
|
||||
ctx.builder.build_left_shift(zext_hi, i64_type.const_int(32, false), "").unwrap();
|
||||
let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "").unwrap();
|
||||
let now_val = ctx.builder.build_or(shifted_hi, zext_lo, "now").unwrap();
|
||||
|
||||
let time = ctx.builder.build_int_add(now_val, dt, "time").unwrap();
|
||||
let time_hi = ctx
|
||||
.builder
|
||||
.build_int_truncate(
|
||||
ctx.builder
|
||||
.build_store(now_hiptr, time_hi)
|
||||
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
|
||||
.unwrap();
|
||||
ctx.builder
|
||||
.build_store(now_loptr, time_lo)
|
||||
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
|
||||
.unwrap();
|
||||
} else {
|
||||
unreachable!();
|
||||
}
|
||||
} else {
|
||||
unreachable!();
|
||||
};
|
||||
.build_right_shift(time, i64_type.const_int(32, false), false, "")
|
||||
.unwrap(),
|
||||
i32_type,
|
||||
"time.hi",
|
||||
)
|
||||
.unwrap();
|
||||
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_store(now_hiptr, time_hi)
|
||||
.unwrap()
|
||||
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
|
||||
.unwrap();
|
||||
ctx.builder
|
||||
.build_store(now_loptr, time_lo)
|
||||
.unwrap()
|
||||
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -155,68 +167,67 @@ pub static NOW_PINNING_TIME_FNS_64: NowPinningTimeFns64 = NowPinningTimeFns64 {}
|
|||
pub struct NowPinningTimeFns {}
|
||||
|
||||
impl TimeFns for NowPinningTimeFns {
|
||||
fn emit_now_mu<'ctx, 'a>(&self, ctx: &mut CodeGenContext<'ctx, 'a>) -> BasicValueEnum<'ctx> {
|
||||
fn emit_now_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx> {
|
||||
let i64_type = ctx.ctx.i64_type();
|
||||
let now = ctx
|
||||
.module
|
||||
.get_global("now")
|
||||
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
|
||||
let now_raw = ctx.builder.build_load(now.as_pointer_value(), "now");
|
||||
if let BasicValueEnum::IntValue(now_raw) = now_raw {
|
||||
let i64_32 = i64_type.const_int(32, false);
|
||||
let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now_shl");
|
||||
let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now_lshr");
|
||||
ctx.builder.build_or(now_lo, now_hi, "now_or").into()
|
||||
} else {
|
||||
unreachable!();
|
||||
}
|
||||
let now_raw = ctx
|
||||
.builder
|
||||
.build_load(now.as_pointer_value(), "now")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
|
||||
let i64_32 = i64_type.const_int(32, false);
|
||||
let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now.lo").unwrap();
|
||||
let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi").unwrap();
|
||||
ctx.builder.build_or(now_lo, now_hi, "now_mu").map(Into::into).unwrap()
|
||||
}
|
||||
|
||||
fn emit_at_mu<'ctx, 'a>(&self, ctx: &mut CodeGenContext<'ctx, 'a>, t: BasicValueEnum<'ctx>) {
|
||||
fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) {
|
||||
let i32_type = ctx.ctx.i32_type();
|
||||
let i64_type = ctx.ctx.i64_type();
|
||||
let i64_32 = i64_type.const_int(32, false);
|
||||
if let BasicValueEnum::IntValue(time) = t {
|
||||
let time_hi = ctx.builder.build_int_truncate(
|
||||
ctx.builder.build_right_shift(time, i64_32, false, "now_lshr"),
|
||||
|
||||
let time = t.into_int_value();
|
||||
|
||||
let time_hi = ctx
|
||||
.builder
|
||||
.build_int_truncate(
|
||||
ctx.builder.build_right_shift(time, i64_32, false, "").unwrap(),
|
||||
i32_type,
|
||||
"now_trunc",
|
||||
);
|
||||
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "now_trunc");
|
||||
let now = ctx
|
||||
.module
|
||||
.get_global("now")
|
||||
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
|
||||
let now_hiptr = ctx.builder.build_bitcast(
|
||||
now,
|
||||
i32_type.ptr_type(AddressSpace::Generic),
|
||||
"now_bitcast",
|
||||
);
|
||||
if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr {
|
||||
let now_loptr = unsafe {
|
||||
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now_gep")
|
||||
};
|
||||
ctx.builder
|
||||
.build_store(now_hiptr, time_hi)
|
||||
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
|
||||
.unwrap();
|
||||
ctx.builder
|
||||
.build_store(now_loptr, time_lo)
|
||||
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
|
||||
.unwrap();
|
||||
} else {
|
||||
unreachable!();
|
||||
}
|
||||
} else {
|
||||
unreachable!();
|
||||
"time.hi",
|
||||
)
|
||||
.unwrap();
|
||||
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "now_trunc").unwrap();
|
||||
let now = ctx
|
||||
.module
|
||||
.get_global("now")
|
||||
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
|
||||
let now_hiptr = ctx
|
||||
.builder
|
||||
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
|
||||
.map(BasicValueEnum::into_pointer_value)
|
||||
.unwrap();
|
||||
|
||||
let now_loptr = unsafe {
|
||||
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr")
|
||||
}
|
||||
.unwrap();
|
||||
ctx.builder
|
||||
.build_store(now_hiptr, time_hi)
|
||||
.unwrap()
|
||||
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
|
||||
.unwrap();
|
||||
ctx.builder
|
||||
.build_store(now_loptr, time_lo)
|
||||
.unwrap()
|
||||
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
fn emit_delay_mu<'ctx, 'a>(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
dt: BasicValueEnum<'ctx>,
|
||||
) {
|
||||
fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>) {
|
||||
let i32_type = ctx.ctx.i32_type();
|
||||
let i64_type = ctx.ctx.i64_type();
|
||||
let i64_32 = i64_type.const_int(32, false);
|
||||
|
@ -224,41 +235,47 @@ impl TimeFns for NowPinningTimeFns {
|
|||
.module
|
||||
.get_global("now")
|
||||
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
|
||||
let now_raw = ctx.builder.build_load(now.as_pointer_value(), "now");
|
||||
if let (BasicValueEnum::IntValue(now_raw), BasicValueEnum::IntValue(dt)) = (now_raw, dt) {
|
||||
let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now_shl");
|
||||
let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now_lshr");
|
||||
let now_val = ctx.builder.build_or(now_lo, now_hi, "now_or");
|
||||
let time = ctx.builder.build_int_add(now_val, dt, "now_add");
|
||||
let time_hi = ctx.builder.build_int_truncate(
|
||||
ctx.builder.build_right_shift(time, i64_32, false, "now_lshr"),
|
||||
let now_raw = ctx
|
||||
.builder
|
||||
.build_load(now.as_pointer_value(), "")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
|
||||
let dt = dt.into_int_value();
|
||||
|
||||
let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now.lo").unwrap();
|
||||
let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi").unwrap();
|
||||
let now_val = ctx.builder.build_or(now_lo, now_hi, "now_val").unwrap();
|
||||
let time = ctx.builder.build_int_add(now_val, dt, "time").unwrap();
|
||||
let time_hi = ctx
|
||||
.builder
|
||||
.build_int_truncate(
|
||||
ctx.builder.build_right_shift(time, i64_32, false, "time.hi").unwrap(),
|
||||
i32_type,
|
||||
"now_trunc",
|
||||
);
|
||||
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "now_trunc");
|
||||
let now_hiptr = ctx.builder.build_bitcast(
|
||||
now,
|
||||
i32_type.ptr_type(AddressSpace::Generic),
|
||||
"now_bitcast",
|
||||
);
|
||||
if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr {
|
||||
let now_loptr = unsafe {
|
||||
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now_gep")
|
||||
};
|
||||
ctx.builder
|
||||
.build_store(now_hiptr, time_hi)
|
||||
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
|
||||
.unwrap();
|
||||
ctx.builder
|
||||
.build_store(now_loptr, time_lo)
|
||||
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
|
||||
.unwrap();
|
||||
} else {
|
||||
unreachable!();
|
||||
}
|
||||
} else {
|
||||
unreachable!();
|
||||
)
|
||||
.unwrap();
|
||||
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap();
|
||||
let now_hiptr = ctx
|
||||
.builder
|
||||
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
|
||||
.map(BasicValueEnum::into_pointer_value)
|
||||
.unwrap();
|
||||
|
||||
let now_loptr = unsafe {
|
||||
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr")
|
||||
}
|
||||
.unwrap();
|
||||
ctx.builder
|
||||
.build_store(now_hiptr, time_hi)
|
||||
.unwrap()
|
||||
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
|
||||
.unwrap();
|
||||
ctx.builder
|
||||
.build_store(now_loptr, time_lo)
|
||||
.unwrap()
|
||||
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -267,14 +284,18 @@ pub static NOW_PINNING_TIME_FNS: NowPinningTimeFns = NowPinningTimeFns {};
|
|||
pub struct ExternTimeFns {}
|
||||
|
||||
impl TimeFns for ExternTimeFns {
|
||||
fn emit_now_mu<'ctx, 'a>(&self, ctx: &mut CodeGenContext<'ctx, 'a>) -> BasicValueEnum<'ctx> {
|
||||
fn emit_now_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx> {
|
||||
let now_mu = ctx.module.get_function("now_mu").unwrap_or_else(|| {
|
||||
ctx.module.add_function("now_mu", ctx.ctx.i64_type().fn_type(&[], false), None)
|
||||
});
|
||||
ctx.builder.build_call(now_mu, &[], "now_mu").try_as_basic_value().left().unwrap()
|
||||
ctx.builder
|
||||
.build_call(now_mu, &[], "now_mu")
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn emit_at_mu<'ctx, 'a>(&self, ctx: &mut CodeGenContext<'ctx, 'a>, t: BasicValueEnum<'ctx>) {
|
||||
fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) {
|
||||
let at_mu = ctx.module.get_function("at_mu").unwrap_or_else(|| {
|
||||
ctx.module.add_function(
|
||||
"at_mu",
|
||||
|
@ -282,14 +303,10 @@ impl TimeFns for ExternTimeFns {
|
|||
None,
|
||||
)
|
||||
});
|
||||
ctx.builder.build_call(at_mu, &[t.into()], "at_mu");
|
||||
ctx.builder.build_call(at_mu, &[t.into()], "at_mu").unwrap();
|
||||
}
|
||||
|
||||
fn emit_delay_mu<'ctx, 'a>(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
dt: BasicValueEnum<'ctx>,
|
||||
) {
|
||||
fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>) {
|
||||
let delay_mu = ctx.module.get_function("delay_mu").unwrap_or_else(|| {
|
||||
ctx.module.add_function(
|
||||
"delay_mu",
|
||||
|
@ -297,7 +314,7 @@ impl TimeFns for ExternTimeFns {
|
|||
None,
|
||||
)
|
||||
});
|
||||
ctx.builder.build_call(delay_mu, &[dt.into()], "delay_mu");
|
||||
ctx.builder.build_call(delay_mu, &[dt.into()], "delay_mu").unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
name = "nac3ast"
|
||||
version = "0.1.0"
|
||||
authors = ["RustPython Team", "M-Labs"]
|
||||
edition = "2018"
|
||||
edition = "2021"
|
||||
|
||||
[features]
|
||||
default = ["constant-optimization", "fold"]
|
||||
|
@ -10,7 +10,7 @@ constant-optimization = ["fold"]
|
|||
fold = []
|
||||
|
||||
[dependencies]
|
||||
lazy_static = "1.4.0"
|
||||
parking_lot = "0.11.1"
|
||||
string-interner = "0.13.0"
|
||||
fxhash = "0.2.1"
|
||||
lazy_static = "1.4"
|
||||
parking_lot = "0.12"
|
||||
string-interner = "0.17"
|
||||
fxhash = "0.2"
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -28,12 +28,12 @@ impl From<bool> for Constant {
|
|||
}
|
||||
impl From<i32> for Constant {
|
||||
fn from(i: i32) -> Constant {
|
||||
Self::Int(i as i128)
|
||||
Self::Int(i128::from(i))
|
||||
}
|
||||
}
|
||||
impl From<i64> for Constant {
|
||||
fn from(i: i64) -> Constant {
|
||||
Self::Int(i as i128)
|
||||
Self::Int(i128::from(i))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -50,6 +50,7 @@ pub enum ConversionFlag {
|
|||
}
|
||||
|
||||
impl ConversionFlag {
|
||||
#[must_use]
|
||||
pub fn try_from_byte(b: u8) -> Option<Self> {
|
||||
match b {
|
||||
b's' => Some(Self::Str),
|
||||
|
@ -69,6 +70,7 @@ pub struct ConstantOptimizer {
|
|||
#[cfg(feature = "constant-optimization")]
|
||||
impl ConstantOptimizer {
|
||||
#[inline]
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self { _priv: () }
|
||||
}
|
||||
|
@ -85,33 +87,22 @@ impl<U> crate::fold::Fold<U> for ConstantOptimizer {
|
|||
fn fold_expr(&mut self, node: crate::Expr<U>) -> Result<crate::Expr<U>, Self::Error> {
|
||||
match node.node {
|
||||
crate::ExprKind::Tuple { elts, ctx } => {
|
||||
let elts = elts
|
||||
.into_iter()
|
||||
.map(|x| self.fold_expr(x))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let expr = if elts
|
||||
.iter()
|
||||
.all(|e| matches!(e.node, crate::ExprKind::Constant { .. }))
|
||||
{
|
||||
let tuple = elts
|
||||
.into_iter()
|
||||
.map(|e| match e.node {
|
||||
crate::ExprKind::Constant { value, .. } => value,
|
||||
_ => unreachable!(),
|
||||
})
|
||||
.collect();
|
||||
crate::ExprKind::Constant {
|
||||
value: Constant::Tuple(tuple),
|
||||
kind: None,
|
||||
}
|
||||
} else {
|
||||
crate::ExprKind::Tuple { elts, ctx }
|
||||
};
|
||||
Ok(crate::Expr {
|
||||
node: expr,
|
||||
custom: node.custom,
|
||||
location: node.location,
|
||||
})
|
||||
let elts =
|
||||
elts.into_iter().map(|x| self.fold_expr(x)).collect::<Result<Vec<_>, _>>()?;
|
||||
let expr =
|
||||
if elts.iter().all(|e| matches!(e.node, crate::ExprKind::Constant { .. })) {
|
||||
let tuple = elts
|
||||
.into_iter()
|
||||
.map(|e| match e.node {
|
||||
crate::ExprKind::Constant { value, .. } => value,
|
||||
_ => unreachable!(),
|
||||
})
|
||||
.collect();
|
||||
crate::ExprKind::Constant { value: Constant::Tuple(tuple), kind: None }
|
||||
} else {
|
||||
crate::ExprKind::Tuple { elts, ctx }
|
||||
};
|
||||
Ok(crate::Expr { node: expr, custom: node.custom, location: node.location })
|
||||
}
|
||||
_ => crate::fold::fold_expr(self, node),
|
||||
}
|
||||
|
@ -138,18 +129,12 @@ mod tests {
|
|||
Located {
|
||||
location,
|
||||
custom,
|
||||
node: ExprKind::Constant {
|
||||
value: 1.into(),
|
||||
kind: None,
|
||||
},
|
||||
node: ExprKind::Constant { value: 1.into(), kind: None },
|
||||
},
|
||||
Located {
|
||||
location,
|
||||
custom,
|
||||
node: ExprKind::Constant {
|
||||
value: 2.into(),
|
||||
kind: None,
|
||||
},
|
||||
node: ExprKind::Constant { value: 2.into(), kind: None },
|
||||
},
|
||||
Located {
|
||||
location,
|
||||
|
@ -160,26 +145,17 @@ mod tests {
|
|||
Located {
|
||||
location,
|
||||
custom,
|
||||
node: ExprKind::Constant {
|
||||
value: 3.into(),
|
||||
kind: None,
|
||||
},
|
||||
node: ExprKind::Constant { value: 3.into(), kind: None },
|
||||
},
|
||||
Located {
|
||||
location,
|
||||
custom,
|
||||
node: ExprKind::Constant {
|
||||
value: 4.into(),
|
||||
kind: None,
|
||||
},
|
||||
node: ExprKind::Constant { value: 4.into(), kind: None },
|
||||
},
|
||||
Located {
|
||||
location,
|
||||
custom,
|
||||
node: ExprKind::Constant {
|
||||
value: 5.into(),
|
||||
kind: None,
|
||||
},
|
||||
node: ExprKind::Constant { value: 5.into(), kind: None },
|
||||
},
|
||||
],
|
||||
},
|
||||
|
@ -187,9 +163,7 @@ mod tests {
|
|||
],
|
||||
},
|
||||
};
|
||||
let new_ast = ConstantOptimizer::new()
|
||||
.fold_expr(ast)
|
||||
.unwrap_or_else(|e| match e {});
|
||||
let new_ast = ConstantOptimizer::new().fold_expr(ast).unwrap_or_else(|e| match e {});
|
||||
assert_eq!(
|
||||
new_ast,
|
||||
Located {
|
||||
|
@ -199,11 +173,7 @@ mod tests {
|
|||
value: Constant::Tuple(vec![
|
||||
1.into(),
|
||||
2.into(),
|
||||
Constant::Tuple(vec![
|
||||
3.into(),
|
||||
4.into(),
|
||||
5.into(),
|
||||
])
|
||||
Constant::Tuple(vec![3.into(), 4.into(), 5.into(),])
|
||||
]),
|
||||
kind: None
|
||||
},
|
||||
|
|
|
@ -64,11 +64,4 @@ macro_rules! simple_fold {
|
|||
};
|
||||
}
|
||||
|
||||
simple_fold!(
|
||||
usize,
|
||||
String,
|
||||
bool,
|
||||
StrRef,
|
||||
constant::Constant,
|
||||
constant::ConversionFlag
|
||||
);
|
||||
simple_fold!(usize, String, bool, StrRef, constant::Constant, constant::ConversionFlag);
|
||||
|
|
|
@ -2,6 +2,7 @@ use crate::{Constant, ExprKind};
|
|||
|
||||
impl<U> ExprKind<U> {
|
||||
/// Returns a short name for the node suitable for use in error messages.
|
||||
#[must_use]
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
ExprKind::BoolOp { .. } | ExprKind::BinOp { .. } | ExprKind::UnaryOp { .. } => {
|
||||
|
@ -34,10 +35,7 @@ impl<U> ExprKind<U> {
|
|||
ExprKind::Starred { .. } => "starred",
|
||||
ExprKind::Slice { .. } => "slice",
|
||||
ExprKind::JoinedStr { values } => {
|
||||
if values
|
||||
.iter()
|
||||
.any(|e| matches!(e.node, ExprKind::JoinedStr { .. }))
|
||||
{
|
||||
if values.iter().any(|e| matches!(e.node, ExprKind::JoinedStr { .. })) {
|
||||
"f-string expression"
|
||||
} else {
|
||||
"literal"
|
||||
|
|
|
@ -1,3 +1,19 @@
|
|||
#![deny(
|
||||
future_incompatible,
|
||||
let_underscore,
|
||||
nonstandard_style,
|
||||
rust_2024_compatibility,
|
||||
clippy::all
|
||||
)]
|
||||
#![warn(clippy::pedantic)]
|
||||
#![allow(
|
||||
clippy::missing_errors_doc,
|
||||
clippy::missing_panics_doc,
|
||||
clippy::module_name_repetitions,
|
||||
clippy::too_many_lines,
|
||||
clippy::wildcard_imports
|
||||
)]
|
||||
|
||||
#[macro_use]
|
||||
extern crate lazy_static;
|
||||
|
||||
|
@ -9,6 +25,6 @@ mod impls;
|
|||
mod location;
|
||||
|
||||
pub use ast_gen::*;
|
||||
pub use location::{Location, FileName};
|
||||
pub use location::{FileName, Location};
|
||||
|
||||
pub type Suite<U = ()> = Vec<Stmt<U>>;
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
//! Datatypes to support source location information.
|
||||
use crate::ast_gen::StrRef;
|
||||
use std::cmp::Ordering;
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||
pub struct FileName(pub StrRef);
|
||||
impl Default for FileName {
|
||||
fn default() -> Self {
|
||||
|
@ -17,16 +18,38 @@ impl From<String> for FileName {
|
|||
}
|
||||
|
||||
/// A location somewhere in the sourcecode.
|
||||
#[derive(Clone, Copy, Debug, Default, PartialEq)]
|
||||
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
|
||||
pub struct Location {
|
||||
pub row: usize,
|
||||
pub column: usize,
|
||||
pub file: FileName
|
||||
pub file: FileName,
|
||||
}
|
||||
|
||||
impl fmt::Display for Location {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}: line {} column {}", self.file.0, self.row, self.column)
|
||||
write!(f, "{}:{}:{}", self.file.0, self.row, self.column)
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for Location {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
let file_cmp = self.file.0.to_string().cmp(&other.file.0.to_string());
|
||||
if file_cmp != Ordering::Equal {
|
||||
return file_cmp;
|
||||
}
|
||||
|
||||
let row_cmp = self.row.cmp(&other.row);
|
||||
if row_cmp != Ordering::Equal {
|
||||
return row_cmp;
|
||||
}
|
||||
|
||||
self.column.cmp(&other.column)
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialOrd for Location {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -53,23 +76,22 @@ impl Location {
|
|||
)
|
||||
}
|
||||
}
|
||||
Visualize {
|
||||
loc: *self,
|
||||
line,
|
||||
desc,
|
||||
}
|
||||
Visualize { loc: *self, line, desc }
|
||||
}
|
||||
}
|
||||
|
||||
impl Location {
|
||||
#[must_use]
|
||||
pub fn new(row: usize, column: usize, file: FileName) -> Self {
|
||||
Location { row, column, file }
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn row(&self) -> usize {
|
||||
self.row
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn column(&self) -> usize {
|
||||
self.column
|
||||
}
|
||||
|
|
|
@ -2,24 +2,27 @@
|
|||
name = "nac3core"
|
||||
version = "0.1.0"
|
||||
authors = ["M-Labs"]
|
||||
edition = "2018"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
itertools = "0.10.1"
|
||||
crossbeam = "0.8.1"
|
||||
parking_lot = "0.11.1"
|
||||
rayon = "1.5.1"
|
||||
itertools = "0.13"
|
||||
crossbeam = "0.8"
|
||||
indexmap = "2.2"
|
||||
parking_lot = "0.12"
|
||||
rayon = "1.8"
|
||||
nac3parser = { path = "../nac3parser" }
|
||||
strum = "0.26.2"
|
||||
strum_macros = "0.26.4"
|
||||
|
||||
[dependencies.inkwell]
|
||||
version = "0.1.0-beta.4"
|
||||
version = "0.4"
|
||||
default-features = false
|
||||
features = ["llvm13-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]
|
||||
features = ["llvm14-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]
|
||||
|
||||
[dev-dependencies]
|
||||
test-case = "1.2.0"
|
||||
indoc = "1.0"
|
||||
indoc = "2.0"
|
||||
insta = "=1.11.0"
|
||||
|
||||
[build-dependencies]
|
||||
regex = "1"
|
||||
regex = "1.10"
|
||||
|
|
|
@ -9,19 +9,20 @@ use std::{
|
|||
|
||||
fn main() {
|
||||
const FILE: &str = "src/codegen/irrt/irrt.c";
|
||||
println!("cargo:rerun-if-changed={}", FILE);
|
||||
let out_dir = env::var("OUT_DIR").unwrap();
|
||||
let out_path = Path::new(&out_dir);
|
||||
|
||||
/*
|
||||
* HACK: Sadly, clang doesn't let us emit generic LLVM bitcode.
|
||||
* Compiling for WASM32 and filtering the output with regex is the closest we can get.
|
||||
*/
|
||||
|
||||
const FLAG: &[&str] = &[
|
||||
let flags: &[&str] = &[
|
||||
"--target=wasm32",
|
||||
FILE,
|
||||
"-O3",
|
||||
"-fno-discard-value-names",
|
||||
match env::var("PROFILE").as_deref() {
|
||||
Ok("debug") => "-O0",
|
||||
Ok("release") => "-O3",
|
||||
flavor => panic!("Unknown or missing build flavor {flavor:?}"),
|
||||
},
|
||||
"-emit-llvm",
|
||||
"-S",
|
||||
"-Wall",
|
||||
|
@ -29,8 +30,13 @@ fn main() {
|
|||
"-o",
|
||||
"-",
|
||||
];
|
||||
let output = Command::new("clang")
|
||||
.args(FLAG)
|
||||
|
||||
println!("cargo:rerun-if-changed={FILE}");
|
||||
let out_dir = env::var("OUT_DIR").unwrap();
|
||||
let out_path = Path::new(&out_dir);
|
||||
|
||||
let output = Command::new("clang-irrt")
|
||||
.args(flags)
|
||||
.output()
|
||||
.map(|o| {
|
||||
assert!(o.status.success(), "{}", std::str::from_utf8(&o.stderr).unwrap());
|
||||
|
@ -42,9 +48,9 @@ fn main() {
|
|||
let output = std::str::from_utf8(&output.stdout).unwrap().replace("\r\n", "\n");
|
||||
let mut filtered_output = String::with_capacity(output.len());
|
||||
|
||||
let regex_filter = regex::Regex::new(r"(?ms:^define.*?\}$)|(?m:^declare.*?$)").unwrap();
|
||||
let regex_filter = Regex::new(r"(?ms:^define.*?\}$)|(?m:^declare.*?$)").unwrap();
|
||||
for f in regex_filter.captures_iter(&output) {
|
||||
assert!(f.len() == 1);
|
||||
assert_eq!(f.len(), 1);
|
||||
filtered_output.push_str(&f[0]);
|
||||
filtered_output.push('\n');
|
||||
}
|
||||
|
@ -61,12 +67,12 @@ fn main() {
|
|||
file.write_all(filtered_output.as_bytes()).unwrap();
|
||||
}
|
||||
|
||||
let mut llvm_as = Command::new("llvm-as")
|
||||
let mut llvm_as = Command::new("llvm-as-irrt")
|
||||
.stdin(Stdio::piped())
|
||||
.arg("-o")
|
||||
.arg(out_path.join("irrt.bc"))
|
||||
.spawn()
|
||||
.unwrap();
|
||||
llvm_as.stdin.as_mut().unwrap().write_all(filtered_output.as_bytes()).unwrap();
|
||||
assert!(llvm_as.wait().unwrap().success())
|
||||
assert!(llvm_as.wait().unwrap().success());
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -3,10 +3,13 @@ use crate::{
|
|||
toplevel::DefinitionId,
|
||||
typecheck::{
|
||||
type_inferencer::PrimitiveStore,
|
||||
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier},
|
||||
typedef::{
|
||||
into_var_map, FunSignature, FuncArg, Type, TypeEnum, TypeVar, TypeVarId, Unifier,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
use indexmap::IndexMap;
|
||||
use nac3parser::ast::StrRef;
|
||||
use std::collections::HashMap;
|
||||
|
||||
|
@ -50,7 +53,7 @@ pub enum ConcreteTypeEnum {
|
|||
TObj {
|
||||
obj_id: DefinitionId,
|
||||
fields: HashMap<StrRef, (ConcreteType, bool)>,
|
||||
params: HashMap<u32, ConcreteType>,
|
||||
params: IndexMap<TypeVarId, ConcreteType>,
|
||||
},
|
||||
TVirtual {
|
||||
ty: ConcreteType,
|
||||
|
@ -58,11 +61,15 @@ pub enum ConcreteTypeEnum {
|
|||
TFunc {
|
||||
args: Vec<ConcreteFuncArg>,
|
||||
ret: ConcreteType,
|
||||
vars: HashMap<u32, ConcreteType>,
|
||||
vars: HashMap<TypeVarId, ConcreteType>,
|
||||
},
|
||||
TLiteral {
|
||||
values: Vec<SymbolValue>,
|
||||
},
|
||||
}
|
||||
|
||||
impl ConcreteTypeStore {
|
||||
#[must_use]
|
||||
pub fn new() -> ConcreteTypeStore {
|
||||
ConcreteTypeStore {
|
||||
store: vec![
|
||||
|
@ -80,6 +87,7 @@ impl ConcreteTypeStore {
|
|||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn get(&self, cty: ConcreteType) -> &ConcreteTypeEnum {
|
||||
&self.store[cty.0]
|
||||
}
|
||||
|
@ -194,9 +202,12 @@ impl ConcreteTypeStore {
|
|||
ty: self.from_unifier_type(unifier, primitives, *ty, cache),
|
||||
},
|
||||
TypeEnum::TFunc(signature) => {
|
||||
self.from_signature(unifier, primitives, &*signature, cache)
|
||||
self.from_signature(unifier, primitives, signature, cache)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
TypeEnum::TLiteral { values, .. } => {
|
||||
ConcreteTypeEnum::TLiteral { values: values.clone() }
|
||||
}
|
||||
_ => unreachable!("{:?}", ty_enum.get_type_name()),
|
||||
};
|
||||
let index = if let Some(ConcreteType(index)) = cache.get(&ty).unwrap() {
|
||||
self.store[*index] = result;
|
||||
|
@ -221,7 +232,7 @@ impl ConcreteTypeStore {
|
|||
return if let Some(ty) = ty {
|
||||
*ty
|
||||
} else {
|
||||
*ty = Some(unifier.get_dummy_var().0);
|
||||
*ty = Some(unifier.get_dummy_var().ty);
|
||||
ty.unwrap()
|
||||
};
|
||||
}
|
||||
|
@ -263,10 +274,10 @@ impl ConcreteTypeStore {
|
|||
(*name, (self.to_unifier_type(unifier, primitives, cty.0, cache), cty.1))
|
||||
})
|
||||
.collect::<HashMap<_, _>>(),
|
||||
params: params
|
||||
.iter()
|
||||
.map(|(id, cty)| (*id, self.to_unifier_type(unifier, primitives, *cty, cache)))
|
||||
.collect::<HashMap<_, _>>(),
|
||||
params: into_var_map(params.iter().map(|(&id, cty)| {
|
||||
let ty = self.to_unifier_type(unifier, primitives, *cty, cache);
|
||||
TypeVar { id, ty }
|
||||
})),
|
||||
},
|
||||
ConcreteTypeEnum::TFunc { args, ret, vars } => TypeEnum::TFunc(FunSignature {
|
||||
args: args
|
||||
|
@ -278,11 +289,14 @@ impl ConcreteTypeStore {
|
|||
})
|
||||
.collect(),
|
||||
ret: self.to_unifier_type(unifier, primitives, *ret, cache),
|
||||
vars: vars
|
||||
.iter()
|
||||
.map(|(id, cty)| (*id, self.to_unifier_type(unifier, primitives, *cty, cache)))
|
||||
.collect::<HashMap<_, _>>(),
|
||||
vars: into_var_map(vars.iter().map(|(&id, cty)| {
|
||||
let ty = self.to_unifier_type(unifier, primitives, *cty, cache);
|
||||
TypeVar { id, ty }
|
||||
})),
|
||||
}),
|
||||
ConcreteTypeEnum::TLiteral { values, .. } => {
|
||||
TypeEnum::TLiteral { values: values.clone(), loc: None }
|
||||
}
|
||||
};
|
||||
let result = unifier.add_ty(result);
|
||||
if let Some(ty) = cache.get(&cty).unwrap() {
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,613 @@
|
|||
use inkwell::attributes::{Attribute, AttributeLoc};
|
||||
use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue};
|
||||
use itertools::Either;
|
||||
|
||||
use crate::codegen::CodeGenContext;
|
||||
|
||||
/// Invokes the [`tan`](https://en.cppreference.com/w/c/numeric/math/tan) function.
|
||||
pub fn call_tan<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
arg: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "tan";
|
||||
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
debug_assert_eq!(arg.get_type(), llvm_f64);
|
||||
|
||||
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
|
||||
func.add_attribute(
|
||||
AttributeLoc::Function,
|
||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
||||
);
|
||||
}
|
||||
|
||||
func
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`asin`](https://en.cppreference.com/w/c/numeric/math/asin) function.
|
||||
pub fn call_asin<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
arg: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "asin";
|
||||
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
debug_assert_eq!(arg.get_type(), llvm_f64);
|
||||
|
||||
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
|
||||
func.add_attribute(
|
||||
AttributeLoc::Function,
|
||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
||||
);
|
||||
}
|
||||
|
||||
func
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`acos`](https://en.cppreference.com/w/c/numeric/math/acos) function.
|
||||
pub fn call_acos<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
arg: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "acos";
|
||||
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
debug_assert_eq!(arg.get_type(), llvm_f64);
|
||||
|
||||
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
|
||||
func.add_attribute(
|
||||
AttributeLoc::Function,
|
||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
||||
);
|
||||
}
|
||||
|
||||
func
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`atan`](https://en.cppreference.com/w/c/numeric/math/atan) function.
|
||||
pub fn call_atan<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
arg: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "atan";
|
||||
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
debug_assert_eq!(arg.get_type(), llvm_f64);
|
||||
|
||||
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
|
||||
func.add_attribute(
|
||||
AttributeLoc::Function,
|
||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
||||
);
|
||||
}
|
||||
|
||||
func
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`sinh`](https://en.cppreference.com/w/c/numeric/math/sinh) function.
|
||||
pub fn call_sinh<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
arg: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "sinh";
|
||||
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
debug_assert_eq!(arg.get_type(), llvm_f64);
|
||||
|
||||
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
|
||||
func.add_attribute(
|
||||
AttributeLoc::Function,
|
||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
||||
);
|
||||
}
|
||||
|
||||
func
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`cosh`](https://en.cppreference.com/w/c/numeric/math/cosh) function.
|
||||
pub fn call_cosh<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
arg: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "cosh";
|
||||
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
debug_assert_eq!(arg.get_type(), llvm_f64);
|
||||
|
||||
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
|
||||
func.add_attribute(
|
||||
AttributeLoc::Function,
|
||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
||||
);
|
||||
}
|
||||
|
||||
func
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`tanh`](https://en.cppreference.com/w/c/numeric/math/tanh) function.
|
||||
pub fn call_tanh<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
arg: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "tanh";
|
||||
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
debug_assert_eq!(arg.get_type(), llvm_f64);
|
||||
|
||||
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
|
||||
func.add_attribute(
|
||||
AttributeLoc::Function,
|
||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
||||
);
|
||||
}
|
||||
|
||||
func
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`asinh`](https://en.cppreference.com/w/c/numeric/math/asinh) function.
|
||||
pub fn call_asinh<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
arg: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "asinh";
|
||||
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
debug_assert_eq!(arg.get_type(), llvm_f64);
|
||||
|
||||
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
|
||||
func.add_attribute(
|
||||
AttributeLoc::Function,
|
||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
||||
);
|
||||
}
|
||||
|
||||
func
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`acosh`](https://en.cppreference.com/w/c/numeric/math/acosh) function.
|
||||
pub fn call_acosh<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
arg: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "acosh";
|
||||
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
debug_assert_eq!(arg.get_type(), llvm_f64);
|
||||
|
||||
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
|
||||
func.add_attribute(
|
||||
AttributeLoc::Function,
|
||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
||||
);
|
||||
}
|
||||
|
||||
func
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`atanh`](https://en.cppreference.com/w/c/numeric/math/atanh) function.
|
||||
pub fn call_atanh<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
arg: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "atanh";
|
||||
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
debug_assert_eq!(arg.get_type(), llvm_f64);
|
||||
|
||||
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
|
||||
func.add_attribute(
|
||||
AttributeLoc::Function,
|
||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
||||
);
|
||||
}
|
||||
|
||||
func
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`expm1`](https://en.cppreference.com/w/c/numeric/math/expm1) function.
|
||||
pub fn call_expm1<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
arg: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "expm1";
|
||||
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
debug_assert_eq!(arg.get_type(), llvm_f64);
|
||||
|
||||
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
|
||||
func.add_attribute(
|
||||
AttributeLoc::Function,
|
||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
||||
);
|
||||
}
|
||||
|
||||
func
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`cbrt`](https://en.cppreference.com/w/c/numeric/math/cbrt) function.
|
||||
pub fn call_cbrt<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
arg: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "cbrt";
|
||||
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
debug_assert_eq!(arg.get_type(), llvm_f64);
|
||||
|
||||
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||
for attr in ["mustprogress", "nofree", "nosync", "nounwind", "readonly", "willreturn"] {
|
||||
func.add_attribute(
|
||||
AttributeLoc::Function,
|
||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
||||
);
|
||||
}
|
||||
|
||||
func
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`erf`](https://en.cppreference.com/w/c/numeric/math/erf) function.
|
||||
pub fn call_erf<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
arg: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "erf";
|
||||
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
debug_assert_eq!(arg.get_type(), llvm_f64);
|
||||
|
||||
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||
func.add_attribute(
|
||||
AttributeLoc::Function,
|
||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0),
|
||||
);
|
||||
|
||||
func
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`erfc`](https://en.cppreference.com/w/c/numeric/math/erfc) function.
|
||||
pub fn call_erfc<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
arg: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "erfc";
|
||||
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
debug_assert_eq!(arg.get_type(), llvm_f64);
|
||||
|
||||
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||
func.add_attribute(
|
||||
AttributeLoc::Function,
|
||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0),
|
||||
);
|
||||
|
||||
func
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`j1`](https://www.gnu.org/software/libc/manual/html_node/Special-Functions.html#index-j1)
|
||||
/// function.
|
||||
pub fn call_j1<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
arg: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "j1";
|
||||
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
debug_assert_eq!(arg.get_type(), llvm_f64);
|
||||
|
||||
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||
func.add_attribute(
|
||||
AttributeLoc::Function,
|
||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0),
|
||||
);
|
||||
|
||||
func
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`atan2`](https://en.cppreference.com/w/c/numeric/math/atan2) function.
|
||||
pub fn call_atan2<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
y: FloatValue<'ctx>,
|
||||
x: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "atan2";
|
||||
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
debug_assert_eq!(y.get_type(), llvm_f64);
|
||||
debug_assert_eq!(x.get_type(), llvm_f64);
|
||||
|
||||
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into(), llvm_f64.into()], false);
|
||||
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
|
||||
func.add_attribute(
|
||||
AttributeLoc::Function,
|
||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
||||
);
|
||||
}
|
||||
|
||||
func
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(extern_fn, &[y.into(), x.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`ldexp`](https://en.cppreference.com/w/c/numeric/math/ldexp) function.
|
||||
pub fn call_ldexp<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
arg: FloatValue<'ctx>,
|
||||
exp: IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "ldexp";
|
||||
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
debug_assert_eq!(arg.get_type(), llvm_f64);
|
||||
debug_assert_eq!(exp.get_type(), llvm_i32);
|
||||
|
||||
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into(), llvm_i32.into()], false);
|
||||
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
|
||||
func.add_attribute(
|
||||
AttributeLoc::Function,
|
||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
||||
);
|
||||
}
|
||||
|
||||
func
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(extern_fn, &[arg.into(), exp.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`hypot`](https://en.cppreference.com/w/c/numeric/math/hypot) function.
|
||||
pub fn call_hypot<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
x: FloatValue<'ctx>,
|
||||
y: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "hypot";
|
||||
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
debug_assert_eq!(x.get_type(), llvm_f64);
|
||||
debug_assert_eq!(y.get_type(), llvm_f64);
|
||||
|
||||
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into(), llvm_f64.into()], false);
|
||||
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||
func.add_attribute(
|
||||
AttributeLoc::Function,
|
||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0),
|
||||
);
|
||||
|
||||
func
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(extern_fn, &[x.into(), y.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`nextafter`](https://en.cppreference.com/w/c/numeric/math/nextafter) function.
|
||||
pub fn call_nextafter<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
from: FloatValue<'ctx>,
|
||||
to: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "nextafter";
|
||||
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
debug_assert_eq!(from.get_type(), llvm_f64);
|
||||
debug_assert_eq!(to.get_type(), llvm_f64);
|
||||
|
||||
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into(), llvm_f64.into()], false);
|
||||
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||
func.add_attribute(
|
||||
AttributeLoc::Function,
|
||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0),
|
||||
);
|
||||
|
||||
func
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(extern_fn, &[from.into(), to.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
|
@ -1,5 +1,5 @@
|
|||
use crate::{
|
||||
codegen::{expr::*, stmt::*, CodeGenContext},
|
||||
codegen::{bool_to_i1, bool_to_i8, classes::ArraySliceValue, expr::*, stmt::*, CodeGenContext},
|
||||
symbol_resolver::ValueEnum,
|
||||
toplevel::{DefinitionId, TopLevelDef},
|
||||
typecheck::typedef::{FunSignature, Type},
|
||||
|
@ -7,7 +7,7 @@ use crate::{
|
|||
use inkwell::{
|
||||
context::Context,
|
||||
types::{BasicTypeEnum, IntType},
|
||||
values::{BasicValueEnum, PointerValue},
|
||||
values::{BasicValueEnum, IntValue, PointerValue},
|
||||
};
|
||||
use nac3parser::ast::{Expr, Stmt, StrRef};
|
||||
|
||||
|
@ -22,9 +22,9 @@ pub trait CodeGenerator {
|
|||
/// - fun: Function signature and definition ID.
|
||||
/// - params: Function parameters. Note that this does not include the object even if the
|
||||
/// function is a class method.
|
||||
fn gen_call<'ctx, 'a>(
|
||||
fn gen_call<'ctx>(
|
||||
&mut self,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
obj: Option<(Type, ValueEnum<'ctx>)>,
|
||||
fun: (&FunSignature, DefinitionId),
|
||||
params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
||||
|
@ -39,9 +39,9 @@ pub trait CodeGenerator {
|
|||
/// - signature: Function signature of the constructor.
|
||||
/// - def: Class definition for the constructor class.
|
||||
/// - params: Function parameters.
|
||||
fn gen_constructor<'ctx, 'a>(
|
||||
fn gen_constructor<'ctx>(
|
||||
&mut self,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
signature: &FunSignature,
|
||||
def: &TopLevelDef,
|
||||
params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
||||
|
@ -59,20 +59,20 @@ pub trait CodeGenerator {
|
|||
/// function is a class method.
|
||||
/// Note that this function should check if the function is generated in another thread (due to
|
||||
/// possible race condition), see the default implementation for an example.
|
||||
fn gen_func_instance<'ctx, 'a>(
|
||||
fn gen_func_instance<'ctx>(
|
||||
&mut self,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
obj: Option<(Type, ValueEnum<'ctx>)>,
|
||||
fun: (&FunSignature, &mut TopLevelDef, String),
|
||||
id: usize,
|
||||
) -> Result<String, String> {
|
||||
gen_func_instance(ctx, obj, fun, id)
|
||||
gen_func_instance(ctx, &obj, fun, id)
|
||||
}
|
||||
|
||||
/// Generate the code for an expression.
|
||||
fn gen_expr<'ctx, 'a>(
|
||||
fn gen_expr<'ctx>(
|
||||
&mut self,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
expr: &Expr<Option<Type>>,
|
||||
) -> Result<Option<ValueEnum<'ctx>>, String>
|
||||
where
|
||||
|
@ -83,30 +83,44 @@ pub trait CodeGenerator {
|
|||
|
||||
/// Allocate memory for a variable and return a pointer pointing to it.
|
||||
/// The default implementation places the allocations at the start of the function.
|
||||
fn gen_var_alloc<'ctx, 'a>(
|
||||
fn gen_var_alloc<'ctx>(
|
||||
&mut self,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ty: BasicTypeEnum<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> Result<PointerValue<'ctx>, String> {
|
||||
gen_var(ctx, ty)
|
||||
gen_var(ctx, ty, name)
|
||||
}
|
||||
|
||||
/// Allocate memory for a variable and return a pointer pointing to it.
|
||||
/// The default implementation places the allocations at the start of the function.
|
||||
fn gen_array_var_alloc<'ctx>(
|
||||
&mut self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ty: BasicTypeEnum<'ctx>,
|
||||
size: IntValue<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> Result<ArraySliceValue<'ctx>, String> {
|
||||
gen_array_var(ctx, ty, size, name)
|
||||
}
|
||||
|
||||
/// Return a pointer pointing to the target of the expression.
|
||||
fn gen_store_target<'ctx, 'a>(
|
||||
fn gen_store_target<'ctx>(
|
||||
&mut self,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
pattern: &Expr<Option<Type>>,
|
||||
) -> Result<PointerValue<'ctx>, String>
|
||||
name: Option<&str>,
|
||||
) -> Result<Option<PointerValue<'ctx>>, String>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
gen_store_target(self, ctx, pattern)
|
||||
gen_store_target(self, ctx, pattern, name)
|
||||
}
|
||||
|
||||
/// Generate code for an assignment expression.
|
||||
fn gen_assign<'ctx, 'a>(
|
||||
fn gen_assign<'ctx>(
|
||||
&mut self,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
target: &Expr<Option<Type>>,
|
||||
value: ValueEnum<'ctx>,
|
||||
) -> Result<(), String>
|
||||
|
@ -118,9 +132,9 @@ pub trait CodeGenerator {
|
|||
|
||||
/// Generate code for a while expression.
|
||||
/// Return true if the while loop must early return
|
||||
fn gen_while<'ctx, 'a>(
|
||||
fn gen_while(
|
||||
&mut self,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ctx: &mut CodeGenContext<'_, '_>,
|
||||
stmt: &Stmt<Option<Type>>,
|
||||
) -> Result<(), String>
|
||||
where
|
||||
|
@ -129,11 +143,11 @@ pub trait CodeGenerator {
|
|||
gen_while(self, ctx, stmt)
|
||||
}
|
||||
|
||||
/// Generate code for a while expression.
|
||||
/// Return true if the while loop must early return
|
||||
fn gen_for<'ctx, 'a>(
|
||||
/// Generate code for a for expression.
|
||||
/// Return true if the for loop must early return
|
||||
fn gen_for(
|
||||
&mut self,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ctx: &mut CodeGenContext<'_, '_>,
|
||||
stmt: &Stmt<Option<Type>>,
|
||||
) -> Result<(), String>
|
||||
where
|
||||
|
@ -144,9 +158,9 @@ pub trait CodeGenerator {
|
|||
|
||||
/// Generate code for an if expression.
|
||||
/// Return true if the statement must early return
|
||||
fn gen_if<'ctx, 'a>(
|
||||
fn gen_if(
|
||||
&mut self,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ctx: &mut CodeGenContext<'_, '_>,
|
||||
stmt: &Stmt<Option<Type>>,
|
||||
) -> Result<(), String>
|
||||
where
|
||||
|
@ -155,9 +169,9 @@ pub trait CodeGenerator {
|
|||
gen_if(self, ctx, stmt)
|
||||
}
|
||||
|
||||
fn gen_with<'ctx, 'a>(
|
||||
fn gen_with(
|
||||
&mut self,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ctx: &mut CodeGenContext<'_, '_>,
|
||||
stmt: &Stmt<Option<Type>>,
|
||||
) -> Result<(), String>
|
||||
where
|
||||
|
@ -167,10 +181,11 @@ pub trait CodeGenerator {
|
|||
}
|
||||
|
||||
/// Generate code for a statement
|
||||
///
|
||||
/// Return true if the statement must early return
|
||||
fn gen_stmt<'ctx, 'a>(
|
||||
fn gen_stmt(
|
||||
&mut self,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ctx: &mut CodeGenContext<'_, '_>,
|
||||
stmt: &Stmt<Option<Type>>,
|
||||
) -> Result<(), String>
|
||||
where
|
||||
|
@ -178,6 +193,36 @@ pub trait CodeGenerator {
|
|||
{
|
||||
gen_stmt(self, ctx, stmt)
|
||||
}
|
||||
|
||||
/// Generates code for a block statement.
|
||||
fn gen_block<'a, I: Iterator<Item = &'a Stmt<Option<Type>>>>(
|
||||
&mut self,
|
||||
ctx: &mut CodeGenContext<'_, '_>,
|
||||
stmts: I,
|
||||
) -> Result<(), String>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
gen_block(self, ctx, stmts)
|
||||
}
|
||||
|
||||
/// See [`bool_to_i1`].
|
||||
fn bool_to_i1<'ctx>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
bool_value: IntValue<'ctx>,
|
||||
) -> IntValue<'ctx> {
|
||||
bool_to_i1(&ctx.builder, bool_value)
|
||||
}
|
||||
|
||||
/// See [`bool_to_i8`].
|
||||
fn bool_to_i8<'ctx>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
bool_value: IntValue<'ctx>,
|
||||
) -> IntValue<'ctx> {
|
||||
bool_to_i8(&ctx.builder, ctx.ctx, bool_value)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DefaultCodeGenerator {
|
||||
|
@ -186,17 +231,20 @@ pub struct DefaultCodeGenerator {
|
|||
}
|
||||
|
||||
impl DefaultCodeGenerator {
|
||||
#[must_use]
|
||||
pub fn new(name: String, size_t: u32) -> DefaultCodeGenerator {
|
||||
assert!(size_t == 32 || size_t == 64);
|
||||
assert!(matches!(size_t, 32 | 64));
|
||||
DefaultCodeGenerator { name, size_t }
|
||||
}
|
||||
}
|
||||
|
||||
impl CodeGenerator for DefaultCodeGenerator {
|
||||
/// Returns the name for this [`CodeGenerator`].
|
||||
fn get_name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
/// Returns an LLVM integer type representing `size_t`.
|
||||
fn get_size_type<'ctx>(&self, ctx: &'ctx Context) -> IntType<'ctx> {
|
||||
// it should be unsigned, but we don't really need unsigned and this could save us from
|
||||
// having to do a bit cast...
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
typedef _ExtInt(8) int8_t;
|
||||
typedef unsigned _ExtInt(8) uint8_t;
|
||||
typedef _ExtInt(32) int32_t;
|
||||
typedef unsigned _ExtInt(32) uint32_t;
|
||||
typedef _ExtInt(64) int64_t;
|
||||
typedef unsigned _ExtInt(64) uint64_t;
|
||||
typedef _BitInt(8) int8_t;
|
||||
typedef unsigned _BitInt(8) uint8_t;
|
||||
typedef _BitInt(32) int32_t;
|
||||
typedef unsigned _BitInt(32) uint32_t;
|
||||
typedef _BitInt(64) int64_t;
|
||||
typedef unsigned _BitInt(64) uint64_t;
|
||||
|
||||
# define MAX(a, b) (a > b ? a : b)
|
||||
# define MIN(a, b) (a > b ? b : a)
|
||||
|
||||
# define NULL ((void *) 0)
|
||||
|
||||
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
|
||||
// need to make sure `exp >= 0` before calling this function
|
||||
#define DEF_INT_EXP(T) T __nac3_int_exp_##T( \
|
||||
|
@ -138,3 +140,250 @@ int32_t __nac3_list_slice_assign_var_size(
|
|||
}
|
||||
return dest_arr_len;
|
||||
}
|
||||
|
||||
int32_t __nac3_isinf(double x) {
|
||||
return __builtin_isinf(x);
|
||||
}
|
||||
|
||||
int32_t __nac3_isnan(double x) {
|
||||
return __builtin_isnan(x);
|
||||
}
|
||||
|
||||
double tgamma(double arg);
|
||||
|
||||
double __nac3_gamma(double z) {
|
||||
// Handling for denormals
|
||||
// | x | Python gamma(x) | C tgamma(x) |
|
||||
// --- | ----------------- | --------------- | ----------- |
|
||||
// (1) | nan | nan | nan |
|
||||
// (2) | -inf | -inf | inf |
|
||||
// (3) | inf | inf | inf |
|
||||
// (4) | 0.0 | inf | inf |
|
||||
// (5) | {-1.0, -2.0, ...} | inf | nan |
|
||||
|
||||
// (1)-(3)
|
||||
if (__builtin_isinf(z) || __builtin_isnan(z)) {
|
||||
return z;
|
||||
}
|
||||
|
||||
double v = tgamma(z);
|
||||
|
||||
// (4)-(5)
|
||||
return __builtin_isinf(v) || __builtin_isnan(v) ? __builtin_inf() : v;
|
||||
}
|
||||
|
||||
double lgamma(double arg);
|
||||
|
||||
double __nac3_gammaln(double x) {
|
||||
// libm's handling of value overflows differs from scipy:
|
||||
// - scipy: gammaln(-inf) -> -inf
|
||||
// - libm : lgamma(-inf) -> inf
|
||||
|
||||
if (__builtin_isinf(x)) {
|
||||
return x;
|
||||
}
|
||||
|
||||
return lgamma(x);
|
||||
}
|
||||
|
||||
double j0(double x);
|
||||
|
||||
double __nac3_j0(double x) {
|
||||
// libm's handling of value overflows differs from scipy:
|
||||
// - scipy: j0(inf) -> nan
|
||||
// - libm : j0(inf) -> 0.0
|
||||
|
||||
if (__builtin_isinf(x)) {
|
||||
return __builtin_nan("");
|
||||
}
|
||||
|
||||
return j0(x);
|
||||
}
|
||||
|
||||
uint32_t __nac3_ndarray_calc_size(
|
||||
const uint64_t *list_data,
|
||||
uint32_t list_len,
|
||||
uint32_t begin_idx,
|
||||
uint32_t end_idx
|
||||
) {
|
||||
__builtin_assume(end_idx <= list_len);
|
||||
|
||||
uint32_t num_elems = 1;
|
||||
for (uint32_t i = begin_idx; i < end_idx; ++i) {
|
||||
uint64_t val = list_data[i];
|
||||
__builtin_assume(val > 0);
|
||||
num_elems *= val;
|
||||
}
|
||||
return num_elems;
|
||||
}
|
||||
|
||||
uint64_t __nac3_ndarray_calc_size64(
|
||||
const uint64_t *list_data,
|
||||
uint64_t list_len,
|
||||
uint64_t begin_idx,
|
||||
uint64_t end_idx
|
||||
) {
|
||||
__builtin_assume(end_idx <= list_len);
|
||||
|
||||
uint64_t num_elems = 1;
|
||||
for (uint64_t i = begin_idx; i < end_idx; ++i) {
|
||||
uint64_t val = list_data[i];
|
||||
__builtin_assume(val > 0);
|
||||
num_elems *= val;
|
||||
}
|
||||
return num_elems;
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_nd_indices(
|
||||
uint32_t index,
|
||||
const uint32_t* dims,
|
||||
uint32_t num_dims,
|
||||
uint32_t* idxs
|
||||
) {
|
||||
uint32_t stride = 1;
|
||||
for (uint32_t dim = 0; dim < num_dims; dim++) {
|
||||
uint32_t i = num_dims - dim - 1;
|
||||
__builtin_assume(dims[i] > 0);
|
||||
idxs[i] = (index / stride) % dims[i];
|
||||
stride *= dims[i];
|
||||
}
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_nd_indices64(
|
||||
uint64_t index,
|
||||
const uint64_t* dims,
|
||||
uint64_t num_dims,
|
||||
uint32_t* idxs
|
||||
) {
|
||||
uint64_t stride = 1;
|
||||
for (uint64_t dim = 0; dim < num_dims; dim++) {
|
||||
uint64_t i = num_dims - dim - 1;
|
||||
__builtin_assume(dims[i] > 0);
|
||||
idxs[i] = (uint32_t) ((index / stride) % dims[i]);
|
||||
stride *= dims[i];
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t __nac3_ndarray_flatten_index(
|
||||
const uint32_t* dims,
|
||||
uint32_t num_dims,
|
||||
const uint32_t* indices,
|
||||
uint32_t num_indices
|
||||
) {
|
||||
uint32_t idx = 0;
|
||||
uint32_t stride = 1;
|
||||
for (uint32_t i = 0; i < num_dims; ++i) {
|
||||
uint32_t ri = num_dims - i - 1;
|
||||
if (ri < num_indices) {
|
||||
idx += (stride * indices[ri]);
|
||||
}
|
||||
|
||||
__builtin_assume(dims[i] > 0);
|
||||
stride *= dims[ri];
|
||||
}
|
||||
return idx;
|
||||
}
|
||||
|
||||
uint64_t __nac3_ndarray_flatten_index64(
|
||||
const uint64_t* dims,
|
||||
uint64_t num_dims,
|
||||
const uint32_t* indices,
|
||||
uint64_t num_indices
|
||||
) {
|
||||
uint64_t idx = 0;
|
||||
uint64_t stride = 1;
|
||||
for (uint64_t i = 0; i < num_dims; ++i) {
|
||||
uint64_t ri = num_dims - i - 1;
|
||||
if (ri < num_indices) {
|
||||
idx += (stride * indices[ri]);
|
||||
}
|
||||
|
||||
__builtin_assume(dims[i] > 0);
|
||||
stride *= dims[ri];
|
||||
}
|
||||
return idx;
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_broadcast(
|
||||
const uint32_t *lhs_dims,
|
||||
uint32_t lhs_ndims,
|
||||
const uint32_t *rhs_dims,
|
||||
uint32_t rhs_ndims,
|
||||
uint32_t *out_dims
|
||||
) {
|
||||
uint32_t max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
|
||||
|
||||
for (uint32_t i = 0; i < max_ndims; ++i) {
|
||||
uint32_t *lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : NULL;
|
||||
uint32_t *rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : NULL;
|
||||
uint32_t *out_dim = &out_dims[max_ndims - i - 1];
|
||||
|
||||
if (lhs_dim_sz == NULL) {
|
||||
*out_dim = *rhs_dim_sz;
|
||||
} else if (rhs_dim_sz == NULL) {
|
||||
*out_dim = *lhs_dim_sz;
|
||||
} else if (*lhs_dim_sz == 1) {
|
||||
*out_dim = *rhs_dim_sz;
|
||||
} else if (*rhs_dim_sz == 1) {
|
||||
*out_dim = *lhs_dim_sz;
|
||||
} else if (*lhs_dim_sz == *rhs_dim_sz) {
|
||||
*out_dim = *lhs_dim_sz;
|
||||
} else {
|
||||
__builtin_unreachable();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_broadcast64(
|
||||
const uint64_t *lhs_dims,
|
||||
uint64_t lhs_ndims,
|
||||
const uint64_t *rhs_dims,
|
||||
uint64_t rhs_ndims,
|
||||
uint64_t *out_dims
|
||||
) {
|
||||
uint64_t max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
|
||||
|
||||
for (uint64_t i = 0; i < max_ndims; ++i) {
|
||||
uint64_t *lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : NULL;
|
||||
uint64_t *rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : NULL;
|
||||
uint64_t *out_dim = &out_dims[max_ndims - i - 1];
|
||||
|
||||
if (lhs_dim_sz == NULL) {
|
||||
*out_dim = *rhs_dim_sz;
|
||||
} else if (rhs_dim_sz == NULL) {
|
||||
*out_dim = *lhs_dim_sz;
|
||||
} else if (*lhs_dim_sz == 1) {
|
||||
*out_dim = *rhs_dim_sz;
|
||||
} else if (*rhs_dim_sz == 1) {
|
||||
*out_dim = *lhs_dim_sz;
|
||||
} else if (*lhs_dim_sz == *rhs_dim_sz) {
|
||||
*out_dim = *lhs_dim_sz;
|
||||
} else {
|
||||
__builtin_unreachable();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_broadcast_idx(
|
||||
const uint32_t *src_dims,
|
||||
uint32_t src_ndims,
|
||||
const uint32_t *in_idx,
|
||||
uint32_t *out_idx
|
||||
) {
|
||||
for (uint32_t i = 0; i < src_ndims; ++i) {
|
||||
uint32_t src_i = src_ndims - i - 1;
|
||||
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
|
||||
}
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_broadcast_idx64(
|
||||
const uint64_t *src_dims,
|
||||
uint64_t src_ndims,
|
||||
const uint32_t *in_idx,
|
||||
uint32_t *out_idx
|
||||
) {
|
||||
for (uint64_t i = 0; i < src_ndims; ++i) {
|
||||
uint64_t src_i = src_ndims - i - 1;
|
||||
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : (uint32_t) in_idx[src_i];
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,17 +1,27 @@
|
|||
use crate::typecheck::typedef::Type;
|
||||
|
||||
use super::{CodeGenContext, CodeGenerator};
|
||||
use super::{
|
||||
classes::{
|
||||
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue,
|
||||
TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
|
||||
},
|
||||
llvm_intrinsics, CodeGenContext, CodeGenerator,
|
||||
};
|
||||
use crate::codegen::classes::TypedArrayLikeAccessor;
|
||||
use crate::codegen::stmt::gen_for_callback_incrementing;
|
||||
use inkwell::{
|
||||
attributes::{Attribute, AttributeLoc},
|
||||
context::Context,
|
||||
memory_buffer::MemoryBuffer,
|
||||
module::Module,
|
||||
types::BasicTypeEnum,
|
||||
values::{IntValue, PointerValue},
|
||||
types::{BasicTypeEnum, IntType},
|
||||
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue},
|
||||
AddressSpace, IntPredicate,
|
||||
};
|
||||
use itertools::Either;
|
||||
use nac3parser::ast::Expr;
|
||||
|
||||
#[must_use]
|
||||
pub fn load_irrt(ctx: &Context) -> Module {
|
||||
let bitcode_buf = MemoryBuffer::create_from_memory_range(
|
||||
include_bytes!(concat!(env!("OUT_DIR"), "/irrt.bc")),
|
||||
|
@ -33,9 +43,9 @@ pub fn load_irrt(ctx: &Context) -> Module {
|
|||
|
||||
// repeated squaring method adapted from GNU Scientific Library:
|
||||
// https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
|
||||
pub fn integer_power<'ctx, 'a>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
base: IntValue<'ctx>,
|
||||
exp: IntValue<'ctx>,
|
||||
signed: bool,
|
||||
|
@ -53,12 +63,15 @@ pub fn integer_power<'ctx, 'a>(
|
|||
ctx.module.add_function(symbol, fn_type, None)
|
||||
});
|
||||
// throw exception when exp < 0
|
||||
let ge_zero = ctx.builder.build_int_compare(
|
||||
IntPredicate::SGE,
|
||||
exp,
|
||||
exp.get_type().const_zero(),
|
||||
"assert_int_pow_ge_0",
|
||||
);
|
||||
let ge_zero = ctx
|
||||
.builder
|
||||
.build_int_compare(
|
||||
IntPredicate::SGE,
|
||||
exp,
|
||||
exp.get_type().const_zero(),
|
||||
"assert_int_pow_ge_0",
|
||||
)
|
||||
.unwrap();
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ge_zero,
|
||||
|
@ -69,14 +82,15 @@ pub fn integer_power<'ctx, 'a>(
|
|||
);
|
||||
ctx.builder
|
||||
.build_call(pow_fun, &[base.into(), exp.into()], "call_int_pow")
|
||||
.try_as_basic_value()
|
||||
.unwrap_left()
|
||||
.into_int_value()
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub fn calculate_len_for_slice_range<'ctx, 'a>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
start: IntValue<'ctx>,
|
||||
end: IntValue<'ctx>,
|
||||
step: IntValue<'ctx>,
|
||||
|
@ -89,12 +103,10 @@ pub fn calculate_len_for_slice_range<'ctx, 'a>(
|
|||
});
|
||||
|
||||
// assert step != 0, throw exception if not
|
||||
let not_zero = ctx.builder.build_int_compare(
|
||||
IntPredicate::NE,
|
||||
step,
|
||||
step.get_type().const_zero(),
|
||||
"range_step_ne",
|
||||
);
|
||||
let not_zero = ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::NE, step, step.get_type().const_zero(), "range_step_ne")
|
||||
.unwrap();
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
not_zero,
|
||||
|
@ -105,10 +117,10 @@ pub fn calculate_len_for_slice_range<'ctx, 'a>(
|
|||
);
|
||||
ctx.builder
|
||||
.build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len")
|
||||
.try_as_basic_value()
|
||||
.left()
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
.into_int_value()
|
||||
}
|
||||
|
||||
/// NOTE: the output value of the end index of this function should be compared ***inclusively***,
|
||||
|
@ -151,47 +163,57 @@ pub fn calculate_len_for_slice_range<'ctx, 'a>(
|
|||
/// ,step
|
||||
/// )
|
||||
/// ```
|
||||
pub fn handle_slice_indices<'a, 'ctx, G: CodeGenerator>(
|
||||
pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
|
||||
start: &Option<Box<Expr<Option<Type>>>>,
|
||||
end: &Option<Box<Expr<Option<Type>>>>,
|
||||
step: &Option<Box<Expr<Option<Type>>>>,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
list: PointerValue<'ctx>,
|
||||
) -> Result<(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), String> {
|
||||
length: IntValue<'ctx>,
|
||||
) -> Result<Option<(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)>, String> {
|
||||
let int32 = ctx.ctx.i32_type();
|
||||
let zero = int32.const_zero();
|
||||
let one = int32.const_int(1, false);
|
||||
let length = ctx.build_gep_and_load(list, &[zero, one]).into_int_value();
|
||||
let length = ctx.builder.build_int_truncate_or_bit_cast(length, int32, "leni32");
|
||||
Ok(match (start, end, step) {
|
||||
let length = ctx.builder.build_int_truncate_or_bit_cast(length, int32, "leni32").unwrap();
|
||||
Ok(Some(match (start, end, step) {
|
||||
(s, e, None) => (
|
||||
s.as_ref().map_or_else(
|
||||
|| Ok(int32.const_zero()),
|
||||
|s| handle_slice_index_bound(s, ctx, generator, length),
|
||||
)?,
|
||||
if let Some(s) = s.as_ref() {
|
||||
match handle_slice_index_bound(s, ctx, generator, length)? {
|
||||
Some(v) => v,
|
||||
None => return Ok(None),
|
||||
}
|
||||
} else {
|
||||
int32.const_zero()
|
||||
},
|
||||
{
|
||||
let e = e.as_ref().map_or_else(
|
||||
|| Ok(length),
|
||||
|e| handle_slice_index_bound(e, ctx, generator, length),
|
||||
)?;
|
||||
ctx.builder.build_int_sub(e, one, "final_end")
|
||||
let e = if let Some(s) = e.as_ref() {
|
||||
match handle_slice_index_bound(s, ctx, generator, length)? {
|
||||
Some(v) => v,
|
||||
None => return Ok(None),
|
||||
}
|
||||
} else {
|
||||
length
|
||||
};
|
||||
ctx.builder.build_int_sub(e, one, "final_end").unwrap()
|
||||
},
|
||||
one,
|
||||
),
|
||||
(s, e, Some(step)) => {
|
||||
let step = generator
|
||||
.gen_expr(ctx, step)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.int32)?
|
||||
.into_int_value();
|
||||
let step = if let Some(v) = generator.gen_expr(ctx, step)? {
|
||||
v.to_basic_value_enum(ctx, generator, ctx.primitives.int32)?.into_int_value()
|
||||
} else {
|
||||
return Ok(None);
|
||||
};
|
||||
// assert step != 0, throw exception if not
|
||||
let not_zero = ctx.builder.build_int_compare(
|
||||
IntPredicate::NE,
|
||||
step,
|
||||
step.get_type().const_zero(),
|
||||
"range_step_ne",
|
||||
);
|
||||
let not_zero = ctx
|
||||
.builder
|
||||
.build_int_compare(
|
||||
IntPredicate::NE,
|
||||
step,
|
||||
step.get_type().const_zero(),
|
||||
"range_step_ne",
|
||||
)
|
||||
.unwrap();
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
not_zero,
|
||||
|
@ -200,60 +222,81 @@ pub fn handle_slice_indices<'a, 'ctx, G: CodeGenerator>(
|
|||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
let len_id = ctx.builder.build_int_sub(length, one, "lenmin1");
|
||||
let neg = ctx.builder.build_int_compare(IntPredicate::SLT, step, zero, "step_is_neg");
|
||||
let len_id = ctx.builder.build_int_sub(length, one, "lenmin1").unwrap();
|
||||
let neg = ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::SLT, step, zero, "step_is_neg")
|
||||
.unwrap();
|
||||
(
|
||||
match s {
|
||||
Some(s) => {
|
||||
let s = handle_slice_index_bound(s, ctx, generator, length)?;
|
||||
let Some(s) = handle_slice_index_bound(s, ctx, generator, length)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
ctx.builder
|
||||
.build_select(
|
||||
ctx.builder.build_and(
|
||||
ctx.builder.build_int_compare(
|
||||
IntPredicate::EQ,
|
||||
s,
|
||||
length,
|
||||
"s_eq_len",
|
||||
),
|
||||
neg,
|
||||
"should_minus_one",
|
||||
),
|
||||
ctx.builder.build_int_sub(s, one, "s_min"),
|
||||
ctx.builder
|
||||
.build_and(
|
||||
ctx.builder
|
||||
.build_int_compare(
|
||||
IntPredicate::EQ,
|
||||
s,
|
||||
length,
|
||||
"s_eq_len",
|
||||
)
|
||||
.unwrap(),
|
||||
neg,
|
||||
"should_minus_one",
|
||||
)
|
||||
.unwrap(),
|
||||
ctx.builder.build_int_sub(s, one, "s_min").unwrap(),
|
||||
s,
|
||||
"final_start",
|
||||
)
|
||||
.into_int_value()
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap()
|
||||
}
|
||||
None => ctx.builder.build_select(neg, len_id, zero, "stt").into_int_value(),
|
||||
None => ctx
|
||||
.builder
|
||||
.build_select(neg, len_id, zero, "stt")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap(),
|
||||
},
|
||||
match e {
|
||||
Some(e) => {
|
||||
let e = handle_slice_index_bound(e, ctx, generator, length)?;
|
||||
let Some(e) = handle_slice_index_bound(e, ctx, generator, length)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
ctx.builder
|
||||
.build_select(
|
||||
neg,
|
||||
ctx.builder.build_int_add(e, one, "end_add_one"),
|
||||
ctx.builder.build_int_sub(e, one, "end_sub_one"),
|
||||
ctx.builder.build_int_add(e, one, "end_add_one").unwrap(),
|
||||
ctx.builder.build_int_sub(e, one, "end_sub_one").unwrap(),
|
||||
"final_end",
|
||||
)
|
||||
.into_int_value()
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap()
|
||||
}
|
||||
None => ctx.builder.build_select(neg, zero, len_id, "end").into_int_value(),
|
||||
None => ctx
|
||||
.builder
|
||||
.build_select(neg, zero, len_id, "end")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap(),
|
||||
},
|
||||
step,
|
||||
)
|
||||
}
|
||||
})
|
||||
}))
|
||||
}
|
||||
|
||||
/// this function allows index out of range, since python
|
||||
/// allows index out of range in slice (`a = [1,2,3]; a[1:10] == [2,3]`).
|
||||
pub fn handle_slice_index_bound<'a, 'ctx, G: CodeGenerator>(
|
||||
pub fn handle_slice_index_bound<'ctx, G: CodeGenerator>(
|
||||
i: &Expr<Option<Type>>,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
length: IntValue<'ctx>,
|
||||
) -> Result<IntValue<'ctx>, String> {
|
||||
) -> Result<Option<IntValue<'ctx>>, String> {
|
||||
const SYMBOL: &str = "__nac3_slice_index_bound";
|
||||
let func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
|
||||
let i32_t = ctx.ctx.i32_type();
|
||||
|
@ -261,30 +304,35 @@ pub fn handle_slice_index_bound<'a, 'ctx, G: CodeGenerator>(
|
|||
ctx.module.add_function(SYMBOL, fn_t, None)
|
||||
});
|
||||
|
||||
let i = generator.gen_expr(ctx, i)?.unwrap().to_basic_value_enum(ctx, generator, i.custom.unwrap())?;
|
||||
Ok(ctx
|
||||
.builder
|
||||
.build_call(func, &[i.into(), length.into()], "bounded_ind")
|
||||
.try_as_basic_value()
|
||||
.left()
|
||||
.unwrap()
|
||||
.into_int_value())
|
||||
let i = if let Some(v) = generator.gen_expr(ctx, i)? {
|
||||
v.to_basic_value_enum(ctx, generator, i.custom.unwrap())?
|
||||
} else {
|
||||
return Ok(None);
|
||||
};
|
||||
Ok(Some(
|
||||
ctx.builder
|
||||
.build_call(func, &[i.into(), length.into()], "bounded_ind")
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap(),
|
||||
))
|
||||
}
|
||||
|
||||
/// This function handles 'end' **inclusively**.
|
||||
/// Order of tuples assign_idx and value_idx is ('start', 'end', 'step').
|
||||
/// Order of tuples `assign_idx` and `value_idx` is ('start', 'end', 'step').
|
||||
/// Negative index should be handled before entering this function
|
||||
pub fn list_slice_assignment<'ctx, 'a>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ty: BasicTypeEnum<'ctx>,
|
||||
dest_arr: PointerValue<'ctx>,
|
||||
dest_arr: ListValue<'ctx>,
|
||||
dest_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
|
||||
src_arr: PointerValue<'ctx>,
|
||||
src_arr: ListValue<'ctx>,
|
||||
src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
|
||||
) {
|
||||
let size_ty = generator.get_size_type(ctx.ctx);
|
||||
let int8_ptr = ctx.ctx.i8_type().ptr_type(AddressSpace::Generic);
|
||||
let int8_ptr = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
|
||||
let int32 = ctx.ctx.i32_type();
|
||||
let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", int8_ptr);
|
||||
let slice_assign_fun = {
|
||||
|
@ -309,76 +357,63 @@ pub fn list_slice_assignment<'ctx, 'a>(
|
|||
|
||||
let zero = int32.const_zero();
|
||||
let one = int32.const_int(1, false);
|
||||
let dest_arr_ptr = ctx.build_gep_and_load(dest_arr, &[zero, zero]);
|
||||
let dest_arr_ptr = ctx.builder.build_pointer_cast(
|
||||
dest_arr_ptr.into_pointer_value(),
|
||||
elem_ptr_type,
|
||||
"dest_arr_ptr_cast",
|
||||
);
|
||||
let dest_len = ctx.build_gep_and_load(dest_arr, &[zero, one]).into_int_value();
|
||||
let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32");
|
||||
let src_arr_ptr = ctx.build_gep_and_load(src_arr, &[zero, zero]);
|
||||
let src_arr_ptr = ctx.builder.build_pointer_cast(
|
||||
src_arr_ptr.into_pointer_value(),
|
||||
elem_ptr_type,
|
||||
"src_arr_ptr_cast",
|
||||
);
|
||||
let src_len = ctx.build_gep_and_load(src_arr, &[zero, one]).into_int_value();
|
||||
let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32");
|
||||
let dest_arr_ptr = dest_arr.data().base_ptr(ctx, generator);
|
||||
let dest_arr_ptr =
|
||||
ctx.builder.build_pointer_cast(dest_arr_ptr, elem_ptr_type, "dest_arr_ptr_cast").unwrap();
|
||||
let dest_len = dest_arr.load_size(ctx, Some("dest.len"));
|
||||
let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32").unwrap();
|
||||
let src_arr_ptr = src_arr.data().base_ptr(ctx, generator);
|
||||
let src_arr_ptr =
|
||||
ctx.builder.build_pointer_cast(src_arr_ptr, elem_ptr_type, "src_arr_ptr_cast").unwrap();
|
||||
let src_len = src_arr.load_size(ctx, Some("src.len"));
|
||||
let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32").unwrap();
|
||||
|
||||
// index in bound and positive should be done
|
||||
// assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and
|
||||
// throw exception if not satisfied
|
||||
let src_end = ctx.builder
|
||||
let src_end = ctx
|
||||
.builder
|
||||
.build_select(
|
||||
ctx.builder.build_int_compare(
|
||||
inkwell::IntPredicate::SLT,
|
||||
src_idx.2,
|
||||
zero,
|
||||
"is_neg",
|
||||
),
|
||||
ctx.builder.build_int_sub(src_idx.1, one, "e_min_one"),
|
||||
ctx.builder.build_int_add(src_idx.1, one, "e_add_one"),
|
||||
ctx.builder.build_int_compare(IntPredicate::SLT, src_idx.2, zero, "is_neg").unwrap(),
|
||||
ctx.builder.build_int_sub(src_idx.1, one, "e_min_one").unwrap(),
|
||||
ctx.builder.build_int_add(src_idx.1, one, "e_add_one").unwrap(),
|
||||
"final_e",
|
||||
)
|
||||
.into_int_value();
|
||||
let dest_end = ctx.builder
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
let dest_end = ctx
|
||||
.builder
|
||||
.build_select(
|
||||
ctx.builder.build_int_compare(
|
||||
inkwell::IntPredicate::SLT,
|
||||
dest_idx.2,
|
||||
zero,
|
||||
"is_neg",
|
||||
),
|
||||
ctx.builder.build_int_sub(dest_idx.1, one, "e_min_one"),
|
||||
ctx.builder.build_int_add(dest_idx.1, one, "e_add_one"),
|
||||
ctx.builder.build_int_compare(IntPredicate::SLT, dest_idx.2, zero, "is_neg").unwrap(),
|
||||
ctx.builder.build_int_sub(dest_idx.1, one, "e_min_one").unwrap(),
|
||||
ctx.builder.build_int_add(dest_idx.1, one, "e_add_one").unwrap(),
|
||||
"final_e",
|
||||
)
|
||||
.into_int_value();
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
let src_slice_len =
|
||||
calculate_len_for_slice_range(generator, ctx, src_idx.0, src_end, src_idx.2);
|
||||
let dest_slice_len =
|
||||
calculate_len_for_slice_range(generator, ctx, dest_idx.0, dest_end, dest_idx.2);
|
||||
let src_eq_dest = ctx.builder.build_int_compare(
|
||||
IntPredicate::EQ,
|
||||
src_slice_len,
|
||||
dest_slice_len,
|
||||
"slice_src_eq_dest",
|
||||
);
|
||||
let src_slt_dest = ctx.builder.build_int_compare(
|
||||
IntPredicate::SLT,
|
||||
src_slice_len,
|
||||
dest_slice_len,
|
||||
"slice_src_slt_dest",
|
||||
);
|
||||
let dest_step_eq_one = ctx.builder.build_int_compare(
|
||||
IntPredicate::EQ,
|
||||
dest_idx.2,
|
||||
dest_idx.2.get_type().const_int(1, false),
|
||||
"slice_dest_step_eq_one",
|
||||
);
|
||||
let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1");
|
||||
let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond");
|
||||
let src_eq_dest = ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::EQ, src_slice_len, dest_slice_len, "slice_src_eq_dest")
|
||||
.unwrap();
|
||||
let src_slt_dest = ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::SLT, src_slice_len, dest_slice_len, "slice_src_slt_dest")
|
||||
.unwrap();
|
||||
let dest_step_eq_one = ctx
|
||||
.builder
|
||||
.build_int_compare(
|
||||
IntPredicate::EQ,
|
||||
dest_idx.2,
|
||||
dest_idx.2.get_type().const_int(1, false),
|
||||
"slice_dest_step_eq_one",
|
||||
)
|
||||
.unwrap();
|
||||
let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1").unwrap();
|
||||
let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond").unwrap();
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
cond,
|
||||
|
@ -408,27 +443,489 @@ pub fn list_slice_assignment<'ctx, 'a>(
|
|||
BasicTypeEnum::StructType(t) => t.size_of().unwrap(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size")
|
||||
ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size").unwrap()
|
||||
}
|
||||
.into(),
|
||||
];
|
||||
ctx.builder
|
||||
.build_call(slice_assign_fun, args.as_slice(), "slice_assign")
|
||||
.try_as_basic_value()
|
||||
.unwrap_left()
|
||||
.into_int_value()
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
};
|
||||
// update length
|
||||
let need_update =
|
||||
ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update");
|
||||
ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update").unwrap();
|
||||
let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
|
||||
let update_bb = ctx.ctx.append_basic_block(current, "update");
|
||||
let cont_bb = ctx.ctx.append_basic_block(current, "cont");
|
||||
ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb);
|
||||
ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb).unwrap();
|
||||
ctx.builder.position_at_end(update_bb);
|
||||
let dest_len_ptr = unsafe { ctx.builder.build_gep(dest_arr, &[zero, one], "dest_len_ptr") };
|
||||
let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len");
|
||||
ctx.builder.build_store(dest_len_ptr, new_len);
|
||||
ctx.builder.build_unconditional_branch(cont_bb);
|
||||
let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len").unwrap();
|
||||
dest_arr.store_size(ctx, generator, new_len);
|
||||
ctx.builder.build_unconditional_branch(cont_bb).unwrap();
|
||||
ctx.builder.position_at_end(cont_bb);
|
||||
}
|
||||
|
||||
/// Generates a call to `isinf` in IR. Returns an `i1` representing the result.
|
||||
pub fn call_isinf<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
v: FloatValue<'ctx>,
|
||||
) -> IntValue<'ctx> {
|
||||
let intrinsic_fn = ctx.module.get_function("__nac3_isinf").unwrap_or_else(|| {
|
||||
let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
|
||||
ctx.module.add_function("__nac3_isinf", fn_type, None)
|
||||
});
|
||||
|
||||
let ret = ctx
|
||||
.builder
|
||||
.build_call(intrinsic_fn, &[v.into()], "isinf")
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap();
|
||||
|
||||
generator.bool_to_i1(ctx, ret)
|
||||
}
|
||||
|
||||
/// Generates a call to `isnan` in IR. Returns an `i1` representing the result.
|
||||
pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
v: FloatValue<'ctx>,
|
||||
) -> IntValue<'ctx> {
|
||||
let intrinsic_fn = ctx.module.get_function("__nac3_isnan").unwrap_or_else(|| {
|
||||
let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
|
||||
ctx.module.add_function("__nac3_isnan", fn_type, None)
|
||||
});
|
||||
|
||||
let ret = ctx
|
||||
.builder
|
||||
.build_call(intrinsic_fn, &[v.into()], "isnan")
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap();
|
||||
|
||||
generator.bool_to_i1(ctx, ret)
|
||||
}
|
||||
|
||||
/// Generates a call to `gamma` in IR. Returns an `f64` representing the result.
|
||||
pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
|
||||
let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| {
|
||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||
ctx.module.add_function("__nac3_gamma", fn_type, None)
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[v.into()], "gamma")
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Generates a call to `gammaln` in IR. Returns an `f64` representing the result.
|
||||
pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
|
||||
let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| {
|
||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||
ctx.module.add_function("__nac3_gammaln", fn_type, None)
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[v.into()], "gammaln")
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Generates a call to `j0` in IR. Returns an `f64` representing the result.
|
||||
pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
|
||||
let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| {
|
||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||
ctx.module.add_function("__nac3_j0", fn_type, None)
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[v.into()], "j0")
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
|
||||
/// calculated total size.
|
||||
///
|
||||
/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension.
|
||||
/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for,
|
||||
/// or [`None`] if starting from the first dimension and ending at the last dimension respectively.
|
||||
pub fn call_ndarray_calc_size<'ctx, G, Dims>(
|
||||
generator: &G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
dims: &Dims,
|
||||
(begin, end): (Option<IntValue<'ctx>>, Option<IntValue<'ctx>>),
|
||||
) -> IntValue<'ctx>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
Dims: ArrayLikeIndexer<'ctx>,
|
||||
{
|
||||
let llvm_i64 = ctx.ctx.i64_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let llvm_pi64 = llvm_i64.ptr_type(AddressSpace::default());
|
||||
|
||||
let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() {
|
||||
32 => "__nac3_ndarray_calc_size",
|
||||
64 => "__nac3_ndarray_calc_size64",
|
||||
bw => unreachable!("Unsupported size type bit width: {}", bw),
|
||||
};
|
||||
let ndarray_calc_size_fn_t = llvm_usize.fn_type(
|
||||
&[llvm_pi64.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()],
|
||||
false,
|
||||
);
|
||||
let ndarray_calc_size_fn =
|
||||
ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| {
|
||||
ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
|
||||
});
|
||||
|
||||
let begin = begin.unwrap_or_else(|| llvm_usize.const_zero());
|
||||
let end = end.unwrap_or_else(|| dims.size(ctx, generator));
|
||||
ctx.builder
|
||||
.build_call(
|
||||
ndarray_calc_size_fn,
|
||||
&[
|
||||
dims.base_ptr(ctx, generator).into(),
|
||||
dims.size(ctx, generator).into(),
|
||||
begin.into(),
|
||||
end.into(),
|
||||
],
|
||||
"",
|
||||
)
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`]
|
||||
/// containing `i32` indices of the flattened index.
|
||||
///
|
||||
/// * `index` - The index to compute the multidimensional index for.
|
||||
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
|
||||
/// `NDArray`.
|
||||
pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
index: IntValue<'ctx>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
|
||||
let llvm_void = ctx.ctx.void_type();
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
|
||||
let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() {
|
||||
32 => "__nac3_ndarray_calc_nd_indices",
|
||||
64 => "__nac3_ndarray_calc_nd_indices64",
|
||||
bw => unreachable!("Unsupported size type bit width: {}", bw),
|
||||
};
|
||||
let ndarray_calc_nd_indices_fn =
|
||||
ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| {
|
||||
let fn_type = llvm_void.fn_type(
|
||||
&[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()],
|
||||
false,
|
||||
);
|
||||
|
||||
ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None)
|
||||
});
|
||||
|
||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||
let ndarray_dims = ndarray.dim_sizes();
|
||||
|
||||
let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_call(
|
||||
ndarray_calc_nd_indices_fn,
|
||||
&[
|
||||
index.into(),
|
||||
ndarray_dims.base_ptr(ctx, generator).into(),
|
||||
ndarray_num_dims.into(),
|
||||
indices.into(),
|
||||
],
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
TypedArrayLikeAdapter::from(
|
||||
ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None),
|
||||
Box::new(|_, v| v.into_int_value()),
|
||||
Box::new(|_, v| v.into()),
|
||||
)
|
||||
}
|
||||
|
||||
fn call_ndarray_flatten_index_impl<'ctx, G, Indices>(
|
||||
generator: &G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
indices: &Indices,
|
||||
) -> IntValue<'ctx>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
Indices: ArrayLikeIndexer<'ctx>,
|
||||
{
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
|
||||
debug_assert_eq!(
|
||||
IntType::try_from(indices.element_type(ctx, generator))
|
||||
.map(IntType::get_bit_width)
|
||||
.unwrap_or_default(),
|
||||
llvm_i32.get_bit_width(),
|
||||
"Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`"
|
||||
);
|
||||
debug_assert_eq!(
|
||||
indices.size(ctx, generator).get_type().get_bit_width(),
|
||||
llvm_usize.get_bit_width(),
|
||||
"Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`"
|
||||
);
|
||||
|
||||
let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
|
||||
32 => "__nac3_ndarray_flatten_index",
|
||||
64 => "__nac3_ndarray_flatten_index64",
|
||||
bw => unreachable!("Unsupported size type bit width: {}", bw),
|
||||
};
|
||||
let ndarray_flatten_index_fn =
|
||||
ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| {
|
||||
let fn_type = llvm_usize.fn_type(
|
||||
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()],
|
||||
false,
|
||||
);
|
||||
|
||||
ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None)
|
||||
});
|
||||
|
||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||
let ndarray_dims = ndarray.dim_sizes();
|
||||
|
||||
let index = ctx
|
||||
.builder
|
||||
.build_call(
|
||||
ndarray_flatten_index_fn,
|
||||
&[
|
||||
ndarray_dims.base_ptr(ctx, generator).into(),
|
||||
ndarray_num_dims.into(),
|
||||
indices.base_ptr(ctx, generator).into(),
|
||||
indices.size(ctx, generator).into(),
|
||||
],
|
||||
"",
|
||||
)
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap();
|
||||
|
||||
index
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
|
||||
/// multidimensional index.
|
||||
///
|
||||
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
|
||||
/// `NDArray`.
|
||||
/// * `indices` - The multidimensional index to compute the flattened index for.
|
||||
pub fn call_ndarray_flatten_index<'ctx, G, Index>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
indices: &Index,
|
||||
) -> IntValue<'ctx>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
Index: ArrayLikeIndexer<'ctx>,
|
||||
{
|
||||
call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices)
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of
|
||||
/// dimension and size of each dimension of the resultant `ndarray`.
|
||||
pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
lhs: NDArrayValue<'ctx>,
|
||||
rhs: NDArrayValue<'ctx>,
|
||||
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
|
||||
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
|
||||
32 => "__nac3_ndarray_calc_broadcast",
|
||||
64 => "__nac3_ndarray_calc_broadcast64",
|
||||
bw => unreachable!("Unsupported size type bit width: {}", bw),
|
||||
};
|
||||
let ndarray_calc_broadcast_fn =
|
||||
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
|
||||
let fn_type = llvm_usize.fn_type(
|
||||
&[
|
||||
llvm_pusize.into(),
|
||||
llvm_usize.into(),
|
||||
llvm_pusize.into(),
|
||||
llvm_usize.into(),
|
||||
llvm_pusize.into(),
|
||||
],
|
||||
false,
|
||||
);
|
||||
|
||||
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
|
||||
});
|
||||
|
||||
let lhs_ndims = lhs.load_ndims(ctx);
|
||||
let rhs_ndims = rhs.load_ndims(ctx);
|
||||
let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None);
|
||||
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
ctx,
|
||||
llvm_usize.const_zero(),
|
||||
(min_ndims, false),
|
||||
|generator, ctx, idx| {
|
||||
let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap();
|
||||
let (lhs_dim_sz, rhs_dim_sz) = unsafe {
|
||||
(
|
||||
lhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None),
|
||||
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None),
|
||||
)
|
||||
};
|
||||
|
||||
let llvm_usize_const_one = llvm_usize.const_int(1, false);
|
||||
let lhs_eqz = ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "")
|
||||
.unwrap();
|
||||
let rhs_eqz = ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "")
|
||||
.unwrap();
|
||||
let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap();
|
||||
|
||||
let lhs_eq_rhs = ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "")
|
||||
.unwrap();
|
||||
|
||||
let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap();
|
||||
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
is_compatible,
|
||||
"0:ValueError",
|
||||
"operands could not be broadcast together",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
},
|
||||
llvm_usize.const_int(1, false),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None);
|
||||
let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator);
|
||||
let lhs_ndims = lhs.load_ndims(ctx);
|
||||
let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator);
|
||||
let rhs_ndims = rhs.load_ndims(ctx);
|
||||
let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap();
|
||||
let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None);
|
||||
|
||||
ctx.builder
|
||||
.build_call(
|
||||
ndarray_calc_broadcast_fn,
|
||||
&[
|
||||
lhs_dims.into(),
|
||||
lhs_ndims.into(),
|
||||
rhs_dims.into(),
|
||||
rhs_ndims.into(),
|
||||
out_dims.base_ptr(ctx, generator).into(),
|
||||
],
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
TypedArrayLikeAdapter::from(
|
||||
out_dims,
|
||||
Box::new(|_, v| v.into_int_value()),
|
||||
Box::new(|_, v| v.into()),
|
||||
)
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`]
|
||||
/// containing the indices used for accessing `array` corresponding to the index of the broadcasted
|
||||
/// array `broadcast_idx`.
|
||||
pub fn call_ndarray_calc_broadcast_index<
|
||||
'ctx,
|
||||
G: CodeGenerator + ?Sized,
|
||||
BroadcastIdx: UntypedArrayLikeAccessor<'ctx>,
|
||||
>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
array: NDArrayValue<'ctx>,
|
||||
broadcast_idx: &BroadcastIdx,
|
||||
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
|
||||
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
|
||||
32 => "__nac3_ndarray_calc_broadcast_idx",
|
||||
64 => "__nac3_ndarray_calc_broadcast_idx64",
|
||||
bw => unreachable!("Unsupported size type bit width: {}", bw),
|
||||
};
|
||||
let ndarray_calc_broadcast_fn =
|
||||
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
|
||||
let fn_type = llvm_usize.fn_type(
|
||||
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()],
|
||||
false,
|
||||
);
|
||||
|
||||
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
|
||||
});
|
||||
|
||||
let broadcast_size = broadcast_idx.size(ctx, generator);
|
||||
let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();
|
||||
|
||||
let array_dims = array.dim_sizes().base_ptr(ctx, generator);
|
||||
let array_ndims = array.load_ndims(ctx);
|
||||
let broadcast_idx_ptr = unsafe {
|
||||
broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
};
|
||||
|
||||
ctx.builder
|
||||
.build_call(
|
||||
ndarray_calc_broadcast_fn,
|
||||
&[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()],
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
TypedArrayLikeAdapter::from(
|
||||
ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None),
|
||||
Box::new(|_, v| v.into_int_value()),
|
||||
Box::new(|_, v| v.into()),
|
||||
)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,719 @@
|
|||
use crate::codegen::CodeGenContext;
|
||||
use inkwell::context::Context;
|
||||
use inkwell::intrinsics::Intrinsic;
|
||||
use inkwell::types::AnyTypeEnum::IntType;
|
||||
use inkwell::types::FloatType;
|
||||
use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue};
|
||||
use inkwell::AddressSpace;
|
||||
use itertools::Either;
|
||||
|
||||
/// Returns the string representation for the floating-point type `ft` when used in intrinsic
|
||||
/// functions.
|
||||
fn get_float_intrinsic_repr(ctx: &Context, ft: FloatType) -> &'static str {
|
||||
// Standard LLVM floating-point types
|
||||
if ft == ctx.f16_type() {
|
||||
return "f16";
|
||||
}
|
||||
if ft == ctx.f32_type() {
|
||||
return "f32";
|
||||
}
|
||||
if ft == ctx.f64_type() {
|
||||
return "f64";
|
||||
}
|
||||
if ft == ctx.f128_type() {
|
||||
return "f128";
|
||||
}
|
||||
|
||||
// Non-standard floating-point types
|
||||
if ft == ctx.x86_f80_type() {
|
||||
return "f80";
|
||||
}
|
||||
if ft == ctx.ppc_f128_type() {
|
||||
return "ppcf128";
|
||||
}
|
||||
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
/// Invokes the [`llvm.stacksave`](https://llvm.org/docs/LangRef.html#llvm-stacksave-intrinsic)
|
||||
/// intrinsic.
|
||||
pub fn call_stacksave<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
const FN_NAME: &str = "llvm.stacksave";
|
||||
|
||||
let intrinsic_fn = Intrinsic::find(FN_NAME)
|
||||
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[]))
|
||||
.unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_pointer_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the
|
||||
/// [`llvm.stackrestore`](https://llvm.org/docs/LangRef.html#llvm-stackrestore-intrinsic) intrinsic.
|
||||
///
|
||||
/// - `ptr`: The pointer storing the address to restore the stack to.
|
||||
pub fn call_stackrestore<'ctx>(ctx: &CodeGenContext<'ctx, '_>, ptr: PointerValue<'ctx>) {
|
||||
const FN_NAME: &str = "llvm.stackrestore";
|
||||
|
||||
let llvm_i8 = ctx.ctx.i8_type();
|
||||
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
|
||||
|
||||
let intrinsic_fn = Intrinsic::find(FN_NAME)
|
||||
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_p0i8.into()]))
|
||||
.unwrap();
|
||||
|
||||
ctx.builder.build_call(intrinsic_fn, &[ptr.into()], "").unwrap();
|
||||
}
|
||||
|
||||
/// Invokes the [`llvm.abs`](https://llvm.org/docs/LangRef.html#llvm-abs-intrinsic) intrinsic.
|
||||
///
|
||||
/// * `src` - The value for which the absolute value is to be returned.
|
||||
/// * `is_int_min_poison` - Whether `poison` is to be returned if `src` is `INT_MIN`.
|
||||
pub fn call_int_abs<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
src: IntValue<'ctx>,
|
||||
is_int_min_poison: IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> IntValue<'ctx> {
|
||||
const FN_NAME: &str = "llvm.abs";
|
||||
|
||||
debug_assert_eq!(is_int_min_poison.get_type().get_bit_width(), 1);
|
||||
debug_assert!(is_int_min_poison.is_const());
|
||||
|
||||
let llvm_src_t = src.get_type();
|
||||
|
||||
let intrinsic_fn = Intrinsic::find(FN_NAME)
|
||||
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_src_t.into()]))
|
||||
.unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[src.into(), is_int_min_poison.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`llvm.smax`](https://llvm.org/docs/LangRef.html#llvm-smax-intrinsic) intrinsic.
|
||||
pub fn call_int_smax<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
a: IntValue<'ctx>,
|
||||
b: IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> IntValue<'ctx> {
|
||||
const FN_NAME: &str = "llvm.smax";
|
||||
|
||||
debug_assert_eq!(a.get_type().get_bit_width(), b.get_type().get_bit_width());
|
||||
|
||||
let llvm_int_t = a.get_type();
|
||||
|
||||
let intrinsic_fn = Intrinsic::find(FN_NAME)
|
||||
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_int_t.into()]))
|
||||
.unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[a.into(), b.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`llvm.smin`](https://llvm.org/docs/LangRef.html#llvm-smin-intrinsic) intrinsic.
|
||||
pub fn call_int_smin<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
a: IntValue<'ctx>,
|
||||
b: IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> IntValue<'ctx> {
|
||||
const FN_NAME: &str = "llvm.smin";
|
||||
|
||||
debug_assert_eq!(a.get_type().get_bit_width(), b.get_type().get_bit_width());
|
||||
|
||||
let llvm_int_t = a.get_type();
|
||||
|
||||
let intrinsic_fn = Intrinsic::find(FN_NAME)
|
||||
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_int_t.into()]))
|
||||
.unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[a.into(), b.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`llvm.umax`](https://llvm.org/docs/LangRef.html#llvm-umax-intrinsic) intrinsic.
|
||||
pub fn call_int_umax<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
a: IntValue<'ctx>,
|
||||
b: IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> IntValue<'ctx> {
|
||||
const FN_NAME: &str = "llvm.umax";
|
||||
|
||||
debug_assert_eq!(a.get_type().get_bit_width(), b.get_type().get_bit_width());
|
||||
|
||||
let llvm_int_t = a.get_type();
|
||||
|
||||
let intrinsic_fn = Intrinsic::find(FN_NAME)
|
||||
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_int_t.into()]))
|
||||
.unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[a.into(), b.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`llvm.umin`](https://llvm.org/docs/LangRef.html#llvm-umin-intrinsic) intrinsic.
|
||||
pub fn call_int_umin<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
a: IntValue<'ctx>,
|
||||
b: IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> IntValue<'ctx> {
|
||||
const FN_NAME: &str = "llvm.umin";
|
||||
|
||||
debug_assert_eq!(a.get_type().get_bit_width(), b.get_type().get_bit_width());
|
||||
|
||||
let llvm_int_t = a.get_type();
|
||||
|
||||
let intrinsic_fn = Intrinsic::find(FN_NAME)
|
||||
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_int_t.into()]))
|
||||
.unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[a.into(), b.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`llvm.memcpy`](https://llvm.org/docs/LangRef.html#llvm-memcpy-intrinsic) intrinsic.
|
||||
///
|
||||
/// * `dest` - The pointer to the destination. Must be a pointer to an integer type.
|
||||
/// * `src` - The pointer to the source. Must be a pointer to an integer type.
|
||||
/// * `len` - The number of bytes to copy.
|
||||
/// * `is_volatile` - Whether the `memcpy` operation should be `volatile`.
|
||||
pub fn call_memcpy<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
dest: PointerValue<'ctx>,
|
||||
src: PointerValue<'ctx>,
|
||||
len: IntValue<'ctx>,
|
||||
is_volatile: IntValue<'ctx>,
|
||||
) {
|
||||
const FN_NAME: &str = "llvm.memcpy";
|
||||
|
||||
debug_assert!(dest.get_type().get_element_type().is_int_type());
|
||||
debug_assert!(src.get_type().get_element_type().is_int_type());
|
||||
debug_assert_eq!(
|
||||
dest.get_type().get_element_type().into_int_type().get_bit_width(),
|
||||
src.get_type().get_element_type().into_int_type().get_bit_width(),
|
||||
);
|
||||
debug_assert!(matches!(len.get_type().get_bit_width(), 32 | 64));
|
||||
debug_assert_eq!(is_volatile.get_type().get_bit_width(), 1);
|
||||
|
||||
let llvm_dest_t = dest.get_type();
|
||||
let llvm_src_t = src.get_type();
|
||||
let llvm_len_t = len.get_type();
|
||||
|
||||
let intrinsic_fn = Intrinsic::find(FN_NAME)
|
||||
.and_then(|intrinsic| {
|
||||
intrinsic.get_declaration(
|
||||
&ctx.module,
|
||||
&[llvm_dest_t.into(), llvm_src_t.into(), llvm_len_t.into()],
|
||||
)
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[dest.into(), src.into(), len.into(), is_volatile.into()], "")
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
/// Invokes the `llvm.memcpy` intrinsic.
|
||||
///
|
||||
/// Unlike [`call_memcpy`], this function accepts any type of pointer value. If `dest` or `src` is
|
||||
/// not a pointer to an integer, the pointer(s) will be cast to `i8*` before invoking `memcpy`.
|
||||
pub fn call_memcpy_generic<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
dest: PointerValue<'ctx>,
|
||||
src: PointerValue<'ctx>,
|
||||
len: IntValue<'ctx>,
|
||||
is_volatile: IntValue<'ctx>,
|
||||
) {
|
||||
let llvm_i8 = ctx.ctx.i8_type();
|
||||
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
|
||||
|
||||
let dest_elem_t = dest.get_type().get_element_type();
|
||||
let src_elem_t = src.get_type().get_element_type();
|
||||
|
||||
let dest = if matches!(dest_elem_t, IntType(t) if t.get_bit_width() == 8) {
|
||||
dest
|
||||
} else {
|
||||
ctx.builder
|
||||
.build_bitcast(dest, llvm_p0i8, "")
|
||||
.map(BasicValueEnum::into_pointer_value)
|
||||
.unwrap()
|
||||
};
|
||||
let src = if matches!(src_elem_t, IntType(t) if t.get_bit_width() == 8) {
|
||||
src
|
||||
} else {
|
||||
ctx.builder
|
||||
.build_bitcast(src, llvm_p0i8, "")
|
||||
.map(BasicValueEnum::into_pointer_value)
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
call_memcpy(ctx, dest, src, len, is_volatile);
|
||||
}
|
||||
|
||||
/// Invokes the [`llvm.sqrt`](https://llvm.org/docs/LangRef.html#llvm-sqrt-intrinsic) intrinsic.
|
||||
pub fn call_float_sqrt<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
val: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "llvm.sqrt";
|
||||
|
||||
let llvm_float_t = val.get_type();
|
||||
|
||||
let intrinsic_fn = Intrinsic::find(FN_NAME)
|
||||
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_float_t.into()]))
|
||||
.unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`llvm.powi`](https://llvm.org/docs/LangRef.html#llvm-powi-intrinsic) intrinsic.
|
||||
pub fn call_float_powi<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
val: FloatValue<'ctx>,
|
||||
power: IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "llvm.powi";
|
||||
|
||||
let llvm_val_t = val.get_type();
|
||||
let llvm_power_t = power.get_type();
|
||||
|
||||
let intrinsic_fn = Intrinsic::find(FN_NAME)
|
||||
.and_then(|intrinsic| {
|
||||
intrinsic.get_declaration(&ctx.module, &[llvm_val_t.into(), llvm_power_t.into()])
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[val.into(), power.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`llvm.sin`](https://llvm.org/docs/LangRef.html#llvm-sin-intrinsic) intrinsic.
|
||||
pub fn call_float_sin<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
val: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "llvm.sin";
|
||||
|
||||
let llvm_float_t = val.get_type();
|
||||
|
||||
let intrinsic_fn = Intrinsic::find(FN_NAME)
|
||||
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_float_t.into()]))
|
||||
.unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`llvm.cos`](https://llvm.org/docs/LangRef.html#llvm-cos-intrinsic) intrinsic.
|
||||
pub fn call_float_cos<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
val: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "llvm.cos";
|
||||
|
||||
let llvm_float_t = val.get_type();
|
||||
|
||||
let intrinsic_fn = Intrinsic::find(FN_NAME)
|
||||
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_float_t.into()]))
|
||||
.unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`llvm.pow`](https://llvm.org/docs/LangRef.html#llvm-pow-intrinsic) intrinsic.
|
||||
pub fn call_float_pow<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
val: FloatValue<'ctx>,
|
||||
power: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "llvm.pow";
|
||||
|
||||
debug_assert_eq!(val.get_type(), power.get_type());
|
||||
|
||||
let llvm_float_t = val.get_type();
|
||||
|
||||
let intrinsic_fn = Intrinsic::find(FN_NAME)
|
||||
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_float_t.into()]))
|
||||
.unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[val.into(), power.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`llvm.exp`](https://llvm.org/docs/LangRef.html#llvm-exp-intrinsic) intrinsic.
|
||||
pub fn call_float_exp<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
val: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "llvm.exp";
|
||||
|
||||
let llvm_float_t = val.get_type();
|
||||
|
||||
let intrinsic_fn = Intrinsic::find(FN_NAME)
|
||||
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_float_t.into()]))
|
||||
.unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`llvm.exp2`](https://llvm.org/docs/LangRef.html#llvm-exp2-intrinsic) intrinsic.
|
||||
pub fn call_float_exp2<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
val: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "llvm.exp2";
|
||||
|
||||
let llvm_float_t = val.get_type();
|
||||
|
||||
let intrinsic_fn = Intrinsic::find(FN_NAME)
|
||||
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_float_t.into()]))
|
||||
.unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`llvm.log`](https://llvm.org/docs/LangRef.html#llvm-log-intrinsic) intrinsic.
|
||||
pub fn call_float_log<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
val: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "llvm.log";
|
||||
|
||||
let llvm_float_t = val.get_type();
|
||||
|
||||
let intrinsic_fn = Intrinsic::find(FN_NAME)
|
||||
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_float_t.into()]))
|
||||
.unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`llvm.log10`](https://llvm.org/docs/LangRef.html#llvm-log10-intrinsic) intrinsic.
|
||||
pub fn call_float_log10<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
val: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "llvm.log10";
|
||||
|
||||
let llvm_float_t = val.get_type();
|
||||
|
||||
let intrinsic_fn = Intrinsic::find(FN_NAME)
|
||||
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_float_t.into()]))
|
||||
.unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`llvm.log2`](https://llvm.org/docs/LangRef.html#llvm-log2-intrinsic) intrinsic.
|
||||
pub fn call_float_log2<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
val: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "llvm.log2";
|
||||
|
||||
let llvm_float_t = val.get_type();
|
||||
|
||||
let intrinsic_fn = Intrinsic::find(FN_NAME)
|
||||
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_float_t.into()]))
|
||||
.unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`llvm.fabs`](https://llvm.org/docs/LangRef.html#llvm-fabs-intrinsic) intrinsic.
|
||||
pub fn call_float_fabs<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
src: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "llvm.fabs";
|
||||
|
||||
let llvm_src_t = src.get_type();
|
||||
|
||||
let intrinsic_fn = Intrinsic::find(FN_NAME)
|
||||
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_src_t.into()]))
|
||||
.unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[src.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`llvm.minnum`](https://llvm.org/docs/LangRef.html#llvm-minnum-intrinsic) intrinsic.
|
||||
pub fn call_float_minnum<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
val1: FloatValue<'ctx>,
|
||||
val2: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "llvm.minnum";
|
||||
|
||||
debug_assert_eq!(val1.get_type(), val2.get_type());
|
||||
|
||||
let llvm_float_t = val1.get_type();
|
||||
|
||||
let intrinsic_fn = Intrinsic::find(FN_NAME)
|
||||
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_float_t.into()]))
|
||||
.unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[val1.into(), val2.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`llvm.maxnum`](https://llvm.org/docs/LangRef.html#llvm-maxnum-intrinsic) intrinsic.
|
||||
pub fn call_float_maxnum<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
val1: FloatValue<'ctx>,
|
||||
val2: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "llvm.maxnum";
|
||||
|
||||
debug_assert_eq!(val1.get_type(), val2.get_type());
|
||||
|
||||
let llvm_float_t = val1.get_type();
|
||||
|
||||
let intrinsic_fn = Intrinsic::find(FN_NAME)
|
||||
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_float_t.into()]))
|
||||
.unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[val1.into(), val2.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`llvm.copysign`](https://llvm.org/docs/LangRef.html#llvm-copysign-intrinsic) intrinsic.
|
||||
pub fn call_float_copysign<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
mag: FloatValue<'ctx>,
|
||||
sgn: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "llvm.copysign";
|
||||
|
||||
debug_assert_eq!(mag.get_type(), sgn.get_type());
|
||||
|
||||
let llvm_float_t = mag.get_type();
|
||||
|
||||
let intrinsic_fn = Intrinsic::find(FN_NAME)
|
||||
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_float_t.into()]))
|
||||
.unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[mag.into(), sgn.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`llvm.floor`](https://llvm.org/docs/LangRef.html#llvm-floor-intrinsic) intrinsic.
|
||||
pub fn call_float_floor<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
val: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "llvm.floor";
|
||||
|
||||
let llvm_float_t = val.get_type();
|
||||
|
||||
let intrinsic_fn = Intrinsic::find(FN_NAME)
|
||||
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_float_t.into()]))
|
||||
.unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`llvm.ceil`](https://llvm.org/docs/LangRef.html#llvm-ceil-intrinsic) intrinsic.
|
||||
pub fn call_float_ceil<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
val: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "llvm.ceil";
|
||||
|
||||
let llvm_float_t = val.get_type();
|
||||
|
||||
let intrinsic_fn = Intrinsic::find(FN_NAME)
|
||||
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_float_t.into()]))
|
||||
.unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`llvm.round`](https://llvm.org/docs/LangRef.html#llvm-round-intrinsic) intrinsic.
|
||||
pub fn call_float_round<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
val: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "llvm.round";
|
||||
|
||||
let llvm_float_t = val.get_type();
|
||||
|
||||
let intrinsic_fn = Intrinsic::find(FN_NAME)
|
||||
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_float_t.into()]))
|
||||
.unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the
|
||||
/// [`llvm.roundeven`](https://llvm.org/docs/LangRef.html#llvm-roundeven-intrinsic) intrinsic.
|
||||
pub fn call_float_roundeven<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
val: FloatValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "llvm.roundeven";
|
||||
|
||||
let llvm_float_t = val.get_type();
|
||||
|
||||
let intrinsic_fn = Intrinsic::find(FN_NAME)
|
||||
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_float_t.into()]))
|
||||
.unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`llvm.expect`](https://llvm.org/docs/LangRef.html#llvm-expect-intrinsic) intrinsic.
|
||||
pub fn call_expect<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
val: IntValue<'ctx>,
|
||||
expected_val: IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> IntValue<'ctx> {
|
||||
const FN_NAME: &str = "llvm.expect";
|
||||
|
||||
debug_assert_eq!(val.get_type().get_bit_width(), expected_val.get_type().get_bit_width());
|
||||
|
||||
let llvm_int_t = val.get_type();
|
||||
|
||||
let intrinsic_fn = Intrinsic::find(FN_NAME)
|
||||
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_int_t.into()]))
|
||||
.unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[val.into(), expected_val.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -1,18 +1,25 @@
|
|||
use crate::{
|
||||
codegen::{
|
||||
concrete_type::ConcreteTypeStore, CodeGenContext, CodeGenTask, DefaultCodeGenerator,
|
||||
WithCall, WorkerRegistry,
|
||||
classes::{ListType, NDArrayType, ProxyType, RangeType},
|
||||
concrete_type::ConcreteTypeStore,
|
||||
CodeGenContext, CodeGenLLVMOptions, CodeGenTargetMachineOptions, CodeGenTask,
|
||||
CodeGenerator, DefaultCodeGenerator, WithCall, WorkerRegistry,
|
||||
},
|
||||
symbol_resolver::{SymbolResolver, ValueEnum},
|
||||
toplevel::{
|
||||
composer::TopLevelComposer, DefinitionId, FunInstance, TopLevelContext, TopLevelDef,
|
||||
composer::{ComposerConfig, TopLevelComposer},
|
||||
DefinitionId, FunInstance, TopLevelContext, TopLevelDef,
|
||||
},
|
||||
typecheck::{
|
||||
type_inferencer::{FunctionData, Inferencer, PrimitiveStore},
|
||||
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier},
|
||||
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
|
||||
},
|
||||
};
|
||||
use indoc::indoc;
|
||||
use inkwell::{
|
||||
targets::{InitializationConfig, Target},
|
||||
OptimizationLevel,
|
||||
};
|
||||
use nac3parser::{
|
||||
ast::{fold::Fold, StrRef},
|
||||
parser::parse_program,
|
||||
|
@ -59,12 +66,12 @@ impl SymbolResolver for Resolver {
|
|||
unimplemented!()
|
||||
}
|
||||
|
||||
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, String> {
|
||||
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> {
|
||||
self.id_to_def
|
||||
.read()
|
||||
.get(&id)
|
||||
.cloned()
|
||||
.ok_or_else(|| format!("cannot find symbol `{}`", id))
|
||||
.ok_or_else(|| HashSet::from([format!("cannot find symbol `{}`", id)]))
|
||||
}
|
||||
|
||||
fn get_string_id(&self, _: &str) -> i32 {
|
||||
|
@ -85,7 +92,7 @@ fn test_primitives() {
|
|||
"};
|
||||
let statements = parse_program(source, Default::default()).unwrap();
|
||||
|
||||
let composer: TopLevelComposer = Default::default();
|
||||
let composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 32).0;
|
||||
let mut unifier = composer.unifier.clone();
|
||||
let primitives = composer.primitives_ty;
|
||||
let top_level = Arc::new(composer.make_top_level_context());
|
||||
|
@ -104,7 +111,7 @@ fn test_primitives() {
|
|||
FuncArg { name: "b".into(), ty: primitives.int32, default_value: None },
|
||||
],
|
||||
ret: primitives.int32,
|
||||
vars: HashMap::new(),
|
||||
vars: VarMap::new(),
|
||||
};
|
||||
|
||||
let mut store = ConcreteTypeStore::new();
|
||||
|
@ -181,27 +188,45 @@ fn test_primitives() {
|
|||
; ModuleID = 'test'
|
||||
source_filename = \"test\"
|
||||
|
||||
define i32 @testing(i32 %0, i32 %1) {
|
||||
; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn
|
||||
define i32 @testing(i32 %0, i32 %1) local_unnamed_addr #0 !dbg !4 {
|
||||
init:
|
||||
%add = add i32 %0, %1
|
||||
%cmp = icmp eq i32 %add, 1
|
||||
br i1 %cmp, label %then, label %else
|
||||
|
||||
then: ; preds = %init
|
||||
br label %cont
|
||||
|
||||
else: ; preds = %init
|
||||
br label %cont
|
||||
|
||||
cont: ; preds = %else, %then
|
||||
%if_exp_result.0 = phi i32 [ %0, %then ], [ 0, %else ]
|
||||
ret i32 %if_exp_result.0
|
||||
%add = add i32 %1, %0, !dbg !9
|
||||
%cmp = icmp eq i32 %add, 1, !dbg !10
|
||||
%. = select i1 %cmp, i32 %0, i32 0, !dbg !11
|
||||
ret i32 %., !dbg !12
|
||||
}
|
||||
"}
|
||||
|
||||
attributes #0 = { mustprogress nofree norecurse nosync nounwind readnone willreturn }
|
||||
|
||||
!llvm.module.flags = !{!0, !1}
|
||||
!llvm.dbg.cu = !{!2}
|
||||
|
||||
!0 = !{i32 2, !\"Debug Info Version\", i32 3}
|
||||
!1 = !{i32 2, !\"Dwarf Version\", i32 4}
|
||||
!2 = distinct !DICompileUnit(language: DW_LANG_Python, file: !3, producer: \"NAC3\", isOptimized: true, runtimeVersion: 0, emissionKind: FullDebug)
|
||||
!3 = !DIFile(filename: \"unknown\", directory: \"\")
|
||||
!4 = distinct !DISubprogram(name: \"testing\", linkageName: \"testing\", scope: null, file: !3, line: 1, type: !5, scopeLine: 1, flags: DIFlagPublic, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !2, retainedNodes: !8)
|
||||
!5 = !DISubroutineType(flags: DIFlagPublic, types: !6)
|
||||
!6 = !{!7}
|
||||
!7 = !DIBasicType(name: \"_\", flags: DIFlagPublic)
|
||||
!8 = !{}
|
||||
!9 = !DILocation(line: 1, column: 9, scope: !4)
|
||||
!10 = !DILocation(line: 2, column: 15, scope: !4)
|
||||
!11 = !DILocation(line: 0, scope: !4)
|
||||
!12 = !DILocation(line: 3, column: 8, scope: !4)
|
||||
"}
|
||||
.trim();
|
||||
assert_eq!(expected, module.print_to_string().to_str().unwrap().trim());
|
||||
})));
|
||||
let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, f);
|
||||
|
||||
Target::initialize_all(&InitializationConfig::default());
|
||||
|
||||
let llvm_options = CodeGenLLVMOptions {
|
||||
opt_level: OptimizationLevel::Default,
|
||||
target: CodeGenTargetMachineOptions::from_host_triple(),
|
||||
};
|
||||
let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f);
|
||||
registry.add_task(task);
|
||||
registry.wait_tasks_complete(handles);
|
||||
}
|
||||
|
@ -219,7 +244,7 @@ fn test_simple_call() {
|
|||
"};
|
||||
let statements_2 = parse_program(source_2, Default::default()).unwrap();
|
||||
|
||||
let composer: TopLevelComposer = Default::default();
|
||||
let composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 32).0;
|
||||
let mut unifier = composer.unifier.clone();
|
||||
let primitives = composer.primitives_ty;
|
||||
let top_level = Arc::new(composer.make_top_level_context());
|
||||
|
@ -228,7 +253,7 @@ fn test_simple_call() {
|
|||
let signature = FunSignature {
|
||||
args: vec![FuncArg { name: "a".into(), ty: primitives.int32, default_value: None }],
|
||||
ret: primitives.int32,
|
||||
vars: HashMap::new(),
|
||||
vars: VarMap::new(),
|
||||
};
|
||||
let fun_ty = unifier.add_ty(TypeEnum::TFunc(signature.clone()));
|
||||
let mut store = ConcreteTypeStore::new();
|
||||
|
@ -342,23 +367,83 @@ fn test_simple_call() {
|
|||
; ModuleID = 'test'
|
||||
source_filename = \"test\"
|
||||
|
||||
define i32 @testing(i32 %0) {
|
||||
; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn
|
||||
define i32 @testing(i32 %0) local_unnamed_addr #0 !dbg !5 {
|
||||
init:
|
||||
%call = call i32 @foo.0(i32 %0)
|
||||
%mul = mul i32 %call, 2
|
||||
ret i32 %mul
|
||||
%add.i = shl i32 %0, 1, !dbg !10
|
||||
%mul = add i32 %add.i, 2, !dbg !10
|
||||
ret i32 %mul, !dbg !10
|
||||
}
|
||||
|
||||
define i32 @foo.0(i32 %0) {
|
||||
; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn
|
||||
define i32 @foo.0(i32 %0) local_unnamed_addr #0 !dbg !11 {
|
||||
init:
|
||||
%add = add i32 %0, 1
|
||||
ret i32 %add
|
||||
%add = add i32 %0, 1, !dbg !12
|
||||
ret i32 %add, !dbg !12
|
||||
}
|
||||
"}
|
||||
|
||||
attributes #0 = { mustprogress nofree norecurse nosync nounwind readnone willreturn }
|
||||
|
||||
!llvm.module.flags = !{!0, !1}
|
||||
!llvm.dbg.cu = !{!2, !4}
|
||||
|
||||
!0 = !{i32 2, !\"Debug Info Version\", i32 3}
|
||||
!1 = !{i32 2, !\"Dwarf Version\", i32 4}
|
||||
!2 = distinct !DICompileUnit(language: DW_LANG_Python, file: !3, producer: \"NAC3\", isOptimized: true, runtimeVersion: 0, emissionKind: FullDebug)
|
||||
!3 = !DIFile(filename: \"unknown\", directory: \"\")
|
||||
!4 = distinct !DICompileUnit(language: DW_LANG_Python, file: !3, producer: \"NAC3\", isOptimized: true, runtimeVersion: 0, emissionKind: FullDebug)
|
||||
!5 = distinct !DISubprogram(name: \"testing\", linkageName: \"testing\", scope: null, file: !3, line: 1, type: !6, scopeLine: 1, flags: DIFlagPublic, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !2, retainedNodes: !9)
|
||||
!6 = !DISubroutineType(flags: DIFlagPublic, types: !7)
|
||||
!7 = !{!8}
|
||||
!8 = !DIBasicType(name: \"_\", flags: DIFlagPublic)
|
||||
!9 = !{}
|
||||
!10 = !DILocation(line: 2, column: 12, scope: !5)
|
||||
!11 = distinct !DISubprogram(name: \"foo.0\", linkageName: \"foo.0\", scope: null, file: !3, line: 1, type: !6, scopeLine: 1, flags: DIFlagPublic, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !4, retainedNodes: !9)
|
||||
!12 = !DILocation(line: 1, column: 12, scope: !11)
|
||||
"}
|
||||
.trim();
|
||||
assert_eq!(expected, module.print_to_string().to_str().unwrap().trim());
|
||||
})));
|
||||
let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, f);
|
||||
|
||||
Target::initialize_all(&InitializationConfig::default());
|
||||
|
||||
let llvm_options = CodeGenLLVMOptions {
|
||||
opt_level: OptimizationLevel::Default,
|
||||
target: CodeGenTargetMachineOptions::from_host_triple(),
|
||||
};
|
||||
let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f);
|
||||
registry.add_task(task);
|
||||
registry.wait_tasks_complete(handles);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_classes_list_type_new() {
|
||||
let ctx = inkwell::context::Context::create();
|
||||
let generator = DefaultCodeGenerator::new(String::new(), 64);
|
||||
|
||||
let llvm_i32 = ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(&ctx);
|
||||
|
||||
let llvm_list = ListType::new(&generator, &ctx, llvm_i32.into());
|
||||
assert!(ListType::is_type(llvm_list.as_base_type(), llvm_usize).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_classes_range_type_new() {
|
||||
let ctx = inkwell::context::Context::create();
|
||||
|
||||
let llvm_range = RangeType::new(&ctx);
|
||||
assert!(RangeType::is_type(llvm_range.as_base_type()).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_classes_ndarray_type_new() {
|
||||
let ctx = inkwell::context::Context::create();
|
||||
let generator = DefaultCodeGenerator::new(String::new(), 64);
|
||||
|
||||
let llvm_i32 = ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(&ctx);
|
||||
|
||||
let llvm_ndarray = NDArrayType::new(&generator, &ctx, llvm_i32.into());
|
||||
assert!(NDArrayType::is_type(llvm_ndarray.as_base_type(), llvm_usize).is_ok());
|
||||
}
|
||||
|
|
|
@ -1,5 +1,23 @@
|
|||
#![warn(clippy::all)]
|
||||
#![allow(dead_code)]
|
||||
#![deny(
|
||||
future_incompatible,
|
||||
let_underscore,
|
||||
nonstandard_style,
|
||||
rust_2024_compatibility,
|
||||
clippy::all
|
||||
)]
|
||||
#![warn(clippy::pedantic)]
|
||||
#![allow(
|
||||
dead_code,
|
||||
clippy::cast_possible_truncation,
|
||||
clippy::cast_sign_loss,
|
||||
clippy::enum_glob_use,
|
||||
clippy::missing_errors_doc,
|
||||
clippy::missing_panics_doc,
|
||||
clippy::module_name_repetitions,
|
||||
clippy::similar_names,
|
||||
clippy::too_many_lines,
|
||||
clippy::wildcard_imports
|
||||
)]
|
||||
|
||||
pub mod codegen;
|
||||
pub mod symbol_resolver;
|
||||
|
|
|
@ -1,22 +1,19 @@
|
|||
use std::fmt::Debug;
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
use std::{collections::HashMap, fmt::Display};
|
||||
use std::{collections::HashMap, collections::HashSet, fmt::Display};
|
||||
|
||||
use crate::typecheck::typedef::TypeEnum;
|
||||
use crate::{
|
||||
codegen::CodeGenContext,
|
||||
toplevel::{DefinitionId, TopLevelDef},
|
||||
};
|
||||
use crate::{
|
||||
codegen::CodeGenerator,
|
||||
codegen::{CodeGenContext, CodeGenerator},
|
||||
toplevel::{type_annotation::TypeAnnotation, DefinitionId, TopLevelDef},
|
||||
typecheck::{
|
||||
type_inferencer::PrimitiveStore,
|
||||
typedef::{Type, Unifier},
|
||||
typedef::{Type, TypeEnum, Unifier, VarMap},
|
||||
},
|
||||
};
|
||||
use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue};
|
||||
use itertools::{chain, izip};
|
||||
use nac3parser::ast::{Expr, Location, StrRef};
|
||||
use itertools::{chain, izip, Itertools};
|
||||
use nac3parser::ast::{Constant, Expr, Location, StrRef};
|
||||
use parking_lot::RwLock;
|
||||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
|
@ -33,15 +30,192 @@ pub enum SymbolValue {
|
|||
OptionNone,
|
||||
}
|
||||
|
||||
impl SymbolValue {
|
||||
/// Creates a [`SymbolValue`] from a [`Constant`].
|
||||
///
|
||||
/// * `constant` - The constant to create the value from.
|
||||
/// * `expected_ty` - The expected type of the [`SymbolValue`].
|
||||
pub fn from_constant(
|
||||
constant: &Constant,
|
||||
expected_ty: Type,
|
||||
primitives: &PrimitiveStore,
|
||||
unifier: &mut Unifier,
|
||||
) -> Result<Self, String> {
|
||||
match constant {
|
||||
Constant::None => {
|
||||
if unifier.unioned(expected_ty, primitives.option) {
|
||||
Ok(SymbolValue::OptionNone)
|
||||
} else {
|
||||
Err(format!("Expected {expected_ty:?}, but got Option"))
|
||||
}
|
||||
}
|
||||
Constant::Bool(b) => {
|
||||
if unifier.unioned(expected_ty, primitives.bool) {
|
||||
Ok(SymbolValue::Bool(*b))
|
||||
} else {
|
||||
Err(format!("Expected {expected_ty:?}, but got bool"))
|
||||
}
|
||||
}
|
||||
Constant::Str(s) => {
|
||||
if unifier.unioned(expected_ty, primitives.str) {
|
||||
Ok(SymbolValue::Str(s.to_string()))
|
||||
} else {
|
||||
Err(format!("Expected {expected_ty:?}, but got str"))
|
||||
}
|
||||
}
|
||||
Constant::Int(i) => {
|
||||
if unifier.unioned(expected_ty, primitives.int32) {
|
||||
i32::try_from(*i).map(SymbolValue::I32).map_err(|e| e.to_string())
|
||||
} else if unifier.unioned(expected_ty, primitives.int64) {
|
||||
i64::try_from(*i).map(SymbolValue::I64).map_err(|e| e.to_string())
|
||||
} else if unifier.unioned(expected_ty, primitives.uint32) {
|
||||
u32::try_from(*i).map(SymbolValue::U32).map_err(|e| e.to_string())
|
||||
} else if unifier.unioned(expected_ty, primitives.uint64) {
|
||||
u64::try_from(*i).map(SymbolValue::U64).map_err(|e| e.to_string())
|
||||
} else {
|
||||
Err(format!("Expected {}, but got int", unifier.stringify(expected_ty)))
|
||||
}
|
||||
}
|
||||
Constant::Tuple(t) => {
|
||||
let expected_ty = unifier.get_ty(expected_ty);
|
||||
let TypeEnum::TTuple { ty } = expected_ty.as_ref() else {
|
||||
return Err(format!(
|
||||
"Expected {:?}, but got Tuple",
|
||||
expected_ty.get_type_name()
|
||||
));
|
||||
};
|
||||
|
||||
assert_eq!(ty.len(), t.len());
|
||||
|
||||
let elems = t
|
||||
.iter()
|
||||
.zip(ty)
|
||||
.map(|(constant, ty)| Self::from_constant(constant, *ty, primitives, unifier))
|
||||
.collect::<Result<Vec<SymbolValue>, _>>()?;
|
||||
Ok(SymbolValue::Tuple(elems))
|
||||
}
|
||||
Constant::Float(f) => {
|
||||
if unifier.unioned(expected_ty, primitives.float) {
|
||||
Ok(SymbolValue::Double(*f))
|
||||
} else {
|
||||
Err(format!("Expected {expected_ty:?}, but got float"))
|
||||
}
|
||||
}
|
||||
_ => Err(format!("Unsupported value type {constant:?}")),
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a [`SymbolValue`] from a [`Constant`], with its type being inferred from the constant value.
|
||||
///
|
||||
/// * `constant` - The constant to create the value from.
|
||||
pub fn from_constant_inferred(constant: &Constant) -> Result<Self, String> {
|
||||
match constant {
|
||||
Constant::None => Ok(SymbolValue::OptionNone),
|
||||
Constant::Bool(b) => Ok(SymbolValue::Bool(*b)),
|
||||
Constant::Str(s) => Ok(SymbolValue::Str(s.to_string())),
|
||||
Constant::Int(i) => {
|
||||
let i = *i;
|
||||
if i >= 0 {
|
||||
i32::try_from(i)
|
||||
.map(SymbolValue::I32)
|
||||
.or_else(|_| i64::try_from(i).map(SymbolValue::I64))
|
||||
.map_err(|_| {
|
||||
format!("Literal cannot be expressed as any integral type: {i}")
|
||||
})
|
||||
} else {
|
||||
u32::try_from(i)
|
||||
.map(SymbolValue::U32)
|
||||
.or_else(|_| u64::try_from(i).map(SymbolValue::U64))
|
||||
.map_err(|_| {
|
||||
format!("Literal cannot be expressed as any integral type: {i}")
|
||||
})
|
||||
}
|
||||
}
|
||||
Constant::Tuple(t) => {
|
||||
let elems = t
|
||||
.iter()
|
||||
.map(Self::from_constant_inferred)
|
||||
.collect::<Result<Vec<SymbolValue>, _>>()?;
|
||||
Ok(SymbolValue::Tuple(elems))
|
||||
}
|
||||
Constant::Float(f) => Ok(SymbolValue::Double(*f)),
|
||||
_ => Err(format!("Unsupported value type {constant:?}")),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the [`Type`] representing the data type of this value.
|
||||
pub fn get_type(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> Type {
|
||||
match self {
|
||||
SymbolValue::I32(_) => primitives.int32,
|
||||
SymbolValue::I64(_) => primitives.int64,
|
||||
SymbolValue::U32(_) => primitives.uint32,
|
||||
SymbolValue::U64(_) => primitives.uint64,
|
||||
SymbolValue::Str(_) => primitives.str,
|
||||
SymbolValue::Double(_) => primitives.float,
|
||||
SymbolValue::Bool(_) => primitives.bool,
|
||||
SymbolValue::Tuple(vs) => {
|
||||
let vs_tys = vs.iter().map(|v| v.get_type(primitives, unifier)).collect::<Vec<_>>();
|
||||
unifier.add_ty(TypeEnum::TTuple { ty: vs_tys })
|
||||
}
|
||||
SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the [`TypeAnnotation`] representing the data type of this value.
|
||||
pub fn get_type_annotation(
|
||||
&self,
|
||||
primitives: &PrimitiveStore,
|
||||
unifier: &mut Unifier,
|
||||
) -> TypeAnnotation {
|
||||
match self {
|
||||
SymbolValue::Bool(..)
|
||||
| SymbolValue::Double(..)
|
||||
| SymbolValue::I32(..)
|
||||
| SymbolValue::I64(..)
|
||||
| SymbolValue::U32(..)
|
||||
| SymbolValue::U64(..)
|
||||
| SymbolValue::Str(..) => TypeAnnotation::Primitive(self.get_type(primitives, unifier)),
|
||||
SymbolValue::Tuple(vs) => {
|
||||
let vs_tys = vs
|
||||
.iter()
|
||||
.map(|v| v.get_type_annotation(primitives, unifier))
|
||||
.collect::<Vec<_>>();
|
||||
TypeAnnotation::Tuple(vs_tys)
|
||||
}
|
||||
SymbolValue::OptionNone => TypeAnnotation::CustomClass {
|
||||
id: primitives.option.obj_id(unifier).unwrap(),
|
||||
params: Vec::default(),
|
||||
},
|
||||
SymbolValue::OptionSome(v) => {
|
||||
let ty = v.get_type_annotation(primitives, unifier);
|
||||
TypeAnnotation::CustomClass {
|
||||
id: primitives.option.obj_id(unifier).unwrap(),
|
||||
params: vec![ty],
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the [`TypeEnum`] representing the data type of this value.
|
||||
pub fn get_type_enum(
|
||||
&self,
|
||||
primitives: &PrimitiveStore,
|
||||
unifier: &mut Unifier,
|
||||
) -> Rc<TypeEnum> {
|
||||
let ty = self.get_type(primitives, unifier);
|
||||
unifier.get_ty(ty)
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for SymbolValue {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
SymbolValue::I32(i) => write!(f, "{}", i),
|
||||
SymbolValue::I64(i) => write!(f, "int64({})", i),
|
||||
SymbolValue::U32(i) => write!(f, "uint32({})", i),
|
||||
SymbolValue::U64(i) => write!(f, "uint64({})", i),
|
||||
SymbolValue::Str(s) => write!(f, "\"{}\"", s),
|
||||
SymbolValue::Double(d) => write!(f, "{}", d),
|
||||
SymbolValue::I32(i) => write!(f, "{i}"),
|
||||
SymbolValue::I64(i) => write!(f, "int64({i})"),
|
||||
SymbolValue::U32(i) => write!(f, "uint32({i})"),
|
||||
SymbolValue::U64(i) => write!(f, "uint64({i})"),
|
||||
SymbolValue::Str(s) => write!(f, "\"{s}\""),
|
||||
SymbolValue::Double(d) => write!(f, "{d}"),
|
||||
SymbolValue::Bool(b) => {
|
||||
if *b {
|
||||
write!(f, "True")
|
||||
|
@ -50,42 +224,82 @@ impl Display for SymbolValue {
|
|||
}
|
||||
}
|
||||
SymbolValue::Tuple(t) => {
|
||||
write!(f, "({})", t.iter().map(|v| format!("{}", v)).collect::<Vec<_>>().join(", "))
|
||||
write!(f, "({})", t.iter().map(|v| format!("{v}")).collect::<Vec<_>>().join(", "))
|
||||
}
|
||||
SymbolValue::OptionSome(v) => write!(f, "Some({})", v),
|
||||
SymbolValue::OptionSome(v) => write!(f, "Some({v})"),
|
||||
SymbolValue::OptionNone => write!(f, "none"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<SymbolValue> for u64 {
|
||||
type Error = ();
|
||||
|
||||
/// Tries to convert a [`SymbolValue`] into a [`u64`], returning [`Err`] if the value is not
|
||||
/// numeric or if the value cannot be converted into a `u64` without overflow.
|
||||
fn try_from(value: SymbolValue) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
SymbolValue::I32(v) => u64::try_from(v).map_err(|_| ()),
|
||||
SymbolValue::I64(v) => u64::try_from(v).map_err(|_| ()),
|
||||
SymbolValue::U32(v) => Ok(u64::from(v)),
|
||||
SymbolValue::U64(v) => Ok(v),
|
||||
_ => Err(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<SymbolValue> for i128 {
|
||||
type Error = ();
|
||||
|
||||
/// Tries to convert a [`SymbolValue`] into a [`i128`], returning [`Err`] if the value is not
|
||||
/// numeric.
|
||||
fn try_from(value: SymbolValue) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
SymbolValue::I32(v) => Ok(i128::from(v)),
|
||||
SymbolValue::I64(v) => Ok(i128::from(v)),
|
||||
SymbolValue::U32(v) => Ok(i128::from(v)),
|
||||
SymbolValue::U64(v) => Ok(i128::from(v)),
|
||||
_ => Err(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait StaticValue {
|
||||
/// Returns a unique identifier for this value.
|
||||
fn get_unique_identifier(&self) -> u64;
|
||||
|
||||
fn get_const_obj<'ctx, 'a>(
|
||||
/// Returns the constant object represented by this unique identifier.
|
||||
fn get_const_obj<'ctx>(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
) -> BasicValueEnum<'ctx>;
|
||||
|
||||
fn to_basic_value_enum<'ctx, 'a>(
|
||||
/// Converts this value to a LLVM [`BasicValueEnum`].
|
||||
fn to_basic_value_enum<'ctx>(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
expected_ty: Type,
|
||||
) -> Result<BasicValueEnum<'ctx>, String>;
|
||||
|
||||
fn get_field<'ctx, 'a>(
|
||||
/// Returns a field within this value.
|
||||
fn get_field<'ctx>(
|
||||
&self,
|
||||
name: StrRef,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> Option<ValueEnum<'ctx>>;
|
||||
|
||||
/// Returns a single element of this tuple.
|
||||
fn get_tuple_element<'ctx>(&self, index: u32) -> Option<ValueEnum<'ctx>>;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum ValueEnum<'ctx> {
|
||||
/// [ValueEnum] representing a static value.
|
||||
Static(Arc<dyn StaticValue + Send + Sync>),
|
||||
|
||||
/// [ValueEnum] representing a dynamic value.
|
||||
Dynamic(BasicValueEnum<'ctx>),
|
||||
}
|
||||
|
||||
|
@ -120,6 +334,7 @@ impl<'ctx> From<StructValue<'ctx>> for ValueEnum<'ctx> {
|
|||
}
|
||||
|
||||
impl<'ctx> ValueEnum<'ctx> {
|
||||
/// Converts this [`ValueEnum`] to a [`BasicValueEnum`].
|
||||
pub fn to_basic_value_enum<'a>(
|
||||
self,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
|
@ -134,7 +349,7 @@ impl<'ctx> ValueEnum<'ctx> {
|
|||
}
|
||||
|
||||
pub trait SymbolResolver {
|
||||
// get type of type variable identifier or top-level function type
|
||||
/// Get type of type variable identifier or top-level function type,
|
||||
fn get_symbol_type(
|
||||
&self,
|
||||
unifier: &mut Unifier,
|
||||
|
@ -143,16 +358,16 @@ pub trait SymbolResolver {
|
|||
str: StrRef,
|
||||
) -> Result<Type, String>;
|
||||
|
||||
// get the top-level definition of identifiers
|
||||
fn get_identifier_def(&self, str: StrRef) -> Result<DefinitionId, String>;
|
||||
/// Get the top-level definition of identifiers.
|
||||
fn get_identifier_def(&self, str: StrRef) -> Result<DefinitionId, HashSet<String>>;
|
||||
|
||||
fn get_symbol_value<'ctx, 'a>(
|
||||
fn get_symbol_value<'ctx>(
|
||||
&self,
|
||||
str: StrRef,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> Option<ValueEnum<'ctx>>;
|
||||
|
||||
fn get_default_param_value(&self, expr: &nac3parser::ast::Expr) -> Option<SymbolValue>;
|
||||
fn get_default_param_value(&self, expr: &Expr) -> Option<SymbolValue>;
|
||||
fn get_string_id(&self, s: &str) -> i32;
|
||||
fn get_exception_id(&self, tyid: usize) -> usize;
|
||||
|
||||
|
@ -160,14 +375,14 @@ pub trait SymbolResolver {
|
|||
&self,
|
||||
_unifier: &mut Unifier,
|
||||
_top_level_defs: &[Arc<RwLock<TopLevelDef>>],
|
||||
_primitives: &PrimitiveStore
|
||||
_primitives: &PrimitiveStore,
|
||||
) -> Result<(), String> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static IDENTIFIER_ID: [StrRef; 11] = [
|
||||
static IDENTIFIER_ID: [StrRef; 12] = [
|
||||
"int32".into(),
|
||||
"int64".into(),
|
||||
"float".into(),
|
||||
|
@ -179,17 +394,18 @@ thread_local! {
|
|||
"Exception".into(),
|
||||
"uint32".into(),
|
||||
"uint64".into(),
|
||||
"Literal".into(),
|
||||
];
|
||||
}
|
||||
|
||||
// convert type annotation into type
|
||||
/// Converts a type annotation into a [Type].
|
||||
pub fn parse_type_annotation<T>(
|
||||
resolver: &dyn SymbolResolver,
|
||||
top_level_defs: &[Arc<RwLock<TopLevelDef>>],
|
||||
unifier: &mut Unifier,
|
||||
primitives: &PrimitiveStore,
|
||||
expr: &Expr<T>,
|
||||
) -> Result<Type, String> {
|
||||
) -> Result<Type, HashSet<String>> {
|
||||
use nac3parser::ast::ExprKind::*;
|
||||
let ids = IDENTIFIER_ID.with(|ids| *ids);
|
||||
let int32_id = ids[0];
|
||||
|
@ -203,6 +419,7 @@ pub fn parse_type_annotation<T>(
|
|||
let exn_id = ids[8];
|
||||
let uint32_id = ids[9];
|
||||
let uint64_id = ids[10];
|
||||
let literal_id = ids[11];
|
||||
|
||||
let name_handling = |id: &StrRef, loc: Location, unifier: &mut Unifier| {
|
||||
if *id == int32_id {
|
||||
|
@ -223,39 +440,33 @@ pub fn parse_type_annotation<T>(
|
|||
Ok(primitives.exception)
|
||||
} else {
|
||||
let obj_id = resolver.get_identifier_def(*id);
|
||||
match obj_id {
|
||||
Ok(obj_id) => {
|
||||
let def = top_level_defs[obj_id.0].read();
|
||||
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
|
||||
if !type_vars.is_empty() {
|
||||
return Err(format!(
|
||||
"Unexpected number of type parameters: expected {} but got 0",
|
||||
type_vars.len()
|
||||
));
|
||||
}
|
||||
let fields = chain(
|
||||
fields.iter().map(|(k, v, m)| (*k, (*v, *m))),
|
||||
methods.iter().map(|(k, v, _)| (*k, (*v, false))),
|
||||
)
|
||||
.collect();
|
||||
Ok(unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id,
|
||||
fields,
|
||||
params: Default::default(),
|
||||
}))
|
||||
} else {
|
||||
Err(format!("Cannot use function name as type at {}", loc))
|
||||
if let Ok(obj_id) = obj_id {
|
||||
let def = top_level_defs[obj_id.0].read();
|
||||
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
|
||||
if !type_vars.is_empty() {
|
||||
return Err(HashSet::from([format!(
|
||||
"Unexpected number of type parameters: expected {} but got 0",
|
||||
type_vars.len()
|
||||
)]));
|
||||
}
|
||||
let fields = chain(
|
||||
fields.iter().map(|(k, v, m)| (*k, (*v, *m))),
|
||||
methods.iter().map(|(k, v, _)| (*k, (*v, false))),
|
||||
)
|
||||
.collect();
|
||||
Ok(unifier.add_ty(TypeEnum::TObj { obj_id, fields, params: VarMap::default() }))
|
||||
} else {
|
||||
Err(HashSet::from([format!("Cannot use function name as type at {loc}")]))
|
||||
}
|
||||
Err(_) => {
|
||||
let ty = resolver
|
||||
.get_symbol_type(unifier, top_level_defs, primitives, *id)
|
||||
.map_err(|e| format!("Unknown type annotation at {}: {}", loc, e))?;
|
||||
if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) {
|
||||
Ok(ty)
|
||||
} else {
|
||||
Err(format!("Unknown type annotation {} at {}", id, loc))
|
||||
}
|
||||
} else {
|
||||
let ty =
|
||||
resolver.get_symbol_type(unifier, top_level_defs, primitives, *id).map_err(
|
||||
|e| HashSet::from([format!("Unknown type annotation at {loc}: {e}")]),
|
||||
)?;
|
||||
if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) {
|
||||
Ok(ty)
|
||||
} else {
|
||||
Err(HashSet::from([format!("Unknown type annotation {id} at {loc}")]))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -278,8 +489,31 @@ pub fn parse_type_annotation<T>(
|
|||
.collect::<Result<Vec<_>, _>>()?;
|
||||
Ok(unifier.add_ty(TypeEnum::TTuple { ty }))
|
||||
} else {
|
||||
Err("Expected multiple elements for tuple".into())
|
||||
Err(HashSet::from(["Expected multiple elements for tuple".into()]))
|
||||
}
|
||||
} else if *id == literal_id {
|
||||
let mut parse_literal = |elt: &Expr<T>| {
|
||||
let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, elt)?;
|
||||
let ty_enum = &*unifier.get_ty_immutable(ty);
|
||||
match ty_enum {
|
||||
TypeEnum::TLiteral { values, .. } => Ok(values.clone()),
|
||||
_ => Err(HashSet::from([format!(
|
||||
"Expected literal in type argument for Literal at {}",
|
||||
elt.location
|
||||
)])),
|
||||
}
|
||||
};
|
||||
|
||||
let values = if let Tuple { elts, .. } = &slice.node {
|
||||
elts.iter().map(&mut parse_literal).collect::<Result<Vec<_>, _>>()?
|
||||
} else {
|
||||
vec![parse_literal(slice)?]
|
||||
}
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect_vec();
|
||||
|
||||
Ok(unifier.get_fresh_literal(values, Some(slice.location)))
|
||||
} else {
|
||||
let types = if let Tuple { elts, .. } = &slice.node {
|
||||
elts.iter()
|
||||
|
@ -295,13 +529,13 @@ pub fn parse_type_annotation<T>(
|
|||
let def = top_level_defs[obj_id.0].read();
|
||||
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
|
||||
if types.len() != type_vars.len() {
|
||||
return Err(format!(
|
||||
return Err(HashSet::from([format!(
|
||||
"Unexpected number of type parameters: expected {} but got {}",
|
||||
type_vars.len(),
|
||||
types.len()
|
||||
));
|
||||
)]));
|
||||
}
|
||||
let mut subst = HashMap::new();
|
||||
let mut subst = VarMap::new();
|
||||
for (var, ty) in izip!(type_vars.iter(), types.iter()) {
|
||||
let id = if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) {
|
||||
*id
|
||||
|
@ -323,7 +557,7 @@ pub fn parse_type_annotation<T>(
|
|||
}));
|
||||
Ok(unifier.add_ty(TypeEnum::TObj { obj_id, fields, params: subst }))
|
||||
} else {
|
||||
Err("Cannot use function name as type".into())
|
||||
Err(HashSet::from(["Cannot use function name as type".into()]))
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -334,10 +568,13 @@ pub fn parse_type_annotation<T>(
|
|||
if let Name { id, .. } = &value.node {
|
||||
subscript_name_handle(id, slice, unifier)
|
||||
} else {
|
||||
Err(format!("unsupported type expression at {}", expr.location))
|
||||
Err(HashSet::from([format!("unsupported type expression at {}", expr.location)]))
|
||||
}
|
||||
}
|
||||
_ => Err(format!("unsupported type expression at {}", expr.location)),
|
||||
Constant { value, .. } => SymbolValue::from_constant_inferred(value)
|
||||
.map(|v| unifier.get_fresh_literal(vec![v], Some(expr.location)))
|
||||
.map_err(|err| HashSet::from([err])),
|
||||
_ => Err(HashSet::from([format!("unsupported type expression at {}", expr.location)])),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -348,7 +585,7 @@ impl dyn SymbolResolver + Send + Sync {
|
|||
unifier: &mut Unifier,
|
||||
primitives: &PrimitiveStore,
|
||||
expr: &Expr<T>,
|
||||
) -> Result<Type, String> {
|
||||
) -> Result<Type, HashSet<String>> {
|
||||
parse_type_annotation(self, top_level_defs, unifier, primitives, expr)
|
||||
}
|
||||
|
||||
|
@ -361,13 +598,13 @@ impl dyn SymbolResolver + Send + Sync {
|
|||
unifier.internal_stringify(
|
||||
ty,
|
||||
&mut |id| {
|
||||
if let TopLevelDef::Class { name, .. } = &*top_level_defs[id].read() {
|
||||
name.to_string()
|
||||
} else {
|
||||
let TopLevelDef::Class { name, .. } = &*top_level_defs[id].read() else {
|
||||
unreachable!("expected class definition")
|
||||
}
|
||||
};
|
||||
|
||||
name.to_string()
|
||||
},
|
||||
&mut |id| format!("typevar{}", id),
|
||||
&mut |id| format!("typevar{id}"),
|
||||
&mut None,
|
||||
)
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -1,10 +1,274 @@
|
|||
use std::convert::TryInto;
|
||||
|
||||
use crate::symbol_resolver::SymbolValue;
|
||||
use crate::toplevel::numpy::unpack_ndarray_var_tys;
|
||||
use crate::typecheck::typedef::{into_var_map, Mapping, TypeVarId, VarMap};
|
||||
use nac3parser::ast::{Constant, Location};
|
||||
use strum::IntoEnumIterator;
|
||||
use strum_macros::EnumIter;
|
||||
|
||||
use super::*;
|
||||
|
||||
/// All primitive types and functions in nac3core.
|
||||
#[derive(Clone, Copy, Debug, EnumIter, PartialEq, Eq)]
|
||||
pub enum PrimDef {
|
||||
Int32,
|
||||
Int64,
|
||||
Float,
|
||||
Bool,
|
||||
None,
|
||||
Range,
|
||||
Str,
|
||||
Exception,
|
||||
UInt32,
|
||||
UInt64,
|
||||
Option,
|
||||
OptionIsSome,
|
||||
OptionIsNone,
|
||||
OptionUnwrap,
|
||||
NDArray,
|
||||
NDArrayCopy,
|
||||
NDArrayFill,
|
||||
FunInt32,
|
||||
FunInt64,
|
||||
FunUInt32,
|
||||
FunUInt64,
|
||||
FunFloat,
|
||||
FunNpNDArray,
|
||||
FunNpEmpty,
|
||||
FunNpZeros,
|
||||
FunNpOnes,
|
||||
FunNpFull,
|
||||
FunNpArray,
|
||||
FunNpEye,
|
||||
FunNpIdentity,
|
||||
FunRound,
|
||||
FunRound64,
|
||||
FunNpRound,
|
||||
FunRange,
|
||||
FunStr,
|
||||
FunBool,
|
||||
FunFloor,
|
||||
FunFloor64,
|
||||
FunNpFloor,
|
||||
FunCeil,
|
||||
FunCeil64,
|
||||
FunNpCeil,
|
||||
FunLen,
|
||||
FunMin,
|
||||
FunNpMin,
|
||||
FunNpMinimum,
|
||||
FunMax,
|
||||
FunNpMax,
|
||||
FunNpMaximum,
|
||||
FunAbs,
|
||||
FunNpIsNan,
|
||||
FunNpIsInf,
|
||||
FunNpSin,
|
||||
FunNpCos,
|
||||
FunNpExp,
|
||||
FunNpExp2,
|
||||
FunNpLog,
|
||||
FunNpLog10,
|
||||
FunNpLog2,
|
||||
FunNpFabs,
|
||||
FunNpSqrt,
|
||||
FunNpRint,
|
||||
FunNpTan,
|
||||
FunNpArcsin,
|
||||
FunNpArccos,
|
||||
FunNpArctan,
|
||||
FunNpSinh,
|
||||
FunNpCosh,
|
||||
FunNpTanh,
|
||||
FunNpArcsinh,
|
||||
FunNpArccosh,
|
||||
FunNpArctanh,
|
||||
FunNpExpm1,
|
||||
FunNpCbrt,
|
||||
FunSpSpecErf,
|
||||
FunSpSpecErfc,
|
||||
FunSpSpecGamma,
|
||||
FunSpSpecGammaln,
|
||||
FunSpSpecJ0,
|
||||
FunSpSpecJ1,
|
||||
FunNpArctan2,
|
||||
FunNpCopysign,
|
||||
FunNpFmax,
|
||||
FunNpFmin,
|
||||
FunNpLdExp,
|
||||
FunNpHypot,
|
||||
FunNpNextAfter,
|
||||
FunSome,
|
||||
}
|
||||
|
||||
/// Associated details of a [`PrimDef`]
|
||||
pub enum PrimDefDetails {
|
||||
PrimFunction { name: &'static str, simple_name: &'static str },
|
||||
PrimClass { name: &'static str },
|
||||
}
|
||||
|
||||
impl PrimDef {
|
||||
/// Get the assigned [`DefinitionId`] of this [`PrimDef`].
|
||||
///
|
||||
/// The assigned definition ID is defined by the position this [`PrimDef`] enum unit variant is defined at,
|
||||
/// with the first `PrimDef`'s definition id being `0`.
|
||||
#[must_use]
|
||||
pub fn id(&self) -> DefinitionId {
|
||||
DefinitionId(*self as usize)
|
||||
}
|
||||
|
||||
/// Check if a definition ID is that of a [`PrimDef`].
|
||||
#[must_use]
|
||||
pub fn contains_id(id: DefinitionId) -> bool {
|
||||
Self::iter().any(|prim| prim.id() == id)
|
||||
}
|
||||
|
||||
/// Get the definition "simple name" of this [`PrimDef`].
|
||||
///
|
||||
/// If the [`PrimDef`] is a function, this corresponds to [`TopLevelDef::Function::simple_name`].
|
||||
///
|
||||
/// If the [`PrimDef`] is a class, this returns [`None`].
|
||||
#[must_use]
|
||||
pub fn simple_name(&self) -> &'static str {
|
||||
match self.details() {
|
||||
PrimDefDetails::PrimFunction { simple_name, .. } => simple_name,
|
||||
PrimDefDetails::PrimClass { .. } => {
|
||||
panic!("PrimDef {self:?} has no simple_name as it is not a function.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the definition "name" of this [`PrimDef`].
|
||||
///
|
||||
/// If the [`PrimDef`] is a function, this corresponds to [`TopLevelDef::Function::name`].
|
||||
///
|
||||
/// If the [`PrimDef`] is a class, this corresponds to [`TopLevelDef::Class::name`].
|
||||
#[must_use]
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self.details() {
|
||||
PrimDefDetails::PrimFunction { name, .. } | PrimDefDetails::PrimClass { name } => name,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the associated details of this [`PrimDef`]
|
||||
#[must_use]
|
||||
pub fn details(self) -> PrimDefDetails {
|
||||
fn class(name: &'static str) -> PrimDefDetails {
|
||||
PrimDefDetails::PrimClass { name }
|
||||
}
|
||||
|
||||
fn fun(name: &'static str, simple_name: Option<&'static str>) -> PrimDefDetails {
|
||||
PrimDefDetails::PrimFunction { simple_name: simple_name.unwrap_or(name), name }
|
||||
}
|
||||
|
||||
match self {
|
||||
PrimDef::Int32 => class("int32"),
|
||||
PrimDef::Int64 => class("int64"),
|
||||
PrimDef::Float => class("float"),
|
||||
PrimDef::Bool => class("bool"),
|
||||
PrimDef::None => class("none"),
|
||||
PrimDef::Range => class("range"),
|
||||
PrimDef::Str => class("str"),
|
||||
PrimDef::Exception => class("Exception"),
|
||||
PrimDef::UInt32 => class("uint32"),
|
||||
PrimDef::UInt64 => class("uint64"),
|
||||
PrimDef::Option => class("Option"),
|
||||
PrimDef::OptionIsSome => fun("Option.is_some", Some("is_some")),
|
||||
PrimDef::OptionIsNone => fun("Option.is_none", Some("is_none")),
|
||||
PrimDef::OptionUnwrap => fun("Option.unwrap", Some("unwrap")),
|
||||
PrimDef::NDArray => class("ndarray"),
|
||||
PrimDef::NDArrayCopy => fun("ndarray.copy", Some("copy")),
|
||||
PrimDef::NDArrayFill => fun("ndarray.fill", Some("fill")),
|
||||
PrimDef::FunInt32 => fun("int32", None),
|
||||
PrimDef::FunInt64 => fun("int64", None),
|
||||
PrimDef::FunUInt32 => fun("uint32", None),
|
||||
PrimDef::FunUInt64 => fun("uint64", None),
|
||||
PrimDef::FunFloat => fun("float", None),
|
||||
PrimDef::FunNpNDArray => fun("np_ndarray", None),
|
||||
PrimDef::FunNpEmpty => fun("np_empty", None),
|
||||
PrimDef::FunNpZeros => fun("np_zeros", None),
|
||||
PrimDef::FunNpOnes => fun("np_ones", None),
|
||||
PrimDef::FunNpFull => fun("np_full", None),
|
||||
PrimDef::FunNpArray => fun("np_array", None),
|
||||
PrimDef::FunNpEye => fun("np_eye", None),
|
||||
PrimDef::FunNpIdentity => fun("np_identity", None),
|
||||
PrimDef::FunRound => fun("round", None),
|
||||
PrimDef::FunRound64 => fun("round64", None),
|
||||
PrimDef::FunNpRound => fun("np_round", None),
|
||||
PrimDef::FunRange => fun("range", None),
|
||||
PrimDef::FunStr => fun("str", None),
|
||||
PrimDef::FunBool => fun("bool", None),
|
||||
PrimDef::FunFloor => fun("floor", None),
|
||||
PrimDef::FunFloor64 => fun("floor64", None),
|
||||
PrimDef::FunNpFloor => fun("np_floor", None),
|
||||
PrimDef::FunCeil => fun("ceil", None),
|
||||
PrimDef::FunCeil64 => fun("ceil64", None),
|
||||
PrimDef::FunNpCeil => fun("np_ceil", None),
|
||||
PrimDef::FunLen => fun("len", None),
|
||||
PrimDef::FunMin => fun("min", None),
|
||||
PrimDef::FunNpMin => fun("np_min", None),
|
||||
PrimDef::FunNpMinimum => fun("np_minimum", None),
|
||||
PrimDef::FunMax => fun("max", None),
|
||||
PrimDef::FunNpMax => fun("np_max", None),
|
||||
PrimDef::FunNpMaximum => fun("np_maximum", None),
|
||||
PrimDef::FunAbs => fun("abs", None),
|
||||
PrimDef::FunNpIsNan => fun("np_isnan", None),
|
||||
PrimDef::FunNpIsInf => fun("np_isinf", None),
|
||||
PrimDef::FunNpSin => fun("np_sin", None),
|
||||
PrimDef::FunNpCos => fun("np_cos", None),
|
||||
PrimDef::FunNpExp => fun("np_exp", None),
|
||||
PrimDef::FunNpExp2 => fun("np_exp2", None),
|
||||
PrimDef::FunNpLog => fun("np_log", None),
|
||||
PrimDef::FunNpLog10 => fun("np_log10", None),
|
||||
PrimDef::FunNpLog2 => fun("np_log2", None),
|
||||
PrimDef::FunNpFabs => fun("np_fabs", None),
|
||||
PrimDef::FunNpSqrt => fun("np_sqrt", None),
|
||||
PrimDef::FunNpRint => fun("np_rint", None),
|
||||
PrimDef::FunNpTan => fun("np_tan", None),
|
||||
PrimDef::FunNpArcsin => fun("np_arcsin", None),
|
||||
PrimDef::FunNpArccos => fun("np_arccos", None),
|
||||
PrimDef::FunNpArctan => fun("np_arctan", None),
|
||||
PrimDef::FunNpSinh => fun("np_sinh", None),
|
||||
PrimDef::FunNpCosh => fun("np_cosh", None),
|
||||
PrimDef::FunNpTanh => fun("np_tanh", None),
|
||||
PrimDef::FunNpArcsinh => fun("np_arcsinh", None),
|
||||
PrimDef::FunNpArccosh => fun("np_arccosh", None),
|
||||
PrimDef::FunNpArctanh => fun("np_arctanh", None),
|
||||
PrimDef::FunNpExpm1 => fun("np_expm1", None),
|
||||
PrimDef::FunNpCbrt => fun("np_cbrt", None),
|
||||
PrimDef::FunSpSpecErf => fun("sp_spec_erf", None),
|
||||
PrimDef::FunSpSpecErfc => fun("sp_spec_erfc", None),
|
||||
PrimDef::FunSpSpecGamma => fun("sp_spec_gamma", None),
|
||||
PrimDef::FunSpSpecGammaln => fun("sp_spec_gammaln", None),
|
||||
PrimDef::FunSpSpecJ0 => fun("sp_spec_j0", None),
|
||||
PrimDef::FunSpSpecJ1 => fun("sp_spec_j1", None),
|
||||
PrimDef::FunNpArctan2 => fun("np_arctan2", None),
|
||||
PrimDef::FunNpCopysign => fun("np_copysign", None),
|
||||
PrimDef::FunNpFmax => fun("np_fmax", None),
|
||||
PrimDef::FunNpFmin => fun("np_fmin", None),
|
||||
PrimDef::FunNpLdExp => fun("np_ldexp", None),
|
||||
PrimDef::FunNpHypot => fun("np_hypot", None),
|
||||
PrimDef::FunNpNextAfter => fun("np_nextafter", None),
|
||||
PrimDef::FunSome => fun("Some", None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Asserts that a [`PrimDef`] is in an allowlist.
|
||||
///
|
||||
/// Like `debug_assert!`, this statements of this function are only
|
||||
/// enabled if `cfg!(debug_assertions)` is true.
|
||||
pub fn debug_assert_prim_is_allowed(prim: PrimDef, allowlist: &[PrimDef]) {
|
||||
if cfg!(debug_assertions) {
|
||||
let allowed = allowlist.iter().any(|p| *p == prim);
|
||||
assert!(
|
||||
allowed,
|
||||
"Disallowed primitive definition. Got {prim:?}, but expects it to be in {allowlist:?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl TopLevelDef {
|
||||
pub fn to_string(&self, unifier: &mut Unifier) -> String {
|
||||
match self {
|
||||
|
@ -43,48 +307,49 @@ impl TopLevelDef {
|
|||
}
|
||||
|
||||
impl TopLevelComposer {
|
||||
pub fn make_primitives() -> (PrimitiveStore, Unifier) {
|
||||
#[must_use]
|
||||
pub fn make_primitives(size_t: u32) -> (PrimitiveStore, Unifier) {
|
||||
let mut unifier = Unifier::new();
|
||||
let int32 = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(0),
|
||||
obj_id: PrimDef::Int32.id(),
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let int64 = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(1),
|
||||
obj_id: PrimDef::Int64.id(),
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let float = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(2),
|
||||
obj_id: PrimDef::Float.id(),
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let bool = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(3),
|
||||
obj_id: PrimDef::Bool.id(),
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let none = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(4),
|
||||
obj_id: PrimDef::None.id(),
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let range = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(5),
|
||||
obj_id: PrimDef::Range.id(),
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let str = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(6),
|
||||
obj_id: PrimDef::Str.id(),
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let exception = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(7),
|
||||
obj_id: PrimDef::Exception.id(),
|
||||
fields: vec![
|
||||
("__name__".into(), (int32, true)),
|
||||
("__file__".into(), (int32, true)),
|
||||
("__file__".into(), (str, true)),
|
||||
("__line__".into(), (int32, true)),
|
||||
("__col__".into(), (int32, true)),
|
||||
("__func__".into(), (str, true)),
|
||||
|
@ -95,63 +360,102 @@ impl TopLevelComposer {
|
|||
]
|
||||
.into_iter()
|
||||
.collect::<HashMap<_, _>>(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let uint32 = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(8),
|
||||
obj_id: PrimDef::UInt32.id(),
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let uint64 = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(9),
|
||||
obj_id: PrimDef::UInt64.id(),
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
|
||||
let option_type_var = unifier.get_fresh_var(Some("option_type_var".into()), None);
|
||||
let is_some_type_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![],
|
||||
ret: bool,
|
||||
vars: HashMap::from([(option_type_var.1, option_type_var.0)]),
|
||||
vars: into_var_map([option_type_var]),
|
||||
}));
|
||||
let unwrap_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![],
|
||||
ret: option_type_var.0,
|
||||
vars: HashMap::from([(option_type_var.1, option_type_var.0)]),
|
||||
ret: option_type_var.ty,
|
||||
vars: into_var_map([option_type_var]),
|
||||
}));
|
||||
let option = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(10),
|
||||
obj_id: PrimDef::Option.id(),
|
||||
fields: vec![
|
||||
("is_some".into(), (is_some_type_fun_ty, true)),
|
||||
("is_none".into(), (is_some_type_fun_ty, true)),
|
||||
("unwrap".into(), (unwrap_fun_ty, true)),
|
||||
(PrimDef::OptionIsSome.simple_name().into(), (is_some_type_fun_ty, true)),
|
||||
(PrimDef::OptionIsNone.simple_name().into(), (is_some_type_fun_ty, true)),
|
||||
(PrimDef::OptionUnwrap.simple_name().into(), (unwrap_fun_ty, true)),
|
||||
]
|
||||
.into_iter()
|
||||
.collect::<HashMap<_, _>>(),
|
||||
params: HashMap::from([(option_type_var.1, option_type_var.0)]),
|
||||
params: into_var_map([option_type_var]),
|
||||
});
|
||||
|
||||
let size_t_ty = match size_t {
|
||||
32 => uint32,
|
||||
64 => uint64,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
|
||||
let ndarray_ndims_tvar =
|
||||
unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None);
|
||||
let ndarray_copy_fun_ret_ty = unifier.get_fresh_var(None, None);
|
||||
let ndarray_copy_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![],
|
||||
ret: ndarray_copy_fun_ret_ty.ty,
|
||||
vars: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
|
||||
}));
|
||||
let ndarray_fill_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![FuncArg {
|
||||
name: "value".into(),
|
||||
ty: ndarray_dtype_tvar.ty,
|
||||
default_value: None,
|
||||
}],
|
||||
ret: none,
|
||||
vars: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
|
||||
}));
|
||||
let ndarray = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: PrimDef::NDArray.id(),
|
||||
fields: Mapping::from([
|
||||
(PrimDef::NDArrayCopy.simple_name().into(), (ndarray_copy_fun_ty, true)),
|
||||
(PrimDef::NDArrayFill.simple_name().into(), (ndarray_fill_fun_ty, true)),
|
||||
]),
|
||||
params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
|
||||
});
|
||||
|
||||
unifier.unify(ndarray_copy_fun_ret_ty.ty, ndarray).unwrap();
|
||||
|
||||
let primitives = PrimitiveStore {
|
||||
int32,
|
||||
int64,
|
||||
uint32,
|
||||
uint64,
|
||||
float,
|
||||
bool,
|
||||
none,
|
||||
range,
|
||||
str,
|
||||
exception,
|
||||
uint32,
|
||||
uint64,
|
||||
option,
|
||||
ndarray,
|
||||
size_t,
|
||||
};
|
||||
unifier.put_primitive_store(&primitives);
|
||||
crate::typecheck::magic_methods::set_primitives_magic_methods(&primitives, &mut unifier);
|
||||
(primitives, unifier)
|
||||
}
|
||||
|
||||
/// already include the definition_id of itself inside the ancestors vector
|
||||
/// when first registering, the type_vars, fields, methods, ancestors are invalid
|
||||
/// already include the `definition_id` of itself inside the ancestors vector
|
||||
/// when first registering, the `type_vars`, fields, methods, ancestors are invalid
|
||||
#[must_use]
|
||||
pub fn make_top_level_class_def(
|
||||
index: usize,
|
||||
obj_id: DefinitionId,
|
||||
resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>,
|
||||
name: StrRef,
|
||||
constructor: Option<Type>,
|
||||
|
@ -159,11 +463,11 @@ impl TopLevelComposer {
|
|||
) -> TopLevelDef {
|
||||
TopLevelDef::Class {
|
||||
name,
|
||||
object_id: DefinitionId(index),
|
||||
type_vars: Default::default(),
|
||||
fields: Default::default(),
|
||||
methods: Default::default(),
|
||||
ancestors: Default::default(),
|
||||
object_id: obj_id,
|
||||
type_vars: Vec::default(),
|
||||
fields: Vec::default(),
|
||||
methods: Vec::default(),
|
||||
ancestors: Vec::default(),
|
||||
constructor,
|
||||
resolver,
|
||||
loc,
|
||||
|
@ -171,6 +475,7 @@ impl TopLevelComposer {
|
|||
}
|
||||
|
||||
/// when first registering, the type is a invalid value
|
||||
#[must_use]
|
||||
pub fn make_top_level_function_def(
|
||||
name: String,
|
||||
simple_name: StrRef,
|
||||
|
@ -182,15 +487,16 @@ impl TopLevelComposer {
|
|||
name,
|
||||
simple_name,
|
||||
signature: ty,
|
||||
var_id: Default::default(),
|
||||
instance_to_symbol: Default::default(),
|
||||
instance_to_stmt: Default::default(),
|
||||
var_id: Vec::default(),
|
||||
instance_to_symbol: HashMap::default(),
|
||||
instance_to_stmt: HashMap::default(),
|
||||
resolver,
|
||||
codegen_callback: None,
|
||||
loc,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn make_class_method_name(mut class_name: String, method_name: &str) -> String {
|
||||
class_name.push('.');
|
||||
class_name.push_str(method_name);
|
||||
|
@ -200,13 +506,13 @@ impl TopLevelComposer {
|
|||
pub fn get_class_method_def_info(
|
||||
class_methods_def: &[(StrRef, Type, DefinitionId)],
|
||||
method_name: StrRef,
|
||||
) -> Result<(Type, DefinitionId), String> {
|
||||
) -> Result<(Type, DefinitionId), HashSet<String>> {
|
||||
for (name, ty, def_id) in class_methods_def {
|
||||
if name == &method_name {
|
||||
return Ok((*ty, *def_id));
|
||||
}
|
||||
}
|
||||
Err(format!("no method {} in the current class", method_name))
|
||||
Err(HashSet::from([format!("no method {method_name} in the current class")]))
|
||||
}
|
||||
|
||||
/// get all base class def id of a class, excluding itself. \
|
||||
|
@ -217,7 +523,7 @@ impl TopLevelComposer {
|
|||
pub fn get_all_ancestors_helper(
|
||||
child: &TypeAnnotation,
|
||||
temp_def_list: &[Arc<RwLock<TopLevelDef>>],
|
||||
) -> Result<Vec<TypeAnnotation>, String> {
|
||||
) -> Result<Vec<TypeAnnotation>, HashSet<String>> {
|
||||
let mut result: Vec<TypeAnnotation> = Vec::new();
|
||||
let mut parent = Self::get_parent(child, temp_def_list);
|
||||
while let Some(p) = parent {
|
||||
|
@ -229,16 +535,16 @@ impl TopLevelComposer {
|
|||
};
|
||||
// check cycle
|
||||
let no_cycle = result.iter().all(|x| {
|
||||
if let TypeAnnotation::CustomClass { id, .. } = x {
|
||||
id.0 != p_id.0
|
||||
} else {
|
||||
let TypeAnnotation::CustomClass { id, .. } = x else {
|
||||
unreachable!("must be class kind annotation")
|
||||
}
|
||||
};
|
||||
|
||||
id.0 != p_id.0
|
||||
});
|
||||
if no_cycle {
|
||||
result.push(p);
|
||||
} else {
|
||||
return Err("cyclic inheritance detected".into());
|
||||
return Err(HashSet::from(["cyclic inheritance detected".into()]));
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
|
@ -256,23 +562,23 @@ impl TopLevelComposer {
|
|||
};
|
||||
let child_def = temp_def_list.get(child_id.0).unwrap();
|
||||
let child_def = child_def.read();
|
||||
if let TopLevelDef::Class { ancestors, .. } = &*child_def {
|
||||
if !ancestors.is_empty() {
|
||||
Some(ancestors[0].clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
let TopLevelDef::Class { ancestors, .. } = &*child_def else {
|
||||
unreachable!("child must be top level class def")
|
||||
};
|
||||
|
||||
if ancestors.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(ancestors[0].clone())
|
||||
}
|
||||
}
|
||||
|
||||
/// get the var_id of a given TVar type
|
||||
pub fn get_var_id(var_ty: Type, unifier: &mut Unifier) -> Result<u32, String> {
|
||||
/// get the `var_id` of a given `TVar` type
|
||||
pub fn get_var_id(var_ty: Type, unifier: &mut Unifier) -> Result<TypeVarId, HashSet<String>> {
|
||||
if let TypeEnum::TVar { id, .. } = unifier.get_ty(var_ty).as_ref() {
|
||||
Ok(*id)
|
||||
} else {
|
||||
Err("not type var".to_string())
|
||||
Err(HashSet::from(["not type var".to_string()]))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -286,13 +592,17 @@ impl TopLevelComposer {
|
|||
let this = this.as_ref();
|
||||
let other = unifier.get_ty(other);
|
||||
let other = other.as_ref();
|
||||
if let (
|
||||
let (
|
||||
TypeEnum::TFunc(FunSignature { args: this_args, ret: this_ret, .. }),
|
||||
TypeEnum::TFunc(FunSignature { args: other_args, ret: other_ret, .. }),
|
||||
) = (this, other)
|
||||
{
|
||||
// check args
|
||||
let args_ok = this_args
|
||||
else {
|
||||
unreachable!("this function must be called with function type")
|
||||
};
|
||||
|
||||
// check args
|
||||
let args_ok =
|
||||
this_args
|
||||
.iter()
|
||||
.map(|FuncArg { name, ty, .. }| (name, type_var_to_concrete_def.get(ty).unwrap()))
|
||||
.zip(other_args.iter().map(|FuncArg { name, ty, .. }| {
|
||||
|
@ -307,18 +617,15 @@ impl TopLevelComposer {
|
|||
}
|
||||
});
|
||||
|
||||
// check rets
|
||||
let ret_ok = check_overload_type_annotation_compatible(
|
||||
type_var_to_concrete_def.get(this_ret).unwrap(),
|
||||
type_var_to_concrete_def.get(other_ret).unwrap(),
|
||||
unifier,
|
||||
);
|
||||
// check rets
|
||||
let ret_ok = check_overload_type_annotation_compatible(
|
||||
type_var_to_concrete_def.get(this_ret).unwrap(),
|
||||
type_var_to_concrete_def.get(other_ret).unwrap(),
|
||||
unifier,
|
||||
);
|
||||
|
||||
// return
|
||||
args_ok && ret_ok
|
||||
} else {
|
||||
unreachable!("this function must be called with function type")
|
||||
}
|
||||
// return
|
||||
args_ok && ret_ok
|
||||
}
|
||||
|
||||
pub fn check_overload_field_type(
|
||||
|
@ -334,7 +641,7 @@ impl TopLevelComposer {
|
|||
)
|
||||
}
|
||||
|
||||
pub fn get_all_assigned_field(stmts: &[ast::Stmt<()>]) -> Result<HashSet<StrRef>, String> {
|
||||
pub fn get_all_assigned_field(stmts: &[Stmt<()>]) -> Result<HashSet<StrRef>, HashSet<String>> {
|
||||
let mut result = HashSet::new();
|
||||
for s in stmts {
|
||||
match &s.node {
|
||||
|
@ -351,10 +658,10 @@ impl TopLevelComposer {
|
|||
}
|
||||
} =>
|
||||
{
|
||||
return Err(format!(
|
||||
return Err(HashSet::from([format!(
|
||||
"redundant type annotation for class fields at {}",
|
||||
s.location
|
||||
))
|
||||
)]))
|
||||
}
|
||||
ast::StmtKind::Assign { targets, .. } => {
|
||||
for t in targets {
|
||||
|
@ -376,14 +683,14 @@ impl TopLevelComposer {
|
|||
ast::StmtKind::If { body, orelse, .. } => {
|
||||
let inited_for_sure = Self::get_all_assigned_field(body.as_slice())?
|
||||
.intersection(&Self::get_all_assigned_field(orelse.as_slice())?)
|
||||
.cloned()
|
||||
.copied()
|
||||
.collect::<HashSet<_>>();
|
||||
result.extend(inited_for_sure);
|
||||
}
|
||||
ast::StmtKind::Try { body, orelse, finalbody, .. } => {
|
||||
let inited_for_sure = Self::get_all_assigned_field(body.as_slice())?
|
||||
.intersection(&Self::get_all_assigned_field(orelse.as_slice())?)
|
||||
.cloned()
|
||||
.copied()
|
||||
.collect::<HashSet<_>>();
|
||||
result.extend(inited_for_sure);
|
||||
result.extend(Self::get_all_assigned_field(finalbody.as_slice())?);
|
||||
|
@ -391,9 +698,9 @@ impl TopLevelComposer {
|
|||
ast::StmtKind::With { body, .. } => {
|
||||
result.extend(Self::get_all_assigned_field(body.as_slice())?);
|
||||
}
|
||||
ast::StmtKind::Pass { .. } => {}
|
||||
ast::StmtKind::Assert { .. } => {}
|
||||
ast::StmtKind::Expr { .. } => {}
|
||||
ast::StmtKind::Pass { .. }
|
||||
| ast::StmtKind::Assert { .. }
|
||||
| ast::StmtKind::Expr { .. } => {}
|
||||
|
||||
_ => {
|
||||
unimplemented!()
|
||||
|
@ -406,7 +713,7 @@ impl TopLevelComposer {
|
|||
pub fn parse_parameter_default_value(
|
||||
default: &ast::Expr,
|
||||
resolver: &(dyn SymbolResolver + Send + Sync),
|
||||
) -> Result<SymbolValue, String> {
|
||||
) -> Result<SymbolValue, HashSet<String>> {
|
||||
parse_parameter_default_value(default, resolver)
|
||||
}
|
||||
|
||||
|
@ -416,40 +723,6 @@ impl TopLevelComposer {
|
|||
primitive: &PrimitiveStore,
|
||||
unifier: &mut Unifier,
|
||||
) -> Result<(), String> {
|
||||
fn type_default_param(
|
||||
val: &SymbolValue,
|
||||
primitive: &PrimitiveStore,
|
||||
unifier: &mut Unifier,
|
||||
) -> TypeAnnotation {
|
||||
match val {
|
||||
SymbolValue::Bool(..) => TypeAnnotation::Primitive(primitive.bool),
|
||||
SymbolValue::Double(..) => TypeAnnotation::Primitive(primitive.float),
|
||||
SymbolValue::I32(..) => TypeAnnotation::Primitive(primitive.int32),
|
||||
SymbolValue::I64(..) => TypeAnnotation::Primitive(primitive.int64),
|
||||
SymbolValue::U32(..) => TypeAnnotation::Primitive(primitive.uint32),
|
||||
SymbolValue::U64(..) => TypeAnnotation::Primitive(primitive.uint64),
|
||||
SymbolValue::Str(..) => TypeAnnotation::Primitive(primitive.str),
|
||||
SymbolValue::Tuple(vs) => {
|
||||
let vs_tys = vs
|
||||
.iter()
|
||||
.map(|v| type_default_param(v, primitive, unifier))
|
||||
.collect::<Vec<_>>();
|
||||
TypeAnnotation::Tuple(vs_tys)
|
||||
}
|
||||
SymbolValue::OptionNone => TypeAnnotation::CustomClass {
|
||||
id: primitive.option.get_obj_id(unifier),
|
||||
params: Default::default(),
|
||||
},
|
||||
SymbolValue::OptionSome(v) => {
|
||||
let ty = type_default_param(v, primitive, unifier);
|
||||
TypeAnnotation::CustomClass {
|
||||
id: primitive.option.get_obj_id(unifier),
|
||||
params: vec![ty],
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn is_compatible(
|
||||
found: &TypeAnnotation,
|
||||
expect: &TypeAnnotation,
|
||||
|
@ -465,7 +738,7 @@ impl TopLevelComposer {
|
|||
TypeAnnotation::CustomClass { id: e_id, params: e_param },
|
||||
) => {
|
||||
*f_id == *e_id
|
||||
&& *f_id == primitive.option.get_obj_id(unifier)
|
||||
&& *f_id == primitive.option.obj_id(unifier).unwrap()
|
||||
&& (f_param.is_empty()
|
||||
|| (f_param.len() == 1
|
||||
&& e_param.len() == 1
|
||||
|
@ -481,15 +754,15 @@ impl TopLevelComposer {
|
|||
}
|
||||
}
|
||||
|
||||
let found = type_default_param(val, primitive, unifier);
|
||||
if !is_compatible(&found, ty, unifier, primitive) {
|
||||
let found = val.get_type_annotation(primitive, unifier);
|
||||
if is_compatible(&found, ty, unifier, primitive) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(format!(
|
||||
"incompatible default parameter type, expect {}, found {}",
|
||||
ty.stringify(unifier),
|
||||
found.stringify(unifier),
|
||||
))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -497,14 +770,14 @@ impl TopLevelComposer {
|
|||
pub fn parse_parameter_default_value(
|
||||
default: &ast::Expr,
|
||||
resolver: &(dyn SymbolResolver + Send + Sync),
|
||||
) -> Result<SymbolValue, String> {
|
||||
fn handle_constant(val: &Constant, loc: &Location) -> Result<SymbolValue, String> {
|
||||
) -> Result<SymbolValue, HashSet<String>> {
|
||||
fn handle_constant(val: &Constant, loc: &Location) -> Result<SymbolValue, HashSet<String>> {
|
||||
match val {
|
||||
Constant::Int(v) => {
|
||||
if let Ok(v) = (*v).try_into() {
|
||||
Ok(SymbolValue::I32(v))
|
||||
} else {
|
||||
Err(format!("integer value out of range at {}", loc))
|
||||
Err(HashSet::from([format!("integer value out of range at {loc}")]))
|
||||
}
|
||||
}
|
||||
Constant::Float(v) => Ok(SymbolValue::Double(*v)),
|
||||
|
@ -512,74 +785,122 @@ pub fn parse_parameter_default_value(
|
|||
Constant::Tuple(tuple) => Ok(SymbolValue::Tuple(
|
||||
tuple.iter().map(|x| handle_constant(x, loc)).collect::<Result<Vec<_>, _>>()?,
|
||||
)),
|
||||
Constant::None => Err(format!(
|
||||
"`None` is not supported, use `none` for option type instead ({})",
|
||||
loc
|
||||
)),
|
||||
Constant::None => Err(HashSet::from([format!(
|
||||
"`None` is not supported, use `none` for option type instead ({loc})"
|
||||
)])),
|
||||
_ => unimplemented!("this constant is not supported at {}", loc),
|
||||
}
|
||||
}
|
||||
match &default.node {
|
||||
ast::ExprKind::Constant { value, .. } => handle_constant(value, &default.location),
|
||||
ast::ExprKind::Call { func, args, .. } if args.len() == 1 => {
|
||||
match &func.node {
|
||||
ast::ExprKind::Name { id, .. } if *id == "int64".into() => match &args[0].node {
|
||||
ast::ExprKind::Constant { value: Constant::Int(v), .. } => {
|
||||
let v: Result<i64, _> = (*v).try_into();
|
||||
match v {
|
||||
Ok(v) => Ok(SymbolValue::I64(v)),
|
||||
_ => Err(format!("default param value out of range at {}", default.location)),
|
||||
}
|
||||
ast::ExprKind::Call { func, args, .. } if args.len() == 1 => match &func.node {
|
||||
ast::ExprKind::Name { id, .. } if *id == "int64".into() => match &args[0].node {
|
||||
ast::ExprKind::Constant { value: Constant::Int(v), .. } => {
|
||||
let v: Result<i64, _> = (*v).try_into();
|
||||
match v {
|
||||
Ok(v) => Ok(SymbolValue::I64(v)),
|
||||
_ => Err(HashSet::from([format!(
|
||||
"default param value out of range at {}",
|
||||
default.location
|
||||
)])),
|
||||
}
|
||||
_ => Err(format!("only allow constant integer here at {}", default.location))
|
||||
}
|
||||
ast::ExprKind::Name { id, .. } if *id == "uint32".into() => match &args[0].node {
|
||||
ast::ExprKind::Constant { value: Constant::Int(v), .. } => {
|
||||
let v: Result<u32, _> = (*v).try_into();
|
||||
match v {
|
||||
Ok(v) => Ok(SymbolValue::U32(v)),
|
||||
_ => Err(format!("default param value out of range at {}", default.location)),
|
||||
}
|
||||
_ => Err(HashSet::from([format!(
|
||||
"only allow constant integer here at {}",
|
||||
default.location
|
||||
)])),
|
||||
},
|
||||
ast::ExprKind::Name { id, .. } if *id == "uint32".into() => match &args[0].node {
|
||||
ast::ExprKind::Constant { value: Constant::Int(v), .. } => {
|
||||
let v: Result<u32, _> = (*v).try_into();
|
||||
match v {
|
||||
Ok(v) => Ok(SymbolValue::U32(v)),
|
||||
_ => Err(HashSet::from([format!(
|
||||
"default param value out of range at {}",
|
||||
default.location
|
||||
)])),
|
||||
}
|
||||
_ => Err(format!("only allow constant integer here at {}", default.location))
|
||||
}
|
||||
ast::ExprKind::Name { id, .. } if *id == "uint64".into() => match &args[0].node {
|
||||
ast::ExprKind::Constant { value: Constant::Int(v), .. } => {
|
||||
let v: Result<u64, _> = (*v).try_into();
|
||||
match v {
|
||||
Ok(v) => Ok(SymbolValue::U64(v)),
|
||||
_ => Err(format!("default param value out of range at {}", default.location)),
|
||||
}
|
||||
_ => Err(HashSet::from([format!(
|
||||
"only allow constant integer here at {}",
|
||||
default.location
|
||||
)])),
|
||||
},
|
||||
ast::ExprKind::Name { id, .. } if *id == "uint64".into() => match &args[0].node {
|
||||
ast::ExprKind::Constant { value: Constant::Int(v), .. } => {
|
||||
let v: Result<u64, _> = (*v).try_into();
|
||||
match v {
|
||||
Ok(v) => Ok(SymbolValue::U64(v)),
|
||||
_ => Err(HashSet::from([format!(
|
||||
"default param value out of range at {}",
|
||||
default.location
|
||||
)])),
|
||||
}
|
||||
_ => Err(format!("only allow constant integer here at {}", default.location))
|
||||
}
|
||||
ast::ExprKind::Name { id, .. } if *id == "Some".into() => Ok(
|
||||
SymbolValue::OptionSome(
|
||||
Box::new(parse_parameter_default_value(&args[0], resolver)?)
|
||||
)
|
||||
),
|
||||
_ => Err(format!("unsupported default parameter at {}", default.location)),
|
||||
}
|
||||
}
|
||||
ast::ExprKind::Tuple { elts, .. } => Ok(SymbolValue::Tuple(elts
|
||||
.iter()
|
||||
.map(|x| parse_parameter_default_value(x, resolver))
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
_ => Err(HashSet::from([format!(
|
||||
"only allow constant integer here at {}",
|
||||
default.location
|
||||
)])),
|
||||
},
|
||||
ast::ExprKind::Name { id, .. } if *id == "Some".into() => Ok(SymbolValue::OptionSome(
|
||||
Box::new(parse_parameter_default_value(&args[0], resolver)?),
|
||||
)),
|
||||
_ => Err(HashSet::from([format!(
|
||||
"unsupported default parameter at {}",
|
||||
default.location
|
||||
)])),
|
||||
},
|
||||
ast::ExprKind::Tuple { elts, .. } => Ok(SymbolValue::Tuple(
|
||||
elts.iter()
|
||||
.map(|x| parse_parameter_default_value(x, resolver))
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
)),
|
||||
ast::ExprKind::Name { id, .. } if id == &"none".into() => Ok(SymbolValue::OptionNone),
|
||||
ast::ExprKind::Name { id, .. } => {
|
||||
resolver.get_default_param_value(default).ok_or_else(
|
||||
|| format!(
|
||||
resolver.get_default_param_value(default).ok_or_else(|| {
|
||||
HashSet::from([format!(
|
||||
"`{}` cannot be used as a default parameter at {} \
|
||||
(not primitive type, option or tuple / not defined?)",
|
||||
id,
|
||||
default.location
|
||||
)
|
||||
)
|
||||
(not primitive type, option or tuple / not defined?)",
|
||||
id, default.location
|
||||
)])
|
||||
})
|
||||
}
|
||||
_ => Err(format!(
|
||||
_ => Err(HashSet::from([format!(
|
||||
"unsupported default parameter (not primitive type, option or tuple) at {}",
|
||||
default.location
|
||||
))
|
||||
)])),
|
||||
}
|
||||
}
|
||||
|
||||
/// Obtains the element type of an array-like type.
|
||||
pub fn arraylike_flatten_element_type(unifier: &mut Unifier, ty: Type) -> Type {
|
||||
match &*unifier.get_ty(ty) {
|
||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||
unpack_ndarray_var_tys(unifier, ty).0
|
||||
}
|
||||
|
||||
TypeEnum::TList { ty } => arraylike_flatten_element_type(unifier, *ty),
|
||||
_ => ty,
|
||||
}
|
||||
}
|
||||
|
||||
/// Obtains the number of dimensions of an array-like type.
|
||||
pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 {
|
||||
match &*unifier.get_ty(ty) {
|
||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||
let ndims = unpack_ndarray_var_tys(unifier, ty).1;
|
||||
let TypeEnum::TLiteral { values, .. } = &*unifier.get_ty_immutable(ndims) else {
|
||||
panic!("Expected TLiteral for ndarray.ndims, got {}", unifier.stringify(ndims))
|
||||
};
|
||||
|
||||
if values.len() > 1 {
|
||||
todo!("Getting num of dimensions for ndarray with more than one ndim bound is unimplemented")
|
||||
}
|
||||
|
||||
u64::try_from(values[0].clone()).unwrap()
|
||||
}
|
||||
|
||||
TypeEnum::TList { ty } => arraylike_get_ndims(unifier, *ty) + 1,
|
||||
_ => 0,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,17 +3,21 @@ use std::{
|
|||
collections::{HashMap, HashSet},
|
||||
fmt::Debug,
|
||||
iter::FromIterator,
|
||||
ops::{Deref, DerefMut},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use super::codegen::CodeGenContext;
|
||||
use super::typecheck::type_inferencer::PrimitiveStore;
|
||||
use super::typecheck::typedef::{FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, Unifier};
|
||||
use super::typecheck::typedef::{
|
||||
FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, Unifier, VarMap,
|
||||
};
|
||||
use crate::{
|
||||
codegen::CodeGenerator,
|
||||
symbol_resolver::{SymbolResolver, ValueEnum},
|
||||
typecheck::{type_inferencer::CodeLocation, typedef::CallId},
|
||||
typecheck::{
|
||||
type_inferencer::CodeLocation,
|
||||
typedef::{CallId, TypeVarId},
|
||||
},
|
||||
};
|
||||
use inkwell::values::BasicValueEnum;
|
||||
use itertools::{izip, Itertools};
|
||||
|
@ -26,36 +30,43 @@ pub struct DefinitionId(pub usize);
|
|||
pub mod builtins;
|
||||
pub mod composer;
|
||||
pub mod helper;
|
||||
pub mod numpy;
|
||||
pub mod type_annotation;
|
||||
use composer::*;
|
||||
use type_annotation::*;
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
|
||||
type GenCallCallback = Box<
|
||||
dyn for<'ctx, 'a> Fn(
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
Option<(Type, ValueEnum<'ctx>)>,
|
||||
(&FunSignature, DefinitionId),
|
||||
Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
||||
&mut dyn CodeGenerator,
|
||||
) -> Result<Option<BasicValueEnum<'ctx>>, String>
|
||||
+ Send
|
||||
+ Sync,
|
||||
>;
|
||||
type GenCallCallback = dyn for<'ctx, 'a> Fn(
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
Option<(Type, ValueEnum<'ctx>)>,
|
||||
(&FunSignature, DefinitionId),
|
||||
Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
||||
&mut dyn CodeGenerator,
|
||||
) -> Result<Option<BasicValueEnum<'ctx>>, String>
|
||||
+ Send
|
||||
+ Sync;
|
||||
|
||||
pub struct GenCall {
|
||||
fp: GenCallCallback,
|
||||
fp: Box<GenCallCallback>,
|
||||
}
|
||||
|
||||
impl GenCall {
|
||||
pub fn new(fp: GenCallCallback) -> GenCall {
|
||||
#[must_use]
|
||||
pub fn new(fp: Box<GenCallCallback>) -> GenCall {
|
||||
GenCall { fp }
|
||||
}
|
||||
|
||||
pub fn run<'ctx, 'a>(
|
||||
/// Creates a dummy instance of [`GenCall`], which invokes [`unreachable!()`] with the given
|
||||
/// `reason`.
|
||||
#[must_use]
|
||||
pub fn create_dummy(reason: String) -> GenCall {
|
||||
Self::new(Box::new(move |_, _, _, _, _| unreachable!("{reason}")))
|
||||
}
|
||||
|
||||
pub fn run<'ctx>(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
obj: Option<(Type, ValueEnum<'ctx>)>,
|
||||
fun: (&FunSignature, DefinitionId),
|
||||
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
||||
|
@ -75,58 +86,62 @@ impl Debug for GenCall {
|
|||
pub struct FunInstance {
|
||||
pub body: Arc<Vec<Stmt<Option<Type>>>>,
|
||||
pub calls: Arc<HashMap<CodeLocation, CallId>>,
|
||||
pub subst: HashMap<u32, Type>,
|
||||
pub subst: VarMap,
|
||||
pub unifier_id: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum TopLevelDef {
|
||||
Class {
|
||||
// name for error messages and symbols
|
||||
/// Name for error messages and symbols.
|
||||
name: StrRef,
|
||||
// object ID used for TypeEnum
|
||||
/// Object ID used for [TypeEnum].
|
||||
object_id: DefinitionId,
|
||||
/// type variables bounded to the class.
|
||||
type_vars: Vec<Type>,
|
||||
// class fields
|
||||
// name, type, is mutable
|
||||
/// Class fields.
|
||||
///
|
||||
/// Name and type is mutable.
|
||||
fields: Vec<(StrRef, Type, bool)>,
|
||||
// class methods, pointing to the corresponding function definition.
|
||||
/// Class methods, pointing to the corresponding function definition.
|
||||
methods: Vec<(StrRef, Type, DefinitionId)>,
|
||||
// ancestor classes, including itself.
|
||||
/// Ancestor classes, including itself.
|
||||
ancestors: Vec<TypeAnnotation>,
|
||||
// symbol resolver of the module defined the class, none if it is built-in type
|
||||
/// Symbol resolver of the module defined the class; [None] if it is built-in type.
|
||||
resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>,
|
||||
// constructor type
|
||||
/// Constructor type.
|
||||
constructor: Option<Type>,
|
||||
// definition location
|
||||
/// Definition location.
|
||||
loc: Option<Location>,
|
||||
},
|
||||
Function {
|
||||
// prefix for symbol, should be unique globally
|
||||
/// Prefix for symbol, should be unique globally.
|
||||
name: String,
|
||||
// simple name, the same as in method/function definition
|
||||
/// Simple name, the same as in method/function definition.
|
||||
simple_name: StrRef,
|
||||
// function signature.
|
||||
/// Function signature.
|
||||
signature: Type,
|
||||
// instantiated type variable IDs
|
||||
var_id: Vec<u32>,
|
||||
/// Instantiated type variable IDs.
|
||||
var_id: Vec<TypeVarId>,
|
||||
/// Function instance to symbol mapping
|
||||
/// Key: string representation of type variable values, sorted by variable ID in ascending
|
||||
///
|
||||
/// * Key: String representation of type variable values, sorted by variable ID in ascending
|
||||
/// order, including type variables associated with the class.
|
||||
/// Value: function symbol name.
|
||||
/// * Value: Function symbol name.
|
||||
instance_to_symbol: HashMap<String, String>,
|
||||
/// Function instances to annotated AST mapping
|
||||
/// Key: string representation of type variable values, sorted by variable ID in ascending
|
||||
///
|
||||
/// * Key: String representation of type variable values, sorted by variable ID in ascending
|
||||
/// order, including type variables associated with the class. Excluding rigid type
|
||||
/// variables.
|
||||
/// rigid type variables that would be substituted when the function is instantiated.
|
||||
///
|
||||
/// Rigid type variables that would be substituted when the function is instantiated.
|
||||
instance_to_stmt: HashMap<String, FunInstance>,
|
||||
// symbol resolver of the module defined the class
|
||||
/// Symbol resolver of the module defined the class.
|
||||
resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>,
|
||||
// custom codegen callback
|
||||
/// Custom code generation callback.
|
||||
codegen_callback: Option<Arc<GenCall>>,
|
||||
// definition location
|
||||
/// Definition location.
|
||||
loc: Option<Location>,
|
||||
},
|
||||
}
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
use crate::{
|
||||
toplevel::helper::PrimDef,
|
||||
typecheck::{
|
||||
type_inferencer::PrimitiveStore,
|
||||
typedef::{Type, TypeEnum, TypeVarId, Unifier, VarMap},
|
||||
},
|
||||
};
|
||||
use itertools::Itertools;
|
||||
|
||||
/// Creates a `ndarray` [`Type`] with the given type arguments.
|
||||
///
|
||||
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
|
||||
/// specialized.
|
||||
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
|
||||
/// specialized.
|
||||
pub fn make_ndarray_ty(
|
||||
unifier: &mut Unifier,
|
||||
primitives: &PrimitiveStore,
|
||||
dtype: Option<Type>,
|
||||
ndims: Option<Type>,
|
||||
) -> Type {
|
||||
subst_ndarray_tvars(unifier, primitives.ndarray, dtype, ndims)
|
||||
}
|
||||
|
||||
/// Substitutes type variables in `ndarray`.
|
||||
///
|
||||
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
|
||||
/// specialized.
|
||||
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
|
||||
/// specialized.
|
||||
pub fn subst_ndarray_tvars(
|
||||
unifier: &mut Unifier,
|
||||
ndarray: Type,
|
||||
dtype: Option<Type>,
|
||||
ndims: Option<Type>,
|
||||
) -> Type {
|
||||
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
|
||||
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
|
||||
};
|
||||
debug_assert_eq!(*obj_id, PrimDef::NDArray.id());
|
||||
|
||||
if dtype.is_none() && ndims.is_none() {
|
||||
return ndarray;
|
||||
}
|
||||
|
||||
let tvar_ids = params.iter().map(|(obj_id, _)| *obj_id).collect_vec();
|
||||
debug_assert_eq!(tvar_ids.len(), 2);
|
||||
|
||||
let mut tvar_subst = VarMap::new();
|
||||
if let Some(dtype) = dtype {
|
||||
tvar_subst.insert(tvar_ids[0], dtype);
|
||||
}
|
||||
if let Some(ndims) = ndims {
|
||||
tvar_subst.insert(tvar_ids[1], ndims);
|
||||
}
|
||||
|
||||
unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray)
|
||||
}
|
||||
|
||||
fn unpack_ndarray_tvars(unifier: &mut Unifier, ndarray: Type) -> Vec<(TypeVarId, Type)> {
|
||||
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
|
||||
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
|
||||
};
|
||||
debug_assert_eq!(*obj_id, PrimDef::NDArray.id());
|
||||
debug_assert_eq!(params.len(), 2);
|
||||
|
||||
params
|
||||
.iter()
|
||||
.sorted_by_key(|(obj_id, _)| *obj_id)
|
||||
.map(|(var_id, ty)| (*var_id, *ty))
|
||||
.collect_vec()
|
||||
}
|
||||
|
||||
/// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds
|
||||
/// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray`
|
||||
/// respectively.
|
||||
pub fn unpack_ndarray_var_ids(unifier: &mut Unifier, ndarray: Type) -> (TypeVarId, TypeVarId) {
|
||||
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.0).collect_tuple().unwrap()
|
||||
}
|
||||
|
||||
/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to
|
||||
/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively.
|
||||
pub fn unpack_ndarray_var_tys(unifier: &mut Unifier, ndarray: Type) -> (Type, Type) {
|
||||
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.1).collect_tuple().unwrap()
|
||||
}
|
|
@ -1,13 +1,11 @@
|
|||
---
|
||||
source: nac3core/src/toplevel/test.rs
|
||||
assertion_line: 549
|
||||
expression: res_vec
|
||||
|
||||
---
|
||||
[
|
||||
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
||||
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [18]\n}\n",
|
||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(239)]\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
||||
|
|
|
@ -1,15 +1,13 @@
|
|||
---
|
||||
source: nac3core/src/toplevel/test.rs
|
||||
assertion_line: 549
|
||||
expression: res_vec
|
||||
|
||||
---
|
||||
[
|
||||
"Class {\nname: \"A\",\nancestors: [\"A[T]\"],\nfields: [\"a\", \"b\", \"c\"],\nmethods: [(\"__init__\", \"fn[[t:T], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"T\"]\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar7]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar7\"]\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar228]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar228\"]\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
||||
|
|
|
@ -1,15 +1,13 @@
|
|||
---
|
||||
source: nac3core/src/toplevel/test.rs
|
||||
assertion_line: 549
|
||||
expression: res_vec
|
||||
|
||||
---
|
||||
[
|
||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [20]\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [25]\n}\n",
|
||||
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[int32, list[float]]], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(241)]\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(246)]\n}\n",
|
||||
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
]
|
||||
|
|
|
@ -1,15 +1,13 @@
|
|||
---
|
||||
source: nac3core/src/toplevel/test.rs
|
||||
assertion_line: 549
|
||||
expression: res_vec
|
||||
|
||||
---
|
||||
[
|
||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar6, typevar7]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[bool, float], b:B], none]\"), (\"fun\", \"fn[[a:A[bool, float]], A[bool, int32]]\")],\ntype_vars: [\"typevar6\", \"typevar7\"]\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[bool, float], b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[bool, float]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[bool, float]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[int32, list[B]]], tuple[A[bool, virtual[A[B, int32]]], B]]\")],\ntype_vars: []\n}\n",
|
||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar227, typevar228]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar227\", \"typevar228\"]\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:B], B]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"B.bar\",\nsig: \"fn[[a:A[int32, list[B]]], tuple[A[bool, virtual[A[B, int32]]], B]]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"B.bar\",\nsig: \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\",\nvar_id: []\n}\n",
|
||||
]
|
||||
|
|
|
@ -1,19 +1,17 @@
|
|||
---
|
||||
source: nac3core/src/toplevel/test.rs
|
||||
assertion_line: 549
|
||||
expression: res_vec
|
||||
|
||||
---
|
||||
[
|
||||
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [26]\n}\n",
|
||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(247)]\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [34]\n}\n",
|
||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(255)]\n}\n",
|
||||
]
|
||||
|
|
|
@ -36,7 +36,7 @@ struct Resolver(Arc<ResolverInternal>);
|
|||
impl SymbolResolver for Resolver {
|
||||
fn get_default_param_value(
|
||||
&self,
|
||||
_: &nac3parser::ast::Expr,
|
||||
_: &ast::Expr,
|
||||
) -> Option<crate::symbol_resolver::SymbolValue> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
@ -64,8 +64,13 @@ impl SymbolResolver for Resolver {
|
|||
unimplemented!()
|
||||
}
|
||||
|
||||
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, String> {
|
||||
self.0.id_to_def.lock().get(&id).cloned().ok_or_else(|| "Unknown identifier".to_string())
|
||||
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> {
|
||||
self.0
|
||||
.id_to_def
|
||||
.lock()
|
||||
.get(&id)
|
||||
.cloned()
|
||||
.ok_or_else(|| HashSet::from(["Unknown identifier".to_string()]))
|
||||
}
|
||||
|
||||
fn get_string_id(&self, _: &str) -> i32 {
|
||||
|
@ -105,21 +110,37 @@ impl SymbolResolver for Resolver {
|
|||
def __init__(self):
|
||||
self.c: int32 = 4
|
||||
self.a: bool = True
|
||||
"}
|
||||
"},
|
||||
];
|
||||
"register"
|
||||
)]
|
||||
fn test_simple_register(source: Vec<&str>) {
|
||||
let mut composer: TopLevelComposer = Default::default();
|
||||
let mut composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 64).0;
|
||||
|
||||
for s in source {
|
||||
let ast = parse_program(s, Default::default()).unwrap();
|
||||
let ast = ast[0].clone();
|
||||
|
||||
composer.register_top_level(ast, None, "".into()).unwrap();
|
||||
composer.register_top_level(ast, None, "".into(), false).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[test_case(
|
||||
indoc! {"
|
||||
class A:
|
||||
def foo(self):
|
||||
pass
|
||||
a = A()
|
||||
"};
|
||||
"register"
|
||||
)]
|
||||
fn test_simple_register_without_constructor(source: &str) {
|
||||
let mut composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 64).0;
|
||||
let ast = parse_program(source, Default::default()).unwrap();
|
||||
let ast = ast[0].clone();
|
||||
composer.register_top_level(ast, None, "".into(), true).unwrap();
|
||||
}
|
||||
|
||||
#[test_case(
|
||||
vec![
|
||||
indoc! {"
|
||||
|
@ -148,7 +169,7 @@ fn test_simple_register(source: Vec<&str>) {
|
|||
"function compose"
|
||||
)]
|
||||
fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&str>) {
|
||||
let mut composer: TopLevelComposer = Default::default();
|
||||
let mut composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 64).0;
|
||||
|
||||
let internal_resolver = Arc::new(ResolverInternal {
|
||||
id_to_def: Default::default(),
|
||||
|
@ -163,7 +184,7 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
|
|||
let ast = ast[0].clone();
|
||||
|
||||
let (id, def_id, ty) =
|
||||
composer.register_top_level(ast, Some(resolver.clone()), "".into()).unwrap();
|
||||
composer.register_top_level(ast, Some(resolver.clone()), "".into(), false).unwrap();
|
||||
internal_resolver.add_id_def(id, def_id);
|
||||
if let Some(ty) = ty {
|
||||
internal_resolver.add_id_type(id, ty);
|
||||
|
@ -345,7 +366,7 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
|
|||
pass
|
||||
"}
|
||||
],
|
||||
vec!["application of type vars to generic class is not currently supported (at unknown: line 4 column 24)"];
|
||||
vec!["application of type vars to generic class is not currently supported (at unknown:4:24)"];
|
||||
"err no type var in generic app"
|
||||
)]
|
||||
#[test_case(
|
||||
|
@ -401,7 +422,7 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
|
|||
def __init__():
|
||||
pass
|
||||
"}],
|
||||
vec!["__init__ method must have a `self` parameter (at unknown: line 2 column 5)"];
|
||||
vec!["__init__ method must have a `self` parameter (at unknown:2:5)"];
|
||||
"err no self_1"
|
||||
)]
|
||||
#[test_case(
|
||||
|
@ -423,7 +444,7 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
|
|||
"}
|
||||
|
||||
],
|
||||
vec!["a class definition can only have at most one base class declaration and one generic declaration (at unknown: line 1 column 24)"];
|
||||
vec!["a class definition can only have at most one base class declaration and one generic declaration (at unknown:1:24)"];
|
||||
"err multiple inheritance"
|
||||
)]
|
||||
#[test_case(
|
||||
|
@ -491,12 +512,12 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
|
|||
pass
|
||||
"}
|
||||
],
|
||||
vec!["duplicate definition of class `A` (at unknown: line 1 column 1)"];
|
||||
vec!["duplicate definition of class `A` (at unknown:1:1)"];
|
||||
"class same name"
|
||||
)]
|
||||
fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
|
||||
let print = false;
|
||||
let mut composer: TopLevelComposer = Default::default();
|
||||
let mut composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 64).0;
|
||||
|
||||
let internal_resolver = make_internal_resolver_with_tvar(
|
||||
vec![
|
||||
|
@ -515,7 +536,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
|
|||
let ast = ast[0].clone();
|
||||
|
||||
let (id, def_id, ty) = {
|
||||
match composer.register_top_level(ast, Some(resolver.clone()), "".into()) {
|
||||
match composer.register_top_level(ast, Some(resolver.clone()), "".into(), false) {
|
||||
Ok(x) => x,
|
||||
Err(msg) => {
|
||||
if print {
|
||||
|
@ -535,9 +556,9 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
|
|||
|
||||
if let Err(msg) = composer.start_analysis(false) {
|
||||
if print {
|
||||
println!("{}", msg);
|
||||
println!("{}", msg.iter().sorted().join("\n----------\n"));
|
||||
} else {
|
||||
assert_eq!(res[0], msg);
|
||||
assert_eq!(res[0], msg.iter().next().unwrap());
|
||||
}
|
||||
} else {
|
||||
// skip 5 to skip primitives
|
||||
|
@ -673,7 +694,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
|
|||
)]
|
||||
fn test_inference(source: Vec<&str>, res: Vec<&str>) {
|
||||
let print = true;
|
||||
let mut composer: TopLevelComposer = Default::default();
|
||||
let mut composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 64).0;
|
||||
|
||||
let internal_resolver = make_internal_resolver_with_tvar(
|
||||
vec![
|
||||
|
@ -699,7 +720,7 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) {
|
|||
let ast = ast[0].clone();
|
||||
|
||||
let (id, def_id, ty) = {
|
||||
match composer.register_top_level(ast, Some(resolver.clone()), "".into()) {
|
||||
match composer.register_top_level(ast, Some(resolver.clone()), "".into(), false) {
|
||||
Ok(x) => x,
|
||||
Err(msg) => {
|
||||
if print {
|
||||
|
@ -719,9 +740,9 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) {
|
|||
|
||||
if let Err(msg) = composer.start_analysis(true) {
|
||||
if print {
|
||||
println!("{}", msg);
|
||||
println!("{}", msg.iter().sorted().join("\n----------\n"));
|
||||
} else {
|
||||
assert_eq!(res[0], msg);
|
||||
assert_eq!(res[0], msg.iter().next().unwrap());
|
||||
}
|
||||
} else {
|
||||
// skip 5 to skip primitives
|
||||
|
@ -761,11 +782,11 @@ fn make_internal_resolver_with_tvar(
|
|||
.into_iter()
|
||||
.map(|(name, range)| {
|
||||
(name, {
|
||||
let (ty, id) = unifier.get_fresh_var_with_range(range.as_slice(), None, None);
|
||||
let tvar = unifier.get_fresh_var_with_range(range.as_slice(), None, None);
|
||||
if print {
|
||||
println!("{}: {:?}, typevar{}", name, ty, id);
|
||||
println!("{}: {:?}, typevar{}", name, tvar.ty, tvar.id);
|
||||
}
|
||||
ty
|
||||
tvar.ty
|
||||
})
|
||||
})
|
||||
.collect::<HashMap<_, _>>()
|
||||
|
|
|
@ -1,4 +1,8 @@
|
|||
use super::*;
|
||||
use crate::symbol_resolver::SymbolValue;
|
||||
use crate::toplevel::helper::PrimDef;
|
||||
use crate::typecheck::typedef::VarMap;
|
||||
use nac3parser::ast::Constant;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum TypeAnnotation {
|
||||
|
@ -12,6 +16,8 @@ pub enum TypeAnnotation {
|
|||
// can only be CustomClassKind
|
||||
Virtual(Box<TypeAnnotation>),
|
||||
TypeVar(Type),
|
||||
/// A `Literal` allowing a subset of literals.
|
||||
Literal(Vec<Constant>),
|
||||
List(Box<TypeAnnotation>),
|
||||
Tuple(Vec<TypeAnnotation>),
|
||||
}
|
||||
|
@ -22,52 +28,58 @@ impl TypeAnnotation {
|
|||
match self {
|
||||
Primitive(ty) | TypeVar(ty) => unifier.stringify(*ty),
|
||||
CustomClass { id, params } => {
|
||||
let class_name = match unifier.top_level {
|
||||
Some(ref top) => {
|
||||
if let TopLevelDef::Class { name, .. } =
|
||||
&*top.definitions.read()[id.0].read()
|
||||
{
|
||||
(*name).into()
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
let class_name = if let Some(ref top) = unifier.top_level {
|
||||
if let TopLevelDef::Class { name, .. } = &*top.definitions.read()[id.0].read() {
|
||||
(*name).into()
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
None => format!("class_def_{}", id.0),
|
||||
} else {
|
||||
format!("class_def_{}", id.0)
|
||||
};
|
||||
format!(
|
||||
"{}{}",
|
||||
class_name,
|
||||
{
|
||||
let param_list = params.iter().map(|p| p.stringify(unifier)).collect_vec().join(", ");
|
||||
if param_list.is_empty() {
|
||||
"".into()
|
||||
} else {
|
||||
format!("[{}]", param_list)
|
||||
}
|
||||
format!("{}{}", class_name, {
|
||||
let param_list =
|
||||
params.iter().map(|p| p.stringify(unifier)).collect_vec().join(", ");
|
||||
if param_list.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!("[{param_list}]")
|
||||
}
|
||||
)
|
||||
})
|
||||
}
|
||||
Literal(values) => {
|
||||
format!("Literal({})", values.iter().map(|v| format!("{v:?}")).join(", "))
|
||||
}
|
||||
Virtual(ty) => format!("virtual[{}]", ty.stringify(unifier)),
|
||||
List(ty) => format!("list[{}]", ty.stringify(unifier)),
|
||||
Tuple(types) => {
|
||||
format!("tuple[{}]", types.iter().map(|p| p.stringify(unifier)).collect_vec().join(", "))
|
||||
format!(
|
||||
"tuple[{}]",
|
||||
types.iter().map(|p| p.stringify(unifier)).collect_vec().join(", ")
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parse_ast_to_type_annotation_kinds<T>(
|
||||
/// Parses an AST expression `expr` into a [`TypeAnnotation`].
|
||||
///
|
||||
/// * `locked` - A [`HashMap`] containing the IDs of known definitions, mapped to a [`Vec`] of all
|
||||
/// generic variables associated with the definition.
|
||||
/// * `type_var` - The type variable associated with the type argument currently being parsed. Pass
|
||||
/// [`None`] when this function is invoked externally.
|
||||
pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
|
||||
resolver: &(dyn SymbolResolver + Send + Sync),
|
||||
top_level_defs: &[Arc<RwLock<TopLevelDef>>],
|
||||
unifier: &mut Unifier,
|
||||
primitives: &PrimitiveStore,
|
||||
expr: &ast::Expr<T>,
|
||||
// the key stores the type_var of this topleveldef::class, we only need this field here
|
||||
locked: HashMap<DefinitionId, Vec<Type>>,
|
||||
) -> Result<TypeAnnotation, String> {
|
||||
locked: HashMap<DefinitionId, Vec<Type>, S>,
|
||||
) -> Result<TypeAnnotation, HashSet<String>> {
|
||||
let name_handle = |id: &StrRef,
|
||||
unifier: &mut Unifier,
|
||||
locked: HashMap<DefinitionId, Vec<Type>>| {
|
||||
locked: HashMap<DefinitionId, Vec<Type>, S>| {
|
||||
if id == &"int32".into() {
|
||||
Ok(TypeAnnotation::Primitive(primitives.int32))
|
||||
} else if id == &"int64".into() {
|
||||
|
@ -83,7 +95,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
|
|||
} else if id == &"str".into() {
|
||||
Ok(TypeAnnotation::Primitive(primitives.str))
|
||||
} else if id == &"Exception".into() {
|
||||
Ok(TypeAnnotation::CustomClass { id: DefinitionId(7), params: Default::default() })
|
||||
Ok(TypeAnnotation::CustomClass { id: PrimDef::Exception.id(), params: Vec::default() })
|
||||
} else if let Ok(obj_id) = resolver.get_identifier_def(*id) {
|
||||
let type_vars = {
|
||||
let def_read = top_level_defs[obj_id.0].try_read();
|
||||
|
@ -91,10 +103,10 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
|
|||
if let TopLevelDef::Class { type_vars, .. } = &*def_read {
|
||||
type_vars.clone()
|
||||
} else {
|
||||
return Err(format!(
|
||||
return Err(HashSet::from([format!(
|
||||
"function cannot be used as a type (at {})",
|
||||
expr.location
|
||||
));
|
||||
)]));
|
||||
}
|
||||
} else {
|
||||
locked.get(&obj_id).unwrap().clone()
|
||||
|
@ -102,23 +114,29 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
|
|||
};
|
||||
// check param number here
|
||||
if !type_vars.is_empty() {
|
||||
return Err(format!(
|
||||
return Err(HashSet::from([format!(
|
||||
"expect {} type variable parameter but got 0 (at {})",
|
||||
type_vars.len(),
|
||||
expr.location,
|
||||
));
|
||||
)]));
|
||||
}
|
||||
Ok(TypeAnnotation::CustomClass { id: obj_id, params: vec![] })
|
||||
} else if let Ok(ty) = resolver.get_symbol_type(unifier, top_level_defs, primitives, *id) {
|
||||
if let TypeEnum::TVar { .. } = unifier.get_ty(ty).as_ref() {
|
||||
let var = unifier.get_fresh_var(Some(*id), Some(expr.location)).0;
|
||||
let var = unifier.get_fresh_var(Some(*id), Some(expr.location)).ty;
|
||||
unifier.unify(var, ty).unwrap();
|
||||
Ok(TypeAnnotation::TypeVar(ty))
|
||||
} else {
|
||||
Err(format!("`{}` is not a valid type annotation (at {})", id, expr.location))
|
||||
Err(HashSet::from([format!(
|
||||
"`{}` is not a valid type annotation (at {})",
|
||||
id, expr.location
|
||||
)]))
|
||||
}
|
||||
} else {
|
||||
Err(format!("`{}` is not a valid type annotation (at {})", id, expr.location))
|
||||
Err(HashSet::from([format!(
|
||||
"`{}` is not a valid type annotation (at {})",
|
||||
id, expr.location
|
||||
)]))
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -126,20 +144,24 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
|
|||
|id: &StrRef,
|
||||
slice: &ast::Expr<T>,
|
||||
unifier: &mut Unifier,
|
||||
mut locked: HashMap<DefinitionId, Vec<Type>>| {
|
||||
if vec!["virtual".into(), "Generic".into(), "list".into(), "tuple".into()].contains(id)
|
||||
mut locked: HashMap<DefinitionId, Vec<Type>, S>| {
|
||||
if ["virtual".into(), "Generic".into(), "list".into(), "tuple".into(), "Option".into()]
|
||||
.contains(id)
|
||||
{
|
||||
return Err(format!("keywords cannot be class name (at {})", expr.location));
|
||||
return Err(HashSet::from([format!(
|
||||
"keywords cannot be class name (at {})",
|
||||
expr.location
|
||||
)]));
|
||||
}
|
||||
let obj_id = resolver.get_identifier_def(*id)?;
|
||||
let type_vars = {
|
||||
let def_read = top_level_defs[obj_id.0].try_read();
|
||||
if let Some(def_read) = def_read {
|
||||
if let TopLevelDef::Class { type_vars, .. } = &*def_read {
|
||||
type_vars.clone()
|
||||
} else {
|
||||
let TopLevelDef::Class { type_vars, .. } = &*def_read else {
|
||||
unreachable!("must be class here")
|
||||
}
|
||||
};
|
||||
|
||||
type_vars.clone()
|
||||
} else {
|
||||
locked.get(&obj_id).unwrap().clone()
|
||||
}
|
||||
|
@ -152,12 +174,12 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
|
|||
vec![slice]
|
||||
};
|
||||
if type_vars.len() != params_ast.len() {
|
||||
return Err(format!(
|
||||
return Err(HashSet::from([format!(
|
||||
"expect {} type parameters but got {} (at {})",
|
||||
type_vars.len(),
|
||||
params_ast.len(),
|
||||
params_ast[0].location,
|
||||
));
|
||||
)]));
|
||||
}
|
||||
let result = params_ast
|
||||
.iter()
|
||||
|
@ -181,15 +203,17 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
|
|||
if no_type_var {
|
||||
result
|
||||
} else {
|
||||
return Err(format!(
|
||||
"application of type vars to generic class \
|
||||
is not currently supported (at {})",
|
||||
params_ast[0].location
|
||||
));
|
||||
return Err(HashSet::from([
|
||||
format!(
|
||||
"application of type vars to generic class is not currently supported (at {})",
|
||||
params_ast[0].location
|
||||
),
|
||||
]));
|
||||
}
|
||||
};
|
||||
Ok(TypeAnnotation::CustomClass { id: obj_id, params: param_type_infos })
|
||||
};
|
||||
|
||||
match &expr.node {
|
||||
ast::ExprKind::Name { id, .. } => name_handle(id, unifier, locked),
|
||||
// virtual
|
||||
|
@ -281,16 +305,70 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
|
|||
Ok(TypeAnnotation::Tuple(type_annotations))
|
||||
}
|
||||
|
||||
// Literal
|
||||
ast::ExprKind::Subscript { value, slice, .. }
|
||||
if {
|
||||
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"Literal".into())
|
||||
} =>
|
||||
{
|
||||
let tup_elts = {
|
||||
if let ast::ExprKind::Tuple { elts, .. } = &slice.node {
|
||||
elts.as_slice()
|
||||
} else {
|
||||
std::slice::from_ref(slice.as_ref())
|
||||
}
|
||||
};
|
||||
let type_annotations = tup_elts
|
||||
.iter()
|
||||
.map(|e| match &e.node {
|
||||
ast::ExprKind::Constant { value, .. } => {
|
||||
Ok(TypeAnnotation::Literal(vec![value.clone()]))
|
||||
}
|
||||
_ => parse_ast_to_type_annotation_kinds(
|
||||
resolver,
|
||||
top_level_defs,
|
||||
unifier,
|
||||
primitives,
|
||||
e,
|
||||
locked.clone(),
|
||||
),
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flat_map(|type_ann| match type_ann {
|
||||
TypeAnnotation::Literal(values) => values,
|
||||
_ => unreachable!(),
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
if type_annotations.len() == 1 {
|
||||
Ok(TypeAnnotation::Literal(type_annotations))
|
||||
} else {
|
||||
Err(HashSet::from([format!(
|
||||
"multiple literal bounds are currently unsupported (at {})",
|
||||
value.location
|
||||
)]))
|
||||
}
|
||||
}
|
||||
|
||||
// custom class
|
||||
ast::ExprKind::Subscript { value, slice, .. } => {
|
||||
if let ast::ExprKind::Name { id, .. } = &value.node {
|
||||
class_name_handle(id, slice, unifier, locked)
|
||||
} else {
|
||||
Err(format!("unsupported expression type for class name (at {})", value.location))
|
||||
Err(HashSet::from([format!(
|
||||
"unsupported expression type for class name (at {})",
|
||||
value.location
|
||||
)]))
|
||||
}
|
||||
}
|
||||
|
||||
_ => Err(format!("unsupported expression for type annotation (at {})", expr.location)),
|
||||
ast::ExprKind::Constant { value, .. } => Ok(TypeAnnotation::Literal(vec![value.clone()])),
|
||||
|
||||
_ => Err(HashSet::from([format!(
|
||||
"unsupported expression for type annotation (at {})",
|
||||
expr.location
|
||||
)])),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -300,109 +378,140 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
|
|||
pub fn get_type_from_type_annotation_kinds(
|
||||
top_level_defs: &[Arc<RwLock<TopLevelDef>>],
|
||||
unifier: &mut Unifier,
|
||||
primitives: &PrimitiveStore,
|
||||
ann: &TypeAnnotation,
|
||||
subst_list: &mut Option<Vec<Type>>
|
||||
) -> Result<Type, String> {
|
||||
subst_list: &mut Option<Vec<Type>>,
|
||||
) -> Result<Type, HashSet<String>> {
|
||||
match ann {
|
||||
TypeAnnotation::CustomClass { id: obj_id, params } => {
|
||||
let def_read = top_level_defs[obj_id.0].read();
|
||||
let class_def: &TopLevelDef = def_read.deref();
|
||||
if let TopLevelDef::Class { fields, methods, type_vars, .. } = class_def {
|
||||
if type_vars.len() != params.len() {
|
||||
Err(format!(
|
||||
"unexpected number of type parameters: expected {} but got {}",
|
||||
type_vars.len(),
|
||||
params.len()
|
||||
))
|
||||
} else {
|
||||
let param_ty = params
|
||||
.iter()
|
||||
.map(|x| {
|
||||
get_type_from_type_annotation_kinds(
|
||||
top_level_defs,
|
||||
unifier,
|
||||
primitives,
|
||||
x,
|
||||
subst_list
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let class_def: &TopLevelDef = &def_read;
|
||||
let TopLevelDef::Class { fields, methods, type_vars, .. } = class_def else {
|
||||
unreachable!("should be class def here")
|
||||
};
|
||||
|
||||
let subst = {
|
||||
// check for compatible range
|
||||
// TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check
|
||||
let mut result: HashMap<u32, Type> = HashMap::new();
|
||||
for (tvar, p) in type_vars.iter().zip(param_ty) {
|
||||
if let TypeEnum::TVar { id, range, fields: None, name, loc } =
|
||||
unifier.get_ty(*tvar).as_ref()
|
||||
{
|
||||
let ok: bool = {
|
||||
// create a temp type var and unify to check compatibility
|
||||
p == *tvar || {
|
||||
let temp = unifier.get_fresh_var_with_range(
|
||||
range.as_slice(),
|
||||
*name,
|
||||
*loc,
|
||||
);
|
||||
unifier.unify(temp.0, p).is_ok()
|
||||
}
|
||||
};
|
||||
if ok {
|
||||
result.insert(*id, p);
|
||||
} else {
|
||||
return Err(format!(
|
||||
"cannot apply type {} to type variable with id {:?}",
|
||||
unifier.internal_stringify(
|
||||
p,
|
||||
&mut |id| format!("class{}", id),
|
||||
&mut |id| format!("typevar{}", id),
|
||||
&mut None
|
||||
),
|
||||
*id
|
||||
));
|
||||
if type_vars.len() != params.len() {
|
||||
return Err(HashSet::from([format!(
|
||||
"unexpected number of type parameters: expected {} but got {}",
|
||||
type_vars.len(),
|
||||
params.len()
|
||||
)]));
|
||||
}
|
||||
|
||||
let param_ty = params
|
||||
.iter()
|
||||
.map(|x| {
|
||||
get_type_from_type_annotation_kinds(top_level_defs, unifier, x, subst_list)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let subst = {
|
||||
// check for compatible range
|
||||
// TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check
|
||||
let mut result = VarMap::new();
|
||||
for (tvar, p) in type_vars.iter().zip(param_ty) {
|
||||
match unifier.get_ty(*tvar).as_ref() {
|
||||
TypeEnum::TVar {
|
||||
id,
|
||||
range,
|
||||
fields: None,
|
||||
name,
|
||||
loc,
|
||||
is_const_generic: false,
|
||||
} => {
|
||||
let ok: bool = {
|
||||
// create a temp type var and unify to check compatibility
|
||||
p == *tvar || {
|
||||
let temp = unifier.get_fresh_var_with_range(
|
||||
range.as_slice(),
|
||||
*name,
|
||||
*loc,
|
||||
);
|
||||
unifier.unify(temp.ty, p).is_ok()
|
||||
}
|
||||
};
|
||||
if ok {
|
||||
result.insert(*id, p);
|
||||
} else {
|
||||
unreachable!("must be generic type var")
|
||||
return Err(HashSet::from([format!(
|
||||
"cannot apply type {} to type variable with id {:?}",
|
||||
unifier.internal_stringify(
|
||||
p,
|
||||
&mut |id| format!("class{id}"),
|
||||
&mut |id| format!("typevar{id}"),
|
||||
&mut None
|
||||
),
|
||||
*id
|
||||
)]));
|
||||
}
|
||||
}
|
||||
result
|
||||
};
|
||||
let mut tobj_fields = methods
|
||||
.iter()
|
||||
.map(|(name, ty, _)| {
|
||||
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
|
||||
// methods are immutable
|
||||
(*name, (subst_ty, false))
|
||||
})
|
||||
.collect::<HashMap<_, _>>();
|
||||
tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| {
|
||||
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
|
||||
(*name, (subst_ty, *mutability))
|
||||
}));
|
||||
let need_subst = !subst.is_empty();
|
||||
let ty = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: *obj_id,
|
||||
fields: tobj_fields,
|
||||
params: subst,
|
||||
});
|
||||
if need_subst {
|
||||
subst_list.as_mut().map(|wl| wl.push(ty));
|
||||
|
||||
TypeEnum::TVar { id, range, name, loc, is_const_generic: true, .. } => {
|
||||
let ty = range[0];
|
||||
let ok: bool = {
|
||||
// create a temp type var and unify to check compatibility
|
||||
p == *tvar || {
|
||||
let temp = unifier.get_fresh_const_generic_var(ty, *name, *loc);
|
||||
unifier.unify(temp.ty, p).is_ok()
|
||||
}
|
||||
};
|
||||
if ok {
|
||||
result.insert(*id, p);
|
||||
} else {
|
||||
return Err(HashSet::from([format!(
|
||||
"cannot apply type {} to type variable {}",
|
||||
unifier.stringify(p),
|
||||
name.unwrap_or_else(|| format!("typevar{id}").into()),
|
||||
)]));
|
||||
}
|
||||
}
|
||||
|
||||
_ => unreachable!("must be generic type var"),
|
||||
}
|
||||
Ok(ty)
|
||||
}
|
||||
} else {
|
||||
unreachable!("should be class def here")
|
||||
result
|
||||
};
|
||||
let mut tobj_fields = methods
|
||||
.iter()
|
||||
.map(|(name, ty, _)| {
|
||||
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
|
||||
// methods are immutable
|
||||
(*name, (subst_ty, false))
|
||||
})
|
||||
.collect::<HashMap<_, _>>();
|
||||
tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| {
|
||||
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
|
||||
(*name, (subst_ty, *mutability))
|
||||
}));
|
||||
let need_subst = !subst.is_empty();
|
||||
let ty = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: *obj_id,
|
||||
fields: tobj_fields,
|
||||
params: subst,
|
||||
});
|
||||
if need_subst {
|
||||
if let Some(wl) = subst_list.as_mut() {
|
||||
wl.push(ty);
|
||||
}
|
||||
}
|
||||
Ok(ty)
|
||||
}
|
||||
TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty),
|
||||
TypeAnnotation::Literal(values) => {
|
||||
let values = values
|
||||
.iter()
|
||||
.map(SymbolValue::from_constant_inferred)
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|err| HashSet::from([err]))?;
|
||||
|
||||
let var = unifier.get_fresh_literal(values, None);
|
||||
Ok(var)
|
||||
}
|
||||
TypeAnnotation::Virtual(ty) => {
|
||||
let ty = get_type_from_type_annotation_kinds(
|
||||
top_level_defs,
|
||||
unifier,
|
||||
primitives,
|
||||
ty.as_ref(),
|
||||
subst_list
|
||||
subst_list,
|
||||
)?;
|
||||
Ok(unifier.add_ty(TypeEnum::TVirtual { ty }))
|
||||
}
|
||||
|
@ -410,9 +519,8 @@ pub fn get_type_from_type_annotation_kinds(
|
|||
let ty = get_type_from_type_annotation_kinds(
|
||||
top_level_defs,
|
||||
unifier,
|
||||
primitives,
|
||||
ty.as_ref(),
|
||||
subst_list
|
||||
subst_list,
|
||||
)?;
|
||||
Ok(unifier.add_ty(TypeEnum::TList { ty }))
|
||||
}
|
||||
|
@ -420,7 +528,7 @@ pub fn get_type_from_type_annotation_kinds(
|
|||
let tys = tys
|
||||
.iter()
|
||||
.map(|x| {
|
||||
get_type_from_type_annotation_kinds(top_level_defs, unifier, primitives, x, subst_list)
|
||||
get_type_from_type_annotation_kinds(top_level_defs, unifier, x, subst_list)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
Ok(unifier.add_ty(TypeEnum::TTuple { ty: tys }))
|
||||
|
@ -437,9 +545,10 @@ pub fn get_type_from_type_annotation_kinds(
|
|||
/// considered to be type variables associated with the class \
|
||||
/// \
|
||||
/// But note that here we do not make a duplication of `T`, `V`, we directly
|
||||
/// use them as they are in the TopLevelDef::Class since those in the
|
||||
/// TopLevelDef::Class.type_vars will be substitute later when seeing applications/instantiations
|
||||
/// use them as they are in the [`TopLevelDef::Class`] since those in the
|
||||
/// `TopLevelDef::Class.type_vars` will be substitute later when seeing applications/instantiations
|
||||
/// the Type of their fields and methods will also be subst when application/instantiation
|
||||
#[must_use]
|
||||
pub fn make_self_type_annotation(type_vars: &[Type], object_id: DefinitionId) -> TypeAnnotation {
|
||||
TypeAnnotation::CustomClass {
|
||||
id: object_id,
|
||||
|
@ -450,27 +559,25 @@ pub fn make_self_type_annotation(type_vars: &[Type], object_id: DefinitionId) ->
|
|||
/// get all the occurences of type vars contained in a type annotation
|
||||
/// e.g. `A[int, B[T], V, virtual[C[G]]]` => [T, V, G]
|
||||
/// this function will not make a duplicate of type var
|
||||
#[must_use]
|
||||
pub fn get_type_var_contained_in_type_annotation(ann: &TypeAnnotation) -> Vec<TypeAnnotation> {
|
||||
let mut result: Vec<TypeAnnotation> = Vec::new();
|
||||
match ann {
|
||||
TypeAnnotation::TypeVar(..) => result.push(ann.clone()),
|
||||
TypeAnnotation::Virtual(ann) => {
|
||||
result.extend(get_type_var_contained_in_type_annotation(ann.as_ref()))
|
||||
TypeAnnotation::Virtual(ann) | TypeAnnotation::List(ann) => {
|
||||
result.extend(get_type_var_contained_in_type_annotation(ann.as_ref()));
|
||||
}
|
||||
TypeAnnotation::CustomClass { params, .. } => {
|
||||
for p in params {
|
||||
result.extend(get_type_var_contained_in_type_annotation(p));
|
||||
}
|
||||
}
|
||||
TypeAnnotation::List(ann) => {
|
||||
result.extend(get_type_var_contained_in_type_annotation(ann.as_ref()))
|
||||
}
|
||||
TypeAnnotation::Tuple(anns) => {
|
||||
for a in anns {
|
||||
result.extend(get_type_var_contained_in_type_annotation(a));
|
||||
}
|
||||
}
|
||||
TypeAnnotation::Primitive(..) => {}
|
||||
TypeAnnotation::Primitive(..) | TypeAnnotation::Literal { .. } => {}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
@ -485,18 +592,18 @@ pub fn check_overload_type_annotation_compatible(
|
|||
(TypeAnnotation::Primitive(a), TypeAnnotation::Primitive(b)) => a == b,
|
||||
(TypeAnnotation::TypeVar(a), TypeAnnotation::TypeVar(b)) => {
|
||||
let a = unifier.get_ty(*a);
|
||||
let a = a.deref();
|
||||
let a = &*a;
|
||||
let b = unifier.get_ty(*b);
|
||||
let b = b.deref();
|
||||
if let (
|
||||
let b = &*b;
|
||||
let (
|
||||
TypeEnum::TVar { id: a, fields: None, .. },
|
||||
TypeEnum::TVar { id: b, fields: None, .. },
|
||||
) = (a, b)
|
||||
{
|
||||
a == b
|
||||
} else {
|
||||
else {
|
||||
unreachable!("must be type var")
|
||||
}
|
||||
};
|
||||
|
||||
a == b
|
||||
}
|
||||
(TypeAnnotation::Virtual(a), TypeAnnotation::Virtual(b))
|
||||
| (TypeAnnotation::List(a), TypeAnnotation::List(b)) => {
|
||||
|
|
|
@ -2,13 +2,17 @@ use crate::typecheck::typedef::TypeEnum;
|
|||
|
||||
use super::type_inferencer::Inferencer;
|
||||
use super::typedef::Type;
|
||||
use nac3parser::ast::{self, Expr, ExprKind, Stmt, StmtKind, StrRef};
|
||||
use nac3parser::ast::{
|
||||
self, Constant, Expr, ExprKind,
|
||||
Operator::{LShift, RShift},
|
||||
Stmt, StmtKind, StrRef,
|
||||
};
|
||||
use std::{collections::HashSet, iter::once};
|
||||
|
||||
impl<'a> Inferencer<'a> {
|
||||
fn should_have_value(&mut self, expr: &Expr<Option<Type>>) -> Result<(), String> {
|
||||
fn should_have_value(&mut self, expr: &Expr<Option<Type>>) -> Result<(), HashSet<String>> {
|
||||
if matches!(expr.custom, Some(ty) if self.unifier.unioned(ty, self.primitives.none)) {
|
||||
Err(format!("Error at {}: cannot have value none", expr.location))
|
||||
Err(HashSet::from([format!("Error at {}: cannot have value none", expr.location)]))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
|
@ -18,10 +22,11 @@ impl<'a> Inferencer<'a> {
|
|||
&mut self,
|
||||
pattern: &Expr<Option<Type>>,
|
||||
defined_identifiers: &mut HashSet<StrRef>,
|
||||
) -> Result<(), String> {
|
||||
) -> Result<(), HashSet<String>> {
|
||||
match &pattern.node {
|
||||
ast::ExprKind::Name { id, .. } if id == &"none".into() =>
|
||||
Err(format!("cannot assign to a `none` (at {})", pattern.location)),
|
||||
ExprKind::Name { id, .. } if id == &"none".into() => {
|
||||
Err(HashSet::from([format!("cannot assign to a `none` (at {})", pattern.location)]))
|
||||
}
|
||||
ExprKind::Name { id, .. } => {
|
||||
if !defined_identifiers.contains(id) {
|
||||
defined_identifiers.insert(*id);
|
||||
|
@ -30,7 +35,7 @@ impl<'a> Inferencer<'a> {
|
|||
Ok(())
|
||||
}
|
||||
ExprKind::Tuple { elts, .. } => {
|
||||
for elt in elts.iter() {
|
||||
for elt in elts {
|
||||
self.check_pattern(elt, defined_identifiers)?;
|
||||
self.should_have_value(elt)?;
|
||||
}
|
||||
|
@ -41,16 +46,17 @@ impl<'a> Inferencer<'a> {
|
|||
self.should_have_value(value)?;
|
||||
self.check_expr(slice, defined_identifiers)?;
|
||||
if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) {
|
||||
return Err(format!(
|
||||
return Err(HashSet::from([format!(
|
||||
"Error at {}: cannot assign to tuple element",
|
||||
value.location
|
||||
));
|
||||
)]));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
ExprKind::Constant { .. } => {
|
||||
Err(format!("cannot assign to a constant (at {})", pattern.location))
|
||||
}
|
||||
ExprKind::Constant { .. } => Err(HashSet::from([format!(
|
||||
"cannot assign to a constant (at {})",
|
||||
pattern.location
|
||||
)])),
|
||||
_ => self.check_expr(pattern, defined_identifiers),
|
||||
}
|
||||
}
|
||||
|
@ -59,15 +65,17 @@ impl<'a> Inferencer<'a> {
|
|||
&mut self,
|
||||
expr: &Expr<Option<Type>>,
|
||||
defined_identifiers: &mut HashSet<StrRef>,
|
||||
) -> Result<(), String> {
|
||||
) -> Result<(), HashSet<String>> {
|
||||
// there are some cases where the custom field is None
|
||||
if let Some(ty) = &expr.custom {
|
||||
if !self.unifier.is_concrete(*ty, &self.function_data.bound_variables) {
|
||||
return Err(format!(
|
||||
if !matches!(&expr.node, ExprKind::Constant { value: Constant::Ellipsis, .. })
|
||||
&& !self.unifier.is_concrete(*ty, &self.function_data.bound_variables)
|
||||
{
|
||||
return Err(HashSet::from([format!(
|
||||
"expected concrete type at {} but got {}",
|
||||
expr.location,
|
||||
self.unifier.get_ty(*ty).get_type_name()
|
||||
));
|
||||
)]));
|
||||
}
|
||||
}
|
||||
match &expr.node {
|
||||
|
@ -87,10 +95,10 @@ impl<'a> Inferencer<'a> {
|
|||
self.defined_identifiers.insert(*id);
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(format!(
|
||||
return Err(HashSet::from([format!(
|
||||
"type error at identifier `{}` ({}) at {}",
|
||||
id, e, expr.location
|
||||
));
|
||||
)]))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -98,7 +106,7 @@ impl<'a> Inferencer<'a> {
|
|||
ExprKind::List { elts, .. }
|
||||
| ExprKind::Tuple { elts, .. }
|
||||
| ExprKind::BoolOp { values: elts, .. } => {
|
||||
for elt in elts.iter() {
|
||||
for elt in elts {
|
||||
self.check_expr(elt, defined_identifiers)?;
|
||||
self.should_have_value(elt)?;
|
||||
}
|
||||
|
@ -107,11 +115,25 @@ impl<'a> Inferencer<'a> {
|
|||
self.check_expr(value, defined_identifiers)?;
|
||||
self.should_have_value(value)?;
|
||||
}
|
||||
ExprKind::BinOp { left, right, .. } => {
|
||||
ExprKind::BinOp { left, op, right } => {
|
||||
self.check_expr(left, defined_identifiers)?;
|
||||
self.check_expr(right, defined_identifiers)?;
|
||||
self.should_have_value(left)?;
|
||||
self.should_have_value(right)?;
|
||||
|
||||
// Check whether a bitwise shift has a negative RHS constant value
|
||||
if *op == LShift || *op == RShift {
|
||||
if let ExprKind::Constant { value, .. } = &right.node {
|
||||
let Constant::Int(rhs_val) = value else { unreachable!() };
|
||||
|
||||
if *rhs_val < 0 {
|
||||
return Err(HashSet::from([format!(
|
||||
"shift count is negative at {}",
|
||||
right.location
|
||||
)]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ExprKind::UnaryOp { operand, .. } => {
|
||||
self.check_expr(operand, defined_identifiers)?;
|
||||
|
@ -141,7 +163,7 @@ impl<'a> Inferencer<'a> {
|
|||
}
|
||||
ExprKind::Lambda { args, body } => {
|
||||
let mut defined_identifiers = defined_identifiers.clone();
|
||||
for arg in args.args.iter() {
|
||||
for arg in &args.args {
|
||||
// TODO: should we check the types here?
|
||||
if !defined_identifiers.contains(&arg.node.arg) {
|
||||
defined_identifiers.insert(arg.node.arg);
|
||||
|
@ -179,24 +201,45 @@ impl<'a> Inferencer<'a> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Check that the return value is a non-`alloca` type, effectively only allowing primitive types.
|
||||
///
|
||||
/// This is a workaround preventing the caller from using a variable `alloca`-ed in the body, which
|
||||
/// is freed when the function returns.
|
||||
fn check_return_value_ty(&mut self, ret_ty: Type) -> bool {
|
||||
match &*self.unifier.get_ty_immutable(ret_ty) {
|
||||
TypeEnum::TObj { .. } => [
|
||||
self.primitives.int32,
|
||||
self.primitives.int64,
|
||||
self.primitives.uint32,
|
||||
self.primitives.uint64,
|
||||
self.primitives.float,
|
||||
self.primitives.bool,
|
||||
]
|
||||
.iter()
|
||||
.any(|allowed_ty| self.unifier.unioned(ret_ty, *allowed_ty)),
|
||||
TypeEnum::TTuple { ty } => ty.iter().all(|t| self.check_return_value_ty(*t)),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
// check statements for proper identifier def-use and return on all paths
|
||||
fn check_stmt(
|
||||
&mut self,
|
||||
stmt: &Stmt<Option<Type>>,
|
||||
defined_identifiers: &mut HashSet<StrRef>,
|
||||
) -> Result<bool, String> {
|
||||
) -> Result<bool, HashSet<String>> {
|
||||
match &stmt.node {
|
||||
StmtKind::For { target, iter, body, orelse, .. } => {
|
||||
self.check_expr(iter, defined_identifiers)?;
|
||||
self.should_have_value(iter)?;
|
||||
let mut local_defined_identifiers = defined_identifiers.clone();
|
||||
for stmt in orelse.iter() {
|
||||
for stmt in orelse {
|
||||
self.check_stmt(stmt, &mut local_defined_identifiers)?;
|
||||
}
|
||||
let mut local_defined_identifiers = defined_identifiers.clone();
|
||||
self.check_pattern(target, &mut local_defined_identifiers)?;
|
||||
self.should_have_value(target)?;
|
||||
for stmt in body.iter() {
|
||||
for stmt in body {
|
||||
self.check_stmt(stmt, &mut local_defined_identifiers)?;
|
||||
}
|
||||
Ok(false)
|
||||
|
@ -209,7 +252,7 @@ impl<'a> Inferencer<'a> {
|
|||
let body_returned = self.check_block(body, &mut body_identifiers)?;
|
||||
let orelse_returned = self.check_block(orelse, &mut orelse_identifiers)?;
|
||||
|
||||
for ident in body_identifiers.iter() {
|
||||
for ident in &body_identifiers {
|
||||
if !defined_identifiers.contains(ident) && orelse_identifiers.contains(ident) {
|
||||
defined_identifiers.insert(*ident);
|
||||
}
|
||||
|
@ -226,7 +269,7 @@ impl<'a> Inferencer<'a> {
|
|||
}
|
||||
StmtKind::With { items, body, .. } => {
|
||||
let mut new_defined_identifiers = defined_identifiers.clone();
|
||||
for item in items.iter() {
|
||||
for item in items {
|
||||
self.check_expr(&item.context_expr, defined_identifiers)?;
|
||||
if let Some(var) = item.optional_vars.as_ref() {
|
||||
self.check_pattern(var, &mut new_defined_identifiers)?;
|
||||
|
@ -238,7 +281,7 @@ impl<'a> Inferencer<'a> {
|
|||
StmtKind::Try { body, handlers, orelse, finalbody, .. } => {
|
||||
self.check_block(body, &mut defined_identifiers.clone())?;
|
||||
self.check_block(orelse, &mut defined_identifiers.clone())?;
|
||||
for handler in handlers.iter() {
|
||||
for handler in handlers {
|
||||
let mut defined_identifiers = defined_identifiers.clone();
|
||||
let ast::ExcepthandlerKind::ExceptHandler { name, body, .. } = &handler.node;
|
||||
if let Some(name) = name {
|
||||
|
@ -273,6 +316,30 @@ impl<'a> Inferencer<'a> {
|
|||
if let Some(value) = value {
|
||||
self.check_expr(value, defined_identifiers)?;
|
||||
self.should_have_value(value)?;
|
||||
|
||||
// Check that the return value is a non-`alloca` type, effectively only allowing primitive types.
|
||||
// This is a workaround preventing the caller from using a variable `alloca`-ed in the body, which
|
||||
// is freed when the function returns.
|
||||
if let Some(ret_ty) = value.custom {
|
||||
// Explicitly allow ellipsis as a return value, as the type of the ellipsis is contextually
|
||||
// inferred and just generates an unconditional assertion
|
||||
if matches!(
|
||||
value.node,
|
||||
ExprKind::Constant { value: Constant::Ellipsis, .. }
|
||||
) {
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
if !self.check_return_value_ty(ret_ty) {
|
||||
return Err(HashSet::from([
|
||||
format!(
|
||||
"return value of type {} must be a primitive or a tuple of primitives at {}",
|
||||
self.unifier.stringify(ret_ty),
|
||||
value.location,
|
||||
),
|
||||
]));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(true)
|
||||
}
|
||||
|
@ -291,11 +358,11 @@ impl<'a> Inferencer<'a> {
|
|||
&mut self,
|
||||
block: &[Stmt<Option<Type>>],
|
||||
defined_identifiers: &mut HashSet<StrRef>,
|
||||
) -> Result<bool, String> {
|
||||
) -> Result<bool, HashSet<String>> {
|
||||
let mut ret = false;
|
||||
for stmt in block {
|
||||
if ret {
|
||||
return Err(format!("dead code at {:?}", stmt.location));
|
||||
eprintln!("warning: dead code at {}\n", stmt.location);
|
||||
}
|
||||
if self.check_stmt(stmt, defined_identifiers)? {
|
||||
ret = true;
|
||||
|
|
|
@ -1,13 +1,20 @@
|
|||
use crate::symbol_resolver::SymbolValue;
|
||||
use crate::toplevel::helper::PrimDef;
|
||||
use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys};
|
||||
use crate::typecheck::{
|
||||
type_inferencer::*,
|
||||
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier},
|
||||
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
|
||||
};
|
||||
use nac3parser::ast::{self, StrRef};
|
||||
use itertools::Itertools;
|
||||
use nac3parser::ast::StrRef;
|
||||
use nac3parser::ast::{Cmpop, Operator, Unaryop};
|
||||
use std::cmp::max;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
use strum::IntoEnumIterator;
|
||||
|
||||
pub fn binop_name(op: &Operator) -> &'static str {
|
||||
#[must_use]
|
||||
pub fn binop_name(op: Operator) -> &'static str {
|
||||
match op {
|
||||
Operator::Add => "__add__",
|
||||
Operator::Sub => "__sub__",
|
||||
|
@ -25,7 +32,8 @@ pub fn binop_name(op: &Operator) -> &'static str {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn binop_assign_name(op: &Operator) -> &'static str {
|
||||
#[must_use]
|
||||
pub fn binop_assign_name(op: Operator) -> &'static str {
|
||||
match op {
|
||||
Operator::Add => "__iadd__",
|
||||
Operator::Sub => "__isub__",
|
||||
|
@ -43,7 +51,8 @@ pub fn binop_assign_name(op: &Operator) -> &'static str {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn unaryop_name(op: &Unaryop) -> &'static str {
|
||||
#[must_use]
|
||||
pub fn unaryop_name(op: Unaryop) -> &'static str {
|
||||
match op {
|
||||
Unaryop::UAdd => "__pos__",
|
||||
Unaryop::USub => "__neg__",
|
||||
|
@ -52,7 +61,8 @@ pub fn unaryop_name(op: &Unaryop) -> &'static str {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn comparison_name(op: &Cmpop) -> Option<&'static str> {
|
||||
#[must_use]
|
||||
pub fn comparison_name(op: Cmpop) -> Option<&'static str> {
|
||||
match op {
|
||||
Cmpop::Lt => Some("__lt__"),
|
||||
Cmpop::LtE => Some("__le__"),
|
||||
|
@ -83,26 +93,30 @@ where
|
|||
|
||||
pub fn impl_binop(
|
||||
unifier: &mut Unifier,
|
||||
store: &PrimitiveStore,
|
||||
_store: &PrimitiveStore,
|
||||
ty: Type,
|
||||
other_ty: &[Type],
|
||||
ret_ty: Type,
|
||||
ops: &[ast::Operator],
|
||||
ret_ty: Option<Type>,
|
||||
ops: &[Operator],
|
||||
) {
|
||||
with_fields(unifier, ty, |unifier, fields| {
|
||||
let (other_ty, other_var_id) = if other_ty.len() == 1 {
|
||||
(other_ty[0], None)
|
||||
} else {
|
||||
let (ty, var_id) = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None);
|
||||
(ty, Some(var_id))
|
||||
let tvar = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None);
|
||||
(tvar.ty, Some(tvar.id))
|
||||
};
|
||||
|
||||
let function_vars = if let Some(var_id) = other_var_id {
|
||||
vec![(var_id, other_ty)].into_iter().collect::<HashMap<_, _>>()
|
||||
vec![(var_id, other_ty)].into_iter().collect::<VarMap>()
|
||||
} else {
|
||||
HashMap::new()
|
||||
VarMap::new()
|
||||
};
|
||||
|
||||
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty);
|
||||
|
||||
for op in ops {
|
||||
fields.insert(binop_name(op).into(), {
|
||||
fields.insert(binop_name(*op).into(), {
|
||||
(
|
||||
unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
ret: ret_ty,
|
||||
|
@ -117,10 +131,10 @@ pub fn impl_binop(
|
|||
)
|
||||
});
|
||||
|
||||
fields.insert(binop_assign_name(op).into(), {
|
||||
fields.insert(binop_assign_name(*op).into(), {
|
||||
(
|
||||
unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
ret: store.none,
|
||||
ret: ret_ty,
|
||||
vars: function_vars.clone(),
|
||||
args: vec![FuncArg {
|
||||
ty: other_ty,
|
||||
|
@ -135,15 +149,17 @@ pub fn impl_binop(
|
|||
});
|
||||
}
|
||||
|
||||
pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Type, ops: &[ast::Unaryop]) {
|
||||
pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Option<Type>, ops: &[Unaryop]) {
|
||||
with_fields(unifier, ty, |unifier, fields| {
|
||||
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty);
|
||||
|
||||
for op in ops {
|
||||
fields.insert(
|
||||
unaryop_name(op).into(),
|
||||
unaryop_name(*op).into(),
|
||||
(
|
||||
unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
ret: ret_ty,
|
||||
vars: HashMap::new(),
|
||||
vars: VarMap::new(),
|
||||
args: vec![],
|
||||
})),
|
||||
false,
|
||||
|
@ -155,19 +171,35 @@ pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Type, ops: &[ast::U
|
|||
|
||||
pub fn impl_cmpop(
|
||||
unifier: &mut Unifier,
|
||||
store: &PrimitiveStore,
|
||||
_store: &PrimitiveStore,
|
||||
ty: Type,
|
||||
other_ty: Type,
|
||||
ops: &[ast::Cmpop],
|
||||
other_ty: &[Type],
|
||||
ops: &[Cmpop],
|
||||
ret_ty: Option<Type>,
|
||||
) {
|
||||
with_fields(unifier, ty, |unifier, fields| {
|
||||
let (other_ty, other_var_id) = if other_ty.len() == 1 {
|
||||
(other_ty[0], None)
|
||||
} else {
|
||||
let tvar = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None);
|
||||
(tvar.ty, Some(tvar.id))
|
||||
};
|
||||
|
||||
let function_vars = if let Some(var_id) = other_var_id {
|
||||
vec![(var_id, other_ty)].into_iter().collect::<VarMap>()
|
||||
} else {
|
||||
VarMap::new()
|
||||
};
|
||||
|
||||
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty);
|
||||
|
||||
for op in ops {
|
||||
fields.insert(
|
||||
comparison_name(op).unwrap().into(),
|
||||
comparison_name(*op).unwrap().into(),
|
||||
(
|
||||
unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
ret: store.bool,
|
||||
vars: HashMap::new(),
|
||||
ret: ret_ty,
|
||||
vars: function_vars.clone(),
|
||||
args: vec![FuncArg {
|
||||
ty: other_ty,
|
||||
default_value: None,
|
||||
|
@ -181,13 +213,13 @@ pub fn impl_cmpop(
|
|||
});
|
||||
}
|
||||
|
||||
/// Add, Sub, Mult
|
||||
/// `Add`, `Sub`, `Mult`
|
||||
pub fn impl_basic_arithmetic(
|
||||
unifier: &mut Unifier,
|
||||
store: &PrimitiveStore,
|
||||
ty: Type,
|
||||
other_ty: &[Type],
|
||||
ret_ty: Type,
|
||||
ret_ty: Option<Type>,
|
||||
) {
|
||||
impl_binop(
|
||||
unifier,
|
||||
|
@ -195,94 +227,373 @@ pub fn impl_basic_arithmetic(
|
|||
ty,
|
||||
other_ty,
|
||||
ret_ty,
|
||||
&[ast::Operator::Add, ast::Operator::Sub, ast::Operator::Mult],
|
||||
)
|
||||
&[Operator::Add, Operator::Sub, Operator::Mult],
|
||||
);
|
||||
}
|
||||
|
||||
/// Pow
|
||||
/// `Pow`
|
||||
pub fn impl_pow(
|
||||
unifier: &mut Unifier,
|
||||
store: &PrimitiveStore,
|
||||
ty: Type,
|
||||
other_ty: &[Type],
|
||||
ret_ty: Type,
|
||||
ret_ty: Option<Type>,
|
||||
) {
|
||||
impl_binop(unifier, store, ty, other_ty, ret_ty, &[ast::Operator::Pow])
|
||||
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Pow]);
|
||||
}
|
||||
|
||||
/// BitOr, BitXor, BitAnd
|
||||
/// `BitOr`, `BitXor`, `BitAnd`
|
||||
pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
|
||||
impl_binop(
|
||||
unifier,
|
||||
store,
|
||||
ty,
|
||||
&[ty],
|
||||
ty,
|
||||
&[ast::Operator::BitAnd, ast::Operator::BitOr, ast::Operator::BitXor],
|
||||
)
|
||||
Some(ty),
|
||||
&[Operator::BitAnd, Operator::BitOr, Operator::BitXor],
|
||||
);
|
||||
}
|
||||
|
||||
/// LShift, RShift
|
||||
/// `LShift`, `RShift`
|
||||
pub fn impl_bitwise_shift(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
|
||||
impl_binop(unifier, store, ty, &[ty], ty, &[ast::Operator::LShift, ast::Operator::RShift])
|
||||
impl_binop(
|
||||
unifier,
|
||||
store,
|
||||
ty,
|
||||
&[store.int32, store.uint32],
|
||||
Some(ty),
|
||||
&[Operator::LShift, Operator::RShift],
|
||||
);
|
||||
}
|
||||
|
||||
/// Div
|
||||
pub fn impl_div(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type]) {
|
||||
impl_binop(unifier, store, ty, other_ty, store.float, &[ast::Operator::Div])
|
||||
/// `Div`
|
||||
pub fn impl_div(
|
||||
unifier: &mut Unifier,
|
||||
store: &PrimitiveStore,
|
||||
ty: Type,
|
||||
other_ty: &[Type],
|
||||
ret_ty: Option<Type>,
|
||||
) {
|
||||
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Div]);
|
||||
}
|
||||
|
||||
/// FloorDiv
|
||||
/// `FloorDiv`
|
||||
pub fn impl_floordiv(
|
||||
unifier: &mut Unifier,
|
||||
store: &PrimitiveStore,
|
||||
ty: Type,
|
||||
other_ty: &[Type],
|
||||
ret_ty: Type,
|
||||
ret_ty: Option<Type>,
|
||||
) {
|
||||
impl_binop(unifier, store, ty, other_ty, ret_ty, &[ast::Operator::FloorDiv])
|
||||
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::FloorDiv]);
|
||||
}
|
||||
|
||||
/// Mod
|
||||
/// `Mod`
|
||||
pub fn impl_mod(
|
||||
unifier: &mut Unifier,
|
||||
store: &PrimitiveStore,
|
||||
ty: Type,
|
||||
other_ty: &[Type],
|
||||
ret_ty: Type,
|
||||
ret_ty: Option<Type>,
|
||||
) {
|
||||
impl_binop(unifier, store, ty, other_ty, ret_ty, &[ast::Operator::Mod])
|
||||
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Mod]);
|
||||
}
|
||||
|
||||
/// UAdd, USub
|
||||
pub fn impl_sign(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type) {
|
||||
impl_unaryop(unifier, ty, ty, &[ast::Unaryop::UAdd, ast::Unaryop::USub])
|
||||
/// [`Operator::MatMult`]
|
||||
pub fn impl_matmul(
|
||||
unifier: &mut Unifier,
|
||||
store: &PrimitiveStore,
|
||||
ty: Type,
|
||||
other_ty: &[Type],
|
||||
ret_ty: Option<Type>,
|
||||
) {
|
||||
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::MatMult]);
|
||||
}
|
||||
|
||||
/// Invert
|
||||
pub fn impl_invert(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type) {
|
||||
impl_unaryop(unifier, ty, ty, &[ast::Unaryop::Invert])
|
||||
/// `UAdd`, `USub`
|
||||
pub fn impl_sign(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, ret_ty: Option<Type>) {
|
||||
impl_unaryop(unifier, ty, ret_ty, &[Unaryop::UAdd, Unaryop::USub]);
|
||||
}
|
||||
|
||||
/// Not
|
||||
pub fn impl_not(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
|
||||
impl_unaryop(unifier, ty, store.bool, &[ast::Unaryop::Not])
|
||||
/// `Invert`
|
||||
pub fn impl_invert(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, ret_ty: Option<Type>) {
|
||||
impl_unaryop(unifier, ty, ret_ty, &[Unaryop::Invert]);
|
||||
}
|
||||
|
||||
/// Lt, LtE, Gt, GtE
|
||||
pub fn impl_comparison(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type) {
|
||||
/// `Not`
|
||||
pub fn impl_not(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, ret_ty: Option<Type>) {
|
||||
impl_unaryop(unifier, ty, ret_ty, &[Unaryop::Not]);
|
||||
}
|
||||
|
||||
/// `Lt`, `LtE`, `Gt`, `GtE`
|
||||
pub fn impl_comparison(
|
||||
unifier: &mut Unifier,
|
||||
store: &PrimitiveStore,
|
||||
ty: Type,
|
||||
other_ty: &[Type],
|
||||
ret_ty: Option<Type>,
|
||||
) {
|
||||
impl_cmpop(
|
||||
unifier,
|
||||
store,
|
||||
ty,
|
||||
other_ty,
|
||||
&[ast::Cmpop::Lt, ast::Cmpop::Gt, ast::Cmpop::LtE, ast::Cmpop::GtE],
|
||||
)
|
||||
&[Cmpop::Lt, Cmpop::Gt, Cmpop::LtE, Cmpop::GtE],
|
||||
ret_ty,
|
||||
);
|
||||
}
|
||||
|
||||
/// Eq, NotEq
|
||||
pub fn impl_eq(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
|
||||
impl_cmpop(unifier, store, ty, ty, &[ast::Cmpop::Eq, ast::Cmpop::NotEq])
|
||||
/// `Eq`, `NotEq`
|
||||
pub fn impl_eq(
|
||||
unifier: &mut Unifier,
|
||||
store: &PrimitiveStore,
|
||||
ty: Type,
|
||||
other_ty: &[Type],
|
||||
ret_ty: Option<Type>,
|
||||
) {
|
||||
impl_cmpop(unifier, store, ty, other_ty, &[Cmpop::Eq, Cmpop::NotEq], ret_ty);
|
||||
}
|
||||
|
||||
/// Returns the expected return type of binary operations with at least one `ndarray` operand.
|
||||
pub fn typeof_ndarray_broadcast(
|
||||
unifier: &mut Unifier,
|
||||
primitives: &PrimitiveStore,
|
||||
left: Type,
|
||||
right: Type,
|
||||
) -> Result<Type, String> {
|
||||
let is_left_ndarray = left.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||
let is_right_ndarray = right.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||
|
||||
assert!(is_left_ndarray || is_right_ndarray);
|
||||
|
||||
if is_left_ndarray && is_right_ndarray {
|
||||
// Perform broadcasting on two ndarray operands.
|
||||
|
||||
let (left_ty_dtype, left_ty_ndims) = unpack_ndarray_var_tys(unifier, left);
|
||||
let (right_ty_dtype, right_ty_ndims) = unpack_ndarray_var_tys(unifier, right);
|
||||
|
||||
assert!(unifier.unioned(left_ty_dtype, right_ty_dtype));
|
||||
|
||||
let left_ty_ndims = match &*unifier.get_ty_immutable(left_ty_ndims) {
|
||||
TypeEnum::TLiteral { values, .. } => values.clone(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let right_ty_ndims = match &*unifier.get_ty_immutable(right_ty_ndims) {
|
||||
TypeEnum::TLiteral { values, .. } => values.clone(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let res_ndims = left_ty_ndims
|
||||
.into_iter()
|
||||
.cartesian_product(right_ty_ndims)
|
||||
.map(|(left, right)| {
|
||||
let left_val = u64::try_from(left).unwrap();
|
||||
let right_val = u64::try_from(right).unwrap();
|
||||
|
||||
max(left_val, right_val)
|
||||
})
|
||||
.unique()
|
||||
.map(SymbolValue::U64)
|
||||
.collect_vec();
|
||||
let res_ndims = unifier.get_fresh_literal(res_ndims, None);
|
||||
|
||||
Ok(make_ndarray_ty(unifier, primitives, Some(left_ty_dtype), Some(res_ndims)))
|
||||
} else {
|
||||
let (ndarray_ty, scalar_ty) = if is_left_ndarray { (left, right) } else { (right, left) };
|
||||
|
||||
let (ndarray_ty_dtype, _) = unpack_ndarray_var_tys(unifier, ndarray_ty);
|
||||
|
||||
if unifier.unioned(ndarray_ty_dtype, scalar_ty) {
|
||||
Ok(ndarray_ty)
|
||||
} else {
|
||||
let (expected_ty, actual_ty) = if is_left_ndarray {
|
||||
(ndarray_ty_dtype, scalar_ty)
|
||||
} else {
|
||||
(scalar_ty, ndarray_ty_dtype)
|
||||
};
|
||||
|
||||
Err(format!(
|
||||
"Expected right-hand side operand to be {}, got {}",
|
||||
unifier.stringify(expected_ty),
|
||||
unifier.stringify(actual_ty),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the return type given a binary operator and its primitive operands.
|
||||
pub fn typeof_binop(
|
||||
unifier: &mut Unifier,
|
||||
primitives: &PrimitiveStore,
|
||||
op: Operator,
|
||||
lhs: Type,
|
||||
rhs: Type,
|
||||
) -> Result<Option<Type>, String> {
|
||||
let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||
let is_right_ndarray = rhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||
|
||||
Ok(Some(match op {
|
||||
Operator::Add | Operator::Sub | Operator::Mult | Operator::Mod | Operator::FloorDiv => {
|
||||
if is_left_ndarray || is_right_ndarray {
|
||||
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
|
||||
} else if unifier.unioned(lhs, rhs) {
|
||||
lhs
|
||||
} else {
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
|
||||
Operator::MatMult => {
|
||||
let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs);
|
||||
let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) {
|
||||
TypeEnum::TLiteral { values, .. } => {
|
||||
assert_eq!(values.len(), 1);
|
||||
u64::try_from(values[0].clone()).unwrap()
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let (_, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs);
|
||||
let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) {
|
||||
TypeEnum::TLiteral { values, .. } => {
|
||||
assert_eq!(values.len(), 1);
|
||||
u64::try_from(values[0].clone()).unwrap()
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
match (lhs_ndims, rhs_ndims) {
|
||||
(2, 2) => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?,
|
||||
(lhs, rhs) if lhs == 0 || rhs == 0 => {
|
||||
return Err(format!(
|
||||
"Input operand {} does not have enough dimensions (has {lhs}, requires {rhs})",
|
||||
u8::from(rhs == 0)
|
||||
))
|
||||
}
|
||||
(lhs, rhs) => {
|
||||
return Err(format!(
|
||||
"ndarray.__matmul__ on {lhs}D and {rhs}D operands not supported"
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Operator::Div => {
|
||||
if is_left_ndarray || is_right_ndarray {
|
||||
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
|
||||
} else if unifier.unioned(lhs, rhs) {
|
||||
primitives.float
|
||||
} else {
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
|
||||
Operator::Pow => {
|
||||
if is_left_ndarray || is_right_ndarray {
|
||||
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
|
||||
} else if [
|
||||
primitives.int32,
|
||||
primitives.int64,
|
||||
primitives.uint32,
|
||||
primitives.uint64,
|
||||
primitives.float,
|
||||
]
|
||||
.into_iter()
|
||||
.any(|ty| unifier.unioned(lhs, ty))
|
||||
{
|
||||
lhs
|
||||
} else {
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
|
||||
Operator::LShift | Operator::RShift => lhs,
|
||||
Operator::BitOr | Operator::BitXor | Operator::BitAnd => {
|
||||
if unifier.unioned(lhs, rhs) {
|
||||
lhs
|
||||
} else {
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn typeof_unaryop(
|
||||
unifier: &mut Unifier,
|
||||
primitives: &PrimitiveStore,
|
||||
op: Unaryop,
|
||||
operand: Type,
|
||||
) -> Result<Option<Type>, String> {
|
||||
let operand_obj_id = operand.obj_id(unifier);
|
||||
|
||||
if op == Unaryop::Not
|
||||
&& operand_obj_id.is_some_and(|id| id == primitives.ndarray.obj_id(unifier).unwrap())
|
||||
{
|
||||
return Err(
|
||||
"The truth value of an array with more than one element is ambiguous".to_string()
|
||||
);
|
||||
}
|
||||
|
||||
Ok(match op {
|
||||
Unaryop::Not => match operand_obj_id {
|
||||
Some(v) if v == PrimDef::NDArray.id() => Some(operand),
|
||||
Some(_) => Some(primitives.bool),
|
||||
_ => None,
|
||||
},
|
||||
|
||||
Unaryop::Invert => {
|
||||
if operand_obj_id.is_some_and(|id| id == PrimDef::Bool.id()) {
|
||||
Some(primitives.int32)
|
||||
} else if operand_obj_id.is_some_and(|id| PrimDef::iter().any(|prim| id == prim.id())) {
|
||||
Some(operand)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
Unaryop::UAdd | Unaryop::USub => {
|
||||
if operand_obj_id.is_some_and(|id| id == PrimDef::NDArray.id()) {
|
||||
let (dtype, _) = unpack_ndarray_var_tys(unifier, operand);
|
||||
if dtype.obj_id(unifier).is_some_and(|id| id == PrimDef::Bool.id()) {
|
||||
return Err(if op == Unaryop::UAdd {
|
||||
"The ufunc 'positive' cannot be applied to ndarray[bool, N]".to_string()
|
||||
} else {
|
||||
"The numpy boolean negative, the `-` operator, is not supported, use the `~` operator function instead.".to_string()
|
||||
});
|
||||
}
|
||||
|
||||
Some(operand)
|
||||
} else if operand_obj_id.is_some_and(|id| id == PrimDef::Bool.id()) {
|
||||
Some(primitives.int32)
|
||||
} else if operand_obj_id.is_some_and(|id| PrimDef::iter().any(|prim| id == prim.id())) {
|
||||
Some(operand)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the return type given a comparison operator and its primitive operands.
|
||||
pub fn typeof_cmpop(
|
||||
unifier: &mut Unifier,
|
||||
primitives: &PrimitiveStore,
|
||||
_op: Cmpop,
|
||||
lhs: Type,
|
||||
rhs: Type,
|
||||
) -> Result<Option<Type>, String> {
|
||||
let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||
let is_right_ndarray = rhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||
|
||||
Ok(Some(if is_left_ndarray || is_right_ndarray {
|
||||
let brd = typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?;
|
||||
let (_, ndims) = unpack_ndarray_var_tys(unifier, brd);
|
||||
|
||||
make_ndarray_ty(unifier, primitives, Some(primitives.bool), Some(ndims))
|
||||
} else if unifier.unioned(lhs, rhs) {
|
||||
primitives.bool
|
||||
} else {
|
||||
return Ok(None);
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) {
|
||||
|
@ -293,38 +604,71 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
|||
bool: bool_t,
|
||||
uint32: uint32_t,
|
||||
uint64: uint64_t,
|
||||
ndarray: ndarray_t,
|
||||
..
|
||||
} = *store;
|
||||
let size_t = store.usize();
|
||||
|
||||
/* int ======== */
|
||||
for t in [int32_t, int64_t, uint32_t, uint64_t] {
|
||||
impl_basic_arithmetic(unifier, store, t, &[t], t);
|
||||
impl_pow(unifier, store, t, &[t], t);
|
||||
let ndarray_int_t = make_ndarray_ty(unifier, store, Some(t), None);
|
||||
impl_basic_arithmetic(unifier, store, t, &[t, ndarray_int_t], None);
|
||||
impl_pow(unifier, store, t, &[t, ndarray_int_t], None);
|
||||
impl_bitwise_arithmetic(unifier, store, t);
|
||||
impl_bitwise_shift(unifier, store, t);
|
||||
impl_div(unifier, store, t, &[t]);
|
||||
impl_floordiv(unifier, store, t, &[t], t);
|
||||
impl_mod(unifier, store, t, &[t], t);
|
||||
impl_invert(unifier, store, t);
|
||||
impl_not(unifier, store, t);
|
||||
impl_comparison(unifier, store, t, t);
|
||||
impl_eq(unifier, store, t);
|
||||
impl_div(unifier, store, t, &[t, ndarray_int_t], None);
|
||||
impl_floordiv(unifier, store, t, &[t, ndarray_int_t], None);
|
||||
impl_mod(unifier, store, t, &[t, ndarray_int_t], None);
|
||||
impl_invert(unifier, store, t, Some(t));
|
||||
impl_not(unifier, store, t, Some(bool_t));
|
||||
impl_comparison(unifier, store, t, &[t, ndarray_int_t], None);
|
||||
impl_eq(unifier, store, t, &[t, ndarray_int_t], None);
|
||||
}
|
||||
for t in [int32_t, int64_t] {
|
||||
impl_sign(unifier, store, t);
|
||||
impl_sign(unifier, store, t, Some(t));
|
||||
}
|
||||
|
||||
/* float ======== */
|
||||
impl_basic_arithmetic(unifier, store, float_t, &[float_t], float_t);
|
||||
impl_pow(unifier, store, float_t, &[int32_t, float_t], float_t);
|
||||
impl_div(unifier, store, float_t, &[float_t]);
|
||||
impl_floordiv(unifier, store, float_t, &[float_t], float_t);
|
||||
impl_mod(unifier, store, float_t, &[float_t], float_t);
|
||||
impl_sign(unifier, store, float_t);
|
||||
impl_not(unifier, store, float_t);
|
||||
impl_comparison(unifier, store, float_t, float_t);
|
||||
impl_eq(unifier, store, float_t);
|
||||
let ndarray_float_t = make_ndarray_ty(unifier, store, Some(float_t), None);
|
||||
let ndarray_int32_t = make_ndarray_ty(unifier, store, Some(int32_t), None);
|
||||
impl_basic_arithmetic(unifier, store, float_t, &[float_t, ndarray_float_t], None);
|
||||
impl_pow(unifier, store, float_t, &[int32_t, float_t, ndarray_int32_t, ndarray_float_t], None);
|
||||
impl_div(unifier, store, float_t, &[float_t, ndarray_float_t], None);
|
||||
impl_floordiv(unifier, store, float_t, &[float_t, ndarray_float_t], None);
|
||||
impl_mod(unifier, store, float_t, &[float_t, ndarray_float_t], None);
|
||||
impl_sign(unifier, store, float_t, Some(float_t));
|
||||
impl_not(unifier, store, float_t, Some(bool_t));
|
||||
impl_comparison(unifier, store, float_t, &[float_t, ndarray_float_t], None);
|
||||
impl_eq(unifier, store, float_t, &[float_t, ndarray_float_t], None);
|
||||
|
||||
/* bool ======== */
|
||||
impl_not(unifier, store, bool_t);
|
||||
impl_eq(unifier, store, bool_t);
|
||||
let ndarray_bool_t = make_ndarray_ty(unifier, store, Some(bool_t), None);
|
||||
impl_invert(unifier, store, bool_t, Some(int32_t));
|
||||
impl_not(unifier, store, bool_t, Some(bool_t));
|
||||
impl_sign(unifier, store, bool_t, Some(int32_t));
|
||||
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None);
|
||||
|
||||
/* ndarray ===== */
|
||||
let ndarray_usized_ndims_tvar =
|
||||
unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None);
|
||||
let ndarray_unsized_t =
|
||||
make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.ty));
|
||||
let (ndarray_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_t);
|
||||
let (ndarray_unsized_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_unsized_t);
|
||||
impl_basic_arithmetic(
|
||||
unifier,
|
||||
store,
|
||||
ndarray_t,
|
||||
&[ndarray_unsized_t, ndarray_unsized_dtype_t],
|
||||
None,
|
||||
);
|
||||
impl_pow(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||
impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None);
|
||||
impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||
impl_matmul(unifier, store, ndarray_t, &[ndarray_t], Some(ndarray_t));
|
||||
impl_sign(unifier, store, ndarray_t, Some(ndarray_t));
|
||||
impl_invert(unifier, store, ndarray_t, Some(ndarray_t));
|
||||
impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||
impl_comparison(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||
}
|
||||
|
|
|
@ -43,15 +43,18 @@ pub struct TypeError {
|
|||
}
|
||||
|
||||
impl TypeError {
|
||||
#[must_use]
|
||||
pub fn new(kind: TypeErrorKind, loc: Option<Location>) -> TypeError {
|
||||
TypeError { kind, loc }
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn at(mut self, loc: Option<Location>) -> TypeError {
|
||||
self.loc = self.loc.or(loc);
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn to_display(self, unifier: &Unifier) -> DisplayTypeError {
|
||||
DisplayTypeError { err: self, unifier }
|
||||
}
|
||||
|
@ -64,8 +67,8 @@ pub struct DisplayTypeError<'a> {
|
|||
|
||||
fn loc_to_str(loc: Option<Location>) -> String {
|
||||
match loc {
|
||||
Some(loc) => format!("(in {})", loc),
|
||||
None => "".to_string(),
|
||||
Some(loc) => format!("(in {loc})"),
|
||||
None => String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -75,22 +78,18 @@ impl<'a> Display for DisplayTypeError<'a> {
|
|||
let mut notes = Some(HashMap::new());
|
||||
match &self.err.kind {
|
||||
TooManyArguments { expected, got } => {
|
||||
write!(f, "Too many arguments. Expected {} but got {}", expected, got)
|
||||
write!(f, "Too many arguments. Expected {expected} but got {got}")
|
||||
}
|
||||
MissingArgs(args) => {
|
||||
write!(f, "Missing arguments: {}", args)
|
||||
write!(f, "Missing arguments: {args}")
|
||||
}
|
||||
UnknownArgName(name) => {
|
||||
write!(f, "Unknown argument name: {}", name)
|
||||
write!(f, "Unknown argument name: {name}")
|
||||
}
|
||||
IncorrectArgType { name, expected, got } => {
|
||||
let expected = self.unifier.stringify_with_notes(*expected, &mut notes);
|
||||
let got = self.unifier.stringify_with_notes(*got, &mut notes);
|
||||
write!(
|
||||
f,
|
||||
"Incorrect argument type for {}. Expected {}, but got {}",
|
||||
name, expected, got
|
||||
)
|
||||
write!(f, "Incorrect argument type for {name}. Expected {expected}, but got {got}")
|
||||
}
|
||||
FieldUnificationError { field, types, loc } => {
|
||||
let lhs = self.unifier.stringify_with_notes(types.0, &mut notes);
|
||||
|
@ -126,7 +125,7 @@ impl<'a> Display for DisplayTypeError<'a> {
|
|||
);
|
||||
if let Some(loc) = loc {
|
||||
result?;
|
||||
write!(f, " (in {})", loc)?;
|
||||
write!(f, " (in {loc})")?;
|
||||
return Ok(());
|
||||
}
|
||||
result
|
||||
|
@ -136,12 +135,12 @@ impl<'a> Display for DisplayTypeError<'a> {
|
|||
{
|
||||
let t1 = self.unifier.stringify_with_notes(*t1, &mut notes);
|
||||
let t2 = self.unifier.stringify_with_notes(*t2, &mut notes);
|
||||
write!(f, "Tuple length mismatch: got {} and {}", t1, t2)
|
||||
write!(f, "Tuple length mismatch: got {t1} and {t2}")
|
||||
}
|
||||
_ => {
|
||||
let t1 = self.unifier.stringify_with_notes(*t1, &mut notes);
|
||||
let t2 = self.unifier.stringify_with_notes(*t2, &mut notes);
|
||||
write!(f, "Incompatible types: {} and {}", t1, t2)
|
||||
write!(f, "Incompatible types: {t1} and {t2}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -150,18 +149,17 @@ impl<'a> Display for DisplayTypeError<'a> {
|
|||
write!(f, "Cannot assign to an element of a tuple")
|
||||
} else {
|
||||
let t = self.unifier.stringify_with_notes(*t, &mut notes);
|
||||
write!(f, "Cannot assign to field {} of {}, which is immutable", name, t)
|
||||
write!(f, "Cannot assign to field {name} of {t}, which is immutable")
|
||||
}
|
||||
}
|
||||
NoSuchField(name, t) => {
|
||||
let t = self.unifier.stringify_with_notes(*t, &mut notes);
|
||||
write!(f, "`{}::{}` field/method does not exist", t, name)
|
||||
write!(f, "`{t}::{name}` field/method does not exist")
|
||||
}
|
||||
TupleIndexOutOfBounds { index, len } => {
|
||||
write!(
|
||||
f,
|
||||
"Tuple index out of bounds. Got {} but tuple has only {} elements",
|
||||
index, len
|
||||
"Tuple index out of bounds. Got {index} but tuple has only {len} elements"
|
||||
)
|
||||
}
|
||||
RequiresTypeAnn => {
|
||||
|
@ -172,13 +170,13 @@ impl<'a> Display for DisplayTypeError<'a> {
|
|||
}
|
||||
}?;
|
||||
if let Some(loc) = self.err.loc {
|
||||
write!(f, " at {}", loc)?;
|
||||
write!(f, " at {loc}")?;
|
||||
}
|
||||
let notes = notes.unwrap();
|
||||
if !notes.is_empty() {
|
||||
write!(f, "\n\nNotes:")?;
|
||||
for line in notes.values() {
|
||||
write!(f, "\n {}", line)?;
|
||||
write!(f, "\n {line}")?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -3,12 +3,12 @@ use super::*;
|
|||
use crate::{
|
||||
codegen::CodeGenContext,
|
||||
symbol_resolver::ValueEnum,
|
||||
toplevel::{DefinitionId, TopLevelDef},
|
||||
toplevel::{helper::PrimDef, DefinitionId, TopLevelDef},
|
||||
};
|
||||
use indoc::indoc;
|
||||
use itertools::zip;
|
||||
use nac3parser::parser::parse_program;
|
||||
use parking_lot::RwLock;
|
||||
use std::iter::zip;
|
||||
use test_case::test_case;
|
||||
|
||||
struct Resolver {
|
||||
|
@ -20,7 +20,7 @@ struct Resolver {
|
|||
impl SymbolResolver for Resolver {
|
||||
fn get_default_param_value(
|
||||
&self,
|
||||
_: &nac3parser::ast::Expr,
|
||||
_: &ast::Expr,
|
||||
) -> Option<crate::symbol_resolver::SymbolValue> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
@ -43,8 +43,11 @@ impl SymbolResolver for Resolver {
|
|||
unimplemented!()
|
||||
}
|
||||
|
||||
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, String> {
|
||||
self.id_to_def.get(&id).cloned().ok_or_else(|| "Unknown identifier".to_string())
|
||||
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> {
|
||||
self.id_to_def
|
||||
.get(&id)
|
||||
.cloned()
|
||||
.ok_or_else(|| HashSet::from(["Unknown identifier".to_string()]))
|
||||
}
|
||||
|
||||
fn get_string_id(&self, _: &str) -> i32 {
|
||||
|
@ -62,7 +65,7 @@ struct TestEnvironment {
|
|||
pub primitives: PrimitiveStore,
|
||||
pub id_to_name: HashMap<usize, StrRef>,
|
||||
pub identifier_mapping: HashMap<StrRef, Type>,
|
||||
pub virtual_checks: Vec<(Type, Type, nac3parser::ast::Location)>,
|
||||
pub virtual_checks: Vec<(Type, Type, Location)>,
|
||||
pub calls: HashMap<CodeLocation, CallId>,
|
||||
pub top_level: TopLevelContext,
|
||||
}
|
||||
|
@ -72,67 +75,75 @@ impl TestEnvironment {
|
|||
let mut unifier = Unifier::new();
|
||||
|
||||
let int32 = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(0),
|
||||
obj_id: PrimDef::Int32.id(),
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
with_fields(&mut unifier, int32, |unifier, fields| {
|
||||
let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }],
|
||||
ret: int32,
|
||||
vars: HashMap::new(),
|
||||
vars: VarMap::new(),
|
||||
}));
|
||||
fields.insert("__add__".into(), (add_ty, false));
|
||||
});
|
||||
let int64 = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(1),
|
||||
obj_id: PrimDef::Int64.id(),
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let float = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(2),
|
||||
obj_id: PrimDef::Float.id(),
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let bool = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(3),
|
||||
obj_id: PrimDef::Bool.id(),
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let none = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(4),
|
||||
obj_id: PrimDef::None.id(),
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let range = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(5),
|
||||
obj_id: PrimDef::Range.id(),
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let str = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(6),
|
||||
obj_id: PrimDef::Str.id(),
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let exception = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(7),
|
||||
obj_id: PrimDef::Exception.id(),
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let uint32 = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(8),
|
||||
obj_id: PrimDef::UInt32.id(),
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let uint64 = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(9),
|
||||
obj_id: PrimDef::UInt64.id(),
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let option = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(10),
|
||||
obj_id: PrimDef::Option.id(),
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
|
||||
let ndarray_ndims_tvar =
|
||||
unifier.get_fresh_const_generic_var(uint64, Some("ndarray_ndims".into()), None);
|
||||
let ndarray = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: PrimDef::NDArray.id(),
|
||||
fields: HashMap::new(),
|
||||
params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
|
||||
});
|
||||
let primitives = PrimitiveStore {
|
||||
int32,
|
||||
|
@ -146,7 +157,10 @@ impl TestEnvironment {
|
|||
uint32,
|
||||
uint64,
|
||||
option,
|
||||
ndarray,
|
||||
size_t: 64,
|
||||
};
|
||||
unifier.put_primitive_store(&primitives);
|
||||
set_primitives_magic_methods(&primitives, &mut unifier);
|
||||
|
||||
let id_to_name = [
|
||||
|
@ -197,67 +211,72 @@ impl TestEnvironment {
|
|||
let mut identifier_mapping = HashMap::new();
|
||||
let mut top_level_defs: Vec<Arc<RwLock<TopLevelDef>>> = Vec::new();
|
||||
let int32 = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(0),
|
||||
obj_id: PrimDef::Int32.id(),
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
with_fields(&mut unifier, int32, |unifier, fields| {
|
||||
let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }],
|
||||
ret: int32,
|
||||
vars: HashMap::new(),
|
||||
vars: VarMap::new(),
|
||||
}));
|
||||
fields.insert("__add__".into(), (add_ty, false));
|
||||
});
|
||||
let int64 = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(1),
|
||||
obj_id: PrimDef::Int64.id(),
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let float = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(2),
|
||||
obj_id: PrimDef::Float.id(),
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let bool = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(3),
|
||||
obj_id: PrimDef::Bool.id(),
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let none = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(4),
|
||||
obj_id: PrimDef::None.id(),
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let range = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(5),
|
||||
obj_id: PrimDef::Range.id(),
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let str = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(6),
|
||||
obj_id: PrimDef::Str.id(),
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let exception = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(7),
|
||||
obj_id: PrimDef::Exception.id(),
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let uint32 = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(8),
|
||||
obj_id: PrimDef::UInt32.id(),
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let uint64 = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(9),
|
||||
obj_id: PrimDef::UInt64.id(),
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let option = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(10),
|
||||
obj_id: PrimDef::Option.id(),
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let ndarray = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: PrimDef::NDArray.id(),
|
||||
fields: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
identifier_mapping.insert("None".into(), none);
|
||||
for (i, name) in ["int32", "int64", "float", "bool", "none", "range", "str", "Exception"]
|
||||
|
@ -293,21 +312,25 @@ impl TestEnvironment {
|
|||
uint32,
|
||||
uint64,
|
||||
option,
|
||||
ndarray,
|
||||
size_t: 64,
|
||||
};
|
||||
|
||||
let (v0, id) = unifier.get_dummy_var();
|
||||
unifier.put_primitive_store(&primitives);
|
||||
|
||||
let tvar = unifier.get_dummy_var();
|
||||
|
||||
let foo_ty = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(defs + 1),
|
||||
fields: [("a".into(), (v0, true))].iter().cloned().collect::<HashMap<_, _>>(),
|
||||
params: [(id, v0)].iter().cloned().collect::<HashMap<_, _>>(),
|
||||
fields: [("a".into(), (tvar.ty, true))].iter().cloned().collect::<HashMap<_, _>>(),
|
||||
params: into_var_map([tvar]),
|
||||
});
|
||||
top_level_defs.push(
|
||||
RwLock::new(TopLevelDef::Class {
|
||||
name: "Foo".into(),
|
||||
object_id: DefinitionId(defs + 1),
|
||||
type_vars: vec![v0],
|
||||
fields: [("a".into(), v0, true)].into(),
|
||||
type_vars: vec![tvar.ty],
|
||||
fields: [("a".into(), tvar.ty, true)].into(),
|
||||
methods: Default::default(),
|
||||
ancestors: Default::default(),
|
||||
resolver: None,
|
||||
|
@ -322,7 +345,7 @@ impl TestEnvironment {
|
|||
unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![],
|
||||
ret: foo_ty,
|
||||
vars: [(id, v0)].iter().cloned().collect(),
|
||||
vars: into_var_map([tvar]),
|
||||
})),
|
||||
);
|
||||
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
use itertools::{zip, Itertools};
|
||||
use indexmap::IndexMap;
|
||||
use itertools::Itertools;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Display;
|
||||
use std::fmt::{self, Display};
|
||||
use std::iter::zip;
|
||||
use std::rc::Rc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::{borrow::Cow, collections::HashSet};
|
||||
|
@ -12,6 +14,7 @@ use super::type_error::{TypeError, TypeErrorKind};
|
|||
use super::unification_table::{UnificationKey, UnificationTable};
|
||||
use crate::symbol_resolver::SymbolValue;
|
||||
use crate::toplevel::{DefinitionId, TopLevelContext, TopLevelDef};
|
||||
use crate::typecheck::type_inferencer::PrimitiveStore;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
|
@ -23,7 +26,52 @@ pub type Type = UnificationKey;
|
|||
pub struct CallId(pub(super) usize);
|
||||
|
||||
pub type Mapping<K, V = Type> = HashMap<K, V>;
|
||||
type VarMap = Mapping<u32>;
|
||||
pub type IndexMapping<K, V = Type> = IndexMap<K, V>;
|
||||
|
||||
/// ID of a Python type variable. Specific to `nac3core`.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
pub struct TypeVarId(u32);
|
||||
|
||||
impl From<TypeVarId> for u32 {
|
||||
fn from(value: TypeVarId) -> Self {
|
||||
value.0
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for TypeVarId {
|
||||
// NOTE: Must output the string of the ID value. Certain unit tests rely on string comparisons.
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_fmt(format_args!("{}", self.0))
|
||||
}
|
||||
}
|
||||
|
||||
/// A Python type variable. Used by `nac3core` during type inference.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct TypeVar {
|
||||
/// `nac3core`'s internal [`TypeVarId`] of this type variable.
|
||||
pub id: TypeVarId,
|
||||
|
||||
/// The assigned [`Type`] of this Python type variable.
|
||||
pub ty: Type,
|
||||
}
|
||||
|
||||
/// The mapping between [`TypeVarId`] and [unifier type][`Type`].
|
||||
pub type VarMap = IndexMapping<TypeVarId>;
|
||||
|
||||
/// Build a [`VarMap`] from an iterator of [`TypeVar`]
|
||||
///
|
||||
/// The resulting [`VarMap`] will have the same order as the input iterator.
|
||||
pub fn into_var_map<I>(vars: I) -> VarMap
|
||||
where
|
||||
I: IntoIterator<Item = TypeVar>,
|
||||
{
|
||||
vars.into_iter().map(|var| (var.id, var.ty)).collect()
|
||||
}
|
||||
|
||||
/// Get an iterator of [`TypeVar`]s from a [`VarMap`]
|
||||
pub fn iter_type_vars(var_map: &VarMap) -> impl Iterator<Item = TypeVar> + '_ {
|
||||
var_map.iter().map(|(&id, &ty)| TypeVar { id, ty })
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Call {
|
||||
|
@ -55,13 +103,14 @@ pub enum RecordKey {
|
|||
}
|
||||
|
||||
impl Type {
|
||||
// a wrapper function for cleaner code so that we don't need to
|
||||
// write this long pattern matching just to get the field `obj_id`
|
||||
pub fn get_obj_id(self, unifier: &Unifier) -> DefinitionId {
|
||||
if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty_immutable(self).as_ref() {
|
||||
*obj_id
|
||||
/// Wrapper function for cleaner code so that we don't need to write this long pattern matching
|
||||
/// just to get the field `obj_id`.
|
||||
#[must_use]
|
||||
pub fn obj_id(self, unifier: &Unifier) -> Option<DefinitionId> {
|
||||
if let TypeEnum::TObj { obj_id, .. } = &*unifier.get_ty_immutable(self) {
|
||||
Some(*obj_id)
|
||||
} else {
|
||||
unreachable!("expect a object type")
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -96,8 +145,8 @@ impl From<i32> for RecordKey {
|
|||
impl Display for RecordKey {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
RecordKey::Str(s) => write!(f, "{}", s),
|
||||
RecordKey::Int(i) => write!(f, "{}", i),
|
||||
RecordKey::Str(s) => write!(f, "{s}"),
|
||||
RecordKey::Int(i) => write!(f, "{i}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -110,50 +159,85 @@ pub struct RecordField {
|
|||
}
|
||||
|
||||
impl RecordField {
|
||||
#[must_use]
|
||||
pub fn new(ty: Type, mutable: bool, loc: Option<Location>) -> RecordField {
|
||||
RecordField { ty, mutable, loc }
|
||||
}
|
||||
}
|
||||
|
||||
/// Category of variable and value types.
|
||||
#[derive(Clone)]
|
||||
pub enum TypeEnum {
|
||||
TRigidVar {
|
||||
id: u32,
|
||||
id: TypeVarId,
|
||||
name: Option<StrRef>,
|
||||
loc: Option<Location>,
|
||||
},
|
||||
|
||||
/// A type variable.
|
||||
TVar {
|
||||
id: u32,
|
||||
id: TypeVarId,
|
||||
// empty indicates this is not a struct/tuple/list
|
||||
fields: Option<Mapping<RecordKey, RecordField>>,
|
||||
// empty indicates no restriction
|
||||
range: Vec<Type>,
|
||||
name: Option<StrRef>,
|
||||
loc: Option<Location>,
|
||||
/// Whether this type variable refers to a const-generic variable.
|
||||
is_const_generic: bool,
|
||||
},
|
||||
|
||||
/// A literal generic type matching `typing.Literal`.
|
||||
TLiteral {
|
||||
/// The value of the constant.
|
||||
values: Vec<SymbolValue>,
|
||||
loc: Option<Location>,
|
||||
},
|
||||
|
||||
/// A tuple type.
|
||||
TTuple {
|
||||
/// The types of elements present in this tuple.
|
||||
ty: Vec<Type>,
|
||||
},
|
||||
|
||||
/// A list type.
|
||||
TList {
|
||||
/// The type of elements present in this list.
|
||||
ty: Type,
|
||||
},
|
||||
|
||||
/// An object type.
|
||||
TObj {
|
||||
/// The [DefintionId] of this object type.
|
||||
obj_id: DefinitionId,
|
||||
|
||||
/// The fields present in this object type.
|
||||
///
|
||||
/// The key of the [Mapping] is the identifier of the field, while the value is a tuple
|
||||
/// containing the [Type] of the field, and a `bool` indicating whether the field is a
|
||||
/// variable (as opposed to a function).
|
||||
fields: Mapping<StrRef, (Type, bool)>,
|
||||
|
||||
/// Mapping between the ID of type variables and the [Type] representing the type variables
|
||||
/// of this object type.
|
||||
params: VarMap,
|
||||
},
|
||||
TVirtual {
|
||||
ty: Type,
|
||||
},
|
||||
TCall(Vec<CallId>),
|
||||
|
||||
/// A function type.
|
||||
TFunc(FunSignature),
|
||||
}
|
||||
|
||||
impl TypeEnum {
|
||||
#[must_use]
|
||||
pub fn get_type_name(&self) -> &'static str {
|
||||
match self {
|
||||
TypeEnum::TRigidVar { .. } => "TRigidVar",
|
||||
TypeEnum::TVar { .. } => "TVar",
|
||||
TypeEnum::TLiteral { .. } => "TConstant",
|
||||
TypeEnum::TTuple { .. } => "TTuple",
|
||||
TypeEnum::TList { .. } => "TList",
|
||||
TypeEnum::TObj { .. } => "TObj",
|
||||
|
@ -171,9 +255,10 @@ pub struct Unifier {
|
|||
pub(crate) top_level: Option<Arc<TopLevelContext>>,
|
||||
pub(crate) unification_table: UnificationTable<Rc<TypeEnum>>,
|
||||
pub(crate) calls: Vec<Rc<Call>>,
|
||||
var_id: u32,
|
||||
var_id_counter: u32,
|
||||
unify_cache: HashSet<(Type, Type)>,
|
||||
snapshot: Option<(usize, u32)>
|
||||
snapshot: Option<(usize, u32)>,
|
||||
primitive_store: Option<PrimitiveStore>,
|
||||
}
|
||||
|
||||
impl Default for Unifier {
|
||||
|
@ -184,17 +269,36 @@ impl Default for Unifier {
|
|||
|
||||
impl Unifier {
|
||||
/// Get an empty unifier
|
||||
#[must_use]
|
||||
pub fn new() -> Unifier {
|
||||
Unifier {
|
||||
unification_table: UnificationTable::new(),
|
||||
var_id: 0,
|
||||
var_id_counter: 0,
|
||||
calls: Vec::new(),
|
||||
unify_cache: HashSet::new(),
|
||||
top_level: None,
|
||||
snapshot: None,
|
||||
primitive_store: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets the [`PrimitiveStore`] instance within this `Unifier`.
|
||||
///
|
||||
/// This function can only be invoked once. Any subsequent invocations will result in an
|
||||
/// assertion error.
|
||||
pub fn put_primitive_store(&mut self, primitives: &PrimitiveStore) {
|
||||
assert!(self.primitive_store.is_none());
|
||||
self.primitive_store.replace(*primitives);
|
||||
}
|
||||
|
||||
/// Returns the [`UnificationTable`] associated with this `Unifier`.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The use of this function is discouraged under most circumstances. Only use this function if
|
||||
/// in-place manipulation of type variables and/or type fields is necessary, otherwise prefer to
|
||||
/// [add a new type][`Unifier::add_ty`] and [unify the type][`Unifier::unify`] with an existing
|
||||
/// type.
|
||||
pub unsafe fn get_unification_table(&mut self) -> &mut UnificationTable<Rc<TypeEnum>> {
|
||||
&mut self.unification_table
|
||||
}
|
||||
|
@ -208,37 +312,39 @@ impl Unifier {
|
|||
let lock = unifier.lock().unwrap();
|
||||
Unifier {
|
||||
unification_table: UnificationTable::from_send(&lock.0),
|
||||
var_id: lock.1,
|
||||
var_id_counter: lock.1,
|
||||
calls: lock.2.iter().map(|v| Rc::new(v.clone())).collect_vec(),
|
||||
top_level: None,
|
||||
unify_cache: HashSet::new(),
|
||||
snapshot: None,
|
||||
primitive_store: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn get_shared_unifier(&self) -> SharedUnifier {
|
||||
Arc::new(Mutex::new((
|
||||
self.unification_table.get_send(),
|
||||
self.var_id,
|
||||
self.var_id_counter,
|
||||
self.calls.iter().map(|v| v.as_ref().clone()).collect_vec(),
|
||||
)))
|
||||
}
|
||||
|
||||
/// Register a type to the unifier.
|
||||
/// Returns a key in the unification_table.
|
||||
/// Returns a key in the `unification_table`.
|
||||
pub fn add_ty(&mut self, a: TypeEnum) -> Type {
|
||||
self.unification_table.new_key(Rc::new(a))
|
||||
}
|
||||
|
||||
pub fn add_record(&mut self, fields: Mapping<RecordKey, RecordField>) -> Type {
|
||||
let id = self.var_id + 1;
|
||||
self.var_id += 1;
|
||||
let id = self.generate_var_id();
|
||||
self.add_ty(TypeEnum::TVar {
|
||||
id,
|
||||
range: vec![],
|
||||
fields: Some(fields),
|
||||
name: None,
|
||||
loc: None,
|
||||
is_const_generic: false,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -257,6 +363,7 @@ impl Unifier {
|
|||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn get_call_signature_immutable(&self, id: CallId) -> Option<FunSignature> {
|
||||
let fun = self.calls.get(id.0).unwrap().fun.borrow().unwrap();
|
||||
if let TypeEnum::TFunc(sign) = &*self.get_ty_immutable(fun) {
|
||||
|
@ -270,44 +377,79 @@ impl Unifier {
|
|||
self.unification_table.get_representative(ty)
|
||||
}
|
||||
|
||||
/// Get the TypeEnum of a type.
|
||||
/// Get the `TypeEnum` of a type.
|
||||
pub fn get_ty(&mut self, a: Type) -> Rc<TypeEnum> {
|
||||
self.unification_table.probe_value(a).clone()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn get_ty_immutable(&self, a: Type) -> Rc<TypeEnum> {
|
||||
self.unification_table.probe_value_immutable(a).clone()
|
||||
}
|
||||
|
||||
pub fn get_fresh_rigid_var(
|
||||
&mut self,
|
||||
name: Option<StrRef>,
|
||||
loc: Option<Location>,
|
||||
) -> (Type, u32) {
|
||||
let id = self.var_id + 1;
|
||||
self.var_id += 1;
|
||||
(self.add_ty(TypeEnum::TRigidVar { id, name, loc }), id)
|
||||
pub fn get_fresh_rigid_var(&mut self, name: Option<StrRef>, loc: Option<Location>) -> TypeVar {
|
||||
let id = self.generate_var_id();
|
||||
let ty = self.add_ty(TypeEnum::TRigidVar { id, name, loc });
|
||||
TypeVar { id, ty }
|
||||
}
|
||||
|
||||
pub fn get_dummy_var(&mut self) -> (Type, u32) {
|
||||
pub fn get_dummy_var(&mut self) -> TypeVar {
|
||||
self.get_fresh_var_with_range(&[], None, None)
|
||||
}
|
||||
|
||||
pub fn get_fresh_var(&mut self, name: Option<StrRef>, loc: Option<Location>) -> (Type, u32) {
|
||||
/// Returns a fresh [type variable][TypeEnum::TVar] with no associated range.
|
||||
///
|
||||
/// This type variable can be instantiated by any type.
|
||||
pub fn get_fresh_var(&mut self, name: Option<StrRef>, loc: Option<Location>) -> TypeVar {
|
||||
self.get_fresh_var_with_range(&[], name, loc)
|
||||
}
|
||||
|
||||
/// Get a fresh type variable.
|
||||
/// Returns a fresh [type variable][TypeEnum::TVar] with the range specified by `range`.
|
||||
///
|
||||
/// This type variable can be instantiated by any type present in `range`.
|
||||
pub fn get_fresh_var_with_range(
|
||||
&mut self,
|
||||
range: &[Type],
|
||||
name: Option<StrRef>,
|
||||
loc: Option<Location>,
|
||||
) -> (Type, u32) {
|
||||
let id = self.var_id + 1;
|
||||
self.var_id += 1;
|
||||
) -> TypeVar {
|
||||
let range = range.to_vec();
|
||||
(self.add_ty(TypeEnum::TVar { id, range, fields: None, name, loc }), id)
|
||||
|
||||
let id = self.generate_var_id();
|
||||
let ty = self.add_ty(TypeEnum::TVar {
|
||||
id,
|
||||
range,
|
||||
fields: None,
|
||||
name,
|
||||
loc,
|
||||
is_const_generic: false,
|
||||
});
|
||||
TypeVar { id, ty }
|
||||
}
|
||||
|
||||
/// Returns a fresh type representing a constant generic variable with the given underlying type `ty`.
|
||||
pub fn get_fresh_const_generic_var(
|
||||
&mut self,
|
||||
ty: Type,
|
||||
name: Option<StrRef>,
|
||||
loc: Option<Location>,
|
||||
) -> TypeVar {
|
||||
let id = self.generate_var_id();
|
||||
let ty = self.add_ty(TypeEnum::TVar {
|
||||
id,
|
||||
range: vec![ty],
|
||||
fields: None,
|
||||
name,
|
||||
loc,
|
||||
is_const_generic: true,
|
||||
});
|
||||
TypeVar { id, ty }
|
||||
}
|
||||
|
||||
/// Returns a fresh type representing a [literal][TypeEnum::TConstant] with the given `values`.
|
||||
pub fn get_fresh_literal(&mut self, values: Vec<SymbolValue>, loc: Option<Location>) -> Type {
|
||||
let ty_enum = TypeEnum::TLiteral { values: values.into_iter().dedup().collect(), loc };
|
||||
self.add_ty(ty_enum)
|
||||
}
|
||||
|
||||
/// Unification would not unify rigid variables with other types, but we want to do this for
|
||||
|
@ -326,8 +468,9 @@ impl Unifier {
|
|||
Some(
|
||||
range
|
||||
.iter()
|
||||
.map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty]))
|
||||
.flatten()
|
||||
.flat_map(|ty| {
|
||||
self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty])
|
||||
})
|
||||
.collect_vec(),
|
||||
)
|
||||
}
|
||||
|
@ -353,7 +496,7 @@ impl Unifier {
|
|||
}
|
||||
}
|
||||
TypeEnum::TObj { params, .. } => {
|
||||
let (keys, params): (Vec<u32>, Vec<Type>) = params.iter().unzip();
|
||||
let (keys, params): (Vec<TypeVarId>, Vec<Type>) = params.iter().unzip();
|
||||
let params = params
|
||||
.into_iter()
|
||||
.map(|ty| self.get_instantiations(ty).unwrap_or_else(|| vec![ty]))
|
||||
|
@ -368,7 +511,7 @@ impl Unifier {
|
|||
.map(|params| {
|
||||
self.subst(
|
||||
ty,
|
||||
&zip(keys.iter().cloned(), params.iter().cloned()).collect(),
|
||||
&zip(keys.iter().copied(), params.iter().copied()).collect(),
|
||||
)
|
||||
.unwrap_or(ty)
|
||||
})
|
||||
|
@ -383,18 +526,21 @@ impl Unifier {
|
|||
pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool {
|
||||
use TypeEnum::*;
|
||||
match &*self.get_ty(a) {
|
||||
TRigidVar { .. } => true,
|
||||
TRigidVar { .. }
|
||||
| TLiteral { .. }
|
||||
// functions are instantiated for each call sites, so the function type can contain
|
||||
// type variables.
|
||||
| TFunc { .. } => true,
|
||||
|
||||
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
|
||||
TCall { .. } => false,
|
||||
TList { ty } => self.is_concrete(*ty, allowed_typevars),
|
||||
TList { ty }
|
||||
| TVirtual { ty } => self.is_concrete(*ty, allowed_typevars),
|
||||
|
||||
TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)),
|
||||
TObj { params: vars, .. } => {
|
||||
vars.values().all(|ty| self.is_concrete(*ty, allowed_typevars))
|
||||
}
|
||||
// functions are instantiated for each call sites, so the function type can contain
|
||||
// type variables.
|
||||
TFunc { .. } => true,
|
||||
TVirtual { ty } => self.is_concrete(*ty, allowed_typevars),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -424,15 +570,10 @@ impl Unifier {
|
|||
}
|
||||
|
||||
let Call { posargs, kwargs, ret, fun, loc } = call;
|
||||
let instantiated = self.instantiate_fun(b, &*signature);
|
||||
let instantiated = self.instantiate_fun(b, signature);
|
||||
let r = self.get_ty(instantiated);
|
||||
let r = r.as_ref();
|
||||
let signature;
|
||||
if let TypeEnum::TFunc(s) = &*r {
|
||||
signature = s;
|
||||
} else {
|
||||
unreachable!();
|
||||
}
|
||||
let TypeEnum::TFunc(signature) = r else { unreachable!() };
|
||||
// we check to make sure that all required arguments (those without default
|
||||
// arguments) are provided, and do not provide the same argument twice.
|
||||
let mut required = required.to_vec();
|
||||
|
@ -455,17 +596,14 @@ impl Unifier {
|
|||
TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc)
|
||||
})?;
|
||||
}
|
||||
for (k, t) in kwargs.iter() {
|
||||
for (k, t) in kwargs {
|
||||
if let Some(i) = required.iter().position(|v| v == k) {
|
||||
required.remove(i);
|
||||
}
|
||||
let i = all_names
|
||||
.iter()
|
||||
.position(|v| &v.0 == k)
|
||||
.ok_or_else(|| {
|
||||
self.restore_snapshot();
|
||||
TypeError::new(TypeErrorKind::UnknownArgName(*k), *loc)
|
||||
})?;
|
||||
let i = all_names.iter().position(|v| &v.0 == k).ok_or_else(|| {
|
||||
self.restore_snapshot();
|
||||
TypeError::new(TypeErrorKind::UnknownArgName(*k), *loc)
|
||||
})?;
|
||||
let (name, expected) = all_names.remove(i);
|
||||
self.unify_impl(expected, *t, false).map_err(|_| {
|
||||
self.restore_snapshot();
|
||||
|
@ -531,8 +669,17 @@ impl Unifier {
|
|||
};
|
||||
match (&*ty_a, &*ty_b) {
|
||||
(
|
||||
TVar { fields: fields1, id, name: name1, loc: loc1, .. },
|
||||
TVar { fields: fields2, id: id2, name: name2, loc: loc2, .. },
|
||||
TVar {
|
||||
fields: fields1, id, name: name1, loc: loc1, is_const_generic: false, ..
|
||||
},
|
||||
TVar {
|
||||
fields: fields2,
|
||||
id: id2,
|
||||
name: name2,
|
||||
loc: loc2,
|
||||
is_const_generic: false,
|
||||
..
|
||||
},
|
||||
) => {
|
||||
let new_fields = match (fields1, fields2) {
|
||||
(None, None) => None,
|
||||
|
@ -542,7 +689,7 @@ impl Unifier {
|
|||
}
|
||||
(Some(fields1), Some(fields2)) => {
|
||||
let mut new_fields: Mapping<_, _> = fields2.clone();
|
||||
for (key, val1) in fields1.iter() {
|
||||
for (key, val1) in fields1 {
|
||||
if let Some(val2) = fields2.get(key) {
|
||||
self.unify_impl(val1.ty, val2.ty, false).map_err(|_| {
|
||||
TypeError::new(
|
||||
|
@ -571,9 +718,9 @@ impl Unifier {
|
|||
};
|
||||
let intersection = self
|
||||
.get_intersection(a, b)
|
||||
.map_err(|_| TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None))?
|
||||
.map_err(|()| TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None))?
|
||||
.unwrap();
|
||||
let range = if let TypeEnum::TVar { range, .. } = &*self.get_ty(intersection) {
|
||||
let range = if let TVar { range, .. } = &*self.get_ty(intersection) {
|
||||
range.clone()
|
||||
} else {
|
||||
unreachable!()
|
||||
|
@ -581,16 +728,17 @@ impl Unifier {
|
|||
self.unification_table.unify(a, b);
|
||||
self.unification_table.set_value(
|
||||
a,
|
||||
Rc::new(TypeEnum::TVar {
|
||||
Rc::new(TVar {
|
||||
id: name1.map_or(*id2, |_| *id),
|
||||
fields: new_fields,
|
||||
range,
|
||||
name: name1.or(*name2),
|
||||
loc: loc1.or(*loc2),
|
||||
is_const_generic: false,
|
||||
}),
|
||||
);
|
||||
}
|
||||
(TVar { fields: None, range, .. }, _) => {
|
||||
(TVar { fields: None, range, is_const_generic: false, .. }, _) => {
|
||||
// We check for the range of the type variable to see if unification is allowed.
|
||||
// Note that although b may be compatible with a, we may have to constrain type
|
||||
// variables in b to make sure that instantiations of b would always be compatible
|
||||
|
@ -607,9 +755,9 @@ impl Unifier {
|
|||
self.unify_impl(x, b, false)?;
|
||||
self.set_a_to_b(a, x);
|
||||
}
|
||||
(TVar { fields: Some(fields), range, .. }, TTuple { ty }) => {
|
||||
let len = ty.len() as i32;
|
||||
for (k, v) in fields.iter() {
|
||||
(TVar { fields: Some(fields), range, is_const_generic: false, .. }, TTuple { ty }) => {
|
||||
let len = i32::try_from(ty.len()).unwrap();
|
||||
for (k, v) in fields {
|
||||
match *k {
|
||||
RecordKey::Int(i) => {
|
||||
if v.mutable {
|
||||
|
@ -637,11 +785,11 @@ impl Unifier {
|
|||
self.unify_impl(x, b, false)?;
|
||||
self.set_a_to_b(a, x);
|
||||
}
|
||||
(TVar { fields: Some(fields), range, .. }, TList { ty }) => {
|
||||
for (k, v) in fields.iter() {
|
||||
(TVar { fields: Some(fields), range, is_const_generic: false, .. }, TList { ty }) => {
|
||||
for (k, v) in fields {
|
||||
match *k {
|
||||
RecordKey::Int(_) => {
|
||||
self.unify_impl(v.ty, *ty, false).map_err(|e| e.at(v.loc))?
|
||||
self.unify_impl(v.ty, *ty, false).map_err(|e| e.at(v.loc))?;
|
||||
}
|
||||
RecordKey::Str(_) => {
|
||||
return Err(TypeError::new(TypeErrorKind::NoSuchField(*k, b), v.loc))
|
||||
|
@ -652,21 +800,93 @@ impl Unifier {
|
|||
self.unify_impl(x, b, false)?;
|
||||
self.set_a_to_b(a, x);
|
||||
}
|
||||
|
||||
(
|
||||
TVar { id: id1, range: ty1, is_const_generic: true, .. },
|
||||
TVar { id: id2, range: ty2, .. },
|
||||
) => {
|
||||
let ty1 = ty1[0];
|
||||
let ty2 = ty2[0];
|
||||
|
||||
if id1 != id2 {
|
||||
self.unify_impl(ty1, ty2, false)?;
|
||||
}
|
||||
|
||||
self.set_a_to_b(a, b);
|
||||
}
|
||||
|
||||
(TVar { range: tys, is_const_generic: true, .. }, TLiteral { values, .. }) => {
|
||||
assert_eq!(tys.len(), 1);
|
||||
assert_eq!(values.len(), 1);
|
||||
|
||||
let primitives =
|
||||
&self.primitive_store.expect("Expected PrimitiveStore to be present");
|
||||
|
||||
let ty = tys[0];
|
||||
let value = &values[0];
|
||||
let value_ty = value.get_type(primitives, self);
|
||||
|
||||
// If the types don't match, try to implicitly promote integers
|
||||
if !self.unioned(ty, value_ty) {
|
||||
let Ok(num_val) = i128::try_from(value.clone()) else {
|
||||
return Self::incompatible_types(a, b);
|
||||
};
|
||||
|
||||
let can_convert = if self.unioned(ty, primitives.int32) {
|
||||
i32::try_from(num_val).is_ok()
|
||||
} else if self.unioned(ty, primitives.int64) {
|
||||
i64::try_from(num_val).is_ok()
|
||||
} else if self.unioned(ty, primitives.uint32) {
|
||||
u32::try_from(num_val).is_ok()
|
||||
} else if self.unioned(ty, primitives.uint64) {
|
||||
u64::try_from(num_val).is_ok()
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
if !can_convert {
|
||||
return Self::incompatible_types(a, b);
|
||||
}
|
||||
}
|
||||
|
||||
self.set_a_to_b(a, b);
|
||||
}
|
||||
|
||||
(TLiteral { values: val1, .. }, TLiteral { values: val2, .. }) => {
|
||||
for (v1, v2) in zip(val1, val2) {
|
||||
if v1 != v2 {
|
||||
// Try performing integer promotion on literals
|
||||
let v1i = i128::try_from(v1.clone()).ok();
|
||||
let v2i = i128::try_from(v2.clone()).ok();
|
||||
|
||||
if v1i != v2i {
|
||||
return Self::incompatible_types(a, b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.set_a_to_b(a, b);
|
||||
}
|
||||
|
||||
(TTuple { ty: ty1 }, TTuple { ty: ty2 }) => {
|
||||
if ty1.len() != ty2.len() {
|
||||
return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None));
|
||||
}
|
||||
for (x, y) in ty1.iter().zip(ty2.iter()) {
|
||||
self.unify_impl(*x, *y, false)?;
|
||||
if self.unify_impl(*x, *y, false).is_err() {
|
||||
return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None));
|
||||
}
|
||||
}
|
||||
self.set_a_to_b(a, b);
|
||||
}
|
||||
(TList { ty: ty1 }, TList { ty: ty2 }) => {
|
||||
self.unify_impl(*ty1, *ty2, false)?;
|
||||
if self.unify_impl(*ty1, *ty2, false).is_err() {
|
||||
return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None));
|
||||
}
|
||||
self.set_a_to_b(a, b);
|
||||
}
|
||||
(TVar { fields: Some(map), range, .. }, TObj { fields, .. }) => {
|
||||
for (k, field) in map.iter() {
|
||||
for (k, field) in map {
|
||||
match *k {
|
||||
RecordKey::Str(s) => {
|
||||
let (ty, mutable) = fields.get(&s).copied().ok_or_else(|| {
|
||||
|
@ -698,7 +918,7 @@ impl Unifier {
|
|||
(TVar { fields: Some(map), range, .. }, TVirtual { ty }) => {
|
||||
let ty = self.get_ty(*ty);
|
||||
if let TObj { fields, .. } = ty.as_ref() {
|
||||
for (k, field) in map.iter() {
|
||||
for (k, field) in map {
|
||||
match *k {
|
||||
RecordKey::Str(s) => {
|
||||
let (ty, _) = fields.get(&s).copied().ok_or_else(|| {
|
||||
|
@ -740,21 +960,32 @@ impl Unifier {
|
|||
TObj { obj_id: id2, params: params2, .. },
|
||||
) => {
|
||||
if id1 != id2 {
|
||||
self.incompatible_types(a, b)?;
|
||||
Self::incompatible_types(a, b)?;
|
||||
}
|
||||
for (x, y) in zip(params1.values(), params2.values()) {
|
||||
self.unify_impl(*x, *y, false)?;
|
||||
|
||||
// Sort the type arguments by its UnificationKey first, since `HashMap::iter` visits
|
||||
// all K-V pairs "in arbitrary order"
|
||||
let (tv1, tv2) = (
|
||||
params1.iter().map(|(_, v)| v).collect_vec(),
|
||||
params2.iter().map(|(_, v)| v).collect_vec(),
|
||||
);
|
||||
for (x, y) in zip(tv1, tv2) {
|
||||
if self.unify_impl(*x, *y, false).is_err() {
|
||||
return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None));
|
||||
};
|
||||
}
|
||||
self.set_a_to_b(a, b);
|
||||
}
|
||||
(TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => {
|
||||
self.unify_impl(*ty1, *ty2, false)?;
|
||||
if self.unify_impl(*ty1, *ty2, false).is_err() {
|
||||
return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None));
|
||||
};
|
||||
self.set_a_to_b(a, b);
|
||||
}
|
||||
(TCall(calls1), TCall(calls2)) => {
|
||||
// we do not unify individual calls, instead we defer until the unification wtih a
|
||||
// function definition.
|
||||
let calls = calls1.iter().chain(calls2.iter()).cloned().collect();
|
||||
let calls = calls1.iter().chain(calls2.iter()).copied().collect();
|
||||
self.set_a_to_b(a, b);
|
||||
self.unification_table.set_value(b, Rc::new(TCall(calls)));
|
||||
}
|
||||
|
@ -767,7 +998,7 @@ impl Unifier {
|
|||
.rev()
|
||||
.collect();
|
||||
// we unify every calls to the function signature.
|
||||
for c in calls.iter() {
|
||||
for c in calls {
|
||||
let call = self.calls[c.0].clone();
|
||||
self.unify_call(&call, b, signature, &required)?;
|
||||
}
|
||||
|
@ -784,9 +1015,13 @@ impl Unifier {
|
|||
if x.name != y.name || x.default_value != y.default_value {
|
||||
return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None));
|
||||
}
|
||||
self.unify_impl(x.ty, y.ty, false)?;
|
||||
if self.unify_impl(x.ty, y.ty, false).is_err() {
|
||||
return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None));
|
||||
};
|
||||
}
|
||||
self.unify_impl(sign1.ret, sign2.ret, false)?;
|
||||
if self.unify_impl(sign1.ret, sign2.ret, false).is_err() {
|
||||
return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None));
|
||||
};
|
||||
self.set_a_to_b(a, b);
|
||||
}
|
||||
(TVar { fields: Some(fields), .. }, _) => {
|
||||
|
@ -795,10 +1030,10 @@ impl Unifier {
|
|||
}
|
||||
_ => {
|
||||
if swapped {
|
||||
return self.incompatible_types(a, b);
|
||||
} else {
|
||||
self.unify_impl(b, a, true)?;
|
||||
return Self::incompatible_types(a, b);
|
||||
}
|
||||
|
||||
self.unify_impl(b, a, true)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
|
@ -811,26 +1046,25 @@ impl Unifier {
|
|||
pub fn stringify_with_notes(
|
||||
&self,
|
||||
ty: Type,
|
||||
notes: &mut Option<HashMap<u32, String>>,
|
||||
notes: &mut Option<HashMap<TypeVarId, String>>,
|
||||
) -> String {
|
||||
let top_level = self.top_level.clone();
|
||||
self.internal_stringify(
|
||||
ty,
|
||||
&mut |id| {
|
||||
top_level.as_ref().map_or_else(
|
||||
|| format!("{}", id),
|
||||
|| format!("{id}"),
|
||||
|top_level| {
|
||||
if let TopLevelDef::Class { name, .. } =
|
||||
&*top_level.definitions.read()[id].read()
|
||||
{
|
||||
name.to_string()
|
||||
} else {
|
||||
let top_level_def = &top_level.definitions.read()[id];
|
||||
let TopLevelDef::Class { name, .. } = &*top_level_def.read() else {
|
||||
unreachable!("expected class definition")
|
||||
}
|
||||
};
|
||||
|
||||
name.to_string()
|
||||
},
|
||||
)
|
||||
},
|
||||
&mut |id| format!("typevar{}", id),
|
||||
&mut |id| format!("typevar{id}"),
|
||||
notes,
|
||||
)
|
||||
}
|
||||
|
@ -841,11 +1075,11 @@ impl Unifier {
|
|||
ty: Type,
|
||||
obj_to_name: &mut F,
|
||||
var_to_name: &mut G,
|
||||
notes: &mut Option<HashMap<u32, String>>,
|
||||
notes: &mut Option<HashMap<TypeVarId, String>>,
|
||||
) -> String
|
||||
where
|
||||
F: FnMut(usize) -> String,
|
||||
G: FnMut(u32) -> String,
|
||||
G: FnMut(TypeVarId) -> String,
|
||||
{
|
||||
let ty = self.unification_table.probe_value_immutable(ty).clone();
|
||||
match ty.as_ref() {
|
||||
|
@ -873,7 +1107,7 @@ impl Unifier {
|
|||
if !range.is_empty() && notes.is_some() && !notes.as_ref().unwrap().contains_key(id)
|
||||
{
|
||||
// just in case if there is any cyclic dependency
|
||||
notes.as_mut().unwrap().insert(*id, "".into());
|
||||
notes.as_mut().unwrap().insert(*id, String::new());
|
||||
let body = format!(
|
||||
"{} ∈ {{{}}}",
|
||||
n,
|
||||
|
@ -887,6 +1121,9 @@ impl Unifier {
|
|||
};
|
||||
n
|
||||
}
|
||||
TypeEnum::TLiteral { values, .. } => {
|
||||
format!("const({})", values.iter().map(|v| format!("{v:?}")).join(", "))
|
||||
}
|
||||
TypeEnum::TTuple { ty } => {
|
||||
let mut fields =
|
||||
ty.iter().map(|v| self.internal_stringify(*v, obj_to_name, var_to_name, notes));
|
||||
|
@ -903,15 +1140,13 @@ impl Unifier {
|
|||
}
|
||||
TypeEnum::TObj { obj_id, params, .. } => {
|
||||
let name = obj_to_name(obj_id.0);
|
||||
if !params.is_empty() {
|
||||
let params = params
|
||||
if params.is_empty() {
|
||||
name
|
||||
} else {
|
||||
let mut params = params
|
||||
.iter()
|
||||
.map(|(_, v)| self.internal_stringify(*v, obj_to_name, var_to_name, notes));
|
||||
// sort to preserve order
|
||||
let mut params = params.sorted();
|
||||
format!("{}[{}]", name, params.join(", "))
|
||||
} else {
|
||||
name
|
||||
}
|
||||
}
|
||||
TypeEnum::TCall { .. } => "call".to_owned(),
|
||||
|
@ -937,20 +1172,20 @@ impl Unifier {
|
|||
})
|
||||
.join(", ");
|
||||
let ret = self.internal_stringify(signature.ret, obj_to_name, var_to_name, notes);
|
||||
format!("fn[[{}], {}]", params, ret)
|
||||
format!("fn[[{params}], {ret}]")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Unifies `a` and `b` together, and set the value to the value of `b`.
|
||||
fn set_a_to_b(&mut self, a: Type, b: Type) {
|
||||
// unify a and b together, and set the value to b's value.
|
||||
let table = &mut self.unification_table;
|
||||
let ty_b = table.probe_value(b).clone();
|
||||
table.unify(a, b);
|
||||
table.set_value(a, ty_b)
|
||||
table.set_value(a, ty_b);
|
||||
}
|
||||
|
||||
fn incompatible_types(&mut self, a: Type, b: Type) -> Result<(), TypeError> {
|
||||
fn incompatible_types(a: Type, b: Type) -> Result<(), TypeError> {
|
||||
Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None))
|
||||
}
|
||||
|
||||
|
@ -960,7 +1195,7 @@ impl Unifier {
|
|||
fn instantiate_fun(&mut self, ty: Type, fun: &FunSignature) -> Type {
|
||||
let mut instantiated = true;
|
||||
let mut vars = Vec::new();
|
||||
for (k, v) in fun.vars.iter() {
|
||||
for (k, v) in &fun.vars {
|
||||
if let TypeEnum::TVar { id, name, loc, range, .. } =
|
||||
self.unification_table.probe_value(*v).as_ref()
|
||||
{
|
||||
|
@ -979,7 +1214,7 @@ impl Unifier {
|
|||
let mapping = vars
|
||||
.into_iter()
|
||||
.map(|(k, range, name, loc)| {
|
||||
(k, self.get_fresh_var_with_range(range.as_ref(), name, loc).0)
|
||||
(k, self.get_fresh_var_with_range(range.as_ref(), name, loc).ty)
|
||||
})
|
||||
.collect();
|
||||
self.subst(ty, &mapping).unwrap_or(ty)
|
||||
|
@ -1003,7 +1238,7 @@ impl Unifier {
|
|||
let cached = cache.get_mut(&a);
|
||||
if let Some(cached) = cached {
|
||||
if cached.is_none() {
|
||||
*cached = Some(self.get_fresh_var(None, None).0);
|
||||
*cached = Some(self.get_fresh_var(None, None).ty);
|
||||
}
|
||||
return *cached;
|
||||
}
|
||||
|
@ -1014,8 +1249,8 @@ impl Unifier {
|
|||
// variables, i.e. things like TRecord, TCall should not occur, and we
|
||||
// should be safe to not implement the substitution for those variants.
|
||||
match &*ty {
|
||||
TypeEnum::TRigidVar { .. } => None,
|
||||
TypeEnum::TVar { id, .. } => mapping.get(id).cloned(),
|
||||
TypeEnum::TRigidVar { .. } | TypeEnum::TLiteral { .. } => None,
|
||||
TypeEnum::TVar { id, .. } => mapping.get(id).copied(),
|
||||
TypeEnum::TTuple { ty } => {
|
||||
let mut new_ty = Cow::from(ty);
|
||||
for (i, t) in ty.iter().enumerate() {
|
||||
|
@ -1077,14 +1312,14 @@ impl Unifier {
|
|||
}
|
||||
if new_params.is_some() || new_ret.is_some() || matches!(new_args, Cow::Owned(..)) {
|
||||
let params = new_params.unwrap_or_else(|| params.clone());
|
||||
let ret = new_ret.unwrap_or_else(|| *ret);
|
||||
let ret = new_ret.unwrap_or(*ret);
|
||||
let args = new_args.into_owned();
|
||||
Some(self.add_ty(TypeEnum::TFunc(FunSignature { args, ret, vars: params })))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
TypeEnum::TCall(_) => {
|
||||
unreachable!("{} not expected", ty.get_type_name())
|
||||
}
|
||||
}
|
||||
|
@ -1092,15 +1327,15 @@ impl Unifier {
|
|||
|
||||
fn subst_map<K>(
|
||||
&mut self,
|
||||
map: &Mapping<K>,
|
||||
map: &IndexMapping<K>,
|
||||
mapping: &VarMap,
|
||||
cache: &mut HashMap<Type, Option<Type>>,
|
||||
) -> Option<Mapping<K>>
|
||||
) -> Option<IndexMapping<K>>
|
||||
where
|
||||
K: std::hash::Hash + std::cmp::Eq + std::clone::Clone,
|
||||
K: std::hash::Hash + Eq + Clone,
|
||||
{
|
||||
let mut map2 = None;
|
||||
for (k, v) in map.iter() {
|
||||
for (k, v) in map {
|
||||
if let Some(v1) = self.subst_impl(*v, mapping, cache) {
|
||||
if map2.is_none() {
|
||||
map2 = Some(map.clone());
|
||||
|
@ -1118,10 +1353,10 @@ impl Unifier {
|
|||
cache: &mut HashMap<Type, Option<Type>>,
|
||||
) -> Option<Mapping<K, (Type, bool)>>
|
||||
where
|
||||
K: std::hash::Hash + std::cmp::Eq + std::clone::Clone,
|
||||
K: std::hash::Hash + Eq + Clone,
|
||||
{
|
||||
let mut map2 = None;
|
||||
for (k, (v, mutability)) in map.iter() {
|
||||
for (k, (v, mutability)) in map {
|
||||
if let Some(v1) = self.subst_impl(*v, mapping, cache) {
|
||||
if map2.is_none() {
|
||||
map2 = Some(map.clone());
|
||||
|
@ -1158,14 +1393,14 @@ impl Unifier {
|
|||
if range.is_empty() {
|
||||
Err(())
|
||||
} else {
|
||||
let id = self.var_id + 1;
|
||||
self.var_id += 1;
|
||||
let id = self.generate_var_id();
|
||||
let ty = TVar {
|
||||
id,
|
||||
fields: fields.clone(),
|
||||
range,
|
||||
name: name2.or(*name),
|
||||
loc: loc2.or(*loc),
|
||||
is_const_generic: false,
|
||||
};
|
||||
Ok(Some(self.unification_table.new_key(ty.into())))
|
||||
}
|
||||
|
@ -1176,7 +1411,7 @@ impl Unifier {
|
|||
if range.is_empty() {
|
||||
Ok(Some(a))
|
||||
} else {
|
||||
for v in range.iter() {
|
||||
for v in range {
|
||||
let result = self.get_intersection(a, *v);
|
||||
if let Ok(result) = result {
|
||||
return Ok(result.or(Some(a)));
|
||||
|
@ -1192,7 +1427,7 @@ impl Unifier {
|
|||
.try_collect()?;
|
||||
if ty.iter().any(Option::is_some) {
|
||||
Ok(Some(self.add_ty(TTuple {
|
||||
ty: zip(ty.into_iter(), ty1.iter()).map(|(a, b)| a.unwrap_or(*b)).collect(),
|
||||
ty: zip(ty, ty1.iter()).map(|(a, b)| a.unwrap_or(*b)).collect(),
|
||||
})))
|
||||
} else {
|
||||
Ok(None)
|
||||
|
@ -1218,7 +1453,7 @@ impl Unifier {
|
|||
if range.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
for t in range.iter() {
|
||||
for t in range {
|
||||
let result = self.get_intersection(*t, b);
|
||||
if let Ok(result) = result {
|
||||
return Ok(result);
|
||||
|
@ -1226,4 +1461,10 @@ impl Unifier {
|
|||
}
|
||||
Err(TypeError::new(TypeErrorKind::IncompatibleRange(b, range.to_vec()), None))
|
||||
}
|
||||
|
||||
/// Generate a new [`TypeVarId`] from [`Unifier::var_id_counter`]
|
||||
fn generate_var_id(&mut self) -> TypeVarId {
|
||||
self.var_id_counter += 1;
|
||||
TypeVarId(self.var_id_counter)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -40,14 +40,14 @@ impl Unifier {
|
|||
TypeEnum::TObj { obj_id: id1, params: params1, .. },
|
||||
TypeEnum::TObj { obj_id: id2, params: params2, .. },
|
||||
) => id1 == id2 && self.map_eq(params1, params2),
|
||||
// TCall and TFunc are not yet implemented
|
||||
// TLiteral, TCall and TFunc are not yet implemented
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn map_eq<K>(&mut self, map1: &Mapping<K>, map2: &Mapping<K>) -> bool
|
||||
fn map_eq<K>(&mut self, map1: &IndexMapping<K>, map2: &IndexMapping<K>) -> bool
|
||||
where
|
||||
K: std::hash::Hash + std::cmp::Eq + std::clone::Clone,
|
||||
K: std::hash::Hash + Eq + Clone,
|
||||
{
|
||||
if map1.len() != map2.len() {
|
||||
return false;
|
||||
|
@ -62,7 +62,7 @@ impl Unifier {
|
|||
|
||||
fn map_eq2<K>(&mut self, map1: &Mapping<K, RecordField>, map2: &Mapping<K, RecordField>) -> bool
|
||||
where
|
||||
K: std::hash::Hash + std::cmp::Eq + std::clone::Clone,
|
||||
K: std::hash::Hash + Eq + Clone,
|
||||
{
|
||||
if map1.len() != map2.len() {
|
||||
return false;
|
||||
|
@ -91,7 +91,7 @@ impl TestEnvironment {
|
|||
unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(0),
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
}),
|
||||
);
|
||||
type_mapping.insert(
|
||||
|
@ -99,7 +99,7 @@ impl TestEnvironment {
|
|||
unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(1),
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
}),
|
||||
);
|
||||
type_mapping.insert(
|
||||
|
@ -107,16 +107,16 @@ impl TestEnvironment {
|
|||
unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(2),
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
}),
|
||||
);
|
||||
let (v0, id) = unifier.get_dummy_var();
|
||||
let tvar = unifier.get_dummy_var();
|
||||
type_mapping.insert(
|
||||
"Foo".into(),
|
||||
unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(3),
|
||||
fields: [("a".into(), (v0, true))].iter().cloned().collect::<HashMap<_, _>>(),
|
||||
params: [(id, v0)].iter().cloned().collect::<HashMap<_, _>>(),
|
||||
fields: [("a".into(), (tvar.ty, true))].iter().cloned().collect::<HashMap<_, _>>(),
|
||||
params: into_var_map([tvar]),
|
||||
}),
|
||||
);
|
||||
|
||||
|
@ -139,7 +139,7 @@ impl TestEnvironment {
|
|||
match &typ[..end] {
|
||||
"tuple" => {
|
||||
let mut s = &typ[end..];
|
||||
assert!(&s[0..1] == "[");
|
||||
assert_eq!(&s[0..1], "[");
|
||||
let mut ty = Vec::new();
|
||||
while &s[0..1] != "]" {
|
||||
let result = self.internal_parse(&s[1..], mapping);
|
||||
|
@ -149,14 +149,14 @@ impl TestEnvironment {
|
|||
(self.unifier.add_ty(TypeEnum::TTuple { ty }), &s[1..])
|
||||
}
|
||||
"list" => {
|
||||
assert!(&typ[end..end + 1] == "[");
|
||||
assert_eq!(&typ[end..end + 1], "[");
|
||||
let (ty, s) = self.internal_parse(&typ[end + 1..], mapping);
|
||||
assert!(&s[0..1] == "]");
|
||||
assert_eq!(&s[0..1], "]");
|
||||
(self.unifier.add_ty(TypeEnum::TList { ty }), &s[1..])
|
||||
}
|
||||
"Record" => {
|
||||
let mut s = &typ[end..];
|
||||
assert!(&s[0..1] == "[");
|
||||
assert_eq!(&s[0..1], "[");
|
||||
let mut fields = HashMap::new();
|
||||
while &s[0..1] != "]" {
|
||||
let eq = s.find('=').unwrap();
|
||||
|
@ -176,7 +176,7 @@ impl TestEnvironment {
|
|||
let te = self.unifier.get_ty(ty);
|
||||
if let TypeEnum::TObj { params, .. } = &*te.as_ref() {
|
||||
if !params.is_empty() {
|
||||
assert!(&s[0..1] == "[");
|
||||
assert_eq!(&s[0..1], "[");
|
||||
let mut p = Vec::new();
|
||||
while &s[0..1] != "]" {
|
||||
let result = self.internal_parse(&s[1..], mapping);
|
||||
|
@ -250,7 +250,7 @@ fn test_unify(
|
|||
let mut mapping = HashMap::new();
|
||||
for i in 1..=variable_count {
|
||||
let v = env.unifier.get_dummy_var();
|
||||
mapping.insert(format!("v{}", i), v.0);
|
||||
mapping.insert(format!("v{}", i), v.ty);
|
||||
}
|
||||
// unification may have side effect when we do type resolution, so freeze the types
|
||||
// before doing unification.
|
||||
|
@ -286,7 +286,7 @@ fn test_unify(
|
|||
("v1", "tuple[int]"),
|
||||
("v2", "tuple[float]"),
|
||||
],
|
||||
(("v1", "v2"), "Incompatible types: 0 and 1")
|
||||
(("v1", "v2"), "Incompatible types: tuple[0] and tuple[1]")
|
||||
; "tuple parameter mismatch"
|
||||
)]
|
||||
#[test_case(2,
|
||||
|
@ -315,7 +315,7 @@ fn test_invalid_unification(
|
|||
let mut mapping = HashMap::new();
|
||||
for i in 1..=variable_count {
|
||||
let v = env.unifier.get_dummy_var();
|
||||
mapping.insert(format!("v{}", i), v.0);
|
||||
mapping.insert(format!("v{}", i), v.ty);
|
||||
}
|
||||
// unification may have side effect when we do type resolution, so freeze the types
|
||||
// before doing unification.
|
||||
|
@ -339,23 +339,17 @@ fn test_recursive_subst() {
|
|||
let int = *env.type_mapping.get("int").unwrap();
|
||||
let foo_id = *env.type_mapping.get("Foo").unwrap();
|
||||
let foo_ty = env.unifier.get_ty(foo_id);
|
||||
let mapping: HashMap<_, _>;
|
||||
with_fields(&mut env.unifier, foo_id, |_unifier, fields| {
|
||||
fields.insert("rec".into(), (foo_id, true));
|
||||
});
|
||||
if let TypeEnum::TObj { params, .. } = &*foo_ty {
|
||||
mapping = params.iter().map(|(id, _)| (*id, int)).collect();
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
let TypeEnum::TObj { params, .. } = &*foo_ty else { unreachable!() };
|
||||
let mapping = params.iter().map(|(id, _)| (*id, int)).collect();
|
||||
let instantiated = env.unifier.subst(foo_id, &mapping).unwrap();
|
||||
let instantiated_ty = env.unifier.get_ty(instantiated);
|
||||
if let TypeEnum::TObj { fields, .. } = &*instantiated_ty {
|
||||
assert!(env.unifier.unioned(fields.get(&"a".into()).unwrap().0, int));
|
||||
assert!(env.unifier.unioned(fields.get(&"rec".into()).unwrap().0, instantiated));
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
let TypeEnum::TObj { fields, .. } = &*instantiated_ty else { unreachable!() };
|
||||
assert!(env.unifier.unioned(fields.get(&"a".into()).unwrap().0, int));
|
||||
assert!(env.unifier.unioned(fields.get(&"rec".into()).unwrap().0, instantiated));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -365,7 +359,7 @@ fn test_virtual() {
|
|||
let fun = env.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![],
|
||||
ret: int,
|
||||
vars: HashMap::new(),
|
||||
vars: VarMap::new(),
|
||||
}));
|
||||
let bar = env.unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(5),
|
||||
|
@ -373,10 +367,10 @@ fn test_virtual() {
|
|||
.iter()
|
||||
.cloned()
|
||||
.collect::<HashMap<StrRef, _>>(),
|
||||
params: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let v0 = env.unifier.get_dummy_var().0;
|
||||
let v1 = env.unifier.get_dummy_var().0;
|
||||
let v0 = env.unifier.get_dummy_var().ty;
|
||||
let v1 = env.unifier.get_dummy_var().ty;
|
||||
|
||||
let a = env.unifier.add_ty(TypeEnum::TVirtual { ty: bar });
|
||||
let b = env.unifier.add_ty(TypeEnum::TVirtual { ty: v0 });
|
||||
|
@ -409,12 +403,12 @@ fn test_typevar_range() {
|
|||
|
||||
// unification between v and int
|
||||
// where v in (int, bool)
|
||||
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0;
|
||||
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).ty;
|
||||
env.unifier.unify(int, v).unwrap();
|
||||
|
||||
// unification between v and list[int]
|
||||
// where v in (int, bool)
|
||||
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0;
|
||||
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).ty;
|
||||
assert_eq!(
|
||||
env.unify(int_list, v),
|
||||
Err("Expected any one of these types: 0, 2, but got list[0]".to_string())
|
||||
|
@ -422,25 +416,25 @@ fn test_typevar_range() {
|
|||
|
||||
// unification between v and float
|
||||
// where v in (int, bool)
|
||||
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0;
|
||||
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).ty;
|
||||
assert_eq!(
|
||||
env.unify(float, v),
|
||||
Err("Expected any one of these types: 0, 2, but got 1".to_string())
|
||||
);
|
||||
|
||||
let v1 = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0;
|
||||
let v1 = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).ty;
|
||||
let v1_list = env.unifier.add_ty(TypeEnum::TList { ty: v1 });
|
||||
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).0;
|
||||
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).ty;
|
||||
// unification between v and int
|
||||
// where v in (int, list[v1]), v1 in (int, bool)
|
||||
env.unifier.unify(int, v).unwrap();
|
||||
|
||||
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).0;
|
||||
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).ty;
|
||||
// unification between v and list[int]
|
||||
// where v in (int, list[v1]), v1 in (int, bool)
|
||||
env.unifier.unify(int_list, v).unwrap();
|
||||
|
||||
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).0;
|
||||
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).ty;
|
||||
// unification between v and list[float]
|
||||
// where v in (int, list[v1]), v1 in (int, bool)
|
||||
assert_eq!(
|
||||
|
@ -448,43 +442,45 @@ fn test_typevar_range() {
|
|||
Err("Expected any one of these types: 0, list[typevar5], but got list[1]\n\nNotes:\n typevar5 ∈ {0, 2}".to_string())
|
||||
);
|
||||
|
||||
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0;
|
||||
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0;
|
||||
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).ty;
|
||||
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).ty;
|
||||
env.unifier.unify(a, b).unwrap();
|
||||
env.unifier.unify(a, float).unwrap();
|
||||
|
||||
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0;
|
||||
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0;
|
||||
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).ty;
|
||||
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).ty;
|
||||
env.unifier.unify(a, b).unwrap();
|
||||
assert_eq!(env.unify(a, int), Err("Expected any one of these types: 1, but got 0".into()));
|
||||
|
||||
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0;
|
||||
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0;
|
||||
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).ty;
|
||||
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).ty;
|
||||
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a });
|
||||
let a_list = env.unifier.get_fresh_var_with_range(&[a_list], None, None).0;
|
||||
let a_list = env.unifier.get_fresh_var_with_range(&[a_list], None, None).ty;
|
||||
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b });
|
||||
let b_list = env.unifier.get_fresh_var_with_range(&[b_list], None, None).0;
|
||||
let b_list = env.unifier.get_fresh_var_with_range(&[b_list], None, None).ty;
|
||||
env.unifier.unify(a_list, b_list).unwrap();
|
||||
let float_list = env.unifier.add_ty(TypeEnum::TList { ty: float });
|
||||
env.unifier.unify(a_list, float_list).unwrap();
|
||||
// previous unifications should not affect a and b
|
||||
env.unifier.unify(a, int).unwrap();
|
||||
|
||||
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0;
|
||||
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0;
|
||||
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).ty;
|
||||
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).ty;
|
||||
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a });
|
||||
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b });
|
||||
env.unifier.unify(a_list, b_list).unwrap();
|
||||
let int_list = env.unifier.add_ty(TypeEnum::TList { ty: int });
|
||||
assert_eq!(
|
||||
env.unify(a_list, int_list),
|
||||
Err("Expected any one of these types: 1, but got 0".into())
|
||||
Err("Incompatible types: list[typevar22] and list[0]\
|
||||
\n\nNotes:\n typevar22 ∈ {1}"
|
||||
.into())
|
||||
);
|
||||
|
||||
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0;
|
||||
let b = env.unifier.get_dummy_var().0;
|
||||
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).ty;
|
||||
let b = env.unifier.get_dummy_var().ty;
|
||||
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a });
|
||||
let a_list = env.unifier.get_fresh_var_with_range(&[a_list], None, None).0;
|
||||
let a_list = env.unifier.get_fresh_var_with_range(&[a_list], None, None).ty;
|
||||
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b });
|
||||
env.unifier.unify(a_list, b_list).unwrap();
|
||||
assert_eq!(
|
||||
|
@ -496,9 +492,9 @@ fn test_typevar_range() {
|
|||
#[test]
|
||||
fn test_rigid_var() {
|
||||
let mut env = TestEnvironment::new();
|
||||
let a = env.unifier.get_fresh_rigid_var(None, None).0;
|
||||
let b = env.unifier.get_fresh_rigid_var(None, None).0;
|
||||
let x = env.unifier.get_dummy_var().0;
|
||||
let a = env.unifier.get_fresh_rigid_var(None, None).ty;
|
||||
let b = env.unifier.get_fresh_rigid_var(None, None).ty;
|
||||
let x = env.unifier.get_dummy_var().ty;
|
||||
let list_a = env.unifier.add_ty(TypeEnum::TList { ty: a });
|
||||
let list_x = env.unifier.add_ty(TypeEnum::TList { ty: x });
|
||||
let int = env.parse("int", &HashMap::new());
|
||||
|
@ -506,7 +502,10 @@ fn test_rigid_var() {
|
|||
|
||||
assert_eq!(env.unify(a, b), Err("Incompatible types: typevar3 and typevar2".to_string()));
|
||||
env.unifier.unify(list_a, list_x).unwrap();
|
||||
assert_eq!(env.unify(list_x, list_int), Err("Incompatible types: 0 and typevar2".to_string()));
|
||||
assert_eq!(
|
||||
env.unify(list_x, list_int),
|
||||
Err("Incompatible types: list[typevar2] and list[0]".to_string())
|
||||
);
|
||||
|
||||
env.unifier.replace_rigid_var(a, int);
|
||||
env.unifier.unify(list_x, list_int).unwrap();
|
||||
|
@ -523,13 +522,13 @@ fn test_instantiation() {
|
|||
let obj_map: HashMap<_, _> =
|
||||
[(0usize, "int"), (1, "float"), (2, "bool")].iter().cloned().collect();
|
||||
|
||||
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0;
|
||||
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).ty;
|
||||
let list_v = env.unifier.add_ty(TypeEnum::TList { ty: v });
|
||||
let v1 = env.unifier.get_fresh_var_with_range(&[list_v, int], None, None).0;
|
||||
let v2 = env.unifier.get_fresh_var_with_range(&[list_int, float], None, None).0;
|
||||
let t = env.unifier.get_dummy_var().0;
|
||||
let v1 = env.unifier.get_fresh_var_with_range(&[list_v, int], None, None).ty;
|
||||
let v2 = env.unifier.get_fresh_var_with_range(&[list_int, float], None, None).ty;
|
||||
let t = env.unifier.get_dummy_var().ty;
|
||||
let tuple = env.unifier.add_ty(TypeEnum::TTuple { ty: vec![v, v1, v2] });
|
||||
let v3 = env.unifier.get_fresh_var_with_range(&[tuple, t], None, None).0;
|
||||
let v3 = env.unifier.get_fresh_var_with_range(&[tuple, t], None, None).ty;
|
||||
// t = TypeVar('t')
|
||||
// v = TypeVar('v', int, bool)
|
||||
// v1 = TypeVar('v1', 'list[v]', int)
|
||||
|
|
|
@ -16,21 +16,10 @@ pub struct UnificationTable<V> {
|
|||
|
||||
#[derive(Clone, Debug)]
|
||||
enum Action<V> {
|
||||
Parent {
|
||||
key: usize,
|
||||
original_parent: usize,
|
||||
},
|
||||
Value {
|
||||
key: usize,
|
||||
original_value: Option<V>,
|
||||
},
|
||||
Rank {
|
||||
key: usize,
|
||||
original_rank: u32,
|
||||
},
|
||||
Marker {
|
||||
generation: u32,
|
||||
}
|
||||
Parent { key: usize, original_parent: usize },
|
||||
Value { key: usize, original_value: Option<V> },
|
||||
Rank { key: usize, original_rank: u32 },
|
||||
Marker { generation: u32 },
|
||||
}
|
||||
|
||||
impl<V> Default for UnificationTable<V> {
|
||||
|
@ -41,7 +30,13 @@ impl<V> Default for UnificationTable<V> {
|
|||
|
||||
impl<V> UnificationTable<V> {
|
||||
pub fn new() -> UnificationTable<V> {
|
||||
UnificationTable { parents: Vec::new(), ranks: Vec::new(), values: Vec::new(), log: Vec::new(), generation: 0 }
|
||||
UnificationTable {
|
||||
parents: Vec::new(),
|
||||
ranks: Vec::new(),
|
||||
values: Vec::new(),
|
||||
log: Vec::new(),
|
||||
generation: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_key(&mut self, v: V) -> UnificationKey {
|
||||
|
@ -125,7 +120,10 @@ impl<V> UnificationTable<V> {
|
|||
pub fn restore_snapshot(&mut self, snapshot: (usize, u32)) {
|
||||
let (log_len, generation) = snapshot;
|
||||
assert!(self.log.len() >= log_len, "snapshot restoration error");
|
||||
assert!(matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation), "snapshot restoration error");
|
||||
assert!(
|
||||
matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation),
|
||||
"snapshot restoration error"
|
||||
);
|
||||
for action in self.log.drain(log_len - 1..).rev() {
|
||||
match action {
|
||||
Action::Parent { key, original_parent } => {
|
||||
|
@ -145,7 +143,10 @@ impl<V> UnificationTable<V> {
|
|||
pub fn discard_snapshot(&mut self, snapshot: (usize, u32)) {
|
||||
let (log_len, generation) = snapshot;
|
||||
assert!(self.log.len() >= log_len, "snapshot discard error");
|
||||
assert!(matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation), "snapshot discard error");
|
||||
assert!(
|
||||
matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation),
|
||||
"snapshot discard error"
|
||||
);
|
||||
self.log.clear();
|
||||
}
|
||||
}
|
||||
|
@ -159,11 +160,23 @@ where
|
|||
.enumerate()
|
||||
.map(|(i, (v, p))| if *p == i { v.as_ref().map(|v| v.as_ref().clone()) } else { None })
|
||||
.collect();
|
||||
UnificationTable { parents: self.parents.clone(), ranks: self.ranks.clone(), values, log: Vec::new(), generation: 0 }
|
||||
UnificationTable {
|
||||
parents: self.parents.clone(),
|
||||
ranks: self.ranks.clone(),
|
||||
values,
|
||||
log: Vec::new(),
|
||||
generation: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_send(table: &UnificationTable<V>) -> UnificationTable<Rc<V>> {
|
||||
let values = table.values.iter().cloned().map(|v| v.map(Rc::new)).collect();
|
||||
UnificationTable { parents: table.parents.clone(), ranks: table.ranks.clone(), values, log: Vec::new(), generation: 0 }
|
||||
UnificationTable {
|
||||
parents: table.parents.clone(),
|
||||
ranks: table.ranks.clone(),
|
||||
values,
|
||||
log: Vec::new(),
|
||||
generation: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
[package]
|
||||
name = "nac3ld"
|
||||
version = "0.1.0"
|
||||
authors = ["M-Labs"]
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
byteorder = { version = "1.5", default-features = false }
|
|
@ -0,0 +1,509 @@
|
|||
#![allow(non_camel_case_types, non_upper_case_globals)]
|
||||
|
||||
use std::mem;
|
||||
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
|
||||
pub const DW_EH_PE_omit: u8 = 0xFF;
|
||||
pub const DW_EH_PE_absptr: u8 = 0x00;
|
||||
|
||||
pub const DW_EH_PE_uleb128: u8 = 0x01;
|
||||
pub const DW_EH_PE_udata2: u8 = 0x02;
|
||||
pub const DW_EH_PE_udata4: u8 = 0x03;
|
||||
pub const DW_EH_PE_udata8: u8 = 0x04;
|
||||
pub const DW_EH_PE_sleb128: u8 = 0x09;
|
||||
pub const DW_EH_PE_sdata2: u8 = 0x0A;
|
||||
pub const DW_EH_PE_sdata4: u8 = 0x0B;
|
||||
pub const DW_EH_PE_sdata8: u8 = 0x0C;
|
||||
|
||||
pub const DW_EH_PE_pcrel: u8 = 0x10;
|
||||
pub const DW_EH_PE_textrel: u8 = 0x20;
|
||||
pub const DW_EH_PE_datarel: u8 = 0x30;
|
||||
pub const DW_EH_PE_funcrel: u8 = 0x40;
|
||||
pub const DW_EH_PE_aligned: u8 = 0x50;
|
||||
|
||||
pub const DW_EH_PE_indirect: u8 = 0x80;
|
||||
|
||||
pub struct DwarfReader<'a> {
|
||||
pub slice: &'a [u8],
|
||||
pub virt_addr: u32,
|
||||
base_slice: &'a [u8],
|
||||
base_virt_addr: u32,
|
||||
}
|
||||
|
||||
impl<'a> DwarfReader<'a> {
|
||||
pub fn new(slice: &[u8], virt_addr: u32) -> DwarfReader {
|
||||
DwarfReader { slice, virt_addr, base_slice: slice, base_virt_addr: virt_addr }
|
||||
}
|
||||
|
||||
/// Creates a new instance from another instance of [DwarfReader], optionally removing any
|
||||
/// offsets previously applied to the other instance.
|
||||
pub fn from_reader(other: &DwarfReader<'a>, reset_offset: bool) -> DwarfReader<'a> {
|
||||
if reset_offset {
|
||||
DwarfReader::new(other.base_slice, other.base_virt_addr)
|
||||
} else {
|
||||
DwarfReader::new(other.slice, other.virt_addr)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn offset(&mut self, offset: u32) {
|
||||
self.slice = &self.slice[offset as usize..];
|
||||
self.virt_addr = self.virt_addr.wrapping_add(offset);
|
||||
}
|
||||
|
||||
/// ULEB128 and SLEB128 encodings are defined in Section 7.6 - "Variable Length Data" of the
|
||||
/// [DWARF-4 Manual](https://dwarfstd.org/doc/DWARF4.pdf).
|
||||
pub fn read_uleb128(&mut self) -> u64 {
|
||||
let mut shift: usize = 0;
|
||||
let mut result: u64 = 0;
|
||||
let mut byte: u8;
|
||||
loop {
|
||||
byte = self.read_u8();
|
||||
result |= u64::from(byte & 0x7F) << shift;
|
||||
shift += 7;
|
||||
if byte & 0x80 == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
pub fn read_sleb128(&mut self) -> i64 {
|
||||
let mut shift: u32 = 0;
|
||||
let mut result: u64 = 0;
|
||||
let mut byte: u8;
|
||||
loop {
|
||||
byte = self.read_u8();
|
||||
result |= u64::from(byte & 0x7F) << shift;
|
||||
shift += 7;
|
||||
if byte & 0x80 == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// sign-extend
|
||||
if shift < u64::BITS && (byte & 0x40) != 0 {
|
||||
result |= (!0u64) << shift;
|
||||
}
|
||||
result as i64
|
||||
}
|
||||
|
||||
pub fn read_u8(&mut self) -> u8 {
|
||||
let val = self.slice[0];
|
||||
self.slice = &self.slice[1..];
|
||||
val
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! impl_read_fn {
|
||||
( $($type: ty, $byteorder_fn: ident);* ) => {
|
||||
impl<'a> DwarfReader<'a> {
|
||||
$(
|
||||
pub fn $byteorder_fn(&mut self) -> $type {
|
||||
let val = LittleEndian::$byteorder_fn(self.slice);
|
||||
self.slice = &self.slice[mem::size_of::<$type>()..];
|
||||
val
|
||||
}
|
||||
)*
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl_read_fn!(
|
||||
u16, read_u16;
|
||||
u32, read_u32;
|
||||
u64, read_u64;
|
||||
i16, read_i16;
|
||||
i32, read_i32;
|
||||
i64, read_i64
|
||||
);
|
||||
|
||||
pub struct DwarfWriter<'a> {
|
||||
pub slice: &'a mut [u8],
|
||||
pub offset: usize,
|
||||
}
|
||||
|
||||
impl<'a> DwarfWriter<'a> {
|
||||
pub fn new(slice: &mut [u8]) -> DwarfWriter {
|
||||
DwarfWriter { slice, offset: 0 }
|
||||
}
|
||||
|
||||
pub fn write_u8(&mut self, data: u8) {
|
||||
self.slice[self.offset] = data;
|
||||
self.offset += 1;
|
||||
}
|
||||
|
||||
pub fn write_u32(&mut self, data: u32) {
|
||||
LittleEndian::write_u32(&mut self.slice[self.offset..], data);
|
||||
self.offset += 4;
|
||||
}
|
||||
}
|
||||
|
||||
fn read_encoded_pointer(reader: &mut DwarfReader, encoding: u8) -> Result<usize, ()> {
|
||||
if encoding == DW_EH_PE_omit {
|
||||
return Err(());
|
||||
}
|
||||
|
||||
// DW_EH_PE_aligned implies it's an absolute pointer value
|
||||
// However, we are linking library for 32-bits architecture
|
||||
// The size of variable should be 4 bytes instead
|
||||
if encoding == DW_EH_PE_aligned {
|
||||
let shifted_virt_addr = round_up(reader.virt_addr as usize, mem::size_of::<u32>())?;
|
||||
let addr_inc = shifted_virt_addr - reader.virt_addr as usize;
|
||||
|
||||
reader.slice = &reader.slice[addr_inc..];
|
||||
reader.virt_addr = shifted_virt_addr as u32;
|
||||
return Ok(reader.read_u32() as usize);
|
||||
}
|
||||
|
||||
match encoding & 0x0F {
|
||||
DW_EH_PE_absptr | DW_EH_PE_udata4 => Ok(reader.read_u32() as usize),
|
||||
DW_EH_PE_uleb128 => Ok(reader.read_uleb128() as usize),
|
||||
DW_EH_PE_udata2 => Ok(reader.read_u16() as usize),
|
||||
DW_EH_PE_udata8 => Ok(reader.read_u64() as usize),
|
||||
DW_EH_PE_sleb128 => Ok(reader.read_sleb128() as usize),
|
||||
DW_EH_PE_sdata2 => Ok(reader.read_i16() as usize),
|
||||
DW_EH_PE_sdata4 => Ok(reader.read_i32() as usize),
|
||||
DW_EH_PE_sdata8 => Ok(reader.read_i64() as usize),
|
||||
_ => Err(()),
|
||||
}
|
||||
}
|
||||
|
||||
fn read_encoded_pointer_with_pc(reader: &mut DwarfReader, encoding: u8) -> Result<usize, ()> {
|
||||
let entry_virt_addr = reader.virt_addr;
|
||||
let mut result = read_encoded_pointer(reader, encoding)?;
|
||||
|
||||
// DW_EH_PE_aligned implies it's an absolute pointer value
|
||||
if encoding == DW_EH_PE_aligned {
|
||||
return Ok(result);
|
||||
}
|
||||
|
||||
result = match encoding & 0x70 {
|
||||
DW_EH_PE_pcrel => result.wrapping_add(entry_virt_addr as usize),
|
||||
|
||||
// .eh_frame normally would not have these kinds of relocations
|
||||
// These would not be supported by a dedicated linker relocation schemes for RISC-V
|
||||
DW_EH_PE_textrel | DW_EH_PE_datarel | DW_EH_PE_funcrel | DW_EH_PE_aligned => {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
// Other values should be impossible
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
if encoding & DW_EH_PE_indirect != 0 {
|
||||
// There should not be a need for indirect addressing, as assembly code from
|
||||
// the dynamic library should not be freely moved relative to the EH frame.
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn round_up(unrounded: usize, align: usize) -> Result<usize, ()> {
|
||||
if align.is_power_of_two() {
|
||||
Ok((unrounded + align - 1) & !(align - 1))
|
||||
} else {
|
||||
Err(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Minimalistic structure to store everything needed for parsing FDEs to synthesize `.eh_frame_hdr`
|
||||
/// section.
|
||||
///
|
||||
/// Refer to [The Linux Standard Base Core Specification, Generic Part](https://refspecs.linuxfoundation.org/LSB_5.0.0/LSB-Core-generic/LSB-Core-generic/ehframechpt.html)
|
||||
/// for more information.
|
||||
pub struct EH_Frame<'a> {
|
||||
reader: DwarfReader<'a>,
|
||||
}
|
||||
|
||||
impl<'a> EH_Frame<'a> {
|
||||
/// Creates an [EH_Frame] using the bytes in the `.eh_frame` section and its address in the ELF
|
||||
/// file.
|
||||
pub fn new(eh_frame_slice: &[u8], eh_frame_addr: u32) -> EH_Frame {
|
||||
EH_Frame { reader: DwarfReader::new(eh_frame_slice, eh_frame_addr) }
|
||||
}
|
||||
|
||||
/// Returns an [Iterator] over all Call Frame Information (CFI) records.
|
||||
pub fn cfi_records(&self) -> CFI_Records<'a> {
|
||||
let reader = DwarfReader::from_reader(&self.reader, true);
|
||||
let len = reader.slice.len();
|
||||
|
||||
CFI_Records { reader, available: len }
|
||||
}
|
||||
}
|
||||
|
||||
/// A single Call Frame Information (CFI) record.
|
||||
///
|
||||
/// From the [specification](https://refspecs.linuxfoundation.org/LSB_5.0.0/LSB-Core-generic/LSB-Core-generic/ehframechpt.html):
|
||||
///
|
||||
/// > Each CFI record contains a Common Information Entry (CIE) record followed by 1 or more Frame
|
||||
/// Description Entry (FDE) records.
|
||||
pub struct CFI_Record<'a> {
|
||||
// It refers to the augmentation data that corresponds to 'R' in the augmentation string
|
||||
fde_pointer_encoding: u8,
|
||||
fde_reader: DwarfReader<'a>,
|
||||
}
|
||||
|
||||
impl<'a> CFI_Record<'a> {
|
||||
pub fn from_reader(cie_reader: &mut DwarfReader<'a>) -> Result<CFI_Record<'a>, ()> {
|
||||
let length = cie_reader.read_u32();
|
||||
let fde_reader = match length {
|
||||
// eh_frame with 0 lengths means the CIE is terminated
|
||||
0 => panic!("Cannot create an EH_Frame from a termination CIE"),
|
||||
|
||||
// length == u32::MAX means that the length is only representable with 64 bits,
|
||||
// which does not make sense in a system with 32-bit address.
|
||||
0xFFFF_FFFF => unimplemented!(),
|
||||
|
||||
_ => {
|
||||
let mut fde_reader = DwarfReader::from_reader(cie_reader, false);
|
||||
fde_reader.offset(length);
|
||||
fde_reader
|
||||
}
|
||||
};
|
||||
|
||||
// Routine check on the .eh_frame well-formness, in terms of CIE ID & Version args.
|
||||
let cie_ptr = cie_reader.read_u32();
|
||||
assert_eq!(cie_ptr, 0);
|
||||
assert_eq!(cie_reader.read_u8(), 1);
|
||||
|
||||
// Parse augmentation string
|
||||
// The first character must be 'z', there is no way to proceed otherwise
|
||||
assert_eq!(cie_reader.read_u8(), b'z');
|
||||
|
||||
// Establish a pointer that skips ahead of the string
|
||||
// Skip code/data alignment factors & return address register along the way as well
|
||||
// We only tackle the case where 'z' and 'R' are part of the augmentation string, otherwise
|
||||
// we cannot get the addresses to make .eh_frame_hdr
|
||||
let mut aug_data_reader = DwarfReader::from_reader(cie_reader, false);
|
||||
let mut aug_str_len = 0;
|
||||
loop {
|
||||
if aug_data_reader.read_u8() == b'\0' {
|
||||
break;
|
||||
}
|
||||
aug_str_len += 1;
|
||||
}
|
||||
if aug_str_len == 0 {
|
||||
unimplemented!();
|
||||
}
|
||||
aug_data_reader.read_uleb128(); // Code alignment factor
|
||||
aug_data_reader.read_sleb128(); // Data alignment factor
|
||||
aug_data_reader.read_uleb128(); // Return address register
|
||||
aug_data_reader.read_uleb128(); // Augmentation data length
|
||||
let mut fde_pointer_encoding = DW_EH_PE_omit;
|
||||
for _ in 0..aug_str_len {
|
||||
match cie_reader.read_u8() {
|
||||
b'L' => {
|
||||
aug_data_reader.read_u8();
|
||||
}
|
||||
|
||||
b'P' => {
|
||||
let encoding = aug_data_reader.read_u8();
|
||||
read_encoded_pointer(&mut aug_data_reader, encoding)?;
|
||||
}
|
||||
|
||||
b'R' => {
|
||||
fde_pointer_encoding = aug_data_reader.read_u8();
|
||||
}
|
||||
|
||||
// Other characters are not supported
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
assert_ne!(fde_pointer_encoding, DW_EH_PE_omit);
|
||||
|
||||
Ok(CFI_Record { fde_pointer_encoding, fde_reader })
|
||||
}
|
||||
|
||||
/// Returns a [DwarfReader] initialized to the first Frame Description Entry (FDE) of this CFI
|
||||
/// record.
|
||||
pub fn get_fde_reader(&self) -> DwarfReader<'a> {
|
||||
DwarfReader::from_reader(&self.fde_reader, true)
|
||||
}
|
||||
|
||||
/// Returns an [Iterator] over all Frame Description Entries (FDEs).
|
||||
pub fn fde_records(&self) -> FDE_Records<'a> {
|
||||
let reader = self.get_fde_reader();
|
||||
let len = reader.slice.len();
|
||||
|
||||
FDE_Records { pointer_encoding: self.fde_pointer_encoding, reader, available: len }
|
||||
}
|
||||
}
|
||||
|
||||
/// [Iterator] over Call Frame Information (CFI) records in an
|
||||
/// [Exception Handling (EH) frame][EH_Frame].
|
||||
pub struct CFI_Records<'a> {
|
||||
reader: DwarfReader<'a>,
|
||||
available: usize,
|
||||
}
|
||||
|
||||
impl<'a> Iterator for CFI_Records<'a> {
|
||||
type Item = CFI_Record<'a>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
loop {
|
||||
if self.available == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut this_reader = DwarfReader::from_reader(&self.reader, false);
|
||||
|
||||
// Remove the length of the header and the content from the counter
|
||||
let length = self.reader.read_u32();
|
||||
let length = match length {
|
||||
// eh_frame with 0-length means the CIE is terminated
|
||||
0 => return None,
|
||||
0xFFFF_FFFF => unimplemented!("CIE entries larger than 4 bytes not supported"),
|
||||
other => other,
|
||||
} as usize;
|
||||
|
||||
// Remove the length of the header and the content from the counter
|
||||
self.available -= length + mem::size_of::<u32>();
|
||||
let mut next_reader = DwarfReader::from_reader(&self.reader, false);
|
||||
next_reader.offset(length as u32);
|
||||
|
||||
let cie_ptr = self.reader.read_u32();
|
||||
|
||||
self.reader = next_reader;
|
||||
|
||||
// Skip this record if it is a FDE
|
||||
if cie_ptr == 0 {
|
||||
// Rewind back to the start of the CFI Record
|
||||
return Some(CFI_Record::from_reader(&mut this_reader).ok().unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// [Iterator] over Frame Description Entries (FDEs) in an
|
||||
/// [Exception Handling (EH) frame][EH_Frame].
|
||||
pub struct FDE_Records<'a> {
|
||||
pointer_encoding: u8,
|
||||
reader: DwarfReader<'a>,
|
||||
available: usize,
|
||||
}
|
||||
|
||||
impl<'a> Iterator for FDE_Records<'a> {
|
||||
type Item = (u32, u32);
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
// Parse each FDE to obtain the starting address that the FDE applies to
|
||||
// Send the FDE offset and the mentioned address to a callback that write up the
|
||||
// .eh_frame_hdr section
|
||||
|
||||
if self.available == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Remove the length of the header and the content from the counter
|
||||
let length = match self.reader.read_u32() {
|
||||
// eh_frame with 0-length means the CIE is terminated
|
||||
0 => return None,
|
||||
0xFFFF_FFFF => unimplemented!("CIE entries larger than 4 bytes not supported"),
|
||||
other => other,
|
||||
} as usize;
|
||||
|
||||
// Remove the length of the header and the content from the counter
|
||||
self.available -= length + mem::size_of::<u32>();
|
||||
let mut next_fde_reader = DwarfReader::from_reader(&self.reader, false);
|
||||
next_fde_reader.offset(length as u32);
|
||||
|
||||
let cie_ptr = self.reader.read_u32();
|
||||
let next_val = if cie_ptr != 0 {
|
||||
let pc_begin = read_encoded_pointer_with_pc(&mut self.reader, self.pointer_encoding)
|
||||
.expect("Failed to read PC Begin");
|
||||
Some((pc_begin as u32, self.reader.virt_addr))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
self.reader = next_fde_reader;
|
||||
|
||||
next_val
|
||||
}
|
||||
}
|
||||
|
||||
pub struct EH_Frame_Hdr<'a> {
|
||||
fde_writer: DwarfWriter<'a>,
|
||||
eh_frame_hdr_addr: u32,
|
||||
fdes: Vec<(u32, u32)>,
|
||||
}
|
||||
|
||||
impl<'a> EH_Frame_Hdr<'a> {
|
||||
/// Create a [EH_Frame_Hdr] object, and write out the fixed fields of `.eh_frame_hdr` to memory.
|
||||
///
|
||||
/// Load address is not known at this point.
|
||||
pub fn new(
|
||||
eh_frame_hdr_slice: &mut [u8],
|
||||
eh_frame_hdr_addr: u32,
|
||||
eh_frame_addr: u32,
|
||||
) -> EH_Frame_Hdr {
|
||||
let mut writer = DwarfWriter::new(eh_frame_hdr_slice);
|
||||
|
||||
writer.write_u8(1); // version
|
||||
writer.write_u8(0x1B); // eh_frame_ptr_enc - PC-relative 4-byte signed value
|
||||
writer.write_u8(0x03); // fde_count_enc - 4-byte unsigned value
|
||||
writer.write_u8(0x3B); // table_enc - .eh_frame_hdr section-relative 4-byte signed value
|
||||
|
||||
let eh_frame_offset = eh_frame_addr.wrapping_sub(
|
||||
eh_frame_hdr_addr + writer.offset as u32 + ((mem::size_of::<u8>() as u32) * 4),
|
||||
);
|
||||
writer.write_u32(eh_frame_offset); // eh_frame_ptr
|
||||
writer.write_u32(0); // `fde_count`, will be written in finalize_fde
|
||||
|
||||
EH_Frame_Hdr { fde_writer: writer, eh_frame_hdr_addr, fdes: Vec::new() }
|
||||
}
|
||||
|
||||
/// The offset of the `fde_count` value relative to the start of the `.eh_frame_hdr` section in
|
||||
/// bytes.
|
||||
fn fde_count_offset() -> usize {
|
||||
8
|
||||
}
|
||||
|
||||
pub fn add_fde(&mut self, init_loc: u32, addr: u32) {
|
||||
self.fdes.push((
|
||||
init_loc.wrapping_sub(self.eh_frame_hdr_addr),
|
||||
addr.wrapping_sub(self.eh_frame_hdr_addr),
|
||||
));
|
||||
}
|
||||
|
||||
pub fn finalize_fde(mut self) {
|
||||
self.fdes
|
||||
.sort_by(|(left_init_loc, _), (right_init_loc, _)| left_init_loc.cmp(right_init_loc));
|
||||
for (init_loc, addr) in &self.fdes {
|
||||
self.fde_writer.write_u32(*init_loc);
|
||||
self.fde_writer.write_u32(*addr);
|
||||
}
|
||||
LittleEndian::write_u32(
|
||||
&mut self.fde_writer.slice[Self::fde_count_offset()..],
|
||||
self.fdes.len() as u32,
|
||||
);
|
||||
}
|
||||
|
||||
pub fn size_from_eh_frame(eh_frame: &[u8]) -> usize {
|
||||
// The virtual address of the EH frame does not matter in this case
|
||||
// Calculation of size does not involve modifying any headers
|
||||
let mut reader = DwarfReader::new(eh_frame, 0);
|
||||
let mut fde_count = 0;
|
||||
while !reader.slice.is_empty() {
|
||||
// The original length field should be able to hold the entire value.
|
||||
// The device memory space is limited to 32-bits addresses anyway.
|
||||
let entry_length = reader.read_u32();
|
||||
if entry_length == 0 || entry_length == 0xFFFF_FFFF {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
// This slot stores the CIE ID (for CIE)/CIE Pointer (for FDE).
|
||||
// This value must be non-zero for FDEs.
|
||||
let cie_ptr = reader.read_u32();
|
||||
if cie_ptr != 0 {
|
||||
fde_count += 1;
|
||||
}
|
||||
|
||||
reader.offset(entry_length - mem::size_of::<u32>() as u32);
|
||||
}
|
||||
|
||||
12 + fde_count * 8
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -5,20 +5,20 @@ description = "Parser for python code."
|
|||
authors = [ "RustPython Team", "M-Labs" ]
|
||||
build = "build.rs"
|
||||
license = "MIT"
|
||||
edition = "2018"
|
||||
edition = "2021"
|
||||
|
||||
[build-dependencies]
|
||||
lalrpop = "0.19.6"
|
||||
lalrpop = "0.20"
|
||||
|
||||
[dependencies]
|
||||
nac3ast = { path = "../nac3ast" }
|
||||
lalrpop-util = "0.19.6"
|
||||
log = "0.4.1"
|
||||
lalrpop-util = "0.20"
|
||||
log = "0.4"
|
||||
unic-emoji-char = "0.9"
|
||||
unic-ucd-ident = "0.9"
|
||||
unicode_names2 = "0.4"
|
||||
phf = { version = "0.9", features = ["macros"] }
|
||||
ahash = "0.7.2"
|
||||
unicode_names2 = "1.2"
|
||||
phf = { version = "0.11", features = ["macros"] }
|
||||
ahash = "0.8"
|
||||
|
||||
[dev-dependencies]
|
||||
insta = "=1.11.0"
|
||||
|
|
|
@ -1,15 +1,15 @@
|
|||
use lalrpop_util::ParseError;
|
||||
use nac3ast::*;
|
||||
use crate::ast::Ident;
|
||||
use crate::ast::Location;
|
||||
use crate::token::Tok;
|
||||
use crate::error::*;
|
||||
use crate::token::Tok;
|
||||
use lalrpop_util::ParseError;
|
||||
use nac3ast::*;
|
||||
|
||||
pub fn make_config_comment(
|
||||
com_loc: Location,
|
||||
stmt_loc: Location,
|
||||
nac3com_above: Vec<(Ident, Tok)>,
|
||||
nac3com_end: Option<Ident>
|
||||
nac3com_end: Option<Ident>,
|
||||
) -> Result<Vec<Ident>, ParseError<Location, Tok, LexicalError>> {
|
||||
if com_loc.column() != stmt_loc.column() && !nac3com_above.is_empty() {
|
||||
return Err(ParseError::User {
|
||||
|
@ -17,24 +17,25 @@ pub fn make_config_comment(
|
|||
location: com_loc,
|
||||
error: LexicalErrorType::OtherError(
|
||||
format!(
|
||||
"config comment at top must have the same indentation with what it applies (comment at {}, statement at {})",
|
||||
com_loc,
|
||||
stmt_loc,
|
||||
"config comment at top must have the same indentation with what it applies (comment at {com_loc}, statement at {stmt_loc})",
|
||||
)
|
||||
)
|
||||
}
|
||||
})
|
||||
});
|
||||
};
|
||||
Ok(
|
||||
nac3com_above
|
||||
.into_iter()
|
||||
.map(|(com, _)| com)
|
||||
.chain(nac3com_end.map_or_else(|| vec![].into_iter(), |com| vec![com].into_iter()))
|
||||
.collect()
|
||||
)
|
||||
Ok(nac3com_above
|
||||
.into_iter()
|
||||
.map(|(com, _)| com)
|
||||
.chain(nac3com_end.map_or_else(|| vec![].into_iter(), |com| vec![com].into_iter()))
|
||||
.collect())
|
||||
}
|
||||
|
||||
pub fn handle_small_stmt<U>(stmts: &mut [Stmt<U>], nac3com_above: Vec<(Ident, Tok)>, nac3com_end: Option<Ident>, com_above_loc: Location) -> Result<(), ParseError<Location, Tok, LexicalError>> {
|
||||
pub fn handle_small_stmt<U>(
|
||||
stmts: &mut [Stmt<U>],
|
||||
nac3com_above: Vec<(Ident, Tok)>,
|
||||
nac3com_end: Option<Ident>,
|
||||
com_above_loc: Location,
|
||||
) -> Result<(), ParseError<Location, Tok, LexicalError>> {
|
||||
if com_above_loc.column() != stmts[0].location.column() && !nac3com_above.is_empty() {
|
||||
return Err(ParseError::User {
|
||||
error: LexicalError {
|
||||
|
@ -47,17 +48,12 @@ pub fn handle_small_stmt<U>(stmts: &mut [Stmt<U>], nac3com_above: Vec<(Ident, To
|
|||
)
|
||||
)
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
apply_config_comments(
|
||||
&mut stmts[0],
|
||||
nac3com_above
|
||||
.into_iter()
|
||||
.map(|(com, _)| com).collect()
|
||||
);
|
||||
apply_config_comments(&mut stmts[0], nac3com_above.into_iter().map(|(com, _)| com).collect());
|
||||
apply_config_comments(
|
||||
stmts.last_mut().unwrap(),
|
||||
nac3com_end.map_or_else(Vec::new, |com| vec![com])
|
||||
nac3com_end.map_or_else(Vec::new, |com| vec![com]),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -80,6 +76,8 @@ fn apply_config_comments<U>(stmt: &mut Stmt<U>, comments: Vec<Ident>) {
|
|||
| StmtKind::Nonlocal { config_comment, .. }
|
||||
| StmtKind::Assert { config_comment, .. } => config_comment.extend(comments),
|
||||
|
||||
_ => { unreachable!("only small statements should call this function") }
|
||||
_ => {
|
||||
unreachable!("only small statements should call this function")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,7 +37,7 @@ impl fmt::Display for LexicalErrorType {
|
|||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
LexicalErrorType::StringError => write!(f, "Got unexpected string"),
|
||||
LexicalErrorType::FStringError(error) => write!(f, "Got error in f-string: {}", error),
|
||||
LexicalErrorType::FStringError(error) => write!(f, "Got error in f-string: {error}"),
|
||||
LexicalErrorType::UnicodeError => write!(f, "Got unexpected unicode"),
|
||||
LexicalErrorType::NestingError => write!(f, "Got unexpected nesting"),
|
||||
LexicalErrorType::IndentationError => {
|
||||
|
@ -59,13 +59,13 @@ impl fmt::Display for LexicalErrorType {
|
|||
write!(f, "positional argument follows keyword argument")
|
||||
}
|
||||
LexicalErrorType::UnrecognizedToken { tok } => {
|
||||
write!(f, "Got unexpected token {}", tok)
|
||||
write!(f, "Got unexpected token {tok}")
|
||||
}
|
||||
LexicalErrorType::LineContinuationError => {
|
||||
write!(f, "unexpected character after line continuation character")
|
||||
}
|
||||
LexicalErrorType::Eof => write!(f, "unexpected EOF while parsing"),
|
||||
LexicalErrorType::OtherError(msg) => write!(f, "{}", msg),
|
||||
LexicalErrorType::OtherError(msg) => write!(f, "{msg}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -96,7 +96,7 @@ impl fmt::Display for FStringErrorType {
|
|||
FStringErrorType::UnopenedRbrace => write!(f, "Unopened '}}'"),
|
||||
FStringErrorType::ExpectedRbrace => write!(f, "Expected '}}' after conversion flag."),
|
||||
FStringErrorType::InvalidExpression(error) => {
|
||||
write!(f, "Invalid expression: {}", error)
|
||||
write!(f, "Invalid expression: {error}")
|
||||
}
|
||||
FStringErrorType::InvalidConversionFlag => write!(f, "Invalid conversion flag"),
|
||||
FStringErrorType::EmptyExpression => write!(f, "Empty expression"),
|
||||
|
@ -144,36 +144,27 @@ pub enum ParseErrorType {
|
|||
impl From<LalrpopError<Location, Tok, LexicalError>> for ParseError {
|
||||
fn from(err: LalrpopError<Location, Tok, LexicalError>) -> Self {
|
||||
match err {
|
||||
// TODO: Are there cases where this isn't an EOF?
|
||||
LalrpopError::InvalidToken { location } => ParseError {
|
||||
error: ParseErrorType::Eof,
|
||||
location,
|
||||
},
|
||||
LalrpopError::ExtraToken { token } => ParseError {
|
||||
error: ParseErrorType::ExtraToken(token.1),
|
||||
location: token.0,
|
||||
},
|
||||
LalrpopError::User { error } => ParseError {
|
||||
error: ParseErrorType::Lexical(error.error),
|
||||
location: error.location,
|
||||
},
|
||||
LalrpopError::ExtraToken { token } => {
|
||||
ParseError { error: ParseErrorType::ExtraToken(token.1), location: token.0 }
|
||||
}
|
||||
LalrpopError::User { error } => {
|
||||
ParseError { error: ParseErrorType::Lexical(error.error), location: error.location }
|
||||
}
|
||||
LalrpopError::UnrecognizedToken { token, expected } => {
|
||||
// Hacky, but it's how CPython does it. See PyParser_AddToken,
|
||||
// in particular "Only one possible expected token" comment.
|
||||
let expected = if expected.len() == 1 {
|
||||
Some(expected[0].clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let expected = if expected.len() == 1 { Some(expected[0].clone()) } else { None };
|
||||
ParseError {
|
||||
error: ParseErrorType::UnrecognizedToken(token.1, expected),
|
||||
location: token.0,
|
||||
}
|
||||
}
|
||||
LalrpopError::UnrecognizedEOF { location, .. } => ParseError {
|
||||
error: ParseErrorType::Eof,
|
||||
location,
|
||||
},
|
||||
|
||||
LalrpopError::UnrecognizedEof { location, .. }
|
||||
// TODO: Are there cases where this isn't an EOF?
|
||||
| LalrpopError::InvalidToken { location } => {
|
||||
ParseError { error: ParseErrorType::Eof, location }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -188,7 +179,7 @@ impl fmt::Display for ParseErrorType {
|
|||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match *self {
|
||||
ParseErrorType::Eof => write!(f, "Got unexpected EOF"),
|
||||
ParseErrorType::ExtraToken(ref tok) => write!(f, "Got extraneous token: {:?}", tok),
|
||||
ParseErrorType::ExtraToken(ref tok) => write!(f, "Got extraneous token: {tok:?}"),
|
||||
ParseErrorType::InvalidToken => write!(f, "Got invalid token"),
|
||||
ParseErrorType::UnrecognizedToken(ref tok, ref expected) => {
|
||||
if *tok == Tok::Indent {
|
||||
|
@ -196,10 +187,10 @@ impl fmt::Display for ParseErrorType {
|
|||
} else if expected.as_deref() == Some("Indent") {
|
||||
write!(f, "expected an indented block")
|
||||
} else {
|
||||
write!(f, "Got unexpected token {}", tok)
|
||||
write!(f, "Got unexpected token {tok}")
|
||||
}
|
||||
}
|
||||
ParseErrorType::Lexical(ref error) => write!(f, "{}", error),
|
||||
ParseErrorType::Lexical(ref error) => write!(f, "{error}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -207,6 +198,7 @@ impl fmt::Display for ParseErrorType {
|
|||
impl Error for ParseErrorType {}
|
||||
|
||||
impl ParseErrorType {
|
||||
#[must_use]
|
||||
pub fn is_indentation_error(&self) -> bool {
|
||||
match self {
|
||||
ParseErrorType::Lexical(LexicalErrorType::IndentationError) => true,
|
||||
|
@ -216,11 +208,11 @@ impl ParseErrorType {
|
|||
_ => false,
|
||||
}
|
||||
}
|
||||
#[must_use]
|
||||
pub fn is_tab_error(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
ParseErrorType::Lexical(LexicalErrorType::TabError)
|
||||
| ParseErrorType::Lexical(LexicalErrorType::TabsAfterSpaces)
|
||||
ParseErrorType::Lexical(LexicalErrorType::TabError | LexicalErrorType::TabsAfterSpaces)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,10 +15,7 @@ struct FStringParser<'a> {
|
|||
|
||||
impl<'a> FStringParser<'a> {
|
||||
fn new(source: &'a str, str_location: Location) -> Self {
|
||||
Self {
|
||||
chars: source.chars().peekable(),
|
||||
str_location,
|
||||
}
|
||||
Self { chars: source.chars().peekable(), str_location }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
@ -133,10 +130,10 @@ impl<'a> FStringParser<'a> {
|
|||
)
|
||||
} else {
|
||||
Box::new(self.expr(ExprKind::Constant {
|
||||
value: spec_expression.to_owned().into(),
|
||||
value: spec_expression.clone().into(),
|
||||
kind: None,
|
||||
}))
|
||||
})
|
||||
});
|
||||
}
|
||||
'(' | '{' | '[' => {
|
||||
expression.push(ch);
|
||||
|
@ -251,17 +248,11 @@ impl<'a> FStringParser<'a> {
|
|||
}
|
||||
|
||||
if !content.is_empty() {
|
||||
values.push(self.expr(ExprKind::Constant {
|
||||
value: content.into(),
|
||||
kind: None,
|
||||
}))
|
||||
values.push(self.expr(ExprKind::Constant { value: content.into(), kind: None }));
|
||||
}
|
||||
|
||||
let s = match values.len() {
|
||||
0 => self.expr(ExprKind::Constant {
|
||||
value: String::new().into(),
|
||||
kind: None,
|
||||
}),
|
||||
0 => self.expr(ExprKind::Constant { value: String::new().into(), kind: None }),
|
||||
1 => values.into_iter().next().unwrap(),
|
||||
_ => self.expr(ExprKind::JoinedStr { values }),
|
||||
};
|
||||
|
@ -270,16 +261,14 @@ impl<'a> FStringParser<'a> {
|
|||
}
|
||||
|
||||
fn parse_fstring_expr(source: &str) -> Result<Expr, ParseError> {
|
||||
let fstring_body = format!("({})", source);
|
||||
let fstring_body = format!("({source})");
|
||||
parse_expression(&fstring_body)
|
||||
}
|
||||
|
||||
/// Parse an fstring from a string, located at a certain position in the sourcecode.
|
||||
/// In case of errors, we will get the location and the error returned.
|
||||
pub fn parse_located_fstring(source: &str, location: Location) -> Result<Expr, FStringError> {
|
||||
FStringParser::new(source, location)
|
||||
.parse()
|
||||
.map_err(|error| FStringError { error, location })
|
||||
FStringParser::new(source, location).parse().map_err(|error| FStringError { error, location })
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
@ -54,38 +54,32 @@ pub fn parse_args(func_args: Vec<FunctionArgument>) -> Result<ArgumentList, Lexi
|
|||
|
||||
let mut keyword_names = HashSet::with_capacity_and_hasher(func_args.len(), RandomState::new());
|
||||
for (name, value) in func_args {
|
||||
match name {
|
||||
Some((location, name)) => {
|
||||
if let Some(keyword_name) = &name {
|
||||
if keyword_names.contains(keyword_name) {
|
||||
return Err(LexicalError {
|
||||
error: LexicalErrorType::DuplicateKeywordArgumentError,
|
||||
location,
|
||||
});
|
||||
}
|
||||
|
||||
keyword_names.insert(keyword_name.clone());
|
||||
}
|
||||
|
||||
keywords.push(ast::Keyword::new(
|
||||
location,
|
||||
ast::KeywordData {
|
||||
arg: name.map(|name| name.into()),
|
||||
value: Box::new(value),
|
||||
},
|
||||
));
|
||||
}
|
||||
None => {
|
||||
// Allow starred args after keyword arguments.
|
||||
if !keywords.is_empty() && !is_starred(&value) {
|
||||
if let Some((location, name)) = name {
|
||||
if let Some(keyword_name) = &name {
|
||||
if keyword_names.contains(keyword_name) {
|
||||
return Err(LexicalError {
|
||||
error: LexicalErrorType::PositionalArgumentError,
|
||||
location: value.location,
|
||||
error: LexicalErrorType::DuplicateKeywordArgumentError,
|
||||
location,
|
||||
});
|
||||
}
|
||||
|
||||
args.push(value);
|
||||
keyword_names.insert(keyword_name.clone());
|
||||
}
|
||||
|
||||
keywords.push(ast::Keyword::new(
|
||||
location,
|
||||
ast::KeywordData { arg: name.map(String::into), value: Box::new(value) },
|
||||
));
|
||||
} else {
|
||||
// Allow starred args after keyword arguments.
|
||||
if !keywords.is_empty() && !is_starred(&value) {
|
||||
return Err(LexicalError {
|
||||
error: LexicalErrorType::PositionalArgumentError,
|
||||
location: value.location,
|
||||
});
|
||||
}
|
||||
|
||||
args.push(value);
|
||||
}
|
||||
}
|
||||
Ok(ArgumentList { args, keywords })
|
||||
|
|
|
@ -3,12 +3,12 @@
|
|||
//! This means source code is translated into separate tokens.
|
||||
|
||||
pub use super::token::Tok;
|
||||
use crate::ast::{Location, FileName};
|
||||
use crate::ast::{FileName, Location};
|
||||
use crate::error::{LexicalError, LexicalErrorType};
|
||||
use std::char;
|
||||
use std::cmp::Ordering;
|
||||
use std::str::FromStr;
|
||||
use std::num::IntErrorKind;
|
||||
use std::str::FromStr;
|
||||
use unic_emoji_char::is_emoji_presentation;
|
||||
use unic_ucd_ident::{is_xid_continue, is_xid_start};
|
||||
|
||||
|
@ -32,20 +32,14 @@ impl IndentationLevel {
|
|||
if self.spaces <= other.spaces {
|
||||
Ok(Ordering::Less)
|
||||
} else {
|
||||
Err(LexicalError {
|
||||
location,
|
||||
error: LexicalErrorType::TabError,
|
||||
})
|
||||
Err(LexicalError { location, error: LexicalErrorType::TabError })
|
||||
}
|
||||
}
|
||||
Ordering::Greater => {
|
||||
if self.spaces >= other.spaces {
|
||||
Ok(Ordering::Greater)
|
||||
} else {
|
||||
Err(LexicalError {
|
||||
location,
|
||||
error: LexicalErrorType::TabError,
|
||||
})
|
||||
Err(LexicalError { location, error: LexicalErrorType::TabError })
|
||||
}
|
||||
}
|
||||
Ordering::Equal => Ok(self.spaces.cmp(&other.spaces)),
|
||||
|
@ -63,7 +57,7 @@ pub struct Lexer<T: Iterator<Item = char>> {
|
|||
chr1: Option<char>,
|
||||
chr2: Option<char>,
|
||||
location: Location,
|
||||
config_comment_prefix: Option<&'static str>
|
||||
config_comment_prefix: Option<&'static str>,
|
||||
}
|
||||
|
||||
pub static KEYWORDS: phf::Map<&'static str, Tok> = phf::phf_map! {
|
||||
|
@ -136,11 +130,7 @@ where
|
|||
T: Iterator<Item = char>,
|
||||
{
|
||||
pub fn new(source: T) -> Self {
|
||||
let mut nlh = NewlineHandler {
|
||||
source,
|
||||
chr0: None,
|
||||
chr1: None,
|
||||
};
|
||||
let mut nlh = NewlineHandler { source, chr0: None, chr1: None };
|
||||
nlh.shift();
|
||||
nlh.shift();
|
||||
nlh
|
||||
|
@ -169,7 +159,7 @@ where
|
|||
self.shift();
|
||||
} else {
|
||||
// Transform MAC EOL into \n
|
||||
self.chr0 = Some('\n')
|
||||
self.chr0 = Some('\n');
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
|
@ -189,13 +179,13 @@ where
|
|||
chars: input,
|
||||
at_begin_of_line: true,
|
||||
nesting: 0,
|
||||
indentation_stack: vec![Default::default()],
|
||||
indentation_stack: vec![IndentationLevel::default()],
|
||||
pending: Vec::new(),
|
||||
chr0: None,
|
||||
location: start,
|
||||
chr1: None,
|
||||
chr2: None,
|
||||
config_comment_prefix: Some(" nac3:")
|
||||
config_comment_prefix: Some(" nac3:"),
|
||||
};
|
||||
lxr.next_char();
|
||||
lxr.next_char();
|
||||
|
@ -217,11 +207,9 @@ where
|
|||
let mut saw_f = false;
|
||||
loop {
|
||||
// Detect r"", f"", b"" and u""
|
||||
if !(saw_b || saw_u || saw_f) && matches!(self.chr0, Some('b') | Some('B')) {
|
||||
if !(saw_b || saw_u || saw_f) && matches!(self.chr0, Some('b' | 'B')) {
|
||||
saw_b = true;
|
||||
} else if !(saw_b || saw_r || saw_u || saw_f)
|
||||
&& matches!(self.chr0, Some('u') | Some('U'))
|
||||
{
|
||||
} else if !(saw_b || saw_r || saw_u || saw_f) && matches!(self.chr0, Some('u' | 'U')) {
|
||||
saw_u = true;
|
||||
} else if !(saw_r || saw_u) && (self.chr0 == Some('r') || self.chr0 == Some('R')) {
|
||||
saw_r = true;
|
||||
|
@ -287,15 +275,15 @@ where
|
|||
let end_pos = self.get_pos();
|
||||
let value = match i128::from_str_radix(&value_text, radix) {
|
||||
Ok(value) => value,
|
||||
Err(e) => {
|
||||
match e.kind() {
|
||||
IntErrorKind::PosOverflow | IntErrorKind::NegOverflow => i128::MAX,
|
||||
_ => return Err(LexicalError {
|
||||
error: LexicalErrorType::OtherError(format!("{:?}", e)),
|
||||
Err(e) => match e.kind() {
|
||||
IntErrorKind::PosOverflow | IntErrorKind::NegOverflow => i128::MAX,
|
||||
_ => {
|
||||
return Err(LexicalError {
|
||||
error: LexicalErrorType::OtherError(format!("{e:?}")),
|
||||
location: start_pos,
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
},
|
||||
};
|
||||
Ok((start_pos, Tok::Int { value }, end_pos))
|
||||
}
|
||||
|
@ -338,14 +326,7 @@ where
|
|||
if self.chr0 == Some('j') || self.chr0 == Some('J') {
|
||||
self.next_char();
|
||||
let end_pos = self.get_pos();
|
||||
Ok((
|
||||
start_pos,
|
||||
Tok::Complex {
|
||||
real: 0.0,
|
||||
imag: value,
|
||||
},
|
||||
end_pos,
|
||||
))
|
||||
Ok((start_pos, Tok::Complex { real: 0.0, imag: value }, end_pos))
|
||||
} else {
|
||||
let end_pos = self.get_pos();
|
||||
Ok((start_pos, Tok::Float { value }, end_pos))
|
||||
|
@ -364,7 +345,7 @@ where
|
|||
let value = value_text.parse::<i128>().ok();
|
||||
let nonzero = match value {
|
||||
Some(value) => value != 0i128,
|
||||
None => true
|
||||
None => true,
|
||||
};
|
||||
if start_is_zero && nonzero {
|
||||
return Err(LexicalError {
|
||||
|
@ -379,7 +360,7 @@ where
|
|||
|
||||
/// Consume a sequence of numbers with the given radix,
|
||||
/// the digits can be decorated with underscores
|
||||
/// like this: '1_2_3_4' == '1234'
|
||||
/// like this: `'1_2_3_4'` == `'1234'`
|
||||
fn radix_run(&mut self, radix: u32) -> String {
|
||||
let mut value_text = String::new();
|
||||
|
||||
|
@ -412,7 +393,7 @@ where
|
|||
2 => matches!(c, Some('0'..='1')),
|
||||
8 => matches!(c, Some('0'..='7')),
|
||||
10 => matches!(c, Some('0'..='9')),
|
||||
16 => matches!(c, Some('0'..='9') | Some('a'..='f') | Some('A'..='F')),
|
||||
16 => matches!(c, Some('0'..='9' | 'a'..='f' | 'A'..='F')),
|
||||
other => unimplemented!("Radix not implemented: {}", other),
|
||||
}
|
||||
}
|
||||
|
@ -420,8 +401,8 @@ where
|
|||
/// Test if we face '[eE][-+]?[0-9]+'
|
||||
fn at_exponent(&self) -> bool {
|
||||
match self.chr0 {
|
||||
Some('e') | Some('E') => match self.chr1 {
|
||||
Some('+') | Some('-') => matches!(self.chr2, Some('0'..='9')),
|
||||
Some('e' | 'E') => match self.chr1 {
|
||||
Some('+' | '-') => matches!(self.chr2, Some('0'..='9')),
|
||||
Some('0'..='9') => true,
|
||||
_ => false,
|
||||
},
|
||||
|
@ -433,19 +414,17 @@ where
|
|||
fn lex_comment(&mut self) -> Option<Spanned> {
|
||||
self.next_char();
|
||||
// if possibly nac3 pseudocomment, special handling for `# nac3:`
|
||||
let (mut prefix, mut is_comment) = self
|
||||
.config_comment_prefix
|
||||
.map_or_else(|| ("".chars(), false), |v| (v.chars(), true));
|
||||
let (mut prefix, mut is_comment) =
|
||||
self.config_comment_prefix.map_or_else(|| ("".chars(), false), |v| (v.chars(), true));
|
||||
// for the correct location of config comment
|
||||
let mut start_loc = self.location;
|
||||
start_loc.go_left();
|
||||
loop {
|
||||
match self.chr0 {
|
||||
Some('\n') => return None,
|
||||
None => return None,
|
||||
Some('\n') | None => return None,
|
||||
Some(c) => {
|
||||
if let (true, Some(p)) = (is_comment, prefix.next()) {
|
||||
is_comment = is_comment && c == p
|
||||
is_comment = is_comment && c == p;
|
||||
} else {
|
||||
// done checking prefix, if is comment then return the spanned
|
||||
if is_comment {
|
||||
|
@ -460,22 +439,20 @@ where
|
|||
return Some((
|
||||
start_loc,
|
||||
Tok::ConfigComment { content: content.trim().into() },
|
||||
self.location
|
||||
self.location,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
self.next_char();
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
fn unicode_literal(&mut self, literal_number: usize) -> Result<char, LexicalError> {
|
||||
let mut p: u32 = 0u32;
|
||||
let unicode_error = LexicalError {
|
||||
error: LexicalErrorType::UnicodeError,
|
||||
location: self.get_pos(),
|
||||
};
|
||||
let unicode_error =
|
||||
LexicalError { error: LexicalErrorType::UnicodeError, location: self.get_pos() };
|
||||
for i in 1..=literal_number {
|
||||
match self.next_char() {
|
||||
Some(c) => match c.to_digit(16) {
|
||||
|
@ -486,8 +463,8 @@ where
|
|||
}
|
||||
}
|
||||
match p {
|
||||
0xD800..=0xDFFF => Ok(std::char::REPLACEMENT_CHARACTER),
|
||||
_ => std::char::from_u32(p).ok_or(unicode_error),
|
||||
0xD800..=0xDFFF => Ok(char::REPLACEMENT_CHARACTER),
|
||||
_ => char::from_u32(p).ok_or(unicode_error),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -496,7 +473,7 @@ where
|
|||
octet_content.push(first);
|
||||
while octet_content.len() < 3 {
|
||||
if let Some('0'..='7') = self.chr0 {
|
||||
octet_content.push(self.next_char().unwrap())
|
||||
octet_content.push(self.next_char().unwrap());
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
|
@ -530,10 +507,8 @@ where
|
|||
}
|
||||
}
|
||||
}
|
||||
unicode_names2::character(&name).ok_or(LexicalError {
|
||||
error: LexicalErrorType::UnicodeError,
|
||||
location: start_pos,
|
||||
})
|
||||
unicode_names2::character(&name)
|
||||
.ok_or(LexicalError { error: LexicalErrorType::UnicodeError, location: start_pos })
|
||||
}
|
||||
|
||||
fn lex_string(
|
||||
|
@ -566,7 +541,7 @@ where
|
|||
} else if is_raw {
|
||||
string_content.push('\\');
|
||||
if let Some(c) = self.next_char() {
|
||||
string_content.push(c)
|
||||
string_content.push(c);
|
||||
} else {
|
||||
return Err(LexicalError {
|
||||
error: LexicalErrorType::StringError,
|
||||
|
@ -599,7 +574,7 @@ where
|
|||
Some('u') if !is_bytes => string_content.push(self.unicode_literal(4)?),
|
||||
Some('U') if !is_bytes => string_content.push(self.unicode_literal(8)?),
|
||||
Some('N') if !is_bytes => {
|
||||
string_content.push(self.parse_unicode_name()?)
|
||||
string_content.push(self.parse_unicode_name()?);
|
||||
}
|
||||
Some(c) => {
|
||||
string_content.push('\\');
|
||||
|
@ -650,20 +625,15 @@ where
|
|||
let end_pos = self.get_pos();
|
||||
|
||||
let tok = if is_bytes {
|
||||
Tok::Bytes {
|
||||
value: string_content.chars().map(|c| c as u8).collect(),
|
||||
}
|
||||
Tok::Bytes { value: string_content.chars().map(|c| c as u8).collect() }
|
||||
} else {
|
||||
Tok::String {
|
||||
value: string_content,
|
||||
is_fstring,
|
||||
}
|
||||
Tok::String { value: string_content, is_fstring }
|
||||
};
|
||||
|
||||
Ok((start_pos, tok, end_pos))
|
||||
}
|
||||
|
||||
fn is_identifier_start(&self, c: char) -> bool {
|
||||
fn is_identifier_start(c: char) -> bool {
|
||||
match c {
|
||||
'_' | 'a'..='z' | 'A'..='Z' => true,
|
||||
'+' | '-' | '*' | '/' | '=' | ' ' | '<' | '>' => false,
|
||||
|
@ -835,18 +805,14 @@ where
|
|||
// Check if we have some character:
|
||||
if let Some(c) = self.chr0 {
|
||||
// First check identifier:
|
||||
if self.is_identifier_start(c) {
|
||||
if Self::is_identifier_start(c) {
|
||||
let identifier = self.lex_identifier()?;
|
||||
self.emit(identifier);
|
||||
} else if is_emoji_presentation(c) {
|
||||
let tok_start = self.get_pos();
|
||||
self.next_char();
|
||||
let tok_end = self.get_pos();
|
||||
self.emit((
|
||||
tok_start,
|
||||
Tok::Name { name: c.to_string().into() },
|
||||
tok_end,
|
||||
));
|
||||
self.emit((tok_start, Tok::Name { name: c.to_string().into() }, tok_end));
|
||||
} else {
|
||||
self.consume_character(c)?;
|
||||
}
|
||||
|
@ -899,16 +865,13 @@ where
|
|||
'=' => {
|
||||
let tok_start = self.get_pos();
|
||||
self.next_char();
|
||||
match self.chr0 {
|
||||
Some('=') => {
|
||||
self.next_char();
|
||||
let tok_end = self.get_pos();
|
||||
self.emit((tok_start, Tok::EqEqual, tok_end));
|
||||
}
|
||||
_ => {
|
||||
let tok_end = self.get_pos();
|
||||
self.emit((tok_start, Tok::Equal, tok_end));
|
||||
}
|
||||
if let Some('=') = self.chr0 {
|
||||
self.next_char();
|
||||
let tok_end = self.get_pos();
|
||||
self.emit((tok_start, Tok::EqEqual, tok_end));
|
||||
} else {
|
||||
let tok_end = self.get_pos();
|
||||
self.emit((tok_start, Tok::Equal, tok_end));
|
||||
}
|
||||
}
|
||||
'+' => {
|
||||
|
@ -934,16 +897,13 @@ where
|
|||
}
|
||||
Some('*') => {
|
||||
self.next_char();
|
||||
match self.chr0 {
|
||||
Some('=') => {
|
||||
self.next_char();
|
||||
let tok_end = self.get_pos();
|
||||
self.emit((tok_start, Tok::DoubleStarEqual, tok_end));
|
||||
}
|
||||
_ => {
|
||||
let tok_end = self.get_pos();
|
||||
self.emit((tok_start, Tok::DoubleStar, tok_end));
|
||||
}
|
||||
if let Some('=') = self.chr0 {
|
||||
self.next_char();
|
||||
let tok_end = self.get_pos();
|
||||
self.emit((tok_start, Tok::DoubleStarEqual, tok_end));
|
||||
} else {
|
||||
let tok_end = self.get_pos();
|
||||
self.emit((tok_start, Tok::DoubleStar, tok_end));
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
|
@ -963,16 +923,13 @@ where
|
|||
}
|
||||
Some('/') => {
|
||||
self.next_char();
|
||||
match self.chr0 {
|
||||
Some('=') => {
|
||||
self.next_char();
|
||||
let tok_end = self.get_pos();
|
||||
self.emit((tok_start, Tok::DoubleSlashEqual, tok_end));
|
||||
}
|
||||
_ => {
|
||||
let tok_end = self.get_pos();
|
||||
self.emit((tok_start, Tok::DoubleSlash, tok_end));
|
||||
}
|
||||
if let Some('=') = self.chr0 {
|
||||
self.next_char();
|
||||
let tok_end = self.get_pos();
|
||||
self.emit((tok_start, Tok::DoubleSlashEqual, tok_end));
|
||||
} else {
|
||||
let tok_end = self.get_pos();
|
||||
self.emit((tok_start, Tok::DoubleSlash, tok_end));
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
|
@ -1141,16 +1098,13 @@ where
|
|||
match self.chr0 {
|
||||
Some('<') => {
|
||||
self.next_char();
|
||||
match self.chr0 {
|
||||
Some('=') => {
|
||||
self.next_char();
|
||||
let tok_end = self.get_pos();
|
||||
self.emit((tok_start, Tok::LeftShiftEqual, tok_end));
|
||||
}
|
||||
_ => {
|
||||
let tok_end = self.get_pos();
|
||||
self.emit((tok_start, Tok::LeftShift, tok_end));
|
||||
}
|
||||
if let Some('=') = self.chr0 {
|
||||
self.next_char();
|
||||
let tok_end = self.get_pos();
|
||||
self.emit((tok_start, Tok::LeftShiftEqual, tok_end));
|
||||
} else {
|
||||
let tok_end = self.get_pos();
|
||||
self.emit((tok_start, Tok::LeftShift, tok_end));
|
||||
}
|
||||
}
|
||||
Some('=') => {
|
||||
|
@ -1170,16 +1124,13 @@ where
|
|||
match self.chr0 {
|
||||
Some('>') => {
|
||||
self.next_char();
|
||||
match self.chr0 {
|
||||
Some('=') => {
|
||||
self.next_char();
|
||||
let tok_end = self.get_pos();
|
||||
self.emit((tok_start, Tok::RightShiftEqual, tok_end));
|
||||
}
|
||||
_ => {
|
||||
let tok_end = self.get_pos();
|
||||
self.emit((tok_start, Tok::RightShift, tok_end));
|
||||
}
|
||||
if let Some('=') = self.chr0 {
|
||||
self.next_char();
|
||||
let tok_end = self.get_pos();
|
||||
self.emit((tok_start, Tok::RightShiftEqual, tok_end));
|
||||
} else {
|
||||
let tok_end = self.get_pos();
|
||||
self.emit((tok_start, Tok::RightShift, tok_end));
|
||||
}
|
||||
}
|
||||
Some('=') => {
|
||||
|
@ -1439,14 +1390,8 @@ class Foo(A, B):
|
|||
assert_eq!(
|
||||
tokens,
|
||||
vec![
|
||||
Tok::String {
|
||||
value: "\\\\".to_owned(),
|
||||
is_fstring: false,
|
||||
},
|
||||
Tok::String {
|
||||
value: "\\".to_owned(),
|
||||
is_fstring: false,
|
||||
},
|
||||
Tok::String { value: "\\\\".to_owned(), is_fstring: false },
|
||||
Tok::String { value: "\\".to_owned(), is_fstring: false },
|
||||
Tok::Newline,
|
||||
]
|
||||
);
|
||||
|
@ -1459,27 +1404,13 @@ class Foo(A, B):
|
|||
assert_eq!(
|
||||
tokens,
|
||||
vec![
|
||||
Tok::Int {
|
||||
value: 47i128,
|
||||
},
|
||||
Tok::Int {
|
||||
value: 13i128,
|
||||
},
|
||||
Tok::Int {
|
||||
value: 0i128,
|
||||
},
|
||||
Tok::Int {
|
||||
value: 123i128,
|
||||
},
|
||||
Tok::Int { value: 47i128 },
|
||||
Tok::Int { value: 13i128 },
|
||||
Tok::Int { value: 0i128 },
|
||||
Tok::Int { value: 123i128 },
|
||||
Tok::Float { value: 0.2 },
|
||||
Tok::Complex {
|
||||
real: 0.0,
|
||||
imag: 2.0,
|
||||
},
|
||||
Tok::Complex {
|
||||
real: 0.0,
|
||||
imag: 2.2,
|
||||
},
|
||||
Tok::Complex { real: 0.0, imag: 2.0 },
|
||||
Tok::Complex { real: 0.0, imag: 2.2 },
|
||||
Tok::Newline,
|
||||
]
|
||||
);
|
||||
|
@ -1539,21 +1470,13 @@ class Foo(A, B):
|
|||
assert_eq!(
|
||||
tokens,
|
||||
vec![
|
||||
Tok::Name {
|
||||
name: String::from("avariable").into(),
|
||||
},
|
||||
Tok::Name { name: String::from("avariable").into() },
|
||||
Tok::Equal,
|
||||
Tok::Int {
|
||||
value: 99i128
|
||||
},
|
||||
Tok::Int { value: 99i128 },
|
||||
Tok::Plus,
|
||||
Tok::Int {
|
||||
value: 2i128
|
||||
},
|
||||
Tok::Int { value: 2i128 },
|
||||
Tok::Minus,
|
||||
Tok::Int {
|
||||
value: 0i128
|
||||
},
|
||||
Tok::Int { value: 0i128 },
|
||||
Tok::Newline,
|
||||
]
|
||||
);
|
||||
|
@ -1740,42 +1663,15 @@ class Foo(A, B):
|
|||
assert_eq!(
|
||||
tokens,
|
||||
vec![
|
||||
Tok::String {
|
||||
value: String::from("double"),
|
||||
is_fstring: false,
|
||||
},
|
||||
Tok::String {
|
||||
value: String::from("single"),
|
||||
is_fstring: false,
|
||||
},
|
||||
Tok::String {
|
||||
value: String::from("can't"),
|
||||
is_fstring: false,
|
||||
},
|
||||
Tok::String {
|
||||
value: String::from("\\\""),
|
||||
is_fstring: false,
|
||||
},
|
||||
Tok::String {
|
||||
value: String::from("\t\r\n"),
|
||||
is_fstring: false,
|
||||
},
|
||||
Tok::String {
|
||||
value: String::from("\\g"),
|
||||
is_fstring: false,
|
||||
},
|
||||
Tok::String {
|
||||
value: String::from("raw\\'"),
|
||||
is_fstring: false,
|
||||
},
|
||||
Tok::String {
|
||||
value: String::from("Đ"),
|
||||
is_fstring: false,
|
||||
},
|
||||
Tok::String {
|
||||
value: String::from("\u{80}\u{0}a"),
|
||||
is_fstring: false,
|
||||
},
|
||||
Tok::String { value: String::from("double"), is_fstring: false },
|
||||
Tok::String { value: String::from("single"), is_fstring: false },
|
||||
Tok::String { value: String::from("can't"), is_fstring: false },
|
||||
Tok::String { value: String::from("\\\""), is_fstring: false },
|
||||
Tok::String { value: String::from("\t\r\n"), is_fstring: false },
|
||||
Tok::String { value: String::from("\\g"), is_fstring: false },
|
||||
Tok::String { value: String::from("raw\\'"), is_fstring: false },
|
||||
Tok::String { value: String::from("Đ"), is_fstring: false },
|
||||
Tok::String { value: String::from("\u{80}\u{0}a"), is_fstring: false },
|
||||
Tok::Newline,
|
||||
]
|
||||
);
|
||||
|
@ -1840,41 +1736,17 @@ class Foo(A, B):
|
|||
fn test_raw_byte_literal() {
|
||||
let source = r"rb'\x1z'";
|
||||
let tokens = lex_source(source);
|
||||
assert_eq!(
|
||||
tokens,
|
||||
vec![
|
||||
Tok::Bytes {
|
||||
value: b"\\x1z".to_vec()
|
||||
},
|
||||
Tok::Newline
|
||||
]
|
||||
);
|
||||
assert_eq!(tokens, vec![Tok::Bytes { value: b"\\x1z".to_vec() }, Tok::Newline]);
|
||||
let source = r"rb'\\'";
|
||||
let tokens = lex_source(source);
|
||||
assert_eq!(
|
||||
tokens,
|
||||
vec![
|
||||
Tok::Bytes {
|
||||
value: b"\\\\".to_vec()
|
||||
},
|
||||
Tok::Newline
|
||||
]
|
||||
)
|
||||
assert_eq!(tokens, vec![Tok::Bytes { value: b"\\\\".to_vec() }, Tok::Newline])
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_escape_octet() {
|
||||
let source = r##"b'\43a\4\1234'"##;
|
||||
let tokens = lex_source(source);
|
||||
assert_eq!(
|
||||
tokens,
|
||||
vec![
|
||||
Tok::Bytes {
|
||||
value: b"#a\x04S4".to_vec()
|
||||
},
|
||||
Tok::Newline
|
||||
]
|
||||
)
|
||||
assert_eq!(tokens, vec![Tok::Bytes { value: b"#a\x04S4".to_vec() }, Tok::Newline])
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -1883,13 +1755,7 @@ class Foo(A, B):
|
|||
let tokens = lex_source(source);
|
||||
assert_eq!(
|
||||
tokens,
|
||||
vec![
|
||||
Tok::String {
|
||||
value: "\u{2002}".to_owned(),
|
||||
is_fstring: false,
|
||||
},
|
||||
Tok::Newline
|
||||
]
|
||||
vec![Tok::String { value: "\u{2002}".to_owned(), is_fstring: false }, Tok::Newline]
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,6 +15,24 @@
|
|||
//!
|
||||
//! ```
|
||||
|
||||
#![deny(
|
||||
future_incompatible,
|
||||
let_underscore,
|
||||
nonstandard_style,
|
||||
rust_2024_compatibility,
|
||||
clippy::all
|
||||
)]
|
||||
#![warn(clippy::pedantic)]
|
||||
#![allow(
|
||||
clippy::enum_glob_use,
|
||||
clippy::fn_params_excessive_bools,
|
||||
clippy::missing_errors_doc,
|
||||
clippy::missing_panics_doc,
|
||||
clippy::module_name_repetitions,
|
||||
clippy::too_many_lines,
|
||||
clippy::wildcard_imports
|
||||
)]
|
||||
|
||||
#[macro_use]
|
||||
extern crate log;
|
||||
use lalrpop_util::lalrpop_mod;
|
||||
|
@ -27,9 +45,16 @@ pub mod lexer;
|
|||
pub mod mode;
|
||||
pub mod parser;
|
||||
lalrpop_mod!(
|
||||
#[allow(clippy::all)]
|
||||
#[allow(unused)]
|
||||
#[allow(
|
||||
future_incompatible,
|
||||
let_underscore,
|
||||
nonstandard_style,
|
||||
rust_2024_compatibility,
|
||||
unused,
|
||||
clippy::all,
|
||||
clippy::pedantic
|
||||
)]
|
||||
python
|
||||
);
|
||||
pub mod token;
|
||||
pub mod config_comment_helper;
|
||||
pub mod token;
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
//! parse a whole program, a single statement, or a single
|
||||
//! expression.
|
||||
|
||||
use nac3ast::Location;
|
||||
use std::iter;
|
||||
|
||||
use crate::ast::{self, FileName};
|
||||
|
@ -63,7 +64,7 @@ pub fn parse_program(source: &str, file: FileName) -> Result<ast::Suite, ParseEr
|
|||
///
|
||||
/// ```
|
||||
pub fn parse_expression(source: &str) -> Result<ast::Expr, ParseError> {
|
||||
parse(source, Mode::Expression, Default::default()).map(|top| match top {
|
||||
parse(source, Mode::Expression, FileName::default()).map(|top| match top {
|
||||
ast::Mod::Expression { body } => *body,
|
||||
_ => unreachable!(),
|
||||
})
|
||||
|
@ -72,12 +73,10 @@ pub fn parse_expression(source: &str) -> Result<ast::Expr, ParseError> {
|
|||
// Parse a given source code
|
||||
pub fn parse(source: &str, mode: Mode, file: FileName) -> Result<ast::Mod, ParseError> {
|
||||
let lxr = lexer::make_tokenizer(source, file);
|
||||
let marker_token = (Default::default(), mode.to_marker(), Default::default());
|
||||
let marker_token = (Location::default(), mode.to_marker(), Location::default());
|
||||
let tokenizer = iter::once(Ok(marker_token)).chain(lxr);
|
||||
|
||||
python::TopParser::new()
|
||||
.parse(tokenizer)
|
||||
.map_err(ParseError::from)
|
||||
python::TopParser::new().parse(tokenizer).map_err(ParseError::from)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
//! Different token definitions.
|
||||
//! Loosely based on token.h from CPython source:
|
||||
use std::fmt::{self, Write};
|
||||
use crate::ast;
|
||||
use std::fmt::{self, Write};
|
||||
|
||||
/// Python source code can be tokenized in a sequence of these tokens.
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
|
@ -111,15 +111,23 @@ impl fmt::Display for Tok {
|
|||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
use Tok::*;
|
||||
match self {
|
||||
Name { name } => write!(f, "'{}'", ast::get_str_from_ref(&ast::get_str_ref_lock(), *name)),
|
||||
Int { value } => if *value != i128::MAX { write!(f, "'{}'", value) } else { write!(f, "'#OFL#'") },
|
||||
Float { value } => write!(f, "'{}'", value),
|
||||
Complex { real, imag } => write!(f, "{}j{}", real, imag),
|
||||
Name { name } => {
|
||||
write!(f, "'{}'", ast::get_str_from_ref(&ast::get_str_ref_lock(), *name))
|
||||
}
|
||||
Int { value } => {
|
||||
if *value == i128::MAX {
|
||||
write!(f, "'#OFL#'")
|
||||
} else {
|
||||
write!(f, "'{value}'")
|
||||
}
|
||||
}
|
||||
Float { value } => write!(f, "'{value}'"),
|
||||
Complex { real, imag } => write!(f, "{real}j{imag}"),
|
||||
String { value, is_fstring } => {
|
||||
if *is_fstring {
|
||||
write!(f, "f")?
|
||||
write!(f, "f")?;
|
||||
}
|
||||
write!(f, "{:?}", value)
|
||||
write!(f, "{value:?}")
|
||||
}
|
||||
Bytes { value } => {
|
||||
write!(f, "b\"")?;
|
||||
|
@ -129,12 +137,16 @@ impl fmt::Display for Tok {
|
|||
10 => f.write_str("\\n")?,
|
||||
13 => f.write_str("\\r")?,
|
||||
32..=126 => f.write_char(*i as char)?,
|
||||
_ => write!(f, "\\x{:02x}", i)?,
|
||||
_ => write!(f, "\\x{i:02x}")?,
|
||||
}
|
||||
}
|
||||
f.write_str("\"")
|
||||
}
|
||||
ConfigComment { content } => write!(f, "ConfigComment: '{}'", ast::get_str_from_ref(&ast::get_str_ref_lock(), *content)),
|
||||
ConfigComment { content } => write!(
|
||||
f,
|
||||
"ConfigComment: '{}'",
|
||||
ast::get_str_from_ref(&ast::get_str_ref_lock(), *content)
|
||||
),
|
||||
Newline => f.write_str("Newline"),
|
||||
Indent => f.write_str("Indent"),
|
||||
Dedent => f.write_str("Dedent"),
|
||||
|
|
|
@ -2,14 +2,18 @@
|
|||
name = "nac3standalone"
|
||||
version = "0.1.0"
|
||||
authors = ["M-Labs"]
|
||||
edition = "2018"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
parking_lot = "0.11.1"
|
||||
parking_lot = "0.12"
|
||||
nac3parser = { path = "../nac3parser" }
|
||||
nac3core = { path = "../nac3core" }
|
||||
|
||||
[dependencies.clap]
|
||||
version = "4.5"
|
||||
features = ["derive"]
|
||||
|
||||
[dependencies.inkwell]
|
||||
version = "0.1.0-beta.4"
|
||||
version = "0.4"
|
||||
default-features = false
|
||||
features = ["llvm13-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]
|
||||
features = ["llvm14-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
*.bc
|
||||
*.o
|
||||
/demo
|
|
@ -0,0 +1,25 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
|
||||
if [ -z "$1" ]; then
|
||||
echo "Requires at least one argument"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
declare -a nac3args
|
||||
while [ $# -gt 1 ]; do
|
||||
nac3args+=("$1")
|
||||
shift
|
||||
done
|
||||
demo="$1"
|
||||
|
||||
echo -n "Checking $demo... "
|
||||
./interpret_demo.py "$demo" > interpreted.log
|
||||
./run_demo.sh --out run.log "${nac3args[@]}" "$demo"
|
||||
./run_demo.sh --lli --out run_lli.log "${nac3args[@]}" "$demo"
|
||||
diff -Nau interpreted.log run.log
|
||||
diff -Nau interpreted.log run_lli.log
|
||||
echo "ok"
|
||||
|
||||
rm -f interpreted.log run.log run_lli.log
|
|
@ -4,12 +4,8 @@ set -e
|
|||
|
||||
count=0
|
||||
for demo in src/*.py; do
|
||||
echo -n "checking $demo... "
|
||||
./interpret_demo.py $demo > interpreted.log
|
||||
./run_demo.sh $demo > run.log
|
||||
diff -Nau interpreted.log run.log
|
||||
echo "ok"
|
||||
let "count+=1"
|
||||
./check_demo.sh "$@" "$demo"
|
||||
((count += 1))
|
||||
done
|
||||
|
||||
echo "Ran $count demo checks - PASSED"
|
||||
|
|
|
@ -0,0 +1,109 @@
|
|||
#include <inttypes.h>
|
||||
#include <math.h>
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#define usize size_t
|
||||
|
||||
double dbl_nan(void) {
|
||||
return NAN;
|
||||
}
|
||||
|
||||
double dbl_inf(void) {
|
||||
return INFINITY;
|
||||
}
|
||||
|
||||
void output_bool(bool x) {
|
||||
puts(x ? "True" : "False");
|
||||
}
|
||||
|
||||
void output_int32(int32_t x) {
|
||||
printf("%"PRId32"\n", x);
|
||||
}
|
||||
|
||||
void output_int64(int64_t x) {
|
||||
printf("%"PRId64"\n", x);
|
||||
}
|
||||
|
||||
void output_uint32(uint32_t x) {
|
||||
printf("%"PRIu32"\n", x);
|
||||
}
|
||||
|
||||
void output_uint64(uint64_t x) {
|
||||
printf("%"PRIu64"\n", x);
|
||||
}
|
||||
|
||||
void output_float64(double x) {
|
||||
if (isnan(x)) {
|
||||
puts("nan");
|
||||
} else {
|
||||
printf("%f\n", x);
|
||||
}
|
||||
}
|
||||
|
||||
void output_asciiart(int32_t x) {
|
||||
static const char *chars = " .,-:;i+hHM$*#@ ";
|
||||
if (x < 0) {
|
||||
putchar('\n');
|
||||
} else {
|
||||
putchar(chars[x]);
|
||||
}
|
||||
}
|
||||
|
||||
struct cslice {
|
||||
void *data;
|
||||
usize len;
|
||||
};
|
||||
|
||||
void output_int32_list(struct cslice *slice) {
|
||||
const int32_t *data = (int32_t *) slice->data;
|
||||
|
||||
putchar('[');
|
||||
for (usize i = 0; i < slice->len; ++i) {
|
||||
if (i == slice->len - 1) {
|
||||
printf("%d", data[i]);
|
||||
} else {
|
||||
printf("%d, ", data[i]);
|
||||
}
|
||||
}
|
||||
putchar(']');
|
||||
putchar('\n');
|
||||
}
|
||||
|
||||
void output_str(struct cslice *slice) {
|
||||
const char *data = (const char *) slice->data;
|
||||
|
||||
for (usize i = 0; i < slice->len; ++i) {
|
||||
putchar(data[i]);
|
||||
}
|
||||
putchar('\n');
|
||||
}
|
||||
|
||||
uint64_t dbg_stack_address(__attribute__((unused)) struct cslice *slice) {
|
||||
int i;
|
||||
void *ptr = (void *) &i;
|
||||
return (uintptr_t) ptr;
|
||||
}
|
||||
|
||||
uint32_t __nac3_personality(uint32_t state, uint32_t exception_object, uint32_t context) {
|
||||
printf("__nac3_personality(state: %u, exception_object: %u, context: %u)\n", state, exception_object, context);
|
||||
exit(101);
|
||||
__builtin_unreachable();
|
||||
}
|
||||
|
||||
uint32_t __nac3_raise(uint32_t state, uint32_t exception_object, uint32_t context) {
|
||||
printf("__nac3_raise(state: %u, exception_object: %u, context: %u)\n", state, exception_object, context);
|
||||
exit(101);
|
||||
__builtin_unreachable();
|
||||
}
|
||||
|
||||
void __nac3_end_catch(void) {}
|
||||
|
||||
extern int32_t run(void);
|
||||
|
||||
int main(void) {
|
||||
run();
|
||||
}
|
|
@ -1,90 +0,0 @@
|
|||
mod cslice {
|
||||
// copied from https://github.com/dherman/cslice
|
||||
use std::marker::PhantomData;
|
||||
use std::slice;
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct CSlice<'a, T> {
|
||||
base: *const T,
|
||||
len: usize,
|
||||
marker: PhantomData<&'a ()>,
|
||||
}
|
||||
|
||||
impl<'a, T> AsRef<[T]> for CSlice<'a, T> {
|
||||
fn as_ref(&self) -> &[T] {
|
||||
unsafe { slice::from_raw_parts(self.base, self.len) }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn output_int32(x: i32) {
|
||||
println!("{}", x);
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn output_int64(x: i64) {
|
||||
println!("{}", x);
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn output_uint32(x: u32) {
|
||||
println!("{}", x);
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn output_uint64(x: u64) {
|
||||
println!("{}", x);
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn output_float64(x: f64) {
|
||||
// debug output to preserve the digits after the decimal points
|
||||
// to match python `print` function
|
||||
println!("{:?}", x);
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn output_asciiart(x: i32) {
|
||||
let chars = " .,-:;i+hHM$*#@ ";
|
||||
if x < 0 {
|
||||
println!("");
|
||||
} else {
|
||||
print!("{}", chars.chars().nth(x as usize).unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn output_int32_list(x: &cslice::CSlice<i32>) {
|
||||
print!("[");
|
||||
let mut it = x.as_ref().iter().peekable();
|
||||
while let Some(e) = it.next() {
|
||||
if it.peek().is_none() {
|
||||
print!("{}", e);
|
||||
} else {
|
||||
print!("{}, ", e);
|
||||
}
|
||||
}
|
||||
println!("]");
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn __nac3_personality(_state: u32, _exception_object: u32, _context: u32) -> u32 {
|
||||
unimplemented!();
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn __nac3_raise(_state: u32, _exception_object: u32, _context: u32) -> u32 {
|
||||
unimplemented!();
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
fn run() -> i32;
|
||||
}
|
||||
|
||||
fn main() {
|
||||
unsafe {
|
||||
run();
|
||||
}
|
||||
}
|
|
@ -3,10 +3,14 @@
|
|||
import sys
|
||||
import importlib.util
|
||||
import importlib.machinery
|
||||
import math
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import pathlib
|
||||
|
||||
from numpy import int32, int64, uint32, uint64
|
||||
from typing import TypeVar, Generic
|
||||
from scipy import special
|
||||
from typing import TypeVar, Generic, Literal, Union
|
||||
|
||||
T = TypeVar('T')
|
||||
class Option(Generic[T]):
|
||||
|
@ -41,26 +45,93 @@ def Some(v: T) -> Option[T]:
|
|||
|
||||
none = Option(None)
|
||||
|
||||
class _ConstGenericMarker:
|
||||
pass
|
||||
|
||||
def ConstGeneric(name, constraint):
|
||||
return TypeVar(name, _ConstGenericMarker, constraint)
|
||||
|
||||
N = TypeVar("N", bound=np.uint64)
|
||||
class _NDArrayDummy(Generic[T, N]):
|
||||
pass
|
||||
|
||||
# https://stackoverflow.com/questions/67803260/how-to-create-a-type-alias-with-a-throw-away-generic
|
||||
NDArray = Union[npt.NDArray[T], _NDArrayDummy[T, N]]
|
||||
|
||||
def _bool(x):
|
||||
if isinstance(x, np.ndarray):
|
||||
return np.bool_(x)
|
||||
else:
|
||||
return bool(x)
|
||||
|
||||
def _float(x):
|
||||
if isinstance(x, np.ndarray):
|
||||
return np.float_(x)
|
||||
else:
|
||||
return float(x)
|
||||
|
||||
def round_away_zero(x):
|
||||
if isinstance(x, np.ndarray):
|
||||
return np.vectorize(round_away_zero)(x)
|
||||
else:
|
||||
if x >= 0.0:
|
||||
return math.floor(x + 0.5)
|
||||
else:
|
||||
return math.ceil(x - 0.5)
|
||||
|
||||
def _floor(x):
|
||||
if isinstance(x, np.ndarray):
|
||||
return np.vectorize(_floor)(x)
|
||||
else:
|
||||
return math.floor(x)
|
||||
|
||||
def _ceil(x):
|
||||
if isinstance(x, np.ndarray):
|
||||
return np.vectorize(_ceil)(x)
|
||||
else:
|
||||
return math.ceil(x)
|
||||
|
||||
def patch(module):
|
||||
def dbl_nan():
|
||||
return np.nan
|
||||
|
||||
def dbl_inf():
|
||||
return np.inf
|
||||
|
||||
def output_asciiart(x):
|
||||
if x < 0:
|
||||
sys.stdout.write("\n")
|
||||
else:
|
||||
sys.stdout.write(" .,-:;i+hHM$*#@ "[x])
|
||||
|
||||
def output_float(x):
|
||||
print("%f" % x)
|
||||
|
||||
def dbg_stack_address(_):
|
||||
return 0
|
||||
|
||||
def extern(fun):
|
||||
name = fun.__name__
|
||||
if name == "output_asciiart":
|
||||
if name == "dbl_nan":
|
||||
return dbl_nan
|
||||
elif name == "dbl_inf":
|
||||
return dbl_inf
|
||||
elif name == "output_asciiart":
|
||||
return output_asciiart
|
||||
elif name == "output_float64":
|
||||
return output_float
|
||||
elif name in {
|
||||
"output_bool",
|
||||
"output_int32",
|
||||
"output_int64",
|
||||
"output_int32_list",
|
||||
"output_uint32",
|
||||
"output_uint64",
|
||||
"output_float64"
|
||||
"output_str",
|
||||
}:
|
||||
return print
|
||||
elif name == "dbg_stack_address":
|
||||
return dbg_stack_address
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -68,13 +139,93 @@ def patch(module):
|
|||
module.int64 = int64
|
||||
module.uint32 = uint32
|
||||
module.uint64 = uint64
|
||||
module.bool = _bool
|
||||
module.float = _float
|
||||
module.TypeVar = TypeVar
|
||||
module.ConstGeneric = ConstGeneric
|
||||
module.Generic = Generic
|
||||
module.Literal = Literal
|
||||
module.extern = extern
|
||||
module.Option = Option
|
||||
module.Some = Some
|
||||
module.none = none
|
||||
|
||||
# Builtin Math functions
|
||||
module.round = round_away_zero
|
||||
module.round64 = round_away_zero
|
||||
module.np_round = np.round
|
||||
module.floor = _floor
|
||||
module.floor64 = _floor
|
||||
module.np_floor = np.floor
|
||||
module.ceil = _ceil
|
||||
module.ceil64 = _ceil
|
||||
module.np_ceil = np.ceil
|
||||
|
||||
# NumPy ndarray functions
|
||||
module.ndarray = NDArray
|
||||
module.np_ndarray = np.ndarray
|
||||
module.np_empty = np.empty
|
||||
module.np_zeros = np.zeros
|
||||
module.np_ones = np.ones
|
||||
module.np_full = np.full
|
||||
module.np_eye = np.eye
|
||||
module.np_identity = np.identity
|
||||
module.np_array = np.array
|
||||
|
||||
# NumPy Math functions
|
||||
module.np_isnan = np.isnan
|
||||
module.np_isinf = np.isinf
|
||||
module.np_min = np.min
|
||||
module.np_minimum = np.minimum
|
||||
module.np_max = np.max
|
||||
module.np_maximum = np.maximum
|
||||
module.np_sin = np.sin
|
||||
module.np_cos = np.cos
|
||||
module.np_exp = np.exp
|
||||
module.np_exp2 = np.exp2
|
||||
module.np_log = np.log
|
||||
module.np_log10 = np.log10
|
||||
module.np_log2 = np.log2
|
||||
module.np_fabs = np.fabs
|
||||
module.np_trunc = np.trunc
|
||||
module.np_sqrt = np.sqrt
|
||||
module.np_rint = np.rint
|
||||
module.np_tan = np.tan
|
||||
module.np_arcsin = np.arcsin
|
||||
module.np_arccos = np.arccos
|
||||
module.np_arctan = np.arctan
|
||||
module.np_sinh = np.sinh
|
||||
module.np_cosh = np.cosh
|
||||
module.np_tanh = np.tanh
|
||||
module.np_arcsinh = np.arcsinh
|
||||
module.np_arccosh = np.arccosh
|
||||
module.np_arctanh = np.arctanh
|
||||
module.np_expm1 = np.expm1
|
||||
module.np_cbrt = np.cbrt
|
||||
module.np_arctan2 = np.arctan2
|
||||
module.np_copysign = np.copysign
|
||||
module.np_fmax = np.fmax
|
||||
module.np_fmin = np.fmin
|
||||
module.np_ldexp = np.ldexp
|
||||
module.np_hypot = np.hypot
|
||||
module.np_nextafter = np.nextafter
|
||||
|
||||
# SciPy Math Functions
|
||||
module.sp_spec_erf = special.erf
|
||||
module.sp_spec_erfc = special.erfc
|
||||
module.sp_spec_gamma = special.gamma
|
||||
module.sp_spec_gammaln = special.gammaln
|
||||
module.sp_spec_j0 = special.j0
|
||||
module.sp_spec_j1 = special.j1
|
||||
|
||||
# NumPy NDArray Functions
|
||||
module.np_ndarray = np.ndarray
|
||||
module.np_empty = np.empty
|
||||
module.np_zeros = np.zeros
|
||||
module.np_ones = np.ones
|
||||
module.np_full = np.full
|
||||
module.np_eye = np.eye
|
||||
module.np_identity = np.identity
|
||||
|
||||
def file_import(filename, prefix="file_import_"):
|
||||
filename = pathlib.Path(filename)
|
||||
|
|
|
@ -7,14 +7,72 @@ if [ -z "$1" ]; then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
if [ -e ../../target/release/nac3standalone ]; then
|
||||
declare -a nac3args
|
||||
while [ $# -ge 1 ]; do
|
||||
case "$1" in
|
||||
--help)
|
||||
echo "Usage: run_demo.sh [--help] [--out OUTFILE] [--lli] [--debug] -- [NAC3ARGS...]"
|
||||
exit
|
||||
;;
|
||||
--out)
|
||||
shift
|
||||
outfile="$1"
|
||||
;;
|
||||
--lli)
|
||||
use_lli=1
|
||||
;;
|
||||
--debug)
|
||||
debug=1
|
||||
;;
|
||||
--)
|
||||
shift
|
||||
break
|
||||
;;
|
||||
*)
|
||||
break
|
||||
;;
|
||||
esac
|
||||
shift
|
||||
done
|
||||
|
||||
while [ $# -ge 1 ]; do
|
||||
nac3args+=("$1")
|
||||
shift
|
||||
done
|
||||
|
||||
if [ -n "$debug" ] && [ -e ../../target/debug/nac3standalone ]; then
|
||||
nac3standalone=../../target/debug/nac3standalone
|
||||
elif [ -e ../../target/release/nac3standalone ]; then
|
||||
nac3standalone=../../target/release/nac3standalone
|
||||
else
|
||||
# used by Nix builds
|
||||
nac3standalone=../../target/x86_64-unknown-linux-gnu/release/nac3standalone
|
||||
fi
|
||||
|
||||
rm -f *.o
|
||||
$nac3standalone $1
|
||||
rustc -o demo demo.rs -Crelocation-model=static -Clink-arg=./module.o
|
||||
./demo
|
||||
rm -f ./*.o ./*.bc demo
|
||||
if [ -z "$use_lli" ]; then
|
||||
$nac3standalone "${nac3args[@]}"
|
||||
|
||||
clang -c -std=gnu11 -Wall -Wextra -O3 -o demo.o demo.c
|
||||
clang -lm -o demo module.o demo.o
|
||||
|
||||
if [ -z "$outfile" ]; then
|
||||
./demo
|
||||
else
|
||||
./demo > "$outfile"
|
||||
fi
|
||||
else
|
||||
$nac3standalone --emit-llvm "${nac3args[@]}"
|
||||
|
||||
clang -c -std=gnu11 -Wall -Wextra -O3 -emit-llvm -o demo.bc demo.c
|
||||
|
||||
shopt -s nullglob
|
||||
llvm-link -o nac3out.bc module*.bc main.bc
|
||||
shopt -u nullglob
|
||||
|
||||
if [ -z "$outfile" ]; then
|
||||
lli --extra-module demo.bc --extra-module irrt.bc nac3out.bc
|
||||
else
|
||||
lli --extra-module demo.bc --extra-module irrt.bc nac3out.bc > "$outfile"
|
||||
fi
|
||||
fi
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
# Different cases for using boolean variables in boolean contexts.
|
||||
# Tests whether all boolean variables (expressed as i8s) are lowered into i1s before used in branching instruction (`br`)
|
||||
|
||||
def bfunc(b: bool) -> bool:
|
||||
return not b
|
||||
|
||||
def run() -> int32:
|
||||
b1 = True
|
||||
b2 = False
|
||||
|
||||
if b1:
|
||||
pass
|
||||
|
||||
if not b2:
|
||||
pass
|
||||
|
||||
while b2:
|
||||
pass
|
||||
|
||||
l = [i for i in range(10) if b2]
|
||||
|
||||
b_and = True and False
|
||||
b_or = True or False
|
||||
|
||||
b_and = b1 and b2
|
||||
b_or = b1 or b2
|
||||
|
||||
bfunc(b1)
|
||||
|
||||
return 0
|
|
@ -23,8 +23,8 @@ class A:
|
|||
def get_a(self) -> int32:
|
||||
return self.a
|
||||
|
||||
def get_b(self) -> B:
|
||||
return self.b
|
||||
# def get_b(self) -> B:
|
||||
# return self.b
|
||||
|
||||
|
||||
def run() -> int32:
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
A = ConstGeneric("A", int32)
|
||||
B = ConstGeneric("B", uint32)
|
||||
T = TypeVar("T")
|
||||
|
||||
class ConstGenericClass(Generic[A]):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
class ConstGeneric2Class(Generic[A, B]):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
class HybridGenericClass2(Generic[A, T]):
|
||||
pass
|
||||
|
||||
class HybridGenericClass3(Generic[T, A, B]):
|
||||
pass
|
||||
|
||||
def make_generic_2() -> ConstGenericClass[Literal[2]]:
|
||||
return ...
|
||||
|
||||
def make_generic2_1_2() -> ConstGeneric2Class[Literal[1], Literal[2]]:
|
||||
return ...
|
||||
|
||||
def make_hybrid_class_2_int32() -> HybridGenericClass2[Literal[2], int32]:
|
||||
return ...
|
||||
|
||||
def make_hybrid_class_i32_0_1() -> HybridGenericClass3[int32, Literal[0], Literal[1]]:
|
||||
return ...
|
||||
|
||||
def consume_generic_2(instance: ConstGenericClass[Literal[2]]):
|
||||
pass
|
||||
|
||||
def consume_generic2_1_2(instance: ConstGeneric2Class[Literal[1], Literal[2]]):
|
||||
pass
|
||||
|
||||
def consume_hybrid_class_2_i32(instance: HybridGenericClass2[Literal[2], int32]):
|
||||
pass
|
||||
|
||||
def consume_hybrid_class_i32_0_1(instance: HybridGenericClass3[int32, Literal[0], Literal[1]]):
|
||||
pass
|
||||
|
||||
def f():
|
||||
consume_generic_2(make_generic_2())
|
||||
consume_generic2_1_2(make_generic2_1_2())
|
||||
consume_hybrid_class_2_i32(make_hybrid_class_2_int32())
|
||||
consume_hybrid_class_i32_0_1(make_hybrid_class_i32_0_1())
|
||||
|
||||
def run() -> int32:
|
||||
return 0
|
|
@ -0,0 +1,8 @@
|
|||
def f():
|
||||
return
|
||||
return
|
||||
|
||||
def run() -> int32:
|
||||
f()
|
||||
|
||||
return 0
|
|
@ -0,0 +1,83 @@
|
|||
@extern
|
||||
def output_bool(x: bool):
|
||||
...
|
||||
|
||||
@extern
|
||||
def output_int32(x: int32):
|
||||
...
|
||||
|
||||
@extern
|
||||
def output_int64(x: int64):
|
||||
...
|
||||
|
||||
@extern
|
||||
def output_uint32(x: uint32):
|
||||
...
|
||||
|
||||
@extern
|
||||
def output_uint64(x: uint64):
|
||||
...
|
||||
|
||||
@extern
|
||||
def output_float64(x: float):
|
||||
...
|
||||
|
||||
@extern
|
||||
def output_int32_list(x: list[int32]):
|
||||
...
|
||||
|
||||
@extern
|
||||
def output_asciiart(x: int32):
|
||||
...
|
||||
|
||||
@extern
|
||||
def output_str(x: str):
|
||||
...
|
||||
|
||||
def test_output_bool():
|
||||
output_bool(True)
|
||||
output_bool(False)
|
||||
|
||||
def test_output_int32():
|
||||
output_int32(-128)
|
||||
|
||||
def test_output_int64():
|
||||
output_int64(int64(-256))
|
||||
|
||||
def test_output_uint32():
|
||||
output_uint32(uint32(128))
|
||||
|
||||
def test_output_uint64():
|
||||
output_uint64(uint64(256))
|
||||
|
||||
def test_output_float64():
|
||||
output_float64(0.0)
|
||||
output_float64(1.0)
|
||||
output_float64(-1.0)
|
||||
output_float64(128.0)
|
||||
output_float64(-128.0)
|
||||
output_float64(16.25)
|
||||
output_float64(-16.25)
|
||||
|
||||
def test_output_asciiart():
|
||||
for i in range(17):
|
||||
output_asciiart(i)
|
||||
output_asciiart(0)
|
||||
|
||||
def test_output_int32_list():
|
||||
output_int32_list([0, 1, 3, 5, 10])
|
||||
|
||||
def test_output_str_family():
|
||||
output_str("hello world")
|
||||
|
||||
def run() -> int32:
|
||||
test_output_bool()
|
||||
test_output_int32()
|
||||
test_output_int64()
|
||||
test_output_uint32()
|
||||
test_output_uint64()
|
||||
test_output_float64()
|
||||
test_output_asciiart()
|
||||
test_output_int32_list()
|
||||
test_output_str_family()
|
||||
return 0
|
|
@ -0,0 +1,32 @@
|
|||
from __future__ import annotations
|
||||
|
||||
@extern
|
||||
def output_int32(x: int32):
|
||||
...
|
||||
|
||||
class A:
|
||||
a: int32
|
||||
|
||||
def __init__(self, a: int32):
|
||||
self.a = a
|
||||
|
||||
def f1(self):
|
||||
self.f2()
|
||||
|
||||
def f2(self):
|
||||
output_int32(self.a)
|
||||
|
||||
class B(A):
|
||||
b: int32
|
||||
|
||||
def __init__(self, b: int32):
|
||||
self.a = b + 1
|
||||
self.b = b
|
||||
|
||||
|
||||
def run() -> int32:
|
||||
aaa = A(5)
|
||||
bbb = B(2)
|
||||
aaa.f1()
|
||||
bbb.f1()
|
||||
return 0
|
|
@ -0,0 +1,17 @@
|
|||
@extern
|
||||
def output_int32(x: int32):
|
||||
...
|
||||
|
||||
@extern
|
||||
def output_int32_list(x: list[int32]):
|
||||
...
|
||||
|
||||
def run() -> int32:
|
||||
bl = [True, False]
|
||||
|
||||
bl1 = bl[:]
|
||||
bl1[1:] = [True]
|
||||
output_int32_list([int32(b) for b in bl1])
|
||||
output_int32_list([int32(b) for b in bl1])
|
||||
|
||||
return 0
|
|
@ -1,9 +1,12 @@
|
|||
# For Loop using an increasing range() expression as its iterable
|
||||
|
||||
@extern
|
||||
def output_int32(x: int32):
|
||||
...
|
||||
|
||||
def run() -> int32:
|
||||
for _ in range(10):
|
||||
output_int32(_)
|
||||
_ = 0
|
||||
i = 0
|
||||
for i in range(10):
|
||||
output_int32(i)
|
||||
output_int32(i)
|
||||
return 0
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
@extern
|
||||
def output_int32(x: int32):
|
||||
...
|
||||
|
||||
def run() -> int32:
|
||||
for i in range(4):
|
||||
output_int32(i)
|
||||
if i < 2:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
n = [0, 1, 2, 3]
|
||||
for i in n:
|
||||
output_int32(i)
|
||||
if i < 2:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
return 0
|
|
@ -0,0 +1,12 @@
|
|||
# For Loop using a decreasing range() expression as its iterable
|
||||
|
||||
@extern
|
||||
def output_int32(x: int32):
|
||||
...
|
||||
|
||||
def run() -> int32:
|
||||
i = 0
|
||||
for i in range(10, 0, -1):
|
||||
output_int32(i)
|
||||
output_int32(i)
|
||||
return 0
|
|
@ -0,0 +1,17 @@
|
|||
# For Loop using a list as its iterable
|
||||
|
||||
@extern
|
||||
def output_int32(x: int32):
|
||||
...
|
||||
|
||||
def run() -> int32:
|
||||
l = [0, 1, 2, 3, 4]
|
||||
|
||||
# i: int32 # declaration-without-initializer not yet supported
|
||||
i = 0 # i must be declared before the loop; this is not necessary in Python
|
||||
for i in l:
|
||||
output_int32(i)
|
||||
i = 0
|
||||
output_int32(i)
|
||||
output_int32(i)
|
||||
return 0
|
|
@ -0,0 +1,14 @@
|
|||
# For Loop using an range() expression as its iterable, additionally reassigning the target on each iteration
|
||||
|
||||
@extern
|
||||
def output_int32(x: int32):
|
||||
...
|
||||
|
||||
def run() -> int32:
|
||||
i = 0
|
||||
for i in range(10):
|
||||
output_int32(i)
|
||||
i = 0
|
||||
output_int32(i)
|
||||
output_int32(i)
|
||||
return 0
|
|
@ -0,0 +1,33 @@
|
|||
# Break within try statement within a loop
|
||||
# Taken from https://book.pythontips.com/en/latest/for_-_else.html
|
||||
|
||||
@extern
|
||||
def output_int32(x: int32):
|
||||
...
|
||||
|
||||
@extern
|
||||
def output_float64(x: float):
|
||||
...
|
||||
|
||||
@extern
|
||||
def output_str(x: str):
|
||||
...
|
||||
|
||||
def run() -> int32:
|
||||
for n in range(2, 10):
|
||||
for x in range(2, n):
|
||||
try:
|
||||
if n % x == 0:
|
||||
output_int32(n)
|
||||
output_str(" equals ")
|
||||
output_int32(x)
|
||||
output_str(" * ")
|
||||
output_float64(n / x)
|
||||
except: # Assume this is intended to catch x == 0
|
||||
break
|
||||
else:
|
||||
# loop fell through without finding a factor
|
||||
output_int32(n)
|
||||
output_str(" is a prime number")
|
||||
|
||||
return 0
|
|
@ -0,0 +1,274 @@
|
|||
@extern
|
||||
def output_bool(x: bool):
|
||||
...
|
||||
|
||||
@extern
|
||||
def output_int32(x: int32):
|
||||
...
|
||||
|
||||
@extern
|
||||
def output_int64(x: int64):
|
||||
...
|
||||
|
||||
@extern
|
||||
def output_float64(x: float):
|
||||
...
|
||||
|
||||
@extern
|
||||
def dbl_nan() -> float:
|
||||
...
|
||||
|
||||
@extern
|
||||
def dbl_inf() -> float:
|
||||
...
|
||||
|
||||
def dbl_pi() -> float:
|
||||
return 3.1415926535897932384626433
|
||||
|
||||
def dbl_e() -> float:
|
||||
return 2.71828182845904523536028747135266249775724709369995
|
||||
|
||||
def test_round():
|
||||
for x in [-1.5, -0.5, 0.5, 1.5]:
|
||||
output_int32(round(x))
|
||||
|
||||
def test_round64():
|
||||
for x in [-1.5, -0.5, 0.5, 1.5]:
|
||||
output_int64(round64(x))
|
||||
|
||||
def test_np_round():
|
||||
for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_round(x))
|
||||
|
||||
def test_np_isnan():
|
||||
for x in [dbl_nan(), 0.0, dbl_inf()]:
|
||||
output_bool(np_isnan(x))
|
||||
|
||||
def test_np_isinf():
|
||||
for x in [dbl_inf(), -dbl_inf(), 0.0, dbl_nan()]:
|
||||
output_bool(np_isinf(x))
|
||||
|
||||
def test_np_sin():
|
||||
pi = dbl_pi()
|
||||
for x in [-pi, -pi / 2.0, -pi / 4.0, 0.0, pi / 4.0, pi / 2.0, pi, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_sin(x))
|
||||
|
||||
def test_np_cos():
|
||||
pi = dbl_pi()
|
||||
for x in [-pi, -pi / 2.0, -pi / 4.0, 0.0, pi / 4.0, pi / 2.0, pi, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_cos(x))
|
||||
|
||||
def test_np_exp():
|
||||
for x in [0.0, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_exp(x))
|
||||
|
||||
def test_np_exp2():
|
||||
for x in [0.0, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_exp2(x))
|
||||
|
||||
def test_np_log():
|
||||
e = dbl_e()
|
||||
for x in [1.0, e, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_log(x))
|
||||
|
||||
def test_np_log10():
|
||||
for x in [1.0, 10.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_log10(x))
|
||||
|
||||
def test_np_log2():
|
||||
for x in [1.0, 2.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_log2(x))
|
||||
|
||||
def test_np_fabs():
|
||||
for x in [-1.0, 0.0, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_fabs(x))
|
||||
|
||||
def test_floor():
|
||||
for x in [-1.5, -0.5, 0.5, 1.5]:
|
||||
output_int32(floor(x))
|
||||
|
||||
def test_floor64():
|
||||
for x in [-1.5, -0.5, 0.5, 1.5]:
|
||||
output_int64(floor64(x))
|
||||
|
||||
def test_np_floor():
|
||||
for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_floor(x))
|
||||
|
||||
def test_ceil():
|
||||
for x in [-1.5, -0.5, 0.5, 1.5]:
|
||||
output_int32(ceil(x))
|
||||
|
||||
def test_ceil64():
|
||||
for x in [-1.5, -0.5, 0.5, 1.5]:
|
||||
output_int64(ceil64(x))
|
||||
|
||||
def test_np_ceil():
|
||||
for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_ceil(x))
|
||||
|
||||
def test_np_sqrt():
|
||||
for x in [1.0, 2.0, 4.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_sqrt(x))
|
||||
|
||||
def test_np_rint():
|
||||
for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_rint(x))
|
||||
|
||||
def test_np_tan():
|
||||
pi = dbl_pi()
|
||||
for x in [-pi, -pi / 2.0, -pi / 4.0, 0.0, pi / 4.0, pi / 2.0, pi, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_tan(x))
|
||||
|
||||
def test_np_arcsin():
|
||||
for x in [-1.0, -0.5, 0.0, 0.5, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_arcsin(x))
|
||||
|
||||
def test_np_arccos():
|
||||
for x in [-1.0, -0.5, 0.0, 0.5, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_arccos(x))
|
||||
|
||||
def test_np_arctan():
|
||||
for x in [-1.0, -0.5, 0.0, 0.5, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_arctan(x))
|
||||
|
||||
def test_np_sinh():
|
||||
for x in [-1.0, -0.5, 0.0, 0.5, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_sinh(x))
|
||||
|
||||
def test_np_cosh():
|
||||
for x in [-1.0, -0.5, 0.0, 0.5, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_cosh(x))
|
||||
|
||||
def test_np_tanh():
|
||||
for x in [-1.0, -0.5, 0.0, 0.5, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_tanh(x))
|
||||
|
||||
def test_np_arcsinh():
|
||||
for x in [-1.0, -0.5, 0.0, 0.5, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_arcsinh(x))
|
||||
|
||||
def test_np_arccosh():
|
||||
for x in [-1.0, -0.5, 0.0, 0.5, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_arccosh(x))
|
||||
|
||||
def test_np_arctanh():
|
||||
for x in [-1.0, -0.5, 0.0, 0.5, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_arctanh(x))
|
||||
|
||||
def test_np_expm1():
|
||||
for x in [0.0, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_expm1(x))
|
||||
|
||||
def test_np_cbrt():
|
||||
for x in [1.0, 8.0, 27.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_cbrt(x))
|
||||
|
||||
def test_sp_spec_erf():
|
||||
for x in [-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(sp_spec_erf(x))
|
||||
|
||||
def test_sp_spec_erfc():
|
||||
for x in [-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(sp_spec_erfc(x))
|
||||
|
||||
def test_sp_spec_gamma():
|
||||
for x in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(sp_spec_gamma(x))
|
||||
|
||||
def test_sp_spec_gammaln():
|
||||
for x in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(sp_spec_gammaln(x))
|
||||
|
||||
def test_sp_spec_j0():
|
||||
for x in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(sp_spec_j0(x))
|
||||
|
||||
def test_sp_spec_j1():
|
||||
for x in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]:
|
||||
output_float64(sp_spec_j1(x))
|
||||
|
||||
def test_np_arctan2():
|
||||
for x1 in [-1.0, -0.5, 0.0, 0.5, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
for x2 in [-1.0, -0.5, 0.0, 0.5, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_arctan2(x1, x2))
|
||||
|
||||
def test_np_copysign():
|
||||
for x1 in [-1.0, -0.5, 0.0, 0.5, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
for x2 in [-1.0, -0.5, 0.0, 0.5, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_copysign(x1, x2))
|
||||
|
||||
def test_np_fmax():
|
||||
for x1 in [-1.0, -0.5, 0.0, 0.5, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
for x2 in [-1.0, -0.5, 0.0, 0.5, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_fmax(x1, x2))
|
||||
|
||||
def test_np_fmin():
|
||||
for x1 in [-1.0, -0.5, 0.0, 0.5, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
for x2 in [-1.0, -0.5, 0.0, 0.5, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_fmin(x1, x2))
|
||||
|
||||
def test_np_ldexp():
|
||||
for x1 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
for x2 in [-2, -1, 0, 1, 2]:
|
||||
output_float64(np_ldexp(x1, x2))
|
||||
|
||||
def test_np_hypot():
|
||||
for x1 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
for x2 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_hypot(x1, x2))
|
||||
|
||||
def test_np_nextafter():
|
||||
for x1 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
for x2 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
|
||||
output_float64(np_nextafter(x1, x2))
|
||||
|
||||
def run() -> int32:
|
||||
test_round()
|
||||
test_round64()
|
||||
test_np_round()
|
||||
test_np_isnan()
|
||||
test_np_isinf()
|
||||
test_np_sin()
|
||||
test_np_cos()
|
||||
test_np_exp()
|
||||
test_np_exp2()
|
||||
test_np_log()
|
||||
test_np_log10()
|
||||
test_np_log2()
|
||||
test_np_fabs()
|
||||
test_floor()
|
||||
test_floor64()
|
||||
test_np_floor()
|
||||
test_ceil()
|
||||
test_ceil64()
|
||||
test_np_ceil()
|
||||
test_np_sqrt()
|
||||
test_np_rint()
|
||||
test_np_tan()
|
||||
test_np_arcsin()
|
||||
test_np_arccos()
|
||||
test_np_arctan()
|
||||
test_np_sinh()
|
||||
test_np_cosh()
|
||||
test_np_tanh()
|
||||
test_np_arcsinh()
|
||||
test_np_arccosh()
|
||||
test_np_arctanh()
|
||||
test_np_expm1()
|
||||
test_np_cbrt()
|
||||
test_sp_spec_erf()
|
||||
test_sp_spec_erfc()
|
||||
test_sp_spec_gamma()
|
||||
test_sp_spec_gammaln()
|
||||
test_sp_spec_j0()
|
||||
test_sp_spec_j1()
|
||||
test_np_arctan2()
|
||||
test_np_copysign()
|
||||
test_np_fmax()
|
||||
test_np_fmin()
|
||||
test_np_ldexp()
|
||||
test_np_hypot()
|
||||
test_np_nextafter()
|
||||
|
||||
return 0
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,184 @@
|
|||
@extern
|
||||
def output_bool(x: bool):
|
||||
...
|
||||
|
||||
@extern
|
||||
def output_int32(x: int32):
|
||||
...
|
||||
|
||||
@extern
|
||||
def output_int64(x: int64):
|
||||
...
|
||||
|
||||
@extern
|
||||
def output_uint32(x: uint32):
|
||||
...
|
||||
|
||||
@extern
|
||||
def output_uint64(x: uint64):
|
||||
...
|
||||
|
||||
@extern
|
||||
def output_float64(x: float):
|
||||
...
|
||||
|
||||
def u32_min() -> uint32:
|
||||
return uint32(0)
|
||||
|
||||
def u32_max() -> uint32:
|
||||
return ~uint32(0)
|
||||
|
||||
def i32_min() -> int32:
|
||||
return int32(1 << 31)
|
||||
|
||||
def i32_max() -> int32:
|
||||
return int32(~(1 << 31))
|
||||
|
||||
def u64_min() -> uint64:
|
||||
return uint64(0)
|
||||
|
||||
def u64_max() -> uint64:
|
||||
return ~uint64(0)
|
||||
|
||||
def i64_min() -> int64:
|
||||
return int64(1) << 63
|
||||
|
||||
def i64_max() -> int64:
|
||||
return ~(int64(1) << 63)
|
||||
|
||||
def test_u32_bnot():
|
||||
output_uint32(~uint32(0))
|
||||
|
||||
def test_u64_bnot():
|
||||
output_uint64(~uint64(0))
|
||||
|
||||
def test_conv_from_i32():
|
||||
for x in [
|
||||
i32_min(),
|
||||
i32_min() + 1,
|
||||
-1,
|
||||
0,
|
||||
1,
|
||||
i32_max() - 1,
|
||||
i32_max()
|
||||
]:
|
||||
output_int64(int64(x))
|
||||
output_uint32(uint32(x))
|
||||
output_uint64(uint64(x))
|
||||
output_float64(float(x))
|
||||
|
||||
def test_conv_from_u32():
|
||||
for x in [
|
||||
u32_min(),
|
||||
u32_min() + uint32(1),
|
||||
u32_max() - uint32(1),
|
||||
u32_max()
|
||||
]:
|
||||
output_uint64(uint64(x))
|
||||
output_int32(int32(x))
|
||||
output_int64(int64(x))
|
||||
output_float64(float(x))
|
||||
|
||||
def test_conv_from_i64():
|
||||
for x in [
|
||||
i64_min(),
|
||||
i64_min() + int64(1),
|
||||
int64(-1),
|
||||
int64(0),
|
||||
int64(1),
|
||||
i64_max() - int64(1),
|
||||
i64_max()
|
||||
]:
|
||||
output_int32(int32(x))
|
||||
output_uint64(uint64(x))
|
||||
output_uint32(uint32(x))
|
||||
output_float64(float(x))
|
||||
|
||||
def test_conv_from_u64():
|
||||
for x in [
|
||||
u64_min(),
|
||||
u64_min() + uint64(1),
|
||||
u64_max() - uint64(1),
|
||||
u64_max()
|
||||
]:
|
||||
output_uint32(uint32(x))
|
||||
output_int64(int64(x))
|
||||
output_int32(int32(x))
|
||||
output_float64(float(x))
|
||||
|
||||
def test_f64toi32():
|
||||
for x in [
|
||||
float(i32_min()) - 1.0,
|
||||
float(i32_min()),
|
||||
float(i32_min()) + 1.0,
|
||||
-1.5,
|
||||
-0.5,
|
||||
0.5,
|
||||
1.5,
|
||||
float(i32_max()) - 1.0,
|
||||
float(i32_max()),
|
||||
float(i32_max()) + 1.0
|
||||
]:
|
||||
output_int32(int32(x))
|
||||
|
||||
def test_f64toi64():
|
||||
for x in [
|
||||
float(i64_min()),
|
||||
float(i64_min()) + 1.0,
|
||||
-1.5,
|
||||
-0.5,
|
||||
0.5,
|
||||
1.5,
|
||||
# 2^53 is the highest integral power-of-two of which uint64 and float have a one-to-one correspondence
|
||||
float(uint64(2) ** uint64(52)) - 1.0,
|
||||
float(uint64(2) ** uint64(52)),
|
||||
float(uint64(2) ** uint64(52)) + 1.0,
|
||||
]:
|
||||
output_int64(int64(x))
|
||||
|
||||
def test_f64tou32():
|
||||
for x in [
|
||||
-1.5,
|
||||
float(u32_min()) - 1.0,
|
||||
-0.5,
|
||||
float(u32_min()),
|
||||
0.5,
|
||||
float(u32_min()) + 1.0,
|
||||
1.5,
|
||||
float(u32_max()) - 1.0,
|
||||
float(u32_max()),
|
||||
float(u32_max()) + 1.0
|
||||
]:
|
||||
output_uint32(uint32(x))
|
||||
|
||||
def test_f64tou64():
|
||||
for x in [
|
||||
-1.5,
|
||||
float(u64_min()) - 1.0,
|
||||
-0.5,
|
||||
float(u64_min()),
|
||||
0.5,
|
||||
float(u64_min()) + 1.0,
|
||||
1.5,
|
||||
# 2^53 is the highest integral power-of-two of which uint64 and float have a one-to-one correspondence
|
||||
float(uint64(2) ** uint64(52)) - 1.0,
|
||||
float(uint64(2) ** uint64(52)),
|
||||
float(uint64(2) ** uint64(52)) + 1.0,
|
||||
]:
|
||||
output_uint64(uint64(x))
|
||||
|
||||
def run() -> int32:
|
||||
test_u32_bnot()
|
||||
test_u64_bnot()
|
||||
|
||||
test_conv_from_i32()
|
||||
test_conv_from_u32()
|
||||
test_conv_from_i64()
|
||||
test_conv_from_u64()
|
||||
|
||||
test_f64toi32()
|
||||
test_f64toi64()
|
||||
test_f64tou32()
|
||||
test_f64tou64()
|
||||
|
||||
return 0
|
|
@ -0,0 +1,281 @@
|
|||
from __future__ import annotations
|
||||
|
||||
@extern
|
||||
def output_bool(x: bool):
|
||||
...
|
||||
|
||||
@extern
|
||||
def output_int32(x: int32):
|
||||
...
|
||||
@extern
|
||||
def output_uint32(x: uint32):
|
||||
...
|
||||
@extern
|
||||
def output_int64(x: int64):
|
||||
...
|
||||
@extern
|
||||
def output_uint64(x: uint64):
|
||||
...
|
||||
@extern
|
||||
def output_float64(x: float):
|
||||
...
|
||||
|
||||
def run() -> int32:
|
||||
test_bool()
|
||||
test_int32()
|
||||
test_uint32()
|
||||
test_int64()
|
||||
test_uint64()
|
||||
# test_A()
|
||||
# test_B()
|
||||
return 0
|
||||
|
||||
def test_bool():
|
||||
t = True
|
||||
f = False
|
||||
output_bool(not t)
|
||||
output_bool(not f)
|
||||
output_int32(~t)
|
||||
output_int32(~f)
|
||||
output_int32(+t)
|
||||
output_int32(+f)
|
||||
output_int32(-t)
|
||||
output_int32(-f)
|
||||
|
||||
def test_int32():
|
||||
a = 17
|
||||
b = 3
|
||||
output_int32(a + b)
|
||||
output_int32(a - b)
|
||||
output_int32(a * b)
|
||||
output_int32(a // b)
|
||||
output_int32(a % b)
|
||||
output_int32(a | b)
|
||||
output_int32(a ^ b)
|
||||
output_int32(a & b)
|
||||
output_int32(a << b)
|
||||
output_int32(a << uint32(b))
|
||||
output_int32(a >> b)
|
||||
output_int32(a >> uint32(b))
|
||||
output_float64(a / b)
|
||||
a += b
|
||||
output_int32(a)
|
||||
a -= b
|
||||
output_int32(a)
|
||||
a *= b
|
||||
output_int32(a)
|
||||
a //= b
|
||||
output_int32(a)
|
||||
a %= b
|
||||
output_int32(a)
|
||||
a |= b
|
||||
output_int32(a)
|
||||
a ^= b
|
||||
output_int32(a)
|
||||
a &= b
|
||||
output_int32(a)
|
||||
a <<= b
|
||||
output_int32(a)
|
||||
a >>= b
|
||||
output_int32(a)
|
||||
# fail because (a / b) is float
|
||||
# a /= b
|
||||
|
||||
def test_uint32():
|
||||
a = uint32(17)
|
||||
b = uint32(3)
|
||||
output_uint32(a + b)
|
||||
output_uint32(a - b)
|
||||
output_uint32(a * b)
|
||||
output_uint32(a // b)
|
||||
output_uint32(a % b)
|
||||
output_uint32(a | b)
|
||||
output_uint32(a ^ b)
|
||||
output_uint32(a & b)
|
||||
output_uint32(a << b)
|
||||
output_uint32(a << int32(b))
|
||||
output_uint32(a >> b)
|
||||
output_uint32(a >> int32(b))
|
||||
output_float64(a / b)
|
||||
a += b
|
||||
output_uint32(a)
|
||||
a -= b
|
||||
output_uint32(a)
|
||||
a *= b
|
||||
output_uint32(a)
|
||||
a //= b
|
||||
output_uint32(a)
|
||||
a %= b
|
||||
output_uint32(a)
|
||||
a |= b
|
||||
output_uint32(a)
|
||||
a ^= b
|
||||
output_uint32(a)
|
||||
a &= b
|
||||
output_uint32(a)
|
||||
a <<= b
|
||||
output_uint32(a)
|
||||
a >>= b
|
||||
output_uint32(a)
|
||||
|
||||
def test_int64():
|
||||
a = int64(17)
|
||||
b = int64(3)
|
||||
output_int64(a + b)
|
||||
output_int64(a - b)
|
||||
output_int64(a * b)
|
||||
output_int64(a // b)
|
||||
output_int64(a % b)
|
||||
output_int64(a | b)
|
||||
output_int64(a ^ b)
|
||||
output_int64(a & b)
|
||||
output_int64(a << int32(b))
|
||||
output_int64(a << uint32(b))
|
||||
output_int64(a >> int32(b))
|
||||
output_int64(a >> uint32(b))
|
||||
output_float64(a / b)
|
||||
a += b
|
||||
output_int64(a)
|
||||
a -= b
|
||||
output_int64(a)
|
||||
a *= b
|
||||
output_int64(a)
|
||||
a //= b
|
||||
output_int64(a)
|
||||
a %= b
|
||||
output_int64(a)
|
||||
a |= b
|
||||
output_int64(a)
|
||||
a ^= b
|
||||
output_int64(a)
|
||||
a &= b
|
||||
output_int64(a)
|
||||
a <<= int32(b)
|
||||
output_int64(a)
|
||||
a >>= int32(b)
|
||||
output_int64(a)
|
||||
|
||||
def test_uint64():
|
||||
a = uint64(17)
|
||||
b = uint64(3)
|
||||
output_uint64(a + b)
|
||||
output_uint64(a - b)
|
||||
output_uint64(a * b)
|
||||
output_uint64(a // b)
|
||||
output_uint64(a % b)
|
||||
output_uint64(a | b)
|
||||
output_uint64(a ^ b)
|
||||
output_uint64(a & b)
|
||||
output_uint64(a << uint32(b))
|
||||
output_uint64(a >> uint32(b))
|
||||
output_float64(a / b)
|
||||
a += b
|
||||
output_uint64(a)
|
||||
a -= b
|
||||
output_uint64(a)
|
||||
a *= b
|
||||
output_uint64(a)
|
||||
a //= b
|
||||
output_uint64(a)
|
||||
a %= b
|
||||
output_uint64(a)
|
||||
a |= b
|
||||
output_uint64(a)
|
||||
a ^= b
|
||||
output_uint64(a)
|
||||
a &= b
|
||||
output_uint64(a)
|
||||
a <<= uint32(b)
|
||||
output_uint64(a)
|
||||
a >>= uint32(b)
|
||||
output_uint64(a)
|
||||
|
||||
# FIXME Fix returning objects of non-primitive types; Currently this is disabled in the function checker
|
||||
# class A:
|
||||
# a: int32
|
||||
# def __init__(self, a: int32):
|
||||
# self.a = a
|
||||
#
|
||||
# def __add__(self, other: A) -> A:
|
||||
# output_int32(self.a + other.a)
|
||||
# return A(self.a + other.a)
|
||||
#
|
||||
# def __sub__(self, other: A) -> A:
|
||||
# output_int32(self.a - other.a)
|
||||
# return A(self.a - other.a)
|
||||
#
|
||||
# def test_A():
|
||||
# a = A(17)
|
||||
# b = A(3)
|
||||
#
|
||||
# c = a + b
|
||||
# # fail due to alloca in __add__ function
|
||||
# # output_int32(c.a)
|
||||
#
|
||||
# a += b
|
||||
# # fail due to alloca in __add__ function
|
||||
# # output_int32(a.a)
|
||||
#
|
||||
# a = A(17)
|
||||
# b = A(3)
|
||||
# d = a - b
|
||||
# # fail due to alloca in __add__ function
|
||||
# # output_int32(c.a)
|
||||
#
|
||||
# a -= b
|
||||
# # fail due to alloca in __add__ function
|
||||
# # output_int32(a.a)
|
||||
#
|
||||
# a = A(17)
|
||||
# b = A(3)
|
||||
# a.__add__(b)
|
||||
# a.__sub__(b)
|
||||
#
|
||||
#
|
||||
# class B:
|
||||
# a: int32
|
||||
# def __init__(self, a: int32):
|
||||
# self.a = a
|
||||
#
|
||||
# def __add__(self, other: B) -> B:
|
||||
# output_int32(self.a + other.a)
|
||||
# return B(self.a + other.a)
|
||||
#
|
||||
# def __sub__(self, other: B) -> B:
|
||||
# output_int32(self.a - other.a)
|
||||
# return B(self.a - other.a)
|
||||
#
|
||||
# def __iadd__(self, other: B) -> B:
|
||||
# output_int32(self.a + other.a + 24)
|
||||
# return B(self.a + other.a + 24)
|
||||
#
|
||||
# def __isub__(self, other: B) -> B:
|
||||
# output_int32(self.a - other.a - 24)
|
||||
# return B(self.a - other.a - 24)
|
||||
#
|
||||
# def test_B():
|
||||
# a = B(17)
|
||||
# b = B(3)
|
||||
#
|
||||
# c = a + b
|
||||
# # fail due to alloca in __add__ function
|
||||
# # output_int32(c.a)
|
||||
#
|
||||
# a += b
|
||||
# # fail due to alloca in __add__ function
|
||||
# # output_int32(a.a)
|
||||
#
|
||||
# a = B(17)
|
||||
# b = B(3)
|
||||
# d = a - b
|
||||
# # fail due to alloca in __add__ function
|
||||
# # output_int32(c.a)
|
||||
#
|
||||
# a -= b
|
||||
# # fail due to alloca in __add__ function
|
||||
# # output_int32(a.a)
|
||||
#
|
||||
# a = B(17)
|
||||
# b = B(3)
|
||||
# a.__add__(b)
|
||||
# a.__sub__(b)
|
|
@ -0,0 +1,36 @@
|
|||
from __future__ import annotations
|
||||
|
||||
@extern
|
||||
def output_int32(a: int32):
|
||||
...
|
||||
|
||||
class A:
|
||||
d: int32
|
||||
a: list[B]
|
||||
def __init__(self, b: list[B]):
|
||||
self.d = 123
|
||||
self.a = b
|
||||
|
||||
def f(self):
|
||||
output_int32(self.d)
|
||||
|
||||
class B:
|
||||
a: A
|
||||
def __init__(self, a: A):
|
||||
self.a = a
|
||||
|
||||
def ff(self):
|
||||
self.a.f()
|
||||
|
||||
class Demo:
|
||||
a: A
|
||||
def __init__(self, a: A):
|
||||
self.a = a
|
||||
|
||||
def run() -> int32:
|
||||
aa = A([])
|
||||
bb = B(aa)
|
||||
aa.a = [bb]
|
||||
d = Demo(aa)
|
||||
d.a.a[0].ff()
|
||||
return 0
|
|
@ -0,0 +1,15 @@
|
|||
@extern
|
||||
def output_bool(x: bool):
|
||||
...
|
||||
|
||||
@extern
|
||||
def dbg_stack_address(x: str) -> uint64:
|
||||
...
|
||||
|
||||
def run() -> int32:
|
||||
a = dbg_stack_address("a")
|
||||
b = dbg_stack_address("b")
|
||||
|
||||
output_bool(a == b)
|
||||
|
||||
return 0
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue