Compare commits
12 Commits
master
...
ndstrides-
Author | SHA1 | Date | |
---|---|---|---|
eac164ce11 | |||
e8aa6129f0 | |||
e1abb819c6 | |||
db9e9586b5 | |||
2525369760 | |||
06b64ea888 | |||
37a022e156 | |||
17e8eeaa46 | |||
eb3aa18ce1 | |||
8068bd5cb0 | |||
3a8c4009fd | |||
ca13f56404 |
@ -1,32 +0,0 @@
|
|||||||
BasedOnStyle: LLVM
|
|
||||||
|
|
||||||
Language: Cpp
|
|
||||||
Standard: Cpp11
|
|
||||||
|
|
||||||
AccessModifierOffset: -1
|
|
||||||
AlignEscapedNewlines: Left
|
|
||||||
AlwaysBreakAfterReturnType: None
|
|
||||||
AlwaysBreakTemplateDeclarations: Yes
|
|
||||||
AllowAllParametersOfDeclarationOnNextLine: false
|
|
||||||
AllowShortFunctionsOnASingleLine: Inline
|
|
||||||
BinPackParameters: false
|
|
||||||
BreakBeforeBinaryOperators: NonAssignment
|
|
||||||
BreakBeforeTernaryOperators: true
|
|
||||||
BreakConstructorInitializers: AfterColon
|
|
||||||
BreakInheritanceList: AfterColon
|
|
||||||
ColumnLimit: 120
|
|
||||||
ConstructorInitializerAllOnOneLineOrOnePerLine: true
|
|
||||||
ContinuationIndentWidth: 4
|
|
||||||
DerivePointerAlignment: false
|
|
||||||
IndentCaseLabels: true
|
|
||||||
IndentPPDirectives: None
|
|
||||||
IndentWidth: 4
|
|
||||||
MaxEmptyLinesToKeep: 1
|
|
||||||
PointerAlignment: Left
|
|
||||||
ReflowComments: true
|
|
||||||
SortIncludes: false
|
|
||||||
SortUsingDeclarations: true
|
|
||||||
SpaceAfterTemplateKeyword: false
|
|
||||||
SpacesBeforeTrailingComments: 2
|
|
||||||
TabWidth: 4
|
|
||||||
UseTab: Never
|
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,4 +1,3 @@
|
|||||||
__pycache__
|
__pycache__
|
||||||
/target
|
/target
|
||||||
/nac3standalone/demo/linalg/target
|
|
||||||
nix/windows/msys2
|
nix/windows/msys2
|
||||||
|
@ -1,24 +1,24 @@
|
|||||||
# See https://pre-commit.com for more information
|
# See https://pre-commit.com for more information
|
||||||
# See https://pre-commit.com/hooks.html for more hooks
|
# See https://pre-commit.com/hooks.html for more hooks
|
||||||
|
|
||||||
default_stages: [pre-commit]
|
default_stages: [commit]
|
||||||
|
|
||||||
repos:
|
repos:
|
||||||
- repo: local
|
- repo: local
|
||||||
hooks:
|
hooks:
|
||||||
- id: nac3-cargo-fmt
|
- id: nac3-cargo-fmt
|
||||||
name: nac3 cargo format
|
name: nac3 cargo format
|
||||||
entry: nix
|
entry: cargo
|
||||||
language: system
|
language: system
|
||||||
types: [file, rust]
|
types: [file, rust]
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
description: Runs cargo fmt on the codebase.
|
description: Runs cargo fmt on the codebase.
|
||||||
args: [develop, -c, cargo, fmt, --all]
|
args: [fmt]
|
||||||
- id: nac3-cargo-clippy
|
- id: nac3-cargo-clippy
|
||||||
name: nac3 cargo clippy
|
name: nac3 cargo clippy
|
||||||
entry: nix
|
entry: cargo
|
||||||
language: system
|
language: system
|
||||||
types: [file, rust]
|
types: [file, rust]
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
description: Runs cargo clippy on the codebase.
|
description: Runs cargo clippy on the codebase.
|
||||||
args: [develop, -c, cargo, clippy, --tests]
|
args: [clippy, --tests]
|
||||||
|
661
Cargo.lock
generated
661
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@ -4,7 +4,6 @@ members = [
|
|||||||
"nac3ast",
|
"nac3ast",
|
||||||
"nac3parser",
|
"nac3parser",
|
||||||
"nac3core",
|
"nac3core",
|
||||||
"nac3core/nac3core_derive",
|
|
||||||
"nac3standalone",
|
"nac3standalone",
|
||||||
"nac3artiq",
|
"nac3artiq",
|
||||||
"runkernel",
|
"runkernel",
|
||||||
|
6
flake.lock
generated
6
flake.lock
generated
@ -2,11 +2,11 @@
|
|||||||
"nodes": {
|
"nodes": {
|
||||||
"nixpkgs": {
|
"nixpkgs": {
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1736798957,
|
"lastModified": 1720418205,
|
||||||
"narHash": "sha256-qwpCtZhSsSNQtK4xYGzMiyEDhkNzOCz/Vfu4oL2ETsQ=",
|
"narHash": "sha256-cPJoFPXU44GlhWg4pUk9oUPqurPlCFZ11ZQPk21GTPU=",
|
||||||
"owner": "NixOS",
|
"owner": "NixOS",
|
||||||
"repo": "nixpkgs",
|
"repo": "nixpkgs",
|
||||||
"rev": "9abb87b552b7f55ac8916b6fc9e5cb486656a2f3",
|
"rev": "655a58a72a6601292512670343087c2d75d859c1",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
|
42
flake.nix
42
flake.nix
@ -6,7 +6,6 @@
|
|||||||
outputs = { self, nixpkgs }:
|
outputs = { self, nixpkgs }:
|
||||||
let
|
let
|
||||||
pkgs = import nixpkgs { system = "x86_64-linux"; };
|
pkgs = import nixpkgs { system = "x86_64-linux"; };
|
||||||
pkgs32 = import nixpkgs { system = "i686-linux"; };
|
|
||||||
in rec {
|
in rec {
|
||||||
packages.x86_64-linux = rec {
|
packages.x86_64-linux = rec {
|
||||||
llvm-nac3 = pkgs.callPackage ./nix/llvm {};
|
llvm-nac3 = pkgs.callPackage ./nix/llvm {};
|
||||||
@ -14,24 +13,9 @@
|
|||||||
''
|
''
|
||||||
mkdir -p $out/bin
|
mkdir -p $out/bin
|
||||||
ln -s ${pkgs.llvmPackages_14.clang-unwrapped}/bin/clang $out/bin/clang-irrt
|
ln -s ${pkgs.llvmPackages_14.clang-unwrapped}/bin/clang $out/bin/clang-irrt
|
||||||
|
ln -s ${pkgs.llvmPackages_14.clang}/bin/clang $out/bin/clang-irrt-test
|
||||||
ln -s ${pkgs.llvmPackages_14.llvm.out}/bin/llvm-as $out/bin/llvm-as-irrt
|
ln -s ${pkgs.llvmPackages_14.llvm.out}/bin/llvm-as $out/bin/llvm-as-irrt
|
||||||
'';
|
'';
|
||||||
demo-linalg-stub = pkgs.rustPlatform.buildRustPackage {
|
|
||||||
name = "demo-linalg-stub";
|
|
||||||
src = ./nac3standalone/demo/linalg;
|
|
||||||
cargoLock = {
|
|
||||||
lockFile = ./nac3standalone/demo/linalg/Cargo.lock;
|
|
||||||
};
|
|
||||||
doCheck = false;
|
|
||||||
};
|
|
||||||
demo-linalg-stub32 = pkgs32.rustPlatform.buildRustPackage {
|
|
||||||
name = "demo-linalg-stub32";
|
|
||||||
src = ./nac3standalone/demo/linalg;
|
|
||||||
cargoLock = {
|
|
||||||
lockFile = ./nac3standalone/demo/linalg/Cargo.lock;
|
|
||||||
};
|
|
||||||
doCheck = false;
|
|
||||||
};
|
|
||||||
nac3artiq = pkgs.python3Packages.toPythonModule (
|
nac3artiq = pkgs.python3Packages.toPythonModule (
|
||||||
pkgs.rustPlatform.buildRustPackage rec {
|
pkgs.rustPlatform.buildRustPackage rec {
|
||||||
name = "nac3artiq";
|
name = "nac3artiq";
|
||||||
@ -40,8 +24,9 @@
|
|||||||
cargoLock = {
|
cargoLock = {
|
||||||
lockFile = ./Cargo.lock;
|
lockFile = ./Cargo.lock;
|
||||||
};
|
};
|
||||||
|
cargoTestFlags = [ "--features" "test" ];
|
||||||
passthru.cargoLock = cargoLock;
|
passthru.cargoLock = cargoLock;
|
||||||
nativeBuildInputs = [ pkgs.python3 (pkgs.wrapClangMulti pkgs.llvmPackages_14.clang) llvm-tools-irrt pkgs.llvmPackages_14.llvm.out llvm-nac3 ];
|
nativeBuildInputs = [ pkgs.python3 pkgs.llvmPackages_14.clang llvm-tools-irrt pkgs.llvmPackages_14.llvm.out llvm-nac3 ];
|
||||||
buildInputs = [ pkgs.python3 llvm-nac3 ];
|
buildInputs = [ pkgs.python3 llvm-nac3 ];
|
||||||
checkInputs = [ (pkgs.python3.withPackages(ps: [ ps.numpy ps.scipy ])) ];
|
checkInputs = [ (pkgs.python3.withPackages(ps: [ ps.numpy ps.scipy ])) ];
|
||||||
checkPhase =
|
checkPhase =
|
||||||
@ -49,9 +34,7 @@
|
|||||||
echo "Checking nac3standalone demos..."
|
echo "Checking nac3standalone demos..."
|
||||||
pushd nac3standalone/demo
|
pushd nac3standalone/demo
|
||||||
patchShebangs .
|
patchShebangs .
|
||||||
export DEMO_LINALG_STUB=${demo-linalg-stub}/lib/liblinalg.a
|
./check_demos.sh
|
||||||
export DEMO_LINALG_STUB32=${demo-linalg-stub32}/lib/liblinalg.a
|
|
||||||
./check_demos.sh -i686
|
|
||||||
popd
|
popd
|
||||||
echo "Running Cargo tests..."
|
echo "Running Cargo tests..."
|
||||||
cargoCheckHook
|
cargoCheckHook
|
||||||
@ -107,18 +90,18 @@
|
|||||||
(pkgs.fetchFromGitHub {
|
(pkgs.fetchFromGitHub {
|
||||||
owner = "m-labs";
|
owner = "m-labs";
|
||||||
repo = "sipyco";
|
repo = "sipyco";
|
||||||
rev = "094a6cd63ffa980ef63698920170e50dc9ba77fd";
|
rev = "939f84f9b5eef7efbf7423c735d1834783b6140e";
|
||||||
sha256 = "sha256-PPnAyDedUQ7Og/Cby9x5OT9wMkNGTP8GS53V6N/dk4w=";
|
sha256 = "sha256-15Nun4EY35j+6SPZkjzZtyH/ncxLS60KuGJjFh5kSTc=";
|
||||||
})
|
})
|
||||||
(pkgs.fetchFromGitHub {
|
(pkgs.fetchFromGitHub {
|
||||||
owner = "m-labs";
|
owner = "m-labs";
|
||||||
repo = "artiq";
|
repo = "artiq";
|
||||||
rev = "28c9de3e251daa89a8c9fd79d5ab64a3ec03bac6";
|
rev = "923ca3377d42c815f979983134ec549dc39d3ca0";
|
||||||
sha256 = "sha256-vAvpbHc5B+1wtG8zqN7j9dQE1ON+i22v+uqA+tw6Gak=";
|
sha256 = "sha256-oJoEeNEeNFSUyh6jXG8Tzp6qHVikeHS0CzfE+mODPgw=";
|
||||||
})
|
})
|
||||||
];
|
];
|
||||||
buildInputs = [
|
buildInputs = [
|
||||||
(python3-mimalloc.withPackages(ps: [ ps.numpy ps.scipy ps.jsonschema ps.lmdb ps.platformdirs nac3artiq-instrumented ]))
|
(python3-mimalloc.withPackages(ps: [ ps.numpy ps.scipy ps.jsonschema ps.lmdb nac3artiq-instrumented ]))
|
||||||
pkgs.llvmPackages_14.llvm.out
|
pkgs.llvmPackages_14.llvm.out
|
||||||
];
|
];
|
||||||
phases = [ "buildPhase" "installPhase" ];
|
phases = [ "buildPhase" "installPhase" ];
|
||||||
@ -168,7 +151,7 @@
|
|||||||
buildInputs = with pkgs; [
|
buildInputs = with pkgs; [
|
||||||
# build dependencies
|
# build dependencies
|
||||||
packages.x86_64-linux.llvm-nac3
|
packages.x86_64-linux.llvm-nac3
|
||||||
(pkgs.wrapClangMulti llvmPackages_14.clang) llvmPackages_14.llvm.out # for running nac3standalone demos
|
llvmPackages_14.clang llvmPackages_14.llvm.out # for running nac3standalone demos
|
||||||
packages.x86_64-linux.llvm-tools-irrt
|
packages.x86_64-linux.llvm-tools-irrt
|
||||||
cargo
|
cargo
|
||||||
rustc
|
rustc
|
||||||
@ -181,11 +164,6 @@
|
|||||||
pre-commit
|
pre-commit
|
||||||
rustfmt
|
rustfmt
|
||||||
];
|
];
|
||||||
shellHook =
|
|
||||||
''
|
|
||||||
export DEMO_LINALG_STUB=${packages.x86_64-linux.demo-linalg-stub}/lib/liblinalg.a
|
|
||||||
export DEMO_LINALG_STUB32=${packages.x86_64-linux.demo-linalg-stub32}/lib/liblinalg.a
|
|
||||||
'';
|
|
||||||
};
|
};
|
||||||
devShells.x86_64-linux.msys2 = pkgs.mkShell {
|
devShells.x86_64-linux.msys2 = pkgs.mkShell {
|
||||||
name = "nac3-dev-shell-msys2";
|
name = "nac3-dev-shell-msys2";
|
||||||
|
@ -12,10 +12,15 @@ crate-type = ["cdylib"]
|
|||||||
itertools = "0.13"
|
itertools = "0.13"
|
||||||
pyo3 = { version = "0.21", features = ["extension-module", "gil-refs"] }
|
pyo3 = { version = "0.21", features = ["extension-module", "gil-refs"] }
|
||||||
parking_lot = "0.12"
|
parking_lot = "0.12"
|
||||||
tempfile = "3.13"
|
tempfile = "3.10"
|
||||||
|
nac3parser = { path = "../nac3parser" }
|
||||||
nac3core = { path = "../nac3core" }
|
nac3core = { path = "../nac3core" }
|
||||||
nac3ld = { path = "../nac3ld" }
|
nac3ld = { path = "../nac3ld" }
|
||||||
|
|
||||||
|
[dependencies.inkwell]
|
||||||
|
version = "0.4"
|
||||||
|
default-features = false
|
||||||
|
features = ["llvm14-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
init-llvm-profile = []
|
init-llvm-profile = []
|
||||||
no-escape-analysis = ["nac3core/no-escape-analysis"]
|
|
||||||
|
66
nac3artiq/demo/embedding_map.py
Normal file
66
nac3artiq/demo/embedding_map.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
class EmbeddingMap:
|
||||||
|
def __init__(self):
|
||||||
|
self.object_inverse_map = {}
|
||||||
|
self.object_map = {}
|
||||||
|
self.string_map = {}
|
||||||
|
self.string_reverse_map = {}
|
||||||
|
self.function_map = {}
|
||||||
|
self.attributes_writeback = []
|
||||||
|
|
||||||
|
# preallocate exception names
|
||||||
|
self.preallocate_runtime_exception_names(["RuntimeError",
|
||||||
|
"RTIOUnderflow",
|
||||||
|
"RTIOOverflow",
|
||||||
|
"RTIODestinationUnreachable",
|
||||||
|
"DMAError",
|
||||||
|
"I2CError",
|
||||||
|
"CacheError",
|
||||||
|
"SPIError",
|
||||||
|
"0:ZeroDivisionError",
|
||||||
|
"0:IndexError",
|
||||||
|
"0:ValueError",
|
||||||
|
"0:RuntimeError",
|
||||||
|
"0:AssertionError",
|
||||||
|
"0:KeyError",
|
||||||
|
"0:NotImplementedError",
|
||||||
|
"0:OverflowError",
|
||||||
|
"0:IOError",
|
||||||
|
"0:UnwrapNoneError"])
|
||||||
|
|
||||||
|
def preallocate_runtime_exception_names(self, names):
|
||||||
|
for i, name in enumerate(names):
|
||||||
|
if ":" not in name:
|
||||||
|
name = "0:artiq.coredevice.exceptions." + name
|
||||||
|
exn_id = self.store_str(name)
|
||||||
|
assert exn_id == i
|
||||||
|
|
||||||
|
def store_function(self, key, fun):
|
||||||
|
self.function_map[key] = fun
|
||||||
|
return key
|
||||||
|
|
||||||
|
def store_object(self, obj):
|
||||||
|
obj_id = id(obj)
|
||||||
|
if obj_id in self.object_inverse_map:
|
||||||
|
return self.object_inverse_map[obj_id]
|
||||||
|
key = len(self.object_map) + 1
|
||||||
|
self.object_map[key] = obj
|
||||||
|
self.object_inverse_map[obj_id] = key
|
||||||
|
return key
|
||||||
|
|
||||||
|
def store_str(self, s):
|
||||||
|
if s in self.string_reverse_map:
|
||||||
|
return self.string_reverse_map[s]
|
||||||
|
key = len(self.string_map)
|
||||||
|
self.string_map[key] = s
|
||||||
|
self.string_reverse_map[s] = key
|
||||||
|
return key
|
||||||
|
|
||||||
|
def retrieve_function(self, key):
|
||||||
|
return self.function_map[key]
|
||||||
|
|
||||||
|
def retrieve_object(self, key):
|
||||||
|
return self.object_map[key]
|
||||||
|
|
||||||
|
def retrieve_str(self, key):
|
||||||
|
return self.string_map[key]
|
||||||
|
|
@ -6,6 +6,7 @@ from typing import Generic, TypeVar
|
|||||||
from math import floor, ceil
|
from math import floor, ceil
|
||||||
|
|
||||||
import nac3artiq
|
import nac3artiq
|
||||||
|
from embedding_map import EmbeddingMap
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -111,15 +112,10 @@ def extern(function):
|
|||||||
register_function(function)
|
register_function(function)
|
||||||
return function
|
return function
|
||||||
|
|
||||||
|
def rpc(function):
|
||||||
def rpc(arg=None, flags={}):
|
"""Decorates a function declaration defined by the core device runtime."""
|
||||||
"""Decorates a function or method to be executed on the host interpreter."""
|
register_function(function)
|
||||||
if arg is None:
|
return function
|
||||||
def inner_decorator(function):
|
|
||||||
return rpc(function, flags)
|
|
||||||
return inner_decorator
|
|
||||||
register_function(arg)
|
|
||||||
return arg
|
|
||||||
|
|
||||||
def kernel(function_or_method):
|
def kernel(function_or_method):
|
||||||
"""Decorates a function or method to be executed on the core device."""
|
"""Decorates a function or method to be executed on the core device."""
|
||||||
@ -192,46 +188,6 @@ def print_int64(x: int64):
|
|||||||
raise NotImplementedError("syscall not simulated")
|
raise NotImplementedError("syscall not simulated")
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingMap:
|
|
||||||
def __init__(self):
|
|
||||||
self.object_inverse_map = {}
|
|
||||||
self.object_map = {}
|
|
||||||
self.string_map = {}
|
|
||||||
self.string_reverse_map = {}
|
|
||||||
self.function_map = {}
|
|
||||||
self.attributes_writeback = []
|
|
||||||
|
|
||||||
def store_function(self, key, fun):
|
|
||||||
self.function_map[key] = fun
|
|
||||||
return key
|
|
||||||
|
|
||||||
def store_object(self, obj):
|
|
||||||
obj_id = id(obj)
|
|
||||||
if obj_id in self.object_inverse_map:
|
|
||||||
return self.object_inverse_map[obj_id]
|
|
||||||
key = len(self.object_map) + 1
|
|
||||||
self.object_map[key] = obj
|
|
||||||
self.object_inverse_map[obj_id] = key
|
|
||||||
return key
|
|
||||||
|
|
||||||
def store_str(self, s):
|
|
||||||
if s in self.string_reverse_map:
|
|
||||||
return self.string_reverse_map[s]
|
|
||||||
key = len(self.string_map)
|
|
||||||
self.string_map[key] = s
|
|
||||||
self.string_reverse_map[s] = key
|
|
||||||
return key
|
|
||||||
|
|
||||||
def retrieve_function(self, key):
|
|
||||||
return self.function_map[key]
|
|
||||||
|
|
||||||
def retrieve_object(self, key):
|
|
||||||
return self.object_map[key]
|
|
||||||
|
|
||||||
def retrieve_str(self, key):
|
|
||||||
return self.string_map[key]
|
|
||||||
|
|
||||||
|
|
||||||
@nac3
|
@nac3
|
||||||
class Core:
|
class Core:
|
||||||
ref_period: KernelInvariant[float]
|
ref_period: KernelInvariant[float]
|
||||||
@ -245,7 +201,7 @@ class Core:
|
|||||||
embedding = EmbeddingMap()
|
embedding = EmbeddingMap()
|
||||||
|
|
||||||
if allow_registration:
|
if allow_registration:
|
||||||
compiler.analyze(registered_functions, registered_classes, set())
|
compiler.analyze(registered_functions, registered_classes)
|
||||||
allow_registration = False
|
allow_registration = False
|
||||||
|
|
||||||
if hasattr(method, "__self__"):
|
if hasattr(method, "__self__"):
|
||||||
|
@ -1,26 +0,0 @@
|
|||||||
from min_artiq import *
|
|
||||||
from numpy import int32
|
|
||||||
|
|
||||||
# Global Variable Definition
|
|
||||||
X: Kernel[int32] = 1
|
|
||||||
|
|
||||||
# TopLevelFunction Defintion
|
|
||||||
@kernel
|
|
||||||
def display_X():
|
|
||||||
print_int32(X)
|
|
||||||
|
|
||||||
# TopLevel Class Definition
|
|
||||||
@nac3
|
|
||||||
class A:
|
|
||||||
@kernel
|
|
||||||
def __init__(self):
|
|
||||||
self.set_x(1)
|
|
||||||
|
|
||||||
@kernel
|
|
||||||
def set_x(self, new_val: int32):
|
|
||||||
global X
|
|
||||||
X = new_val
|
|
||||||
|
|
||||||
@kernel
|
|
||||||
def get_X(self) -> int32:
|
|
||||||
return X
|
|
@ -1,26 +0,0 @@
|
|||||||
from min_artiq import *
|
|
||||||
import module as module_definition
|
|
||||||
|
|
||||||
@nac3
|
|
||||||
class TestModuleSupport:
|
|
||||||
core: KernelInvariant[Core]
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.core = Core()
|
|
||||||
|
|
||||||
@kernel
|
|
||||||
def run(self):
|
|
||||||
# Accessing classes
|
|
||||||
obj = module_definition.A()
|
|
||||||
obj.get_X()
|
|
||||||
obj.set_x(2)
|
|
||||||
|
|
||||||
# Calling functions
|
|
||||||
module_definition.display_X()
|
|
||||||
|
|
||||||
# Updating global variables
|
|
||||||
module_definition.X = 9
|
|
||||||
module_definition.display_X()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
TestModuleSupport().run()
|
|
@ -1,29 +0,0 @@
|
|||||||
from min_artiq import *
|
|
||||||
import numpy
|
|
||||||
from numpy import int32
|
|
||||||
|
|
||||||
|
|
||||||
@nac3
|
|
||||||
class NumpyBoolDecay:
|
|
||||||
core: KernelInvariant[Core]
|
|
||||||
np_true: KernelInvariant[bool]
|
|
||||||
np_false: KernelInvariant[bool]
|
|
||||||
np_int: KernelInvariant[int32]
|
|
||||||
np_float: KernelInvariant[float]
|
|
||||||
np_str: KernelInvariant[str]
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.core = Core()
|
|
||||||
self.np_true = numpy.True_
|
|
||||||
self.np_false = numpy.False_
|
|
||||||
self.np_int = numpy.int32(0)
|
|
||||||
self.np_float = numpy.float64(0.0)
|
|
||||||
self.np_str = numpy.str_("")
|
|
||||||
|
|
||||||
@kernel
|
|
||||||
def run(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
NumpyBoolDecay().run()
|
|
@ -1,26 +0,0 @@
|
|||||||
from min_artiq import *
|
|
||||||
from numpy import ndarray, zeros as np_zeros
|
|
||||||
|
|
||||||
|
|
||||||
@nac3
|
|
||||||
class StrFail:
|
|
||||||
core: KernelInvariant[Core]
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.core = Core()
|
|
||||||
|
|
||||||
@kernel
|
|
||||||
def hello(self, arg: str):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@kernel
|
|
||||||
def consume_ndarray(self, arg: ndarray[str, 1]):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
self.hello("world")
|
|
||||||
self.consume_ndarray(np_zeros([10], dtype=str))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
StrFail().run()
|
|
@ -3,22 +3,22 @@ from numpy import int32
|
|||||||
|
|
||||||
|
|
||||||
@nac3
|
@nac3
|
||||||
class EmptyList:
|
class Demo:
|
||||||
core: KernelInvariant[Core]
|
core: KernelInvariant[Core]
|
||||||
|
attr1: KernelInvariant[str]
|
||||||
|
attr2: KernelInvariant[int32]
|
||||||
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.core = Core()
|
self.core = Core()
|
||||||
|
self.attr2 = 32
|
||||||
@rpc
|
self.attr1 = "SAMPLE"
|
||||||
def get_empty(self) -> list[int32]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
@kernel
|
@kernel
|
||||||
def run(self):
|
def run(self):
|
||||||
a: list[int32] = self.get_empty()
|
print_int32(self.attr2)
|
||||||
if a != []:
|
self.attr1
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
EmptyList().run()
|
Demo().run()
|
40
nac3artiq/demo/support_class_attr_issue102.py
Normal file
40
nac3artiq/demo/support_class_attr_issue102.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
from min_artiq import *
|
||||||
|
from numpy import int32
|
||||||
|
|
||||||
|
|
||||||
|
@nac3
|
||||||
|
class Demo:
|
||||||
|
attr1: KernelInvariant[int32] = 2
|
||||||
|
attr2: int32 = 4
|
||||||
|
attr3: Kernel[int32]
|
||||||
|
|
||||||
|
@kernel
|
||||||
|
def __init__(self):
|
||||||
|
self.attr3 = 8
|
||||||
|
|
||||||
|
|
||||||
|
@nac3
|
||||||
|
class NAC3Devices:
|
||||||
|
core: KernelInvariant[Core]
|
||||||
|
attr4: KernelInvariant[int32] = 16
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.core = Core()
|
||||||
|
|
||||||
|
@kernel
|
||||||
|
def run(self):
|
||||||
|
Demo.attr1 # Supported
|
||||||
|
# Demo.attr2 # Field not accessible on Kernel
|
||||||
|
# Demo.attr3 # Only attributes can be accessed in this way
|
||||||
|
# Demo.attr1 = 2 # Attributes are immutable
|
||||||
|
|
||||||
|
self.attr4 # Attributes can be accessed within class
|
||||||
|
|
||||||
|
obj = Demo()
|
||||||
|
obj.attr1 # Attributes can be accessed by class objects
|
||||||
|
|
||||||
|
NAC3Devices.attr4 # Attributes accessible for classes without __init__
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
NAC3Devices().run()
|
File diff suppressed because it is too large
Load Diff
@ -1,4 +1,10 @@
|
|||||||
#![deny(future_incompatible, let_underscore, nonstandard_style, clippy::all)]
|
#![deny(
|
||||||
|
future_incompatible,
|
||||||
|
let_underscore,
|
||||||
|
nonstandard_style,
|
||||||
|
rust_2024_compatibility,
|
||||||
|
clippy::all
|
||||||
|
)]
|
||||||
#![warn(clippy::pedantic)]
|
#![warn(clippy::pedantic)]
|
||||||
#![allow(
|
#![allow(
|
||||||
unsafe_op_in_unsafe_fn,
|
unsafe_op_in_unsafe_fn,
|
||||||
@ -10,65 +16,63 @@
|
|||||||
clippy::wildcard_imports
|
clippy::wildcard_imports
|
||||||
)]
|
)]
|
||||||
|
|
||||||
use std::{
|
use std::collections::{HashMap, HashSet};
|
||||||
collections::{HashMap, HashSet},
|
use std::fs;
|
||||||
fs,
|
use std::io::Write;
|
||||||
io::Write,
|
use std::process::Command;
|
||||||
process::Command,
|
use std::rc::Rc;
|
||||||
rc::Rc,
|
use std::sync::Arc;
|
||||||
sync::Arc,
|
|
||||||
};
|
|
||||||
|
|
||||||
use itertools::Itertools;
|
use inkwell::{
|
||||||
use parking_lot::{Mutex, RwLock};
|
memory_buffer::MemoryBuffer,
|
||||||
use pyo3::{
|
module::{Linkage, Module},
|
||||||
create_exception, exceptions,
|
passes::PassBuilderOptions,
|
||||||
prelude::*,
|
support::is_multithreaded,
|
||||||
types::{PyBytes, PyDict, PyNone, PySet},
|
targets::*,
|
||||||
|
OptimizationLevel,
|
||||||
};
|
};
|
||||||
use tempfile::{self, TempDir};
|
use itertools::Itertools;
|
||||||
|
use nac3core::codegen::{gen_func_impl, CodeGenLLVMOptions, CodeGenTargetMachineOptions};
|
||||||
|
use nac3core::toplevel::builtins::get_exn_constructor;
|
||||||
|
use nac3core::typecheck::typedef::{TypeEnum, Unifier, VarMap};
|
||||||
|
use nac3parser::{
|
||||||
|
ast::{ExprKind, Stmt, StmtKind, StrRef},
|
||||||
|
parser::parse_program,
|
||||||
|
};
|
||||||
|
use pyo3::create_exception;
|
||||||
|
use pyo3::prelude::*;
|
||||||
|
use pyo3::{exceptions, types::PyBytes, types::PyDict, types::PySet};
|
||||||
|
|
||||||
|
use parking_lot::{Mutex, RwLock};
|
||||||
|
|
||||||
use nac3core::{
|
use nac3core::{
|
||||||
codegen::{
|
codegen::irrt::load_irrt,
|
||||||
concrete_type::ConcreteTypeStore, gen_func_impl, irrt::load_irrt, CodeGenLLVMOptions,
|
codegen::{concrete_type::ConcreteTypeStore, CodeGenTask, WithCall, WorkerRegistry},
|
||||||
CodeGenTargetMachineOptions, CodeGenTask, CodeGenerator, WithCall, WorkerRegistry,
|
|
||||||
},
|
|
||||||
inkwell::{
|
|
||||||
context::Context,
|
|
||||||
memory_buffer::MemoryBuffer,
|
|
||||||
module::{FlagBehavior, Linkage, Module},
|
|
||||||
passes::PassBuilderOptions,
|
|
||||||
support::is_multithreaded,
|
|
||||||
targets::*,
|
|
||||||
OptimizationLevel,
|
|
||||||
},
|
|
||||||
nac3parser::{
|
|
||||||
ast::{self, Constant, ExprKind, Located, Stmt, StmtKind, StrRef},
|
|
||||||
parser::parse_program,
|
|
||||||
},
|
|
||||||
symbol_resolver::SymbolResolver,
|
symbol_resolver::SymbolResolver,
|
||||||
toplevel::{
|
toplevel::{
|
||||||
builtins::get_exn_constructor,
|
composer::{ComposerConfig, TopLevelComposer},
|
||||||
composer::{BuiltinFuncCreator, BuiltinFuncSpec, ComposerConfig, TopLevelComposer},
|
|
||||||
DefinitionId, GenCall, TopLevelDef,
|
DefinitionId, GenCall, TopLevelDef,
|
||||||
},
|
},
|
||||||
typecheck::{
|
typecheck::typedef::{FunSignature, FuncArg},
|
||||||
type_inferencer::PrimitiveStore,
|
typecheck::{type_inferencer::PrimitiveStore, typedef::Type},
|
||||||
typedef::{into_var_map, FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use nac3ld::Linker;
|
use nac3ld::Linker;
|
||||||
|
|
||||||
use codegen::{
|
use tempfile::{self, TempDir};
|
||||||
attributes_writeback, gen_core_log, gen_rtio_log, rpc_codegen_callback, ArtiqCodeGenerator,
|
|
||||||
|
use crate::codegen::attributes_writeback;
|
||||||
|
use crate::{
|
||||||
|
codegen::{rpc_codegen_callback, ArtiqCodeGenerator},
|
||||||
|
symbol_resolver::{DeferredEvaluationStore, InnerResolver, PythonHelper, Resolver},
|
||||||
};
|
};
|
||||||
use symbol_resolver::{DeferredEvaluationStore, InnerResolver, PythonHelper, Resolver};
|
|
||||||
use timeline::TimeFns;
|
|
||||||
|
|
||||||
mod codegen;
|
mod codegen;
|
||||||
mod symbol_resolver;
|
mod symbol_resolver;
|
||||||
mod timeline;
|
mod timeline;
|
||||||
|
|
||||||
|
use timeline::TimeFns;
|
||||||
|
|
||||||
#[derive(PartialEq, Clone, Copy)]
|
#[derive(PartialEq, Clone, Copy)]
|
||||||
enum Isa {
|
enum Isa {
|
||||||
Host,
|
Host,
|
||||||
@ -78,62 +82,14 @@ enum Isa {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Isa {
|
impl Isa {
|
||||||
/// Returns the [`TargetTriple`] used for compiling to this ISA.
|
/// Returns the number of bits in `size_t` for the [`Isa`].
|
||||||
pub fn get_llvm_target_triple(self) -> TargetTriple {
|
fn get_size_type(self) -> u32 {
|
||||||
match self {
|
if self == Isa::Host {
|
||||||
Isa::Host => TargetMachine::get_default_triple(),
|
64u32
|
||||||
Isa::RiscV32G | Isa::RiscV32IMA => TargetTriple::create("riscv32-unknown-linux"),
|
} else {
|
||||||
Isa::CortexA9 => TargetTriple::create("armv7-unknown-linux-gnueabihf"),
|
32u32
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the [`String`] representing the target CPU used for compiling to this ISA.
|
|
||||||
pub fn get_llvm_target_cpu(self) -> String {
|
|
||||||
match self {
|
|
||||||
Isa::Host => TargetMachine::get_host_cpu_name().to_string(),
|
|
||||||
Isa::RiscV32G | Isa::RiscV32IMA => "generic-rv32".to_string(),
|
|
||||||
Isa::CortexA9 => "cortex-a9".to_string(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the [`String`] representing the target features used for compiling to this ISA.
|
|
||||||
pub fn get_llvm_target_features(self) -> String {
|
|
||||||
match self {
|
|
||||||
Isa::Host => TargetMachine::get_host_cpu_features().to_string(),
|
|
||||||
Isa::RiscV32G => "+a,+m,+f,+d".to_string(),
|
|
||||||
Isa::RiscV32IMA => "+a,+m".to_string(),
|
|
||||||
Isa::CortexA9 => "+dsp,+fp16,+neon,+vfp3,+long-calls".to_string(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns an instance of [`CodeGenTargetMachineOptions`] representing the target machine
|
|
||||||
/// options used for compiling to this ISA.
|
|
||||||
pub fn get_llvm_target_options(self) -> CodeGenTargetMachineOptions {
|
|
||||||
CodeGenTargetMachineOptions {
|
|
||||||
triple: self.get_llvm_target_triple().as_str().to_string_lossy().into_owned(),
|
|
||||||
cpu: self.get_llvm_target_cpu(),
|
|
||||||
features: self.get_llvm_target_features(),
|
|
||||||
reloc_mode: RelocMode::PIC,
|
|
||||||
..CodeGenTargetMachineOptions::from_host()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns an instance of [`TargetMachine`] used in compiling and linking of a program of this
|
|
||||||
/// ISA.
|
|
||||||
pub fn create_llvm_target_machine(self, opt_level: OptimizationLevel) -> TargetMachine {
|
|
||||||
self.get_llvm_target_options()
|
|
||||||
.create_target_machine(opt_level)
|
|
||||||
.expect("couldn't create target machine")
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the number of bits in `size_t` for this ISA.
|
|
||||||
fn get_size_type(self, ctx: &Context) -> u32 {
|
|
||||||
ctx.ptr_sized_int_type(
|
|
||||||
&self.create_llvm_target_machine(OptimizationLevel::Default).get_target_data(),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.get_bit_width()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@ -159,7 +115,6 @@ pub struct PrimitivePythonId {
|
|||||||
generic_alias: (u64, u64),
|
generic_alias: (u64, u64),
|
||||||
virtual_id: u64,
|
virtual_id: u64,
|
||||||
option: u64,
|
option: u64,
|
||||||
module: u64,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type TopLevelComponent = (Stmt, String, PyObject);
|
type TopLevelComponent = (Stmt, String, PyObject);
|
||||||
@ -171,7 +126,7 @@ struct Nac3 {
|
|||||||
isa: Isa,
|
isa: Isa,
|
||||||
time_fns: &'static (dyn TimeFns + Sync),
|
time_fns: &'static (dyn TimeFns + Sync),
|
||||||
primitive: PrimitiveStore,
|
primitive: PrimitiveStore,
|
||||||
builtins: Vec<BuiltinFuncSpec>,
|
builtins: Vec<(StrRef, FunSignature, Arc<GenCall>)>,
|
||||||
pyid_to_def: Arc<RwLock<HashMap<u64, DefinitionId>>>,
|
pyid_to_def: Arc<RwLock<HashMap<u64, DefinitionId>>>,
|
||||||
primitive_ids: PrimitivePythonId,
|
primitive_ids: PrimitivePythonId,
|
||||||
working_directory: TempDir,
|
working_directory: TempDir,
|
||||||
@ -191,32 +146,14 @@ impl Nac3 {
|
|||||||
module: &PyObject,
|
module: &PyObject,
|
||||||
registered_class_ids: &HashSet<u64>,
|
registered_class_ids: &HashSet<u64>,
|
||||||
) -> PyResult<()> {
|
) -> PyResult<()> {
|
||||||
let (module_name, source_file, source) =
|
let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> {
|
||||||
Python::with_gil(|py| -> PyResult<(String, String, String)> {
|
let module: &PyAny = module.extract(py)?;
|
||||||
let module: &PyAny = module.extract(py)?;
|
Ok((module.getattr("__name__")?.extract()?, module.getattr("__file__")?.extract()?))
|
||||||
let source_file = module.getattr("__file__");
|
})?;
|
||||||
let (source_file, source) = if let Ok(source_file) = source_file {
|
|
||||||
let source_file = source_file.extract()?;
|
|
||||||
(
|
|
||||||
source_file,
|
|
||||||
fs::read_to_string(source_file).map_err(|e| {
|
|
||||||
exceptions::PyIOError::new_err(format!(
|
|
||||||
"failed to read input file: {e}"
|
|
||||||
))
|
|
||||||
})?,
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
// kernels submitted by content have no file
|
|
||||||
// but still can provide source by StringLoader
|
|
||||||
let get_src_fn = module
|
|
||||||
.getattr("__loader__")?
|
|
||||||
.extract::<PyObject>()?
|
|
||||||
.getattr(py, "get_source")?;
|
|
||||||
("<expcontent>", get_src_fn.call1(py, (PyNone::get(py),))?.extract(py)?)
|
|
||||||
};
|
|
||||||
Ok((module.getattr("__name__")?.extract()?, source_file.to_string(), source))
|
|
||||||
})?;
|
|
||||||
|
|
||||||
|
let source = fs::read_to_string(&source_file).map_err(|e| {
|
||||||
|
exceptions::PyIOError::new_err(format!("failed to read input file: {e}"))
|
||||||
|
})?;
|
||||||
let parser_result = parse_program(&source, source_file.into())
|
let parser_result = parse_program(&source, source_file.into())
|
||||||
.map_err(|e| exceptions::PySyntaxError::new_err(format!("parse error: {e}")))?;
|
.map_err(|e| exceptions::PySyntaxError::new_err(format!("parse error: {e}")))?;
|
||||||
|
|
||||||
@ -256,8 +193,10 @@ impl Nac3 {
|
|||||||
body.retain(|stmt| {
|
body.retain(|stmt| {
|
||||||
if let StmtKind::FunctionDef { ref decorator_list, .. } = stmt.node {
|
if let StmtKind::FunctionDef { ref decorator_list, .. } = stmt.node {
|
||||||
decorator_list.iter().any(|decorator| {
|
decorator_list.iter().any(|decorator| {
|
||||||
if let Some(id) = decorator_id_string(decorator) {
|
if let ExprKind::Name { id, .. } = decorator.node {
|
||||||
id == "kernel" || id == "portable" || id == "rpc"
|
id.to_string() == "kernel"
|
||||||
|
|| id.to_string() == "portable"
|
||||||
|
|| id.to_string() == "rpc"
|
||||||
} else {
|
} else {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
@ -270,17 +209,14 @@ impl Nac3 {
|
|||||||
}
|
}
|
||||||
StmtKind::FunctionDef { ref decorator_list, .. } => {
|
StmtKind::FunctionDef { ref decorator_list, .. } => {
|
||||||
decorator_list.iter().any(|decorator| {
|
decorator_list.iter().any(|decorator| {
|
||||||
if let Some(id) = decorator_id_string(decorator) {
|
if let ExprKind::Name { id, .. } = decorator.node {
|
||||||
id == "extern" || id == "kernel" || id == "portable" || id == "rpc"
|
let id = id.to_string();
|
||||||
|
id == "extern" || id == "portable" || id == "kernel" || id == "rpc"
|
||||||
} else {
|
} else {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
// Allow global variable declaration with `Kernel` type annotation
|
|
||||||
StmtKind::AnnAssign { ref annotation, .. } => {
|
|
||||||
matches!(&annotation.node, ExprKind::Subscript { value, .. } if matches!(&value.node, ExprKind::Name {id, ..} if id == &"Kernel".into()))
|
|
||||||
}
|
|
||||||
_ => false,
|
_ => false,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -328,7 +264,7 @@ impl Nac3 {
|
|||||||
arg_names.len(),
|
arg_names.len(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
for (i, FuncArg { ty, default_value, name, .. }) in args.iter().enumerate() {
|
for (i, FuncArg { ty, default_value, name }) in args.iter().enumerate() {
|
||||||
let in_name = match arg_names.get(i) {
|
let in_name = match arg_names.get(i) {
|
||||||
Some(n) => n,
|
Some(n) => n,
|
||||||
None if default_value.is_none() => {
|
None if default_value.is_none() => {
|
||||||
@ -364,64 +300,6 @@ impl Nac3 {
|
|||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a [`Vec`] of builtins that needs to be initialized during method compilation time.
|
|
||||||
fn get_lateinit_builtins() -> Vec<Box<BuiltinFuncCreator>> {
|
|
||||||
vec![
|
|
||||||
Box::new(|primitives, unifier| {
|
|
||||||
let arg_ty = unifier.get_fresh_var(Some("T".into()), None);
|
|
||||||
|
|
||||||
(
|
|
||||||
"core_log".into(),
|
|
||||||
FunSignature {
|
|
||||||
args: vec![FuncArg {
|
|
||||||
name: "arg".into(),
|
|
||||||
ty: arg_ty.ty,
|
|
||||||
default_value: None,
|
|
||||||
is_vararg: false,
|
|
||||||
}],
|
|
||||||
ret: primitives.none,
|
|
||||||
vars: into_var_map([arg_ty]),
|
|
||||||
},
|
|
||||||
Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| {
|
|
||||||
gen_core_log(ctx, obj.as_ref(), fun, &args, generator)?;
|
|
||||||
|
|
||||||
Ok(None)
|
|
||||||
}))),
|
|
||||||
)
|
|
||||||
}),
|
|
||||||
Box::new(|primitives, unifier| {
|
|
||||||
let arg_ty = unifier.get_fresh_var(Some("T".into()), None);
|
|
||||||
|
|
||||||
(
|
|
||||||
"rtio_log".into(),
|
|
||||||
FunSignature {
|
|
||||||
args: vec![
|
|
||||||
FuncArg {
|
|
||||||
name: "channel".into(),
|
|
||||||
ty: primitives.str,
|
|
||||||
default_value: None,
|
|
||||||
is_vararg: false,
|
|
||||||
},
|
|
||||||
FuncArg {
|
|
||||||
name: "arg".into(),
|
|
||||||
ty: arg_ty.ty,
|
|
||||||
default_value: None,
|
|
||||||
is_vararg: false,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
ret: primitives.none,
|
|
||||||
vars: into_var_map([arg_ty]),
|
|
||||||
},
|
|
||||||
Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| {
|
|
||||||
gen_rtio_log(ctx, obj.as_ref(), fun, &args, generator)?;
|
|
||||||
|
|
||||||
Ok(None)
|
|
||||||
}))),
|
|
||||||
)
|
|
||||||
}),
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
fn compile_method<T>(
|
fn compile_method<T>(
|
||||||
&self,
|
&self,
|
||||||
obj: &PyAny,
|
obj: &PyAny,
|
||||||
@ -431,10 +309,9 @@ impl Nac3 {
|
|||||||
py: Python,
|
py: Python,
|
||||||
link_fn: &dyn Fn(&Module) -> PyResult<T>,
|
link_fn: &dyn Fn(&Module) -> PyResult<T>,
|
||||||
) -> PyResult<T> {
|
) -> PyResult<T> {
|
||||||
let size_t = self.isa.get_size_type(&Context::create());
|
let size_t = self.isa.get_size_type();
|
||||||
let (mut composer, mut builtins_def, mut builtins_ty) = TopLevelComposer::new(
|
let (mut composer, mut builtins_def, mut builtins_ty) = TopLevelComposer::new(
|
||||||
self.builtins.clone(),
|
self.builtins.clone(),
|
||||||
Self::get_lateinit_builtins(),
|
|
||||||
ComposerConfig { kernel_ann: Some("Kernel"), kernel_invariant_ann: "KernelInvariant" },
|
ComposerConfig { kernel_ann: Some("Kernel"), kernel_invariant_ann: "KernelInvariant" },
|
||||||
size_t,
|
size_t,
|
||||||
);
|
);
|
||||||
@ -474,14 +351,12 @@ impl Nac3 {
|
|||||||
];
|
];
|
||||||
add_exceptions(&mut composer, &mut builtins_def, &mut builtins_ty, &exception_names);
|
add_exceptions(&mut composer, &mut builtins_def, &mut builtins_ty, &exception_names);
|
||||||
|
|
||||||
// Stores a mapping from module id to attributes
|
|
||||||
let mut module_to_resolver_cache: HashMap<u64, _> = HashMap::new();
|
let mut module_to_resolver_cache: HashMap<u64, _> = HashMap::new();
|
||||||
|
|
||||||
let mut rpc_ids = vec![];
|
let mut rpc_ids = vec![];
|
||||||
for (stmt, path, module) in &self.top_levels {
|
for (stmt, path, module) in &self.top_levels {
|
||||||
let py_module: &PyAny = module.extract(py)?;
|
let py_module: &PyAny = module.extract(py)?;
|
||||||
let module_id: u64 = id_fn.call1((py_module,))?.extract()?;
|
let module_id: u64 = id_fn.call1((py_module,))?.extract()?;
|
||||||
let module_name: String = py_module.getattr("__name__")?.extract()?;
|
|
||||||
let helper = helper.clone();
|
let helper = helper.clone();
|
||||||
let class_obj;
|
let class_obj;
|
||||||
if let StmtKind::ClassDef { name, .. } = &stmt.node {
|
if let StmtKind::ClassDef { name, .. } = &stmt.node {
|
||||||
@ -496,7 +371,7 @@ impl Nac3 {
|
|||||||
} else {
|
} else {
|
||||||
class_obj = None;
|
class_obj = None;
|
||||||
}
|
}
|
||||||
let (name_to_pyid, resolver, _, _) =
|
let (name_to_pyid, resolver) =
|
||||||
module_to_resolver_cache.get(&module_id).cloned().unwrap_or_else(|| {
|
module_to_resolver_cache.get(&module_id).cloned().unwrap_or_else(|| {
|
||||||
let mut name_to_pyid: HashMap<StrRef, u64> = HashMap::new();
|
let mut name_to_pyid: HashMap<StrRef, u64> = HashMap::new();
|
||||||
let members: &PyDict =
|
let members: &PyDict =
|
||||||
@ -513,6 +388,7 @@ impl Nac3 {
|
|||||||
pyid_to_type: pyid_to_type.clone(),
|
pyid_to_type: pyid_to_type.clone(),
|
||||||
primitive_ids: self.primitive_ids.clone(),
|
primitive_ids: self.primitive_ids.clone(),
|
||||||
global_value_ids: global_value_ids.clone(),
|
global_value_ids: global_value_ids.clone(),
|
||||||
|
class_names: Mutex::default(),
|
||||||
name_to_pyid: name_to_pyid.clone(),
|
name_to_pyid: name_to_pyid.clone(),
|
||||||
module: module.clone(),
|
module: module.clone(),
|
||||||
id_to_pyval: RwLock::default(),
|
id_to_pyval: RwLock::default(),
|
||||||
@ -525,17 +401,9 @@ impl Nac3 {
|
|||||||
})))
|
})))
|
||||||
as Arc<dyn SymbolResolver + Send + Sync>;
|
as Arc<dyn SymbolResolver + Send + Sync>;
|
||||||
let name_to_pyid = Rc::new(name_to_pyid);
|
let name_to_pyid = Rc::new(name_to_pyid);
|
||||||
let module_location = ast::Location::new(1, 1, stmt.location.file);
|
module_to_resolver_cache
|
||||||
module_to_resolver_cache.insert(
|
.insert(module_id, (name_to_pyid.clone(), resolver.clone()));
|
||||||
module_id,
|
(name_to_pyid, resolver)
|
||||||
(
|
|
||||||
name_to_pyid.clone(),
|
|
||||||
resolver.clone(),
|
|
||||||
module_name.clone(),
|
|
||||||
Some(module_location),
|
|
||||||
),
|
|
||||||
);
|
|
||||||
(name_to_pyid, resolver, module_name, Some(module_location))
|
|
||||||
});
|
});
|
||||||
|
|
||||||
let (name, def_id, ty) = composer
|
let (name, def_id, ty) = composer
|
||||||
@ -551,25 +419,9 @@ impl Nac3 {
|
|||||||
|
|
||||||
match &stmt.node {
|
match &stmt.node {
|
||||||
StmtKind::FunctionDef { decorator_list, .. } => {
|
StmtKind::FunctionDef { decorator_list, .. } => {
|
||||||
if decorator_list
|
if decorator_list.iter().any(|decorator| matches!(decorator.node, ExprKind::Name { id, .. } if id == "rpc".into())) {
|
||||||
.iter()
|
store_fun.call1(py, (def_id.0.into_py(py), module.getattr(py, name.to_string().as_str()).unwrap())).unwrap();
|
||||||
.any(|decorator| decorator_id_string(decorator) == Some("rpc".to_string()))
|
rpc_ids.push((None, def_id));
|
||||||
{
|
|
||||||
store_fun
|
|
||||||
.call1(
|
|
||||||
py,
|
|
||||||
(
|
|
||||||
def_id.0.into_py(py),
|
|
||||||
module.getattr(py, name.to_string().as_str()).unwrap(),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
let is_async = decorator_list.iter().any(|decorator| {
|
|
||||||
decorator_get_flags(decorator)
|
|
||||||
.iter()
|
|
||||||
.any(|constant| *constant == Constant::Str("async".into()))
|
|
||||||
});
|
|
||||||
rpc_ids.push((None, def_id, is_async));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
StmtKind::ClassDef { name, body, .. } => {
|
StmtKind::ClassDef { name, body, .. } => {
|
||||||
@ -577,26 +429,19 @@ impl Nac3 {
|
|||||||
let class_obj = module.getattr(py, class_name.as_str()).unwrap();
|
let class_obj = module.getattr(py, class_name.as_str()).unwrap();
|
||||||
for stmt in body {
|
for stmt in body {
|
||||||
if let StmtKind::FunctionDef { name, decorator_list, .. } = &stmt.node {
|
if let StmtKind::FunctionDef { name, decorator_list, .. } = &stmt.node {
|
||||||
if decorator_list.iter().any(|decorator| {
|
if decorator_list.iter().any(|decorator| matches!(decorator.node, ExprKind::Name { id, .. } if id == "rpc".into())) {
|
||||||
decorator_id_string(decorator) == Some("rpc".to_string())
|
|
||||||
}) {
|
|
||||||
let is_async = decorator_list.iter().any(|decorator| {
|
|
||||||
decorator_get_flags(decorator)
|
|
||||||
.iter()
|
|
||||||
.any(|constant| *constant == Constant::Str("async".into()))
|
|
||||||
});
|
|
||||||
if name == &"__init__".into() {
|
if name == &"__init__".into() {
|
||||||
return Err(CompileError::new_err(format!(
|
return Err(CompileError::new_err(format!(
|
||||||
"compilation failed\n----------\nThe constructor of class {} should not be decorated with rpc decorator (at {})",
|
"compilation failed\n----------\nThe constructor of class {} should not be decorated with rpc decorator (at {})",
|
||||||
class_name, stmt.location
|
class_name, stmt.location
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
rpc_ids.push((Some((class_obj.clone(), *name)), def_id, is_async));
|
rpc_ids.push((Some((class_obj.clone(), *name)), def_id));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => (),
|
_ => ()
|
||||||
}
|
}
|
||||||
|
|
||||||
let id = *name_to_pyid.get(&name).unwrap();
|
let id = *name_to_pyid.get(&name).unwrap();
|
||||||
@ -609,24 +454,6 @@ impl Nac3 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Adding top level module definitions
|
|
||||||
for (module_id, (module_name_to_pyid, module_resolver, module_name, module_location)) in
|
|
||||||
module_to_resolver_cache
|
|
||||||
{
|
|
||||||
let def_id = composer
|
|
||||||
.register_top_level_module(
|
|
||||||
&module_name,
|
|
||||||
&module_name_to_pyid,
|
|
||||||
module_resolver,
|
|
||||||
module_location,
|
|
||||||
)
|
|
||||||
.map_err(|e| {
|
|
||||||
CompileError::new_err(format!("compilation failed\n----------\n{e}"))
|
|
||||||
})?;
|
|
||||||
|
|
||||||
self.pyid_to_def.write().insert(module_id, def_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
let id_fun = PyModule::import(py, "builtins")?.getattr("id")?;
|
let id_fun = PyModule::import(py, "builtins")?.getattr("id")?;
|
||||||
let mut name_to_pyid: HashMap<StrRef, u64> = HashMap::new();
|
let mut name_to_pyid: HashMap<StrRef, u64> = HashMap::new();
|
||||||
let module = PyModule::new(py, "tmp")?;
|
let module = PyModule::new(py, "tmp")?;
|
||||||
@ -653,12 +480,13 @@ impl Nac3 {
|
|||||||
pyid_to_type: pyid_to_type.clone(),
|
pyid_to_type: pyid_to_type.clone(),
|
||||||
primitive_ids: self.primitive_ids.clone(),
|
primitive_ids: self.primitive_ids.clone(),
|
||||||
global_value_ids: global_value_ids.clone(),
|
global_value_ids: global_value_ids.clone(),
|
||||||
|
class_names: Mutex::default(),
|
||||||
id_to_pyval: RwLock::default(),
|
id_to_pyval: RwLock::default(),
|
||||||
id_to_primitive: RwLock::default(),
|
id_to_primitive: RwLock::default(),
|
||||||
field_to_val: RwLock::default(),
|
field_to_val: RwLock::default(),
|
||||||
name_to_pyid,
|
name_to_pyid,
|
||||||
module: module.to_object(py),
|
module: module.to_object(py),
|
||||||
helper: helper.clone(),
|
helper,
|
||||||
string_store: self.string_store.clone(),
|
string_store: self.string_store.clone(),
|
||||||
exception_ids: self.exception_ids.clone(),
|
exception_ids: self.exception_ids.clone(),
|
||||||
deferred_eval_store: self.deferred_eval_store.clone(),
|
deferred_eval_store: self.deferred_eval_store.clone(),
|
||||||
@ -669,10 +497,6 @@ impl Nac3 {
|
|||||||
.register_top_level(synthesized.pop().unwrap(), Some(resolver.clone()), "", false)
|
.register_top_level(synthesized.pop().unwrap(), Some(resolver.clone()), "", false)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
// Process IRRT
|
|
||||||
let context = Context::create();
|
|
||||||
let irrt = load_irrt(&context, resolver.as_ref());
|
|
||||||
|
|
||||||
let fun_signature =
|
let fun_signature =
|
||||||
FunSignature { args: vec![], ret: self.primitive.none, vars: VarMap::new() };
|
FunSignature { args: vec![], ret: self.primitive.none, vars: VarMap::new() };
|
||||||
let mut store = ConcreteTypeStore::new();
|
let mut store = ConcreteTypeStore::new();
|
||||||
@ -710,12 +534,13 @@ impl Nac3 {
|
|||||||
let top_level = Arc::new(composer.make_top_level_context());
|
let top_level = Arc::new(composer.make_top_level_context());
|
||||||
|
|
||||||
{
|
{
|
||||||
|
let rpc_codegen = rpc_codegen_callback();
|
||||||
let defs = top_level.definitions.read();
|
let defs = top_level.definitions.read();
|
||||||
for (class_data, id, is_async) in &rpc_ids {
|
for (class_data, id) in &rpc_ids {
|
||||||
let mut def = defs[id.0].write();
|
let mut def = defs[id.0].write();
|
||||||
match &mut *def {
|
match &mut *def {
|
||||||
TopLevelDef::Function { codegen_callback, .. } => {
|
TopLevelDef::Function { codegen_callback, .. } => {
|
||||||
*codegen_callback = Some(rpc_codegen_callback(*is_async));
|
*codegen_callback = Some(rpc_codegen.clone());
|
||||||
}
|
}
|
||||||
TopLevelDef::Class { methods, .. } => {
|
TopLevelDef::Class { methods, .. } => {
|
||||||
let (class_def, method_name) = class_data.as_ref().unwrap();
|
let (class_def, method_name) = class_data.as_ref().unwrap();
|
||||||
@ -726,7 +551,7 @@ impl Nac3 {
|
|||||||
if let TopLevelDef::Function { codegen_callback, .. } =
|
if let TopLevelDef::Function { codegen_callback, .. } =
|
||||||
&mut *defs[id.0].write()
|
&mut *defs[id.0].write()
|
||||||
{
|
{
|
||||||
*codegen_callback = Some(rpc_codegen_callback(*is_async));
|
*codegen_callback = Some(rpc_codegen.clone());
|
||||||
store_fun
|
store_fun
|
||||||
.call1(
|
.call1(
|
||||||
py,
|
py,
|
||||||
@ -741,14 +566,6 @@ impl Nac3 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TopLevelDef::Variable { .. } => {
|
|
||||||
return Err(CompileError::new_err(String::from(
|
|
||||||
"Unsupported @rpc annotation on global variable",
|
|
||||||
)))
|
|
||||||
}
|
|
||||||
TopLevelDef::Module { .. } => {
|
|
||||||
unreachable!("Type module cannot be decorated with @rpc")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -769,12 +586,33 @@ impl Nac3 {
|
|||||||
let task = CodeGenTask {
|
let task = CodeGenTask {
|
||||||
subst: Vec::default(),
|
subst: Vec::default(),
|
||||||
symbol_name: "__modinit__".to_string(),
|
symbol_name: "__modinit__".to_string(),
|
||||||
|
body: instance.body,
|
||||||
|
signature,
|
||||||
|
resolver: resolver.clone(),
|
||||||
|
store,
|
||||||
|
unifier_index: instance.unifier_id,
|
||||||
|
calls: instance.calls,
|
||||||
|
id: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut store = ConcreteTypeStore::new();
|
||||||
|
let mut cache = HashMap::new();
|
||||||
|
let signature = store.from_signature(
|
||||||
|
&mut composer.unifier,
|
||||||
|
&self.primitive,
|
||||||
|
&fun_signature,
|
||||||
|
&mut cache,
|
||||||
|
);
|
||||||
|
let signature = store.add_cty(signature);
|
||||||
|
let attributes_writeback_task = CodeGenTask {
|
||||||
|
subst: Vec::default(),
|
||||||
|
symbol_name: "attributes_writeback".to_string(),
|
||||||
body: Arc::new(Vec::default()),
|
body: Arc::new(Vec::default()),
|
||||||
signature,
|
signature,
|
||||||
resolver,
|
resolver,
|
||||||
store,
|
store,
|
||||||
unifier_index: instance.unifier_id,
|
unifier_index: instance.unifier_id,
|
||||||
calls: instance.calls,
|
calls: Arc::new(HashMap::default()),
|
||||||
id: 0,
|
id: 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -787,47 +625,25 @@ impl Nac3 {
|
|||||||
let buffer = buffer.as_slice().into();
|
let buffer = buffer.as_slice().into();
|
||||||
membuffer.lock().push(buffer);
|
membuffer.lock().push(buffer);
|
||||||
})));
|
})));
|
||||||
|
let size_t = if self.isa == Isa::Host { 64 } else { 32 };
|
||||||
let num_threads = if is_multithreaded() { 4 } else { 1 };
|
let num_threads = if is_multithreaded() { 4 } else { 1 };
|
||||||
let thread_names: Vec<String> = (0..num_threads).map(|_| "main".to_string()).collect();
|
let thread_names: Vec<String> = (0..num_threads).map(|_| "main".to_string()).collect();
|
||||||
let threads: Vec<_> = thread_names
|
let threads: Vec<_> = thread_names
|
||||||
.iter()
|
.iter()
|
||||||
.map(|s| {
|
.map(|s| Box::new(ArtiqCodeGenerator::new(s.to_string(), size_t, self.time_fns)))
|
||||||
Box::new(ArtiqCodeGenerator::with_target_machine(
|
|
||||||
s.to_string(),
|
|
||||||
&context,
|
|
||||||
&self.get_llvm_target_machine(),
|
|
||||||
self.time_fns,
|
|
||||||
))
|
|
||||||
})
|
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let membuffer = membuffers.clone();
|
let membuffer = membuffers.clone();
|
||||||
let mut has_return = false;
|
|
||||||
py.allow_threads(|| {
|
py.allow_threads(|| {
|
||||||
let (registry, handles) =
|
let (registry, handles) =
|
||||||
WorkerRegistry::create_workers(threads, top_level.clone(), &self.llvm_options, &f);
|
WorkerRegistry::create_workers(threads, top_level.clone(), &self.llvm_options, &f);
|
||||||
|
registry.add_task(task);
|
||||||
|
registry.wait_tasks_complete(handles);
|
||||||
|
|
||||||
let context = Context::create();
|
let mut generator =
|
||||||
let mut generator = ArtiqCodeGenerator::with_target_machine(
|
ArtiqCodeGenerator::new("attributes_writeback".to_string(), size_t, self.time_fns);
|
||||||
"main".to_string(),
|
let context = inkwell::context::Context::create();
|
||||||
&context,
|
let module = context.create_module("attributes_writeback");
|
||||||
&self.get_llvm_target_machine(),
|
|
||||||
self.time_fns,
|
|
||||||
);
|
|
||||||
let module = context.create_module("main");
|
|
||||||
let target_machine = self.llvm_options.create_target_machine().unwrap();
|
|
||||||
module.set_data_layout(&target_machine.get_target_data().get_data_layout());
|
|
||||||
module.set_triple(&target_machine.get_triple());
|
|
||||||
module.add_basic_value_flag(
|
|
||||||
"Debug Info Version",
|
|
||||||
FlagBehavior::Warning,
|
|
||||||
context.i32_type().const_int(3, false),
|
|
||||||
);
|
|
||||||
module.add_basic_value_flag(
|
|
||||||
"Dwarf Version",
|
|
||||||
FlagBehavior::Warning,
|
|
||||||
context.i32_type().const_int(4, false),
|
|
||||||
);
|
|
||||||
let builder = context.create_builder();
|
let builder = context.create_builder();
|
||||||
let (_, module, _) = gen_func_impl(
|
let (_, module, _) = gen_func_impl(
|
||||||
&context,
|
&context,
|
||||||
@ -835,27 +651,9 @@ impl Nac3 {
|
|||||||
®istry,
|
®istry,
|
||||||
builder,
|
builder,
|
||||||
module,
|
module,
|
||||||
task,
|
attributes_writeback_task,
|
||||||
|generator, ctx| {
|
|generator, ctx| {
|
||||||
assert_eq!(instance.body.len(), 1, "toplevel module should have 1 statement");
|
attributes_writeback(ctx, generator, inner_resolver.as_ref(), &host_attributes)
|
||||||
let StmtKind::Expr { value: ref expr, .. } = instance.body[0].node else {
|
|
||||||
unreachable!("toplevel statement must be an expression")
|
|
||||||
};
|
|
||||||
let ExprKind::Call { .. } = expr.node else {
|
|
||||||
unreachable!("toplevel expression must be a function call")
|
|
||||||
};
|
|
||||||
|
|
||||||
let return_obj =
|
|
||||||
generator.gen_expr(ctx, expr)?.map(|value| (expr.custom.unwrap(), value));
|
|
||||||
has_return = return_obj.is_some();
|
|
||||||
registry.wait_tasks_complete(handles);
|
|
||||||
attributes_writeback(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
inner_resolver.as_ref(),
|
|
||||||
&host_attributes,
|
|
||||||
return_obj,
|
|
||||||
)
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -864,24 +662,37 @@ impl Nac3 {
|
|||||||
membuffer.lock().push(buffer);
|
membuffer.lock().push(buffer);
|
||||||
});
|
});
|
||||||
|
|
||||||
embedding_map.setattr("expects_return", has_return).unwrap();
|
let context = inkwell::context::Context::create();
|
||||||
|
|
||||||
// Link all modules into `main`.
|
|
||||||
let buffers = membuffers.lock();
|
let buffers = membuffers.lock();
|
||||||
let main = context
|
let main = context
|
||||||
.create_module_from_ir(MemoryBuffer::create_from_memory_range(
|
.create_module_from_ir(MemoryBuffer::create_from_memory_range(&buffers[0], "main"))
|
||||||
buffers.last().unwrap(),
|
|
||||||
"main",
|
|
||||||
))
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
for buffer in buffers.iter().rev().skip(1) {
|
for buffer in buffers.iter().skip(1) {
|
||||||
let other = context
|
let other = context
|
||||||
.create_module_from_ir(MemoryBuffer::create_from_memory_range(buffer, "main"))
|
.create_module_from_ir(MemoryBuffer::create_from_memory_range(buffer, "main"))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
main.link_in_module(other).map_err(|err| CompileError::new_err(err.to_string()))?;
|
main.link_in_module(other).map_err(|err| CompileError::new_err(err.to_string()))?;
|
||||||
}
|
}
|
||||||
main.link_in_module(irrt).map_err(|err| CompileError::new_err(err.to_string()))?;
|
let builder = context.create_builder();
|
||||||
|
let modinit_return = main
|
||||||
|
.get_function("__modinit__")
|
||||||
|
.unwrap()
|
||||||
|
.get_last_basic_block()
|
||||||
|
.unwrap()
|
||||||
|
.get_terminator()
|
||||||
|
.unwrap();
|
||||||
|
builder.position_before(&modinit_return);
|
||||||
|
builder
|
||||||
|
.build_call(
|
||||||
|
main.get_function("attributes_writeback").unwrap(),
|
||||||
|
&[],
|
||||||
|
"attributes_writeback",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
main.link_in_module(load_irrt(&context))
|
||||||
|
.map_err(|err| CompileError::new_err(err.to_string()))?;
|
||||||
|
|
||||||
let mut function_iter = main.get_first_function();
|
let mut function_iter = main.get_first_function();
|
||||||
while let Some(func) = function_iter {
|
while let Some(func) = function_iter {
|
||||||
@ -915,65 +726,58 @@ impl Nac3 {
|
|||||||
panic!("Failed to run optimization for module `main`: {}", err.to_string());
|
panic!("Failed to run optimization for module `main`: {}", err.to_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
Python::with_gil(|py| {
|
|
||||||
let string_store = self.string_store.read();
|
|
||||||
let mut string_store_vec = string_store.iter().collect::<Vec<_>>();
|
|
||||||
string_store_vec.sort_by(|(_s1, key1), (_s2, key2)| key1.cmp(key2));
|
|
||||||
for (s, key) in string_store_vec {
|
|
||||||
let embed_key: i32 = helper.store_str.call1(py, (s,)).unwrap().extract(py).unwrap();
|
|
||||||
assert_eq!(
|
|
||||||
embed_key, *key,
|
|
||||||
"string {s} is out of sync between embedding map (key={embed_key}) and \
|
|
||||||
the internal string store (key={key})"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
link_fn(&main)
|
link_fn(&main)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the [`TargetTriple`] used for compiling to [isa].
|
||||||
|
fn get_llvm_target_triple(isa: Isa) -> TargetTriple {
|
||||||
|
match isa {
|
||||||
|
Isa::Host => TargetMachine::get_default_triple(),
|
||||||
|
Isa::RiscV32G | Isa::RiscV32IMA => TargetTriple::create("riscv32-unknown-linux"),
|
||||||
|
Isa::CortexA9 => TargetTriple::create("armv7-unknown-linux-gnueabihf"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the [`String`] representing the target CPU used for compiling to [isa].
|
||||||
|
fn get_llvm_target_cpu(isa: Isa) -> String {
|
||||||
|
match isa {
|
||||||
|
Isa::Host => TargetMachine::get_host_cpu_name().to_string(),
|
||||||
|
Isa::RiscV32G | Isa::RiscV32IMA => "generic-rv32".to_string(),
|
||||||
|
Isa::CortexA9 => "cortex-a9".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the [`String`] representing the target features used for compiling to [isa].
|
||||||
|
fn get_llvm_target_features(isa: Isa) -> String {
|
||||||
|
match isa {
|
||||||
|
Isa::Host => TargetMachine::get_host_cpu_features().to_string(),
|
||||||
|
Isa::RiscV32G => "+a,+m,+f,+d".to_string(),
|
||||||
|
Isa::RiscV32IMA => "+a,+m".to_string(),
|
||||||
|
Isa::CortexA9 => "+dsp,+fp16,+neon,+vfp3,+long-calls".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns an instance of [`CodeGenTargetMachineOptions`] representing the target machine
|
||||||
|
/// options used for compiling to [isa].
|
||||||
|
fn get_llvm_target_options(isa: Isa) -> CodeGenTargetMachineOptions {
|
||||||
|
CodeGenTargetMachineOptions {
|
||||||
|
triple: Nac3::get_llvm_target_triple(isa).as_str().to_string_lossy().into_owned(),
|
||||||
|
cpu: Nac3::get_llvm_target_cpu(isa),
|
||||||
|
features: Nac3::get_llvm_target_features(isa),
|
||||||
|
reloc_mode: RelocMode::PIC,
|
||||||
|
..CodeGenTargetMachineOptions::from_host()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns an instance of [`TargetMachine`] used in compiling and linking of a program to the
|
/// Returns an instance of [`TargetMachine`] used in compiling and linking of a program to the
|
||||||
/// target [ISA][isa].
|
/// target [isa].
|
||||||
fn get_llvm_target_machine(&self) -> TargetMachine {
|
fn get_llvm_target_machine(&self) -> TargetMachine {
|
||||||
self.isa.create_llvm_target_machine(self.llvm_options.opt_level)
|
Nac3::get_llvm_target_options(self.isa)
|
||||||
|
.create_target_machine(self.llvm_options.opt_level)
|
||||||
|
.expect("couldn't create target machine")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Retrieves the Name.id from a decorator, supports decorators with arguments.
|
|
||||||
fn decorator_id_string(decorator: &Located<ExprKind>) -> Option<String> {
|
|
||||||
if let ExprKind::Name { id, .. } = decorator.node {
|
|
||||||
// Bare decorator
|
|
||||||
return Some(id.to_string());
|
|
||||||
} else if let ExprKind::Call { func, .. } = &decorator.node {
|
|
||||||
// Decorators that are calls (e.g. "@rpc()") have Call for the node,
|
|
||||||
// need to extract the id from within.
|
|
||||||
if let ExprKind::Name { id, .. } = func.node {
|
|
||||||
return Some(id.to_string());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Retrieves flags from a decorator, if any.
|
|
||||||
fn decorator_get_flags(decorator: &Located<ExprKind>) -> Vec<Constant> {
|
|
||||||
let mut flags = vec![];
|
|
||||||
if let ExprKind::Call { keywords, .. } = &decorator.node {
|
|
||||||
for keyword in keywords {
|
|
||||||
if keyword.node.arg != Some("flags".into()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if let ExprKind::Set { elts } = &keyword.node.value.node {
|
|
||||||
for elt in elts {
|
|
||||||
if let ExprKind::Constant { value, .. } = &elt.node {
|
|
||||||
flags.push(value.clone());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
flags
|
|
||||||
}
|
|
||||||
|
|
||||||
fn link_with_lld(elf_filename: String, obj_filename: String) -> PyResult<()> {
|
fn link_with_lld(elf_filename: String, obj_filename: String) -> PyResult<()> {
|
||||||
let linker_args = vec![
|
let linker_args = vec![
|
||||||
"-shared".to_string(),
|
"-shared".to_string(),
|
||||||
@ -1043,8 +847,7 @@ impl Nac3 {
|
|||||||
Isa::RiscV32IMA => &timeline::NOW_PINNING_TIME_FNS,
|
Isa::RiscV32IMA => &timeline::NOW_PINNING_TIME_FNS,
|
||||||
Isa::CortexA9 | Isa::Host => &timeline::EXTERN_TIME_FNS,
|
Isa::CortexA9 | Isa::Host => &timeline::EXTERN_TIME_FNS,
|
||||||
};
|
};
|
||||||
let (primitive, _) =
|
let primitive: PrimitiveStore = TopLevelComposer::make_primitives(isa.get_size_type()).0;
|
||||||
TopLevelComposer::make_primitives(isa.get_size_type(&Context::create()));
|
|
||||||
let builtins = vec![
|
let builtins = vec![
|
||||||
(
|
(
|
||||||
"now_mu".into(),
|
"now_mu".into(),
|
||||||
@ -1060,7 +863,6 @@ impl Nac3 {
|
|||||||
name: "t".into(),
|
name: "t".into(),
|
||||||
ty: primitive.int64,
|
ty: primitive.int64,
|
||||||
default_value: None,
|
default_value: None,
|
||||||
is_vararg: false,
|
|
||||||
}],
|
}],
|
||||||
ret: primitive.none,
|
ret: primitive.none,
|
||||||
vars: VarMap::new(),
|
vars: VarMap::new(),
|
||||||
@ -1080,7 +882,6 @@ impl Nac3 {
|
|||||||
name: "dt".into(),
|
name: "dt".into(),
|
||||||
ty: primitive.int64,
|
ty: primitive.int64,
|
||||||
default_value: None,
|
default_value: None,
|
||||||
is_vararg: false,
|
|
||||||
}],
|
}],
|
||||||
ret: primitive.none,
|
ret: primitive.none,
|
||||||
vars: VarMap::new(),
|
vars: VarMap::new(),
|
||||||
@ -1132,54 +933,11 @@ impl Nac3 {
|
|||||||
tuple: get_attr_id(builtins_mod, "tuple"),
|
tuple: get_attr_id(builtins_mod, "tuple"),
|
||||||
exception: get_attr_id(builtins_mod, "Exception"),
|
exception: get_attr_id(builtins_mod, "Exception"),
|
||||||
option: get_id(artiq_builtins.get_item("Option").ok().flatten().unwrap()),
|
option: get_id(artiq_builtins.get_item("Option").ok().flatten().unwrap()),
|
||||||
module: get_attr_id(types_mod, "ModuleType"),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap();
|
let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap();
|
||||||
fs::write(working_directory.path().join("kernel.ld"), include_bytes!("kernel.ld")).unwrap();
|
fs::write(working_directory.path().join("kernel.ld"), include_bytes!("kernel.ld")).unwrap();
|
||||||
|
|
||||||
let mut string_store: HashMap<String, i32> = HashMap::default();
|
|
||||||
|
|
||||||
// Keep this list of exceptions in sync with `EXCEPTION_ID_LOOKUP` in `artiq::firmware::ksupport::eh_artiq`
|
|
||||||
// The exceptions declared here must be defined in `artiq.coredevice.exceptions`
|
|
||||||
// Verify synchronization by running the test cases in `artiq.test.coredevice.test_exceptions`
|
|
||||||
let runtime_exception_names = [
|
|
||||||
"RTIOUnderflow",
|
|
||||||
"RTIOOverflow",
|
|
||||||
"RTIODestinationUnreachable",
|
|
||||||
"DMAError",
|
|
||||||
"I2CError",
|
|
||||||
"CacheError",
|
|
||||||
"SPIError",
|
|
||||||
"SubkernelError",
|
|
||||||
"0:AssertionError",
|
|
||||||
"0:AttributeError",
|
|
||||||
"0:IndexError",
|
|
||||||
"0:IOError",
|
|
||||||
"0:KeyError",
|
|
||||||
"0:NotImplementedError",
|
|
||||||
"0:OverflowError",
|
|
||||||
"0:RuntimeError",
|
|
||||||
"0:TimeoutError",
|
|
||||||
"0:TypeError",
|
|
||||||
"0:ValueError",
|
|
||||||
"0:ZeroDivisionError",
|
|
||||||
"0:LinAlgError",
|
|
||||||
"UnwrapNoneError",
|
|
||||||
];
|
|
||||||
|
|
||||||
// Preallocate runtime exception names
|
|
||||||
for (i, name) in runtime_exception_names.iter().enumerate() {
|
|
||||||
let exn_name = if name.find(':').is_none() {
|
|
||||||
format!("0:artiq.coredevice.exceptions.{name}")
|
|
||||||
} else {
|
|
||||||
(*name).to_string()
|
|
||||||
};
|
|
||||||
|
|
||||||
let id = i32::try_from(i).unwrap();
|
|
||||||
string_store.insert(exn_name, id);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(Nac3 {
|
Ok(Nac3 {
|
||||||
isa,
|
isa,
|
||||||
time_fns,
|
time_fns,
|
||||||
@ -1189,22 +947,17 @@ impl Nac3 {
|
|||||||
top_levels: Vec::default(),
|
top_levels: Vec::default(),
|
||||||
pyid_to_def: Arc::default(),
|
pyid_to_def: Arc::default(),
|
||||||
working_directory,
|
working_directory,
|
||||||
string_store: Arc::new(string_store.into()),
|
string_store: Arc::default(),
|
||||||
exception_ids: Arc::default(),
|
exception_ids: Arc::default(),
|
||||||
deferred_eval_store: DeferredEvaluationStore::new(),
|
deferred_eval_store: DeferredEvaluationStore::new(),
|
||||||
llvm_options: CodeGenLLVMOptions {
|
llvm_options: CodeGenLLVMOptions {
|
||||||
opt_level: OptimizationLevel::Default,
|
opt_level: OptimizationLevel::Default,
|
||||||
target: isa.get_llvm_target_options(),
|
target: Nac3::get_llvm_target_options(isa),
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn analyze(
|
fn analyze(&mut self, functions: &PySet, classes: &PySet) -> PyResult<()> {
|
||||||
&mut self,
|
|
||||||
functions: &PySet,
|
|
||||||
classes: &PySet,
|
|
||||||
content_modules: &PySet,
|
|
||||||
) -> PyResult<()> {
|
|
||||||
let (modules, class_ids) =
|
let (modules, class_ids) =
|
||||||
Python::with_gil(|py| -> PyResult<(HashMap<u64, PyObject>, HashSet<u64>)> {
|
Python::with_gil(|py| -> PyResult<(HashMap<u64, PyObject>, HashSet<u64>)> {
|
||||||
let mut modules: HashMap<u64, PyObject> = HashMap::new();
|
let mut modules: HashMap<u64, PyObject> = HashMap::new();
|
||||||
@ -1214,21 +967,13 @@ impl Nac3 {
|
|||||||
let getmodule_fn = PyModule::import(py, "inspect")?.getattr("getmodule")?;
|
let getmodule_fn = PyModule::import(py, "inspect")?.getattr("getmodule")?;
|
||||||
|
|
||||||
for function in functions {
|
for function in functions {
|
||||||
let module: PyObject = getmodule_fn.call1((function,))?.extract()?;
|
let module = getmodule_fn.call1((function,))?.extract()?;
|
||||||
if !module.is_none(py) {
|
modules.insert(id_fn.call1((&module,))?.extract()?, module);
|
||||||
modules.insert(id_fn.call1((&module,))?.extract()?, module);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
for class in classes {
|
for class in classes {
|
||||||
let module: PyObject = getmodule_fn.call1((class,))?.extract()?;
|
let module = getmodule_fn.call1((class,))?.extract()?;
|
||||||
if !module.is_none(py) {
|
|
||||||
modules.insert(id_fn.call1((&module,))?.extract()?, module);
|
|
||||||
}
|
|
||||||
class_ids.insert(id_fn.call1((class,))?.extract()?);
|
|
||||||
}
|
|
||||||
for module in content_modules {
|
|
||||||
let module: PyObject = module.extract()?;
|
|
||||||
modules.insert(id_fn.call1((&module,))?.extract()?, module);
|
modules.insert(id_fn.call1((&module,))?.extract()?, module);
|
||||||
|
class_ids.insert(id_fn.call1((class,))?.extract()?);
|
||||||
}
|
}
|
||||||
Ok((modules, class_ids))
|
Ok((modules, class_ids))
|
||||||
})?;
|
})?;
|
||||||
|
@ -1,32 +1,14 @@
|
|||||||
use std::{
|
use inkwell::{
|
||||||
collections::{HashMap, HashSet},
|
types::{BasicType, BasicTypeEnum},
|
||||||
sync::{
|
values::BasicValueEnum,
|
||||||
atomic::{AtomicBool, Ordering::Relaxed},
|
AddressSpace,
|
||||||
Arc,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use parking_lot::RwLock;
|
|
||||||
use pyo3::{
|
|
||||||
types::{PyDict, PyTuple},
|
|
||||||
PyAny, PyErr, PyObject, PyResult, Python,
|
|
||||||
};
|
|
||||||
|
|
||||||
use super::PrimitivePythonId;
|
|
||||||
use nac3core::{
|
use nac3core::{
|
||||||
codegen::{
|
codegen::{
|
||||||
types::{ndarray::NDArrayType, ProxyType},
|
classes::{NDArrayType, ProxyType},
|
||||||
values::ndarray::make_contiguous_strides,
|
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
},
|
},
|
||||||
inkwell::{
|
|
||||||
module::Linkage,
|
|
||||||
types::{BasicType, BasicTypeEnum},
|
|
||||||
values::{BasicValue, BasicValueEnum},
|
|
||||||
AddressSpace,
|
|
||||||
},
|
|
||||||
nac3parser::ast::{self, StrRef},
|
|
||||||
symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum},
|
symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum},
|
||||||
toplevel::{
|
toplevel::{
|
||||||
helper::PrimDef,
|
helper::PrimDef,
|
||||||
@ -38,6 +20,21 @@ use nac3core::{
|
|||||||
typedef::{into_var_map, iter_type_vars, Type, TypeEnum, TypeVar, Unifier, VarMap},
|
typedef::{into_var_map, iter_type_vars, Type, TypeEnum, TypeVar, Unifier, VarMap},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
use nac3parser::ast::{self, StrRef};
|
||||||
|
use parking_lot::{Mutex, RwLock};
|
||||||
|
use pyo3::{
|
||||||
|
types::{PyDict, PyTuple},
|
||||||
|
PyAny, PyObject, PyResult, Python,
|
||||||
|
};
|
||||||
|
use std::{
|
||||||
|
collections::{HashMap, HashSet},
|
||||||
|
sync::{
|
||||||
|
atomic::{AtomicBool, Ordering::Relaxed},
|
||||||
|
Arc,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::PrimitivePythonId;
|
||||||
|
|
||||||
pub enum PrimitiveValue {
|
pub enum PrimitiveValue {
|
||||||
I32(i32),
|
I32(i32),
|
||||||
@ -82,6 +79,7 @@ pub struct InnerResolver {
|
|||||||
pub id_to_primitive: RwLock<HashMap<u64, PrimitiveValue>>,
|
pub id_to_primitive: RwLock<HashMap<u64, PrimitiveValue>>,
|
||||||
pub field_to_val: RwLock<HashMap<ResolverField, Option<PyFieldHandle>>>,
|
pub field_to_val: RwLock<HashMap<ResolverField, Option<PyFieldHandle>>>,
|
||||||
pub global_value_ids: Arc<RwLock<HashMap<u64, PyObject>>>,
|
pub global_value_ids: Arc<RwLock<HashMap<u64, PyObject>>>,
|
||||||
|
pub class_names: Mutex<HashMap<StrRef, Type>>,
|
||||||
pub pyid_to_def: Arc<RwLock<HashMap<u64, DefinitionId>>>,
|
pub pyid_to_def: Arc<RwLock<HashMap<u64, DefinitionId>>>,
|
||||||
pub pyid_to_type: Arc<RwLock<HashMap<u64, Type>>>,
|
pub pyid_to_type: Arc<RwLock<HashMap<u64, Type>>>,
|
||||||
pub primitive_ids: PrimitivePythonId,
|
pub primitive_ids: PrimitivePythonId,
|
||||||
@ -135,8 +133,6 @@ impl StaticValue for PythonValue {
|
|||||||
format!("{}_const", self.id).as_str(),
|
format!("{}_const", self.id).as_str(),
|
||||||
);
|
);
|
||||||
global.set_constant(true);
|
global.set_constant(true);
|
||||||
// Set linkage of global to private to avoid name collisions
|
|
||||||
global.set_linkage(Linkage::Private);
|
|
||||||
global.set_initializer(&ctx.ctx.const_struct(
|
global.set_initializer(&ctx.ctx.const_struct(
|
||||||
&[ctx.ctx.i32_type().const_int(u64::from(id), false).into()],
|
&[ctx.ctx.i32_type().const_int(u64::from(id), false).into()],
|
||||||
false,
|
false,
|
||||||
@ -167,7 +163,7 @@ impl StaticValue for PythonValue {
|
|||||||
PrimitiveValue::Bool(val) => {
|
PrimitiveValue::Bool(val) => {
|
||||||
ctx.ctx.i8_type().const_int(u64::from(*val), false).into()
|
ctx.ctx.i8_type().const_int(u64::from(*val), false).into()
|
||||||
}
|
}
|
||||||
PrimitiveValue::Str(val) => ctx.gen_string(generator, val).into(),
|
PrimitiveValue::Str(val) => ctx.ctx.const_string(val.as_bytes(), true).into(),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
if let Some(global) = ctx.module.get_global(&self.id.to_string()) {
|
if let Some(global) = ctx.module.get_global(&self.id.to_string()) {
|
||||||
@ -355,7 +351,7 @@ impl InnerResolver {
|
|||||||
Ok(Ok((ndarray, false)))
|
Ok(Ok((ndarray, false)))
|
||||||
} else if ty_id == self.primitive_ids.tuple {
|
} else if ty_id == self.primitive_ids.tuple {
|
||||||
// do not handle type var param and concrete check here
|
// do not handle type var param and concrete check here
|
||||||
Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![], is_vararg_ctx: false }), false)))
|
Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false)))
|
||||||
} else if ty_id == self.primitive_ids.option {
|
} else if ty_id == self.primitive_ids.option {
|
||||||
Ok(Ok((primitives.option, false)))
|
Ok(Ok((primitives.option, false)))
|
||||||
} else if ty_id == self.primitive_ids.none {
|
} else if ty_id == self.primitive_ids.none {
|
||||||
@ -559,10 +555,7 @@ impl InnerResolver {
|
|||||||
Err(err) => return Ok(Err(err)),
|
Err(err) => return Ok(Err(err)),
|
||||||
_ => return Ok(Err("tuple type needs at least 1 type parameters".to_string()))
|
_ => return Ok(Err("tuple type needs at least 1 type parameters".to_string()))
|
||||||
};
|
};
|
||||||
Ok(Ok((
|
Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: args }), true)))
|
||||||
unifier.add_ty(TypeEnum::TTuple { ty: args, is_vararg_ctx: false }),
|
|
||||||
true,
|
|
||||||
)))
|
|
||||||
}
|
}
|
||||||
TypeEnum::TObj { params, obj_id, .. } => {
|
TypeEnum::TObj { params, obj_id, .. } => {
|
||||||
let subst = {
|
let subst = {
|
||||||
@ -674,48 +667,6 @@ impl InnerResolver {
|
|||||||
})
|
})
|
||||||
});
|
});
|
||||||
|
|
||||||
// check if obj is module
|
|
||||||
if self.helper.id_fn.call1(py, (ty.clone(),))?.extract::<u64>(py)?
|
|
||||||
== self.primitive_ids.module
|
|
||||||
&& self.pyid_to_def.read().contains_key(&py_obj_id)
|
|
||||||
{
|
|
||||||
let def_id = self.pyid_to_def.read()[&py_obj_id];
|
|
||||||
let def = defs[def_id.0].read();
|
|
||||||
let TopLevelDef::Module { name: module_name, module_id, attributes, methods, .. } =
|
|
||||||
&*def
|
|
||||||
else {
|
|
||||||
unreachable!("must be a module here");
|
|
||||||
};
|
|
||||||
// Construct the module return type
|
|
||||||
let mut module_attributes = HashMap::new();
|
|
||||||
for (name, _) in attributes {
|
|
||||||
let attribute_obj = obj.getattr(name.to_string().as_str())?;
|
|
||||||
let attribute_ty =
|
|
||||||
self.get_obj_type(py, attribute_obj, unifier, defs, primitives)?;
|
|
||||||
if let Ok(attribute_ty) = attribute_ty {
|
|
||||||
module_attributes.insert(*name, (attribute_ty, false));
|
|
||||||
} else {
|
|
||||||
return Ok(Err(format!("Unable to resolve {module_name}.{name}")));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for name in methods.keys() {
|
|
||||||
let method_obj = obj.getattr(name.to_string().as_str())?;
|
|
||||||
let method_ty = self.get_obj_type(py, method_obj, unifier, defs, primitives)?;
|
|
||||||
if let Ok(method_ty) = method_ty {
|
|
||||||
module_attributes.insert(*name, (method_ty, true));
|
|
||||||
} else {
|
|
||||||
return Ok(Err(format!("Unable to resolve {module_name}.{name}")));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let module_ty =
|
|
||||||
TypeEnum::TModule { module_id: *module_id, attributes: module_attributes };
|
|
||||||
|
|
||||||
let ty = unifier.add_ty(module_ty);
|
|
||||||
return Ok(Ok(ty));
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(ty) = constructor_ty {
|
if let Some(ty) = constructor_ty {
|
||||||
self.pyid_to_type.write().insert(py_obj_id, ty);
|
self.pyid_to_type.write().insert(py_obj_id, ty);
|
||||||
return Ok(Ok(ty));
|
return Ok(Ok(ty));
|
||||||
@ -846,9 +797,7 @@ impl InnerResolver {
|
|||||||
.map(|elem| self.get_obj_type(py, elem, unifier, defs, primitives))
|
.map(|elem| self.get_obj_type(py, elem, unifier, defs, primitives))
|
||||||
.collect();
|
.collect();
|
||||||
let types = types?;
|
let types = types?;
|
||||||
Ok(types.map(|types| {
|
Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types })))
|
||||||
unifier.add_ty(TypeEnum::TTuple { ty: types, is_vararg_ctx: false })
|
|
||||||
}))
|
|
||||||
}
|
}
|
||||||
// special handling for option type since its class member layout in python side
|
// special handling for option type since its class member layout in python side
|
||||||
// is special and cannot be mapped directly to a nac3 type as below
|
// is special and cannot be mapped directly to a nac3 type as below
|
||||||
@ -973,13 +922,10 @@ impl InnerResolver {
|
|||||||
|_| Ok(Ok(extracted_ty)),
|
|_| Ok(Ok(extracted_ty)),
|
||||||
)
|
)
|
||||||
} else if unifier.unioned(extracted_ty, primitives.bool) {
|
} else if unifier.unioned(extracted_ty, primitives.bool) {
|
||||||
if obj.extract::<bool>().is_ok()
|
obj.extract::<bool>().map_or_else(
|
||||||
|| obj.call_method("__bool__", (), None)?.extract::<bool>().is_ok()
|
|_| Ok(Err(format!("{obj} is not in the range of bool"))),
|
||||||
{
|
|_| Ok(Ok(extracted_ty)),
|
||||||
Ok(Ok(extracted_ty))
|
)
|
||||||
} else {
|
|
||||||
Ok(Err(format!("{obj} is not in the range of bool")))
|
|
||||||
}
|
|
||||||
} else if unifier.unioned(extracted_ty, primitives.float) {
|
} else if unifier.unioned(extracted_ty, primitives.float) {
|
||||||
obj.extract::<f64>().map_or_else(
|
obj.extract::<f64>().map_or_else(
|
||||||
|_| Ok(Err(format!("{obj} is not in the range of float64"))),
|
|_| Ok(Err(format!("{obj} is not in the range of float64"))),
|
||||||
@ -1019,18 +965,14 @@ impl InnerResolver {
|
|||||||
let val: u64 = obj.extract().unwrap();
|
let val: u64 = obj.extract().unwrap();
|
||||||
self.id_to_primitive.write().insert(id, PrimitiveValue::U64(val));
|
self.id_to_primitive.write().insert(id, PrimitiveValue::U64(val));
|
||||||
Ok(Some(ctx.ctx.i64_type().const_int(val, false).into()))
|
Ok(Some(ctx.ctx.i64_type().const_int(val, false).into()))
|
||||||
} else if ty_id == self.primitive_ids.bool {
|
} else if ty_id == self.primitive_ids.bool || ty_id == self.primitive_ids.np_bool_ {
|
||||||
let val: bool = obj.extract().unwrap();
|
let val: bool = obj.extract().unwrap();
|
||||||
self.id_to_primitive.write().insert(id, PrimitiveValue::Bool(val));
|
self.id_to_primitive.write().insert(id, PrimitiveValue::Bool(val));
|
||||||
Ok(Some(ctx.ctx.i8_type().const_int(u64::from(val), false).into()))
|
Ok(Some(ctx.ctx.i8_type().const_int(u64::from(val), false).into()))
|
||||||
} else if ty_id == self.primitive_ids.np_bool_ {
|
|
||||||
let val: bool = obj.call_method("__bool__", (), None)?.extract().unwrap();
|
|
||||||
self.id_to_primitive.write().insert(id, PrimitiveValue::Bool(val));
|
|
||||||
Ok(Some(ctx.ctx.i8_type().const_int(u64::from(val), false).into()))
|
|
||||||
} else if ty_id == self.primitive_ids.string || ty_id == self.primitive_ids.np_str_ {
|
} else if ty_id == self.primitive_ids.string || ty_id == self.primitive_ids.np_str_ {
|
||||||
let val: String = obj.extract().unwrap();
|
let val: String = obj.extract().unwrap();
|
||||||
self.id_to_primitive.write().insert(id, PrimitiveValue::Str(val.clone()));
|
self.id_to_primitive.write().insert(id, PrimitiveValue::Str(val.clone()));
|
||||||
Ok(Some(ctx.gen_string(generator, val).into()))
|
Ok(Some(ctx.ctx.const_string(val.as_bytes(), true).into()))
|
||||||
} else if ty_id == self.primitive_ids.float || ty_id == self.primitive_ids.float64 {
|
} else if ty_id == self.primitive_ids.float || ty_id == self.primitive_ids.float64 {
|
||||||
let val: f64 = obj.extract().unwrap();
|
let val: f64 = obj.extract().unwrap();
|
||||||
self.id_to_primitive.write().insert(id, PrimitiveValue::F64(val));
|
self.id_to_primitive.write().insert(id, PrimitiveValue::F64(val));
|
||||||
@ -1049,15 +991,8 @@ impl InnerResolver {
|
|||||||
}
|
}
|
||||||
_ => unreachable!("must be list"),
|
_ => unreachable!("must be list"),
|
||||||
};
|
};
|
||||||
let size_t = ctx.get_size_type();
|
let ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
let ty = if len == 0
|
let size_t = generator.get_size_type(ctx.ctx);
|
||||||
&& matches!(&*ctx.unifier.get_ty_immutable(elem_ty), TypeEnum::TVar { .. })
|
|
||||||
{
|
|
||||||
// The default type for zero-length lists of unknown element type is size_t
|
|
||||||
size_t.into()
|
|
||||||
} else {
|
|
||||||
ctx.get_llvm_type(generator, elem_ty)
|
|
||||||
};
|
|
||||||
let arr_ty = ctx
|
let arr_ty = ctx
|
||||||
.ctx
|
.ctx
|
||||||
.struct_type(&[ty.ptr_type(AddressSpace::default()).into(), size_t.into()], false);
|
.struct_type(&[ty.ptr_type(AddressSpace::default()).into(), size_t.into()], false);
|
||||||
@ -1134,19 +1069,18 @@ impl InnerResolver {
|
|||||||
} else {
|
} else {
|
||||||
unreachable!("must be ndarray")
|
unreachable!("must be ndarray")
|
||||||
};
|
};
|
||||||
let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty);
|
let (ndarray_dtype, ndarray_ndims) =
|
||||||
|
unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty);
|
||||||
|
|
||||||
let llvm_i8 = ctx.ctx.i8_type();
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
let ndarray_dtype_llvm_ty = ctx.get_llvm_type(generator, ndarray_dtype);
|
||||||
let llvm_usize = ctx.get_size_type();
|
let ndarray_llvm_ty = NDArrayType::new(generator, ctx.ctx, ndarray_dtype_llvm_ty);
|
||||||
let llvm_ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty);
|
|
||||||
let dtype = llvm_ndarray.element_type();
|
|
||||||
|
|
||||||
{
|
{
|
||||||
if self.global_value_ids.read().contains_key(&id) {
|
if self.global_value_ids.read().contains_key(&id) {
|
||||||
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
|
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
|
||||||
ctx.module.add_global(
|
ctx.module.add_global(
|
||||||
llvm_ndarray.as_base_type().get_element_type().into_struct_type(),
|
ndarray_llvm_ty.as_underlying_type(),
|
||||||
Some(AddressSpace::default()),
|
Some(AddressSpace::default()),
|
||||||
&id_str,
|
&id_str,
|
||||||
)
|
)
|
||||||
@ -1156,44 +1090,40 @@ impl InnerResolver {
|
|||||||
self.global_value_ids.write().insert(id, obj.into());
|
self.global_value_ids.write().insert(id, obj.into());
|
||||||
}
|
}
|
||||||
|
|
||||||
let ndims = llvm_ndarray.ndims();
|
let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndarray_ndims)
|
||||||
|
else {
|
||||||
|
unreachable!("Expected Literal for ndarray_ndims")
|
||||||
|
};
|
||||||
|
|
||||||
|
let ndarray_ndims = if values.len() == 1 {
|
||||||
|
values[0].clone()
|
||||||
|
} else {
|
||||||
|
todo!("Unpacking literal of more than one element unimplemented")
|
||||||
|
};
|
||||||
|
let Ok(ndarray_ndims) = u64::try_from(ndarray_ndims) else {
|
||||||
|
unreachable!("Expected u64 value for ndarray_ndims")
|
||||||
|
};
|
||||||
|
|
||||||
// Obtain the shape of the ndarray
|
// Obtain the shape of the ndarray
|
||||||
let shape_tuple: &PyTuple = obj.getattr("shape")?.downcast()?;
|
let shape_tuple: &PyTuple = obj.getattr("shape")?.downcast()?;
|
||||||
assert_eq!(shape_tuple.len(), ndims as usize);
|
assert_eq!(shape_tuple.len(), ndarray_ndims as usize);
|
||||||
|
let shape_values: Result<Option<Vec<_>>, _> = shape_tuple
|
||||||
// The Rust type inferencer cannot figure this out
|
|
||||||
let shape_values = shape_tuple
|
|
||||||
.iter()
|
.iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.map(|(i, elem)| {
|
.map(|(i, elem)| {
|
||||||
let value = self
|
self.get_obj_value(py, elem, ctx, generator, ctx.primitives.usize()).map_err(
|
||||||
.get_obj_value(py, elem, ctx, generator, ctx.primitives.usize())
|
|e| super::CompileError::new_err(format!("Error getting element {i}: {e}")),
|
||||||
.map_err(|e| {
|
)
|
||||||
super::CompileError::new_err(format!("Error getting element {i}: {e}"))
|
|
||||||
})?
|
|
||||||
.unwrap();
|
|
||||||
let value = ctx
|
|
||||||
.builder
|
|
||||||
.build_int_z_extend(value.into_int_value(), llvm_usize, "")
|
|
||||||
.unwrap();
|
|
||||||
Ok(value)
|
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>, PyErr>>()?;
|
.collect();
|
||||||
|
let shape_values = shape_values?.unwrap();
|
||||||
// Also use this opportunity to get the constant values of `shape_values` for calculating strides.
|
let shape_values = llvm_usize.const_array(
|
||||||
let shape_u64s = shape_values
|
&shape_values.into_iter().map(BasicValueEnum::into_int_value).collect_vec(),
|
||||||
.iter()
|
);
|
||||||
.map(|dim| {
|
|
||||||
assert!(dim.is_const());
|
|
||||||
dim.get_zero_extended_constant().unwrap()
|
|
||||||
})
|
|
||||||
.collect_vec();
|
|
||||||
let shape_values = llvm_usize.const_array(&shape_values);
|
|
||||||
|
|
||||||
// create a global for ndarray.shape and initialize it using the shape
|
// create a global for ndarray.shape and initialize it using the shape
|
||||||
let shape_global = ctx.module.add_global(
|
let shape_global = ctx.module.add_global(
|
||||||
llvm_usize.array_type(ndims as u32),
|
llvm_usize.array_type(ndarray_ndims as u32),
|
||||||
Some(AddressSpace::default()),
|
Some(AddressSpace::default()),
|
||||||
&(id_str.clone() + ".shape"),
|
&(id_str.clone() + ".shape"),
|
||||||
);
|
);
|
||||||
@ -1201,25 +1131,17 @@ impl InnerResolver {
|
|||||||
|
|
||||||
// Obtain the (flattened) elements of the ndarray
|
// Obtain the (flattened) elements of the ndarray
|
||||||
let sz: usize = obj.getattr("size")?.extract()?;
|
let sz: usize = obj.getattr("size")?.extract()?;
|
||||||
let data: Vec<_> = (0..sz)
|
let data: Result<Option<Vec<_>>, _> = (0..sz)
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
obj.getattr("flat")?.get_item(i).and_then(|elem| {
|
obj.getattr("flat")?.get_item(i).and_then(|elem| {
|
||||||
let value = self
|
self.get_obj_value(py, elem, ctx, generator, ndarray_dtype).map_err(|e| {
|
||||||
.get_obj_value(py, elem, ctx, generator, ndarray_dtype)
|
super::CompileError::new_err(format!("Error getting element {i}: {e}"))
|
||||||
.map_err(|e| {
|
})
|
||||||
super::CompileError::new_err(format!(
|
|
||||||
"Error getting element {i}: {e}"
|
|
||||||
))
|
|
||||||
})?
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert_eq!(value.get_type(), dtype);
|
|
||||||
Ok(value)
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.try_collect()?;
|
.collect();
|
||||||
let data = data.into_iter();
|
let data = data?.unwrap().into_iter();
|
||||||
let data = match dtype {
|
let data = match ndarray_dtype_llvm_ty {
|
||||||
BasicTypeEnum::ArrayType(ty) => {
|
BasicTypeEnum::ArrayType(ty) => {
|
||||||
ty.const_array(&data.map(BasicValueEnum::into_array_value).collect_vec())
|
ty.const_array(&data.map(BasicValueEnum::into_array_value).collect_vec())
|
||||||
}
|
}
|
||||||
@ -1244,102 +1166,37 @@ impl InnerResolver {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// create a global for ndarray.data and initialize it using the elements
|
// create a global for ndarray.data and initialize it using the elements
|
||||||
//
|
|
||||||
// NOTE: NDArray's `data` is `u8*`. Here, `data_global` is an array of `dtype`.
|
|
||||||
// We will have to cast it to an `u8*` later.
|
|
||||||
let data_global = ctx.module.add_global(
|
let data_global = ctx.module.add_global(
|
||||||
dtype.array_type(sz as u32),
|
ndarray_dtype_llvm_ty.array_type(sz as u32),
|
||||||
Some(AddressSpace::default()),
|
Some(AddressSpace::default()),
|
||||||
&(id_str.clone() + ".data"),
|
&(id_str.clone() + ".data"),
|
||||||
);
|
);
|
||||||
data_global.set_initializer(&data);
|
data_global.set_initializer(&data);
|
||||||
|
|
||||||
// Get the constant itemsize.
|
|
||||||
//
|
|
||||||
// NOTE: dtype.size_of() may return a non-constant, where `TargetData::get_store_size`
|
|
||||||
// will always return a constant size.
|
|
||||||
let itemsize = ctx
|
|
||||||
.registry
|
|
||||||
.llvm_options
|
|
||||||
.create_target_machine()
|
|
||||||
.map(|tm| tm.get_target_data().get_store_size(&dtype))
|
|
||||||
.unwrap();
|
|
||||||
assert_ne!(itemsize, 0);
|
|
||||||
|
|
||||||
// Create the strides needed for ndarray.strides
|
|
||||||
let strides = make_contiguous_strides(itemsize, ndims, &shape_u64s);
|
|
||||||
let strides =
|
|
||||||
strides.into_iter().map(|stride| llvm_usize.const_int(stride, false)).collect_vec();
|
|
||||||
let strides = llvm_usize.const_array(&strides);
|
|
||||||
|
|
||||||
// create a global for ndarray.strides and initialize it
|
|
||||||
let strides_global = ctx.module.add_global(
|
|
||||||
llvm_usize.array_type(ndims as u32),
|
|
||||||
Some(AddressSpace::default()),
|
|
||||||
&format!("${id_str}.strides"),
|
|
||||||
);
|
|
||||||
strides_global.set_initializer(&strides);
|
|
||||||
|
|
||||||
// create a global for the ndarray object and initialize it
|
// create a global for the ndarray object and initialize it
|
||||||
|
let value = ndarray_llvm_ty.as_underlying_type().const_named_struct(&[
|
||||||
|
llvm_usize.const_int(ndarray_ndims, false).into(),
|
||||||
|
shape_global
|
||||||
|
.as_pointer_value()
|
||||||
|
.const_cast(llvm_usize.ptr_type(AddressSpace::default()))
|
||||||
|
.into(),
|
||||||
|
data_global
|
||||||
|
.as_pointer_value()
|
||||||
|
.const_cast(ndarray_dtype_llvm_ty.ptr_type(AddressSpace::default()))
|
||||||
|
.into(),
|
||||||
|
]);
|
||||||
|
|
||||||
// NOTE: data_global is an array of dtype, we want a `u8*`.
|
let ndarray = ctx.module.add_global(
|
||||||
let ndarray_data = data_global.as_pointer_value();
|
ndarray_llvm_ty.as_underlying_type(),
|
||||||
let ndarray_data = ctx.builder.build_pointer_cast(ndarray_data, llvm_pi8, "").unwrap();
|
|
||||||
|
|
||||||
let ndarray_itemsize = llvm_usize.const_int(itemsize, false);
|
|
||||||
|
|
||||||
let ndarray_ndims = llvm_usize.const_int(ndims, false);
|
|
||||||
|
|
||||||
// calling as_pointer_value on shape and strides returns [i64 x ndims]*
|
|
||||||
// convert into i64* to conform with expected layout of ndarray
|
|
||||||
|
|
||||||
let ndarray_shape = shape_global.as_pointer_value();
|
|
||||||
let ndarray_shape = unsafe {
|
|
||||||
ctx.builder
|
|
||||||
.build_in_bounds_gep(
|
|
||||||
ndarray_shape,
|
|
||||||
&[llvm_usize.const_zero(), llvm_usize.const_zero()],
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
.unwrap()
|
|
||||||
};
|
|
||||||
|
|
||||||
let ndarray_strides = strides_global.as_pointer_value();
|
|
||||||
let ndarray_strides = unsafe {
|
|
||||||
ctx.builder
|
|
||||||
.build_in_bounds_gep(
|
|
||||||
ndarray_strides,
|
|
||||||
&[llvm_usize.const_zero(), llvm_usize.const_zero()],
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
.unwrap()
|
|
||||||
};
|
|
||||||
|
|
||||||
let ndarray = llvm_ndarray
|
|
||||||
.as_base_type()
|
|
||||||
.get_element_type()
|
|
||||||
.into_struct_type()
|
|
||||||
.const_named_struct(&[
|
|
||||||
ndarray_itemsize.into(),
|
|
||||||
ndarray_ndims.into(),
|
|
||||||
ndarray_shape.into(),
|
|
||||||
ndarray_strides.into(),
|
|
||||||
ndarray_data.into(),
|
|
||||||
]);
|
|
||||||
|
|
||||||
let ndarray_global = ctx.module.add_global(
|
|
||||||
llvm_ndarray.as_base_type().get_element_type().into_struct_type(),
|
|
||||||
Some(AddressSpace::default()),
|
Some(AddressSpace::default()),
|
||||||
&id_str,
|
&id_str,
|
||||||
);
|
);
|
||||||
ndarray_global.set_initializer(&ndarray);
|
ndarray.set_initializer(&value);
|
||||||
|
|
||||||
Ok(Some(ndarray_global.as_pointer_value().into()))
|
Ok(Some(ndarray.as_pointer_value().into()))
|
||||||
} else if ty_id == self.primitive_ids.tuple {
|
} else if ty_id == self.primitive_ids.tuple {
|
||||||
let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty);
|
let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty);
|
||||||
let TypeEnum::TTuple { ty, is_vararg_ctx: false } = expected_ty_enum.as_ref() else {
|
let TypeEnum::TTuple { ty } = expected_ty_enum.as_ref() else { unreachable!() };
|
||||||
unreachable!()
|
|
||||||
};
|
|
||||||
|
|
||||||
let tup_tys = ty.iter();
|
let tup_tys = ty.iter();
|
||||||
let elements: &PyTuple = obj.downcast()?;
|
let elements: &PyTuple = obj.downcast()?;
|
||||||
@ -1415,77 +1272,6 @@ impl InnerResolver {
|
|||||||
None => Ok(None),
|
None => Ok(None),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if ty_id == self.primitive_ids.module {
|
|
||||||
let id_str = id.to_string();
|
|
||||||
|
|
||||||
if let Some(global) = ctx.module.get_global(&id_str) {
|
|
||||||
return Ok(Some(global.as_pointer_value().into()));
|
|
||||||
}
|
|
||||||
|
|
||||||
let top_level_defs = ctx.top_level.definitions.read();
|
|
||||||
let ty = self
|
|
||||||
.get_obj_type(py, obj, &mut ctx.unifier, &top_level_defs, &ctx.primitives)?
|
|
||||||
.unwrap();
|
|
||||||
let ty = ctx
|
|
||||||
.get_llvm_type(generator, ty)
|
|
||||||
.into_pointer_type()
|
|
||||||
.get_element_type()
|
|
||||||
.into_struct_type();
|
|
||||||
|
|
||||||
{
|
|
||||||
if self.global_value_ids.read().contains_key(&id) {
|
|
||||||
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
|
|
||||||
ctx.module.add_global(ty, Some(AddressSpace::default()), &id_str)
|
|
||||||
});
|
|
||||||
return Ok(Some(global.as_pointer_value().into()));
|
|
||||||
}
|
|
||||||
self.global_value_ids.write().insert(id, obj.into());
|
|
||||||
}
|
|
||||||
|
|
||||||
let fields = {
|
|
||||||
let definition =
|
|
||||||
top_level_defs.get(self.pyid_to_def.read().get(&id).unwrap().0).unwrap().read();
|
|
||||||
let TopLevelDef::Module { attributes, .. } = &*definition else { unreachable!() };
|
|
||||||
attributes
|
|
||||||
.iter()
|
|
||||||
.filter_map(|f| {
|
|
||||||
let definition = top_level_defs.get(f.1 .0).unwrap().read();
|
|
||||||
if let TopLevelDef::Variable { ty, .. } = &*definition {
|
|
||||||
Some((f.0, *ty))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect_vec()
|
|
||||||
};
|
|
||||||
|
|
||||||
let values: Result<Option<Vec<_>>, _> = fields
|
|
||||||
.iter()
|
|
||||||
.map(|(name, ty)| {
|
|
||||||
self.get_obj_value(
|
|
||||||
py,
|
|
||||||
obj.getattr(name.to_string().as_str())?,
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
*ty,
|
|
||||||
)
|
|
||||||
.map_err(|e| {
|
|
||||||
super::CompileError::new_err(format!("Error getting field {name}: {e}"))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
let values = values?;
|
|
||||||
|
|
||||||
if let Some(values) = values {
|
|
||||||
let val = ty.const_named_struct(&values);
|
|
||||||
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
|
|
||||||
ctx.module.add_global(ty, Some(AddressSpace::default()), &id_str)
|
|
||||||
});
|
|
||||||
global.set_initializer(&val);
|
|
||||||
Ok(Some(global.as_pointer_value().into()))
|
|
||||||
} else {
|
|
||||||
Ok(None)
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
let id_str = id.to_string();
|
let id_str = id.to_string();
|
||||||
|
|
||||||
@ -1565,12 +1351,9 @@ impl InnerResolver {
|
|||||||
} else if ty_id == self.primitive_ids.uint64 {
|
} else if ty_id == self.primitive_ids.uint64 {
|
||||||
let val: u64 = obj.extract()?;
|
let val: u64 = obj.extract()?;
|
||||||
Ok(SymbolValue::U64(val))
|
Ok(SymbolValue::U64(val))
|
||||||
} else if ty_id == self.primitive_ids.bool {
|
} else if ty_id == self.primitive_ids.bool || ty_id == self.primitive_ids.np_bool_ {
|
||||||
let val: bool = obj.extract()?;
|
let val: bool = obj.extract()?;
|
||||||
Ok(SymbolValue::Bool(val))
|
Ok(SymbolValue::Bool(val))
|
||||||
} else if ty_id == self.primitive_ids.np_bool_ {
|
|
||||||
let val: bool = obj.call_method("__bool__", (), None)?.extract()?;
|
|
||||||
Ok(SymbolValue::Bool(val))
|
|
||||||
} else if ty_id == self.primitive_ids.string || ty_id == self.primitive_ids.np_str_ {
|
} else if ty_id == self.primitive_ids.string || ty_id == self.primitive_ids.np_str_ {
|
||||||
let val: String = obj.extract()?;
|
let val: String = obj.extract()?;
|
||||||
Ok(SymbolValue::Str(val))
|
Ok(SymbolValue::Str(val))
|
||||||
@ -1668,50 +1451,8 @@ impl SymbolResolver for Resolver {
|
|||||||
fn get_symbol_value<'ctx>(
|
fn get_symbol_value<'ctx>(
|
||||||
&self,
|
&self,
|
||||||
id: StrRef,
|
id: StrRef,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
_: &mut CodeGenContext<'ctx, '_>,
|
||||||
generator: &mut dyn CodeGenerator,
|
|
||||||
) -> Option<ValueEnum<'ctx>> {
|
) -> Option<ValueEnum<'ctx>> {
|
||||||
if let Some(def_id) = self.0.id_to_def.read().get(&id) {
|
|
||||||
let top_levels = ctx.top_level.definitions.read();
|
|
||||||
if matches!(&*top_levels[def_id.0].read(), TopLevelDef::Variable { .. }) {
|
|
||||||
let module_val = &self.0.module;
|
|
||||||
let ret = Python::with_gil(|py| -> PyResult<Result<BasicValueEnum, String>> {
|
|
||||||
let module_val = module_val.as_ref(py);
|
|
||||||
|
|
||||||
let ty = self.0.get_obj_type(
|
|
||||||
py,
|
|
||||||
module_val,
|
|
||||||
&mut ctx.unifier,
|
|
||||||
&top_levels,
|
|
||||||
&ctx.primitives,
|
|
||||||
)?;
|
|
||||||
if let Err(ty) = ty {
|
|
||||||
return Ok(Err(ty));
|
|
||||||
}
|
|
||||||
let ty = ty.unwrap();
|
|
||||||
let obj = self.0.get_obj_value(py, module_val, ctx, generator, ty)?.unwrap();
|
|
||||||
let (idx, _) = ctx.get_attr_index(ty, id);
|
|
||||||
let ret = unsafe {
|
|
||||||
ctx.builder.build_gep(
|
|
||||||
obj.into_pointer_value(),
|
|
||||||
&[
|
|
||||||
ctx.ctx.i32_type().const_zero(),
|
|
||||||
ctx.ctx.i32_type().const_int(idx as u64, false),
|
|
||||||
],
|
|
||||||
id.to_string().as_str(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
.unwrap();
|
|
||||||
Ok(Ok(ret.as_basic_value_enum()))
|
|
||||||
})
|
|
||||||
.unwrap();
|
|
||||||
if ret.is_err() {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
return Some(ret.unwrap().into());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let sym_value = {
|
let sym_value = {
|
||||||
let id_to_val = self.0.id_to_pyval.read();
|
let id_to_val = self.0.id_to_pyval.read();
|
||||||
id_to_val.get(&id).cloned()
|
id_to_val.get(&id).cloned()
|
||||||
@ -1772,7 +1513,10 @@ impl SymbolResolver for Resolver {
|
|||||||
if let Some(id) = string_store.get(s) {
|
if let Some(id) = string_store.get(s) {
|
||||||
*id
|
*id
|
||||||
} else {
|
} else {
|
||||||
let id = i32::try_from(string_store.len()).unwrap();
|
let id = Python::with_gil(|py| -> PyResult<i32> {
|
||||||
|
self.0.helper.store_str.call1(py, (s,))?.extract(py)
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
string_store.insert(s.into(), id);
|
string_store.insert(s.into(), id);
|
||||||
id
|
id
|
||||||
}
|
}
|
||||||
|
@ -1,12 +1,9 @@
|
|||||||
use itertools::Either;
|
use inkwell::{
|
||||||
|
values::{BasicValueEnum, CallSiteValue},
|
||||||
use nac3core::{
|
AddressSpace, AtomicOrdering,
|
||||||
codegen::CodeGenContext,
|
|
||||||
inkwell::{
|
|
||||||
values::{BasicValueEnum, CallSiteValue},
|
|
||||||
AddressSpace, AtomicOrdering,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
use itertools::Either;
|
||||||
|
use nac3core::codegen::CodeGenContext;
|
||||||
|
|
||||||
/// Functions for manipulating the timeline.
|
/// Functions for manipulating the timeline.
|
||||||
pub trait TimeFns {
|
pub trait TimeFns {
|
||||||
@ -34,7 +31,7 @@ impl TimeFns for NowPinningTimeFns64 {
|
|||||||
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
|
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
|
||||||
let now_hiptr = ctx
|
let now_hiptr = ctx
|
||||||
.builder
|
.builder
|
||||||
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
|
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
|
||||||
.map(BasicValueEnum::into_pointer_value)
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@ -83,7 +80,7 @@ impl TimeFns for NowPinningTimeFns64 {
|
|||||||
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
|
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
|
||||||
let now_hiptr = ctx
|
let now_hiptr = ctx
|
||||||
.builder
|
.builder
|
||||||
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
|
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
|
||||||
.map(BasicValueEnum::into_pointer_value)
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@ -112,7 +109,7 @@ impl TimeFns for NowPinningTimeFns64 {
|
|||||||
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
|
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
|
||||||
let now_hiptr = ctx
|
let now_hiptr = ctx
|
||||||
.builder
|
.builder
|
||||||
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
|
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
|
||||||
.map(BasicValueEnum::into_pointer_value)
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@ -210,7 +207,7 @@ impl TimeFns for NowPinningTimeFns {
|
|||||||
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
|
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
|
||||||
let now_hiptr = ctx
|
let now_hiptr = ctx
|
||||||
.builder
|
.builder
|
||||||
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
|
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
|
||||||
.map(BasicValueEnum::into_pointer_value)
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@ -261,7 +258,7 @@ impl TimeFns for NowPinningTimeFns {
|
|||||||
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap();
|
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap();
|
||||||
let now_hiptr = ctx
|
let now_hiptr = ctx
|
||||||
.builder
|
.builder
|
||||||
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
|
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
|
||||||
.map(BasicValueEnum::into_pointer_value)
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ constant-optimization = ["fold"]
|
|||||||
fold = []
|
fold = []
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
lazy_static = "1.5"
|
||||||
parking_lot = "0.12"
|
parking_lot = "0.12"
|
||||||
string-interner = "0.17"
|
string-interner = "0.17"
|
||||||
fxhash = "0.2"
|
fxhash = "0.2"
|
||||||
|
@ -5,12 +5,14 @@ pub use crate::location::Location;
|
|||||||
|
|
||||||
use fxhash::FxBuildHasher;
|
use fxhash::FxBuildHasher;
|
||||||
use parking_lot::{Mutex, MutexGuard};
|
use parking_lot::{Mutex, MutexGuard};
|
||||||
use std::{cell::RefCell, collections::HashMap, fmt, sync::LazyLock};
|
use std::{cell::RefCell, collections::HashMap, fmt};
|
||||||
use string_interner::{symbol::SymbolU32, DefaultBackend, StringInterner};
|
use string_interner::{symbol::SymbolU32, DefaultBackend, StringInterner};
|
||||||
|
|
||||||
pub type Interner = StringInterner<DefaultBackend, FxBuildHasher>;
|
pub type Interner = StringInterner<DefaultBackend, FxBuildHasher>;
|
||||||
static INTERNER: LazyLock<Mutex<Interner>> =
|
lazy_static! {
|
||||||
LazyLock::new(|| Mutex::new(StringInterner::with_hasher(FxBuildHasher::default())));
|
static ref INTERNER: Mutex<Interner> =
|
||||||
|
Mutex::new(StringInterner::with_hasher(FxBuildHasher::default()));
|
||||||
|
}
|
||||||
|
|
||||||
thread_local! {
|
thread_local! {
|
||||||
static LOCAL_INTERNER: RefCell<HashMap<String, StrRef>> = RefCell::default();
|
static LOCAL_INTERNER: RefCell<HashMap<String, StrRef>> = RefCell::default();
|
||||||
|
@ -1,4 +1,10 @@
|
|||||||
#![deny(future_incompatible, let_underscore, nonstandard_style, clippy::all)]
|
#![deny(
|
||||||
|
future_incompatible,
|
||||||
|
let_underscore,
|
||||||
|
nonstandard_style,
|
||||||
|
rust_2024_compatibility,
|
||||||
|
clippy::all
|
||||||
|
)]
|
||||||
#![warn(clippy::pedantic)]
|
#![warn(clippy::pedantic)]
|
||||||
#![allow(
|
#![allow(
|
||||||
clippy::missing_errors_doc,
|
clippy::missing_errors_doc,
|
||||||
@ -8,6 +14,9 @@
|
|||||||
clippy::wildcard_imports
|
clippy::wildcard_imports
|
||||||
)]
|
)]
|
||||||
|
|
||||||
|
#[macro_use]
|
||||||
|
extern crate lazy_static;
|
||||||
|
|
||||||
mod ast_gen;
|
mod ast_gen;
|
||||||
mod constant;
|
mod constant;
|
||||||
#[cfg(feature = "fold")]
|
#[cfg(feature = "fold")]
|
||||||
|
@ -1,29 +1,26 @@
|
|||||||
|
[features]
|
||||||
|
test = []
|
||||||
|
|
||||||
[package]
|
[package]
|
||||||
name = "nac3core"
|
name = "nac3core"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
authors = ["M-Labs"]
|
authors = ["M-Labs"]
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
[features]
|
|
||||||
default = ["derive"]
|
|
||||||
derive = ["dep:nac3core_derive"]
|
|
||||||
no-escape-analysis = []
|
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
itertools = "0.13"
|
itertools = "0.13"
|
||||||
crossbeam = "0.8"
|
crossbeam = "0.8"
|
||||||
indexmap = "2.6"
|
indexmap = "2.2"
|
||||||
parking_lot = "0.12"
|
parking_lot = "0.12"
|
||||||
rayon = "1.10"
|
rayon = "1.8"
|
||||||
nac3core_derive = { path = "nac3core_derive", optional = true }
|
|
||||||
nac3parser = { path = "../nac3parser" }
|
nac3parser = { path = "../nac3parser" }
|
||||||
strum = "0.26"
|
strum = "0.26.2"
|
||||||
strum_macros = "0.26"
|
strum_macros = "0.26.4"
|
||||||
|
|
||||||
[dependencies.inkwell]
|
[dependencies.inkwell]
|
||||||
version = "0.5"
|
version = "0.4"
|
||||||
default-features = false
|
default-features = false
|
||||||
features = ["llvm14-0-prefer-dynamic", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]
|
features = ["llvm14-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
test-case = "1.2.0"
|
test-case = "1.2.0"
|
||||||
|
@ -1,36 +1,51 @@
|
|||||||
|
use regex::Regex;
|
||||||
use std::{
|
use std::{
|
||||||
env,
|
env,
|
||||||
fs::File,
|
fs::File,
|
||||||
io::Write,
|
io::Write,
|
||||||
path::Path,
|
path::{Path, PathBuf},
|
||||||
process::{Command, Stdio},
|
process::{Command, Stdio},
|
||||||
};
|
};
|
||||||
|
|
||||||
use regex::Regex;
|
const CMD_IRRT_CLANG: &str = "clang-irrt";
|
||||||
|
const CMD_IRRT_CLANG_TEST: &str = "clang-irrt-test";
|
||||||
|
const CMD_IRRT_LLVM_AS: &str = "llvm-as-irrt";
|
||||||
|
|
||||||
fn main() {
|
fn get_out_dir() -> PathBuf {
|
||||||
let out_dir = env::var("OUT_DIR").unwrap();
|
PathBuf::from(env::var("OUT_DIR").unwrap())
|
||||||
let out_dir = Path::new(&out_dir);
|
}
|
||||||
let irrt_dir = Path::new("irrt");
|
|
||||||
|
|
||||||
let irrt_cpp_path = irrt_dir.join("irrt.cpp");
|
fn get_irrt_dir() -> &'static Path {
|
||||||
|
Path::new("irrt")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Compile `irrt.cpp` for use in `src/codegen`
|
||||||
|
fn compile_irrt_cpp() {
|
||||||
|
let out_dir = get_out_dir();
|
||||||
|
let irrt_dir = get_irrt_dir();
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* HACK: Sadly, clang doesn't let us emit generic LLVM bitcode.
|
* HACK: Sadly, clang doesn't let us emit generic LLVM bitcode.
|
||||||
* Compiling for WASM32 and filtering the output with regex is the closest we can get.
|
* Compiling for WASM32 and filtering the output with regex is the closest we can get.
|
||||||
*/
|
*/
|
||||||
let mut flags: Vec<&str> = vec![
|
let irrt_cpp_path = irrt_dir.join("irrt.cpp");
|
||||||
|
let flags: &[&str] = &[
|
||||||
"--target=wasm32",
|
"--target=wasm32",
|
||||||
"-x",
|
"-x",
|
||||||
"c++",
|
"c++",
|
||||||
"-std=c++20",
|
|
||||||
"-fno-discard-value-names",
|
"-fno-discard-value-names",
|
||||||
"-fno-exceptions",
|
"-fno-exceptions",
|
||||||
"-fno-rtti",
|
"-fno-rtti",
|
||||||
|
match env::var("PROFILE").as_deref() {
|
||||||
|
Ok("debug") => "-O0",
|
||||||
|
Ok("release") => "-O3",
|
||||||
|
flavor => panic!("Unknown or missing build flavor {flavor:?}"),
|
||||||
|
},
|
||||||
"-emit-llvm",
|
"-emit-llvm",
|
||||||
"-S",
|
"-S",
|
||||||
"-Wall",
|
"-Wall",
|
||||||
"-Wextra",
|
"-Wextra",
|
||||||
|
"-Werror=return-type",
|
||||||
"-o",
|
"-o",
|
||||||
"-",
|
"-",
|
||||||
"-I",
|
"-I",
|
||||||
@ -38,26 +53,16 @@ fn main() {
|
|||||||
irrt_cpp_path.to_str().unwrap(),
|
irrt_cpp_path.to_str().unwrap(),
|
||||||
];
|
];
|
||||||
|
|
||||||
match env::var("PROFILE").as_deref() {
|
|
||||||
Ok("debug") => {
|
|
||||||
flags.push("-O0");
|
|
||||||
flags.push("-DIRRT_DEBUG_ASSERT");
|
|
||||||
}
|
|
||||||
Ok("release") => {
|
|
||||||
flags.push("-O3");
|
|
||||||
}
|
|
||||||
flavor => panic!("Unknown or missing build flavor {flavor:?}"),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tell Cargo to rerun if any file under `irrt_dir` (recursive) changes
|
// Tell Cargo to rerun if any file under `irrt_dir` (recursive) changes
|
||||||
println!("cargo:rerun-if-changed={}", irrt_dir.to_str().unwrap());
|
println!("cargo:rerun-if-changed={}", irrt_dir.to_str().unwrap());
|
||||||
|
|
||||||
// Compile IRRT and capture the LLVM IR output
|
// Compile IRRT and capture the LLVM IR output
|
||||||
let output = Command::new("clang-irrt")
|
let output = Command::new(CMD_IRRT_CLANG)
|
||||||
.args(flags)
|
.args(flags)
|
||||||
.output()
|
.output()
|
||||||
.inspect(|o| {
|
.map(|o| {
|
||||||
assert!(o.status.success(), "{}", std::str::from_utf8(&o.stderr).unwrap());
|
assert!(o.status.success(), "{}", std::str::from_utf8(&o.stderr).unwrap());
|
||||||
|
o
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@ -98,7 +103,9 @@ fn main() {
|
|||||||
file.write_all(filtered_output.as_bytes()).unwrap();
|
file.write_all(filtered_output.as_bytes()).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut llvm_as = Command::new("llvm-as-irrt")
|
// Assemble the emitted and filtered IR to .bc
|
||||||
|
// That .bc will be integrated into nac3core's codegen
|
||||||
|
let mut llvm_as = Command::new(CMD_IRRT_LLVM_AS)
|
||||||
.stdin(Stdio::piped())
|
.stdin(Stdio::piped())
|
||||||
.arg("-o")
|
.arg("-o")
|
||||||
.arg(out_dir.join("irrt.bc"))
|
.arg(out_dir.join("irrt.bc"))
|
||||||
@ -107,3 +114,48 @@ fn main() {
|
|||||||
llvm_as.stdin.as_mut().unwrap().write_all(filtered_output.as_bytes()).unwrap();
|
llvm_as.stdin.as_mut().unwrap().write_all(filtered_output.as_bytes()).unwrap();
|
||||||
assert!(llvm_as.wait().unwrap().success());
|
assert!(llvm_as.wait().unwrap().success());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Compile `irrt_test.cpp` for testing
|
||||||
|
fn compile_irrt_test_cpp() {
|
||||||
|
let out_dir = get_out_dir();
|
||||||
|
let irrt_dir = get_irrt_dir();
|
||||||
|
|
||||||
|
let exe_path = out_dir.join("irrt_test.out"); // Output path of the compiled test executable
|
||||||
|
let irrt_test_cpp_path = irrt_dir.join("irrt_test.cpp");
|
||||||
|
let flags: &[&str] = &[
|
||||||
|
irrt_test_cpp_path.to_str().unwrap(),
|
||||||
|
"-x",
|
||||||
|
"c++",
|
||||||
|
"-I",
|
||||||
|
irrt_dir.to_str().unwrap(),
|
||||||
|
"-g",
|
||||||
|
"-fno-discard-value-names",
|
||||||
|
"-O0",
|
||||||
|
"-Wall",
|
||||||
|
"-Wextra",
|
||||||
|
"-Werror=return-type",
|
||||||
|
"-lm", // for `tgamma()`, `lgamma()`
|
||||||
|
"-o",
|
||||||
|
exe_path.to_str().unwrap(),
|
||||||
|
];
|
||||||
|
|
||||||
|
Command::new(CMD_IRRT_CLANG_TEST)
|
||||||
|
.args(flags)
|
||||||
|
.output()
|
||||||
|
.map(|o| {
|
||||||
|
assert!(o.status.success(), "{}", std::str::from_utf8(&o.stderr).unwrap());
|
||||||
|
o
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
println!("cargo:rerun-if-changed={}", irrt_dir.to_str().unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
compile_irrt_cpp();
|
||||||
|
|
||||||
|
// https://github.com/rust-lang/cargo/issues/2549
|
||||||
|
// `cargo test -F test` to also build `irrt_test.cpp
|
||||||
|
if cfg!(feature = "test") {
|
||||||
|
compile_irrt_test_cpp();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,15 +1,9 @@
|
|||||||
#include "irrt/exception.hpp"
|
#define IRRT_DEFINE_TYPEDEF_INTS
|
||||||
#include "irrt/list.hpp"
|
#include <irrt_everything.hpp>
|
||||||
#include "irrt/math.hpp"
|
|
||||||
#include "irrt/range.hpp"
|
/*
|
||||||
#include "irrt/slice.hpp"
|
All IRRT implementations.
|
||||||
#include "irrt/string.hpp"
|
|
||||||
#include "irrt/ndarray/basic.hpp"
|
We don't have any pre-compiled objects, so we are writing all implementations in headers and
|
||||||
#include "irrt/ndarray/def.hpp"
|
concatenate them with `#include` into one massive source file that contains all the IRRT stuff.
|
||||||
#include "irrt/ndarray/iter.hpp"
|
*/
|
||||||
#include "irrt/ndarray/indexing.hpp"
|
|
||||||
#include "irrt/ndarray/array.hpp"
|
|
||||||
#include "irrt/ndarray/reshape.hpp"
|
|
||||||
#include "irrt/ndarray/broadcast.hpp"
|
|
||||||
#include "irrt/ndarray/transpose.hpp"
|
|
||||||
#include "irrt/ndarray/matmul.hpp"
|
|
334
nac3core/irrt/irrt/core.hpp
Normal file
334
nac3core/irrt/irrt/core.hpp
Normal file
@ -0,0 +1,334 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <irrt/int_defs.hpp>
|
||||||
|
#include <irrt/utils.hpp>
|
||||||
|
|
||||||
|
// NDArray indices are always `uint32_t`.
|
||||||
|
using NDIndex = uint32_t;
|
||||||
|
// The type of an index or a value describing the length of a
|
||||||
|
// range/slice is always `int32_t`.
|
||||||
|
using SliceIndex = int32_t;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// adapted from GNU Scientific Library:
|
||||||
|
// https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
|
||||||
|
// need to make sure `exp >= 0` before calling this function
|
||||||
|
template <typename T>
|
||||||
|
T __nac3_int_exp_impl(T base, T exp) {
|
||||||
|
T res = 1;
|
||||||
|
/* repeated squaring method */
|
||||||
|
do {
|
||||||
|
if (exp & 1) {
|
||||||
|
res *= base; /* for n odd */
|
||||||
|
}
|
||||||
|
exp >>= 1;
|
||||||
|
base *= base;
|
||||||
|
} while (exp);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SizeT>
|
||||||
|
SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) {
|
||||||
|
__builtin_assume(end_idx <= list_len);
|
||||||
|
|
||||||
|
SizeT num_elems = 1;
|
||||||
|
for (SizeT i = begin_idx; i < end_idx; ++i) {
|
||||||
|
SizeT val = list_data[i];
|
||||||
|
__builtin_assume(val > 0);
|
||||||
|
num_elems *= val;
|
||||||
|
}
|
||||||
|
return num_elems;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SizeT>
|
||||||
|
void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT num_dims, NDIndex* idxs) {
|
||||||
|
SizeT stride = 1;
|
||||||
|
for (SizeT dim = 0; dim < num_dims; dim++) {
|
||||||
|
SizeT i = num_dims - dim - 1;
|
||||||
|
__builtin_assume(dims[i] > 0);
|
||||||
|
idxs[i] = (index / stride) % dims[i];
|
||||||
|
stride *= dims[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SizeT>
|
||||||
|
SizeT __nac3_ndarray_flatten_index_impl(const SizeT* dims, SizeT num_dims, const NDIndex* indices, SizeT num_indices) {
|
||||||
|
SizeT idx = 0;
|
||||||
|
SizeT stride = 1;
|
||||||
|
for (SizeT i = 0; i < num_dims; ++i) {
|
||||||
|
SizeT ri = num_dims - i - 1;
|
||||||
|
if (ri < num_indices) {
|
||||||
|
idx += stride * indices[ri];
|
||||||
|
}
|
||||||
|
|
||||||
|
__builtin_assume(dims[i] > 0);
|
||||||
|
stride *= dims[ri];
|
||||||
|
}
|
||||||
|
return idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SizeT>
|
||||||
|
void __nac3_ndarray_calc_broadcast_impl(
|
||||||
|
const SizeT* lhs_dims, SizeT lhs_ndims, const SizeT* rhs_dims, SizeT rhs_ndims, SizeT* out_dims
|
||||||
|
) {
|
||||||
|
SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
|
||||||
|
|
||||||
|
for (SizeT i = 0; i < max_ndims; ++i) {
|
||||||
|
const SizeT* lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr;
|
||||||
|
const SizeT* rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr;
|
||||||
|
|
||||||
|
SizeT* out_dim = &out_dims[max_ndims - i - 1];
|
||||||
|
|
||||||
|
if (lhs_dim_sz == nullptr) {
|
||||||
|
*out_dim = *rhs_dim_sz;
|
||||||
|
} else if (rhs_dim_sz == nullptr) {
|
||||||
|
*out_dim = *lhs_dim_sz;
|
||||||
|
} else if (*lhs_dim_sz == 1) {
|
||||||
|
*out_dim = *rhs_dim_sz;
|
||||||
|
} else if (*rhs_dim_sz == 1) {
|
||||||
|
*out_dim = *lhs_dim_sz;
|
||||||
|
} else if (*lhs_dim_sz == *rhs_dim_sz) {
|
||||||
|
*out_dim = *lhs_dim_sz;
|
||||||
|
} else {
|
||||||
|
__builtin_unreachable();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SizeT>
|
||||||
|
void __nac3_ndarray_calc_broadcast_idx_impl(
|
||||||
|
const SizeT* src_dims, SizeT src_ndims, const NDIndex* in_idx, NDIndex* out_idx
|
||||||
|
) {
|
||||||
|
for (SizeT i = 0; i < src_ndims; ++i) {
|
||||||
|
SizeT src_i = src_ndims - i - 1;
|
||||||
|
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
#define DEF_nac3_int_exp_(T) \
|
||||||
|
T __nac3_int_exp_##T(T base, T exp) { \
|
||||||
|
return __nac3_int_exp_impl(base, exp); \
|
||||||
|
}
|
||||||
|
|
||||||
|
DEF_nac3_int_exp_(int32_t);
|
||||||
|
DEF_nac3_int_exp_(int64_t);
|
||||||
|
DEF_nac3_int_exp_(uint32_t);
|
||||||
|
DEF_nac3_int_exp_(uint64_t);
|
||||||
|
|
||||||
|
SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) {
|
||||||
|
if (i < 0) {
|
||||||
|
i = len + i;
|
||||||
|
}
|
||||||
|
if (i < 0) {
|
||||||
|
return 0;
|
||||||
|
} else if (i > len) {
|
||||||
|
return len;
|
||||||
|
}
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
|
||||||
|
SliceIndex __nac3_range_slice_len(const SliceIndex start, const SliceIndex end, const SliceIndex step) {
|
||||||
|
SliceIndex diff = end - start;
|
||||||
|
if (diff > 0 && step > 0) {
|
||||||
|
return ((diff - 1) / step) + 1;
|
||||||
|
} else if (diff < 0 && step < 0) {
|
||||||
|
return ((diff + 1) / step) + 1;
|
||||||
|
} else {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle list assignment and dropping part of the list when
|
||||||
|
// both dest_step and src_step are +1.
|
||||||
|
// - All the index must *not* be out-of-bound or negative,
|
||||||
|
// - The end index is *inclusive*,
|
||||||
|
// - The length of src and dest slice size should already
|
||||||
|
// be checked: if dest.step == 1 then len(src) <= len(dest) else
|
||||||
|
// len(src) == len(dest)
|
||||||
|
SliceIndex __nac3_list_slice_assign_var_size(
|
||||||
|
SliceIndex dest_start,
|
||||||
|
SliceIndex dest_end,
|
||||||
|
SliceIndex dest_step,
|
||||||
|
uint8_t* dest_arr,
|
||||||
|
SliceIndex dest_arr_len,
|
||||||
|
SliceIndex src_start,
|
||||||
|
SliceIndex src_end,
|
||||||
|
SliceIndex src_step,
|
||||||
|
uint8_t* src_arr,
|
||||||
|
SliceIndex src_arr_len,
|
||||||
|
const SliceIndex size
|
||||||
|
) {
|
||||||
|
/* if dest_arr_len == 0, do nothing since we do not support
|
||||||
|
* extending list
|
||||||
|
*/
|
||||||
|
if (dest_arr_len == 0)
|
||||||
|
return dest_arr_len;
|
||||||
|
/* if both step is 1, memmove directly, handle the dropping of
|
||||||
|
* the list, and shrink size */
|
||||||
|
if (src_step == dest_step && dest_step == 1) {
|
||||||
|
const SliceIndex src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0;
|
||||||
|
const SliceIndex dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
|
||||||
|
if (src_len > 0) {
|
||||||
|
__builtin_memmove(dest_arr + dest_start * size, src_arr + src_start * size, src_len * size);
|
||||||
|
}
|
||||||
|
if (dest_len > 0) {
|
||||||
|
/* dropping */
|
||||||
|
__builtin_memmove(
|
||||||
|
dest_arr + (dest_start + src_len) * size,
|
||||||
|
dest_arr + (dest_end + 1) * size,
|
||||||
|
(dest_arr_len - dest_end - 1) * size
|
||||||
|
);
|
||||||
|
}
|
||||||
|
/* shrink size */
|
||||||
|
return dest_arr_len - (dest_len - src_len);
|
||||||
|
}
|
||||||
|
/* if two range overlaps, need alloca */
|
||||||
|
uint8_t need_alloca = (dest_arr == src_arr)
|
||||||
|
&& !(max(dest_start, dest_end) < min(src_start, src_end) || max(src_start, src_end) < min(dest_start, dest_end)
|
||||||
|
);
|
||||||
|
if (need_alloca) {
|
||||||
|
uint8_t* tmp = reinterpret_cast<uint8_t*>(__builtin_alloca(src_arr_len * size));
|
||||||
|
__builtin_memcpy(tmp, src_arr, src_arr_len * size);
|
||||||
|
src_arr = tmp;
|
||||||
|
}
|
||||||
|
SliceIndex src_ind = src_start;
|
||||||
|
SliceIndex dest_ind = dest_start;
|
||||||
|
for (; (src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end); src_ind += src_step, dest_ind += dest_step) {
|
||||||
|
/* for constant optimization */
|
||||||
|
if (size == 1) {
|
||||||
|
__builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1);
|
||||||
|
} else if (size == 4) {
|
||||||
|
__builtin_memcpy(dest_arr + dest_ind * 4, src_arr + src_ind * 4, 4);
|
||||||
|
} else if (size == 8) {
|
||||||
|
__builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8);
|
||||||
|
} else {
|
||||||
|
/* memcpy for var size, cannot overlap after previous
|
||||||
|
* alloca */
|
||||||
|
__builtin_memcpy(dest_arr + dest_ind * size, src_arr + src_ind * size, size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/* only dest_step == 1 can we shrink the dest list. */
|
||||||
|
/* size should be ensured prior to calling this function */
|
||||||
|
if (dest_step == 1 && dest_end >= dest_start) {
|
||||||
|
__builtin_memmove(
|
||||||
|
dest_arr + dest_ind * size,
|
||||||
|
dest_arr + (dest_end + 1) * size,
|
||||||
|
(dest_arr_len - dest_end - 1) * size + size + size + size
|
||||||
|
);
|
||||||
|
return dest_arr_len - (dest_end - dest_ind) - 1;
|
||||||
|
}
|
||||||
|
return dest_arr_len;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t __nac3_isinf(double x) {
|
||||||
|
return __builtin_isinf(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t __nac3_isnan(double x) {
|
||||||
|
return __builtin_isnan(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
double tgamma(double arg);
|
||||||
|
|
||||||
|
double __nac3_gamma(double z) {
|
||||||
|
// Handling for denormals
|
||||||
|
// | x | Python gamma(x) | C tgamma(x) |
|
||||||
|
// --- | ----------------- | --------------- | ----------- |
|
||||||
|
// (1) | nan | nan | nan |
|
||||||
|
// (2) | -inf | -inf | inf |
|
||||||
|
// (3) | inf | inf | inf |
|
||||||
|
// (4) | 0.0 | inf | inf |
|
||||||
|
// (5) | {-1.0, -2.0, ...} | inf | nan |
|
||||||
|
|
||||||
|
// (1)-(3)
|
||||||
|
if (__builtin_isinf(z) || __builtin_isnan(z)) {
|
||||||
|
return z;
|
||||||
|
}
|
||||||
|
|
||||||
|
double v = tgamma(z);
|
||||||
|
|
||||||
|
// (4)-(5)
|
||||||
|
return __builtin_isinf(v) || __builtin_isnan(v) ? __builtin_inf() : v;
|
||||||
|
}
|
||||||
|
|
||||||
|
double lgamma(double arg);
|
||||||
|
|
||||||
|
double __nac3_gammaln(double x) {
|
||||||
|
// libm's handling of value overflows differs from scipy:
|
||||||
|
// - scipy: gammaln(-inf) -> -inf
|
||||||
|
// - libm : lgamma(-inf) -> inf
|
||||||
|
|
||||||
|
if (__builtin_isinf(x)) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
return lgamma(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
double j0(double x);
|
||||||
|
|
||||||
|
double __nac3_j0(double x) {
|
||||||
|
// libm's handling of value overflows differs from scipy:
|
||||||
|
// - scipy: j0(inf) -> nan
|
||||||
|
// - libm : j0(inf) -> 0.0
|
||||||
|
|
||||||
|
if (__builtin_isinf(x)) {
|
||||||
|
return __builtin_nan("");
|
||||||
|
}
|
||||||
|
|
||||||
|
return j0(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t __nac3_ndarray_calc_size(const uint32_t* list_data, uint32_t list_len, uint32_t begin_idx, uint32_t end_idx) {
|
||||||
|
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t
|
||||||
|
__nac3_ndarray_calc_size64(const uint64_t* list_data, uint64_t list_len, uint64_t begin_idx, uint64_t end_idx) {
|
||||||
|
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims, uint32_t num_dims, NDIndex* idxs) {
|
||||||
|
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndex* idxs) {
|
||||||
|
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t
|
||||||
|
__nac3_ndarray_flatten_index(const uint32_t* dims, uint32_t num_dims, const NDIndex* indices, uint32_t num_indices) {
|
||||||
|
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t
|
||||||
|
__nac3_ndarray_flatten_index64(const uint64_t* dims, uint64_t num_dims, const NDIndex* indices, uint64_t num_indices) {
|
||||||
|
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_calc_broadcast(
|
||||||
|
const uint32_t* lhs_dims, uint32_t lhs_ndims, const uint32_t* rhs_dims, uint32_t rhs_ndims, uint32_t* out_dims
|
||||||
|
) {
|
||||||
|
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_calc_broadcast64(
|
||||||
|
const uint64_t* lhs_dims, uint64_t lhs_ndims, const uint64_t* rhs_dims, uint64_t rhs_ndims, uint64_t* out_dims
|
||||||
|
) {
|
||||||
|
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_calc_broadcast_idx(
|
||||||
|
const uint32_t* src_dims, uint32_t src_ndims, const NDIndex* in_idx, NDIndex* out_idx
|
||||||
|
) {
|
||||||
|
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_calc_broadcast_idx64(
|
||||||
|
const uint64_t* src_dims, uint64_t src_ndims, const NDIndex* in_idx, NDIndex* out_idx
|
||||||
|
) {
|
||||||
|
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
|
||||||
|
}
|
||||||
|
} // extern "C"
|
@ -1,9 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "irrt/int_types.hpp"
|
|
||||||
|
|
||||||
template<typename SizeT>
|
|
||||||
struct CSlice {
|
|
||||||
void* base;
|
|
||||||
SizeT len;
|
|
||||||
};
|
|
@ -1,25 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
// Set in nac3core/build.rs
|
|
||||||
#ifdef IRRT_DEBUG_ASSERT
|
|
||||||
#define IRRT_DEBUG_ASSERT_BOOL true
|
|
||||||
#else
|
|
||||||
#define IRRT_DEBUG_ASSERT_BOOL false
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#define raise_debug_assert(SizeT, msg, param1, param2, param3) \
|
|
||||||
raise_exception(SizeT, EXN_ASSERTION_ERROR, "IRRT debug assert failed: " msg, param1, param2, param3)
|
|
||||||
|
|
||||||
#define debug_assert_eq(SizeT, lhs, rhs) \
|
|
||||||
if constexpr (IRRT_DEBUG_ASSERT_BOOL) { \
|
|
||||||
if ((lhs) != (rhs)) { \
|
|
||||||
raise_debug_assert(SizeT, "LHS = {0}. RHS = {1}", lhs, rhs, NO_PARAM); \
|
|
||||||
} \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define debug_assert(SizeT, expr) \
|
|
||||||
if constexpr (IRRT_DEBUG_ASSERT_BOOL) { \
|
|
||||||
if (!(expr)) { \
|
|
||||||
raise_debug_assert(SizeT, "Got false.", NO_PARAM, NO_PARAM, NO_PARAM); \
|
|
||||||
} \
|
|
||||||
}
|
|
87
nac3core/irrt/irrt/error_context.hpp
Normal file
87
nac3core/irrt/irrt/error_context.hpp
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <irrt/int_defs.hpp>
|
||||||
|
#include <irrt/utils.hpp>
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// nac3core's "str" struct type definition
|
||||||
|
template <typename SizeT>
|
||||||
|
struct Str {
|
||||||
|
const char* content;
|
||||||
|
SizeT length;
|
||||||
|
};
|
||||||
|
|
||||||
|
// A limited set of errors IRRT could use.
|
||||||
|
typedef uint32_t ErrorId;
|
||||||
|
struct ErrorIds {
|
||||||
|
ErrorId index_error;
|
||||||
|
ErrorId value_error;
|
||||||
|
ErrorId assertion_error;
|
||||||
|
ErrorId runtime_error;
|
||||||
|
ErrorId type_error;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ErrorContext {
|
||||||
|
// Context
|
||||||
|
ErrorIds* error_ids;
|
||||||
|
|
||||||
|
// Error thrown by IRRT
|
||||||
|
ErrorId error_id;
|
||||||
|
const char* message_template; // MUST BE `&'static`
|
||||||
|
uint64_t param1;
|
||||||
|
uint64_t param2;
|
||||||
|
uint64_t param3;
|
||||||
|
|
||||||
|
void initialize(ErrorIds* error_ids) {
|
||||||
|
this->error_ids = error_ids;
|
||||||
|
clear_error();
|
||||||
|
}
|
||||||
|
|
||||||
|
void clear_error() {
|
||||||
|
// Point the message_template to an empty str. Don't set it to nullptr as a sentinel
|
||||||
|
this->message_template = "";
|
||||||
|
}
|
||||||
|
|
||||||
|
void
|
||||||
|
set_error(ErrorId error_id, const char* message, uint64_t param1 = 0, uint64_t param2 = 0, uint64_t param3 = 0) {
|
||||||
|
this->error_id = error_id;
|
||||||
|
this->message_template = message;
|
||||||
|
this->param1 = param1;
|
||||||
|
this->param2 = param2;
|
||||||
|
this->param3 = param3;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool has_error() { return !cstr_utils::is_empty(message_template); }
|
||||||
|
|
||||||
|
/// Get a nac3core-understanding `Str<SizeT>` that containing
|
||||||
|
/// the error message template
|
||||||
|
template <typename SizeT>
|
||||||
|
void get_error_str(Str<SizeT>* dst_str) {
|
||||||
|
dst_str->content = message_template;
|
||||||
|
dst_str->length = (SizeT)cstr_utils::length(message_template);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
void __nac3_error_context_initialize(ErrorContext* errctx, ErrorIds* error_ids) {
|
||||||
|
errctx->initialize(error_ids);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool __nac3_error_context_has_no_error(ErrorContext* errctx) {
|
||||||
|
return !errctx->has_error();
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_error_context_get_error_str(ErrorContext* errctx, Str<int32_t>* dst_str) {
|
||||||
|
errctx->get_error_str<int32_t>(dst_str);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_error_context_get_error_str64(ErrorContext* errctx, Str<int64_t>* dst_str) {
|
||||||
|
errctx->get_error_str<int64_t>(dst_str);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Used for testing
|
||||||
|
void __nac3_error_dummy_raise(ErrorContext* errctx) {
|
||||||
|
errctx->set_error(errctx->error_ids->runtime_error, "Error thrown from __nac3_error_dummy_raise");
|
||||||
|
}
|
||||||
|
}
|
@ -1,85 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "irrt/cslice.hpp"
|
|
||||||
#include "irrt/int_types.hpp"
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief The int type of ARTIQ exception IDs.
|
|
||||||
*/
|
|
||||||
using ExceptionId = int32_t;
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Set of exceptions C++ IRRT can use.
|
|
||||||
* Must be synchronized with `setup_irrt_exceptions` in `nac3core/src/codegen/irrt/mod.rs`.
|
|
||||||
*/
|
|
||||||
extern "C" {
|
|
||||||
ExceptionId EXN_INDEX_ERROR;
|
|
||||||
ExceptionId EXN_VALUE_ERROR;
|
|
||||||
ExceptionId EXN_ASSERTION_ERROR;
|
|
||||||
ExceptionId EXN_TYPE_ERROR;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Extern function to `__nac3_raise`
|
|
||||||
*
|
|
||||||
* The parameter `err` could be `Exception<int32_t>` or `Exception<int64_t>`. The caller
|
|
||||||
* must make sure to pass `Exception`s with the correct `SizeT` depending on the `size_t` of the runtime.
|
|
||||||
*/
|
|
||||||
extern "C" void __nac3_raise(void* err);
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
/**
|
|
||||||
* @brief NAC3's Exception struct
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
struct Exception {
|
|
||||||
ExceptionId id;
|
|
||||||
CSlice<SizeT> filename;
|
|
||||||
int32_t line;
|
|
||||||
int32_t column;
|
|
||||||
CSlice<SizeT> function;
|
|
||||||
CSlice<SizeT> msg;
|
|
||||||
int64_t params[3];
|
|
||||||
};
|
|
||||||
|
|
||||||
constexpr int64_t NO_PARAM = 0;
|
|
||||||
|
|
||||||
template<typename SizeT>
|
|
||||||
void _raise_exception_helper(ExceptionId id,
|
|
||||||
const char* filename,
|
|
||||||
int32_t line,
|
|
||||||
const char* function,
|
|
||||||
const char* msg,
|
|
||||||
int64_t param0,
|
|
||||||
int64_t param1,
|
|
||||||
int64_t param2) {
|
|
||||||
Exception<SizeT> e = {
|
|
||||||
.id = id,
|
|
||||||
.filename = {.base = reinterpret_cast<void*>(const_cast<char*>(filename)),
|
|
||||||
.len = static_cast<SizeT>(__builtin_strlen(filename))},
|
|
||||||
.line = line,
|
|
||||||
.column = 0,
|
|
||||||
.function = {.base = reinterpret_cast<void*>(const_cast<char*>(function)),
|
|
||||||
.len = static_cast<SizeT>(__builtin_strlen(function))},
|
|
||||||
.msg = {.base = reinterpret_cast<void*>(const_cast<char*>(msg)),
|
|
||||||
.len = static_cast<SizeT>(__builtin_strlen(msg))},
|
|
||||||
};
|
|
||||||
e.params[0] = param0;
|
|
||||||
e.params[1] = param1;
|
|
||||||
e.params[2] = param2;
|
|
||||||
__nac3_raise(reinterpret_cast<void*>(&e));
|
|
||||||
__builtin_unreachable();
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Raise an exception with location details (location in the IRRT source files).
|
|
||||||
* @param SizeT The runtime `size_t` type.
|
|
||||||
* @param id The ID of the exception to raise.
|
|
||||||
* @param msg A global constant C-string of the error message.
|
|
||||||
*
|
|
||||||
* `param0` to `param2` are optional format arguments of `msg`. They should be set to
|
|
||||||
* `NO_PARAM` to indicate they are unused.
|
|
||||||
*/
|
|
||||||
#define raise_exception(SizeT, id, msg, param0, param1, param2) \
|
|
||||||
_raise_exception_helper<SizeT>(id, __FILE__, __LINE__, __FUNCTION__, msg, param0, param1, param2)
|
|
12
nac3core/irrt/irrt/int_defs.hpp
Normal file
12
nac3core/irrt/irrt/int_defs.hpp
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
// This is made toggleable since `irrt_test.cpp` itself would include
|
||||||
|
// headers that define these typedefs
|
||||||
|
#ifdef IRRT_DEFINE_TYPEDEF_INTS
|
||||||
|
using int8_t = _BitInt(8);
|
||||||
|
using uint8_t = unsigned _BitInt(8);
|
||||||
|
using int32_t = _BitInt(32);
|
||||||
|
using uint32_t = unsigned _BitInt(32);
|
||||||
|
using int64_t = _BitInt(64);
|
||||||
|
using uint64_t = unsigned _BitInt(64);
|
||||||
|
#endif
|
@ -1,25 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#if __STDC_VERSION__ >= 202000
|
|
||||||
using int8_t = _BitInt(8);
|
|
||||||
using uint8_t = unsigned _BitInt(8);
|
|
||||||
using int32_t = _BitInt(32);
|
|
||||||
using uint32_t = unsigned _BitInt(32);
|
|
||||||
using int64_t = _BitInt(64);
|
|
||||||
using uint64_t = unsigned _BitInt(64);
|
|
||||||
#else
|
|
||||||
|
|
||||||
#pragma clang diagnostic push
|
|
||||||
#pragma clang diagnostic ignored "-Wdeprecated-type"
|
|
||||||
using int8_t = _ExtInt(8);
|
|
||||||
using uint8_t = unsigned _ExtInt(8);
|
|
||||||
using int32_t = _ExtInt(32);
|
|
||||||
using uint32_t = unsigned _ExtInt(32);
|
|
||||||
using int64_t = _ExtInt(64);
|
|
||||||
using uint64_t = unsigned _ExtInt(64);
|
|
||||||
#pragma clang diagnostic pop
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// The type of an index or a value describing the length of a range/slice is always `int32_t`.
|
|
||||||
using SliceIndex = int32_t;
|
|
@ -1,96 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "irrt/int_types.hpp"
|
|
||||||
#include "irrt/math_util.hpp"
|
|
||||||
#include "irrt/slice.hpp"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
/**
|
|
||||||
* @brief A list in NAC3.
|
|
||||||
*
|
|
||||||
* The `items` field is opaque. You must rely on external contexts to
|
|
||||||
* know how to interpret it.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
struct List {
|
|
||||||
uint8_t* items;
|
|
||||||
SizeT len;
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
// Handle list assignment and dropping part of the list when
|
|
||||||
// both dest_step and src_step are +1.
|
|
||||||
// - All the index must *not* be out-of-bound or negative,
|
|
||||||
// - The end index is *inclusive*,
|
|
||||||
// - The length of src and dest slice size should already
|
|
||||||
// be checked: if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest)
|
|
||||||
SliceIndex __nac3_list_slice_assign_var_size(SliceIndex dest_start,
|
|
||||||
SliceIndex dest_end,
|
|
||||||
SliceIndex dest_step,
|
|
||||||
void* dest_arr,
|
|
||||||
SliceIndex dest_arr_len,
|
|
||||||
SliceIndex src_start,
|
|
||||||
SliceIndex src_end,
|
|
||||||
SliceIndex src_step,
|
|
||||||
void* src_arr,
|
|
||||||
SliceIndex src_arr_len,
|
|
||||||
const SliceIndex size) {
|
|
||||||
/* if dest_arr_len == 0, do nothing since we do not support extending list */
|
|
||||||
if (dest_arr_len == 0)
|
|
||||||
return dest_arr_len;
|
|
||||||
/* if both step is 1, memmove directly, handle the dropping of the list, and shrink size */
|
|
||||||
if (src_step == dest_step && dest_step == 1) {
|
|
||||||
const SliceIndex src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0;
|
|
||||||
const SliceIndex dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
|
|
||||||
if (src_len > 0) {
|
|
||||||
__builtin_memmove(static_cast<uint8_t*>(dest_arr) + dest_start * size,
|
|
||||||
static_cast<uint8_t*>(src_arr) + src_start * size, src_len * size);
|
|
||||||
}
|
|
||||||
if (dest_len > 0) {
|
|
||||||
/* dropping */
|
|
||||||
__builtin_memmove(static_cast<uint8_t*>(dest_arr) + (dest_start + src_len) * size,
|
|
||||||
static_cast<uint8_t*>(dest_arr) + (dest_end + 1) * size,
|
|
||||||
(dest_arr_len - dest_end - 1) * size);
|
|
||||||
}
|
|
||||||
/* shrink size */
|
|
||||||
return dest_arr_len - (dest_len - src_len);
|
|
||||||
}
|
|
||||||
/* if two range overlaps, need alloca */
|
|
||||||
uint8_t need_alloca = (dest_arr == src_arr)
|
|
||||||
&& !(max(dest_start, dest_end) < min(src_start, src_end)
|
|
||||||
|| max(src_start, src_end) < min(dest_start, dest_end));
|
|
||||||
if (need_alloca) {
|
|
||||||
void* tmp = __builtin_alloca(src_arr_len * size);
|
|
||||||
__builtin_memcpy(tmp, src_arr, src_arr_len * size);
|
|
||||||
src_arr = tmp;
|
|
||||||
}
|
|
||||||
SliceIndex src_ind = src_start;
|
|
||||||
SliceIndex dest_ind = dest_start;
|
|
||||||
for (; (src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end); src_ind += src_step, dest_ind += dest_step) {
|
|
||||||
/* for constant optimization */
|
|
||||||
if (size == 1) {
|
|
||||||
__builtin_memcpy(static_cast<uint8_t*>(dest_arr) + dest_ind, static_cast<uint8_t*>(src_arr) + src_ind, 1);
|
|
||||||
} else if (size == 4) {
|
|
||||||
__builtin_memcpy(static_cast<uint8_t*>(dest_arr) + dest_ind * 4,
|
|
||||||
static_cast<uint8_t*>(src_arr) + src_ind * 4, 4);
|
|
||||||
} else if (size == 8) {
|
|
||||||
__builtin_memcpy(static_cast<uint8_t*>(dest_arr) + dest_ind * 8,
|
|
||||||
static_cast<uint8_t*>(src_arr) + src_ind * 8, 8);
|
|
||||||
} else {
|
|
||||||
/* memcpy for var size, cannot overlap after previous alloca */
|
|
||||||
__builtin_memcpy(static_cast<uint8_t*>(dest_arr) + dest_ind * size,
|
|
||||||
static_cast<uint8_t*>(src_arr) + src_ind * size, size);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
/* only dest_step == 1 can we shrink the dest list. */
|
|
||||||
/* size should be ensured prior to calling this function */
|
|
||||||
if (dest_step == 1 && dest_end >= dest_start) {
|
|
||||||
__builtin_memmove(static_cast<uint8_t*>(dest_arr) + dest_ind * size,
|
|
||||||
static_cast<uint8_t*>(dest_arr) + (dest_end + 1) * size,
|
|
||||||
(dest_arr_len - dest_end - 1) * size);
|
|
||||||
return dest_arr_len - (dest_end - dest_ind) - 1;
|
|
||||||
}
|
|
||||||
return dest_arr_len;
|
|
||||||
}
|
|
||||||
} // extern "C"
|
|
@ -1,95 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "irrt/int_types.hpp"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
|
|
||||||
// need to make sure `exp >= 0` before calling this function
|
|
||||||
template<typename T>
|
|
||||||
T __nac3_int_exp_impl(T base, T exp) {
|
|
||||||
T res = 1;
|
|
||||||
/* repeated squaring method */
|
|
||||||
do {
|
|
||||||
if (exp & 1) {
|
|
||||||
res *= base; /* for n odd */
|
|
||||||
}
|
|
||||||
exp >>= 1;
|
|
||||||
base *= base;
|
|
||||||
} while (exp);
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
#define DEF_nac3_int_exp_(T) \
|
|
||||||
T __nac3_int_exp_##T(T base, T exp) { \
|
|
||||||
return __nac3_int_exp_impl(base, exp); \
|
|
||||||
}
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
|
|
||||||
// Putting semicolons here to make clang-format not reformat this into
|
|
||||||
// a stair shape.
|
|
||||||
DEF_nac3_int_exp_(int32_t);
|
|
||||||
DEF_nac3_int_exp_(int64_t);
|
|
||||||
DEF_nac3_int_exp_(uint32_t);
|
|
||||||
DEF_nac3_int_exp_(uint64_t);
|
|
||||||
|
|
||||||
int32_t __nac3_isinf(double x) {
|
|
||||||
return __builtin_isinf(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t __nac3_isnan(double x) {
|
|
||||||
return __builtin_isnan(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
double tgamma(double arg);
|
|
||||||
|
|
||||||
double __nac3_gamma(double z) {
|
|
||||||
// Handling for denormals
|
|
||||||
// | x | Python gamma(x) | C tgamma(x) |
|
|
||||||
// --- | ----------------- | --------------- | ----------- |
|
|
||||||
// (1) | nan | nan | nan |
|
|
||||||
// (2) | -inf | -inf | inf |
|
|
||||||
// (3) | inf | inf | inf |
|
|
||||||
// (4) | 0.0 | inf | inf |
|
|
||||||
// (5) | {-1.0, -2.0, ...} | inf | nan |
|
|
||||||
|
|
||||||
// (1)-(3)
|
|
||||||
if (__builtin_isinf(z) || __builtin_isnan(z)) {
|
|
||||||
return z;
|
|
||||||
}
|
|
||||||
|
|
||||||
double v = tgamma(z);
|
|
||||||
|
|
||||||
// (4)-(5)
|
|
||||||
return __builtin_isinf(v) || __builtin_isnan(v) ? __builtin_inf() : v;
|
|
||||||
}
|
|
||||||
|
|
||||||
double lgamma(double arg);
|
|
||||||
|
|
||||||
double __nac3_gammaln(double x) {
|
|
||||||
// libm's handling of value overflows differs from scipy:
|
|
||||||
// - scipy: gammaln(-inf) -> -inf
|
|
||||||
// - libm : lgamma(-inf) -> inf
|
|
||||||
|
|
||||||
if (__builtin_isinf(x)) {
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
return lgamma(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
double j0(double x);
|
|
||||||
|
|
||||||
double __nac3_j0(double x) {
|
|
||||||
// libm's handling of value overflows differs from scipy:
|
|
||||||
// - scipy: j0(inf) -> nan
|
|
||||||
// - libm : j0(inf) -> 0.0
|
|
||||||
|
|
||||||
if (__builtin_isinf(x)) {
|
|
||||||
return __builtin_nan("");
|
|
||||||
}
|
|
||||||
|
|
||||||
return j0(x);
|
|
||||||
}
|
|
||||||
} // namespace
|
|
@ -1,13 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
template<typename T>
|
|
||||||
const T& max(const T& a, const T& b) {
|
|
||||||
return a > b ? a : b;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
const T& min(const T& a, const T& b) {
|
|
||||||
return a > b ? b : a;
|
|
||||||
}
|
|
||||||
} // namespace
|
|
@ -1,132 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "irrt/debug.hpp"
|
|
||||||
#include "irrt/exception.hpp"
|
|
||||||
#include "irrt/int_types.hpp"
|
|
||||||
#include "irrt/list.hpp"
|
|
||||||
#include "irrt/ndarray/basic.hpp"
|
|
||||||
#include "irrt/ndarray/def.hpp"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
namespace ndarray::array {
|
|
||||||
/**
|
|
||||||
* @brief In the context of `np.array(<list>)`, deduce the ndarray's shape produced by `<list>` and raise
|
|
||||||
* an exception if there is anything wrong with `<shape>` (e.g., inconsistent dimensions `np.array([[1.0, 2.0],
|
|
||||||
* [3.0]])`)
|
|
||||||
*
|
|
||||||
* If this function finds no issues with `<list>`, the deduced shape is written to `shape`. The caller has the
|
|
||||||
* responsibility to allocate `[SizeT; ndims]` for `shape`. The caller must also initialize `shape` with `-1`s because
|
|
||||||
* of implementation details.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void set_and_validate_list_shape_helper(SizeT axis, List<SizeT>* list, SizeT ndims, SizeT* shape) {
|
|
||||||
if (shape[axis] == -1) {
|
|
||||||
// Dimension is unspecified. Set it.
|
|
||||||
shape[axis] = list->len;
|
|
||||||
} else {
|
|
||||||
// Dimension is specified. Check.
|
|
||||||
if (shape[axis] != list->len) {
|
|
||||||
// Mismatch, throw an error.
|
|
||||||
// NOTE: NumPy's error message is more complex and needs more PARAMS to display.
|
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR,
|
|
||||||
"The requested array has an inhomogenous shape "
|
|
||||||
"after {0} dimension(s).",
|
|
||||||
axis, shape[axis], list->len);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (axis + 1 == ndims) {
|
|
||||||
// `list` has type `list[ItemType]`
|
|
||||||
// Do nothing
|
|
||||||
} else {
|
|
||||||
// `list` has type `list[list[...]]`
|
|
||||||
List<SizeT>** lists = (List<SizeT>**)(list->items);
|
|
||||||
for (SizeT i = 0; i < list->len; i++) {
|
|
||||||
set_and_validate_list_shape_helper<SizeT>(axis + 1, lists[i], ndims, shape);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief See `set_and_validate_list_shape_helper`.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void set_and_validate_list_shape(List<SizeT>* list, SizeT ndims, SizeT* shape) {
|
|
||||||
for (SizeT axis = 0; axis < ndims; axis++) {
|
|
||||||
shape[axis] = -1; // Sentinel to say this dimension is unspecified.
|
|
||||||
}
|
|
||||||
set_and_validate_list_shape_helper<SizeT>(0, list, ndims, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief In the context of `np.array(<list>)`, copied the contents stored in `list` to `ndarray`.
|
|
||||||
*
|
|
||||||
* `list` is assumed to be "legal". (i.e., no inconsistent dimensions)
|
|
||||||
*
|
|
||||||
* # Notes on `ndarray`
|
|
||||||
* The caller is responsible for allocating space for `ndarray`.
|
|
||||||
* Here is what this function expects from `ndarray` when called:
|
|
||||||
* - `ndarray->data` has to be allocated, contiguous, and may contain uninitialized values.
|
|
||||||
* - `ndarray->itemsize` has to be initialized.
|
|
||||||
* - `ndarray->ndims` has to be initialized.
|
|
||||||
* - `ndarray->shape` has to be initialized.
|
|
||||||
* - `ndarray->strides` is ignored, but note that `ndarray->data` is contiguous.
|
|
||||||
* When this function call ends:
|
|
||||||
* - `ndarray->data` is written with contents from `<list>`.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void write_list_to_array_helper(SizeT axis, SizeT* index, List<SizeT>* list, NDArray<SizeT>* ndarray) {
|
|
||||||
debug_assert_eq(SizeT, list->len, ndarray->shape[axis]);
|
|
||||||
if (IRRT_DEBUG_ASSERT_BOOL) {
|
|
||||||
if (!ndarray::basic::is_c_contiguous(ndarray)) {
|
|
||||||
raise_debug_assert(SizeT, "ndarray is not C-contiguous", ndarray->strides[0], ndarray->strides[1],
|
|
||||||
NO_PARAM);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (axis + 1 == ndarray->ndims) {
|
|
||||||
// `list` has type `list[scalar]`
|
|
||||||
// `ndarray` is contiguous, so we can do this, and this is fast.
|
|
||||||
uint8_t* dst = static_cast<uint8_t*>(ndarray->data) + (ndarray->itemsize * (*index));
|
|
||||||
__builtin_memcpy(dst, list->items, ndarray->itemsize * list->len);
|
|
||||||
*index += list->len;
|
|
||||||
} else {
|
|
||||||
// `list` has type `list[list[...]]`
|
|
||||||
List<SizeT>** lists = (List<SizeT>**)(list->items);
|
|
||||||
|
|
||||||
for (SizeT i = 0; i < list->len; i++) {
|
|
||||||
write_list_to_array_helper<SizeT>(axis + 1, index, lists[i], ndarray);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief See `write_list_to_array_helper`.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void write_list_to_array(List<SizeT>* list, NDArray<SizeT>* ndarray) {
|
|
||||||
SizeT index = 0;
|
|
||||||
write_list_to_array_helper<SizeT>((SizeT)0, &index, list, ndarray);
|
|
||||||
}
|
|
||||||
} // namespace ndarray::array
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
using namespace ndarray::array;
|
|
||||||
|
|
||||||
void __nac3_ndarray_array_set_and_validate_list_shape(List<int32_t>* list, int32_t ndims, int32_t* shape) {
|
|
||||||
set_and_validate_list_shape(list, ndims, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_array_set_and_validate_list_shape64(List<int64_t>* list, int64_t ndims, int64_t* shape) {
|
|
||||||
set_and_validate_list_shape(list, ndims, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_array_write_list_to_array(List<int32_t>* list, NDArray<int32_t>* ndarray) {
|
|
||||||
write_list_to_array(list, ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_array_write_list_to_array64(List<int64_t>* list, NDArray<int64_t>* ndarray) {
|
|
||||||
write_list_to_array(list, ndarray);
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,340 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "irrt/debug.hpp"
|
|
||||||
#include "irrt/exception.hpp"
|
|
||||||
#include "irrt/int_types.hpp"
|
|
||||||
#include "irrt/ndarray/def.hpp"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
namespace ndarray::basic {
|
|
||||||
/**
|
|
||||||
* @brief Assert that `shape` does not contain negative dimensions.
|
|
||||||
*
|
|
||||||
* @param ndims Number of dimensions in `shape`
|
|
||||||
* @param shape The shape to check on
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void assert_shape_no_negative(SizeT ndims, const SizeT* shape) {
|
|
||||||
for (SizeT axis = 0; axis < ndims; axis++) {
|
|
||||||
if (shape[axis] < 0) {
|
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR,
|
|
||||||
"negative dimensions are not allowed; axis {0} "
|
|
||||||
"has dimension {1}",
|
|
||||||
axis, shape[axis], NO_PARAM);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Assert that two shapes are the same in the context of writing output to an ndarray.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void assert_output_shape_same(SizeT ndarray_ndims,
|
|
||||||
const SizeT* ndarray_shape,
|
|
||||||
SizeT output_ndims,
|
|
||||||
const SizeT* output_shape) {
|
|
||||||
if (ndarray_ndims != output_ndims) {
|
|
||||||
// There is no corresponding NumPy error message like this.
|
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot write output of ndims {0} to an ndarray with ndims {1}",
|
|
||||||
output_ndims, ndarray_ndims, NO_PARAM);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (SizeT axis = 0; axis < ndarray_ndims; axis++) {
|
|
||||||
if (ndarray_shape[axis] != output_shape[axis]) {
|
|
||||||
// There is no corresponding NumPy error message like this.
|
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR,
|
|
||||||
"Mismatched dimensions on axis {0}, output has "
|
|
||||||
"dimension {1}, but destination ndarray has dimension {2}.",
|
|
||||||
axis, output_shape[axis], ndarray_shape[axis]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Return the number of elements of an ndarray given its shape.
|
|
||||||
*
|
|
||||||
* @param ndims Number of dimensions in `shape`
|
|
||||||
* @param shape The shape of the ndarray
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
SizeT calc_size_from_shape(SizeT ndims, const SizeT* shape) {
|
|
||||||
SizeT size = 1;
|
|
||||||
for (SizeT axis = 0; axis < ndims; axis++)
|
|
||||||
size *= shape[axis];
|
|
||||||
return size;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Compute the array indices of the `nth` (0-based) element of an ndarray given only its shape.
|
|
||||||
*
|
|
||||||
* @param ndims Number of elements in `shape` and `indices`
|
|
||||||
* @param shape The shape of the ndarray
|
|
||||||
* @param indices The returned indices indexing the ndarray with shape `shape`.
|
|
||||||
* @param nth The index of the element of interest.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void set_indices_by_nth(SizeT ndims, const SizeT* shape, SizeT* indices, SizeT nth) {
|
|
||||||
for (SizeT i = 0; i < ndims; i++) {
|
|
||||||
SizeT axis = ndims - i - 1;
|
|
||||||
SizeT dim = shape[axis];
|
|
||||||
|
|
||||||
indices[axis] = nth % dim;
|
|
||||||
nth /= dim;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Return the number of elements of an `ndarray`
|
|
||||||
*
|
|
||||||
* This function corresponds to `<an_ndarray>.size`
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
SizeT size(const NDArray<SizeT>* ndarray) {
|
|
||||||
return calc_size_from_shape(ndarray->ndims, ndarray->shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Return of the number of its content of an `ndarray`.
|
|
||||||
*
|
|
||||||
* This function corresponds to `<an_ndarray>.nbytes`.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
SizeT nbytes(const NDArray<SizeT>* ndarray) {
|
|
||||||
return size(ndarray) * ndarray->itemsize;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Get the `len()` of an ndarray, and asserts that `ndarray` is a sized object.
|
|
||||||
*
|
|
||||||
* This function corresponds to `<an_ndarray>.__len__`.
|
|
||||||
*
|
|
||||||
* @param dst_length The length.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
SizeT len(const NDArray<SizeT>* ndarray) {
|
|
||||||
if (ndarray->ndims != 0) {
|
|
||||||
return ndarray->shape[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
// numpy prohibits `__len__` on unsized objects
|
|
||||||
raise_exception(SizeT, EXN_TYPE_ERROR, "len() of unsized object", NO_PARAM, NO_PARAM, NO_PARAM);
|
|
||||||
__builtin_unreachable();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Return a boolean indicating if `ndarray` is (C-)contiguous.
|
|
||||||
*
|
|
||||||
* You may want to see ndarray's rules for C-contiguity:
|
|
||||||
* https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
bool is_c_contiguous(const NDArray<SizeT>* ndarray) {
|
|
||||||
// References:
|
|
||||||
// - tinynumpy's implementation:
|
|
||||||
// https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L102
|
|
||||||
// - ndarray's flags["C_CONTIGUOUS"]:
|
|
||||||
// https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags
|
|
||||||
// - ndarray's rules for C-contiguity:
|
|
||||||
// https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45
|
|
||||||
|
|
||||||
// From
|
|
||||||
// https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45:
|
|
||||||
//
|
|
||||||
// The traditional rule is that for an array to be flagged as C contiguous,
|
|
||||||
// the following must hold:
|
|
||||||
//
|
|
||||||
// strides[-1] == itemsize
|
|
||||||
// strides[i] == shape[i+1] * strides[i + 1]
|
|
||||||
// [...]
|
|
||||||
// According to these rules, a 0- or 1-dimensional array is either both
|
|
||||||
// C- and F-contiguous, or neither; and an array with 2+ dimensions
|
|
||||||
// can be C- or F- contiguous, or neither, but not both. Though there
|
|
||||||
// there are exceptions for arrays with zero or one item, in the first
|
|
||||||
// case the check is relaxed up to and including the first dimension
|
|
||||||
// with shape[i] == 0. In the second case `strides == itemsize` will
|
|
||||||
// can be true for all dimensions and both flags are set.
|
|
||||||
|
|
||||||
if (ndarray->ndims == 0) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ndarray->strides[ndarray->ndims - 1] != ndarray->itemsize) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (SizeT i = 1; i < ndarray->ndims; i++) {
|
|
||||||
SizeT axis_i = ndarray->ndims - i - 1;
|
|
||||||
if (ndarray->strides[axis_i] != ndarray->shape[axis_i + 1] * ndarray->strides[axis_i + 1]) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Return the pointer to the element indexed by `indices` along the ndarray's axes.
|
|
||||||
*
|
|
||||||
* This function does no bound check.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void* get_pelement_by_indices(const NDArray<SizeT>* ndarray, const SizeT* indices) {
|
|
||||||
void* element = ndarray->data;
|
|
||||||
for (SizeT dim_i = 0; dim_i < ndarray->ndims; dim_i++)
|
|
||||||
element = static_cast<uint8_t*>(element) + indices[dim_i] * ndarray->strides[dim_i];
|
|
||||||
return element;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Return the pointer to the nth (0-based) element of `ndarray` in flattened view.
|
|
||||||
*
|
|
||||||
* This function does no bound check.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void* get_nth_pelement(const NDArray<SizeT>* ndarray, SizeT nth) {
|
|
||||||
void* element = ndarray->data;
|
|
||||||
for (SizeT i = 0; i < ndarray->ndims; i++) {
|
|
||||||
SizeT axis = ndarray->ndims - i - 1;
|
|
||||||
SizeT dim = ndarray->shape[axis];
|
|
||||||
element = static_cast<uint8_t*>(element) + ndarray->strides[axis] * (nth % dim);
|
|
||||||
nth /= dim;
|
|
||||||
}
|
|
||||||
return element;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Update the strides of an ndarray given an ndarray `shape` to be contiguous.
|
|
||||||
*
|
|
||||||
* You might want to read https://ajcr.net/stride-guide-part-1/.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void set_strides_by_shape(NDArray<SizeT>* ndarray) {
|
|
||||||
SizeT stride_product = 1;
|
|
||||||
for (SizeT i = 0; i < ndarray->ndims; i++) {
|
|
||||||
SizeT axis = ndarray->ndims - i - 1;
|
|
||||||
ndarray->strides[axis] = stride_product * ndarray->itemsize;
|
|
||||||
stride_product *= ndarray->shape[axis];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Set an element in `ndarray`.
|
|
||||||
*
|
|
||||||
* @param pelement Pointer to the element in `ndarray` to be set.
|
|
||||||
* @param pvalue Pointer to the value `pelement` will be set to.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void set_pelement_value(NDArray<SizeT>* ndarray, void* pelement, const void* pvalue) {
|
|
||||||
__builtin_memcpy(pelement, pvalue, ndarray->itemsize);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Copy data from one ndarray to another of the exact same size and itemsize.
|
|
||||||
*
|
|
||||||
* Both ndarrays will be viewed in their flatten views when copying the elements.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void copy_data(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
|
||||||
// TODO: Make this faster with memcpy when we see a contiguous segment.
|
|
||||||
// TODO: Handle overlapping.
|
|
||||||
|
|
||||||
debug_assert_eq(SizeT, src_ndarray->itemsize, dst_ndarray->itemsize);
|
|
||||||
|
|
||||||
for (SizeT i = 0; i < size(src_ndarray); i++) {
|
|
||||||
auto src_element = ndarray::basic::get_nth_pelement(src_ndarray, i);
|
|
||||||
auto dst_element = ndarray::basic::get_nth_pelement(dst_ndarray, i);
|
|
||||||
ndarray::basic::set_pelement_value(dst_ndarray, dst_element, src_element);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace ndarray::basic
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
using namespace ndarray::basic;
|
|
||||||
|
|
||||||
void __nac3_ndarray_util_assert_shape_no_negative(int32_t ndims, int32_t* shape) {
|
|
||||||
assert_shape_no_negative(ndims, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_util_assert_shape_no_negative64(int64_t ndims, int64_t* shape) {
|
|
||||||
assert_shape_no_negative(ndims, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_util_assert_output_shape_same(int32_t ndarray_ndims,
|
|
||||||
const int32_t* ndarray_shape,
|
|
||||||
int32_t output_ndims,
|
|
||||||
const int32_t* output_shape) {
|
|
||||||
assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims, output_shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_util_assert_output_shape_same64(int64_t ndarray_ndims,
|
|
||||||
const int64_t* ndarray_shape,
|
|
||||||
int64_t output_ndims,
|
|
||||||
const int64_t* output_shape) {
|
|
||||||
assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims, output_shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t __nac3_ndarray_size(NDArray<int32_t>* ndarray) {
|
|
||||||
return size(ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t __nac3_ndarray_size64(NDArray<int64_t>* ndarray) {
|
|
||||||
return size(ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t __nac3_ndarray_nbytes(NDArray<int32_t>* ndarray) {
|
|
||||||
return nbytes(ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t __nac3_ndarray_nbytes64(NDArray<int64_t>* ndarray) {
|
|
||||||
return nbytes(ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t __nac3_ndarray_len(NDArray<int32_t>* ndarray) {
|
|
||||||
return len(ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t __nac3_ndarray_len64(NDArray<int64_t>* ndarray) {
|
|
||||||
return len(ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool __nac3_ndarray_is_c_contiguous(NDArray<int32_t>* ndarray) {
|
|
||||||
return is_c_contiguous(ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool __nac3_ndarray_is_c_contiguous64(NDArray<int64_t>* ndarray) {
|
|
||||||
return is_c_contiguous(ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
void* __nac3_ndarray_get_nth_pelement(const NDArray<int32_t>* ndarray, int32_t nth) {
|
|
||||||
return get_nth_pelement(ndarray, nth);
|
|
||||||
}
|
|
||||||
|
|
||||||
void* __nac3_ndarray_get_nth_pelement64(const NDArray<int64_t>* ndarray, int64_t nth) {
|
|
||||||
return get_nth_pelement(ndarray, nth);
|
|
||||||
}
|
|
||||||
|
|
||||||
void* __nac3_ndarray_get_pelement_by_indices(const NDArray<int32_t>* ndarray, int32_t* indices) {
|
|
||||||
return get_pelement_by_indices(ndarray, indices);
|
|
||||||
}
|
|
||||||
|
|
||||||
void* __nac3_ndarray_get_pelement_by_indices64(const NDArray<int64_t>* ndarray, int64_t* indices) {
|
|
||||||
return get_pelement_by_indices(ndarray, indices);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_set_strides_by_shape(NDArray<int32_t>* ndarray) {
|
|
||||||
set_strides_by_shape(ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_set_strides_by_shape64(NDArray<int64_t>* ndarray) {
|
|
||||||
set_strides_by_shape(ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_copy_data(NDArray<int32_t>* src_ndarray, NDArray<int32_t>* dst_ndarray) {
|
|
||||||
copy_data(src_ndarray, dst_ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_copy_data64(NDArray<int64_t>* src_ndarray, NDArray<int64_t>* dst_ndarray) {
|
|
||||||
copy_data(src_ndarray, dst_ndarray);
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,165 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "irrt/int_types.hpp"
|
|
||||||
#include "irrt/ndarray/def.hpp"
|
|
||||||
#include "irrt/slice.hpp"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
template<typename SizeT>
|
|
||||||
struct ShapeEntry {
|
|
||||||
SizeT ndims;
|
|
||||||
SizeT* shape;
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
namespace ndarray::broadcast {
|
|
||||||
/**
|
|
||||||
* @brief Return true if `src_shape` can broadcast to `dst_shape`.
|
|
||||||
*
|
|
||||||
* See https://numpy.org/doc/stable/user/basics.broadcasting.html
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
bool can_broadcast_shape_to(SizeT target_ndims, const SizeT* target_shape, SizeT src_ndims, const SizeT* src_shape) {
|
|
||||||
if (src_ndims > target_ndims) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (SizeT i = 0; i < src_ndims; i++) {
|
|
||||||
SizeT target_dim = target_shape[target_ndims - i - 1];
|
|
||||||
SizeT src_dim = src_shape[src_ndims - i - 1];
|
|
||||||
if (!(src_dim == 1 || target_dim == src_dim)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Performs `np.broadcast_shapes(<shapes>)`
|
|
||||||
*
|
|
||||||
* @param num_shapes Number of entries in `shapes`
|
|
||||||
* @param shapes The list of shape to do `np.broadcast_shapes` on.
|
|
||||||
* @param dst_ndims The length of `dst_shape`.
|
|
||||||
* `dst_ndims` must be `max([shape.ndims for shape in shapes])`, but the caller has to calculate it/provide it.
|
|
||||||
* for this function since they should already know in order to allocate `dst_shape` in the first place.
|
|
||||||
* @param dst_shape The resulting shape. Must be pre-allocated by the caller. This function calculate the result
|
|
||||||
* of `np.broadcast_shapes` and write it here.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void broadcast_shapes(SizeT num_shapes, const ShapeEntry<SizeT>* shapes, SizeT dst_ndims, SizeT* dst_shape) {
|
|
||||||
for (SizeT dst_axis = 0; dst_axis < dst_ndims; dst_axis++) {
|
|
||||||
dst_shape[dst_axis] = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef IRRT_DEBUG_ASSERT
|
|
||||||
SizeT max_ndims_found = 0;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
for (SizeT i = 0; i < num_shapes; i++) {
|
|
||||||
ShapeEntry<SizeT> entry = shapes[i];
|
|
||||||
|
|
||||||
// Check pre-condition: `dst_ndims` must be `max([shape.ndims for shape in shapes])`
|
|
||||||
debug_assert(SizeT, entry.ndims <= dst_ndims);
|
|
||||||
|
|
||||||
#ifdef IRRT_DEBUG_ASSERT
|
|
||||||
max_ndims_found = max(max_ndims_found, entry.ndims);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
for (SizeT j = 0; j < entry.ndims; j++) {
|
|
||||||
SizeT entry_axis = entry.ndims - j - 1;
|
|
||||||
SizeT dst_axis = dst_ndims - j - 1;
|
|
||||||
|
|
||||||
SizeT entry_dim = entry.shape[entry_axis];
|
|
||||||
SizeT dst_dim = dst_shape[dst_axis];
|
|
||||||
|
|
||||||
if (dst_dim == 1) {
|
|
||||||
dst_shape[dst_axis] = entry_dim;
|
|
||||||
} else if (entry_dim == 1 || entry_dim == dst_dim) {
|
|
||||||
// Do nothing
|
|
||||||
} else {
|
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR,
|
|
||||||
"shape mismatch: objects cannot be broadcast "
|
|
||||||
"to a single shape.",
|
|
||||||
NO_PARAM, NO_PARAM, NO_PARAM);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef IRRT_DEBUG_ASSERT
|
|
||||||
// Check pre-condition: `dst_ndims` must be `max([shape.ndims for shape in shapes])`
|
|
||||||
debug_assert_eq(SizeT, max_ndims_found, dst_ndims);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Perform `np.broadcast_to(<ndarray>, <target_shape>)` and appropriate assertions.
|
|
||||||
*
|
|
||||||
* This function attempts to broadcast `src_ndarray` to a new shape defined by `dst_ndarray.shape`,
|
|
||||||
* and return the result by modifying `dst_ndarray`.
|
|
||||||
*
|
|
||||||
* # Notes on `dst_ndarray`
|
|
||||||
* The caller is responsible for allocating space for the resulting ndarray.
|
|
||||||
* Here is what this function expects from `dst_ndarray` when called:
|
|
||||||
* - `dst_ndarray->data` does not have to be initialized.
|
|
||||||
* - `dst_ndarray->itemsize` does not have to be initialized.
|
|
||||||
* - `dst_ndarray->ndims` must be initialized, determining the length of `dst_ndarray->shape`
|
|
||||||
* - `dst_ndarray->shape` must be allocated, and must contain the desired target broadcast shape.
|
|
||||||
* - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values.
|
|
||||||
* When this function call ends:
|
|
||||||
* - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`)
|
|
||||||
* - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`
|
|
||||||
* - `dst_ndarray->ndims` is unchanged.
|
|
||||||
* - `dst_ndarray->shape` is unchanged.
|
|
||||||
* - `dst_ndarray->strides` is updated accordingly by how ndarray broadcast_to works.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void broadcast_to(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
|
||||||
if (!ndarray::broadcast::can_broadcast_shape_to(dst_ndarray->ndims, dst_ndarray->shape, src_ndarray->ndims,
|
|
||||||
src_ndarray->shape)) {
|
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR, "operands could not be broadcast together", NO_PARAM, NO_PARAM,
|
|
||||||
NO_PARAM);
|
|
||||||
}
|
|
||||||
|
|
||||||
dst_ndarray->data = src_ndarray->data;
|
|
||||||
dst_ndarray->itemsize = src_ndarray->itemsize;
|
|
||||||
|
|
||||||
for (SizeT i = 0; i < dst_ndarray->ndims; i++) {
|
|
||||||
SizeT src_axis = src_ndarray->ndims - i - 1;
|
|
||||||
SizeT dst_axis = dst_ndarray->ndims - i - 1;
|
|
||||||
if (src_axis < 0 || (src_ndarray->shape[src_axis] == 1 && dst_ndarray->shape[dst_axis] != 1)) {
|
|
||||||
// Freeze the steps in-place
|
|
||||||
dst_ndarray->strides[dst_axis] = 0;
|
|
||||||
} else {
|
|
||||||
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace ndarray::broadcast
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
using namespace ndarray::broadcast;
|
|
||||||
|
|
||||||
void __nac3_ndarray_broadcast_to(NDArray<int32_t>* src_ndarray, NDArray<int32_t>* dst_ndarray) {
|
|
||||||
broadcast_to(src_ndarray, dst_ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_broadcast_to64(NDArray<int64_t>* src_ndarray, NDArray<int64_t>* dst_ndarray) {
|
|
||||||
broadcast_to(src_ndarray, dst_ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_broadcast_shapes(int32_t num_shapes,
|
|
||||||
const ShapeEntry<int32_t>* shapes,
|
|
||||||
int32_t dst_ndims,
|
|
||||||
int32_t* dst_shape) {
|
|
||||||
broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_broadcast_shapes64(int64_t num_shapes,
|
|
||||||
const ShapeEntry<int64_t>* shapes,
|
|
||||||
int64_t dst_ndims,
|
|
||||||
int64_t* dst_shape) {
|
|
||||||
broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape);
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,51 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "irrt/int_types.hpp"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
/**
|
|
||||||
* @brief The NDArray object
|
|
||||||
*
|
|
||||||
* Official numpy implementation:
|
|
||||||
* https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst#pyarrayinterface
|
|
||||||
*
|
|
||||||
* Note that this implementation is based on `PyArrayInterface` rather of `PyArrayObject`. The
|
|
||||||
* difference between `PyArrayInterface` and `PyArrayObject` (relevant to our implementation) is
|
|
||||||
* that `PyArrayInterface` *has* `itemsize` and uses `void*` for its `data`, whereas `PyArrayObject`
|
|
||||||
* does not require `itemsize` (probably using `strides[-1]` instead) and uses `char*` for its
|
|
||||||
* `data`. There are also minor differences in the struct layout.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
struct NDArray {
|
|
||||||
/**
|
|
||||||
* @brief The number of bytes of a single element in `data`.
|
|
||||||
*/
|
|
||||||
SizeT itemsize;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief The number of dimensions of this shape.
|
|
||||||
*/
|
|
||||||
SizeT ndims;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief The NDArray shape, with length equal to `ndims`.
|
|
||||||
*
|
|
||||||
* Note that it may contain 0.
|
|
||||||
*/
|
|
||||||
SizeT* shape;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Array strides, with length equal to `ndims`
|
|
||||||
*
|
|
||||||
* The stride values are in units of bytes, not number of elements.
|
|
||||||
*
|
|
||||||
* Note that `strides` can have negative values or contain 0.
|
|
||||||
*/
|
|
||||||
SizeT* strides;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief The underlying data this `ndarray` is pointing to.
|
|
||||||
*/
|
|
||||||
void* data;
|
|
||||||
};
|
|
||||||
} // namespace
|
|
@ -1,219 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "irrt/exception.hpp"
|
|
||||||
#include "irrt/int_types.hpp"
|
|
||||||
#include "irrt/ndarray/basic.hpp"
|
|
||||||
#include "irrt/ndarray/def.hpp"
|
|
||||||
#include "irrt/range.hpp"
|
|
||||||
#include "irrt/slice.hpp"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
typedef uint8_t NDIndexType;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief A single element index
|
|
||||||
*
|
|
||||||
* `data` points to a `int32_t`.
|
|
||||||
*/
|
|
||||||
const NDIndexType ND_INDEX_TYPE_SINGLE_ELEMENT = 0;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief A slice index
|
|
||||||
*
|
|
||||||
* `data` points to a `Slice<int32_t>`.
|
|
||||||
*/
|
|
||||||
const NDIndexType ND_INDEX_TYPE_SLICE = 1;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief `np.newaxis` / `None`
|
|
||||||
*
|
|
||||||
* `data` is unused.
|
|
||||||
*/
|
|
||||||
const NDIndexType ND_INDEX_TYPE_NEWAXIS = 2;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief `Ellipsis` / `...`
|
|
||||||
*
|
|
||||||
* `data` is unused.
|
|
||||||
*/
|
|
||||||
const NDIndexType ND_INDEX_TYPE_ELLIPSIS = 3;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief An index used in ndarray indexing
|
|
||||||
*
|
|
||||||
* That is:
|
|
||||||
* ```
|
|
||||||
* my_ndarray[::-1, 3, ..., np.newaxis]
|
|
||||||
* ^^^^ ^ ^^^ ^^^^^^^^^^ each of these is represented by an NDIndex.
|
|
||||||
* ```
|
|
||||||
*/
|
|
||||||
struct NDIndex {
|
|
||||||
/**
|
|
||||||
* @brief Enum tag to specify the type of index.
|
|
||||||
*
|
|
||||||
* Please see the comment of each enum constant.
|
|
||||||
*/
|
|
||||||
NDIndexType type;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief The accompanying data associated with `type`.
|
|
||||||
*
|
|
||||||
* Please see the comment of each enum constant.
|
|
||||||
*/
|
|
||||||
uint8_t* data;
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
namespace ndarray::indexing {
|
|
||||||
/**
|
|
||||||
* @brief Perform ndarray "basic indexing" (https://numpy.org/doc/stable/user/basics.indexing.html#basic-indexing)
|
|
||||||
*
|
|
||||||
* This function is very similar to performing `dst_ndarray = src_ndarray[indices]` in Python.
|
|
||||||
*
|
|
||||||
* This function also does proper assertions on `indices` to check for out of bounds access and more.
|
|
||||||
*
|
|
||||||
* # Notes on `dst_ndarray`
|
|
||||||
* The caller is responsible for allocating space for the resulting ndarray.
|
|
||||||
* Here is what this function expects from `dst_ndarray` when called:
|
|
||||||
* - `dst_ndarray->data` does not have to be initialized.
|
|
||||||
* - `dst_ndarray->itemsize` does not have to be initialized.
|
|
||||||
* - `dst_ndarray->ndims` must be initialized, and it must be equal to the expected `ndims` of the `dst_ndarray` after
|
|
||||||
* indexing `src_ndarray` with `indices`.
|
|
||||||
* - `dst_ndarray->shape` must be allocated, through it can contain uninitialized values.
|
|
||||||
* - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values.
|
|
||||||
* When this function call ends:
|
|
||||||
* - `dst_ndarray->data` is set to `src_ndarray->data`.
|
|
||||||
* - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`.
|
|
||||||
* - `dst_ndarray->ndims` is unchanged.
|
|
||||||
* - `dst_ndarray->shape` is updated according to how `src_ndarray` is indexed.
|
|
||||||
* - `dst_ndarray->strides` is updated accordingly by how ndarray indexing works.
|
|
||||||
*
|
|
||||||
* @param indices indices to index `src_ndarray`, ordered in the same way you would write them in Python.
|
|
||||||
* @param src_ndarray The NDArray to be indexed.
|
|
||||||
* @param dst_ndarray The resulting NDArray after indexing. Further details in the comments above,
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void index(SizeT num_indices, const NDIndex* indices, const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
|
||||||
// Validate `indices`.
|
|
||||||
|
|
||||||
// Expected value of `dst_ndarray->ndims`.
|
|
||||||
SizeT expected_dst_ndims = src_ndarray->ndims;
|
|
||||||
// To check for "too many indices for array: array is ?-dimensional, but ? were indexed"
|
|
||||||
SizeT num_indexed = 0;
|
|
||||||
// There may be ellipsis `...` in `indices`. There can only be 0 or 1 ellipsis.
|
|
||||||
SizeT num_ellipsis = 0;
|
|
||||||
|
|
||||||
for (SizeT i = 0; i < num_indices; i++) {
|
|
||||||
if (indices[i].type == ND_INDEX_TYPE_SINGLE_ELEMENT) {
|
|
||||||
expected_dst_ndims--;
|
|
||||||
num_indexed++;
|
|
||||||
} else if (indices[i].type == ND_INDEX_TYPE_SLICE) {
|
|
||||||
num_indexed++;
|
|
||||||
} else if (indices[i].type == ND_INDEX_TYPE_NEWAXIS) {
|
|
||||||
expected_dst_ndims++;
|
|
||||||
} else if (indices[i].type == ND_INDEX_TYPE_ELLIPSIS) {
|
|
||||||
num_ellipsis++;
|
|
||||||
if (num_ellipsis > 1) {
|
|
||||||
raise_exception(SizeT, EXN_INDEX_ERROR, "an index can only have a single ellipsis ('...')", NO_PARAM,
|
|
||||||
NO_PARAM, NO_PARAM);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
__builtin_unreachable();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
debug_assert_eq(SizeT, expected_dst_ndims, dst_ndarray->ndims);
|
|
||||||
|
|
||||||
if (src_ndarray->ndims - num_indexed < 0) {
|
|
||||||
raise_exception(SizeT, EXN_INDEX_ERROR,
|
|
||||||
"too many indices for array: array is {0}-dimensional, "
|
|
||||||
"but {1} were indexed",
|
|
||||||
src_ndarray->ndims, num_indices, NO_PARAM);
|
|
||||||
}
|
|
||||||
|
|
||||||
dst_ndarray->data = src_ndarray->data;
|
|
||||||
dst_ndarray->itemsize = src_ndarray->itemsize;
|
|
||||||
|
|
||||||
// Reference code:
|
|
||||||
// https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L652
|
|
||||||
SizeT src_axis = 0;
|
|
||||||
SizeT dst_axis = 0;
|
|
||||||
|
|
||||||
for (int32_t i = 0; i < num_indices; i++) {
|
|
||||||
const NDIndex* index = &indices[i];
|
|
||||||
if (index->type == ND_INDEX_TYPE_SINGLE_ELEMENT) {
|
|
||||||
SizeT input = (SizeT) * ((int32_t*)index->data);
|
|
||||||
|
|
||||||
SizeT k = slice::resolve_index_in_length(src_ndarray->shape[src_axis], input);
|
|
||||||
if (k == -1) {
|
|
||||||
raise_exception(SizeT, EXN_INDEX_ERROR,
|
|
||||||
"index {0} is out of bounds for axis {1} "
|
|
||||||
"with size {2}",
|
|
||||||
input, src_axis, src_ndarray->shape[src_axis]);
|
|
||||||
}
|
|
||||||
|
|
||||||
dst_ndarray->data = static_cast<uint8_t*>(dst_ndarray->data) + k * src_ndarray->strides[src_axis];
|
|
||||||
|
|
||||||
src_axis++;
|
|
||||||
} else if (index->type == ND_INDEX_TYPE_SLICE) {
|
|
||||||
Slice<int32_t>* slice = (Slice<int32_t>*)index->data;
|
|
||||||
|
|
||||||
Range<int32_t> range = slice->indices_checked<SizeT>(src_ndarray->shape[src_axis]);
|
|
||||||
|
|
||||||
dst_ndarray->data =
|
|
||||||
static_cast<uint8_t*>(dst_ndarray->data) + (SizeT)range.start * src_ndarray->strides[src_axis];
|
|
||||||
dst_ndarray->strides[dst_axis] = ((SizeT)range.step) * src_ndarray->strides[src_axis];
|
|
||||||
dst_ndarray->shape[dst_axis] = (SizeT)range.len<SizeT>();
|
|
||||||
|
|
||||||
dst_axis++;
|
|
||||||
src_axis++;
|
|
||||||
} else if (index->type == ND_INDEX_TYPE_NEWAXIS) {
|
|
||||||
dst_ndarray->strides[dst_axis] = 0;
|
|
||||||
dst_ndarray->shape[dst_axis] = 1;
|
|
||||||
|
|
||||||
dst_axis++;
|
|
||||||
} else if (index->type == ND_INDEX_TYPE_ELLIPSIS) {
|
|
||||||
// The number of ':' entries this '...' implies.
|
|
||||||
SizeT ellipsis_size = src_ndarray->ndims - num_indexed;
|
|
||||||
|
|
||||||
for (SizeT j = 0; j < ellipsis_size; j++) {
|
|
||||||
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
|
|
||||||
dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis];
|
|
||||||
|
|
||||||
dst_axis++;
|
|
||||||
src_axis++;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
__builtin_unreachable();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (; dst_axis < dst_ndarray->ndims; dst_axis++, src_axis++) {
|
|
||||||
dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis];
|
|
||||||
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
|
|
||||||
}
|
|
||||||
|
|
||||||
debug_assert_eq(SizeT, src_ndarray->ndims, src_axis);
|
|
||||||
debug_assert_eq(SizeT, dst_ndarray->ndims, dst_axis);
|
|
||||||
}
|
|
||||||
} // namespace ndarray::indexing
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
using namespace ndarray::indexing;
|
|
||||||
|
|
||||||
void __nac3_ndarray_index(int32_t num_indices,
|
|
||||||
NDIndex* indices,
|
|
||||||
NDArray<int32_t>* src_ndarray,
|
|
||||||
NDArray<int32_t>* dst_ndarray) {
|
|
||||||
index(num_indices, indices, src_ndarray, dst_ndarray);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_index64(int64_t num_indices,
|
|
||||||
NDIndex* indices,
|
|
||||||
NDArray<int64_t>* src_ndarray,
|
|
||||||
NDArray<int64_t>* dst_ndarray) {
|
|
||||||
index(num_indices, indices, src_ndarray, dst_ndarray);
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,146 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "irrt/int_types.hpp"
|
|
||||||
#include "irrt/ndarray/def.hpp"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
/**
|
|
||||||
* @brief Helper struct to enumerate through an ndarray *efficiently*.
|
|
||||||
*
|
|
||||||
* Example usage (in pseudo-code):
|
|
||||||
* ```
|
|
||||||
* // Suppose my_ndarray has been initialized, with shape [2, 3] and dtype `double`
|
|
||||||
* NDIter nditer;
|
|
||||||
* nditer.initialize(my_ndarray);
|
|
||||||
* while (nditer.has_element()) {
|
|
||||||
* // This body is run 6 (= my_ndarray.size) times.
|
|
||||||
*
|
|
||||||
* // [0, 0] -> [0, 1] -> [0, 2] -> [1, 0] -> [1, 1] -> [1, 2] -> end
|
|
||||||
* print(nditer.indices);
|
|
||||||
*
|
|
||||||
* // 0 -> 1 -> 2 -> 3 -> 4 -> 5
|
|
||||||
* print(nditer.nth);
|
|
||||||
*
|
|
||||||
* // <1st element> -> <2nd element> -> ... -> <6th element> -> end
|
|
||||||
* print(*((double *) nditer.element))
|
|
||||||
*
|
|
||||||
* nditer.next(); // Go to next element.
|
|
||||||
* }
|
|
||||||
* ```
|
|
||||||
*
|
|
||||||
* Interesting cases:
|
|
||||||
* - If `my_ndarray.ndims` == 0, there is one iteration.
|
|
||||||
* - If `my_ndarray.shape` contains zeroes, there are no iterations.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
struct NDIter {
|
|
||||||
// Information about the ndarray being iterated over.
|
|
||||||
SizeT ndims;
|
|
||||||
SizeT* shape;
|
|
||||||
SizeT* strides;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief The current indices.
|
|
||||||
*
|
|
||||||
* Must be allocated by the caller.
|
|
||||||
*/
|
|
||||||
SizeT* indices;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief The nth (0-based) index of the current indices.
|
|
||||||
*
|
|
||||||
* Initially this is 0.
|
|
||||||
*/
|
|
||||||
SizeT nth;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Pointer to the current element.
|
|
||||||
*
|
|
||||||
* Initially this points to first element of the ndarray.
|
|
||||||
*/
|
|
||||||
void* element;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Cache for the product of shape.
|
|
||||||
*
|
|
||||||
* Could be 0 if `shape` has 0s in it.
|
|
||||||
*/
|
|
||||||
SizeT size;
|
|
||||||
|
|
||||||
void initialize(SizeT ndims, SizeT* shape, SizeT* strides, void* element, SizeT* indices) {
|
|
||||||
this->ndims = ndims;
|
|
||||||
this->shape = shape;
|
|
||||||
this->strides = strides;
|
|
||||||
|
|
||||||
this->indices = indices;
|
|
||||||
this->element = element;
|
|
||||||
|
|
||||||
// Compute size
|
|
||||||
this->size = 1;
|
|
||||||
for (SizeT i = 0; i < ndims; i++) {
|
|
||||||
this->size *= shape[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
// `indices` starts on all 0s.
|
|
||||||
for (SizeT axis = 0; axis < ndims; axis++)
|
|
||||||
indices[axis] = 0;
|
|
||||||
nth = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
void initialize_by_ndarray(NDArray<SizeT>* ndarray, SizeT* indices) {
|
|
||||||
// NOTE: ndarray->data is pointing to the first element, and `NDIter`'s `element` should also point to the first
|
|
||||||
// element as well.
|
|
||||||
this->initialize(ndarray->ndims, ndarray->shape, ndarray->strides, ndarray->data, indices);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Is the current iteration valid?
|
|
||||||
// If true, then `element`, `indices` and `nth` contain details about the current element.
|
|
||||||
bool has_element() { return nth < size; }
|
|
||||||
|
|
||||||
// Go to the next element.
|
|
||||||
void next() {
|
|
||||||
for (SizeT i = 0; i < ndims; i++) {
|
|
||||||
SizeT axis = ndims - i - 1;
|
|
||||||
indices[axis]++;
|
|
||||||
if (indices[axis] >= shape[axis]) {
|
|
||||||
indices[axis] = 0;
|
|
||||||
|
|
||||||
// TODO: There is something called backstrides to speedup iteration.
|
|
||||||
// See https://ajcr.net/stride-guide-part-1/, and
|
|
||||||
// https://docs.scipy.org/doc/numpy-1.13.0/reference/c-api.types-and-structures.html#c.PyArrayIterObject.PyArrayIterObject.backstrides.
|
|
||||||
element = static_cast<void*>(reinterpret_cast<uint8_t*>(element) - strides[axis] * (shape[axis] - 1));
|
|
||||||
} else {
|
|
||||||
element = static_cast<void*>(reinterpret_cast<uint8_t*>(element) + strides[axis]);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
nth++;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
void __nac3_nditer_initialize(NDIter<int32_t>* iter, NDArray<int32_t>* ndarray, int32_t* indices) {
|
|
||||||
iter->initialize_by_ndarray(ndarray, indices);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_nditer_initialize64(NDIter<int64_t>* iter, NDArray<int64_t>* ndarray, int64_t* indices) {
|
|
||||||
iter->initialize_by_ndarray(ndarray, indices);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool __nac3_nditer_has_element(NDIter<int32_t>* iter) {
|
|
||||||
return iter->has_element();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool __nac3_nditer_has_element64(NDIter<int64_t>* iter) {
|
|
||||||
return iter->has_element();
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_nditer_next(NDIter<int32_t>* iter) {
|
|
||||||
iter->next();
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_nditer_next64(NDIter<int64_t>* iter) {
|
|
||||||
iter->next();
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,98 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "irrt/debug.hpp"
|
|
||||||
#include "irrt/exception.hpp"
|
|
||||||
#include "irrt/int_types.hpp"
|
|
||||||
#include "irrt/ndarray/basic.hpp"
|
|
||||||
#include "irrt/ndarray/broadcast.hpp"
|
|
||||||
#include "irrt/ndarray/iter.hpp"
|
|
||||||
|
|
||||||
// NOTE: Everything would be much easier and elegant if einsum is implemented.
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
namespace ndarray::matmul {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Perform the broadcast in `np.einsum("...ij,...jk->...ik", a, b)`.
|
|
||||||
*
|
|
||||||
* Example:
|
|
||||||
* Suppose `a_shape == [1, 97, 4, 2]`
|
|
||||||
* and `b_shape == [99, 98, 1, 2, 5]`,
|
|
||||||
*
|
|
||||||
* ...then `new_a_shape == [99, 98, 97, 4, 2]`,
|
|
||||||
* `new_b_shape == [99, 98, 97, 2, 5]`,
|
|
||||||
* and `dst_shape == [99, 98, 97, 4, 5]`.
|
|
||||||
* ^^^^^^^^^^ ^^^^
|
|
||||||
* (broadcasted) (4x2 @ 2x5 => 4x5)
|
|
||||||
*
|
|
||||||
* @param a_ndims Length of `a_shape`.
|
|
||||||
* @param a_shape Shape of `a`.
|
|
||||||
* @param b_ndims Length of `b_shape`.
|
|
||||||
* @param b_shape Shape of `b`.
|
|
||||||
* @param final_ndims Should be equal to `max(a_ndims, b_ndims)`. This is the length of `new_a_shape`,
|
|
||||||
* `new_b_shape`, and `dst_shape` - the number of dimensions after broadcasting.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void calculate_shapes(SizeT a_ndims,
|
|
||||||
SizeT* a_shape,
|
|
||||||
SizeT b_ndims,
|
|
||||||
SizeT* b_shape,
|
|
||||||
SizeT final_ndims,
|
|
||||||
SizeT* new_a_shape,
|
|
||||||
SizeT* new_b_shape,
|
|
||||||
SizeT* dst_shape) {
|
|
||||||
debug_assert(SizeT, a_ndims >= 2);
|
|
||||||
debug_assert(SizeT, b_ndims >= 2);
|
|
||||||
debug_assert_eq(SizeT, max(a_ndims, b_ndims), final_ndims);
|
|
||||||
|
|
||||||
// Check that a and b are compatible for matmul
|
|
||||||
if (a_shape[a_ndims - 1] != b_shape[b_ndims - 2]) {
|
|
||||||
// This is a custom error message. Different from NumPy.
|
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot multiply LHS (shape ?x{0}) with RHS (shape {1}x?})",
|
|
||||||
a_shape[a_ndims - 1], b_shape[b_ndims - 2], NO_PARAM);
|
|
||||||
}
|
|
||||||
|
|
||||||
const SizeT num_entries = 2;
|
|
||||||
ShapeEntry<SizeT> entries[num_entries] = {{.ndims = a_ndims - 2, .shape = a_shape},
|
|
||||||
{.ndims = b_ndims - 2, .shape = b_shape}};
|
|
||||||
|
|
||||||
// TODO: Optimize this
|
|
||||||
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries, final_ndims - 2, new_a_shape);
|
|
||||||
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries, final_ndims - 2, new_b_shape);
|
|
||||||
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries, final_ndims - 2, dst_shape);
|
|
||||||
|
|
||||||
new_a_shape[final_ndims - 2] = a_shape[a_ndims - 2];
|
|
||||||
new_a_shape[final_ndims - 1] = a_shape[a_ndims - 1];
|
|
||||||
new_b_shape[final_ndims - 2] = b_shape[b_ndims - 2];
|
|
||||||
new_b_shape[final_ndims - 1] = b_shape[b_ndims - 1];
|
|
||||||
dst_shape[final_ndims - 2] = a_shape[a_ndims - 2];
|
|
||||||
dst_shape[final_ndims - 1] = b_shape[b_ndims - 1];
|
|
||||||
}
|
|
||||||
} // namespace ndarray::matmul
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
using namespace ndarray::matmul;
|
|
||||||
|
|
||||||
void __nac3_ndarray_matmul_calculate_shapes(int32_t a_ndims,
|
|
||||||
int32_t* a_shape,
|
|
||||||
int32_t b_ndims,
|
|
||||||
int32_t* b_shape,
|
|
||||||
int32_t final_ndims,
|
|
||||||
int32_t* new_a_shape,
|
|
||||||
int32_t* new_b_shape,
|
|
||||||
int32_t* dst_shape) {
|
|
||||||
calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_matmul_calculate_shapes64(int64_t a_ndims,
|
|
||||||
int64_t* a_shape,
|
|
||||||
int64_t b_ndims,
|
|
||||||
int64_t* b_shape,
|
|
||||||
int64_t final_ndims,
|
|
||||||
int64_t* new_a_shape,
|
|
||||||
int64_t* new_b_shape,
|
|
||||||
int64_t* dst_shape) {
|
|
||||||
calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape);
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,97 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "irrt/exception.hpp"
|
|
||||||
#include "irrt/int_types.hpp"
|
|
||||||
#include "irrt/ndarray/def.hpp"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
namespace ndarray::reshape {
|
|
||||||
/**
|
|
||||||
* @brief Perform assertions on and resolve unknown dimensions in `new_shape` in `np.reshape(<ndarray>, new_shape)`
|
|
||||||
*
|
|
||||||
* If `new_shape` indeed contains unknown dimensions (specified with `-1`, just like numpy), `new_shape` will be
|
|
||||||
* modified to contain the resolved dimension.
|
|
||||||
*
|
|
||||||
* To perform assertions on and resolve unknown dimensions in `new_shape`, we don't need the actual
|
|
||||||
* `<ndarray>` object itself, but only the `.size` of the `<ndarray>`.
|
|
||||||
*
|
|
||||||
* @param size The `.size` of `<ndarray>`
|
|
||||||
* @param new_ndims Number of elements in `new_shape`
|
|
||||||
* @param new_shape Target shape to reshape to
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void resolve_and_check_new_shape(SizeT size, SizeT new_ndims, SizeT* new_shape) {
|
|
||||||
// Is there a -1 in `new_shape`?
|
|
||||||
bool neg1_exists = false;
|
|
||||||
// Location of -1, only initialized if `neg1_exists` is true
|
|
||||||
SizeT neg1_axis_i;
|
|
||||||
// The computed ndarray size of `new_shape`
|
|
||||||
SizeT new_size = 1;
|
|
||||||
|
|
||||||
for (SizeT axis_i = 0; axis_i < new_ndims; axis_i++) {
|
|
||||||
SizeT dim = new_shape[axis_i];
|
|
||||||
if (dim < 0) {
|
|
||||||
if (dim == -1) {
|
|
||||||
if (neg1_exists) {
|
|
||||||
// Multiple `-1` found. Throw an error.
|
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR, "can only specify one unknown dimension", NO_PARAM,
|
|
||||||
NO_PARAM, NO_PARAM);
|
|
||||||
} else {
|
|
||||||
neg1_exists = true;
|
|
||||||
neg1_axis_i = axis_i;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// TODO: What? In `np.reshape` any negative dimensions is
|
|
||||||
// treated like its `-1`.
|
|
||||||
//
|
|
||||||
// Try running `np.zeros((3, 4)).reshape((-999, 2))`
|
|
||||||
//
|
|
||||||
// It is not documented by numpy.
|
|
||||||
// Throw an error for now...
|
|
||||||
|
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR, "Found non -1 negative dimension {0} on axis {1}", dim, axis_i,
|
|
||||||
NO_PARAM);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
new_size *= dim;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool can_reshape;
|
|
||||||
if (neg1_exists) {
|
|
||||||
// Let `x` be the unknown dimension
|
|
||||||
// Solve `x * <new_size> = <size>`
|
|
||||||
if (new_size == 0 && size == 0) {
|
|
||||||
// `x` has infinitely many solutions
|
|
||||||
can_reshape = false;
|
|
||||||
} else if (new_size == 0 && size != 0) {
|
|
||||||
// `x` has no solutions
|
|
||||||
can_reshape = false;
|
|
||||||
} else if (size % new_size != 0) {
|
|
||||||
// `x` has no integer solutions
|
|
||||||
can_reshape = false;
|
|
||||||
} else {
|
|
||||||
can_reshape = true;
|
|
||||||
new_shape[neg1_axis_i] = size / new_size; // Resolve dimension
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
can_reshape = (new_size == size);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!can_reshape) {
|
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR, "cannot reshape array of size {0} into given shape", size, NO_PARAM,
|
|
||||||
NO_PARAM);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace ndarray::reshape
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
void __nac3_ndarray_reshape_resolve_and_check_new_shape(int32_t size, int32_t new_ndims, int32_t* new_shape) {
|
|
||||||
ndarray::reshape::resolve_and_check_new_shape(size, new_ndims, new_shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_reshape_resolve_and_check_new_shape64(int64_t size, int64_t new_ndims, int64_t* new_shape) {
|
|
||||||
ndarray::reshape::resolve_and_check_new_shape(size, new_ndims, new_shape);
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,143 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "irrt/debug.hpp"
|
|
||||||
#include "irrt/exception.hpp"
|
|
||||||
#include "irrt/int_types.hpp"
|
|
||||||
#include "irrt/ndarray/def.hpp"
|
|
||||||
#include "irrt/slice.hpp"
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Notes on `np.transpose(<array>, <axes>)`
|
|
||||||
*
|
|
||||||
* TODO: `axes`, if specified, can actually contain negative indices,
|
|
||||||
* but it is not documented in numpy.
|
|
||||||
*
|
|
||||||
* Supporting it for now.
|
|
||||||
*/
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
namespace ndarray::transpose {
|
|
||||||
/**
|
|
||||||
* @brief Do assertions on `<axes>` in `np.transpose(<array>, <axes>)`.
|
|
||||||
*
|
|
||||||
* Note that `np.transpose`'s `<axe>` argument is optional. If the argument
|
|
||||||
* is specified but the user, use this function to do assertions on it.
|
|
||||||
*
|
|
||||||
* @param ndims The number of dimensions of `<array>`
|
|
||||||
* @param num_axes Number of elements in `<axes>` as specified by the user.
|
|
||||||
* This should be equal to `ndims`. If not, a "ValueError: axes don't match array" is thrown.
|
|
||||||
* @param axes The user specified `<axes>`.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void assert_transpose_axes(SizeT ndims, SizeT num_axes, const SizeT* axes) {
|
|
||||||
if (ndims != num_axes) {
|
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR, "axes don't match array", NO_PARAM, NO_PARAM, NO_PARAM);
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Optimize this
|
|
||||||
bool* axe_specified = (bool*)__builtin_alloca(sizeof(bool) * ndims);
|
|
||||||
for (SizeT i = 0; i < ndims; i++)
|
|
||||||
axe_specified[i] = false;
|
|
||||||
|
|
||||||
for (SizeT i = 0; i < ndims; i++) {
|
|
||||||
SizeT axis = slice::resolve_index_in_length(ndims, axes[i]);
|
|
||||||
if (axis == -1) {
|
|
||||||
// TODO: numpy actually throws a `numpy.exceptions.AxisError`
|
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR, "axis {0} is out of bounds for array of dimension {1}", axis, ndims,
|
|
||||||
NO_PARAM);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (axe_specified[axis]) {
|
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR, "repeated axis in transpose", NO_PARAM, NO_PARAM, NO_PARAM);
|
|
||||||
}
|
|
||||||
|
|
||||||
axe_specified[axis] = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Create a transpose view of `src_ndarray` and perform proper assertions.
|
|
||||||
*
|
|
||||||
* This function is very similar to doing `dst_ndarray = np.transpose(src_ndarray, <axes>)`.
|
|
||||||
* If `<axes>` is supposed to be `None`, caller can pass in a `nullptr` to `<axes>`.
|
|
||||||
*
|
|
||||||
* The transpose view created is returned by modifying `dst_ndarray`.
|
|
||||||
*
|
|
||||||
* The caller is responsible for setting up `dst_ndarray` before calling this function.
|
|
||||||
* Here is what this function expects from `dst_ndarray` when called:
|
|
||||||
* - `dst_ndarray->data` does not have to be initialized.
|
|
||||||
* - `dst_ndarray->itemsize` does not have to be initialized.
|
|
||||||
* - `dst_ndarray->ndims` must be initialized, must be equal to `src_ndarray->ndims`.
|
|
||||||
* - `dst_ndarray->shape` must be allocated, through it can contain uninitialized values.
|
|
||||||
* - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values.
|
|
||||||
* When this function call ends:
|
|
||||||
* - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`)
|
|
||||||
* - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`
|
|
||||||
* - `dst_ndarray->ndims` is unchanged
|
|
||||||
* - `dst_ndarray->shape` is updated according to how `np.transpose` works
|
|
||||||
* - `dst_ndarray->strides` is updated according to how `np.transpose` works
|
|
||||||
*
|
|
||||||
* @param src_ndarray The NDArray to build a transpose view on
|
|
||||||
* @param dst_ndarray The resulting NDArray after transpose. Further details in the comments above,
|
|
||||||
* @param num_axes Number of elements in axes. Unused if `axes` is nullptr.
|
|
||||||
* @param axes Axes permutation. Set it to `nullptr` if `<axes>` is `None`.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
void transpose(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray, SizeT num_axes, const SizeT* axes) {
|
|
||||||
debug_assert_eq(SizeT, src_ndarray->ndims, dst_ndarray->ndims);
|
|
||||||
const auto ndims = src_ndarray->ndims;
|
|
||||||
|
|
||||||
if (axes != nullptr)
|
|
||||||
assert_transpose_axes(ndims, num_axes, axes);
|
|
||||||
|
|
||||||
dst_ndarray->data = src_ndarray->data;
|
|
||||||
dst_ndarray->itemsize = src_ndarray->itemsize;
|
|
||||||
|
|
||||||
// Check out https://ajcr.net/stride-guide-part-2/ to see how `np.transpose` works behind the scenes.
|
|
||||||
if (axes == nullptr) {
|
|
||||||
// `np.transpose(<array>, axes=None)`
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Minor note: `np.transpose(<array>, axes=None)` is equivalent to
|
|
||||||
* `np.transpose(<array>, axes=[N-1, N-2, ..., 0])` - basically it
|
|
||||||
* is reversing the order of strides and shape.
|
|
||||||
*
|
|
||||||
* This is a fast implementation to handle this special (but very common) case.
|
|
||||||
*/
|
|
||||||
|
|
||||||
for (SizeT axis = 0; axis < ndims; axis++) {
|
|
||||||
dst_ndarray->shape[axis] = src_ndarray->shape[ndims - axis - 1];
|
|
||||||
dst_ndarray->strides[axis] = src_ndarray->strides[ndims - axis - 1];
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// `np.transpose(<array>, <axes>)`
|
|
||||||
|
|
||||||
// Permute strides and shape according to `axes`, while resolving negative indices in `axes`
|
|
||||||
for (SizeT axis = 0; axis < ndims; axis++) {
|
|
||||||
// `i` cannot be OUT_OF_BOUNDS because of assertions
|
|
||||||
SizeT i = slice::resolve_index_in_length(ndims, axes[axis]);
|
|
||||||
|
|
||||||
dst_ndarray->shape[axis] = src_ndarray->shape[i];
|
|
||||||
dst_ndarray->strides[axis] = src_ndarray->strides[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace ndarray::transpose
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
using namespace ndarray::transpose;
|
|
||||||
void __nac3_ndarray_transpose(const NDArray<int32_t>* src_ndarray,
|
|
||||||
NDArray<int32_t>* dst_ndarray,
|
|
||||||
int32_t num_axes,
|
|
||||||
const int32_t* axes) {
|
|
||||||
transpose(src_ndarray, dst_ndarray, num_axes, axes);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_transpose64(const NDArray<int64_t>* src_ndarray,
|
|
||||||
NDArray<int64_t>* dst_ndarray,
|
|
||||||
int64_t num_axes,
|
|
||||||
const int64_t* axes) {
|
|
||||||
transpose(src_ndarray, dst_ndarray, num_axes, axes);
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,47 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "irrt/debug.hpp"
|
|
||||||
#include "irrt/int_types.hpp"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
namespace range {
|
|
||||||
template<typename T>
|
|
||||||
T len(T start, T stop, T step) {
|
|
||||||
// Reference:
|
|
||||||
// https://github.com/python/cpython/blob/9dbd12375561a393eaec4b21ee4ac568a407cdb0/Objects/rangeobject.c#L933
|
|
||||||
if (step > 0 && start < stop)
|
|
||||||
return 1 + (stop - 1 - start) / step;
|
|
||||||
else if (step < 0 && start > stop)
|
|
||||||
return 1 + (start - 1 - stop) / (-step);
|
|
||||||
else
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
} // namespace range
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief A Python range.
|
|
||||||
*/
|
|
||||||
template<typename T>
|
|
||||||
struct Range {
|
|
||||||
T start;
|
|
||||||
T stop;
|
|
||||||
T step;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Calculate the `len()` of this range.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
T len() {
|
|
||||||
debug_assert(SizeT, step != 0);
|
|
||||||
return range::len(start, stop, step);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
using namespace range;
|
|
||||||
|
|
||||||
SliceIndex __nac3_range_slice_len(const SliceIndex start, const SliceIndex end, const SliceIndex step) {
|
|
||||||
return len(start, end, step);
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,156 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "irrt/debug.hpp"
|
|
||||||
#include "irrt/exception.hpp"
|
|
||||||
#include "irrt/int_types.hpp"
|
|
||||||
#include "irrt/math_util.hpp"
|
|
||||||
#include "irrt/range.hpp"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
namespace slice {
|
|
||||||
/**
|
|
||||||
* @brief Resolve a possibly negative index in a list of a known length.
|
|
||||||
*
|
|
||||||
* Returns -1 if the resolved index is out of the list's bounds.
|
|
||||||
*/
|
|
||||||
template<typename T>
|
|
||||||
T resolve_index_in_length(T length, T index) {
|
|
||||||
T resolved = index < 0 ? length + index : index;
|
|
||||||
if (0 <= resolved && resolved < length) {
|
|
||||||
return resolved;
|
|
||||||
} else {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Resolve a slice as a range.
|
|
||||||
*
|
|
||||||
* This is equivalent to `range(*slice(start, stop, step).indices(length))` in Python.
|
|
||||||
*/
|
|
||||||
template<typename T>
|
|
||||||
void indices(bool start_defined,
|
|
||||||
T start,
|
|
||||||
bool stop_defined,
|
|
||||||
T stop,
|
|
||||||
bool step_defined,
|
|
||||||
T step,
|
|
||||||
T length,
|
|
||||||
T* range_start,
|
|
||||||
T* range_stop,
|
|
||||||
T* range_step) {
|
|
||||||
// Reference: https://github.com/python/cpython/blob/main/Objects/sliceobject.c#L388
|
|
||||||
*range_step = step_defined ? step : 1;
|
|
||||||
bool step_is_negative = *range_step < 0;
|
|
||||||
|
|
||||||
T lower, upper;
|
|
||||||
if (step_is_negative) {
|
|
||||||
lower = -1;
|
|
||||||
upper = length - 1;
|
|
||||||
} else {
|
|
||||||
lower = 0;
|
|
||||||
upper = length;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (start_defined) {
|
|
||||||
*range_start = start < 0 ? max(lower, start + length) : min(upper, start);
|
|
||||||
} else {
|
|
||||||
*range_start = step_is_negative ? upper : lower;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (stop_defined) {
|
|
||||||
*range_stop = stop < 0 ? max(lower, stop + length) : min(upper, stop);
|
|
||||||
} else {
|
|
||||||
*range_stop = step_is_negative ? lower : upper;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace slice
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief A Python-like slice with **unresolved** indices.
|
|
||||||
*/
|
|
||||||
template<typename T>
|
|
||||||
struct Slice {
|
|
||||||
bool start_defined;
|
|
||||||
T start;
|
|
||||||
|
|
||||||
bool stop_defined;
|
|
||||||
T stop;
|
|
||||||
|
|
||||||
bool step_defined;
|
|
||||||
T step;
|
|
||||||
|
|
||||||
Slice() { this->reset(); }
|
|
||||||
|
|
||||||
void reset() {
|
|
||||||
this->start_defined = false;
|
|
||||||
this->stop_defined = false;
|
|
||||||
this->step_defined = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_start(T start) {
|
|
||||||
this->start_defined = true;
|
|
||||||
this->start = start;
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_stop(T stop) {
|
|
||||||
this->stop_defined = true;
|
|
||||||
this->stop = stop;
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_step(T step) {
|
|
||||||
this->step_defined = true;
|
|
||||||
this->step = step;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Resolve this slice as a range.
|
|
||||||
*
|
|
||||||
* In Python, this would be `range(*slice(start, stop, step).indices(length))`.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
Range<T> indices(T length) {
|
|
||||||
// Reference:
|
|
||||||
// https://github.com/python/cpython/blob/main/Objects/sliceobject.c#L388
|
|
||||||
debug_assert(SizeT, length >= 0);
|
|
||||||
|
|
||||||
Range<T> result;
|
|
||||||
slice::indices(start_defined, start, stop_defined, stop, step_defined, step, length, &result.start,
|
|
||||||
&result.stop, &result.step);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Like `.indices()` but with assertions.
|
|
||||||
*/
|
|
||||||
template<typename SizeT>
|
|
||||||
Range<T> indices_checked(T length) {
|
|
||||||
// TODO: Switch to `SizeT length`
|
|
||||||
|
|
||||||
if (length < 0) {
|
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR, "length should not be negative, got {0}", length, NO_PARAM,
|
|
||||||
NO_PARAM);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (this->step_defined && this->step == 0) {
|
|
||||||
raise_exception(SizeT, EXN_VALUE_ERROR, "slice step cannot be zero", NO_PARAM, NO_PARAM, NO_PARAM);
|
|
||||||
}
|
|
||||||
|
|
||||||
return this->indices<SizeT>(length);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) {
|
|
||||||
if (i < 0) {
|
|
||||||
i = len + i;
|
|
||||||
}
|
|
||||||
if (i < 0) {
|
|
||||||
return 0;
|
|
||||||
} else if (i > len) {
|
|
||||||
return len;
|
|
||||||
}
|
|
||||||
return i;
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,23 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "irrt/int_types.hpp"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
template<typename SizeT>
|
|
||||||
bool __nac3_str_eq_impl(const char* str1, SizeT len1, const char* str2, SizeT len2) {
|
|
||||||
if (len1 != len2) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
return __builtin_memcmp(str1, str2, static_cast<SizeT>(len1)) == 0;
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
bool nac3_str_eq(const char* str1, uint32_t len1, const char* str2, uint32_t len2) {
|
|
||||||
return __nac3_str_eq_impl<uint32_t>(str1, len1, str2, len2);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool nac3_str_eq64(const char* str1, uint64_t len1, const char* str2, uint64_t len2) {
|
|
||||||
return __nac3_str_eq_impl<uint64_t>(str1, len1, str2, len2);
|
|
||||||
}
|
|
||||||
}
|
|
77
nac3core/irrt/irrt/utils.hpp
Normal file
77
nac3core/irrt/irrt/utils.hpp
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <typename T>
|
||||||
|
const T& max(const T& a, const T& b) {
|
||||||
|
return a > b ? a : b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
const T& min(const T& a, const T& b) {
|
||||||
|
return a > b ? b : a;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool arrays_match(int len, T* as, T* bs) {
|
||||||
|
for (int i = 0; i < len; i++) {
|
||||||
|
if (as[i] != bs[i])
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace cstr_utils {
|
||||||
|
bool is_empty(const char* str) {
|
||||||
|
return str[0] == '\0';
|
||||||
|
}
|
||||||
|
|
||||||
|
int8_t compare(const char* a, const char* b) {
|
||||||
|
uint32_t i = 0;
|
||||||
|
while (true) {
|
||||||
|
if (a[i] < b[i]) {
|
||||||
|
return -1;
|
||||||
|
} else if (a[i] > b[i]) {
|
||||||
|
return 1;
|
||||||
|
} else { // a[i] == b[i]
|
||||||
|
if (a[i] == '\0') {
|
||||||
|
return 0;
|
||||||
|
} else {
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int8_t equal(const char* a, const char* b) {
|
||||||
|
return compare(a, b) == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t length(const char* str) {
|
||||||
|
uint32_t length = 0;
|
||||||
|
while (*str != '\0') {
|
||||||
|
length++;
|
||||||
|
str++;
|
||||||
|
}
|
||||||
|
return length;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool copy(const char* src, char* dst, uint32_t dst_max_size) {
|
||||||
|
for (uint32_t i = 0; i < dst_max_size; i++) {
|
||||||
|
bool is_last = i + 1 == dst_max_size;
|
||||||
|
if (is_last && src[i] != '\0') {
|
||||||
|
dst[i] = '\0';
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (src[i] == '\0') {
|
||||||
|
dst[i] = '\0';
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
dst[i] = src[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
__builtin_unreachable();
|
||||||
|
}
|
||||||
|
} // namespace cstr_utils
|
||||||
|
} // namespace
|
6
nac3core/irrt/irrt_everything.hpp
Normal file
6
nac3core/irrt/irrt_everything.hpp
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <irrt/core.hpp>
|
||||||
|
#include <irrt/int_defs.hpp>
|
||||||
|
#include <irrt/utils.hpp>
|
||||||
|
#include <irrt/error_context.hpp>
|
13
nac3core/irrt/irrt_test.cpp
Normal file
13
nac3core/irrt/irrt_test.cpp
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
// This file will be compiled like a real C++ program,
|
||||||
|
// and we do have the luxury to use the standard libraries.
|
||||||
|
// That is if the nix flakes do not have issues... especially on msys2...
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstdlib>
|
||||||
|
|
||||||
|
#include <test/test_core.hpp>
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
test::core::run();
|
||||||
|
return 0;
|
||||||
|
}
|
12
nac3core/irrt/test/includes.hpp
Normal file
12
nac3core/irrt/test/includes.hpp
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstdlib>
|
||||||
|
|
||||||
|
#include <test/util.hpp>
|
||||||
|
#include <irrt_everything.hpp>
|
||||||
|
|
||||||
|
/*
|
||||||
|
Include this header for every test_*.cpp
|
||||||
|
*/
|
16
nac3core/irrt/test/test_core.hpp
Normal file
16
nac3core/irrt/test/test_core.hpp
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <test/includes.hpp>
|
||||||
|
|
||||||
|
namespace test { namespace core {
|
||||||
|
void test_int_exp() {
|
||||||
|
BEGIN_TEST();
|
||||||
|
|
||||||
|
assert_values_match(125, __nac3_int_exp_impl<int32_t>(5, 3));
|
||||||
|
assert_values_match(3125, __nac3_int_exp_impl<int32_t>(5, 5));
|
||||||
|
}
|
||||||
|
|
||||||
|
void run() {
|
||||||
|
test_int_exp();
|
||||||
|
}
|
||||||
|
}} // namespace test::core
|
114
nac3core/irrt/test/util.hpp
Normal file
114
nac3core/irrt/test/util.hpp
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <cstdio>
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
void print_value(const T& value) {}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void print_value(const int8_t& value) {
|
||||||
|
printf("%d", value);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void print_value(const int32_t& value) {
|
||||||
|
printf("%d", value);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void print_value(const uint8_t& value) {
|
||||||
|
printf("%u", value);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void print_value(const uint32_t& value) {
|
||||||
|
printf("%u", value);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void print_value(const float& value) {
|
||||||
|
printf("%f", value);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void print_value(const double& value) {
|
||||||
|
printf("%f", value);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __begin_test(const char* function_name, const char* file, int line) {
|
||||||
|
printf("######### Running %s @ %s:%d\n", function_name, file, line);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define BEGIN_TEST() __begin_test(__FUNCTION__, __FILE__, __LINE__)
|
||||||
|
|
||||||
|
void test_fail() {
|
||||||
|
printf("[!] Test failed. Exiting with status code 1.\n");
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void debug_print_array(int len, const T* as) {
|
||||||
|
printf("[");
|
||||||
|
for (int i = 0; i < len; i++) {
|
||||||
|
if (i != 0)
|
||||||
|
printf(", ");
|
||||||
|
print_value(as[i]);
|
||||||
|
}
|
||||||
|
printf("]");
|
||||||
|
}
|
||||||
|
|
||||||
|
void print_assertion_passed(const char* file, int line) {
|
||||||
|
printf("[*] Assertion passed on %s:%d\n", file, line);
|
||||||
|
}
|
||||||
|
|
||||||
|
void print_assertion_failed(const char* file, int line) {
|
||||||
|
printf("[!] Assertion failed on %s:%d\n", file, line);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __assert_true(const char* file, int line, bool cond) {
|
||||||
|
if (cond) {
|
||||||
|
print_assertion_passed(file, line);
|
||||||
|
} else {
|
||||||
|
print_assertion_failed(file, line);
|
||||||
|
test_fail();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define assert_true(cond) __assert_true(__FILE__, __LINE__, cond)
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void __assert_arrays_match(const char* file, int line, int len, const T* expected, const T* got) {
|
||||||
|
if (arrays_match(len, expected, got)) {
|
||||||
|
print_assertion_passed(file, line);
|
||||||
|
} else {
|
||||||
|
print_assertion_failed(file, line);
|
||||||
|
printf("Expect = ");
|
||||||
|
debug_print_array(len, expected);
|
||||||
|
printf("\n");
|
||||||
|
printf(" Got = ");
|
||||||
|
debug_print_array(len, got);
|
||||||
|
printf("\n");
|
||||||
|
test_fail();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define assert_arrays_match(len, expected, got) __assert_arrays_match(__FILE__, __LINE__, len, expected, got)
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void __assert_values_match(const char* file, int line, T expected, T got) {
|
||||||
|
if (expected == got) {
|
||||||
|
print_assertion_passed(file, line);
|
||||||
|
} else {
|
||||||
|
print_assertion_failed(file, line);
|
||||||
|
printf("Expect = ");
|
||||||
|
print_value(expected);
|
||||||
|
printf("\n");
|
||||||
|
printf(" Got = ");
|
||||||
|
print_value(got);
|
||||||
|
printf("\n");
|
||||||
|
test_fail();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define assert_values_match(expected, got) __assert_values_match(__FILE__, __LINE__, expected, got)
|
@ -1,21 +0,0 @@
|
|||||||
[package]
|
|
||||||
name = "nac3core_derive"
|
|
||||||
version = "0.1.0"
|
|
||||||
edition = "2021"
|
|
||||||
|
|
||||||
[lib]
|
|
||||||
proc-macro = true
|
|
||||||
|
|
||||||
[[test]]
|
|
||||||
name = "structfields_tests"
|
|
||||||
path = "tests/structfields_test.rs"
|
|
||||||
|
|
||||||
[dev-dependencies]
|
|
||||||
nac3core = { path = ".." }
|
|
||||||
trybuild = { version = "1.0", features = ["diff"] }
|
|
||||||
|
|
||||||
[dependencies]
|
|
||||||
proc-macro2 = "1.0"
|
|
||||||
proc-macro-error = "1.0"
|
|
||||||
syn = "2.0"
|
|
||||||
quote = "1.0"
|
|
@ -1,320 +0,0 @@
|
|||||||
use proc_macro::TokenStream;
|
|
||||||
use proc_macro_error::{abort, proc_macro_error};
|
|
||||||
use quote::quote;
|
|
||||||
use syn::{
|
|
||||||
parse_macro_input, spanned::Spanned, Data, DataStruct, Expr, ExprField, ExprMethodCall,
|
|
||||||
ExprPath, GenericArgument, Ident, LitStr, Path, PathArguments, Type, TypePath,
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Extracts all generic arguments of a [`Type`] into a [`Vec`].
|
|
||||||
///
|
|
||||||
/// Returns [`Some`] of a possibly-empty [`Vec`] if the path of `ty` matches with
|
|
||||||
/// `expected_ty_name`, otherwise returns [`None`].
|
|
||||||
fn extract_generic_args(expected_ty_name: &'static str, ty: &Type) -> Option<Vec<GenericArgument>> {
|
|
||||||
let Type::Path(TypePath { qself: None, path, .. }) = ty else {
|
|
||||||
return None;
|
|
||||||
};
|
|
||||||
|
|
||||||
let segments = &path.segments;
|
|
||||||
if segments.len() != 1 {
|
|
||||||
return None;
|
|
||||||
};
|
|
||||||
|
|
||||||
let segment = segments.iter().next().unwrap();
|
|
||||||
if segment.ident != expected_ty_name {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
let PathArguments::AngleBracketed(path_args) = &segment.arguments else {
|
|
||||||
return Some(Vec::new());
|
|
||||||
};
|
|
||||||
let args = &path_args.args;
|
|
||||||
|
|
||||||
Some(args.iter().cloned().collect::<Vec<_>>())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Maps a `path` matching one of the `target_idents` into the `replacement` [`Ident`].
|
|
||||||
fn map_path_to_ident(path: &Path, target_idents: &[&str], replacement: &str) -> Option<Ident> {
|
|
||||||
path.require_ident()
|
|
||||||
.ok()
|
|
||||||
.filter(|ident| target_idents.iter().any(|target| ident == target))
|
|
||||||
.map(|ident| Ident::new(replacement, ident.span()))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Extracts the left-hand side of a dot-expression.
|
|
||||||
fn extract_dot_operand(expr: &Expr) -> Option<&Expr> {
|
|
||||||
match expr {
|
|
||||||
Expr::MethodCall(ExprMethodCall { receiver: operand, .. })
|
|
||||||
| Expr::Field(ExprField { base: operand, .. }) => Some(operand),
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Replaces the top-level receiver of a dot-expression with an [`Ident`], returning `Some(&mut expr)` if the
|
|
||||||
/// replacement is performed.
|
|
||||||
///
|
|
||||||
/// The top-level receiver is the left-most receiver expression, e.g. the top-level receiver of `a.b.c.foo()` is `a`.
|
|
||||||
fn replace_top_level_receiver(expr: &mut Expr, ident: Ident) -> Option<&mut Expr> {
|
|
||||||
if let Expr::MethodCall(ExprMethodCall { receiver: operand, .. })
|
|
||||||
| Expr::Field(ExprField { base: operand, .. }) = expr
|
|
||||||
{
|
|
||||||
return if extract_dot_operand(operand).is_some() {
|
|
||||||
if replace_top_level_receiver(operand, ident).is_some() {
|
|
||||||
Some(expr)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
*operand = Box::new(Expr::Path(ExprPath {
|
|
||||||
attrs: Vec::default(),
|
|
||||||
qself: None,
|
|
||||||
path: ident.into(),
|
|
||||||
}));
|
|
||||||
|
|
||||||
Some(expr)
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Iterates all operands to the left-hand side of the `.` of an [expression][`Expr`], i.e. the container operand of all
|
|
||||||
/// [`Expr::Field`] and the receiver operand of all [`Expr::MethodCall`].
|
|
||||||
///
|
|
||||||
/// The iterator will return the operand expressions in reverse order of appearance. For example, `a.b.c.func()` will
|
|
||||||
/// return `vec![c, b, a]`.
|
|
||||||
fn iter_dot_operands(expr: &Expr) -> impl Iterator<Item = &Expr> {
|
|
||||||
let mut o = extract_dot_operand(expr);
|
|
||||||
|
|
||||||
std::iter::from_fn(move || {
|
|
||||||
let this = o;
|
|
||||||
o = o.as_ref().and_then(|o| extract_dot_operand(o));
|
|
||||||
|
|
||||||
this
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Normalizes a value expression for use when creating an instance of this structure, returning a
|
|
||||||
/// [`proc_macro2::TokenStream`] of tokens representing the normalized expression.
|
|
||||||
fn normalize_value_expr(expr: &Expr) -> proc_macro2::TokenStream {
|
|
||||||
match &expr {
|
|
||||||
Expr::Path(ExprPath { qself: None, path, .. }) => {
|
|
||||||
if let Some(ident) = map_path_to_ident(path, &["usize", "size_t"], "llvm_usize") {
|
|
||||||
quote! { #ident }
|
|
||||||
} else {
|
|
||||||
abort!(
|
|
||||||
path,
|
|
||||||
format!(
|
|
||||||
"Expected one of `size_t`, `usize`, or an implicit call expression in #[value_type(...)], found {}",
|
|
||||||
quote!(#expr).to_string(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Expr::Call(_) => {
|
|
||||||
quote! { ctx.#expr }
|
|
||||||
}
|
|
||||||
|
|
||||||
Expr::MethodCall(_) => {
|
|
||||||
let base_receiver = iter_dot_operands(expr).last();
|
|
||||||
|
|
||||||
match base_receiver {
|
|
||||||
// `usize.{...}`, `size_t.{...}` -> Rewrite the identifiers to `llvm_usize`
|
|
||||||
Some(Expr::Path(ExprPath { qself: None, path, .. }))
|
|
||||||
if map_path_to_ident(path, &["usize", "size_t"], "llvm_usize").is_some() =>
|
|
||||||
{
|
|
||||||
let ident =
|
|
||||||
map_path_to_ident(path, &["usize", "size_t"], "llvm_usize").unwrap();
|
|
||||||
|
|
||||||
let mut expr = expr.clone();
|
|
||||||
let expr = replace_top_level_receiver(&mut expr, ident).unwrap();
|
|
||||||
|
|
||||||
quote!(#expr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// `ctx.{...}`, `context.{...}` -> Rewrite the identifiers to `ctx`
|
|
||||||
Some(Expr::Path(ExprPath { qself: None, path, .. }))
|
|
||||||
if map_path_to_ident(path, &["ctx", "context"], "ctx").is_some() =>
|
|
||||||
{
|
|
||||||
let ident = map_path_to_ident(path, &["ctx", "context"], "ctx").unwrap();
|
|
||||||
|
|
||||||
let mut expr = expr.clone();
|
|
||||||
let expr = replace_top_level_receiver(&mut expr, ident).unwrap();
|
|
||||||
|
|
||||||
quote!(#expr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// No reserved identifier prefix -> Prepend `ctx.` to the entire expression
|
|
||||||
_ => quote! { ctx.#expr },
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
_ => {
|
|
||||||
abort!(
|
|
||||||
expr,
|
|
||||||
format!(
|
|
||||||
"Expected one of `size_t`, `usize`, or an implicit call expression in #[value_type(...)], found {}",
|
|
||||||
quote!(#expr).to_string(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Derives an implementation of `codegen::types::structure::StructFields`.
|
|
||||||
///
|
|
||||||
/// The benefit of using `#[derive(StructFields)]` is that all index- or order-dependent logic required by
|
|
||||||
/// `impl StructFields` is automatically generated by this implementation, including the field index as required by
|
|
||||||
/// `StructField::new` and the fields as returned by `StructFields::to_vec`.
|
|
||||||
///
|
|
||||||
/// # Prerequisites
|
|
||||||
///
|
|
||||||
/// In order to derive from [`StructFields`], you must implement (or derive) [`Eq`] and [`Copy`] as required by
|
|
||||||
/// `StructFields`.
|
|
||||||
///
|
|
||||||
/// Moreover, `#[derive(StructFields)]` can only be used for `struct`s with named fields, and may only contain fields
|
|
||||||
/// with either `StructField` or [`PhantomData`] types.
|
|
||||||
///
|
|
||||||
/// # Attributes for [`StructFields`]
|
|
||||||
///
|
|
||||||
/// Each `StructField` field must be declared with the `#[value_type(...)]` attribute. The argument of `value_type`
|
|
||||||
/// accepts one of the following:
|
|
||||||
///
|
|
||||||
/// - An expression returning an instance of `inkwell::types::BasicType` (with or without the receiver `ctx`/`context`).
|
|
||||||
/// For example, `context.i8_type()`, `ctx.i8_type()`, and `i8_type()` all refer to `i8`.
|
|
||||||
/// - The reserved identifiers `usize` and `size_t` referring to an `inkwell::types::IntType` of the platform-dependent
|
|
||||||
/// integer size. `usize` and `size_t` can also be used as the receiver to other method calls, e.g.
|
|
||||||
/// `usize.array_type(3)`.
|
|
||||||
///
|
|
||||||
/// # Example
|
|
||||||
///
|
|
||||||
/// The following is an example of an LLVM slice implemented using `#[derive(StructFields)]`.
|
|
||||||
///
|
|
||||||
/// ```rust,ignore
|
|
||||||
/// use nac3core::{
|
|
||||||
/// codegen::types::structure::StructField,
|
|
||||||
/// inkwell::{
|
|
||||||
/// values::{IntValue, PointerValue},
|
|
||||||
/// AddressSpace,
|
|
||||||
/// },
|
|
||||||
/// };
|
|
||||||
/// use nac3core_derive::StructFields;
|
|
||||||
///
|
|
||||||
/// // All classes that implement StructFields must also implement Eq and Copy
|
|
||||||
/// #[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
|
||||||
/// pub struct SliceValue<'ctx> {
|
|
||||||
/// // Declares ptr have a value type of i8*
|
|
||||||
/// //
|
|
||||||
/// // Can also be written as `ctx.i8_type().ptr_type(...)` or `context.i8_type().ptr_type(...)`
|
|
||||||
/// #[value_type(i8_type().ptr_type(AddressSpace::default()))]
|
|
||||||
/// ptr: StructField<'ctx, PointerValue<'ctx>>,
|
|
||||||
///
|
|
||||||
/// // Declares len have a value type of usize, depending on the target compilation platform
|
|
||||||
/// #[value_type(usize)]
|
|
||||||
/// len: StructField<'ctx, IntValue<'ctx>>,
|
|
||||||
/// }
|
|
||||||
/// ```
|
|
||||||
#[proc_macro_derive(StructFields, attributes(value_type))]
|
|
||||||
#[proc_macro_error]
|
|
||||||
pub fn derive(input: TokenStream) -> TokenStream {
|
|
||||||
let input = parse_macro_input!(input as syn::DeriveInput);
|
|
||||||
let ident = &input.ident;
|
|
||||||
|
|
||||||
let Data::Struct(DataStruct { fields, .. }) = &input.data else {
|
|
||||||
abort!(input, "Only structs with named fields are supported");
|
|
||||||
};
|
|
||||||
if let Err(err_span) =
|
|
||||||
fields
|
|
||||||
.iter()
|
|
||||||
.try_for_each(|field| if field.ident.is_some() { Ok(()) } else { Err(field.span()) })
|
|
||||||
{
|
|
||||||
abort!(err_span, "Only structs with named fields are supported");
|
|
||||||
};
|
|
||||||
|
|
||||||
// Check if struct<'ctx>
|
|
||||||
if input.generics.params.len() != 1 {
|
|
||||||
abort!(input.generics, "Expected exactly 1 generic parameter")
|
|
||||||
}
|
|
||||||
|
|
||||||
let phantom_info = fields
|
|
||||||
.iter()
|
|
||||||
.filter(|field| extract_generic_args("PhantomData", &field.ty).is_some())
|
|
||||||
.map(|field| field.ident.as_ref().unwrap())
|
|
||||||
.cloned()
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
let field_info = fields
|
|
||||||
.iter()
|
|
||||||
.filter(|field| extract_generic_args("PhantomData", &field.ty).is_none())
|
|
||||||
.map(|field| {
|
|
||||||
let ident = field.ident.as_ref().unwrap();
|
|
||||||
let ty = &field.ty;
|
|
||||||
|
|
||||||
let Some(_) = extract_generic_args("StructField", ty) else {
|
|
||||||
abort!(field, "Only StructField and PhantomData are allowed")
|
|
||||||
};
|
|
||||||
|
|
||||||
let attrs = &field.attrs;
|
|
||||||
let Some(value_type_attr) =
|
|
||||||
attrs.iter().find(|attr| attr.path().is_ident("value_type"))
|
|
||||||
else {
|
|
||||||
abort!(field, "Expected #[value_type(...)] attribute for field");
|
|
||||||
};
|
|
||||||
|
|
||||||
let Ok(value_type_expr) = value_type_attr.parse_args::<Expr>() else {
|
|
||||||
abort!(value_type_attr, "Expected expression in #[value_type(...)]");
|
|
||||||
};
|
|
||||||
|
|
||||||
let value_expr_toks = normalize_value_expr(&value_type_expr);
|
|
||||||
|
|
||||||
(ident.clone(), value_expr_toks)
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
// `<*>::new` impl of `StructField` and `PhantomData` for `StructFields::new`
|
|
||||||
let phantoms_create = phantom_info
|
|
||||||
.iter()
|
|
||||||
.map(|id| quote! { #id: ::std::marker::PhantomData })
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
let fields_create = field_info
|
|
||||||
.iter()
|
|
||||||
.map(|(id, ty)| {
|
|
||||||
let id_lit = LitStr::new(&id.to_string(), id.span());
|
|
||||||
quote! {
|
|
||||||
#id: ::nac3core::codegen::types::structure::StructField::create(
|
|
||||||
&mut counter,
|
|
||||||
#id_lit,
|
|
||||||
#ty,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
// `.into()` impl of `StructField` for `StructFields::to_vec`
|
|
||||||
let fields_into =
|
|
||||||
field_info.iter().map(|(id, _)| quote! { self.#id.into() }).collect::<Vec<_>>();
|
|
||||||
|
|
||||||
let impl_block = quote! {
|
|
||||||
impl<'ctx> ::nac3core::codegen::types::structure::StructFields<'ctx> for #ident<'ctx> {
|
|
||||||
fn new(ctx: impl ::nac3core::inkwell::context::AsContextRef<'ctx>, llvm_usize: ::nac3core::inkwell::types::IntType<'ctx>) -> Self {
|
|
||||||
let ctx = unsafe { ::nac3core::inkwell::context::ContextRef::new(ctx.as_ctx_ref()) };
|
|
||||||
|
|
||||||
let mut counter = ::nac3core::codegen::types::structure::FieldIndexCounter::default();
|
|
||||||
|
|
||||||
#ident {
|
|
||||||
#(#fields_create),*
|
|
||||||
#(#phantoms_create),*
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn to_vec(&self) -> ::std::vec::Vec<(&'static str, ::nac3core::inkwell::types::BasicTypeEnum<'ctx>)> {
|
|
||||||
vec![
|
|
||||||
#(#fields_into),*
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
impl_block.into()
|
|
||||||
}
|
|
@ -1,9 +0,0 @@
|
|||||||
use nac3core_derive::StructFields;
|
|
||||||
use std::marker::PhantomData;
|
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
|
||||||
pub struct EmptyValue<'ctx> {
|
|
||||||
_phantom: PhantomData<&'ctx ()>,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() {}
|
|
@ -1,20 +0,0 @@
|
|||||||
use nac3core::{
|
|
||||||
codegen::types::structure::StructField,
|
|
||||||
inkwell::{
|
|
||||||
values::{IntValue, PointerValue},
|
|
||||||
AddressSpace,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
use nac3core_derive::StructFields;
|
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
|
||||||
pub struct NDArrayValue<'ctx> {
|
|
||||||
#[value_type(usize)]
|
|
||||||
ndims: StructField<'ctx, IntValue<'ctx>>,
|
|
||||||
#[value_type(usize.ptr_type(AddressSpace::default()))]
|
|
||||||
shape: StructField<'ctx, PointerValue<'ctx>>,
|
|
||||||
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
|
|
||||||
data: StructField<'ctx, PointerValue<'ctx>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() {}
|
|
@ -1,18 +0,0 @@
|
|||||||
use nac3core::{
|
|
||||||
codegen::types::structure::StructField,
|
|
||||||
inkwell::{
|
|
||||||
values::{IntValue, PointerValue},
|
|
||||||
AddressSpace,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
use nac3core_derive::StructFields;
|
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
|
||||||
pub struct SliceValue<'ctx> {
|
|
||||||
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
|
|
||||||
ptr: StructField<'ctx, PointerValue<'ctx>>,
|
|
||||||
#[value_type(usize)]
|
|
||||||
len: StructField<'ctx, IntValue<'ctx>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() {}
|
|
@ -1,18 +0,0 @@
|
|||||||
use nac3core::{
|
|
||||||
codegen::types::structure::StructField,
|
|
||||||
inkwell::{
|
|
||||||
values::{IntValue, PointerValue},
|
|
||||||
AddressSpace,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
use nac3core_derive::StructFields;
|
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
|
||||||
pub struct SliceValue<'ctx> {
|
|
||||||
#[value_type(context.i8_type().ptr_type(AddressSpace::default()))]
|
|
||||||
ptr: StructField<'ctx, PointerValue<'ctx>>,
|
|
||||||
#[value_type(usize)]
|
|
||||||
len: StructField<'ctx, IntValue<'ctx>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() {}
|
|
@ -1,18 +0,0 @@
|
|||||||
use nac3core::{
|
|
||||||
codegen::types::structure::StructField,
|
|
||||||
inkwell::{
|
|
||||||
values::{IntValue, PointerValue},
|
|
||||||
AddressSpace,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
use nac3core_derive::StructFields;
|
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
|
||||||
pub struct SliceValue<'ctx> {
|
|
||||||
#[value_type(ctx.i8_type().ptr_type(AddressSpace::default()))]
|
|
||||||
ptr: StructField<'ctx, PointerValue<'ctx>>,
|
|
||||||
#[value_type(usize)]
|
|
||||||
len: StructField<'ctx, IntValue<'ctx>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() {}
|
|
@ -1,18 +0,0 @@
|
|||||||
use nac3core::{
|
|
||||||
codegen::types::structure::StructField,
|
|
||||||
inkwell::{
|
|
||||||
values::{IntValue, PointerValue},
|
|
||||||
AddressSpace,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
use nac3core_derive::StructFields;
|
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
|
||||||
pub struct SliceValue<'ctx> {
|
|
||||||
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
|
|
||||||
ptr: StructField<'ctx, PointerValue<'ctx>>,
|
|
||||||
#[value_type(size_t)]
|
|
||||||
len: StructField<'ctx, IntValue<'ctx>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() {}
|
|
@ -1,10 +0,0 @@
|
|||||||
#[test]
|
|
||||||
fn test_parse_empty() {
|
|
||||||
let t = trybuild::TestCases::new();
|
|
||||||
t.pass("tests/structfields_empty.rs");
|
|
||||||
t.pass("tests/structfields_slice.rs");
|
|
||||||
t.pass("tests/structfields_slice_ctx.rs");
|
|
||||||
t.pass("tests/structfields_slice_context.rs");
|
|
||||||
t.pass("tests/structfields_slice_sizet.rs");
|
|
||||||
t.pass("tests/structfields_ndarray.rs");
|
|
||||||
}
|
|
File diff suppressed because it is too large
Load Diff
1763
nac3core/src/codegen/classes.rs
Normal file
1763
nac3core/src/codegen/classes.rs
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,9 +1,3 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
use indexmap::IndexMap;
|
|
||||||
|
|
||||||
use nac3parser::ast::StrRef;
|
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
symbol_resolver::SymbolValue,
|
symbol_resolver::SymbolValue,
|
||||||
toplevel::DefinitionId,
|
toplevel::DefinitionId,
|
||||||
@ -15,6 +9,10 @@ use crate::{
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use indexmap::IndexMap;
|
||||||
|
use nac3parser::ast::StrRef;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
pub struct ConcreteTypeStore {
|
pub struct ConcreteTypeStore {
|
||||||
store: Vec<ConcreteTypeEnum>,
|
store: Vec<ConcreteTypeEnum>,
|
||||||
}
|
}
|
||||||
@ -27,7 +25,6 @@ pub struct ConcreteFuncArg {
|
|||||||
pub name: StrRef,
|
pub name: StrRef,
|
||||||
pub ty: ConcreteType,
|
pub ty: ConcreteType,
|
||||||
pub default_value: Option<SymbolValue>,
|
pub default_value: Option<SymbolValue>,
|
||||||
pub is_vararg: bool,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
@ -49,17 +46,12 @@ pub enum ConcreteTypeEnum {
|
|||||||
TPrimitive(Primitive),
|
TPrimitive(Primitive),
|
||||||
TTuple {
|
TTuple {
|
||||||
ty: Vec<ConcreteType>,
|
ty: Vec<ConcreteType>,
|
||||||
is_vararg_ctx: bool,
|
|
||||||
},
|
},
|
||||||
TObj {
|
TObj {
|
||||||
obj_id: DefinitionId,
|
obj_id: DefinitionId,
|
||||||
fields: HashMap<StrRef, (ConcreteType, bool)>,
|
fields: HashMap<StrRef, (ConcreteType, bool)>,
|
||||||
params: IndexMap<TypeVarId, ConcreteType>,
|
params: IndexMap<TypeVarId, ConcreteType>,
|
||||||
},
|
},
|
||||||
TModule {
|
|
||||||
module_id: DefinitionId,
|
|
||||||
methods: HashMap<StrRef, (ConcreteType, bool)>,
|
|
||||||
},
|
|
||||||
TVirtual {
|
TVirtual {
|
||||||
ty: ConcreteType,
|
ty: ConcreteType,
|
||||||
},
|
},
|
||||||
@ -110,16 +102,8 @@ impl ConcreteTypeStore {
|
|||||||
.iter()
|
.iter()
|
||||||
.map(|arg| ConcreteFuncArg {
|
.map(|arg| ConcreteFuncArg {
|
||||||
name: arg.name,
|
name: arg.name,
|
||||||
ty: if arg.is_vararg {
|
ty: self.from_unifier_type(unifier, primitives, arg.ty, cache),
|
||||||
let tuple_ty = unifier
|
|
||||||
.add_ty(TypeEnum::TTuple { ty: vec![arg.ty], is_vararg_ctx: true });
|
|
||||||
|
|
||||||
self.from_unifier_type(unifier, primitives, tuple_ty, cache)
|
|
||||||
} else {
|
|
||||||
self.from_unifier_type(unifier, primitives, arg.ty, cache)
|
|
||||||
},
|
|
||||||
default_value: arg.default_value.clone(),
|
default_value: arg.default_value.clone(),
|
||||||
is_vararg: arg.is_vararg,
|
|
||||||
})
|
})
|
||||||
.collect(),
|
.collect(),
|
||||||
ret: self.from_unifier_type(unifier, primitives, signature.ret, cache),
|
ret: self.from_unifier_type(unifier, primitives, signature.ret, cache),
|
||||||
@ -174,12 +158,11 @@ impl ConcreteTypeStore {
|
|||||||
cache.insert(ty, None);
|
cache.insert(ty, None);
|
||||||
let ty_enum = unifier.get_ty(ty);
|
let ty_enum = unifier.get_ty(ty);
|
||||||
let result = match &*ty_enum {
|
let result = match &*ty_enum {
|
||||||
TypeEnum::TTuple { ty, is_vararg_ctx } => ConcreteTypeEnum::TTuple {
|
TypeEnum::TTuple { ty } => ConcreteTypeEnum::TTuple {
|
||||||
ty: ty
|
ty: ty
|
||||||
.iter()
|
.iter()
|
||||||
.map(|t| self.from_unifier_type(unifier, primitives, *t, cache))
|
.map(|t| self.from_unifier_type(unifier, primitives, *t, cache))
|
||||||
.collect(),
|
.collect(),
|
||||||
is_vararg_ctx: *is_vararg_ctx,
|
|
||||||
},
|
},
|
||||||
TypeEnum::TObj { obj_id, fields, params } => ConcreteTypeEnum::TObj {
|
TypeEnum::TObj { obj_id, fields, params } => ConcreteTypeEnum::TObj {
|
||||||
obj_id: *obj_id,
|
obj_id: *obj_id,
|
||||||
@ -209,19 +192,6 @@ impl ConcreteTypeStore {
|
|||||||
})
|
})
|
||||||
.collect(),
|
.collect(),
|
||||||
},
|
},
|
||||||
TypeEnum::TModule { module_id, attributes } => ConcreteTypeEnum::TModule {
|
|
||||||
module_id: *module_id,
|
|
||||||
methods: attributes
|
|
||||||
.iter()
|
|
||||||
.filter_map(|(name, ty)| match &*unifier.get_ty(ty.0) {
|
|
||||||
TypeEnum::TFunc(..) | TypeEnum::TObj { .. } => None,
|
|
||||||
_ => Some((
|
|
||||||
*name,
|
|
||||||
(self.from_unifier_type(unifier, primitives, ty.0, cache), ty.1),
|
|
||||||
)),
|
|
||||||
})
|
|
||||||
.collect(),
|
|
||||||
},
|
|
||||||
TypeEnum::TVirtual { ty } => ConcreteTypeEnum::TVirtual {
|
TypeEnum::TVirtual { ty } => ConcreteTypeEnum::TVirtual {
|
||||||
ty: self.from_unifier_type(unifier, primitives, *ty, cache),
|
ty: self.from_unifier_type(unifier, primitives, *ty, cache),
|
||||||
},
|
},
|
||||||
@ -278,12 +248,11 @@ impl ConcreteTypeStore {
|
|||||||
*cache.get_mut(&cty).unwrap() = Some(ty);
|
*cache.get_mut(&cty).unwrap() = Some(ty);
|
||||||
return ty;
|
return ty;
|
||||||
}
|
}
|
||||||
ConcreteTypeEnum::TTuple { ty, is_vararg_ctx } => TypeEnum::TTuple {
|
ConcreteTypeEnum::TTuple { ty } => TypeEnum::TTuple {
|
||||||
ty: ty
|
ty: ty
|
||||||
.iter()
|
.iter()
|
||||||
.map(|cty| self.to_unifier_type(unifier, primitives, *cty, cache))
|
.map(|cty| self.to_unifier_type(unifier, primitives, *cty, cache))
|
||||||
.collect(),
|
.collect(),
|
||||||
is_vararg_ctx: *is_vararg_ctx,
|
|
||||||
},
|
},
|
||||||
ConcreteTypeEnum::TVirtual { ty } => {
|
ConcreteTypeEnum::TVirtual { ty } => {
|
||||||
TypeEnum::TVirtual { ty: self.to_unifier_type(unifier, primitives, *ty, cache) }
|
TypeEnum::TVirtual { ty: self.to_unifier_type(unifier, primitives, *ty, cache) }
|
||||||
@ -301,15 +270,6 @@ impl ConcreteTypeStore {
|
|||||||
TypeVar { id, ty }
|
TypeVar { id, ty }
|
||||||
})),
|
})),
|
||||||
},
|
},
|
||||||
ConcreteTypeEnum::TModule { module_id, methods } => TypeEnum::TModule {
|
|
||||||
module_id: *module_id,
|
|
||||||
attributes: methods
|
|
||||||
.iter()
|
|
||||||
.map(|(name, cty)| {
|
|
||||||
(*name, (self.to_unifier_type(unifier, primitives, cty.0, cache), cty.1))
|
|
||||||
})
|
|
||||||
.collect::<HashMap<_, _>>(),
|
|
||||||
},
|
|
||||||
ConcreteTypeEnum::TFunc { args, ret, vars } => TypeEnum::TFunc(FunSignature {
|
ConcreteTypeEnum::TFunc { args, ret, vars } => TypeEnum::TFunc(FunSignature {
|
||||||
args: args
|
args: args
|
||||||
.iter()
|
.iter()
|
||||||
@ -317,7 +277,6 @@ impl ConcreteTypeStore {
|
|||||||
name: arg.name,
|
name: arg.name,
|
||||||
ty: self.to_unifier_type(unifier, primitives, arg.ty, cache),
|
ty: self.to_unifier_type(unifier, primitives, arg.ty, cache),
|
||||||
default_value: arg.default_value.clone(),
|
default_value: arg.default_value.clone(),
|
||||||
is_vararg: false,
|
|
||||||
})
|
})
|
||||||
.collect(),
|
.collect(),
|
||||||
ret: self.to_unifier_type(unifier, primitives, *ret, cache),
|
ret: self.to_unifier_type(unifier, primitives, *ret, cache),
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -1,10 +1,8 @@
|
|||||||
use inkwell::{
|
use inkwell::attributes::{Attribute, AttributeLoc};
|
||||||
attributes::{Attribute, AttributeLoc},
|
use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue};
|
||||||
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue},
|
|
||||||
};
|
|
||||||
use itertools::Either;
|
use itertools::Either;
|
||||||
|
|
||||||
use super::CodeGenContext;
|
use crate::codegen::CodeGenContext;
|
||||||
|
|
||||||
/// Macro to generate extern function
|
/// Macro to generate extern function
|
||||||
/// Both function return type and function parameter type are `FloatValue`
|
/// Both function return type and function parameter type are `FloatValue`
|
||||||
@ -15,11 +13,11 @@ use super::CodeGenContext;
|
|||||||
/// * `$extern_fn:literal`: Name of underlying extern function
|
/// * `$extern_fn:literal`: Name of underlying extern function
|
||||||
///
|
///
|
||||||
/// Optional Arguments:
|
/// Optional Arguments:
|
||||||
/// * `$(,$attributes:literal)*)`: Attributes linked with the extern function.
|
/// * `$(,$attributes:literal)*)`: Attributes linked with the extern function
|
||||||
/// The default attributes are "mustprogress", "nofree", "nounwind", "willreturn", and "writeonly".
|
/// The default attributes are "mustprogress", "nofree", "nounwind", "willreturn", and "writeonly"
|
||||||
/// These will be used unless other attributes are specified
|
/// These will be used unless other attributes are specified
|
||||||
/// * `$(,$args:ident)*`: Operands of the extern function
|
/// * `$(,$args:ident)*`: Operands of the extern function
|
||||||
/// The data type of these operands will be set to `FloatValue`
|
/// The data type of these operands will be set to `FloatValue`
|
||||||
///
|
///
|
||||||
macro_rules! generate_extern_fn {
|
macro_rules! generate_extern_fn {
|
||||||
("unary", $fn_name:ident, $extern_fn:literal) => {
|
("unary", $fn_name:ident, $extern_fn:literal) => {
|
||||||
@ -132,62 +130,3 @@ pub fn call_ldexp<'ctx>(
|
|||||||
.map(Either::unwrap_left)
|
.map(Either::unwrap_left)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Macro to generate `np_linalg` and `sp_linalg` functions
|
|
||||||
/// The function takes as input `NDArray` and returns ()
|
|
||||||
///
|
|
||||||
/// Arguments:
|
|
||||||
/// * `$fn_name:ident`: The identifier of the rust function to be generated
|
|
||||||
/// * `$extern_fn:literal`: Name of underlying extern function
|
|
||||||
/// * (2/3/4): Number of `NDArray` that function takes as input
|
|
||||||
///
|
|
||||||
/// Note:
|
|
||||||
/// The operands and resulting `NDArray` are both passed as input to the funcion
|
|
||||||
/// It is the responsibility of caller to ensure that output `NDArray` is properly allocated on stack
|
|
||||||
/// The function changes the content of the output `NDArray` in-place
|
|
||||||
macro_rules! generate_linalg_extern_fn {
|
|
||||||
($fn_name:ident, $extern_fn:literal, 2) => {
|
|
||||||
generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2);
|
|
||||||
};
|
|
||||||
($fn_name:ident, $extern_fn:literal, 3) => {
|
|
||||||
generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2, mat3);
|
|
||||||
};
|
|
||||||
($fn_name:ident, $extern_fn:literal, 4) => {
|
|
||||||
generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2, mat3, mat4);
|
|
||||||
};
|
|
||||||
($fn_name:ident, $extern_fn:literal $(,$input_matrix:ident)*) => {
|
|
||||||
#[doc = concat!("Invokes the linalg `", stringify!($extern_fn), " function." )]
|
|
||||||
pub fn $fn_name<'ctx>(
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>
|
|
||||||
$(,$input_matrix: BasicValueEnum<'ctx>)*,
|
|
||||||
name: Option<&str>,
|
|
||||||
){
|
|
||||||
const FN_NAME: &str = $extern_fn;
|
|
||||||
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
|
||||||
let fn_type = ctx.ctx.void_type().fn_type(&[$($input_matrix.get_type().into()),*], false);
|
|
||||||
|
|
||||||
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
|
||||||
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
|
|
||||||
func.add_attribute(
|
|
||||||
AttributeLoc::Function,
|
|
||||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
func
|
|
||||||
});
|
|
||||||
|
|
||||||
ctx.builder.build_call(extern_fn, &[$($input_matrix.into(),)*], name.unwrap_or_default()).unwrap();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
generate_linalg_extern_fn!(call_np_linalg_cholesky, "np_linalg_cholesky", 2);
|
|
||||||
generate_linalg_extern_fn!(call_np_linalg_qr, "np_linalg_qr", 3);
|
|
||||||
generate_linalg_extern_fn!(call_np_linalg_svd, "np_linalg_svd", 4);
|
|
||||||
generate_linalg_extern_fn!(call_np_linalg_inv, "np_linalg_inv", 2);
|
|
||||||
generate_linalg_extern_fn!(call_np_linalg_pinv, "np_linalg_pinv", 2);
|
|
||||||
generate_linalg_extern_fn!(call_np_linalg_matrix_power, "np_linalg_matrix_power", 3);
|
|
||||||
generate_linalg_extern_fn!(call_np_linalg_det, "np_linalg_det", 2);
|
|
||||||
generate_linalg_extern_fn!(call_sp_linalg_lu, "sp_linalg_lu", 3);
|
|
||||||
generate_linalg_extern_fn!(call_sp_linalg_schur, "sp_linalg_schur", 3);
|
|
||||||
generate_linalg_extern_fn!(call_sp_linalg_hessenberg, "sp_linalg_hessenberg", 3);
|
|
||||||
|
@ -1,27 +1,20 @@
|
|||||||
use inkwell::{
|
|
||||||
context::Context,
|
|
||||||
targets::TargetMachine,
|
|
||||||
types::{BasicTypeEnum, IntType},
|
|
||||||
values::{BasicValueEnum, IntValue, PointerValue},
|
|
||||||
};
|
|
||||||
|
|
||||||
use nac3parser::ast::{Expr, Stmt, StrRef};
|
|
||||||
|
|
||||||
use super::{bool_to_i1, bool_to_i8, expr::*, stmt::*, values::ArraySliceValue, CodeGenContext};
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
codegen::{bool_to_i1, bool_to_i8, classes::ArraySliceValue, expr::*, stmt::*, CodeGenContext},
|
||||||
symbol_resolver::ValueEnum,
|
symbol_resolver::ValueEnum,
|
||||||
toplevel::{DefinitionId, TopLevelDef},
|
toplevel::{DefinitionId, TopLevelDef},
|
||||||
typecheck::typedef::{FunSignature, Type},
|
typecheck::typedef::{FunSignature, Type},
|
||||||
};
|
};
|
||||||
|
use inkwell::{
|
||||||
|
context::Context,
|
||||||
|
types::{BasicTypeEnum, IntType},
|
||||||
|
values::{BasicValueEnum, IntValue, PointerValue},
|
||||||
|
};
|
||||||
|
use nac3parser::ast::{Expr, Stmt, StrRef};
|
||||||
|
|
||||||
pub trait CodeGenerator {
|
pub trait CodeGenerator {
|
||||||
/// Return the module name for the code generator.
|
/// Return the module name for the code generator.
|
||||||
fn get_name(&self) -> &str;
|
fn get_name(&self) -> &str;
|
||||||
|
|
||||||
/// Return an instance of [`IntType`] corresponding to the type of `size_t` for this instance.
|
|
||||||
///
|
|
||||||
/// Prefer using [`CodeGenContext::get_size_type`] if [`CodeGenContext`] is available, as it is
|
|
||||||
/// equivalent to this function in a more concise syntax.
|
|
||||||
fn get_size_type<'ctx>(&self, ctx: &'ctx Context) -> IntType<'ctx>;
|
fn get_size_type<'ctx>(&self, ctx: &'ctx Context) -> IntType<'ctx>;
|
||||||
|
|
||||||
/// Generate function call and returns the function return value.
|
/// Generate function call and returns the function return value.
|
||||||
@ -64,7 +57,6 @@ pub trait CodeGenerator {
|
|||||||
/// - fun: Function signature, definition ID and the substitution key.
|
/// - fun: Function signature, definition ID and the substitution key.
|
||||||
/// - params: Function parameters. Note that this does not include the object even if the
|
/// - params: Function parameters. Note that this does not include the object even if the
|
||||||
/// function is a class method.
|
/// function is a class method.
|
||||||
///
|
|
||||||
/// Note that this function should check if the function is generated in another thread (due to
|
/// Note that this function should check if the function is generated in another thread (due to
|
||||||
/// possible race condition), see the default implementation for an example.
|
/// possible race condition), see the default implementation for an example.
|
||||||
fn gen_func_instance<'ctx>(
|
fn gen_func_instance<'ctx>(
|
||||||
@ -131,45 +123,11 @@ pub trait CodeGenerator {
|
|||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
target: &Expr<Option<Type>>,
|
target: &Expr<Option<Type>>,
|
||||||
value: ValueEnum<'ctx>,
|
value: ValueEnum<'ctx>,
|
||||||
value_ty: Type,
|
|
||||||
) -> Result<(), String>
|
) -> Result<(), String>
|
||||||
where
|
where
|
||||||
Self: Sized,
|
Self: Sized,
|
||||||
{
|
{
|
||||||
gen_assign(self, ctx, target, value, value_ty)
|
gen_assign(self, ctx, target, value)
|
||||||
}
|
|
||||||
|
|
||||||
/// Generate code for an assignment expression where LHS is a `"target_list"`.
|
|
||||||
///
|
|
||||||
/// See <https://docs.python.org/3/reference/simple_stmts.html#assignment-statements>.
|
|
||||||
fn gen_assign_target_list<'ctx>(
|
|
||||||
&mut self,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
targets: &Vec<Expr<Option<Type>>>,
|
|
||||||
value: ValueEnum<'ctx>,
|
|
||||||
value_ty: Type,
|
|
||||||
) -> Result<(), String>
|
|
||||||
where
|
|
||||||
Self: Sized,
|
|
||||||
{
|
|
||||||
gen_assign_target_list(self, ctx, targets, value, value_ty)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generate code for an item assignment.
|
|
||||||
///
|
|
||||||
/// i.e., `target[key] = value`
|
|
||||||
fn gen_setitem<'ctx>(
|
|
||||||
&mut self,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
target: &Expr<Option<Type>>,
|
|
||||||
key: &Expr<Option<Type>>,
|
|
||||||
value: ValueEnum<'ctx>,
|
|
||||||
value_ty: Type,
|
|
||||||
) -> Result<(), String>
|
|
||||||
where
|
|
||||||
Self: Sized,
|
|
||||||
{
|
|
||||||
gen_setitem(self, ctx, target, key, value, value_ty)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate code for a while expression.
|
/// Generate code for a while expression.
|
||||||
@ -274,27 +232,19 @@ pub struct DefaultCodeGenerator {
|
|||||||
|
|
||||||
impl DefaultCodeGenerator {
|
impl DefaultCodeGenerator {
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn new(name: String, size_t: IntType<'_>) -> DefaultCodeGenerator {
|
pub fn new(name: String, size_t: u32) -> DefaultCodeGenerator {
|
||||||
assert!(matches!(size_t.get_bit_width(), 32 | 64));
|
assert!(matches!(size_t, 32 | 64));
|
||||||
DefaultCodeGenerator { name, size_t: size_t.get_bit_width() }
|
DefaultCodeGenerator { name, size_t }
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn with_target_machine(
|
|
||||||
name: String,
|
|
||||||
ctx: &Context,
|
|
||||||
target_machine: &TargetMachine,
|
|
||||||
) -> DefaultCodeGenerator {
|
|
||||||
let llvm_usize = ctx.ptr_sized_int_type(&target_machine.get_target_data(), None);
|
|
||||||
Self::new(name, llvm_usize)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CodeGenerator for DefaultCodeGenerator {
|
impl CodeGenerator for DefaultCodeGenerator {
|
||||||
|
/// Returns the name for this [`CodeGenerator`].
|
||||||
fn get_name(&self) -> &str {
|
fn get_name(&self) -> &str {
|
||||||
&self.name
|
&self.name
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns an LLVM integer type representing `size_t`.
|
||||||
fn get_size_type<'ctx>(&self, ctx: &'ctx Context) -> IntType<'ctx> {
|
fn get_size_type<'ctx>(&self, ctx: &'ctx Context) -> IntType<'ctx> {
|
||||||
// it should be unsigned, but we don't really need unsigned and this could save us from
|
// it should be unsigned, but we don't really need unsigned and this could save us from
|
||||||
// having to do a bit cast...
|
// having to do a bit cast...
|
||||||
|
185
nac3core/src/codegen/irrt/error_context.rs
Normal file
185
nac3core/src/codegen/irrt/error_context.rs
Normal file
@ -0,0 +1,185 @@
|
|||||||
|
use crate::codegen::{model::*, CodeGenContext, CodeGenerator};
|
||||||
|
|
||||||
|
use super::{string::Str, util::get_sized_dependent_function_name};
|
||||||
|
|
||||||
|
/// The [`IntModel`] of nac3core's error ID.
|
||||||
|
///
|
||||||
|
/// It is always [`Int32`].
|
||||||
|
type ErrorId = Int32;
|
||||||
|
|
||||||
|
pub struct ErrorIdsFields {
|
||||||
|
pub index_error: Field<FixedIntModel<ErrorId>>,
|
||||||
|
pub value_error: Field<FixedIntModel<ErrorId>>,
|
||||||
|
pub assertion_error: Field<FixedIntModel<ErrorId>>,
|
||||||
|
pub runtime_error: Field<FixedIntModel<ErrorId>>,
|
||||||
|
pub type_error: Field<FixedIntModel<ErrorId>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Corresponds to IRRT's `struct ErrorIds`
|
||||||
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
|
pub struct ErrorIds;
|
||||||
|
|
||||||
|
impl<'ctx> IsStruct<'ctx> for ErrorIds {
|
||||||
|
type Fields = ErrorIdsFields;
|
||||||
|
|
||||||
|
fn struct_name(&self) -> &'static str {
|
||||||
|
"ErrorIds"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_fields(&self, builder: &mut FieldBuilder) -> Self::Fields {
|
||||||
|
Self::Fields {
|
||||||
|
index_error: builder.add_field_auto("index_error"),
|
||||||
|
value_error: builder.add_field_auto("value_error"),
|
||||||
|
assertion_error: builder.add_field_auto("assertion_error"),
|
||||||
|
runtime_error: builder.add_field_auto("runtime_error"),
|
||||||
|
type_error: builder.add_field_auto("type_error"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ErrorContextFields {
|
||||||
|
pub error_ids: Field<PointerModel<StructModel<ErrorIds>>>,
|
||||||
|
pub error_id: Field<FixedIntModel<ErrorId>>,
|
||||||
|
pub message_template: Field<PointerModel<FixedIntModel<Byte>>>,
|
||||||
|
pub param1: Field<FixedIntModel<Int64>>,
|
||||||
|
pub param2: Field<FixedIntModel<Int64>>,
|
||||||
|
pub param3: Field<FixedIntModel<Int64>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Corresponds to IRRT's `struct ErrorContext`
|
||||||
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
|
pub struct ErrorContext;
|
||||||
|
|
||||||
|
impl<'ctx> IsStruct<'ctx> for ErrorContext {
|
||||||
|
type Fields = ErrorContextFields;
|
||||||
|
|
||||||
|
fn struct_name(&self) -> &'static str {
|
||||||
|
"ErrorIds"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_fields(&self, builder: &mut FieldBuilder) -> Self::Fields {
|
||||||
|
Self::Fields {
|
||||||
|
error_ids: builder.add_field_auto("error_ids"),
|
||||||
|
error_id: builder.add_field_auto("error_id"),
|
||||||
|
message_template: builder.add_field_auto("message_template"),
|
||||||
|
param1: builder.add_field_auto("param1"),
|
||||||
|
param2: builder.add_field_auto("param2"),
|
||||||
|
param3: builder.add_field_auto("param3"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare ErrorIds
|
||||||
|
fn build_error_ids<'ctx>(ctx: &CodeGenContext<'ctx, '_>) -> Pointer<'ctx, StructModel<ErrorIds>> {
|
||||||
|
// ErrorIdsLens.get_fields(ctx.ctx).assertion_error.
|
||||||
|
let error_ids = StructModel(ErrorIds).alloca(ctx, "error_ids");
|
||||||
|
let i32_model = FixedIntModel(Int32);
|
||||||
|
// i32_model.make_constant()
|
||||||
|
|
||||||
|
let get_string_id =
|
||||||
|
|string_id| i32_model.constant(ctx.ctx, ctx.resolver.get_string_id(string_id) as u64);
|
||||||
|
|
||||||
|
error_ids.gep(ctx, |f| f.index_error).store(ctx, get_string_id("0:IndexError"));
|
||||||
|
error_ids.gep(ctx, |f| f.value_error).store(ctx, get_string_id("0:ValueError"));
|
||||||
|
error_ids.gep(ctx, |f| f.assertion_error).store(ctx, get_string_id("0:AssertionError"));
|
||||||
|
error_ids.gep(ctx, |f| f.runtime_error).store(ctx, get_string_id("0:RuntimeError"));
|
||||||
|
error_ids.gep(ctx, |f| f.type_error).store(ctx, get_string_id("0:TypeError"));
|
||||||
|
|
||||||
|
error_ids
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_error_context_initialize<'ctx>(
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
perrctx: Pointer<'ctx, StructModel<ErrorContext>>,
|
||||||
|
perror_ids: Pointer<'ctx, StructModel<ErrorIds>>,
|
||||||
|
) {
|
||||||
|
FunctionBuilder::begin(ctx, "__nac3_error_context_initialize")
|
||||||
|
.arg("errctx", PointerModel(StructModel(ErrorContext)), perrctx)
|
||||||
|
.arg("error_ids", PointerModel(StructModel(ErrorIds)), perror_ids)
|
||||||
|
.returning_void();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_error_context_has_no_error<'ctx>(
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
errctx: Pointer<'ctx, StructModel<ErrorContext>>,
|
||||||
|
) -> FixedInt<'ctx, Bool> {
|
||||||
|
FunctionBuilder::begin(ctx, "__nac3_error_context_has_no_error")
|
||||||
|
.arg("errctx", PointerModel(StructModel(ErrorContext)), errctx)
|
||||||
|
.returning("has_error", FixedIntModel(Bool))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_error_context_get_error_str<'ctx>(
|
||||||
|
sizet: IntModel<'ctx>,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
errctx: Pointer<'ctx, StructModel<ErrorContext>>,
|
||||||
|
dst_str: Pointer<'ctx, StructModel<Str<'ctx>>>,
|
||||||
|
) {
|
||||||
|
FunctionBuilder::begin(
|
||||||
|
ctx,
|
||||||
|
&get_sized_dependent_function_name(sizet, "__nac3_error_context_get_error_str"),
|
||||||
|
)
|
||||||
|
.arg("errctx", PointerModel(StructModel(ErrorContext)), errctx)
|
||||||
|
.arg("dst_str", PointerModel(StructModel(Str { sizet })), dst_str)
|
||||||
|
.returning_void();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Setup a [`ErrorContext`] that could
|
||||||
|
/// be passed to IRRT functions taking in a `ErrorContext* errctx`
|
||||||
|
/// for error reporting purposes.
|
||||||
|
///
|
||||||
|
/// Also see: [`check_error_context`]
|
||||||
|
pub fn setup_error_context<'ctx>(
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
) -> Pointer<'ctx, StructModel<ErrorContext>> {
|
||||||
|
let error_ids = build_error_ids(ctx);
|
||||||
|
let errctx_ptr = StructModel(ErrorContext).alloca(ctx, "errctx");
|
||||||
|
call_nac3_error_context_initialize(ctx, errctx_ptr, error_ids);
|
||||||
|
errctx_ptr
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check a [`ErrorContext`] to see
|
||||||
|
/// if it contains error.
|
||||||
|
///
|
||||||
|
/// If there is an error, an LLVM exception will be raised at runtime.
|
||||||
|
pub fn check_error_context<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
errctx_ptr: Pointer<'ctx, StructModel<ErrorContext>>,
|
||||||
|
) {
|
||||||
|
let sizet = IntModel(generator.get_size_type(ctx.ctx));
|
||||||
|
|
||||||
|
// Does ErrorContext contain an error?
|
||||||
|
let has_error = call_nac3_error_context_has_no_error(ctx, errctx_ptr);
|
||||||
|
|
||||||
|
// Get the error message (doesn't matter even if there's actually no error)
|
||||||
|
let pstr = StructModel(Str { sizet }).alloca(ctx, "error_str");
|
||||||
|
call_nac3_error_context_get_error_str(sizet, ctx, errctx_ptr, pstr);
|
||||||
|
|
||||||
|
// Load all the values for `ctx.make_assert_impl_by_id`
|
||||||
|
let error_id = errctx_ptr.gep(ctx, |f| f.error_id).load(ctx, "error_id");
|
||||||
|
let error_str = pstr.load(ctx, "error_str");
|
||||||
|
let param1 = errctx_ptr.gep(ctx, |f| f.param1).load(ctx, "param1");
|
||||||
|
let param2 = errctx_ptr.gep(ctx, |f| f.param2).load(ctx, "param2");
|
||||||
|
let param3 = errctx_ptr.gep(ctx, |f| f.param3).load(ctx, "param3");
|
||||||
|
|
||||||
|
// Make assert
|
||||||
|
ctx.make_assert_impl_by_id(
|
||||||
|
generator,
|
||||||
|
has_error.value,
|
||||||
|
error_id.value,
|
||||||
|
error_str.get_llvm_value(),
|
||||||
|
[Some(param1.value), Some(param2.value), Some(param3.value)],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_dummy_raise<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext,
|
||||||
|
) {
|
||||||
|
let errctx = setup_error_context(ctx);
|
||||||
|
FunctionBuilder::begin(ctx, "__nac3_error_dummy_raise")
|
||||||
|
.arg("errctx", PointerModel(StructModel(ErrorContext)), errctx)
|
||||||
|
.returning_void();
|
||||||
|
check_error_context(generator, ctx, errctx);
|
||||||
|
}
|
@ -1,174 +0,0 @@
|
|||||||
use inkwell::{
|
|
||||||
types::BasicTypeEnum,
|
|
||||||
values::{BasicValueEnum, CallSiteValue, IntValue},
|
|
||||||
AddressSpace, IntPredicate,
|
|
||||||
};
|
|
||||||
use itertools::Either;
|
|
||||||
|
|
||||||
use super::calculate_len_for_slice_range;
|
|
||||||
use crate::codegen::{
|
|
||||||
macros::codegen_unreachable,
|
|
||||||
values::{ArrayLikeValue, ListValue},
|
|
||||||
CodeGenContext, CodeGenerator,
|
|
||||||
};
|
|
||||||
|
|
||||||
/// This function handles 'end' **inclusively**.
|
|
||||||
/// Order of tuples `assign_idx` and `value_idx` is ('start', 'end', 'step').
|
|
||||||
/// Negative index should be handled before entering this function
|
|
||||||
pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
ty: BasicTypeEnum<'ctx>,
|
|
||||||
dest_arr: ListValue<'ctx>,
|
|
||||||
dest_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
|
|
||||||
src_arr: ListValue<'ctx>,
|
|
||||||
src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
|
|
||||||
) {
|
|
||||||
let llvm_usize = ctx.get_size_type();
|
|
||||||
let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
|
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
|
||||||
|
|
||||||
assert_eq!(dest_idx.0.get_type(), llvm_i32);
|
|
||||||
assert_eq!(dest_idx.1.get_type(), llvm_i32);
|
|
||||||
assert_eq!(dest_idx.2.get_type(), llvm_i32);
|
|
||||||
assert_eq!(src_idx.0.get_type(), llvm_i32);
|
|
||||||
assert_eq!(src_idx.1.get_type(), llvm_i32);
|
|
||||||
assert_eq!(src_idx.2.get_type(), llvm_i32);
|
|
||||||
|
|
||||||
let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", llvm_pi8);
|
|
||||||
let slice_assign_fun = {
|
|
||||||
let ty_vec = vec![
|
|
||||||
llvm_i32.into(), // dest start idx
|
|
||||||
llvm_i32.into(), // dest end idx
|
|
||||||
llvm_i32.into(), // dest step
|
|
||||||
elem_ptr_type.into(), // dest arr ptr
|
|
||||||
llvm_i32.into(), // dest arr len
|
|
||||||
llvm_i32.into(), // src start idx
|
|
||||||
llvm_i32.into(), // src end idx
|
|
||||||
llvm_i32.into(), // src step
|
|
||||||
elem_ptr_type.into(), // src arr ptr
|
|
||||||
llvm_i32.into(), // src arr len
|
|
||||||
llvm_i32.into(), // size
|
|
||||||
];
|
|
||||||
ctx.module.get_function(fun_symbol).unwrap_or_else(|| {
|
|
||||||
let fn_t = llvm_i32.fn_type(ty_vec.as_slice(), false);
|
|
||||||
ctx.module.add_function(fun_symbol, fn_t, None)
|
|
||||||
})
|
|
||||||
};
|
|
||||||
|
|
||||||
let zero = llvm_i32.const_zero();
|
|
||||||
let one = llvm_i32.const_int(1, false);
|
|
||||||
let dest_arr_ptr = dest_arr.data().base_ptr(ctx, generator);
|
|
||||||
let dest_arr_ptr =
|
|
||||||
ctx.builder.build_pointer_cast(dest_arr_ptr, elem_ptr_type, "dest_arr_ptr_cast").unwrap();
|
|
||||||
let dest_len = dest_arr.load_size(ctx, Some("dest.len"));
|
|
||||||
let dest_len =
|
|
||||||
ctx.builder.build_int_truncate_or_bit_cast(dest_len, llvm_i32, "srclen32").unwrap();
|
|
||||||
let src_arr_ptr = src_arr.data().base_ptr(ctx, generator);
|
|
||||||
let src_arr_ptr =
|
|
||||||
ctx.builder.build_pointer_cast(src_arr_ptr, elem_ptr_type, "src_arr_ptr_cast").unwrap();
|
|
||||||
let src_len = src_arr.load_size(ctx, Some("src.len"));
|
|
||||||
let src_len =
|
|
||||||
ctx.builder.build_int_truncate_or_bit_cast(src_len, llvm_i32, "srclen32").unwrap();
|
|
||||||
|
|
||||||
// index in bound and positive should be done
|
|
||||||
// assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and
|
|
||||||
// throw exception if not satisfied
|
|
||||||
let src_end = ctx
|
|
||||||
.builder
|
|
||||||
.build_select(
|
|
||||||
ctx.builder.build_int_compare(IntPredicate::SLT, src_idx.2, zero, "is_neg").unwrap(),
|
|
||||||
ctx.builder.build_int_sub(src_idx.1, one, "e_min_one").unwrap(),
|
|
||||||
ctx.builder.build_int_add(src_idx.1, one, "e_add_one").unwrap(),
|
|
||||||
"final_e",
|
|
||||||
)
|
|
||||||
.map(BasicValueEnum::into_int_value)
|
|
||||||
.unwrap();
|
|
||||||
let dest_end = ctx
|
|
||||||
.builder
|
|
||||||
.build_select(
|
|
||||||
ctx.builder.build_int_compare(IntPredicate::SLT, dest_idx.2, zero, "is_neg").unwrap(),
|
|
||||||
ctx.builder.build_int_sub(dest_idx.1, one, "e_min_one").unwrap(),
|
|
||||||
ctx.builder.build_int_add(dest_idx.1, one, "e_add_one").unwrap(),
|
|
||||||
"final_e",
|
|
||||||
)
|
|
||||||
.map(BasicValueEnum::into_int_value)
|
|
||||||
.unwrap();
|
|
||||||
let src_slice_len =
|
|
||||||
calculate_len_for_slice_range(generator, ctx, src_idx.0, src_end, src_idx.2);
|
|
||||||
let dest_slice_len =
|
|
||||||
calculate_len_for_slice_range(generator, ctx, dest_idx.0, dest_end, dest_idx.2);
|
|
||||||
let src_eq_dest = ctx
|
|
||||||
.builder
|
|
||||||
.build_int_compare(IntPredicate::EQ, src_slice_len, dest_slice_len, "slice_src_eq_dest")
|
|
||||||
.unwrap();
|
|
||||||
let src_slt_dest = ctx
|
|
||||||
.builder
|
|
||||||
.build_int_compare(IntPredicate::SLT, src_slice_len, dest_slice_len, "slice_src_slt_dest")
|
|
||||||
.unwrap();
|
|
||||||
let dest_step_eq_one = ctx
|
|
||||||
.builder
|
|
||||||
.build_int_compare(
|
|
||||||
IntPredicate::EQ,
|
|
||||||
dest_idx.2,
|
|
||||||
dest_idx.2.get_type().const_int(1, false),
|
|
||||||
"slice_dest_step_eq_one",
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1").unwrap();
|
|
||||||
let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond").unwrap();
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
cond,
|
|
||||||
"0:ValueError",
|
|
||||||
"attempt to assign sequence of size {0} to slice of size {1} with step size {2}",
|
|
||||||
[Some(src_slice_len), Some(dest_slice_len), Some(dest_idx.2)],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
|
|
||||||
let new_len = {
|
|
||||||
let args = vec![
|
|
||||||
dest_idx.0.into(), // dest start idx
|
|
||||||
dest_idx.1.into(), // dest end idx
|
|
||||||
dest_idx.2.into(), // dest step
|
|
||||||
dest_arr_ptr.into(), // dest arr ptr
|
|
||||||
dest_len.into(), // dest arr len
|
|
||||||
src_idx.0.into(), // src start idx
|
|
||||||
src_idx.1.into(), // src end idx
|
|
||||||
src_idx.2.into(), // src step
|
|
||||||
src_arr_ptr.into(), // src arr ptr
|
|
||||||
src_len.into(), // src arr len
|
|
||||||
{
|
|
||||||
let s = match ty {
|
|
||||||
BasicTypeEnum::FloatType(t) => t.size_of(),
|
|
||||||
BasicTypeEnum::IntType(t) => t.size_of(),
|
|
||||||
BasicTypeEnum::PointerType(t) => t.size_of(),
|
|
||||||
BasicTypeEnum::StructType(t) => t.size_of().unwrap(),
|
|
||||||
_ => codegen_unreachable!(ctx),
|
|
||||||
};
|
|
||||||
ctx.builder.build_int_truncate_or_bit_cast(s, llvm_i32, "size").unwrap()
|
|
||||||
}
|
|
||||||
.into(),
|
|
||||||
];
|
|
||||||
ctx.builder
|
|
||||||
.build_call(slice_assign_fun, args.as_slice(), "slice_assign")
|
|
||||||
.map(CallSiteValue::try_as_basic_value)
|
|
||||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
|
||||||
.map(Either::unwrap_left)
|
|
||||||
.unwrap()
|
|
||||||
};
|
|
||||||
|
|
||||||
// update length
|
|
||||||
let need_update =
|
|
||||||
ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update").unwrap();
|
|
||||||
let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
|
|
||||||
let update_bb = ctx.ctx.append_basic_block(current, "update");
|
|
||||||
let cont_bb = ctx.ctx.append_basic_block(current, "cont");
|
|
||||||
ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb).unwrap();
|
|
||||||
ctx.builder.position_at_end(update_bb);
|
|
||||||
let new_len =
|
|
||||||
ctx.builder.build_int_z_extend_or_bit_cast(new_len, llvm_usize, "new_len").unwrap();
|
|
||||||
dest_arr.store_size(ctx, new_len);
|
|
||||||
ctx.builder.build_unconditional_branch(cont_bb).unwrap();
|
|
||||||
ctx.builder.position_at_end(cont_bb);
|
|
||||||
}
|
|
@ -1,168 +0,0 @@
|
|||||||
use inkwell::{
|
|
||||||
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue},
|
|
||||||
IntPredicate,
|
|
||||||
};
|
|
||||||
use itertools::Either;
|
|
||||||
|
|
||||||
use crate::codegen::{
|
|
||||||
macros::codegen_unreachable,
|
|
||||||
{CodeGenContext, CodeGenerator},
|
|
||||||
};
|
|
||||||
|
|
||||||
// repeated squaring method adapted from GNU Scientific Library:
|
|
||||||
// https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
|
|
||||||
pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
base: IntValue<'ctx>,
|
|
||||||
exp: IntValue<'ctx>,
|
|
||||||
signed: bool,
|
|
||||||
) -> IntValue<'ctx> {
|
|
||||||
let symbol = match (base.get_type().get_bit_width(), exp.get_type().get_bit_width(), signed) {
|
|
||||||
(32, 32, true) => "__nac3_int_exp_int32_t",
|
|
||||||
(64, 64, true) => "__nac3_int_exp_int64_t",
|
|
||||||
(32, 32, false) => "__nac3_int_exp_uint32_t",
|
|
||||||
(64, 64, false) => "__nac3_int_exp_uint64_t",
|
|
||||||
_ => codegen_unreachable!(ctx),
|
|
||||||
};
|
|
||||||
let base_type = base.get_type();
|
|
||||||
let pow_fun = ctx.module.get_function(symbol).unwrap_or_else(|| {
|
|
||||||
let fn_type = base_type.fn_type(&[base_type.into(), base_type.into()], false);
|
|
||||||
ctx.module.add_function(symbol, fn_type, None)
|
|
||||||
});
|
|
||||||
// throw exception when exp < 0
|
|
||||||
let ge_zero = ctx
|
|
||||||
.builder
|
|
||||||
.build_int_compare(
|
|
||||||
IntPredicate::SGE,
|
|
||||||
exp,
|
|
||||||
exp.get_type().const_zero(),
|
|
||||||
"assert_int_pow_ge_0",
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
ge_zero,
|
|
||||||
"0:ValueError",
|
|
||||||
"integer power must be positive or zero",
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
ctx.builder
|
|
||||||
.build_call(pow_fun, &[base.into(), exp.into()], "call_int_pow")
|
|
||||||
.map(CallSiteValue::try_as_basic_value)
|
|
||||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
|
||||||
.map(Either::unwrap_left)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates a call to `isinf` in IR. Returns an `i1` representing the result.
|
|
||||||
pub fn call_isinf<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
v: FloatValue<'ctx>,
|
|
||||||
) -> IntValue<'ctx> {
|
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
|
||||||
let llvm_f64 = ctx.ctx.f64_type();
|
|
||||||
|
|
||||||
assert_eq!(v.get_type(), llvm_f64);
|
|
||||||
|
|
||||||
let intrinsic_fn = ctx.module.get_function("__nac3_isinf").unwrap_or_else(|| {
|
|
||||||
let fn_type = llvm_i32.fn_type(&[llvm_f64.into()], false);
|
|
||||||
ctx.module.add_function("__nac3_isinf", fn_type, None)
|
|
||||||
});
|
|
||||||
|
|
||||||
let ret = ctx
|
|
||||||
.builder
|
|
||||||
.build_call(intrinsic_fn, &[v.into()], "isinf")
|
|
||||||
.map(CallSiteValue::try_as_basic_value)
|
|
||||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
|
||||||
.map(Either::unwrap_left)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
generator.bool_to_i1(ctx, ret)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates a call to `isnan` in IR. Returns an `i1` representing the result.
|
|
||||||
pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
v: FloatValue<'ctx>,
|
|
||||||
) -> IntValue<'ctx> {
|
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
|
||||||
let llvm_f64 = ctx.ctx.f64_type();
|
|
||||||
|
|
||||||
assert_eq!(v.get_type(), llvm_f64);
|
|
||||||
|
|
||||||
let intrinsic_fn = ctx.module.get_function("__nac3_isnan").unwrap_or_else(|| {
|
|
||||||
let fn_type = llvm_i32.fn_type(&[llvm_f64.into()], false);
|
|
||||||
ctx.module.add_function("__nac3_isnan", fn_type, None)
|
|
||||||
});
|
|
||||||
|
|
||||||
let ret = ctx
|
|
||||||
.builder
|
|
||||||
.build_call(intrinsic_fn, &[v.into()], "isnan")
|
|
||||||
.map(CallSiteValue::try_as_basic_value)
|
|
||||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
|
||||||
.map(Either::unwrap_left)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
generator.bool_to_i1(ctx, ret)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates a call to `gamma` in IR. Returns an `f64` representing the result.
|
|
||||||
pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
|
|
||||||
let llvm_f64 = ctx.ctx.f64_type();
|
|
||||||
|
|
||||||
assert_eq!(v.get_type(), llvm_f64);
|
|
||||||
|
|
||||||
let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| {
|
|
||||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
|
||||||
ctx.module.add_function("__nac3_gamma", fn_type, None)
|
|
||||||
});
|
|
||||||
|
|
||||||
ctx.builder
|
|
||||||
.build_call(intrinsic_fn, &[v.into()], "gamma")
|
|
||||||
.map(CallSiteValue::try_as_basic_value)
|
|
||||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
|
||||||
.map(Either::unwrap_left)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates a call to `gammaln` in IR. Returns an `f64` representing the result.
|
|
||||||
pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
|
|
||||||
let llvm_f64 = ctx.ctx.f64_type();
|
|
||||||
|
|
||||||
assert_eq!(v.get_type(), llvm_f64);
|
|
||||||
|
|
||||||
let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| {
|
|
||||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
|
||||||
ctx.module.add_function("__nac3_gammaln", fn_type, None)
|
|
||||||
});
|
|
||||||
|
|
||||||
ctx.builder
|
|
||||||
.build_call(intrinsic_fn, &[v.into()], "gammaln")
|
|
||||||
.map(CallSiteValue::try_as_basic_value)
|
|
||||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
|
||||||
.map(Either::unwrap_left)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates a call to `j0` in IR. Returns an `f64` representing the result.
|
|
||||||
pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
|
|
||||||
let llvm_f64 = ctx.ctx.f64_type();
|
|
||||||
|
|
||||||
assert_eq!(v.get_type(), llvm_f64);
|
|
||||||
|
|
||||||
let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| {
|
|
||||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
|
||||||
ctx.module.add_function("__nac3_j0", fn_type, None)
|
|
||||||
});
|
|
||||||
|
|
||||||
ctx.builder
|
|
||||||
.build_call(intrinsic_fn, &[v.into()], "j0")
|
|
||||||
.map(CallSiteValue::try_as_basic_value)
|
|
||||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
|
||||||
.map(Either::unwrap_left)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
@ -1,31 +1,33 @@
|
|||||||
|
use crate::typecheck::typedef::Type;
|
||||||
|
|
||||||
|
pub mod error_context;
|
||||||
|
pub mod string;
|
||||||
|
mod test;
|
||||||
|
pub mod util;
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
classes::{
|
||||||
|
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue,
|
||||||
|
TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
|
||||||
|
},
|
||||||
|
llvm_intrinsics, CodeGenContext, CodeGenerator,
|
||||||
|
};
|
||||||
|
use crate::codegen::classes::TypedArrayLikeAccessor;
|
||||||
|
use crate::codegen::stmt::gen_for_callback_incrementing;
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
attributes::{Attribute, AttributeLoc},
|
attributes::{Attribute, AttributeLoc},
|
||||||
context::Context,
|
context::Context,
|
||||||
memory_buffer::MemoryBuffer,
|
memory_buffer::MemoryBuffer,
|
||||||
module::Module,
|
module::Module,
|
||||||
values::{BasicValue, BasicValueEnum, IntValue},
|
types::{BasicTypeEnum, IntType},
|
||||||
IntPredicate,
|
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue},
|
||||||
|
AddressSpace, IntPredicate,
|
||||||
};
|
};
|
||||||
|
use itertools::Either;
|
||||||
use nac3parser::ast::Expr;
|
use nac3parser::ast::Expr;
|
||||||
|
|
||||||
use super::{CodeGenContext, CodeGenerator};
|
|
||||||
use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Type};
|
|
||||||
pub use list::*;
|
|
||||||
pub use math::*;
|
|
||||||
pub use range::*;
|
|
||||||
pub use slice::*;
|
|
||||||
pub use string::*;
|
|
||||||
|
|
||||||
mod list;
|
|
||||||
mod math;
|
|
||||||
pub mod ndarray;
|
|
||||||
mod range;
|
|
||||||
mod slice;
|
|
||||||
mod string;
|
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> {
|
pub fn load_irrt(ctx: &Context) -> Module {
|
||||||
let bitcode_buf = MemoryBuffer::create_from_memory_range(
|
let bitcode_buf = MemoryBuffer::create_from_memory_range(
|
||||||
include_bytes!(concat!(env!("OUT_DIR"), "/irrt.bc")),
|
include_bytes!(concat!(env!("OUT_DIR"), "/irrt.bc")),
|
||||||
"irrt_bitcode_buffer",
|
"irrt_bitcode_buffer",
|
||||||
@ -41,43 +43,89 @@ pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver)
|
|||||||
let function = irrt_mod.get_function(symbol).unwrap();
|
let function = irrt_mod.get_function(symbol).unwrap();
|
||||||
function.add_attribute(AttributeLoc::Function, ctx.create_enum_attribute(inline_attr, 0));
|
function.add_attribute(AttributeLoc::Function, ctx.create_enum_attribute(inline_attr, 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize all global `EXN_*` exception IDs in IRRT with the [`SymbolResolver`].
|
|
||||||
let exn_id_type = ctx.i32_type();
|
|
||||||
let errors = &[
|
|
||||||
("EXN_INDEX_ERROR", "0:IndexError"),
|
|
||||||
("EXN_VALUE_ERROR", "0:ValueError"),
|
|
||||||
("EXN_ASSERTION_ERROR", "0:AssertionError"),
|
|
||||||
("EXN_TYPE_ERROR", "0:TypeError"),
|
|
||||||
];
|
|
||||||
for (irrt_name, symbol_name) in errors {
|
|
||||||
let exn_id = symbol_resolver.get_string_id(symbol_name);
|
|
||||||
let exn_id = exn_id_type.const_int(exn_id as u64, false).as_basic_value_enum();
|
|
||||||
|
|
||||||
let global = irrt_mod.get_global(irrt_name).unwrap_or_else(|| {
|
|
||||||
panic!("Exception symbol name '{irrt_name}' should exist in the IRRT LLVM module")
|
|
||||||
});
|
|
||||||
global.set_initializer(&exn_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
irrt_mod
|
irrt_mod
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the name of a function which contains variants for 32-bit and 64-bit `size_t`.
|
// repeated squaring method adapted from GNU Scientific Library:
|
||||||
///
|
// https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
|
||||||
/// - When [`TypeContext::size_type`] is 32-bits, the function name is `fn_name}`.
|
pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
/// - When [`TypeContext::size_type`] is 64-bits, the function name is `{fn_name}64`.
|
generator: &mut G,
|
||||||
#[must_use]
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
pub fn get_usize_dependent_function_name(ctx: &CodeGenContext<'_, '_>, name: &str) -> String {
|
base: IntValue<'ctx>,
|
||||||
let mut name = name.to_owned();
|
exp: IntValue<'ctx>,
|
||||||
match ctx.get_size_type().get_bit_width() {
|
signed: bool,
|
||||||
32 => {}
|
) -> IntValue<'ctx> {
|
||||||
64 => name.push_str("64"),
|
let symbol = match (base.get_type().get_bit_width(), exp.get_type().get_bit_width(), signed) {
|
||||||
bit_width => {
|
(32, 32, true) => "__nac3_int_exp_int32_t",
|
||||||
panic!("Unsupported int type bit width {bit_width}, must be either 32-bits or 64-bits")
|
(64, 64, true) => "__nac3_int_exp_int64_t",
|
||||||
}
|
(32, 32, false) => "__nac3_int_exp_uint32_t",
|
||||||
}
|
(64, 64, false) => "__nac3_int_exp_uint64_t",
|
||||||
name
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
let base_type = base.get_type();
|
||||||
|
let pow_fun = ctx.module.get_function(symbol).unwrap_or_else(|| {
|
||||||
|
let fn_type = base_type.fn_type(&[base_type.into(), base_type.into()], false);
|
||||||
|
ctx.module.add_function(symbol, fn_type, None)
|
||||||
|
});
|
||||||
|
// throw exception when exp < 0
|
||||||
|
let ge_zero = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(
|
||||||
|
IntPredicate::SGE,
|
||||||
|
exp,
|
||||||
|
exp.get_type().const_zero(),
|
||||||
|
"assert_int_pow_ge_0",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
ge_zero,
|
||||||
|
"0:ValueError",
|
||||||
|
"integer power must be positive or zero",
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
ctx.builder
|
||||||
|
.build_call(pow_fun, &[base.into(), exp.into()], "call_int_pow")
|
||||||
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
|
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||||
|
.map(Either::unwrap_left)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
start: IntValue<'ctx>,
|
||||||
|
end: IntValue<'ctx>,
|
||||||
|
step: IntValue<'ctx>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
const SYMBOL: &str = "__nac3_range_slice_len";
|
||||||
|
let len_func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
|
||||||
|
let i32_t = ctx.ctx.i32_type();
|
||||||
|
let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into(), i32_t.into()], false);
|
||||||
|
ctx.module.add_function(SYMBOL, fn_t, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
// assert step != 0, throw exception if not
|
||||||
|
let not_zero = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(IntPredicate::NE, step, step.get_type().const_zero(), "range_step_ne")
|
||||||
|
.unwrap();
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
not_zero,
|
||||||
|
"0:ValueError",
|
||||||
|
"step must not be zero",
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
ctx.builder
|
||||||
|
.build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len")
|
||||||
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
|
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||||
|
.map(Either::unwrap_left)
|
||||||
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// NOTE: the output value of the end index of this function should be compared ***inclusively***,
|
/// NOTE: the output value of the end index of this function should be compared ***inclusively***,
|
||||||
@ -128,11 +176,10 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
|
|||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
length: IntValue<'ctx>,
|
length: IntValue<'ctx>,
|
||||||
) -> Result<Option<(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)>, String> {
|
) -> Result<Option<(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)>, String> {
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
let int32 = ctx.ctx.i32_type();
|
||||||
|
let zero = int32.const_zero();
|
||||||
let zero = llvm_i32.const_zero();
|
let one = int32.const_int(1, false);
|
||||||
let one = llvm_i32.const_int(1, false);
|
let length = ctx.builder.build_int_truncate_or_bit_cast(length, int32, "leni32").unwrap();
|
||||||
let length = ctx.builder.build_int_truncate_or_bit_cast(length, llvm_i32, "leni32").unwrap();
|
|
||||||
Ok(Some(match (start, end, step) {
|
Ok(Some(match (start, end, step) {
|
||||||
(s, e, None) => (
|
(s, e, None) => (
|
||||||
if let Some(s) = s.as_ref() {
|
if let Some(s) = s.as_ref() {
|
||||||
@ -141,7 +188,7 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
|
|||||||
None => return Ok(None),
|
None => return Ok(None),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
llvm_i32.const_zero()
|
int32.const_zero()
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
let e = if let Some(s) = e.as_ref() {
|
let e = if let Some(s) = e.as_ref() {
|
||||||
@ -246,3 +293,642 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
|
|||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// this function allows index out of range, since python
|
||||||
|
/// allows index out of range in slice (`a = [1,2,3]; a[1:10] == [2,3]`).
|
||||||
|
pub fn handle_slice_index_bound<'ctx, G: CodeGenerator>(
|
||||||
|
i: &Expr<Option<Type>>,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut G,
|
||||||
|
length: IntValue<'ctx>,
|
||||||
|
) -> Result<Option<IntValue<'ctx>>, String> {
|
||||||
|
const SYMBOL: &str = "__nac3_slice_index_bound";
|
||||||
|
let func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
|
||||||
|
let i32_t = ctx.ctx.i32_type();
|
||||||
|
let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into()], false);
|
||||||
|
ctx.module.add_function(SYMBOL, fn_t, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
let i = if let Some(v) = generator.gen_expr(ctx, i)? {
|
||||||
|
v.to_basic_value_enum(ctx, generator, i.custom.unwrap())?
|
||||||
|
} else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
Ok(Some(
|
||||||
|
ctx.builder
|
||||||
|
.build_call(func, &[i.into(), length.into()], "bounded_ind")
|
||||||
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
|
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||||
|
.map(Either::unwrap_left)
|
||||||
|
.unwrap(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This function handles 'end' **inclusively**.
|
||||||
|
/// Order of tuples `assign_idx` and `value_idx` is ('start', 'end', 'step').
|
||||||
|
/// Negative index should be handled before entering this function
|
||||||
|
pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ty: BasicTypeEnum<'ctx>,
|
||||||
|
dest_arr: ListValue<'ctx>,
|
||||||
|
dest_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
|
||||||
|
src_arr: ListValue<'ctx>,
|
||||||
|
src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
|
||||||
|
) {
|
||||||
|
let size_ty = generator.get_size_type(ctx.ctx);
|
||||||
|
let int8_ptr = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
|
||||||
|
let int32 = ctx.ctx.i32_type();
|
||||||
|
let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", int8_ptr);
|
||||||
|
let slice_assign_fun = {
|
||||||
|
let ty_vec = vec![
|
||||||
|
int32.into(), // dest start idx
|
||||||
|
int32.into(), // dest end idx
|
||||||
|
int32.into(), // dest step
|
||||||
|
elem_ptr_type.into(), // dest arr ptr
|
||||||
|
int32.into(), // dest arr len
|
||||||
|
int32.into(), // src start idx
|
||||||
|
int32.into(), // src end idx
|
||||||
|
int32.into(), // src step
|
||||||
|
elem_ptr_type.into(), // src arr ptr
|
||||||
|
int32.into(), // src arr len
|
||||||
|
int32.into(), // size
|
||||||
|
];
|
||||||
|
ctx.module.get_function(fun_symbol).unwrap_or_else(|| {
|
||||||
|
let fn_t = int32.fn_type(ty_vec.as_slice(), false);
|
||||||
|
ctx.module.add_function(fun_symbol, fn_t, None)
|
||||||
|
})
|
||||||
|
};
|
||||||
|
|
||||||
|
let zero = int32.const_zero();
|
||||||
|
let one = int32.const_int(1, false);
|
||||||
|
let dest_arr_ptr = dest_arr.data().base_ptr(ctx, generator);
|
||||||
|
let dest_arr_ptr =
|
||||||
|
ctx.builder.build_pointer_cast(dest_arr_ptr, elem_ptr_type, "dest_arr_ptr_cast").unwrap();
|
||||||
|
let dest_len = dest_arr.load_size(ctx, Some("dest.len"));
|
||||||
|
let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32").unwrap();
|
||||||
|
let src_arr_ptr = src_arr.data().base_ptr(ctx, generator);
|
||||||
|
let src_arr_ptr =
|
||||||
|
ctx.builder.build_pointer_cast(src_arr_ptr, elem_ptr_type, "src_arr_ptr_cast").unwrap();
|
||||||
|
let src_len = src_arr.load_size(ctx, Some("src.len"));
|
||||||
|
let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32").unwrap();
|
||||||
|
|
||||||
|
// index in bound and positive should be done
|
||||||
|
// assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and
|
||||||
|
// throw exception if not satisfied
|
||||||
|
let src_end = ctx
|
||||||
|
.builder
|
||||||
|
.build_select(
|
||||||
|
ctx.builder.build_int_compare(IntPredicate::SLT, src_idx.2, zero, "is_neg").unwrap(),
|
||||||
|
ctx.builder.build_int_sub(src_idx.1, one, "e_min_one").unwrap(),
|
||||||
|
ctx.builder.build_int_add(src_idx.1, one, "e_add_one").unwrap(),
|
||||||
|
"final_e",
|
||||||
|
)
|
||||||
|
.map(BasicValueEnum::into_int_value)
|
||||||
|
.unwrap();
|
||||||
|
let dest_end = ctx
|
||||||
|
.builder
|
||||||
|
.build_select(
|
||||||
|
ctx.builder.build_int_compare(IntPredicate::SLT, dest_idx.2, zero, "is_neg").unwrap(),
|
||||||
|
ctx.builder.build_int_sub(dest_idx.1, one, "e_min_one").unwrap(),
|
||||||
|
ctx.builder.build_int_add(dest_idx.1, one, "e_add_one").unwrap(),
|
||||||
|
"final_e",
|
||||||
|
)
|
||||||
|
.map(BasicValueEnum::into_int_value)
|
||||||
|
.unwrap();
|
||||||
|
let src_slice_len =
|
||||||
|
calculate_len_for_slice_range(generator, ctx, src_idx.0, src_end, src_idx.2);
|
||||||
|
let dest_slice_len =
|
||||||
|
calculate_len_for_slice_range(generator, ctx, dest_idx.0, dest_end, dest_idx.2);
|
||||||
|
let src_eq_dest = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(IntPredicate::EQ, src_slice_len, dest_slice_len, "slice_src_eq_dest")
|
||||||
|
.unwrap();
|
||||||
|
let src_slt_dest = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(IntPredicate::SLT, src_slice_len, dest_slice_len, "slice_src_slt_dest")
|
||||||
|
.unwrap();
|
||||||
|
let dest_step_eq_one = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(
|
||||||
|
IntPredicate::EQ,
|
||||||
|
dest_idx.2,
|
||||||
|
dest_idx.2.get_type().const_int(1, false),
|
||||||
|
"slice_dest_step_eq_one",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1").unwrap();
|
||||||
|
let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond").unwrap();
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
cond,
|
||||||
|
"0:ValueError",
|
||||||
|
"attempt to assign sequence of size {0} to slice of size {1} with step size {2}",
|
||||||
|
[Some(src_slice_len), Some(dest_slice_len), Some(dest_idx.2)],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
|
||||||
|
let new_len = {
|
||||||
|
let args = vec![
|
||||||
|
dest_idx.0.into(), // dest start idx
|
||||||
|
dest_idx.1.into(), // dest end idx
|
||||||
|
dest_idx.2.into(), // dest step
|
||||||
|
dest_arr_ptr.into(), // dest arr ptr
|
||||||
|
dest_len.into(), // dest arr len
|
||||||
|
src_idx.0.into(), // src start idx
|
||||||
|
src_idx.1.into(), // src end idx
|
||||||
|
src_idx.2.into(), // src step
|
||||||
|
src_arr_ptr.into(), // src arr ptr
|
||||||
|
src_len.into(), // src arr len
|
||||||
|
{
|
||||||
|
let s = match ty {
|
||||||
|
BasicTypeEnum::FloatType(t) => t.size_of(),
|
||||||
|
BasicTypeEnum::IntType(t) => t.size_of(),
|
||||||
|
BasicTypeEnum::PointerType(t) => t.size_of(),
|
||||||
|
BasicTypeEnum::StructType(t) => t.size_of().unwrap(),
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size").unwrap()
|
||||||
|
}
|
||||||
|
.into(),
|
||||||
|
];
|
||||||
|
ctx.builder
|
||||||
|
.build_call(slice_assign_fun, args.as_slice(), "slice_assign")
|
||||||
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
|
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||||
|
.map(Either::unwrap_left)
|
||||||
|
.unwrap()
|
||||||
|
};
|
||||||
|
// update length
|
||||||
|
let need_update =
|
||||||
|
ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update").unwrap();
|
||||||
|
let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
|
||||||
|
let update_bb = ctx.ctx.append_basic_block(current, "update");
|
||||||
|
let cont_bb = ctx.ctx.append_basic_block(current, "cont");
|
||||||
|
ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb).unwrap();
|
||||||
|
ctx.builder.position_at_end(update_bb);
|
||||||
|
let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len").unwrap();
|
||||||
|
dest_arr.store_size(ctx, generator, new_len);
|
||||||
|
ctx.builder.build_unconditional_branch(cont_bb).unwrap();
|
||||||
|
ctx.builder.position_at_end(cont_bb);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a call to `isinf` in IR. Returns an `i1` representing the result.
|
||||||
|
pub fn call_isinf<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
v: FloatValue<'ctx>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
let intrinsic_fn = ctx.module.get_function("__nac3_isinf").unwrap_or_else(|| {
|
||||||
|
let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
|
||||||
|
ctx.module.add_function("__nac3_isinf", fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
let ret = ctx
|
||||||
|
.builder
|
||||||
|
.build_call(intrinsic_fn, &[v.into()], "isinf")
|
||||||
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
|
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||||
|
.map(Either::unwrap_left)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
generator.bool_to_i1(ctx, ret)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a call to `isnan` in IR. Returns an `i1` representing the result.
|
||||||
|
pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
v: FloatValue<'ctx>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
let intrinsic_fn = ctx.module.get_function("__nac3_isnan").unwrap_or_else(|| {
|
||||||
|
let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
|
||||||
|
ctx.module.add_function("__nac3_isnan", fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
let ret = ctx
|
||||||
|
.builder
|
||||||
|
.build_call(intrinsic_fn, &[v.into()], "isnan")
|
||||||
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
|
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||||
|
.map(Either::unwrap_left)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
generator.bool_to_i1(ctx, ret)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a call to `gamma` in IR. Returns an `f64` representing the result.
|
||||||
|
pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
|
||||||
|
let llvm_f64 = ctx.ctx.f64_type();
|
||||||
|
|
||||||
|
let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| {
|
||||||
|
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||||
|
ctx.module.add_function("__nac3_gamma", fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
ctx.builder
|
||||||
|
.build_call(intrinsic_fn, &[v.into()], "gamma")
|
||||||
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
|
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||||
|
.map(Either::unwrap_left)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a call to `gammaln` in IR. Returns an `f64` representing the result.
|
||||||
|
pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
|
||||||
|
let llvm_f64 = ctx.ctx.f64_type();
|
||||||
|
|
||||||
|
let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| {
|
||||||
|
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||||
|
ctx.module.add_function("__nac3_gammaln", fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
ctx.builder
|
||||||
|
.build_call(intrinsic_fn, &[v.into()], "gammaln")
|
||||||
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
|
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||||
|
.map(Either::unwrap_left)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a call to `j0` in IR. Returns an `f64` representing the result.
|
||||||
|
pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
|
||||||
|
let llvm_f64 = ctx.ctx.f64_type();
|
||||||
|
|
||||||
|
let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| {
|
||||||
|
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||||
|
ctx.module.add_function("__nac3_j0", fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
ctx.builder
|
||||||
|
.build_call(intrinsic_fn, &[v.into()], "j0")
|
||||||
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
|
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||||
|
.map(Either::unwrap_left)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
|
||||||
|
/// calculated total size.
|
||||||
|
///
|
||||||
|
/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension.
|
||||||
|
/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for,
|
||||||
|
/// or [`None`] if starting from the first dimension and ending at the last dimension respectively.
|
||||||
|
pub fn call_ndarray_calc_size<'ctx, G, Dims>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
dims: &Dims,
|
||||||
|
(begin, end): (Option<IntValue<'ctx>>, Option<IntValue<'ctx>>),
|
||||||
|
) -> IntValue<'ctx>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
Dims: ArrayLikeIndexer<'ctx>,
|
||||||
|
{
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||||
|
|
||||||
|
let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() {
|
||||||
|
32 => "__nac3_ndarray_calc_size",
|
||||||
|
64 => "__nac3_ndarray_calc_size64",
|
||||||
|
bw => unreachable!("Unsupported size type bit width: {}", bw),
|
||||||
|
};
|
||||||
|
let ndarray_calc_size_fn_t = llvm_usize.fn_type(
|
||||||
|
&[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()],
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
let ndarray_calc_size_fn =
|
||||||
|
ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| {
|
||||||
|
ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
let begin = begin.unwrap_or_else(|| llvm_usize.const_zero());
|
||||||
|
let end = end.unwrap_or_else(|| dims.size(ctx, generator));
|
||||||
|
ctx.builder
|
||||||
|
.build_call(
|
||||||
|
ndarray_calc_size_fn,
|
||||||
|
&[
|
||||||
|
dims.base_ptr(ctx, generator).into(),
|
||||||
|
dims.size(ctx, generator).into(),
|
||||||
|
begin.into(),
|
||||||
|
end.into(),
|
||||||
|
],
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
|
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||||
|
.map(Either::unwrap_left)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`]
|
||||||
|
/// containing `i32` indices of the flattened index.
|
||||||
|
///
|
||||||
|
/// * `index` - The index to compute the multidimensional index for.
|
||||||
|
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
|
||||||
|
/// `NDArray`.
|
||||||
|
pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
index: IntValue<'ctx>,
|
||||||
|
ndarray: NDArrayValue<'ctx>,
|
||||||
|
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
|
||||||
|
let llvm_void = ctx.ctx.void_type();
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
|
||||||
|
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||||
|
|
||||||
|
let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() {
|
||||||
|
32 => "__nac3_ndarray_calc_nd_indices",
|
||||||
|
64 => "__nac3_ndarray_calc_nd_indices64",
|
||||||
|
bw => unreachable!("Unsupported size type bit width: {}", bw),
|
||||||
|
};
|
||||||
|
let ndarray_calc_nd_indices_fn =
|
||||||
|
ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| {
|
||||||
|
let fn_type = llvm_void.fn_type(
|
||||||
|
&[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()],
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
|
||||||
|
ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||||
|
let ndarray_dims = ndarray.dim_sizes();
|
||||||
|
|
||||||
|
let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap();
|
||||||
|
|
||||||
|
ctx.builder
|
||||||
|
.build_call(
|
||||||
|
ndarray_calc_nd_indices_fn,
|
||||||
|
&[
|
||||||
|
index.into(),
|
||||||
|
ndarray_dims.base_ptr(ctx, generator).into(),
|
||||||
|
ndarray_num_dims.into(),
|
||||||
|
indices.into(),
|
||||||
|
],
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
TypedArrayLikeAdapter::from(
|
||||||
|
ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None),
|
||||||
|
Box::new(|_, v| v.into_int_value()),
|
||||||
|
Box::new(|_, v| v.into()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call_ndarray_flatten_index_impl<'ctx, G, Indices>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray: NDArrayValue<'ctx>,
|
||||||
|
indices: &Indices,
|
||||||
|
) -> IntValue<'ctx>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
Indices: ArrayLikeIndexer<'ctx>,
|
||||||
|
{
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
|
||||||
|
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||||
|
|
||||||
|
debug_assert_eq!(
|
||||||
|
IntType::try_from(indices.element_type(ctx, generator))
|
||||||
|
.map(IntType::get_bit_width)
|
||||||
|
.unwrap_or_default(),
|
||||||
|
llvm_i32.get_bit_width(),
|
||||||
|
"Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`"
|
||||||
|
);
|
||||||
|
debug_assert_eq!(
|
||||||
|
indices.size(ctx, generator).get_type().get_bit_width(),
|
||||||
|
llvm_usize.get_bit_width(),
|
||||||
|
"Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`"
|
||||||
|
);
|
||||||
|
|
||||||
|
let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
|
||||||
|
32 => "__nac3_ndarray_flatten_index",
|
||||||
|
64 => "__nac3_ndarray_flatten_index64",
|
||||||
|
bw => unreachable!("Unsupported size type bit width: {}", bw),
|
||||||
|
};
|
||||||
|
let ndarray_flatten_index_fn =
|
||||||
|
ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| {
|
||||||
|
let fn_type = llvm_usize.fn_type(
|
||||||
|
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()],
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
|
||||||
|
ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||||
|
let ndarray_dims = ndarray.dim_sizes();
|
||||||
|
|
||||||
|
let index = ctx
|
||||||
|
.builder
|
||||||
|
.build_call(
|
||||||
|
ndarray_flatten_index_fn,
|
||||||
|
&[
|
||||||
|
ndarray_dims.base_ptr(ctx, generator).into(),
|
||||||
|
ndarray_num_dims.into(),
|
||||||
|
indices.base_ptr(ctx, generator).into(),
|
||||||
|
indices.size(ctx, generator).into(),
|
||||||
|
],
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
|
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||||
|
.map(Either::unwrap_left)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
index
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
|
||||||
|
/// multidimensional index.
|
||||||
|
///
|
||||||
|
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
|
||||||
|
/// `NDArray`.
|
||||||
|
/// * `indices` - The multidimensional index to compute the flattened index for.
|
||||||
|
pub fn call_ndarray_flatten_index<'ctx, G, Index>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray: NDArrayValue<'ctx>,
|
||||||
|
indices: &Index,
|
||||||
|
) -> IntValue<'ctx>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
Index: ArrayLikeIndexer<'ctx>,
|
||||||
|
{
|
||||||
|
call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of
|
||||||
|
/// dimension and size of each dimension of the resultant `ndarray`.
|
||||||
|
pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
lhs: NDArrayValue<'ctx>,
|
||||||
|
rhs: NDArrayValue<'ctx>,
|
||||||
|
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||||
|
|
||||||
|
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
|
||||||
|
32 => "__nac3_ndarray_calc_broadcast",
|
||||||
|
64 => "__nac3_ndarray_calc_broadcast64",
|
||||||
|
bw => unreachable!("Unsupported size type bit width: {}", bw),
|
||||||
|
};
|
||||||
|
let ndarray_calc_broadcast_fn =
|
||||||
|
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
|
||||||
|
let fn_type = llvm_usize.fn_type(
|
||||||
|
&[
|
||||||
|
llvm_pusize.into(),
|
||||||
|
llvm_usize.into(),
|
||||||
|
llvm_pusize.into(),
|
||||||
|
llvm_usize.into(),
|
||||||
|
llvm_pusize.into(),
|
||||||
|
],
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
|
||||||
|
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
let lhs_ndims = lhs.load_ndims(ctx);
|
||||||
|
let rhs_ndims = rhs.load_ndims(ctx);
|
||||||
|
let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None);
|
||||||
|
|
||||||
|
gen_for_callback_incrementing(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
llvm_usize.const_zero(),
|
||||||
|
(min_ndims, false),
|
||||||
|
|generator, ctx, _, idx| {
|
||||||
|
let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap();
|
||||||
|
let (lhs_dim_sz, rhs_dim_sz) = unsafe {
|
||||||
|
(
|
||||||
|
lhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None),
|
||||||
|
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
let llvm_usize_const_one = llvm_usize.const_int(1, false);
|
||||||
|
let lhs_eqz = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "")
|
||||||
|
.unwrap();
|
||||||
|
let rhs_eqz = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "")
|
||||||
|
.unwrap();
|
||||||
|
let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap();
|
||||||
|
|
||||||
|
let lhs_eq_rhs = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "")
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap();
|
||||||
|
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
is_compatible,
|
||||||
|
"0:ValueError",
|
||||||
|
"operands could not be broadcast together",
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
llvm_usize.const_int(1, false),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None);
|
||||||
|
let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator);
|
||||||
|
let lhs_ndims = lhs.load_ndims(ctx);
|
||||||
|
let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator);
|
||||||
|
let rhs_ndims = rhs.load_ndims(ctx);
|
||||||
|
let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap();
|
||||||
|
let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None);
|
||||||
|
|
||||||
|
ctx.builder
|
||||||
|
.build_call(
|
||||||
|
ndarray_calc_broadcast_fn,
|
||||||
|
&[
|
||||||
|
lhs_dims.into(),
|
||||||
|
lhs_ndims.into(),
|
||||||
|
rhs_dims.into(),
|
||||||
|
rhs_ndims.into(),
|
||||||
|
out_dims.base_ptr(ctx, generator).into(),
|
||||||
|
],
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
TypedArrayLikeAdapter::from(
|
||||||
|
out_dims,
|
||||||
|
Box::new(|_, v| v.into_int_value()),
|
||||||
|
Box::new(|_, v| v.into()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`]
|
||||||
|
/// containing the indices used for accessing `array` corresponding to the index of the broadcasted
|
||||||
|
/// array `broadcast_idx`.
|
||||||
|
pub fn call_ndarray_calc_broadcast_index<
|
||||||
|
'ctx,
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
BroadcastIdx: UntypedArrayLikeAccessor<'ctx>,
|
||||||
|
>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
array: NDArrayValue<'ctx>,
|
||||||
|
broadcast_idx: &BroadcastIdx,
|
||||||
|
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
|
||||||
|
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||||
|
|
||||||
|
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
|
||||||
|
32 => "__nac3_ndarray_calc_broadcast_idx",
|
||||||
|
64 => "__nac3_ndarray_calc_broadcast_idx64",
|
||||||
|
bw => unreachable!("Unsupported size type bit width: {}", bw),
|
||||||
|
};
|
||||||
|
let ndarray_calc_broadcast_fn =
|
||||||
|
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
|
||||||
|
let fn_type = llvm_usize.fn_type(
|
||||||
|
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()],
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
|
||||||
|
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
let broadcast_size = broadcast_idx.size(ctx, generator);
|
||||||
|
let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();
|
||||||
|
|
||||||
|
let array_dims = array.dim_sizes().base_ptr(ctx, generator);
|
||||||
|
let array_ndims = array.load_ndims(ctx);
|
||||||
|
let broadcast_idx_ptr = unsafe {
|
||||||
|
broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
|
};
|
||||||
|
|
||||||
|
ctx.builder
|
||||||
|
.build_call(
|
||||||
|
ndarray_calc_broadcast_fn,
|
||||||
|
&[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()],
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
TypedArrayLikeAdapter::from(
|
||||||
|
ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None),
|
||||||
|
Box::new(|_, v| v.into_int_value()),
|
||||||
|
Box::new(|_, v| v.into()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
@ -1,72 +0,0 @@
|
|||||||
use inkwell::{types::BasicTypeEnum, values::IntValue};
|
|
||||||
|
|
||||||
use crate::codegen::{
|
|
||||||
expr::infer_and_call_function,
|
|
||||||
irrt::get_usize_dependent_function_name,
|
|
||||||
values::{ndarray::NDArrayValue, ListValue, ProxyValue, TypedArrayLikeAccessor},
|
|
||||||
CodeGenContext, CodeGenerator,
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Generates a call to `__nac3_ndarray_array_set_and_validate_list_shape`.
|
|
||||||
///
|
|
||||||
/// Deduces the target shape of the `ndarray` from the provided `list`, raising an exception if
|
|
||||||
/// there is any issue with the resultant `shape`.
|
|
||||||
///
|
|
||||||
/// `shape` must be pre-allocated by the caller of this function to `[usize; ndims]`, and must be
|
|
||||||
/// initialized to all `-1`s.
|
|
||||||
pub fn call_nac3_ndarray_array_set_and_validate_list_shape<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
list: ListValue<'ctx>,
|
|
||||||
ndims: IntValue<'ctx>,
|
|
||||||
shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
|
||||||
) {
|
|
||||||
let llvm_usize = ctx.get_size_type();
|
|
||||||
assert_eq!(list.get_type().element_type().unwrap(), ctx.ctx.i8_type().into());
|
|
||||||
assert_eq!(ndims.get_type(), llvm_usize);
|
|
||||||
assert_eq!(
|
|
||||||
BasicTypeEnum::try_from(shape.element_type(ctx, generator)).unwrap(),
|
|
||||||
llvm_usize.into()
|
|
||||||
);
|
|
||||||
|
|
||||||
let name =
|
|
||||||
get_usize_dependent_function_name(ctx, "__nac3_ndarray_array_set_and_validate_list_shape");
|
|
||||||
|
|
||||||
infer_and_call_function(
|
|
||||||
ctx,
|
|
||||||
&name,
|
|
||||||
None,
|
|
||||||
&[list.as_base_value().into(), ndims.into(), shape.base_ptr(ctx, generator).into()],
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates a call to `__nac3_ndarray_array_write_list_to_array`.
|
|
||||||
///
|
|
||||||
/// Copies the contents stored in `list` into `ndarray`.
|
|
||||||
///
|
|
||||||
/// The `ndarray` must fulfill the following preconditions:
|
|
||||||
///
|
|
||||||
/// - `ndarray.itemsize`: Must be initialized.
|
|
||||||
/// - `ndarray.ndims`: Must be initialized.
|
|
||||||
/// - `ndarray.shape`: Must be initialized.
|
|
||||||
/// - `ndarray.data`: Must be allocated and contiguous.
|
|
||||||
pub fn call_nac3_ndarray_array_write_list_to_array<'ctx>(
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
list: ListValue<'ctx>,
|
|
||||||
ndarray: NDArrayValue<'ctx>,
|
|
||||||
) {
|
|
||||||
assert_eq!(list.get_type().element_type().unwrap(), ctx.ctx.i8_type().into());
|
|
||||||
|
|
||||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_array_write_list_to_array");
|
|
||||||
|
|
||||||
infer_and_call_function(
|
|
||||||
ctx,
|
|
||||||
&name,
|
|
||||||
None,
|
|
||||||
&[list.as_base_value().into(), ndarray.as_base_value().into()],
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
}
|
|
@ -1,295 +0,0 @@
|
|||||||
use inkwell::{
|
|
||||||
types::BasicTypeEnum,
|
|
||||||
values::{BasicValueEnum, IntValue, PointerValue},
|
|
||||||
AddressSpace,
|
|
||||||
};
|
|
||||||
|
|
||||||
use crate::codegen::{
|
|
||||||
expr::{create_and_call_function, infer_and_call_function},
|
|
||||||
irrt::get_usize_dependent_function_name,
|
|
||||||
types::ProxyType,
|
|
||||||
values::{ndarray::NDArrayValue, ProxyValue, TypedArrayLikeAccessor},
|
|
||||||
CodeGenContext, CodeGenerator,
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Generates a call to `__nac3_ndarray_util_assert_shape_no_negative`.
|
|
||||||
///
|
|
||||||
/// Assets that `shape` does not contain negative dimensions.
|
|
||||||
pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
|
||||||
) {
|
|
||||||
let llvm_usize = ctx.get_size_type();
|
|
||||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
BasicTypeEnum::try_from(shape.element_type(ctx, generator)).unwrap(),
|
|
||||||
llvm_usize.into()
|
|
||||||
);
|
|
||||||
|
|
||||||
let name =
|
|
||||||
get_usize_dependent_function_name(ctx, "__nac3_ndarray_util_assert_shape_no_negative");
|
|
||||||
|
|
||||||
create_and_call_function(
|
|
||||||
ctx,
|
|
||||||
&name,
|
|
||||||
Some(llvm_usize.into()),
|
|
||||||
&[
|
|
||||||
(llvm_usize.into(), shape.size(ctx, generator).into()),
|
|
||||||
(llvm_pusize.into(), shape.base_ptr(ctx, generator).into()),
|
|
||||||
],
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates a call to `__nac3_ndarray_util_assert_shape_output_shape_same`.
|
|
||||||
///
|
|
||||||
/// Asserts that `ndarray_shape` and `output_shape` are the same in the context of writing output to
|
|
||||||
/// an `ndarray`.
|
|
||||||
pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
ndarray_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
|
||||||
output_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
|
||||||
) {
|
|
||||||
let llvm_usize = ctx.get_size_type();
|
|
||||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
BasicTypeEnum::try_from(ndarray_shape.element_type(ctx, generator)).unwrap(),
|
|
||||||
llvm_usize.into()
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
BasicTypeEnum::try_from(output_shape.element_type(ctx, generator)).unwrap(),
|
|
||||||
llvm_usize.into()
|
|
||||||
);
|
|
||||||
|
|
||||||
let name =
|
|
||||||
get_usize_dependent_function_name(ctx, "__nac3_ndarray_util_assert_output_shape_same");
|
|
||||||
|
|
||||||
create_and_call_function(
|
|
||||||
ctx,
|
|
||||||
&name,
|
|
||||||
Some(llvm_usize.into()),
|
|
||||||
&[
|
|
||||||
(llvm_usize.into(), ndarray_shape.size(ctx, generator).into()),
|
|
||||||
(llvm_pusize.into(), ndarray_shape.base_ptr(ctx, generator).into()),
|
|
||||||
(llvm_usize.into(), output_shape.size(ctx, generator).into()),
|
|
||||||
(llvm_pusize.into(), output_shape.base_ptr(ctx, generator).into()),
|
|
||||||
],
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates a call to `__nac3_ndarray_size`.
|
|
||||||
///
|
|
||||||
/// Returns a [`usize`][CodeGenerator::get_size_type] value of the number of elements of an
|
|
||||||
/// `ndarray`, corresponding to the value of `ndarray.size`.
|
|
||||||
pub fn call_nac3_ndarray_size<'ctx>(
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
ndarray: NDArrayValue<'ctx>,
|
|
||||||
) -> IntValue<'ctx> {
|
|
||||||
let llvm_usize = ctx.get_size_type();
|
|
||||||
let llvm_ndarray = ndarray.get_type().as_base_type();
|
|
||||||
|
|
||||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_size");
|
|
||||||
|
|
||||||
create_and_call_function(
|
|
||||||
ctx,
|
|
||||||
&name,
|
|
||||||
Some(llvm_usize.into()),
|
|
||||||
&[(llvm_ndarray.into(), ndarray.as_base_value().into())],
|
|
||||||
Some("size"),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.map(BasicValueEnum::into_int_value)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates a call to `__nac3_ndarray_nbytes`.
|
|
||||||
///
|
|
||||||
/// Returns a [`usize`][CodeGenerator::get_size_type] value of the number of bytes consumed by the
|
|
||||||
/// data of the `ndarray`, corresponding to the value of `ndarray.nbytes`.
|
|
||||||
pub fn call_nac3_ndarray_nbytes<'ctx>(
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
ndarray: NDArrayValue<'ctx>,
|
|
||||||
) -> IntValue<'ctx> {
|
|
||||||
let llvm_usize = ctx.get_size_type();
|
|
||||||
let llvm_ndarray = ndarray.get_type().as_base_type();
|
|
||||||
|
|
||||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_nbytes");
|
|
||||||
|
|
||||||
create_and_call_function(
|
|
||||||
ctx,
|
|
||||||
&name,
|
|
||||||
Some(llvm_usize.into()),
|
|
||||||
&[(llvm_ndarray.into(), ndarray.as_base_value().into())],
|
|
||||||
Some("nbytes"),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.map(BasicValueEnum::into_int_value)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates a call to `__nac3_ndarray_len`.
|
|
||||||
///
|
|
||||||
/// Returns a [`usize`][CodeGenerator::get_size_type] value of the size of the topmost dimension of
|
|
||||||
/// the `ndarray`, corresponding to the value of `ndarray.__len__`.
|
|
||||||
pub fn call_nac3_ndarray_len<'ctx>(
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
ndarray: NDArrayValue<'ctx>,
|
|
||||||
) -> IntValue<'ctx> {
|
|
||||||
let llvm_usize = ctx.get_size_type();
|
|
||||||
let llvm_ndarray = ndarray.get_type().as_base_type();
|
|
||||||
|
|
||||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_len");
|
|
||||||
|
|
||||||
create_and_call_function(
|
|
||||||
ctx,
|
|
||||||
&name,
|
|
||||||
Some(llvm_usize.into()),
|
|
||||||
&[(llvm_ndarray.into(), ndarray.as_base_value().into())],
|
|
||||||
Some("len"),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.map(BasicValueEnum::into_int_value)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates a call to `__nac3_ndarray_is_c_contiguous`.
|
|
||||||
///
|
|
||||||
/// Returns an `i1` value indicating whether the `ndarray` is C-contiguous.
|
|
||||||
pub fn call_nac3_ndarray_is_c_contiguous<'ctx>(
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
ndarray: NDArrayValue<'ctx>,
|
|
||||||
) -> IntValue<'ctx> {
|
|
||||||
let llvm_i1 = ctx.ctx.bool_type();
|
|
||||||
let llvm_ndarray = ndarray.get_type().as_base_type();
|
|
||||||
|
|
||||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_is_c_contiguous");
|
|
||||||
|
|
||||||
create_and_call_function(
|
|
||||||
ctx,
|
|
||||||
&name,
|
|
||||||
Some(llvm_i1.into()),
|
|
||||||
&[(llvm_ndarray.into(), ndarray.as_base_value().into())],
|
|
||||||
Some("is_c_contiguous"),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.map(BasicValueEnum::into_int_value)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates a call to `__nac3_ndarray_get_nth_pelement`.
|
|
||||||
///
|
|
||||||
/// Returns a [`PointerValue`] to the `index`-th flattened element of the `ndarray`.
|
|
||||||
pub fn call_nac3_ndarray_get_nth_pelement<'ctx>(
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
ndarray: NDArrayValue<'ctx>,
|
|
||||||
index: IntValue<'ctx>,
|
|
||||||
) -> PointerValue<'ctx> {
|
|
||||||
let llvm_i8 = ctx.ctx.i8_type();
|
|
||||||
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
|
||||||
let llvm_usize = ctx.get_size_type();
|
|
||||||
let llvm_ndarray = ndarray.get_type().as_base_type();
|
|
||||||
|
|
||||||
assert_eq!(index.get_type(), llvm_usize);
|
|
||||||
|
|
||||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_get_nth_pelement");
|
|
||||||
|
|
||||||
create_and_call_function(
|
|
||||||
ctx,
|
|
||||||
&name,
|
|
||||||
Some(llvm_pi8.into()),
|
|
||||||
&[(llvm_ndarray.into(), ndarray.as_base_value().into()), (llvm_usize.into(), index.into())],
|
|
||||||
Some("pelement"),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.map(BasicValueEnum::into_pointer_value)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates a call to `__nac3_ndarray_get_pelement_by_indices`.
|
|
||||||
///
|
|
||||||
/// `indices` must have the same number of elements as the number of dimensions in `ndarray`.
|
|
||||||
///
|
|
||||||
/// Returns a [`PointerValue`] to the element indexed by `indices`.
|
|
||||||
pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
ndarray: NDArrayValue<'ctx>,
|
|
||||||
indices: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
|
||||||
) -> PointerValue<'ctx> {
|
|
||||||
let llvm_i8 = ctx.ctx.i8_type();
|
|
||||||
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
|
||||||
let llvm_usize = ctx.get_size_type();
|
|
||||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
|
||||||
let llvm_ndarray = ndarray.get_type().as_base_type();
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
BasicTypeEnum::try_from(indices.element_type(ctx, generator)).unwrap(),
|
|
||||||
llvm_usize.into()
|
|
||||||
);
|
|
||||||
|
|
||||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_get_pelement_by_indices");
|
|
||||||
|
|
||||||
create_and_call_function(
|
|
||||||
ctx,
|
|
||||||
&name,
|
|
||||||
Some(llvm_pi8.into()),
|
|
||||||
&[
|
|
||||||
(llvm_ndarray.into(), ndarray.as_base_value().into()),
|
|
||||||
(llvm_pusize.into(), indices.base_ptr(ctx, generator).into()),
|
|
||||||
],
|
|
||||||
Some("pelement"),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.map(BasicValueEnum::into_pointer_value)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates a call to `__nac3_ndarray_set_strides_by_shape`.
|
|
||||||
///
|
|
||||||
/// Sets `ndarray.strides` assuming that `ndarray.shape` is C-contiguous.
|
|
||||||
pub fn call_nac3_ndarray_set_strides_by_shape<'ctx>(
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
ndarray: NDArrayValue<'ctx>,
|
|
||||||
) {
|
|
||||||
let llvm_ndarray = ndarray.get_type().as_base_type();
|
|
||||||
|
|
||||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_set_strides_by_shape");
|
|
||||||
|
|
||||||
create_and_call_function(
|
|
||||||
ctx,
|
|
||||||
&name,
|
|
||||||
None,
|
|
||||||
&[(llvm_ndarray.into(), ndarray.as_base_value().into())],
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates a call to `__nac3_ndarray_copy_data`.
|
|
||||||
///
|
|
||||||
/// Copies all elements from `src_ndarray` to `dst_ndarray` using their flattened views. The number
|
|
||||||
/// of elements in `src_ndarray` must be greater than or equal to the number of elements in
|
|
||||||
/// `dst_ndarray`.
|
|
||||||
pub fn call_nac3_ndarray_copy_data<'ctx>(
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
src_ndarray: NDArrayValue<'ctx>,
|
|
||||||
dst_ndarray: NDArrayValue<'ctx>,
|
|
||||||
) {
|
|
||||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_copy_data");
|
|
||||||
|
|
||||||
infer_and_call_function(
|
|
||||||
ctx,
|
|
||||||
&name,
|
|
||||||
None,
|
|
||||||
&[src_ndarray.as_base_value().into(), dst_ndarray.as_base_value().into()],
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
}
|
|
@ -1,81 +0,0 @@
|
|||||||
use inkwell::values::IntValue;
|
|
||||||
|
|
||||||
use crate::codegen::{
|
|
||||||
expr::infer_and_call_function,
|
|
||||||
irrt::get_usize_dependent_function_name,
|
|
||||||
types::{ndarray::ShapeEntryType, ProxyType},
|
|
||||||
values::{
|
|
||||||
ndarray::NDArrayValue, ArrayLikeValue, ArraySliceValue, ProxyValue, TypedArrayLikeAccessor,
|
|
||||||
TypedArrayLikeMutator,
|
|
||||||
},
|
|
||||||
CodeGenContext, CodeGenerator,
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Generates a call to `__nac3_ndarray_broadcast_to`.
|
|
||||||
///
|
|
||||||
/// Attempts to broadcast `src_ndarray` to the new shape defined by `dst_ndarray`.
|
|
||||||
///
|
|
||||||
/// `dst_ndarray` must meet the following preconditions:
|
|
||||||
///
|
|
||||||
/// - `dst_ndarray.ndims` must be initialized and matching the length of `dst_ndarray.shape`.
|
|
||||||
/// - `dst_ndarray.shape` must be initialized and contains the target broadcast shape.
|
|
||||||
/// - `dst_ndarray.strides` must be allocated and may contain uninitialized values.
|
|
||||||
pub fn call_nac3_ndarray_broadcast_to<'ctx>(
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
src_ndarray: NDArrayValue<'ctx>,
|
|
||||||
dst_ndarray: NDArrayValue<'ctx>,
|
|
||||||
) {
|
|
||||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_broadcast_to");
|
|
||||||
infer_and_call_function(
|
|
||||||
ctx,
|
|
||||||
&name,
|
|
||||||
None,
|
|
||||||
&[src_ndarray.as_base_value().into(), dst_ndarray.as_base_value().into()],
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates a call to `__nac3_ndarray_broadcast_shapes`.
|
|
||||||
///
|
|
||||||
/// Attempts to calculate the resultant shape from broadcasting all shapes in `shape_entries`,
|
|
||||||
/// writing the result to `dst_shape`.
|
|
||||||
pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G, Shape>(
|
|
||||||
generator: &G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
num_shape_entries: IntValue<'ctx>,
|
|
||||||
shape_entries: ArraySliceValue<'ctx>,
|
|
||||||
dst_ndims: IntValue<'ctx>,
|
|
||||||
dst_shape: &Shape,
|
|
||||||
) where
|
|
||||||
G: CodeGenerator + ?Sized,
|
|
||||||
Shape: TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>
|
|
||||||
+ TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>>,
|
|
||||||
{
|
|
||||||
let llvm_usize = ctx.get_size_type();
|
|
||||||
|
|
||||||
assert_eq!(num_shape_entries.get_type(), llvm_usize);
|
|
||||||
assert!(ShapeEntryType::is_type(
|
|
||||||
generator,
|
|
||||||
ctx.ctx,
|
|
||||||
shape_entries.base_ptr(ctx, generator).get_type()
|
|
||||||
)
|
|
||||||
.is_ok());
|
|
||||||
assert_eq!(dst_ndims.get_type(), llvm_usize);
|
|
||||||
assert_eq!(dst_shape.element_type(ctx, generator), llvm_usize.into());
|
|
||||||
|
|
||||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_broadcast_shapes");
|
|
||||||
infer_and_call_function(
|
|
||||||
ctx,
|
|
||||||
&name,
|
|
||||||
None,
|
|
||||||
&[
|
|
||||||
num_shape_entries.into(),
|
|
||||||
shape_entries.base_ptr(ctx, generator).into(),
|
|
||||||
dst_ndims.into(),
|
|
||||||
dst_shape.base_ptr(ctx, generator).into(),
|
|
||||||
],
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
}
|
|
@ -1,34 +0,0 @@
|
|||||||
use crate::codegen::{
|
|
||||||
expr::infer_and_call_function,
|
|
||||||
irrt::get_usize_dependent_function_name,
|
|
||||||
values::{ndarray::NDArrayValue, ArrayLikeValue, ArraySliceValue, ProxyValue},
|
|
||||||
CodeGenContext, CodeGenerator,
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Generates a call to `__nac3_ndarray_index`.
|
|
||||||
///
|
|
||||||
/// Performs [basic indexing](https://numpy.org/doc/stable/user/basics.indexing.html#basic-indexing)
|
|
||||||
/// on `src_ndarray` using `indices`, writing the result to `dst_ndarray`, corresponding to the
|
|
||||||
/// operation `dst_ndarray = src_ndarray[indices]`.
|
|
||||||
pub fn call_nac3_ndarray_index<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
indices: ArraySliceValue<'ctx>,
|
|
||||||
src_ndarray: NDArrayValue<'ctx>,
|
|
||||||
dst_ndarray: NDArrayValue<'ctx>,
|
|
||||||
) {
|
|
||||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_index");
|
|
||||||
infer_and_call_function(
|
|
||||||
ctx,
|
|
||||||
&name,
|
|
||||||
None,
|
|
||||||
&[
|
|
||||||
indices.size(ctx, generator).into(),
|
|
||||||
indices.base_ptr(ctx, generator).into(),
|
|
||||||
src_ndarray.as_base_value().into(),
|
|
||||||
dst_ndarray.as_base_value().into(),
|
|
||||||
],
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
}
|
|
@ -1,81 +0,0 @@
|
|||||||
use inkwell::{
|
|
||||||
types::BasicTypeEnum,
|
|
||||||
values::{BasicValueEnum, IntValue},
|
|
||||||
AddressSpace,
|
|
||||||
};
|
|
||||||
|
|
||||||
use crate::codegen::{
|
|
||||||
expr::{create_and_call_function, infer_and_call_function},
|
|
||||||
irrt::get_usize_dependent_function_name,
|
|
||||||
types::ProxyType,
|
|
||||||
values::{
|
|
||||||
ndarray::{NDArrayValue, NDIterValue},
|
|
||||||
ProxyValue, TypedArrayLikeAccessor,
|
|
||||||
},
|
|
||||||
CodeGenContext, CodeGenerator,
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Generates a call to `__nac3_nditer_initialize`.
|
|
||||||
///
|
|
||||||
/// Initializes the `iter` object.
|
|
||||||
pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
iter: NDIterValue<'ctx>,
|
|
||||||
ndarray: NDArrayValue<'ctx>,
|
|
||||||
indices: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
|
||||||
) {
|
|
||||||
let llvm_usize = ctx.get_size_type();
|
|
||||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
BasicTypeEnum::try_from(indices.element_type(ctx, generator)).unwrap(),
|
|
||||||
llvm_usize.into()
|
|
||||||
);
|
|
||||||
|
|
||||||
let name = get_usize_dependent_function_name(ctx, "__nac3_nditer_initialize");
|
|
||||||
|
|
||||||
create_and_call_function(
|
|
||||||
ctx,
|
|
||||||
&name,
|
|
||||||
None,
|
|
||||||
&[
|
|
||||||
(iter.get_type().as_base_type().into(), iter.as_base_value().into()),
|
|
||||||
(ndarray.get_type().as_base_type().into(), ndarray.as_base_value().into()),
|
|
||||||
(llvm_pusize.into(), indices.base_ptr(ctx, generator).into()),
|
|
||||||
],
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates a call to `__nac3_nditer_initialize_has_element`.
|
|
||||||
///
|
|
||||||
/// Returns an `i1` value indicating whether there are elements left to traverse for the `iter`
|
|
||||||
/// object.
|
|
||||||
pub fn call_nac3_nditer_has_element<'ctx>(
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
iter: NDIterValue<'ctx>,
|
|
||||||
) -> IntValue<'ctx> {
|
|
||||||
let name = get_usize_dependent_function_name(ctx, "__nac3_nditer_has_element");
|
|
||||||
|
|
||||||
infer_and_call_function(
|
|
||||||
ctx,
|
|
||||||
&name,
|
|
||||||
Some(ctx.ctx.bool_type().into()),
|
|
||||||
&[iter.as_base_value().into()],
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.map(BasicValueEnum::into_int_value)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates a call to `__nac3_nditer_next`.
|
|
||||||
///
|
|
||||||
/// Moves `iter` to point to the next element.
|
|
||||||
pub fn call_nac3_nditer_next<'ctx>(ctx: &CodeGenContext<'ctx, '_>, iter: NDIterValue<'ctx>) {
|
|
||||||
let name = get_usize_dependent_function_name(ctx, "__nac3_nditer_next");
|
|
||||||
|
|
||||||
infer_and_call_function(ctx, &name, None, &[iter.as_base_value().into()], None, None);
|
|
||||||
}
|
|
@ -1,65 +0,0 @@
|
|||||||
use inkwell::{types::BasicTypeEnum, values::IntValue};
|
|
||||||
|
|
||||||
use crate::codegen::{
|
|
||||||
expr::infer_and_call_function, irrt::get_usize_dependent_function_name,
|
|
||||||
values::TypedArrayLikeAccessor, CodeGenContext, CodeGenerator,
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Generates a call to `__nac3_ndarray_matmul_calculate_shapes`.
|
|
||||||
///
|
|
||||||
/// Calculates the broadcasted shapes for `a`, `b`, and the `ndarray` holding the final values of
|
|
||||||
/// `a @ b`.
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
pub fn call_nac3_ndarray_matmul_calculate_shapes<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
a_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
|
||||||
b_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
|
||||||
final_ndims: IntValue<'ctx>,
|
|
||||||
new_a_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
|
||||||
new_b_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
|
||||||
dst_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
|
||||||
) {
|
|
||||||
let llvm_usize = ctx.get_size_type();
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
BasicTypeEnum::try_from(a_shape.element_type(ctx, generator)).unwrap(),
|
|
||||||
llvm_usize.into()
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
BasicTypeEnum::try_from(b_shape.element_type(ctx, generator)).unwrap(),
|
|
||||||
llvm_usize.into()
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
BasicTypeEnum::try_from(new_a_shape.element_type(ctx, generator)).unwrap(),
|
|
||||||
llvm_usize.into()
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
BasicTypeEnum::try_from(new_b_shape.element_type(ctx, generator)).unwrap(),
|
|
||||||
llvm_usize.into()
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
BasicTypeEnum::try_from(dst_shape.element_type(ctx, generator)).unwrap(),
|
|
||||||
llvm_usize.into()
|
|
||||||
);
|
|
||||||
|
|
||||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_matmul_calculate_shapes");
|
|
||||||
|
|
||||||
infer_and_call_function(
|
|
||||||
ctx,
|
|
||||||
&name,
|
|
||||||
None,
|
|
||||||
&[
|
|
||||||
a_shape.size(ctx, generator).into(),
|
|
||||||
a_shape.base_ptr(ctx, generator).into(),
|
|
||||||
b_shape.size(ctx, generator).into(),
|
|
||||||
b_shape.base_ptr(ctx, generator).into(),
|
|
||||||
final_ndims.into(),
|
|
||||||
new_a_shape.base_ptr(ctx, generator).into(),
|
|
||||||
new_b_shape.base_ptr(ctx, generator).into(),
|
|
||||||
dst_shape.base_ptr(ctx, generator).into(),
|
|
||||||
],
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
}
|
|
@ -1,17 +0,0 @@
|
|||||||
pub use array::*;
|
|
||||||
pub use basic::*;
|
|
||||||
pub use broadcast::*;
|
|
||||||
pub use indexing::*;
|
|
||||||
pub use iter::*;
|
|
||||||
pub use matmul::*;
|
|
||||||
pub use reshape::*;
|
|
||||||
pub use transpose::*;
|
|
||||||
|
|
||||||
mod array;
|
|
||||||
mod basic;
|
|
||||||
mod broadcast;
|
|
||||||
mod indexing;
|
|
||||||
mod iter;
|
|
||||||
mod matmul;
|
|
||||||
mod reshape;
|
|
||||||
mod transpose;
|
|
@ -1,39 +0,0 @@
|
|||||||
use inkwell::values::IntValue;
|
|
||||||
|
|
||||||
use crate::codegen::{
|
|
||||||
expr::infer_and_call_function,
|
|
||||||
irrt::get_usize_dependent_function_name,
|
|
||||||
values::{ArrayLikeValue, ArraySliceValue},
|
|
||||||
CodeGenContext, CodeGenerator,
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Generates a call to `__nac3_ndarray_reshape_resolve_and_check_new_shape`.
|
|
||||||
///
|
|
||||||
/// Resolves unknown dimensions in `new_shape` for `numpy.reshape(<ndarray>, new_shape)`, raising an
|
|
||||||
/// assertion if multiple dimensions are unknown (`-1`).
|
|
||||||
pub fn call_nac3_ndarray_reshape_resolve_and_check_new_shape<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
size: IntValue<'ctx>,
|
|
||||||
new_ndims: IntValue<'ctx>,
|
|
||||||
new_shape: ArraySliceValue<'ctx>,
|
|
||||||
) {
|
|
||||||
let llvm_usize = ctx.get_size_type();
|
|
||||||
|
|
||||||
assert_eq!(size.get_type(), llvm_usize);
|
|
||||||
assert_eq!(new_ndims.get_type(), llvm_usize);
|
|
||||||
assert_eq!(new_shape.element_type(ctx, generator), llvm_usize.into());
|
|
||||||
|
|
||||||
let name = get_usize_dependent_function_name(
|
|
||||||
ctx,
|
|
||||||
"__nac3_ndarray_reshape_resolve_and_check_new_shape",
|
|
||||||
);
|
|
||||||
infer_and_call_function(
|
|
||||||
ctx,
|
|
||||||
&name,
|
|
||||||
None,
|
|
||||||
&[size.into(), new_ndims.into(), new_shape.base_ptr(ctx, generator).into()],
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
}
|
|
@ -1,48 +0,0 @@
|
|||||||
use inkwell::{values::IntValue, AddressSpace};
|
|
||||||
|
|
||||||
use crate::codegen::{
|
|
||||||
expr::infer_and_call_function,
|
|
||||||
irrt::get_usize_dependent_function_name,
|
|
||||||
values::{ndarray::NDArrayValue, ProxyValue, TypedArrayLikeAccessor},
|
|
||||||
CodeGenContext, CodeGenerator,
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Generates a call to `__nac3_ndarray_transpose`.
|
|
||||||
///
|
|
||||||
/// Creates a transpose view of `src_ndarray` and writes the result to `dst_ndarray`.
|
|
||||||
///
|
|
||||||
/// `dst_ndarray` must fulfill the following preconditions:
|
|
||||||
///
|
|
||||||
/// - `dst_ndarray.ndims` must be initialized and must be equal to `src_ndarray.ndims`.
|
|
||||||
/// - `dst_ndarray.shape` must be allocated and may contain uninitialized values.
|
|
||||||
/// - `dst_ndarray.strides` must be allocated and may contain uninitialized values.
|
|
||||||
pub fn call_nac3_ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
src_ndarray: NDArrayValue<'ctx>,
|
|
||||||
dst_ndarray: NDArrayValue<'ctx>,
|
|
||||||
axes: Option<&impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>>,
|
|
||||||
) {
|
|
||||||
let llvm_usize = ctx.get_size_type();
|
|
||||||
|
|
||||||
assert!(axes.is_none_or(|axes| axes.size(ctx, generator).get_type() == llvm_usize));
|
|
||||||
assert!(axes.is_none_or(|axes| axes.element_type(ctx, generator) == llvm_usize.into()));
|
|
||||||
|
|
||||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_transpose");
|
|
||||||
infer_and_call_function(
|
|
||||||
ctx,
|
|
||||||
&name,
|
|
||||||
None,
|
|
||||||
&[
|
|
||||||
src_ndarray.as_base_value().into(),
|
|
||||||
dst_ndarray.as_base_value().into(),
|
|
||||||
axes.map_or(llvm_usize.const_zero(), |axes| axes.size(ctx, generator)).into(),
|
|
||||||
axes.map_or(llvm_usize.ptr_type(AddressSpace::default()).const_null(), |axes| {
|
|
||||||
axes.base_ptr(ctx, generator)
|
|
||||||
})
|
|
||||||
.into(),
|
|
||||||
],
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
}
|
|
@ -1,56 +0,0 @@
|
|||||||
use inkwell::{
|
|
||||||
values::{BasicValueEnum, CallSiteValue, IntValue},
|
|
||||||
IntPredicate,
|
|
||||||
};
|
|
||||||
use itertools::Either;
|
|
||||||
|
|
||||||
use crate::codegen::{CodeGenContext, CodeGenerator};
|
|
||||||
|
|
||||||
/// Invokes the `__nac3_range_slice_len` in IRRT.
|
|
||||||
///
|
|
||||||
/// - `start`: The `i32` start value for the slice.
|
|
||||||
/// - `end`: The `i32` end value for the slice.
|
|
||||||
/// - `step`: The `i32` step value for the slice.
|
|
||||||
///
|
|
||||||
/// Returns an `i32` value of the length of the slice.
|
|
||||||
pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
start: IntValue<'ctx>,
|
|
||||||
end: IntValue<'ctx>,
|
|
||||||
step: IntValue<'ctx>,
|
|
||||||
) -> IntValue<'ctx> {
|
|
||||||
const SYMBOL: &str = "__nac3_range_slice_len";
|
|
||||||
|
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
|
||||||
|
|
||||||
assert_eq!(start.get_type(), llvm_i32);
|
|
||||||
assert_eq!(end.get_type(), llvm_i32);
|
|
||||||
assert_eq!(step.get_type(), llvm_i32);
|
|
||||||
|
|
||||||
let len_func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
|
|
||||||
let fn_t = llvm_i32.fn_type(&[llvm_i32.into(), llvm_i32.into(), llvm_i32.into()], false);
|
|
||||||
ctx.module.add_function(SYMBOL, fn_t, None)
|
|
||||||
});
|
|
||||||
|
|
||||||
// assert step != 0, throw exception if not
|
|
||||||
let not_zero = ctx
|
|
||||||
.builder
|
|
||||||
.build_int_compare(IntPredicate::NE, step, step.get_type().const_zero(), "range_step_ne")
|
|
||||||
.unwrap();
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
not_zero,
|
|
||||||
"0:ValueError",
|
|
||||||
"step must not be zero",
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
|
|
||||||
ctx.builder
|
|
||||||
.build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len")
|
|
||||||
.map(CallSiteValue::try_as_basic_value)
|
|
||||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
|
||||||
.map(Either::unwrap_left)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
@ -1,39 +0,0 @@
|
|||||||
use inkwell::values::{BasicValueEnum, CallSiteValue, IntValue};
|
|
||||||
use itertools::Either;
|
|
||||||
|
|
||||||
use nac3parser::ast::Expr;
|
|
||||||
|
|
||||||
use crate::{
|
|
||||||
codegen::{CodeGenContext, CodeGenerator},
|
|
||||||
typecheck::typedef::Type,
|
|
||||||
};
|
|
||||||
|
|
||||||
/// this function allows index out of range, since python
|
|
||||||
/// allows index out of range in slice (`a = [1,2,3]; a[1:10] == [2,3]`).
|
|
||||||
pub fn handle_slice_index_bound<'ctx, G: CodeGenerator>(
|
|
||||||
i: &Expr<Option<Type>>,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
generator: &mut G,
|
|
||||||
length: IntValue<'ctx>,
|
|
||||||
) -> Result<Option<IntValue<'ctx>>, String> {
|
|
||||||
const SYMBOL: &str = "__nac3_slice_index_bound";
|
|
||||||
let func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
|
|
||||||
let i32_t = ctx.ctx.i32_type();
|
|
||||||
let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into()], false);
|
|
||||||
ctx.module.add_function(SYMBOL, fn_t, None)
|
|
||||||
});
|
|
||||||
|
|
||||||
let i = if let Some(v) = generator.gen_expr(ctx, i)? {
|
|
||||||
v.to_basic_value_enum(ctx, generator, i.custom.unwrap())?
|
|
||||||
} else {
|
|
||||||
return Ok(None);
|
|
||||||
};
|
|
||||||
Ok(Some(
|
|
||||||
ctx.builder
|
|
||||||
.build_call(func, &[i.into(), length.into()], "bounded_ind")
|
|
||||||
.map(CallSiteValue::try_as_basic_value)
|
|
||||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
|
||||||
.map(Either::unwrap_left)
|
|
||||||
.unwrap(),
|
|
||||||
))
|
|
||||||
}
|
|
@ -1,45 +1,34 @@
|
|||||||
use inkwell::values::{BasicValueEnum, CallSiteValue, IntValue, PointerValue};
|
use crate::codegen::model::*;
|
||||||
use itertools::Either;
|
|
||||||
|
|
||||||
use super::get_usize_dependent_function_name;
|
pub struct StrFields<'ctx> {
|
||||||
use crate::codegen::CodeGenContext;
|
/// Pointer to the string. Does not have to be null-terminated.
|
||||||
|
pub content: Field<PointerModel<FixedIntModel<Byte>>>,
|
||||||
/// Generates a call to string equality comparison. Returns an `i1` representing whether the strings are equal.
|
/// Number of bytes this string occupies in space.
|
||||||
pub fn call_string_eq<'ctx>(
|
///
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
/// The [`IntModel`] matches [`Str::sizet`].
|
||||||
str1_ptr: PointerValue<'ctx>,
|
pub length: Field<IntModel<'ctx>>,
|
||||||
str1_len: IntValue<'ctx>,
|
}
|
||||||
str2_ptr: PointerValue<'ctx>,
|
|
||||||
str2_len: IntValue<'ctx>,
|
/// Corresponds to IRRT's `struct Str`
|
||||||
) -> IntValue<'ctx> {
|
///
|
||||||
let llvm_i1 = ctx.ctx.bool_type();
|
/// nac3core's LLVM representation of a string in memory
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
let func_name = get_usize_dependent_function_name(ctx, "nac3_str_eq");
|
pub struct Str<'ctx> {
|
||||||
|
/// The `SizeT` type of this string.
|
||||||
let func = ctx.module.get_function(&func_name).unwrap_or_else(|| {
|
pub sizet: IntModel<'ctx>,
|
||||||
ctx.module.add_function(
|
}
|
||||||
&func_name,
|
|
||||||
llvm_i1.fn_type(
|
impl<'ctx> IsStruct<'ctx> for Str<'ctx> {
|
||||||
&[
|
type Fields = StrFields<'ctx>;
|
||||||
str1_ptr.get_type().into(),
|
|
||||||
str1_len.get_type().into(),
|
fn struct_name(&self) -> &'static str {
|
||||||
str2_ptr.get_type().into(),
|
"Str"
|
||||||
str2_len.get_type().into(),
|
}
|
||||||
],
|
|
||||||
false,
|
fn build_fields(&self, builder: &mut FieldBuilder<'ctx>) -> Self::Fields {
|
||||||
),
|
Self::Fields {
|
||||||
None,
|
content: builder.add_field_auto("content"),
|
||||||
)
|
length: builder.add_field("length", self.sizet),
|
||||||
});
|
}
|
||||||
|
}
|
||||||
ctx.builder
|
|
||||||
.build_call(
|
|
||||||
func,
|
|
||||||
&[str1_ptr.into(), str1_len.into(), str2_ptr.into(), str2_len.into()],
|
|
||||||
"str_eq_call",
|
|
||||||
)
|
|
||||||
.map(CallSiteValue::try_as_basic_value)
|
|
||||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
|
||||||
.map(Either::unwrap_left)
|
|
||||||
.unwrap()
|
|
||||||
}
|
}
|
||||||
|
26
nac3core/src/codegen/irrt/test.rs
Normal file
26
nac3core/src/codegen/irrt/test.rs
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use std::{path::Path, process::Command};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn run_irrt_test() {
|
||||||
|
assert!(
|
||||||
|
cfg!(feature = "test"),
|
||||||
|
"Please do `cargo test -F test` to compile `irrt_test.out` and run test"
|
||||||
|
);
|
||||||
|
|
||||||
|
let irrt_test_out_path = Path::new(concat!(env!("OUT_DIR"), "/irrt_test.out"));
|
||||||
|
let output = Command::new(irrt_test_out_path.to_str().unwrap()).output().unwrap();
|
||||||
|
|
||||||
|
if !output.status.success() {
|
||||||
|
eprintln!("irrt_test failed with status {}:", output.status);
|
||||||
|
eprintln!("====== stdout ======");
|
||||||
|
eprintln!("{}", String::from_utf8(output.stdout).unwrap());
|
||||||
|
eprintln!("====== stderr ======");
|
||||||
|
eprintln!("{}", String::from_utf8(output.stderr).unwrap());
|
||||||
|
eprintln!("====================");
|
||||||
|
|
||||||
|
panic!("irrt_test failed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
79
nac3core/src/codegen/irrt/util.rs
Normal file
79
nac3core/src/codegen/irrt/util.rs
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
use inkwell::{
|
||||||
|
types::{BasicMetadataTypeEnum, BasicType, IntType},
|
||||||
|
values::{AnyValue, BasicMetadataValueEnum},
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
codegen::{model::*, CodeGenContext},
|
||||||
|
util::SizeVariant,
|
||||||
|
};
|
||||||
|
|
||||||
|
fn get_size_variant(ty: IntType) -> SizeVariant {
|
||||||
|
match ty.get_bit_width() {
|
||||||
|
32 => SizeVariant::Bits32,
|
||||||
|
64 => SizeVariant::Bits64,
|
||||||
|
_ => unreachable!("Unsupported int type bit width {}", ty.get_bit_width()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn get_sized_dependent_function_name(ty: IntModel, fn_name: &str) -> String {
|
||||||
|
let mut fn_name = fn_name.to_owned();
|
||||||
|
match get_size_variant(ty.0) {
|
||||||
|
SizeVariant::Bits32 => {
|
||||||
|
// Do nothing, `fn_name` already has the correct name
|
||||||
|
}
|
||||||
|
SizeVariant::Bits64 => {
|
||||||
|
// Append "64", this is the naming convention
|
||||||
|
fn_name.push_str("64");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn_name
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Variadic argument?
|
||||||
|
pub struct FunctionBuilder<'ctx, 'a> {
|
||||||
|
ctx: &'a CodeGenContext<'ctx, 'a>,
|
||||||
|
fn_name: &'a str,
|
||||||
|
arguments: Vec<(BasicMetadataTypeEnum<'ctx>, BasicMetadataValueEnum<'ctx>)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, 'a> FunctionBuilder<'ctx, 'a> {
|
||||||
|
pub fn begin(ctx: &'a CodeGenContext<'ctx, 'a>, fn_name: &'a str) -> Self {
|
||||||
|
FunctionBuilder { ctx, fn_name, arguments: Vec::new() }
|
||||||
|
}
|
||||||
|
|
||||||
|
// The name is for self-documentation
|
||||||
|
#[must_use]
|
||||||
|
pub fn arg<M: Model<'ctx>>(mut self, _name: &'static str, model: M, value: M::Value) -> Self {
|
||||||
|
self.arguments
|
||||||
|
.push((model.get_llvm_type(self.ctx.ctx).into(), value.get_llvm_value().into()));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn returning<M: Model<'ctx>>(self, name: &'static str, return_model: M) -> M::Value {
|
||||||
|
let (param_tys, param_vals): (Vec<_>, Vec<_>) = self.arguments.into_iter().unzip();
|
||||||
|
|
||||||
|
let function = self.ctx.module.get_function(self.fn_name).unwrap_or_else(|| {
|
||||||
|
let return_type = return_model.get_llvm_type(self.ctx.ctx);
|
||||||
|
let fn_type = return_type.fn_type(¶m_tys, false);
|
||||||
|
self.ctx.module.add_function(self.fn_name, fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
let ret = self.ctx.builder.build_call(function, ¶m_vals, name).unwrap();
|
||||||
|
return_model.review(self.ctx.ctx, ret.as_any_value_enum())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Code duplication, but otherwise returning<S: Optic<'ctx>> cannot resolve S if return_optic = None
|
||||||
|
pub fn returning_void(self) {
|
||||||
|
let (param_tys, param_vals): (Vec<_>, Vec<_>) = self.arguments.into_iter().unzip();
|
||||||
|
|
||||||
|
let function = self.ctx.module.get_function(self.fn_name).unwrap_or_else(|| {
|
||||||
|
let return_type = self.ctx.ctx.void_type();
|
||||||
|
let fn_type = return_type.fn_type(¶m_tys, false);
|
||||||
|
self.ctx.module.add_function(self.fn_name, fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
self.ctx.builder.build_call(function, ¶m_vals, "").unwrap();
|
||||||
|
}
|
||||||
|
}
|
@ -1,45 +1,38 @@
|
|||||||
use inkwell::{
|
use crate::codegen::CodeGenContext;
|
||||||
intrinsics::Intrinsic,
|
use inkwell::context::Context;
|
||||||
types::AnyTypeEnum::IntType,
|
use inkwell::intrinsics::Intrinsic;
|
||||||
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue},
|
use inkwell::types::AnyTypeEnum::IntType;
|
||||||
AddressSpace,
|
use inkwell::types::FloatType;
|
||||||
};
|
use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue};
|
||||||
|
use inkwell::AddressSpace;
|
||||||
use itertools::Either;
|
use itertools::Either;
|
||||||
|
|
||||||
use super::CodeGenContext;
|
/// Returns the string representation for the floating-point type `ft` when used in intrinsic
|
||||||
|
/// functions.
|
||||||
|
fn get_float_intrinsic_repr(ctx: &Context, ft: FloatType) -> &'static str {
|
||||||
|
// Standard LLVM floating-point types
|
||||||
|
if ft == ctx.f16_type() {
|
||||||
|
return "f16";
|
||||||
|
}
|
||||||
|
if ft == ctx.f32_type() {
|
||||||
|
return "f32";
|
||||||
|
}
|
||||||
|
if ft == ctx.f64_type() {
|
||||||
|
return "f64";
|
||||||
|
}
|
||||||
|
if ft == ctx.f128_type() {
|
||||||
|
return "f128";
|
||||||
|
}
|
||||||
|
|
||||||
/// Invokes the [`llvm.va_start`](https://llvm.org/docs/LangRef.html#llvm-va-start-intrinsic)
|
// Non-standard floating-point types
|
||||||
/// intrinsic.
|
if ft == ctx.x86_f80_type() {
|
||||||
pub fn call_va_start<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue<'ctx>) {
|
return "f80";
|
||||||
const FN_NAME: &str = "llvm.va_start";
|
}
|
||||||
|
if ft == ctx.ppc_f128_type() {
|
||||||
|
return "ppcf128";
|
||||||
|
}
|
||||||
|
|
||||||
let intrinsic_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
unreachable!()
|
||||||
let llvm_void = ctx.ctx.void_type();
|
|
||||||
let llvm_i8 = ctx.ctx.i8_type();
|
|
||||||
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
|
|
||||||
let fn_type = llvm_void.fn_type(&[llvm_p0i8.into()], false);
|
|
||||||
|
|
||||||
ctx.module.add_function(FN_NAME, fn_type, None)
|
|
||||||
});
|
|
||||||
|
|
||||||
ctx.builder.build_call(intrinsic_fn, &[arglist.into()], "").unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Invokes the [`llvm.va_end`](https://llvm.org/docs/LangRef.html#llvm-va-end-intrinsic)
|
|
||||||
/// intrinsic.
|
|
||||||
pub fn call_va_end<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue<'ctx>) {
|
|
||||||
const FN_NAME: &str = "llvm.va_end";
|
|
||||||
|
|
||||||
let intrinsic_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
|
||||||
let llvm_void = ctx.ctx.void_type();
|
|
||||||
let llvm_i8 = ctx.ctx.i8_type();
|
|
||||||
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
|
|
||||||
let fn_type = llvm_void.fn_type(&[llvm_p0i8.into()], false);
|
|
||||||
|
|
||||||
ctx.module.add_function(FN_NAME, fn_type, None)
|
|
||||||
});
|
|
||||||
|
|
||||||
ctx.builder.build_call(intrinsic_fn, &[arglist.into()], "").unwrap();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Invokes the [`llvm.stacksave`](https://llvm.org/docs/LangRef.html#llvm-stacksave-intrinsic)
|
/// Invokes the [`llvm.stacksave`](https://llvm.org/docs/LangRef.html#llvm-stacksave-intrinsic)
|
||||||
@ -156,7 +149,7 @@ pub fn call_memcpy_generic<'ctx>(
|
|||||||
dest
|
dest
|
||||||
} else {
|
} else {
|
||||||
ctx.builder
|
ctx.builder
|
||||||
.build_bit_cast(dest, llvm_p0i8, "")
|
.build_bitcast(dest, llvm_p0i8, "")
|
||||||
.map(BasicValueEnum::into_pointer_value)
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
};
|
};
|
||||||
@ -164,7 +157,7 @@ pub fn call_memcpy_generic<'ctx>(
|
|||||||
src
|
src
|
||||||
} else {
|
} else {
|
||||||
ctx.builder
|
ctx.builder
|
||||||
.build_bit_cast(src, llvm_p0i8, "")
|
.build_bitcast(src, llvm_p0i8, "")
|
||||||
.map(BasicValueEnum::into_pointer_value)
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
};
|
};
|
||||||
@ -172,58 +165,14 @@ pub fn call_memcpy_generic<'ctx>(
|
|||||||
call_memcpy(ctx, dest, src, len, is_volatile);
|
call_memcpy(ctx, dest, src, len, is_volatile);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Invokes the `llvm.memcpy` intrinsic.
|
|
||||||
///
|
|
||||||
/// Unlike [`call_memcpy`], this function accepts any type of pointer value. If `dest` or `src` is
|
|
||||||
/// not a pointer to an integer, the pointer(s) will be cast to `i8*` before invoking `memcpy`.
|
|
||||||
/// Moreover, `len` now refers to the number of elements to copy (rather than number of bytes to
|
|
||||||
/// copy).
|
|
||||||
pub fn call_memcpy_generic_array<'ctx>(
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
dest: PointerValue<'ctx>,
|
|
||||||
src: PointerValue<'ctx>,
|
|
||||||
len: IntValue<'ctx>,
|
|
||||||
is_volatile: IntValue<'ctx>,
|
|
||||||
) {
|
|
||||||
let llvm_i8 = ctx.ctx.i8_type();
|
|
||||||
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
|
|
||||||
let llvm_sizeof_expr_t = llvm_i8.size_of().get_type();
|
|
||||||
|
|
||||||
let dest_elem_t = dest.get_type().get_element_type();
|
|
||||||
let src_elem_t = src.get_type().get_element_type();
|
|
||||||
|
|
||||||
let dest = if matches!(dest_elem_t, IntType(t) if t.get_bit_width() == 8) {
|
|
||||||
dest
|
|
||||||
} else {
|
|
||||||
ctx.builder
|
|
||||||
.build_bit_cast(dest, llvm_p0i8, "")
|
|
||||||
.map(BasicValueEnum::into_pointer_value)
|
|
||||||
.unwrap()
|
|
||||||
};
|
|
||||||
let src = if matches!(src_elem_t, IntType(t) if t.get_bit_width() == 8) {
|
|
||||||
src
|
|
||||||
} else {
|
|
||||||
ctx.builder
|
|
||||||
.build_bit_cast(src, llvm_p0i8, "")
|
|
||||||
.map(BasicValueEnum::into_pointer_value)
|
|
||||||
.unwrap()
|
|
||||||
};
|
|
||||||
|
|
||||||
let len = ctx.builder.build_int_z_extend_or_bit_cast(len, llvm_sizeof_expr_t, "").unwrap();
|
|
||||||
let len = ctx.builder.build_int_mul(len, src_elem_t.size_of().unwrap(), "").unwrap();
|
|
||||||
|
|
||||||
call_memcpy(ctx, dest, src, len, is_volatile);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Macro to find and generate build call for llvm intrinsic (body of llvm intrinsic function)
|
/// Macro to find and generate build call for llvm intrinsic (body of llvm intrinsic function)
|
||||||
///
|
///
|
||||||
/// Arguments:
|
/// Arguments:
|
||||||
/// * `$ctx:ident`: Reference to the current Code Generation Context
|
/// * `$ctx:ident`: Reference to the current Code Generation Context
|
||||||
/// * `$name:ident`: Optional name to be assigned to the llvm build call (Option<&str>)
|
/// * `$name:ident`: Optional name to be assigned to the llvm build call (Option<&str>)
|
||||||
/// * `$llvm_name:literal`: Name of underlying llvm intrinsic function
|
/// * `$llvm_name:literal`: Name of underlying llvm intrinsic function
|
||||||
/// * `$map_fn:ident`: Mapping function to be applied on `BasicValue` (`BasicValue` -> Function Return Type).
|
/// * `$map_fn:ident`: Mapping function to be applied on `BasicValue` (`BasicValue` -> Function Return Type)
|
||||||
/// Use `BasicValueEnum::into_int_value` for Integer return type and
|
/// Use `BasicValueEnum::into_int_value` for Integer return type and `BasicValueEnum::into_float_value` for Float return type
|
||||||
/// `BasicValueEnum::into_float_value` for Float return type
|
|
||||||
/// * `$llvm_ty:ident`: Type of first operand
|
/// * `$llvm_ty:ident`: Type of first operand
|
||||||
/// * `,($val:ident)*`: Comma separated list of operands
|
/// * `,($val:ident)*`: Comma separated list of operands
|
||||||
macro_rules! generate_llvm_intrinsic_fn_body {
|
macro_rules! generate_llvm_intrinsic_fn_body {
|
||||||
@ -239,8 +188,8 @@ macro_rules! generate_llvm_intrinsic_fn_body {
|
|||||||
/// Arguments:
|
/// Arguments:
|
||||||
/// * `float/int`: Indicates the return and argument type of the function
|
/// * `float/int`: Indicates the return and argument type of the function
|
||||||
/// * `$fn_name:ident`: The identifier of the rust function to be generated
|
/// * `$fn_name:ident`: The identifier of the rust function to be generated
|
||||||
/// * `$llvm_name:literal`: Name of underlying llvm intrinsic function.
|
/// * `$llvm_name:literal`: Name of underlying llvm intrinsic function
|
||||||
/// Omit "llvm." prefix from the function name i.e. use "ceil" instead of "llvm.ceil"
|
/// Omit "llvm." prefix from the function name i.e. use "ceil" instead of "llvm.ceil"
|
||||||
/// * `$val:ident`: The operand for unary operations
|
/// * `$val:ident`: The operand for unary operations
|
||||||
/// * `$val1:ident`, `$val2:ident`: The operands for binary operations
|
/// * `$val1:ident`, `$val2:ident`: The operands for binary operations
|
||||||
macro_rules! generate_llvm_intrinsic_fn {
|
macro_rules! generate_llvm_intrinsic_fn {
|
||||||
@ -357,25 +306,3 @@ pub fn call_float_powi<'ctx>(
|
|||||||
.map(Either::unwrap_left)
|
.map(Either::unwrap_left)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Invokes the [`llvm.ctpop`](https://llvm.org/docs/LangRef.html#llvm-ctpop-intrinsic) intrinsic.
|
|
||||||
pub fn call_int_ctpop<'ctx>(
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
src: IntValue<'ctx>,
|
|
||||||
name: Option<&str>,
|
|
||||||
) -> IntValue<'ctx> {
|
|
||||||
const FN_NAME: &str = "llvm.ctpop";
|
|
||||||
|
|
||||||
let llvm_src_t = src.get_type();
|
|
||||||
|
|
||||||
let intrinsic_fn = Intrinsic::find(FN_NAME)
|
|
||||||
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_src_t.into()]))
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
ctx.builder
|
|
||||||
.build_call(intrinsic_fn, &[src.into()], name.unwrap_or_default())
|
|
||||||
.map(CallSiteValue::try_as_basic_value)
|
|
||||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
|
||||||
.map(Either::unwrap_left)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
@ -1,13 +1,12 @@
|
|||||||
use std::{
|
use crate::{
|
||||||
cell::OnceCell,
|
codegen::classes::{ListType, NDArrayType, ProxyType, RangeType},
|
||||||
collections::{HashMap, HashSet},
|
symbol_resolver::{StaticValue, SymbolResolver},
|
||||||
sync::{
|
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef},
|
||||||
atomic::{AtomicBool, Ordering},
|
typecheck::{
|
||||||
Arc,
|
type_inferencer::{CodeLocation, PrimitiveStore},
|
||||||
|
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
|
||||||
},
|
},
|
||||||
thread,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use crossbeam::channel::{unbounded, Receiver, Sender};
|
use crossbeam::channel::{unbounded, Receiver, Sender};
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
attributes::{Attribute, AttributeLoc},
|
attributes::{Attribute, AttributeLoc},
|
||||||
@ -20,61 +19,39 @@ use inkwell::{
|
|||||||
module::Module,
|
module::Module,
|
||||||
passes::PassBuilderOptions,
|
passes::PassBuilderOptions,
|
||||||
targets::{CodeModel, RelocMode, Target, TargetMachine, TargetTriple},
|
targets::{CodeModel, RelocMode, Target, TargetMachine, TargetTriple},
|
||||||
types::{AnyType, BasicType, BasicTypeEnum, IntType},
|
types::{AnyType, BasicType, BasicTypeEnum},
|
||||||
values::{BasicValueEnum, FunctionValue, IntValue, PhiValue, PointerValue},
|
values::{BasicValueEnum, FunctionValue, IntValue, PhiValue, PointerValue},
|
||||||
AddressSpace, IntPredicate, OptimizationLevel,
|
AddressSpace, IntPredicate, OptimizationLevel,
|
||||||
};
|
};
|
||||||
|
use irrt::string::Str;
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use parking_lot::{Condvar, Mutex};
|
use model::*;
|
||||||
|
|
||||||
use nac3parser::ast::{Location, Stmt, StrRef};
|
use nac3parser::ast::{Location, Stmt, StrRef};
|
||||||
|
use parking_lot::{Condvar, Mutex};
|
||||||
use crate::{
|
use std::collections::{HashMap, HashSet};
|
||||||
symbol_resolver::{StaticValue, SymbolResolver},
|
use std::sync::{
|
||||||
toplevel::{
|
atomic::{AtomicBool, Ordering},
|
||||||
helper::{extract_ndims, PrimDef},
|
Arc,
|
||||||
numpy::unpack_ndarray_var_tys,
|
|
||||||
TopLevelContext, TopLevelDef,
|
|
||||||
},
|
|
||||||
typecheck::{
|
|
||||||
type_inferencer::{CodeLocation, PrimitiveStore},
|
|
||||||
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore};
|
use std::thread;
|
||||||
pub use generator::{CodeGenerator, DefaultCodeGenerator};
|
|
||||||
use types::{ndarray::NDArrayType, ListType, ProxyType, RangeType, TupleType};
|
|
||||||
|
|
||||||
pub mod builtin_fns;
|
pub mod builtin_fns;
|
||||||
|
pub mod classes;
|
||||||
pub mod concrete_type;
|
pub mod concrete_type;
|
||||||
pub mod expr;
|
pub mod expr;
|
||||||
pub mod extern_fns;
|
pub mod extern_fns;
|
||||||
mod generator;
|
mod generator;
|
||||||
pub mod irrt;
|
pub mod irrt;
|
||||||
pub mod llvm_intrinsics;
|
pub mod llvm_intrinsics;
|
||||||
|
pub mod model;
|
||||||
pub mod numpy;
|
pub mod numpy;
|
||||||
pub mod stmt;
|
pub mod stmt;
|
||||||
pub mod types;
|
|
||||||
pub mod values;
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test;
|
mod test;
|
||||||
|
|
||||||
mod macros {
|
use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore};
|
||||||
/// Codegen-variant of [`std::unreachable`] which accepts an instance of [`CodeGenContext`] as
|
pub use generator::{CodeGenerator, DefaultCodeGenerator};
|
||||||
/// its first argument to provide Python source information to indicate the codegen location
|
|
||||||
/// causing the assertion.
|
|
||||||
macro_rules! codegen_unreachable {
|
|
||||||
($ctx:expr $(,)?) => {
|
|
||||||
std::unreachable!("unreachable code while processing {}", &$ctx.current_loc)
|
|
||||||
};
|
|
||||||
($ctx:expr, $($arg:tt)*) => {
|
|
||||||
std::unreachable!("unreachable code while processing {}: {}", &$ctx.current_loc, std::format!("{}", std::format_args!($($arg)+)))
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) use codegen_unreachable;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
pub struct StaticValueStore {
|
pub struct StaticValueStore {
|
||||||
@ -94,16 +71,6 @@ pub struct CodeGenLLVMOptions {
|
|||||||
pub target: CodeGenTargetMachineOptions,
|
pub target: CodeGenTargetMachineOptions,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CodeGenLLVMOptions {
|
|
||||||
/// Creates a [`TargetMachine`] using the target options specified by this struct.
|
|
||||||
///
|
|
||||||
/// See [`Target::create_target_machine`].
|
|
||||||
#[must_use]
|
|
||||||
pub fn create_target_machine(&self) -> Option<TargetMachine> {
|
|
||||||
self.target.create_target_machine(self.opt_level)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Additional options for code generation for the target machine.
|
/// Additional options for code generation for the target machine.
|
||||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||||
pub struct CodeGenTargetMachineOptions {
|
pub struct CodeGenTargetMachineOptions {
|
||||||
@ -227,33 +194,14 @@ pub struct CodeGenContext<'ctx, 'a> {
|
|||||||
|
|
||||||
/// The current source location.
|
/// The current source location.
|
||||||
pub current_loc: Location,
|
pub current_loc: Location,
|
||||||
|
|
||||||
/// The cached type of `size_t`.
|
|
||||||
llvm_usize: OnceCell<IntType<'ctx>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ctx> CodeGenContext<'ctx, '_> {
|
impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||||
/// Whether the [current basic block][Builder::get_insert_block] referenced by `builder`
|
/// Whether the [current basic block][Builder::get_insert_block] referenced by `builder`
|
||||||
/// contains a [terminator statement][BasicBlock::get_terminator].
|
/// contains a [terminator statement][BasicBlock::get_terminator].
|
||||||
pub fn is_terminated(&self) -> bool {
|
pub fn is_terminated(&self) -> bool {
|
||||||
self.builder.get_insert_block().and_then(BasicBlock::get_terminator).is_some()
|
self.builder.get_insert_block().and_then(BasicBlock::get_terminator).is_some()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a [`IntType`] representing `size_t` for the compilation target as specified by
|
|
||||||
/// [`self.registry`][WorkerRegistry].
|
|
||||||
pub fn get_size_type(&self) -> IntType<'ctx> {
|
|
||||||
*self.llvm_usize.get_or_init(|| {
|
|
||||||
self.ctx.ptr_sized_int_type(
|
|
||||||
&self
|
|
||||||
.registry
|
|
||||||
.llvm_options
|
|
||||||
.create_target_machine()
|
|
||||||
.map(|tm| tm.get_target_data())
|
|
||||||
.unwrap(),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Fp = Box<dyn Fn(&Module) + Send + Sync>;
|
type Fp = Box<dyn Fn(&Module) + Send + Sync>;
|
||||||
@ -393,10 +341,6 @@ impl WorkerRegistry {
|
|||||||
let mut builder = context.create_builder();
|
let mut builder = context.create_builder();
|
||||||
let mut module = context.create_module(generator.get_name());
|
let mut module = context.create_module(generator.get_name());
|
||||||
|
|
||||||
let target_machine = self.llvm_options.create_target_machine().unwrap();
|
|
||||||
module.set_data_layout(&target_machine.get_target_data().get_data_layout());
|
|
||||||
module.set_triple(&target_machine.get_triple());
|
|
||||||
|
|
||||||
module.add_basic_value_flag(
|
module.add_basic_value_flag(
|
||||||
"Debug Info Version",
|
"Debug Info Version",
|
||||||
inkwell::module::FlagBehavior::Warning,
|
inkwell::module::FlagBehavior::Warning,
|
||||||
@ -420,10 +364,6 @@ impl WorkerRegistry {
|
|||||||
errors.insert(e);
|
errors.insert(e);
|
||||||
// create a new empty module just to continue codegen and collect errors
|
// create a new empty module just to continue codegen and collect errors
|
||||||
module = context.create_module(&format!("{}_recover", generator.get_name()));
|
module = context.create_module(&format!("{}_recover", generator.get_name()));
|
||||||
|
|
||||||
let target_machine = self.llvm_options.create_target_machine().unwrap();
|
|
||||||
module.set_data_layout(&target_machine.get_target_data().get_data_layout());
|
|
||||||
module.set_triple(&target_machine.get_triple());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
*self.task_count.lock() -= 1;
|
*self.task_count.lock() -= 1;
|
||||||
@ -489,7 +429,7 @@ pub struct CodeGenTask {
|
|||||||
fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
|
fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
ctx: &'ctx Context,
|
ctx: &'ctx Context,
|
||||||
module: &Module<'ctx>,
|
module: &Module<'ctx>,
|
||||||
generator: &G,
|
generator: &mut G,
|
||||||
unifier: &mut Unifier,
|
unifier: &mut Unifier,
|
||||||
top_level: &TopLevelContext,
|
top_level: &TopLevelContext,
|
||||||
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
|
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
|
||||||
@ -501,38 +441,6 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
type_cache.get(&unifier.get_representative(ty)).copied().unwrap_or_else(|| {
|
type_cache.get(&unifier.get_representative(ty)).copied().unwrap_or_else(|| {
|
||||||
let ty_enum = unifier.get_ty(ty);
|
let ty_enum = unifier.get_ty(ty);
|
||||||
let result = match &*ty_enum {
|
let result = match &*ty_enum {
|
||||||
TModule {module_id, attributes} => {
|
|
||||||
let top_level_defs = top_level.definitions.read();
|
|
||||||
let definition = top_level_defs.get(module_id.0).unwrap();
|
|
||||||
let TopLevelDef::Module { name, attributes: attribute_fields, .. } = &*definition.read() else {
|
|
||||||
unreachable!()
|
|
||||||
};
|
|
||||||
let ty: BasicTypeEnum<'_> = if let Some(t) = module.get_struct_type(&name.to_string()) {
|
|
||||||
t.ptr_type(AddressSpace::default()).into()
|
|
||||||
} else {
|
|
||||||
let struct_type = ctx.opaque_struct_type(&name.to_string());
|
|
||||||
type_cache.insert(
|
|
||||||
unifier.get_representative(ty),
|
|
||||||
struct_type.ptr_type(AddressSpace::default()).into(),
|
|
||||||
);
|
|
||||||
let module_fields: Vec<BasicTypeEnum<'_>> = attribute_fields.iter()
|
|
||||||
.map(|f| {
|
|
||||||
get_llvm_type(
|
|
||||||
ctx,
|
|
||||||
module,
|
|
||||||
generator,
|
|
||||||
unifier,
|
|
||||||
top_level,
|
|
||||||
type_cache,
|
|
||||||
attributes[&f.0].0,
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.collect_vec();
|
|
||||||
struct_type.set_body(&module_fields, false);
|
|
||||||
struct_type.ptr_type(AddressSpace::default()).into()
|
|
||||||
};
|
|
||||||
return ty;
|
|
||||||
},
|
|
||||||
TObj { obj_id, fields, .. } => {
|
TObj { obj_id, fields, .. } => {
|
||||||
// check to avoid treating non-class primitives as classes
|
// check to avoid treating non-class primitives as classes
|
||||||
if PrimDef::contains_id(*obj_id) {
|
if PrimDef::contains_id(*obj_id) {
|
||||||
@ -562,17 +470,16 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
*params.iter().next().unwrap().1,
|
*params.iter().next().unwrap().1,
|
||||||
);
|
);
|
||||||
|
|
||||||
ListType::new_with_generator(generator, ctx, element_type).as_base_type().into()
|
ListType::new(generator, ctx, element_type).as_base_type().into()
|
||||||
}
|
}
|
||||||
|
|
||||||
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||||
let (dtype, ndims) = unpack_ndarray_var_tys(unifier, ty);
|
let (dtype, _) = unpack_ndarray_var_tys(unifier, ty);
|
||||||
let ndims = extract_ndims(unifier, ndims);
|
|
||||||
let element_type = get_llvm_type(
|
let element_type = get_llvm_type(
|
||||||
ctx, module, generator, unifier, top_level, type_cache, dtype,
|
ctx, module, generator, unifier, top_level, type_cache, dtype,
|
||||||
);
|
);
|
||||||
|
|
||||||
NDArrayType::new_with_generator(generator, ctx, element_type, ndims).as_base_type().into()
|
NDArrayType::new(generator, ctx, element_type).as_base_type().into()
|
||||||
}
|
}
|
||||||
|
|
||||||
_ => unreachable!(
|
_ => unreachable!(
|
||||||
@ -616,17 +523,15 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
};
|
};
|
||||||
return ty;
|
return ty;
|
||||||
}
|
}
|
||||||
TTuple { ty, is_vararg_ctx } => {
|
TTuple { ty } => {
|
||||||
// a struct with fields in the order present in the tuple
|
// a struct with fields in the order present in the tuple
|
||||||
assert!(!is_vararg_ctx, "Tuples in vararg context must be instantiated with the correct number of arguments before calling get_llvm_type");
|
|
||||||
|
|
||||||
let fields = ty
|
let fields = ty
|
||||||
.iter()
|
.iter()
|
||||||
.map(|ty| {
|
.map(|ty| {
|
||||||
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty)
|
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty)
|
||||||
})
|
})
|
||||||
.collect_vec();
|
.collect_vec();
|
||||||
TupleType::new_with_generator(generator, ctx, &fields).as_base_type().into()
|
ctx.struct_type(&fields, false).into()
|
||||||
}
|
}
|
||||||
TVirtual { .. } => unimplemented!(),
|
TVirtual { .. } => unimplemented!(),
|
||||||
_ => unreachable!("{}", ty_enum.get_type_name()),
|
_ => unreachable!("{}", ty_enum.get_type_name()),
|
||||||
@ -649,7 +554,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
fn get_llvm_abi_type<'ctx, G: CodeGenerator + ?Sized>(
|
fn get_llvm_abi_type<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
ctx: &'ctx Context,
|
ctx: &'ctx Context,
|
||||||
module: &Module<'ctx>,
|
module: &Module<'ctx>,
|
||||||
generator: &G,
|
generator: &mut G,
|
||||||
unifier: &mut Unifier,
|
unifier: &mut Unifier,
|
||||||
top_level: &TopLevelContext,
|
top_level: &TopLevelContext,
|
||||||
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
|
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
|
||||||
@ -658,11 +563,11 @@ fn get_llvm_abi_type<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
) -> BasicTypeEnum<'ctx> {
|
) -> BasicTypeEnum<'ctx> {
|
||||||
// If the type is used in the definition of a function, return `i1` instead of `i8` for ABI
|
// If the type is used in the definition of a function, return `i1` instead of `i8` for ABI
|
||||||
// consistency.
|
// consistency.
|
||||||
if unifier.unioned(ty, primitives.bool) {
|
return if unifier.unioned(ty, primitives.bool) {
|
||||||
ctx.bool_type().into()
|
ctx.bool_type().into()
|
||||||
} else {
|
} else {
|
||||||
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, ty)
|
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, ty)
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Whether `sret` is needed for a return value with type `ty`.
|
/// Whether `sret` is needed for a return value with type `ty`.
|
||||||
@ -687,40 +592,6 @@ fn need_sret(ty: BasicTypeEnum) -> bool {
|
|||||||
need_sret_impl(ty, true)
|
need_sret_impl(ty, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the [`BasicTypeEnum`] representing a `va_list` struct for variadic arguments.
|
|
||||||
fn get_llvm_valist_type<'ctx>(ctx: &'ctx Context, triple: &TargetTriple) -> BasicTypeEnum<'ctx> {
|
|
||||||
let triple = TargetMachine::normalize_triple(triple);
|
|
||||||
let triple = triple.as_str().to_str().unwrap();
|
|
||||||
let arch = triple.split('-').next().unwrap();
|
|
||||||
|
|
||||||
let llvm_pi8 = ctx.i8_type().ptr_type(AddressSpace::default());
|
|
||||||
|
|
||||||
// Referenced from parseArch() in llvm/lib/Support/Triple.cpp
|
|
||||||
match arch {
|
|
||||||
"i386" | "i486" | "i586" | "i686" | "riscv32" => {
|
|
||||||
ctx.i8_type().ptr_type(AddressSpace::default()).into()
|
|
||||||
}
|
|
||||||
"amd64" | "x86_64" | "x86_64h" => {
|
|
||||||
let llvm_i32 = ctx.i32_type();
|
|
||||||
|
|
||||||
let va_list_tag = ctx.opaque_struct_type("struct.__va_list_tag");
|
|
||||||
va_list_tag.set_body(
|
|
||||||
&[llvm_i32.into(), llvm_i32.into(), llvm_pi8.into(), llvm_pi8.into()],
|
|
||||||
false,
|
|
||||||
);
|
|
||||||
va_list_tag.into()
|
|
||||||
}
|
|
||||||
"armv7" => {
|
|
||||||
let va_list = ctx.opaque_struct_type("struct.__va_list");
|
|
||||||
va_list.set_body(&[llvm_pi8.into()], false);
|
|
||||||
va_list.into()
|
|
||||||
}
|
|
||||||
triple => {
|
|
||||||
todo!("Unsupported platform for varargs: {triple}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Implementation for generating LLVM IR for a function.
|
/// Implementation for generating LLVM IR for a function.
|
||||||
pub fn gen_func_impl<
|
pub fn gen_func_impl<
|
||||||
'ctx,
|
'ctx,
|
||||||
@ -786,19 +657,8 @@ pub fn gen_func_impl<
|
|||||||
(primitives.float, context.f64_type().into()),
|
(primitives.float, context.f64_type().into()),
|
||||||
(primitives.bool, context.i8_type().into()),
|
(primitives.bool, context.i8_type().into()),
|
||||||
(primitives.str, {
|
(primitives.str, {
|
||||||
let name = "str";
|
let sizet = IntModel(generator.get_size_type(context));
|
||||||
match module.get_struct_type(name) {
|
StructModel(Str { sizet }).get_llvm_type(context)
|
||||||
None => {
|
|
||||||
let str_type = context.opaque_struct_type("str");
|
|
||||||
let fields = [
|
|
||||||
context.i8_type().ptr_type(AddressSpace::default()).into(),
|
|
||||||
generator.get_size_type(context).into(),
|
|
||||||
];
|
|
||||||
str_type.set_body(&fields, false);
|
|
||||||
str_type.into()
|
|
||||||
}
|
|
||||||
Some(t) => t.as_basic_type_enum(),
|
|
||||||
}
|
|
||||||
}),
|
}),
|
||||||
(primitives.range, RangeType::new(context).as_base_type().into()),
|
(primitives.range, RangeType::new(context).as_base_type().into()),
|
||||||
(primitives.exception, {
|
(primitives.exception, {
|
||||||
@ -832,7 +692,6 @@ pub fn gen_func_impl<
|
|||||||
name: arg.name,
|
name: arg.name,
|
||||||
ty: task.store.to_unifier_type(&mut unifier, &primitives, arg.ty, &mut cache),
|
ty: task.store.to_unifier_type(&mut unifier, &primitives, arg.ty, &mut cache),
|
||||||
default_value: arg.default_value.clone(),
|
default_value: arg.default_value.clone(),
|
||||||
is_vararg: arg.is_vararg,
|
|
||||||
})
|
})
|
||||||
.collect_vec(),
|
.collect_vec(),
|
||||||
task.store.to_unifier_type(&mut unifier, &primitives, *ret, &mut cache),
|
task.store.to_unifier_type(&mut unifier, &primitives, *ret, &mut cache),
|
||||||
@ -855,10 +714,7 @@ pub fn gen_func_impl<
|
|||||||
let has_sret = ret_type.map_or(false, |ty| need_sret(ty));
|
let has_sret = ret_type.map_or(false, |ty| need_sret(ty));
|
||||||
let mut params = args
|
let mut params = args
|
||||||
.iter()
|
.iter()
|
||||||
.filter(|arg| !arg.is_vararg)
|
|
||||||
.map(|arg| {
|
.map(|arg| {
|
||||||
debug_assert!(!arg.is_vararg);
|
|
||||||
|
|
||||||
get_llvm_abi_type(
|
get_llvm_abi_type(
|
||||||
context,
|
context,
|
||||||
&module,
|
&module,
|
||||||
@ -877,12 +733,9 @@ pub fn gen_func_impl<
|
|||||||
params.insert(0, ret_type.unwrap().ptr_type(AddressSpace::default()).into());
|
params.insert(0, ret_type.unwrap().ptr_type(AddressSpace::default()).into());
|
||||||
}
|
}
|
||||||
|
|
||||||
debug_assert!(matches!(args.iter().filter(|arg| arg.is_vararg).count(), 0..=1));
|
|
||||||
let vararg_arg = args.iter().find(|arg| arg.is_vararg);
|
|
||||||
|
|
||||||
let fn_type = match ret_type {
|
let fn_type = match ret_type {
|
||||||
Some(ret_type) if !has_sret => ret_type.fn_type(¶ms, vararg_arg.is_some()),
|
Some(ret_type) if !has_sret => ret_type.fn_type(¶ms, false),
|
||||||
_ => context.void_type().fn_type(¶ms, vararg_arg.is_some()),
|
_ => context.void_type().fn_type(¶ms, false),
|
||||||
};
|
};
|
||||||
|
|
||||||
let symbol = &task.symbol_name;
|
let symbol = &task.symbol_name;
|
||||||
@ -910,10 +763,9 @@ pub fn gen_func_impl<
|
|||||||
builder.position_at_end(init_bb);
|
builder.position_at_end(init_bb);
|
||||||
let body_bb = context.append_basic_block(fn_val, "body");
|
let body_bb = context.append_basic_block(fn_val, "body");
|
||||||
|
|
||||||
// Store non-vararg argument values into local variables
|
|
||||||
let mut var_assignment = HashMap::new();
|
let mut var_assignment = HashMap::new();
|
||||||
let offset = u32::from(has_sret);
|
let offset = u32::from(has_sret);
|
||||||
for (n, arg) in args.iter().enumerate().filter(|(_, arg)| !arg.is_vararg) {
|
for (n, arg) in args.iter().enumerate() {
|
||||||
let param = fn_val.get_nth_param((n as u32) + offset).unwrap();
|
let param = fn_val.get_nth_param((n as u32) + offset).unwrap();
|
||||||
let local_type = get_llvm_type(
|
let local_type = get_llvm_type(
|
||||||
context,
|
context,
|
||||||
@ -946,8 +798,6 @@ pub fn gen_func_impl<
|
|||||||
var_assignment.insert(arg.name, (alloca, None, 0));
|
var_assignment.insert(arg.name, (alloca, None, 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Save vararg parameters as list
|
|
||||||
|
|
||||||
let return_buffer = if has_sret {
|
let return_buffer = if has_sret {
|
||||||
Some(fn_val.get_nth_param(0).unwrap().into_pointer_value())
|
Some(fn_val.get_nth_param(0).unwrap().into_pointer_value())
|
||||||
} else {
|
} else {
|
||||||
@ -1039,20 +889,8 @@ pub fn gen_func_impl<
|
|||||||
need_sret: has_sret,
|
need_sret: has_sret,
|
||||||
current_loc: Location::default(),
|
current_loc: Location::default(),
|
||||||
debug_info: (dibuilder, compile_unit, func_scope.as_debug_info_scope()),
|
debug_info: (dibuilder, compile_unit, func_scope.as_debug_info_scope()),
|
||||||
llvm_usize: OnceCell::default(),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let target_llvm_usize = context.ptr_sized_int_type(
|
|
||||||
®istry.llvm_options.create_target_machine().map(|tm| tm.get_target_data()).unwrap(),
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
let generator_llvm_usize = generator.get_size_type(context);
|
|
||||||
assert_eq!(
|
|
||||||
generator_llvm_usize,
|
|
||||||
target_llvm_usize,
|
|
||||||
"CodeGenerator (size_t = {generator_llvm_usize}) is not compatible with CodeGen Target (size_t = {target_llvm_usize})",
|
|
||||||
);
|
|
||||||
|
|
||||||
let loc = code_gen_context.debug_info.0.create_debug_location(
|
let loc = code_gen_context.debug_info.0.create_debug_location(
|
||||||
context,
|
context,
|
||||||
row as u32,
|
row as u32,
|
||||||
@ -1182,112 +1020,3 @@ fn gen_in_range_check<'ctx>(
|
|||||||
|
|
||||||
ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp").unwrap()
|
ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp").unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the internal name for the `va_count` argument, used to indicate the number of arguments
|
|
||||||
/// passed to the variadic function.
|
|
||||||
fn get_va_count_arg_name(arg_name: StrRef) -> StrRef {
|
|
||||||
format!("__{}_va_count", &arg_name).into()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the alignment of the type.
|
|
||||||
///
|
|
||||||
/// This is necessary as `get_alignment` is not implemented as part of [`BasicType`].
|
|
||||||
pub fn get_type_alignment<'ctx>(ty: impl Into<BasicTypeEnum<'ctx>>) -> IntValue<'ctx> {
|
|
||||||
match ty.into() {
|
|
||||||
BasicTypeEnum::ArrayType(ty) => ty.get_alignment(),
|
|
||||||
BasicTypeEnum::FloatType(ty) => ty.get_alignment(),
|
|
||||||
BasicTypeEnum::IntType(ty) => ty.get_alignment(),
|
|
||||||
BasicTypeEnum::PointerType(ty) => ty.get_alignment(),
|
|
||||||
BasicTypeEnum::StructType(ty) => ty.get_alignment(),
|
|
||||||
BasicTypeEnum::VectorType(ty) => ty.get_alignment(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Inserts an `alloca` instruction with allocation `size` given in bytes and the alignment of the
|
|
||||||
/// given type.
|
|
||||||
///
|
|
||||||
/// The returned [`PointerValue`] will have a type of `i8*`, a size of at least `size`, and will be
|
|
||||||
/// aligned with the alignment of `align_ty`.
|
|
||||||
pub fn type_aligned_alloca<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
align_ty: impl Into<BasicTypeEnum<'ctx>>,
|
|
||||||
size: IntValue<'ctx>,
|
|
||||||
name: Option<&str>,
|
|
||||||
) -> PointerValue<'ctx> {
|
|
||||||
/// Round `val` up to its modulo `power_of_two`.
|
|
||||||
fn round_up<'ctx>(
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
val: IntValue<'ctx>,
|
|
||||||
power_of_two: IntValue<'ctx>,
|
|
||||||
) -> IntValue<'ctx> {
|
|
||||||
debug_assert_eq!(
|
|
||||||
val.get_type().get_bit_width(),
|
|
||||||
power_of_two.get_type().get_bit_width(),
|
|
||||||
"`val` ({}) and `power_of_two` ({}) must be the same type",
|
|
||||||
val.get_type(),
|
|
||||||
power_of_two.get_type(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let llvm_val_t = val.get_type();
|
|
||||||
|
|
||||||
let max_rem =
|
|
||||||
ctx.builder.build_int_sub(power_of_two, llvm_val_t.const_int(1, false), "").unwrap();
|
|
||||||
ctx.builder
|
|
||||||
.build_and(
|
|
||||||
ctx.builder.build_int_add(val, max_rem, "").unwrap(),
|
|
||||||
ctx.builder.build_not(max_rem, "").unwrap(),
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
let llvm_i8 = ctx.ctx.i8_type();
|
|
||||||
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
|
||||||
let llvm_usize = ctx.get_size_type();
|
|
||||||
let align_ty = align_ty.into();
|
|
||||||
|
|
||||||
let size = ctx.builder.build_int_truncate_or_bit_cast(size, llvm_usize, "").unwrap();
|
|
||||||
|
|
||||||
debug_assert_eq!(
|
|
||||||
size.get_type().get_bit_width(),
|
|
||||||
llvm_usize.get_bit_width(),
|
|
||||||
"Expected size_t ({}) for parameter `size` of `aligned_alloca`, got {}",
|
|
||||||
llvm_usize,
|
|
||||||
size.get_type(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let alignment = get_type_alignment(align_ty);
|
|
||||||
let alignment = ctx.builder.build_int_truncate_or_bit_cast(alignment, llvm_usize, "").unwrap();
|
|
||||||
|
|
||||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
|
||||||
let alignment_bitcount = llvm_intrinsics::call_int_ctpop(ctx, alignment, None);
|
|
||||||
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
ctx.builder
|
|
||||||
.build_int_compare(
|
|
||||||
IntPredicate::EQ,
|
|
||||||
alignment_bitcount,
|
|
||||||
alignment_bitcount.get_type().const_int(1, false),
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
.unwrap(),
|
|
||||||
"0:AssertionError",
|
|
||||||
"Expected power-of-two alignment for aligned_alloca, got {0}",
|
|
||||||
[Some(alignment), None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
let buffer_size = round_up(ctx, size, alignment);
|
|
||||||
let aligned_slices = ctx.builder.build_int_unsigned_div(buffer_size, alignment, "").unwrap();
|
|
||||||
|
|
||||||
// Just to be absolutely sure, alloca in [i8 x alignment] slices
|
|
||||||
let buffer = ctx.builder.build_array_alloca(align_ty, aligned_slices, "").unwrap();
|
|
||||||
|
|
||||||
ctx.builder
|
|
||||||
.build_bit_cast(buffer, llvm_pi8, name.unwrap_or_default())
|
|
||||||
.map(BasicValueEnum::into_pointer_value)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
118
nac3core/src/codegen/model/core.rs
Normal file
118
nac3core/src/codegen/model/core.rs
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
use inkwell::{
|
||||||
|
context::Context,
|
||||||
|
types::{AnyTypeEnum, BasicTypeEnum},
|
||||||
|
values::{AnyValue, AnyValueEnum, BasicValueEnum, PointerValue},
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::codegen::{CodeGenContext, CodeGenerator};
|
||||||
|
|
||||||
|
use super::{slice::ArraySlice, Int, Pointer};
|
||||||
|
|
||||||
|
/// A value that belongs to/produced by a [`Model<'ctx>`]
|
||||||
|
pub trait ModelValue<'ctx>: Clone + Copy {
|
||||||
|
fn get_llvm_value(&self) -> BasicValueEnum<'ctx>;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have been within [`Model<'ctx>`],
|
||||||
|
// but rust object safety requirements made it necessary to
|
||||||
|
// split this interface out
|
||||||
|
pub trait CanCheckLLVMType<'ctx> {
|
||||||
|
/// Check if `scrutinee` matches the same LLVM type of this [`Model<'ctx>`].
|
||||||
|
///
|
||||||
|
/// If they don't not match, a human-readable error message is returned.
|
||||||
|
fn check_llvm_type(
|
||||||
|
&self,
|
||||||
|
ctx: &'ctx Context,
|
||||||
|
scrutinee: AnyTypeEnum<'ctx>,
|
||||||
|
) -> Result<(), String>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A [`Model`] is a type-safe concrete representation of a complex LLVM type.
|
||||||
|
pub trait Model<'ctx>: Clone + Copy + CanCheckLLVMType<'ctx> + Sized {
|
||||||
|
/// The values that inhabit this [`Model<'ctx>`].
|
||||||
|
///
|
||||||
|
/// ...that is the type of wrapper that wraps the LLVM values that inhabit [`Model<'ctx>::get_llvm_type()`].
|
||||||
|
type Value: ModelValue<'ctx>;
|
||||||
|
|
||||||
|
/// Get the [`BasicTypeEnum<'ctx>`] this [`Model<'ctx>`] is representing.
|
||||||
|
fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx>;
|
||||||
|
|
||||||
|
/// Cast an [`AnyValueEnum<'ctx>`] into [`Self::Value`].
|
||||||
|
///
|
||||||
|
/// Panics if `value` cannot pass [`CanCheckLLVMType::check_llvm_type()`].
|
||||||
|
fn review(&self, ctx: &'ctx Context, value: AnyValueEnum<'ctx>) -> Self::Value;
|
||||||
|
|
||||||
|
/// Check if [`Self::Value`] has the correct type described by this [`Model<'ctx>`]
|
||||||
|
///
|
||||||
|
/// For example:
|
||||||
|
/// ```ignore
|
||||||
|
/// let ctx: &CodeGenContext<'ctx, '_>;
|
||||||
|
/// let my_i32 = IntModel(ctx.ctx.i32_type());
|
||||||
|
/// let my_i64 = IntModel(ctx.ctx.i64_type());
|
||||||
|
/// let value1 = my_i32.constant(3);
|
||||||
|
/// let value2 = my_i64.constant(3);
|
||||||
|
/// // Both value1 and value2 have type `IntModel<'ctx>`!
|
||||||
|
/// // There is no type constraints to tell which value has what int type.
|
||||||
|
/// my_i32.check(value1); // ok
|
||||||
|
/// my_i64.check(value2); // ok
|
||||||
|
///
|
||||||
|
/// my_i32.check(value2); // PANIC
|
||||||
|
/// my_i64.check(value1); // PANIC
|
||||||
|
/// ```
|
||||||
|
fn check(&self, ctx: &'ctx Context, value: Self::Value) {
|
||||||
|
self.review(ctx, value.get_llvm_value().as_any_value_enum());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build an instruction to allocate a value of [`Model::get_llvm_type`].
|
||||||
|
fn alloca(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> Pointer<'ctx, Self> {
|
||||||
|
Pointer {
|
||||||
|
element: *self,
|
||||||
|
value: ctx.builder.build_alloca(self.get_llvm_type(ctx.ctx), name).unwrap(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build an instruction to allocate an array of [`Model::get_llvm_type`].
|
||||||
|
fn array_alloca(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
count: Int<'ctx>,
|
||||||
|
name: &str,
|
||||||
|
) -> ArraySlice<'ctx, Self> {
|
||||||
|
ArraySlice {
|
||||||
|
num_elements: count,
|
||||||
|
pointer: Pointer {
|
||||||
|
element: *self,
|
||||||
|
value: ctx
|
||||||
|
.builder
|
||||||
|
.build_array_alloca(self.get_llvm_type(ctx.ctx), count.0, name)
|
||||||
|
.unwrap(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Do [`CodeGenerator::gen_var_alloc`] with the LLVM type of this [`Model<'ctx>`].
|
||||||
|
fn var_alloc<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> Result<Pointer<'ctx, Self>, String> {
|
||||||
|
let value = generator.gen_var_alloc(ctx, self.get_llvm_type(ctx.ctx), name)?;
|
||||||
|
Ok(Pointer { element: *self, value })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Do [`CodeGenerator::gen_array_var_alloc`] with the LLVM type of this [`Model<'ctx>`].
|
||||||
|
fn array_var_alloc<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
size: Int<'ctx>,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
) -> Result<Pointer<'ctx, Self>, String> {
|
||||||
|
let slice =
|
||||||
|
generator.gen_array_var_alloc(ctx, self.get_llvm_type(ctx.ctx), size.0, name)?;
|
||||||
|
let ptr = PointerValue::from(slice); // TODO: Remove ArraySliceValue
|
||||||
|
|
||||||
|
Ok(Pointer { element: *self, value: ptr })
|
||||||
|
}
|
||||||
|
}
|
159
nac3core/src/codegen/model/fixed_int.rs
Normal file
159
nac3core/src/codegen/model/fixed_int.rs
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
use inkwell::{
|
||||||
|
context::Context,
|
||||||
|
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType},
|
||||||
|
values::{AnyValueEnum, BasicValue, BasicValueEnum, IntValue},
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::codegen::CodeGenContext;
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
core::*,
|
||||||
|
int_util::{check_int_llvm_type, review_int_llvm_value},
|
||||||
|
Int, IntModel,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// A marker trait to mark singleton struct that describes a particular fixed integer type.
|
||||||
|
/// See [`Bool`], [`Byte`], [`Int32`], etc.
|
||||||
|
///
|
||||||
|
/// The [`Default`] trait is to enable auto-derivations for utilities like
|
||||||
|
/// [`FieldBuilder::add_field_auto`]
|
||||||
|
pub trait IsFixedInt: Clone + Copy + Default {
|
||||||
|
fn get_int_type(ctx: &Context) -> IntType<'_>;
|
||||||
|
fn get_bit_width() -> u32; // This is required, instead of only relying on get_int_type
|
||||||
|
|
||||||
|
fn constant<'ctx>(&self, ctx: &'ctx Context, value: u64) -> FixedInt<'ctx, Self> {
|
||||||
|
FixedInt { int: *self, value: Self::get_int_type(ctx).const_int(value, false) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Some pre-defined fixed integers
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
|
pub struct Bool;
|
||||||
|
pub type BoolModel = FixedIntModel<Bool>;
|
||||||
|
|
||||||
|
impl IsFixedInt for Bool {
|
||||||
|
fn get_int_type(ctx: &Context) -> IntType<'_> {
|
||||||
|
ctx.bool_type()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_bit_width() -> u32 {
|
||||||
|
1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
|
pub struct Byte;
|
||||||
|
pub type ByteModel = FixedIntModel<Byte>;
|
||||||
|
|
||||||
|
impl IsFixedInt for Byte {
|
||||||
|
fn get_int_type(ctx: &Context) -> IntType<'_> {
|
||||||
|
ctx.i8_type()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_bit_width() -> u32 {
|
||||||
|
8
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
|
pub struct Int32;
|
||||||
|
pub type Int32Model = FixedIntModel<Int32>;
|
||||||
|
|
||||||
|
impl IsFixedInt for Int32 {
|
||||||
|
fn get_int_type(ctx: &Context) -> IntType<'_> {
|
||||||
|
ctx.i32_type()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_bit_width() -> u32 {
|
||||||
|
32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
|
pub struct Int64;
|
||||||
|
pub type Int64Model = FixedIntModel<Int64>;
|
||||||
|
|
||||||
|
impl IsFixedInt for Int64 {
|
||||||
|
fn get_int_type(ctx: &Context) -> IntType<'_> {
|
||||||
|
ctx.i64_type()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_bit_width() -> u32 {
|
||||||
|
64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A model representing a compile-time known [`IntType<'ctx>`].
|
||||||
|
///
|
||||||
|
/// Also see [`IntModel`], which is less constrained than [`FixedIntModel`],
|
||||||
|
/// but enables one to handle [`IntType<'ctx>`] that could be dynamic
|
||||||
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
|
pub struct FixedIntModel<T>(pub T);
|
||||||
|
|
||||||
|
// FixedIntModel's implementation
|
||||||
|
|
||||||
|
impl<'ctx, T: IsFixedInt> CanCheckLLVMType<'ctx> for FixedIntModel<T> {
|
||||||
|
fn check_llvm_type(
|
||||||
|
&self,
|
||||||
|
ctx: &'ctx Context,
|
||||||
|
scrutinee: AnyTypeEnum<'ctx>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
check_int_llvm_type(scrutinee, T::get_int_type(ctx))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, T: IsFixedInt> Model<'ctx> for FixedIntModel<T> {
|
||||||
|
type Value = FixedInt<'ctx, T>;
|
||||||
|
|
||||||
|
fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
|
||||||
|
T::get_int_type(ctx).as_basic_type_enum()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn review(&self, ctx: &'ctx Context, value: AnyValueEnum<'ctx>) -> Self::Value {
|
||||||
|
let value = review_int_llvm_value(value, T::get_int_type(ctx)).unwrap();
|
||||||
|
FixedInt { int: self.0, value }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: IsFixedInt> FixedIntModel<T> {
|
||||||
|
pub fn to_int_model(self, ctx: &Context) -> IntModel<'_> {
|
||||||
|
IntModel(T::get_int_type(ctx))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An inhabitant of [`FixedIntModel<'ctx>`]
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct FixedInt<'ctx, T: IsFixedInt> {
|
||||||
|
pub int: T,
|
||||||
|
pub value: IntValue<'ctx>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// FixedInt's Implementation
|
||||||
|
|
||||||
|
impl<'ctx, T: IsFixedInt> ModelValue<'ctx> for FixedInt<'ctx, T> {
|
||||||
|
fn get_llvm_value(&self) -> BasicValueEnum<'ctx> {
|
||||||
|
self.value.as_basic_value_enum()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, T: IsFixedInt> FixedInt<'ctx, T> {
|
||||||
|
pub fn to_int(self) -> Int<'ctx> {
|
||||||
|
Int(self.value)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn signed_cast_to_fixed<R: IsFixedInt>(
|
||||||
|
self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
target_fixed_int: R,
|
||||||
|
name: &str,
|
||||||
|
) -> FixedInt<'ctx, R> {
|
||||||
|
FixedInt {
|
||||||
|
int: target_fixed_int,
|
||||||
|
value: ctx
|
||||||
|
.builder
|
||||||
|
.build_int_s_extend_or_bit_cast(self.value, R::get_int_type(ctx.ctx), name)
|
||||||
|
.unwrap(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
55
nac3core/src/codegen/model/function_builder.rs
Normal file
55
nac3core/src/codegen/model/function_builder.rs
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
use inkwell::{
|
||||||
|
types::{BasicMetadataTypeEnum, BasicType},
|
||||||
|
values::{AnyValue, BasicMetadataValueEnum},
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::codegen::{model::*, CodeGenContext};
|
||||||
|
|
||||||
|
// TODO: Variadic argument?
|
||||||
|
pub struct FunctionBuilder<'ctx, 'a> {
|
||||||
|
ctx: &'a CodeGenContext<'ctx, 'a>,
|
||||||
|
fn_name: &'a str,
|
||||||
|
arguments: Vec<(BasicMetadataTypeEnum<'ctx>, BasicMetadataValueEnum<'ctx>)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, 'a> FunctionBuilder<'ctx, 'a> {
|
||||||
|
pub fn begin(ctx: &'a CodeGenContext<'ctx, 'a>, fn_name: &'a str) -> Self {
|
||||||
|
FunctionBuilder { ctx, fn_name, arguments: Vec::new() }
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: `_name` is for self-documentation
|
||||||
|
#[must_use]
|
||||||
|
pub fn arg<M: Model<'ctx>>(mut self, _name: &'static str, model: M, value: M::Value) -> Self {
|
||||||
|
model.check(self.ctx.ctx, value); // Panics if the passed `value` has the incorrect type.
|
||||||
|
|
||||||
|
self.arguments
|
||||||
|
.push((model.get_llvm_type(self.ctx.ctx).into(), value.get_llvm_value().into()));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn returning<M: Model<'ctx>>(self, name: &'static str, return_model: M) -> M::Value {
|
||||||
|
let (param_tys, param_vals): (Vec<_>, Vec<_>) = self.arguments.into_iter().unzip();
|
||||||
|
|
||||||
|
let function = self.ctx.module.get_function(self.fn_name).unwrap_or_else(|| {
|
||||||
|
let return_type = return_model.get_llvm_type(self.ctx.ctx);
|
||||||
|
let fn_type = return_type.fn_type(¶m_tys, false);
|
||||||
|
self.ctx.module.add_function(self.fn_name, fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
let ret = self.ctx.builder.build_call(function, ¶m_vals, name).unwrap();
|
||||||
|
return_model.review(self.ctx.ctx, ret.as_any_value_enum())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Code duplication, but otherwise returning<S: Optic<'ctx>> cannot resolve S if return_optic = None
|
||||||
|
pub fn returning_void(self) {
|
||||||
|
let (param_tys, param_vals): (Vec<_>, Vec<_>) = self.arguments.into_iter().unzip();
|
||||||
|
|
||||||
|
let function = self.ctx.module.get_function(self.fn_name).unwrap_or_else(|| {
|
||||||
|
let return_type = self.ctx.ctx.void_type();
|
||||||
|
let fn_type = return_type.fn_type(¶m_tys, false);
|
||||||
|
self.ctx.module.add_function(self.fn_name, fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
self.ctx.builder.build_call(function, ¶m_vals, "").unwrap();
|
||||||
|
}
|
||||||
|
}
|
97
nac3core/src/codegen/model/int.rs
Normal file
97
nac3core/src/codegen/model/int.rs
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
use inkwell::{
|
||||||
|
context::Context,
|
||||||
|
types::{AnyType, AnyTypeEnum, BasicType, BasicTypeEnum, IntType},
|
||||||
|
values::{AnyValueEnum, BasicValue, BasicValueEnum, IntValue},
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::codegen::CodeGenContext;
|
||||||
|
|
||||||
|
use super::{core::*, int_util::check_int_llvm_type, FixedInt, IsFixedInt};
|
||||||
|
|
||||||
|
/// A model representing an [`IntType<'ctx>`].
|
||||||
|
///
|
||||||
|
/// Also see [`FixedIntModel`], which is more constrained than [`IntModel`]
|
||||||
|
/// but provides more type-safe mechanisms and even auto-derivation of [`BasicTypeEnum<'ctx>`]
|
||||||
|
/// for creating LLVM structures.
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct IntModel<'ctx>(pub IntType<'ctx>);
|
||||||
|
|
||||||
|
impl<'ctx> CanCheckLLVMType<'ctx> for IntModel<'ctx> {
|
||||||
|
fn check_llvm_type(
|
||||||
|
&self,
|
||||||
|
_ctx: &'ctx Context,
|
||||||
|
scrutinee: AnyTypeEnum<'ctx>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
check_int_llvm_type(scrutinee, self.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> Model<'ctx> for IntModel<'ctx> {
|
||||||
|
type Value = Int<'ctx>;
|
||||||
|
|
||||||
|
fn get_llvm_type(&self, _ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
|
||||||
|
self.0.as_basic_type_enum()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn review(&self, ctx: &'ctx Context, value: AnyValueEnum<'ctx>) -> Self::Value {
|
||||||
|
let int = value.into_int_value();
|
||||||
|
self.check_llvm_type(ctx, int.get_type().as_any_type_enum()).unwrap();
|
||||||
|
Int(int)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> IntModel<'ctx> {
|
||||||
|
/// Create a constant value that inhabits this [`IntModel<'ctx>`].
|
||||||
|
#[must_use]
|
||||||
|
pub fn constant(&self, value: u64) -> Int<'ctx> {
|
||||||
|
Int(self.0.const_int(value, false))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if `other` is fully compatible with this [`IntModel<'ctx>`].
|
||||||
|
///
|
||||||
|
/// This simply checks if the underlying [`IntType<'ctx>`] has
|
||||||
|
/// the same number of bits.
|
||||||
|
#[must_use]
|
||||||
|
pub fn same_as(&self, other: IntModel<'ctx>) -> bool {
|
||||||
|
// TODO: or `self.0 == other.0` would also work?
|
||||||
|
self.0.get_bit_width() == other.0.get_bit_width()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An inhabitant of an [`IntModel<'ctx>`]
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct Int<'ctx>(pub IntValue<'ctx>);
|
||||||
|
|
||||||
|
impl<'ctx> ModelValue<'ctx> for Int<'ctx> {
|
||||||
|
fn get_llvm_value(&self) -> BasicValueEnum<'ctx> {
|
||||||
|
self.0.as_basic_value_enum()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> Int<'ctx> {
|
||||||
|
#[must_use]
|
||||||
|
pub fn signed_cast_to_int(
|
||||||
|
self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
target_int: IntModel<'ctx>,
|
||||||
|
name: &str,
|
||||||
|
) -> Int<'ctx> {
|
||||||
|
Int(ctx.builder.build_int_s_extend_or_bit_cast(self.0, target_int.0, name).unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn signed_cast_to_fixed<T: IsFixedInt>(
|
||||||
|
self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
target_fixed: T,
|
||||||
|
name: &str,
|
||||||
|
) -> FixedInt<'ctx, T> {
|
||||||
|
FixedInt {
|
||||||
|
int: target_fixed,
|
||||||
|
value: ctx
|
||||||
|
.builder
|
||||||
|
.build_int_s_extend_or_bit_cast(self.0, T::get_int_type(ctx.ctx), name)
|
||||||
|
.unwrap(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
39
nac3core/src/codegen/model/int_util.rs
Normal file
39
nac3core/src/codegen/model/int_util.rs
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
use inkwell::{
|
||||||
|
types::{AnyType, AnyTypeEnum, IntType},
|
||||||
|
values::{AnyValueEnum, IntValue},
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Helper function to check if `scrutinee` is the same as `expected_int_type`
|
||||||
|
pub fn check_int_llvm_type<'ctx>(
|
||||||
|
scrutinee: AnyTypeEnum<'ctx>,
|
||||||
|
expected_int_type: IntType<'ctx>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
// Check if llvm_type is int type
|
||||||
|
let AnyTypeEnum::IntType(scrutinee) = scrutinee else {
|
||||||
|
return Err(format!("Expecting an int type but got {scrutinee:?}"));
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check bit width
|
||||||
|
if scrutinee.get_bit_width() != expected_int_type.get_bit_width() {
|
||||||
|
return Err(format!(
|
||||||
|
"Expecting an int type of {}-bit(s) but got int type {}-bit(s)",
|
||||||
|
expected_int_type.get_bit_width(),
|
||||||
|
scrutinee.get_bit_width()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper function to cast `scrutinee` is into an [`IntValue<'ctx>`].
|
||||||
|
/// The LLVM type of `scrutinee` will be checked with [`check_int_llvm_type`].
|
||||||
|
pub fn review_int_llvm_value<'ctx>(
|
||||||
|
value: AnyValueEnum<'ctx>,
|
||||||
|
expected_int_type: IntType<'ctx>,
|
||||||
|
) -> Result<IntValue<'ctx>, String> {
|
||||||
|
// Check if value is of int type, error if that is anything else
|
||||||
|
check_int_llvm_type(value.get_type().as_any_type_enum(), expected_int_type)?;
|
||||||
|
|
||||||
|
// Ok, it is must be an int
|
||||||
|
Ok(value.into_int_value())
|
||||||
|
}
|
18
nac3core/src/codegen/model/mod.rs
Normal file
18
nac3core/src/codegen/model/mod.rs
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
pub mod core;
|
||||||
|
pub mod fixed_int;
|
||||||
|
pub mod function_builder;
|
||||||
|
pub mod int;
|
||||||
|
mod int_util;
|
||||||
|
pub mod opaque;
|
||||||
|
pub mod pointer;
|
||||||
|
pub mod slice;
|
||||||
|
pub mod structure;
|
||||||
|
|
||||||
|
pub use core::*;
|
||||||
|
pub use fixed_int::*;
|
||||||
|
pub use function_builder::*;
|
||||||
|
pub use int::*;
|
||||||
|
pub use opaque::*;
|
||||||
|
pub use pointer::*;
|
||||||
|
pub use slice::*;
|
||||||
|
pub use structure::*;
|
46
nac3core/src/codegen/model/opaque.rs
Normal file
46
nac3core/src/codegen/model/opaque.rs
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
use inkwell::{
|
||||||
|
context::Context,
|
||||||
|
types::{AnyTypeEnum, BasicTypeEnum},
|
||||||
|
values::{AnyValueEnum, BasicValueEnum},
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct OpaqueModel<'ctx>(pub BasicTypeEnum<'ctx>);
|
||||||
|
|
||||||
|
impl<'ctx> CanCheckLLVMType<'ctx> for OpaqueModel<'ctx> {
|
||||||
|
fn check_llvm_type(
|
||||||
|
&self,
|
||||||
|
_ctx: &'ctx Context,
|
||||||
|
scrutinee: AnyTypeEnum<'ctx>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
match BasicTypeEnum::try_from(scrutinee) {
|
||||||
|
Ok(_) => Ok(()),
|
||||||
|
Err(_err) => Err(format!("Expecting a BasicTypeEnum, but got {scrutinee:?}")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> Model<'ctx> for OpaqueModel<'ctx> {
|
||||||
|
type Value = Opaque<'ctx>;
|
||||||
|
|
||||||
|
fn get_llvm_type(&self, _ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
|
||||||
|
self.0
|
||||||
|
}
|
||||||
|
|
||||||
|
fn review(&self, ctx: &'ctx Context, value: AnyValueEnum<'ctx>) -> Self::Value {
|
||||||
|
self.check_llvm_type(ctx, value.get_type()).unwrap();
|
||||||
|
let value = BasicValueEnum::try_from(value).unwrap(); // Must work
|
||||||
|
Opaque(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct Opaque<'ctx>(pub BasicValueEnum<'ctx>);
|
||||||
|
|
||||||
|
impl<'ctx> ModelValue<'ctx> for Opaque<'ctx> {
|
||||||
|
fn get_llvm_value(&self) -> BasicValueEnum<'ctx> {
|
||||||
|
self.0
|
||||||
|
}
|
||||||
|
}
|
83
nac3core/src/codegen/model/pointer.rs
Normal file
83
nac3core/src/codegen/model/pointer.rs
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
use inkwell::{
|
||||||
|
context::Context,
|
||||||
|
types::{AnyTypeEnum, BasicType, BasicTypeEnum},
|
||||||
|
values::{AnyValue, AnyValueEnum, BasicValue, BasicValueEnum, PointerValue},
|
||||||
|
AddressSpace,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::codegen::CodeGenContext;
|
||||||
|
|
||||||
|
use super::{core::*, OpaqueModel};
|
||||||
|
|
||||||
|
/// A [`Model<'ctx>`] representing an LLVM [`PointerType<'ctx>`]
|
||||||
|
/// with *full* information on the element u
|
||||||
|
///
|
||||||
|
/// [`self.0`] contains [`Model<'ctx>`] that represents the
|
||||||
|
/// LLVM type of element of the [`PointerType<'ctx>`] is pointing at
|
||||||
|
/// (like `PointerType<'ctx>::get_element_type()`, but abstracted as a [`Model<'ctx>`]).
|
||||||
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
|
pub struct PointerModel<E>(pub E);
|
||||||
|
|
||||||
|
impl<'ctx, E: Model<'ctx>> CanCheckLLVMType<'ctx> for PointerModel<E> {
|
||||||
|
fn check_llvm_type(
|
||||||
|
&self,
|
||||||
|
ctx: &'ctx Context,
|
||||||
|
scrutinee: AnyTypeEnum<'ctx>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
// Check if scrutinee is even a PointerValue
|
||||||
|
let AnyTypeEnum::PointerType(scrutinee) = scrutinee else {
|
||||||
|
return Err(format!("Expecting a pointer value, but got {scrutinee:?}"));
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check the type of what the pointer is pointing at
|
||||||
|
// TODO: This will be deprecated by inkwell > llvm14 because `get_element_type()` will be gone
|
||||||
|
self.0.check_llvm_type(ctx, scrutinee.get_element_type())?; // TODO: Include backtrace?
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, E: Model<'ctx>> Model<'ctx> for PointerModel<E> {
|
||||||
|
type Value = Pointer<'ctx, E>;
|
||||||
|
|
||||||
|
fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
|
||||||
|
self.0.get_llvm_type(ctx).ptr_type(AddressSpace::default()).as_basic_type_enum()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn review(&self, ctx: &'ctx Context, value: AnyValueEnum<'ctx>) -> Self::Value {
|
||||||
|
self.check_llvm_type(ctx, value.get_type()).unwrap();
|
||||||
|
|
||||||
|
// TODO: Check get_element_type(). For inkwell LLVM 14 at least...
|
||||||
|
Pointer { element: self.0, value: value.into_pointer_value() }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An inhabitant of [`PointerModel<E>`]
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct Pointer<'ctx, E: Model<'ctx>> {
|
||||||
|
pub element: E,
|
||||||
|
pub value: PointerValue<'ctx>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, E: Model<'ctx>> ModelValue<'ctx> for Pointer<'ctx, E> {
|
||||||
|
fn get_llvm_value(&self) -> BasicValueEnum<'ctx> {
|
||||||
|
self.value.as_basic_value_enum()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, E: Model<'ctx>> Pointer<'ctx, E> {
|
||||||
|
/// Build an instruction to store a value into this pointer
|
||||||
|
pub fn store(&self, ctx: &CodeGenContext<'ctx, '_>, val: E::Value) {
|
||||||
|
ctx.builder.build_store(self.value, val.get_llvm_value()).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build an instruction to load a value from this pointer
|
||||||
|
pub fn load(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> E::Value {
|
||||||
|
let val = ctx.builder.build_load(self.value, name).unwrap();
|
||||||
|
self.element.review(ctx.ctx, val.as_any_value_enum())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_opaque(self, ctx: &'ctx Context) -> Pointer<'ctx, OpaqueModel<'ctx>> {
|
||||||
|
Pointer { element: OpaqueModel(self.element.get_llvm_type(ctx)), value: self.value }
|
||||||
|
}
|
||||||
|
}
|
73
nac3core/src/codegen/model/slice.rs
Normal file
73
nac3core/src/codegen/model/slice.rs
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
use crate::codegen::{CodeGenContext, CodeGenerator};
|
||||||
|
|
||||||
|
use super::{Int, Model, Pointer};
|
||||||
|
|
||||||
|
pub struct ArraySlice<'ctx, E: Model<'ctx>> {
|
||||||
|
pub num_elements: Int<'ctx>,
|
||||||
|
pub pointer: Pointer<'ctx, E>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, E: Model<'ctx>> ArraySlice<'ctx, E> {
|
||||||
|
pub fn ix_unchecked(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
idx: Int<'ctx>,
|
||||||
|
name: &str,
|
||||||
|
) -> Pointer<'ctx, E> {
|
||||||
|
let element_addr =
|
||||||
|
unsafe { ctx.builder.build_in_bounds_gep(self.pointer.value, &[idx.0], name).unwrap() };
|
||||||
|
Pointer { value: element_addr, element: self.pointer.element }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn ix<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
idx: Int<'ctx>,
|
||||||
|
name: &str,
|
||||||
|
) -> Pointer<'ctx, E> {
|
||||||
|
let int_type = self.num_elements.0.get_type(); // NOTE: Weird get_type(), see comment under `trait Ixed`
|
||||||
|
|
||||||
|
assert_eq!(int_type.get_bit_width(), idx.0.get_type().get_bit_width()); // Might as well check bit width to catch bugs
|
||||||
|
|
||||||
|
// TODO: SGE or UGE? or make it defined by the implementee?
|
||||||
|
|
||||||
|
// Check `0 <= index`
|
||||||
|
let lower_bounded = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(
|
||||||
|
inkwell::IntPredicate::SLE,
|
||||||
|
int_type.const_zero(),
|
||||||
|
idx.0,
|
||||||
|
"lower_bounded",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Check `index < num_elements`
|
||||||
|
let upper_bounded = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(
|
||||||
|
inkwell::IntPredicate::SLT,
|
||||||
|
idx.0,
|
||||||
|
self.num_elements.0,
|
||||||
|
"upper_bounded",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Compute `0 <= index && index < num_elements`
|
||||||
|
let bounded = ctx.builder.build_and(lower_bounded, upper_bounded, "bounded").unwrap();
|
||||||
|
|
||||||
|
// Assert `bounded`
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
bounded,
|
||||||
|
"0:IndexError",
|
||||||
|
"nac3core LLVM codegen attempting to access out of bounds array index {0}. Must satisfy 0 <= index < {2}",
|
||||||
|
[ Some(idx.0), Some(self.num_elements.0), None],
|
||||||
|
ctx.current_loc
|
||||||
|
);
|
||||||
|
|
||||||
|
// ...and finally do indexing
|
||||||
|
self.ix_unchecked(ctx, idx, name)
|
||||||
|
}
|
||||||
|
}
|
219
nac3core/src/codegen/model/structure.rs
Normal file
219
nac3core/src/codegen/model/structure.rs
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
use inkwell::{
|
||||||
|
context::Context,
|
||||||
|
types::{AnyType, AnyTypeEnum, BasicType, BasicTypeEnum, StructType},
|
||||||
|
values::{AnyValueEnum, BasicValue, BasicValueEnum, StructValue},
|
||||||
|
};
|
||||||
|
use itertools::{izip, Itertools};
|
||||||
|
|
||||||
|
use crate::codegen::CodeGenContext;
|
||||||
|
|
||||||
|
use super::{core::CanCheckLLVMType, Model, ModelValue, Pointer};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct Field<E> {
|
||||||
|
pub gep_index: u64,
|
||||||
|
pub name: &'static str,
|
||||||
|
pub element: E,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct FieldLLVM<'ctx> {
|
||||||
|
gep_index: u64,
|
||||||
|
name: &'ctx str,
|
||||||
|
llvm_type: BasicTypeEnum<'ctx>,
|
||||||
|
|
||||||
|
// Only CanCheckLLVMType is needed, dont use `Model<'ctx>`
|
||||||
|
llvm_type_model: Box<dyn CanCheckLLVMType<'ctx> + 'ctx>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct FieldBuilder<'ctx> {
|
||||||
|
pub ctx: &'ctx Context,
|
||||||
|
gep_index_counter: u64,
|
||||||
|
struct_name: &'ctx str,
|
||||||
|
fields: Vec<FieldLLVM<'ctx>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> FieldBuilder<'ctx> {
|
||||||
|
#[must_use]
|
||||||
|
pub fn new(ctx: &'ctx Context, struct_name: &'ctx str) -> Self {
|
||||||
|
FieldBuilder { ctx, gep_index_counter: 0, struct_name, fields: Vec::new() }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn next_gep_index(&mut self) -> u64 {
|
||||||
|
let index = self.gep_index_counter;
|
||||||
|
self.gep_index_counter += 1;
|
||||||
|
index
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn add_field<E: Model<'ctx> + 'ctx>(&mut self, name: &'static str, element: E) -> Field<E> {
|
||||||
|
let gep_index = self.next_gep_index();
|
||||||
|
|
||||||
|
self.fields.push(FieldLLVM {
|
||||||
|
gep_index,
|
||||||
|
name,
|
||||||
|
llvm_type: element.get_llvm_type(self.ctx),
|
||||||
|
llvm_type_model: Box::new(element),
|
||||||
|
});
|
||||||
|
|
||||||
|
Field { gep_index, name, element }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn add_field_auto<E: Model<'ctx> + Default + 'ctx>(
|
||||||
|
&mut self,
|
||||||
|
name: &'static str,
|
||||||
|
) -> Field<E> {
|
||||||
|
self.add_field(name, E::default())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A marker trait to mark singleton struct that describes a particular LLVM structure.
|
||||||
|
pub trait IsStruct<'ctx>: Clone + Copy {
|
||||||
|
/// The type of the Rust `struct` that holds all the fields of this LLVM struct.
|
||||||
|
type Fields;
|
||||||
|
|
||||||
|
/// A cosmetic name for this struct.
|
||||||
|
/// TODO: Currently unused. To be used in error reporting.
|
||||||
|
fn struct_name(&self) -> &'static str;
|
||||||
|
|
||||||
|
fn build_fields(&self, builder: &mut FieldBuilder<'ctx>) -> Self::Fields;
|
||||||
|
|
||||||
|
fn get_fields(&self, ctx: &'ctx Context) -> Self::Fields {
|
||||||
|
let mut builder = FieldBuilder::new(ctx, self.struct_name());
|
||||||
|
self.build_fields(&mut builder)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the LLVM struct type this [`IsStruct<'ctx>`] is representing.
|
||||||
|
fn get_struct_type(&self, ctx: &'ctx Context) -> StructType<'ctx> {
|
||||||
|
let mut builder = FieldBuilder::new(ctx, self.struct_name());
|
||||||
|
self.build_fields(&mut builder); // Self::Fields is discarded
|
||||||
|
|
||||||
|
let field_types = builder.fields.iter().map(|f| f.llvm_type).collect_vec();
|
||||||
|
ctx.struct_type(&field_types, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if `scrutinee` matches the [`StructType<'ctx>`] this [`IsStruct<'ctx>`] is representing.
|
||||||
|
fn check_struct_type(
|
||||||
|
&self,
|
||||||
|
ctx: &'ctx Context,
|
||||||
|
scrutinee: StructType<'ctx>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
// Details about scrutinee
|
||||||
|
let scrutinee_field_types = scrutinee.get_field_types();
|
||||||
|
|
||||||
|
// Details about the defined specifications of this struct
|
||||||
|
// We will access them through builder
|
||||||
|
let mut builder = FieldBuilder::new(ctx, self.struct_name());
|
||||||
|
self.build_fields(&mut builder);
|
||||||
|
|
||||||
|
// Check # of fields
|
||||||
|
if builder.fields.len() != scrutinee_field_types.len() {
|
||||||
|
return Err(format!(
|
||||||
|
"Expecting struct to have {} field(s), but scrutinee has {} field(s)",
|
||||||
|
builder.fields.len(),
|
||||||
|
scrutinee_field_types.len()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the types of each field
|
||||||
|
// TODO: Traceback?
|
||||||
|
for (f, scrutinee_field_type) in izip!(builder.fields, scrutinee_field_types) {
|
||||||
|
f.llvm_type_model.check_llvm_type(ctx, scrutinee_field_type.as_any_type_enum())?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A [`Model<'ctx>`] that represents an LLVM struct.
|
||||||
|
///
|
||||||
|
/// `self.0` contains a [`IsStruct<'ctx>`] that gives the details of the LLVM struct.
|
||||||
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
|
pub struct StructModel<S>(pub S);
|
||||||
|
|
||||||
|
impl<'ctx, S: IsStruct<'ctx>> CanCheckLLVMType<'ctx> for StructModel<S> {
|
||||||
|
fn check_llvm_type(
|
||||||
|
&self,
|
||||||
|
ctx: &'ctx Context,
|
||||||
|
scrutinee: AnyTypeEnum<'ctx>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
// Check if scrutinee is even a struct type
|
||||||
|
let AnyTypeEnum::StructType(scrutinee) = scrutinee else {
|
||||||
|
return Err(format!("Expecting a struct type, but got {scrutinee:?}"));
|
||||||
|
};
|
||||||
|
|
||||||
|
// Ok. now check the struct type *thoroughly*
|
||||||
|
self.0.check_struct_type(ctx, scrutinee)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, S: IsStruct<'ctx>> Model<'ctx> for StructModel<S> {
|
||||||
|
type Value = Struct<'ctx, S>;
|
||||||
|
|
||||||
|
fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
|
||||||
|
self.0.get_struct_type(ctx).as_basic_type_enum()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn review(&self, ctx: &'ctx Context, value: AnyValueEnum<'ctx>) -> Self::Value {
|
||||||
|
// Check that `value` is not some bogus values or an incorrect StructValue
|
||||||
|
self.check_llvm_type(ctx, value.get_type()).unwrap();
|
||||||
|
|
||||||
|
Struct { structure: self.0, value: value.into_struct_value() }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct Struct<'ctx, S> {
|
||||||
|
pub structure: S,
|
||||||
|
pub value: StructValue<'ctx>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, S: IsStruct<'ctx>> ModelValue<'ctx> for Struct<'ctx, S> {
|
||||||
|
fn get_llvm_value(&self) -> BasicValueEnum<'ctx> {
|
||||||
|
self.value.as_basic_value_enum()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, S: IsStruct<'ctx>> Pointer<'ctx, StructModel<S>> {
|
||||||
|
/// Build an instruction that does `getelementptr` on an LLVM structure referenced by this pointer.
|
||||||
|
///
|
||||||
|
/// This provides a nice syntax to chain up `getelementptr` in an intuitive and type-safe way:
|
||||||
|
///
|
||||||
|
/// ```ignore
|
||||||
|
/// let ctx: &CodeGenContext<'ctx, '_>;
|
||||||
|
/// let ndarray: Pointer<'ctx, StructModel<NpArray<'ctx>>>;
|
||||||
|
/// ndarray.gep(ctx, |f| f.ndims).store();
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// You might even write chains `gep`, i.e.,
|
||||||
|
/// ```ignore
|
||||||
|
/// my_struct
|
||||||
|
/// .gep(ctx, |f| f.thing1)
|
||||||
|
/// .gep(ctx, |f| f.value)
|
||||||
|
/// .store(ctx, my_value) // Equivalent to `my_struct.thing1.value = my_value`
|
||||||
|
/// ```
|
||||||
|
pub fn gep<E, GetFieldFn>(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
get_field: GetFieldFn,
|
||||||
|
) -> Pointer<'ctx, E>
|
||||||
|
where
|
||||||
|
E: Model<'ctx>,
|
||||||
|
GetFieldFn: FnOnce(S::Fields) -> Field<E>,
|
||||||
|
{
|
||||||
|
let fields = self.element.0.get_fields(ctx.ctx);
|
||||||
|
let field = get_field(fields);
|
||||||
|
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type(); // TODO: I think I'm not supposed to *just* use i32 for GEP like that
|
||||||
|
|
||||||
|
let ptr = unsafe {
|
||||||
|
ctx.builder
|
||||||
|
.build_in_bounds_gep(
|
||||||
|
self.value,
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_int(field.gep_index, false)],
|
||||||
|
field.name,
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
};
|
||||||
|
|
||||||
|
Pointer { element: field.element, value: ptr }
|
||||||
|
}
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user