Compare commits

..

16 Commits

Author SHA1 Message Date
pca006132 65d9b620fe Merge remote-tracking branch 'origin/master' 2021-01-27 16:00:47 +08:00
pca006132 075245f4c2 get list of unknowns 2021-01-27 16:00:06 +08:00
pca006132 0e9dea0e65 updated todo 2021-01-25 17:09:17 +08:00
pca006132 7d06c903e3 added get_expression_unknowns 2021-01-15 15:19:41 +08:00
pca006132 5fce6cf069 updated lifetime 2021-01-15 14:37:55 +08:00
pca006132 d466b7bc2b simplified lifetime 2021-01-11 11:12:37 +08:00
pca006132 b6e4e68587 started adding error types
* converted to string to avoid breaking interface, would be fixed later
* tests would not check error messages now, as messages are not
  finalized
2021-01-11 10:40:32 +08:00
pca006132 779288d685 added todo 2021-01-08 17:16:25 +08:00
pca006132 3ecf57a588 fixed bugs 2021-01-08 16:54:34 +08:00
pca006132 ebe1027ffa added resolve signature 2021-01-08 16:24:25 +08:00
pca006132 7c6349520c started getting signatures 2021-01-08 14:52:48 +08:00
pca006132 04f121403a moved type check code to submodule 2021-01-08 12:58:33 +08:00
pca006132 b51168a5ab added statement tests 2021-01-05 13:35:44 +08:00
pca006132 007843c1ef added aug assign for primitives 2021-01-05 13:21:39 +08:00
pca006132 ff41cdb000 implemented statement check 2021-01-05 12:17:45 +08:00
pca006132 e1efb47ad2 statement inference: assignment 2021-01-04 17:07:14 +08:00
181 changed files with 4660 additions and 53799 deletions

View File

@ -1 +0,0 @@
doc-valid-idents = ["NumPy", ".."]

1
.gitignore vendored
View File

@ -1,3 +1,2 @@
__pycache__
/target
nix/windows/msys2

1400
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,14 +1,6 @@
[workspace]
members = [
"nac3ld",
"nac3ast",
"nac3parser",
"nac3core",
"nac3standalone",
"nac3artiq",
"runkernel",
"nac3embedded",
]
resolver = "2"
[profile.release]
debug = true

View File

@ -1,53 +0,0 @@
<div align="center">
![icon](https://git.m-labs.hk/M-Labs/nac3/raw/branch/master/nac3.svg)
</div>
# NAC3
NAC3 is a major, backward-incompatible rewrite of the compiler for the [ARTIQ](https://m-labs.hk/artiq) physics experiment control and data acquisition system. It features greatly improved compilation speeds, a much better type system, and more predictable and transparent operation.
NAC3 has a modular design and its applicability reaches beyond ARTIQ. The ``nac3core`` module does not contain anything specific to ARTIQ, and can be used in any project that requires compiling Python to machine code.
**WARNING: NAC3 is currently experimental software and several important features are not implemented yet.**
## Packaging
NAC3 is packaged using the [Nix](https://nixos.org) Flakes system. Install Nix 2.8+ and enable flakes by adding ``experimental-features = nix-command flakes`` to ``nix.conf`` (e.g. ``~/.config/nix/nix.conf``).
## Try NAC3
### Linux
After setting up Nix as above, use ``nix shell git+https://github.com/m-labs/artiq.git?ref=nac3`` to get a shell with the NAC3 version of ARTIQ. See the ``examples`` directory in ARTIQ (``nac3`` Git branch) for some samples of NAC3 kernel code.
### Windows
Install [MSYS2](https://www.msys2.org/), and open "MSYS2 CLANG64". Edit ``/etc/pacman.conf`` to add:
```
[artiq]
SigLevel = Optional TrustAll
Server = https://msys2.m-labs.hk/artiq-nac3
```
Then run the following commands:
```
pacman -Syu
pacman -S mingw-w64-clang-x86_64-artiq
```
## For developers
This repository contains:
- ``nac3ast``: Python abstract syntax tree definition (based on RustPython).
- ``nac3parser``: Python parser (based on RustPython).
- ``nac3core``: Core compiler library, containing type-checking and code generation.
- ``nac3standalone``: Standalone compiler tool (core language only).
- ``nac3ld``: Minimalist RISC-V and ARM linker.
- ``nac3artiq``: Integration with ARTIQ and implementation of ARTIQ-specific extensions to the core language.
- ``runkernel``: Simple program that runs compiled ARTIQ kernels on the host and displays RTIO operations. Useful for testing without hardware.
Use ``nix develop`` in this repository to enter a development shell.
If you are using a different shell than bash you can use e.g. ``nix develop --command fish``.
Build NAC3 with ``cargo build --release``. See the demonstrations in ``nac3artiq`` and ``nac3standalone``.

View File

@ -1,27 +0,0 @@
{
"nodes": {
"nixpkgs": {
"locked": {
"lastModified": 1708296515,
"narHash": "sha256-FyF489fYNAUy7b6dkYV6rGPyzp+4tThhr80KNAaF/yY=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "b98a4e1746acceb92c509bc496ef3d0e5ad8d4aa",
"type": "github"
},
"original": {
"owner": "NixOS",
"ref": "nixos-unstable",
"repo": "nixpkgs",
"type": "github"
}
},
"root": {
"inputs": {
"nixpkgs": "nixpkgs"
}
}
},
"root": "root",
"version": 7
}

187
flake.nix
View File

@ -1,187 +0,0 @@
{
description = "The third-generation ARTIQ compiler";
inputs.nixpkgs.url = github:NixOS/nixpkgs/nixos-unstable;
outputs = { self, nixpkgs }:
let
pkgs = import nixpkgs { system = "x86_64-linux"; };
in rec {
packages.x86_64-linux = rec {
llvm-nac3 = pkgs.callPackage ./nix/llvm {};
llvm-tools-irrt = pkgs.runCommandNoCC "llvm-tools-irrt" {}
''
mkdir -p $out/bin
ln -s ${pkgs.llvmPackages_14.clang-unwrapped}/bin/clang $out/bin/clang-irrt
ln -s ${pkgs.llvmPackages_14.llvm.out}/bin/llvm-as $out/bin/llvm-as-irrt
'';
nac3artiq = pkgs.python3Packages.toPythonModule (
pkgs.rustPlatform.buildRustPackage rec {
name = "nac3artiq";
outputs = [ "out" "runkernel" "standalone" ];
src = self;
cargoLock = {
lockFile = ./Cargo.lock;
};
passthru.cargoLock = cargoLock;
nativeBuildInputs = [ pkgs.python3 pkgs.llvmPackages_14.clang llvm-tools-irrt pkgs.llvmPackages_14.llvm.out llvm-nac3 ];
buildInputs = [ pkgs.python3 llvm-nac3 ];
checkInputs = [ (pkgs.python3.withPackages(ps: [ ps.numpy ps.scipy ])) ];
checkPhase =
''
echo "Checking nac3standalone demos..."
pushd nac3standalone/demo
patchShebangs .
./check_demos.sh
popd
echo "Running Cargo tests..."
cargoCheckHook
'';
installPhase =
''
PYTHON_SITEPACKAGES=$out/${pkgs.python3Packages.python.sitePackages}
mkdir -p $PYTHON_SITEPACKAGES
cp target/x86_64-unknown-linux-gnu/release/libnac3artiq.so $PYTHON_SITEPACKAGES/nac3artiq.so
mkdir -p $runkernel/bin
cp target/x86_64-unknown-linux-gnu/release/runkernel $runkernel/bin
mkdir -p $standalone/bin
cp target/x86_64-unknown-linux-gnu/release/nac3standalone $standalone/bin
'';
}
);
python3-mimalloc = pkgs.python3 // rec {
withMimalloc = pkgs.python3.buildEnv.override({ makeWrapperArgs = [ "--set LD_PRELOAD ${pkgs.mimalloc}/lib/libmimalloc.so" ]; });
withPackages = f: let packages = f pkgs.python3.pkgs; in withMimalloc.override { extraLibs = packages; };
};
# LLVM PGO support
llvm-nac3-instrumented = pkgs.callPackage ./nix/llvm {
stdenv = pkgs.llvmPackages_14.stdenv;
extraCmakeFlags = [ "-DLLVM_BUILD_INSTRUMENTED=IR" ];
};
nac3artiq-instrumented = pkgs.python3Packages.toPythonModule (
pkgs.rustPlatform.buildRustPackage {
name = "nac3artiq-instrumented";
src = self;
inherit (nac3artiq) cargoLock;
nativeBuildInputs = [ pkgs.python3 packages.x86_64-linux.llvm-tools-irrt llvm-nac3-instrumented ];
buildInputs = [ pkgs.python3 llvm-nac3-instrumented ];
cargoBuildFlags = [ "--package" "nac3artiq" "--features" "init-llvm-profile" ];
doCheck = false;
configurePhase =
''
export CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_RUSTFLAGS="-C link-arg=-L${pkgs.llvmPackages_14.compiler-rt}/lib/linux -C link-arg=-lclang_rt.profile-x86_64"
'';
installPhase =
''
TARGET_DIR=$out/${pkgs.python3Packages.python.sitePackages}
mkdir -p $TARGET_DIR
cp target/x86_64-unknown-linux-gnu/release/libnac3artiq.so $TARGET_DIR/nac3artiq.so
'';
}
);
nac3artiq-profile = pkgs.stdenvNoCC.mkDerivation {
name = "nac3artiq-profile";
srcs = [
(pkgs.fetchFromGitHub {
owner = "m-labs";
repo = "sipyco";
rev = "939f84f9b5eef7efbf7423c735d1834783b6140e";
sha256 = "sha256-15Nun4EY35j+6SPZkjzZtyH/ncxLS60KuGJjFh5kSTc=";
})
(pkgs.fetchFromGitHub {
owner = "m-labs";
repo = "artiq";
rev = "923ca3377d42c815f979983134ec549dc39d3ca0";
sha256 = "sha256-oJoEeNEeNFSUyh6jXG8Tzp6qHVikeHS0CzfE+mODPgw=";
})
];
buildInputs = [
(python3-mimalloc.withPackages(ps: [ ps.numpy ps.scipy ps.jsonschema ps.lmdb nac3artiq-instrumented ]))
pkgs.llvmPackages_14.llvm.out
];
phases = [ "buildPhase" "installPhase" ];
buildPhase =
''
srcs=($srcs)
sipyco=''${srcs[0]}
artiq=''${srcs[1]}
export PYTHONPATH=$sipyco:$artiq
python -m artiq.frontend.artiq_ddb_template $artiq/artiq/examples/nac3devices/nac3devices.json > device_db.py
cp $artiq/artiq/examples/nac3devices/nac3devices.py .
python -m artiq.frontend.artiq_compile nac3devices.py
'';
installPhase =
''
mkdir $out
llvm-profdata merge -o $out/llvm.profdata /build/llvm/build/profiles/*
'';
};
llvm-nac3-pgo = pkgs.callPackage ./nix/llvm {
stdenv = pkgs.llvmPackages_14.stdenv;
extraCmakeFlags = [ "-DLLVM_PROFDATA_FILE=${nac3artiq-profile}/llvm.profdata" ];
};
nac3artiq-pgo = pkgs.python3Packages.toPythonModule (
pkgs.rustPlatform.buildRustPackage {
name = "nac3artiq-pgo";
src = self;
inherit (nac3artiq) cargoLock;
nativeBuildInputs = [ pkgs.python3 packages.x86_64-linux.llvm-tools-irrt llvm-nac3-pgo ];
buildInputs = [ pkgs.python3 llvm-nac3-pgo ];
cargoBuildFlags = [ "--package" "nac3artiq" ];
cargoTestFlags = [ "--package" "nac3ast" "--package" "nac3parser" "--package" "nac3core" "--package" "nac3artiq" ];
installPhase =
''
TARGET_DIR=$out/${pkgs.python3Packages.python.sitePackages}
mkdir -p $TARGET_DIR
cp target/x86_64-unknown-linux-gnu/release/libnac3artiq.so $TARGET_DIR/nac3artiq.so
'';
}
);
};
packages.x86_64-w64-mingw32 = import ./nix/windows { inherit pkgs; };
devShells.x86_64-linux.default = pkgs.mkShell {
name = "nac3-dev-shell";
buildInputs = with pkgs; [
# build dependencies
packages.x86_64-linux.llvm-nac3
llvmPackages_14.clang # demo
packages.x86_64-linux.llvm-tools-irrt
cargo
rustc
# runtime dependencies
lld_14 # for running kernels on the host
(packages.x86_64-linux.python3-mimalloc.withPackages(ps: [ ps.numpy ps.scipy ]))
# development tools
cargo-insta
clippy
rustfmt
];
};
devShells.x86_64-linux.msys2 = pkgs.mkShell {
name = "nac3-dev-shell-msys2";
buildInputs = with pkgs; [
curl
pacman
fakeroot
packages.x86_64-w64-mingw32.wine-msys2
];
};
hydraJobs = {
inherit (packages.x86_64-linux) llvm-nac3 nac3artiq nac3artiq-pgo;
llvm-nac3-msys2 = packages.x86_64-w64-mingw32.llvm-nac3;
nac3artiq-msys2 = packages.x86_64-w64-mingw32.nac3artiq;
nac3artiq-msys2-pkg = packages.x86_64-w64-mingw32.nac3artiq-pkg;
};
};
nixConfig = {
extra-trusted-public-keys = "nixbld.m-labs.hk-1:5aSRVA5b320xbNvu30tqxVPXpld73bhtOeH6uAjRyHc=";
extra-substituters = "https://nixbld.m-labs.hk";
};
}

View File

@ -1,56 +0,0 @@
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg
id="a"
width="128"
height="128"
viewBox="0 0 95.99999 95.99999"
version="1.1"
sodipodi:docname="nac3.svg"
inkscape:version="1.1.1 (3bf5ae0d25, 2021-09-20)"
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
xmlns="http://www.w3.org/2000/svg"
xmlns:svg="http://www.w3.org/2000/svg">
<defs
id="defs11" />
<sodipodi:namedview
id="namedview9"
pagecolor="#ffffff"
bordercolor="#666666"
borderopacity="1.0"
inkscape:pageshadow="2"
inkscape:pageopacity="0.0"
inkscape:pagecheckerboard="0"
inkscape:document-units="mm"
showgrid="false"
units="px"
width="128px"
inkscape:zoom="5.9448568"
inkscape:cx="60.472441"
inkscape:cy="60.556547"
inkscape:window-width="2560"
inkscape:window-height="1371"
inkscape:window-x="0"
inkscape:window-y="32"
inkscape:window-maximized="1"
inkscape:current-layer="a" />
<rect
x="40.072601"
y="-26.776209"
width="55.668747"
height="55.668747"
transform="matrix(0.71803815,0.69600374,-0.71803815,0.69600374,0,0)"
style="fill:#be211e;stroke:#000000;stroke-width:4.37375px;stroke-linecap:round;stroke-linejoin:round"
id="rect2" />
<line
x1="38.00692"
y1="63.457153"
x2="57.993061"
y2="63.457153"
style="fill:none;stroke:#000000;stroke-width:4.37269px;stroke-linecap:round;stroke-linejoin:round"
id="line4" />
<path
d="m 48.007301,57.843329 c -1.943097,0 -3.877522,-0.41727 -5.686157,-1.246007 -3.218257,-1.474616 -5.650382,-4.075418 -6.849639,-7.323671 -2.065624,-5.588921 -1.192751,-10.226647 2.575258,-13.827 0.611554,-0.584909 1.518048,-0.773041 2.323689,-0.488206 0.80673,0.286405 1.369495,0.998486 1.447563,1.827234 0.237469,2.549302 2.439719,5.917376 4.28414,6.55273 0.396859,0.13506 0.820953,-0.05859 1.097084,-0.35222 0.339254,-0.360754 0.451065,-0.961893 -1.013597,-3.191372 -2.089851,-3.181137 -4.638728,-8.754903 -0.262407,-15.069853 0.494457,-0.713491 1.384673,-1.068907 2.256469,-0.909156 0.871795,0.161332 1.583757,0.806404 1.752251,1.651189 0.716448,3.591862 2.962357,6.151755 5.199306,8.023138 1.935503,1.61861 4.344688,3.867387 5.435687,7.096643 2.283183,6.758017 -1.202511,14.114988 -8.060822,16.494025 -1.467083,0.509226 -2.98513,0.762536 -4.498836,0.762536 z M 39.358865,40.002192 c -0.304711,0.696206 -0.541636,2.080524 -0.56865,2.237454 -0.330316,1.918771 0.168305,3.803963 0.846157,5.539951 0.856828,2.19436 2.437543,3.942467 4.583411,4.925713 2.143691,0.981675 4.554131,1.097816 6.789992,0.322666 4.571485,-1.586549 6.977584,-6.532238 5.363036,-11.02597 v -5.27e-4 C 55.455481,39.447968 54.023463,38.162043 52.221335,36.65432 50.876945,35.529534 49.409662,33.987726 48.417983,32.135555 48.01343,31.37996 47.79547,30.34303 47.76669,29.413263 c -0.187481,0.669514 -0.212441,2.325923 -0.150396,2.93691 0.179209,1.764456 1.333476,3.644546 2.340611,5.171243 1.311568,1.988179 2.72058,6.037272 0.459681,8.367985 -1.54192,1.58953 -4.038511,2.052034 -5.839973,1.38492 -2.398314,-0.888147 -3.942744,-2.690627 -4.941118,-4.768029 -0.121194,-0.25217 -0.532464,-1.174187 -0.276619,-2.5041 z"
id="path6"
style="stroke-width:1.09317" />
</svg>

Before

Width:  |  Height:  |  Size: 3.3 KiB

View File

@ -1,26 +0,0 @@
[package]
name = "nac3artiq"
version = "0.1.0"
authors = ["M-Labs"]
edition = "2021"
[lib]
name = "nac3artiq"
crate-type = ["cdylib"]
[dependencies]
itertools = "0.12"
pyo3 = { version = "0.20", features = ["extension-module"] }
parking_lot = "0.12"
tempfile = "3.10"
nac3parser = { path = "../nac3parser" }
nac3core = { path = "../nac3core" }
nac3ld = { path = "../nac3ld" }
[dependencies.inkwell]
version = "0.4"
default-features = false
features = ["llvm14-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]
[features]
init-llvm-profile = []

View File

@ -1,26 +0,0 @@
from min_artiq import *
@nac3
class Demo:
core: KernelInvariant[Core]
led0: KernelInvariant[TTLOut]
led1: KernelInvariant[TTLOut]
def __init__(self):
self.core = Core()
self.led0 = TTLOut(self.core, 18)
self.led1 = TTLOut(self.core, 19)
@kernel
def run(self):
self.core.reset()
while True:
with parallel:
self.led0.pulse(100.*ms)
self.led1.pulse(100.*ms)
self.core.delay(100.*ms)
if __name__ == "__main__":
Demo().run()

View File

@ -1,16 +0,0 @@
# python demo.py
# artiq_run module.elf
device_db = {
"core": {
"type": "local",
"module": "artiq.coredevice.core",
"class": "Core",
"arguments": {
"host": "kc705",
"ref_period": 1e-9,
"ref_multiplier": 8,
"target": "rv32g"
}
},
}

View File

@ -1,66 +0,0 @@
class EmbeddingMap:
def __init__(self):
self.object_inverse_map = {}
self.object_map = {}
self.string_map = {}
self.string_reverse_map = {}
self.function_map = {}
self.attributes_writeback = []
# preallocate exception names
self.preallocate_runtime_exception_names(["RuntimeError",
"RTIOUnderflow",
"RTIOOverflow",
"RTIODestinationUnreachable",
"DMAError",
"I2CError",
"CacheError",
"SPIError",
"0:ZeroDivisionError",
"0:IndexError",
"0:ValueError",
"0:RuntimeError",
"0:AssertionError",
"0:KeyError",
"0:NotImplementedError",
"0:OverflowError",
"0:IOError",
"0:UnwrapNoneError"])
def preallocate_runtime_exception_names(self, names):
for i, name in enumerate(names):
if ":" not in name:
name = "0:artiq.coredevice.exceptions." + name
exn_id = self.store_str(name)
assert exn_id == i
def store_function(self, key, fun):
self.function_map[key] = fun
return key
def store_object(self, obj):
obj_id = id(obj)
if obj_id in self.object_inverse_map:
return self.object_inverse_map[obj_id]
key = len(self.object_map) + 1
self.object_map[key] = obj
self.object_inverse_map[obj_id] = key
return key
def store_str(self, s):
if s in self.string_reverse_map:
return self.string_reverse_map[s]
key = len(self.string_map)
self.string_map[key] = s
self.string_reverse_map[s] = key
return key
def retrieve_function(self, key):
return self.function_map[key]
def retrieve_object(self, key):
return self.object_map[key]
def retrieve_str(self, key):
return self.string_map[key]

View File

@ -1,295 +0,0 @@
from inspect import getfullargspec
from functools import wraps
from types import SimpleNamespace
from numpy import int32, int64
from typing import Generic, TypeVar
from math import floor, ceil
import nac3artiq
from embedding_map import EmbeddingMap
__all__ = [
"Kernel", "KernelInvariant", "virtual", "ConstGeneric",
"Option", "Some", "none", "UnwrapNoneError",
"round64", "floor64", "ceil64",
"extern", "kernel", "portable", "nac3",
"rpc", "ms", "us", "ns",
"print_int32", "print_int64",
"Core", "TTLOut",
"parallel", "sequential"
]
T = TypeVar('T')
class Kernel(Generic[T]):
pass
class KernelInvariant(Generic[T]):
pass
# The virtual class must exist before nac3artiq.NAC3 is created.
class virtual(Generic[T]):
pass
class Option(Generic[T]):
_nac3_option: T
def __init__(self, v: T):
self._nac3_option = v
def is_none(self):
return self._nac3_option is None
def is_some(self):
return not self.is_none()
def unwrap(self):
if self.is_none():
raise UnwrapNoneError()
return self._nac3_option
def __repr__(self) -> str:
if self.is_none():
return "none"
else:
return "Some({})".format(repr(self._nac3_option))
def __str__(self) -> str:
if self.is_none():
return "none"
else:
return "Some({})".format(str(self._nac3_option))
def Some(v: T) -> Option[T]:
return Option(v)
none = Option(None)
class _ConstGenericMarker:
pass
def ConstGeneric(name, constraint):
return TypeVar(name, _ConstGenericMarker, constraint)
def round64(x):
return round(x)
def floor64(x):
return floor(x)
def ceil64(x):
return ceil(x)
import device_db
core_arguments = device_db.device_db["core"]["arguments"]
artiq_builtins = {
"none": none,
"virtual": virtual,
"_ConstGenericMarker": _ConstGenericMarker,
"Option": Option,
}
compiler = nac3artiq.NAC3(core_arguments["target"], artiq_builtins)
allow_registration = True
# Delay NAC3 analysis until all referenced variables are supposed to exist on the CPython side.
registered_functions = set()
registered_classes = set()
def register_function(fun):
assert allow_registration
registered_functions.add(fun)
def register_class(cls):
assert allow_registration
registered_classes.add(cls)
def extern(function):
"""Decorates a function declaration defined by the core device runtime."""
register_function(function)
return function
def rpc(function):
"""Decorates a function declaration defined by the core device runtime."""
register_function(function)
return function
def kernel(function_or_method):
"""Decorates a function or method to be executed on the core device."""
register_function(function_or_method)
argspec = getfullargspec(function_or_method)
if argspec.args and argspec.args[0] == "self":
@wraps(function_or_method)
def run_on_core(self, *args, **kwargs):
fake_method = SimpleNamespace(__self__=self, __name__=function_or_method.__name__)
self.core.run(fake_method, *args, **kwargs)
else:
@wraps(function_or_method)
def run_on_core(*args, **kwargs):
raise RuntimeError("Kernel functions need explicit core.run()")
return run_on_core
def portable(function):
"""Decorates a function or method to be executed on the same device (host/core device) as the caller."""
register_function(function)
return function
def nac3(cls):
"""
Decorates a class to be analyzed by NAC3.
All classes containing kernels or portable methods must use this decorator.
"""
register_class(cls)
return cls
ms = 1e-3
us = 1e-6
ns = 1e-9
@extern
def rtio_init():
raise NotImplementedError("syscall not simulated")
@extern
def rtio_get_counter() -> int64:
raise NotImplementedError("syscall not simulated")
@extern
def rtio_output(target: int32, data: int32):
raise NotImplementedError("syscall not simulated")
@extern
def rtio_input_timestamp(timeout_mu: int64, channel: int32) -> int64:
raise NotImplementedError("syscall not simulated")
@extern
def rtio_input_data(channel: int32) -> int32:
raise NotImplementedError("syscall not simulated")
# These is not part of ARTIQ and only available in runkernel. Defined here for convenience.
@extern
def print_int32(x: int32):
raise NotImplementedError("syscall not simulated")
@extern
def print_int64(x: int64):
raise NotImplementedError("syscall not simulated")
@nac3
class Core:
ref_period: KernelInvariant[float]
def __init__(self):
self.ref_period = core_arguments["ref_period"]
def run(self, method, *args, **kwargs):
global allow_registration
embedding = EmbeddingMap()
if allow_registration:
compiler.analyze(registered_functions, registered_classes)
allow_registration = False
if hasattr(method, "__self__"):
obj = method.__self__
name = method.__name__
else:
obj = method
name = ""
compiler.compile_method_to_file(obj, name, args, "module.elf", embedding)
@kernel
def reset(self):
rtio_init()
at_mu(rtio_get_counter() + int64(125000))
@kernel
def break_realtime(self):
min_now = rtio_get_counter() + int64(125000)
if now_mu() < min_now:
at_mu(min_now)
@portable
def seconds_to_mu(self, seconds: float) -> int64:
return int64(round(seconds/self.ref_period))
@portable
def mu_to_seconds(self, mu: int64) -> float:
return float(mu)*self.ref_period
@kernel
def delay(self, dt: float):
delay_mu(self.seconds_to_mu(dt))
@nac3
class TTLOut:
core: KernelInvariant[Core]
channel: KernelInvariant[int32]
target_o: KernelInvariant[int32]
def __init__(self, core: Core, channel: int32):
self.core = core
self.channel = channel
self.target_o = channel << 8
@kernel
def output(self):
pass
@kernel
def set_o(self, o: bool):
rtio_output(self.target_o, 1 if o else 0)
@kernel
def on(self):
self.set_o(True)
@kernel
def off(self):
self.set_o(False)
@kernel
def pulse_mu(self, duration: int64):
self.on()
delay_mu(duration)
self.off()
@kernel
def pulse(self, duration: float):
self.on()
self.core.delay(duration)
self.off()
@nac3
class KernelContextManager:
@kernel
def __enter__(self):
pass
@kernel
def __exit__(self):
pass
@nac3
class UnwrapNoneError(Exception):
"""raised when unwrapping a none value"""
artiq_builtin = True
parallel = KernelContextManager()
sequential = KernelContextManager()

View File

@ -1 +0,0 @@
../../target/release/libnac3artiq.so

View File

@ -1,685 +0,0 @@
use nac3core::{
codegen::{
expr::gen_call,
llvm_intrinsics::{call_int_smax, call_stackrestore, call_stacksave},
stmt::{gen_block, gen_with},
CodeGenContext, CodeGenerator,
},
symbol_resolver::ValueEnum,
toplevel::{DefinitionId, GenCall, helper::PRIMITIVE_DEF_IDS},
typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, VarMap}
};
use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef};
use inkwell::{
context::Context,
module::Linkage,
types::IntType,
values::BasicValueEnum,
AddressSpace,
};
use pyo3::{PyObject, PyResult, Python, types::{PyDict, PyList}};
use crate::{symbol_resolver::InnerResolver, timeline::TimeFns};
use std::{
collections::hash_map::DefaultHasher,
collections::HashMap,
hash::{Hash, Hasher},
sync::Arc,
};
/// The parallelism mode within a block.
#[derive(Copy, Clone, Eq, PartialEq)]
enum ParallelMode {
/// No parallelism is currently registered for this context.
None,
/// Legacy (or shallow) parallelism. Default before NAC3.
///
/// Each statement within the `with` block is treated as statements to be executed in parallel.
Legacy,
/// Deep parallelism. Default since NAC3.
///
/// Each function call within the `with` block (except those within a nested `sequential` block)
/// are treated to be executed in parallel.
Deep
}
pub struct ArtiqCodeGenerator<'a> {
name: String,
/// The size of a `size_t` variable in bits.
size_t: u32,
/// Monotonic counter for naming `start`/`stop` variables used by `with parallel` blocks.
name_counter: u32,
/// Variable for tracking the start of a `with parallel` block.
start: Option<Expr<Option<Type>>>,
/// Variable for tracking the end of a `with parallel` block.
end: Option<Expr<Option<Type>>>,
timeline: &'a (dyn TimeFns + Sync),
/// The [ParallelMode] of the current parallel context.
///
/// The current parallel context refers to the nearest `with parallel` or `with legacy_parallel`
/// statement, which is used to determine when and how the timeline should be updated.
parallel_mode: ParallelMode,
}
impl<'a> ArtiqCodeGenerator<'a> {
pub fn new(
name: String,
size_t: u32,
timeline: &'a (dyn TimeFns + Sync),
) -> ArtiqCodeGenerator<'a> {
assert!(size_t == 32 || size_t == 64);
ArtiqCodeGenerator {
name,
size_t,
name_counter: 0,
start: None,
end: None,
timeline,
parallel_mode: ParallelMode::None,
}
}
/// If the generator is currently in a direct-`parallel` block context, emits IR that resets the
/// position of the timeline to the initial timeline position before entering the `parallel`
/// block.
///
/// Direct-`parallel` block context refers to when the generator is generating statements whose
/// closest parent `with` statement is a `with parallel` block.
fn timeline_reset_start(
&mut self,
ctx: &mut CodeGenContext<'_, '_>
) -> Result<(), String> {
if let Some(start) = self.start.clone() {
let start_val = self.gen_expr(ctx, &start)?
.unwrap()
.to_basic_value_enum(ctx, self, start.custom.unwrap())?;
self.timeline.emit_at_mu(ctx, start_val);
}
Ok(())
}
/// If the generator is currently in a `parallel` block context, emits IR that updates the
/// maximum end position of the `parallel` block as specified by the timeline `end` value.
///
/// In general the `end` parameter should be set to `self.end` for updating the maximum end
/// position for the current `parallel` block. Other values can be passed in to update the
/// maximum end position for other `parallel` blocks.
///
/// `parallel`-block context refers to when the generator is generating statements within a
/// (possibly indirect) `parallel` block.
///
/// * `store_name` - The LLVM value name for the pointer to `end`. `.addr` will be appended to
/// the end of the provided value name.
fn timeline_update_end_max(
&mut self,
ctx: &mut CodeGenContext<'_, '_>,
end: Option<Expr<Option<Type>>>,
store_name: Option<&str>,
) -> Result<(), String> {
if let Some(end) = end {
let old_end = self.gen_expr(ctx, &end)?
.unwrap()
.to_basic_value_enum(ctx, self, end.custom.unwrap())?;
let now = self.timeline.emit_now_mu(ctx);
let max = call_int_smax(
ctx,
old_end.into_int_value(),
now.into_int_value(),
Some("smax")
);
let end_store = self.gen_store_target(
ctx,
&end,
store_name.map(|name| format!("{name}.addr")).as_deref())?
.unwrap();
ctx.builder.build_store(end_store, max).unwrap();
}
Ok(())
}
}
impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
fn get_name(&self) -> &str {
&self.name
}
fn get_size_type<'ctx>(&self, ctx: &'ctx Context) -> IntType<'ctx> {
if self.size_t == 32 {
ctx.i32_type()
} else {
ctx.i64_type()
}
}
fn gen_block<'ctx, 'a, 'c, I: Iterator<Item=&'c Stmt<Option<Type>>>>(
&mut self,
ctx: &mut CodeGenContext<'ctx, 'a>,
stmts: I
) -> Result<(), String> where Self: Sized {
// Legacy parallel emits timeline end-update/timeline-reset after each top-level statement
// in the parallel block
if self.parallel_mode == ParallelMode::Legacy {
for stmt in stmts {
self.gen_stmt(ctx, stmt)?;
if ctx.is_terminated() {
break;
}
self.timeline_update_end_max(ctx, self.end.clone(), Some("end"))?;
self.timeline_reset_start(ctx)?;
}
Ok(())
} else {
gen_block(self, ctx, stmts)
}
}
fn gen_call<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
obj: Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
let result = gen_call(self, ctx, obj, fun, params)?;
// Deep parallel emits timeline end-update/timeline-reset after each function call
if self.parallel_mode == ParallelMode::Deep {
self.timeline_update_end_max(ctx, self.end.clone(), Some("end"))?;
self.timeline_reset_start(ctx)?;
}
Ok(result)
}
fn gen_with(
&mut self,
ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>,
) -> Result<(), String> {
let StmtKind::With { items, body, .. } = &stmt.node else {
unreachable!()
};
if items.len() == 1 && items[0].optional_vars.is_none() {
let item = &items[0];
// Behavior of parallel and sequential:
// Each function call (indirectly, can be inside a sequential block) within a parallel
// block will update the end variable to the maximum now_mu in the block.
// Each function call directly inside a parallel block will reset the timeline after
// execution. A parallel block within a sequential block (or not within any block) will
// set the timeline to the max now_mu within the block (and the outer max now_mu will also
// be updated).
//
// Implementation: We track the start and end separately.
// - If there is a start variable, it indicates that we are directly inside a
// parallel block and we have to reset the timeline after every function call.
// - If there is a end variable, it indicates that we are (indirectly) inside a
// parallel block, and we should update the max end value.
if let ExprKind::Name { id, ctx: name_ctx } = &item.context_expr.node {
if id == &"parallel".into() || id == &"legacy_parallel".into() {
let old_start = self.start.take();
let old_end = self.end.take();
let old_parallel_mode = self.parallel_mode;
let now = if let Some(old_start) = &old_start {
self.gen_expr(ctx, old_start)?
.unwrap()
.to_basic_value_enum(ctx, self, old_start.custom.unwrap())?
} else {
self.timeline.emit_now_mu(ctx)
};
// Emulate variable allocation, as we need to use the CodeGenContext
// HashMap to store our variable due to lifetime limitation
// Note: we should be able to store variables directly if generic
// associative type is used by limiting the lifetime of CodeGenerator to
// the LLVM Context.
// The name is guaranteed to be unique as users cannot use this as variable
// name.
self.start = old_start.clone().map_or_else(
|| {
let start = format!("with-{}-start", self.name_counter).into();
let start_expr = Located {
// location does not matter at this point
location: stmt.location,
node: ExprKind::Name { id: start, ctx: name_ctx.clone() },
custom: Some(ctx.primitives.int64),
};
let start = self
.gen_store_target(ctx, &start_expr, Some("start.addr"))?
.unwrap();
ctx.builder.build_store(start, now).unwrap();
Ok(Some(start_expr)) as Result<_, String>
},
|v| Ok(Some(v)),
)?;
let end = format!("with-{}-end", self.name_counter).into();
let end_expr = Located {
// location does not matter at this point
location: stmt.location,
node: ExprKind::Name { id: end, ctx: name_ctx.clone() },
custom: Some(ctx.primitives.int64),
};
let end = self
.gen_store_target(ctx, &end_expr, Some("end.addr"))?
.unwrap();
ctx.builder.build_store(end, now).unwrap();
self.end = Some(end_expr);
self.name_counter += 1;
self.parallel_mode = match id.to_string().as_str() {
"parallel" => ParallelMode::Deep,
"legacy_parallel" => ParallelMode::Legacy,
_ => unreachable!(),
};
self.gen_block(ctx, body.iter())?;
let current = ctx.builder.get_insert_block().unwrap();
// if the current block is terminated, move before the terminator
// we want to set the timeline before reaching the terminator
// TODO: This may be unsound if there are multiple exit paths in the
// block... e.g.
// if ...:
// return
// Perhaps we can fix this by using actual with block?
let reset_position = if let Some(terminator) = current.get_terminator() {
ctx.builder.position_before(&terminator);
true
} else {
false
};
// set duration
let end_expr = self.end.take().unwrap();
let end_val = self
.gen_expr(ctx, &end_expr)?
.unwrap()
.to_basic_value_enum(ctx, self, end_expr.custom.unwrap())?;
// inside a sequential block
if old_start.is_none() {
self.timeline.emit_at_mu(ctx, end_val);
}
// inside a parallel block, should update the outer max now_mu
self.timeline_update_end_max(ctx, old_end.clone(), Some("outer.end"))?;
self.parallel_mode = old_parallel_mode;
self.end = old_end;
self.start = old_start;
if reset_position {
ctx.builder.position_at_end(current);
}
return Ok(());
} else if id == &"sequential".into() {
// For deep parallel, temporarily take away start to avoid function calls in
// the block from resetting the timeline.
// This does not affect legacy parallel, as the timeline will be reset after
// this block finishes execution.
let start = self.start.take();
self.gen_block(ctx, body.iter())?;
self.start = start;
// Reset the timeline when we are exiting the sequential block
// Legacy parallel does not need this, since it will be reset after codegen
// for this statement is completed
if self.parallel_mode == ParallelMode::Deep {
self.timeline_reset_start(ctx)?;
}
return Ok(());
}
}
}
// not parallel/sequential
gen_with(self, ctx, stmt)
}
}
fn gen_rpc_tag(
ctx: &mut CodeGenContext<'_, '_>,
ty: Type,
buffer: &mut Vec<u8>,
) -> Result<(), String> {
use nac3core::typecheck::typedef::TypeEnum::*;
let int32 = ctx.primitives.int32;
let int64 = ctx.primitives.int64;
let float = ctx.primitives.float;
let bool = ctx.primitives.bool;
let str = ctx.primitives.str;
let none = ctx.primitives.none;
if ctx.unifier.unioned(ty, int32) {
buffer.push(b'i');
} else if ctx.unifier.unioned(ty, int64) {
buffer.push(b'I');
} else if ctx.unifier.unioned(ty, float) {
buffer.push(b'f');
} else if ctx.unifier.unioned(ty, bool) {
buffer.push(b'b');
} else if ctx.unifier.unioned(ty, str) {
buffer.push(b's');
} else if ctx.unifier.unioned(ty, none) {
buffer.push(b'n');
} else {
let ty_enum = ctx.unifier.get_ty(ty);
match &*ty_enum {
TTuple { ty } => {
buffer.push(b't');
buffer.push(ty.len() as u8);
for ty in ty {
gen_rpc_tag(ctx, *ty, buffer)?;
}
}
TList { ty } => {
buffer.push(b'l');
gen_rpc_tag(ctx, *ty, buffer)?;
}
_ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))),
}
}
Ok(())
}
fn rpc_codegen_callback_fn<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
obj: Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
generator: &mut dyn CodeGenerator,
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
let ptr_type = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
let size_type = generator.get_size_type(ctx.ctx);
let int8 = ctx.ctx.i8_type();
let int32 = ctx.ctx.i32_type();
let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false);
let service_id = int32.const_int(fun.1.0 as u64, false);
// -- setup rpc tags
let mut tag = Vec::new();
if obj.is_some() {
tag.push(b'O');
}
for arg in &fun.0.args {
gen_rpc_tag(ctx, arg.ty, &mut tag)?;
}
tag.push(b':');
gen_rpc_tag(ctx, fun.0.ret, &mut tag)?;
let mut hasher = DefaultHasher::new();
tag.hash(&mut hasher);
let hash = format!("{}", hasher.finish());
let tag_ptr = ctx
.module
.get_global(hash.as_str())
.unwrap_or_else(|| {
let tag_arr_ptr = ctx.module.add_global(
int8.array_type(tag.len() as u32),
None,
format!("tagptr{}", fun.1 .0).as_str(),
);
tag_arr_ptr.set_initializer(&int8.const_array(
&tag.iter().map(|v| int8.const_int(*v as u64, false)).collect::<Vec<_>>(),
));
tag_arr_ptr.set_linkage(Linkage::Private);
let tag_ptr = ctx.module.add_global(tag_ptr_type, None, &hash);
tag_ptr.set_linkage(Linkage::Private);
tag_ptr.set_initializer(&ctx.ctx.const_struct(
&[
tag_arr_ptr.as_pointer_value().const_cast(ptr_type).into(),
size_type.const_int(tag.len() as u64, false).into(),
],
false,
));
tag_ptr
})
.as_pointer_value();
let arg_length = args.len() + usize::from(obj.is_some());
let stackptr = call_stacksave(ctx, Some("rpc.stack"));
let args_ptr = ctx.builder
.build_array_alloca(
ptr_type,
ctx.ctx.i32_type().const_int(arg_length as u64, false),
"argptr",
)
.unwrap();
// -- rpc args handling
let mut keys = fun.0.args.clone();
let mut mapping = HashMap::new();
for (key, value) in args {
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value);
}
// default value handling
for k in keys {
mapping.insert(
k.name,
ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into()
);
}
// reorder the parameters
let mut real_params = fun
.0
.args
.iter()
.map(|arg| mapping.remove(&arg.name).unwrap().to_basic_value_enum(ctx, generator, arg.ty))
.collect::<Result<Vec<_>, _>>()?;
if let Some(obj) = obj {
if let ValueEnum::Static(obj) = obj.1 {
real_params.insert(0, obj.get_const_obj(ctx, generator));
} else {
// should be an error here...
panic!("only host object is allowed");
}
}
for (i, arg) in real_params.iter().enumerate() {
let arg_slot = generator.gen_var_alloc(ctx, arg.get_type(), Some(&format!("rpc.arg{i}"))).unwrap();
ctx.builder.build_store(arg_slot, *arg).unwrap();
let arg_slot = ctx.builder.build_bitcast(arg_slot, ptr_type, "rpc.arg").unwrap();
let arg_ptr = unsafe {
ctx.builder.build_gep(
args_ptr,
&[int32.const_int(i as u64, false)],
&format!("rpc.arg{i}"),
)
}.unwrap();
ctx.builder.build_store(arg_ptr, arg_slot).unwrap();
}
// call
let rpc_send = ctx.module.get_function("rpc_send").unwrap_or_else(|| {
ctx.module.add_function(
"rpc_send",
ctx.ctx.void_type().fn_type(
&[
int32.into(),
tag_ptr_type.ptr_type(AddressSpace::default()).into(),
ptr_type.ptr_type(AddressSpace::default()).into(),
],
false,
),
None,
)
});
ctx.builder
.build_call(
rpc_send,
&[service_id.into(), tag_ptr.into(), args_ptr.into()],
"rpc.send",
)
.unwrap();
// reclaim stack space used by arguments
call_stackrestore(ctx, stackptr);
// -- receive value:
// T result = {
// void *ret_ptr = alloca(sizeof(T));
// void *ptr = ret_ptr;
// loop: int size = rpc_recv(ptr);
// // Non-zero: Provide `size` bytes of extra storage for variable-length data.
// if(size) { ptr = alloca(size); goto loop; }
// else *(T*)ret_ptr
// }
let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| {
ctx.module.add_function("rpc_recv", int32.fn_type(&[ptr_type.into()], false), None)
});
if ctx.unifier.unioned(fun.0.ret, ctx.primitives.none) {
ctx.build_call_or_invoke(rpc_recv, &[ptr_type.const_null().into()], "rpc_recv");
return Ok(None);
}
let prehead_bb = ctx.builder.get_insert_block().unwrap();
let current_function = prehead_bb.get_parent().unwrap();
let head_bb = ctx.ctx.append_basic_block(current_function, "rpc.head");
let alloc_bb = ctx.ctx.append_basic_block(current_function, "rpc.continue");
let tail_bb = ctx.ctx.append_basic_block(current_function, "rpc.tail");
let ret_ty = ctx.get_llvm_abi_type(generator, fun.0.ret);
let need_load = !ret_ty.is_pointer_type();
let slot = ctx.builder.build_alloca(ret_ty, "rpc.ret.slot").unwrap();
let slotgen = ctx.builder.build_bitcast(slot, ptr_type, "rpc.ret.ptr").unwrap();
ctx.builder.build_unconditional_branch(head_bb).unwrap();
ctx.builder.position_at_end(head_bb);
let phi = ctx.builder.build_phi(ptr_type, "rpc.ptr").unwrap();
phi.add_incoming(&[(&slotgen, prehead_bb)]);
let alloc_size = ctx
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
.unwrap()
.into_int_value();
let is_done = ctx.builder
.build_int_compare(
inkwell::IntPredicate::EQ,
int32.const_zero(),
alloc_size,
"rpc.done",
)
.unwrap();
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap();
ctx.builder.position_at_end(alloc_bb);
let alloc_ptr = ctx.builder.build_array_alloca(ptr_type, alloc_size, "rpc.alloc").unwrap();
let alloc_ptr = ctx.builder.build_bitcast(alloc_ptr, ptr_type, "rpc.alloc.ptr").unwrap();
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
ctx.builder.build_unconditional_branch(head_bb).unwrap();
ctx.builder.position_at_end(tail_bb);
let result = ctx.builder.build_load(slot, "rpc.result").unwrap();
if need_load {
call_stackrestore(ctx, stackptr);
}
Ok(Some(result))
}
pub fn attributes_writeback(
ctx: &mut CodeGenContext<'_, '_>,
generator: &mut dyn CodeGenerator,
inner_resolver: &InnerResolver,
host_attributes: &PyObject,
) -> Result<(), String> {
Python::with_gil(|py| -> PyResult<Result<(), String>> {
let host_attributes: &PyList = host_attributes.downcast(py)?;
let top_levels = ctx.top_level.definitions.read();
let globals = inner_resolver.global_value_ids.read();
let int32 = ctx.ctx.i32_type();
let zero = int32.const_zero();
let mut values = Vec::new();
let mut scratch_buffer = Vec::new();
for val in (*globals).values() {
let val = val.as_ref(py);
let ty = inner_resolver.get_obj_type(py, val, &mut ctx.unifier, &top_levels, &ctx.primitives)?;
if let Err(ty) = ty {
return Ok(Err(ty))
}
let ty = ty.unwrap();
match &*ctx.unifier.get_ty(ty) {
TypeEnum::TObj { fields, obj_id, .. }
if *obj_id != ctx.primitives.option.obj_id(&ctx.unifier).unwrap() =>
{
// we only care about primitive attributes
// for non-primitive attributes, they should be in another global
let mut attributes = Vec::new();
let obj = inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap();
for (name, (field_ty, is_mutable)) in fields {
if !is_mutable {
continue
}
if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() {
attributes.push(name.to_string());
let index = ctx.get_attr_index(ty, *name);
values.push((*field_ty, ctx.build_gep_and_load(
obj.into_pointer_value(),
&[zero, int32.const_int(index as u64, false)], None)));
}
}
if !attributes.is_empty() {
let pydict = PyDict::new(py);
pydict.set_item("obj", val)?;
pydict.set_item("fields", attributes)?;
host_attributes.append(pydict)?;
}
},
TypeEnum::TList { ty: elem_ty } => {
if gen_rpc_tag(ctx, *elem_ty, &mut scratch_buffer).is_ok() {
let pydict = PyDict::new(py);
pydict.set_item("obj", val)?;
host_attributes.append(pydict)?;
values.push((ty, inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap()));
}
},
_ => {}
}
}
let fun = FunSignature {
args: values.iter().enumerate().map(|(i, (ty, _))| FuncArg {
name: i.to_string().into(),
ty: *ty,
default_value: None
}).collect(),
ret: ctx.primitives.none,
vars: VarMap::default()
};
let args: Vec<_> = values.into_iter().map(|(_, val)| (None, ValueEnum::Dynamic(val))).collect();
if let Err(e) = rpc_codegen_callback_fn(ctx, None, (&fun, PRIMITIVE_DEF_IDS.int32), args, generator) {
return Ok(Err(e));
}
Ok(Ok(()))
}).unwrap()?;
Ok(())
}
pub fn rpc_codegen_callback() -> Arc<GenCall> {
Arc::new(GenCall::new(Box::new(|ctx, obj, fun, args, generator| {
rpc_codegen_callback_fn(ctx, obj, fun, args, generator)
})))
}

View File

@ -1,56 +0,0 @@
/* Force ld to make the ELF header as loadable. */
PHDRS
{
headers PT_LOAD FILEHDR PHDRS ;
text PT_LOAD ;
data PT_LOAD ;
dynamic PT_DYNAMIC ;
eh_frame PT_GNU_EH_FRAME ;
}
SECTIONS
{
/* Push back .text section enough so that ld.lld not complain */
. = SIZEOF_HEADERS;
.text :
{
*(.text .text.*)
} : text
.rodata :
{
*(.rodata .rodata.*)
}
.eh_frame :
{
KEEP(*(.eh_frame))
} : text
.eh_frame_hdr :
{
KEEP(*(.eh_frame_hdr))
} : text : eh_frame
.data :
{
*(.data)
} : data
.dynamic :
{
*(.dynamic)
} : data : dynamic
.bss (NOLOAD) : ALIGN(4)
{
__bss_start = .;
*(.sbss .sbss.* .bss .bss.*);
. = ALIGN(4);
_end = .;
}
. = ALIGN(0x1000);
_sstack_guard = .;
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,312 +0,0 @@
use inkwell::{values::{BasicValueEnum, CallSiteValue}, AddressSpace, AtomicOrdering};
use itertools::Either;
use nac3core::codegen::CodeGenContext;
/// Functions for manipulating the timeline.
pub trait TimeFns {
/// Emits LLVM IR for `now_mu`.
fn emit_now_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx>;
/// Emits LLVM IR for `at_mu`.
fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>);
/// Emits LLVM IR for `delay_mu`.
fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>);
}
pub struct NowPinningTimeFns64 {}
// For FPGA design reasons, on VexRiscv with 64-bit data bus, the "now" CSR is split into two 32-bit
// values that are each padded to 64-bits.
impl TimeFns for NowPinningTimeFns64 {
fn emit_now_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx> {
let i64_type = ctx.ctx.i64_type();
let i32_type = ctx.ctx.i32_type();
let now = ctx
.module
.get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx.builder
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")
}.unwrap();
let now_hi = ctx.builder.build_load(now_hiptr, "now.hi")
.map(BasicValueEnum::into_int_value)
.unwrap();
let now_lo = ctx.builder.build_load(now_loptr, "now.lo")
.map(BasicValueEnum::into_int_value)
.unwrap();
let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "").unwrap();
let shifted_hi = ctx.builder
.build_left_shift(zext_hi, i64_type.const_int(32, false), "")
.unwrap();
let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "").unwrap();
ctx.builder.build_or(shifted_hi, zext_lo, "now_mu").map(Into::into).unwrap()
}
fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) {
let i32_type = ctx.ctx.i32_type();
let i64_type = ctx.ctx.i64_type();
let i64_32 = i64_type.const_int(32, false);
let time = t.into_int_value();
let time_hi = ctx.builder
.build_int_truncate(
ctx.builder.build_right_shift(time, i64_32, false, "time.hi").unwrap(),
i32_type,
"",
)
.unwrap();
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap();
let now = ctx
.module
.get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx.builder
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")
}.unwrap();
ctx.builder
.build_store(now_hiptr, time_hi)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
ctx.builder
.build_store(now_loptr, time_lo)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
}
fn emit_delay_mu<'ctx>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
dt: BasicValueEnum<'ctx>,
) {
let i64_type = ctx.ctx.i64_type();
let i32_type = ctx.ctx.i32_type();
let now = ctx
.module
.get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx.builder
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")
}.unwrap();
let now_hi = ctx.builder.build_load(now_hiptr, "now.hi")
.map(BasicValueEnum::into_int_value)
.unwrap();
let now_lo = ctx.builder.build_load(now_loptr, "now.lo")
.map(BasicValueEnum::into_int_value)
.unwrap();
let dt = dt.into_int_value();
let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "").unwrap();
let shifted_hi = ctx.builder
.build_left_shift(zext_hi, i64_type.const_int(32, false), "")
.unwrap();
let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "").unwrap();
let now_val = ctx.builder.build_or(shifted_hi, zext_lo, "now").unwrap();
let time = ctx.builder.build_int_add(now_val, dt, "time").unwrap();
let time_hi = ctx.builder
.build_int_truncate(
ctx.builder.build_right_shift(
time,
i64_type.const_int(32, false),
false,
"",
).unwrap(),
i32_type,
"time.hi",
)
.unwrap();
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap();
ctx.builder
.build_store(now_hiptr, time_hi)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
ctx.builder
.build_store(now_loptr, time_lo)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
}
}
pub static NOW_PINNING_TIME_FNS_64: NowPinningTimeFns64 = NowPinningTimeFns64 {};
pub struct NowPinningTimeFns {}
impl TimeFns for NowPinningTimeFns {
fn emit_now_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx> {
let i64_type = ctx.ctx.i64_type();
let now = ctx
.module
.get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_raw = ctx.builder.build_load(now.as_pointer_value(), "now")
.map(BasicValueEnum::into_int_value)
.unwrap();
let i64_32 = i64_type.const_int(32, false);
let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now.lo").unwrap();
let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi").unwrap();
ctx.builder.build_or(now_lo, now_hi, "now_mu")
.map(Into::into)
.unwrap()
}
fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) {
let i32_type = ctx.ctx.i32_type();
let i64_type = ctx.ctx.i64_type();
let i64_32 = i64_type.const_int(32, false);
let time = t.into_int_value();
let time_hi = ctx.builder
.build_int_truncate(
ctx.builder.build_right_shift(time, i64_32, false, "").unwrap(),
i32_type,
"time.hi",
)
.unwrap();
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "now_trunc").unwrap();
let now = ctx
.module
.get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx.builder
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr")
}.unwrap();
ctx.builder
.build_store(now_hiptr, time_hi)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
ctx.builder
.build_store(now_loptr, time_lo)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
}
fn emit_delay_mu<'ctx>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
dt: BasicValueEnum<'ctx>,
) {
let i32_type = ctx.ctx.i32_type();
let i64_type = ctx.ctx.i64_type();
let i64_32 = i64_type.const_int(32, false);
let now = ctx
.module
.get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_raw = ctx.builder
.build_load(now.as_pointer_value(), "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let dt = dt.into_int_value();
let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now.lo").unwrap();
let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi").unwrap();
let now_val = ctx.builder.build_or(now_lo, now_hi, "now_val").unwrap();
let time = ctx.builder.build_int_add(now_val, dt, "time").unwrap();
let time_hi = ctx.builder
.build_int_truncate(
ctx.builder.build_right_shift(time, i64_32, false, "time.hi").unwrap(),
i32_type,
"now_trunc",
)
.unwrap();
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap();
let now_hiptr = ctx.builder
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr")
}.unwrap();
ctx.builder
.build_store(now_hiptr, time_hi)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
ctx.builder
.build_store(now_loptr, time_lo)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
}
}
pub static NOW_PINNING_TIME_FNS: NowPinningTimeFns = NowPinningTimeFns {};
pub struct ExternTimeFns {}
impl TimeFns for ExternTimeFns {
fn emit_now_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx> {
let now_mu = ctx.module.get_function("now_mu").unwrap_or_else(|| {
ctx.module.add_function("now_mu", ctx.ctx.i64_type().fn_type(&[], false), None)
});
ctx.builder.build_call(now_mu, &[], "now_mu")
.map(CallSiteValue::try_as_basic_value)
.map(Either::unwrap_left)
.unwrap()
}
fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) {
let at_mu = ctx.module.get_function("at_mu").unwrap_or_else(|| {
ctx.module.add_function(
"at_mu",
ctx.ctx.void_type().fn_type(&[ctx.ctx.i64_type().into()], false),
None,
)
});
ctx.builder.build_call(at_mu, &[t.into()], "at_mu").unwrap();
}
fn emit_delay_mu<'ctx>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
dt: BasicValueEnum<'ctx>,
) {
let delay_mu = ctx.module.get_function("delay_mu").unwrap_or_else(|| {
ctx.module.add_function(
"delay_mu",
ctx.ctx.void_type().fn_type(&[ctx.ctx.i64_type().into()], false),
None,
)
});
ctx.builder.build_call(delay_mu, &[dt.into()], "delay_mu").unwrap();
}
}
pub static EXTERN_TIME_FNS: ExternTimeFns = ExternTimeFns {};

View File

@ -1,16 +0,0 @@
[package]
name = "nac3ast"
version = "0.1.0"
authors = ["RustPython Team", "M-Labs"]
edition = "2021"
[features]
default = ["constant-optimization", "fold"]
constant-optimization = ["fold"]
fold = []
[dependencies]
lazy_static = "1.4"
parking_lot = "0.12"
string-interner = "0.15"
fxhash = "0.2"

View File

@ -1,127 +0,0 @@
-- ASDL's 4 builtin types are:
-- identifier, int, string, constant
module Python
{
mod = Module(stmt* body, type_ignore* type_ignores)
| Interactive(stmt* body)
| Expression(expr body)
| FunctionType(expr* argtypes, expr returns)
stmt = FunctionDef(identifier name, arguments args,
stmt* body, expr* decorator_list, expr? returns,
string? type_comment, identifier* config_comment)
| AsyncFunctionDef(identifier name, arguments args,
stmt* body, expr* decorator_list, expr? returns,
string? type_comment, identifier* config_comment)
| ClassDef(identifier name,
expr* bases,
keyword* keywords,
stmt* body,
expr* decorator_list, identifier* config_comment)
| Return(expr? value, identifier* config_comment)
| Delete(expr* targets, identifier* config_comment)
| Assign(expr* targets, expr value, string? type_comment, identifier* config_comment)
| AugAssign(expr target, operator op, expr value, identifier* config_comment)
-- 'simple' indicates that we annotate simple name without parens
| AnnAssign(expr target, expr annotation, expr? value, bool simple, identifier* config_comment)
-- use 'orelse' because else is a keyword in target languages
| For(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment, identifier* config_comment)
| AsyncFor(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment, identifier* config_comment)
| While(expr test, stmt* body, stmt* orelse, identifier* config_comment)
| If(expr test, stmt* body, stmt* orelse, identifier* config_comment)
| With(withitem* items, stmt* body, string? type_comment, identifier* config_comment)
| AsyncWith(withitem* items, stmt* body, string? type_comment, identifier* config_comment)
| Raise(expr? exc, expr? cause, identifier* config_comment)
| Try(stmt* body, excepthandler* handlers, stmt* orelse, stmt* finalbody, identifier* config_comment)
| Assert(expr test, expr? msg, identifier* config_comment)
| Import(alias* names, identifier* config_comment)
| ImportFrom(identifier? module, alias* names, int level, identifier* config_comment)
| Global(identifier* names, identifier* config_comment)
| Nonlocal(identifier* names, identifier* config_comment)
| Expr(expr value, identifier* config_comment)
| Pass(identifier* config_comment)
| Break(identifier* config_comment)
| Continue(identifier* config_comment)
-- col_offset is the byte offset in the utf8 string the parser uses
attributes (int lineno, int col_offset, int? end_lineno, int? end_col_offset)
-- BoolOp() can use left & right?
expr = BoolOp(boolop op, expr* values)
| NamedExpr(expr target, expr value)
| BinOp(expr left, operator op, expr right)
| UnaryOp(unaryop op, expr operand)
| Lambda(arguments args, expr body)
| IfExp(expr test, expr body, expr orelse)
| Dict(expr?* keys, expr* values)
| Set(expr* elts)
| ListComp(expr elt, comprehension* generators)
| SetComp(expr elt, comprehension* generators)
| DictComp(expr key, expr value, comprehension* generators)
| GeneratorExp(expr elt, comprehension* generators)
-- the grammar constrains where yield expressions can occur
| Await(expr value)
| Yield(expr? value)
| YieldFrom(expr value)
-- need sequences for compare to distinguish between
-- x < 4 < 3 and (x < 4) < 3
| Compare(expr left, cmpop* ops, expr* comparators)
| Call(expr func, expr* args, keyword* keywords)
| FormattedValue(expr value, conversion_flag? conversion, expr? format_spec)
| JoinedStr(expr* values)
| Constant(constant value, string? kind)
-- the following expression can appear in assignment context
| Attribute(expr value, identifier attr, expr_context ctx)
| Subscript(expr value, expr slice, expr_context ctx)
| Starred(expr value, expr_context ctx)
| Name(identifier id, expr_context ctx)
| List(expr* elts, expr_context ctx)
| Tuple(expr* elts, expr_context ctx)
-- can appear only in Subscript
| Slice(expr? lower, expr? upper, expr? step)
-- col_offset is the byte offset in the utf8 string the parser uses
attributes (int lineno, int col_offset, int? end_lineno, int? end_col_offset)
expr_context = Load | Store | Del
boolop = And | Or
operator = Add | Sub | Mult | MatMult | Div | Mod | Pow | LShift
| RShift | BitOr | BitXor | BitAnd | FloorDiv
unaryop = Invert | Not | UAdd | USub
cmpop = Eq | NotEq | Lt | LtE | Gt | GtE | Is | IsNot | In | NotIn
comprehension = (expr target, expr iter, expr* ifs, bool is_async)
excepthandler = ExceptHandler(expr? type, identifier? name, stmt* body)
attributes (int lineno, int col_offset, int? end_lineno, int? end_col_offset)
arguments = (arg* posonlyargs, arg* args, arg? vararg, arg* kwonlyargs,
expr?* kw_defaults, arg? kwarg, expr* defaults)
arg = (identifier arg, expr? annotation, string? type_comment)
attributes (int lineno, int col_offset, int? end_lineno, int? end_col_offset)
-- keyword arguments supplied to call (NULL identifier for **kwargs)
keyword = (identifier? arg, expr value)
attributes (int lineno, int col_offset, int? end_lineno, int? end_col_offset)
-- import name with optional 'as' alias.
alias = (identifier name, identifier? asname)
withitem = (expr context_expr, expr? optional_vars)
type_ignore = TypeIgnore(int lineno, string tag)
}

View File

@ -1,385 +0,0 @@
#-------------------------------------------------------------------------------
# Parser for ASDL [1] definition files. Reads in an ASDL description and parses
# it into an AST that describes it.
#
# The EBNF we're parsing here: Figure 1 of the paper [1]. Extended to support
# modules and attributes after a product. Words starting with Capital letters
# are terminals. Literal tokens are in "double quotes". Others are
# non-terminals. Id is either TokenId or ConstructorId.
#
# module ::= "module" Id "{" [definitions] "}"
# definitions ::= { TypeId "=" type }
# type ::= product | sum
# product ::= fields ["attributes" fields]
# fields ::= "(" { field, "," } field ")"
# field ::= TypeId ["?" | "*"] [Id]
# sum ::= constructor { "|" constructor } ["attributes" fields]
# constructor ::= ConstructorId [fields]
#
# [1] "The Zephyr Abstract Syntax Description Language" by Wang, et. al. See
# http://asdl.sourceforge.net/
#-------------------------------------------------------------------------------
from collections import namedtuple
import re
__all__ = [
'builtin_types', 'parse', 'AST', 'Module', 'Type', 'Constructor',
'Field', 'Sum', 'Product', 'VisitorBase', 'Check', 'check']
# The following classes define nodes into which the ASDL description is parsed.
# Note: this is a "meta-AST". ASDL files (such as Python.asdl) describe the AST
# structure used by a programming language. But ASDL files themselves need to be
# parsed. This module parses ASDL files and uses a simple AST to represent them.
# See the EBNF at the top of the file to understand the logical connection
# between the various node types.
builtin_types = {'identifier', 'string', 'int', 'constant', 'bool', 'conversion_flag'}
class AST:
def __repr__(self):
raise NotImplementedError
class Module(AST):
def __init__(self, name, dfns):
self.name = name
self.dfns = dfns
self.types = {type.name: type.value for type in dfns}
def __repr__(self):
return 'Module({0.name}, {0.dfns})'.format(self)
class Type(AST):
def __init__(self, name, value):
self.name = name
self.value = value
def __repr__(self):
return 'Type({0.name}, {0.value})'.format(self)
class Constructor(AST):
def __init__(self, name, fields=None):
self.name = name
self.fields = fields or []
def __repr__(self):
return 'Constructor({0.name}, {0.fields})'.format(self)
class Field(AST):
def __init__(self, type, name=None, seq=False, opt=False):
self.type = type
self.name = name
self.seq = seq
self.opt = opt
def __str__(self):
if self.seq:
extra = "*"
elif self.opt:
extra = "?"
else:
extra = ""
return "{}{} {}".format(self.type, extra, self.name)
def __repr__(self):
if self.seq:
extra = ", seq=True"
elif self.opt:
extra = ", opt=True"
else:
extra = ""
if self.name is None:
return 'Field({0.type}{1})'.format(self, extra)
else:
return 'Field({0.type}, {0.name}{1})'.format(self, extra)
class Sum(AST):
def __init__(self, types, attributes=None):
self.types = types
self.attributes = attributes or []
def __repr__(self):
if self.attributes:
return 'Sum({0.types}, {0.attributes})'.format(self)
else:
return 'Sum({0.types})'.format(self)
class Product(AST):
def __init__(self, fields, attributes=None):
self.fields = fields
self.attributes = attributes or []
def __repr__(self):
if self.attributes:
return 'Product({0.fields}, {0.attributes})'.format(self)
else:
return 'Product({0.fields})'.format(self)
# A generic visitor for the meta-AST that describes ASDL. This can be used by
# emitters. Note that this visitor does not provide a generic visit method, so a
# subclass needs to define visit methods from visitModule to as deep as the
# interesting node.
# We also define a Check visitor that makes sure the parsed ASDL is well-formed.
class VisitorBase(object):
"""Generic tree visitor for ASTs."""
def __init__(self):
self.cache = {}
def visit(self, obj, *args):
klass = obj.__class__
meth = self.cache.get(klass)
if meth is None:
methname = "visit" + klass.__name__
meth = getattr(self, methname, None)
self.cache[klass] = meth
if meth:
try:
meth(obj, *args)
except Exception as e:
print("Error visiting %r: %s" % (obj, e))
raise
class Check(VisitorBase):
"""A visitor that checks a parsed ASDL tree for correctness.
Errors are printed and accumulated.
"""
def __init__(self):
super(Check, self).__init__()
self.cons = {}
self.errors = 0
self.types = {}
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type):
self.visit(type.value, str(type.name))
def visitSum(self, sum, name):
for t in sum.types:
self.visit(t, name)
def visitConstructor(self, cons, name):
key = str(cons.name)
conflict = self.cons.get(key)
if conflict is None:
self.cons[key] = name
else:
print('Redefinition of constructor {}'.format(key))
print('Defined in {} and {}'.format(conflict, name))
self.errors += 1
for f in cons.fields:
self.visit(f, key)
def visitField(self, field, name):
key = str(field.type)
l = self.types.setdefault(key, [])
l.append(name)
def visitProduct(self, prod, name):
for f in prod.fields:
self.visit(f, name)
def check(mod):
"""Check the parsed ASDL tree for correctness.
Return True if success. For failure, the errors are printed out and False
is returned.
"""
v = Check()
v.visit(mod)
for t in v.types:
if t not in mod.types and not t in builtin_types:
v.errors += 1
uses = ", ".join(v.types[t])
print('Undefined type {}, used in {}'.format(t, uses))
return not v.errors
# The ASDL parser itself comes next. The only interesting external interface
# here is the top-level parse function.
def parse(filename):
"""Parse ASDL from the given file and return a Module node describing it."""
with open(filename) as f:
parser = ASDLParser()
return parser.parse(f.read())
# Types for describing tokens in an ASDL specification.
class TokenKind:
"""TokenKind is provides a scope for enumerated token kinds."""
(ConstructorId, TypeId, Equals, Comma, Question, Pipe, Asterisk,
LParen, RParen, LBrace, RBrace) = range(11)
operator_table = {
'=': Equals, ',': Comma, '?': Question, '|': Pipe, '(': LParen,
')': RParen, '*': Asterisk, '{': LBrace, '}': RBrace}
Token = namedtuple('Token', 'kind value lineno')
class ASDLSyntaxError(Exception):
def __init__(self, msg, lineno=None):
self.msg = msg
self.lineno = lineno or '<unknown>'
def __str__(self):
return 'Syntax error on line {0.lineno}: {0.msg}'.format(self)
def tokenize_asdl(buf):
"""Tokenize the given buffer. Yield Token objects."""
for lineno, line in enumerate(buf.splitlines(), 1):
for m in re.finditer(r'\s*(\w+|--.*|.)', line.strip()):
c = m.group(1)
if c[0].isalpha():
# Some kind of identifier
if c[0].isupper():
yield Token(TokenKind.ConstructorId, c, lineno)
else:
yield Token(TokenKind.TypeId, c, lineno)
elif c[:2] == '--':
# Comment
break
else:
# Operators
try:
op_kind = TokenKind.operator_table[c]
except KeyError:
raise ASDLSyntaxError('Invalid operator %s' % c, lineno)
yield Token(op_kind, c, lineno)
class ASDLParser:
"""Parser for ASDL files.
Create, then call the parse method on a buffer containing ASDL.
This is a simple recursive descent parser that uses tokenize_asdl for the
lexing.
"""
def __init__(self):
self._tokenizer = None
self.cur_token = None
def parse(self, buf):
"""Parse the ASDL in the buffer and return an AST with a Module root.
"""
self._tokenizer = tokenize_asdl(buf)
self._advance()
return self._parse_module()
def _parse_module(self):
if self._at_keyword('module'):
self._advance()
else:
raise ASDLSyntaxError(
'Expected "module" (found {})'.format(self.cur_token.value),
self.cur_token.lineno)
name = self._match(self._id_kinds)
self._match(TokenKind.LBrace)
defs = self._parse_definitions()
self._match(TokenKind.RBrace)
return Module(name, defs)
def _parse_definitions(self):
defs = []
while self.cur_token.kind == TokenKind.TypeId:
typename = self._advance()
self._match(TokenKind.Equals)
type = self._parse_type()
defs.append(Type(typename, type))
return defs
def _parse_type(self):
if self.cur_token.kind == TokenKind.LParen:
# If we see a (, it's a product
return self._parse_product()
else:
# Otherwise it's a sum. Look for ConstructorId
sumlist = [Constructor(self._match(TokenKind.ConstructorId),
self._parse_optional_fields())]
while self.cur_token.kind == TokenKind.Pipe:
# More constructors
self._advance()
sumlist.append(Constructor(
self._match(TokenKind.ConstructorId),
self._parse_optional_fields()))
return Sum(sumlist, self._parse_optional_attributes())
def _parse_product(self):
return Product(self._parse_fields(), self._parse_optional_attributes())
def _parse_fields(self):
fields = []
self._match(TokenKind.LParen)
while self.cur_token.kind == TokenKind.TypeId:
typename = self._advance()
is_seq, is_opt = self._parse_optional_field_quantifier()
id = (self._advance() if self.cur_token.kind in self._id_kinds
else None)
fields.append(Field(typename, id, seq=is_seq, opt=is_opt))
if self.cur_token.kind == TokenKind.RParen:
break
elif self.cur_token.kind == TokenKind.Comma:
self._advance()
self._match(TokenKind.RParen)
return fields
def _parse_optional_fields(self):
if self.cur_token.kind == TokenKind.LParen:
return self._parse_fields()
else:
return None
def _parse_optional_attributes(self):
if self._at_keyword('attributes'):
self._advance()
return self._parse_fields()
else:
return None
def _parse_optional_field_quantifier(self):
is_seq, is_opt = False, False
if self.cur_token.kind == TokenKind.Question:
is_opt = True
self._advance()
if self.cur_token.kind == TokenKind.Asterisk:
is_seq = True
self._advance()
return is_seq, is_opt
def _advance(self):
""" Return the value of the current token and read the next one into
self.cur_token.
"""
cur_val = None if self.cur_token is None else self.cur_token.value
try:
self.cur_token = next(self._tokenizer)
except StopIteration:
self.cur_token = None
return cur_val
_id_kinds = (TokenKind.ConstructorId, TokenKind.TypeId)
def _match(self, kind):
"""The 'match' primitive of RD parsers.
* Verifies that the current token is of the given kind (kind can
be a tuple, in which the kind must match one of its members).
* Returns the value of the current token
* Reads in the next token
"""
if (isinstance(kind, tuple) and self.cur_token.kind in kind or
self.cur_token.kind == kind
):
value = self.cur_token.value
self._advance()
return value
else:
raise ASDLSyntaxError(
'Unmatched {} (found {})'.format(kind, self.cur_token.kind),
self.cur_token.lineno)
def _at_keyword(self, keyword):
return (self.cur_token.kind == TokenKind.TypeId and
self.cur_token.value == keyword)

View File

@ -1,609 +0,0 @@
#! /usr/bin/env python
"""Generate Rust code from an ASDL description."""
import os
import sys
import textwrap
import json
from argparse import ArgumentParser
from pathlib import Path
import asdl
TABSIZE = 4
AUTOGEN_MESSAGE = "// File automatically generated by {}.\n\n"
builtin_type_mapping = {
'identifier': 'Ident',
'string': 'String',
'int': 'usize',
'constant': 'Constant',
'bool': 'bool',
'conversion_flag': 'ConversionFlag',
}
assert builtin_type_mapping.keys() == asdl.builtin_types
def get_rust_type(name):
"""Return a string for the C name of the type.
This function special cases the default types provided by asdl.
"""
if name in asdl.builtin_types:
return builtin_type_mapping[name]
else:
return "".join(part.capitalize() for part in name.split("_"))
def is_simple(sum):
"""Return True if a sum is a simple.
A sum is simple if its types have no fields, e.g.
unaryop = Invert | Not | UAdd | USub
"""
for t in sum.types:
if t.fields:
return False
return True
def asdl_of(name, obj):
if isinstance(obj, asdl.Product) or isinstance(obj, asdl.Constructor):
fields = ", ".join(map(str, obj.fields))
if fields:
fields = "({})".format(fields)
return "{}{}".format(name, fields)
else:
if is_simple(obj):
types = " | ".join(type.name for type in obj.types)
else:
sep = "\n{}| ".format(" " * (len(name) + 1))
types = sep.join(
asdl_of(type.name, type) for type in obj.types
)
return "{} = {}".format(name, types)
class EmitVisitor(asdl.VisitorBase):
"""Visit that emits lines"""
def __init__(self, file):
self.file = file
self.identifiers = set()
super(EmitVisitor, self).__init__()
def emit_identifier(self, name):
name = str(name)
if name in self.identifiers:
return
self.emit("_Py_IDENTIFIER(%s);" % name, 0)
self.identifiers.add(name)
def emit(self, line, depth):
if line:
line = (" " * TABSIZE * depth) + line
self.file.write(line + "\n")
class TypeInfo:
def __init__(self, name):
self.name = name
self.has_userdata = None
self.children = set()
self.boxed = False
def __repr__(self):
return f"<TypeInfo: {self.name}>"
def determine_userdata(self, typeinfo, stack):
if self.name in stack:
return None
stack.add(self.name)
for child, child_seq in self.children:
if child in asdl.builtin_types:
continue
childinfo = typeinfo[child]
child_has_userdata = childinfo.determine_userdata(typeinfo, stack)
if self.has_userdata is None and child_has_userdata is True:
self.has_userdata = True
stack.remove(self.name)
return self.has_userdata
class FindUserdataTypesVisitor(asdl.VisitorBase):
def __init__(self, typeinfo):
self.typeinfo = typeinfo
super().__init__()
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
stack = set()
for info in self.typeinfo.values():
info.determine_userdata(self.typeinfo, stack)
def visitType(self, type):
self.typeinfo[type.name] = TypeInfo(type.name)
self.visit(type.value, type.name)
def visitSum(self, sum, name):
info = self.typeinfo[name]
if is_simple(sum):
info.has_userdata = False
else:
if len(sum.types) > 1:
info.boxed = True
if sum.attributes:
# attributes means Located, which has the `custom: U` field
info.has_userdata = True
for variant in sum.types:
self.add_children(name, variant.fields)
def visitProduct(self, product, name):
info = self.typeinfo[name]
if product.attributes:
# attributes means Located, which has the `custom: U` field
info.has_userdata = True
if len(product.fields) > 2:
info.boxed = True
self.add_children(name, product.fields)
def add_children(self, name, fields):
self.typeinfo[name].children.update((field.type, field.seq) for field in fields)
def rust_field(field_name):
if field_name == 'type':
return 'type_'
else:
return field_name
class TypeInfoEmitVisitor(EmitVisitor):
def __init__(self, file, typeinfo):
self.typeinfo = typeinfo
super().__init__(file)
def has_userdata(self, typ):
return self.typeinfo[typ].has_userdata
def get_generics(self, typ, *generics):
if self.has_userdata(typ):
return [f"<{g}>" for g in generics]
else:
return ["" for g in generics]
class StructVisitor(TypeInfoEmitVisitor):
"""Visitor to generate typedefs for AST."""
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
if is_simple(sum):
self.simple_sum(sum, name, depth)
else:
self.sum_with_constructors(sum, name, depth)
def emit_attrs(self, depth):
self.emit("#[derive(Clone, Debug, PartialEq)]", depth)
def simple_sum(self, sum, name, depth):
rustname = get_rust_type(name)
self.emit_attrs(depth)
self.emit(f"pub enum {rustname} {{", depth)
for variant in sum.types:
self.emit(f"{variant.name},", depth + 1)
self.emit("}", depth)
self.emit("", depth)
def sum_with_constructors(self, sum, name, depth):
typeinfo = self.typeinfo[name]
generics, generics_applied = self.get_generics(name, "U = ()", "U")
enumname = rustname = get_rust_type(name)
# all the attributes right now are for location, so if it has attrs we
# can just wrap it in Located<>
if sum.attributes:
enumname = rustname + "Kind"
self.emit_attrs(depth)
self.emit(f"pub enum {enumname}{generics} {{", depth)
for t in sum.types:
self.visit(t, typeinfo, depth + 1)
self.emit("}", depth)
if sum.attributes:
self.emit(f"pub type {rustname}<U = ()> = Located<{enumname}{generics_applied}, U>;", depth)
self.emit("", depth)
def visitConstructor(self, cons, parent, depth):
if cons.fields:
self.emit(f"{cons.name} {{", depth)
for f in cons.fields:
self.visit(f, parent, "", depth + 1)
self.emit("},", depth)
else:
self.emit(f"{cons.name},", depth)
def visitField(self, field, parent, vis, depth):
typ = get_rust_type(field.type)
fieldtype = self.typeinfo.get(field.type)
if fieldtype and fieldtype.has_userdata:
typ = f"{typ}<U>"
# don't box if we're doing Vec<T>, but do box if we're doing Vec<Option<Box<T>>>
if fieldtype and fieldtype.boxed and (not field.seq or field.opt):
typ = f"Box<{typ}>"
if field.opt:
typ = f"Option<{typ}>"
if field.seq:
typ = f"Vec<{typ}>"
name = rust_field(field.name)
self.emit(f"{vis}{name}: {typ},", depth)
def visitProduct(self, product, name, depth):
typeinfo = self.typeinfo[name]
generics, generics_applied = self.get_generics(name, "U = ()", "U")
dataname = rustname = get_rust_type(name)
if product.attributes:
dataname = rustname + "Data"
self.emit_attrs(depth)
self.emit(f"pub struct {dataname}{generics} {{", depth)
for f in product.fields:
self.visit(f, typeinfo, "pub ", depth + 1)
self.emit("}", depth)
if product.attributes:
# attributes should just be location info
self.emit(f"pub type {rustname}<U = ()> = Located<{dataname}{generics_applied}, U>;", depth);
self.emit("", depth)
class FoldTraitDefVisitor(TypeInfoEmitVisitor):
def visitModule(self, mod, depth):
self.emit("pub trait Fold<U> {", depth)
self.emit("type TargetU;", depth + 1)
self.emit("type Error;", depth + 1)
self.emit("fn map_user(&mut self, user: U) -> Result<Self::TargetU, Self::Error>;", depth + 2)
for dfn in mod.dfns:
self.visit(dfn, depth + 2)
self.emit("}", depth)
def visitType(self, type, depth):
name = type.name
apply_u, apply_target_u = self.get_generics(name, "U", "Self::TargetU")
enumname = get_rust_type(name)
self.emit(f"fn fold_{name}(&mut self, node: {enumname}{apply_u}) -> Result<{enumname}{apply_target_u}, Self::Error> {{", depth)
self.emit(f"fold_{name}(self, node)", depth + 1)
self.emit("}", depth)
class FoldImplVisitor(TypeInfoEmitVisitor):
def visitModule(self, mod, depth):
self.emit("fn fold_located<U, F: Fold<U> + ?Sized, T, MT>(folder: &mut F, node: Located<T, U>, f: impl FnOnce(&mut F, T) -> Result<MT, F::Error>) -> Result<Located<MT, F::TargetU>, F::Error> {", depth)
self.emit("Ok(Located { custom: folder.map_user(node.custom)?, location: node.location, node: f(folder, node.node)? })", depth + 1)
self.emit("}", depth)
for dfn in mod.dfns:
self.visit(dfn, depth)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
apply_t, apply_u, apply_target_u = self.get_generics(name, "T", "U", "F::TargetU")
enumname = get_rust_type(name)
is_located = bool(sum.attributes)
self.emit(f"impl<T, U> Foldable<T, U> for {enumname}{apply_t} {{", depth)
self.emit(f"type Mapped = {enumname}{apply_u};", depth + 1)
self.emit("fn fold<F: Fold<T, TargetU = U> + ?Sized>(self, folder: &mut F) -> Result<Self::Mapped, F::Error> {", depth + 1)
self.emit(f"folder.fold_{name}(self)", depth + 2)
self.emit("}", depth + 1)
self.emit("}", depth)
self.emit(f"pub fn fold_{name}<U, F: Fold<U> + ?Sized>(#[allow(unused)] folder: &mut F, node: {enumname}{apply_u}) -> Result<{enumname}{apply_target_u}, F::Error> {{", depth)
if is_located:
self.emit("fold_located(folder, node, |folder, node| {", depth)
enumname += "Kind"
self.emit("match node {", depth + 1)
for cons in sum.types:
fields_pattern = self.make_pattern(cons.fields)
self.emit(f"{enumname}::{cons.name} {{ {fields_pattern} }} => {{", depth + 2)
self.gen_construction(f"{enumname}::{cons.name}", cons.fields, depth + 3)
self.emit("}", depth + 2)
self.emit("}", depth + 1)
if is_located:
self.emit("})", depth)
self.emit("}", depth)
def visitProduct(self, product, name, depth):
apply_t, apply_u, apply_target_u = self.get_generics(name, "T", "U", "F::TargetU")
structname = get_rust_type(name)
is_located = bool(product.attributes)
self.emit(f"impl<T, U> Foldable<T, U> for {structname}{apply_t} {{", depth)
self.emit(f"type Mapped = {structname}{apply_u};", depth + 1)
self.emit("fn fold<F: Fold<T, TargetU = U> + ?Sized>(self, folder: &mut F) -> Result<Self::Mapped, F::Error> {", depth + 1)
self.emit(f"folder.fold_{name}(self)", depth + 2)
self.emit("}", depth + 1)
self.emit("}", depth)
self.emit(f"pub fn fold_{name}<U, F: Fold<U> + ?Sized>(#[allow(unused)] folder: &mut F, node: {structname}{apply_u}) -> Result<{structname}{apply_target_u}, F::Error> {{", depth)
if is_located:
self.emit("fold_located(folder, node, |folder, node| {", depth)
structname += "Data"
fields_pattern = self.make_pattern(product.fields)
self.emit(f"let {structname} {{ {fields_pattern} }} = node;", depth + 1)
self.gen_construction(structname, product.fields, depth + 1)
if is_located:
self.emit("})", depth)
self.emit("}", depth)
def make_pattern(self, fields):
return ",".join(rust_field(f.name) for f in fields)
def gen_construction(self, cons_path, fields, depth):
self.emit(f"Ok({cons_path} {{", depth)
for field in fields:
name = rust_field(field.name)
self.emit(f"{name}: Foldable::fold({name}, folder)?,", depth + 1)
self.emit("})", depth)
class FoldModuleVisitor(TypeInfoEmitVisitor):
def visitModule(self, mod):
depth = 0
self.emit('#[cfg(feature = "fold")]', depth)
self.emit("pub mod fold {", depth)
self.emit("use super::*;", depth + 1)
self.emit("use crate::fold_helpers::Foldable;", depth + 1)
FoldTraitDefVisitor(self.file, self.typeinfo).visit(mod, depth + 1)
FoldImplVisitor(self.file, self.typeinfo).visit(mod, depth + 1)
self.emit("}", depth)
class ClassDefVisitor(EmitVisitor):
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
for cons in sum.types:
self.visit(cons, sum.attributes, depth)
def visitConstructor(self, cons, attrs, depth):
self.gen_classdef(cons.name, cons.fields, attrs, depth)
def visitProduct(self, product, name, depth):
self.gen_classdef(name, product.fields, product.attributes, depth)
def gen_classdef(self, name, fields, attrs, depth):
structname = "Node" + name
self.emit(f'#[pyclass(module = "_ast", name = {json.dumps(name)}, base = "AstNode")]', depth)
self.emit(f"struct {structname};", depth)
self.emit("#[pyimpl(flags(HAS_DICT, BASETYPE))]", depth)
self.emit(f"impl {structname} {{", depth)
self.emit(f"#[extend_class]", depth + 1)
self.emit("fn extend_class_with_fields(ctx: &PyContext, class: &PyTypeRef) {", depth + 1)
fields = ",".join(f"ctx.new_str({json.dumps(f.name)})" for f in fields)
self.emit(f'class.set_str_attr("_fields", ctx.new_list(vec![{fields}]));', depth + 2)
attrs = ",".join(f"ctx.new_str({json.dumps(attr.name)})" for attr in attrs)
self.emit(f'class.set_str_attr("_attributes", ctx.new_list(vec![{attrs}]));', depth + 2)
self.emit("}", depth + 1)
self.emit("}", depth)
class ExtendModuleVisitor(EmitVisitor):
def visitModule(self, mod):
depth = 0
self.emit("pub fn extend_module_nodes(vm: &VirtualMachine, module: &PyObjectRef) {", depth)
self.emit("extend_module!(vm, module, {", depth + 1)
for dfn in mod.dfns:
self.visit(dfn, depth + 2)
self.emit("})", depth + 1)
self.emit("}", depth)
def visitType(self, type, depth):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
for cons in sum.types:
self.visit(cons, depth)
def visitConstructor(self, cons, depth):
self.gen_extension(cons.name, depth)
def visitProduct(self, product, name, depth):
self.gen_extension(name, depth)
def gen_extension(self, name, depth):
self.emit(f"{json.dumps(name)} => Node{name}::make_class(&vm.ctx),", depth)
class TraitImplVisitor(EmitVisitor):
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
enumname = get_rust_type(name)
if sum.attributes:
enumname += "Kind"
self.emit(f"impl NamedNode for ast::{enumname} {{", depth)
self.emit(f"const NAME: &'static str = {json.dumps(name)};", depth + 1)
self.emit("}", depth)
self.emit(f"impl Node for ast::{enumname} {{", depth)
self.emit("fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef {", depth + 1)
self.emit("match self {", depth + 2)
for variant in sum.types:
self.constructor_to_object(variant, enumname, depth + 3)
self.emit("}", depth + 2)
self.emit("}", depth + 1)
self.emit("fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult<Self> {", depth + 1)
self.gen_sum_fromobj(sum, name, enumname, depth + 2)
self.emit("}", depth + 1)
self.emit("}", depth)
def constructor_to_object(self, cons, enumname, depth):
fields_pattern = self.make_pattern(cons.fields)
self.emit(f"ast::{enumname}::{cons.name} {{ {fields_pattern} }} => {{", depth)
self.make_node(cons.name, cons.fields, depth + 1)
self.emit("}", depth)
def visitProduct(self, product, name, depth):
structname = get_rust_type(name)
if product.attributes:
structname += "Data"
self.emit(f"impl NamedNode for ast::{structname} {{", depth)
self.emit(f"const NAME: &'static str = {json.dumps(name)};", depth + 1)
self.emit("}", depth)
self.emit(f"impl Node for ast::{structname} {{", depth)
self.emit("fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef {", depth + 1)
fields_pattern = self.make_pattern(product.fields)
self.emit(f"let ast::{structname} {{ {fields_pattern} }} = self;", depth + 2)
self.make_node(name, product.fields, depth + 2)
self.emit("}", depth + 1)
self.emit("fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult<Self> {", depth + 1)
self.gen_product_fromobj(product, name, structname, depth + 2)
self.emit("}", depth + 1)
self.emit("}", depth)
def make_node(self, variant, fields, depth):
lines = []
self.emit(f"let _node = AstNode.into_ref_with_type(_vm, Node{variant}::static_type().clone()).unwrap();", depth)
if fields:
self.emit("let _dict = _node.as_object().dict().unwrap();", depth)
for f in fields:
self.emit(f"_dict.set_item({json.dumps(f.name)}, {rust_field(f.name)}.ast_to_object(_vm), _vm).unwrap();", depth)
self.emit("_node.into_object()", depth)
def make_pattern(self, fields):
return ",".join(rust_field(f.name) for f in fields)
def gen_sum_fromobj(self, sum, sumname, enumname, depth):
if sum.attributes:
self.extract_location(sumname, depth)
self.emit("let _cls = _object.class();", depth)
self.emit("Ok(", depth)
for cons in sum.types:
self.emit(f"if _cls.is(Node{cons.name}::static_type()) {{", depth)
self.gen_construction(f"{enumname}::{cons.name}", cons, sumname, depth + 1)
self.emit("} else", depth)
self.emit("{", depth)
msg = f'format!("expected some sort of {sumname}, but got {{}}",_vm.to_repr(&_object)?)'
self.emit(f"return Err(_vm.new_type_error({msg}));", depth + 1)
self.emit("})", depth)
def gen_product_fromobj(self, product, prodname, structname, depth):
if product.attributes:
self.extract_location(prodname, depth)
self.emit("Ok(", depth)
self.gen_construction(structname, product, prodname, depth + 1)
self.emit(")", depth)
def gen_construction(self, cons_path, cons, name, depth):
self.emit(f"ast::{cons_path} {{", depth)
for field in cons.fields:
self.emit(f"{rust_field(field.name)}: {self.decode_field(field, name)},", depth + 1)
self.emit("}", depth)
def extract_location(self, typename, depth):
row = self.decode_field(asdl.Field('int', 'lineno'), typename)
column = self.decode_field(asdl.Field('int', 'col_offset'), typename)
self.emit(f"let _location = ast::Location::new({row}, {column});", depth)
def wrap_located_node(self, depth):
self.emit(f"let node = ast::Located::new(_location, node);", depth)
def decode_field(self, field, typename):
name = json.dumps(field.name)
if field.opt and not field.seq:
return f"get_node_field_opt(_vm, &_object, {name})?.map(|obj| Node::ast_from_object(_vm, obj)).transpose()?"
else:
return f"Node::ast_from_object(_vm, get_node_field(_vm, &_object, {name}, {json.dumps(typename)})?)?"
class ChainOfVisitors:
def __init__(self, *visitors):
self.visitors = visitors
def visit(self, object):
for v in self.visitors:
v.visit(object)
v.emit("", 0)
def write_ast_def(mod, typeinfo, f):
f.write('pub use crate::location::Location;\n')
f.write('pub use crate::constant::*;\n')
f.write('\n')
f.write('type Ident = String;\n')
f.write('\n')
StructVisitor(f, typeinfo).emit_attrs(0)
f.write('pub struct Located<T, U = ()> {\n')
f.write(' pub location: Location,\n')
f.write(' pub custom: U,\n')
f.write(' pub node: T,\n')
f.write('}\n')
f.write('\n')
f.write('impl<T> Located<T> {\n')
f.write(' pub fn new(location: Location, node: T) -> Self {\n')
f.write(' Self { location, custom: (), node }\n')
f.write(' }\n')
f.write('}\n')
f.write('\n')
c = ChainOfVisitors(StructVisitor(f, typeinfo),
FoldModuleVisitor(f, typeinfo))
c.visit(mod)
def write_ast_mod(mod, f):
f.write('use super::*;\n')
f.write('\n')
c = ChainOfVisitors(ClassDefVisitor(f),
TraitImplVisitor(f),
ExtendModuleVisitor(f))
c.visit(mod)
def main(input_filename, ast_mod_filename, ast_def_filename, dump_module=False):
auto_gen_msg = AUTOGEN_MESSAGE.format("/".join(Path(__file__).parts[-2:]))
mod = asdl.parse(input_filename)
if dump_module:
print('Parsed Module:')
print(mod)
if not asdl.check(mod):
sys.exit(1)
typeinfo = {}
FindUserdataTypesVisitor(typeinfo).visit(mod)
with ast_def_filename.open("w") as def_file, \
ast_mod_filename.open("w") as mod_file:
def_file.write(auto_gen_msg)
write_ast_def(mod, typeinfo, def_file)
mod_file.write(auto_gen_msg)
write_ast_mod(mod, mod_file)
print(f"{ast_def_filename}, {ast_mod_filename} regenerated.")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("input_file", type=Path)
parser.add_argument("-M", "--mod-file", type=Path, required=True)
parser.add_argument("-D", "--def-file", type=Path, required=True)
parser.add_argument("-d", "--dump-module", action="store_true")
args = parser.parse_args()
main(args.input_file, args.mod_file, args.def_file, args.dump_module)

File diff suppressed because it is too large Load Diff

View File

@ -1,213 +0,0 @@
#[derive(Clone, Debug, PartialEq)]
pub enum Constant {
None,
Bool(bool),
Str(String),
Bytes(Vec<u8>),
Int(i128),
Tuple(Vec<Constant>),
Float(f64),
Complex { real: f64, imag: f64 },
Ellipsis,
}
impl From<String> for Constant {
fn from(s: String) -> Constant {
Self::Str(s)
}
}
impl From<Vec<u8>> for Constant {
fn from(b: Vec<u8>) -> Constant {
Self::Bytes(b)
}
}
impl From<bool> for Constant {
fn from(b: bool) -> Constant {
Self::Bool(b)
}
}
impl From<i32> for Constant {
fn from(i: i32) -> Constant {
Self::Int(i as i128)
}
}
impl From<i64> for Constant {
fn from(i: i64) -> Constant {
Self::Int(i as i128)
}
}
/// Transforms a value prior to formatting it.
#[derive(Copy, Clone, Debug, PartialEq)]
#[repr(u8)]
pub enum ConversionFlag {
/// Converts by calling `str(<value>)`.
Str = b's',
/// Converts by calling `ascii(<value>)`.
Ascii = b'a',
/// Converts by calling `repr(<value>)`.
Repr = b'r',
}
impl ConversionFlag {
pub fn try_from_byte(b: u8) -> Option<Self> {
match b {
b's' => Some(Self::Str),
b'a' => Some(Self::Ascii),
b'r' => Some(Self::Repr),
_ => None,
}
}
}
#[cfg(feature = "constant-optimization")]
#[derive(Default)]
pub struct ConstantOptimizer {
_priv: (),
}
#[cfg(feature = "constant-optimization")]
impl ConstantOptimizer {
#[inline]
pub fn new() -> Self {
Self { _priv: () }
}
}
#[cfg(feature = "constant-optimization")]
impl<U> crate::fold::Fold<U> for ConstantOptimizer {
type TargetU = U;
type Error = std::convert::Infallible;
#[inline]
fn map_user(&mut self, user: U) -> Result<Self::TargetU, Self::Error> {
Ok(user)
}
fn fold_expr(&mut self, node: crate::Expr<U>) -> Result<crate::Expr<U>, Self::Error> {
match node.node {
crate::ExprKind::Tuple { elts, ctx } => {
let elts = elts
.into_iter()
.map(|x| self.fold_expr(x))
.collect::<Result<Vec<_>, _>>()?;
let expr = if elts
.iter()
.all(|e| matches!(e.node, crate::ExprKind::Constant { .. }))
{
let tuple = elts
.into_iter()
.map(|e| match e.node {
crate::ExprKind::Constant { value, .. } => value,
_ => unreachable!(),
})
.collect();
crate::ExprKind::Constant {
value: Constant::Tuple(tuple),
kind: None,
}
} else {
crate::ExprKind::Tuple { elts, ctx }
};
Ok(crate::Expr {
node: expr,
custom: node.custom,
location: node.location,
})
}
_ => crate::fold::fold_expr(self, node),
}
}
}
#[cfg(test)]
mod tests {
#[cfg(feature = "constant-optimization")]
#[test]
fn test_constant_opt() {
use super::*;
use crate::fold::Fold;
use crate::*;
let location = Location::new(0, 0, Default::default());
let custom = ();
let ast = Located {
location,
custom,
node: ExprKind::Tuple {
ctx: ExprContext::Load,
elts: vec![
Located {
location,
custom,
node: ExprKind::Constant {
value: 1.into(),
kind: None,
},
},
Located {
location,
custom,
node: ExprKind::Constant {
value: 2.into(),
kind: None,
},
},
Located {
location,
custom,
node: ExprKind::Tuple {
ctx: ExprContext::Load,
elts: vec![
Located {
location,
custom,
node: ExprKind::Constant {
value: 3.into(),
kind: None,
},
},
Located {
location,
custom,
node: ExprKind::Constant {
value: 4.into(),
kind: None,
},
},
Located {
location,
custom,
node: ExprKind::Constant {
value: 5.into(),
kind: None,
},
},
],
},
},
],
},
};
let new_ast = ConstantOptimizer::new()
.fold_expr(ast)
.unwrap_or_else(|e| match e {});
assert_eq!(
new_ast,
Located {
location,
custom,
node: ExprKind::Constant {
value: Constant::Tuple(vec![
1.into(),
2.into(),
Constant::Tuple(vec![
3.into(),
4.into(),
5.into(),
])
]),
kind: None
},
}
);
}
}

View File

@ -1,74 +0,0 @@
use crate::constant;
use crate::fold::Fold;
use crate::StrRef;
pub(crate) trait Foldable<T, U> {
type Mapped;
fn fold<F: Fold<T, TargetU = U> + ?Sized>(
self,
folder: &mut F,
) -> Result<Self::Mapped, F::Error>;
}
impl<T, U, X> Foldable<T, U> for Vec<X>
where
X: Foldable<T, U>,
{
type Mapped = Vec<X::Mapped>;
fn fold<F: Fold<T, TargetU = U> + ?Sized>(
self,
folder: &mut F,
) -> Result<Self::Mapped, F::Error> {
self.into_iter().map(|x| x.fold(folder)).collect()
}
}
impl<T, U, X> Foldable<T, U> for Option<X>
where
X: Foldable<T, U>,
{
type Mapped = Option<X::Mapped>;
fn fold<F: Fold<T, TargetU = U> + ?Sized>(
self,
folder: &mut F,
) -> Result<Self::Mapped, F::Error> {
self.map(|x| x.fold(folder)).transpose()
}
}
impl<T, U, X> Foldable<T, U> for Box<X>
where
X: Foldable<T, U>,
{
type Mapped = Box<X::Mapped>;
fn fold<F: Fold<T, TargetU = U> + ?Sized>(
self,
folder: &mut F,
) -> Result<Self::Mapped, F::Error> {
(*self).fold(folder).map(Box::new)
}
}
macro_rules! simple_fold {
($($t:ty),+$(,)?) => {
$(impl<T, U> $crate::fold_helpers::Foldable<T, U> for $t {
type Mapped = Self;
#[inline]
fn fold<F: Fold<T, TargetU = U> + ?Sized>(
self,
_folder: &mut F,
) -> Result<Self::Mapped, F::Error> {
Ok(self)
}
})+
};
}
simple_fold!(
usize,
String,
bool,
StrRef,
constant::Constant,
constant::ConversionFlag
);

View File

@ -1,53 +0,0 @@
use crate::{Constant, ExprKind};
impl<U> ExprKind<U> {
/// Returns a short name for the node suitable for use in error messages.
pub fn name(&self) -> &'static str {
match self {
ExprKind::BoolOp { .. } | ExprKind::BinOp { .. } | ExprKind::UnaryOp { .. } => {
"operator"
}
ExprKind::Subscript { .. } => "subscript",
ExprKind::Await { .. } => "await expression",
ExprKind::Yield { .. } | ExprKind::YieldFrom { .. } => "yield expression",
ExprKind::Compare { .. } => "comparison",
ExprKind::Attribute { .. } => "attribute",
ExprKind::Call { .. } => "function call",
ExprKind::Constant { value, .. } => match value {
Constant::Str(_)
| Constant::Int(_)
| Constant::Float(_)
| Constant::Complex { .. }
| Constant::Bytes(_) => "literal",
Constant::Tuple(_) => "tuple",
Constant::Bool(_) | Constant::None => "keyword",
Constant::Ellipsis => "ellipsis",
},
ExprKind::List { .. } => "list",
ExprKind::Tuple { .. } => "tuple",
ExprKind::Dict { .. } => "dict display",
ExprKind::Set { .. } => "set display",
ExprKind::ListComp { .. } => "list comprehension",
ExprKind::DictComp { .. } => "dict comprehension",
ExprKind::SetComp { .. } => "set comprehension",
ExprKind::GeneratorExp { .. } => "generator expression",
ExprKind::Starred { .. } => "starred",
ExprKind::Slice { .. } => "slice",
ExprKind::JoinedStr { values } => {
if values
.iter()
.any(|e| matches!(e.node, ExprKind::JoinedStr { .. }))
{
"f-string expression"
} else {
"literal"
}
}
ExprKind::FormattedValue { .. } => "f-string expression",
ExprKind::Name { .. } => "name",
ExprKind::Lambda { .. } => "lambda",
ExprKind::IfExp { .. } => "conditional expression",
ExprKind::NamedExpr { .. } => "named expression",
}
}
}

View File

@ -1,14 +0,0 @@
#[macro_use]
extern crate lazy_static;
mod ast_gen;
mod constant;
#[cfg(feature = "fold")]
mod fold_helpers;
mod impls;
mod location;
pub use ast_gen::*;
pub use location::{Location, FileName};
pub type Suite<U = ()> = Vec<Stmt<U>>;

View File

@ -1,117 +0,0 @@
//! Datatypes to support source location information.
use std::cmp::Ordering;
use crate::ast_gen::StrRef;
use std::fmt;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct FileName(pub StrRef);
impl Default for FileName {
fn default() -> Self {
FileName("unknown".into())
}
}
impl From<String> for FileName {
fn from(s: String) -> Self {
FileName(s.into())
}
}
/// A location somewhere in the sourcecode.
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct Location {
pub row: usize,
pub column: usize,
pub file: FileName
}
impl fmt::Display for Location {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}:{}:{}", self.file.0, self.row, self.column)
}
}
impl Ord for Location {
fn cmp(&self, other: &Self) -> Ordering {
let file_cmp = self.file.0.to_string().cmp(&other.file.0.to_string());
if file_cmp != Ordering::Equal {
return file_cmp
}
let row_cmp = self.row.cmp(&other.row);
if row_cmp != Ordering::Equal {
return row_cmp
}
self.column.cmp(&other.column)
}
}
impl PartialOrd for Location {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Location {
pub fn visualize<'a>(
&self,
line: &'a str,
desc: impl fmt::Display + 'a,
) -> impl fmt::Display + 'a {
struct Visualize<'a, D: fmt::Display> {
loc: Location,
line: &'a str,
desc: D,
}
impl<D: fmt::Display> fmt::Display for Visualize<'_, D> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{}\n{}\n{arrow:>pad$}",
self.desc,
self.line,
pad = self.loc.column,
arrow = "",
)
}
}
Visualize {
loc: *self,
line,
desc,
}
}
}
impl Location {
pub fn new(row: usize, column: usize, file: FileName) -> Self {
Location { row, column, file }
}
pub fn row(&self) -> usize {
self.row
}
pub fn column(&self) -> usize {
self.column
}
pub fn reset(&mut self) {
self.row = 1;
self.column = 1;
}
pub fn go_right(&mut self) {
self.column += 1;
}
pub fn go_left(&mut self) {
self.column -= 1;
}
pub fn newline(&mut self) {
self.row += 1;
self.column = 1;
}
}

View File

@ -2,25 +2,15 @@
name = "nac3core"
version = "0.1.0"
authors = ["M-Labs"]
edition = "2021"
edition = "2018"
[dependencies]
itertools = "0.12"
crossbeam = "0.8"
indexmap = "2.2"
parking_lot = "0.12"
rayon = "1.8"
nac3parser = { path = "../nac3parser" }
[dependencies.inkwell]
version = "0.4"
default-features = false
features = ["llvm14-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]
num-bigint = "0.3"
num-traits = "0.2"
thiserror = "1.0"
inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm10-0"] }
rustpython-parser = { git = "https://github.com/RustPython/RustPython", branch = "master" }
[dev-dependencies]
test-case = "1.2.0"
indoc = "2.0"
insta = "=1.11.0"
indoc = "1.0"
[build-dependencies]
regex = "1.10"

View File

@ -1,78 +0,0 @@
use regex::Regex;
use std::{
env,
fs::File,
io::Write,
path::Path,
process::{Command, Stdio},
};
fn main() {
const FILE: &str = "src/codegen/irrt/irrt.c";
/*
* HACK: Sadly, clang doesn't let us emit generic LLVM bitcode.
* Compiling for WASM32 and filtering the output with regex is the closest we can get.
*/
let flags: &[&str] = &[
"--target=wasm32",
FILE,
"-fno-discard-value-names",
match env::var("PROFILE").as_deref() {
Ok("debug") => "-O0",
Ok("release") => "-O3",
flavor => panic!("Unknown or missing build flavor {flavor:?}"),
},
"-emit-llvm",
"-S",
"-Wall",
"-Wextra",
"-o",
"-",
];
println!("cargo:rerun-if-changed={FILE}");
let out_dir = env::var("OUT_DIR").unwrap();
let out_path = Path::new(&out_dir);
let output = Command::new("clang-irrt")
.args(flags)
.output()
.map(|o| {
assert!(o.status.success(), "{}", std::str::from_utf8(&o.stderr).unwrap());
o
})
.unwrap();
// https://github.com/rust-lang/regex/issues/244
let output = std::str::from_utf8(&output.stdout).unwrap().replace("\r\n", "\n");
let mut filtered_output = String::with_capacity(output.len());
let regex_filter = Regex::new(r"(?ms:^define.*?\}$)|(?m:^declare.*?$)").unwrap();
for f in regex_filter.captures_iter(&output) {
assert_eq!(f.len(), 1);
filtered_output.push_str(&f[0]);
filtered_output.push('\n');
}
let filtered_output = Regex::new("(#\\d+)|(, *![0-9A-Za-z.]+)|(![0-9A-Za-z.]+)|(!\".*?\")")
.unwrap()
.replace_all(&filtered_output, "");
println!("cargo:rerun-if-env-changed=DEBUG_DUMP_IRRT");
if env::var("DEBUG_DUMP_IRRT").is_ok() {
let mut file = File::create(out_path.join("irrt.ll")).unwrap();
file.write_all(output.as_bytes()).unwrap();
let mut file = File::create(out_path.join("irrt-filtered.ll")).unwrap();
file.write_all(filtered_output.as_bytes()).unwrap();
}
let mut llvm_as = Command::new("llvm-as-irrt")
.stdin(Stdio::piped())
.arg("-o")
.arg(out_path.join("irrt.bc"))
.spawn()
.unwrap();
llvm_as.stdin.as_mut().unwrap().write_all(filtered_output.as_bytes()).unwrap();
assert!(llvm_as.wait().unwrap().success());
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,318 +0,0 @@
use crate::{
symbol_resolver::SymbolValue,
toplevel::DefinitionId,
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
},
};
use nac3parser::ast::StrRef;
use std::collections::HashMap;
use indexmap::IndexMap;
pub struct ConcreteTypeStore {
store: Vec<ConcreteTypeEnum>,
}
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub struct ConcreteType(usize);
#[derive(Clone, Debug)]
pub struct ConcreteFuncArg {
pub name: StrRef,
pub ty: ConcreteType,
pub default_value: Option<SymbolValue>,
}
#[derive(Clone, Debug)]
pub enum Primitive {
Int32,
Int64,
UInt32,
UInt64,
Float,
Bool,
None,
Range,
Str,
Exception,
}
#[derive(Debug)]
pub enum ConcreteTypeEnum {
TPrimitive(Primitive),
TTuple {
ty: Vec<ConcreteType>,
},
TList {
ty: ConcreteType,
},
TObj {
obj_id: DefinitionId,
fields: HashMap<StrRef, (ConcreteType, bool)>,
params: IndexMap<u32, ConcreteType>,
},
TVirtual {
ty: ConcreteType,
},
TFunc {
args: Vec<ConcreteFuncArg>,
ret: ConcreteType,
vars: HashMap<u32, ConcreteType>,
},
TLiteral {
values: Vec<SymbolValue>,
},
}
impl ConcreteTypeStore {
#[must_use]
pub fn new() -> ConcreteTypeStore {
ConcreteTypeStore {
store: vec![
ConcreteTypeEnum::TPrimitive(Primitive::Int32),
ConcreteTypeEnum::TPrimitive(Primitive::Int64),
ConcreteTypeEnum::TPrimitive(Primitive::Float),
ConcreteTypeEnum::TPrimitive(Primitive::Bool),
ConcreteTypeEnum::TPrimitive(Primitive::None),
ConcreteTypeEnum::TPrimitive(Primitive::Range),
ConcreteTypeEnum::TPrimitive(Primitive::Str),
ConcreteTypeEnum::TPrimitive(Primitive::Exception),
ConcreteTypeEnum::TPrimitive(Primitive::UInt32),
ConcreteTypeEnum::TPrimitive(Primitive::UInt64),
],
}
}
#[must_use]
pub fn get(&self, cty: ConcreteType) -> &ConcreteTypeEnum {
&self.store[cty.0]
}
pub fn from_signature(
&mut self,
unifier: &mut Unifier,
primitives: &PrimitiveStore,
signature: &FunSignature,
cache: &mut HashMap<Type, Option<ConcreteType>>,
) -> ConcreteTypeEnum {
ConcreteTypeEnum::TFunc {
args: signature
.args
.iter()
.map(|arg| ConcreteFuncArg {
name: arg.name,
ty: self.from_unifier_type(unifier, primitives, arg.ty, cache),
default_value: arg.default_value.clone(),
})
.collect(),
ret: self.from_unifier_type(unifier, primitives, signature.ret, cache),
vars: signature
.vars
.iter()
.map(|(id, ty)| (*id, self.from_unifier_type(unifier, primitives, *ty, cache)))
.collect(),
}
}
pub fn from_unifier_type(
&mut self,
unifier: &mut Unifier,
primitives: &PrimitiveStore,
ty: Type,
cache: &mut HashMap<Type, Option<ConcreteType>>,
) -> ConcreteType {
let ty = unifier.get_representative(ty);
if unifier.unioned(ty, primitives.int32) {
ConcreteType(0)
} else if unifier.unioned(ty, primitives.int64) {
ConcreteType(1)
} else if unifier.unioned(ty, primitives.float) {
ConcreteType(2)
} else if unifier.unioned(ty, primitives.bool) {
ConcreteType(3)
} else if unifier.unioned(ty, primitives.none) {
ConcreteType(4)
} else if unifier.unioned(ty, primitives.range) {
ConcreteType(5)
} else if unifier.unioned(ty, primitives.str) {
ConcreteType(6)
} else if unifier.unioned(ty, primitives.exception) {
ConcreteType(7)
} else if unifier.unioned(ty, primitives.uint32) {
ConcreteType(8)
} else if unifier.unioned(ty, primitives.uint64) {
ConcreteType(9)
} else if let Some(cty) = cache.get(&ty) {
if let Some(cty) = cty {
*cty
} else {
let index = self.store.len();
// placeholder
self.store.push(ConcreteTypeEnum::TPrimitive(Primitive::Int32));
let result = ConcreteType(index);
cache.insert(ty, Some(result));
result
}
} else {
cache.insert(ty, None);
let ty_enum = unifier.get_ty(ty);
let result = match &*ty_enum {
TypeEnum::TTuple { ty } => ConcreteTypeEnum::TTuple {
ty: ty
.iter()
.map(|t| self.from_unifier_type(unifier, primitives, *t, cache))
.collect(),
},
TypeEnum::TList { ty } => ConcreteTypeEnum::TList {
ty: self.from_unifier_type(unifier, primitives, *ty, cache),
},
TypeEnum::TObj { obj_id, fields, params } => ConcreteTypeEnum::TObj {
obj_id: *obj_id,
fields: fields
.iter()
.filter_map(|(name, ty)| {
// here we should not have type vars, but some partial instantiated
// class methods can still have uninstantiated type vars, so
// filter out all the methods, as this will not affect codegen
if let TypeEnum::TFunc(..) = &*unifier.get_ty(ty.0) {
None
} else {
Some((
*name,
(
self.from_unifier_type(unifier, primitives, ty.0, cache),
ty.1,
),
))
}
})
.collect(),
params: params
.iter()
.map(|(id, ty)| {
(*id, self.from_unifier_type(unifier, primitives, *ty, cache))
})
.collect(),
},
TypeEnum::TVirtual { ty } => ConcreteTypeEnum::TVirtual {
ty: self.from_unifier_type(unifier, primitives, *ty, cache),
},
TypeEnum::TFunc(signature) => {
self.from_signature(unifier, primitives, signature, cache)
}
TypeEnum::TLiteral { values, .. } => ConcreteTypeEnum::TLiteral {
values: values.clone(),
},
_ => unreachable!("{:?}", ty_enum.get_type_name()),
};
let index = if let Some(ConcreteType(index)) = cache.get(&ty).unwrap() {
self.store[*index] = result;
*index
} else {
self.store.push(result);
self.store.len() - 1
};
cache.insert(ty, Some(ConcreteType(index)));
ConcreteType(index)
}
}
pub fn to_unifier_type(
&self,
unifier: &mut Unifier,
primitives: &PrimitiveStore,
cty: ConcreteType,
cache: &mut HashMap<ConcreteType, Option<Type>>,
) -> Type {
if let Some(ty) = cache.get_mut(&cty) {
return if let Some(ty) = ty {
*ty
} else {
*ty = Some(unifier.get_dummy_var().0);
ty.unwrap()
};
}
cache.insert(cty, None);
let result = match &self.store[cty.0] {
ConcreteTypeEnum::TPrimitive(primitive) => {
let ty = match primitive {
Primitive::Int32 => primitives.int32,
Primitive::Int64 => primitives.int64,
Primitive::UInt32 => primitives.uint32,
Primitive::UInt64 => primitives.uint64,
Primitive::Float => primitives.float,
Primitive::Bool => primitives.bool,
Primitive::None => primitives.none,
Primitive::Range => primitives.range,
Primitive::Str => primitives.str,
Primitive::Exception => primitives.exception,
};
*cache.get_mut(&cty).unwrap() = Some(ty);
return ty;
}
ConcreteTypeEnum::TTuple { ty } => TypeEnum::TTuple {
ty: ty
.iter()
.map(|cty| self.to_unifier_type(unifier, primitives, *cty, cache))
.collect(),
},
ConcreteTypeEnum::TList { ty } => {
TypeEnum::TList { ty: self.to_unifier_type(unifier, primitives, *ty, cache) }
}
ConcreteTypeEnum::TVirtual { ty } => {
TypeEnum::TVirtual { ty: self.to_unifier_type(unifier, primitives, *ty, cache) }
}
ConcreteTypeEnum::TObj { obj_id, fields, params } => TypeEnum::TObj {
obj_id: *obj_id,
fields: fields
.iter()
.map(|(name, cty)| {
(*name, (self.to_unifier_type(unifier, primitives, cty.0, cache), cty.1))
})
.collect::<HashMap<_, _>>(),
params: params
.iter()
.map(|(id, cty)| (*id, self.to_unifier_type(unifier, primitives, *cty, cache)))
.collect::<VarMap>(),
},
ConcreteTypeEnum::TFunc { args, ret, vars } => TypeEnum::TFunc(FunSignature {
args: args
.iter()
.map(|arg| FuncArg {
name: arg.name,
ty: self.to_unifier_type(unifier, primitives, arg.ty, cache),
default_value: arg.default_value.clone(),
})
.collect(),
ret: self.to_unifier_type(unifier, primitives, *ret, cache),
vars: vars
.iter()
.map(|(id, cty)| (*id, self.to_unifier_type(unifier, primitives, *cty, cache)))
.collect::<VarMap>(),
}),
ConcreteTypeEnum::TLiteral { values, .. } => TypeEnum::TLiteral {
values: values.clone(),
loc: None,
}
};
let result = unifier.add_ty(result);
if let Some(ty) = cache.get(&cty).unwrap() {
unifier.unify(*ty, result).unwrap();
}
cache.insert(cty, Some(result));
result
}
pub fn add_cty(&mut self, cty: ConcreteTypeEnum) -> ConcreteType {
self.store.push(cty);
ConcreteType(self.store.len() - 1)
}
}
impl Default for ConcreteTypeStore {
fn default() -> Self {
Self::new()
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,613 +0,0 @@
use inkwell::attributes::{Attribute, AttributeLoc};
use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue};
use itertools::Either;
use crate::codegen::CodeGenContext;
/// Invokes the [`tan`](https://en.cppreference.com/w/c/numeric/math/tan) function.
pub fn call_tan<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "tan";
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0)
);
}
func
});
ctx.builder
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`asin`](https://en.cppreference.com/w/c/numeric/math/asin) function.
pub fn call_asin<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "asin";
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0)
);
}
func
});
ctx.builder
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`acos`](https://en.cppreference.com/w/c/numeric/math/acos) function.
pub fn call_acos<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "acos";
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0)
);
}
func
});
ctx.builder
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`atan`](https://en.cppreference.com/w/c/numeric/math/atan) function.
pub fn call_atan<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "atan";
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0)
);
}
func
});
ctx.builder
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`sinh`](https://en.cppreference.com/w/c/numeric/math/sinh) function.
pub fn call_sinh<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "sinh";
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0)
);
}
func
});
ctx.builder
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`cosh`](https://en.cppreference.com/w/c/numeric/math/cosh) function.
pub fn call_cosh<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "cosh";
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0)
);
}
func
});
ctx.builder
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`tanh`](https://en.cppreference.com/w/c/numeric/math/tanh) function.
pub fn call_tanh<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "tanh";
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0)
);
}
func
});
ctx.builder
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`asinh`](https://en.cppreference.com/w/c/numeric/math/asinh) function.
pub fn call_asinh<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "asinh";
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0)
);
}
func
});
ctx.builder
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`acosh`](https://en.cppreference.com/w/c/numeric/math/acosh) function.
pub fn call_acosh<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "acosh";
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0)
);
}
func
});
ctx.builder
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`atanh`](https://en.cppreference.com/w/c/numeric/math/atanh) function.
pub fn call_atanh<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "atanh";
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0)
);
}
func
});
ctx.builder
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`expm1`](https://en.cppreference.com/w/c/numeric/math/expm1) function.
pub fn call_expm1<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "expm1";
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0)
);
}
func
});
ctx.builder
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`cbrt`](https://en.cppreference.com/w/c/numeric/math/cbrt) function.
pub fn call_cbrt<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "cbrt";
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nosync", "nounwind", "readonly", "willreturn"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0)
);
}
func
});
ctx.builder
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`erf`](https://en.cppreference.com/w/c/numeric/math/erf) function.
pub fn call_erf<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "erf";
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0)
);
func
});
ctx.builder
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`erfc`](https://en.cppreference.com/w/c/numeric/math/erfc) function.
pub fn call_erfc<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "erfc";
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0)
);
func
});
ctx.builder
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`j1`](https://www.gnu.org/software/libc/manual/html_node/Special-Functions.html#index-j1)
/// function.
pub fn call_j1<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "j1";
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0)
);
func
});
ctx.builder
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`atan2`](https://en.cppreference.com/w/c/numeric/math/atan2) function.
pub fn call_atan2<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
y: FloatValue<'ctx>,
x: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "atan2";
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(y.get_type(), llvm_f64);
debug_assert_eq!(x.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into(), llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0)
);
}
func
});
ctx.builder
.build_call(extern_fn, &[y.into(), x.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`ldexp`](https://en.cppreference.com/w/c/numeric/math/ldexp) function.
pub fn call_ldexp<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
exp: IntValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "ldexp";
let llvm_f64 = ctx.ctx.f64_type();
let llvm_i32 = ctx.ctx.i32_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
debug_assert_eq!(exp.get_type(), llvm_i32);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into(), llvm_i32.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0)
);
}
func
});
ctx.builder
.build_call(extern_fn, &[arg.into(), exp.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`hypot`](https://en.cppreference.com/w/c/numeric/math/hypot) function.
pub fn call_hypot<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
x: FloatValue<'ctx>,
y: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "hypot";
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(x.get_type(), llvm_f64);
debug_assert_eq!(y.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into(), llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0)
);
func
});
ctx.builder
.build_call(extern_fn, &[x.into(), y.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`nextafter`](https://en.cppreference.com/w/c/numeric/math/nextafter) function.
pub fn call_nextafter<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
from: FloatValue<'ctx>,
to: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "nextafter";
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(from.get_type(), llvm_f64);
debug_assert_eq!(to.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into(), llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0)
);
func
});
ctx.builder
.build_call(extern_fn, &[from.into(), to.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}

View File

@ -1,258 +0,0 @@
use crate::{
codegen::{classes::ArraySliceValue, expr::*, stmt::*, bool_to_i1, bool_to_i8, CodeGenContext},
symbol_resolver::ValueEnum,
toplevel::{DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, Type},
};
use inkwell::{
context::Context,
types::{BasicTypeEnum, IntType},
values::{BasicValueEnum, IntValue, PointerValue},
};
use nac3parser::ast::{Expr, Stmt, StrRef};
pub trait CodeGenerator {
/// Return the module name for the code generator.
fn get_name(&self) -> &str;
fn get_size_type<'ctx>(&self, ctx: &'ctx Context) -> IntType<'ctx>;
/// Generate function call and returns the function return value.
/// - obj: Optional object for method call.
/// - fun: Function signature and definition ID.
/// - params: Function parameters. Note that this does not include the object even if the
/// function is a class method.
fn gen_call<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
obj: Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
) -> Result<Option<BasicValueEnum<'ctx>>, String>
where
Self: Sized,
{
gen_call(self, ctx, obj, fun, params)
}
/// Generate object constructor and returns the constructed object.
/// - signature: Function signature of the constructor.
/// - def: Class definition for the constructor class.
/// - params: Function parameters.
fn gen_constructor<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
signature: &FunSignature,
def: &TopLevelDef,
params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
) -> Result<BasicValueEnum<'ctx>, String>
where
Self: Sized,
{
gen_constructor(self, ctx, signature, def, params)
}
/// Generate a function instance.
/// - obj: Optional object for method call.
/// - fun: Function signature, definition ID and the substitution key.
/// - params: Function parameters. Note that this does not include the object even if the
/// function is a class method.
/// Note that this function should check if the function is generated in another thread (due to
/// possible race condition), see the default implementation for an example.
fn gen_func_instance<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
obj: Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, &mut TopLevelDef, String),
id: usize,
) -> Result<String, String> {
gen_func_instance(ctx, &obj, fun, id)
}
/// Generate the code for an expression.
fn gen_expr<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
expr: &Expr<Option<Type>>,
) -> Result<Option<ValueEnum<'ctx>>, String>
where
Self: Sized,
{
gen_expr(self, ctx, expr)
}
/// Allocate memory for a variable and return a pointer pointing to it.
/// The default implementation places the allocations at the start of the function.
fn gen_var_alloc<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
ty: BasicTypeEnum<'ctx>,
name: Option<&str>,
) -> Result<PointerValue<'ctx>, String> {
gen_var(ctx, ty, name)
}
/// Allocate memory for a variable and return a pointer pointing to it.
/// The default implementation places the allocations at the start of the function.
fn gen_array_var_alloc<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
ty: BasicTypeEnum<'ctx>,
size: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> Result<ArraySliceValue<'ctx>, String> {
gen_array_var(ctx, ty, size, name)
}
/// Return a pointer pointing to the target of the expression.
fn gen_store_target<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
pattern: &Expr<Option<Type>>,
name: Option<&str>,
) -> Result<Option<PointerValue<'ctx>>, String>
where
Self: Sized,
{
gen_store_target(self, ctx, pattern, name)
}
/// Generate code for an assignment expression.
fn gen_assign<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
target: &Expr<Option<Type>>,
value: ValueEnum<'ctx>,
) -> Result<(), String>
where
Self: Sized,
{
gen_assign(self, ctx, target, value)
}
/// Generate code for a while expression.
/// Return true if the while loop must early return
fn gen_while(
&mut self,
ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>,
) -> Result<(), String>
where
Self: Sized,
{
gen_while(self, ctx, stmt)
}
/// Generate code for a while expression.
/// Return true if the while loop must early return
fn gen_for(
&mut self,
ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>,
) -> Result<(), String>
where
Self: Sized,
{
gen_for(self, ctx, stmt)
}
/// Generate code for an if expression.
/// Return true if the statement must early return
fn gen_if(
&mut self,
ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>,
) -> Result<(), String>
where
Self: Sized,
{
gen_if(self, ctx, stmt)
}
fn gen_with(
&mut self,
ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>,
) -> Result<(), String>
where
Self: Sized,
{
gen_with(self, ctx, stmt)
}
/// Generate code for a statement
///
/// Return true if the statement must early return
fn gen_stmt(
&mut self,
ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>,
) -> Result<(), String>
where
Self: Sized,
{
gen_stmt(self, ctx, stmt)
}
/// Generates code for a block statement.
fn gen_block<'a, I: Iterator<Item = &'a Stmt<Option<Type>>>>(
&mut self,
ctx: &mut CodeGenContext<'_, '_>,
stmts: I,
) -> Result<(), String>
where
Self: Sized,
{
gen_block(self, ctx, stmts)
}
/// See [`bool_to_i1`].
fn bool_to_i1<'ctx>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
bool_value: IntValue<'ctx>
) -> IntValue<'ctx> {
bool_to_i1(&ctx.builder, bool_value)
}
/// See [`bool_to_i8`].
fn bool_to_i8<'ctx>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
bool_value: IntValue<'ctx>
) -> IntValue<'ctx> {
bool_to_i8(&ctx.builder, ctx.ctx, bool_value)
}
}
pub struct DefaultCodeGenerator {
name: String,
size_t: u32,
}
impl DefaultCodeGenerator {
#[must_use]
pub fn new(name: String, size_t: u32) -> DefaultCodeGenerator {
assert!(matches!(size_t, 32 | 64));
DefaultCodeGenerator { name, size_t }
}
}
impl CodeGenerator for DefaultCodeGenerator {
/// Returns the name for this [`CodeGenerator`].
fn get_name(&self) -> &str {
&self.name
}
/// Returns an LLVM integer type representing `size_t`.
fn get_size_type<'ctx>(&self, ctx: &'ctx Context) -> IntType<'ctx> {
// it should be unsigned, but we don't really need unsigned and this could save us from
// having to do a bit cast...
if self.size_t == 32 {
ctx.i32_type()
} else {
ctx.i64_type()
}
}
}

View File

@ -1,381 +0,0 @@
typedef _BitInt(8) int8_t;
typedef unsigned _BitInt(8) uint8_t;
typedef _BitInt(32) int32_t;
typedef unsigned _BitInt(32) uint32_t;
typedef _BitInt(64) int64_t;
typedef unsigned _BitInt(64) uint64_t;
# define MAX(a, b) (a > b ? a : b)
# define MIN(a, b) (a > b ? b : a)
# define NULL ((void *) 0)
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
// need to make sure `exp >= 0` before calling this function
#define DEF_INT_EXP(T) T __nac3_int_exp_##T( \
T base, \
T exp \
) { \
T res = (T)1; \
/* repeated squaring method */ \
do { \
if (exp & 1) res *= base; /* for n odd */ \
exp >>= 1; \
base *= base; \
} while (exp); \
return res; \
} \
DEF_INT_EXP(int32_t)
DEF_INT_EXP(int64_t)
DEF_INT_EXP(uint32_t)
DEF_INT_EXP(uint64_t)
int32_t __nac3_slice_index_bound(int32_t i, const int32_t len) {
if (i < 0) {
i = len + i;
}
if (i < 0) {
return 0;
} else if (i > len) {
return len;
}
return i;
}
int32_t __nac3_range_slice_len(const int32_t start, const int32_t end, const int32_t step) {
int32_t diff = end - start;
if (diff > 0 && step > 0) {
return ((diff - 1) / step) + 1;
} else if (diff < 0 && step < 0) {
return ((diff + 1) / step) + 1;
} else {
return 0;
}
}
// Handle list assignment and dropping part of the list when
// both dest_step and src_step are +1.
// - All the index must *not* be out-of-bound or negative,
// - The end index is *inclusive*,
// - The length of src and dest slice size should already
// be checked: if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest)
int32_t __nac3_list_slice_assign_var_size(
int32_t dest_start,
int32_t dest_end,
int32_t dest_step,
uint8_t *dest_arr,
int32_t dest_arr_len,
int32_t src_start,
int32_t src_end,
int32_t src_step,
uint8_t *src_arr,
int32_t src_arr_len,
const int32_t size
) {
/* if dest_arr_len == 0, do nothing since we do not support extending list */
if (dest_arr_len == 0) return dest_arr_len;
/* if both step is 1, memmove directly, handle the dropping of the list, and shrink size */
if (src_step == dest_step && dest_step == 1) {
const int32_t src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0;
const int32_t dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
if (src_len > 0) {
__builtin_memmove(
dest_arr + dest_start * size,
src_arr + src_start * size,
src_len * size
);
}
if (dest_len > 0) {
/* dropping */
__builtin_memmove(
dest_arr + (dest_start + src_len) * size,
dest_arr + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size
);
}
/* shrink size */
return dest_arr_len - (dest_len - src_len);
}
/* if two range overlaps, need alloca */
uint8_t need_alloca =
(dest_arr == src_arr)
&& !(
MAX(dest_start, dest_end) < MIN(src_start, src_end)
|| MAX(src_start, src_end) < MIN(dest_start, dest_end)
);
if (need_alloca) {
uint8_t *tmp = __builtin_alloca(src_arr_len * size);
__builtin_memcpy(tmp, src_arr, src_arr_len * size);
src_arr = tmp;
}
int32_t src_ind = src_start;
int32_t dest_ind = dest_start;
for (;
(src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end);
src_ind += src_step, dest_ind += dest_step
) {
/* for constant optimization */
if (size == 1) {
__builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1);
} else if (size == 4) {
__builtin_memcpy(dest_arr + dest_ind * 4, src_arr + src_ind * 4, 4);
} else if (size == 8) {
__builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8);
} else {
/* memcpy for var size, cannot overlap after previous alloca */
__builtin_memcpy(dest_arr + dest_ind * size, src_arr + src_ind * size, size);
}
}
/* only dest_step == 1 can we shrink the dest list. */
/* size should be ensured prior to calling this function */
if (dest_step == 1 && dest_end >= dest_start) {
__builtin_memmove(
dest_arr + dest_ind * size,
dest_arr + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size
);
return dest_arr_len - (dest_end - dest_ind) - 1;
}
return dest_arr_len;
}
int32_t __nac3_isinf(double x) {
return __builtin_isinf(x);
}
int32_t __nac3_isnan(double x) {
return __builtin_isnan(x);
}
double tgamma(double arg);
double __nac3_gamma(double z) {
// Handling for denormals
// | x | Python gamma(x) | C tgamma(x) |
// --- | ----------------- | --------------- | ----------- |
// (1) | nan | nan | nan |
// (2) | -inf | -inf | inf |
// (3) | inf | inf | inf |
// (4) | 0.0 | inf | inf |
// (5) | {-1.0, -2.0, ...} | inf | nan |
// (1)-(3)
if (__builtin_isinf(z) || __builtin_isnan(z)) {
return z;
}
double v = tgamma(z);
// (4)-(5)
return __builtin_isinf(v) || __builtin_isnan(v) ? __builtin_inf() : v;
}
double lgamma(double arg);
double __nac3_gammaln(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: gammaln(-inf) -> -inf
// - libm : lgamma(-inf) -> inf
if (__builtin_isinf(x)) {
return x;
}
return lgamma(x);
}
double j0(double x);
double __nac3_j0(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: j0(inf) -> nan
// - libm : j0(inf) -> 0.0
if (__builtin_isinf(x)) {
return __builtin_nan("");
}
return j0(x);
}
uint32_t __nac3_ndarray_calc_size(
const uint64_t *list_data,
uint32_t list_len
) {
uint32_t num_elems = 1;
for (uint32_t i = 0; i < list_len; ++i) {
uint64_t val = list_data[i];
__builtin_assume(val > 0);
num_elems *= val;
}
return num_elems;
}
uint64_t __nac3_ndarray_calc_size64(
const uint64_t *list_data,
uint64_t list_len
) {
uint64_t num_elems = 1;
for (uint64_t i = 0; i < list_len; ++i) {
uint64_t val = list_data[i];
__builtin_assume(val > 0);
num_elems *= val;
}
return num_elems;
}
void __nac3_ndarray_calc_nd_indices(
uint32_t index,
const uint32_t* dims,
uint32_t num_dims,
uint32_t* idxs
) {
uint32_t stride = 1;
for (uint32_t dim = 0; dim < num_dims; dim++) {
uint32_t i = num_dims - dim - 1;
__builtin_assume(dims[i] > 0);
idxs[i] = (index / stride) % dims[i];
stride *= dims[i];
}
}
void __nac3_ndarray_calc_nd_indices64(
uint64_t index,
const uint64_t* dims,
uint64_t num_dims,
uint32_t* idxs
) {
uint64_t stride = 1;
for (uint64_t dim = 0; dim < num_dims; dim++) {
uint64_t i = num_dims - dim - 1;
__builtin_assume(dims[i] > 0);
idxs[i] = (uint32_t) ((index / stride) % dims[i]);
stride *= dims[i];
}
}
uint32_t __nac3_ndarray_flatten_index(
const uint32_t* dims,
uint32_t num_dims,
const uint32_t* indices,
uint32_t num_indices
) {
uint32_t idx = 0;
uint32_t stride = 1;
for (uint32_t i = 0; i < num_dims; ++i) {
uint32_t ri = num_dims - i - 1;
if (ri < num_indices) {
idx += (stride * indices[ri]);
}
__builtin_assume(dims[i] > 0);
stride *= dims[ri];
}
return idx;
}
uint64_t __nac3_ndarray_flatten_index64(
const uint64_t* dims,
uint64_t num_dims,
const uint32_t* indices,
uint64_t num_indices
) {
uint64_t idx = 0;
uint64_t stride = 1;
for (uint64_t i = 0; i < num_dims; ++i) {
uint64_t ri = num_dims - i - 1;
if (ri < num_indices) {
idx += (stride * indices[ri]);
}
__builtin_assume(dims[i] > 0);
stride *= dims[ri];
}
return idx;
}
void __nac3_ndarray_calc_broadcast(
const uint32_t *lhs_dims,
uint32_t lhs_ndims,
const uint32_t *rhs_dims,
uint32_t rhs_ndims,
uint32_t *out_dims
) {
uint32_t max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
for (uint32_t i = 0; i < max_ndims; ++i) {
uint32_t *lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : NULL;
uint32_t *rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : NULL;
uint32_t *out_dim = &out_dims[max_ndims - i - 1];
if (lhs_dim_sz == NULL) {
*out_dim = *rhs_dim_sz;
} else if (rhs_dim_sz == NULL) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == 1) {
*out_dim = *rhs_dim_sz;
} else if (*rhs_dim_sz == 1) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == *rhs_dim_sz) {
*out_dim = *lhs_dim_sz;
} else {
__builtin_unreachable();
}
}
}
void __nac3_ndarray_calc_broadcast64(
const uint64_t *lhs_dims,
uint64_t lhs_ndims,
const uint64_t *rhs_dims,
uint64_t rhs_ndims,
uint64_t *out_dims
) {
uint64_t max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
for (uint64_t i = 0; i < max_ndims; ++i) {
uint64_t *lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : NULL;
uint64_t *rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : NULL;
uint64_t *out_dim = &out_dims[max_ndims - i - 1];
if (lhs_dim_sz == NULL) {
*out_dim = *rhs_dim_sz;
} else if (rhs_dim_sz == NULL) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == 1) {
*out_dim = *rhs_dim_sz;
} else if (*rhs_dim_sz == 1) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == *rhs_dim_sz) {
*out_dim = *lhs_dim_sz;
} else {
__builtin_unreachable();
}
}
}
void __nac3_ndarray_calc_broadcast_idx(
const uint32_t *src_dims,
uint32_t src_ndims,
const uint32_t *in_idx,
uint32_t *out_idx
) {
for (uint32_t i = 0; i < src_ndims; ++i) {
uint32_t src_i = src_ndims - i - 1;
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
}
}
void __nac3_ndarray_calc_broadcast_idx64(
const uint64_t *src_dims,
uint64_t src_ndims,
const uint32_t *in_idx,
uint32_t *out_idx
) {
for (uint64_t i = 0; i < src_ndims; ++i) {
uint64_t src_i = src_ndims - i - 1;
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : (uint32_t) in_idx[src_i];
}
}

View File

@ -1,981 +0,0 @@
use crate::typecheck::typedef::Type;
use super::{
classes::{
ArrayLikeIndexer,
ArrayLikeValue,
ArraySliceValue,
ListValue,
NDArrayValue,
TypedArrayLikeAdapter,
UntypedArrayLikeAccessor,
},
CodeGenContext,
CodeGenerator,
llvm_intrinsics,
};
use inkwell::{
attributes::{Attribute, AttributeLoc},
context::Context,
memory_buffer::MemoryBuffer,
module::Module,
types::{BasicTypeEnum, IntType},
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue},
AddressSpace, IntPredicate,
};
use itertools::Either;
use nac3parser::ast::Expr;
use crate::codegen::classes::TypedArrayLikeAccessor;
use crate::codegen::stmt::gen_for_callback_incrementing;
#[must_use]
pub fn load_irrt(ctx: &Context) -> Module {
let bitcode_buf = MemoryBuffer::create_from_memory_range(
include_bytes!(concat!(env!("OUT_DIR"), "/irrt.bc")),
"irrt_bitcode_buffer",
);
let irrt_mod = Module::parse_bitcode_from_buffer(&bitcode_buf, ctx).unwrap();
let inline_attr = Attribute::get_named_enum_kind_id("alwaysinline");
for symbol in &[
"__nac3_int_exp_int32_t",
"__nac3_int_exp_int64_t",
"__nac3_range_slice_len",
"__nac3_slice_index_bound",
] {
let function = irrt_mod.get_function(symbol).unwrap();
function.add_attribute(AttributeLoc::Function, ctx.create_enum_attribute(inline_attr, 0));
}
irrt_mod
}
// repeated squaring method adapted from GNU Scientific Library:
// https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
base: IntValue<'ctx>,
exp: IntValue<'ctx>,
signed: bool,
) -> IntValue<'ctx> {
let symbol = match (base.get_type().get_bit_width(), exp.get_type().get_bit_width(), signed) {
(32, 32, true) => "__nac3_int_exp_int32_t",
(64, 64, true) => "__nac3_int_exp_int64_t",
(32, 32, false) => "__nac3_int_exp_uint32_t",
(64, 64, false) => "__nac3_int_exp_uint64_t",
_ => unreachable!(),
};
let base_type = base.get_type();
let pow_fun = ctx.module.get_function(symbol).unwrap_or_else(|| {
let fn_type = base_type.fn_type(&[base_type.into(), base_type.into()], false);
ctx.module.add_function(symbol, fn_type, None)
});
// throw exception when exp < 0
let ge_zero = ctx.builder.build_int_compare(
IntPredicate::SGE,
exp,
exp.get_type().const_zero(),
"assert_int_pow_ge_0",
).unwrap();
ctx.make_assert(
generator,
ge_zero,
"0:ValueError",
"integer power must be positive or zero",
[None, None, None],
ctx.current_loc,
);
ctx.builder
.build_call(pow_fun, &[base.into(), exp.into()], "call_int_pow")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
start: IntValue<'ctx>,
end: IntValue<'ctx>,
step: IntValue<'ctx>,
) -> IntValue<'ctx> {
const SYMBOL: &str = "__nac3_range_slice_len";
let len_func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
let i32_t = ctx.ctx.i32_type();
let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into(), i32_t.into()], false);
ctx.module.add_function(SYMBOL, fn_t, None)
});
// assert step != 0, throw exception if not
let not_zero = ctx.builder.build_int_compare(
IntPredicate::NE,
step,
step.get_type().const_zero(),
"range_step_ne",
).unwrap();
ctx.make_assert(
generator,
not_zero,
"0:ValueError",
"step must not be zero",
[None, None, None],
ctx.current_loc,
);
ctx.builder
.build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
/// NOTE: the output value of the end index of this function should be compared ***inclusively***,
/// because python allows `a[2::-1]`, whose semantic is `[a[2], a[1], a[0]]`, which is equivalent to
/// NO numeric slice in python.
///
/// equivalent code:
/// ```pseudo_code
/// match (start, end, step):
/// case (s, e, None | Some(step)) if step > 0:
/// return (
/// match s:
/// case None:
/// 0
/// case Some(s):
/// handle_in_bound(s)
/// ,match e:
/// case None:
/// length - 1
/// case Some(e):
/// handle_in_bound(e) - 1
/// ,step == None ? 1 : step
/// )
/// case (s, e, Some(step)) if step < 0:
/// return (
/// match s:
/// case None:
/// length - 1
/// case Some(s):
/// s = handle_in_bound(s)
/// if s == length:
/// s - 1
/// else:
/// s
/// ,match e:
/// case None:
/// 0
/// case Some(e):
/// handle_in_bound(e) + 1
/// ,step
/// )
/// ```
pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
start: &Option<Box<Expr<Option<Type>>>>,
end: &Option<Box<Expr<Option<Type>>>>,
step: &Option<Box<Expr<Option<Type>>>>,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
list: ListValue<'ctx>,
) -> Result<Option<(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)>, String> {
let int32 = ctx.ctx.i32_type();
let zero = int32.const_zero();
let one = int32.const_int(1, false);
let length = list.load_size(ctx, Some("length"));
let length = ctx.builder.build_int_truncate_or_bit_cast(length, int32, "leni32").unwrap();
Ok(Some(match (start, end, step) {
(s, e, None) => (
if let Some(s) = s.as_ref() {
match handle_slice_index_bound(s, ctx, generator, length)? {
Some(v) => v,
None => return Ok(None),
}
} else {
int32.const_zero()
},
{
let e = if let Some(s) = e.as_ref() {
match handle_slice_index_bound(s, ctx, generator, length)? {
Some(v) => v,
None => return Ok(None),
}
} else {
length
};
ctx.builder.build_int_sub(e, one, "final_end").unwrap()
},
one,
),
(s, e, Some(step)) => {
let step = if let Some(v) = generator.gen_expr(ctx, step)? {
v.to_basic_value_enum(ctx, generator, ctx.primitives.int32)?.into_int_value()
} else {
return Ok(None)
};
// assert step != 0, throw exception if not
let not_zero = ctx.builder.build_int_compare(
IntPredicate::NE,
step,
step.get_type().const_zero(),
"range_step_ne",
).unwrap();
ctx.make_assert(
generator,
not_zero,
"0:ValueError",
"slice step cannot be zero",
[None, None, None],
ctx.current_loc,
);
let len_id = ctx.builder.build_int_sub(length, one, "lenmin1").unwrap();
let neg = ctx.builder.build_int_compare(IntPredicate::SLT, step, zero, "step_is_neg").unwrap();
(
match s {
Some(s) => {
let Some(s) = handle_slice_index_bound(s, ctx, generator, length)? else {
return Ok(None)
};
ctx.builder
.build_select(
ctx.builder.build_and(
ctx.builder.build_int_compare(
IntPredicate::EQ,
s,
length,
"s_eq_len",
).unwrap(),
neg,
"should_minus_one",
).unwrap(),
ctx.builder.build_int_sub(s, one, "s_min").unwrap(),
s,
"final_start",
)
.map(BasicValueEnum::into_int_value)
.unwrap()
}
None => ctx.builder.build_select(neg, len_id, zero, "stt")
.map(BasicValueEnum::into_int_value)
.unwrap(),
},
match e {
Some(e) => {
let Some(e) = handle_slice_index_bound(e, ctx, generator, length)? else {
return Ok(None)
};
ctx.builder
.build_select(
neg,
ctx.builder.build_int_add(e, one, "end_add_one").unwrap(),
ctx.builder.build_int_sub(e, one, "end_sub_one").unwrap(),
"final_end",
)
.map(BasicValueEnum::into_int_value)
.unwrap()
}
None => ctx.builder.build_select(neg, zero, len_id, "end")
.map(BasicValueEnum::into_int_value)
.unwrap(),
},
step,
)
}
}))
}
/// this function allows index out of range, since python
/// allows index out of range in slice (`a = [1,2,3]; a[1:10] == [2,3]`).
pub fn handle_slice_index_bound<'ctx, G: CodeGenerator>(
i: &Expr<Option<Type>>,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
length: IntValue<'ctx>,
) -> Result<Option<IntValue<'ctx>>, String> {
const SYMBOL: &str = "__nac3_slice_index_bound";
let func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
let i32_t = ctx.ctx.i32_type();
let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into()], false);
ctx.module.add_function(SYMBOL, fn_t, None)
});
let i = if let Some(v) = generator.gen_expr(ctx, i)? {
v.to_basic_value_enum(ctx, generator, i.custom.unwrap())?
} else {
return Ok(None)
};
Ok(Some(ctx
.builder
.build_call(func, &[i.into(), length.into()], "bounded_ind")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()))
}
/// This function handles 'end' **inclusively**.
/// Order of tuples `assign_idx` and `value_idx` is ('start', 'end', 'step').
/// Negative index should be handled before entering this function
pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ty: BasicTypeEnum<'ctx>,
dest_arr: ListValue<'ctx>,
dest_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
src_arr: ListValue<'ctx>,
src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
) {
let size_ty = generator.get_size_type(ctx.ctx);
let int8_ptr = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
let int32 = ctx.ctx.i32_type();
let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", int8_ptr);
let slice_assign_fun = {
let ty_vec = vec![
int32.into(), // dest start idx
int32.into(), // dest end idx
int32.into(), // dest step
elem_ptr_type.into(), // dest arr ptr
int32.into(), // dest arr len
int32.into(), // src start idx
int32.into(), // src end idx
int32.into(), // src step
elem_ptr_type.into(), // src arr ptr
int32.into(), // src arr len
int32.into(), // size
];
ctx.module.get_function(fun_symbol).unwrap_or_else(|| {
let fn_t = int32.fn_type(ty_vec.as_slice(), false);
ctx.module.add_function(fun_symbol, fn_t, None)
})
};
let zero = int32.const_zero();
let one = int32.const_int(1, false);
let dest_arr_ptr = dest_arr.data().base_ptr(ctx, generator);
let dest_arr_ptr = ctx.builder.build_pointer_cast(
dest_arr_ptr,
elem_ptr_type,
"dest_arr_ptr_cast",
).unwrap();
let dest_len = dest_arr.load_size(ctx, Some("dest.len"));
let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32").unwrap();
let src_arr_ptr = src_arr.data().base_ptr(ctx, generator);
let src_arr_ptr = ctx.builder.build_pointer_cast(
src_arr_ptr,
elem_ptr_type,
"src_arr_ptr_cast",
).unwrap();
let src_len = src_arr.load_size(ctx, Some("src.len"));
let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32").unwrap();
// index in bound and positive should be done
// assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and
// throw exception if not satisfied
let src_end = ctx.builder
.build_select(
ctx.builder.build_int_compare(
IntPredicate::SLT,
src_idx.2,
zero,
"is_neg",
).unwrap(),
ctx.builder.build_int_sub(src_idx.1, one, "e_min_one").unwrap(),
ctx.builder.build_int_add(src_idx.1, one, "e_add_one").unwrap(),
"final_e",
)
.map(BasicValueEnum::into_int_value)
.unwrap();
let dest_end = ctx.builder
.build_select(
ctx.builder.build_int_compare(
IntPredicate::SLT,
dest_idx.2,
zero,
"is_neg",
).unwrap(),
ctx.builder.build_int_sub(dest_idx.1, one, "e_min_one").unwrap(),
ctx.builder.build_int_add(dest_idx.1, one, "e_add_one").unwrap(),
"final_e",
)
.map(BasicValueEnum::into_int_value)
.unwrap();
let src_slice_len =
calculate_len_for_slice_range(generator, ctx, src_idx.0, src_end, src_idx.2);
let dest_slice_len =
calculate_len_for_slice_range(generator, ctx, dest_idx.0, dest_end, dest_idx.2);
let src_eq_dest = ctx.builder.build_int_compare(
IntPredicate::EQ,
src_slice_len,
dest_slice_len,
"slice_src_eq_dest",
).unwrap();
let src_slt_dest = ctx.builder.build_int_compare(
IntPredicate::SLT,
src_slice_len,
dest_slice_len,
"slice_src_slt_dest",
).unwrap();
let dest_step_eq_one = ctx.builder.build_int_compare(
IntPredicate::EQ,
dest_idx.2,
dest_idx.2.get_type().const_int(1, false),
"slice_dest_step_eq_one",
).unwrap();
let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1").unwrap();
let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond").unwrap();
ctx.make_assert(
generator,
cond,
"0:ValueError",
"attempt to assign sequence of size {0} to slice of size {1} with step size {2}",
[Some(src_slice_len), Some(dest_slice_len), Some(dest_idx.2)],
ctx.current_loc,
);
let new_len = {
let args = vec![
dest_idx.0.into(), // dest start idx
dest_idx.1.into(), // dest end idx
dest_idx.2.into(), // dest step
dest_arr_ptr.into(), // dest arr ptr
dest_len.into(), // dest arr len
src_idx.0.into(), // src start idx
src_idx.1.into(), // src end idx
src_idx.2.into(), // src step
src_arr_ptr.into(), // src arr ptr
src_len.into(), // src arr len
{
let s = match ty {
BasicTypeEnum::FloatType(t) => t.size_of(),
BasicTypeEnum::IntType(t) => t.size_of(),
BasicTypeEnum::PointerType(t) => t.size_of(),
BasicTypeEnum::StructType(t) => t.size_of().unwrap(),
_ => unreachable!(),
};
ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size").unwrap()
}
.into(),
];
ctx.builder
.build_call(slice_assign_fun, args.as_slice(), "slice_assign")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
};
// update length
let need_update = ctx.builder
.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update")
.unwrap();
let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
let update_bb = ctx.ctx.append_basic_block(current, "update");
let cont_bb = ctx.ctx.append_basic_block(current, "cont");
ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb).unwrap();
ctx.builder.position_at_end(update_bb);
let new_len = ctx.builder
.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len")
.unwrap();
dest_arr.store_size(ctx, generator, new_len);
ctx.builder.build_unconditional_branch(cont_bb).unwrap();
ctx.builder.position_at_end(cont_bb);
}
/// Generates a call to `isinf` in IR. Returns an `i1` representing the result.
pub fn call_isinf<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> IntValue<'ctx> {
let intrinsic_fn = ctx.module.get_function("__nac3_isinf").unwrap_or_else(|| {
let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
ctx.module.add_function("__nac3_isinf", fn_type, None)
});
let ret = ctx.builder
.build_call(intrinsic_fn, &[v.into()], "isinf")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap();
generator.bool_to_i1(ctx, ret)
}
/// Generates a call to `isnan` in IR. Returns an `i1` representing the result.
pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> IntValue<'ctx> {
let intrinsic_fn = ctx.module.get_function("__nac3_isnan").unwrap_or_else(|| {
let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
ctx.module.add_function("__nac3_isnan", fn_type, None)
});
let ret = ctx.builder
.build_call(intrinsic_fn, &[v.into()], "isnan")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap();
generator.bool_to_i1(ctx, ret)
}
/// Generates a call to `gamma` in IR. Returns an `f64` representing the result.
pub fn call_gamma<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("__nac3_gamma", fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[v.into()], "gamma")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Generates a call to `gammaln` in IR. Returns an `f64` representing the result.
pub fn call_gammaln<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("__nac3_gammaln", fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[v.into()], "gammaln")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Generates a call to `j0` in IR. Returns an `f64` representing the result.
pub fn call_j0<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("__nac3_j0", fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[v.into()], "j0")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
/// calculated total size.
///
/// * `num_dims` - An [`IntValue`] containing the number of dimensions.
/// * `dims` - A [`PointerValue`] to an array containing the size of each dimension.
pub fn call_ndarray_calc_size<'ctx, G, Dims>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
dims: &Dims,
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Dims: ArrayLikeIndexer<'ctx>, {
let llvm_i64 = ctx.ctx.i64_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi64 = llvm_i64.ptr_type(AddressSpace::default());
let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_size",
64 => "__nac3_ndarray_calc_size64",
bw => unreachable!("Unsupported size type bit width: {}", bw)
};
let ndarray_calc_size_fn_t = llvm_usize.fn_type(
&[
llvm_pi64.into(),
llvm_usize.into(),
],
false,
);
let ndarray_calc_size_fn = ctx.module.get_function(ndarray_calc_size_fn_name)
.unwrap_or_else(|| {
ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
});
ctx.builder
.build_call(
ndarray_calc_size_fn,
&[
dims.base_ptr(ctx, generator).into(),
dims.size(ctx, generator).into(),
],
"",
)
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`]
/// containing `i32` indices of the flattened index.
///
/// * `index` - The index to compute the multidimensional index for.
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`.
pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &mut CodeGenContext<'ctx, '_>,
index: IntValue<'ctx>,
ndarray: NDArrayValue<'ctx>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_void = ctx.ctx.void_type();
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_nd_indices",
64 => "__nac3_ndarray_calc_nd_indices64",
bw => unreachable!("Unsupported size type bit width: {}", bw)
};
let ndarray_calc_nd_indices_fn = ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| {
let fn_type = llvm_void.fn_type(
&[
llvm_usize.into(),
llvm_pusize.into(),
llvm_usize.into(),
llvm_pi32.into(),
],
false,
);
ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None)
});
let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.dim_sizes();
let indices = ctx.builder.build_array_alloca(
llvm_i32,
ndarray_num_dims,
"",
).unwrap();
ctx.builder
.build_call(
ndarray_calc_nd_indices_fn,
&[
index.into(),
ndarray_dims.base_ptr(ctx, generator).into(),
ndarray_num_dims.into(),
indices.into(),
],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None),
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}
fn call_ndarray_flatten_index_impl<'ctx, G, Indices>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: &Indices,
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Indices: ArrayLikeIndexer<'ctx>, {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
debug_assert_eq!(
IntType::try_from(indices.element_type(ctx, generator))
.map(IntType::get_bit_width)
.unwrap_or_default(),
llvm_i32.get_bit_width(),
"Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`"
);
debug_assert_eq!(
indices.size(ctx, generator).get_type().get_bit_width(),
llvm_usize.get_bit_width(),
"Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`"
);
let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_flatten_index",
64 => "__nac3_ndarray_flatten_index64",
bw => unreachable!("Unsupported size type bit width: {}", bw)
};
let ndarray_flatten_index_fn = ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[
llvm_pusize.into(),
llvm_usize.into(),
llvm_pi32.into(),
llvm_usize.into(),
],
false,
);
ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None)
});
let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.dim_sizes();
let index = ctx.builder
.build_call(
ndarray_flatten_index_fn,
&[
ndarray_dims.base_ptr(ctx, generator).into(),
ndarray_num_dims.into(),
indices.base_ptr(ctx, generator).into(),
indices.size(ctx, generator).into(),
],
"",
)
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap();
index
}
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
/// multidimensional index.
///
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`.
/// * `indices` - The multidimensional index to compute the flattened index for.
pub fn call_ndarray_flatten_index<'ctx, G, Index>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: &Index,
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Index: ArrayLikeIndexer<'ctx>, {
call_ndarray_flatten_index_impl(
generator,
ctx,
ndarray,
indices,
)
}
/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of
/// dimension and size of each dimension of the resultant `ndarray`.
pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
lhs: NDArrayValue<'ctx>,
rhs: NDArrayValue<'ctx>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast",
64 => "__nac3_ndarray_calc_broadcast64",
bw => unreachable!("Unsupported size type bit width: {}", bw)
};
let ndarray_calc_broadcast_fn = ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
],
false,
);
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
});
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_ndims = rhs.load_ndims(ctx);
let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None);
gen_for_callback_incrementing(
generator,
ctx,
llvm_usize.const_zero(),
(min_ndims, false),
|generator, ctx, idx| {
let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap();
let (lhs_dim_sz, rhs_dim_sz) = unsafe {
(
lhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None),
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None),
)
};
let llvm_usize_const_one = llvm_usize.const_int(1, false);
let lhs_eqz = ctx.builder.build_int_compare(
IntPredicate::EQ,
lhs_dim_sz,
llvm_usize_const_one,
"",
).unwrap();
let rhs_eqz = ctx.builder.build_int_compare(
IntPredicate::EQ,
rhs_dim_sz,
llvm_usize_const_one,
"",
).unwrap();
let lhs_or_rhs_eqz = ctx.builder.build_or(
lhs_eqz,
rhs_eqz,
""
).unwrap();
let lhs_eq_rhs = ctx.builder.build_int_compare(
IntPredicate::EQ,
lhs_dim_sz,
rhs_dim_sz,
""
).unwrap();
let is_compatible = ctx.builder.build_or(
lhs_or_rhs_eqz,
lhs_eq_rhs,
""
).unwrap();
ctx.make_assert(
generator,
is_compatible,
"0:ValueError",
"operands could not be broadcast together",
[None, None, None],
ctx.current_loc,
);
Ok(())
},
llvm_usize.const_int(1, false),
).unwrap();
let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None);
let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator);
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator);
let rhs_ndims = rhs.load_ndims(ctx);
let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap();
let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None);
ctx.builder
.build_call(
ndarray_calc_broadcast_fn,
&[
lhs_dims.into(),
lhs_ndims.into(),
rhs_dims.into(),
rhs_ndims.into(),
out_dims.base_ptr(ctx, generator).into(),
],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
out_dims,
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}
/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`]
/// containing the indices used for accessing `array` corresponding to the index of the broadcasted
/// array `broadcast_idx`.
pub fn call_ndarray_calc_broadcast_index<'ctx, G: CodeGenerator + ?Sized, BroadcastIdx: UntypedArrayLikeAccessor<'ctx>>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
array: NDArrayValue<'ctx>,
broadcast_idx: &BroadcastIdx,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast_idx",
64 => "__nac3_ndarray_calc_broadcast_idx64",
bw => unreachable!("Unsupported size type bit width: {}", bw)
};
let ndarray_calc_broadcast_fn = ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[
llvm_pusize.into(),
llvm_usize.into(),
llvm_pi32.into(),
llvm_pi32.into(),
],
false,
);
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
});
let broadcast_size = broadcast_idx.size(ctx, generator);
let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();
let array_dims = array.dim_sizes().base_ptr(ctx, generator);
let array_ndims = array.load_ndims(ctx);
let broadcast_idx_ptr = unsafe {
broadcast_idx.ptr_offset_unchecked(
ctx,
generator,
&llvm_usize.const_zero(),
None
)
};
ctx.builder
.build_call(
ndarray_calc_broadcast_fn,
&[
array_dims.into(),
array_ndims.into(),
broadcast_idx_ptr.into(),
out_idx.into(),
],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None),
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}

View File

@ -1,773 +0,0 @@
use inkwell::AddressSpace;
use inkwell::context::Context;
use inkwell::types::AnyTypeEnum::IntType;
use inkwell::types::FloatType;
use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue};
use itertools::Either;
use crate::codegen::CodeGenContext;
/// Returns the string representation for the floating-point type `ft` when used in intrinsic
/// functions.
fn get_float_intrinsic_repr(ctx: &Context, ft: FloatType) -> &'static str {
// Standard LLVM floating-point types
if ft == ctx.f16_type() {
return "f16"
}
if ft == ctx.f32_type() {
return "f32"
}
if ft == ctx.f64_type() {
return "f64"
}
if ft == ctx.f128_type() {
return "f128"
}
// Non-standard floating-point types
if ft == ctx.x86_f80_type() {
return "f80"
}
if ft == ctx.ppc_f128_type() {
return "ppcf128"
}
unreachable!()
}
/// Invokes the [`llvm.stacksave`](https://llvm.org/docs/LangRef.html#llvm-stacksave-intrinsic)
/// intrinsic.
pub fn call_stacksave<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
name: Option<&str>,
) -> PointerValue<'ctx> {
const FN_NAME: &str = "llvm.stacksave";
let intrinsic_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let fn_type = llvm_p0i8.fn_type(&[], false);
ctx.module.add_function(FN_NAME, fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_pointer_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the
/// [`llvm.stackrestore`](https://llvm.org/docs/LangRef.html#llvm-stackrestore-intrinsic) intrinsic.
pub fn call_stackrestore<'ctx>(ctx: &CodeGenContext<'ctx, '_>, ptr: PointerValue<'ctx>) {
const FN_NAME: &str = "llvm.stackrestore";
let intrinsic_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let llvm_void = ctx.ctx.void_type();
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let fn_type = llvm_void.fn_type(&[llvm_p0i8.into()], false);
ctx.module.add_function(FN_NAME, fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[ptr.into()], "")
.unwrap();
}
/// Invokes the [`llvm.abs`](https://llvm.org/docs/LangRef.html#llvm-abs-intrinsic) intrinsic.
///
/// * `src` - The value for which the absolute value is to be returned.
/// * `is_int_min_poison` - Whether `poison` is to be returned if `src` is `INT_MIN`.
pub fn call_int_abs<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
src: IntValue<'ctx>,
is_int_min_poison: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
debug_assert_eq!(is_int_min_poison.get_type().get_bit_width(), 1);
debug_assert!(is_int_min_poison.is_const());
let llvm_src_t = src.get_type();
let fn_name = format!("llvm.abs.i{}", llvm_src_t.get_bit_width());
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let llvm_i1 = ctx.ctx.bool_type();
let fn_type = llvm_src_t.fn_type(&[llvm_src_t.into(), llvm_i1.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[src.into(), is_int_min_poison.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.smax`](https://llvm.org/docs/LangRef.html#llvm-smax-intrinsic) intrinsic.
pub fn call_int_smax<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
a: IntValue<'ctx>,
b: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
debug_assert_eq!(a.get_type().get_bit_width(), b.get_type().get_bit_width());
let llvm_int_t = a.get_type();
let fn_name = format!("llvm.smax.i{}", llvm_int_t.get_bit_width());
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_int_t.fn_type(&[llvm_int_t.into(), llvm_int_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[a.into(), b.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.smin`](https://llvm.org/docs/LangRef.html#llvm-smin-intrinsic) intrinsic.
pub fn call_int_smin<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
a: IntValue<'ctx>,
b: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
debug_assert_eq!(a.get_type().get_bit_width(), b.get_type().get_bit_width());
let llvm_int_t = a.get_type();
let fn_name = format!("llvm.smin.i{}", llvm_int_t.get_bit_width());
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_int_t.fn_type(&[llvm_int_t.into(), llvm_int_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[a.into(), b.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.umax`](https://llvm.org/docs/LangRef.html#llvm-umax-intrinsic) intrinsic.
pub fn call_int_umax<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
a: IntValue<'ctx>,
b: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
debug_assert_eq!(a.get_type().get_bit_width(), b.get_type().get_bit_width());
let llvm_int_t = a.get_type();
let fn_name = format!("llvm.umax.i{}", llvm_int_t.get_bit_width());
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_int_t.fn_type(&[llvm_int_t.into(), llvm_int_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[a.into(), b.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.umin`](https://llvm.org/docs/LangRef.html#llvm-umin-intrinsic) intrinsic.
pub fn call_int_umin<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
a: IntValue<'ctx>,
b: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
debug_assert_eq!(a.get_type().get_bit_width(), b.get_type().get_bit_width());
let llvm_int_t = a.get_type();
let fn_name = format!("llvm.umin.i{}", llvm_int_t.get_bit_width());
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_int_t.fn_type(&[llvm_int_t.into(), llvm_int_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[a.into(), b.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.memcpy`](https://llvm.org/docs/LangRef.html#llvm-memcpy-intrinsic) intrinsic.
///
/// * `dest` - The pointer to the destination. Must be a pointer to an integer type.
/// * `src` - The pointer to the source. Must be a pointer to an integer type.
/// * `len` - The number of bytes to copy.
/// * `is_volatile` - Whether the `memcpy` operation should be `volatile`.
pub fn call_memcpy<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
dest: PointerValue<'ctx>,
src: PointerValue<'ctx>,
len: IntValue<'ctx>,
is_volatile: IntValue<'ctx>,
) {
debug_assert!(dest.get_type().get_element_type().is_int_type());
debug_assert!(src.get_type().get_element_type().is_int_type());
debug_assert_eq!(
dest.get_type().get_element_type().into_int_type().get_bit_width(),
src.get_type().get_element_type().into_int_type().get_bit_width(),
);
debug_assert!(matches!(len.get_type().get_bit_width(), 32 | 64));
debug_assert_eq!(is_volatile.get_type().get_bit_width(), 1);
let llvm_dest_t = dest.get_type();
let llvm_src_t = src.get_type();
let llvm_len_t = len.get_type();
let fn_name = format!(
"llvm.memcpy.p0i{}.p0i{}.i{}",
llvm_dest_t.get_element_type().into_int_type().get_bit_width(),
llvm_src_t.get_element_type().into_int_type().get_bit_width(),
llvm_len_t.get_bit_width(),
);
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let llvm_void = ctx.ctx.void_type();
let fn_type = llvm_void.fn_type(
&[
llvm_dest_t.into(),
llvm_src_t.into(),
llvm_len_t.into(),
is_volatile.get_type().into(),
],
false,
);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[dest.into(), src.into(), len.into(), is_volatile.into()], "")
.unwrap();
}
/// Invokes the `llvm.memcpy` intrinsic.
///
/// Unlike [`call_memcpy`], this function accepts any type of pointer value. If `dest` or `src` is
/// not a pointer to an integer, the pointer(s) will be cast to `i8*` before invoking `memcpy`.
pub fn call_memcpy_generic<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
dest: PointerValue<'ctx>,
src: PointerValue<'ctx>,
len: IntValue<'ctx>,
is_volatile: IntValue<'ctx>,
) {
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let dest_elem_t = dest.get_type().get_element_type();
let src_elem_t = src.get_type().get_element_type();
let dest = if matches!(dest_elem_t, IntType(t) if t.get_bit_width() == 8) {
dest
} else {
ctx.builder
.build_bitcast(dest, llvm_p0i8, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap()
};
let src = if matches!(src_elem_t, IntType(t) if t.get_bit_width() == 8) {
src
} else {
ctx.builder
.build_bitcast(src, llvm_p0i8, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap()
};
call_memcpy(ctx, dest, src, len, is_volatile);
}
/// Invokes the [`llvm.sqrt`](https://llvm.org/docs/LangRef.html#llvm-sqrt-intrinsic) intrinsic.
pub fn call_float_sqrt<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
let llvm_float_t = val.get_type();
let fn_name = format!("llvm.sqrt.{}", get_float_intrinsic_repr(ctx.ctx, llvm_float_t));
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_float_t.fn_type(&[llvm_float_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.powi`](https://llvm.org/docs/LangRef.html#llvm-powi-intrinsic) intrinsic.
pub fn call_float_powi<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
power: IntValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
let llvm_val_t = val.get_type();
let llvm_power_t = power.get_type();
let fn_name = format!(
"llvm.powi.{}.i{}",
get_float_intrinsic_repr(ctx.ctx, llvm_val_t),
llvm_power_t.get_bit_width(),
);
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_val_t.fn_type(&[llvm_val_t.into(), llvm_power_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[val.into(), power.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.sin`](https://llvm.org/docs/LangRef.html#llvm-sin-intrinsic) intrinsic.
pub fn call_float_sin<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
let llvm_float_t = val.get_type();
let fn_name = format!("llvm.sin.{}", get_float_intrinsic_repr(ctx.ctx, llvm_float_t));
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_float_t.fn_type(&[llvm_float_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.cos`](https://llvm.org/docs/LangRef.html#llvm-cos-intrinsic) intrinsic.
pub fn call_float_cos<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
let llvm_float_t = val.get_type();
let fn_name = format!("llvm.cos.{}", get_float_intrinsic_repr(ctx.ctx, llvm_float_t));
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_float_t.fn_type(&[llvm_float_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.pow`](https://llvm.org/docs/LangRef.html#llvm-pow-intrinsic) intrinsic.
pub fn call_float_pow<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
power: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
debug_assert_eq!(val.get_type(), power.get_type());
let llvm_float_t = val.get_type();
let fn_name = format!("llvm.pow.{}", get_float_intrinsic_repr(ctx.ctx, llvm_float_t));
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_float_t.fn_type(&[llvm_float_t.into(), llvm_float_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[val.into(), power.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.exp`](https://llvm.org/docs/LangRef.html#llvm-exp-intrinsic) intrinsic.
pub fn call_float_exp<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
let llvm_float_t = val.get_type();
let fn_name = format!("llvm.exp.{}", get_float_intrinsic_repr(ctx.ctx, llvm_float_t));
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_float_t.fn_type(&[llvm_float_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.exp2`](https://llvm.org/docs/LangRef.html#llvm-exp2-intrinsic) intrinsic.
pub fn call_float_exp2<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
let llvm_float_t = val.get_type();
let fn_name = format!("llvm.exp2.{}", get_float_intrinsic_repr(ctx.ctx, llvm_float_t));
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_float_t.fn_type(&[llvm_float_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.log`](https://llvm.org/docs/LangRef.html#llvm-log-intrinsic) intrinsic.
pub fn call_float_log<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
let llvm_float_t = val.get_type();
let fn_name = format!("llvm.log.{}", get_float_intrinsic_repr(ctx.ctx, llvm_float_t));
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_float_t.fn_type(&[llvm_float_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.log10`](https://llvm.org/docs/LangRef.html#llvm-log10-intrinsic) intrinsic.
pub fn call_float_log10<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
let llvm_float_t = val.get_type();
let fn_name = format!("llvm.log10.{}", get_float_intrinsic_repr(ctx.ctx, llvm_float_t));
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_float_t.fn_type(&[llvm_float_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.log2`](https://llvm.org/docs/LangRef.html#llvm-log2-intrinsic) intrinsic.
pub fn call_float_log2<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
let llvm_float_t = val.get_type();
let fn_name = format!("llvm.log2.{}", get_float_intrinsic_repr(ctx.ctx, llvm_float_t));
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_float_t.fn_type(&[llvm_float_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.fabs`](https://llvm.org/docs/LangRef.html#llvm-fabs-intrinsic) intrinsic.
pub fn call_float_fabs<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
src: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
let llvm_src_t = src.get_type();
let fn_name = format!("llvm.fabs.{}", get_float_intrinsic_repr(ctx.ctx, llvm_src_t));
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_src_t.fn_type(&[llvm_src_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[src.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.minnum`](https://llvm.org/docs/LangRef.html#llvm-minnum-intrinsic) intrinsic.
pub fn call_float_minnum<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val1: FloatValue<'ctx>,
val2: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
debug_assert_eq!(val1.get_type(), val2.get_type());
let llvm_float_t = val1.get_type();
let fn_name = format!("llvm.minnum.{}", get_float_intrinsic_repr(ctx.ctx, llvm_float_t));
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_float_t.fn_type(&[llvm_float_t.into(), llvm_float_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[val1.into(), val2.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.maxnum`](https://llvm.org/docs/LangRef.html#llvm-maxnum-intrinsic) intrinsic.
pub fn call_float_maxnum<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val1: FloatValue<'ctx>,
val2: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
debug_assert_eq!(val1.get_type(), val2.get_type());
let llvm_float_t = val1.get_type();
let fn_name = format!("llvm.maxnum.{}", get_float_intrinsic_repr(ctx.ctx, llvm_float_t));
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_float_t.fn_type(&[llvm_float_t.into(), llvm_float_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[val1.into(), val2.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.copysign`](https://llvm.org/docs/LangRef.html#llvm-copysign-intrinsic) intrinsic.
pub fn call_float_copysign<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
mag: FloatValue<'ctx>,
sgn: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
debug_assert_eq!(mag.get_type(), sgn.get_type());
let llvm_float_t = mag.get_type();
let fn_name = format!("llvm.copysign.{}", get_float_intrinsic_repr(ctx.ctx, llvm_float_t));
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_float_t.fn_type(&[llvm_float_t.into(), llvm_float_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[mag.into(), sgn.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.floor`](https://llvm.org/docs/LangRef.html#llvm-floor-intrinsic) intrinsic.
pub fn call_float_floor<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
let llvm_float_t = val.get_type();
let fn_name = format!("llvm.floor.{}", get_float_intrinsic_repr(ctx.ctx, llvm_float_t));
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_float_t.fn_type(&[llvm_float_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.ceil`](https://llvm.org/docs/LangRef.html#llvm-ceil-intrinsic) intrinsic.
pub fn call_float_ceil<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
let llvm_float_t = val.get_type();
let fn_name = format!("llvm.ceil.{}", get_float_intrinsic_repr(ctx.ctx, llvm_float_t));
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_float_t.fn_type(&[llvm_float_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.round`](https://llvm.org/docs/LangRef.html#llvm-round-intrinsic) intrinsic.
pub fn call_float_round<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
let llvm_float_t = val.get_type();
let fn_name = format!("llvm.round.{}", get_float_intrinsic_repr(ctx.ctx, llvm_float_t));
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_float_t.fn_type(&[llvm_float_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the
/// [`llvm.roundeven`](https://llvm.org/docs/LangRef.html#llvm-roundeven-intrinsic) intrinsic.
pub fn call_float_roundeven<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
let llvm_float_t = val.get_type();
let fn_name = format!("llvm.roundeven.{}", get_float_intrinsic_repr(ctx.ctx, llvm_float_t));
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_float_t.fn_type(&[llvm_float_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.expect`](https://llvm.org/docs/LangRef.html#llvm-expect-intrinsic) intrinsic.
pub fn call_expect<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: IntValue<'ctx>,
expected_val: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
debug_assert_eq!(val.get_type().get_bit_width(), expected_val.get_type().get_bit_width());
let llvm_int_t = val.get_type();
let fn_name = format!("llvm.expect.i{}", llvm_int_t.get_bit_width());
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_int_t.fn_type(&[llvm_int_t.into(), llvm_int_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[val.into(), expected_val.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,427 +0,0 @@
use crate::{
codegen::{
concrete_type::ConcreteTypeStore, CodeGenContext, CodeGenLLVMOptions,
CodeGenTargetMachineOptions, CodeGenTask, DefaultCodeGenerator, WithCall, WorkerRegistry,
},
symbol_resolver::{SymbolResolver, ValueEnum},
toplevel::{
composer::{ComposerConfig, TopLevelComposer}, DefinitionId, FunInstance, TopLevelContext,
TopLevelDef,
},
typecheck::{
type_inferencer::{FunctionData, Inferencer, PrimitiveStore},
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
},
};
use indoc::indoc;
use inkwell::{
targets::{InitializationConfig, Target},
OptimizationLevel
};
use nac3parser::{
ast::{fold::Fold, StrRef},
parser::parse_program,
};
use parking_lot::RwLock;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
struct Resolver {
id_to_type: HashMap<StrRef, Type>,
id_to_def: RwLock<HashMap<StrRef, DefinitionId>>,
class_names: HashMap<StrRef, Type>,
}
impl Resolver {
pub fn add_id_def(&self, id: StrRef, def: DefinitionId) {
self.id_to_def.write().insert(id, def);
}
}
impl SymbolResolver for Resolver {
fn get_default_param_value(
&self,
_: &nac3parser::ast::Expr,
) -> Option<crate::symbol_resolver::SymbolValue> {
unimplemented!()
}
fn get_symbol_type(
&self,
_: &mut Unifier,
_: &[Arc<RwLock<TopLevelDef>>],
_: &PrimitiveStore,
str: StrRef,
) -> Result<Type, String> {
self.id_to_type.get(&str).cloned().ok_or_else(|| format!("cannot find symbol `{}`", str))
}
fn get_symbol_value<'ctx, 'a>(
&self,
_: StrRef,
_: &mut CodeGenContext<'ctx, 'a>,
) -> Option<ValueEnum<'ctx>> {
unimplemented!()
}
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> {
self.id_to_def
.read()
.get(&id)
.cloned()
.ok_or_else(|| HashSet::from([
format!("cannot find symbol `{}`", id),
]))
}
fn get_string_id(&self, _: &str) -> i32 {
unimplemented!()
}
fn get_exception_id(&self, _tyid: usize) -> usize {
unimplemented!()
}
}
#[test]
fn test_primitives() {
let source = indoc! { "
c = a + b
d = a if c == 1 else 0
return d
"};
let statements = parse_program(source, Default::default()).unwrap();
let composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 32).0;
let mut unifier = composer.unifier.clone();
let primitives = composer.primitives_ty;
let top_level = Arc::new(composer.make_top_level_context());
unifier.top_level = Some(top_level.clone());
let resolver = Arc::new(Resolver {
id_to_type: HashMap::new(),
id_to_def: RwLock::new(HashMap::new()),
class_names: Default::default(),
}) as Arc<dyn SymbolResolver + Send + Sync>;
let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()];
let signature = FunSignature {
args: vec![
FuncArg { name: "a".into(), ty: primitives.int32, default_value: None },
FuncArg { name: "b".into(), ty: primitives.int32, default_value: None },
],
ret: primitives.int32,
vars: VarMap::new(),
};
let mut store = ConcreteTypeStore::new();
let mut cache = HashMap::new();
let signature = store.from_signature(&mut unifier, &primitives, &signature, &mut cache);
let signature = store.add_cty(signature);
let mut function_data = FunctionData {
resolver: resolver.clone(),
bound_variables: Vec::new(),
return_type: Some(primitives.int32),
};
let mut virtual_checks = Vec::new();
let mut calls = HashMap::new();
let mut identifiers: HashSet<_> = ["a".into(), "b".into()].iter().cloned().collect();
let mut inferencer = Inferencer {
top_level: &top_level,
function_data: &mut function_data,
unifier: &mut unifier,
variable_mapping: Default::default(),
primitives: &primitives,
virtual_checks: &mut virtual_checks,
calls: &mut calls,
defined_identifiers: identifiers.clone(),
in_handler: false,
};
inferencer.variable_mapping.insert("a".into(), inferencer.primitives.int32);
inferencer.variable_mapping.insert("b".into(), inferencer.primitives.int32);
let statements = statements
.into_iter()
.map(|v| inferencer.fold_stmt(v))
.collect::<Result<Vec<_>, _>>()
.unwrap();
inferencer.check_block(&statements, &mut identifiers).unwrap();
let top_level = Arc::new(TopLevelContext {
definitions: Arc::new(RwLock::new(std::mem::take(&mut *top_level.definitions.write()))),
unifiers: Arc::new(RwLock::new(vec![(unifier.get_shared_unifier(), primitives)])),
personality_symbol: None,
});
let task = CodeGenTask {
subst: Default::default(),
symbol_name: "testing".into(),
body: Arc::new(statements),
unifier_index: 0,
calls: Arc::new(calls),
resolver,
store,
signature,
id: 0,
};
let f = Arc::new(WithCall::new(Box::new(|module| {
// the following IR is equivalent to
// ```
// ; ModuleID = 'test.ll'
// source_filename = "test"
//
// ; Function Attrs: norecurse nounwind readnone
// define i32 @testing(i32 %0, i32 %1) local_unnamed_addr #0 {
// init:
// %add = add i32 %1, %0
// %cmp = icmp eq i32 %add, 1
// %ifexpr = select i1 %cmp, i32 %0, i32 0
// ret i32 %ifexpr
// }
//
// attributes #0 = { norecurse nounwind readnone }
// ```
// after O2 optimization
let expected = indoc! {"
; ModuleID = 'test'
source_filename = \"test\"
; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn
define i32 @testing(i32 %0, i32 %1) local_unnamed_addr #0 !dbg !4 {
init:
%add = add i32 %1, %0, !dbg !9
%cmp = icmp eq i32 %add, 1, !dbg !10
%. = select i1 %cmp, i32 %0, i32 0, !dbg !11
ret i32 %., !dbg !12
}
attributes #0 = { mustprogress nofree norecurse nosync nounwind readnone willreturn }
!llvm.module.flags = !{!0, !1}
!llvm.dbg.cu = !{!2}
!0 = !{i32 2, !\"Debug Info Version\", i32 3}
!1 = !{i32 2, !\"Dwarf Version\", i32 4}
!2 = distinct !DICompileUnit(language: DW_LANG_Python, file: !3, producer: \"NAC3\", isOptimized: true, runtimeVersion: 0, emissionKind: FullDebug)
!3 = !DIFile(filename: \"unknown\", directory: \"\")
!4 = distinct !DISubprogram(name: \"testing\", linkageName: \"testing\", scope: null, file: !3, line: 1, type: !5, scopeLine: 1, flags: DIFlagPublic, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !2, retainedNodes: !8)
!5 = !DISubroutineType(flags: DIFlagPublic, types: !6)
!6 = !{!7}
!7 = !DIBasicType(name: \"_\", flags: DIFlagPublic)
!8 = !{}
!9 = !DILocation(line: 1, column: 9, scope: !4)
!10 = !DILocation(line: 2, column: 15, scope: !4)
!11 = !DILocation(line: 0, scope: !4)
!12 = !DILocation(line: 3, column: 8, scope: !4)
"}
.trim();
assert_eq!(expected, module.print_to_string().to_str().unwrap().trim());
})));
Target::initialize_all(&InitializationConfig::default());
let llvm_options = CodeGenLLVMOptions {
opt_level: OptimizationLevel::Default,
target: CodeGenTargetMachineOptions::from_host_triple(),
};
let (registry, handles) = WorkerRegistry::create_workers(
threads,
top_level,
&llvm_options,
&f
);
registry.add_task(task);
registry.wait_tasks_complete(handles);
}
#[test]
fn test_simple_call() {
let source_1 = indoc! { "
a = foo(a)
return a * 2
"};
let statements_1 = parse_program(source_1, Default::default()).unwrap();
let source_2 = indoc! { "
return a + 1
"};
let statements_2 = parse_program(source_2, Default::default()).unwrap();
let composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 32).0;
let mut unifier = composer.unifier.clone();
let primitives = composer.primitives_ty;
let top_level = Arc::new(composer.make_top_level_context());
unifier.top_level = Some(top_level.clone());
let signature = FunSignature {
args: vec![FuncArg { name: "a".into(), ty: primitives.int32, default_value: None }],
ret: primitives.int32,
vars: VarMap::new(),
};
let fun_ty = unifier.add_ty(TypeEnum::TFunc(signature.clone()));
let mut store = ConcreteTypeStore::new();
let mut cache = HashMap::new();
let signature = store.from_signature(&mut unifier, &primitives, &signature, &mut cache);
let signature = store.add_cty(signature);
let foo_id = top_level.definitions.read().len();
top_level.definitions.write().push(Arc::new(RwLock::new(TopLevelDef::Function {
name: "foo".to_string(),
simple_name: "foo".into(),
signature: fun_ty,
var_id: vec![],
instance_to_stmt: HashMap::new(),
instance_to_symbol: HashMap::new(),
resolver: None,
codegen_callback: None,
loc: None,
})));
let resolver = Resolver {
id_to_type: HashMap::new(),
id_to_def: RwLock::new(HashMap::new()),
class_names: Default::default(),
};
resolver.add_id_def("foo".into(), DefinitionId(foo_id));
let resolver = Arc::new(resolver) as Arc<dyn SymbolResolver + Send + Sync>;
if let TopLevelDef::Function { resolver: r, .. } =
&mut *top_level.definitions.read()[foo_id].write()
{
*r = Some(resolver.clone());
} else {
unreachable!()
}
let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()];
let mut function_data = FunctionData {
resolver: resolver.clone(),
bound_variables: Vec::new(),
return_type: Some(primitives.int32),
};
let mut virtual_checks = Vec::new();
let mut calls = HashMap::new();
let mut identifiers: HashSet<_> = ["a".into(), "foo".into()].iter().cloned().collect();
let mut inferencer = Inferencer {
top_level: &top_level,
function_data: &mut function_data,
unifier: &mut unifier,
variable_mapping: Default::default(),
primitives: &primitives,
virtual_checks: &mut virtual_checks,
calls: &mut calls,
defined_identifiers: identifiers.clone(),
in_handler: false,
};
inferencer.variable_mapping.insert("a".into(), inferencer.primitives.int32);
inferencer.variable_mapping.insert("foo".into(), fun_ty);
let statements_1 = statements_1
.into_iter()
.map(|v| inferencer.fold_stmt(v))
.collect::<Result<Vec<_>, _>>()
.unwrap();
let calls1 = inferencer.calls.clone();
inferencer.calls.clear();
let statements_2 = statements_2
.into_iter()
.map(|v| inferencer.fold_stmt(v))
.collect::<Result<Vec<_>, _>>()
.unwrap();
if let TopLevelDef::Function { instance_to_stmt, .. } =
&mut *top_level.definitions.read()[foo_id].write()
{
instance_to_stmt.insert(
"".to_string(),
FunInstance {
body: Arc::new(statements_2),
calls: Arc::new(inferencer.calls.clone()),
subst: Default::default(),
unifier_id: 0,
},
);
} else {
unreachable!()
}
inferencer.check_block(&statements_1, &mut identifiers).unwrap();
let top_level = Arc::new(TopLevelContext {
definitions: Arc::new(RwLock::new(std::mem::take(&mut *top_level.definitions.write()))),
unifiers: Arc::new(RwLock::new(vec![(unifier.get_shared_unifier(), primitives)])),
personality_symbol: None,
});
let task = CodeGenTask {
subst: Default::default(),
symbol_name: "testing".to_string(),
body: Arc::new(statements_1),
calls: Arc::new(calls1),
unifier_index: 0,
resolver,
signature,
store,
id: 0,
};
let f = Arc::new(WithCall::new(Box::new(|module| {
let expected = indoc! {"
; ModuleID = 'test'
source_filename = \"test\"
; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn
define i32 @testing(i32 %0) local_unnamed_addr #0 !dbg !5 {
init:
%add.i = shl i32 %0, 1, !dbg !10
%mul = add i32 %add.i, 2, !dbg !10
ret i32 %mul, !dbg !10
}
; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn
define i32 @foo.0(i32 %0) local_unnamed_addr #0 !dbg !11 {
init:
%add = add i32 %0, 1, !dbg !12
ret i32 %add, !dbg !12
}
attributes #0 = { mustprogress nofree norecurse nosync nounwind readnone willreturn }
!llvm.module.flags = !{!0, !1}
!llvm.dbg.cu = !{!2, !4}
!0 = !{i32 2, !\"Debug Info Version\", i32 3}
!1 = !{i32 2, !\"Dwarf Version\", i32 4}
!2 = distinct !DICompileUnit(language: DW_LANG_Python, file: !3, producer: \"NAC3\", isOptimized: true, runtimeVersion: 0, emissionKind: FullDebug)
!3 = !DIFile(filename: \"unknown\", directory: \"\")
!4 = distinct !DICompileUnit(language: DW_LANG_Python, file: !3, producer: \"NAC3\", isOptimized: true, runtimeVersion: 0, emissionKind: FullDebug)
!5 = distinct !DISubprogram(name: \"testing\", linkageName: \"testing\", scope: null, file: !3, line: 1, type: !6, scopeLine: 1, flags: DIFlagPublic, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !2, retainedNodes: !9)
!6 = !DISubroutineType(flags: DIFlagPublic, types: !7)
!7 = !{!8}
!8 = !DIBasicType(name: \"_\", flags: DIFlagPublic)
!9 = !{}
!10 = !DILocation(line: 2, column: 12, scope: !5)
!11 = distinct !DISubprogram(name: \"foo.0\", linkageName: \"foo.0\", scope: null, file: !3, line: 1, type: !6, scopeLine: 1, flags: DIFlagPublic, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !4, retainedNodes: !9)
!12 = !DILocation(line: 1, column: 12, scope: !11)
"}
.trim();
assert_eq!(expected, module.print_to_string().to_str().unwrap().trim());
})));
Target::initialize_all(&InitializationConfig::default());
let llvm_options = CodeGenLLVMOptions {
opt_level: OptimizationLevel::Default,
target: CodeGenTargetMachineOptions::from_host_triple(),
};
let (registry, handles) = WorkerRegistry::create_workers(
threads,
top_level,
&llvm_options,
&f
);
registry.add_task(task);
registry.wait_tasks_complete(handles);
}

View File

@ -1,7 +1,588 @@
#![warn(clippy::all)]
#![allow(dead_code)]
#![allow(clippy::clone_double_ref)]
pub mod codegen;
pub mod symbol_resolver;
pub mod toplevel;
pub mod typecheck;
extern crate num_bigint;
extern crate inkwell;
extern crate rustpython_parser;
pub mod type_check;
use std::error::Error;
use std::fmt;
use std::path::Path;
use std::collections::HashMap;
use num_traits::cast::ToPrimitive;
use rustpython_parser::ast;
use inkwell::OptimizationLevel;
use inkwell::builder::Builder;
use inkwell::context::Context;
use inkwell::module::Module;
use inkwell::targets::*;
use inkwell::types;
use inkwell::types::BasicType;
use inkwell::values;
use inkwell::{IntPredicate, FloatPredicate};
use inkwell::basic_block;
use inkwell::passes;
#[derive(Debug)]
enum CompileErrorKind {
Unsupported(&'static str),
MissingTypeAnnotation,
UnknownTypeAnnotation,
IncompatibleTypes,
UnboundIdentifier,
BreakOutsideLoop,
Internal(&'static str)
}
impl fmt::Display for CompileErrorKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CompileErrorKind::Unsupported(feature)
=> write!(f, "The following Python feature is not supported by NAC3: {}", feature),
CompileErrorKind::MissingTypeAnnotation
=> write!(f, "Missing type annotation"),
CompileErrorKind::UnknownTypeAnnotation
=> write!(f, "Unknown type annotation"),
CompileErrorKind::IncompatibleTypes
=> write!(f, "Incompatible types"),
CompileErrorKind::UnboundIdentifier
=> write!(f, "Unbound identifier"),
CompileErrorKind::BreakOutsideLoop
=> write!(f, "Break outside loop"),
CompileErrorKind::Internal(details)
=> write!(f, "Internal compiler error: {}", details),
}
}
}
#[derive(Debug)]
pub struct CompileError {
location: ast::Location,
kind: CompileErrorKind,
}
impl fmt::Display for CompileError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}, at {}", self.kind, self.location)
}
}
impl Error for CompileError {}
type CompileResult<T> = Result<T, CompileError>;
pub struct CodeGen<'ctx> {
context: &'ctx Context,
module: Module<'ctx>,
pass_manager: passes::PassManager<values::FunctionValue<'ctx>>,
builder: Builder<'ctx>,
current_source_location: ast::Location,
namespace: HashMap<String, values::PointerValue<'ctx>>,
break_bb: Option<basic_block::BasicBlock<'ctx>>,
}
impl<'ctx> CodeGen<'ctx> {
pub fn new(context: &'ctx Context) -> CodeGen<'ctx> {
let module = context.create_module("kernel");
let pass_manager = passes::PassManager::create(&module);
pass_manager.add_instruction_combining_pass();
pass_manager.add_reassociate_pass();
pass_manager.add_gvn_pass();
pass_manager.add_cfg_simplification_pass();
pass_manager.add_basic_alias_analysis_pass();
pass_manager.add_promote_memory_to_register_pass();
pass_manager.add_instruction_combining_pass();
pass_manager.add_reassociate_pass();
pass_manager.initialize();
let i32_type = context.i32_type();
let fn_type = i32_type.fn_type(&[i32_type.into()], false);
module.add_function("output", fn_type, None);
CodeGen {
context, module, pass_manager,
builder: context.create_builder(),
current_source_location: ast::Location::default(),
namespace: HashMap::new(),
break_bb: None,
}
}
fn set_source_location(&mut self, location: ast::Location) {
self.current_source_location = location;
}
fn compile_error(&self, kind: CompileErrorKind) -> CompileError {
CompileError {
location: self.current_source_location,
kind
}
}
fn get_basic_type(&self, name: &str) -> CompileResult<types::BasicTypeEnum<'ctx>> {
match name {
"bool" => Ok(self.context.bool_type().into()),
"int32" => Ok(self.context.i32_type().into()),
"int64" => Ok(self.context.i64_type().into()),
"float32" => Ok(self.context.f32_type().into()),
"float64" => Ok(self.context.f64_type().into()),
_ => Err(self.compile_error(CompileErrorKind::UnknownTypeAnnotation))
}
}
fn compile_function_def(
&mut self,
name: &str,
args: &ast::Parameters,
body: &ast::Suite,
decorator_list: &[ast::Expression],
returns: &Option<ast::Expression>,
is_async: bool,
) -> CompileResult<values::FunctionValue<'ctx>> {
if is_async {
return Err(self.compile_error(CompileErrorKind::Unsupported("async functions")))
}
for decorator in decorator_list.iter() {
self.set_source_location(decorator.location);
if let ast::ExpressionType::Identifier { name } = &decorator.node {
if name != "kernel" && name != "portable" {
return Err(self.compile_error(CompileErrorKind::Unsupported("custom decorators")))
}
} else {
return Err(self.compile_error(CompileErrorKind::Unsupported("decorator must be an identifier")))
}
}
let args_type = args.args.iter().map(|val| {
self.set_source_location(val.location);
if let Some(annotation) = &val.annotation {
if let ast::ExpressionType::Identifier { name } = &annotation.node {
Ok(self.get_basic_type(&name)?)
} else {
Err(self.compile_error(CompileErrorKind::Unsupported("type annotation must be an identifier")))
}
} else {
Err(self.compile_error(CompileErrorKind::MissingTypeAnnotation))
}
}).collect::<CompileResult<Vec<types::BasicTypeEnum>>>()?;
let return_type = if let Some(returns) = returns {
self.set_source_location(returns.location);
if let ast::ExpressionType::Identifier { name } = &returns.node {
if name == "None" { None } else { Some(self.get_basic_type(name)?) }
} else {
return Err(self.compile_error(CompileErrorKind::Unsupported("type annotation must be an identifier")))
}
} else {
None
};
let fn_type = match return_type {
Some(ty) => ty.fn_type(&args_type, false),
None => self.context.void_type().fn_type(&args_type, false)
};
let function = self.module.add_function(name, fn_type, None);
let basic_block = self.context.append_basic_block(function, "entry");
self.builder.position_at_end(basic_block);
for (n, arg) in args.args.iter().enumerate() {
let param = function.get_nth_param(n as u32).unwrap();
let alloca = self.builder.build_alloca(param.get_type(), &arg.arg);
self.builder.build_store(alloca, param);
self.namespace.insert(arg.arg.clone(), alloca);
}
self.compile_suite(body, return_type)?;
Ok(function)
}
fn compile_expression(
&mut self,
expression: &ast::Expression
) -> CompileResult<values::BasicValueEnum<'ctx>> {
self.set_source_location(expression.location);
match &expression.node {
ast::ExpressionType::True => Ok(self.context.bool_type().const_int(1, false).into()),
ast::ExpressionType::False => Ok(self.context.bool_type().const_int(0, false).into()),
ast::ExpressionType::Number { value: ast::Number::Integer { value } } => {
let mut bits = value.bits();
if value.sign() == num_bigint::Sign::Minus {
bits += 1;
}
match bits {
0..=32 => Ok(self.context.i32_type().const_int(value.to_i32().unwrap() as _, true).into()),
33..=64 => Ok(self.context.i64_type().const_int(value.to_i64().unwrap() as _, true).into()),
_ => Err(self.compile_error(CompileErrorKind::Unsupported("integers larger than 64 bits")))
}
},
ast::ExpressionType::Number { value: ast::Number::Float { value } } => {
Ok(self.context.f64_type().const_float(*value).into())
},
ast::ExpressionType::Identifier { name } => {
match self.namespace.get(name) {
Some(value) => Ok(self.builder.build_load(*value, name).into()),
None => Err(self.compile_error(CompileErrorKind::UnboundIdentifier))
}
},
ast::ExpressionType::Unop { op, a } => {
let a = self.compile_expression(&a)?;
match (op, a) {
(ast::UnaryOperator::Pos, values::BasicValueEnum::IntValue(a))
=> Ok(a.into()),
(ast::UnaryOperator::Pos, values::BasicValueEnum::FloatValue(a))
=> Ok(a.into()),
(ast::UnaryOperator::Neg, values::BasicValueEnum::IntValue(a))
=> Ok(self.builder.build_int_neg(a, "tmpneg").into()),
(ast::UnaryOperator::Neg, values::BasicValueEnum::FloatValue(a))
=> Ok(self.builder.build_float_neg(a, "tmpneg").into()),
(ast::UnaryOperator::Inv, values::BasicValueEnum::IntValue(a))
=> Ok(self.builder.build_not(a, "tmpnot").into()),
(ast::UnaryOperator::Not, values::BasicValueEnum::IntValue(a)) => {
// boolean "not"
if a.get_type().get_bit_width() != 1 {
Err(self.compile_error(CompileErrorKind::Unsupported("unimplemented unary operation")))
} else {
Ok(self.builder.build_not(a, "tmpnot").into())
}
},
_ => Err(self.compile_error(CompileErrorKind::Unsupported("unimplemented unary operation"))),
}
},
ast::ExpressionType::Binop { a, op, b } => {
let a = self.compile_expression(&a)?;
let b = self.compile_expression(&b)?;
if a.get_type() != b.get_type() {
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
}
use ast::Operator::*;
match (op, a, b) {
(Add, values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b))
=> Ok(self.builder.build_int_add(a, b, "tmpadd").into()),
(Sub, values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b))
=> Ok(self.builder.build_int_sub(a, b, "tmpsub").into()),
(Mult, values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b))
=> Ok(self.builder.build_int_mul(a, b, "tmpmul").into()),
(Add, values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b))
=> Ok(self.builder.build_float_add(a, b, "tmpadd").into()),
(Sub, values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b))
=> Ok(self.builder.build_float_sub(a, b, "tmpsub").into()),
(Mult, values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b))
=> Ok(self.builder.build_float_mul(a, b, "tmpmul").into()),
(Div, values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b))
=> Ok(self.builder.build_float_div(a, b, "tmpdiv").into()),
(FloorDiv, values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b))
=> Ok(self.builder.build_int_signed_div(a, b, "tmpdiv").into()),
_ => Err(self.compile_error(CompileErrorKind::Unsupported("unimplemented binary operation"))),
}
},
ast::ExpressionType::Compare { vals, ops } => {
let mut vals = vals.iter();
let mut ops = ops.iter();
let mut result = None;
let mut a = self.compile_expression(vals.next().unwrap())?;
loop {
if let Some(op) = ops.next() {
let b = self.compile_expression(vals.next().unwrap())?;
if a.get_type() != b.get_type() {
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
}
let this_result = match (a, b) {
(values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b)) => {
match op {
ast::Comparison::Equal
=> self.builder.build_int_compare(IntPredicate::EQ, a, b, "tmpeq"),
ast::Comparison::NotEqual
=> self.builder.build_int_compare(IntPredicate::NE, a, b, "tmpne"),
ast::Comparison::Less
=> self.builder.build_int_compare(IntPredicate::SLT, a, b, "tmpslt"),
ast::Comparison::LessOrEqual
=> self.builder.build_int_compare(IntPredicate::SLE, a, b, "tmpsle"),
ast::Comparison::Greater
=> self.builder.build_int_compare(IntPredicate::SGT, a, b, "tmpsgt"),
ast::Comparison::GreaterOrEqual
=> self.builder.build_int_compare(IntPredicate::SGE, a, b, "tmpsge"),
_ => return Err(self.compile_error(CompileErrorKind::Unsupported("special comparison"))),
}
},
(values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b)) => {
match op {
ast::Comparison::Equal
=> self.builder.build_float_compare(FloatPredicate::OEQ, a, b, "tmpoeq"),
ast::Comparison::NotEqual
=> self.builder.build_float_compare(FloatPredicate::UNE, a, b, "tmpune"),
ast::Comparison::Less
=> self.builder.build_float_compare(FloatPredicate::OLT, a, b, "tmpolt"),
ast::Comparison::LessOrEqual
=> self.builder.build_float_compare(FloatPredicate::OLE, a, b, "tmpole"),
ast::Comparison::Greater
=> self.builder.build_float_compare(FloatPredicate::OGT, a, b, "tmpogt"),
ast::Comparison::GreaterOrEqual
=> self.builder.build_float_compare(FloatPredicate::OGE, a, b, "tmpoge"),
_ => return Err(self.compile_error(CompileErrorKind::Unsupported("special comparison"))),
}
},
_ => return Err(self.compile_error(CompileErrorKind::Unsupported("comparison of non-numerical types"))),
};
match result {
Some(last) => {
result = Some(self.builder.build_and(last, this_result, "tmpand"));
}
None => {
result = Some(this_result);
}
}
a = b;
} else {
return Ok(result.unwrap().into())
}
}
},
ast::ExpressionType::Call { function, args, keywords } => {
if !keywords.is_empty() {
return Err(self.compile_error(CompileErrorKind::Unsupported("keyword arguments")))
}
let args = args.iter().map(|val| self.compile_expression(val))
.collect::<CompileResult<Vec<values::BasicValueEnum>>>()?;
self.set_source_location(expression.location);
if let ast::ExpressionType::Identifier { name } = &function.node {
match (name.as_str(), args[0]) {
("int32", values::BasicValueEnum::IntValue(a)) => {
let nbits = a.get_type().get_bit_width();
if nbits < 32 {
Ok(self.builder.build_int_s_extend(a, self.context.i32_type(), "tmpsext").into())
} else if nbits > 32 {
Ok(self.builder.build_int_truncate(a, self.context.i32_type(), "tmptrunc").into())
} else {
Ok(a.into())
}
},
("int64", values::BasicValueEnum::IntValue(a)) => {
let nbits = a.get_type().get_bit_width();
if nbits < 64 {
Ok(self.builder.build_int_s_extend(a, self.context.i64_type(), "tmpsext").into())
} else {
Ok(a.into())
}
},
("int32", values::BasicValueEnum::FloatValue(a)) => {
Ok(self.builder.build_float_to_signed_int(a, self.context.i32_type(), "tmpfptosi").into())
},
("int64", values::BasicValueEnum::FloatValue(a)) => {
Ok(self.builder.build_float_to_signed_int(a, self.context.i64_type(), "tmpfptosi").into())
},
("float32", values::BasicValueEnum::IntValue(a)) => {
Ok(self.builder.build_signed_int_to_float(a, self.context.f32_type(), "tmpsitofp").into())
},
("float64", values::BasicValueEnum::IntValue(a)) => {
Ok(self.builder.build_signed_int_to_float(a, self.context.f64_type(), "tmpsitofp").into())
},
("float32", values::BasicValueEnum::FloatValue(a)) => {
if a.get_type() == self.context.f64_type() {
Ok(self.builder.build_float_trunc(a, self.context.f32_type(), "tmptrunc").into())
} else {
Ok(a.into())
}
},
("float64", values::BasicValueEnum::FloatValue(a)) => {
if a.get_type() == self.context.f32_type() {
Ok(self.builder.build_float_ext(a, self.context.f64_type(), "tmpext").into())
} else {
Ok(a.into())
}
},
("output", values::BasicValueEnum::IntValue(a)) => {
let fn_value = self.module.get_function("output").unwrap();
Ok(self.builder.build_call(fn_value, &[a.into()], "call")
.try_as_basic_value().left().unwrap())
},
_ => Err(self.compile_error(CompileErrorKind::Unsupported("unrecognized call")))
}
} else {
return Err(self.compile_error(CompileErrorKind::Unsupported("function must be an identifier")))
}
},
_ => return Err(self.compile_error(CompileErrorKind::Unsupported("unimplemented expression"))),
}
}
fn compile_statement(
&mut self,
statement: &ast::Statement,
return_type: Option<types::BasicTypeEnum>
) -> CompileResult<()> {
self.set_source_location(statement.location);
use ast::StatementType::*;
match &statement.node {
Assign { targets, value } => {
let value = self.compile_expression(value)?;
for target in targets.iter() {
self.set_source_location(target.location);
if let ast::ExpressionType::Identifier { name } = &target.node {
let builder = &self.builder;
let target = self.namespace.entry(name.clone()).or_insert_with(
|| builder.build_alloca(value.get_type(), name));
if target.get_type() != value.get_type().ptr_type(inkwell::AddressSpace::Generic) {
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
}
builder.build_store(*target, value);
} else {
return Err(self.compile_error(CompileErrorKind::Unsupported("assignment target must be an identifier")))
}
}
},
Expression { expression } => { self.compile_expression(expression)?; },
If { test, body, orelse } => {
let test = self.compile_expression(test)?;
if test.get_type() != self.context.bool_type().into() {
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
}
let parent = self.builder.get_insert_block().unwrap().get_parent().unwrap();
let then_bb = self.context.append_basic_block(parent, "then");
let else_bb = self.context.append_basic_block(parent, "else");
let cont_bb = self.context.append_basic_block(parent, "ifcont");
self.builder.build_conditional_branch(test.into_int_value(), then_bb, else_bb);
self.builder.position_at_end(then_bb);
self.compile_suite(body, return_type)?;
self.builder.build_unconditional_branch(cont_bb);
self.builder.position_at_end(else_bb);
if let Some(orelse) = orelse {
self.compile_suite(orelse, return_type)?;
}
self.builder.build_unconditional_branch(cont_bb);
self.builder.position_at_end(cont_bb);
},
While { test, body, orelse } => {
let parent = self.builder.get_insert_block().unwrap().get_parent().unwrap();
let test_bb = self.context.append_basic_block(parent, "test");
self.builder.build_unconditional_branch(test_bb);
self.builder.position_at_end(test_bb);
let test = self.compile_expression(test)?;
if test.get_type() != self.context.bool_type().into() {
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
}
let then_bb = self.context.append_basic_block(parent, "then");
let else_bb = self.context.append_basic_block(parent, "else");
let cont_bb = self.context.append_basic_block(parent, "ifcont");
self.builder.build_conditional_branch(test.into_int_value(), then_bb, else_bb);
self.break_bb = Some(cont_bb);
self.builder.position_at_end(then_bb);
self.compile_suite(body, return_type)?;
self.builder.build_unconditional_branch(test_bb);
self.builder.position_at_end(else_bb);
if let Some(orelse) = orelse {
self.compile_suite(orelse, return_type)?;
}
self.builder.build_unconditional_branch(cont_bb);
self.builder.position_at_end(cont_bb);
self.break_bb = None;
},
Break => {
if let Some(bb) = self.break_bb {
self.builder.build_unconditional_branch(bb);
let parent = self.builder.get_insert_block().unwrap().get_parent().unwrap();
let unreachable_bb = self.context.append_basic_block(parent, "unreachable");
self.builder.position_at_end(unreachable_bb);
} else {
return Err(self.compile_error(CompileErrorKind::BreakOutsideLoop));
}
}
Return { value: Some(value) } => {
if let Some(return_type) = return_type {
let value = self.compile_expression(value)?;
if value.get_type() != return_type {
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
}
self.builder.build_return(Some(&value));
} else {
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
}
},
Return { value: None } => {
if !return_type.is_none() {
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
}
self.builder.build_return(None);
},
Pass => (),
_ => return Err(self.compile_error(CompileErrorKind::Unsupported("special statement"))),
}
Ok(())
}
fn compile_suite(
&mut self,
suite: &ast::Suite,
return_type: Option<types::BasicTypeEnum>
) -> CompileResult<()> {
for statement in suite.iter() {
self.compile_statement(statement, return_type)?;
}
Ok(())
}
pub fn compile_toplevel(&mut self, statement: &ast::Statement) -> CompileResult<()> {
self.set_source_location(statement.location);
if let ast::StatementType::FunctionDef {
is_async,
name,
args,
body,
decorator_list,
returns,
} = &statement.node {
let function = self.compile_function_def(name, args, body, decorator_list, returns, *is_async)?;
self.pass_manager.run_on(&function);
Ok(())
} else {
Err(self.compile_error(CompileErrorKind::Internal("top-level is not a function definition")))
}
}
pub fn print_ir(&self) {
self.module.print_to_stderr();
}
pub fn output(&self, filename: &str) {
//let triple = TargetTriple::create("riscv32-none-linux-gnu");
let triple = TargetMachine::get_default_triple();
let target = Target::from_triple(&triple)
.expect("couldn't create target from target triple");
let target_machine = target
.create_target_machine(
&triple,
"",
"",
OptimizationLevel::Default,
RelocMode::Default,
CodeModel::Default,
)
.expect("couldn't create target machine");
target_machine
.write_to_file(&self.module, FileType::Object, Path::new(filename))
.expect("couldn't write module to file");
}
}

View File

@ -1,635 +0,0 @@
use std::fmt::Debug;
use std::sync::Arc;
use std::{collections::HashMap, collections::HashSet, fmt::Display};
use std::rc::Rc;
use crate::{
codegen::{CodeGenContext, CodeGenerator},
toplevel::{DefinitionId, TopLevelDef, type_annotation::TypeAnnotation},
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{Type, TypeEnum, Unifier, VarMap},
},
};
use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue};
use itertools::{chain, Itertools, izip};
use nac3parser::ast::{Constant, Expr, Location, StrRef};
use parking_lot::RwLock;
#[derive(Clone, PartialEq, Debug)]
pub enum SymbolValue {
I32(i32),
I64(i64),
U32(u32),
U64(u64),
Str(String),
Double(f64),
Bool(bool),
Tuple(Vec<SymbolValue>),
OptionSome(Box<SymbolValue>),
OptionNone,
}
impl SymbolValue {
/// Creates a [`SymbolValue`] from a [`Constant`].
///
/// * `constant` - The constant to create the value from.
/// * `expected_ty` - The expected type of the [`SymbolValue`].
pub fn from_constant(
constant: &Constant,
expected_ty: Type,
primitives: &PrimitiveStore,
unifier: &mut Unifier
) -> Result<Self, String> {
match constant {
Constant::None => {
if unifier.unioned(expected_ty, primitives.option) {
Ok(SymbolValue::OptionNone)
} else {
Err(format!("Expected {expected_ty:?}, but got Option"))
}
}
Constant::Bool(b) => {
if unifier.unioned(expected_ty, primitives.bool) {
Ok(SymbolValue::Bool(*b))
} else {
Err(format!("Expected {expected_ty:?}, but got bool"))
}
}
Constant::Str(s) => {
if unifier.unioned(expected_ty, primitives.str) {
Ok(SymbolValue::Str(s.to_string()))
} else {
Err(format!("Expected {expected_ty:?}, but got str"))
}
},
Constant::Int(i) => {
if unifier.unioned(expected_ty, primitives.int32) {
i32::try_from(*i)
.map(SymbolValue::I32)
.map_err(|e| e.to_string())
} else if unifier.unioned(expected_ty, primitives.int64) {
i64::try_from(*i)
.map(SymbolValue::I64)
.map_err(|e| e.to_string())
} else if unifier.unioned(expected_ty, primitives.uint32) {
u32::try_from(*i)
.map(SymbolValue::U32)
.map_err(|e| e.to_string())
} else if unifier.unioned(expected_ty, primitives.uint64) {
u64::try_from(*i)
.map(SymbolValue::U64)
.map_err(|e| e.to_string())
} else {
Err(format!("Expected {}, but got int", unifier.stringify(expected_ty)))
}
}
Constant::Tuple(t) => {
let expected_ty = unifier.get_ty(expected_ty);
let TypeEnum::TTuple { ty } = expected_ty.as_ref() else {
return Err(format!("Expected {:?}, but got Tuple", expected_ty.get_type_name()))
};
assert_eq!(ty.len(), t.len());
let elems = t
.iter()
.zip(ty)
.map(|(constant, ty)| Self::from_constant(constant, *ty, primitives, unifier))
.collect::<Result<Vec<SymbolValue>, _>>()?;
Ok(SymbolValue::Tuple(elems))
}
Constant::Float(f) => {
if unifier.unioned(expected_ty, primitives.float) {
Ok(SymbolValue::Double(*f))
} else {
Err(format!("Expected {expected_ty:?}, but got float"))
}
},
_ => Err(format!("Unsupported value type {constant:?}")),
}
}
/// Creates a [`SymbolValue`] from a [`Constant`], with its type being inferred from the constant value.
///
/// * `constant` - The constant to create the value from.
pub fn from_constant_inferred(
constant: &Constant,
) -> Result<Self, String> {
match constant {
Constant::None => Ok(SymbolValue::OptionNone),
Constant::Bool(b) => Ok(SymbolValue::Bool(*b)),
Constant::Str(s) => Ok(SymbolValue::Str(s.to_string())),
Constant::Int(i) => {
let i = *i;
if i >= 0 {
i32::try_from(i).map(SymbolValue::I32)
.or_else(|_| i64::try_from(i).map(SymbolValue::I64))
.map_err(|_| format!("Literal cannot be expressed as any integral type: {i}"))
} else {
u32::try_from(i).map(SymbolValue::U32)
.or_else(|_| u64::try_from(i).map(SymbolValue::U64))
.map_err(|_| format!("Literal cannot be expressed as any integral type: {i}"))
}
}
Constant::Tuple(t) => {
let elems = t
.iter()
.map(Self::from_constant_inferred)
.collect::<Result<Vec<SymbolValue>, _>>()?;
Ok(SymbolValue::Tuple(elems))
}
Constant::Float(f) => Ok(SymbolValue::Double(*f)),
_ => Err(format!("Unsupported value type {constant:?}")),
}
}
/// Returns the [`Type`] representing the data type of this value.
pub fn get_type(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> Type {
match self {
SymbolValue::I32(_) => primitives.int32,
SymbolValue::I64(_) => primitives.int64,
SymbolValue::U32(_) => primitives.uint32,
SymbolValue::U64(_) => primitives.uint64,
SymbolValue::Str(_) => primitives.str,
SymbolValue::Double(_) => primitives.float,
SymbolValue::Bool(_) => primitives.bool,
SymbolValue::Tuple(vs) => {
let vs_tys = vs
.iter()
.map(|v| v.get_type(primitives, unifier))
.collect::<Vec<_>>();
unifier.add_ty(TypeEnum::TTuple {
ty: vs_tys,
})
}
SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option,
}
}
/// Returns the [`TypeAnnotation`] representing the data type of this value.
pub fn get_type_annotation(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> TypeAnnotation {
match self {
SymbolValue::Bool(..)
| SymbolValue::Double(..)
| SymbolValue::I32(..)
| SymbolValue::I64(..)
| SymbolValue::U32(..)
| SymbolValue::U64(..)
| SymbolValue::Str(..) => TypeAnnotation::Primitive(self.get_type(primitives, unifier)),
SymbolValue::Tuple(vs) => {
let vs_tys = vs
.iter()
.map(|v| v.get_type_annotation(primitives, unifier))
.collect::<Vec<_>>();
TypeAnnotation::Tuple(vs_tys)
}
SymbolValue::OptionNone => TypeAnnotation::CustomClass {
id: primitives.option.obj_id(unifier).unwrap(),
params: Vec::default(),
},
SymbolValue::OptionSome(v) => {
let ty = v.get_type_annotation(primitives, unifier);
TypeAnnotation::CustomClass {
id: primitives.option.obj_id(unifier).unwrap(),
params: vec![ty],
}
}
}
}
/// Returns the [`TypeEnum`] representing the data type of this value.
pub fn get_type_enum(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> Rc<TypeEnum> {
let ty = self.get_type(primitives, unifier);
unifier.get_ty(ty)
}
}
impl Display for SymbolValue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SymbolValue::I32(i) => write!(f, "{i}"),
SymbolValue::I64(i) => write!(f, "int64({i})"),
SymbolValue::U32(i) => write!(f, "uint32({i})"),
SymbolValue::U64(i) => write!(f, "uint64({i})"),
SymbolValue::Str(s) => write!(f, "\"{s}\""),
SymbolValue::Double(d) => write!(f, "{d}"),
SymbolValue::Bool(b) => {
if *b {
write!(f, "True")
} else {
write!(f, "False")
}
}
SymbolValue::Tuple(t) => {
write!(f, "({})", t.iter().map(|v| format!("{v}")).collect::<Vec<_>>().join(", "))
}
SymbolValue::OptionSome(v) => write!(f, "Some({v})"),
SymbolValue::OptionNone => write!(f, "none"),
}
}
}
impl TryFrom<SymbolValue> for u64 {
type Error = ();
/// Tries to convert a [`SymbolValue`] into a [`u64`], returning [`Err`] if the value is not
/// numeric or if the value cannot be converted into a `u64` without overflow.
fn try_from(value: SymbolValue) -> Result<Self, Self::Error> {
match value {
SymbolValue::I32(v) => u64::try_from(v).map_err(|_| ()),
SymbolValue::I64(v) => u64::try_from(v).map_err(|_| ()),
SymbolValue::U32(v) => Ok(v as u64),
SymbolValue::U64(v) => Ok(v),
_ => Err(()),
}
}
}
impl TryFrom<SymbolValue> for i128 {
type Error = ();
/// Tries to convert a [`SymbolValue`] into a [`i128`], returning [`Err`] if the value is not
/// numeric.
fn try_from(value: SymbolValue) -> Result<Self, Self::Error> {
match value {
SymbolValue::I32(v) => Ok(v as i128),
SymbolValue::I64(v) => Ok(v as i128),
SymbolValue::U32(v) => Ok(v as i128),
SymbolValue::U64(v) => Ok(v as i128),
_ => Err(()),
}
}
}
pub trait StaticValue {
/// Returns a unique identifier for this value.
fn get_unique_identifier(&self) -> u64;
/// Returns the constant object represented by this unique identifier.
fn get_const_obj<'ctx>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
) -> BasicValueEnum<'ctx>;
/// Converts this value to a LLVM [`BasicValueEnum`].
fn to_basic_value_enum<'ctx>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
expected_ty: Type,
) -> Result<BasicValueEnum<'ctx>, String>;
/// Returns a field within this value.
fn get_field<'ctx>(
&self,
name: StrRef,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Option<ValueEnum<'ctx>>;
/// Returns a single element of this tuple.
fn get_tuple_element<'ctx>(&self, index: u32) -> Option<ValueEnum<'ctx>>;
}
#[derive(Clone)]
pub enum ValueEnum<'ctx> {
/// [ValueEnum] representing a static value.
Static(Arc<dyn StaticValue + Send + Sync>),
/// [ValueEnum] representing a dynamic value.
Dynamic(BasicValueEnum<'ctx>),
}
impl<'ctx> From<BasicValueEnum<'ctx>> for ValueEnum<'ctx> {
fn from(v: BasicValueEnum<'ctx>) -> Self {
ValueEnum::Dynamic(v)
}
}
impl<'ctx> From<PointerValue<'ctx>> for ValueEnum<'ctx> {
fn from(v: PointerValue<'ctx>) -> Self {
ValueEnum::Dynamic(v.into())
}
}
impl<'ctx> From<IntValue<'ctx>> for ValueEnum<'ctx> {
fn from(v: IntValue<'ctx>) -> Self {
ValueEnum::Dynamic(v.into())
}
}
impl<'ctx> From<FloatValue<'ctx>> for ValueEnum<'ctx> {
fn from(v: FloatValue<'ctx>) -> Self {
ValueEnum::Dynamic(v.into())
}
}
impl<'ctx> From<StructValue<'ctx>> for ValueEnum<'ctx> {
fn from(v: StructValue<'ctx>) -> Self {
ValueEnum::Dynamic(v.into())
}
}
impl<'ctx> ValueEnum<'ctx> {
/// Converts this [`ValueEnum`] to a [`BasicValueEnum`].
pub fn to_basic_value_enum<'a>(
self,
ctx: &mut CodeGenContext<'ctx, 'a>,
generator: &mut dyn CodeGenerator,
expected_ty: Type,
) -> Result<BasicValueEnum<'ctx>, String> {
match self {
ValueEnum::Static(v) => v.to_basic_value_enum(ctx, generator, expected_ty),
ValueEnum::Dynamic(v) => Ok(v),
}
}
}
pub trait SymbolResolver {
/// Get type of type variable identifier or top-level function type,
fn get_symbol_type(
&self,
unifier: &mut Unifier,
top_level_defs: &[Arc<RwLock<TopLevelDef>>],
primitives: &PrimitiveStore,
str: StrRef,
) -> Result<Type, String>;
/// Get the top-level definition of identifiers.
fn get_identifier_def(&self, str: StrRef) -> Result<DefinitionId, HashSet<String>>;
fn get_symbol_value<'ctx>(
&self,
str: StrRef,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Option<ValueEnum<'ctx>>;
fn get_default_param_value(&self, expr: &Expr) -> Option<SymbolValue>;
fn get_string_id(&self, s: &str) -> i32;
fn get_exception_id(&self, tyid: usize) -> usize;
fn handle_deferred_eval(
&self,
_unifier: &mut Unifier,
_top_level_defs: &[Arc<RwLock<TopLevelDef>>],
_primitives: &PrimitiveStore
) -> Result<(), String> {
Ok(())
}
}
thread_local! {
static IDENTIFIER_ID: [StrRef; 12] = [
"int32".into(),
"int64".into(),
"float".into(),
"bool".into(),
"virtual".into(),
"list".into(),
"tuple".into(),
"str".into(),
"Exception".into(),
"uint32".into(),
"uint64".into(),
"Literal".into(),
];
}
/// Converts a type annotation into a [Type].
pub fn parse_type_annotation<T>(
resolver: &dyn SymbolResolver,
top_level_defs: &[Arc<RwLock<TopLevelDef>>],
unifier: &mut Unifier,
primitives: &PrimitiveStore,
expr: &Expr<T>,
) -> Result<Type, HashSet<String>> {
use nac3parser::ast::ExprKind::*;
let ids = IDENTIFIER_ID.with(|ids| *ids);
let int32_id = ids[0];
let int64_id = ids[1];
let float_id = ids[2];
let bool_id = ids[3];
let virtual_id = ids[4];
let list_id = ids[5];
let tuple_id = ids[6];
let str_id = ids[7];
let exn_id = ids[8];
let uint32_id = ids[9];
let uint64_id = ids[10];
let literal_id = ids[11];
let name_handling = |id: &StrRef, loc: Location, unifier: &mut Unifier| {
if *id == int32_id {
Ok(primitives.int32)
} else if *id == int64_id {
Ok(primitives.int64)
} else if *id == uint32_id {
Ok(primitives.uint32)
} else if *id == uint64_id {
Ok(primitives.uint64)
} else if *id == float_id {
Ok(primitives.float)
} else if *id == bool_id {
Ok(primitives.bool)
} else if *id == str_id {
Ok(primitives.str)
} else if *id == exn_id {
Ok(primitives.exception)
} else {
let obj_id = resolver.get_identifier_def(*id);
if let Ok(obj_id) = obj_id {
let def = top_level_defs[obj_id.0].read();
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
if !type_vars.is_empty() {
return Err(HashSet::from([
format!(
"Unexpected number of type parameters: expected {} but got 0",
type_vars.len()
),
]))
}
let fields = chain(
fields.iter().map(|(k, v, m)| (*k, (*v, *m))),
methods.iter().map(|(k, v, _)| (*k, (*v, false))),
)
.collect();
Ok(unifier.add_ty(TypeEnum::TObj {
obj_id,
fields,
params: VarMap::default(),
}))
} else {
Err(HashSet::from([
format!("Cannot use function name as type at {loc}"),
]))
}
} else {
let ty = resolver
.get_symbol_type(unifier, top_level_defs, primitives, *id)
.map_err(|e| HashSet::from([
format!("Unknown type annotation at {loc}: {e}"),
]))?;
if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) {
Ok(ty)
} else {
Err(HashSet::from([
format!("Unknown type annotation {id} at {loc}"),
]))
}
}
}
};
let subscript_name_handle = |id: &StrRef, slice: &Expr<T>, unifier: &mut Unifier| {
if *id == virtual_id {
let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, slice)?;
Ok(unifier.add_ty(TypeEnum::TVirtual { ty }))
} else if *id == list_id {
let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, slice)?;
Ok(unifier.add_ty(TypeEnum::TList { ty }))
} else if *id == tuple_id {
if let Tuple { elts, .. } = &slice.node {
let ty = elts
.iter()
.map(|elt| {
parse_type_annotation(resolver, top_level_defs, unifier, primitives, elt)
})
.collect::<Result<Vec<_>, _>>()?;
Ok(unifier.add_ty(TypeEnum::TTuple { ty }))
} else {
Err(HashSet::from([
"Expected multiple elements for tuple".into()
]))
}
} else if *id == literal_id {
let mut parse_literal = |elt: &Expr<T>| {
let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, elt)?;
let ty_enum = &*unifier.get_ty_immutable(ty);
match ty_enum {
TypeEnum::TLiteral { values, .. } => Ok(values.clone()),
_ => Err(HashSet::from([
format!("Expected literal in type argument for Literal at {}", elt.location),
]))
}
};
let values = if let Tuple { elts, .. } = &slice.node {
elts.iter()
.map(&mut parse_literal)
.collect::<Result<Vec<_>, _>>()?
} else {
vec![parse_literal(slice)?]
}.into_iter().flatten().collect_vec();
Ok(unifier.get_fresh_literal(values, Some(slice.location)))
} else {
let types = if let Tuple { elts, .. } = &slice.node {
elts.iter()
.map(|v| {
parse_type_annotation(resolver, top_level_defs, unifier, primitives, v)
})
.collect::<Result<Vec<_>, _>>()?
} else {
vec![parse_type_annotation(resolver, top_level_defs, unifier, primitives, slice)?]
};
let obj_id = resolver.get_identifier_def(*id)?;
let def = top_level_defs[obj_id.0].read();
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
if types.len() != type_vars.len() {
return Err(HashSet::from([
format!(
"Unexpected number of type parameters: expected {} but got {}",
type_vars.len(),
types.len()
),
]))
}
let mut subst = VarMap::new();
for (var, ty) in izip!(type_vars.iter(), types.iter()) {
let id = if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) {
*id
} else {
unreachable!()
};
subst.insert(id, *ty);
}
let mut fields = fields
.iter()
.map(|(attr, ty, is_mutable)| {
let ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*attr, (ty, *is_mutable))
})
.collect::<HashMap<_, _>>();
fields.extend(methods.iter().map(|(attr, ty, _)| {
let ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*attr, (ty, false))
}));
Ok(unifier.add_ty(TypeEnum::TObj { obj_id, fields, params: subst }))
} else {
Err(HashSet::from([
"Cannot use function name as type".into(),
]))
}
}
};
match &expr.node {
Name { id, .. } => name_handling(id, expr.location, unifier),
Subscript { value, slice, .. } => {
if let Name { id, .. } = &value.node {
subscript_name_handle(id, slice, unifier)
} else {
Err(HashSet::from([
format!("unsupported type expression at {}", expr.location),
]))
}
}
Constant { value, .. } => SymbolValue::from_constant_inferred(value)
.map(|v| unifier.get_fresh_literal(vec![v], Some(expr.location)))
.map_err(|err| HashSet::from([err])),
_ => Err(HashSet::from([
format!("unsupported type expression at {}", expr.location),
])),
}
}
impl dyn SymbolResolver + Send + Sync {
pub fn parse_type_annotation<T>(
&self,
top_level_defs: &[Arc<RwLock<TopLevelDef>>],
unifier: &mut Unifier,
primitives: &PrimitiveStore,
expr: &Expr<T>,
) -> Result<Type, HashSet<String>> {
parse_type_annotation(self, top_level_defs, unifier, primitives, expr)
}
pub fn get_type_name(
&self,
top_level_defs: &[Arc<RwLock<TopLevelDef>>],
unifier: &mut Unifier,
ty: Type,
) -> String {
unifier.internal_stringify(
ty,
&mut |id| {
let TopLevelDef::Class { name, .. } = &*top_level_defs[id].read() else {
unreachable!("expected class definition")
};
name.to_string()
},
&mut |id| format!("typevar{id}"),
&mut None,
)
}
}
impl Debug for dyn SymbolResolver + Send + Sync {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "")
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,693 +0,0 @@
use std::convert::TryInto;
use crate::symbol_resolver::SymbolValue;
use crate::typecheck::typedef::{Mapping, VarMap};
use nac3parser::ast::{Constant, Location};
use super::*;
/// Structure storing [`DefinitionId`] for primitive types.
#[derive(Clone, Copy)]
pub struct PrimitiveDefinitionIds {
pub int32: DefinitionId,
pub int64: DefinitionId,
pub uint32: DefinitionId,
pub uint64: DefinitionId,
pub float: DefinitionId,
pub bool: DefinitionId,
pub none: DefinitionId,
pub range: DefinitionId,
pub str: DefinitionId,
pub exception: DefinitionId,
pub option: DefinitionId,
pub ndarray: DefinitionId,
}
impl PrimitiveDefinitionIds {
/// Returns all [`DefinitionId`] of primitives as a [`Vec`].
///
/// There are no guarantees on ordering of the IDs.
#[must_use]
fn as_vec(&self) -> Vec<DefinitionId> {
vec![
self.int32,
self.int64,
self.uint32,
self.uint64,
self.float,
self.bool,
self.none,
self.range,
self.str,
self.exception,
self.option,
self.ndarray,
]
}
/// Returns an iterator over all [`DefinitionId`]s of this instance in indeterminate order.
pub fn iter(&self) -> impl Iterator<Item=DefinitionId> {
self.as_vec().into_iter()
}
/// Returns the primitive with the largest [`DefinitionId`].
#[must_use]
pub fn max_id(&self) -> DefinitionId {
self.iter().max().unwrap()
}
}
/// The [definition IDs][DefinitionId] for primitive types.
pub const PRIMITIVE_DEF_IDS: PrimitiveDefinitionIds = PrimitiveDefinitionIds {
int32: DefinitionId(0),
int64: DefinitionId(1),
uint32: DefinitionId(8),
uint64: DefinitionId(9),
float: DefinitionId(2),
bool: DefinitionId(3),
none: DefinitionId(4),
range: DefinitionId(5),
str: DefinitionId(6),
exception: DefinitionId(7),
option: DefinitionId(10),
ndarray: DefinitionId(14),
};
impl TopLevelDef {
pub fn to_string(&self, unifier: &mut Unifier) -> String {
match self {
TopLevelDef::Class { name, ancestors, fields, methods, type_vars, .. } => {
let fields_str = fields
.iter()
.map(|(n, ty, _)| (n.to_string(), unifier.stringify(*ty)))
.collect_vec();
let methods_str = methods
.iter()
.map(|(n, ty, id)| (n.to_string(), unifier.stringify(*ty), *id))
.collect_vec();
format!(
"Class {{\nname: {:?},\nancestors: {:?},\nfields: {:?},\nmethods: {:?},\ntype_vars: {:?}\n}}",
name,
ancestors.iter().map(|ancestor| ancestor.stringify(unifier)).collect_vec(),
fields_str.iter().map(|(a, _)| a).collect_vec(),
methods_str.iter().map(|(a, b, _)| (a, b)).collect_vec(),
type_vars.iter().map(|id| unifier.stringify(*id)).collect_vec(),
)
}
TopLevelDef::Function { name, signature, var_id, .. } => format!(
"Function {{\nname: {:?},\nsig: {:?},\nvar_id: {:?}\n}}",
name,
unifier.stringify(*signature),
{
// preserve the order for debug output and test
let mut r = var_id.clone();
r.sort_unstable();
r
}
),
}
}
}
impl TopLevelComposer {
#[must_use]
pub fn make_primitives(size_t: u32) -> (PrimitiveStore, Unifier) {
let mut unifier = Unifier::new();
let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.int32,
fields: HashMap::new(),
params: VarMap::new(),
});
let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.int64,
fields: HashMap::new(),
params: VarMap::new(),
});
let float = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.float,
fields: HashMap::new(),
params: VarMap::new(),
});
let bool = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.bool,
fields: HashMap::new(),
params: VarMap::new(),
});
let none = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.none,
fields: HashMap::new(),
params: VarMap::new(),
});
let range = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.range,
fields: HashMap::new(),
params: VarMap::new(),
});
let str = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.str,
fields: HashMap::new(),
params: VarMap::new(),
});
let exception = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.exception,
fields: vec![
("__name__".into(), (int32, true)),
("__file__".into(), (str, true)),
("__line__".into(), (int32, true)),
("__col__".into(), (int32, true)),
("__func__".into(), (str, true)),
("__message__".into(), (str, true)),
("__param0__".into(), (int64, true)),
("__param1__".into(), (int64, true)),
("__param2__".into(), (int64, true)),
]
.into_iter()
.collect::<HashMap<_, _>>(),
params: VarMap::new(),
});
let uint32 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.uint32,
fields: HashMap::new(),
params: VarMap::new(),
});
let uint64 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.uint64,
fields: HashMap::new(),
params: VarMap::new(),
});
let option_type_var = unifier.get_fresh_var(Some("option_type_var".into()), None);
let is_some_type_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![],
ret: bool,
vars: VarMap::from([(option_type_var.1, option_type_var.0)]),
}));
let unwrap_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![],
ret: option_type_var.0,
vars: VarMap::from([(option_type_var.1, option_type_var.0)]),
}));
let option = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.option,
fields: vec![
("is_some".into(), (is_some_type_fun_ty, true)),
("is_none".into(), (is_some_type_fun_ty, true)),
("unwrap".into(), (unwrap_fun_ty, true)),
]
.into_iter()
.collect::<HashMap<_, _>>(),
params: VarMap::from([(option_type_var.1, option_type_var.0)]),
});
let size_t_ty = match size_t {
32 => uint32,
64 => uint64,
_ => unreachable!(),
};
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
let ndarray_ndims_tvar = unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None);
let ndarray_copy_fun_ret_ty = unifier.get_fresh_var(None, None);
let ndarray_copy_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![],
ret: ndarray_copy_fun_ret_ty.0,
vars: VarMap::from([
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
(ndarray_ndims_tvar.1, ndarray_ndims_tvar.0),
]),
}));
let ndarray_fill_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg {
name: "value".into(),
ty: ndarray_dtype_tvar.0,
default_value: None,
},
],
ret: none,
vars: VarMap::from([
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
(ndarray_ndims_tvar.1, ndarray_ndims_tvar.0),
]),
}));
let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.ndarray,
fields: Mapping::from([
("copy".into(), (ndarray_copy_fun_ty, true)),
("fill".into(), (ndarray_fill_fun_ty, true)),
]),
params: VarMap::from([
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
(ndarray_ndims_tvar.1, ndarray_ndims_tvar.0),
]),
});
unifier.unify(ndarray_copy_fun_ret_ty.0, ndarray).unwrap();
let primitives = PrimitiveStore {
int32,
int64,
uint32,
uint64,
float,
bool,
none,
range,
str,
exception,
option,
ndarray,
size_t,
};
unifier.put_primitive_store(&primitives);
crate::typecheck::magic_methods::set_primitives_magic_methods(&primitives, &mut unifier);
(primitives, unifier)
}
/// already include the `definition_id` of itself inside the ancestors vector
/// when first registering, the `type_vars`, fields, methods, ancestors are invalid
#[must_use]
pub fn make_top_level_class_def(
obj_id: DefinitionId,
resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>,
name: StrRef,
constructor: Option<Type>,
loc: Option<Location>,
) -> TopLevelDef {
TopLevelDef::Class {
name,
object_id: obj_id,
type_vars: Vec::default(),
fields: Vec::default(),
methods: Vec::default(),
ancestors: Vec::default(),
constructor,
resolver,
loc,
}
}
/// when first registering, the type is a invalid value
#[must_use]
pub fn make_top_level_function_def(
name: String,
simple_name: StrRef,
ty: Type,
resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>,
loc: Option<Location>,
) -> TopLevelDef {
TopLevelDef::Function {
name,
simple_name,
signature: ty,
var_id: Vec::default(),
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver,
codegen_callback: None,
loc,
}
}
#[must_use]
pub fn make_class_method_name(mut class_name: String, method_name: &str) -> String {
class_name.push('.');
class_name.push_str(method_name);
class_name
}
pub fn get_class_method_def_info(
class_methods_def: &[(StrRef, Type, DefinitionId)],
method_name: StrRef,
) -> Result<(Type, DefinitionId), HashSet<String>> {
for (name, ty, def_id) in class_methods_def {
if name == &method_name {
return Ok((*ty, *def_id));
}
}
Err(HashSet::from([format!("no method {method_name} in the current class")]))
}
/// get all base class def id of a class, excluding itself. \
/// this function should called only after the direct parent is set
/// and before all the ancestors are set
/// and when we allow single inheritance \
/// the order of the returned list is from the child to the deepest ancestor
pub fn get_all_ancestors_helper(
child: &TypeAnnotation,
temp_def_list: &[Arc<RwLock<TopLevelDef>>],
) -> Result<Vec<TypeAnnotation>, HashSet<String>> {
let mut result: Vec<TypeAnnotation> = Vec::new();
let mut parent = Self::get_parent(child, temp_def_list);
while let Some(p) = parent {
parent = Self::get_parent(&p, temp_def_list);
let p_id = if let TypeAnnotation::CustomClass { id, .. } = &p {
*id
} else {
unreachable!("must be class kind annotation")
};
// check cycle
let no_cycle = result.iter().all(|x| {
let TypeAnnotation::CustomClass { id, .. } = x else {
unreachable!("must be class kind annotation")
};
id.0 != p_id.0
});
if no_cycle {
result.push(p);
} else {
return Err(HashSet::from(["cyclic inheritance detected".into()]));
}
}
Ok(result)
}
/// should only be called when finding all ancestors, so panic when wrong
fn get_parent(
child: &TypeAnnotation,
temp_def_list: &[Arc<RwLock<TopLevelDef>>],
) -> Option<TypeAnnotation> {
let child_id = if let TypeAnnotation::CustomClass { id, .. } = child {
*id
} else {
unreachable!("should be class type annotation")
};
let child_def = temp_def_list.get(child_id.0).unwrap();
let child_def = child_def.read();
let TopLevelDef::Class { ancestors, .. } = &*child_def else {
unreachable!("child must be top level class def")
};
if ancestors.is_empty() {
None
} else {
Some(ancestors[0].clone())
}
}
/// get the `var_id` of a given `TVar` type
pub fn get_var_id(var_ty: Type, unifier: &mut Unifier) -> Result<u32, HashSet<String>> {
if let TypeEnum::TVar { id, .. } = unifier.get_ty(var_ty).as_ref() {
Ok(*id)
} else {
Err(HashSet::from([
"not type var".to_string(),
]))
}
}
pub fn check_overload_function_type(
this: Type,
other: Type,
unifier: &mut Unifier,
type_var_to_concrete_def: &HashMap<Type, TypeAnnotation>,
) -> bool {
let this = unifier.get_ty(this);
let this = this.as_ref();
let other = unifier.get_ty(other);
let other = other.as_ref();
let (
TypeEnum::TFunc(FunSignature { args: this_args, ret: this_ret, .. }),
TypeEnum::TFunc(FunSignature { args: other_args, ret: other_ret, .. }),
) = (this, other) else {
unreachable!("this function must be called with function type")
};
// check args
let args_ok = this_args
.iter()
.map(|FuncArg { name, ty, .. }| (name, type_var_to_concrete_def.get(ty).unwrap()))
.zip(other_args.iter().map(|FuncArg { name, ty, .. }| {
(name, type_var_to_concrete_def.get(ty).unwrap())
}))
.all(|(this, other)| {
if this.0 == &"self".into() && this.0 == other.0 {
true
} else {
this.0 == other.0
&& check_overload_type_annotation_compatible(this.1, other.1, unifier)
}
});
// check rets
let ret_ok = check_overload_type_annotation_compatible(
type_var_to_concrete_def.get(this_ret).unwrap(),
type_var_to_concrete_def.get(other_ret).unwrap(),
unifier,
);
// return
args_ok && ret_ok
}
pub fn check_overload_field_type(
this: Type,
other: Type,
unifier: &mut Unifier,
type_var_to_concrete_def: &HashMap<Type, TypeAnnotation>,
) -> bool {
check_overload_type_annotation_compatible(
type_var_to_concrete_def.get(&this).unwrap(),
type_var_to_concrete_def.get(&other).unwrap(),
unifier,
)
}
pub fn get_all_assigned_field(stmts: &[Stmt<()>]) -> Result<HashSet<StrRef>, HashSet<String>> {
let mut result = HashSet::new();
for s in stmts {
match &s.node {
ast::StmtKind::AnnAssign { target, .. }
if {
if let ast::ExprKind::Attribute { value, .. } = &target.node {
if let ast::ExprKind::Name { id, .. } = &value.node {
id == &"self".into()
} else {
false
}
} else {
false
}
} =>
{
return Err(HashSet::from([
format!(
"redundant type annotation for class fields at {}",
s.location
),
]))
}
ast::StmtKind::Assign { targets, .. } => {
for t in targets {
if let ast::ExprKind::Attribute { value, attr, .. } = &t.node {
if let ast::ExprKind::Name { id, .. } = &value.node {
if id == &"self".into() {
result.insert(*attr);
}
}
}
}
}
// TODO: do not check for For and While?
ast::StmtKind::For { body, orelse, .. }
| ast::StmtKind::While { body, orelse, .. } => {
result.extend(Self::get_all_assigned_field(body.as_slice())?);
result.extend(Self::get_all_assigned_field(orelse.as_slice())?);
}
ast::StmtKind::If { body, orelse, .. } => {
let inited_for_sure = Self::get_all_assigned_field(body.as_slice())?
.intersection(&Self::get_all_assigned_field(orelse.as_slice())?)
.copied()
.collect::<HashSet<_>>();
result.extend(inited_for_sure);
}
ast::StmtKind::Try { body, orelse, finalbody, .. } => {
let inited_for_sure = Self::get_all_assigned_field(body.as_slice())?
.intersection(&Self::get_all_assigned_field(orelse.as_slice())?)
.copied()
.collect::<HashSet<_>>();
result.extend(inited_for_sure);
result.extend(Self::get_all_assigned_field(finalbody.as_slice())?);
}
ast::StmtKind::With { body, .. } => {
result.extend(Self::get_all_assigned_field(body.as_slice())?);
}
ast::StmtKind::Pass { .. }
| ast::StmtKind::Assert { .. }
| ast::StmtKind::Expr { .. } => {}
_ => {
unimplemented!()
}
}
}
Ok(result)
}
pub fn parse_parameter_default_value(
default: &ast::Expr,
resolver: &(dyn SymbolResolver + Send + Sync),
) -> Result<SymbolValue, HashSet<String>> {
parse_parameter_default_value(default, resolver)
}
pub fn check_default_param_type(
val: &SymbolValue,
ty: &TypeAnnotation,
primitive: &PrimitiveStore,
unifier: &mut Unifier,
) -> Result<(), String> {
fn is_compatible(
found: &TypeAnnotation,
expect: &TypeAnnotation,
unifier: &mut Unifier,
primitive: &PrimitiveStore,
) -> bool {
match (found, expect) {
(TypeAnnotation::Primitive(f), TypeAnnotation::Primitive(e)) => {
unifier.unioned(*f, *e)
}
(
TypeAnnotation::CustomClass { id: f_id, params: f_param },
TypeAnnotation::CustomClass { id: e_id, params: e_param },
) => {
*f_id == *e_id
&& *f_id == primitive.option.obj_id(unifier).unwrap()
&& (f_param.is_empty()
|| (f_param.len() == 1
&& e_param.len() == 1
&& is_compatible(&f_param[0], &e_param[0], unifier, primitive)))
}
(TypeAnnotation::Tuple(f), TypeAnnotation::Tuple(e)) => {
f.len() == e.len()
&& f.iter()
.zip(e.iter())
.all(|(f, e)| is_compatible(f, e, unifier, primitive))
}
_ => false,
}
}
let found = val.get_type_annotation(primitive, unifier);
if is_compatible(&found, ty, unifier, primitive) {
Ok(())
} else {
Err(format!(
"incompatible default parameter type, expect {}, found {}",
ty.stringify(unifier),
found.stringify(unifier),
))
}
}
}
pub fn parse_parameter_default_value(
default: &ast::Expr,
resolver: &(dyn SymbolResolver + Send + Sync),
) -> Result<SymbolValue, HashSet<String>> {
fn handle_constant(val: &Constant, loc: &Location) -> Result<SymbolValue, HashSet<String>> {
match val {
Constant::Int(v) => {
if let Ok(v) = (*v).try_into() {
Ok(SymbolValue::I32(v))
} else {
Err(HashSet::from([format!("integer value out of range at {loc}")]))
}
}
Constant::Float(v) => Ok(SymbolValue::Double(*v)),
Constant::Bool(v) => Ok(SymbolValue::Bool(*v)),
Constant::Tuple(tuple) => Ok(SymbolValue::Tuple(
tuple.iter().map(|x| handle_constant(x, loc)).collect::<Result<Vec<_>, _>>()?,
)),
Constant::None => Err(HashSet::from([
format!(
"`None` is not supported, use `none` for option type instead ({loc})"
),
])),
_ => unimplemented!("this constant is not supported at {}", loc),
}
}
match &default.node {
ast::ExprKind::Constant { value, .. } => handle_constant(value, &default.location),
ast::ExprKind::Call { func, args, .. } if args.len() == 1 => {
match &func.node {
ast::ExprKind::Name { id, .. } if *id == "int64".into() => match &args[0].node {
ast::ExprKind::Constant { value: Constant::Int(v), .. } => {
let v: Result<i64, _> = (*v).try_into();
match v {
Ok(v) => Ok(SymbolValue::I64(v)),
_ => Err(HashSet::from([
format!("default param value out of range at {}", default.location)
])),
}
}
_ => Err(HashSet::from([
format!("only allow constant integer here at {}", default.location),
]))
}
ast::ExprKind::Name { id, .. } if *id == "uint32".into() => match &args[0].node {
ast::ExprKind::Constant { value: Constant::Int(v), .. } => {
let v: Result<u32, _> = (*v).try_into();
match v {
Ok(v) => Ok(SymbolValue::U32(v)),
_ => Err(HashSet::from([
format!("default param value out of range at {}", default.location),
])),
}
}
_ => Err(HashSet::from([
format!("only allow constant integer here at {}", default.location),
]))
}
ast::ExprKind::Name { id, .. } if *id == "uint64".into() => match &args[0].node {
ast::ExprKind::Constant { value: Constant::Int(v), .. } => {
let v: Result<u64, _> = (*v).try_into();
match v {
Ok(v) => Ok(SymbolValue::U64(v)),
_ => Err(HashSet::from([
format!("default param value out of range at {}", default.location),
])),
}
}
_ => Err(HashSet::from([
format!("only allow constant integer here at {}", default.location),
]))
}
ast::ExprKind::Name { id, .. } if *id == "Some".into() => Ok(
SymbolValue::OptionSome(
Box::new(parse_parameter_default_value(&args[0], resolver)?)
)
),
_ => Err(HashSet::from([
format!("unsupported default parameter at {}", default.location),
])),
}
}
ast::ExprKind::Tuple { elts, .. } => Ok(SymbolValue::Tuple(elts
.iter()
.map(|x| parse_parameter_default_value(x, resolver))
.collect::<Result<Vec<_>, _>>()?
)),
ast::ExprKind::Name { id, .. } if id == &"none".into() => Ok(SymbolValue::OptionNone),
ast::ExprKind::Name { id, .. } => {
resolver.get_default_param_value(default).ok_or_else(
|| HashSet::from([
format!(
"`{}` cannot be used as a default parameter at {} \
(not primitive type, option or tuple / not defined?)",
id,
default.location
),
])
)
}
_ => Err(HashSet::from([
format!(
"unsupported default parameter (not primitive type, option or tuple) at {}",
default.location
),
]))
}
}

View File

@ -1,149 +0,0 @@
use std::{
borrow::BorrowMut,
collections::{HashMap, HashSet},
fmt::Debug,
iter::FromIterator,
sync::Arc,
};
use super::codegen::CodeGenContext;
use super::typecheck::type_inferencer::PrimitiveStore;
use super::typecheck::typedef::{FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, Unifier, VarMap};
use crate::{
codegen::CodeGenerator,
symbol_resolver::{SymbolResolver, ValueEnum},
typecheck::{type_inferencer::CodeLocation, typedef::CallId},
};
use inkwell::values::BasicValueEnum;
use itertools::{izip, Itertools};
use nac3parser::ast::{self, Location, Stmt, StrRef};
use parking_lot::RwLock;
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash, Debug)]
pub struct DefinitionId(pub usize);
pub mod builtins;
pub mod composer;
pub mod helper;
pub mod numpy;
pub mod type_annotation;
use composer::*;
use type_annotation::*;
#[cfg(test)]
mod test;
type GenCallCallback =
dyn for<'ctx, 'a> Fn(
&mut CodeGenContext<'ctx, 'a>,
Option<(Type, ValueEnum<'ctx>)>,
(&FunSignature, DefinitionId),
Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
&mut dyn CodeGenerator,
) -> Result<Option<BasicValueEnum<'ctx>>, String>
+ Send
+ Sync;
pub struct GenCall {
fp: Box<GenCallCallback>,
}
impl GenCall {
#[must_use]
pub fn new(fp: Box<GenCallCallback>) -> GenCall {
GenCall { fp }
}
/// Creates a dummy instance of [`GenCall`], which invokes [`unreachable!()`] with the given
/// `reason`.
#[must_use]
pub fn create_dummy(reason: String) -> GenCall {
Self::new(Box::new(move |_, _, _, _, _| unreachable!("{reason}")))
}
pub fn run<'ctx>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
obj: Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
generator: &mut dyn CodeGenerator,
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
(self.fp)(ctx, obj, fun, args, generator)
}
}
impl Debug for GenCall {
fn fmt(&self, _: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct FunInstance {
pub body: Arc<Vec<Stmt<Option<Type>>>>,
pub calls: Arc<HashMap<CodeLocation, CallId>>,
pub subst: VarMap,
pub unifier_id: usize,
}
#[derive(Debug, Clone)]
pub enum TopLevelDef {
Class {
/// Name for error messages and symbols.
name: StrRef,
/// Object ID used for [TypeEnum].
object_id: DefinitionId,
/// type variables bounded to the class.
type_vars: Vec<Type>,
/// Class fields.
///
/// Name and type is mutable.
fields: Vec<(StrRef, Type, bool)>,
/// Class methods, pointing to the corresponding function definition.
methods: Vec<(StrRef, Type, DefinitionId)>,
/// Ancestor classes, including itself.
ancestors: Vec<TypeAnnotation>,
/// Symbol resolver of the module defined the class; [None] if it is built-in type.
resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>,
/// Constructor type.
constructor: Option<Type>,
/// Definition location.
loc: Option<Location>,
},
Function {
/// Prefix for symbol, should be unique globally.
name: String,
/// Simple name, the same as in method/function definition.
simple_name: StrRef,
/// Function signature.
signature: Type,
/// Instantiated type variable IDs.
var_id: Vec<u32>,
/// Function instance to symbol mapping
///
/// * Key: String representation of type variable values, sorted by variable ID in ascending
/// order, including type variables associated with the class.
/// * Value: Function symbol name.
instance_to_symbol: HashMap<String, String>,
/// Function instances to annotated AST mapping
///
/// * Key: String representation of type variable values, sorted by variable ID in ascending
/// order, including type variables associated with the class. Excluding rigid type
/// variables.
///
/// Rigid type variables that would be substituted when the function is instantiated.
instance_to_stmt: HashMap<String, FunInstance>,
/// Symbol resolver of the module defined the class.
resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>,
/// Custom code generation callback.
codegen_callback: Option<Arc<GenCall>>,
/// Definition location.
loc: Option<Location>,
},
}
pub struct TopLevelContext {
pub definitions: Arc<RwLock<Vec<Arc<RwLock<TopLevelDef>>>>>,
pub unifiers: Arc<RwLock<Vec<(SharedUnifier, PrimitiveStore)>>>,
pub personality_symbol: Option<String>,
}

View File

@ -1,103 +0,0 @@
use itertools::Itertools;
use crate::{
toplevel::helper::PRIMITIVE_DEF_IDS,
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{Type, TypeEnum, Unifier, VarMap},
},
};
/// Creates a `ndarray` [`Type`] with the given type arguments.
///
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
pub fn make_ndarray_ty(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
dtype: Option<Type>,
ndims: Option<Type>,
) -> Type {
subst_ndarray_tvars(unifier, primitives.ndarray, dtype, ndims)
}
/// Substitutes type variables in `ndarray`.
///
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
pub fn subst_ndarray_tvars(
unifier: &mut Unifier,
ndarray: Type,
dtype: Option<Type>,
ndims: Option<Type>,
) -> Type {
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
};
debug_assert_eq!(*obj_id, PRIMITIVE_DEF_IDS.ndarray);
if dtype.is_none() && ndims.is_none() {
return ndarray
}
let tvar_ids = params.iter()
.map(|(obj_id, _)| *obj_id)
.collect_vec();
debug_assert_eq!(tvar_ids.len(), 2);
let mut tvar_subst = VarMap::new();
if let Some(dtype) = dtype {
tvar_subst.insert(tvar_ids[0], dtype);
}
if let Some(ndims) = ndims {
tvar_subst.insert(tvar_ids[1], ndims);
}
unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray)
}
fn unpack_ndarray_tvars(
unifier: &mut Unifier,
ndarray: Type,
) -> Vec<(u32, Type)> {
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
};
debug_assert_eq!(*obj_id, PRIMITIVE_DEF_IDS.ndarray);
debug_assert_eq!(params.len(), 2);
params.iter()
.sorted_by_key(|(obj_id, _)| *obj_id)
.map(|(var_id, ty)| (*var_id, *ty))
.collect_vec()
}
/// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds
/// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray`
/// respectively.
pub fn unpack_ndarray_var_ids(
unifier: &mut Unifier,
ndarray: Type,
) -> (u32, u32) {
unpack_ndarray_tvars(unifier, ndarray)
.into_iter()
.map(|v| v.0)
.collect_tuple()
.unwrap()
}
/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to
/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively.
pub fn unpack_ndarray_var_tys(
unifier: &mut Unifier,
ndarray: Type,
) -> (Type, Type) {
unpack_ndarray_tvars(unifier, ndarray)
.into_iter()
.map(|v| v.1)
.collect_tuple()
.unwrap()
}

View File

@ -1,12 +0,0 @@
---
source: nac3core/src/toplevel/test.rs
expression: res_vec
---
[
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [238]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
]

View File

@ -1,15 +0,0 @@
---
source: nac3core/src/toplevel/test.rs
expression: res_vec
---
[
"Class {\nname: \"A\",\nancestors: [\"A[T]\"],\nfields: [\"a\", \"b\", \"c\"],\nmethods: [(\"__init__\", \"fn[[t:T], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"T\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar227]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar227\"]\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
]

View File

@ -1,13 +0,0 @@
---
source: nac3core/src/toplevel/test.rs
expression: res_vec
---
[
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [240]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [245]\n}\n",
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
]

View File

@ -1,13 +0,0 @@
---
source: nac3core/src/toplevel/test.rs
expression: res_vec
---
[
"Class {\nname: \"A\",\nancestors: [\"A[typevar226, typevar227]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar226\", \"typevar227\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:B], B]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.bar\",\nsig: \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\",\nvar_id: []\n}\n",
]

View File

@ -1,17 +0,0 @@
---
source: nac3core/src/toplevel/test.rs
expression: res_vec
---
[
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [246]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [254]\n}\n",
]

View File

@ -1,9 +0,0 @@
---
source: nac3core/src/toplevel/test.rs
assertion_line: 549
expression: res_vec
---
[
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [],\nmethods: [],\ntype_vars: []\n}\n",
]

View File

@ -1,818 +0,0 @@
use crate::{
codegen::CodeGenContext,
symbol_resolver::{SymbolResolver, ValueEnum},
toplevel::DefinitionId,
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{Type, Unifier},
},
};
use indoc::indoc;
use nac3parser::{ast::fold::Fold, parser::parse_program};
use parking_lot::Mutex;
use std::{collections::HashMap, sync::Arc};
use test_case::test_case;
use super::*;
struct ResolverInternal {
id_to_type: Mutex<HashMap<StrRef, Type>>,
id_to_def: Mutex<HashMap<StrRef, DefinitionId>>,
class_names: Mutex<HashMap<StrRef, Type>>,
}
impl ResolverInternal {
fn add_id_def(&self, id: StrRef, def: DefinitionId) {
self.id_to_def.lock().insert(id, def);
}
fn add_id_type(&self, id: StrRef, ty: Type) {
self.id_to_type.lock().insert(id, ty);
}
}
struct Resolver(Arc<ResolverInternal>);
impl SymbolResolver for Resolver {
fn get_default_param_value(
&self,
_: &ast::Expr,
) -> Option<crate::symbol_resolver::SymbolValue> {
unimplemented!()
}
fn get_symbol_type(
&self,
_: &mut Unifier,
_: &[Arc<RwLock<TopLevelDef>>],
_: &PrimitiveStore,
str: StrRef,
) -> Result<Type, String> {
self.0
.id_to_type
.lock()
.get(&str)
.cloned()
.ok_or_else(|| format!("cannot find symbol `{}`", str))
}
fn get_symbol_value<'ctx, 'a>(
&self,
_: StrRef,
_: &mut CodeGenContext<'ctx, 'a>,
) -> Option<ValueEnum<'ctx>> {
unimplemented!()
}
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> {
self.0.id_to_def.lock().get(&id).cloned()
.ok_or_else(|| HashSet::from(["Unknown identifier".to_string()]))
}
fn get_string_id(&self, _: &str) -> i32 {
unimplemented!()
}
fn get_exception_id(&self, _tyid: usize) -> usize {
unimplemented!()
}
}
#[test_case(
vec![
indoc! {"
def fun(a: int32) -> int32:
return a
"},
indoc! {"
class A:
def __init__(self):
self.a: int32 = 3
"},
indoc! {"
class B:
def __init__(self):
self.b: float = 4.3
def fun(self):
self.b = self.b + 3.0
"},
indoc! {"
def foo(a: float):
a + 1.0
"},
indoc! {"
class C(B):
def __init__(self):
self.c: int32 = 4
self.a: bool = True
"},
];
"register"
)]
fn test_simple_register(source: Vec<&str>) {
let mut composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 64).0;
for s in source {
let ast = parse_program(s, Default::default()).unwrap();
let ast = ast[0].clone();
composer.register_top_level(ast, None, "".into(), false).unwrap();
}
}
#[test_case(
indoc! {"
class A:
def foo(self):
pass
a = A()
"};
"register"
)]
fn test_simple_register_without_constructor(source: &str) {
let mut composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 64).0;
let ast = parse_program(source, Default::default()).unwrap();
let ast = ast[0].clone();
composer.register_top_level(ast, None, "".into(), true).unwrap();
}
#[test_case(
vec![
indoc! {"
def fun(a: int32) -> int32:
return a
"},
indoc! {"
def foo(a: float):
a + 1.0
"},
indoc! {"
def f(b: int64) -> int32:
return 3
"},
],
vec![
"fn[[a:0], 0]",
"fn[[a:2], 4]",
"fn[[b:1], 0]",
],
vec![
"fun",
"foo",
"f"
];
"function compose"
)]
fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&str>) {
let mut composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 64).0;
let internal_resolver = Arc::new(ResolverInternal {
id_to_def: Default::default(),
id_to_type: Default::default(),
class_names: Default::default(),
});
let resolver =
Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
for s in source {
let ast = parse_program(s, Default::default()).unwrap();
let ast = ast[0].clone();
let (id, def_id, ty) =
composer.register_top_level(ast, Some(resolver.clone()), "".into(), false).unwrap();
internal_resolver.add_id_def(id, def_id);
if let Some(ty) = ty {
internal_resolver.add_id_type(id, ty);
}
}
composer.start_analysis(true).unwrap();
for (i, (def, _)) in composer.definition_ast_list.iter().skip(composer.builtin_num).enumerate()
{
let def = &*def.read();
if let TopLevelDef::Function { signature, name, .. } = def {
let ty_str = composer.unifier.internal_stringify(
*signature,
&mut |id| id.to_string(),
&mut |id| id.to_string(),
&mut None,
);
assert_eq!(ty_str, tys[i]);
assert_eq!(name, names[i]);
}
}
}
#[test_case(
vec![
indoc! {"
class A():
a: int32
def __init__(self):
self.a = 3
def fun(self, b: B):
pass
def foo(self, a: T, b: V):
pass
"},
indoc! {"
class B(C):
def __init__(self):
pass
"},
indoc! {"
class C(A):
def __init__(self):
pass
def fun(self, b: B):
a = 1
pass
"},
indoc! {"
def foo(a: A):
pass
"},
indoc! {"
def ff(a: T) -> V:
pass
"}
],
vec![];
"simple class compose"
)]
#[test_case(
vec![
indoc! {"
class Generic_A(Generic[V], B):
a: int64
def __init__(self):
self.a = 123123123123
def fun(self, a: int32) -> V:
pass
"},
indoc! {"
class B:
aa: bool
def __init__(self):
self.aa = False
def foo(self, b: T):
pass
"}
],
vec![];
"generic class"
)]
#[test_case(
vec![
indoc! {"
def foo(a: list[int32], b: tuple[T, float]) -> A[B, bool]:
pass
"},
indoc! {"
class A(Generic[T, V]):
a: T
b: V
def __init__(self, v: V):
self.a = 1
self.b = v
def fun(self, a: T) -> V:
pass
"},
indoc! {"
def gfun(a: A[list[float], int32]):
pass
"},
indoc! {"
class B:
def __init__(self):
pass
"}
],
vec![];
"list tuple generic"
)]
#[test_case(
vec![
indoc! {"
class A(Generic[T, V]):
a: A[float, bool]
b: B
def __init__(self, a: A[float, bool], b: B):
self.a = a
self.b = b
def fun(self, a: A[float, bool]) -> A[bool, int32]:
pass
"},
indoc! {"
class B(A[int64, bool]):
def __init__(self):
pass
def foo(self, b: B) -> B:
pass
def bar(self, a: A[list[B], int32]) -> tuple[A[virtual[A[B, int32]], bool], B]:
pass
"}
],
vec![];
"self1"
)]
#[test_case(
vec![
indoc! {"
class A(Generic[T]):
a: int32
b: T
c: A[int64]
def __init__(self, t: T):
self.a = 3
self.b = T
def fun(self, a: int32, b: T) -> list[virtual[B[bool]]]:
pass
def foo(self, c: C):
pass
"},
indoc! {"
class B(Generic[V], A[float]):
d: C
def __init__(self):
pass
def fun(self, a: int32, b: T) -> list[virtual[B[bool]]]:
# override
pass
"},
indoc! {"
class C(B[bool]):
e: int64
def __init__(self):
pass
"}
],
vec![];
"inheritance_override"
)]
#[test_case(
vec![
indoc! {"
class A(Generic[T]):
def __init__(self):
pass
def fun(self, a: A[T]) -> A[T]:
pass
"}
],
vec!["application of type vars to generic class is not currently supported (at unknown:4:24)"];
"err no type var in generic app"
)]
#[test_case(
vec![
indoc! {"
class A(B):
def __init__(self):
pass
"},
indoc! {"
class B(A):
def __init__(self):
pass
"}
],
vec!["cyclic inheritance detected"];
"cyclic1"
)]
#[test_case(
vec![
indoc! {"
class A(B[bool, int64]):
def __init__(self):
pass
"},
indoc! {"
class B(Generic[V, T], C[int32]):
def __init__(self):
pass
"},
indoc! {"
class C(Generic[T], A):
def __init__(self):
pass
"},
],
vec!["cyclic inheritance detected"];
"cyclic2"
)]
#[test_case(
vec![
indoc! {"
class A:
pass
"}
],
vec!["5: Class {\nname: \"A\",\ndef_id: DefinitionId(5),\nancestors: [CustomClassKind { id: DefinitionId(5), params: [] }],\nfields: [],\nmethods: [],\ntype_vars: []\n}"];
"simple pass in class"
)]
#[test_case(
vec![indoc! {"
class A:
def __init__():
pass
"}],
vec!["__init__ method must have a `self` parameter (at unknown:2:5)"];
"err no self_1"
)]
#[test_case(
vec![
indoc! {"
class A(B, Generic[T], C):
def __init__(self):
pass
"},
indoc! {"
class B:
def __init__(self):
pass
"},
indoc! {"
class C:
def __init__(self):
pass
"}
],
vec!["a class definition can only have at most one base class declaration and one generic declaration (at unknown:1:24)"];
"err multiple inheritance"
)]
#[test_case(
vec![
indoc! {"
class A(Generic[T]):
a: int32
b: T
c: A[int64]
def __init__(self, t: T):
self.a = 3
self.b = T
def fun(self, a: int32, b: T) -> list[virtual[B[bool]]]:
pass
"},
indoc! {"
class B(Generic[V], A[float]):
def __init__(self):
pass
def fun(self, a: int32, b: T) -> list[virtual[B[int32]]]:
# override
pass
"}
],
vec!["method fun has same name as ancestors' method, but incompatible type"];
"err_incompatible_inheritance_method"
)]
#[test_case(
vec![
indoc! {"
class A(Generic[T]):
a: int32
b: T
c: A[int64]
def __init__(self, t: T):
self.a = 3
self.b = T
def fun(self, a: int32, b: T) -> list[virtual[B[bool]]]:
pass
"},
indoc! {"
class B(Generic[V], A[float]):
a: int32
def __init__(self):
pass
def fun(self, a: int32, b: T) -> list[virtual[B[bool]]]:
# override
pass
"}
],
vec!["field `a` has already declared in the ancestor classes"];
"err_incompatible_inheritance_field"
)]
#[test_case(
vec![
indoc! {"
class A:
def __init__(self):
pass
"},
indoc! {"
class A:
a: int32
def __init__(self):
pass
"}
],
vec!["duplicate definition of class `A` (at unknown:1:1)"];
"class same name"
)]
fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
let print = false;
let mut composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 64).0;
let internal_resolver = make_internal_resolver_with_tvar(
vec![
("T".into(), vec![]),
("V".into(), vec![composer.primitives_ty.bool, composer.primitives_ty.int32]),
("G".into(), vec![composer.primitives_ty.bool, composer.primitives_ty.int64]),
],
&mut composer.unifier,
print,
);
let resolver =
Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
for s in source {
let ast = parse_program(s, Default::default()).unwrap();
let ast = ast[0].clone();
let (id, def_id, ty) = {
match composer.register_top_level(ast, Some(resolver.clone()), "".into(), false) {
Ok(x) => x,
Err(msg) => {
if print {
println!("{}", msg);
} else {
assert_eq!(res[0], msg);
}
return;
}
}
};
internal_resolver.add_id_def(id, def_id);
if let Some(ty) = ty {
internal_resolver.add_id_type(id, ty);
}
}
if let Err(msg) = composer.start_analysis(false) {
if print {
println!("{}", msg.iter().sorted().join("\n----------\n"));
} else {
assert_eq!(res[0], msg.iter().next().unwrap());
}
} else {
// skip 5 to skip primitives
let mut res_vec: Vec<String> = Vec::new();
for (def, _) in composer.definition_ast_list.iter().skip(composer.builtin_num) {
let def = &*def.read();
res_vec.push(format!("{}\n", def.to_string(composer.unifier.borrow_mut())));
}
insta::assert_debug_snapshot!(res_vec);
}
}
#[test_case(
vec![
indoc! {"
def fun(a: int32, b: int32) -> int32:
return a + b
"},
indoc! {"
def fib(n: int32) -> int32:
if n <= 2:
return 1
a = fib(n - 1)
b = fib(n - 2)
return fib(n - 1)
"}
],
vec![];
"simple function"
)]
#[test_case(
vec![
indoc! {"
class A:
a: int32
def __init__(self):
self.a = 3
def fun(self) -> int32:
b = self.a + 3
return b * self.a
def clone(self) -> A:
SELF = self
return SELF
def sum(self) -> int32:
if self.a == 0:
return self.a
else:
a = self.a
self.a = self.a - 1
return a + self.sum()
def fib(self, a: int32) -> int32:
if a <= 2:
return 1
return self.fib(a - 1) + self.fib(a - 2)
"},
indoc! {"
def fun(a: A) -> int32:
return a.fun() + 2
"}
],
vec![];
"simple class body"
)]
#[test_case(
vec![
indoc! {"
def fun(a: V, c: G, t: T) -> V:
b = a
cc = c
ret = fun(b, cc, t)
return ret * ret
"},
indoc! {"
def sum_three(l: list[V]) -> V:
return l[0] + l[1] + l[2]
"},
indoc! {"
def sum_sq_pair(p: tuple[V, V]) -> list[V]:
a = p[0]
b = p[1]
a = a**a
b = b**b
return [a, b]
"}
],
vec![];
"type var fun"
)]
#[test_case(
vec![
indoc! {"
class A(Generic[G]):
a: G
b: bool
def __init__(self, aa: G):
self.a = aa
if 2 > 1:
self.b = True
else:
# self.b = False
pass
def fun(self, a: G) -> list[G]:
ret = [a, self.a]
return ret if self.b else self.fun(self.a)
"}
],
vec![];
"type var class"
)]
#[test_case(
vec![
indoc! {"
class A:
def fun(self):
pass
"},
indoc!{"
class B:
a: int32
b: bool
def __init__(self):
# self.b = False
if 3 > 2:
self.a = 3
self.b = False
else:
self.a = 4
self.b = True
"}
],
vec![];
"no_init_inst_check"
)]
fn test_inference(source: Vec<&str>, res: Vec<&str>) {
let print = true;
let mut composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 64).0;
let internal_resolver = make_internal_resolver_with_tvar(
vec![
("T".into(), vec![]),
(
"V".into(),
vec![
composer.primitives_ty.float,
composer.primitives_ty.int32,
composer.primitives_ty.int64,
],
),
("G".into(), vec![composer.primitives_ty.bool, composer.primitives_ty.int64]),
],
&mut composer.unifier,
print,
);
let resolver =
Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
for s in source {
let ast = parse_program(s, Default::default()).unwrap();
let ast = ast[0].clone();
let (id, def_id, ty) = {
match composer.register_top_level(ast, Some(resolver.clone()), "".into(), false) {
Ok(x) => x,
Err(msg) => {
if print {
println!("{}", msg);
} else {
assert_eq!(res[0], msg);
}
return;
}
}
};
internal_resolver.add_id_def(id, def_id);
if let Some(ty) = ty {
internal_resolver.add_id_type(id, ty);
}
}
if let Err(msg) = composer.start_analysis(true) {
if print {
println!("{}", msg.iter().sorted().join("\n----------\n"));
} else {
assert_eq!(res[0], msg.iter().next().unwrap());
}
} else {
// skip 5 to skip primitives
let mut stringify_folder = TypeToStringFolder { unifier: &mut composer.unifier };
for (_i, (def, _)) in
composer.definition_ast_list.iter().skip(composer.builtin_num).enumerate()
{
let def = &*def.read();
if let TopLevelDef::Function { instance_to_stmt, name, .. } = def {
println!(
"=========`{}`: number of instances: {}===========",
name,
instance_to_stmt.len()
);
for inst in instance_to_stmt.iter() {
let ast = &inst.1.body;
for b in ast.iter() {
println!("{:?}", stringify_folder.fold_stmt(b.clone()).unwrap());
println!("--------------------");
}
println!("\n");
}
}
}
}
}
fn make_internal_resolver_with_tvar(
tvars: Vec<(StrRef, Vec<Type>)>,
unifier: &mut Unifier,
print: bool,
) -> Arc<ResolverInternal> {
let res: Arc<ResolverInternal> = ResolverInternal {
id_to_def: Default::default(),
id_to_type: tvars
.into_iter()
.map(|(name, range)| {
(name, {
let (ty, id) = unifier.get_fresh_var_with_range(range.as_slice(), None, None);
if print {
println!("{}: {:?}, typevar{}", name, ty, id);
}
ty
})
})
.collect::<HashMap<_, _>>()
.into(),
class_names: Default::default(),
}
.into();
if print {
println!();
}
res
}
struct TypeToStringFolder<'a> {
unifier: &'a mut Unifier,
}
impl<'a> Fold<Option<Type>> for TypeToStringFolder<'a> {
type TargetU = String;
type Error = String;
fn map_user(&mut self, user: Option<Type>) -> Result<Self::TargetU, Self::Error> {
Ok(if let Some(ty) = user {
self.unifier.internal_stringify(
ty,
&mut |id| format!("class{}", id.to_string()),
&mut |id| format!("typevar{}", id.to_string()),
&mut None,
)
} else {
"None".into()
})
}
}

View File

@ -1,644 +0,0 @@
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::PRIMITIVE_DEF_IDS;
use crate::typecheck::typedef::VarMap;
use super::*;
use nac3parser::ast::Constant;
#[derive(Clone, Debug)]
pub enum TypeAnnotation {
Primitive(Type),
// we use type vars kind at params to represent self type
CustomClass {
id: DefinitionId,
// params can also be type var
params: Vec<TypeAnnotation>,
},
// can only be CustomClassKind
Virtual(Box<TypeAnnotation>),
TypeVar(Type),
/// A `Literal` allowing a subset of literals.
Literal(Vec<Constant>),
List(Box<TypeAnnotation>),
Tuple(Vec<TypeAnnotation>),
}
impl TypeAnnotation {
pub fn stringify(&self, unifier: &mut Unifier) -> String {
use TypeAnnotation::*;
match self {
Primitive(ty) | TypeVar(ty) => unifier.stringify(*ty),
CustomClass { id, params } => {
let class_name = if let Some(ref top) = unifier.top_level {
if let TopLevelDef::Class { name, .. } =
&*top.definitions.read()[id.0].read()
{
(*name).into()
} else {
unreachable!()
}
} else {
format!("class_def_{}", id.0)
};
format!(
"{}{}",
class_name,
{
let param_list = params.iter().map(|p| p.stringify(unifier)).collect_vec().join(", ");
if param_list.is_empty() {
String::new()
} else {
format!("[{param_list}]")
}
}
)
}
Literal(values) => format!("Literal({})", values.iter().map(|v| format!("{v:?}")).join(", ")),
Virtual(ty) => format!("virtual[{}]", ty.stringify(unifier)),
List(ty) => format!("list[{}]", ty.stringify(unifier)),
Tuple(types) => {
format!("tuple[{}]", types.iter().map(|p| p.stringify(unifier)).collect_vec().join(", "))
}
}
}
}
/// Parses an AST expression `expr` into a [`TypeAnnotation`].
///
/// * `locked` - A [`HashMap`] containing the IDs of known definitions, mapped to a [`Vec`] of all
/// generic variables associated with the definition.
/// * `type_var` - The type variable associated with the type argument currently being parsed. Pass
/// [`None`] when this function is invoked externally.
pub fn parse_ast_to_type_annotation_kinds<T>(
resolver: &(dyn SymbolResolver + Send + Sync),
top_level_defs: &[Arc<RwLock<TopLevelDef>>],
unifier: &mut Unifier,
primitives: &PrimitiveStore,
expr: &ast::Expr<T>,
// the key stores the type_var of this topleveldef::class, we only need this field here
locked: HashMap<DefinitionId, Vec<Type>>,
) -> Result<TypeAnnotation, HashSet<String>> {
let name_handle = |id: &StrRef,
unifier: &mut Unifier,
locked: HashMap<DefinitionId, Vec<Type>>| {
if id == &"int32".into() {
Ok(TypeAnnotation::Primitive(primitives.int32))
} else if id == &"int64".into() {
Ok(TypeAnnotation::Primitive(primitives.int64))
} else if id == &"uint32".into() {
Ok(TypeAnnotation::Primitive(primitives.uint32))
} else if id == &"uint64".into() {
Ok(TypeAnnotation::Primitive(primitives.uint64))
} else if id == &"float".into() {
Ok(TypeAnnotation::Primitive(primitives.float))
} else if id == &"bool".into() {
Ok(TypeAnnotation::Primitive(primitives.bool))
} else if id == &"str".into() {
Ok(TypeAnnotation::Primitive(primitives.str))
} else if id == &"Exception".into() {
Ok(TypeAnnotation::CustomClass { id: PRIMITIVE_DEF_IDS.exception, params: Vec::default() })
} else if let Ok(obj_id) = resolver.get_identifier_def(*id) {
let type_vars = {
let def_read = top_level_defs[obj_id.0].try_read();
if let Some(def_read) = def_read {
if let TopLevelDef::Class { type_vars, .. } = &*def_read {
type_vars.clone()
} else {
return Err(HashSet::from([
format!(
"function cannot be used as a type (at {})",
expr.location
),
]))
}
} else {
locked.get(&obj_id).unwrap().clone()
}
};
// check param number here
if !type_vars.is_empty() {
return Err(HashSet::from([
format!(
"expect {} type variable parameter but got 0 (at {})",
type_vars.len(),
expr.location,
),
]))
}
Ok(TypeAnnotation::CustomClass { id: obj_id, params: vec![] })
} else if let Ok(ty) = resolver.get_symbol_type(unifier, top_level_defs, primitives, *id) {
if let TypeEnum::TVar { .. } = unifier.get_ty(ty).as_ref() {
let var = unifier.get_fresh_var(Some(*id), Some(expr.location)).0;
unifier.unify(var, ty).unwrap();
Ok(TypeAnnotation::TypeVar(ty))
} else {
Err(HashSet::from([
format!("`{}` is not a valid type annotation (at {})", id, expr.location),
]))
}
} else {
Err(HashSet::from([
format!("`{}` is not a valid type annotation (at {})", id, expr.location),
]))
}
};
let class_name_handle =
|id: &StrRef,
slice: &ast::Expr<T>,
unifier: &mut Unifier,
mut locked: HashMap<DefinitionId, Vec<Type>>| {
if ["virtual".into(), "Generic".into(), "list".into(), "tuple".into(), "Option".into()].contains(id)
{
return Err(HashSet::from([
format!("keywords cannot be class name (at {})", expr.location),
]))
}
let obj_id = resolver.get_identifier_def(*id)?;
let type_vars = {
let def_read = top_level_defs[obj_id.0].try_read();
if let Some(def_read) = def_read {
let TopLevelDef::Class { type_vars, .. } = &*def_read else {
unreachable!("must be class here")
};
type_vars.clone()
} else {
locked.get(&obj_id).unwrap().clone()
}
};
// we do not check whether the application of type variables are compatible here
let param_type_infos = {
let params_ast = if let ast::ExprKind::Tuple { elts, .. } = &slice.node {
elts.iter().collect_vec()
} else {
vec![slice]
};
if type_vars.len() != params_ast.len() {
return Err(HashSet::from([
format!(
"expect {} type parameters but got {} (at {})",
type_vars.len(),
params_ast.len(),
params_ast[0].location,
),
]))
}
let result = params_ast
.iter()
.map(|x| {
parse_ast_to_type_annotation_kinds(
resolver,
top_level_defs,
unifier,
primitives,
x,
{
locked.insert(obj_id, type_vars.clone());
locked.clone()
},
)
})
.collect::<Result<Vec<_>, _>>()?;
// make sure the result do not contain any type vars
let no_type_var =
result.iter().all(|x| get_type_var_contained_in_type_annotation(x).is_empty());
if no_type_var {
result
} else {
return Err(HashSet::from([
format!(
"application of type vars to generic class is not currently supported (at {})",
params_ast[0].location
),
]))
}
};
Ok(TypeAnnotation::CustomClass { id: obj_id, params: param_type_infos })
};
match &expr.node {
ast::ExprKind::Name { id, .. } => name_handle(id, unifier, locked),
// virtual
ast::ExprKind::Subscript { value, slice, .. }
if {
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"virtual".into())
} =>
{
let def = parse_ast_to_type_annotation_kinds(
resolver,
top_level_defs,
unifier,
primitives,
slice.as_ref(),
locked,
)?;
if !matches!(def, TypeAnnotation::CustomClass { .. }) {
unreachable!("must be concretized custom class kind in the virtual")
}
Ok(TypeAnnotation::Virtual(def.into()))
}
// list
ast::ExprKind::Subscript { value, slice, .. }
if {
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"list".into())
} =>
{
let def_ann = parse_ast_to_type_annotation_kinds(
resolver,
top_level_defs,
unifier,
primitives,
slice.as_ref(),
locked,
)?;
Ok(TypeAnnotation::List(def_ann.into()))
}
// option
ast::ExprKind::Subscript { value, slice, .. }
if {
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"Option".into())
} =>
{
let def_ann = parse_ast_to_type_annotation_kinds(
resolver,
top_level_defs,
unifier,
primitives,
slice.as_ref(),
locked,
)?;
let id =
if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty(primitives.option).as_ref() {
*obj_id
} else {
unreachable!()
};
Ok(TypeAnnotation::CustomClass { id, params: vec![def_ann] })
}
// tuple
ast::ExprKind::Subscript { value, slice, .. }
if {
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"tuple".into())
} =>
{
let tup_elts = {
if let ast::ExprKind::Tuple { elts, .. } = &slice.node {
elts.as_slice()
} else {
std::slice::from_ref(slice.as_ref())
}
};
let type_annotations = tup_elts
.iter()
.map(|e| {
parse_ast_to_type_annotation_kinds(
resolver,
top_level_defs,
unifier,
primitives,
e,
locked.clone(),
)
})
.collect::<Result<Vec<_>, _>>()?;
Ok(TypeAnnotation::Tuple(type_annotations))
}
// Literal
ast::ExprKind::Subscript { value, slice, .. }
if {
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"Literal".into())
} => {
let tup_elts = {
if let ast::ExprKind::Tuple { elts, .. } = &slice.node {
elts.as_slice()
} else {
std::slice::from_ref(slice.as_ref())
}
};
let type_annotations = tup_elts
.iter()
.map(|e| {
match &e.node {
ast::ExprKind::Constant { value, .. } => Ok(
TypeAnnotation::Literal(vec![value.clone()]),
),
_ => parse_ast_to_type_annotation_kinds(
resolver,
top_level_defs,
unifier,
primitives,
e,
locked.clone(),
),
}
})
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flat_map(|type_ann| match type_ann {
TypeAnnotation::Literal(values) => values,
_ => unreachable!(),
})
.collect_vec();
if type_annotations.len() == 1 {
Ok(TypeAnnotation::Literal(type_annotations))
} else {
Err(HashSet::from([
format!("multiple literal bounds are currently unsupported (at {})", value.location)
]))
}
}
// custom class
ast::ExprKind::Subscript { value, slice, .. } => {
if let ast::ExprKind::Name { id, .. } = &value.node {
class_name_handle(id, slice, unifier, locked)
} else {
Err(HashSet::from([
format!("unsupported expression type for class name (at {})", value.location)
]))
}
}
ast::ExprKind::Constant { value, .. } => {
Ok(TypeAnnotation::Literal(vec![value.clone()]))
}
_ => Err(HashSet::from([
format!("unsupported expression for type annotation (at {})", expr.location),
])),
}
}
// no need to have the `locked` parameter, unlike the `parse_ast_to_type_annotation_kinds`, since
// when calling this function, there should be no topleveldef::class being write, and this function
// also only read the toplevedefs
pub fn get_type_from_type_annotation_kinds(
top_level_defs: &[Arc<RwLock<TopLevelDef>>],
unifier: &mut Unifier,
ann: &TypeAnnotation,
subst_list: &mut Option<Vec<Type>>
) -> Result<Type, HashSet<String>> {
match ann {
TypeAnnotation::CustomClass { id: obj_id, params } => {
let def_read = top_level_defs[obj_id.0].read();
let class_def: &TopLevelDef = &def_read;
let TopLevelDef::Class { fields, methods, type_vars, .. } = class_def else {
unreachable!("should be class def here")
};
if type_vars.len() != params.len() {
return Err(HashSet::from([
format!(
"unexpected number of type parameters: expected {} but got {}",
type_vars.len(),
params.len()
),
]))
}
let param_ty = params
.iter()
.map(|x| {
get_type_from_type_annotation_kinds(
top_level_defs,
unifier,
x,
subst_list
)
})
.collect::<Result<Vec<_>, _>>()?;
let subst = {
// check for compatible range
// TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check
let mut result = VarMap::new();
for (tvar, p) in type_vars.iter().zip(param_ty) {
match unifier.get_ty(*tvar).as_ref() {
TypeEnum::TVar { id, range, fields: None, name, loc, is_const_generic: false } => {
let ok: bool = {
// create a temp type var and unify to check compatibility
p == *tvar || {
let temp = unifier.get_fresh_var_with_range(
range.as_slice(),
*name,
*loc,
);
unifier.unify(temp.0, p).is_ok()
}
};
if ok {
result.insert(*id, p);
} else {
return Err(HashSet::from([
format!(
"cannot apply type {} to type variable with id {:?}",
unifier.internal_stringify(
p,
&mut |id| format!("class{id}"),
&mut |id| format!("typevar{id}"),
&mut None
),
*id
)
]))
}
}
TypeEnum::TVar { id, range, name, loc, is_const_generic: true, .. } => {
let ty = range[0];
let ok: bool = {
// create a temp type var and unify to check compatibility
p == *tvar || {
let temp = unifier.get_fresh_const_generic_var(
ty,
*name,
*loc,
);
unifier.unify(temp.0, p).is_ok()
}
};
if ok {
result.insert(*id, p);
} else {
return Err(HashSet::from([
format!(
"cannot apply type {} to type variable {}",
unifier.stringify(p),
name.unwrap_or_else(|| format!("typevar{id}").into()),
),
]))
}
}
_ => unreachable!("must be generic type var"),
}
}
result
};
let mut tobj_fields = methods
.iter()
.map(|(name, ty, _)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
// methods are immutable
(*name, (subst_ty, false))
})
.collect::<HashMap<_, _>>();
tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*name, (subst_ty, *mutability))
}));
let need_subst = !subst.is_empty();
let ty = unifier.add_ty(TypeEnum::TObj {
obj_id: *obj_id,
fields: tobj_fields,
params: subst,
});
if need_subst {
if let Some(wl) = subst_list.as_mut() {
wl.push(ty);
}
}
Ok(ty)
}
TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty),
TypeAnnotation::Literal(values) => {
let values = values.iter()
.map(SymbolValue::from_constant_inferred)
.collect::<Result<Vec<_>, _>>()
.map_err(|err| HashSet::from([err]))?;
let var = unifier.get_fresh_literal(values, None);
Ok(var)
}
TypeAnnotation::Virtual(ty) => {
let ty = get_type_from_type_annotation_kinds(
top_level_defs,
unifier,
ty.as_ref(),
subst_list
)?;
Ok(unifier.add_ty(TypeEnum::TVirtual { ty }))
}
TypeAnnotation::List(ty) => {
let ty = get_type_from_type_annotation_kinds(
top_level_defs,
unifier,
ty.as_ref(),
subst_list
)?;
Ok(unifier.add_ty(TypeEnum::TList { ty }))
}
TypeAnnotation::Tuple(tys) => {
let tys = tys
.iter()
.map(|x| {
get_type_from_type_annotation_kinds(top_level_defs, unifier, x, subst_list)
})
.collect::<Result<Vec<_>, _>>()?;
Ok(unifier.add_ty(TypeEnum::TTuple { ty: tys }))
}
}
}
/// given an def id, return a type annotation of self \
/// ```python
/// class A(Generic[T, V]):
/// def fun(self):
/// ```
/// the type of `self` should be similar to `A[T, V]`, where `T`, `V`
/// considered to be type variables associated with the class \
/// \
/// But note that here we do not make a duplication of `T`, `V`, we directly
/// use them as they are in the [`TopLevelDef::Class`] since those in the
/// `TopLevelDef::Class.type_vars` will be substitute later when seeing applications/instantiations
/// the Type of their fields and methods will also be subst when application/instantiation
#[must_use]
pub fn make_self_type_annotation(type_vars: &[Type], object_id: DefinitionId) -> TypeAnnotation {
TypeAnnotation::CustomClass {
id: object_id,
params: type_vars.iter().map(|ty| TypeAnnotation::TypeVar(*ty)).collect_vec(),
}
}
/// get all the occurences of type vars contained in a type annotation
/// e.g. `A[int, B[T], V, virtual[C[G]]]` => [T, V, G]
/// this function will not make a duplicate of type var
#[must_use]
pub fn get_type_var_contained_in_type_annotation(ann: &TypeAnnotation) -> Vec<TypeAnnotation> {
let mut result: Vec<TypeAnnotation> = Vec::new();
match ann {
TypeAnnotation::TypeVar(..) => result.push(ann.clone()),
TypeAnnotation::Virtual(ann) | TypeAnnotation::List(ann) => {
result.extend(get_type_var_contained_in_type_annotation(ann.as_ref()));
}
TypeAnnotation::CustomClass { params, .. } => {
for p in params {
result.extend(get_type_var_contained_in_type_annotation(p));
}
}
TypeAnnotation::Tuple(anns) => {
for a in anns {
result.extend(get_type_var_contained_in_type_annotation(a));
}
}
TypeAnnotation::Primitive(..) | TypeAnnotation::Literal { .. } => {}
}
result
}
/// check the type compatibility for overload
pub fn check_overload_type_annotation_compatible(
this: &TypeAnnotation,
other: &TypeAnnotation,
unifier: &mut Unifier,
) -> bool {
match (this, other) {
(TypeAnnotation::Primitive(a), TypeAnnotation::Primitive(b)) => a == b,
(TypeAnnotation::TypeVar(a), TypeAnnotation::TypeVar(b)) => {
let a = unifier.get_ty(*a);
let a = &*a;
let b = unifier.get_ty(*b);
let b = &*b;
let (
TypeEnum::TVar { id: a, fields: None, .. },
TypeEnum::TVar { id: b, fields: None, .. },
) = (a, b) else {
unreachable!("must be type var")
};
a == b
}
(TypeAnnotation::Virtual(a), TypeAnnotation::Virtual(b))
| (TypeAnnotation::List(a), TypeAnnotation::List(b)) => {
check_overload_type_annotation_compatible(a.as_ref(), b.as_ref(), unifier)
}
(TypeAnnotation::Tuple(a), TypeAnnotation::Tuple(b)) => {
a.len() == b.len() && {
a.iter()
.zip(b)
.all(|(a, b)| check_overload_type_annotation_compatible(a, b, unifier))
}
}
(
TypeAnnotation::CustomClass { id: a, params: a_p },
TypeAnnotation::CustomClass { id: b, params: b_p },
) => {
a.0 == b.0 && {
a_p.len() == b_p.len() && {
a_p.iter()
.zip(b_p)
.all(|(a, b)| check_overload_type_annotation_compatible(a, b, unifier))
}
}
}
_ => false,
}
}

View File

@ -0,0 +1,228 @@
use super::super::typedef::*;
use super::TopLevelContext;
use std::boxed::Box;
use std::collections::HashMap;
struct ContextStack<'a> {
/// stack level, starts from 0
level: u32,
/// stack of variable definitions containing (id, def, level) where `def` is the original
/// definition in `level-1`.
var_defs: Vec<(usize, VarDef<'a>, u32)>,
/// stack of symbol definitions containing (name, level) where `level` is the smallest level
/// where the name is assigned a value
sym_def: Vec<(&'a str, u32)>,
}
pub struct InferenceContext<'a> {
/// top level context
top_level: TopLevelContext<'a>,
/// list of primitive instances
primitives: Vec<Type>,
/// list of variable instances
variables: Vec<Type>,
/// identifier to type mapping.
sym_table: HashMap<&'a str, Type>,
/// resolution function reference, that may resolve unbounded identifiers to some type
resolution_fn: Box<dyn FnMut(&str) -> Result<Type, String>>,
/// stack
stack: ContextStack<'a>,
/// return type
result: Option<Type>,
}
// non-trivial implementations here
impl<'a> InferenceContext<'a> {
/// return a new `InferenceContext` from `TopLevelContext` and resolution function.
pub fn new(
top_level: TopLevelContext,
resolution_fn: Box<dyn FnMut(&str) -> Result<Type, String>>,
) -> InferenceContext {
let primitives = (0..top_level.primitive_defs.len())
.map(|v| TypeEnum::PrimitiveType(PrimitiveId(v)).into())
.collect();
let variables = (0..top_level.var_defs.len())
.map(|v| TypeEnum::TypeVariable(VariableId(v)).into())
.collect();
InferenceContext {
top_level,
primitives,
variables,
sym_table: HashMap::new(),
resolution_fn,
stack: ContextStack {
level: 0,
var_defs: Vec::new(),
sym_def: Vec::new(),
},
result: None,
}
}
/// execute the function with new scope.
/// variable assignment would be limited within the scope (not readable outside), and type
/// variable type guard would be limited within the scope.
/// returns the list of variables assigned within the scope, and the result of the function
pub fn with_scope<F, R>(&mut self, f: F) -> (Vec<(&'a str, Type)>, R)
where
F: FnOnce(&mut Self) -> R,
{
self.stack.level += 1;
let result = f(self);
self.stack.level -= 1;
while !self.stack.var_defs.is_empty() {
let (_, _, level) = self.stack.var_defs.last().unwrap();
if *level > self.stack.level {
let (id, def, _) = self.stack.var_defs.pop().unwrap();
self.top_level.var_defs[id] = def;
} else {
break;
}
}
let mut poped_names = Vec::new();
while !self.stack.sym_def.is_empty() {
let (_, level) = self.stack.sym_def.last().unwrap();
if *level > self.stack.level {
let (name, _) = self.stack.sym_def.pop().unwrap();
let ty = self.sym_table.remove(name).unwrap();
poped_names.push((name, ty));
} else {
break;
}
}
(poped_names, result)
}
/// assign a type to an identifier.
/// may return error if the identifier was defined but with different type
pub fn assign(&mut self, name: &'a str, ty: Type) -> Result<Type, String> {
if let Some(t) = self.sym_table.get_mut(name) {
if t == &ty {
Ok(ty)
} else {
Err("different types".into())
}
} else {
self.stack.sym_def.push((name, self.stack.level));
self.sym_table.insert(name, ty.clone());
Ok(ty)
}
}
/// check if an identifier is already defined
pub fn defined(&self, name: &str) -> bool {
self.sym_table.get(name).is_some()
}
/// get the type of an identifier
/// may return error if the identifier is not defined, and cannot be resolved with the
/// resolution function.
pub fn resolve(&mut self, name: &str) -> Result<Type, String> {
if let Some(t) = self.sym_table.get(name) {
Ok(t.clone())
} else {
self.resolution_fn.as_mut()(name)
}
}
/// restrict the bound of a type variable by replacing its definition.
/// used for implementing type guard
pub fn restrict(&mut self, id: VariableId, mut def: VarDef<'a>) {
std::mem::swap(self.top_level.var_defs.get_mut(id.0).unwrap(), &mut def);
self.stack.var_defs.push((id.0, def, self.stack.level));
}
pub fn set_result(&mut self, result: Option<Type>) {
self.result = result;
}
}
// trivial getters:
impl<'a> InferenceContext<'a> {
pub fn get_primitive(&self, id: PrimitiveId) -> Type {
self.primitives.get(id.0).unwrap().clone()
}
pub fn get_variable(&self, id: VariableId) -> Type {
self.variables.get(id.0).unwrap().clone()
}
pub fn get_fn_def(&self, name: &str) -> Option<&FnDef> {
self.top_level.fn_table.get(name)
}
pub fn get_primitive_def(&self, id: PrimitiveId) -> &TypeDef {
self.top_level.primitive_defs.get(id.0).unwrap()
}
pub fn get_class_def(&self, id: ClassId) -> &ClassDef {
self.top_level.class_defs.get(id.0).unwrap()
}
pub fn get_parametric_def(&self, id: ParamId) -> &ParametricDef {
self.top_level.parametric_defs.get(id.0).unwrap()
}
pub fn get_variable_def(&self, id: VariableId) -> &VarDef {
self.top_level.var_defs.get(id.0).unwrap()
}
pub fn get_type(&self, name: &str) -> Option<Type> {
self.top_level.get_type(name)
}
pub fn get_result(&self) -> Option<Type> {
self.result.clone()
}
}
impl TypeEnum {
pub fn subst(&self, map: &HashMap<VariableId, Type>) -> TypeEnum {
match self {
TypeEnum::TypeVariable(id) => map.get(id).map(|v| v.as_ref()).unwrap_or(self).clone(),
TypeEnum::ParametricType(id, params) => TypeEnum::ParametricType(
*id,
params
.iter()
.map(|v| v.as_ref().subst(map).into())
.collect(),
),
_ => self.clone(),
}
}
pub fn inv_subst(&self, map: &[(Type, Type)]) -> Type {
for (from, to) in map.iter() {
if self == from.as_ref() {
return to.clone();
}
}
match self {
TypeEnum::ParametricType(id, params) => TypeEnum::ParametricType(
*id,
params.iter().map(|v| v.as_ref().inv_subst(map)).collect(),
),
_ => self.clone(),
}
.into()
}
pub fn get_subst(&self, ctx: &InferenceContext) -> HashMap<VariableId, Type> {
match self {
TypeEnum::ParametricType(id, params) => {
let vars = &ctx.get_parametric_def(*id).params;
vars.iter()
.zip(params)
.map(|(v, p)| (*v, p.as_ref().clone().into()))
.collect()
}
// if this proves to be slow, we can use option type
_ => HashMap::new(),
}
}
pub fn get_base<'a>(&'a self, ctx: &'a InferenceContext) -> Option<&'a TypeDef> {
match self {
TypeEnum::PrimitiveType(id) => Some(ctx.get_primitive_def(*id)),
TypeEnum::ClassType(id) | TypeEnum::VirtualClassType(id) => {
Some(&ctx.get_class_def(*id).base)
}
TypeEnum::ParametricType(id, _) => Some(&ctx.get_parametric_def(*id).base),
_ => None,
}
}
}

View File

@ -0,0 +1,4 @@
mod inference_context;
mod top_level_context;
pub use inference_context::InferenceContext;
pub use top_level_context::TopLevelContext;

View File

@ -0,0 +1,138 @@
use super::super::typedef::*;
use std::collections::HashMap;
use std::rc::Rc;
/// Structure for storing top-level type definitions.
/// Used for collecting type signature from source code.
/// Can be converted to `InferenceContext` for type inference in functions.
pub struct TopLevelContext<'a> {
/// List of primitive definitions.
pub(super) primitive_defs: Vec<TypeDef<'a>>,
/// List of class definitions.
pub(super) class_defs: Vec<ClassDef<'a>>,
/// List of parametric type definitions.
pub(super) parametric_defs: Vec<ParametricDef<'a>>,
/// List of type variable definitions.
pub(super) var_defs: Vec<VarDef<'a>>,
/// Function name to signature mapping.
pub(super) fn_table: HashMap<&'a str, FnDef>,
/// Type name to type mapping.
pub(super) sym_table: HashMap<&'a str, Type>,
primitives: Vec<Type>,
variables: Vec<Type>,
}
impl<'a> TopLevelContext<'a> {
pub fn new(primitive_defs: Vec<TypeDef<'a>>) -> TopLevelContext {
let mut sym_table = HashMap::new();
let mut primitives = Vec::new();
for (i, t) in primitive_defs.iter().enumerate() {
primitives.push(TypeEnum::PrimitiveType(PrimitiveId(i)).into());
sym_table.insert(t.name, TypeEnum::PrimitiveType(PrimitiveId(i)).into());
}
TopLevelContext {
primitive_defs,
class_defs: Vec::new(),
parametric_defs: Vec::new(),
var_defs: Vec::new(),
fn_table: HashMap::new(),
sym_table,
primitives,
variables: Vec::new(),
}
}
pub fn add_class(&mut self, def: ClassDef<'a>) -> ClassId {
self.sym_table.insert(
def.base.name,
TypeEnum::ClassType(ClassId(self.class_defs.len())).into(),
);
self.class_defs.push(def);
ClassId(self.class_defs.len() - 1)
}
pub fn add_parametric(&mut self, def: ParametricDef<'a>) -> ParamId {
let params = def
.params
.iter()
.map(|&v| Rc::new(TypeEnum::TypeVariable(v)))
.collect();
self.sym_table.insert(
def.base.name,
TypeEnum::ParametricType(ParamId(self.parametric_defs.len()), params).into(),
);
self.parametric_defs.push(def);
ParamId(self.parametric_defs.len() - 1)
}
pub fn add_variable(&mut self, def: VarDef<'a>) -> VariableId {
self.sym_table.insert(
def.name,
TypeEnum::TypeVariable(VariableId(self.var_defs.len())).into(),
);
self.add_variable_private(def)
}
pub fn add_variable_private(&mut self, def: VarDef<'a>) -> VariableId {
self.var_defs.push(def);
self.variables
.push(TypeEnum::TypeVariable(VariableId(self.var_defs.len() - 1)).into());
VariableId(self.var_defs.len() - 1)
}
pub fn add_fn(&mut self, name: &'a str, def: FnDef) {
self.fn_table.insert(name, def);
}
pub fn get_fn_def(&self, name: &str) -> Option<&FnDef> {
self.fn_table.get(name)
}
pub fn get_primitive_def_mut(&mut self, id: PrimitiveId) -> &mut TypeDef<'a> {
self.primitive_defs.get_mut(id.0).unwrap()
}
pub fn get_primitive_def(&self, id: PrimitiveId) -> &TypeDef {
self.primitive_defs.get(id.0).unwrap()
}
pub fn get_class_def_mut(&mut self, id: ClassId) -> &mut ClassDef<'a> {
self.class_defs.get_mut(id.0).unwrap()
}
pub fn get_class_def(&self, id: ClassId) -> &ClassDef {
self.class_defs.get(id.0).unwrap()
}
pub fn get_parametric_def_mut(&mut self, id: ParamId) -> &mut ParametricDef<'a> {
self.parametric_defs.get_mut(id.0).unwrap()
}
pub fn get_parametric_def(&self, id: ParamId) -> &ParametricDef {
self.parametric_defs.get(id.0).unwrap()
}
pub fn get_variable_def_mut(&mut self, id: VariableId) -> &mut VarDef<'a> {
self.var_defs.get_mut(id.0).unwrap()
}
pub fn get_variable_def(&self, id: VariableId) -> &VarDef {
self.var_defs.get(id.0).unwrap()
}
pub fn get_primitive(&self, id: PrimitiveId) -> Type {
self.primitives.get(id.0).unwrap().clone()
}
pub fn get_variable(&self, id: VariableId) -> Type {
self.variables.get(id.0).unwrap().clone()
}
pub fn get_type(&self, name: &str) -> Option<Type> {
// TODO: handle name visibility
// possibly by passing a function from outside to tell what names are allowed, and what are
// not...
self.sym_table.get(name).cloned()
}
}

View File

@ -0,0 +1,967 @@
use super::context::InferenceContext;
use super::inference_core::resolve_call;
use super::magic_methods::*;
use super::primitives::*;
use super::typedef::{Type, TypeEnum::*};
use rustpython_parser::ast::{
Comparison, Comprehension, ComprehensionKind, Expression, ExpressionType, Operator,
UnaryOperator,
};
use std::convert::TryInto;
type ParserResult = Result<Option<Type>, String>;
pub fn infer_expr<'a>(
ctx: &mut InferenceContext<'a>,
expr: &'a Expression,
) -> ParserResult {
match &expr.node {
ExpressionType::Number { value } => infer_constant(ctx, value),
ExpressionType::Identifier { name } => infer_identifier(ctx, name),
ExpressionType::List { elements } => infer_list(ctx, elements),
ExpressionType::Tuple { elements } => infer_tuple(ctx, elements),
ExpressionType::Attribute { value, name } => infer_attribute(ctx, value, name),
ExpressionType::BoolOp { values, .. } => infer_bool_ops(ctx, values),
ExpressionType::Binop { a, b, op } => infer_bin_ops(ctx, op, a, b),
ExpressionType::Unop { op, a } => infer_unary_ops(ctx, op, a),
ExpressionType::Compare { vals, ops } => infer_compare(ctx, vals, ops),
ExpressionType::Call {
args,
function,
keywords,
} => {
if !keywords.is_empty() {
Err("keyword is not supported".into())
} else {
infer_call(ctx, &args, &function)
}
}
ExpressionType::Subscript { a, b } => infer_subscript(ctx, a, b),
ExpressionType::IfExpression { test, body, orelse } => {
infer_if_expr(ctx, &test, &body, orelse)
}
ExpressionType::Comprehension { kind, generators } => match kind.as_ref() {
ComprehensionKind::List { element } => {
if generators.len() == 1 {
infer_list_comprehension(ctx, element, &generators[0])
} else {
Err("only 1 generator statement is supported".into())
}
}
_ => Err("only list comprehension is supported".into()),
},
ExpressionType::True | ExpressionType::False => Ok(Some(ctx.get_primitive(BOOL_TYPE))),
_ => Err("not supported".into()),
}
}
fn infer_constant(
ctx: &mut InferenceContext,
value: &rustpython_parser::ast::Number,
) -> ParserResult {
use rustpython_parser::ast::Number;
match value {
Number::Integer { value } => {
let int32: Result<i32, _> = value.try_into();
if int32.is_ok() {
Ok(Some(ctx.get_primitive(INT32_TYPE)))
} else {
let int64: Result<i64, _> = value.try_into();
if int64.is_ok() {
Ok(Some(ctx.get_primitive(INT64_TYPE)))
} else {
Err("integer out of range".into())
}
}
}
Number::Float { .. } => Ok(Some(ctx.get_primitive(FLOAT_TYPE))),
_ => Err("not supported".into()),
}
}
fn infer_identifier(ctx: &mut InferenceContext, name: &str) -> ParserResult {
Ok(Some(ctx.resolve(name)?))
}
fn infer_list<'a>(
ctx: &mut InferenceContext<'a>,
elements: &'a [Expression],
) -> ParserResult {
if elements.is_empty() {
return Ok(Some(ParametricType(LIST_TYPE, vec![BotType.into()]).into()));
}
let mut types = elements.iter().map(|v| infer_expr(ctx, v));
let head = types.next().unwrap()?;
if head.is_none() {
return Err("list elements must have some type".into());
}
for v in types {
if v? != head {
return Err("inhomogeneous list is not allowed".into());
}
}
Ok(Some(ParametricType(LIST_TYPE, vec![head.unwrap()]).into()))
}
fn infer_tuple<'a>(
ctx: &mut InferenceContext<'a>,
elements: &'a [Expression],
) -> ParserResult {
let types: Result<Option<Vec<_>>, String> =
elements.iter().map(|v| infer_expr(ctx, v)).collect();
if let Some(t) = types? {
Ok(Some(ParametricType(TUPLE_TYPE, t).into()))
} else {
Err("tuple elements must have some type".into())
}
}
fn infer_attribute<'a>(
ctx: &mut InferenceContext<'a>,
value: &'a Expression,
name: &str,
) -> ParserResult {
let value = infer_expr(ctx, value)?.ok_or_else(|| "no value".to_string())?;
if let TypeVariable(id) = value.as_ref() {
let v = ctx.get_variable_def(*id);
if v.bound.is_empty() {
return Err("no fields on unbounded type variable".into());
}
let ty = v.bound[0].get_base(ctx).and_then(|v| v.fields.get(name));
if ty.is_none() {
return Err("unknown field".into());
}
for x in v.bound[1..].iter() {
let ty1 = x.get_base(ctx).and_then(|v| v.fields.get(name));
if ty1 != ty {
return Err("unknown field (type mismatch between variants)".into());
}
}
return Ok(Some(ty.unwrap().clone()));
}
match value.get_base(ctx) {
Some(b) => match b.fields.get(name) {
Some(t) => Ok(Some(t.clone())),
None => Err("no such field".into()),
},
None => Err("this object has no fields".into()),
}
}
fn infer_bool_ops<'a>(ctx: &mut InferenceContext<'a>, values: &'a [Expression]) -> ParserResult {
assert_eq!(values.len(), 2);
let left = infer_expr(ctx, &values[0])?.ok_or_else(|| "no value".to_string())?;
let right = infer_expr(ctx, &values[1])?.ok_or_else(|| "no value".to_string())?;
let b = ctx.get_primitive(BOOL_TYPE);
if left == b && right == b {
Ok(Some(b))
} else {
Err("bool operands must be bool".into())
}
}
fn infer_bin_ops<'a>(
ctx: &mut InferenceContext<'a>,
op: &Operator,
left: &'a Expression,
right: &'a Expression,
) -> ParserResult {
let left = infer_expr(ctx, left)?.ok_or_else(|| "no value".to_string())?;
let right = infer_expr(ctx, right)?.ok_or_else(|| "no value".to_string())?;
let fun = binop_name(op);
resolve_call(ctx, Some(left), fun, &[right])
}
fn infer_unary_ops<'a>(
ctx: &mut InferenceContext<'a>,
op: &UnaryOperator,
obj: &'a Expression,
) -> ParserResult {
let ty = infer_expr(ctx, obj)?.ok_or_else(|| "no value".to_string())?;
if let UnaryOperator::Not = op {
if ty == ctx.get_primitive(BOOL_TYPE) {
Ok(Some(ty))
} else {
Err("logical not must be applied to bool".into())
}
} else {
resolve_call(ctx, Some(ty), unaryop_name(op), &[])
}
}
fn infer_compare<'a>(
ctx: &mut InferenceContext<'a>,
vals: &'a [Expression],
ops: &'a [Comparison],
) -> ParserResult {
let types: Result<Option<Vec<_>>, _> = vals.iter().map(|v| infer_expr(ctx, v)).collect();
let types = types?;
if types.is_none() {
return Err("comparison operands must have type".into());
}
let types = types.unwrap();
let boolean = ctx.get_primitive(BOOL_TYPE);
let left = &types[..types.len() - 1];
let right = &types[1..];
for ((a, b), op) in left.iter().zip(right.iter()).zip(ops.iter()) {
let fun = comparison_name(op).ok_or_else(|| "unsupported comparison".to_string())?;
let ty = resolve_call(ctx, Some(a.clone()), fun, &[b.clone()])?;
if ty.is_none() || ty.unwrap() != boolean {
return Err("comparison result must be boolean".into());
}
}
Ok(Some(boolean))
}
fn infer_call<'a>(
ctx: &mut InferenceContext<'a>,
args: &'a [Expression],
function: &'a Expression,
) -> ParserResult {
let types: Result<Option<Vec<_>>, _> = args.iter().map(|v| infer_expr(ctx, v)).collect();
let types = types?;
if types.is_none() {
return Err("function params must have type".into());
}
let (obj, fun) = match &function.node {
ExpressionType::Identifier { name } => (None, name),
ExpressionType::Attribute { value, name } => (
Some(infer_expr(ctx, &value)?.ok_or_else(|| "no value".to_string())?),
name,
),
_ => return Err("not supported".into()),
};
resolve_call(ctx, obj, fun.as_str(), &types.unwrap())
}
fn infer_subscript<'a>(
ctx: &mut InferenceContext<'a>,
a: &'a Expression,
b: &'a Expression,
) -> ParserResult {
let a = infer_expr(ctx, a)?.ok_or_else(|| "no value".to_string())?;
let t = if let ParametricType(LIST_TYPE, ls) = a.as_ref() {
ls[0].clone()
} else {
return Err("subscript is not supported for types other than list".into());
};
match &b.node {
ExpressionType::Slice { elements } => {
let int32 = ctx.get_primitive(INT32_TYPE);
let types: Result<Option<Vec<_>>, _> = elements
.iter()
.map(|v| {
if let ExpressionType::None = v.node {
Ok(Some(int32.clone()))
} else {
infer_expr(ctx, v)
}
})
.collect();
let types = types?.ok_or_else(|| "slice must have type".to_string())?;
if types.iter().all(|v| v == &int32) {
Ok(Some(a))
} else {
Err("slice must be int32 type".into())
}
}
_ => {
let b = infer_expr(ctx, b)?.ok_or_else(|| "no value".to_string())?;
if b == ctx.get_primitive(INT32_TYPE) {
Ok(Some(t))
} else {
Err("index must be either slice or int32".into())
}
}
}
}
fn infer_if_expr<'a>(
ctx: &mut InferenceContext<'a>,
test: &'a Expression,
body: &'a Expression,
orelse: &'a Expression,
) -> ParserResult {
let test = infer_expr(ctx, test)?.ok_or_else(|| "no value".to_string())?;
if test != ctx.get_primitive(BOOL_TYPE) {
return Err("test should be bool".into());
}
let body = infer_expr(ctx, body)?;
let orelse = infer_expr(ctx, orelse)?;
if body.as_ref() == orelse.as_ref() {
Ok(body)
} else {
Err("divergent type".into())
}
}
pub fn infer_simple_binding<'a>(
ctx: &mut InferenceContext<'a>,
name: &'a Expression,
ty: Type,
) -> Result<(), String> {
match &name.node {
ExpressionType::Identifier { name } => {
if name == "_" {
Ok(())
} else if ctx.defined(name.as_str()) {
Err("duplicated naming".into())
} else {
ctx.assign(name.as_str(), ty)?;
Ok(())
}
}
ExpressionType::Tuple { elements } => {
if let ParametricType(TUPLE_TYPE, ls) = ty.as_ref() {
if elements.len() == ls.len() {
for (a, b) in elements.iter().zip(ls.iter()) {
infer_simple_binding(ctx, a, b.clone())?;
}
Ok(())
} else {
Err("different length".into())
}
} else {
Err("not supported".into())
}
}
_ => Err("not supported".into()),
}
}
fn infer_list_comprehension<'a>(
ctx: &mut InferenceContext<'a>,
element: &'a Expression,
comprehension: &'a Comprehension,
) -> ParserResult {
if comprehension.is_async {
return Err("async is not supported".into());
}
let iter = infer_expr(ctx, &comprehension.iter)?.ok_or_else(|| "no value".to_string())?;
if let ParametricType(LIST_TYPE, ls) = iter.as_ref() {
ctx.with_scope(|ctx| {
infer_simple_binding(ctx, &comprehension.target, ls[0].clone())?;
let boolean = ctx.get_primitive(BOOL_TYPE);
for test in comprehension.ifs.iter() {
let result =
infer_expr(ctx, test)?.ok_or_else(|| "no value in test".to_string())?;
if result != boolean {
return Err("test must be bool".into());
}
}
let result = infer_expr(ctx, element)?.ok_or_else(|| "no value")?;
Ok(Some(ParametricType(LIST_TYPE, vec![result]).into()))
})
.1
} else {
Err("iteration is supported for list only".into())
}
}
#[cfg(test)]
mod test {
use super::{
super::{context::*, typedef::*},
*,
};
use rustpython_parser::parser::parse_expression;
use std::collections::HashMap;
use std::rc::Rc;
fn get_inference_context(ctx: TopLevelContext) -> InferenceContext {
InferenceContext::new(ctx, Box::new(|_| Err("unbounded identifier".into())))
}
#[test]
fn test_constants() {
let ctx = basic_ctx();
let mut ctx = get_inference_context(ctx);
let ast = parse_expression("123").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("2147483647").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("2147483648").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT64_TYPE));
let ast = parse_expression("9223372036854775807").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT64_TYPE));
let ast = parse_expression("9223372036854775808").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("integer out of range".into()));
let ast = parse_expression("123.456").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(FLOAT_TYPE));
let ast = parse_expression("True").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE));
let ast = parse_expression("False").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE));
}
#[test]
fn test_identifier() {
let ctx = basic_ctx();
let mut ctx = get_inference_context(ctx);
ctx.assign("abc", ctx.get_primitive(INT32_TYPE)).unwrap();
let ast = parse_expression("abc").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("ab").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("unbounded identifier".into()));
}
#[test]
fn test_list() {
let mut ctx = basic_ctx();
ctx.add_fn(
"foo",
FnDef {
args: vec![],
result: None,
},
);
let mut ctx = get_inference_context(ctx);
ctx.assign("abc", ctx.get_primitive(INT32_TYPE)).unwrap();
// def is reserved...
ctx.assign("efg", ctx.get_primitive(INT32_TYPE)).unwrap();
ctx.assign("xyz", ctx.get_primitive(FLOAT_TYPE)).unwrap();
let ast = parse_expression("[]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result.unwrap().unwrap(),
ParametricType(LIST_TYPE, vec![BotType.into()]).into()
);
let ast = parse_expression("[abc]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result.unwrap().unwrap(),
ParametricType(LIST_TYPE, vec![ctx.get_primitive(INT32_TYPE)]).into()
);
let ast = parse_expression("[abc, efg]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result.unwrap().unwrap(),
ParametricType(LIST_TYPE, vec![ctx.get_primitive(INT32_TYPE)]).into()
);
let ast = parse_expression("[abc, efg, xyz]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("inhomogeneous list is not allowed".into()));
let ast = parse_expression("[foo()]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("list elements must have some type".into()));
}
#[test]
fn test_tuple() {
let mut ctx = basic_ctx();
ctx.add_fn(
"foo",
FnDef {
args: vec![],
result: None,
},
);
let mut ctx = get_inference_context(ctx);
ctx.assign("abc", ctx.get_primitive(INT32_TYPE)).unwrap();
ctx.assign("efg", ctx.get_primitive(FLOAT_TYPE)).unwrap();
let ast = parse_expression("(abc, efg)").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result.unwrap().unwrap(),
ParametricType(
TUPLE_TYPE,
vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(FLOAT_TYPE)]
)
.into()
);
let ast = parse_expression("(abc, efg, foo())").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("tuple elements must have some type".into()));
}
#[test]
fn test_attribute() {
let mut ctx = basic_ctx();
ctx.add_fn(
"none",
FnDef {
args: vec![],
result: None,
},
);
let int32 = ctx.get_primitive(INT32_TYPE);
let float = ctx.get_primitive(FLOAT_TYPE);
let foo = ctx.add_class(ClassDef {
base: TypeDef {
name: "Foo",
fields: HashMap::new(),
methods: HashMap::new(),
},
parents: vec![],
});
let foo_def = ctx.get_class_def_mut(foo);
foo_def.base.fields.insert("a", int32.clone());
foo_def.base.fields.insert("b", ClassType(foo).into());
foo_def.base.fields.insert("c", int32.clone());
let bar = ctx.add_class(ClassDef {
base: TypeDef {
name: "Bar",
fields: HashMap::new(),
methods: HashMap::new(),
},
parents: vec![],
});
let bar_def = ctx.get_class_def_mut(bar);
bar_def.base.fields.insert("a", int32);
bar_def.base.fields.insert("b", ClassType(bar).into());
bar_def.base.fields.insert("c", float);
let v0 = ctx.add_variable(VarDef {
name: "v0",
bound: vec![],
});
let v1 = ctx.add_variable(VarDef {
name: "v1",
bound: vec![ClassType(foo).into(), ClassType(bar).into()],
});
let mut ctx = get_inference_context(ctx);
ctx.assign("foo", Rc::new(ClassType(foo))).unwrap();
ctx.assign("bar", Rc::new(ClassType(bar))).unwrap();
ctx.assign("foobar", Rc::new(VirtualClassType(foo)))
.unwrap();
ctx.assign("v0", ctx.get_variable(v0)).unwrap();
ctx.assign("v1", ctx.get_variable(v1)).unwrap();
ctx.assign("bot", Rc::new(BotType)).unwrap();
let ast = parse_expression("foo.a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("foo.d").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no such field".into()));
let ast = parse_expression("foobar.a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("v0.a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no fields on unbounded type variable".into()));
let ast = parse_expression("v1.a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
// shall we support this?
let ast = parse_expression("v1.b").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result,
Err("unknown field (type mismatch between variants)".into())
);
// assert_eq!(result.unwrap().unwrap(), TypeVariable(v1).into());
let ast = parse_expression("v1.c").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result,
Err("unknown field (type mismatch between variants)".into())
);
let ast = parse_expression("v1.d").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("unknown field".into()));
let ast = parse_expression("none().a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no value".into()));
let ast = parse_expression("bot.a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("this object has no fields".into()));
}
#[test]
fn test_bool_ops() {
let mut ctx = basic_ctx();
ctx.add_fn(
"none",
FnDef {
args: vec![],
result: None,
},
);
let mut ctx = get_inference_context(ctx);
let ast = parse_expression("True and False").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE));
let ast = parse_expression("True and none()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no value".into()));
let ast = parse_expression("True and 123").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("bool operands must be bool".into()));
}
#[test]
fn test_bin_ops() {
let mut ctx = basic_ctx();
let v0 = ctx.add_variable(VarDef {
name: "v0",
bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(INT64_TYPE)],
});
let mut ctx = get_inference_context(ctx);
ctx.assign("a", TypeVariable(v0).into()).unwrap();
let ast = parse_expression("1 + 2 + 3").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("a + a + a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), TypeVariable(v0).into());
}
#[test]
fn test_unary_ops() {
let mut ctx = basic_ctx();
let v0 = ctx.add_variable(VarDef {
name: "v0",
bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(INT64_TYPE)],
});
let mut ctx = get_inference_context(ctx);
ctx.assign("a", TypeVariable(v0).into()).unwrap();
let ast = parse_expression("-(123)").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("-a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), TypeVariable(v0).into());
let ast = parse_expression("not True").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE));
let ast = parse_expression("not (1)").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("logical not must be applied to bool".into()));
}
#[test]
fn test_compare() {
let mut ctx = basic_ctx();
let v0 = ctx.add_variable(VarDef {
name: "v0",
bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(INT64_TYPE)],
});
let mut ctx = get_inference_context(ctx);
ctx.assign("a", TypeVariable(v0).into()).unwrap();
let ast = parse_expression("a == a == a").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE));
let ast = parse_expression("a == a == 1").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("different types".into()));
let ast = parse_expression("True > False").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no such function".into()));
let ast = parse_expression("True in False").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("unsupported comparison".into()));
}
#[test]
fn test_call() {
let mut ctx = basic_ctx();
ctx.add_fn(
"none",
FnDef {
args: vec![],
result: None,
},
);
let foo = ctx.add_class(ClassDef {
base: TypeDef {
name: "Foo",
fields: HashMap::new(),
methods: HashMap::new(),
},
parents: vec![],
});
let foo_def = ctx.get_class_def_mut(foo);
foo_def.base.methods.insert(
"a",
FnDef {
args: vec![],
result: Some(Rc::new(ClassType(foo))),
},
);
let bar = ctx.add_class(ClassDef {
base: TypeDef {
name: "Bar",
fields: HashMap::new(),
methods: HashMap::new(),
},
parents: vec![],
});
let bar_def = ctx.get_class_def_mut(bar);
bar_def.base.methods.insert(
"a",
FnDef {
args: vec![],
result: Some(Rc::new(ClassType(bar))),
},
);
let v0 = ctx.add_variable(VarDef {
name: "v0",
bound: vec![],
});
let v1 = ctx.add_variable(VarDef {
name: "v1",
bound: vec![ClassType(foo).into(), ClassType(bar).into()],
});
let v2 = ctx.add_variable(VarDef {
name: "v2",
bound: vec![
ClassType(foo).into(),
ClassType(bar).into(),
ctx.get_primitive(INT32_TYPE),
],
});
let mut ctx = get_inference_context(ctx);
ctx.assign("foo", Rc::new(ClassType(foo))).unwrap();
ctx.assign("bar", Rc::new(ClassType(bar))).unwrap();
ctx.assign("foobar", Rc::new(VirtualClassType(foo)))
.unwrap();
ctx.assign("v0", ctx.get_variable(v0)).unwrap();
ctx.assign("v1", ctx.get_variable(v1)).unwrap();
ctx.assign("v2", ctx.get_variable(v2)).unwrap();
ctx.assign("bot", Rc::new(BotType)).unwrap();
let ast = parse_expression("foo.a()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ClassType(foo).into());
let ast = parse_expression("v1.a()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), TypeVariable(v1).into());
let ast = parse_expression("foobar.a()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ClassType(foo).into());
let ast = parse_expression("none().a()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no value".into()));
let ast = parse_expression("bot.a()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("not supported".into()));
let ast = parse_expression("[][0].a()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("not supported".into()));
let ast = parse_expression("v0.a()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("unbounded type var".into()));
let ast = parse_expression("v2.a()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no such function".into()));
}
#[test]
fn infer_subscript() {
let mut ctx = basic_ctx();
ctx.add_fn(
"none",
FnDef {
args: vec![],
result: None,
},
);
let mut ctx = get_inference_context(ctx);
let ast = parse_expression("[1, 2, 3][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("[[1]][0][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("[1, 2, 3][1:2]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result.unwrap().unwrap(),
ParametricType(LIST_TYPE, vec![ctx.get_primitive(INT32_TYPE)]).into()
);
let ast = parse_expression("[1, 2, 3][1:2:2]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result.unwrap().unwrap(),
ParametricType(LIST_TYPE, vec![ctx.get_primitive(INT32_TYPE)]).into()
);
let ast = parse_expression("[1, 2, 3][1:1.2]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("slice must be int32 type".into()));
let ast = parse_expression("[1, 2, 3][1:none()]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("slice must have type".into()));
let ast = parse_expression("[1, 2, 3][1.2]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("index must be either slice or int32".into()));
let ast = parse_expression("[1, 2, 3][none()]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no value".into()));
let ast = parse_expression("none()[1.2]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no value".into()));
let ast = parse_expression("123[1]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result,
Err("subscript is not supported for types other than list".into())
);
}
#[test]
fn test_if_expr() {
let mut ctx = basic_ctx();
ctx.add_fn(
"none",
FnDef {
args: vec![],
result: None,
},
);
let mut ctx = get_inference_context(ctx);
let ast = parse_expression("1 if True else 0").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE));
let ast = parse_expression("none() if True else none()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap(), None);
let ast = parse_expression("none() if 1 else none()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("test should be bool".into()));
let ast = parse_expression("1 if True else none()").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("divergent type".into()));
}
#[test]
fn test_list_comp() {
let mut ctx = basic_ctx();
ctx.add_fn(
"none",
FnDef {
args: vec![],
result: None,
},
);
let int32 = ctx.get_primitive(INT32_TYPE);
let mut ctx = get_inference_context(ctx);
ctx.assign("z", int32.clone()).unwrap();
let ast = parse_expression("[x for x in [(1, 2), (2, 3), (3, 4)]][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result.unwrap().unwrap(),
ParametricType(TUPLE_TYPE, vec![int32.clone(), int32.clone()]).into()
);
let ast = parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)]][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), int32);
let ast =
parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)] if x > 0][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result.unwrap().unwrap(), int32);
let ast = parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)] if x][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("test must be bool".into()));
let ast = parse_expression("[y for x in []][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("unbounded identifier".into()));
let ast = parse_expression("[none() for x in []][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("no value".into()));
let ast = parse_expression("[z for z in []][0]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(result, Err("duplicated naming".into()));
let ast = parse_expression("[x for x in [] for y in []]").unwrap();
let result = infer_expr(&mut ctx, &ast);
assert_eq!(
result,
Err("only 1 generator statement is supported".into())
);
}
}

View File

@ -0,0 +1,617 @@
use super::context::InferenceContext;
use super::typedef::{TypeEnum::*, *};
use std::collections::HashMap;
use thiserror::Error;
#[derive(Error, Debug)]
enum SubstError {
#[error("different type variables after substitution")]
DifferentSubstVar(VariableId, VariableId),
#[error("cannot substitute unbounded type variable into bounded one")]
UnboundedTypeVar(VariableId, VariableId),
#[error("incompatible bound for type variables")]
IncompatibleBound(VariableId, VariableId),
#[error("only subtype of virtual class can be substituted into virtual class type")]
NotVirtualClassSubtype(Type, ClassId),
#[error("different types")]
DifferentTypes(Type, Type),
}
fn find_subst(
ctx: &InferenceContext,
valuation: &Option<(VariableId, Type)>,
sub: &mut HashMap<VariableId, Type>,
mut a: Type,
mut b: Type,
) -> Result<(), SubstError> {
if let TypeVariable(id) = a.as_ref() {
if let Some((assumption_id, t)) = valuation {
if assumption_id == id {
a = t.clone();
}
}
}
let mut substituted = false;
if let TypeVariable(id) = b.as_ref() {
if let Some(c) = sub.get(&id) {
b = c.clone();
substituted = true;
}
}
match (a.as_ref(), b.as_ref()) {
(BotType, _) => Ok(()),
(TypeVariable(id_a), TypeVariable(id_b)) => {
if substituted {
return if id_a == id_b {
Ok(())
} else {
Err(SubstError::DifferentSubstVar(*id_a, *id_b))
};
}
let v_a = ctx.get_variable_def(*id_a);
let v_b = ctx.get_variable_def(*id_b);
if !v_b.bound.is_empty() {
if v_a.bound.is_empty() {
return Err(SubstError::UnboundedTypeVar(*id_a, *id_b));
} else {
let diff: Vec<_> = v_a
.bound
.iter()
.filter(|x| !v_b.bound.contains(x))
.collect();
if !diff.is_empty() {
return Err(SubstError::IncompatibleBound(*id_a, *id_b));
}
}
}
sub.insert(*id_b, a.clone());
Ok(())
}
(TypeVariable(id_a), _) => {
let v_a = ctx.get_variable_def(*id_a);
if v_a.bound.len() == 1 && v_a.bound[0].as_ref() == b.as_ref() {
Ok(())
} else {
Err(SubstError::DifferentTypes(a.clone(), b.clone()))
}
}
(_, TypeVariable(id_b)) => {
let v_b = ctx.get_variable_def(*id_b);
if v_b.bound.is_empty() || v_b.bound.contains(&a) {
sub.insert(*id_b, a.clone());
Ok(())
} else {
Err(SubstError::DifferentTypes(a.clone(), b.clone()))
}
}
(_, VirtualClassType(id_b)) => {
let mut parents;
match a.as_ref() {
ClassType(id_a) => {
parents = [*id_a].to_vec();
}
VirtualClassType(id_a) => {
parents = [*id_a].to_vec();
}
_ => {
return Err(SubstError::NotVirtualClassSubtype(a.clone(), *id_b));
}
};
while !parents.is_empty() {
if *id_b == parents[0] {
return Ok(());
}
let c = ctx.get_class_def(parents.remove(0));
parents.extend_from_slice(&c.parents);
}
Err(SubstError::NotVirtualClassSubtype(a.clone(), *id_b))
}
(ParametricType(id_a, param_a), ParametricType(id_b, param_b)) => {
if id_a != id_b || param_a.len() != param_b.len() {
Err(SubstError::DifferentTypes(a.clone(), b.clone()))
} else {
for (x, y) in param_a.iter().zip(param_b.iter()) {
find_subst(ctx, valuation, sub, x.clone(), y.clone())?;
}
Ok(())
}
}
(_, _) => {
if a == b {
Ok(())
} else {
Err(SubstError::DifferentTypes(a.clone(), b.clone()))
}
}
}
}
fn resolve_call_rec(
ctx: &InferenceContext,
valuation: &Option<(VariableId, Type)>,
obj: Option<Type>,
func: &str,
args: &[Type],
) -> Result<Option<Type>, String> {
let mut subst = obj
.as_ref()
.map(|v| v.get_subst(ctx))
.unwrap_or_else(HashMap::new);
let fun = match &obj {
Some(obj) => {
let base = match obj.as_ref() {
TypeVariable(id) => {
let v = ctx.get_variable_def(*id);
if v.bound.is_empty() {
return Err("unbounded type var".to_string());
}
let results: Result<Vec<_>, String> = v
.bound
.iter()
.map(|ins| {
resolve_call_rec(
ctx,
&Some((*id, ins.clone())),
Some(ins.clone()),
func,
args.clone(),
)
})
.collect();
let results = results?;
if results.iter().all(|v| v == &results[0]) {
return Ok(results[0].clone());
}
let mut results = results.iter().zip(v.bound.iter()).map(|(r, ins)| {
r.as_ref()
.map(|v| v.inv_subst(&[(ins.clone(), obj.clone())]))
});
let first = results.next().unwrap();
if results.all(|v| v == first) {
return Ok(first);
} else {
return Err("divergent type after substitution".to_string());
}
}
PrimitiveType(id) => &ctx.get_primitive_def(*id),
ClassType(id) | VirtualClassType(id) => &ctx.get_class_def(*id).base,
ParametricType(id, _) => &ctx.get_parametric_def(*id).base,
_ => return Err("not supported".to_string()),
};
base.methods.get(func)
}
None => ctx.get_fn_def(func),
}
.ok_or_else(|| "no such function".to_string())?;
if args.len() != fun.args.len() {
return Err("incorrect parameter number".to_string());
}
for (a, b) in args.iter().zip(fun.args.iter()) {
find_subst(ctx, valuation, &mut subst, a.clone(), b.clone()).map_err(|v| v.to_string())?;
}
let result = fun.result.as_ref().map(|v| v.subst(&subst));
Ok(result.map(|result| {
if let SelfType = result {
obj.unwrap()
} else {
result.into()
}
}))
}
pub fn resolve_call(
ctx: &InferenceContext,
obj: Option<Type>,
func: &str,
args: &[Type],
) -> Result<Option<Type>, String> {
resolve_call_rec(ctx, &None, obj, func, args)
}
#[cfg(test)]
mod tests {
use super::{
super::{context::*, primitives::*},
*,
};
use std::matches;
use std::rc::Rc;
fn get_inference_context(ctx: TopLevelContext) -> InferenceContext {
InferenceContext::new(ctx, Box::new(|_| Err("unbounded identifier".into())))
}
#[test]
fn test_simple_generic() {
let mut ctx = basic_ctx();
let v1 = ctx.add_variable(VarDef {
name: "V1",
bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(FLOAT_TYPE)],
});
let v1 = ctx.get_variable(v1);
let v2 = ctx.add_variable(VarDef {
name: "V2",
bound: vec![
ctx.get_primitive(BOOL_TYPE),
ctx.get_primitive(INT32_TYPE),
ctx.get_primitive(FLOAT_TYPE),
],
});
let v2 = ctx.get_variable(v2);
let ctx = get_inference_context(ctx);
assert_eq!(
resolve_call(&ctx, None, "int32", &[ctx.get_primitive(FLOAT_TYPE)]),
Ok(Some(ctx.get_primitive(INT32_TYPE)))
);
assert_eq!(
resolve_call(&ctx, None, "int32", &[ctx.get_primitive(INT32_TYPE)],),
Ok(Some(ctx.get_primitive(INT32_TYPE)))
);
assert_eq!(
resolve_call(&ctx, None, "float", &[ctx.get_primitive(INT32_TYPE)]),
Ok(Some(ctx.get_primitive(FLOAT_TYPE)))
);
assert!(matches!(
resolve_call(&ctx, None, "float", &[ctx.get_primitive(BOOL_TYPE)]),
Err(..)
));
assert!(matches!(
resolve_call(&ctx, None, "float", &[]),
Err(..)
));
assert_eq!(
resolve_call(&ctx, None, "float", &[v1]),
Ok(Some(ctx.get_primitive(FLOAT_TYPE)))
);
assert!(matches!(
resolve_call(&ctx, None, "float", &[v2]),
Err(..)
));
}
#[test]
fn test_methods() {
let mut ctx = basic_ctx();
let v0 = ctx.add_variable(VarDef {
name: "V0",
bound: vec![],
});
let v0 = ctx.get_variable(v0);
let v1 = ctx.add_variable(VarDef {
name: "V1",
bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(FLOAT_TYPE)],
});
let v1 = ctx.get_variable(v1);
let v2 = ctx.add_variable(VarDef {
name: "V2",
bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(FLOAT_TYPE)],
});
let v2 = ctx.get_variable(v2);
let v3 = ctx.add_variable(VarDef {
name: "V3",
bound: vec![
ctx.get_primitive(BOOL_TYPE),
ctx.get_primitive(INT32_TYPE),
ctx.get_primitive(FLOAT_TYPE),
],
});
let v3 = ctx.get_variable(v3);
let int32 = ctx.get_primitive(INT32_TYPE);
let int64 = ctx.get_primitive(INT64_TYPE);
let ctx = get_inference_context(ctx);
// simple cases
assert_eq!(
resolve_call(&ctx, Some(int32.clone()), "__add__", &[int32.clone()]),
Ok(Some(int32.clone()))
);
assert_ne!(
resolve_call(&ctx, Some(int32.clone()), "__add__", &[int32.clone()]),
Ok(Some(int64.clone()))
);
assert!(matches!(
resolve_call(&ctx, Some(int32), "__add__", &[int64]),
Err(..)
));
// with type variables
assert_eq!(
resolve_call(&ctx, Some(v1.clone()), "__add__", &[v1.clone()]),
Ok(Some(v1.clone()))
);
assert!(matches!(
resolve_call(&ctx, Some(v0.clone()), "__add__", &[v2.clone()]),
Err(..)
));
assert!(matches!(
resolve_call(&ctx, Some(v1.clone()), "__add__", &[v0]),
Err(..)
));
assert!(matches!(
resolve_call(&ctx, Some(v1.clone()), "__add__", &[v2]),
Err(..)
));
assert!(matches!(
resolve_call(&ctx, Some(v1.clone()), "__add__", &[v3.clone()]),
Err(..)
));
assert!(matches!(
resolve_call(&ctx, Some(v3.clone()), "__add__", &[v1]),
Err(..)
));
assert!(matches!(
resolve_call(&ctx, Some(v3.clone()), "__add__", &[v3]),
Err(..)
));
}
#[test]
fn test_multi_generic() {
let mut ctx = basic_ctx();
let v0 = ctx.add_variable(VarDef {
name: "V0",
bound: vec![],
});
let v0 = ctx.get_variable(v0);
let v1 = ctx.add_variable(VarDef {
name: "V1",
bound: vec![],
});
let v1 = ctx.get_variable(v1);
let v2 = ctx.add_variable(VarDef {
name: "V2",
bound: vec![],
});
let v2 = ctx.get_variable(v2);
let v3 = ctx.add_variable(VarDef {
name: "V3",
bound: vec![],
});
let v3 = ctx.get_variable(v3);
ctx.add_fn(
"foo",
FnDef {
args: vec![v0.clone(), v0.clone(), v1.clone()],
result: Some(v0.clone()),
},
);
ctx.add_fn(
"foo1",
FnDef {
args: vec![ParametricType(TUPLE_TYPE, vec![v0.clone(), v0.clone(), v1]).into()],
result: Some(v0),
},
);
let ctx = get_inference_context(ctx);
assert_eq!(
resolve_call(&ctx, None, "foo", &[v2.clone(), v2.clone(), v2.clone()]),
Ok(Some(v2.clone()))
);
assert_eq!(
resolve_call(&ctx, None, "foo", &[v2.clone(), v2.clone(), v3.clone()]),
Ok(Some(v2.clone()))
);
assert!(matches!(
resolve_call(&ctx, None, "foo", &[v2.clone(), v3.clone(), v3.clone()]),
Err(..)
));
assert_eq!(
resolve_call(
&ctx,
None,
"foo1",
&[ParametricType(TUPLE_TYPE, vec![v2.clone(), v2.clone(), v2.clone()]).into()]
),
Ok(Some(v2.clone()))
);
assert_eq!(
resolve_call(
&ctx,
None,
"foo1",
&[ParametricType(TUPLE_TYPE, vec![v2.clone(), v2.clone(), v3.clone()]).into()]
),
Ok(Some(v2.clone()))
);
assert!(matches!(
resolve_call(
&ctx,
None,
"foo1",
&[ParametricType(TUPLE_TYPE, vec![v2, v3.clone(), v3]).into()]
),
Err(..)
));
}
#[test]
fn test_class_generics() {
let mut ctx = basic_ctx();
let list = ctx.get_parametric_def_mut(LIST_TYPE);
let t = Rc::new(TypeVariable(list.params[0]));
list.base.methods.insert(
"head",
FnDef {
args: vec![],
result: Some(t.clone()),
},
);
list.base.methods.insert(
"append",
FnDef {
args: vec![t],
result: None,
},
);
let v0 = ctx.add_variable(VarDef {
name: "V0",
bound: vec![],
});
let v0 = ctx.get_variable(v0);
let v1 = ctx.add_variable(VarDef {
name: "V1",
bound: vec![],
});
let v1 = ctx.get_variable(v1);
let ctx = get_inference_context(ctx);
assert_eq!(
resolve_call(
&ctx,
Some(ParametricType(LIST_TYPE, vec![v0.clone()]).into()),
"head",
&[]
),
Ok(Some(v0.clone()))
);
assert_eq!(
resolve_call(
&ctx,
Some(ParametricType(LIST_TYPE, vec![v0.clone()]).into()),
"append",
&[v0.clone()]
),
Ok(None)
);
assert!(matches!(
resolve_call(
&ctx,
Some(ParametricType(LIST_TYPE, vec![v0]).into()),
"append",
&[v1]
),
Err(..)
));
}
#[test]
fn test_virtual_class() {
let mut ctx = basic_ctx();
let foo = ctx.add_class(ClassDef {
base: TypeDef {
name: "Foo",
methods: HashMap::new(),
fields: HashMap::new(),
},
parents: vec![],
});
let foo1 = ctx.add_class(ClassDef {
base: TypeDef {
name: "Foo1",
methods: HashMap::new(),
fields: HashMap::new(),
},
parents: vec![foo],
});
let foo2 = ctx.add_class(ClassDef {
base: TypeDef {
name: "Foo2",
methods: HashMap::new(),
fields: HashMap::new(),
},
parents: vec![foo1],
});
let bar = ctx.add_class(ClassDef {
base: TypeDef {
name: "bar",
methods: HashMap::new(),
fields: HashMap::new(),
},
parents: vec![],
});
ctx.add_fn(
"foo",
FnDef {
args: vec![VirtualClassType(foo).into()],
result: None,
},
);
ctx.add_fn(
"foo1",
FnDef {
args: vec![VirtualClassType(foo1).into()],
result: None,
},
);
let ctx = get_inference_context(ctx);
assert_eq!(
resolve_call(&ctx, None, "foo", &[ClassType(foo).into()]),
Ok(None)
);
assert_eq!(
resolve_call(&ctx, None, "foo", &[ClassType(foo1).into()]),
Ok(None)
);
assert_eq!(
resolve_call(&ctx, None, "foo", &[ClassType(foo2).into()]),
Ok(None)
);
assert!(matches!(
resolve_call(&ctx, None, "foo", &[ClassType(bar).into()]),
Err(..)
));
assert_eq!(
resolve_call(&ctx, None, "foo1", &[ClassType(foo1).into()]),
Ok(None)
);
assert_eq!(
resolve_call(&ctx, None, "foo1", &[ClassType(foo2).into()]),
Ok(None)
);
assert!(matches!(
resolve_call(&ctx, None, "foo1", &[ClassType(foo).into()]),
Err(..)
));
// virtual class substitution
assert_eq!(
resolve_call(&ctx, None, "foo", &[VirtualClassType(foo).into()]),
Ok(None)
);
assert_eq!(
resolve_call(&ctx, None, "foo", &[VirtualClassType(foo1).into()]),
Ok(None)
);
assert_eq!(
resolve_call(&ctx, None, "foo", &[VirtualClassType(foo2).into()]),
Ok(None)
);
assert!(matches!(
resolve_call(&ctx, None, "foo", &[VirtualClassType(bar).into()]),
Err(..)
));
}
}

View File

@ -0,0 +1,58 @@
use rustpython_parser::ast::{Comparison, Operator, UnaryOperator};
pub fn binop_name(op: &Operator) -> &'static str {
match op {
Operator::Add => "__add__",
Operator::Sub => "__sub__",
Operator::Div => "__truediv__",
Operator::Mod => "__mod__",
Operator::Mult => "__mul__",
Operator::Pow => "__pow__",
Operator::BitOr => "__or__",
Operator::BitXor => "__xor__",
Operator::BitAnd => "__and__",
Operator::LShift => "__lshift__",
Operator::RShift => "__rshift__",
Operator::FloorDiv => "__floordiv__",
Operator::MatMult => "__matmul__",
}
}
pub fn binop_assign_name(op: &Operator) -> &'static str {
match op {
Operator::Add => "__iadd__",
Operator::Sub => "__isub__",
Operator::Div => "__itruediv__",
Operator::Mod => "__imod__",
Operator::Mult => "__imul__",
Operator::Pow => "__ipow__",
Operator::BitOr => "__ior__",
Operator::BitXor => "__ixor__",
Operator::BitAnd => "__iand__",
Operator::LShift => "__ilshift__",
Operator::RShift => "__irshift__",
Operator::FloorDiv => "__ifloordiv__",
Operator::MatMult => "__imatmul__",
}
}
pub fn unaryop_name(op: &UnaryOperator) -> &'static str {
match op {
UnaryOperator::Pos => "__pos__",
UnaryOperator::Neg => "__neg__",
UnaryOperator::Not => "__not__",
UnaryOperator::Inv => "__inv__",
}
}
pub fn comparison_name(op: &Comparison) -> Option<&'static str> {
match op {
Comparison::Less => Some("__lt__"),
Comparison::LessOrEqual => Some("__le__"),
Comparison::Greater => Some("__gt__"),
Comparison::GreaterOrEqual => Some("__ge__"),
Comparison::Equal => Some("__eq__"),
Comparison::NotEqual => Some("__ne__"),
_ => None,
}
}

View File

@ -0,0 +1,9 @@
pub mod context;
pub mod expression_inference;
pub mod inference_core;
mod magic_methods;
pub mod primitives;
pub mod statement_check;
pub mod typedef;
pub mod signature;

View File

@ -0,0 +1,201 @@
use super::context::*;
use super::typedef::{TypeEnum::*, *};
use std::collections::HashMap;
pub const PRIMITIVES: [&str; 6] = ["int32", "int64", "float", "bool", "list", "tuple"];
pub const TUPLE_TYPE: ParamId = ParamId(0);
pub const LIST_TYPE: ParamId = ParamId(1);
pub const BOOL_TYPE: PrimitiveId = PrimitiveId(0);
pub const INT32_TYPE: PrimitiveId = PrimitiveId(1);
pub const INT64_TYPE: PrimitiveId = PrimitiveId(2);
pub const FLOAT_TYPE: PrimitiveId = PrimitiveId(3);
fn impl_math(def: &mut TypeDef, ty: &Type) {
let result = Some(ty.clone());
let fun = FnDef {
args: vec![ty.clone()],
result: result.clone(),
};
def.methods.insert("__add__", fun.clone());
def.methods.insert("__iadd__", fun.clone());
def.methods.insert("__sub__", fun.clone());
def.methods.insert("__isub__", fun.clone());
def.methods.insert("__mul__", fun.clone());
def.methods.insert("__imul__", fun.clone());
def.methods.insert(
"__neg__",
FnDef {
args: vec![],
result,
},
);
def.methods.insert(
"__truediv__",
FnDef {
args: vec![ty.clone()],
result: Some(PrimitiveType(FLOAT_TYPE).into()),
},
);
if ty.as_ref() == &PrimitiveType(FLOAT_TYPE) {
def.methods.insert(
"__itruediv__",
FnDef {
args: vec![ty.clone()],
result: Some(PrimitiveType(FLOAT_TYPE).into()),
},
);
}
def.methods.insert("__floordiv__", fun.clone());
def.methods.insert("__ifloordiv__", fun.clone());
def.methods.insert("__mod__", fun.clone());
def.methods.insert("__imod__", fun.clone());
def.methods.insert("__pow__", fun.clone());
def.methods.insert("__ipow__", fun);
}
fn impl_bits(def: &mut TypeDef, ty: &Type) {
let result = Some(ty.clone());
let fun = FnDef {
args: vec![PrimitiveType(INT32_TYPE).into()],
result,
};
def.methods.insert("__lshift__", fun.clone());
def.methods.insert("__rshift__", fun);
def.methods.insert(
"__xor__",
FnDef {
args: vec![ty.clone()],
result: Some(ty.clone()),
},
);
}
fn impl_eq(def: &mut TypeDef, ty: &Type) {
let fun = FnDef {
args: vec![ty.clone()],
result: Some(PrimitiveType(BOOL_TYPE).into()),
};
def.methods.insert("__eq__", fun.clone());
def.methods.insert("__ne__", fun);
}
fn impl_order(def: &mut TypeDef, ty: &Type) {
let fun = FnDef {
args: vec![ty.clone()],
result: Some(PrimitiveType(BOOL_TYPE).into()),
};
def.methods.insert("__lt__", fun.clone());
def.methods.insert("__gt__", fun.clone());
def.methods.insert("__le__", fun.clone());
def.methods.insert("__ge__", fun);
}
pub fn basic_ctx() -> TopLevelContext<'static> {
let primitives = [
TypeDef {
name: "bool",
fields: HashMap::new(),
methods: HashMap::new(),
},
TypeDef {
name: "int32",
fields: HashMap::new(),
methods: HashMap::new(),
},
TypeDef {
name: "int64",
fields: HashMap::new(),
methods: HashMap::new(),
},
TypeDef {
name: "float",
fields: HashMap::new(),
methods: HashMap::new(),
},
]
.to_vec();
let mut ctx = TopLevelContext::new(primitives);
let b = ctx.get_primitive(BOOL_TYPE);
let b_def = ctx.get_primitive_def_mut(BOOL_TYPE);
impl_eq(b_def, &b);
let int32 = ctx.get_primitive(INT32_TYPE);
let int32_def = ctx.get_primitive_def_mut(INT32_TYPE);
impl_math(int32_def, &int32);
impl_bits(int32_def, &int32);
impl_order(int32_def, &int32);
impl_eq(int32_def, &int32);
let int64 = ctx.get_primitive(INT64_TYPE);
let int64_def = ctx.get_primitive_def_mut(INT64_TYPE);
impl_math(int64_def, &int64);
impl_bits(int64_def, &int64);
impl_order(int64_def, &int64);
impl_eq(int64_def, &int64);
let float = ctx.get_primitive(FLOAT_TYPE);
let float_def = ctx.get_primitive_def_mut(FLOAT_TYPE);
impl_math(float_def, &float);
impl_order(float_def, &float);
impl_eq(float_def, &float);
let t = ctx.add_variable_private(VarDef {
name: "T",
bound: vec![],
});
ctx.add_parametric(ParametricDef {
base: TypeDef {
name: "tuple",
fields: HashMap::new(),
methods: HashMap::new(),
},
// we have nothing for tuple, so no param def
params: vec![],
});
ctx.add_parametric(ParametricDef {
base: TypeDef {
name: "list",
fields: HashMap::new(),
methods: HashMap::new(),
},
params: vec![t],
});
let i = ctx.add_variable_private(VarDef {
name: "I",
bound: vec![
PrimitiveType(INT32_TYPE).into(),
PrimitiveType(INT64_TYPE).into(),
PrimitiveType(FLOAT_TYPE).into(),
],
});
let args = vec![TypeVariable(i).into()];
ctx.add_fn(
"int32",
FnDef {
args: args.clone(),
result: Some(PrimitiveType(INT32_TYPE).into()),
},
);
ctx.add_fn(
"int64",
FnDef {
args: args.clone(),
result: Some(PrimitiveType(INT64_TYPE).into()),
},
);
ctx.add_fn(
"float",
FnDef {
args,
result: Some(PrimitiveType(FLOAT_TYPE).into()),
},
);
ctx
}

View File

@ -0,0 +1,503 @@
/// obtain class and function signature from AST
use super::context::TopLevelContext;
use super::primitives::*;
use super::typedef::*;
use rustpython_parser::ast::{
ComprehensionKind, ExpressionType, Statement, StatementType, StringGroup,
};
use std::collections::HashMap;
// TODO: fix condition checking, return error message instead of panic...
fn typename_from_expr<'a>(typenames: &mut Vec<&'a str>, expr: &'a ExpressionType) {
match expr {
ExpressionType::Identifier { name } => typenames.push(&name),
ExpressionType::String { value } => match value {
StringGroup::Constant { value } => typenames.push(&value),
_ => unimplemented!(),
},
ExpressionType::Subscript { a, b } => {
typename_from_expr(typenames, &b.node);
typename_from_expr(typenames, &a.node)
}
_ => unimplemented!(),
}
}
fn typename_from_fn<'a>(typenames: &mut Vec<&'a str>, fun: &'a StatementType) {
match fun {
StatementType::FunctionDef { args, returns, .. } => {
for arg in args.args.iter() {
if let Some(ann) = &arg.annotation {
typename_from_expr(typenames, &ann.node);
}
}
if let Some(returns) = &returns {
typename_from_expr(typenames, &returns.node);
}
}
_ => unreachable!(),
}
}
fn name_from_expr<'a>(expr: &'a ExpressionType) -> &'a str {
match &expr {
ExpressionType::Identifier { name } => &name,
ExpressionType::String { value } => match value {
StringGroup::Constant { value } => &value,
_ => unimplemented!(),
},
_ => unimplemented!(),
}
}
fn type_from_expr<'a>(ctx: &'a TopLevelContext, expr: &'a ExpressionType) -> Result<Type, String> {
match expr {
ExpressionType::Identifier { name } => {
ctx.get_type(name).ok_or_else(|| "no such type".into())
}
ExpressionType::String { value } => match value {
StringGroup::Constant { value } => {
ctx.get_type(&value).ok_or_else(|| "no such type".into())
}
_ => unimplemented!(),
},
ExpressionType::Subscript { a, b } => {
if let ExpressionType::Identifier { name } = &a.node {
match name.as_str() {
"list" => {
let ty = type_from_expr(ctx, &b.node)?;
Ok(TypeEnum::ParametricType(LIST_TYPE, vec![ty]).into())
}
"tuple" => {
if let ExpressionType::Tuple { elements } = &b.node {
let ty_list: Result<Vec<_>, _> = elements
.iter()
.map(|v| type_from_expr(ctx, &v.node))
.collect();
Ok(TypeEnum::ParametricType(TUPLE_TYPE, ty_list?).into())
} else {
Err("unsupported format".into())
}
}
_ => Err("no such parameterized type".into()),
}
} else {
// we require a to be an identifier, for a[b]
Err("unsupported format".into())
}
}
_ => Err("unsupported format".into()),
}
}
pub fn get_typenames<'a>(stmts: &'a [Statement]) -> (Vec<&'a str>, Vec<&'a str>) {
let mut classes = Vec::new();
let mut typenames = Vec::new();
for stmt in stmts.iter() {
match &stmt.node {
StatementType::ClassDef {
name, body, bases, ..
} => {
// check if class is not duplicated...
// and annotations
classes.push(&name[..]);
for base in bases.iter() {
let name = name_from_expr(&base.node);
typenames.push(name);
}
// may check if fields/functions are not duplicated
for stmt in body.iter() {
match &stmt.node {
StatementType::AnnAssign { annotation, .. } => {
typename_from_expr(&mut typenames, &annotation.node)
}
StatementType::FunctionDef { .. } => {
typename_from_fn(&mut typenames, &stmt.node);
}
_ => unimplemented!(),
}
}
}
StatementType::FunctionDef { .. } => {
// may check annotations
typename_from_fn(&mut typenames, &stmt.node);
}
_ => (),
}
}
let mut unknowns = Vec::new();
for n in typenames {
if !PRIMITIVES.contains(&n) && !classes.contains(&n) && !unknowns.contains(&n) {
unknowns.push(n);
}
}
(classes, unknowns)
}
fn resolve_function<'a>(
ctx: &'a TopLevelContext,
fun: &'a StatementType,
method: bool,
) -> Result<FnDef, String> {
if let StatementType::FunctionDef { args, returns, .. } = &fun {
let args = if method {
args.args[1..].iter()
} else {
args.args.iter()
};
let args: Result<Vec<_>, _> = args
.map(|arg| type_from_expr(ctx, &arg.annotation.as_ref().unwrap().node))
.collect();
let args = args?;
let result = match returns {
Some(v) => Some(type_from_expr(ctx, &v.node)?),
None => None,
};
Ok(FnDef { args, result })
} else {
unreachable!()
}
}
fn get_expr_unknowns<'a>(
defined: &mut Vec<&'a str>,
unknowns: &mut Vec<&'a str>,
expr: &'a ExpressionType,
) {
match expr {
ExpressionType::BoolOp { values, .. } => {
for v in values.iter() {
get_expr_unknowns(defined, unknowns, &v.node)
}
}
ExpressionType::Binop { a, b, .. } => {
get_expr_unknowns(defined, unknowns, &a.node);
get_expr_unknowns(defined, unknowns, &b.node);
}
ExpressionType::Subscript { a, b } => {
get_expr_unknowns(defined, unknowns, &a.node);
get_expr_unknowns(defined, unknowns, &b.node);
}
ExpressionType::Unop { a, .. } => {
get_expr_unknowns(defined, unknowns, &a.node);
}
ExpressionType::Compare { vals, .. } => {
for v in vals.iter() {
get_expr_unknowns(defined, unknowns, &v.node)
}
}
ExpressionType::Attribute { value, .. } => {
get_expr_unknowns(defined, unknowns, &value.node);
}
ExpressionType::Call { function, args, .. } => {
get_expr_unknowns(defined, unknowns, &function.node);
for v in args.iter() {
get_expr_unknowns(defined, unknowns, &v.node)
}
}
ExpressionType::List { elements } => {
for v in elements.iter() {
get_expr_unknowns(defined, unknowns, &v.node)
}
}
ExpressionType::Tuple { elements } => {
for v in elements.iter() {
get_expr_unknowns(defined, unknowns, &v.node)
}
}
ExpressionType::Comprehension { kind, generators } => {
if generators.len() != 1 {
unimplemented!()
}
let g = &generators[0];
get_expr_unknowns(defined, unknowns, &g.iter.node);
let mut scoped = defined.clone();
get_expr_unknowns(defined, &mut scoped, &g.target.node);
for if_expr in g.ifs.iter() {
get_expr_unknowns(&mut scoped, unknowns, &if_expr.node);
}
match kind.as_ref() {
ComprehensionKind::List { element } => {
get_expr_unknowns(&mut scoped, unknowns, &element.node);
}
_ => unimplemented!(),
}
}
ExpressionType::Slice { elements } => {
for v in elements.iter() {
get_expr_unknowns(defined, unknowns, &v.node);
}
}
ExpressionType::Identifier { name } => {
if !defined.contains(&name.as_str()) && !unknowns.contains(&name.as_str()) {
unknowns.push(name);
}
}
ExpressionType::IfExpression { test, body, orelse } => {
get_expr_unknowns(defined, unknowns, &test.node);
get_expr_unknowns(defined, unknowns, &body.node);
get_expr_unknowns(defined, unknowns, &orelse.node);
}
_ => (),
};
}
struct ExprPattern<'a>(&'a ExpressionType, Vec<usize>, bool);
impl<'a> ExprPattern<'a> {
fn new(expr: &'a ExpressionType) -> ExprPattern {
let mut pattern = ExprPattern(expr, Vec::new(), true);
pattern.find_leaf();
pattern
}
fn pointed(&mut self) -> &'a ExpressionType {
let mut current = self.0;
for v in self.1.iter() {
if let ExpressionType::Tuple { elements } = current {
current = &elements[*v].node
} else {
unreachable!()
}
}
current
}
fn find_leaf(&mut self) {
let mut current = self.pointed();
while let ExpressionType::Tuple { elements } = current {
if elements.is_empty() {
break;
}
current = &elements[0].node;
self.1.push(0);
}
}
fn inc(&mut self) -> bool {
loop {
if self.1.is_empty() {
return false;
}
let ind = self.1.pop().unwrap() + 1;
let parent = self.pointed();
if let ExpressionType::Tuple { elements } = parent {
if ind < elements.len() {
self.1.push(ind);
self.find_leaf();
return true;
}
} else {
unreachable!()
}
}
}
}
impl<'a> Iterator for ExprPattern<'a> {
type Item = &'a ExpressionType;
fn next(&mut self) -> Option<Self::Item> {
if self.2 {
self.2 = false;
Some(self.pointed())
} else if self.inc() {
Some(self.pointed())
} else {
None
}
}
}
fn get_stmt_unknowns<'a>(
defined: &mut Vec<&'a str>,
unknowns: &mut Vec<&'a str>,
stmts: &'a [Statement],
) {
for stmt in stmts.iter() {
match &stmt.node {
StatementType::Return { value } => {
if let Some(value) = value {
get_expr_unknowns(defined, unknowns, &value.node);
}
}
StatementType::Assign { targets, value } => {
get_expr_unknowns(defined, unknowns, &value.node);
for target in targets.iter() {
for node in ExprPattern::new(&target.node).into_iter() {
if let ExpressionType::Identifier { name } = node {
let name = name.as_str();
if !defined.contains(&name) {
defined.push(name);
}
} else {
get_expr_unknowns(defined, unknowns, node);
}
}
}
}
StatementType::AugAssign { target, value, .. } => {
get_expr_unknowns(defined, unknowns, &target.node);
get_expr_unknowns(defined, unknowns, &value.node);
}
StatementType::AnnAssign { target, value, .. } => {
get_expr_unknowns(defined, unknowns, &target.node);
if let Some(value) = value {
get_expr_unknowns(defined, unknowns, &value.node);
}
}
StatementType::Expression { expression } => {
get_expr_unknowns(defined, unknowns, &expression.node);
}
StatementType::Global { names } => {
for name in names.iter() {
let name = name.as_str();
if !unknowns.contains(&name) {
unknowns.push(name);
}
}
}
StatementType::If { test, body, orelse }
| StatementType::While { test, body, orelse } => {
get_expr_unknowns(defined, unknowns, &test.node);
// we are not very strict at this point...
// some identifiers not treated as unknowns may not be resolved
// but should be checked during type inference
get_stmt_unknowns(defined, unknowns, body.as_slice());
if let Some(orelse) = orelse {
get_stmt_unknowns(defined, unknowns, orelse.as_slice());
}
}
StatementType::For { is_async, target, iter, body, orelse } => {
if *is_async {
unimplemented!()
}
get_expr_unknowns(defined, unknowns, &iter.node);
for node in ExprPattern::new(&target.node).into_iter() {
if let ExpressionType::Identifier { name } = node {
let name = name.as_str();
if !defined.contains(&name) {
defined.push(name);
}
} else {
get_expr_unknowns(defined, unknowns, node);
}
}
get_stmt_unknowns(defined, unknowns, body.as_slice());
if let Some(orelse) = orelse {
get_stmt_unknowns(defined, unknowns, orelse.as_slice());
}
}
_ => (),
}
}
}
pub fn resolve_signatures<'a>(ctx: &mut TopLevelContext<'a>, stmts: &'a [Statement]) {
for stmt in stmts.iter() {
match &stmt.node {
StatementType::ClassDef {
name, bases, body, ..
} => {
let mut parents = Vec::new();
for base in bases.iter() {
let name = name_from_expr(&base.node);
let c = ctx.get_type(name).unwrap();
let id = if let TypeEnum::ClassType(id) = c.as_ref() {
*id
} else {
unreachable!()
};
parents.push(id);
}
let mut fields = HashMap::new();
let mut functions = HashMap::new();
for stmt in body.iter() {
match &stmt.node {
StatementType::AnnAssign {
target, annotation, ..
} => {
let name = name_from_expr(&target.node);
let ty = type_from_expr(ctx, &annotation.node).unwrap();
fields.insert(name, ty);
}
StatementType::FunctionDef { name, .. } => {
functions.insert(
&name[..],
resolve_function(ctx, &stmt.node, true).unwrap(),
);
}
_ => unimplemented!(),
}
}
let class = ctx.get_type(name).unwrap();
let class = if let TypeEnum::ClassType(id) = class.as_ref() {
ctx.get_class_def_mut(*id)
} else {
unreachable!()
};
class.parents.extend_from_slice(&parents);
class.base.fields.clone_from(&fields);
class.base.methods.clone_from(&functions);
}
StatementType::FunctionDef { name, .. } => {
ctx.add_fn(&name[..], resolve_function(ctx, &stmt.node, false).unwrap());
}
_ => unimplemented!(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use indoc::indoc;
use rustpython_parser::parser::{parse_program, parse_statement};
#[test]
fn test_get_classes() {
let ast = parse_program(indoc! {"
class Foo:
a: int32
b: Test
def test(self, a: int32) -> Test2:
return self.b
class Bar(Foo, 'FooBar'):
def test2(self, a: list[Foo]) -> Test2:
return self.b
def test3(self, a: list[FooBar2]) -> Test2:
return self.b
" })
.unwrap();
let (mut classes, mut unknowns) = get_typenames(&ast.statements);
let classes_count = classes.len();
let unknowns_count = unknowns.len();
classes.sort();
unknowns.sort();
assert_eq!(classes.len(), classes_count);
assert_eq!(unknowns.len(), unknowns_count);
assert_eq!(&classes, &["Bar", "Foo"]);
assert_eq!(&unknowns, &["FooBar", "FooBar2", "Test", "Test2"]);
}
#[test]
fn test_assignment() {
let ast = parse_statement(indoc! {"
((a, b), c[i]) = core.foo(x, get_y())
" })
.unwrap();
let mut defined = Vec::new();
let mut unknowns = Vec::new();
get_stmt_unknowns(&mut defined, &mut unknowns, ast.as_slice());
defined.sort();
unknowns.sort();
assert_eq!(defined.as_slice(), &["a", "b"]);
assert_eq!(unknowns.as_slice(), &["c", "core", "get_y", "i", "x"]);
}
}

View File

@ -0,0 +1,572 @@
use super::context::InferenceContext;
use super::expression_inference::{infer_expr, infer_simple_binding};
use super::inference_core::resolve_call;
use super::magic_methods::binop_assign_name;
use super::primitives::*;
use super::typedef::{Type, TypeEnum::*};
use rustpython_parser::ast::*;
pub fn check_stmts<'b: 'a, 'a>(
ctx: &mut InferenceContext<'a>,
stmts: &'b [Statement],
) -> Result<bool, String> {
for stmt in stmts.iter() {
match &stmt.node {
StatementType::Assign { targets, value } => {
check_assign(ctx, targets.as_slice(), &value)?;
}
StatementType::AugAssign { target, op, value } => {
check_aug_assign(ctx, &target, op, &value)?;
}
StatementType::If { test, body, orelse } => {
if check_if(ctx, test, body.as_slice(), orelse)? {
return Ok(true);
}
}
StatementType::While { test, body, orelse } => {
check_while_stmt(ctx, test, body.as_slice(), orelse)?;
}
StatementType::For {
is_async,
target,
iter,
body,
orelse,
} => {
if *is_async {
return Err("async for is not supported".to_string());
}
check_for_stmt(ctx, target, iter, body.as_slice(), orelse)?;
}
StatementType::Return { value } => {
let result = ctx.get_result();
let t = if let Some(value) = value {
infer_expr(ctx, value)?
} else {
None
};
return if t == result {
Ok(true)
} else {
Err("return type mismatch".to_string())
};
}
StatementType::Continue | StatementType::Break => {
continue;
}
_ => return Err("not supported".to_string()),
}
}
Ok(false)
}
fn get_target_type<'b: 'a, 'a>(
ctx: &mut InferenceContext<'a>,
target: &'b Expression,
) -> Result<Type, String> {
match &target.node {
ExpressionType::Subscript { a, b } => {
let int32 = ctx.get_primitive(INT32_TYPE);
if infer_expr(ctx, &a)? == Some(int32) {
let b = get_target_type(ctx, &b)?;
if let ParametricType(LIST_TYPE, t) = b.as_ref() {
Ok(t[0].clone())
} else {
Err("subscript is only supported for list".to_string())
}
} else {
Err("subscript must be int32".to_string())
}
}
ExpressionType::Attribute { value, name } => {
let t = get_target_type(ctx, &value)?;
let base = t.get_base(ctx).ok_or_else(|| "no attributes".to_string())?;
Ok(base
.fields
.get(name.as_str())
.ok_or_else(|| "no such attribute")?
.clone())
}
ExpressionType::Identifier { name } => Ok(ctx.resolve(name.as_str())?),
_ => Err("not supported".to_string()),
}
}
fn check_stmt_binding<'b: 'a, 'a>(
ctx: &mut InferenceContext<'a>,
target: &'b Expression,
ty: Type,
) -> Result<(), String> {
match &target.node {
ExpressionType::Identifier { name } => {
if name.as_str() == "_" {
Ok(())
} else {
match ctx.resolve(name.as_str()) {
Ok(t) if t == ty => Ok(()),
Err(_) => {
ctx.assign(name.as_str(), ty).unwrap();
Ok(())
}
_ => Err("conflicting type".into()),
}
}
}
ExpressionType::Tuple { elements } => {
if let ParametricType(TUPLE_TYPE, ls) = ty.as_ref() {
if ls.len() != elements.len() {
return Err("incorrect pattern length".into());
}
for (x, y) in elements.iter().zip(ls.iter()) {
check_stmt_binding(ctx, x, y.clone())?;
}
Ok(())
} else {
Err("pattern matching supports tuple only".into())
}
}
_ => {
let t = get_target_type(ctx, target)?;
if ty == t {
Ok(())
} else {
Err("type mismatch".into())
}
}
}
}
fn check_assign<'b: 'a, 'a>(
ctx: &mut InferenceContext<'a>,
targets: &'b [Expression],
value: &'b Expression,
) -> Result<(), String> {
let ty = infer_expr(ctx, value)?.ok_or_else(|| "no value".to_string())?;
for t in targets.iter() {
check_stmt_binding(ctx, t, ty.clone())?;
}
Ok(())
}
fn check_aug_assign<'b: 'a, 'a>(
ctx: &mut InferenceContext<'a>,
target: &'b Expression,
op: &'b Operator,
value: &'b Expression,
) -> Result<(), String> {
let left = infer_expr(ctx, target)?.ok_or_else(|| "no value".to_string())?;
let right = infer_expr(ctx, value)?.ok_or_else(|| "no value".to_string())?;
let fun = binop_assign_name(op);
resolve_call(ctx, Some(left), fun, &[right])?;
Ok(())
}
fn check_if<'b: 'a, 'a>(
ctx: &mut InferenceContext<'a>,
test: &'b Expression,
body: &'b [Statement],
orelse: &'b Option<Suite>,
) -> Result<bool, String> {
let boolean = ctx.get_primitive(BOOL_TYPE);
let t = infer_expr(ctx, test)?;
if t == Some(boolean) {
let (names, result) = ctx.with_scope(|ctx| check_stmts(ctx, body));
let returned = result?;
if let Some(orelse) = orelse {
let (names2, result) = ctx.with_scope(|ctx| check_stmts(ctx, orelse.as_slice()));
let returned = returned && result?;
for (name, ty) in names.iter() {
for (name2, ty2) in names2.iter() {
if *name == *name2 && ty == ty2 {
ctx.assign(name, ty.clone()).unwrap();
}
}
}
Ok(returned)
} else {
Ok(false)
}
} else {
Err("condition should be bool".to_string())
}
}
fn check_while_stmt<'b: 'a, 'a>(
ctx: &mut InferenceContext<'a>,
test: &'b Expression,
body: &'b [Statement],
orelse: &'b Option<Suite>,
) -> Result<bool, String> {
let boolean = ctx.get_primitive(BOOL_TYPE);
let t = infer_expr(ctx, test)?;
if t == Some(boolean) {
// to check what variables are defined, we would have to do a graph analysis...
// not implemented now
let (_, result) = ctx.with_scope(|ctx| check_stmts(ctx, body));
result?;
if let Some(orelse) = orelse {
let (_, result) = ctx.with_scope(|ctx| check_stmts(ctx, orelse.as_slice()));
result?;
}
// to check whether the loop returned on every possible path, we need to analyse the graph,
// not implemented now
Ok(false)
} else {
Err("condition should be bool".to_string())
}
}
fn check_for_stmt<'b: 'a, 'a>(
ctx: &mut InferenceContext<'a>,
target: &'b Expression,
iter: &'b Expression,
body: &'b [Statement],
orelse: &'b Option<Suite>,
) -> Result<bool, String> {
let ty = infer_expr(ctx, iter)?.ok_or_else(|| "no value".to_string())?;
if let ParametricType(LIST_TYPE, ls) = ty.as_ref() {
let (_, result) = ctx.with_scope(|ctx| {
infer_simple_binding(ctx, target, ls[0].clone())?;
check_stmts(ctx, body)
});
result?;
if let Some(orelse) = orelse {
let (_, result) = ctx.with_scope(|ctx| check_stmts(ctx, orelse.as_slice()));
result?;
}
// to check whether the loop returned on every possible path, we need to analyse the graph,
// not implemented now
Ok(false)
} else {
Err("only list can be iterated over".to_string())
}
}
#[cfg(test)]
mod test {
use super::{super::context::*, *};
use indoc::indoc;
use rustpython_parser::parser::parse_program;
fn get_inference_context(ctx: TopLevelContext) -> InferenceContext {
InferenceContext::new(ctx, Box::new(|_| Err("unbounded identifier".into())))
}
#[test]
fn test_assign() {
let ctx = basic_ctx();
let mut ctx = get_inference_context(ctx);
let ast = parse_program(indoc! {"
a = 1
b = a * 2
" })
.unwrap();
ctx.with_scope(|ctx| {
assert_eq!(Ok(false), check_stmts(ctx, ast.statements.as_slice()));
});
let ast = parse_program(indoc! {"
a = 1
b = b * 2
" })
.unwrap();
ctx.with_scope(|ctx| {
assert_eq!(
Err("unbounded identifier".to_string()),
check_stmts(ctx, ast.statements.as_slice())
);
});
let ast = parse_program(indoc! {"
b = a = 1
b = b * 2
" })
.unwrap();
ctx.with_scope(|ctx| {
assert_eq!(Ok(false), check_stmts(ctx, ast.statements.as_slice()));
});
let ast = parse_program(indoc! {"
b = a = 1
b = [a]
" })
.unwrap();
ctx.with_scope(|ctx| {
assert_eq!(
Err("conflicting type".to_string()),
check_stmts(ctx, ast.statements.as_slice())
);
});
}
#[test]
fn test_if() {
let ctx = basic_ctx();
let mut ctx = get_inference_context(ctx);
let ast = parse_program(indoc! {"
a = 1
b = a * 2
if b > a:
c = 1
else:
c = 0
d = c
" })
.unwrap();
ctx.with_scope(|ctx| {
assert_eq!(Ok(false), check_stmts(ctx, ast.statements.as_slice()));
});
let ast = parse_program(indoc! {"
a = 1
b = a * 2
if b > a:
c = 1
else:
d = 0
d = c
" })
.unwrap();
ctx.with_scope(|ctx| {
assert_eq!(
Err("unbounded identifier".to_string()),
check_stmts(ctx, ast.statements.as_slice())
);
});
let ast = parse_program(indoc! {"
a = 1
b = a * 2
if b > a:
c = 1
d = c
" })
.unwrap();
ctx.with_scope(|ctx| {
assert_eq!(
Err("unbounded identifier".to_string()),
check_stmts(ctx, ast.statements.as_slice())
);
});
let ast = parse_program(indoc! {"
a = 1
b = a * 2
if a:
b = 0
" })
.unwrap();
ctx.with_scope(|ctx| {
assert_eq!(
Err("condition should be bool".to_string()),
check_stmts(ctx, ast.statements.as_slice())
);
});
let ast = parse_program(indoc! {"
a = 1
b = a * 2
if b > a:
c = 1
c = [1]
" })
.unwrap();
ctx.with_scope(|ctx| {
assert_eq!(Ok(false), check_stmts(ctx, ast.statements.as_slice()));
});
let ast = parse_program(indoc! {"
a = 1
b = a * 2
if b > a:
c = 1
else:
c = 0
c = [1]
" })
.unwrap();
ctx.with_scope(|ctx| {
assert_eq!(
Err("conflicting type".to_string()),
check_stmts(ctx, ast.statements.as_slice())
);
});
}
#[test]
fn test_while() {
let ctx = basic_ctx();
let mut ctx = get_inference_context(ctx);
let ast = parse_program(indoc! {"
a = 1
b = 1
while a < 10:
a += 1
b *= a
" })
.unwrap();
ctx.with_scope(|ctx| {
assert_eq!(Ok(false), check_stmts(ctx, ast.statements.as_slice()));
});
let ast = parse_program(indoc! {"
a = 1
b = 1
while a < 10:
a += 1
b *= a
" })
.unwrap();
ctx.with_scope(|ctx| {
assert_eq!(Ok(false), check_stmts(ctx, ast.statements.as_slice()));
});
let ast = parse_program(indoc! {"
a = 1
b = 1
while a < 10:
a += 1
b *= a
else:
a += 1
" })
.unwrap();
ctx.with_scope(|ctx| {
assert_eq!(Ok(false), check_stmts(ctx, ast.statements.as_slice()));
});
let ast = parse_program(indoc! {"
a = 1
b = 1
while a:
a += 1
" })
.unwrap();
ctx.with_scope(|ctx| {
assert_eq!(
Err("condition should be bool".to_string()),
check_stmts(ctx, ast.statements.as_slice())
);
});
let ast = parse_program(indoc! {"
a = 1
b = 1
while a < 10:
a += 1
c = a*2
else:
c = a*2
b = c
" })
.unwrap();
ctx.with_scope(|ctx| {
assert_eq!(
Err("unbounded identifier".to_string()),
check_stmts(ctx, ast.statements.as_slice())
);
});
}
#[test]
fn test_for() {
let ctx = basic_ctx();
let mut ctx = get_inference_context(ctx);
let ast = parse_program(indoc! {"
b = 1
for a in [0, 1, 2, 3, 4, 5]:
b *= a
" })
.unwrap();
ctx.with_scope(|ctx| {
assert_eq!(Ok(false), check_stmts(ctx, ast.statements.as_slice()));
});
let ast = parse_program(indoc! {"
b = 1
for a, a1 in [(0, 1), (2, 3), (4, 5)]:
b *= a
" })
.unwrap();
ctx.with_scope(|ctx| {
assert_eq!(Ok(false), check_stmts(ctx, ast.statements.as_slice()));
});
}
#[test]
fn test_return() {
let ctx = basic_ctx();
let mut ctx = get_inference_context(ctx);
let ast = parse_program(indoc! {"
b = 1
return
" })
.unwrap();
ctx.with_scope(|ctx| {
assert_eq!(Ok(true), check_stmts(ctx, ast.statements.as_slice()));
});
let ast = parse_program(indoc! {"
b = 1
if b > 0:
return
" })
.unwrap();
ctx.with_scope(|ctx| {
assert_eq!(Ok(false), check_stmts(ctx, ast.statements.as_slice()));
});
let ast = parse_program(indoc! {"
b = 1
if b > 0:
return
else:
return
" })
.unwrap();
ctx.with_scope(|ctx| {
assert_eq!(Ok(true), check_stmts(ctx, ast.statements.as_slice()));
});
let ast = parse_program(indoc! {"
b = 1
while b > 0:
return
else:
return
" })
.unwrap();
ctx.with_scope(|ctx| {
// with sophisticated analysis, this one should be Ok(true)
// but with our simple implementation, this is Ok(false)
// as we don't analyse the control flow
assert_eq!(Ok(false), check_stmts(ctx, ast.statements.as_slice()));
});
ctx.set_result(Some(ctx.get_primitive(INT32_TYPE)));
let ast = parse_program(indoc! {"
b = 1
return 1
" })
.unwrap();
ctx.with_scope(|ctx| {
assert_eq!(Ok(true), check_stmts(ctx, ast.statements.as_slice()));
});
ctx.set_result(Some(ctx.get_primitive(INT32_TYPE)));
let ast = parse_program(indoc! {"
b = 1
return [1]
" })
.unwrap();
ctx.with_scope(|ctx| {
assert_eq!(
Err("return type mismatch".to_string()),
check_stmts(ctx, ast.statements.as_slice())
);
});
}
}

View File

@ -0,0 +1,60 @@
use std::collections::HashMap;
use std::rc::Rc;
#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
pub struct PrimitiveId(pub(crate) usize);
#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
pub struct ClassId(pub(crate) usize);
#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
pub struct ParamId(pub(crate) usize);
#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
pub struct VariableId(pub(crate) usize);
#[derive(PartialEq, Eq, Clone, Hash, Debug)]
pub enum TypeEnum {
BotType,
SelfType,
PrimitiveType(PrimitiveId),
ClassType(ClassId),
VirtualClassType(ClassId),
ParametricType(ParamId, Vec<Rc<TypeEnum>>),
TypeVariable(VariableId),
}
pub type Type = Rc<TypeEnum>;
#[derive(Clone)]
pub struct FnDef {
// we assume methods first argument to be SelfType,
// so the first argument is not contained here
pub args: Vec<Type>,
pub result: Option<Type>,
}
#[derive(Clone)]
pub struct TypeDef<'a> {
pub name: &'a str,
pub fields: HashMap<&'a str, Type>,
pub methods: HashMap<&'a str, FnDef>,
}
#[derive(Clone)]
pub struct ClassDef<'a> {
pub base: TypeDef<'a>,
pub parents: Vec<ClassId>,
}
#[derive(Clone)]
pub struct ParametricDef<'a> {
pub base: TypeDef<'a>,
pub params: Vec<VariableId>,
}
#[derive(Clone)]
pub struct VarDef<'a> {
pub name: &'a str,
pub bound: Vec<Type>,
}

View File

@ -1,377 +0,0 @@
use crate::typecheck::typedef::TypeEnum;
use super::type_inferencer::Inferencer;
use super::typedef::Type;
use nac3parser::ast::{self, Constant, Expr, ExprKind, Operator::{LShift, RShift}, Stmt, StmtKind, StrRef};
use std::{collections::HashSet, iter::once};
impl<'a> Inferencer<'a> {
fn should_have_value(&mut self, expr: &Expr<Option<Type>>) -> Result<(), HashSet<String>> {
if matches!(expr.custom, Some(ty) if self.unifier.unioned(ty, self.primitives.none)) {
Err(HashSet::from([
format!("Error at {}: cannot have value none", expr.location),
]))
} else {
Ok(())
}
}
fn check_pattern(
&mut self,
pattern: &Expr<Option<Type>>,
defined_identifiers: &mut HashSet<StrRef>,
) -> Result<(), HashSet<String>> {
match &pattern.node {
ExprKind::Name { id, .. } if id == &"none".into() => Err(HashSet::from([
format!("cannot assign to a `none` (at {})", pattern.location),
])),
ExprKind::Name { id, .. } => {
if !defined_identifiers.contains(id) {
defined_identifiers.insert(*id);
}
self.should_have_value(pattern)?;
Ok(())
}
ExprKind::Tuple { elts, .. } => {
for elt in elts {
self.check_pattern(elt, defined_identifiers)?;
self.should_have_value(elt)?;
}
Ok(())
}
ExprKind::Subscript { value, slice, .. } => {
self.check_expr(value, defined_identifiers)?;
self.should_have_value(value)?;
self.check_expr(slice, defined_identifiers)?;
if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) {
return Err(HashSet::from([
format!(
"Error at {}: cannot assign to tuple element",
value.location
),
]))
}
Ok(())
}
ExprKind::Constant { .. } => {
Err(HashSet::from([
format!("cannot assign to a constant (at {})", pattern.location),
]))
}
_ => self.check_expr(pattern, defined_identifiers),
}
}
fn check_expr(
&mut self,
expr: &Expr<Option<Type>>,
defined_identifiers: &mut HashSet<StrRef>,
) -> Result<(), HashSet<String>> {
// there are some cases where the custom field is None
if let Some(ty) = &expr.custom {
if !matches!(&expr.node, ExprKind::Constant { value: Constant::Ellipsis, .. }) && !self.unifier.is_concrete(*ty, &self.function_data.bound_variables) {
return Err(HashSet::from([
format!(
"expected concrete type at {} but got {}",
expr.location,
self.unifier.get_ty(*ty).get_type_name()
)
]))
}
}
match &expr.node {
ExprKind::Name { id, .. } => {
if id == &"none".into() {
return Ok(());
}
self.should_have_value(expr)?;
if !defined_identifiers.contains(id) {
match self.function_data.resolver.get_symbol_type(
self.unifier,
&self.top_level.definitions.read(),
self.primitives,
*id,
) {
Ok(_) => {
self.defined_identifiers.insert(*id);
}
Err(e) => {
return Err(HashSet::from([
format!(
"type error at identifier `{}` ({}) at {}",
id, e, expr.location
)
]))
}
}
}
}
ExprKind::List { elts, .. }
| ExprKind::Tuple { elts, .. }
| ExprKind::BoolOp { values: elts, .. } => {
for elt in elts {
self.check_expr(elt, defined_identifiers)?;
self.should_have_value(elt)?;
}
}
ExprKind::Attribute { value, .. } => {
self.check_expr(value, defined_identifiers)?;
self.should_have_value(value)?;
}
ExprKind::BinOp { left, op, right } => {
self.check_expr(left, defined_identifiers)?;
self.check_expr(right, defined_identifiers)?;
self.should_have_value(left)?;
self.should_have_value(right)?;
// Check whether a bitwise shift has a negative RHS constant value
if *op == LShift || *op == RShift {
if let ExprKind::Constant { value, .. } = &right.node {
let Constant::Int(rhs_val) = value else {
unreachable!()
};
if *rhs_val < 0 {
return Err(HashSet::from([
format!(
"shift count is negative at {}",
right.location
),
]))
}
}
}
}
ExprKind::UnaryOp { operand, .. } => {
self.check_expr(operand, defined_identifiers)?;
self.should_have_value(operand)?;
}
ExprKind::Compare { left, comparators, .. } => {
for elt in once(left.as_ref()).chain(comparators.iter()) {
self.check_expr(elt, defined_identifiers)?;
self.should_have_value(elt)?;
}
}
ExprKind::Subscript { value, slice, .. } => {
self.should_have_value(value)?;
self.check_expr(value, defined_identifiers)?;
self.check_expr(slice, defined_identifiers)?;
}
ExprKind::IfExp { test, body, orelse } => {
self.check_expr(test, defined_identifiers)?;
self.check_expr(body, defined_identifiers)?;
self.check_expr(orelse, defined_identifiers)?;
}
ExprKind::Slice { lower, upper, step } => {
for elt in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() {
self.should_have_value(elt)?;
self.check_expr(elt, defined_identifiers)?;
}
}
ExprKind::Lambda { args, body } => {
let mut defined_identifiers = defined_identifiers.clone();
for arg in &args.args {
// TODO: should we check the types here?
if !defined_identifiers.contains(&arg.node.arg) {
defined_identifiers.insert(arg.node.arg);
}
}
self.check_expr(body, &mut defined_identifiers)?;
}
ExprKind::ListComp { elt, generators, .. } => {
// in our type inference stage, we already make sure that there is only 1 generator
let ast::Comprehension { target, iter, ifs, .. } = &generators[0];
self.check_expr(iter, defined_identifiers)?;
self.should_have_value(iter)?;
let mut defined_identifiers = defined_identifiers.clone();
self.check_pattern(target, &mut defined_identifiers)?;
self.should_have_value(target)?;
for term in once(elt.as_ref()).chain(ifs.iter()) {
self.check_expr(term, &mut defined_identifiers)?;
self.should_have_value(term)?;
}
}
ExprKind::Call { func, args, keywords } => {
for expr in once(func.as_ref())
.chain(args.iter())
.chain(keywords.iter().map(|v| v.node.value.as_ref()))
{
self.check_expr(expr, defined_identifiers)?;
self.should_have_value(expr)?;
}
}
ExprKind::Constant { .. } => {}
_ => {
unimplemented!()
}
}
Ok(())
}
/// Check that the return value is a non-`alloca` type, effectively only allowing primitive types.
///
/// This is a workaround preventing the caller from using a variable `alloca`-ed in the body, which
/// is freed when the function returns.
fn check_return_value_ty(&mut self, ret_ty: Type) -> bool {
match &*self.unifier.get_ty_immutable(ret_ty) {
TypeEnum::TObj { .. } => {
[
self.primitives.int32,
self.primitives.int64,
self.primitives.uint32,
self.primitives.uint64,
self.primitives.float,
self.primitives.bool,
].iter().any(|allowed_ty| self.unifier.unioned(ret_ty, *allowed_ty))
}
TypeEnum::TTuple { ty } => ty.iter().all(|t| self.check_return_value_ty(*t)),
_ => false,
}
}
// check statements for proper identifier def-use and return on all paths
fn check_stmt(
&mut self,
stmt: &Stmt<Option<Type>>,
defined_identifiers: &mut HashSet<StrRef>,
) -> Result<bool, HashSet<String>> {
match &stmt.node {
StmtKind::For { target, iter, body, orelse, .. } => {
self.check_expr(iter, defined_identifiers)?;
self.should_have_value(iter)?;
let mut local_defined_identifiers = defined_identifiers.clone();
for stmt in orelse {
self.check_stmt(stmt, &mut local_defined_identifiers)?;
}
let mut local_defined_identifiers = defined_identifiers.clone();
self.check_pattern(target, &mut local_defined_identifiers)?;
self.should_have_value(target)?;
for stmt in body {
self.check_stmt(stmt, &mut local_defined_identifiers)?;
}
Ok(false)
}
StmtKind::If { test, body, orelse, .. } => {
self.check_expr(test, defined_identifiers)?;
self.should_have_value(test)?;
let mut body_identifiers = defined_identifiers.clone();
let mut orelse_identifiers = defined_identifiers.clone();
let body_returned = self.check_block(body, &mut body_identifiers)?;
let orelse_returned = self.check_block(orelse, &mut orelse_identifiers)?;
for ident in &body_identifiers {
if !defined_identifiers.contains(ident) && orelse_identifiers.contains(ident) {
defined_identifiers.insert(*ident);
}
}
Ok(body_returned && orelse_returned)
}
StmtKind::While { test, body, orelse, .. } => {
self.check_expr(test, defined_identifiers)?;
self.should_have_value(test)?;
let mut defined_identifiers = defined_identifiers.clone();
self.check_block(body, &mut defined_identifiers)?;
self.check_block(orelse, &mut defined_identifiers)?;
Ok(false)
}
StmtKind::With { items, body, .. } => {
let mut new_defined_identifiers = defined_identifiers.clone();
for item in items {
self.check_expr(&item.context_expr, defined_identifiers)?;
if let Some(var) = item.optional_vars.as_ref() {
self.check_pattern(var, &mut new_defined_identifiers)?;
}
}
self.check_block(body, &mut new_defined_identifiers)?;
Ok(false)
}
StmtKind::Try { body, handlers, orelse, finalbody, .. } => {
self.check_block(body, &mut defined_identifiers.clone())?;
self.check_block(orelse, &mut defined_identifiers.clone())?;
for handler in handlers {
let mut defined_identifiers = defined_identifiers.clone();
let ast::ExcepthandlerKind::ExceptHandler { name, body, .. } = &handler.node;
if let Some(name) = name {
defined_identifiers.insert(*name);
}
self.check_block(body, &mut defined_identifiers)?;
}
self.check_block(finalbody, defined_identifiers)?;
Ok(false)
}
StmtKind::Expr { value, .. } => {
self.check_expr(value, defined_identifiers)?;
Ok(false)
}
StmtKind::Assign { targets, value, .. } => {
self.check_expr(value, defined_identifiers)?;
self.should_have_value(value)?;
for target in targets {
self.check_pattern(target, defined_identifiers)?;
}
Ok(false)
}
StmtKind::AnnAssign { target, value, .. } => {
if let Some(value) = value {
self.check_expr(value, defined_identifiers)?;
self.should_have_value(value)?;
self.check_pattern(target, defined_identifiers)?;
}
Ok(false)
}
StmtKind::Return { value, .. } => {
if let Some(value) = value {
self.check_expr(value, defined_identifiers)?;
self.should_have_value(value)?;
// Check that the return value is a non-`alloca` type, effectively only allowing primitive types.
// This is a workaround preventing the caller from using a variable `alloca`-ed in the body, which
// is freed when the function returns.
if let Some(ret_ty) = value.custom {
// Explicitly allow ellipsis as a return value, as the type of the ellipsis is contextually
// inferred and just generates an unconditional assertion
if matches!(value.node, ExprKind::Constant { value: Constant::Ellipsis, .. }) {
return Ok(true)
}
if !self.check_return_value_ty(ret_ty) {
return Err(HashSet::from([
format!(
"return value of type {} must be a primitive or a tuple of primitives at {}",
self.unifier.stringify(ret_ty),
value.location,
),
]))
}
}
}
Ok(true)
}
StmtKind::Raise { exc, .. } => {
if let Some(value) = exc {
self.check_expr(value, defined_identifiers)?;
}
Ok(true)
}
// break, raise, etc.
_ => Ok(false),
}
}
pub fn check_block(
&mut self,
block: &[Stmt<Option<Type>>],
defined_identifiers: &mut HashSet<StrRef>,
) -> Result<bool, HashSet<String>> {
let mut ret = false;
for stmt in block {
if ret {
eprintln!("warning: dead code at {}\n", stmt.location);
}
if self.check_stmt(stmt, defined_identifiers)? {
ret = true;
}
}
Ok(ret)
}
}

View File

@ -1,660 +0,0 @@
use std::cmp::max;
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::PRIMITIVE_DEF_IDS;
use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys};
use crate::typecheck::{
type_inferencer::*,
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
};
use nac3parser::ast::StrRef;
use nac3parser::ast::{Cmpop, Operator, Unaryop};
use std::collections::HashMap;
use std::rc::Rc;
use itertools::Itertools;
#[must_use]
pub fn binop_name(op: &Operator) -> &'static str {
match op {
Operator::Add => "__add__",
Operator::Sub => "__sub__",
Operator::Div => "__truediv__",
Operator::Mod => "__mod__",
Operator::Mult => "__mul__",
Operator::Pow => "__pow__",
Operator::BitOr => "__or__",
Operator::BitXor => "__xor__",
Operator::BitAnd => "__and__",
Operator::LShift => "__lshift__",
Operator::RShift => "__rshift__",
Operator::FloorDiv => "__floordiv__",
Operator::MatMult => "__matmul__",
}
}
#[must_use]
pub fn binop_assign_name(op: &Operator) -> &'static str {
match op {
Operator::Add => "__iadd__",
Operator::Sub => "__isub__",
Operator::Div => "__itruediv__",
Operator::Mod => "__imod__",
Operator::Mult => "__imul__",
Operator::Pow => "__ipow__",
Operator::BitOr => "__ior__",
Operator::BitXor => "__ixor__",
Operator::BitAnd => "__iand__",
Operator::LShift => "__ilshift__",
Operator::RShift => "__irshift__",
Operator::FloorDiv => "__ifloordiv__",
Operator::MatMult => "__imatmul__",
}
}
#[must_use]
pub fn unaryop_name(op: &Unaryop) -> &'static str {
match op {
Unaryop::UAdd => "__pos__",
Unaryop::USub => "__neg__",
Unaryop::Not => "__not__",
Unaryop::Invert => "__inv__",
}
}
#[must_use]
pub fn comparison_name(op: &Cmpop) -> Option<&'static str> {
match op {
Cmpop::Lt => Some("__lt__"),
Cmpop::LtE => Some("__le__"),
Cmpop::Gt => Some("__gt__"),
Cmpop::GtE => Some("__ge__"),
Cmpop::Eq => Some("__eq__"),
Cmpop::NotEq => Some("__ne__"),
_ => None,
}
}
pub(super) fn with_fields<F>(unifier: &mut Unifier, ty: Type, f: F)
where
F: FnOnce(&mut Unifier, &mut HashMap<StrRef, (Type, bool)>),
{
let (id, mut fields, params) =
if let TypeEnum::TObj { obj_id, fields, params } = &*unifier.get_ty(ty) {
(*obj_id, fields.clone(), params.clone())
} else {
unreachable!()
};
f(unifier, &mut fields);
unsafe {
let unification_table = unifier.get_unification_table();
unification_table.set_value(ty, Rc::new(TypeEnum::TObj { obj_id: id, fields, params }));
}
}
pub fn impl_binop(
unifier: &mut Unifier,
_store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Option<Type>,
ops: &[Operator],
) {
with_fields(unifier, ty, |unifier, fields| {
let (other_ty, other_var_id) = if other_ty.len() == 1 {
(other_ty[0], None)
} else {
let (ty, var_id) = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None);
(ty, Some(var_id))
};
let function_vars = if let Some(var_id) = other_var_id {
vec![(var_id, other_ty)].into_iter().collect::<VarMap>()
} else {
VarMap::new()
};
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).0);
for op in ops {
fields.insert(binop_name(op).into(), {
(
unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ret_ty,
vars: function_vars.clone(),
args: vec![FuncArg {
ty: other_ty,
default_value: None,
name: "other".into(),
}],
})),
false,
)
});
fields.insert(binop_assign_name(op).into(), {
(
unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ret_ty,
vars: function_vars.clone(),
args: vec![FuncArg {
ty: other_ty,
default_value: None,
name: "other".into(),
}],
})),
false,
)
});
}
});
}
pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Option<Type>, ops: &[Unaryop]) {
with_fields(unifier, ty, |unifier, fields| {
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).0);
for op in ops {
fields.insert(
unaryop_name(op).into(),
(
unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ret_ty,
vars: VarMap::new(),
args: vec![],
})),
false,
),
);
}
});
}
pub fn impl_cmpop(
unifier: &mut Unifier,
_store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ops: &[Cmpop],
ret_ty: Option<Type>,
) {
with_fields(unifier, ty, |unifier, fields| {
let (other_ty, other_var_id) = if other_ty.len() == 1 {
(other_ty[0], None)
} else {
let (ty, var_id) = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None);
(ty, Some(var_id))
};
let function_vars = if let Some(var_id) = other_var_id {
vec![(var_id, other_ty)].into_iter().collect::<VarMap>()
} else {
VarMap::new()
};
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).0);
for op in ops {
fields.insert(
comparison_name(op).unwrap().into(),
(
unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ret_ty,
vars: function_vars.clone(),
args: vec![FuncArg {
ty: other_ty,
default_value: None,
name: "other".into(),
}],
})),
false,
),
);
}
});
}
/// `Add`, `Sub`, `Mult`
pub fn impl_basic_arithmetic(
unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Option<Type>,
) {
impl_binop(
unifier,
store,
ty,
other_ty,
ret_ty,
&[Operator::Add, Operator::Sub, Operator::Mult],
);
}
/// `Pow`
pub fn impl_pow(
unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Option<Type>,
) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Pow]);
}
/// `BitOr`, `BitXor`, `BitAnd`
pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
impl_binop(
unifier,
store,
ty,
&[ty],
Some(ty),
&[Operator::BitAnd, Operator::BitOr, Operator::BitXor],
);
}
/// `LShift`, `RShift`
pub fn impl_bitwise_shift(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
impl_binop(unifier, store, ty, &[store.int32, store.uint32], Some(ty), &[Operator::LShift, Operator::RShift]);
}
/// `Div`
pub fn impl_div(
unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Option<Type>,
) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Div]);
}
/// `FloorDiv`
pub fn impl_floordiv(
unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Option<Type>,
) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::FloorDiv]);
}
/// `Mod`
pub fn impl_mod(
unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Option<Type>,
) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Mod]);
}
/// [`Operator::MatMult`]
pub fn impl_matmul(
unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Option<Type>,
) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::MatMult]);
}
/// `UAdd`, `USub`
pub fn impl_sign(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, ret_ty: Option<Type>) {
impl_unaryop(unifier, ty, ret_ty, &[Unaryop::UAdd, Unaryop::USub]);
}
/// `Invert`
pub fn impl_invert(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, ret_ty: Option<Type>) {
impl_unaryop(unifier, ty, ret_ty, &[Unaryop::Invert]);
}
/// `Not`
pub fn impl_not(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, ret_ty: Option<Type>) {
impl_unaryop(unifier, ty, ret_ty, &[Unaryop::Not]);
}
/// `Lt`, `LtE`, `Gt`, `GtE`
pub fn impl_comparison(
unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Option<Type>,
) {
impl_cmpop(
unifier,
store,
ty,
other_ty,
&[Cmpop::Lt, Cmpop::Gt, Cmpop::LtE, Cmpop::GtE],
ret_ty,
);
}
/// `Eq`, `NotEq`
pub fn impl_eq(
unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Option<Type>,
) {
impl_cmpop(unifier, store, ty, other_ty, &[Cmpop::Eq, Cmpop::NotEq], ret_ty);
}
/// Returns the expected return type of binary operations with at least one `ndarray` operand.
pub fn typeof_ndarray_broadcast(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
left: Type,
right: Type,
) -> Result<Type, String> {
let is_left_ndarray = left.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
let is_right_ndarray = right.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
assert!(is_left_ndarray || is_right_ndarray);
if is_left_ndarray && is_right_ndarray {
// Perform broadcasting on two ndarray operands.
let (left_ty_dtype, left_ty_ndims) = unpack_ndarray_var_tys(unifier, left);
let (right_ty_dtype, right_ty_ndims) = unpack_ndarray_var_tys(unifier, right);
assert!(unifier.unioned(left_ty_dtype, right_ty_dtype));
let left_ty_ndims = match &*unifier.get_ty_immutable(left_ty_ndims) {
TypeEnum::TLiteral { values, .. } => values.clone(),
_ => unreachable!(),
};
let right_ty_ndims = match &*unifier.get_ty_immutable(right_ty_ndims) {
TypeEnum::TLiteral { values, .. } => values.clone(),
_ => unreachable!(),
};
let res_ndims = left_ty_ndims.into_iter()
.cartesian_product(right_ty_ndims)
.map(|(left, right)| {
let left_val = u64::try_from(left).unwrap();
let right_val = u64::try_from(right).unwrap();
max(left_val, right_val)
})
.unique()
.map(SymbolValue::U64)
.collect_vec();
let res_ndims = unifier.get_fresh_literal(res_ndims, None);
Ok(make_ndarray_ty(unifier, primitives, Some(left_ty_dtype), Some(res_ndims)))
} else {
let (ndarray_ty, scalar_ty) = if is_left_ndarray {
(left, right)
} else {
(right, left)
};
let (ndarray_ty_dtype, _) = unpack_ndarray_var_tys(unifier, ndarray_ty);
if unifier.unioned(ndarray_ty_dtype, scalar_ty) {
Ok(ndarray_ty)
} else {
let (expected_ty, actual_ty) = if is_left_ndarray {
(ndarray_ty_dtype, scalar_ty)
} else {
(scalar_ty, ndarray_ty_dtype)
};
Err(format!(
"Expected right-hand side operand to be {}, got {}",
unifier.stringify(expected_ty),
unifier.stringify(actual_ty),
))
}
}
}
/// Returns the return type given a binary operator and its primitive operands.
pub fn typeof_binop(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
op: &Operator,
lhs: Type,
rhs: Type,
) -> Result<Option<Type>, String> {
let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
let is_right_ndarray = rhs.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
Ok(Some(match op {
Operator::Add
| Operator::Sub
| Operator::Mult
| Operator::Mod
| Operator::FloorDiv => {
if is_left_ndarray || is_right_ndarray {
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
} else if unifier.unioned(lhs, rhs) {
lhs
} else {
return Ok(None)
}
}
Operator::MatMult => {
let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs);
let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) {
TypeEnum::TLiteral { values, .. } => {
assert_eq!(values.len(), 1);
u64::try_from(values[0].clone()).unwrap()
}
_ => unreachable!(),
};
let (_, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs);
let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) {
TypeEnum::TLiteral { values, .. } => {
assert_eq!(values.len(), 1);
u64::try_from(values[0].clone()).unwrap()
}
_ => unreachable!(),
};
match (lhs_ndims, rhs_ndims) {
(2, 2) => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?,
(lhs, rhs) if lhs == 0 || rhs == 0 => {
return Err(format!(
"Input operand {} does not have enough dimensions (has {lhs}, requires {rhs})",
(rhs == 0) as u8
))
}
(lhs, rhs) => {
return Err(format!("ndarray.__matmul__ on {lhs}D and {rhs}D operands not supported"))
}
}
}
Operator::Div => {
if is_left_ndarray || is_right_ndarray {
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
} else if unifier.unioned(lhs, rhs) {
primitives.float
} else {
return Ok(None)
}
}
Operator::Pow => {
if is_left_ndarray || is_right_ndarray {
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
} else if [primitives.int32, primitives.int64, primitives.uint32, primitives.uint64, primitives.float].into_iter().any(|ty| unifier.unioned(lhs, ty)) {
lhs
} else {
return Ok(None)
}
}
Operator::LShift
| Operator::RShift => lhs,
Operator::BitOr
| Operator::BitXor
| Operator::BitAnd => {
if unifier.unioned(lhs, rhs) {
lhs
} else {
return Ok(None)
}
}
}))
}
pub fn typeof_unaryop(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
op: &Unaryop,
operand: Type,
) -> Result<Option<Type>, String> {
let operand_obj_id = operand.obj_id(unifier);
if *op == Unaryop::Not && operand_obj_id.is_some_and(|id| id == primitives.ndarray.obj_id(unifier).unwrap()) {
return Err("The truth value of an array with more than one element is ambiguous".to_string())
}
Ok(match *op {
Unaryop::Not => {
match operand_obj_id {
Some(v) if v == PRIMITIVE_DEF_IDS.ndarray => Some(operand),
Some(_) => Some(primitives.bool),
_ => None
}
}
Unaryop::Invert => {
if operand_obj_id.is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) {
Some(primitives.int32)
} else if operand_obj_id.is_some_and(|id| PRIMITIVE_DEF_IDS.iter().any(|prim_id| id == prim_id)) {
Some(operand)
} else {
None
}
}
Unaryop::UAdd
| Unaryop::USub => {
if operand_obj_id.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
let (dtype, _) = unpack_ndarray_var_tys(unifier, operand);
if dtype.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) {
return Err(if *op == Unaryop::UAdd {
"The ufunc 'positive' cannot be applied to ndarray[bool, N]".to_string()
} else {
"The numpy boolean negative, the `-` operator, is not supported, use the `~` operator function instead.".to_string()
})
}
Some(operand)
} else if operand_obj_id.is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) {
Some(primitives.int32)
} else if operand_obj_id.is_some_and(|id| PRIMITIVE_DEF_IDS.iter().any(|prim_id| id == prim_id)) {
Some(operand)
} else {
None
}
}
})
}
/// Returns the return type given a comparison operator and its primitive operands.
pub fn typeof_cmpop(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
_op: &Cmpop,
lhs: Type,
rhs: Type,
) -> Result<Option<Type>, String> {
let is_left_ndarray = lhs
.obj_id(unifier)
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
let is_right_ndarray = rhs
.obj_id(unifier)
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
Ok(Some(if is_left_ndarray || is_right_ndarray {
let brd = typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?;
let (_, ndims) = unpack_ndarray_var_tys(unifier, brd);
make_ndarray_ty(unifier, primitives, Some(primitives.bool), Some(ndims))
} else if unifier.unioned(lhs, rhs) {
primitives.bool
} else {
return Ok(None)
}))
}
pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) {
let PrimitiveStore {
int32: int32_t,
int64: int64_t,
float: float_t,
bool: bool_t,
uint32: uint32_t,
uint64: uint64_t,
ndarray: ndarray_t,
..
} = *store;
let size_t = store.usize();
/* int ======== */
for t in [int32_t, int64_t, uint32_t, uint64_t] {
let ndarray_int_t = make_ndarray_ty(unifier, store, Some(t), None);
impl_basic_arithmetic(unifier, store, t, &[t, ndarray_int_t], None);
impl_pow(unifier, store, t, &[t, ndarray_int_t], None);
impl_bitwise_arithmetic(unifier, store, t);
impl_bitwise_shift(unifier, store, t);
impl_div(unifier, store, t, &[t, ndarray_int_t], None);
impl_floordiv(unifier, store, t, &[t, ndarray_int_t], None);
impl_mod(unifier, store, t, &[t, ndarray_int_t], None);
impl_invert(unifier, store, t, Some(t));
impl_not(unifier, store, t, Some(bool_t));
impl_comparison(unifier, store, t, &[t, ndarray_int_t], None);
impl_eq(unifier, store, t, &[t, ndarray_int_t], None);
}
for t in [int32_t, int64_t] {
impl_sign(unifier, store, t, Some(t));
}
/* float ======== */
let ndarray_float_t = make_ndarray_ty(unifier, store, Some(float_t), None);
let ndarray_int32_t = make_ndarray_ty(unifier, store, Some(int32_t), None);
impl_basic_arithmetic(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_pow(unifier, store, float_t, &[int32_t, float_t, ndarray_int32_t, ndarray_float_t], None);
impl_div(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_floordiv(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_mod(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_sign(unifier, store, float_t, Some(float_t));
impl_not(unifier, store, float_t, Some(bool_t));
impl_comparison(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_eq(unifier, store, float_t, &[float_t, ndarray_float_t], None);
/* bool ======== */
let ndarray_bool_t = make_ndarray_ty(unifier, store, Some(bool_t), None);
impl_invert(unifier, store, bool_t, Some(int32_t));
impl_not(unifier, store, bool_t, Some(bool_t));
impl_sign(unifier, store, bool_t, Some(int32_t));
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None);
/* ndarray ===== */
let ndarray_usized_ndims_tvar = unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None);
let ndarray_unsized_t = make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.0));
let (ndarray_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_t);
let (ndarray_unsized_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_unsized_t);
impl_basic_arithmetic(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_pow(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None);
impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_matmul(unifier, store, ndarray_t, &[ndarray_t], Some(ndarray_t));
impl_sign(unifier, store, ndarray_t, Some(ndarray_t));
impl_invert(unifier, store, ndarray_t, Some(ndarray_t));
impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_comparison(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
}

View File

@ -1,6 +0,0 @@
mod function_check;
pub mod magic_methods;
pub mod type_error;
pub mod type_inferencer;
pub mod typedef;
mod unification_table;

View File

@ -1,187 +0,0 @@
use std::collections::HashMap;
use std::fmt::Display;
use crate::typecheck::typedef::TypeEnum;
use super::typedef::{RecordKey, Type, Unifier};
use nac3parser::ast::{Location, StrRef};
#[derive(Debug, Clone)]
pub enum TypeErrorKind {
TooManyArguments {
expected: usize,
got: usize,
},
MissingArgs(String),
UnknownArgName(StrRef),
IncorrectArgType {
name: StrRef,
expected: Type,
got: Type,
},
FieldUnificationError {
field: RecordKey,
types: (Type, Type),
loc: (Option<Location>, Option<Location>),
},
IncompatibleRange(Type, Vec<Type>),
IncompatibleTypes(Type, Type),
MutationError(RecordKey, Type),
NoSuchField(RecordKey, Type),
TupleIndexOutOfBounds {
index: i32,
len: i32,
},
RequiresTypeAnn,
PolymorphicFunctionPointer,
}
#[derive(Debug, Clone)]
pub struct TypeError {
pub kind: TypeErrorKind,
pub loc: Option<Location>,
}
impl TypeError {
#[must_use]
pub fn new(kind: TypeErrorKind, loc: Option<Location>) -> TypeError {
TypeError { kind, loc }
}
#[must_use]
pub fn at(mut self, loc: Option<Location>) -> TypeError {
self.loc = self.loc.or(loc);
self
}
#[must_use]
pub fn to_display(self, unifier: &Unifier) -> DisplayTypeError {
DisplayTypeError { err: self, unifier }
}
}
pub struct DisplayTypeError<'a> {
pub err: TypeError,
pub unifier: &'a Unifier,
}
fn loc_to_str(loc: Option<Location>) -> String {
match loc {
Some(loc) => format!("(in {loc})"),
None => String::new(),
}
}
impl<'a> Display for DisplayTypeError<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
use TypeErrorKind::*;
let mut notes = Some(HashMap::new());
match &self.err.kind {
TooManyArguments { expected, got } => {
write!(f, "Too many arguments. Expected {expected} but got {got}")
}
MissingArgs(args) => {
write!(f, "Missing arguments: {args}")
}
UnknownArgName(name) => {
write!(f, "Unknown argument name: {name}")
}
IncorrectArgType { name, expected, got } => {
let expected = self.unifier.stringify_with_notes(*expected, &mut notes);
let got = self.unifier.stringify_with_notes(*got, &mut notes);
write!(
f,
"Incorrect argument type for {name}. Expected {expected}, but got {got}"
)
}
FieldUnificationError { field, types, loc } => {
let lhs = self.unifier.stringify_with_notes(types.0, &mut notes);
let rhs = self.unifier.stringify_with_notes(types.1, &mut notes);
write!(
f,
"Unable to unify field {}: Got types {}{} and {}{}",
field,
lhs,
loc_to_str(loc.0),
rhs,
loc_to_str(loc.1)
)
}
IncompatibleRange(t, ts) => {
let t = self.unifier.stringify_with_notes(*t, &mut notes);
let ts = ts
.iter()
.map(|t| self.unifier.stringify_with_notes(*t, &mut notes))
.collect::<Vec<_>>();
write!(f, "Expected any one of these types: {}, but got {}", ts.join(", "), t)
}
IncompatibleTypes(t1, t2) => {
let type1 = self.unifier.get_ty_immutable(*t1);
let type2 = self.unifier.get_ty_immutable(*t2);
match (&*type1, &*type2) {
(TypeEnum::TCall(calls), _) => {
let loc = self.unifier.calls[calls[0].0].loc;
let result = write!(
f,
"{} is not callable",
self.unifier.stringify_with_notes(*t2, &mut notes)
);
if let Some(loc) = loc {
result?;
write!(f, " (in {loc})")?;
return Ok(());
}
result
}
(TypeEnum::TTuple { ty: ty1 }, TypeEnum::TTuple { ty: ty2 })
if ty1.len() != ty2.len() =>
{
let t1 = self.unifier.stringify_with_notes(*t1, &mut notes);
let t2 = self.unifier.stringify_with_notes(*t2, &mut notes);
write!(f, "Tuple length mismatch: got {t1} and {t2}")
}
_ => {
let t1 = self.unifier.stringify_with_notes(*t1, &mut notes);
let t2 = self.unifier.stringify_with_notes(*t2, &mut notes);
write!(f, "Incompatible types: {t1} and {t2}")
}
}
}
MutationError(name, t) => {
if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty_immutable(*t) {
write!(f, "Cannot assign to an element of a tuple")
} else {
let t = self.unifier.stringify_with_notes(*t, &mut notes);
write!(f, "Cannot assign to field {name} of {t}, which is immutable")
}
}
NoSuchField(name, t) => {
let t = self.unifier.stringify_with_notes(*t, &mut notes);
write!(f, "`{t}::{name}` field/method does not exist")
}
TupleIndexOutOfBounds { index, len } => {
write!(
f,
"Tuple index out of bounds. Got {index} but tuple has only {len} elements"
)
}
RequiresTypeAnn => {
write!(f, "Unable to infer virtual object type: Type annotation required")
}
PolymorphicFunctionPointer => {
write!(f, "Polymorphic function pointers is not supported")
}
}?;
if let Some(loc) = self.err.loc {
write!(f, " at {loc}")?;
}
let notes = notes.unwrap();
if !notes.is_empty() {
write!(f, "\n\nNotes:")?;
for line in notes.values() {
write!(f, "\n {line}")?;
}
}
Ok(())
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,736 +0,0 @@
use super::super::{magic_methods::with_fields, typedef::*};
use super::*;
use crate::{
codegen::CodeGenContext,
symbol_resolver::ValueEnum,
toplevel::{DefinitionId, helper::PRIMITIVE_DEF_IDS, TopLevelDef},
};
use indoc::indoc;
use std::iter::zip;
use nac3parser::parser::parse_program;
use parking_lot::RwLock;
use test_case::test_case;
struct Resolver {
id_to_type: HashMap<StrRef, Type>,
id_to_def: HashMap<StrRef, DefinitionId>,
class_names: HashMap<StrRef, Type>,
}
impl SymbolResolver for Resolver {
fn get_default_param_value(
&self,
_: &ast::Expr,
) -> Option<crate::symbol_resolver::SymbolValue> {
unimplemented!()
}
fn get_symbol_type(
&self,
_: &mut Unifier,
_: &[Arc<RwLock<TopLevelDef>>],
_: &PrimitiveStore,
str: StrRef,
) -> Result<Type, String> {
self.id_to_type.get(&str).cloned().ok_or_else(|| format!("cannot find symbol `{}`", str))
}
fn get_symbol_value<'ctx, 'a>(
&self,
_: StrRef,
_: &mut CodeGenContext<'ctx, 'a>,
) -> Option<ValueEnum<'ctx>> {
unimplemented!()
}
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> {
self.id_to_def.get(&id).cloned()
.ok_or_else(|| HashSet::from(["Unknown identifier".to_string()]))
}
fn get_string_id(&self, _: &str) -> i32 {
unimplemented!()
}
fn get_exception_id(&self, _tyid: usize) -> usize {
unimplemented!()
}
}
struct TestEnvironment {
pub unifier: Unifier,
pub function_data: FunctionData,
pub primitives: PrimitiveStore,
pub id_to_name: HashMap<usize, StrRef>,
pub identifier_mapping: HashMap<StrRef, Type>,
pub virtual_checks: Vec<(Type, Type, Location)>,
pub calls: HashMap<CodeLocation, CallId>,
pub top_level: TopLevelContext,
}
impl TestEnvironment {
pub fn basic_test_env() -> TestEnvironment {
let mut unifier = Unifier::new();
let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.int32,
fields: HashMap::new(),
params: VarMap::new(),
});
with_fields(&mut unifier, int32, |unifier, fields| {
let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }],
ret: int32,
vars: VarMap::new(),
}));
fields.insert("__add__".into(), (add_ty, false));
});
let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.int64,
fields: HashMap::new(),
params: VarMap::new(),
});
let float = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.float,
fields: HashMap::new(),
params: VarMap::new(),
});
let bool = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.bool,
fields: HashMap::new(),
params: VarMap::new(),
});
let none = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.none,
fields: HashMap::new(),
params: VarMap::new(),
});
let range = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.range,
fields: HashMap::new(),
params: VarMap::new(),
});
let str = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.str,
fields: HashMap::new(),
params: VarMap::new(),
});
let exception = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.exception,
fields: HashMap::new(),
params: VarMap::new(),
});
let uint32 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.uint32,
fields: HashMap::new(),
params: VarMap::new(),
});
let uint64 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.uint64,
fields: HashMap::new(),
params: VarMap::new(),
});
let option = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.option,
fields: HashMap::new(),
params: VarMap::new(),
});
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
let ndarray_ndims_tvar = unifier.get_fresh_const_generic_var(uint64, Some("ndarray_ndims".into()), None);
let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.ndarray,
fields: HashMap::new(),
params: VarMap::from([
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
(ndarray_ndims_tvar.1, ndarray_ndims_tvar.0),
]),
});
let primitives = PrimitiveStore {
int32,
int64,
float,
bool,
none,
range,
str,
exception,
uint32,
uint64,
option,
ndarray,
size_t: 64,
};
unifier.put_primitive_store(&primitives);
set_primitives_magic_methods(&primitives, &mut unifier);
let id_to_name = [
(0, "int32".into()),
(1, "int64".into()),
(2, "float".into()),
(3, "bool".into()),
(4, "none".into()),
(5, "range".into()),
(6, "str".into()),
(7, "exception".into()),
]
.iter()
.cloned()
.collect();
let mut identifier_mapping = HashMap::new();
identifier_mapping.insert("None".into(), none);
let resolver = Arc::new(Resolver {
id_to_type: identifier_mapping.clone(),
id_to_def: Default::default(),
class_names: Default::default(),
}) as Arc<dyn SymbolResolver + Send + Sync>;
TestEnvironment {
top_level: TopLevelContext {
definitions: Default::default(),
unifiers: Default::default(),
personality_symbol: None,
},
unifier,
function_data: FunctionData {
resolver,
bound_variables: Vec::new(),
return_type: None,
},
primitives,
id_to_name,
identifier_mapping,
virtual_checks: Vec::new(),
calls: HashMap::new(),
}
}
fn new() -> TestEnvironment {
let mut unifier = Unifier::new();
let mut identifier_mapping = HashMap::new();
let mut top_level_defs: Vec<Arc<RwLock<TopLevelDef>>> = Vec::new();
let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.int32,
fields: HashMap::new(),
params: VarMap::new(),
});
with_fields(&mut unifier, int32, |unifier, fields| {
let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }],
ret: int32,
vars: VarMap::new(),
}));
fields.insert("__add__".into(), (add_ty, false));
});
let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.int64,
fields: HashMap::new(),
params: VarMap::new(),
});
let float = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.float,
fields: HashMap::new(),
params: VarMap::new(),
});
let bool = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.bool,
fields: HashMap::new(),
params: VarMap::new(),
});
let none = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.none,
fields: HashMap::new(),
params: VarMap::new(),
});
let range = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.range,
fields: HashMap::new(),
params: VarMap::new(),
});
let str = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.str,
fields: HashMap::new(),
params: VarMap::new(),
});
let exception = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.exception,
fields: HashMap::new(),
params: VarMap::new(),
});
let uint32 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.uint32,
fields: HashMap::new(),
params: VarMap::new(),
});
let uint64 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.uint64,
fields: HashMap::new(),
params: VarMap::new(),
});
let option = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.option,
fields: HashMap::new(),
params: VarMap::new(),
});
let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.ndarray,
fields: HashMap::new(),
params: VarMap::new(),
});
identifier_mapping.insert("None".into(), none);
for (i, name) in ["int32", "int64", "float", "bool", "none", "range", "str", "Exception"]
.iter()
.enumerate()
{
top_level_defs.push(
RwLock::new(TopLevelDef::Class {
name: (*name).into(),
object_id: DefinitionId(i),
type_vars: Default::default(),
fields: Default::default(),
methods: Default::default(),
ancestors: Default::default(),
resolver: None,
constructor: None,
loc: None,
})
.into(),
);
}
let defs = 7;
let primitives = PrimitiveStore {
int32,
int64,
float,
bool,
none,
range,
str,
exception,
uint32,
uint64,
option,
ndarray,
size_t: 64,
};
unifier.put_primitive_store(&primitives);
let (v0, id) = unifier.get_dummy_var();
let foo_ty = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(defs + 1),
fields: [("a".into(), (v0, true))].iter().cloned().collect::<HashMap<_, _>>(),
params: [(id, v0)].iter().cloned().collect::<VarMap>(),
});
top_level_defs.push(
RwLock::new(TopLevelDef::Class {
name: "Foo".into(),
object_id: DefinitionId(defs + 1),
type_vars: vec![v0],
fields: [("a".into(), v0, true)].into(),
methods: Default::default(),
ancestors: Default::default(),
resolver: None,
constructor: None,
loc: None,
})
.into(),
);
identifier_mapping.insert(
"Foo".into(),
unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![],
ret: foo_ty,
vars: [(id, v0)].iter().cloned().collect(),
})),
);
let fun = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![],
ret: int32,
vars: Default::default(),
}));
let bar = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(defs + 2),
fields: [("a".into(), (int32, true)), ("b".into(), (fun, true))]
.iter()
.cloned()
.collect::<HashMap<_, _>>(),
params: Default::default(),
});
top_level_defs.push(
RwLock::new(TopLevelDef::Class {
name: "Bar".into(),
object_id: DefinitionId(defs + 2),
type_vars: Default::default(),
fields: [("a".into(), int32, true), ("b".into(), fun, true)].into(),
methods: Default::default(),
ancestors: Default::default(),
resolver: None,
constructor: None,
loc: None,
})
.into(),
);
identifier_mapping.insert(
"Bar".into(),
unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![],
ret: bar,
vars: Default::default(),
})),
);
let bar2 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(defs + 3),
fields: [("a".into(), (bool, true)), ("b".into(), (fun, false))]
.iter()
.cloned()
.collect::<HashMap<_, _>>(),
params: Default::default(),
});
top_level_defs.push(
RwLock::new(TopLevelDef::Class {
name: "Bar2".into(),
object_id: DefinitionId(defs + 3),
type_vars: Default::default(),
fields: [("a".into(), bool, true), ("b".into(), fun, false)].into(),
methods: Default::default(),
ancestors: Default::default(),
resolver: None,
constructor: None,
loc: None,
})
.into(),
);
identifier_mapping.insert(
"Bar2".into(),
unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![],
ret: bar2,
vars: Default::default(),
})),
);
let class_names = [("Bar".into(), bar), ("Bar2".into(), bar2)].iter().cloned().collect();
let id_to_name = [
"int32".into(),
"int64".into(),
"float".into(),
"bool".into(),
"none".into(),
"range".into(),
"str".into(),
"exception".into(),
"Foo".into(),
"Bar".into(),
"Bar2".into(),
]
.iter()
.enumerate()
.map(|(a, b)| (a, *b))
.collect();
let top_level = TopLevelContext {
definitions: Arc::new(top_level_defs.into()),
unifiers: Default::default(),
personality_symbol: None,
};
let resolver = Arc::new(Resolver {
id_to_type: identifier_mapping.clone(),
id_to_def: [
("Foo".into(), DefinitionId(defs + 1)),
("Bar".into(), DefinitionId(defs + 2)),
("Bar2".into(), DefinitionId(defs + 3)),
]
.iter()
.cloned()
.collect(),
class_names,
}) as Arc<dyn SymbolResolver + Send + Sync>;
TestEnvironment {
unifier,
top_level,
function_data: FunctionData {
resolver,
bound_variables: Vec::new(),
return_type: None,
},
primitives,
id_to_name,
identifier_mapping,
virtual_checks: Vec::new(),
calls: HashMap::new(),
}
}
fn get_inferencer(&mut self) -> Inferencer {
Inferencer {
top_level: &self.top_level,
function_data: &mut self.function_data,
unifier: &mut self.unifier,
variable_mapping: Default::default(),
primitives: &mut self.primitives,
virtual_checks: &mut self.virtual_checks,
calls: &mut self.calls,
defined_identifiers: Default::default(),
in_handler: false,
}
}
}
#[test_case(indoc! {"
a = 1234
b = int64(2147483648)
c = 1.234
d = True
"},
[("a", "int32"), ("b", "int64"), ("c", "float"), ("d", "bool")].iter().cloned().collect(),
&[]
; "primitives test")]
#[test_case(indoc! {"
a = lambda x, y: x
b = lambda x: a(x, x)
c = 1.234
d = b(c)
"},
[("a", "fn[[x:float, y:float], float]"), ("b", "fn[[x:float], float]"), ("c", "float"), ("d", "float")].iter().cloned().collect(),
&[]
; "lambda test")]
#[test_case(indoc! {"
a = lambda x: x + x
b = lambda x: a(x) + x
a = b
c = b(1)
"},
[("a", "fn[[x:int32], int32]"), ("b", "fn[[x:int32], int32]"), ("c", "int32")].iter().cloned().collect(),
&[]
; "lambda test 2")]
#[test_case(indoc! {"
a = lambda x: x
b = lambda x: x
foo1 = Foo()
foo2 = Foo()
c = a(foo1.a)
d = b(foo2.a)
a(True)
b(123)
"},
[("a", "fn[[x:bool], bool]"), ("b", "fn[[x:int32], int32]"), ("c", "bool"),
("d", "int32"), ("foo1", "Foo[bool]"), ("foo2", "Foo[int32]")].iter().cloned().collect(),
&[]
; "obj test")]
#[test_case(indoc! {"
a = [1, 2, 3]
b = [x + x for x in a]
"},
[("a", "list[int32]"), ("b", "list[int32]")].iter().cloned().collect(),
&[]
; "listcomp test")]
#[test_case(indoc! {"
a = virtual(Bar(), Bar)
b = a.b()
a = virtual(Bar2())
"},
[("a", "virtual[Bar]"), ("b", "int32")].iter().cloned().collect(),
&[("Bar", "Bar"), ("Bar2", "Bar")]
; "virtual test")]
#[test_case(indoc! {"
a = [virtual(Bar(), Bar), virtual(Bar2())]
b = [x.b() for x in a]
"},
[("a", "list[virtual[Bar]]"), ("b", "list[int32]")].iter().cloned().collect(),
&[("Bar", "Bar"), ("Bar2", "Bar")]
; "virtual list test")]
fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &str)]) {
println!("source:\n{}", source);
let mut env = TestEnvironment::new();
let id_to_name = std::mem::take(&mut env.id_to_name);
let mut defined_identifiers: HashSet<_> = env.identifier_mapping.keys().cloned().collect();
defined_identifiers.insert("virtual".into());
let mut inferencer = env.get_inferencer();
inferencer.defined_identifiers = defined_identifiers.clone();
let statements = parse_program(source, Default::default()).unwrap();
let statements = statements
.into_iter()
.map(|v| inferencer.fold_stmt(v))
.collect::<Result<Vec<_>, _>>()
.unwrap();
inferencer.check_block(&statements, &mut defined_identifiers).unwrap();
for (k, v) in inferencer.variable_mapping.iter() {
let name = inferencer.unifier.internal_stringify(
*v,
&mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{}", v),
&mut None,
);
println!("{}: {}", k, name);
}
for (k, v) in mapping.iter() {
let ty = inferencer.variable_mapping.get(&(*k).into()).unwrap();
let name = inferencer.unifier.internal_stringify(
*ty,
&mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{}", v),
&mut None,
);
assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name));
}
assert_eq!(inferencer.virtual_checks.len(), virtuals.len());
for ((a, b, _), (x, y)) in zip(inferencer.virtual_checks.iter(), virtuals) {
let a = inferencer.unifier.internal_stringify(
*a,
&mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{}", v),
&mut None,
);
let b = inferencer.unifier.internal_stringify(
*b,
&mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{}", v),
&mut None,
);
assert_eq!(&a, x);
assert_eq!(&b, y);
}
}
#[test_case(indoc! {"
a = 2
b = 2
c = a + b
d = a - b
e = a * b
f = a / b
g = a // b
h = a % b
"},
[("a", "int32"),
("b", "int32"),
("c", "int32"),
("d", "int32"),
("e", "int32"),
("f", "float"),
("g", "int32"),
("h", "int32")].iter().cloned().collect()
; "int32")]
#[test_case(
indoc! {"
a = 2.4
b = 3.6
c = a + b
d = a - b
e = a * b
f = a / b
g = a // b
h = a % b
i = a ** b
ii = 3
j = a ** b
"},
[("a", "float"),
("b", "float"),
("c", "float"),
("d", "float"),
("e", "float"),
("f", "float"),
("g", "float"),
("h", "float"),
("i", "float"),
("ii", "int32"),
("j", "float")].iter().cloned().collect()
; "float"
)]
#[test_case(
indoc! {"
a = int64(12312312312)
b = int64(24242424424)
c = a + b
d = a - b
e = a * b
f = a / b
g = a // b
h = a % b
i = a == b
j = a > b
k = a < b
l = a != b
"},
[("a", "int64"),
("b", "int64"),
("c", "int64"),
("d", "int64"),
("e", "int64"),
("f", "float"),
("g", "int64"),
("h", "int64"),
("i", "bool"),
("j", "bool"),
("k", "bool"),
("l", "bool")].iter().cloned().collect()
; "int64"
)]
#[test_case(
indoc! {"
a = True
b = False
c = a == b
d = not a
e = a != b
"},
[("a", "bool"),
("b", "bool"),
("c", "bool"),
("d", "bool"),
("e", "bool")].iter().cloned().collect()
; "boolean"
)]
fn test_primitive_magic_methods(source: &str, mapping: HashMap<&str, &str>) {
println!("source:\n{}", source);
let mut env = TestEnvironment::basic_test_env();
let id_to_name = std::mem::take(&mut env.id_to_name);
let mut defined_identifiers: HashSet<_> = env.identifier_mapping.keys().cloned().collect();
defined_identifiers.insert("virtual".into());
let mut inferencer = env.get_inferencer();
inferencer.defined_identifiers = defined_identifiers.clone();
let statements = parse_program(source, Default::default()).unwrap();
let statements = statements
.into_iter()
.map(|v| inferencer.fold_stmt(v))
.collect::<Result<Vec<_>, _>>()
.unwrap();
inferencer.check_block(&statements, &mut defined_identifiers).unwrap();
for (k, v) in inferencer.variable_mapping.iter() {
let name = inferencer.unifier.internal_stringify(
*v,
&mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{}", v),
&mut None,
);
println!("{}: {}", k, name);
}
for (k, v) in mapping.iter() {
let ty = inferencer.variable_mapping.get(&(*k).into()).unwrap();
let name = inferencer.unifier.internal_stringify(
*ty,
&mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{}", v),
&mut None,
);
assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name));
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,570 +0,0 @@
use super::super::magic_methods::with_fields;
use super::*;
use indoc::indoc;
use itertools::Itertools;
use std::collections::HashMap;
use test_case::test_case;
impl Unifier {
/// Check whether two types are equal.
fn eq(&mut self, a: Type, b: Type) -> bool {
if a == b {
return true;
}
let (ty_a, ty_b) = {
let table = &mut self.unification_table;
if table.unioned(a, b) {
return true;
}
(table.probe_value(a).clone(), table.probe_value(b).clone())
};
match (&*ty_a, &*ty_b) {
(
TypeEnum::TVar { fields: None, id: id1, .. },
TypeEnum::TVar { fields: None, id: id2, .. },
) => id1 == id2,
(
TypeEnum::TVar { fields: Some(map1), .. },
TypeEnum::TVar { fields: Some(map2), .. },
) => self.map_eq2(map1, map2),
(TypeEnum::TTuple { ty: ty1 }, TypeEnum::TTuple { ty: ty2 }) => {
ty1.len() == ty2.len()
&& ty1.iter().zip(ty2.iter()).all(|(t1, t2)| self.eq(*t1, *t2))
}
(TypeEnum::TList { ty: ty1 }, TypeEnum::TList { ty: ty2 })
| (TypeEnum::TVirtual { ty: ty1 }, TypeEnum::TVirtual { ty: ty2 }) => {
self.eq(*ty1, *ty2)
}
(
TypeEnum::TObj { obj_id: id1, params: params1, .. },
TypeEnum::TObj { obj_id: id2, params: params2, .. },
) => id1 == id2 && self.map_eq(params1, params2),
// TLiteral, TCall and TFunc are not yet implemented
_ => false,
}
}
fn map_eq<K>(&mut self, map1: &IndexMapping<K>, map2: &IndexMapping<K>) -> bool
where
K: std::hash::Hash + Eq + Clone
{
if map1.len() != map2.len() {
return false;
}
for (k, v) in map1.iter() {
if !map2.get(k).map(|v1| self.eq(*v, *v1)).unwrap_or(false) {
return false;
}
}
true
}
fn map_eq2<K>(&mut self, map1: &Mapping<K, RecordField>, map2: &Mapping<K, RecordField>) -> bool
where
K: std::hash::Hash + Eq + Clone,
{
if map1.len() != map2.len() {
return false;
}
for (k, v) in map1.iter() {
if !map2.get(k).map(|v1| self.eq(v.ty, v1.ty)).unwrap_or(false) {
return false;
}
}
true
}
}
struct TestEnvironment {
pub unifier: Unifier,
pub type_mapping: HashMap<String, Type>,
}
impl TestEnvironment {
fn new() -> TestEnvironment {
let mut unifier = Unifier::new();
let mut type_mapping = HashMap::new();
type_mapping.insert(
"int".into(),
unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(0),
fields: HashMap::new(),
params: VarMap::new(),
}),
);
type_mapping.insert(
"float".into(),
unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(1),
fields: HashMap::new(),
params: VarMap::new(),
}),
);
type_mapping.insert(
"bool".into(),
unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(2),
fields: HashMap::new(),
params: VarMap::new(),
}),
);
let (v0, id) = unifier.get_dummy_var();
type_mapping.insert(
"Foo".into(),
unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(3),
fields: [("a".into(), (v0, true))].iter().cloned().collect::<HashMap<_, _>>(),
params: [(id, v0)].iter().cloned().collect::<VarMap>(),
}),
);
TestEnvironment { unifier, type_mapping }
}
fn parse(&mut self, typ: &str, mapping: &Mapping<String>) -> Type {
let result = self.internal_parse(typ, mapping);
assert!(result.1.is_empty());
result.0
}
fn internal_parse<'a, 'b>(
&'a mut self,
typ: &'b str,
mapping: &Mapping<String>,
) -> (Type, &'b str) {
// for testing only, so we can just panic when the input is malformed
let end = typ.find(|c| ['[', ',', ']', '='].contains(&c)).unwrap_or_else(|| typ.len());
match &typ[..end] {
"tuple" => {
let mut s = &typ[end..];
assert_eq!(&s[0..1], "[");
let mut ty = Vec::new();
while &s[0..1] != "]" {
let result = self.internal_parse(&s[1..], mapping);
ty.push(result.0);
s = result.1;
}
(self.unifier.add_ty(TypeEnum::TTuple { ty }), &s[1..])
}
"list" => {
assert_eq!(&typ[end..end + 1], "[");
let (ty, s) = self.internal_parse(&typ[end + 1..], mapping);
assert_eq!(&s[0..1], "]");
(self.unifier.add_ty(TypeEnum::TList { ty }), &s[1..])
}
"Record" => {
let mut s = &typ[end..];
assert_eq!(&s[0..1], "[");
let mut fields = HashMap::new();
while &s[0..1] != "]" {
let eq = s.find('=').unwrap();
let key = s[1..eq].into();
let result = self.internal_parse(&s[eq + 1..], mapping);
fields.insert(key, RecordField::new(result.0, true, None));
s = result.1;
}
(self.unifier.add_record(fields), &s[1..])
}
x => {
let mut s = &typ[end..];
let ty = mapping.get(x).cloned().unwrap_or_else(|| {
// mapping should be type variables, type_mapping should be concrete types
// we should not resolve the type of type variables.
let mut ty = *self.type_mapping.get(x).unwrap();
let te = self.unifier.get_ty(ty);
if let TypeEnum::TObj { params, .. } = &*te.as_ref() {
if !params.is_empty() {
assert_eq!(&s[0..1], "[");
let mut p = Vec::new();
while &s[0..1] != "]" {
let result = self.internal_parse(&s[1..], mapping);
p.push(result.0);
s = result.1;
}
s = &s[1..];
ty = self
.unifier
.subst(ty, &params.keys().cloned().zip(p.into_iter()).collect())
.unwrap_or(ty);
}
}
ty
});
(ty, s)
}
}
}
fn unify(&mut self, typ1: Type, typ2: Type) -> Result<(), String> {
self.unifier.unify(typ1, typ2).map_err(|e| e.to_display(&self.unifier).to_string())
}
}
#[test_case(2,
&[("v1", "v2"), ("v2", "float")],
&[("v1", "float"), ("v2", "float")]
; "simple variable"
)]
#[test_case(2,
&[("v1", "list[v2]"), ("v1", "list[float]")],
&[("v1", "list[float]"), ("v2", "float")]
; "list element"
)]
#[test_case(3,
&[
("v1", "Record[a=v3,b=v3]"),
("v2", "Record[b=float,c=v3]"),
("v1", "v2")
],
&[
("v1", "Record[a=float,b=float,c=float]"),
("v2", "Record[a=float,b=float,c=float]"),
("v3", "float")
]
; "record merge"
)]
#[test_case(3,
&[
("v1", "Record[a=float]"),
("v2", "Foo[v3]"),
("v1", "v2")
],
&[
("v1", "Foo[float]"),
("v3", "float")
]
; "record obj merge"
)]
/// Test cases for valid unifications.
fn test_unify(
variable_count: u32,
unify_pairs: &[(&'static str, &'static str)],
verify_pairs: &[(&'static str, &'static str)],
) {
let unify_count = unify_pairs.len();
// test all permutations...
for perm in unify_pairs.iter().permutations(unify_count) {
let mut env = TestEnvironment::new();
let mut mapping = HashMap::new();
for i in 1..=variable_count {
let v = env.unifier.get_dummy_var();
mapping.insert(format!("v{}", i), v.0);
}
// unification may have side effect when we do type resolution, so freeze the types
// before doing unification.
let mut pairs = Vec::new();
for (a, b) in perm.iter() {
let t1 = env.parse(a, &mapping);
let t2 = env.parse(b, &mapping);
pairs.push((t1, t2));
}
for (t1, t2) in pairs {
env.unifier.unify(t1, t2).unwrap();
}
for (a, b) in verify_pairs.iter() {
println!("{} = {}", a, b);
let t1 = env.parse(a, &mapping);
let t2 = env.parse(b, &mapping);
println!("a = {}, b = {}", env.unifier.stringify(t1), env.unifier.stringify(t2));
assert!(env.unifier.eq(t1, t2));
}
}
}
#[test_case(2,
&[
("v1", "tuple[int]"),
("v2", "list[int]"),
],
(("v1", "v2"), "Incompatible types: list[0] and tuple[0]")
; "type mismatch"
)]
#[test_case(2,
&[
("v1", "tuple[int]"),
("v2", "tuple[float]"),
],
(("v1", "v2"), "Incompatible types: tuple[0] and tuple[1]")
; "tuple parameter mismatch"
)]
#[test_case(2,
&[
("v1", "tuple[int,int]"),
("v2", "tuple[int]"),
],
(("v1", "v2"), "Tuple length mismatch: got tuple[0, 0] and tuple[0]")
; "tuple length mismatch"
)]
#[test_case(3,
&[
("v1", "Record[a=float,b=int]"),
("v2", "Foo[v3]"),
],
(("v1", "v2"), "`3[typevar4]::b` field/method does not exist")
; "record obj merge"
)]
/// Test cases for invalid unifications.
fn test_invalid_unification(
variable_count: u32,
unify_pairs: &[(&'static str, &'static str)],
erroneous_pair: ((&'static str, &'static str), &'static str),
) {
let mut env = TestEnvironment::new();
let mut mapping = HashMap::new();
for i in 1..=variable_count {
let v = env.unifier.get_dummy_var();
mapping.insert(format!("v{}", i), v.0);
}
// unification may have side effect when we do type resolution, so freeze the types
// before doing unification.
let mut pairs = Vec::new();
for (a, b) in unify_pairs.iter() {
let t1 = env.parse(a, &mapping);
let t2 = env.parse(b, &mapping);
pairs.push((t1, t2));
}
let (t1, t2) =
(env.parse(erroneous_pair.0 .0, &mapping), env.parse(erroneous_pair.0 .1, &mapping));
for (a, b) in pairs {
env.unifier.unify(a, b).unwrap();
}
assert_eq!(env.unify(t1, t2), Err(erroneous_pair.1.to_string()));
}
#[test]
fn test_recursive_subst() {
let mut env = TestEnvironment::new();
let int = *env.type_mapping.get("int").unwrap();
let foo_id = *env.type_mapping.get("Foo").unwrap();
let foo_ty = env.unifier.get_ty(foo_id);
with_fields(&mut env.unifier, foo_id, |_unifier, fields| {
fields.insert("rec".into(), (foo_id, true));
});
let TypeEnum::TObj { params, .. } = &*foo_ty else {
unreachable!()
};
let mapping = params.iter().map(|(id, _)| (*id, int)).collect();
let instantiated = env.unifier.subst(foo_id, &mapping).unwrap();
let instantiated_ty = env.unifier.get_ty(instantiated);
let TypeEnum::TObj { fields, .. } = &*instantiated_ty else {
unreachable!()
};
assert!(env.unifier.unioned(fields.get(&"a".into()).unwrap().0, int));
assert!(env.unifier.unioned(fields.get(&"rec".into()).unwrap().0, instantiated));
}
#[test]
fn test_virtual() {
let mut env = TestEnvironment::new();
let int = env.parse("int", &HashMap::new());
let fun = env.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![],
ret: int,
vars: VarMap::new(),
}));
let bar = env.unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(5),
fields: [("f".into(), (fun, false)), ("a".into(), (int, false))]
.iter()
.cloned()
.collect::<HashMap<StrRef, _>>(),
params: VarMap::new(),
});
let v0 = env.unifier.get_dummy_var().0;
let v1 = env.unifier.get_dummy_var().0;
let a = env.unifier.add_ty(TypeEnum::TVirtual { ty: bar });
let b = env.unifier.add_ty(TypeEnum::TVirtual { ty: v0 });
let c = env
.unifier
.add_record([("f".into(), RecordField::new(v1, false, None))].iter().cloned().collect());
env.unifier.unify(a, b).unwrap();
env.unifier.unify(b, c).unwrap();
assert!(env.unifier.eq(v1, fun));
let d = env
.unifier
.add_record([("a".into(), RecordField::new(v1, true, None))].iter().cloned().collect());
assert_eq!(env.unify(b, d), Err("`virtual[5]::a` field/method does not exist".to_string()));
let d = env
.unifier
.add_record([("b".into(), RecordField::new(v1, true, None))].iter().cloned().collect());
assert_eq!(env.unify(b, d), Err("`virtual[5]::b` field/method does not exist".to_string()));
}
#[test]
fn test_typevar_range() {
let mut env = TestEnvironment::new();
let int = env.parse("int", &HashMap::new());
let boolean = env.parse("bool", &HashMap::new());
let float = env.parse("float", &HashMap::new());
let int_list = env.parse("list[int]", &HashMap::new());
let float_list = env.parse("list[float]", &HashMap::new());
// unification between v and int
// where v in (int, bool)
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0;
env.unifier.unify(int, v).unwrap();
// unification between v and list[int]
// where v in (int, bool)
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0;
assert_eq!(
env.unify(int_list, v),
Err("Expected any one of these types: 0, 2, but got list[0]".to_string())
);
// unification between v and float
// where v in (int, bool)
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0;
assert_eq!(
env.unify(float, v),
Err("Expected any one of these types: 0, 2, but got 1".to_string())
);
let v1 = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0;
let v1_list = env.unifier.add_ty(TypeEnum::TList { ty: v1 });
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).0;
// unification between v and int
// where v in (int, list[v1]), v1 in (int, bool)
env.unifier.unify(int, v).unwrap();
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).0;
// unification between v and list[int]
// where v in (int, list[v1]), v1 in (int, bool)
env.unifier.unify(int_list, v).unwrap();
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).0;
// unification between v and list[float]
// where v in (int, list[v1]), v1 in (int, bool)
assert_eq!(
env.unify(float_list, v),
Err("Expected any one of these types: 0, list[typevar5], but got list[1]\n\nNotes:\n typevar5 ∈ {0, 2}".to_string())
);
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0;
env.unifier.unify(a, b).unwrap();
env.unifier.unify(a, float).unwrap();
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0;
env.unifier.unify(a, b).unwrap();
assert_eq!(env.unify(a, int), Err("Expected any one of these types: 1, but got 0".into()));
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0;
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a });
let a_list = env.unifier.get_fresh_var_with_range(&[a_list], None, None).0;
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b });
let b_list = env.unifier.get_fresh_var_with_range(&[b_list], None, None).0;
env.unifier.unify(a_list, b_list).unwrap();
let float_list = env.unifier.add_ty(TypeEnum::TList { ty: float });
env.unifier.unify(a_list, float_list).unwrap();
// previous unifications should not affect a and b
env.unifier.unify(a, int).unwrap();
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0;
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a });
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b });
env.unifier.unify(a_list, b_list).unwrap();
let int_list = env.unifier.add_ty(TypeEnum::TList { ty: int });
assert_eq!(
env.unify(a_list, int_list),
Err("Incompatible types: list[typevar22] and list[0]\
\n\nNotes:\n typevar22 {1}".into())
);
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0;
let b = env.unifier.get_dummy_var().0;
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a });
let a_list = env.unifier.get_fresh_var_with_range(&[a_list], None, None).0;
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b });
env.unifier.unify(a_list, b_list).unwrap();
assert_eq!(
env.unify(b, boolean),
Err("Expected any one of these types: 0, 1, but got 2".into())
);
}
#[test]
fn test_rigid_var() {
let mut env = TestEnvironment::new();
let a = env.unifier.get_fresh_rigid_var(None, None).0;
let b = env.unifier.get_fresh_rigid_var(None, None).0;
let x = env.unifier.get_dummy_var().0;
let list_a = env.unifier.add_ty(TypeEnum::TList { ty: a });
let list_x = env.unifier.add_ty(TypeEnum::TList { ty: x });
let int = env.parse("int", &HashMap::new());
let list_int = env.parse("list[int]", &HashMap::new());
assert_eq!(env.unify(a, b), Err("Incompatible types: typevar3 and typevar2".to_string()));
env.unifier.unify(list_a, list_x).unwrap();
assert_eq!(env.unify(list_x, list_int), Err("Incompatible types: list[typevar2] and list[0]".to_string()));
env.unifier.replace_rigid_var(a, int);
env.unifier.unify(list_x, list_int).unwrap();
}
#[test]
fn test_instantiation() {
let mut env = TestEnvironment::new();
let int = env.parse("int", &HashMap::new());
let boolean = env.parse("bool", &HashMap::new());
let float = env.parse("float", &HashMap::new());
let list_int = env.parse("list[int]", &HashMap::new());
let obj_map: HashMap<_, _> =
[(0usize, "int"), (1, "float"), (2, "bool")].iter().cloned().collect();
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0;
let list_v = env.unifier.add_ty(TypeEnum::TList { ty: v });
let v1 = env.unifier.get_fresh_var_with_range(&[list_v, int], None, None).0;
let v2 = env.unifier.get_fresh_var_with_range(&[list_int, float], None, None).0;
let t = env.unifier.get_dummy_var().0;
let tuple = env.unifier.add_ty(TypeEnum::TTuple { ty: vec![v, v1, v2] });
let v3 = env.unifier.get_fresh_var_with_range(&[tuple, t], None, None).0;
// t = TypeVar('t')
// v = TypeVar('v', int, bool)
// v1 = TypeVar('v1', 'list[v]', int)
// v2 = TypeVar('v2', 'list[int]', float)
// v3 = TypeVar('v3', tuple[v, v1, v2], t)
// what values can v3 take?
let types = env.unifier.get_instantiations(v3).unwrap();
let expected_types = indoc! {"
tuple[bool, int, float]
tuple[bool, int, list[int]]
tuple[bool, list[bool], float]
tuple[bool, list[bool], list[int]]
tuple[bool, list[int], float]
tuple[bool, list[int], list[int]]
tuple[int, int, float]
tuple[int, int, list[int]]
tuple[int, list[bool], float]
tuple[int, list[bool], list[int]]
tuple[int, list[int], float]
tuple[int, list[int], list[int]]
v5"
}
.split('\n')
.collect_vec();
let types = types
.iter()
.map(|ty| {
env.unifier.internal_stringify(
*ty,
&mut |i| obj_map.get(&i).unwrap().to_string(),
&mut |i| format!("v{}", i),
&mut None,
)
})
.sorted()
.collect_vec();
assert_eq!(expected_types, types);
}

View File

@ -1,169 +0,0 @@
use std::rc::Rc;
use itertools::izip;
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
pub struct UnificationKey(usize);
#[derive(Clone)]
pub struct UnificationTable<V> {
parents: Vec<usize>,
ranks: Vec<u32>,
values: Vec<Option<V>>,
log: Vec<Action<V>>,
generation: u32,
}
#[derive(Clone, Debug)]
enum Action<V> {
Parent {
key: usize,
original_parent: usize,
},
Value {
key: usize,
original_value: Option<V>,
},
Rank {
key: usize,
original_rank: u32,
},
Marker {
generation: u32,
}
}
impl<V> Default for UnificationTable<V> {
fn default() -> Self {
Self::new()
}
}
impl<V> UnificationTable<V> {
pub fn new() -> UnificationTable<V> {
UnificationTable { parents: Vec::new(), ranks: Vec::new(), values: Vec::new(), log: Vec::new(), generation: 0 }
}
pub fn new_key(&mut self, v: V) -> UnificationKey {
let index = self.parents.len();
self.parents.push(index);
self.ranks.push(0);
self.values.push(Some(v));
UnificationKey(index)
}
pub fn unify(&mut self, a: UnificationKey, b: UnificationKey) {
let mut a = self.find(a);
let mut b = self.find(b);
if a == b {
return;
}
if self.ranks[a] < self.ranks[b] {
std::mem::swap(&mut a, &mut b);
}
self.log.push(Action::Parent { key: b, original_parent: self.parents[b] });
self.parents[b] = a;
if self.ranks[a] == self.ranks[b] {
self.log.push(Action::Rank { key: a, original_rank: self.ranks[a] });
self.ranks[a] += 1;
}
}
pub fn probe_value_immutable(&self, key: UnificationKey) -> &V {
let mut root = key.0;
let mut parent = self.parents[root];
while root != parent {
root = parent;
// parent = root.parent
parent = self.parents[parent];
}
self.values[parent].as_ref().unwrap()
}
pub fn probe_value(&mut self, a: UnificationKey) -> &V {
let index = self.find(a);
self.values[index].as_ref().unwrap()
}
pub fn set_value(&mut self, a: UnificationKey, v: V) {
let index = self.find(a);
let original_value = self.values[index].replace(v);
self.log.push(Action::Value { key: index, original_value });
}
pub fn unioned(&mut self, a: UnificationKey, b: UnificationKey) -> bool {
self.find(a) == self.find(b)
}
pub fn get_representative(&mut self, key: UnificationKey) -> UnificationKey {
UnificationKey(self.find(key))
}
fn find(&mut self, key: UnificationKey) -> usize {
let mut root = key.0;
let mut parent = self.parents[root];
while root != parent {
// a = parent.parent
let a = self.parents[parent];
// root.parent = parent.parent
self.log.push(Action::Parent { key: root, original_parent: self.parents[root] });
self.parents[root] = a;
root = parent;
// parent = root.parent
parent = a;
}
parent
}
pub fn get_snapshot(&mut self) -> (usize, u32) {
let generation = self.generation;
self.log.push(Action::Marker { generation });
self.generation += 1;
(self.log.len(), generation)
}
pub fn restore_snapshot(&mut self, snapshot: (usize, u32)) {
let (log_len, generation) = snapshot;
assert!(self.log.len() >= log_len, "snapshot restoration error");
assert!(matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation), "snapshot restoration error");
for action in self.log.drain(log_len - 1..).rev() {
match action {
Action::Parent { key, original_parent } => {
self.parents[key] = original_parent;
}
Action::Value { key, original_value } => {
self.values[key] = original_value;
}
Action::Rank { key, original_rank } => {
self.ranks[key] = original_rank;
}
Action::Marker { .. } => {}
}
}
}
pub fn discard_snapshot(&mut self, snapshot: (usize, u32)) {
let (log_len, generation) = snapshot;
assert!(self.log.len() >= log_len, "snapshot discard error");
assert!(matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation), "snapshot discard error");
self.log.clear();
}
}
impl<V> UnificationTable<Rc<V>>
where
V: Clone,
{
pub fn get_send(&self) -> UnificationTable<V> {
let values = izip!(self.values.iter(), self.parents.iter())
.enumerate()
.map(|(i, (v, p))| if *p == i { v.as_ref().map(|v| v.as_ref().clone()) } else { None })
.collect();
UnificationTable { parents: self.parents.clone(), ranks: self.ranks.clone(), values, log: Vec::new(), generation: 0 }
}
pub fn from_send(table: &UnificationTable<V>) -> UnificationTable<Rc<V>> {
let values = table.values.iter().cloned().map(|v| v.map(Rc::new)).collect();
UnificationTable { parents: table.parents.clone(), ranks: table.ranks.clone(), values, log: Vec::new(), generation: 0 }
}
}

15
nac3embedded/Cargo.toml Normal file
View File

@ -0,0 +1,15 @@
[package]
name = "nac3embedded"
version = "0.1.0"
authors = ["M-Labs"]
edition = "2018"
[lib]
name = "nac3embedded"
crate-type = ["cdylib"]
[dependencies]
pyo3 = { version = "0.12.4", features = ["extension-module"] }
inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm10-0"] }
rustpython-parser = { git = "https://github.com/RustPython/RustPython", branch = "master" }
nac3core = { path = "../nac3core" }

11
nac3embedded/demo.py Normal file
View File

@ -0,0 +1,11 @@
from language import *
class Demo:
@kernel
def run(self: bool) -> bool:
return False
if __name__ == "__main__":
Demo().run()

19
nac3embedded/language.py Normal file
View File

@ -0,0 +1,19 @@
from functools import wraps
import nac3embedded
__all__ = ["kernel", "portable"]
def kernel(function):
@wraps(function)
def run_on_core(self, *args, **kwargs):
nac3 = nac3embedded.NAC3()
nac3.register_host_object(self)
nac3.compile_method(self, function.__name__)
return run_on_core
def portable(function):
return fn

View File

@ -0,0 +1 @@
../target/release/libnac3embedded.so

116
nac3embedded/src/lib.rs Normal file
View File

@ -0,0 +1,116 @@
use std::collections::HashMap;
use std::collections::hash_map::Entry;
use pyo3::prelude::*;
use pyo3::exceptions;
use rustpython_parser::{ast, parser};
use inkwell::context::Context;
use inkwell::targets::*;
use nac3core::CodeGen;
fn runs_on_core(decorator_list: &[ast::Expression]) -> bool {
for decorator in decorator_list.iter() {
if let ast::ExpressionType::Identifier { name } = &decorator.node {
if name == "kernel" || name == "portable" {
return true
}
}
}
false
}
#[pyclass(name=NAC3)]
struct Nac3 {
type_definitions: HashMap<i64, ast::Program>,
host_objects: HashMap<i64, i64>,
}
#[pymethods]
impl Nac3 {
#[new]
fn new() -> Self {
Nac3 {
type_definitions: HashMap::new(),
host_objects: HashMap::new(),
}
}
fn register_host_object(&mut self, obj: PyObject) -> PyResult<()> {
Python::with_gil(|py| -> PyResult<()> {
let obj: &PyAny = obj.extract(py)?;
let obj_type = obj.get_type();
let builtins = PyModule::import(py, "builtins")?;
let type_id = builtins.call1("id", (obj_type, ))?.extract()?;
let entry = self.type_definitions.entry(type_id);
if let Entry::Vacant(entry) = entry {
let source = PyModule::import(py, "inspect")?.call1("getsource", (obj_type, ))?;
let ast = parser::parse_program(source.extract()?).map_err(|e|
exceptions::PySyntaxError::new_err(format!("failed to parse host object source: {}", e)))?;
entry.insert(ast);
// TODO: examine AST and recursively register dependencies
};
let obj_id = builtins.call1("id", (obj, ))?.extract()?;
match self.host_objects.entry(obj_id) {
Entry::Vacant(entry) => entry.insert(type_id),
Entry::Occupied(_) => return Err(
exceptions::PyValueError::new_err("host object registered twice")),
};
// TODO: collect other information about host object, e.g. value of fields
Ok(())
})
}
fn compile_method(&self, obj: PyObject, name: String) -> PyResult<()> {
Python::with_gil(|py| -> PyResult<()> {
let obj: &PyAny = obj.extract(py)?;
let builtins = PyModule::import(py, "builtins")?;
let obj_id = builtins.call1("id", (obj, ))?.extract()?;
let type_id = self.host_objects.get(&obj_id).ok_or_else(||
exceptions::PyKeyError::new_err("type of host object not found"))?;
let ast = self.type_definitions.get(&type_id).ok_or_else(||
exceptions::PyKeyError::new_err("type definition not found"))?;
if let ast::StatementType::ClassDef {
name: _,
body,
bases: _,
keywords: _,
decorator_list: _ } = &ast.statements[0].node {
for statement in body.iter() {
if let ast::StatementType::FunctionDef {
is_async: _,
name: funcdef_name,
args: _,
body: _,
decorator_list,
returns: _ } = &statement.node {
if runs_on_core(decorator_list) && funcdef_name == &name {
let context = Context::create();
let mut codegen = CodeGen::new(&context);
codegen.compile_toplevel(&body[0]).map_err(|e|
exceptions::PyRuntimeError::new_err(format!("compilation failed: {}", e)))?;
codegen.print_ir();
}
}
}
} else {
return Err(exceptions::PyValueError::new_err("expected ClassDef for type definition"));
}
Ok(())
})
}
}
#[pymodule]
fn nac3embedded(_py: Python, m: &PyModule) -> PyResult<()> {
Target::initialize_all(&InitializationConfig::default());
m.add_class::<Nac3>()?;
Ok(())
}

View File

@ -1,8 +0,0 @@
[package]
name = "nac3ld"
version = "0.1.0"
authors = ["M-Labs"]
edition = "2021"
[dependencies]
byteorder = { version = "1.5", default-features = false }

View File

@ -1,523 +0,0 @@
#![allow(non_camel_case_types, non_upper_case_globals)]
use std::mem;
use byteorder::{ByteOrder, LittleEndian};
pub const DW_EH_PE_omit: u8 = 0xFF;
pub const DW_EH_PE_absptr: u8 = 0x00;
pub const DW_EH_PE_uleb128: u8 = 0x01;
pub const DW_EH_PE_udata2: u8 = 0x02;
pub const DW_EH_PE_udata4: u8 = 0x03;
pub const DW_EH_PE_udata8: u8 = 0x04;
pub const DW_EH_PE_sleb128: u8 = 0x09;
pub const DW_EH_PE_sdata2: u8 = 0x0A;
pub const DW_EH_PE_sdata4: u8 = 0x0B;
pub const DW_EH_PE_sdata8: u8 = 0x0C;
pub const DW_EH_PE_pcrel: u8 = 0x10;
pub const DW_EH_PE_textrel: u8 = 0x20;
pub const DW_EH_PE_datarel: u8 = 0x30;
pub const DW_EH_PE_funcrel: u8 = 0x40;
pub const DW_EH_PE_aligned: u8 = 0x50;
pub const DW_EH_PE_indirect: u8 = 0x80;
pub struct DwarfReader<'a> {
pub slice: &'a [u8],
pub virt_addr: u32,
base_slice: &'a [u8],
base_virt_addr: u32,
}
impl<'a> DwarfReader<'a> {
pub fn new(slice: &[u8], virt_addr: u32) -> DwarfReader {
DwarfReader { slice, virt_addr, base_slice: slice, base_virt_addr: virt_addr }
}
/// Creates a new instance from another instance of [DwarfReader], optionally removing any
/// offsets previously applied to the other instance.
pub fn from_reader(other: &DwarfReader<'a>, reset_offset: bool) -> DwarfReader<'a> {
if reset_offset {
DwarfReader::new(other.base_slice, other.base_virt_addr)
} else {
DwarfReader::new(other.slice, other.virt_addr)
}
}
pub fn offset(&mut self, offset: u32) {
self.slice = &self.slice[offset as usize..];
self.virt_addr = self.virt_addr.wrapping_add(offset);
}
/// ULEB128 and SLEB128 encodings are defined in Section 7.6 - "Variable Length Data" of the
/// [DWARF-4 Manual](https://dwarfstd.org/doc/DWARF4.pdf).
pub fn read_uleb128(&mut self) -> u64 {
let mut shift: usize = 0;
let mut result: u64 = 0;
let mut byte: u8;
loop {
byte = self.read_u8();
result |= ((byte & 0x7F) as u64) << shift;
shift += 7;
if byte & 0x80 == 0 {
break;
}
}
result
}
pub fn read_sleb128(&mut self) -> i64 {
let mut shift: u32 = 0;
let mut result: u64 = 0;
let mut byte: u8;
loop {
byte = self.read_u8();
result |= ((byte & 0x7F) as u64) << shift;
shift += 7;
if byte & 0x80 == 0 {
break;
}
}
// sign-extend
if shift < u64::BITS && (byte & 0x40) != 0 {
result |= (!0u64) << shift;
}
result as i64
}
pub fn read_u8(&mut self) -> u8 {
let val = self.slice[0];
self.slice = &self.slice[1..];
val
}
}
macro_rules! impl_read_fn {
( $($type: ty, $byteorder_fn: ident);* ) => {
impl<'a> DwarfReader<'a> {
$(
pub fn $byteorder_fn(&mut self) -> $type {
let val = LittleEndian::$byteorder_fn(self.slice);
self.slice = &self.slice[mem::size_of::<$type>()..];
val
}
)*
}
}
}
impl_read_fn!(
u16, read_u16;
u32, read_u32;
u64, read_u64;
i16, read_i16;
i32, read_i32;
i64, read_i64
);
pub struct DwarfWriter<'a> {
pub slice: &'a mut [u8],
pub offset: usize,
}
impl<'a> DwarfWriter<'a> {
pub fn new(slice: &mut [u8]) -> DwarfWriter {
DwarfWriter { slice, offset: 0 }
}
pub fn write_u8(&mut self, data: u8) {
self.slice[self.offset] = data;
self.offset += 1;
}
pub fn write_u32(&mut self, data: u32) {
LittleEndian::write_u32(&mut self.slice[self.offset..], data);
self.offset += 4;
}
}
fn read_encoded_pointer(reader: &mut DwarfReader, encoding: u8) -> Result<usize, ()> {
if encoding == DW_EH_PE_omit {
return Err(());
}
// DW_EH_PE_aligned implies it's an absolute pointer value
// However, we are linking library for 32-bits architecture
// The size of variable should be 4 bytes instead
if encoding == DW_EH_PE_aligned {
let shifted_virt_addr = round_up(reader.virt_addr as usize, mem::size_of::<u32>())?;
let addr_inc = shifted_virt_addr - reader.virt_addr as usize;
reader.slice = &reader.slice[addr_inc..];
reader.virt_addr = shifted_virt_addr as u32;
return Ok(reader.read_u32() as usize);
}
match encoding & 0x0F {
DW_EH_PE_absptr => Ok(reader.read_u32() as usize),
DW_EH_PE_uleb128 => Ok(reader.read_uleb128() as usize),
DW_EH_PE_udata2 => Ok(reader.read_u16() as usize),
DW_EH_PE_udata4 => Ok(reader.read_u32() as usize),
DW_EH_PE_udata8 => Ok(reader.read_u64() as usize),
DW_EH_PE_sleb128 => Ok(reader.read_sleb128() as usize),
DW_EH_PE_sdata2 => Ok(reader.read_i16() as usize),
DW_EH_PE_sdata4 => Ok(reader.read_i32() as usize),
DW_EH_PE_sdata8 => Ok(reader.read_i64() as usize),
_ => Err(()),
}
}
fn read_encoded_pointer_with_pc(
reader: &mut DwarfReader,
encoding: u8,
) -> Result<usize, ()> {
let entry_virt_addr = reader.virt_addr;
let mut result = read_encoded_pointer(reader, encoding)?;
// DW_EH_PE_aligned implies it's an absolute pointer value
if encoding == DW_EH_PE_aligned {
return Ok(result);
}
result = match encoding & 0x70 {
DW_EH_PE_pcrel => result.wrapping_add(entry_virt_addr as usize),
// .eh_frame normally would not have these kinds of relocations
// These would not be supported by a dedicated linker relocation schemes for RISC-V
DW_EH_PE_textrel | DW_EH_PE_datarel | DW_EH_PE_funcrel | DW_EH_PE_aligned => {
unimplemented!()
}
// Other values should be impossible
_ => unreachable!(),
};
if encoding & DW_EH_PE_indirect != 0 {
// There should not be a need for indirect addressing, as assembly code from
// the dynamic library should not be freely moved relative to the EH frame.
unreachable!()
}
Ok(result)
}
#[inline]
fn round_up(unrounded: usize, align: usize) -> Result<usize, ()> {
if align.is_power_of_two() {
Ok((unrounded + align - 1) & !(align - 1))
} else {
Err(())
}
}
/// Minimalistic structure to store everything needed for parsing FDEs to synthesize `.eh_frame_hdr`
/// section.
///
/// Refer to [The Linux Standard Base Core Specification, Generic Part](https://refspecs.linuxfoundation.org/LSB_5.0.0/LSB-Core-generic/LSB-Core-generic/ehframechpt.html)
/// for more information.
pub struct EH_Frame<'a> {
reader: DwarfReader<'a>,
}
impl<'a> EH_Frame<'a> {
/// Creates an [EH_Frame] using the bytes in the `.eh_frame` section and its address in the ELF
/// file.
pub fn new(eh_frame_slice: &[u8], eh_frame_addr: u32) -> Result<EH_Frame, ()> {
Ok(EH_Frame { reader: DwarfReader::new(eh_frame_slice, eh_frame_addr) })
}
/// Returns an [Iterator] over all Call Frame Information (CFI) records.
pub fn cfi_records(&self) -> CFI_Records<'a> {
let reader = DwarfReader::from_reader(&self.reader, true);
let len = reader.slice.len();
CFI_Records {
reader,
available: len,
}
}
}
/// A single Call Frame Information (CFI) record.
///
/// From the [specification](https://refspecs.linuxfoundation.org/LSB_5.0.0/LSB-Core-generic/LSB-Core-generic/ehframechpt.html):
///
/// > Each CFI record contains a Common Information Entry (CIE) record followed by 1 or more Frame
/// Description Entry (FDE) records.
pub struct CFI_Record<'a> {
// It refers to the augmentation data that corresponds to 'R' in the augmentation string
fde_pointer_encoding: u8,
fde_reader: DwarfReader<'a>,
}
impl<'a> CFI_Record<'a> {
pub fn from_reader(cie_reader: &mut DwarfReader<'a>) -> Result<CFI_Record<'a>, ()> {
let length = cie_reader.read_u32();
let fde_reader = match length {
// eh_frame with 0 lengths means the CIE is terminated
0 => panic!("Cannot create an EH_Frame from a termination CIE"),
// length == u32::MAX means that the length is only representable with 64 bits,
// which does not make sense in a system with 32-bit address.
0xFFFFFFFF => unimplemented!(),
_ => {
let mut fde_reader = DwarfReader::from_reader(cie_reader, false);
fde_reader.offset(length);
fde_reader
}
};
// Routine check on the .eh_frame well-formness, in terms of CIE ID & Version args.
let cie_ptr = cie_reader.read_u32();
assert_eq!(cie_ptr, 0);
assert_eq!(cie_reader.read_u8(), 1);
// Parse augmentation string
// The first character must be 'z', there is no way to proceed otherwise
assert_eq!(cie_reader.read_u8(), b'z');
// Establish a pointer that skips ahead of the string
// Skip code/data alignment factors & return address register along the way as well
// We only tackle the case where 'z' and 'R' are part of the augmentation string, otherwise
// we cannot get the addresses to make .eh_frame_hdr
let mut aug_data_reader = DwarfReader::from_reader(cie_reader, false);
let mut aug_str_len = 0;
loop {
if aug_data_reader.read_u8() == b'\0' {
break;
}
aug_str_len += 1;
}
if aug_str_len == 0 {
unimplemented!();
}
aug_data_reader.read_uleb128(); // Code alignment factor
aug_data_reader.read_sleb128(); // Data alignment factor
aug_data_reader.read_uleb128(); // Return address register
aug_data_reader.read_uleb128(); // Augmentation data length
let mut fde_pointer_encoding = DW_EH_PE_omit;
for _ in 0..aug_str_len {
match cie_reader.read_u8() {
b'L' => {
aug_data_reader.read_u8();
}
b'P' => {
let encoding = aug_data_reader.read_u8();
read_encoded_pointer(&mut aug_data_reader, encoding)?;
}
b'R' => {
fde_pointer_encoding = aug_data_reader.read_u8();
}
// Other characters are not supported
_ => unimplemented!(),
}
}
assert_ne!(fde_pointer_encoding, DW_EH_PE_omit);
Ok(CFI_Record {
fde_pointer_encoding,
fde_reader,
})
}
/// Returns a [DwarfReader] initialized to the first Frame Description Entry (FDE) of this CFI
/// record.
pub fn get_fde_reader(&self) -> DwarfReader<'a> {
DwarfReader::from_reader(&self.fde_reader, true)
}
/// Returns an [Iterator] over all Frame Description Entries (FDEs).
pub fn fde_records(&self) -> FDE_Records<'a> {
let reader = self.get_fde_reader();
let len = reader.slice.len();
FDE_Records {
pointer_encoding: self.fde_pointer_encoding,
reader,
available: len,
}
}
}
/// [Iterator] over Call Frame Information (CFI) records in an
/// [Exception Handling (EH) frame][EH_Frame].
pub struct CFI_Records<'a> {
reader: DwarfReader<'a>,
available: usize,
}
impl<'a> Iterator for CFI_Records<'a> {
type Item = CFI_Record<'a>;
fn next(&mut self) -> Option<Self::Item> {
loop {
if self.available == 0 {
return None;
}
let mut this_reader = DwarfReader::from_reader(&self.reader, false);
// Remove the length of the header and the content from the counter
let length = self.reader.read_u32();
let length = match length {
// eh_frame with 0-length means the CIE is terminated
0 => return None,
0xFFFFFFFF => unimplemented!("CIE entries larger than 4 bytes not supported"),
other => other,
} as usize;
// Remove the length of the header and the content from the counter
self.available -= length + mem::size_of::<u32>();
let mut next_reader = DwarfReader::from_reader(&self.reader, false);
next_reader.offset(length as u32);
let cie_ptr = self.reader.read_u32();
self.reader = next_reader;
// Skip this record if it is a FDE
if cie_ptr == 0 {
// Rewind back to the start of the CFI Record
return Some(CFI_Record::from_reader(&mut this_reader).ok().unwrap())
}
}
}
}
/// [Iterator] over Frame Description Entries (FDEs) in an
/// [Exception Handling (EH) frame][EH_Frame].
pub struct FDE_Records<'a> {
pointer_encoding: u8,
reader: DwarfReader<'a>,
available: usize,
}
impl<'a> Iterator for FDE_Records<'a> {
type Item = (u32, u32);
fn next(&mut self) -> Option<Self::Item> {
// Parse each FDE to obtain the starting address that the FDE applies to
// Send the FDE offset and the mentioned address to a callback that write up the
// .eh_frame_hdr section
if self.available == 0 {
return None;
}
// Remove the length of the header and the content from the counter
let length = match self.reader.read_u32() {
// eh_frame with 0-length means the CIE is terminated
0 => return None,
0xFFFFFFFF => unimplemented!("CIE entries larger than 4 bytes not supported"),
other => other,
} as usize;
// Remove the length of the header and the content from the counter
self.available -= length + mem::size_of::<u32>();
let mut next_fde_reader = DwarfReader::from_reader(&self.reader, false);
next_fde_reader.offset(length as u32);
let cie_ptr = self.reader.read_u32();
let next_val = if cie_ptr != 0 {
let pc_begin = read_encoded_pointer_with_pc(&mut self.reader, self.pointer_encoding)
.expect("Failed to read PC Begin");
Some((pc_begin as u32, self.reader.virt_addr))
} else {
None
};
self.reader = next_fde_reader;
next_val
}
}
pub struct EH_Frame_Hdr<'a> {
fde_writer: DwarfWriter<'a>,
eh_frame_hdr_addr: u32,
fdes: Vec<(u32, u32)>,
}
impl<'a> EH_Frame_Hdr<'a> {
/// Create a [EH_Frame_Hdr] object, and write out the fixed fields of `.eh_frame_hdr` to memory.
///
/// Load address is not known at this point.
pub fn new(
eh_frame_hdr_slice: &mut [u8],
eh_frame_hdr_addr: u32,
eh_frame_addr: u32,
) -> EH_Frame_Hdr {
let mut writer = DwarfWriter::new(eh_frame_hdr_slice);
writer.write_u8(1); // version
writer.write_u8(0x1B); // eh_frame_ptr_enc - PC-relative 4-byte signed value
writer.write_u8(0x03); // fde_count_enc - 4-byte unsigned value
writer.write_u8(0x3B); // table_enc - .eh_frame_hdr section-relative 4-byte signed value
let eh_frame_offset = eh_frame_addr
.wrapping_sub(eh_frame_hdr_addr + writer.offset as u32 + ((mem::size_of::<u8>() as u32) * 4));
writer.write_u32(eh_frame_offset); // eh_frame_ptr
writer.write_u32(0); // `fde_count`, will be written in finalize_fde
EH_Frame_Hdr { fde_writer: writer, eh_frame_hdr_addr, fdes: Vec::new() }
}
/// The offset of the `fde_count` value relative to the start of the `.eh_frame_hdr` section in
/// bytes.
fn fde_count_offset() -> usize {
8
}
pub fn add_fde(&mut self, init_loc: u32, addr: u32) {
self.fdes.push((
init_loc.wrapping_sub(self.eh_frame_hdr_addr),
addr.wrapping_sub(self.eh_frame_hdr_addr),
));
}
pub fn finalize_fde(mut self) {
self.fdes
.sort_by(|(left_init_loc, _), (right_init_loc, _)| left_init_loc.cmp(right_init_loc));
for (init_loc, addr) in &self.fdes {
self.fde_writer.write_u32(*init_loc);
self.fde_writer.write_u32(*addr);
}
LittleEndian::write_u32(&mut self.fde_writer.slice[Self::fde_count_offset()..], self.fdes.len() as u32);
}
pub fn size_from_eh_frame(eh_frame: &[u8]) -> usize {
// The virtual address of the EH frame does not matter in this case
// Calculation of size does not involve modifying any headers
let mut reader = DwarfReader::new(eh_frame, 0);
let mut fde_count = 0;
while !reader.slice.is_empty() {
// The original length field should be able to hold the entire value.
// The device memory space is limited to 32-bits addresses anyway.
let entry_length = reader.read_u32();
if entry_length == 0 || entry_length == 0xFFFFFFFF {
unimplemented!()
}
// This slot stores the CIE ID (for CIE)/CIE Pointer (for FDE).
// This value must be non-zero for FDEs.
let cie_ptr = reader.read_u32();
if cie_ptr != 0 {
fde_count += 1;
}
reader.offset(entry_length - mem::size_of::<u32>() as u32)
}
12 + fde_count * 8
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,24 +0,0 @@
[package]
name = "nac3parser"
version = "0.1.2"
description = "Parser for python code."
authors = [ "RustPython Team", "M-Labs" ]
build = "build.rs"
license = "MIT"
edition = "2021"
[build-dependencies]
lalrpop = "0.20"
[dependencies]
nac3ast = { path = "../nac3ast" }
lalrpop-util = "0.20"
log = "0.4"
unic-emoji-char = "0.9"
unic-ucd-ident = "0.9"
unicode_names2 = "1.2"
phf = { version = "0.11", features = ["macros"] }
ahash = "0.8"
[dev-dependencies]
insta = "=1.11.0"

View File

@ -1,56 +0,0 @@
# nac3parser
This directory has the code for python lexing, parsing and generating Abstract Syntax Trees (AST).
This is the RustPython parser with modifications for NAC3.
The steps are:
- Lexical analysis: splits the source code into tokens.
- Parsing and generating the AST: transforms those tokens into an AST. Uses `LALRPOP`, a Rust parser generator framework.
The RustPython team wrote [a blog post](https://rustpython.github.io/2020/04/02/thing-explainer-parser.html) with screenshots and an explanation to help you understand the steps by seeing them in action.
For more information on LALRPOP, here is a link to the [LALRPOP book](https://github.com/lalrpop/lalrpop).
There is a readme in the `src` folder with the details of each file.
## Directory content
`build.rs`: The build script.
`Cargo.toml`: The config file.
The `src` directory has:
**lib.rs**
This is the crate's root.
**lexer.rs**
This module takes care of lexing python source text. This means source code is translated into separate tokens.
**parser.rs**
A python parsing module. Use this module to parse python code into an AST. There are three ways to parse python code. You could parse a whole program, a single statement, or a single expression.
**ast.rs**
Implements abstract syntax tree (AST) nodes for the python language. Roughly equivalent to [the python AST](https://docs.python.org/3/library/ast.html).
**python.lalrpop**
Python grammar.
**token.rs**
Different token definitions. Loosely based on token.h from CPython source.
**errors.rs**
Define internal parse error types. The goal is to provide a matching and a safe error API, masking errors from LALR.
**fstring.rs**
Format strings.
**function.rs**
Collection of functions for parsing parameters, arguments.
**location.rs**
Datatypes to support source location information.
**mode.rs**
Execution mode check. Allowed modes are `exec`, `eval` or `single`.

View File

@ -1,3 +0,0 @@
fn main() {
lalrpop::process_root().unwrap()
}

View File

@ -1,85 +0,0 @@
use lalrpop_util::ParseError;
use nac3ast::*;
use crate::ast::Ident;
use crate::ast::Location;
use crate::token::Tok;
use crate::error::*;
pub fn make_config_comment(
com_loc: Location,
stmt_loc: Location,
nac3com_above: Vec<(Ident, Tok)>,
nac3com_end: Option<Ident>
) -> Result<Vec<Ident>, ParseError<Location, Tok, LexicalError>> {
if com_loc.column() != stmt_loc.column() && !nac3com_above.is_empty() {
return Err(ParseError::User {
error: LexicalError {
location: com_loc,
error: LexicalErrorType::OtherError(
format!(
"config comment at top must have the same indentation with what it applies (comment at {}, statement at {})",
com_loc,
stmt_loc,
)
)
}
})
};
Ok(
nac3com_above
.into_iter()
.map(|(com, _)| com)
.chain(nac3com_end.map_or_else(|| vec![].into_iter(), |com| vec![com].into_iter()))
.collect()
)
}
pub fn handle_small_stmt<U>(stmts: &mut [Stmt<U>], nac3com_above: Vec<(Ident, Tok)>, nac3com_end: Option<Ident>, com_above_loc: Location) -> Result<(), ParseError<Location, Tok, LexicalError>> {
if com_above_loc.column() != stmts[0].location.column() && !nac3com_above.is_empty() {
return Err(ParseError::User {
error: LexicalError {
location: com_above_loc,
error: LexicalErrorType::OtherError(
format!(
"config comment at top must have the same indentation with what it applies (comment at {}, statement at {})",
com_above_loc,
stmts[0].location,
)
)
}
})
}
apply_config_comments(
&mut stmts[0],
nac3com_above
.into_iter()
.map(|(com, _)| com).collect()
);
apply_config_comments(
stmts.last_mut().unwrap(),
nac3com_end.map_or_else(Vec::new, |com| vec![com])
);
Ok(())
}
fn apply_config_comments<U>(stmt: &mut Stmt<U>, comments: Vec<Ident>) {
match &mut stmt.node {
StmtKind::Pass { config_comment, .. }
| StmtKind::Delete { config_comment, .. }
| StmtKind::Expr { config_comment, .. }
| StmtKind::Assign { config_comment, .. }
| StmtKind::AugAssign { config_comment, .. }
| StmtKind::AnnAssign { config_comment, .. }
| StmtKind::Break { config_comment, .. }
| StmtKind::Continue { config_comment, .. }
| StmtKind::Return { config_comment, .. }
| StmtKind::Raise { config_comment, .. }
| StmtKind::Import { config_comment, .. }
| StmtKind::ImportFrom { config_comment, .. }
| StmtKind::Global { config_comment, .. }
| StmtKind::Nonlocal { config_comment, .. }
| StmtKind::Assert { config_comment, .. } => config_comment.extend(comments),
_ => { unreachable!("only small statements should call this function") }
}
}

View File

@ -1,239 +0,0 @@
//! Define internal parse error types
//! The goal is to provide a matching and a safe error API, maksing errors from LALR
use lalrpop_util::ParseError as LalrpopError;
use crate::ast::Location;
use crate::token::Tok;
use std::error::Error;
use std::fmt;
/// Represents an error during lexical scanning.
#[derive(Debug, PartialEq)]
pub struct LexicalError {
pub error: LexicalErrorType,
pub location: Location,
}
#[derive(Debug, PartialEq)]
pub enum LexicalErrorType {
StringError,
UnicodeError,
NestingError,
IndentationError,
TabError,
TabsAfterSpaces,
DefaultArgumentError,
PositionalArgumentError,
DuplicateKeywordArgumentError,
UnrecognizedToken { tok: char },
FStringError(FStringErrorType),
LineContinuationError,
Eof,
OtherError(String),
}
impl fmt::Display for LexicalErrorType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
LexicalErrorType::StringError => write!(f, "Got unexpected string"),
LexicalErrorType::FStringError(error) => write!(f, "Got error in f-string: {}", error),
LexicalErrorType::UnicodeError => write!(f, "Got unexpected unicode"),
LexicalErrorType::NestingError => write!(f, "Got unexpected nesting"),
LexicalErrorType::IndentationError => {
write!(f, "unindent does not match any outer indentation level")
}
LexicalErrorType::TabError => {
write!(f, "inconsistent use of tabs and spaces in indentation")
}
LexicalErrorType::TabsAfterSpaces => {
write!(f, "Tabs not allowed as part of indentation after spaces")
}
LexicalErrorType::DefaultArgumentError => {
write!(f, "non-default argument follows default argument")
}
LexicalErrorType::DuplicateKeywordArgumentError => {
write!(f, "keyword argument repeated")
}
LexicalErrorType::PositionalArgumentError => {
write!(f, "positional argument follows keyword argument")
}
LexicalErrorType::UnrecognizedToken { tok } => {
write!(f, "Got unexpected token {}", tok)
}
LexicalErrorType::LineContinuationError => {
write!(f, "unexpected character after line continuation character")
}
LexicalErrorType::Eof => write!(f, "unexpected EOF while parsing"),
LexicalErrorType::OtherError(msg) => write!(f, "{}", msg),
}
}
}
// TODO: consolidate these with ParseError
#[derive(Debug, PartialEq)]
pub struct FStringError {
pub error: FStringErrorType,
pub location: Location,
}
#[derive(Debug, PartialEq)]
pub enum FStringErrorType {
UnclosedLbrace,
UnopenedRbrace,
ExpectedRbrace,
InvalidExpression(Box<ParseErrorType>),
InvalidConversionFlag,
EmptyExpression,
MismatchedDelimiter,
ExpressionNestedTooDeeply,
}
impl fmt::Display for FStringErrorType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
FStringErrorType::UnclosedLbrace => write!(f, "Unclosed '{{'"),
FStringErrorType::UnopenedRbrace => write!(f, "Unopened '}}'"),
FStringErrorType::ExpectedRbrace => write!(f, "Expected '}}' after conversion flag."),
FStringErrorType::InvalidExpression(error) => {
write!(f, "Invalid expression: {}", error)
}
FStringErrorType::InvalidConversionFlag => write!(f, "Invalid conversion flag"),
FStringErrorType::EmptyExpression => write!(f, "Empty expression"),
FStringErrorType::MismatchedDelimiter => write!(f, "Mismatched delimiter"),
FStringErrorType::ExpressionNestedTooDeeply => {
write!(f, "expressions nested too deeply")
}
}
}
}
impl From<FStringError> for LalrpopError<Location, Tok, LexicalError> {
fn from(err: FStringError) -> Self {
lalrpop_util::ParseError::User {
error: LexicalError {
error: LexicalErrorType::FStringError(err.error),
location: err.location,
},
}
}
}
/// Represents an error during parsing
#[derive(Debug, PartialEq)]
pub struct ParseError {
pub error: ParseErrorType,
pub location: Location,
}
#[derive(Debug, PartialEq)]
pub enum ParseErrorType {
/// Parser encountered an unexpected end of input
Eof,
/// Parser encountered an extra token
ExtraToken(Tok),
/// Parser encountered an invalid token
InvalidToken,
/// Parser encountered an unexpected token
UnrecognizedToken(Tok, Option<String>),
/// Maps to `User` type from `lalrpop-util`
Lexical(LexicalErrorType),
}
/// Convert `lalrpop_util::ParseError` to our internal type
impl From<LalrpopError<Location, Tok, LexicalError>> for ParseError {
fn from(err: LalrpopError<Location, Tok, LexicalError>) -> Self {
match err {
// TODO: Are there cases where this isn't an EOF?
LalrpopError::InvalidToken { location } => ParseError {
error: ParseErrorType::Eof,
location,
},
LalrpopError::ExtraToken { token } => ParseError {
error: ParseErrorType::ExtraToken(token.1),
location: token.0,
},
LalrpopError::User { error } => ParseError {
error: ParseErrorType::Lexical(error.error),
location: error.location,
},
LalrpopError::UnrecognizedToken { token, expected } => {
// Hacky, but it's how CPython does it. See PyParser_AddToken,
// in particular "Only one possible expected token" comment.
let expected = if expected.len() == 1 {
Some(expected[0].clone())
} else {
None
};
ParseError {
error: ParseErrorType::UnrecognizedToken(token.1, expected),
location: token.0,
}
}
LalrpopError::UnrecognizedEof { location, .. } => ParseError {
error: ParseErrorType::Eof,
location,
},
}
}
}
impl fmt::Display for ParseError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{} at {}", self.error, self.location)
}
}
impl fmt::Display for ParseErrorType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
ParseErrorType::Eof => write!(f, "Got unexpected EOF"),
ParseErrorType::ExtraToken(ref tok) => write!(f, "Got extraneous token: {:?}", tok),
ParseErrorType::InvalidToken => write!(f, "Got invalid token"),
ParseErrorType::UnrecognizedToken(ref tok, ref expected) => {
if *tok == Tok::Indent {
write!(f, "unexpected indent")
} else if expected.as_deref() == Some("Indent") {
write!(f, "expected an indented block")
} else {
write!(f, "Got unexpected token {}", tok)
}
}
ParseErrorType::Lexical(ref error) => write!(f, "{}", error),
}
}
}
impl Error for ParseErrorType {}
impl ParseErrorType {
pub fn is_indentation_error(&self) -> bool {
match self {
ParseErrorType::Lexical(LexicalErrorType::IndentationError) => true,
ParseErrorType::UnrecognizedToken(token, expected) => {
*token == Tok::Indent || expected.clone() == Some("Indent".to_owned())
}
_ => false,
}
}
pub fn is_tab_error(&self) -> bool {
matches!(
self,
ParseErrorType::Lexical(LexicalErrorType::TabError)
| ParseErrorType::Lexical(LexicalErrorType::TabsAfterSpaces)
)
}
}
impl std::ops::Deref for ParseError {
type Target = ParseErrorType;
fn deref(&self) -> &Self::Target {
&self.error
}
}
impl Error for ParseError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
None
}
}

View File

@ -1,405 +0,0 @@
use std::iter;
use std::mem;
use std::str;
use crate::ast::{Constant, ConversionFlag, Expr, ExprKind, Location};
use crate::error::{FStringError, FStringErrorType, ParseError};
use crate::parser::parse_expression;
use self::FStringErrorType::*;
struct FStringParser<'a> {
chars: iter::Peekable<str::Chars<'a>>,
str_location: Location,
}
impl<'a> FStringParser<'a> {
fn new(source: &'a str, str_location: Location) -> Self {
Self {
chars: source.chars().peekable(),
str_location,
}
}
#[inline]
fn expr(&self, node: ExprKind) -> Expr {
Expr::new(self.str_location, node)
}
fn parse_formatted_value(&mut self) -> Result<Vec<Expr>, FStringErrorType> {
let mut expression = String::new();
let mut spec = None;
let mut delims = Vec::new();
let mut conversion = None;
let mut pred_expression_text = String::new();
let mut trailing_seq = String::new();
while let Some(ch) = self.chars.next() {
match ch {
// can be integrated better with the remainign code, but as a starting point ok
// in general I would do here a tokenizing of the fstrings to omit this peeking.
'!' if self.chars.peek() == Some(&'=') => {
expression.push_str("!=");
self.chars.next();
}
'=' if self.chars.peek() == Some(&'=') => {
expression.push_str("==");
self.chars.next();
}
'>' if self.chars.peek() == Some(&'=') => {
expression.push_str(">=");
self.chars.next();
}
'<' if self.chars.peek() == Some(&'=') => {
expression.push_str("<=");
self.chars.next();
}
'!' if delims.is_empty() && self.chars.peek() != Some(&'=') => {
if expression.trim().is_empty() {
return Err(EmptyExpression);
}
conversion = Some(match self.chars.next() {
Some('s') => ConversionFlag::Str,
Some('a') => ConversionFlag::Ascii,
Some('r') => ConversionFlag::Repr,
Some(_) => {
return Err(InvalidConversionFlag);
}
None => {
return Err(ExpectedRbrace);
}
});
if let Some(&peek) = self.chars.peek() {
if peek != '}' && peek != ':' {
return Err(ExpectedRbrace);
}
} else {
return Err(ExpectedRbrace);
}
}
// match a python 3.8 self documenting expression
// format '{' PYTHON_EXPRESSION '=' FORMAT_SPECIFIER? '}'
'=' if self.chars.peek() != Some(&'=') && delims.is_empty() => {
pred_expression_text = expression.to_string(); // safe expression before = to print it
}
':' if delims.is_empty() => {
let mut nested = false;
let mut in_nested = false;
let mut spec_expression = String::new();
while let Some(&next) = self.chars.peek() {
match next {
'{' => {
if in_nested {
return Err(ExpressionNestedTooDeeply);
}
in_nested = true;
nested = true;
self.chars.next();
continue;
}
'}' => {
if in_nested {
in_nested = false;
self.chars.next();
}
break;
}
_ => (),
}
spec_expression.push(next);
self.chars.next();
}
if in_nested {
return Err(UnclosedLbrace);
}
spec = Some(if nested {
Box::new(
self.expr(ExprKind::FormattedValue {
value: Box::new(
parse_fstring_expr(&spec_expression)
.map_err(|e| InvalidExpression(Box::new(e.error)))?,
),
conversion: None,
format_spec: None,
}),
)
} else {
Box::new(self.expr(ExprKind::Constant {
value: spec_expression.to_owned().into(),
kind: None,
}))
})
}
'(' | '{' | '[' => {
expression.push(ch);
delims.push(ch);
}
')' => {
if delims.pop() != Some('(') {
return Err(MismatchedDelimiter);
}
expression.push(ch);
}
']' => {
if delims.pop() != Some('[') {
return Err(MismatchedDelimiter);
}
expression.push(ch);
}
'}' if !delims.is_empty() => {
if delims.pop() != Some('{') {
return Err(MismatchedDelimiter);
}
expression.push(ch);
}
'}' => {
if expression.is_empty() {
return Err(EmptyExpression);
}
let ret = if pred_expression_text.is_empty() {
vec![self.expr(ExprKind::FormattedValue {
value: Box::new(
parse_fstring_expr(&expression)
.map_err(|e| InvalidExpression(Box::new(e.error)))?,
),
conversion,
format_spec: spec,
})]
} else {
vec![
self.expr(ExprKind::Constant {
value: Constant::Str(pred_expression_text + "="),
kind: None,
}),
self.expr(ExprKind::Constant {
value: trailing_seq.into(),
kind: None,
}),
self.expr(ExprKind::FormattedValue {
value: Box::new(
parse_fstring_expr(&expression)
.map_err(|e| InvalidExpression(Box::new(e.error)))?,
),
conversion,
format_spec: spec,
}),
]
};
return Ok(ret);
}
'"' | '\'' => {
expression.push(ch);
for next in &mut self.chars {
expression.push(next);
if next == ch {
break;
}
}
}
' ' if !pred_expression_text.is_empty() => {
trailing_seq.push(ch);
}
_ => {
expression.push(ch);
}
}
}
Err(UnclosedLbrace)
}
fn parse(mut self) -> Result<Expr, FStringErrorType> {
let mut content = String::new();
let mut values = vec![];
while let Some(ch) = self.chars.next() {
match ch {
'{' => {
if let Some('{') = self.chars.peek() {
self.chars.next();
content.push('{');
} else {
if !content.is_empty() {
values.push(self.expr(ExprKind::Constant {
value: mem::take(&mut content).into(),
kind: None,
}));
}
values.extend(self.parse_formatted_value()?);
}
}
'}' => {
if let Some('}') = self.chars.peek() {
self.chars.next();
content.push('}');
} else {
return Err(UnopenedRbrace);
}
}
_ => {
content.push(ch);
}
}
}
if !content.is_empty() {
values.push(self.expr(ExprKind::Constant {
value: content.into(),
kind: None,
}))
}
let s = match values.len() {
0 => self.expr(ExprKind::Constant {
value: String::new().into(),
kind: None,
}),
1 => values.into_iter().next().unwrap(),
_ => self.expr(ExprKind::JoinedStr { values }),
};
Ok(s)
}
}
fn parse_fstring_expr(source: &str) -> Result<Expr, ParseError> {
let fstring_body = format!("({})", source);
parse_expression(&fstring_body)
}
/// Parse an fstring from a string, located at a certain position in the sourcecode.
/// In case of errors, we will get the location and the error returned.
pub fn parse_located_fstring(source: &str, location: Location) -> Result<Expr, FStringError> {
FStringParser::new(source, location)
.parse()
.map_err(|error| FStringError { error, location })
}
#[cfg(test)]
mod tests {
use super::*;
fn parse_fstring(source: &str) -> Result<Expr, FStringErrorType> {
FStringParser::new(source, Location::default()).parse()
}
#[test]
fn test_parse_fstring() {
let source = "{a}{ b }{{foo}}";
let parse_ast = parse_fstring(&source).unwrap();
insta::assert_debug_snapshot!(parse_ast);
}
#[test]
fn test_parse_fstring_nested_spec() {
let source = "{foo:{spec}}";
let parse_ast = parse_fstring(&source).unwrap();
insta::assert_debug_snapshot!(parse_ast);
}
#[test]
fn test_parse_fstring_not_nested_spec() {
let source = "{foo:spec}";
let parse_ast = parse_fstring(&source).unwrap();
insta::assert_debug_snapshot!(parse_ast);
}
#[test]
fn test_parse_empty_fstring() {
insta::assert_debug_snapshot!(parse_fstring("").unwrap());
}
#[test]
fn test_fstring_parse_selfdocumenting_base() {
let src = "{user=}";
let parse_ast = parse_fstring(&src).unwrap();
insta::assert_debug_snapshot!(parse_ast);
}
#[test]
fn test_fstring_parse_selfdocumenting_base_more() {
let src = "mix {user=} with text and {second=}";
let parse_ast = parse_fstring(&src).unwrap();
insta::assert_debug_snapshot!(parse_ast);
}
#[test]
fn test_fstring_parse_selfdocumenting_format() {
let src = "{user=:>10}";
let parse_ast = parse_fstring(&src).unwrap();
insta::assert_debug_snapshot!(parse_ast);
}
#[test]
fn test_parse_invalid_fstring() {
assert_eq!(parse_fstring("{5!a"), Err(ExpectedRbrace));
assert_eq!(parse_fstring("{5!a1}"), Err(ExpectedRbrace));
assert_eq!(parse_fstring("{5!"), Err(ExpectedRbrace));
assert_eq!(parse_fstring("abc{!a 'cat'}"), Err(EmptyExpression));
assert_eq!(parse_fstring("{!a"), Err(EmptyExpression));
assert_eq!(parse_fstring("{ !a}"), Err(EmptyExpression));
assert_eq!(parse_fstring("{5!}"), Err(InvalidConversionFlag));
assert_eq!(parse_fstring("{5!x}"), Err(InvalidConversionFlag));
assert_eq!(parse_fstring("{a:{a:{b}}"), Err(ExpressionNestedTooDeeply));
assert_eq!(parse_fstring("{a:b}}"), Err(UnopenedRbrace));
assert_eq!(parse_fstring("}"), Err(UnopenedRbrace));
assert_eq!(parse_fstring("{a:{b}"), Err(UnclosedLbrace));
assert_eq!(parse_fstring("{"), Err(UnclosedLbrace));
assert_eq!(parse_fstring("{}"), Err(EmptyExpression));
// TODO: check for InvalidExpression enum?
assert!(parse_fstring("{class}").is_err());
}
#[test]
fn test_parse_fstring_not_equals() {
let source = "{1 != 2}";
let parse_ast = parse_fstring(&source).unwrap();
insta::assert_debug_snapshot!(parse_ast);
}
#[test]
fn test_parse_fstring_equals() {
let source = "{42 == 42}";
let parse_ast = parse_fstring(&source).unwrap();
insta::assert_debug_snapshot!(parse_ast);
}
#[test]
fn test_parse_fstring_selfdoc_prec_space() {
let source = "{x =}";
let parse_ast = parse_fstring(&source).unwrap();
insta::assert_debug_snapshot!(parse_ast);
}
#[test]
fn test_parse_fstring_selfdoc_trailing_space() {
let source = "{x= }";
let parse_ast = parse_fstring(&source).unwrap();
insta::assert_debug_snapshot!(parse_ast);
}
#[test]
fn test_parse_fstring_yield_expr() {
let source = "{yield}";
let parse_ast = parse_fstring(&source).unwrap();
insta::assert_debug_snapshot!(parse_ast);
}
}

View File

@ -1,96 +0,0 @@
use ahash::RandomState;
use std::collections::HashSet;
use crate::ast;
use crate::error::{LexicalError, LexicalErrorType};
pub struct ArgumentList {
pub args: Vec<ast::Expr>,
pub keywords: Vec<ast::Keyword>,
}
type ParameterDefs = (Vec<ast::Arg>, Vec<ast::Arg>, Vec<ast::Expr>);
type ParameterDef = (ast::Arg, Option<ast::Expr>);
pub fn parse_params(
params: (Vec<ParameterDef>, Vec<ParameterDef>),
) -> Result<ParameterDefs, LexicalError> {
let mut posonly = Vec::with_capacity(params.0.len());
let mut names = Vec::with_capacity(params.1.len());
let mut defaults = vec![];
let mut try_default = |name: &ast::Arg, default| {
if let Some(default) = default {
defaults.push(default);
} else if !defaults.is_empty() {
// Once we have started with defaults, all remaining arguments must
// have defaults
return Err(LexicalError {
error: LexicalErrorType::DefaultArgumentError,
location: name.location,
});
}
Ok(())
};
for (name, default) in params.0 {
try_default(&name, default)?;
posonly.push(name);
}
for (name, default) in params.1 {
try_default(&name, default)?;
names.push(name);
}
Ok((posonly, names, defaults))
}
type FunctionArgument = (Option<(ast::Location, Option<String>)>, ast::Expr);
pub fn parse_args(func_args: Vec<FunctionArgument>) -> Result<ArgumentList, LexicalError> {
let mut args = vec![];
let mut keywords = vec![];
let mut keyword_names = HashSet::with_capacity_and_hasher(func_args.len(), RandomState::new());
for (name, value) in func_args {
match name {
Some((location, name)) => {
if let Some(keyword_name) = &name {
if keyword_names.contains(keyword_name) {
return Err(LexicalError {
error: LexicalErrorType::DuplicateKeywordArgumentError,
location,
});
}
keyword_names.insert(keyword_name.clone());
}
keywords.push(ast::Keyword::new(
location,
ast::KeywordData {
arg: name.map(|name| name.into()),
value: Box::new(value),
},
));
}
None => {
// Allow starred args after keyword arguments.
if !keywords.is_empty() && !is_starred(&value) {
return Err(LexicalError {
error: LexicalErrorType::PositionalArgumentError,
location: value.location,
});
}
args.push(value);
}
}
}
Ok(ArgumentList { args, keywords })
}
fn is_starred(exp: &ast::Expr) -> bool {
matches!(exp.node, ast::ExprKind::Starred { .. })
}

File diff suppressed because it is too large Load Diff

View File

@ -1,35 +0,0 @@
//! This crate can be used to parse python sourcecode into a so
//! called AST (abstract syntax tree).
//!
//! The stages involved in this process are lexical analysis and
//! parsing. The lexical analysis splits the sourcecode into
//! tokens, and the parsing transforms those tokens into an AST.
//!
//! For example, one could do this:
//!
//! ```
//! use nac3parser::{parser, ast};
//!
//! let python_source = "print('Hello world')";
//! let python_ast = parser::parse_expression(python_source).unwrap();
//!
//! ```
#[macro_use]
extern crate log;
use lalrpop_util::lalrpop_mod;
pub use nac3ast as ast;
pub mod error;
mod fstring;
mod function;
pub mod lexer;
pub mod mode;
pub mod parser;
lalrpop_mod!(
#[allow(clippy::all)]
#[allow(unused)]
python
);
pub mod token;
pub mod config_comment_helper;

View File

@ -1,40 +0,0 @@
use crate::token::Tok;
#[derive(Clone, Copy)]
pub enum Mode {
Module,
Interactive,
Expression,
}
impl Mode {
pub(crate) fn to_marker(self) -> Tok {
match self {
Self::Module => Tok::StartModule,
Self::Interactive => Tok::StartInteractive,
Self::Expression => Tok::StartExpression,
}
}
}
impl std::str::FromStr for Mode {
type Err = ModeParseError;
fn from_str(s: &str) -> Result<Self, ModeParseError> {
match s {
"exec" | "single" => Ok(Mode::Module),
"eval" => Ok(Mode::Expression),
_ => Err(ModeParseError { _priv: () }),
}
}
}
#[derive(Debug)]
pub struct ModeParseError {
_priv: (),
}
impl std::fmt::Display for ModeParseError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, r#"mode should be "exec", "eval", or "single""#)
}
}

View File

@ -1,229 +0,0 @@
//! Python parsing.
//!
//! Use this module to parse python code into an AST.
//! There are three ways to parse python code. You could
//! parse a whole program, a single statement, or a single
//! expression.
use std::iter;
use crate::ast::{self, FileName};
use crate::error::ParseError;
use crate::lexer;
pub use crate::mode::Mode;
use crate::python;
/*
* Parse python code.
* Grammar may be inspired by antlr grammar for python:
* https://github.com/antlr/grammars-v4/tree/master/python3
*/
/// Parse a full python program, containing usually multiple lines.
pub fn parse_program(source: &str, file: FileName) -> Result<ast::Suite, ParseError> {
parse(source, Mode::Module, file).map(|top| match top {
ast::Mod::Module { body, .. } => body,
_ => unreachable!(),
})
}
/// Parses a python expression
///
/// # Example
/// ```
/// use nac3parser::{parser, ast};
/// let expr = parser::parse_expression("1 + 2").unwrap();
///
/// assert_eq!(
/// expr,
/// ast::Expr {
/// location: ast::Location::new(1, 3, Default::default()),
/// custom: (),
/// node: ast::ExprKind::BinOp {
/// left: Box::new(ast::Expr {
/// location: ast::Location::new(1, 1, Default::default()),
/// custom: (),
/// node: ast::ExprKind::Constant {
/// value: ast::Constant::Int(1.into()),
/// kind: None,
/// }
/// }),
/// op: ast::Operator::Add,
/// right: Box::new(ast::Expr {
/// location: ast::Location::new(1, 5, Default::default()),
/// custom: (),
/// node: ast::ExprKind::Constant {
/// value: ast::Constant::Int(2.into()),
/// kind: None,
/// }
/// })
/// }
/// },
/// );
///
/// ```
pub fn parse_expression(source: &str) -> Result<ast::Expr, ParseError> {
parse(source, Mode::Expression, Default::default()).map(|top| match top {
ast::Mod::Expression { body } => *body,
_ => unreachable!(),
})
}
// Parse a given source code
pub fn parse(source: &str, mode: Mode, file: FileName) -> Result<ast::Mod, ParseError> {
let lxr = lexer::make_tokenizer(source, file);
let marker_token = (Default::default(), mode.to_marker(), Default::default());
let tokenizer = iter::once(Ok(marker_token)).chain(lxr);
python::TopParser::new()
.parse(tokenizer)
.map_err(ParseError::from)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_empty() {
let parse_ast = parse_program("", Default::default()).unwrap();
insta::assert_debug_snapshot!(parse_ast);
}
#[test]
fn test_parse_print_hello() {
let source = String::from("print('Hello world')");
let parse_ast = parse_program(&source, Default::default()).unwrap();
insta::assert_debug_snapshot!(parse_ast);
}
#[test]
fn test_parse_print_2() {
let source = String::from("print('Hello world', 2)");
let parse_ast = parse_program(&source, Default::default()).unwrap();
insta::assert_debug_snapshot!(parse_ast);
}
#[test]
fn test_parse_kwargs() {
let source = String::from("my_func('positional', keyword=2)");
let parse_ast = parse_program(&source, Default::default()).unwrap();
insta::assert_debug_snapshot!(parse_ast);
}
#[test]
fn test_parse_if_elif_else() {
let source = String::from("if 1: 10\nelif 2: 20\nelse: 30");
let parse_ast = parse_program(&source, Default::default()).unwrap();
insta::assert_debug_snapshot!(parse_ast);
}
#[test]
fn test_parse_lambda() {
let source = "lambda x, y: x * y"; // lambda(x, y): x * y";
let parse_ast = parse_program(source, Default::default()).unwrap();
insta::assert_debug_snapshot!(parse_ast);
}
#[test]
fn test_parse_tuples() {
let source = "a, b = 4, 5";
insta::assert_debug_snapshot!(parse_program(source, Default::default()).unwrap());
}
#[test]
fn test_parse_class() {
let source = "\
class Foo(A, B):
def __init__(self):
pass
def method_with_default(self, arg='default'):
pass";
insta::assert_debug_snapshot!(parse_program(source, Default::default()).unwrap());
}
#[test]
fn test_parse_dict_comprehension() {
let source = String::from("{x1: x2 for y in z}");
let parse_ast = parse_expression(&source).unwrap();
insta::assert_debug_snapshot!(parse_ast);
}
#[test]
fn test_parse_list_comprehension() {
let source = String::from("[x for y in z]");
let parse_ast = parse_expression(&source).unwrap();
insta::assert_debug_snapshot!(parse_ast);
}
#[test]
fn test_parse_double_list_comprehension() {
let source = String::from("[x for y, y2 in z for a in b if a < 5 if a > 10]");
let parse_ast = parse_expression(&source).unwrap();
insta::assert_debug_snapshot!(parse_ast);
}
#[test]
fn test_more_comment() {
let source = "\
a: int # nac3: sf1
# nac3: sdf4
for i in (1, '12'): # nac3: sf2
a: int
# nac3: 3
# nac3: 5
while i < 2: # nac3: 4
# nac3: real pass
pass
# nac3: expr1
# nac3: expr3
1 + 2 # nac3: expr2
# nac3: if3
# nac3: if1
if 1: # nac3: if2
3";
insta::assert_debug_snapshot!(parse_program(source, Default::default()).unwrap());
}
#[test]
fn test_sample_comment() {
let source = "\
# nac3: while1
# nac3: while2
# normal comment
while test: # nac3: while3
# nac3: simple assign0
a = 3 # nac3: simple assign1
";
insta::assert_debug_snapshot!(parse_program(source, Default::default()).unwrap());
}
#[test]
fn test_comment_ambiguity() {
let source = "\
if a: d; # nac3: for d
if b: c # nac3: for c
if d: # nac3: for if d
b; b + 3; # nac3: for b + 3
a = 3; a + 3; b = a; # nac3: notif
# nac3: smallsingle1
# nac3: smallsingle3
aa = 3 # nac3: smallsingle2
if a: # nac3: small2
a
for i in a: # nac3: for1
pass
";
insta::assert_debug_snapshot!(parse_program(source, Default::default()).unwrap());
}
#[test]
fn test_comment_should_fail() {
let source = "\
if a: # nac3: something
a = 3
";
assert!(parse_program(source, Default::default()).is_err());
}
}

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More