Compare commits

..

16 Commits

Author SHA1 Message Date
pca006132 65d9b620fe Merge remote-tracking branch 'origin/master' 2021-01-27 16:00:47 +08:00
pca006132 075245f4c2 get list of unknowns 2021-01-27 16:00:06 +08:00
pca006132 0e9dea0e65 updated todo 2021-01-25 17:09:17 +08:00
pca006132 7d06c903e3 added get_expression_unknowns 2021-01-15 15:19:41 +08:00
pca006132 5fce6cf069 updated lifetime 2021-01-15 14:37:55 +08:00
pca006132 d466b7bc2b simplified lifetime 2021-01-11 11:12:37 +08:00
pca006132 b6e4e68587 started adding error types
* converted to string to avoid breaking interface, would be fixed later
* tests would not check error messages now, as messages are not
  finalized
2021-01-11 10:40:32 +08:00
pca006132 779288d685 added todo 2021-01-08 17:16:25 +08:00
pca006132 3ecf57a588 fixed bugs 2021-01-08 16:54:34 +08:00
pca006132 ebe1027ffa added resolve signature 2021-01-08 16:24:25 +08:00
pca006132 7c6349520c started getting signatures 2021-01-08 14:52:48 +08:00
pca006132 04f121403a moved type check code to submodule 2021-01-08 12:58:33 +08:00
pca006132 b51168a5ab added statement tests 2021-01-05 13:35:44 +08:00
pca006132 007843c1ef added aug assign for primitives 2021-01-05 13:21:39 +08:00
pca006132 ff41cdb000 implemented statement check 2021-01-05 12:17:45 +08:00
pca006132 e1efb47ad2 statement inference: assignment 2021-01-04 17:07:14 +08:00
226 changed files with 4667 additions and 62344 deletions

View File

@ -1,32 +0,0 @@
BasedOnStyle: LLVM
Language: Cpp
Standard: Cpp11
AccessModifierOffset: -1
AlignEscapedNewlines: Left
AlwaysBreakAfterReturnType: None
AlwaysBreakTemplateDeclarations: Yes
AllowAllParametersOfDeclarationOnNextLine: false
AllowShortFunctionsOnASingleLine: Inline
BinPackParameters: false
BreakBeforeBinaryOperators: NonAssignment
BreakBeforeTernaryOperators: true
BreakConstructorInitializers: AfterColon
BreakInheritanceList: AfterColon
ColumnLimit: 120
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ContinuationIndentWidth: 4
DerivePointerAlignment: false
IndentCaseLabels: true
IndentPPDirectives: None
IndentWidth: 4
MaxEmptyLinesToKeep: 1
PointerAlignment: Left
ReflowComments: true
SortIncludes: false
SortUsingDeclarations: true
SpaceAfterTemplateKeyword: false
SpacesBeforeTrailingComments: 2
TabWidth: 4
UseTab: Never

View File

@ -1 +0,0 @@
doc-valid-idents = ["CPython", "NumPy", ".."]

2
.gitignore vendored
View File

@ -1,4 +1,2 @@
__pycache__ __pycache__
/target /target
/nac3standalone/demo/linalg/target
nix/windows/msys2

View File

@ -1,24 +0,0 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
default_stages: [pre-commit]
repos:
- repo: local
hooks:
- id: nac3-cargo-fmt
name: nac3 cargo format
entry: nix
language: system
types: [file, rust]
pass_filenames: false
description: Runs cargo fmt on the codebase.
args: [develop, -c, cargo, fmt, --all]
- id: nac3-cargo-clippy
name: nac3 cargo clippy
entry: nix
language: system
types: [file, rust]
pass_filenames: false
description: Runs cargo clippy on the codebase.
args: [develop, -c, cargo, clippy, --tests]

1588
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,15 +1,6 @@
[workspace] [workspace]
members = [ members = [
"nac3ld",
"nac3ast",
"nac3parser",
"nac3core", "nac3core",
"nac3core/nac3core_derive",
"nac3standalone", "nac3standalone",
"nac3artiq", "nac3embedded",
"runkernel",
] ]
resolver = "2"
[profile.release]
debug = true

View File

@ -1,62 +0,0 @@
<div align="center">
![icon](https://git.m-labs.hk/M-Labs/nac3/raw/branch/master/nac3.svg)
</div>
# NAC3
NAC3 is a major, backward-incompatible rewrite of the compiler for the [ARTIQ](https://m-labs.hk/artiq) physics experiment control and data acquisition system. It features greatly improved compilation speeds, a much better type system, and more predictable and transparent operation.
NAC3 has a modular design and its applicability reaches beyond ARTIQ. The ``nac3core`` module does not contain anything specific to ARTIQ, and can be used in any project that requires compiling Python to machine code.
**WARNING: NAC3 is currently experimental software and several important features are not implemented yet.**
## Packaging
NAC3 is packaged using the [Nix](https://nixos.org) Flakes system. Install Nix 2.8+ and enable flakes by adding ``experimental-features = nix-command flakes`` to ``nix.conf`` (e.g. ``~/.config/nix/nix.conf``).
## Try NAC3
### Linux
After setting up Nix as above, use ``nix shell git+https://github.com/m-labs/artiq.git?ref=nac3`` to get a shell with the NAC3 version of ARTIQ. See the ``examples`` directory in ARTIQ (``nac3`` Git branch) for some samples of NAC3 kernel code.
### Windows
Install [MSYS2](https://www.msys2.org/), and open "MSYS2 CLANG64". Edit ``/etc/pacman.conf`` to add:
```
[artiq]
SigLevel = Optional TrustAll
Server = https://msys2.m-labs.hk/artiq-nac3
```
Then run the following commands:
```
pacman -Syu
pacman -S mingw-w64-clang-x86_64-artiq
```
## For developers
This repository contains:
- ``nac3ast``: Python abstract syntax tree definition (based on RustPython).
- ``nac3parser``: Python parser (based on RustPython).
- ``nac3core``: Core compiler library, containing type-checking and code generation.
- ``nac3standalone``: Standalone compiler tool (core language only).
- ``nac3ld``: Minimalist RISC-V and ARM linker.
- ``nac3artiq``: Integration with ARTIQ and implementation of ARTIQ-specific extensions to the core language.
- ``runkernel``: Simple program that runs compiled ARTIQ kernels on the host and displays RTIO operations. Useful for testing without hardware.
Use ``nix develop`` in this repository to enter a development shell.
If you are using a different shell than bash you can use e.g. ``nix develop --command fish``.
Build NAC3 with ``cargo build --release``. See the demonstrations in ``nac3artiq`` and ``nac3standalone``.
### Pre-Commit Hooks
You are strongly recommended to use the provided pre-commit hooks to automatically reformat files and check for non-optimal Rust practices using Clippy. Run `pre-commit install` to install the hook and `pre-commit` will automatically run `cargo fmt` and `cargo clippy` for you.
Several things to note:
- If `cargo fmt` or `cargo clippy` returns an error, the pre-commit hook will fail. You should fix all errors before trying to commit again.
- If `cargo fmt` reformats some files, the pre-commit hook will also fail. You should review the changes and, if satisfied, try to commit again.

View File

@ -1,27 +0,0 @@
{
"nodes": {
"nixpkgs": {
"locked": {
"lastModified": 1731319897,
"narHash": "sha256-PbABj4tnbWFMfBp6OcUK5iGy1QY+/Z96ZcLpooIbuEI=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "dc460ec76cbff0e66e269457d7b728432263166c",
"type": "github"
},
"original": {
"owner": "NixOS",
"ref": "nixos-unstable",
"repo": "nixpkgs",
"type": "github"
}
},
"root": {
"inputs": {
"nixpkgs": "nixpkgs"
}
}
},
"root": "root",
"version": 7
}

212
flake.nix
View File

@ -1,212 +0,0 @@
{
description = "The third-generation ARTIQ compiler";
inputs.nixpkgs.url = github:NixOS/nixpkgs/nixos-unstable;
outputs = { self, nixpkgs }:
let
pkgs = import nixpkgs { system = "x86_64-linux"; };
pkgs32 = import nixpkgs { system = "i686-linux"; };
in rec {
packages.x86_64-linux = rec {
llvm-nac3 = pkgs.callPackage ./nix/llvm {};
llvm-tools-irrt = pkgs.runCommandNoCC "llvm-tools-irrt" {}
''
mkdir -p $out/bin
ln -s ${pkgs.llvmPackages_14.clang-unwrapped}/bin/clang $out/bin/clang-irrt
ln -s ${pkgs.llvmPackages_14.llvm.out}/bin/llvm-as $out/bin/llvm-as-irrt
'';
demo-linalg-stub = pkgs.rustPlatform.buildRustPackage {
name = "demo-linalg-stub";
src = ./nac3standalone/demo/linalg;
cargoLock = {
lockFile = ./nac3standalone/demo/linalg/Cargo.lock;
};
doCheck = false;
};
demo-linalg-stub32 = pkgs32.rustPlatform.buildRustPackage {
name = "demo-linalg-stub32";
src = ./nac3standalone/demo/linalg;
cargoLock = {
lockFile = ./nac3standalone/demo/linalg/Cargo.lock;
};
doCheck = false;
};
nac3artiq = pkgs.python3Packages.toPythonModule (
pkgs.rustPlatform.buildRustPackage rec {
name = "nac3artiq";
outputs = [ "out" "runkernel" "standalone" ];
src = self;
cargoLock = {
lockFile = ./Cargo.lock;
};
passthru.cargoLock = cargoLock;
nativeBuildInputs = [ pkgs.python3 (pkgs.wrapClangMulti pkgs.llvmPackages_14.clang) llvm-tools-irrt pkgs.llvmPackages_14.llvm.out llvm-nac3 ];
buildInputs = [ pkgs.python3 llvm-nac3 ];
checkInputs = [ (pkgs.python3.withPackages(ps: [ ps.numpy ps.scipy ])) ];
checkPhase =
''
echo "Checking nac3standalone demos..."
pushd nac3standalone/demo
patchShebangs .
export DEMO_LINALG_STUB=${demo-linalg-stub}/lib/liblinalg.a
export DEMO_LINALG_STUB32=${demo-linalg-stub32}/lib/liblinalg.a
./check_demos.sh -i686
popd
echo "Running Cargo tests..."
cargoCheckHook
'';
installPhase =
''
PYTHON_SITEPACKAGES=$out/${pkgs.python3Packages.python.sitePackages}
mkdir -p $PYTHON_SITEPACKAGES
cp target/x86_64-unknown-linux-gnu/release/libnac3artiq.so $PYTHON_SITEPACKAGES/nac3artiq.so
mkdir -p $runkernel/bin
cp target/x86_64-unknown-linux-gnu/release/runkernel $runkernel/bin
mkdir -p $standalone/bin
cp target/x86_64-unknown-linux-gnu/release/nac3standalone $standalone/bin
'';
}
);
python3-mimalloc = pkgs.python3 // rec {
withMimalloc = pkgs.python3.buildEnv.override({ makeWrapperArgs = [ "--set LD_PRELOAD ${pkgs.mimalloc}/lib/libmimalloc.so" ]; });
withPackages = f: let packages = f pkgs.python3.pkgs; in withMimalloc.override { extraLibs = packages; };
};
# LLVM PGO support
llvm-nac3-instrumented = pkgs.callPackage ./nix/llvm {
stdenv = pkgs.llvmPackages_14.stdenv;
extraCmakeFlags = [ "-DLLVM_BUILD_INSTRUMENTED=IR" ];
};
nac3artiq-instrumented = pkgs.python3Packages.toPythonModule (
pkgs.rustPlatform.buildRustPackage {
name = "nac3artiq-instrumented";
src = self;
inherit (nac3artiq) cargoLock;
nativeBuildInputs = [ pkgs.python3 packages.x86_64-linux.llvm-tools-irrt llvm-nac3-instrumented ];
buildInputs = [ pkgs.python3 llvm-nac3-instrumented ];
cargoBuildFlags = [ "--package" "nac3artiq" "--features" "init-llvm-profile" ];
doCheck = false;
configurePhase =
''
export CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_RUSTFLAGS="-C link-arg=-L${pkgs.llvmPackages_14.compiler-rt}/lib/linux -C link-arg=-lclang_rt.profile-x86_64"
'';
installPhase =
''
TARGET_DIR=$out/${pkgs.python3Packages.python.sitePackages}
mkdir -p $TARGET_DIR
cp target/x86_64-unknown-linux-gnu/release/libnac3artiq.so $TARGET_DIR/nac3artiq.so
'';
}
);
nac3artiq-profile = pkgs.stdenvNoCC.mkDerivation {
name = "nac3artiq-profile";
srcs = [
(pkgs.fetchFromGitHub {
owner = "m-labs";
repo = "sipyco";
rev = "094a6cd63ffa980ef63698920170e50dc9ba77fd";
sha256 = "sha256-PPnAyDedUQ7Og/Cby9x5OT9wMkNGTP8GS53V6N/dk4w=";
})
(pkgs.fetchFromGitHub {
owner = "m-labs";
repo = "artiq";
rev = "28c9de3e251daa89a8c9fd79d5ab64a3ec03bac6";
sha256 = "sha256-vAvpbHc5B+1wtG8zqN7j9dQE1ON+i22v+uqA+tw6Gak=";
})
];
buildInputs = [
(python3-mimalloc.withPackages(ps: [ ps.numpy ps.scipy ps.jsonschema ps.lmdb ps.platformdirs nac3artiq-instrumented ]))
pkgs.llvmPackages_14.llvm.out
];
phases = [ "buildPhase" "installPhase" ];
buildPhase =
''
srcs=($srcs)
sipyco=''${srcs[0]}
artiq=''${srcs[1]}
export PYTHONPATH=$sipyco:$artiq
python -m artiq.frontend.artiq_ddb_template $artiq/artiq/examples/nac3devices/nac3devices.json > device_db.py
cp $artiq/artiq/examples/nac3devices/nac3devices.py .
python -m artiq.frontend.artiq_compile nac3devices.py
'';
installPhase =
''
mkdir $out
llvm-profdata merge -o $out/llvm.profdata /build/llvm/build/profiles/*
'';
};
llvm-nac3-pgo = pkgs.callPackage ./nix/llvm {
stdenv = pkgs.llvmPackages_14.stdenv;
extraCmakeFlags = [ "-DLLVM_PROFDATA_FILE=${nac3artiq-profile}/llvm.profdata" ];
};
nac3artiq-pgo = pkgs.python3Packages.toPythonModule (
pkgs.rustPlatform.buildRustPackage {
name = "nac3artiq-pgo";
src = self;
inherit (nac3artiq) cargoLock;
nativeBuildInputs = [ pkgs.python3 packages.x86_64-linux.llvm-tools-irrt llvm-nac3-pgo ];
buildInputs = [ pkgs.python3 llvm-nac3-pgo ];
cargoBuildFlags = [ "--package" "nac3artiq" ];
cargoTestFlags = [ "--package" "nac3ast" "--package" "nac3parser" "--package" "nac3core" "--package" "nac3artiq" ];
installPhase =
''
TARGET_DIR=$out/${pkgs.python3Packages.python.sitePackages}
mkdir -p $TARGET_DIR
cp target/x86_64-unknown-linux-gnu/release/libnac3artiq.so $TARGET_DIR/nac3artiq.so
'';
}
);
};
packages.x86_64-w64-mingw32 = import ./nix/windows { inherit pkgs; };
devShells.x86_64-linux.default = pkgs.mkShell {
name = "nac3-dev-shell";
buildInputs = with pkgs; [
# build dependencies
packages.x86_64-linux.llvm-nac3
(pkgs.wrapClangMulti llvmPackages_14.clang) llvmPackages_14.llvm.out # for running nac3standalone demos
packages.x86_64-linux.llvm-tools-irrt
cargo
rustc
# runtime dependencies
lld_14 # for running kernels on the host
(packages.x86_64-linux.python3-mimalloc.withPackages(ps: [ ps.numpy ps.scipy ]))
# development tools
cargo-insta
clippy
pre-commit
rustfmt
];
shellHook =
''
export DEMO_LINALG_STUB=${packages.x86_64-linux.demo-linalg-stub}/lib/liblinalg.a
export DEMO_LINALG_STUB32=${packages.x86_64-linux.demo-linalg-stub32}/lib/liblinalg.a
'';
};
devShells.x86_64-linux.msys2 = pkgs.mkShell {
name = "nac3-dev-shell-msys2";
buildInputs = with pkgs; [
curl
pacman
fakeroot
packages.x86_64-w64-mingw32.wine-msys2
];
};
hydraJobs = {
inherit (packages.x86_64-linux) llvm-nac3 nac3artiq nac3artiq-pgo;
llvm-nac3-msys2 = packages.x86_64-w64-mingw32.llvm-nac3;
nac3artiq-msys2 = packages.x86_64-w64-mingw32.nac3artiq;
nac3artiq-msys2-pkg = packages.x86_64-w64-mingw32.nac3artiq-pkg;
};
};
nixConfig = {
extra-trusted-public-keys = "nixbld.m-labs.hk-1:5aSRVA5b320xbNvu30tqxVPXpld73bhtOeH6uAjRyHc=";
extra-substituters = "https://nixbld.m-labs.hk";
};
}

View File

@ -1,56 +0,0 @@
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg
id="a"
width="128"
height="128"
viewBox="0 0 95.99999 95.99999"
version="1.1"
sodipodi:docname="nac3.svg"
inkscape:version="1.1.1 (3bf5ae0d25, 2021-09-20)"
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
xmlns="http://www.w3.org/2000/svg"
xmlns:svg="http://www.w3.org/2000/svg">
<defs
id="defs11" />
<sodipodi:namedview
id="namedview9"
pagecolor="#ffffff"
bordercolor="#666666"
borderopacity="1.0"
inkscape:pageshadow="2"
inkscape:pageopacity="0.0"
inkscape:pagecheckerboard="0"
inkscape:document-units="mm"
showgrid="false"
units="px"
width="128px"
inkscape:zoom="5.9448568"
inkscape:cx="60.472441"
inkscape:cy="60.556547"
inkscape:window-width="2560"
inkscape:window-height="1371"
inkscape:window-x="0"
inkscape:window-y="32"
inkscape:window-maximized="1"
inkscape:current-layer="a" />
<rect
x="40.072601"
y="-26.776209"
width="55.668747"
height="55.668747"
transform="matrix(0.71803815,0.69600374,-0.71803815,0.69600374,0,0)"
style="fill:#be211e;stroke:#000000;stroke-width:4.37375px;stroke-linecap:round;stroke-linejoin:round"
id="rect2" />
<line
x1="38.00692"
y1="63.457153"
x2="57.993061"
y2="63.457153"
style="fill:none;stroke:#000000;stroke-width:4.37269px;stroke-linecap:round;stroke-linejoin:round"
id="line4" />
<path
d="m 48.007301,57.843329 c -1.943097,0 -3.877522,-0.41727 -5.686157,-1.246007 -3.218257,-1.474616 -5.650382,-4.075418 -6.849639,-7.323671 -2.065624,-5.588921 -1.192751,-10.226647 2.575258,-13.827 0.611554,-0.584909 1.518048,-0.773041 2.323689,-0.488206 0.80673,0.286405 1.369495,0.998486 1.447563,1.827234 0.237469,2.549302 2.439719,5.917376 4.28414,6.55273 0.396859,0.13506 0.820953,-0.05859 1.097084,-0.35222 0.339254,-0.360754 0.451065,-0.961893 -1.013597,-3.191372 -2.089851,-3.181137 -4.638728,-8.754903 -0.262407,-15.069853 0.494457,-0.713491 1.384673,-1.068907 2.256469,-0.909156 0.871795,0.161332 1.583757,0.806404 1.752251,1.651189 0.716448,3.591862 2.962357,6.151755 5.199306,8.023138 1.935503,1.61861 4.344688,3.867387 5.435687,7.096643 2.283183,6.758017 -1.202511,14.114988 -8.060822,16.494025 -1.467083,0.509226 -2.98513,0.762536 -4.498836,0.762536 z M 39.358865,40.002192 c -0.304711,0.696206 -0.541636,2.080524 -0.56865,2.237454 -0.330316,1.918771 0.168305,3.803963 0.846157,5.539951 0.856828,2.19436 2.437543,3.942467 4.583411,4.925713 2.143691,0.981675 4.554131,1.097816 6.789992,0.322666 4.571485,-1.586549 6.977584,-6.532238 5.363036,-11.02597 v -5.27e-4 C 55.455481,39.447968 54.023463,38.162043 52.221335,36.65432 50.876945,35.529534 49.409662,33.987726 48.417983,32.135555 48.01343,31.37996 47.79547,30.34303 47.76669,29.413263 c -0.187481,0.669514 -0.212441,2.325923 -0.150396,2.93691 0.179209,1.764456 1.333476,3.644546 2.340611,5.171243 1.311568,1.988179 2.72058,6.037272 0.459681,8.367985 -1.54192,1.58953 -4.038511,2.052034 -5.839973,1.38492 -2.398314,-0.888147 -3.942744,-2.690627 -4.941118,-4.768029 -0.121194,-0.25217 -0.532464,-1.174187 -0.276619,-2.5041 z"
id="path6"
style="stroke-width:1.09317" />
</svg>

Before

Width:  |  Height:  |  Size: 3.3 KiB

View File

@ -1,21 +0,0 @@
[package]
name = "nac3artiq"
version = "0.1.0"
authors = ["M-Labs"]
edition = "2021"
[lib]
name = "nac3artiq"
crate-type = ["cdylib"]
[dependencies]
itertools = "0.13"
pyo3 = { version = "0.21", features = ["extension-module", "gil-refs"] }
parking_lot = "0.12"
tempfile = "3.13"
nac3core = { path = "../nac3core" }
nac3ld = { path = "../nac3ld" }
[features]
init-llvm-profile = []
no-escape-analysis = ["nac3core/no-escape-analysis"]

View File

@ -1,26 +0,0 @@
from min_artiq import *
@nac3
class Demo:
core: KernelInvariant[Core]
led0: KernelInvariant[TTLOut]
led1: KernelInvariant[TTLOut]
def __init__(self):
self.core = Core()
self.led0 = TTLOut(self.core, 18)
self.led1 = TTLOut(self.core, 19)
@kernel
def run(self):
self.core.reset()
while True:
with parallel:
self.led0.pulse(100.*ms)
self.led1.pulse(100.*ms)
self.core.delay(100.*ms)
if __name__ == "__main__":
Demo().run()

View File

@ -1,16 +0,0 @@
# python demo.py
# artiq_run module.elf
device_db = {
"core": {
"type": "local",
"module": "artiq.coredevice.core",
"class": "Core",
"arguments": {
"host": "kc705",
"ref_period": 1e-9,
"ref_multiplier": 8,
"target": "rv32g"
}
},
}

View File

@ -1,66 +0,0 @@
class EmbeddingMap:
def __init__(self):
self.object_inverse_map = {}
self.object_map = {}
self.string_map = {}
self.string_reverse_map = {}
self.function_map = {}
self.attributes_writeback = []
# preallocate exception names
self.preallocate_runtime_exception_names(["RuntimeError",
"RTIOUnderflow",
"RTIOOverflow",
"RTIODestinationUnreachable",
"DMAError",
"I2CError",
"CacheError",
"SPIError",
"0:ZeroDivisionError",
"0:IndexError",
"0:ValueError",
"0:RuntimeError",
"0:AssertionError",
"0:KeyError",
"0:NotImplementedError",
"0:OverflowError",
"0:IOError",
"0:UnwrapNoneError"])
def preallocate_runtime_exception_names(self, names):
for i, name in enumerate(names):
if ":" not in name:
name = "0:artiq.coredevice.exceptions." + name
exn_id = self.store_str(name)
assert exn_id == i
def store_function(self, key, fun):
self.function_map[key] = fun
return key
def store_object(self, obj):
obj_id = id(obj)
if obj_id in self.object_inverse_map:
return self.object_inverse_map[obj_id]
key = len(self.object_map) + 1
self.object_map[key] = obj
self.object_inverse_map[obj_id] = key
return key
def store_str(self, s):
if s in self.string_reverse_map:
return self.string_reverse_map[s]
key = len(self.string_map)
self.string_map[key] = s
self.string_reverse_map[s] = key
return key
def retrieve_function(self, key):
return self.function_map[key]
def retrieve_object(self, key):
return self.object_map[key]
def retrieve_str(self, key):
return self.string_map[key]

View File

@ -1,24 +0,0 @@
from min_artiq import *
from numpy import int32
@nac3
class EmptyList:
core: KernelInvariant[Core]
def __init__(self):
self.core = Core()
@rpc
def get_empty(self) -> list[int32]:
return []
@kernel
def run(self):
a: list[int32] = self.get_empty()
if a != []:
raise ValueError
if __name__ == "__main__":
EmptyList().run()

View File

@ -1,300 +0,0 @@
from inspect import getfullargspec
from functools import wraps
from types import SimpleNamespace
from numpy import int32, int64
from typing import Generic, TypeVar
from math import floor, ceil
import nac3artiq
from embedding_map import EmbeddingMap
__all__ = [
"Kernel", "KernelInvariant", "virtual", "ConstGeneric",
"Option", "Some", "none", "UnwrapNoneError",
"round64", "floor64", "ceil64",
"extern", "kernel", "portable", "nac3",
"rpc", "ms", "us", "ns",
"print_int32", "print_int64",
"Core", "TTLOut",
"parallel", "sequential"
]
T = TypeVar('T')
class Kernel(Generic[T]):
pass
class KernelInvariant(Generic[T]):
pass
# The virtual class must exist before nac3artiq.NAC3 is created.
class virtual(Generic[T]):
pass
class Option(Generic[T]):
_nac3_option: T
def __init__(self, v: T):
self._nac3_option = v
def is_none(self):
return self._nac3_option is None
def is_some(self):
return not self.is_none()
def unwrap(self):
if self.is_none():
raise UnwrapNoneError()
return self._nac3_option
def __repr__(self) -> str:
if self.is_none():
return "none"
else:
return "Some({})".format(repr(self._nac3_option))
def __str__(self) -> str:
if self.is_none():
return "none"
else:
return "Some({})".format(str(self._nac3_option))
def Some(v: T) -> Option[T]:
return Option(v)
none = Option(None)
class _ConstGenericMarker:
pass
def ConstGeneric(name, constraint):
return TypeVar(name, _ConstGenericMarker, constraint)
def round64(x):
return round(x)
def floor64(x):
return floor(x)
def ceil64(x):
return ceil(x)
import device_db
core_arguments = device_db.device_db["core"]["arguments"]
artiq_builtins = {
"none": none,
"virtual": virtual,
"_ConstGenericMarker": _ConstGenericMarker,
"Option": Option,
}
compiler = nac3artiq.NAC3(core_arguments["target"], artiq_builtins)
allow_registration = True
# Delay NAC3 analysis until all referenced variables are supposed to exist on the CPython side.
registered_functions = set()
registered_classes = set()
def register_function(fun):
assert allow_registration
registered_functions.add(fun)
def register_class(cls):
assert allow_registration
registered_classes.add(cls)
def extern(function):
"""Decorates a function declaration defined by the core device runtime."""
register_function(function)
return function
def rpc(arg=None, flags={}):
"""Decorates a function or method to be executed on the host interpreter."""
if arg is None:
def inner_decorator(function):
return rpc(function, flags)
return inner_decorator
register_function(arg)
return arg
def kernel(function_or_method):
"""Decorates a function or method to be executed on the core device."""
register_function(function_or_method)
argspec = getfullargspec(function_or_method)
if argspec.args and argspec.args[0] == "self":
@wraps(function_or_method)
def run_on_core(self, *args, **kwargs):
fake_method = SimpleNamespace(__self__=self, __name__=function_or_method.__name__)
self.core.run(fake_method, *args, **kwargs)
else:
@wraps(function_or_method)
def run_on_core(*args, **kwargs):
raise RuntimeError("Kernel functions need explicit core.run()")
return run_on_core
def portable(function):
"""Decorates a function or method to be executed on the same device (host/core device) as the caller."""
register_function(function)
return function
def nac3(cls):
"""
Decorates a class to be analyzed by NAC3.
All classes containing kernels or portable methods must use this decorator.
"""
register_class(cls)
return cls
ms = 1e-3
us = 1e-6
ns = 1e-9
@extern
def rtio_init():
raise NotImplementedError("syscall not simulated")
@extern
def rtio_get_counter() -> int64:
raise NotImplementedError("syscall not simulated")
@extern
def rtio_output(target: int32, data: int32):
raise NotImplementedError("syscall not simulated")
@extern
def rtio_input_timestamp(timeout_mu: int64, channel: int32) -> int64:
raise NotImplementedError("syscall not simulated")
@extern
def rtio_input_data(channel: int32) -> int32:
raise NotImplementedError("syscall not simulated")
# These is not part of ARTIQ and only available in runkernel. Defined here for convenience.
@extern
def print_int32(x: int32):
raise NotImplementedError("syscall not simulated")
@extern
def print_int64(x: int64):
raise NotImplementedError("syscall not simulated")
@nac3
class Core:
ref_period: KernelInvariant[float]
def __init__(self):
self.ref_period = core_arguments["ref_period"]
def run(self, method, *args, **kwargs):
global allow_registration
embedding = EmbeddingMap()
if allow_registration:
compiler.analyze(registered_functions, registered_classes, set())
allow_registration = False
if hasattr(method, "__self__"):
obj = method.__self__
name = method.__name__
else:
obj = method
name = ""
compiler.compile_method_to_file(obj, name, args, "module.elf", embedding)
@kernel
def reset(self):
rtio_init()
at_mu(rtio_get_counter() + int64(125000))
@kernel
def break_realtime(self):
min_now = rtio_get_counter() + int64(125000)
if now_mu() < min_now:
at_mu(min_now)
@portable
def seconds_to_mu(self, seconds: float) -> int64:
return int64(round(seconds/self.ref_period))
@portable
def mu_to_seconds(self, mu: int64) -> float:
return float(mu)*self.ref_period
@kernel
def delay(self, dt: float):
delay_mu(self.seconds_to_mu(dt))
@nac3
class TTLOut:
core: KernelInvariant[Core]
channel: KernelInvariant[int32]
target_o: KernelInvariant[int32]
def __init__(self, core: Core, channel: int32):
self.core = core
self.channel = channel
self.target_o = channel << 8
@kernel
def output(self):
pass
@kernel
def set_o(self, o: bool):
rtio_output(self.target_o, 1 if o else 0)
@kernel
def on(self):
self.set_o(True)
@kernel
def off(self):
self.set_o(False)
@kernel
def pulse_mu(self, duration: int64):
self.on()
delay_mu(duration)
self.off()
@kernel
def pulse(self, duration: float):
self.on()
self.core.delay(duration)
self.off()
@nac3
class KernelContextManager:
@kernel
def __enter__(self):
pass
@kernel
def __exit__(self):
pass
@nac3
class UnwrapNoneError(Exception):
"""raised when unwrapping a none value"""
artiq_builtin = True
parallel = KernelContextManager()
sequential = KernelContextManager()

View File

@ -1 +0,0 @@
../../target/release/libnac3artiq.so

View File

@ -1,26 +0,0 @@
from min_artiq import *
from numpy import ndarray, zeros as np_zeros
@nac3
class StrFail:
core: KernelInvariant[Core]
def __init__(self):
self.core = Core()
@kernel
def hello(self, arg: str):
pass
@kernel
def consume_ndarray(self, arg: ndarray[str, 1]):
pass
def run(self):
self.hello("world")
self.consume_ndarray(np_zeros([10], dtype=str))
if __name__ == "__main__":
StrFail().run()

View File

@ -1,24 +0,0 @@
from min_artiq import *
from numpy import int32
@nac3
class Demo:
core: KernelInvariant[Core]
attr1: KernelInvariant[str]
attr2: KernelInvariant[int32]
def __init__(self):
self.core = Core()
self.attr2 = 32
self.attr1 = "SAMPLE"
@kernel
def run(self):
print_int32(self.attr2)
self.attr1
if __name__ == "__main__":
Demo().run()

View File

@ -1,40 +0,0 @@
from min_artiq import *
from numpy import int32
@nac3
class Demo:
attr1: KernelInvariant[int32] = 2
attr2: int32 = 4
attr3: Kernel[int32]
@kernel
def __init__(self):
self.attr3 = 8
@nac3
class NAC3Devices:
core: KernelInvariant[Core]
attr4: KernelInvariant[int32] = 16
def __init__(self):
self.core = Core()
@kernel
def run(self):
Demo.attr1 # Supported
# Demo.attr2 # Field not accessible on Kernel
# Demo.attr3 # Only attributes can be accessed in this way
# Demo.attr1 = 2 # Attributes are immutable
self.attr4 # Attributes can be accessed within class
obj = Demo()
obj.attr1 # Attributes can be accessed by class objects
NAC3Devices.attr4 # Attributes accessible for classes without __init__
if __name__ == "__main__":
NAC3Devices().run()

File diff suppressed because it is too large Load Diff

View File

@ -1,56 +0,0 @@
/* Force ld to make the ELF header as loadable. */
PHDRS
{
headers PT_LOAD FILEHDR PHDRS ;
text PT_LOAD ;
data PT_LOAD ;
dynamic PT_DYNAMIC ;
eh_frame PT_GNU_EH_FRAME ;
}
SECTIONS
{
/* Push back .text section enough so that ld.lld not complain */
. = SIZEOF_HEADERS;
.text :
{
*(.text .text.*)
} : text
.rodata :
{
*(.rodata .rodata.*)
}
.eh_frame :
{
KEEP(*(.eh_frame))
} : text
.eh_frame_hdr :
{
KEEP(*(.eh_frame_hdr))
} : text : eh_frame
.data :
{
*(.data)
} : data
.dynamic :
{
*(.dynamic)
} : data : dynamic
.bss (NOLOAD) : ALIGN(4)
{
__bss_start = .;
*(.sbss .sbss.* .bss .bss.*);
. = ALIGN(4);
_end = .;
}
. = ALIGN(0x1000);
_sstack_guard = .;
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,324 +0,0 @@
use itertools::Either;
use nac3core::{
codegen::CodeGenContext,
inkwell::{
values::{BasicValueEnum, CallSiteValue},
AddressSpace, AtomicOrdering,
},
};
/// Functions for manipulating the timeline.
pub trait TimeFns {
/// Emits LLVM IR for `now_mu`.
fn emit_now_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx>;
/// Emits LLVM IR for `at_mu`.
fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>);
/// Emits LLVM IR for `delay_mu`.
fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>);
}
pub struct NowPinningTimeFns64 {}
// For FPGA design reasons, on VexRiscv with 64-bit data bus, the "now" CSR is split into two 32-bit
// values that are each padded to 64-bits.
impl TimeFns for NowPinningTimeFns64 {
fn emit_now_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx> {
let i64_type = ctx.ctx.i64_type();
let i32_type = ctx.ctx.i32_type();
let now = ctx
.module
.get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx
.builder
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")
}
.unwrap();
let now_hi = ctx
.builder
.build_load(now_hiptr, "now.hi")
.map(BasicValueEnum::into_int_value)
.unwrap();
let now_lo = ctx
.builder
.build_load(now_loptr, "now.lo")
.map(BasicValueEnum::into_int_value)
.unwrap();
let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "").unwrap();
let shifted_hi =
ctx.builder.build_left_shift(zext_hi, i64_type.const_int(32, false), "").unwrap();
let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "").unwrap();
ctx.builder.build_or(shifted_hi, zext_lo, "now_mu").map(Into::into).unwrap()
}
fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) {
let i32_type = ctx.ctx.i32_type();
let i64_type = ctx.ctx.i64_type();
let i64_32 = i64_type.const_int(32, false);
let time = t.into_int_value();
let time_hi = ctx
.builder
.build_int_truncate(
ctx.builder.build_right_shift(time, i64_32, false, "time.hi").unwrap(),
i32_type,
"",
)
.unwrap();
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap();
let now = ctx
.module
.get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx
.builder
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")
}
.unwrap();
ctx.builder
.build_store(now_hiptr, time_hi)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
ctx.builder
.build_store(now_loptr, time_lo)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
}
fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>) {
let i64_type = ctx.ctx.i64_type();
let i32_type = ctx.ctx.i32_type();
let now = ctx
.module
.get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx
.builder
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")
}
.unwrap();
let now_hi = ctx
.builder
.build_load(now_hiptr, "now.hi")
.map(BasicValueEnum::into_int_value)
.unwrap();
let now_lo = ctx
.builder
.build_load(now_loptr, "now.lo")
.map(BasicValueEnum::into_int_value)
.unwrap();
let dt = dt.into_int_value();
let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "").unwrap();
let shifted_hi =
ctx.builder.build_left_shift(zext_hi, i64_type.const_int(32, false), "").unwrap();
let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "").unwrap();
let now_val = ctx.builder.build_or(shifted_hi, zext_lo, "now").unwrap();
let time = ctx.builder.build_int_add(now_val, dt, "time").unwrap();
let time_hi = ctx
.builder
.build_int_truncate(
ctx.builder
.build_right_shift(time, i64_type.const_int(32, false), false, "")
.unwrap(),
i32_type,
"time.hi",
)
.unwrap();
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap();
ctx.builder
.build_store(now_hiptr, time_hi)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
ctx.builder
.build_store(now_loptr, time_lo)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
}
}
pub static NOW_PINNING_TIME_FNS_64: NowPinningTimeFns64 = NowPinningTimeFns64 {};
pub struct NowPinningTimeFns {}
impl TimeFns for NowPinningTimeFns {
fn emit_now_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx> {
let i64_type = ctx.ctx.i64_type();
let now = ctx
.module
.get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_raw = ctx
.builder
.build_load(now.as_pointer_value(), "now")
.map(BasicValueEnum::into_int_value)
.unwrap();
let i64_32 = i64_type.const_int(32, false);
let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now.lo").unwrap();
let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi").unwrap();
ctx.builder.build_or(now_lo, now_hi, "now_mu").map(Into::into).unwrap()
}
fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) {
let i32_type = ctx.ctx.i32_type();
let i64_type = ctx.ctx.i64_type();
let i64_32 = i64_type.const_int(32, false);
let time = t.into_int_value();
let time_hi = ctx
.builder
.build_int_truncate(
ctx.builder.build_right_shift(time, i64_32, false, "").unwrap(),
i32_type,
"time.hi",
)
.unwrap();
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "now_trunc").unwrap();
let now = ctx
.module
.get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx
.builder
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr")
}
.unwrap();
ctx.builder
.build_store(now_hiptr, time_hi)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
ctx.builder
.build_store(now_loptr, time_lo)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
}
fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>) {
let i32_type = ctx.ctx.i32_type();
let i64_type = ctx.ctx.i64_type();
let i64_32 = i64_type.const_int(32, false);
let now = ctx
.module
.get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_raw = ctx
.builder
.build_load(now.as_pointer_value(), "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let dt = dt.into_int_value();
let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now.lo").unwrap();
let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi").unwrap();
let now_val = ctx.builder.build_or(now_lo, now_hi, "now_val").unwrap();
let time = ctx.builder.build_int_add(now_val, dt, "time").unwrap();
let time_hi = ctx
.builder
.build_int_truncate(
ctx.builder.build_right_shift(time, i64_32, false, "time.hi").unwrap(),
i32_type,
"now_trunc",
)
.unwrap();
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap();
let now_hiptr = ctx
.builder
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr")
}
.unwrap();
ctx.builder
.build_store(now_hiptr, time_hi)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
ctx.builder
.build_store(now_loptr, time_lo)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
}
}
pub static NOW_PINNING_TIME_FNS: NowPinningTimeFns = NowPinningTimeFns {};
pub struct ExternTimeFns {}
impl TimeFns for ExternTimeFns {
fn emit_now_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx> {
let now_mu = ctx.module.get_function("now_mu").unwrap_or_else(|| {
ctx.module.add_function("now_mu", ctx.ctx.i64_type().fn_type(&[], false), None)
});
ctx.builder
.build_call(now_mu, &[], "now_mu")
.map(CallSiteValue::try_as_basic_value)
.map(Either::unwrap_left)
.unwrap()
}
fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) {
let at_mu = ctx.module.get_function("at_mu").unwrap_or_else(|| {
ctx.module.add_function(
"at_mu",
ctx.ctx.void_type().fn_type(&[ctx.ctx.i64_type().into()], false),
None,
)
});
ctx.builder.build_call(at_mu, &[t.into()], "at_mu").unwrap();
}
fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>) {
let delay_mu = ctx.module.get_function("delay_mu").unwrap_or_else(|| {
ctx.module.add_function(
"delay_mu",
ctx.ctx.void_type().fn_type(&[ctx.ctx.i64_type().into()], false),
None,
)
});
ctx.builder.build_call(delay_mu, &[dt.into()], "delay_mu").unwrap();
}
}
pub static EXTERN_TIME_FNS: ExternTimeFns = ExternTimeFns {};

View File

@ -1,15 +0,0 @@
[package]
name = "nac3ast"
version = "0.1.0"
authors = ["RustPython Team", "M-Labs"]
edition = "2021"
[features]
default = ["constant-optimization", "fold"]
constant-optimization = ["fold"]
fold = []
[dependencies]
parking_lot = "0.12"
string-interner = "0.17"
fxhash = "0.2"

View File

@ -1,127 +0,0 @@
-- ASDL's 4 builtin types are:
-- identifier, int, string, constant
module Python
{
mod = Module(stmt* body, type_ignore* type_ignores)
| Interactive(stmt* body)
| Expression(expr body)
| FunctionType(expr* argtypes, expr returns)
stmt = FunctionDef(identifier name, arguments args,
stmt* body, expr* decorator_list, expr? returns,
string? type_comment, identifier* config_comment)
| AsyncFunctionDef(identifier name, arguments args,
stmt* body, expr* decorator_list, expr? returns,
string? type_comment, identifier* config_comment)
| ClassDef(identifier name,
expr* bases,
keyword* keywords,
stmt* body,
expr* decorator_list, identifier* config_comment)
| Return(expr? value, identifier* config_comment)
| Delete(expr* targets, identifier* config_comment)
| Assign(expr* targets, expr value, string? type_comment, identifier* config_comment)
| AugAssign(expr target, operator op, expr value, identifier* config_comment)
-- 'simple' indicates that we annotate simple name without parens
| AnnAssign(expr target, expr annotation, expr? value, bool simple, identifier* config_comment)
-- use 'orelse' because else is a keyword in target languages
| For(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment, identifier* config_comment)
| AsyncFor(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment, identifier* config_comment)
| While(expr test, stmt* body, stmt* orelse, identifier* config_comment)
| If(expr test, stmt* body, stmt* orelse, identifier* config_comment)
| With(withitem* items, stmt* body, string? type_comment, identifier* config_comment)
| AsyncWith(withitem* items, stmt* body, string? type_comment, identifier* config_comment)
| Raise(expr? exc, expr? cause, identifier* config_comment)
| Try(stmt* body, excepthandler* handlers, stmt* orelse, stmt* finalbody, identifier* config_comment)
| Assert(expr test, expr? msg, identifier* config_comment)
| Import(alias* names, identifier* config_comment)
| ImportFrom(identifier? module, alias* names, int level, identifier* config_comment)
| Global(identifier* names, identifier* config_comment)
| Nonlocal(identifier* names, identifier* config_comment)
| Expr(expr value, identifier* config_comment)
| Pass(identifier* config_comment)
| Break(identifier* config_comment)
| Continue(identifier* config_comment)
-- col_offset is the byte offset in the utf8 string the parser uses
attributes (int lineno, int col_offset, int? end_lineno, int? end_col_offset)
-- BoolOp() can use left & right?
expr = BoolOp(boolop op, expr* values)
| NamedExpr(expr target, expr value)
| BinOp(expr left, operator op, expr right)
| UnaryOp(unaryop op, expr operand)
| Lambda(arguments args, expr body)
| IfExp(expr test, expr body, expr orelse)
| Dict(expr?* keys, expr* values)
| Set(expr* elts)
| ListComp(expr elt, comprehension* generators)
| SetComp(expr elt, comprehension* generators)
| DictComp(expr key, expr value, comprehension* generators)
| GeneratorExp(expr elt, comprehension* generators)
-- the grammar constrains where yield expressions can occur
| Await(expr value)
| Yield(expr? value)
| YieldFrom(expr value)
-- need sequences for compare to distinguish between
-- x < 4 < 3 and (x < 4) < 3
| Compare(expr left, cmpop* ops, expr* comparators)
| Call(expr func, expr* args, keyword* keywords)
| FormattedValue(expr value, conversion_flag? conversion, expr? format_spec)
| JoinedStr(expr* values)
| Constant(constant value, string? kind)
-- the following expression can appear in assignment context
| Attribute(expr value, identifier attr, expr_context ctx)
| Subscript(expr value, expr slice, expr_context ctx)
| Starred(expr value, expr_context ctx)
| Name(identifier id, expr_context ctx)
| List(expr* elts, expr_context ctx)
| Tuple(expr* elts, expr_context ctx)
-- can appear only in Subscript
| Slice(expr? lower, expr? upper, expr? step)
-- col_offset is the byte offset in the utf8 string the parser uses
attributes (int lineno, int col_offset, int? end_lineno, int? end_col_offset)
expr_context = Load | Store | Del
boolop = And | Or
operator = Add | Sub | Mult | MatMult | Div | Mod | Pow | LShift
| RShift | BitOr | BitXor | BitAnd | FloorDiv
unaryop = Invert | Not | UAdd | USub
cmpop = Eq | NotEq | Lt | LtE | Gt | GtE | Is | IsNot | In | NotIn
comprehension = (expr target, expr iter, expr* ifs, bool is_async)
excepthandler = ExceptHandler(expr? type, identifier? name, stmt* body)
attributes (int lineno, int col_offset, int? end_lineno, int? end_col_offset)
arguments = (arg* posonlyargs, arg* args, arg? vararg, arg* kwonlyargs,
expr?* kw_defaults, arg? kwarg, expr* defaults)
arg = (identifier arg, expr? annotation, string? type_comment)
attributes (int lineno, int col_offset, int? end_lineno, int? end_col_offset)
-- keyword arguments supplied to call (NULL identifier for **kwargs)
keyword = (identifier? arg, expr value)
attributes (int lineno, int col_offset, int? end_lineno, int? end_col_offset)
-- import name with optional 'as' alias.
alias = (identifier name, identifier? asname)
withitem = (expr context_expr, expr? optional_vars)
type_ignore = TypeIgnore(int lineno, string tag)
}

View File

@ -1,385 +0,0 @@
#-------------------------------------------------------------------------------
# Parser for ASDL [1] definition files. Reads in an ASDL description and parses
# it into an AST that describes it.
#
# The EBNF we're parsing here: Figure 1 of the paper [1]. Extended to support
# modules and attributes after a product. Words starting with Capital letters
# are terminals. Literal tokens are in "double quotes". Others are
# non-terminals. Id is either TokenId or ConstructorId.
#
# module ::= "module" Id "{" [definitions] "}"
# definitions ::= { TypeId "=" type }
# type ::= product | sum
# product ::= fields ["attributes" fields]
# fields ::= "(" { field, "," } field ")"
# field ::= TypeId ["?" | "*"] [Id]
# sum ::= constructor { "|" constructor } ["attributes" fields]
# constructor ::= ConstructorId [fields]
#
# [1] "The Zephyr Abstract Syntax Description Language" by Wang, et. al. See
# http://asdl.sourceforge.net/
#-------------------------------------------------------------------------------
from collections import namedtuple
import re
__all__ = [
'builtin_types', 'parse', 'AST', 'Module', 'Type', 'Constructor',
'Field', 'Sum', 'Product', 'VisitorBase', 'Check', 'check']
# The following classes define nodes into which the ASDL description is parsed.
# Note: this is a "meta-AST". ASDL files (such as Python.asdl) describe the AST
# structure used by a programming language. But ASDL files themselves need to be
# parsed. This module parses ASDL files and uses a simple AST to represent them.
# See the EBNF at the top of the file to understand the logical connection
# between the various node types.
builtin_types = {'identifier', 'string', 'int', 'constant', 'bool', 'conversion_flag'}
class AST:
def __repr__(self):
raise NotImplementedError
class Module(AST):
def __init__(self, name, dfns):
self.name = name
self.dfns = dfns
self.types = {type.name: type.value for type in dfns}
def __repr__(self):
return 'Module({0.name}, {0.dfns})'.format(self)
class Type(AST):
def __init__(self, name, value):
self.name = name
self.value = value
def __repr__(self):
return 'Type({0.name}, {0.value})'.format(self)
class Constructor(AST):
def __init__(self, name, fields=None):
self.name = name
self.fields = fields or []
def __repr__(self):
return 'Constructor({0.name}, {0.fields})'.format(self)
class Field(AST):
def __init__(self, type, name=None, seq=False, opt=False):
self.type = type
self.name = name
self.seq = seq
self.opt = opt
def __str__(self):
if self.seq:
extra = "*"
elif self.opt:
extra = "?"
else:
extra = ""
return "{}{} {}".format(self.type, extra, self.name)
def __repr__(self):
if self.seq:
extra = ", seq=True"
elif self.opt:
extra = ", opt=True"
else:
extra = ""
if self.name is None:
return 'Field({0.type}{1})'.format(self, extra)
else:
return 'Field({0.type}, {0.name}{1})'.format(self, extra)
class Sum(AST):
def __init__(self, types, attributes=None):
self.types = types
self.attributes = attributes or []
def __repr__(self):
if self.attributes:
return 'Sum({0.types}, {0.attributes})'.format(self)
else:
return 'Sum({0.types})'.format(self)
class Product(AST):
def __init__(self, fields, attributes=None):
self.fields = fields
self.attributes = attributes or []
def __repr__(self):
if self.attributes:
return 'Product({0.fields}, {0.attributes})'.format(self)
else:
return 'Product({0.fields})'.format(self)
# A generic visitor for the meta-AST that describes ASDL. This can be used by
# emitters. Note that this visitor does not provide a generic visit method, so a
# subclass needs to define visit methods from visitModule to as deep as the
# interesting node.
# We also define a Check visitor that makes sure the parsed ASDL is well-formed.
class VisitorBase(object):
"""Generic tree visitor for ASTs."""
def __init__(self):
self.cache = {}
def visit(self, obj, *args):
klass = obj.__class__
meth = self.cache.get(klass)
if meth is None:
methname = "visit" + klass.__name__
meth = getattr(self, methname, None)
self.cache[klass] = meth
if meth:
try:
meth(obj, *args)
except Exception as e:
print("Error visiting %r: %s" % (obj, e))
raise
class Check(VisitorBase):
"""A visitor that checks a parsed ASDL tree for correctness.
Errors are printed and accumulated.
"""
def __init__(self):
super(Check, self).__init__()
self.cons = {}
self.errors = 0
self.types = {}
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type):
self.visit(type.value, str(type.name))
def visitSum(self, sum, name):
for t in sum.types:
self.visit(t, name)
def visitConstructor(self, cons, name):
key = str(cons.name)
conflict = self.cons.get(key)
if conflict is None:
self.cons[key] = name
else:
print('Redefinition of constructor {}'.format(key))
print('Defined in {} and {}'.format(conflict, name))
self.errors += 1
for f in cons.fields:
self.visit(f, key)
def visitField(self, field, name):
key = str(field.type)
l = self.types.setdefault(key, [])
l.append(name)
def visitProduct(self, prod, name):
for f in prod.fields:
self.visit(f, name)
def check(mod):
"""Check the parsed ASDL tree for correctness.
Return True if success. For failure, the errors are printed out and False
is returned.
"""
v = Check()
v.visit(mod)
for t in v.types:
if t not in mod.types and not t in builtin_types:
v.errors += 1
uses = ", ".join(v.types[t])
print('Undefined type {}, used in {}'.format(t, uses))
return not v.errors
# The ASDL parser itself comes next. The only interesting external interface
# here is the top-level parse function.
def parse(filename):
"""Parse ASDL from the given file and return a Module node describing it."""
with open(filename) as f:
parser = ASDLParser()
return parser.parse(f.read())
# Types for describing tokens in an ASDL specification.
class TokenKind:
"""TokenKind is provides a scope for enumerated token kinds."""
(ConstructorId, TypeId, Equals, Comma, Question, Pipe, Asterisk,
LParen, RParen, LBrace, RBrace) = range(11)
operator_table = {
'=': Equals, ',': Comma, '?': Question, '|': Pipe, '(': LParen,
')': RParen, '*': Asterisk, '{': LBrace, '}': RBrace}
Token = namedtuple('Token', 'kind value lineno')
class ASDLSyntaxError(Exception):
def __init__(self, msg, lineno=None):
self.msg = msg
self.lineno = lineno or '<unknown>'
def __str__(self):
return 'Syntax error on line {0.lineno}: {0.msg}'.format(self)
def tokenize_asdl(buf):
"""Tokenize the given buffer. Yield Token objects."""
for lineno, line in enumerate(buf.splitlines(), 1):
for m in re.finditer(r'\s*(\w+|--.*|.)', line.strip()):
c = m.group(1)
if c[0].isalpha():
# Some kind of identifier
if c[0].isupper():
yield Token(TokenKind.ConstructorId, c, lineno)
else:
yield Token(TokenKind.TypeId, c, lineno)
elif c[:2] == '--':
# Comment
break
else:
# Operators
try:
op_kind = TokenKind.operator_table[c]
except KeyError:
raise ASDLSyntaxError('Invalid operator %s' % c, lineno)
yield Token(op_kind, c, lineno)
class ASDLParser:
"""Parser for ASDL files.
Create, then call the parse method on a buffer containing ASDL.
This is a simple recursive descent parser that uses tokenize_asdl for the
lexing.
"""
def __init__(self):
self._tokenizer = None
self.cur_token = None
def parse(self, buf):
"""Parse the ASDL in the buffer and return an AST with a Module root.
"""
self._tokenizer = tokenize_asdl(buf)
self._advance()
return self._parse_module()
def _parse_module(self):
if self._at_keyword('module'):
self._advance()
else:
raise ASDLSyntaxError(
'Expected "module" (found {})'.format(self.cur_token.value),
self.cur_token.lineno)
name = self._match(self._id_kinds)
self._match(TokenKind.LBrace)
defs = self._parse_definitions()
self._match(TokenKind.RBrace)
return Module(name, defs)
def _parse_definitions(self):
defs = []
while self.cur_token.kind == TokenKind.TypeId:
typename = self._advance()
self._match(TokenKind.Equals)
type = self._parse_type()
defs.append(Type(typename, type))
return defs
def _parse_type(self):
if self.cur_token.kind == TokenKind.LParen:
# If we see a (, it's a product
return self._parse_product()
else:
# Otherwise it's a sum. Look for ConstructorId
sumlist = [Constructor(self._match(TokenKind.ConstructorId),
self._parse_optional_fields())]
while self.cur_token.kind == TokenKind.Pipe:
# More constructors
self._advance()
sumlist.append(Constructor(
self._match(TokenKind.ConstructorId),
self._parse_optional_fields()))
return Sum(sumlist, self._parse_optional_attributes())
def _parse_product(self):
return Product(self._parse_fields(), self._parse_optional_attributes())
def _parse_fields(self):
fields = []
self._match(TokenKind.LParen)
while self.cur_token.kind == TokenKind.TypeId:
typename = self._advance()
is_seq, is_opt = self._parse_optional_field_quantifier()
id = (self._advance() if self.cur_token.kind in self._id_kinds
else None)
fields.append(Field(typename, id, seq=is_seq, opt=is_opt))
if self.cur_token.kind == TokenKind.RParen:
break
elif self.cur_token.kind == TokenKind.Comma:
self._advance()
self._match(TokenKind.RParen)
return fields
def _parse_optional_fields(self):
if self.cur_token.kind == TokenKind.LParen:
return self._parse_fields()
else:
return None
def _parse_optional_attributes(self):
if self._at_keyword('attributes'):
self._advance()
return self._parse_fields()
else:
return None
def _parse_optional_field_quantifier(self):
is_seq, is_opt = False, False
if self.cur_token.kind == TokenKind.Question:
is_opt = True
self._advance()
if self.cur_token.kind == TokenKind.Asterisk:
is_seq = True
self._advance()
return is_seq, is_opt
def _advance(self):
""" Return the value of the current token and read the next one into
self.cur_token.
"""
cur_val = None if self.cur_token is None else self.cur_token.value
try:
self.cur_token = next(self._tokenizer)
except StopIteration:
self.cur_token = None
return cur_val
_id_kinds = (TokenKind.ConstructorId, TokenKind.TypeId)
def _match(self, kind):
"""The 'match' primitive of RD parsers.
* Verifies that the current token is of the given kind (kind can
be a tuple, in which the kind must match one of its members).
* Returns the value of the current token
* Reads in the next token
"""
if (isinstance(kind, tuple) and self.cur_token.kind in kind or
self.cur_token.kind == kind
):
value = self.cur_token.value
self._advance()
return value
else:
raise ASDLSyntaxError(
'Unmatched {} (found {})'.format(kind, self.cur_token.kind),
self.cur_token.lineno)
def _at_keyword(self, keyword):
return (self.cur_token.kind == TokenKind.TypeId and
self.cur_token.value == keyword)

View File

@ -1,609 +0,0 @@
#! /usr/bin/env python
"""Generate Rust code from an ASDL description."""
import os
import sys
import textwrap
import json
from argparse import ArgumentParser
from pathlib import Path
import asdl
TABSIZE = 4
AUTOGEN_MESSAGE = "// File automatically generated by {}.\n\n"
builtin_type_mapping = {
'identifier': 'Ident',
'string': 'String',
'int': 'usize',
'constant': 'Constant',
'bool': 'bool',
'conversion_flag': 'ConversionFlag',
}
assert builtin_type_mapping.keys() == asdl.builtin_types
def get_rust_type(name):
"""Return a string for the C name of the type.
This function special cases the default types provided by asdl.
"""
if name in asdl.builtin_types:
return builtin_type_mapping[name]
else:
return "".join(part.capitalize() for part in name.split("_"))
def is_simple(sum):
"""Return True if a sum is a simple.
A sum is simple if its types have no fields, e.g.
unaryop = Invert | Not | UAdd | USub
"""
for t in sum.types:
if t.fields:
return False
return True
def asdl_of(name, obj):
if isinstance(obj, asdl.Product) or isinstance(obj, asdl.Constructor):
fields = ", ".join(map(str, obj.fields))
if fields:
fields = "({})".format(fields)
return "{}{}".format(name, fields)
else:
if is_simple(obj):
types = " | ".join(type.name for type in obj.types)
else:
sep = "\n{}| ".format(" " * (len(name) + 1))
types = sep.join(
asdl_of(type.name, type) for type in obj.types
)
return "{} = {}".format(name, types)
class EmitVisitor(asdl.VisitorBase):
"""Visit that emits lines"""
def __init__(self, file):
self.file = file
self.identifiers = set()
super(EmitVisitor, self).__init__()
def emit_identifier(self, name):
name = str(name)
if name in self.identifiers:
return
self.emit("_Py_IDENTIFIER(%s);" % name, 0)
self.identifiers.add(name)
def emit(self, line, depth):
if line:
line = (" " * TABSIZE * depth) + line
self.file.write(line + "\n")
class TypeInfo:
def __init__(self, name):
self.name = name
self.has_userdata = None
self.children = set()
self.boxed = False
def __repr__(self):
return f"<TypeInfo: {self.name}>"
def determine_userdata(self, typeinfo, stack):
if self.name in stack:
return None
stack.add(self.name)
for child, child_seq in self.children:
if child in asdl.builtin_types:
continue
childinfo = typeinfo[child]
child_has_userdata = childinfo.determine_userdata(typeinfo, stack)
if self.has_userdata is None and child_has_userdata is True:
self.has_userdata = True
stack.remove(self.name)
return self.has_userdata
class FindUserdataTypesVisitor(asdl.VisitorBase):
def __init__(self, typeinfo):
self.typeinfo = typeinfo
super().__init__()
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
stack = set()
for info in self.typeinfo.values():
info.determine_userdata(self.typeinfo, stack)
def visitType(self, type):
self.typeinfo[type.name] = TypeInfo(type.name)
self.visit(type.value, type.name)
def visitSum(self, sum, name):
info = self.typeinfo[name]
if is_simple(sum):
info.has_userdata = False
else:
if len(sum.types) > 1:
info.boxed = True
if sum.attributes:
# attributes means Located, which has the `custom: U` field
info.has_userdata = True
for variant in sum.types:
self.add_children(name, variant.fields)
def visitProduct(self, product, name):
info = self.typeinfo[name]
if product.attributes:
# attributes means Located, which has the `custom: U` field
info.has_userdata = True
if len(product.fields) > 2:
info.boxed = True
self.add_children(name, product.fields)
def add_children(self, name, fields):
self.typeinfo[name].children.update((field.type, field.seq) for field in fields)
def rust_field(field_name):
if field_name == 'type':
return 'type_'
else:
return field_name
class TypeInfoEmitVisitor(EmitVisitor):
def __init__(self, file, typeinfo):
self.typeinfo = typeinfo
super().__init__(file)
def has_userdata(self, typ):
return self.typeinfo[typ].has_userdata
def get_generics(self, typ, *generics):
if self.has_userdata(typ):
return [f"<{g}>" for g in generics]
else:
return ["" for g in generics]
class StructVisitor(TypeInfoEmitVisitor):
"""Visitor to generate typedefs for AST."""
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
if is_simple(sum):
self.simple_sum(sum, name, depth)
else:
self.sum_with_constructors(sum, name, depth)
def emit_attrs(self, depth):
self.emit("#[derive(Clone, Debug, PartialEq)]", depth)
def simple_sum(self, sum, name, depth):
rustname = get_rust_type(name)
self.emit_attrs(depth)
self.emit(f"pub enum {rustname} {{", depth)
for variant in sum.types:
self.emit(f"{variant.name},", depth + 1)
self.emit("}", depth)
self.emit("", depth)
def sum_with_constructors(self, sum, name, depth):
typeinfo = self.typeinfo[name]
generics, generics_applied = self.get_generics(name, "U = ()", "U")
enumname = rustname = get_rust_type(name)
# all the attributes right now are for location, so if it has attrs we
# can just wrap it in Located<>
if sum.attributes:
enumname = rustname + "Kind"
self.emit_attrs(depth)
self.emit(f"pub enum {enumname}{generics} {{", depth)
for t in sum.types:
self.visit(t, typeinfo, depth + 1)
self.emit("}", depth)
if sum.attributes:
self.emit(f"pub type {rustname}<U = ()> = Located<{enumname}{generics_applied}, U>;", depth)
self.emit("", depth)
def visitConstructor(self, cons, parent, depth):
if cons.fields:
self.emit(f"{cons.name} {{", depth)
for f in cons.fields:
self.visit(f, parent, "", depth + 1)
self.emit("},", depth)
else:
self.emit(f"{cons.name},", depth)
def visitField(self, field, parent, vis, depth):
typ = get_rust_type(field.type)
fieldtype = self.typeinfo.get(field.type)
if fieldtype and fieldtype.has_userdata:
typ = f"{typ}<U>"
# don't box if we're doing Vec<T>, but do box if we're doing Vec<Option<Box<T>>>
if fieldtype and fieldtype.boxed and (not field.seq or field.opt):
typ = f"Box<{typ}>"
if field.opt:
typ = f"Option<{typ}>"
if field.seq:
typ = f"Vec<{typ}>"
name = rust_field(field.name)
self.emit(f"{vis}{name}: {typ},", depth)
def visitProduct(self, product, name, depth):
typeinfo = self.typeinfo[name]
generics, generics_applied = self.get_generics(name, "U = ()", "U")
dataname = rustname = get_rust_type(name)
if product.attributes:
dataname = rustname + "Data"
self.emit_attrs(depth)
self.emit(f"pub struct {dataname}{generics} {{", depth)
for f in product.fields:
self.visit(f, typeinfo, "pub ", depth + 1)
self.emit("}", depth)
if product.attributes:
# attributes should just be location info
self.emit(f"pub type {rustname}<U = ()> = Located<{dataname}{generics_applied}, U>;", depth);
self.emit("", depth)
class FoldTraitDefVisitor(TypeInfoEmitVisitor):
def visitModule(self, mod, depth):
self.emit("pub trait Fold<U> {", depth)
self.emit("type TargetU;", depth + 1)
self.emit("type Error;", depth + 1)
self.emit("fn map_user(&mut self, user: U) -> Result<Self::TargetU, Self::Error>;", depth + 2)
for dfn in mod.dfns:
self.visit(dfn, depth + 2)
self.emit("}", depth)
def visitType(self, type, depth):
name = type.name
apply_u, apply_target_u = self.get_generics(name, "U", "Self::TargetU")
enumname = get_rust_type(name)
self.emit(f"fn fold_{name}(&mut self, node: {enumname}{apply_u}) -> Result<{enumname}{apply_target_u}, Self::Error> {{", depth)
self.emit(f"fold_{name}(self, node)", depth + 1)
self.emit("}", depth)
class FoldImplVisitor(TypeInfoEmitVisitor):
def visitModule(self, mod, depth):
self.emit("fn fold_located<U, F: Fold<U> + ?Sized, T, MT>(folder: &mut F, node: Located<T, U>, f: impl FnOnce(&mut F, T) -> Result<MT, F::Error>) -> Result<Located<MT, F::TargetU>, F::Error> {", depth)
self.emit("Ok(Located { custom: folder.map_user(node.custom)?, location: node.location, node: f(folder, node.node)? })", depth + 1)
self.emit("}", depth)
for dfn in mod.dfns:
self.visit(dfn, depth)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
apply_t, apply_u, apply_target_u = self.get_generics(name, "T", "U", "F::TargetU")
enumname = get_rust_type(name)
is_located = bool(sum.attributes)
self.emit(f"impl<T, U> Foldable<T, U> for {enumname}{apply_t} {{", depth)
self.emit(f"type Mapped = {enumname}{apply_u};", depth + 1)
self.emit("fn fold<F: Fold<T, TargetU = U> + ?Sized>(self, folder: &mut F) -> Result<Self::Mapped, F::Error> {", depth + 1)
self.emit(f"folder.fold_{name}(self)", depth + 2)
self.emit("}", depth + 1)
self.emit("}", depth)
self.emit(f"pub fn fold_{name}<U, F: Fold<U> + ?Sized>(#[allow(unused)] folder: &mut F, node: {enumname}{apply_u}) -> Result<{enumname}{apply_target_u}, F::Error> {{", depth)
if is_located:
self.emit("fold_located(folder, node, |folder, node| {", depth)
enumname += "Kind"
self.emit("match node {", depth + 1)
for cons in sum.types:
fields_pattern = self.make_pattern(cons.fields)
self.emit(f"{enumname}::{cons.name} {{ {fields_pattern} }} => {{", depth + 2)
self.gen_construction(f"{enumname}::{cons.name}", cons.fields, depth + 3)
self.emit("}", depth + 2)
self.emit("}", depth + 1)
if is_located:
self.emit("})", depth)
self.emit("}", depth)
def visitProduct(self, product, name, depth):
apply_t, apply_u, apply_target_u = self.get_generics(name, "T", "U", "F::TargetU")
structname = get_rust_type(name)
is_located = bool(product.attributes)
self.emit(f"impl<T, U> Foldable<T, U> for {structname}{apply_t} {{", depth)
self.emit(f"type Mapped = {structname}{apply_u};", depth + 1)
self.emit("fn fold<F: Fold<T, TargetU = U> + ?Sized>(self, folder: &mut F) -> Result<Self::Mapped, F::Error> {", depth + 1)
self.emit(f"folder.fold_{name}(self)", depth + 2)
self.emit("}", depth + 1)
self.emit("}", depth)
self.emit(f"pub fn fold_{name}<U, F: Fold<U> + ?Sized>(#[allow(unused)] folder: &mut F, node: {structname}{apply_u}) -> Result<{structname}{apply_target_u}, F::Error> {{", depth)
if is_located:
self.emit("fold_located(folder, node, |folder, node| {", depth)
structname += "Data"
fields_pattern = self.make_pattern(product.fields)
self.emit(f"let {structname} {{ {fields_pattern} }} = node;", depth + 1)
self.gen_construction(structname, product.fields, depth + 1)
if is_located:
self.emit("})", depth)
self.emit("}", depth)
def make_pattern(self, fields):
return ",".join(rust_field(f.name) for f in fields)
def gen_construction(self, cons_path, fields, depth):
self.emit(f"Ok({cons_path} {{", depth)
for field in fields:
name = rust_field(field.name)
self.emit(f"{name}: Foldable::fold({name}, folder)?,", depth + 1)
self.emit("})", depth)
class FoldModuleVisitor(TypeInfoEmitVisitor):
def visitModule(self, mod):
depth = 0
self.emit('#[cfg(feature = "fold")]', depth)
self.emit("pub mod fold {", depth)
self.emit("use super::*;", depth + 1)
self.emit("use crate::fold_helpers::Foldable;", depth + 1)
FoldTraitDefVisitor(self.file, self.typeinfo).visit(mod, depth + 1)
FoldImplVisitor(self.file, self.typeinfo).visit(mod, depth + 1)
self.emit("}", depth)
class ClassDefVisitor(EmitVisitor):
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
for cons in sum.types:
self.visit(cons, sum.attributes, depth)
def visitConstructor(self, cons, attrs, depth):
self.gen_classdef(cons.name, cons.fields, attrs, depth)
def visitProduct(self, product, name, depth):
self.gen_classdef(name, product.fields, product.attributes, depth)
def gen_classdef(self, name, fields, attrs, depth):
structname = "Node" + name
self.emit(f'#[pyclass(module = "_ast", name = {json.dumps(name)}, base = "AstNode")]', depth)
self.emit(f"struct {structname};", depth)
self.emit("#[pyimpl(flags(HAS_DICT, BASETYPE))]", depth)
self.emit(f"impl {structname} {{", depth)
self.emit(f"#[extend_class]", depth + 1)
self.emit("fn extend_class_with_fields(ctx: &PyContext, class: &PyTypeRef) {", depth + 1)
fields = ",".join(f"ctx.new_str({json.dumps(f.name)})" for f in fields)
self.emit(f'class.set_str_attr("_fields", ctx.new_list(vec![{fields}]));', depth + 2)
attrs = ",".join(f"ctx.new_str({json.dumps(attr.name)})" for attr in attrs)
self.emit(f'class.set_str_attr("_attributes", ctx.new_list(vec![{attrs}]));', depth + 2)
self.emit("}", depth + 1)
self.emit("}", depth)
class ExtendModuleVisitor(EmitVisitor):
def visitModule(self, mod):
depth = 0
self.emit("pub fn extend_module_nodes(vm: &VirtualMachine, module: &PyObjectRef) {", depth)
self.emit("extend_module!(vm, module, {", depth + 1)
for dfn in mod.dfns:
self.visit(dfn, depth + 2)
self.emit("})", depth + 1)
self.emit("}", depth)
def visitType(self, type, depth):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
for cons in sum.types:
self.visit(cons, depth)
def visitConstructor(self, cons, depth):
self.gen_extension(cons.name, depth)
def visitProduct(self, product, name, depth):
self.gen_extension(name, depth)
def gen_extension(self, name, depth):
self.emit(f"{json.dumps(name)} => Node{name}::make_class(&vm.ctx),", depth)
class TraitImplVisitor(EmitVisitor):
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
enumname = get_rust_type(name)
if sum.attributes:
enumname += "Kind"
self.emit(f"impl NamedNode for ast::{enumname} {{", depth)
self.emit(f"const NAME: &'static str = {json.dumps(name)};", depth + 1)
self.emit("}", depth)
self.emit(f"impl Node for ast::{enumname} {{", depth)
self.emit("fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef {", depth + 1)
self.emit("match self {", depth + 2)
for variant in sum.types:
self.constructor_to_object(variant, enumname, depth + 3)
self.emit("}", depth + 2)
self.emit("}", depth + 1)
self.emit("fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult<Self> {", depth + 1)
self.gen_sum_fromobj(sum, name, enumname, depth + 2)
self.emit("}", depth + 1)
self.emit("}", depth)
def constructor_to_object(self, cons, enumname, depth):
fields_pattern = self.make_pattern(cons.fields)
self.emit(f"ast::{enumname}::{cons.name} {{ {fields_pattern} }} => {{", depth)
self.make_node(cons.name, cons.fields, depth + 1)
self.emit("}", depth)
def visitProduct(self, product, name, depth):
structname = get_rust_type(name)
if product.attributes:
structname += "Data"
self.emit(f"impl NamedNode for ast::{structname} {{", depth)
self.emit(f"const NAME: &'static str = {json.dumps(name)};", depth + 1)
self.emit("}", depth)
self.emit(f"impl Node for ast::{structname} {{", depth)
self.emit("fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef {", depth + 1)
fields_pattern = self.make_pattern(product.fields)
self.emit(f"let ast::{structname} {{ {fields_pattern} }} = self;", depth + 2)
self.make_node(name, product.fields, depth + 2)
self.emit("}", depth + 1)
self.emit("fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult<Self> {", depth + 1)
self.gen_product_fromobj(product, name, structname, depth + 2)
self.emit("}", depth + 1)
self.emit("}", depth)
def make_node(self, variant, fields, depth):
lines = []
self.emit(f"let _node = AstNode.into_ref_with_type(_vm, Node{variant}::static_type().clone()).unwrap();", depth)
if fields:
self.emit("let _dict = _node.as_object().dict().unwrap();", depth)
for f in fields:
self.emit(f"_dict.set_item({json.dumps(f.name)}, {rust_field(f.name)}.ast_to_object(_vm), _vm).unwrap();", depth)
self.emit("_node.into_object()", depth)
def make_pattern(self, fields):
return ",".join(rust_field(f.name) for f in fields)
def gen_sum_fromobj(self, sum, sumname, enumname, depth):
if sum.attributes:
self.extract_location(sumname, depth)
self.emit("let _cls = _object.class();", depth)
self.emit("Ok(", depth)
for cons in sum.types:
self.emit(f"if _cls.is(Node{cons.name}::static_type()) {{", depth)
self.gen_construction(f"{enumname}::{cons.name}", cons, sumname, depth + 1)
self.emit("} else", depth)
self.emit("{", depth)
msg = f'format!("expected some sort of {sumname}, but got {{}}",_vm.to_repr(&_object)?)'
self.emit(f"return Err(_vm.new_type_error({msg}));", depth + 1)
self.emit("})", depth)
def gen_product_fromobj(self, product, prodname, structname, depth):
if product.attributes:
self.extract_location(prodname, depth)
self.emit("Ok(", depth)
self.gen_construction(structname, product, prodname, depth + 1)
self.emit(")", depth)
def gen_construction(self, cons_path, cons, name, depth):
self.emit(f"ast::{cons_path} {{", depth)
for field in cons.fields:
self.emit(f"{rust_field(field.name)}: {self.decode_field(field, name)},", depth + 1)
self.emit("}", depth)
def extract_location(self, typename, depth):
row = self.decode_field(asdl.Field('int', 'lineno'), typename)
column = self.decode_field(asdl.Field('int', 'col_offset'), typename)
self.emit(f"let _location = ast::Location::new({row}, {column});", depth)
def wrap_located_node(self, depth):
self.emit(f"let node = ast::Located::new(_location, node);", depth)
def decode_field(self, field, typename):
name = json.dumps(field.name)
if field.opt and not field.seq:
return f"get_node_field_opt(_vm, &_object, {name})?.map(|obj| Node::ast_from_object(_vm, obj)).transpose()?"
else:
return f"Node::ast_from_object(_vm, get_node_field(_vm, &_object, {name}, {json.dumps(typename)})?)?"
class ChainOfVisitors:
def __init__(self, *visitors):
self.visitors = visitors
def visit(self, object):
for v in self.visitors:
v.visit(object)
v.emit("", 0)
def write_ast_def(mod, typeinfo, f):
f.write('pub use crate::location::Location;\n')
f.write('pub use crate::constant::*;\n')
f.write('\n')
f.write('type Ident = String;\n')
f.write('\n')
StructVisitor(f, typeinfo).emit_attrs(0)
f.write('pub struct Located<T, U = ()> {\n')
f.write(' pub location: Location,\n')
f.write(' pub custom: U,\n')
f.write(' pub node: T,\n')
f.write('}\n')
f.write('\n')
f.write('impl<T> Located<T> {\n')
f.write(' pub fn new(location: Location, node: T) -> Self {\n')
f.write(' Self { location, custom: (), node }\n')
f.write(' }\n')
f.write('}\n')
f.write('\n')
c = ChainOfVisitors(StructVisitor(f, typeinfo),
FoldModuleVisitor(f, typeinfo))
c.visit(mod)
def write_ast_mod(mod, f):
f.write('use super::*;\n')
f.write('\n')
c = ChainOfVisitors(ClassDefVisitor(f),
TraitImplVisitor(f),
ExtendModuleVisitor(f))
c.visit(mod)
def main(input_filename, ast_mod_filename, ast_def_filename, dump_module=False):
auto_gen_msg = AUTOGEN_MESSAGE.format("/".join(Path(__file__).parts[-2:]))
mod = asdl.parse(input_filename)
if dump_module:
print('Parsed Module:')
print(mod)
if not asdl.check(mod):
sys.exit(1)
typeinfo = {}
FindUserdataTypesVisitor(typeinfo).visit(mod)
with ast_def_filename.open("w") as def_file, \
ast_mod_filename.open("w") as mod_file:
def_file.write(auto_gen_msg)
write_ast_def(mod, typeinfo, def_file)
mod_file.write(auto_gen_msg)
write_ast_mod(mod, mod_file)
print(f"{ast_def_filename}, {ast_mod_filename} regenerated.")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("input_file", type=Path)
parser.add_argument("-M", "--mod-file", type=Path, required=True)
parser.add_argument("-D", "--def-file", type=Path, required=True)
parser.add_argument("-d", "--dump-module", action="store_true")
args = parser.parse_args()
main(args.input_file, args.mod_file, args.def_file, args.dump_module)

File diff suppressed because it is too large Load Diff

View File

@ -1,183 +0,0 @@
#[derive(Clone, Debug, PartialEq)]
pub enum Constant {
None,
Bool(bool),
Str(String),
Bytes(Vec<u8>),
Int(i128),
Tuple(Vec<Constant>),
Float(f64),
Complex { real: f64, imag: f64 },
Ellipsis,
}
impl From<String> for Constant {
fn from(s: String) -> Constant {
Self::Str(s)
}
}
impl From<Vec<u8>> for Constant {
fn from(b: Vec<u8>) -> Constant {
Self::Bytes(b)
}
}
impl From<bool> for Constant {
fn from(b: bool) -> Constant {
Self::Bool(b)
}
}
impl From<i32> for Constant {
fn from(i: i32) -> Constant {
Self::Int(i128::from(i))
}
}
impl From<i64> for Constant {
fn from(i: i64) -> Constant {
Self::Int(i128::from(i))
}
}
/// Transforms a value prior to formatting it.
#[derive(Copy, Clone, Debug, PartialEq)]
#[repr(u8)]
pub enum ConversionFlag {
/// Converts by calling `str(<value>)`.
Str = b's',
/// Converts by calling `ascii(<value>)`.
Ascii = b'a',
/// Converts by calling `repr(<value>)`.
Repr = b'r',
}
impl ConversionFlag {
#[must_use]
pub fn try_from_byte(b: u8) -> Option<Self> {
match b {
b's' => Some(Self::Str),
b'a' => Some(Self::Ascii),
b'r' => Some(Self::Repr),
_ => None,
}
}
}
#[cfg(feature = "constant-optimization")]
#[derive(Default)]
pub struct ConstantOptimizer {
_priv: (),
}
#[cfg(feature = "constant-optimization")]
impl ConstantOptimizer {
#[inline]
#[must_use]
pub fn new() -> Self {
Self { _priv: () }
}
}
#[cfg(feature = "constant-optimization")]
impl<U> crate::fold::Fold<U> for ConstantOptimizer {
type TargetU = U;
type Error = std::convert::Infallible;
#[inline]
fn map_user(&mut self, user: U) -> Result<Self::TargetU, Self::Error> {
Ok(user)
}
fn fold_expr(&mut self, node: crate::Expr<U>) -> Result<crate::Expr<U>, Self::Error> {
match node.node {
crate::ExprKind::Tuple { elts, ctx } => {
let elts =
elts.into_iter().map(|x| self.fold_expr(x)).collect::<Result<Vec<_>, _>>()?;
let expr =
if elts.iter().all(|e| matches!(e.node, crate::ExprKind::Constant { .. })) {
let tuple = elts
.into_iter()
.map(|e| match e.node {
crate::ExprKind::Constant { value, .. } => value,
_ => unreachable!(),
})
.collect();
crate::ExprKind::Constant { value: Constant::Tuple(tuple), kind: None }
} else {
crate::ExprKind::Tuple { elts, ctx }
};
Ok(crate::Expr { node: expr, custom: node.custom, location: node.location })
}
_ => crate::fold::fold_expr(self, node),
}
}
}
#[cfg(test)]
mod tests {
#[cfg(feature = "constant-optimization")]
#[test]
fn test_constant_opt() {
use super::*;
use crate::fold::Fold;
use crate::*;
let location = Location::new(0, 0, FileName::default());
let custom = ();
let ast = Located {
location,
custom,
node: ExprKind::Tuple {
ctx: ExprContext::Load,
elts: vec![
Located {
location,
custom,
node: ExprKind::Constant { value: 1.into(), kind: None },
},
Located {
location,
custom,
node: ExprKind::Constant { value: 2.into(), kind: None },
},
Located {
location,
custom,
node: ExprKind::Tuple {
ctx: ExprContext::Load,
elts: vec![
Located {
location,
custom,
node: ExprKind::Constant { value: 3.into(), kind: None },
},
Located {
location,
custom,
node: ExprKind::Constant { value: 4.into(), kind: None },
},
Located {
location,
custom,
node: ExprKind::Constant { value: 5.into(), kind: None },
},
],
},
},
],
},
};
let new_ast = ConstantOptimizer::new().fold_expr(ast).unwrap_or_else(|e| match e {});
assert_eq!(
new_ast,
Located {
location,
custom,
node: ExprKind::Constant {
value: Constant::Tuple(vec![
1.into(),
2.into(),
Constant::Tuple(vec![3.into(), 4.into(), 5.into(),])
]),
kind: None
},
}
);
}
}

View File

@ -1,67 +0,0 @@
use crate::constant;
use crate::fold::Fold;
use crate::StrRef;
pub(crate) trait Foldable<T, U> {
type Mapped;
fn fold<F: Fold<T, TargetU = U> + ?Sized>(
self,
folder: &mut F,
) -> Result<Self::Mapped, F::Error>;
}
impl<T, U, X> Foldable<T, U> for Vec<X>
where
X: Foldable<T, U>,
{
type Mapped = Vec<X::Mapped>;
fn fold<F: Fold<T, TargetU = U> + ?Sized>(
self,
folder: &mut F,
) -> Result<Self::Mapped, F::Error> {
self.into_iter().map(|x| x.fold(folder)).collect()
}
}
impl<T, U, X> Foldable<T, U> for Option<X>
where
X: Foldable<T, U>,
{
type Mapped = Option<X::Mapped>;
fn fold<F: Fold<T, TargetU = U> + ?Sized>(
self,
folder: &mut F,
) -> Result<Self::Mapped, F::Error> {
self.map(|x| x.fold(folder)).transpose()
}
}
impl<T, U, X> Foldable<T, U> for Box<X>
where
X: Foldable<T, U>,
{
type Mapped = Box<X::Mapped>;
fn fold<F: Fold<T, TargetU = U> + ?Sized>(
self,
folder: &mut F,
) -> Result<Self::Mapped, F::Error> {
(*self).fold(folder).map(Box::new)
}
}
macro_rules! simple_fold {
($($t:ty),+$(,)?) => {
$(impl<T, U> $crate::fold_helpers::Foldable<T, U> for $t {
type Mapped = Self;
#[inline]
fn fold<F: Fold<T, TargetU = U> + ?Sized>(
self,
_folder: &mut F,
) -> Result<Self::Mapped, F::Error> {
Ok(self)
}
})+
};
}
simple_fold!(usize, String, bool, StrRef, constant::Constant, constant::ConversionFlag);

View File

@ -1,51 +0,0 @@
use crate::{Constant, ExprKind};
impl<U> ExprKind<U> {
/// Returns a short name for the node suitable for use in error messages.
#[must_use]
pub fn name(&self) -> &'static str {
match self {
ExprKind::BoolOp { .. } | ExprKind::BinOp { .. } | ExprKind::UnaryOp { .. } => {
"operator"
}
ExprKind::Subscript { .. } => "subscript",
ExprKind::Await { .. } => "await expression",
ExprKind::Yield { .. } | ExprKind::YieldFrom { .. } => "yield expression",
ExprKind::Compare { .. } => "comparison",
ExprKind::Attribute { .. } => "attribute",
ExprKind::Call { .. } => "function call",
ExprKind::Constant { value, .. } => match value {
Constant::Str(_)
| Constant::Int(_)
| Constant::Float(_)
| Constant::Complex { .. }
| Constant::Bytes(_) => "literal",
Constant::Tuple(_) => "tuple",
Constant::Bool(_) | Constant::None => "keyword",
Constant::Ellipsis => "ellipsis",
},
ExprKind::List { .. } => "list",
ExprKind::Tuple { .. } => "tuple",
ExprKind::Dict { .. } => "dict display",
ExprKind::Set { .. } => "set display",
ExprKind::ListComp { .. } => "list comprehension",
ExprKind::DictComp { .. } => "dict comprehension",
ExprKind::SetComp { .. } => "set comprehension",
ExprKind::GeneratorExp { .. } => "generator expression",
ExprKind::Starred { .. } => "starred",
ExprKind::Slice { .. } => "slice",
ExprKind::JoinedStr { values } => {
if values.iter().any(|e| matches!(e.node, ExprKind::JoinedStr { .. })) {
"f-string expression"
} else {
"literal"
}
}
ExprKind::FormattedValue { .. } => "f-string expression",
ExprKind::Name { .. } => "name",
ExprKind::Lambda { .. } => "lambda",
ExprKind::IfExp { .. } => "conditional expression",
ExprKind::NamedExpr { .. } => "named expression",
}
}
}

View File

@ -1,21 +0,0 @@
#![deny(future_incompatible, let_underscore, nonstandard_style, clippy::all)]
#![warn(clippy::pedantic)]
#![allow(
clippy::missing_errors_doc,
clippy::missing_panics_doc,
clippy::module_name_repetitions,
clippy::too_many_lines,
clippy::wildcard_imports
)]
mod ast_gen;
mod constant;
#[cfg(feature = "fold")]
mod fold_helpers;
mod impls;
mod location;
pub use ast_gen::*;
pub use location::{FileName, Location};
pub type Suite<U = ()> = Vec<Stmt<U>>;

View File

@ -1,116 +0,0 @@
//! Datatypes to support source location information.
use crate::ast_gen::StrRef;
use std::cmp::Ordering;
use std::fmt;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct FileName(pub StrRef);
impl Default for FileName {
fn default() -> Self {
FileName("unknown".into())
}
}
impl From<String> for FileName {
fn from(s: String) -> Self {
FileName(s.into())
}
}
/// A location somewhere in the sourcecode.
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct Location {
pub row: usize,
pub column: usize,
pub file: FileName,
}
impl fmt::Display for Location {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}:{}:{}", self.file.0, self.row, self.column)
}
}
impl Ord for Location {
fn cmp(&self, other: &Self) -> Ordering {
let file_cmp = self.file.0.to_string().cmp(&other.file.0.to_string());
if file_cmp != Ordering::Equal {
return file_cmp;
}
let row_cmp = self.row.cmp(&other.row);
if row_cmp != Ordering::Equal {
return row_cmp;
}
self.column.cmp(&other.column)
}
}
impl PartialOrd for Location {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Location {
pub fn visualize<'a>(
&self,
line: &'a str,
desc: impl fmt::Display + 'a,
) -> impl fmt::Display + 'a {
struct Visualize<'a, D: fmt::Display> {
loc: Location,
line: &'a str,
desc: D,
}
impl<D: fmt::Display> fmt::Display for Visualize<'_, D> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{}\n{}\n{arrow:>pad$}",
self.desc,
self.line,
pad = self.loc.column,
arrow = "",
)
}
}
Visualize { loc: *self, line, desc }
}
}
impl Location {
#[must_use]
pub fn new(row: usize, column: usize, file: FileName) -> Self {
Location { row, column, file }
}
#[must_use]
pub fn row(&self) -> usize {
self.row
}
#[must_use]
pub fn column(&self) -> usize {
self.column
}
pub fn reset(&mut self) {
self.row = 1;
self.column = 1;
}
pub fn go_right(&mut self) {
self.column += 1;
}
pub fn go_left(&mut self) {
self.column -= 1;
}
pub fn newline(&mut self) {
self.row += 1;
self.column = 1;
}
}

View File

@ -2,33 +2,15 @@
name = "nac3core" name = "nac3core"
version = "0.1.0" version = "0.1.0"
authors = ["M-Labs"] authors = ["M-Labs"]
edition = "2021" edition = "2018"
[features]
default = ["derive"]
derive = ["dep:nac3core_derive"]
no-escape-analysis = []
[dependencies] [dependencies]
itertools = "0.13" num-bigint = "0.3"
crossbeam = "0.8" num-traits = "0.2"
indexmap = "2.6" thiserror = "1.0"
parking_lot = "0.12" inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm10-0"] }
rayon = "1.10" rustpython-parser = { git = "https://github.com/RustPython/RustPython", branch = "master" }
nac3core_derive = { path = "nac3core_derive", optional = true }
nac3parser = { path = "../nac3parser" }
strum = "0.26"
strum_macros = "0.26"
[dependencies.inkwell]
version = "0.5"
default-features = false
features = ["llvm14-0-prefer-dynamic", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]
[dev-dependencies] [dev-dependencies]
test-case = "1.2.0" indoc = "1.0"
indoc = "2.0"
insta = "=1.11.0"
[build-dependencies]
regex = "1.10"

View File

@ -1,109 +0,0 @@
use std::{
env,
fs::File,
io::Write,
path::Path,
process::{Command, Stdio},
};
use regex::Regex;
fn main() {
let out_dir = env::var("OUT_DIR").unwrap();
let out_dir = Path::new(&out_dir);
let irrt_dir = Path::new("irrt");
let irrt_cpp_path = irrt_dir.join("irrt.cpp");
/*
* HACK: Sadly, clang doesn't let us emit generic LLVM bitcode.
* Compiling for WASM32 and filtering the output with regex is the closest we can get.
*/
let mut flags: Vec<&str> = vec![
"--target=wasm32",
"-x",
"c++",
"-std=c++20",
"-fno-discard-value-names",
"-fno-exceptions",
"-fno-rtti",
"-emit-llvm",
"-S",
"-Wall",
"-Wextra",
"-o",
"-",
"-I",
irrt_dir.to_str().unwrap(),
irrt_cpp_path.to_str().unwrap(),
];
match env::var("PROFILE").as_deref() {
Ok("debug") => {
flags.push("-O0");
flags.push("-DIRRT_DEBUG_ASSERT");
}
Ok("release") => {
flags.push("-O3");
}
flavor => panic!("Unknown or missing build flavor {flavor:?}"),
}
// Tell Cargo to rerun if any file under `irrt_dir` (recursive) changes
println!("cargo:rerun-if-changed={}", irrt_dir.to_str().unwrap());
// Compile IRRT and capture the LLVM IR output
let output = Command::new("clang-irrt")
.args(flags)
.output()
.inspect(|o| {
assert!(o.status.success(), "{}", std::str::from_utf8(&o.stderr).unwrap());
})
.unwrap();
// https://github.com/rust-lang/regex/issues/244
let output = std::str::from_utf8(&output.stdout).unwrap().replace("\r\n", "\n");
let mut filtered_output = String::with_capacity(output.len());
// Filter out irrelevant IR
//
// Regex:
// - `(?ms:^define.*?\}$)` captures LLVM `define` blocks
// - `(?m:^declare.*?$)` captures LLVM `declare` lines
// - `(?m:^%.+?=\s*type\s*\{.+?\}$)` captures LLVM `type` declarations
// - `(?m:^@.+?=.+$)` captures global constants
let regex_filter = Regex::new(
r"(?ms:^define.*?\}$)|(?m:^declare.*?$)|(?m:^%.+?=\s*type\s*\{.+?\}$)|(?m:^@.+?=.+$)",
)
.unwrap();
for f in regex_filter.captures_iter(&output) {
assert_eq!(f.len(), 1);
filtered_output.push_str(&f[0]);
filtered_output.push('\n');
}
let filtered_output = Regex::new("(#\\d+)|(, *![0-9A-Za-z.]+)|(![0-9A-Za-z.]+)|(!\".*?\")")
.unwrap()
.replace_all(&filtered_output, "");
// For debugging
// Doing `DEBUG_DUMP_IRRT=1 cargo build -p nac3core` dumps the LLVM IR generated
const DEBUG_DUMP_IRRT: &str = "DEBUG_DUMP_IRRT";
println!("cargo:rerun-if-env-changed={DEBUG_DUMP_IRRT}");
if env::var(DEBUG_DUMP_IRRT).is_ok() {
let mut file = File::create(out_dir.join("irrt.ll")).unwrap();
file.write_all(output.as_bytes()).unwrap();
let mut file = File::create(out_dir.join("irrt-filtered.ll")).unwrap();
file.write_all(filtered_output.as_bytes()).unwrap();
}
let mut llvm_as = Command::new("llvm-as-irrt")
.stdin(Stdio::piped())
.arg("-o")
.arg(out_dir.join("irrt.bc"))
.spawn()
.unwrap();
llvm_as.stdin.as_mut().unwrap().write_all(filtered_output.as_bytes()).unwrap();
assert!(llvm_as.wait().unwrap().success());
}

View File

@ -1,5 +0,0 @@
#include "irrt/exception.hpp"
#include "irrt/list.hpp"
#include "irrt/math.hpp"
#include "irrt/ndarray.hpp"
#include "irrt/slice.hpp"

View File

@ -1,9 +0,0 @@
#pragma once
#include "irrt/int_types.hpp"
template<typename SizeT>
struct CSlice {
void* base;
SizeT len;
};

View File

@ -1,25 +0,0 @@
#pragma once
// Set in nac3core/build.rs
#ifdef IRRT_DEBUG_ASSERT
#define IRRT_DEBUG_ASSERT_BOOL true
#else
#define IRRT_DEBUG_ASSERT_BOOL false
#endif
#define raise_debug_assert(SizeT, msg, param1, param2, param3) \
raise_exception(SizeT, EXN_ASSERTION_ERROR, "IRRT debug assert failed: " msg, param1, param2, param3)
#define debug_assert_eq(SizeT, lhs, rhs) \
if constexpr (IRRT_DEBUG_ASSERT_BOOL) { \
if ((lhs) != (rhs)) { \
raise_debug_assert(SizeT, "LHS = {0}. RHS = {1}", lhs, rhs, NO_PARAM); \
} \
}
#define debug_assert(SizeT, expr) \
if constexpr (IRRT_DEBUG_ASSERT_BOOL) { \
if (!(expr)) { \
raise_debug_assert(SizeT, "Got false.", NO_PARAM, NO_PARAM, NO_PARAM); \
} \
}

View File

@ -1,85 +0,0 @@
#pragma once
#include "irrt/cslice.hpp"
#include "irrt/int_types.hpp"
/**
* @brief The int type of ARTIQ exception IDs.
*/
using ExceptionId = int32_t;
/*
* Set of exceptions C++ IRRT can use.
* Must be synchronized with `setup_irrt_exceptions` in `nac3core/src/codegen/irrt/mod.rs`.
*/
extern "C" {
ExceptionId EXN_INDEX_ERROR;
ExceptionId EXN_VALUE_ERROR;
ExceptionId EXN_ASSERTION_ERROR;
ExceptionId EXN_TYPE_ERROR;
}
/**
* @brief Extern function to `__nac3_raise`
*
* The parameter `err` could be `Exception<int32_t>` or `Exception<int64_t>`. The caller
* must make sure to pass `Exception`s with the correct `SizeT` depending on the `size_t` of the runtime.
*/
extern "C" void __nac3_raise(void* err);
namespace {
/**
* @brief NAC3's Exception struct
*/
template<typename SizeT>
struct Exception {
ExceptionId id;
CSlice<SizeT> filename;
int32_t line;
int32_t column;
CSlice<SizeT> function;
CSlice<SizeT> msg;
int64_t params[3];
};
constexpr int64_t NO_PARAM = 0;
template<typename SizeT>
void _raise_exception_helper(ExceptionId id,
const char* filename,
int32_t line,
const char* function,
const char* msg,
int64_t param0,
int64_t param1,
int64_t param2) {
Exception<SizeT> e = {
.id = id,
.filename = {.base = reinterpret_cast<void*>(const_cast<char*>(filename)),
.len = static_cast<SizeT>(__builtin_strlen(filename))},
.line = line,
.column = 0,
.function = {.base = reinterpret_cast<void*>(const_cast<char*>(function)),
.len = static_cast<SizeT>(__builtin_strlen(function))},
.msg = {.base = reinterpret_cast<void*>(const_cast<char*>(msg)),
.len = static_cast<SizeT>(__builtin_strlen(msg))},
};
e.params[0] = param0;
e.params[1] = param1;
e.params[2] = param2;
__nac3_raise(reinterpret_cast<void*>(&e));
__builtin_unreachable();
}
} // namespace
/**
* @brief Raise an exception with location details (location in the IRRT source files).
* @param SizeT The runtime `size_t` type.
* @param id The ID of the exception to raise.
* @param msg A global constant C-string of the error message.
*
* `param0` to `param2` are optional format arguments of `msg`. They should be set to
* `NO_PARAM` to indicate they are unused.
*/
#define raise_exception(SizeT, id, msg, param0, param1, param2) \
_raise_exception_helper<SizeT>(id, __FILE__, __LINE__, __FUNCTION__, msg, param0, param1, param2)

View File

@ -1,27 +0,0 @@
#pragma once
#if __STDC_VERSION__ >= 202000
using int8_t = _BitInt(8);
using uint8_t = unsigned _BitInt(8);
using int32_t = _BitInt(32);
using uint32_t = unsigned _BitInt(32);
using int64_t = _BitInt(64);
using uint64_t = unsigned _BitInt(64);
#else
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wdeprecated-type"
using int8_t = _ExtInt(8);
using uint8_t = unsigned _ExtInt(8);
using int32_t = _ExtInt(32);
using uint32_t = unsigned _ExtInt(32);
using int64_t = _ExtInt(64);
using uint64_t = unsigned _ExtInt(64);
#pragma clang diagnostic pop
#endif
// NDArray indices are always `uint32_t`.
using NDIndex = uint32_t;
// The type of an index or a value describing the length of a range/slice is always `int32_t`.
using SliceIndex = int32_t;

View File

@ -1,81 +0,0 @@
#pragma once
#include "irrt/int_types.hpp"
#include "irrt/math_util.hpp"
extern "C" {
// Handle list assignment and dropping part of the list when
// both dest_step and src_step are +1.
// - All the index must *not* be out-of-bound or negative,
// - The end index is *inclusive*,
// - The length of src and dest slice size should already
// be checked: if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest)
SliceIndex __nac3_list_slice_assign_var_size(SliceIndex dest_start,
SliceIndex dest_end,
SliceIndex dest_step,
void* dest_arr,
SliceIndex dest_arr_len,
SliceIndex src_start,
SliceIndex src_end,
SliceIndex src_step,
void* src_arr,
SliceIndex src_arr_len,
const SliceIndex size) {
/* if dest_arr_len == 0, do nothing since we do not support extending list */
if (dest_arr_len == 0)
return dest_arr_len;
/* if both step is 1, memmove directly, handle the dropping of the list, and shrink size */
if (src_step == dest_step && dest_step == 1) {
const SliceIndex src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0;
const SliceIndex dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
if (src_len > 0) {
__builtin_memmove(static_cast<uint8_t*>(dest_arr) + dest_start * size,
static_cast<uint8_t*>(src_arr) + src_start * size, src_len * size);
}
if (dest_len > 0) {
/* dropping */
__builtin_memmove(static_cast<uint8_t*>(dest_arr) + (dest_start + src_len) * size,
static_cast<uint8_t*>(dest_arr) + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size);
}
/* shrink size */
return dest_arr_len - (dest_len - src_len);
}
/* if two range overlaps, need alloca */
uint8_t need_alloca = (dest_arr == src_arr)
&& !(max(dest_start, dest_end) < min(src_start, src_end)
|| max(src_start, src_end) < min(dest_start, dest_end));
if (need_alloca) {
void* tmp = __builtin_alloca(src_arr_len * size);
__builtin_memcpy(tmp, src_arr, src_arr_len * size);
src_arr = tmp;
}
SliceIndex src_ind = src_start;
SliceIndex dest_ind = dest_start;
for (; (src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end); src_ind += src_step, dest_ind += dest_step) {
/* for constant optimization */
if (size == 1) {
__builtin_memcpy(static_cast<uint8_t*>(dest_arr) + dest_ind, static_cast<uint8_t*>(src_arr) + src_ind, 1);
} else if (size == 4) {
__builtin_memcpy(static_cast<uint8_t*>(dest_arr) + dest_ind * 4,
static_cast<uint8_t*>(src_arr) + src_ind * 4, 4);
} else if (size == 8) {
__builtin_memcpy(static_cast<uint8_t*>(dest_arr) + dest_ind * 8,
static_cast<uint8_t*>(src_arr) + src_ind * 8, 8);
} else {
/* memcpy for var size, cannot overlap after previous alloca */
__builtin_memcpy(static_cast<uint8_t*>(dest_arr) + dest_ind * size,
static_cast<uint8_t*>(src_arr) + src_ind * size, size);
}
}
/* only dest_step == 1 can we shrink the dest list. */
/* size should be ensured prior to calling this function */
if (dest_step == 1 && dest_end >= dest_start) {
__builtin_memmove(static_cast<uint8_t*>(dest_arr) + dest_ind * size,
static_cast<uint8_t*>(dest_arr) + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size);
return dest_arr_len - (dest_end - dest_ind) - 1;
}
return dest_arr_len;
}
} // extern "C"

View File

@ -1,93 +0,0 @@
#pragma once
namespace {
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
// need to make sure `exp >= 0` before calling this function
template<typename T>
T __nac3_int_exp_impl(T base, T exp) {
T res = 1;
/* repeated squaring method */
do {
if (exp & 1) {
res *= base; /* for n odd */
}
exp >>= 1;
base *= base;
} while (exp);
return res;
}
} // namespace
#define DEF_nac3_int_exp_(T) \
T __nac3_int_exp_##T(T base, T exp) { \
return __nac3_int_exp_impl(base, exp); \
}
extern "C" {
// Putting semicolons here to make clang-format not reformat this into
// a stair shape.
DEF_nac3_int_exp_(int32_t);
DEF_nac3_int_exp_(int64_t);
DEF_nac3_int_exp_(uint32_t);
DEF_nac3_int_exp_(uint64_t);
int32_t __nac3_isinf(double x) {
return __builtin_isinf(x);
}
int32_t __nac3_isnan(double x) {
return __builtin_isnan(x);
}
double tgamma(double arg);
double __nac3_gamma(double z) {
// Handling for denormals
// | x | Python gamma(x) | C tgamma(x) |
// --- | ----------------- | --------------- | ----------- |
// (1) | nan | nan | nan |
// (2) | -inf | -inf | inf |
// (3) | inf | inf | inf |
// (4) | 0.0 | inf | inf |
// (5) | {-1.0, -2.0, ...} | inf | nan |
// (1)-(3)
if (__builtin_isinf(z) || __builtin_isnan(z)) {
return z;
}
double v = tgamma(z);
// (4)-(5)
return __builtin_isinf(v) || __builtin_isnan(v) ? __builtin_inf() : v;
}
double lgamma(double arg);
double __nac3_gammaln(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: gammaln(-inf) -> -inf
// - libm : lgamma(-inf) -> inf
if (__builtin_isinf(x)) {
return x;
}
return lgamma(x);
}
double j0(double x);
double __nac3_j0(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: j0(inf) -> nan
// - libm : j0(inf) -> 0.0
if (__builtin_isinf(x)) {
return __builtin_nan("");
}
return j0(x);
}
} // namespace

View File

@ -1,13 +0,0 @@
#pragma once
namespace {
template<typename T>
const T& max(const T& a, const T& b) {
return a > b ? a : b;
}
template<typename T>
const T& min(const T& a, const T& b) {
return a > b ? b : a;
}
} // namespace

View File

@ -1,144 +0,0 @@
#pragma once
#include "irrt/int_types.hpp"
namespace {
template<typename SizeT>
SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) {
__builtin_assume(end_idx <= list_len);
SizeT num_elems = 1;
for (SizeT i = begin_idx; i < end_idx; ++i) {
SizeT val = list_data[i];
__builtin_assume(val > 0);
num_elems *= val;
}
return num_elems;
}
template<typename SizeT>
void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT num_dims, NDIndex* idxs) {
SizeT stride = 1;
for (SizeT dim = 0; dim < num_dims; dim++) {
SizeT i = num_dims - dim - 1;
__builtin_assume(dims[i] > 0);
idxs[i] = (index / stride) % dims[i];
stride *= dims[i];
}
}
template<typename SizeT>
SizeT __nac3_ndarray_flatten_index_impl(const SizeT* dims, SizeT num_dims, const NDIndex* indices, SizeT num_indices) {
SizeT idx = 0;
SizeT stride = 1;
for (SizeT i = 0; i < num_dims; ++i) {
SizeT ri = num_dims - i - 1;
if (ri < num_indices) {
idx += stride * indices[ri];
}
__builtin_assume(dims[i] > 0);
stride *= dims[ri];
}
return idx;
}
template<typename SizeT>
void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims,
SizeT lhs_ndims,
const SizeT* rhs_dims,
SizeT rhs_ndims,
SizeT* out_dims) {
SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
for (SizeT i = 0; i < max_ndims; ++i) {
const SizeT* lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr;
const SizeT* rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr;
SizeT* out_dim = &out_dims[max_ndims - i - 1];
if (lhs_dim_sz == nullptr) {
*out_dim = *rhs_dim_sz;
} else if (rhs_dim_sz == nullptr) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == 1) {
*out_dim = *rhs_dim_sz;
} else if (*rhs_dim_sz == 1) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == *rhs_dim_sz) {
*out_dim = *lhs_dim_sz;
} else {
__builtin_unreachable();
}
}
}
template<typename SizeT>
void __nac3_ndarray_calc_broadcast_idx_impl(const SizeT* src_dims,
SizeT src_ndims,
const NDIndex* in_idx,
NDIndex* out_idx) {
for (SizeT i = 0; i < src_ndims; ++i) {
SizeT src_i = src_ndims - i - 1;
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
}
}
} // namespace
extern "C" {
uint32_t __nac3_ndarray_calc_size(const uint32_t* list_data, uint32_t list_len, uint32_t begin_idx, uint32_t end_idx) {
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
}
uint64_t
__nac3_ndarray_calc_size64(const uint64_t* list_data, uint64_t list_len, uint64_t begin_idx, uint64_t end_idx) {
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
}
void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims, uint32_t num_dims, NDIndex* idxs) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
}
void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndex* idxs) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
}
uint32_t
__nac3_ndarray_flatten_index(const uint32_t* dims, uint32_t num_dims, const NDIndex* indices, uint32_t num_indices) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
}
uint64_t
__nac3_ndarray_flatten_index64(const uint64_t* dims, uint64_t num_dims, const NDIndex* indices, uint64_t num_indices) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
}
void __nac3_ndarray_calc_broadcast(const uint32_t* lhs_dims,
uint32_t lhs_ndims,
const uint32_t* rhs_dims,
uint32_t rhs_ndims,
uint32_t* out_dims) {
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
}
void __nac3_ndarray_calc_broadcast64(const uint64_t* lhs_dims,
uint64_t lhs_ndims,
const uint64_t* rhs_dims,
uint64_t rhs_ndims,
uint64_t* out_dims) {
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
}
void __nac3_ndarray_calc_broadcast_idx(const uint32_t* src_dims,
uint32_t src_ndims,
const NDIndex* in_idx,
NDIndex* out_idx) {
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
}
void __nac3_ndarray_calc_broadcast_idx64(const uint64_t* src_dims,
uint64_t src_ndims,
const NDIndex* in_idx,
NDIndex* out_idx) {
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
}
} // namespace

View File

@ -1,28 +0,0 @@
#pragma once
#include "irrt/int_types.hpp"
extern "C" {
SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) {
if (i < 0) {
i = len + i;
}
if (i < 0) {
return 0;
} else if (i > len) {
return len;
}
return i;
}
SliceIndex __nac3_range_slice_len(const SliceIndex start, const SliceIndex end, const SliceIndex step) {
SliceIndex diff = end - start;
if (diff > 0 && step > 0) {
return ((diff - 1) / step) + 1;
} else if (diff < 0 && step < 0) {
return ((diff + 1) / step) + 1;
} else {
return 0;
}
}
} // namespace

View File

@ -1,21 +0,0 @@
[package]
name = "nac3core_derive"
version = "0.1.0"
edition = "2021"
[lib]
proc-macro = true
[[test]]
name = "structfields_tests"
path = "tests/structfields_test.rs"
[dev-dependencies]
nac3core = { path = ".." }
trybuild = { version = "1.0", features = ["diff"] }
[dependencies]
proc-macro2 = "1.0"
proc-macro-error = "1.0"
syn = "2.0"
quote = "1.0"

View File

@ -1,320 +0,0 @@
use proc_macro::TokenStream;
use proc_macro_error::{abort, proc_macro_error};
use quote::quote;
use syn::{
parse_macro_input, spanned::Spanned, Data, DataStruct, Expr, ExprField, ExprMethodCall,
ExprPath, GenericArgument, Ident, LitStr, Path, PathArguments, Type, TypePath,
};
/// Extracts all generic arguments of a [`Type`] into a [`Vec`].
///
/// Returns [`Some`] of a possibly-empty [`Vec`] if the path of `ty` matches with
/// `expected_ty_name`, otherwise returns [`None`].
fn extract_generic_args(expected_ty_name: &'static str, ty: &Type) -> Option<Vec<GenericArgument>> {
let Type::Path(TypePath { qself: None, path, .. }) = ty else {
return None;
};
let segments = &path.segments;
if segments.len() != 1 {
return None;
};
let segment = segments.iter().next().unwrap();
if segment.ident != expected_ty_name {
return None;
}
let PathArguments::AngleBracketed(path_args) = &segment.arguments else {
return Some(Vec::new());
};
let args = &path_args.args;
Some(args.iter().cloned().collect::<Vec<_>>())
}
/// Maps a `path` matching one of the `target_idents` into the `replacement` [`Ident`].
fn map_path_to_ident(path: &Path, target_idents: &[&str], replacement: &str) -> Option<Ident> {
path.require_ident()
.ok()
.filter(|ident| target_idents.iter().any(|target| ident == target))
.map(|ident| Ident::new(replacement, ident.span()))
}
/// Extracts the left-hand side of a dot-expression.
fn extract_dot_operand(expr: &Expr) -> Option<&Expr> {
match expr {
Expr::MethodCall(ExprMethodCall { receiver: operand, .. })
| Expr::Field(ExprField { base: operand, .. }) => Some(operand),
_ => None,
}
}
/// Replaces the top-level receiver of a dot-expression with an [`Ident`], returning `Some(&mut expr)` if the
/// replacement is performed.
///
/// The top-level receiver is the left-most receiver expression, e.g. the top-level receiver of `a.b.c.foo()` is `a`.
fn replace_top_level_receiver(expr: &mut Expr, ident: Ident) -> Option<&mut Expr> {
if let Expr::MethodCall(ExprMethodCall { receiver: operand, .. })
| Expr::Field(ExprField { base: operand, .. }) = expr
{
return if extract_dot_operand(operand).is_some() {
if replace_top_level_receiver(operand, ident).is_some() {
Some(expr)
} else {
None
}
} else {
*operand = Box::new(Expr::Path(ExprPath {
attrs: Vec::default(),
qself: None,
path: ident.into(),
}));
Some(expr)
};
}
None
}
/// Iterates all operands to the left-hand side of the `.` of an [expression][`Expr`], i.e. the container operand of all
/// [`Expr::Field`] and the receiver operand of all [`Expr::MethodCall`].
///
/// The iterator will return the operand expressions in reverse order of appearance. For example, `a.b.c.func()` will
/// return `vec![c, b, a]`.
fn iter_dot_operands(expr: &Expr) -> impl Iterator<Item = &Expr> {
let mut o = extract_dot_operand(expr);
std::iter::from_fn(move || {
let this = o;
o = o.as_ref().and_then(|o| extract_dot_operand(o));
this
})
}
/// Normalizes a value expression for use when creating an instance of this structure, returning a
/// [`proc_macro2::TokenStream`] of tokens representing the normalized expression.
fn normalize_value_expr(expr: &Expr) -> proc_macro2::TokenStream {
match &expr {
Expr::Path(ExprPath { qself: None, path, .. }) => {
if let Some(ident) = map_path_to_ident(path, &["usize", "size_t"], "llvm_usize") {
quote! { #ident }
} else {
abort!(
path,
format!(
"Expected one of `size_t`, `usize`, or an implicit call expression in #[value_type(...)], found {}",
quote!(#expr).to_string(),
)
)
}
}
Expr::Call(_) => {
quote! { ctx.#expr }
}
Expr::MethodCall(_) => {
let base_receiver = iter_dot_operands(expr).last();
match base_receiver {
// `usize.{...}`, `size_t.{...}` -> Rewrite the identifiers to `llvm_usize`
Some(Expr::Path(ExprPath { qself: None, path, .. }))
if map_path_to_ident(path, &["usize", "size_t"], "llvm_usize").is_some() =>
{
let ident =
map_path_to_ident(path, &["usize", "size_t"], "llvm_usize").unwrap();
let mut expr = expr.clone();
let expr = replace_top_level_receiver(&mut expr, ident).unwrap();
quote!(#expr)
}
// `ctx.{...}`, `context.{...}` -> Rewrite the identifiers to `ctx`
Some(Expr::Path(ExprPath { qself: None, path, .. }))
if map_path_to_ident(path, &["ctx", "context"], "ctx").is_some() =>
{
let ident = map_path_to_ident(path, &["ctx", "context"], "ctx").unwrap();
let mut expr = expr.clone();
let expr = replace_top_level_receiver(&mut expr, ident).unwrap();
quote!(#expr)
}
// No reserved identifier prefix -> Prepend `ctx.` to the entire expression
_ => quote! { ctx.#expr },
}
}
_ => {
abort!(
expr,
format!(
"Expected one of `size_t`, `usize`, or an implicit call expression in #[value_type(...)], found {}",
quote!(#expr).to_string(),
)
)
}
}
}
/// Derives an implementation of `codegen::types::structure::StructFields`.
///
/// The benefit of using `#[derive(StructFields)]` is that all index- or order-dependent logic required by
/// `impl StructFields` is automatically generated by this implementation, including the field index as required by
/// `StructField::new` and the fields as returned by `StructFields::to_vec`.
///
/// # Prerequisites
///
/// In order to derive from [`StructFields`], you must implement (or derive) [`Eq`] and [`Copy`] as required by
/// `StructFields`.
///
/// Moreover, `#[derive(StructFields)]` can only be used for `struct`s with named fields, and may only contain fields
/// with either `StructField` or [`PhantomData`] types.
///
/// # Attributes for [`StructFields`]
///
/// Each `StructField` field must be declared with the `#[value_type(...)]` attribute. The argument of `value_type`
/// accepts one of the following:
///
/// - An expression returning an instance of `inkwell::types::BasicType` (with or without the receiver `ctx`/`context`).
/// For example, `context.i8_type()`, `ctx.i8_type()`, and `i8_type()` all refer to `i8`.
/// - The reserved identifiers `usize` and `size_t` referring to an `inkwell::types::IntType` of the platform-dependent
/// integer size. `usize` and `size_t` can also be used as the receiver to other method calls, e.g.
/// `usize.array_type(3)`.
///
/// # Example
///
/// The following is an example of an LLVM slice implemented using `#[derive(StructFields)]`.
///
/// ```rust,ignore
/// use nac3core::{
/// codegen::types::structure::StructField,
/// inkwell::{
/// values::{IntValue, PointerValue},
/// AddressSpace,
/// },
/// };
/// use nac3core_derive::StructFields;
///
/// // All classes that implement StructFields must also implement Eq and Copy
/// #[derive(PartialEq, Eq, Clone, Copy, StructFields)]
/// pub struct SliceValue<'ctx> {
/// // Declares ptr have a value type of i8*
/// //
/// // Can also be written as `ctx.i8_type().ptr_type(...)` or `context.i8_type().ptr_type(...)`
/// #[value_type(i8_type().ptr_type(AddressSpace::default()))]
/// ptr: StructField<'ctx, PointerValue<'ctx>>,
///
/// // Declares len have a value type of usize, depending on the target compilation platform
/// #[value_type(usize)]
/// len: StructField<'ctx, IntValue<'ctx>>,
/// }
/// ```
#[proc_macro_derive(StructFields, attributes(value_type))]
#[proc_macro_error]
pub fn derive(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as syn::DeriveInput);
let ident = &input.ident;
let Data::Struct(DataStruct { fields, .. }) = &input.data else {
abort!(input, "Only structs with named fields are supported");
};
if let Err(err_span) =
fields
.iter()
.try_for_each(|field| if field.ident.is_some() { Ok(()) } else { Err(field.span()) })
{
abort!(err_span, "Only structs with named fields are supported");
};
// Check if struct<'ctx>
if input.generics.params.len() != 1 {
abort!(input.generics, "Expected exactly 1 generic parameter")
}
let phantom_info = fields
.iter()
.filter(|field| extract_generic_args("PhantomData", &field.ty).is_some())
.map(|field| field.ident.as_ref().unwrap())
.cloned()
.collect::<Vec<_>>();
let field_info = fields
.iter()
.filter(|field| extract_generic_args("PhantomData", &field.ty).is_none())
.map(|field| {
let ident = field.ident.as_ref().unwrap();
let ty = &field.ty;
let Some(_) = extract_generic_args("StructField", ty) else {
abort!(field, "Only StructField and PhantomData are allowed")
};
let attrs = &field.attrs;
let Some(value_type_attr) =
attrs.iter().find(|attr| attr.path().is_ident("value_type"))
else {
abort!(field, "Expected #[value_type(...)] attribute for field");
};
let Ok(value_type_expr) = value_type_attr.parse_args::<Expr>() else {
abort!(value_type_attr, "Expected expression in #[value_type(...)]");
};
let value_expr_toks = normalize_value_expr(&value_type_expr);
(ident.clone(), value_expr_toks)
})
.collect::<Vec<_>>();
// `<*>::new` impl of `StructField` and `PhantomData` for `StructFields::new`
let phantoms_create = phantom_info
.iter()
.map(|id| quote! { #id: ::std::marker::PhantomData })
.collect::<Vec<_>>();
let fields_create = field_info
.iter()
.map(|(id, ty)| {
let id_lit = LitStr::new(&id.to_string(), id.span());
quote! {
#id: ::nac3core::codegen::types::structure::StructField::create(
&mut counter,
#id_lit,
#ty,
)
}
})
.collect::<Vec<_>>();
// `.into()` impl of `StructField` for `StructFields::to_vec`
let fields_into =
field_info.iter().map(|(id, _)| quote! { self.#id.into() }).collect::<Vec<_>>();
let impl_block = quote! {
impl<'ctx> ::nac3core::codegen::types::structure::StructFields<'ctx> for #ident<'ctx> {
fn new(ctx: impl ::nac3core::inkwell::context::AsContextRef<'ctx>, llvm_usize: ::nac3core::inkwell::types::IntType<'ctx>) -> Self {
let ctx = unsafe { ::nac3core::inkwell::context::ContextRef::new(ctx.as_ctx_ref()) };
let mut counter = ::nac3core::codegen::types::structure::FieldIndexCounter::default();
#ident {
#(#fields_create),*
#(#phantoms_create),*
}
}
fn to_vec(&self) -> ::std::vec::Vec<(&'static str, ::nac3core::inkwell::types::BasicTypeEnum<'ctx>)> {
vec![
#(#fields_into),*
]
}
}
};
impl_block.into()
}

View File

@ -1,9 +0,0 @@
use nac3core_derive::StructFields;
use std::marker::PhantomData;
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct EmptyValue<'ctx> {
_phantom: PhantomData<&'ctx ()>,
}
fn main() {}

View File

@ -1,20 +0,0 @@
use nac3core::{
codegen::types::structure::StructField,
inkwell::{
values::{IntValue, PointerValue},
AddressSpace,
},
};
use nac3core_derive::StructFields;
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct NDArrayValue<'ctx> {
#[value_type(usize)]
ndims: StructField<'ctx, IntValue<'ctx>>,
#[value_type(usize.ptr_type(AddressSpace::default()))]
shape: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
data: StructField<'ctx, PointerValue<'ctx>>,
}
fn main() {}

View File

@ -1,18 +0,0 @@
use nac3core::{
codegen::types::structure::StructField,
inkwell::{
values::{IntValue, PointerValue},
AddressSpace,
},
};
use nac3core_derive::StructFields;
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct SliceValue<'ctx> {
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
ptr: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(usize)]
len: StructField<'ctx, IntValue<'ctx>>,
}
fn main() {}

View File

@ -1,18 +0,0 @@
use nac3core::{
codegen::types::structure::StructField,
inkwell::{
values::{IntValue, PointerValue},
AddressSpace,
},
};
use nac3core_derive::StructFields;
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct SliceValue<'ctx> {
#[value_type(context.i8_type().ptr_type(AddressSpace::default()))]
ptr: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(usize)]
len: StructField<'ctx, IntValue<'ctx>>,
}
fn main() {}

View File

@ -1,18 +0,0 @@
use nac3core::{
codegen::types::structure::StructField,
inkwell::{
values::{IntValue, PointerValue},
AddressSpace,
},
};
use nac3core_derive::StructFields;
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct SliceValue<'ctx> {
#[value_type(ctx.i8_type().ptr_type(AddressSpace::default()))]
ptr: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(usize)]
len: StructField<'ctx, IntValue<'ctx>>,
}
fn main() {}

View File

@ -1,18 +0,0 @@
use nac3core::{
codegen::types::structure::StructField,
inkwell::{
values::{IntValue, PointerValue},
AddressSpace,
},
};
use nac3core_derive::StructFields;
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct SliceValue<'ctx> {
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
ptr: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(size_t)]
len: StructField<'ctx, IntValue<'ctx>>,
}
fn main() {}

View File

@ -1,10 +0,0 @@
#[test]
fn test_parse_empty() {
let t = trybuild::TestCases::new();
t.pass("tests/structfields_empty.rs");
t.pass("tests/structfields_slice.rs");
t.pass("tests/structfields_slice_ctx.rs");
t.pass("tests/structfields_slice_context.rs");
t.pass("tests/structfields_slice_sizet.rs");
t.pass("tests/structfields_ndarray.rs");
}

File diff suppressed because it is too large Load Diff

View File

@ -1,325 +0,0 @@
use std::collections::HashMap;
use indexmap::IndexMap;
use nac3parser::ast::StrRef;
use crate::{
symbol_resolver::SymbolValue,
toplevel::DefinitionId,
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{
into_var_map, FunSignature, FuncArg, Type, TypeEnum, TypeVar, TypeVarId, Unifier,
},
},
};
pub struct ConcreteTypeStore {
store: Vec<ConcreteTypeEnum>,
}
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub struct ConcreteType(usize);
#[derive(Clone, Debug)]
pub struct ConcreteFuncArg {
pub name: StrRef,
pub ty: ConcreteType,
pub default_value: Option<SymbolValue>,
pub is_vararg: bool,
}
#[derive(Clone, Debug)]
pub enum Primitive {
Int32,
Int64,
UInt32,
UInt64,
Float,
Bool,
None,
Range,
Str,
Exception,
}
#[derive(Debug)]
pub enum ConcreteTypeEnum {
TPrimitive(Primitive),
TTuple {
ty: Vec<ConcreteType>,
is_vararg_ctx: bool,
},
TObj {
obj_id: DefinitionId,
fields: HashMap<StrRef, (ConcreteType, bool)>,
params: IndexMap<TypeVarId, ConcreteType>,
},
TVirtual {
ty: ConcreteType,
},
TFunc {
args: Vec<ConcreteFuncArg>,
ret: ConcreteType,
vars: HashMap<TypeVarId, ConcreteType>,
},
TLiteral {
values: Vec<SymbolValue>,
},
}
impl ConcreteTypeStore {
#[must_use]
pub fn new() -> ConcreteTypeStore {
ConcreteTypeStore {
store: vec![
ConcreteTypeEnum::TPrimitive(Primitive::Int32),
ConcreteTypeEnum::TPrimitive(Primitive::Int64),
ConcreteTypeEnum::TPrimitive(Primitive::Float),
ConcreteTypeEnum::TPrimitive(Primitive::Bool),
ConcreteTypeEnum::TPrimitive(Primitive::None),
ConcreteTypeEnum::TPrimitive(Primitive::Range),
ConcreteTypeEnum::TPrimitive(Primitive::Str),
ConcreteTypeEnum::TPrimitive(Primitive::Exception),
ConcreteTypeEnum::TPrimitive(Primitive::UInt32),
ConcreteTypeEnum::TPrimitive(Primitive::UInt64),
],
}
}
#[must_use]
pub fn get(&self, cty: ConcreteType) -> &ConcreteTypeEnum {
&self.store[cty.0]
}
pub fn from_signature(
&mut self,
unifier: &mut Unifier,
primitives: &PrimitiveStore,
signature: &FunSignature,
cache: &mut HashMap<Type, Option<ConcreteType>>,
) -> ConcreteTypeEnum {
ConcreteTypeEnum::TFunc {
args: signature
.args
.iter()
.map(|arg| ConcreteFuncArg {
name: arg.name,
ty: if arg.is_vararg {
let tuple_ty = unifier
.add_ty(TypeEnum::TTuple { ty: vec![arg.ty], is_vararg_ctx: true });
self.from_unifier_type(unifier, primitives, tuple_ty, cache)
} else {
self.from_unifier_type(unifier, primitives, arg.ty, cache)
},
default_value: arg.default_value.clone(),
is_vararg: arg.is_vararg,
})
.collect(),
ret: self.from_unifier_type(unifier, primitives, signature.ret, cache),
vars: signature
.vars
.iter()
.map(|(id, ty)| (*id, self.from_unifier_type(unifier, primitives, *ty, cache)))
.collect(),
}
}
pub fn from_unifier_type(
&mut self,
unifier: &mut Unifier,
primitives: &PrimitiveStore,
ty: Type,
cache: &mut HashMap<Type, Option<ConcreteType>>,
) -> ConcreteType {
let ty = unifier.get_representative(ty);
if unifier.unioned(ty, primitives.int32) {
ConcreteType(0)
} else if unifier.unioned(ty, primitives.int64) {
ConcreteType(1)
} else if unifier.unioned(ty, primitives.float) {
ConcreteType(2)
} else if unifier.unioned(ty, primitives.bool) {
ConcreteType(3)
} else if unifier.unioned(ty, primitives.none) {
ConcreteType(4)
} else if unifier.unioned(ty, primitives.range) {
ConcreteType(5)
} else if unifier.unioned(ty, primitives.str) {
ConcreteType(6)
} else if unifier.unioned(ty, primitives.exception) {
ConcreteType(7)
} else if unifier.unioned(ty, primitives.uint32) {
ConcreteType(8)
} else if unifier.unioned(ty, primitives.uint64) {
ConcreteType(9)
} else if let Some(cty) = cache.get(&ty) {
if let Some(cty) = cty {
*cty
} else {
let index = self.store.len();
// placeholder
self.store.push(ConcreteTypeEnum::TPrimitive(Primitive::Int32));
let result = ConcreteType(index);
cache.insert(ty, Some(result));
result
}
} else {
cache.insert(ty, None);
let ty_enum = unifier.get_ty(ty);
let result = match &*ty_enum {
TypeEnum::TTuple { ty, is_vararg_ctx } => ConcreteTypeEnum::TTuple {
ty: ty
.iter()
.map(|t| self.from_unifier_type(unifier, primitives, *t, cache))
.collect(),
is_vararg_ctx: *is_vararg_ctx,
},
TypeEnum::TObj { obj_id, fields, params } => ConcreteTypeEnum::TObj {
obj_id: *obj_id,
fields: fields
.iter()
.filter_map(|(name, ty)| {
// here we should not have type vars, but some partial instantiated
// class methods can still have uninstantiated type vars, so
// filter out all the methods, as this will not affect codegen
if let TypeEnum::TFunc(..) = &*unifier.get_ty(ty.0) {
None
} else {
Some((
*name,
(
self.from_unifier_type(unifier, primitives, ty.0, cache),
ty.1,
),
))
}
})
.collect(),
params: params
.iter()
.map(|(id, ty)| {
(*id, self.from_unifier_type(unifier, primitives, *ty, cache))
})
.collect(),
},
TypeEnum::TVirtual { ty } => ConcreteTypeEnum::TVirtual {
ty: self.from_unifier_type(unifier, primitives, *ty, cache),
},
TypeEnum::TFunc(signature) => {
self.from_signature(unifier, primitives, signature, cache)
}
TypeEnum::TLiteral { values, .. } => {
ConcreteTypeEnum::TLiteral { values: values.clone() }
}
_ => unreachable!("{:?}", ty_enum.get_type_name()),
};
let index = if let Some(ConcreteType(index)) = cache.get(&ty).unwrap() {
self.store[*index] = result;
*index
} else {
self.store.push(result);
self.store.len() - 1
};
cache.insert(ty, Some(ConcreteType(index)));
ConcreteType(index)
}
}
pub fn to_unifier_type(
&self,
unifier: &mut Unifier,
primitives: &PrimitiveStore,
cty: ConcreteType,
cache: &mut HashMap<ConcreteType, Option<Type>>,
) -> Type {
if let Some(ty) = cache.get_mut(&cty) {
return if let Some(ty) = ty {
*ty
} else {
*ty = Some(unifier.get_dummy_var().ty);
ty.unwrap()
};
}
cache.insert(cty, None);
let result = match &self.store[cty.0] {
ConcreteTypeEnum::TPrimitive(primitive) => {
let ty = match primitive {
Primitive::Int32 => primitives.int32,
Primitive::Int64 => primitives.int64,
Primitive::UInt32 => primitives.uint32,
Primitive::UInt64 => primitives.uint64,
Primitive::Float => primitives.float,
Primitive::Bool => primitives.bool,
Primitive::None => primitives.none,
Primitive::Range => primitives.range,
Primitive::Str => primitives.str,
Primitive::Exception => primitives.exception,
};
*cache.get_mut(&cty).unwrap() = Some(ty);
return ty;
}
ConcreteTypeEnum::TTuple { ty, is_vararg_ctx } => TypeEnum::TTuple {
ty: ty
.iter()
.map(|cty| self.to_unifier_type(unifier, primitives, *cty, cache))
.collect(),
is_vararg_ctx: *is_vararg_ctx,
},
ConcreteTypeEnum::TVirtual { ty } => {
TypeEnum::TVirtual { ty: self.to_unifier_type(unifier, primitives, *ty, cache) }
}
ConcreteTypeEnum::TObj { obj_id, fields, params } => TypeEnum::TObj {
obj_id: *obj_id,
fields: fields
.iter()
.map(|(name, cty)| {
(*name, (self.to_unifier_type(unifier, primitives, cty.0, cache), cty.1))
})
.collect::<HashMap<_, _>>(),
params: into_var_map(params.iter().map(|(&id, cty)| {
let ty = self.to_unifier_type(unifier, primitives, *cty, cache);
TypeVar { id, ty }
})),
},
ConcreteTypeEnum::TFunc { args, ret, vars } => TypeEnum::TFunc(FunSignature {
args: args
.iter()
.map(|arg| FuncArg {
name: arg.name,
ty: self.to_unifier_type(unifier, primitives, arg.ty, cache),
default_value: arg.default_value.clone(),
is_vararg: false,
})
.collect(),
ret: self.to_unifier_type(unifier, primitives, *ret, cache),
vars: into_var_map(vars.iter().map(|(&id, cty)| {
let ty = self.to_unifier_type(unifier, primitives, *cty, cache);
TypeVar { id, ty }
})),
}),
ConcreteTypeEnum::TLiteral { values, .. } => {
TypeEnum::TLiteral { values: values.clone(), loc: None }
}
};
let result = unifier.add_ty(result);
if let Some(ty) = cache.get(&cty).unwrap() {
unifier.unify(*ty, result).unwrap();
}
cache.insert(cty, Some(result));
result
}
pub fn add_cty(&mut self, cty: ConcreteTypeEnum) -> ConcreteType {
self.store.push(cty);
ConcreteType(self.store.len() - 1)
}
}
impl Default for ConcreteTypeStore {
fn default() -> Self {
Self::new()
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,193 +0,0 @@
use inkwell::{
attributes::{Attribute, AttributeLoc},
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue},
};
use itertools::Either;
use super::CodeGenContext;
/// Macro to generate extern function
/// Both function return type and function parameter type are `FloatValue`
///
/// Arguments:
/// * `unary/binary`: Whether the extern function requires one (unary) or two (binary) operands
/// * `$fn_name:ident`: The identifier of the rust function to be generated
/// * `$extern_fn:literal`: Name of underlying extern function
///
/// Optional Arguments:
/// * `$(,$attributes:literal)*)`: Attributes linked with the extern function.
/// The default attributes are "mustprogress", "nofree", "nounwind", "willreturn", and "writeonly".
/// These will be used unless other attributes are specified
/// * `$(,$args:ident)*`: Operands of the extern function
/// The data type of these operands will be set to `FloatValue`
///
macro_rules! generate_extern_fn {
("unary", $fn_name:ident, $extern_fn:literal) => {
generate_extern_fn!($fn_name, $extern_fn, arg, "mustprogress", "nofree", "nounwind", "willreturn", "writeonly");
};
("unary", $fn_name:ident, $extern_fn:literal $(,$attributes:literal)*) => {
generate_extern_fn!($fn_name, $extern_fn, arg $(,$attributes)*);
};
("binary", $fn_name:ident, $extern_fn:literal) => {
generate_extern_fn!($fn_name, $extern_fn, arg1, arg2, "mustprogress", "nofree", "nounwind", "willreturn", "writeonly");
};
("binary", $fn_name:ident, $extern_fn:literal $(,$attributes:literal)*) => {
generate_extern_fn!($fn_name, $extern_fn, arg1, arg2 $(,$attributes)*);
};
($fn_name:ident, $extern_fn:literal $(,$args:ident)* $(,$attributes:literal)*) => {
#[doc = concat!("Invokes the [`", stringify!($extern_fn), "`](https://en.cppreference.com/w/c/numeric/math/", stringify!($llvm_name), ") function." )]
pub fn $fn_name<'ctx>(
ctx: &CodeGenContext<'ctx, '_>
$(,$args: FloatValue<'ctx>)*,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = $extern_fn;
let llvm_f64 = ctx.ctx.f64_type();
$(debug_assert_eq!($args.get_type(), llvm_f64);)*
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[$($args.get_type().into()),*], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in [$($attributes),*] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
.build_call(extern_fn, &[$($args.into()),*], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
};
}
generate_extern_fn!("unary", call_tan, "tan");
generate_extern_fn!("unary", call_asin, "asin");
generate_extern_fn!("unary", call_acos, "acos");
generate_extern_fn!("unary", call_atan, "atan");
generate_extern_fn!("unary", call_sinh, "sinh");
generate_extern_fn!("unary", call_cosh, "cosh");
generate_extern_fn!("unary", call_tanh, "tanh");
generate_extern_fn!("unary", call_asinh, "asinh");
generate_extern_fn!("unary", call_acosh, "acosh");
generate_extern_fn!("unary", call_atanh, "atanh");
generate_extern_fn!("unary", call_expm1, "expm1");
generate_extern_fn!(
"unary",
call_cbrt,
"cbrt",
"mustprogress",
"nofree",
"nosync",
"nounwind",
"readonly",
"willreturn"
);
generate_extern_fn!("unary", call_erf, "erf", "nounwind");
generate_extern_fn!("unary", call_erfc, "erfc", "nounwind");
generate_extern_fn!("unary", call_j1, "j1", "nounwind");
generate_extern_fn!("binary", call_atan2, "atan2");
generate_extern_fn!("binary", call_hypot, "hypot", "nounwind");
generate_extern_fn!("binary", call_nextafter, "nextafter", "nounwind");
/// Invokes the [`ldexp`](https://en.cppreference.com/w/c/numeric/math/ldexp) function.
pub fn call_ldexp<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
exp: IntValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "ldexp";
let llvm_f64 = ctx.ctx.f64_type();
let llvm_i32 = ctx.ctx.i32_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
debug_assert_eq!(exp.get_type(), llvm_i32);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into(), llvm_i32.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
.build_call(extern_fn, &[arg.into(), exp.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Macro to generate `np_linalg` and `sp_linalg` functions
/// The function takes as input `NDArray` and returns ()
///
/// Arguments:
/// * `$fn_name:ident`: The identifier of the rust function to be generated
/// * `$extern_fn:literal`: Name of underlying extern function
/// * (2/3/4): Number of `NDArray` that function takes as input
///
/// Note:
/// The operands and resulting `NDArray` are both passed as input to the funcion
/// It is the responsibility of caller to ensure that output `NDArray` is properly allocated on stack
/// The function changes the content of the output `NDArray` in-place
macro_rules! generate_linalg_extern_fn {
($fn_name:ident, $extern_fn:literal, 2) => {
generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2);
};
($fn_name:ident, $extern_fn:literal, 3) => {
generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2, mat3);
};
($fn_name:ident, $extern_fn:literal, 4) => {
generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2, mat3, mat4);
};
($fn_name:ident, $extern_fn:literal $(,$input_matrix:ident)*) => {
#[doc = concat!("Invokes the linalg `", stringify!($extern_fn), " function." )]
pub fn $fn_name<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>
$(,$input_matrix: BasicValueEnum<'ctx>)*,
name: Option<&str>,
){
const FN_NAME: &str = $extern_fn;
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = ctx.ctx.void_type().fn_type(&[$($input_matrix.get_type().into()),*], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder.build_call(extern_fn, &[$($input_matrix.into(),)*], name.unwrap_or_default()).unwrap();
}
};
}
generate_linalg_extern_fn!(call_np_linalg_cholesky, "np_linalg_cholesky", 2);
generate_linalg_extern_fn!(call_np_linalg_qr, "np_linalg_qr", 3);
generate_linalg_extern_fn!(call_np_linalg_svd, "np_linalg_svd", 4);
generate_linalg_extern_fn!(call_np_linalg_inv, "np_linalg_inv", 2);
generate_linalg_extern_fn!(call_np_linalg_pinv, "np_linalg_pinv", 2);
generate_linalg_extern_fn!(call_np_linalg_matrix_power, "np_linalg_matrix_power", 3);
generate_linalg_extern_fn!(call_np_linalg_det, "np_linalg_det", 2);
generate_linalg_extern_fn!(call_sp_linalg_lu, "sp_linalg_lu", 3);
generate_linalg_extern_fn!(call_sp_linalg_schur, "sp_linalg_schur", 3);
generate_linalg_extern_fn!(call_sp_linalg_hessenberg, "sp_linalg_hessenberg", 3);

View File

@ -1,294 +0,0 @@
use inkwell::{
context::Context,
types::{BasicTypeEnum, IntType},
values::{BasicValueEnum, IntValue, PointerValue},
};
use nac3parser::ast::{Expr, Stmt, StrRef};
use super::{bool_to_i1, bool_to_i8, expr::*, stmt::*, values::ArraySliceValue, CodeGenContext};
use crate::{
symbol_resolver::ValueEnum,
toplevel::{DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, Type},
};
pub trait CodeGenerator {
/// Return the module name for the code generator.
fn get_name(&self) -> &str;
fn get_size_type<'ctx>(&self, ctx: &'ctx Context) -> IntType<'ctx>;
/// Generate function call and returns the function return value.
/// - obj: Optional object for method call.
/// - fun: Function signature and definition ID.
/// - params: Function parameters. Note that this does not include the object even if the
/// function is a class method.
fn gen_call<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
obj: Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
) -> Result<Option<BasicValueEnum<'ctx>>, String>
where
Self: Sized,
{
gen_call(self, ctx, obj, fun, params)
}
/// Generate object constructor and returns the constructed object.
/// - signature: Function signature of the constructor.
/// - def: Class definition for the constructor class.
/// - params: Function parameters.
fn gen_constructor<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
signature: &FunSignature,
def: &TopLevelDef,
params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
) -> Result<BasicValueEnum<'ctx>, String>
where
Self: Sized,
{
gen_constructor(self, ctx, signature, def, params)
}
/// Generate a function instance.
/// - obj: Optional object for method call.
/// - fun: Function signature, definition ID and the substitution key.
/// - params: Function parameters. Note that this does not include the object even if the
/// function is a class method.
///
/// Note that this function should check if the function is generated in another thread (due to
/// possible race condition), see the default implementation for an example.
fn gen_func_instance<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
obj: Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, &mut TopLevelDef, String),
id: usize,
) -> Result<String, String> {
gen_func_instance(ctx, &obj, fun, id)
}
/// Generate the code for an expression.
fn gen_expr<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
expr: &Expr<Option<Type>>,
) -> Result<Option<ValueEnum<'ctx>>, String>
where
Self: Sized,
{
gen_expr(self, ctx, expr)
}
/// Allocate memory for a variable and return a pointer pointing to it.
/// The default implementation places the allocations at the start of the function.
fn gen_var_alloc<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
ty: BasicTypeEnum<'ctx>,
name: Option<&str>,
) -> Result<PointerValue<'ctx>, String> {
gen_var(ctx, ty, name)
}
/// Allocate memory for a variable and return a pointer pointing to it.
/// The default implementation places the allocations at the start of the function.
fn gen_array_var_alloc<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
ty: BasicTypeEnum<'ctx>,
size: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> Result<ArraySliceValue<'ctx>, String> {
gen_array_var(ctx, ty, size, name)
}
/// Return a pointer pointing to the target of the expression.
fn gen_store_target<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
pattern: &Expr<Option<Type>>,
name: Option<&str>,
) -> Result<Option<PointerValue<'ctx>>, String>
where
Self: Sized,
{
gen_store_target(self, ctx, pattern, name)
}
/// Generate code for an assignment expression.
fn gen_assign<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
target: &Expr<Option<Type>>,
value: ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String>
where
Self: Sized,
{
gen_assign(self, ctx, target, value, value_ty)
}
/// Generate code for an assignment expression where LHS is a `"target_list"`.
///
/// See <https://docs.python.org/3/reference/simple_stmts.html#assignment-statements>.
fn gen_assign_target_list<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
targets: &Vec<Expr<Option<Type>>>,
value: ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String>
where
Self: Sized,
{
gen_assign_target_list(self, ctx, targets, value, value_ty)
}
/// Generate code for an item assignment.
///
/// i.e., `target[key] = value`
fn gen_setitem<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
target: &Expr<Option<Type>>,
key: &Expr<Option<Type>>,
value: ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String>
where
Self: Sized,
{
gen_setitem(self, ctx, target, key, value, value_ty)
}
/// Generate code for a while expression.
/// Return true if the while loop must early return
fn gen_while(
&mut self,
ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>,
) -> Result<(), String>
where
Self: Sized,
{
gen_while(self, ctx, stmt)
}
/// Generate code for a for expression.
/// Return true if the for loop must early return
fn gen_for(
&mut self,
ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>,
) -> Result<(), String>
where
Self: Sized,
{
gen_for(self, ctx, stmt)
}
/// Generate code for an if expression.
/// Return true if the statement must early return
fn gen_if(
&mut self,
ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>,
) -> Result<(), String>
where
Self: Sized,
{
gen_if(self, ctx, stmt)
}
fn gen_with(
&mut self,
ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>,
) -> Result<(), String>
where
Self: Sized,
{
gen_with(self, ctx, stmt)
}
/// Generate code for a statement
///
/// Return true if the statement must early return
fn gen_stmt(
&mut self,
ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>,
) -> Result<(), String>
where
Self: Sized,
{
gen_stmt(self, ctx, stmt)
}
/// Generates code for a block statement.
fn gen_block<'a, I: Iterator<Item = &'a Stmt<Option<Type>>>>(
&mut self,
ctx: &mut CodeGenContext<'_, '_>,
stmts: I,
) -> Result<(), String>
where
Self: Sized,
{
gen_block(self, ctx, stmts)
}
/// See [`bool_to_i1`].
fn bool_to_i1<'ctx>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
bool_value: IntValue<'ctx>,
) -> IntValue<'ctx> {
bool_to_i1(&ctx.builder, bool_value)
}
/// See [`bool_to_i8`].
fn bool_to_i8<'ctx>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
bool_value: IntValue<'ctx>,
) -> IntValue<'ctx> {
bool_to_i8(&ctx.builder, ctx.ctx, bool_value)
}
}
pub struct DefaultCodeGenerator {
name: String,
size_t: u32,
}
impl DefaultCodeGenerator {
#[must_use]
pub fn new(name: String, size_t: u32) -> DefaultCodeGenerator {
assert!(matches!(size_t, 32 | 64));
DefaultCodeGenerator { name, size_t }
}
}
impl CodeGenerator for DefaultCodeGenerator {
/// Returns the name for this [`CodeGenerator`].
fn get_name(&self) -> &str {
&self.name
}
/// Returns an LLVM integer type representing `size_t`.
fn get_size_type<'ctx>(&self, ctx: &'ctx Context) -> IntType<'ctx> {
// it should be unsigned, but we don't really need unsigned and this could save us from
// having to do a bit cast...
if self.size_t == 32 {
ctx.i32_type()
} else {
ctx.i64_type()
}
}
}

View File

@ -1,162 +0,0 @@
use inkwell::{
types::BasicTypeEnum,
values::{BasicValueEnum, CallSiteValue, IntValue},
AddressSpace, IntPredicate,
};
use itertools::Either;
use super::calculate_len_for_slice_range;
use crate::codegen::{
macros::codegen_unreachable,
values::{ArrayLikeValue, ListValue},
CodeGenContext, CodeGenerator,
};
/// This function handles 'end' **inclusively**.
/// Order of tuples `assign_idx` and `value_idx` is ('start', 'end', 'step').
/// Negative index should be handled before entering this function
pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ty: BasicTypeEnum<'ctx>,
dest_arr: ListValue<'ctx>,
dest_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
src_arr: ListValue<'ctx>,
src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
) {
let size_ty = generator.get_size_type(ctx.ctx);
let int8_ptr = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
let int32 = ctx.ctx.i32_type();
let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", int8_ptr);
let slice_assign_fun = {
let ty_vec = vec![
int32.into(), // dest start idx
int32.into(), // dest end idx
int32.into(), // dest step
elem_ptr_type.into(), // dest arr ptr
int32.into(), // dest arr len
int32.into(), // src start idx
int32.into(), // src end idx
int32.into(), // src step
elem_ptr_type.into(), // src arr ptr
int32.into(), // src arr len
int32.into(), // size
];
ctx.module.get_function(fun_symbol).unwrap_or_else(|| {
let fn_t = int32.fn_type(ty_vec.as_slice(), false);
ctx.module.add_function(fun_symbol, fn_t, None)
})
};
let zero = int32.const_zero();
let one = int32.const_int(1, false);
let dest_arr_ptr = dest_arr.data().base_ptr(ctx, generator);
let dest_arr_ptr =
ctx.builder.build_pointer_cast(dest_arr_ptr, elem_ptr_type, "dest_arr_ptr_cast").unwrap();
let dest_len = dest_arr.load_size(ctx, Some("dest.len"));
let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32").unwrap();
let src_arr_ptr = src_arr.data().base_ptr(ctx, generator);
let src_arr_ptr =
ctx.builder.build_pointer_cast(src_arr_ptr, elem_ptr_type, "src_arr_ptr_cast").unwrap();
let src_len = src_arr.load_size(ctx, Some("src.len"));
let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32").unwrap();
// index in bound and positive should be done
// assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and
// throw exception if not satisfied
let src_end = ctx
.builder
.build_select(
ctx.builder.build_int_compare(IntPredicate::SLT, src_idx.2, zero, "is_neg").unwrap(),
ctx.builder.build_int_sub(src_idx.1, one, "e_min_one").unwrap(),
ctx.builder.build_int_add(src_idx.1, one, "e_add_one").unwrap(),
"final_e",
)
.map(BasicValueEnum::into_int_value)
.unwrap();
let dest_end = ctx
.builder
.build_select(
ctx.builder.build_int_compare(IntPredicate::SLT, dest_idx.2, zero, "is_neg").unwrap(),
ctx.builder.build_int_sub(dest_idx.1, one, "e_min_one").unwrap(),
ctx.builder.build_int_add(dest_idx.1, one, "e_add_one").unwrap(),
"final_e",
)
.map(BasicValueEnum::into_int_value)
.unwrap();
let src_slice_len =
calculate_len_for_slice_range(generator, ctx, src_idx.0, src_end, src_idx.2);
let dest_slice_len =
calculate_len_for_slice_range(generator, ctx, dest_idx.0, dest_end, dest_idx.2);
let src_eq_dest = ctx
.builder
.build_int_compare(IntPredicate::EQ, src_slice_len, dest_slice_len, "slice_src_eq_dest")
.unwrap();
let src_slt_dest = ctx
.builder
.build_int_compare(IntPredicate::SLT, src_slice_len, dest_slice_len, "slice_src_slt_dest")
.unwrap();
let dest_step_eq_one = ctx
.builder
.build_int_compare(
IntPredicate::EQ,
dest_idx.2,
dest_idx.2.get_type().const_int(1, false),
"slice_dest_step_eq_one",
)
.unwrap();
let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1").unwrap();
let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond").unwrap();
ctx.make_assert(
generator,
cond,
"0:ValueError",
"attempt to assign sequence of size {0} to slice of size {1} with step size {2}",
[Some(src_slice_len), Some(dest_slice_len), Some(dest_idx.2)],
ctx.current_loc,
);
let new_len = {
let args = vec![
dest_idx.0.into(), // dest start idx
dest_idx.1.into(), // dest end idx
dest_idx.2.into(), // dest step
dest_arr_ptr.into(), // dest arr ptr
dest_len.into(), // dest arr len
src_idx.0.into(), // src start idx
src_idx.1.into(), // src end idx
src_idx.2.into(), // src step
src_arr_ptr.into(), // src arr ptr
src_len.into(), // src arr len
{
let s = match ty {
BasicTypeEnum::FloatType(t) => t.size_of(),
BasicTypeEnum::IntType(t) => t.size_of(),
BasicTypeEnum::PointerType(t) => t.size_of(),
BasicTypeEnum::StructType(t) => t.size_of().unwrap(),
_ => codegen_unreachable!(ctx),
};
ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size").unwrap()
}
.into(),
];
ctx.builder
.build_call(slice_assign_fun, args.as_slice(), "slice_assign")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
};
// update length
let need_update =
ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update").unwrap();
let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
let update_bb = ctx.ctx.append_basic_block(current, "update");
let cont_bb = ctx.ctx.append_basic_block(current, "cont");
ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb).unwrap();
ctx.builder.position_at_end(update_bb);
let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len").unwrap();
dest_arr.store_size(ctx, generator, new_len);
ctx.builder.build_unconditional_branch(cont_bb).unwrap();
ctx.builder.position_at_end(cont_bb);
}

View File

@ -1,152 +0,0 @@
use inkwell::{
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue},
IntPredicate,
};
use itertools::Either;
use crate::codegen::{
macros::codegen_unreachable,
{CodeGenContext, CodeGenerator},
};
// repeated squaring method adapted from GNU Scientific Library:
// https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
base: IntValue<'ctx>,
exp: IntValue<'ctx>,
signed: bool,
) -> IntValue<'ctx> {
let symbol = match (base.get_type().get_bit_width(), exp.get_type().get_bit_width(), signed) {
(32, 32, true) => "__nac3_int_exp_int32_t",
(64, 64, true) => "__nac3_int_exp_int64_t",
(32, 32, false) => "__nac3_int_exp_uint32_t",
(64, 64, false) => "__nac3_int_exp_uint64_t",
_ => codegen_unreachable!(ctx),
};
let base_type = base.get_type();
let pow_fun = ctx.module.get_function(symbol).unwrap_or_else(|| {
let fn_type = base_type.fn_type(&[base_type.into(), base_type.into()], false);
ctx.module.add_function(symbol, fn_type, None)
});
// throw exception when exp < 0
let ge_zero = ctx
.builder
.build_int_compare(
IntPredicate::SGE,
exp,
exp.get_type().const_zero(),
"assert_int_pow_ge_0",
)
.unwrap();
ctx.make_assert(
generator,
ge_zero,
"0:ValueError",
"integer power must be positive or zero",
[None, None, None],
ctx.current_loc,
);
ctx.builder
.build_call(pow_fun, &[base.into(), exp.into()], "call_int_pow")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Generates a call to `isinf` in IR. Returns an `i1` representing the result.
pub fn call_isinf<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> IntValue<'ctx> {
let intrinsic_fn = ctx.module.get_function("__nac3_isinf").unwrap_or_else(|| {
let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
ctx.module.add_function("__nac3_isinf", fn_type, None)
});
let ret = ctx
.builder
.build_call(intrinsic_fn, &[v.into()], "isinf")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap();
generator.bool_to_i1(ctx, ret)
}
/// Generates a call to `isnan` in IR. Returns an `i1` representing the result.
pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> IntValue<'ctx> {
let intrinsic_fn = ctx.module.get_function("__nac3_isnan").unwrap_or_else(|| {
let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
ctx.module.add_function("__nac3_isnan", fn_type, None)
});
let ret = ctx
.builder
.build_call(intrinsic_fn, &[v.into()], "isnan")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap();
generator.bool_to_i1(ctx, ret)
}
/// Generates a call to `gamma` in IR. Returns an `f64` representing the result.
pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("__nac3_gamma", fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[v.into()], "gamma")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Generates a call to `gammaln` in IR. Returns an `f64` representing the result.
pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("__nac3_gammaln", fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[v.into()], "gammaln")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Generates a call to `j0` in IR. Returns an `f64` representing the result.
pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("__nac3_j0", fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[v.into()], "j0")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}

View File

@ -1,227 +0,0 @@
use inkwell::{
attributes::{Attribute, AttributeLoc},
context::Context,
memory_buffer::MemoryBuffer,
module::Module,
values::{BasicValue, BasicValueEnum, IntValue},
IntPredicate,
};
use nac3parser::ast::Expr;
use super::{CodeGenContext, CodeGenerator};
use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Type};
pub use list::*;
pub use math::*;
pub use ndarray::*;
pub use slice::*;
mod list;
mod math;
mod ndarray;
mod slice;
#[must_use]
pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> {
let bitcode_buf = MemoryBuffer::create_from_memory_range(
include_bytes!(concat!(env!("OUT_DIR"), "/irrt.bc")),
"irrt_bitcode_buffer",
);
let irrt_mod = Module::parse_bitcode_from_buffer(&bitcode_buf, ctx).unwrap();
let inline_attr = Attribute::get_named_enum_kind_id("alwaysinline");
for symbol in &[
"__nac3_int_exp_int32_t",
"__nac3_int_exp_int64_t",
"__nac3_range_slice_len",
"__nac3_slice_index_bound",
] {
let function = irrt_mod.get_function(symbol).unwrap();
function.add_attribute(AttributeLoc::Function, ctx.create_enum_attribute(inline_attr, 0));
}
// Initialize all global `EXN_*` exception IDs in IRRT with the [`SymbolResolver`].
let exn_id_type = ctx.i32_type();
let errors = &[
("EXN_INDEX_ERROR", "0:IndexError"),
("EXN_VALUE_ERROR", "0:ValueError"),
("EXN_ASSERTION_ERROR", "0:AssertionError"),
("EXN_TYPE_ERROR", "0:TypeError"),
];
for (irrt_name, symbol_name) in errors {
let exn_id = symbol_resolver.get_string_id(symbol_name);
let exn_id = exn_id_type.const_int(exn_id as u64, false).as_basic_value_enum();
let global = irrt_mod.get_global(irrt_name).unwrap_or_else(|| {
panic!("Exception symbol name '{irrt_name}' should exist in the IRRT LLVM module")
});
global.set_initializer(&exn_id);
}
irrt_mod
}
/// NOTE: the output value of the end index of this function should be compared ***inclusively***,
/// because python allows `a[2::-1]`, whose semantic is `[a[2], a[1], a[0]]`, which is equivalent to
/// NO numeric slice in python.
///
/// equivalent code:
/// ```pseudo_code
/// match (start, end, step):
/// case (s, e, None | Some(step)) if step > 0:
/// return (
/// match s:
/// case None:
/// 0
/// case Some(s):
/// handle_in_bound(s)
/// ,match e:
/// case None:
/// length - 1
/// case Some(e):
/// handle_in_bound(e) - 1
/// ,step == None ? 1 : step
/// )
/// case (s, e, Some(step)) if step < 0:
/// return (
/// match s:
/// case None:
/// length - 1
/// case Some(s):
/// s = handle_in_bound(s)
/// if s == length:
/// s - 1
/// else:
/// s
/// ,match e:
/// case None:
/// 0
/// case Some(e):
/// handle_in_bound(e) + 1
/// ,step
/// )
/// ```
pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
start: &Option<Box<Expr<Option<Type>>>>,
end: &Option<Box<Expr<Option<Type>>>>,
step: &Option<Box<Expr<Option<Type>>>>,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
length: IntValue<'ctx>,
) -> Result<Option<(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)>, String> {
let int32 = ctx.ctx.i32_type();
let zero = int32.const_zero();
let one = int32.const_int(1, false);
let length = ctx.builder.build_int_truncate_or_bit_cast(length, int32, "leni32").unwrap();
Ok(Some(match (start, end, step) {
(s, e, None) => (
if let Some(s) = s.as_ref() {
match handle_slice_index_bound(s, ctx, generator, length)? {
Some(v) => v,
None => return Ok(None),
}
} else {
int32.const_zero()
},
{
let e = if let Some(s) = e.as_ref() {
match handle_slice_index_bound(s, ctx, generator, length)? {
Some(v) => v,
None => return Ok(None),
}
} else {
length
};
ctx.builder.build_int_sub(e, one, "final_end").unwrap()
},
one,
),
(s, e, Some(step)) => {
let step = if let Some(v) = generator.gen_expr(ctx, step)? {
v.to_basic_value_enum(ctx, generator, ctx.primitives.int32)?.into_int_value()
} else {
return Ok(None);
};
// assert step != 0, throw exception if not
let not_zero = ctx
.builder
.build_int_compare(
IntPredicate::NE,
step,
step.get_type().const_zero(),
"range_step_ne",
)
.unwrap();
ctx.make_assert(
generator,
not_zero,
"0:ValueError",
"slice step cannot be zero",
[None, None, None],
ctx.current_loc,
);
let len_id = ctx.builder.build_int_sub(length, one, "lenmin1").unwrap();
let neg = ctx
.builder
.build_int_compare(IntPredicate::SLT, step, zero, "step_is_neg")
.unwrap();
(
match s {
Some(s) => {
let Some(s) = handle_slice_index_bound(s, ctx, generator, length)? else {
return Ok(None);
};
ctx.builder
.build_select(
ctx.builder
.build_and(
ctx.builder
.build_int_compare(
IntPredicate::EQ,
s,
length,
"s_eq_len",
)
.unwrap(),
neg,
"should_minus_one",
)
.unwrap(),
ctx.builder.build_int_sub(s, one, "s_min").unwrap(),
s,
"final_start",
)
.map(BasicValueEnum::into_int_value)
.unwrap()
}
None => ctx
.builder
.build_select(neg, len_id, zero, "stt")
.map(BasicValueEnum::into_int_value)
.unwrap(),
},
match e {
Some(e) => {
let Some(e) = handle_slice_index_bound(e, ctx, generator, length)? else {
return Ok(None);
};
ctx.builder
.build_select(
neg,
ctx.builder.build_int_add(e, one, "end_add_one").unwrap(),
ctx.builder.build_int_sub(e, one, "end_sub_one").unwrap(),
"final_end",
)
.map(BasicValueEnum::into_int_value)
.unwrap()
}
None => ctx
.builder
.build_select(neg, zero, len_id, "end")
.map(BasicValueEnum::into_int_value)
.unwrap(),
},
step,
)
}
}))
}

View File

@ -1,384 +0,0 @@
use inkwell::{
types::IntType,
values::{BasicValueEnum, CallSiteValue, IntValue},
AddressSpace, IntPredicate,
};
use itertools::Either;
use crate::codegen::{
llvm_intrinsics,
macros::codegen_unreachable,
stmt::gen_for_callback_incrementing,
values::{
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, NDArrayValue, TypedArrayLikeAccessor,
TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
},
CodeGenContext, CodeGenerator,
};
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
/// calculated total size.
///
/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension.
/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for,
/// or [`None`] if starting from the first dimension and ending at the last dimension
/// respectively.
pub fn call_ndarray_calc_size<'ctx, G, Dims>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
dims: &Dims,
(begin, end): (Option<IntValue<'ctx>>, Option<IntValue<'ctx>>),
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Dims: ArrayLikeIndexer<'ctx>,
{
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_size",
64 => "__nac3_ndarray_calc_size64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_calc_size_fn_t = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()],
false,
);
let ndarray_calc_size_fn =
ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| {
ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
});
let begin = begin.unwrap_or_else(|| llvm_usize.const_zero());
let end = end.unwrap_or_else(|| dims.size(ctx, generator));
ctx.builder
.build_call(
ndarray_calc_size_fn,
&[
dims.base_ptr(ctx, generator).into(),
dims.size(ctx, generator).into(),
begin.into(),
end.into(),
],
"",
)
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`]
/// containing `i32` indices of the flattened index.
///
/// * `index` - The index to compute the multidimensional index for.
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`.
pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &mut CodeGenContext<'ctx, '_>,
index: IntValue<'ctx>,
ndarray: NDArrayValue<'ctx>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_void = ctx.ctx.void_type();
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_nd_indices",
64 => "__nac3_ndarray_calc_nd_indices64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_calc_nd_indices_fn =
ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| {
let fn_type = llvm_void.fn_type(
&[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()],
false,
);
ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None)
});
let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.shape();
let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap();
ctx.builder
.build_call(
ndarray_calc_nd_indices_fn,
&[
index.into(),
ndarray_dims.base_ptr(ctx, generator).into(),
ndarray_num_dims.into(),
indices.into(),
],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None),
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}
fn call_ndarray_flatten_index_impl<'ctx, G, Indices>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: &Indices,
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Indices: ArrayLikeIndexer<'ctx>,
{
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
debug_assert_eq!(
IntType::try_from(indices.element_type(ctx, generator))
.map(IntType::get_bit_width)
.unwrap_or_default(),
llvm_i32.get_bit_width(),
"Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`"
);
debug_assert_eq!(
indices.size(ctx, generator).get_type().get_bit_width(),
llvm_usize.get_bit_width(),
"Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`"
);
let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_flatten_index",
64 => "__nac3_ndarray_flatten_index64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_flatten_index_fn =
ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()],
false,
);
ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None)
});
let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.shape();
let index = ctx
.builder
.build_call(
ndarray_flatten_index_fn,
&[
ndarray_dims.base_ptr(ctx, generator).into(),
ndarray_num_dims.into(),
indices.base_ptr(ctx, generator).into(),
indices.size(ctx, generator).into(),
],
"",
)
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap();
index
}
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
/// multidimensional index.
///
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`.
/// * `indices` - The multidimensional index to compute the flattened index for.
pub fn call_ndarray_flatten_index<'ctx, G, Index>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: &Index,
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Index: ArrayLikeIndexer<'ctx>,
{
call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices)
}
/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of
/// dimension and size of each dimension of the resultant `ndarray`.
pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
lhs: NDArrayValue<'ctx>,
rhs: NDArrayValue<'ctx>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast",
64 => "__nac3_ndarray_calc_broadcast64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_calc_broadcast_fn =
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
],
false,
);
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
});
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_ndims = rhs.load_ndims(ctx);
let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None);
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(min_ndims, false),
|generator, ctx, _, idx| {
let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap();
let (lhs_dim_sz, rhs_dim_sz) = unsafe {
(
lhs.shape().get_typed_unchecked(ctx, generator, &idx, None),
rhs.shape().get_typed_unchecked(ctx, generator, &idx, None),
)
};
let llvm_usize_const_one = llvm_usize.const_int(1, false);
let lhs_eqz = ctx
.builder
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "")
.unwrap();
let rhs_eqz = ctx
.builder
.build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "")
.unwrap();
let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap();
let lhs_eq_rhs = ctx
.builder
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "")
.unwrap();
let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap();
ctx.make_assert(
generator,
is_compatible,
"0:ValueError",
"operands could not be broadcast together",
[None, None, None],
ctx.current_loc,
);
Ok(())
},
llvm_usize.const_int(1, false),
)
.unwrap();
let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None);
let lhs_dims = lhs.shape().base_ptr(ctx, generator);
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_dims = rhs.shape().base_ptr(ctx, generator);
let rhs_ndims = rhs.load_ndims(ctx);
let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap();
let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None);
ctx.builder
.build_call(
ndarray_calc_broadcast_fn,
&[
lhs_dims.into(),
lhs_ndims.into(),
rhs_dims.into(),
rhs_ndims.into(),
out_dims.base_ptr(ctx, generator).into(),
],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
out_dims,
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}
/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`]
/// containing the indices used for accessing `array` corresponding to the index of the broadcasted
/// array `broadcast_idx`.
pub fn call_ndarray_calc_broadcast_index<
'ctx,
G: CodeGenerator + ?Sized,
BroadcastIdx: UntypedArrayLikeAccessor<'ctx>,
>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
array: NDArrayValue<'ctx>,
broadcast_idx: &BroadcastIdx,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast_idx",
64 => "__nac3_ndarray_calc_broadcast_idx64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_calc_broadcast_fn =
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()],
false,
);
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
});
let broadcast_size = broadcast_idx.size(ctx, generator);
let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();
let array_dims = array.shape().base_ptr(ctx, generator);
let array_ndims = array.load_ndims(ctx);
let broadcast_idx_ptr = unsafe {
broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
ctx.builder
.build_call(
ndarray_calc_broadcast_fn,
&[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None),
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}

View File

@ -1,76 +0,0 @@
use inkwell::{
values::{BasicValueEnum, CallSiteValue, IntValue},
IntPredicate,
};
use itertools::Either;
use nac3parser::ast::Expr;
use crate::{
codegen::{CodeGenContext, CodeGenerator},
typecheck::typedef::Type,
};
/// this function allows index out of range, since python
/// allows index out of range in slice (`a = [1,2,3]; a[1:10] == [2,3]`).
pub fn handle_slice_index_bound<'ctx, G: CodeGenerator>(
i: &Expr<Option<Type>>,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
length: IntValue<'ctx>,
) -> Result<Option<IntValue<'ctx>>, String> {
const SYMBOL: &str = "__nac3_slice_index_bound";
let func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
let i32_t = ctx.ctx.i32_type();
let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into()], false);
ctx.module.add_function(SYMBOL, fn_t, None)
});
let i = if let Some(v) = generator.gen_expr(ctx, i)? {
v.to_basic_value_enum(ctx, generator, i.custom.unwrap())?
} else {
return Ok(None);
};
Ok(Some(
ctx.builder
.build_call(func, &[i.into(), length.into()], "bounded_ind")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap(),
))
}
pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
start: IntValue<'ctx>,
end: IntValue<'ctx>,
step: IntValue<'ctx>,
) -> IntValue<'ctx> {
const SYMBOL: &str = "__nac3_range_slice_len";
let len_func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
let i32_t = ctx.ctx.i32_type();
let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into(), i32_t.into()], false);
ctx.module.add_function(SYMBOL, fn_t, None)
});
// assert step != 0, throw exception if not
let not_zero = ctx
.builder
.build_int_compare(IntPredicate::NE, step, step.get_type().const_zero(), "range_step_ne")
.unwrap();
ctx.make_assert(
generator,
not_zero,
"0:ValueError",
"step must not be zero",
[None, None, None],
ctx.current_loc,
);
ctx.builder
.build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}

View File

@ -1,345 +0,0 @@
use inkwell::{
context::Context,
intrinsics::Intrinsic,
types::{AnyTypeEnum::IntType, FloatType},
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue},
AddressSpace,
};
use itertools::Either;
use super::CodeGenContext;
/// Returns the string representation for the floating-point type `ft` when used in intrinsic
/// functions.
fn get_float_intrinsic_repr(ctx: &Context, ft: FloatType) -> &'static str {
// Standard LLVM floating-point types
if ft == ctx.f16_type() {
return "f16";
}
if ft == ctx.f32_type() {
return "f32";
}
if ft == ctx.f64_type() {
return "f64";
}
if ft == ctx.f128_type() {
return "f128";
}
// Non-standard floating-point types
if ft == ctx.x86_f80_type() {
return "f80";
}
if ft == ctx.ppc_f128_type() {
return "ppcf128";
}
unreachable!()
}
/// Invokes the [`llvm.va_start`](https://llvm.org/docs/LangRef.html#llvm-va-start-intrinsic)
/// intrinsic.
pub fn call_va_start<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue<'ctx>) {
const FN_NAME: &str = "llvm.va_start";
let intrinsic_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let llvm_void = ctx.ctx.void_type();
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let fn_type = llvm_void.fn_type(&[llvm_p0i8.into()], false);
ctx.module.add_function(FN_NAME, fn_type, None)
});
ctx.builder.build_call(intrinsic_fn, &[arglist.into()], "").unwrap();
}
/// Invokes the [`llvm.va_start`](https://llvm.org/docs/LangRef.html#llvm-va-start-intrinsic)
/// intrinsic.
pub fn call_va_end<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue<'ctx>) {
const FN_NAME: &str = "llvm.va_end";
let intrinsic_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let llvm_void = ctx.ctx.void_type();
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let fn_type = llvm_void.fn_type(&[llvm_p0i8.into()], false);
ctx.module.add_function(FN_NAME, fn_type, None)
});
ctx.builder.build_call(intrinsic_fn, &[arglist.into()], "").unwrap();
}
/// Invokes the [`llvm.stacksave`](https://llvm.org/docs/LangRef.html#llvm-stacksave-intrinsic)
/// intrinsic.
pub fn call_stacksave<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
name: Option<&str>,
) -> PointerValue<'ctx> {
const FN_NAME: &str = "llvm.stacksave";
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[]))
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_pointer_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the
/// [`llvm.stackrestore`](https://llvm.org/docs/LangRef.html#llvm-stackrestore-intrinsic) intrinsic.
///
/// - `ptr`: The pointer storing the address to restore the stack to.
pub fn call_stackrestore<'ctx>(ctx: &CodeGenContext<'ctx, '_>, ptr: PointerValue<'ctx>) {
const FN_NAME: &str = "llvm.stackrestore";
/*
SEE https://github.com/TheDan64/inkwell/issues/496
We want `llvm.stackrestore`, but the following would generate `llvm.stackrestore.p0i8`.
```ignore
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_p0i8.into()]))
.unwrap();
```
Temp workaround by manually declaring the intrinsic with the correct function name instead.
*/
let intrinsic_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let llvm_void = ctx.ctx.void_type();
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let fn_type = llvm_void.fn_type(&[llvm_p0i8.into()], false);
ctx.module.add_function(FN_NAME, fn_type, None)
});
ctx.builder.build_call(intrinsic_fn, &[ptr.into()], "").unwrap();
}
/// Invokes the [`llvm.memcpy`](https://llvm.org/docs/LangRef.html#llvm-memcpy-intrinsic) intrinsic.
///
/// * `dest` - The pointer to the destination. Must be a pointer to an integer type.
/// * `src` - The pointer to the source. Must be a pointer to an integer type.
/// * `len` - The number of bytes to copy.
/// * `is_volatile` - Whether the `memcpy` operation should be `volatile`.
pub fn call_memcpy<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
dest: PointerValue<'ctx>,
src: PointerValue<'ctx>,
len: IntValue<'ctx>,
is_volatile: IntValue<'ctx>,
) {
const FN_NAME: &str = "llvm.memcpy";
debug_assert!(dest.get_type().get_element_type().is_int_type());
debug_assert!(src.get_type().get_element_type().is_int_type());
debug_assert_eq!(
dest.get_type().get_element_type().into_int_type().get_bit_width(),
src.get_type().get_element_type().into_int_type().get_bit_width(),
);
debug_assert!(matches!(len.get_type().get_bit_width(), 32 | 64));
debug_assert_eq!(is_volatile.get_type().get_bit_width(), 1);
let llvm_dest_t = dest.get_type();
let llvm_src_t = src.get_type();
let llvm_len_t = len.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| {
intrinsic.get_declaration(
&ctx.module,
&[llvm_dest_t.into(), llvm_src_t.into(), llvm_len_t.into()],
)
})
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[dest.into(), src.into(), len.into(), is_volatile.into()], "")
.unwrap();
}
/// Invokes the `llvm.memcpy` intrinsic.
///
/// Unlike [`call_memcpy`], this function accepts any type of pointer value. If `dest` or `src` is
/// not a pointer to an integer, the pointer(s) will be cast to `i8*` before invoking `memcpy`.
pub fn call_memcpy_generic<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
dest: PointerValue<'ctx>,
src: PointerValue<'ctx>,
len: IntValue<'ctx>,
is_volatile: IntValue<'ctx>,
) {
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let dest_elem_t = dest.get_type().get_element_type();
let src_elem_t = src.get_type().get_element_type();
let dest = if matches!(dest_elem_t, IntType(t) if t.get_bit_width() == 8) {
dest
} else {
ctx.builder
.build_bit_cast(dest, llvm_p0i8, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap()
};
let src = if matches!(src_elem_t, IntType(t) if t.get_bit_width() == 8) {
src
} else {
ctx.builder
.build_bit_cast(src, llvm_p0i8, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap()
};
call_memcpy(ctx, dest, src, len, is_volatile);
}
/// Macro to find and generate build call for llvm intrinsic (body of llvm intrinsic function)
///
/// Arguments:
/// * `$ctx:ident`: Reference to the current Code Generation Context
/// * `$name:ident`: Optional name to be assigned to the llvm build call (Option<&str>)
/// * `$llvm_name:literal`: Name of underlying llvm intrinsic function
/// * `$map_fn:ident`: Mapping function to be applied on `BasicValue` (`BasicValue` -> Function Return Type).
/// Use `BasicValueEnum::into_int_value` for Integer return type and
/// `BasicValueEnum::into_float_value` for Float return type
/// * `$llvm_ty:ident`: Type of first operand
/// * `,($val:ident)*`: Comma separated list of operands
macro_rules! generate_llvm_intrinsic_fn_body {
($ctx:ident, $name:ident, $llvm_name:literal, $map_fn:expr, $llvm_ty:ident $(,$val:ident)*) => {{
const FN_NAME: &str = concat!("llvm.", $llvm_name);
let intrinsic_fn = Intrinsic::find(FN_NAME).and_then(|intrinsic| intrinsic.get_declaration(&$ctx.module, &[$llvm_ty.into()])).unwrap();
$ctx.builder.build_call(intrinsic_fn, &[$($val.into()),*], $name.unwrap_or_default()).map(CallSiteValue::try_as_basic_value).map(|v| v.map_left($map_fn)).map(Either::unwrap_left).unwrap()
}};
}
/// Macro to generate the llvm intrinsic function using [`generate_llvm_intrinsic_fn_body`].
///
/// Arguments:
/// * `float/int`: Indicates the return and argument type of the function
/// * `$fn_name:ident`: The identifier of the rust function to be generated
/// * `$llvm_name:literal`: Name of underlying llvm intrinsic function.
/// Omit "llvm." prefix from the function name i.e. use "ceil" instead of "llvm.ceil"
/// * `$val:ident`: The operand for unary operations
/// * `$val1:ident`, `$val2:ident`: The operands for binary operations
macro_rules! generate_llvm_intrinsic_fn {
("float", $fn_name:ident, $llvm_name:literal, $val:ident) => {
#[doc = concat!("Invokes the [`", stringify!($llvm_name), "`](https://llvm.org/docs/LangRef.html#llvm-", stringify!($llvm_name), "-intrinsic) intrinsic." )]
pub fn $fn_name<'ctx> (
ctx: &CodeGenContext<'ctx, '_>,
$val: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
let llvm_ty = $val.get_type();
generate_llvm_intrinsic_fn_body!(ctx, name, $llvm_name, BasicValueEnum::into_float_value, llvm_ty, $val)
}
};
("float", $fn_name:ident, $llvm_name:literal, $val1:ident, $val2:ident) => {
#[doc = concat!("Invokes the [`", stringify!($llvm_name), "`](https://llvm.org/docs/LangRef.html#llvm-", stringify!($llvm_name), "-intrinsic) intrinsic." )]
pub fn $fn_name<'ctx> (
ctx: &CodeGenContext<'ctx, '_>,
$val1: FloatValue<'ctx>,
$val2: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
debug_assert_eq!($val1.get_type(), $val2.get_type());
let llvm_ty = $val1.get_type();
generate_llvm_intrinsic_fn_body!(ctx, name, $llvm_name, BasicValueEnum::into_float_value, llvm_ty, $val1, $val2)
}
};
("int", $fn_name:ident, $llvm_name:literal, $val1:ident, $val2:ident) => {
#[doc = concat!("Invokes the [`", stringify!($llvm_name), "`](https://llvm.org/docs/LangRef.html#llvm-", stringify!($llvm_name), "-intrinsic) intrinsic." )]
pub fn $fn_name<'ctx> (
ctx: &CodeGenContext<'ctx, '_>,
$val1: IntValue<'ctx>,
$val2: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
debug_assert_eq!($val1.get_type().get_bit_width(), $val2.get_type().get_bit_width());
let llvm_ty = $val1.get_type();
generate_llvm_intrinsic_fn_body!(ctx, name, $llvm_name, BasicValueEnum::into_int_value, llvm_ty, $val1, $val2)
}
};
}
/// Invokes the [`llvm.abs`](https://llvm.org/docs/LangRef.html#llvm-abs-intrinsic) intrinsic.
///
/// * `src` - The value for which the absolute value is to be returned.
/// * `is_int_min_poison` - Whether `poison` is to be returned if `src` is `INT_MIN`.
pub fn call_int_abs<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
src: IntValue<'ctx>,
is_int_min_poison: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
debug_assert_eq!(is_int_min_poison.get_type().get_bit_width(), 1);
debug_assert!(is_int_min_poison.is_const());
let src_type = src.get_type();
generate_llvm_intrinsic_fn_body!(
ctx,
name,
"abs",
BasicValueEnum::into_int_value,
src_type,
src,
is_int_min_poison
)
}
generate_llvm_intrinsic_fn!("int", call_int_smax, "smax", a, b);
generate_llvm_intrinsic_fn!("int", call_int_smin, "smin", a, b);
generate_llvm_intrinsic_fn!("int", call_int_umax, "umax", a, b);
generate_llvm_intrinsic_fn!("int", call_int_umin, "umin", a, b);
generate_llvm_intrinsic_fn!("int", call_expect, "expect", val, expected_val);
generate_llvm_intrinsic_fn!("float", call_float_sqrt, "sqrt", val);
generate_llvm_intrinsic_fn!("float", call_float_sin, "sin", val);
generate_llvm_intrinsic_fn!("float", call_float_cos, "cos", val);
generate_llvm_intrinsic_fn!("float", call_float_pow, "pow", val, power);
generate_llvm_intrinsic_fn!("float", call_float_exp, "exp", val);
generate_llvm_intrinsic_fn!("float", call_float_exp2, "exp2", val);
generate_llvm_intrinsic_fn!("float", call_float_log, "log", val);
generate_llvm_intrinsic_fn!("float", call_float_log10, "log10", val);
generate_llvm_intrinsic_fn!("float", call_float_log2, "log2", val);
generate_llvm_intrinsic_fn!("float", call_float_fabs, "fabs", src);
generate_llvm_intrinsic_fn!("float", call_float_minnum, "minnum", val, power);
generate_llvm_intrinsic_fn!("float", call_float_maxnum, "maxnum", val, power);
generate_llvm_intrinsic_fn!("float", call_float_copysign, "copysign", mag, sgn);
generate_llvm_intrinsic_fn!("float", call_float_floor, "floor", val);
generate_llvm_intrinsic_fn!("float", call_float_ceil, "ceil", val);
generate_llvm_intrinsic_fn!("float", call_float_round, "round", val);
generate_llvm_intrinsic_fn!("float", call_float_rint, "rint", val);
/// Invokes the [`llvm.powi`](https://llvm.org/docs/LangRef.html#llvm-powi-intrinsic) intrinsic.
pub fn call_float_powi<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
power: IntValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "llvm.powi";
let llvm_val_t = val.get_type();
let llvm_power_t = power.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| {
intrinsic.get_declaration(&ctx.module, &[llvm_val_t.into(), llvm_power_t.into()])
})
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[val.into(), power.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,476 +0,0 @@
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use indexmap::IndexMap;
use indoc::indoc;
use inkwell::{
targets::{InitializationConfig, Target},
OptimizationLevel,
};
use nac3parser::{
ast::{fold::Fold, FileName, StrRef},
parser::parse_program,
};
use parking_lot::RwLock;
use super::{
concrete_type::ConcreteTypeStore,
types::{ListType, NDArrayType, ProxyType, RangeType},
CodeGenContext, CodeGenLLVMOptions, CodeGenTargetMachineOptions, CodeGenTask, CodeGenerator,
DefaultCodeGenerator, WithCall, WorkerRegistry,
};
use crate::{
symbol_resolver::{SymbolResolver, ValueEnum},
toplevel::{
composer::{ComposerConfig, TopLevelComposer},
DefinitionId, FunInstance, TopLevelContext, TopLevelDef,
},
typecheck::{
type_inferencer::{FunctionData, IdentifierInfo, Inferencer, PrimitiveStore},
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
},
};
struct Resolver {
id_to_type: HashMap<StrRef, Type>,
id_to_def: RwLock<HashMap<StrRef, DefinitionId>>,
class_names: HashMap<StrRef, Type>,
}
impl Resolver {
pub fn add_id_def(&self, id: StrRef, def: DefinitionId) {
self.id_to_def.write().insert(id, def);
}
}
impl SymbolResolver for Resolver {
fn get_default_param_value(
&self,
_: &nac3parser::ast::Expr,
) -> Option<crate::symbol_resolver::SymbolValue> {
unimplemented!()
}
fn get_symbol_type(
&self,
_: &mut Unifier,
_: &[Arc<RwLock<TopLevelDef>>],
_: &PrimitiveStore,
str: StrRef,
) -> Result<Type, String> {
self.id_to_type.get(&str).copied().ok_or_else(|| format!("cannot find symbol `{str}`"))
}
fn get_symbol_value<'ctx>(
&self,
_: StrRef,
_: &mut CodeGenContext<'ctx, '_>,
_: &mut dyn CodeGenerator,
) -> Option<ValueEnum<'ctx>> {
unimplemented!()
}
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> {
self.id_to_def
.read()
.get(&id)
.copied()
.ok_or_else(|| HashSet::from([format!("cannot find symbol `{id}`")]))
}
fn get_string_id(&self, _: &str) -> i32 {
unimplemented!()
}
fn get_exception_id(&self, _tyid: usize) -> usize {
unimplemented!()
}
}
#[test]
fn test_primitives() {
let source = indoc! { "
c = a + b
d = a if c == 1 else 0
return d
"};
let statements = parse_program(source, FileName::default()).unwrap();
let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 32).0;
let mut unifier = composer.unifier.clone();
let primitives = composer.primitives_ty;
let top_level = Arc::new(composer.make_top_level_context());
unifier.top_level = Some(top_level.clone());
let resolver = Arc::new(Resolver {
id_to_type: HashMap::new(),
id_to_def: RwLock::new(HashMap::new()),
class_names: HashMap::default(),
}) as Arc<dyn SymbolResolver + Send + Sync>;
let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()];
let signature = FunSignature {
args: vec![
FuncArg {
name: "a".into(),
ty: primitives.int32,
default_value: None,
is_vararg: false,
},
FuncArg {
name: "b".into(),
ty: primitives.int32,
default_value: None,
is_vararg: false,
},
],
ret: primitives.int32,
vars: VarMap::new(),
};
let mut store = ConcreteTypeStore::new();
let mut cache = HashMap::new();
let signature = store.from_signature(&mut unifier, &primitives, &signature, &mut cache);
let signature = store.add_cty(signature);
let mut function_data = FunctionData {
resolver: resolver.clone(),
bound_variables: Vec::new(),
return_type: Some(primitives.int32),
};
let mut virtual_checks = Vec::new();
let mut calls = HashMap::new();
let mut identifiers: HashMap<_, _> =
["a".into(), "b".into()].map(|id| (id, IdentifierInfo::default())).into();
let mut inferencer = Inferencer {
top_level: &top_level,
function_data: &mut function_data,
unifier: &mut unifier,
variable_mapping: HashMap::default(),
primitives: &primitives,
virtual_checks: &mut virtual_checks,
calls: &mut calls,
defined_identifiers: identifiers.clone(),
in_handler: false,
};
inferencer.variable_mapping.insert("a".into(), inferencer.primitives.int32);
inferencer.variable_mapping.insert("b".into(), inferencer.primitives.int32);
let statements = statements
.into_iter()
.map(|v| inferencer.fold_stmt(v))
.collect::<Result<Vec<_>, _>>()
.unwrap();
inferencer.check_block(&statements, &mut identifiers).unwrap();
let top_level = Arc::new(TopLevelContext {
definitions: Arc::new(RwLock::new(std::mem::take(&mut *top_level.definitions.write()))),
unifiers: Arc::new(RwLock::new(vec![(unifier.get_shared_unifier(), primitives)])),
personality_symbol: None,
});
let task = CodeGenTask {
subst: Vec::default(),
symbol_name: "testing".into(),
body: Arc::new(statements),
unifier_index: 0,
calls: Arc::new(calls),
resolver,
store,
signature,
id: 0,
};
let f = Arc::new(WithCall::new(Box::new(|module| {
// the following IR is equivalent to
// ```
// ; ModuleID = 'test.ll'
// source_filename = "test"
//
// ; Function Attrs: norecurse nounwind readnone
// define i32 @testing(i32 %0, i32 %1) local_unnamed_addr #0 {
// init:
// %add = add i32 %1, %0
// %cmp = icmp eq i32 %add, 1
// %ifexpr = select i1 %cmp, i32 %0, i32 0
// ret i32 %ifexpr
// }
//
// attributes #0 = { norecurse nounwind readnone }
// ```
// after O2 optimization
let expected = indoc! {"
; ModuleID = 'test'
source_filename = \"test\"
target datalayout = \"e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128\"
target triple = \"x86_64-unknown-linux-gnu\"
; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn
define i32 @testing(i32 %0, i32 %1) local_unnamed_addr #0 !dbg !4 {
init:
%add = add i32 %1, %0, !dbg !9
%cmp = icmp eq i32 %add, 1, !dbg !10
%. = select i1 %cmp, i32 %0, i32 0, !dbg !11
ret i32 %., !dbg !12
}
attributes #0 = { mustprogress nofree norecurse nosync nounwind readnone willreturn }
!llvm.module.flags = !{!0, !1}
!llvm.dbg.cu = !{!2}
!0 = !{i32 2, !\"Debug Info Version\", i32 3}
!1 = !{i32 2, !\"Dwarf Version\", i32 4}
!2 = distinct !DICompileUnit(language: DW_LANG_Python, file: !3, producer: \"NAC3\", isOptimized: true, runtimeVersion: 0, emissionKind: FullDebug)
!3 = !DIFile(filename: \"unknown\", directory: \"\")
!4 = distinct !DISubprogram(name: \"testing\", linkageName: \"testing\", scope: null, file: !3, line: 1, type: !5, scopeLine: 1, flags: DIFlagPublic, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !2, retainedNodes: !8)
!5 = !DISubroutineType(flags: DIFlagPublic, types: !6)
!6 = !{!7}
!7 = !DIBasicType(name: \"_\", flags: DIFlagPublic)
!8 = !{}
!9 = !DILocation(line: 1, column: 9, scope: !4)
!10 = !DILocation(line: 2, column: 15, scope: !4)
!11 = !DILocation(line: 0, scope: !4)
!12 = !DILocation(line: 3, column: 8, scope: !4)
"}
.trim();
assert_eq!(expected, module.print_to_string().to_str().unwrap().trim());
})));
Target::initialize_all(&InitializationConfig::default());
let llvm_options = CodeGenLLVMOptions {
opt_level: OptimizationLevel::Default,
target: CodeGenTargetMachineOptions::from_host_triple(),
};
let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f);
registry.add_task(task);
registry.wait_tasks_complete(handles);
}
#[test]
fn test_simple_call() {
let source_1 = indoc! { "
a = foo(a)
return a * 2
"};
let statements_1 = parse_program(source_1, FileName::default()).unwrap();
let source_2 = indoc! { "
return a + 1
"};
let statements_2 = parse_program(source_2, FileName::default()).unwrap();
let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 32).0;
let mut unifier = composer.unifier.clone();
let primitives = composer.primitives_ty;
let top_level = Arc::new(composer.make_top_level_context());
unifier.top_level = Some(top_level.clone());
let signature = FunSignature {
args: vec![FuncArg {
name: "a".into(),
ty: primitives.int32,
default_value: None,
is_vararg: false,
}],
ret: primitives.int32,
vars: VarMap::new(),
};
let fun_ty = unifier.add_ty(TypeEnum::TFunc(signature.clone()));
let mut store = ConcreteTypeStore::new();
let mut cache = HashMap::new();
let signature = store.from_signature(&mut unifier, &primitives, &signature, &mut cache);
let signature = store.add_cty(signature);
let foo_id = top_level.definitions.read().len();
top_level.definitions.write().push(Arc::new(RwLock::new(TopLevelDef::Function {
name: "foo".to_string(),
simple_name: "foo".into(),
signature: fun_ty,
var_id: vec![],
instance_to_stmt: HashMap::new(),
instance_to_symbol: HashMap::new(),
resolver: None,
codegen_callback: None,
loc: None,
})));
let resolver = Resolver {
id_to_type: HashMap::new(),
id_to_def: RwLock::new(HashMap::new()),
class_names: HashMap::default(),
};
resolver.add_id_def("foo".into(), DefinitionId(foo_id));
let resolver = Arc::new(resolver) as Arc<dyn SymbolResolver + Send + Sync>;
if let TopLevelDef::Function { resolver: r, .. } =
&mut *top_level.definitions.read()[foo_id].write()
{
*r = Some(resolver.clone());
} else {
unreachable!()
}
let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()];
let mut function_data = FunctionData {
resolver: resolver.clone(),
bound_variables: Vec::new(),
return_type: Some(primitives.int32),
};
let mut virtual_checks = Vec::new();
let mut calls = HashMap::new();
let mut identifiers: HashMap<_, _> =
["a".into(), "foo".into()].map(|id| (id, IdentifierInfo::default())).into();
let mut inferencer = Inferencer {
top_level: &top_level,
function_data: &mut function_data,
unifier: &mut unifier,
variable_mapping: HashMap::default(),
primitives: &primitives,
virtual_checks: &mut virtual_checks,
calls: &mut calls,
defined_identifiers: identifiers.clone(),
in_handler: false,
};
inferencer.variable_mapping.insert("a".into(), inferencer.primitives.int32);
inferencer.variable_mapping.insert("foo".into(), fun_ty);
let statements_1 = statements_1
.into_iter()
.map(|v| inferencer.fold_stmt(v))
.collect::<Result<Vec<_>, _>>()
.unwrap();
let calls1 = inferencer.calls.clone();
inferencer.calls.clear();
let statements_2 = statements_2
.into_iter()
.map(|v| inferencer.fold_stmt(v))
.collect::<Result<Vec<_>, _>>()
.unwrap();
if let TopLevelDef::Function { instance_to_stmt, .. } =
&mut *top_level.definitions.read()[foo_id].write()
{
instance_to_stmt.insert(
String::new(),
FunInstance {
body: Arc::new(statements_2),
calls: Arc::new(inferencer.calls.clone()),
subst: IndexMap::default(),
unifier_id: 0,
},
);
} else {
unreachable!()
}
inferencer.check_block(&statements_1, &mut identifiers).unwrap();
let top_level = Arc::new(TopLevelContext {
definitions: Arc::new(RwLock::new(std::mem::take(&mut *top_level.definitions.write()))),
unifiers: Arc::new(RwLock::new(vec![(unifier.get_shared_unifier(), primitives)])),
personality_symbol: None,
});
let task = CodeGenTask {
subst: Vec::default(),
symbol_name: "testing".to_string(),
body: Arc::new(statements_1),
calls: Arc::new(calls1),
unifier_index: 0,
resolver,
signature,
store,
id: 0,
};
let f = Arc::new(WithCall::new(Box::new(|module| {
let expected = indoc! {"
; ModuleID = 'test'
source_filename = \"test\"
target datalayout = \"e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128\"
target triple = \"x86_64-unknown-linux-gnu\"
; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn
define i32 @testing(i32 %0) local_unnamed_addr #0 !dbg !5 {
init:
%add.i = shl i32 %0, 1, !dbg !10
%mul = add i32 %add.i, 2, !dbg !10
ret i32 %mul, !dbg !10
}
; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn
define i32 @foo.0(i32 %0) local_unnamed_addr #0 !dbg !11 {
init:
%add = add i32 %0, 1, !dbg !12
ret i32 %add, !dbg !12
}
attributes #0 = { mustprogress nofree norecurse nosync nounwind readnone willreturn }
!llvm.module.flags = !{!0, !1}
!llvm.dbg.cu = !{!2, !4}
!0 = !{i32 2, !\"Debug Info Version\", i32 3}
!1 = !{i32 2, !\"Dwarf Version\", i32 4}
!2 = distinct !DICompileUnit(language: DW_LANG_Python, file: !3, producer: \"NAC3\", isOptimized: true, runtimeVersion: 0, emissionKind: FullDebug)
!3 = !DIFile(filename: \"unknown\", directory: \"\")
!4 = distinct !DICompileUnit(language: DW_LANG_Python, file: !3, producer: \"NAC3\", isOptimized: true, runtimeVersion: 0, emissionKind: FullDebug)
!5 = distinct !DISubprogram(name: \"testing\", linkageName: \"testing\", scope: null, file: !3, line: 1, type: !6, scopeLine: 1, flags: DIFlagPublic, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !2, retainedNodes: !9)
!6 = !DISubroutineType(flags: DIFlagPublic, types: !7)
!7 = !{!8}
!8 = !DIBasicType(name: \"_\", flags: DIFlagPublic)
!9 = !{}
!10 = !DILocation(line: 2, column: 12, scope: !5)
!11 = distinct !DISubprogram(name: \"foo.0\", linkageName: \"foo.0\", scope: null, file: !3, line: 1, type: !6, scopeLine: 1, flags: DIFlagPublic, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !4, retainedNodes: !9)
!12 = !DILocation(line: 1, column: 12, scope: !11)
"}
.trim();
assert_eq!(expected, module.print_to_string().to_str().unwrap().trim());
})));
Target::initialize_all(&InitializationConfig::default());
let llvm_options = CodeGenLLVMOptions {
opt_level: OptimizationLevel::Default,
target: CodeGenTargetMachineOptions::from_host_triple(),
};
let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f);
registry.add_task(task);
registry.wait_tasks_complete(handles);
}
#[test]
fn test_classes_list_type_new() {
let ctx = inkwell::context::Context::create();
let generator = DefaultCodeGenerator::new(String::new(), 64);
let llvm_i32 = ctx.i32_type();
let llvm_usize = generator.get_size_type(&ctx);
let llvm_list = ListType::new(&generator, &ctx, llvm_i32.into());
assert!(ListType::is_representable(llvm_list.as_base_type(), llvm_usize).is_ok());
}
#[test]
fn test_classes_range_type_new() {
let ctx = inkwell::context::Context::create();
let llvm_range = RangeType::new(&ctx);
assert!(RangeType::is_representable(llvm_range.as_base_type()).is_ok());
}
#[test]
fn test_classes_ndarray_type_new() {
let ctx = inkwell::context::Context::create();
let generator = DefaultCodeGenerator::new(String::new(), 64);
let llvm_i32 = ctx.i32_type();
let llvm_usize = generator.get_size_type(&ctx);
let llvm_ndarray = NDArrayType::new(&generator, &ctx, llvm_i32.into());
assert!(NDArrayType::is_representable(llvm_ndarray.as_base_type(), llvm_usize).is_ok());
}

View File

@ -1,192 +0,0 @@
use inkwell::{
context::Context,
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
values::IntValue,
AddressSpace,
};
use super::ProxyType;
use crate::codegen::{
values::{ArraySliceValue, ListValue, ProxyValue},
CodeGenContext, CodeGenerator,
};
/// Proxy type for a `list` type in LLVM.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct ListType<'ctx> {
ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>,
}
impl<'ctx> ListType<'ctx> {
/// Checks whether `llvm_ty` represents a `list` type, returning [Err] if it does not.
pub fn is_representable(
llvm_ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
let llvm_list_ty = llvm_ty.get_element_type();
let AnyTypeEnum::StructType(llvm_list_ty) = llvm_list_ty else {
return Err(format!("Expected struct type for `list` type, got {llvm_list_ty}"));
};
if llvm_list_ty.count_fields() != 2 {
return Err(format!(
"Expected 2 fields in `list`, got {}",
llvm_list_ty.count_fields()
));
}
let list_size_ty = llvm_list_ty.get_field_type_at_index(0).unwrap();
let Ok(_) = PointerType::try_from(list_size_ty) else {
return Err(format!("Expected pointer type for `list.0`, got {list_size_ty}"));
};
let list_data_ty = llvm_list_ty.get_field_type_at_index(1).unwrap();
let Ok(list_data_ty) = IntType::try_from(list_data_ty) else {
return Err(format!("Expected int type for `list.1`, got {list_data_ty}"));
};
if list_data_ty.get_bit_width() != llvm_usize.get_bit_width() {
return Err(format!(
"Expected {}-bit int type for `list.1`, got {}-bit int",
llvm_usize.get_bit_width(),
list_data_ty.get_bit_width()
));
}
Ok(())
}
/// Creates an LLVM type corresponding to the expected structure of a `List`.
#[must_use]
fn llvm_type(
ctx: &'ctx Context,
element_type: BasicTypeEnum<'ctx>,
llvm_usize: IntType<'ctx>,
) -> PointerType<'ctx> {
// struct List { data: T*, size: size_t }
let field_tys = [element_type.ptr_type(AddressSpace::default()).into(), llvm_usize.into()];
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
}
/// Creates an instance of [`ListType`].
#[must_use]
pub fn new<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
element_type: BasicTypeEnum<'ctx>,
) -> Self {
let llvm_usize = generator.get_size_type(ctx);
let llvm_list = Self::llvm_type(ctx, element_type, llvm_usize);
ListType::from_type(llvm_list, llvm_usize)
}
/// Creates an [`ListType`] from a [`PointerType`].
#[must_use]
pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok());
ListType { ty: ptr_ty, llvm_usize }
}
/// Returns the type of the `size` field of this `list` type.
#[must_use]
pub fn size_type(&self) -> IntType<'ctx> {
self.as_base_type()
.get_element_type()
.into_struct_type()
.get_field_type_at_index(1)
.map(BasicTypeEnum::into_int_type)
.unwrap()
}
/// Returns the element type of this `list` type.
#[must_use]
pub fn element_type(&self) -> AnyTypeEnum<'ctx> {
self.as_base_type()
.get_element_type()
.into_struct_type()
.get_field_type_at_index(0)
.map(BasicTypeEnum::into_pointer_type)
.map(PointerType::get_element_type)
.unwrap()
}
}
impl<'ctx> ProxyType<'ctx> for ListType<'ctx> {
type Base = PointerType<'ctx>;
type Value = ListValue<'ctx>;
fn is_type<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: impl BasicType<'ctx>,
) -> Result<(), String> {
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
<Self as ProxyType<'ctx>>::is_representable(generator, ctx, ty)
} else {
Err(format!("Expected pointer type, got {llvm_ty:?}"))
}
}
fn is_representable<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: Self::Base,
) -> Result<(), String> {
Self::is_representable(llvm_ty, generator.get_size_type(ctx))
}
fn new_value<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> Self::Value {
self.map_value(
generator
.gen_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
name,
)
.unwrap(),
name,
)
}
fn new_array_value<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
size: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> ArraySliceValue<'ctx> {
generator
.gen_array_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
size,
name,
)
.unwrap()
}
fn map_value(
&self,
value: <Self::Value as ProxyValue<'ctx>>::Base,
name: Option<&'ctx str>,
) -> Self::Value {
Self::Value::from_pointer_value(value, self.llvm_usize, name)
}
fn as_base_type(&self) -> Self::Base {
self.ty
}
}
impl<'ctx> From<ListType<'ctx>> for PointerType<'ctx> {
fn from(value: ListType<'ctx>) -> Self {
value.as_base_type()
}
}

View File

@ -1,64 +0,0 @@
use inkwell::{context::Context, types::BasicType, values::IntValue};
use super::{
values::{ArraySliceValue, ProxyValue},
{CodeGenContext, CodeGenerator},
};
pub use list::*;
pub use ndarray::*;
pub use range::*;
mod list;
mod ndarray;
mod range;
pub mod structure;
/// A LLVM type that is used to represent a corresponding type in NAC3.
pub trait ProxyType<'ctx>: Into<Self::Base> {
/// The LLVM type of which values of this type possess. This is usually a
/// [LLVM pointer type][PointerType] for any non-primitive types.
type Base: BasicType<'ctx>;
/// The type of values represented by this type.
type Value: ProxyValue<'ctx, Type = Self>;
fn is_type<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: impl BasicType<'ctx>,
) -> Result<(), String>;
/// Checks whether `llvm_ty` can be represented by this [`ProxyType`].
fn is_representable<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: Self::Base,
) -> Result<(), String>;
/// Creates a new value of this type.
fn new_value<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> Self::Value;
/// Creates a new array value of this type.
fn new_array_value<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
size: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> ArraySliceValue<'ctx>;
/// Converts an existing value into a [`ProxyValue`] of this type.
fn map_value(
&self,
value: <Self::Value as ProxyValue<'ctx>>::Base,
name: Option<&'ctx str>,
) -> Self::Value;
/// Returns the [base type][Self::Base] of this proxy.
fn as_base_type(&self) -> Self::Base;
}

View File

@ -1,258 +0,0 @@
use inkwell::{
context::Context,
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
values::{IntValue, PointerValue},
AddressSpace,
};
use itertools::Itertools;
use nac3core_derive::StructFields;
use super::{
structure::{StructField, StructFields},
ProxyType,
};
use crate::codegen::{
values::{ArraySliceValue, NDArrayValue, ProxyValue},
{CodeGenContext, CodeGenerator},
};
/// Proxy type for a `ndarray` type in LLVM.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct NDArrayType<'ctx> {
ty: PointerType<'ctx>,
dtype: BasicTypeEnum<'ctx>,
llvm_usize: IntType<'ctx>,
}
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct NDArrayStructFields<'ctx> {
#[value_type(usize)]
pub ndims: StructField<'ctx, IntValue<'ctx>>,
#[value_type(usize.ptr_type(AddressSpace::default()))]
pub shape: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
pub data: StructField<'ctx, PointerValue<'ctx>>,
}
impl<'ctx> NDArrayType<'ctx> {
/// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not.
pub fn is_representable(
llvm_ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
let llvm_ndarray_ty = llvm_ty.get_element_type();
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}"));
};
if llvm_ndarray_ty.count_fields() != 3 {
return Err(format!(
"Expected 3 fields in `NDArray`, got {}",
llvm_ndarray_ty.count_fields()
));
}
let ndarray_ndims_ty = llvm_ndarray_ty.get_field_type_at_index(0).unwrap();
let Ok(ndarray_ndims_ty) = IntType::try_from(ndarray_ndims_ty) else {
return Err(format!("Expected int type for `ndarray.0`, got {ndarray_ndims_ty}"));
};
if ndarray_ndims_ty.get_bit_width() != llvm_usize.get_bit_width() {
return Err(format!(
"Expected {}-bit int type for `ndarray.0`, got {}-bit int",
llvm_usize.get_bit_width(),
ndarray_ndims_ty.get_bit_width()
));
}
let ndarray_dims_ty = llvm_ndarray_ty.get_field_type_at_index(1).unwrap();
let Ok(ndarray_pdims) = PointerType::try_from(ndarray_dims_ty) else {
return Err(format!("Expected pointer type for `ndarray.1`, got {ndarray_dims_ty}"));
};
let ndarray_dims = ndarray_pdims.get_element_type();
let Ok(ndarray_dims) = IntType::try_from(ndarray_dims) else {
return Err(format!(
"Expected pointer-to-int type for `ndarray.1`, got pointer-to-{ndarray_dims}"
));
};
if ndarray_dims.get_bit_width() != llvm_usize.get_bit_width() {
return Err(format!(
"Expected pointer-to-{}-bit int type for `ndarray.1`, got pointer-to-{}-bit int",
llvm_usize.get_bit_width(),
ndarray_dims.get_bit_width()
));
}
let ndarray_data_ty = llvm_ndarray_ty.get_field_type_at_index(2).unwrap();
let Ok(ndarray_pdata) = PointerType::try_from(ndarray_data_ty) else {
return Err(format!("Expected pointer type for `ndarray.2`, got {ndarray_data_ty}"));
};
let ndarray_data = ndarray_pdata.get_element_type();
let Ok(ndarray_data) = IntType::try_from(ndarray_data) else {
return Err(format!(
"Expected pointer-to-int type for `ndarray.2`, got pointer-to-{ndarray_data}"
));
};
if ndarray_data.get_bit_width() != 8 {
return Err(format!(
"Expected pointer-to-8-bit int type for `ndarray.1`, got pointer-to-{}-bit int",
ndarray_data.get_bit_width()
));
}
Ok(())
}
// TODO: Move this into e.g. StructProxyType
#[must_use]
fn fields(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> NDArrayStructFields<'ctx> {
NDArrayStructFields::new(ctx, llvm_usize)
}
// TODO: Move this into e.g. StructProxyType
#[must_use]
pub fn get_fields(
&self,
ctx: &'ctx Context,
llvm_usize: IntType<'ctx>,
) -> NDArrayStructFields<'ctx> {
Self::fields(ctx, llvm_usize)
}
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
#[must_use]
fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> {
// struct NDArray { num_dims: size_t, dims: size_t*, data: i8* }
//
// * data : Pointer to an array containing the array data
// * itemsize: The size of each NDArray elements in bytes
// * ndims : Number of dimensions in the array
// * shape : Pointer to an array containing the shape of the NDArray
// * strides : Pointer to an array indicating the number of bytes between each element at a dimension
let field_tys =
Self::fields(ctx, llvm_usize).into_iter().map(|field| field.1).collect_vec();
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
}
/// Creates an instance of [`NDArrayType`].
#[must_use]
pub fn new<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
dtype: BasicTypeEnum<'ctx>,
) -> Self {
let llvm_usize = generator.get_size_type(ctx);
let llvm_ndarray = Self::llvm_type(ctx, llvm_usize);
NDArrayType { ty: llvm_ndarray, dtype, llvm_usize }
}
/// Creates an [`NDArrayType`] from a [`PointerType`] representing an `NDArray`.
#[must_use]
pub fn from_type(
ptr_ty: PointerType<'ctx>,
dtype: BasicTypeEnum<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Self {
debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok());
NDArrayType { ty: ptr_ty, dtype, llvm_usize }
}
/// Returns the type of the `size` field of this `ndarray` type.
#[must_use]
pub fn size_type(&self) -> IntType<'ctx> {
self.as_base_type()
.get_element_type()
.into_struct_type()
.get_field_type_at_index(0)
.map(BasicTypeEnum::into_int_type)
.unwrap()
}
/// Returns the element type of this `ndarray` type.
#[must_use]
pub fn element_type(&self) -> BasicTypeEnum<'ctx> {
self.dtype
}
}
impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> {
type Base = PointerType<'ctx>;
type Value = NDArrayValue<'ctx>;
fn is_type<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: impl BasicType<'ctx>,
) -> Result<(), String> {
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
<Self as ProxyType<'ctx>>::is_representable(generator, ctx, ty)
} else {
Err(format!("Expected pointer type, got {llvm_ty:?}"))
}
}
fn is_representable<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: Self::Base,
) -> Result<(), String> {
Self::is_representable(llvm_ty, generator.get_size_type(ctx))
}
fn new_value<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> Self::Value {
self.map_value(
generator
.gen_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
name,
)
.unwrap(),
name,
)
}
fn new_array_value<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
size: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> ArraySliceValue<'ctx> {
generator
.gen_array_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
size,
name,
)
.unwrap()
}
fn map_value(
&self,
value: <Self::Value as ProxyValue<'ctx>>::Base,
name: Option<&'ctx str>,
) -> Self::Value {
debug_assert_eq!(value.get_type(), self.as_base_type());
NDArrayValue::from_pointer_value(value, self.dtype, self.llvm_usize, name)
}
fn as_base_type(&self) -> Self::Base {
self.ty
}
}
impl<'ctx> From<NDArrayType<'ctx>> for PointerType<'ctx> {
fn from(value: NDArrayType<'ctx>) -> Self {
value.as_base_type()
}
}

View File

@ -1,159 +0,0 @@
use inkwell::{
context::Context,
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
values::IntValue,
AddressSpace,
};
use super::ProxyType;
use crate::codegen::{
values::{ArraySliceValue, ProxyValue, RangeValue},
{CodeGenContext, CodeGenerator},
};
/// Proxy type for a `range` type in LLVM.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct RangeType<'ctx> {
ty: PointerType<'ctx>,
}
impl<'ctx> RangeType<'ctx> {
/// Checks whether `llvm_ty` represents a `range` type, returning [Err] if it does not.
pub fn is_representable(llvm_ty: PointerType<'ctx>) -> Result<(), String> {
let llvm_range_ty = llvm_ty.get_element_type();
let AnyTypeEnum::ArrayType(llvm_range_ty) = llvm_range_ty else {
return Err(format!("Expected array type for `range` type, got {llvm_range_ty}"));
};
if llvm_range_ty.len() != 3 {
return Err(format!(
"Expected 3 elements for `range` type, got {}",
llvm_range_ty.len()
));
}
let llvm_range_elem_ty = llvm_range_ty.get_element_type();
let Ok(llvm_range_elem_ty) = IntType::try_from(llvm_range_elem_ty) else {
return Err(format!(
"Expected int type for `range` element type, got {llvm_range_elem_ty}"
));
};
if llvm_range_elem_ty.get_bit_width() != 32 {
return Err(format!(
"Expected 32-bit int type for `range` element type, got {}",
llvm_range_elem_ty.get_bit_width()
));
}
Ok(())
}
/// Creates an LLVM type corresponding to the expected structure of a `Range`.
#[must_use]
fn llvm_type(ctx: &'ctx Context) -> PointerType<'ctx> {
// typedef int32_t Range[3];
let llvm_i32 = ctx.i32_type();
llvm_i32.array_type(3).ptr_type(AddressSpace::default())
}
/// Creates an instance of [`RangeType`].
#[must_use]
pub fn new(ctx: &'ctx Context) -> Self {
let llvm_range = Self::llvm_type(ctx);
RangeType::from_type(llvm_range)
}
/// Creates an [`RangeType`] from a [`PointerType`].
#[must_use]
pub fn from_type(ptr_ty: PointerType<'ctx>) -> Self {
debug_assert!(Self::is_representable(ptr_ty).is_ok());
RangeType { ty: ptr_ty }
}
/// Returns the type of all fields of this `range` type.
#[must_use]
pub fn value_type(&self) -> IntType<'ctx> {
self.as_base_type().get_element_type().into_array_type().get_element_type().into_int_type()
}
}
impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> {
type Base = PointerType<'ctx>;
type Value = RangeValue<'ctx>;
fn is_type<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: impl BasicType<'ctx>,
) -> Result<(), String> {
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
<Self as ProxyType<'ctx>>::is_representable(generator, ctx, ty)
} else {
Err(format!("Expected pointer type, got {llvm_ty:?}"))
}
}
fn is_representable<G: CodeGenerator + ?Sized>(
_: &G,
_: &'ctx Context,
llvm_ty: Self::Base,
) -> Result<(), String> {
Self::is_representable(llvm_ty)
}
fn new_value<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> Self::Value {
self.map_value(
generator
.gen_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
name,
)
.unwrap(),
name,
)
}
fn new_array_value<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
size: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> ArraySliceValue<'ctx> {
generator
.gen_array_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
size,
name,
)
.unwrap()
}
fn map_value(
&self,
value: <Self::Value as ProxyValue<'ctx>>::Base,
name: Option<&'ctx str>,
) -> Self::Value {
debug_assert_eq!(value.get_type(), self.as_base_type());
RangeValue::from_pointer_value(value, name)
}
fn as_base_type(&self) -> Self::Base {
self.ty
}
}
impl<'ctx> From<RangeType<'ctx>> for PointerType<'ctx> {
fn from(value: RangeType<'ctx>) -> Self {
value.as_base_type()
}
}

View File

@ -1,203 +0,0 @@
use std::marker::PhantomData;
use inkwell::{
context::AsContextRef,
types::{BasicTypeEnum, IntType},
values::{BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue},
};
use crate::codegen::CodeGenContext;
/// Trait indicating that the structure is a field-wise representation of an LLVM structure.
///
/// # Usage
///
/// For example, for a simple C-slice LLVM structure:
///
/// ```ignore
/// struct CSliceFields<'ctx> {
/// ptr: StructField<'ctx, PointerValue<'ctx>>,
/// len: StructField<'ctx, IntValue<'ctx>>
/// }
/// ```
pub trait StructFields<'ctx>: Eq + Copy {
/// Creates an instance of [`StructFields`] using the given `ctx` and `size_t` types.
fn new(ctx: impl AsContextRef<'ctx>, llvm_usize: IntType<'ctx>) -> Self;
/// Returns a [`Vec`] that contains the fields of the structure in the order as they appear in
/// the type definition.
#[must_use]
fn to_vec(&self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)>;
/// Returns a [`Iterator`] that contains the fields of the structure in the order as they appear
/// in the type definition.
#[must_use]
fn iter(&self) -> impl Iterator<Item = (&'static str, BasicTypeEnum<'ctx>)> {
self.to_vec().into_iter()
}
/// Returns a [`Vec`] that contains the fields of the structure in the order as they appear in
/// the type definition.
#[must_use]
fn into_vec(self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)>
where
Self: Sized,
{
self.to_vec()
}
/// Returns a [`Iterator`] that contains the fields of the structure in the order as they appear
/// in the type definition.
#[must_use]
fn into_iter(self) -> impl Iterator<Item = (&'static str, BasicTypeEnum<'ctx>)>
where
Self: Sized,
{
self.into_vec().into_iter()
}
}
/// A single field of an LLVM structure.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct StructField<'ctx, Value>
where
Value: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>, Error = ()>,
{
/// The index of this field within the structure.
index: u32,
/// The name of this field.
name: &'static str,
/// The type of this field.
ty: BasicTypeEnum<'ctx>,
/// Instance of [`PhantomData`] containing [`Value`], used to implement automatic downcasts.
_value_ty: PhantomData<Value>,
}
impl<'ctx, Value> StructField<'ctx, Value>
where
Value: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>, Error = ()>,
{
/// Creates an instance of [`StructField`].
///
/// * `idx_counter` - The instance of [`FieldIndexCounter`] used to track the current field
/// index.
/// * `name` - Name of the field.
/// * `ty` - The type of this field.
pub fn create(
idx_counter: &mut FieldIndexCounter,
name: &'static str,
ty: impl Into<BasicTypeEnum<'ctx>>,
) -> Self {
StructField { index: idx_counter.increment(), name, ty: ty.into(), _value_ty: PhantomData }
}
/// Creates an instance of [`StructField`] with a given index.
///
/// * `index` - The index of this field within its enclosing structure.
/// * `name` - Name of the field.
/// * `ty` - The type of this field.
pub fn create_at(index: u32, name: &'static str, ty: impl Into<BasicTypeEnum<'ctx>>) -> Self {
StructField { index, name, ty: ty.into(), _value_ty: PhantomData }
}
/// Creates a pointer to this field in an arbitrary structure by performing a `getelementptr i32
/// {idx...}, i32 {self.index}`.
pub fn ptr_by_array_gep(
&self,
ctx: &CodeGenContext<'ctx, '_>,
pobj: PointerValue<'ctx>,
idx: &[IntValue<'ctx>],
) -> PointerValue<'ctx> {
unsafe {
ctx.builder.build_in_bounds_gep(
pobj,
&[idx, &[ctx.ctx.i32_type().const_int(u64::from(self.index), false)]].concat(),
"",
)
}
.unwrap()
}
/// Creates a pointer to this field in an arbitrary structure by performing the equivalent of
/// `getelementptr i32 0, i32 {self.index}`.
pub fn ptr_by_gep(
&self,
ctx: &CodeGenContext<'ctx, '_>,
pobj: PointerValue<'ctx>,
obj_name: Option<&'ctx str>,
) -> PointerValue<'ctx> {
ctx.builder
.build_struct_gep(
pobj,
self.index,
&obj_name.map(|name| format!("{name}.{}.addr", self.name)).unwrap_or_default(),
)
.unwrap()
}
/// Gets the value of this field for a given `obj`.
#[must_use]
pub fn get_from_value(&self, obj: StructValue<'ctx>) -> Value {
obj.get_field_at_index(self.index).and_then(|value| Value::try_from(value).ok()).unwrap()
}
/// Sets the value of this field for a given `obj`.
pub fn set_from_value(&self, obj: StructValue<'ctx>, value: Value) {
obj.set_field_at_index(self.index, value);
}
/// Gets the value of this field for a pointer-to-structure.
pub fn get(
&self,
ctx: &CodeGenContext<'ctx, '_>,
pobj: PointerValue<'ctx>,
obj_name: Option<&'ctx str>,
) -> Value {
ctx.builder
.build_load(
self.ptr_by_gep(ctx, pobj, obj_name),
&obj_name.map(|name| format!("{name}.{}", self.name)).unwrap_or_default(),
)
.map_err(|_| ())
.and_then(|value| Value::try_from(value))
.unwrap()
}
/// Sets the value of this field for a pointer-to-structure.
pub fn set(
&self,
ctx: &CodeGenContext<'ctx, '_>,
pobj: PointerValue<'ctx>,
value: Value,
obj_name: Option<&'ctx str>,
) {
ctx.builder.build_store(self.ptr_by_gep(ctx, pobj, obj_name), value).unwrap();
}
}
impl<'ctx, Value> From<StructField<'ctx, Value>> for (&'static str, BasicTypeEnum<'ctx>)
where
Value: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>, Error = ()>,
{
fn from(value: StructField<'ctx, Value>) -> Self {
(value.name, value.ty)
}
}
/// A counter that tracks the next index of a field using a monotonically increasing counter.
#[derive(Default, Debug, PartialEq, Eq, Clone, Copy)]
pub struct FieldIndexCounter(u32);
impl FieldIndexCounter {
/// Increments the number stored by this counter, returning the previous value.
///
/// Functionally equivalent to `i++` in C-based languages.
pub fn increment(&mut self) -> u32 {
let v = self.0;
self.0 += 1;
v
}
}

View File

@ -1,426 +0,0 @@
use inkwell::{
types::AnyTypeEnum,
values::{BasicValueEnum, IntValue, PointerValue},
IntPredicate,
};
use crate::codegen::{CodeGenContext, CodeGenerator};
/// An LLVM value that is array-like, i.e. it contains a contiguous, sequenced collection of
/// elements.
pub trait ArrayLikeValue<'ctx> {
/// Returns the element type of this array-like value.
fn element_type<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> AnyTypeEnum<'ctx>;
/// Returns the base pointer to the array.
fn base_ptr<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> PointerValue<'ctx>;
/// Returns the size of this array-like value.
fn size<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> IntValue<'ctx>;
/// Returns a [`ArraySliceValue`] representing this value.
fn as_slice_value<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> ArraySliceValue<'ctx> {
ArraySliceValue::from_ptr_val(
self.base_ptr(ctx, generator),
self.size(ctx, generator),
None,
)
}
}
/// An array-like value that can be indexed by memory offset.
pub trait ArrayLikeIndexer<'ctx, Index = IntValue<'ctx>>: ArrayLikeValue<'ctx> {
/// # Safety
///
/// This function should be called with a valid index.
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
name: Option<&str>,
) -> PointerValue<'ctx>;
/// Returns the pointer to the data at the `idx`-th index.
fn ptr_offset<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
name: Option<&str>,
) -> PointerValue<'ctx>;
}
/// An array-like value that can have its array elements accessed as a [`BasicValueEnum`].
pub trait UntypedArrayLikeAccessor<'ctx, Index = IntValue<'ctx>>:
ArrayLikeIndexer<'ctx, Index>
{
/// # Safety
///
/// This function should be called with a valid index.
unsafe fn get_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
name: Option<&str>,
) -> BasicValueEnum<'ctx> {
let ptr = unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) };
ctx.builder.build_load(ptr, name.unwrap_or_default()).unwrap()
}
/// Returns the data at the `idx`-th index.
fn get<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
name: Option<&str>,
) -> BasicValueEnum<'ctx> {
let ptr = self.ptr_offset(ctx, generator, idx, name);
ctx.builder.build_load(ptr, name.unwrap_or_default()).unwrap()
}
}
/// An array-like value that can have its array elements mutated as a [`BasicValueEnum`].
pub trait UntypedArrayLikeMutator<'ctx, Index = IntValue<'ctx>>:
ArrayLikeIndexer<'ctx, Index>
{
/// # Safety
///
/// This function should be called with a valid index.
unsafe fn set_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
value: BasicValueEnum<'ctx>,
) {
let ptr = unsafe { self.ptr_offset_unchecked(ctx, generator, idx, None) };
ctx.builder.build_store(ptr, value).unwrap();
}
/// Sets the data at the `idx`-th index.
fn set<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
value: BasicValueEnum<'ctx>,
) {
let ptr = self.ptr_offset(ctx, generator, idx, None);
ctx.builder.build_store(ptr, value).unwrap();
}
}
/// An array-like value that can have its array elements accessed as an arbitrary type `T`.
pub trait TypedArrayLikeAccessor<'ctx, T, Index = IntValue<'ctx>>:
UntypedArrayLikeAccessor<'ctx, Index>
{
/// Casts an element from [`BasicValueEnum`] into `T`.
fn downcast_to_type(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
value: BasicValueEnum<'ctx>,
) -> T;
/// # Safety
///
/// This function should be called with a valid index.
unsafe fn get_typed_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
name: Option<&str>,
) -> T {
let value = unsafe { self.get_unchecked(ctx, generator, idx, name) };
self.downcast_to_type(ctx, value)
}
/// Returns the data at the `idx`-th index.
fn get_typed<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
name: Option<&str>,
) -> T {
let value = self.get(ctx, generator, idx, name);
self.downcast_to_type(ctx, value)
}
}
/// An array-like value that can have its array elements mutated as an arbitrary type `T`.
pub trait TypedArrayLikeMutator<'ctx, T, Index = IntValue<'ctx>>:
UntypedArrayLikeMutator<'ctx, Index>
{
/// Casts an element from T into [`BasicValueEnum`].
fn upcast_from_type(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
value: T,
) -> BasicValueEnum<'ctx>;
/// # Safety
///
/// This function should be called with a valid index.
unsafe fn set_typed_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
value: T,
) {
let value = self.upcast_from_type(ctx, value);
unsafe { self.set_unchecked(ctx, generator, idx, value) }
}
/// Sets the data at the `idx`-th index.
fn set_typed<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
value: T,
) {
let value = self.upcast_from_type(ctx, value);
self.set(ctx, generator, idx, value);
}
}
/// Type alias for a function that casts a [`BasicValueEnum`] into a `T`.
type ValueDowncastFn<'ctx, T> =
Box<dyn Fn(&mut CodeGenContext<'ctx, '_>, BasicValueEnum<'ctx>) -> T>;
/// Type alias for a function that casts a `T` into a [`BasicValueEnum`].
type ValueUpcastFn<'ctx, T> = Box<dyn Fn(&mut CodeGenContext<'ctx, '_>, T) -> BasicValueEnum<'ctx>>;
/// An adapter for constraining untyped array values as typed values.
pub struct TypedArrayLikeAdapter<'ctx, T, Adapted: ArrayLikeValue<'ctx> = ArraySliceValue<'ctx>> {
adapted: Adapted,
downcast_fn: ValueDowncastFn<'ctx, T>,
upcast_fn: ValueUpcastFn<'ctx, T>,
}
impl<'ctx, T, Adapted> TypedArrayLikeAdapter<'ctx, T, Adapted>
where
Adapted: ArrayLikeValue<'ctx>,
{
/// Creates a [`TypedArrayLikeAdapter`].
///
/// * `adapted` - The value to be adapted.
/// * `downcast_fn` - The function converting a [`BasicValueEnum`] into a `T`.
/// * `upcast_fn` - The function converting a T into a [`BasicValueEnum`].
pub fn from(
adapted: Adapted,
downcast_fn: ValueDowncastFn<'ctx, T>,
upcast_fn: ValueUpcastFn<'ctx, T>,
) -> Self {
TypedArrayLikeAdapter { adapted, downcast_fn, upcast_fn }
}
}
impl<'ctx, T, Adapted> ArrayLikeValue<'ctx> for TypedArrayLikeAdapter<'ctx, T, Adapted>
where
Adapted: ArrayLikeValue<'ctx>,
{
fn element_type<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> AnyTypeEnum<'ctx> {
self.adapted.element_type(ctx, generator)
}
fn base_ptr<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> PointerValue<'ctx> {
self.adapted.base_ptr(ctx, generator)
}
fn size<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> IntValue<'ctx> {
self.adapted.size(ctx, generator)
}
}
impl<'ctx, T, Index, Adapted> ArrayLikeIndexer<'ctx, Index>
for TypedArrayLikeAdapter<'ctx, T, Adapted>
where
Adapted: ArrayLikeIndexer<'ctx, Index>,
{
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
name: Option<&str>,
) -> PointerValue<'ctx> {
unsafe { self.adapted.ptr_offset_unchecked(ctx, generator, idx, name) }
}
fn ptr_offset<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
name: Option<&str>,
) -> PointerValue<'ctx> {
self.adapted.ptr_offset(ctx, generator, idx, name)
}
}
impl<'ctx, T, Index, Adapted> UntypedArrayLikeAccessor<'ctx, Index>
for TypedArrayLikeAdapter<'ctx, T, Adapted>
where
Adapted: UntypedArrayLikeAccessor<'ctx, Index>,
{
}
impl<'ctx, T, Index, Adapted> UntypedArrayLikeMutator<'ctx, Index>
for TypedArrayLikeAdapter<'ctx, T, Adapted>
where
Adapted: UntypedArrayLikeMutator<'ctx, Index>,
{
}
impl<'ctx, T, Index, Adapted> TypedArrayLikeAccessor<'ctx, T, Index>
for TypedArrayLikeAdapter<'ctx, T, Adapted>
where
Adapted: UntypedArrayLikeAccessor<'ctx, Index>,
{
fn downcast_to_type(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
value: BasicValueEnum<'ctx>,
) -> T {
(self.downcast_fn)(ctx, value)
}
}
impl<'ctx, T, Index, Adapted> TypedArrayLikeMutator<'ctx, T, Index>
for TypedArrayLikeAdapter<'ctx, T, Adapted>
where
Adapted: UntypedArrayLikeMutator<'ctx, Index>,
{
fn upcast_from_type(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
value: T,
) -> BasicValueEnum<'ctx> {
(self.upcast_fn)(ctx, value)
}
}
/// An LLVM value representing an array slice, consisting of a pointer to the data and the size of
/// the slice.
#[derive(Copy, Clone)]
pub struct ArraySliceValue<'ctx>(PointerValue<'ctx>, IntValue<'ctx>, Option<&'ctx str>);
impl<'ctx> ArraySliceValue<'ctx> {
/// Creates an [`ArraySliceValue`] from a [`PointerValue`] and its size.
#[must_use]
pub fn from_ptr_val(
ptr: PointerValue<'ctx>,
size: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> Self {
ArraySliceValue(ptr, size, name)
}
}
impl<'ctx> From<ArraySliceValue<'ctx>> for PointerValue<'ctx> {
fn from(value: ArraySliceValue<'ctx>) -> Self {
value.0
}
}
impl<'ctx> ArrayLikeValue<'ctx> for ArraySliceValue<'ctx> {
fn element_type<G: CodeGenerator + ?Sized>(
&self,
_: &CodeGenContext<'ctx, '_>,
_: &G,
) -> AnyTypeEnum<'ctx> {
self.0.get_type().get_element_type()
}
fn base_ptr<G: CodeGenerator + ?Sized>(
&self,
_: &CodeGenContext<'ctx, '_>,
_: &G,
) -> PointerValue<'ctx> {
self.0
}
fn size<G: CodeGenerator + ?Sized>(
&self,
_: &CodeGenContext<'ctx, '_>,
_: &G,
) -> IntValue<'ctx> {
self.1
}
}
impl<'ctx> ArrayLikeIndexer<'ctx> for ArraySliceValue<'ctx> {
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
.unwrap()
}
}
fn ptr_offset<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
debug_assert_eq!(idx.get_type(), generator.get_size_type(ctx.ctx));
let size = self.size(ctx, generator);
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
ctx.make_assert(
generator,
in_range,
"0:IndexError",
"list index out of range",
[None, None, None],
ctx.current_loc,
);
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
}
}
impl<'ctx> UntypedArrayLikeAccessor<'ctx> for ArraySliceValue<'ctx> {}
impl<'ctx> UntypedArrayLikeMutator<'ctx> for ArraySliceValue<'ctx> {}

View File

@ -1,241 +0,0 @@
use inkwell::{
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType},
values::{BasicValueEnum, IntValue, PointerValue},
AddressSpace, IntPredicate,
};
use super::{
ArrayLikeIndexer, ArrayLikeValue, ProxyValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
};
use crate::codegen::{
types::ListType,
{CodeGenContext, CodeGenerator},
};
/// Proxy type for accessing a `list` value in LLVM.
#[derive(Copy, Clone)]
pub struct ListValue<'ctx> {
value: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
}
impl<'ctx> ListValue<'ctx> {
/// Checks whether `value` is an instance of `list`, returning [Err] if `value` is not an
/// instance.
pub fn is_representable(
value: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
ListType::is_representable(value.get_type(), llvm_usize)
}
/// Creates an [`ListValue`] from a [`PointerValue`].
#[must_use]
pub fn from_pointer_value(
ptr: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
) -> Self {
debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok());
ListValue { value: ptr, llvm_usize, name }
}
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
/// on the field.
fn pptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
var_name.as_str(),
)
.unwrap()
}
}
/// Returns the pointer to the field storing the size of this `list`.
fn ptr_to_size(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.size.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
var_name.as_str(),
)
.unwrap()
}
}
/// Stores the array of data elements `data` into this instance.
fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) {
ctx.builder.build_store(self.pptr_to_data(ctx), data).unwrap();
}
/// Convenience method for creating a new array storing data elements with the given element
/// type `elem_ty` and `size`.
///
/// If `size` is [None], the size stored in the field of this instance is used instead.
pub fn create_data(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: BasicTypeEnum<'ctx>,
size: Option<IntValue<'ctx>>,
) {
let size = size.unwrap_or_else(|| self.load_size(ctx, None));
let data = ctx
.builder
.build_select(
ctx.builder
.build_int_compare(IntPredicate::NE, size, self.llvm_usize.const_zero(), "")
.unwrap(),
ctx.builder.build_array_alloca(elem_ty, size, "").unwrap(),
elem_ty.ptr_type(AddressSpace::default()).const_zero(),
"",
)
.map(BasicValueEnum::into_pointer_value)
.unwrap();
self.store_data(ctx, data);
}
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
/// on the field.
#[must_use]
pub fn data(&self) -> ListDataProxy<'ctx, '_> {
ListDataProxy(self)
}
/// Stores the `size` of this `list` into this instance.
pub fn store_size<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
size: IntValue<'ctx>,
) {
debug_assert_eq!(size.get_type(), generator.get_size_type(ctx.ctx));
let psize = self.ptr_to_size(ctx);
ctx.builder.build_store(psize, size).unwrap();
}
/// Returns the size of this `list` as a value.
pub fn load_size(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> {
let psize = self.ptr_to_size(ctx);
let var_name = name
.map(ToString::to_string)
.or_else(|| self.name.map(|v| format!("{v}.size")))
.unwrap_or_default();
ctx.builder
.build_load(psize, var_name.as_str())
.map(BasicValueEnum::into_int_value)
.unwrap()
}
}
impl<'ctx> ProxyValue<'ctx> for ListValue<'ctx> {
type Base = PointerValue<'ctx>;
type Type = ListType<'ctx>;
fn get_type(&self) -> Self::Type {
ListType::from_type(self.as_base_value().get_type(), self.llvm_usize)
}
fn as_base_value(&self) -> Self::Base {
self.value
}
}
impl<'ctx> From<ListValue<'ctx>> for PointerValue<'ctx> {
fn from(value: ListValue<'ctx>) -> Self {
value.as_base_value()
}
}
/// Proxy type for accessing the `data` array of an `list` instance in LLVM.
#[derive(Copy, Clone)]
pub struct ListDataProxy<'ctx, 'a>(&'a ListValue<'ctx>);
impl<'ctx> ArrayLikeValue<'ctx> for ListDataProxy<'ctx, '_> {
fn element_type<G: CodeGenerator + ?Sized>(
&self,
_: &CodeGenContext<'ctx, '_>,
_: &G,
) -> AnyTypeEnum<'ctx> {
self.0.value.get_type().get_element_type()
}
fn base_ptr<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> PointerValue<'ctx> {
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
ctx.builder
.build_load(self.0.pptr_to_data(ctx), var_name.as_str())
.map(BasicValueEnum::into_pointer_value)
.unwrap()
}
fn size<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> IntValue<'ctx> {
self.0.load_size(ctx, None)
}
}
impl<'ctx> ArrayLikeIndexer<'ctx> for ListDataProxy<'ctx, '_> {
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
.unwrap()
}
}
fn ptr_offset<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
debug_assert_eq!(idx.get_type(), generator.get_size_type(ctx.ctx));
let size = self.size(ctx, generator);
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
ctx.make_assert(
generator,
in_range,
"0:IndexError",
"list index out of range",
[None, None, None],
ctx.current_loc,
);
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
}
}
impl<'ctx> UntypedArrayLikeAccessor<'ctx> for ListDataProxy<'ctx, '_> {}
impl<'ctx> UntypedArrayLikeMutator<'ctx> for ListDataProxy<'ctx, '_> {}

View File

@ -1,47 +0,0 @@
use inkwell::{context::Context, values::BasicValue};
use super::types::ProxyType;
use crate::codegen::CodeGenerator;
pub use array::*;
pub use list::*;
pub use ndarray::*;
pub use range::*;
mod array;
mod list;
mod ndarray;
mod range;
/// A LLVM type that is used to represent a non-primitive value in NAC3.
pub trait ProxyValue<'ctx>: Into<Self::Base> {
/// The type of LLVM values represented by this instance. This is usually the
/// [LLVM pointer type][PointerValue].
type Base: BasicValue<'ctx>;
/// The type of this value.
type Type: ProxyType<'ctx, Value = Self>;
/// Checks whether `value` can be represented by this [`ProxyValue`].
fn is_instance<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
value: impl BasicValue<'ctx>,
) -> Result<(), String> {
Self::Type::is_type(generator, ctx, value.as_basic_value_enum().get_type())
}
/// Checks whether `value` can be represented by this [`ProxyValue`].
fn is_representable<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
value: Self::Base,
) -> Result<(), String> {
Self::is_instance(generator, ctx, value.as_basic_value_enum())
}
/// Returns the [type][ProxyType] of this value.
fn get_type(&self) -> Self::Type;
/// Returns the [base value][Self::Base] of this proxy.
fn as_base_value(&self) -> Self::Base;
}

View File

@ -1,523 +0,0 @@
use inkwell::{
types::{AnyType, AnyTypeEnum, BasicType, BasicTypeEnum, IntType},
values::{BasicValueEnum, IntValue, PointerValue},
AddressSpace, IntPredicate,
};
use super::{
ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeMutator,
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
};
use crate::codegen::{
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
llvm_intrinsics::call_int_umin,
stmt::gen_for_callback_incrementing,
types::NDArrayType,
CodeGenContext, CodeGenerator,
};
/// Proxy type for accessing an `NDArray` value in LLVM.
#[derive(Copy, Clone)]
pub struct NDArrayValue<'ctx> {
value: PointerValue<'ctx>,
dtype: BasicTypeEnum<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
}
impl<'ctx> NDArrayValue<'ctx> {
/// Checks whether `value` is an instance of `NDArray`, returning [Err] if `value` is not an
/// instance.
pub fn is_representable(
value: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
NDArrayType::is_representable(value.get_type(), llvm_usize)
}
/// Creates an [`NDArrayValue`] from a [`PointerValue`].
#[must_use]
pub fn from_pointer_value(
ptr: PointerValue<'ctx>,
dtype: BasicTypeEnum<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
) -> Self {
debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok());
NDArrayValue { value: ptr, dtype, llvm_usize, name }
}
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.ndims
.ptr_by_gep(ctx, self.value, self.name)
}
/// Stores the number of dimensions `ndims` into this instance.
pub fn store_ndims<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
ndims: IntValue<'ctx>,
) {
debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx));
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_store(pndims, ndims).unwrap();
}
/// Returns the number of dimensions of this `NDArray` as a value.
pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
}
/// Returns the double-indirection pointer to the `shape` array, as if by calling
/// `getelementptr` on the field.
fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.shape
.ptr_by_gep(ctx, self.value, self.name)
}
/// Stores the array of dimension sizes `dims` into this instance.
fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
ctx.builder.build_store(self.ptr_to_shape(ctx), dims).unwrap();
}
/// Convenience method for creating a new array storing dimension sizes with the given `size`.
pub fn create_shape(
&self,
ctx: &CodeGenContext<'ctx, '_>,
llvm_usize: IntType<'ctx>,
size: IntValue<'ctx>,
) {
self.store_shape(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
}
/// Returns a proxy object to the field storing the size of each dimension of this `NDArray`.
#[must_use]
pub fn shape(&self) -> NDArrayShapeProxy<'ctx, '_> {
NDArrayShapeProxy(self)
}
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
/// on the field.
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.data
.ptr_by_gep(ctx, self.value, self.name)
}
/// Stores the array of data elements `data` into this instance.
fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) {
let data = ctx
.builder
.build_bit_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "")
.unwrap();
ctx.builder.build_store(self.ptr_to_data(ctx), data).unwrap();
}
/// Convenience method for creating a new array storing data elements with the given element
/// type `elem_ty` and `size`.
pub fn create_data(
&self,
ctx: &CodeGenContext<'ctx, '_>,
elem_ty: BasicTypeEnum<'ctx>,
size: IntValue<'ctx>,
) {
let itemsize =
ctx.builder.build_int_cast(elem_ty.size_of().unwrap(), size.get_type(), "").unwrap();
let nbytes = ctx.builder.build_int_mul(size, itemsize, "").unwrap();
// TODO: What about alignment?
self.store_data(
ctx,
ctx.builder.build_array_alloca(ctx.ctx.i8_type(), nbytes, "").unwrap(),
);
}
/// Returns a proxy object to the field storing the data of this `NDArray`.
#[must_use]
pub fn data(&self) -> NDArrayDataProxy<'ctx, '_> {
NDArrayDataProxy(self)
}
}
impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> {
type Base = PointerValue<'ctx>;
type Type = NDArrayType<'ctx>;
fn get_type(&self) -> Self::Type {
NDArrayType::from_type(self.as_base_value().get_type(), self.dtype, self.llvm_usize)
}
fn as_base_value(&self) -> Self::Base {
self.value
}
}
impl<'ctx> From<NDArrayValue<'ctx>> for PointerValue<'ctx> {
fn from(value: NDArrayValue<'ctx>) -> Self {
value.as_base_value()
}
}
/// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM.
#[derive(Copy, Clone)]
pub struct NDArrayShapeProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> {
fn element_type<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> AnyTypeEnum<'ctx> {
self.0.shape().base_ptr(ctx, generator).get_type().get_element_type()
}
fn base_ptr<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> PointerValue<'ctx> {
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
ctx.builder
.build_load(self.0.ptr_to_shape(ctx), var_name.as_str())
.map(BasicValueEnum::into_pointer_value)
.unwrap()
}
fn size<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> IntValue<'ctx> {
self.0.load_ndims(ctx)
}
}
impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
.unwrap()
}
}
fn ptr_offset<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let size = self.size(ctx, generator);
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
ctx.make_assert(
generator,
in_range,
"0:IndexError",
"index {0} is out of bounds for axis 0 with size {1}",
[Some(*idx), Some(self.0.load_ndims(ctx)), None],
ctx.current_loc,
);
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
}
}
impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {}
impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {}
impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
fn downcast_to_type(
&self,
_: &mut CodeGenContext<'ctx, '_>,
value: BasicValueEnum<'ctx>,
) -> IntValue<'ctx> {
value.into_int_value()
}
}
impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
fn upcast_from_type(
&self,
_: &mut CodeGenContext<'ctx, '_>,
value: IntValue<'ctx>,
) -> BasicValueEnum<'ctx> {
value.into()
}
}
/// Proxy type for accessing the `data` array of an `NDArray` instance in LLVM.
#[derive(Copy, Clone)]
pub struct NDArrayDataProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> {
fn element_type<G: CodeGenerator + ?Sized>(
&self,
_: &CodeGenContext<'ctx, '_>,
_: &G,
) -> AnyTypeEnum<'ctx> {
self.0.dtype.as_any_type_enum()
}
fn base_ptr<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> PointerValue<'ctx> {
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
ctx.builder
.build_load(self.0.ptr_to_data(ctx), var_name.as_str())
.map(BasicValueEnum::into_pointer_value)
.unwrap()
}
fn size<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> IntValue<'ctx> {
call_ndarray_calc_size(generator, ctx, &self.as_slice_value(ctx, generator), (None, None))
}
}
impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> {
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let sizeof_elem = ctx
.builder
.build_int_truncate_or_bit_cast(
self.element_type(ctx, generator).size_of().unwrap(),
idx.get_type(),
"",
)
.unwrap();
let idx = ctx.builder.build_int_mul(*idx, sizeof_elem, "").unwrap();
let ptr = unsafe {
ctx.builder
.build_in_bounds_gep(
self.base_ptr(ctx, generator),
&[idx],
name.unwrap_or_default(),
)
.unwrap()
};
// Current implementation is transparent - The returned pointer type is
// already cast into the expected type, allowing for immediately
// load/store.
ctx.builder
.build_pointer_cast(
ptr,
BasicTypeEnum::try_from(self.element_type(ctx, generator))
.unwrap()
.ptr_type(AddressSpace::default()),
"",
)
.unwrap()
}
fn ptr_offset<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let data_sz = self.size(ctx, generator);
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, data_sz, "").unwrap();
ctx.make_assert(
generator,
in_range,
"0:IndexError",
"index {0} is out of bounds with size {1}",
[Some(*idx), Some(self.0.load_ndims(ctx)), None],
ctx.current_loc,
);
let ptr = unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) };
// Current implementation is transparent - The returned pointer type is
// already cast into the expected type, allowing for immediately
// load/store.
ctx.builder
.build_pointer_cast(
ptr,
BasicTypeEnum::try_from(self.element_type(ctx, generator))
.unwrap()
.ptr_type(AddressSpace::default()),
"",
)
.unwrap()
}
}
impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDataProxy<'ctx, '_> {}
impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayDataProxy<'ctx, '_> {}
impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
for NDArrayDataProxy<'ctx, '_>
{
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
indices: &Index,
name: Option<&str>,
) -> PointerValue<'ctx> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let indices_elem_ty = indices
.ptr_offset(ctx, generator, &llvm_usize.const_zero(), None)
.get_type()
.get_element_type();
let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else {
panic!("Expected list[int32] but got {indices_elem_ty}")
};
assert_eq!(
indices_elem_ty.get_bit_width(),
32,
"Expected list[int32] but got list[int{}]",
indices_elem_ty.get_bit_width()
);
let index = call_ndarray_flatten_index(generator, ctx, *self.0, indices);
let sizeof_elem = ctx
.builder
.build_int_truncate_or_bit_cast(
self.element_type(ctx, generator).size_of().unwrap(),
index.get_type(),
"",
)
.unwrap();
let index = ctx.builder.build_int_mul(index, sizeof_elem, "").unwrap();
let ptr = unsafe {
ctx.builder
.build_in_bounds_gep(
self.base_ptr(ctx, generator),
&[index],
name.unwrap_or_default(),
)
.unwrap()
};
// TODO: Current implementation is transparent
ctx.builder
.build_pointer_cast(
ptr,
BasicTypeEnum::try_from(self.element_type(ctx, generator))
.unwrap()
.ptr_type(AddressSpace::default()),
"",
)
.unwrap()
}
fn ptr_offset<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
indices: &Index,
name: Option<&str>,
) -> PointerValue<'ctx> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let indices_size = indices.size(ctx, generator);
let nidx_leq_ndims = ctx
.builder
.build_int_compare(IntPredicate::SLE, indices_size, self.0.load_ndims(ctx), "")
.unwrap();
ctx.make_assert(
generator,
nidx_leq_ndims,
"0:IndexError",
"invalid index to scalar variable",
[None, None, None],
ctx.current_loc,
);
let indices_len = indices.size(ctx, generator);
let ndarray_len = self.0.load_ndims(ctx);
let len = call_int_umin(ctx, indices_len, ndarray_len, None);
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(len, false),
|generator, ctx, _, i| {
let (dim_idx, dim_sz) = unsafe {
(
indices.get_unchecked(ctx, generator, &i, None).into_int_value(),
self.0.shape().get_typed_unchecked(ctx, generator, &i, None),
)
};
let dim_idx = ctx
.builder
.build_int_z_extend_or_bit_cast(dim_idx, dim_sz.get_type(), "")
.unwrap();
let dim_lt =
ctx.builder.build_int_compare(IntPredicate::SLT, dim_idx, dim_sz, "").unwrap();
ctx.make_assert(
generator,
dim_lt,
"0:IndexError",
"index {0} is out of bounds for axis 0 with size {1}",
[Some(dim_idx), Some(dim_sz), None],
ctx.current_loc,
);
Ok(())
},
llvm_usize.const_int(1, false),
)
.unwrap();
let ptr = unsafe { self.ptr_offset_unchecked(ctx, generator, indices, name) };
// TODO: Current implementation is transparent
ctx.builder
.build_pointer_cast(
ptr,
BasicTypeEnum::try_from(self.element_type(ctx, generator))
.unwrap()
.ptr_type(AddressSpace::default()),
"",
)
.unwrap()
}
}
impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeAccessor<'ctx, Index>
for NDArrayDataProxy<'ctx, '_>
{
}
impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeMutator<'ctx, Index>
for NDArrayDataProxy<'ctx, '_>
{
}

View File

@ -1,153 +0,0 @@
use inkwell::values::{BasicValueEnum, IntValue, PointerValue};
use super::ProxyValue;
use crate::codegen::{types::RangeType, CodeGenContext};
/// Proxy type for accessing a `range` value in LLVM.
#[derive(Copy, Clone)]
pub struct RangeValue<'ctx> {
value: PointerValue<'ctx>,
name: Option<&'ctx str>,
}
impl<'ctx> RangeValue<'ctx> {
/// Checks whether `value` is an instance of `range`, returning [Err] if `value` is not an instance.
pub fn is_representable(value: PointerValue<'ctx>) -> Result<(), String> {
RangeType::is_representable(value.get_type())
}
/// Creates an [`RangeValue`] from a [`PointerValue`].
#[must_use]
pub fn from_pointer_value(ptr: PointerValue<'ctx>, name: Option<&'ctx str>) -> Self {
debug_assert!(Self::is_representable(ptr).is_ok());
RangeValue { value: ptr, name }
}
fn ptr_to_start(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.start.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(0, false)],
var_name.as_str(),
)
.unwrap()
}
}
fn ptr_to_end(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.end.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(1, false)],
var_name.as_str(),
)
.unwrap()
}
}
fn ptr_to_step(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.step.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(2, false)],
var_name.as_str(),
)
.unwrap()
}
}
/// Stores the `start` value into this instance.
pub fn store_start(&self, ctx: &CodeGenContext<'ctx, '_>, start: IntValue<'ctx>) {
debug_assert_eq!(start.get_type().get_bit_width(), 32);
let pstart = self.ptr_to_start(ctx);
ctx.builder.build_store(pstart, start).unwrap();
}
/// Returns the `start` value of this `range`.
pub fn load_start(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> {
let pstart = self.ptr_to_start(ctx);
let var_name = name
.map(ToString::to_string)
.or_else(|| self.name.map(|v| format!("{v}.start")))
.unwrap_or_default();
ctx.builder
.build_load(pstart, var_name.as_str())
.map(BasicValueEnum::into_int_value)
.unwrap()
}
/// Stores the `end` value into this instance.
pub fn store_end(&self, ctx: &CodeGenContext<'ctx, '_>, end: IntValue<'ctx>) {
debug_assert_eq!(end.get_type().get_bit_width(), 32);
let pend = self.ptr_to_end(ctx);
ctx.builder.build_store(pend, end).unwrap();
}
/// Returns the `end` value of this `range`.
pub fn load_end(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> {
let pend = self.ptr_to_end(ctx);
let var_name = name
.map(ToString::to_string)
.or_else(|| self.name.map(|v| format!("{v}.end")))
.unwrap_or_default();
ctx.builder.build_load(pend, var_name.as_str()).map(BasicValueEnum::into_int_value).unwrap()
}
/// Stores the `step` value into this instance.
pub fn store_step(&self, ctx: &CodeGenContext<'ctx, '_>, step: IntValue<'ctx>) {
debug_assert_eq!(step.get_type().get_bit_width(), 32);
let pstep = self.ptr_to_step(ctx);
ctx.builder.build_store(pstep, step).unwrap();
}
/// Returns the `step` value of this `range`.
pub fn load_step(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> {
let pstep = self.ptr_to_step(ctx);
let var_name = name
.map(ToString::to_string)
.or_else(|| self.name.map(|v| format!("{v}.step")))
.unwrap_or_default();
ctx.builder
.build_load(pstep, var_name.as_str())
.map(BasicValueEnum::into_int_value)
.unwrap()
}
}
impl<'ctx> ProxyValue<'ctx> for RangeValue<'ctx> {
type Base = PointerValue<'ctx>;
type Type = RangeType<'ctx>;
fn get_type(&self) -> Self::Type {
RangeType::from_type(self.value.get_type())
}
fn as_base_value(&self) -> Self::Base {
self.value
}
}
impl<'ctx> From<RangeValue<'ctx>> for PointerValue<'ctx> {
fn from(value: RangeValue<'ctx>) -> Self {
value.as_base_value()
}
}

View File

@ -1,25 +1,588 @@
#![deny(future_incompatible, let_underscore, nonstandard_style, clippy::all)] #![warn(clippy::all)]
#![warn(clippy::pedantic)] #![allow(clippy::clone_double_ref)]
#![allow(
dead_code,
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::enum_glob_use,
clippy::missing_errors_doc,
clippy::missing_panics_doc,
clippy::module_name_repetitions,
clippy::similar_names,
clippy::too_many_lines,
clippy::wildcard_imports
)]
// users of nac3core need to use the same version of these dependencies, so expose them as nac3core::* extern crate num_bigint;
pub use inkwell; extern crate inkwell;
pub use nac3parser; extern crate rustpython_parser;
pub mod codegen; pub mod type_check;
pub mod symbol_resolver;
pub mod toplevel;
pub mod typecheck;
extern crate self as nac3core; use std::error::Error;
use std::fmt;
use std::path::Path;
use std::collections::HashMap;
use num_traits::cast::ToPrimitive;
use rustpython_parser::ast;
use inkwell::OptimizationLevel;
use inkwell::builder::Builder;
use inkwell::context::Context;
use inkwell::module::Module;
use inkwell::targets::*;
use inkwell::types;
use inkwell::types::BasicType;
use inkwell::values;
use inkwell::{IntPredicate, FloatPredicate};
use inkwell::basic_block;
use inkwell::passes;
#[derive(Debug)]
enum CompileErrorKind {
Unsupported(&'static str),
MissingTypeAnnotation,
UnknownTypeAnnotation,
IncompatibleTypes,
UnboundIdentifier,
BreakOutsideLoop,
Internal(&'static str)
}
impl fmt::Display for CompileErrorKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CompileErrorKind::Unsupported(feature)
=> write!(f, "The following Python feature is not supported by NAC3: {}", feature),
CompileErrorKind::MissingTypeAnnotation
=> write!(f, "Missing type annotation"),
CompileErrorKind::UnknownTypeAnnotation
=> write!(f, "Unknown type annotation"),
CompileErrorKind::IncompatibleTypes
=> write!(f, "Incompatible types"),
CompileErrorKind::UnboundIdentifier
=> write!(f, "Unbound identifier"),
CompileErrorKind::BreakOutsideLoop
=> write!(f, "Break outside loop"),
CompileErrorKind::Internal(details)
=> write!(f, "Internal compiler error: {}", details),
}
}
}
#[derive(Debug)]
pub struct CompileError {
location: ast::Location,
kind: CompileErrorKind,
}
impl fmt::Display for CompileError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}, at {}", self.kind, self.location)
}
}
impl Error for CompileError {}
type CompileResult<T> = Result<T, CompileError>;
pub struct CodeGen<'ctx> {
context: &'ctx Context,
module: Module<'ctx>,
pass_manager: passes::PassManager<values::FunctionValue<'ctx>>,
builder: Builder<'ctx>,
current_source_location: ast::Location,
namespace: HashMap<String, values::PointerValue<'ctx>>,
break_bb: Option<basic_block::BasicBlock<'ctx>>,
}
impl<'ctx> CodeGen<'ctx> {
pub fn new(context: &'ctx Context) -> CodeGen<'ctx> {
let module = context.create_module("kernel");
let pass_manager = passes::PassManager::create(&module);
pass_manager.add_instruction_combining_pass();
pass_manager.add_reassociate_pass();
pass_manager.add_gvn_pass();
pass_manager.add_cfg_simplification_pass();
pass_manager.add_basic_alias_analysis_pass();
pass_manager.add_promote_memory_to_register_pass();
pass_manager.add_instruction_combining_pass();
pass_manager.add_reassociate_pass();
pass_manager.initialize();
let i32_type = context.i32_type();
let fn_type = i32_type.fn_type(&[i32_type.into()], false);
module.add_function("output", fn_type, None);
CodeGen {
context, module, pass_manager,
builder: context.create_builder(),
current_source_location: ast::Location::default(),
namespace: HashMap::new(),
break_bb: None,
}
}
fn set_source_location(&mut self, location: ast::Location) {
self.current_source_location = location;
}
fn compile_error(&self, kind: CompileErrorKind) -> CompileError {
CompileError {
location: self.current_source_location,
kind
}
}
fn get_basic_type(&self, name: &str) -> CompileResult<types::BasicTypeEnum<'ctx>> {
match name {
"bool" => Ok(self.context.bool_type().into()),
"int32" => Ok(self.context.i32_type().into()),
"int64" => Ok(self.context.i64_type().into()),
"float32" => Ok(self.context.f32_type().into()),
"float64" => Ok(self.context.f64_type().into()),
_ => Err(self.compile_error(CompileErrorKind::UnknownTypeAnnotation))
}
}
fn compile_function_def(
&mut self,
name: &str,
args: &ast::Parameters,
body: &ast::Suite,
decorator_list: &[ast::Expression],
returns: &Option<ast::Expression>,
is_async: bool,
) -> CompileResult<values::FunctionValue<'ctx>> {
if is_async {
return Err(self.compile_error(CompileErrorKind::Unsupported("async functions")))
}
for decorator in decorator_list.iter() {
self.set_source_location(decorator.location);
if let ast::ExpressionType::Identifier { name } = &decorator.node {
if name != "kernel" && name != "portable" {
return Err(self.compile_error(CompileErrorKind::Unsupported("custom decorators")))
}
} else {
return Err(self.compile_error(CompileErrorKind::Unsupported("decorator must be an identifier")))
}
}
let args_type = args.args.iter().map(|val| {
self.set_source_location(val.location);
if let Some(annotation) = &val.annotation {
if let ast::ExpressionType::Identifier { name } = &annotation.node {
Ok(self.get_basic_type(&name)?)
} else {
Err(self.compile_error(CompileErrorKind::Unsupported("type annotation must be an identifier")))
}
} else {
Err(self.compile_error(CompileErrorKind::MissingTypeAnnotation))
}
}).collect::<CompileResult<Vec<types::BasicTypeEnum>>>()?;
let return_type = if let Some(returns) = returns {
self.set_source_location(returns.location);
if let ast::ExpressionType::Identifier { name } = &returns.node {
if name == "None" { None } else { Some(self.get_basic_type(name)?) }
} else {
return Err(self.compile_error(CompileErrorKind::Unsupported("type annotation must be an identifier")))
}
} else {
None
};
let fn_type = match return_type {
Some(ty) => ty.fn_type(&args_type, false),
None => self.context.void_type().fn_type(&args_type, false)
};
let function = self.module.add_function(name, fn_type, None);
let basic_block = self.context.append_basic_block(function, "entry");
self.builder.position_at_end(basic_block);
for (n, arg) in args.args.iter().enumerate() {
let param = function.get_nth_param(n as u32).unwrap();
let alloca = self.builder.build_alloca(param.get_type(), &arg.arg);
self.builder.build_store(alloca, param);
self.namespace.insert(arg.arg.clone(), alloca);
}
self.compile_suite(body, return_type)?;
Ok(function)
}
fn compile_expression(
&mut self,
expression: &ast::Expression
) -> CompileResult<values::BasicValueEnum<'ctx>> {
self.set_source_location(expression.location);
match &expression.node {
ast::ExpressionType::True => Ok(self.context.bool_type().const_int(1, false).into()),
ast::ExpressionType::False => Ok(self.context.bool_type().const_int(0, false).into()),
ast::ExpressionType::Number { value: ast::Number::Integer { value } } => {
let mut bits = value.bits();
if value.sign() == num_bigint::Sign::Minus {
bits += 1;
}
match bits {
0..=32 => Ok(self.context.i32_type().const_int(value.to_i32().unwrap() as _, true).into()),
33..=64 => Ok(self.context.i64_type().const_int(value.to_i64().unwrap() as _, true).into()),
_ => Err(self.compile_error(CompileErrorKind::Unsupported("integers larger than 64 bits")))
}
},
ast::ExpressionType::Number { value: ast::Number::Float { value } } => {
Ok(self.context.f64_type().const_float(*value).into())
},
ast::ExpressionType::Identifier { name } => {
match self.namespace.get(name) {
Some(value) => Ok(self.builder.build_load(*value, name).into()),
None => Err(self.compile_error(CompileErrorKind::UnboundIdentifier))
}
},
ast::ExpressionType::Unop { op, a } => {
let a = self.compile_expression(&a)?;
match (op, a) {
(ast::UnaryOperator::Pos, values::BasicValueEnum::IntValue(a))
=> Ok(a.into()),
(ast::UnaryOperator::Pos, values::BasicValueEnum::FloatValue(a))
=> Ok(a.into()),
(ast::UnaryOperator::Neg, values::BasicValueEnum::IntValue(a))
=> Ok(self.builder.build_int_neg(a, "tmpneg").into()),
(ast::UnaryOperator::Neg, values::BasicValueEnum::FloatValue(a))
=> Ok(self.builder.build_float_neg(a, "tmpneg").into()),
(ast::UnaryOperator::Inv, values::BasicValueEnum::IntValue(a))
=> Ok(self.builder.build_not(a, "tmpnot").into()),
(ast::UnaryOperator::Not, values::BasicValueEnum::IntValue(a)) => {
// boolean "not"
if a.get_type().get_bit_width() != 1 {
Err(self.compile_error(CompileErrorKind::Unsupported("unimplemented unary operation")))
} else {
Ok(self.builder.build_not(a, "tmpnot").into())
}
},
_ => Err(self.compile_error(CompileErrorKind::Unsupported("unimplemented unary operation"))),
}
},
ast::ExpressionType::Binop { a, op, b } => {
let a = self.compile_expression(&a)?;
let b = self.compile_expression(&b)?;
if a.get_type() != b.get_type() {
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
}
use ast::Operator::*;
match (op, a, b) {
(Add, values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b))
=> Ok(self.builder.build_int_add(a, b, "tmpadd").into()),
(Sub, values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b))
=> Ok(self.builder.build_int_sub(a, b, "tmpsub").into()),
(Mult, values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b))
=> Ok(self.builder.build_int_mul(a, b, "tmpmul").into()),
(Add, values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b))
=> Ok(self.builder.build_float_add(a, b, "tmpadd").into()),
(Sub, values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b))
=> Ok(self.builder.build_float_sub(a, b, "tmpsub").into()),
(Mult, values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b))
=> Ok(self.builder.build_float_mul(a, b, "tmpmul").into()),
(Div, values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b))
=> Ok(self.builder.build_float_div(a, b, "tmpdiv").into()),
(FloorDiv, values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b))
=> Ok(self.builder.build_int_signed_div(a, b, "tmpdiv").into()),
_ => Err(self.compile_error(CompileErrorKind::Unsupported("unimplemented binary operation"))),
}
},
ast::ExpressionType::Compare { vals, ops } => {
let mut vals = vals.iter();
let mut ops = ops.iter();
let mut result = None;
let mut a = self.compile_expression(vals.next().unwrap())?;
loop {
if let Some(op) = ops.next() {
let b = self.compile_expression(vals.next().unwrap())?;
if a.get_type() != b.get_type() {
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
}
let this_result = match (a, b) {
(values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b)) => {
match op {
ast::Comparison::Equal
=> self.builder.build_int_compare(IntPredicate::EQ, a, b, "tmpeq"),
ast::Comparison::NotEqual
=> self.builder.build_int_compare(IntPredicate::NE, a, b, "tmpne"),
ast::Comparison::Less
=> self.builder.build_int_compare(IntPredicate::SLT, a, b, "tmpslt"),
ast::Comparison::LessOrEqual
=> self.builder.build_int_compare(IntPredicate::SLE, a, b, "tmpsle"),
ast::Comparison::Greater
=> self.builder.build_int_compare(IntPredicate::SGT, a, b, "tmpsgt"),
ast::Comparison::GreaterOrEqual
=> self.builder.build_int_compare(IntPredicate::SGE, a, b, "tmpsge"),
_ => return Err(self.compile_error(CompileErrorKind::Unsupported("special comparison"))),
}
},
(values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b)) => {
match op {
ast::Comparison::Equal
=> self.builder.build_float_compare(FloatPredicate::OEQ, a, b, "tmpoeq"),
ast::Comparison::NotEqual
=> self.builder.build_float_compare(FloatPredicate::UNE, a, b, "tmpune"),
ast::Comparison::Less
=> self.builder.build_float_compare(FloatPredicate::OLT, a, b, "tmpolt"),
ast::Comparison::LessOrEqual
=> self.builder.build_float_compare(FloatPredicate::OLE, a, b, "tmpole"),
ast::Comparison::Greater
=> self.builder.build_float_compare(FloatPredicate::OGT, a, b, "tmpogt"),
ast::Comparison::GreaterOrEqual
=> self.builder.build_float_compare(FloatPredicate::OGE, a, b, "tmpoge"),
_ => return Err(self.compile_error(CompileErrorKind::Unsupported("special comparison"))),
}
},
_ => return Err(self.compile_error(CompileErrorKind::Unsupported("comparison of non-numerical types"))),
};
match result {
Some(last) => {
result = Some(self.builder.build_and(last, this_result, "tmpand"));
}
None => {
result = Some(this_result);
}
}
a = b;
} else {
return Ok(result.unwrap().into())
}
}
},
ast::ExpressionType::Call { function, args, keywords } => {
if !keywords.is_empty() {
return Err(self.compile_error(CompileErrorKind::Unsupported("keyword arguments")))
}
let args = args.iter().map(|val| self.compile_expression(val))
.collect::<CompileResult<Vec<values::BasicValueEnum>>>()?;
self.set_source_location(expression.location);
if let ast::ExpressionType::Identifier { name } = &function.node {
match (name.as_str(), args[0]) {
("int32", values::BasicValueEnum::IntValue(a)) => {
let nbits = a.get_type().get_bit_width();
if nbits < 32 {
Ok(self.builder.build_int_s_extend(a, self.context.i32_type(), "tmpsext").into())
} else if nbits > 32 {
Ok(self.builder.build_int_truncate(a, self.context.i32_type(), "tmptrunc").into())
} else {
Ok(a.into())
}
},
("int64", values::BasicValueEnum::IntValue(a)) => {
let nbits = a.get_type().get_bit_width();
if nbits < 64 {
Ok(self.builder.build_int_s_extend(a, self.context.i64_type(), "tmpsext").into())
} else {
Ok(a.into())
}
},
("int32", values::BasicValueEnum::FloatValue(a)) => {
Ok(self.builder.build_float_to_signed_int(a, self.context.i32_type(), "tmpfptosi").into())
},
("int64", values::BasicValueEnum::FloatValue(a)) => {
Ok(self.builder.build_float_to_signed_int(a, self.context.i64_type(), "tmpfptosi").into())
},
("float32", values::BasicValueEnum::IntValue(a)) => {
Ok(self.builder.build_signed_int_to_float(a, self.context.f32_type(), "tmpsitofp").into())
},
("float64", values::BasicValueEnum::IntValue(a)) => {
Ok(self.builder.build_signed_int_to_float(a, self.context.f64_type(), "tmpsitofp").into())
},
("float32", values::BasicValueEnum::FloatValue(a)) => {
if a.get_type() == self.context.f64_type() {
Ok(self.builder.build_float_trunc(a, self.context.f32_type(), "tmptrunc").into())
} else {
Ok(a.into())
}
},
("float64", values::BasicValueEnum::FloatValue(a)) => {
if a.get_type() == self.context.f32_type() {
Ok(self.builder.build_float_ext(a, self.context.f64_type(), "tmpext").into())
} else {
Ok(a.into())
}
},
("output", values::BasicValueEnum::IntValue(a)) => {
let fn_value = self.module.get_function("output").unwrap();
Ok(self.builder.build_call(fn_value, &[a.into()], "call")
.try_as_basic_value().left().unwrap())
},
_ => Err(self.compile_error(CompileErrorKind::Unsupported("unrecognized call")))
}
} else {
return Err(self.compile_error(CompileErrorKind::Unsupported("function must be an identifier")))
}
},
_ => return Err(self.compile_error(CompileErrorKind::Unsupported("unimplemented expression"))),
}
}
fn compile_statement(
&mut self,
statement: &ast::Statement,
return_type: Option<types::BasicTypeEnum>
) -> CompileResult<()> {
self.set_source_location(statement.location);
use ast::StatementType::*;
match &statement.node {
Assign { targets, value } => {
let value = self.compile_expression(value)?;
for target in targets.iter() {
self.set_source_location(target.location);
if let ast::ExpressionType::Identifier { name } = &target.node {
let builder = &self.builder;
let target = self.namespace.entry(name.clone()).or_insert_with(
|| builder.build_alloca(value.get_type(), name));
if target.get_type() != value.get_type().ptr_type(inkwell::AddressSpace::Generic) {
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
}
builder.build_store(*target, value);
} else {
return Err(self.compile_error(CompileErrorKind::Unsupported("assignment target must be an identifier")))
}
}
},
Expression { expression } => { self.compile_expression(expression)?; },
If { test, body, orelse } => {
let test = self.compile_expression(test)?;
if test.get_type() != self.context.bool_type().into() {
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
}
let parent = self.builder.get_insert_block().unwrap().get_parent().unwrap();
let then_bb = self.context.append_basic_block(parent, "then");
let else_bb = self.context.append_basic_block(parent, "else");
let cont_bb = self.context.append_basic_block(parent, "ifcont");
self.builder.build_conditional_branch(test.into_int_value(), then_bb, else_bb);
self.builder.position_at_end(then_bb);
self.compile_suite(body, return_type)?;
self.builder.build_unconditional_branch(cont_bb);
self.builder.position_at_end(else_bb);
if let Some(orelse) = orelse {
self.compile_suite(orelse, return_type)?;
}
self.builder.build_unconditional_branch(cont_bb);
self.builder.position_at_end(cont_bb);
},
While { test, body, orelse } => {
let parent = self.builder.get_insert_block().unwrap().get_parent().unwrap();
let test_bb = self.context.append_basic_block(parent, "test");
self.builder.build_unconditional_branch(test_bb);
self.builder.position_at_end(test_bb);
let test = self.compile_expression(test)?;
if test.get_type() != self.context.bool_type().into() {
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
}
let then_bb = self.context.append_basic_block(parent, "then");
let else_bb = self.context.append_basic_block(parent, "else");
let cont_bb = self.context.append_basic_block(parent, "ifcont");
self.builder.build_conditional_branch(test.into_int_value(), then_bb, else_bb);
self.break_bb = Some(cont_bb);
self.builder.position_at_end(then_bb);
self.compile_suite(body, return_type)?;
self.builder.build_unconditional_branch(test_bb);
self.builder.position_at_end(else_bb);
if let Some(orelse) = orelse {
self.compile_suite(orelse, return_type)?;
}
self.builder.build_unconditional_branch(cont_bb);
self.builder.position_at_end(cont_bb);
self.break_bb = None;
},
Break => {
if let Some(bb) = self.break_bb {
self.builder.build_unconditional_branch(bb);
let parent = self.builder.get_insert_block().unwrap().get_parent().unwrap();
let unreachable_bb = self.context.append_basic_block(parent, "unreachable");
self.builder.position_at_end(unreachable_bb);
} else {
return Err(self.compile_error(CompileErrorKind::BreakOutsideLoop));
}
}
Return { value: Some(value) } => {
if let Some(return_type) = return_type {
let value = self.compile_expression(value)?;
if value.get_type() != return_type {
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
}
self.builder.build_return(Some(&value));
} else {
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
}
},
Return { value: None } => {
if !return_type.is_none() {
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
}
self.builder.build_return(None);
},
Pass => (),
_ => return Err(self.compile_error(CompileErrorKind::Unsupported("special statement"))),
}
Ok(())
}
fn compile_suite(
&mut self,
suite: &ast::Suite,
return_type: Option<types::BasicTypeEnum>
) -> CompileResult<()> {
for statement in suite.iter() {
self.compile_statement(statement, return_type)?;
}
Ok(())
}
pub fn compile_toplevel(&mut self, statement: &ast::Statement) -> CompileResult<()> {
self.set_source_location(statement.location);
if let ast::StatementType::FunctionDef {
is_async,
name,
args,
body,
decorator_list,
returns,
} = &statement.node {
let function = self.compile_function_def(name, args, body, decorator_list, returns, *is_async)?;
self.pass_manager.run_on(&function);
Ok(())
} else {
Err(self.compile_error(CompileErrorKind::Internal("top-level is not a function definition")))
}
}
pub fn print_ir(&self) {
self.module.print_to_stderr();
}
pub fn output(&self, filename: &str) {
//let triple = TargetTriple::create("riscv32-none-linux-gnu");
let triple = TargetMachine::get_default_triple();
let target = Target::from_triple(&triple)
.expect("couldn't create target from target triple");
let target_machine = target
.create_target_machine(
&triple,
"",
"",
OptimizationLevel::Default,
RelocMode::Default,
CodeModel::Default,
)
.expect("couldn't create target machine");
target_machine
.write_to_file(&self.module, FileType::Object, Path::new(filename))
.expect("couldn't write module to file");
}
}

View File

@ -1,617 +0,0 @@
use std::{
collections::{HashMap, HashSet},
fmt::{Debug, Display},
rc::Rc,
sync::Arc,
};
use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue};
use itertools::{chain, izip, Itertools};
use parking_lot::RwLock;
use nac3parser::ast::{Constant, Expr, Location, StrRef};
use crate::{
codegen::{CodeGenContext, CodeGenerator},
toplevel::{type_annotation::TypeAnnotation, DefinitionId, TopLevelDef},
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{Type, TypeEnum, Unifier, VarMap},
},
};
#[derive(Clone, PartialEq, Debug)]
pub enum SymbolValue {
I32(i32),
I64(i64),
U32(u32),
U64(u64),
Str(String),
Double(f64),
Bool(bool),
Tuple(Vec<SymbolValue>),
OptionSome(Box<SymbolValue>),
OptionNone,
}
impl SymbolValue {
/// Creates a [`SymbolValue`] from a [`Constant`].
///
/// * `constant` - The constant to create the value from.
/// * `expected_ty` - The expected type of the [`SymbolValue`].
pub fn from_constant(
constant: &Constant,
expected_ty: Type,
primitives: &PrimitiveStore,
unifier: &mut Unifier,
) -> Result<Self, String> {
match constant {
Constant::None => {
if unifier.unioned(expected_ty, primitives.option) {
Ok(SymbolValue::OptionNone)
} else {
Err(format!("Expected {expected_ty:?}, but got Option"))
}
}
Constant::Bool(b) => {
if unifier.unioned(expected_ty, primitives.bool) {
Ok(SymbolValue::Bool(*b))
} else {
Err(format!("Expected {expected_ty:?}, but got bool"))
}
}
Constant::Str(s) => {
if unifier.unioned(expected_ty, primitives.str) {
Ok(SymbolValue::Str(s.to_string()))
} else {
Err(format!("Expected {expected_ty:?}, but got str"))
}
}
Constant::Int(i) => {
if unifier.unioned(expected_ty, primitives.int32) {
i32::try_from(*i).map(SymbolValue::I32).map_err(|e| e.to_string())
} else if unifier.unioned(expected_ty, primitives.int64) {
i64::try_from(*i).map(SymbolValue::I64).map_err(|e| e.to_string())
} else if unifier.unioned(expected_ty, primitives.uint32) {
u32::try_from(*i).map(SymbolValue::U32).map_err(|e| e.to_string())
} else if unifier.unioned(expected_ty, primitives.uint64) {
u64::try_from(*i).map(SymbolValue::U64).map_err(|e| e.to_string())
} else {
Err(format!("Expected {}, but got int", unifier.stringify(expected_ty)))
}
}
Constant::Tuple(t) => {
let expected_ty = unifier.get_ty(expected_ty);
let TypeEnum::TTuple { ty, is_vararg_ctx } = expected_ty.as_ref() else {
return Err(format!(
"Expected {:?}, but got Tuple",
expected_ty.get_type_name()
));
};
assert!(*is_vararg_ctx || ty.len() == t.len());
let elems = t
.iter()
.zip(ty)
.map(|(constant, ty)| Self::from_constant(constant, *ty, primitives, unifier))
.collect::<Result<Vec<SymbolValue>, _>>()?;
Ok(SymbolValue::Tuple(elems))
}
Constant::Float(f) => {
if unifier.unioned(expected_ty, primitives.float) {
Ok(SymbolValue::Double(*f))
} else {
Err(format!("Expected {expected_ty:?}, but got float"))
}
}
_ => Err(format!("Unsupported value type {constant:?}")),
}
}
/// Creates a [`SymbolValue`] from a [`Constant`], with its type being inferred from the constant value.
///
/// * `constant` - The constant to create the value from.
pub fn from_constant_inferred(constant: &Constant) -> Result<Self, String> {
match constant {
Constant::None => Ok(SymbolValue::OptionNone),
Constant::Bool(b) => Ok(SymbolValue::Bool(*b)),
Constant::Str(s) => Ok(SymbolValue::Str(s.to_string())),
Constant::Int(i) => {
let i = *i;
if i >= 0 {
i32::try_from(i)
.map(SymbolValue::I32)
.or_else(|_| i64::try_from(i).map(SymbolValue::I64))
.map_err(|_| {
format!("Literal cannot be expressed as any integral type: {i}")
})
} else {
u32::try_from(i)
.map(SymbolValue::U32)
.or_else(|_| u64::try_from(i).map(SymbolValue::U64))
.map_err(|_| {
format!("Literal cannot be expressed as any integral type: {i}")
})
}
}
Constant::Tuple(t) => {
let elems = t
.iter()
.map(Self::from_constant_inferred)
.collect::<Result<Vec<SymbolValue>, _>>()?;
Ok(SymbolValue::Tuple(elems))
}
Constant::Float(f) => Ok(SymbolValue::Double(*f)),
_ => Err(format!("Unsupported value type {constant:?}")),
}
}
/// Returns the [`Type`] representing the data type of this value.
pub fn get_type(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> Type {
match self {
SymbolValue::I32(_) => primitives.int32,
SymbolValue::I64(_) => primitives.int64,
SymbolValue::U32(_) => primitives.uint32,
SymbolValue::U64(_) => primitives.uint64,
SymbolValue::Str(_) => primitives.str,
SymbolValue::Double(_) => primitives.float,
SymbolValue::Bool(_) => primitives.bool,
SymbolValue::Tuple(vs) => {
let vs_tys = vs.iter().map(|v| v.get_type(primitives, unifier)).collect::<Vec<_>>();
unifier.add_ty(TypeEnum::TTuple { ty: vs_tys, is_vararg_ctx: false })
}
SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option,
}
}
/// Returns the [`TypeAnnotation`] representing the data type of this value.
pub fn get_type_annotation(
&self,
primitives: &PrimitiveStore,
unifier: &mut Unifier,
) -> TypeAnnotation {
match self {
SymbolValue::Bool(..)
| SymbolValue::Double(..)
| SymbolValue::I32(..)
| SymbolValue::I64(..)
| SymbolValue::U32(..)
| SymbolValue::U64(..)
| SymbolValue::Str(..) => TypeAnnotation::Primitive(self.get_type(primitives, unifier)),
SymbolValue::Tuple(vs) => {
let vs_tys = vs
.iter()
.map(|v| v.get_type_annotation(primitives, unifier))
.collect::<Vec<_>>();
TypeAnnotation::Tuple(vs_tys)
}
SymbolValue::OptionNone => TypeAnnotation::CustomClass {
id: primitives.option.obj_id(unifier).unwrap(),
params: Vec::default(),
},
SymbolValue::OptionSome(v) => {
let ty = v.get_type_annotation(primitives, unifier);
TypeAnnotation::CustomClass {
id: primitives.option.obj_id(unifier).unwrap(),
params: vec![ty],
}
}
}
}
/// Returns the [`TypeEnum`] representing the data type of this value.
pub fn get_type_enum(
&self,
primitives: &PrimitiveStore,
unifier: &mut Unifier,
) -> Rc<TypeEnum> {
let ty = self.get_type(primitives, unifier);
unifier.get_ty(ty)
}
}
impl Display for SymbolValue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SymbolValue::I32(i) => write!(f, "{i}"),
SymbolValue::I64(i) => write!(f, "int64({i})"),
SymbolValue::U32(i) => write!(f, "uint32({i})"),
SymbolValue::U64(i) => write!(f, "uint64({i})"),
SymbolValue::Str(s) => write!(f, "\"{s}\""),
SymbolValue::Double(d) => write!(f, "{d}"),
SymbolValue::Bool(b) => {
if *b {
write!(f, "True")
} else {
write!(f, "False")
}
}
SymbolValue::Tuple(t) => {
write!(f, "({})", t.iter().map(|v| format!("{v}")).collect::<Vec<_>>().join(", "))
}
SymbolValue::OptionSome(v) => write!(f, "Some({v})"),
SymbolValue::OptionNone => write!(f, "none"),
}
}
}
impl TryFrom<SymbolValue> for u64 {
type Error = ();
/// Tries to convert a [`SymbolValue`] into a [`u64`], returning [`Err`] if the value is not
/// numeric or if the value cannot be converted into a `u64` without overflow.
fn try_from(value: SymbolValue) -> Result<Self, Self::Error> {
match value {
SymbolValue::I32(v) => u64::try_from(v).map_err(|_| ()),
SymbolValue::I64(v) => u64::try_from(v).map_err(|_| ()),
SymbolValue::U32(v) => Ok(u64::from(v)),
SymbolValue::U64(v) => Ok(v),
_ => Err(()),
}
}
}
impl TryFrom<SymbolValue> for i128 {
type Error = ();
/// Tries to convert a [`SymbolValue`] into a [`i128`], returning [`Err`] if the value is not
/// numeric.
fn try_from(value: SymbolValue) -> Result<Self, Self::Error> {
match value {
SymbolValue::I32(v) => Ok(i128::from(v)),
SymbolValue::I64(v) => Ok(i128::from(v)),
SymbolValue::U32(v) => Ok(i128::from(v)),
SymbolValue::U64(v) => Ok(i128::from(v)),
_ => Err(()),
}
}
}
pub trait StaticValue {
/// Returns a unique identifier for this value.
fn get_unique_identifier(&self) -> u64;
/// Returns the constant object represented by this unique identifier.
fn get_const_obj<'ctx>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
) -> BasicValueEnum<'ctx>;
/// Converts this value to a LLVM [`BasicValueEnum`].
fn to_basic_value_enum<'ctx>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
expected_ty: Type,
) -> Result<BasicValueEnum<'ctx>, String>;
/// Returns a field within this value.
fn get_field<'ctx>(
&self,
name: StrRef,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Option<ValueEnum<'ctx>>;
/// Returns a single element of this tuple.
fn get_tuple_element<'ctx>(&self, index: u32) -> Option<ValueEnum<'ctx>>;
}
#[derive(Clone)]
pub enum ValueEnum<'ctx> {
/// [`ValueEnum`] representing a static value.
Static(Arc<dyn StaticValue + Send + Sync>),
/// [`ValueEnum`] representing a dynamic value.
Dynamic(BasicValueEnum<'ctx>),
}
impl<'ctx> From<BasicValueEnum<'ctx>> for ValueEnum<'ctx> {
fn from(v: BasicValueEnum<'ctx>) -> Self {
ValueEnum::Dynamic(v)
}
}
impl<'ctx> From<PointerValue<'ctx>> for ValueEnum<'ctx> {
fn from(v: PointerValue<'ctx>) -> Self {
ValueEnum::Dynamic(v.into())
}
}
impl<'ctx> From<IntValue<'ctx>> for ValueEnum<'ctx> {
fn from(v: IntValue<'ctx>) -> Self {
ValueEnum::Dynamic(v.into())
}
}
impl<'ctx> From<FloatValue<'ctx>> for ValueEnum<'ctx> {
fn from(v: FloatValue<'ctx>) -> Self {
ValueEnum::Dynamic(v.into())
}
}
impl<'ctx> From<StructValue<'ctx>> for ValueEnum<'ctx> {
fn from(v: StructValue<'ctx>) -> Self {
ValueEnum::Dynamic(v.into())
}
}
impl<'ctx> ValueEnum<'ctx> {
/// Converts this [`ValueEnum`] to a [`BasicValueEnum`].
pub fn to_basic_value_enum<'a>(
self,
ctx: &mut CodeGenContext<'ctx, 'a>,
generator: &mut dyn CodeGenerator,
expected_ty: Type,
) -> Result<BasicValueEnum<'ctx>, String> {
match self {
ValueEnum::Static(v) => v.to_basic_value_enum(ctx, generator, expected_ty),
ValueEnum::Dynamic(v) => Ok(v),
}
}
}
pub trait SymbolResolver {
/// Get type of type variable identifier or top-level function type,
fn get_symbol_type(
&self,
unifier: &mut Unifier,
top_level_defs: &[Arc<RwLock<TopLevelDef>>],
primitives: &PrimitiveStore,
str: StrRef,
) -> Result<Type, String>;
/// Get the top-level definition of identifiers.
fn get_identifier_def(&self, str: StrRef) -> Result<DefinitionId, HashSet<String>>;
fn get_symbol_value<'ctx>(
&self,
str: StrRef,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
) -> Option<ValueEnum<'ctx>>;
fn get_default_param_value(&self, expr: &Expr) -> Option<SymbolValue>;
fn get_string_id(&self, s: &str) -> i32;
fn get_exception_id(&self, tyid: usize) -> usize;
fn handle_deferred_eval(
&self,
_unifier: &mut Unifier,
_top_level_defs: &[Arc<RwLock<TopLevelDef>>],
_primitives: &PrimitiveStore,
) -> Result<(), String> {
Ok(())
}
}
thread_local! {
static IDENTIFIER_ID: [StrRef; 11] = [
"int32".into(),
"int64".into(),
"float".into(),
"bool".into(),
"virtual".into(),
"tuple".into(),
"str".into(),
"Exception".into(),
"uint32".into(),
"uint64".into(),
"Literal".into(),
];
}
/// Converts a type annotation into a [Type].
pub fn parse_type_annotation<T>(
resolver: &dyn SymbolResolver,
top_level_defs: &[Arc<RwLock<TopLevelDef>>],
unifier: &mut Unifier,
primitives: &PrimitiveStore,
expr: &Expr<T>,
) -> Result<Type, HashSet<String>> {
use nac3parser::ast::ExprKind::*;
let ids = IDENTIFIER_ID.with(|ids| *ids);
let int32_id = ids[0];
let int64_id = ids[1];
let float_id = ids[2];
let bool_id = ids[3];
let virtual_id = ids[4];
let tuple_id = ids[5];
let str_id = ids[6];
let exn_id = ids[7];
let uint32_id = ids[8];
let uint64_id = ids[9];
let literal_id = ids[10];
let name_handling = |id: &StrRef, loc: Location, unifier: &mut Unifier| {
if *id == int32_id {
Ok(primitives.int32)
} else if *id == int64_id {
Ok(primitives.int64)
} else if *id == uint32_id {
Ok(primitives.uint32)
} else if *id == uint64_id {
Ok(primitives.uint64)
} else if *id == float_id {
Ok(primitives.float)
} else if *id == bool_id {
Ok(primitives.bool)
} else if *id == str_id {
Ok(primitives.str)
} else if *id == exn_id {
Ok(primitives.exception)
} else {
let obj_id = resolver.get_identifier_def(*id);
if let Ok(obj_id) = obj_id {
let def = top_level_defs[obj_id.0].read();
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
if !type_vars.is_empty() {
return Err(HashSet::from([format!(
"Unexpected number of type parameters: expected {} but got 0",
type_vars.len()
)]));
}
let fields = chain(
fields.iter().map(|(k, v, m)| (*k, (*v, *m))),
methods.iter().map(|(k, v, _)| (*k, (*v, false))),
)
.collect();
Ok(unifier.add_ty(TypeEnum::TObj { obj_id, fields, params: VarMap::default() }))
} else {
Err(HashSet::from([format!("Cannot use function name as type at {loc}")]))
}
} else {
let ty =
resolver.get_symbol_type(unifier, top_level_defs, primitives, *id).map_err(
|e| HashSet::from([format!("Unknown type annotation at {loc}: {e}")]),
)?;
if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) {
Ok(ty)
} else {
Err(HashSet::from([format!("Unknown type annotation {id} at {loc}")]))
}
}
}
};
let subscript_name_handle = |id: &StrRef, slice: &Expr<T>, unifier: &mut Unifier| {
if *id == virtual_id {
let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, slice)?;
Ok(unifier.add_ty(TypeEnum::TVirtual { ty }))
} else if *id == tuple_id {
if let Tuple { elts, .. } = &slice.node {
let ty = elts
.iter()
.map(|elt| {
parse_type_annotation(resolver, top_level_defs, unifier, primitives, elt)
})
.collect::<Result<Vec<_>, _>>()?;
Ok(unifier.add_ty(TypeEnum::TTuple { ty, is_vararg_ctx: false }))
} else {
Err(HashSet::from(["Expected multiple elements for tuple".into()]))
}
} else if *id == literal_id {
let mut parse_literal = |elt: &Expr<T>| {
let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, elt)?;
let ty_enum = &*unifier.get_ty_immutable(ty);
match ty_enum {
TypeEnum::TLiteral { values, .. } => Ok(values.clone()),
_ => Err(HashSet::from([format!(
"Expected literal in type argument for Literal at {}",
elt.location
)])),
}
};
let values = if let Tuple { elts, .. } = &slice.node {
elts.iter().map(&mut parse_literal).collect::<Result<Vec<_>, _>>()?
} else {
vec![parse_literal(slice)?]
}
.into_iter()
.flatten()
.collect_vec();
Ok(unifier.get_fresh_literal(values, Some(slice.location)))
} else {
let types = if let Tuple { elts, .. } = &slice.node {
elts.iter()
.map(|v| {
parse_type_annotation(resolver, top_level_defs, unifier, primitives, v)
})
.collect::<Result<Vec<_>, _>>()?
} else {
vec![parse_type_annotation(resolver, top_level_defs, unifier, primitives, slice)?]
};
let obj_id = resolver.get_identifier_def(*id)?;
let def = top_level_defs[obj_id.0].read();
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
if types.len() != type_vars.len() {
return Err(HashSet::from([format!(
"Unexpected number of type parameters: expected {} but got {}",
type_vars.len(),
types.len()
)]));
}
let mut subst = VarMap::new();
for (var, ty) in izip!(type_vars.iter(), types.iter()) {
let id = if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) {
*id
} else {
unreachable!()
};
subst.insert(id, *ty);
}
let mut fields = fields
.iter()
.map(|(attr, ty, is_mutable)| {
let ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*attr, (ty, *is_mutable))
})
.collect::<HashMap<_, _>>();
fields.extend(methods.iter().map(|(attr, ty, _)| {
let ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*attr, (ty, false))
}));
Ok(unifier.add_ty(TypeEnum::TObj { obj_id, fields, params: subst }))
} else {
Err(HashSet::from(["Cannot use function name as type".into()]))
}
}
};
match &expr.node {
Name { id, .. } => name_handling(id, expr.location, unifier),
Subscript { value, slice, .. } => {
if let Name { id, .. } = &value.node {
subscript_name_handle(id, slice, unifier)
} else {
Err(HashSet::from([format!("unsupported type expression at {}", expr.location)]))
}
}
Constant { value, .. } => SymbolValue::from_constant_inferred(value)
.map(|v| unifier.get_fresh_literal(vec![v], Some(expr.location)))
.map_err(|err| HashSet::from([err])),
_ => Err(HashSet::from([format!("unsupported type expression at {}", expr.location)])),
}
}
impl dyn SymbolResolver + Send + Sync {
pub fn parse_type_annotation<T>(
&self,
top_level_defs: &[Arc<RwLock<TopLevelDef>>],
unifier: &mut Unifier,
primitives: &PrimitiveStore,
expr: &Expr<T>,
) -> Result<Type, HashSet<String>> {
parse_type_annotation(self, top_level_defs, unifier, primitives, expr)
}
pub fn get_type_name(
&self,
top_level_defs: &[Arc<RwLock<TopLevelDef>>],
unifier: &mut Unifier,
ty: Type,
) -> String {
unifier.internal_stringify(
ty,
&mut |id| {
let TopLevelDef::Class { name, .. } = &*top_level_defs[id].read() else {
unreachable!("expected class definition")
};
name.to_string()
},
&mut |id| format!("typevar{id}"),
&mut None,
)
}
}
impl Debug for dyn SymbolResolver + Send + Sync {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "")
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,176 +0,0 @@
use std::{
borrow::BorrowMut,
collections::{HashMap, HashSet},
fmt::Debug,
iter::FromIterator,
sync::Arc,
};
use inkwell::values::BasicValueEnum;
use itertools::Itertools;
use parking_lot::RwLock;
use nac3parser::ast::{self, Expr, Location, Stmt, StrRef};
use crate::{
codegen::{CodeGenContext, CodeGenerator},
symbol_resolver::{SymbolResolver, ValueEnum},
typecheck::{
type_inferencer::{CodeLocation, PrimitiveStore},
typedef::{
CallId, FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, TypeVarId, Unifier,
VarMap,
},
},
};
use composer::*;
use type_annotation::*;
pub mod builtins;
pub mod composer;
pub mod helper;
pub mod numpy;
#[cfg(test)]
mod test;
pub mod type_annotation;
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash, Debug)]
pub struct DefinitionId(pub usize);
type GenCallCallback = dyn for<'ctx, 'a> Fn(
&mut CodeGenContext<'ctx, 'a>,
Option<(Type, ValueEnum<'ctx>)>,
(&FunSignature, DefinitionId),
Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
&mut dyn CodeGenerator,
) -> Result<Option<BasicValueEnum<'ctx>>, String>
+ Send
+ Sync;
pub struct GenCall {
fp: Box<GenCallCallback>,
}
impl GenCall {
#[must_use]
pub fn new(fp: Box<GenCallCallback>) -> GenCall {
GenCall { fp }
}
/// Creates a dummy instance of [`GenCall`], which invokes [`unreachable!()`] with the given
/// `reason`.
#[must_use]
pub fn create_dummy(reason: String) -> GenCall {
Self::new(Box::new(move |_, _, _, _, _| unreachable!("{reason}")))
}
pub fn run<'ctx>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
obj: Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
generator: &mut dyn CodeGenerator,
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
(self.fp)(ctx, obj, fun, args, generator)
}
}
impl Debug for GenCall {
fn fmt(&self, _: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct FunInstance {
pub body: Arc<Vec<Stmt<Option<Type>>>>,
pub calls: Arc<HashMap<CodeLocation, CallId>>,
pub subst: VarMap,
pub unifier_id: usize,
}
#[derive(Debug, Clone)]
pub enum TopLevelDef {
Class {
/// Name for error messages and symbols.
name: StrRef,
/// Object ID used for [`TypeEnum`].
object_id: DefinitionId,
/// type variables bounded to the class.
type_vars: Vec<Type>,
/// Class fields.
///
/// Name and type is mutable.
fields: Vec<(StrRef, Type, bool)>,
/// Class Attributes.
///
/// Name, type, value.
attributes: Vec<(StrRef, Type, ast::Constant)>,
/// Class methods, pointing to the corresponding function definition.
methods: Vec<(StrRef, Type, DefinitionId)>,
/// Ancestor classes, including itself.
ancestors: Vec<TypeAnnotation>,
/// Symbol resolver of the module defined the class; [None] if it is built-in type.
resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>,
/// Constructor type.
constructor: Option<Type>,
/// Definition location.
loc: Option<Location>,
},
Function {
/// Prefix for symbol, should be unique globally.
name: String,
/// Simple name, the same as in method/function definition.
simple_name: StrRef,
/// Function signature.
signature: Type,
/// Instantiated type variable IDs.
var_id: Vec<TypeVarId>,
/// Function instance to symbol mapping
///
/// * Key: String representation of type variable values, sorted by variable ID in ascending
/// order, including type variables associated with the class.
/// * Value: Function symbol name.
instance_to_symbol: HashMap<String, String>,
/// Function instances to annotated AST mapping
///
/// * Key: String representation of type variable values, sorted by variable ID in ascending
/// order, including type variables associated with the class. Excluding rigid type
/// variables.
///
/// Rigid type variables that would be substituted when the function is instantiated.
instance_to_stmt: HashMap<String, FunInstance>,
/// Symbol resolver of the module defined the class.
resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>,
/// Custom code generation callback.
codegen_callback: Option<Arc<GenCall>>,
/// Definition location.
loc: Option<Location>,
},
Variable {
/// Qualified name of the global variable, should be unique globally.
name: String,
/// Simple name, the same as in method/function definition.
simple_name: StrRef,
/// Type of the global variable.
ty: Type,
/// The declared type of the global variable, or [`None`] if no type annotation is provided.
ty_decl: Option<Expr>,
/// Symbol resolver of the module defined the class.
resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>,
/// Definition location.
loc: Option<Location>,
},
}
pub struct TopLevelContext {
pub definitions: Arc<RwLock<Vec<Arc<RwLock<TopLevelDef>>>>>,
pub unifiers: Arc<RwLock<Vec<(SharedUnifier, PrimitiveStore)>>>,
pub personality_symbol: Option<String>,
}

View File

@ -1,84 +0,0 @@
use itertools::Itertools;
use super::helper::PrimDef;
use crate::typecheck::{
type_inferencer::PrimitiveStore,
typedef::{Type, TypeEnum, TypeVarId, Unifier, VarMap},
};
/// Creates a `ndarray` [`Type`] with the given type arguments.
///
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
pub fn make_ndarray_ty(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
dtype: Option<Type>,
ndims: Option<Type>,
) -> Type {
subst_ndarray_tvars(unifier, primitives.ndarray, dtype, ndims)
}
/// Substitutes type variables in `ndarray`.
///
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
pub fn subst_ndarray_tvars(
unifier: &mut Unifier,
ndarray: Type,
dtype: Option<Type>,
ndims: Option<Type>,
) -> Type {
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
};
debug_assert_eq!(*obj_id, PrimDef::NDArray.id());
if dtype.is_none() && ndims.is_none() {
return ndarray;
}
let tvar_ids = params.iter().map(|(obj_id, _)| *obj_id).collect_vec();
debug_assert_eq!(tvar_ids.len(), 2);
let mut tvar_subst = VarMap::new();
if let Some(dtype) = dtype {
tvar_subst.insert(tvar_ids[0], dtype);
}
if let Some(ndims) = ndims {
tvar_subst.insert(tvar_ids[1], ndims);
}
unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray)
}
fn unpack_ndarray_tvars(unifier: &mut Unifier, ndarray: Type) -> Vec<(TypeVarId, Type)> {
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
};
debug_assert_eq!(*obj_id, PrimDef::NDArray.id());
debug_assert_eq!(params.len(), 2);
params
.iter()
.sorted_by_key(|(obj_id, _)| *obj_id)
.map(|(var_id, ty)| (*var_id, *ty))
.collect_vec()
}
/// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds
/// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray`
/// respectively.
pub fn unpack_ndarray_var_ids(unifier: &mut Unifier, ndarray: Type) -> (TypeVarId, TypeVarId) {
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.0).collect_tuple().unwrap()
}
/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to
/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively.
pub fn unpack_ndarray_var_tys(unifier: &mut Unifier, ndarray: Type) -> (Type, Type) {
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.1).collect_tuple().unwrap()
}

View File

@ -1,12 +0,0 @@
---
source: nac3core/src/toplevel/test.rs
expression: res_vec
---
[
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(241)]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
]

View File

@ -1,15 +0,0 @@
---
source: nac3core/src/toplevel/test.rs
expression: res_vec
---
[
"Class {\nname: \"A\",\nancestors: [\"A[T]\"],\nfields: [\"a\", \"b\", \"c\"],\nmethods: [(\"__init__\", \"fn[[t:T], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"T\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar230]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar230\"]\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
]

View File

@ -1,13 +0,0 @@
---
source: nac3core/src/toplevel/test.rs
expression: res_vec
---
[
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(243)]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(248)]\n}\n",
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
]

View File

@ -1,13 +0,0 @@
---
source: nac3core/src/toplevel/test.rs
expression: res_vec
---
[
"Class {\nname: \"A\",\nancestors: [\"A[typevar229, typevar230]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar229\", \"typevar230\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:B], B]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.bar\",\nsig: \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\",\nvar_id: []\n}\n",
]

View File

@ -1,17 +0,0 @@
---
source: nac3core/src/toplevel/test.rs
expression: res_vec
---
[
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(249)]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(257)]\n}\n",
]

View File

@ -1,9 +0,0 @@
---
source: nac3core/src/toplevel/test.rs
assertion_line: 549
expression: res_vec
---
[
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [],\nmethods: [],\ntype_vars: []\n}\n",
]

View File

@ -1,837 +0,0 @@
use std::{collections::HashMap, sync::Arc};
use indoc::indoc;
use parking_lot::Mutex;
use test_case::test_case;
use nac3parser::{
ast::{fold::Fold, FileName},
parser::parse_program,
};
use super::{helper::PrimDef, DefinitionId, *};
use crate::{
codegen::CodeGenContext,
symbol_resolver::{SymbolResolver, ValueEnum},
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{into_var_map, Type, Unifier},
},
};
struct ResolverInternal {
id_to_type: Mutex<HashMap<StrRef, Type>>,
id_to_def: Mutex<HashMap<StrRef, DefinitionId>>,
class_names: Mutex<HashMap<StrRef, Type>>,
}
impl ResolverInternal {
fn add_id_def(&self, id: StrRef, def: DefinitionId) {
self.id_to_def.lock().insert(id, def);
}
fn add_id_type(&self, id: StrRef, ty: Type) {
self.id_to_type.lock().insert(id, ty);
}
}
struct Resolver(Arc<ResolverInternal>);
impl SymbolResolver for Resolver {
fn get_default_param_value(
&self,
_: &ast::Expr,
) -> Option<crate::symbol_resolver::SymbolValue> {
unimplemented!()
}
fn get_symbol_type(
&self,
_: &mut Unifier,
_: &[Arc<RwLock<TopLevelDef>>],
_: &PrimitiveStore,
str: StrRef,
) -> Result<Type, String> {
self.0
.id_to_type
.lock()
.get(&str)
.copied()
.ok_or_else(|| format!("cannot find symbol `{str}`"))
}
fn get_symbol_value<'ctx>(
&self,
_: StrRef,
_: &mut CodeGenContext<'ctx, '_>,
_: &mut dyn CodeGenerator,
) -> Option<ValueEnum<'ctx>> {
unimplemented!()
}
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> {
self.0
.id_to_def
.lock()
.get(&id)
.copied()
.ok_or_else(|| HashSet::from(["Unknown identifier".to_string()]))
}
fn get_string_id(&self, _: &str) -> i32 {
unimplemented!()
}
fn get_exception_id(&self, _tyid: usize) -> usize {
unimplemented!()
}
}
#[test_case(
vec![
indoc! {"
def fun(a: int32) -> int32:
return a
"},
indoc! {"
class A:
def __init__(self):
self.a: int32 = 3
"},
indoc! {"
class B:
def __init__(self):
self.b: float = 4.3
def fun(self):
self.b = self.b + 3.0
"},
indoc! {"
def foo(a: float):
a + 1.0
"},
indoc! {"
class C(B):
def __init__(self):
self.c: int32 = 4
self.a: bool = True
"},
];
"register"
)]
fn test_simple_register(source: Vec<&str>) {
let mut composer =
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
for s in source {
let ast = parse_program(s, FileName::default()).unwrap();
let ast = ast[0].clone();
composer.register_top_level(ast, None, "", false).unwrap();
}
}
#[test_case(
indoc! {"
class A:
def foo(self):
pass
a = A()
"};
"register"
)]
fn test_simple_register_without_constructor(source: &str) {
let mut composer =
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
let ast = parse_program(source, FileName::default()).unwrap();
let ast = ast[0].clone();
composer.register_top_level(ast, None, "", true).unwrap();
}
#[test_case(
&[
indoc! {"
def fun(a: int32) -> int32:
return a
"},
indoc! {"
def foo(a: float):
a + 1.0
"},
indoc! {"
def f(b: int64) -> int32:
return 3
"},
],
&[
"fn[[a:0], 0]",
"fn[[a:2], 4]",
"fn[[b:1], 0]",
],
&[
"fun",
"foo",
"f"
];
"function compose"
)]
fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
let mut composer =
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
let internal_resolver = Arc::new(ResolverInternal {
id_to_def: Mutex::default(),
id_to_type: Mutex::default(),
class_names: Mutex::default(),
});
let resolver =
Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
for s in source {
let ast = parse_program(s, FileName::default()).unwrap();
let ast = ast[0].clone();
let (id, def_id, ty) =
composer.register_top_level(ast, Some(resolver.clone()), "", false).unwrap();
internal_resolver.add_id_def(id, def_id);
if let Some(ty) = ty {
internal_resolver.add_id_type(id, ty);
}
}
composer.start_analysis(true).unwrap();
for (i, (def, _)) in composer.definition_ast_list.iter().skip(composer.builtin_num).enumerate()
{
let def = &*def.read();
if let TopLevelDef::Function { signature, name, .. } = def {
let ty_str = composer.unifier.internal_stringify(
*signature,
&mut |id| id.to_string(),
&mut |id| id.to_string(),
&mut None,
);
assert_eq!(ty_str, tys[i]);
assert_eq!(name, names[i]);
}
}
}
#[test_case(
&[
indoc! {"
class A():
a: int32
def __init__(self):
self.a = 3
def fun(self, b: B):
pass
def foo(self, a: T, b: V):
pass
"},
indoc! {"
class B(C):
def __init__(self):
pass
"},
indoc! {"
class C(A):
def __init__(self):
pass
def fun(self, b: B):
a = 1
pass
"},
indoc! {"
def foo(a: A):
pass
"},
indoc! {"
def ff(a: T) -> V:
pass
"}
],
&[];
"simple class compose"
)]
#[test_case(
&[
indoc! {"
class Generic_A(Generic[V], B):
a: int64
def __init__(self):
self.a = 123123123123
def fun(self, a: int32) -> V:
pass
"},
indoc! {"
class B:
aa: bool
def __init__(self):
self.aa = False
def foo(self, b: T):
pass
"}
],
&[];
"generic class"
)]
#[test_case(
&[
indoc! {"
def foo(a: list[int32], b: tuple[T, float]) -> A[B, bool]:
pass
"},
indoc! {"
class A(Generic[T, V]):
a: T
b: V
def __init__(self, v: V):
self.a = 1
self.b = v
def fun(self, a: T) -> V:
pass
"},
indoc! {"
def gfun(a: A[list[float], int32]):
pass
"},
indoc! {"
class B:
def __init__(self):
pass
"}
],
&[];
"list tuple generic"
)]
#[test_case(
&[
indoc! {"
class A(Generic[T, V]):
a: A[float, bool]
b: B
def __init__(self, a: A[float, bool], b: B):
self.a = a
self.b = b
def fun(self, a: A[float, bool]) -> A[bool, int32]:
pass
"},
indoc! {"
class B(A[int64, bool]):
def __init__(self):
pass
def foo(self, b: B) -> B:
pass
def bar(self, a: A[list[B], int32]) -> tuple[A[virtual[A[B, int32]], bool], B]:
pass
"}
],
&[];
"self1"
)]
#[test_case(
&[
indoc! {"
class A(Generic[T]):
a: int32
b: T
c: A[int64]
def __init__(self, t: T):
self.a = 3
self.b = T
def fun(self, a: int32, b: T) -> list[virtual[B[bool]]]:
pass
def foo(self, c: C):
pass
"},
indoc! {"
class B(Generic[V], A[float]):
d: C
def __init__(self):
pass
def fun(self, a: int32, b: T) -> list[virtual[B[bool]]]:
# override
pass
"},
indoc! {"
class C(B[bool]):
e: int64
def __init__(self):
pass
"}
],
&[];
"inheritance_override"
)]
#[test_case(
&[
indoc! {"
class A(Generic[T]):
def __init__(self):
pass
def fun(self, a: A[T]) -> A[T]:
pass
"}
],
&["application of type vars to generic class is not currently supported (at unknown:4:24)"];
"err no type var in generic app"
)]
#[test_case(
&[
indoc! {"
class A(B):
def __init__(self):
pass
"},
indoc! {"
class B(A):
def __init__(self):
pass
"}
],
&["cyclic inheritance detected"];
"cyclic1"
)]
#[test_case(
&[
indoc! {"
class A(B[bool, int64]):
def __init__(self):
pass
"},
indoc! {"
class B(Generic[V, T], C[int32]):
def __init__(self):
pass
"},
indoc! {"
class C(Generic[T], A):
def __init__(self):
pass
"},
],
&["cyclic inheritance detected"];
"cyclic2"
)]
#[test_case(
&[
indoc! {"
class A:
pass
"}
],
&["5: Class {\nname: \"A\",\ndef_id: DefinitionId(5),\nancestors: [CustomClassKind { id: DefinitionId(5), params: [] }],\nfields: [],\nmethods: [],\ntype_vars: []\n}"];
"simple pass in class"
)]
#[test_case(
&[indoc! {"
class A:
def __init__():
pass
"}],
&["__init__ method must have a `self` parameter (at unknown:2:5)"];
"err no self_1"
)]
#[test_case(
&[
indoc! {"
class A(B, Generic[T], C):
def __init__(self):
pass
"},
indoc! {"
class B:
def __init__(self):
pass
"},
indoc! {"
class C:
def __init__(self):
pass
"}
],
&["a class definition can only have at most one base class declaration and one generic declaration (at unknown:1:24)"];
"err multiple inheritance"
)]
#[test_case(
&[
indoc! {"
class A(Generic[T]):
a: int32
b: T
c: A[int64]
def __init__(self, t: T):
self.a = 3
self.b = T
def fun(self, a: int32, b: T) -> list[virtual[B[bool]]]:
pass
"},
indoc! {"
class B(Generic[V], A[float]):
def __init__(self):
pass
def fun(self, a: int32, b: T) -> list[virtual[B[int32]]]:
# override
pass
"}
],
&["method fun has same name as ancestors' method, but incompatible type"];
"err_incompatible_inheritance_method"
)]
#[test_case(
&[
indoc! {"
class A(Generic[T]):
a: int32
b: T
c: A[int64]
def __init__(self, t: T):
self.a = 3
self.b = T
def fun(self, a: int32, b: T) -> list[virtual[B[bool]]]:
pass
"},
indoc! {"
class B(Generic[V], A[float]):
a: int32
def __init__(self):
pass
def fun(self, a: int32, b: T) -> list[virtual[B[bool]]]:
# override
pass
"}
],
&["field `a` has already declared in the ancestor classes"];
"err_incompatible_inheritance_field"
)]
#[test_case(
&[
indoc! {"
class A:
def __init__(self):
pass
"},
indoc! {"
class A:
a: int32
def __init__(self):
pass
"}
],
&["duplicate definition of class `A` (at unknown:1:1)"];
"class same name"
)]
fn test_analyze(source: &[&str], res: &[&str]) {
let print = false;
let mut composer =
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
let internal_resolver = make_internal_resolver_with_tvar(
vec![
("T".into(), vec![]),
("V".into(), vec![composer.primitives_ty.bool, composer.primitives_ty.int32]),
("G".into(), vec![composer.primitives_ty.bool, composer.primitives_ty.int64]),
],
&mut composer.unifier,
print,
);
let resolver =
Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
for s in source {
let ast = parse_program(s, FileName::default()).unwrap();
let ast = ast[0].clone();
let (id, def_id, ty) = {
match composer.register_top_level(ast, Some(resolver.clone()), "", false) {
Ok(x) => x,
Err(msg) => {
if print {
println!("{msg}");
} else {
assert_eq!(res[0], msg);
}
return;
}
}
};
internal_resolver.add_id_def(id, def_id);
if let Some(ty) = ty {
internal_resolver.add_id_type(id, ty);
}
}
if let Err(msg) = composer.start_analysis(false) {
if print {
println!("{}", msg.iter().sorted().join("\n----------\n"));
} else {
assert_eq!(res[0], msg.iter().next().unwrap());
}
} else {
// skip 5 to skip primitives
let mut res_vec: Vec<String> = Vec::new();
for (def, _) in composer.definition_ast_list.iter().skip(composer.builtin_num) {
let def = &*def.read();
res_vec.push(format!("{}\n", def.to_string(composer.unifier.borrow_mut())));
}
insta::assert_debug_snapshot!(res_vec);
}
}
#[test_case(
vec![
indoc! {"
def fun(a: int32, b: int32) -> int32:
return a + b
"},
indoc! {"
def fib(n: int32) -> int32:
if n <= 2:
return 1
a = fib(n - 1)
b = fib(n - 2)
return fib(n - 1)
"}
],
&[];
"simple function"
)]
#[test_case(
vec![
indoc! {"
class A:
a: int32
def __init__(self):
self.a = 3
def fun(self) -> int32:
b = self.a + 3
return b * self.a
def clone(self) -> A:
SELF = self
return SELF
def sum(self) -> int32:
if self.a == 0:
return self.a
else:
a = self.a
self.a = self.a - 1
return a + self.sum()
def fib(self, a: int32) -> int32:
if a <= 2:
return 1
return self.fib(a - 1) + self.fib(a - 2)
"},
indoc! {"
def fun(a: A) -> int32:
return a.fun() + 2
"}
],
&[];
"simple class body"
)]
#[test_case(
vec![
indoc! {"
def fun(a: V, c: G, t: T) -> V:
b = a
cc = c
ret = fun(b, cc, t)
return ret * ret
"},
indoc! {"
def sum_three(l: list[V]) -> V:
return l[0] + l[1] + l[2]
"},
indoc! {"
def sum_sq_pair(p: tuple[V, V]) -> list[V]:
a = p[0]
b = p[1]
a = a**a
b = b**b
return [a, b]
"}
],
&[];
"type var fun"
)]
#[test_case(
vec![
indoc! {"
class A(Generic[G]):
a: G
b: bool
def __init__(self, aa: G):
self.a = aa
if 2 > 1:
self.b = True
else:
# self.b = False
pass
def fun(self, a: G) -> list[G]:
ret = [a, self.a]
return ret if self.b else self.fun(self.a)
"}
],
&[];
"type var class"
)]
#[test_case(
vec![
indoc! {"
class A:
def fun(self):
pass
"},
indoc!{"
class B:
a: int32
b: bool
def __init__(self):
# self.b = False
if 3 > 2:
self.a = 3
self.b = False
else:
self.a = 4
self.b = True
"}
],
&[];
"no_init_inst_check"
)]
fn test_inference(source: Vec<&str>, res: &[&str]) {
let print = true;
let mut composer =
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
let internal_resolver = make_internal_resolver_with_tvar(
vec![
("T".into(), vec![]),
(
"V".into(),
vec![
composer.primitives_ty.float,
composer.primitives_ty.int32,
composer.primitives_ty.int64,
],
),
("G".into(), vec![composer.primitives_ty.bool, composer.primitives_ty.int64]),
],
&mut composer.unifier,
print,
);
let resolver =
Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
for s in source {
let ast = parse_program(s, FileName::default()).unwrap();
let ast = ast[0].clone();
let (id, def_id, ty) = {
match composer.register_top_level(ast, Some(resolver.clone()), "", false) {
Ok(x) => x,
Err(msg) => {
if print {
println!("{msg}");
} else {
assert_eq!(res[0], msg);
}
return;
}
}
};
internal_resolver.add_id_def(id, def_id);
if let Some(ty) = ty {
internal_resolver.add_id_type(id, ty);
}
}
if let Err(msg) = composer.start_analysis(true) {
if print {
println!("{}", msg.iter().sorted().join("\n----------\n"));
} else {
assert_eq!(res[0], msg.iter().next().unwrap());
}
} else {
// skip 5 to skip primitives
let mut stringify_folder = TypeToStringFolder { unifier: &mut composer.unifier };
for (def, _) in composer.definition_ast_list.iter().skip(composer.builtin_num) {
let def = &*def.read();
if let TopLevelDef::Function { instance_to_stmt, name, .. } = def {
println!(
"=========`{}`: number of instances: {}===========",
name,
instance_to_stmt.len()
);
for inst in instance_to_stmt {
let ast = &inst.1.body;
for b in ast.iter() {
println!("{:?}", stringify_folder.fold_stmt(b.clone()).unwrap());
println!("--------------------");
}
println!("\n");
}
}
}
}
}
fn make_internal_resolver_with_tvar(
tvars: Vec<(StrRef, Vec<Type>)>,
unifier: &mut Unifier,
print: bool,
) -> Arc<ResolverInternal> {
let list_elem_tvar = unifier.get_fresh_var(Some("list_elem".into()), None);
let list = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::List.id(),
fields: HashMap::new(),
params: into_var_map([list_elem_tvar]),
});
let res: Arc<ResolverInternal> = ResolverInternal {
id_to_def: Mutex::new(HashMap::from([("list".into(), PrimDef::List.id())])),
id_to_type: tvars
.into_iter()
.map(|(name, range)| {
(name, {
let tvar = unifier.get_fresh_var_with_range(range.as_slice(), None, None);
if print {
println!("{}: {:?}, typevar{}", name, tvar.ty, tvar.id);
}
tvar.ty
})
})
.collect::<HashMap<_, _>>()
.into(),
class_names: Mutex::new(HashMap::from([("list".into(), list)])),
}
.into();
if print {
println!();
}
res
}
struct TypeToStringFolder<'a> {
unifier: &'a mut Unifier,
}
impl<'a> Fold<Option<Type>> for TypeToStringFolder<'a> {
type TargetU = String;
type Error = String;
fn map_user(&mut self, user: Option<Type>) -> Result<Self::TargetU, Self::Error> {
Ok(if let Some(ty) = user {
self.unifier.internal_stringify(
ty,
&mut |id| format!("class{id}"),
&mut |id| format!("typevar{id}"),
&mut None,
)
} else {
"None".into()
})
}
}

View File

@ -1,659 +0,0 @@
use strum::IntoEnumIterator;
use nac3parser::ast::Constant;
use super::{
helper::{PrimDef, PrimDefDetails},
*,
};
use crate::{symbol_resolver::SymbolValue, typecheck::typedef::VarMap};
#[derive(Clone, Debug)]
pub enum TypeAnnotation {
Primitive(Type),
// we use type vars kind at params to represent self type
CustomClass {
id: DefinitionId,
// params can also be type var
params: Vec<TypeAnnotation>,
},
// can only be CustomClassKind
Virtual(Box<TypeAnnotation>),
TypeVar(Type),
/// A `Literal` allowing a subset of literals.
Literal(Vec<Constant>),
Tuple(Vec<TypeAnnotation>),
}
impl TypeAnnotation {
pub fn stringify(&self, unifier: &mut Unifier) -> String {
use TypeAnnotation::*;
match self {
Primitive(ty) | TypeVar(ty) => unifier.stringify(*ty),
CustomClass { id, params } => {
let class_name = if let Some(ref top) = unifier.top_level {
if let TopLevelDef::Class { name, .. } = &*top.definitions.read()[id.0].read() {
(*name).into()
} else {
unreachable!()
}
} else {
format!("class_def_{}", id.0)
};
format!("{}{}", class_name, {
let param_list =
params.iter().map(|p| p.stringify(unifier)).collect_vec().join(", ");
if param_list.is_empty() {
String::new()
} else {
format!("[{param_list}]")
}
})
}
Literal(values) => {
format!("Literal({})", values.iter().map(|v| format!("{v:?}")).join(", "))
}
Virtual(ty) => format!("virtual[{}]", ty.stringify(unifier)),
Tuple(types) => {
format!(
"tuple[{}]",
types.iter().map(|p| p.stringify(unifier)).collect_vec().join(", ")
)
}
}
}
}
/// Parses an AST expression `expr` into a [`TypeAnnotation`].
///
/// * `locked` - A [`HashMap`] containing the IDs of known definitions, mapped to a [`Vec`] of all
/// generic variables associated with the definition.
/// * `type_var` - The type variable associated with the type argument currently being parsed. Pass
/// [`None`] when this function is invoked externally.
pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
resolver: &(dyn SymbolResolver + Send + Sync),
top_level_defs: &[Arc<RwLock<TopLevelDef>>],
unifier: &mut Unifier,
primitives: &PrimitiveStore,
expr: &ast::Expr<T>,
// the key stores the type_var of this topleveldef::class, we only need this field here
locked: HashMap<DefinitionId, Vec<Type>, S>,
) -> Result<TypeAnnotation, HashSet<String>> {
let name_handle = |id: &StrRef,
unifier: &mut Unifier,
locked: HashMap<DefinitionId, Vec<Type>, S>| {
if id == &"int32".into() {
Ok(TypeAnnotation::Primitive(primitives.int32))
} else if id == &"int64".into() {
Ok(TypeAnnotation::Primitive(primitives.int64))
} else if id == &"uint32".into() {
Ok(TypeAnnotation::Primitive(primitives.uint32))
} else if id == &"uint64".into() {
Ok(TypeAnnotation::Primitive(primitives.uint64))
} else if id == &"float".into() {
Ok(TypeAnnotation::Primitive(primitives.float))
} else if id == &"bool".into() {
Ok(TypeAnnotation::Primitive(primitives.bool))
} else if id == &"str".into() {
Ok(TypeAnnotation::Primitive(primitives.str))
} else if id == &"Exception".into() {
Ok(TypeAnnotation::CustomClass { id: PrimDef::Exception.id(), params: Vec::default() })
} else if let Ok(obj_id) = resolver.get_identifier_def(*id) {
let type_vars = {
let def_read = top_level_defs[obj_id.0].try_read();
if let Some(def_read) = def_read {
if let TopLevelDef::Class { type_vars, .. } = &*def_read {
type_vars.clone()
} else {
return Err(HashSet::from([format!(
"function cannot be used as a type (at {})",
expr.location
)]));
}
} else {
locked.get(&obj_id).unwrap().clone()
}
};
// check param number here
if !type_vars.is_empty() {
return Err(HashSet::from([format!(
"expect {} type variable parameter but got 0 (at {})",
type_vars.len(),
expr.location,
)]));
}
Ok(TypeAnnotation::CustomClass { id: obj_id, params: vec![] })
} else if let Ok(ty) = resolver.get_symbol_type(unifier, top_level_defs, primitives, *id) {
if let TypeEnum::TVar { .. } = unifier.get_ty(ty).as_ref() {
let var = unifier.get_fresh_var(Some(*id), Some(expr.location)).ty;
unifier.unify(var, ty).unwrap();
Ok(TypeAnnotation::TypeVar(ty))
} else {
Err(HashSet::from([format!(
"`{}` is not a valid type annotation (at {})",
id, expr.location
)]))
}
} else {
Err(HashSet::from([format!(
"`{}` is not a valid type annotation (at {})",
id, expr.location
)]))
}
};
let class_name_handle =
|id: &StrRef,
slice: &ast::Expr<T>,
unifier: &mut Unifier,
mut locked: HashMap<DefinitionId, Vec<Type>, S>| {
if ["virtual".into(), "Generic".into(), "tuple".into(), "Option".into()].contains(id) {
return Err(HashSet::from([format!(
"keywords cannot be class name (at {})",
expr.location
)]));
}
let obj_id = resolver.get_identifier_def(*id)?;
let type_vars = {
let def_read = top_level_defs[obj_id.0].try_read();
if let Some(def_read) = def_read {
let TopLevelDef::Class { type_vars, .. } = &*def_read else {
unreachable!("must be class here")
};
type_vars.clone()
} else {
locked.get(&obj_id).unwrap().clone()
}
};
// we do not check whether the application of type variables are compatible here
let param_type_infos = {
let params_ast = if let ast::ExprKind::Tuple { elts, .. } = &slice.node {
elts.iter().collect_vec()
} else {
vec![slice]
};
if type_vars.len() != params_ast.len() {
return Err(HashSet::from([format!(
"expect {} type parameters but got {} (at {})",
type_vars.len(),
params_ast.len(),
params_ast[0].location,
)]));
}
let result = params_ast
.iter()
.map(|x| {
parse_ast_to_type_annotation_kinds(
resolver,
top_level_defs,
unifier,
primitives,
x,
{
locked.insert(obj_id, type_vars.clone());
locked.clone()
},
)
})
.collect::<Result<Vec<_>, _>>()?;
// make sure the result do not contain any type vars
let no_type_var =
result.iter().all(|x| get_type_var_contained_in_type_annotation(x).is_empty());
if no_type_var {
result
} else {
return Err(HashSet::from([
format!(
"application of type vars to generic class is not currently supported (at {})",
params_ast[0].location
),
]));
}
};
Ok(TypeAnnotation::CustomClass { id: obj_id, params: param_type_infos })
};
match &expr.node {
ast::ExprKind::Name { id, .. } => name_handle(id, unifier, locked),
// virtual
ast::ExprKind::Subscript { value, slice, .. }
if {
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"virtual".into())
} =>
{
let def = parse_ast_to_type_annotation_kinds(
resolver,
top_level_defs,
unifier,
primitives,
slice.as_ref(),
locked,
)?;
if !matches!(def, TypeAnnotation::CustomClass { .. }) {
unreachable!("must be concretized custom class kind in the virtual")
}
Ok(TypeAnnotation::Virtual(def.into()))
}
// option
ast::ExprKind::Subscript { value, slice, .. }
if {
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"Option".into())
} =>
{
let def_ann = parse_ast_to_type_annotation_kinds(
resolver,
top_level_defs,
unifier,
primitives,
slice.as_ref(),
locked,
)?;
let id =
if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty(primitives.option).as_ref() {
*obj_id
} else {
unreachable!()
};
Ok(TypeAnnotation::CustomClass { id, params: vec![def_ann] })
}
// tuple
ast::ExprKind::Subscript { value, slice, .. }
if {
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"tuple".into())
} =>
{
let tup_elts = {
if let ast::ExprKind::Tuple { elts, .. } = &slice.node {
elts.as_slice()
} else {
std::slice::from_ref(slice.as_ref())
}
};
let type_annotations = tup_elts
.iter()
.map(|e| {
parse_ast_to_type_annotation_kinds(
resolver,
top_level_defs,
unifier,
primitives,
e,
locked.clone(),
)
})
.collect::<Result<Vec<_>, _>>()?;
Ok(TypeAnnotation::Tuple(type_annotations))
}
// Literal
ast::ExprKind::Subscript { value, slice, .. }
if {
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"Literal".into())
} =>
{
let tup_elts = {
if let ast::ExprKind::Tuple { elts, .. } = &slice.node {
elts.as_slice()
} else {
std::slice::from_ref(slice.as_ref())
}
};
let type_annotations = tup_elts
.iter()
.map(|e| match &e.node {
ast::ExprKind::Constant { value, .. } => {
Ok(TypeAnnotation::Literal(vec![value.clone()]))
}
_ => parse_ast_to_type_annotation_kinds(
resolver,
top_level_defs,
unifier,
primitives,
e,
locked.clone(),
),
})
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flat_map(|type_ann| match type_ann {
TypeAnnotation::Literal(values) => values,
_ => unreachable!(),
})
.collect_vec();
if type_annotations.len() == 1 {
Ok(TypeAnnotation::Literal(type_annotations))
} else {
Err(HashSet::from([format!(
"multiple literal bounds are currently unsupported (at {})",
value.location
)]))
}
}
// custom class
ast::ExprKind::Subscript { value, slice, .. } => {
if let ast::ExprKind::Name { id, .. } = &value.node {
class_name_handle(id, slice, unifier, locked)
} else {
Err(HashSet::from([format!(
"unsupported expression type for class name (at {})",
value.location
)]))
}
}
ast::ExprKind::Constant { value, .. } => Ok(TypeAnnotation::Literal(vec![value.clone()])),
_ => Err(HashSet::from([format!(
"unsupported expression for type annotation (at {})",
expr.location
)])),
}
}
// no need to have the `locked` parameter, unlike the `parse_ast_to_type_annotation_kinds`, since
// when calling this function, there should be no topleveldef::class being write, and this function
// also only read the toplevedefs
pub fn get_type_from_type_annotation_kinds(
top_level_defs: &[Arc<RwLock<TopLevelDef>>],
unifier: &mut Unifier,
primitives: &PrimitiveStore,
ann: &TypeAnnotation,
subst_list: &mut Option<Vec<Type>>,
) -> Result<Type, HashSet<String>> {
match ann {
TypeAnnotation::CustomClass { id: obj_id, params } => {
let def_read = top_level_defs[obj_id.0].read();
let class_def: &TopLevelDef = &def_read;
let TopLevelDef::Class { fields, methods, type_vars, .. } = class_def else {
unreachable!("should be class def here")
};
if type_vars.len() != params.len() {
return Err(HashSet::from([format!(
"unexpected number of type parameters: expected {} but got {}",
type_vars.len(),
params.len()
)]));
}
let param_ty = params
.iter()
.map(|x| {
get_type_from_type_annotation_kinds(
top_level_defs,
unifier,
primitives,
x,
subst_list,
)
})
.collect::<Result<Vec<_>, _>>()?;
let ty = if let Some(prim_def) = PrimDef::iter().find(|prim| prim.id() == *obj_id) {
// Primitive TopLevelDefs do not contain all fields that are present in their Type
// counterparts, so directly perform subst on the Type instead.
let PrimDefDetails::PrimClass { get_ty_fn, .. } = prim_def.details() else {
unreachable!()
};
let base_ty = get_ty_fn(primitives);
let params =
if let TypeEnum::TObj { params, .. } = &*unifier.get_ty_immutable(base_ty) {
params.clone()
} else {
unreachable!()
};
unifier
.subst(
get_ty_fn(primitives),
&params
.iter()
.zip(param_ty)
.map(|(obj_tv, param)| (*obj_tv.0, param))
.collect(),
)
.unwrap_or(base_ty)
} else {
let subst = {
// check for compatible range
// TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check
let mut result = VarMap::new();
for (tvar, p) in type_vars.iter().zip(param_ty) {
match unifier.get_ty(*tvar).as_ref() {
TypeEnum::TVar {
id,
range,
fields: None,
name,
loc,
is_const_generic: false,
} => {
let ok: bool = {
// create a temp type var and unify to check compatibility
p == *tvar || {
let temp = unifier.get_fresh_var_with_range(
range.as_slice(),
*name,
*loc,
);
unifier.unify(temp.ty, p).is_ok()
}
};
if ok {
result.insert(*id, p);
} else {
return Err(HashSet::from([format!(
"cannot apply type {} to type variable with id {:?}",
unifier.internal_stringify(
p,
&mut |id| format!("class{id}"),
&mut |id| format!("typevar{id}"),
&mut None
),
*id
)]));
}
}
TypeEnum::TVar {
id, range, name, loc, is_const_generic: true, ..
} => {
let ty = range[0];
let ok: bool = {
// create a temp type var and unify to check compatibility
p == *tvar || {
let temp =
unifier.get_fresh_const_generic_var(ty, *name, *loc);
unifier.unify(temp.ty, p).is_ok()
}
};
if ok {
result.insert(*id, p);
} else {
return Err(HashSet::from([format!(
"cannot apply type {} to type variable {}",
unifier.stringify(p),
name.unwrap_or_else(|| format!("typevar{id}").into()),
)]));
}
}
_ => unreachable!("must be generic type var"),
}
}
result
};
// Class Attributes keep a copy with Class Definition and are not added to objects
let mut tobj_fields = methods
.iter()
.map(|(name, ty, _)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
// methods are immutable
(*name, (subst_ty, false))
})
.collect::<HashMap<_, _>>();
tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*name, (subst_ty, *mutability))
}));
let need_subst = !subst.is_empty();
let ty = unifier.add_ty(TypeEnum::TObj {
obj_id: *obj_id,
fields: tobj_fields,
params: subst,
});
if need_subst {
if let Some(wl) = subst_list.as_mut() {
wl.push(ty);
}
}
ty
};
Ok(ty)
}
TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty),
TypeAnnotation::Literal(values) => {
let values = values
.iter()
.map(SymbolValue::from_constant_inferred)
.collect::<Result<Vec<_>, _>>()
.map_err(|err| HashSet::from([err]))?;
let var = unifier.get_fresh_literal(values, None);
Ok(var)
}
TypeAnnotation::Virtual(ty) => {
let ty = get_type_from_type_annotation_kinds(
top_level_defs,
unifier,
primitives,
ty.as_ref(),
subst_list,
)?;
Ok(unifier.add_ty(TypeEnum::TVirtual { ty }))
}
TypeAnnotation::Tuple(tys) => {
let tys = tys
.iter()
.map(|x| {
get_type_from_type_annotation_kinds(
top_level_defs,
unifier,
primitives,
x,
subst_list,
)
})
.collect::<Result<Vec<_>, _>>()?;
Ok(unifier.add_ty(TypeEnum::TTuple { ty: tys, is_vararg_ctx: false }))
}
}
}
/// given an def id, return a type annotation of self \
/// ```python
/// class A(Generic[T, V]):
/// def fun(self):
/// ```
/// the type of `self` should be similar to `A[T, V]`, where `T`, `V`
/// considered to be type variables associated with the class \
/// \
/// But note that here we do not make a duplication of `T`, `V`, we directly
/// use them as they are in the [`TopLevelDef::Class`] since those in the
/// `TopLevelDef::Class.type_vars` will be substitute later when seeing applications/instantiations
/// the Type of their fields and methods will also be subst when application/instantiation
#[must_use]
pub fn make_self_type_annotation(type_vars: &[Type], object_id: DefinitionId) -> TypeAnnotation {
TypeAnnotation::CustomClass {
id: object_id,
params: type_vars.iter().map(|ty| TypeAnnotation::TypeVar(*ty)).collect_vec(),
}
}
/// get all the occurences of type vars contained in a type annotation
/// e.g. `A[int, B[T], V, virtual[C[G]]]` => [T, V, G]
/// this function will not make a duplicate of type var
#[must_use]
pub fn get_type_var_contained_in_type_annotation(ann: &TypeAnnotation) -> Vec<TypeAnnotation> {
let mut result: Vec<TypeAnnotation> = Vec::new();
match ann {
TypeAnnotation::TypeVar(..) => result.push(ann.clone()),
TypeAnnotation::Virtual(ann) => {
result.extend(get_type_var_contained_in_type_annotation(ann.as_ref()));
}
TypeAnnotation::CustomClass { params, .. } => {
for p in params {
result.extend(get_type_var_contained_in_type_annotation(p));
}
}
TypeAnnotation::Tuple(anns) => {
for a in anns {
result.extend(get_type_var_contained_in_type_annotation(a));
}
}
TypeAnnotation::Primitive(..) | TypeAnnotation::Literal { .. } => {}
}
result
}
/// check the type compatibility for overload
pub fn check_overload_type_annotation_compatible(
this: &TypeAnnotation,
other: &TypeAnnotation,
unifier: &mut Unifier,
) -> bool {
match (this, other) {
(TypeAnnotation::Primitive(a), TypeAnnotation::Primitive(b)) => a == b,
(TypeAnnotation::TypeVar(a), TypeAnnotation::TypeVar(b)) => {
let a = unifier.get_ty(*a);
let a = &*a;
let b = unifier.get_ty(*b);
let b = &*b;
let (
TypeEnum::TVar { id: a, fields: None, .. },
TypeEnum::TVar { id: b, fields: None, .. },
) = (a, b)
else {
unreachable!("must be type var")
};
a == b
}
(TypeAnnotation::Virtual(a), TypeAnnotation::Virtual(b)) => {
check_overload_type_annotation_compatible(a.as_ref(), b.as_ref(), unifier)
}
(TypeAnnotation::Tuple(a), TypeAnnotation::Tuple(b)) => {
a.len() == b.len() && {
a.iter()
.zip(b)
.all(|(a, b)| check_overload_type_annotation_compatible(a, b, unifier))
}
}
(
TypeAnnotation::CustomClass { id: a, params: a_p },
TypeAnnotation::CustomClass { id: b, params: b_p },
) => {
a.0 == b.0 && {
a_p.len() == b_p.len() && {
a_p.iter()
.zip(b_p)
.all(|(a, b)| check_overload_type_annotation_compatible(a, b, unifier))
}
}
}
_ => false,
}
}

View File

@ -0,0 +1,228 @@
use super::super::typedef::*;
use super::TopLevelContext;
use std::boxed::Box;
use std::collections::HashMap;
struct ContextStack<'a> {
/// stack level, starts from 0
level: u32,
/// stack of variable definitions containing (id, def, level) where `def` is the original
/// definition in `level-1`.
var_defs: Vec<(usize, VarDef<'a>, u32)>,
/// stack of symbol definitions containing (name, level) where `level` is the smallest level
/// where the name is assigned a value
sym_def: Vec<(&'a str, u32)>,
}
pub struct InferenceContext<'a> {
/// top level context
top_level: TopLevelContext<'a>,
/// list of primitive instances
primitives: Vec<Type>,
/// list of variable instances
variables: Vec<Type>,
/// identifier to type mapping.
sym_table: HashMap<&'a str, Type>,
/// resolution function reference, that may resolve unbounded identifiers to some type
resolution_fn: Box<dyn FnMut(&str) -> Result<Type, String>>,
/// stack
stack: ContextStack<'a>,
/// return type
result: Option<Type>,
}
// non-trivial implementations here
impl<'a> InferenceContext<'a> {
/// return a new `InferenceContext` from `TopLevelContext` and resolution function.
pub fn new(
top_level: TopLevelContext,
resolution_fn: Box<dyn FnMut(&str) -> Result<Type, String>>,
) -> InferenceContext {
let primitives = (0..top_level.primitive_defs.len())
.map(|v| TypeEnum::PrimitiveType(PrimitiveId(v)).into())
.collect();
let variables = (0..top_level.var_defs.len())
.map(|v| TypeEnum::TypeVariable(VariableId(v)).into())
.collect();
InferenceContext {
top_level,
primitives,
variables,
sym_table: HashMap::new(),
resolution_fn,
stack: ContextStack {
level: 0,
var_defs: Vec::new(),
sym_def: Vec::new(),
},
result: None,
}
}
/// execute the function with new scope.
/// variable assignment would be limited within the scope (not readable outside), and type
/// variable type guard would be limited within the scope.
/// returns the list of variables assigned within the scope, and the result of the function
pub fn with_scope<F, R>(&mut self, f: F) -> (Vec<(&'a str, Type)>, R)
where
F: FnOnce(&mut Self) -> R,
{
self.stack.level += 1;
let result = f(self);
self.stack.level -= 1;
while !self.stack.var_defs.is_empty() {
let (_, _, level) = self.stack.var_defs.last().unwrap();
if *level > self.stack.level {
let (id, def, _) = self.stack.var_defs.pop().unwrap();
self.top_level.var_defs[id] = def;
} else {
break;
}
}
let mut poped_names = Vec::new();
while !self.stack.sym_def.is_empty() {
let (_, level) = self.stack.sym_def.last().unwrap();
if *level > self.stack.level {
let (name, _) = self.stack.sym_def.pop().unwrap();
let ty = self.sym_table.remove(name).unwrap();
poped_names.push((name, ty));
} else {
break;
}
}
(poped_names, result)
}
/// assign a type to an identifier.
/// may return error if the identifier was defined but with different type
pub fn assign(&mut self, name: &'a str, ty: Type) -> Result<Type, String> {
if let Some(t) = self.sym_table.get_mut(name) {
if t == &ty {
Ok(ty)
} else {
Err("different types".into())
}
} else {
self.stack.sym_def.push((name, self.stack.level));
self.sym_table.insert(name, ty.clone());
Ok(ty)
}
}
/// check if an identifier is already defined
pub fn defined(&self, name: &str) -> bool {
self.sym_table.get(name).is_some()
}
/// get the type of an identifier
/// may return error if the identifier is not defined, and cannot be resolved with the
/// resolution function.
pub fn resolve(&mut self, name: &str) -> Result<Type, String> {
if let Some(t) = self.sym_table.get(name) {
Ok(t.clone())
} else {
self.resolution_fn.as_mut()(name)
}
}
/// restrict the bound of a type variable by replacing its definition.
/// used for implementing type guard
pub fn restrict(&mut self, id: VariableId, mut def: VarDef<'a>) {
std::mem::swap(self.top_level.var_defs.get_mut(id.0).unwrap(), &mut def);
self.stack.var_defs.push((id.0, def, self.stack.level));
}
pub fn set_result(&mut self, result: Option<Type>) {
self.result = result;
}
}
// trivial getters:
impl<'a> InferenceContext<'a> {
pub fn get_primitive(&self, id: PrimitiveId) -> Type {
self.primitives.get(id.0).unwrap().clone()
}
pub fn get_variable(&self, id: VariableId) -> Type {
self.variables.get(id.0).unwrap().clone()
}
pub fn get_fn_def(&self, name: &str) -> Option<&FnDef> {
self.top_level.fn_table.get(name)
}
pub fn get_primitive_def(&self, id: PrimitiveId) -> &TypeDef {
self.top_level.primitive_defs.get(id.0).unwrap()
}
pub fn get_class_def(&self, id: ClassId) -> &ClassDef {
self.top_level.class_defs.get(id.0).unwrap()
}
pub fn get_parametric_def(&self, id: ParamId) -> &ParametricDef {
self.top_level.parametric_defs.get(id.0).unwrap()
}
pub fn get_variable_def(&self, id: VariableId) -> &VarDef {
self.top_level.var_defs.get(id.0).unwrap()
}
pub fn get_type(&self, name: &str) -> Option<Type> {
self.top_level.get_type(name)
}
pub fn get_result(&self) -> Option<Type> {
self.result.clone()
}
}
impl TypeEnum {
pub fn subst(&self, map: &HashMap<VariableId, Type>) -> TypeEnum {
match self {
TypeEnum::TypeVariable(id) => map.get(id).map(|v| v.as_ref()).unwrap_or(self).clone(),
TypeEnum::ParametricType(id, params) => TypeEnum::ParametricType(
*id,
params
.iter()
.map(|v| v.as_ref().subst(map).into())
.collect(),
),
_ => self.clone(),
}
}
pub fn inv_subst(&self, map: &[(Type, Type)]) -> Type {
for (from, to) in map.iter() {
if self == from.as_ref() {
return to.clone();
}
}
match self {
TypeEnum::ParametricType(id, params) => TypeEnum::ParametricType(
*id,
params.iter().map(|v| v.as_ref().inv_subst(map)).collect(),
),
_ => self.clone(),
}
.into()
}
pub fn get_subst(&self, ctx: &InferenceContext) -> HashMap<VariableId, Type> {
match self {
TypeEnum::ParametricType(id, params) => {
let vars = &ctx.get_parametric_def(*id).params;
vars.iter()
.zip(params)
.map(|(v, p)| (*v, p.as_ref().clone().into()))
.collect()
}
// if this proves to be slow, we can use option type
_ => HashMap::new(),
}
}
pub fn get_base<'a>(&'a self, ctx: &'a InferenceContext) -> Option<&'a TypeDef> {
match self {
TypeEnum::PrimitiveType(id) => Some(ctx.get_primitive_def(*id)),
TypeEnum::ClassType(id) | TypeEnum::VirtualClassType(id) => {
Some(&ctx.get_class_def(*id).base)
}
TypeEnum::ParametricType(id, _) => Some(&ctx.get_parametric_def(*id).base),
_ => None,
}
}
}

View File

@ -0,0 +1,4 @@
mod inference_context;
mod top_level_context;
pub use inference_context::InferenceContext;
pub use top_level_context::TopLevelContext;

View File

@ -0,0 +1,138 @@
use super::super::typedef::*;
use std::collections::HashMap;
use std::rc::Rc;
/// Structure for storing top-level type definitions.
/// Used for collecting type signature from source code.
/// Can be converted to `InferenceContext` for type inference in functions.
pub struct TopLevelContext<'a> {
/// List of primitive definitions.
pub(super) primitive_defs: Vec<TypeDef<'a>>,
/// List of class definitions.
pub(super) class_defs: Vec<ClassDef<'a>>,
/// List of parametric type definitions.
pub(super) parametric_defs: Vec<ParametricDef<'a>>,
/// List of type variable definitions.
pub(super) var_defs: Vec<VarDef<'a>>,
/// Function name to signature mapping.
pub(super) fn_table: HashMap<&'a str, FnDef>,
/// Type name to type mapping.
pub(super) sym_table: HashMap<&'a str, Type>,
primitives: Vec<Type>,
variables: Vec<Type>,
}
impl<'a> TopLevelContext<'a> {
pub fn new(primitive_defs: Vec<TypeDef<'a>>) -> TopLevelContext {
let mut sym_table = HashMap::new();
let mut primitives = Vec::new();
for (i, t) in primitive_defs.iter().enumerate() {
primitives.push(TypeEnum::PrimitiveType(PrimitiveId(i)).into());
sym_table.insert(t.name, TypeEnum::PrimitiveType(PrimitiveId(i)).into());
}
TopLevelContext {
primitive_defs,
class_defs: Vec::new(),
parametric_defs: Vec::new(),
var_defs: Vec::new(),
fn_table: HashMap::new(),
sym_table,
primitives,
variables: Vec::new(),
}
}
pub fn add_class(&mut self, def: ClassDef<'a>) -> ClassId {
self.sym_table.insert(
def.base.name,
TypeEnum::ClassType(ClassId(self.class_defs.len())).into(),
);
self.class_defs.push(def);
ClassId(self.class_defs.len() - 1)
}
pub fn add_parametric(&mut self, def: ParametricDef<'a>) -> ParamId {
let params = def
.params
.iter()
.map(|&v| Rc::new(TypeEnum::TypeVariable(v)))
.collect();
self.sym_table.insert(
def.base.name,
TypeEnum::ParametricType(ParamId(self.parametric_defs.len()), params).into(),
);
self.parametric_defs.push(def);
ParamId(self.parametric_defs.len() - 1)
}
pub fn add_variable(&mut self, def: VarDef<'a>) -> VariableId {
self.sym_table.insert(
def.name,
TypeEnum::TypeVariable(VariableId(self.var_defs.len())).into(),
);
self.add_variable_private(def)
}
pub fn add_variable_private(&mut self, def: VarDef<'a>) -> VariableId {
self.var_defs.push(def);
self.variables
.push(TypeEnum::TypeVariable(VariableId(self.var_defs.len() - 1)).into());
VariableId(self.var_defs.len() - 1)
}
pub fn add_fn(&mut self, name: &'a str, def: FnDef) {
self.fn_table.insert(name, def);
}
pub fn get_fn_def(&self, name: &str) -> Option<&FnDef> {
self.fn_table.get(name)
}
pub fn get_primitive_def_mut(&mut self, id: PrimitiveId) -> &mut TypeDef<'a> {
self.primitive_defs.get_mut(id.0).unwrap()
}
pub fn get_primitive_def(&self, id: PrimitiveId) -> &TypeDef {
self.primitive_defs.get(id.0).unwrap()
}
pub fn get_class_def_mut(&mut self, id: ClassId) -> &mut ClassDef<'a> {
self.class_defs.get_mut(id.0).unwrap()
}
pub fn get_class_def(&self, id: ClassId) -> &ClassDef {
self.class_defs.get(id.0).unwrap()
}
pub fn get_parametric_def_mut(&mut self, id: ParamId) -> &mut ParametricDef<'a> {
self.parametric_defs.get_mut(id.0).unwrap()
}
pub fn get_parametric_def(&self, id: ParamId) -> &ParametricDef {
self.parametric_defs.get(id.0).unwrap()
}
pub fn get_variable_def_mut(&mut self, id: VariableId) -> &mut VarDef<'a> {
self.var_defs.get_mut(id.0).unwrap()
}
pub fn get_variable_def(&self, id: VariableId) -> &VarDef {
self.var_defs.get(id.0).unwrap()
}
pub fn get_primitive(&self, id: PrimitiveId) -> Type {
self.primitives.get(id.0).unwrap().clone()
}
pub fn get_variable(&self, id: VariableId) -> Type {
self.variables.get(id.0).unwrap().clone()
}
pub fn get_type(&self, name: &str) -> Option<Type> {
// TODO: handle name visibility
// possibly by passing a function from outside to tell what names are allowed, and what are
// not...
self.sym_table.get(name).cloned()
}
}

View File

@ -0,0 +1,967 @@
use super::context::InferenceContext;
use super::inference_core::resolve_call;
use super::magic_methods::*;
use super::primitives::*;
use super::typedef::{Type, TypeEnum::*};
use rustpython_parser::ast::{
Comparison, Comprehension, ComprehensionKind, Expression, ExpressionType, Operator,
UnaryOperator,
};
use std::convert::TryInto;
type ParserResult = Result<Option<Type>, String>;
pub fn infer_expr<'a>(
ctx: &mut InferenceContext<'a>,
expr: &'a Expression,
) -> ParserResult {
match &expr.node {
ExpressionType::Number { value } => infer_constant(ctx, value),
ExpressionType::Identifier { name } => infer_identifier(ctx, name),
ExpressionType::List { elements } => infer_list(ctx, elements),
ExpressionType::Tuple { elements } => infer_tuple(ctx, elements),
ExpressionType::Attribute { value, name } => infer_attribute(ctx, value, name),
ExpressionType::BoolOp { values, .. } => infer_bool_ops(ctx, values),
ExpressionType::Binop { a, b, op } => infer_bin_ops(ctx, op, a, b),
ExpressionType::Unop { op, a } => infer_unary_ops(ctx, op, a),
ExpressionType::Compare { vals, ops } => infer_compare(ctx, vals, ops),
ExpressionType::Call {
args,
function,
keywords,
} => {
if !keywords.is_empty() {
Err("keyword is not supported".into())
} else {
infer_call(ctx, &args, &function)
}
}
ExpressionType::Subscript { a, b } => infer_subscript(ctx, a, b),
ExpressionType::IfExpression { test, body, orelse } => {
infer_if_expr(ctx, &test, &body, orelse)
}
ExpressionType::Comprehension { kind, generators } => match kind.as_ref() {
ComprehensionKind::List { element } => {
if generators.len() == 1 {
infer_list_comprehension(ctx, element, &generators[0])
} else {
Err("only 1 generator statement is supported".into())
}
}
_ => Err("only list comprehension is supported".into()),
},
ExpressionType::True | ExpressionType::False => Ok(Some(ctx.get_primitive(BOOL_TYPE))),
_ => Err("not supported".into()),
}
}
fn infer_constant(
ctx: &mut InferenceContext,
value: &rustpython_parser::ast::Number,
) -> ParserResult {
use rustpython_parser::ast::Number;
match value {
Number::Integer { value } => {
let int32: Result<i32, _> = value.try_into();
if int32.is_ok() {
Ok(Some(ctx.get_primitive(INT32_TYPE)))
} else {
let int64: Result<i64, _> = value.try_into();
if int64.is_ok() {
Ok(Some(ctx.get_primitive(INT64_TYPE)))
} else {
Err("integer out of range".into())
}
}
}
Number::Float { .. } => Ok(Some(ctx.get_primitive(FLOAT_TYPE))),
_ => Err("not supported".into()),
}
}
fn infer_identifier(ctx: &mut InferenceContext, name: &str) -> ParserResult {
Ok(Some(ctx.resolve(name)?))
}
fn infer_list<'a>(
ctx: &mut InferenceContext<'a>,
elements: &'a [Expression],
) -> ParserResult {
if elements.is_empty() {
return Ok(Some(ParametricType(LIST_TYPE, vec![BotType.into()]).into()));
}
let mut types = elements.iter().map(|v| infer_expr(ctx, v));
let head = types.next().unwrap()?;
if head.is_none() {
return Err("list elements must have some type".into());
}
for v in types {
if v? != head {
return Err("inhomogeneous list is not allowed".into());
}
}
Ok(Some(ParametricType(LIST_TYPE, vec![head.unwrap()]).into()))
}
fn infer_tuple<'a>(
ctx: &mut InferenceContext<'a>,
elements: &'a [Expression],
) -> ParserResult {
let types: Result<Option<Vec<_>>, String> =
elements.iter().map(|v| infer_expr(ctx, v)).collect();
if let Some(t) = types? {
Ok(Some(ParametricType(TUPLE_TYPE, t).into()))
} else {
Err("tuple elements must have some type".into())
}
}
fn infer_attribute<'a>(
ctx: &mut InferenceContext<'a>,
value: &'a Expression,
name: &str,
) -> ParserResult {
let value = infer_expr(ctx, value)?.ok_or_else(|| "no value".to_string())?;
if let TypeVariable(id) = value.as_ref() {
let v = ctx.get_variable_def(*id);
if v.bound.is_empty() {
return Err("no fields on unbounded type variable".into());
}
let ty = v.bound[0].get_base(ctx).and_then(|v| v.fields.get(name));
if ty.is_none() {
return Err("unknown field".into());
}
for x in v.bound[1..].iter() {
let ty1 = x.get_base(ctx).and_then(|v| v.fields.get(name));
if ty1 != ty {
return Err("unknown field (type mismatch between variants)".into());
}
}
return Ok(Some(ty.unwrap().clone()));
}
match value.get_base(ctx) {
Some(b) => match b.fields.get(name) {
Some(t) => Ok(Some(t.clone())),
None => Err("no such field".into()),
},
None => Err("this object has no fields".into()),
}
}
fn infer_bool_ops<'a>(ctx: &mut InferenceContext<'a>, values: &'a [Expression]) -> ParserResult {
assert_eq!(values.len(), 2);
let left = infer_expr(ctx, &values[0])?.ok_or_else(|| "no value".to_string())?;
let right = infer_expr(ctx, &values[1])?.ok_or_else(|| "no value".to_string())?;
let b = ctx.get_primitive(BOOL_TYPE);
if left == b && right == b {
Ok(Some(b))
} else {
Err("bool operands must be bool".into())
}
}
fn infer_bin_ops<'a>(
ctx: &mut InferenceContext<'a>,
op: &Operator,
left: &'a Expression,
right: &'a Expression,
) -> ParserResult {
let left = infer_expr(ctx, left)?.ok_or_else(|| "no value".to_string())?;
let right = infer_expr(ctx, right)?.ok_or_else(|| "no value".to_string())?;
let fun = binop_name(op);
resolve_call(ctx, Some(left), fun, &[right])
}
fn infer_unary_ops<'a>(
ctx: &mut InferenceContext<'a>,
op: &UnaryOperator,
obj: &'a Expression,
) -> ParserResult {
let ty = infer_expr(ctx, obj)?.ok_or_else(|| "no value".to_string())?;
if let UnaryOperator::Not = op {
if ty == ctx.get_primitive(BOOL_TYPE) {
Ok(Some(ty))
} else {
Err("logical not must be applied to bool".into())
}
} else {
resolve_call(ctx, Some(ty), unaryop_name(op), &[])
}
}
fn infer_compare<'a>(
ctx: &mut InferenceContext<'a>,
vals: &'a [Expression],
ops: &'a [Comparison],
) -> ParserResult {
let types: Result<Option<Vec<_>>, _> = vals.iter().map(|v| infer_expr(ctx, v)).collect();
let types = types?;
if types.is_none() {
return Err("comparison operands must have type".into());
}
let types = types.unwrap();
let boolean = ctx.get_primitive(BOOL_TYPE);
let left = &types[..types.len() - 1];
let right = &types[1..];
for ((a, b), op) in left.iter().zip(right.iter()).zip(ops.iter()) {
let fun = comparison_name(op).ok_or_else(|| "unsupported comparison".to_string())?;
let ty = resolve_call(ctx, Some(a.clone()), fun, &[b.clone()])?;
if ty.is_none() || ty.unwrap() != boolean {
return Err("comparison result must be boolean".into());
}
}
Ok(Some(boolean))
}
fn infer_call<'a>(
ctx: &mut InferenceContext<'a>,
args: &'a [Expression],
function: &'a Expression,
) -> ParserResult {
let types: Result<Option<Vec<_>>, _> = args.iter().map(|v| infer_expr(ctx, v)).collect();
let types = types?;
if types.is_none() {
return Err("function params must have type".into());
}
let (obj, fun) = match &function.node {
ExpressionType::Identifier { name } => (None, name),
ExpressionType::Attribute { value, name } => (
Some(infer_expr(ctx, &value)?.ok_or_else(|| "no value".to_string())?),
name,
),
_ => return Err("not supported".into()),
};
resolve_call(ctx, obj, fun.as_str(), &types.unwrap())
}
fn infer_subscript<'a>(
ctx: &mut InferenceContext<'a>,
a: &'a Expression,
b: &'a Expression,
) -> ParserResult {
let a = infer_expr(ctx, a)?.ok_or_else(|| "no value".to_string())?;
let t = if let ParametricType(LIST_TYPE, ls) = a.as_ref() {
ls[0].clone()
} else {
return Err("subscript is not supported for types other than list".into());
};
match &b.node {
ExpressionType::Slice { elements } => {
let int32 = ctx.get_primitive(INT32_TYPE);
let types: Result<Option<Vec<_>>, _> = elements
.iter()
.map(|v| {
if let ExpressionType::None = v.node {
Ok(Some(int32.clone()))
} else {
infer_expr(ctx, v)
}
})
.collect();
let types = types?.ok_or_else(|| "slice must have type".to_string())?;
if types.iter().all(|v| v == &int32) {
Ok(Some(a))
} else {
Err("slice must be int32 type".into())
}
}
_ => {
let b = infer_expr(ctx, b)?.ok_or_else(|| "no value".to_string())?;
if b == ctx.get_primitive(INT32_TYPE) {
Ok(Some(t))
} else {
Err("index must be either slice or int32".into())
}
}
}
}
fn infer_if_expr<'a>(
ctx: &mut InferenceContext<'a>,
test: &'a Expression,
body: &'a Expression,
orelse: &'a Expression,
) -> ParserResult {
let test = infer_expr(ctx, test)?.ok_or_else(|| "no value".to_string())?;
if test != ctx.get_primitive(BOOL_TYPE) {
return Err("test should be bool".into());
}
let body = infer_expr(ctx, body)?;
let orelse = infer_expr(ctx, orelse)?;
if body.as_ref() == orelse.as_ref() {
Ok(body)
} else {
Err("divergent type".into())
}
}
pub fn infer_simple_binding<'a>(
ctx: &mut InferenceContext<'a>,
name: &'a Expression,
ty: Type,
) -> Result<(), String> {
match &name.node {
ExpressionType::Identifier { name } => {
if name == "_" {
Ok(())
} else if ctx.defined(name.as_str()) {
Err("duplicated naming".into())
} else {
ctx.assign(name.as_str(), ty)?;
Ok(())
}
}
ExpressionType::Tuple { elements } => {
if let ParametricType(TUPLE_TYPE, ls) = ty.as_ref() {
if elements.len() == ls.len() {
for (a, b) in elements.iter().zip(ls.iter()) {
infer_simple_binding(ctx, a, b.clone())?;
}
Ok(())
} else {
Err("different length".into())
}
} else {
Err("not supported".into())
}
}
_ => Err("not supported".into()),
}
}
fn infer_list_comprehension<'a>(
ctx: &mut InferenceContext<'a>,
element: &'a Expression,
comprehension: &'a Comprehension,
) -> ParserResult {
if comprehension.is_async {
return Err("async is not supported".into());
}
let iter = infer_expr(ctx, &comprehension.iter)?.ok_or_else(|| "no value".to_string())?;
if let ParametricType(LIST_TYPE, ls) = iter.as_ref() {
ctx.with_scope(|ctx| {
infer_simple_binding(ctx, &comprehension.target, ls[0].clone())?;
let boolean = ctx.get_primitive(BOOL_TYPE);
for test in comprehension.ifs.iter() {
let result =
infer_expr(ctx, test)?.ok_or_else(|| "no value in test".to_string())?;
if result != boolean {
return Err("test must be bool".into());
}
}
let result = infer_expr(ctx, element)?.ok_or_else(|| "no value")?;
Ok(Some(ParametricType(LIST_TYPE, vec![result]).into()))
})
.1
} else {
Err("iteration is supported for list only".into())
}
}
#[cfg(test)]
mod test {
use super::{
super::{context::*, typedef::*},
*,
};
use rustpython_parser::parser::parse_expression;
use std::collections::HashMap;
use std::rc::Rc;
fn get_inference_context(ctx: TopLevelContext) -> InferenceContext {
InferenceContext::new(ctx, Box::new(|_| Err("unbounded identifier".into())))
}
#[test]
fn test_constants() {
let ctx = basic_ctx();
let mut ctx = get_inference_context(ctx);
let ast = parse_expression("123").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("2147483647").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("2147483648").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT64_TYPE));
let ast = parse_expression("9223372036854775807").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT64_TYPE));
let ast = parse_expression("9223372036854775808").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("integer out of range".into()));
let ast = parse_expression("123.456").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(FLOAT_TYPE));
let ast = parse_expression("True").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE));
let ast = parse_expression("False").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE));
}
#[test]
fn test_identifier() {
let ctx = basic_ctx();
let mut ctx = get_inference_context(ctx);
ctx.assign("abc", ctx.get_primitive(INT32_TYPE)).unwrap();
let ast = parse_expression("abc").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("ab").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("unbounded identifier".into()));
}
#[test]
fn test_list() {
let mut ctx = basic_ctx();
ctx.add_fn(
"foo",
FnDef {
args: vec![],
result: None,
},
);
let mut ctx = get_inference_context(ctx);
ctx.assign("abc", ctx.get_primitive(INT32_TYPE)).unwrap();
// def is reserved...
ctx.assign("efg", ctx.get_primitive(INT32_TYPE)).unwrap();
ctx.assign("xyz", ctx.get_primitive(FLOAT_TYPE)).unwrap();
let ast = parse_expression("[]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result.unwrap().unwrap(),
ParametricType(LIST_TYPE, vec![BotType.into()]).into()
);
let ast = parse_expression("[abc]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result.unwrap().unwrap(),
ParametricType(LIST_TYPE, vec![ctx.get_primitive(INT32_TYPE)]).into()
);
let ast = parse_expression("[abc, efg]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result.unwrap().unwrap(),
ParametricType(LIST_TYPE, vec![ctx.get_primitive(INT32_TYPE)]).into()
);
let ast = parse_expression("[abc, efg, xyz]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("inhomogeneous list is not allowed".into()));
let ast = parse_expression("[foo()]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("list elements must have some type".into()));
}
#[test]
fn test_tuple() {
let mut ctx = basic_ctx();
ctx.add_fn(
"foo",
FnDef {
args: vec![],
result: None,
},
);
let mut ctx = get_inference_context(ctx);
ctx.assign("abc", ctx.get_primitive(INT32_TYPE)).unwrap();
ctx.assign("efg", ctx.get_primitive(FLOAT_TYPE)).unwrap();
let ast = parse_expression("(abc, efg)").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result.unwrap().unwrap(),
ParametricType(
TUPLE_TYPE,
vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(FLOAT_TYPE)]
)
.into()
);
let ast = parse_expression("(abc, efg, foo())").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("tuple elements must have some type".into()));
}
#[test]
fn test_attribute() {
let mut ctx = basic_ctx();
ctx.add_fn(
"none",
FnDef {
args: vec![],
result: None,
},
);
let int32 = ctx.get_primitive(INT32_TYPE);
let float = ctx.get_primitive(FLOAT_TYPE);
let foo = ctx.add_class(ClassDef {
base: TypeDef {
name: "Foo",
fields: HashMap::new(),
methods: HashMap::new(),
},
parents: vec![],
});
let foo_def = ctx.get_class_def_mut(foo);
foo_def.base.fields.insert("a", int32.clone());
foo_def.base.fields.insert("b", ClassType(foo).into());
foo_def.base.fields.insert("c", int32.clone());
let bar = ctx.add_class(ClassDef {
base: TypeDef {
name: "Bar",
fields: HashMap::new(),
methods: HashMap::new(),
},
parents: vec![],
});
let bar_def = ctx.get_class_def_mut(bar);
bar_def.base.fields.insert("a", int32);
bar_def.base.fields.insert("b", ClassType(bar).into());
bar_def.base.fields.insert("c", float);
let v0 = ctx.add_variable(VarDef {
name: "v0",
bound: vec![],
});
let v1 = ctx.add_variable(VarDef {
name: "v1",
bound: vec![ClassType(foo).into(), ClassType(bar).into()],
});
let mut ctx = get_inference_context(ctx);
ctx.assign("foo", Rc::new(ClassType(foo))).unwrap();
ctx.assign("bar", Rc::new(ClassType(bar))).unwrap();
ctx.assign("foobar", Rc::new(VirtualClassType(foo)))
.unwrap();
ctx.assign("v0", ctx.get_variable(v0)).unwrap();
ctx.assign("v1", ctx.get_variable(v1)).unwrap();
ctx.assign("bot", Rc::new(BotType)).unwrap();
let ast = parse_expression("foo.a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("foo.d").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no such field".into()));
let ast = parse_expression("foobar.a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("v0.a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no fields on unbounded type variable".into()));
let ast = parse_expression("v1.a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
// shall we support this?
let ast = parse_expression("v1.b").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result,
Err("unknown field (type mismatch between variants)".into())
);
// assert_eq!(result.unwrap().unwrap(), TypeVariable(v1).into());
let ast = parse_expression("v1.c").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result,
Err("unknown field (type mismatch between variants)".into())
);
let ast = parse_expression("v1.d").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("unknown field".into()));
let ast = parse_expression("none().a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no value".into()));
let ast = parse_expression("bot.a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("this object has no fields".into()));
}
#[test]
fn test_bool_ops() {
let mut ctx = basic_ctx();
ctx.add_fn(
"none",
FnDef {
args: vec![],
result: None,
},
);
let mut ctx = get_inference_context(ctx);
let ast = parse_expression("True and False").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE));
let ast = parse_expression("True and none()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no value".into()));
let ast = parse_expression("True and 123").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("bool operands must be bool".into()));
}
#[test]
fn test_bin_ops() {
let mut ctx = basic_ctx();
let v0 = ctx.add_variable(VarDef {
name: "v0",
bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(INT64_TYPE)],
});
let mut ctx = get_inference_context(ctx);
ctx.assign("a", TypeVariable(v0).into()).unwrap();
let ast = parse_expression("1 + 2 + 3").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("a + a + a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), TypeVariable(v0).into());
}
#[test]
fn test_unary_ops() {
let mut ctx = basic_ctx();
let v0 = ctx.add_variable(VarDef {
name: "v0",
bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(INT64_TYPE)],
});
let mut ctx = get_inference_context(ctx);
ctx.assign("a", TypeVariable(v0).into()).unwrap();
let ast = parse_expression("-(123)").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("-a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), TypeVariable(v0).into());
let ast = parse_expression("not True").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE));
let ast = parse_expression("not (1)").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("logical not must be applied to bool".into()));
}
#[test]
fn test_compare() {
let mut ctx = basic_ctx();
let v0 = ctx.add_variable(VarDef {
name: "v0",
bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(INT64_TYPE)],
});
let mut ctx = get_inference_context(ctx);
ctx.assign("a", TypeVariable(v0).into()).unwrap();
let ast = parse_expression("a == a == a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE));
let ast = parse_expression("a == a == 1").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("different types".into()));
let ast = parse_expression("True > False").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no such function".into()));
let ast = parse_expression("True in False").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("unsupported comparison".into()));
}
#[test]
fn test_call() {
let mut ctx = basic_ctx();
ctx.add_fn(
"none",
FnDef {
args: vec![],
result: None,
},
);
let foo = ctx.add_class(ClassDef {
base: TypeDef {
name: "Foo",
fields: HashMap::new(),
methods: HashMap::new(),
},
parents: vec![],
});
let foo_def = ctx.get_class_def_mut(foo);
foo_def.base.methods.insert(
"a",
FnDef {
args: vec![],
result: Some(Rc::new(ClassType(foo))),
},
);
let bar = ctx.add_class(ClassDef {
base: TypeDef {
name: "Bar",
fields: HashMap::new(),
methods: HashMap::new(),
},
parents: vec![],
});
let bar_def = ctx.get_class_def_mut(bar);
bar_def.base.methods.insert(
"a",
FnDef {
args: vec![],
result: Some(Rc::new(ClassType(bar))),
},
);
let v0 = ctx.add_variable(VarDef {
name: "v0",
bound: vec![],
});
let v1 = ctx.add_variable(VarDef {
name: "v1",
bound: vec![ClassType(foo).into(), ClassType(bar).into()],
});
let v2 = ctx.add_variable(VarDef {
name: "v2",
bound: vec![
ClassType(foo).into(),
ClassType(bar).into(),
ctx.get_primitive(INT32_TYPE),
],
});
let mut ctx = get_inference_context(ctx);
ctx.assign("foo", Rc::new(ClassType(foo))).unwrap();
ctx.assign("bar", Rc::new(ClassType(bar))).unwrap();
ctx.assign("foobar", Rc::new(VirtualClassType(foo)))
.unwrap();
ctx.assign("v0", ctx.get_variable(v0)).unwrap();
ctx.assign("v1", ctx.get_variable(v1)).unwrap();
ctx.assign("v2", ctx.get_variable(v2)).unwrap();
ctx.assign("bot", Rc::new(BotType)).unwrap();
let ast = parse_expression("foo.a()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ClassType(foo).into());
let ast = parse_expression("v1.a()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), TypeVariable(v1).into());
let ast = parse_expression("foobar.a()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ClassType(foo).into());
let ast = parse_expression("none().a()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no value".into()));
let ast = parse_expression("bot.a()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("not supported".into()));
let ast = parse_expression("[][0].a()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("not supported".into()));
let ast = parse_expression("v0.a()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("unbounded type var".into()));
let ast = parse_expression("v2.a()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no such function".into()));
}
#[test]
fn infer_subscript() {
let mut ctx = basic_ctx();
ctx.add_fn(
"none",
FnDef {
args: vec![],
result: None,
},
);
let mut ctx = get_inference_context(ctx);
let ast = parse_expression("[1, 2, 3][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("[[1]][0][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("[1, 2, 3][1:2]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result.unwrap().unwrap(),
ParametricType(LIST_TYPE, vec![ctx.get_primitive(INT32_TYPE)]).into()
);
let ast = parse_expression("[1, 2, 3][1:2:2]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result.unwrap().unwrap(),
ParametricType(LIST_TYPE, vec![ctx.get_primitive(INT32_TYPE)]).into()
);
let ast = parse_expression("[1, 2, 3][1:1.2]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("slice must be int32 type".into()));
let ast = parse_expression("[1, 2, 3][1:none()]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("slice must have type".into()));
let ast = parse_expression("[1, 2, 3][1.2]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("index must be either slice or int32".into()));
let ast = parse_expression("[1, 2, 3][none()]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no value".into()));
let ast = parse_expression("none()[1.2]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no value".into()));
let ast = parse_expression("123[1]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result,
Err("subscript is not supported for types other than list".into())
);
}
#[test]
fn test_if_expr() {
let mut ctx = basic_ctx();
ctx.add_fn(
"none",
FnDef {
args: vec![],
result: None,
},
);
let mut ctx = get_inference_context(ctx);
let ast = parse_expression("1 if True else 0").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("none() if True else none()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap(), None);
let ast = parse_expression("none() if 1 else none()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("test should be bool".into()));
let ast = parse_expression("1 if True else none()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("divergent type".into()));
}
#[test]
fn test_list_comp() {
let mut ctx = basic_ctx();
ctx.add_fn(
"none",
FnDef {
args: vec![],
result: None,
},
);
let int32 = ctx.get_primitive(INT32_TYPE);
let mut ctx = get_inference_context(ctx);
ctx.assign("z", int32.clone()).unwrap();
let ast = parse_expression("[x for x in [(1, 2), (2, 3), (3, 4)]][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result.unwrap().unwrap(),
ParametricType(TUPLE_TYPE, vec![int32.clone(), int32.clone()]).into()
);
let ast = parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)]][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), int32);
let ast =
parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)] if x > 0][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), int32);
let ast = parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)] if x][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("test must be bool".into()));
let ast = parse_expression("[y for x in []][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("unbounded identifier".into()));
let ast = parse_expression("[none() for x in []][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no value".into()));
let ast = parse_expression("[z for z in []][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("duplicated naming".into()));
let ast = parse_expression("[x for x in [] for y in []]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result,
Err("only 1 generator statement is supported".into())
);
}
}

Some files were not shown because too many files have changed in this diff Show More